diff --git a/src/emqtt_client.erl b/src/emqtt_client.erl index d76617f10..ea2eb9772 100644 --- a/src/emqtt_client.erl +++ b/src/emqtt_client.erl @@ -123,7 +123,8 @@ handle_info(keep_alive_timeout, #state{keep_alive=KeepAlive}=State) -> handle_info(Info, State) -> {stop, {badinfo, Info}, State}. -terminate(_Reason, #state{keep_alive=KeepAlive}) -> +terminate(_Reason, #state{client_id=ClientId, keep_alive=KeepAlive}) -> + ok = emqtt_registry:unregister(ClientId), emqtt_keep_alive:cancel(KeepAlive), ok. @@ -201,6 +202,7 @@ process_request(?CONNECT, ?ERROR_MSG("MQTT login failed - no credentials"), {?CONNACK_CREDENTIALS, State}; true -> + ok = emqtt_registry:register(ClientId, self()), KeepAlive = emqtt_keep_alive:new(AlivePeriod*1500, keep_alive_timeout), {?CONNACK_ACCEPT, State #state{ will_msg = make_will_msg(Var), @@ -355,6 +357,7 @@ control_throttle(State = #state{ connection_state = Flow, end. stop(Reason, State ) -> + {stop, Reason, State}. valid_client_id(ClientId) -> diff --git a/src/emqtt_registry.erl b/src/emqtt_registry.erl new file mode 100644 index 000000000..7f437234f --- /dev/null +++ b/src/emqtt_registry.erl @@ -0,0 +1,83 @@ +-module(emqtt_registry). + +-include("emqtt.hrl"). + +-export([start_link/0, + size/0, + register/2, + unregister/1]). + +-behaviour(gen_server). + +-export([init/1, + handle_call/3, + handle_cast/2, + handle_info/2, + terminate/2, + code_change/3]). + +-record(state, {}). + +-define(SERVER, ?MODULE). + +%%---------------------------------------------------------------------------- + +start_link() -> + gen_server2:start_link({local, ?SERVER}, ?MODULE, [], []). + +size() -> + ets:info(client, size). + +register(ClientId, Pid) -> + gen_server2:cast(?SERVER, {register, ClientId, Pid}). + +unregister(ClientId) -> + gen_server2:cast(?SERVER, {unregister, ClientId}). + +%%---------------------------------------------------------------------------- + +init([]) -> + ets:new(client, [set, protected, named_table]), + ?INFO("~p is started.", [?MODULE]), + {ok, #state{}}. % clientid -> {pid, monitor} + +%%-------------------------------------------------------------------------- +handle_call(Req, _From, State) -> + {stop, {badreq, Req}, State}. + +handle_cast({register, ClientId, Pid}, State) -> + ?INFO("register ~p ~p", [ClientId, Pid]), + case ets:lookup(client, ClientId) of + [{_, {OldPid, MRef}}] -> + catch gen_server2:call(OldPid, duplicate_id), + erlang:demonitor(MRef); + [] -> + ignore + end, + ets:insert(client, {ClientId, {Pid, erlang:monitor(process, Pid)}}), + {noreply, State}; + +handle_cast({unregister, ClientId}, State) -> + ?INFO("unregister ~p", [ClientId]), + case ets:lookup(client, ClientId) of + [{_, {_Pid, MRef}}] -> + erlang:demonitor(MRef), + ets:delete(client, ClientId); + [] -> + ignore + end, + {noreply, State}; + +handle_cast(Msg, State) -> + {stop, {badmsg, Msg}, State}. + +handle_info({'DOWN', MRef, process, DownPid, _Reason}, State) -> + ets:match_delete(client, {'_', {DownPid, MRef}}), + {noreply, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + diff --git a/src/emqtt_sup.erl b/src/emqtt_sup.erl index 59ec7fdd7..5bdb3699a 100644 --- a/src/emqtt_sup.erl +++ b/src/emqtt_sup.erl @@ -29,6 +29,7 @@ init([Listeners]) -> ?CHILD(emqtt_auth, worker), ?CHILD(emqtt_retained, worker), ?CHILD(emqtt_router, worker), + ?CHILD(emqtt_registry, worker), ?CHILD(emqtt_client_sup, supervisor) | listener_children(Listeners) ]} }.