From 63ef00a20878633447c8342e997e48e756595a58 Mon Sep 17 00:00:00 2001 From: JianBo He Date: Wed, 2 Mar 2022 16:43:53 +0800 Subject: [PATCH] fix(gw): add takeover_session/3 for cm_proto_v1 --- apps/emqx_gateway/src/emqx_gateway_cm.erl | 16 ++++++++-------- .../src/proto/emqx_gateway_cm_proto_v1.erl | 7 +++++++ apps/emqx_gateway/test/emqx_gateway_cm_SUITE.erl | 9 +++++---- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/apps/emqx_gateway/src/emqx_gateway_cm.erl b/apps/emqx_gateway/src/emqx_gateway_cm.erl index d5b437527..24ec4415f 100644 --- a/apps/emqx_gateway/src/emqx_gateway_cm.erl +++ b/apps/emqx_gateway/src/emqx_gateway_cm.erl @@ -78,6 +78,7 @@ , do_get_chan_stats/3 , do_set_chan_stats/4 , do_kick_session/4 + , do_takeover_session/3 , do_get_chann_conn_mod/3 , do_call/4 , do_call/5 @@ -301,7 +302,7 @@ open_session(GwName, true = _CleanStart, ClientInfo, ConnInfo, CreateSessionFun, Self = self(), ClientId = maps:get(clientid, ClientInfo), Fun = fun(_) -> - ok = discard_session(GwName, ClientId), + _ = discard_session(GwName, ClientId), Session = create_session(GwName, ClientInfo, ConnInfo, @@ -394,7 +395,7 @@ takeover_session(GwName, ClientId) -> , chan_pids => ChanPids }), lists:foreach(fun(StalePid) -> - catch discard_session(ClientId, StalePid) + catch discard_session(GwName, ClientId, StalePid) end, StalePids), do_takeover_session(GwName, ClientId, ChanPid) end. @@ -415,21 +416,20 @@ do_takeover_session(GwName, ClientId, ChanPid) -> wrap_rpc(emqx_gateway_cm_proto_v1:takeover_session(GwName, ClientId, ChanPid)). %% @doc Discard all the sessions identified by the ClientId. --spec discard_session(GwName :: gateway_name(), binary()) -> ok. +-spec discard_session(GwName :: gateway_name(), binary()) -> ok | {error, not_found}. discard_session(GwName, ClientId) when is_binary(ClientId) -> case lookup_channels(GwName, ClientId) of - [] -> ok; + [] -> {error, not_found}; ChanPids -> lists:foreach(fun(Pid) -> discard_session(GwName, ClientId, Pid) end, ChanPids) end. discard_session(GwName, ClientId, ChanPid) -> kick_session(GwName, discard, ClientId, ChanPid). --spec kick_session(gateway_name(), emqx_types:clientid()) -> ok. - +-spec kick_session(gateway_name(), emqx_types:clientid()) -> ok | {error, not_found}. kick_session(GwName, ClientId) -> case lookup_channels(GwName, ClientId) of - [] -> ok; + [] -> {error, not_found}; ChanPids -> ChanPids > 1 andalso begin ?SLOG(warning, #{ msg => "more_than_one_channel_found" @@ -438,7 +438,7 @@ kick_session(GwName, ClientId) -> #{clientid => ClientId}) end, lists:foreach(fun(Pid) -> - kick_session(GwName, ClientId, Pid) + _ = kick_session(GwName, ClientId, Pid) end, ChanPids) end. diff --git a/apps/emqx_gateway/src/proto/emqx_gateway_cm_proto_v1.erl b/apps/emqx_gateway/src/proto/emqx_gateway_cm_proto_v1.erl index d82255066..08c713b14 100644 --- a/apps/emqx_gateway/src/proto/emqx_gateway_cm_proto_v1.erl +++ b/apps/emqx_gateway/src/proto/emqx_gateway_cm_proto_v1.erl @@ -27,6 +27,7 @@ , kick_session/4 , get_chann_conn_mod/3 , lookup_by_clientid/3 + , takeover_session/3 , call/4 , call/5 , cast/4 @@ -81,6 +82,12 @@ get_chann_conn_mod(GwName, ClientId, ChanPid) -> rpc:call(node(ChanPid), emqx_gateway_cm, do_get_chann_conn_mod, [GwName, ClientId, ChanPid]). +-spec takeover_session(emqx_gateway_cm:gateway_name(), + emqx_types:clientid(), + pid()) -> boolean() | {badrpc, _}. +takeover_session(GwName, ClientId, ChanPid) -> + rpc:call(node(ChanPid), emqx_gateway_cm, do_takeover_session, [GwName, ClientId, ChanPid]). + -spec call(emqx_gateway_cm:gateway_name(), emqx_types:clientid(), pid(), diff --git a/apps/emqx_gateway/test/emqx_gateway_cm_SUITE.erl b/apps/emqx_gateway/test/emqx_gateway_cm_SUITE.erl index 1f715774f..dabea7c95 100644 --- a/apps/emqx_gateway/test/emqx_gateway_cm_SUITE.erl +++ b/apps/emqx_gateway/test/emqx_gateway_cm_SUITE.erl @@ -61,9 +61,10 @@ end_per_testcase(_TestCase, Conf) -> %%-------------------------------------------------------------------- t_open_session(_) -> - {error, not_supported_now} = emqx_gateway_cm:open_session( - ?GWNAME, false, clientinfo(), conninfo(), - fun(_, _) -> #{} end), + {ok, #{present := false, + session := #{}}} = emqx_gateway_cm:open_session( + ?GWNAME, false, clientinfo(), conninfo(), + fun(_, _) -> #{} end), {ok, SessionRes} = emqx_gateway_cm:open_session( ?GWNAME, true, clientinfo(), conninfo(), @@ -189,7 +190,7 @@ t_kick_session(_) -> ok = emqx_gateway_cm:kick_session(?GWNAME, ?CLIENTID), - receive discard -> ok + receive kick -> ok after 100 -> ?assert(false, "waiting discard msg timeout") end, receive