From a2ddd9d5f5e243a740aa4c2ef944610af55bdd9d Mon Sep 17 00:00:00 2001 From: Andrew Mayorov Date: Wed, 20 Sep 2023 14:21:52 +0400 Subject: [PATCH] fix(session): respect existing session even if expiry interval = 0 If the original connection had Session-Expiry-Interval > 0, and the new connection set Session-Expiry-Interval = 0, the MQTTv5 spec says that (supposedly) we still have to continue with the existing session (if it hasn't expired yet). Co-Authored-By: Thales Macedo Garitezi --- apps/emqx/integration_test/emqx_ds_SUITE.erl | 4 +- apps/emqx/src/emqx_persistent_session_ds.erl | 48 +++++++------ apps/emqx/src/emqx_session.erl | 72 ++++++++++++++----- apps/emqx/src/emqx_session_mem.erl | 6 +- .../test/emqx_persistent_session_SUITE.erl | 25 ++++++- apps/emqx_durable_storage/src/emqx_ds.erl | 32 ++++++--- 6 files changed, 129 insertions(+), 58 deletions(-) diff --git a/apps/emqx/integration_test/emqx_ds_SUITE.erl b/apps/emqx/integration_test/emqx_ds_SUITE.erl index a35790897..b042aa87a 100644 --- a/apps/emqx/integration_test/emqx_ds_SUITE.erl +++ b/apps/emqx/integration_test/emqx_ds_SUITE.erl @@ -245,8 +245,8 @@ t_session_subscription_idempotency(Config) -> ?assertEqual([{ClientId, SubTopicFilterWords}], get_all_iterator_refs(Node1)), ?assertMatch({ok, [_]}, get_all_iterator_ids(Node1)), ?assertMatch( - {_IsNew = false, #{}, #{SubTopicFilterWords := #{}}}, - erpc:call(Node1, emqx_ds, session_open, [ClientId, #{}]) + {ok, #{}, #{SubTopicFilterWords := #{}}}, + erpc:call(Node1, emqx_ds, session_open, [ClientId]) ) end ), diff --git a/apps/emqx/src/emqx_persistent_session_ds.erl b/apps/emqx/src/emqx_persistent_session_ds.erl index 35e0677c2..e56a05484 100644 --- a/apps/emqx/src/emqx_persistent_session_ds.erl +++ b/apps/emqx/src/emqx_persistent_session_ds.erl @@ -24,7 +24,7 @@ %% Session API -export([ create/3, - open/3, + open/2, destroy/1 ]). @@ -98,12 +98,11 @@ session(). create(#{clientid := ClientID}, _ConnInfo, Conf) -> % TODO: expiration - {true, Session} = open_session(ClientID, Conf), - Session. + ensure_session(ClientID, Conf). --spec open(clientinfo(), conninfo(), emqx_session:conf()) -> - {_IsPresent :: true, session(), []} | {_IsPresent :: false, session()}. -open(#{clientid := ClientID}, _ConnInfo, Conf) -> +-spec open(clientinfo(), conninfo()) -> + {_IsPresent :: true, session(), []} | false. +open(#{clientid := ClientID}, _ConnInfo) -> %% NOTE %% The fact that we need to concern about discarding all live channels here %% is essentially a consequence of the in-memory session design, where we @@ -111,24 +110,31 @@ open(#{clientid := ClientID}, _ConnInfo, Conf) -> %% somehow isolate those idling not-yet-expired sessions into a separate process %% space, and move this call back into `emqx_cm` where it belongs. ok = emqx_cm:discard_session(ClientID), - {IsNew, Session} = open_session(ClientID, Conf), - IsPresent = not IsNew, - case IsPresent of - true -> - {IsPresent, Session, []}; + case open_session(ClientID) of + Session = #{} -> + {true, Session, []}; false -> - {IsPresent, Session} + false end. -open_session(ClientID, Conf) -> - {IsNew, Session, Iterators} = emqx_ds:session_open(ClientID, Conf), - {IsNew, Session#{ - iterators => maps:fold( - fun(Topic, Iterator, Acc) -> Acc#{emqx_topic:join(Topic) => Iterator} end, - #{}, - Iterators - ) - }}. +ensure_session(ClientID, Conf) -> + {ok, Session, #{}} = emqx_ds:session_ensure_new(ClientID, Conf), + Session#{iterators => #{}}. + +open_session(ClientID) -> + case emqx_ds:session_open(ClientID) of + {ok, Session, Iterators} -> + Session#{iterators => prep_iterators(Iterators)}; + false -> + false + end. + +prep_iterators(Iterators) -> + maps:fold( + fun(Topic, Iterator, Acc) -> Acc#{emqx_topic:join(Topic) => Iterator} end, + #{}, + Iterators + ). -spec destroy(session() | clientinfo()) -> ok. destroy(#{id := ClientID}) -> diff --git a/apps/emqx/src/emqx_session.erl b/apps/emqx/src/emqx_session.erl index 71a0e8eea..092c4483a 100644 --- a/apps/emqx/src/emqx_session.erl +++ b/apps/emqx/src/emqx_session.erl @@ -156,6 +156,15 @@ -define(IMPL(S), (get_impl_mod(S))). +%%-------------------------------------------------------------------- +%% Behaviour +%% ------------------------------------------------------------------- + +-callback create(clientinfo(), conninfo(), conf()) -> + t(). +-callback open(clientinfo(), conninfo()) -> + {_IsPresent :: true, t(), _ReplayContext} | false. + %%-------------------------------------------------------------------- %% Create a Session %%-------------------------------------------------------------------- @@ -167,7 +176,11 @@ create(ClientInfo, ConnInfo) -> create(ClientInfo, ConnInfo, Conf) -> % FIXME error conditions - Session = (choose_impl_mod(ConnInfo)):create(ClientInfo, ConnInfo, Conf), + create(choose_impl_mod(ConnInfo), ClientInfo, ConnInfo, Conf). + +create(Mod, ClientInfo, ConnInfo, Conf) -> + % FIXME error conditions + Session = Mod:create(ClientInfo, ConnInfo, Conf), ok = emqx_metrics:inc('session.created'), ok = emqx_hooks:run('session.created', [ClientInfo, info(Session)]), Session. @@ -176,17 +189,29 @@ create(ClientInfo, ConnInfo, Conf) -> {_IsPresent :: true, t(), _ReplayContext} | {_IsPresent :: false, t()}. open(ClientInfo, ConnInfo) -> Conf = get_session_conf(ClientInfo, ConnInfo), - case (choose_impl_mod(ConnInfo)):open(ClientInfo, ConnInfo, Conf) of - {_IsPresent = true, Session, ReplayContext} -> - {true, Session, ReplayContext}; - {_IsPresent = false, NewSession} -> - ok = emqx_metrics:inc('session.created'), - ok = emqx_hooks:run('session.created', [ClientInfo, info(NewSession)]), - {false, NewSession}; - _IsPresent = false -> - {false, create(ClientInfo, ConnInfo, Conf)} + Mods = [Default | _] = choose_impl_candidates(ConnInfo), + %% NOTE + %% Try to look the existing session up in session stores corresponding to the given + %% `Mods` in order, starting from the last one. + case try_open(Mods, ClientInfo, ConnInfo) of + {_IsPresent = true, _, _} = Present -> + Present; + false -> + %% NOTE + %% Nothing was found, create a new session with the `Default` implementation. + {false, create(Default, ClientInfo, ConnInfo, Conf)} end. +try_open([Mod | Rest], ClientInfo, ConnInfo) -> + case try_open(Rest, ClientInfo, ConnInfo) of + {_IsPresent = true, _, _} = Present -> + Present; + false -> + Mod:open(ClientInfo, ConnInfo) + end; +try_open([], _ClientInfo, _ConnInfo) -> + false. + -spec get_session_conf(clientinfo(), conninfo()) -> conf(). get_session_conf( #{zone := Zone}, @@ -527,15 +552,24 @@ get_impl_mod(Session) when ?IS_SESSION_IMPL_DS(Session) -> emqx_persistent_session_ds. -spec choose_impl_mod(conninfo()) -> module(). -choose_impl_mod(#{expiry_interval := 0}) -> - emqx_session_mem; -choose_impl_mod(#{expiry_interval := EI}) when EI > 0 -> - case emqx_persistent_message:is_store_enabled() of - true -> - emqx_persistent_session_ds; - false -> - emqx_session_mem - end. +choose_impl_mod(#{expiry_interval := EI}) -> + hd(choose_impl_candidates(EI, emqx_persistent_message:is_store_enabled())). + +-spec choose_impl_candidates(conninfo()) -> [module()]. +choose_impl_candidates(#{expiry_interval := EI}) -> + choose_impl_candidates(EI, emqx_persistent_message:is_store_enabled()). + +choose_impl_candidates(_, _IsPSStoreEnabled = false) -> + [emqx_session_mem]; +choose_impl_candidates(0, _IsPSStoreEnabled = true) -> + %% NOTE + %% If ExpiryInterval is 0, the natural choice is `emqx_session_mem`. Yet we still + %% need to look the existing session up in the `emqx_persistent_session_ds` store + %% first, because previous connection may have set ExpiryInterval to a non-zero + %% value. + [emqx_session_mem, emqx_persistent_session_ds]; +choose_impl_candidates(EI, _IsPSStoreEnabled = true) when EI > 0 -> + [emqx_persistent_session_ds]. -compile({inline, [run_hook/2]}). run_hook(Name, Args) -> diff --git a/apps/emqx/src/emqx_session_mem.erl b/apps/emqx/src/emqx_session_mem.erl index 578a4fb68..e72feffd5 100644 --- a/apps/emqx/src/emqx_session_mem.erl +++ b/apps/emqx/src/emqx_session_mem.erl @@ -57,7 +57,7 @@ -export([ create/3, - open/3, + open/2, destroy/1 ]). @@ -193,9 +193,9 @@ destroy(_Session) -> %% Open a (possibly existing) Session %%-------------------------------------------------------------------- --spec open(clientinfo(), conninfo(), emqx_session:conf()) -> +-spec open(clientinfo(), conninfo()) -> {_IsPresent :: true, session(), replayctx()} | _IsPresent :: false. -open(ClientInfo = #{clientid := ClientId}, _ConnInfo, _Conf) -> +open(ClientInfo = #{clientid := ClientId}, _ConnInfo) -> case emqx_cm:takeover_session_begin(ClientId) of {ok, SessionRemote, TakeoverState} -> Session = resume(ClientInfo, SessionRemote), diff --git a/apps/emqx/test/emqx_persistent_session_SUITE.erl b/apps/emqx/test/emqx_persistent_session_SUITE.erl index c1ba6a60c..89fba9738 100644 --- a/apps/emqx/test/emqx_persistent_session_SUITE.erl +++ b/apps/emqx/test/emqx_persistent_session_SUITE.erl @@ -50,13 +50,14 @@ all() -> groups() -> TCs = emqx_common_test_helpers:all(?MODULE), + TCsNonGeneric = [t_choose_impl], [ {persistent_store_disabled, [{group, no_kill_connection_process}]}, {persistent_store_ds, [{group, no_kill_connection_process}]}, {no_kill_connection_process, [], [{group, tcp}, {group, quic}, {group, ws}]}, {tcp, [], TCs}, - {quic, [], TCs}, - {ws, [], TCs} + {quic, [], TCs -- TCsNonGeneric}, + {ws, [], TCs -- TCsNonGeneric} ]. init_per_group(persistent_store_disabled, Config) -> @@ -276,6 +277,25 @@ do_publish(Payload, PublishFun, WaitForUnregister) -> %% Test Cases %%-------------------------------------------------------------------- +t_choose_impl(Config) -> + ClientId = ?config(client_id, Config), + ConnFun = ?config(conn_fun, Config), + {ok, Client} = emqtt:start_link([ + {clientid, ClientId}, + {proto_ver, v5}, + {properties, #{'Session-Expiry-Interval' => 30}} + | Config + ]), + {ok, _} = emqtt:ConnFun(Client), + [ChanPid] = emqx_cm:lookup_channels(ClientId), + ?assertEqual( + case ?config(persistent_store, Config) of + false -> emqx_session_mem; + ds -> emqx_persistent_session_ds + end, + emqx_connection:info({channel, {session, impl}}, sys:get_state(ChanPid)) + ). + t_connect_discards_existing_client(Config) -> ClientId = ?config(client_id, Config), ConnFun = ?config(conn_fun, Config), @@ -372,7 +392,6 @@ t_assigned_clientid_persistent_session(Config) -> {ok, Client2} = emqtt:start_link([ {clientid, AssignedClientId}, {proto_ver, v5}, - {properties, #{'Session-Expiry-Interval' => 30}}, {clean_start, false} | Config ]), diff --git a/apps/emqx_durable_storage/src/emqx_ds.erl b/apps/emqx_durable_storage/src/emqx_ds.erl index e7890a3a1..b311d2550 100644 --- a/apps/emqx_durable_storage/src/emqx_ds.erl +++ b/apps/emqx_durable_storage/src/emqx_ds.erl @@ -26,7 +26,8 @@ -export([iterator_update/2, iterator_next/1, iterator_stats/0]). %% Session: -export([ - session_open/2, + session_open/1, + session_ensure_new/2, session_drop/1, session_suspend/1, session_add_iterator/3, @@ -148,28 +149,36 @@ message_stats() -> %%-------------------------------------------------------------------------------- %% @doc Called when a client connects. This function looks up a -%% session or creates a new one if previous one couldn't be found. +%% session or returns `false` if previous one couldn't be found. %% %% This function also spawns replay agents for each iterator. %% %% Note: session API doesn't handle session takeovers, it's the job of %% the broker. --spec session_open(session_id(), _Props :: map()) -> - {_New :: boolean(), session(), iterators()}. -session_open(SessionId, Props) -> +-spec session_open(session_id()) -> + {ok, session(), iterators()} | false. +session_open(SessionId) -> transaction(fun() -> case mnesia:read(?SESSION_TAB, SessionId, write) of [Record = #session{}] -> Session = export_record(Record), IteratorRefs = session_read_iterators(SessionId), Iterators = export_iterators(IteratorRefs), - {false, Session, Iterators}; + {ok, Session, Iterators}; [] -> - Session = export_record(session_create(SessionId, Props)), - {true, Session, #{}} + false end end). +-spec session_ensure_new(session_id(), _Props :: map()) -> + {ok, session(), iterators()}. +session_ensure_new(SessionId, Props) -> + transaction(fun() -> + ok = session_drop_iterators(SessionId), + Session = export_record(session_create(SessionId, Props)), + {ok, Session, #{}} + end). + session_create(SessionId, Props) -> Session = #session{ id = SessionId, @@ -186,11 +195,14 @@ session_create(SessionId, Props) -> session_drop(DSSessionId) -> transaction(fun() -> %% TODO: ensure all iterators from this clientid are closed? - IteratorRefs = session_read_iterators(DSSessionId), - ok = lists:foreach(fun session_del_iterator/1, IteratorRefs), + ok = session_drop_iterators(DSSessionId), ok = mnesia:delete(?SESSION_TAB, DSSessionId, write) end). +session_drop_iterators(DSSessionId) -> + IteratorRefs = session_read_iterators(DSSessionId), + ok = lists:foreach(fun session_del_iterator/1, IteratorRefs). + %% @doc Called when a client disconnects. This function terminates all %% active processes related to the session. -spec session_suspend(session_id()) -> ok | {error, session_not_found}.