diff --git a/apps/emqx/integration_test/emqx_persistent_session_ds_SUITE.erl b/apps/emqx/integration_test/emqx_persistent_session_ds_SUITE.erl index 72775228c..6c5fdc56e 100644 --- a/apps/emqx/integration_test/emqx_persistent_session_ds_SUITE.erl +++ b/apps/emqx/integration_test/emqx_persistent_session_ds_SUITE.erl @@ -357,18 +357,12 @@ do_t_session_discard(Params) -> _Attempts0 = 50, true = map_size(emqx_persistent_session_ds:list_all_streams()) > 0 ), - ?retry( - _Sleep0 = 100, - _Attempts0 = 50, - true = map_size(emqx_persistent_session_ds:list_all_iterators()) > 0 - ), ok = emqtt:stop(Client0), ?tp(notice, "disconnected", #{}), ?tp(notice, "reconnecting", #{}), - %% we still have iterators and streams + %% we still have streams ?assert(map_size(emqx_persistent_session_ds:list_all_streams()) > 0), - ?assert(map_size(emqx_persistent_session_ds:list_all_iterators()) > 0), Client1 = start_client(ReconnectOpts), {ok, _} = emqtt:connect(Client1), ?assertEqual([], emqtt:subscriptions(Client1)), @@ -381,7 +375,7 @@ do_t_session_discard(Params) -> ?assertEqual(#{}, emqx_persistent_session_ds:list_all_subscriptions()), ?assertEqual([], emqx_persistent_session_ds_router:topics()), ?assertEqual(#{}, emqx_persistent_session_ds:list_all_streams()), - ?assertEqual(#{}, emqx_persistent_session_ds:list_all_iterators()), + ?assertEqual(#{}, emqx_persistent_session_ds:list_all_pubranges()), ok = emqtt:stop(Client1), ?tp(notice, "disconnected", #{}), diff --git a/apps/emqx/src/emqx_persistent_message_ds_replayer.erl b/apps/emqx/src/emqx_persistent_message_ds_replayer.erl index 69b6675d8..a95e1c152 100644 --- a/apps/emqx/src/emqx_persistent_message_ds_replayer.erl +++ b/apps/emqx/src/emqx_persistent_message_ds_replayer.erl @@ -19,12 +19,12 @@ -module(emqx_persistent_message_ds_replayer). %% API: --export([new/0, next_packet_id/1, replay/2, commit_offset/3, poll/3, n_inflight/1]). +-export([new/0, open/1, next_packet_id/1, replay/1, commit_offset/3, poll/3, n_inflight/1]). %% internal exports: -export([]). --export_type([inflight/0]). +-export_type([inflight/0, seqno/0]). -include_lib("emqx/include/logger.hrl"). -include("emqx_persistent_session_ds.hrl"). @@ -42,17 +42,28 @@ -type seqno() :: non_neg_integer(). -record(range, { - stream :: emqx_ds:stream(), + stream :: _StreamRef, first :: seqno(), - last :: seqno(), - iterator_next :: emqx_ds:iterator() | undefined + until :: seqno(), + %% Type of a range: + %% * Inflight range is a range of yet unacked messages from this stream. + %% * Checkpoint range was already acked, its purpose is to keep track of the + %% very last iterator for this stream. + type :: inflight | checkpoint, + %% Meaning of this depends on the type of the range: + %% * For inflight range, this is the iterator pointing to the first message in + %% the range. + %% * For checkpoint range, this is the iterator pointing right past the last + %% message in the range. + iterator :: emqx_ds:iterator() }). -type range() :: #range{}. -record(inflight, { - next_seqno = 0 :: seqno(), - acked_seqno = 0 :: seqno(), + next_seqno = 1 :: seqno(), + acked_until = 1 :: seqno(), + %% Ranges are sorted in ascending order of their sequence numbers. offset_ranges = [] :: [range()] }). @@ -66,34 +77,37 @@ new() -> #inflight{}. +-spec open(emqx_persistent_session_ds:id()) -> inflight(). +open(SessionId) -> + Ranges = ro_transaction(fun() -> get_ranges(SessionId) end), + {AckedUntil, NextSeqno} = compute_inflight_range(Ranges), + #inflight{ + acked_until = AckedUntil, + next_seqno = NextSeqno, + offset_ranges = Ranges + }. + -spec next_packet_id(inflight()) -> {emqx_types:packet_id(), inflight()}. -next_packet_id(Inflight0 = #inflight{next_seqno = LastSeqNo}) -> - Inflight = Inflight0#inflight{next_seqno = LastSeqNo + 1}, - case LastSeqNo rem 16#10000 of - 0 -> - %% We skip sequence numbers that lead to PacketId = 0 to - %% simplify math. Note: it leads to occasional gaps in the - %% sequence numbers. - next_packet_id(Inflight); - PacketId -> - {PacketId, Inflight} - end. +next_packet_id(Inflight0 = #inflight{next_seqno = LastSeqno}) -> + Inflight = Inflight0#inflight{next_seqno = next_seqno(LastSeqno)}, + {seqno_to_packet_id(LastSeqno), Inflight}. -spec n_inflight(inflight()) -> non_neg_integer(). -n_inflight(#inflight{next_seqno = NextSeqNo, acked_seqno = AckedSeqno}) -> - %% NOTE: this function assumes that gaps in the sequence ID occur - %% _only_ when the packet ID wraps: - case AckedSeqno >= ((NextSeqNo bsr 16) bsl 16) of - true -> - NextSeqNo - AckedSeqno; - false -> - NextSeqNo - AckedSeqno - 1 - end. +n_inflight(#inflight{next_seqno = NextSeqno, acked_until = AckedUntil}) -> + range_size(AckedUntil, NextSeqno). --spec replay(emqx_persistent_session_ds:id(), inflight()) -> - emqx_session:replies(). -replay(_SessionId, _Inflight = #inflight{offset_ranges = _Ranges}) -> - []. +-spec replay(inflight()) -> + {emqx_session:replies(), inflight()}. +replay(Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0}) -> + {Ranges, Replies} = lists:mapfoldr( + fun(Range, Acc) -> + replay_range(Range, AckedUntil, Acc) + end, + [], + Ranges0 + ), + Inflight = Inflight0#inflight{offset_ranges = Ranges}, + {Replies, Inflight}. -spec commit_offset(emqx_persistent_session_ds:id(), emqx_types:packet_id(), inflight()) -> {_IsValidOffset :: boolean(), inflight()}. @@ -101,47 +115,34 @@ commit_offset( SessionId, PacketId, Inflight0 = #inflight{ - acked_seqno = AckedSeqno0, next_seqno = NextSeqNo, offset_ranges = Ranges0 + acked_until = AckedUntil, next_seqno = NextSeqno } ) -> - AckedSeqno = - case packet_id_to_seqno(NextSeqNo, PacketId) of - N when N > AckedSeqno0; AckedSeqno0 =:= 0 -> - N; - OutOfRange -> - ?SLOG(warning, #{ - msg => "out-of-order_ack", - prev_seqno => AckedSeqno0, - acked_seqno => OutOfRange, - next_seqno => NextSeqNo, - packet_id => PacketId - }), - AckedSeqno0 - end, - Ranges = lists:filter( - fun(#range{stream = Stream, last = LastSeqno, iterator_next = ItNext}) -> - case LastSeqno =< AckedSeqno of - true -> - %% This range has been fully - %% acked. Remove it and replace saved - %% iterator with the trailing iterator. - update_iterator(SessionId, Stream, ItNext), - false; - false -> - %% This range still has unacked - %% messages: - true - end - end, - Ranges0 - ), - Inflight = Inflight0#inflight{acked_seqno = AckedSeqno, offset_ranges = Ranges}, - {true, Inflight}. + case packet_id_to_seqno(NextSeqno, PacketId) of + Seqno when Seqno >= AckedUntil andalso Seqno < NextSeqno -> + %% TODO + %% We do not preserve `acked_until` in the database. Instead, we discard + %% fully acked ranges from the database. In effect, this means that the + %% most recent `acked_until` the client has sent may be lost in case of a + %% crash or client loss. + Inflight1 = Inflight0#inflight{acked_until = next_seqno(Seqno)}, + Inflight = discard_acked(SessionId, Inflight1), + {true, Inflight}; + OutOfRange -> + ?SLOG(warning, #{ + msg => "out-of-order_ack", + acked_until => AckedUntil, + acked_seqno => OutOfRange, + next_seqno => NextSeqno, + packet_id => PacketId + }), + {false, Inflight0} + end. -spec poll(emqx_persistent_session_ds:id(), inflight(), pos_integer()) -> {emqx_session:replies(), inflight()}. poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff -> - #inflight{next_seqno = NextSeqNo0, acked_seqno = AckedSeqno} = + #inflight{next_seqno = NextSeqNo0, acked_until = AckedSeqno} = Inflight0, FetchThreshold = max(1, WindowSize div 2), FreeSpace = AckedSeqno + WindowSize - NextSeqNo0, @@ -153,6 +154,7 @@ poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff %% client get stuck even? {[], Inflight0}; true -> + %% TODO: Wrap this in `mria:async_dirty/2`? Streams = shuffle(get_streams(SessionId)), fetch(SessionId, Inflight0, Streams, FreeSpace, []) end. @@ -165,75 +167,206 @@ poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff %% Internal functions %%================================================================================ -fetch(_SessionId, Inflight, _Streams = [], _N, Acc) -> - {lists:reverse(Acc), Inflight}; -fetch(_SessionId, Inflight, _Streams, 0, Acc) -> - {lists:reverse(Acc), Inflight}; -fetch(SessionId, Inflight0, [Stream | Streams], N, Publishes0) -> - #inflight{next_seqno = FirstSeqNo, offset_ranges = Ranges0} = Inflight0, - ItBegin = get_last_iterator(SessionId, Stream, Ranges0), +compute_inflight_range([]) -> + {1, 1}; +compute_inflight_range(Ranges) -> + _RangeLast = #range{until = LastSeqno} = lists:last(Ranges), + RangesUnacked = lists:dropwhile(fun(#range{type = T}) -> T == checkpoint end, Ranges), + case RangesUnacked of + [#range{first = AckedUntil} | _] -> + {AckedUntil, LastSeqno}; + [] -> + {LastSeqno, LastSeqno} + end. + +get_ranges(SessionId) -> + DSRanges = mnesia:match_object( + ?SESSION_PUBRANGE_TAB, + #ds_pubrange{id = {SessionId, '_'}, _ = '_'}, + read + ), + lists:map(fun export_range/1, DSRanges). + +export_range(#ds_pubrange{ + type = Type, id = {_, First}, until = Until, stream = StreamRef, iterator = It +}) -> + #range{type = Type, stream = StreamRef, first = First, until = Until, iterator = It}. + +fetch(SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 -> + #inflight{next_seqno = FirstSeqno, offset_ranges = Ranges0} = Inflight0, + ItBegin = get_last_iterator(DSStream, Ranges0), {ok, ItEnd, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N), - {NMessages, Publishes, Inflight1} = - lists:foldl( - fun(Msg, {N0, PubAcc0, InflightAcc0}) -> - {PacketId, InflightAcc} = next_packet_id(InflightAcc0), - PubAcc = [{PacketId, Msg} | PubAcc0], - {N0 + 1, PubAcc, InflightAcc} - end, - {0, Publishes0, Inflight0}, - Messages - ), - #inflight{next_seqno = LastSeqNo} = Inflight1, - case NMessages > 0 of - true -> - Range = #range{ - first = FirstSeqNo, - last = LastSeqNo - 1, - stream = Stream, - iterator_next = ItEnd + {Publishes, UntilSeqno} = publish(FirstSeqno, Messages), + case range_size(FirstSeqno, UntilSeqno) of + Size when Size > 0 -> + Range0 = #range{ + type = inflight, + first = FirstSeqno, + until = UntilSeqno, + stream = DSStream#ds_stream.ref, + iterator = ItBegin }, - Inflight = Inflight1#inflight{offset_ranges = Ranges0 ++ [Range]}, - fetch(SessionId, Inflight, Streams, N - NMessages, Publishes); - false -> - fetch(SessionId, Inflight1, Streams, N, Publishes) - end. + %% We need to preserve the iterator pointing to the beginning of the + %% range, so that we can replay it if needed. + ok = preserve_range(SessionId, Range0), + %% ...Yet we need to keep the iterator pointing past the end of the + %% range, so that we can pick up where we left off: it will become + %% `ItBegin` of the next range for this stream. + Range = Range0#range{iterator = ItEnd}, + Ranges = Ranges0 ++ [Range#range{iterator = ItEnd}], + Inflight = Inflight0#inflight{ + next_seqno = UntilSeqno, + offset_ranges = Ranges + }, + fetch(SessionId, Inflight, Streams, N - Size, [Publishes | Acc]); + 0 -> + fetch(SessionId, Inflight0, Streams, N, Acc) + end; +fetch(_SessionId, Inflight, _Streams, _N, Acc) -> + Publishes = lists:append(lists:reverse(Acc)), + {Publishes, Inflight}. --spec update_iterator(emqx_persistent_session_ds:id(), emqx_ds:stream(), emqx_ds:iterator()) -> ok. -update_iterator(DSSessionId, Stream, Iterator) -> - %% Workaround: we convert `Stream' to a binary before attempting to store it in - %% mnesia(rocksdb) because of a bug in `mnesia_rocksdb' when trying to do - %% `mnesia:dirty_all_keys' later. - StreamBin = term_to_binary(Stream), - mria:dirty_write(?SESSION_ITER_TAB, #ds_iter{id = {DSSessionId, StreamBin}, iter = Iterator}). +discard_acked( + SessionId, + Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0} +) -> + %% TODO: This could be kept and incrementally updated in the inflight state. + Checkpoints = find_checkpoints(Ranges0), + %% TODO: Wrap this in `mria:async_dirty/2`? + Ranges = discard_acked_ranges(SessionId, AckedUntil, Checkpoints, Ranges0), + Inflight0#inflight{offset_ranges = Ranges}. -get_last_iterator(SessionId, Stream, Ranges) -> - case lists:keyfind(Stream, #range.stream, lists:reverse(Ranges)) of - false -> - get_iterator(SessionId, Stream); - #range{iterator_next = Next} -> - Next - end. - --spec get_iterator(emqx_persistent_session_ds:id(), emqx_ds:stream()) -> emqx_ds:iterator(). -get_iterator(DSSessionId, Stream) -> - %% See comment in `update_iterator'. - StreamBin = term_to_binary(Stream), - Id = {DSSessionId, StreamBin}, - [#ds_iter{iter = It}] = mnesia:dirty_read(?SESSION_ITER_TAB, Id), - It. - --spec get_streams(emqx_persistent_session_ds:id()) -> [emqx_ds:stream()]. -get_streams(SessionId) -> - lists:map( - fun(#ds_stream{stream = Stream}) -> - Stream +find_checkpoints(Ranges) -> + lists:foldl( + fun(#range{stream = StreamRef, until = Until}, Acc) -> + %% For each stream, remember the last range over this stream. + Acc#{StreamRef => Until} end, - mnesia:dirty_read(?SESSION_STREAM_TAB, SessionId) + #{}, + Ranges ). +discard_acked_ranges( + SessionId, + AckedUntil, + Checkpoints, + [Range = #range{until = Until, stream = StreamRef} | Rest] +) when Until =< AckedUntil -> + %% This range has been fully acked. + %% Either discard it completely, or preserve the iterator for the next range + %% over this stream (i.e. a checkpoint). + RangeKept = + case maps:get(StreamRef, Checkpoints) of + CP when CP > Until -> + discard_range(SessionId, Range), + []; + Until -> + checkpoint_range(SessionId, Range), + [Range#range{type = checkpoint}] + end, + %% Since we're (intentionally) not using transactions here, it's important to + %% issue database writes in the same order in which ranges are stored: from + %% the oldest to the newest. This is also why we need to compute which ranges + %% should become checkpoints before we start writing anything. + RangeKept ++ discard_acked_ranges(SessionId, AckedUntil, Checkpoints, Rest); +discard_acked_ranges(_SessionId, _AckedUntil, _Checkpoints, Ranges) -> + %% The rest of ranges (if any) still have unacked messages. + Ranges. + +replay_range( + Range0 = #range{type = inflight, first = First, until = Until, iterator = It}, + AckedUntil, + Acc +) -> + Size = range_size(First, Until), + FirstUnacked = max(First, AckedUntil), + {ok, ItNext, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size), + MessagesUnacked = + case FirstUnacked of + First -> + Messages; + _ -> + lists:nthtail(range_size(First, FirstUnacked), Messages) + end, + %% Asserting that range is consistent with the message storage state. + {Replies, Until} = publish(FirstUnacked, MessagesUnacked), + Range = Range0#range{iterator = ItNext}, + {Range, Replies ++ Acc}; +replay_range(Range0 = #range{type = checkpoint}, _AckedUntil, Acc) -> + {Range0, Acc}. + +publish(FirstSeqno, Messages) -> + lists:mapfoldl( + fun(Message, Seqno) -> + PacketId = seqno_to_packet_id(Seqno), + {{PacketId, Message}, next_seqno(Seqno)} + end, + FirstSeqno, + Messages + ). + +-spec preserve_range(emqx_persistent_session_ds:id(), range()) -> ok. +preserve_range( + SessionId, + #range{first = First, until = Until, stream = StreamRef, iterator = It} +) -> + DSRange = #ds_pubrange{ + id = {SessionId, First}, + until = Until, + stream = StreamRef, + type = inflight, + iterator = It + }, + mria:dirty_write(?SESSION_PUBRANGE_TAB, DSRange). + +-spec discard_range(emqx_persistent_session_ds:id(), range()) -> ok. +discard_range(SessionId, #range{first = First}) -> + mria:dirty_delete(?SESSION_PUBRANGE_TAB, {SessionId, First}). + +-spec checkpoint_range(emqx_persistent_session_ds:id(), range()) -> ok. +checkpoint_range( + SessionId, + #range{type = inflight, first = First, until = Until, stream = StreamRef, iterator = ItNext} +) -> + DSRange = #ds_pubrange{ + id = {SessionId, First}, + until = Until, + stream = StreamRef, + type = checkpoint, + iterator = ItNext + }, + mria:dirty_write(?SESSION_PUBRANGE_TAB, DSRange); +checkpoint_range(_SessionId, #range{type = checkpoint}) -> + %% This range should have been checkpointed already. + ok. + +get_last_iterator(DSStream = #ds_stream{ref = StreamRef}, Ranges) -> + case lists:keyfind(StreamRef, #range.stream, lists:reverse(Ranges)) of + false -> + DSStream#ds_stream.beginning; + #range{iterator = ItNext} -> + ItNext + end. + +-spec get_streams(emqx_persistent_session_ds:id()) -> [ds_stream()]. +get_streams(SessionId) -> + mnesia:dirty_read(?SESSION_STREAM_TAB, SessionId). + +next_seqno(Seqno) -> + NextSeqno = Seqno + 1, + case seqno_to_packet_id(NextSeqno) of + 0 -> + %% We skip sequence numbers that lead to PacketId = 0 to + %% simplify math. Note: it leads to occasional gaps in the + %% sequence numbers. + NextSeqno + 1; + _ -> + NextSeqno + end. + %% Reconstruct session counter by adding most significant bits from %% the current counter to the packet id. --spec packet_id_to_seqno(non_neg_integer(), emqx_types:packet_id()) -> non_neg_integer(). +-spec packet_id_to_seqno(_Next :: seqno(), emqx_types:packet_id()) -> seqno(). packet_id_to_seqno(NextSeqNo, PacketId) -> Epoch = NextSeqNo bsr 16, case packet_id_to_seqno_(Epoch, PacketId) of @@ -243,10 +376,20 @@ packet_id_to_seqno(NextSeqNo, PacketId) -> packet_id_to_seqno_(Epoch - 1, PacketId) end. --spec packet_id_to_seqno_(non_neg_integer(), emqx_types:packet_id()) -> non_neg_integer(). +-spec packet_id_to_seqno_(non_neg_integer(), emqx_types:packet_id()) -> seqno(). packet_id_to_seqno_(Epoch, PacketId) -> (Epoch bsl 16) + PacketId. +-spec seqno_to_packet_id(seqno()) -> emqx_types:packet_id(). +seqno_to_packet_id(Seqno) -> + Seqno rem 16#10000. + +range_size(FirstSeqno, UntilSeqno) -> + %% This function assumes that gaps in the sequence ID occur _only_ when the + %% packet ID wraps. + Size = UntilSeqno - FirstSeqno, + Size + (FirstSeqno bsr 16) - (UntilSeqno bsr 16). + -spec shuffle([A]) -> [A]. shuffle(L0) -> L1 = lists:map( @@ -259,6 +402,10 @@ shuffle(L0) -> {_, L} = lists:unzip(L2), L. +ro_transaction(Fun) -> + {atomic, Res} = mria:ro_transaction(?DS_MRIA_SHARD, Fun), + Res. + -ifdef(TEST). %% This test only tests boundary conditions (to make sure property-based test didn't skip them): @@ -311,4 +458,40 @@ seqno_gen(NextSeqNo) -> Max = max(0, NextSeqNo - 1), range(Min, Max). +range_size_test_() -> + [ + ?_assertEqual(0, range_size(42, 42)), + ?_assertEqual(1, range_size(42, 43)), + ?_assertEqual(1, range_size(16#ffff, 16#10001)), + ?_assertEqual(16#ffff - 456 + 123, range_size(16#1f0000 + 456, 16#200000 + 123)) + ]. + +compute_inflight_range_test_() -> + [ + ?_assertEqual( + {1, 1}, + compute_inflight_range([]) + ), + ?_assertEqual( + {12, 42}, + compute_inflight_range([ + #range{first = 1, until = 2, type = checkpoint}, + #range{first = 4, until = 8, type = checkpoint}, + #range{first = 11, until = 12, type = checkpoint}, + #range{first = 12, until = 13, type = inflight}, + #range{first = 13, until = 20, type = inflight}, + #range{first = 20, until = 42, type = inflight} + ]) + ), + ?_assertEqual( + {13, 13}, + compute_inflight_range([ + #range{first = 1, until = 2, type = checkpoint}, + #range{first = 4, until = 8, type = checkpoint}, + #range{first = 11, until = 12, type = checkpoint}, + #range{first = 12, until = 13, type = checkpoint} + ]) + ) + ]. + -endif. diff --git a/apps/emqx/src/emqx_persistent_session_ds.erl b/apps/emqx/src/emqx_persistent_session_ds.erl index 3a7232747..7ba5aa527 100644 --- a/apps/emqx/src/emqx_persistent_session_ds.erl +++ b/apps/emqx/src/emqx_persistent_session_ds.erl @@ -76,7 +76,7 @@ list_all_sessions/0, list_all_subscriptions/0, list_all_streams/0, - list_all_iterators/0 + list_all_pubranges/0 ]). -endif. @@ -359,15 +359,16 @@ handle_timeout( end, ensure_timer(pull, Timeout), {ok, Publishes, Session#{inflight => Inflight}}; -handle_timeout(_ClientInfo, get_streams, Session = #{id := Id}) -> - renew_streams(Id), +handle_timeout(_ClientInfo, get_streams, Session) -> + renew_streams(Session), ensure_timer(get_streams), {ok, [], Session}. -spec replay(clientinfo(), [], session()) -> {ok, replies(), session()}. -replay(_ClientInfo, [], Session = #{}) -> - {ok, [], Session}. +replay(_ClientInfo, [], Session = #{inflight := Inflight0}) -> + {Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(Inflight0), + {ok, Replies, Session#{inflight := Inflight}}. %%-------------------------------------------------------------------- @@ -474,17 +475,20 @@ create_tables() -> ] ), ok = mria:create_table( - ?SESSION_ITER_TAB, + ?SESSION_PUBRANGE_TAB, [ {rlog_shard, ?DS_MRIA_SHARD}, - {type, set}, + {type, ordered_set}, {storage, storage()}, - {record_name, ds_iter}, - {attributes, record_info(fields, ds_iter)} + {record_name, ds_pubrange}, + {attributes, record_info(fields, ds_pubrange)} ] ), ok = mria:wait_for_tables([ - ?SESSION_TAB, ?SESSION_SUBSCRIPTIONS_TAB, ?SESSION_STREAM_TAB, ?SESSION_ITER_TAB + ?SESSION_TAB, + ?SESSION_SUBSCRIPTIONS_TAB, + ?SESSION_STREAM_TAB, + ?SESSION_PUBRANGE_TAB ]), ok. @@ -512,9 +516,10 @@ session_open(SessionId) -> Session = export_session(Record), DSSubs = session_read_subscriptions(SessionId), Subscriptions = export_subscriptions(DSSubs), + Inflight = emqx_persistent_message_ds_replayer:open(SessionId), Session#{ subscriptions => Subscriptions, - inflight => emqx_persistent_message_ds_replayer:new() + inflight => Inflight }; [] -> false @@ -549,7 +554,7 @@ session_create(SessionId, Props) -> session_drop(DSSessionId) -> transaction(fun() -> ok = session_drop_subscriptions(DSSessionId), - ok = session_drop_iterators(DSSessionId), + ok = session_drop_pubranges(DSSessionId), ok = session_drop_streams(DSSessionId), ok = mnesia:delete(?SESSION_TAB, DSSessionId, write) end). @@ -663,77 +668,82 @@ do_ensure_all_iterators_closed(_DSSessionID) -> %% Reading batches %%-------------------------------------------------------------------- --spec renew_streams(id()) -> ok. -renew_streams(DSSessionId) -> - Subscriptions = ro_transaction(fun() -> session_read_subscriptions(DSSessionId) end), - ExistingStreams = ro_transaction(fun() -> mnesia:read(?SESSION_STREAM_TAB, DSSessionId) end), - lists:foreach( - fun(#ds_sub{id = {_, TopicFilter}, start_time = StartTime}) -> - renew_streams(DSSessionId, ExistingStreams, TopicFilter, StartTime) +-spec renew_streams(session()) -> ok. +renew_streams(#{id := SessionId, subscriptions := Subscriptions}) -> + transaction(fun() -> + ExistingStreams = mnesia:read(?SESSION_STREAM_TAB, SessionId, write), + maps:fold( + fun(TopicFilter, #{start_time := StartTime}, Streams) -> + TopicFilterWords = emqx_topic:words(TopicFilter), + renew_topic_streams(SessionId, TopicFilterWords, StartTime, Streams) + end, + ExistingStreams, + Subscriptions + ) + end), + ok. + +-spec renew_topic_streams(id(), topic_filter_words(), emqx_ds:time(), _Acc :: [ds_stream()]) -> ok. +renew_topic_streams(DSSessionId, TopicFilter, StartTime, ExistingStreams) -> + TopicStreams = emqx_ds:get_streams(?PERSISTENT_MESSAGE_DB, TopicFilter, StartTime), + lists:foldl( + fun({Rank, Stream}, Streams) -> + case lists:keymember(Stream, #ds_stream.stream, Streams) of + true -> + Streams; + false -> + StreamRef = length(Streams) + 1, + DSStream = session_store_stream( + DSSessionId, + StreamRef, + Stream, + Rank, + TopicFilter, + StartTime + ), + [DSStream | Streams] + end end, - Subscriptions + ExistingStreams, + TopicStreams ). --spec renew_streams(id(), [ds_stream()], topic_filter_words(), emqx_ds:time()) -> ok. -renew_streams(DSSessionId, ExistingStreams, TopicFilter, StartTime) -> - AllStreams = emqx_ds:get_streams(?PERSISTENT_MESSAGE_DB, TopicFilter, StartTime), - transaction( - fun() -> - lists:foreach( - fun({Rank, Stream}) -> - Rec = #ds_stream{ - session = DSSessionId, - topic_filter = TopicFilter, - stream = Stream, - rank = Rank - }, - case lists:member(Rec, ExistingStreams) of - true -> - ok; - false -> - mnesia:write(?SESSION_STREAM_TAB, Rec, write), - {ok, Iterator} = emqx_ds:make_iterator( - ?PERSISTENT_MESSAGE_DB, Stream, TopicFilter, StartTime - ), - %% Workaround: we convert `Stream' to a binary before - %% attempting to store it in mnesia(rocksdb) because of a bug - %% in `mnesia_rocksdb' when trying to do - %% `mnesia:dirty_all_keys' later. - StreamBin = term_to_binary(Stream), - IterRec = #ds_iter{id = {DSSessionId, StreamBin}, iter = Iterator}, - mnesia:write(?SESSION_ITER_TAB, IterRec, write) - end - end, - AllStreams - ) - end - ). +session_store_stream(DSSessionId, StreamRef, Stream, Rank, TopicFilter, StartTime) -> + {ok, ItBegin} = emqx_ds:make_iterator( + ?PERSISTENT_MESSAGE_DB, + Stream, + TopicFilter, + StartTime + ), + DSStream = #ds_stream{ + session = DSSessionId, + ref = StreamRef, + stream = Stream, + rank = Rank, + beginning = ItBegin + }, + mnesia:write(?SESSION_STREAM_TAB, DSStream, write), + DSStream. %% must be called inside a transaction -spec session_drop_streams(id()) -> ok. session_drop_streams(DSSessionId) -> - MS = ets:fun2ms( - fun(#ds_stream{session = DSSessionId0}) when DSSessionId0 =:= DSSessionId -> - DSSessionId0 - end - ), - StreamIDs = mnesia:select(?SESSION_STREAM_TAB, MS, write), - lists:foreach(fun(Key) -> mnesia:delete(?SESSION_STREAM_TAB, Key, write) end, StreamIDs). + mnesia:delete(?SESSION_STREAM_TAB, DSSessionId, write). %% must be called inside a transaction --spec session_drop_iterators(id()) -> ok. -session_drop_iterators(DSSessionId) -> +-spec session_drop_pubranges(id()) -> ok. +session_drop_pubranges(DSSessionId) -> MS = ets:fun2ms( - fun(#ds_iter{id = {DSSessionId0, StreamBin}}) when DSSessionId0 =:= DSSessionId -> - StreamBin + fun(#ds_pubrange{id = {DSSessionId0, First}}) when DSSessionId0 =:= DSSessionId -> + {DSSessionId, First} end ), - StreamBins = mnesia:select(?SESSION_ITER_TAB, MS, write), + RangeIds = mnesia:select(?SESSION_PUBRANGE_TAB, MS, write), lists:foreach( - fun(StreamBin) -> - mnesia:delete(?SESSION_ITER_TAB, {DSSessionId, StreamBin}, write) + fun(RangeId) -> + mnesia:delete(?SESSION_PUBRANGE_TAB, RangeId, write) end, - StreamBins + RangeIds ). %%-------------------------------------------------------------------------------- @@ -758,7 +768,7 @@ export_subscriptions(DSSubs) -> ). export_session(#session{} = Record) -> - export_record(Record, #session.id, [id, created_at, expires_at, inflight, props], #{}). + export_record(Record, #session.id, [id, created_at, expires_at, props], #{}). export_subscription(#ds_sub{} = Record) -> export_record(Record, #ds_sub.start_time, [start_time, props, extra], #{}). @@ -833,16 +843,18 @@ list_all_streams() -> ), maps:from_list(DSStreams). -list_all_iterators() -> - DSIterIds = mnesia:dirty_all_keys(?SESSION_ITER_TAB), - DSIters = lists:map( - fun(DSIterId) -> - [Record] = mnesia:dirty_read(?SESSION_ITER_TAB, DSIterId), - {DSIterId, export_record(Record, #ds_iter.id, [id, iter], #{})} +list_all_pubranges() -> + DSPubranges = mnesia:dirty_match_object(?SESSION_PUBRANGE_TAB, #ds_pubrange{_ = '_'}), + lists:foldl( + fun(Record = #ds_pubrange{id = {SessionId, First}}, Acc) -> + Range = export_record( + Record, #ds_pubrange.until, [until, stream, type, iterator], #{first => First} + ), + maps:put(SessionId, maps:get(SessionId, Acc, []) ++ [Range], Acc) end, - DSIterIds - ), - maps:from_list(DSIters). + #{}, + DSPubranges + ). %% ifdef(TEST) -endif. diff --git a/apps/emqx/src/emqx_persistent_session_ds.hrl b/apps/emqx/src/emqx_persistent_session_ds.hrl index cc995ce66..a3ea5a662 100644 --- a/apps/emqx/src/emqx_persistent_session_ds.hrl +++ b/apps/emqx/src/emqx_persistent_session_ds.hrl @@ -21,7 +21,7 @@ -define(SESSION_TAB, emqx_ds_session). -define(SESSION_SUBSCRIPTIONS_TAB, emqx_ds_session_subscriptions). -define(SESSION_STREAM_TAB, emqx_ds_stream_tab). --define(SESSION_ITER_TAB, emqx_ds_iter_tab). +-define(SESSION_PUBRANGE_TAB, emqx_ds_pubrange_tab). -define(DS_MRIA_SHARD, emqx_ds_session_shard). -record(ds_sub, { @@ -34,17 +34,24 @@ -record(ds_stream, { session :: emqx_persistent_session_ds:id(), - topic_filter :: emqx_ds:topic_filter(), + ref :: _StreamRef, stream :: emqx_ds:stream(), - rank :: emqx_ds:stream_rank() + rank :: emqx_ds:stream_rank(), + beginning :: emqx_ds:iterator() }). -type ds_stream() :: #ds_stream{}. --type ds_stream_bin() :: binary(). --record(ds_iter, { - id :: {emqx_persistent_session_ds:id(), ds_stream_bin()}, - iter :: emqx_ds:iterator() +-record(ds_pubrange, { + id :: { + _Session :: emqx_persistent_session_ds:id(), + _First :: emqx_persistent_message_ds_replayer:seqno() + }, + until :: emqx_persistent_message_ds_replayer:seqno(), + stream :: _StreamRef, + type :: inflight | checkpoint, + iterator :: emqx_ds:iterator() }). +-type ds_pubrange() :: #ds_pubrange{}. -record(session, { %% same as clientid @@ -52,7 +59,7 @@ %% creation time created_at :: _Millisecond :: non_neg_integer(), expires_at = never :: _Millisecond :: non_neg_integer() | never, - inflight :: emqx_persistent_message_ds_replayer:inflight(), + % last_ack = 0 :: emqx_persistent_message_ds_replayer:seqno(), %% for future usage props = #{} :: map() }).