fix: try throw proto_ver and proto_name when parsing CONNECT packet

This commit is contained in:
JimMoen 2024-06-27 16:39:11 +08:00
parent 6db1c0a446
commit c313aa89f0
No known key found for this signature in database
2 changed files with 64 additions and 30 deletions

View File

@ -145,7 +145,9 @@
-type replies() :: emqx_types:packet() | reply() | [reply()]. -type replies() :: emqx_types:packet() | reply() | [reply()].
-define(IS_MQTT_V5, #channel{conninfo = #{proto_ver := ?MQTT_PROTO_V5}}). -define(IS_MQTT_V5, #channel{conninfo = #{proto_ver := ?MQTT_PROTO_V5}}).
-define(IS_CONNECTED_OR_REAUTHENTICATING(ConnState),
((ConnState == connected) orelse (ConnState == reauthenticating))
).
-define(IS_COMMON_SESSION_TIMER(N), -define(IS_COMMON_SESSION_TIMER(N),
((N == retry_delivery) orelse (N == expire_awaiting_rel)) ((N == retry_delivery) orelse (N == expire_awaiting_rel))
). ).
@ -333,7 +335,7 @@ take_conn_info_fields(Fields, ClientInfo, ConnInfo) ->
| {shutdown, Reason :: term(), channel()} | {shutdown, Reason :: term(), channel()}
| {shutdown, Reason :: term(), replies(), channel()}. | {shutdown, Reason :: term(), replies(), channel()}.
handle_in(?CONNECT_PACKET(), Channel = #channel{conn_state = ConnState}) when handle_in(?CONNECT_PACKET(), Channel = #channel{conn_state = ConnState}) when
ConnState =:= connected orelse ConnState =:= reauthenticating ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
-> ->
handle_out(disconnect, ?RC_PROTOCOL_ERROR, Channel); handle_out(disconnect, ?RC_PROTOCOL_ERROR, Channel);
handle_in(?CONNECT_PACKET(), Channel = #channel{conn_state = connecting}) -> handle_in(?CONNECT_PACKET(), Channel = #channel{conn_state = connecting}) ->
@ -1016,11 +1018,11 @@ handle_frame_error(Reason, Channel = #channel{conn_state = connecting}) ->
handle_frame_error( handle_frame_error(
#{cause := frame_too_large}, Channel = #channel{conn_state = ConnState} #{cause := frame_too_large}, Channel = #channel{conn_state = ConnState}
) when ) when
ConnState =:= connected orelse ConnState =:= reauthenticating ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
-> ->
handle_out(disconnect, {?RC_PACKET_TOO_LARGE, frame_too_large}, Channel); handle_out(disconnect, {?RC_PACKET_TOO_LARGE, frame_too_large}, Channel);
handle_frame_error(Reason, Channel = #channel{conn_state = ConnState}) when handle_frame_error(Reason, Channel = #channel{conn_state = ConnState}) when
ConnState =:= connected orelse ConnState =:= reauthenticating ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
-> ->
handle_out(disconnect, {?RC_MALFORMED_PACKET, Reason}, Channel); handle_out(disconnect, {?RC_MALFORMED_PACKET, Reason}, Channel);
handle_frame_error(Reason, Channel = #channel{conn_state = disconnected}) -> handle_frame_error(Reason, Channel = #channel{conn_state = disconnected}) ->
@ -1295,7 +1297,7 @@ handle_info(
session = Session session = Session
} }
) when ) when
ConnState =:= connected orelse ConnState =:= reauthenticating ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
-> ->
{Intent, Session1} = session_disconnect(ClientInfo, ConnInfo, Session), {Intent, Session1} = session_disconnect(ClientInfo, ConnInfo, Session),
Channel1 = ensure_disconnected(Reason, maybe_publish_will_msg(sock_closed, Channel)), Channel1 = ensure_disconnected(Reason, maybe_publish_will_msg(sock_closed, Channel)),
@ -2675,13 +2677,13 @@ disconnect_and_shutdown(
?IS_MQTT_V5 = ?IS_MQTT_V5 =
#channel{conn_state = ConnState} #channel{conn_state = ConnState}
) when ) when
ConnState =:= connected orelse ConnState =:= reauthenticating ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
-> ->
NChannel = ensure_disconnected(Reason, Channel), NChannel = ensure_disconnected(Reason, Channel),
shutdown(Reason, Reply, ?DISCONNECT_PACKET(reason_code(Reason)), NChannel); shutdown(Reason, Reply, ?DISCONNECT_PACKET(reason_code(Reason)), NChannel);
%% mqtt v3/v4 connected sessions %% mqtt v3/v4 connected sessions
disconnect_and_shutdown(Reason, Reply, Channel = #channel{conn_state = ConnState}) when disconnect_and_shutdown(Reason, Reply, Channel = #channel{conn_state = ConnState}) when
ConnState =:= connected orelse ConnState =:= reauthenticating ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState)
-> ->
NChannel = ensure_disconnected(Reason, Channel), NChannel = ensure_disconnected(Reason, Channel),
shutdown(Reason, Reply, NChannel); shutdown(Reason, Reply, NChannel);

View File

@ -267,27 +267,36 @@ packet(Header, Variable, Payload) ->
#mqtt_packet{header = Header, variable = Variable, payload = Payload}. #mqtt_packet{header = Header, variable = Variable, payload = Payload}.
parse_connect(FrameBin, StrictMode) -> parse_connect(FrameBin, StrictMode) ->
{ProtoName, Rest} = parse_utf8_string_with_cause(FrameBin, StrictMode, invalid_proto_name), {ProtoName, Rest0} = parse_utf8_string_with_cause(FrameBin, StrictMode, invalid_proto_name),
case ProtoName of %% No need to parse and check proto_ver if proto_name is invalid, check it first
<<"MQTT">> -> %% And the matching check of `proto_name` and `proto_ver` fields will be done in `emqx_packet:check_proto_ver/2`
ok; _ = validate_proto_name(ProtoName),
<<"MQIsdp">> -> {IsBridge, ProtoVer, Rest2} = parse_connect_proto_ver(Rest0),
ok; Meta = #{proto_name => ProtoName, proto_ver => ProtoVer},
_ -> try
%% from spec: the server MAY send disconnect with reason code 0x84 do_parse_connect(ProtoName, IsBridge, ProtoVer, Rest2, StrictMode)
%% we chose to close socket because the client is likely not talking MQTT anyway catch
?PARSE_ERR(#{ throw:{?FRAME_PARSE_ERROR, ReasonM} when is_map(ReasonM) ->
cause => invalid_proto_name, ?PARSE_ERR(maps:merge(ReasonM, Meta));
expected => <<"'MQTT' or 'MQIsdp'">>, throw:{?FRAME_PARSE_ERROR, Reason} ->
received => ProtoName ?PARSE_ERR(Meta#{cause => Reason})
}) end.
end,
parse_connect2(ProtoName, Rest, StrictMode).
parse_connect2( do_parse_connect(
ProtoName, ProtoName,
<<BridgeTag:4, ProtoVer:4, UsernameFlagB:1, PasswordFlagB:1, WillRetainB:1, WillQoS:2, IsBridge,
WillFlagB:1, CleanStart:1, Reserved:1, KeepAlive:16/big, Rest2/binary>>, ProtoVer,
<<
UsernameFlagB:1,
PasswordFlagB:1,
WillRetainB:1,
WillQoS:2,
WillFlagB:1,
CleanStart:1,
Reserved:1,
KeepAlive:16/big,
Rest/binary
>>,
StrictMode StrictMode
) -> ) ->
_ = validate_connect_reserved(Reserved), _ = validate_connect_reserved(Reserved),
@ -302,14 +311,14 @@ parse_connect2(
UsernameFlag = bool(UsernameFlagB), UsernameFlag = bool(UsernameFlagB),
PasswordFlag = bool(PasswordFlagB) PasswordFlag = bool(PasswordFlagB)
), ),
{Properties, Rest3} = parse_properties(Rest2, ProtoVer, StrictMode), {Properties, Rest3} = parse_properties(Rest, ProtoVer, StrictMode),
{ClientId, Rest4} = parse_utf8_string_with_cause(Rest3, StrictMode, invalid_clientid), {ClientId, Rest4} = parse_utf8_string_with_cause(Rest3, StrictMode, invalid_clientid),
ConnPacket = #mqtt_packet_connect{ ConnPacket = #mqtt_packet_connect{
proto_name = ProtoName, proto_name = ProtoName,
proto_ver = ProtoVer, proto_ver = ProtoVer,
%% For bridge mode, non-standard implementation %% For bridge mode, non-standard implementation
%% Invented by mosquitto, named 'try_private': https://mosquitto.org/man/mosquitto-conf-5.html %% Invented by mosquitto, named 'try_private': https://mosquitto.org/man/mosquitto-conf-5.html
is_bridge = (BridgeTag =:= 8), is_bridge = IsBridge,
clean_start = bool(CleanStart), clean_start = bool(CleanStart),
will_flag = WillFlag, will_flag = WillFlag,
will_qos = WillQoS, will_qos = WillQoS,
@ -342,8 +351,8 @@ parse_connect2(
unexpected_trailing_bytes => size(Rest7) unexpected_trailing_bytes => size(Rest7)
}) })
end; end;
parse_connect2(_ProtoName, Bin, _StrictMode) -> do_parse_connect(_ProtoName, _IsBridge, _ProtoVer, Bin, _StrictMode) ->
%% sent less than 32 bytes %% sent less than 24 bytes
?PARSE_ERR(#{cause => malformed_connect, header_bytes => Bin}). ?PARSE_ERR(#{cause => malformed_connect, header_bytes => Bin}).
parse_packet( parse_packet(
@ -515,6 +524,12 @@ parse_packet_id(<<PacketId:16/big, Rest/binary>>) ->
parse_packet_id(_) -> parse_packet_id(_) ->
?PARSE_ERR(invalid_packet_id). ?PARSE_ERR(invalid_packet_id).
parse_connect_proto_ver(<<BridgeTag:4, ProtoVer:4, Rest/binary>>) ->
{_IsBridge = (BridgeTag =:= 8), ProtoVer, Rest};
parse_connect_proto_ver(Bin) ->
%% sent less than 1 bytes or empty
?PARSE_ERR(#{cause => malformed_connect, header_bytes => Bin}).
parse_properties(Bin, Ver, _StrictMode) when Ver =/= ?MQTT_PROTO_V5 -> parse_properties(Bin, Ver, _StrictMode) when Ver =/= ?MQTT_PROTO_V5 ->
{#{}, Bin}; {#{}, Bin};
%% TODO: version mess? %% TODO: version mess?
@ -1129,10 +1144,25 @@ validate_subqos([3 | _]) -> ?PARSE_ERR(bad_subqos);
validate_subqos([_ | T]) -> validate_subqos(T); validate_subqos([_ | T]) -> validate_subqos(T);
validate_subqos([]) -> ok. validate_subqos([]) -> ok.
%% from spec: the server MAY send disconnect with reason code 0x84
%% we chose to close socket because the client is likely not talking MQTT anyway
validate_proto_name(<<"MQTT">>) ->
ok;
validate_proto_name(<<"MQIsdp">>) ->
ok;
validate_proto_name(ProtoName) ->
?PARSE_ERR(#{
cause => invalid_proto_name,
expected => <<"'MQTT' or 'MQIsdp'">>,
received => ProtoName
}).
%% MQTT-v3.1.1-[MQTT-3.1.2-3], MQTT-v5.0-[MQTT-3.1.2-3] %% MQTT-v3.1.1-[MQTT-3.1.2-3], MQTT-v5.0-[MQTT-3.1.2-3]
-compile({inline, [validate_connect_reserved/1]}).
validate_connect_reserved(0) -> ok; validate_connect_reserved(0) -> ok;
validate_connect_reserved(1) -> ?PARSE_ERR(reserved_connect_flag). validate_connect_reserved(1) -> ?PARSE_ERR(reserved_connect_flag).
-compile({inline, [validate_connect_will/3]}).
%% MQTT-v3.1.1-[MQTT-3.1.2-13], MQTT-v5.0-[MQTT-3.1.2-11] %% MQTT-v3.1.1-[MQTT-3.1.2-13], MQTT-v5.0-[MQTT-3.1.2-11]
validate_connect_will(false, _, WillQoS) when WillQoS > 0 -> ?PARSE_ERR(invalid_will_qos); validate_connect_will(false, _, WillQoS) when WillQoS > 0 -> ?PARSE_ERR(invalid_will_qos);
%% MQTT-v3.1.1-[MQTT-3.1.2-14], MQTT-v5.0-[MQTT-3.1.2-12] %% MQTT-v3.1.1-[MQTT-3.1.2-14], MQTT-v5.0-[MQTT-3.1.2-12]
@ -1141,6 +1171,7 @@ validate_connect_will(true, _, WillQoS) when WillQoS > 2 -> ?PARSE_ERR(invalid_w
validate_connect_will(false, WillRetain, _) when WillRetain -> ?PARSE_ERR(invalid_will_retain); validate_connect_will(false, WillRetain, _) when WillRetain -> ?PARSE_ERR(invalid_will_retain);
validate_connect_will(_, _, _) -> ok. validate_connect_will(_, _, _) -> ok.
-compile({inline, [validate_connect_password_flag/4]}).
%% MQTT-v3.1 %% MQTT-v3.1
%% Username flag and password flag are not strongly related %% Username flag and password flag are not strongly related
%% https://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html#connect %% https://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html#connect
@ -1155,6 +1186,7 @@ validate_connect_password_flag(true, ?MQTT_PROTO_V5, _, _) ->
validate_connect_password_flag(_, _, _, _) -> validate_connect_password_flag(_, _, _, _) ->
ok. ok.
-compile({inline, [bool/1]}).
bool(0) -> false; bool(0) -> false;
bool(1) -> true. bool(1) -> true.