From 39548cc3998e420b337e61ebcf30a1ae755a0ce4 Mon Sep 17 00:00:00 2001 From: Feng Lee Date: Wed, 4 Apr 2018 15:28:01 +0800 Subject: [PATCH] Improve the session management --- include/emqx.hrl | 4 +- src/emqx_access_rule.erl | 6 +- src/emqx_banned.erl | 72 ++++++++++++++ src/emqx_flapping.erl | 74 ++++++++++++++ src/emqx_mod_presence.erl | 8 +- src/emqx_mod_subscription.erl | 6 +- src/emqx_protocol.erl | 8 +- src/emqx_router.erl | 3 +- src/emqx_session.erl | 2 +- src/emqx_sm.erl | 182 ++++++++++++++++++---------------- src/emqx_sm_locker.erl | 44 ++++++-- src/emqx_tracer.erl | 2 +- 12 files changed, 295 insertions(+), 116 deletions(-) create mode 100644 src/emqx_banned.erl create mode 100644 src/emqx_flapping.erl diff --git a/include/emqx.hrl b/include/emqx.hrl index ddd91dbca..0639ae19d 100644 --- a/include/emqx.hrl +++ b/include/emqx.hrl @@ -81,8 +81,8 @@ -type(zone() :: atom()). -record(client, - { id :: client_id(), - pid :: pid(), + { client_id :: client_id(), + client_pid :: pid(), zone :: zone(), node :: node(), username :: username(), diff --git a/src/emqx_access_rule.erl b/src/emqx_access_rule.erl index 63bf3b9ad..a96094065 100644 --- a/src/emqx_access_rule.erl +++ b/src/emqx_access_rule.erl @@ -99,7 +99,7 @@ match_who(_Client, {user, all}) -> true; match_who(_Client, {client, all}) -> true; -match_who(#client{id = ClientId}, {client, ClientId}) -> +match_who(#client{client_id = ClientId}, {client, ClientId}) -> true; match_who(#client{username = Username}, {user, Username}) -> true; @@ -137,9 +137,9 @@ feed_var(Client, Pattern) -> feed_var(Client, Pattern, []). feed_var(_Client, [], Acc) -> lists:reverse(Acc); -feed_var(Client = #client{id = undefined}, [<<"%c">>|Words], Acc) -> +feed_var(Client = #client{client_id = undefined}, [<<"%c">>|Words], Acc) -> feed_var(Client, Words, [<<"%c">>|Acc]); -feed_var(Client = #client{id = ClientId}, [<<"%c">>|Words], Acc) -> +feed_var(Client = #client{client_id = ClientId}, [<<"%c">>|Words], Acc) -> feed_var(Client, Words, [ClientId |Acc]); feed_var(Client = #client{username = undefined}, [<<"%u">>|Words], Acc) -> feed_var(Client, Words, [<<"%u">>|Acc]); diff --git a/src/emqx_banned.erl b/src/emqx_banned.erl new file mode 100644 index 000000000..66cdf87b5 --- /dev/null +++ b/src/emqx_banned.erl @@ -0,0 +1,72 @@ +%%-------------------------------------------------------------------- +%% Copyright © 2013-2018 EMQ Inc. All rights reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +%% Banned an IP Address, ClientId? +-module(emqx_banned). + +-behaviour(gen_server). + +%% API +-export([start_link/0]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +-define(SERVER, ?MODULE). + +-record(state, {}). + +%%%=================================================================== +%%% API +%%%=================================================================== + +%% @doc Starts the server +-spec(start_link() -> {ok, pid()} | ignore | {error, any()}). +start_link() -> + gen_server:start_link({local, ?SERVER}, ?MODULE, [], []). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== + +init([]) -> + {ok, #state{}}. + +handle_call(_Request, _From, State) -> + Reply = ok, + {reply, Reply, State}. + +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info(_Info, State) -> + {noreply, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== + + + + + diff --git a/src/emqx_flapping.erl b/src/emqx_flapping.erl new file mode 100644 index 000000000..a0c1c3d45 --- /dev/null +++ b/src/emqx_flapping.erl @@ -0,0 +1,74 @@ +%%-------------------------------------------------------------------- +%% Copyright (C) 2013-2018 EMQ Inc. All rights reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +%% 1. Flapping Detection +%% 2. Conflict Detection? +-module(emqx_flapping). + +%% Use ets:update_counter??? + +-behaviour(gen_server). + +-export([start_link/0]). + +-export([is_banned/1, banned/1]). + +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +-define(SERVER, ?MODULE). + +-record(state, {}). + +-spec(start_link() -> {ok, pid()} | ignore | {error, any()}). +start_link() -> + gen_server:start_link({local, ?SERVER}, ?MODULE, [], []). + +is_banned(ClientId) -> + ets:member(banned, ClientId). + +banned(ClientId) -> + ets:insert(banned, {ClientId, os:timestamp()}). + +%%-------------------------------------------------------------------- +%% gen_server callbacks +%%-------------------------------------------------------------------- + +init([]) -> + _ = ets:new(banned, [public, ordered_set, named_table]), + {ok, #state{}}. + +handle_call(_Request, _From, State) -> + Reply = ok, + {reply, Reply, State}. + +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info(_Info, State) -> + {noreply, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%% Internal functions +%%-------------------------------------------------------------------- + + diff --git a/src/emqx_mod_presence.erl b/src/emqx_mod_presence.erl index ab9c553b1..91cf654fc 100644 --- a/src/emqx_mod_presence.erl +++ b/src/emqx_mod_presence.erl @@ -28,9 +28,9 @@ load(Env) -> emqx:hook('client.connected', fun ?MODULE:on_client_connected/3, [Env]), emqx:hook('client.disconnected', fun ?MODULE:on_client_disconnected/3, [Env]). -on_client_connected(ConnAck, Client = #client{id = ClientId, - username = Username, - peername = {IpAddr, _} +on_client_connected(ConnAck, Client = #client{client_id = ClientId, + username = Username, + peername = {IpAddr, _} %%clean_sess = CleanSess, %%proto_ver = ProtoVer }, Env) -> @@ -49,7 +49,7 @@ on_client_connected(ConnAck, Client = #client{id = ClientId, end, {ok, Client}. -on_client_disconnected(Reason, #client{id = ClientId, +on_client_disconnected(Reason, #client{client_id = ClientId, username = Username}, Env) -> case catch emqx_json:encode([{clientid, ClientId}, {username, Username}, diff --git a/src/emqx_mod_subscription.erl b/src/emqx_mod_subscription.erl index d37a1bec4..83badcef6 100644 --- a/src/emqx_mod_subscription.erl +++ b/src/emqx_mod_subscription.erl @@ -33,9 +33,9 @@ load(Topics) -> emqx:hook('client.connected', fun ?MODULE:on_client_connected/3, [Topics]). -on_client_connected(?CONNACK_ACCEPT, Client = #client{id = ClientId, - pid = ClientPid, - username = Username}, Topics) -> +on_client_connected(?CONNACK_ACCEPT, Client = #client{client_id = ClientId, + client_pid = ClientPid, + username = Username}, Topics) -> Replace = fun(Topic) -> rep(<<"%u">>, Username, rep(<<"%c">>, ClientId, Topic)) end, TopicTable = [{Replace(Topic), Qos} || {Topic, Qos} <- Topics], diff --git a/src/emqx_protocol.erl b/src/emqx_protocol.erl index 1623432a0..2976c79a5 100644 --- a/src/emqx_protocol.erl +++ b/src/emqx_protocol.erl @@ -122,8 +122,8 @@ client(#proto_state{client_id = ClientId, WillMsg =:= undefined -> undefined; true -> WillMsg#message.topic end, - #client{id = ClientId, - pid = ClientPid, + #client{client_id = ClientId, + client_pid = ClientPid, username = Username, peername = Peername, mountpoint = MountPoint}. @@ -327,7 +327,7 @@ publish(Packet = ?PUBLISH_PACKET(?QOS_0, _PacketId), mountpoint = MountPoint, session = Session}) -> Msg = emqx_packet:to_message(Packet), - Msg1 = Msg#message{from = #client{id = ClientId, username = Username}}, + Msg1 = Msg#message{from = #client{client_id = ClientId, username = Username}}, emqx_session:publish(Session, mount(replvar(MountPoint, State), Msg1)); publish(Packet = ?PUBLISH_PACKET(?QOS_1, _PacketId), State) -> @@ -343,7 +343,7 @@ with_puback(Type, Packet = ?PUBLISH_PACKET(_Qos, PacketId), session = Session}) -> %% TODO: ... Msg = emqx_packet:to_message(Packet), - Msg1 = Msg#message{from = #client{id = ClientId, username = Username}}, + Msg1 = Msg#message{from = #client{client_id = ClientId, username = Username}}, case emqx_session:publish(Session, mount(replvar(MountPoint, State), Msg1)) of ok -> send(?PUBACK_PACKET(Type, PacketId), State); diff --git a/src/emqx_router.erl b/src/emqx_router.erl index 3903dc005..e4741c8c4 100644 --- a/src/emqx_router.erl +++ b/src/emqx_router.erl @@ -153,9 +153,8 @@ cast(Router, Msg) -> pick(Topic) -> gproc_pool:pick_worker(router, Topic). -%%FIXME: OOM? dump() -> - [{route, [{To, Dest} || #route{topic = To, dest = Dest} <- ets:tab2list(route)]}]. + ets:tab2list(route). %%-------------------------------------------------------------------- %% gen_server callbacks diff --git a/src/emqx_session.erl b/src/emqx_session.erl index b9c1c87f0..d1b691d32 100644 --- a/src/emqx_session.erl +++ b/src/emqx_session.erl @@ -299,7 +299,7 @@ init(#{clean_start := CleanStart, force_gc_count = ForceGcCount, ignore_loop_deliver = IgnoreLoopDeliver, created_at = os:timestamp()}, - %%emqx_sm:register_session(ClientId, info(State)), + emqx_sm:register_session(ClientId, self()), emqx_hooks:run('session.created', [ClientId, Username]), io:format("Session started: ~p~n", [self()]), {ok, emit_stats(State), hibernate}. diff --git a/src/emqx_sm.erl b/src/emqx_sm.erl index d0e974456..f4ea1468c 100644 --- a/src/emqx_sm.erl +++ b/src/emqx_sm.erl @@ -24,128 +24,134 @@ -export([open_session/1, lookup_session/1, close_session/1]). -export([resume_session/1, discard_session/1]). --export([register_session/1, unregister_session/1, unregister_session/2]). +-export([register_session/1, register_session/2]). +-export([unregister_session/1, unregister_session/2]). -%% lock_session/1, create_session/1, unlock_session/1, - --export([dispatch/3]). +%% Internal functions for rpc +-export([lookup/1, dispatch/3]). -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --record(state, {stats_fun, stats_timer, monitors = #{}}). +-record(state, {stats, pids = #{}}). --spec(start_link(StatsFun :: fun()) -> {ok, pid()} | ignore | {error, term()}). +-spec(start_link(fun()) -> {ok, pid()} | ignore | {error, term()}). start_link(StatsFun) -> gen_server:start_link({local, ?MODULE}, ?MODULE, [StatsFun], []). -open_session(Session = #{client_id := ClientId, clean_start := true}) -> - with_lock(ClientId, - fun() -> - io:format("Nodelist: ~p~n", [ekka_membership:nodelist()]), - case rpc:multicall(ekka_membership:nodelist(), ?MODULE, discard_session, [ClientId]) of - {_Res, []} -> ok; - {_Res, BadNodes} -> emqx_log:error("[SM] Bad nodes found when lock a session: ~p", [BadNodes]) +open_session(Attrs = #{clean_start := true, + client_id := ClientId, client_pid := ClientPid}) -> + CleanStart = fun(_) -> + discard_session(ClientId, ClientPid), + emqx_session_sup:start_session(Attrs) + end, + emqx_sm_locker:trans(ClientId, CleanStart); + +open_session(Attrs = #{clean_start := false, + client_id := ClientId, client_pid := ClientPid}) -> + ResumeStart = fun(_) -> + case resume_session(ClientId, ClientPid) of + {ok, SessionPid} -> + {ok, SessionPid}; + {error, not_found} -> + emqx_session_sup:start_session(Attrs); + {error, Reason} -> + {error, Reason} + end end, - io:format("Begin to start session: ~p~n", [Session]), - emqx_session_sup:start_session(Session) - end); - -open_session(Session = #{client_id := ClientId, clean_start := false}) -> - with_lock(ClientId, - fun() -> - {ResL, _BadNodes} = rpc:multicall(ekka_membership:nodelist(), ?MODULE, lookup_session, [ClientId]), - case lists:flatten([Pid || Pid <- ResL, Pid =/= undefined]) of - [] -> - {ok, emqx_session_sup:start_session(Session)}; - [SessPid|_] -> - case resume_session(SessPid) of - ok -> {ok, SessPid}; - {error, Reason} -> - emqx_log:error("[SM] Failed to resume session: ~p, ~p", [Session, Reason]), - emqx_session_sup:start_session(Session) - end - end - end). - -resume_session(SessPid) when node(SessPid) == node() -> - case is_process_alive(SessPid) of - true -> - emqx_session:resume(SessPid, self()); - false -> - emqx_log:error("Cannot resume ~p which seems already dead!", [SessPid]), - {error, session_died} - end; - -resume_session(SessPid) -> - case rpc:call(node(SessPid), emqx_session, resume, [SessPid]) of - ok -> {ok, SessPid}; - {badrpc, Reason} -> - {error, Reason}; - {error, Reason} -> - {error, Reason} - end. + emqx_sm_locker:trans(ClientId, ResumeStart). discard_session(ClientId) -> + discard_session(ClientId, self()). + +discard_session(ClientId, ClientPid) -> + lists:foreach(fun({_, SessionPid}) -> + catch emqx_session:discard(SessionPid, ClientPid) + end, lookup_session(ClientId)). + +resume_session(ClientId) -> + resume_session(ClientId, self()). + +resume_session(ClientId, ClientPid) -> case lookup_session(ClientId) of - undefined -> ok; - Pid -> emqx_session:discard(Pid) + [] -> {error, not_found}; + [{_, SessionPid}] -> + ok = emqx_session:resume(SessionPid, ClientPid), + {ok, SessionPid}; + [{_, SessionPid}|_More] = Sessions -> + emqx_log:error("[SM] More than one session found: ~p", [Sessions]), + ok = emqx_session:resume(SessionPid, ClientPid), + {ok, SessionPid} end. lookup_session(ClientId) -> - try ets:lookup_element(session, ClientId, 2) catch error:badarg -> undefined end. + {ResL, _} = multicall(?MODULE, lookup, [ClientId]), + lists:append(ResL). -close_session(SessPid) -> - emqx_session:close(SessPid). +close_session(ClientId) -> + lists:foreach(fun(#session{pid = SessionPid}) -> + emqx_session:close(SessionPid) + end, lookup_session(ClientId)). -with_lock(ClientId, Fun) -> - case emqx_sm_locker:lock(ClientId) of - true -> Result = Fun(), - emqx_sm_locker:unlock(ClientId), - Result; - false -> {error, client_id_unavailable}; - {error, Reason} -> {error, Reason} - end. - --spec(register_session(client_id()) -> true). register_session(ClientId) -> - ets:insert(session, {ClientId, self()}). + register_session(ClientId, self()). + +register_session(ClientId, SessionPid) -> + ets:insert(session, {ClientId, SessionPid}). unregister_session(ClientId) -> unregister_session(ClientId, self()). -unregister_session(ClientId, Pid) -> +unregister_session(ClientId, SessionPid) -> case ets:lookup(session, ClientId) of - [{_, Pid}] -> - ets:delete_object(session, {ClientId, Pid}); + [Session = {ClientId, SessionPid}] -> + ets:delete(session_attrs, Session), + ets:delete(session_stats, Session), + ets:delete_object(session, Session); _ -> false end. dispatch(ClientId, Topic, Msg) -> - case lookup_session(ClientId) of - Pid when is_pid(Pid) -> + case lookup(ClientId) of + [{_, Pid}] -> Pid ! {dispatch, Topic, Msg}; - undefined -> + [] -> emqx_hooks:run('message.dropped', [ClientId, Msg]) end. +lookup(ClientId) -> + ets:lookup(session, ClientId). + +multicall(Mod, Fun, Args) -> + multicall(ekka:nodelist(up), Mod, Fun, Args). + +multicall([Node], Mod, Fun, Args) when Node == node() -> + Res = erlang:apply(Mod, Fun, Args), [Res]; + +multicall(Nodes, Mod, Fun, Args) -> + {ResL, _} = emqx_rpc:multicall(Nodes, Mod, Fun, Args), + ResL. + %%-------------------------------------------------------------------- %% gen_server callbacks %%-------------------------------------------------------------------- init([StatsFun]) -> + {ok, sched_stats(StatsFun, #state{pids = #{}})}. + +sched_stats(Fun, State) -> {ok, TRef} = timer:send_interval(timer:seconds(1), stats), - {ok, #state{stats_fun = StatsFun, stats_timer = TRef}}. + State#state{stats = #{func => Fun, timer => TRef}}. handle_call(Req, _From, State) -> emqx_log:error("[SM] Unexpected request: ~p", [Req]), {reply, ignore, State}. -handle_cast({monitor_session, SessionPid, ClientId}, - State = #state{monitors = Monitors}) -> - MRef = erlang:monitor(process, SessionPid), - {noreply, State#state{monitors = maps:put(MRef, ClientId, Monitors)}}; +handle_cast({registered, ClientId, SessionPid}, + State = #state{pids = Pids}) -> + _ = erlang:monitor(process, SessionPid), + {noreply, State#state{pids = maps:put(SessionPid, ClientId, Pids)}}; handle_cast(Msg, State) -> emqx_log:error("[SM] Unexpected msg: ~p", [Msg]), @@ -154,14 +160,14 @@ handle_cast(Msg, State) -> handle_info(stats, State) -> {noreply, setstats(State), hibernate}; -handle_info({'DOWN', MRef, process, DownPid, _Reason}, - State = #state{monitors = Monitors}) -> - case maps:find(MRef, Monitors) of - {ok, {ClientId, Pid}} -> - ets:delete_object(session, {ClientId, Pid}), - {noreply, setstats(State#state{monitors = maps:remove(MRef, Monitors)})}; +handle_info({'DOWN', _MRef, process, DownPid, _Reason}, + State = #state{pids = Pids}) -> + case maps:find(DownPid, Pids) of + {ok, ClientId} -> + unregister_session(ClientId, DownPid), + {noreply, State#state{pids = maps:remove(DownPid, Pids)}}; error -> - emqx_log:error("session ~p not found", [DownPid]), + emqx_log:error("[SM] Session ~p not found", [DownPid]), {noreply, State} end; @@ -169,7 +175,7 @@ handle_info(Info, State) -> emqx_log:error("[SM] Unexpected info: ~p", [Info]), {noreply, State}. -terminate(_Reason, _State = #state{stats_timer = TRef}) -> +terminate(_Reason, _State = #state{stats = #{timer := TRef}}) -> timer:cancel(TRef). code_change(_OldVsn, State, _Extra) -> @@ -179,6 +185,6 @@ code_change(_OldVsn, State, _Extra) -> %% Internal functions %%-------------------------------------------------------------------- -setstats(State = #state{stats_fun = StatsFun}) -> - StatsFun(ets:info(session, size)), State. +setstats(State = #state{stats = #{func := Fun}}) -> + Fun(ets:info(session, size)), State. diff --git a/src/emqx_sm_locker.erl b/src/emqx_sm_locker.erl index 150481979..116b75bd7 100644 --- a/src/emqx_sm_locker.erl +++ b/src/emqx_sm_locker.erl @@ -18,16 +18,44 @@ -include("emqx.hrl"). -%% Lock/Unlock API based on canal-lock. --export([lock/1, unlock/1]). +-export([start_link/0]). -%% @doc Lock a clientid --spec(lock(client_id()) -> boolean() | {error, term()}). +-export([trans/2, trans/3]). + +-export([lock/1, lock/2, unlock/1]). + +-spec(start_link() -> {ok, pid()} | ignore | {error, term()}). +start_link() -> + ekka_locker:start_link(?MODULE). + +-spec(trans(client_id(), fun(([node()]) -> any())) -> any()). +trans(ClientId, Fun) -> + trans(ClientId, Fun, undefined). + +-spec(trans(client_id(), fun(([node()]) -> any()), + ekka_locker:piggyback()) -> any()). +trans(ClientId, Fun, Piggyback) -> + case lock(ClientId, Piggyback) of + {true, Nodes} -> + try Fun(Nodes) after unlock(ClientId) end; + {false, _Nodes} -> + {error, client_id_unavailable} + end. + +-spec(lock(client_id()) -> ekka_locker:lock_result()). lock(ClientId) -> - rpc:call(ekka_membership:leader(), emqx_locker, lock, [ClientId]). + ekka_locker:aquire(?MODULE, ClientId, strategy()). -%% @doc Unlock a clientid --spec(unlock(client_id()) -> ok). +-spec(lock(client_id(), ekka_locker:piggyback()) + -> ekka_locker:lock_result()). +lock(ClientId, Piggyback) -> + ekka_locker:aquire(?MODULE, ClientId, strategy(), Piggyback). + +-spec(unlock(client_id()) -> {boolean(), [node()]}). unlock(ClientId) -> - rpc:call(ekka_membership:leader(), emqx_locker, unlock, [ClientId]). + ekka_locker:release(?MODULE, ClientId, strategy()). + +-spec(strategy() -> local | one | quorum | all). +strategy() -> + application:get_env(emqx, session_locking_strategy, quorum). diff --git a/src/emqx_tracer.erl b/src/emqx_tracer.erl index fa4c25892..74284c35a 100644 --- a/src/emqx_tracer.erl +++ b/src/emqx_tracer.erl @@ -51,7 +51,7 @@ start_link() -> trace(publish, From, _Msg) when is_atom(From) -> %% Dont' trace '$SYS' publish ignore; -trace(publish, #client{id = ClientId, username = Username}, +trace(publish, #client{client_id = ClientId, username = Username}, #message{topic = Topic, payload = Payload}) -> lager:info([{client, ClientId}, {topic, Topic}], "~s/~s PUBLISH to ~s: ~p", [ClientId, Username, Topic, Payload]);