diff --git a/include/emqttd_protocol.hrl b/include/emqttd_protocol.hrl index b7f2c5743..ee7b553d9 100644 --- a/include/emqttd_protocol.hrl +++ b/include/emqttd_protocol.hrl @@ -209,6 +209,11 @@ #mqtt_packet{header = #mqtt_packet_header{type = ?CONNACK}, variable = #mqtt_packet_connack{return_code = ReturnCode}}). +-define(CONNACK_PACKET(ReturnCode, SessPresent), + #mqtt_packet{header = #mqtt_packet_header{type = ?CONNACK}, + variable = #mqtt_packet_connack{ack_flags = SessPresent, + return_code = ReturnCode}}). + -define(PUBLISH_PACKET(Qos, PacketId), #mqtt_packet{header = #mqtt_packet_header{type = ?PUBLISH, qos = Qos}, diff --git a/src/emqttd_protocol.erl b/src/emqttd_protocol.erl index 33571e213..053b74e1a 100644 --- a/src/emqttd_protocol.erl +++ b/src/emqttd_protocol.erl @@ -149,7 +149,7 @@ process(Packet = ?CONNECT_PACKET(Var), State0) -> trace(recv, Packet, State1), - {ReturnCode1, State3} = + {ReturnCode1, SessPresent, State3} = case validate_connect(Var, State1) of ?CONNACK_ACCEPT -> case emqttd_access_control:auth(client(State1), Password) of @@ -159,27 +159,27 @@ process(Packet = ?CONNECT_PACKET(Var), State0) -> %% Start session case emqttd_sm:start_session(CleanSess, clientid(State2)) of - {ok, Session} -> + {ok, Session, SP} -> %% Register the client emqttd_cm:register(client(State2)), %% Start keepalive start_keepalive(KeepAlive), %% ACCEPT - {?CONNACK_ACCEPT, State2#proto_state{session = Session}}; + {?CONNACK_ACCEPT, SP, State2#proto_state{session = Session}}; {error, Error} -> exit({shutdown, Error}) end; {error, Reason}-> ?LOG(error, "Username '~s' login failed for ~s", [Username, Reason], State1), - {?CONNACK_CREDENTIALS, State1} + {?CONNACK_CREDENTIALS, false, State1} end; ReturnCode -> - {ReturnCode, State1} + {ReturnCode, false, State1} end, %% Run hooks emqttd_broker:foreach_hooks('client.connected', [ReturnCode1, client(State3)]), %% Send connack - send(?CONNACK_PACKET(ReturnCode1), State3); + send(?CONNACK_PACKET(ReturnCode1, sp(SessPresent)), State3); process(Packet = ?PUBLISH_PACKET(_Qos, Topic, _PacketId, _Payload), State) -> case check_acl(publish, Topic, client(State)) of @@ -405,3 +405,6 @@ check_acl(publish, Topic, Client) -> check_acl(subscribe, Topic, Client) -> emqttd_access_control:check_acl(Client, subscribe, Topic). +sp(true) -> 1; +sp(false) -> 0. + diff --git a/src/emqttd_session.erl b/src/emqttd_session.erl index e11b9a127..f06037238 100644 --- a/src/emqttd_session.erl +++ b/src/emqttd_session.erl @@ -585,7 +585,8 @@ kick(_ClientId, Pid, Pid) -> ignore; kick(ClientId, OldPid, Pid) -> unlink(OldPid), - OldPid ! {shutdown, conflict, {ClientId, Pid}}. + OldPid ! {shutdown, conflict, {ClientId, Pid}}, + ok. %%------------------------------------------------------------------------------ %% Check inflight and awaiting_rel diff --git a/src/emqttd_sm.erl b/src/emqttd_sm.erl index d62e85c86..f3211cae9 100644 --- a/src/emqttd_sm.erl +++ b/src/emqttd_sm.erl @@ -24,7 +24,6 @@ %%% %%% @end %%%----------------------------------------------------------------------------- - -module(emqttd_sm). -author("Feng Lee "). @@ -57,7 +56,7 @@ -define(SM_POOL, ?MODULE). --define(CALL_TIMEOUT, 60000). +-define(TIMEOUT, 60000). -define(LOG(Level, Format, Args, Session), lager:Level("SM(~s): " ++ Format, [Session#mqtt_session.client_id | Args])). @@ -103,7 +102,7 @@ pool() -> ?SM_POOL. %% @doc Start a session %% @end %%------------------------------------------------------------------------------ --spec start_session(CleanSess :: boolean(), binary()) -> {ok, pid()} | {error, any()}. +-spec start_session(CleanSess :: boolean(), binary()) -> {ok, pid(), boolean()} | {error, any()}. start_session(CleanSess, ClientId) -> SM = gproc_pool:pick_worker(?SM_POOL, ClientId), call(SM, {start_session, {CleanSess, ClientId, self()}}). @@ -144,7 +143,7 @@ sesstab(true) -> mqtt_transient_session; sesstab(false) -> mqtt_persistent_session. call(SM, Req) -> - gen_server2:call(SM, Req, ?CALL_TIMEOUT). %%infinity). + gen_server2:call(SM, Req, ?TIMEOUT). %%infinity). %%%============================================================================= %%% gen_server callbacks @@ -168,20 +167,20 @@ handle_call({start_session, {false, ClientId, ClientPid}}, _From, State) -> case lookup_session(ClientId) of undefined -> %% create session locally - {reply, create_session(false, ClientId, ClientPid), State}; + reply(create_session(false, ClientId, ClientPid), false, State); Session -> - {reply, resume_session(Session, ClientPid), State} + reply(resume_session(Session, ClientPid), true, State) end; %% transient session handle_call({start_session, {true, ClientId, ClientPid}}, _From, State) -> case lookup_session(ClientId) of undefined -> - {reply, create_session(true, ClientId, ClientPid), State}; + reply(create_session(true, ClientId, ClientPid), false, State); Session -> case destroy_session(Session) of ok -> - {reply, create_session(true, ClientId, ClientPid), State}; + reply(create_session(true, ClientId, ClientPid), false, State); {error, Error} -> {reply, {error, Error}, State} end @@ -302,3 +301,8 @@ remove_session(Session) -> {aborted, Error} -> {error, Error} end. +reply({ok, SessPid}, SP, State) -> + {reply, {ok, SessPid, SP}, State}; +reply({error, Error}, _SP, State) -> + {reply, {error, Error}, State}. +