diff --git a/include/emqttd.hrl b/include/emqttd.hrl index ab4b6ee72..51bce2f6d 100644 --- a/include/emqttd.hrl +++ b/include/emqttd.hrl @@ -91,7 +91,6 @@ -record(mqtt_client, { client_id :: binary() | undefined, client_pid :: pid(), - client_mon :: reference(), username :: binary() | undefined, peername :: {inet:ip_address(), integer()}, clean_sess :: boolean(), diff --git a/src/emqttd_cm.erl b/src/emqttd_cm.erl index b3bfb242b..4c2c2b252 100644 --- a/src/emqttd_cm.erl +++ b/src/emqttd_cm.erl @@ -33,7 +33,7 @@ %% API Exports -export([start_link/2, pool/0]). --export([lookup/1, register/1, unregister/1]). +-export([lookup/1, lookup_proc/1, register/1, unregister/1]). -behaviour(gen_server2). @@ -44,7 +44,7 @@ %% gen_server2 priorities -export([prioritise_call/4, prioritise_cast/3, prioritise_info/3]). --record(state, {id, statsfun}). +-record(state, {id, statsfun, monitors}). -define(CM_POOL, ?MODULE). @@ -68,16 +68,28 @@ start_link(Id, StatsFun) -> pool() -> ?CM_POOL. %%------------------------------------------------------------------------------ -%% @doc Lookup client pid with clientId +%% @doc Lookup client by clientId %% @end %%------------------------------------------------------------------------------ -spec lookup(ClientId :: binary()) -> mqtt_client() | undefined. lookup(ClientId) when is_binary(ClientId) -> case ets:lookup(mqtt_client, ClientId) of - [Client] -> Client; - [] -> undefined + [Client] -> Client; + [] -> undefined end. +%%------------------------------------------------------------------------------ +%% @doc Lookup client pid by clientId +%% @end +%%------------------------------------------------------------------------------ +-spec lookup_proc(ClientId :: binary()) -> pid() | undefined. +lookup_proc(ClientId) when is_binary(ClientId) -> + try ets:lookup_element(mqtt_client, ClientId, #mqtt_client.client_pid) of + Pid -> Pid + catch + error:badarg -> undefined + end. + %%------------------------------------------------------------------------------ %% @doc Register clientId with pid. %% @end @@ -102,15 +114,15 @@ unregister(ClientId) when is_binary(ClientId) -> init([Id, StatsFun]) -> gproc_pool:connect_worker(?CM_POOL, {?MODULE, Id}), - {ok, #state{id = Id, statsfun = StatsFun}}. + {ok, #state{id = Id, statsfun = StatsFun, monitors = dict:new()}}. prioritise_call(_Req, _From, _Len, _State) -> 1. prioritise_cast(Msg, _Len, _State) -> case Msg of - {register, _Client} -> 2; - {unregister, _ClientId, _Pid} -> 3; + {register, _Client} -> 2; + {unregister, _ClientId, _Pid} -> 9; _ -> 1 end. @@ -123,28 +135,20 @@ handle_call(Req, _From, State) -> handle_cast({register, Client = #mqtt_client{client_id = ClientId, client_pid = Pid}}, State) -> - case ets:lookup(mqtt_client, ClientId) of - [#mqtt_client{client_pid = Pid}] -> - ignore; - [#mqtt_client{client_pid = _OldPid, client_mon = MRef}] -> - %% demonitor - erlang:demonitor(MRef, [flush]); - [] -> - ok - end, - ets:insert(mqtt_client, Client#mqtt_client{client_mon = erlang:monitor(process, Pid)}), - {noreply, setstats(State)}; + case lookup_proc(ClientId) of + Pid -> + {noreply, State}; + _None -> + ets:insert(mqtt_client, Client), + {noreply, setstats(monitor_client(ClientId, Pid, State))} + end; handle_cast({unregister, ClientId, Pid}, State) -> - case ets:lookup(mqtt_client, ClientId) of - [#mqtt_client{client_pid = Pid, client_mon = MRef}] -> - erlang:demonitor(MRef, [flush]), + case lookup_proc(ClientId) of + Pid -> ets:delete(mqtt_client, ClientId), {noreply, setstats(State)}; - [_] -> - {noreply, State}; - [] -> - lager:warning("CM(~s): Cannot find pid ~p", [ClientId, Pid]), + undefined -> {noreply, State} end; @@ -152,16 +156,20 @@ handle_cast(Msg, State) -> lager:error("Unexpected Msg: ~p", [Msg]), {noreply, State}. -handle_info({'DOWN', MRef, process, DownPid, Reason}, State) -> - MP = #mqtt_client{client_pid = DownPid, client_mon = MRef, _ = '_'}, - case ets:match_object(mqtt_client, MP) of - [Client] -> - ?LOG(warning, "client ~p DOWN for ~p", [DownPid, Reason], Client), - ets:delete_object(mqtt_client, Client); - [] -> - ignore - end, - {noreply, setstats(State)}; +handle_info({'DOWN', MRef, process, DownPid, _Reason}, State = #state{monitors = MonDict}) -> + case dict:find(MRef, MonDict) of + {ok, {ClientId, DownPid}} -> + case lookup_proc(ClientId) of + DownPid -> + ets:delete(mqtt_client, ClientId); + _ -> + ignore + end, + {noreply, setstats(erase_monitor(MRef, State))}; + error -> + lager:error("MRef of client ~p not found", [DownPid]), + {noreply, State} + end; handle_info(Info, State) -> lager:error("Unexpected Info: ~p", [Info]), @@ -178,6 +186,13 @@ code_change(_OldVsn, State, _Extra) -> %%% Internal functions %%%============================================================================= +monitor_client(ClientId, Pid, State = #state{monitors = Monintors}) -> + MRef = erlang:monitor(process, Pid), + State#state{monitors = dict:store(MRef, {ClientId, Pid}, Monintors)}. + +erase_monitor(MRef, State = #state{monitors = Monintors}) -> + State#state{monitors = dict:erase(MRef, Monintors)}. + setstats(State = #state{statsfun = StatsFun}) -> StatsFun(ets:info(mqtt_client, size)), State. diff --git a/src/emqttd_protocol.erl b/src/emqttd_protocol.erl index 053b74e1a..29985f044 100644 --- a/src/emqttd_protocol.erl +++ b/src/emqttd_protocol.erl @@ -282,14 +282,18 @@ redeliver({?PUBREL, PacketId}, State) -> shutdown(_Error, #proto_state{client_id = undefined}) -> ignore; -shutdown(conflict, #proto_state{client_id = ClientId}) -> - emqttd_cm:unregister(ClientId); +shutdown(conflict, #proto_state{client_id = _ClientId}) -> + %% let it down + %% emqttd_cm:unregister(ClientId); + ignore; shutdown(Error, State = #proto_state{client_id = ClientId, will_msg = WillMsg}) -> ?LOG(info, "Shutdown for ~p", [Error], State), send_willmsg(ClientId, WillMsg), emqttd_broker:foreach_hooks('client.disconnected', [Error, ClientId]), - emqttd_cm:unregister(ClientId). + %% let it down + %% emqttd_cm:unregister(ClientId). + ok. willmsg(Packet) when is_record(Packet, mqtt_packet_connect) -> emqttd_message:from_packet(Packet).