diff --git a/include/emqx_client.hrl b/include/emqx_client.hrl index bf2f49283..b23a02f65 100644 --- a/include/emqx_client.hrl +++ b/include/emqx_client.hrl @@ -12,7 +12,6 @@ %% See the License for the specific language governing permissions and %% limitations under the License. - -ifndef(EMQX_CLIENT_HRL). -define(EMQX_CLIENT_HRL, true). -include("emqx_mqtt.hrl"). diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl new file mode 100644 index 000000000..9799719b6 --- /dev/null +++ b/src/emqx_channel.erl @@ -0,0 +1,581 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +%% MQTT Channel +-module(emqx_channel). + +-include("emqx.hrl"). +-include("emqx_mqtt.hrl"). +-include("logger.hrl"). +-include("types.hrl"). + +-logger_header("[Channel]"). + +-export([ attrs/1 ]). + +-export([ zone/1 + , client_id/1 + , conn_mod/1 + , endpoint/1 + , proto_ver/1 + , keepalive/1 + , session/1 + ]). + +-export([ init/2 + , handle_in/2 + , handle_out/2 + , handle_timeout/3 + , terminate/2 + ]). + +-export_type([channel/0]). + +-record(channel, { + conn_mod :: maybe(module()), + endpoint :: emqx_endpoint:endpoint(), + proto_name :: binary(), + proto_ver :: emqx_mqtt:version(), + keepalive :: non_neg_integer(), + session :: emqx_session:session(), + will_msg :: emqx_types:message(), + enable_acl :: boolean(), + is_bridge :: boolean(), + connected :: boolean(), + topic_aliases :: map(), + alias_maximum :: map(), + connected_at :: erlang:timestamp() + }). + +-opaque(channel() :: #channel{}). + +attrs(#channel{endpoint = Endpoint, session = Session}) -> + maps:merge(emqx_endpoint:to_map(Endpoint), + emqx_session:attrs(Session)). + +zone(#channel{endpoint = Endpoint}) -> + emqx_endpoint:zone(Endpoint). + +-spec(client_id(channel()) -> emqx_types:client_id()). +client_id(#channel{endpoint = Endpoint}) -> + emqx_endpoint:client_id(Endpoint). + +-spec(conn_mod(channel()) -> module()). +conn_mod(#channel{conn_mod = ConnMod}) -> + ConnMod. + +-spec(endpoint(channel()) -> emqx_endpoint:endpoint()). +endpoint(#channel{endpoint = Endpoint}) -> + Endpoint. + +-spec(proto_ver(channel()) -> emqx_mqtt:version()). +proto_ver(#channel{proto_ver = ProtoVer}) -> + ProtoVer. + +keepalive(#channel{keepalive = Keepalive}) -> + Keepalive. + +-spec(session(channel()) -> emqx_session:session()). +session(#channel{session = Session}) -> + Session. + +-spec(init(map(), proplists:proplist()) -> channel()). +init(ConnInfo = #{peername := Peername, + sockname := Sockname, + conn_mod := ConnMod}, Options) -> + Zone = proplists:get_value(zone, Options), + Peercert = maps:get(peercert, ConnInfo, nossl), + Username = peer_cert_as_username(Peercert, Options), + Mountpoint = emqx_zone:get_env(Zone, mountpoint), + WsCookie = maps:get(ws_cookie, ConnInfo, undefined), + Endpoint = emqx_endpoint:new(#{zone => Zone, + peername => Peername, + sockname => Sockname, + username => Username, + peercert => Peercert, + mountpoint => Mountpoint, + ws_cookie => WsCookie + }), + EnableAcl = emqx_zone:get_env(Zone, enable_acl, false), + #channel{conn_mod = ConnMod, + endpoint = Endpoint, + enable_acl = EnableAcl, + is_bridge = false, + connected = false + }. + +peer_cert_as_username(Peercert, Options) -> + case proplists:get_value(peer_cert_as_username, Options) of + cn -> esockd_peercert:common_name(Peercert); + dn -> esockd_peercert:subject(Peercert); + crt -> Peercert; + _ -> undefined + end. + +%%-------------------------------------------------------------------- +%% Handle incoming packet +%%-------------------------------------------------------------------- + +-spec(handle_in(emqx_mqtt:packet(), channel()) + -> {ok, channel()} + | {ok, emqx_mqtt:packet(), channel()} + | {error, Reason :: term(), channel()} + | {stop, Error :: atom(), channel()}). +handle_in(?CONNECT_PACKET( + #mqtt_packet_connect{proto_name = ProtoName, + proto_ver = ProtoVer, + is_bridge = IsBridge, + client_id = ClientId, + username = Username, + password = Password, + keepalive = Keepalive} = ConnPkt), + Channel = #channel{endpoint = Endpoint}) -> + Endpoint1 = emqx_endpoint:update(#{client_id => ClientId, + username => Username, + password => Password + }, Endpoint), + emqx_logger:set_metadata_client_id(ClientId), + WillMsg = emqx_packet:will_msg(ConnPkt), + Channel1 = Channel#channel{endpoint = Endpoint1, + proto_name = ProtoName, + proto_ver = ProtoVer, + is_bridge = IsBridge, + keepalive = Keepalive, + will_msg = WillMsg + }, + %% fun validate_packet/2, + case pipeline([fun check_connect/2, + fun handle_connect/2], ConnPkt, Channel1) of + {ok, SP, Channel2} -> + handle_out({connack, ?RC_SUCCESS, sp(SP)}, Channel2); + {error, ReasonCode} -> + handle_out({connack, ReasonCode}, Channel1); + {error, ReasonCode, Channel2} -> + handle_out({connack, ReasonCode}, Channel2) + end; + +handle_in(Packet = ?PUBLISH_PACKET(QoS, Topic, PacketId), Channel) -> + case pipeline([fun validate_packet/2, + fun check_pub_caps/2, + fun check_pub_acl/2, + fun handle_publish/2], Packet, Channel) of + {error, ReasonCode} -> + ?LOG(warning, "Cannot publish qos~w message to ~s due to ~s", + [QoS, Topic, emqx_reason_codes:text(ReasonCode)]), + handle_out(case QoS of + ?QOS_0 -> {puberr, ReasonCode}; + ?QOS_1 -> {puback, PacketId, ReasonCode}; + ?QOS_2 -> {pubrec, PacketId, ReasonCode} + end, Channel); + Result -> Result + end; + +handle_in(?PUBACK_PACKET(PacketId, ReasonCode), Channel = #channel{session = Session}) -> + case emqx_session:puback(PacketId, ReasonCode, Session) of + {ok, NSession} -> + {ok, Channel#channel{session = NSession}}; + {error, _NotFound} -> + %% TODO: metrics? error msg? + {ok, Channel} + end; + +handle_in(?PUBREC_PACKET(PacketId, ReasonCode), Channel = #channel{session = Session}) -> + case emqx_session:pubrec(PacketId, ReasonCode, Session) of + {ok, NSession} -> + handle_out({pubrel, PacketId}, Channel#channel{session = NSession}); + {error, ReasonCode} -> + handle_out({pubrel, PacketId, ReasonCode}, Channel) + end; + +handle_in(?PUBREL_PACKET(PacketId, ReasonCode), Channel = #channel{session = Session}) -> + case emqx_session:pubrel(PacketId, ReasonCode, Session) of + {ok, NSession} -> + handle_out({pubcomp, PacketId}, Channel#channel{session = NSession}); + {error, ReasonCode} -> + handle_out({pubcomp, PacketId, ReasonCode}, Channel) + end; + +handle_in(?PUBCOMP_PACKET(PacketId, ReasonCode), Channel = #channel{session = Session}) -> + case emqx_session:pubcomp(PacketId, ReasonCode, Session) of + {ok, NSession} -> + {ok, Channel#channel{session = NSession}}; + {error, _ReasonCode} -> + %% TODO: How to handle the reason code? + {ok, Channel} + end; + +handle_in(?SUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters), + Channel = #channel{endpoint = Endpoint, session = Session}) -> + case check_subscribe(parse_topic_filters(?SUBSCRIBE, RawTopicFilters), Channel) of + {ok, TopicFilters} -> + TopicFilters1 = preprocess_topic_filters(?SUBSCRIBE, Endpoint, + enrich_subopts(TopicFilters, Channel)), + {ok, ReasonCodes, NSession} = emqx_session:subscribe(emqx_endpoint:to_map(Endpoint), + TopicFilters1, Session), + handle_out({suback, PacketId, ReasonCodes}, Channel#channel{session = NSession}); + {error, TopicFilters} -> + {Topics, ReasonCodes} = lists:unzip([{Topic, RC} || {Topic, #{rc := RC}} <- TopicFilters]), + ?LOG(warning, "Cannot subscribe ~p due to ~p", + [Topics, [emqx_reason_codes:text(R) || R <- ReasonCodes]]), + %% Tell the client that all subscriptions has been failed. + ReasonCodes1 = lists:map(fun(?RC_SUCCESS) -> + ?RC_IMPLEMENTATION_SPECIFIC_ERROR; + (RC) -> RC + end, ReasonCodes), + handle_out({suback, PacketId, ReasonCodes1}, Channel) + end; + +handle_in(?UNSUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters), + Channel = #channel{endpoint = Endpoint, session = Session}) -> + TopicFilters = preprocess_topic_filters( + ?UNSUBSCRIBE, Endpoint, + parse_topic_filters(?UNSUBSCRIBE, RawTopicFilters)), + {ok, ReasonCodes, NSession} = emqx_session:unsubscribe(emqx_endpoint:to_map(Endpoint), TopicFilters, Session), + handle_out({unsuback, PacketId, ReasonCodes}, Channel#channel{session = NSession}); + +handle_in(?PACKET(?PINGREQ), Channel) -> + {ok, ?PACKET(?PINGRESP), Channel}; + +handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), Channel) -> + %% Clear will msg + {stop, normal, Channel#channel{will_msg = undefined}}; + +handle_in(?DISCONNECT_PACKET(RC), Channel = #channel{proto_ver = Ver}) -> + %% TODO: + %% {stop, {shutdown, abnormal_disconnet}, Channel}; + {stop, {shutdown, emqx_reason_codes:name(RC, Ver)}, Channel}; + +handle_in(?AUTH_PACKET(), Channel) -> + %%TODO: implement later. + {ok, Channel}; + +handle_in(Packet, Channel) -> + io:format("In: ~p~n", [Packet]), + {ok, Channel}. + +%%-------------------------------------------------------------------- +%% Handle outgoing packet +%%-------------------------------------------------------------------- + +handle_out({connack, ?RC_SUCCESS, SP}, Channel = #channel{endpoint = Endpoint}) -> + ok = emqx_hooks:run('client.connected', + [emqx_endpoint:to_map(Endpoint), ?RC_SUCCESS, attrs(Channel)]), + Props = #{}, %% TODO: ... + {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, Props), Channel}; + +handle_out({connack, ReasonCode}, Channel = #channel{endpoint = Endpoint, + proto_ver = ProtoVer}) -> + ok = emqx_hooks:run('client.connected', + [emqx_endpoint:to_map(Endpoint), ReasonCode, attrs(Channel)]), + ReasonCode1 = if + ProtoVer == ?MQTT_PROTO_V5 -> ReasonCode; + true -> emqx_reason_codes:compat(connack, ReasonCode) + end, + Reason = emqx_reason_codes:name(ReasonCode1, ProtoVer), + {error, Reason, ?CONNACK_PACKET(ReasonCode1), Channel}; + +handle_out(Delivers, Channel = #channel{endpoint = Endpoint, session = Session}) + when is_list(Delivers) -> + case emqx_session:deliver([{Topic, Msg} || {deliver, Topic, Msg} <- Delivers], Session) of + {ok, Publishes, NSession} -> + Credentials = emqx_endpoint:credentials(Endpoint), + Packets = lists:map(fun({publish, PacketId, Msg}) -> + Msg0 = emqx_hooks:run_fold('message.deliver', [Credentials], Msg), + Msg1 = emqx_message:update_expiry(Msg0), + Msg2 = emqx_mountpoint:unmount(emqx_endpoint:mountpoint(Endpoint), Msg1), + emqx_packet:from_message(PacketId, Msg2) + end, Publishes), + {ok, Packets, Channel#channel{session = NSession}}; + {ok, NSession} -> + {ok, Channel#channel{session = NSession}} + end; + +handle_out({publish, PacketId, Msg}, Channel = #channel{endpoint = Endpoint}) -> + Credentials = emqx_endpoint:credentials(Endpoint), + Msg0 = emqx_hooks:run_fold('message.deliver', [Credentials], Msg), + Msg1 = emqx_message:update_expiry(Msg0), + Msg2 = emqx_mountpoint:unmount( + emqx_endpoint:mountpoint(Credentials), Msg1), + {ok, emqx_packet:from_message(PacketId, Msg2), Channel}; + +handle_out({puberr, ReasonCode}, Channel) -> + {ok, Channel}; + +handle_out({puback, PacketId, ReasonCode}, Channel) -> + {ok, ?PUBACK_PACKET(PacketId, ReasonCode), Channel}; + +handle_out({pubrel, PacketId}, Channel) -> + {ok, ?PUBREL_PACKET(PacketId), Channel}; +handle_out({pubrel, PacketId, ReasonCode}, Channel) -> + {ok, ?PUBREL_PACKET(PacketId, ReasonCode), Channel}; + +handle_out({pubrec, PacketId, ReasonCode}, Channel) -> + {ok, ?PUBREC_PACKET(PacketId, ReasonCode), Channel}; + +handle_out({pubcomp, PacketId}, Channel) -> + {ok, ?PUBCOMP_PACKET(PacketId), Channel}; +handle_out({pubcomp, PacketId, ReasonCode}, Channel) -> + {ok, ?PUBCOMP_PACKET(PacketId, ReasonCode), Channel}; + +handle_out({suback, PacketId, ReasonCodes}, Channel = #channel{proto_ver = ?MQTT_PROTO_V5}) -> + %% TODO: ACL Deny + {ok, ?SUBACK_PACKET(PacketId, ReasonCodes), Channel}; +handle_out({suback, PacketId, ReasonCodes}, Channel) -> + %% TODO: ACL Deny + ReasonCodes1 = [emqx_reason_codes:compat(suback, RC) || RC <- ReasonCodes], + {ok, ?SUBACK_PACKET(PacketId, ReasonCodes1), Channel}; + +handle_out({unsuback, PacketId, ReasonCodes}, Channel = #channel{proto_ver = ?MQTT_PROTO_V5}) -> + {ok, ?UNSUBACK_PACKET(PacketId, ReasonCodes), Channel}; +%% Ignore reason codes if not MQTT5 +handle_out({unsuback, PacketId, _ReasonCodes}, Channel) -> + {ok, ?UNSUBACK_PACKET(PacketId), Channel}; + +handle_out(Packet, State) -> + io:format("Out: ~p~n", [Packet]), + {ok, State}. + +handle_deliver(Msg, State) -> + io:format("Msg: ~p~n", [Msg]), + %% Msg -> Pub + {ok, State}. + +handle_timeout(Name, TRef, State) -> + io:format("Timeout: ~s ~p~n", [Name, TRef]), + {ok, State}. + +terminate(Reason, _State) -> + %%io:format("Terminated for ~p~n", [Reason]), + ok. + +%%-------------------------------------------------------------------- +%% Check Connect Packet +%%-------------------------------------------------------------------- + +check_connect(_ConnPkt, Channel) -> + {ok, Channel}. + +%%-------------------------------------------------------------------- +%% Handle Connect Packet +%%-------------------------------------------------------------------- + +handle_connect(#mqtt_packet_connect{proto_name = ProtoName, + proto_ver = ProtoVer, + is_bridge = IsBridge, + clean_start = CleanStart, + keepalive = Keepalive, + properties = ConnProps, + client_id = ClientId, + username = Username, + password = Password} = ConnPkt, + Channel = #channel{endpoint = Endpoint}) -> + Credentials = emqx_endpoint:credentials(Endpoint), + case emqx_access_control:authenticate( + Credentials#{password => Password}) of + {ok, Credentials1} -> + Endpoint1 = emqx_endpoint:update(Credentials1, Endpoint), + %% Open session + case open_session(ConnPkt, Channel) of + {ok, Session, SP} -> + Channel1 = Channel#channel{endpoint = Endpoint1, + session = Session, + connected = true, + connected_at = os:timestamp() + }, + ok = emqx_cm:register_channel(ClientId), + {ok, SP, Channel1}; + {error, Error} -> + ?LOG(error, "Failed to open session: ~p", [Error]), + {error, ?RC_UNSPECIFIED_ERROR, Channel#channel{endpoint = Endpoint1}} + end; + {error, Reason} -> + ?LOG(warning, "Client ~s (Username: '~s') login failed for ~p", + [ClientId, Username, Reason]), + {error, emqx_reason_codes:connack_error(Reason), Channel} + end. + +open_session(#mqtt_packet_connect{clean_start = CleanStart, + %%properties = ConnProps, + client_id = ClientId, + username = Username} = ConnPkt, + Channel = #channel{endpoint = Endpoint}) -> + emqx_cm:open_session(maps:merge(emqx_endpoint:to_map(Endpoint), + #{clean_start => CleanStart, + max_inflight => 0, + expiry_interval => 0})). + +%%-------------------------------------------------------------------- +%% Handle Publish Message: Client -> Broker +%%-------------------------------------------------------------------- + +handle_publish(Packet = ?PUBLISH_PACKET(QoS, Topic, PacketId), + Channel = #channel{endpoint = Endpoint}) -> + Credentials = emqx_endpoint:credentials(Endpoint), + %% TODO: ugly... publish_to_msg(...) + Msg = emqx_packet:to_message(Credentials, Packet), + Msg1 = emqx_mountpoint:mount( + emqx_endpoint:mountpoint(Endpoint), Msg), + Msg2 = emqx_message:set_flag(dup, false, Msg1), + handle_publish(PacketId, Msg2, Channel). + +handle_publish(_PacketId, Msg = #message{qos = ?QOS_0}, Channel) -> + _ = emqx_broker:publish(Msg), + {ok, Channel}; + +handle_publish(PacketId, Msg = #message{qos = ?QOS_1}, Channel) -> + Results = emqx_broker:publish(Msg), + ReasonCode = emqx_reason_codes:puback(Results), + handle_out({puback, PacketId, ReasonCode}, Channel); + +handle_publish(PacketId, Msg = #message{qos = ?QOS_2}, + Channel = #channel{session = Session}) -> + case emqx_session:publish(PacketId, Msg, Session) of + {ok, Results, NSession} -> + ReasonCode = emqx_reason_codes:puback(Results), + handle_out({pubrec, PacketId, ReasonCode}, + Channel#channel{session = NSession}); + {error, ReasonCode} -> + handle_out({pubrec, PacketId, ReasonCode}, Channel) + end. + +%%-------------------------------------------------------------------- +%% Validate Incoming Packet +%%-------------------------------------------------------------------- + +-spec(validate_packet(emqx_mqtt:packet(), channel()) -> ok). +validate_packet(Packet, _Channel) -> + try emqx_packet:validate(Packet) of + true -> ok + catch + error:protocol_error -> + {error, ?RC_PROTOCOL_ERROR}; + error:subscription_identifier_invalid -> + {error, ?RC_SUBSCRIPTION_IDENTIFIERS_NOT_SUPPORTED}; + error:topic_alias_invalid -> + {error, ?RC_TOPIC_ALIAS_INVALID}; + error:topic_filters_invalid -> + {error, ?RC_TOPIC_FILTER_INVALID}; + error:topic_name_invalid -> + {error, ?RC_TOPIC_FILTER_INVALID}; + error:_Reason -> + {error, ?RC_MALFORMED_PACKET} + end. + +%%-------------------------------------------------------------------- +%% Preprocess MQTT Properties +%%-------------------------------------------------------------------- + +%% TODO:... + +%%-------------------------------------------------------------------- +%% Check Publish +%%-------------------------------------------------------------------- + +check_pub_caps(#mqtt_packet{header = #mqtt_packet_header{qos = QoS, + retain = Retain}, + variable = #mqtt_packet_publish{}}, + #channel{endpoint = Endpoint}) -> + emqx_mqtt_caps:check_pub(emqx_endpoint:zone(Endpoint), + #{qos => QoS, retain => Retain}). + +check_pub_acl(_Packet, #channel{enable_acl = false}) -> + ok; +check_pub_acl(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = Topic}}, + #channel{endpoint = Endpoint}) -> + case emqx_endpoint:is_superuser(Endpoint) of + true -> ok; + false -> + do_acl_check(Endpoint, publish, Topic) + end. + +check_sub_acl(_Packet, #channel{enable_acl = false}) -> + ok. + +do_acl_check(Endpoint, PubSub, Topic) -> + case emqx_access_control:check_acl( + emqx_endpoint:to_map(Endpoint), PubSub, Topic) of + allow -> ok; + deny -> {error, ?RC_NOT_AUTHORIZED} + end. + +%%-------------------------------------------------------------------- +%% Check Subscribe Packet +%%-------------------------------------------------------------------- + +check_subscribe(TopicFilters, _Channel) -> + {ok, TopicFilters}. + +%%-------------------------------------------------------------------- +%% Pipeline +%%-------------------------------------------------------------------- + +pipeline([Fun], Packet, Channel) -> + Fun(Packet, Channel); +pipeline([Fun|More], Packet, Channel) -> + case Fun(Packet, Channel) of + ok -> pipeline(More, Packet, Channel); + {ok, NChannel} -> + pipeline(More, Packet, NChannel); + {ok, NPacket, NChannel} -> + pipeline(More, NPacket, NChannel); + {error, Reason} -> + {error, Reason} + end. + +%%-------------------------------------------------------------------- +%% Preprocess topic filters +%%-------------------------------------------------------------------- + +preprocess_topic_filters(Type, Endpoint, TopicFilters) -> + TopicFilters1 = emqx_hooks:run_fold(case Type of + ?SUBSCRIBE -> 'client.subscribe'; + ?UNSUBSCRIBE -> 'client.unsubscribe' + end, + [emqx_endpoint:credentials(Endpoint)], + TopicFilters), + emqx_mountpoint:mount(emqx_endpoint:mountpoint(Endpoint), TopicFilters1). + +%%-------------------------------------------------------------------- +%% Enrich subopts +%%-------------------------------------------------------------------- + +enrich_subopts(TopicFilters, #channel{proto_ver = ?MQTT_PROTO_V5}) -> + TopicFilters; +enrich_subopts(TopicFilters, #channel{endpoint = Endpoint, is_bridge = IsBridge}) -> + Rap = flag(IsBridge), + Nl = flag(emqx_zone:get_env(emqx_endpoint:zone(Endpoint), ignore_loop_deliver, false)), + [{Topic, SubOpts#{rap => Rap, nl => Nl}} || {Topic, SubOpts} <- TopicFilters]. + +%%-------------------------------------------------------------------- +%% Parse topic filters +%%-------------------------------------------------------------------- + +parse_topic_filters(?SUBSCRIBE, TopicFilters) -> + [emqx_topic:parse(Topic, SubOpts) || {Topic, SubOpts} <- TopicFilters]; + +parse_topic_filters(?UNSUBSCRIBE, TopicFilters) -> + lists:map(fun emqx_topic:parse/1, TopicFilters). + +%%-------------------------------------------------------------------- +%% Helper functions +%%-------------------------------------------------------------------- + +sp(true) -> 1; +sp(false) -> 0. + +flag(true) -> 1; +flag(false) -> 0. + diff --git a/src/emqx_cm.erl b/src/emqx_cm.erl index 39b19d9ba..7f230e841 100644 --- a/src/emqx_cm.erl +++ b/src/emqx_cm.erl @@ -208,7 +208,7 @@ open_session(Attrs = #{clean_start := true, client_id := ClientId}) -> CleanStart = fun(_) -> ok = discard_session(ClientId), - {ok, emqx_session:new(Attrs)} + {ok, emqx_session:new(Attrs), false} end, emqx_cm_locker:trans(ClientId, CleanStart); @@ -219,7 +219,7 @@ open_session(Attrs = #{clean_start := false, {ok, Session} -> {ok, Session, true}; {error, not_found} -> - {ok, emqx_session:new(Attrs)} + {ok, emqx_session:new(Attrs), false} end end, emqx_cm_locker:trans(ClientId, ResumeStart). diff --git a/src/emqx_connection.erl b/src/emqx_connection.erl index be192f22e..f58e2fa57 100644 --- a/src/emqx_connection.erl +++ b/src/emqx_connection.erl @@ -14,8 +14,8 @@ %% limitations under the License. %%-------------------------------------------------------------------- -%% MQTT TCP/SSL Channel --module(emqx_channel). +%% MQTT TCP/SSL Connection +-module(emqx_connection). -behaviour(gen_statem). @@ -24,7 +24,7 @@ -include("logger.hrl"). -include("types.hrl"). --logger_header("[Channel]"). +-logger_header("[Conn]"). -export([start_link/3]). @@ -48,21 +48,19 @@ -record(state, { transport :: esockd:transport(), - socket :: esockd:sock(), - peername :: {inet:ip_address(), inet:port_number()}, - sockname :: {inet:ip_address(), inet:port_number()}, + socket :: esockd:socket(), + peername :: emqx_types:peername(), + sockname :: emqx_types:peername(), conn_state :: running | blocked, active_n :: pos_integer(), rate_limit :: maybe(esockd_rate_limit:bucket()), pub_limit :: maybe(esockd_rate_limit:bucket()), limit_timer :: maybe(reference()), - serializer :: emqx_frame:serializer(), %% TODO: remove it later. parse_state :: emqx_frame:parse_state(), - proto_state :: emqx_protocol:protocol(), + chan_state :: emqx_channel:channel(), gc_state :: emqx_gc:gc_state(), - keepalive :: maybe(reference()), - enable_stats :: boolean(), - stats_timer :: maybe(reference()), + keepalive :: maybe(emqx_keepalive:keepalive()), + stats_timer :: disabled | maybe(reference()), idle_timeout :: timeout() }). @@ -71,7 +69,7 @@ -define(CHAN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]). --spec(start_link(esockd:transport(), esockd:sock(), proplists:proplist()) +-spec(start_link(esockd:transport(), esockd:socket(), proplists:proplist()) -> {ok, pid()}). start_link(Transport, Socket, Options) -> {ok, proc_lib:spawn_link(?MODULE, init, [{Transport, Socket, Options}])}. @@ -93,7 +91,7 @@ info(#state{transport = Transport, active_n = ActiveN, rate_limit = RateLimit, pub_limit = PubLimit, - proto_state = ProtoState}) -> + chan_state = ChanState}) -> ConnInfo = #{socktype => Transport:type(Socket), peername => Peername, sockname => Sockname, @@ -102,11 +100,11 @@ info(#state{transport = Transport, rate_limit => rate_limit_info(RateLimit), pub_limit => rate_limit_info(PubLimit) }, - ProtoInfo = emqx_protocol:info(ProtoState), - maps:merge(ConnInfo, ProtoInfo). + ChanInfo = emqx_channel:info(ChanState), + maps:merge(ConnInfo, ChanInfo). rate_limit_info(undefined) -> - #{}; + undefined; rate_limit_info(Limit) -> esockd_rate_limit:info(Limit). @@ -116,13 +114,16 @@ attrs(CPid) when is_pid(CPid) -> attrs(#state{peername = Peername, sockname = Sockname, - proto_state = ProtoState}) -> + conn_state = ConnState, + chan_state = ChanState}) -> SockAttrs = #{peername => Peername, - sockname => Sockname}, - ProtoAttrs = emqx_protocol:attrs(ProtoState), - maps:merge(SockAttrs, ProtoAttrs). + sockname => Sockname, + conn_state => ConnState + }, + ChanAttrs = emqx_channel:attrs(ChanState), + maps:merge(SockAttrs, ChanAttrs). -%% Conn stats +%% @doc Get connection stats stats(CPid) when is_pid(CPid) -> call(CPid, stats); @@ -153,15 +154,16 @@ init({Transport, RawSocket, Options}) -> ActiveN = proplists:get_value(active_n, Options, ?ACTIVE_N), MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE), ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}), - ProtoState = emqx_protocol:init(#{peername => Peername, - sockname => Sockname, - peercert => Peercert, - conn_mod => ?MODULE}, Options), + ChanState = emqx_channel:init(#{peername => Peername, + sockname => Sockname, + peercert => Peercert, + conn_mod => ?MODULE}, Options), GcPolicy = emqx_zone:get_env(Zone, force_gc_policy, false), GcState = emqx_gc:init(GcPolicy), - ok = emqx_misc:init_proc_mng_policy(Zone), EnableStats = emqx_zone:get_env(Zone, enable_stats, true), + StatsTimer = if EnableStats -> undefined; ?Otherwise-> disabled end, IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), + ok = emqx_misc:init_proc_mng_policy(Zone), State = #state{transport = Transport, socket = Socket, peername = Peername, @@ -170,9 +172,9 @@ init({Transport, RawSocket, Options}) -> rate_limit = RateLimit, pub_limit = PubLimit, parse_state = ParseState, - proto_state = ProtoState, + chan_state = ChanState, gc_state = GcState, - enable_stats = EnableStats, + stats_timer = StatsTimer, idle_timeout = IdleTimout }, gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}], @@ -190,17 +192,23 @@ callback_mode() -> %% Idle State idle(enter, _, State) -> - ok = activate_socket(State), - keep_state_and_data; + case activate_socket(State) of + ok -> keep_state_and_data; + {error, Reason} -> + shutdown(Reason, State) + end; idle(timeout, _Timeout, State) -> stop(idle_timeout, State); idle(cast, {incoming, Packet = ?CONNECT_PACKET(ConnVar)}, State) -> - #mqtt_packet_connect{proto_ver = ProtoVer} = ConnVar, - Serializer = emqx_frame:init_serializer(#{version => ProtoVer}), - NState = State#state{serializer = Serializer}, - handle_incoming(Packet, fun(St) -> {next_state, connected, St} end, NState); + handle_incoming(Packet, + fun(St = #state{chan_state = ChanState}) -> + %% Ensure keepalive after connected successfully. + Interval = emqx_channel:keepalive(ChanState), + NextEvent = {next_event, info, {keepalive, start, Interval}}, + {next_state, connected, St, NextEvent} + end, State); idle(cast, {incoming, Packet}, State) -> ?LOG(warning, "Unexpected incoming: ~p", [Packet]), @@ -221,47 +229,45 @@ connected(cast, {incoming, Packet = ?PACKET(?CONNECT)}, State) -> shutdown(unexpected_incoming_connect, State); connected(cast, {incoming, Packet = ?PACKET(Type)}, State) -> - ok = emqx_metrics:inc_recv(Packet), - (Type == ?PUBLISH) andalso emqx_pd:update_counter(incoming_pubs, 1), - handle_incoming(Packet, fun(St) -> {keep_state, St} end, State); + handle_incoming(Packet, fun keep_state/1, State); -%% Handle delivery -connected(info, Devliery = {deliver, _Topic, Msg}, State = #state{proto_state = ProtoState}) -> - case emqx_protocol:handle_out(Devliery, ProtoState) of - {ok, NProtoState} -> - {keep_state, State#state{proto_state = NProtoState}}; - {ok, Packet, NProtoState} -> - NState = State#state{proto_state = NProtoState}, - handle_outgoing(Packet, fun(St) -> {keep_state, St} end, NState); +connected(info, Deliver = {deliver, _Topic, _Msg}, + State = #state{chan_state = ChanState}) -> + Delivers = emqx_misc:drain_deliver([Deliver]), + %% TODO: ... + case BatchLen = length(Delivers) of + 1 -> ok; + N -> io:format("Batch Deliver: ~w~n", [N]) + end, + case emqx_channel:handle_out(Delivers, ChanState) of + {ok, NChanState} -> + keep_state(State#state{chan_state = NChanState}); + {ok, Packets, NChanState} -> + NState = State#state{chan_state = NChanState}, + handle_outgoing(Packets, fun keep_state/1, NState); {error, Reason} -> shutdown(Reason, State) end; %% Start Keepalive -connected(info, {keepalive, start, Interval}, - State = #state{transport = Transport, socket = Socket}) -> - StatFun = fun() -> - case Transport:getstat(Socket, [recv_oct]) of - {ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct}; - Error -> Error - end - end, - case emqx_keepalive:start(StatFun, Interval, {keepalive, check}) of +connected(info, {keepalive, start, Interval}, State) -> + case ensure_keepalive(Interval, State) of + ignore -> keep_state(State); {ok, KeepAlive} -> - {keep_state, State#state{keepalive = KeepAlive}}; - {error, Error} -> - shutdown(Error, State) + keep_state(State#state{keepalive = KeepAlive}); + {error, Reason} -> + shutdown(Reason, State) end; %% Keepalive timer connected(info, {keepalive, check}, State = #state{keepalive = KeepAlive}) -> case emqx_keepalive:check(KeepAlive) of {ok, KeepAlive1} -> - {keep_state, State#state{keepalive = KeepAlive1}}; + keep_state(State#state{keepalive = KeepAlive1}); {error, timeout} -> shutdown(keepalive_timeout, State); - {error, Error} -> - shutdown(Error, State) + {error, Reason} -> + shutdown(Reason, State) end; connected(EventType, Content, State) -> @@ -287,13 +293,13 @@ handle({call, From}, attrs, State) -> handle({call, From}, stats, State) -> reply(From, stats(State), State); -handle({call, From}, kick, State) -> - ok = gen_statem:reply(From, ok), - shutdown(kicked, State); +%%handle({call, From}, kick, State) -> +%% ok = gen_statem:reply(From, ok), +%% shutdown(kicked, State); -handle({call, From}, discard, State) -> - ok = gen_statem:reply(From, ok), - shutdown(discard, State); +%%handle({call, From}, discard, State) -> +%% ok = gen_statem:reply(From, ok), +%% shutdown(discard, State); handle({call, From}, Req, State) -> ?LOG(error, "Unexpected call: ~p", [Req]), @@ -302,16 +308,16 @@ handle({call, From}, Req, State) -> %% Handle cast handle(cast, Msg, State) -> ?LOG(error, "Unexpected cast: ~p", [Msg]), - {keep_state, State}; + keep_state(State); %% Handle Incoming handle(info, {Inet, _Sock, Data}, State) when Inet == tcp; Inet == ssl -> - Oct = iolist_size(Data), ?LOG(debug, "RECV ~p", [Data]), + Oct = iolist_size(Data), emqx_pd:update_counter(incoming_bytes, Oct), ok = emqx_metrics:inc('bytes.received', Oct), - NState = ensure_stats_timer(maybe_gc({1, Oct}, State)), + NState = ensure_stats_timer(maybe_gc(1, Oct, State)), process_incoming(Data, [], NState); handle(info, {Error, _Sock, Reason}, State) @@ -326,32 +332,40 @@ handle(info, {Passive, _Sock}, State) when Passive == tcp_passive; Passive == ssl_passive -> %% Rate limit here:) NState = ensure_rate_limit(State), - ok = activate_socket(NState), - {keep_state, NState}; + case activate_socket(NState) of + ok -> keep_state(NState); + {error, Reason} -> + shutdown(Reason, NState) + end; handle(info, activate_socket, State) -> %% Rate limit timer expired. - ok = activate_socket(State#state{conn_state = running}), - {keep_state, State#state{conn_state = running, limit_timer = undefined}}; + NState = State#state{conn_state = running}, + case activate_socket(NState) of + ok -> + keep_state(NState#state{limit_timer = undefined}); + {error, Reason} -> + shutdown(Reason, NState) + end; handle(info, {inet_reply, _Sock, ok}, State) -> %% something sent - {keep_state, ensure_stats_timer(State)}; + keep_state(ensure_stats_timer(State)); handle(info, {inet_reply, _Sock, {error, Reason}}, State) -> shutdown(Reason, State); handle(info, {timeout, Timer, emit_stats}, State = #state{stats_timer = Timer, - proto_state = ProtoState, - gc_state = GcState}) -> - ClientId = emqx_protocol:client_id(ProtoState), - emqx_cm:set_conn_stats(ClientId, stats(State)), + chan_state = ChanState, + gc_state = GcState}) -> + ClientId = emqx_channel:client_id(ChanState), + ok = emqx_cm:set_conn_stats(ClientId, stats(State)), NState = State#state{stats_timer = undefined}, Limits = erlang:get(force_shutdown_policy), case emqx_misc:conn_proc_mng_policy(Limits) of continue -> - {keep_state, NState}; + keep_state(NState); hibernate -> %% going to hibernate, reset gc stats GcState1 = emqx_gc:reset(GcState), @@ -374,7 +388,7 @@ handle(info, {shutdown, Reason}, State) -> handle(info, Info, State) -> ?LOG(error, "Unexpected info: ~p", [Info]), - {keep_state, State}. + keep_state(State). code_change(_Vsn, State, Data, _Extra) -> {ok, State, Data}. @@ -382,11 +396,11 @@ code_change(_Vsn, State, Data, _Extra) -> terminate(Reason, _StateName, #state{transport = Transport, socket = Socket, keepalive = KeepAlive, - proto_state = ProtoState}) -> + chan_state = ChanState}) -> ?LOG(debug, "Terminated for ~p", [Reason]), ok = Transport:fast_close(Socket), ok = emqx_keepalive:cancel(KeepAlive), - emqx_protocol:terminate(Reason, ProtoState). + emqx_channel:terminate(Reason, ChanState). %%-------------------------------------------------------------------- %% Process incoming data @@ -420,39 +434,74 @@ next_events(Packet) -> %% Handle incoming packet handle_incoming(Packet = ?PACKET(Type), SuccFun, - State = #state{proto_state = ProtoState}) -> + State = #state{chan_state = ChanState}) -> _ = inc_incoming_stats(Type), + ok = emqx_metrics:inc_recv(Packet), ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]), - case emqx_protocol:handle_in(Packet, ProtoState) of - {ok, NProtoState} -> - SuccFun(State#state{proto_state = NProtoState}); - {ok, OutPacket, NProtoState} -> + case emqx_channel:handle_in(Packet, ChanState) of + {ok, NChanState} -> + SuccFun(State#state{chan_state = NChanState}); + {ok, OutPacket, NChanState} -> handle_outgoing(OutPacket, SuccFun, - State#state{proto_state = NProtoState}); - {error, Reason} -> - shutdown(Reason, State); - {error, Reason, NProtoState} -> - shutdown(Reason, State#state{proto_state = NProtoState}); - {stop, Error, NProtoState} -> - stop(Error, State#state{proto_state = NProtoState}) + State#state{chan_state = NChanState}); + {error, Reason, NChanState} -> + shutdown(Reason, State#state{chan_state = NChanState}); + {stop, Error, NChanState} -> + stop(Error, State#state{chan_state = NChanState}) end. %%-------------------------------------------------------------------- -%% Handle outgoing packet +%% Handle outgoing packets -handle_outgoing(Packet = ?PACKET(Type), SuccFun, - State = #state{transport = Transport, - socket = Socket, - serializer = Serializer}) -> - _ = inc_outgoing_stats(Type), +handle_outgoing(Packets, SuccFun, State = #state{chan_state = ChanState}) + when is_list(Packets) -> + ProtoVer = emqx_channel:proto_ver(ChanState), + IoData = lists:foldl( + fun(Packet = ?PACKET(Type), Acc) -> + ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), + _ = inc_outgoing_stats(Type), + [emqx_frame:serialize(Packet, ProtoVer)|Acc] + end, [], Packets), + send(lists:reverse(IoData), SuccFun, State); + +handle_outgoing(Packet = ?PACKET(Type), SuccFun, State = #state{chan_state = ChanState}) -> ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), - Data = Serializer(Packet), - case Transport:async_send(Socket, Data) of - ok -> SuccFun(State); + _ = inc_outgoing_stats(Type), + ProtoVer = emqx_channel:proto_ver(ChanState), + IoData = emqx_frame:serialize(Packet, ProtoVer), + send(IoData, SuccFun, State). + +%%-------------------------------------------------------------------- +%% Send data + +send(IoData, SuccFun, State = #state{transport = Transport, socket = Socket}) -> + Oct = iolist_size(IoData), + ok = emqx_metrics:inc('bytes.sent', Oct), + case Transport:async_send(Socket, IoData) of + ok -> SuccFun(maybe_gc(1, Oct, State)); {error, Reason} -> shutdown(Reason, State) end. +%%-------------------------------------------------------------------- +%% Ensure keepalive + +ensure_keepalive(0, State) -> + ignore; +ensure_keepalive(Interval, State = #state{transport = Transport, + socket = Socket, + chan_state = ChanState}) -> + StatFun = fun() -> + case Transport:getstat(Socket, [recv_oct]) of + {ok, [{recv_oct, RecvOct}]} -> + {ok, RecvOct}; + Error -> Error + end + end, + Backoff = emqx_zone:get_env(emqx_channel:zone(ChanState), + keepalive_backoff, 0.75), + emqx_keepalive:start(StatFun, round(Interval * Backoff), {keepalive, check}). + %%-------------------------------------------------------------------- %% Ensure rate limit @@ -466,82 +515,77 @@ ensure_rate_limit([], State) -> ensure_rate_limit([{undefined, _Pos, _Cnt}|Limiters], State) -> ensure_rate_limit(Limiters, State); ensure_rate_limit([{Rl, Pos, Cnt}|Limiters], State) -> - case esockd_rate_limit:check(Cnt, Rl) of - {0, Rl1} -> - ensure_rate_limit(Limiters, setelement(Pos, State, Rl1)); - {Pause, Rl1} -> - ?LOG(debug, "Rate limit pause connection ~pms", [Pause]), - TRef = erlang:send_after(Pause, self(), activate_socket), - setelement(Pos, State#state{conn_state = blocked, limit_timer = TRef}, Rl1) - end. - -%% start_keepalive(0, _PState) -> -%% ignore; -%% start_keepalive(Secs, #pstate{zone = Zone}) when Secs > 0 -> -%% Backoff = emqx_zone:get_env(Zone, keepalive_backoff, 0.75), -%% self() ! {keepalive, start, round(Secs * Backoff)}. + case esockd_rate_limit:check(Cnt, Rl) of + {0, Rl1} -> + ensure_rate_limit(Limiters, setelement(Pos, State, Rl1)); + {Pause, Rl1} -> + ?LOG(debug, "Rate limit pause connection ~pms", [Pause]), + TRef = erlang:send_after(Pause, self(), activate_socket), + setelement(Pos, State#state{conn_state = blocked, limit_timer = TRef}, Rl1) + end. %%-------------------------------------------------------------------- -%% Activate socket +%% Activate Socket activate_socket(#state{conn_state = blocked}) -> ok; - -activate_socket(#state{transport = Transport, socket = Socket, active_n = N}) -> - case Transport:setopts(Socket, [{active, N}]) of - ok -> ok; - {error, Reason} -> - self() ! {shutdown, Reason}, - ok - end. +activate_socket(#state{transport = Transport, + socket = Socket, + active_n = N}) -> + Transport:setopts(Socket, [{active, N}]). %%-------------------------------------------------------------------- %% Inc incoming/outgoing stats inc_incoming_stats(Type) -> emqx_pd:update_counter(recv_pkt, 1), - Type =:= ?PUBLISH andalso emqx_pd:update_counter(recv_msg, 1). + case Type == ?PUBLISH of + true -> + emqx_pd:update_counter(recv_msg, 1), + emqx_pd:update_counter(incoming_pubs, 1); + false -> ok + end. inc_outgoing_stats(Type) -> emqx_pd:update_counter(send_pkt, 1), - Type =:= ?PUBLISH andalso emqx_pd:update_counter(send_msg, 1). + (Type == ?PUBLISH) + andalso emqx_pd:update_counter(send_msg, 1). %%-------------------------------------------------------------------- %% Ensure stats timer -ensure_stats_timer(State = #state{enable_stats = true, - stats_timer = undefined, +ensure_stats_timer(State = #state{stats_timer = undefined, idle_timeout = IdleTimeout}) -> State#state{stats_timer = emqx_misc:start_timer(IdleTimeout, emit_stats)}; +%% disabled or timer existed ensure_stats_timer(State) -> State. %%-------------------------------------------------------------------- %% Maybe GC -maybe_gc(_, State = #state{gc_state = undefined}) -> +maybe_gc(_Cnt, _Oct, State = #state{gc_state = undefined}) -> State; -maybe_gc({publish, _, #message{payload = Payload}}, State) -> - Oct = iolist_size(Payload), - maybe_gc({1, Oct}, State); -maybe_gc(Packets, State) when is_list(Packets) -> - {Cnt, Oct} = - lists:unzip([{1, iolist_size(Payload)} - || {publish, _, #message{payload = Payload}} <- Packets]), - maybe_gc({lists:sum(Cnt), lists:sum(Oct)}, State); -maybe_gc({Cnt, Oct}, State = #state{gc_state = GCSt}) -> +maybe_gc(Cnt, Oct, State = #state{gc_state = GCSt}) -> {_, GCSt1} = emqx_gc:run(Cnt, Oct, GCSt), - State#state{gc_state = GCSt1}; -maybe_gc(_, State) -> State. + %% TODO: gc metric? + State#state{gc_state = GCSt1}. %%-------------------------------------------------------------------- %% Helper functions +-compile({inline, [reply/3]}). reply(From, Reply, State) -> {keep_state, State, [{reply, From, Reply}]}. +-compile({inline, [keep_state/1]}). +keep_state(State) -> + {keep_state, State}. + +-compile({inline, [shutdown/2]}). shutdown(Reason, State) -> stop({shutdown, Reason}, State). +-compile({inline, [stop/2]}). stop(Reason, State) -> {stop, Reason, State}. diff --git a/src/emqx_endpoint.erl b/src/emqx_endpoint.erl new file mode 100644 index 000000000..1529698aa --- /dev/null +++ b/src/emqx_endpoint.erl @@ -0,0 +1,95 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_endpoint). + +-include("types.hrl"). + +%% APIs +-export([ new/0 + , new/1 + ]). + +-export([ zone/1 + , client_id/1 + , mountpoint/1 + , is_superuser/1 + , credentials/1 + ]). + +-export([update/2]). + +-export([to_map/1]). + +-export_type([endpoint/0]). + +-opaque(endpoint() :: + {endpoint, + #{zone := emqx_types:zone(), + peername := emqx_types:peername(), + sockname => emqx_types:peername(), + client_id := emqx_types:client_id(), + username := emqx_types:username(), + peercert := esockd_peercert:peercert(), + is_superuser := boolean(), + mountpoint := maybe(binary()), + ws_cookie := maybe(list()), + password => binary(), + auth_result => emqx_types:auth_result(), + anonymous => boolean(), + atom() => term() + } + }). + +-define(Endpoint(M), {endpoint, M}). + +-define(Default, #{is_superuser => false, + anonymous => false + }). + +-spec(new() -> endpoint()). +new() -> + ?Endpoint(?Default). + +-spec(new(map()) -> endpoint()). +new(M) when is_map(M) -> + ?Endpoint(maps:merge(?Default, M)). + +-spec(zone(endpoint()) -> emqx_zone:zone()). +zone(?Endpoint(#{zone := Zone})) -> + Zone. + +client_id(?Endpoint(#{client_id := ClientId})) -> + ClientId. + +-spec(mountpoint(endpoint()) -> maybe(binary())). +mountpoint(?Endpoint(#{mountpoint := Mountpoint})) -> + Mountpoint; +mountpoint(_) -> undefined. + +is_superuser(?Endpoint(#{is_superuser := B})) -> + B. + +update(Attrs, ?Endpoint(M)) -> + ?Endpoint(maps:merge(M, Attrs)). + +credentials(?Endpoint(M)) -> + M. %% TODO: ... + +-spec(to_map(endpoint()) -> map()). +to_map(?Endpoint(M)) -> + M. + diff --git a/src/emqx_frame.erl b/src/emqx_frame.erl index 3c46e50db..3a18ab297 100644 --- a/src/emqx_frame.erl +++ b/src/emqx_frame.erl @@ -21,7 +21,6 @@ -export([ initial_parse_state/0 , initial_parse_state/1 - , init_serializer/1 ]). -export([ parse/1 @@ -386,18 +385,15 @@ parse_binary_data(<>) -> %% Serialize MQTT Packet %%-------------------------------------------------------------------- -init_serializer(Options) -> - fun(Packet) -> serialize(Packet, Options) end. - -spec(serialize(emqx_mqtt:packet()) -> iodata()). serialize(Packet) -> - serialize(Packet, ?DEFAULT_OPTIONS). + serialize(Packet, ?MQTT_PROTO_V4). --spec(serialize(emqx_mqtt:packet(), options()) -> iodata()). +-spec(serialize(emqx_mqtt:packet(), emqx_mqtt:version()) -> iodata()). serialize(#mqtt_packet{header = Header, variable = Variable, - payload = Payload}, Options) when is_map(Options) -> - serialize(Header, serialize_variable(Variable, merge_opts(Options)), serialize_payload(Payload)). + payload = Payload}, Ver) -> + serialize(Header, serialize_variable(Variable, Ver), serialize_payload(Payload)). serialize(#mqtt_packet_header{type = Type, dup = Dup, @@ -424,7 +420,7 @@ serialize_variable(#mqtt_packet_connect{ will_topic = WillTopic, will_payload = WillPayload, username = Username, - password = Password}, _Options) -> + password = Password}, _Ver) -> [serialize_binary_data(ProtoName), <<(case IsBridge of true -> 16#80 + ProtoVer; @@ -451,14 +447,12 @@ serialize_variable(#mqtt_packet_connect{ serialize_variable(#mqtt_packet_connack{ack_flags = AckFlags, reason_code = ReasonCode, - properties = Properties}, - #{version := Ver}) -> + properties = Properties}, Ver) -> [AckFlags, ReasonCode, serialize_properties(Properties, Ver)]; serialize_variable(#mqtt_packet_publish{topic_name = TopicName, packet_id = PacketId, - properties = Properties}, - #{version := Ver}) -> + properties = Properties}, Ver) -> [serialize_utf8_string(TopicName), if PacketId =:= undefined -> <<>>; @@ -466,59 +460,54 @@ serialize_variable(#mqtt_packet_publish{topic_name = TopicName, end, serialize_properties(Properties, Ver)]; -serialize_variable(#mqtt_packet_puback{packet_id = PacketId}, - #{version := Ver}) +serialize_variable(#mqtt_packet_puback{packet_id = PacketId}, Ver) when Ver == ?MQTT_PROTO_V3; Ver == ?MQTT_PROTO_V4 -> <>; serialize_variable(#mqtt_packet_puback{packet_id = PacketId, reason_code = ReasonCode, properties = Properties}, - #{version := ?MQTT_PROTO_V5}) -> + ?MQTT_PROTO_V5) -> [<>, ReasonCode, serialize_properties(Properties, ?MQTT_PROTO_V5)]; serialize_variable(#mqtt_packet_subscribe{packet_id = PacketId, properties = Properties, - topic_filters = TopicFilters}, - #{version := Ver}) -> + topic_filters = TopicFilters}, Ver) -> [<>, serialize_properties(Properties, Ver), serialize_topic_filters(subscribe, TopicFilters, Ver)]; serialize_variable(#mqtt_packet_suback{packet_id = PacketId, properties = Properties, - reason_codes = ReasonCodes}, - #{version := Ver}) -> + reason_codes = ReasonCodes}, Ver) -> [<>, serialize_properties(Properties, Ver), serialize_reason_codes(ReasonCodes)]; serialize_variable(#mqtt_packet_unsubscribe{packet_id = PacketId, properties = Properties, - topic_filters = TopicFilters}, - #{version := Ver}) -> + topic_filters = TopicFilters}, Ver) -> [<>, serialize_properties(Properties, Ver), serialize_topic_filters(unsubscribe, TopicFilters, Ver)]; serialize_variable(#mqtt_packet_unsuback{packet_id = PacketId, properties = Properties, - reason_codes = ReasonCodes}, - #{version := Ver}) -> + reason_codes = ReasonCodes}, Ver) -> [<>, serialize_properties(Properties, Ver), serialize_reason_codes(ReasonCodes)]; -serialize_variable(#mqtt_packet_disconnect{}, #{version := Ver}) +serialize_variable(#mqtt_packet_disconnect{}, Ver) when Ver == ?MQTT_PROTO_V3; Ver == ?MQTT_PROTO_V4 -> <<>>; serialize_variable(#mqtt_packet_disconnect{reason_code = ReasonCode, properties = Properties}, - #{version := Ver = ?MQTT_PROTO_V5}) -> + Ver = ?MQTT_PROTO_V5) -> [ReasonCode, serialize_properties(Properties, Ver)]; serialize_variable(#mqtt_packet_disconnect{}, _Ver) -> <<>>; serialize_variable(#mqtt_packet_auth{reason_code = ReasonCode, properties = Properties}, - #{version := Ver = ?MQTT_PROTO_V5}) -> + Ver = ?MQTT_PROTO_V5) -> [ReasonCode, serialize_properties(Properties, Ver)]; serialize_variable(PacketId, ?MQTT_PROTO_V3) when is_integer(PacketId) -> diff --git a/src/emqx_listeners.erl b/src/emqx_listeners.erl index 043b55fcc..94babe7fd 100644 --- a/src/emqx_listeners.erl +++ b/src/emqx_listeners.erl @@ -75,7 +75,7 @@ start_listener(Proto, ListenOn, Options) when Proto == https; Proto == wss -> start_mqtt_listener(Name, ListenOn, Options) -> SockOpts = esockd:parse_opt(Options), esockd:open(Name, ListenOn, merge_default(SockOpts), - {emqx_channel, start_link, [Options -- SockOpts]}). + {emqx_connection, start_link, [Options -- SockOpts]}). start_http_listener(Start, Name, ListenOn, RanchOpts, ProtoOpts) -> Start(Name, with_port(ListenOn, RanchOpts), ProtoOpts). @@ -84,7 +84,7 @@ mqtt_path(Options) -> proplists:get_value(mqtt_path, Options, "/mqtt"). ws_opts(Options) -> - Dispatch = cowboy_router:compile([{'_', [{mqtt_path(Options), emqx_ws_channel, Options}]}]), + Dispatch = cowboy_router:compile([{'_', [{mqtt_path(Options), emqx_ws_connection, Options}]}]), #{env => #{dispatch => Dispatch}, proxy_header => proplists:get_value(proxy_protocol, Options, false)}. ranch_opts(Options) -> diff --git a/src/emqx_misc.erl b/src/emqx_misc.erl index 26c97c92f..e04fd606d 100644 --- a/src/emqx_misc.erl +++ b/src/emqx_misc.erl @@ -29,7 +29,14 @@ , conn_proc_mng_policy/1 ]). --export([drain_down/1]). +-export([ drain_deliver/1 + , drain_down/1 + ]). + +-compile({inline, + [ start_timer/2 + , start_timer/3 + ]}). %% @doc Merge options -spec(merge_opts(list(), list()) -> list()). @@ -121,6 +128,16 @@ proc_info(Key) -> {Key, Value} = erlang:process_info(self(), Key), Value. +%% @doc Drain delivers from channel's mailbox. +drain_deliver(Acc) -> + receive + Deliver = {deliver, _Topic, _Msg} -> + drain_deliver([Deliver|Acc]) + after 0 -> + lists:reverse(Acc) + end. + +%% @doc Drain process down events. -spec(drain_down(pos_integer()) -> list(pid())). drain_down(Cnt) when Cnt > 0 -> drain_down(Cnt, []). diff --git a/src/emqx_mountpoint.erl b/src/emqx_mountpoint.erl index 1e52d9d25..80a65b743 100644 --- a/src/emqx_mountpoint.erl +++ b/src/emqx_mountpoint.erl @@ -19,8 +19,6 @@ -include("emqx.hrl"). -include("logger.hrl"). --logger_header("[Mountpoint]"). - -export([ mount/2 , unmount/2 ]). @@ -41,7 +39,8 @@ mount(MountPoint, Msg = #message{topic = Topic}) -> Msg#message{topic = <>}; mount(MountPoint, TopicFilters) when is_list(TopicFilters) -> - [{<>, SubOpts} || {Topic, SubOpts} <- TopicFilters]. + [{<>, SubOpts} + || {Topic, SubOpts} <- TopicFilters]. unmount(undefined, Msg) -> Msg; @@ -49,8 +48,7 @@ unmount(MountPoint, Msg = #message{topic = Topic}) -> try split_binary(Topic, byte_size(MountPoint)) of {MountPoint, Topic1} -> Msg#message{topic = Topic1} catch - _Error:Reason -> - ?LOG(error, "Unmount error : ~p", [Reason]), + error:badarg-> Msg end. diff --git a/src/emqx_pd.erl b/src/emqx_pd.erl index 5d1277833..9800d9d57 100644 --- a/src/emqx_pd.erl +++ b/src/emqx_pd.erl @@ -24,6 +24,12 @@ , reset_counter/1 ]). +-compile({inline, + [ update_counter/2 + , get_counter/1 + , reset_counter/1 + ]}). + -type(key() :: term()). -spec(update_counter(key(), number()) -> maybe(number())). diff --git a/src/emqx_protocol.erl b/src/emqx_protocol.erl deleted file mode 100644 index b1396db3c..000000000 --- a/src/emqx_protocol.erl +++ /dev/null @@ -1,1010 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - -%% MQTT Protocol --module(emqx_protocol). - --include("emqx.hrl"). --include("emqx_mqtt.hrl"). --include("logger.hrl"). --include("types.hrl"). - --logger_header("[Protocol]"). - --export([ info/1 - , attrs/1 - , attr/2 - , caps/1 - , caps/2 - , client_id/1 - , credentials/1 - , session/1 - ]). - --export([ init/2 - , handle_in/2 - , handle_out/2 - , handle_timeout/3 - , terminate/2 - ]). - --export_type([protocol/0]). - --record(protocol, { - zone :: emqx_zone:zone(), - conn_mod :: module(), - sendfun, - sockname, - peername, - peercert, - proto_ver :: emqx_mqtt:version(), - proto_name, - client_id :: maybe(emqx_types:client_id()), - is_assigned, - username :: maybe(emqx_types:username()), - conn_props, - ack_props, - credentials :: map(), - session :: maybe(emqx_session:session()), - clean_start, - topic_aliases, - will_topic, - will_msg, - keepalive, - is_bridge :: boolean(), - connected :: boolean(), - connected_at :: erlang:timestamp(), - topic_alias_maximum, - ws_cookie - }). - --opaque(protocol() :: #protocol{}). - --ifdef(TEST). --compile(export_all). --compile(nowarn_export_all). --endif. - --define(NO_PROPS, undefined). - -%%-------------------------------------------------------------------- -%% Init -%%-------------------------------------------------------------------- - --spec(init(map(), list()) -> protocol()). -init(SocketOpts = #{sockname := Sockname, - peername := Peername, - peercert := Peercert}, Options) -> - Zone = proplists:get_value(zone, Options), - #protocol{zone = Zone, - %%sendfun = SendFun, - sockname = Sockname, - peername = Peername, - peercert = Peercert, - proto_ver = ?MQTT_PROTO_V4, - proto_name = <<"MQTT">>, - client_id = <<>>, - is_assigned = false, - %%conn_pid = self(), - username = init_username(Peercert, Options), - clean_start = false, - topic_aliases = #{}, - is_bridge = false, - connected = false, - topic_alias_maximum = #{to_client => 0, from_client => 0}, - conn_mod = maps:get(conn_mod, SocketOpts, undefined), - credentials = #{}, - ws_cookie = maps:get(ws_cookie, SocketOpts, undefined) - }. - -init_username(Peercert, Options) -> - case proplists:get_value(peer_cert_as_username, Options) of - cn -> esockd_peercert:common_name(Peercert); - dn -> esockd_peercert:subject(Peercert); - crt -> Peercert; - _ -> undefined - end. - -set_username(Username, PState = #protocol{username = undefined}) -> - PState#protocol{username = Username}; -set_username(_Username, PState) -> - PState. - -%%-------------------------------------------------------------------- -%% API -%%-------------------------------------------------------------------- - -info(PState = #protocol{zone = Zone, - conn_props = ConnProps, - ack_props = AckProps, - session = Session, - topic_aliases = Aliases, - will_msg = WillMsg}) -> - maps:merge(attrs(PState), #{conn_props => ConnProps, - ack_props => AckProps, - session => Session, - topic_aliases => Aliases, - will_msg => WillMsg, - enable_acl => emqx_zone:get_env(Zone, enable_acl, false) - }). - -attrs(#protocol{zone = Zone, - client_id = ClientId, - username = Username, - peername = Peername, - peercert = Peercert, - clean_start = CleanStart, - proto_ver = ProtoVer, - proto_name = ProtoName, - keepalive = Keepalive, - is_bridge = IsBridge, - connected_at = ConnectedAt, - conn_mod = ConnMod, - credentials = Credentials}) -> - #{zone => Zone, - client_id => ClientId, - username => Username, - peername => Peername, - peercert => Peercert, - proto_ver => ProtoVer, - proto_name => ProtoName, - clean_start => CleanStart, - keepalive => Keepalive, - is_bridge => IsBridge, - connected_at => ConnectedAt, - conn_mod => ConnMod, - credentials => Credentials - }. - -attr(proto_ver, #protocol{proto_ver = ProtoVer}) -> - ProtoVer; -attr(max_inflight, #protocol{proto_ver = ?MQTT_PROTO_V5, conn_props = ConnProps}) -> - get_property('Receive-Maximum', ConnProps, 65535); -attr(max_inflight, #protocol{zone = Zone}) -> - emqx_zone:get_env(Zone, max_inflight, 65535); -attr(expiry_interval, #protocol{proto_ver = ?MQTT_PROTO_V5, conn_props = ConnProps}) -> - get_property('Session-Expiry-Interval', ConnProps, 0); -attr(expiry_interval, #protocol{zone = Zone, clean_start = CleanStart}) -> - case CleanStart of - true -> 0; - false -> emqx_zone:get_env(Zone, session_expiry_interval, 16#ffffffff) - end; -attr(topic_alias_maximum, #protocol{proto_ver = ?MQTT_PROTO_V5, conn_props = ConnProps}) -> - get_property('Topic-Alias-Maximum', ConnProps, 0); -attr(topic_alias_maximum, #protocol{zone = Zone}) -> - emqx_zone:get_env(Zone, max_topic_alias, 0); -attr(Name, PState) -> - Attrs = lists:zip(record_info(fields, protocol), tl(tuple_to_list(PState))), - case lists:keyfind(Name, 1, Attrs) of - {_, Value} -> Value; - false -> undefined - end. - -caps(Name, PState) -> - maps:get(Name, caps(PState)). - -caps(#protocol{zone = Zone}) -> - emqx_mqtt_caps:get_caps(Zone). - -client_id(#protocol{client_id = ClientId}) -> - ClientId. - -credentials(#protocol{zone = Zone, - client_id = ClientId, - username = Username, - sockname = Sockname, - peername = Peername, - peercert = Peercert, - ws_cookie = WsCookie}) -> - with_cert(#{zone => Zone, - client_id => ClientId, - sockname => Sockname, - username => Username, - peername => Peername, - ws_cookie => WsCookie, - mountpoint => emqx_zone:get_env(Zone, mountpoint)}, Peercert). - -with_cert(Credentials, undefined) -> Credentials; -with_cert(Credentials, Peercert) -> - Credentials#{dn => esockd_peercert:subject(Peercert), - cn => esockd_peercert:common_name(Peercert)}. - -keepsafety(Credentials) -> - maps:filter(fun(password, _) -> false; - (dn, _) -> false; - (cn, _) -> false; - (_, _) -> true end, Credentials). - -session(#protocol{session = Session}) -> - Session. - -%%-------------------------------------------------------------------- -%% Packet Received -%%-------------------------------------------------------------------- - -set_protover(?CONNECT_PACKET(#mqtt_packet_connect{proto_ver = ProtoVer}), PState) -> - PState#protocol{proto_ver = ProtoVer}; -set_protover(_Packet, PState) -> - PState. - -handle_in(?PACKET(Type), PState = #protocol{connected = false}) when Type =/= ?CONNECT -> - {error, proto_not_connected, PState}; - -handle_in(?PACKET(?CONNECT), PState = #protocol{connected = true}) -> - {error, proto_unexpected_connect, PState}; - -handle_in(Packet = ?PACKET(_Type), PState) -> - PState1 = set_protover(Packet, PState), - try emqx_packet:validate(Packet) of - true -> - case preprocess_properties(Packet, PState1) of - {ok, Packet1, PState2} -> - process(Packet1, PState2); - {error, ReasonCode} -> - handle_out({disconnect, ReasonCode}, PState1) - end - catch - error:protocol_error -> - handle_out({disconnect, ?RC_PROTOCOL_ERROR}, PState1); - error:subscription_identifier_invalid -> - handle_out({disconnect, ?RC_SUBSCRIPTION_IDENTIFIERS_NOT_SUPPORTED}, PState1); - error:topic_alias_invalid -> - handle_out({disconnect, ?RC_TOPIC_ALIAS_INVALID}, PState1); - error:topic_filters_invalid -> - handle_out({disconnect, ?RC_TOPIC_FILTER_INVALID}, PState1); - error:topic_name_invalid -> - handle_out({disconnect, ?RC_TOPIC_FILTER_INVALID}, PState1); - error:_Reason -> - %% TODO: {error, Reason, PState1} - handle_out({disconnect, ?RC_MALFORMED_PACKET}, PState1) - end. - -%%-------------------------------------------------------------------- -%% Preprocess MQTT Properties -%%-------------------------------------------------------------------- - -preprocess_properties(Packet = #mqtt_packet{ - variable = #mqtt_packet_connect{ - properties = #{'Topic-Alias-Maximum' := ToClient} - } - }, - PState = #protocol{topic_alias_maximum = TopicAliasMaximum}) -> - {ok, Packet, PState#protocol{topic_alias_maximum = TopicAliasMaximum#{to_client => ToClient}}}; - -%% Subscription Identifier -preprocess_properties(Packet = #mqtt_packet{ - variable = Subscribe = #mqtt_packet_subscribe{ - properties = #{'Subscription-Identifier' := SubId}, - topic_filters = TopicFilters - } - }, - PState = #protocol{proto_ver = ?MQTT_PROTO_V5}) -> - TopicFilters1 = [{Topic, SubOpts#{subid => SubId}} || {Topic, SubOpts} <- TopicFilters], - {ok, Packet#mqtt_packet{variable = Subscribe#mqtt_packet_subscribe{topic_filters = TopicFilters1}}, PState}; - -%% Topic Alias Mapping -preprocess_properties(#mqtt_packet{ - variable = #mqtt_packet_publish{ - properties = #{'Topic-Alias' := 0}} - }, - PState) -> - {error, ?RC_TOPIC_ALIAS_INVALID}; - -preprocess_properties(Packet = #mqtt_packet{ - variable = Publish = #mqtt_packet_publish{ - topic_name = <<>>, - properties = #{'Topic-Alias' := AliasId}} - }, - PState = #protocol{proto_ver = ?MQTT_PROTO_V5, - topic_aliases = Aliases, - topic_alias_maximum = #{from_client := TopicAliasMaximum}}) -> - case AliasId =< TopicAliasMaximum of - true -> - {ok, Packet#mqtt_packet{variable = Publish#mqtt_packet_publish{ - topic_name = maps:get(AliasId, Aliases, <<>>)}}, PState}; - false -> - {error, ?RC_TOPIC_ALIAS_INVALID} - end; - -preprocess_properties(Packet = #mqtt_packet{ - variable = #mqtt_packet_publish{ - topic_name = Topic, - properties = #{'Topic-Alias' := AliasId}} - }, - PState = #protocol{proto_ver = ?MQTT_PROTO_V5, - topic_aliases = Aliases, - topic_alias_maximum = #{from_client := TopicAliasMaximum}}) -> - case AliasId =< TopicAliasMaximum of - true -> - {ok, Packet, PState#protocol{topic_aliases = maps:put(AliasId, Topic, Aliases)}}; - false -> - {error, ?RC_TOPIC_ALIAS_INVALID} - end; - -preprocess_properties(Packet, PState) -> - {ok, Packet, PState}. - -%%-------------------------------------------------------------------- -%% Process MQTT Packet -%%-------------------------------------------------------------------- - -process(?CONNECT_PACKET( - #mqtt_packet_connect{proto_name = ProtoName, - proto_ver = ProtoVer, - is_bridge = IsBridge, - clean_start = CleanStart, - keepalive = Keepalive, - properties = ConnProps, - client_id = ClientId, - username = Username, - password = Password} = ConnPkt), PState) -> - - %% TODO: Mountpoint... - %% Msg -> emqx_mountpoint:mount(MountPoint, Msg) - PState0 = maybe_use_username_as_clientid(ClientId, - set_username(Username, - PState#protocol{proto_ver = ProtoVer, - proto_name = ProtoName, - clean_start = CleanStart, - keepalive = Keepalive, - conn_props = ConnProps, - is_bridge = IsBridge, - connected_at = os:timestamp()})), - - NewClientId = PState0#protocol.client_id, - - emqx_logger:set_metadata_client_id(NewClientId), - - Credentials = credentials(PState0), - PState1 = PState0#protocol{credentials = Credentials}, - connack( - case check_connect(ConnPkt, PState1) of - ok -> - case emqx_access_control:authenticate(Credentials#{password => Password}) of - {ok, Credentials0} -> - PState3 = maybe_assign_client_id(PState1), - emqx_logger:set_metadata_client_id(PState3#protocol.client_id), - %% Open session - SessAttrs = #{will_msg => make_will_msg(ConnPkt)}, - case try_open_session(SessAttrs, PState3) of - {ok, Session, SP} -> - PState4 = PState3#protocol{session = Session, connected = true, - credentials = keepsafety(Credentials0)}, - ok = emqx_cm:register_channel(client_id(PState4)), - ok = emqx_cm:set_conn_attrs(client_id(PState4), attrs(PState4)), - %% Start keepalive - start_keepalive(Keepalive, PState4), - %% Success - {?RC_SUCCESS, SP, PState4}; - {error, Error} -> - ?LOG(error, "Failed to open session: ~p", [Error]), - {?RC_UNSPECIFIED_ERROR, PState1#protocol{credentials = Credentials0}} - end; - {error, Reason} -> - ?LOG(warning, "Client ~s (Username: '~s') login failed for ~p", [NewClientId, Username, Reason]), - {emqx_reason_codes:connack_error(Reason), PState1#protocol{credentials = Credentials}} - end; - {error, ReasonCode} -> - {ReasonCode, PState1} - end); - -process(Packet = ?PUBLISH_PACKET(?QOS_0, Topic, _PacketId, _Payload), PState = #protocol{zone = Zone}) -> - case check_publish(Packet, PState) of - ok -> - do_publish(Packet, PState); - {error, ReasonCode} -> - ?LOG(warning, "Cannot publish qos0 message to ~s for ~s", - [Topic, emqx_reason_codes:text(ReasonCode)]), - %% TODO: ... - AclDenyAction = emqx_zone:get_env(Zone, acl_deny_action, ignore), - do_acl_deny_action(AclDenyAction, Packet, ReasonCode, PState) - end; - -process(Packet = ?PUBLISH_PACKET(?QOS_1, Topic, PacketId, _Payload), PState = #protocol{zone = Zone}) -> - case check_publish(Packet, PState) of - ok -> - do_publish(Packet, PState); - {error, ReasonCode} -> - ?LOG(warning, "Cannot publish qos1 message to ~s for ~s", - [Topic, emqx_reason_codes:text(ReasonCode)]), - handle_out({puback, PacketId, ReasonCode}, PState) - end; - -process(Packet = ?PUBLISH_PACKET(?QOS_2, Topic, PacketId, _Payload), PState = #protocol{zone = Zone}) -> - case check_publish(Packet, PState) of - ok -> - do_publish(Packet, PState); - {error, ReasonCode} -> - ?LOG(warning, "Cannot publish qos2 message to ~s for ~s", - [Topic, emqx_reason_codes:text(ReasonCode)]), - handle_out({pubrec, PacketId, ReasonCode}, PState) - end; - -process(?PUBACK_PACKET(PacketId, ReasonCode), PState = #protocol{session = Session}) -> - case emqx_session:puback(PacketId, ReasonCode, Session) of - {ok, NSession} -> - {ok, PState#protocol{session = NSession}}; - {error, _NotFound} -> - {ok, PState} %% TODO: Fixme later - end; - -process(?PUBREC_PACKET(PacketId, ReasonCode), PState = #protocol{session = Session}) -> - case emqx_session:pubrec(PacketId, ReasonCode, Session) of - {ok, NSession} -> - {ok, ?PUBREL_PACKET(PacketId), PState#protocol{session = NSession}}; - {error, NotFound} -> - {ok, ?PUBREL_PACKET(PacketId, NotFound), PState} - end; - -process(?PUBREL_PACKET(PacketId, ReasonCode), PState = #protocol{session = Session}) -> - case emqx_session:pubrel(PacketId, ReasonCode, Session) of - {ok, NSession} -> - {ok, ?PUBCOMP_PACKET(PacketId), PState#protocol{session = NSession}}; - {error, NotFound} -> - {ok, ?PUBCOMP_PACKET(PacketId, NotFound), PState} - end; - -process(?PUBCOMP_PACKET(PacketId, ReasonCode), PState = #protocol{session = Session}) -> - case emqx_session:pubcomp(PacketId, ReasonCode, Session) of - {ok, NSession} -> - {ok, PState#protocol{session = NSession}}; - {error, _NotFound} -> ok - %% TODO: How to handle NotFound? - end; - -process(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters), - PState = #protocol{zone = Zone, session = Session, credentials = Credentials}) -> - case check_subscribe(parse_topic_filters(?SUBSCRIBE, raw_topic_filters(PState, RawTopicFilters)), PState) of - {ok, TopicFilters} -> - TopicFilters0 = emqx_hooks:run_fold('client.subscribe', [Credentials], TopicFilters), - TopicFilters1 = emqx_mountpoint:mount(mountpoint(Credentials), TopicFilters0), - {ok, ReasonCodes, NSession} = emqx_session:subscribe(Credentials, TopicFilters1, Session), - handle_out({suback, PacketId, ReasonCodes}, PState#protocol{session = NSession}); - {error, TopicFilters} -> - {SubTopics, ReasonCodes} = - lists:foldr(fun({Topic, #{rc := ?RC_SUCCESS}}, {Topics, Codes}) -> - {[Topic|Topics], [?RC_IMPLEMENTATION_SPECIFIC_ERROR | Codes]}; - ({Topic, #{rc := Code}}, {Topics, Codes}) -> - {[Topic|Topics], [Code|Codes]} - end, {[], []}, TopicFilters), - ?LOG(warning, "Cannot subscribe ~p for ~p", - [SubTopics, [emqx_reason_codes:text(R) || R <- ReasonCodes]]), - handle_out({suback, PacketId, ReasonCodes}, PState) - end; - -process(?UNSUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters), - PState = #protocol{session = Session, credentials = Credentials}) -> - TopicFilters = emqx_hooks:run_fold('client.unsubscribe', [Credentials], - parse_topic_filters(?UNSUBSCRIBE, RawTopicFilters)), - TopicFilters1 = emqx_mountpoint:mount(mountpoint(Credentials), TopicFilters), - {ok, ReasonCodes, NSession} = emqx_session:unsubscribe(Credentials, TopicFilters1, Session), - handle_out({unsuback, PacketId, ReasonCodes}, PState#protocol{session = NSession}); - -process(?PACKET(?PINGREQ), PState) -> - {ok, ?PACKET(?PINGRESP), PState}; - -process(?DISCONNECT_PACKET(?RC_SUCCESS, #{'Session-Expiry-Interval' := Interval}), - PState = #protocol{session = Session, conn_props = #{'Session-Expiry-Interval' := OldInterval}}) -> - case Interval =/= 0 andalso OldInterval =:= 0 of - true -> - handle_out({disconnect, ?RC_PROTOCOL_ERROR}, PState#protocol{will_msg = undefined}); - false -> - %% TODO: - %% emqx_session:update_expiry_interval(SPid, Interval), - %% Clean willmsg - {stop, normal, PState#protocol{will_msg = undefined}} - end; - -process(?DISCONNECT_PACKET(?RC_SUCCESS), PState) -> - {stop, normal, PState#protocol{will_msg = undefined}}; - -process(?DISCONNECT_PACKET(_), PState) -> - {stop, {shutdown, abnormal_disconnet}, PState}; - -process(?AUTH_PACKET(), State) -> - %%TODO: implement later. - {ok, State}. - -%%-------------------------------------------------------------------- -%% ConnAck --> Client -%%-------------------------------------------------------------------- - -connack({?RC_SUCCESS, SP, PState = #protocol{credentials = Credentials}}) -> - ok = emqx_hooks:run('client.connected', [Credentials, ?RC_SUCCESS, attrs(PState)]), - handle_out({connack, ?RC_SUCCESS, sp(SP)}, PState); - -connack({ReasonCode, PState = #protocol{proto_ver = ProtoVer, credentials = Credentials}}) -> - ok = emqx_hooks:run('client.connected', [Credentials, ReasonCode, attrs(PState)]), - [ReasonCode1] = reason_codes_compat(connack, [ReasonCode], ProtoVer), - handle_out({connack, ReasonCode1}, PState). - -%%------------------------------------------------------------------------------ -%% Publish Message -> Broker -%%------------------------------------------------------------------------------ - -do_publish(Packet = ?PUBLISH_PACKET(QoS, PacketId), - PState = #protocol{session = Session, credentials = Credentials}) -> - Msg = emqx_mountpoint:mount(mountpoint(Credentials), - emqx_packet:to_message(Credentials, Packet)), - Msg1 = emqx_message:set_flag(dup, false, Msg), - case emqx_session:publish(PacketId, Msg1, Session) of - {ok, Results} -> - puback(QoS, PacketId, Results, PState); - {ok, Results, NSession} -> - puback(QoS, PacketId, Results, PState#protocol{session = NSession}); - {error, Reason} -> - puback(QoS, PacketId, {error, Reason}, PState) - end. - -%%------------------------------------------------------------------------------ -%% Puback -> Client -%%------------------------------------------------------------------------------ - -puback(?QOS_0, _PacketId, _Result, PState) -> - {ok, PState}; -puback(?QOS_1, PacketId, [], PState) -> - handle_out({puback, PacketId, ?RC_NO_MATCHING_SUBSCRIBERS}, PState); -%%TODO: calc the deliver count? -puback(?QOS_1, PacketId, Result, PState) when is_list(Result) -> - handle_out({puback, PacketId, ?RC_SUCCESS}, PState); -puback(?QOS_1, PacketId, {error, ReasonCode}, PState) -> - handle_out({puback, PacketId, ReasonCode}, PState); -puback(?QOS_2, PacketId, [], PState) -> - handle_out({pubrec, PacketId, ?RC_NO_MATCHING_SUBSCRIBERS}, PState); -puback(?QOS_2, PacketId, Result, PState) when is_list(Result) -> - handle_out({pubrec, PacketId, ?RC_SUCCESS}, PState); -puback(?QOS_2, PacketId, {error, ReasonCode}, PState) -> - handle_out({pubrec, PacketId, ReasonCode}, PState). - -%%-------------------------------------------------------------------- -%% Handle outgoing -%%-------------------------------------------------------------------- - -handle_out({connack, ?RC_SUCCESS, SP}, PState = #protocol{zone = Zone, - proto_ver = ?MQTT_PROTO_V5, - client_id = ClientId, - is_assigned = IsAssigned, - topic_alias_maximum = TopicAliasMaximum}) -> - #{max_packet_size := MaxPktSize, - max_qos_allowed := MaxQoS, - mqtt_retain_available := Retain, - max_topic_alias := MaxAlias, - mqtt_shared_subscription := Shared, - mqtt_wildcard_subscription := Wildcard} = caps(PState), - %% Response-Information is so far not set by broker. - %% i.e. It's a Client-to-Client contract for the request-response topic naming scheme. - %% According to MQTT 5.0 spec: - %% A common use of this is to pass a globally unique portion of the topic tree which - %% is reserved for this Client for at least the lifetime of its Session. - %% This often cannot just be a random name as both the requesting Client and the - %% responding Client need to be authorized to use it. - %% If we are to support it in the feature, the implementation should be flexible - %% to allow prefixing the response topic based on different ACL config. - %% e.g. prefix by username or client-id, so that unauthorized clients can not - %% subscribe requests or responses that are not intended for them. - Props = #{'Retain-Available' => flag(Retain), - 'Maximum-Packet-Size' => MaxPktSize, - 'Topic-Alias-Maximum' => MaxAlias, - 'Wildcard-Subscription-Available' => flag(Wildcard), - 'Subscription-Identifier-Available' => 1, - %'Response-Information' => - 'Shared-Subscription-Available' => flag(Shared)}, - - Props1 = if - MaxQoS =:= ?QOS_2 -> - Props; - true -> - maps:put('Maximum-QoS', MaxQoS, Props) - end, - - Props2 = if IsAssigned -> - Props1#{'Assigned-Client-Identifier' => ClientId}; - true -> Props1 - - end, - - Props3 = case emqx_zone:get_env(Zone, server_keepalive) of - undefined -> Props2; - Keepalive -> Props2#{'Server-Keep-Alive' => Keepalive} - end, - - PState1 = PState#protocol{topic_alias_maximum = TopicAliasMaximum#{from_client => MaxAlias}}, - - {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, Props3), PState1}; - -handle_out({connack, ?RC_SUCCESS, SP}, PState) -> - {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP), PState}; - -handle_out({connack, ReasonCode}, PState = #protocol{proto_ver = ProtoVer}) -> - Reason = emqx_reason_codes:name(ReasonCode, ProtoVer), - {error, Reason, ?CONNACK_PACKET(ReasonCode), PState}; - -handle_out({deliver, Topic, Msg}, PState = #protocol{session = Session}) -> - case emqx_session:deliver(Topic, Msg, Session) of - {ok, Publish, NSession} -> - handle_out(Publish, PState#protocol{session = NSession}); - {ok, NSession} -> - {ok, PState#protocol{session = NSession}} - end; - -handle_out({publish, PacketId, Msg}, PState = #protocol{credentials = Credentials}) -> - Msg0 = emqx_hooks:run_fold('message.deliver', [Credentials], Msg), - Msg1 = emqx_message:update_expiry(Msg0), - Msg2 = emqx_mountpoint:unmount(mountpoint(Credentials), Msg1), - {ok, emqx_packet:from_message(PacketId, Msg2), PState}; - -handle_out({puback, PacketId, ReasonCode}, PState) -> - {ok, ?PUBACK_PACKET(PacketId, ReasonCode), PState}; - %% TODO: - %% AclDenyAction = emqx_zone:get_env(Zone, acl_deny_action, ignore), - %% do_acl_deny_action(AclDenyAction, Packet, ReasonCode, PState1); - -handle_out({pubrel, PacketId}, PState) -> - {ok, ?PUBREL_PACKET(PacketId), PState}; - -handle_out({pubrec, PacketId, ReasonCode}, PState) -> - %% TODO: - %% AclDenyAction = emqx_zone:get_env(Zone, acl_deny_action, ignore), - %% do_acl_deny_action(AclDenyAction, Packet, ReasonCode, PState1); - {ok, ?PUBREC_PACKET(PacketId, ReasonCode), PState}; - -%%handle_out({pubrec, PacketId, ReasonCode}, PState) -> -%% {ok, ?PUBREC_PACKET(PacketId, ReasonCode), PState}; - -handle_out({suback, PacketId, ReasonCodes}, PState = #protocol{proto_ver = ProtoVer}) -> - %% TODO: ACL Deny - {ok, ?SUBACK_PACKET(PacketId, reason_codes_compat(suback, ReasonCodes, ProtoVer)), PState}; - -handle_out({unsuback, PacketId, ReasonCodes}, PState = #protocol{proto_ver = ProtoVer}) -> - {ok, ?UNSUBACK_PACKET(PacketId, reason_codes_compat(unsuback, ReasonCodes, ProtoVer)), PState}; - -%% Deliver a disconnect for mqtt 5.0 -handle_out({disconnect, RC}, PState = #protocol{proto_ver = ?MQTT_PROTO_V5}) -> - {error, emqx_reason_codes:name(RC), ?DISCONNECT_PACKET(RC), PState}; - -handle_out({disconnect, RC}, PState) -> - {error, emqx_reason_codes:name(RC), PState}. - -handle_timeout(Timer, Name, PState) -> - {ok, PState}. - -%%------------------------------------------------------------------------------ -%% Maybe use username replace client id - -maybe_use_username_as_clientid(ClientId, PState = #protocol{username = undefined}) -> - PState#protocol{client_id = ClientId}; -maybe_use_username_as_clientid(ClientId, PState = #protocol{username = Username, zone = Zone}) -> - case emqx_zone:get_env(Zone, use_username_as_clientid, false) of - true -> - PState#protocol{client_id = Username}; - false -> - PState#protocol{client_id = ClientId} - end. - -%%------------------------------------------------------------------------------ -%% Assign a clientId - -maybe_assign_client_id(PState = #protocol{client_id = <<>>, ack_props = AckProps}) -> - ClientId = emqx_guid:to_base62(emqx_guid:gen()), - AckProps1 = set_property('Assigned-Client-Identifier', ClientId, AckProps), - PState#protocol{client_id = ClientId, is_assigned = true, ack_props = AckProps1}; -maybe_assign_client_id(PState) -> - PState. - -try_open_session(SessAttrs, PState = #protocol{zone = Zone, - client_id = ClientId, - username = Username, - clean_start = CleanStart}) -> - case emqx_cm:open_session( - maps:merge(#{zone => Zone, - client_id => ClientId, - username => Username, - clean_start => CleanStart, - max_inflight => attr(max_inflight, PState), - expiry_interval => attr(expiry_interval, PState), - topic_alias_maximum => attr(topic_alias_maximum, PState)}, - SessAttrs)) of - {ok, Session} -> - {ok, Session, false}; - Other -> Other - end. - -set_property(Name, Value, ?NO_PROPS) -> - #{Name => Value}; -set_property(Name, Value, Props) -> - Props#{Name => Value}. - -get_property(_Name, undefined, Default) -> - Default; -get_property(Name, Props, Default) -> - maps:get(Name, Props, Default). - -make_will_msg(#mqtt_packet_connect{proto_ver = ProtoVer, - will_props = WillProps} = ConnPkt) -> - emqx_packet:will_msg( - case ProtoVer of - ?MQTT_PROTO_V5 -> - WillDelayInterval = get_property('Will-Delay-Interval', WillProps, 0), - ConnPkt#mqtt_packet_connect{ - will_props = set_property('Will-Delay-Interval', WillDelayInterval, WillProps)}; - _ -> - ConnPkt - end). - -%%-------------------------------------------------------------------- -%% Check Packet -%%-------------------------------------------------------------------- - -check_connect(Packet, PState) -> - run_check_steps([fun check_proto_ver/2, - fun check_client_id/2, - fun check_flapping/2, - fun check_banned/2, - fun check_will_topic/2, - fun check_will_retain/2], Packet, PState). - -check_proto_ver(#mqtt_packet_connect{proto_ver = Ver, - proto_name = Name}, _PState) -> - case lists:member({Ver, Name}, ?PROTOCOL_NAMES) of - true -> ok; - false -> {error, ?RC_PROTOCOL_ERROR} - end. - -%% MQTT3.1 does not allow null clientId -check_client_id(#mqtt_packet_connect{proto_ver = ?MQTT_PROTO_V3, - client_id = <<>>}, _PState) -> - {error, ?RC_CLIENT_IDENTIFIER_NOT_VALID}; - -%% Issue#599: Null clientId and clean_start = false -check_client_id(#mqtt_packet_connect{client_id = <<>>, - clean_start = false}, _PState) -> - {error, ?RC_CLIENT_IDENTIFIER_NOT_VALID}; - -check_client_id(#mqtt_packet_connect{client_id = <<>>, - clean_start = true}, _PState) -> - ok; - -check_client_id(#mqtt_packet_connect{client_id = ClientId}, #protocol{zone = Zone}) -> - Len = byte_size(ClientId), - MaxLen = emqx_zone:get_env(Zone, max_clientid_len), - case (1 =< Len) andalso (Len =< MaxLen) of - true -> ok; - false -> {error, ?RC_CLIENT_IDENTIFIER_NOT_VALID} - end. - -check_flapping(#mqtt_packet_connect{}, PState) -> - do_flapping_detect(connect, PState). - -check_banned(#mqtt_packet_connect{client_id = ClientId, username = Username}, - #protocol{zone = Zone, peername = Peername}) -> - Credentials = #{client_id => ClientId, - username => Username, - peername => Peername}, - EnableBan = emqx_zone:get_env(Zone, enable_ban, false), - do_check_banned(EnableBan, Credentials). - -check_will_topic(#mqtt_packet_connect{will_flag = false}, _PState) -> - ok; -check_will_topic(#mqtt_packet_connect{will_topic = WillTopic} = ConnPkt, PState) -> - try emqx_topic:validate(WillTopic) of - true -> check_will_acl(ConnPkt, PState) - catch error : _Error -> - {error, ?RC_TOPIC_NAME_INVALID} - end. - -check_will_retain(#mqtt_packet_connect{will_retain = false, proto_ver = ?MQTT_PROTO_V5}, _PState) -> - ok; -check_will_retain(#mqtt_packet_connect{will_retain = true, proto_ver = ?MQTT_PROTO_V5}, #protocol{zone = Zone}) -> - case emqx_zone:get_env(Zone, mqtt_retain_available, true) of - true -> {error, ?RC_RETAIN_NOT_SUPPORTED}; - false -> ok - end; -check_will_retain(_Packet, _PState) -> - ok. - -check_will_acl(#mqtt_packet_connect{will_topic = WillTopic}, - #protocol{zone = Zone, credentials = Credentials}) -> - EnableAcl = emqx_zone:get_env(Zone, enable_acl, false), - case do_acl_check(EnableAcl, publish, Credentials, WillTopic) of - ok -> ok; - Other -> - ?LOG(warning, "Cannot publish will message to ~p for acl denied", [WillTopic]), - Other - end. - -check_publish(Packet, PState) -> - run_check_steps([fun check_pub_caps/2, - fun check_pub_acl/2], Packet, PState). - -check_pub_caps(#mqtt_packet{header = #mqtt_packet_header{qos = QoS, retain = Retain}, - variable = #mqtt_packet_publish{properties = _Properties}}, - #protocol{zone = Zone}) -> - emqx_mqtt_caps:check_pub(Zone, #{qos => QoS, retain => Retain}). - -check_pub_acl(_Packet, #protocol{credentials = #{is_superuser := IsSuper}}) - when IsSuper -> - ok; -check_pub_acl(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = Topic}}, - #protocol{zone = Zone, credentials = Credentials}) -> - EnableAcl = emqx_zone:get_env(Zone, enable_acl, false), - do_acl_check(EnableAcl, publish, Credentials, Topic). - -run_check_steps([], _Packet, _PState) -> - ok; -run_check_steps([Check|Steps], Packet, PState) -> - case Check(Packet, PState) of - ok -> - run_check_steps(Steps, Packet, PState); - Error = {error, _RC} -> - Error - end. - -check_subscribe(TopicFilters, PState = #protocol{zone = Zone}) -> - case emqx_mqtt_caps:check_sub(Zone, TopicFilters) of - {ok, TopicFilter1} -> - check_sub_acl(TopicFilter1, PState); - {error, TopicFilter1} -> - {error, TopicFilter1} - end. - -check_sub_acl(TopicFilters, #protocol{credentials = #{is_superuser := IsSuper}}) - when IsSuper -> - {ok, TopicFilters}; -check_sub_acl(TopicFilters, #protocol{zone = Zone, credentials = Credentials}) -> - EnableAcl = emqx_zone:get_env(Zone, enable_acl, false), - lists:foldr( - fun({Topic, SubOpts}, {Ok, Acc}) when EnableAcl -> - AllowTerm = {Ok, [{Topic, SubOpts}|Acc]}, - DenyTerm = {error, [{Topic, SubOpts#{rc := ?RC_NOT_AUTHORIZED}}|Acc]}, - do_acl_check(subscribe, Credentials, Topic, AllowTerm, DenyTerm); - (TopicFilter, Acc) -> - {ok, [TopicFilter | Acc]} - end, {ok, []}, TopicFilters). - -terminate(_Reason, #protocol{client_id = undefined}) -> - ok; -terminate(_Reason, PState = #protocol{connected = false}) -> - do_flapping_detect(disconnect, PState), - ok; -terminate(Reason, PState) when Reason =:= conflict; - Reason =:= discard -> - do_flapping_detect(disconnect, PState), - ok; - -terminate(Reason, PState = #protocol{credentials = Credentials}) -> - do_flapping_detect(disconnect, PState), - ?LOG(info, "Shutdown for ~p", [Reason]), - ok = emqx_hooks:run('client.disconnected', [Credentials, Reason]). - -start_keepalive(0, _PState) -> - ignore; -start_keepalive(Secs, #protocol{zone = Zone}) when Secs > 0 -> - Backoff = emqx_zone:get_env(Zone, keepalive_backoff, 0.75), - self() ! {keepalive, start, round(Secs * Backoff)}. - -%%-------------------------------------------------------------------- -%% Parse topic filters -%%-------------------------------------------------------------------- - -parse_topic_filters(?SUBSCRIBE, RawTopicFilters) -> - [emqx_topic:parse(RawTopic, SubOpts) || {RawTopic, SubOpts} <- RawTopicFilters]; - -parse_topic_filters(?UNSUBSCRIBE, RawTopicFilters) -> - lists:map(fun emqx_topic:parse/1, RawTopicFilters). - -sp(true) -> 1; -sp(false) -> 0. - -flag(false) -> 0; -flag(true) -> 1. - -%%-------------------------------------------------------------------- -%% Execute actions in case acl deny - -do_flapping_detect(Action, #protocol{zone = Zone, - client_id = ClientId}) -> - ok = case emqx_zone:get_env(Zone, enable_flapping_detect, false) of - true -> - Threshold = emqx_zone:get_env(Zone, flapping_threshold, {10, 60}), - case emqx_flapping:check(Action, ClientId, Threshold) of - flapping -> - BanExpiryInterval = emqx_zone:get_env(Zone, flapping_banned_expiry_interval, 3600000), - Until = erlang:system_time(second) + BanExpiryInterval, - emqx_banned:add(#banned{who = {client_id, ClientId}, - reason = <<"flapping">>, - by = <<"flapping_checker">>, - until = Until}), - ok; - _Other -> - ok - end; - _EnableFlappingDetect -> ok - end. - -do_acl_deny_action(disconnect, ?PUBLISH_PACKET(?QOS_0, _Topic, _PacketId, _Payload), - ?RC_NOT_AUTHORIZED, PState = #protocol{proto_ver = ProtoVer}) -> - {error, emqx_reason_codes:name(?RC_NOT_AUTHORIZED, ProtoVer), PState}; - -do_acl_deny_action(disconnect, ?PUBLISH_PACKET(QoS, _Topic, _PacketId, _Payload), - ?RC_NOT_AUTHORIZED, PState = #protocol{proto_ver = ProtoVer}) - when QoS =:= ?QOS_1; QoS =:= ?QOS_2 -> - %% TODO:... - %% deliver({disconnect, ?RC_NOT_AUTHORIZED}, PState), - {error, emqx_reason_codes:name(?RC_NOT_AUTHORIZED, ProtoVer), PState}; - -do_acl_deny_action(Action, ?SUBSCRIBE_PACKET(_PacketId, _Properties, _RawTopicFilters), ReasonCodes, PState) - when is_list(ReasonCodes) -> - traverse_reason_codes(ReasonCodes, Action, PState); -do_acl_deny_action(_OtherAction, _PubSubPacket, ?RC_NOT_AUTHORIZED, PState) -> - {ok, PState}; -do_acl_deny_action(_OtherAction, _PubSubPacket, ReasonCode, PState = #protocol{proto_ver = ProtoVer}) -> - {error, emqx_reason_codes:name(ReasonCode, ProtoVer), PState}. - -traverse_reason_codes([], _Action, PState) -> - {ok, PState}; -traverse_reason_codes([?RC_SUCCESS | LeftReasonCodes], Action, PState) -> - traverse_reason_codes(LeftReasonCodes, Action, PState); -traverse_reason_codes([?RC_NOT_AUTHORIZED | _LeftReasonCodes], disconnect, PState = #protocol{proto_ver = ProtoVer}) -> - {error, emqx_reason_codes:name(?RC_NOT_AUTHORIZED, ProtoVer), PState}; -traverse_reason_codes([?RC_NOT_AUTHORIZED | LeftReasonCodes], Action, PState) -> - traverse_reason_codes(LeftReasonCodes, Action, PState); -traverse_reason_codes([OtherCode | _LeftReasonCodes], _Action, PState = #protocol{proto_ver = ProtoVer}) -> - {error, emqx_reason_codes:name(OtherCode, ProtoVer), PState}. - -%% Reason code compat -reason_codes_compat(_PktType, ReasonCodes, ?MQTT_PROTO_V5) -> - ReasonCodes; -reason_codes_compat(unsuback, _ReasonCodes, _ProtoVer) -> - undefined; -reason_codes_compat(PktType, ReasonCodes, _ProtoVer) -> - [emqx_reason_codes:compat(PktType, RC) || RC <- ReasonCodes]. - -raw_topic_filters(#protocol{zone = Zone, proto_ver = ProtoVer, is_bridge = IsBridge}, RawTopicFilters) -> - IgnoreLoop = emqx_zone:get_env(Zone, ignore_loop_deliver, false), - case ProtoVer < ?MQTT_PROTO_V5 of - true -> - IfIgnoreLoop = case IgnoreLoop of true -> 1; false -> 0 end, - case IsBridge of - true -> [{RawTopic, SubOpts#{rap => 1, nl => IfIgnoreLoop}} || {RawTopic, SubOpts} <- RawTopicFilters]; - false -> [{RawTopic, SubOpts#{rap => 0, nl => IfIgnoreLoop}} || {RawTopic, SubOpts} <- RawTopicFilters] - end; - false -> - RawTopicFilters - end. - -mountpoint(Credentials) -> - maps:get(mountpoint, Credentials, undefined). - -do_check_banned(_EnableBan = true, Credentials) -> - case emqx_banned:check(Credentials) of - true -> {error, ?RC_BANNED}; - false -> ok - end; -do_check_banned(_EnableBan, _Credentials) -> ok. - -do_acl_check(_EnableAcl = true, Action, Credentials, Topic) -> - AllowTerm = ok, - DenyTerm = {error, ?RC_NOT_AUTHORIZED}, - do_acl_check(Action, Credentials, Topic, AllowTerm, DenyTerm); -do_acl_check(_EnableAcl, _Action, _Credentials, _Topic) -> - ok. - -do_acl_check(Action, Credentials, Topic, AllowTerm, DenyTerm) -> - case emqx_access_control:check_acl(Credentials, Action, Topic) of - allow -> AllowTerm; - deny -> DenyTerm - end. - diff --git a/src/emqx_reason_codes.erl b/src/emqx_reason_codes.erl index a99406f53..327b96018 100644 --- a/src/emqx_reason_codes.erl +++ b/src/emqx_reason_codes.erl @@ -22,6 +22,7 @@ -export([ name/2 , text/1 , connack_error/1 + , puback/1 ]). -export([compat/2]). @@ -161,3 +162,6 @@ connack_error(server_busy) -> ?RC_SERVER_BUSY; connack_error(banned) -> ?RC_BANNED; connack_error(bad_authentication_method) -> ?RC_BAD_AUTHENTICATION_METHOD; connack_error(_) -> ?RC_NOT_AUTHORIZED. + +puback([]) -> ?RC_NO_MATCHING_SUBSCRIBERS; +puback(L) when is_list(L) -> ?RC_SUCCESS. diff --git a/src/emqx_session.erl b/src/emqx_session.erl index f5e7414d5..e699a7252 100644 --- a/src/emqx_session.erl +++ b/src/emqx_session.erl @@ -69,7 +69,7 @@ , pubcomp/3 ]). --export([ deliver/3 +-export([ deliver/2 , await/3 , enqueue/2 ]). @@ -397,31 +397,29 @@ pubcomp(PacketId, ReasonCode, Session = #session{inflight = Inflight, mqueue = Q %% Handle delivery %%-------------------------------------------------------------------- -deliver(Topic, Msg, Session = #session{subscriptions = SubMap}) -> - SubOpts = get_subopts(Topic, SubMap), - case enrich(SubOpts, Msg, Session) of - {ok, Msg1} -> - deliver(Msg1, Session); - ignore -> ignore - end. +deliver(Delivers, Session = #session{subscriptions = SubMap}) + when is_list(Delivers) -> + Msgs = [enrich(get_subopts(Topic, SubMap), Msg, Session) + || {Topic, Msg} <- Delivers], + deliver(Msgs, [], Session). -%% Enqueue message if the client has been disconnected -%% process_msg(Msg, Session = #session{conn_pid = undefined}) -> -%% {ignore, enqueue_msg(Msg, Session)}; -deliver(Msg = #message{qos = ?QOS_0}, Session) -> - {ok, {publish, undefined, Msg}, Session}; +deliver([], Publishes, Session) -> + {ok, lists:reverse(Publishes), Session}; -deliver(Msg = #message{qos = QoS}, +deliver([Msg = #message{qos = ?QOS_0}|More], Acc, Session) -> + deliver(More, [{publish, undefined, Msg}|Acc], Session); + +deliver([Msg = #message{qos = QoS}|More], Acc, Session = #session{next_pkt_id = PacketId, inflight = Inflight}) when QoS =:= ?QOS_1 orelse QoS =:= ?QOS_2 -> case emqx_inflight:is_full(Inflight) of true -> - {ignore, enqueue(Msg, Session)}; + deliver(More, Acc, enqueue(Msg, Session)); false -> Publish = {publish, PacketId, Msg}, NSession = await(PacketId, Msg, Session), - {ok, Publish, next_pkt_id(NSession)} + deliver(More, [Publish|Acc], next_pkt_id(NSession)) end. enqueue(Msg, Session = #session{mqueue = Q}) -> @@ -454,7 +452,7 @@ get_subopts(Topic, SubMap) -> end. enrich([], Msg, _Session) -> - {ok, Msg}; + Msg; %%enrich([{nl, 1}|_Opts], #message{from = ClientId}, #session{client_id = ClientId}) -> %% ignore; enrich([{nl, _}|Opts], Msg, Session) -> diff --git a/src/emqx_ws_connection.erl b/src/emqx_ws_connection.erl index 73c37fc04..b21897b5e 100644 --- a/src/emqx_ws_connection.erl +++ b/src/emqx_ws_connection.erl @@ -14,20 +14,19 @@ %% limitations under the License. %%-------------------------------------------------------------------- -%% MQTT WebSocket Channel --module(emqx_ws_channel). +%% MQTT WebSocket Connection +-module(emqx_ws_connection). -include("emqx.hrl"). -include("emqx_mqtt.hrl"). -include("logger.hrl"). +-include("types.hrl"). --logger_header("[WS Channel]"). +-logger_header("[WS Conn]"). -export([ info/1 , attrs/1 , stats/1 - , kick/1 - , session/1 ]). %% websocket callbacks @@ -41,18 +40,19 @@ -record(state, { request, options, - peername, - sockname, - proto_state, - parse_state, - keepalive, - enable_stats, - stats_timer, - idle_timeout, + peername :: {inet:ip_address(), inet:port_number()}, + sockname :: {inet:ip_address(), inet:port_number()}, + parse_state :: emqx_frame:parse_state(), + packets :: list(emqx_mqtt:packet()), + chan_state :: emqx_channel:channel(), + keepalive :: maybe(emqx_keepalive:keepalive()), + stats_timer :: disabled | maybe(reference()), + idle_timeout :: timeout(), shutdown }). -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]). +-define(CHAN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). %%-------------------------------------------------------------------- %% API @@ -62,42 +62,40 @@ info(WSPid) when is_pid(WSPid) -> call(WSPid, info); -info(#state{peername = Peername, - sockname = Sockname, - proto_state = ProtoState}) -> - ProtoInfo = emqx_protocol:info(ProtoState), +info(#state{peername = Peername, + sockname = Sockname, + chan_state = ChanState}) -> ConnInfo = #{socktype => websocket, conn_state => running, peername => Peername, - sockname => Sockname}, - maps:merge(ProtoInfo, ConnInfo). + sockname => Sockname + }, + ChanInfo = emqx_channel:info(ChanState), + maps:merge(ConnInfo, ChanInfo). %% for dashboard attrs(WSPid) when is_pid(WSPid) -> call(WSPid, attrs); -attrs(#state{peername = Peername, - sockname = Sockname, - proto_state = ProtoState}) -> +attrs(#state{peername = Peername, + sockname = Sockname, + chan_state = ChanState}) -> SockAttrs = #{peername => Peername, sockname => Sockname}, - ProtoAttrs = emqx_protocol:attrs(ProtoState), - maps:merge(SockAttrs, ProtoAttrs). + ChanAttrs = emqx_channel:attrs(ChanState), + maps:merge(SockAttrs, ChanAttrs). stats(WSPid) when is_pid(WSPid) -> call(WSPid, stats); -stats(#state{proto_state = ProtoState}) -> - lists:append([wsock_stats(), - emqx_misc:proc_stats(), - emqx_protocol:stats(ProtoState) - ]). +stats(#state{}) -> + lists:append([chan_stats(), wsock_stats(), emqx_misc:proc_stats()]). -kick(WSPid) when is_pid(WSPid) -> - call(WSPid, kick). +%%kick(WSPid) when is_pid(WSPid) -> +%% call(WSPid, kick). -session(WSPid) when is_pid(WSPid) -> - call(WSPid, session). +%%session(WSPid) when is_pid(WSPid) -> +%% call(WSPid, session). call(WSPid, Req) when is_pid(WSPid) -> Mref = erlang:monitor(process, WSPid), @@ -153,24 +151,24 @@ websocket_init(#state{request = Req, options = Options}) -> [Error, Reason]), undefined end, - ProtoState = emqx_protocol:init(#{peername => Peername, - sockname => Sockname, - peercert => Peercert, - sendfun => send_fun(self()), - ws_cookie => WsCookie, - conn_mod => ?MODULE}, Options), + ChanState = emqx_channel:init(#{peername => Peername, + sockname => Sockname, + peercert => Peercert, + ws_cookie => WsCookie, + conn_mod => ?MODULE}, Options), Zone = proplists:get_value(zone, Options), MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE), ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}), EnableStats = emqx_zone:get_env(Zone, enable_stats, true), + StatsTimer = if EnableStats -> undefined; ?Otherwise-> disabled end, IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), emqx_logger:set_metadata_peername(esockd_net:format(Peername)), ok = emqx_misc:init_proc_mng_policy(Zone), {ok, #state{peername = Peername, sockname = Sockname, parse_state = ParseState, - proto_state = ProtoState, - enable_stats = EnableStats, + chan_state = ChanState, + stats_timer = StatsTimer, idle_timeout = IdleTimout }}. @@ -244,24 +242,21 @@ websocket_info({call, From, kick}, State) -> gen_server:reply(From, ok), shutdown(kick, State); -websocket_info({call, From, session}, State = #state{proto_state = ProtoState}) -> - gen_server:reply(From, emqx_protocol:session(ProtoState)), - {ok, State}; - -websocket_info(Delivery, State = #state{proto_state = ProtoState}) +websocket_info(Delivery, State = #state{chan_state = ChanState}) when element(1, Delivery) =:= deliver -> - case emqx_protocol:handle_out(Delivery, ProtoState) of - {ok, NProtoState} -> - {ok, State#state{proto_state = NProtoState}}; - {ok, Packet, NProtoState} -> - handle_outgoing(Packet, State#state{proto_state = NProtoState}); + case emqx_channel:handle_out(Delivery, ChanState) of + {ok, NChanState} -> + {ok, State#state{chan_state = NChanState}}; + {ok, Packet, NChanState} -> + handle_outgoing(Packet, State#state{chan_state = NChanState}); {error, Reason} -> shutdown(Reason, State) end; websocket_info({timeout, Timer, emit_stats}, - State = #state{stats_timer = Timer, proto_state = ProtoState}) -> - emqx_cm:set_conn_stats(emqx_protocol:client_id(ProtoState), stats(State)), + State = #state{stats_timer = Timer, chan_state = ChanState}) -> + ClientId = emqx_channel:client_id(ChanState), + ok = emqx_cm:set_conn_stats(ClientId, stats(State)), {ok, State#state{stats_timer = undefined}, hibernate}; websocket_info({keepalive, start, Interval}, State) -> @@ -307,59 +302,74 @@ websocket_info(Info, State) -> ?LOG(error, "Unexpected info: ~p", [Info]), {ok, State}. -terminate(SockError, _Req, #state{keepalive = Keepalive, - proto_state = ProtoState, - shutdown = Shutdown}) -> +terminate(SockError, _Req, #state{keepalive = Keepalive, + chan_state = ChanState, + shutdown = Shutdown}) -> ?LOG(debug, "Terminated for ~p, sockerror: ~p", [Shutdown, SockError]), emqx_keepalive:cancel(Keepalive), - case {ProtoState, Shutdown} of + case {ChanState, Shutdown} of {undefined, _} -> ok; {_, {shutdown, Reason}} -> - emqx_protocol:terminate(Reason, ProtoState); + emqx_channel:terminate(Reason, ChanState); {_, Error} -> - emqx_protocol:terminate(Error, ProtoState) + emqx_channel:terminate(Error, ChanState) end. %%-------------------------------------------------------------------- %% Internal functions %%-------------------------------------------------------------------- -handle_incoming(Packet, SuccFun, State = #state{proto_state = ProtoState}) -> - case emqx_protocol:handle_in(Packet, ProtoState) of - {ok, NProtoState} -> - SuccFun(State#state{proto_state = NProtoState}); - {ok, OutPacket, NProtoState} -> - %% TODO: How to call SuccFun??? - handle_outgoing(OutPacket, State#state{proto_state = NProtoState}); - {error, Reason} -> - ?LOG(error, "Protocol error: ~p", [Reason]), - shutdown(Reason, State); - {error, Reason, NProtoState} -> - shutdown(Reason, State#state{proto_state = NProtoState}); - {stop, Error, NProtoState} -> - shutdown(Error, State#state{proto_state = NProtoState}) +handle_incoming(Packet = ?PACKET(Type), SuccFun, + State = #state{chan_state = ChanState}) -> + _ = inc_incoming_stats(Type), + case emqx_channel:handle_in(Packet, ChanState) of + {ok, NChanState} -> + SuccFun(State#state{chan_state = NChanState}); + {ok, OutPacket, NChanState} -> + %% TODO: SuccFun, + handle_outgoing(OutPacket, State#state{chan_state = NChanState}); + {error, Reason, NChanState} -> + shutdown(Reason, State#state{chan_state = NChanState}); + {stop, Error, NChanState} -> + shutdown(Error, State#state{chan_state = NChanState}) end. -handle_outgoing(Packet, State = #state{proto_state = _NProtoState}) -> - Data = emqx_frame:serialize(Packet), %% TODO:, Options), +handle_outgoing(Packet = ?PACKET(Type), State = #state{chan_state = ChanState}) -> + ProtoVer = emqx_channel:info(proto_ver, ChanState), + Data = emqx_frame:serialize(Packet, ProtoVer), BinSize = iolist_size(Data), - emqx_pd:update_counter(send_cnt, 1), - emqx_pd:update_counter(send_oct, BinSize), + _ = inc_outgoing_stats(Type, BinSize), {reply, {binary, Data}, ensure_stats_timer(State)}. -ensure_stats_timer(State = #state{enable_stats = true, - stats_timer = undefined, - idle_timeout = IdleTimeout}) -> - State#state{stats_timer = emqx_misc:start_timer(IdleTimeout, emit_stats)}; -ensure_stats_timer(State) -> - State. +inc_incoming_stats(Type) -> + emqx_pd:update_counter(recv_pkt, 1), + (Type == ?PUBLISH) + andalso emqx_pd:update_counter(recv_msg, 1). +inc_outgoing_stats(Type, BinSize) -> + emqx_pd:update_counter(send_cnt, 1), + emqx_pd:update_counter(send_oct, BinSize), + emqx_pd:update_counter(send_pkt, 1), + (Type == ?PUBLISH) + andalso emqx_pd:update_counter(send_msg, 1). + +ensure_stats_timer(State = #state{stats_timer = undefined, + idle_timeout = IdleTimeout}) -> + TRef = emqx_misc:start_timer(IdleTimeout, emit_stats), + State#state{stats_timer = TRef}; +%% disabled or timer existed +ensure_stats_timer(State) -> State. + +-compile({inline, [shutdown/2]}). shutdown(Reason, State) -> %% Fix the issue#2591(https://github.com/emqx/emqx/issues/2591#issuecomment-500278696) - self() ! {stop, Reason}, - {ok, State}. + %% self() ! {stop, Reason}, + {stop, State#state{shutdown = Reason}}. wsock_stats() -> [{Key, emqx_pd:get_counter(Key)} || Key <- ?SOCK_STATS]. +chan_stats() -> + [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS]. +