diff --git a/apps/emqtt/include/emqtt.hrl b/apps/emqtt/include/emqtt.hrl index bbd1f95c4..bb17eef64 100644 --- a/apps/emqtt/include/emqtt.hrl +++ b/apps/emqtt/include/emqtt.hrl @@ -58,6 +58,10 @@ -record(mqtt_message, { %% topic is first for message may be retained topic :: binary(), + %% clientid from + from :: binary() | atom(), + %% sender pid ?? + sender :: pid(), qos = ?QOS_0 :: mqtt_qos(), retain = false :: boolean(), dup = false :: boolean(), diff --git a/apps/emqttd/src/emqttd_session.erl b/apps/emqttd/src/emqttd_session.erl index fce1c766b..01a910e3c 100644 --- a/apps/emqttd/src/emqttd_session.erl +++ b/apps/emqttd/src/emqttd_session.erl @@ -225,7 +225,7 @@ puback(Session = #session{clientid = ClientId, awaiting_ack = Awaiting}, {?PUBAC Session#session{awaiting_ack = maps:remove(PacketId, Awaiting)}; puback(SessPid, {?PUBACK, PacketId}) when is_pid(SessPid) -> - gen_server:cast(SessPid, {puback, PacketId}); + gen_server:cast(SessPid, {puback, {?PUBACK, PacketId}); %% PUBREC puback(Session = #session{clientid = ClientId, @@ -239,7 +239,7 @@ puback(Session = #session{clientid = ClientId, awaiting_comp = maps:put(PacketId, true, AwaitingComp)}; puback(SessPid, {?PUBREC, PacketId}) when is_pid(SessPid) -> - gen_server:cast(SessPid, {pubrec, PacketId}), SessPid; + gen_server:cast(SessPid, {puback, {?PUBREC, PacketId}); %% PUBREL puback(Session = #session{clientid = ClientId, awaiting_rel = Awaiting}, {?PUBREL, PacketId}) -> @@ -253,7 +253,7 @@ puback(Session = #session{clientid = ClientId, awaiting_rel = Awaiting}, {?PUBRE Session#session{awaiting_rel = maps:remove(PacketId, Awaiting)}; puback(SessPid, {?PUBREL, PacketId}) when is_pid(SessPid) -> - cast(SessPid, {pubrel, PacketId}); + gen_server:cast(SessPid, {puback, {?PUBREL, PacketId}); %% PUBCOMP puback(Session = #session{clientid = ClientId, @@ -265,7 +265,9 @@ puback(Session = #session{clientid = ClientId, Session#session{awaiting_comp = maps:remove(PacketId, AwaitingComp)}; puback(SessPid, {?PUBCOMP, PacketId}) when is_pid(SessPid) -> - cast(SessPid, {pubcomp, PacketId}). + gen_server:cast(SessPid, {puback, {?PUBCOMP, PacketId}); + +wait_ack timeout(awaiting_rel, MsgId, Session = #session{clientid = ClientId, awaiting_rel = Awaiting}) -> case maps:find(MsgId, Awaiting) of @@ -440,48 +442,34 @@ handle_cast({resume, ClientId, ClientPid}, State = #session{ end, emqttd_queue:all(Queue)), {noreply, State#session{client_pid = ClientPid, - msg_queue = emqttd_queue:clear(Queue), - expire_timer = undefined}, hibernate}; + msg_queue = emqttd_queue:clear(Queue), + expire_timer = undefined}, hibernate}; -handle_cast({publish, ClientId, {?QOS_2, Message}}, State) -> - NewState = publish(State, ClientId, {?QOS_2, Message}), - {noreply, NewState}; +handle_cast({publish, ClientId, {?QOS_2, Message}}, Session) -> + {noreply, publish(Session, ClientId, {?QOS_2, Message})}; -handle_cast({puback, PacketId}, State) -> - NewState = puback(State, {?PUBACK, PacketId}), - {noreply, NewState}; +handle_cast({puback, {PubAck, PacketId}, Session) -> + {noreply, puback(Session, {PubAck, PacketId})}; -handle_cast({pubrec, PacketId}, State) -> - NewState = puback(State, {?PUBREC, PacketId}), - {noreply, NewState}; - -handle_cast({pubrel, PacketId}, State) -> - NewState = puback(State, {?PUBREL, PacketId}), - {noreply, NewState}; - -handle_cast({pubcomp, PacketId}, State) -> - NewState = puback(State, {?PUBCOMP, PacketId}), - {noreply, NewState}; - -handle_cast({destroy, ClientId}, State = #session{clientid = ClientId}) -> +handle_cast({destroy, ClientId}, Session = #session{clientid = ClientId}) -> lager:warning("Session ~s destroyed", [ClientId]), - {stop, normal, State}; + {stop, normal, Session}; handle_cast(Msg, State) -> lager:critical("Unexpected Msg: ~p, State: ~p", [Msg, State]), {noreply, State}. -handle_info({dispatch, {_From, Messages}}, State) when is_list(Messages) -> +handle_info({dispatch, {_From, Messages}}, Session) when is_list(Messages) -> F = fun(Message, S) -> dispatch(Message, S) end, - {noreply, lists:foldl(F, State, Messages)}; + {noreply, lists:foldl(F, Session, Messages)}; handle_info({dispatch, {_From, Message}}, State) -> {noreply, dispatch(Message, State)}; -handle_info({'EXIT', ClientPid, Reason}, State = #session{clientid = ClientId, - client_pid = ClientPid}) -> +handle_info({'EXIT', ClientPid, Reason}, Session = #session{clientid = ClientId, + client_pid = ClientPid}) -> lager:info("Session: client ~s@~p exited for ~p", [ClientId, ClientPid, Reason]), - {noreply, start_expire_timer(State#session{client_pid = undefined})}; + {noreply, start_expire_timer(Session#session{client_pid = undefined})}; handle_info({'EXIT', ClientPid0, _Reason}, State = #session{client_pid = ClientPid}) -> lager:error("Unexpected Client EXIT: pid=~p, pid(state): ~p", [ClientPid0, ClientPid]), @@ -491,51 +479,55 @@ handle_info(session_expired, State = #session{clientid = ClientId}) -> lager:warning("Session ~s expired!", [ClientId]), {stop, {shutdown, expired}, State}; -handle_info({timeout, awaiting_rel, MsgId}, SessState) -> - NewState = timeout(awaiting_rel, MsgId, SessState), - {noreply, NewState}; +handle_info({timeout, awaiting_rel, MsgId}, Session) -> + {noreply, timeout(awaiting_rel, MsgId, Session)}; -handle_info(Info, State) -> - lager:critical("Unexpected Info: ~p, State: ~p", [Info, State]), - {noreply, State}. +handle_info(Info, Session) -> + lager:critical("Unexpected Info: ~p, Session: ~p", [Info, Session]), + {noreply, Session}. -terminate(_Reason, _State) -> +terminate(_Reason, _Session) -> ok. -code_change(_OldVsn, State, _Extra) -> - {ok, State}. - - - +code_change(_OldVsn, Session, _Extra) -> + {ok, Session}. %%%============================================================================= -%%% Internal functions +%%% Dispatch message from broker -> client. %%%============================================================================= -%% client is offline +%% queued the message if client is offline dispatch(Msg, Session = #session{client_pid = undefined}) -> queue(Msg, Session); -%% dispatch qos0 directly +%% dispatch qos0 directly to client process dispatch(Msg = #mqtt_message{qos = ?QOS_0}, Session = #session{client_pid = ClientPid}) -> ClientPid ! {dispatch, {self(), Msg}}, Session; -%% queue if inflight_queue is full -dispatch(Msg = #mqtt_message{qos = Qos}, Session = #session{inflight_window = InflightWin, - inflight_queue = InflightQ}) - when (Qos > ?QOS_0) andalso (length(InflightQ) >= InflightWin) -> - %%TODO: set alarms - lager:error([{clientid, ClientId}], "Session ~s inflight_queue is full!", [ClientId]), - queue(Msg, Session); - -%% dispatch and await ack -dispatch(Msg = #mqtt_message{qos = Qos}, Session = #session{client_pid = ClientPid}) +%% dispatch qos1/2 messages and wait for puback +dispatch(Msg = #mqtt_message{qos = Qos}, Session = #session{clientid = ClientId, + message_id = MsgId, + pending_queue = Q, + inflight_window = Win}) when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) -> - %% assign msgid and await - {NewMsg, NewState} = await_ack(Msg, Session), - ClientPid ! {dispatch, {self(), NewMsg}}, -queue(Msg, Session = #session{pending_queue = Queue}) -> + case emqttd_mqwin:is_full(InflightWin) of + true -> + lager:error("Session ~s inflight window is full!", [ClientId]), + Session#session{pending_queue = emqttd_mqueue:in(Msg, Q)}; + false -> + Msg1 = Msg#mqtt_message{msgid = MsgId}, + Msg2 = + if + Qos =:= ?QOS_2 -> Msg1#mqtt_message{dup = false}; + true -> Msg1 + end, + ClientPid ! {dispatch, {self(), Msg2}}, + NewWin = emqttd_mqwin:in(Msg2, Win), + await_ack(Msg2, next_msgid(Session#session{inflight_window = NewWin})) + end. + +queue(Msg, Session = #session{pending_queue= Queue}) -> Session#session{pending_queue = emqttd_mqueue:in(Msg, Queue)}. next_msgid(State = #session{message_id = 16#ffff}) -> @@ -544,8 +536,9 @@ next_msgid(State = #session{message_id = 16#ffff}) -> next_msgid(State = #session{message_id = MsgId}) -> State#session{message_id = MsgId + 1}. -start_expire_timer(State = #session{expires = Expires, expire_timer = OldTimer}) -> +start_expire_timer(Session = #session{expired_after = Expires, + expired_timer = OldTimer}) -> emqttd_util:cancel_timer(OldTimer), Timer = erlang:send_after(Expires * 1000, self(), session_expired), - State#session{expire_timer = Timer}. + Session#session{expired_timer = Timer}.