From c313aa89f07c9918c87f6938795020368baca42d Mon Sep 17 00:00:00 2001 From: JimMoen Date: Thu, 27 Jun 2024 16:39:11 +0800 Subject: [PATCH] fix: try throw proto_ver and proto_name when parsing CONNECT packet --- apps/emqx/src/emqx_channel.erl | 16 ++++--- apps/emqx/src/emqx_frame.erl | 78 ++++++++++++++++++++++++---------- 2 files changed, 64 insertions(+), 30 deletions(-) diff --git a/apps/emqx/src/emqx_channel.erl b/apps/emqx/src/emqx_channel.erl index f7c76cade..f7b210a22 100644 --- a/apps/emqx/src/emqx_channel.erl +++ b/apps/emqx/src/emqx_channel.erl @@ -145,7 +145,9 @@ -type replies() :: emqx_types:packet() | reply() | [reply()]. -define(IS_MQTT_V5, #channel{conninfo = #{proto_ver := ?MQTT_PROTO_V5}}). - +-define(IS_CONNECTED_OR_REAUTHENTICATING(ConnState), + ((ConnState == connected) orelse (ConnState == reauthenticating)) +). -define(IS_COMMON_SESSION_TIMER(N), ((N == retry_delivery) orelse (N == expire_awaiting_rel)) ). @@ -333,7 +335,7 @@ take_conn_info_fields(Fields, ClientInfo, ConnInfo) -> | {shutdown, Reason :: term(), channel()} | {shutdown, Reason :: term(), replies(), channel()}. handle_in(?CONNECT_PACKET(), Channel = #channel{conn_state = ConnState}) when - ConnState =:= connected orelse ConnState =:= reauthenticating + ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState) -> handle_out(disconnect, ?RC_PROTOCOL_ERROR, Channel); handle_in(?CONNECT_PACKET(), Channel = #channel{conn_state = connecting}) -> @@ -1016,11 +1018,11 @@ handle_frame_error(Reason, Channel = #channel{conn_state = connecting}) -> handle_frame_error( #{cause := frame_too_large}, Channel = #channel{conn_state = ConnState} ) when - ConnState =:= connected orelse ConnState =:= reauthenticating + ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState) -> handle_out(disconnect, {?RC_PACKET_TOO_LARGE, frame_too_large}, Channel); handle_frame_error(Reason, Channel = #channel{conn_state = ConnState}) when - ConnState =:= connected orelse ConnState =:= reauthenticating + ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState) -> handle_out(disconnect, {?RC_MALFORMED_PACKET, Reason}, Channel); handle_frame_error(Reason, Channel = #channel{conn_state = disconnected}) -> @@ -1295,7 +1297,7 @@ handle_info( session = Session } ) when - ConnState =:= connected orelse ConnState =:= reauthenticating + ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState) -> {Intent, Session1} = session_disconnect(ClientInfo, ConnInfo, Session), Channel1 = ensure_disconnected(Reason, maybe_publish_will_msg(sock_closed, Channel)), @@ -2675,13 +2677,13 @@ disconnect_and_shutdown( ?IS_MQTT_V5 = #channel{conn_state = ConnState} ) when - ConnState =:= connected orelse ConnState =:= reauthenticating + ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState) -> NChannel = ensure_disconnected(Reason, Channel), shutdown(Reason, Reply, ?DISCONNECT_PACKET(reason_code(Reason)), NChannel); %% mqtt v3/v4 connected sessions disconnect_and_shutdown(Reason, Reply, Channel = #channel{conn_state = ConnState}) when - ConnState =:= connected orelse ConnState =:= reauthenticating + ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState) -> NChannel = ensure_disconnected(Reason, Channel), shutdown(Reason, Reply, NChannel); diff --git a/apps/emqx/src/emqx_frame.erl b/apps/emqx/src/emqx_frame.erl index f83a739ad..398a5f35c 100644 --- a/apps/emqx/src/emqx_frame.erl +++ b/apps/emqx/src/emqx_frame.erl @@ -267,27 +267,36 @@ packet(Header, Variable, Payload) -> #mqtt_packet{header = Header, variable = Variable, payload = Payload}. parse_connect(FrameBin, StrictMode) -> - {ProtoName, Rest} = parse_utf8_string_with_cause(FrameBin, StrictMode, invalid_proto_name), - case ProtoName of - <<"MQTT">> -> - ok; - <<"MQIsdp">> -> - ok; - _ -> - %% from spec: the server MAY send disconnect with reason code 0x84 - %% we chose to close socket because the client is likely not talking MQTT anyway - ?PARSE_ERR(#{ - cause => invalid_proto_name, - expected => <<"'MQTT' or 'MQIsdp'">>, - received => ProtoName - }) - end, - parse_connect2(ProtoName, Rest, StrictMode). + {ProtoName, Rest0} = parse_utf8_string_with_cause(FrameBin, StrictMode, invalid_proto_name), + %% No need to parse and check proto_ver if proto_name is invalid, check it first + %% And the matching check of `proto_name` and `proto_ver` fields will be done in `emqx_packet:check_proto_ver/2` + _ = validate_proto_name(ProtoName), + {IsBridge, ProtoVer, Rest2} = parse_connect_proto_ver(Rest0), + Meta = #{proto_name => ProtoName, proto_ver => ProtoVer}, + try + do_parse_connect(ProtoName, IsBridge, ProtoVer, Rest2, StrictMode) + catch + throw:{?FRAME_PARSE_ERROR, ReasonM} when is_map(ReasonM) -> + ?PARSE_ERR(maps:merge(ReasonM, Meta)); + throw:{?FRAME_PARSE_ERROR, Reason} -> + ?PARSE_ERR(Meta#{cause => Reason}) + end. -parse_connect2( +do_parse_connect( ProtoName, - <>, + IsBridge, + ProtoVer, + << + UsernameFlagB:1, + PasswordFlagB:1, + WillRetainB:1, + WillQoS:2, + WillFlagB:1, + CleanStart:1, + Reserved:1, + KeepAlive:16/big, + Rest/binary + >>, StrictMode ) -> _ = validate_connect_reserved(Reserved), @@ -302,14 +311,14 @@ parse_connect2( UsernameFlag = bool(UsernameFlagB), PasswordFlag = bool(PasswordFlagB) ), - {Properties, Rest3} = parse_properties(Rest2, ProtoVer, StrictMode), + {Properties, Rest3} = parse_properties(Rest, ProtoVer, StrictMode), {ClientId, Rest4} = parse_utf8_string_with_cause(Rest3, StrictMode, invalid_clientid), ConnPacket = #mqtt_packet_connect{ proto_name = ProtoName, proto_ver = ProtoVer, %% For bridge mode, non-standard implementation %% Invented by mosquitto, named 'try_private': https://mosquitto.org/man/mosquitto-conf-5.html - is_bridge = (BridgeTag =:= 8), + is_bridge = IsBridge, clean_start = bool(CleanStart), will_flag = WillFlag, will_qos = WillQoS, @@ -342,8 +351,8 @@ parse_connect2( unexpected_trailing_bytes => size(Rest7) }) end; -parse_connect2(_ProtoName, Bin, _StrictMode) -> - %% sent less than 32 bytes +do_parse_connect(_ProtoName, _IsBridge, _ProtoVer, Bin, _StrictMode) -> + %% sent less than 24 bytes ?PARSE_ERR(#{cause => malformed_connect, header_bytes => Bin}). parse_packet( @@ -515,6 +524,12 @@ parse_packet_id(<>) -> parse_packet_id(_) -> ?PARSE_ERR(invalid_packet_id). +parse_connect_proto_ver(<>) -> + {_IsBridge = (BridgeTag =:= 8), ProtoVer, Rest}; +parse_connect_proto_ver(Bin) -> + %% sent less than 1 bytes or empty + ?PARSE_ERR(#{cause => malformed_connect, header_bytes => Bin}). + parse_properties(Bin, Ver, _StrictMode) when Ver =/= ?MQTT_PROTO_V5 -> {#{}, Bin}; %% TODO: version mess? @@ -1129,10 +1144,25 @@ validate_subqos([3 | _]) -> ?PARSE_ERR(bad_subqos); validate_subqos([_ | T]) -> validate_subqos(T); validate_subqos([]) -> ok. +%% from spec: the server MAY send disconnect with reason code 0x84 +%% we chose to close socket because the client is likely not talking MQTT anyway +validate_proto_name(<<"MQTT">>) -> + ok; +validate_proto_name(<<"MQIsdp">>) -> + ok; +validate_proto_name(ProtoName) -> + ?PARSE_ERR(#{ + cause => invalid_proto_name, + expected => <<"'MQTT' or 'MQIsdp'">>, + received => ProtoName + }). + %% MQTT-v3.1.1-[MQTT-3.1.2-3], MQTT-v5.0-[MQTT-3.1.2-3] +-compile({inline, [validate_connect_reserved/1]}). validate_connect_reserved(0) -> ok; validate_connect_reserved(1) -> ?PARSE_ERR(reserved_connect_flag). +-compile({inline, [validate_connect_will/3]}). %% MQTT-v3.1.1-[MQTT-3.1.2-13], MQTT-v5.0-[MQTT-3.1.2-11] validate_connect_will(false, _, WillQoS) when WillQoS > 0 -> ?PARSE_ERR(invalid_will_qos); %% MQTT-v3.1.1-[MQTT-3.1.2-14], MQTT-v5.0-[MQTT-3.1.2-12] @@ -1141,6 +1171,7 @@ validate_connect_will(true, _, WillQoS) when WillQoS > 2 -> ?PARSE_ERR(invalid_w validate_connect_will(false, WillRetain, _) when WillRetain -> ?PARSE_ERR(invalid_will_retain); validate_connect_will(_, _, _) -> ok. +-compile({inline, [validate_connect_password_flag/4]}). %% MQTT-v3.1 %% Username flag and password flag are not strongly related %% https://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html#connect @@ -1155,6 +1186,7 @@ validate_connect_password_flag(true, ?MQTT_PROTO_V5, _, _) -> validate_connect_password_flag(_, _, _, _) -> ok. +-compile({inline, [bool/1]}). bool(0) -> false; bool(1) -> true.