refactor(sessds): move parts of message processing to replayer

To simplify the processing flow, reducing the number of back-and-forth
between the session and the replayer.
This commit is contained in:
Andrew Mayorov 2023-12-01 15:23:58 +03:00
parent 29ec73847a
commit fd26e690b8
No known key found for this signature in database
GPG Key ID: 2837C62ACFBFED5D
2 changed files with 125 additions and 77 deletions

View File

@ -33,6 +33,8 @@
-export_type([inflight/0, seqno/0]).
-include_lib("emqx/include/logger.hrl").
-include_lib("emqx/include/emqx_mqtt.hrl").
-include_lib("emqx_utils/include/emqx_message.hrl").
-include("emqx_persistent_session_ds.hrl").
-ifdef(TEST).
@ -46,6 +48,8 @@
-define(COMP, 1).
-define(TRACK_FLAG(WHICH), (1 bsl WHICH)).
-define(TRACK_FLAGS_ALL, ?TRACK_FLAG(?ACK) bor ?TRACK_FLAG(?COMP)).
-define(TRACK_FLAGS_NONE, 0).
%%================================================================================
%% Type declarations
@ -66,9 +70,10 @@
-opaque inflight() :: #inflight{}.
-type replies() :: reply() | [replies()].
-type reply() :: emqx_session:reply() | fun((emqx_types:packet_id()) -> emqx_session:replies()).
-type reply_fun() :: fun((seqno(), emqx_types:message()) -> replies()).
-type message() :: emqx_types:message().
-type replies() :: [emqx_session:reply()].
-type preproc_fun() :: fun((message()) -> message() | [message()]).
%%================================================================================
%% API funcions
@ -115,11 +120,11 @@ n_inflight(#inflight{offset_ranges = Ranges}) ->
Ranges
).
-spec replay(reply_fun(), inflight()) -> {emqx_session:replies(), inflight()}.
replay(ReplyFun, Inflight0 = #inflight{offset_ranges = Ranges0}) ->
-spec replay(preproc_fun(), inflight()) -> {emqx_session:replies(), inflight()}.
replay(PreprocFunFun, Inflight0 = #inflight{offset_ranges = Ranges0, commits = Commits}) ->
{Ranges, Replies} = lists:mapfoldr(
fun(Range, Acc) ->
replay_range(ReplyFun, Range, Acc)
replay_range(PreprocFunFun, Commits, Range, Acc)
end,
[],
Ranges0
@ -165,9 +170,9 @@ commit_offset(
{false, Inflight0}
end.
-spec poll(reply_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
-spec poll(preproc_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
{emqx_session:replies(), inflight()}.
poll(ReplyFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE ->
poll(PreprocFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE ->
MinBatchSize = emqx_config:get([session_persistence, min_batch_size]),
FetchThreshold = min(MinBatchSize, ceil(WindowSize / 2)),
FreeSpace = WindowSize - n_inflight(Inflight0),
@ -181,7 +186,7 @@ poll(ReplyFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize
true ->
%% TODO: Wrap this in `mria:async_dirty/2`?
Streams = shuffle(get_streams(SessionId)),
fetch(ReplyFun, SessionId, Inflight0, Streams, FreeSpace, [])
fetch(PreprocFun, SessionId, Inflight0, Streams, FreeSpace, [])
end.
%% Which seqno this track is committed until.
@ -248,22 +253,22 @@ get_ranges(SessionId) ->
),
mnesia:match_object(?SESSION_PUBRANGE_TAB, Pat, read).
fetch(ReplyFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
fetch(PreprocFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
#inflight{next_seqno = FirstSeqno, offset_ranges = Ranges} = Inflight0,
ItBegin = get_last_iterator(DSStream, Ranges),
{ok, ItEnd, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N),
case Messages of
[] ->
fetch(ReplyFun, SessionId, Inflight0, Streams, N, Acc);
fetch(PreprocFun, SessionId, Inflight0, Streams, N, Acc);
_ ->
%% We need to preserve the iterator pointing to the beginning of the
%% range, so that we can replay it if needed.
{Publishes, {UntilSeqno, Tracks}} = publish(ReplyFun, FirstSeqno, Messages),
{Publishes, UntilSeqno} = publish_fetch(PreprocFun, FirstSeqno, Messages),
Size = range_size(FirstSeqno, UntilSeqno),
Range0 = #ds_pubrange{
id = {SessionId, FirstSeqno},
type = ?T_INFLIGHT,
tracks = Tracks,
tracks = compute_pub_tracks(Publishes),
until = UntilSeqno,
stream = DSStream#ds_stream.ref,
iterator = ItBegin
@ -277,7 +282,7 @@ fetch(ReplyFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 -
next_seqno = UntilSeqno,
offset_ranges = Ranges ++ [Range]
},
fetch(ReplyFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
fetch(PreprocFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
end;
fetch(_ReplyFun, _SessionId, Inflight, _Streams, _N, Acc) ->
Publishes = lists:append(lists:reverse(Acc)),
@ -374,19 +379,20 @@ discard_tracks(#{ack := AckedUntil, comp := CompUntil}, Until, Tracks) ->
TAck bor TComp.
replay_range(
ReplyFun,
PreprocFun,
Commits,
Range0 = #ds_pubrange{type = ?T_INFLIGHT, id = {_, First}, until = Until, iterator = It},
Acc
) ->
Size = range_size(First, Until),
{ok, ItNext, MessagesUnacked} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size),
%% Asserting that range is consistent with the message storage state.
{Replies, {Until, _TracksInitial}} = publish(ReplyFun, First, MessagesUnacked),
{Replies, Until} = publish_replay(PreprocFun, Commits, First, MessagesUnacked),
%% Again, we need to keep the iterator pointing past the end of the
%% range, so that we can pick up where we left off.
Range = keep_next_iterator(ItNext, Range0),
{Range, Replies ++ Acc};
replay_range(_ReplyFun, Range0 = #ds_pubrange{type = ?T_CHECKPOINT}, Acc) ->
replay_range(_PreprocFun, _Commits, Range0 = #ds_pubrange{type = ?T_CHECKPOINT}, Acc) ->
{Range0, Acc}.
validate_commit(
@ -419,33 +425,88 @@ get_commit_next(rec, #inflight{next_seqno = NextSeqno}) ->
get_commit_next(comp, #inflight{commits = Commits}) ->
maps:get(rec, Commits).
publish(ReplyFun, FirstSeqno, Messages) ->
lists:mapfoldl(
fun(Message, Acc = {Seqno, _Tracks}) ->
Reply = ReplyFun(Seqno, Message),
publish_reply(Reply, Acc)
publish_fetch(PreprocFun, FirstSeqno, Messages) ->
flatmapfoldl(
fun(MessageIn, Acc) ->
Message = PreprocFun(MessageIn),
publish_fetch(Message, Acc)
end,
{FirstSeqno, 0},
FirstSeqno,
Messages
).
publish_reply(Replies = [_ | _], Acc) ->
lists:mapfoldl(fun publish_reply/2, Acc, Replies);
publish_reply(Reply, {Seqno, Tracks}) when is_function(Reply) ->
Pub = Reply(seqno_to_packet_id(Seqno)),
NextSeqno = next_seqno(Seqno),
NextTracks = add_pub_track(Pub, Tracks),
{Pub, {NextSeqno, NextTracks}};
publish_reply(Reply, Acc) ->
{Reply, Acc}.
publish_fetch(#message{qos = ?QOS_0} = Message, Seqno) ->
{{undefined, Message}, Seqno};
publish_fetch(#message{} = Message, Seqno) ->
PacketId = seqno_to_packet_id(Seqno),
{{PacketId, Message}, next_seqno(Seqno)};
publish_fetch(Messages, Seqno) ->
flatmapfoldl(fun publish_fetch/2, Seqno, Messages).
add_pub_track({PacketId, Message}, Tracks) when is_integer(PacketId) ->
case emqx_message:qos(Message) of
1 -> ?TRACK_FLAG(?ACK) bor Tracks;
2 -> ?TRACK_FLAG(?COMP) bor Tracks;
_ -> Tracks
publish_replay(PreprocFun, Commits, FirstSeqno, Messages) ->
#{ack := AckedUntil, comp := CompUntil, rec := RecUntil} = Commits,
flatmapfoldl(
fun(MessageIn, Acc) ->
Message = PreprocFun(MessageIn),
publish_replay(Message, AckedUntil, CompUntil, RecUntil, Acc)
end,
FirstSeqno,
Messages
).
publish_replay(#message{qos = ?QOS_0}, _, _, _, Seqno) ->
%% QoS 0 (at most once) messages should not be replayed.
{[], Seqno};
publish_replay(#message{qos = Qos} = Message, AckedUntil, CompUntil, RecUntil, Seqno) ->
case Qos of
?QOS_1 when Seqno < AckedUntil ->
%% This message has already been acked, so we can skip it.
%% We still need to advance seqno, because previously we assigned this message
%% a unique Packet Id.
{[], next_seqno(Seqno)};
?QOS_2 when Seqno < CompUntil ->
%% This message's flow has already been fully completed, so we can skip it.
%% We still need to advance seqno, because previously we assigned this message
%% a unique Packet Id.
{[], next_seqno(Seqno)};
?QOS_2 when Seqno < RecUntil ->
%% This message's flow has been partially completed, we need to resend a PUBREL.
PacketId = seqno_to_packet_id(Seqno),
Pub = {pubrel, PacketId},
{Pub, next_seqno(Seqno)};
_ ->
%% This message flow hasn't been acked and/or received, we need to resend it.
PacketId = seqno_to_packet_id(Seqno),
Pub = {PacketId, emqx_message:set_flag(dup, true, Message)},
{Pub, next_seqno(Seqno)}
end;
add_pub_track(_Pub, Tracks) ->
publish_replay([], _, _, _, Seqno) ->
{[], Seqno};
publish_replay(Messages, AckedUntil, CompUntil, RecUntil, Seqno) ->
flatmapfoldl(
fun(Message, Acc) ->
publish_replay(Message, AckedUntil, CompUntil, RecUntil, Acc)
end,
Seqno,
Messages
).
-spec compute_pub_tracks(replies()) -> non_neg_integer().
compute_pub_tracks(Pubs) ->
compute_pub_tracks(Pubs, ?TRACK_FLAGS_NONE).
compute_pub_tracks(_Pubs, Tracks = ?TRACK_FLAGS_ALL) ->
Tracks;
compute_pub_tracks([Pub | Rest], Tracks) ->
Track =
case Pub of
{_PacketId, #message{qos = ?QOS_1}} -> ?TRACK_FLAG(?ACK);
{_PacketId, #message{qos = ?QOS_2}} -> ?TRACK_FLAG(?COMP);
{pubrel, _PacketId} -> ?TRACK_FLAG(?COMP);
_ -> ?TRACK_FLAGS_NONE
end,
compute_pub_tracks(Rest, Track bor Tracks);
compute_pub_tracks([], Tracks) ->
Tracks.
keep_next_iterator(ItNext, Range = #ds_pubrange{iterator = ItFirst, misc = Misc}) ->
@ -550,6 +611,19 @@ shuffle(L0) ->
{_, L} = lists:unzip(L2),
L.
-spec flatmapfoldl(fun((X, Acc) -> {Y | [Y], Acc}), Acc, [X]) -> {[Y], Acc}.
flatmapfoldl(_Fun, Acc, []) ->
{[], Acc};
flatmapfoldl(Fun, Acc, [X | Xs]) ->
{Ys, NAcc} = Fun(X, Acc),
{Zs, FAcc} = flatmapfoldl(Fun, NAcc, Xs),
case is_list(Ys) of
true ->
{Ys ++ Zs, FAcc};
_ ->
{[Ys | Zs], FAcc}
end.
ro_transaction(Fun) ->
{atomic, Res} = mria:ro_transaction(?DS_MRIA_SHARD, Fun),
Res.

View File

@ -152,9 +152,8 @@
-spec create(clientinfo(), conninfo(), emqx_session:conf()) ->
session().
create(#{clientid := ClientID}, ConnInfo, Conf) ->
% TODO: expiration
Session = ensure_timers(session_ensure_new(ClientID, ConnInfo)),
preserve_conf(ConnInfo, Conf, Session).
Session = session_ensure_new(ClientID, ConnInfo),
apply_conf(ConnInfo, Conf, ensure_timers(Session)).
-spec open(clientinfo(), conninfo(), emqx_session:conf()) ->
{_IsPresent :: true, session(), []} | false.
@ -168,13 +167,13 @@ open(#{clientid := ClientID} = _ClientInfo, ConnInfo, Conf) ->
ok = emqx_cm:discard_session(ClientID),
case session_open(ClientID, ConnInfo) of
Session0 = #{} ->
Session = preserve_conf(ConnInfo, Conf, Session0),
Session = apply_conf(ConnInfo, Conf, Session0),
{true, ensure_timers(Session), []};
false ->
false
end.
preserve_conf(ConnInfo, Conf, Session) ->
apply_conf(ConnInfo, Conf, Session) ->
Session#{
receive_maximum => receive_maximum(ConnInfo),
props => Conf
@ -399,7 +398,7 @@ deliver(_ClientInfo, _Delivers, Session) ->
{ok, replies(), session()} | {ok, replies(), timeout(), session()}.
handle_timeout(
ClientInfo,
pull,
?TIMER_PULL,
Session0 = #{
id := Id,
inflight := Inflight0,
@ -411,14 +410,9 @@ handle_timeout(
MaxBatchSize = emqx_config:get([session_persistence, max_batch_size]),
BatchSize = min(ReceiveMaximum, MaxBatchSize),
UpgradeQoS = maps:get(upgrade_qos, Conf),
ReplyFun = make_reply_fun(ClientInfo, Subs, UpgradeQoS, fun
(_Seqno, Message = #message{qos = ?QOS_0}) ->
{undefined, Message};
(_Seqno, Message) ->
fun(PacketId) -> {PacketId, Message} end
end),
PreprocFun = make_preproc_fun(ClientInfo, Subs, UpgradeQoS),
{Publishes, Inflight} = emqx_persistent_message_ds_replayer:poll(
ReplyFun,
PreprocFun,
Id,
Inflight0,
BatchSize
@ -455,22 +449,8 @@ replay(
Session = #{inflight := Inflight0, subscriptions := Subs, props := Conf}
) ->
UpgradeQoS = maps:get(upgrade_qos, Conf),
AckedUntil = emqx_persistent_message_ds_replayer:committed_until(ack, Inflight0),
RecUntil = emqx_persistent_message_ds_replayer:committed_until(rec, Inflight0),
CompUntil = emqx_persistent_message_ds_replayer:committed_until(comp, Inflight0),
ReplyFun = make_reply_fun(ClientInfo, Subs, UpgradeQoS, fun
(_Seqno, #message{qos = ?QOS_0}) ->
[];
(Seqno, #message{qos = ?QOS_1}) when Seqno < AckedUntil ->
fun(_) -> [] end;
(Seqno, #message{qos = ?QOS_2}) when Seqno < CompUntil ->
fun(_) -> [] end;
(Seqno, #message{qos = ?QOS_2}) when Seqno < RecUntil ->
fun(PacketId) -> {pubrel, PacketId} end;
(_Seqno, Message) ->
fun(PacketId) -> {PacketId, emqx_message:set_flag(dup, true, Message)} end
end),
{Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(ReplyFun, Inflight0),
PreprocFun = make_preproc_fun(ClientInfo, Subs, UpgradeQoS),
{Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(PreprocFun, Inflight0),
{ok, Replies, Session#{inflight := Inflight}}.
%%--------------------------------------------------------------------
@ -486,23 +466,17 @@ terminate(_Reason, _Session = #{}) ->
%%--------------------------------------------------------------------
make_reply_fun(ClientInfo, Subs, UpgradeQoS, InnerFun) ->
fun(Seqno, Message0 = #message{topic = Topic}) ->
make_preproc_fun(ClientInfo, Subs, UpgradeQoS) ->
fun(Message = #message{topic = Topic}) ->
emqx_utils:flattermap(
fun(Match) ->
emqx_utils:flattermap(
fun(Message) -> InnerFun(Seqno, Message) end,
enrich_message(ClientInfo, Message0, Match, Subs, UpgradeQoS)
)
#{props := SubOpts} = subs_get_match(Match, Subs),
emqx_session:enrich_message(ClientInfo, Message, SubOpts, UpgradeQoS)
end,
subs_matches(Topic, Subs)
)
end.
enrich_message(ClientInfo, Message, SubMatch, Subs, UpgradeQoS) ->
#{props := SubOpts} = subs_get_match(SubMatch, Subs),
emqx_session:enrich_message(ClientInfo, Message, SubOpts, UpgradeQoS).
%%--------------------------------------------------------------------
-spec add_subscription(topic_filter(), emqx_types:subopts(), id()) ->