From 0f9b5ff3a124bf4b8585a85c576fa74028a5b025 Mon Sep 17 00:00:00 2001 From: JianBo He Date: Tue, 20 Jul 2021 14:56:15 +0800 Subject: [PATCH] refactor(gw): use typical conn&channel to implement mqtt-sn gateway --- .../src/mqttsn/emqx_sn_channel.erl | 1384 +++++++++++++++++ apps/emqx_gateway/src/mqttsn/emqx_sn_conn.erl | 802 ++++++++++ .../emqx_gateway/src/mqttsn/emqx_sn_frame.erl | 92 +- apps/emqx_gateway/src/mqttsn/emqx_sn_impl.erl | 2 +- .../src/mqttsn/include/emqx_sn.hrl | 35 +- 5 files changed, 2274 insertions(+), 41 deletions(-) create mode 100644 apps/emqx_gateway/src/mqttsn/emqx_sn_channel.erl create mode 100644 apps/emqx_gateway/src/mqttsn/emqx_sn_conn.erl diff --git a/apps/emqx_gateway/src/mqttsn/emqx_sn_channel.erl b/apps/emqx_gateway/src/mqttsn/emqx_sn_channel.erl new file mode 100644 index 000000000..f2d548409 --- /dev/null +++ b/apps/emqx_gateway/src/mqttsn/emqx_sn_channel.erl @@ -0,0 +1,1384 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_sn_channel). + +-include("src/mqttsn/include/emqx_sn.hrl"). +-include_lib("emqx/include/emqx.hrl"). +-include_lib("emqx/include/emqx_mqtt.hrl"). +-include_lib("emqx/include/logger.hrl"). + +-logger_header("[SN-Proto]"). + +%% API +-export([ info/1 + , info/2 + , stats/1 + ]). + +-export([ init/2 + , handle_in/2 + , handle_out/3 + , handle_deliver/2 + , handle_timeout/3 + , terminate/2 + , set_conn_state/2 + ]). + +-export([ handle_call/2 + , handle_cast/2 + , handle_info/2 + ]). + +-record(channel, { + %% Context + ctx :: emqx_gateway_ctx:context(), + %% Registry + registry :: pid(), + %% Gateway Id + gateway_id :: integer(), + %% Enable QoS3 + enable_qos3 :: boolean(), %% XXX: Get confs from ctx ? + %% MQTT-SN Connection Info + conninfo :: emqx_types:conninfo(), + %% MQTT-SN Client Info + clientinfo :: emqx_types:clientinfo(), + %% Session + session :: emqx_session:session() | undefined, + %% Keepalive + keepalive :: emqx_keepalive:keepalive() | undefined, + %% Will Msg + will_msg :: emqx_types:message() | undefined, + %% ClientInfo override specs + clientinfo_override :: map(), + %% Connection State + conn_state :: conn_state(), + %% Timer + timers :: #{atom() => disable | undefined | reference()}, + %%% Takeover + takeover :: boolean(), + %% Resume + resuming :: boolean(), + %% Pending delivers when takeovering + pendings :: list() + }). + +-type(channel() :: #channel{}). + +-type(conn_state() :: idle | connecting | connected | asleep | disconnected). + +-type(reply() :: {outgoing, mqtt_sn_message()} + | {outgoing, [mqtt_sn_message()]} + | {event, conn_state()|updated} + | {close, Reason :: atom()}). + +-type(replies() :: emqx_sn_frame:packet() | reply() | [reply()]). + +-define(TIMER_TABLE, #{ + alive_timer => keepalive, + retry_timer => retry_delivery, + await_timer => expire_awaiting_rel, + expire_timer => expire_session, + asleep_timer => expire_asleep + }). + +-define(DEFAULT_OVERRIDE, + #{ clientid => <<"">> %% Generate clientid by default + , username => <<"${Packet.headers.login}">> + , password => <<"${Packet.headers.passcode}">> + }). + +-define(INFO_KEYS, [conninfo, conn_state, clientinfo, session, will_msg]). + +-define(NEG_QOS_CLIENT_ID, <<"NegQoS-Client">>). + +%%-------------------------------------------------------------------- +%% Init the channel +%%-------------------------------------------------------------------- + +%% @doc Init protocol +init(ConnInfo = #{peername := {PeerHost, _}, + sockname := {_, SockPort}}, Option) -> + Peercert = maps:get(peercert, ConnInfo, undefined), + Mountpoint = maps:get(mountpoint, Option, undefined), + Registry = maps:get(registry, Option), + GwId = maps:get(gateway_id, Option), + EnableQoS3 = maps:get(enable_qos3, Option, true), + ClientInfo = set_peercert_infos( + Peercert, + #{ zone => undefined %% XXX: + , protocol => 'mqtt-sn' + , peerhost => PeerHost + , sockport => SockPort + , clientid => undefined + , username => undefined + , is_bridge => false + , is_superuser => false + , mountpoint => Mountpoint + } + ), + + Ctx = maps:get(ctx, Option), + Override = maps:merge(?DEFAULT_OVERRIDE, + maps:get(clientinfo_override, Option, #{}) + ), + #channel{ ctx = Ctx + , registry = Registry + , gateway_id = GwId + , enable_qos3 = EnableQoS3 + , conninfo = ConnInfo + , clientinfo = ClientInfo + , clientinfo_override = Override + , conn_state = idle + , timers = #{} + , takeover = false + , resuming = false + , pendings = [] + }. + +set_peercert_infos(NoSSL, ClientInfo) + when NoSSL =:= nossl; + NoSSL =:= undefined -> + ClientInfo; +set_peercert_infos(Peercert, ClientInfo) -> + {DN, CN} = {esockd_peercert:subject(Peercert), + esockd_peercert:common_name(Peercert)}, + ClientInfo#{dn => DN, cn => CN}. + +-spec info(channel()) -> emqx_types:infos(). +info(Channel) -> + maps:from_list(info(?INFO_KEYS, Channel)). + +-spec info(list(atom())|atom(), channel()) -> term(). +info(Keys, Channel) when is_list(Keys) -> + [{Key, info(Key, Channel)} || Key <- Keys]; + +info(conninfo, #channel{conninfo = ConnInfo}) -> + ConnInfo; +info(conn_state, #channel{conn_state = ConnState}) -> + ConnState; +info(clientinfo, #channel{clientinfo = ClientInfo}) -> + ClientInfo; +info(session, #channel{session = Session}) -> + Session; +info(will_msg, #channel{will_msg = WillMsg}) -> + WillMsg; +info(clientid, #channel{clientinfo = #{clientid := ClientId}}) -> + ClientId; +info(ctx, #channel{ctx = Ctx}) -> + Ctx. + +stats(_Channel) -> + []. + +set_conn_state(ConnState, Channel) -> + Channel#channel{conn_state = ConnState}. + +enrich_conninfo(?SN_CONNECT_MSG(_Flags, _ProtoId, Duration, _ClientId), + Channel = #channel{conninfo = ConnInfo}) -> + NConnInfo = ConnInfo#{ proto_name => <<"MQTT-SN">> + , proto_ver => "1.2" + , clean_start => true + , keepalive => Duration + , expiry_interval => 0 + }, + {ok, Channel#channel{conninfo = NConnInfo}}. + +run_conn_hooks(Packet, Channel = #channel{ctx = Ctx, + conninfo = ConnInfo}) -> + %% XXX: Assign headers of Packet to ConnProps + ConnProps = #{}, + case run_hooks(Ctx, 'client.connect', [ConnInfo], ConnProps) of + Error = {error, _Reason} -> Error; + _NConnProps -> + {ok, Packet, Channel} + end. + +enrich_clientinfo(Packet, + Channel = #channel{ + conninfo = ConnInfo, + clientinfo = ClientInfo0, + clientinfo_override = Override}) -> + ClientInfo = write_clientinfo( + feedvar(Override, Packet, ConnInfo, ClientInfo0), + ClientInfo0 + ), + {ok, NPacket, NClientInfo} = emqx_misc:pipeline( + [ fun maybe_assign_clientid/2 + %% FIXME: CALL After authentication successfully + , fun fix_mountpoint/2 + ], Packet, ClientInfo + ), + {ok, NPacket, Channel#channel{clientinfo = NClientInfo}}. + +feedvar(Override, Packet, ConnInfo, ClientInfo) -> + Envs = #{ 'ConnInfo' => ConnInfo + , 'ClientInfo' => ClientInfo + , 'Packet' => connect_packet_to_map(Packet) + }, + maps:map(fun(_K, V) -> + Tokens = emqx_rule_utils:preproc_tmpl(V), + emqx_rule_utils:proc_tmpl(Tokens, Envs) + end, Override). + +connect_packet_to_map(#mqtt_sn_message{}) -> + %% XXX: Empty now + #{}. + +write_clientinfo(Override, ClientInfo) -> + Override1 = maps:with([username, password, clientid], Override), + maps:merge(ClientInfo, Override1). + +maybe_assign_clientid(_Packet, ClientInfo = #{clientid := ClientId}) + when ClientId == undefined; + ClientId == <<>> -> + {ok, ClientInfo#{clientid => emqx_guid:to_base62(emqx_guid:gen())}}; + +maybe_assign_clientid(_Packet, ClientInfo) -> + {ok, ClientInfo}. + +fix_mountpoint(_Packet, #{mountpoint := undefined}) -> ok; +fix_mountpoint(_Packet, ClientInfo = #{mountpoint := Mountpoint}) -> + %% TODO: Enrich the varibale replacement???? + %% i.e: ${ClientInfo.auth_result.productKey} + Mountpoint1 = emqx_mountpoint:replvar(Mountpoint, ClientInfo), + {ok, ClientInfo#{mountpoint := Mountpoint1}}. + +set_log_meta(_Packet, #channel{clientinfo = #{clientid := ClientId}}) -> + emqx_logger:set_metadata_clientid(ClientId), + ok. + +maybe_require_will_msg(?SN_CONNECT_MSG(Flags, _, _, _), Channel) -> + #mqtt_sn_flags{will = Will} = Flags, + case Will of + true -> + {error, need_will_msg, Channel}; + _ -> + ok + end. + +auth_connect(_Packet, Channel = #channel{ctx = Ctx, + clientinfo = ClientInfo}) -> + #{clientid := ClientId, + username := Username} = ClientInfo, + case emqx_gateway_ctx:authenticate(Ctx, ClientInfo) of + {ok, NClientInfo} -> + {ok, Channel#channel{clientinfo = NClientInfo}}; + {error, Reason} -> + ?LOG(warning, "Client ~s (Username: '~s') login failed for ~0p", + [ClientId, Username, Reason]), + %% FIXME: ReasonCode? + {error, Reason} + end. + +ensure_connected(Channel = #channel{ + ctx = Ctx, + conninfo = ConnInfo, + clientinfo = ClientInfo}) -> + NConnInfo = ConnInfo#{connected_at => erlang:system_time(millisecond)}, + ok = run_hooks(Ctx, 'client.connected', [ClientInfo, NConnInfo]), + Channel#channel{conninfo = NConnInfo, + conn_state = connected + }. + +process_connect(Channel = #channel{ + ctx = Ctx, + conninfo = ConnInfo, + clientinfo = ClientInfo + }) -> + SessFun = fun(_,_) -> + %% TODO: + emqx_session:init(#{zone => undefined}, + #{receive_maximum => 100} + ) + end, + case emqx_gateway_ctx:open_session( + Ctx, + true, + ClientInfo, + ConnInfo, + SessFun + ) of + {ok, #{session := Session, + present := _Present}} -> + handle_out(connack, ?SN_RC_ACCEPTED, + Channel#channel{session = Session}); + {error, Reason} -> + ?LOG(error, "Failed to open session due to ~p", [Reason]), + handle_out(connack, ?SN_RC_FAILED_SESSION, Channel) + end. + +%%-------------------------------------------------------------------- +%% Enrich Keepalive + +ensure_keepalive(Channel = #channel{conninfo = ConnInfo}) -> + ensure_keepalive_timer(maps:get(keepalive, ConnInfo), Channel). + +ensure_keepalive_timer(0, Channel) -> Channel; +ensure_keepalive_timer(Interval, Channel) -> + Keepalive = emqx_keepalive:init(round(timer:seconds(Interval))), + ensure_timer(alive_timer, Channel#channel{keepalive = Keepalive}). + +%%-------------------------------------------------------------------- +%% Handle incoming packet +%%-------------------------------------------------------------------- + +-spec handle_in(emqx_types:packet(), channel()) + -> {ok, channel()} + | {ok, replies(), channel()} + | {shutdown, Reason :: term(), channel()} + | {shutdown, Reason :: term(), replies(), channel()}. + +%% SEARCHGW, GWINFO +handle_in(?SN_SEARCHGW_MSG(_Radius), + Channel = #channel{gateway_id = GwId}) -> + {ok, {outgoing, ?SN_GWINFO_MSG(GwId, <<>>)}, Channel}; + +handle_in(?SN_ADVERTISE_MSG(_GwId, _Radius), Channel) -> + % ingore + {ok, Channel}; + +handle_in(?SN_PUBLISH_MSG(#mqtt_sn_flags{qos = ?QOS_NEG1, + topic_id_type = TopicIdType + }, + TopicId, _MsgId, Data), + Channel = #channel{conn_state = idle, registry = Registry}) -> + %% FIXME: check enable_qos3 ?? + ClientId = undefined, + TopicName = case (TopicIdType =:= ?SN_SHORT_TOPIC) of + false -> + emqx_sn_registry:lookup_topic( + Registry, + ClientId, + TopicId + ); + true -> + <> + end, + _ = case TopicName =/= undefined of + true -> + Msg = emqx_message:make( + ?NEG_QOS_CLIENT_ID, + ?QOS_0, + TopicName, + Data + ), + emqx_broker:publish(Msg); + false -> + ok + end, + ?LOG(debug, "Client id=~p receives a publish with QoS=-1 in idle mode!", + [ClientId]), + {ok, Channel}; + +handle_in(Pkt = #mqtt_sn_message{type = Type}, + Channel = #channel{conn_state = idle}) + when Type /= ?SN_CONNECT -> + ?LOG(warning, "Receive unknown packet ~0p in idle state", [Pkt]), + shutdown(normal, Channel); + +handle_in(?SN_CONNECT_MSG(_Flags, _ProtoId, _Duration, _ClientId), + Channel = #channel{conn_state = connecting}) -> + ?LOG(warning, "Receive connect packet in connecting state"), + {ok, Channel}; + +handle_in(?SN_CONNECT_MSG(_Flags, _ProtoId, _Duration, _ClientId), + Channel = #channel{conn_state = connected}) -> + {error, unexpected_connect, Channel}; + +handle_in(?SN_WILLTOPIC_EMPTY_MSG, + Channel = #channel{conn_state = connecting}) -> + %% 6.3: + %% Note that if a client wants to delete only its Will data at + %% connection setup, it could send a CONNECT message with + %% 'CleanSession=false' and 'Will=true', + %% and sends an empty WILLTOPIC message to the GW when prompted to do so + case auth_connect(fake_packet, Channel#channel{will_msg = undefined}) of + {ok, NChannel} -> + process_connect(ensure_connected(NChannel)); + {error, ReasonCode, NChannel} -> + handle_out(connack, ReasonCode, NChannel) + end; + +handle_in(?SN_WILLTOPIC_MSG(Flags, Topic), + Channel = #channel{conn_state = connecting, + clientinfo = #{clientid := ClientId}}) -> + #mqtt_sn_flags{qos = QoS, retain = Retain} = Flags, + WillMsg0 = emqx_message:make(ClientId, QoS, Topic, <<>>), + WillMsg = emqx_message:set_flag(retain, Retain, WillMsg0), + NChannel = Channel#channel{will_msg = WillMsg}, + {ok, {outgoing, ?SN_WILLMSGREQ_MSG()}, NChannel}; + +handle_in(?SN_WILLMSG_MSG(Payload), + Channel = #channel{conn_state = connecting, + will_msg = WillMsg}) -> + NWillMsg = WillMsg#message{payload = Payload}, + case auth_connect(fake_packet, Channel#channel{will_msg = NWillMsg}) of + {ok, NChannel} -> + process_connect(ensure_connected(NChannel)); + {error, ReasonCode, NChannel} -> + handle_out(connack, ReasonCode, NChannel) + end; + +handle_in(Packet = ?SN_CONNECT_MSG(_Flags, _ProtoId, _Duration, _ClientId), + Channel) -> + case emqx_misc:pipeline( + [ fun enrich_conninfo/2 + , fun run_conn_hooks/2 + , fun enrich_clientinfo/2 + , fun set_log_meta/2 + %% TODO: How to implement the banned in the gateway instance? + %, fun check_banned/2 + , fun maybe_require_will_msg/2 + , fun auth_connect/2 + ], Packet, Channel#channel{conn_state = connecting}) of + {ok, _NPacket, NChannel} -> + process_connect(ensure_connected(NChannel)); + {error, need_will_msg, NChannel} -> + {ok, {outgoing, ?SN_WILLTOPICREQ_MSG()}, NChannel}; + {error, ReasonCode, NChannel} -> + handle_out(connack, ReasonCode, NChannel) + end; + +handle_in(?SN_REGISTER_MSG(_TopicId, MsgId, TopicName), + Channel = #channel{registry = Registry, + clientinfo = #{clientid := ClientId}}) -> + case emqx_sn_registry:register_topic(Registry, ClientId, TopicName) of + TopicId when is_integer(TopicId) -> + ?LOG(debug, "register TopicName=~p, TopicId=~p", + [TopicName, TopicId]), + AckPacket = ?SN_REGACK_MSG(TopicId, MsgId, ?SN_RC_ACCEPTED), + {ok, {outgoing, AckPacket}, Channel}; + {error, too_large} -> + ?LOG(error, "TopicId is full! TopicName=~p", [TopicName]), + AckPacket = ?SN_REGACK_MSG( + ?SN_INVALID_TOPIC_ID, + MsgId, + ?SN_RC_NOT_SUPPORTED + ), + {ok, {outgoing, AckPacket}, Channel}; + {error, wildcard_topic} -> + ?LOG(error, "wildcard topic can not be registered! TopicName=~p", + [TopicName]), + AckPacket = ?SN_REGACK_MSG( + ?SN_INVALID_TOPIC_ID, + MsgId, + ?SN_RC_NOT_SUPPORTED + ), + {ok, {outgoing, AckPacket}, Channel} + end; + +handle_in(PubPkt = ?SN_PUBLISH_MSG(_Flags, TopicId0, MsgId, _Data), Channel) -> + TopicId = case is_integer(TopicId0) of + true -> TopicId0; + _ -> <> = TopicId0, Id + end, + case emqx_misc:pipeline( + [ fun check_qos3_enable/2 + , fun preproc_pub_pkt/2 + , fun convert_topic_id_to_name/2 + , fun check_pub_acl/2 + , fun convert_pub_to_msg/2 + ], PubPkt, Channel) of + {ok, Msg, NChannel} -> + do_publish(TopicId, MsgId, Msg, NChannel); + {error, ReturnCode, NChannel} -> + handle_out(puback, {TopicId, MsgId, ReturnCode}, NChannel) + end; + +handle_in(?SN_PUBACK_MSG(TopicId, MsgId, ReturnCode), + Channel = #channel{ + registry = Registry, + session = Session, + clientinfo = ClientInfo = #{clientid := ClientId}}) -> + case ReturnCode of + ?SN_RC_ACCEPTED -> + case emqx_session:puback(MsgId, Session) of + {ok, Msg, NSession} -> + ok = after_message_acked(ClientInfo, Msg, Channel), + {ok, Channel#channel{session = NSession}}; + {ok, Msg, Publishes, NSession} -> + ok = after_message_acked(ClientInfo, Msg, Channel), + handle_out(publish, + Publishes, + Channel#channel{session = NSession}); + {error, ?RC_PACKET_IDENTIFIER_IN_USE} -> + ?LOG(warning, "The PUBACK MsgId ~w is inuse.", + [MsgId]), + ok = metrics_inc('packets.puback.inuse', Channel), + {ok, Channel}; + {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND} -> + ?LOG(warning, "The PUBACK MsgId ~w is not found.", + [MsgId]), + ok = metrics_inc('packets.puback.missed', Channel), + {ok, Channel} + end; + ?SN_RC_INVALID_TOPIC_ID -> + case emqx_sn_registry:lookup_topic(Registry, ClientId, TopicId) of + undefined -> + {ok, Channel}; + TopicName -> + %% notice that this TopicName maybe normal or predefined, + %% involving the predefined topic name in register to + %% enhance the gateway's robustness even inconsistent + %% with MQTT-SN channels + RegPkt = ?SN_REGISTER_MSG(TopicId, MsgId, TopicName), + {ok, {outgoing, RegPkt}, Channel} + end; + _ -> + ?LOG(error, "CAN NOT handle PUBACK ReturnCode=~p", [ReturnCode]), + {ok, Channel} + end; + +handle_in(?SN_PUBREC_MSG(?SN_PUBREC, MsgId), + Channel = #channel{session = Session, clientinfo = ClientInfo}) -> + case emqx_session:pubrec(MsgId, Session) of + {ok, Msg, NSession} -> + ok = after_message_acked(ClientInfo, Msg, Channel), + NChannel = Channel#channel{session = NSession}, + handle_out(pubrel, MsgId, NChannel); + {error, ?RC_PACKET_IDENTIFIER_IN_USE} -> + ?LOG(warning, "The PUBREC MsgId ~w is inuse.", [MsgId]), + ok = metrics_inc('packets.pubrec.inuse', Channel), + handle_out(pubrel, MsgId, Channel); + {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND} -> + ?LOG(warning, "The PUBREC ~w is not found.", [MsgId]), + ok = metrics_inc('packets.pubrec.missed', Channel), + handle_out(pubrel, MsgId, Channel) + end; + +handle_in(?SN_PUBREC_MSG(?SN_PUBREL, MsgId), + Channel = #channel{session = Session}) -> + case emqx_session:pubrel(MsgId, Session) of + {ok, NSession} -> + NChannel = Channel#channel{session = NSession}, + handle_out(pubcomp, MsgId, NChannel); + {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND} -> + ?LOG(warning, "The PUBREL MsgId ~w is not found.", [MsgId]), + ok = metrics_inc('packets.pubrel.missed', Channel), + handle_out(pubcomp, MsgId, Channel) + end; + +handle_in(?SN_PUBREC_MSG(?SN_PUBCOMP, MsgId), + Channel = #channel{session = Session}) -> + case emqx_session:pubcomp(MsgId, Session) of + {ok, NSession} -> + {ok, Channel#channel{session = NSession}}; + {ok, Publishes, NSession} -> + handle_out(publish, Publishes, + Channel#channel{session = NSession}); + {error, ?RC_PACKET_IDENTIFIER_IN_USE} -> + ok = metrics_inc('packets.pubcomp.inuse', Channel), + {ok, Channel}; + {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND} -> + ?LOG(warning, "The PUBCOMP MsgId ~w is not found", [MsgId]), + ok = metrics_inc('packets.pubcomp.missed', Channel), + {ok, Channel} + end; + +handle_in(SubPkt = ?SN_SUBSCRIBE_MSG(_, MsgId, _), Channel) -> + case emqx_misc:pipeline( + [ fun preproc_subs_type/2 + , fun check_subscribe_acl/2 + , fun do_subscribe/2 + ], SubPkt, Channel) of + {ok, {TopicId, GrantedQoS}, NChannel} -> + SubAck = ?SN_SUBACK_MSG(#mqtt_sn_flags{qos = GrantedQoS}, + TopicId, MsgId, ?SN_RC_ACCEPTED), + {ok, {outgoing, SubAck}, NChannel}; + {error, ReturnCode, NChannel} -> + SubAck = ?SN_SUBACK_MSG(#mqtt_sn_flags{}, + ?SN_INVALID_TOPIC_ID, + MsgId, + ReturnCode), + {ok, {outgoing, SubAck}, NChannel} + end; + +handle_in(UnsubPkt = ?SN_UNSUBSCRIBE_MSG(_, MsgId, TopicIdOrName), + Channel) -> + case emqx_misc:pipeline( + [ fun preproc_unsub_type/2 + , fun do_unsubscribe/2 + ], UnsubPkt, Channel) of + {ok, _TopicName, NChannel} -> + UnsubAck = ?SN_UNSUBACK_MSG(MsgId), + {ok, {outgoing, UnsubAck}, NChannel}; + {error, Reason, NChannel} -> + ?LOG(warning, "Unsubscribe ~p failed: ~0p", + [TopicIdOrName, Reason]), + %% XXX: Even if it fails, the reply is successful. + UnsubAck = ?SN_UNSUBACK_MSG(MsgId), + {ok, {outgoing, UnsubAck}, NChannel} + end; + +handle_in(?SN_PINGREQ_MSG(_ClientId), + Channel = #channel{conn_state = asleep}) -> + {ok, Outgoing, NChannel} = awake(Channel), + NOutgoings = Outgoing ++ [{outgoing, ?SN_PINGRESP_MSG()}], + {ok, NOutgoings, NChannel}; + +handle_in(?SN_PINGREQ_MSG(_ClientId), Channel) -> + {ok, {outgoing, ?SN_PINGRESP_MSG()}, Channel}; + +handle_in(?SN_PINGRESP_MSG(), Channel) -> + {ok, Channel}; + +handle_in(?SN_DISCONNECT_MSG(Duration), Channel) -> + AckPkt = ?SN_DISCONNECT_MSG(undefined), + case Duration of + undefined -> + shutdown(normal, AckPkt, Channel); + _ -> + %% TODO: asleep mechnisa + {ok, {outgoing, AckPkt}, asleep(Duration, Channel)} + end; + +handle_in(?SN_WILLTOPICUPD_MSG(Flags, Topic), + Channel = #channel{will_msg = WillMsg, + clientinfo = #{clientid := ClientId}}) -> + NWillMsg = case Topic of + undefined -> undefined; + _ -> + update_will_topic(WillMsg, Flags, Topic, ClientId) + end, + AckPkt = ?SN_WILLTOPICRESP_MSG(?SN_RC_ACCEPTED), + {ok, {outgoing, AckPkt}, Channel#channel{will_msg = NWillMsg}}; + +handle_in(?SN_WILLMSGUPD_MSG(Payload), + Channel = #channel{will_msg = WillMsg}) -> + AckPkt = ?SN_WILLMSGRESP_MSG(?SN_RC_ACCEPTED), + NWillMsg = update_will_msg(WillMsg, Payload), + {ok, {outgoing, AckPkt}, Channel#channel{will_msg = NWillMsg}}; + +handle_in({frame_error, Reason}, + Channel = #channel{conn_state = _ConnState}) -> + ?LOG(error, "Unexpected frame error: ~p", [Reason]), + shutdown(Reason, Channel). + +after_message_acked(ClientInfo, Msg, + Channel = #channel{ctx = Ctx}) -> + ok = metrics_inc('messages.acked', Channel), + run_hooks_without_metrics(Ctx, + 'message.acked', + [ClientInfo, emqx_message:set_header(puback_props, #{}, Msg)]). + +%%-------------------------------------------------------------------- +%% Handle Publish + +check_qos3_enable(?SN_PUBLISH_MSG(Flags, _, _, _), + #channel{enable_qos3 = EnableQoS3}) -> + #mqtt_sn_flags{qos = QoS} = Flags, + case EnableQoS3 =:= false andalso QoS =:= ?QOS_NEG1 of + true -> + ?LOG(debug, "The enable_qos3 is false, ignore the received " + "publish with QoS=-1 in connected mode!"), + {error, ?SN_RC_NOT_SUPPORTED}; + false -> + ok + end. + +preproc_pub_pkt(?SN_PUBLISH_MSG(Flags, Topic0, _MsgId, Data), + Channel) -> + #mqtt_sn_flags{topic_id_type = TopicIdType} = Flags, + case TopicIdType of + ?SN_NORMAL_TOPIC -> + <> = Topic0, + TopicIndicator = {id, TopicId}, + {ok, {TopicIndicator, Flags, Data}, Channel}; + ?SN_PREDEFINED_TOPIC -> + TopicIndicator = {id, Topic0}, + {ok, {TopicIndicator, Flags, Data}, Channel}; + ?SN_SHORT_TOPIC -> + case emqx_topic:wildcard(Topic0) of + true -> + {error, ?SN_RC_NOT_SUPPORTED}; + false -> + TopicIndicator = {name, Topic0}, + {ok, {TopicIndicator, Flags, Data}, Channel} + end + end. + +convert_topic_id_to_name({{name, TopicName}, Flags, Data}, Channel) -> + {ok, {TopicName, Flags, Data}, Channel}; + +convert_topic_id_to_name({{id, TopicId}, Flags, Data}, + Channel = #channel{ + registry = Registry, + clientinfo = #{clientid := ClientId}} + ) -> + case emqx_sn_registry:lookup_topic(Registry, ClientId, TopicId) of + undefined -> + {error, ?SN_RC_INVALID_TOPIC_ID}; + TopicName -> + {ok, {TopicName, Flags, Data}, Channel} + end. + +check_pub_acl({TopicName, _Flags, _Data}, + #channel{clientinfo = ClientInfo}) -> + case emqx_access_control:authorize(ClientInfo, publish, TopicName) of + allow -> ok; + deny -> {error, ?SN_RC_NOT_AUTHORIZE} + end. + +convert_pub_to_msg({TopicName, Flags, Data}, + Channel = #channel{ + clientinfo = #{clientid := ClientId}}) -> + #mqtt_sn_flags{qos = QoS, dup = Dup, retain = Retain} = Flags, + NewQoS = get_corrected_qos(QoS), + Message = emqx_message:make(ClientId, NewQoS, TopicName, Data), + NMessage = emqx_message:set_flags( + #{dup => Dup, retain => Retain}, + Message + ), + {ok, NMessage, Channel}. + +get_corrected_qos(?QOS_NEG1) -> ?QOS_0; +get_corrected_qos(QoS) -> QoS. + +do_publish(_TopicId, _MsgId, Msg = #message{qos = ?QOS_0}, Channel) -> + _ = emqx_broker:publish(Msg), + {ok, Channel}; + +do_publish(TopicId, MsgId, Msg = #message{qos = ?QOS_1}, Channel) -> + _ = emqx_broker:publish(Msg), + handle_out(puback, {TopicId, MsgId, ?SN_RC_ACCEPTED}, Channel); + +do_publish(TopicId, MsgId, Msg = #message{qos = ?QOS_2}, + Channel = #channel{session = Session}) -> + case emqx_session:publish(MsgId, Msg, Session) of + {ok, _PubRes, NSession} -> + NChannel1 = ensure_timer(await_timer, + Channel#channel{session = NSession} + ), + handle_out(pubrec, MsgId, NChannel1); + {error, ?RC_PACKET_IDENTIFIER_IN_USE} -> + ok = metrics_inc('packets.publish.inuse', Channel), + %% XXX: Use PUBACK to reply a PUBLISH Error Code + handle_out(puback , {TopicId, MsgId, ?SN_RC_NOT_SUPPORTED}, + Channel); + {error, ?RC_RECEIVE_MAXIMUM_EXCEEDED} -> + ?LOG(warning, "Dropped the qos2 packet ~w " + "due to awaiting_rel is full.", [MsgId]), + ok = emqx_metrics:inc('packets.publish.dropped'), + handle_out(puback, {TopicId, MsgId, ?SN_RC_CONGESTION}, Channel) + end. + +%%-------------------------------------------------------------------- +%% Handle Susbscribe + +preproc_subs_type(?SN_SUBSCRIBE_MSG_TYPE(?SN_NORMAL_TOPIC, + TopicName, QoS), + Channel = #channel{ + registry = Registry, + clientinfo = #{clientid := ClientId} + }) -> + %% If the gateway is able accept the subscription, + %% it assigns a topic id to the received topic name + %% and returns it within a SUBACK message + case emqx_sn_registry:register_topic(Registry, ClientId, TopicName) of + {error, too_large} -> + {error, ?SN_EXCEED_LIMITATION}; + {error, wildcard_topic} -> + %% If the client subscribes to a topic name which contains a + %% wildcard character, the returning SUBACK message will contain + %% the topic id value 0x0000. The GW will the use the registration + %% procedure to inform the client about the to-be-used topic id + %% value when it has the first PUBLISH message with a matching + %% topic name to be sent to the client, see also Section 6.10. + {ok, {?SN_INVALID_TOPIC_ID, TopicName, QoS}, Channel}; + TopicId when is_integer(TopicId) -> + {ok, {TopicId, TopicName, QoS}, Channel} + end; + +preproc_subs_type(?SN_SUBSCRIBE_MSG_TYPE(?SN_PREDEFINED_TOPIC, + TopicId, QoS), + Channel = #channel{ + registry = Registry, + clientinfo = #{clientid := ClientId} + }) -> + case emqx_sn_registry:lookup_topic(Registry, + ClientId, TopicId) of + undefined -> + {error, ?SN_RC_INVALID_TOPIC_ID}; + TopicName -> + {ok, {TopicId, TopicName, QoS}, Channel} + end; + +preproc_subs_type(?SN_SUBSCRIBE_MSG_TYPE(?SN_SHORT_TOPIC, + TopicId, QoS), + Channel) -> + TopicName = case is_binary(TopicId) of + true -> TopicId; + false -> <> + end, + %% XXX: ?SN_INVALID_TOPIC_ID ??? + {ok, {?SN_INVALID_TOPIC_ID, TopicName, QoS}, Channel}; + +preproc_subs_type(?SN_SUBSCRIBE_MSG_TYPE(_Reserved, _TopicId, _QoS), + _Channel) -> + {error, ?SN_RC_NOT_SUPPORTED}. + +check_subscribe_acl({_TopicId, TopicName, _QoS}, + Channel = #channel{clientinfo = ClientInfo}) -> + case emqx_access_control:authorize( + ClientInfo, subscribe, TopicName) of + allow -> + {ok, Channel}; + _ -> + {error, ?SN_RC_NOT_AUTHORIZE} + end. + +do_subscribe({TopicId, TopicName, QoS}, + Channel = #channel{ + session = Session, + clientinfo = ClientInfo + = #{mountpoint := Mountpoint}}) -> + NTopicName = emqx_mountpoint:mount(Mountpoint, TopicName), + SubOpts = maps:merge(?DEFAULT_SUBOPTS, #{qos => QoS}), + case emqx_session:subscribe(ClientInfo, NTopicName, SubOpts, Session) of + {ok, NSession} -> + {ok, {TopicId, QoS}, + Channel#channel{session = NSession}}; + {error, ?RC_QUOTA_EXCEEDED} -> + ?LOG(warning, "Cannot subscribe ~s due to ~s.", + [TopicName, emqx_reason_codes:text(?RC_QUOTA_EXCEEDED)]), + {error, ?SN_EXCEED_LIMITATION} + end. + +%%-------------------------------------------------------------------- +%% Handle Unsubscribe + +preproc_unsub_type(?SN_UNSUBSCRIBE_MSG_TYPE(?SN_NORMAL_TOPIC, + TopicName), + Channel) -> + {ok, TopicName, Channel}; +preproc_unsub_type(?SN_UNSUBSCRIBE_MSG_TYPE(?SN_PREDEFINED_TOPIC, + TopicId), + Channel = #channel{ + registry = Registry, + clientinfo = #{clientid := ClientId} + }) -> + case emqx_sn_registry:lookup_topic(Registry, ClientId, + TopicId) of + undefined -> + {error, not_found}; + TopicName -> + {ok, TopicName, Channel} + end; +preproc_unsub_type(?SN_UNSUBSCRIBE_MSG_TYPE(?SN_SHORT_TOPIC, + TopicId), + Channel) -> + TopicName = case is_binary(TopicId) of + true -> TopicId; + false -> <> + end, + {ok, TopicName, Channel}. + +do_unsubscribe(TopicName, + Channel = #channel{ + session = Session, + clientinfo = ClientInfo + = #{mountpoint := Mountpoint}}) -> + SubOpts = #{}, + NTopicName = emqx_mountpoint:mount(Mountpoint, TopicName), + case emqx_session:unsubscribe(ClientInfo, NTopicName, SubOpts, Session) of + {ok, NSession} -> + {ok, Channel#channel{session = NSession}}; + {error, ?RC_NO_SUBSCRIPTION_EXISTED} -> + {ok, Channel} + end. + +%%-------------------------------------------------------------------- +%% Awake & Asleep + +awake(Channel = #channel{session = Session}) -> + {ok, Publishes, Session1} = emqx_session:replay(Session), + {NPublishes, NSession} = case emqx_session:deliver([], Session1) of + {ok, Session2} -> + {Publishes, Session2}; + {ok, More, Session2} -> + {lists:append(Publishes, More), Session2} + end, + {Packets, NChannel} = do_deliver(NPublishes, + Channel#channel{session = NSession}), + Outgoing = [{outgoing, Packets} || length(Packets) > 0], + {ok, Outgoing, NChannel}. + +asleep(Duration, Channel = #channel{conn_state = asleep}) -> + %% 6.14: The client can also modify its sleep duration + %% by sending a DISCONNECT message with a new value of + %% the sleep duration + ensure_timer(asleep_timer, Duration, + cancel_timer(asleep_timer, Channel) + ); + +asleep(Duration, Channel = #channel{conn_state = connected}) -> + ensure_timer(asleep_timer, Duration, + Channel#channel{conn_state = asleep} + ). + +%%-------------------------------------------------------------------- +%% Handle outgoing packet +%%-------------------------------------------------------------------- + +-spec handle_out(atom(), term(), channel()) + -> {ok, channel()} + | {ok, replies(), channel()} + | {shutdown, Reason :: term(), channel()} + | {shutdown, Reason :: term(), replies(), channel()}. + +handle_out(connack, ?SN_RC_ACCEPTED, + Channel = #channel{ctx = Ctx, conninfo = ConnInfo}) -> + _ = run_hooks(Ctx, 'client.connack', + [ConnInfo, ?SN_RC_NAME(?SN_RC_ACCEPTED)], + #{} + ), + return_connack(?SN_CONNACK_MSG(?SN_RC_ACCEPTED), + ensure_keepalive(Channel)); + +handle_out(connack, ReasonCode, + Channel = #channel{ctx = Ctx, conninfo = ConnInfo}) -> + Reason = ?SN_RC_NAME(ReasonCode), + _ = run_hooks(Ctx, 'client.connack', [ConnInfo, Reason], #{}), + AckPacket = ?SN_CONNACK_MSG(ReasonCode), + shutdown(Reason, AckPacket, Channel); + +handle_out(publish, Publishes, Channel) -> + {Packets, NChannel} = do_deliver(Publishes, Channel), + {ok, {outgoing, Packets}, NChannel}; + +handle_out(puback, {TopicId, MsgId, Rc}, Channel) -> + {ok, {outgoing, ?SN_PUBACK_MSG(TopicId, MsgId, Rc)}, Channel}; + +handle_out(pubrec, MsgId, Channel) -> + {ok, {outgoing, ?SN_PUBREC_MSG(?SN_PUBREC, MsgId)}, Channel}; + +handle_out(pubrel, MsgId, Channel) -> + {ok, {outgoing, ?SN_PUBREC_MSG(?SN_PUBREL, MsgId)}, Channel}; + +handle_out(pubcomp, MsgId, Channel) -> + {ok, {outgoing, ?SN_PUBREC_MSG(?SN_PUBCOMP, MsgId)}, Channel}; + +handle_out(disconnect, RC, Channel) -> + DisPkt = ?SN_DISCONNECT_MSG(undefined), + {ok, [{outgoing, DisPkt}, {close, RC}], Channel}. + +%%-------------------------------------------------------------------- +%% Return ConnAck +%%-------------------------------------------------------------------- + +return_connack(AckPacket, Channel) -> + Replies = [{event, connected}, {outgoing, AckPacket}], + case maybe_resume_session(Channel) of + ignore -> {ok, Replies, Channel}; + {ok, Publishes, NSession} -> + NChannel = Channel#channel{session = NSession, + resuming = false, + pendings = [] + }, + {Packets, NChannel1} = do_deliver(Publishes, NChannel), + Outgoing = [{outgoing, Packets} || length(Packets) > 0], + {ok, Replies ++ Outgoing, NChannel1} + end. + +%%-------------------------------------------------------------------- +%% Maybe Resume Session + +maybe_resume_session(#channel{resuming = false}) -> + ignore; +maybe_resume_session(#channel{session = Session, + resuming = true, + pendings = Pendings}) -> + {ok, Publishes, Session1} = emqx_session:replay(Session), + case emqx_session:deliver(Pendings, Session1) of + {ok, Session2} -> + {ok, Publishes, Session2}; + {ok, More, Session2} -> + {ok, lists:append(Publishes, More), Session2} + end. + +%%-------------------------------------------------------------------- +%% Deliver publish: broker -> client +%%-------------------------------------------------------------------- + +%% return list(emqx_types:packet()) +do_deliver({pubrel, MsgId}, Channel) -> + {[?SN_PUBREC_MSG(?SN_PUBREL, MsgId)], Channel}; + +do_deliver({MsgId, Msg}, + Channel = #channel{ + ctx = Ctx, + clientinfo = ClientInfo + = #{mountpoint := Mountpoint}}) -> + metrics_inc('messages.delivered', Channel), + Msg1 = run_hooks_without_metrics( + Ctx, + 'message.delivered', + [ClientInfo], + emqx_message:update_expiry(Msg) + ), + Msg2 = emqx_mountpoint:unmount(Mountpoint, Msg1), + Packet = message_to_packet(MsgId, Msg2, Channel), + {[Packet], Channel}; + +do_deliver([Publish], Channel) -> + do_deliver(Publish, Channel); + +do_deliver(Publishes, Channel) when is_list(Publishes) -> + {Packets, NChannel} = + lists:foldl(fun(Publish, {Acc, Chann}) -> + {Packets, NChann} = do_deliver(Publish, Chann), + {Packets ++ Acc, NChann} + end, {[], Channel}, Publishes), + {lists:reverse(Packets), NChannel}. + +message_to_packet(MsgId, Message, + #channel{registry = Registry, + clientinfo = #{clientid := ClientId}}) -> + QoS = emqx_message:qos(Message), + Topic = emqx_message:topic(Message), + Payload = emqx_message:payload(Message), + NMsgId = case QoS of + ?QOS_0 -> 0; + _ -> MsgId + end, + {TopicIdType, NTopicId} = + case emqx_sn_registry:lookup_topic_id(Registry, ClientId, Topic) of + {predef, PredefTopicId} -> + {?SN_PREDEFINED_TOPIC, PredefTopicId}; + TopicId when is_integer(TopicId) -> + {?SN_NORMAL_TOPIC, TopicId}; + undefined -> + {?SN_SHORT_TOPIC, Topic} + end, + Flags = #mqtt_sn_flags{qos = QoS, topic_id_type = TopicIdType}, + ?SN_PUBLISH_MSG(Flags, NTopicId, NMsgId, Payload). + +%%-------------------------------------------------------------------- +%% Handle call +%%-------------------------------------------------------------------- + +-spec handle_call(Req :: term(), channel()) + -> {reply, Reply :: term(), channel()} + | {shutdown, Reason :: term(), Reply :: term(), channel()} + | {shutdown, Reason :: term(), Reply :: term(), + emqx_types:packet(), channel()}. +handle_call(kick, Channel) -> + NChannel = ensure_disconnected(kicked, Channel), + shutdown_and_reply(kicked, ok, NChannel); + +handle_call(discard, Channel) -> + shutdown_and_reply(discarded, ok, Channel); + +%% XXX: No Session Takeover +%handle_call({takeover, 'begin'}, Channel = #channel{session = Session}) -> +% reply(Session, Channel#channel{takeover = true}); +% +%handle_call({takeover, 'end'}, Channel = #channel{session = Session, +% pendings = Pendings}) -> +% ok = emqx_session:takeover(Session), +% %% TODO: Should not drain deliver here (side effect) +% Delivers = emqx_misc:drain_deliver(), +% AllPendings = lists:append(Delivers, Pendings), +% shutdown_and_reply(takeovered, AllPendings, Channel); + +handle_call(list_acl_cache, Channel) -> + {reply, emqx_acl_cache:list_acl_cache(), Channel}; + +%% XXX: No Quota Now +% handle_call({quota, Policy}, Channel) -> +% Zone = info(zone, Channel), +% Quota = emqx_limiter:init(Zone, Policy), +% reply(ok, Channel#channel{quota = Quota}); + +handle_call(Req, Channel) -> + ?LOG(error, "Unexpected call: ~p", [Req]), + reply(ignored, Channel). + +%%-------------------------------------------------------------------- +%% Handle Cast +%%-------------------------------------------------------------------- + +-spec handle_cast(Req :: term(), channel()) + -> ok | {ok, channel()} | {shutdown, Reason :: term(), channel()}. +handle_cast(_Req, Channel) -> + {ok, Channel}. + +%%-------------------------------------------------------------------- +%% Handle Info +%%-------------------------------------------------------------------- + +-spec handle_info(Info :: term(), channel()) + -> ok | {ok, channel()} | {shutdown, Reason :: term(), channel()}. + +%% XXX: Received from the emqx-management ??? +%handle_info({subscribe, TopicFilters}, Channel ) -> +% {_, NChannel} = lists:foldl( +% fun({TopicFilter, SubOpts}, {_, ChannelAcc}) -> +% do_subscribe(TopicFilter, SubOpts, ChannelAcc) +% end, {[], Channel}, parse_topic_filters(TopicFilters)), +% {ok, NChannel}; +% +%handle_info({unsubscribe, TopicFilters}, Channel) -> +% {_RC, NChannel} = process_unsubscribe(TopicFilters, #{}, Channel), +% {ok, NChannel}; + +handle_info({sock_closed, Reason}, + Channel = #channel{conn_state = idle}) -> + shutdown(Reason, Channel); + +handle_info({sock_closed, Reason}, + Channel = #channel{conn_state = connecting}) -> + shutdown(Reason, Channel); + +handle_info({sock_closed, Reason}, + Channel = #channel{conn_state = connected, + clientinfo = _ClientInfo}) -> + %% XXX: Flapping detect ??? + %% How to get the flapping detect policy ??? + %emqx_zone:enable_flapping_detect(Zone) + % andalso emqx_flapping:detect(ClientInfo), + NChannel = ensure_disconnected(Reason, mabye_publish_will_msg(Channel)), + %% XXX: Session keepper detect here + shutdown(Reason, NChannel); + +handle_info({sock_closed, Reason}, + Channel = #channel{conn_state = disconnected}) -> + ?LOG(error, "Unexpected sock_closed: ~p", [Reason]), + {ok, Channel}; + +handle_info(clean_acl_cache, Channel) -> + ok = emqx_acl_cache:empty_acl_cache(), + {ok, Channel}; + +handle_info(Info, Channel) -> + ?LOG(error, "Unexpected info: ~p", [Info]), + {ok, Channel}. + +%%-------------------------------------------------------------------- +%% Ensure disconnected + +ensure_disconnected(Reason, Channel = #channel{ + ctx = Ctx, + conninfo = ConnInfo, + clientinfo = ClientInfo}) -> + NConnInfo = ConnInfo#{disconnected_at => erlang:system_time(millisecond)}, + ok = run_hooks(Ctx, 'client.disconnected', + [ClientInfo, Reason, NConnInfo]), + Channel#channel{conninfo = NConnInfo, conn_state = disconnected}. + +mabye_publish_will_msg(Channel = #channel{will_msg = undefined}) -> + Channel; +mabye_publish_will_msg(Channel = #channel{will_msg = WillMsg}) -> + ok = publish_will_msg(WillMsg), + Channel#channel{will_msg = undefined}. + +publish_will_msg(Msg) -> + _ = emqx_broker:publish(Msg), + ok. + +%%-------------------------------------------------------------------- +%% Handle Delivers from broker to client +%%-------------------------------------------------------------------- + +-spec handle_deliver(list(emqx_types:deliver()), channel()) + -> {ok, channel()} + | {ok, replies(), channel()}. +handle_deliver(Delivers, Channel = #channel{ + conn_state = ConnState, + session = Session, + clientinfo = #{clientid := ClientId}}) + when ConnState =:= disconnected; + ConnState =:= asleep -> + NSession = emqx_session:enqueue( + ignore_local(maybe_nack(Delivers), ClientId, Session), + Session + ), + {ok, Channel#channel{session = NSession}}; + +handle_deliver(Delivers, Channel = #channel{ + takeover = true, + pendings = Pendings, + session = Session, + clientinfo = #{clientid := ClientId}}) -> + NPendings = lists:append( + Pendings, + ignore_local(maybe_nack(Delivers), ClientId, Session) + ), + {ok, Channel#channel{pendings = NPendings}}; + +handle_deliver(Delivers, Channel = #channel{ + session = Session, + clientinfo = #{clientid := ClientId}}) -> + case emqx_session:deliver( + ignore_local(Delivers, ClientId, Session), + Session + ) of + {ok, Publishes, NSession} -> + NChannel = Channel#channel{session = NSession}, + handle_out(publish, Publishes, + ensure_timer(retry_timer, NChannel)); + {ok, NSession} -> + {ok, Channel#channel{session = NSession}} + end. + +ignore_local(Delivers, Subscriber, Session) -> + Subs = emqx_session:info(subscriptions, Session), + lists:dropwhile(fun({deliver, Topic, #message{from = Publisher}}) -> + case maps:find(Topic, Subs) of + {ok, #{nl := 1}} when Subscriber =:= Publisher -> + ok = emqx_metrics:inc('delivery.dropped'), + ok = emqx_metrics:inc('delivery.dropped.no_local'), + true; + _ -> + false + end + end, Delivers). + +%% Nack delivers from shared subscription +maybe_nack(Delivers) -> + lists:filter(fun not_nacked/1, Delivers). + +not_nacked({deliver, _Topic, Msg}) -> + not (emqx_shared_sub:is_ack_required(Msg) + andalso (ok == emqx_shared_sub:nack_no_connection(Msg))). + +%%-------------------------------------------------------------------- +%% Handle timeout +%%-------------------------------------------------------------------- + +-spec handle_timeout(reference(), Msg :: term(), channel()) + -> {ok, channel()} + | {ok, replies(), channel()} + | {shutdown, Reason :: term(), channel()}. + +handle_timeout(_TRef, {keepalive, _StatVal}, + Channel = #channel{keepalive = undefined}) -> + {ok, Channel}; +handle_timeout(_TRef, {keepalive, _StatVal}, + Channel = #channel{conn_state = ConnState}) + when ConnState =:= disconnected; + ConnState =:= asleep -> + {ok, Channel}; +handle_timeout(_TRef, {keepalive, StatVal}, + Channel = #channel{keepalive = Keepalive}) -> + case emqx_keepalive:check(StatVal, Keepalive) of + {ok, NKeepalive} -> + NChannel = Channel#channel{keepalive = NKeepalive}, + {ok, reset_timer(alive_timer, NChannel)}; + {error, timeout} -> + handle_out(disconnect, ?RC_KEEP_ALIVE_TIMEOUT, Channel) + end. + +%%-------------------------------------------------------------------- +%% Terminate +%%-------------------------------------------------------------------- + +terminate(_Reason, _Channel) -> + ok. + +reply(Reply, Channel) -> + {reply, Reply, Channel}. + +shutdown(Reason, Channel) -> + {shutdown, Reason, Channel}. + +shutdown(Reason, AckFrame, Channel) -> + {shutdown, Reason, AckFrame, Channel}. + +shutdown_and_reply(Reason, Reply, Channel) -> + {shutdown, Reason, Reply, Channel}. + +%%-------------------------------------------------------------------- +%% Will + +update_will_topic(undefined, #mqtt_sn_flags{qos = QoS, retain = Retain}, + Topic, ClientId) -> + WillMsg0 = emqx_message:make(ClientId, QoS, Topic, <<>>), + emqx_message:set_flag(retain, Retain, WillMsg0); +update_will_topic(Will, #mqtt_sn_flags{qos = QoS, retain = Retain}, + Topic, _ClientId) -> + emqx_message:set_flag(retain, Retain, + Will#message{qos = QoS, topic = Topic}). + +update_will_msg(Will, Payload) -> + Will#message{payload = Payload}. + +%%-------------------------------------------------------------------- +%% Timer + +cancel_timer(Name, Channel = #channel{timers = Timers}) -> + case maps:get(Name, Timers, undefined) of + undefined -> + Channel; + TRef -> + emqx_misc:cancel_timer(TRef), + Channel#channel{timers = maps:without([Name], Timers)} + end. + +ensure_timer([Name], Channel) -> + ensure_timer(Name, Channel); +ensure_timer([Name | Rest], Channel) -> + ensure_timer(Rest, ensure_timer(Name, Channel)); + +ensure_timer(Name, Channel = #channel{timers = Timers}) -> + TRef = maps:get(Name, Timers, undefined), + Time = interval(Name, Channel), + case TRef == undefined andalso is_integer(Time) andalso Time > 0 of + true -> ensure_timer(Name, Time, Channel); + false -> Channel %% Timer disabled or exists + end. + +ensure_timer(Name, Time, Channel = #channel{timers = Timers}) -> + Msg = maps:get(Name, ?TIMER_TABLE), + TRef = emqx_misc:start_timer(Time, Msg), + Channel#channel{timers = Timers#{Name => TRef}}. + +reset_timer(Name, Channel) -> + ensure_timer(Name, clean_timer(Name, Channel)). + +clean_timer(Name, Channel = #channel{timers = Timers}) -> + Channel#channel{timers = maps:remove(Name, Timers)}. + +interval(alive_timer, #channel{keepalive = KeepAlive}) -> + emqx_keepalive:info(interval, KeepAlive); +interval(retry_timer, #channel{session = Session}) -> + timer:seconds(emqx_session:info(retry_interval, Session)); +interval(await_timer, #channel{session = Session}) -> + timer:seconds(emqx_session:info(await_rel_timeout, Session)); +interval(expire_timer, #channel{conninfo = ConnInfo}) -> + timer:seconds(maps:get(expiry_interval, ConnInfo)). + +%%-------------------------------------------------------------------- +%% Helper functions +%%-------------------------------------------------------------------- + +run_hooks(Ctx, Name, Args) -> + emqx_gateway_ctx:metrics_inc(Ctx, Name), + emqx_hooks:run(Name, Args). + +run_hooks(Ctx, Name, Args, Acc) -> + emqx_gateway_ctx:metrics_inc(Ctx, Name), + emqx_hooks:run_fold(Name, Args, Acc). + +run_hooks_without_metrics(_Ctx, Name, Args) -> + emqx_hooks:run_fold(Name, Args). + +run_hooks_without_metrics(_Ctx, Name, Args, Acc) -> + emqx_hooks:run_fold(Name, Args, Acc). + +metrics_inc(Name, #channel{ctx = Ctx}) -> + emqx_gateway_ctx:metrics_inc(Ctx, Name). diff --git a/apps/emqx_gateway/src/mqttsn/emqx_sn_conn.erl b/apps/emqx_gateway/src/mqttsn/emqx_sn_conn.erl new file mode 100644 index 000000000..a8780b5e0 --- /dev/null +++ b/apps/emqx_gateway/src/mqttsn/emqx_sn_conn.erl @@ -0,0 +1,802 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +%% @doc MQTT-SN Connection process +-module(emqx_sn_conn). + +-include_lib("emqx/include/types.hrl"). +-include_lib("emqx/include/logger.hrl"). + +-logger_header("[SN-Conn]"). + +%% API +-export([ start_link/3 + , stop/1 + ]). + +-export([ info/1 + , stats/1 + ]). + +-export([ call/2 + , call/3 + , cast/2 + ]). + +%% Callback +-export([init/4]). + +%% Sys callbacks +-export([ system_continue/3 + , system_terminate/4 + , system_code_change/4 + , system_get_state/1 + ]). + +%% Internal callback +-export([wakeup_from_hib/2, recvloop/2]). + +-import(emqx_misc, [start_timer/2]). + +-record(state, { + %% TCP/SSL/UDP/DTLS Wrapped Socket + socket :: {esockd_transport, esockd:socket()} | {udp, _, _}, + %% Peername of the connection + peername :: emqx_types:peername(), + %% Sockname of the connection + sockname :: emqx_types:peername(), + %% Sock State + sockstate :: emqx_types:sockstate(), + %% The {active, N} option + active_n :: pos_integer(), + %% Limiter + limiter :: maybe(emqx_limiter:limiter()), + %% Limit Timer + limit_timer :: maybe(reference()), + %% Parse State + parse_state :: emqx_sn_frame:parse_state(), + %% Serialize options + serialize :: emqx_sn_frame:serialize_opts(), + %% Channel State + channel :: emqx_sn_channel:channel(), + %% GC State + gc_state :: maybe(emqx_gc:gc_state()), + %% Stats Timer + stats_timer :: disabled | maybe(reference()), + %% Idle Timeout + idle_timeout :: integer(), + %% Idle Timer + idle_timer :: maybe(reference()) + }). + +-type(state() :: #state{}). + +-define(INFO_KEYS, [socktype, peername, sockname, sockstate, active_n]). +-define(CONN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). +-define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]). + +-define(ENABLED(X), (X =/= undefined)). + +-define(DEFAULT_GC_OPTS, #{count => 1000, bytes => 1024*1024}). +-define(DEFAULT_IDLE_TIMEOUT, 30000). +-define(DEFAULT_OOM_POLICY, #{max_heap_size => 4194304, + message_queue_len => 32000}). + +-dialyzer({nowarn_function, + [ system_terminate/4 + , handle_call/3 + , handle_msg/2 + , shutdown/3 + , stop/3 + ]}). + +%% udp +start_link(Socket = {udp, _SockPid, _Sock}, Peername, Options) -> + Args = [self(), Socket, Peername, Options], + {ok, proc_lib:spawn_link(?MODULE, init, Args)}; + +%% tcp/ssl/dtls +start_link(esockd_transport, Sock, Options) -> + Socket = {esockd_transport, Sock}, + case esockd_transport:peername(Sock) of + {ok, Peername} -> + Args = [self(), Socket, Peername, Options], + {ok, proc_lib:spawn_link(?MODULE, init, Args)}; + R = {error, _} -> R + end. + +%%-------------------------------------------------------------------- +%% API +%%-------------------------------------------------------------------- + +%% @doc Get infos of the connection/channel. +-spec(info(pid()|state()) -> emqx_types:infos()). +info(CPid) when is_pid(CPid) -> + call(CPid, info); +info(State = #state{channel = Channel}) -> + ChanInfo = emqx_sn_channel:info(Channel), + SockInfo = maps:from_list( + info(?INFO_KEYS, State)), + ChanInfo#{sockinfo => SockInfo}. + +info(Keys, State) when is_list(Keys) -> + [{Key, info(Key, State)} || Key <- Keys]; +info(socktype, #state{socket = Socket}) -> + esockd_type(Socket); +info(peername, #state{peername = Peername}) -> + Peername; +info(sockname, #state{sockname = Sockname}) -> + Sockname; +info(sockstate, #state{sockstate = SockSt}) -> + SockSt; +info(active_n, #state{active_n = ActiveN}) -> + ActiveN. + +-spec(stats(pid()|state()) -> emqx_types:stats()). +stats(CPid) when is_pid(CPid) -> + call(CPid, stats); +stats(#state{socket = Socket, + channel = Channel}) -> + SockStats = case esockd_getstat(Socket, ?SOCK_STATS) of + {ok, Ss} -> Ss; + {error, _} -> [] + end, + ConnStats = emqx_pd:get_counters(?CONN_STATS), + ChanStats = emqx_sn_channel:stats(Channel), + ProcStats = emqx_misc:proc_stats(), + lists:append([SockStats, ConnStats, ChanStats, ProcStats]). + +call(Pid, Req) -> + call(Pid, Req, infinity). + +call(Pid, Req, Timeout) -> + gen_server:call(Pid, Req, Timeout). + +cast(Pid, Req) -> + gen_server:cast(Pid, Req). + +stop(Pid) -> + gen_server:stop(Pid). + +%%-------------------------------------------------------------------- +%% Wrapped funcs +%%-------------------------------------------------------------------- + +esockd_wait(Socket = {udp, _SockPid, _Sock}) -> + {ok, Socket}; +esockd_wait({esockd_transport, Sock}) -> + case esockd_transport:wait(Sock) of + {ok, NSock} -> {ok, {esockd_transport, NSock}}; + R = {error, _} -> R + end. + +esockd_close({udp, _SockPid, _Sock}) -> + %% nothing to do for udp socket + %%gen_udp:close(Sock); + ok; +esockd_close({esockd_transport, Sock}) -> + esockd_transport:fast_close(Sock). + +esockd_ensure_ok_or_exit(peercert, {udp, _SockPid, _Sock}) -> + nossl; +esockd_ensure_ok_or_exit(Fun, {udp, _SockPid, Sock}) -> + esockd_transport:ensure_ok_or_exit(Fun, [Sock]); +esockd_ensure_ok_or_exit(Fun, {esockd_transport, Socket}) -> + esockd_transport:ensure_ok_or_exit(Fun, [Socket]). + +esockd_type({udp, _, _}) -> + udp; +esockd_type({esockd_transport, Socket}) -> + esockd_transport:type(Socket). + +esockd_setopts({udp, _, _}, _) -> + ok; +esockd_setopts({esockd_transport, Socket}, Opts) -> + %% FIXME: DTLS works?? + esockd_transport:setopts(Socket, Opts). + +esockd_getstat({udp, _SockPid, Sock}, Stats) -> + inet:getstat(Sock, Stats); +esockd_getstat({esockd_transport, Sock}, Stats) -> + esockd_transport:getstat(Sock, Stats). + +esockd_send(Data, #state{socket = {udp, _SockPid, Sock}, + peername = {Ip, Port}}) -> + gen_udp:send(Sock, Ip, Port, Data); +esockd_send(Data, #state{socket = {esockd_transport, Sock}}) -> + esockd_transport:async_send(Sock, Data). + +%%-------------------------------------------------------------------- +%% callbacks +%%-------------------------------------------------------------------- + +init(Parent, WrappedSock, Peername, Options) -> + case esockd_wait(WrappedSock) of + {ok, NWrappedSock} -> + run_loop(Parent, init_state(NWrappedSock, Peername, Options)); + {error, Reason} -> + ok = esockd_close(WrappedSock), + exit_on_sock_error(Reason) + end. + +init_state(WrappedSock, Peername, Options) -> + {ok, Sockname} = esockd_ensure_ok_or_exit(sockname, WrappedSock), + Peercert = esockd_ensure_ok_or_exit(peercert, WrappedSock), + ConnInfo = #{socktype => esockd_type(WrappedSock), + peername => Peername, + sockname => Sockname, + peercert => Peercert, + conn_mod => ?MODULE + }, + ActiveN = emqx_gateway_utils:active_n(Options), + %% FIXME: + %%Limiter = emqx_limiter:init(Options), + Limiter = undefined, + FrameOpts = emqx_gateway_utils:frame_options(Options), + ParseState = emqx_sn_frame:initial_parse_state(FrameOpts), + Serialize = emqx_sn_frame:serialize_opts(), + Channel = emqx_sn_channel:init(ConnInfo, Options), + GcState = emqx_gateway_utils:init_gc_state(Options), + StatsTimer = emqx_gateway_utils:stats_timer(Options), + IdleTimeout = emqx_gateway_utils:idle_timeout(Options), + IdleTimer = start_timer(IdleTimeout, idle_timeout), + #state{socket = WrappedSock, + peername = Peername, + sockname = Sockname, + sockstate = idle, + active_n = ActiveN, + limiter = Limiter, + parse_state = ParseState, + serialize = Serialize, + channel = Channel, + gc_state = GcState, + stats_timer = StatsTimer, + idle_timeout = IdleTimeout, + idle_timer = IdleTimer + }. + +run_loop(Parent, State = #state{socket = Socket, + peername = Peername}) -> + emqx_logger:set_metadata_peername(esockd:format(Peername)), + _ = emqx_misc:tune_heap_size(?DEFAULT_OOM_POLICY), + case activate_socket(State) of + {ok, NState} -> + hibernate(Parent, NState); + {error, Reason} -> + ok = esockd_close(Socket), + exit_on_sock_error(Reason) + end. + +-spec exit_on_sock_error(atom()) -> no_return(). +exit_on_sock_error(Reason) when Reason =:= einval; + Reason =:= enotconn; + Reason =:= closed -> + erlang:exit(normal); +exit_on_sock_error(timeout) -> + erlang:exit({shutdown, ssl_upgrade_timeout}); +exit_on_sock_error(Reason) -> + erlang:exit({shutdown, Reason}). + +%%-------------------------------------------------------------------- +%% Recv Loop + +recvloop(Parent, State = #state{idle_timeout = IdleTimeout}) -> + receive + Msg -> + handle_recv(Msg, Parent, State) + after + IdleTimeout + 100 -> + hibernate(Parent, cancel_stats_timer(State)) + end. + +handle_recv({system, From, Request}, Parent, State) -> + sys:handle_system_msg(Request, From, Parent, ?MODULE, [], State); +handle_recv({'EXIT', Parent, Reason}, Parent, State) -> + %% FIXME: it's not trapping exit, should never receive an EXIT + terminate(Reason, State); +handle_recv(Msg, Parent, State = #state{idle_timeout = IdleTimeout}) -> + case process_msg([Msg], ensure_stats_timer(IdleTimeout, State)) of + {ok, NewState} -> + ?MODULE:recvloop(Parent, NewState); + {stop, Reason, NewSate} -> + terminate(Reason, NewSate) + end. + +hibernate(Parent, State) -> + proc_lib:hibernate(?MODULE, wakeup_from_hib, [Parent, State]). + +%% Maybe do something here later. +wakeup_from_hib(Parent, State) -> + ?MODULE:recvloop(Parent, State). + +%%-------------------------------------------------------------------- +%% Ensure/cancel stats timer + +ensure_stats_timer(Timeout, State = #state{stats_timer = undefined}) -> + State#state{stats_timer = start_timer(Timeout, emit_stats)}; +ensure_stats_timer(_Timeout, State) -> State. + +cancel_stats_timer(State = #state{stats_timer = TRef}) + when is_reference(TRef) -> + ok = emqx_misc:cancel_timer(TRef), + State#state{stats_timer = undefined}; +cancel_stats_timer(State) -> State. + +%%-------------------------------------------------------------------- +%% Process next Msg + +process_msg([], State) -> + {ok, State}; +process_msg([Msg|More], State) -> + try + case handle_msg(Msg, State) of + ok -> + process_msg(More, State); + {ok, NState} -> + process_msg(More, NState); + {ok, Msgs, NState} -> + process_msg(append_msg(More, Msgs), NState); + {stop, Reason, NState} -> + {stop, Reason, NState} + end + catch + exit : normal -> + {stop, normal, State}; + exit : shutdown -> + {stop, shutdown, State}; + exit : {shutdown, _} = Shutdown -> + {stop, Shutdown, State}; + Exception : Context : Stack -> + {stop, #{exception => Exception, + context => Context, + stacktrace => Stack}, State} + end. + +append_msg([], Msgs) when is_list(Msgs) -> + Msgs; +append_msg([], Msg) -> [Msg]; +append_msg(Q, Msgs) when is_list(Msgs) -> + lists:append(Q, Msgs); +append_msg(Q, Msg) -> + lists:append(Q, [Msg]). + +%%-------------------------------------------------------------------- +%% Handle a Msg + +handle_msg({'$gen_call', From, Req}, State) -> + case handle_call(From, Req, State) of + {reply, Reply, NState} -> + gen_server:reply(From, Reply), + {ok, NState}; + {reply, Reply, Msgs, NState} -> + gen_server:reply(From, Reply), + {ok, next_msgs(Msgs), NState}; + {stop, Reason, Reply, NState} -> + gen_server:reply(From, Reply), + stop(Reason, NState) + end; + +handle_msg({'$gen_cast', Req}, State) -> + with_channel(handle_cast, [Req], State); + +handle_msg({datagram, _SockPid, Data}, State) -> + parse_incoming(Data, State); + +handle_msg({Inet, _Sock, Data}, State) + when Inet == tcp; + Inet == ssl -> + parse_incoming(Data, State); + +handle_msg({incoming, Packet}, + State = #state{idle_timer = IdleTimer}) -> + IdleTimer /= undefined andalso + emqx_misc:cancel_timer(IdleTimer), + NState = State#state{idle_timer = undefined}, + handle_incoming(Packet, NState); + +handle_msg({outgoing, Data}, State) -> + handle_outgoing(Data, State); + +handle_msg({Error, _Sock, Reason}, State) + when Error == tcp_error; Error == ssl_error -> + handle_info({sock_error, Reason}, State); + +handle_msg({Closed, _Sock}, State) + when Closed == tcp_closed; Closed == ssl_closed -> + handle_info({sock_closed, Closed}, close_socket(State)); + +%% TODO: udp_passive??? +handle_msg({Passive, _Sock}, State) + when Passive == tcp_passive; Passive == ssl_passive -> + %% In Stats + Bytes = emqx_pd:reset_counter(incoming_bytes), + Pubs = emqx_pd:reset_counter(incoming_pkt), + InStats = #{cnt => Pubs, oct => Bytes}, + %% Ensure Rate Limit + NState = ensure_rate_limit(InStats, State), + %% Run GC and Check OOM + NState1 = check_oom(run_gc(InStats, NState)), + handle_info(activate_socket, NState1); + +handle_msg(Deliver = {deliver, _Topic, _Msg}, + State = #state{active_n = ActiveN}) -> + Delivers = [Deliver|emqx_misc:drain_deliver(ActiveN)], + with_channel(handle_deliver, [Delivers], State); + +%% Something sent +%% TODO: Who will deliver this message? +handle_msg({inet_reply, _Sock, ok}, State = #state{active_n = ActiveN}) -> + case emqx_pd:get_counter(outgoing_pkt) > ActiveN of + true -> + Pubs = emqx_pd:reset_counter(outgoing_pkt), + Bytes = emqx_pd:reset_counter(outgoing_bytes), + OutStats = #{cnt => Pubs, oct => Bytes}, + {ok, check_oom(run_gc(OutStats, State))}; + false -> ok + end; + +handle_msg({inet_reply, _Sock, {error, Reason}}, State) -> + handle_info({sock_error, Reason}, State); + +handle_msg({close, Reason}, State) -> + ?LOG(debug, "Force to close the socket due to ~p", [Reason]), + handle_info({sock_closed, Reason}, close_socket(State)); + +handle_msg({event, connected}, State = #state{channel = Channel}) -> + Ctx = emqx_sn_channel:info(ctx, Channel), + ClientId = emqx_sn_channel:info(clientid, Channel), + emqx_gateway_ctx:insert_channel_info( + Ctx, + ClientId, + info(State), + stats(State) + ); + +handle_msg({event, disconnected}, State = #state{channel = Channel}) -> + Ctx = emqx_sn_channel:info(ctx, Channel), + ClientId = emqx_sn_channel:info(clientid, Channel), + emqx_gateway_ctx:set_chan_info(Ctx, ClientId, info(State)), + emqx_gateway_ctx:connection_closed(Ctx, ClientId), + {ok, State}; + +handle_msg({event, _Other}, State = #state{channel = Channel}) -> + Ctx = emqx_sn_channel:info(ctx, Channel), + ClientId = emqx_sn_channel:info(clientid, Channel), + emqx_gateway_ctx:set_chan_info(Ctx, ClientId, info(State)), + emqx_gateway_ctx:set_chan_stats(Ctx, ClientId, stats(State)), + {ok, State}; + +handle_msg({timeout, TRef, TMsg}, State) -> + handle_timeout(TRef, TMsg, State); + +handle_msg(Shutdown = {shutdown, _Reason}, State) -> + stop(Shutdown, State); + +handle_msg(Msg, State) -> + handle_info(Msg, State). + +%%-------------------------------------------------------------------- +%% Terminate + +-spec terminate(atom(), state()) -> no_return(). +terminate(Reason, State = #state{channel = Channel}) -> + ?LOG(debug, "Terminated due to ~p", [Reason]), + _ = emqx_sn_channel:terminate(Reason, Channel), + _ = close_socket(State), + exit(Reason). + +%%-------------------------------------------------------------------- +%% Sys callbacks + +system_continue(Parent, _Debug, State) -> + recvloop(Parent, State). + +system_terminate(Reason, _Parent, _Debug, State) -> + terminate(Reason, State). + +system_code_change(State, _Mod, _OldVsn, _Extra) -> + {ok, State}. + +system_get_state(State) -> {ok, State}. + +%%-------------------------------------------------------------------- +%% Handle call + +handle_call(_From, info, State) -> + {reply, info(State), State}; + +handle_call(_From, stats, State) -> + {reply, stats(State), State}; + +handle_call(_From, Req, State = #state{channel = Channel}) -> + case emqx_sn_channel:handle_call(Req, Channel) of + {reply, Reply, NChannel} -> + {reply, Reply, State#state{channel = NChannel}}; + {reply, Reply, Replies, NChannel} -> + {reply, Reply, Replies, State#state{channel = NChannel}}; + {shutdown, Reason, Reply, NChannel} -> + shutdown(Reason, Reply, State#state{channel = NChannel}) + end. + +%%-------------------------------------------------------------------- +%% Handle timeout + +handle_timeout(_TRef, idle_timeout, State) -> + shutdown(idle_timeout, State); + +handle_timeout(_TRef, limit_timeout, State) -> + NState = State#state{sockstate = idle, + limit_timer = undefined + }, + handle_info(activate_socket, NState); +handle_timeout(TRef, keepalive, State = #state{socket = Socket, + channel = Channel})-> + case emqx_sn_channel:info(conn_state, Channel) of + disconnected -> {ok, State}; + _ -> + case esockd_getstat(Socket, [recv_oct]) of + {ok, [{recv_oct, RecvOct}]} -> + handle_timeout(TRef, {keepalive, RecvOct}, State); + {error, Reason} -> + handle_info({sock_error, Reason}, State) + end + end; +handle_timeout(_TRef, emit_stats, State = + #state{channel = Channel}) -> + Ctx = emqx_sn_channel:info(ctx, Channel), + ClientId = emqx_sn_channel:info(clientid, Channel), + emqx_gateway_ctx:set_chan_stats(Ctx, ClientId, stats(State)), + {ok, State#state{stats_timer = undefined}}; + +handle_timeout(TRef, Msg, State) -> + with_channel(handle_timeout, [TRef, Msg], State). + +%%-------------------------------------------------------------------- +%% Parse incoming data + +parse_incoming(Data, State = #state{channel = Channel}) -> + ?LOG(debug, "RECV ~0p", [Data]), + Oct = iolist_size(Data), + inc_counter(incoming_bytes, Oct), + Ctx = emqx_sn_channel:info(ctx, Channel), + ok = emqx_gateway_ctx:metrics_inc(Ctx, 'bytes.received', Oct), + {Packets, NState} = parse_incoming(Data, [], State), + {ok, next_incoming_msgs(Packets), NState}. + +parse_incoming(<<>>, Packets, State) -> + {Packets, State}; + +parse_incoming(Data, Packets, State = #state{parse_state = ParseState}) -> + try emqx_sn_frame:parse(Data, ParseState) of + {more, NParseState} -> + {Packets, State#state{parse_state = NParseState}}; + {ok, Packet, Rest, NParseState} -> + NState = State#state{parse_state = NParseState}, + parse_incoming(Rest, [Packet|Packets], NState) + catch + error:Reason:Stk -> + ?LOG(error, "~nParse failed for ~0p~n~0p~nFrame data:~0p", + [Reason, Stk, Data]), + {[{frame_error, Reason}|Packets], State} + end. + +next_incoming_msgs([Packet]) -> + {incoming, Packet}; +next_incoming_msgs(Packets) -> + [{incoming, Packet} || Packet <- lists:reverse(Packets)]. + +%%-------------------------------------------------------------------- +%% Handle incoming packet + +handle_incoming(Packet, State) -> + ok = inc_incoming_stats(Packet), + ?LOG(debug, "RECV ~s", [emqx_sn_frame:format(Packet)]), + with_channel(handle_in, [Packet], State). + +%%-------------------------------------------------------------------- +%% With Channel + +with_channel(Fun, Args, State = #state{channel = Channel}) -> + case erlang:apply(emqx_sn_channel, Fun, Args ++ [Channel]) of + ok -> {ok, State}; + {ok, NChannel} -> + {ok, State#state{channel = NChannel}}; + {ok, Replies, NChannel} -> + {ok, next_msgs(Replies), State#state{channel = NChannel}}; + {shutdown, Reason, NChannel} -> + shutdown(Reason, State#state{channel = NChannel}); + {shutdown, Reason, Packet, NChannel} -> + NState = State#state{channel = NChannel}, + ok = handle_outgoing(Packet, NState), + shutdown(Reason, NState) + end. + +%%-------------------------------------------------------------------- +%% Handle outgoing packets + +handle_outgoing(Packets, State) when is_list(Packets) -> + send(lists:map(serialize_and_inc_stats_fun(State), Packets), State); + +handle_outgoing(Packet, State) -> + send((serialize_and_inc_stats_fun(State))(Packet), State). + +serialize_and_inc_stats_fun(#state{serialize = Serialize, channel = Channel}) -> + Ctx = emqx_sn_channel:info(ctx, Channel), + fun(Packet) -> + case emqx_sn_frame:serialize_pkt(Packet, Serialize) of + <<>> -> ?LOG(warning, "~s is discarded due to the frame is too large!", + [emqx_sn_frame:format(Packet)]), + ok = emqx_gateway_ctx:metrics_inc(Ctx, 'delivery.dropped.too_large'), + ok = emqx_gateway_ctx:metrics_inc(Ctx, 'delivery.dropped'), + <<>>; + Data -> ?LOG(debug, "SEND ~s", [emqx_sn_frame:format(Packet)]), + ok = inc_outgoing_stats(Packet), + Data + end + end. + +%%-------------------------------------------------------------------- +%% Send data + +-spec(send(iodata(), state()) -> ok). +send(IoData, State = #state{socket = Socket, channel = Channel}) -> + Ctx = emqx_sn_channel:info(ctx, Channel), + Oct = iolist_size(IoData), + ok = emqx_gateway_ctx:metrics_inc(Ctx, 'bytes.sent', Oct), + inc_counter(outgoing_bytes, Oct), + case esockd_send(IoData, State) of + ok -> ok; + Error = {error, _Reason} -> + %% Send an inet_reply to postpone handling the error + self() ! {inet_reply, Socket, Error}, + ok + end. + +%%-------------------------------------------------------------------- +%% Handle Info + +handle_info(activate_socket, State = #state{sockstate = OldSst}) -> + case activate_socket(State) of + {ok, NState = #state{sockstate = NewSst}} -> + if OldSst =/= NewSst -> + {ok, {event, NewSst}, NState}; + true -> {ok, NState} + end; + {error, Reason} -> + handle_info({sock_error, Reason}, State) + end; + +handle_info({sock_error, Reason}, State) -> + ?LOG(debug, "Socket error: ~p", [Reason]), + handle_info({sock_closed, Reason}, close_socket(State)); + +handle_info(Info, State) -> + with_channel(handle_info, [Info], State). + +%%-------------------------------------------------------------------- +%% Ensure rate limit + +ensure_rate_limit(Stats, State = #state{limiter = Limiter}) -> + case ?ENABLED(Limiter) andalso emqx_limiter:check(Stats, Limiter) of + false -> State; + {ok, Limiter1} -> + State#state{limiter = Limiter1}; + {pause, Time, Limiter1} -> + ?LOG(warning, "Pause ~pms due to rate limit", [Time]), + TRef = start_timer(Time, limit_timeout), + State#state{sockstate = blocked, + limiter = Limiter1, + limit_timer = TRef + } + end. + +%%-------------------------------------------------------------------- +%% Run GC and Check OOM + +run_gc(Stats, State = #state{gc_state = GcSt}) -> + case ?ENABLED(GcSt) andalso emqx_gc:run(Stats, GcSt) of + false -> State; + {_IsGC, GcSt1} -> + State#state{gc_state = GcSt1} + end. + +check_oom(State) -> + OomPolicy = ?DEFAULT_OOM_POLICY, + case ?ENABLED(OomPolicy) andalso emqx_misc:check_oom(OomPolicy) of + Shutdown = {shutdown, _Reason} -> + erlang:send(self(), Shutdown); + _Other -> ok + end, + State. + +%%-------------------------------------------------------------------- +%% Activate Socket + +activate_socket(State = #state{sockstate = closed}) -> + {ok, State}; +activate_socket(State = #state{sockstate = blocked}) -> + {ok, State}; +activate_socket(State = #state{socket = Socket, + active_n = N}) -> + %% FIXME: Works on dtls/udp ??? + %% How to hanlde buffer? + case esockd_setopts(Socket, [{active, N}]) of + ok -> {ok, State#state{sockstate = running}}; + Error -> Error + end. + +%%-------------------------------------------------------------------- +%% Close Socket + +close_socket(State = #state{sockstate = closed}) -> State; +close_socket(State = #state{socket = Socket}) -> + ok = esockd_close(Socket), + State#state{sockstate = closed}. + +%%-------------------------------------------------------------------- +%% Inc incoming/outgoing stats + +%% XXX: How to stats? +inc_incoming_stats(_Packet) -> + inc_counter(recv_pkt, 1), + ok. + %case Type =:= ?CMD_SEND of + % true -> + % inc_counter(recv_msg, 1), + % inc_counter(incoming_pubs, 1); + % false -> + % ok + %end, + %emqx_metrics:inc_recv(Packet). + +inc_outgoing_stats(_Packet) -> + inc_counter(send_pkt, 1), + ok. + %case Type =:= ?CMD_MESSAGE of + % true -> + % inc_counter(send_msg, 1), + % inc_counter(outgoing_pubs, 1); + % false -> + % ok + %end, + %emqx_metrics:inc_sent(Packet). + +%%-------------------------------------------------------------------- +%% Helper functions + +-compile({inline, [next_msgs/1]}). +next_msgs(Event) when is_tuple(Event) -> + Event; +next_msgs(More) when is_list(More) -> + More. + +-compile({inline, [shutdown/2, shutdown/3]}). +shutdown(Reason, State) -> + stop({shutdown, Reason}, State). + +shutdown(Reason, Reply, State) -> + stop({shutdown, Reason}, Reply, State). + +-compile({inline, [stop/2, stop/3]}). +stop(Reason, State) -> + {stop, Reason, State}. + +stop(Reason, Reply, State) -> + {stop, Reason, Reply, State}. + +inc_counter(Name, Value) -> + _ = emqx_pd:inc_counter(Name, Value), + ok. diff --git a/apps/emqx_gateway/src/mqttsn/emqx_sn_frame.erl b/apps/emqx_gateway/src/mqttsn/emqx_sn_frame.erl index 301247fbc..c9b9b137d 100644 --- a/apps/emqx_gateway/src/mqttsn/emqx_sn_frame.erl +++ b/apps/emqx_gateway/src/mqttsn/emqx_sn_frame.erl @@ -19,8 +19,10 @@ -include("src/mqttsn/include/emqx_sn.hrl"). --export([ parse/1 - , serialize/1 +-export([ initial_parse_state/1 + , serialize_opts/0 + , parse/2 + , serialize_pkt/2 , message_type/1 , format/1 ]). @@ -29,17 +31,33 @@ -define(byte, 8/big-integer). -define(short, 16/big-integer). +-type parse_state() :: #{}. +-type serialize_opts() :: #{}. + +-export_type([ parse_state/0 + , serialize_opts/0 + ]). + +%%-------------------------------------------------------------------- +%% Initial + +initial_parse_state(_) -> + #{}. + +serialize_opts() -> + #{}. + %%-------------------------------------------------------------------- %% Parse MQTT-SN Message %%-------------------------------------------------------------------- -parse(<<16#01:?byte, Len:?short, Type:?byte, Var/binary>>) -> - parse(Type, Len - 4, Var); -parse(<>) -> - parse(Type, Len - 2, Var). +parse(<<16#01:?byte, Len:?short, Type:?byte, Var/binary>>, _State) -> + {ok, parse(Type, Len - 4, Var), <<>>, _State}; +parse(<>, _State) -> + {ok, parse(Type, Len - 2, Var), <<>>, _State}. parse(Type, Len, Var) when Len =:= size(Var) -> - {ok, #mqtt_sn_message{type = Type, variable = parse_var(Type, Var)}}; + #mqtt_sn_message{type = Type, variable = parse_var(Type, Var)}; parse(_Type, _Len, _Var) -> error(malformed_message_len). @@ -127,70 +145,70 @@ parse_topic(2#11, Topic) -> Topic. %% Serialize MQTT-SN Message %%-------------------------------------------------------------------- -serialize(#mqtt_sn_message{type = Type, variable = Var}) -> - VarBin = serialize(Type, Var), VarLen = size(VarBin), +serialize_pkt(#mqtt_sn_message{type = Type, variable = Var}, Opts) -> + VarBin = serialize(Type, Var, Opts), VarLen = size(VarBin), if VarLen < 254 -> <<(VarLen + 2), Type, VarBin/binary>>; true -> <<16#01, (VarLen + 4):?short, Type, VarBin/binary>> end. -serialize(?SN_ADVERTISE, {GwId, Duration}) -> +serialize(?SN_ADVERTISE, {GwId, Duration}, _Opts) -> <>; -serialize(?SN_SEARCHGW, Radius) -> +serialize(?SN_SEARCHGW, Radius, _Opts) -> <>; -serialize(?SN_GWINFO, {GwId, GwAdd}) -> +serialize(?SN_GWINFO, {GwId, GwAdd}, _Opts) -> <>; -serialize(?SN_CONNECT, {Flags, ProtocolId, Duration, ClientId}) -> +serialize(?SN_CONNECT, {Flags, ProtocolId, Duration, ClientId}, _Opts) -> <<(serialize_flags(Flags))/binary, ProtocolId, Duration:?short, ClientId/binary>>; -serialize(?SN_CONNACK, ReturnCode) -> +serialize(?SN_CONNACK, ReturnCode, _Opts) -> <>; -serialize(?SN_WILLTOPICREQ, _) -> +serialize(?SN_WILLTOPICREQ, _, _Opts) -> <<>>; -serialize(?SN_WILLTOPIC, undefined) -> +serialize(?SN_WILLTOPIC, undefined, _Opts) -> <<>>; -serialize(?SN_WILLTOPIC, {Flags, Topic}) -> +serialize(?SN_WILLTOPIC, {Flags, Topic}, _Opts) -> %% The WillTopic must a short topic name <<(serialize_flags(Flags))/binary, Topic/binary>>; -serialize(?SN_WILLMSGREQ, _) -> +serialize(?SN_WILLMSGREQ, _, _Opts) -> <<>>; -serialize(?SN_WILLMSG, WillMsg) -> +serialize(?SN_WILLMSG, WillMsg, _Opts) -> WillMsg; -serialize(?SN_REGISTER, {TopicId, MsgId, TopicName}) -> +serialize(?SN_REGISTER, {TopicId, MsgId, TopicName}, _Opts) -> <>; -serialize(?SN_REGACK, {TopicId, MsgId, ReturnCode}) -> +serialize(?SN_REGACK, {TopicId, MsgId, ReturnCode}, _Opts) -> <>; -serialize(?SN_PUBLISH, {Flags=#mqtt_sn_flags{topic_id_type = ?SN_NORMAL_TOPIC}, TopicId, MsgId, Data}) -> +serialize(?SN_PUBLISH, {Flags=#mqtt_sn_flags{topic_id_type = ?SN_NORMAL_TOPIC}, TopicId, MsgId, Data}, _Opts) -> <<(serialize_flags(Flags))/binary, TopicId:?short, MsgId:?short, Data/binary>>; -serialize(?SN_PUBLISH, {Flags=#mqtt_sn_flags{topic_id_type = ?SN_PREDEFINED_TOPIC}, TopicId, MsgId, Data}) -> +serialize(?SN_PUBLISH, {Flags=#mqtt_sn_flags{topic_id_type = ?SN_PREDEFINED_TOPIC}, TopicId, MsgId, Data}, _Opts) -> <<(serialize_flags(Flags))/binary, TopicId:?short, MsgId:?short, Data/binary>>; -serialize(?SN_PUBLISH, {Flags=#mqtt_sn_flags{topic_id_type = ?SN_SHORT_TOPIC}, STopicName, MsgId, Data}) -> +serialize(?SN_PUBLISH, {Flags=#mqtt_sn_flags{topic_id_type = ?SN_SHORT_TOPIC}, STopicName, MsgId, Data}, _Opts) -> <<(serialize_flags(Flags))/binary, STopicName:2/binary, MsgId:?short, Data/binary>>; -serialize(?SN_PUBACK, {TopicId, MsgId, ReturnCode}) -> +serialize(?SN_PUBACK, {TopicId, MsgId, ReturnCode}, _Opts) -> <>; -serialize(PubRec, MsgId) when PubRec == ?SN_PUBREC; PubRec == ?SN_PUBREL; PubRec == ?SN_PUBCOMP -> +serialize(PubRec, MsgId, _Opts) when PubRec == ?SN_PUBREC; PubRec == ?SN_PUBREL; PubRec == ?SN_PUBCOMP -> <>; -serialize(Sub, {Flags = #mqtt_sn_flags{topic_id_type = IdType}, MsgId, Topic}) +serialize(Sub, {Flags = #mqtt_sn_flags{topic_id_type = IdType}, MsgId, Topic}, _Opts) when Sub == ?SN_SUBSCRIBE; Sub == ?SN_UNSUBSCRIBE -> <<(serialize_flags(Flags))/binary, MsgId:16, (serialize_topic(IdType, Topic))/binary>>; -serialize(?SN_SUBACK, {Flags, TopicId, MsgId, ReturnCode}) -> +serialize(?SN_SUBACK, {Flags, TopicId, MsgId, ReturnCode}, _Opts) -> <<(serialize_flags(Flags))/binary, TopicId:?short, MsgId:?short, ReturnCode>>; -serialize(?SN_UNSUBACK, MsgId) -> +serialize(?SN_UNSUBACK, MsgId, _Opts) -> <>; -serialize(?SN_PINGREQ, ClientId) -> +serialize(?SN_PINGREQ, ClientId, _Opts) -> ClientId; -serialize(?SN_PINGRESP, _) -> +serialize(?SN_PINGRESP, _, _Opts) -> <<>>; -serialize(?SN_WILLTOPICUPD, {Flags, WillTopic}) -> +serialize(?SN_WILLTOPICUPD, {Flags, WillTopic}, _Opts) -> <<(serialize_flags(Flags))/binary, WillTopic/binary>>; -serialize(?SN_WILLMSGUPD, WillMsg) -> +serialize(?SN_WILLMSGUPD, WillMsg, _Opts) -> WillMsg; -serialize(?SN_WILLTOPICRESP, ReturnCode) -> +serialize(?SN_WILLTOPICRESP, ReturnCode, _Opts) -> <>; -serialize(?SN_WILLMSGRESP, ReturnCode) -> +serialize(?SN_WILLMSGRESP, ReturnCode, _Opts) -> <>; -serialize(?SN_DISCONNECT, undefined) -> +serialize(?SN_DISCONNECT, undefined, _Opts) -> <<>>; -serialize(?SN_DISCONNECT, Duration) -> +serialize(?SN_DISCONNECT, Duration, _Opts) -> <>. serialize_flags(#mqtt_sn_flags{dup = Dup, qos = QoS, retain = Retain, will = Will, diff --git a/apps/emqx_gateway/src/mqttsn/emqx_sn_impl.erl b/apps/emqx_gateway/src/mqttsn/emqx_sn_impl.erl index 1e5c5d8cd..7fe257d5c 100644 --- a/apps/emqx_gateway/src/mqttsn/emqx_sn_impl.erl +++ b/apps/emqx_gateway/src/mqttsn/emqx_sn_impl.erl @@ -129,7 +129,7 @@ start_listener(InstaId, Ctx, {Type, ListenOn, SocketOpts, Cfg}) -> start_listener(InstaId, Ctx, Type, ListenOn, SocketOpts, Cfg) -> Name = name(InstaId, Type), esockd:open_udp(Name, ListenOn, merge_default(SocketOpts), - {emqx_sn_gateway, start_link, [Cfg#{ctx => Ctx}]}). + {emqx_sn_conn, start_link, [Cfg#{ctx => Ctx}]}). name(InstaId, Type) -> list_to_atom(lists:concat([InstaId, ":", Type])). diff --git a/apps/emqx_gateway/src/mqttsn/include/emqx_sn.hrl b/apps/emqx_gateway/src/mqttsn/include/emqx_sn.hrl index 29c5b2c86..c7d9ce6b7 100644 --- a/apps/emqx_gateway/src/mqttsn/include/emqx_sn.hrl +++ b/apps/emqx_gateway/src/mqttsn/include/emqx_sn.hrl @@ -49,14 +49,32 @@ -type(mqtt_sn_type() :: ?SN_ADVERTISE..?SN_WILLMSGRESP). --define(SN_RC_ACCEPTED, 16#00). +-define(SN_RC_ACCEPTED, 16#00). -define(SN_RC_CONGESTION, 16#01). -define(SN_RC_INVALID_TOPIC_ID, 16#02). -define(SN_RC_NOT_SUPPORTED, 16#03). +%% Custome Reason code by emqx +-define(SN_RC_NOT_AUTHORIZE, 16#04). +-define(SN_RC_FAILED_SESSION, 16#05). +-define(SN_EXCEED_LIMITATION, 16#06). + +-define(SN_RC_NAME(Rc), + (begin + case Rc of + ?SN_RC_ACCEPTED -> accepted; + ?SN_RC_CONGESTION -> rejected_congestion; + ?SN_RC_INVALID_TOPIC_ID -> rejected_invaild_topic_id; + ?SN_RC_NOT_SUPPORTED -> rejected_not_supported; + ?SN_RC_NOT_AUTHORIZE -> rejected_not_authorize; + ?SN_RC_FAILED_SESSION -> rejected_failed_open_session; + ?SN_EXCEED_LIMITATION -> rejected_exceed_limitation; + _ -> reserved + end + end)). -define(QOS_NEG1, 3). --type(mqtt_sn_return_code() :: ?SN_RC_ACCEPTED .. ?SN_RC_NOT_SUPPORTED). +-type(mqtt_sn_return_code() :: ?SN_RC_ACCEPTED .. ?SN_EXCEED_LIMITATION). %%-------------------------------------------------------------------- %% MQTT-SN Message @@ -139,6 +157,12 @@ #mqtt_sn_message{type = ?SN_SUBSCRIBE, variable = {Flags, MsgId, Topic}}). +-define(SN_SUBSCRIBE_MSG_TYPE(Type, Topic, QoS), + #mqtt_sn_message{type = ?SN_SUBSCRIBE, + variable = { + #mqtt_sn_flags{qos = QoS, topic_id_type = Type}, + _, Topic}}). + -define(SN_SUBACK_MSG(Flags, TopicId, MsgId, ReturnCode), #mqtt_sn_message{type = ?SN_SUBACK, variable = {Flags, TopicId, MsgId, ReturnCode}}). @@ -147,6 +171,12 @@ #mqtt_sn_message{type = ?SN_UNSUBSCRIBE, variable = {Flags, MsgId, Topic}}). +-define(SN_UNSUBSCRIBE_MSG_TYPE(Type, Topic), + #mqtt_sn_message{type = ?SN_UNSUBSCRIBE, + variable = { + #mqtt_sn_flags{topic_id_type = Type}, + _, Topic}}). + -define(SN_UNSUBACK_MSG(MsgId), #mqtt_sn_message{type = ?SN_UNSUBACK, variable = MsgId}). @@ -181,5 +211,4 @@ -define(SN_SHORT_TOPIC, 2). -define(SN_RESERVED_TOPIC, 3). - -define(SN_INVALID_TOPIC_ID, 0).