diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl index a3804298b..aa9a9f897 100644 --- a/src/emqx_channel.erl +++ b/src/emqx_channel.erl @@ -24,13 +24,19 @@ -include("logger.hrl"). -include("types.hrl"). +-logger_header("[Channel]"). + -export([start_link/3]). %% APIs -export([ info/1 + , attrs/1 , stats/1 ]). +%% for Debug +-export([state/1]). + %% state callbacks -export([ idle/3 , connected/3 @@ -60,10 +66,12 @@ gc_state :: emqx_gc:gc_state(), keepalive :: maybe(emqx_keepalive:keepalive()), stats_timer :: disabled | maybe(reference()), - idle_timeout :: timeout() - }). + idle_timeout :: timeout(), + connected :: boolean(), + connected_at :: erlang:timestamp() + }). --logger_header("[Channel]"). +-type(state() :: #state{}). -define(ACTIVE_N, 100). -define(HANDLE(T, C, D), handle((T), (C), (D))). @@ -79,55 +87,81 @@ start_link(Transport, Socket, Options) -> %% API %%-------------------------------------------------------------------- -%% @doc Get channel's info. --spec(info(pid() | #state{}) -> proplists:proplist()). +%% @doc Get infos of the channel. +-spec(info(pid() | state()) -> emqx_types:infos()). info(CPid) when is_pid(CPid) -> call(CPid, info); +info(#state{transport = Transport, + socket = Socket, + peername = Peername, + sockname = Sockname, + conn_state = ConnState, + active_n = ActiveN, + rate_limit = RateLimit, + pub_limit = PubLimit, + proto_state = ProtoState, + gc_state = GCState, + stats_timer = StatsTimer, + idle_timeout = IdleTimeout, + connected = Connected, + connected_at = ConnectedAt}) -> + ChanInfo = #{socktype => Transport:type(Socket), + peername => Peername, + sockname => Sockname, + conn_state => ConnState, + active_n => ActiveN, + rate_limit => limit_info(RateLimit), + pub_limit => limit_info(PubLimit), + gc_state => emqx_gc:info(GCState), + enable_stats => case StatsTimer of + disabled -> false; + _Otherwise -> true + end, + idle_timeout => IdleTimeout, + connected => Connected, + connected_at => ConnectedAt + }, + maps:merge(ChanInfo, emqx_protocol:info(ProtoState)). -info(#state{transport = Transport, - socket = Socket, - peername = Peername, - sockname = Sockname, - conn_state = ConnState, - active_n = ActiveN, - rate_limit = RateLimit, - pub_limit = PubLimit, - proto_state = ProtoState, - gc_state = GCState, - stats_timer = StatsTimer, - idle_timeout = IdleTimeout}) -> - [{socktype, Transport:type(Socket)}, - {peername, Peername}, - {sockname, Sockname}, - {conn_state, ConnState}, - {active_n, ActiveN}, - {rate_limit, rate_limit_info(RateLimit)}, - {pub_limit, rate_limit_info(PubLimit)}, - {gc_state, emqx_gc:info(GCState)}, - {enable_stats, case StatsTimer of - disabled -> false; - _Otherwise -> true - end}, - {idle_timeout, IdleTimeout} | - emqx_protocol:info(ProtoState)]. - -rate_limit_info(undefined) -> +limit_info(undefined) -> undefined; -rate_limit_info(Limit) -> +limit_info(Limit) -> esockd_rate_limit:info(Limit). -%% @doc Get channel's stats. --spec(stats(pid() | #state{}) -> proplists:proplist()). +%% @doc Get attrs of the channel. +-spec(attrs(pid() | state()) -> emqx_types:attrs()). +attrs(CPid) when is_pid(CPid) -> + call(CPid, attrs); +attrs(#state{transport = Transport, + socket = Socket, + peername = Peername, + sockname = Sockname, + proto_state = ProtoState, + connected = Connected, + connected_at = ConnectedAt}) -> + ConnAttrs = #{socktype => Transport:type(Socket), + peername => Peername, + sockname => Sockname, + connected => Connected, + connected_at => ConnectedAt}, + maps:merge(ConnAttrs, emqx_protocol:attrs(ProtoState)). + +%% @doc Get stats of the channel. +-spec(stats(pid() | state()) -> emqx_types:stats()). stats(CPid) when is_pid(CPid) -> call(CPid, stats); - -stats(#state{transport = Transport, socket = Socket}) -> +stats(#state{transport = Transport, + socket = Socket, + proto_state = ProtoState}) -> SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of {ok, Ss} -> Ss; {error, _} -> [] end, ChanStats = [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS], - lists:append([SockStats, ChanStats, emqx_misc:proc_stats()]). + SessStats = emqx_session:stats(emqx_protocol:info(session, ProtoState)), + lists:append([SockStats, ChanStats, SessStats, emqx_misc:proc_stats()]). + +state(CPid) -> call(CPid, get_state). %% @private call(CPid, Req) -> @@ -162,6 +196,7 @@ init({Transport, RawSocket, Options}) -> State = #state{transport = Transport, socket = Socket, peername = Peername, + sockname = Sockname, conn_state = running, active_n = ActiveN, rate_limit = RateLimit, @@ -170,7 +205,8 @@ init({Transport, RawSocket, Options}) -> proto_state = ProtoState, gc_state = GcState, stats_timer = StatsTimer, - idle_timeout = IdleTimout + idle_timeout = IdleTimout, + connected = false }, gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}], idle, State, self(), [IdleTimout]). @@ -216,16 +252,19 @@ idle(EventType, Content, State) -> %% Connected State connected(enter, _PrevSt, State = #state{proto_state = ProtoState}) -> - ClientId = emqx_protocol:client_id(ProtoState), - ok = emqx_cm:set_chan_attrs(ClientId, info(State)), + NState = State#state{connected = true, + connected_at = os:timestamp()}, + ClientId = emqx_protocol:info(client_id, ProtoState), + ok = emqx_cm:set_chan_attrs(ClientId, attrs(NState)), %% Ensure keepalive after connected successfully. Interval = emqx_protocol:info(keepalive, ProtoState), - case ensure_keepalive(Interval, State) of - ignore -> keep_state_and_data; + case ensure_keepalive(Interval, NState) of + ignore -> + keep_state(NState); {ok, KeepAlive} -> - keep_state(State#state{keepalive = KeepAlive}); + keep_state(NState#state{keepalive = KeepAlive}); {error, Reason} -> - shutdown(Reason, State) + shutdown(Reason, NState) end; connected(cast, {incoming, Packet = ?PACKET(?CONNECT)}, State) -> @@ -279,9 +318,15 @@ disconnected(EventType, Content, State) -> handle({call, From}, info, State) -> reply(From, info(State), State); +handle({call, From}, attrs, State) -> + reply(From, attrs(State), State); + handle({call, From}, stats, State) -> reply(From, stats(State), State); +handle({call, From}, get_state, State) -> + reply(From, State, State); + %%handle({call, From}, kick, State) -> %% ok = gen_statem:reply(From, ok), %% shutdown(kicked, State); @@ -309,12 +354,12 @@ handle(info, {Inet, _Sock, Data}, State) when Inet == tcp; NState = maybe_gc(1, Oct, State), process_incoming(Data, ensure_stats_timer(NState)); -handle(info, {Error, _Sock, Reason}, State) - when Error == tcp_error; Error == ssl_error -> +handle(info, {Error, _Sock, Reason}, State) when Error == tcp_error; + Error == ssl_error -> shutdown(Reason, State); -handle(info, {Closed, _Sock}, State) - when Closed == tcp_closed; Closed == ssl_closed -> +handle(info, {Closed, _Sock}, State) when Closed == tcp_closed; + Closed == ssl_closed -> shutdown(closed, State); handle(info, {Passive, _Sock}, State) when Passive == tcp_passive; @@ -348,7 +393,7 @@ handle(info, {timeout, Timer, emit_stats}, State = #state{stats_timer = Timer, proto_state = ProtoState, gc_state = GcState}) -> - ClientId = emqx_protocol:client_id(ProtoState), + ClientId = emqx_protocol:info(client_id, ProtoState), ok = emqx_cm:set_chan_stats(ClientId, stats(State)), NState = State#state{stats_timer = undefined}, Limits = erlang:get(force_shutdown_policy), @@ -474,7 +519,7 @@ serialize_fun(ProtoVer) -> ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), _ = inc_outgoing_stats(Type), emqx_frame:serialize(Packet, ProtoVer) - end. + end. %%-------------------------------------------------------------------- %% Send data diff --git a/src/emqx_endpoint.erl b/src/emqx_endpoint.erl deleted file mode 100644 index dee80a099..000000000 --- a/src/emqx_endpoint.erl +++ /dev/null @@ -1,99 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - --module(emqx_endpoint). - --include("types.hrl"). - -%% APIs --export([ new/0 - , new/1 - , info/1 - ]). - --export([ zone/1 - , client_id/1 - , mountpoint/1 - , is_superuser/1 - , credentials/1 - ]). - --export([update/2]). - --export([to_map/1]). - --export_type([endpoint/0]). - --type(st() :: #{zone := emqx_types:zone(), - conn_mod := maybe(module()), - peername := emqx_types:peername(), - sockname := emqx_types:peername(), - client_id := emqx_types:client_id(), - username := emqx_types:username(), - peercert := esockd_peercert:peercert(), - is_superuser := boolean(), - mountpoint := maybe(binary()), - ws_cookie := maybe(list()), - password => binary(), - auth_result => emqx_types:auth_result(), - anonymous => boolean(), - atom() => term() - }). - --opaque(endpoint() :: {endpoint, st()}). - --define(Endpoint(St), {endpoint, St}). - --define(Default, #{is_superuser => false, - anonymous => false - }). - --spec(new() -> endpoint()). -new() -> - ?Endpoint(?Default). - --spec(new(map()) -> endpoint()). -new(M) when is_map(M) -> - ?Endpoint(maps:merge(?Default, M)). - -info(?Endpoint(M)) -> - maps:to_list(M). - --spec(zone(endpoint()) -> emqx_zone:zone()). -zone(?Endpoint(#{zone := Zone})) -> - Zone. - -client_id(?Endpoint(#{client_id := ClientId})) -> - ClientId. - --spec(mountpoint(endpoint()) -> maybe(binary())). -mountpoint(?Endpoint(#{mountpoint := Mountpoint})) -> - Mountpoint; -mountpoint(_) -> undefined. - -is_superuser(?Endpoint(#{is_superuser := B})) -> - B. - -update(Attrs, ?Endpoint(M)) -> - ?Endpoint(maps:merge(M, Attrs)). - -credentials(?Endpoint(M)) -> - M. %% TODO: ... - --spec(to_map(endpoint()) -> map()). -to_map(?Endpoint(M)) -> - M. - diff --git a/src/emqx_ws_channel.erl b/src/emqx_ws_channel.erl index 4da81a2d7..3c04e4cea 100644 --- a/src/emqx_ws_channel.erl +++ b/src/emqx_ws_channel.erl @@ -38,8 +38,6 @@ ]). -record(state, { - request :: cowboy_req:req(), - options :: proplists:proplist(), peername :: emqx_types:peername(), sockname :: emqx_types:peername(), fsm_state :: idle | connected | disconnected, @@ -48,10 +46,15 @@ proto_state :: emqx_protocol:proto_state(), gc_state :: emqx_gc:gc_state(), keepalive :: maybe(emqx_keepalive:keepalive()), + pendings :: list(), stats_timer :: disabled | maybe(reference()), idle_timeout :: timeout(), - shutdown - }). + connected :: boolean(), + connected_at :: erlang:timestamp(), + reason :: term() + }). + +-type(state() :: #state{}). -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]). -define(CHAN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). @@ -60,36 +63,57 @@ %% API %%-------------------------------------------------------------------- -%% for debug +-spec(info(pid() | state()) -> emqx_types:infos()). info(WSPid) when is_pid(WSPid) -> call(WSPid, info); - info(#state{peername = Peername, sockname = Sockname, - proto_state = ProtoState}) -> - [{socktype, websocket}, - {conn_state, running}, - {peername, Peername}, - {sockname, Sockname} | - emqx_protocol:info(ProtoState)]. + proto_state = ProtoState, + gc_state = GCState, + stats_timer = StatsTimer, + idle_timeout = IdleTimeout, + connected = Connected, + connected_at = ConnectedAt}) -> + ChanInfo = #{socktype => websocket, + peername => Peername, + sockname => Sockname, + conn_state => running, + gc_state => emqx_gc:info(GCState), + enable_stats => enable_stats(StatsTimer), + idle_timeout => IdleTimeout, + connected => Connected, + connected_at => ConnectedAt + }, + maps:merge(ChanInfo, emqx_protocol:info(ProtoState)). -%% for dashboard +enable_stats(disabled) -> false; +enable_stats(_MaybeRef) -> true. + +-spec(attrs(pid() | state()) -> emqx_types:attrs()). attrs(WSPid) when is_pid(WSPid) -> call(WSPid, attrs); - attrs(#state{peername = Peername, sockname = Sockname, - proto_state = ProtoState}) -> - [{peername, Peername}, - {sockname, Sockname} | - emqx_protocol:attrs(ProtoState)]. + proto_state = ProtoState, + connected = Connected, + connected_at = ConnectedAt}) -> + ConnAttrs = #{socktype => websocket, + peername => Peername, + sockname => Sockname, + connected => Connected, + connected_at => ConnectedAt + }, + maps:merge(ConnAttrs, emqx_protocol:attrs(ProtoState)). +-spec(stats(pid() | state()) -> emqx_types:stats()). stats(WSPid) when is_pid(WSPid) -> call(WSPid, stats); +stats(#state{proto_state = ProtoState}) -> + ProcStats = emqx_misc:proc_stats(), + SessStats = emqx_session:stats(emqx_protocol:info(session, ProtoState)), + lists:append([ProcStats, SessStats, chan_stats(), wsock_stats()]). -stats(#state{}) -> - lists:append([chan_stats(), wsock_stats(), emqx_misc:proc_stats()]). - +%% @private call(WSPid, Req) when is_pid(WSPid) -> Mref = erlang:monitor(process, WSPid), WSPid ! {call, {self(), Mref}, Req}, @@ -116,21 +140,24 @@ init(Req, Opts) -> I -> I end, Compress = proplists:get_value(compress, Opts, false), - Options = #{compress => Compress, - deflate_opts => DeflateOptions, - max_frame_size => MaxFrameSize, - idle_timeout => IdleTimeout}, + WsOpts = #{compress => Compress, + deflate_opts => DeflateOptions, + max_frame_size => MaxFrameSize, + idle_timeout => IdleTimeout + }, case cowboy_req:parse_header(<<"sec-websocket-protocol">>, Req) of undefined -> - {cowboy_websocket, Req, #state{}, Options}; + %% TODO: why not reply 500??? + {cowboy_websocket, Req, [Req, Opts], WsOpts}; [<<"mqtt", Vsn/binary>>] -> - Resp = cowboy_req:set_resp_header(<<"sec-websocket-protocol">>, <<"mqtt", Vsn/binary>>, Req), - {cowboy_websocket, Resp, #state{request = Req, options = Opts}, Options}; + Resp = cowboy_req:set_resp_header( + <<"sec-websocket-protocol">>, <<"mqtt", Vsn/binary>>, Req), + {cowboy_websocket, Resp, [Req, Opts], WsOpts}; _ -> {ok, cowboy_req:reply(400, Req), #state{}} end. -websocket_init(#state{request = Req, options = Options}) -> +websocket_init([Req, Opts]) -> Peername = cowboy_req:peer(Req), Sockname = cowboy_req:sock(Req), Peercert = cowboy_req:cert(Req), @@ -148,8 +175,8 @@ websocket_init(#state{request = Req, options = Options}) -> sockname => Sockname, peercert => Peercert, ws_cookie => WsCookie, - conn_mod => ?MODULE}, Options), - Zone = proplists:get_value(zone, Options), + conn_mod => ?MODULE}, Opts), + Zone = proplists:get_value(zone, Opts), MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE), ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}), GcPolicy = emqx_zone:get_env(Zone, force_gc_policy, false), @@ -159,15 +186,16 @@ websocket_init(#state{request = Req, options = Options}) -> IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), emqx_logger:set_metadata_peername(esockd_net:format(Peername)), ok = emqx_misc:init_proc_mng_policy(Zone), - %% TODO: Idle timeout? {ok, #state{peername = Peername, sockname = Sockname, fsm_state = idle, parse_state = ParseState, proto_state = ProtoState, gc_state = GcState, + pendings = [], stats_timer = StatsTimer, - idle_timeout = IdleTimout + idle_timeout = IdleTimout, + connected = false }}. stat_fun() -> @@ -196,7 +224,7 @@ websocket_handle({FrameType, _}, State) %% According to mqtt spec[https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901285] websocket_handle({FrameType, _}, State) -> ?LOG(error, "Frame error: unexpected frame - ~p", [FrameType]), - shutdown(unexpected_ws_frame, State). + stop(unexpected_ws_frame, State). websocket_info({call, From, info}, State) -> gen_server:reply(From, info(State)), @@ -212,62 +240,41 @@ websocket_info({call, From, stats}, State) -> websocket_info({call, From, kick}, State) -> gen_server:reply(From, ok), - shutdown(kick, State); + stop(kick, State); websocket_info({incoming, Packet = ?CONNECT_PACKET( #mqtt_packet_connect{ proto_ver = ProtoVer} )}, State = #state{fsm_state = idle}) -> - State1 = State#state{serialize = serialize_fun(ProtoVer)}, - %% TODO: Fixme later - case handle_incoming(Packet, State1) of - Rep = {reply, _Data, _State} -> - self() ! {enter, connected}, - Rep; - Other -> Other - end; + handle_incoming(Packet, fun connected/1, + State#state{serialize = serialize_fun(ProtoVer)}); websocket_info({incoming, Packet}, State = #state{fsm_state = idle}) -> ?LOG(warning, "Unexpected incoming: ~p", [Packet]), - shutdown(unexpected_incoming_packet, State); - -websocket_info({enter, connected}, State = #state{proto_state = ProtoState}) -> - ClientId = emqx_protocol:client_id(ProtoState), - ok = emqx_cm:set_chan_attrs(ClientId, info(State)), - %% Ensure keepalive after connected successfully. - Interval = emqx_protocol:info(keepalive, ProtoState), - State1 = State#state{fsm_state = connected}, - case ensure_keepalive(Interval, State1) of - ignore -> {ok, State1}; - {ok, KeepAlive} -> - {ok, State1#state{keepalive = KeepAlive}}; - {error, Reason} -> - shutdown(Reason, State1) - end; + stop(unexpected_incoming_packet, State); websocket_info({incoming, Packet = ?PACKET(?CONNECT)}, State = #state{fsm_state = connected}) -> ?LOG(warning, "Unexpected connect: ~p", [Packet]), - shutdown(unexpected_incoming_connect, State); + stop(unexpected_incoming_connect, State); websocket_info({incoming, Packet}, State = #state{fsm_state = connected}) when is_record(Packet, mqtt_packet) -> - handle_incoming(Packet, State); + handle_incoming(Packet, fun reply/1, State); -websocket_info(Deliver = {deliver, _Topic, _Msg}, +websocket_info(Deliver = {deliver, _Topic, _Msg}, State = #state{proto_state = ProtoState}) -> Delivers = emqx_misc:drain_deliver([Deliver]), case emqx_protocol:handle_deliver(Delivers, ProtoState) of {ok, NProtoState} -> - {ok, State#state{proto_state = NProtoState}}; + reply(State#state{proto_state = NProtoState}); {ok, Packets, NProtoState} -> - NState = State#state{proto_state = NProtoState}, - handle_outgoing(Packets, NState); + reply(Packets, State#state{proto_state = NProtoState}); {error, Reason} -> - shutdown(Reason, State); + stop(Reason, State); {error, Reason, NProtoState} -> - shutdown(Reason, State#state{proto_state = NProtoState}) + stop(Reason, State#state{proto_state = NProtoState}) end; websocket_info({keepalive, check}, State = #state{keepalive = KeepAlive}) -> @@ -276,17 +283,17 @@ websocket_info({keepalive, check}, State = #state{keepalive = KeepAlive}) -> {ok, State#state{keepalive = KeepAlive1}}; {error, timeout} -> ?LOG(debug, "Keepalive Timeout!"), - shutdown(keepalive_timeout, State); + stop(keepalive_timeout, State); {error, Error} -> ?LOG(error, "Keepalive error: ~p", [Error]), - shutdown(keepalive_error, State) + stop(keepalive_error, State) end; websocket_info({timeout, Timer, emit_stats}, State = #state{stats_timer = Timer, proto_state = ProtoState, gc_state = GcState}) -> - ClientId = emqx_protocol:client_id(ProtoState), + ClientId = emqx_protocol:info(client_id, ProtoState), ok = emqx_cm:set_chan_stats(ClientId, stats(State)), NState = State#state{stats_timer = undefined}, Limits = erlang:get(force_shutdown_policy), @@ -299,7 +306,7 @@ websocket_info({timeout, Timer, emit_stats}, {ok, NState#state{gc_state = GcState1}, hibernate}; {shutdown, Reason} -> ?LOG(error, "Shutdown exceptionally due to ~p", [Reason]), - shutdown(Reason, NState) + stop(Reason, NState) end; websocket_info({timeout, Timer, Msg}, @@ -308,29 +315,29 @@ websocket_info({timeout, Timer, Msg}, {ok, NProtoState} -> {ok, State#state{proto_state = NProtoState}}; {ok, Packets, NProtoState} -> - handle_outgoing(Packets, State#state{proto_state = NProtoState}); + reply(Packets, State#state{proto_state = NProtoState}); {error, Reason} -> - shutdown(Reason, State); + stop(Reason, State); {error, Reason, NProtoState} -> - shutdown(Reason, State#state{proto_state = NProtoState}) + stop(Reason, State#state{proto_state = NProtoState}) end; websocket_info({shutdown, discard, {ClientId, ByPid}}, State) -> ?LOG(warning, "Discarded by ~s:~p", [ClientId, ByPid]), - shutdown(discard, State); + stop(discard, State); websocket_info({shutdown, conflict, {ClientId, NewPid}}, State) -> ?LOG(warning, "Clientid '~s' conflict with ~p", [ClientId, NewPid]), - shutdown(conflict, State); + stop(conflict, State); %% websocket_info({binary, Data}, State) -> %% {reply, {binary, Data}, State}; websocket_info({shutdown, Reason}, State) -> - shutdown(Reason, State); + stop(Reason, State); websocket_info({stop, Reason}, State) -> - {stop, State#state{shutdown = Reason}}; + stop(Reason, State); websocket_info(Info, State) -> ?LOG(error, "Unexpected info: ~p", [Info]), @@ -338,16 +345,31 @@ websocket_info(Info, State) -> terminate(SockError, _Req, #state{keepalive = Keepalive, proto_state = ProtoState, - shutdown = Shutdown}) -> + reason = Reason}) -> ?LOG(debug, "Terminated for ~p, sockerror: ~p", - [Shutdown, SockError]), + [Reason, SockError]), emqx_keepalive:cancel(Keepalive), - case {ProtoState, Shutdown} of - {undefined, _} -> ok; - {_, {shutdown, Reason}} -> - emqx_protocol:terminate(Reason, ProtoState); - {_, Error} -> - emqx_protocol:terminate(Error, ProtoState) + emqx_protocol:terminate(Reason, ProtoState). + +%%-------------------------------------------------------------------- +%% Connected callback + +connected(State = #state{proto_state = ProtoState}) -> + NState = State#state{fsm_state = connected, + connected = true, + connected_at = os:timestamp() + }, + ClientId = emqx_protocol:info(client_id, ProtoState), + ok = emqx_cm:set_chan_attrs(ClientId, info(NState)), + %% Ensure keepalive after connected successfully. + Interval = emqx_protocol:info(keepalive, ProtoState), + case ensure_keepalive(Interval, NState) of + ignore -> + reply(NState); + {ok, KeepAlive} -> + reply(NState#state{keepalive = KeepAlive}); + {error, Reason} -> + stop(Reason, NState) end. %%-------------------------------------------------------------------- @@ -361,9 +383,9 @@ ensure_keepalive(Interval, State = #state{proto_state = ProtoState}) -> case emqx_keepalive:start(stat_fun(), round(Interval * Backoff), {keepalive, check}) of {ok, KeepAlive} -> {ok, State#state{keepalive = KeepAlive}}; - {error, Error} -> - ?LOG(warning, "Keepalive error: ~p", [Error]), - shutdown(Error, State) + {error, Reason} -> + ?LOG(warning, "Keepalive error: ~p", [Reason]), + stop(Reason, State) end. %%-------------------------------------------------------------------- @@ -381,41 +403,46 @@ process_incoming(Data, State = #state{parse_state = ParseState}) -> process_incoming(Rest, State#state{parse_state = NParseState}); {error, Reason} -> ?LOG(error, "Frame error: ~p", [Reason]), - shutdown(Reason, State) + stop(Reason, State) catch error:Reason:Stk -> ?LOG(error, "Parse failed for ~p~n\ Stacktrace:~p~nFrame data: ~p", [Reason, Stk, Data]), - shutdown(parse_error, State) + stop(parse_error, State) end. %%-------------------------------------------------------------------- %% Handle incoming packets -handle_incoming(Packet = ?PACKET(Type), State = #state{proto_state = ProtoState}) -> +handle_incoming(Packet = ?PACKET(Type), SuccFun, + State = #state{proto_state = ProtoState, + pendings = Pendings}) -> _ = inc_incoming_stats(Type), ok = emqx_metrics:inc_recv(Packet), ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]), case emqx_protocol:handle_in(Packet, ProtoState) of {ok, NProtoState} -> - {ok, State#state{proto_state = NProtoState}}; + SuccFun(State#state{proto_state = NProtoState}); {ok, OutPackets, NProtoState} -> - handle_outgoing(OutPackets, State#state{proto_state = NProtoState}); + Pendings1 = lists:append(Pendings, OutPackets), + SuccFun(State#state{proto_state = NProtoState, + pendings = Pendings1}); {error, Reason, NProtoState} -> - shutdown(Reason, State#state{proto_state = NProtoState}); + stop(Reason, State#state{proto_state = NProtoState}); {stop, Error, NProtoState} -> - shutdown(Error, State#state{proto_state = NProtoState}) + stop(Error, State#state{proto_state = NProtoState}) end. %%-------------------------------------------------------------------- %% Handle outgoing packets -handle_outgoing(Packets, State = #state{serialize = Serialize}) - when is_list(Packets) -> - reply(lists:map(Serialize, Packets), State); +handle_outgoing(Packet, State) when is_tuple(Packet) -> + handle_outgoing([Packet], State); -handle_outgoing(Packet, State = #state{serialize = Serialize}) -> - reply(Serialize(Packet), State). +handle_outgoing(Packets, #state{serialize = Serialize}) -> + Data = lists:map(Serialize, Packets), + emqx_pd:update_counter(send_oct, iolist_size(Data)), + {binary, Data}. %%-------------------------------------------------------------------- %% Serialize fun @@ -442,13 +469,24 @@ inc_outgoing_stats(Type) -> andalso emqx_pd:update_counter(send_msg, 1). %%-------------------------------------------------------------------- -%% Reply data +%% Reply or Stop --compile({inline, [reply/2]}). -reply(Data, State) -> - BinSize = iolist_size(Data), - emqx_pd:update_counter(send_oct, BinSize), - {reply, {binary, Data}, State}. +reply(Packets, State = #state{pendings = Pendings}) -> + Pendings1 = lists:append(Pendings, Packets), + reply(State#state{pendings = Pendings1}). + +reply(State = #state{pendings = []}) -> + {ok, State}; +reply(State = #state{pendings = Pendings}) -> + Reply = handle_outgoing(Pendings, State), + {reply, Reply, State#state{pendings = []}}. + +stop(Reason, State = #state{pendings = []}) -> + {stop, State#state{reason = Reason}}; +stop(Reason, State = #state{pendings = Pendings}) -> + Reply = handle_outgoing(Pendings, State), + {reply, [Reply, close], + State#state{pendings = [], reason = Reason}}. %%-------------------------------------------------------------------- %% Ensure stats timer @@ -460,6 +498,12 @@ ensure_stats_timer(State = #state{stats_timer = undefined, %% disabled or timer existed ensure_stats_timer(State) -> State. +wsock_stats() -> + [{Key, emqx_pd:get_counter(Key)} || Key <- ?SOCK_STATS]. + +chan_stats() -> + [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS]. + %%-------------------------------------------------------------------- %% Maybe GC @@ -470,15 +514,3 @@ maybe_gc(Cnt, Oct, State = #state{gc_state = GCSt}) -> Ok andalso emqx_metrics:inc('channel.gc.cnt'), State#state{gc_state = GCSt1}. --compile({inline, [shutdown/2]}). -shutdown(Reason, State) -> - %% Fix the issue#2591(https://github.com/emqx/emqx/issues/2591#issuecomment-500278696) - %% self() ! {stop, Reason}, - {stop, State#state{shutdown = Reason}}. - -wsock_stats() -> - [{Key, emqx_pd:get_counter(Key)} || Key <- ?SOCK_STATS]. - -chan_stats() -> - [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS]. -