From 008eae5a8ee2ee6f6433320ef54c25b7f77bff90 Mon Sep 17 00:00:00 2001 From: Tobias Lindahl Date: Thu, 1 Jul 2021 17:04:15 +0200 Subject: [PATCH] feat: More reliable persistent sessions Add sessions with expiry interval set to > 0 to a mnesia table to avoid losing the session even if the connection process dies or the node goes down. Messages sent after the process dies are still lost. --- apps/emqx/src/emqx_channel.erl | 51 ++++++++++++--------- apps/emqx/src/emqx_cm.erl | 24 +++++++++- apps/emqx/src/emqx_session.erl | 65 +++++++++++++++++++++++++++ apps/emqx/test/emqx_channel_SUITE.erl | 1 + 4 files changed, 117 insertions(+), 24 deletions(-) diff --git a/apps/emqx/src/emqx_channel.erl b/apps/emqx/src/emqx_channel.erl index 5e4d11953..f5207b4ef 100644 --- a/apps/emqx/src/emqx_channel.erl +++ b/apps/emqx/src/emqx_channel.erl @@ -183,7 +183,11 @@ set_conn_state(ConnState, Channel) -> get_session(#channel{session = Session}) -> Session. -set_session(Session, Channel) -> +set_session(Session, Channel = #channel{clientinfo = ClientInfo, conninfo = ConnInfo}) -> + %% Assume that this is also an updated session. Allow side effect. + ClientID = maps:get(clientid, ClientInfo, undefined), + ExpiryInterval = maps:get(expiry_interval, ConnInfo, 0), + emqx_session:db_put(ClientID, ExpiryInterval, Session), Channel#channel{session = Session}. %% TODO: Add more stats. @@ -367,10 +371,10 @@ handle_in(?PUBACK_PACKET(PacketId, _ReasonCode, Properties), Channel case emqx_session:puback(PacketId, Session) of {ok, Msg, NSession} -> ok = after_message_acked(ClientInfo, Msg, Properties), - {ok, Channel#channel{session = NSession}}; + {ok, set_session(NSession, Channel)}; {ok, Msg, Publishes, NSession} -> ok = after_message_acked(ClientInfo, Msg, Properties), - handle_out(publish, Publishes, Channel#channel{session = NSession}); + handle_out(publish, Publishes, set_session(NSession, Channel)); {error, ?RC_PACKET_IDENTIFIER_IN_USE} -> ?LOG(warning, "The PUBACK PacketId ~w is inuse.", [PacketId]), ok = emqx_metrics:inc('packets.puback.inuse'), @@ -386,7 +390,7 @@ handle_in(?PUBREC_PACKET(PacketId, _ReasonCode, Properties), Channel case emqx_session:pubrec(PacketId, Session) of {ok, Msg, NSession} -> ok = after_message_acked(ClientInfo, Msg, Properties), - NChannel = Channel#channel{session = NSession}, + NChannel = set_session(NSession, Channel), handle_out(pubrel, {PacketId, ?RC_SUCCESS}, NChannel); {error, RC = ?RC_PACKET_IDENTIFIER_IN_USE} -> ?LOG(warning, "The PUBREC PacketId ~w is inuse.", [PacketId]), @@ -401,7 +405,7 @@ handle_in(?PUBREC_PACKET(PacketId, _ReasonCode, Properties), Channel handle_in(?PUBREL_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Session}) -> case emqx_session:pubrel(PacketId, Session) of {ok, NSession} -> - NChannel = Channel#channel{session = NSession}, + NChannel = set_session(NSession, Channel), handle_out(pubcomp, {PacketId, ?RC_SUCCESS}, NChannel); {error, RC = ?RC_PACKET_IDENTIFIER_NOT_FOUND} -> ?LOG(warning, "The PUBREL PacketId ~w is not found.", [PacketId]), @@ -412,9 +416,9 @@ handle_in(?PUBREL_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Se handle_in(?PUBCOMP_PACKET(PacketId, _ReasonCode), Channel = #channel{session = Session}) -> case emqx_session:pubcomp(PacketId, Session) of {ok, NSession} -> - {ok, Channel#channel{session = NSession}}; + {ok, set_session(NSession, Channel)}; {ok, Publishes, NSession} -> - handle_out(publish, Publishes, Channel#channel{session = NSession}); + handle_out(publish, Publishes, set_session(NSession, Channel)); {error, ?RC_PACKET_IDENTIFIER_IN_USE} -> ok = emqx_metrics:inc('packets.pubcomp.inuse'), {ok, Channel}; @@ -614,7 +618,8 @@ do_publish(PacketId, Msg = #message{qos = ?QOS_2}, case emqx_session:publish(PacketId, Msg, Session) of {ok, PubRes, NSession} -> RC = puback_reason_code(PubRes), - NChannel1 = ensure_timer(await_timer, Channel#channel{session = NSession}), + NChannel0 = set_session(NSession, Channel), + NChannel1 = ensure_timer(await_timer, NChannel0), NChannel2 = ensure_quota(PubRes, NChannel1), handle_out(pubrec, {PacketId, RC}, NChannel2); {error, RC = ?RC_PACKET_IDENTIFIER_IN_USE} -> @@ -683,7 +688,7 @@ do_subscribe(TopicFilter, SubOpts = #{qos := QoS}, Channel = NSubOpts = enrich_subopts(maps:merge(?DEFAULT_SUBOPTS, SubOpts), Channel), case emqx_session:subscribe(ClientInfo, NTopicFilter, NSubOpts, Session) of {ok, NSession} -> - {QoS, Channel#channel{session = NSession}}; + {QoS, set_session(NSession, Channel)}; {error, RC} -> ?LOG(warning, "Cannot subscribe ~s due to ~s.", [TopicFilter, emqx_reason_codes:text(RC)]), @@ -711,7 +716,7 @@ do_unsubscribe(TopicFilter, SubOpts, Channel = TopicFilter1 = emqx_mountpoint:mount(MountPoint, TopicFilter), case emqx_session:unsubscribe(ClientInfo, TopicFilter1, SubOpts, Session) of {ok, NSession} -> - {?RC_SUCCESS, Channel#channel{session = NSession}}; + {?RC_SUCCESS, set_session(NSession, Channel)}; {error, RC} -> {RC, Channel} end. %%-------------------------------------------------------------------- @@ -736,7 +741,9 @@ process_disconnect(ReasonCode, Properties, Channel) -> maybe_update_expiry_interval(#{'Session-Expiry-Interval' := Interval}, Channel = #channel{conninfo = ConnInfo}) -> - Channel#channel{conninfo = ConnInfo#{expiry_interval => timer:seconds(Interval)}}; + NChannel = Channel#channel{conninfo = ConnInfo#{expiry_interval => timer:seconds(Interval)}}, + %% We need to update the expiry interval on the session as well + set_session(NChannel#channel.session, NChannel); maybe_update_expiry_interval(_Properties, Channel) -> Channel. %%-------------------------------------------------------------------- @@ -749,7 +756,7 @@ handle_deliver(Delivers, Channel = #channel{conn_state = disconnected, session = Session, clientinfo = #{clientid := ClientId}}) -> NSession = emqx_session:enqueue(ignore_local(maybe_nack(Delivers), ClientId, Session), Session), - {ok, Channel#channel{session = NSession}}; + {ok, set_session(NSession, Channel)}; handle_deliver(Delivers, Channel = #channel{takeover = true, pendings = Pendings, @@ -762,10 +769,10 @@ handle_deliver(Delivers, Channel = #channel{session = Session, clientinfo = #{clientid := ClientId}}) -> case emqx_session:deliver(ignore_local(Delivers, ClientId, Session), Session) of {ok, Publishes, NSession} -> - NChannel = Channel#channel{session = NSession}, + NChannel = set_session(NSession, Channel), handle_out(publish, Publishes, ensure_timer(retry_timer, NChannel)); {ok, NSession} -> - {ok, Channel#channel{session = NSession}} + {ok, set_session(NSession, Channel)} end. ignore_local(Delivers, Subscriber, Session) -> @@ -881,13 +888,13 @@ return_connack(AckPacket, Channel) -> case maybe_resume_session(Channel) of ignore -> {ok, Replies, Channel}; {ok, Publishes, NSession} -> - NChannel = Channel#channel{session = NSession, - resuming = false, + NChannel0 = Channel#channel{resuming = false, pendings = [] }, - {Packets, NChannel1} = do_deliver(Publishes, NChannel), + NChannel1 = set_session(NSession, NChannel0), + {Packets, NChannel2} = do_deliver(Publishes, NChannel1), Outgoing = [{outgoing, Packets} || length(Packets) > 0], - {ok, Replies ++ Outgoing, NChannel1} + {ok, Replies ++ Outgoing, NChannel2} end. %%-------------------------------------------------------------------- @@ -1047,9 +1054,9 @@ handle_timeout(_TRef, retry_delivery, Channel = #channel{session = Session}) -> case emqx_session:retry(Session) of {ok, NSession} -> - {ok, clean_timer(retry_timer, Channel#channel{session = NSession})}; + {ok, clean_timer(retry_timer, set_session(NSession, Channel))}; {ok, Publishes, Timeout, NSession} -> - NChannel = Channel#channel{session = NSession}, + NChannel = set_session(NSession, Channel), handle_out(publish, Publishes, reset_timer(retry_timer, Timeout, NChannel)) end; @@ -1060,9 +1067,9 @@ handle_timeout(_TRef, expire_awaiting_rel, Channel = #channel{session = Session}) -> case emqx_session:expire(awaiting_rel, Session) of {ok, NSession} -> - {ok, clean_timer(await_timer, Channel#channel{session = NSession})}; + {ok, clean_timer(await_timer, set_session(NSession, Channel))}; {ok, Timeout, NSession} -> - {ok, reset_timer(await_timer, Timeout, Channel#channel{session = NSession})} + {ok, reset_timer(await_timer, Timeout, set_session(NSession, Channel))} end; handle_timeout(_TRef, expire_session, Channel) -> diff --git a/apps/emqx/src/emqx_cm.erl b/apps/emqx/src/emqx_cm.erl index f4f5f3981..8676c9a8a 100644 --- a/apps/emqx/src/emqx_cm.erl +++ b/apps/emqx/src/emqx_cm.erl @@ -211,21 +211,33 @@ set_chan_stats(ClientId, ChanPid, Stats) -> pendings => list()}} | {error, Reason :: term()}). open_session(true, ClientInfo = #{clientid := ClientId}, ConnInfo) -> + EI = maps:get(expiry_interval, ConnInfo, 0), Self = self(), CleanStart = fun(_) -> ok = discard_session(ClientId), Session = create_session(ClientInfo, ConnInfo), + emqx_session:db_put(ClientId, EI, Session), register_channel(ClientId, Self, ConnInfo), {ok, #{session => Session, present => false}} end, emqx_cm_locker:trans(ClientId, CleanStart); open_session(false, ClientInfo = #{clientid := ClientId}, ConnInfo) -> + EI = maps:get(expiry_interval, ConnInfo, 0), Self = self(), ResumeStart = fun(_) -> case takeover_session(ClientId) of + {ok, Session} -> + %% TODO: Any messages in the mean time was lost. + ok = emqx_session:resume(ClientInfo, Session), + emqx_session:db_put(ClientId, EI, Session), + register_channel(ClientId, Self, ConnInfo), + {ok, #{session => Session, + present => true, + pendings => []}}; {ok, ConnMod, ChanPid, Session} -> ok = emqx_session:resume(ClientInfo, Session), + emqx_session:db_put(ClientId, EI, Session), Pendings = ConnMod:call(ChanPid, {takeover, 'end'}, ?T_TAKEOVER), register_channel(ClientId, Self, ConnInfo), {ok, #{session => Session, @@ -233,6 +245,7 @@ open_session(false, ClientInfo = #{clientid := ClientId}, ConnInfo) -> pendings => Pendings}}; {error, not_found} -> Session = create_session(ClientInfo, ConnInfo), + emqx_session:db_put(ClientId, EI, Session), register_channel(ClientId, Self, ConnInfo), {ok, #{session => Session, present => false}} end @@ -271,7 +284,11 @@ get_mqtt_conf(Zone, Key) -> | {ok, atom(), pid(), emqx_session:session()}). takeover_session(ClientId) -> case lookup_channels(ClientId) of - [] -> {error, not_found}; + [] -> + case emqx_session:db_get(ClientId) of + [] -> {error, not_found}; + [Session] -> {ok, Session} + end; [ChanPid] -> takeover_session(ClientId, ChanPid); ChanPids -> @@ -286,7 +303,10 @@ takeover_session(ClientId) -> takeover_session(ClientId, ChanPid) when node(ChanPid) == node() -> case get_chann_conn_mod(ClientId, ChanPid) of undefined -> - {error, not_found}; + case emqx_session:db_get(ClientId) of + [] -> {error, not_found}; + [Session] -> {ok, Session} + end; ConnMod when is_atom(ConnMod) -> Session = ConnMod:call(ChanPid, {takeover, 'begin'}, ?T_TAKEOVER), {ok, ConnMod, ChanPid, Session} diff --git a/apps/emqx/src/emqx_session.erl b/apps/emqx/src/emqx_session.erl index f915155cb..d4cbdae73 100644 --- a/apps/emqx/src/emqx_session.erl +++ b/apps/emqx/src/emqx_session.erl @@ -54,6 +54,15 @@ -compile(nowarn_export_all). -endif. +%% DB API +-export([ mnesia/1 + , db_get/1 + , db_put/3 + ]). + +-boot_mnesia({mnesia, [boot]}). +-copy_mnesia({mnesia, [copy]}). + -export([init/1]). -export([ info/1 @@ -159,6 +168,27 @@ , mqueue => emqx_mqueue:options() }. +%%-------------------------------------------------------------------- +%% Mnesia bootstrap +%%-------------------------------------------------------------------- + +-define(SESSION_STORE, emqx_session_store). +-record(session_store, { id :: binary() + , expiry_interval :: non_neg_integer() + , ts :: non_neg_integer() + , session :: #session{}}). + +mnesia(boot) -> + ok = ekka_mnesia:create_table(?SESSION_STORE, [ + {type, set}, + {ram_copies, [node()]}, + {record_name, session_store}, + {attributes, record_info(fields, session_store)}, + {storage_properties, [{ets, [{read_concurrency, true}]}]}]); + +mnesia(copy) -> + ok = ekka_mnesia:copy_table(?SESSION_STORE, ram_copies). + %%-------------------------------------------------------------------- %% Init a Session %%-------------------------------------------------------------------- @@ -184,6 +214,41 @@ init(Opts) -> created_at = erlang:system_time(millisecond) }. +%%-------------------------------------------------------------------- +%% DB API +%%-------------------------------------------------------------------- + +db_put(undefined,_ExpiryInterval, #session{}) -> + ok; +db_put(SessionID, ExpiryInterval, #session{} = Session) when is_binary(SessionID), + is_integer(ExpiryInterval) -> + SS = #session_store{ id = SessionID + , expiry_interval = ExpiryInterval + , ts = erlang:system_time(millisecond) + , session = Session}, + case use_db_session(SS) of + false -> ekka_mnesia:dirty_delete(?SESSION_STORE, SessionID); + true -> ekka_mnesia:dirty_write(?SESSION_STORE, SS) + end. + +db_get(SessionID) when is_binary(SessionID) -> + case mnesia:dirty_read(?SESSION_STORE, SessionID) of + [] -> []; + [#session_store{session = S} = SS] -> + case use_db_session(SS) of + true -> [S]; + false -> [] + end + end. + +%% @private [MQTT-3.1.2-23] +use_db_session(#session_store{expiry_interval = 0}) -> + false; +use_db_session(#session_store{expiry_interval = 16#FFFFFFFF}) -> + true; +use_db_session(#session_store{expiry_interval = E, ts = TS}) -> + E*1000 + TS > erlang:system_time(millisecond). + %%-------------------------------------------------------------------- %% Info, Stats %%-------------------------------------------------------------------- diff --git a/apps/emqx/test/emqx_channel_SUITE.erl b/apps/emqx/test/emqx_channel_SUITE.erl index be7c94ede..48380190f 100644 --- a/apps/emqx/test/emqx_channel_SUITE.erl +++ b/apps/emqx/test/emqx_channel_SUITE.erl @@ -191,6 +191,7 @@ init_per_suite(Config) -> ok = meck:expect(emqx_hooks, run_fold, fun(_Hook, _Args, Acc) -> Acc end), %% Session Meck ok = meck:new(emqx_session, [passthrough, no_history, no_link]), + meck:expect(emqx_session, db_put, fun(_, _, _) -> ok end), %% Metrics ok = meck:new(emqx_metrics, [passthrough, no_history, no_link]), ok = meck:expect(emqx_metrics, inc, fun(_) -> ok end),