From 8ab682151d675d4b68b3a2dd3b394bc985d45fcd Mon Sep 17 00:00:00 2001 From: Feng Lee Date: Sun, 29 Sep 2019 10:22:02 +0800 Subject: [PATCH] Improve the connection and channel modules - Rename the 'client_id' field to 'clientid' - Support publish stats in channel module - Update test cases for frame and channel modules --- src/emqx_channel.erl | 396 ++++++++++++++++++--------------- src/emqx_connection.erl | 367 +++++++++++++++--------------- src/emqx_frame.erl | 114 +++++++--- src/emqx_session.erl | 69 ++++-- src/emqx_stats.erl | 4 +- src/emqx_ws_connection.erl | 301 +++++++++++-------------- src/emqx_zone.erl | 35 ++- test/emqx_channel_SUITE.erl | 114 ++++++---- test/emqx_client_SUITE.erl | 17 +- test/emqx_connection_SUITE.erl | 2 +- test/emqx_frame_SUITE.erl | 10 +- test/emqx_misc_SUITE.erl | 2 +- test/emqx_session_SUITE.erl | 4 +- test/emqx_shared_sub_SUITE.erl | 12 +- 14 files changed, 781 insertions(+), 666 deletions(-) diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl index 5d76f8414..cad4111d4 100644 --- a/src/emqx_channel.erl +++ b/src/emqx_channel.erl @@ -38,9 +38,9 @@ , handle_in/2 , handle_out/2 , handle_call/2 - , handle_cast/2 , handle_info/2 , handle_timeout/3 + , disconnect/2 , terminate/2 ]). @@ -60,7 +60,7 @@ %% MQTT ConnInfo conninfo :: emqx_types:conninfo(), %% MQTT ClientInfo - client_info :: emqx_types:client_info(), + clientinfo :: emqx_types:clientinfo(), %% MQTT Session session :: emqx_session:session(), %% Keepalive @@ -71,18 +71,14 @@ topic_aliases :: maybe(map()), %% MQTT Topic Alias Maximum alias_maximum :: maybe(map()), + %% Publish Stats + pub_stats :: emqx_types:stats(), %% Timers timers :: #{atom() => disabled | maybe(reference())}, + %% Fsm State + fsm_state :: fsm_state(), %% GC State gc_state :: maybe(emqx_gc:gc_state()), - %% OOM Policy TODO: should be removed from channel. - oom_policy :: maybe(emqx_oom:oom_policy()), - %% Connected - connected :: undefined | boolean(), - %% Connected at - connected_at :: erlang:timestamp(), - %% Disconnected at - disconnected_at :: erlang:timestamp(), %% Takeover takeover :: boolean(), %% Resume @@ -93,6 +89,14 @@ -opaque(channel() :: #channel{}). +-type(fsm_state() :: #{state_name := initialized + | connecting + | connected + | disconnected, + connected_at := pos_integer(), + disconnected := pos_integer() + }). + -define(TIMER_TABLE, #{ stats_timer => emit_stats, alive_timer => keepalive, @@ -102,8 +106,10 @@ will_timer => will_message }). --define(ATTR_KEYS, [conninfo, client_info, session, connected, connected_at, disconnected_at]). --define(INFO_KEYS, ?ATTR_KEYS ++ [keepalive, topic_aliases, alias_maximum, gc_state, disconnected_at]). +-define(ATTR_KEYS, [conninfo, client, session]). + +-define(INFO_KEYS, ?ATTR_KEYS ++ [conninfo, client, session, keepalive, + will_msg, topic_aliases, alias_maximum, gc_state]). %%-------------------------------------------------------------------- %% Info, Attrs and Caps @@ -119,7 +125,7 @@ info(Keys, Channel) when is_list(Keys) -> [{Key, info(Key, Channel)} || Key <- Keys]; info(conninfo, #channel{conninfo = ConnInfo}) -> ConnInfo; -info(client_info, #channel{client_info = ClientInfo}) -> +info(client, #channel{clientinfo = ClientInfo}) -> ClientInfo; info(session, #channel{session = Session}) -> maybe_apply(fun emqx_session:info/1, Session); @@ -129,36 +135,31 @@ info(topic_aliases, #channel{topic_aliases = Aliases}) -> Aliases; info(alias_maximum, #channel{alias_maximum = Limits}) -> Limits; +info(will_msg, #channel{will_msg = undefined}) -> + undefined; info(will_msg, #channel{will_msg = WillMsg}) -> - WillMsg; + emqx_message:to_map(WillMsg); +info(pub_stats, #channel{pub_stats = PubStats}) -> + PubStats; info(gc_state, #channel{gc_state = GcState}) -> - maybe_apply(fun emqx_gc:info/1, GcState); -info(oom_policy, #channel{oom_policy = OomPolicy}) -> - maybe_apply(fun emqx_oom:info/1, OomPolicy); -info(connected, #channel{connected = Connected}) -> - Connected; -info(connected_at, #channel{connected_at = ConnectedAt}) -> - ConnectedAt; -info(disconnected_at, #channel{disconnected_at = DisconnectedAt}) -> - DisconnectedAt. + maybe_apply(fun emqx_gc:info/1, GcState). %% @doc Get attrs of the channel. -spec(attrs(channel()) -> emqx_types:attrs()). attrs(Channel) -> - maps:from_list([{Key, attr(Key, Channel)} || Key <- ?ATTR_KEYS]). + Attrs = [{Key, attrs(Key, Channel)} || Key <- ?ATTR_KEYS], + maps:from_list(Attrs). -attr(conninfo, #channel{conninfo = ConnInfo}) -> - ConnInfo; -attr(session, #channel{session = Session}) -> +attrs(session, #channel{session = Session}) -> maybe_apply(fun emqx_session:attrs/1, Session); -attr(Key, Channel) -> info(Key, Channel). +attrs(Key, Channel) -> info(Key, Channel). -spec(stats(channel()) -> emqx_types:stats()). -stats(#channel{session = Session}) -> - emqx_session:stats(Session). +stats(#channel{pub_stats = PubStats, session = Session}) -> + maps:to_list(PubStats) ++ emqx_session:stats(Session). -spec(caps(channel()) -> emqx_types:caps()). -caps(#channel{client_info = #{zone := Zone}}) -> +caps(#channel{clientinfo = #{zone := Zone}}) -> emqx_mqtt_caps:get_caps(Zone). %% For tests @@ -172,7 +173,7 @@ set_field(Name, Val, Channel) -> %%-------------------------------------------------------------------- -spec(init(emqx_types:conninfo(), proplists:proplist()) -> channel()). -init(ConnInfo = #{peername := {PeerHost, _Port}, protocol := Protocol}, Options) -> +init(ConnInfo = #{peername := {PeerHost, _Port}}, Options) -> Zone = proplists:get_value(zone, Options), Peercert = maps:get(peercert, ConnInfo, undefined), Username = case peer_cert_as_username(Options) of @@ -181,12 +182,13 @@ init(ConnInfo = #{peername := {PeerHost, _Port}, protocol := Protocol}, Options) crt -> Peercert; _ -> undefined end, + Protocol = maps:get(protocol, ConnInfo, mqtt), MountPoint = emqx_zone:get_env(Zone, mountpoint), ClientInfo = #{zone => Zone, protocol => Protocol, peerhost => PeerHost, peercert => Peercert, - client_id => undefined, + clientid => undefined, username => Username, mountpoint => MountPoint, is_bridge => false, @@ -196,15 +198,15 @@ init(ConnInfo = #{peername := {PeerHost, _Port}, protocol := Protocol}, Options) true -> undefined; false -> disabled end, - #channel{conninfo = ConnInfo, - client_info = ClientInfo, - gc_state = init_gc_state(Zone), - oom_policy = init_oom_policy(Zone), - timers = #{stats_timer => StatsTimer}, - connected = undefined, - takeover = false, - resuming = false, - pendings = [] + #channel{conninfo = ConnInfo, + clientinfo = ClientInfo, + pub_stats = #{}, + timers = #{stats_timer => StatsTimer}, + fsm_state = #{state_name => initialized}, + gc_state = init_gc_state(Zone), + takeover = false, + resuming = false, + pendings = [] }. peer_cert_as_username(Options) -> @@ -213,9 +215,6 @@ peer_cert_as_username(Options) -> init_gc_state(Zone) -> maybe_apply(fun emqx_gc:init/1, emqx_zone:force_gc_policy(Zone)). -init_oom_policy(Zone) -> - maybe_apply(fun emqx_oom:init/1, emqx_zone:force_shutdown_policy(Zone)). - %%-------------------------------------------------------------------- %% Handle incoming packet %%-------------------------------------------------------------------- @@ -224,9 +223,11 @@ init_oom_policy(Zone) -> -> {ok, channel()} | {ok, emqx_types:packet(), channel()} | {ok, list(emqx_types:packet()), channel()} + | {close, channel()} + | {close, emqx_types:packet(), channel()} | {stop, Error :: term(), channel()} | {stop, Error :: term(), emqx_types:packet(), channel()}). -handle_in(?CONNECT_PACKET(_), Channel = #channel{connected = true}) -> +handle_in(?CONNECT_PACKET(_), Channel = #channel{fsm_state = #{state_name := connected}}) -> handle_out({disconnect, ?RC_PROTOCOL_ERROR}, Channel); handle_in(?CONNECT_PACKET(ConnPkt), Channel) -> @@ -244,73 +245,77 @@ handle_in(?CONNECT_PACKET(ConnPkt), Channel) -> end; handle_in(Packet = ?PUBLISH_PACKET(_QoS), Channel) -> + Channel1 = inc_pub_stats(publish_in, Channel), case emqx_packet:check(Packet) of - ok -> - handle_publish(Packet, Channel); + ok -> handle_publish(Packet, Channel1); {error, ReasonCode} -> - handle_out({disconnect, ReasonCode}, Channel) + handle_out({disconnect, ReasonCode}, Channel1) end; handle_in(?PUBACK_PACKET(PacketId, _ReasonCode), - Channel = #channel{client_info = ClientInfo, session = Session}) -> + Channel = #channel{clientinfo = ClientInfo, session = Session}) -> + Channel1 = inc_pub_stats(puback_in, Channel), case emqx_session:puback(PacketId, Session) of {ok, Msg, Publishes, NSession} -> ok = emqx_hooks:run('message.acked', [ClientInfo, Msg]), - handle_out({publish, Publishes}, Channel#channel{session = NSession}); + handle_out({publish, Publishes}, Channel1#channel{session = NSession}); {ok, Msg, NSession} -> ok = emqx_hooks:run('message.acked', [ClientInfo, Msg]), - {ok, Channel#channel{session = NSession}}; + {ok, Channel1#channel{session = NSession}}; {error, ?RC_PACKET_IDENTIFIER_IN_USE} -> ?LOG(warning, "The PUBACK PacketId ~w is inuse.", [PacketId]), ok = emqx_metrics:inc('packets.puback.inuse'), - {ok, Channel}; + {ok, Channel1}; {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND} -> ?LOG(warning, "The PUBACK PacketId ~w is not found", [PacketId]), ok = emqx_metrics:inc('packets.puback.missed'), - {ok, Channel} + {ok, Channel1} end; handle_in(?PUBREC_PACKET(PacketId, _ReasonCode), - Channel = #channel{client_info = ClientInfo, session = Session}) -> + Channel = #channel{clientinfo = ClientInfo, session = Session}) -> + Channel1 = inc_pub_stats(pubrec_in, Channel), case emqx_session:pubrec(PacketId, Session) of {ok, Msg, NSession} -> ok = emqx_hooks:run('message.acked', [ClientInfo, Msg]), - NChannel = Channel#channel{session = NSession}, + NChannel = Channel1#channel{session = NSession}, handle_out({pubrel, PacketId, ?RC_SUCCESS}, NChannel); {error, RC = ?RC_PACKET_IDENTIFIER_IN_USE} -> ?LOG(warning, "The PUBREC PacketId ~w is inuse.", [PacketId]), ok = emqx_metrics:inc('packets.pubrec.inuse'), - handle_out({pubrel, PacketId, RC}, Channel); + handle_out({pubrel, PacketId, RC}, Channel1); {error, RC = ?RC_PACKET_IDENTIFIER_NOT_FOUND} -> ?LOG(warning, "The PUBREC ~w is not found.", [PacketId]), ok = emqx_metrics:inc('packets.pubrec.missed'), - handle_out({pubrel, PacketId, RC}, Channel) + handle_out({pubrel, PacketId, RC}, Channel1) end; handle_in(?PUBREL_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Session}) -> + Channel1 = inc_pub_stats(pubrel_in, Channel), case emqx_session:pubrel(PacketId, Session) of {ok, NSession} -> - handle_out({pubcomp, PacketId, ?RC_SUCCESS}, Channel#channel{session = NSession}); + handle_out({pubcomp, PacketId, ?RC_SUCCESS}, Channel1#channel{session = NSession}); {error, NotFound} -> - ?LOG(warning, "The PUBREL PacketId ~w is not found", [PacketId]), ok = emqx_metrics:inc('packets.pubrel.missed'), - handle_out({pubcomp, PacketId, NotFound}, Channel) + ?LOG(warning, "The PUBREL PacketId ~w is not found", [PacketId]), + handle_out({pubcomp, PacketId, NotFound}, Channel1) end; handle_in(?PUBCOMP_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Session}) -> + Channel1 = inc_pub_stats(pubcomp_in, Channel), case emqx_session:pubcomp(PacketId, Session) of {ok, Publishes, NSession} -> - handle_out({publish, Publishes}, Channel#channel{session = NSession}); + handle_out({publish, Publishes}, Channel1#channel{session = NSession}); {ok, NSession} -> - {ok, Channel#channel{session = NSession}}; + {ok, Channel1#channel{session = NSession}}; {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND} -> ?LOG(warning, "The PUBCOMP PacketId ~w is not found", [PacketId]), ok = emqx_metrics:inc('packets.pubcomp.missed'), - {ok, Channel} + {ok, Channel1} end; handle_in(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), - Channel = #channel{client_info = ClientInfo}) -> + Channel = #channel{clientinfo = ClientInfo}) -> case emqx_packet:check(Packet) of ok -> TopicFilters1 = emqx_hooks:run_fold('client.subscribe', [ClientInfo, Properties], @@ -323,7 +328,7 @@ handle_in(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), end; handle_in(Packet = ?UNSUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), - Channel = #channel{client_info = ClientInfo}) -> + Channel = #channel{clientinfo = ClientInfo}) -> case emqx_packet:check(Packet) of ok -> TopicFilters1 = emqx_hooks:run_fold('client.unsubscribe', [ClientInfo, Properties], @@ -339,37 +344,49 @@ handle_in(?PACKET(?PINGREQ), Channel) -> handle_in(?DISCONNECT_PACKET(ReasonCode, Properties), Channel = #channel{conninfo = ConnInfo}) -> #{proto_ver := ProtoVer, expiry_interval := OldInterval} = ConnInfo, + {ReasonName, Channel1} = case ReasonCode of + ?RC_SUCCESS -> + {normal, Channel#channel{will_msg = undefined}}; + _Other -> + {emqx_reason_codes:name(ReasonCode, ProtoVer), Channel} + end, Interval = emqx_mqtt_props:get('Session-Expiry-Interval', Properties, OldInterval), - case OldInterval =:= 0 andalso Interval =/= OldInterval of + if + OldInterval == 0 andalso Interval > OldInterval -> + handle_out({disconnect, ?RC_PROTOCOL_ERROR}, Channel1); + Interval == 0 -> + {stop, ReasonName, Channel1}; true -> - handle_out({disconnect, ?RC_PROTOCOL_ERROR}, Channel); - false -> - Reason = case ReasonCode of - ?RC_SUCCESS -> normal; - _ -> emqx_reason_codes:name(ReasonCode, ProtoVer) - end, - Channel1 = Channel#channel{conninfo = ConnInfo#{expiry_interval := Interval}}, - Channel2 = case ReasonCode of - ?RC_SUCCESS -> Channel1#channel{will_msg = undefined}; - _ -> Channel1 - end, - {wait_session_expire, {shutdown, Reason}, Channel2} + Channel2 = Channel1#channel{conninfo = ConnInfo#{expiry_interval => Interval}}, + {close, ReasonName, Channel2} end; handle_in(?AUTH_PACKET(), Channel) -> - %%TODO: implement later. - {ok, Channel}; + handle_out({disconnect, ?RC_IMPLEMENTATION_SPECIFIC_ERROR}, Channel); + +handle_in({frame_error, Reason}, Channel = #channel{fsm_state = FsmState}) -> + case FsmState of + #{state_name := initialized} -> + {stop, {shutdown, Reason}, Channel}; + #{state_name := connecting} -> + {stop, {shutdown, Reason}, ?CONNACK_PACKET(?RC_MALFORMED_PACKET), Channel}; + #{state_name := connected} -> + handle_out({disconnect, ?RC_MALFORMED_PACKET}, Channel); + #{state_name := disconnected} -> + ?LOG(error, "Unexpected frame error: ~p", [Reason]), + {ok, Channel} + end; handle_in(Packet, Channel) -> ?LOG(error, "Unexpected incoming: ~p", [Packet]), - handle_out({disconnect, ?RC_MALFORMED_PACKET}, Channel). + handle_out({disconnect, ?RC_PROTOCOL_ERROR}, Channel). %%-------------------------------------------------------------------- %% Process Connect %%-------------------------------------------------------------------- process_connect(ConnPkt = #mqtt_packet_connect{clean_start = CleanStart}, - Channel = #channel{conninfo = ConnInfo, client_info = ClientInfo}) -> + Channel = #channel{conninfo = ConnInfo, clientinfo = ClientInfo}) -> case emqx_cm:open_session(CleanStart, ClientInfo, ConnInfo) of {ok, #{session := Session, present := false}} -> NChannel = Channel#channel{session = Session}, @@ -391,6 +408,11 @@ process_connect(ConnPkt = #mqtt_packet_connect{clean_start = CleanStart}, %% Process Publish %%-------------------------------------------------------------------- +inc_pub_stats(Key, Channel) -> inc_pub_stats(Key, 1, Channel). +inc_pub_stats(Key, I, Channel = #channel{pub_stats = PubStats}) -> + NPubStats = maps:update_with(Key, fun(V) -> V+I end, I, PubStats), + Channel#channel{pub_stats = NPubStats}. + handle_publish(Packet = ?PUBLISH_PACKET(_QoS, Topic, _PacketId), Channel = #channel{conninfo = #{proto_ver := ProtoVer}}) -> case pipeline([fun process_alias/2, @@ -440,7 +462,7 @@ process_publish(PacketId, Msg = #message{qos = ?QOS_2}, end. publish_to_msg(Packet, #channel{conninfo = #{proto_ver := ProtoVer}, - client_info = ClientInfo = #{mountpoint := MountPoint}}) -> + clientinfo = ClientInfo = #{mountpoint := MountPoint}}) -> Msg = emqx_packet:to_message(ClientInfo, Packet), Msg1 = emqx_message:set_flag(dup, false, Msg), Msg2 = emqx_message:set_header(proto_ver, ProtoVer, Msg1), @@ -461,7 +483,7 @@ process_subscribe([{TopicFilter, SubOpts}|More], Acc, Channel) -> process_subscribe(More, [RC|Acc], NChannel). do_subscribe(TopicFilter, SubOpts = #{qos := QoS}, Channel = - #channel{client_info = ClientInfo = #{mountpoint := MountPoint}, + #channel{clientinfo = ClientInfo = #{mountpoint := MountPoint}, session = Session}) -> case check_subscribe(TopicFilter, SubOpts, Channel) of ok -> @@ -491,7 +513,7 @@ process_unsubscribe([{TopicFilter, SubOpts}|More], Acc, Channel) -> process_unsubscribe(More, [RC|Acc], NChannel). do_unsubscribe(TopicFilter, _SubOpts, Channel = - #channel{client_info = ClientInfo = #{mountpoint := MountPoint}, + #channel{clientinfo = ClientInfo = #{mountpoint := MountPoint}, session = Session}) -> TopicFilter1 = emqx_mountpoint:mount(MountPoint, TopicFilter), case emqx_session:unsubscribe(ClientInfo, TopicFilter1, Session) of @@ -506,22 +528,24 @@ do_unsubscribe(TopicFilter, _SubOpts, Channel = %%TODO: RunFold or Pipeline handle_out({connack, ?RC_SUCCESS, SP, ConnPkt}, - Channel = #channel{conninfo = ConnInfo, client_info = ClientInfo}) -> + Channel = #channel{conninfo = ConnInfo, + clientinfo = ClientInfo, + fsm_state = FsmState}) -> AckProps = run_fold([fun enrich_caps/2, fun enrich_server_keepalive/2, - fun enrich_assigned_clientid/2 - ], #{}, Channel), - Channel1 = Channel#channel{will_msg = emqx_packet:will_msg(ConnPkt), - alias_maximum = init_alias_maximum(ConnPkt, ClientInfo), - connected = true, - connected_at = os:timestamp() + fun enrich_assigned_clientid/2], #{}, Channel), + FsmState1 = FsmState#{state_name => connected, + connected_at => erlang:system_time(second) + }, + Channel1 = Channel#channel{fsm_state = FsmState1, + will_msg = emqx_packet:will_msg(ConnPkt), + alias_maximum = init_alias_maximum(ConnPkt, ClientInfo) }, Channel2 = ensure_keepalive(AckProps, Channel1), ok = emqx_hooks:run('client.connected', [ClientInfo, ?RC_SUCCESS, ConnInfo]), AckPacket = ?CONNACK_PACKET(?RC_SUCCESS, SP, AckProps), case maybe_resume_session(Channel2) of - ignore -> - {ok, AckPacket, Channel2}; + ignore -> {ok, AckPacket, Channel2}; {ok, Publishes, NSession} -> Channel3 = Channel2#channel{session = NSession, resuming = false, @@ -531,7 +555,7 @@ handle_out({connack, ?RC_SUCCESS, SP, ConnPkt}, end; handle_out({connack, ReasonCode, _ConnPkt}, Channel = #channel{conninfo = ConnInfo, - client_info = ClientInfo}) -> + clientinfo = ClientInfo}) -> ok = emqx_hooks:run('client.connected', [ClientInfo, ReasonCode, ConnInfo]), ReasonCode1 = case ProtoVer = maps:get(proto_ver, ConnInfo) of ?MQTT_PROTO_V5 -> ReasonCode; @@ -540,8 +564,8 @@ handle_out({connack, ReasonCode, _ConnPkt}, Channel = #channel{conninfo = ConnIn Reason = emqx_reason_codes:name(ReasonCode1, ProtoVer), {stop, {shutdown, Reason}, ?CONNACK_PACKET(ReasonCode1), Channel}; -handle_out({deliver, Delivers}, Channel = #channel{session = Session, - connected = false}) -> +handle_out({deliver, Delivers}, Channel = #channel{fsm_state = #{state_name := disconnected}, + session = Session}) -> NSession = emqx_session:enqueue(Delivers, Session), {ok, Channel#channel{session = NSession}}; @@ -567,32 +591,33 @@ handle_out({publish, Publishes}, Channel) when is_list(Publishes) -> {ok, _Ch} -> Acc end end, [], Publishes), - {ok, lists:reverse(Packets), Channel}; + NChannel = inc_pub_stats(publish_out, length(Packets), Channel), + {ok, lists:reverse(Packets), NChannel}; %% Ignore loop deliver handle_out({publish, _PacketId, #message{from = ClientId, flags = #{nl := true}}}, - Channel = #channel{client_info = #{client_id := ClientId}}) -> + Channel = #channel{clientinfo = #{clientid := ClientId}}) -> {ok, Channel}; handle_out({publish, PacketId, Msg}, Channel = - #channel{client_info = ClientInfo = #{mountpoint := MountPoint}}) -> + #channel{clientinfo = ClientInfo = #{mountpoint := MountPoint}}) -> Msg1 = emqx_message:update_expiry(Msg), Msg2 = emqx_hooks:run_fold('message.delivered', [ClientInfo], Msg1), Msg3 = emqx_mountpoint:unmount(MountPoint, Msg2), {ok, emqx_message:to_packet(PacketId, Msg3), Channel}; handle_out({puback, PacketId, ReasonCode}, Channel) -> - {ok, ?PUBACK_PACKET(PacketId, ReasonCode), Channel}; + {ok, ?PUBACK_PACKET(PacketId, ReasonCode), inc_pub_stats(puback_out, Channel)}; handle_out({pubrel, PacketId, ReasonCode}, Channel) -> - {ok, ?PUBREL_PACKET(PacketId, ReasonCode), Channel}; + {ok, ?PUBREL_PACKET(PacketId, ReasonCode), inc_pub_stats(pubrel_out, Channel)}; handle_out({pubrec, PacketId, ReasonCode}, Channel) -> - {ok, ?PUBREC_PACKET(PacketId, ReasonCode), Channel}; + {ok, ?PUBREC_PACKET(PacketId, ReasonCode), inc_pub_stats(pubrec_out, Channel)}; handle_out({pubcomp, PacketId, ReasonCode}, Channel) -> - {ok, ?PUBCOMP_PACKET(PacketId, ReasonCode), Channel}; + {ok, ?PUBCOMP_PACKET(PacketId, ReasonCode), inc_pub_stats(pubcomp_out, Channel)}; handle_out({suback, PacketId, ReasonCodes}, Channel = #channel{conninfo = #{proto_ver := ?MQTT_PROTO_V5}}) -> @@ -609,15 +634,28 @@ handle_out({unsuback, PacketId, ReasonCodes}, handle_out({unsuback, PacketId, _ReasonCodes}, Channel) -> {ok, ?UNSUBACK_PACKET(PacketId), Channel}; -handle_out({disconnect, ReasonCode}, Channel = #channel{conninfo = ConnInfo}) -> - case maps:get(proto_ver, ConnInfo) of - ?MQTT_PROTO_V5 -> - Reason = emqx_reason_codes:name(ReasonCode), - Packet = ?DISCONNECT_PACKET(ReasonCode), - {wait_session_expire, {shutdown, Reason}, Packet, Channel}; - ProtoVer -> - Reason = emqx_reason_codes:name(ReasonCode, ProtoVer), - {wait_session_expire, {shutdown, Reason}, Channel} +handle_out({disconnect, ReasonCode}, Channel = #channel{conninfo = #{proto_ver := ProtoVer}}) -> + ReasonName = emqx_reason_codes:name(ReasonCode, ProtoVer), + handle_out({disconnect, ReasonCode, ReasonName}, Channel); + +%%TODO: Improve later... +handle_out({disconnect, ReasonCode, ReasonName}, + Channel = #channel{conninfo = #{proto_ver := ProtoVer, + expiry_interval := ExpiryInterval}}) -> + case {ExpiryInterval, ProtoVer} of + {0, ?MQTT_PROTO_V5} -> + {stop, ReasonName, ?DISCONNECT_PACKET(ReasonCode), Channel}; + {0, _Ver} -> + {stop, ReasonName, Channel}; + {?UINT_MAX, ?MQTT_PROTO_V5} -> + {close, ReasonName, ?DISCONNECT_PACKET(ReasonCode), Channel}; + {?UINT_MAX, _Ver} -> + {close, ReasonName, Channel}; + {Interval, ?MQTT_PROTO_V5} -> + NChannel = ensure_timer(expire_timer, Interval, Channel), + {close, ReasonName, ?DISCONNECT_PACKET(ReasonCode), NChannel}; + {Interval, _Ver} -> + {close, ReasonName, ensure_timer(expire_timer, Interval, Channel)} end; handle_out({Type, Data}, Channel) -> @@ -631,10 +669,10 @@ handle_out({Type, Data}, Channel) -> handle_call(kick, Channel) -> {stop, {shutdown, kicked}, ok, Channel}; -handle_call(discard, Channel = #channel{connected = true}) -> +handle_call(discard, Channel = #channel{fsm_state = #{state_name := connected}}) -> Packet = ?DISCONNECT_PACKET(?RC_SESSION_TAKEN_OVER), {stop, {shutdown, discarded}, Packet, ok, Channel}; -handle_call(discard, Channel = #channel{connected = false}) -> +handle_call(discard, Channel = #channel{fsm_state = #{state_name := disconnected}}) -> {stop, {shutdown, discarded}, ok, Channel}; %% Session Takeover @@ -651,49 +689,40 @@ handle_call(Req, Channel) -> ?LOG(error, "Unexpected call: ~p", [Req]), {ok, ignored, Channel}. -%%-------------------------------------------------------------------- -%% Handle cast -%%-------------------------------------------------------------------- - --spec(handle_cast(Msg :: term(), channel()) - -> ok | {ok, channel()} | {stop, Reason :: term(), channel()}). -handle_cast({register, Attrs, Stats}, #channel{client_info = #{client_id := ClientId}}) -> - ok = emqx_cm:register_channel(ClientId), - emqx_cm:set_chan_attrs(ClientId, Attrs), - emqx_cm:set_chan_stats(ClientId, Stats); - -handle_cast(Msg, Channel) -> - ?LOG(error, "Unexpected cast: ~p", [Msg]), - {ok, Channel}. - %%-------------------------------------------------------------------- %% Handle Info %%-------------------------------------------------------------------- -spec(handle_info(Info :: term(), channel()) - -> {ok, channel()} | {stop, Reason :: term(), channel()}). -handle_info({subscribe, TopicFilters}, Channel = #channel{client_info = ClientInfo}) -> + -> ok | {ok, channel()} | {stop, Reason :: term(), channel()}). +handle_info({subscribe, TopicFilters}, Channel = #channel{clientinfo = ClientInfo}) -> TopicFilters1 = emqx_hooks:run_fold('client.subscribe', [ClientInfo, #{'Internal' => true}], parse_topic_filters(TopicFilters)), {_ReasonCodes, NChannel} = process_subscribe(TopicFilters1, Channel), {ok, NChannel}; -handle_info({unsubscribe, TopicFilters}, Channel = #channel{client_info = ClientInfo}) -> +handle_info({unsubscribe, TopicFilters}, Channel = #channel{clientinfo = ClientInfo}) -> TopicFilters1 = emqx_hooks:run_fold('client.unsubscribe', [ClientInfo, #{'Internal' => true}], parse_topic_filters(TopicFilters)), {_ReasonCodes, NChannel} = process_unsubscribe(TopicFilters1, Channel), {ok, NChannel}; -handle_info(disconnected, Channel = #channel{connected = undefined}) -> - shutdown(closed, Channel); +handle_info({register, Attrs, Stats}, #channel{clientinfo = #{clientid := ClientId}}) -> + ok = emqx_cm:register_channel(ClientId), + emqx_cm:set_chan_attrs(ClientId, Attrs), + emqx_cm:set_chan_stats(ClientId, Stats); -handle_info(disconnected, Channel = #channel{connected = false}) -> +%%TODO: Fixme later +%%handle_info(disconnected, Channel = #channel{connected = undefined}) -> +%% shutdown(closed, Channel); + +handle_info(disconnected, Channel = #channel{fsm_state = #{state_name := disconnected}}) -> {ok, Channel}; handle_info(disconnected, Channel = #channel{conninfo = #{expiry_interval := ExpiryInterval}, - client_info = ClientInfo = #{zone := Zone}, + clientinfo = ClientInfo = #{zone := Zone}, will_msg = WillMsg}) -> emqx_zone:enable_flapping_detect(Zone) andalso emqx_flapping:detect(ClientInfo), Channel1 = ensure_disconnected(Channel), @@ -726,7 +755,7 @@ handle_info(Info, Channel) -> | {ok, Result :: term(), channel()} | {stop, Reason :: term(), channel()}). handle_timeout(TRef, {emit_stats, Stats}, - Channel = #channel{client_info = #{client_id := ClientId}, + Channel = #channel{clientinfo = #{clientid := ClientId}, timers = #{stats_timer := TRef}}) -> ok = emqx_cm:set_chan_stats(ClientId, Stats), {ok, clean_timer(stats_timer, Channel)}; @@ -739,7 +768,7 @@ handle_timeout(TRef, {keepalive, StatVal}, NChannel = Channel#channel{keepalive = NKeepalive}, {ok, reset_timer(alive_timer, NChannel)}; {error, timeout} -> - {wait_session_expire, {shutdown, keepalive_timeout}, Channel} + handle_out({disconnect, ?RC_KEEP_ALIVE_TIMEOUT}, Channel) end; handle_timeout(TRef, retry_delivery, @@ -810,7 +839,7 @@ reset_timer(Name, Time, Channel) -> clean_timer(Name, Channel = #channel{timers = Timers}) -> Channel#channel{timers = maps:remove(Name, Timers)}. -interval(stats_timer, #channel{client_info = #{zone := Zone}}) -> +interval(stats_timer, #channel{clientinfo = #{zone := Zone}}) -> emqx_zone:get_env(Zone, idle_timeout, 30000); interval(alive_timer, #channel{keepalive = KeepAlive}) -> emqx_keepalive:info(interval, KeepAlive); @@ -828,18 +857,21 @@ will_delay_interval(undefined) -> 0; will_delay_interval(WillMsg) -> emqx_message:get_header('Will-Delay-Interval', WillMsg, 0). +%% TODO: Implement later. +disconnect(_Reason, Channel) -> {ok, Channel}. + %%-------------------------------------------------------------------- %% Terminate %%-------------------------------------------------------------------- -terminate(_, #channel{connected = undefined}) -> +terminate(_, #channel{fsm_state = #{state_name := initialized}}) -> ok; -terminate(normal, #channel{conninfo = ConnInfo, client_info = ClientInfo}) -> +terminate(normal, #channel{conninfo = ConnInfo, clientinfo = ClientInfo}) -> ok = emqx_hooks:run('client.disconnected', [ClientInfo, normal, ConnInfo]); -terminate({shutdown, Reason}, #channel{conninfo = ConnInfo, client_info = ClientInfo}) +terminate({shutdown, Reason}, #channel{conninfo = ConnInfo, clientinfo = ClientInfo}) when Reason =:= kicked orelse Reason =:= discarded orelse Reason =:= takeovered -> ok = emqx_hooks:run('client.disconnected', [ClientInfo, Reason, ConnInfo]); -terminate(Reason, #channel{conninfo = ConnInfo, client_info = ClientInfo, will_msg = WillMsg}) -> +terminate(Reason, #channel{conninfo = ConnInfo, clientinfo = ClientInfo, will_msg = WillMsg}) -> publish_will_msg(WillMsg), ok = emqx_hooks:run('client.disconnected', [ClientInfo, Reason, ConnInfo]). @@ -864,9 +896,9 @@ enrich_conninfo(#mqtt_packet_connect{ clean_start = CleanStart, keepalive = Keepalive, properties = ConnProps, - client_id = ClientId, + clientid = ClientId, username = Username}, Channel) -> - #channel{conninfo = ConnInfo, client_info = #{zone := Zone}} = Channel, + #channel{conninfo = ConnInfo, clientinfo = #{zone := Zone}} = Channel, MaxInflight = emqx_mqtt_props:get('Receive-Maximum', ConnProps, emqx_zone:max_inflight(Zone)), Interval = if ProtoVer == ?MQTT_PROTO_V5 -> @@ -880,7 +912,7 @@ enrich_conninfo(#mqtt_packet_connect{ proto_ver => ProtoVer, clean_start => CleanStart, keepalive => Keepalive, - client_id => ClientId, + clientid => ClientId, username => Username, conn_props => ConnProps, receive_maximum => MaxInflight, @@ -889,18 +921,18 @@ enrich_conninfo(#mqtt_packet_connect{ {ok, Channel#channel{conninfo = NConnInfo}}. %% @doc Check connect packet. -check_connect(ConnPkt, #channel{client_info = #{zone := Zone}}) -> +check_connect(ConnPkt, #channel{clientinfo = #{zone := Zone}}) -> emqx_packet:check(ConnPkt, emqx_mqtt_caps:get_caps(Zone)). %% @doc Enrich client -enrich_client(ConnPkt, Channel = #channel{client_info = ClientInfo}) -> +enrich_client(ConnPkt, Channel = #channel{clientinfo = ClientInfo}) -> {ok, NConnPkt, NClientInfo} = pipeline([fun set_username/2, fun set_bridge_mode/2, fun maybe_username_as_clientid/2, fun maybe_assign_clientid/2, fun fix_mountpoint/2], ConnPkt, ClientInfo), - {ok, NConnPkt, Channel#channel{client_info = NClientInfo}}. + {ok, NConnPkt, Channel#channel{clientinfo = NClientInfo}}. set_username(#mqtt_packet_connect{username = Username}, ClientInfo = #{username := undefined}) -> @@ -916,35 +948,35 @@ maybe_username_as_clientid(_ConnPkt, ClientInfo = #{username := undefined}) -> {ok, ClientInfo}; maybe_username_as_clientid(_ConnPkt, ClientInfo = #{zone := Zone, username := Username}) -> case emqx_zone:use_username_as_clientid(Zone) of - true -> {ok, ClientInfo#{client_id => Username}}; + true -> {ok, ClientInfo#{clientid => Username}}; false -> ok end. -maybe_assign_clientid(#mqtt_packet_connect{client_id = <<>>}, ClientInfo) -> +maybe_assign_clientid(#mqtt_packet_connect{clientid = <<>>}, ClientInfo) -> %% Generate a rand clientId - {ok, ClientInfo#{client_id => emqx_guid:to_base62(emqx_guid:gen())}}; -maybe_assign_clientid(#mqtt_packet_connect{client_id = ClientId}, ClientInfo) -> - {ok, ClientInfo#{client_id => ClientId}}. + {ok, ClientInfo#{clientid => emqx_guid:to_base62(emqx_guid:gen())}}; +maybe_assign_clientid(#mqtt_packet_connect{clientid = ClientId}, ClientInfo) -> + {ok, ClientInfo#{clientid => ClientId}}. fix_mountpoint(_ConnPkt, #{mountpoint := undefined}) -> ok; fix_mountpoint(_ConnPkt, ClientInfo = #{mountpoint := Mountpoint}) -> {ok, ClientInfo#{mountpoint := emqx_mountpoint:replvar(Mountpoint, ClientInfo)}}. %% @doc Set logger metadata. -set_logger_meta(_ConnPkt, #channel{client_info = #{client_id := ClientId}}) -> - emqx_logger:set_metadata_client_id(ClientId). +set_logger_meta(_ConnPkt, #channel{clientinfo = #{clientid := ClientId}}) -> + emqx_logger:set_metadata_clientid(ClientId). %%-------------------------------------------------------------------- %% Check banned/flapping %%-------------------------------------------------------------------- -check_banned(_ConnPkt, #channel{client_info = ClientInfo = #{zone := Zone}}) -> +check_banned(_ConnPkt, #channel{clientinfo = ClientInfo = #{zone := Zone}}) -> case emqx_zone:enable_ban(Zone) andalso emqx_banned:check(ClientInfo) of true -> {error, ?RC_BANNED}; false -> ok end. -check_flapping(_ConnPkt, #channel{client_info = ClientInfo = #{zone := Zone}}) -> +check_flapping(_ConnPkt, #channel{clientinfo = ClientInfo = #{zone := Zone}}) -> case emqx_zone:enable_flapping_detect(Zone) andalso emqx_flapping:check(ClientInfo) of true -> {error, ?RC_CONNECTION_RATE_EXCEEDED}; @@ -955,13 +987,13 @@ check_flapping(_ConnPkt, #channel{client_info = ClientInfo = #{zone := Zone}}) - %% Auth Connect %%-------------------------------------------------------------------- -auth_connect(#mqtt_packet_connect{client_id = ClientId, +auth_connect(#mqtt_packet_connect{clientid = ClientId, username = Username, password = Password}, - Channel = #channel{client_info = ClientInfo}) -> + Channel = #channel{clientinfo = ClientInfo}) -> case emqx_access_control:authenticate(ClientInfo#{password => Password}) of {ok, AuthResult} -> - {ok, Channel#channel{client_info = maps:merge(ClientInfo, AuthResult)}}; + {ok, Channel#channel{clientinfo = maps:merge(ClientInfo, AuthResult)}}; {error, Reason} -> ?LOG(warning, "Client ~s (Username: '~s') login failed for ~0p", [ClientId, Username, Reason]), @@ -1004,7 +1036,7 @@ save_alias(AliasId, Topic, Aliases) -> maps:put(AliasId, Topic, Aliases). %% Check Pub ACL check_pub_acl(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = Topic}}, - #channel{client_info = ClientInfo}) -> + #channel{clientinfo = ClientInfo}) -> case is_acl_enabled(ClientInfo) andalso emqx_access_control:check_acl(ClientInfo, publish, Topic) of false -> ok; @@ -1033,7 +1065,7 @@ check_pub_caps(#mqtt_packet{header = #mqtt_packet_header{qos = QoS, retain = Retain } }, - #channel{client_info = #{zone := Zone}}) -> + #channel{clientinfo = #{zone := Zone}}) -> emqx_mqtt_caps:check_pub(Zone, #{qos => QoS, retain => Retain}). %% Check Sub @@ -1044,7 +1076,7 @@ check_subscribe(TopicFilter, SubOpts, Channel) -> end. %% Check Sub ACL -check_sub_acl(TopicFilter, #channel{client_info = ClientInfo}) -> +check_sub_acl(TopicFilter, #channel{clientinfo = ClientInfo}) -> case is_acl_enabled(ClientInfo) andalso emqx_access_control:check_acl(ClientInfo, subscribe, TopicFilter) of false -> allow; @@ -1052,7 +1084,7 @@ check_sub_acl(TopicFilter, #channel{client_info = ClientInfo}) -> end. %% Check Sub Caps -check_sub_caps(TopicFilter, SubOpts, #channel{client_info = #{zone := Zone}}) -> +check_sub_caps(TopicFilter, SubOpts, #channel{clientinfo = #{zone := Zone}}) -> emqx_mqtt_caps:check_sub(Zone, TopicFilter, SubOpts). enrich_subid(#{'Subscription-Identifier' := SubId}, TopicFilters) -> @@ -1063,12 +1095,12 @@ enrich_subid(_Properties, TopicFilters) -> enrich_subopts(SubOpts, #channel{conninfo = #{proto_ver := ?MQTT_PROTO_V5}}) -> SubOpts; -enrich_subopts(SubOpts, #channel{client_info = #{zone := Zone, is_bridge := IsBridge}}) -> +enrich_subopts(SubOpts, #channel{clientinfo = #{zone := Zone, is_bridge := IsBridge}}) -> NL = flag(emqx_zone:ignore_loop_deliver(Zone)), SubOpts#{rap => flag(IsBridge), nl => NL}. enrich_caps(AckProps, #channel{conninfo = #{proto_ver := ?MQTT_PROTO_V5}, - client_info = #{zone := Zone}}) -> + clientinfo = #{zone := Zone}}) -> #{max_packet_size := MaxPktSize, max_qos_allowed := MaxQoS, retain_available := Retain, @@ -1087,16 +1119,16 @@ enrich_caps(AckProps, #channel{conninfo = #{proto_ver := ?MQTT_PROTO_V5}, enrich_caps(AckProps, _Channel) -> AckProps. -enrich_server_keepalive(AckProps, #channel{client_info = #{zone := Zone}}) -> +enrich_server_keepalive(AckProps, #channel{clientinfo = #{zone := Zone}}) -> case emqx_zone:server_keepalive(Zone) of undefined -> AckProps; Keepalive -> AckProps#{'Server-Keep-Alive' => Keepalive} end. enrich_assigned_clientid(AckProps, #channel{conninfo = ConnInfo, - client_info = #{client_id := ClientId} + clientinfo = #{clientid := ClientId} }) -> - case maps:get(client_id, ConnInfo) of + case maps:get(clientid, ConnInfo) of <<>> -> %% Original ClientId is null. AckProps#{'Assigned-Client-Identifier' => ClientId}; _Origin -> AckProps @@ -1108,8 +1140,10 @@ init_alias_maximum(#mqtt_packet_connect{proto_ver = ?MQTT_PROTO_V5, inbound => emqx_mqtt_caps:get_caps(Zone, max_topic_alias, 0)}; init_alias_maximum(_ConnPkt, _ClientInfo) -> undefined. -ensure_disconnected(Channel) -> - Channel#channel{connected = false, disconnected_at = os:timestamp()}. +ensure_disconnected(Channel = #channel{fsm_state = FsmState}) -> + Channel#channel{fsm_state = FsmState#{state_name := disconnected, + disconnected_at => erlang:system_time(second) + }}. ensure_keepalive(#{'Server-Keep-Alive' := Interval}, Channel) -> ensure_keepalive_timer(Interval, Channel); @@ -1117,7 +1151,7 @@ ensure_keepalive(_AckProps, Channel = #channel{conninfo = ConnInfo}) -> ensure_keepalive_timer(maps:get(keepalive, ConnInfo), Channel). ensure_keepalive_timer(0, Channel) -> Channel; -ensure_keepalive_timer(Interval, Channel = #channel{client_info = #{zone := Zone}}) -> +ensure_keepalive_timer(Interval, Channel = #channel{clientinfo = #{zone := Zone}}) -> Backoff = emqx_zone:get_env(Zone, keepalive_backoff, 0.75), Keepalive = emqx_keepalive:init(round(timer:seconds(Interval) * Backoff)), ensure_timer(alive_timer, Channel#channel{keepalive = Keepalive}). @@ -1150,19 +1184,13 @@ parse_topic_filters(TopicFilters) -> maybe_gc_and_check_oom(_Oct, Channel = #channel{gc_state = undefined}) -> Channel; -maybe_gc_and_check_oom(Oct, Channel = #channel{gc_state = GCSt, - oom_policy = OomPolicy}) -> +maybe_gc_and_check_oom(Oct, Channel = #channel{clientinfo = #{zone := Zone}, + gc_state = GCSt}) -> {IsGC, GCSt1} = emqx_gc:run(1, Oct, GCSt), IsGC andalso emqx_metrics:inc('channel.gc.cnt'), - IsGC andalso maybe_apply(fun check_oom/1, OomPolicy), + IsGC andalso emqx_zone:check_oom(Zone, fun(Shutdown) -> self() ! Shutdown end), Channel#channel{gc_state = GCSt1}. -check_oom(OomPolicy) -> - case emqx_oom:check(OomPolicy) of - ok -> ok; - Shutdown -> self() ! Shutdown - end. - %%-------------------------------------------------------------------- %% Helper functions %%-------------------------------------------------------------------- diff --git a/src/emqx_connection.erl b/src/emqx_connection.erl index 3adee23f5..a0408dea6 100644 --- a/src/emqx_connection.erl +++ b/src/emqx_connection.erl @@ -30,9 +30,7 @@ %% APIs -export([ info/1 - , attrs/1 , stats/1 - , state/1 ]). -export([call/2]). @@ -50,7 +48,7 @@ , terminate/3 ]). --record(connection, { +-record(state, { %% TCP/TLS Transport transport :: esockd:transport(), %% TCP/TLS Socket @@ -63,27 +61,26 @@ active_n :: pos_integer(), %% The active state active_state :: running | blocked, - %% Rate Limit - rate_limit :: maybe(esockd_rate_limit:bucket()), %% Publish Limit pub_limit :: maybe(esockd_rate_limit:bucket()), + %% Rate Limit + rate_limit :: maybe(esockd_rate_limit:bucket()), %% Limit Timer limit_timer :: maybe(reference()), %% Parser State parse_state :: emqx_frame:parse_state(), %% Serialize function - serialize :: fun((emqx_types:packet()) -> iodata()), + serialize :: emqx_frame:serialize_fun(), %% Channel State chan_state :: emqx_channel:channel() }). --type(connection() :: #connection{}). +-type(state() :: #state{}). -define(ACTIVE_N, 100). -define(HANDLE(T, C, D), handle((T), (C), (D))). --define(ATTR_KEYS, [socktype, peername, sockname]). -define(INFO_KEYS, [socktype, peername, sockname, active_n, active_state, - rate_limit, pub_limit]). + pub_limit, rate_limit]). -define(CONN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]). @@ -98,64 +95,52 @@ start_link(Transport, Socket, Options) -> %%-------------------------------------------------------------------- %% @doc Get infos of the connection. --spec(info(pid()|connection()) -> emqx_types:infos()). +-spec(info(pid()|state()) -> emqx_types:infos()). info(CPid) when is_pid(CPid) -> call(CPid, info); -info(Conn = #connection{chan_state = ChanState}) -> +info(Conn = #state{chan_state = ChanState}) -> ChanInfo = emqx_channel:info(ChanState), SockInfo = maps:from_list(info(?INFO_KEYS, Conn)), maps:merge(ChanInfo, #{sockinfo => SockInfo}). info(Keys, Conn) when is_list(Keys) -> [{Key, info(Key, Conn)} || Key <- Keys]; -info(socktype, #connection{transport = Transport, socket = Socket}) -> +info(socktype, #state{transport = Transport, socket = Socket}) -> Transport:type(Socket); -info(peername, #connection{peername = Peername}) -> +info(peername, #state{peername = Peername}) -> Peername; -info(sockname, #connection{sockname = Sockname}) -> +info(sockname, #state{sockname = Sockname}) -> Sockname; -info(active_n, #connection{active_n = ActiveN}) -> +info(active_n, #state{active_n = ActiveN}) -> ActiveN; -info(active_state, #connection{active_state = ActiveSt}) -> +info(active_state, #state{active_state = ActiveSt}) -> ActiveSt; -info(rate_limit, #connection{rate_limit = RateLimit}) -> - limit_info(RateLimit); -info(pub_limit, #connection{pub_limit = PubLimit}) -> +info(pub_limit, #state{pub_limit = PubLimit}) -> limit_info(PubLimit); -info(chan_state, #connection{chan_state = ChanState}) -> +info(rate_limit, #state{rate_limit = RateLimit}) -> + limit_info(RateLimit); +info(chan_state, #state{chan_state = ChanState}) -> emqx_channel:info(ChanState). limit_info(Limit) -> emqx_misc:maybe_apply(fun esockd_rate_limit:info/1, Limit). -%% @doc Get attrs of the connection. --spec(attrs(pid()|connection()) -> emqx_types:attrs()). -attrs(CPid) when is_pid(CPid) -> - call(CPid, attrs); -attrs(Conn = #connection{chan_state = ChanState}) -> - ChanAttrs = emqx_channel:attrs(ChanState), - SockAttrs = maps:from_list(info(?ATTR_KEYS, Conn)), - maps:merge(ChanAttrs, #{sockinfo => SockAttrs}). - %% @doc Get stats of the channel. --spec(stats(pid()|connection()) -> emqx_types:stats()). +-spec(stats(pid()|state()) -> emqx_types:stats()). stats(CPid) when is_pid(CPid) -> call(CPid, stats); -stats(#connection{transport = Transport, - socket = Socket, - chan_state = ChanState}) -> - ProcStats = emqx_misc:proc_stats(), +stats(#state{transport = Transport, + socket = Socket, + chan_state = ChanState}) -> SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of {ok, Ss} -> Ss; {error, _} -> [] end, - ConnStats = [{Name, emqx_pd:get_counter(Name)} || Name <- ?CONN_STATS], + ConnStats = emqx_pd:get_counters(?CONN_STATS), ChanStats = emqx_channel:stats(ChanState), - lists:append([ProcStats, SockStats, ConnStats, ChanStats]). - -%% For debug --spec(state(pid()) -> connection()). -state(CPid) -> call(CPid, state). + ProcStats = emqx_misc:proc_stats(), + [{sock_stats, SockStats}, {conn_stats, ConnStats}, + {chan_stats, ChanStats}, {proc_stats, ProcStats}]. %% kick|discard|takeover -spec(call(pid(), Req :: term()) -> Reply :: term()). @@ -170,38 +155,43 @@ init({Transport, RawSocket, Options}) -> {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]), {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]), Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]), + ConnInfo = #{socktype => Transport:type(Socket), + peername => Peername, + sockname => Sockname, + peercert => Peercert, + conn_mod => ?MODULE + }, emqx_logger:set_metadata_peername(esockd_net:format(Peername)), Zone = proplists:get_value(zone, Options), - RateLimit = init_limiter(proplists:get_value(rate_limit, Options)), - PubLimit = init_limiter(emqx_zone:get_env(Zone, publish_limit)), ActiveN = proplists:get_value(active_n, Options, ?ACTIVE_N), - MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE), - ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}), - ChanState = emqx_channel:init(#{peername => Peername, - sockname => Sockname, - peercert => Peercert, - protocol => mqtt, - conn_mod => ?MODULE}, Options), + PubLimit = init_limiter(emqx_zone:get_env(Zone, publish_limit)), + RateLimit = init_limiter(proplists:get_value(rate_limit, Options)), + FrameOpts = emqx_zone:frame_options(Zone), + ParseState = emqx_frame:initial_parse_state(FrameOpts), + Serialize = emqx_frame:serialize_fun(), + ChanState = emqx_channel:init(ConnInfo, Options), + State = #state{transport = Transport, + socket = Socket, + peername = Peername, + sockname = Sockname, + active_n = ActiveN, + active_state = running, + pub_limit = PubLimit, + rate_limit = RateLimit, + parse_state = ParseState, + serialize = Serialize, + chan_state = ChanState + }, IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), - State = #connection{transport = Transport, - socket = Socket, - peername = Peername, - sockname = Sockname, - active_n = ActiveN, - active_state = running, - rate_limit = RateLimit, - pub_limit = PubLimit, - parse_state = ParseState, - chan_state = ChanState, - serialize = serialize_fun(?MQTT_PROTO_V5, undefined) - }, gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}], idle, State, self(), [IdleTimout]). +-compile({inline, [init_limiter/1]}). init_limiter(undefined) -> undefined; init_limiter({Rate, Burst}) -> esockd_rate_limit:new(Rate, Burst). +-compile({inline, [callback_mode/0]}). callback_mode() -> [state_functions, state_enter]. @@ -219,18 +209,17 @@ idle(timeout, _Timeout, State) -> shutdown(idle_timeout, State); idle(cast, {incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State) -> - #mqtt_packet_connect{proto_ver = ProtoVer, properties = Properties} = ConnPkt, - MaxPacketSize = emqx_mqtt_props:get('Maximum-Packet-Size', Properties, undefined), - NState = State#connection{serialize = serialize_fun(ProtoVer, MaxPacketSize)}, SuccFun = fun(NewSt) -> {next_state, connected, NewSt} end, + Serialize = emqx_frame:serialize_fun(ConnPkt), + NState = State#state{serialize = Serialize}, handle_incoming(Packet, SuccFun, NState); idle(cast, {incoming, Packet}, State) when is_record(Packet, mqtt_packet) -> - ?LOG(warning, "Unexpected incoming: ~p", [Packet]), - shutdown(unexpected_incoming_packet, State); + SuccFun = fun(NewSt) -> {next_state, connected, NewSt} end, + handle_incoming(Packet, SuccFun, State); -idle(cast, {incoming, {error, Reason}}, State) -> - shutdown(Reason, State); +idle(cast, {incoming, FrameError = {frame_error, _Reason}}, State) -> + handle_incoming(FrameError, State); idle(EventType, Content, State) -> ?HANDLE(EventType, Content, State). @@ -245,16 +234,8 @@ connected(enter, _PrevSt, State) -> connected(cast, {incoming, Packet}, State) when is_record(Packet, mqtt_packet) -> handle_incoming(Packet, fun keep_state/1, State); -connected(cast, {incoming, {error, Reason}}, State = #connection{chan_state = ChanState}) -> - case emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState) of - {wait_session_expire, _, NChanState} -> - ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]), - {next_state, disconnected, State#connection{chan_state= NChanState}}; - {wait_session_expire, _, OutPackets, NChanState} -> - ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]), - NState = State#connection{chan_state= NChanState}, - {next_state, disconnected, handle_outgoing(OutPackets, fun(NewSt) -> NewSt end, NState)} - end; +connected(cast, {incoming, FrameError = {frame_error, _Reason}}, State) -> + handle_incoming(FrameError, State); connected(info, Deliver = {deliver, _Topic, _Msg}, State) -> handle_deliver(emqx_misc:drain_deliver([Deliver]), State); @@ -265,13 +246,13 @@ connected(EventType, Content, State) -> %%-------------------------------------------------------------------- %% Disconnected State -disconnected(enter, _, State = #connection{chan_state = ChanState}) -> +disconnected(enter, _, State = #state{chan_state = ChanState}) -> case emqx_channel:handle_info(disconnected, ChanState) of {ok, NChanState} -> - ok = register_self(State#connection{chan_state = NChanState}), - keep_state(State#connection{chan_state = NChanState}); + ok = register_self(State#state{chan_state = NChanState}), + keep_state(State#state{chan_state = NChanState}); {stop, Reason, NChanState} -> - stop(Reason, State#connection{chan_state = NChanState}) + stop(Reason, State#state{chan_state = NChanState}) end; disconnected(info, Deliver = {deliver, _Topic, _Msg}, State) -> @@ -286,51 +267,49 @@ disconnected(EventType, Content, State) -> handle({call, From}, info, State) -> reply(From, info(State), State); -handle({call, From}, attrs, State) -> - reply(From, attrs(State), State); - handle({call, From}, stats, State) -> reply(From, stats(State), State); handle({call, From}, state, State) -> reply(From, State, State); -handle({call, From}, Req, State = #connection{chan_state = ChanState}) -> +handle({call, From}, Req, State = #state{chan_state = ChanState}) -> case emqx_channel:handle_call(Req, ChanState) of {ok, Reply, NChanState} -> - reply(From, Reply, State#connection{chan_state = NChanState}); + reply(From, Reply, State#state{chan_state = NChanState}); {stop, Reason, Reply, NChanState} -> ok = gen_statem:reply(From, Reply), - stop(Reason, State#connection{chan_state = NChanState}); + stop(Reason, State#state{chan_state = NChanState}); {stop, Reason, Packet, Reply, NChanState} -> - handle_outgoing(Packet, fun (_) -> ok end, State#connection{chan_state = NChanState}), + handle_outgoing(Packet, State#state{chan_state = NChanState}), ok = gen_statem:reply(From, Reply), - stop(Reason, State#connection{chan_state = NChanState}) + stop(Reason, State#state{chan_state = NChanState}) end; %%-------------------------------------------------------------------- %% Handle cast -handle(cast, Msg, State = #connection{chan_state = ChanState}) -> - case emqx_channel:handle_cast(Msg, ChanState) of +handle(cast, Msg, State = #state{chan_state = ChanState}) -> + case emqx_channel:handle_info(Msg, ChanState) of + ok -> {ok, State}; {ok, NChanState} -> - keep_state(State#connection{chan_state = NChanState}); + keep_state(State#state{chan_state = NChanState}); {stop, Reason, NChanState} -> - stop(Reason, State#connection{chan_state = NChanState}) + stop(Reason, State#state{chan_state = NChanState}) end; %%-------------------------------------------------------------------- %% Handle info %% Handle incoming data -handle(info, {Inet, _Sock, Data}, State = #connection{chan_state = ChanState}) +handle(info, {Inet, _Sock, Data}, State = #state{chan_state = ChanState}) when Inet == tcp; Inet == ssl -> ?LOG(debug, "RECV ~p", [Data]), Oct = iolist_size(Data), emqx_pd:update_counter(incoming_bytes, Oct), ok = emqx_metrics:inc('bytes.received', Oct), NChanState = emqx_channel:received(Oct, ChanState), - NState = State#connection{chan_state = NChanState}, + NState = State#state{chan_state = NChanState}, process_incoming(Data, NState); handle(info, {Error, _Sock, Reason}, State) @@ -353,9 +332,9 @@ handle(info, {Passive, _Sock}, State) handle(info, activate_socket, State) -> %% Rate limit timer expired. - NState = State#connection{active_state = running, - limit_timer = undefined - }, + NState = State#state{active_state = running, + limit_timer = undefined + }, case activate_socket(NState) of ok -> keep_state(NState); {error, Reason} -> @@ -370,7 +349,7 @@ handle(info, {inet_reply, _Sock, {error, Reason}}, State) -> shutdown(Reason, State); handle(info, {timeout, TRef, keepalive}, - State = #connection{transport = Transport, socket = Socket}) -> + State = #state{transport = Transport, socket = Socket}) -> case Transport:getstat(Socket, [recv_oct]) of {ok, [{recv_oct, RecvOct}]} -> handle_timeout(TRef, {keepalive, RecvOct}, State); @@ -387,21 +366,21 @@ handle(info, {timeout, TRef, Msg}, State) -> handle(info, {shutdown, Reason}, State) -> shutdown(Reason, State); -handle(info, Info, State = #connection{chan_state = ChanState}) -> +handle(info, Info, State = #state{chan_state = ChanState}) -> case emqx_channel:handle_info(Info, ChanState) of {ok, NChanState} -> - keep_state(State#connection{chan_state = NChanState}); + keep_state(State#state{chan_state = NChanState}); {stop, Reason, NChanState} -> - stop(Reason, State#connection{chan_state = NChanState}) + stop(Reason, State#state{chan_state = NChanState}) end. code_change(_Vsn, State, Data, _Extra) -> {ok, State, Data}. -terminate(Reason, _StateName, State) -> - #connection{transport = Transport, - socket = Socket, - chan_state = ChanState} = State, +terminate(Reason, _StateName, #state{transport = Transport, + socket = Socket, + chan_state = ChanState + }) -> ?LOG(debug, "Terminated for ~p", [Reason]), ok = Transport:fast_close(Socket), emqx_channel:terminate(Reason, ChanState). @@ -410,8 +389,16 @@ terminate(Reason, _StateName, State) -> %% Internal functions %%-------------------------------------------------------------------- -register_self(State = #connection{chan_state = ChanState}) -> - emqx_channel:handle_cast({register, attrs(State), stats(State)}, ChanState). +register_self(State = #state{active_n = ActiveN, + active_state = ActiveSt, + chan_state = ChanState + }) -> + ChanAttrs = emqx_channel:attrs(ChanState), + SockAttrs = #{active_n => ActiveN, + active_state => ActiveSt + }, + Attrs = maps:merge(ChanAttrs, #{sockinfo => SockAttrs}), + emqx_channel:handle_info({register, Attrs, stats(State)}, ChanState). %%-------------------------------------------------------------------- %% Process incoming data @@ -421,107 +408,109 @@ process_incoming(Data, State) -> process_incoming(Data, [], State). process_incoming(<<>>, Packets, State) -> - {keep_state, State, next_incoming_events(Packets)}; + keep_state(State, next_incoming_events(Packets)); -process_incoming(Data, Packets, State = #connection{parse_state = ParseState}) -> +process_incoming(Data, Packets, State = #state{parse_state = ParseState}) -> try emqx_frame:parse(Data, ParseState) of {more, NParseState} -> - NState = State#connection{parse_state = NParseState}, - {keep_state, NState, next_incoming_events(Packets)}; + NState = State#state{parse_state = NParseState}, + keep_state(NState, next_incoming_events(Packets)); {ok, Packet, Rest, NParseState} -> - NState = State#connection{parse_state = NParseState}, - process_incoming(Rest, [Packet|Packets], NState); - {error, Reason} -> - {keep_state, State, next_incoming_events({error, Reason})} + NState = State#state{parse_state = NParseState}, + process_incoming(Rest, [Packet|Packets], NState) catch error:Reason:Stk -> - ?LOG(error, "~nParse failed for ~p~nStacktrace: ~p~nError data:~p", [Reason, Stk, Data]), - {keep_state, State, next_incoming_events({error, Reason})} + ?LOG(error, "~nParse failed for ~p~nStacktrace: ~p~nFrame data:~p", + [Reason, Stk, Data]), + keep_state(State, next_incoming_events(Packets++[{frame_error, Reason}])) end. -compile({inline, [next_incoming_events/1]}). -next_incoming_events({error, Reason}) -> - [next_event(cast, {incoming, {error, Reason}})]; +next_incoming_events([]) -> []; next_incoming_events(Packets) -> [next_event(cast, {incoming, Packet}) || Packet <- Packets]. %%-------------------------------------------------------------------- %% Handle incoming packet -handle_incoming(Packet = ?PACKET(Type), SuccFun, - State = #connection{chan_state = ChanState}) -> +handle_incoming(Packet = ?PACKET(Type), SuccFun, State = #state{chan_state = ChanState}) -> _ = inc_incoming_stats(Type), - ok = emqx_metrics:inc_recv(Packet), + _ = emqx_metrics:inc_recv(Packet), ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]), case emqx_channel:handle_in(Packet, ChanState) of {ok, NChanState} -> - SuccFun(State#connection{chan_state= NChanState}); + SuccFun(State#state{chan_state= NChanState}); {ok, OutPackets, NChanState} -> - handle_outgoing(OutPackets, SuccFun, State#connection{chan_state = NChanState}); - {wait_session_expire, Reason, NChanState} -> - ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]), - {next_state, disconnected, State#connection{chan_state = NChanState}}; - {wait_session_expire, Reason, OutPackets, NChanState} -> - ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]), - NState = State#connection{chan_state= NChanState}, - {next_state, disconnected, handle_outgoing(OutPackets, fun(NewSt) -> NewSt end, NState)}; + NState = State#state{chan_state = NChanState}, + handle_outgoing(OutPackets, SuccFun, NState); {stop, Reason, NChanState} -> - stop(Reason, State#connection{chan_state = NChanState}); + stop(Reason, State#state{chan_state = NChanState}); {stop, Reason, OutPackets, NChanState} -> - NState = State#connection{chan_state= NChanState}, + NState = State#state{chan_state= NChanState}, + stop(Reason, handle_outgoing(OutPackets, fun(NewSt) -> NewSt end, NState)) + end. + +handle_incoming(FrameError = {frame_error, _Reason}, State = #state{chan_state = ChanState}) -> + case emqx_channel:handle_in(FrameError, ChanState) of + {close, Reason, NChanState} -> + close(Reason, State#state{chan_state = NChanState}); + {close, Reason, OutPackets, NChanState} -> + NState = State#state{chan_state= NChanState}, + close(Reason, handle_outgoing(OutPackets, fun(NewSt) -> NewSt end, NState)); + {stop, Reason, NChanState} -> + stop(Reason, State#state{chan_state = NChanState}); + {stop, Reason, OutPackets, NChanState} -> + NState = State#state{chan_state= NChanState}, stop(Reason, handle_outgoing(OutPackets, fun(NewSt) -> NewSt end, NState)) end. %%------------------------------------------------------------------- %% Handle deliver -handle_deliver(Delivers, State = #connection{chan_state = ChanState}) -> +handle_deliver(Delivers, State = #state{chan_state = ChanState}) -> case emqx_channel:handle_out({deliver, Delivers}, ChanState) of {ok, NChanState} -> - keep_state(State#connection{chan_state = NChanState}); + keep_state(State#state{chan_state = NChanState}); {ok, Packets, NChanState} -> - handle_outgoing(Packets, fun keep_state/1, State#connection{chan_state = NChanState}) + handle_outgoing(Packets, fun keep_state/1, State#state{chan_state = NChanState}) end. %%-------------------------------------------------------------------- %% Handle outgoing packets -handle_outgoing(Packets, SuccFun, State = #connection{serialize = Serialize}) - when is_list(Packets) -> - send(lists:map(Serialize, Packets), SuccFun, State); +handle_outgoing(Packet, State) -> + handle_outgoing(Packet, fun (_) -> ok end, State). -handle_outgoing(Packet, SuccFun, State = #connection{serialize = Serialize}) -> - send(Serialize(Packet), SuccFun, State). +handle_outgoing(Packets, SuccFun, State) when is_list(Packets) -> + send(lists:map(serialize_and_inc_stats_fun(State), Packets), SuccFun, State); -%%-------------------------------------------------------------------- -%% Serialize fun +handle_outgoing(Packet, SuccFun, State) -> + send((serialize_and_inc_stats_fun(State))(Packet), SuccFun, State). -serialize_fun(ProtoVer, MaxPacketSize) -> +serialize_and_inc_stats_fun(#state{serialize = Serialize}) -> fun(Packet = ?PACKET(Type)) -> - IoData = emqx_frame:serialize(Packet, ProtoVer), - case Type =/= ?PUBLISH orelse MaxPacketSize =:= undefined orelse iolist_size(IoData) =< MaxPacketSize of - true -> - ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), - _ = inc_outgoing_stats(Type), - _ = emqx_metrics:inc_sent(Packet), - IoData; - false -> - ?LOG(warning, "DROP ~s due to oversize packet size", [emqx_packet:format(Packet)]), - <<"">> + case Serialize(Packet) of + <<>> -> ?LOG(warning, "~s is discarded due to the frame is too large!", + [emqx_packet:format(Packet)]), + <<>>; + Data -> _ = inc_outgoing_stats(Type), + _ = emqx_metrics:inc_sent(Packet), + ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), + Data end end. %%-------------------------------------------------------------------- %% Send data -send(IoData, SuccFun, State = #connection{transport = Transport, - socket = Socket, - chan_state = ChanState}) -> +send(IoData, SuccFun, State = #state{transport = Transport, + socket = Socket, + chan_state = ChanState}) -> Oct = iolist_size(IoData), ok = emqx_metrics:inc('bytes.sent', Oct), case Transport:async_send(Socket, IoData) of ok -> NChanState = emqx_channel:sent(Oct, ChanState), - SuccFun(State#connection{chan_state = NChanState}); + SuccFun(State#state{chan_state = NChanState}); {error, Reason} -> shutdown(Reason, State) end. @@ -529,17 +518,22 @@ send(IoData, SuccFun, State = #connection{transport = Transport, %%-------------------------------------------------------------------- %% Handle timeout -handle_timeout(TRef, Msg, State = #connection{chan_state = ChanState}) -> +handle_timeout(TRef, Msg, State = #state{chan_state = ChanState}) -> case emqx_channel:handle_timeout(TRef, Msg, ChanState) of {ok, NChanState} -> - keep_state(State#connection{chan_state = NChanState}); + keep_state(State#state{chan_state = NChanState}); {ok, Packets, NChanState} -> - handle_outgoing(Packets, fun keep_state/1, State#connection{chan_state = NChanState}); - {wait_session_expire, Reason, NChanState} -> - ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]), - {next_state, disconnected, State#connection{chan_state = NChanState}}; + handle_outgoing(Packets, fun keep_state/1, State#state{chan_state = NChanState}); + {close, Reason, NChanState} -> + close(Reason, State#state{chan_state = NChanState}); + {close, Reason, OutPackets, NChanState} -> + NState = State#state{chan_state= NChanState}, + close(Reason, handle_outgoing(OutPackets, fun(NewSt) -> NewSt end, NState)); {stop, Reason, NChanState} -> - stop(Reason, State#connection{chan_state = NChanState}) + stop(Reason, State#state{chan_state = NChanState}); + {stop, Reason, OutPackets, NChanState} -> + NState = State#state{chan_state= NChanState}, + stop(Reason, handle_outgoing(OutPackets, fun(NewSt) -> NewSt end, NState)) end. %%-------------------------------------------------------------------- @@ -547,11 +541,11 @@ handle_timeout(TRef, Msg, State = #connection{chan_state = ChanState}) -> -define(ENABLED(Rl), (Rl =/= undefined)). -ensure_rate_limit(State = #connection{rate_limit = Rl, pub_limit = Pl}) -> +ensure_rate_limit(State = #state{rate_limit = Rl, pub_limit = Pl}) -> Pubs = emqx_pd:reset_counter(incoming_pubs), Bytes = emqx_pd:reset_counter(incoming_bytes), - Limiters = [{Pl, #connection.pub_limit, Pubs} || ?ENABLED(Pl)] ++ - [{Rl, #connection.rate_limit, Bytes} || ?ENABLED(Rl)], + Limiters = [{Pl, #state.pub_limit, Pubs} || ?ENABLED(Pl)] ++ + [{Rl, #state.rate_limit, Bytes} || ?ENABLED(Rl)], ensure_rate_limit(Limiters, State). ensure_rate_limit([], State) -> @@ -563,30 +557,27 @@ ensure_rate_limit([{Rl, Pos, Cnt}|Limiters], State) -> {Pause, Rl1} -> ?LOG(debug, "Pause ~pms due to rate limit", [Pause]), TRef = erlang:send_after(Pause, self(), activate_socket), - NState = State#connection{active_state = blocked, - limit_timer = TRef}, + NState = State#state{active_state = blocked, + limit_timer = TRef + }, setelement(Pos, NState, Rl1) end. %%-------------------------------------------------------------------- %% Activate Socket -activate_socket(#connection{active_state = blocked}) -> - ok; -activate_socket(#connection{transport = Transport, - socket = Socket, - active_n = N}) -> +-compile({inline, [activate_socket/1]}). +activate_socket(#state{active_state = blocked}) -> ok; +activate_socket(#state{transport = Transport, + socket = Socket, + active_n = N}) -> Transport:setopts(Socket, [{active, N}]). %%-------------------------------------------------------------------- %% Inc incoming/outgoing stats --compile({inline, - [ inc_incoming_stats/1 - , inc_outgoing_stats/1 - ]}). - -inc_incoming_stats(Type) -> +-compile({inline, [inc_incoming_stats/1]}). +inc_incoming_stats(Type) when is_integer(Type) -> emqx_pd:update_counter(recv_pkt, 1), if Type == ?PUBLISH -> @@ -595,6 +586,7 @@ inc_incoming_stats(Type) -> true -> ok end. +-compile({inline, [inc_outgoing_stats/1]}). inc_outgoing_stats(Type) -> emqx_pd:update_counter(send_pkt, 1), (Type == ?PUBLISH) @@ -606,6 +598,7 @@ inc_outgoing_stats(Type) -> -compile({inline, [ reply/3 , keep_state/1 + , keep_state/2 , next_event/2 , shutdown/2 , stop/2 @@ -617,9 +610,17 @@ reply(From, Reply, State) -> keep_state(State) -> {keep_state, State}. +keep_state(State, Events) -> + {keep_state, State, Events}. + next_event(Type, Content) -> {next_event, Type, Content}. +close(Reason, State = #state{transport = Transport, socket = Socket}) -> + ?LOG(warning, "Closed for ~p", [Reason]), + ok = Transport:fast_close(Socket), + {next_state, disconnected, State}. + shutdown(Reason, State) -> stop({shutdown, Reason}, State). diff --git a/src/emqx_frame.erl b/src/emqx_frame.erl index e65f5d7c4..b402ad989 100644 --- a/src/emqx_frame.erl +++ b/src/emqx_frame.erl @@ -25,6 +25,8 @@ -export([ parse/1 , parse/2 + , serialize_fun/0 + , serialize_fun/1 , serialize/1 , serialize/2 ]). @@ -32,6 +34,7 @@ -export_type([ options/0 , parse_state/0 , parse_result/0 + , serialize_fun/0 ]). -type(options() :: #{strict_mode => boolean(), @@ -46,7 +49,9 @@ -type(cont_fun() :: fun((binary()) -> parse_result())). --define(none(Opts), {none, Opts}). +-type(serialize_fun() :: fun((emqx_types:packet()) -> iodata())). + +-define(none(Options), {none, Options}). -define(DEFAULT_OPTIONS, #{strict_mode => false, @@ -84,12 +89,12 @@ parse(<<>>, {none, Options}) -> parse(<>, {none, Options = #{strict_mode := StrictMode}}) -> %% Validate header if strict mode. - StrictMode andalso validate_header(Type, Dup, QoS, Retain), Header = #mqtt_packet_header{type = Type, dup = bool(Dup), qos = QoS, retain = bool(Retain) }, + StrictMode andalso validate_header(Type, Dup, QoS, Retain), Header1 = case fixqos(Type, QoS) of QoS -> Header; FixedQoS -> Header#mqtt_packet_header{qos = FixedQoS} @@ -105,7 +110,7 @@ parse_remaining_len(Rest, Header, Options) -> parse_remaining_len(_Bin, _Header, _Multiplier, Length, #{max_size := MaxSize}) when Length > MaxSize -> - error(mqtt_frame_too_large); + error(frame_too_large); parse_remaining_len(<<>>, Header, Multiplier, Length, Options) -> {more, fun(Bin) -> parse_remaining_len(Bin, Header, Multiplier, Length, Options) end}; %% Match DISCONNECT without payload @@ -124,7 +129,7 @@ parse_remaining_len(<<0:1, Len:7, Rest/binary>>, Header, Multiplier, Value, Options = #{max_size := MaxSize}) -> FrameLen = Value + Len * Multiplier, if - FrameLen > MaxSize -> error(mqtt_frame_too_large); + FrameLen > MaxSize -> error(frame_too_large); true -> parse_frame(Rest, Header, FrameLen, Options) end. @@ -148,6 +153,7 @@ parse_frame(Bin, Header, Length, Options) -> end} end. +-compile({inline, [packet/1, packet/2, packet/3]}). packet(Header) -> #mqtt_packet{header = Header}. packet(Header, Variable) -> @@ -180,7 +186,8 @@ parse_packet(#mqtt_packet_header{type = ?CONNECT}, FrameBin, _Options) -> will_retain = bool(WillRetain), keepalive = KeepAlive, properties = Properties, - clientid = ClientId}, + clientid = ClientId + }, {ConnPacket1, Rest5} = parse_will_message(ConnPacket, Rest4), {Username, Rest6} = parse_utf8_string(Rest5, bool(UsernameFlag)), {Passsword, <<>>} = parse_utf8_string(Rest6, bool(PasswordFlag)), @@ -191,10 +198,10 @@ parse_packet(#mqtt_packet_header{type = ?CONNACK}, {Properties, <<>>} = parse_properties(Rest, Ver), #mqtt_packet_connack{ack_flags = AckFlags, reason_code = ReasonCode, - properties = Properties}; + properties = Properties + }; -parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin, - #{version := Ver}) -> +parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin, #{version := Ver}) -> {TopicName, Rest} = parse_utf8_string(Bin), {PacketId, Rest1} = case QoS of ?QOS_0 -> {undefined, Rest}; @@ -202,14 +209,17 @@ parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin, end, (PacketId =/= undefined) andalso validate_packet_id(PacketId), {Properties, Payload} = parse_properties(Rest1, Ver), - {#mqtt_packet_publish{topic_name = TopicName, - packet_id = PacketId, - properties = Properties}, Payload}; + Publish = #mqtt_packet_publish{topic_name = TopicName, + packet_id = PacketId, + properties = Properties + }, + {Publish, Payload}; parse_packet(#mqtt_packet_header{type = PubAck}, <>, _Options) when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP -> ok = validate_packet_id(PacketId), #mqtt_packet_puback{packet_id = PacketId, reason_code = 0}; + parse_packet(#mqtt_packet_header{type = PubAck}, <>, #{version := Ver = ?MQTT_PROTO_V5}) when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP -> @@ -217,24 +227,29 @@ parse_packet(#mqtt_packet_header{type = PubAck}, <>} = parse_properties(Rest, Ver), #mqtt_packet_puback{packet_id = PacketId, reason_code = ReasonCode, - properties = Properties}; + properties = Properties + }; parse_packet(#mqtt_packet_header{type = ?SUBSCRIBE}, <>, #{version := Ver}) -> ok = validate_packet_id(PacketId), {Properties, Rest1} = parse_properties(Rest, Ver), TopicFilters = parse_topic_filters(subscribe, Rest1), + ok = validate_subqos([QoS || {_, #{qos := QoS}} <- TopicFilters]), #mqtt_packet_subscribe{packet_id = PacketId, properties = Properties, - topic_filters = TopicFilters}; + topic_filters = TopicFilters + }; parse_packet(#mqtt_packet_header{type = ?SUBACK}, <>, #{version := Ver}) -> ok = validate_packet_id(PacketId), {Properties, Rest1} = parse_properties(Rest, Ver), + ReasonCodes = parse_reason_codes(Rest1), #mqtt_packet_suback{packet_id = PacketId, properties = Properties, - reason_codes = parse_reason_codes(Rest1)}; + reason_codes = ReasonCodes + }; parse_packet(#mqtt_packet_header{type = ?UNSUBSCRIBE}, <>, #{version := Ver}) -> @@ -243,11 +258,13 @@ parse_packet(#mqtt_packet_header{type = ?UNSUBSCRIBE}, <>, _Options) -> ok = validate_packet_id(PacketId), #mqtt_packet_unsuback{packet_id = PacketId}; + parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <>, #{version := Ver}) -> ok = validate_packet_id(PacketId), @@ -255,13 +272,15 @@ parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <>, #{version := ?MQTT_PROTO_V5}) -> {Properties, <<>>} = parse_properties(Rest, ?MQTT_PROTO_V5), #mqtt_packet_disconnect{reason_code = ReasonCode, - properties = Properties}; + properties = Properties + }; parse_packet(#mqtt_packet_header{type = ?AUTH}, <>, #{version := ?MQTT_PROTO_V5}) -> @@ -275,16 +294,15 @@ parse_will_message(Packet = #mqtt_packet_connect{will_flag = true, {Payload, Rest2} = parse_binary_data(Rest1), {Packet#mqtt_packet_connect{will_props = Props, will_topic = Topic, - will_payload = Payload}, Rest2}; + will_payload = Payload + }, Rest2}; parse_will_message(Packet, Bin) -> {Packet, Bin}. +-compile({inline, [parse_packet_id/1]}). parse_packet_id(<>) -> {PacketId, Rest}. -validate_packet_id(0) -> error(bad_packet_id); -validate_packet_id(_) -> ok. - parse_properties(Bin, Ver) when Ver =/= ?MQTT_PROTO_V5 -> {undefined, Bin}; %% TODO: version mess? @@ -377,7 +395,7 @@ parse_variable_byte_integer(<<0:1, Len:7, Rest/binary>>, Multiplier, Value) -> {Value + Len * Multiplier, Rest}. parse_topic_filters(subscribe, Bin) -> - [{Topic, #{rh => Rh, rap => Rap, nl => Nl, qos => validate_subqos(QoS), rc => 0}} + [{Topic, #{rh => Rh, rap => Rap, nl => Nl, qos => QoS}} || <> <= Bin]; parse_topic_filters(unsubscribe, Bin) -> @@ -405,9 +423,23 @@ parse_binary_data(<>) -> %% Serialize MQTT Packet %%-------------------------------------------------------------------- +serialize_fun() -> serialize_fun(?DEFAULT_OPTIONS). + +serialize_fun(#mqtt_packet_connect{proto_ver = ProtoVer, properties = ConnProps}) -> + MaxSize = get_property('Maximum-Packet-Size', ConnProps, ?MAX_PACKET_SIZE), + serialize_fun(#{version => ProtoVer, max_size => MaxSize}); + +serialize_fun(#{version := Ver, max_size := MaxSize}) -> + fun(Packet) -> + IoData = serialize(Packet, Ver), + case is_too_large(IoData, MaxSize) of + true -> <<>>; + false -> IoData + end + end. + -spec(serialize(emqx_types:packet()) -> iodata()). -serialize(Packet) -> - serialize(Packet, ?MQTT_PROTO_V4). +serialize(Packet) -> serialize(Packet, ?MQTT_PROTO_V4). -spec(serialize(emqx_types:packet(), emqx_types:version()) -> iodata()). serialize(#mqtt_packet{header = Header, @@ -418,10 +450,10 @@ serialize(#mqtt_packet{header = Header, serialize(#mqtt_packet_header{type = Type, dup = Dup, qos = QoS, - retain = Retain}, VariableBin, PayloadBin) + retain = Retain + }, VariableBin, PayloadBin) when ?CONNECT =< Type andalso Type =< ?AUTH -> Len = iolist_size(VariableBin) + iolist_size(PayloadBin), - (Len =< ?MAX_PACKET_SIZE) orelse error(mqtt_frame_too_large), [<>, serialize_remaining_len(Len), VariableBin, PayloadBin]. @@ -485,10 +517,11 @@ serialize_variable(#mqtt_packet_puback{packet_id = PacketId}, Ver) <>; serialize_variable(#mqtt_packet_puback{packet_id = PacketId, reason_code = ReasonCode, - properties = Properties}, - ?MQTT_PROTO_V5) -> + properties = Properties + }, + Ver = ?MQTT_PROTO_V5) -> [<>, ReasonCode, - serialize_properties(Properties, ?MQTT_PROTO_V5)]; + serialize_properties(Properties, Ver)]; serialize_variable(#mqtt_packet_subscribe{packet_id = PacketId, properties = Properties, @@ -616,8 +649,7 @@ serialize_property('Shared-Subscription-Available', Val) -> serialize_topic_filters(subscribe, TopicFilters, ?MQTT_PROTO_V5) -> << <<(serialize_utf8_string(Topic))/binary, ?RESERVED:2, Rh:2, (flag(Rap)):1,(flag(Nl)):1, QoS:2 >> - || {Topic, #{rh := Rh, rap := Rap, nl := Nl, qos := QoS}} - <- TopicFilters >>; + || {Topic, #{rh := Rh, rap := Rap, nl := Nl, qos := QoS}} <- TopicFilters >>; serialize_topic_filters(subscribe, TopicFilters, _Ver) -> << <<(serialize_utf8_string(Topic))/binary, ?RESERVED:6, QoS:2>> @@ -658,6 +690,16 @@ serialize_variable_byte_integer(N) when N =< ?LOWBITS -> serialize_variable_byte_integer(N) -> <<1:1, (N rem ?HIGHBIT):7, (serialize_variable_byte_integer(N div ?HIGHBIT))/binary>>. +%% Is the frame too large? +-spec(is_too_large(iodata(), pos_integer()) -> boolean()). +is_too_large(IoData, MaxSize) -> + iolist_size(IoData) >= MaxSize. + +get_property(_Key, undefined, Default) -> + Default; +get_property(Key, Props, Default) -> + maps:get(Key, Props, Default). + %% Validate header if sctrict mode. See: mqtt-v5.0: 2.1.3 Flags validate_header(?CONNECT, 0, 0, 0) -> ok; validate_header(?CONNACK, 0, 0, 0) -> ok; @@ -678,9 +720,12 @@ validate_header(?DISCONNECT, 0, 0, 0) -> ok; validate_header(?AUTH, 0, 0, 0) -> ok; validate_header(_Type, _Dup, _QoS, _Rt) -> error(bad_frame_header). -validate_subqos(QoS) when ?QOS_0 =< QoS, QoS =< ?QOS_2 -> - QoS; -validate_subqos(_) -> error(bad_subqos). +validate_packet_id(0) -> error(bad_packet_id); +validate_packet_id(_) -> ok. + +validate_subqos([3|_]) -> error(bad_subqos); +validate_subqos([_|T]) -> validate_subqos(T); +validate_subqos([]) -> ok. bool(0) -> false; bool(1) -> true. @@ -695,3 +740,4 @@ fixqos(?PUBREL, 0) -> 1; fixqos(?SUBSCRIBE, 0) -> 1; fixqos(?UNSUBSCRIBE, 0) -> 1; fixqos(_Type, QoS) -> QoS. + diff --git a/src/emqx_session.erl b/src/emqx_session.erl index 7b3de773c..a945e23f7 100644 --- a/src/emqx_session.erl +++ b/src/emqx_session.erl @@ -117,7 +117,7 @@ %% Enqueue Count enqueue_cnt :: non_neg_integer(), %% Created at - created_at :: erlang:timestamp() + created_at :: pos_integer() }). -opaque(session() :: #session{}). @@ -125,22 +125,49 @@ -type(publish() :: {publish, emqx_types:packet_id(), emqx_types:message()}). -define(DEFAULT_BATCH_N, 1000). --define(ATTR_KEYS, [max_inflight, max_mqueue, retry_interval, - max_awaiting_rel, await_rel_timeout, created_at]). --define(INFO_KEYS, [subscriptions, max_subscriptions, upgrade_qos, inflight, - max_inflight, retry_interval, mqueue_len, max_mqueue, - mqueue_dropped, next_pkt_id, awaiting_rel, max_awaiting_rel, - await_rel_timeout, created_at]). --define(STATS_KEYS, [subscriptions_cnt, max_subscriptions, inflight, max_inflight, - mqueue_len, max_mqueue, mqueue_dropped, awaiting_rel, - max_awaiting_rel, enqueue_cnt]). + +-define(ATTR_KEYS, [inflight_max, + mqueue_max, + retry_interval, + awaiting_rel_max, + await_rel_timeout, + created_at + ]). + +-define(INFO_KEYS, [subscriptions, + subscriptions_max, + upgrade_qos, + inflight, + inflight_max, + retry_interval, + mqueue_len, + mqueue_max, + mqueue_dropped, + next_pkt_id, + awaiting_rel, + awaiting_rel_max, + await_rel_timeout, + created_at + ]). + +-define(STATS_KEYS, [subscriptions_cnt, + subscriptions_max, + inflight, + inflight_max, + mqueue_len, + mqueue_max, + mqueue_dropped, + awaiting_rel, + awaiting_rel_max, + enqueue_cnt + ]). %%-------------------------------------------------------------------- %% Init a session %%-------------------------------------------------------------------- %% @doc Init a session. --spec(init(emqx_types:client_info(), emqx_types:conninfo()) -> session()). +-spec(init(emqx_types:clientinfo(), emqx_types:conninfo()) -> session()). init(#{zone := Zone}, #{receive_maximum := MaxInflight}) -> #session{max_subscriptions = get_env(Zone, max_subscriptions, 0), subscriptions = #{}, @@ -153,7 +180,7 @@ init(#{zone := Zone}, #{receive_maximum := MaxInflight}) -> max_awaiting_rel = get_env(Zone, max_awaiting_rel, 100), await_rel_timeout = get_env(Zone, await_rel_timeout, 3600*1000), enqueue_cnt = 0, - created_at = os:timestamp() + created_at = erlang:system_time(second) }. init_mqueue(Zone) -> @@ -183,19 +210,19 @@ info(subscriptions, #session{subscriptions = Subs}) -> Subs; info(subscriptions_cnt, #session{subscriptions = Subs}) -> maps:size(Subs); -info(max_subscriptions, #session{max_subscriptions = MaxSubs}) -> +info(subscriptions_max, #session{max_subscriptions = MaxSubs}) -> MaxSubs; info(upgrade_qos, #session{upgrade_qos = UpgradeQoS}) -> UpgradeQoS; info(inflight, #session{inflight = Inflight}) -> emqx_inflight:size(Inflight); -info(max_inflight, #session{inflight = Inflight}) -> +info(inflight_max, #session{inflight = Inflight}) -> emqx_inflight:max_size(Inflight); info(retry_interval, #session{retry_interval = Interval}) -> Interval; info(mqueue_len, #session{mqueue = MQueue}) -> emqx_mqueue:len(MQueue); -info(max_mqueue, #session{mqueue = MQueue}) -> +info(mqueue_max, #session{mqueue = MQueue}) -> emqx_mqueue:max_len(MQueue); info(mqueue_dropped, #session{mqueue = MQueue}) -> emqx_mqueue:dropped(MQueue); @@ -203,7 +230,7 @@ info(next_pkt_id, #session{next_pkt_id = PacketId}) -> PacketId; info(awaiting_rel, #session{awaiting_rel = AwaitingRel}) -> maps:size(AwaitingRel); -info(max_awaiting_rel, #session{max_awaiting_rel = MaxAwaitingRel}) -> +info(awaiting_rel_max, #session{max_awaiting_rel = MaxAwaitingRel}) -> MaxAwaitingRel; info(await_rel_timeout, #session{await_rel_timeout = Timeout}) -> Timeout; @@ -224,14 +251,14 @@ takeover(#session{subscriptions = Subs}) -> ok = emqx_broker:unsubscribe(TopicFilter) end, maps:to_list(Subs)). --spec(resume(emqx_types:client_id(), session()) -> ok). +-spec(resume(emqx_types:clientid(), session()) -> ok). resume(ClientId, #session{subscriptions = Subs}) -> %% 1. Subscribe again. lists:foreach(fun({TopicFilter, SubOpts}) -> ok = emqx_broker:subscribe(TopicFilter, ClientId, SubOpts) end, maps:to_list(Subs)). %% 2. Run hooks. - %% ok = emqx_hooks:run('session.resumed', [#{client_id => ClientId}, attrs(Session)]), + %% ok = emqx_hooks:run('session.resumed', [#{clientid => ClientId}, attrs(Session)]), %% TODO: 3. Redeliver: Replay delivery and Dequeue pending messages %%Session. @@ -252,7 +279,7 @@ redeliver(Session = #session{inflight = Inflight}) -> %% Client -> Broker: SUBSCRIBE %%-------------------------------------------------------------------- --spec(subscribe(emqx_types:client_info(), emqx_types:topic(), emqx_types:subopts(), session()) +-spec(subscribe(emqx_types:clientinfo(), emqx_types:topic(), emqx_types:subopts(), session()) -> {ok, session()} | {error, emqx_types:reason_code()}). subscribe(ClientInfo, TopicFilter, SubOpts, Session = #session{subscriptions = Subs}) -> case is_subscriptions_full(Session) @@ -269,7 +296,7 @@ is_subscriptions_full(#session{max_subscriptions = MaxLimit, maps:size(Subs) >= MaxLimit. -compile({inline, [do_subscribe/4]}). -do_subscribe(Client = #{client_id := ClientId}, TopicFilter, SubOpts, +do_subscribe(Client = #{clientid := ClientId}, TopicFilter, SubOpts, Session = #session{subscriptions = Subs}) -> case IsNew = (not maps:is_key(TopicFilter, Subs)) of true -> @@ -285,7 +312,7 @@ do_subscribe(Client = #{client_id := ClientId}, TopicFilter, SubOpts, %% Client -> Broker: UNSUBSCRIBE %%-------------------------------------------------------------------- --spec(unsubscribe(emqx_types:client_info(), emqx_types:topic(), session()) +-spec(unsubscribe(emqx_types:clientinfo(), emqx_types:topic(), session()) -> {ok, session()} | {error, emqx_types:reason_code()}). unsubscribe(ClientInfo, TopicFilter, Session = #session{subscriptions = Subs}) -> case maps:find(TopicFilter, Subs) of diff --git a/src/emqx_stats.erl b/src/emqx_stats.erl index d3454c6f2..b6a61d458 100644 --- a/src/emqx_stats.erl +++ b/src/emqx_stats.erl @@ -37,7 +37,9 @@ , setstat/3 , statsfun/1 , statsfun/2 - , update_interval/2 + ]). + +-export([ update_interval/2 , update_interval/3 , cancel_update/1 ]). diff --git a/src/emqx_ws_connection.erl b/src/emqx_ws_connection.erl index 7207669e3..267dbad60 100644 --- a/src/emqx_ws_connection.erl +++ b/src/emqx_ws_connection.erl @@ -25,9 +25,7 @@ -logger_header("[WsConnection]"). -export([ info/1 - , attrs/1 , stats/1 - , state/1 ]). -export([call/2]). @@ -40,17 +38,17 @@ , terminate/3 ]). --record(ws_connection, { +-record(state, { %% Peername of the ws connection. peername :: emqx_types:peername(), %% Sockname of the ws connection sockname :: emqx_types:peername(), - %% FSM state - fsm_state :: idle | connected | disconnected, + %% Conn state + conn_state :: idle | connected | disconnected, %% Parser State parse_state :: emqx_frame:parse_state(), %% Serialize function - serialize :: fun((emqx_types:packet()) -> iodata()), + serialize :: emqx_frame:serialize_fun(), %% Channel State chan_state :: emqx_channel:channel(), %% Out Pending Packets @@ -59,10 +57,9 @@ stop_reason :: term() }). --type(ws_connection() :: #ws_connection{}). +-type(state() :: #state{}). -define(INFO_KEYS, [socktype, peername, sockname, active_state]). --define(ATTR_KEYS, [socktype, peername, sockname]). -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]). -define(CONN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). @@ -70,53 +67,35 @@ %% API %%-------------------------------------------------------------------- --spec(info(pid()|ws_connection()) -> emqx_types:infos()). +-spec(info(pid()|state()) -> emqx_types:infos()). info(WsPid) when is_pid(WsPid) -> call(WsPid, info); -info(WsConn = #ws_connection{chan_state = ChanState}) -> +info(WsConn = #state{chan_state = ChanState}) -> ChanInfo = emqx_channel:info(ChanState), SockInfo = maps:from_list(info(?INFO_KEYS, WsConn)), maps:merge(ChanInfo, #{sockinfo => SockInfo}). info(Keys, WsConn) when is_list(Keys) -> [{Key, info(Key, WsConn)} || Key <- Keys]; -info(socktype, #ws_connection{}) -> - websocket; -info(peername, #ws_connection{peername = Peername}) -> +info(socktype, _State) -> + ws; +info(peername, #state{peername = Peername}) -> Peername; -info(sockname, #ws_connection{sockname = Sockname}) -> +info(sockname, #state{sockname = Sockname}) -> Sockname; -info(active_state, #ws_connection{}) -> +info(active_state, _State) -> running; -info(chan_state, #ws_connection{chan_state = ChanState}) -> +info(chan_state, #state{chan_state = ChanState}) -> emqx_channel:info(ChanState). --spec(attrs(pid()|ws_connection()) -> emqx_types:attrs()). -attrs(WsPid) when is_pid(WsPid) -> - call(WsPid, attrs); -attrs(WsConn = #ws_connection{chan_state = ChanState}) -> - ChanAttrs = emqx_channel:attrs(ChanState), - SockAttrs = maps:from_list(info(?ATTR_KEYS, WsConn)), - maps:merge(ChanAttrs, #{sockinfo => SockAttrs}). - --spec(stats(pid()|ws_connection()) -> emqx_types:stats()). +-spec(stats(pid()|state()) -> emqx_types:stats()). stats(WsPid) when is_pid(WsPid) -> call(WsPid, stats); -stats(#ws_connection{chan_state = ChanState}) -> - ProcStats = emqx_misc:proc_stats(), - SockStats = wsock_stats(), - ConnStats = conn_stats(), - ChanStats = emqx_channel:stats(ChanState), - lists:append([ProcStats, SockStats, ConnStats, ChanStats]). - -wsock_stats() -> - [{Key, emqx_pd:get_counter(Key)} || Key <- ?SOCK_STATS]. - -conn_stats() -> - [{Name, emqx_pd:get_counter(Name)} || Name <- ?CONN_STATS]. - --spec(state(pid()) -> ws_connection()). -state(WsPid) -> call(WsPid, state). +stats(#state{chan_state = ChanState}) -> + [{sock_stats, emqx_pd:get_counters(?SOCK_STATS)}, + {conn_stats, emqx_pd:get_counters(?CONN_STATS)}, + {chan_stats, emqx_channel:stats(ChanState)}, + {proc_stats, emqx_misc:proc_stats()}]. %% kick|discard|takeover -spec(call(pid(), Req :: term()) -> Reply :: term()). @@ -177,44 +156,46 @@ websocket_init([Req, Opts]) -> [Error, Reason]), undefined end, - ChanState = emqx_channel:init(#{peername => Peername, - sockname => Sockname, - peercert => Peercert, - ws_cookie => WsCookie, - protocol => mqtt, - conn_mod => ?MODULE - }, Opts), + ConnInfo = #{socktype => ws, + peername => Peername, + sockname => Sockname, + peercert => Peercert, + ws_cookie => WsCookie, + conn_mod => ?MODULE + }, Zone = proplists:get_value(zone, Opts), - MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE), - ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}), + FrameOpts = emqx_zone:frame_options(Zone), + ParseState = emqx_frame:initial_parse_state(FrameOpts), + Serialize = emqx_frame:serialize_fun(), + ChanState = emqx_channel:init(ConnInfo, Opts), emqx_logger:set_metadata_peername(esockd_net:format(Peername)), - {ok, #ws_connection{peername = Peername, - sockname = Sockname, - fsm_state = idle, - parse_state = ParseState, - chan_state = ChanState, - pendings = [], - serialize = serialize_fun(?MQTT_PROTO_V5, undefined)}}. + {ok, #state{peername = Peername, + sockname = Sockname, + conn_state = idle, + parse_state = ParseState, + serialize = Serialize, + chan_state = ChanState, + pendings = [] + }}. websocket_handle({binary, Data}, State) when is_list(Data) -> websocket_handle({binary, iolist_to_binary(Data)}, State); -websocket_handle({binary, Data}, State = #ws_connection{chan_state = ChanState}) -> +websocket_handle({binary, Data}, State = #state{chan_state = ChanState}) -> ?LOG(debug, "RECV ~p", [Data]), Oct = iolist_size(Data), ok = inc_recv_stats(1, Oct), NChanState = emqx_channel:received(Oct, ChanState), - NState = State#ws_connection{chan_state = NChanState}, + NState = State#state{chan_state = NChanState}, process_incoming(Data, NState); %% Pings should be replied with pongs, cowboy does it automatically %% Pongs can be safely ignored. Clause here simply prevents crash. -websocket_handle(Frame, State) - when Frame =:= ping; Frame =:= pong -> +websocket_handle(Frame, State) when Frame =:= ping; Frame =:= pong -> {ok, State}; -websocket_handle({FrameType, _}, State) - when FrameType =:= ping; FrameType =:= pong -> +websocket_handle({FrameType, _}, State) when FrameType =:= ping; + FrameType =:= pong -> {ok, State}; websocket_handle({FrameType, _}, State) -> @@ -225,10 +206,6 @@ websocket_info({call, From, info}, State) -> gen_server:reply(From, info(State)), {ok, State}; -websocket_info({call, From, attrs}, State) -> - gen_server:reply(From, attrs(State)), - {ok, State}; - websocket_info({call, From, stats}, State) -> gen_server:reply(From, stats(State)), {ok, State}; @@ -237,63 +214,43 @@ websocket_info({call, From, state}, State) -> gen_server:reply(From, State), {ok, State}; -websocket_info({call, From, Req}, State = #ws_connection{chan_state = ChanState}) -> +websocket_info({call, From, Req}, State = #state{chan_state = ChanState}) -> case emqx_channel:handle_call(Req, ChanState) of {ok, Reply, NChanState} -> _ = gen_server:reply(From, Reply), - {ok, State#ws_connection{chan_state = NChanState}}; + {ok, State#state{chan_state = NChanState}}; {stop, Reason, Reply, NChanState} -> _ = gen_server:reply(From, Reply), - stop(Reason, State#ws_connection{chan_state = NChanState}) + stop(Reason, State#state{chan_state = NChanState}) end; -websocket_info({cast, Msg}, State = #ws_connection{chan_state = ChanState}) -> - case emqx_channel:handle_cast(Msg, ChanState) of +websocket_info({cast, Msg}, State = #state{chan_state = ChanState}) -> + case emqx_channel:handle_info(Msg, ChanState) of + ok -> {ok, State}; {ok, NChanState} -> - {ok, State#ws_connection{chan_state = NChanState}}; + {ok, State#state{chan_state = NChanState}}; {stop, Reason, NChanState} -> - stop(Reason, State#ws_connection{chan_state = NChanState}) + stop(Reason, State#state{chan_state = NChanState}) end; -websocket_info({incoming, {error, Reason}}, State = #ws_connection{fsm_state = idle}) -> - stop({shutdown, Reason}, State); - -websocket_info({incoming, {error, Reason}}, State = #ws_connection{fsm_state = connected, chan_state = ChanState}) -> - case emqx_channel:handle_out({disconnect, emqx_reason_codes:mqtt_frame_error(Reason)}, ChanState) of - {wait_session_expire, _, NChanState} -> - ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]), - disconnected(State#ws_connection{chan_state= NChanState}); - {wait_session_expire, _, OutPackets, NChanState} -> - ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]), - disconnected(enqueue(OutPackets, State#ws_connection{chan_state = NChanState})) - end; - -websocket_info({incoming, {error, _Reason}}, State = #ws_connection{fsm_state = disconnected}) -> - reply(State); - -websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, - State = #ws_connection{fsm_state = idle}) -> - #mqtt_packet_connect{proto_ver = ProtoVer, properties = Properties} = ConnPkt, - MaxPacketSize = emqx_mqtt_props:get('Maximum-Packet-Size', Properties, undefined), - NState = State#ws_connection{serialize = serialize_fun(ProtoVer, MaxPacketSize)}, +websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State) -> + NState = State#state{serialize = emqx_frame:serialize_fun(ConnPkt)}, handle_incoming(Packet, fun connected/1, NState); -websocket_info({incoming, Packet}, State = #ws_connection{fsm_state = idle}) -> - ?LOG(warning, "Unexpected incoming: ~p", [Packet]), - stop(unexpected_incoming_packet, State); - -websocket_info({incoming, Packet}, State = #ws_connection{fsm_state = connected}) - when is_record(Packet, mqtt_packet) -> +websocket_info({incoming, Packet}, State) when is_record(Packet, mqtt_packet) -> handle_incoming(Packet, fun reply/1, State); +websocket_info({incoming, FrameError = {frame_error, _Reason}}, State) -> + handle_incoming(FrameError, State); + websocket_info(Deliver = {deliver, _Topic, _Msg}, - State = #ws_connection{chan_state = ChanState}) -> + State = #state{chan_state = ChanState}) -> Delivers = emqx_misc:drain_deliver([Deliver]), case emqx_channel:handle_out({deliver, Delivers}, ChanState) of {ok, NChanState} -> - reply(State#ws_connection{chan_state = NChanState}); + reply(State#state{chan_state = NChanState}); {ok, Packets, NChanState} -> - reply(enqueue(Packets, State#ws_connection{chan_state = NChanState})) + reply(enqueue(Packets, State#state{chan_state = NChanState})) end; websocket_info({timeout, TRef, keepalive}, State) when is_reference(TRef) -> @@ -312,15 +269,15 @@ websocket_info({shutdown, Reason}, State) -> websocket_info({stop, Reason}, State) -> stop(Reason, State); -websocket_info(Info, State = #ws_connection{chan_state = ChanState}) -> +websocket_info(Info, State = #state{chan_state = ChanState}) -> case emqx_channel:handle_info(Info, ChanState) of {ok, NChanState} -> - {ok, State#ws_connection{chan_state = NChanState}}; + {ok, State#state{chan_state = NChanState}}; {stop, Reason, NChanState} -> - stop(Reason, State#ws_connection{chan_state = NChanState}) + stop(Reason, State#state{chan_state = NChanState}) end. -terminate(SockError, _Req, #ws_connection{chan_state = ChanState, +terminate(SockError, _Req, #state{chan_state = ChanState, stop_reason = Reason}) -> ?LOG(debug, "Terminated for ~p, sockerror: ~p", [Reason, SockError]), emqx_channel:terminate(Reason, ChanState). @@ -328,31 +285,37 @@ terminate(SockError, _Req, #ws_connection{chan_state = ChanState, %%-------------------------------------------------------------------- %% Connected callback -connected(State = #ws_connection{chan_state = ChanState}) -> - ok = emqx_channel:handle_cast({register, attrs(State), stats(State)}, ChanState), - reply(State#ws_connection{fsm_state = connected}). +connected(State = #state{chan_state = ChanState}) -> + ChanAttrs = emqx_channel:attrs(ChanState), + SockAttrs = #{active_state => running}, + Attrs = maps:merge(ChanAttrs, #{sockinfo => SockAttrs}), + ok = emqx_channel:handle_info({register, Attrs, stats(State)}, ChanState), + reply(State#state{conn_state = connected}). %%-------------------------------------------------------------------- -%% Disconnected callback +%% Close -disconnected(State) -> - reply(State#ws_connection{fsm_state = disconnected}). +close(Reason, State) -> + ?LOG(warning, "Closed for ~p", [Reason]), + reply(State#state{conn_state = disconnected}). %%-------------------------------------------------------------------- %% Handle timeout -handle_timeout(TRef, Msg, State = #ws_connection{chan_state = ChanState}) -> +handle_timeout(TRef, Msg, State = #state{chan_state = ChanState}) -> case emqx_channel:handle_timeout(TRef, Msg, ChanState) of {ok, NChanState} -> - {ok, State#ws_connection{chan_state = NChanState}}; + {ok, State#state{chan_state = NChanState}}; {ok, Packets, NChanState} -> - NState = State#ws_connection{chan_state = NChanState}, + NState = State#state{chan_state = NChanState}, reply(enqueue(Packets, NState)); - {wait_session_expire, Reason, NChanState} -> - ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]), - disconnected(State#ws_connection{chan_state = NChanState}); + {close, Reason, NChanState} -> + close(Reason, State#state{chan_state = NChanState}); + {close, Reason, OutPackets, NChanState} -> + NState = State#state{chan_state= NChanState}, + close(Reason, enqueue(OutPackets, NState)); {stop, Reason, NChanState} -> - stop(Reason, State#ws_connection{chan_state = NChanState}) + stop(Reason, State#state{chan_state = NChanState}) end. %%-------------------------------------------------------------------- @@ -361,20 +324,18 @@ handle_timeout(TRef, Msg, State = #ws_connection{chan_state = ChanState}) -> process_incoming(<<>>, State) -> {ok, State}; -process_incoming(Data, State = #ws_connection{parse_state = ParseState}) -> +process_incoming(Data, State = #state{parse_state = ParseState}) -> try emqx_frame:parse(Data, ParseState) of {more, NParseState} -> - {ok, State#ws_connection{parse_state = NParseState}}; + {ok, State#state{parse_state = NParseState}}; {ok, Packet, Rest, NParseState} -> self() ! {incoming, Packet}, - process_incoming(Rest, State#ws_connection{parse_state = NParseState}); - {error, Reason} -> - self() ! {incoming, {error, Reason}}, - {ok, State} + process_incoming(Rest, State#state{parse_state = NParseState}) catch error:Reason:Stk -> - ?LOG(error, "~nParse failed for ~p~nStacktrace: ~p~nFrame data: ~p", [Reason, Stk, Data]), - self() ! {incoming, {error, Reason}}, + ?LOG(error, "~nParse failed for ~p~nStacktrace: ~p~nFrame data: ~p", + [Reason, Stk, Data]), + self() ! {incoming, {frame_error, Reason}}, {ok, State} end. @@ -382,55 +343,59 @@ process_incoming(Data, State = #ws_connection{parse_state = ParseState}) -> %% Handle incoming packets handle_incoming(Packet = ?PACKET(Type), SuccFun, - State = #ws_connection{chan_state = ChanState}) -> + State = #state{chan_state = ChanState}) -> _ = inc_incoming_stats(Type), - ok = emqx_metrics:inc_recv(Packet), + _ = emqx_metrics:inc_recv(Packet), ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]), case emqx_channel:handle_in(Packet, ChanState) of {ok, NChanState} -> - SuccFun(State#ws_connection{chan_state= NChanState}); + SuccFun(State#state{chan_state= NChanState}); {ok, OutPackets, NChanState} -> - NState = State#ws_connection{chan_state= NChanState}, + NState = State#state{chan_state= NChanState}, SuccFun(enqueue(OutPackets, NState)); - {wait_session_expire, Reason, NChanState} -> - ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]), - disconnected(State#ws_connection{chan_state = NChanState}); - {wait_session_expire, Reason, OutPackets, NChanState} -> - ?LOG(debug, "Disconnect and wait for session to expire due to ~p", [Reason]), - disconnected(enqueue(OutPackets, State#ws_connection{chan_state = NChanState})); + {close, Reason, NChanState} -> + close(Reason, State#state{chan_state = NChanState}); + {close, Reason, OutPackets, NChanState} -> + NState = State#state{chan_state= NChanState}, + close(Reason, enqueue(OutPackets, NState)); {stop, Reason, NChanState} -> - stop(Reason, State#ws_connection{chan_state = NChanState}); + stop(Reason, State#state{chan_state = NChanState}); {stop, Reason, OutPackets, NChanState} -> - NState = State#ws_connection{chan_state= NChanState}, + NState = State#state{chan_state= NChanState}, + stop(Reason, enqueue(OutPackets, NState)) + end. + +handle_incoming(FrameError = {frame_error, _Reason}, + State = #state{chan_state = ChanState}) -> + case emqx_channel:handle_in(FrameError, ChanState) of + {stop, Reason, NChanState} -> + stop(Reason, State#state{chan_state = NChanState}); + {stop, Reason, OutPackets, NChanState} -> + NState = State#state{chan_state = NChanState}, stop(Reason, enqueue(OutPackets, NState)) end. %%-------------------------------------------------------------------- %% Handle outgoing packets -handle_outgoing(Packets, State = #ws_connection{serialize = Serialize, - chan_state = ChanState}) -> - Data = lists:map(Serialize, Packets), - Oct = iolist_size(Data), +handle_outgoing(Packets, State = #state{chan_state = ChanState}) -> + IoData = lists:map(serialize_and_inc_stats_fun(State), Packets), + Oct = iolist_size(IoData), ok = inc_sent_stats(length(Packets), Oct), NChanState = emqx_channel:sent(Oct, ChanState), - {{binary, Data}, State#ws_connection{chan_state = NChanState}}. + {{binary, IoData}, State#state{chan_state = NChanState}}. -%%-------------------------------------------------------------------- -%% Serialize fun - -serialize_fun(ProtoVer, MaxPacketSize) -> +%% TODO: Duplicated with emqx_channel:serialize_and_inc_stats_fun/1 +serialize_and_inc_stats_fun(#state{serialize = Serialize}) -> fun(Packet = ?PACKET(Type)) -> - IoData = emqx_frame:serialize(Packet, ProtoVer), - case Type =/= ?PUBLISH orelse MaxPacketSize =:= undefined orelse iolist_size(IoData) =< MaxPacketSize of - true -> - ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), - _ = inc_outgoing_stats(Type), - _ = emqx_metrics:inc_sent(Packet), - IoData; - false -> - ?LOG(warning, "DROP ~s due to oversize packet size", [emqx_packet:format(Packet)]), - <<"">> + case Serialize(Packet) of + <<>> -> ?LOG(warning, "~s is discarded due to the frame is too large!", + [emqx_packet:format(Packet)]), + <<>>; + Data -> _ = inc_outgoing_stats(Type), + _ = emqx_metrics:inc_sent(Packet), + ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), + Data end end. @@ -469,21 +434,21 @@ inc_sent_stats(Cnt, Oct) -> -compile({inline, [reply/1]}). -reply(State = #ws_connection{pendings = []}) -> +reply(State = #state{pendings = []}) -> {ok, State}; -reply(State = #ws_connection{pendings = Pendings}) -> +reply(State = #state{pendings = Pendings}) -> {Reply, NState} = handle_outgoing(Pendings, State), - {reply, Reply, NState#ws_connection{pendings = []}}. + {reply, Reply, NState#state{pendings = []}}. -stop(Reason, State = #ws_connection{pendings = []}) -> - {stop, State#ws_connection{stop_reason = Reason}}; -stop(Reason, State = #ws_connection{pendings = Pendings}) -> +stop(Reason, State = #state{pendings = []}) -> + {stop, State#state{stop_reason = Reason}}; +stop(Reason, State = #state{pendings = Pendings}) -> {Reply, State1} = handle_outgoing(Pendings, State), - State2 = State1#ws_connection{pendings = [], stop_reason = Reason}, + State2 = State1#state{pendings = [], stop_reason = Reason}, {reply, [Reply, close], State2}. enqueue(Packet, State) when is_record(Packet, mqtt_packet) -> enqueue([Packet], State); -enqueue(Packets, State = #ws_connection{pendings = Pendings}) -> - State#ws_connection{pendings = lists:append(Pendings, Packets)}. +enqueue(Packets, State = #state{pendings = Pendings}) -> + State#state{pendings = lists:append(Pendings, Packets)}. diff --git a/src/emqx_zone.erl b/src/emqx_zone.erl index 9d45c64e1..90d56eb0e 100644 --- a/src/emqx_zone.erl +++ b/src/emqx_zone.erl @@ -19,6 +19,7 @@ -behaviour(gen_server). -include("emqx.hrl"). +-include("emqx_mqtt.hrl"). -include("logger.hrl"). -include("types.hrl"). @@ -27,7 +28,10 @@ %% APIs -export([start_link/0, stop/0]). --export([ use_username_as_clientid/1 +-export([ frame_options/1 + , mqtt_strict_mode/1 + , max_packet_size/1 + , use_username_as_clientid/1 , enable_stats/1 , enable_acl/1 , enable_ban/1 @@ -40,6 +44,8 @@ , force_shutdown_policy/1 ]). +-export([check_oom/2]). + -export([ get_env/2 , get_env/3 , set_env/3 @@ -76,6 +82,20 @@ start_link() -> gen_server:start_link({local, ?SERVER}, ?MODULE, [], []). +-spec(frame_options(zone()) -> emqx_frame:options()). +frame_options(Zone) -> + #{strict_mode => mqtt_strict_mode(Zone), + max_size => max_packet_size(Zone) + }. + +-spec(mqtt_strict_mode(zone()) -> boolean()). +mqtt_strict_mode(Zone) -> + get_env(Zone, mqtt_strict_mode, false). + +-spec(max_packet_size(zone()) -> integer()). +max_packet_size(Zone) -> + get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE). + -spec(use_username_as_clientid(zone()) -> boolean()). use_username_as_clientid(Zone) -> get_env(Zone, use_username_as_clientid, false). @@ -120,6 +140,19 @@ force_gc_policy(Zone) -> force_shutdown_policy(Zone) -> get_env(Zone, force_shutdown_policy). +-spec(check_oom(zone(), fun()) -> ok | term()). +check_oom(Zone, Action) -> + case emqx_zone:force_shutdown_policy(Zone) of + undefined -> ok; + Policy -> do_check_oom(emqx_oom:init(Policy), Action) + end. + +do_check_oom(OomPolicy, Action) -> + case emqx_oom:check(OomPolicy) of + ok -> ok; + Shutdown -> Action(Shutdown) + end. + -spec(get_env(maybe(zone()), atom()) -> maybe(term())). get_env(undefined, Key) -> emqx:get_env(Key); get_env(Zone, Key) -> diff --git a/test/emqx_channel_SUITE.erl b/test/emqx_channel_SUITE.erl index fdd7c1b56..d6bf61b58 100644 --- a/test/emqx_channel_SUITE.erl +++ b/test/emqx_channel_SUITE.erl @@ -28,6 +28,21 @@ -include("emqx_mqtt.hrl"). -include_lib("eunit/include/eunit.hrl"). +-define(DEFAULT_CONNINFO, + #{peername => {{127,0,0,1}, 3456}, + sockname => {{127,0,0,1}, 1883}, + conn_mod => emqx_connection, + proto_name => <<"MQTT">>, + proto_ver => ?MQTT_PROTO_V5, + clean_start => true, + keepalive => 30, + clientid => <<"clientid">>, + username => <<"username">>, + conn_props => #{}, + receive_maximum => 100, + expiry_interval => 0 + }). + all() -> emqx_ct:all(?MODULE). init_per_suite(Config) -> @@ -50,7 +65,7 @@ t_handle_connect(_) -> clean_start = true, keepalive = 30, properties = #{}, - client_id = <<"clientid">>, + clientid = <<"clientid">>, username = <<"username">>, password = <<"passwd">> }, @@ -58,20 +73,21 @@ t_handle_connect(_) -> fun(Channel) -> {ok, ?CONNACK_PACKET(?RC_SUCCESS), Channel1} = handle_in(?CONNECT_PACKET(ConnPkt), Channel), - #{client_id := ClientId, username := Username} + #{clientid := ClientId, username := Username} = emqx_channel:info(client, Channel1), ?assertEqual(<<"clientid">>, ClientId), ?assertEqual(<<"username">>, Username) end). -t_handle_publish_qos0(_) -> +t_handle_in_publish_qos0(_) -> with_channel( fun(Channel) -> Publish = ?PUBLISH_PACKET(?QOS_0, <<"topic">>, undefined, <<"payload">>), - {ok, Channel} = handle_in(Publish, Channel) + {ok, Channel1} = handle_in(Publish, Channel), + ?assertEqual(#{publish_in => 1}, emqx_channel:info(pub_stats, Channel1)) end). -t_handle_publish_qos1(_) -> +t_handle_in_publish_qos1(_) -> with_channel( fun(Channel) -> Publish = ?PUBLISH_PACKET(?QOS_1, <<"topic">>, 1, <<"payload">>), @@ -91,30 +107,34 @@ t_handle_publish_qos2(_) -> ?assertEqual(2, AwaitingRel) end). -t_handle_puback(_) -> +t_handle_in_puback(_) -> with_channel( fun(Channel) -> - {ok, Channel} = handle_in(?PUBACK_PACKET(1, ?RC_SUCCESS), Channel) + {ok, Channel1} = handle_in(?PUBACK_PACKET(1, ?RC_SUCCESS), Channel), + ?assertEqual(#{puback_in => 1}, emqx_channel:info(pub_stats, Channel1)) end). -t_handle_pubrec(_) -> +t_handle_in_pubrec(_) -> with_channel( fun(Channel) -> - {ok, ?PUBREL_PACKET(1, ?RC_PACKET_IDENTIFIER_NOT_FOUND), Channel} - = handle_in(?PUBREC_PACKET(1, ?RC_SUCCESS), Channel) + {ok, ?PUBREL_PACKET(1, ?RC_PACKET_IDENTIFIER_NOT_FOUND), Channel1} + = handle_in(?PUBREC_PACKET(1, ?RC_SUCCESS), Channel), + ?assertEqual(#{pubrec_in => 1, pubrel_out => 1}, emqx_channel:info(pub_stats, Channel1)) end). -t_handle_pubrel(_) -> +t_handle_in_pubrel(_) -> with_channel( fun(Channel) -> - {ok, ?PUBCOMP_PACKET(1, ?RC_PACKET_IDENTIFIER_NOT_FOUND), Channel} - = handle_in(?PUBREL_PACKET(1, ?RC_SUCCESS), Channel) + {ok, ?PUBCOMP_PACKET(1, ?RC_PACKET_IDENTIFIER_NOT_FOUND), Channel1} + = handle_in(?PUBREL_PACKET(1, ?RC_SUCCESS), Channel), + ?assertEqual(#{pubrel_in => 1, pubcomp_out => 1}, emqx_channel:info(pub_stats, Channel1)) end). -t_handle_pubcomp(_) -> +t_handle_in_pubcomp(_) -> with_channel( fun(Channel) -> - {ok, Channel} = handle_in(?PUBCOMP_PACKET(1, ?RC_SUCCESS), Channel) + {ok, Channel1} = handle_in(?PUBCOMP_PACKET(1, ?RC_SUCCESS), Channel), + ?assertEqual(#{pubcomp_in => 1}, emqx_channel:info(pub_stats, Channel1)) end). t_handle_subscribe(_) -> @@ -144,14 +164,15 @@ t_handle_pingreq(_) -> t_handle_disconnect(_) -> with_channel( fun(Channel) -> - {wait_session_expire, {shutdown, normal}, Channel1} = handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), Channel), + {stop, normal, Channel1} = handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), Channel), ?assertEqual(undefined, emqx_channel:info(will_msg, Channel1)) end). -t_handle_auth(_) -> +t_handle_in_auth(_) -> with_channel( fun(Channel) -> - {ok, Channel} = handle_in(?AUTH_PACKET(), Channel) + Packet = ?DISCONNECT_PACKET(?RC_IMPLEMENTATION_SPECIFIC_ERROR), + {stop, implementation_specific_error, Packet, Channel} = handle_in(?AUTH_PACKET(), Channel) end). %%-------------------------------------------------------------------- @@ -175,13 +196,13 @@ t_handle_deliver(_) -> %% Test cases for handle_out %%-------------------------------------------------------------------- -t_handle_connack(_) -> +t_handle_out_connack(_) -> ConnPkt = #mqtt_packet_connect{ proto_name = <<"MQTT">>, proto_ver = ?MQTT_PROTO_V4, clean_start = true, properties = #{}, - client_id = <<"clientid">> + clientid = <<"clientid">> }, with_channel( fun(Channel) -> @@ -199,39 +220,44 @@ t_handle_out_publish(_) -> Pub1 = {publish, 1, emqx_message:make(<<"c">>, ?QOS_1, <<"t">>, <<"qos1">>)}, {ok, ?PUBLISH_PACKET(?QOS_0), Channel} = handle_out(Pub0, Channel), {ok, ?PUBLISH_PACKET(?QOS_1), Channel} = handle_out(Pub1, Channel), - {ok, Packets, Channel} = handle_out({publish, [Pub0, Pub1]}, Channel), - ?assertEqual(2, length(Packets)) + {ok, Packets, Channel1} = handle_out({publish, [Pub0, Pub1]}, Channel), + ?assertEqual(2, length(Packets)), + ?assertEqual(#{publish_out => 2}, emqx_channel:info(pub_stats, Channel1)) end). t_handle_out_puback(_) -> with_channel( fun(Channel) -> {ok, Channel} = handle_out({puberr, ?RC_NOT_AUTHORIZED}, Channel), - {ok, ?PUBACK_PACKET(1, ?RC_SUCCESS), Channel} - = handle_out({puback, 1, ?RC_SUCCESS}, Channel) + {ok, ?PUBACK_PACKET(1, ?RC_SUCCESS), Channel1} + = handle_out({puback, 1, ?RC_SUCCESS}, Channel), + ?assertEqual(#{puback_out => 1}, emqx_channel:info(pub_stats, Channel1)) end). t_handle_out_pubrec(_) -> with_channel( fun(Channel) -> - {ok, ?PUBREC_PACKET(4, ?RC_SUCCESS), Channel} - = handle_out({pubrec, 4, ?RC_SUCCESS}, Channel) + {ok, ?PUBREC_PACKET(4, ?RC_SUCCESS), Channel1} + = handle_out({pubrec, 4, ?RC_SUCCESS}, Channel), + ?assertEqual(#{pubrec_out => 1}, emqx_channel:info(pub_stats, Channel1)) end). t_handle_out_pubrel(_) -> with_channel( fun(Channel) -> - {ok, ?PUBREL_PACKET(2), Channel} + {ok, ?PUBREL_PACKET(2), Channel1} = handle_out({pubrel, 2, ?RC_SUCCESS}, Channel), - {ok, ?PUBREL_PACKET(3, ?RC_SUCCESS), Channel} - = handle_out({pubrel, 3, ?RC_SUCCESS}, Channel) + {ok, ?PUBREL_PACKET(3, ?RC_SUCCESS), Channel2} + = handle_out({pubrel, 3, ?RC_SUCCESS}, Channel1), + ?assertEqual(#{pubrel_out => 2}, emqx_channel:info(pub_stats, Channel2)) end). t_handle_out_pubcomp(_) -> with_channel( fun(Channel) -> - {ok, ?PUBCOMP_PACKET(5, ?RC_SUCCESS), Channel} - = handle_out({pubcomp, 5, ?RC_SUCCESS}, Channel) + {ok, ?PUBCOMP_PACKET(5, ?RC_SUCCESS), Channel1} + = handle_out({pubcomp, 5, ?RC_SUCCESS}, Channel), + ?assertEqual(#{pubcomp_out => 1}, emqx_channel:info(pub_stats, Channel1)) end). t_handle_out_suback(_) -> @@ -279,32 +305,22 @@ t_terminate(_) -> %%-------------------------------------------------------------------- with_channel(TestFun) -> - ConnInfo = #{peername => {{127,0,0,1}, 3456}, - sockname => {{127,0,0,1}, 1883}, - protocol => mqtt, - conn_mod => emqx_connection, - proto_name => <<"MQTT">>, - proto_ver => ?MQTT_PROTO_V5, - clean_start => true, - keepalive => 30, - client_id => <<"clientid">>, - username => <<"username">>, - conn_props => #{}, - receive_maximum => 100, - expiry_interval => 60 - }, + with_channel(#{}, TestFun). + +with_channel(ConnInfo, TestFun) -> + ConnInfo1 = maps:merge(?DEFAULT_CONNINFO, ConnInfo), ClientInfo = #{zone => <<"external">>, protocol => mqtt, peerhost => {127,0,0,1}, - client_id => <<"clientid">>, + clientid => <<"clientid">>, username => <<"username">>, peercert => undefined, is_bridge => false, is_superuser => false, mountpoint => undefined }, - Channel = emqx_channel:init(ConnInfo, [{zone, testing}]), - Session = emqx_session:init(ClientInfo, ConnInfo), - Channel1 = emqx_channel:set_field(client, ClientInfo, Channel), + Channel = emqx_channel:init(ConnInfo1, [{zone, testing}]), + Session = emqx_session:init(ClientInfo, ConnInfo1), + Channel1 = emqx_channel:set_field(clientinfo, ClientInfo, Channel), TestFun(emqx_channel:set_field(session, Session, Channel1)). diff --git a/test/emqx_client_SUITE.erl b/test/emqx_client_SUITE.erl index 19557143b..e83aef057 100644 --- a/test/emqx_client_SUITE.erl +++ b/test/emqx_client_SUITE.erl @@ -95,10 +95,10 @@ t_cm(_) -> IdleTimeout = emqx_zone:get_env(external, idle_timeout, 30000), emqx_zone:set_env(external, idle_timeout, 1000), ClientId = <<"myclient">>, - {ok, C} = emqtt:start_link([{client_id, ClientId}]), + {ok, C} = emqtt:start_link([{clientid, ClientId}]), {ok, _} = emqtt:connect(C), ct:sleep(50), - #{client := #{client_id := ClientId}} = emqx_cm:get_chan_attrs(ClientId), + #{client := #{clientid := ClientId}} = emqx_cm:get_chan_attrs(ClientId), emqtt:subscribe(C, <<"mytopic">>, 0), ct:sleep(1200), Stats = emqx_cm:get_chan_stats(ClientId), @@ -135,13 +135,13 @@ t_will_message(_Config) -> t_offline_message_queueing(_) -> {ok, C1} = emqtt:start_link([{clean_start, false}, - {client_id, <<"c1">>}]), + {clientid, <<"c1">>}]), {ok, _} = emqtt:connect(C1), {ok, _, [2]} = emqtt:subscribe(C1, nth(6, ?WILD_TOPICS), 2), ok = emqtt:disconnect(C1), {ok, C2} = emqtt:start_link([{clean_start, true}, - {client_id, <<"c2">>}]), + {clientid, <<"c2">>}]), {ok, _} = emqtt:connect(C2), ok = emqtt:publish(C2, nth(2, ?TOPICS), <<"qos 0">>, 0), @@ -149,8 +149,7 @@ t_offline_message_queueing(_) -> {ok, _} = emqtt:publish(C2, nth(4, ?TOPICS), <<"qos 2">>, 2), timer:sleep(10), emqtt:disconnect(C2), - {ok, C3} = emqtt:start_link([{clean_start, false}, - {client_id, <<"c1">>}]), + {ok, C3} = emqtt:start_link([{clean_start, false}, {clientid, <<"c1">>}]), {ok, _} = emqtt:connect(C3), timer:sleep(10), @@ -198,8 +197,7 @@ t_overlapping_subscriptions(_) -> t_redelivery_on_reconnect(_) -> ct:pal("Redelivery on reconnect test starting"), - {ok, C1} = emqtt:start_link([{clean_start, false}, - {client_id, <<"c">>}]), + {ok, C1} = emqtt:start_link([{clean_start, false}, {clientid, <<"c">>}]), {ok, _} = emqtt:connect(C1), {ok, _, [2]} = emqtt:subscribe(C1, nth(7, ?WILD_TOPICS), 2), @@ -212,8 +210,7 @@ t_redelivery_on_reconnect(_) -> timer:sleep(10), ok = emqtt:disconnect(C1), ?assertEqual(0, length(recv_msgs(2))), - {ok, C2} = emqtt:start_link([{clean_start, false}, - {client_id, <<"c">>}]), + {ok, C2} = emqtt:start_link([{clean_start, false}, {clientid, <<"c">>}]), {ok, _} = emqtt:connect(C2), timer:sleep(10), diff --git a/test/emqx_connection_SUITE.erl b/test/emqx_connection_SUITE.erl index d321d2c85..39cc67dee 100644 --- a/test/emqx_connection_SUITE.erl +++ b/test/emqx_connection_SUITE.erl @@ -33,7 +33,7 @@ end_per_suite(_Config) -> t_basic(_) -> Topic = <<"TopicA">>, - {ok, C} = emqtt:start_link([{port, 1883}, {client_id, <<"hello">>}]), + {ok, C} = emqtt:start_link([{port, 1883}, {clientid, <<"hello">>}]), {ok, _} = emqtt:connect(C), {ok, _, [1]} = emqtt:subscribe(C, Topic, qos1), {ok, _, [2]} = emqtt:subscribe(C, Topic, qos2), diff --git a/test/emqx_frame_SUITE.erl b/test/emqx_frame_SUITE.erl index 09ef7d901..f70a815fe 100644 --- a/test/emqx_frame_SUITE.erl +++ b/test/emqx_frame_SUITE.erl @@ -129,8 +129,8 @@ t_parse_cont(_) -> t_parse_frame_too_large(_) -> Packet = ?PUBLISH_PACKET(?QOS_1, <<"t">>, 1, payload(1000)), - ?catch_error(mqtt_frame_too_large, parse_serialize(Packet, #{max_size => 256})), - ?catch_error(mqtt_frame_too_large, parse_serialize(Packet, #{max_size => 512})), + ?catch_error(frame_too_large, parse_serialize(Packet, #{max_size => 256})), + ?catch_error(frame_too_large, parse_serialize(Packet, #{max_size => 512})), ?assertEqual(Packet, parse_serialize(Packet, #{max_size => 2048, version => ?MQTT_PROTO_V4})). t_serialize_parse_connect(_) -> @@ -411,7 +411,7 @@ t_serialize_parse_pubcomp_v5(_) -> t_serialize_parse_subscribe(_) -> %% SUBSCRIBE(Q1, R0, D0, PacketId=2, TopicTable=[{<<"TopicA">>,2}]) Bin = <>, - TopicOpts = #{nl => 0 , rap => 0, rc => 0, rh => 0, qos => 2}, + TopicOpts = #{nl => 0 , rap => 0, rh => 0, qos => 2}, TopicFilters = [{<<"TopicA">>, TopicOpts}], Packet = ?SUBSCRIBE_PACKET(2, TopicFilters), ?assertEqual(Bin, serialize_to_binary(Packet)), @@ -424,8 +424,8 @@ t_serialize_parse_subscribe(_) -> ?catch_error(bad_subqos, parse_serialize(?SUBSCRIBE_PACKET(1, [{<<"t">>, #{qos => 3}}]))). t_serialize_parse_subscribe_v5(_) -> - TopicFilters = [{<<"TopicQos0">>, #{rh => 1, qos => ?QOS_2, rap => 0, nl => 0, rc => 0}}, - {<<"TopicQos1">>, #{rh => 1, qos => ?QOS_2, rap => 0, nl => 0, rc => 0}}], + TopicFilters = [{<<"TopicQos0">>, #{rh => 1, qos => ?QOS_2, rap => 0, nl => 0}}, + {<<"TopicQos1">>, #{rh => 1, qos => ?QOS_2, rap => 0, nl => 0}}], Packet = ?SUBSCRIBE_PACKET(3, #{'Subscription-Identifier' => 16#FFFFFFF}, TopicFilters), ?assertEqual(Packet, parse_serialize(Packet, #{version => ?MQTT_PROTO_V5})). diff --git a/test/emqx_misc_SUITE.erl b/test/emqx_misc_SUITE.erl index 677b25c17..6f77ec86c 100644 --- a/test/emqx_misc_SUITE.erl +++ b/test/emqx_misc_SUITE.erl @@ -87,7 +87,7 @@ t_proc_stats(_) -> Pid2 = spawn(fun() -> timer:sleep(100) end), Pid2 ! msg, timer:sleep(10), - ?assertMatch([{mailbox_len, 1}|_], emqx_misc:proc_stats(Pid2)). + ?assertMatch([{message_queue_len, 1}|_], emqx_misc:proc_stats(Pid2)). t_drain_deliver(_) -> self() ! {deliver, t1, m1}, diff --git a/test/emqx_session_SUITE.erl b/test/emqx_session_SUITE.erl index da0c0115a..93a4e1873 100644 --- a/test/emqx_session_SUITE.erl +++ b/test/emqx_session_SUITE.erl @@ -188,12 +188,12 @@ info_args() -> sub_args() -> ?LET({ClientId, TopicFilter, SubOpts}, {clientid(), topic(), sub_opts()}, - {#{client_id => ClientId}, TopicFilter, SubOpts}). + {#{clientid => ClientId}, TopicFilter, SubOpts}). unsub_args() -> ?LET({ClientId, TopicFilter}, {clientid(), topic()}, - {#{client_id => ClientId}, TopicFilter}). + {#{clientid => ClientId}, TopicFilter}). publish_args() -> ?LET({PacketId, Message}, diff --git a/test/emqx_shared_sub_SUITE.erl b/test/emqx_shared_sub_SUITE.erl index 14baf8a76..9fec11171 100644 --- a/test/emqx_shared_sub_SUITE.erl +++ b/test/emqx_shared_sub_SUITE.erl @@ -80,9 +80,9 @@ t_no_connection_nack(_) -> ShareTopic = <<"$share/", Group/binary, $/, Topic/binary>>, ExpProp = [{properties, #{'Session-Expiry-Interval' => timer:seconds(30)}}], - {ok, SubConnPid1} = emqtt:start_link([{client_id, Subscriber1}] ++ ExpProp), + {ok, SubConnPid1} = emqtt:start_link([{clientid, Subscriber1}] ++ ExpProp), {ok, _Props} = emqtt:connect(SubConnPid1), - {ok, SubConnPid2} = emqtt:start_link([{client_id, Subscriber2}] ++ ExpProp), + {ok, SubConnPid2} = emqtt:start_link([{clientid, Subscriber2}] ++ ExpProp), {ok, _Props} = emqtt:connect(SubConnPid2), emqtt:subscribe(SubConnPid1, ShareTopic, QoS), emqtt:subscribe(SubConnPid1, ShareTopic, QoS), @@ -151,9 +151,9 @@ t_not_so_sticky(_) -> ok = ensure_config(sticky), ClientId1 = <<"ClientId1">>, ClientId2 = <<"ClientId2">>, - {ok, C1} = emqtt:start_link([{client_id, ClientId1}]), + {ok, C1} = emqtt:start_link([{clientid, ClientId1}]), {ok, _} = emqtt:connect(C1), - {ok, C2} = emqtt:start_link([{client_id, ClientId2}]), + {ok, C2} = emqtt:start_link([{clientid, ClientId2}]), {ok, _} = emqtt:connect(C2), emqtt:subscribe(C1, {<<"$share/group1/foo/bar">>, 0}), @@ -179,9 +179,9 @@ test_two_messages(Strategy, WithAck) -> Topic = <<"foo/bar">>, ClientId1 = <<"ClientId1">>, ClientId2 = <<"ClientId2">>, - {ok, ConnPid1} = emqtt:start_link([{client_id, ClientId1}]), + {ok, ConnPid1} = emqtt:start_link([{clientid, ClientId1}]), {ok, _} = emqtt:connect(ConnPid1), - {ok, ConnPid2} = emqtt:start_link([{client_id, ClientId2}]), + {ok, ConnPid2} = emqtt:start_link([{clientid, ClientId2}]), {ok, _} = emqtt:connect(ConnPid2), Message1 = emqx_message:make(ClientId1, 0, Topic, <<"hello1">>),