diff --git a/src/emqttd_sm.erl b/src/emqttd_sm.erl index 6c81ee33f..d08b28d50 100644 --- a/src/emqttd_sm.erl +++ b/src/emqttd_sm.erl @@ -51,11 +51,11 @@ %% gen_server2 priorities -export([prioritise_call/4, prioritise_cast/3, prioritise_info/3]). --record(state, {pool, id}). +-record(state, {pool, id, monitors}). -define(POOL, ?MODULE). --define(TIMEOUT, 60000). +-define(TIMEOUT, 120000). -define(LOG(Level, Format, Args, Session), lager:Level("SM(~s): " ++ Format, [Session#mqtt_session.client_id | Args])). @@ -67,12 +67,11 @@ mnesia(boot) -> %% Global Session Table ok = emqttd_mnesia:create_table(session, [ - {type, ordered_set}, + {type, set}, {ram_copies, [node()]}, {record_name, mqtt_session}, - {attributes, record_info(fields, mqtt_session)}, - %% TODO: index_read is slow... - {index, [sess_pid]}]); + {attributes, record_info(fields, mqtt_session)}]); + mnesia(copy) -> ok = emqttd_mnesia:copy_table(session). @@ -144,7 +143,8 @@ call(SM, Req) -> init([Pool, Id]) -> ?GPROC_POOL(join, Pool, Id), - {ok, #state{pool = Pool, id = Id}}. + {ok, #state{pool = Pool, id = Id, + monitors = dict:new()}}. prioritise_call(_Msg, _From, _Len, _State) -> 1. @@ -155,43 +155,57 @@ prioritise_cast(_Msg, _Len, _State) -> prioritise_info(_Msg, _Len, _State) -> 2. -%% persistent session -handle_call({start_session, {false, ClientId, ClientPid}}, _From, State) -> +%% Persistent Session +handle_call({start_session, Client = {false, ClientId, ClientPid}}, _From, State) -> case lookup_session(ClientId) of undefined -> - %% create session locally - reply(create_session(false, ClientId, ClientPid), false, State); + %% Create session locally + create_session(Client, State); Session -> - reply(resume_session(Session, ClientPid), true, State) + case resume_session(Session, ClientPid) of + {ok, SessPid} -> + {reply, {ok, SessPid, true}, State}; + {error, Erorr} -> + {reply, {error, Erorr}, State} + end end; -%% transient session -handle_call({start_session, {true, ClientId, ClientPid}}, _From, State) -> +%% Transient Session +handle_call({start_session, Client = {true, ClientId, _ClientPid}}, _From, State) -> case lookup_session(ClientId) of undefined -> - reply(create_session(true, ClientId, ClientPid), false, State); + create_session(Client, State); Session -> case destroy_session(Session) of ok -> - reply(create_session(true, ClientId, ClientPid), false, State); + create_session(Client, State); {error, Error} -> {reply, {error, Error}, State} end end; -handle_call(_Request, _From, State) -> - {reply, ok, State}. +handle_call(Req, _From, State) -> + ?UNEXPECTED_REQ(Req, State). handle_cast(Msg, State) -> ?UNEXPECTED_MSG(Msg, State). -%%TODO: fix this issue that index_read is really slow... -handle_info({'DOWN', _MRef, process, DownPid, _Reason}, State) -> - mnesia:transaction(fun() -> - [mnesia:delete_object(session, Sess, write) || Sess - <- mnesia:index_read(session, DownPid, #mqtt_session.sess_pid)] - end), - {noreply, State}; +handle_info({'DOWN', MRef, process, DownPid, _Reason}, State) -> + case dict:find(MRef, State#state.monitors) of + {ok, ClientId} -> + mnesia:transaction(fun() -> + case mnesia:wread({session, ClientId}) of + [] -> ok; + [Sess = #mqtt_session{sess_pid = DownPid}] -> + mnesia:delete_object(session, Sess, write); + [_Sess] -> ok + end + end), + {noreply, erase_monitor(MRef, State)}; + error -> + lager:error("MRef of session ~p not found", [DownPid]), + {noreply, State} + end; handle_info(Info, State) -> ?UNEXPECTED_INFO(Info, State). @@ -206,6 +220,16 @@ code_change(_OldVsn, State, _Extra) -> %%% Internal functions %%%============================================================================= +%% Create Session Locally +create_session({CleanSess, ClientId, ClientPid}, State) -> + case create_session(CleanSess, ClientId, ClientPid) of + {ok, SessPid} -> + {reply, {ok, SessPid, false}, + monitor_session(ClientId, SessPid, State)}; + {error, Error} -> + {reply, {error, Error}, State} + end. + create_session(CleanSess, ClientId, ClientPid) -> case emqttd_session_sup:start_session(CleanSess, ClientId, ClientPid) of {ok, SessPid} -> @@ -218,7 +242,6 @@ create_session(CleanSess, ClientId, ClientPid) -> lager:error("SM(~s): Conflict with ~p", [ClientId, ConflictPid]), {error, mnesia_conflict}; {atomic, ok} -> - erlang:monitor(process, SessPid), {ok, SessPid} end; {error, Error} -> @@ -293,8 +316,10 @@ remove_session(Session) -> {aborted, Error} -> {error, Error} end. -reply({ok, SessPid}, SP, State) -> - {reply, {ok, SessPid, SP}, State}; -reply({error, Error}, _SP, State) -> - {reply, {error, Error}, State}. +monitor_session(ClientId, SessPid, State = #state{monitors = Monitors}) -> + MRef = erlang:monitor(process, SessPid), + State#state{monitors = dict:store(MRef, ClientId, Monitors)}. + +erase_monitor(MRef, State = #state{monitors = Monitors}) -> + State#state{monitors = dict:erase(MRef, Monitors)}.