From f26505d40ae036617581977a06fd6e0e2243052b Mon Sep 17 00:00:00 2001 From: Feng Lee Date: Sat, 17 Aug 2019 13:37:48 +0800 Subject: [PATCH] Implement session takeover and resumption. - Implement session takeover - Remove `init_proc_mng_policy/1' from emqx_misc module - Remove `conn_proc_mng_policy/1' from emqx_misc module - Add `emqx_oom' module to monitor OOM of channel process --- src/emqx_channel.erl | 1577 +++++++++++------ src/emqx_cm.erl | 63 +- src/emqx_connection.erl | 632 +++++++ src/emqx_listeners.erl | 4 +- src/emqx_misc.erl | 54 - src/emqx_mod_presence.erl | 7 +- src/emqx_oom.erl | 102 ++ src/emqx_protocol.erl | 924 ---------- src/emqx_session.erl | 25 +- ..._ws_channel.erl => emqx_ws_connection.erl} | 281 ++- test/emqx_channel_SUITE.erl | 276 ++- test/emqx_connection_SUITE.erl | 57 + test/emqx_oom_SUITE.erl | 34 + test/emqx_protocol_SUITE.erl | 287 --- ...SUITE.erl => emqx_ws_connection_SUITE.erl} | 2 +- 15 files changed, 2251 insertions(+), 2074 deletions(-) create mode 100644 src/emqx_connection.erl create mode 100644 src/emqx_oom.erl delete mode 100644 src/emqx_protocol.erl rename src/{emqx_ws_channel.erl => emqx_ws_connection.erl} (60%) create mode 100644 test/emqx_connection_SUITE.erl create mode 100644 test/emqx_oom_SUITE.erl delete mode 100644 test/emqx_protocol_SUITE.erl rename test/{emqx_ws_channel_SUITE.erl => emqx_ws_connection_SUITE.erl} (98%) diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl index 630c4de6a..b8cf74ee1 100644 --- a/src/emqx_channel.erl +++ b/src/emqx_channel.erl @@ -14,11 +14,9 @@ %% limitations under the License. %%-------------------------------------------------------------------- -%% MQTT TCP/SSL Channel +%% MQTT Channel -module(emqx_channel). --behaviour(gen_statem). - -include("emqx.hrl"). -include("emqx_mqtt.hrl"). -include("logger.hrl"). @@ -26,652 +24,1063 @@ -logger_header("[Channel]"). --export([start_link/3]). - -%% APIs -export([ info/1 + , info/2 , attrs/1 - , stats/1 + , caps/1 ]). -%% for Debug --export([state/1]). +%% for tests +-export([set/3]). -%% state callbacks --export([ idle/3 - , connected/3 - , disconnected/3 +-export([takeover/2]). + +-export([ init/2 + , handle_in/2 + , handle_out/2 + , handle_out/3 + , handle_call/2 + , handle_cast/2 + , handle_info/2 + , timeout/3 + , terminate/2 ]). -%% gen_statem callbacks --export([ init/1 - , callback_mode/0 - , code_change/4 - , terminate/3 +-export([gc/3]). + +-import(emqx_access_control, + [ authenticate/1 + , check_acl/3 ]). --record(state, { - transport :: esockd:transport(), - socket :: esockd:socket(), - peername :: emqx_types:peername(), - sockname :: emqx_types:peername(), - conn_state :: running | blocked, - active_n :: pos_integer(), - rate_limit :: maybe(esockd_rate_limit:bucket()), - pub_limit :: maybe(esockd_rate_limit:bucket()), - limit_timer :: maybe(reference()), - serialize :: fun((emqx_types:packet()) -> iodata()), - parse_state :: emqx_frame:parse_state(), - proto_state :: emqx_protocol:proto_state(), - gc_state :: emqx_gc:gc_state(), - keepalive :: maybe(emqx_keepalive:keepalive()), - stats_timer :: disabled | maybe(reference()), - idle_timeout :: timeout(), - connected :: boolean(), - connected_at :: erlang:timestamp() - }). +-export_type([channel/0]). --type(state() :: #state{}). +-record(channel, { + client :: emqx_types:client(), + session :: emqx_session:session(), + proto_name :: binary(), + proto_ver :: emqx_types:ver(), + keepalive :: non_neg_integer(), + will_msg :: emqx_types:message(), + topic_aliases :: maybe(map()), + alias_maximum :: maybe(map()), + ack_props :: maybe(emqx_types:properties()), + idle_timeout :: timeout(), + retry_timer :: maybe(reference()), + alive_timer :: maybe(reference()), + stats_timer :: disabled | maybe(reference()), + expiry_timer :: maybe(reference()), + gc_state :: emqx_gc:gc_state(), %% GC State + oom_policy :: emqx_oom:oom_policy(), %% OOM Policy + connected :: boolean(), + connected_at :: erlang:timestamp(), + resuming :: boolean(), + pendings :: list() + }). --define(ACTIVE_N, 100). --define(HANDLE(T, C, D), handle((T), (C), (D))). --define(CHAN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). --define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]). +-opaque(channel() :: #channel{}). --spec(start_link(esockd:transport(), esockd:socket(), proplists:proplist()) - -> {ok, pid()}). -start_link(Transport, Socket, Options) -> - {ok, proc_lib:spawn_link(?MODULE, init, [{Transport, Socket, Options}])}. +-define(NO_PROPS, undefined). %%-------------------------------------------------------------------- -%% API +%% Info, Attrs and Caps %%-------------------------------------------------------------------- -%% @doc Get infos of the channel. --spec(info(pid() | state()) -> emqx_types:infos()). -info(CPid) when is_pid(CPid) -> - call(CPid, info); -info(#state{transport = Transport, - socket = Socket, - peername = Peername, - sockname = Sockname, - conn_state = ConnState, - active_n = ActiveN, - rate_limit = RateLimit, - pub_limit = PubLimit, - proto_state = ProtoState, - gc_state = GCState, - stats_timer = StatsTimer, - idle_timeout = IdleTimeout, - connected = Connected, - connected_at = ConnectedAt}) -> - ChanInfo = #{socktype => Transport:type(Socket), - peername => Peername, - sockname => Sockname, - conn_state => ConnState, - active_n => ActiveN, - rate_limit => limit_info(RateLimit), - pub_limit => limit_info(PubLimit), - gc_state => emqx_gc:info(GCState), - enable_stats => case StatsTimer of - disabled -> false; - _Otherwise -> true - end, - idle_timeout => IdleTimeout, - connected => Connected, - connected_at => ConnectedAt - }, - maps:merge(ChanInfo, emqx_protocol:info(ProtoState)). +-spec(info(channel()) -> emqx_types:infos()). +info(#channel{client = Client, + session = Session, + proto_name = ProtoName, + proto_ver = ProtoVer, + keepalive = Keepalive, + will_msg = WillMsg, + topic_aliases = Aliases, + stats_timer = StatsTimer, + idle_timeout = IdleTimeout, + gc_state = GCState, + connected = Connected, + connected_at = ConnectedAt}) -> + #{client => Client, + session => if Session == undefined -> + undefined; + true -> emqx_session:info(Session) + end, + proto_name => ProtoName, + proto_ver => ProtoVer, + keepalive => Keepalive, + will_msg => WillMsg, + topic_aliases => Aliases, + enable_stats => case StatsTimer of + disabled -> false; + _Otherwise -> true + end, + idle_timeout => IdleTimeout, + gc_state => emqx_gc:info(GCState), + connected => Connected, + connected_at => ConnectedAt, + resuming => false, + pendings => [] + }. -limit_info(undefined) -> - undefined; -limit_info(Limit) -> - esockd_rate_limit:info(Limit). +-spec(info(atom(), channel()) -> term()). +info(client, #channel{client = Client}) -> + Client; +info(zone, #channel{client = #{zone := Zone}}) -> + Zone; +info(client_id, #channel{client = #{client_id := ClientId}}) -> + ClientId; +info(session, #channel{session = Session}) -> + Session; +info(proto_name, #channel{proto_name = ProtoName}) -> + ProtoName; +info(proto_ver, #channel{proto_ver = ProtoVer}) -> + ProtoVer; +info(keepalive, #channel{keepalive = Keepalive}) -> + Keepalive; +info(will_msg, #channel{will_msg = WillMsg}) -> + WillMsg; +info(topic_aliases, #channel{topic_aliases = Aliases}) -> + Aliases; +info(enable_stats, #channel{stats_timer = disabled}) -> + false; +info(enable_stats, #channel{stats_timer = _TRef}) -> + true; +info(idle_timeout, #channel{idle_timeout = IdleTimeout}) -> + IdleTimeout; +info(gc_state, #channel{gc_state = GCState}) -> + emqx_gc:info(GCState); +info(connected, #channel{connected = Connected}) -> + Connected; +info(connected_at, #channel{connected_at = ConnectedAt}) -> + ConnectedAt. -%% @doc Get attrs of the channel. --spec(attrs(pid() | state()) -> emqx_types:attrs()). -attrs(CPid) when is_pid(CPid) -> - call(CPid, attrs); -attrs(#state{transport = Transport, - socket = Socket, - peername = Peername, - sockname = Sockname, - proto_state = ProtoState, - connected = Connected, - connected_at = ConnectedAt}) -> - ConnAttrs = #{socktype => Transport:type(Socket), - peername => Peername, - sockname => Sockname, - connected => Connected, - connected_at => ConnectedAt}, - maps:merge(ConnAttrs, emqx_protocol:attrs(ProtoState)). +-spec(attrs(channel()) -> emqx_types:attrs()). +attrs(#channel{client = Client, + session = Session, + proto_name = ProtoName, + proto_ver = ProtoVer, + keepalive = Keepalive, + connected = Connected, + connected_at = ConnectedAt}) -> + #{client => Client, + session => if Session == undefined -> + undefined; + true -> emqx_session:attrs(Session) + end, + proto_name => ProtoName, + proto_ver => ProtoVer, + keepalive => Keepalive, + connected => Connected, + connected_at => ConnectedAt + }. -%% @doc Get stats of the channel. --spec(stats(pid() | state()) -> emqx_types:stats()). -stats(CPid) when is_pid(CPid) -> - call(CPid, stats); -stats(#state{transport = Transport, - socket = Socket, - proto_state = ProtoState}) -> - SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of - {ok, Ss} -> Ss; - {error, _} -> [] - end, - ChanStats = [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS], - SessStats = emqx_session:stats(emqx_protocol:info(session, ProtoState)), - lists:append([SockStats, ChanStats, SessStats, emqx_misc:proc_stats()]). - -state(CPid) -> call(CPid, get_state). - -%% @private -call(CPid, Req) -> - gen_statem:call(CPid, Req, infinity). +-spec(caps(channel()) -> emqx_types:caps()). +caps(#channel{client = #{zone := Zone}}) -> + emqx_mqtt_caps:get_caps(Zone). %%-------------------------------------------------------------------- -%% gen_statem callbacks +%% For unit tests %%-------------------------------------------------------------------- -init({Transport, RawSocket, Options}) -> - {ok, Socket} = Transport:wait(RawSocket), - {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]), - {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]), - Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]), - emqx_logger:set_metadata_peername(esockd_net:format(Peername)), +set(client, Client, Channel) -> + Channel#channel{client = Client}; +set(session, Session, Channel) -> + Channel#channel{session = Session}. + +%%-------------------------------------------------------------------- +%% Takeover session +%%-------------------------------------------------------------------- + +takeover('begin', Channel = #channel{session = Session}) -> + {ok, Session, Channel#channel{resuming = true}}; + +takeover('end', Channel = #channel{session = Session, + pendings = Pendings}) -> + ok = emqx_session:takeover(Session), + {ok, Pendings, Channel}. + +%%-------------------------------------------------------------------- +%% Init a channel +%%-------------------------------------------------------------------- + +-spec(init(emqx_types:conn(), proplists:proplist()) -> channel()). +init(ConnInfo, Options) -> Zone = proplists:get_value(zone, Options), - RateLimit = init_limiter(proplists:get_value(rate_limit, Options)), - PubLimit = init_limiter(emqx_zone:get_env(Zone, publish_limit)), - ActiveN = proplists:get_value(active_n, Options, ?ACTIVE_N), - MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE), - ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}), - ProtoState = emqx_protocol:init(#{peername => Peername, - sockname => Sockname, - peercert => Peercert, - conn_mod => ?MODULE}, Options), - GcPolicy = emqx_zone:get_env(Zone, force_gc_policy, false), - GcState = emqx_gc:init(GcPolicy), - EnableStats = emqx_zone:get_env(Zone, enable_stats, true), - StatsTimer = if EnableStats -> undefined; ?Otherwise -> disabled end, + Peercert = maps:get(peercert, ConnInfo, undefined), + Username = case peer_cert_as_username(Options) of + cn -> esockd_peercert:common_name(Peercert); + dn -> esockd_peercert:subject(Peercert); + crt -> Peercert; + _ -> undefined + end, + MountPoint = emqx_zone:get_env(Zone, mountpoint), + Client = maps:merge(#{zone => Zone, + username => Username, + mountpoint => MountPoint, + is_bridge => false, + is_superuser => false}, ConnInfo), IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), - ok = emqx_misc:init_proc_mng_policy(Zone), - State = #state{transport = Transport, - socket = Socket, - peername = Peername, - sockname = Sockname, - conn_state = running, - active_n = ActiveN, - rate_limit = RateLimit, - pub_limit = PubLimit, - parse_state = ParseState, - proto_state = ProtoState, - gc_state = GcState, - stats_timer = StatsTimer, - idle_timeout = IdleTimout, - connected = false - }, - gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}], - idle, State, self(), [IdleTimout]). + EnableStats = emqx_zone:get_env(Zone, enable_stats, true), + StatsTimer = if EnableStats -> undefined; + ?Otherwise -> disabled + end, + GcState = emqx_gc:init(emqx_zone:get_env(Zone, force_gc_policy, false)), + OomPolicy = emqx_oom:init(emqx_zone:get_env(Zone, force_shutdown_policy)), + #channel{client = Client, + proto_name = <<"MQTT">>, + proto_ver = ?MQTT_PROTO_V4, + keepalive = 0, + idle_timeout = IdleTimout, + stats_timer = StatsTimer, + gc_state = GcState, + oom_policy = OomPolicy, + connected = false + }. -init_limiter(undefined) -> - undefined; -init_limiter({Rate, Burst}) -> - esockd_rate_limit:new(Rate, Burst). - -callback_mode() -> - [state_functions, state_enter]. - -%%-------------------------------------------------------------------- -%% Idle State - -idle(enter, _, State) -> - case activate_socket(State) of - ok -> keep_state_and_data; - {error, Reason} -> - shutdown(Reason, State) - end; - -idle(timeout, _Timeout, State) -> - stop(idle_timeout, State); - -idle(cast, {incoming, Packet = ?CONNECT_PACKET( - #mqtt_packet_connect{ - proto_ver = ProtoVer} - )}, State) -> - State1 = State#state{serialize = serialize_fun(ProtoVer)}, - handle_incoming(Packet, fun(NewSt) -> - {next_state, connected, NewSt} - end, State1); - -idle(cast, {incoming, Packet}, State) -> - ?LOG(warning, "Unexpected incoming: ~p", [Packet]), - shutdown(unexpected_incoming_packet, State); - -idle(EventType, Content, State) -> - ?HANDLE(EventType, Content, State). - -%%-------------------------------------------------------------------- -%% Connected State - -connected(enter, _PrevSt, State = #state{proto_state = ProtoState}) -> - NState = State#state{connected = true, - connected_at = os:timestamp()}, - ClientId = emqx_protocol:info(client_id, ProtoState), - ok = emqx_cm:register_channel(ClientId), - ok = emqx_cm:set_chan_attrs(ClientId, info(NState)), - %% Ensure keepalive after connected successfully. - Interval = emqx_protocol:info(keepalive, ProtoState), - case ensure_keepalive(Interval, NState) of - ignore -> keep_state(NState); - {ok, KeepAlive} -> - keep_state(NState#state{keepalive = KeepAlive}); - {error, Reason} -> - shutdown(Reason, NState) - end; - -connected(cast, {incoming, Packet = ?PACKET(?CONNECT)}, State) -> - ?LOG(warning, "Unexpected connect: ~p", [Packet]), - shutdown(unexpected_incoming_connect, State); - -connected(cast, {incoming, Packet}, State) when is_record(Packet, mqtt_packet) -> - handle_incoming(Packet, fun keep_state/1, State); - -connected(info, Deliver = {deliver, _Topic, _Msg}, - State = #state{proto_state = ProtoState}) -> - Delivers = emqx_misc:drain_deliver([Deliver]), - case emqx_protocol:handle_deliver(Delivers, ProtoState) of - {ok, NProtoState} -> - keep_state(State#state{proto_state = NProtoState}); - {ok, Packets, NProtoState} -> - NState = State#state{proto_state = NProtoState}, - handle_outgoing(Packets, fun keep_state/1, NState); - {error, Reason} -> - shutdown(Reason, State); - {error, Reason, NProtoState} -> - shutdown(Reason, State#state{proto_state = NProtoState}) - end; - -%% TODO: Improve later. -connected(info, {subscribe, TopicFilters}, State) -> - handle_request({subscribe, TopicFilters}, State); - -connected(info, {unsubscribe, TopicFilters}, State) -> - handle_request({unsubscribe, TopicFilters}, State); - -%% Keepalive timer -connected(info, {keepalive, check}, State = #state{keepalive = KeepAlive}) -> - case emqx_keepalive:check(KeepAlive) of - {ok, KeepAlive1} -> - keep_state(State#state{keepalive = KeepAlive1}); - {error, timeout} -> - shutdown(keepalive_timeout, State); - {error, Reason} -> - shutdown(Reason, State) - end; - -connected(EventType, Content, State) -> - ?HANDLE(EventType, Content, State). - -%%-------------------------------------------------------------------- -%% Disconnected State - -disconnected(enter, _, _State) -> - %% TODO: What to do? - %% CleanStart is true - keep_state_and_data; - -disconnected(EventType, Content, State) -> - ?HANDLE(EventType, Content, State). - -%% Handle call -handle({call, From}, info, State) -> - reply(From, info(State), State); - -handle({call, From}, attrs, State) -> - reply(From, attrs(State), State); - -handle({call, From}, stats, State) -> - reply(From, stats(State), State); - -handle({call, From}, get_state, State) -> - reply(From, State, State); - -%%handle({call, From}, kick, State) -> -%% ok = gen_statem:reply(From, ok), -%% shutdown(kicked, State); - -%%handle({call, From}, discard, State) -> -%% ok = gen_statem:reply(From, ok), -%% shutdown(discard, State); - -handle({call, From}, Req, State) -> - ?LOG(error, "Unexpected call: ~p", [Req]), - reply(From, ignored, State); - -%% Handle cast -handle(cast, Msg, State) -> - ?LOG(error, "Unexpected cast: ~p", [Msg]), - keep_state(State); - -%% Handle incoming data -handle(info, {Inet, _Sock, Data}, State) when Inet == tcp; - Inet == ssl -> - Oct = iolist_size(Data), - ?LOG(debug, "RECV ~p", [Data]), - emqx_pd:update_counter(incoming_bytes, Oct), - ok = emqx_metrics:inc('bytes.received', Oct), - NState = maybe_gc(1, Oct, State), - process_incoming(Data, ensure_stats_timer(NState)); - -handle(info, {Error, _Sock, Reason}, State) when Error == tcp_error; - Error == ssl_error -> - shutdown(Reason, State); - -handle(info, {Closed, _Sock}, State) when Closed == tcp_closed; - Closed == ssl_closed -> - shutdown(closed, State); - -handle(info, {Passive, _Sock}, State) when Passive == tcp_passive; - Passive == ssl_passive -> - %% Rate limit here:) - NState = ensure_rate_limit(State), - case activate_socket(NState) of - ok -> keep_state(NState); - {error, Reason} -> - shutdown(Reason, NState) - end; - -handle(info, activate_socket, State) -> - %% Rate limit timer expired. - NState = State#state{conn_state = running}, - case activate_socket(NState) of - ok -> - keep_state(NState#state{limit_timer = undefined}); - {error, Reason} -> - shutdown(Reason, NState) - end; - -handle(info, {inet_reply, _Sock, ok}, State) -> - %% something sent - keep_state(ensure_stats_timer(State)); - -handle(info, {inet_reply, _Sock, {error, Reason}}, State) -> - shutdown(Reason, State); - -handle(info, {timeout, Timer, emit_stats}, - State = #state{stats_timer = Timer, - proto_state = ProtoState, - gc_state = GcState}) -> - ClientId = emqx_protocol:info(client_id, ProtoState), - ok = emqx_cm:set_chan_stats(ClientId, stats(State)), - NState = State#state{stats_timer = undefined}, - Limits = erlang:get(force_shutdown_policy), - case emqx_misc:conn_proc_mng_policy(Limits) of - continue -> - keep_state(NState); - hibernate -> - %% going to hibernate, reset gc stats - GcState1 = emqx_gc:reset(GcState), - {keep_state, NState#state{gc_state = GcState1}, hibernate}; - {shutdown, Reason} -> - ?LOG(error, "Shutdown exceptionally due to ~p", [Reason]), - shutdown(Reason, NState) - end; - -handle(info, {timeout, Timer, Msg}, - State = #state{proto_state = ProtoState}) -> - case emqx_protocol:handle_timeout(Timer, Msg, ProtoState) of - {ok, NProtoState} -> - keep_state(State#state{proto_state = NProtoState}); - {ok, Packets, NProtoState} -> - handle_outgoing(Packets, fun keep_state/1, - State#state{proto_state = NProtoState}); - {error, Reason} -> - shutdown(Reason, State); - {error, Reason, NProtoState} -> - shutdown(Reason, State#state{proto_state = NProtoState}) - end; - -handle(info, {shutdown, discard, {ClientId, ByPid}}, State) -> - ?LOG(error, "Discarded by ~s:~p", [ClientId, ByPid]), - shutdown(discard, State); - -handle(info, {shutdown, conflict, {ClientId, NewPid}}, State) -> - ?LOG(warning, "Clientid '~s' conflict with ~p", [ClientId, NewPid]), - shutdown(conflict, State); - -handle(info, {shutdown, Reason}, State) -> - shutdown(Reason, State); - -handle(info, Info, State) -> - ?LOG(error, "Unexpected info: ~p", [Info]), - keep_state(State). - -code_change(_Vsn, State, Data, _Extra) -> - {ok, State, Data}. - -terminate(Reason, _StateName, #state{transport = Transport, - socket = Socket, - keepalive = KeepAlive, - proto_state = ProtoState}) -> - ?LOG(debug, "Terminated for ~p", [Reason]), - ok = Transport:fast_close(Socket), - ok = emqx_keepalive:cancel(KeepAlive), - emqx_protocol:terminate(Reason, ProtoState). - -%%-------------------------------------------------------------------- -%% Handle internal request - -handle_request(Req, State = #state{proto_state = ProtoState}) -> - case emqx_protocol:handle_req(Req, ProtoState) of - {ok, _Result, NProtoState} -> %% TODO:: how to handle the result? - keep_state(State#state{proto_state = NProtoState}); - {error, Reason, NProtoState} -> - shutdown(Reason, State#state{proto_state = NProtoState}) - end. - -%%-------------------------------------------------------------------- -%% Process incoming data - --compile({inline, [process_incoming/2]}). -process_incoming(Data, State) -> - process_incoming(Data, [], State). - -process_incoming(<<>>, Packets, State) -> - {keep_state, State, next_incoming_events(Packets)}; - -process_incoming(Data, Packets, State = #state{parse_state = ParseState}) -> - try emqx_frame:parse(Data, ParseState) of - {ok, NParseState} -> - NState = State#state{parse_state = NParseState}, - {keep_state, NState, next_incoming_events(Packets)}; - {ok, Packet, Rest, NParseState} -> - NState = State#state{parse_state = NParseState}, - process_incoming(Rest, [Packet|Packets], NState); - {error, Reason} -> - shutdown(Reason, State) - catch - error:Reason:Stk -> - ?LOG(error, "Parse failed for ~p~n\ - Stacktrace:~p~nError data:~p", [Reason, Stk, Data]), - shutdown(parse_error, State) - end. - -next_incoming_events(Packets) when is_list(Packets) -> - [next_event(cast, {incoming, Packet}) - || Packet <- lists:reverse(Packets)]. +peer_cert_as_username(Options) -> + proplists:get_value(peer_cert_as_username, Options). %%-------------------------------------------------------------------- %% Handle incoming packet +%%-------------------------------------------------------------------- -handle_incoming(Packet = ?PACKET(Type), SuccFun, - State = #state{proto_state = ProtoState}) -> - _ = inc_incoming_stats(Type), - ok = emqx_metrics:inc_recv(Packet), - ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]), - case emqx_protocol:handle_in(Packet, ProtoState) of - {ok, NProtoState} -> - SuccFun(State#state{proto_state = NProtoState}); - {ok, OutPackets, NProtoState} -> - handle_outgoing(OutPackets, SuccFun, - State#state{proto_state = NProtoState}); - {error, Reason, NProtoState} -> - shutdown(Reason, State#state{proto_state = NProtoState}); - {error, Reason, OutPacket, NProtoState} -> - Shutdown = fun(NewSt) -> shutdown(Reason, NewSt) end, - handle_outgoing(OutPacket, Shutdown, State#state{proto_state = NProtoState}); - {stop, Error, NProtoState} -> - stop(Error, State#state{proto_state = NProtoState}) - end. +-spec(handle_in(emqx_types:packet(), channel()) + -> {ok, channel()} + | {ok, emqx_types:packet(), channel()} + | {ok, list(emqx_types:packet()), channel()} + | {stop, Error :: term(), channel()} + | {stop, Error :: term(), emqx_types:packet(), channel()}). +handle_in(?CONNECT_PACKET(_), Channel = #channel{connected = true}) -> + handle_out(disconnect, ?RC_PROTOCOL_ERROR, Channel); + +handle_in(?CONNECT_PACKET( + #mqtt_packet_connect{proto_name = ProtoName, + proto_ver = ProtoVer, + keepalive = Keepalive, + client_id = ClientId + } = ConnPkt), Channel) -> + Channel1 = Channel#channel{proto_name = ProtoName, + proto_ver = ProtoVer, + keepalive = Keepalive + }, + ok = emqx_logger:set_metadata_client_id(ClientId), + case pipeline([fun validate_in/2, + fun process_props/2, + fun check_connect/2, + fun enrich_client/2, + fun auth_connect/2], ConnPkt, Channel1) of + {ok, NConnPkt, NChannel} -> + process_connect(NConnPkt, maybe_assign_clientid(NChannel)); + {error, ReasonCode, NChannel} -> + handle_out(disconnect, ReasonCode, NChannel) + end; + +handle_in(Packet = ?PUBLISH_PACKET(QoS, Topic, PacketId), Channel) -> + case pipeline([fun validate_in/2, + fun process_alias/2, + fun check_publish/2], Packet, Channel) of + {ok, NPacket, NChannel} -> + process_publish(NPacket, NChannel); + {error, ReasonCode, NChannel} -> + ?LOG(warning, "Cannot publish message to ~s due to ~s", + [Topic, emqx_reason_codes:text(ReasonCode)]), + case QoS of + ?QOS_0 -> handle_out(puberr, ReasonCode, NChannel); + ?QOS_1 -> handle_out(puback, {PacketId, ReasonCode}, NChannel); + ?QOS_2 -> handle_out(pubrec, {PacketId, ReasonCode}, NChannel) + end + end; + +%%TODO: How to handle the ReasonCode? +handle_in(?PUBACK_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Session}) -> + case emqx_session:puback(PacketId, Session) of + {ok, Publishes, NSession} -> + handle_out(publish, Publishes, Channel#channel{session = NSession}); + {ok, NSession} -> + {ok, Channel#channel{session = NSession}}; + {error, _NotFound} -> + %%TODO: How to handle NotFound, inc metrics? + {ok, Channel} + end; + +%%TODO: How to handle the ReasonCode? +handle_in(?PUBREC_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Session}) -> + case emqx_session:pubrec(PacketId, Session) of + {ok, NSession} -> + handle_out(pubrel, {PacketId, ?RC_SUCCESS}, Channel#channel{session = NSession}); + {error, ReasonCode} -> + handle_out(pubrel, {PacketId, ReasonCode}, Channel) + end; + +%%TODO: How to handle the ReasonCode? +handle_in(?PUBREL_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Session}) -> + case emqx_session:pubrel(PacketId, Session) of + {ok, NSession} -> + handle_out(pubcomp, {PacketId, ?RC_SUCCESS}, Channel#channel{session = NSession}); + {error, ReasonCode} -> + handle_out(pubcomp, {PacketId, ReasonCode}, Channel) + end; + +handle_in(?PUBCOMP_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Session}) -> + case emqx_session:pubcomp(PacketId, Session) of + {ok, Publishes, NSession} -> + handle_out(publish, Publishes, Channel#channel{session = NSession}); + {ok, NSession} -> + {ok, Channel#channel{session = NSession}}; + {error, _NotFound} -> + %% TODO: how to handle NotFound? + {ok, Channel} + end; + +handle_in(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), + Channel = #channel{client = Client}) -> + case validate_in(Packet, Channel) of + ok -> + TopicFilters1 = [emqx_topic:parse(TopicFilter, SubOpts) + || {TopicFilter, SubOpts} <- TopicFilters], + TopicFilters2 = emqx_hooks:run_fold('client.subscribe', + [Client, Properties], + TopicFilters1), + TopicFilters3 = enrich_subid(Properties, TopicFilters2), + {ReasonCodes, NChannel} = process_subscribe(TopicFilters3, Channel), + handle_out(suback, {PacketId, ReasonCodes}, NChannel); + {error, ReasonCode} -> + handle_out(disconnect, ReasonCode, Channel) + end; + +handle_in(Packet = ?UNSUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), + Channel = #channel{client = Client}) -> + case validate_in(Packet, Channel) of + ok -> + TopicFilters1 = lists:map(fun emqx_topic:parse/1, TopicFilters), + TopicFilters2 = emqx_hooks:run_fold('client.unsubscribe', + [Client, Properties], + TopicFilters1), + {ReasonCodes, NChannel} = process_unsubscribe(TopicFilters2, Channel), + handle_out(unsuback, {PacketId, ReasonCodes}, NChannel); + {error, ReasonCode} -> + handle_out(disconnect, ReasonCode, Channel) + end; + +handle_in(?PACKET(?PINGREQ), Channel) -> + {ok, ?PACKET(?PINGRESP), Channel}; + +handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), Channel) -> + %% Clear will msg + {stop, normal, Channel#channel{will_msg = undefined}}; + +handle_in(?DISCONNECT_PACKET(RC), Channel = #channel{proto_ver = Ver}) -> + {stop, {shutdown, emqx_reason_codes:name(RC, Ver)}, Channel}; + +handle_in(?AUTH_PACKET(), Channel) -> + %%TODO: implement later. + {ok, Channel}; + +handle_in(Packet, Channel) -> + ?LOG(error, "Unexpected incoming: ~p", [Packet]), + {stop, {shutdown, unexpected_incoming_packet}, Channel}. %%-------------------------------------------------------------------- -%% Handle outgoing packets - -handle_outgoing(Packets, SuccFun, State = #state{serialize = Serialize}) - when is_list(Packets) -> - send(lists:map(Serialize, Packets), SuccFun, State); - -handle_outgoing(Packet, SuccFun, State = #state{serialize = Serialize}) -> - send(Serialize(Packet), SuccFun, State). - +%% Process Connect %%-------------------------------------------------------------------- -%% Serialize fun -serialize_fun(ProtoVer) -> - fun(Packet = ?PACKET(Type)) -> - ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), - _ = inc_outgoing_stats(Type), - emqx_frame:serialize(Packet, ProtoVer) - end. - -%%-------------------------------------------------------------------- -%% Send data - -send(IoData, SuccFun, State = #state{transport = Transport, - socket = Socket}) -> - Oct = iolist_size(IoData), - ok = emqx_metrics:inc('bytes.sent', Oct), - case Transport:async_send(Socket, IoData) of - ok -> SuccFun(maybe_gc(1, Oct, State)); +process_connect(ConnPkt, Channel) -> + case open_session(ConnPkt, Channel) of + {ok, Session, SP} -> + WillMsg = emqx_packet:will_msg(ConnPkt), + NChannel = Channel#channel{session = Session, + will_msg = WillMsg, + connected = true, + connected_at = os:timestamp() + }, + handle_out(connack, {?RC_SUCCESS, sp(SP)}, NChannel); {error, Reason} -> - shutdown(Reason, State) + %% TODO: Unknown error? + ?LOG(error, "Failed to open session: ~p", [Reason]), + handle_out(connack, ?RC_UNSPECIFIED_ERROR, Channel) end. %%-------------------------------------------------------------------- -%% Ensure keepalive - -ensure_keepalive(0, _State) -> - ignore; -ensure_keepalive(Interval, #state{transport = Transport, - socket = Socket, - proto_state = ProtoState}) -> - StatFun = fun() -> - case Transport:getstat(Socket, [recv_oct]) of - {ok, [{recv_oct, RecvOct}]} -> - {ok, RecvOct}; - Error -> Error - end - end, - Backoff = emqx_zone:get_env(emqx_protocol:info(zone, ProtoState), - keepalive_backoff, 0.75), - emqx_keepalive:start(StatFun, round(Interval * Backoff), {keepalive, check}). - +%% Process Publish %%-------------------------------------------------------------------- -%% Ensure rate limit -ensure_rate_limit(State = #state{rate_limit = Rl, pub_limit = Pl}) -> - Limiters = [{Pl, #state.pub_limit, emqx_pd:reset_counter(incoming_pubs)}, - {Rl, #state.rate_limit, emqx_pd:reset_counter(incoming_bytes)}], - ensure_rate_limit(Limiters, State). +%% Process Publish +process_publish(Packet = ?PUBLISH_PACKET(_QoS, _Topic, PacketId), + Channel = #channel{client = Client}) -> + Msg = emqx_packet:to_message(Client, Packet), + %%TODO: Improve later. + Msg1 = emqx_message:set_flag(dup, false, Msg), + process_publish(PacketId, mount(Client, Msg1), Channel). -ensure_rate_limit([], State) -> - State; -ensure_rate_limit([{undefined, _Pos, _Cnt}|Limiters], State) -> - ensure_rate_limit(Limiters, State); -ensure_rate_limit([{Rl, Pos, Cnt}|Limiters], State) -> - case esockd_rate_limit:check(Cnt, Rl) of - {0, Rl1} -> - ensure_rate_limit(Limiters, setelement(Pos, State, Rl1)); - {Pause, Rl1} -> - ?LOG(debug, "Rate limit pause connection ~pms", [Pause]), - TRef = erlang:send_after(Pause, self(), activate_socket), - setelement(Pos, State#state{conn_state = blocked, - limit_timer = TRef}, Rl1) +process_publish(_PacketId, Msg = #message{qos = ?QOS_0}, Channel) -> + _ = emqx_broker:publish(Msg), + {ok, Channel}; + +process_publish(PacketId, Msg = #message{qos = ?QOS_1}, Channel) -> + Deliveries = emqx_broker:publish(Msg), + ReasonCode = emqx_reason_codes:puback(Deliveries), + handle_out(puback, {PacketId, ReasonCode}, Channel); + +process_publish(PacketId, Msg = #message{qos = ?QOS_2}, + Channel = #channel{session = Session}) -> + case emqx_session:publish(PacketId, Msg, Session) of + {ok, Deliveries, NSession} -> + ReasonCode = emqx_reason_codes:puback(Deliveries), + handle_out(pubrec, {PacketId, ReasonCode}, + Channel#channel{session = NSession}); + {error, ReasonCode} -> + handle_out(pubrec, {PacketId, ReasonCode}, Channel) end. %%-------------------------------------------------------------------- -%% Activate Socket +%% Process Subscribe +%%-------------------------------------------------------------------- -activate_socket(#state{conn_state = blocked}) -> +process_subscribe(TopicFilters, Channel) -> + process_subscribe(TopicFilters, [], Channel). + +process_subscribe([], Acc, Channel) -> + {lists:reverse(Acc), Channel}; + +process_subscribe([{TopicFilter, SubOpts}|More], Acc, Channel) -> + {RC, NChannel} = do_subscribe(TopicFilter, SubOpts, Channel), + process_subscribe(More, [RC|Acc], NChannel). + +do_subscribe(TopicFilter, SubOpts = #{qos := QoS}, + Channel = #channel{client = Client, session = Session}) -> + case check_subscribe(TopicFilter, SubOpts, Channel) of + ok -> TopicFilter1 = mount(Client, TopicFilter), + SubOpts1 = enrich_subopts(maps:merge(?DEFAULT_SUBOPTS, SubOpts), Channel), + case emqx_session:subscribe(Client, TopicFilter1, SubOpts1, Session) of + {ok, NSession} -> + {QoS, Channel#channel{session = NSession}}; + {error, RC} -> {RC, Channel} + end; + {error, RC} -> {RC, Channel} + end. + +%%-------------------------------------------------------------------- +%% Process Unsubscribe +%%-------------------------------------------------------------------- + +process_unsubscribe(TopicFilters, Channel) -> + process_unsubscribe(TopicFilters, [], Channel). + +process_unsubscribe([], Acc, Channel) -> + {lists:reverse(Acc), Channel}; + +process_unsubscribe([{TopicFilter, SubOpts}|More], Acc, Channel) -> + {RC, Channel1} = do_unsubscribe(TopicFilter, SubOpts, Channel), + process_unsubscribe(More, [RC|Acc], Channel1). + +do_unsubscribe(TopicFilter, _SubOpts, Channel = #channel{client = Client, + session = Session}) -> + case emqx_session:unsubscribe(Client, mount(Client, TopicFilter), Session) of + {ok, NSession} -> + {?RC_SUCCESS, Channel#channel{session = NSession}}; + {error, RC} -> {RC, Channel} + end. + +%%-------------------------------------------------------------------- +%% Handle outgoing packet +%%-------------------------------------------------------------------- + +handle_out(Deliver = {deliver, _Topic, _Msg}, + Channel = #channel{resuming = true, pendings = Pendings}) -> + Delivers = emqx_misc:drain_deliver([Deliver]), + {ok, Channel#channel{pendings = lists:append(Pendings, Delivers)}}; + +handle_out(Deliver = {deliver, _Topic, _Msg}, Channel = #channel{session = Session}) -> + Delivers = emqx_misc:drain_deliver([Deliver]), + case emqx_session:deliver(Delivers, Session) of + {ok, Publishes, NSession} -> + handle_out(publish, Publishes, Channel#channel{session = NSession}); + {ok, NSession} -> + {ok, Channel#channel{session = NSession}} + end; + +handle_out({publish, PacketId, Msg}, Channel = #channel{client = Client}) -> + Msg1 = emqx_hooks:run_fold('message.deliver', [Client], + emqx_message:update_expiry(Msg)), + Packet = emqx_packet:from_message(PacketId, unmount(Client, Msg1)), + {ok, Packet, Channel}. + +handle_out(connack, {?RC_SUCCESS, SP}, + Channel = #channel{client = Client = #{zone := Zone}, + ack_props = AckProps, + alias_maximum = AliasMaximum}) -> + ok = emqx_hooks:run('client.connected', [Client, ?RC_SUCCESS, attrs(Channel)]), + #{max_packet_size := MaxPktSize, + max_qos_allowed := MaxQoS, + retain_available := Retain, + max_topic_alias := MaxAlias, + shared_subscription := Shared, + wildcard_subscription := Wildcard + } = caps(Channel), + %% Response-Information is so far not set by broker. + %% i.e. It's a Client-to-Client contract for the request-response topic naming scheme. + %% According to MQTT 5.0 spec: + %% A common use of this is to pass a globally unique portion of the topic tree which + %% is reserved for this Client for at least the lifetime of its Session. + %% This often cannot just be a random name as both the requesting Client and the + %% responding Client need to be authorized to use it. + %% If we are to support it in the feature, the implementation should be flexible + %% to allow prefixing the response topic based on different ACL config. + %% e.g. prefix by username or client-id, so that unauthorized clients can not + %% subscribe requests or responses that are not intended for them. + AckProps1 = if AckProps == undefined -> #{}; true -> AckProps end, + AckProps2 = AckProps1#{'Retain-Available' => flag(Retain), + 'Maximum-Packet-Size' => MaxPktSize, + 'Topic-Alias-Maximum' => MaxAlias, + 'Wildcard-Subscription-Available' => flag(Wildcard), + 'Subscription-Identifier-Available' => 1, + %'Response-Information' => + 'Shared-Subscription-Available' => flag(Shared), + 'Maximum-QoS' => MaxQoS + }, + AckProps3 = case emqx_zone:get_env(Zone, server_keepalive) of + undefined -> AckProps2; + Keepalive -> AckProps2#{'Server-Keep-Alive' => Keepalive} + end, + AliasMaximum1 = set_property(inbound, MaxAlias, AliasMaximum), + Channel1 = Channel#channel{alias_maximum = AliasMaximum1, + ack_props = undefined + }, + {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, AckProps3), Channel1}; + +handle_out(connack, ReasonCode, Channel = #channel{client = Client, + proto_ver = ProtoVer}) -> + ok = emqx_hooks:run('client.connected', [Client, ReasonCode, attrs(Channel)]), + ReasonCode1 = if + ProtoVer == ?MQTT_PROTO_V5 -> ReasonCode; + true -> emqx_reason_codes:compat(connack, ReasonCode) + end, + Reason = emqx_reason_codes:name(ReasonCode1, ProtoVer), + {stop, {shutdown, Reason}, ?CONNACK_PACKET(ReasonCode1), Channel}; + +handle_out(publish, Publishes, Channel) -> + Packets = [element(2, handle_out(Publish, Channel)) || Publish <- Publishes], + {ok, Packets, Channel}; + +%% TODO: How to handle the puberr? +handle_out(puberr, _ReasonCode, Channel) -> + {ok, Channel}; + +handle_out(puback, {PacketId, ReasonCode}, Channel) -> + {ok, ?PUBACK_PACKET(PacketId, ReasonCode), Channel}; + +handle_out(pubrel, {PacketId, ReasonCode}, Channel) -> + {ok, ?PUBREL_PACKET(PacketId, ReasonCode), Channel}; + +handle_out(pubrec, {PacketId, ReasonCode}, Channel) -> + {ok, ?PUBREC_PACKET(PacketId, ReasonCode), Channel}; + +handle_out(pubcomp, {PacketId, ReasonCode}, Channel) -> + {ok, ?PUBCOMP_PACKET(PacketId, ReasonCode), Channel}; + +handle_out(suback, {PacketId, ReasonCodes}, + Channel = #channel{proto_ver = ?MQTT_PROTO_V5}) -> + %% TODO: ACL Deny + {ok, ?SUBACK_PACKET(PacketId, ReasonCodes), Channel}; + +handle_out(suback, {PacketId, ReasonCodes}, Channel) -> + %% TODO: ACL Deny + ReasonCodes1 = [emqx_reason_codes:compat(suback, RC) || RC <- ReasonCodes], + {ok, ?SUBACK_PACKET(PacketId, ReasonCodes1), Channel}; + +handle_out(unsuback, {PacketId, ReasonCodes}, + Channel = #channel{proto_ver = ?MQTT_PROTO_V5}) -> + {ok, ?UNSUBACK_PACKET(PacketId, ReasonCodes), Channel}; + +%% Ignore reason codes if not MQTT5 +handle_out(unsuback, {PacketId, _ReasonCodes}, Channel) -> + {ok, ?UNSUBACK_PACKET(PacketId), Channel}; + +handle_out(disconnect, ReasonCode, Channel = #channel{proto_ver = ?MQTT_PROTO_V5}) -> + Reason = emqx_reason_codes:name(ReasonCode), + {stop, {shutdown, Reason}, ?DISCONNECT_PACKET(ReasonCode), Channel}; + +handle_out(disconnect, ReasonCode, Channel = #channel{proto_ver = ProtoVer}) -> + {stop, {shutdown, emqx_reason_codes:name(ReasonCode, ProtoVer)}, Channel}; + +handle_out(Type, Data, Channel) -> + ?LOG(error, "Unexpected outgoing: ~s, ~p", [Type, Data]), + {ok, Channel}. + +%%-------------------------------------------------------------------- +%% Handle call +%%-------------------------------------------------------------------- + +handle_call(Req, Channel) -> + ?LOG(error, "Unexpected call: Req", [Req]), + {ok, ignored, Channel}. + +%%-------------------------------------------------------------------- +%% Handle cast +%%-------------------------------------------------------------------- + +handle_cast(discard, Channel) -> + {stop, {shutdown, discarded}, Channel}; + +handle_cast(Msg, Channel) -> + ?LOG(error, "Unexpected cast: ~p", [Msg]), + {ok, Channel}. + +%%-------------------------------------------------------------------- +%% Handle Info +%%-------------------------------------------------------------------- + +-spec(handle_info(Info :: term(), channel()) + -> {ok, channel()} | {stop, Reason :: term(), channel()}). +handle_info({subscribe, TopicFilters}, Channel = #channel{client = Client}) -> + TopicFilters1 = emqx_hooks:run_fold('client.subscribe', + [Client, #{'Internal' => true}], + parse(subscribe, TopicFilters)), + {_ReasonCodes, NChannel} = process_subscribe(TopicFilters1, Channel), + {ok, NChannel}; + +handle_info({unsubscribe, TopicFilters}, Channel = #channel{client = Client}) -> + TopicFilters1 = emqx_hooks:run_fold('client.unsubscribe', + [Client, #{'Internal' => true}], + parse(unsubscribe, TopicFilters)), + {_ReasonCodes, NChannel} = process_unsubscribe(TopicFilters1, Channel), + {ok, NChannel}; + +handle_info(Info, Channel) -> + ?LOG(error, "Unexpected info: ~p~n", [Info]), + {ok, Channel}. + +%%-------------------------------------------------------------------- +%% Handle timeout +%%-------------------------------------------------------------------- + +-spec(timeout(reference(), Msg :: term(), channel()) + -> {ok, channel()} + | {ok, Result :: term(), channel()} + | {stop, Reason :: term(), channel()}). +timeout(TRef, retry_deliver, Channel = #channel{%%session = Session, + retry_timer = TRef}) -> + %% case emqx_session:retry(Session) of + %% TODO: ... + {ok, Channel#channel{retry_timer = undefined}}; + +timeout(TRef, emit_stats, Channel = #channel{stats_timer = TRef}) -> + ClientId = info(client_id, Channel), + %% ok = emqx_cm:set_chan_stats(ClientId, stats(Channel)), + {ok, Channel#channel{stats_timer = undefined}}; + +timeout(_TRef, Msg, Channel) -> + ?LOG(error, "Unexpected timeout: ~p~n", [Msg]), + {ok, Channel}. + +%%-------------------------------------------------------------------- +%% Ensure timers +%%-------------------------------------------------------------------- + +ensure_timer(retry, Channel = #channel{session = Session, + retry_timer = undefined}) -> + Interval = emqx_session:info(retry_interval, Session), + TRef = emqx_misc:start_timer(Interval, retry_deliver), + Channel#channel{retry_timer = TRef}; + +ensure_timer(stats, Channel = #channel{stats_timer = undefined, + idle_timeout = IdleTimeout}) -> + TRef = emqx_misc:start_timer(IdleTimeout, emit_stats), + Channel#channel{stats_timer = TRef}; + +%% disabled or timer existed +ensure_timer(_Name, Channel) -> + Channel. + +%%-------------------------------------------------------------------- +%% Terminate +%%-------------------------------------------------------------------- + +terminate(normal, #channel{client = Client}) -> + ok = emqx_hooks:run('client.disconnected', [Client, normal]); +terminate(Reason, #channel{client = Client, will_msg = WillMsg}) -> + ok = emqx_hooks:run('client.disconnected', [Client, Reason]), + publish_will_msg(WillMsg). + +%%TODO: Improve will msg:) +publish_will_msg(undefined) -> ok; -activate_socket(#state{transport = Transport, - socket = Socket, - active_n = N}) -> - Transport:setopts(Socket, [{active, N}]). +publish_will_msg(Msg) -> + emqx_broker:publish(Msg). %%-------------------------------------------------------------------- -%% Inc incoming/outgoing stats +%% GC the channel. +%%-------------------------------------------------------------------- --compile({inline, - [ inc_incoming_stats/1 - , inc_outgoing_stats/1 - ]}). +gc(_Cnt, _Oct, Channel = #channel{gc_state = undefined}) -> + Channel; +gc(Cnt, Oct, Channel = #channel{gc_state = GCSt}) -> + {Ok, GCSt1} = emqx_gc:run(Cnt, Oct, GCSt), + Ok andalso emqx_metrics:inc('channel.gc.cnt'), + Channel#channel{gc_state = GCSt1}. -inc_incoming_stats(Type) -> - emqx_pd:update_counter(recv_pkt, 1), - case Type == ?PUBLISH of +%%-------------------------------------------------------------------- +%% Validate incoming packet +%%-------------------------------------------------------------------- + +-spec(validate_in(emqx_types:packet(), channel()) + -> ok | {error, emqx_types:reason_code()}). +validate_in(Packet, _Channel) -> + try emqx_packet:validate(Packet) of + true -> ok + catch + error:protocol_error -> + {error, ?RC_PROTOCOL_ERROR}; + error:subscription_identifier_invalid -> + {error, ?RC_SUBSCRIPTION_IDENTIFIERS_NOT_SUPPORTED}; + error:topic_alias_invalid -> + {error, ?RC_TOPIC_ALIAS_INVALID}; + error:topic_filters_invalid -> + {error, ?RC_TOPIC_FILTER_INVALID}; + error:topic_name_invalid -> + {error, ?RC_TOPIC_FILTER_INVALID}; + error:_Reason -> + {error, ?RC_MALFORMED_PACKET} + end. + +%%-------------------------------------------------------------------- +%% Preprocess properties +%%-------------------------------------------------------------------- + +process_props(#mqtt_packet_connect{ + properties = #{'Topic-Alias-Maximum' := Max} + }, + Channel = #channel{alias_maximum = AliasMaximum}) -> + NAliasMaximum = if AliasMaximum == undefined -> + #{outbound => Max}; + true -> AliasMaximum#{outbound => Max} + end, + {ok, Channel#channel{alias_maximum = NAliasMaximum}}; + +process_props(Packet, Channel) -> + {ok, Packet, Channel}. + +%%-------------------------------------------------------------------- +%% Check connect packet +%%-------------------------------------------------------------------- + +check_connect(ConnPkt, Channel) -> + case pipeline([fun check_proto_ver/2, + fun check_client_id/2, + %%fun check_flapping/2, + fun check_banned/2, + fun check_will_topic/2, + fun check_will_retain/2], ConnPkt, Channel) of + ok -> {ok, Channel}; + Error -> Error + end. + +check_proto_ver(#mqtt_packet_connect{proto_ver = Ver, + proto_name = Name}, _Channel) -> + case lists:member({Ver, Name}, ?PROTOCOL_NAMES) of + true -> ok; + false -> {error, ?RC_PROTOCOL_ERROR} + end. + +%% MQTT3.1 does not allow null clientId +check_client_id(#mqtt_packet_connect{proto_ver = ?MQTT_PROTO_V3, + client_id = <<>> + }, _Channel) -> + {error, ?RC_CLIENT_IDENTIFIER_NOT_VALID}; + +%% Issue#599: Null clientId and clean_start = false +check_client_id(#mqtt_packet_connect{client_id = <<>>, + clean_start = false}, _Channel) -> + {error, ?RC_CLIENT_IDENTIFIER_NOT_VALID}; + +check_client_id(#mqtt_packet_connect{client_id = <<>>, + clean_start = true}, _Channel) -> + ok; + +check_client_id(#mqtt_packet_connect{client_id = ClientId}, + #channel{client = #{zone := Zone}}) -> + Len = byte_size(ClientId), + MaxLen = emqx_zone:get_env(Zone, max_clientid_len), + case (1 =< Len) andalso (Len =< MaxLen) of + true -> ok; + false -> {error, ?RC_CLIENT_IDENTIFIER_NOT_VALID} + end. + +%%TODO: check banned... +check_banned(#mqtt_packet_connect{client_id = ClientId, + username = Username}, + #channel{client = Client = #{zone := Zone}}) -> + case emqx_zone:get_env(Zone, enable_ban, false) of true -> - emqx_pd:update_counter(recv_msg, 1), - emqx_pd:update_counter(incoming_pubs, 1); + case emqx_banned:check(Client#{client_id => ClientId, + username => Username}) of + true -> {error, ?RC_BANNED}; + false -> ok + end; false -> ok end. -inc_outgoing_stats(Type) -> - emqx_pd:update_counter(send_pkt, 1), - (Type == ?PUBLISH) - andalso emqx_pd:update_counter(send_msg, 1). +check_will_topic(#mqtt_packet_connect{will_flag = false}, _Channel) -> + ok; +check_will_topic(#mqtt_packet_connect{will_topic = WillTopic}, _Channel) -> + try emqx_topic:validate(WillTopic) of + true -> ok + catch error:_Error -> + {error, ?RC_TOPIC_NAME_INVALID} + end. + +check_will_retain(#mqtt_packet_connect{will_retain = false}, _Channel) -> + ok; +check_will_retain(#mqtt_packet_connect{will_retain = true}, + #channel{client = #{zone := Zone}}) -> + case emqx_zone:get_env(Zone, mqtt_retain_available, true) of + true -> ok; + false -> {error, ?RC_RETAIN_NOT_SUPPORTED} + end. %%-------------------------------------------------------------------- -%% Ensure stats timer +%% Enrich client +%%-------------------------------------------------------------------- -ensure_stats_timer(State = #state{stats_timer = undefined, - idle_timeout = IdleTimeout}) -> - TRef = emqx_misc:start_timer(IdleTimeout, emit_stats), - State#state{stats_timer = TRef}; -%% disabled or timer existed -ensure_stats_timer(State) -> State. +enrich_client(#mqtt_packet_connect{client_id = ClientId, + username = Username, + is_bridge = IsBridge + }, + Channel = #channel{client = Client}) -> + Client1 = set_username(Username, Client#{client_id => ClientId, + is_bridge => IsBridge + }), + {ok, Channel#channel{client = maybe_username_as_clientid(Client1)}}. + +%% Username maybe not undefined if peer_cert_as_username +set_username(Username, Client = #{username := undefined}) -> + Client#{username => Username}; +set_username(_Username, Client) -> Client. + +maybe_username_as_clientid(Client = #{username := undefined}) -> + Client; +maybe_username_as_clientid(Client = #{zone := Zone, + username := Username}) -> + case emqx_zone:get_env(Zone, use_username_as_clientid, false) of + true -> Client#{client_id => Username}; + false -> Client + end. %%-------------------------------------------------------------------- -%% Maybe GC +%% Auth Connect +%%-------------------------------------------------------------------- -maybe_gc(_Cnt, _Oct, State = #state{gc_state = undefined}) -> - State; -maybe_gc(Cnt, Oct, State = #state{gc_state = GCSt}) -> - {Ok, GCSt1} = emqx_gc:run(Cnt, Oct, GCSt), - Ok andalso emqx_metrics:inc('channel.gc.cnt'), - State#state{gc_state = GCSt1}. +auth_connect(#mqtt_packet_connect{client_id = ClientId, + username = Username, + password = Password}, + Channel = #channel{client = Client}) -> + case authenticate(Client#{password => Password}) of + {ok, AuthResult} -> + {ok, Channel#channel{client = maps:merge(Client, AuthResult)}}; + {error, Reason} -> + ?LOG(warning, "Client ~s (Username: '~s') login failed for ~0p", + [ClientId, Username, Reason]), + {error, emqx_reason_codes:connack_error(Reason)} + end. + +%%-------------------------------------------------------------------- +%% Open session +%%-------------------------------------------------------------------- + +open_session(#mqtt_packet_connect{clean_start = CleanStart, + properties = ConnProps}, + #channel{client = Client = #{zone := Zone}}) -> + MaxInflight = get_property('Receive-Maximum', ConnProps, + emqx_zone:get_env(Zone, max_inflight, 65535)), + Interval = get_property('Session-Expiry-Interval', ConnProps, + emqx_zone:get_env(Zone, session_expiry_interval, 0)), + emqx_cm:open_session(CleanStart, Client, #{max_inflight => MaxInflight, + expiry_interval => Interval + }). + +%%-------------------------------------------------------------------- +%% Assign a random clientId +%%-------------------------------------------------------------------- + +maybe_assign_clientid(Channel = #channel{client = Client = #{client_id := <<>>}, + ack_props = AckProps}) -> + ClientId = emqx_guid:to_base62(emqx_guid:gen()), + Client1 = Client#{client_id => ClientId}, + AckProps1 = set_property('Assigned-Client-Identifier', ClientId, AckProps), + Channel#channel{client = Client1, ack_props = AckProps1}; +maybe_assign_clientid(Channel) -> Channel. + +%%-------------------------------------------------------------------- +%% Process publish message: Client -> Broker +%%-------------------------------------------------------------------- + +process_alias(Packet = #mqtt_packet{ + variable = #mqtt_packet_publish{topic_name = <<>>, + properties = #{'Topic-Alias' := AliasId} + } = Publish + }, Channel = #channel{topic_aliases = Aliases}) -> + case find_alias(AliasId, Aliases) of + {ok, Topic} -> + {ok, Packet#mqtt_packet{ + variable = Publish#mqtt_packet_publish{ + topic_name = Topic}}, Channel}; + false -> {error, ?RC_TOPIC_ALIAS_INVALID} + end; + +process_alias(#mqtt_packet{ + variable = #mqtt_packet_publish{topic_name = Topic, + properties = #{'Topic-Alias' := AliasId} + } + }, Channel = #channel{topic_aliases = Aliases}) -> + {ok, Channel#channel{topic_aliases = save_alias(AliasId, Topic, Aliases)}}; + +process_alias(_Packet, Channel) -> + {ok, Channel}. + +find_alias(_AliasId, undefined) -> + false; +find_alias(AliasId, Aliases) -> + maps:find(AliasId, Aliases). + +save_alias(AliasId, Topic, undefined) -> + #{AliasId => Topic}; +save_alias(AliasId, Topic, Aliases) -> + maps:put(AliasId, Topic, Aliases). + +%% Check Publish +check_publish(Packet, Channel) -> + pipeline([fun check_pub_acl/2, + fun check_pub_alias/2, + fun check_pub_caps/2], Packet, Channel). + +%% Check Pub ACL +check_pub_acl(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = Topic}}, + #channel{client = Client}) -> + case is_acl_enabled(Client) andalso check_acl(Client, publish, Topic) of + false -> ok; + allow -> ok; + deny -> {error, ?RC_NOT_AUTHORIZED} + end. + +%% Check Pub Alias +check_pub_alias(#mqtt_packet{ + variable = #mqtt_packet_publish{ + properties = #{'Topic-Alias' := AliasId} + } + }, + #channel{alias_maximum = Limits}) -> + case (Limits == undefined) + orelse (Max = maps:get(inbound, Limits, 0)) == 0 + orelse (AliasId > Max) of + false -> ok; + true -> {error, ?RC_TOPIC_ALIAS_INVALID} + end; +check_pub_alias(_Packet, _Channel) -> ok. + +%% Check Pub Caps +check_pub_caps(#mqtt_packet{header = #mqtt_packet_header{qos = QoS, + retain = Retain + } + }, + #channel{client = #{zone := Zone}}) -> + emqx_mqtt_caps:check_pub(Zone, #{qos => QoS, retain => Retain}). + +%% Check Sub +check_subscribe(TopicFilter, SubOpts, Channel) -> + case check_sub_acl(TopicFilter, Channel) of + allow -> check_sub_caps(TopicFilter, SubOpts, Channel); + deny -> {error, ?RC_NOT_AUTHORIZED} + end. + +%% Check Sub ACL +check_sub_acl(TopicFilter, #channel{client = Client}) -> + case is_acl_enabled(Client) andalso + check_acl(Client, subscribe, TopicFilter) of + false -> allow; + Result -> Result + end. + +%% Check Sub Caps +check_sub_caps(TopicFilter, SubOpts, #channel{client = #{zone := Zone}}) -> + emqx_mqtt_caps:check_sub(Zone, TopicFilter, SubOpts). + +enrich_subid(#{'Subscription-Identifier' := SubId}, TopicFilters) -> + [{Topic, SubOpts#{subid => SubId}} || {Topic, SubOpts} <- TopicFilters]; +enrich_subid(_Properties, TopicFilters) -> + TopicFilters. + +enrich_subopts(SubOpts, #channel{proto_ver = ?MQTT_PROTO_V5}) -> + SubOpts; +enrich_subopts(SubOpts, #channel{client = #{zone := Zone, is_bridge := IsBridge}}) -> + Rap = flag(IsBridge), + Nl = flag(emqx_zone:get_env(Zone, ignore_loop_deliver, false)), + SubOpts#{rap => Rap, nl => Nl}. + +%%-------------------------------------------------------------------- +%% Is ACL enabled? +%%-------------------------------------------------------------------- + +is_acl_enabled(#{zone := Zone, is_superuser := IsSuperuser}) -> + (not IsSuperuser) andalso emqx_zone:get_env(Zone, enable_acl, true). + +%%-------------------------------------------------------------------- +%% Parse Topic Filters +%%-------------------------------------------------------------------- + +parse(subscribe, TopicFilters) -> + [emqx_topic:parse(TopicFilter, SubOpts) || {TopicFilter, SubOpts} <- TopicFilters]; + +parse(unsubscribe, TopicFilters) -> + lists:map(fun emqx_topic:parse/1, TopicFilters). + +%%-------------------------------------------------------------------- +%% Mount/Unmount +%%-------------------------------------------------------------------- + +mount(Client = #{mountpoint := MountPoint}, TopicOrMsg) -> + emqx_mountpoint:mount( + emqx_mountpoint:replvar(MountPoint, Client), TopicOrMsg). + +unmount(Client = #{mountpoint := MountPoint}, TopicOrMsg) -> + emqx_mountpoint:unmount( + emqx_mountpoint:replvar(MountPoint, Client), TopicOrMsg). + +%%-------------------------------------------------------------------- +%% Pipeline +%%-------------------------------------------------------------------- + +pipeline([], Packet, Channel) -> + {ok, Packet, Channel}; + +pipeline([Fun|More], Packet, Channel) -> + case Fun(Packet, Channel) of + ok -> pipeline(More, Packet, Channel); + {ok, NChannel} -> + pipeline(More, Packet, NChannel); + {ok, NPacket, NChannel} -> + pipeline(More, NPacket, NChannel); + {error, ReasonCode} -> + {error, ReasonCode, Channel}; + {error, ReasonCode, NChannel} -> + {error, ReasonCode, NChannel} + end. %%-------------------------------------------------------------------- %% Helper functions +%%-------------------------------------------------------------------- --compile({inline, - [ reply/3 - , keep_state/1 - , next_event/2 - , shutdown/2 - , stop/2 - ]}). +set_property(Name, Value, ?NO_PROPS) -> + #{Name => Value}; +set_property(Name, Value, Props) -> + Props#{Name => Value}. -reply(From, Reply, State) -> - {keep_state, State, [{reply, From, Reply}]}. +get_property(_Name, undefined, Default) -> + Default; +get_property(Name, Props, Default) -> + maps:get(Name, Props, Default). -keep_state(State) -> - {keep_state, State}. +sp(true) -> 1; +sp(false) -> 0. -next_event(Type, Content) -> - {next_event, Type, Content}. - -shutdown(Reason, State) -> - stop({shutdown, Reason}, State). - -stop(Reason, State) -> - {stop, Reason, State}. +flag(true) -> 1; +flag(false) -> 0. diff --git a/src/emqx_cm.erl b/src/emqx_cm.erl index a4800e839..43831eea5 100644 --- a/src/emqx_cm.erl +++ b/src/emqx_cm.erl @@ -44,7 +44,7 @@ -export([ open_session/3 , discard_session/1 - , resume_session/1 + , takeover_session/1 ]). -export([ lookup_channels/1 @@ -179,35 +179,47 @@ open_session(true, Client = #{client_id := ClientId}, Options) -> open_session(false, Client = #{client_id := ClientId}, Options) -> ResumeStart = fun(_) -> - case resume_session(ClientId) of - {ok, Session} -> - {ok, Session, true}; + case takeover_session(ClientId) of + {ok, ConnMod, ChanPid, Session} -> + {ok, NSession} = emqx_session:resume(ClientId, Session), + {ok, Pendings} = ConnMod:takeover(ChanPid, 'end'), + io:format("Pending Delivers: ~p~n", [Pendings]), + {ok, NSession, true}; {error, not_found} -> {ok, emqx_session:init(false, Client, Options), false} end end, emqx_cm_locker:trans(ClientId, ResumeStart). -%% @doc Try to resume a session. --spec(resume_session(emqx_types:client_id()) +%% @doc Try to takeover a session. +-spec(takeover_session(emqx_types:client_id()) -> {ok, emqx_session:session()} | {error, Reason :: term()}). -resume_session(ClientId) -> +takeover_session(ClientId) -> case lookup_channels(ClientId) of [] -> {error, not_found}; - [_ChanPid] -> - ok; - % emqx_channel:resume(ChanPid); + [ChanPid] -> + takeover_session(ClientId, ChanPid); ChanPids -> - [_ChanPid|StalePids] = lists:reverse(ChanPids), + [ChanPid|StalePids] = lists:reverse(ChanPids), ?LOG(error, "[SM] More than one channel found: ~p", [ChanPids]), - lists:foreach(fun(_StalePid) -> - % catch emqx_channel:discard(StalePid) - ok + lists:foreach(fun(StalePid) -> + catch discard_session(ClientId, StalePid) end, StalePids), - % emqx_channel:resume(ChanPid) - ok + takeover_session(ClientId, ChanPid) end. +takeover_session(ClientId, ChanPid) when node(ChanPid) == node() -> + case get_chan_attrs(ClientId, ChanPid) of + #{client := #{conn_mod := ConnMod}} -> + {ok, Session} = ConnMod:takeover(ChanPid, 'begin'), + {ok, ConnMod, ChanPid, Session}; + undefined -> + {error, not_found} + end; + +takeover_session(ClientId, ChanPid) -> + rpc_call(node(ChanPid), takeover_session, [ClientId, ChanPid]). + %% @doc Discard all the sessions identified by the ClientId. -spec(discard_session(emqx_types:client_id()) -> ok). discard_session(ClientId) when is_binary(ClientId) -> @@ -216,15 +228,25 @@ discard_session(ClientId) when is_binary(ClientId) -> ChanPids -> lists:foreach( fun(ChanPid) -> - try ok - % emqx_channel:discard(ChanPid) + try + discard_session(ClientId, ChanPid) catch _:Error:_Stk -> - ?LOG(warning, "[SM] Failed to discard ~p: ~p", [ChanPid, Error]) + ?LOG(error, "[SM] Failed to discard ~p: ~p", [ChanPid, Error]) end end, ChanPids) end. +discard_session(ClientId, ChanPid) when node(ChanPid) == node() -> + case get_chan_attrs(ClientId, ChanPid) of + #{conn_mod := ConnMod} -> + ConnMod:discard(ChanPid); + undefined -> ok + end; + +discard_session(ClientId, ChanPid) -> + rpc_call(node(ChanPid), discard_session, [ClientId, ChanPid]). + %% @doc Is clean start? % is_clean_start(#{clean_start := false}) -> false; % is_clean_start(_Attrs) -> true. @@ -314,8 +336,7 @@ code_change(_OldVsn, State, _Extra) -> %%-------------------------------------------------------------------- clean_down({ChanPid, ClientId}) -> - Chan = {ClientId, ChanPid}, - do_unregister_channel(Chan). + do_unregister_channel({ClientId, ChanPid}). stats_fun() -> lists:foreach(fun update_stats/1, ?CHAN_STATS). diff --git a/src/emqx_connection.erl b/src/emqx_connection.erl new file mode 100644 index 000000000..c525a1fc7 --- /dev/null +++ b/src/emqx_connection.erl @@ -0,0 +1,632 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +%% MQTT TCP/SSL Connection +-module(emqx_connection). + +-behaviour(gen_statem). + +-include("emqx.hrl"). +-include("emqx_mqtt.hrl"). +-include("logger.hrl"). +-include("types.hrl"). + +-logger_header("[Connection]"). + +-export([start_link/3]). + +%% APIs +-export([ info/1 + , attrs/1 + , stats/1 + ]). + +%% For Debug +-export([state/1]). + +-export([ kick/1 + , discard/1 + , takeover/2 + ]). + +%% state callbacks +-export([ idle/3 + , connected/3 + , disconnected/3 + , takeovering/3 + ]). + +%% gen_statem callbacks +-export([ init/1 + , callback_mode/0 + , code_change/4 + , terminate/3 + ]). + +-record(state, { + transport :: esockd:transport(), + socket :: esockd:socket(), + peername :: emqx_types:peername(), + sockname :: emqx_types:peername(), + conn_state :: running | blocked, + active_n :: pos_integer(), + rate_limit :: maybe(esockd_rate_limit:bucket()), + pub_limit :: maybe(esockd_rate_limit:bucket()), + limit_timer :: maybe(reference()), + parse_state :: emqx_frame:parse_state(), + serialize :: fun((emqx_types:packet()) -> iodata()), + chan_state :: emqx_channel:channel(), + keepalive :: maybe(emqx_keepalive:keepalive()) + }). + +-type(state() :: #state{}). + +-define(ACTIVE_N, 100). +-define(HANDLE(T, C, D), handle((T), (C), (D))). +-define(CHAN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). +-define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]). + +-spec(start_link(esockd:transport(), esockd:socket(), proplists:proplist()) + -> {ok, pid()}). +start_link(Transport, Socket, Options) -> + {ok, proc_lib:spawn_link(?MODULE, init, [{Transport, Socket, Options}])}. + +%%-------------------------------------------------------------------- +%% API +%%-------------------------------------------------------------------- + +%% @doc Get infos of the channel. +-spec(info(pid() | state()) -> emqx_types:infos()). +info(CPid) when is_pid(CPid) -> + call(CPid, info); +info(#state{transport = Transport, + socket = Socket, + peername = Peername, + sockname = Sockname, + conn_state = ConnState, + active_n = ActiveN, + rate_limit = RateLimit, + pub_limit = PubLimit, + chan_state = ChanState}) -> + ConnInfo = #{socktype => Transport:type(Socket), + peername => Peername, + sockname => Sockname, + conn_state => ConnState, + active_n => ActiveN, + rate_limit => limit_info(RateLimit), + pub_limit => limit_info(PubLimit) + }, + maps:merge(ConnInfo, emqx_channel:info(ChanState)). + +limit_info(undefined) -> + undefined; +limit_info(Limit) -> + esockd_rate_limit:info(Limit). + +%% @doc Get attrs of the channel. +-spec(attrs(pid() | state()) -> emqx_types:attrs()). +attrs(CPid) when is_pid(CPid) -> + call(CPid, attrs); +attrs(#state{transport = Transport, + socket = Socket, + peername = Peername, + sockname = Sockname, + chan_state = ChanState}) -> + ConnAttrs = #{socktype => Transport:type(Socket), + peername => Peername, + sockname => Sockname + }, + maps:merge(ConnAttrs, emqx_channel:attrs(ChanState)). + +%% @doc Get stats of the channel. +-spec(stats(pid() | state()) -> emqx_types:stats()). +stats(CPid) when is_pid(CPid) -> + call(CPid, stats); +stats(#state{transport = Transport, + socket = Socket, + chan_state = ChanState}) -> + SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of + {ok, Ss} -> Ss; + {error, _} -> [] + end, + ChanStats = [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS], + SessStats = emqx_session:stats(emqx_channel:info(session, ChanState)), + lists:append([SockStats, ChanStats, SessStats, emqx_misc:proc_stats()]). + +state(CPid) -> + call(CPid, get_state). + +-spec(kick(pid()) -> ok). +kick(CPid) -> + call(CPid, kick). + +-spec(discard(pid()) -> ok). +discard(CPid) -> + gen_statem:cast(CPid, discard). + +%% TODO: +-spec(takeover(pid(), 'begin'|'end') -> {ok, Result :: term()}). +takeover(CPid, Phase) -> + gen_statem:call(CPid, {takeover, Phase}). + +%% @private +call(CPid, Req) -> + gen_statem:call(CPid, Req, infinity). + +%%-------------------------------------------------------------------- +%% gen_statem callbacks +%%-------------------------------------------------------------------- + +init({Transport, RawSocket, Options}) -> + {ok, Socket} = Transport:wait(RawSocket), + {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]), + {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]), + Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]), + emqx_logger:set_metadata_peername(esockd_net:format(Peername)), + Zone = proplists:get_value(zone, Options), + RateLimit = init_limiter(proplists:get_value(rate_limit, Options)), + PubLimit = init_limiter(emqx_zone:get_env(Zone, publish_limit)), + ActiveN = proplists:get_value(active_n, Options, ?ACTIVE_N), + MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE), + ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}), + ChanState = emqx_channel:init(#{peername => Peername, + sockname => Sockname, + peercert => Peercert, + conn_mod => ?MODULE}, Options), + IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), + State = #state{transport = Transport, + socket = Socket, + peername = Peername, + sockname = Sockname, + conn_state = running, + active_n = ActiveN, + rate_limit = RateLimit, + pub_limit = PubLimit, + parse_state = ParseState, + chan_state = ChanState + }, + gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}], + idle, State, self(), [IdleTimout]). + +init_limiter(undefined) -> + undefined; +init_limiter({Rate, Burst}) -> + esockd_rate_limit:new(Rate, Burst). + +callback_mode() -> + [state_functions, state_enter]. + +%%-------------------------------------------------------------------- +%% Idle State + +idle(enter, _, State) -> + case activate_socket(State) of + ok -> keep_state_and_data; + {error, Reason} -> + shutdown(Reason, State) + end; + +idle(timeout, _Timeout, State) -> + stop(idle_timeout, State); + +idle(cast, {incoming, Packet = ?CONNECT_PACKET( + #mqtt_packet_connect{ + proto_ver = ProtoVer} + )}, State) -> + State1 = State#state{serialize = serialize_fun(ProtoVer)}, + handle_incoming(Packet, fun(NewSt) -> + {next_state, connected, NewSt} + end, State1); + +idle(cast, {incoming, Packet}, State) -> + ?LOG(warning, "Unexpected incoming: ~p", [Packet]), + shutdown(unexpected_incoming_packet, State); + +idle(EventType, Content, State) -> + ?HANDLE(EventType, Content, State). + +%%-------------------------------------------------------------------- +%% Connected State + +connected(enter, _PrevSt, State = #state{chan_state = ChanState}) -> + ClientId = emqx_channel:info(client_id, ChanState), + ok = emqx_cm:register_channel(ClientId), + ok = emqx_cm:set_chan_attrs(ClientId, info(State)), + %% Ensure keepalive after connected successfully. + Interval = emqx_channel:info(keepalive, ChanState), + case ensure_keepalive(Interval, State) of + ignore -> keep_state(State); + {ok, KeepAlive} -> + keep_state(State#state{keepalive = KeepAlive}); + {error, Reason} -> + shutdown(Reason, State) + end; + +connected(cast, {incoming, Packet = ?PACKET(?CONNECT)}, State) -> + ?LOG(warning, "Unexpected connect: ~p", [Packet]), + shutdown(unexpected_incoming_connect, State); + +connected(cast, {incoming, Packet}, State) when is_record(Packet, mqtt_packet) -> + handle_incoming(Packet, fun keep_state/1, State); + +connected(info, Deliver = {deliver, _Topic, _Msg}, + State = #state{chan_state = ChanState}) -> + case emqx_channel:handle_out(Deliver, ChanState) of + {ok, NChanState} -> + keep_state(State#state{chan_state = NChanState}); + {ok, Packets, NChanState} -> + NState = State#state{chan_state = NChanState}, + handle_outgoing(Packets, fun keep_state/1, NState); + {stop, Reason, NChanState} -> + stop(Reason, State#state{chan_state = NChanState}) + end; + +%% Keepalive timer +connected(info, {keepalive, check}, State = #state{keepalive = KeepAlive}) -> + case emqx_keepalive:check(KeepAlive) of + {ok, KeepAlive1} -> + keep_state(State#state{keepalive = KeepAlive1}); + {error, timeout} -> + shutdown(keepalive_timeout, State); + {error, Reason} -> + shutdown(Reason, State) + end; + +connected(EventType, Content, State) -> + ?HANDLE(EventType, Content, State). + +%%-------------------------------------------------------------------- +%% Disconnected State + +disconnected(enter, _, _State) -> + %% TODO: What to do? + %% CleanStart is true + keep_state_and_data; + +disconnected(EventType, Content, State) -> + ?HANDLE(EventType, Content, State). + +%%-------------------------------------------------------------------- +%% Takeovering State + +takeovering(enter, _PreState, State) -> + {keep_state, State}; + +takeovering(EventType, Content, State) -> + ?HANDLE(EventType, Content, State). + +%% Handle call +handle({call, From}, info, State) -> + reply(From, info(State), State); + +handle({call, From}, attrs, State) -> + reply(From, attrs(State), State); + +handle({call, From}, stats, State) -> + reply(From, stats(State), State); + +handle({call, From}, get_state, State) -> + reply(From, State, State); + +handle({call, From}, kick, State) -> + ok = gen_statem:reply(From, ok), + shutdown(kicked, State); + +handle({call, From}, {takeover, 'begin'}, State = #state{chan_state = ChanState}) -> + {ok, Session, NChanState} = emqx_channel:takeover('begin', ChanState), + ok = gen_statem:reply(From, {ok, Session}), + {next_state, takeovering, State#state{chan_state = NChanState}}; + +handle({call, From}, {takeover, 'end'}, State = #state{chan_state = ChanState}) -> + {ok, Delivers, NChanState} = emqx_channel:takeover('end', ChanState), + ok = gen_statem:reply(From, {ok, Delivers}), + shutdown(takeovered, State#state{chan_state = NChanState}); + +handle({call, From}, Req, State = #state{chan_state = ChanState}) -> + case emqx_channel:handle_call(Req, ChanState) of + {ok, Reply, NChanState} -> + reply(From, Reply, State#state{chan_state = NChanState}); + {stop, Reason, Reply, NChanState} -> + ok = gen_statem:reply(From, Reply), + stop(Reason, State#state{chan_state = NChanState}) + end; + +%% Handle cast +handle(cast, Msg, State = #state{chan_state = ChanState}) -> + case emqx_channel:handle_cast(Msg, ChanState) of + {ok, NChanState} -> + keep_state(State#state{chan_state = NChanState}); + {stop, Reason, NChanState} -> + stop(Reason, State#state{chan_state = NChanState}) + end; + +%% Handle incoming data +handle(info, {Inet, _Sock, Data}, State = #state{chan_state = ChanState}) + when Inet == tcp; Inet == ssl -> + Oct = iolist_size(Data), + ?LOG(debug, "RECV ~p", [Data]), + emqx_pd:update_counter(incoming_bytes, Oct), + ok = emqx_metrics:inc('bytes.received', Oct), + NChanState = emqx_channel:gc(1, Oct, ChanState), + process_incoming(Data, State#state{chan_state = NChanState}); + +handle(info, {Error, _Sock, Reason}, State) + when Error == tcp_error; Error == ssl_error -> + shutdown(Reason, State); + +handle(info, {Closed, _Sock}, State = #state{chan_state = ChanState}) + when Closed == tcp_closed; Closed == ssl_closed -> + case emqx_channel:info(session, ChanState) of + undefined -> shutdown(closed, State); + Session -> + case emqx_session:info(clean_start, Session) of + true -> shutdown(closed, State); + false -> {next_state, disconnected, State} + end + end; + +handle(info, {Passive, _Sock}, State) when Passive == tcp_passive; + Passive == ssl_passive -> + %% Rate limit here:) + NState = ensure_rate_limit(State), + case activate_socket(NState) of + ok -> keep_state(NState); + {error, Reason} -> + shutdown(Reason, NState) + end; + +handle(info, activate_socket, State) -> + %% Rate limit timer expired. + NState = State#state{conn_state = running}, + case activate_socket(NState) of + ok -> + keep_state(NState#state{limit_timer = undefined}); + {error, Reason} -> + shutdown(Reason, NState) + end; + +handle(info, {inet_reply, _Sock, ok}, State) -> + %% something sent + keep_state(State); + +handle(info, {inet_reply, _Sock, {error, Reason}}, State) -> + shutdown(Reason, State); + +handle(info, {timeout, TRef, Msg}, State = #state{chan_state = ChanState}) + when is_reference(TRef) -> + case emqx_channel:timeout(TRef, Msg, ChanState) of + {ok, NChanState} -> + keep_state(State#state{chan_state = NChanState}); + {ok, Packets, NChanState} -> + handle_outgoing(Packets, fun keep_state/1, + State#state{chan_state = NChanState}); + {stop, Reason, NChanState} -> + stop(Reason, State#state{chan_state = NChanState}) + end; + +handle(info, {shutdown, conflict, {ClientId, NewPid}}, State) -> + ?LOG(warning, "Clientid '~s' conflict with ~p", [ClientId, NewPid]), + shutdown(conflict, State); + +handle(info, {shutdown, Reason}, State) -> + shutdown(Reason, State); + +handle(info, Info, State = #state{chan_state = ChanState}) -> + case emqx_channel:handle_info(Info, ChanState) of + {ok, NChanState} -> + keep_state(State#state{chan_state = NChanState}); + {stop, Reason, NChanState} -> + stop(Reason, State#state{chan_state = NChanState}) + end. + +code_change(_Vsn, State, Data, _Extra) -> + {ok, State, Data}. + +terminate(Reason, _StateName, #state{transport = Transport, + socket = Socket, + keepalive = KeepAlive, + chan_state = ChanState}) -> + ?LOG(debug, "Terminated for ~p", [Reason]), + ok = Transport:fast_close(Socket), + KeepAlive =/= undefined + andalso emqx_keepalive:cancel(KeepAlive), + emqx_channel:terminate(Reason, ChanState). + +%%-------------------------------------------------------------------- +%% Process incoming data + +-compile({inline, [process_incoming/2]}). +process_incoming(Data, State) -> + process_incoming(Data, [], State). + +process_incoming(<<>>, Packets, State) -> + {keep_state, State, next_incoming_events(Packets)}; + +process_incoming(Data, Packets, State = #state{parse_state = ParseState}) -> + try emqx_frame:parse(Data, ParseState) of + {ok, NParseState} -> + NState = State#state{parse_state = NParseState}, + {keep_state, NState, next_incoming_events(Packets)}; + {ok, Packet, Rest, NParseState} -> + NState = State#state{parse_state = NParseState}, + process_incoming(Rest, [Packet|Packets], NState); + {error, Reason} -> + shutdown(Reason, State) + catch + error:Reason:Stk -> + ?LOG(error, "Parse failed for ~p~n\ + Stacktrace:~p~nError data:~p", [Reason, Stk, Data]), + shutdown(parse_error, State) + end. + +next_incoming_events(Packets) when is_list(Packets) -> + [next_event(cast, {incoming, Packet}) || Packet <- Packets]. + +%%-------------------------------------------------------------------- +%% Handle incoming packet + +handle_incoming(Packet = ?PACKET(Type), SuccFun, + State = #state{chan_state = ChanState}) -> + _ = inc_incoming_stats(Type), + ok = emqx_metrics:inc_recv(Packet), + ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]), + case emqx_channel:handle_in(Packet, ChanState) of + {ok, NChanState} -> + SuccFun(State#state{chan_state= NChanState}); + {ok, OutPackets, NChanState} -> + handle_outgoing(OutPackets, SuccFun, State#state{chan_state = NChanState}); + {stop, Reason, NChanState} -> + stop(Reason, State#state{chan_state = NChanState}); + {stop, Reason, OutPacket, NChanState} -> + Shutdown = fun(NewSt) -> shutdown(Reason, NewSt) end, + handle_outgoing(OutPacket, Shutdown, State#state{chan_state = NChanState}) + end. + +%%-------------------------------------------------------------------- +%% Handle outgoing packets + +handle_outgoing(Packets, SuccFun, State = #state{serialize = Serialize}) + when is_list(Packets) -> + send(lists:map(Serialize, Packets), SuccFun, State); + +handle_outgoing(Packet, SuccFun, State = #state{serialize = Serialize}) -> + send(Serialize(Packet), SuccFun, State). + +%%-------------------------------------------------------------------- +%% Serialize fun + +serialize_fun(ProtoVer) -> + fun(Packet = ?PACKET(Type)) -> + ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), + _ = inc_outgoing_stats(Type), + emqx_frame:serialize(Packet, ProtoVer) + end. + +%%-------------------------------------------------------------------- +%% Send data + +send(IoData, SuccFun, State = #state{transport = Transport, + socket = Socket}) -> + Oct = iolist_size(IoData), + ok = emqx_metrics:inc('bytes.sent', Oct), + case Transport:async_send(Socket, IoData) of + ok -> SuccFun(State); + {error, Reason} -> + shutdown(Reason, State) + end. + +%% TODO: maybe_gc(1, Oct, State) + +%%-------------------------------------------------------------------- +%% Ensure keepalive + +ensure_keepalive(0, _State) -> + ignore; +ensure_keepalive(Interval, #state{transport = Transport, + socket = Socket, + chan_state = ChanState}) -> + StatFun = fun() -> + case Transport:getstat(Socket, [recv_oct]) of + {ok, [{recv_oct, RecvOct}]} -> + {ok, RecvOct}; + Error -> Error + end + end, + Backoff = emqx_zone:get_env(emqx_channel:info(zone, ChanState), + keepalive_backoff, 0.75), + emqx_keepalive:start(StatFun, round(Interval * Backoff), {keepalive, check}). + +%%-------------------------------------------------------------------- +%% Ensure rate limit + +ensure_rate_limit(State = #state{rate_limit = Rl, pub_limit = Pl}) -> + Limiters = [{Pl, #state.pub_limit, emqx_pd:reset_counter(incoming_pubs)}, + {Rl, #state.rate_limit, emqx_pd:reset_counter(incoming_bytes)}], + ensure_rate_limit(Limiters, State). + +ensure_rate_limit([], State) -> + State; +ensure_rate_limit([{undefined, _Pos, _Cnt}|Limiters], State) -> + ensure_rate_limit(Limiters, State); +ensure_rate_limit([{Rl, Pos, Cnt}|Limiters], State) -> + case esockd_rate_limit:check(Cnt, Rl) of + {0, Rl1} -> + ensure_rate_limit(Limiters, setelement(Pos, State, Rl1)); + {Pause, Rl1} -> + ?LOG(debug, "Rate limit pause connection ~pms", [Pause]), + TRef = erlang:send_after(Pause, self(), activate_socket), + setelement(Pos, State#state{conn_state = blocked, + limit_timer = TRef}, Rl1) + end. + +%%-------------------------------------------------------------------- +%% Activate Socket + +activate_socket(#state{conn_state = blocked}) -> + ok; +activate_socket(#state{transport = Transport, + socket = Socket, + active_n = N}) -> + Transport:setopts(Socket, [{active, N}]). + +%%-------------------------------------------------------------------- +%% Inc incoming/outgoing stats + +-compile({inline, + [ inc_incoming_stats/1 + , inc_outgoing_stats/1 + ]}). + +inc_incoming_stats(Type) -> + emqx_pd:update_counter(recv_pkt, 1), + case Type == ?PUBLISH of + true -> + emqx_pd:update_counter(recv_msg, 1), + emqx_pd:update_counter(incoming_pubs, 1); + false -> ok + end. + +inc_outgoing_stats(Type) -> + emqx_pd:update_counter(send_pkt, 1), + (Type == ?PUBLISH) + andalso emqx_pd:update_counter(send_msg, 1). + +%%-------------------------------------------------------------------- +%% Helper functions + +-compile({inline, + [ reply/3 + , keep_state/1 + , next_event/2 + , shutdown/2 + , stop/2 + ]}). + +reply(From, Reply, State) -> + {keep_state, State, [{reply, From, Reply}]}. + +keep_state(State) -> + {keep_state, State}. + +next_event(Type, Content) -> + {next_event, Type, Content}. + +shutdown(Reason, State) -> + stop({shutdown, Reason}, State). + +stop(Reason, State) -> + {stop, Reason, State}. + diff --git a/src/emqx_listeners.erl b/src/emqx_listeners.erl index b39873879..802179065 100644 --- a/src/emqx_listeners.erl +++ b/src/emqx_listeners.erl @@ -79,7 +79,7 @@ start_listener(Proto, ListenOn, Options) when Proto == https; Proto == wss -> start_mqtt_listener(Name, ListenOn, Options) -> SockOpts = esockd:parse_opt(Options), esockd:open(Name, ListenOn, merge_default(SockOpts), - {emqx_channel, start_link, [Options -- SockOpts]}). + {emqx_connection, start_link, [Options -- SockOpts]}). start_http_listener(Start, Name, ListenOn, RanchOpts, ProtoOpts) -> Start(Name, with_port(ListenOn, RanchOpts), ProtoOpts). @@ -88,7 +88,7 @@ mqtt_path(Options) -> proplists:get_value(mqtt_path, Options, "/mqtt"). ws_opts(Options) -> - WsPaths = [{mqtt_path(Options), emqx_ws_channel, Options}], + WsPaths = [{mqtt_path(Options), emqx_ws_connection, Options}], Dispatch = cowboy_router:compile([{'_', WsPaths}]), ProxyProto = proplists:get_value(proxy_protocol, Options, false), #{env => #{dispatch => Dispatch}, proxy_header => ProxyProto}. diff --git a/src/emqx_misc.erl b/src/emqx_misc.erl index 42e88d850..007f444f4 100644 --- a/src/emqx_misc.erl +++ b/src/emqx_misc.erl @@ -25,10 +25,6 @@ , proc_stats/1 ]). --export([ init_proc_mng_policy/1 - , conn_proc_mng_policy/1 - ]). - -export([ drain_deliver/1 , drain_down/1 ]). @@ -82,56 +78,6 @@ proc_stats(Pid) -> [{mailbox_len, Len}|Stats] end. --define(DISABLED, 0). - -init_proc_mng_policy(undefined) -> ok; -init_proc_mng_policy(Zone) -> - #{max_heap_size := MaxHeapSizeInBytes} - = ShutdownPolicy - = emqx_zone:get_env(Zone, force_shutdown_policy), - MaxHeapSize = MaxHeapSizeInBytes div erlang:system_info(wordsize), - _ = erlang:process_flag(max_heap_size, MaxHeapSize), % zero is discarded - erlang:put(force_shutdown_policy, ShutdownPolicy), - ok. - -%% @doc Check self() process status against connection/session process management policy, -%% return `continue | hibernate | {shutdown, Reason}' accordingly. -%% `continue': There is nothing out of the ordinary. -%% `hibernate': Nothing to process in my mailbox, and since this check is triggered -%% by a timer, we assume it is a fat chance to continue idel, hence hibernate. -%% `shutdown': Some numbers (message queue length hit the limit), -%% hence shutdown for greater good (system stability). --spec(conn_proc_mng_policy(#{message_queue_len => integer()} | false) -> - continue | hibernate | {shutdown, _}). -conn_proc_mng_policy(#{message_queue_len := MaxMsgQueueLen}) -> - Qlength = proc_info(message_queue_len), - Checks = - [{fun() -> is_message_queue_too_long(Qlength, MaxMsgQueueLen) end, - {shutdown, message_queue_too_long}}, - {fun() -> Qlength > 0 end, continue}, - {fun() -> true end, hibernate} - ], - check(Checks); -conn_proc_mng_policy(_) -> - %% disable by default - conn_proc_mng_policy(#{message_queue_len => 0}). - -check([{Pred, Result} | Rest]) -> - case Pred() of - true -> Result; - false -> check(Rest) - end. - -is_message_queue_too_long(Qlength, Max) -> - is_enabled(Max) andalso Qlength > Max. - -is_enabled(Max) -> - is_integer(Max) andalso Max > ?DISABLED. - -proc_info(Key) -> - {Key, Value} = erlang:process_info(self(), Key), - Value. - %% @doc Drain delivers from the channel's mailbox. drain_deliver(Acc) -> receive diff --git a/src/emqx_mod_presence.erl b/src/emqx_mod_presence.erl index 97f7a9929..30f0ad334 100644 --- a/src/emqx_mod_presence.erl +++ b/src/emqx_mod_presence.erl @@ -37,9 +37,10 @@ %% APIs %%-------------------------------------------------------------------- -load(Env) -> - emqx_hooks:add('client.connected', {?MODULE, on_client_connected, [Env]}), - emqx_hooks:add('client.disconnected', {?MODULE, on_client_disconnected, [Env]}). +load(_Env) -> + ok. + %% emqx_hooks:add('client.connected', {?MODULE, on_client_connected, [Env]}), + %% emqx_hooks:add('client.disconnected', {?MODULE, on_client_disconnected, [Env]}). on_client_connected(#{client_id := ClientId, username := Username, diff --git a/src/emqx_oom.erl b/src/emqx_oom.erl new file mode 100644 index 000000000..5e14547f0 --- /dev/null +++ b/src/emqx_oom.erl @@ -0,0 +1,102 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +%%-------------------------------------------------------------------- +%% @doc OOM (Out Of Memory) monitor for the channel process. +%% @end +%%-------------------------------------------------------------------- + +-module(emqx_oom). + +-include("types.hrl"). + +-export([ init/1 + , check/1 + , info/1 + ]). + +-export_type([oom_policy/0]). + +-type(opts() :: #{message_queue_len => non_neg_integer(), + max_heap_size => non_neg_integer() + }). + +-opaque(oom_policy() :: {oom_policy, opts()}). + +-type(reason() :: message_queue_too_long|proc_heap_too_large). + +-define(DISABLED, 0). + +%% @doc Init the OOM policy. +-spec(init(maybe(opts())) -> oom_policy()). +init(undefined) -> undefined; +init(#{message_queue_len := MaxQLen, + max_heap_size := MaxHeapSizeInBytes}) -> + MaxHeapSize = MaxHeapSizeInBytes div erlang:system_info(wordsize), + %% If set to zero, the limit is disabled. + _ = erlang:process_flag(max_heap_size, #{size => MaxHeapSize, + kill => false, + error_logger => true + }), + {oom_policy, #{message_queue_len => MaxQLen, + max_heap_size => MaxHeapSize + }}. + +%% @doc Check self() process status against channel process management policy, +%% return `ok | {shutdown, Reason}' accordingly. +%% `ok': There is nothing out of the ordinary. +%% `shutdown': Some numbers (message queue length hit the limit), +%% hence shutdown for greater good (system stability). +-spec(check(maybe(oom_policy())) -> ok | {shutdown, reason()}). +check(undefined) -> ok; +check({oom_policy, #{message_queue_len := MaxQLen, + max_heap_size := MaxHeapSize}}) -> + Qlength = proc_info(message_queue_len), + HeapSize = proc_info(total_heap_size), + do_check([{fun() -> is_exceeded(Qlength, MaxQLen) end, + {shutdown, message_queue_too_long}}, + {fun() -> is_exceeded(HeapSize, MaxHeapSize) end, + {shutdown, proc_heap_too_large}}]). + +do_check([]) -> + ok; +do_check([{Pred, Result} | Rest]) -> + case Pred() of + true -> Result; + false -> do_check(Rest) + end. + +-spec(info(maybe(oom_policy())) -> maybe(opts())). +info(undefined) -> undefined; +info({oom_policy, Opts}) -> + Opts. + +-compile({inline, + [ is_exceeded/2 + , is_enabled/1 + , proc_info/1 + ]}). + +is_exceeded(Val, Max) -> + is_enabled(Max) andalso Val > Max. + +is_enabled(Max) -> + is_integer(Max) andalso Max > ?DISABLED. + +proc_info(Key) -> + {Key, Value} = erlang:process_info(self(), Key), + Value. + diff --git a/src/emqx_protocol.erl b/src/emqx_protocol.erl deleted file mode 100644 index 3fa92ac70..000000000 --- a/src/emqx_protocol.erl +++ /dev/null @@ -1,924 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - -%% MQTT Protocol --module(emqx_protocol). - --include("emqx.hrl"). --include("emqx_mqtt.hrl"). --include("logger.hrl"). --include("types.hrl"). - --logger_header("[Protocol]"). - --export([ info/1 - , info/2 - , attrs/1 - , caps/1 - ]). - -%% for tests --export([set/3]). - --export([ init/2 - , handle_in/2 - , handle_req/2 - , handle_deliver/2 - , handle_out/2 - , handle_timeout/3 - , terminate/2 - ]). - --import(emqx_access_control, - [ authenticate/1 - , check_acl/3 - ]). - --export_type([proto_state/0]). - --record(protocol, { - client :: emqx_types:client(), - session :: emqx_session:session(), - proto_name :: binary(), - proto_ver :: emqx_types:ver(), - keepalive :: non_neg_integer(), - will_msg :: emqx_types:message(), - topic_aliases :: maybe(map()), - alias_maximum :: maybe(map()), - ack_props :: maybe(emqx_types:properties()) %% Tmp props - }). - --opaque(proto_state() :: #protocol{}). - --define(NO_PROPS, undefined). - --spec(info(proto_state()) -> emqx_types:infos()). -info(#protocol{client = Client, - session = Session, - proto_name = ProtoName, - proto_ver = ProtoVer, - keepalive = Keepalive, - will_msg = WillMsg, - topic_aliases = Aliases}) -> - #{client => Client, - session => session_info(Session), - proto_name => ProtoName, - proto_ver => ProtoVer, - keepalive => Keepalive, - will_msg => WillMsg, - topic_aliases => Aliases - }. - --spec(info(atom(), proto_state()) -> term()). -info(client, #protocol{client = Client}) -> - Client; -info(zone, #protocol{client = #{zone := Zone}}) -> - Zone; -info(client_id, #protocol{client = #{client_id := ClientId}}) -> - ClientId; -info(session, #protocol{session = Session}) -> - Session; -info(proto_name, #protocol{proto_name = ProtoName}) -> - ProtoName; -info(proto_ver, #protocol{proto_ver = ProtoVer}) -> - ProtoVer; -info(keepalive, #protocol{keepalive = Keepalive}) -> - Keepalive; -info(will_msg, #protocol{will_msg = WillMsg}) -> - WillMsg; -info(topic_aliases, #protocol{topic_aliases = Aliases}) -> - Aliases. - -%% For tests -set(client, Client, PState) -> - PState#protocol{client = Client}; -set(session, Session, PState) -> - PState#protocol{session = Session}. - -attrs(#protocol{client = Client, - session = Session, - proto_name = ProtoName, - proto_ver = ProtoVer, - keepalive = Keepalive}) -> - #{client => Client, - session => emqx_session:attrs(Session), - proto_name => ProtoName, - proto_ver => ProtoVer, - keepalive => Keepalive - }. - -caps(#protocol{client = #{zone := Zone}}) -> - emqx_mqtt_caps:get_caps(Zone). - - --spec(init(emqx_types:conn(), proplists:proplist()) -> proto_state()). -init(ConnInfo, Options) -> - Zone = proplists:get_value(zone, Options), - Peercert = maps:get(peercert, ConnInfo, undefined), - Username = case peer_cert_as_username(Options) of - cn -> esockd_peercert:common_name(Peercert); - dn -> esockd_peercert:subject(Peercert); - crt -> Peercert; - _ -> undefined - end, - MountPoint = emqx_zone:get_env(Zone, mountpoint), - Client = maps:merge(#{zone => Zone, - username => Username, - mountpoint => MountPoint, - is_bridge => false, - is_superuser => false - }, ConnInfo), - #protocol{client = Client, - proto_name = <<"MQTT">>, - proto_ver = ?MQTT_PROTO_V4 - }. - -peer_cert_as_username(Options) -> - proplists:get_value(peer_cert_as_username, Options). - -%%-------------------------------------------------------------------- -%% Handle incoming packet -%%-------------------------------------------------------------------- - --spec(handle_in(emqx_types:packet(), proto_state()) - -> {ok, proto_state()} - | {ok, emqx_types:packet(), proto_state()} - | {ok, list(emqx_types:packet()), proto_state()} - | {error, Reason :: term(), proto_state()} - | {stop, Error :: atom(), proto_state()}). -handle_in(?CONNECT_PACKET( - #mqtt_packet_connect{proto_name = ProtoName, - proto_ver = ProtoVer, - keepalive = Keepalive, - client_id = ClientId - } = ConnPkt), PState) -> - PState1 = PState#protocol{proto_name = ProtoName, - proto_ver = ProtoVer, - keepalive = Keepalive - }, - ok = emqx_logger:set_metadata_client_id(ClientId), - case pipeline([fun validate_in/2, - fun process_props/2, - fun check_connect/2, - fun enrich_client/2, - fun auth_connect/2], ConnPkt, PState1) of - {ok, NConnPkt, NPState} -> - process_connect(NConnPkt, maybe_assign_clientid(NPState)); - {error, ReasonCode, NPState} -> - handle_out({disconnect, ReasonCode}, NPState) - end; - -handle_in(Packet = ?PUBLISH_PACKET(QoS, Topic, PacketId), PState) -> - case pipeline([fun validate_in/2, - fun process_alias/2, - fun check_publish/2], Packet, PState) of - {ok, NPacket, NPState} -> - process_publish(NPacket, NPState); - {error, ReasonCode, NPState} -> - ?LOG(warning, "Cannot publish message to ~s due to ~s", - [Topic, emqx_reason_codes:text(ReasonCode)]), - puback(QoS, PacketId, ReasonCode, NPState) - end; - -handle_in(?PUBACK_PACKET(PacketId, _ReasonCode), PState = #protocol{session = Session}) -> - case emqx_session:puback(PacketId, Session) of - {ok, Publishes, NSession} -> - handle_out({publish, Publishes}, PState#protocol{session = NSession}); - {ok, NSession} -> - {ok, PState#protocol{session = NSession}}; - {error, _NotFound} -> - {ok, PState} - end; - -handle_in(?PUBREC_PACKET(PacketId, _ReasonCode), PState = #protocol{session = Session}) -> - case emqx_session:pubrec(PacketId, Session) of - {ok, NSession} -> - handle_out({pubrel, PacketId}, PState#protocol{session = NSession}); - {error, ReasonCode1} -> - handle_out({pubrel, PacketId, ReasonCode1}, PState) - end; - -handle_in(?PUBREL_PACKET(PacketId, _ReasonCode), PState = #protocol{session = Session}) -> - case emqx_session:pubrel(PacketId, Session) of - {ok, NSession} -> - handle_out({pubcomp, PacketId}, PState#protocol{session = NSession}); - {error, ReasonCode1} -> - handle_out({pubcomp, PacketId, ReasonCode1}, PState) - end; - -handle_in(?PUBCOMP_PACKET(PacketId, _ReasonCode), PState = #protocol{session = Session}) -> - case emqx_session:pubcomp(PacketId, Session) of - {ok, Publishes, NSession} -> - handle_out({publish, Publishes}, PState#protocol{session = NSession}); - {ok, NSession} -> - {ok, PState#protocol{session = NSession}}; - {error, _NotFound} -> - {ok, PState} - end; - -handle_in(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), - PState = #protocol{client = Client}) -> - case validate_in(Packet, PState) of - ok -> TopicFilters1 = [emqx_topic:parse(TopicFilter, SubOpts) - || {TopicFilter, SubOpts} <- TopicFilters], - TopicFilters2 = emqx_hooks:run_fold('client.subscribe', - [Client, Properties], - TopicFilters1), - TopicFilters3 = enrich_subid(Properties, TopicFilters2), - {ReasonCodes, NPState} = process_subscribe(TopicFilters3, PState), - handle_out({suback, PacketId, ReasonCodes}, NPState); - {error, ReasonCode} -> - handle_out({disconnect, ReasonCode}, PState) - end; - -handle_in(Packet = ?UNSUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), - PState = #protocol{client = Client}) -> - case validate_in(Packet, PState) of - ok -> TopicFilters1 = lists:map(fun emqx_topic:parse/1, TopicFilters), - TopicFilters2 = emqx_hooks:run_fold('client.unsubscribe', - [Client, Properties], - TopicFilters1), - {ReasonCodes, NPState} = process_unsubscribe(TopicFilters2, PState), - handle_out({unsuback, PacketId, ReasonCodes}, NPState); - {error, ReasonCode} -> - handle_out({disconnect, ReasonCode}, PState) - end; - -handle_in(?PACKET(?PINGREQ), PState) -> - {ok, ?PACKET(?PINGRESP), PState}; - -handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), PState) -> - %% Clear will msg - {stop, normal, PState#protocol{will_msg = undefined}}; - -handle_in(?DISCONNECT_PACKET(RC), PState = #protocol{proto_ver = Ver}) -> - {stop, {shutdown, emqx_reason_codes:name(RC, Ver)}, PState}; - -handle_in(?AUTH_PACKET(), PState) -> - %%TODO: implement later. - {ok, PState}; - -handle_in(Packet, PState) -> - io:format("In: ~p~n", [Packet]), - {ok, PState}. - -%%-------------------------------------------------------------------- -%% Handle internal request -%%-------------------------------------------------------------------- - --spec(handle_req(Req:: term(), proto_state()) - -> {ok, Result :: term(), proto_state()} | - {error, Reason :: term(), proto_state()}). -handle_req({subscribe, TopicFilters}, PState = #protocol{client = Client}) -> - TopicFilters1 = emqx_hooks:run_fold('client.subscribe', - [Client, #{'Internal' => true}], - parse(subscribe, TopicFilters)), - {ReasonCodes, NPState} = process_subscribe(TopicFilters1, PState), - {ok, ReasonCodes, NPState}; - -handle_req({unsubscribe, TopicFilters}, PState = #protocol{client = Client}) -> - TopicFilters1 = emqx_hooks:run_fold('client.unsubscribe', - [Client, #{'Internal' => true}], - parse(unsubscribe, TopicFilters)), - {ReasonCodes, NPState} = process_unsubscribe(TopicFilters1, PState), - {ok, ReasonCodes, NPState}; - -handle_req(Req, PState) -> - ?LOG(error, "Unexpected request: ~p~n", [Req]), - {ok, ignored, PState}. - -%%-------------------------------------------------------------------- -%% Handle delivers -%%-------------------------------------------------------------------- - -handle_deliver(Delivers, PState = #protocol{session = Session}) - when is_list(Delivers) -> - case emqx_session:deliver(Delivers, Session) of - {ok, Publishes, NSession} -> - handle_out({publish, Publishes}, PState#protocol{session = NSession}); - {ok, NSession} -> - {ok, PState#protocol{session = NSession}} - end. - -%%-------------------------------------------------------------------- -%% Handle outgoing packet -%%-------------------------------------------------------------------- - -handle_out({connack, ?RC_SUCCESS, SP}, - PState = #protocol{client = Client = #{zone := Zone}, - ack_props = AckProps, - alias_maximum = AliasMaximum}) -> - ok = emqx_hooks:run('client.connected', [Client, ?RC_SUCCESS, attrs(PState)]), - #{max_packet_size := MaxPktSize, - max_qos_allowed := MaxQoS, - retain_available := Retain, - max_topic_alias := MaxAlias, - shared_subscription := Shared, - wildcard_subscription := Wildcard - } = caps(PState), - %% Response-Information is so far not set by broker. - %% i.e. It's a Client-to-Client contract for the request-response topic naming scheme. - %% According to MQTT 5.0 spec: - %% A common use of this is to pass a globally unique portion of the topic tree which - %% is reserved for this Client for at least the lifetime of its Session. - %% This often cannot just be a random name as both the requesting Client and the - %% responding Client need to be authorized to use it. - %% If we are to support it in the feature, the implementation should be flexible - %% to allow prefixing the response topic based on different ACL config. - %% e.g. prefix by username or client-id, so that unauthorized clients can not - %% subscribe requests or responses that are not intended for them. - AckProps1 = if AckProps == undefined -> #{}; true -> AckProps end, - AckProps2 = AckProps1#{'Retain-Available' => flag(Retain), - 'Maximum-Packet-Size' => MaxPktSize, - 'Topic-Alias-Maximum' => MaxAlias, - 'Wildcard-Subscription-Available' => flag(Wildcard), - 'Subscription-Identifier-Available' => 1, - %'Response-Information' => - 'Shared-Subscription-Available' => flag(Shared), - 'Maximum-QoS' => MaxQoS - }, - AckProps3 = case emqx_zone:get_env(Zone, server_keepalive) of - undefined -> AckProps2; - Keepalive -> AckProps2#{'Server-Keep-Alive' => Keepalive} - end, - AliasMaximum1 = set_property(inbound, MaxAlias, AliasMaximum), - PState1 = PState#protocol{alias_maximum = AliasMaximum1, - ack_props = undefined - }, - {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, AckProps3), PState1}; - -handle_out({connack, ReasonCode}, PState = #protocol{client = Client, - proto_ver = ProtoVer}) -> - ok = emqx_hooks:run('client.connected', [Client, ReasonCode, attrs(PState)]), - ReasonCode1 = if - ProtoVer == ?MQTT_PROTO_V5 -> ReasonCode; - true -> emqx_reason_codes:compat(connack, ReasonCode) - end, - Reason = emqx_reason_codes:name(ReasonCode1, ProtoVer), - {error, Reason, ?CONNACK_PACKET(ReasonCode1), PState}; - -handle_out({publish, Publishes}, PState) -> - Packets = [element(2, handle_out(Publish, PState)) || Publish <- Publishes], - {ok, Packets, PState}; - -handle_out({publish, PacketId, Msg}, PState = #protocol{client = Client}) -> - Msg1 = emqx_hooks:run_fold('message.deliver', [Client], - emqx_message:update_expiry(Msg)), - Packet = emqx_packet:from_message(PacketId, unmount(Client, Msg1)), - {ok, Packet, PState}; - -%% TODO: How to handle the err? -handle_out({puberr, _ReasonCode}, PState) -> - {ok, PState}; - -handle_out({puback, PacketId, ReasonCode}, PState) -> - {ok, ?PUBACK_PACKET(PacketId, ReasonCode), PState}; - -handle_out({pubrel, PacketId}, PState) -> - {ok, ?PUBREL_PACKET(PacketId), PState}; -handle_out({pubrel, PacketId, ReasonCode}, PState) -> - {ok, ?PUBREL_PACKET(PacketId, ReasonCode), PState}; - -handle_out({pubrec, PacketId, ReasonCode}, PState) -> - {ok, ?PUBREC_PACKET(PacketId, ReasonCode), PState}; - -handle_out({pubcomp, PacketId}, PState) -> - {ok, ?PUBCOMP_PACKET(PacketId), PState}; -handle_out({pubcomp, PacketId, ReasonCode}, PState) -> - {ok, ?PUBCOMP_PACKET(PacketId, ReasonCode), PState}; - -handle_out({suback, PacketId, ReasonCodes}, PState = #protocol{proto_ver = ?MQTT_PROTO_V5}) -> - %% TODO: ACL Deny - {ok, ?SUBACK_PACKET(PacketId, ReasonCodes), PState}; -handle_out({suback, PacketId, ReasonCodes}, PState) -> - %% TODO: ACL Deny - ReasonCodes1 = [emqx_reason_codes:compat(suback, RC) || RC <- ReasonCodes], - {ok, ?SUBACK_PACKET(PacketId, ReasonCodes1), PState}; - -handle_out({unsuback, PacketId, ReasonCodes}, PState = #protocol{proto_ver = ?MQTT_PROTO_V5}) -> - {ok, ?UNSUBACK_PACKET(PacketId, ReasonCodes), PState}; -%% Ignore reason codes if not MQTT5 -handle_out({unsuback, PacketId, _ReasonCodes}, PState) -> - {ok, ?UNSUBACK_PACKET(PacketId), PState}; - -handle_out({disconnect, ReasonCode}, PState = #protocol{proto_ver = ?MQTT_PROTO_V5}) -> - Reason = emqx_reason_codes:name(ReasonCode), - {error, Reason, ?DISCONNECT_PACKET(ReasonCode), PState}; - -handle_out({disconnect, ReasonCode}, PState = #protocol{proto_ver = ProtoVer}) -> - {error, emqx_reason_codes:name(ReasonCode, ProtoVer), PState}; - -handle_out(Packet, PState) -> - ?LOG(error, "Unexpected out:~p", [Packet]), - {ok, PState}. - -%%-------------------------------------------------------------------- -%% Handle timeout -%%-------------------------------------------------------------------- - -handle_timeout(TRef, Msg, PState = #protocol{session = Session}) -> - case emqx_session:timeout(TRef, Msg, Session) of - {ok, NSession} -> - {ok, PState#protocol{session = NSession}}; - {ok, Publishes, NSession} -> - handle_out({publish, Publishes}, PState#protocol{session = NSession}) - end. - -terminate(normal, #protocol{client = Client}) -> - ok = emqx_hooks:run('client.disconnected', [Client, normal]); -terminate(Reason, #protocol{client = Client, will_msg = WillMsg}) -> - ok = emqx_hooks:run('client.disconnected', [Client, Reason]), - publish_will_msg(WillMsg). - -publish_will_msg(undefined) -> - ok; -publish_will_msg(Msg) -> - emqx_broker:publish(Msg). - -%%-------------------------------------------------------------------- -%% Validate incoming packet -%%-------------------------------------------------------------------- - --spec(validate_in(emqx_types:packet(), proto_state()) - -> ok | {error, emqx_types:reason_code()}). -validate_in(Packet, _PState) -> - try emqx_packet:validate(Packet) of - true -> ok - catch - error:protocol_error -> - {error, ?RC_PROTOCOL_ERROR}; - error:subscription_identifier_invalid -> - {error, ?RC_SUBSCRIPTION_IDENTIFIERS_NOT_SUPPORTED}; - error:topic_alias_invalid -> - {error, ?RC_TOPIC_ALIAS_INVALID}; - error:topic_filters_invalid -> - {error, ?RC_TOPIC_FILTER_INVALID}; - error:topic_name_invalid -> - {error, ?RC_TOPIC_FILTER_INVALID}; - error:_Reason -> - {error, ?RC_MALFORMED_PACKET} - end. - -%%-------------------------------------------------------------------- -%% Preprocess properties -%%-------------------------------------------------------------------- - -process_props(#mqtt_packet_connect{ - properties = #{'Topic-Alias-Maximum' := Max} - }, - PState = #protocol{alias_maximum = AliasMaximum}) -> - NAliasMaximum = if AliasMaximum == undefined -> - #{outbound => Max}; - true -> AliasMaximum#{outbound => Max} - end, - {ok, PState#protocol{alias_maximum = NAliasMaximum}}; - -process_props(Packet, PState) -> - {ok, Packet, PState}. - -%%-------------------------------------------------------------------- -%% Check Connect Packet -%%-------------------------------------------------------------------- - -check_connect(ConnPkt, PState) -> - case pipeline([fun check_proto_ver/2, - fun check_client_id/2, - %%fun check_flapping/2, - fun check_banned/2, - fun check_will_topic/2, - fun check_will_retain/2], ConnPkt, PState) of - ok -> {ok, PState}; - Error -> Error - end. - -check_proto_ver(#mqtt_packet_connect{proto_ver = Ver, - proto_name = Name}, _PState) -> - case lists:member({Ver, Name}, ?PROTOCOL_NAMES) of - true -> ok; - false -> {error, ?RC_PROTOCOL_ERROR} - end. - -%% MQTT3.1 does not allow null clientId -check_client_id(#mqtt_packet_connect{proto_ver = ?MQTT_PROTO_V3, - client_id = <<>> - }, _PState) -> - {error, ?RC_CLIENT_IDENTIFIER_NOT_VALID}; - -%% Issue#599: Null clientId and clean_start = false -check_client_id(#mqtt_packet_connect{client_id = <<>>, - clean_start = false}, _PState) -> - {error, ?RC_CLIENT_IDENTIFIER_NOT_VALID}; - -check_client_id(#mqtt_packet_connect{client_id = <<>>, - clean_start = true}, _PState) -> - ok; - -check_client_id(#mqtt_packet_connect{client_id = ClientId}, - #protocol{client = #{zone := Zone}}) -> - Len = byte_size(ClientId), - MaxLen = emqx_zone:get_env(Zone, max_clientid_len), - case (1 =< Len) andalso (Len =< MaxLen) of - true -> ok; - false -> {error, ?RC_CLIENT_IDENTIFIER_NOT_VALID} - end. - -%%TODO: check banned... -check_banned(#mqtt_packet_connect{client_id = ClientId, - username = Username}, - #protocol{client = Client = #{zone := Zone}}) -> - case emqx_zone:get_env(Zone, enable_ban, false) of - true -> - case emqx_banned:check(Client#{client_id => ClientId, - username => Username}) of - true -> {error, ?RC_BANNED}; - false -> ok - end; - false -> ok - end. - -check_will_topic(#mqtt_packet_connect{will_flag = false}, _PState) -> - ok; -check_will_topic(#mqtt_packet_connect{will_topic = WillTopic}, _PState) -> - try emqx_topic:validate(WillTopic) of - true -> ok - catch error:_Error -> - {error, ?RC_TOPIC_NAME_INVALID} - end. - -check_will_retain(#mqtt_packet_connect{will_retain = false}, _PState) -> - ok; -check_will_retain(#mqtt_packet_connect{will_retain = true}, - #protocol{client = #{zone := Zone}}) -> - case emqx_zone:get_env(Zone, mqtt_retain_available, true) of - true -> ok; - false -> {error, ?RC_RETAIN_NOT_SUPPORTED} - end. - -%%-------------------------------------------------------------------- -%% Enrich client -%%-------------------------------------------------------------------- - -enrich_client(#mqtt_packet_connect{client_id = ClientId, - username = Username, - is_bridge = IsBridge - }, - PState = #protocol{client = Client}) -> - Client1 = set_username(Username, Client#{client_id => ClientId, - is_bridge => IsBridge - }), - {ok, PState#protocol{client = maybe_username_as_clientid(Client1)}}. - -%% Username maybe not undefined if peer_cert_as_username -set_username(Username, Client = #{username := undefined}) -> - Client#{username => Username}; -set_username(_Username, Client) -> Client. - -maybe_username_as_clientid(Client = #{username := undefined}) -> - Client; -maybe_username_as_clientid(Client = #{zone := Zone, - username := Username}) -> - case emqx_zone:get_env(Zone, use_username_as_clientid, false) of - true -> Client#{client_id => Username}; - false -> Client - end. - -%%-------------------------------------------------------------------- -%% Auth Connect -%%-------------------------------------------------------------------- - -auth_connect(#mqtt_packet_connect{client_id = ClientId, - username = Username, - password = Password}, - PState = #protocol{client = Client}) -> - case authenticate(Client#{password => Password}) of - {ok, AuthResult} -> - {ok, PState#protocol{client = maps:merge(Client, AuthResult)}}; - {error, Reason} -> - ?LOG(warning, "Client ~s (Username: '~s') login failed for ~0p", - [ClientId, Username, Reason]), - {error, emqx_reason_codes:connack_error(Reason)} - end. - -%%-------------------------------------------------------------------- -%% Assign a random clientId -%%-------------------------------------------------------------------- - -maybe_assign_clientid(PState = #protocol{client = Client = #{client_id := <<>>}, - ack_props = AckProps}) -> - ClientId = emqx_guid:to_base62(emqx_guid:gen()), - Client1 = Client#{client_id => ClientId}, - AckProps1 = set_property('Assigned-Client-Identifier', ClientId, AckProps), - PState#protocol{client = Client1, ack_props = AckProps1}; -maybe_assign_clientid(PState) -> PState. - -%%-------------------------------------------------------------------- -%% Process Connect -%%-------------------------------------------------------------------- - -process_connect(ConnPkt, PState) -> - case open_session(ConnPkt, PState) of - {ok, Session, SP} -> - WillMsg = emqx_packet:will_msg(ConnPkt), - NPState = PState#protocol{session = Session, - will_msg = WillMsg - }, - handle_out({connack, ?RC_SUCCESS, sp(SP)}, NPState); - {error, Reason} -> - %% TODO: Unknown error? - ?LOG(error, "Failed to open session: ~p", [Reason]), - handle_out({connack, ?RC_UNSPECIFIED_ERROR}, PState) - end. - -%%-------------------------------------------------------------------- -%% Open session -%%-------------------------------------------------------------------- - -open_session(#mqtt_packet_connect{clean_start = CleanStart, - properties = ConnProps}, - #protocol{client = Client = #{zone := Zone}}) -> - MaxInflight = get_property('Receive-Maximum', ConnProps, - emqx_zone:get_env(Zone, max_inflight, 65535)), - Interval = get_property('Session-Expiry-Interval', ConnProps, - emqx_zone:get_env(Zone, session_expiry_interval, 0)), - emqx_cm:open_session(CleanStart, Client, #{max_inflight => MaxInflight, - expiry_interval => Interval - }). - -%%-------------------------------------------------------------------- -%% Process publish message: Client -> Broker -%%-------------------------------------------------------------------- - -process_alias(Packet = #mqtt_packet{ - variable = #mqtt_packet_publish{topic_name = <<>>, - properties = #{'Topic-Alias' := AliasId} - } = Publish - }, PState = #protocol{topic_aliases = Aliases}) -> - case find_alias(AliasId, Aliases) of - {ok, Topic} -> - {ok, Packet#mqtt_packet{ - variable = Publish#mqtt_packet_publish{ - topic_name = Topic}}, PState}; - false -> {error, ?RC_TOPIC_ALIAS_INVALID} - end; - -process_alias(#mqtt_packet{ - variable = #mqtt_packet_publish{topic_name = Topic, - properties = #{'Topic-Alias' := AliasId} - } - }, PState = #protocol{topic_aliases = Aliases}) -> - {ok, PState#protocol{topic_aliases = save_alias(AliasId, Topic, Aliases)}}; - -process_alias(_Packet, PState) -> - {ok, PState}. - -find_alias(_AliasId, undefined) -> - false; -find_alias(AliasId, Aliases) -> - maps:find(AliasId, Aliases). - -save_alias(AliasId, Topic, undefined) -> - #{AliasId => Topic}; -save_alias(AliasId, Topic, Aliases) -> - maps:put(AliasId, Topic, Aliases). - -%% Check Publish -check_publish(Packet, PState) -> - pipeline([fun check_pub_acl/2, - fun check_pub_alias/2, - fun check_pub_caps/2], Packet, PState). - -%% Check Pub ACL -check_pub_acl(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = Topic}}, - #protocol{client = Client}) -> - case is_acl_enabled(Client) andalso check_acl(Client, publish, Topic) of - false -> ok; - allow -> ok; - deny -> {error, ?RC_NOT_AUTHORIZED} - end. - -%% Check Pub Alias -check_pub_alias(#mqtt_packet{ - variable = #mqtt_packet_publish{ - properties = #{'Topic-Alias' := AliasId} - } - }, - #protocol{alias_maximum = Limits}) -> - case (Limits == undefined) - orelse (Max = maps:get(inbound, Limits, 0)) == 0 - orelse (AliasId > Max) of - false -> ok; - true -> {error, ?RC_TOPIC_ALIAS_INVALID} - end; -check_pub_alias(_Packet, _PState) -> ok. - -%% Check Pub Caps -check_pub_caps(#mqtt_packet{header = #mqtt_packet_header{qos = QoS, - retain = Retain - } - }, - #protocol{client = #{zone := Zone}}) -> - emqx_mqtt_caps:check_pub(Zone, #{qos => QoS, retain => Retain}). - -%% Process Publish -process_publish(Packet = ?PUBLISH_PACKET(_QoS, _Topic, PacketId), - PState = #protocol{client = Client}) -> - Msg = emqx_packet:to_message(Client, Packet), - %%TODO: Improve later. - Msg1 = emqx_message:set_flag(dup, false, Msg), - process_publish(PacketId, mount(Client, Msg1), PState). - -process_publish(_PacketId, Msg = #message{qos = ?QOS_0}, PState) -> - _ = emqx_broker:publish(Msg), - {ok, PState}; - -process_publish(PacketId, Msg = #message{qos = ?QOS_1}, PState) -> - Deliveries = emqx_broker:publish(Msg), - ReasonCode = emqx_reason_codes:puback(Deliveries), - handle_out({puback, PacketId, ReasonCode}, PState); - -process_publish(PacketId, Msg = #message{qos = ?QOS_2}, - PState = #protocol{session = Session}) -> - case emqx_session:publish(PacketId, Msg, Session) of - {ok, Deliveries, NSession} -> - ReasonCode = emqx_reason_codes:puback(Deliveries), - handle_out({pubrec, PacketId, ReasonCode}, - PState#protocol{session = NSession}); - {error, ReasonCode} -> - handle_out({pubrec, PacketId, ReasonCode}, PState) - end. - -%%-------------------------------------------------------------------- -%% Puback -%%-------------------------------------------------------------------- - -puback(?QOS_0, _PacketId, ReasonCode, PState) -> - handle_out({puberr, ReasonCode}, PState); -puback(?QOS_1, PacketId, ReasonCode, PState) -> - handle_out({puback, PacketId, ReasonCode}, PState); -puback(?QOS_2, PacketId, ReasonCode, PState) -> - handle_out({pubrec, PacketId, ReasonCode}, PState). - -%%-------------------------------------------------------------------- -%% Process subscribe request -%%-------------------------------------------------------------------- - -process_subscribe(TopicFilters, PState) -> - process_subscribe(TopicFilters, [], PState). - -process_subscribe([], Acc, PState) -> - {lists:reverse(Acc), PState}; - -process_subscribe([{TopicFilter, SubOpts}|More], Acc, PState) -> - {RC, NPState} = do_subscribe(TopicFilter, SubOpts, PState), - process_subscribe(More, [RC|Acc], NPState). - -do_subscribe(TopicFilter, SubOpts = #{qos := QoS}, - PState = #protocol{client = Client, session = Session}) -> - case check_subscribe(TopicFilter, SubOpts, PState) of - ok -> TopicFilter1 = mount(Client, TopicFilter), - SubOpts1 = enrich_subopts(maps:merge(?DEFAULT_SUBOPTS, SubOpts), PState), - case emqx_session:subscribe(Client, TopicFilter1, SubOpts1, Session) of - {ok, NSession} -> - {QoS, PState#protocol{session = NSession}}; - {error, RC} -> {RC, PState} - end; - {error, RC} -> {RC, PState} - end. - -enrich_subid(#{'Subscription-Identifier' := SubId}, TopicFilters) -> - [{Topic, SubOpts#{subid => SubId}} || {Topic, SubOpts} <- TopicFilters]; -enrich_subid(_Properties, TopicFilters) -> - TopicFilters. - -enrich_subopts(SubOpts, #protocol{proto_ver = ?MQTT_PROTO_V5}) -> - SubOpts; -enrich_subopts(SubOpts, #protocol{client = #{zone := Zone, is_bridge := IsBridge}}) -> - Rap = flag(IsBridge), - Nl = flag(emqx_zone:get_env(Zone, ignore_loop_deliver, false)), - SubOpts#{rap => Rap, nl => Nl}. - -%% Check Sub -check_subscribe(TopicFilter, SubOpts, PState) -> - case check_sub_acl(TopicFilter, PState) of - allow -> check_sub_caps(TopicFilter, SubOpts, PState); - deny -> {error, ?RC_NOT_AUTHORIZED} - end. - -%% Check Sub ACL -check_sub_acl(TopicFilter, #protocol{client = Client}) -> - case is_acl_enabled(Client) andalso - check_acl(Client, subscribe, TopicFilter) of - false -> allow; - Result -> Result - end. - -%% Check Sub Caps -check_sub_caps(TopicFilter, SubOpts, #protocol{client = #{zone := Zone}}) -> - emqx_mqtt_caps:check_sub(Zone, TopicFilter, SubOpts). - -%%-------------------------------------------------------------------- -%% Process unsubscribe request -%%-------------------------------------------------------------------- - -process_unsubscribe(TopicFilters, PState) -> - process_unsubscribe(TopicFilters, [], PState). - -process_unsubscribe([], Acc, PState) -> - {lists:reverse(Acc), PState}; - -process_unsubscribe([{TopicFilter, SubOpts}|More], Acc, PState) -> - {RC, PState1} = do_unsubscribe(TopicFilter, SubOpts, PState), - process_unsubscribe(More, [RC|Acc], PState1). - -do_unsubscribe(TopicFilter, _SubOpts, PState = #protocol{client = Client, - session = Session}) -> - case emqx_session:unsubscribe(Client, mount(Client, TopicFilter), Session) of - {ok, NSession} -> - {?RC_SUCCESS, PState#protocol{session = NSession}}; - {error, RC} -> {RC, PState} - end. - -%%-------------------------------------------------------------------- -%% Is ACL enabled? -%%-------------------------------------------------------------------- - -is_acl_enabled(#{zone := Zone, is_superuser := IsSuperuser}) -> - (not IsSuperuser) andalso emqx_zone:get_env(Zone, enable_acl, true). - -%%-------------------------------------------------------------------- -%% Parse topic filters -%%-------------------------------------------------------------------- - -parse(subscribe, TopicFilters) -> - [emqx_topic:parse(TopicFilter, SubOpts) || {TopicFilter, SubOpts} <- TopicFilters]; - -parse(unsubscribe, TopicFilters) -> - lists:map(fun emqx_topic:parse/1, TopicFilters). - -%%-------------------------------------------------------------------- -%% Mount/Unmount -%%-------------------------------------------------------------------- - -mount(Client = #{mountpoint := MountPoint}, TopicOrMsg) -> - emqx_mountpoint:mount(emqx_mountpoint:replvar(MountPoint, Client), TopicOrMsg). - -unmount(Client = #{mountpoint := MountPoint}, TopicOrMsg) -> - emqx_mountpoint:unmount(emqx_mountpoint:replvar(MountPoint, Client), TopicOrMsg). - -%%-------------------------------------------------------------------- -%% Pipeline -%%-------------------------------------------------------------------- - -pipeline([], Packet, PState) -> - {ok, Packet, PState}; - -pipeline([Fun|More], Packet, PState) -> - case Fun(Packet, PState) of - ok -> pipeline(More, Packet, PState); - {ok, NPState} -> - pipeline(More, Packet, NPState); - {ok, NPacket, NPState} -> - pipeline(More, NPacket, NPState); - {error, ReasonCode} -> - {error, ReasonCode, PState}; - {error, ReasonCode, NPState} -> - {error, ReasonCode, NPState} - end. - -%%-------------------------------------------------------------------- -%% Helper functions -%%-------------------------------------------------------------------- - -set_property(Name, Value, ?NO_PROPS) -> - #{Name => Value}; -set_property(Name, Value, Props) -> - Props#{Name => Value}. - -get_property(_Name, undefined, Default) -> - Default; -get_property(Name, Props, Default) -> - maps:get(Name, Props, Default). - -sp(true) -> 1; -sp(false) -> 0. - -flag(true) -> 1; -flag(false) -> 0. - -session_info(undefined) -> - undefined; -session_info(Session) -> - emqx_session:info(Session). diff --git a/src/emqx_session.erl b/src/emqx_session.erl index 2530f3a42..fcc404a18 100644 --- a/src/emqx_session.erl +++ b/src/emqx_session.erl @@ -58,6 +58,10 @@ , stats/1 ]). +-export([ takeover/1 + , resume/2 + ]). + -export([ subscribe/4 , unsubscribe/3 ]). @@ -278,6 +282,25 @@ stats(#session{subscriptions = Subscriptions, {awaiting_rel, maps:size(AwaitingRel)}, {max_awaiting_rel, MaxAwaitingRel}]. +-spec(takeover(session()) -> ok). +takeover(#session{subscriptions = Subs}) -> + lists:foreach(fun({TopicFilter, _SubOpts}) -> + ok = emqx_broker:unsubscribe(TopicFilter) + end, maps:to_list(Subs)). + +-spec(resume(emqx_types:client_id(), session()) -> {ok, session()}). +resume(ClientId, Session = #session{subscriptions = Subs}) -> + ?LOG(info, "Session is resumed."), + %% 1. Subscribe again + ok = lists:foreach(fun({TopicFilter, SubOpts}) -> + ok = emqx_broker:subscribe(TopicFilter, ClientId, SubOpts) + end, maps:to_list(Subs)), + %% 2. Run hooks. + ok = emqx_hooks:run('session.resumed', [#{client_id => ClientId}, attrs(Session)]), + %% TODO: 3. Redeliver: Replay delivery and Dequeue pending messages + %% noreply(ensure_stats_timer(dequeue(retry_delivery(true, State1)))); + {ok, Session}. + %%-------------------------------------------------------------------- %% Client -> Broker: SUBSCRIBE %%-------------------------------------------------------------------- @@ -683,6 +706,6 @@ next_pkt_id(Session = #session{next_pkt_id = Id}) -> %% For Test case %%--------------------------------------------------------------------- - set_pkt_id(Session, PktId) -> Session#session{next_pkt_id = PktId}. + diff --git a/src/emqx_ws_channel.erl b/src/emqx_ws_connection.erl similarity index 60% rename from src/emqx_ws_channel.erl rename to src/emqx_ws_connection.erl index 3bc067525..2597ec7fe 100644 --- a/src/emqx_ws_channel.erl +++ b/src/emqx_ws_connection.erl @@ -14,21 +14,26 @@ %% limitations under the License. %%-------------------------------------------------------------------- -%% MQTT WebSocket Channel --module(emqx_ws_channel). +%% MQTT WebSocket Connection +-module(emqx_ws_connection). -include("emqx.hrl"). -include("emqx_mqtt.hrl"). -include("logger.hrl"). -include("types.hrl"). --logger_header("[WsChannel]"). +-logger_header("[WsConn]"). -export([ info/1 , attrs/1 , stats/1 ]). +-export([ kick/1 + , discard/1 + , takeover/2 + ]). + %% WebSocket callbacks -export([ init/2 , websocket_init/1 @@ -38,20 +43,15 @@ ]). -record(state, { - peername :: emqx_types:peername(), - sockname :: emqx_types:peername(), - fsm_state :: idle | connected | disconnected, - serialize :: fun((emqx_types:packet()) -> iodata()), - parse_state :: emqx_frame:parse_state(), - proto_state :: emqx_protocol:proto_state(), - gc_state :: emqx_gc:gc_state(), - keepalive :: maybe(emqx_keepalive:keepalive()), - pendings :: list(), - stats_timer :: disabled | maybe(reference()), - idle_timeout :: timeout(), - connected :: boolean(), - connected_at :: erlang:timestamp(), - reason :: term() + peername :: emqx_types:peername(), + sockname :: emqx_types:peername(), + fsm_state :: idle | connected | disconnected, + serialize :: fun((emqx_types:packet()) -> iodata()), + parse_state :: emqx_frame:parse_state(), + chan_state :: emqx_channel:channel(), + keepalive :: maybe(emqx_keepalive:keepalive()), + pendings :: list(), + reason :: term() }). -type(state() :: #state{}). @@ -68,51 +68,47 @@ info(WSPid) when is_pid(WSPid) -> call(WSPid, info); info(#state{peername = Peername, sockname = Sockname, - proto_state = ProtoState, - gc_state = GCState, - stats_timer = StatsTimer, - idle_timeout = IdleTimeout, - connected = Connected, - connected_at = ConnectedAt}) -> - ChanInfo = #{socktype => websocket, + chan_state = ChanState + }) -> + ConnInfo = #{socktype => websocket, peername => Peername, sockname => Sockname, - conn_state => running, - gc_state => emqx_gc:info(GCState), - enable_stats => enable_stats(StatsTimer), - idle_timeout => IdleTimeout, - connected => Connected, - connected_at => ConnectedAt + conn_state => running }, - maps:merge(ChanInfo, emqx_protocol:info(ProtoState)). - -enable_stats(disabled) -> false; -enable_stats(_MaybeRef) -> true. + maps:merge(ConnInfo, emqx_channel:info(ChanState)). -spec(attrs(pid() | state()) -> emqx_types:attrs()). attrs(WSPid) when is_pid(WSPid) -> call(WSPid, attrs); attrs(#state{peername = Peername, sockname = Sockname, - proto_state = ProtoState, - connected = Connected, - connected_at = ConnectedAt}) -> + chan_state = ChanState}) -> ConnAttrs = #{socktype => websocket, peername => Peername, - sockname => Sockname, - connected => Connected, - connected_at => ConnectedAt + sockname => Sockname }, - maps:merge(ConnAttrs, emqx_protocol:attrs(ProtoState)). + maps:merge(ConnAttrs, emqx_channel:attrs(ChanState)). -spec(stats(pid() | state()) -> emqx_types:stats()). stats(WSPid) when is_pid(WSPid) -> call(WSPid, stats); -stats(#state{proto_state = ProtoState}) -> +stats(#state{chan_state = ChanState}) -> ProcStats = emqx_misc:proc_stats(), - SessStats = emqx_session:stats(emqx_protocol:info(session, ProtoState)), + SessStats = emqx_session:stats(emqx_channel:info(session, ChanState)), lists:append([ProcStats, SessStats, chan_stats(), wsock_stats()]). +-spec(kick(pid()) -> ok). +kick(CPid) -> + call(CPid, kick). + +-spec(discard(pid()) -> ok). +discard(WSPid) -> + WSPid ! {cast, discard}, ok. + +-spec(takeover(pid(), 'begin'|'end') -> {ok, Result :: term()}). +takeover(CPid, Phase) -> + call(CPid, {takeover, Phase}). + %% @private call(WSPid, Req) when is_pid(WSPid) -> Mref = erlang:monitor(process, WSPid), @@ -171,31 +167,23 @@ websocket_init([Req, Opts]) -> [Error, Reason]), undefined end, - ProtoState = emqx_protocol:init(#{peername => Peername, - sockname => Sockname, - peercert => Peercert, - ws_cookie => WsCookie, - conn_mod => ?MODULE}, Opts), + ChanState = emqx_channel:init(#{peername => Peername, + sockname => Sockname, + peercert => Peercert, + ws_cookie => WsCookie, + conn_mod => ?MODULE + }, Opts), Zone = proplists:get_value(zone, Opts), MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE), ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}), - GcPolicy = emqx_zone:get_env(Zone, force_gc_policy, false), - GcState = emqx_gc:init(GcPolicy), - EnableStats = emqx_zone:get_env(Zone, enable_stats, true), - StatsTimer = if EnableStats -> undefined; ?Otherwise-> disabled end, - IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), emqx_logger:set_metadata_peername(esockd_net:format(Peername)), ok = emqx_misc:init_proc_mng_policy(Zone), {ok, #state{peername = Peername, sockname = Sockname, fsm_state = idle, parse_state = ParseState, - proto_state = ProtoState, - gc_state = GcState, - pendings = [], - stats_timer = StatsTimer, - idle_timeout = IdleTimout, - connected = false + chan_state = ChanState, + pendings = [] }}. stat_fun() -> @@ -204,14 +192,15 @@ stat_fun() -> websocket_handle({binary, Data}, State) when is_list(Data) -> websocket_handle({binary, iolist_to_binary(Data)}, State); -websocket_handle({binary, Data}, State) when is_binary(Data) -> +websocket_handle({binary, Data}, State = #state{chan_state = ChanState}) + when is_binary(Data) -> ?LOG(debug, "RECV ~p", [Data]), Oct = iolist_size(Data), emqx_pd:update_counter(recv_cnt, 1), emqx_pd:update_counter(recv_oct, Oct), ok = emqx_metrics:inc('bytes.received', Oct), - NState = maybe_gc(1, Oct, State), - process_incoming(Data, ensure_stats_timer(NState)); + NChanState = emqx_channel:gc(1, Oct, ChanState), + process_incoming(Data, State#state{chan_state = NChanState}); %% Pings should be replied with pongs, cowboy does it automatically %% Pongs can be safely ignored. Clause here simply prevents crash. @@ -240,7 +229,15 @@ websocket_info({call, From, stats}, State) -> websocket_info({call, From, kick}, State) -> gen_server:reply(From, ok), - stop(kick, State); + stop(kicked, State); + +websocket_info({cast, Msg}, State = #state{chan_state = ChanState}) -> + case emqx_channel:handle_cast(Msg, ChanState) of + {ok, NChanState} -> + {ok, State#state{chan_state = NChanState}}; + {stop, Reason, NChanState} -> + stop(Reason, State#state{chan_state = NChanState}) + end; websocket_info({incoming, Packet = ?CONNECT_PACKET( #mqtt_packet_connect{ @@ -264,17 +261,14 @@ websocket_info({incoming, Packet}, State = #state{fsm_state = connected}) handle_incoming(Packet, fun reply/1, State); websocket_info(Deliver = {deliver, _Topic, _Msg}, - State = #state{proto_state = ProtoState}) -> - Delivers = emqx_misc:drain_deliver([Deliver]), - case emqx_protocol:handle_deliver(Delivers, ProtoState) of - {ok, NProtoState} -> - reply(State#state{proto_state = NProtoState}); - {ok, Packets, NProtoState} -> - reply(enqueue(Packets, State#state{proto_state = NProtoState})); - {error, Reason} -> - stop(Reason, State); - {error, Reason, NProtoState} -> - stop(Reason, State#state{proto_state = NProtoState}) + State = #state{chan_state = ChanState}) -> + case emqx_channel:handle_out(Deliver, ChanState) of + {ok, NChanState} -> + reply(State#state{chan_state = NChanState}); + {ok, Packets, NChanState} -> + reply(enqueue(Packets, State#state{chan_state = NChanState})); + {stop, Reason, NChanState} -> + stop(Reason, State#state{chan_state = NChanState}) end; websocket_info({keepalive, check}, State = #state{keepalive = KeepAlive}) -> @@ -288,45 +282,17 @@ websocket_info({keepalive, check}, State = #state{keepalive = KeepAlive}) -> stop(keepalive_error, State) end; -websocket_info({timeout, Timer, emit_stats}, - State = #state{stats_timer = Timer, - proto_state = ProtoState, - gc_state = GcState}) -> - ClientId = emqx_protocol:info(client_id, ProtoState), - ok = emqx_cm:set_chan_stats(ClientId, stats(State)), - NState = State#state{stats_timer = undefined}, - Limits = erlang:get(force_shutdown_policy), - case emqx_misc:conn_proc_mng_policy(Limits) of - continue -> - {ok, NState}; - hibernate -> - %% going to hibernate, reset gc stats - GcState1 = emqx_gc:reset(GcState), - {ok, NState#state{gc_state = GcState1}, hibernate}; - {shutdown, Reason} -> - ?LOG(error, "Shutdown exceptionally due to ~p", [Reason]), - stop(Reason, NState) +websocket_info({timeout, TRef, Msg}, State = #state{chan_state = ChanState}) + when is_reference(TRef) -> + case emqx_channel:timeout(TRef, Msg, ChanState) of + {ok, NChanState} -> + {ok, State#state{chan_state = NChanState}}; + {ok, Packets, NChanState} -> + reply(enqueue(Packets, State#state{chan_state = NChanState})); + {stop, Reason, NChanState} -> + stop(Reason, State#state{chan_state = NChanState}) end; -websocket_info({timeout, Timer, Msg}, - State = #state{proto_state = ProtoState}) -> - case emqx_protocol:handle_timeout(Timer, Msg, ProtoState) of - {ok, NProtoState} -> - {ok, State#state{proto_state = NProtoState}}; - {ok, Packets, NProtoState} -> - reply(enqueue(Packets, State#state{proto_state = NProtoState})); - {error, Reason} -> - stop(Reason, State); - {error, Reason, NProtoState} -> - stop(Reason, State#state{proto_state = NProtoState}) - end; - -websocket_info({subscribe, TopicFilters}, State) -> - handle_request({subscribe, TopicFilters}, State); - -websocket_info({unsubscribe, TopicFilters}, State) -> - handle_request({unsubscribe, TopicFilters}, State); - websocket_info({shutdown, discard, {ClientId, ByPid}}, State) -> ?LOG(warning, "Discarded by ~s:~p", [ClientId, ByPid]), stop(discard, State); @@ -335,40 +301,39 @@ websocket_info({shutdown, conflict, {ClientId, NewPid}}, State) -> ?LOG(warning, "Clientid '~s' conflict with ~p", [ClientId, NewPid]), stop(conflict, State); -%% websocket_info({binary, Data}, State) -> -%% {reply, {binary, Data}, State}; - websocket_info({shutdown, Reason}, State) -> stop(Reason, State); websocket_info({stop, Reason}, State) -> stop(Reason, State); -websocket_info(Info, State) -> - ?LOG(error, "Unexpected info: ~p", [Info]), - {ok, State}. +websocket_info(Info, State = #state{chan_state = ChanState}) -> + case emqx_channel:handle_info(Info, ChanState) of + {ok, NChanState} -> + {ok, State#state{chan_state = NChanState}}; + {stop, Reason, NChanState} -> + stop(Reason, State#state{chan_state = NChanState}) + end. -terminate(SockError, _Req, #state{keepalive = Keepalive, - proto_state = ProtoState, - reason = Reason}) -> +terminate(SockError, _Req, #state{keepalive = KeepAlive, + chan_state = ChanState, + reason = Reason}) -> ?LOG(debug, "Terminated for ~p, sockerror: ~p", [Reason, SockError]), - emqx_keepalive:cancel(Keepalive), - emqx_protocol:terminate(Reason, ProtoState). + KeepAlive =/= undefined + andalso emqx_keepalive:cancel(KeepAlive), + emqx_channel:terminate(Reason, ChanState). %%-------------------------------------------------------------------- %% Connected callback -connected(State = #state{proto_state = ProtoState}) -> - NState = State#state{fsm_state = connected, - connected = true, - connected_at = os:timestamp() - }, - ClientId = emqx_protocol:info(client_id, ProtoState), +connected(State = #state{chan_state = ChanState}) -> + NState = State#state{fsm_state = connected}, + ClientId = emqx_channel:info(client_id, ChanState), ok = emqx_cm:register_channel(ClientId), ok = emqx_cm:set_chan_attrs(ClientId, info(NState)), %% Ensure keepalive after connected successfully. - Interval = emqx_protocol:info(keepalive, ProtoState), + Interval = emqx_channel:info(keepalive, ChanState), case ensure_keepalive(Interval, NState) of ignore -> reply(NState); {ok, KeepAlive} -> @@ -382,22 +347,11 @@ connected(State = #state{proto_state = ProtoState}) -> ensure_keepalive(0, _State) -> ignore; -ensure_keepalive(Interval, #state{proto_state = ProtoState}) -> - Backoff = emqx_zone:get_env(emqx_protocol:info(zone, ProtoState), +ensure_keepalive(Interval, #state{chan_state = ChanState}) -> + Backoff = emqx_zone:get_env(emqx_channel:info(zone, ChanState), keepalive_backoff, 0.75), emqx_keepalive:start(stat_fun(), round(Interval * Backoff), {keepalive, check}). -%%-------------------------------------------------------------------- -%% Handle internal request - -handle_request(Req, State = #state{proto_state = ProtoState}) -> - case emqx_protocol:handle_req(Req, ProtoState) of - {ok, _Result, NProtoState} -> %% TODO:: how to handle the result? - {ok, State#state{proto_state = NProtoState}}; - {error, Reason, NProtoState} -> - stop(Reason, State#state{proto_state = NProtoState}) - end. - %%-------------------------------------------------------------------- %% Process incoming data @@ -424,22 +378,19 @@ process_incoming(Data, State = #state{parse_state = ParseState}) -> %%-------------------------------------------------------------------- %% Handle incoming packets -handle_incoming(Packet = ?PACKET(Type), SuccFun, - State = #state{proto_state = ProtoState}) -> +handle_incoming(Packet = ?PACKET(Type), SuccFun, State = #state{chan_state = ChanState}) -> _ = inc_incoming_stats(Type), ok = emqx_metrics:inc_recv(Packet), ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]), - case emqx_protocol:handle_in(Packet, ProtoState) of - {ok, NProtoState} -> - SuccFun(State#state{proto_state = NProtoState}); - {ok, OutPackets, NProtoState} -> - SuccFun(enqueue(OutPackets, State#state{proto_state = NProtoState})); - {error, Reason, NProtoState} -> - stop(Reason, State#state{proto_state = NProtoState}); - {error, Reason, OutPacket, NProtoState} -> - stop(Reason, enqueue(OutPacket, State#state{proto_state = NProtoState})); - {stop, Error, NProtoState} -> - stop(Error, State#state{proto_state = NProtoState}) + case emqx_channel:handle_in(Packet, ChanState) of + {ok, NChanState} -> + SuccFun(State#state{chan_state= NChanState}); + {ok, OutPackets, NChanState} -> + SuccFun(enqueue(OutPackets, State#state{chan_state= NChanState})); + {stop, Reason, NChanState} -> + stop(Reason, State#state{chan_state= NChanState}); + {stop, Reason, OutPacket, NChanState} -> + stop(Reason, enqueue(OutPacket, State#state{chan_state= NChanState})) end. %%-------------------------------------------------------------------- @@ -495,29 +446,9 @@ enqueue(Packet, State) when is_record(Packet, mqtt_packet) -> enqueue(Packets, State = #state{pendings = Pendings}) -> State#state{pendings = lists:append(Pendings, Packets)}. -%%-------------------------------------------------------------------- -%% Ensure stats timer - -ensure_stats_timer(State = #state{stats_timer = undefined, - idle_timeout = IdleTimeout}) -> - TRef = emqx_misc:start_timer(IdleTimeout, emit_stats), - State#state{stats_timer = TRef}; -%% disabled or timer existed -ensure_stats_timer(State) -> State. - wsock_stats() -> [{Key, emqx_pd:get_counter(Key)} || Key <- ?SOCK_STATS]. chan_stats() -> [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS]. -%%-------------------------------------------------------------------- -%% Maybe GC - -maybe_gc(_Cnt, _Oct, State = #state{gc_state = undefined}) -> - State; -maybe_gc(Cnt, Oct, State = #state{gc_state = GCSt}) -> - {Ok, GCSt1} = emqx_gc:run(Cnt, Oct, GCSt), - Ok andalso emqx_metrics:inc('channel.gc.cnt'), - State#state{gc_state = GCSt1}. - diff --git a/test/emqx_channel_SUITE.erl b/test/emqx_channel_SUITE.erl index 580b7a0fa..55b6f0323 100644 --- a/test/emqx_channel_SUITE.erl +++ b/test/emqx_channel_SUITE.erl @@ -19,6 +19,14 @@ -compile(export_all). -compile(nowarn_export_all). +-import(emqx_channel, + [ handle_in/2 + , handle_out/2 + , handle_out/3 + ]). + +-include("emqx.hrl"). +-include("emqx_mqtt.hrl"). -include_lib("eunit/include/eunit.hrl"). all() -> emqx_ct:all(?MODULE). @@ -30,28 +38,252 @@ init_per_suite(Config) -> end_per_suite(_Config) -> emqx_ct_helpers:stop_apps([]). -t_basic(_) -> - Topic = <<"TopicA">>, - {ok, C} = emqtt:start_link([{port, 1883}, {client_id, <<"hello">>}]), - {ok, _} = emqtt:connect(C), - {ok, _, [1]} = emqtt:subscribe(C, Topic, qos1), - {ok, _, [2]} = emqtt:subscribe(C, Topic, qos2), - {ok, _} = emqtt:publish(C, Topic, <<"qos 2">>, 2), - {ok, _} = emqtt:publish(C, Topic, <<"qos 2">>, 2), - {ok, _} = emqtt:publish(C, Topic, <<"qos 2">>, 2), - ?assertEqual(3, length(recv_msgs(3))), - ok = emqtt:disconnect(C). +%%-------------------------------------------------------------------- +%% Test cases for handle_in +%%-------------------------------------------------------------------- -recv_msgs(Count) -> - recv_msgs(Count, []). +t_handle_connect(_) -> + ConnPkt = #mqtt_packet_connect{ + proto_name = <<"MQTT">>, + proto_ver = ?MQTT_PROTO_V4, + is_bridge = false, + clean_start = true, + keepalive = 30, + properties = #{}, + client_id = <<"clientid">>, + username = <<"username">>, + password = <<"passwd">> + }, + with_channel( + fun(Channel) -> + {ok, ?CONNACK_PACKET(?RC_SUCCESS), Channel1} + = handle_in(?CONNECT_PACKET(ConnPkt), Channel), + Client = emqx_channel:info(client, Channel1), + ?assertEqual(<<"clientid">>, maps:get(client_id, Client)), + ?assertEqual(<<"username">>, maps:get(username, Client)) + end). -recv_msgs(0, Msgs) -> - Msgs; -recv_msgs(Count, Msgs) -> - receive - {publish, Msg} -> - recv_msgs(Count-1, [Msg|Msgs]) - after 100 -> - Msgs - end. +t_handle_publish_qos0(_) -> + with_channel( + fun(Channel) -> + Publish = ?PUBLISH_PACKET(?QOS_0, <<"topic">>, undefined, <<"payload">>), + {ok, Channel} = handle_in(Publish, Channel) + end). + +t_handle_publish_qos1(_) -> + with_channel( + fun(Channel) -> + Publish = ?PUBLISH_PACKET(?QOS_1, <<"topic">>, 1, <<"payload">>), + {ok, ?PUBACK_PACKET(1, RC), _} = handle_in(Publish, Channel), + ?assert((RC == ?RC_SUCCESS) orelse (RC == ?RC_NO_MATCHING_SUBSCRIBERS)) + end). + +t_handle_publish_qos2(_) -> + with_channel( + fun(Channel) -> + Publish1 = ?PUBLISH_PACKET(?QOS_2, <<"topic">>, 1, <<"payload">>), + {ok, ?PUBREC_PACKET(1, RC), Channel1} = handle_in(Publish1, Channel), + Publish2 = ?PUBLISH_PACKET(?QOS_2, <<"topic">>, 2, <<"payload">>), + {ok, ?PUBREC_PACKET(2, RC), Channel2} = handle_in(Publish2, Channel1), + ?assert((RC == ?RC_SUCCESS) orelse (RC == ?RC_NO_MATCHING_SUBSCRIBERS)), + Session = emqx_channel:info(session, Channel2), + ?assertEqual(2, emqx_session:info(awaiting_rel, Session)) + end). + +t_handle_puback(_) -> + with_channel( + fun(Channel) -> + {ok, Channel} = handle_in(?PUBACK_PACKET(1, ?RC_SUCCESS), Channel) + end). + +t_handle_pubrec(_) -> + with_channel( + fun(Channel) -> + {ok, ?PUBREL_PACKET(1, ?RC_PACKET_IDENTIFIER_NOT_FOUND), Channel} + = handle_in(?PUBREC_PACKET(1, ?RC_SUCCESS), Channel) + end). + +t_handle_pubrel(_) -> + with_channel( + fun(Channel) -> + {ok, ?PUBCOMP_PACKET(1, ?RC_PACKET_IDENTIFIER_NOT_FOUND), Channel} + = handle_in(?PUBREL_PACKET(1, ?RC_SUCCESS), Channel) + end). + +t_handle_pubcomp(_) -> + with_channel( + fun(Channel) -> + {ok, Channel} = handle_in(?PUBCOMP_PACKET(1, ?RC_SUCCESS), Channel) + end). + +t_handle_subscribe(_) -> + with_channel( + fun(Channel) -> + TopicFilters = [{<<"+">>, ?DEFAULT_SUBOPTS}], + {ok, ?SUBACK_PACKET(10, [?QOS_0]), Channel1} + = handle_in(?SUBSCRIBE_PACKET(10, #{}, TopicFilters), Channel), + Session = emqx_channel:info(session, Channel1), + ?assertEqual(maps:from_list(TopicFilters), + emqx_session:info(subscriptions, Session)) + + end). + +t_handle_unsubscribe(_) -> + with_channel( + fun(Channel) -> + {ok, ?UNSUBACK_PACKET(11), Channel} + = handle_in(?UNSUBSCRIBE_PACKET(11, #{}, [<<"+">>]), Channel) + end). + +t_handle_pingreq(_) -> + with_channel( + fun(Channel) -> + {ok, ?PACKET(?PINGRESP), Channel} = handle_in(?PACKET(?PINGREQ), Channel) + end). + +t_handle_disconnect(_) -> + with_channel( + fun(Channel) -> + {stop, normal, Channel1} = handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), Channel), + ?assertEqual(undefined, emqx_channel:info(will_msg, Channel1)) + end). + +t_handle_auth(_) -> + with_channel( + fun(Channel) -> + {ok, Channel} = handle_in(?AUTH_PACKET(), Channel) + end). + +%%-------------------------------------------------------------------- +%% Test cases for handle_deliver +%%-------------------------------------------------------------------- + +t_handle_deliver(_) -> + with_channel( + fun(Channel) -> + TopicFilters = [{<<"+">>, ?DEFAULT_SUBOPTS#{qos => ?QOS_2}}], + {ok, ?SUBACK_PACKET(1, [?QOS_2]), Channel1} + = handle_in(?SUBSCRIBE_PACKET(1, #{}, TopicFilters), Channel), + Msg0 = emqx_message:make(<<"clientx">>, ?QOS_0, <<"t0">>, <<"qos0">>), + Msg1 = emqx_message:make(<<"clientx">>, ?QOS_1, <<"t1">>, <<"qos1">>), + Delivers = [{deliver, <<"+">>, Msg0}, + {deliver, <<"+">>, Msg1}], + {ok, Packets, _Channel2} = emqx_channel:handle_deliver(Delivers, Channel1), + ?assertMatch([?PUBLISH_PACKET(?QOS_0, <<"t0">>, undefined, <<"qos0">>), + ?PUBLISH_PACKET(?QOS_1, <<"t1">>, 1, <<"qos1">>) + ], Packets) + end). + +%%-------------------------------------------------------------------- +%% Test cases for handle_out +%%-------------------------------------------------------------------- + +t_handle_conack(_) -> + with_channel( + fun(Channel) -> + {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, _), _} + = handle_out(connack, {?RC_SUCCESS, 0}, Channel), + {error, unauthorized_client, ?CONNACK_PACKET(5), _} + = handle_out(connack, ?RC_NOT_AUTHORIZED, Channel) + end). + +t_handle_out_publish(_) -> + with_channel( + fun(Channel) -> + Pub0 = {publish, undefined, emqx_message:make(<<"t">>, <<"qos0">>)}, + Pub1 = {publish, 1, emqx_message:make(<<"c">>, ?QOS_1, <<"t">>, <<"qos1">>)}, + {ok, ?PUBLISH_PACKET(?QOS_0), Channel} = handle_out(Pub0, Channel), + {ok, ?PUBLISH_PACKET(?QOS_1), Channel} = handle_out(Pub1, Channel), + {ok, Packets, Channel} = handle_out(publish, [Pub0, Pub1], Channel), + ?assertEqual(2, length(Packets)) + end). + +t_handle_out_puback(_) -> + with_channel( + fun(Channel) -> + {ok, Channel} = handle_out(puberr, ?RC_NOT_AUTHORIZED, Channel), + {ok, ?PUBACK_PACKET(1, ?RC_SUCCESS), Channel} + = handle_out(puback, {1, ?RC_SUCCESS}, Channel) + end). + +t_handle_out_pubrec(_) -> + with_channel( + fun(Channel) -> + {ok, ?PUBREC_PACKET(4, ?RC_SUCCESS), Channel} + = handle_out(pubrec, {4, ?RC_SUCCESS}, Channel) + end). + +t_handle_out_pubrel(_) -> + with_channel( + fun(Channel) -> + {ok, ?PUBREL_PACKET(2), Channel} + = handle_out(pubrel, {2, ?RC_SUCCESS}, Channel), + {ok, ?PUBREL_PACKET(3, ?RC_SUCCESS), Channel} + = handle_out(pubrel, {3, ?RC_SUCCESS}, Channel) + end). + +t_handle_out_pubcomp(_) -> + with_channel( + fun(Channel) -> + {ok, ?PUBCOMP_PACKET(5, ?RC_SUCCESS), Channel} + = handle_out(pubcomp, {5, ?RC_SUCCESS}, Channel) + end). + +t_handle_out_suback(_) -> + with_channel( + fun(Channel) -> + {ok, ?SUBACK_PACKET(1, [?QOS_2]), Channel} + = handle_out(suback, {1, [?QOS_2]}, Channel) + end). + +t_handle_out_unsuback(_) -> + with_channel( + fun(Channel) -> + {ok, ?UNSUBACK_PACKET(1), Channel} + = handle_out(unsuback, {1, [?RC_SUCCESS]}, Channel) + end). + +t_handle_out_disconnect(_) -> + with_channel( + fun(Channel) -> + handle_out(disconnect, ?RC_SUCCESS, Channel) + end). + +%%-------------------------------------------------------------------- +%% Test cases for handle_timeout +%%-------------------------------------------------------------------- + +t_handle_timeout(_) -> + with_channel( + fun(Channel) -> + 'TODO' + end). + +%%-------------------------------------------------------------------- +%% Test cases for terminate +%%-------------------------------------------------------------------- + +t_terminate(_) -> + with_channel( + fun(Channel) -> + 'TODO' + end). + +%%-------------------------------------------------------------------- +%% Helper functions +%%-------------------------------------------------------------------- + +with_channel(Fun) -> + ConnInfo = #{peername => {{127,0,0,1}, 3456}, + sockname => {{127,0,0,1}, 1883}, + client_id => <<"clientid">>, + username => <<"username">> + }, + Options = [{zone, testing}], + Channel = emqx_channel:init(ConnInfo, Options), + Session = emqx_session:init(false, #{zone => testing}, + #{max_inflight => 100, + expiry_interval => 0 + }), + Fun(emqx_channel:set(session, Session, Channel)). diff --git a/test/emqx_connection_SUITE.erl b/test/emqx_connection_SUITE.erl new file mode 100644 index 000000000..8e595b8b2 --- /dev/null +++ b/test/emqx_connection_SUITE.erl @@ -0,0 +1,57 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_connection_SUITE). + +-compile(export_all). +-compile(nowarn_export_all). + +-include_lib("eunit/include/eunit.hrl"). + +all() -> emqx_ct:all(?MODULE). + +init_per_suite(Config) -> + emqx_ct_helpers:start_apps([]), + Config. + +end_per_suite(_Config) -> + emqx_ct_helpers:stop_apps([]). + +t_basic(_) -> + Topic = <<"TopicA">>, + {ok, C} = emqtt:start_link([{port, 1883}, {client_id, <<"hello">>}]), + {ok, _} = emqtt:connect(C), + {ok, _, [1]} = emqtt:subscribe(C, Topic, qos1), + {ok, _, [2]} = emqtt:subscribe(C, Topic, qos2), + {ok, _} = emqtt:publish(C, Topic, <<"qos 2">>, 2), + {ok, _} = emqtt:publish(C, Topic, <<"qos 2">>, 2), + {ok, _} = emqtt:publish(C, Topic, <<"qos 2">>, 2), + ?assertEqual(3, length(recv_msgs(3))), + ok = emqtt:disconnect(C). + +recv_msgs(Count) -> + recv_msgs(Count, []). + +recv_msgs(0, Msgs) -> + Msgs; +recv_msgs(Count, Msgs) -> + receive + {publish, Msg} -> + recv_msgs(Count-1, [Msg|Msgs]) + after 100 -> + Msgs + end. + diff --git a/test/emqx_oom_SUITE.erl b/test/emqx_oom_SUITE.erl new file mode 100644 index 000000000..a4be93129 --- /dev/null +++ b/test/emqx_oom_SUITE.erl @@ -0,0 +1,34 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_oom_SUITE). + +-compile(export_all). +-compile(nowarn_export_all). + +-include_lib("eunit/include/eunit.hrl"). + +all() -> emqx_ct:all(?MODULE). + +t_init(_) -> + 'TODO'. + +t_check(_) -> + 'TODO'. + +t_info(_) -> + 'TODO'. + diff --git a/test/emqx_protocol_SUITE.erl b/test/emqx_protocol_SUITE.erl deleted file mode 100644 index 4edcbe3f7..000000000 --- a/test/emqx_protocol_SUITE.erl +++ /dev/null @@ -1,287 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - --module(emqx_protocol_SUITE). - --compile(export_all). --compile(nowarn_export_all). - --import(emqx_protocol, - [ handle_in/2 - , handle_out/2 - ]). - --include("emqx.hrl"). --include("emqx_mqtt.hrl"). - --include_lib("eunit/include/eunit.hrl"). - -all() -> emqx_ct:all(?MODULE). - -init_per_suite(Config) -> - emqx_ct_helpers:start_apps([]), - Config. - -end_per_suite(_Config) -> - emqx_ct_helpers:stop_apps([]). - -%%-------------------------------------------------------------------- -%% Test cases for handle_in -%%-------------------------------------------------------------------- - -t_handle_connect(_) -> - ConnPkt = #mqtt_packet_connect{ - proto_name = <<"MQTT">>, - proto_ver = ?MQTT_PROTO_V4, - is_bridge = false, - clean_start = true, - keepalive = 30, - properties = #{}, - client_id = <<"clientid">>, - username = <<"username">>, - password = <<"passwd">> - }, - with_proto( - fun(PState) -> - {ok, ?CONNACK_PACKET(?RC_SUCCESS), PState1} - = handle_in(?CONNECT_PACKET(ConnPkt), PState), - Client = emqx_protocol:info(client, PState1), - ?assertEqual(<<"clientid">>, maps:get(client_id, Client)), - ?assertEqual(<<"username">>, maps:get(username, Client)) - end). - -t_handle_publish_qos0(_) -> - with_proto( - fun(PState) -> - Publish = ?PUBLISH_PACKET(?QOS_0, <<"topic">>, undefined, <<"payload">>), - {ok, PState} = handle_in(Publish, PState) - end). - -t_handle_publish_qos1(_) -> - with_proto( - fun(PState) -> - Publish = ?PUBLISH_PACKET(?QOS_1, <<"topic">>, 1, <<"payload">>), - {ok, ?PUBACK_PACKET(1, RC), _} = handle_in(Publish, PState), - ?assert((RC == ?RC_SUCCESS) orelse (RC == ?RC_NO_MATCHING_SUBSCRIBERS)) - end). - -t_handle_publish_qos2(_) -> - with_proto( - fun(PState) -> - Publish1 = ?PUBLISH_PACKET(?QOS_2, <<"topic">>, 1, <<"payload">>), - {ok, ?PUBREC_PACKET(1, RC), PState1} = handle_in(Publish1, PState), - Publish2 = ?PUBLISH_PACKET(?QOS_2, <<"topic">>, 2, <<"payload">>), - {ok, ?PUBREC_PACKET(2, RC), PState2} = handle_in(Publish2, PState1), - ?assert((RC == ?RC_SUCCESS) orelse (RC == ?RC_NO_MATCHING_SUBSCRIBERS)), - Session = emqx_protocol:info(session, PState2), - ?assertEqual(2, emqx_session:info(awaiting_rel, Session)) - end). - -t_handle_puback(_) -> - with_proto( - fun(PState) -> - {ok, PState} = handle_in(?PUBACK_PACKET(1, ?RC_SUCCESS), PState) - end). - -t_handle_pubrec(_) -> - with_proto( - fun(PState) -> - {ok, ?PUBREL_PACKET(1, ?RC_PACKET_IDENTIFIER_NOT_FOUND), PState} - = handle_in(?PUBREC_PACKET(1, ?RC_SUCCESS), PState) - end). - -t_handle_pubrel(_) -> - with_proto( - fun(PState) -> - {ok, ?PUBCOMP_PACKET(1, ?RC_PACKET_IDENTIFIER_NOT_FOUND), PState} - = handle_in(?PUBREL_PACKET(1, ?RC_SUCCESS), PState) - end). - -t_handle_pubcomp(_) -> - with_proto( - fun(PState) -> - {ok, PState} = handle_in(?PUBCOMP_PACKET(1, ?RC_SUCCESS), PState) - end). - -t_handle_subscribe(_) -> - with_proto( - fun(PState) -> - TopicFilters = [{<<"+">>, ?DEFAULT_SUBOPTS}], - {ok, ?SUBACK_PACKET(10, [?QOS_0]), PState1} - = handle_in(?SUBSCRIBE_PACKET(10, #{}, TopicFilters), PState), - Session = emqx_protocol:info(session, PState1), - ?assertEqual(maps:from_list(TopicFilters), - emqx_session:info(subscriptions, Session)) - - end). - -t_handle_unsubscribe(_) -> - with_proto( - fun(PState) -> - {ok, ?UNSUBACK_PACKET(11), PState} - = handle_in(?UNSUBSCRIBE_PACKET(11, #{}, [<<"+">>]), PState) - end). - -t_handle_pingreq(_) -> - with_proto( - fun(PState) -> - {ok, ?PACKET(?PINGRESP), PState} = handle_in(?PACKET(?PINGREQ), PState) - end). - -t_handle_disconnect(_) -> - with_proto( - fun(PState) -> - {stop, normal, PState1} = handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), PState), - ?assertEqual(undefined, emqx_protocol:info(will_msg, PState1)) - end). - -t_handle_auth(_) -> - with_proto( - fun(PState) -> - {ok, PState} = handle_in(?AUTH_PACKET(), PState) - end). - -%%-------------------------------------------------------------------- -%% Test cases for handle_deliver -%%-------------------------------------------------------------------- - -t_handle_deliver(_) -> - with_proto( - fun(PState) -> - TopicFilters = [{<<"+">>, ?DEFAULT_SUBOPTS#{qos => ?QOS_2}}], - {ok, ?SUBACK_PACKET(1, [?QOS_2]), PState1} - = handle_in(?SUBSCRIBE_PACKET(1, #{}, TopicFilters), PState), - Msg0 = emqx_message:make(<<"clientx">>, ?QOS_0, <<"t0">>, <<"qos0">>), - Msg1 = emqx_message:make(<<"clientx">>, ?QOS_1, <<"t1">>, <<"qos1">>), - Delivers = [{deliver, <<"+">>, Msg0}, - {deliver, <<"+">>, Msg1}], - {ok, Packets, _PState2} = emqx_protocol:handle_deliver(Delivers, PState1), - ?assertMatch([?PUBLISH_PACKET(?QOS_0, <<"t0">>, undefined, <<"qos0">>), - ?PUBLISH_PACKET(?QOS_1, <<"t1">>, 1, <<"qos1">>) - ], Packets) - end). - -%%-------------------------------------------------------------------- -%% Test cases for handle_out -%%-------------------------------------------------------------------- - -t_handle_conack(_) -> - with_proto( - fun(PState) -> - {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, _), _} - = handle_out({connack, ?RC_SUCCESS, 0}, PState), - {error, unauthorized_client, ?CONNACK_PACKET(5), _} - = handle_out({connack, ?RC_NOT_AUTHORIZED}, PState) - end). - -t_handle_out_publish(_) -> - with_proto( - fun(PState) -> - Pub0 = {publish, undefined, emqx_message:make(<<"t">>, <<"qos0">>)}, - Pub1 = {publish, 1, emqx_message:make(<<"c">>, ?QOS_1, <<"t">>, <<"qos1">>)}, - {ok, ?PUBLISH_PACKET(?QOS_0), PState} = handle_out(Pub0, PState), - {ok, ?PUBLISH_PACKET(?QOS_1), PState} = handle_out(Pub1, PState), - {ok, Packets, PState} = handle_out({publish, [Pub0, Pub1]}, PState), - ?assertEqual(2, length(Packets)) - end). - -t_handle_out_puback(_) -> - with_proto( - fun(PState) -> - {ok, PState} = handle_out({puberr, ?RC_NOT_AUTHORIZED}, PState), - {ok, ?PUBACK_PACKET(1, ?RC_SUCCESS), PState} - = handle_out({puback, 1, ?RC_SUCCESS}, PState) - end). - -t_handle_out_pubrec(_) -> - with_proto( - fun(PState) -> - {ok, ?PUBREC_PACKET(4, ?RC_SUCCESS), PState} - = handle_out({pubrec, 4, ?RC_SUCCESS}, PState) - end). - -t_handle_out_pubrel(_) -> - with_proto( - fun(PState) -> - {ok, ?PUBREL_PACKET(2), PState} = handle_out({pubrel, 2}, PState), - {ok, ?PUBREL_PACKET(3, ?RC_SUCCESS), PState} - = handle_out({pubrel, 3, ?RC_SUCCESS}, PState) - end). - -t_handle_out_pubcomp(_) -> - with_proto( - fun(PState) -> - {ok, ?PUBCOMP_PACKET(5, ?RC_SUCCESS), PState} - = handle_out({pubcomp, 5, ?RC_SUCCESS}, PState) - end). - -t_handle_out_suback(_) -> - with_proto( - fun(PState) -> - {ok, ?SUBACK_PACKET(1, [?QOS_2]), PState} - = handle_out({suback, 1, [?QOS_2]}, PState) - end). - -t_handle_out_unsuback(_) -> - with_proto( - fun(PState) -> - {ok, ?UNSUBACK_PACKET(1), PState} = handle_out({unsuback, 1, [?RC_SUCCESS]}, PState) - end). - -t_handle_out_disconnect(_) -> - with_proto( - fun(PState) -> - handle_out({disconnect, 0}, PState) - end). - -%%-------------------------------------------------------------------- -%% Test cases for handle_timeout -%%-------------------------------------------------------------------- - -t_handle_timeout(_) -> - with_proto( - fun(PState) -> - 'TODO' - end). - -%%-------------------------------------------------------------------- -%% Test cases for terminate -%%-------------------------------------------------------------------- - -t_terminate(_) -> - with_proto( - fun(PState) -> - 'TODO' - end). - -%%-------------------------------------------------------------------- -%% Helper functions -%%-------------------------------------------------------------------- - -with_proto(Fun) -> - ConnInfo = #{peername => {{127,0,0,1}, 3456}, - sockname => {{127,0,0,1}, 1883}, - client_id => <<"clientid">>, - username => <<"username">> - }, - Options = [{zone, testing}], - PState = emqx_protocol:init(ConnInfo, Options), - Session = emqx_session:init(false, #{zone => testing}, - #{max_inflight => 100, - expiry_interval => 0 - }), - Fun(emqx_protocol:set(session, Session, PState)). - diff --git a/test/emqx_ws_channel_SUITE.erl b/test/emqx_ws_connection_SUITE.erl similarity index 98% rename from test/emqx_ws_channel_SUITE.erl rename to test/emqx_ws_connection_SUITE.erl index f634e633e..2e2db7728 100644 --- a/test/emqx_ws_channel_SUITE.erl +++ b/test/emqx_ws_connection_SUITE.erl @@ -14,7 +14,7 @@ %% limitations under the License. %%-------------------------------------------------------------------- --module(emqx_ws_channel_SUITE). +-module(emqx_ws_connection_SUITE). -compile(export_all). -compile(nowarn_export_all).