diff --git a/src/emqx_cm.erl b/src/emqx_cm.erl index dd8b02e54..9ab737add 100644 --- a/src/emqx_cm.erl +++ b/src/emqx_cm.erl @@ -27,9 +27,7 @@ -export([start_link/0]). --export([ register_channel/1 - , register_channel/2 - , register_channel/3 +-export([ register_channel/3 , unregister_channel/1 ]). @@ -45,6 +43,8 @@ , set_chan_stats/2 ]). +-export([get_chann_conn_mod/2]). + -export([ open_session/3 , discard_session/1 , discard_session/2 @@ -98,28 +98,29 @@ start_link() -> %% API %%-------------------------------------------------------------------- -%% @doc Register a channel. --spec(register_channel(emqx_types:clientid()) -> ok). -register_channel(ClientId) -> - register_channel(ClientId, self()). - -%% @doc Register a channel with pid. --spec(register_channel(emqx_types:clientid(), chan_pid()) -> ok). -register_channel(ClientId, ChanPid) when is_pid(ChanPid) -> - Chan = {ClientId, ChanPid}, - true = ets:insert(?CHAN_TAB, Chan), - true = ets:insert(?CHAN_CONN_TAB, Chan), - ok = emqx_cm_registry:register_channel(Chan), - cast({registered, Chan}). - %% @doc Register a channel with info and stats. -spec(register_channel(emqx_types:clientid(), emqx_types:infos(), emqx_types:stats()) -> ok). -register_channel(ClientId, Info, Stats) -> +register_channel(ClientId, Info = #{conninfo := ConnInfo}, Stats) -> Chan = {ClientId, ChanPid = self()}, true = ets:insert(?CHAN_INFO_TAB, {Chan, Info, Stats}), - register_channel(ClientId, ChanPid). + register_channel(ClientId, ChanPid, ConnInfo); + +%% @private +%% @doc Register a channel with pid and conn_mod. +%% +%% There is a Race-Condition on one node or cluster when many connections +%% login to Broker with the same clientid. We should register it and save +%% the conn_mod first for taking up the clientid access right. +%% +%% Note that: It should be called on a lock transaction +register_channel(ClientId, ChanPid, #{conn_mod := ConnMod}) when is_pid(ChanPid) -> + Chan = {ClientId, ChanPid}, + true = ets:insert(?CHAN_TAB, Chan), + true = ets:insert(?CHAN_CONN_TAB, {Chan, ConnMod}), + ok = emqx_cm_registry:register_channel(Chan), + cast({registered, Chan}). %% @doc Unregister a channel. -spec(unregister_channel(emqx_types:clientid()) -> ok). @@ -130,7 +131,7 @@ unregister_channel(ClientId) when is_binary(ClientId) -> %% @private do_unregister_channel(Chan) -> ok = emqx_cm_registry:unregister_channel(Chan), - true = ets:delete_object(?CHAN_CONN_TAB, Chan), + true = ets:delete(?CHAN_CONN_TAB, Chan), true = ets:delete(?CHAN_INFO_TAB, Chan), ets:delete_object(?CHAN_TAB, Chan). @@ -204,24 +205,29 @@ set_chan_stats(ClientId, ChanPid, Stats) -> pendings => list()}} | {error, Reason :: term()}). open_session(true, ClientInfo = #{clientid := ClientId}, ConnInfo) -> + Self = self(), CleanStart = fun(_) -> ok = discard_session(ClientId), Session = create_session(ClientInfo, ConnInfo), + register_channel(ClientId, Self, ConnInfo), {ok, #{session => Session, present => false}} end, emqx_cm_locker:trans(ClientId, CleanStart); open_session(false, ClientInfo = #{clientid := ClientId}, ConnInfo) -> + Self = self(), ResumeStart = fun(_) -> case takeover_session(ClientId) of {ok, ConnMod, ChanPid, Session} -> ok = emqx_session:resume(ClientInfo, Session), Pendings = ConnMod:call(ChanPid, {takeover, 'end'}), + register_channel(ClientId, Self, ConnInfo), {ok, #{session => Session, present => true, pendings => Pendings}}; {error, not_found} -> Session = create_session(ClientInfo, ConnInfo), + register_channel(ClientId, Self, ConnInfo), {ok, #{session => Session, present => false}} end end, @@ -251,8 +257,8 @@ takeover_session(ClientId) -> end. takeover_session(ClientId, ChanPid) when node(ChanPid) == node() -> - case get_chan_info(ClientId, ChanPid) of - #{conninfo := #{conn_mod := ConnMod}} -> + case get_chann_conn_mod(ClientId, ChanPid) of + ConnMod when is_atom(ConnMod) -> Session = ConnMod:call(ChanPid, {takeover, 'begin'}), {ok, ConnMod, ChanPid, Session}; undefined -> @@ -282,8 +288,8 @@ discard_session(ClientId) when is_binary(ClientId) -> end. discard_session(ClientId, ChanPid) when node(ChanPid) == node() -> - case get_chan_info(ClientId, ChanPid) of - #{conninfo := #{conn_mod := ConnMod}} -> + case get_chann_conn_mod(ClientId, ChanPid) of + ConnMod when is_atom(ConnMod) -> ConnMod:call(ChanPid, discard); undefined -> ok end; @@ -411,3 +417,12 @@ update_stats({Tab, Stat, MaxStat}) -> Size -> emqx_stats:setstat(Stat, MaxStat, Size) end. +get_chann_conn_mod(ClientId, ChanPid) when node(ChanPid) == node() -> + Chan = {ClientId, ChanPid}, + try [ConnMod] = ets:lookup_element(?CHAN_CONN_TAB, Chan, 2), ConnMod + catch + error:badarg -> undefined + end; +get_chann_conn_mod(ClientId, ChanPid) -> + rpc_call(node(ChanPid), get_chann_conn_mod, [ClientId, ChanPid]). + diff --git a/test/emqx_cm_SUITE.erl b/test/emqx_cm_SUITE.erl index 183dd5f30..8d143b563 100644 --- a/test/emqx_cm_SUITE.erl +++ b/test/emqx_cm_SUITE.erl @@ -23,6 +23,13 @@ -include_lib("eunit/include/eunit.hrl"). -define(CM, emqx_cm). +-define(ChanInfo,#{conninfo => + #{socktype => tcp, + peername => {{127,0,0,1}, 5000}, + sockname => {{127,0,0,1}, 1883}, + peercert => nossl, + conn_mod => emqx_connection, + receive_maximum => 100}}). %%-------------------------------------------------------------------- %% CT callbacks @@ -43,13 +50,13 @@ end_per_suite(_Config) -> %%-------------------------------------------------------------------- t_reg_unreg_channel(_) -> - ok = emqx_cm:register_channel(<<"clientid">>), + ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, []), ?assertEqual([self()], emqx_cm:lookup_channels(<<"clientid">>)), ok = emqx_cm:unregister_channel(<<"clientid">>), ?assertEqual([], emqx_cm:lookup_channels(<<"clientid">>)). t_get_set_chan_info(_) -> - Info = #{proto_ver => 4, proto_name => <<"MQTT">>}, + Info = ?ChanInfo, ok = emqx_cm:register_channel(<<"clientid">>, Info, []), ?assertEqual(Info, emqx_cm:get_chan_info(<<"clientid">>)), Info1 = Info#{proto_ver => 5}, @@ -60,7 +67,7 @@ t_get_set_chan_info(_) -> t_get_set_chan_stats(_) -> Stats = [{recv_oct, 10}, {send_oct, 8}], - ok = emqx_cm:register_channel(<<"clientid">>, #{}, Stats), + ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, Stats), ?assertEqual(Stats, emqx_cm:get_chan_stats(<<"clientid">>)), Stats1 = [{recv_oct, 10}|Stats], true = emqx_cm:set_chan_stats(<<"clientid">>, Stats1), @@ -69,27 +76,89 @@ t_get_set_chan_stats(_) -> ?assertEqual(undefined, emqx_cm:get_chan_stats(<<"clientid">>)). t_open_session(_) -> + ok = meck:new(emqx_connection, [passthrough, no_history]), + ok = meck:expect(emqx_connection, call, fun(_, _) -> ok end), + ClientInfo = #{zone => external, clientid => <<"clientid">>, username => <<"username">>, peerhost => {127,0,0,1}}, - ConnInfo = #{peername => {{127,0,0,1}, 5000}, + ConnInfo = #{socktype => tcp, + peername => {{127,0,0,1}, 5000}, + sockname => {{127,0,0,1}, 1883}, + peercert => nossl, + conn_mod => emqx_connection, receive_maximum => 100}, {ok, #{session := Session1, present := false}} = emqx_cm:open_session(true, ClientInfo, ConnInfo), ?assertEqual(100, emqx_session:info(inflight_max, Session1)), {ok, #{session := Session2, present := false}} - = emqx_cm:open_session(false, ClientInfo, ConnInfo), - ?assertEqual(100, emqx_session:info(inflight_max, Session2)). + = emqx_cm:open_session(true, ClientInfo, ConnInfo), + ?assertEqual(100, emqx_session:info(inflight_max, Session2)), + + emqx_cm:unregister_channel(<<"clientid">>), + ok = meck:unload(emqx_connection). + +t_open_session_race_condition(_) -> + ClientInfo = #{zone => external, + clientid => <<"clientid">>, + username => <<"username">>, + peerhost => {127,0,0,1}}, + ConnInfo = #{socktype => tcp, + peername => {{127,0,0,1}, 5000}, + sockname => {{127,0,0,1}, 1883}, + peercert => nossl, + conn_mod => emqx_connection, + receive_maximum => 100}, + + Parent = self(), + OpenASession = fun() -> + timer:sleep(rand:uniform(100)), + OpenR = (emqx_cm:open_session(true, ClientInfo, ConnInfo)), + Parent ! OpenR, + case OpenR of + {ok, _} -> + receive + {'$gen_call', From, discard} -> + gen_server:reply(From, ok), ok + end; + {error, Reason} -> + exit(Reason) + end + end, + [spawn( + fun() -> + spawn(OpenASession), + spawn(OpenASession) + end) || _ <- lists:seq(1, 1000)], + + WaitingRecv = fun _Wr(N1, N2, 0) -> + {N1, N2}; + _Wr(N1, N2, Rest) -> + receive + {ok, _} -> _Wr(N1+1, N2, Rest-1); + {error, _} -> _Wr(N1, N2+1, Rest-1) + end + end, + + ct:pal("Race condition status: ~p~n", [WaitingRecv(0, 0, 2000)]), + + ?assertEqual(1, ets:info(emqx_channel, size)), + ?assertEqual(1, ets:info(emqx_channel_conn, size)), + ?assertEqual(1, ets:info(emqx_channel_registry, size)), + + [Pid] = emqx_cm:lookup_channels(<<"clientid">>), + exit(Pid, kill), timer:sleep(100), + ?assertEqual([], emqx_cm:lookup_channels(<<"clientid">>)). t_discard_session(_) -> ok = meck:new(emqx_connection, [passthrough, no_history]), ok = meck:expect(emqx_connection, call, fun(_, _) -> ok end), ok = emqx_cm:discard_session(<<"clientid">>), - ok = emqx_cm:register_channel(<<"clientid">>), + ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, []), ok = emqx_cm:discard_session(<<"clientid">>), ok = emqx_cm:unregister_channel(<<"clientid">>), - ok = emqx_cm:register_channel(<<"clientid">>, #{conninfo => #{conn_mod => emqx_connection}}, []), + ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, []), ok = emqx_cm:discard_session(<<"clientid">>), ok = meck:expect(emqx_connection, call, fun(_, _) -> error(testing) end), ok = emqx_cm:discard_session(<<"clientid">>), @@ -97,35 +166,26 @@ t_discard_session(_) -> ok = meck:unload(emqx_connection). t_takeover_session(_) -> - ok = meck:new(emqx_connection, [passthrough, no_history]), - ok = meck:expect(emqx_connection, call, fun(_, _) -> test end), {error, not_found} = emqx_cm:takeover_session(<<"clientid">>), - ok = emqx_cm:register_channel(<<"clientid">>), - {error, not_found} = emqx_cm:takeover_session(<<"clientid">>), - ok = emqx_cm:unregister_channel(<<"clientid">>), - ok = emqx_cm:register_channel(<<"clientid">>, #{conninfo => #{conn_mod => emqx_connection}}, []), - Pid = self(), - {ok, emqx_connection, Pid, test} = emqx_cm:takeover_session(<<"clientid">>), erlang:spawn(fun() -> - ok = emqx_cm:register_channel(<<"clientid">>, #{conninfo => #{conn_mod => emqx_connection}}, []), - timer:sleep(1000) + ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, []), + receive + {'$gen_call', From, {takeover, 'begin'}} -> + gen_server:reply(From, test), ok + end end), - ct:sleep(100), + timer:sleep(100), {ok, emqx_connection, _, test} = emqx_cm:takeover_session(<<"clientid">>), - ok = emqx_cm:unregister_channel(<<"clientid">>), - ok = meck:unload(emqx_connection). + emqx_cm:unregister_channel(<<"clientid">>). t_kick_session(_) -> ok = meck:new(emqx_connection, [passthrough, no_history]), ok = meck:expect(emqx_connection, call, fun(_, _) -> test end), {error, not_found} = emqx_cm:kick_session(<<"clientid">>), - ok = emqx_cm:register_channel(<<"clientid">>), - {error, not_found} = emqx_cm:kick_session(<<"clientid">>), - ok = emqx_cm:unregister_channel(<<"clientid">>), - ok = emqx_cm:register_channel(<<"clientid">>, #{conninfo => #{conn_mod => emqx_connection}}, []), + ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, []), test = emqx_cm:kick_session(<<"clientid">>), erlang:spawn(fun() -> - ok = emqx_cm:register_channel(<<"clientid">>, #{conninfo => #{conn_mod => emqx_connection}}, []), + ok = emqx_cm:register_channel(<<"clientid">>, ?ChanInfo, []), timer:sleep(1000) end), ct:sleep(100),