From 8b03371a4a42d7d86ff656e167ac825415dedcb6 Mon Sep 17 00:00:00 2001 From: Feng Lee Date: Thu, 22 Aug 2019 16:38:25 +0800 Subject: [PATCH] Improve the keepalive, connection, channel and session modules (#2813) --- src/emqx_channel.erl | 788 ++++++++++++++++++---------------- src/emqx_cm.erl | 10 +- src/emqx_connection.erl | 169 +++----- src/emqx_keepalive.erl | 106 ++--- src/emqx_misc.erl | 19 +- src/emqx_packet.erl | 18 +- src/emqx_protocol.erl | 136 ++++++ src/emqx_session.erl | 149 ++----- src/emqx_ws_connection.erl | 112 ++--- test/emqx_channel_SUITE.erl | 69 +-- test/emqx_keepalive_SUITE.erl | 35 +- test/emqx_net_SUITE.erl | 45 -- test/emqx_packet_SUITE.erl | 6 +- test/emqx_protocol_SUITE.erl | 49 +++ test/emqx_session_SUITE.erl | 15 +- 15 files changed, 881 insertions(+), 845 deletions(-) create mode 100644 src/emqx_protocol.erl delete mode 100644 test/emqx_net_SUITE.erl create mode 100644 test/emqx_protocol_SUITE.erl diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl index 8c4412f42..b1671489e 100644 --- a/src/emqx_channel.erl +++ b/src/emqx_channel.erl @@ -24,21 +24,20 @@ -logger_header("[Channel]"). +-export([init/2]). + -export([ info/1 , info/2 , attrs/1 + , stats/1 , caps/1 ]). %% for tests -export([set/3]). --export([takeover/2]). - --export([ init/2 - , handle_in/2 +-export([ handle_in/2 , handle_out/2 - , handle_out/3 , handle_call/2 , handle_cast/2 , handle_info/2 @@ -46,164 +45,52 @@ , terminate/2 ]). +%% Ensure timer -export([ensure_timer/2]). -export([gc/3]). --import(emqx_access_control, - [ authenticate/1 - , check_acl/3 - ]). +-import(emqx_misc, [maybe_apply/2]). --import(emqx_misc, [start_timer/2]). +-import(emqx_access_control, [check_acl/3]). -export_type([channel/0]). -record(channel, { - client :: emqx_types:client(), - session :: emqx_session:session(), - proto_name :: binary(), - proto_ver :: emqx_types:ver(), - keepalive :: non_neg_integer(), - will_msg :: emqx_types:message(), - topic_aliases :: maybe(map()), - alias_maximum :: maybe(map()), - ack_props :: maybe(emqx_types:properties()), - idle_timeout :: timeout(), - retry_timer :: maybe(reference()), - alive_timer :: maybe(reference()), - stats_timer :: disabled | maybe(reference()), - expiry_timer :: maybe(reference()), - gc_state :: emqx_gc:gc_state(), %% GC State - oom_policy :: emqx_oom:oom_policy(), %% OOM Policy - connected :: boolean(), - connected_at :: erlang:timestamp(), - resuming :: boolean(), - pendings :: list() + %% MQTT Client + client :: emqx_types:client(), + %% MQTT Session + session :: emqx_session:session(), + %% MQTT Protocol + protocol :: emqx_protocol:protocol(), + %% Keepalive + keepalive :: emqx_keepalive:keepalive(), + %% Timers + timers :: #{atom() => disabled | maybe(reference())}, + %% GC State + gc_state :: emqx_gc:gc_state(), + %% OOM Policy + oom_policy :: emqx_oom:oom_policy(), + %% Connected + connected :: boolean(), + connected_at :: erlang:timestamp(), + %% Takeover/Resume + resuming :: boolean(), + pendings :: list() }). -opaque(channel() :: #channel{}). --define(NO_PROPS, undefined). +-define(TIMER_TABLE, #{ + stats_timer => emit_stats, + alive_timer => keepalive, + retry_timer => retry_delivery, + await_timer => expire_awaiting_rel, + expire_timer => expire_session + }). %%-------------------------------------------------------------------- -%% Info, Attrs and Caps -%%-------------------------------------------------------------------- - --spec(info(channel()) -> emqx_types:infos()). -info(#channel{client = Client, - session = Session, - proto_name = ProtoName, - proto_ver = ProtoVer, - keepalive = Keepalive, - will_msg = WillMsg, - topic_aliases = Aliases, - stats_timer = StatsTimer, - idle_timeout = IdleTimeout, - gc_state = GCState, - connected = Connected, - connected_at = ConnectedAt}) -> - #{client => Client, - session => if Session == undefined -> - undefined; - true -> emqx_session:info(Session) - end, - proto_name => ProtoName, - proto_ver => ProtoVer, - keepalive => Keepalive, - will_msg => WillMsg, - topic_aliases => Aliases, - enable_stats => case StatsTimer of - disabled -> false; - _Otherwise -> true - end, - idle_timeout => IdleTimeout, - gc_state => emqx_gc:info(GCState), - connected => Connected, - connected_at => ConnectedAt, - resuming => false, - pendings => [] - }. - --spec(info(atom(), channel()) -> term()). -info(client, #channel{client = Client}) -> - Client; -info(zone, #channel{client = #{zone := Zone}}) -> - Zone; -info(client_id, #channel{client = #{client_id := ClientId}}) -> - ClientId; -info(session, #channel{session = Session}) -> - Session; -info(proto_name, #channel{proto_name = ProtoName}) -> - ProtoName; -info(proto_ver, #channel{proto_ver = ProtoVer}) -> - ProtoVer; -info(keepalive, #channel{keepalive = Keepalive}) -> - Keepalive; -info(will_msg, #channel{will_msg = WillMsg}) -> - WillMsg; -info(topic_aliases, #channel{topic_aliases = Aliases}) -> - Aliases; -info(enable_stats, #channel{stats_timer = disabled}) -> - false; -info(enable_stats, #channel{stats_timer = _TRef}) -> - true; -info(idle_timeout, #channel{idle_timeout = IdleTimeout}) -> - IdleTimeout; -info(gc_state, #channel{gc_state = GCState}) -> - emqx_gc:info(GCState); -info(connected, #channel{connected = Connected}) -> - Connected; -info(connected_at, #channel{connected_at = ConnectedAt}) -> - ConnectedAt. - --spec(attrs(channel()) -> emqx_types:attrs()). -attrs(#channel{client = Client, - session = Session, - proto_name = ProtoName, - proto_ver = ProtoVer, - keepalive = Keepalive, - connected = Connected, - connected_at = ConnectedAt}) -> - #{client => Client, - session => if Session == undefined -> - undefined; - true -> emqx_session:attrs(Session) - end, - proto_name => ProtoName, - proto_ver => ProtoVer, - keepalive => Keepalive, - connected => Connected, - connected_at => ConnectedAt - }. - --spec(caps(channel()) -> emqx_types:caps()). -caps(#channel{client = #{zone := Zone}}) -> - emqx_mqtt_caps:get_caps(Zone). - -%%-------------------------------------------------------------------- -%% For unit tests -%%-------------------------------------------------------------------- - -set(client, Client, Channel) -> - Channel#channel{client = Client}; -set(session, Session, Channel) -> - Channel#channel{session = Session}. - -%%-------------------------------------------------------------------- -%% Takeover session -%%-------------------------------------------------------------------- - -takeover('begin', Channel = #channel{session = Session}) -> - {ok, Session, Channel#channel{resuming = true}}; - -takeover('end', Channel = #channel{session = Session, - pendings = Pendings}) -> - ok = emqx_session:takeover(Session), - {ok, Pendings, Channel}. - -%%-------------------------------------------------------------------- -%% Init a channel +%% Init the channel %%-------------------------------------------------------------------- -spec(init(emqx_types:conn(), proplists:proplist()) -> channel()). @@ -223,27 +110,100 @@ init(ConnInfo, Options) -> mountpoint => MountPoint, is_bridge => false, is_superuser => false}, ConnInfo), - IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), EnableStats = emqx_zone:get_env(Zone, enable_stats, true), - StatsTimer = if EnableStats -> undefined; - ?Otherwise -> disabled + StatsTimer = if + EnableStats -> undefined; + ?Otherwise -> disabled end, GcState = emqx_gc:init(emqx_zone:get_env(Zone, force_gc_policy, false)), OomPolicy = emqx_oom:init(emqx_zone:get_env(Zone, force_shutdown_policy)), - #channel{client = Client, - proto_name = <<"MQTT">>, - proto_ver = ?MQTT_PROTO_V4, - keepalive = 0, - idle_timeout = IdleTimout, - stats_timer = StatsTimer, - gc_state = GcState, - oom_policy = OomPolicy, - connected = false + #channel{client = Client, + session = undefined, + protocol = undefined, + gc_state = GcState, + oom_policy = OomPolicy, + timers = #{stats_timer => StatsTimer}, + connected = false }. peer_cert_as_username(Options) -> proplists:get_value(peer_cert_as_username, Options). +%%-------------------------------------------------------------------- +%% Info, Attrs and Caps +%%-------------------------------------------------------------------- + +-spec(info(channel()) -> emqx_types:infos()). +info(#channel{client = Client, + session = Session, + protocol = Protocol, + keepalive = Keepalive, + gc_state = GCState, + oom_policy = OomPolicy, + connected = Connected, + connected_at = ConnectedAt + }) -> + #{client => Client, + session => maybe_apply(fun emqx_session:info/1, Session), + protocol => maybe_apply(fun emqx_protocol:info/1, Protocol), + keepalive => maybe_apply(fun emqx_keepalive:info/1, Keepalive), + gc_state => emqx_gc:info(GCState), + oom_policy => emqx_oom:info(OomPolicy), + connected => Connected, + connected_at => ConnectedAt + }. + +-spec(info(atom(), channel()) -> term()). +info(client, #channel{client = Client}) -> + Client; +info(session, #channel{session = Session}) -> + maybe_apply(fun emqx_session:info/1, Session); +info(protocol, #channel{protocol = Protocol}) -> + maybe_apply(fun emqx_protocol:info/1, Protocol); +info(keepalive, #channel{keepalive = Keepalive}) -> + maybe_apply(fun emqx_keepalive:info/1, Keepalive); +info(gc_state, #channel{gc_state = GCState}) -> + emqx_gc:info(GCState); +info(oom_policy, #channel{oom_policy = Policy}) -> + emqx_oom:info(Policy); +info(connected, #channel{connected = Connected}) -> + Connected; +info(connected_at, #channel{connected_at = ConnectedAt}) -> + ConnectedAt. + +-spec(attrs(channel()) -> emqx_types:attrs()). +attrs(#channel{client = Client, + session = Session, + protocol = Protocol, + connected = Connected, + connected_at = ConnectedAt}) -> + #{client => Client, + session => maybe_apply(fun emqx_session:attrs/1, Session), + protocol => maybe_apply(fun emqx_protocol:attrs/1, Protocol), + connected => Connected, + connected_at => ConnectedAt + }. + +%%TODO: ChanStats? +-spec(stats(channel()) -> emqx_types:stats()). +stats(#channel{session = Session}) -> + emqx_session:stats(Session). + +-spec(caps(channel()) -> emqx_types:caps()). +caps(#channel{client = #{zone := Zone}}) -> + emqx_mqtt_caps:get_caps(Zone). + +%%-------------------------------------------------------------------- +%% For unit tests +%%-------------------------------------------------------------------- + +set(client, Client, Channel) -> + Channel#channel{client = Client}; +set(session, Session, Channel) -> + Channel#channel{session = Session}; +set(protocol, Protocol, Channel) -> + Channel#channel{protocol = Protocol}. + %%-------------------------------------------------------------------- %% Handle incoming packet %%-------------------------------------------------------------------- @@ -255,48 +215,43 @@ peer_cert_as_username(Options) -> | {stop, Error :: term(), channel()} | {stop, Error :: term(), emqx_types:packet(), channel()}). handle_in(?CONNECT_PACKET(_), Channel = #channel{connected = true}) -> - handle_out(disconnect, ?RC_PROTOCOL_ERROR, Channel); + handle_out({disconnect, ?RC_PROTOCOL_ERROR}, Channel); -handle_in(?CONNECT_PACKET( - #mqtt_packet_connect{proto_name = ProtoName, - proto_ver = ProtoVer, - keepalive = Keepalive, - client_id = ClientId - } = ConnPkt), Channel) -> - Channel1 = Channel#channel{proto_name = ProtoName, - proto_ver = ProtoVer, - keepalive = Keepalive - }, - ok = emqx_logger:set_metadata_client_id(ClientId), - case pipeline([fun validate_in/2, - fun process_props/2, +handle_in(?CONNECT_PACKET(ConnPkt), Channel) -> + case pipeline([fun validate_packet/2, fun check_connect/2, + fun init_protocol/2, fun enrich_client/2, - fun auth_connect/2], ConnPkt, Channel1) of - {ok, NConnPkt, NChannel = #channel{client = #{client_id := ClientId1}}} -> - ok = emqx_logger:set_metadata_client_id(ClientId1), + fun set_logger_meta/2, + fun auth_connect/2], ConnPkt, Channel) of + {ok, NConnPkt, NChannel} -> process_connect(NConnPkt, NChannel); {error, ReasonCode, NChannel} -> - handle_out(connack, ReasonCode, NChannel) + handle_out({connack, ReasonCode}, NChannel) end; -handle_in(Packet = ?PUBLISH_PACKET(_QoS, Topic, _PacketId), Channel = #channel{proto_ver = Ver}) -> - case pipeline([fun validate_in/2, +handle_in(Packet = ?PUBLISH_PACKET(QoS, Topic, PacketId), Channel = #channel{protocol = Protocol}) -> + case pipeline([fun validate_packet/2, fun process_alias/2, fun check_publish/2], Packet, Channel) of {ok, NPacket, NChannel} -> process_publish(NPacket, NChannel); {error, ReasonCode, NChannel} -> + ProtoVer = emqx_protocol:info(proto_ver, Protocol), ?LOG(warning, "Cannot publish message to ~s due to ~s", - [Topic, emqx_reason_codes:text(ReasonCode, Ver)]), - handle_out(disconnect, ReasonCode, NChannel) + [Topic, emqx_reason_codes:text(ReasonCode, ProtoVer)]), + case QoS of + ?QOS_0 -> handle_out({puberr, ReasonCode}, NChannel); + ?QOS_1 -> handle_out({puback, PacketId, ReasonCode}, NChannel); + ?QOS_2 -> handle_out({pubrec, PacketId, ReasonCode}, NChannel) + end end; %%TODO: How to handle the ReasonCode? handle_in(?PUBACK_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Session}) -> case emqx_session:puback(PacketId, Session) of {ok, Publishes, NSession} -> - handle_out(publish, Publishes, Channel#channel{session = NSession}); + handle_out({publish, Publishes}, Channel#channel{session = NSession}); {ok, NSession} -> {ok, Channel#channel{session = NSession}}; {error, _NotFound} -> @@ -308,24 +263,24 @@ handle_in(?PUBACK_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Se handle_in(?PUBREC_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Session}) -> case emqx_session:pubrec(PacketId, Session) of {ok, NSession} -> - handle_out(pubrel, {PacketId, ?RC_SUCCESS}, Channel#channel{session = NSession}); + handle_out({pubrel, PacketId, ?RC_SUCCESS}, Channel#channel{session = NSession}); {error, ReasonCode} -> - handle_out(pubrel, {PacketId, ReasonCode}, Channel) + handle_out({pubrel, PacketId, ReasonCode}, Channel) end; %%TODO: How to handle the ReasonCode? handle_in(?PUBREL_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Session}) -> case emqx_session:pubrel(PacketId, Session) of {ok, NSession} -> - handle_out(pubcomp, {PacketId, ?RC_SUCCESS}, Channel#channel{session = NSession}); + handle_out({pubcomp, PacketId, ?RC_SUCCESS}, Channel#channel{session = NSession}); {error, ReasonCode} -> - handle_out(pubcomp, {PacketId, ReasonCode}, Channel) + handle_out({pubcomp, PacketId, ReasonCode}, Channel) end; handle_in(?PUBCOMP_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Session}) -> case emqx_session:pubcomp(PacketId, Session) of {ok, Publishes, NSession} -> - handle_out(publish, Publishes, Channel#channel{session = NSession}); + handle_out({publish, Publishes}, Channel#channel{session = NSession}); {ok, NSession} -> {ok, Channel#channel{session = NSession}}; {error, _NotFound} -> @@ -335,7 +290,7 @@ handle_in(?PUBCOMP_PACKET(PacketId, _ReasonCode), Channel = #channel{session = S handle_in(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), Channel = #channel{client = Client}) -> - case validate_in(Packet, Channel) of + case validate_packet(Packet, Channel) of ok -> TopicFilters1 = [emqx_topic:parse(TopicFilter, SubOpts) || {TopicFilter, SubOpts} <- TopicFilters], @@ -344,23 +299,23 @@ handle_in(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), TopicFilters1), TopicFilters3 = enrich_subid(Properties, TopicFilters2), {ReasonCodes, NChannel} = process_subscribe(TopicFilters3, Channel), - handle_out(suback, {PacketId, ReasonCodes}, NChannel); + handle_out({suback, PacketId, ReasonCodes}, NChannel); {error, ReasonCode} -> - handle_out(disconnect, ReasonCode, Channel) + handle_out({disconnect, ReasonCode}, Channel) end; handle_in(Packet = ?UNSUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), Channel = #channel{client = Client}) -> - case validate_in(Packet, Channel) of + case validate_packet(Packet, Channel) of ok -> TopicFilters1 = lists:map(fun emqx_topic:parse/1, TopicFilters), TopicFilters2 = emqx_hooks:run_fold('client.unsubscribe', [Client, Properties], TopicFilters1), {ReasonCodes, NChannel} = process_unsubscribe(TopicFilters2, Channel), - handle_out(unsuback, {PacketId, ReasonCodes}, NChannel); + handle_out({unsuback, PacketId, ReasonCodes}, NChannel); {error, ReasonCode} -> - handle_out(disconnect, ReasonCode, Channel) + handle_out({disconnect, ReasonCode}, Channel) end; handle_in(?PACKET(?PINGREQ), Channel) -> @@ -368,9 +323,10 @@ handle_in(?PACKET(?PINGREQ), Channel) -> handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), Channel) -> %% Clear will msg - {stop, normal, Channel#channel{will_msg = undefined}}; + {stop, normal, Channel}; -handle_in(?DISCONNECT_PACKET(RC), Channel = #channel{proto_ver = Ver}) -> +handle_in(?DISCONNECT_PACKET(RC), Channel = #channel{protocol = Protocol}) -> + Ver = emqx_protocol:info(proto_ver, Protocol), {stop, {shutdown, emqx_reason_codes:name(RC, Ver)}, Channel}; handle_in(?AUTH_PACKET(), Channel) -> @@ -388,17 +344,12 @@ handle_in(Packet, Channel) -> process_connect(ConnPkt, Channel) -> case open_session(ConnPkt, Channel) of {ok, Session, SP} -> - WillMsg = emqx_packet:will_msg(ConnPkt), - NChannel = Channel#channel{session = Session, - will_msg = WillMsg, - connected = true, - connected_at = os:timestamp() - }, - handle_out(connack, {?RC_SUCCESS, sp(SP)}, NChannel); + NChannel = Channel#channel{session = Session}, + handle_out({connack, ?RC_SUCCESS, sp(SP)}, NChannel); {error, Reason} -> %% TODO: Unknown error? ?LOG(error, "Failed to open session: ~p", [Reason]), - handle_out(connack, ?RC_UNSPECIFIED_ERROR, Channel) + handle_out({connack, ?RC_UNSPECIFIED_ERROR}, Channel) end. %%-------------------------------------------------------------------- @@ -420,17 +371,18 @@ process_publish(_PacketId, Msg = #message{qos = ?QOS_0}, Channel) -> process_publish(PacketId, Msg = #message{qos = ?QOS_1}, Channel) -> Deliveries = emqx_broker:publish(Msg), ReasonCode = emqx_reason_codes:puback(Deliveries), - handle_out(puback, {PacketId, ReasonCode}, Channel); + handle_out({puback, PacketId, ReasonCode}, Channel); process_publish(PacketId, Msg = #message{qos = ?QOS_2}, Channel = #channel{session = Session}) -> case emqx_session:publish(PacketId, Msg, Session) of {ok, Deliveries, NSession} -> ReasonCode = emqx_reason_codes:puback(Deliveries), - handle_out(pubrec, {PacketId, ReasonCode}, - Channel#channel{session = NSession}); + NChannel = Channel#channel{session = NSession}, + handle_out({pubrec, PacketId, ReasonCode}, + ensure_timer(await_timer, NChannel)); {error, ReasonCode} -> - handle_out(pubrec, {PacketId, ReasonCode}, Channel) + handle_out({pubrec, PacketId, ReasonCode}, Channel) end. %%-------------------------------------------------------------------- @@ -474,8 +426,8 @@ process_unsubscribe([{TopicFilter, SubOpts}|More], Acc, Channel) -> {RC, Channel1} = do_unsubscribe(TopicFilter, SubOpts, Channel), process_unsubscribe(More, [RC|Acc], Channel1). -do_unsubscribe(TopicFilter, _SubOpts, Channel = #channel{client = Client, - session = Session}) -> +do_unsubscribe(TopicFilter, _SubOpts, + Channel = #channel{client = Client, session = Session}) -> case emqx_session:unsubscribe(Client, mount(Client, TopicFilter), Session) of {ok, NSession} -> {?RC_SUCCESS, Channel#channel{session = NSession}}; @@ -486,72 +438,21 @@ do_unsubscribe(TopicFilter, _SubOpts, Channel = #channel{client = Client, %% Handle outgoing packet %%-------------------------------------------------------------------- -handle_out(Deliver = {deliver, _Topic, _Msg}, - Channel = #channel{resuming = true, pendings = Pendings}) -> - Delivers = emqx_misc:drain_deliver([Deliver]), - {ok, Channel#channel{pendings = lists:append(Pendings, Delivers)}}; +handle_out({connack, ?RC_SUCCESS, SP}, Channel = #channel{client = Client}) -> + ok = emqx_hooks:run('client.connected', + [Client, ?RC_SUCCESS, attrs(Channel)]), + AckProps = emqx_misc:run_fold([fun enrich_caps/2, + fun enrich_server_keepalive/2, + fun enrich_assigned_clientid/2 + ], #{}, Channel), + NChannel = ensure_keepalive(AckProps, ensure_connected(Channel)), + {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, AckProps), NChannel}; -handle_out(Deliver = {deliver, _Topic, _Msg}, Channel = #channel{session = Session}) -> - Delivers = emqx_misc:drain_deliver([Deliver]), - case emqx_session:deliver(Delivers, Session) of - {ok, Publishes, NSession} -> - handle_out(publish, Publishes, Channel#channel{session = NSession}); - {ok, NSession} -> - {ok, Channel#channel{session = NSession}} - end; - -handle_out({publish, PacketId, Msg}, Channel = #channel{client = Client}) -> - Msg1 = emqx_hooks:run_fold('message.deliver', [Client], - emqx_message:update_expiry(Msg)), - Packet = emqx_packet:from_message(PacketId, unmount(Client, Msg1)), - {ok, Packet, Channel}. - -handle_out(connack, {?RC_SUCCESS, SP}, - Channel = #channel{client = Client = #{zone := Zone}, - ack_props = AckProps, - alias_maximum = AliasMaximum}) -> - ok = emqx_hooks:run('client.connected', [Client, ?RC_SUCCESS, attrs(Channel)]), - #{max_packet_size := MaxPktSize, - max_qos_allowed := MaxQoS, - retain_available := Retain, - max_topic_alias := MaxAlias, - shared_subscription := Shared, - wildcard_subscription := Wildcard - } = caps(Channel), - %% Response-Information is so far not set by broker. - %% i.e. It's a Client-to-Client contract for the request-response topic naming scheme. - %% According to MQTT 5.0 spec: - %% A common use of this is to pass a globally unique portion of the topic tree which - %% is reserved for this Client for at least the lifetime of its Session. - %% This often cannot just be a random name as both the requesting Client and the - %% responding Client need to be authorized to use it. - %% If we are to support it in the feature, the implementation should be flexible - %% to allow prefixing the response topic based on different ACL config. - %% e.g. prefix by username or client-id, so that unauthorized clients can not - %% subscribe requests or responses that are not intended for them. - AckProps1 = if AckProps == undefined -> #{}; true -> AckProps end, - AckProps2 = AckProps1#{'Retain-Available' => flag(Retain), - 'Maximum-Packet-Size' => MaxPktSize, - 'Topic-Alias-Maximum' => MaxAlias, - 'Wildcard-Subscription-Available' => flag(Wildcard), - 'Subscription-Identifier-Available' => 1, - %'Response-Information' => - 'Shared-Subscription-Available' => flag(Shared), - 'Maximum-QoS' => MaxQoS - }, - AckProps3 = case emqx_zone:get_env(Zone, server_keepalive) of - undefined -> AckProps2; - Keepalive -> AckProps2#{'Server-Keep-Alive' => Keepalive} - end, - AliasMaximum1 = set_property(inbound, MaxAlias, AliasMaximum), - Channel1 = Channel#channel{alias_maximum = AliasMaximum1, - ack_props = undefined - }, - {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, AckProps3), Channel1}; - -handle_out(connack, ReasonCode, Channel = #channel{client = Client, - proto_ver = ProtoVer}) -> +handle_out({connack, ReasonCode}, Channel = #channel{client = Client, + protocol = Protocol + }) -> ok = emqx_hooks:run('client.connected', [Client, ReasonCode, attrs(Channel)]), + ProtoVer = emqx_protocol:info(proto_ver, Protocol), ReasonCode1 = if ProtoVer == ?MQTT_PROTO_V5 -> ReasonCode; true -> emqx_reason_codes:compat(connack, ReasonCode) @@ -559,52 +460,81 @@ handle_out(connack, ReasonCode, Channel = #channel{client = Client, Reason = emqx_reason_codes:name(ReasonCode1, ProtoVer), {stop, {shutdown, Reason}, ?CONNACK_PACKET(ReasonCode1), Channel}; -handle_out(publish, Publishes, Channel) -> - Packets = [element(2, handle_out(Publish, Channel)) || Publish <- Publishes], +handle_out({deliver, Delivers}, Channel = #channel{resuming = true, + pendings = Pendings + }) -> + {ok, Channel#channel{pendings = lists:append(Pendings, Delivers)}}; + +handle_out({deliver, Delivers}, Channel = #channel{session = Session}) -> + case emqx_session:deliver(Delivers, Session) of + {ok, Publishes, NSession} -> + NChannel = Channel#channel{session = NSession}, + handle_out({publish, Publishes}, ensure_timer(retry_timer, NChannel)); + {ok, NSession} -> + {ok, Channel#channel{session = NSession}} + end; + +handle_out({publish, Publishes}, Channel) -> + Packets = lists:map( + fun(Publish) -> + element(2, handle_out(Publish, Channel)) + end, Publishes), {ok, Packets, Channel}; +handle_out({publish, PacketId, Msg}, Channel = #channel{client = Client}) -> + Msg1 = emqx_hooks:run_fold('message.deliver', [Client], + emqx_message:update_expiry(Msg)), + Packet = emqx_packet:from_message(PacketId, unmount(Client, Msg1)), + {ok, Packet, Channel}; + %% TODO: How to handle the puberr? -handle_out(puberr, _ReasonCode, Channel) -> +handle_out({puberr, _ReasonCode}, Channel) -> {ok, Channel}; -handle_out(puback, {PacketId, ReasonCode}, Channel) -> +handle_out({puback, PacketId, ReasonCode}, Channel) -> {ok, ?PUBACK_PACKET(PacketId, ReasonCode), Channel}; -handle_out(pubrel, {PacketId, ReasonCode}, Channel) -> +handle_out({pubrel, PacketId, ReasonCode}, Channel) -> {ok, ?PUBREL_PACKET(PacketId, ReasonCode), Channel}; -handle_out(pubrec, {PacketId, ReasonCode}, Channel) -> +handle_out({pubrec, PacketId, ReasonCode}, Channel) -> {ok, ?PUBREC_PACKET(PacketId, ReasonCode), Channel}; -handle_out(pubcomp, {PacketId, ReasonCode}, Channel) -> +handle_out({pubcomp, PacketId, ReasonCode}, Channel) -> {ok, ?PUBCOMP_PACKET(PacketId, ReasonCode), Channel}; -handle_out(suback, {PacketId, ReasonCodes}, - Channel = #channel{proto_ver = ?MQTT_PROTO_V5}) -> - %% TODO: ACL Deny - {ok, ?SUBACK_PACKET(PacketId, ReasonCodes), Channel}; - -handle_out(suback, {PacketId, ReasonCodes}, Channel) -> - %% TODO: ACL Deny - ReasonCodes1 = [emqx_reason_codes:compat(suback, RC) || RC <- ReasonCodes], +handle_out({suback, PacketId, ReasonCodes}, + Channel = #channel{protocol = Protocol}) -> + ReasonCodes1 = + case emqx_protocol:info(proto_ver, Protocol) of + ?MQTT_PROTO_V5 -> ReasonCodes; + _Ver -> + [emqx_reason_codes:compat(suback, RC) || RC <- ReasonCodes] + end, {ok, ?SUBACK_PACKET(PacketId, ReasonCodes1), Channel}; -handle_out(unsuback, {PacketId, ReasonCodes}, - Channel = #channel{proto_ver = ?MQTT_PROTO_V5}) -> - {ok, ?UNSUBACK_PACKET(PacketId, ReasonCodes), Channel}; +handle_out({unsuback, PacketId, ReasonCodes}, + Channel = #channel{protocol = Protocol}) -> + Packet = case emqx_protocol:info(proto_ver, Protocol) of + ?MQTT_PROTO_V5 -> + ?UNSUBACK_PACKET(PacketId, ReasonCodes); + %% Ignore reason codes if not MQTT5 + _Ver -> ?UNSUBACK_PACKET(PacketId) + end, + {ok, Packet, Channel}; -%% Ignore reason codes if not MQTT5 -handle_out(unsuback, {PacketId, _ReasonCodes}, Channel) -> - {ok, ?UNSUBACK_PACKET(PacketId), Channel}; +handle_out({disconnect, ReasonCode}, Channel = #channel{protocol = Protocol}) -> + case emqx_protocol:info(proto_ver, Protocol) of + ?MQTT_PROTO_V5 -> + Reason = emqx_reason_codes:name(ReasonCode), + Packet = ?DISCONNECT_PACKET(ReasonCode), + {stop, {shutdown, Reason}, Packet, Channel}; + ProtoVer -> + Reason = emqx_reason_codes:name(ReasonCode, ProtoVer), + {stop, {shutdown, Reason}, Channel} + end; -handle_out(disconnect, ReasonCode, Channel = #channel{proto_ver = ?MQTT_PROTO_V5}) -> - Reason = emqx_reason_codes:name(ReasonCode), - {stop, {shutdown, Reason}, ?DISCONNECT_PACKET(ReasonCode), Channel}; - -handle_out(disconnect, ReasonCode, Channel = #channel{proto_ver = ProtoVer}) -> - {stop, {shutdown, emqx_reason_codes:name(ReasonCode, ProtoVer)}, Channel}; - -handle_out(Type, Data, Channel) -> +handle_out({Type, Data}, Channel) -> ?LOG(error, "Unexpected outgoing: ~s, ~p", [Type, Data]), {ok, Channel}. @@ -612,6 +542,18 @@ handle_out(Type, Data, Channel) -> %% Handle call %%-------------------------------------------------------------------- +%%-------------------------------------------------------------------- +%% Takeover session +%%-------------------------------------------------------------------- + +handle_call({takeover, 'begin'}, Channel = #channel{session = Session}) -> + {ok, Session, Channel#channel{resuming = true}}; + +handle_call({takeover, 'end'}, Channel = #channel{session = Session, + pendings = Pendings}) -> + ok = emqx_session:takeover(Session), + {stop, {shutdown, takeovered}, Pendings, Channel}; + handle_call(Req, Channel) -> ?LOG(error, "Unexpected call: Req", [Req]), {ok, ignored, Channel}. @@ -659,16 +601,49 @@ handle_info(Info, Channel) -> -> {ok, channel()} | {ok, Result :: term(), channel()} | {stop, Reason :: term(), channel()}). -timeout(TRef, {emit_stats, Stats}, Channel = #channel{stats_timer = TRef}) -> - ClientId = info(client_id, Channel), +timeout(TRef, {emit_stats, Stats}, + Channel = #channel{client = #{client_id := ClientId}, + timers = #{stats_timer := TRef} + }) -> ok = emqx_cm:set_chan_stats(ClientId, Stats), - {ok, Channel#channel{stats_timer = undefined}}; + {ok, clean_timer(stats_timer, Channel)}; -timeout(TRef, retry_deliver, Channel = #channel{%%session = Session, - retry_timer = TRef}) -> - %% case emqx_session:retry(Session) of - %% TODO: ... - {ok, Channel#channel{retry_timer = undefined}}; +timeout(TRef, {keepalive, StatVal}, Channel = #channel{keepalive = Keepalive, + timers = #{alive_timer := TRef} + }) -> + case emqx_keepalive:check(StatVal, Keepalive) of + {ok, NKeepalive} -> + NChannel = Channel#channel{keepalive = NKeepalive}, + {ok, reset_timer(alive_timer, NChannel)}; + {error, timeout} -> + {stop, {shutdown, keepalive_timeout}, Channel} + end; + +timeout(TRef, retry_delivery, Channel = #channel{session = Session, + timers = #{retry_timer := TRef} + }) -> + case emqx_session:retry(Session) of + {ok, NSession} -> + {ok, clean_timer(retry_timer, Channel#channel{session = NSession})}; + {ok, Publishes, NSession} -> + NChannel = Channel#channel{session = NSession}, + handle_out({publish, Publishes}, reset_timer(retry_timer, NChannel)); + {ok, Publishes, Timeout, NSession} -> + NChannel = Channel#channel{session = NSession}, + handle_out({publish, Publishes}, reset_timer(retry_timer, Timeout, NChannel)) + end; + +timeout(TRef, expire_awaiting_rel, Channel = #channel{session = Session, + timers = #{await_timer := TRef}}) -> + case emqx_session:expire(awaiting_rel, Session) of + {ok, Session} -> + {ok, clean_timer(await_timer, Channel#channel{session = Session})}; + {ok, Timeout, Session} -> + {ok, reset_timer(await_timer, Timeout, Channel#channel{session = Session})} + end; + +timeout(_TRef, expire_session, Channel) -> + {ok, Channel}; timeout(_TRef, Msg, Channel) -> ?LOG(error, "Unexpected timeout: ~p~n", [Msg]), @@ -678,20 +653,39 @@ timeout(_TRef, Msg, Channel) -> %% Ensure timers %%-------------------------------------------------------------------- -ensure_timer(emit_stats, Channel = #channel{stats_timer = undefined, - idle_timeout = IdleTimeout - }) -> - Channel#channel{stats_timer = start_timer(IdleTimeout, emit_stats)}; +ensure_timer(Name, Channel = #channel{timers = Timers}) -> + TRef = maps:get(Name, Timers, undefined), + Time = interval(Name, Channel), + case TRef == undefined andalso Time > 0 of + true -> + ensure_timer(Name, Time, Channel); + false -> Channel %% Timer disabled or exists + end. -ensure_timer(retry, Channel = #channel{session = Session, - retry_timer = undefined}) -> - Interval = emqx_session:info(retry_interval, Session), - TRef = emqx_misc:start_timer(Interval, retry_deliver), - Channel#channel{retry_timer = TRef}; +ensure_timer(Name, Time, Channel = #channel{timers = Timers}) -> + Msg = maps:get(Name, ?TIMER_TABLE), + TRef = emqx_misc:start_timer(Time, Msg), + Channel#channel{timers = Timers#{Name => TRef}}. -%% disabled or timer existed -ensure_timer(_Name, Channel) -> - Channel. +reset_timer(Name, Channel) -> + ensure_timer(Name, clean_timer(Name, Channel)). + +reset_timer(Name, Time, Channel) -> + ensure_timer(Name, Time, clean_timer(Name, Channel)). + +clean_timer(Name, Channel = #channel{timers = Timers}) -> + Channel#channel{timers = maps:remove(Name, Timers)}. + +interval(stats_timer, #channel{client = #{zone := Zone}}) -> + emqx_zone:get_env(Zone, idle_timeout, 30000); +interval(alive_timer, #channel{keepalive = KeepAlive}) -> + emqx_keepalive:info(interval, KeepAlive); +interval(retry_timer, #channel{session = Session}) -> + emqx_session:info(retry_interval, Session); +interval(await_timer, #channel{session = Session}) -> + emqx_session:info(await_rel_timeout, Session); +interval(expire_timer, #channel{session = Session}) -> + emqx_session:info(expiry_interval, Session). %%-------------------------------------------------------------------- %% Terminate @@ -699,9 +693,14 @@ ensure_timer(_Name, Channel) -> terminate(normal, #channel{client = Client}) -> ok = emqx_hooks:run('client.disconnected', [Client, normal]); -terminate(Reason, #channel{client = Client, will_msg = WillMsg}) -> +terminate(Reason, #channel{client = Client, + protocol = Protocol + }) -> ok = emqx_hooks:run('client.disconnected', [Client, Reason]), - publish_will_msg(WillMsg). + if + Protocol == undefined -> ok; + true -> publish_will_msg(emqx_protocol:info(will_msg, Protocol)) + end. %%TODO: Improve will msg:) publish_will_msg(undefined) -> @@ -720,13 +719,10 @@ gc(Cnt, Oct, Channel = #channel{gc_state = GCSt}) -> Ok andalso emqx_metrics:inc('channel.gc.cnt'), Channel#channel{gc_state = GCSt1}. -%%-------------------------------------------------------------------- -%% Validate incoming packet -%%-------------------------------------------------------------------- - --spec(validate_in(emqx_types:packet(), channel()) +%% @doc Validate incoming packet. +-spec(validate_packet(emqx_types:packet(), channel()) -> ok | {error, emqx_types:reason_code()}). -validate_in(Packet, _Channel) -> +validate_packet(Packet, _Channel) -> try emqx_packet:validate(Packet) of true -> ok catch @@ -744,23 +740,6 @@ validate_in(Packet, _Channel) -> {error, ?RC_MALFORMED_PACKET} end. -%%-------------------------------------------------------------------- -%% Preprocess properties -%%-------------------------------------------------------------------- - -process_props(#mqtt_packet_connect{ - properties = #{'Topic-Alias-Maximum' := Max} - }, - Channel = #channel{alias_maximum = AliasMaximum}) -> - NAliasMaximum = if AliasMaximum == undefined -> - #{outbound => Max}; - true -> AliasMaximum#{outbound => Max} - end, - {ok, Channel#channel{alias_maximum = NAliasMaximum}}; - -process_props(Packet, Channel) -> - {ok, Packet, Channel}. - %%-------------------------------------------------------------------- %% Check connect packet %%-------------------------------------------------------------------- @@ -836,6 +815,9 @@ check_will_retain(#mqtt_packet_connect{will_retain = true}, false -> {error, ?RC_RETAIN_NOT_SUPPORTED} end. +init_protocol(ConnPkt, Channel) -> + {ok, Channel#channel{protocol = emqx_protocol:init(ConnPkt)}}. + %%-------------------------------------------------------------------- %% Enrich client %%-------------------------------------------------------------------- @@ -858,11 +840,10 @@ maybe_use_username_as_clientid(_ConnPkt, Channel = #channel{client = Client = #{ {ok, Channel#channel{client = NClient}}. maybe_assign_clientid(#mqtt_packet_connect{client_id = <<>>}, - Channel = #channel{client = Client, - ack_props = AckProps}) -> - ClientId = emqx_guid:to_base62(emqx_guid:gen()), - AckProps1 = set_property('Assigned-Client-Identifier', ClientId, AckProps), - {ok, Channel#channel{client = Client#{client_id => ClientId}, ack_props = AckProps1}}; + Channel = #channel{client = Client}) -> + RandClientId = emqx_guid:to_base62(emqx_guid:gen()), + {ok, Channel#channel{client = Client#{client_id => RandClientId}}}; + maybe_assign_clientid(#mqtt_packet_connect{client_id = ClientId}, Channel = #channel{client = Client}) -> {ok, Channel#channel{client = Client#{client_id => ClientId}}}. @@ -878,6 +859,10 @@ set_rest_client_fields(#mqtt_packet_connect{is_bridge = IsBridge}, Channel = #channel{client = Client}) -> {ok, Channel#channel{client = Client#{is_bridge => IsBridge}}}. +%% @doc Set logger metadata. +set_logger_meta(_ConnPkt, #channel{client = #{client_id := ClientId}}) -> + emqx_logger:set_metadata_client_id(ClientId). + %%-------------------------------------------------------------------- %% Auth Connect %%-------------------------------------------------------------------- @@ -886,7 +871,7 @@ auth_connect(#mqtt_packet_connect{client_id = ClientId, username = Username, password = Password}, Channel = #channel{client = Client}) -> - case authenticate(Client#{password => Password}) of + case emqx_access_control:authenticate(Client#{password => Password}) of {ok, AuthResult} -> {ok, Channel#channel{client = maps:merge(Client, AuthResult)}}; {error, Reason} -> @@ -906,7 +891,7 @@ open_session(#mqtt_packet_connect{clean_start = CleanStart, emqx_zone:get_env(Zone, max_inflight, 65535)), Interval = get_property('Session-Expiry-Interval', ConnProps, emqx_zone:get_env(Zone, session_expiry_interval, 0)), - emqx_cm:open_session(CleanStart, Client, #{max_inflight => MaxInflight, + emqx_cm:open_session(CleanStart, Client, #{max_inflight => MaxInflight, expiry_interval => Interval }). @@ -918,8 +903,9 @@ process_alias(Packet = #mqtt_packet{ variable = #mqtt_packet_publish{topic_name = <<>>, properties = #{'Topic-Alias' := AliasId} } = Publish - }, Channel = #channel{topic_aliases = Aliases}) -> - case find_alias(AliasId, Aliases) of + }, + Channel = #channel{protocol = Protocol}) -> + case emqx_protocol:find_alias(AliasId, Protocol) of {ok, Topic} -> {ok, Packet#mqtt_packet{ variable = Publish#mqtt_packet_publish{ @@ -931,22 +917,12 @@ process_alias(#mqtt_packet{ variable = #mqtt_packet_publish{topic_name = Topic, properties = #{'Topic-Alias' := AliasId} } - }, Channel = #channel{topic_aliases = Aliases}) -> - {ok, Channel#channel{topic_aliases = save_alias(AliasId, Topic, Aliases)}}; + }, Channel = #channel{protocol = Protocol}) -> + {ok, Channel#channel{protocol = emqx_protocol:save_alias(AliasId, Topic, Protocol)}}; process_alias(_Packet, Channel) -> {ok, Channel}. -find_alias(_AliasId, undefined) -> - false; -find_alias(AliasId, Aliases) -> - maps:find(AliasId, Aliases). - -save_alias(AliasId, Topic, undefined) -> - #{AliasId => Topic}; -save_alias(AliasId, Topic, Aliases) -> - maps:put(AliasId, Topic, Aliases). - %% Check Publish check_publish(Packet, Channel) -> pipeline([fun check_pub_acl/2, @@ -968,7 +944,9 @@ check_pub_alias(#mqtt_packet{ properties = #{'Topic-Alias' := AliasId} } }, - #channel{alias_maximum = Limits}) -> + #channel{protocol = Protocol}) -> + %% TODO: Move to Protocol + Limits = emqx_protocol:info(alias_maximum, Protocol), case (Limits == undefined) orelse (Max = maps:get(inbound, Limits, 0)) == 0 orelse (AliasId > Max) of @@ -1009,12 +987,65 @@ enrich_subid(#{'Subscription-Identifier' := SubId}, TopicFilters) -> enrich_subid(_Properties, TopicFilters) -> TopicFilters. -enrich_subopts(SubOpts, #channel{proto_ver = ?MQTT_PROTO_V5}) -> - SubOpts; -enrich_subopts(SubOpts, #channel{client = #{zone := Zone, is_bridge := IsBridge}}) -> - Rap = flag(IsBridge), - Nl = flag(emqx_zone:get_env(Zone, ignore_loop_deliver, false)), - SubOpts#{rap => Rap, nl => Nl}. +enrich_subopts(SubOpts, #channel{client = Client, protocol = Proto}) -> + #{zone := Zone, is_bridge := IsBridge} = Client, + case emqx_protocol:info(proto_ver, Proto) of + ?MQTT_PROTO_V5 -> SubOpts; + _Ver -> Rap = flag(IsBridge), + Nl = flag(emqx_zone:get_env(Zone, ignore_loop_deliver, false)), + SubOpts#{rap => Rap, nl => Nl} + end. + +enrich_caps(AckProps, #channel{client = #{zone := Zone}, protocol = Protocol}) -> + case emqx_protocol:info(proto_ver, Protocol) of + ?MQTT_PROTO_V5 -> + #{max_packet_size := MaxPktSize, + max_qos_allowed := MaxQoS, + retain_available := Retain, + max_topic_alias := MaxAlias, + shared_subscription := Shared, + wildcard_subscription := Wildcard + } = emqx_mqtt_caps:get_caps(Zone), + AckProps#{'Retain-Available' => flag(Retain), + 'Maximum-Packet-Size' => MaxPktSize, + 'Topic-Alias-Maximum' => MaxAlias, + 'Wildcard-Subscription-Available' => flag(Wildcard), + 'Subscription-Identifier-Available' => 1, + 'Shared-Subscription-Available' => flag(Shared), + 'Maximum-QoS' => MaxQoS + }; + _Ver -> AckProps + end. + +enrich_server_keepalive(AckProps, #channel{client = #{zone := Zone}}) -> + case emqx_zone:get_env(Zone, server_keepalive) of + undefined -> AckProps; + Keepalive -> AckProps#{'Server-Keep-Alive' => Keepalive} + end. + +enrich_assigned_clientid(AckProps, #channel{client = #{client_id := ClientId}, + protocol = Protocol}) -> + case emqx_protocol:info(client_id, Protocol) of + <<>> -> %% Original ClientId. + AckProps#{'Assigned-Client-Identifier' => ClientId}; + _Origin -> AckProps + end. + +ensure_connected(Channel) -> + Channel#channel{connected = true, connected_at = os:timestamp()}. + +ensure_keepalive(#{'Server-Keep-Alive' := Interval}, Channel) -> + ensure_keepalive_timer(Interval, Channel); +ensure_keepalive(_AckProp, Channel = #channel{protocol = Protocol}) -> + case emqx_protocol:info(keepalive, Protocol) of + 0 -> Channel; + Interval -> ensure_keepalive_timer(Interval, Channel) + end. + +ensure_keepalive_timer(Interval, Channel = #channel{client = #{zone := Zone}}) -> + Backoff = emqx_zone:get_env(Zone, keepalive_backoff, 0.75), + Keepalive = emqx_keepalive:init(round(timer:seconds(Interval) * Backoff)), + ensure_timer(alive_timer, Channel#channel{keepalive = Keepalive}). %%-------------------------------------------------------------------- %% Is ACL enabled? @@ -1069,11 +1100,6 @@ pipeline([Fun|More], Packet, Channel) -> %% Helper functions %%-------------------------------------------------------------------- -set_property(Name, Value, ?NO_PROPS) -> - #{Name => Value}; -set_property(Name, Value, Props) -> - Props#{Name => Value}. - get_property(_Name, undefined, Default) -> Default; get_property(Name, Props, Default) -> diff --git a/src/emqx_cm.erl b/src/emqx_cm.erl index c393d6ab0..1e5842ec4 100644 --- a/src/emqx_cm.erl +++ b/src/emqx_cm.erl @@ -161,7 +161,7 @@ set_chan_stats(ClientId, ChanPid, Stats) -> open_session(true, Client = #{client_id := ClientId}, Options) -> CleanStart = fun(_) -> ok = discard_session(ClientId), - {ok, emqx_session:init(true, Client, Options), false} + {ok, emqx_session:init(Client, Options), false} end, emqx_cm_locker:trans(ClientId, CleanStart); @@ -169,12 +169,12 @@ open_session(false, Client = #{client_id := ClientId}, Options) -> ResumeStart = fun(_) -> case takeover_session(ClientId) of {ok, ConnMod, ChanPid, Session} -> - {ok, NSession} = emqx_session:resume(ClientId, Session), - {ok, Pendings} = ConnMod:takeover(ChanPid, 'end'), + NSession = emqx_session:resume(ClientId, Session), + Pendings = ConnMod:takeover(ChanPid, 'end'), io:format("Pending Delivers: ~p~n", [Pendings]), {ok, NSession, true}; {error, not_found} -> - {ok, emqx_session:init(false, Client, Options), false} + {ok, emqx_session:init(Client, Options), false} end end, emqx_cm_locker:trans(ClientId, ResumeStart). @@ -199,7 +199,7 @@ takeover_session(ClientId) -> takeover_session(ClientId, ChanPid) when node(ChanPid) == node() -> case get_chan_attrs(ClientId, ChanPid) of #{client := #{conn_mod := ConnMod}} -> - {ok, Session} = ConnMod:takeover(ChanPid, 'begin'), + Session = ConnMod:takeover(ChanPid, 'begin'), {ok, ConnMod, ChanPid, Session}; undefined -> {error, not_found} diff --git a/src/emqx_connection.erl b/src/emqx_connection.erl index 5ff534a2b..f6cd96108 100644 --- a/src/emqx_connection.erl +++ b/src/emqx_connection.erl @@ -35,7 +35,7 @@ ]). %% For Debug --export([state/1]). +-export([get_state/1]). -export([ kick/1 , discard/1 @@ -68,15 +68,14 @@ limit_timer :: maybe(reference()), parse_state :: emqx_frame:parse_state(), serialize :: fun((emqx_types:packet()) -> iodata()), - chan_state :: emqx_channel:channel(), - keepalive :: maybe(emqx_keepalive:keepalive()) + chan_state :: emqx_channel:channel() }). -type(state() :: #state{}). -define(ACTIVE_N, 100). -define(HANDLE(T, C, D), handle((T), (C), (D))). --define(CHAN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). +-define(CONN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]). -spec(start_link(esockd:transport(), esockd:socket(), proplists:proplist()) @@ -92,61 +91,63 @@ start_link(Transport, Socket, Options) -> -spec(info(pid() | state()) -> emqx_types:infos()). info(CPid) when is_pid(CPid) -> call(CPid, info); -info(#state{transport = Transport, - socket = Socket, - peername = Peername, - sockname = Sockname, +info(#state{transport = Transport, + socket = Socket, + peername = Peername, + sockname = Sockname, conn_state = ConnState, - active_n = ActiveN, + active_n = ActiveN, rate_limit = RateLimit, - pub_limit = PubLimit, + pub_limit = PubLimit, chan_state = ChanState}) -> - ConnInfo = #{socktype => Transport:type(Socket), - peername => Peername, - sockname => Sockname, + ConnInfo = #{socktype => Transport:type(Socket), + peername => Peername, + sockname => Sockname, conn_state => ConnState, - active_n => ActiveN, + active_n => ActiveN, rate_limit => limit_info(RateLimit), - pub_limit => limit_info(PubLimit) + pub_limit => limit_info(PubLimit) }, - maps:merge(ConnInfo, emqx_channel:info(ChanState)). + ChanInfo = emqx_channel:info(ChanState), + maps:merge(ConnInfo, ChanInfo). -limit_info(undefined) -> - undefined; limit_info(Limit) -> - esockd_rate_limit:info(Limit). + emqx_misc:maybe_apply(fun esockd_rate_limit:info/1, Limit). %% @doc Get attrs of the channel. -spec(attrs(pid() | state()) -> emqx_types:attrs()). attrs(CPid) when is_pid(CPid) -> call(CPid, attrs); -attrs(#state{transport = Transport, - socket = Socket, - peername = Peername, - sockname = Sockname, +attrs(#state{transport = Transport, + socket = Socket, + peername = Peername, + sockname = Sockname, chan_state = ChanState}) -> ConnAttrs = #{socktype => Transport:type(Socket), peername => Peername, sockname => Sockname }, - maps:merge(ConnAttrs, emqx_channel:attrs(ChanState)). + ChanAttrs = emqx_channel:attrs(ChanState), + maps:merge(ConnAttrs, ChanAttrs). %% @doc Get stats of the channel. -spec(stats(pid() | state()) -> emqx_types:stats()). stats(CPid) when is_pid(CPid) -> call(CPid, stats); -stats(#state{transport = Transport, - socket = Socket, +stats(#state{transport = Transport, + socket = Socket, chan_state = ChanState}) -> + ProcStats = emqx_misc:proc_stats(), SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of {ok, Ss} -> Ss; {error, _} -> [] end, - ChanStats = [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS], - SessStats = emqx_session:stats(emqx_channel:info(session, ChanState)), - lists:append([SockStats, ChanStats, SessStats, emqx_misc:proc_stats()]). + ConnStats = [{Name, emqx_pd:get_counter(Name)} || Name <- ?CONN_STATS], + ChanStats = emqx_channel:stats(ChanState), + lists:append([ProcStats, SockStats, ConnStats, ChanStats]). -state(CPid) -> +-spec(get_state(pid()) -> state()). +get_state(CPid) -> call(CPid, get_state). -spec(kick(pid()) -> ok). @@ -157,8 +158,7 @@ kick(CPid) -> discard(CPid) -> gen_statem:cast(CPid, discard). -%% TODO: --spec(takeover(pid(), 'begin'|'end') -> {ok, Result :: term()}). +-spec(takeover(pid(), 'begin'|'end') -> Result :: term()). takeover(CPid, Phase) -> gen_statem:call(CPid, {takeover, Phase}). @@ -187,16 +187,16 @@ init({Transport, RawSocket, Options}) -> peercert => Peercert, conn_mod => ?MODULE}, Options), IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), - State = #state{transport = Transport, - socket = Socket, - peername = Peername, - sockname = Sockname, - conn_state = running, - active_n = ActiveN, - rate_limit = RateLimit, - pub_limit = PubLimit, - parse_state = ParseState, - chan_state = ChanState + State = #state{transport = Transport, + socket = Socket, + peername = Peername, + sockname = Sockname, + conn_state = running, + active_n = ActiveN, + rate_limit = RateLimit, + pub_limit = PubLimit, + parse_state = ParseState, + chan_state = ChanState }, gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}], idle, State, self(), [IdleTimout]). @@ -242,18 +242,10 @@ idle(EventType, Content, State) -> %% Connected State connected(enter, _PrevSt, State = #state{chan_state = ChanState}) -> - ClientId = emqx_channel:info(client_id, ChanState), + #{client_id := ClientId} = emqx_channel:info(client, ChanState), ok = emqx_cm:register_channel(ClientId), - ok = emqx_cm:set_chan_attrs(ClientId, info(State)), - %% Ensure keepalive after connected successfully. - Interval = emqx_channel:info(keepalive, ChanState), - case ensure_keepalive(Interval, State) of - ignore -> keep_state(State); - {ok, KeepAlive} -> - keep_state(State#state{keepalive = KeepAlive}); - {error, Reason} -> - shutdown(Reason, State) - end; + ok = emqx_cm:set_chan_attrs(ClientId, attrs(State)), + keep_state_and_data; connected(cast, {incoming, Packet = ?PACKET(?CONNECT)}, State) -> ?LOG(warning, "Unexpected connect: ~p", [Packet]), @@ -265,7 +257,8 @@ connected(cast, {incoming, Packet}, State) when is_record(Packet, mqtt_packet) - connected(info, Deliver = {deliver, _Topic, _Msg}, State = #state{chan_state = ChanState}) -> - case emqx_channel:handle_out(Deliver, ChanState) of + Delivers = emqx_misc:drain_deliver([Deliver]), + case emqx_channel:handle_out({deliver, Delivers}, ChanState) of {ok, NChanState} -> keep_state(State#state{chan_state = NChanState}); {ok, Packets, NChanState} -> @@ -275,17 +268,6 @@ connected(info, Deliver = {deliver, _Topic, _Msg}, stop(Reason, State#state{chan_state = NChanState}) end; -%% Keepalive timer -connected(info, {keepalive, check}, State = #state{keepalive = KeepAlive}) -> - case emqx_keepalive:check(KeepAlive) of - {ok, KeepAlive1} -> - keep_state(State#state{keepalive = KeepAlive1}); - {error, timeout} -> - shutdown(keepalive_timeout, State); - {error, Reason} -> - shutdown(Reason, State) - end; - connected(EventType, Content, State) -> ?HANDLE(EventType, Content, State). @@ -326,16 +308,6 @@ handle({call, From}, kick, State) -> ok = gen_statem:reply(From, ok), shutdown(kicked, State); -handle({call, From}, {takeover, 'begin'}, State = #state{chan_state = ChanState}) -> - {ok, Session, NChanState} = emqx_channel:takeover('begin', ChanState), - ok = gen_statem:reply(From, {ok, Session}), - {next_state, takeovering, State#state{chan_state = NChanState}}; - -handle({call, From}, {takeover, 'end'}, State = #state{chan_state = ChanState}) -> - {ok, Delivers, NChanState} = emqx_channel:takeover('end', ChanState), - ok = gen_statem:reply(From, {ok, Delivers}), - shutdown(takeovered, State#state{chan_state = NChanState}); - handle({call, From}, Req, State = #state{chan_state = ChanState}) -> case emqx_channel:handle_call(Req, ChanState) of {ok, Reply, NChanState} -> @@ -362,22 +334,22 @@ handle(info, {Inet, _Sock, Data}, State = #state{chan_state = ChanState}) emqx_pd:update_counter(incoming_bytes, Oct), ok = emqx_metrics:inc('bytes.received', Oct), NChanState = emqx_channel:ensure_timer( - emit_stats, emqx_channel:gc(1, Oct, ChanState)), + stats_timer, emqx_channel:gc(1, Oct, ChanState)), process_incoming(Data, State#state{chan_state = NChanState}); handle(info, {Error, _Sock, Reason}, State) when Error == tcp_error; Error == ssl_error -> shutdown(Reason, State); +%%TODO: fixme later. handle(info, {Closed, _Sock}, State = #state{chan_state = ChanState}) when Closed == tcp_closed; Closed == ssl_closed -> - case emqx_channel:info(session, ChanState) of + case emqx_channel:info(protocol, ChanState) of undefined -> shutdown(closed, State); - Session -> - case emqx_session:info(clean_start, Session) of - true -> shutdown(closed, State); - false -> {next_state, disconnected, State} - end + #{clean_start := true} -> + shutdown(closed, State); + #{clean_start := false} -> + {next_state, disconnected, State} end; handle(info, {Passive, _Sock}, State) when Passive == tcp_passive; @@ -402,12 +374,22 @@ handle(info, activate_socket, State) -> handle(info, {inet_reply, _Sock, ok}, State = #state{chan_state = ChanState}) -> %% something sent - NChanState = emqx_channel:ensure_timer(emit_stats, ChanState), + NChanState = emqx_channel:ensure_timer(stats_timer, ChanState), keep_state(State#state{chan_state = NChanState}); handle(info, {inet_reply, _Sock, {error, Reason}}, State) -> shutdown(Reason, State); +handle(info, {timeout, TRef, keepalive}, + State = #state{transport = Transport, socket = Socket}) + when is_reference(TRef) -> + case Transport:getstat(Socket, [recv_oct]) of + {ok, [{recv_oct, RecvOct}]} -> + handle_timeout(TRef, {keepalive, RecvOct}, State); + {error, Reason} -> + shutdown(Reason, State) + end; + handle(info, {timeout, TRef, emit_stats}, State) when is_reference(TRef) -> handle_timeout(TRef, {emit_stats, stats(State)}, State); @@ -434,12 +416,9 @@ code_change(_Vsn, State, Data, _Extra) -> terminate(Reason, _StateName, #state{transport = Transport, socket = Socket, - keepalive = KeepAlive, chan_state = ChanState}) -> ?LOG(debug, "Terminated for ~p", [Reason]), ok = Transport:fast_close(Socket), - KeepAlive =/= undefined - andalso emqx_keepalive:cancel(KeepAlive), emqx_channel:terminate(Reason, ChanState). %%-------------------------------------------------------------------- @@ -539,24 +518,6 @@ handle_timeout(TRef, Msg, State = #state{chan_state = ChanState}) -> stop(Reason, State#state{chan_state = NChanState}) end. -%%-------------------------------------------------------------------- -%% Ensure keepalive - -ensure_keepalive(0, _State) -> - ignore; -ensure_keepalive(Interval, #state{transport = Transport, - socket = Socket, - chan_state = ChanState}) -> - StatFun = fun() -> - case Transport:getstat(Socket, [recv_oct]) of - {ok, [{recv_oct, RecvOct}]} -> - {ok, RecvOct}; - Error -> Error - end - end, - Backoff = emqx_zone:get_env(emqx_channel:info(zone, ChanState), - keepalive_backoff, 0.75), - emqx_keepalive:start(StatFun, round(Interval * Backoff), {keepalive, check}). %%-------------------------------------------------------------------- %% Ensure rate limit diff --git a/src/emqx_keepalive.erl b/src/emqx_keepalive.erl index 88848f7ac..6ce970b54 100644 --- a/src/emqx_keepalive.erl +++ b/src/emqx_keepalive.erl @@ -16,78 +16,58 @@ -module(emqx_keepalive). -%% APIs --export([ start/3 - , check/1 - , cancel/1 +-export([ init/1 + , info/1 + , info/2 + , check/2 ]). -export_type([keepalive/0]). -record(keepalive, { - statfun :: statfun(), - statval :: integer(), - tsec :: pos_integer(), - tmsg :: term(), - tref :: reference(), - repeat = 0 :: non_neg_integer() + interval :: pos_integer(), + statval :: non_neg_integer(), + repeat :: non_neg_integer() }). --type(statfun() :: fun(() -> {ok, integer()} | {error, term()})). - -opaque(keepalive() :: #keepalive{}). -%%-------------------------------------------------------------------- -%% APIs -%%-------------------------------------------------------------------- +%% @doc Init keepalive. +-spec(init(Interval :: non_neg_integer()) -> keepalive()). +init(Interval) when Interval > 0 -> + #keepalive{interval = Interval, + statval = 0, + repeat = 0}. -%% @doc Start a keepalive --spec(start(statfun(), pos_integer(), term()) - -> {ok, keepalive()} | {error, term()}). -start(StatFun, TimeoutSec, TimeoutMsg) when TimeoutSec > 0 -> - try StatFun() of - {ok, StatVal} -> - TRef = timer(TimeoutSec, TimeoutMsg), - {ok, #keepalive{statfun = StatFun, - statval = StatVal, - tsec = TimeoutSec, - tmsg = TimeoutMsg, - tref = TRef}}; - {error, Error} -> - {error, Error} - catch - _Error:Reason -> - {error, Reason} +%% @doc Get Info of the keepalive. +-spec(info(keepalive()) -> emqx_types:infos()). +info(#keepalive{interval = Interval, + statval = StatVal, + repeat = Repeat}) -> + #{interval => Interval, + statval => StatVal, + repeat => Repeat + }. + +-spec(info(interval|statval|repeat, keepalive()) + -> non_neg_integer()). +info(interval, #keepalive{interval = Interval}) -> + Interval; +info(statval, #keepalive{statval = StatVal}) -> + StatVal; +info(repeat, #keepalive{repeat = Repeat}) -> + Repeat. + +%% @doc Check keepalive. +-spec(check(non_neg_integer(), keepalive()) + -> {ok, keepalive()} | {error, timeout}). +check(NewVal, KeepAlive = #keepalive{statval = OldVal, + repeat = Repeat}) -> + if + NewVal =/= OldVal -> + {ok, KeepAlive#keepalive{statval = NewVal, repeat = 0}}; + Repeat < 1 -> + {ok, KeepAlive#keepalive{repeat = Repeat + 1}}; + true -> {error, timeout} end. -%% @doc Check keepalive, called when timeout... --spec(check(keepalive()) -> {ok, keepalive()} | {error, term()}). -check(KeepAlive = #keepalive{statfun = StatFun, statval = LastVal, repeat = Repeat}) -> - try StatFun() of - {ok, NewVal} -> - if NewVal =/= LastVal -> - {ok, resume(KeepAlive#keepalive{statval = NewVal, repeat = 0})}; - Repeat < 1 -> - {ok, resume(KeepAlive#keepalive{statval = NewVal, repeat = Repeat + 1})}; - true -> - {error, timeout} - end; - {error, Error} -> - {error, Error} - catch - _Error:Reason -> - {error, Reason} - end. - --spec(resume(keepalive()) -> keepalive()). -resume(KeepAlive = #keepalive{tsec = TimeoutSec, tmsg = TimeoutMsg}) -> - KeepAlive#keepalive{tref = timer(TimeoutSec, TimeoutMsg)}. - -%% @doc Cancel Keepalive --spec(cancel(keepalive()) -> ok). -cancel(#keepalive{tref = TRef}) when is_reference(TRef) -> - catch erlang:cancel_timer(TRef), ok. - -timer(Secs, Msg) -> - erlang:send_after(timer:seconds(Secs), self(), Msg). - diff --git a/src/emqx_misc.erl b/src/emqx_misc.erl index 007f444f4..94f859a82 100644 --- a/src/emqx_misc.erl +++ b/src/emqx_misc.erl @@ -16,7 +16,11 @@ -module(emqx_misc). +-include("types.hrl"). + -export([ merge_opts/2 + , maybe_apply/2 + , run_fold/3 , start_timer/2 , start_timer/3 , cancel_timer/1 @@ -44,6 +48,19 @@ merge_opts(Defaults, Options) -> lists:usort([Opt | Acc]) end, Defaults, Options). +%% @doc Apply a function to a maybe argument. +-spec(maybe_apply(fun((maybe(A)) -> maybe(A)), maybe(A)) + -> maybe(A) when A :: any()). +maybe_apply(_Fun, undefined) -> + undefined; +maybe_apply(Fun, Arg) when is_function(Fun) -> + erlang:apply(Fun, [Arg]). + +run_fold([], Acc, _State) -> + Acc; +run_fold([Fun|More], Acc, State) -> + run_fold(More, Fun(Acc, State), State). + -spec(start_timer(integer(), term()) -> reference()). start_timer(Interval, Msg) -> start_timer(Interval, self(), Msg). @@ -52,7 +69,7 @@ start_timer(Interval, Msg) -> start_timer(Interval, Dest, Msg) -> erlang:start_timer(Interval, Dest, Msg). --spec(cancel_timer(undefined | reference()) -> ok). +-spec(cancel_timer(maybe(reference())) -> ok). cancel_timer(Timer) when is_reference(Timer) -> case erlang:cancel_timer(Timer) of false -> diff --git a/src/emqx_packet.erl b/src/emqx_packet.erl index ea5657f41..4b0912d3f 100644 --- a/src/emqx_packet.erl +++ b/src/emqx_packet.erl @@ -19,7 +19,7 @@ -include("emqx.hrl"). -include("emqx_mqtt.hrl"). --export([ protocol_name/1 +-export([ proto_name/1 , type_name/1 , validate/1 , format/1 @@ -28,18 +28,20 @@ , will_msg/1 ]). -%% @doc Protocol name of version --spec(protocol_name(emqx_types:version()) -> binary()). -protocol_name(?MQTT_PROTO_V3) -> +-compile(inline). + +%% @doc Protocol name of the version. +-spec(proto_name(emqx_types:version()) -> binary()). +proto_name(?MQTT_PROTO_V3) -> <<"MQIsdp">>; -protocol_name(?MQTT_PROTO_V4) -> +proto_name(?MQTT_PROTO_V4) -> <<"MQTT">>; -protocol_name(?MQTT_PROTO_V5) -> +proto_name(?MQTT_PROTO_V5) -> <<"MQTT">>. -%% @doc Name of MQTT packet type +%% @doc Name of MQTT packet type. -spec(type_name(emqx_types:packet_type()) -> atom()). -type_name(Type) when Type > ?RESERVED andalso Type =< ?AUTH -> +type_name(Type) when ?RESERVED < Type, Type =< ?AUTH -> lists:nth(Type, ?TYPE_NAMES). %%-------------------------------------------------------------------- diff --git a/src/emqx_protocol.erl b/src/emqx_protocol.erl new file mode 100644 index 000000000..ebd59106d --- /dev/null +++ b/src/emqx_protocol.erl @@ -0,0 +1,136 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +%% MQTT Protocol +-module(emqx_protocol). + +-include("types.hrl"). +-include("emqx_mqtt.hrl"). + +-export([ init/1 + , info/1 + , info/2 + , attrs/1 + ]). + +-export([ find_alias/2 + , save_alias/3 + ]). + +-export_type([protocol/0]). + +-record(protocol, { + %% MQTT Proto Name + proto_name :: binary(), + %% MQTT Proto Version + proto_ver :: emqx_types:ver(), + %% Clean Start Flag + clean_start :: boolean(), + %% MQTT Keepalive interval + keepalive :: non_neg_integer(), + %% ClientId in CONNECT Packet + client_id :: emqx_types:client_id(), + %% Username in CONNECT Packet + username :: emqx_types:username(), + %% MQTT Will Msg + will_msg :: emqx_types:message(), + %% MQTT Conn Properties + conn_props :: maybe(emqx_types:properties()), + %% MQTT Topic Aliases + topic_aliases :: maybe(map()) + }). + +-opaque(protocol() :: #protocol{}). + +-spec(init(#mqtt_packet_connect{}) -> protocol()). +init(#mqtt_packet_connect{proto_name = ProtoName, + proto_ver = ProtoVer, + clean_start = CleanStart, + keepalive = Keepalive, + properties = Properties, + client_id = ClientId, + username = Username + } = ConnPkt) -> + WillMsg = emqx_packet:will_msg(ConnPkt), + #protocol{proto_name = ProtoName, + proto_ver = ProtoVer, + clean_start = CleanStart, + keepalive = Keepalive, + client_id = ClientId, + username = Username, + will_msg = WillMsg, + conn_props = Properties + }. + +info(#protocol{proto_name = ProtoName, + proto_ver = ProtoVer, + clean_start = CleanStart, + keepalive = Keepalive, + client_id = ClientId, + username = Username, + will_msg = WillMsg, + conn_props = ConnProps, + topic_aliases = Aliases }) -> + #{proto_name => ProtoName, + proto_ver => ProtoVer, + clean_start => CleanStart, + keepalive => Keepalive, + client_id => ClientId, + username => Username, + will_msg => WillMsg, + conn_props => ConnProps, + topic_aliases => Aliases + }. + +info(proto_name, #protocol{proto_name = ProtoName}) -> + ProtoName; +info(proto_ver, #protocol{proto_ver = ProtoVer}) -> + ProtoVer; +info(clean_start, #protocol{clean_start = CleanStart}) -> + CleanStart; +info(keepalive, #protocol{keepalive = Keepalive}) -> + Keepalive; +info(client_id, #protocol{client_id = ClientId}) -> + ClientId; +info(username, #protocol{username = Username}) -> + Username; +info(will_msg, #protocol{will_msg = WillMsg}) -> + WillMsg; +info(conn_props, #protocol{conn_props = ConnProps}) -> + ConnProps; +info(topic_aliases, #protocol{topic_aliases = Aliases}) -> + Aliases. + +attrs(#protocol{proto_name = ProtoName, + proto_ver = ProtoVer, + clean_start = CleanStart, + keepalive = Keepalive}) -> + #{proto_name => ProtoName, + proto_ver => ProtoVer, + clean_start => CleanStart, + keepalive => Keepalive + }. + +find_alias(_AliasId, #protocol{topic_aliases = undefined}) -> + false; +find_alias(AliasId, #protocol{topic_aliases = Aliases}) -> + maps:find(AliasId, Aliases). + +save_alias(AliasId, Topic, Protocol = #protocol{topic_aliases = undefined}) -> + Protocol#protocol{topic_aliases = #{AliasId => Topic}}; +save_alias(AliasId, Topic, Protocol = #protocol{topic_aliases = Aliases}) -> + Protocol#protocol{topic_aliases = maps:put(AliasId, Topic, Aliases)}. + diff --git a/src/emqx_session.erl b/src/emqx_session.erl index 72a683b2a..98afade52 100644 --- a/src/emqx_session.erl +++ b/src/emqx_session.erl @@ -50,7 +50,7 @@ -logger_header("[Session]"). --export([init/3]). +-export([init/2]). -export([ info/1 , info/2 @@ -58,10 +58,6 @@ , stats/1 ]). --export([ takeover/1 - , resume/2 - ]). - -export([ subscribe/4 , unsubscribe/3 ]). @@ -73,71 +69,51 @@ , pubcomp/2 ]). --export([deliver/2]). +-export([ deliver/2 + , retry/1 + ]). --export([timeout/3]). +-export([ takeover/1 + , resume/2 + ]). + +-export([expire/2]). -export_type([session/0]). --import(emqx_zone, - [ get_env/2 - , get_env/3 - ]). - %% For test case -export([set_pkt_id/2]). --record(session, { - %% Clean Start Flag - clean_start :: boolean(), +-import(emqx_zone, [get_env/3]). +-record(session, { %% Client’s Subscriptions. subscriptions :: map(), - %% Max subscriptions allowed max_subscriptions :: non_neg_integer(), - %% Upgrade QoS? upgrade_qos :: boolean(), - %% Client <- Broker: %% Inflight QoS1, QoS2 messages sent to the client but unacked. inflight :: emqx_inflight:inflight(), - %% All QoS1, QoS2 messages published to when client is disconnected. %% QoS 1 and QoS 2 messages pending transmission to the Client. %% %% Optionally, QoS 0 messages pending transmission to the Client. mqueue :: emqx_mqueue:mqueue(), - %% Next packet id of the session next_pkt_id = 1 :: emqx_types:packet_id(), - %% Retry interval for redelivering QoS1/2 messages retry_interval :: timeout(), - - %% Retry delivery timer - retry_timer :: maybe(reference()), - %% Client -> Broker: %% Inflight QoS2 messages received from client and waiting for pubrel. awaiting_rel :: map(), - %% Max Packets Awaiting PUBREL max_awaiting_rel :: non_neg_integer(), - - %% Awaiting PUBREL Timer - await_rel_timer :: maybe(reference()), - %% Awaiting PUBREL Timeout await_rel_timeout :: timeout(), - %% Session Expiry Interval expiry_interval :: timeout(), - - %% Expired Timer - expiry_timer :: maybe(reference()), - %% Created at created_at :: erlang:timestamp() }). @@ -153,11 +129,10 @@ %%-------------------------------------------------------------------- %% @doc Init a session. --spec(init(boolean(), emqx_types:client(), Options :: map()) -> session()). -init(CleanStart, #{zone := Zone}, #{max_inflight := MaxInflight, - expiry_interval := ExpiryInterval}) -> - #session{clean_start = CleanStart, - max_subscriptions = get_env(Zone, max_subscriptions, 0), +-spec(init(emqx_types:client(), Options :: map()) -> session()). +init(#{zone := Zone}, #{max_inflight := MaxInflight, + expiry_interval := ExpiryInterval}) -> + #session{max_subscriptions = get_env(Zone, max_subscriptions, 0), subscriptions = #{}, upgrade_qos = get_env(Zone, upgrade_qos, false), inflight = emqx_inflight:new(MaxInflight), @@ -183,8 +158,7 @@ init_mqueue(Zone) -> %%-------------------------------------------------------------------- -spec(info(session()) -> emqx_types:infos()). -info(#session{clean_start = CleanStart, - max_subscriptions = MaxSubscriptions, +info(#session{max_subscriptions = MaxSubscriptions, subscriptions = Subscriptions, upgrade_qos = UpgradeQoS, inflight = Inflight, @@ -196,8 +170,7 @@ info(#session{clean_start = CleanStart, await_rel_timeout = AwaitRelTimeout, expiry_interval = ExpiryInterval, created_at = CreatedAt}) -> - #{clean_start => CleanStart, - subscriptions => Subscriptions, + #{subscriptions => Subscriptions, max_subscriptions => MaxSubscriptions, upgrade_qos => UpgradeQoS, inflight => emqx_inflight:size(Inflight), @@ -214,8 +187,6 @@ info(#session{clean_start = CleanStart, created_at => CreatedAt }. -info(clean_start, #session{clean_start = CleanStart}) -> - CleanStart; info(subscriptions, #session{subscriptions = Subs}) -> Subs; info(max_subscriptions, #session{max_subscriptions = MaxSubs}) -> @@ -254,11 +225,9 @@ info(created_at, #session{created_at = CreatedAt}) -> -spec(attrs(session()) -> emqx_types:attrs()). attrs(undefined) -> #{}; -attrs(#session{clean_start = CleanStart, - expiry_interval = ExpiryInterval, +attrs(#session{expiry_interval = ExpiryInterval, created_at = CreatedAt}) -> - #{clean_start => CleanStart, - expiry_interval => ExpiryInterval, + #{expiry_interval => ExpiryInterval, created_at => CreatedAt }. @@ -290,7 +259,7 @@ takeover(#session{subscriptions = Subs}) -> ok = emqx_broker:unsubscribe(TopicFilter) end, maps:to_list(Subs)). --spec(resume(emqx_types:client_id(), session()) -> {ok, session()}). +-spec(resume(emqx_types:client_id(), session()) -> session()). resume(ClientId, Session = #session{subscriptions = Subs}) -> ?LOG(info, "Session is resumed."), %% 1. Subscribe again @@ -300,8 +269,8 @@ resume(ClientId, Session = #session{subscriptions = Subs}) -> %% 2. Run hooks. ok = emqx_hooks:run('session.resumed', [#{client_id => ClientId}, attrs(Session)]), %% TODO: 3. Redeliver: Replay delivery and Dequeue pending messages - %% noreply(ensure_stats_timer(dequeue(retry_delivery(true, State1)))); - {ok, Session}. + %% noreply(dequeue(retry_delivery(true, State1))); + Session. %%-------------------------------------------------------------------- %% Client -> Broker: SUBSCRIBE @@ -388,7 +357,7 @@ do_publish(PacketId, Msg = #message{timestamp = Ts}, DeliverResults = emqx_broker:publish(Msg), AwaitingRel1 = maps:put(PacketId, Ts, AwaitingRel), Session1 = Session#session{awaiting_rel = AwaitingRel1}, - {ok, DeliverResults, ensure_await_rel_timer(Session1)}; + {ok, DeliverResults, Session1}; true -> {error, ?RC_PACKET_IDENTIFIER_IN_USE} end. @@ -544,9 +513,8 @@ enqueue(Msg, Session = #session{mqueue = Q}) -> %%-------------------------------------------------------------------- await(PacketId, Msg, Session = #session{inflight = Inflight}) -> - Inflight1 = emqx_inflight:insert( - PacketId, {Msg, os:timestamp()}, Inflight), - ensure_retry_timer(Session#session{inflight = Inflight1}). + Inflight1 = emqx_inflight:insert(PacketId, {Msg, os:timestamp()}, Inflight), + Session#session{inflight = Inflight1}. get_subopts(Topic, SubMap) -> case maps:find(Topic, SubMap) of @@ -578,44 +546,12 @@ enrich([{rap, _}|Opts], Msg = #message{flags = Flags}, Session) -> enrich([{subid, SubId}|Opts], Msg, Session) -> enrich(Opts, emqx_message:set_header('Subscription-Identifier', SubId, Msg), Session). -%%-------------------------------------------------------------------- -%% Handle timeout -%%-------------------------------------------------------------------- - --spec(timeout(reference(), atom(), session()) - -> {ok, session()} | {ok, list(), session()}). -timeout(TRef, retry_delivery, Session = #session{retry_timer = TRef}) -> - retry_delivery(Session#session{retry_timer = undefined}); - -timeout(TRef, check_awaiting_rel, Session = #session{await_rel_timer = TRef}) -> - expire_awaiting_rel(Session); - -timeout(TRef, Msg, Session) -> - ?LOG(error, "unexpected timeout - ~p: ~p", [TRef, Msg]), - {ok, Session}. - -%%-------------------------------------------------------------------- -%% Ensure retry timer -%%-------------------------------------------------------------------- - -ensure_retry_timer(Session = #session{retry_interval = Interval, - retry_timer = undefined}) -> - ensure_retry_timer(Interval, Session); -ensure_retry_timer(Session) -> - Session. - -ensure_retry_timer(Interval, Session = #session{retry_timer = undefined}) -> - TRef = emqx_misc:start_timer(Interval, retry_delivery), - Session#session{retry_timer = TRef}; -ensure_retry_timer(_Interval, Session) -> - Session. - %%-------------------------------------------------------------------- %% Retry Delivery %%-------------------------------------------------------------------- %% Redeliver at once if force is true -retry_delivery(Session = #session{inflight = Inflight}) -> +retry(Session = #session{inflight = Inflight}) -> case emqx_inflight:is_empty(Inflight) of true -> {ok, Session}; false -> @@ -626,10 +562,11 @@ retry_delivery(Session = #session{inflight = Inflight}) -> retry_delivery([], _Now, Acc, Session) -> %% Retry again... - {ok, lists:reverse(Acc), ensure_retry_timer(Session)}; + {ok, lists:reverse(Acc), Session}; retry_delivery([{PacketId, {Val, Ts}}|More], Now, Acc, - Session = #session{retry_interval = Interval, inflight = Inflight}) -> + Session = #session{retry_interval = Interval, + inflight = Inflight}) -> %% Microseconds -> MilliSeconds Age = timer:now_diff(Now, Ts) div 1000, if @@ -637,7 +574,7 @@ retry_delivery([{PacketId, {Val, Ts}}|More], Now, Acc, {Acc1, Inflight1} = retry_delivery(PacketId, Val, Now, Acc, Inflight), retry_delivery(More, Now, Acc1, Session#session{inflight = Inflight1}); true -> - {ok, lists:reverse(Acc), ensure_retry_timer(Interval - max(0, Age), Session)} + {ok, lists:reverse(Acc), Interval - max(0, Age), Session} end. retry_delivery(PacketId, Msg, Now, Acc, Inflight) when is_record(Msg, message) -> @@ -654,34 +591,20 @@ retry_delivery(PacketId, pubrel, Now, Acc, Inflight) -> Inflight1 = emqx_inflight:update(PacketId, {pubrel, Now}, Inflight), {[{pubrel, PacketId}|Acc], Inflight1}. -%%-------------------------------------------------------------------- -%% Ensure await_rel timer -%%-------------------------------------------------------------------- - -ensure_await_rel_timer(Session = #session{await_rel_timeout = Timeout, - await_rel_timer = undefined}) -> - ensure_await_rel_timer(Timeout, Session); -ensure_await_rel_timer(Session) -> - Session. - -ensure_await_rel_timer(Timeout, Session = #session{await_rel_timer = undefined}) -> - TRef = emqx_misc:start_timer(Timeout, check_awaiting_rel), - Session#session{await_rel_timer = TRef}; -ensure_await_rel_timer(_Timeout, Session) -> - Session. - %%-------------------------------------------------------------------- %% Expire Awaiting Rel %%-------------------------------------------------------------------- -expire_awaiting_rel(Session = #session{awaiting_rel = AwaitingRel}) -> +expire(awaiting_rel, Session = #session{awaiting_rel = AwaitingRel}) -> case maps:size(AwaitingRel) of 0 -> {ok, Session}; - _ -> expire_awaiting_rel(lists:keysort(2, maps:to_list(AwaitingRel)), os:timestamp(), Session) + _ -> + AwaitingRel1 = lists:keysort(2, maps:to_list(AwaitingRel)), + expire_awaiting_rel(AwaitingRel1, os:timestamp(), Session) end. expire_awaiting_rel([], _Now, Session) -> - {ok, Session#session{await_rel_timer = undefined}}; + {ok, Session}; expire_awaiting_rel([{PacketId, Ts} | More], Now, Session = #session{awaiting_rel = AwaitingRel, @@ -693,7 +616,7 @@ expire_awaiting_rel([{PacketId, Ts} | More], Now, Session1 = Session#session{awaiting_rel = maps:remove(PacketId, AwaitingRel)}, expire_awaiting_rel(More, Now, Session1); Age -> - {ok, ensure_await_rel_timer(Timeout - max(0, Age), Session)} + {ok, Timeout - max(0, Age), Session} end. %%-------------------------------------------------------------------- diff --git a/src/emqx_ws_connection.erl b/src/emqx_ws_connection.erl index 530a300f8..6dc709000 100644 --- a/src/emqx_ws_connection.erl +++ b/src/emqx_ws_connection.erl @@ -22,7 +22,7 @@ -include("logger.hrl"). -include("types.hrl"). --logger_header("[WsConn]"). +-logger_header("[WsConnection]"). -export([ info/1 , attrs/1 @@ -49,7 +49,6 @@ serialize :: fun((emqx_types:packet()) -> iodata()), parse_state :: emqx_frame:parse_state(), chan_state :: emqx_channel:channel(), - keepalive :: maybe(emqx_keepalive:keepalive()), pendings :: list(), reason :: term() }). @@ -57,7 +56,7 @@ -type(state() :: #state{}). -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]). --define(CHAN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). +-define(CONN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). %%-------------------------------------------------------------------- %% API @@ -66,36 +65,37 @@ -spec(info(pid() | state()) -> emqx_types:infos()). info(WSPid) when is_pid(WSPid) -> call(WSPid, info); -info(#state{peername = Peername, - sockname = Sockname, - chan_state = ChanState - }) -> - ConnInfo = #{socktype => websocket, - peername => Peername, - sockname => Sockname, +info(#state{peername = Peername, + sockname = Sockname, + chan_state = ChanState}) -> + ConnInfo = #{socktype => websocket, + peername => Peername, + sockname => Sockname, conn_state => running }, - maps:merge(ConnInfo, emqx_channel:info(ChanState)). + ChanInfo = emqx_channel:info(ChanState), + maps:merge(ConnInfo, ChanInfo). -spec(attrs(pid() | state()) -> emqx_types:attrs()). attrs(WSPid) when is_pid(WSPid) -> call(WSPid, attrs); -attrs(#state{peername = Peername, - sockname = Sockname, +attrs(#state{peername = Peername, + sockname = Sockname, chan_state = ChanState}) -> ConnAttrs = #{socktype => websocket, peername => Peername, sockname => Sockname }, - maps:merge(ConnAttrs, emqx_channel:attrs(ChanState)). + ChanAttrs = emqx_channel:attrs(ChanState), + maps:merge(ConnAttrs, ChanAttrs). -spec(stats(pid() | state()) -> emqx_types:stats()). stats(WSPid) when is_pid(WSPid) -> call(WSPid, stats); stats(#state{chan_state = ChanState}) -> ProcStats = emqx_misc:proc_stats(), - SessStats = emqx_session:stats(emqx_channel:info(session, ChanState)), - lists:append([ProcStats, SessStats, chan_stats(), wsock_stats()]). + ChanStats = emqx_channel:stats(ChanState), + lists:append([ProcStats, wsock_stats(), conn_stats(), ChanStats]). -spec(kick(pid()) -> ok). kick(CPid) -> @@ -105,7 +105,7 @@ kick(CPid) -> discard(WSPid) -> WSPid ! {cast, discard}, ok. --spec(takeover(pid(), 'begin'|'end') -> {ok, Result :: term()}). +-spec(takeover(pid(), 'begin'|'end') -> Result :: term()). takeover(CPid, Phase) -> call(CPid, {takeover, Phase}). @@ -177,17 +177,14 @@ websocket_init([Req, Opts]) -> MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE), ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}), emqx_logger:set_metadata_peername(esockd_net:format(Peername)), - {ok, #state{peername = Peername, - sockname = Sockname, - fsm_state = idle, - parse_state = ParseState, - chan_state = ChanState, - pendings = [] + {ok, #state{peername = Peername, + sockname = Sockname, + fsm_state = idle, + parse_state = ParseState, + chan_state = ChanState, + pendings = [] }}. -stat_fun() -> - fun() -> {ok, emqx_pd:get_counter(recv_oct)} end. - websocket_handle({binary, Data}, State) when is_list(Data) -> websocket_handle({binary, iolist_to_binary(Data)}, State); @@ -199,7 +196,7 @@ websocket_handle({binary, Data}, State = #state{chan_state = ChanState}) emqx_pd:update_counter(recv_oct, Oct), ok = emqx_metrics:inc('bytes.received', Oct), NChanState = emqx_channel:ensure_timer( - emit_stats, emqx_channel:gc(1, Oct, ChanState)), + stats_timer, emqx_channel:gc(1, Oct, ChanState)), process_incoming(Data, State#state{chan_state = NChanState}); %% Pings should be replied with pongs, cowboy does it automatically @@ -231,6 +228,16 @@ websocket_info({call, From, kick}, State) -> gen_server:reply(From, ok), stop(kicked, State); +websocket_info({call, From, Req}, State = #state{chan_state = ChanState}) -> + case emqx_channel:handle_call(Req, ChanState) of + {ok, Reply, NChanState} -> + _ = gen_server:reply(From, Reply), + {ok, State#state{chan_state = NChanState}}; + {stop, Reason, Reply, NChanState} -> + _ = gen_server:reply(From, Reply), + stop(Reason, State#state{chan_state = NChanState}) + end; + websocket_info({cast, Msg}, State = #state{chan_state = ChanState}) -> case emqx_channel:handle_cast(Msg, ChanState) of {ok, NChanState} -> @@ -262,7 +269,8 @@ websocket_info({incoming, Packet}, State = #state{fsm_state = connected}) websocket_info(Deliver = {deliver, _Topic, _Msg}, State = #state{chan_state = ChanState}) -> - case emqx_channel:handle_out(Deliver, ChanState) of + Delivers = emqx_misc:drain_deliver([Deliver]), + case emqx_channel:handle_out({deliver, Delivers}, ChanState) of {ok, NChanState} -> reply(State#state{chan_state = NChanState}); {ok, Packets, NChanState} -> @@ -271,16 +279,9 @@ websocket_info(Deliver = {deliver, _Topic, _Msg}, stop(Reason, State#state{chan_state = NChanState}) end; -websocket_info({keepalive, check}, State = #state{keepalive = KeepAlive}) -> - case emqx_keepalive:check(KeepAlive) of - {ok, KeepAlive1} -> - {ok, State#state{keepalive = KeepAlive1}}; - {error, timeout} -> - stop(keepalive_timeout, State); - {error, Error} -> - ?LOG(error, "Keepalive error: ~p", [Error]), - stop(keepalive_error, State) - end; +websocket_info({timeout, TRef, keepalive}, State) when is_reference(TRef) -> + RecvOct = emqx_pd:get_counter(recv_oct), + handle_timeout(TRef, {keepalive, RecvOct}, State); websocket_info({timeout, TRef, emit_stats}, State) when is_reference(TRef) -> handle_timeout(TRef, {emit_stats, stats(State)}, State); @@ -310,13 +311,10 @@ websocket_info(Info, State = #state{chan_state = ChanState}) -> stop(Reason, State#state{chan_state = NChanState}) end. -terminate(SockError, _Req, #state{keepalive = KeepAlive, - chan_state = ChanState, +terminate(SockError, _Req, #state{chan_state = ChanState, reason = Reason}) -> ?LOG(debug, "Terminated for ~p, sockerror: ~p", [Reason, SockError]), - KeepAlive =/= undefined - andalso emqx_keepalive:cancel(KeepAlive), emqx_channel:terminate(Reason, ChanState). %%-------------------------------------------------------------------- @@ -324,18 +322,10 @@ terminate(SockError, _Req, #state{keepalive = KeepAlive, connected(State = #state{chan_state = ChanState}) -> NState = State#state{fsm_state = connected}, - ClientId = emqx_channel:info(client_id, ChanState), + #{client_id := ClientId} = emqx_channel:info(client, ChanState), ok = emqx_cm:register_channel(ClientId), - ok = emqx_cm:set_chan_attrs(ClientId, info(NState)), - %% Ensure keepalive after connected successfully. - Interval = emqx_channel:info(keepalive, ChanState), - case ensure_keepalive(Interval, NState) of - ignore -> reply(NState); - {ok, KeepAlive} -> - reply(NState#state{keepalive = KeepAlive}); - {error, Reason} -> - stop(Reason, NState) - end. + ok = emqx_cm:set_chan_attrs(ClientId, attrs(NState)), + reply(NState). %%-------------------------------------------------------------------- %% Handle timeout @@ -350,16 +340,6 @@ handle_timeout(TRef, Msg, State = #state{chan_state = ChanState}) -> stop(Reason, State#state{chan_state = NChanState}) end. -%%-------------------------------------------------------------------- -%% Ensure keepalive - -ensure_keepalive(0, _State) -> - ignore; -ensure_keepalive(Interval, #state{chan_state = ChanState}) -> - Backoff = emqx_zone:get_env(emqx_channel:info(zone, ChanState), - keepalive_backoff, 0.75), - emqx_keepalive:start(stat_fun(), round(Interval * Backoff), {keepalive, check}). - %%-------------------------------------------------------------------- %% Process incoming data @@ -440,7 +420,7 @@ reply(State = #state{pendings = []}) -> {ok, State}; reply(State = #state{chan_state = ChanState, pendings = Pendings}) -> Reply = handle_outgoing(Pendings, State), - NChanState = emqx_channel:ensure_timer(emit_stats, ChanState), + NChanState = emqx_channel:ensure_timer(stats_timer, ChanState), {reply, Reply, State#state{chan_state = NChanState, pendings = []}}. stop(Reason, State = #state{pendings = []}) -> @@ -458,6 +438,6 @@ enqueue(Packets, State = #state{pendings = Pendings}) -> wsock_stats() -> [{Key, emqx_pd:get_counter(Key)} || Key <- ?SOCK_STATS]. -chan_stats() -> - [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS]. +conn_stats() -> + [{Name, emqx_pd:get_counter(Name)} || Name <- ?CONN_STATS]. diff --git a/test/emqx_channel_SUITE.erl b/test/emqx_channel_SUITE.erl index 02f9aa975..e7326cd76 100644 --- a/test/emqx_channel_SUITE.erl +++ b/test/emqx_channel_SUITE.erl @@ -22,7 +22,6 @@ -import(emqx_channel, [ handle_in/2 , handle_out/2 - , handle_out/3 ]). -include("emqx.hrl"). @@ -58,9 +57,10 @@ t_handle_connect(_) -> fun(Channel) -> {ok, ?CONNACK_PACKET(?RC_SUCCESS), Channel1} = handle_in(?CONNECT_PACKET(ConnPkt), Channel), - Client = emqx_channel:info(client, Channel1), - ?assertEqual(<<"clientid">>, maps:get(client_id, Client)), - ?assertEqual(<<"username">>, maps:get(username, Client)) + #{client_id := ClientId, username := Username} + = emqx_channel:info(client, Channel1), + ?assertEqual(<<"clientid">>, ClientId), + ?assertEqual(<<"username">>, Username) end). t_handle_publish_qos0(_) -> @@ -86,8 +86,8 @@ t_handle_publish_qos2(_) -> Publish2 = ?PUBLISH_PACKET(?QOS_2, <<"topic">>, 2, <<"payload">>), {ok, ?PUBREC_PACKET(2, RC), Channel2} = handle_in(Publish2, Channel1), ?assert((RC == ?RC_SUCCESS) orelse (RC == ?RC_NO_MATCHING_SUBSCRIBERS)), - Session = emqx_channel:info(session, Channel2), - ?assertEqual(2, emqx_session:info(awaiting_rel, Session)) + #{awaiting_rel := AwaitingRel} = emqx_channel:info(session, Channel2), + ?assertEqual(2, AwaitingRel) end). t_handle_puback(_) -> @@ -122,10 +122,9 @@ t_handle_subscribe(_) -> TopicFilters = [{<<"+">>, ?DEFAULT_SUBOPTS}], {ok, ?SUBACK_PACKET(10, [?QOS_0]), Channel1} = handle_in(?SUBSCRIBE_PACKET(10, #{}, TopicFilters), Channel), - Session = emqx_channel:info(session, Channel1), - ?assertEqual(maps:from_list(TopicFilters), - emqx_session:info(subscriptions, Session)) - + #{subscriptions := Subscriptions} + = emqx_channel:info(session, Channel1), + ?assertEqual(maps:from_list(TopicFilters), Subscriptions) end). t_handle_unsubscribe(_) -> @@ -145,7 +144,7 @@ t_handle_disconnect(_) -> with_channel( fun(Channel) -> {stop, normal, Channel1} = handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), Channel), - ?assertEqual(undefined, emqx_channel:info(will_msg, Channel1)) + ?assertMatch(#{will_msg := undefined}, emqx_channel:info(protocol, Channel1)) end). t_handle_auth(_) -> @@ -166,9 +165,8 @@ t_handle_deliver(_) -> = handle_in(?SUBSCRIBE_PACKET(1, #{}, TopicFilters), Channel), Msg0 = emqx_message:make(<<"clientx">>, ?QOS_0, <<"t0">>, <<"qos0">>), Msg1 = emqx_message:make(<<"clientx">>, ?QOS_1, <<"t1">>, <<"qos1">>), - %% TODO: Fixme later. - self() ! {deliver, <<"+">>, Msg1}, - {ok, Packets, _Channel2} = emqx_channel:handle_out({deliver, <<"+">>, Msg0}, Channel1), + Delivers = [{deliver, <<"+">>, Msg0}, {deliver, <<"+">>, Msg1}], + {ok, Packets, _Ch} = emqx_channel:handle_out({deliver, Delivers}, Channel1), ?assertMatch([?PUBLISH_PACKET(?QOS_0, <<"t0">>, undefined, <<"qos0">>), ?PUBLISH_PACKET(?QOS_1, <<"t1">>, 1, <<"qos1">>) ], Packets) @@ -178,13 +176,13 @@ t_handle_deliver(_) -> %% Test cases for handle_out %%-------------------------------------------------------------------- -t_handle_conack(_) -> +t_handle_connack(_) -> with_channel( fun(Channel) -> {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, _), _} - = handle_out(connack, {?RC_SUCCESS, 0}, Channel), + = handle_out({connack, ?RC_SUCCESS, 0}, Channel), {stop, {shutdown, unauthorized_client}, ?CONNACK_PACKET(5), _} - = handle_out(connack, ?RC_NOT_AUTHORIZED, Channel) + = handle_out({connack, ?RC_NOT_AUTHORIZED}, Channel) end). t_handle_out_publish(_) -> @@ -194,59 +192,59 @@ t_handle_out_publish(_) -> Pub1 = {publish, 1, emqx_message:make(<<"c">>, ?QOS_1, <<"t">>, <<"qos1">>)}, {ok, ?PUBLISH_PACKET(?QOS_0), Channel} = handle_out(Pub0, Channel), {ok, ?PUBLISH_PACKET(?QOS_1), Channel} = handle_out(Pub1, Channel), - {ok, Packets, Channel} = handle_out(publish, [Pub0, Pub1], Channel), + {ok, Packets, Channel} = handle_out({publish, [Pub0, Pub1]}, Channel), ?assertEqual(2, length(Packets)) end). t_handle_out_puback(_) -> with_channel( fun(Channel) -> - {ok, Channel} = handle_out(puberr, ?RC_NOT_AUTHORIZED, Channel), + {ok, Channel} = handle_out({puberr, ?RC_NOT_AUTHORIZED}, Channel), {ok, ?PUBACK_PACKET(1, ?RC_SUCCESS), Channel} - = handle_out(puback, {1, ?RC_SUCCESS}, Channel) + = handle_out({puback, 1, ?RC_SUCCESS}, Channel) end). t_handle_out_pubrec(_) -> with_channel( fun(Channel) -> {ok, ?PUBREC_PACKET(4, ?RC_SUCCESS), Channel} - = handle_out(pubrec, {4, ?RC_SUCCESS}, Channel) + = handle_out({pubrec, 4, ?RC_SUCCESS}, Channel) end). t_handle_out_pubrel(_) -> with_channel( fun(Channel) -> {ok, ?PUBREL_PACKET(2), Channel} - = handle_out(pubrel, {2, ?RC_SUCCESS}, Channel), + = handle_out({pubrel, 2, ?RC_SUCCESS}, Channel), {ok, ?PUBREL_PACKET(3, ?RC_SUCCESS), Channel} - = handle_out(pubrel, {3, ?RC_SUCCESS}, Channel) + = handle_out({pubrel, 3, ?RC_SUCCESS}, Channel) end). t_handle_out_pubcomp(_) -> with_channel( fun(Channel) -> {ok, ?PUBCOMP_PACKET(5, ?RC_SUCCESS), Channel} - = handle_out(pubcomp, {5, ?RC_SUCCESS}, Channel) + = handle_out({pubcomp, 5, ?RC_SUCCESS}, Channel) end). t_handle_out_suback(_) -> with_channel( fun(Channel) -> {ok, ?SUBACK_PACKET(1, [?QOS_2]), Channel} - = handle_out(suback, {1, [?QOS_2]}, Channel) + = handle_out({suback, 1, [?QOS_2]}, Channel) end). t_handle_out_unsuback(_) -> with_channel( fun(Channel) -> {ok, ?UNSUBACK_PACKET(1), Channel} - = handle_out(unsuback, {1, [?RC_SUCCESS]}, Channel) + = handle_out({unsuback, 1, [?RC_SUCCESS]}, Channel) end). t_handle_out_disconnect(_) -> with_channel( fun(Channel) -> - handle_out(disconnect, ?RC_SUCCESS, Channel) + handle_out({disconnect, ?RC_SUCCESS}, Channel) end). %%-------------------------------------------------------------------- @@ -281,9 +279,20 @@ with_channel(Fun) -> }, Options = [{zone, testing}], Channel = emqx_channel:init(ConnInfo, Options), - Session = emqx_session:init(false, #{zone => testing}, - #{max_inflight => 100, + ConnPkt = #mqtt_packet_connect{ + proto_name = <<"MQTT">>, + proto_ver = ?MQTT_PROTO_V4, + clean_start = true, + keepalive = 30, + properties = #{}, + client_id = <<"clientid">>, + username = <<"username">> + }, + Protocol = emqx_protocol:init(ConnPkt), + Session = emqx_session:init(#{zone => testing}, + #{max_inflight => 100, expiry_interval => 0 }), - Fun(emqx_channel:set(session, Session, Channel)). + Fun(emqx_channel:set(protocol, Protocol, + emqx_channel:set(session, Session, Channel))). diff --git a/test/emqx_keepalive_SUITE.erl b/test/emqx_keepalive_SUITE.erl index f140913ec..0bdc79f60 100644 --- a/test/emqx_keepalive_SUITE.erl +++ b/test/emqx_keepalive_SUITE.erl @@ -19,23 +19,24 @@ -compile(export_all). -compile(nowarn_export_all). +-include_lib("eunit/include/eunit.hrl"). + all() -> emqx_ct:all(?MODULE). -t_keepalive(_) -> - {ok, KA} = emqx_keepalive:start(fun() -> {ok, 1} end, 1, {keepalive, timeout}), - [resumed, timeout] = lists:reverse(keepalive_recv(KA, [])). - -keepalive_recv(KA, Acc) -> - receive - {keepalive, timeout} -> - case emqx_keepalive:check(KA) of - {ok, KA1} -> keepalive_recv(KA1, [resumed | Acc]); - {error, timeout} -> [timeout | Acc] - end - after 4000 -> Acc - end. - -t_cancel(_) -> - {ok, KA} = emqx_keepalive:start(fun() -> {ok, 1} end, 1, {keepalive, timeout}), - ok = emqx_keepalive:cancel(KA). +t_check(_) -> + Keepalive = emqx_keepalive:init(60), + ?assertEqual(60, emqx_keepalive:info(interval, Keepalive)), + ?assertEqual(0, emqx_keepalive:info(statval, Keepalive)), + ?assertEqual(0, emqx_keepalive:info(repeat, Keepalive)), + Info = emqx_keepalive:info(Keepalive), + ?assertEqual(#{interval => 60, + statval => 0, + repeat => 0}, Info), + {ok, Keepalive1} = emqx_keepalive:check(1, Keepalive), + ?assertEqual(1, emqx_keepalive:info(statval, Keepalive1)), + ?assertEqual(0, emqx_keepalive:info(repeat, Keepalive1)), + {ok, Keepalive2} = emqx_keepalive:check(1, Keepalive1), + ?assertEqual(1, emqx_keepalive:info(statval, Keepalive2)), + ?assertEqual(1, emqx_keepalive:info(repeat, Keepalive2)), + ?assertEqual({error, timeout}, emqx_keepalive:check(1, Keepalive2)). diff --git a/test/emqx_net_SUITE.erl b/test/emqx_net_SUITE.erl deleted file mode 100644 index 439ac6c70..000000000 --- a/test/emqx_net_SUITE.erl +++ /dev/null @@ -1,45 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - --module(emqx_net_SUITE). - -%% CT --compile(export_all). --compile(nowarn_export_all). - -all() -> [{group, keepalive}]. - -groups() -> [{keepalive, [], [t_keepalive]}]. - -%%-------------------------------------------------------------------- -%% Keepalive -%%-------------------------------------------------------------------- - -t_keepalive(_) -> - {ok, KA} = emqx_keepalive:start(fun() -> {ok, 1} end, 1, {keepalive, timeout}), - [resumed, timeout] = lists:reverse(keepalive_recv(KA, [])). - -keepalive_recv(KA, Acc) -> - receive - {keepalive, timeout} -> - case emqx_keepalive:check(KA) of - {ok, KA1} -> keepalive_recv(KA1, [resumed | Acc]); - {error, timeout} -> [timeout | Acc] - end - after 4000 -> - Acc - end. - diff --git a/test/emqx_packet_SUITE.erl b/test/emqx_packet_SUITE.erl index b334093b8..732c48fe3 100644 --- a/test/emqx_packet_SUITE.erl +++ b/test/emqx_packet_SUITE.erl @@ -27,9 +27,9 @@ all() -> emqx_ct:all(?MODULE). t_proto_name(_) -> - ?assertEqual(<<"MQIsdp">>, emqx_packet:protocol_name(3)), - ?assertEqual(<<"MQTT">>, emqx_packet:protocol_name(4)), - ?assertEqual(<<"MQTT">>, emqx_packet:protocol_name(5)). + ?assertEqual(<<"MQIsdp">>, emqx_packet:proto_name(3)), + ?assertEqual(<<"MQTT">>, emqx_packet:proto_name(4)), + ?assertEqual(<<"MQTT">>, emqx_packet:proto_name(5)). t_type_name(_) -> ?assertEqual('CONNECT', emqx_packet:type_name(?CONNECT)), diff --git a/test/emqx_protocol_SUITE.erl b/test/emqx_protocol_SUITE.erl new file mode 100644 index 000000000..89f1d7344 --- /dev/null +++ b/test/emqx_protocol_SUITE.erl @@ -0,0 +1,49 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_protocol_SUITE). + +-compile(export_all). +-compile(nowarn_export_all). + +-include("emqx_mqtt.hrl"). +-include_lib("eunit/include/eunit.hrl"). + +all() -> emqx_ct:all(?MODULE). + +t_init_and_info(_) -> + ConnPkt = #mqtt_packet_connect{ + proto_name = <<"MQTT">>, + proto_ver = ?MQTT_PROTO_V4, + is_bridge = false, + clean_start = true, + keepalive = 30, + properties = #{}, + client_id = <<"clientid">>, + username = <<"username">>, + password = <<"passwd">> + }, + Proto = emqx_protocol:init(ConnPkt), + ?assertEqual(<<"MQTT">>, emqx_protocol:info(proto_name, Proto)), + ?assertEqual(?MQTT_PROTO_V4, emqx_protocol:info(proto_ver, Proto)), + ?assertEqual(true, emqx_protocol:info(clean_start, Proto)), + ?assertEqual(<<"clientid">>, emqx_protocol:info(client_id, Proto)), + ?assertEqual(<<"username">>, emqx_protocol:info(username, Proto)), + ?assertEqual(undefined, emqx_protocol:info(will_msg, Proto)), + ?assertEqual(#{}, emqx_protocol:info(conn_props, Proto)). + + + diff --git a/test/emqx_session_SUITE.erl b/test/emqx_session_SUITE.erl index 224b8afaf..c142284c1 100644 --- a/test/emqx_session_SUITE.erl +++ b/test/emqx_session_SUITE.erl @@ -181,8 +181,7 @@ timeout_args() -> {tref(), timeout_msg()}. info_args() -> - oneof([clean_start, - subscriptions, + oneof([subscriptions, max_subscriptions, upgrade_qos, inflight, @@ -292,16 +291,14 @@ expiry_interval() -> ?LET(EI, choose(1, 10), EI * 3600). option() -> ?LET(Option, [{max_inflight, max_inflight()}, - {expiry_interval, expiry_interval()}] - , maps:from_list(Option)). - -cleanstart() -> bool(). + {expiry_interval, expiry_interval()}], + maps:from_list(Option)). session() -> - ?LET({CleanStart, Zone, Options}, - {cleanstart(), zone(), option()}, + ?LET({Zone, Options}, + {zone(), option()}, begin - Session = emqx_session:init(CleanStart, #{zone => Zone}, Options), + Session = emqx_session:init(#{zone => Zone}, Options), emqx_session:set_pkt_id(Session, 16#ffff) end).