fix issue #292 - async sub/unsub

This commit is contained in:
Feng 2015-10-04 19:48:50 +08:00
parent 9f643ea267
commit d5a400c308
4 changed files with 159 additions and 94 deletions

View File

@ -34,7 +34,10 @@
-include("emqttd_protocol.hrl").
%% API Function Exports
-export([start_link/2, session/1, info/1, kick/1, subscribe/2]).
-export([start_link/2, session/1, info/1, kick/1]).
%% SUB/UNSUB Asynchronously
-export([subscribe/2, unsubscribe/2]).
-behaviour(gen_server).
@ -59,7 +62,7 @@ start_link(SockArgs, MqttEnv) ->
{ok, proc_lib:spawn_link(?MODULE, init, [[SockArgs, MqttEnv]])}.
session(CPid) ->
gen_server:call(CPid, session).
gen_server:call(CPid, session, infinity).
info(CPid) ->
gen_server:call(CPid, info, infinity).
@ -70,6 +73,9 @@ kick(CPid) ->
subscribe(CPid, TopicTable) ->
gen_server:cast(CPid, {subscribe, TopicTable}).
unsubscribe(CPid, Topics) ->
gen_server:cast(CPid, {unsubscribe, Topics}).
init([SockArgs = {Transport, Sock, _SockFun}, MqttEnv]) ->
% Transform if ssl.
{ok, NewSock} = esockd_connection:accept(SockArgs),
@ -107,9 +113,11 @@ handle_call(Req, _From, State = #state{peername = Peername}) ->
lager:critical("Client ~s: unexpected request - ~p", [emqttd_net:format(Peername), Req]),
{reply, {error, unsupported_request}, State}.
handle_cast({subscribe, TopicTable}, State = #state{proto_state = ProtoState}) ->
{ok, ProtoState1} = emqttd_protocol:handle({subscribe, TopicTable}, ProtoState),
noreply(State#state{proto_state = ProtoState1});
handle_cast({subscribe, TopicTable}, State) ->
with_session(fun(SessPid) -> emqttd_session:subscribe(SessPid, TopicTable) end, State);
handle_cast({unsubscribe, Topics}, State) ->
with_session(fun(SessPid) -> emqttd_session:unsubscribe(SessPid, Topics) end, State);
handle_cast(Msg, State = #state{peername = Peername}) ->
lager:critical("Client ~s: unexpected msg - ~p",[emqttd_net:format(Peername), Msg]),
@ -149,17 +157,26 @@ handle_info({inet_reply, _Sock, {error, Reason}}, State = #state{peername = Peer
handle_info({keepalive, start, TimeoutSec}, State = #state{transport = Transport, socket = Socket, peername = Peername}) ->
lager:debug("Client ~s: Start KeepAlive with ~p seconds", [emqttd_net:format(Peername), TimeoutSec]),
KeepAlive = emqttd_keepalive:new({Transport, Socket}, TimeoutSec, {keepalive, timeout}),
StatFun = fun() ->
case Transport:getstat(Socket, [recv_oct]) of
{ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct};
{error, Error} -> {error, Error}
end
end,
KeepAlive = emqttd_keepalive:start(StatFun, TimeoutSec, {keepalive, check}),
noreply(State#state{keepalive = KeepAlive});
handle_info({keepalive, timeout}, State = #state{peername = Peername, keepalive = KeepAlive}) ->
case emqttd_keepalive:resume(KeepAlive) of
timeout ->
handle_info({keepalive, check}, State = #state{peername = Peername, keepalive = KeepAlive}) ->
case emqttd_keepalive:check(KeepAlive) of
{ok, KeepAlive1} ->
lager:debug("Client ~s: Keepalive Resumed", [emqttd_net:format(Peername)]),
noreply(State#state{keepalive = KeepAlive1});
{error, timeout} ->
lager:debug("Client ~s: Keepalive Timeout!", [emqttd_net:format(Peername)]),
stop({shutdown, keepalive_timeout}, State#state{keepalive = undefined});
{resumed, KeepAlive1} ->
lager:debug("Client ~s: Keepalive Resumed", [emqttd_net:format(Peername)]),
noreply(State#state{keepalive = KeepAlive1})
{error, Error} ->
lager:debug("Client ~s: Keepalive Error: ~p!", [emqttd_net:format(Peername), Error]),
stop({shutdown, keepalive_error}, State#state{keepalive = undefined})
end;
handle_info(Info, State = #state{peername = Peername}) ->
@ -188,12 +205,20 @@ terminate(Reason, #state{peername = Peername,
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%%=============================================================================
%%% Internal functions
%%%=============================================================================
noreply(State) ->
{noreply, State, hibernate}.
%-------------------------------------------------------
% receive and parse tcp data
%-------------------------------------------------------
stop(Reason, State) ->
{stop, Reason, State}.
with_session(Fun, State = #state{proto_state = ProtoState}) ->
Fun(emqttd_protocol:session(ProtoState)), noreply(State).
%% receive and parse tcp data
received(<<>>, State) ->
{noreply, State, hibernate};
@ -244,12 +269,8 @@ control_throttle(State = #state{conn_state = Flow,
{_, _} -> run_socket(State)
end.
stop(Reason, State) ->
{stop, Reason, State}.
received_stats(?PACKET(Type)) ->
emqttd_metrics:inc('packets/received'),
inc(Type).
emqttd_metrics:inc('packets/received'), inc(Type).
inc(?CONNECT) ->
emqttd_metrics:inc('packets/connect');
inc(?PUBLISH) ->

View File

@ -239,16 +239,11 @@ handle(?SUBSCRIBE_PACKET(PacketId, TopicTable), State = #proto_state{client_id =
case lists:member(deny, AllowDenies) of
true ->
%%TODO: return 128 QoS when deny... no need to SUBACK?
lager:error("SUBSCRIBE from '~s' Denied: ~p", [ClientId, TopicTable]),
{ok, State};
lager:error("SUBSCRIBE from '~s' Denied: ~p", [ClientId, TopicTable]);
false ->
%%TODO: GrantedQos should be renamed.
{ok, GrantedQos} = emqttd_session:subscribe(Session, TopicTable),
send(?SUBACK_PACKET(PacketId, GrantedQos), State)
end;
handle({subscribe, TopicTable}, State = #proto_state{session = Session}) ->
{ok, _GrantedQos} = emqttd_session:subscribe(Session, TopicTable),
Callback = fun(GrantedQos) -> send(?SUBACK_PACKET(PacketId, GrantedQos), State) end,
emqttd_session:subscribe(Session, TopicTable, Callback)
end,
{ok, State};
%% protect from empty topic list
@ -256,7 +251,7 @@ handle(?UNSUBSCRIBE_PACKET(PacketId, []), State) ->
send(?UNSUBACK_PACKET(PacketId), State);
handle(?UNSUBSCRIBE_PACKET(PacketId, Topics), State = #proto_state{session = Session}) ->
ok = emqttd_session:unsubscribe(Session, Topics),
emqttd_session:unsubscribe(Session, Topics),
send(?UNSUBACK_PACKET(PacketId), State);
handle(?PACKET(?PINGREQ), State) ->
@ -349,7 +344,7 @@ send_willmsg(ClientId, WillMsg) ->
start_keepalive(0) -> ignore;
start_keepalive(Sec) when Sec > 0 ->
self() ! {keepalive, start, round(Sec * 1.5)}.
self() ! {keepalive, start, round(Sec * 1.2)}.
%%----------------------------------------------------------------------------
%% Validate Packets

View File

@ -59,7 +59,7 @@
%% PubSub APIs
-export([publish/2,
puback/2, pubrec/2, pubrel/2, pubcomp/2,
subscribe/2, unsubscribe/2]).
subscribe/2, subscribe/3, unsubscribe/2]).
-behaviour(gen_server2).
@ -166,9 +166,13 @@ destroy(SessPid, ClientId) ->
%% @doc Subscribe Topics
%% @end
%%------------------------------------------------------------------------------
-spec subscribe(pid(), [{binary(), mqtt_qos()}]) -> {ok, [mqtt_qos()]}.
-spec subscribe(pid(), [{binary(), mqtt_qos()}]) -> ok.
subscribe(SessPid, TopicTable) ->
gen_server2:call(SessPid, {subscribe, TopicTable}, ?PUBSUB_TIMEOUT).
subscribe(SessPid, TopicTable, fun(_) -> ok end).
-spec subscribe(pid(), [{binary(), mqtt_qos()}], Callback :: fun()) -> ok.
subscribe(SessPid, TopicTable, Callback) ->
gen_server2:cast(SessPid, {subscribe, TopicTable, Callback}).
%%------------------------------------------------------------------------------
%% @doc Publish message
@ -213,7 +217,7 @@ pubcomp(SessPid, PktId) ->
%%------------------------------------------------------------------------------
-spec unsubscribe(pid(), [binary()]) -> ok.
unsubscribe(SessPid, Topics) ->
gen_server2:call(SessPid, {unsubscribe, Topics}, ?PUBSUB_TIMEOUT).
gen_server2:cast(SessPid, {unsubscribe, Topics}).
%%%=============================================================================
%%% gen_server callbacks
@ -247,26 +251,24 @@ init([CleanSess, ClientId, ClientPid]) ->
{ok, start_collector(Session#session{client_mon = MRef}), hibernate}.
prioritise_call(Msg, _From, _Len, _State) ->
case Msg of
{unsubscribe, _} -> 2;
{subscribe, _} -> 1;
_ -> 0
end.
case Msg of _ -> 0 end.
prioritise_cast(Msg, _Len, _State) ->
case Msg of
{destroy, _} -> 10;
{resume, _, _} -> 9;
{pubrel, _PktId} -> 8;
{pubcomp, _PktId} -> 8;
{pubrec, _PktId} -> 8;
{puback, _PktId} -> 7;
_ -> 0
{destroy, _} -> 10;
{resume, _, _} -> 9;
{pubrel, _PktId} -> 8;
{pubcomp, _PktId} -> 8;
{pubrec, _PktId} -> 8;
{puback, _PktId} -> 7;
{unsubscribe, _, _} -> 6;
{subscribe, _, _} -> 5;
_ -> 0
end.
prioritise_info(Msg, _Len, _State) ->
case Msg of
{'DOWN', _, process, _, _} -> 10;
{'DOWN', _, _, _, _} -> 10;
{'EXIT', _, _} -> 10;
session_expired -> 10;
{timeout, _, _} -> 5;
@ -275,17 +277,40 @@ prioritise_info(Msg, _Len, _State) ->
_ -> 0
end.
handle_call({subscribe, TopicTable0}, _From, Session = #session{client_id = ClientId,
subscriptions = Subscriptions}) ->
handle_call({publish, Msg = #mqtt_message{qos = ?QOS_2, pktid = PktId}}, _From,
Session = #session{client_id = ClientId,
awaiting_rel = AwaitingRel,
await_rel_timeout = Timeout}) ->
case check_awaiting_rel(Session) of
true ->
TRef = timer(Timeout, {timeout, awaiting_rel, PktId}),
AwaitingRel1 = maps:put(PktId, {Msg, TRef}, AwaitingRel),
{reply, ok, Session#session{awaiting_rel = AwaitingRel1}};
false ->
lager:critical([{client, ClientId}], "Session(~s) dropped Qos2 message "
"for too many awaiting_rel: ~p", [ClientId, Msg]),
{reply, {error, dropped}, Session}
end;
case TopicTable0 -- Subscriptions of
handle_call(Req, _From, State) ->
lager:critical("Unexpected Request: ~p", [Req]),
{reply, ok, State}.
handle_cast({subscribe, TopicTable0, Callback}, Session = #session{
client_id = ClientId, subscriptions = Subscriptions}) ->
TopicTable = emqttd_broker:foldl_hooks('client.subscribe', [ClientId], TopicTable0),
case TopicTable -- Subscriptions of
[] ->
{reply, {ok, [Qos || {_, Qos} <- TopicTable0]}, Session};
catch Callback([Qos || {_, Qos} <- TopicTable]),
noreply(Session);
_ ->
TopicTable = emqttd_broker:foldl_hooks('client.subscribe', [ClientId], TopicTable0),
%% subscribe first and don't care if the subscriptions have been existed
{ok, GrantedQos} = emqttd_pubsub:subscribe(TopicTable),
catch Callback(GrantedQos),
emqttd_broker:foreach_hooks('client.subscribe.after', [ClientId, TopicTable]),
lager:info([{client, ClientId}], "Session(~s): subscribe ~p, Granted QoS: ~p",
@ -310,11 +335,11 @@ handle_call({subscribe, TopicTable0}, _From, Session = #session{client_id = Clie
[{Topic, Qos} | Acc]
end
end, Subscriptions, TopicTable),
{reply, {ok, GrantedQos}, Session#session{subscriptions = Subscriptions1}}
noreply(Session#session{subscriptions = Subscriptions1})
end;
handle_call({unsubscribe, Topics0}, _From, Session = #session{client_id = ClientId,
subscriptions = Subscriptions}) ->
handle_cast({unsubscribe, Topics0}, Session = #session{client_id = ClientId,
subscriptions = Subscriptions}) ->
Topics = emqttd_broker:foldl_hooks('client.unsubscribe', [ClientId], Topics0),
@ -333,26 +358,7 @@ handle_call({unsubscribe, Topics0}, _From, Session = #session{client_id = Client
end
end, Subscriptions, Topics),
{reply, ok, Session#session{subscriptions = Subscriptions1}};
handle_call({publish, Msg = #mqtt_message{qos = ?QOS_2, pktid = PktId}}, _From,
Session = #session{client_id = ClientId,
awaiting_rel = AwaitingRel,
await_rel_timeout = Timeout}) ->
case check_awaiting_rel(Session) of
true ->
TRef = timer(Timeout, {timeout, awaiting_rel, PktId}),
AwaitingRel1 = maps:put(PktId, {Msg, TRef}, AwaitingRel),
{reply, ok, Session#session{awaiting_rel = AwaitingRel1}};
false ->
lager:critical([{client, ClientId}], "Session(~s) dropped Qos2 message "
"for too many awaiting_rel: ~p", [ClientId, Msg]),
{reply, {error, dropped}, Session}
end;
handle_call(Req, _From, State) ->
lager:critical("Unexpected Request: ~p", [Req]),
{reply, ok, State}.
noreply(Session#session{subscriptions = Subscriptions1});
handle_cast({destroy, ClientId}, Session = #session{client_id = ClientId}) ->
lager:warning([{client, ClientId}], "Session(~s) destroyed", [ClientId]),

View File

@ -34,7 +34,10 @@
-include("emqttd_protocol.hrl").
%% API Exports
-export([start_link/1, ws_loop/3, subscribe/2]).
-export([start_link/1, ws_loop/3, session/1, info/1, kick/1]).
%% SUB/UNSUB Asynchronously
-export([subscribe/2, unsubscribe/2]).
-behaviour(gen_server).
@ -61,9 +64,21 @@ start_link(Req) ->
packet_opts = PktOpts,
parser = emqttd_parser:new(PktOpts)}).
session(CPid) ->
gen_server:call(CPid, session, infinity).
info(CPid) ->
gen_server:call(CPid, info, infinity).
kick(CPid) ->
gen_server:call(CPid, kick).
subscribe(CPid, TopicTable) ->
gen_server:cast(CPid, {subscribe, TopicTable}).
unsubscribe(CPid, Topics) ->
gen_server:cast(CPid, {unsubscribe, Topics}).
%%------------------------------------------------------------------------------
%% @private
%% @doc Start WebSocket client.
@ -112,17 +127,30 @@ init([WsPid, Req, ReplyChannel, PktOpts]) ->
ProtoState = emqttd_protocol:init(Peername, SendFun, [{ws_initial_headers, HeadersList}|PktOpts]),
{ok, #client_state{ws_pid = WsPid, request = Req, proto_state = ProtoState}}.
handle_call(session, _From, State = #client_state{proto_state = ProtoState}) ->
{reply, emqttd_protocol:session(ProtoState), State};
handle_call(info, _From, State = #client_state{request = Req,
proto_state = ProtoState}) ->
{reply, [{websocket, true}, {peer, Req:get(peer)}
| emqttd_protocol:info(ProtoState)], State};
handle_call(kick, _From, State) ->
{stop, {shutdown, kick}, ok, State};
handle_call(_Req, _From, State) ->
{reply, error, State}.
handle_cast({subscribe, TopicTable}, State = #client_state{proto_state = ProtoState}) ->
{ok, ProtoState1} = emqttd_protocol:handle({subscribe, TopicTable}, ProtoState),
{noreply, State#client_state{proto_state = ProtoState1}, hibernate};
handle_cast({subscribe, TopicTable}, State) ->
with_session(fun(SessPid) -> emqttd_session:subscribe(SessPid, TopicTable) end, State);
handle_cast({unsubscribe, Topics}, State) ->
with_session(fun(SessPid) -> emqttd_session:unsubscribe(SessPid, Topics) end, State);
handle_cast({received, Packet}, State = #client_state{proto_state = ProtoState}) ->
case emqttd_protocol:received(Packet, ProtoState) of
{ok, ProtoState1} ->
{noreply, State#client_state{proto_state = ProtoState1}};
noreply(State#client_state{proto_state = ProtoState1});
{error, Error} ->
lager:error("MQTT protocol error ~p", [Error]),
stop({shutdown, Error}, State);
@ -137,11 +165,11 @@ handle_cast(_Msg, State) ->
handle_info({deliver, Message}, State = #client_state{proto_state = ProtoState}) ->
{ok, ProtoState1} = emqttd_protocol:send(Message, ProtoState),
{noreply, State#client_state{proto_state = ProtoState1}};
noreply(State#client_state{proto_state = ProtoState1});
handle_info({redeliver, {?PUBREL, PacketId}}, State = #client_state{proto_state = ProtoState}) ->
{ok, ProtoState1} = emqttd_protocol:redeliver({?PUBREL, PacketId}, ProtoState),
{noreply, State#client_state{proto_state = ProtoState1}};
noreply(State#client_state{proto_state = ProtoState1});
handle_info({stop, duplicate_id, _NewPid}, State = #client_state{proto_state = ProtoState}) ->
lager:error("Shutdown for duplicate clientid: ~s", [emqttd_protocol:clientid(ProtoState)]),
@ -149,18 +177,27 @@ handle_info({stop, duplicate_id, _NewPid}, State = #client_state{proto_state = P
handle_info({keepalive, start, TimeoutSec}, State = #client_state{request = Req}) ->
lager:debug("Client(WebSocket) ~s: Start KeepAlive with ~p seconds", [Req:get(peer), TimeoutSec]),
KeepAlive = emqttd_keepalive:new({esockd_transport, Req:get(socket)},
TimeoutSec, {keepalive, timeout}),
{noreply, State#client_state{keepalive = KeepAlive}};
Socket = Req:get(socket),
StatFun = fun() ->
case esockd_transport:getstat(Socket, [recv_oct]) of
{ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct};
{error, Error} -> {error, Error}
end
end,
KeepAlive = emqttd_keepalive:start(StatFun, TimeoutSec, {keepalive, check}),
noreply(State#client_state{keepalive = KeepAlive});
handle_info({keepalive, timeout}, State = #client_state{request = Req, keepalive = KeepAlive}) ->
case emqttd_keepalive:resume(KeepAlive) of
timeout ->
handle_info({keepalive, check}, State = #client_state{request = Req, keepalive = KeepAlive}) ->
case emqttd_keepalive:check(KeepAlive) of
{ok, KeepAlive1} ->
lager:debug("Client(WebSocket) ~s: Keepalive Resumed", [Req:get(peer)]),
noreply(State#client_state{keepalive = KeepAlive1});
{error, timeout} ->
lager:debug("Client(WebSocket) ~s: Keepalive Timeout!", [Req:get(peer)]),
stop({shutdown, keepalive_timeout}, State#client_state{keepalive = undefined});
{resumed, KeepAlive1} ->
lager:debug("Client(WebSocket) ~s: Keepalive Resumed", [Req:get(peer)]),
{noreply, State#client_state{keepalive = KeepAlive1}}
{error, Error} ->
lager:debug("Client(WebSocket) ~s: Keepalive Error: ~p", [Req:get(peer), Error]),
stop({shutdown, keepalive_error}, State#client_state{keepalive = undefined})
end;
handle_info({'EXIT', WsPid, Reason}, State = #client_state{ws_pid = WsPid, proto_state = ProtoState}) ->
@ -170,7 +207,7 @@ handle_info({'EXIT', WsPid, Reason}, State = #client_state{ws_pid = WsPid, proto
handle_info(Info, State = #client_state{request = Req}) ->
lager:critical("Client(WebSocket) ~s: Unexpected Info - ~p", [Req:get(peer), Info]),
{noreply, State}.
noreply(State).
terminate(Reason, #client_state{proto_state = ProtoState, keepalive = KeepAlive}) ->
lager:info("WebSocket client terminated: ~p", [Reason]),
@ -189,6 +226,12 @@ code_change(_OldVsn, State, _Extra) ->
%%% Internal functions
%%%=============================================================================
noreply(State) ->
{noreply, State, hibernate}.
stop(Reason, State ) ->
{stop, Reason, State}.
with_session(Fun, State = #client_state{proto_state = ProtoState}) ->
Fun(emqttd_protocol:session(ProtoState)), noreply(State).