diff --git a/src/emqttd_client.erl b/src/emqttd_client.erl index 6f15bdf8c..1e4b3db4a 100644 --- a/src/emqttd_client.erl +++ b/src/emqttd_client.erl @@ -1,5 +1,5 @@ %%-------------------------------------------------------------------- -%% Copyright (c) 2012-2017 Feng Lee . +%% Copyright (c) 2013-2017 EMQ Enterprise, Inc. (http://emqtt.io) %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -14,24 +14,36 @@ %% limitations under the License. %%-------------------------------------------------------------------- -%% @doc MQTT Client Connection +%% @doc MQTT/TCP Connection + -module(emqttd_client). -behaviour(gen_server). +-author("Feng Lee "). + -include("emqttd.hrl"). -include("emqttd_protocol.hrl"). -include("emqttd_internal.hrl"). +-import(proplists, [get_value/2, get_value/3]). + %% API Function Exports --export([start_link/2, session/1, info/1, kick/1, - set_rate_limit/2, get_rate_limit/1]). +-export([start_link/2]). + +%% Management and Monitor API +-export([info/1, stats/1, kick/1]). + +-export([set_rate_limit/2, get_rate_limit/1]). %% SUB/UNSUB Asynchronously. Called by plugins. -export([subscribe/2, unsubscribe/2]). +%% Get the session proc? +-export([session/1]). + %% gen_server Function Exports -export([init/1, handle_call/3, handle_cast/2, handle_info/2, code_change/3, terminate/2]). @@ -39,24 +51,25 @@ %% Client State -record(client_state, {connection, connname, peername, peerhost, peerport, await_recv, conn_state, rate_limit, parser_fun, - proto_state, packet_opts, keepalive, mountpoint}). + proto_state, packet_opts, keepalive, enable_stats, + stats_timer}). --define(INFO_KEYS, [peername, peerhost, peerport, await_recv, conn_state]). +-define(INFO_KEYS, [connname, peername, peerhost, peerport, await_recv, conn_state]). --define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]). +-define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]). -define(LOG(Level, Format, Args, State), lager:Level("Client(~s): " ++ Format, [State#client_state.connname | Args])). -start_link(Connection, MqttEnv) -> - {ok, proc_lib:spawn_link(?MODULE, init, [[Connection, MqttEnv]])}. - -session(CPid) -> - gen_server:call(CPid, session, infinity). +start_link(Conn, Env) -> + {ok, proc_lib:spawn_link(?MODULE, init, [[Conn, Env]])}. info(CPid) -> gen_server:call(CPid, info, infinity). +stats(CPid) -> + gen_server:call(CPid, stats). + kick(CPid) -> gen_server:call(CPid, kick). @@ -67,42 +80,49 @@ get_rate_limit(Cpid) -> gen_server:call(Cpid, get_rate_limit). subscribe(CPid, TopicTable) -> - gen_server:cast(CPid, {subscribe, TopicTable}). + CPid ! {subscribe, TopicTable}. unsubscribe(CPid, Topics) -> - gen_server:cast(CPid, {unsubscribe, Topics}). + CPid ! {unsubscribe, Topics}. -init([OriginConn, MqttEnv]) -> - {ok, Connection} = OriginConn:wait(), +session(CPid) -> + gen_server2:call(CPid, session, infinity). + +%%-------------------------------------------------------------------- +%% gen_server Callbacks +%%-------------------------------------------------------------------- + +init([Conn0, Env]) -> + {ok, Conn} = Conn0:wait(), {PeerHost, PeerPort, PeerName} = - case Connection:peername() of + case Conn:peername() of {ok, Peer = {Host, Port}} -> {Host, Port, Peer}; {error, enotconn} -> - Connection:fast_close(), + Conn:fast_close(), exit(normal); {error, Reason} -> - Connection:fast_close(), + Conn:fast_close(), exit({shutdown, Reason}) end, ConnName = esockd_net:format(PeerName), Self = self(), - %% Send Packet... SendFun = fun(Packet) -> Data = emqttd_serializer:serialize(Packet), ?LOG(debug, "SEND ~p", [Data], #client_state{connname = ConnName}), emqttd_metrics:inc('bytes/sent', iolist_size(Data)), - try Connection:async_send(Data) of + try Conn:async_send(Data) of true -> ok catch error:Error -> Self ! {shutdown, Error} end end, - ParserFun = emqttd_parser:new(MqttEnv), - ProtoState = emqttd_protocol:init(PeerName, SendFun, MqttEnv), - RateLimit = proplists:get_value(rate_limit, Connection:opts()), - State = run_socket(#client_state{connection = Connection, + ParserFun = emqttd_parser:new(Env), + ProtoState = emqttd_protocol:init(PeerName, SendFun, Env), + RateLimit = get_value(rate_limit, Conn:opts()), + EnableStats = get_value(client_enable_stats, Env, false), + State = run_socket(#client_state{connection = Conn, connname = ConnName, peername = PeerName, peerhost = PeerHost, @@ -112,20 +132,21 @@ init([OriginConn, MqttEnv]) -> rate_limit = RateLimit, parser_fun = ParserFun, proto_state = ProtoState, - packet_opts = MqttEnv}), - IdleTimout = proplists:get_value(client_idle_timeout, MqttEnv, 30), - gen_server:enter_loop(?MODULE, [], State, timer:seconds(IdleTimout)). + packet_opts = Env, + enable_stats = EnableStats}), + IdleTimout = get_value(client_idle_timeout, Env, 30000), + gen_server:enter_loop(?MODULE, [], maybe_enable_stats(State), IdleTimout). -handle_call(session, _From, State = #client_state{proto_state = ProtoState}) -> - {reply, emqttd_protocol:session(ProtoState), State}; - -handle_call(info, _From, State = #client_state{connection = Connection, - proto_state = ProtoState}) -> - ClientInfo = ?record_to_proplist(client_state, State, ?INFO_KEYS), +handle_call(info, From, State = #client_state{proto_state = ProtoState}) -> ProtoInfo = emqttd_protocol:info(ProtoState), - {ok, SockStats} = Connection:getstat(?SOCK_STATS), - {reply, lists:append([ClientInfo, [{proto_info, ProtoInfo}, - {sock_stats, SockStats}]]), State}; + ClientInfo = ?record_to_proplist(client_state, State, ?INFO_KEYS), + {reply, Stats, _} = handle_call(stats, From, State), + {reply, lists:append([ClientInfo, ProtoInfo, Stats]), State}; + +handle_call(stats, _From, State = #client_state{proto_state = ProtoState}) -> + {reply, lists:append([emqttd_misc:proc_stats(), + emqttd_protocol:stats(ProtoState), + sock_stats(State)]), State}; handle_call(kick, _From, State) -> {stop, {shutdown, kick}, ok, State}; @@ -136,45 +157,56 @@ handle_call({set_rate_limit, Rl}, _From, State) -> handle_call(get_rate_limit, _From, State = #client_state{rate_limit = Rl}) -> {reply, Rl, State}; +handle_call(session, _From, State = #client_state{proto_state = ProtoState}) -> + {reply, emqttd_protocol:session(ProtoState), State}; + handle_call(Req, _From, State) -> ?UNEXPECTED_REQ(Req, State). -handle_cast({subscribe, TopicTable}, State) -> - with_proto_state(fun(ProtoState) -> - emqttd_protocol:handle({subscribe, TopicTable}, ProtoState) - end, State); - -handle_cast({unsubscribe, Topics}, State) -> - with_proto_state(fun(ProtoState) -> - emqttd_protocol:handle({unsubscribe, Topics}, ProtoState) - end, State); - handle_cast(Msg, State) -> ?UNEXPECTED_MSG(Msg, State). -handle_info(timeout, State) -> - shutdown(idle_timeout, State); +handle_info({subscribe, TopicTable}, State) -> + with_proto( + fun(ProtoState) -> + emqttd_protocol:subscribe(TopicTable, ProtoState) + end, State); -%% fix issue #535 -handle_info({shutdown, Error}, State) -> - shutdown(Error, State); +handle_info({unsubscribe, Topics}, State) -> + with_proto( + fun(ProtoState) -> + emqttd_protocol:unsubscribe(Topics, ProtoState) + end, State); %% Asynchronous SUBACK handle_info({suback, PacketId, GrantedQos}, State) -> - with_proto_state(fun(ProtoState) -> - Packet = ?SUBACK_PACKET(PacketId, GrantedQos), - emqttd_protocol:send(Packet, ProtoState) - end, State); + with_proto( + fun(ProtoState) -> + Packet = ?SUBACK_PACKET(PacketId, GrantedQos), + emqttd_protocol:send(Packet, ProtoState) + end, State); handle_info({deliver, Message}, State) -> - with_proto_state(fun(ProtoState) -> - emqttd_protocol:send(Message, ProtoState) - end, State); + with_proto( + fun(ProtoState) -> + emqttd_protocol:send(Message, ProtoState) + end, State); handle_info({redeliver, {?PUBREL, PacketId}}, State) -> - with_proto_state(fun(ProtoState) -> - emqttd_protocol:redeliver({?PUBREL, PacketId}, ProtoState) - end, State); + with_proto( + fun(ProtoState) -> + emqttd_protocol:pubrel(PacketId, ProtoState) + end, State); + +handle_info(timeout, State) -> + shutdown(idle_timeout, State); + +handle_info({timeout, _Timer, emit_stats}, State) -> + hibernate(maybe_enable_stats(emit_stats(State))); + +%% Fix issue #535 +handle_info({shutdown, Error}, State) -> + shutdown(Error, State); handle_info({shutdown, conflict, {ClientId, NewPid}}, State) -> ?LOG(warning, "clientid '~s' conflict with ~p", [ClientId, NewPid], State), @@ -193,26 +225,26 @@ handle_info({inet_async, _Sock, _Ref, {error, Reason}}, State) -> shutdown(Reason, State); handle_info({inet_reply, _Sock, ok}, State) -> - hibernate(State); + {noreply, State}; handle_info({inet_reply, _Sock, {error, Reason}}, State) -> shutdown(Reason, State); -handle_info({keepalive, start, Interval}, State = #client_state{connection = Connection}) -> +handle_info({keepalive, start, Interval}, State = #client_state{connection = Conn}) -> ?LOG(debug, "Keepalive at the interval of ~p", [Interval], State), StatFun = fun() -> - case Connection:getstat([recv_oct]) of + case Conn:getstat([recv_oct]) of {ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct}; {error, Error} -> {error, Error} end end, KeepAlive = emqttd_keepalive:start(StatFun, Interval, {keepalive, check}), - hibernate(State#client_state{keepalive = KeepAlive}); + {noreply, stats_by_keepalive(State#client_state{keepalive = KeepAlive})}; handle_info({keepalive, check}, State = #client_state{keepalive = KeepAlive}) -> case emqttd_keepalive:check(KeepAlive) of {ok, KeepAlive1} -> - hibernate(State#client_state{keepalive = KeepAlive1}); + hibernate(emit_stats(State#client_state{keepalive = KeepAlive1})); {error, timeout} -> ?LOG(debug, "Keepalive timeout", [], State), shutdown(keepalive_timeout, State); @@ -224,10 +256,10 @@ handle_info({keepalive, check}, State = #client_state{keepalive = KeepAlive}) -> handle_info(Info, State) -> ?UNEXPECTED_INFO(Info, State). -terminate(Reason, #client_state{connection = Connection, +terminate(Reason, #client_state{connection = Conn, keepalive = KeepAlive, proto_state = ProtoState}) -> - Connection:fast_close(), + Conn:fast_close(), emqttd_keepalive:cancel(KeepAlive), case {ProtoState, Reason} of {undefined, _} -> @@ -245,10 +277,6 @@ code_change(_OldVsn, State, _Extra) -> %% Internal functions %%-------------------------------------------------------------------- -with_proto_state(Fun, State = #client_state{proto_state = ProtoState}) -> - {ok, ProtoState1} = Fun(ProtoState), - hibernate(State#client_state{proto_state = ProtoState1}). - %% Receive and parse tcp data received(<<>>, State) -> hibernate(State); @@ -258,7 +286,7 @@ received(Bytes, State = #client_state{parser_fun = ParserFun, proto_state = ProtoState}) -> case catch ParserFun(Bytes) of {more, NewParser} -> - noreply(run_socket(State#client_state{parser_fun = NewParser})); + {noreply, run_socket(State#client_state{parser_fun = NewParser})}; {ok, Packet, Rest} -> emqttd_metrics:received(Packet), case emqttd_protocol:received(Packet, ProtoState) of @@ -289,7 +317,7 @@ rate_limit(Size, State = #client_state{rate_limit = Rl}) -> {0, Rl1} -> run_socket(State#client_state{conn_state = running, rate_limit = Rl1}); {Pause, Rl1} -> - ?LOG(error, "Rate limiter pause for ~p", [Pause], State), + ?LOG(warning, "Rate limiter pause for ~p", [Pause], State), erlang:send_after(Pause, self(), activate_sock), State#client_state{conn_state = blocked, rate_limit = Rl1} end. @@ -298,12 +326,36 @@ run_socket(State = #client_state{conn_state = blocked}) -> State; run_socket(State = #client_state{await_recv = true}) -> State; -run_socket(State = #client_state{connection = Connection}) -> - Connection:async_recv(0, infinity), +run_socket(State = #client_state{connection = Conn}) -> + Conn:async_recv(0, infinity), State#client_state{await_recv = true}. -noreply(State) -> - {noreply, State}. +with_proto(Fun, State = #client_state{proto_state = ProtoState}) -> + {ok, ProtoState1} = Fun(ProtoState), + {noreply, State#client_state{proto_state = ProtoState1}}. + +maybe_enable_stats(State = #client_state{enable_stats = false}) -> + State; +maybe_enable_stats(State = #client_state{enable_stats = keepalive}) -> + State; +maybe_enable_stats(State = #client_state{enable_stats = Interval}) -> + State#client_state{stats_timer = emqttd_misc:start_timer(Interval, self(), emit_stats)}. + +stats_by_keepalive(State) -> + State#client_state{enable_stats = keepalive}. + +emit_stats(State = #client_state{enable_stats = false}) -> + State; +emit_stats(State = #client_state{proto_state = ProtoState}) -> + {reply, Stats, _} = handle_call(stats, undefined, State), + emqttd_stats:set_client_stats(emqttd_protocol:clientid(ProtoState), Stats), + State. + +sock_stats(#client_state{connection = Conn}) -> + case Conn:getstat(?SOCK_STATS) of + {ok, Ss} -> Ss; + {error, _} -> [] + end. hibernate(State) -> {noreply, State, hibernate}.