From 517c7eb7b6b7c5826f980a3e869279a6d44ceb4c Mon Sep 17 00:00:00 2001 From: Feng Lee Date: Thu, 11 Jun 2015 23:34:53 +0800 Subject: [PATCH] session upgrade --- apps/emqttd/src/emqttd_protocol.erl | 2 +- apps/emqttd/src/emqttd_session.erl | 205 ++++++++++++++++++---------- rel/files/emqttd.config | 4 +- 3 files changed, 134 insertions(+), 77 deletions(-) diff --git a/apps/emqttd/src/emqttd_protocol.erl b/apps/emqttd/src/emqttd_protocol.erl index d519a182e..37e19f16e 100644 --- a/apps/emqttd/src/emqttd_protocol.erl +++ b/apps/emqttd/src/emqttd_protocol.erl @@ -272,7 +272,7 @@ send({_From = SessPid, Message}, State = #proto_state{session = SessPid}) when i %% message(qos1, qos2) not from session send({_From, Message = #mqtt_message{qos = Qos}}, State = #proto_state{session = Session}) when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) -> - {Message1, NewSession} = emqttd_session:store(Session, Message), + {Message1, NewSession} = emqttd_session:await_ack(Session, Message), send(emqtt_message:to_packet(Message1), State#proto_state{session = NewSession}); send(Packet, State = #proto_state{sendfun = SendFun, peername = Peername}) when is_record(Packet, mqtt_packet) -> diff --git a/apps/emqttd/src/emqttd_session.erl b/apps/emqttd/src/emqttd_session.erl index 8350b89b8..6b6d961a4 100644 --- a/apps/emqttd/src/emqttd_session.erl +++ b/apps/emqttd/src/emqttd_session.erl @@ -24,6 +24,7 @@ %%% %%% @end %%%----------------------------------------------------------------------------- + -module(emqttd_session). -author("Feng Lee "). @@ -53,7 +54,7 @@ -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --record(session_state, { +-record(session, { %% ClientId: Identifier of Session clientid :: binary(), @@ -100,14 +101,17 @@ %% 4, 8, 16 seconds if 3 retries:) unack_retry_after = 4, - %% session expired - sess_expired_after = 48, + %% Awaiting PUBREL timeout + await_rel_timeout = 8, + + %% session expired after 48 hours + sess_expired_after = 172800, sess_expired_timer, - timestamp}). + timestamp }). --type session() :: #session_state{} | pid(). +-type session() :: #session{} | pid(). %%%============================================================================= %%% Session API @@ -132,7 +136,7 @@ start({false = _CleanSess, ClientId, ClientPid}) -> %% @end %%------------------------------------------------------------------------------ -spec resume(session(), binary(), pid()) -> session(). -resume(SessState = #session_state{}, _ClientId, _ClientPid) -> +resume(SessState = #session{}, _ClientId, _ClientPid) -> SessState; resume(SessPid, ClientId, ClientPid) when is_pid(SessPid) -> gen_server:cast(SessPid, {resume, ClientId, ClientPid}), @@ -149,10 +153,12 @@ publish(Session, ClientId, {?QOS_0, Message}) -> publish(Session, ClientId, {?QOS_1, Message}) -> emqttd_pubsub:publish(ClientId, Message), Session; -publish(SessState = #session_state{awaiting_rel = AwaitingRel}, _ClientId, +publish(SessState = #session{awaiting_rel = AwaitingRel, + await_rel_timeout = Timeout}, _ClientId, {?QOS_2, Message = #mqtt_message{msgid = MsgId}}) -> %% store in awaiting_rel - SessState#session_state{awaiting_rel = maps:put(MsgId, Message, AwaitingRel)}; + TRef = erlang:send_after(Timeout * 1000, self(), {timeout, awaiting_rel, MsgId}), + SessState#session{awaiting_rel = maps:put(MsgId, {Message, TRef}, AwaitingRel)}; publish(SessPid, ClientId, {?QOS_2, Message}) when is_pid(SessPid) -> gen_server:cast(SessPid, {publish, ClientId, {?QOS_2, Message}}), @@ -163,59 +169,72 @@ publish(SessPid, ClientId, {?QOS_2, Message}) when is_pid(SessPid) -> %% @end %%------------------------------------------------------------------------------ -spec puback(session(), {mqtt_packet_type(), mqtt_packet_id()}) -> session(). -puback(SessState = #session_state{clientid = ClientId, awaiting_ack = Awaiting}, {?PUBACK, PacketId}) -> +puback(SessState = #session{clientid = ClientId, awaiting_ack = Awaiting}, {?PUBACK, PacketId}) -> case maps:is_key(PacketId, Awaiting) of true -> ok; false -> lager:warning("Session ~s: PUBACK PacketId '~p' not found!", [ClientId, PacketId]) end, - SessState#session_state{awaiting_ack = maps:remove(PacketId, Awaiting)}; + SessState#session{awaiting_ack = maps:remove(PacketId, Awaiting)}; puback(SessPid, {?PUBACK, PacketId}) when is_pid(SessPid) -> gen_server:cast(SessPid, {puback, PacketId}), SessPid; %% PUBREC -puback(SessState = #session_state{clientid = ClientId, +puback(SessState = #session{clientid = ClientId, awaiting_ack = AwaitingAck, awaiting_comp = AwaitingComp}, {?PUBREC, PacketId}) -> case maps:is_key(PacketId, AwaitingAck) of true -> ok; false -> lager:warning("Session ~s: PUBREC PacketId '~p' not found!", [ClientId, PacketId]) end, - SessState#session_state{awaiting_ack = maps:remove(PacketId, AwaitingAck), - awaiting_comp = maps:put(PacketId, true, AwaitingComp)}; + SessState#session{awaiting_ack = maps:remove(PacketId, AwaitingAck), + awaiting_comp = maps:put(PacketId, true, AwaitingComp)}; puback(SessPid, {?PUBREC, PacketId}) when is_pid(SessPid) -> gen_server:cast(SessPid, {pubrec, PacketId}), SessPid; %% PUBREL -puback(SessState = #session_state{clientid = ClientId, - awaiting_rel = Awaiting}, {?PUBREL, PacketId}) -> +puback(SessState = #session{clientid = ClientId, + awaiting_rel = Awaiting}, {?PUBREL, PacketId}) -> case maps:find(PacketId, Awaiting) of - {ok, Msg} -> emqttd_pubsub:publish(ClientId, Msg); - error -> lager:warning("Session ~s: PUBREL PacketId '~p' not found!", [ClientId, PacketId]) + {ok, {Msg, TRef}} -> + catch erlang:cancel_timer(TRef), + emqttd_pubsub:publish(ClientId, Msg); + error -> + lager:error("Session ~s PUBREL PacketId '~p' not found!", [ClientId, PacketId]) end, - SessState#session_state{awaiting_rel = maps:remove(PacketId, Awaiting)}; + SessState#session{awaiting_rel = maps:remove(PacketId, Awaiting)}; puback(SessPid, {?PUBREL, PacketId}) when is_pid(SessPid) -> gen_server:cast(SessPid, {pubrel, PacketId}), SessPid; %% PUBCOMP -puback(SessState = #session_state{clientid = ClientId, +puback(SessState = #session{clientid = ClientId, awaiting_comp = AwaitingComp}, {?PUBCOMP, PacketId}) -> case maps:is_key(PacketId, AwaitingComp) of true -> ok; false -> lager:warning("Session ~s: PUBREC PacketId '~p' not exist", [ClientId, PacketId]) end, - SessState#session_state{awaiting_comp = maps:remove(PacketId, AwaitingComp)}; + SessState#session{awaiting_comp = maps:remove(PacketId, AwaitingComp)}; puback(SessPid, {?PUBCOMP, PacketId}) when is_pid(SessPid) -> gen_server:cast(SessPid, {pubcomp, PacketId}), SessPid. +timeout(awaiting_rel, MsgId, SessState = #session{clientid = ClientId, awaiting_rel = Awaiting}) -> + case maps:find(MsgId, Awaiting) of + {ok, {Msg, _TRef}} -> + lager:error([{client, ClientId}], "Session ~s Awaiting Rel Timout!~nDrop Message:~p", [ClientId, Msg]), + SessState#session{awaiting_rel = maps:remove(MsgId, Awaiting)}; + error -> + lager:error([{client, ClientId}], "Session ~s Cannot find Awaiting Rel: MsgId=~p", [ClientId, MsgId]), + SessState + end. + %%------------------------------------------------------------------------------ %% @doc Subscribe Topics %% @end %%------------------------------------------------------------------------------ -spec subscribe(session(), [{binary(), mqtt_qos()}]) -> {ok, session(), [mqtt_qos()]}. -subscribe(SessState = #session_state{clientid = ClientId, subscriptions = Subscriptions}, Topics) -> +subscribe(SessState = #session{clientid = ClientId, subscriptions = Subscriptions}, Topics) -> %% subscribe first and don't care if the subscriptions have been existed {ok, GrantedQos} = emqttd_pubsub:subscribe(Topics), @@ -242,7 +261,7 @@ subscribe(SessState = #session_state{clientid = ClientId, subscriptions = Subscr end end, Subscriptions, Topics), - {ok, SessState#session_state{subscriptions = Subscriptions1}, GrantedQos}; + {ok, SessState#session{subscriptions = Subscriptions1}, GrantedQos}; subscribe(SessPid, Topics) when is_pid(SessPid) -> {ok, GrantedQos} = gen_server:call(SessPid, {subscribe, Topics}), @@ -253,17 +272,23 @@ subscribe(SessPid, Topics) when is_pid(SessPid) -> %% @end %%------------------------------------------------------------------------------ -spec unsubscribe(session(), [binary()]) -> {ok, session()}. -unsubscribe(SessState = #session_state{clientid = ClientId, subscriptions = Subscriptions}, Topics) -> - %%TODO: refactor later. - case Topics -- maps:keys(SubMap) of - [] -> ok; - BadUnsubs -> lager:warning("~s should not unsubscribe ~p", [ClientId, BadUnsubs]) - end, +unsubscribe(SessState = #session{clientid = ClientId, subscriptions = Subscriptions}, Topics) -> + %%unsubscribe from topic tree ok = emqttd_pubsub:unsubscribe(Topics), lager:info([{client, ClientId}], "Client ~s unsubscribe ~p.", [ClientId, Topics]), - SubMap1 = lists:foldl(fun(Topic, Acc) -> maps:remove(Topic, Acc) end, SubMap, Topics), - {ok, SessState#session_state{submap = SubMap1}}; + + Subscriptions1 = + lists:foldl(fun(Topic, Acc) -> + case lists:keyfind(Topic, 1, Acc) of + {Topic, _Qos} -> + lists:keydelete(Topic, 1, Acc); + false -> + lager:warning([{client, ClientId}], "~s not subscribe ~s", [ClientId, Topic]), Acc + end + end, Subscriptions, Topics), + + {ok, SessState#session{subscriptions = Subscriptions1}}; unsubscribe(SessPid, Topics) when is_pid(SessPid) -> gen_server:call(SessPid, {unsubscribe, Topics}), @@ -277,31 +302,45 @@ unsubscribe(SessPid, Topics) when is_pid(SessPid) -> destroy(SessPid, ClientId) when is_pid(SessPid) -> gen_server:cast(SessPid, {destroy, ClientId}). -%store message(qos1) that sent to client -store(SessState = #session_state{message_id = MsgId, awaiting_ack = Awaiting}, - Message = #mqtt_message{qos = Qos}) when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) -> +% message(qos1) is awaiting ack +await_ack(Msg = #mqtt_message{qos = ?QOS_1}, SessState = #session{message_id = MsgId, + inflight_queue = InflightQ, + awaiting_ack = Awaiting, + unack_retry_after = Time, + max_unack_retries = Retries}) -> + %% assign msgid before send + Msg1 = Msg#mqtt_message{msgid = MsgId}, + TRef = erlang:send_after(Time * 1000, self(), {retry, MsgId}), + Awaiting1 = maps:put(MsgId, {TRef, Retries, Time}, Awaiting), + {Msg1, next_msgid(SessState#session{inflight_queue = [{MsgId, Msg1} | InflightQ], + awaiting_ack = Awaiting1})}. + +% message(qos2) is awaiting ack +await_ack(Message = #mqtt_message{qos = Qos}, SessState = #session{message_id = MsgId, awaiting_ack = Awaiting},) + when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) -> %%assign msgid before send - Message1 = Message#mqtt_message{msgid = MsgId}, + Message1 = Message#mqtt_message{msgid = MsgId, dup = false}, Message2 = if Qos =:= ?QOS_2 -> Message1#mqtt_message{dup = false}; true -> Message1 end, Awaiting1 = maps:put(MsgId, Message2, Awaiting), - {Message1, next_msg_id(SessState#session_state{awaiting_ack = Awaiting1})}. + {Message1, next_msgid(SessState#session{awaiting_ack = Awaiting1})}. initial_state(ClientId) -> - #session_state{clientid = ClientId, - subscriptions = [], - inflight_queue = [], - awaiting_queue = [], - awaiting_ack = #{}, - awaiting_rel = #{}, - awaiting_comp = #{}}. + %%TODO: init session options. + #session{clientid = ClientId, + subscriptions = [], + inflight_queue = [], + awaiting_queue = [], + awaiting_ack = #{}, + awaiting_rel = #{}, + awaiting_comp = #{}}. initial_state(ClientId, ClientPid) -> State = initial_state(ClientId), - State#session_state{client_pid = ClientPid}. + State#session{client_pid = ClientPid}. %%------------------------------------------------------------------------------ %% @doc Start a session process. @@ -319,7 +358,7 @@ init([ClientId, ClientPid]) -> true = link(ClientPid), State = initial_state(ClientId, ClientPid), MQueue = emqttd_mqueue:new(ClientId, emqttd:env(mqtt, queue)), - State1 = State#session_state{pending_queue = MQueue, + State1 = State#session{pending_queue = MQueue, timestamp = os:timestamp()}, {ok, init(emqttd:env(mqtt, session), State1), hibernate}. @@ -328,19 +367,23 @@ init([], State) -> %% Session expired after hours init([{expired_after, Hours} | Opts], State) -> - init(Opts, State#session_state{sess_expired_after = Hours * 3600 * 1000}); + init(Opts, State#session{sess_expired_after = Hours * 3600}); %% Max number of QoS 1 and 2 messages that can be “inflight” at one time. init([{max_inflight_messages, MaxInflight} | Opts], State) -> - init(Opts, State#session_state{inflight_window = MaxInflight}); + init(Opts, State#session{inflight_window = MaxInflight}); %% Max retries for unacknolege Qos1/2 messages init([{max_unack_retries, Retries} | Opts], State) -> - init(Opts, State#session_state{max_unack_retries = Retries}); + init(Opts, State#session{max_unack_retries = Retries}); %% Retry after 4, 8, 16 seconds init([{unack_retry_after, Secs} | Opts], State) -> - init(Opts, State#session_state{unack_retry_after = Secs * 1000}); + init(Opts, State#session{unack_retry_after = Secs}); + +%% Awaiting PUBREL timeout +init([{await_rel_timeout, Secs} | Opts], State) -> + init(Opts, State#session{await_rel_timeout = Secs}); init([Opt | Opts], State) -> lager:error("Bad Session Option: ~p", [Opt]), @@ -358,7 +401,7 @@ handle_call(Req, _From, State) -> lager:error("Unexpected request: ~p", [Req]), {reply, error, State}. -handle_cast({resume, ClientId, ClientPid}, State = #session_state{ +handle_cast({resume, ClientId, ClientPid}, State = #session{ clientid = ClientId, client_pid = OldClientPid, msg_queue = Queue, @@ -399,7 +442,7 @@ handle_cast({resume, ClientId, ClientPid}, State = #session_state{ ClientPid ! {dispatch, {self(), Msg}} end, emqttd_queue:all(Queue)), - {noreply, State#session_state{client_pid = ClientPid, + {noreply, State#session{client_pid = ClientPid, msg_queue = emqttd_queue:clear(Queue), expire_timer = undefined}, hibernate}; @@ -423,7 +466,7 @@ handle_cast({pubcomp, PacketId}, State) -> NewState = puback(State, {?PUBCOMP, PacketId}), {noreply, NewState}; -handle_cast({destroy, ClientId}, State = #session_state{clientid = ClientId}) -> +handle_cast({destroy, ClientId}, State = #session{clientid = ClientId}) -> lager:warning("Session ~s destroyed", [ClientId]), {stop, normal, State}; @@ -438,19 +481,23 @@ handle_info({dispatch, {_From, Messages}}, State) when is_list(Messages) -> handle_info({dispatch, {_From, Message}}, State) -> {noreply, dispatch(Message, State)}; -handle_info({'EXIT', ClientPid, Reason}, State = #session_state{clientid = ClientId, +handle_info({'EXIT', ClientPid, Reason}, State = #session{clientid = ClientId, client_pid = ClientPid}) -> lager:info("Session: client ~s@~p exited for ~p", [ClientId, ClientPid, Reason]), - {noreply, start_expire_timer(State#session_state{client_pid = undefined})}; + {noreply, start_expire_timer(State#session{client_pid = undefined})}; -handle_info({'EXIT', ClientPid0, _Reason}, State = #session_state{client_pid = ClientPid}) -> +handle_info({'EXIT', ClientPid0, _Reason}, State = #session{client_pid = ClientPid}) -> lager:error("Unexpected Client EXIT: pid=~p, pid(state): ~p", [ClientPid0, ClientPid]), {noreply, State}; -handle_info(session_expired, State = #session_state{clientid = ClientId}) -> +handle_info(session_expired, State = #session{clientid = ClientId}) -> lager:warning("Session ~s expired!", [ClientId]), {stop, {shutdown, expired}, State}; +handle_info({timeout, awaiting_rel, MsgId}, SessState) -> + NewState = timeout(awaiting_rel, MsgId, SessState), + {noreply, NewState}; + handle_info(Info, State) -> lager:critical("Unexpected Info: ~p, State: ~p", [Info, State]), {noreply, State}. @@ -465,32 +512,40 @@ code_change(_OldVsn, State, _Extra) -> %%% Internal functions %%%============================================================================= -dispatch(Message, State = #session_state{clientid = ClientId, - client_pid = undefined}) -> - queue(ClientId, Message, State); +%% client is offline +dispatch(Msg, SessState = #session{client_pid = undefined}) -> + queue(Msg, SessState); -dispatch(Message = #mqtt_message{qos = ?QOS_0}, State = #session_state{client_pid = ClientPid}) -> - ClientPid ! {dispatch, {self(), Message}}, - State; +%% dispatch qos0 directly +dispatch(Msg = #mqtt_message{qos = ?QOS_0}, SessState = #session{client_pid = ClientPid}) -> + ClientPid ! {dispatch, {self(), Msg}}, SessState; -dispatch(Message = #mqtt_message{qos = Qos}, State = #session_state{client_pid = ClientPid}) - when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) -> - {Message1, NewState} = store(State, Message), - ClientPid ! {dispatch, {self(), Message1}}, - NewState. +%% queue if inflight_queue is full +dispatch(Msg = #mqtt_message{qos = Qos}, SessState = #session{inflight_window = InflightWin, + inflight_queue = InflightQ}) + when (Qos > ?QOS_0) andalso (length(InflightQ) >= InflightWin) -> + %%TODO: set alarms + lager:error([{clientid, ClientId}], "Session ~s inflight_queue is full!", [ClientId]), + queue(Msg, SessState); -queue(ClientId, Message, State = #session_state{msg_queue = Queue}) -> - State#session_state{msg_queue = emqttd_queue:in(ClientId, Message, Queue)}. +%% dispatch and await ack +dispatch(Msg = #mqtt_message{qos = Qos}, SessState = #session{client_pid = ClientPid}) + when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) -> + %% assign msgid and await + {NewMsg, NewState} = await_ack(Msg, SessState), + ClientPid ! {dispatch, {self(), NewMsg}}, -next_msg_id(State = #session_state{message_id = 16#ffff}) -> - State#session_state{message_id = 1}; +queue(Msg, SessState = #session{pending_queue = Queue}) -> + SessState#session{pending_queue = emqttd_mqueue:in(Msg, Queue)}. -next_msg_id(State = #session_state{message_id = MsgId}) -> - State#session_state{message_id = MsgId + 1}. +next_msgid(State = #session{message_id = 16#ffff}) -> + State#session{message_id = 1}; -start_expire_timer(State = #session_state{expires = Expires, - expire_timer = OldTimer}) -> +next_msgid(State = #session{message_id = MsgId}) -> + State#session{message_id = MsgId + 1}. + +start_expire_timer(State = #session{expires = Expires, expire_timer = OldTimer}) -> emqttd_util:cancel_timer(OldTimer), Timer = erlang:send_after(Expires * 1000, self(), session_expired), - State#session_state{expire_timer = Timer}. + State#session{expire_timer = Timer}. diff --git a/rel/files/emqttd.config b/rel/files/emqttd.config index 5de150b4b..da32fbfc9 100644 --- a/rel/files/emqttd.config +++ b/rel/files/emqttd.config @@ -94,7 +94,9 @@ %% Max retries for unacknolege Qos1/2 messages {max_unack_retries, 3}, %% Retry after 4, 8, 16 seconds - {unack_retry_after, 4} + {unack_retry_after, 4}, + %% Awaiting PUBREL timeout + {await_rel_timeout, 8} ]}, {queue, [ %% Max messages queued when client is disconnected, or inflight messsage window is overload