feat(sessds): preserve acks / ranges in mnesia for replays

This commit is contained in:
Andrew Mayorov 2023-11-20 13:28:20 +07:00
parent 5b40304d1f
commit 1246d714c5
No known key found for this signature in database
GPG Key ID: 2837C62ACFBFED5D
4 changed files with 418 additions and 222 deletions

View File

@ -357,18 +357,12 @@ do_t_session_discard(Params) ->
_Attempts0 = 50,
true = map_size(emqx_persistent_session_ds:list_all_streams()) > 0
),
?retry(
_Sleep0 = 100,
_Attempts0 = 50,
true = map_size(emqx_persistent_session_ds:list_all_iterators()) > 0
),
ok = emqtt:stop(Client0),
?tp(notice, "disconnected", #{}),
?tp(notice, "reconnecting", #{}),
%% we still have iterators and streams
%% we still have streams
?assert(map_size(emqx_persistent_session_ds:list_all_streams()) > 0),
?assert(map_size(emqx_persistent_session_ds:list_all_iterators()) > 0),
Client1 = start_client(ReconnectOpts),
{ok, _} = emqtt:connect(Client1),
?assertEqual([], emqtt:subscriptions(Client1)),
@ -381,7 +375,7 @@ do_t_session_discard(Params) ->
?assertEqual(#{}, emqx_persistent_session_ds:list_all_subscriptions()),
?assertEqual([], emqx_persistent_session_ds_router:topics()),
?assertEqual(#{}, emqx_persistent_session_ds:list_all_streams()),
?assertEqual(#{}, emqx_persistent_session_ds:list_all_iterators()),
?assertEqual(#{}, emqx_persistent_session_ds:list_all_pubranges()),
ok = emqtt:stop(Client1),
?tp(notice, "disconnected", #{}),

View File

@ -19,12 +19,12 @@
-module(emqx_persistent_message_ds_replayer).
%% API:
-export([new/0, next_packet_id/1, replay/2, commit_offset/3, poll/3, n_inflight/1]).
-export([new/0, open/1, next_packet_id/1, replay/1, commit_offset/3, poll/3, n_inflight/1]).
%% internal exports:
-export([]).
-export_type([inflight/0]).
-export_type([inflight/0, seqno/0]).
-include_lib("emqx/include/logger.hrl").
-include("emqx_persistent_session_ds.hrl").
@ -42,17 +42,28 @@
-type seqno() :: non_neg_integer().
-record(range, {
stream :: emqx_ds:stream(),
stream :: _StreamRef,
first :: seqno(),
last :: seqno(),
iterator_next :: emqx_ds:iterator() | undefined
until :: seqno(),
%% Type of a range:
%% * Inflight range is a range of yet unacked messages from this stream.
%% * Checkpoint range was already acked, its purpose is to keep track of the
%% very last iterator for this stream.
type :: inflight | checkpoint,
%% Meaning of this depends on the type of the range:
%% * For inflight range, this is the iterator pointing to the first message in
%% the range.
%% * For checkpoint range, this is the iterator pointing right past the last
%% message in the range.
iterator :: emqx_ds:iterator()
}).
-type range() :: #range{}.
-record(inflight, {
next_seqno = 0 :: seqno(),
acked_seqno = 0 :: seqno(),
next_seqno = 1 :: seqno(),
acked_until = 1 :: seqno(),
%% Ranges are sorted in ascending order of their sequence numbers.
offset_ranges = [] :: [range()]
}).
@ -66,34 +77,37 @@
new() ->
#inflight{}.
-spec open(emqx_persistent_session_ds:id()) -> inflight().
open(SessionId) ->
Ranges = ro_transaction(fun() -> get_ranges(SessionId) end),
{AckedUntil, NextSeqno} = compute_inflight_range(Ranges),
#inflight{
acked_until = AckedUntil,
next_seqno = NextSeqno,
offset_ranges = Ranges
}.
-spec next_packet_id(inflight()) -> {emqx_types:packet_id(), inflight()}.
next_packet_id(Inflight0 = #inflight{next_seqno = LastSeqNo}) ->
Inflight = Inflight0#inflight{next_seqno = LastSeqNo + 1},
case LastSeqNo rem 16#10000 of
0 ->
%% We skip sequence numbers that lead to PacketId = 0 to
%% simplify math. Note: it leads to occasional gaps in the
%% sequence numbers.
next_packet_id(Inflight);
PacketId ->
{PacketId, Inflight}
end.
next_packet_id(Inflight0 = #inflight{next_seqno = LastSeqno}) ->
Inflight = Inflight0#inflight{next_seqno = next_seqno(LastSeqno)},
{seqno_to_packet_id(LastSeqno), Inflight}.
-spec n_inflight(inflight()) -> non_neg_integer().
n_inflight(#inflight{next_seqno = NextSeqNo, acked_seqno = AckedSeqno}) ->
%% NOTE: this function assumes that gaps in the sequence ID occur
%% _only_ when the packet ID wraps:
case AckedSeqno >= ((NextSeqNo bsr 16) bsl 16) of
true ->
NextSeqNo - AckedSeqno;
false ->
NextSeqNo - AckedSeqno - 1
end.
n_inflight(#inflight{next_seqno = NextSeqno, acked_until = AckedUntil}) ->
range_size(AckedUntil, NextSeqno).
-spec replay(emqx_persistent_session_ds:id(), inflight()) ->
emqx_session:replies().
replay(_SessionId, _Inflight = #inflight{offset_ranges = _Ranges}) ->
[].
-spec replay(inflight()) ->
{emqx_session:replies(), inflight()}.
replay(Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0}) ->
{Ranges, Replies} = lists:mapfoldr(
fun(Range, Acc) ->
replay_range(Range, AckedUntil, Acc)
end,
[],
Ranges0
),
Inflight = Inflight0#inflight{offset_ranges = Ranges},
{Replies, Inflight}.
-spec commit_offset(emqx_persistent_session_ds:id(), emqx_types:packet_id(), inflight()) ->
{_IsValidOffset :: boolean(), inflight()}.
@ -101,47 +115,34 @@ commit_offset(
SessionId,
PacketId,
Inflight0 = #inflight{
acked_seqno = AckedSeqno0, next_seqno = NextSeqNo, offset_ranges = Ranges0
acked_until = AckedUntil, next_seqno = NextSeqno
}
) ->
AckedSeqno =
case packet_id_to_seqno(NextSeqNo, PacketId) of
N when N > AckedSeqno0; AckedSeqno0 =:= 0 ->
N;
OutOfRange ->
?SLOG(warning, #{
msg => "out-of-order_ack",
prev_seqno => AckedSeqno0,
acked_seqno => OutOfRange,
next_seqno => NextSeqNo,
packet_id => PacketId
}),
AckedSeqno0
end,
Ranges = lists:filter(
fun(#range{stream = Stream, last = LastSeqno, iterator_next = ItNext}) ->
case LastSeqno =< AckedSeqno of
true ->
%% This range has been fully
%% acked. Remove it and replace saved
%% iterator with the trailing iterator.
update_iterator(SessionId, Stream, ItNext),
false;
false ->
%% This range still has unacked
%% messages:
true
end
end,
Ranges0
),
Inflight = Inflight0#inflight{acked_seqno = AckedSeqno, offset_ranges = Ranges},
{true, Inflight}.
case packet_id_to_seqno(NextSeqno, PacketId) of
Seqno when Seqno >= AckedUntil andalso Seqno < NextSeqno ->
%% TODO
%% We do not preserve `acked_until` in the database. Instead, we discard
%% fully acked ranges from the database. In effect, this means that the
%% most recent `acked_until` the client has sent may be lost in case of a
%% crash or client loss.
Inflight1 = Inflight0#inflight{acked_until = next_seqno(Seqno)},
Inflight = discard_acked(SessionId, Inflight1),
{true, Inflight};
OutOfRange ->
?SLOG(warning, #{
msg => "out-of-order_ack",
acked_until => AckedUntil,
acked_seqno => OutOfRange,
next_seqno => NextSeqno,
packet_id => PacketId
}),
{false, Inflight0}
end.
-spec poll(emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
{emqx_session:replies(), inflight()}.
poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff ->
#inflight{next_seqno = NextSeqNo0, acked_seqno = AckedSeqno} =
#inflight{next_seqno = NextSeqNo0, acked_until = AckedSeqno} =
Inflight0,
FetchThreshold = max(1, WindowSize div 2),
FreeSpace = AckedSeqno + WindowSize - NextSeqNo0,
@ -153,6 +154,7 @@ poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff
%% client get stuck even?
{[], Inflight0};
true ->
%% TODO: Wrap this in `mria:async_dirty/2`?
Streams = shuffle(get_streams(SessionId)),
fetch(SessionId, Inflight0, Streams, FreeSpace, [])
end.
@ -165,75 +167,206 @@ poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff
%% Internal functions
%%================================================================================
fetch(_SessionId, Inflight, _Streams = [], _N, Acc) ->
{lists:reverse(Acc), Inflight};
fetch(_SessionId, Inflight, _Streams, 0, Acc) ->
{lists:reverse(Acc), Inflight};
fetch(SessionId, Inflight0, [Stream | Streams], N, Publishes0) ->
#inflight{next_seqno = FirstSeqNo, offset_ranges = Ranges0} = Inflight0,
ItBegin = get_last_iterator(SessionId, Stream, Ranges0),
compute_inflight_range([]) ->
{1, 1};
compute_inflight_range(Ranges) ->
_RangeLast = #range{until = LastSeqno} = lists:last(Ranges),
RangesUnacked = lists:dropwhile(fun(#range{type = T}) -> T == checkpoint end, Ranges),
case RangesUnacked of
[#range{first = AckedUntil} | _] ->
{AckedUntil, LastSeqno};
[] ->
{LastSeqno, LastSeqno}
end.
get_ranges(SessionId) ->
DSRanges = mnesia:match_object(
?SESSION_PUBRANGE_TAB,
#ds_pubrange{id = {SessionId, '_'}, _ = '_'},
read
),
lists:map(fun export_range/1, DSRanges).
export_range(#ds_pubrange{
type = Type, id = {_, First}, until = Until, stream = StreamRef, iterator = It
}) ->
#range{type = Type, stream = StreamRef, first = First, until = Until, iterator = It}.
fetch(SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
#inflight{next_seqno = FirstSeqno, offset_ranges = Ranges0} = Inflight0,
ItBegin = get_last_iterator(DSStream, Ranges0),
{ok, ItEnd, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N),
{NMessages, Publishes, Inflight1} =
lists:foldl(
fun(Msg, {N0, PubAcc0, InflightAcc0}) ->
{PacketId, InflightAcc} = next_packet_id(InflightAcc0),
PubAcc = [{PacketId, Msg} | PubAcc0],
{N0 + 1, PubAcc, InflightAcc}
end,
{0, Publishes0, Inflight0},
Messages
),
#inflight{next_seqno = LastSeqNo} = Inflight1,
case NMessages > 0 of
true ->
Range = #range{
first = FirstSeqNo,
last = LastSeqNo - 1,
stream = Stream,
iterator_next = ItEnd
{Publishes, UntilSeqno} = publish(FirstSeqno, Messages),
case range_size(FirstSeqno, UntilSeqno) of
Size when Size > 0 ->
Range0 = #range{
type = inflight,
first = FirstSeqno,
until = UntilSeqno,
stream = DSStream#ds_stream.ref,
iterator = ItBegin
},
Inflight = Inflight1#inflight{offset_ranges = Ranges0 ++ [Range]},
fetch(SessionId, Inflight, Streams, N - NMessages, Publishes);
false ->
fetch(SessionId, Inflight1, Streams, N, Publishes)
end.
%% We need to preserve the iterator pointing to the beginning of the
%% range, so that we can replay it if needed.
ok = preserve_range(SessionId, Range0),
%% ...Yet we need to keep the iterator pointing past the end of the
%% range, so that we can pick up where we left off: it will become
%% `ItBegin` of the next range for this stream.
Range = Range0#range{iterator = ItEnd},
Ranges = Ranges0 ++ [Range#range{iterator = ItEnd}],
Inflight = Inflight0#inflight{
next_seqno = UntilSeqno,
offset_ranges = Ranges
},
fetch(SessionId, Inflight, Streams, N - Size, [Publishes | Acc]);
0 ->
fetch(SessionId, Inflight0, Streams, N, Acc)
end;
fetch(_SessionId, Inflight, _Streams, _N, Acc) ->
Publishes = lists:append(lists:reverse(Acc)),
{Publishes, Inflight}.
-spec update_iterator(emqx_persistent_session_ds:id(), emqx_ds:stream(), emqx_ds:iterator()) -> ok.
update_iterator(DSSessionId, Stream, Iterator) ->
%% Workaround: we convert `Stream' to a binary before attempting to store it in
%% mnesia(rocksdb) because of a bug in `mnesia_rocksdb' when trying to do
%% `mnesia:dirty_all_keys' later.
StreamBin = term_to_binary(Stream),
mria:dirty_write(?SESSION_ITER_TAB, #ds_iter{id = {DSSessionId, StreamBin}, iter = Iterator}).
discard_acked(
SessionId,
Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0}
) ->
%% TODO: This could be kept and incrementally updated in the inflight state.
Checkpoints = find_checkpoints(Ranges0),
%% TODO: Wrap this in `mria:async_dirty/2`?
Ranges = discard_acked_ranges(SessionId, AckedUntil, Checkpoints, Ranges0),
Inflight0#inflight{offset_ranges = Ranges}.
get_last_iterator(SessionId, Stream, Ranges) ->
case lists:keyfind(Stream, #range.stream, lists:reverse(Ranges)) of
false ->
get_iterator(SessionId, Stream);
#range{iterator_next = Next} ->
Next
end.
-spec get_iterator(emqx_persistent_session_ds:id(), emqx_ds:stream()) -> emqx_ds:iterator().
get_iterator(DSSessionId, Stream) ->
%% See comment in `update_iterator'.
StreamBin = term_to_binary(Stream),
Id = {DSSessionId, StreamBin},
[#ds_iter{iter = It}] = mnesia:dirty_read(?SESSION_ITER_TAB, Id),
It.
-spec get_streams(emqx_persistent_session_ds:id()) -> [emqx_ds:stream()].
get_streams(SessionId) ->
lists:map(
fun(#ds_stream{stream = Stream}) ->
Stream
find_checkpoints(Ranges) ->
lists:foldl(
fun(#range{stream = StreamRef, until = Until}, Acc) ->
%% For each stream, remember the last range over this stream.
Acc#{StreamRef => Until}
end,
mnesia:dirty_read(?SESSION_STREAM_TAB, SessionId)
#{},
Ranges
).
discard_acked_ranges(
SessionId,
AckedUntil,
Checkpoints,
[Range = #range{until = Until, stream = StreamRef} | Rest]
) when Until =< AckedUntil ->
%% This range has been fully acked.
%% Either discard it completely, or preserve the iterator for the next range
%% over this stream (i.e. a checkpoint).
RangeKept =
case maps:get(StreamRef, Checkpoints) of
CP when CP > Until ->
discard_range(SessionId, Range),
[];
Until ->
checkpoint_range(SessionId, Range),
[Range#range{type = checkpoint}]
end,
%% Since we're (intentionally) not using transactions here, it's important to
%% issue database writes in the same order in which ranges are stored: from
%% the oldest to the newest. This is also why we need to compute which ranges
%% should become checkpoints before we start writing anything.
RangeKept ++ discard_acked_ranges(SessionId, AckedUntil, Checkpoints, Rest);
discard_acked_ranges(_SessionId, _AckedUntil, _Checkpoints, Ranges) ->
%% The rest of ranges (if any) still have unacked messages.
Ranges.
replay_range(
Range0 = #range{type = inflight, first = First, until = Until, iterator = It},
AckedUntil,
Acc
) ->
Size = range_size(First, Until),
FirstUnacked = max(First, AckedUntil),
{ok, ItNext, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size),
MessagesUnacked =
case FirstUnacked of
First ->
Messages;
_ ->
lists:nthtail(range_size(First, FirstUnacked), Messages)
end,
%% Asserting that range is consistent with the message storage state.
{Replies, Until} = publish(FirstUnacked, MessagesUnacked),
Range = Range0#range{iterator = ItNext},
{Range, Replies ++ Acc};
replay_range(Range0 = #range{type = checkpoint}, _AckedUntil, Acc) ->
{Range0, Acc}.
publish(FirstSeqno, Messages) ->
lists:mapfoldl(
fun(Message, Seqno) ->
PacketId = seqno_to_packet_id(Seqno),
{{PacketId, Message}, next_seqno(Seqno)}
end,
FirstSeqno,
Messages
).
-spec preserve_range(emqx_persistent_session_ds:id(), range()) -> ok.
preserve_range(
SessionId,
#range{first = First, until = Until, stream = StreamRef, iterator = It}
) ->
DSRange = #ds_pubrange{
id = {SessionId, First},
until = Until,
stream = StreamRef,
type = inflight,
iterator = It
},
mria:dirty_write(?SESSION_PUBRANGE_TAB, DSRange).
-spec discard_range(emqx_persistent_session_ds:id(), range()) -> ok.
discard_range(SessionId, #range{first = First}) ->
mria:dirty_delete(?SESSION_PUBRANGE_TAB, {SessionId, First}).
-spec checkpoint_range(emqx_persistent_session_ds:id(), range()) -> ok.
checkpoint_range(
SessionId,
#range{type = inflight, first = First, until = Until, stream = StreamRef, iterator = ItNext}
) ->
DSRange = #ds_pubrange{
id = {SessionId, First},
until = Until,
stream = StreamRef,
type = checkpoint,
iterator = ItNext
},
mria:dirty_write(?SESSION_PUBRANGE_TAB, DSRange);
checkpoint_range(_SessionId, #range{type = checkpoint}) ->
%% This range should have been checkpointed already.
ok.
get_last_iterator(DSStream = #ds_stream{ref = StreamRef}, Ranges) ->
case lists:keyfind(StreamRef, #range.stream, lists:reverse(Ranges)) of
false ->
DSStream#ds_stream.beginning;
#range{iterator = ItNext} ->
ItNext
end.
-spec get_streams(emqx_persistent_session_ds:id()) -> [ds_stream()].
get_streams(SessionId) ->
mnesia:dirty_read(?SESSION_STREAM_TAB, SessionId).
next_seqno(Seqno) ->
NextSeqno = Seqno + 1,
case seqno_to_packet_id(NextSeqno) of
0 ->
%% We skip sequence numbers that lead to PacketId = 0 to
%% simplify math. Note: it leads to occasional gaps in the
%% sequence numbers.
NextSeqno + 1;
_ ->
NextSeqno
end.
%% Reconstruct session counter by adding most significant bits from
%% the current counter to the packet id.
-spec packet_id_to_seqno(non_neg_integer(), emqx_types:packet_id()) -> non_neg_integer().
-spec packet_id_to_seqno(_Next :: seqno(), emqx_types:packet_id()) -> seqno().
packet_id_to_seqno(NextSeqNo, PacketId) ->
Epoch = NextSeqNo bsr 16,
case packet_id_to_seqno_(Epoch, PacketId) of
@ -243,10 +376,20 @@ packet_id_to_seqno(NextSeqNo, PacketId) ->
packet_id_to_seqno_(Epoch - 1, PacketId)
end.
-spec packet_id_to_seqno_(non_neg_integer(), emqx_types:packet_id()) -> non_neg_integer().
-spec packet_id_to_seqno_(non_neg_integer(), emqx_types:packet_id()) -> seqno().
packet_id_to_seqno_(Epoch, PacketId) ->
(Epoch bsl 16) + PacketId.
-spec seqno_to_packet_id(seqno()) -> emqx_types:packet_id().
seqno_to_packet_id(Seqno) ->
Seqno rem 16#10000.
range_size(FirstSeqno, UntilSeqno) ->
%% This function assumes that gaps in the sequence ID occur _only_ when the
%% packet ID wraps.
Size = UntilSeqno - FirstSeqno,
Size + (FirstSeqno bsr 16) - (UntilSeqno bsr 16).
-spec shuffle([A]) -> [A].
shuffle(L0) ->
L1 = lists:map(
@ -259,6 +402,10 @@ shuffle(L0) ->
{_, L} = lists:unzip(L2),
L.
ro_transaction(Fun) ->
{atomic, Res} = mria:ro_transaction(?DS_MRIA_SHARD, Fun),
Res.
-ifdef(TEST).
%% This test only tests boundary conditions (to make sure property-based test didn't skip them):
@ -311,4 +458,40 @@ seqno_gen(NextSeqNo) ->
Max = max(0, NextSeqNo - 1),
range(Min, Max).
range_size_test_() ->
[
?_assertEqual(0, range_size(42, 42)),
?_assertEqual(1, range_size(42, 43)),
?_assertEqual(1, range_size(16#ffff, 16#10001)),
?_assertEqual(16#ffff - 456 + 123, range_size(16#1f0000 + 456, 16#200000 + 123))
].
compute_inflight_range_test_() ->
[
?_assertEqual(
{1, 1},
compute_inflight_range([])
),
?_assertEqual(
{12, 42},
compute_inflight_range([
#range{first = 1, until = 2, type = checkpoint},
#range{first = 4, until = 8, type = checkpoint},
#range{first = 11, until = 12, type = checkpoint},
#range{first = 12, until = 13, type = inflight},
#range{first = 13, until = 20, type = inflight},
#range{first = 20, until = 42, type = inflight}
])
),
?_assertEqual(
{13, 13},
compute_inflight_range([
#range{first = 1, until = 2, type = checkpoint},
#range{first = 4, until = 8, type = checkpoint},
#range{first = 11, until = 12, type = checkpoint},
#range{first = 12, until = 13, type = checkpoint}
])
)
].
-endif.

View File

@ -76,7 +76,7 @@
list_all_sessions/0,
list_all_subscriptions/0,
list_all_streams/0,
list_all_iterators/0
list_all_pubranges/0
]).
-endif.
@ -359,15 +359,16 @@ handle_timeout(
end,
ensure_timer(pull, Timeout),
{ok, Publishes, Session#{inflight => Inflight}};
handle_timeout(_ClientInfo, get_streams, Session = #{id := Id}) ->
renew_streams(Id),
handle_timeout(_ClientInfo, get_streams, Session) ->
renew_streams(Session),
ensure_timer(get_streams),
{ok, [], Session}.
-spec replay(clientinfo(), [], session()) ->
{ok, replies(), session()}.
replay(_ClientInfo, [], Session = #{}) ->
{ok, [], Session}.
replay(_ClientInfo, [], Session = #{inflight := Inflight0}) ->
{Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(Inflight0),
{ok, Replies, Session#{inflight := Inflight}}.
%%--------------------------------------------------------------------
@ -474,17 +475,20 @@ create_tables() ->
]
),
ok = mria:create_table(
?SESSION_ITER_TAB,
?SESSION_PUBRANGE_TAB,
[
{rlog_shard, ?DS_MRIA_SHARD},
{type, set},
{type, ordered_set},
{storage, storage()},
{record_name, ds_iter},
{attributes, record_info(fields, ds_iter)}
{record_name, ds_pubrange},
{attributes, record_info(fields, ds_pubrange)}
]
),
ok = mria:wait_for_tables([
?SESSION_TAB, ?SESSION_SUBSCRIPTIONS_TAB, ?SESSION_STREAM_TAB, ?SESSION_ITER_TAB
?SESSION_TAB,
?SESSION_SUBSCRIPTIONS_TAB,
?SESSION_STREAM_TAB,
?SESSION_PUBRANGE_TAB
]),
ok.
@ -512,9 +516,10 @@ session_open(SessionId) ->
Session = export_session(Record),
DSSubs = session_read_subscriptions(SessionId),
Subscriptions = export_subscriptions(DSSubs),
Inflight = emqx_persistent_message_ds_replayer:open(SessionId),
Session#{
subscriptions => Subscriptions,
inflight => emqx_persistent_message_ds_replayer:new()
inflight => Inflight
};
[] ->
false
@ -549,7 +554,7 @@ session_create(SessionId, Props) ->
session_drop(DSSessionId) ->
transaction(fun() ->
ok = session_drop_subscriptions(DSSessionId),
ok = session_drop_iterators(DSSessionId),
ok = session_drop_pubranges(DSSessionId),
ok = session_drop_streams(DSSessionId),
ok = mnesia:delete(?SESSION_TAB, DSSessionId, write)
end).
@ -663,77 +668,82 @@ do_ensure_all_iterators_closed(_DSSessionID) ->
%% Reading batches
%%--------------------------------------------------------------------
-spec renew_streams(id()) -> ok.
renew_streams(DSSessionId) ->
Subscriptions = ro_transaction(fun() -> session_read_subscriptions(DSSessionId) end),
ExistingStreams = ro_transaction(fun() -> mnesia:read(?SESSION_STREAM_TAB, DSSessionId) end),
lists:foreach(
fun(#ds_sub{id = {_, TopicFilter}, start_time = StartTime}) ->
renew_streams(DSSessionId, ExistingStreams, TopicFilter, StartTime)
-spec renew_streams(session()) -> ok.
renew_streams(#{id := SessionId, subscriptions := Subscriptions}) ->
transaction(fun() ->
ExistingStreams = mnesia:read(?SESSION_STREAM_TAB, SessionId, write),
maps:fold(
fun(TopicFilter, #{start_time := StartTime}, Streams) ->
TopicFilterWords = emqx_topic:words(TopicFilter),
renew_topic_streams(SessionId, TopicFilterWords, StartTime, Streams)
end,
ExistingStreams,
Subscriptions
)
end),
ok.
-spec renew_topic_streams(id(), topic_filter_words(), emqx_ds:time(), _Acc :: [ds_stream()]) -> ok.
renew_topic_streams(DSSessionId, TopicFilter, StartTime, ExistingStreams) ->
TopicStreams = emqx_ds:get_streams(?PERSISTENT_MESSAGE_DB, TopicFilter, StartTime),
lists:foldl(
fun({Rank, Stream}, Streams) ->
case lists:keymember(Stream, #ds_stream.stream, Streams) of
true ->
Streams;
false ->
StreamRef = length(Streams) + 1,
DSStream = session_store_stream(
DSSessionId,
StreamRef,
Stream,
Rank,
TopicFilter,
StartTime
),
[DSStream | Streams]
end
end,
Subscriptions
ExistingStreams,
TopicStreams
).
-spec renew_streams(id(), [ds_stream()], topic_filter_words(), emqx_ds:time()) -> ok.
renew_streams(DSSessionId, ExistingStreams, TopicFilter, StartTime) ->
AllStreams = emqx_ds:get_streams(?PERSISTENT_MESSAGE_DB, TopicFilter, StartTime),
transaction(
fun() ->
lists:foreach(
fun({Rank, Stream}) ->
Rec = #ds_stream{
session = DSSessionId,
topic_filter = TopicFilter,
stream = Stream,
rank = Rank
},
case lists:member(Rec, ExistingStreams) of
true ->
ok;
false ->
mnesia:write(?SESSION_STREAM_TAB, Rec, write),
{ok, Iterator} = emqx_ds:make_iterator(
?PERSISTENT_MESSAGE_DB, Stream, TopicFilter, StartTime
),
%% Workaround: we convert `Stream' to a binary before
%% attempting to store it in mnesia(rocksdb) because of a bug
%% in `mnesia_rocksdb' when trying to do
%% `mnesia:dirty_all_keys' later.
StreamBin = term_to_binary(Stream),
IterRec = #ds_iter{id = {DSSessionId, StreamBin}, iter = Iterator},
mnesia:write(?SESSION_ITER_TAB, IterRec, write)
end
end,
AllStreams
)
end
).
session_store_stream(DSSessionId, StreamRef, Stream, Rank, TopicFilter, StartTime) ->
{ok, ItBegin} = emqx_ds:make_iterator(
?PERSISTENT_MESSAGE_DB,
Stream,
TopicFilter,
StartTime
),
DSStream = #ds_stream{
session = DSSessionId,
ref = StreamRef,
stream = Stream,
rank = Rank,
beginning = ItBegin
},
mnesia:write(?SESSION_STREAM_TAB, DSStream, write),
DSStream.
%% must be called inside a transaction
-spec session_drop_streams(id()) -> ok.
session_drop_streams(DSSessionId) ->
MS = ets:fun2ms(
fun(#ds_stream{session = DSSessionId0}) when DSSessionId0 =:= DSSessionId ->
DSSessionId0
end
),
StreamIDs = mnesia:select(?SESSION_STREAM_TAB, MS, write),
lists:foreach(fun(Key) -> mnesia:delete(?SESSION_STREAM_TAB, Key, write) end, StreamIDs).
mnesia:delete(?SESSION_STREAM_TAB, DSSessionId, write).
%% must be called inside a transaction
-spec session_drop_iterators(id()) -> ok.
session_drop_iterators(DSSessionId) ->
-spec session_drop_pubranges(id()) -> ok.
session_drop_pubranges(DSSessionId) ->
MS = ets:fun2ms(
fun(#ds_iter{id = {DSSessionId0, StreamBin}}) when DSSessionId0 =:= DSSessionId ->
StreamBin
fun(#ds_pubrange{id = {DSSessionId0, First}}) when DSSessionId0 =:= DSSessionId ->
{DSSessionId, First}
end
),
StreamBins = mnesia:select(?SESSION_ITER_TAB, MS, write),
RangeIds = mnesia:select(?SESSION_PUBRANGE_TAB, MS, write),
lists:foreach(
fun(StreamBin) ->
mnesia:delete(?SESSION_ITER_TAB, {DSSessionId, StreamBin}, write)
fun(RangeId) ->
mnesia:delete(?SESSION_PUBRANGE_TAB, RangeId, write)
end,
StreamBins
RangeIds
).
%%--------------------------------------------------------------------------------
@ -758,7 +768,7 @@ export_subscriptions(DSSubs) ->
).
export_session(#session{} = Record) ->
export_record(Record, #session.id, [id, created_at, expires_at, inflight, props], #{}).
export_record(Record, #session.id, [id, created_at, expires_at, props], #{}).
export_subscription(#ds_sub{} = Record) ->
export_record(Record, #ds_sub.start_time, [start_time, props, extra], #{}).
@ -833,16 +843,18 @@ list_all_streams() ->
),
maps:from_list(DSStreams).
list_all_iterators() ->
DSIterIds = mnesia:dirty_all_keys(?SESSION_ITER_TAB),
DSIters = lists:map(
fun(DSIterId) ->
[Record] = mnesia:dirty_read(?SESSION_ITER_TAB, DSIterId),
{DSIterId, export_record(Record, #ds_iter.id, [id, iter], #{})}
list_all_pubranges() ->
DSPubranges = mnesia:dirty_match_object(?SESSION_PUBRANGE_TAB, #ds_pubrange{_ = '_'}),
lists:foldl(
fun(Record = #ds_pubrange{id = {SessionId, First}}, Acc) ->
Range = export_record(
Record, #ds_pubrange.until, [until, stream, type, iterator], #{first => First}
),
maps:put(SessionId, maps:get(SessionId, Acc, []) ++ [Range], Acc)
end,
DSIterIds
),
maps:from_list(DSIters).
#{},
DSPubranges
).
%% ifdef(TEST)
-endif.

View File

@ -21,7 +21,7 @@
-define(SESSION_TAB, emqx_ds_session).
-define(SESSION_SUBSCRIPTIONS_TAB, emqx_ds_session_subscriptions).
-define(SESSION_STREAM_TAB, emqx_ds_stream_tab).
-define(SESSION_ITER_TAB, emqx_ds_iter_tab).
-define(SESSION_PUBRANGE_TAB, emqx_ds_pubrange_tab).
-define(DS_MRIA_SHARD, emqx_ds_session_shard).
-record(ds_sub, {
@ -34,17 +34,24 @@
-record(ds_stream, {
session :: emqx_persistent_session_ds:id(),
topic_filter :: emqx_ds:topic_filter(),
ref :: _StreamRef,
stream :: emqx_ds:stream(),
rank :: emqx_ds:stream_rank()
rank :: emqx_ds:stream_rank(),
beginning :: emqx_ds:iterator()
}).
-type ds_stream() :: #ds_stream{}.
-type ds_stream_bin() :: binary().
-record(ds_iter, {
id :: {emqx_persistent_session_ds:id(), ds_stream_bin()},
iter :: emqx_ds:iterator()
-record(ds_pubrange, {
id :: {
_Session :: emqx_persistent_session_ds:id(),
_First :: emqx_persistent_message_ds_replayer:seqno()
},
until :: emqx_persistent_message_ds_replayer:seqno(),
stream :: _StreamRef,
type :: inflight | checkpoint,
iterator :: emqx_ds:iterator()
}).
-type ds_pubrange() :: #ds_pubrange{}.
-record(session, {
%% same as clientid
@ -52,7 +59,7 @@
%% creation time
created_at :: _Millisecond :: non_neg_integer(),
expires_at = never :: _Millisecond :: non_neg_integer() | never,
inflight :: emqx_persistent_message_ds_replayer:inflight(),
% last_ack = 0 :: emqx_persistent_message_ds_replayer:seqno(),
%% for future usage
props = #{} :: map()
}).