From 37a89d009480294a4c37b2d9446dd8a159063ad5 Mon Sep 17 00:00:00 2001 From: JimMoen Date: Mon, 15 Jul 2024 11:38:10 +0800 Subject: [PATCH] fix: enrich parse_state and connection serialize opts --- apps/emqx/include/emqx_mqtt.hrl | 1 + apps/emqx/src/emqx_channel.erl | 78 +++++++++++++++++++++------- apps/emqx/src/emqx_connection.erl | 9 +++- apps/emqx/src/emqx_frame.erl | 27 +++++++--- apps/emqx/src/emqx_ws_connection.erl | 11 +++- 5 files changed, 100 insertions(+), 26 deletions(-) diff --git a/apps/emqx/include/emqx_mqtt.hrl b/apps/emqx/include/emqx_mqtt.hrl index 09f7495ea..1c3fd770c 100644 --- a/apps/emqx/include/emqx_mqtt.hrl +++ b/apps/emqx/include/emqx_mqtt.hrl @@ -683,6 +683,7 @@ end). -define(FRAME_PARSE_ERROR, frame_parse_error). -define(FRAME_SERIALIZE_ERROR, frame_serialize_error). + -define(THROW_FRAME_ERROR(Reason), erlang:throw({?FRAME_PARSE_ERROR, Reason})). -define(THROW_SERIALIZE_ERROR(Reason), erlang:throw({?FRAME_SERIALIZE_ERROR, Reason})). diff --git a/apps/emqx/src/emqx_channel.erl b/apps/emqx/src/emqx_channel.erl index f7b210a22..07aad9b24 100644 --- a/apps/emqx/src/emqx_channel.erl +++ b/apps/emqx/src/emqx_channel.erl @@ -37,6 +37,7 @@ get_mqtt_conf/2, get_mqtt_conf/3, set_conn_state/2, + set_conninfo_proto_ver/2, stats/1, caps/1 ]). @@ -219,6 +220,9 @@ info(impl, #channel{session = Session}) -> set_conn_state(ConnState, Channel) -> Channel#channel{conn_state = ConnState}. +set_conninfo_proto_ver({none, #{version := ProtoVer}}, Channel = #channel{conninfo = ConnInfo}) -> + Channel#channel{conninfo = ConnInfo#{proto_ver => ProtoVer}}. + -spec stats(channel()) -> emqx_types:stats(). stats(#channel{session = undefined}) -> emqx_pd:get_counters(?CHANNEL_METRICS); @@ -1003,29 +1007,60 @@ not_nacked({deliver, _Topic, Msg}) -> %%-------------------------------------------------------------------- handle_frame_error( - Reason, - Channel = #channel{conn_state = idle} -) -> - shutdown(shutdown_count(frame_error, Reason), Channel); -handle_frame_error( - #{cause := frame_too_large} = R, Channel = #channel{conn_state = connecting} -) -> - shutdown( - shutdown_count(frame_error, R), ?CONNACK_PACKET(?RC_PACKET_TOO_LARGE), Channel - ); -handle_frame_error(Reason, Channel = #channel{conn_state = connecting}) -> - shutdown(shutdown_count(frame_error, Reason), ?CONNACK_PACKET(?RC_MALFORMED_PACKET), Channel); -handle_frame_error( - #{cause := frame_too_large}, Channel = #channel{conn_state = ConnState} + Reason = #{cause := frame_too_large}, + Channel = #channel{conn_state = ConnState, conninfo = ConnInfo} ) when ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState) -> - handle_out(disconnect, {?RC_PACKET_TOO_LARGE, frame_too_large}, Channel); -handle_frame_error(Reason, Channel = #channel{conn_state = ConnState}) when + ShutdownCount = shutdown_count(frame_error, Reason), + case proto_ver(Reason, ConnInfo) of + ?MQTT_PROTO_V5 -> + handle_out(disconnect, {?RC_PACKET_TOO_LARGE, frame_too_large}, Channel); + _ -> + shutdown(ShutdownCount, Channel) + end; +%% Only send CONNACK with reason code `frame_too_large` for MQTT-v5.0 when connecting, +%% otherwise DONOT send any CONNACK or DISCONNECT packet. +handle_frame_error( + Reason, + Channel = #channel{conn_state = ConnState, conninfo = ConnInfo} +) when + is_map(Reason) andalso + (ConnState == idle orelse ConnState == connecting) +-> + ShutdownCount = shutdown_count(frame_error, Reason), + ProtoVer = proto_ver(Reason, ConnInfo), + NChannel = Channel#channel{conninfo = ConnInfo#{proto_ver => ProtoVer}}, + case ProtoVer of + ?MQTT_PROTO_V5 -> + shutdown(ShutdownCount, ?CONNACK_PACKET(?RC_PACKET_TOO_LARGE), NChannel); + _ -> + shutdown(ShutdownCount, NChannel) + end; +handle_frame_error( + Reason, + Channel = #channel{conn_state = connecting} +) -> + shutdown( + shutdown_count(frame_error, Reason), + ?CONNACK_PACKET(?RC_MALFORMED_PACKET), + Channel + ); +handle_frame_error( + Reason, + Channel = #channel{conn_state = ConnState} +) when ?IS_CONNECTED_OR_REAUTHENTICATING(ConnState) -> - handle_out(disconnect, {?RC_MALFORMED_PACKET, Reason}, Channel); -handle_frame_error(Reason, Channel = #channel{conn_state = disconnected}) -> + handle_out( + disconnect, + {?RC_MALFORMED_PACKET, Reason}, + Channel + ); +handle_frame_error( + Reason, + Channel = #channel{conn_state = disconnected} +) -> ?SLOG(error, #{msg => "malformed_mqtt_message", reason => Reason}), {ok, Channel}. @@ -2726,6 +2761,13 @@ is_durable_session(#channel{session = Session}) -> false end. +proto_ver(#{proto_ver := ProtoVer}, _ConnInfo) -> + ProtoVer; +proto_ver(_Reason, #{proto_ver := ProtoVer}) -> + ProtoVer; +proto_ver(_, _) -> + ?MQTT_PROTO_V4. + %%-------------------------------------------------------------------- %% For CT tests %%-------------------------------------------------------------------- diff --git a/apps/emqx/src/emqx_connection.erl b/apps/emqx/src/emqx_connection.erl index f378b700e..ecb962f08 100644 --- a/apps/emqx/src/emqx_connection.erl +++ b/apps/emqx/src/emqx_connection.erl @@ -782,7 +782,8 @@ parse_incoming(Data, Packets, State = #state{parse_state = ParseState}) -> input_bytes => Data, parsed_packets => Packets }), - {[{frame_error, Reason} | Packets], State}; + NState = enrich_state(Reason, State), + {[{frame_error, Reason} | Packets], NState}; error:Reason:Stacktrace -> ?LOG(error, #{ at_state => emqx_frame:describe_state(ParseState), @@ -1204,6 +1205,12 @@ inc_counter(Key, Inc) -> _ = emqx_pd:inc_counter(Key, Inc), ok. +enrich_state(#{parse_state := NParseState}, State) -> + Serialize = emqx_frame:serialize_opts(NParseState), + State#state{parse_state = NParseState, serialize = Serialize}; +enrich_state(_, State) -> + State. + set_tcp_keepalive({quic, _Listener}) -> ok; set_tcp_keepalive({Type, Id}) -> diff --git a/apps/emqx/src/emqx_frame.erl b/apps/emqx/src/emqx_frame.erl index 398a5f35c..554847d67 100644 --- a/apps/emqx/src/emqx_frame.erl +++ b/apps/emqx/src/emqx_frame.erl @@ -266,20 +266,33 @@ packet(Header, Variable) -> packet(Header, Variable, Payload) -> #mqtt_packet{header = Header, variable = Variable, payload = Payload}. -parse_connect(FrameBin, StrictMode) -> +parse_connect(FrameBin, Options = #{strict_mode := StrictMode}) -> {ProtoName, Rest0} = parse_utf8_string_with_cause(FrameBin, StrictMode, invalid_proto_name), %% No need to parse and check proto_ver if proto_name is invalid, check it first %% And the matching check of `proto_name` and `proto_ver` fields will be done in `emqx_packet:check_proto_ver/2` _ = validate_proto_name(ProtoName), {IsBridge, ProtoVer, Rest2} = parse_connect_proto_ver(Rest0), - Meta = #{proto_name => ProtoName, proto_ver => ProtoVer}, + NOptions = Options#{version => ProtoVer}, try do_parse_connect(ProtoName, IsBridge, ProtoVer, Rest2, StrictMode) catch throw:{?FRAME_PARSE_ERROR, ReasonM} when is_map(ReasonM) -> - ?PARSE_ERR(maps:merge(ReasonM, Meta)); + ?PARSE_ERR( + ReasonM#{ + proto_ver => ProtoVer, + proto_name => ProtoName, + parse_state => ?NONE(NOptions) + } + ); throw:{?FRAME_PARSE_ERROR, Reason} -> - ?PARSE_ERR(Meta#{cause => Reason}) + ?PARSE_ERR( + #{ + cause => Reason, + proto_ver => ProtoVer, + proto_name => ProtoName, + parse_state => ?NONE(NOptions) + } + ) end. do_parse_connect( @@ -358,9 +371,9 @@ do_parse_connect(_ProtoName, _IsBridge, _ProtoVer, Bin, _StrictMode) -> parse_packet( #mqtt_packet_header{type = ?CONNECT}, FrameBin, - #{strict_mode := StrictMode} + Options ) -> - parse_connect(FrameBin, StrictMode); + parse_connect(FrameBin, Options); parse_packet( #mqtt_packet_header{type = ?CONNACK}, <>, @@ -753,6 +766,8 @@ serialize_fun(#{version := Ver, max_size := MaxSize}) -> serialize_opts() -> ?DEFAULT_OPTIONS. +serialize_opts(?NONE(Options)) -> + maps:merge(?DEFAULT_OPTIONS, Options); serialize_opts(#mqtt_packet_connect{proto_ver = ProtoVer, properties = ConnProps}) -> MaxSize = get_property('Maximum-Packet-Size', ConnProps, ?MAX_PACKET_SIZE), #{version => ProtoVer, max_size => MaxSize}. diff --git a/apps/emqx/src/emqx_ws_connection.erl b/apps/emqx/src/emqx_ws_connection.erl index 5d04b3304..4765fdace 100644 --- a/apps/emqx/src/emqx_ws_connection.erl +++ b/apps/emqx/src/emqx_ws_connection.erl @@ -436,6 +436,7 @@ websocket_handle({Frame, _}, State) -> %% TODO: should not close the ws connection ?LOG(error, #{msg => "unexpected_frame", frame => Frame}), shutdown(unexpected_ws_frame, State). + websocket_info({call, From, Req}, State) -> handle_call(From, Req, State); websocket_info({cast, rate_limit}, State) -> @@ -725,7 +726,8 @@ parse_incoming(Data, Packets, State = #state{parse_state = ParseState}) -> input_bytes => Data }), FrameError = {frame_error, Reason}, - {[{incoming, FrameError} | Packets], State}; + NState = enrich_state(Reason, State), + {[{incoming, FrameError} | Packets], NState}; error:Reason:Stacktrace -> ?LOG(error, #{ at_state => emqx_frame:describe_state(ParseState), @@ -1059,6 +1061,13 @@ check_max_connection(Type, Listener) -> {denny, Reason} end end. + +enrich_state(#{parse_state := NParseState}, State) -> + Serialize = emqx_frame:serialize_opts(NParseState), + State#state{parse_state = NParseState, serialize = Serialize}; +enrich_state(_, State) -> + State. + %%-------------------------------------------------------------------- %% For CT tests %%--------------------------------------------------------------------