Fix handling for MQTT spec (#2892)

Fix invalid QoS and protocol name, fix handling for Topic-Alias-Maximum and Maximum-Packet-Size, and send DISCONNECT Packet to client when the session is taken over
This commit is contained in:
tigercl 2019-09-16 13:51:50 +08:00 committed by GitHub
parent 59309e6c11
commit 35822ff97a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 122 additions and 59 deletions

View File

@ -525,7 +525,7 @@ handle_out({connack, ReasonCode}, Channel = #channel{client = Client,
}) -> }) ->
ok = emqx_hooks:run('client.connected', [Client, ReasonCode, attrs(Channel)]), ok = emqx_hooks:run('client.connected', [Client, ReasonCode, attrs(Channel)]),
ProtoVer = case Protocol of ProtoVer = case Protocol of
undefined -> undefined; undefined -> ?MQTT_PROTO_V5;
_ -> emqx_protocol:info(proto_ver, Protocol) _ -> emqx_protocol:info(proto_ver, Protocol)
end, end,
ReasonCode1 = if ReasonCode1 = if
@ -630,7 +630,10 @@ handle_out({Type, Data}, Channel) ->
handle_call(kick, Channel) -> handle_call(kick, Channel) ->
{stop, {shutdown, kicked}, ok, Channel}; {stop, {shutdown, kicked}, ok, Channel};
handle_call(discard, Channel) -> handle_call(discard, Channel = #channel{connected = true}) ->
Packet = ?DISCONNECT_PACKET(?RC_SESSION_TAKEN_OVER),
{stop, {shutdown, discarded}, Packet, ok, Channel};
handle_call(discard, Channel = #channel{connected = false}) ->
{stop, {shutdown, discarded}, ok, Channel}; {stop, {shutdown, discarded}, ok, Channel};
%% Session Takeover %% Session Takeover
@ -852,7 +855,7 @@ validate_packet(Packet, _Channel) ->
error:topic_filters_invalid -> error:topic_filters_invalid ->
{error, ?RC_TOPIC_FILTER_INVALID}; {error, ?RC_TOPIC_FILTER_INVALID};
error:topic_name_invalid -> error:topic_name_invalid ->
{error, ?RC_TOPIC_FILTER_INVALID}; {error, ?RC_TOPIC_NAME_INVALID};
error:_Reason -> error:_Reason ->
{error, ?RC_MALFORMED_PACKET} {error, ?RC_MALFORMED_PACKET}
end. end.
@ -916,8 +919,8 @@ check_will_retain(#mqtt_packet_connect{will_retain = true},
false -> {error, ?RC_RETAIN_NOT_SUPPORTED} false -> {error, ?RC_RETAIN_NOT_SUPPORTED}
end. end.
init_protocol(ConnPkt, Channel) -> init_protocol(ConnPkt, Channel = #channel{client = #{zone := Zone}}) ->
{ok, Channel#channel{protocol = emqx_protocol:init(ConnPkt)}}. {ok, Channel#channel{protocol = emqx_protocol:init(ConnPkt, Zone)}}.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Enrich client %% Enrich client

View File

@ -191,7 +191,8 @@ init({Transport, RawSocket, Options}) ->
rate_limit = RateLimit, rate_limit = RateLimit,
pub_limit = PubLimit, pub_limit = PubLimit,
parse_state = ParseState, parse_state = ParseState,
chan_state = ChanState chan_state = ChanState,
serialize = serialize_fun(?MQTT_PROTO_V5, undefined)
}, },
gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}], gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}],
idle, State, self(), [IdleTimout]). idle, State, self(), [IdleTimout]).
@ -217,8 +218,9 @@ idle(timeout, _Timeout, State) ->
shutdown(idle_timeout, State); shutdown(idle_timeout, State);
idle(cast, {incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State) -> idle(cast, {incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State) ->
#mqtt_packet_connect{proto_ver = ProtoVer} = ConnPkt, #mqtt_packet_connect{proto_ver = ProtoVer, properties = Properties} = ConnPkt,
NState = State#connection{serialize = serialize_fun(ProtoVer)}, MaxPacketSize = emqx_mqtt_props:get_property('Maximum-Packet-Size', Properties, undefined),
NState = State#connection{serialize = serialize_fun(ProtoVer, MaxPacketSize)},
SuccFun = fun(NewSt) -> {next_state, connected, NewSt} end, SuccFun = fun(NewSt) -> {next_state, connected, NewSt} end,
handle_incoming(Packet, SuccFun, NState); handle_incoming(Packet, SuccFun, NState);
@ -283,6 +285,10 @@ handle({call, From}, Req, State = #connection{chan_state = ChanState}) ->
{ok, Reply, NChanState} -> {ok, Reply, NChanState} ->
reply(From, Reply, State#connection{chan_state = NChanState}); reply(From, Reply, State#connection{chan_state = NChanState});
{stop, Reason, Reply, NChanState} -> {stop, Reason, Reply, NChanState} ->
ok = gen_statem:reply(From, Reply),
stop(Reason, State#connection{chan_state = NChanState});
{stop, Reason, Packet, Reply, NChanState} ->
handle_outgoing(Packet, fun (_) -> ok end, State#connection{chan_state = NChanState}),
ok = gen_statem:reply(From, Reply), ok = gen_statem:reply(From, Reply),
stop(Reason, State#connection{chan_state = NChanState}) stop(Reason, State#connection{chan_state = NChanState})
end; end;
@ -414,15 +420,25 @@ process_incoming(Data, Packets, State = #connection{parse_state = ParseState, ch
shutdown(Reason, State) shutdown(Reason, State)
catch catch
error:Reason:Stk -> error:Reason:Stk ->
?LOG(error, "Parse failed for ~p~n\ ?LOG(error, "Parse failed for ~p~nStacktrace:~p~nError data:~p", [Reason, Stk, Data]),
Stacktrace:~p~nError data:~p", [Reason, Stk, Data]), Result =
case emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState) of case emqx_channel:info(connected, ChanState) of
undefined ->
emqx_channel:handle_out({connack, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState);
true ->
emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState);
_ ->
ignore
end,
case Result of
{stop, Reason0, OutPackets, NChanState} -> {stop, Reason0, OutPackets, NChanState} ->
Shutdown = fun(NewSt) -> stop(Reason0, NewSt) end, Shutdown = fun(NewSt) -> stop(Reason0, NewSt) end,
NState = State#connection{chan_state = NChanState}, NState = State#connection{chan_state = NChanState},
handle_outgoing(OutPackets, Shutdown, NState); handle_outgoing(OutPackets, Shutdown, NState);
{stop, Reason0, NChanState} -> {stop, Reason0, NChanState} ->
stop(Reason0, State#connection{chan_state = NChanState}) stop(Reason0, State#connection{chan_state = NChanState});
ignore ->
keep_state(State)
end end
end. end.
@ -479,12 +495,19 @@ handle_outgoing(Packet, SuccFun, State = #connection{serialize = Serialize}) ->
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Serialize fun %% Serialize fun
serialize_fun(ProtoVer) -> serialize_fun(ProtoVer, MaxPacketSize) ->
fun(Packet = ?PACKET(Type)) -> fun(Packet = ?PACKET(Type)) ->
IoData = emqx_frame:serialize(Packet, ProtoVer),
case Type =/= ?PUBLISH orelse MaxPacketSize =:= undefined orelse iolist_size(IoData) =< MaxPacketSize of
true ->
?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]),
_ = inc_outgoing_stats(Type), _ = inc_outgoing_stats(Type),
_ = emqx_metrics:inc_sent(Packet), _ = emqx_metrics:inc_sent(Packet),
emqx_frame:serialize(Packet, ProtoVer) IoData;
false ->
?LOG(warning, "DROP ~s due to oversize packet size", [emqx_packet:format(Packet)]),
<<"">>
end
end. end.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------

View File

@ -280,7 +280,8 @@ do_inc_recv(?PUBLISH_PACKET(QoS, _PktId)) ->
case QoS of case QoS of
?QOS_0 -> inc('messages.qos0.received'); ?QOS_0 -> inc('messages.qos0.received');
?QOS_1 -> inc('messages.qos1.received'); ?QOS_1 -> inc('messages.qos1.received');
?QOS_2 -> inc('messages.qos2.received') ?QOS_2 -> inc('messages.qos2.received');
_ -> ignore
end, end,
inc('packets.publish.received'); inc('packets.publish.received');
do_inc_recv(?PACKET(?PUBACK)) -> do_inc_recv(?PACKET(?PUBACK)) ->

View File

@ -28,6 +28,10 @@
%% For tests %% For tests
-export([all/0]). -export([all/0]).
-export([ set_property/3
, get_property/3
]).
-type(prop_name() :: atom()). -type(prop_name() :: atom()).
-type(prop_id() :: pos_integer()). -type(prop_id() :: pos_integer()).
@ -179,3 +183,13 @@ validate_value(_Type, _Val) -> false.
-spec(all() -> map()). -spec(all() -> map()).
all() -> ?PROPS_TABLE. all() -> ?PROPS_TABLE.
set_property(Name, Value, undefined) ->
#{Name => Value};
set_property(Name, Value, Props) ->
Props#{Name => Value}.
get_property(_Name, undefined, Default) ->
Default;
get_property(Name, Props, Default) ->
maps:get(Name, Props, Default).

View File

@ -75,9 +75,10 @@ validate(?UNSUBSCRIBE_PACKET(PacketId, TopicFilters)) ->
validate(?PUBLISH_PACKET(_QoS, <<>>, _, #{'Topic-Alias':= _I}, _)) -> validate(?PUBLISH_PACKET(_QoS, <<>>, _, #{'Topic-Alias':= _I}, _)) ->
true; true;
validate(?PUBLISH_PACKET(_QoS, <<>>, _, _, _)) -> validate(?PUBLISH_PACKET(_QoS, <<>>, _, _, _)) ->
error(topic_name_invalid); error(protocol_error);
validate(?PUBLISH_PACKET(_QoS, Topic, _, Properties, _)) -> validate(?PUBLISH_PACKET(QoS, Topic, _, Properties, _)) ->
((not emqx_topic:wildcard(Topic)) orelse error(topic_name_invalid)) ((not (QoS =:= 3)) orelse error(qos_invalid))
andalso ((not emqx_topic:wildcard(Topic)) orelse error(topic_name_invalid))
andalso validate_properties(?PUBLISH, Properties); andalso validate_properties(?PUBLISH, Properties);
validate(?CONNECT_PACKET(#mqtt_packet_connect{properties = Properties})) -> validate(?CONNECT_PACKET(#mqtt_packet_connect{properties = Properties})) ->

View File

@ -20,7 +20,7 @@
-include("types.hrl"). -include("types.hrl").
-include("emqx_mqtt.hrl"). -include("emqx_mqtt.hrl").
-export([ init/1 -export([ init/2
, info/1 , info/1
, info/2 , info/2
, attrs/1 , attrs/1
@ -48,10 +48,10 @@
username :: emqx_types:username(), username :: emqx_types:username(),
%% MQTT Will Msg %% MQTT Will Msg
will_msg :: emqx_types:message(), will_msg :: emqx_types:message(),
%% MQTT Conn Properties
conn_props :: maybe(emqx_types:properties()),
%% MQTT Topic Aliases %% MQTT Topic Aliases
topic_aliases :: maybe(map()) topic_aliases :: maybe(map()),
%% MQTT Topic Alias Maximum
alias_maximum :: maybe(map())
}). }).
-opaque(protocol() :: #protocol{}). -opaque(protocol() :: #protocol{}).
@ -60,14 +60,14 @@
-define(ATTR_KEYS, [proto_name, proto_ver, clean_start, keepalive]). -define(ATTR_KEYS, [proto_name, proto_ver, clean_start, keepalive]).
-spec(init(#mqtt_packet_connect{}) -> protocol()). -spec(init(#mqtt_packet_connect{}, atom()) -> protocol()).
init(#mqtt_packet_connect{proto_name = ProtoName, init(#mqtt_packet_connect{proto_name = ProtoName,
proto_ver = ProtoVer, proto_ver = ProtoVer,
clean_start = CleanStart, clean_start = CleanStart,
keepalive = Keepalive, keepalive = Keepalive,
properties = Properties, properties = Properties,
client_id = ClientId, client_id = ClientId,
username = Username} = ConnPkt) -> username = Username} = ConnPkt, Zone) ->
WillMsg = emqx_packet:will_msg(ConnPkt), WillMsg = emqx_packet:will_msg(ConnPkt),
#protocol{proto_name = ProtoName, #protocol{proto_name = ProtoName,
proto_ver = ProtoVer, proto_ver = ProtoVer,
@ -76,7 +76,8 @@ init(#mqtt_packet_connect{proto_name = ProtoName,
client_id = ClientId, client_id = ClientId,
username = Username, username = Username,
will_msg = WillMsg, will_msg = WillMsg,
conn_props = Properties alias_maximum = #{outbound => emqx_mqtt_props:get_property('Topic-Alias-Maximum', Properties, 0),
inbound => maps:get(max_topic_alias, emqx_mqtt_caps:get_caps(Zone), 0)}
}. }.
-spec(info(protocol()) -> emqx_types:infos()). -spec(info(protocol()) -> emqx_types:infos()).
@ -104,10 +105,10 @@ info(will_delay_interval, #protocol{will_msg = undefined}) ->
0; 0;
info(will_delay_interval, #protocol{will_msg = WillMsg}) -> info(will_delay_interval, #protocol{will_msg = WillMsg}) ->
emqx_message:get_header('Will-Delay-Interval', WillMsg, 0); emqx_message:get_header('Will-Delay-Interval', WillMsg, 0);
info(conn_props, #protocol{conn_props = ConnProps}) ->
ConnProps;
info(topic_aliases, #protocol{topic_aliases = Aliases}) -> info(topic_aliases, #protocol{topic_aliases = Aliases}) ->
Aliases. Aliases;
info(alias_maximum, #protocol{alias_maximum = AliasMaximum}) ->
AliasMaximum.
-spec(attrs(protocol()) -> emqx_types:attrs()). -spec(attrs(protocol()) -> emqx_types:attrs()).
attrs(Proto) -> attrs(Proto) ->
@ -129,3 +130,4 @@ save_alias(AliasId, Topic, Proto = #protocol{topic_aliases = Aliases}) ->
clear_will_msg(Protocol) -> clear_will_msg(Protocol) ->
Protocol#protocol{will_msg = undefined}. Protocol#protocol{will_msg = undefined}.

View File

@ -192,7 +192,8 @@ websocket_init([Req, Opts]) ->
fsm_state = idle, fsm_state = idle,
parse_state = ParseState, parse_state = ParseState,
chan_state = ChanState, chan_state = ChanState,
pendings = []}}. pendings = [],
serialize = serialize_fun(?MQTT_PROTO_V5, undefined)}}.
websocket_handle({binary, Data}, State) when is_list(Data) -> websocket_handle({binary, Data}, State) when is_list(Data) ->
websocket_handle({binary, iolist_to_binary(Data)}, State); websocket_handle({binary, iolist_to_binary(Data)}, State);
@ -255,8 +256,9 @@ websocket_info({cast, Msg}, State = #ws_connection{chan_state = ChanState}) ->
websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)},
State = #ws_connection{fsm_state = idle}) -> State = #ws_connection{fsm_state = idle}) ->
#mqtt_packet_connect{proto_ver = ProtoVer} = ConnPkt, #mqtt_packet_connect{proto_ver = ProtoVer, properties = Properties} = ConnPkt,
NState = State#ws_connection{serialize = serialize_fun(ProtoVer)}, MaxPacketSize = emqx_mqtt_props:get_property('Maximum-Packet-Size', Properties, undefined),
NState = State#ws_connection{serialize = serialize_fun(ProtoVer, MaxPacketSize)},
handle_incoming(Packet, fun connected/1, NState); handle_incoming(Packet, fun connected/1, NState);
websocket_info({incoming, Packet}, State = #ws_connection{fsm_state = idle}) -> websocket_info({incoming, Packet}, State = #ws_connection{fsm_state = idle}) ->
@ -348,14 +350,24 @@ process_incoming(Data, State = #ws_connection{parse_state = ParseState, chan_sta
stop(Reason, State) stop(Reason, State)
catch catch
error:Reason:Stk -> error:Reason:Stk ->
?LOG(error, "Parse failed for ~p~n\ ?LOG(error, "Parse failed for ~p~nStacktrace:~p~nFrame data: ~p", [Reason, Stk, Data]),
Stacktrace:~p~nFrame data: ~p", [Reason, Stk, Data]), Result =
case emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState) of case emqx_channel:info(connected, ChanState) of
undefined ->
emqx_channel:handle_out({connack, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState);
true ->
emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState);
_ ->
ignore
end,
case Result of
{stop, Reason0, OutPackets, NChanState} -> {stop, Reason0, OutPackets, NChanState} ->
NState = State#ws_connection{chan_state = NChanState}, NState = State#ws_connection{chan_state = NChanState},
stop(Reason0, enqueue(OutPackets, NState)); stop(Reason0, enqueue(OutPackets, NState));
{stop, Reason0, NChanState} -> {stop, Reason0, NChanState} ->
stop(Reason0, State#ws_connection{chan_state = NChanState}) stop(Reason0, State#ws_connection{chan_state = NChanState});
ignore ->
{ok, State}
end end
end. end.
@ -394,12 +406,19 @@ handle_outgoing(Packets, State = #ws_connection{serialize = Serialize,
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Serialize fun %% Serialize fun
serialize_fun(ProtoVer) -> serialize_fun(ProtoVer, MaxPacketSize) ->
fun(Packet = ?PACKET(Type)) -> fun(Packet = ?PACKET(Type)) ->
IoData = emqx_frame:serialize(Packet, ProtoVer),
case Type =/= ?PUBLISH orelse MaxPacketSize =:= undefined orelse iolist_size(IoData) =< MaxPacketSize of
true ->
?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]),
_ = inc_outgoing_stats(Type), _ = inc_outgoing_stats(Type),
_ = emqx_metrics:inc_sent(Packet), _ = emqx_metrics:inc_sent(Packet),
emqx_frame:serialize(Packet, ProtoVer) IoData;
false ->
?LOG(warning, "DROP ~s due to oversize packet size", [emqx_packet:format(Packet)]),
<<"">>
end
end. end.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------

View File

@ -289,7 +289,7 @@ with_channel(Fun) ->
username = <<"username">>, username = <<"username">>,
password = <<"passwd">> password = <<"passwd">>
}, },
Protocol = emqx_protocol:init(ConnPkt), Protocol = emqx_protocol:init(ConnPkt, testing),
Session = emqx_session:init(#{zone => testing}, Session = emqx_session:init(#{zone => testing},
#{max_inflight => 100, #{max_inflight => 100,
expiry_interval => 0 expiry_interval => 0

View File

@ -49,7 +49,7 @@ t_validate(_) ->
[{<<"topic">>, #{qos => ?QOS_0}}]))), [{<<"topic">>, #{qos => ?QOS_0}}]))),
?assertError(topic_filters_invalid, ?assertError(topic_filters_invalid,
emqx_packet:validate(?UNSUBSCRIBE_PACKET(1,[]))), emqx_packet:validate(?UNSUBSCRIBE_PACKET(1,[]))),
?assertError(topic_name_invalid, ?assertError(protocol_error,
emqx_packet:validate(?PUBLISH_PACKET(1,<<>>,1,#{},<<"payload">>))), emqx_packet:validate(?PUBLISH_PACKET(1,<<>>,1,#{},<<"payload">>))),
?assertError(topic_name_invalid, ?assertError(topic_name_invalid,
emqx_packet:validate(?PUBLISH_PACKET emqx_packet:validate(?PUBLISH_PACKET

View File

@ -38,7 +38,7 @@ init_protocol() ->
client_id = <<"clientid">>, client_id = <<"clientid">>,
username = <<"username">>, username = <<"username">>,
password = <<"passwd">> password = <<"passwd">>
}). }, testing).
end_per_suite(_Config) -> ok. end_per_suite(_Config) -> ok.
@ -48,11 +48,11 @@ t_init_info_1(Config) ->
proto_ver => ?MQTT_PROTO_V5, proto_ver => ?MQTT_PROTO_V5,
clean_start => true, clean_start => true,
keepalive => 30, keepalive => 30,
conn_props => #{},
will_msg => undefined, will_msg => undefined,
client_id => <<"clientid">>, client_id => <<"clientid">>,
username => <<"username">>, username => <<"username">>,
topic_aliases => undefined topic_aliases => undefined,
alias_maximum => #{outbound => 0, inbound => 0}
}, emqx_protocol:info(Proto)). }, emqx_protocol:info(Proto)).
t_init_info_2(Config) -> t_init_info_2(Config) ->
@ -65,8 +65,8 @@ t_init_info_2(Config) ->
?assertEqual(<<"username">>, emqx_protocol:info(username, Proto)), ?assertEqual(<<"username">>, emqx_protocol:info(username, Proto)),
?assertEqual(undefined, emqx_protocol:info(will_msg, Proto)), ?assertEqual(undefined, emqx_protocol:info(will_msg, Proto)),
?assertEqual(0, emqx_protocol:info(will_delay_interval, Proto)), ?assertEqual(0, emqx_protocol:info(will_delay_interval, Proto)),
?assertEqual(#{}, emqx_protocol:info(conn_props, Proto)), ?assertEqual(undefined, emqx_protocol:info(topic_aliases, Proto)),
?assertEqual(undefined, emqx_protocol:info(topic_aliases, Proto)). ?assertEqual(#{outbound => 0, inbound => 0}, emqx_protocol:info(alias_maximum, Proto)).
t_find_save_alias(Config) -> t_find_save_alias(Config) ->
Proto = proplists:get_value(proto, Config), Proto = proplists:get_value(proto, Config),