From d5a400c308b17c21ff7361d38ba3a75bf238a5b8 Mon Sep 17 00:00:00 2001 From: Feng Date: Sun, 4 Oct 2015 19:48:50 +0800 Subject: [PATCH] fix issue #292 - async sub/unsub --- src/emqttd_client.erl | 63 +++++++++++++++++--------- src/emqttd_protocol.erl | 17 +++---- src/emqttd_session.erl | 96 +++++++++++++++++++++------------------- src/emqttd_ws_client.erl | 77 +++++++++++++++++++++++++------- 4 files changed, 159 insertions(+), 94 deletions(-) diff --git a/src/emqttd_client.erl b/src/emqttd_client.erl index da38d9d8d..7eb4be8f4 100644 --- a/src/emqttd_client.erl +++ b/src/emqttd_client.erl @@ -34,7 +34,10 @@ -include("emqttd_protocol.hrl"). %% API Function Exports --export([start_link/2, session/1, info/1, kick/1, subscribe/2]). +-export([start_link/2, session/1, info/1, kick/1]). + +%% SUB/UNSUB Asynchronously +-export([subscribe/2, unsubscribe/2]). -behaviour(gen_server). @@ -59,7 +62,7 @@ start_link(SockArgs, MqttEnv) -> {ok, proc_lib:spawn_link(?MODULE, init, [[SockArgs, MqttEnv]])}. session(CPid) -> - gen_server:call(CPid, session). + gen_server:call(CPid, session, infinity). info(CPid) -> gen_server:call(CPid, info, infinity). @@ -70,6 +73,9 @@ kick(CPid) -> subscribe(CPid, TopicTable) -> gen_server:cast(CPid, {subscribe, TopicTable}). +unsubscribe(CPid, Topics) -> + gen_server:cast(CPid, {unsubscribe, Topics}). + init([SockArgs = {Transport, Sock, _SockFun}, MqttEnv]) -> % Transform if ssl. {ok, NewSock} = esockd_connection:accept(SockArgs), @@ -107,9 +113,11 @@ handle_call(Req, _From, State = #state{peername = Peername}) -> lager:critical("Client ~s: unexpected request - ~p", [emqttd_net:format(Peername), Req]), {reply, {error, unsupported_request}, State}. -handle_cast({subscribe, TopicTable}, State = #state{proto_state = ProtoState}) -> - {ok, ProtoState1} = emqttd_protocol:handle({subscribe, TopicTable}, ProtoState), - noreply(State#state{proto_state = ProtoState1}); +handle_cast({subscribe, TopicTable}, State) -> + with_session(fun(SessPid) -> emqttd_session:subscribe(SessPid, TopicTable) end, State); + +handle_cast({unsubscribe, Topics}, State) -> + with_session(fun(SessPid) -> emqttd_session:unsubscribe(SessPid, Topics) end, State); handle_cast(Msg, State = #state{peername = Peername}) -> lager:critical("Client ~s: unexpected msg - ~p",[emqttd_net:format(Peername), Msg]), @@ -149,17 +157,26 @@ handle_info({inet_reply, _Sock, {error, Reason}}, State = #state{peername = Peer handle_info({keepalive, start, TimeoutSec}, State = #state{transport = Transport, socket = Socket, peername = Peername}) -> lager:debug("Client ~s: Start KeepAlive with ~p seconds", [emqttd_net:format(Peername), TimeoutSec]), - KeepAlive = emqttd_keepalive:new({Transport, Socket}, TimeoutSec, {keepalive, timeout}), + StatFun = fun() -> + case Transport:getstat(Socket, [recv_oct]) of + {ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct}; + {error, Error} -> {error, Error} + end + end, + KeepAlive = emqttd_keepalive:start(StatFun, TimeoutSec, {keepalive, check}), noreply(State#state{keepalive = KeepAlive}); -handle_info({keepalive, timeout}, State = #state{peername = Peername, keepalive = KeepAlive}) -> - case emqttd_keepalive:resume(KeepAlive) of - timeout -> +handle_info({keepalive, check}, State = #state{peername = Peername, keepalive = KeepAlive}) -> + case emqttd_keepalive:check(KeepAlive) of + {ok, KeepAlive1} -> + lager:debug("Client ~s: Keepalive Resumed", [emqttd_net:format(Peername)]), + noreply(State#state{keepalive = KeepAlive1}); + {error, timeout} -> lager:debug("Client ~s: Keepalive Timeout!", [emqttd_net:format(Peername)]), stop({shutdown, keepalive_timeout}, State#state{keepalive = undefined}); - {resumed, KeepAlive1} -> - lager:debug("Client ~s: Keepalive Resumed", [emqttd_net:format(Peername)]), - noreply(State#state{keepalive = KeepAlive1}) + {error, Error} -> + lager:debug("Client ~s: Keepalive Error: ~p!", [emqttd_net:format(Peername), Error]), + stop({shutdown, keepalive_error}, State#state{keepalive = undefined}) end; handle_info(Info, State = #state{peername = Peername}) -> @@ -188,12 +205,20 @@ terminate(Reason, #state{peername = Peername, code_change(_OldVsn, State, _Extra) -> {ok, State}. +%%%============================================================================= +%%% Internal functions +%%%============================================================================= + noreply(State) -> {noreply, State, hibernate}. - -%------------------------------------------------------- -% receive and parse tcp data -%------------------------------------------------------- + +stop(Reason, State) -> + {stop, Reason, State}. + +with_session(Fun, State = #state{proto_state = ProtoState}) -> + Fun(emqttd_protocol:session(ProtoState)), noreply(State). + +%% receive and parse tcp data received(<<>>, State) -> {noreply, State, hibernate}; @@ -244,12 +269,8 @@ control_throttle(State = #state{conn_state = Flow, {_, _} -> run_socket(State) end. -stop(Reason, State) -> - {stop, Reason, State}. - received_stats(?PACKET(Type)) -> - emqttd_metrics:inc('packets/received'), - inc(Type). + emqttd_metrics:inc('packets/received'), inc(Type). inc(?CONNECT) -> emqttd_metrics:inc('packets/connect'); inc(?PUBLISH) -> diff --git a/src/emqttd_protocol.erl b/src/emqttd_protocol.erl index dcd120035..840339819 100644 --- a/src/emqttd_protocol.erl +++ b/src/emqttd_protocol.erl @@ -239,16 +239,11 @@ handle(?SUBSCRIBE_PACKET(PacketId, TopicTable), State = #proto_state{client_id = case lists:member(deny, AllowDenies) of true -> %%TODO: return 128 QoS when deny... no need to SUBACK? - lager:error("SUBSCRIBE from '~s' Denied: ~p", [ClientId, TopicTable]), - {ok, State}; + lager:error("SUBSCRIBE from '~s' Denied: ~p", [ClientId, TopicTable]); false -> - %%TODO: GrantedQos should be renamed. - {ok, GrantedQos} = emqttd_session:subscribe(Session, TopicTable), - send(?SUBACK_PACKET(PacketId, GrantedQos), State) - end; - -handle({subscribe, TopicTable}, State = #proto_state{session = Session}) -> - {ok, _GrantedQos} = emqttd_session:subscribe(Session, TopicTable), + Callback = fun(GrantedQos) -> send(?SUBACK_PACKET(PacketId, GrantedQos), State) end, + emqttd_session:subscribe(Session, TopicTable, Callback) + end, {ok, State}; %% protect from empty topic list @@ -256,7 +251,7 @@ handle(?UNSUBSCRIBE_PACKET(PacketId, []), State) -> send(?UNSUBACK_PACKET(PacketId), State); handle(?UNSUBSCRIBE_PACKET(PacketId, Topics), State = #proto_state{session = Session}) -> - ok = emqttd_session:unsubscribe(Session, Topics), + emqttd_session:unsubscribe(Session, Topics), send(?UNSUBACK_PACKET(PacketId), State); handle(?PACKET(?PINGREQ), State) -> @@ -349,7 +344,7 @@ send_willmsg(ClientId, WillMsg) -> start_keepalive(0) -> ignore; start_keepalive(Sec) when Sec > 0 -> - self() ! {keepalive, start, round(Sec * 1.5)}. + self() ! {keepalive, start, round(Sec * 1.2)}. %%---------------------------------------------------------------------------- %% Validate Packets diff --git a/src/emqttd_session.erl b/src/emqttd_session.erl index 1d3caa98f..a23f76f7d 100644 --- a/src/emqttd_session.erl +++ b/src/emqttd_session.erl @@ -59,7 +59,7 @@ %% PubSub APIs -export([publish/2, puback/2, pubrec/2, pubrel/2, pubcomp/2, - subscribe/2, unsubscribe/2]). + subscribe/2, subscribe/3, unsubscribe/2]). -behaviour(gen_server2). @@ -166,9 +166,13 @@ destroy(SessPid, ClientId) -> %% @doc Subscribe Topics %% @end %%------------------------------------------------------------------------------ --spec subscribe(pid(), [{binary(), mqtt_qos()}]) -> {ok, [mqtt_qos()]}. +-spec subscribe(pid(), [{binary(), mqtt_qos()}]) -> ok. subscribe(SessPid, TopicTable) -> - gen_server2:call(SessPid, {subscribe, TopicTable}, ?PUBSUB_TIMEOUT). + subscribe(SessPid, TopicTable, fun(_) -> ok end). + +-spec subscribe(pid(), [{binary(), mqtt_qos()}], Callback :: fun()) -> ok. +subscribe(SessPid, TopicTable, Callback) -> + gen_server2:cast(SessPid, {subscribe, TopicTable, Callback}). %%------------------------------------------------------------------------------ %% @doc Publish message @@ -213,7 +217,7 @@ pubcomp(SessPid, PktId) -> %%------------------------------------------------------------------------------ -spec unsubscribe(pid(), [binary()]) -> ok. unsubscribe(SessPid, Topics) -> - gen_server2:call(SessPid, {unsubscribe, Topics}, ?PUBSUB_TIMEOUT). + gen_server2:cast(SessPid, {unsubscribe, Topics}). %%%============================================================================= %%% gen_server callbacks @@ -247,26 +251,24 @@ init([CleanSess, ClientId, ClientPid]) -> {ok, start_collector(Session#session{client_mon = MRef}), hibernate}. prioritise_call(Msg, _From, _Len, _State) -> - case Msg of - {unsubscribe, _} -> 2; - {subscribe, _} -> 1; - _ -> 0 - end. + case Msg of _ -> 0 end. prioritise_cast(Msg, _Len, _State) -> case Msg of - {destroy, _} -> 10; - {resume, _, _} -> 9; - {pubrel, _PktId} -> 8; - {pubcomp, _PktId} -> 8; - {pubrec, _PktId} -> 8; - {puback, _PktId} -> 7; - _ -> 0 + {destroy, _} -> 10; + {resume, _, _} -> 9; + {pubrel, _PktId} -> 8; + {pubcomp, _PktId} -> 8; + {pubrec, _PktId} -> 8; + {puback, _PktId} -> 7; + {unsubscribe, _, _} -> 6; + {subscribe, _, _} -> 5; + _ -> 0 end. prioritise_info(Msg, _Len, _State) -> case Msg of - {'DOWN', _, process, _, _} -> 10; + {'DOWN', _, _, _, _} -> 10; {'EXIT', _, _} -> 10; session_expired -> 10; {timeout, _, _} -> 5; @@ -275,17 +277,40 @@ prioritise_info(Msg, _Len, _State) -> _ -> 0 end. -handle_call({subscribe, TopicTable0}, _From, Session = #session{client_id = ClientId, - subscriptions = Subscriptions}) -> +handle_call({publish, Msg = #mqtt_message{qos = ?QOS_2, pktid = PktId}}, _From, + Session = #session{client_id = ClientId, + awaiting_rel = AwaitingRel, + await_rel_timeout = Timeout}) -> + case check_awaiting_rel(Session) of + true -> + TRef = timer(Timeout, {timeout, awaiting_rel, PktId}), + AwaitingRel1 = maps:put(PktId, {Msg, TRef}, AwaitingRel), + {reply, ok, Session#session{awaiting_rel = AwaitingRel1}}; + false -> + lager:critical([{client, ClientId}], "Session(~s) dropped Qos2 message " + "for too many awaiting_rel: ~p", [ClientId, Msg]), + {reply, {error, dropped}, Session} + end; - case TopicTable0 -- Subscriptions of +handle_call(Req, _From, State) -> + lager:critical("Unexpected Request: ~p", [Req]), + {reply, ok, State}. + +handle_cast({subscribe, TopicTable0, Callback}, Session = #session{ + client_id = ClientId, subscriptions = Subscriptions}) -> + + TopicTable = emqttd_broker:foldl_hooks('client.subscribe', [ClientId], TopicTable0), + + case TopicTable -- Subscriptions of [] -> - {reply, {ok, [Qos || {_, Qos} <- TopicTable0]}, Session}; + catch Callback([Qos || {_, Qos} <- TopicTable]), + noreply(Session); _ -> - TopicTable = emqttd_broker:foldl_hooks('client.subscribe', [ClientId], TopicTable0), %% subscribe first and don't care if the subscriptions have been existed {ok, GrantedQos} = emqttd_pubsub:subscribe(TopicTable), + catch Callback(GrantedQos), + emqttd_broker:foreach_hooks('client.subscribe.after', [ClientId, TopicTable]), lager:info([{client, ClientId}], "Session(~s): subscribe ~p, Granted QoS: ~p", @@ -310,11 +335,11 @@ handle_call({subscribe, TopicTable0}, _From, Session = #session{client_id = Clie [{Topic, Qos} | Acc] end end, Subscriptions, TopicTable), - {reply, {ok, GrantedQos}, Session#session{subscriptions = Subscriptions1}} + noreply(Session#session{subscriptions = Subscriptions1}) end; -handle_call({unsubscribe, Topics0}, _From, Session = #session{client_id = ClientId, - subscriptions = Subscriptions}) -> +handle_cast({unsubscribe, Topics0}, Session = #session{client_id = ClientId, + subscriptions = Subscriptions}) -> Topics = emqttd_broker:foldl_hooks('client.unsubscribe', [ClientId], Topics0), @@ -333,26 +358,7 @@ handle_call({unsubscribe, Topics0}, _From, Session = #session{client_id = Client end end, Subscriptions, Topics), - {reply, ok, Session#session{subscriptions = Subscriptions1}}; - -handle_call({publish, Msg = #mqtt_message{qos = ?QOS_2, pktid = PktId}}, _From, - Session = #session{client_id = ClientId, - awaiting_rel = AwaitingRel, - await_rel_timeout = Timeout}) -> - case check_awaiting_rel(Session) of - true -> - TRef = timer(Timeout, {timeout, awaiting_rel, PktId}), - AwaitingRel1 = maps:put(PktId, {Msg, TRef}, AwaitingRel), - {reply, ok, Session#session{awaiting_rel = AwaitingRel1}}; - false -> - lager:critical([{client, ClientId}], "Session(~s) dropped Qos2 message " - "for too many awaiting_rel: ~p", [ClientId, Msg]), - {reply, {error, dropped}, Session} - end; - -handle_call(Req, _From, State) -> - lager:critical("Unexpected Request: ~p", [Req]), - {reply, ok, State}. + noreply(Session#session{subscriptions = Subscriptions1}); handle_cast({destroy, ClientId}, Session = #session{client_id = ClientId}) -> lager:warning([{client, ClientId}], "Session(~s) destroyed", [ClientId]), diff --git a/src/emqttd_ws_client.erl b/src/emqttd_ws_client.erl index 6b4a001b2..827d7de18 100644 --- a/src/emqttd_ws_client.erl +++ b/src/emqttd_ws_client.erl @@ -34,7 +34,10 @@ -include("emqttd_protocol.hrl"). %% API Exports --export([start_link/1, ws_loop/3, subscribe/2]). +-export([start_link/1, ws_loop/3, session/1, info/1, kick/1]). + +%% SUB/UNSUB Asynchronously +-export([subscribe/2, unsubscribe/2]). -behaviour(gen_server). @@ -61,9 +64,21 @@ start_link(Req) -> packet_opts = PktOpts, parser = emqttd_parser:new(PktOpts)}). +session(CPid) -> + gen_server:call(CPid, session, infinity). + +info(CPid) -> + gen_server:call(CPid, info, infinity). + +kick(CPid) -> + gen_server:call(CPid, kick). + subscribe(CPid, TopicTable) -> gen_server:cast(CPid, {subscribe, TopicTable}). +unsubscribe(CPid, Topics) -> + gen_server:cast(CPid, {unsubscribe, Topics}). + %%------------------------------------------------------------------------------ %% @private %% @doc Start WebSocket client. @@ -112,17 +127,30 @@ init([WsPid, Req, ReplyChannel, PktOpts]) -> ProtoState = emqttd_protocol:init(Peername, SendFun, [{ws_initial_headers, HeadersList}|PktOpts]), {ok, #client_state{ws_pid = WsPid, request = Req, proto_state = ProtoState}}. +handle_call(session, _From, State = #client_state{proto_state = ProtoState}) -> + {reply, emqttd_protocol:session(ProtoState), State}; + +handle_call(info, _From, State = #client_state{request = Req, + proto_state = ProtoState}) -> + {reply, [{websocket, true}, {peer, Req:get(peer)} + | emqttd_protocol:info(ProtoState)], State}; + +handle_call(kick, _From, State) -> + {stop, {shutdown, kick}, ok, State}; + handle_call(_Req, _From, State) -> {reply, error, State}. -handle_cast({subscribe, TopicTable}, State = #client_state{proto_state = ProtoState}) -> - {ok, ProtoState1} = emqttd_protocol:handle({subscribe, TopicTable}, ProtoState), - {noreply, State#client_state{proto_state = ProtoState1}, hibernate}; +handle_cast({subscribe, TopicTable}, State) -> + with_session(fun(SessPid) -> emqttd_session:subscribe(SessPid, TopicTable) end, State); + +handle_cast({unsubscribe, Topics}, State) -> + with_session(fun(SessPid) -> emqttd_session:unsubscribe(SessPid, Topics) end, State); handle_cast({received, Packet}, State = #client_state{proto_state = ProtoState}) -> case emqttd_protocol:received(Packet, ProtoState) of {ok, ProtoState1} -> - {noreply, State#client_state{proto_state = ProtoState1}}; + noreply(State#client_state{proto_state = ProtoState1}); {error, Error} -> lager:error("MQTT protocol error ~p", [Error]), stop({shutdown, Error}, State); @@ -137,11 +165,11 @@ handle_cast(_Msg, State) -> handle_info({deliver, Message}, State = #client_state{proto_state = ProtoState}) -> {ok, ProtoState1} = emqttd_protocol:send(Message, ProtoState), - {noreply, State#client_state{proto_state = ProtoState1}}; + noreply(State#client_state{proto_state = ProtoState1}); handle_info({redeliver, {?PUBREL, PacketId}}, State = #client_state{proto_state = ProtoState}) -> {ok, ProtoState1} = emqttd_protocol:redeliver({?PUBREL, PacketId}, ProtoState), - {noreply, State#client_state{proto_state = ProtoState1}}; + noreply(State#client_state{proto_state = ProtoState1}); handle_info({stop, duplicate_id, _NewPid}, State = #client_state{proto_state = ProtoState}) -> lager:error("Shutdown for duplicate clientid: ~s", [emqttd_protocol:clientid(ProtoState)]), @@ -149,18 +177,27 @@ handle_info({stop, duplicate_id, _NewPid}, State = #client_state{proto_state = P handle_info({keepalive, start, TimeoutSec}, State = #client_state{request = Req}) -> lager:debug("Client(WebSocket) ~s: Start KeepAlive with ~p seconds", [Req:get(peer), TimeoutSec]), - KeepAlive = emqttd_keepalive:new({esockd_transport, Req:get(socket)}, - TimeoutSec, {keepalive, timeout}), - {noreply, State#client_state{keepalive = KeepAlive}}; + Socket = Req:get(socket), + StatFun = fun() -> + case esockd_transport:getstat(Socket, [recv_oct]) of + {ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct}; + {error, Error} -> {error, Error} + end + end, + KeepAlive = emqttd_keepalive:start(StatFun, TimeoutSec, {keepalive, check}), + noreply(State#client_state{keepalive = KeepAlive}); -handle_info({keepalive, timeout}, State = #client_state{request = Req, keepalive = KeepAlive}) -> - case emqttd_keepalive:resume(KeepAlive) of - timeout -> +handle_info({keepalive, check}, State = #client_state{request = Req, keepalive = KeepAlive}) -> + case emqttd_keepalive:check(KeepAlive) of + {ok, KeepAlive1} -> + lager:debug("Client(WebSocket) ~s: Keepalive Resumed", [Req:get(peer)]), + noreply(State#client_state{keepalive = KeepAlive1}); + {error, timeout} -> lager:debug("Client(WebSocket) ~s: Keepalive Timeout!", [Req:get(peer)]), stop({shutdown, keepalive_timeout}, State#client_state{keepalive = undefined}); - {resumed, KeepAlive1} -> - lager:debug("Client(WebSocket) ~s: Keepalive Resumed", [Req:get(peer)]), - {noreply, State#client_state{keepalive = KeepAlive1}} + {error, Error} -> + lager:debug("Client(WebSocket) ~s: Keepalive Error: ~p", [Req:get(peer), Error]), + stop({shutdown, keepalive_error}, State#client_state{keepalive = undefined}) end; handle_info({'EXIT', WsPid, Reason}, State = #client_state{ws_pid = WsPid, proto_state = ProtoState}) -> @@ -170,7 +207,7 @@ handle_info({'EXIT', WsPid, Reason}, State = #client_state{ws_pid = WsPid, proto handle_info(Info, State = #client_state{request = Req}) -> lager:critical("Client(WebSocket) ~s: Unexpected Info - ~p", [Req:get(peer), Info]), - {noreply, State}. + noreply(State). terminate(Reason, #client_state{proto_state = ProtoState, keepalive = KeepAlive}) -> lager:info("WebSocket client terminated: ~p", [Reason]), @@ -189,6 +226,12 @@ code_change(_OldVsn, State, _Extra) -> %%% Internal functions %%%============================================================================= +noreply(State) -> + {noreply, State, hibernate}. + stop(Reason, State ) -> {stop, Reason, State}. +with_session(Fun, State = #client_state{proto_state = ProtoState}) -> + Fun(emqttd_protocol:session(ProtoState)), noreply(State). +