From 35822ff97ad550c5f802d3dc24f75b13b1f7502c Mon Sep 17 00:00:00 2001 From: tigercl Date: Mon, 16 Sep 2019 13:51:50 +0800 Subject: [PATCH] Fix handling for MQTT spec (#2892) Fix invalid QoS and protocol name, fix handling for Topic-Alias-Maximum and Maximum-Packet-Size, and send DISCONNECT Packet to client when the session is taken over --- src/emqx_channel.erl | 13 ++++++---- src/emqx_connection.erl | 47 +++++++++++++++++++++++++++--------- src/emqx_metrics.erl | 3 ++- src/emqx_mqtt_props.erl | 14 +++++++++++ src/emqx_packet.erl | 9 ++++--- src/emqx_protocol.erl | 38 +++++++++++++++-------------- src/emqx_ws_connection.erl | 43 ++++++++++++++++++++++++--------- test/emqx_channel_SUITE.erl | 2 +- test/emqx_packet_SUITE.erl | 2 +- test/emqx_protocol_SUITE.erl | 10 ++++---- 10 files changed, 122 insertions(+), 59 deletions(-) diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl index 05f22edfb..4c918253b 100644 --- a/src/emqx_channel.erl +++ b/src/emqx_channel.erl @@ -525,7 +525,7 @@ handle_out({connack, ReasonCode}, Channel = #channel{client = Client, }) -> ok = emqx_hooks:run('client.connected', [Client, ReasonCode, attrs(Channel)]), ProtoVer = case Protocol of - undefined -> undefined; + undefined -> ?MQTT_PROTO_V5; _ -> emqx_protocol:info(proto_ver, Protocol) end, ReasonCode1 = if @@ -630,7 +630,10 @@ handle_out({Type, Data}, Channel) -> handle_call(kick, Channel) -> {stop, {shutdown, kicked}, ok, Channel}; -handle_call(discard, Channel) -> +handle_call(discard, Channel = #channel{connected = true}) -> + Packet = ?DISCONNECT_PACKET(?RC_SESSION_TAKEN_OVER), + {stop, {shutdown, discarded}, Packet, ok, Channel}; +handle_call(discard, Channel = #channel{connected = false}) -> {stop, {shutdown, discarded}, ok, Channel}; %% Session Takeover @@ -852,7 +855,7 @@ validate_packet(Packet, _Channel) -> error:topic_filters_invalid -> {error, ?RC_TOPIC_FILTER_INVALID}; error:topic_name_invalid -> - {error, ?RC_TOPIC_FILTER_INVALID}; + {error, ?RC_TOPIC_NAME_INVALID}; error:_Reason -> {error, ?RC_MALFORMED_PACKET} end. @@ -916,8 +919,8 @@ check_will_retain(#mqtt_packet_connect{will_retain = true}, false -> {error, ?RC_RETAIN_NOT_SUPPORTED} end. -init_protocol(ConnPkt, Channel) -> - {ok, Channel#channel{protocol = emqx_protocol:init(ConnPkt)}}. +init_protocol(ConnPkt, Channel = #channel{client = #{zone := Zone}}) -> + {ok, Channel#channel{protocol = emqx_protocol:init(ConnPkt, Zone)}}. %%-------------------------------------------------------------------- %% Enrich client diff --git a/src/emqx_connection.erl b/src/emqx_connection.erl index 54da647fb..b840cad48 100644 --- a/src/emqx_connection.erl +++ b/src/emqx_connection.erl @@ -191,7 +191,8 @@ init({Transport, RawSocket, Options}) -> rate_limit = RateLimit, pub_limit = PubLimit, parse_state = ParseState, - chan_state = ChanState + chan_state = ChanState, + serialize = serialize_fun(?MQTT_PROTO_V5, undefined) }, gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}], idle, State, self(), [IdleTimout]). @@ -217,8 +218,9 @@ idle(timeout, _Timeout, State) -> shutdown(idle_timeout, State); idle(cast, {incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State) -> - #mqtt_packet_connect{proto_ver = ProtoVer} = ConnPkt, - NState = State#connection{serialize = serialize_fun(ProtoVer)}, + #mqtt_packet_connect{proto_ver = ProtoVer, properties = Properties} = ConnPkt, + MaxPacketSize = emqx_mqtt_props:get_property('Maximum-Packet-Size', Properties, undefined), + NState = State#connection{serialize = serialize_fun(ProtoVer, MaxPacketSize)}, SuccFun = fun(NewSt) -> {next_state, connected, NewSt} end, handle_incoming(Packet, SuccFun, NState); @@ -283,6 +285,10 @@ handle({call, From}, Req, State = #connection{chan_state = ChanState}) -> {ok, Reply, NChanState} -> reply(From, Reply, State#connection{chan_state = NChanState}); {stop, Reason, Reply, NChanState} -> + ok = gen_statem:reply(From, Reply), + stop(Reason, State#connection{chan_state = NChanState}); + {stop, Reason, Packet, Reply, NChanState} -> + handle_outgoing(Packet, fun (_) -> ok end, State#connection{chan_state = NChanState}), ok = gen_statem:reply(From, Reply), stop(Reason, State#connection{chan_state = NChanState}) end; @@ -414,15 +420,25 @@ process_incoming(Data, Packets, State = #connection{parse_state = ParseState, ch shutdown(Reason, State) catch error:Reason:Stk -> - ?LOG(error, "Parse failed for ~p~n\ - Stacktrace:~p~nError data:~p", [Reason, Stk, Data]), - case emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState) of + ?LOG(error, "Parse failed for ~p~nStacktrace:~p~nError data:~p", [Reason, Stk, Data]), + Result = + case emqx_channel:info(connected, ChanState) of + undefined -> + emqx_channel:handle_out({connack, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState); + true -> + emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState); + _ -> + ignore + end, + case Result of {stop, Reason0, OutPackets, NChanState} -> Shutdown = fun(NewSt) -> stop(Reason0, NewSt) end, NState = State#connection{chan_state = NChanState}, handle_outgoing(OutPackets, Shutdown, NState); {stop, Reason0, NChanState} -> - stop(Reason0, State#connection{chan_state = NChanState}) + stop(Reason0, State#connection{chan_state = NChanState}); + ignore -> + keep_state(State) end end. @@ -479,12 +495,19 @@ handle_outgoing(Packet, SuccFun, State = #connection{serialize = Serialize}) -> %%-------------------------------------------------------------------- %% Serialize fun -serialize_fun(ProtoVer) -> +serialize_fun(ProtoVer, MaxPacketSize) -> fun(Packet = ?PACKET(Type)) -> - ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), - _ = inc_outgoing_stats(Type), - _ = emqx_metrics:inc_sent(Packet), - emqx_frame:serialize(Packet, ProtoVer) + IoData = emqx_frame:serialize(Packet, ProtoVer), + case Type =/= ?PUBLISH orelse MaxPacketSize =:= undefined orelse iolist_size(IoData) =< MaxPacketSize of + true -> + ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), + _ = inc_outgoing_stats(Type), + _ = emqx_metrics:inc_sent(Packet), + IoData; + false -> + ?LOG(warning, "DROP ~s due to oversize packet size", [emqx_packet:format(Packet)]), + <<"">> + end end. %%-------------------------------------------------------------------- diff --git a/src/emqx_metrics.erl b/src/emqx_metrics.erl index c7f347dc5..db42cd1e8 100644 --- a/src/emqx_metrics.erl +++ b/src/emqx_metrics.erl @@ -280,7 +280,8 @@ do_inc_recv(?PUBLISH_PACKET(QoS, _PktId)) -> case QoS of ?QOS_0 -> inc('messages.qos0.received'); ?QOS_1 -> inc('messages.qos1.received'); - ?QOS_2 -> inc('messages.qos2.received') + ?QOS_2 -> inc('messages.qos2.received'); + _ -> ignore end, inc('packets.publish.received'); do_inc_recv(?PACKET(?PUBACK)) -> diff --git a/src/emqx_mqtt_props.erl b/src/emqx_mqtt_props.erl index 163d9baf1..241db9a2e 100644 --- a/src/emqx_mqtt_props.erl +++ b/src/emqx_mqtt_props.erl @@ -28,6 +28,10 @@ %% For tests -export([all/0]). +-export([ set_property/3 + , get_property/3 + ]). + -type(prop_name() :: atom()). -type(prop_id() :: pos_integer()). @@ -179,3 +183,13 @@ validate_value(_Type, _Val) -> false. -spec(all() -> map()). all() -> ?PROPS_TABLE. +set_property(Name, Value, undefined) -> + #{Name => Value}; +set_property(Name, Value, Props) -> + Props#{Name => Value}. + +get_property(_Name, undefined, Default) -> + Default; +get_property(Name, Props, Default) -> + maps:get(Name, Props, Default). + diff --git a/src/emqx_packet.erl b/src/emqx_packet.erl index 9f9fc7610..a6876c1ef 100644 --- a/src/emqx_packet.erl +++ b/src/emqx_packet.erl @@ -75,10 +75,11 @@ validate(?UNSUBSCRIBE_PACKET(PacketId, TopicFilters)) -> validate(?PUBLISH_PACKET(_QoS, <<>>, _, #{'Topic-Alias':= _I}, _)) -> true; validate(?PUBLISH_PACKET(_QoS, <<>>, _, _, _)) -> - error(topic_name_invalid); -validate(?PUBLISH_PACKET(_QoS, Topic, _, Properties, _)) -> - ((not emqx_topic:wildcard(Topic)) orelse error(topic_name_invalid)) - andalso validate_properties(?PUBLISH, Properties); + error(protocol_error); +validate(?PUBLISH_PACKET(QoS, Topic, _, Properties, _)) -> + ((not (QoS =:= 3)) orelse error(qos_invalid)) + andalso ((not emqx_topic:wildcard(Topic)) orelse error(topic_name_invalid)) + andalso validate_properties(?PUBLISH, Properties); validate(?CONNECT_PACKET(#mqtt_packet_connect{properties = Properties})) -> validate_properties(?CONNECT, Properties); diff --git a/src/emqx_protocol.erl b/src/emqx_protocol.erl index 7f9d69666..54ff6056c 100644 --- a/src/emqx_protocol.erl +++ b/src/emqx_protocol.erl @@ -20,7 +20,7 @@ -include("types.hrl"). -include("emqx_mqtt.hrl"). --export([ init/1 +-export([ init/2 , info/1 , info/2 , attrs/1 @@ -48,10 +48,10 @@ username :: emqx_types:username(), %% MQTT Will Msg will_msg :: emqx_types:message(), - %% MQTT Conn Properties - conn_props :: maybe(emqx_types:properties()), %% MQTT Topic Aliases - topic_aliases :: maybe(map()) + topic_aliases :: maybe(map()), + %% MQTT Topic Alias Maximum + alias_maximum :: maybe(map()) }). -opaque(protocol() :: #protocol{}). @@ -60,23 +60,24 @@ -define(ATTR_KEYS, [proto_name, proto_ver, clean_start, keepalive]). --spec(init(#mqtt_packet_connect{}) -> protocol()). +-spec(init(#mqtt_packet_connect{}, atom()) -> protocol()). init(#mqtt_packet_connect{proto_name = ProtoName, proto_ver = ProtoVer, clean_start = CleanStart, keepalive = Keepalive, properties = Properties, client_id = ClientId, - username = Username} = ConnPkt) -> + username = Username} = ConnPkt, Zone) -> WillMsg = emqx_packet:will_msg(ConnPkt), - #protocol{proto_name = ProtoName, - proto_ver = ProtoVer, - clean_start = CleanStart, - keepalive = Keepalive, - client_id = ClientId, - username = Username, - will_msg = WillMsg, - conn_props = Properties + #protocol{proto_name = ProtoName, + proto_ver = ProtoVer, + clean_start = CleanStart, + keepalive = Keepalive, + client_id = ClientId, + username = Username, + will_msg = WillMsg, + alias_maximum = #{outbound => emqx_mqtt_props:get_property('Topic-Alias-Maximum', Properties, 0), + inbound => maps:get(max_topic_alias, emqx_mqtt_caps:get_caps(Zone), 0)} }. -spec(info(protocol()) -> emqx_types:infos()). @@ -104,10 +105,10 @@ info(will_delay_interval, #protocol{will_msg = undefined}) -> 0; info(will_delay_interval, #protocol{will_msg = WillMsg}) -> emqx_message:get_header('Will-Delay-Interval', WillMsg, 0); -info(conn_props, #protocol{conn_props = ConnProps}) -> - ConnProps; info(topic_aliases, #protocol{topic_aliases = Aliases}) -> - Aliases. + Aliases; +info(alias_maximum, #protocol{alias_maximum = AliasMaximum}) -> + AliasMaximum. -spec(attrs(protocol()) -> emqx_types:attrs()). attrs(Proto) -> @@ -128,4 +129,5 @@ save_alias(AliasId, Topic, Proto = #protocol{topic_aliases = Aliases}) -> Proto#protocol{topic_aliases = maps:put(AliasId, Topic, Aliases)}. clear_will_msg(Protocol) -> - Protocol#protocol{will_msg = undefined}. \ No newline at end of file + Protocol#protocol{will_msg = undefined}. + diff --git a/src/emqx_ws_connection.erl b/src/emqx_ws_connection.erl index 4e6baebdf..58c989453 100644 --- a/src/emqx_ws_connection.erl +++ b/src/emqx_ws_connection.erl @@ -192,7 +192,8 @@ websocket_init([Req, Opts]) -> fsm_state = idle, parse_state = ParseState, chan_state = ChanState, - pendings = []}}. + pendings = [], + serialize = serialize_fun(?MQTT_PROTO_V5, undefined)}}. websocket_handle({binary, Data}, State) when is_list(Data) -> websocket_handle({binary, iolist_to_binary(Data)}, State); @@ -255,8 +256,9 @@ websocket_info({cast, Msg}, State = #ws_connection{chan_state = ChanState}) -> websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State = #ws_connection{fsm_state = idle}) -> - #mqtt_packet_connect{proto_ver = ProtoVer} = ConnPkt, - NState = State#ws_connection{serialize = serialize_fun(ProtoVer)}, + #mqtt_packet_connect{proto_ver = ProtoVer, properties = Properties} = ConnPkt, + MaxPacketSize = emqx_mqtt_props:get_property('Maximum-Packet-Size', Properties, undefined), + NState = State#ws_connection{serialize = serialize_fun(ProtoVer, MaxPacketSize)}, handle_incoming(Packet, fun connected/1, NState); websocket_info({incoming, Packet}, State = #ws_connection{fsm_state = idle}) -> @@ -348,14 +350,24 @@ process_incoming(Data, State = #ws_connection{parse_state = ParseState, chan_sta stop(Reason, State) catch error:Reason:Stk -> - ?LOG(error, "Parse failed for ~p~n\ - Stacktrace:~p~nFrame data: ~p", [Reason, Stk, Data]), - case emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState) of + ?LOG(error, "Parse failed for ~p~nStacktrace:~p~nFrame data: ~p", [Reason, Stk, Data]), + Result = + case emqx_channel:info(connected, ChanState) of + undefined -> + emqx_channel:handle_out({connack, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState); + true -> + emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState); + _ -> + ignore + end, + case Result of {stop, Reason0, OutPackets, NChanState} -> NState = State#ws_connection{chan_state = NChanState}, stop(Reason0, enqueue(OutPackets, NState)); {stop, Reason0, NChanState} -> - stop(Reason0, State#ws_connection{chan_state = NChanState}) + stop(Reason0, State#ws_connection{chan_state = NChanState}); + ignore -> + {ok, State} end end. @@ -394,12 +406,19 @@ handle_outgoing(Packets, State = #ws_connection{serialize = Serialize, %%-------------------------------------------------------------------- %% Serialize fun -serialize_fun(ProtoVer) -> +serialize_fun(ProtoVer, MaxPacketSize) -> fun(Packet = ?PACKET(Type)) -> - ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), - _ = inc_outgoing_stats(Type), - _ = emqx_metrics:inc_sent(Packet), - emqx_frame:serialize(Packet, ProtoVer) + IoData = emqx_frame:serialize(Packet, ProtoVer), + case Type =/= ?PUBLISH orelse MaxPacketSize =:= undefined orelse iolist_size(IoData) =< MaxPacketSize of + true -> + ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), + _ = inc_outgoing_stats(Type), + _ = emqx_metrics:inc_sent(Packet), + IoData; + false -> + ?LOG(warning, "DROP ~s due to oversize packet size", [emqx_packet:format(Packet)]), + <<"">> + end end. %%-------------------------------------------------------------------- diff --git a/test/emqx_channel_SUITE.erl b/test/emqx_channel_SUITE.erl index 187d2e302..82f52c11b 100644 --- a/test/emqx_channel_SUITE.erl +++ b/test/emqx_channel_SUITE.erl @@ -289,7 +289,7 @@ with_channel(Fun) -> username = <<"username">>, password = <<"passwd">> }, - Protocol = emqx_protocol:init(ConnPkt), + Protocol = emqx_protocol:init(ConnPkt, testing), Session = emqx_session:init(#{zone => testing}, #{max_inflight => 100, expiry_interval => 0 diff --git a/test/emqx_packet_SUITE.erl b/test/emqx_packet_SUITE.erl index 732c48fe3..f3857ff2e 100644 --- a/test/emqx_packet_SUITE.erl +++ b/test/emqx_packet_SUITE.erl @@ -49,7 +49,7 @@ t_validate(_) -> [{<<"topic">>, #{qos => ?QOS_0}}]))), ?assertError(topic_filters_invalid, emqx_packet:validate(?UNSUBSCRIBE_PACKET(1,[]))), - ?assertError(topic_name_invalid, + ?assertError(protocol_error, emqx_packet:validate(?PUBLISH_PACKET(1,<<>>,1,#{},<<"payload">>))), ?assertError(topic_name_invalid, emqx_packet:validate(?PUBLISH_PACKET diff --git a/test/emqx_protocol_SUITE.erl b/test/emqx_protocol_SUITE.erl index 2cd091e30..2c7e9479f 100644 --- a/test/emqx_protocol_SUITE.erl +++ b/test/emqx_protocol_SUITE.erl @@ -38,7 +38,7 @@ init_protocol() -> client_id = <<"clientid">>, username = <<"username">>, password = <<"passwd">> - }). + }, testing). end_per_suite(_Config) -> ok. @@ -48,11 +48,11 @@ t_init_info_1(Config) -> proto_ver => ?MQTT_PROTO_V5, clean_start => true, keepalive => 30, - conn_props => #{}, will_msg => undefined, client_id => <<"clientid">>, username => <<"username">>, - topic_aliases => undefined + topic_aliases => undefined, + alias_maximum => #{outbound => 0, inbound => 0} }, emqx_protocol:info(Proto)). t_init_info_2(Config) -> @@ -65,8 +65,8 @@ t_init_info_2(Config) -> ?assertEqual(<<"username">>, emqx_protocol:info(username, Proto)), ?assertEqual(undefined, emqx_protocol:info(will_msg, Proto)), ?assertEqual(0, emqx_protocol:info(will_delay_interval, Proto)), - ?assertEqual(#{}, emqx_protocol:info(conn_props, Proto)), - ?assertEqual(undefined, emqx_protocol:info(topic_aliases, Proto)). + ?assertEqual(undefined, emqx_protocol:info(topic_aliases, Proto)), + ?assertEqual(#{outbound => 0, inbound => 0}, emqx_protocol:info(alias_maximum, Proto)). t_find_save_alias(Config) -> Proto = proplists:get_value(proto, Config),