From 2dae8020ec7f5b22290fc86b9ecd801e082b7f98 Mon Sep 17 00:00:00 2001 From: Andrew Mayorov Date: Fri, 15 Sep 2023 21:01:55 +0400 Subject: [PATCH] refactor(cm): avoid deep indirection in `emqx_session_mem` --- apps/emqx/src/emqx_cm.erl | 37 ++++++++++++++++++------------ apps/emqx/src/emqx_session_mem.erl | 21 ++++++++--------- apps/emqx/test/emqx_cm_SUITE.erl | 29 +++++++++++------------ 3 files changed, 45 insertions(+), 42 deletions(-) diff --git a/apps/emqx/src/emqx_cm.erl b/apps/emqx/src/emqx_cm.erl index cbe1a8f55..89300a4f6 100644 --- a/apps/emqx/src/emqx_cm.erl +++ b/apps/emqx/src/emqx_cm.erl @@ -52,7 +52,8 @@ open_session/3, discard_session/1, discard_session/2, - takeover_channel_session/2, + takeover_session_begin/1, + takeover_session_end/1, kick_session/1, kick_session/2 ]). @@ -118,6 +119,8 @@ _Stats :: emqx_types:stats() }. +-type takeover_state() :: {_ConnMod :: module(), _ChanPid :: pid()}. + -define(CHAN_STATS, [ {?CHAN_TAB, 'channels.count', 'channels.max'}, {?CHAN_TAB, 'sessions.count', 'sessions.max'}, @@ -289,28 +292,32 @@ create_register_session(ClientInfo = #{clientid := ClientId}, ConnInfo, ChanPid) {ok, #{session => Session, present => false}}. %% @doc Try to takeover a session from existing channel. -%% Naming is wierd, because `takeover_session/2` is an RPC target and cannot be renamed. --spec takeover_channel_session(emqx_types:clientid(), _TODO) -> - {ok, emqx_session:session(), _ReplayContext} | none | {error, _Reason}. -takeover_channel_session(ClientId, OnTakeover) -> - takeover_channel_session(ClientId, pick_channel(ClientId), OnTakeover). +-spec takeover_session_begin(emqx_types:clientid()) -> + {ok, emqx_session_mem:session(), takeover_state()} | none. +takeover_session_begin(ClientId) -> + takeover_session_begin(ClientId, pick_channel(ClientId)). -takeover_channel_session(ClientId, ChanPid, OnTakeover) when is_pid(ChanPid) -> +takeover_session_begin(ClientId, ChanPid) when is_pid(ChanPid) -> case takeover_session(ClientId, ChanPid) of {living, ConnMod, Session} -> - Session1 = OnTakeover(Session), - case wrap_rpc(emqx_cm_proto_v2:takeover_finish(ConnMod, ChanPid)) of - {ok, Pendings} -> - {ok, Session1, Pendings}; - {error, _} = Error -> - Error - end; + {ok, Session, {ConnMod, ChanPid}}; none -> none end; -takeover_channel_session(_ClientId, undefined, _OnTakeover) -> +takeover_session_begin(_ClientId, undefined) -> none. +%% @doc Conclude the session takeover process. +-spec takeover_session_end(takeover_state()) -> + {ok, _ReplayContext} | {error, _Reason}. +takeover_session_end({ConnMod, ChanPid}) -> + case wrap_rpc(emqx_cm_proto_v2:takeover_finish(ConnMod, ChanPid)) of + {ok, Pendings} -> + {ok, Pendings}; + {error, _} = Error -> + Error + end. + -spec pick_channel(emqx_types:clientid()) -> maybe(pid()). pick_channel(ClientId) -> diff --git a/apps/emqx/src/emqx_session_mem.erl b/apps/emqx/src/emqx_session_mem.erl index e5e76ed31..42f261321 100644 --- a/apps/emqx/src/emqx_session_mem.erl +++ b/apps/emqx/src/emqx_session_mem.erl @@ -196,17 +196,16 @@ destroy(_Session) -> -spec open(clientinfo(), emqx_types:conninfo()) -> {true, session(), replayctx()} | false. open(ClientInfo = #{clientid := ClientId}, _ConnInfo) -> - case - emqx_cm:takeover_channel_session( - ClientId, - fun(Session) -> resume(ClientInfo, Session) end - ) - of - {ok, Session, Pendings} -> - clean_session(ClientInfo, Session, Pendings); - {error, _} -> - % TODO log error? - false; + case emqx_cm:takeover_session_begin(ClientId) of + {ok, SessionRemote, TakeoverState} -> + Session = resume(ClientInfo, SessionRemote), + case emqx_cm:takeover_session_end(TakeoverState) of + {ok, Pendings} -> + clean_session(ClientInfo, Session, Pendings); + {error, _} -> + % TODO log error? + false + end; none -> false end. diff --git a/apps/emqx/test/emqx_cm_SUITE.erl b/apps/emqx/test/emqx_cm_SUITE.erl index ea874987b..8c6712c5e 100644 --- a/apps/emqx/test/emqx_cm_SUITE.erl +++ b/apps/emqx/test/emqx_cm_SUITE.erl @@ -321,7 +321,7 @@ test_stepdown_session(Action, Reason) -> discard -> emqx_cm:discard_session(ClientId); {takeover, _} -> - none = emqx_cm:takeover_channel_session(ClientId, fun ident/1), + none = emqx_cm:takeover_session_begin(ClientId), ok end, case Reason =:= timeout orelse Reason =:= noproc of @@ -381,10 +381,11 @@ t_discard_session_race(_) -> t_takeover_session(_) -> #{conninfo := ConnInfo} = ?ChanInfo, - none = emqx_cm:takeover_channel_session(<<"clientid">>, fun ident/1), + ClientId = <<"clientid">>, + none = emqx_cm:takeover_session_begin(ClientId), Parent = self(), - erlang:spawn_link(fun() -> - ok = emqx_cm:register_channel(<<"clientid">>, self(), ConnInfo), + ChanPid = erlang:spawn_link(fun() -> + ok = emqx_cm:register_channel(ClientId, self(), ConnInfo), Parent ! registered, receive {'$gen_call', From1, {takeover, 'begin'}} -> @@ -398,16 +399,17 @@ t_takeover_session(_) -> receive registered -> ok end, - {ok, test, []} = emqx_cm:takeover_channel_session(<<"clientid">>, fun ident/1), - emqx_cm:unregister_channel(<<"clientid">>). + {ok, test, State = {emqx_connection, ChanPid}} = emqx_cm:takeover_session_begin(ClientId), + {ok, []} = emqx_cm:takeover_session_end(State), + emqx_cm:unregister_channel(ClientId). t_takeover_session_process_gone(_) -> #{conninfo := ConnInfo} = ?ChanInfo, ClientIDTcp = <<"clientidTCP">>, ClientIDWs = <<"clientidWs">>, ClientIDRpc = <<"clientidRPC">>, - none = emqx_cm:takeover_channel_session(ClientIDTcp, fun ident/1), - none = emqx_cm:takeover_channel_session(ClientIDWs, fun ident/1), + none = emqx_cm:takeover_session_begin(ClientIDTcp), + none = emqx_cm:takeover_session_begin(ClientIDWs), meck:new(emqx_connection, [passthrough, no_history]), meck:expect( emqx_connection, @@ -420,7 +422,7 @@ t_takeover_session_process_gone(_) -> end ), ok = emqx_cm:register_channel(ClientIDTcp, self(), ConnInfo), - none = emqx_cm:takeover_channel_session(ClientIDTcp, fun ident/1), + none = emqx_cm:takeover_session_begin(ClientIDTcp), meck:expect( emqx_connection, call, @@ -432,7 +434,7 @@ t_takeover_session_process_gone(_) -> end ), ok = emqx_cm:register_channel(ClientIDWs, self(), ConnInfo), - none = emqx_cm:takeover_channel_session(ClientIDWs, fun ident/1), + none = emqx_cm:takeover_session_begin(ClientIDWs), meck:expect( emqx_connection, call, @@ -444,7 +446,7 @@ t_takeover_session_process_gone(_) -> end ), ok = emqx_cm:register_channel(ClientIDRpc, self(), ConnInfo), - none = emqx_cm:takeover_channel_session(ClientIDRpc, fun ident/1), + none = emqx_cm:takeover_session_begin(ClientIDRpc), emqx_cm:unregister_channel(ClientIDTcp), emqx_cm:unregister_channel(ClientIDWs), emqx_cm:unregister_channel(ClientIDRpc), @@ -463,8 +465,3 @@ t_message(_) -> ?CM ! testing, gen_server:cast(?CM, testing), gen_server:call(?CM, testing). - -%% - -ident(V) -> - V.