diff --git a/apps/emqx/src/emqx_persistent_message_ds_replayer.erl b/apps/emqx/src/emqx_persistent_message_ds_replayer.erl index 32934492a..723f02a01 100644 --- a/apps/emqx/src/emqx_persistent_message_ds_replayer.erl +++ b/apps/emqx/src/emqx_persistent_message_ds_replayer.erl @@ -33,6 +33,8 @@ -export_type([inflight/0, seqno/0]). -include_lib("emqx/include/logger.hrl"). +-include_lib("emqx/include/emqx_mqtt.hrl"). +-include_lib("emqx_utils/include/emqx_message.hrl"). -include("emqx_persistent_session_ds.hrl"). -ifdef(TEST). @@ -46,6 +48,8 @@ -define(COMP, 1). -define(TRACK_FLAG(WHICH), (1 bsl WHICH)). +-define(TRACK_FLAGS_ALL, ?TRACK_FLAG(?ACK) bor ?TRACK_FLAG(?COMP)). +-define(TRACK_FLAGS_NONE, 0). %%================================================================================ %% Type declarations @@ -66,9 +70,10 @@ -opaque inflight() :: #inflight{}. --type replies() :: reply() | [replies()]. --type reply() :: emqx_session:reply() | fun((emqx_types:packet_id()) -> emqx_session:replies()). --type reply_fun() :: fun((seqno(), emqx_types:message()) -> replies()). +-type message() :: emqx_types:message(). +-type replies() :: [emqx_session:reply()]. + +-type preproc_fun() :: fun((message()) -> message() | [message()]). %%================================================================================ %% API funcions @@ -115,11 +120,11 @@ n_inflight(#inflight{offset_ranges = Ranges}) -> Ranges ). --spec replay(reply_fun(), inflight()) -> {emqx_session:replies(), inflight()}. -replay(ReplyFun, Inflight0 = #inflight{offset_ranges = Ranges0}) -> +-spec replay(preproc_fun(), inflight()) -> {emqx_session:replies(), inflight()}. +replay(PreprocFunFun, Inflight0 = #inflight{offset_ranges = Ranges0, commits = Commits}) -> {Ranges, Replies} = lists:mapfoldr( fun(Range, Acc) -> - replay_range(ReplyFun, Range, Acc) + replay_range(PreprocFunFun, Commits, Range, Acc) end, [], Ranges0 @@ -165,9 +170,9 @@ commit_offset( {false, Inflight0} end. --spec poll(reply_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) -> +-spec poll(preproc_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) -> {emqx_session:replies(), inflight()}. -poll(ReplyFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE -> +poll(PreprocFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE -> MinBatchSize = emqx_config:get([session_persistence, min_batch_size]), FetchThreshold = min(MinBatchSize, ceil(WindowSize / 2)), FreeSpace = WindowSize - n_inflight(Inflight0), @@ -181,7 +186,7 @@ poll(ReplyFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize true -> %% TODO: Wrap this in `mria:async_dirty/2`? Streams = shuffle(get_streams(SessionId)), - fetch(ReplyFun, SessionId, Inflight0, Streams, FreeSpace, []) + fetch(PreprocFun, SessionId, Inflight0, Streams, FreeSpace, []) end. %% Which seqno this track is committed until. @@ -248,22 +253,22 @@ get_ranges(SessionId) -> ), mnesia:match_object(?SESSION_PUBRANGE_TAB, Pat, read). -fetch(ReplyFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 -> +fetch(PreprocFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 -> #inflight{next_seqno = FirstSeqno, offset_ranges = Ranges} = Inflight0, ItBegin = get_last_iterator(DSStream, Ranges), {ok, ItEnd, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N), case Messages of [] -> - fetch(ReplyFun, SessionId, Inflight0, Streams, N, Acc); + fetch(PreprocFun, SessionId, Inflight0, Streams, N, Acc); _ -> %% We need to preserve the iterator pointing to the beginning of the %% range, so that we can replay it if needed. - {Publishes, {UntilSeqno, Tracks}} = publish(ReplyFun, FirstSeqno, Messages), + {Publishes, UntilSeqno} = publish_fetch(PreprocFun, FirstSeqno, Messages), Size = range_size(FirstSeqno, UntilSeqno), Range0 = #ds_pubrange{ id = {SessionId, FirstSeqno}, type = ?T_INFLIGHT, - tracks = Tracks, + tracks = compute_pub_tracks(Publishes), until = UntilSeqno, stream = DSStream#ds_stream.ref, iterator = ItBegin @@ -277,7 +282,7 @@ fetch(ReplyFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 - next_seqno = UntilSeqno, offset_ranges = Ranges ++ [Range] }, - fetch(ReplyFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc]) + fetch(PreprocFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc]) end; fetch(_ReplyFun, _SessionId, Inflight, _Streams, _N, Acc) -> Publishes = lists:append(lists:reverse(Acc)), @@ -374,19 +379,20 @@ discard_tracks(#{ack := AckedUntil, comp := CompUntil}, Until, Tracks) -> TAck bor TComp. replay_range( - ReplyFun, + PreprocFun, + Commits, Range0 = #ds_pubrange{type = ?T_INFLIGHT, id = {_, First}, until = Until, iterator = It}, Acc ) -> Size = range_size(First, Until), {ok, ItNext, MessagesUnacked} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size), %% Asserting that range is consistent with the message storage state. - {Replies, {Until, _TracksInitial}} = publish(ReplyFun, First, MessagesUnacked), + {Replies, Until} = publish_replay(PreprocFun, Commits, First, MessagesUnacked), %% Again, we need to keep the iterator pointing past the end of the %% range, so that we can pick up where we left off. Range = keep_next_iterator(ItNext, Range0), {Range, Replies ++ Acc}; -replay_range(_ReplyFun, Range0 = #ds_pubrange{type = ?T_CHECKPOINT}, Acc) -> +replay_range(_PreprocFun, _Commits, Range0 = #ds_pubrange{type = ?T_CHECKPOINT}, Acc) -> {Range0, Acc}. validate_commit( @@ -419,33 +425,88 @@ get_commit_next(rec, #inflight{next_seqno = NextSeqno}) -> get_commit_next(comp, #inflight{commits = Commits}) -> maps:get(rec, Commits). -publish(ReplyFun, FirstSeqno, Messages) -> - lists:mapfoldl( - fun(Message, Acc = {Seqno, _Tracks}) -> - Reply = ReplyFun(Seqno, Message), - publish_reply(Reply, Acc) +publish_fetch(PreprocFun, FirstSeqno, Messages) -> + flatmapfoldl( + fun(MessageIn, Acc) -> + Message = PreprocFun(MessageIn), + publish_fetch(Message, Acc) end, - {FirstSeqno, 0}, + FirstSeqno, Messages ). -publish_reply(Replies = [_ | _], Acc) -> - lists:mapfoldl(fun publish_reply/2, Acc, Replies); -publish_reply(Reply, {Seqno, Tracks}) when is_function(Reply) -> - Pub = Reply(seqno_to_packet_id(Seqno)), - NextSeqno = next_seqno(Seqno), - NextTracks = add_pub_track(Pub, Tracks), - {Pub, {NextSeqno, NextTracks}}; -publish_reply(Reply, Acc) -> - {Reply, Acc}. +publish_fetch(#message{qos = ?QOS_0} = Message, Seqno) -> + {{undefined, Message}, Seqno}; +publish_fetch(#message{} = Message, Seqno) -> + PacketId = seqno_to_packet_id(Seqno), + {{PacketId, Message}, next_seqno(Seqno)}; +publish_fetch(Messages, Seqno) -> + flatmapfoldl(fun publish_fetch/2, Seqno, Messages). -add_pub_track({PacketId, Message}, Tracks) when is_integer(PacketId) -> - case emqx_message:qos(Message) of - 1 -> ?TRACK_FLAG(?ACK) bor Tracks; - 2 -> ?TRACK_FLAG(?COMP) bor Tracks; - _ -> Tracks +publish_replay(PreprocFun, Commits, FirstSeqno, Messages) -> + #{ack := AckedUntil, comp := CompUntil, rec := RecUntil} = Commits, + flatmapfoldl( + fun(MessageIn, Acc) -> + Message = PreprocFun(MessageIn), + publish_replay(Message, AckedUntil, CompUntil, RecUntil, Acc) + end, + FirstSeqno, + Messages + ). + +publish_replay(#message{qos = ?QOS_0}, _, _, _, Seqno) -> + %% QoS 0 (at most once) messages should not be replayed. + {[], Seqno}; +publish_replay(#message{qos = Qos} = Message, AckedUntil, CompUntil, RecUntil, Seqno) -> + case Qos of + ?QOS_1 when Seqno < AckedUntil -> + %% This message has already been acked, so we can skip it. + %% We still need to advance seqno, because previously we assigned this message + %% a unique Packet Id. + {[], next_seqno(Seqno)}; + ?QOS_2 when Seqno < CompUntil -> + %% This message's flow has already been fully completed, so we can skip it. + %% We still need to advance seqno, because previously we assigned this message + %% a unique Packet Id. + {[], next_seqno(Seqno)}; + ?QOS_2 when Seqno < RecUntil -> + %% This message's flow has been partially completed, we need to resend a PUBREL. + PacketId = seqno_to_packet_id(Seqno), + Pub = {pubrel, PacketId}, + {Pub, next_seqno(Seqno)}; + _ -> + %% This message flow hasn't been acked and/or received, we need to resend it. + PacketId = seqno_to_packet_id(Seqno), + Pub = {PacketId, emqx_message:set_flag(dup, true, Message)}, + {Pub, next_seqno(Seqno)} end; -add_pub_track(_Pub, Tracks) -> +publish_replay([], _, _, _, Seqno) -> + {[], Seqno}; +publish_replay(Messages, AckedUntil, CompUntil, RecUntil, Seqno) -> + flatmapfoldl( + fun(Message, Acc) -> + publish_replay(Message, AckedUntil, CompUntil, RecUntil, Acc) + end, + Seqno, + Messages + ). + +-spec compute_pub_tracks(replies()) -> non_neg_integer(). +compute_pub_tracks(Pubs) -> + compute_pub_tracks(Pubs, ?TRACK_FLAGS_NONE). + +compute_pub_tracks(_Pubs, Tracks = ?TRACK_FLAGS_ALL) -> + Tracks; +compute_pub_tracks([Pub | Rest], Tracks) -> + Track = + case Pub of + {_PacketId, #message{qos = ?QOS_1}} -> ?TRACK_FLAG(?ACK); + {_PacketId, #message{qos = ?QOS_2}} -> ?TRACK_FLAG(?COMP); + {pubrel, _PacketId} -> ?TRACK_FLAG(?COMP); + _ -> ?TRACK_FLAGS_NONE + end, + compute_pub_tracks(Rest, Track bor Tracks); +compute_pub_tracks([], Tracks) -> Tracks. keep_next_iterator(ItNext, Range = #ds_pubrange{iterator = ItFirst, misc = Misc}) -> @@ -550,6 +611,19 @@ shuffle(L0) -> {_, L} = lists:unzip(L2), L. +-spec flatmapfoldl(fun((X, Acc) -> {Y | [Y], Acc}), Acc, [X]) -> {[Y], Acc}. +flatmapfoldl(_Fun, Acc, []) -> + {[], Acc}; +flatmapfoldl(Fun, Acc, [X | Xs]) -> + {Ys, NAcc} = Fun(X, Acc), + {Zs, FAcc} = flatmapfoldl(Fun, NAcc, Xs), + case is_list(Ys) of + true -> + {Ys ++ Zs, FAcc}; + _ -> + {[Ys | Zs], FAcc} + end. + ro_transaction(Fun) -> {atomic, Res} = mria:ro_transaction(?DS_MRIA_SHARD, Fun), Res. diff --git a/apps/emqx/src/emqx_persistent_session_ds.erl b/apps/emqx/src/emqx_persistent_session_ds.erl index a25e4da3b..bedd72a16 100644 --- a/apps/emqx/src/emqx_persistent_session_ds.erl +++ b/apps/emqx/src/emqx_persistent_session_ds.erl @@ -152,9 +152,8 @@ -spec create(clientinfo(), conninfo(), emqx_session:conf()) -> session(). create(#{clientid := ClientID}, ConnInfo, Conf) -> - % TODO: expiration - Session = ensure_timers(session_ensure_new(ClientID, ConnInfo)), - preserve_conf(ConnInfo, Conf, Session). + Session = session_ensure_new(ClientID, ConnInfo), + apply_conf(ConnInfo, Conf, ensure_timers(Session)). -spec open(clientinfo(), conninfo(), emqx_session:conf()) -> {_IsPresent :: true, session(), []} | false. @@ -168,13 +167,13 @@ open(#{clientid := ClientID} = _ClientInfo, ConnInfo, Conf) -> ok = emqx_cm:discard_session(ClientID), case session_open(ClientID, ConnInfo) of Session0 = #{} -> - Session = preserve_conf(ConnInfo, Conf, Session0), + Session = apply_conf(ConnInfo, Conf, Session0), {true, ensure_timers(Session), []}; false -> false end. -preserve_conf(ConnInfo, Conf, Session) -> +apply_conf(ConnInfo, Conf, Session) -> Session#{ receive_maximum => receive_maximum(ConnInfo), props => Conf @@ -399,7 +398,7 @@ deliver(_ClientInfo, _Delivers, Session) -> {ok, replies(), session()} | {ok, replies(), timeout(), session()}. handle_timeout( ClientInfo, - pull, + ?TIMER_PULL, Session0 = #{ id := Id, inflight := Inflight0, @@ -411,14 +410,9 @@ handle_timeout( MaxBatchSize = emqx_config:get([session_persistence, max_batch_size]), BatchSize = min(ReceiveMaximum, MaxBatchSize), UpgradeQoS = maps:get(upgrade_qos, Conf), - ReplyFun = make_reply_fun(ClientInfo, Subs, UpgradeQoS, fun - (_Seqno, Message = #message{qos = ?QOS_0}) -> - {undefined, Message}; - (_Seqno, Message) -> - fun(PacketId) -> {PacketId, Message} end - end), + PreprocFun = make_preproc_fun(ClientInfo, Subs, UpgradeQoS), {Publishes, Inflight} = emqx_persistent_message_ds_replayer:poll( - ReplyFun, + PreprocFun, Id, Inflight0, BatchSize @@ -455,22 +449,8 @@ replay( Session = #{inflight := Inflight0, subscriptions := Subs, props := Conf} ) -> UpgradeQoS = maps:get(upgrade_qos, Conf), - AckedUntil = emqx_persistent_message_ds_replayer:committed_until(ack, Inflight0), - RecUntil = emqx_persistent_message_ds_replayer:committed_until(rec, Inflight0), - CompUntil = emqx_persistent_message_ds_replayer:committed_until(comp, Inflight0), - ReplyFun = make_reply_fun(ClientInfo, Subs, UpgradeQoS, fun - (_Seqno, #message{qos = ?QOS_0}) -> - []; - (Seqno, #message{qos = ?QOS_1}) when Seqno < AckedUntil -> - fun(_) -> [] end; - (Seqno, #message{qos = ?QOS_2}) when Seqno < CompUntil -> - fun(_) -> [] end; - (Seqno, #message{qos = ?QOS_2}) when Seqno < RecUntil -> - fun(PacketId) -> {pubrel, PacketId} end; - (_Seqno, Message) -> - fun(PacketId) -> {PacketId, emqx_message:set_flag(dup, true, Message)} end - end), - {Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(ReplyFun, Inflight0), + PreprocFun = make_preproc_fun(ClientInfo, Subs, UpgradeQoS), + {Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(PreprocFun, Inflight0), {ok, Replies, Session#{inflight := Inflight}}. %%-------------------------------------------------------------------- @@ -486,23 +466,17 @@ terminate(_Reason, _Session = #{}) -> %%-------------------------------------------------------------------- -make_reply_fun(ClientInfo, Subs, UpgradeQoS, InnerFun) -> - fun(Seqno, Message0 = #message{topic = Topic}) -> +make_preproc_fun(ClientInfo, Subs, UpgradeQoS) -> + fun(Message = #message{topic = Topic}) -> emqx_utils:flattermap( fun(Match) -> - emqx_utils:flattermap( - fun(Message) -> InnerFun(Seqno, Message) end, - enrich_message(ClientInfo, Message0, Match, Subs, UpgradeQoS) - ) + #{props := SubOpts} = subs_get_match(Match, Subs), + emqx_session:enrich_message(ClientInfo, Message, SubOpts, UpgradeQoS) end, subs_matches(Topic, Subs) ) end. -enrich_message(ClientInfo, Message, SubMatch, Subs, UpgradeQoS) -> - #{props := SubOpts} = subs_get_match(SubMatch, Subs), - emqx_session:enrich_message(ClientInfo, Message, SubOpts, UpgradeQoS). - %%-------------------------------------------------------------------- -spec add_subscription(topic_filter(), emqx_types:subopts(), id()) ->