feat(sessds): provide QoS2 message replay support

This commit is contained in:
Andrew Mayorov 2023-11-26 22:34:41 +03:00
parent ce59cb71bb
commit 46475fac66
No known key found for this signature in database
GPG Key ID: 2837C62ACFBFED5D
5 changed files with 682 additions and 230 deletions

View File

@ -19,7 +19,13 @@
-module(emqx_persistent_message_ds_replayer).
%% API:
-export([new/0, open/1, next_packet_id/1, replay/1, commit_offset/3, poll/3, n_inflight/1]).
-export([new/0, open/1, next_packet_id/1, n_inflight/1]).
-export([poll/4, replay/2, commit_offset/4, commit_marker/4]).
-export([seqno_to_packet_id/1, packet_id_to_seqno/2]).
-export([committed_until/2]).
%% internal exports:
-export([]).
@ -27,7 +33,6 @@
-export_type([inflight/0, seqno/0]).
-include_lib("emqx/include/logger.hrl").
-include_lib("emqx_utils/include/emqx_message.hrl").
-include("emqx_persistent_session_ds.hrl").
-ifdef(TEST).
@ -35,6 +40,13 @@
-include_lib("eunit/include/eunit.hrl").
-endif.
-define(EPOCH_SIZE, 16#10000).
-define(ACK, 0).
-define(COMP, 1).
-define(TRACK_FLAG(WHICH), (1 bsl WHICH)).
%%================================================================================
%% Type declarations
%%================================================================================
@ -42,15 +54,20 @@
%% Note: sequence numbers are monotonic; they don't wrap around:
-type seqno() :: non_neg_integer().
-type track() :: ack | comp.
-type marker() :: rec.
-record(inflight, {
next_seqno = 1 :: seqno(),
acked_until = 1 :: seqno(),
commits = #{ack => 1, comp => 1, rec => 1} :: #{track() | marker() => seqno()},
%% Ranges are sorted in ascending order of their sequence numbers.
offset_ranges = [] :: [ds_pubrange()]
}).
-opaque inflight() :: #inflight{}.
-type reply_fun() :: fun((seqno(), emqx_types:message()) -> emqx_session:reply()).
%%================================================================================
%% API funcions
%%================================================================================
@ -61,10 +78,12 @@ new() ->
-spec open(emqx_persistent_session_ds:id()) -> inflight().
open(SessionId) ->
Ranges = ro_transaction(fun() -> get_ranges(SessionId) end),
{AckedUntil, NextSeqno} = compute_inflight_range(Ranges),
{Ranges, RecUntil} = ro_transaction(
fun() -> {get_ranges(SessionId), get_marker(SessionId, rec)} end
),
{Commits, NextSeqno} = compute_inflight_range(Ranges),
#inflight{
acked_until = AckedUntil,
commits = Commits#{rec => RecUntil},
next_seqno = NextSeqno,
offset_ranges = Ranges
}.
@ -75,15 +94,30 @@ next_packet_id(Inflight0 = #inflight{next_seqno = LastSeqno}) ->
{seqno_to_packet_id(LastSeqno), Inflight}.
-spec n_inflight(inflight()) -> non_neg_integer().
n_inflight(#inflight{next_seqno = NextSeqno, acked_until = AckedUntil}) ->
range_size(AckedUntil, NextSeqno).
n_inflight(#inflight{offset_ranges = Ranges}) ->
%% TODO
%% This is not very efficient. Instead, we can take the maximum of
%% `range_size(AckedUntil, NextSeqno)` and `range_size(CompUntil, NextSeqno)`.
%% This won't be exact number but a pessimistic estimate, but this way we
%% will penalize clients that PUBACK QoS 1 messages but don't PUBCOMP QoS 2
%% messages for some reason. For that to work, we need to additionally track
%% actual `AckedUntil` / `CompUntil` during `commit_offset/4`.
lists:foldl(
fun
(#ds_pubrange{type = checkpoint}, N) ->
N;
(#ds_pubrange{type = inflight, id = {_, First}, until = Until}, N) ->
N + range_size(First, Until)
end,
0,
Ranges
).
-spec replay(inflight()) ->
{emqx_session:replies(), inflight()}.
replay(Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0}) ->
-spec replay(reply_fun(), inflight()) -> {emqx_session:replies(), inflight()}.
replay(ReplyFun, Inflight0 = #inflight{offset_ranges = Ranges0}) ->
{Ranges, Replies} = lists:mapfoldr(
fun(Range, Acc) ->
replay_range(Range, AckedUntil, Acc)
replay_range(ReplyFun, Range, Acc)
end,
[],
Ranges0
@ -91,43 +125,50 @@ replay(Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0})
Inflight = Inflight0#inflight{offset_ranges = Ranges},
{Replies, Inflight}.
-spec commit_offset(emqx_persistent_session_ds:id(), emqx_types:packet_id(), inflight()) ->
-spec commit_offset(emqx_persistent_session_ds:id(), track(), emqx_types:packet_id(), inflight()) ->
{_IsValidOffset :: boolean(), inflight()}.
commit_offset(
SessionId,
Track,
PacketId,
Inflight0 = #inflight{
acked_until = AckedUntil, next_seqno = NextSeqno
}
Inflight0 = #inflight{commits = Commits}
) ->
case packet_id_to_seqno(NextSeqno, PacketId) of
Seqno when Seqno >= AckedUntil andalso Seqno < NextSeqno ->
case validate_commit(Track, PacketId, Inflight0) of
CommitUntil when is_integer(CommitUntil) ->
%% TODO
%% We do not preserve `acked_until` in the database. Instead, we discard
%% We do not preserve `CommitUntil` in the database. Instead, we discard
%% fully acked ranges from the database. In effect, this means that the
%% most recent `acked_until` the client has sent may be lost in case of a
%% most recent `CommitUntil` the client has sent may be lost in case of a
%% crash or client loss.
Inflight1 = Inflight0#inflight{acked_until = next_seqno(Seqno)},
Inflight = discard_acked(SessionId, Inflight1),
Inflight1 = Inflight0#inflight{commits = Commits#{Track := CommitUntil}},
Inflight = discard_committed(SessionId, Inflight1),
{true, Inflight};
OutOfRange ->
?SLOG(warning, #{
msg => "out-of-order_ack",
acked_until => AckedUntil,
acked_seqno => OutOfRange,
next_seqno => NextSeqno,
packet_id => PacketId
}),
false ->
{false, Inflight0}
end.
-spec poll(emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
-spec commit_marker(emqx_persistent_session_ds:id(), marker(), emqx_types:packet_id(), inflight()) ->
{_IsValidMarker :: boolean(), inflight()}.
commit_marker(
SessionId,
Marker = rec,
PacketId,
Inflight0 = #inflight{commits = Commits}
) ->
case validate_commit(Marker, PacketId, Inflight0) of
CommitUntil when is_integer(CommitUntil) ->
update_marker(SessionId, Marker, CommitUntil),
Inflight = Inflight0#inflight{commits = Commits#{Marker := CommitUntil}},
{true, Inflight};
false ->
{false, Inflight0}
end.
-spec poll(reply_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
{emqx_session:replies(), inflight()}.
poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff ->
#inflight{next_seqno = NextSeqNo0, acked_until = AckedSeqno} =
Inflight0,
poll(ReplyFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE ->
FetchThreshold = max(1, WindowSize div 2),
FreeSpace = AckedSeqno + WindowSize - NextSeqNo0,
FreeSpace = WindowSize - n_inflight(Inflight0),
case FreeSpace >= FetchThreshold of
false ->
%% TODO: this branch is meant to avoid fetching data from
@ -138,9 +179,23 @@ poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff
true ->
%% TODO: Wrap this in `mria:async_dirty/2`?
Streams = shuffle(get_streams(SessionId)),
fetch(SessionId, Inflight0, Streams, FreeSpace, [])
fetch(ReplyFun, SessionId, Inflight0, Streams, FreeSpace, [])
end.
-spec committed_until(track() | marker(), inflight()) -> seqno().
committed_until(Track, #inflight{commits = Commits}) ->
maps:get(Track, Commits).
-spec seqno_to_packet_id(seqno()) -> emqx_types:packet_id() | 0.
seqno_to_packet_id(Seqno) ->
Seqno rem ?EPOCH_SIZE.
%% Reconstruct session counter by adding most significant bits from
%% the current counter to the packet id.
-spec packet_id_to_seqno(emqx_types:packet_id(), inflight()) -> seqno().
packet_id_to_seqno(PacketId, #inflight{next_seqno = NextSeqno}) ->
packet_id_to_seqno_(NextSeqno, PacketId).
%%================================================================================
%% Internal exports
%%================================================================================
@ -150,18 +205,34 @@ poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff
%%================================================================================
compute_inflight_range([]) ->
{1, 1};
{#{ack => 1, comp => 1}, 1};
compute_inflight_range(Ranges) ->
_RangeLast = #ds_pubrange{until = LastSeqno} = lists:last(Ranges),
RangesUnacked = lists:dropwhile(
fun(#ds_pubrange{type = T}) -> T == checkpoint end,
AckedUntil = find_committed_until(ack, Ranges),
CompUntil = find_committed_until(comp, Ranges),
Commits = #{
ack => emqx_maybe:define(AckedUntil, LastSeqno),
comp => emqx_maybe:define(CompUntil, LastSeqno)
},
{Commits, LastSeqno}.
find_committed_until(Track, Ranges) ->
RangesUncommitted = lists:dropwhile(
fun(Range) ->
case Range of
#ds_pubrange{type = checkpoint} ->
true;
#ds_pubrange{type = inflight} = Range ->
not has_range_track(Track, Range)
end
end,
Ranges
),
case RangesUnacked of
[#ds_pubrange{id = {_, AckedUntil}} | _] ->
{AckedUntil, LastSeqno};
case RangesUncommitted of
[#ds_pubrange{id = {_, CommittedUntil}} | _] ->
CommittedUntil;
[] ->
{LastSeqno, LastSeqno}
undefined
end.
-spec get_ranges(emqx_persistent_session_ds:id()) -> [ds_pubrange()].
@ -173,18 +244,18 @@ get_ranges(SessionId) ->
),
mnesia:match_object(?SESSION_PUBRANGE_TAB, Pat, read).
fetch(SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
fetch(ReplyFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
#inflight{next_seqno = FirstSeqno, offset_ranges = Ranges} = Inflight0,
ItBegin = get_last_iterator(DSStream, Ranges),
{ok, ItEnd, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N),
case Messages of
[] ->
fetch(SessionId, Inflight0, Streams, N, Acc);
fetch(ReplyFun, SessionId, Inflight0, Streams, N, Acc);
_ ->
{Publishes, UntilSeqno} = publish(FirstSeqno, Messages, _PreserveQoS0 = true),
Size = range_size(FirstSeqno, UntilSeqno),
%% We need to preserve the iterator pointing to the beginning of the
%% range, so that we can replay it if needed.
{Publishes, {UntilSeqno, Tracks}} = publish(ReplyFun, FirstSeqno, Messages),
Size = range_size(FirstSeqno, UntilSeqno),
Range0 = #ds_pubrange{
id = {SessionId, FirstSeqno},
type = inflight,
@ -192,29 +263,30 @@ fetch(SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
stream = DSStream#ds_stream.ref,
iterator = ItBegin
},
ok = preserve_range(Range0),
Range1 = update_range_tracks(Tracks, Range0),
ok = preserve_range(Range1),
%% ...Yet we need to keep the iterator pointing past the end of the
%% range, so that we can pick up where we left off: it will become
%% `ItBegin` of the next range for this stream.
Range = Range0#ds_pubrange{iterator = ItEnd},
Range = keep_next_iterator(ItEnd, Range1),
Inflight = Inflight0#inflight{
next_seqno = UntilSeqno,
offset_ranges = Ranges ++ [Range]
},
fetch(SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
fetch(ReplyFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
end;
fetch(_SessionId, Inflight, _Streams, _N, Acc) ->
fetch(_ReplyFun, _SessionId, Inflight, _Streams, _N, Acc) ->
Publishes = lists:append(lists:reverse(Acc)),
{Publishes, Inflight}.
discard_acked(
discard_committed(
SessionId,
Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0}
Inflight0 = #inflight{commits = Commits, offset_ranges = Ranges0}
) ->
%% TODO: This could be kept and incrementally updated in the inflight state.
Checkpoints = find_checkpoints(Ranges0),
%% TODO: Wrap this in `mria:async_dirty/2`?
Ranges = discard_acked_ranges(SessionId, AckedUntil, Checkpoints, Ranges0),
Ranges = discard_committed_ranges(SessionId, Commits, Checkpoints, Ranges0),
Inflight0#inflight{offset_ranges = Ranges}.
find_checkpoints(Ranges) ->
@ -227,13 +299,15 @@ find_checkpoints(Ranges) ->
Ranges
).
discard_acked_ranges(
discard_committed_ranges(
SessionId,
AckedUntil,
Commits,
Checkpoints,
[Range = #ds_pubrange{until = Until, stream = StreamRef} | Rest]
) when Until =< AckedUntil ->
%% This range has been fully acked.
Ranges = [Range = #ds_pubrange{until = Until, stream = StreamRef} | Rest]
) ->
case discard_committed_range(Commits, Range) of
discard ->
%% This range has been fully committed.
%% Either discard it completely, or preserve the iterator for the next range
%% over this stream (i.e. a checkpoint).
RangeKept =
@ -248,60 +322,174 @@ discard_acked_ranges(
%% issue database writes in the same order in which ranges are stored: from
%% the oldest to the newest. This is also why we need to compute which ranges
%% should become checkpoints before we start writing anything.
RangeKept ++ discard_acked_ranges(SessionId, AckedUntil, Checkpoints, Rest);
discard_acked_ranges(_SessionId, _AckedUntil, _Checkpoints, Ranges) ->
%% The rest of ranges (if any) still have unacked messages.
Ranges.
RangeKept ++ discard_committed_ranges(SessionId, Commits, Checkpoints, Rest);
keep ->
%% This range has not been fully committed.
[Range | discard_committed_ranges(SessionId, Commits, Checkpoints, Rest)];
keep_all ->
%% The rest of ranges (if any) still have uncommitted messages.
Ranges;
TracksLeft ->
%% Only some track has been committed.
%% Preserve the uncommitted tracks in the database.
RangeKept = update_range_tracks(TracksLeft, Range),
preserve_range(restore_first_iterator(RangeKept)),
[RangeKept | discard_committed_ranges(SessionId, Commits, Checkpoints, Rest)]
end;
discard_committed_ranges(_SessionId, _Commits, _Checkpoints, []) ->
[].
discard_committed_range(_Commits, #ds_pubrange{type = checkpoint}) ->
discard;
discard_committed_range(
#{ack := AckedUntil, comp := CompUntil},
#ds_pubrange{until = Until}
) when Until > AckedUntil andalso Until > CompUntil ->
keep_all;
discard_committed_range(
Commits,
Range = #ds_pubrange{until = Until}
) ->
Tracks = get_range_tracks(Range),
case discard_tracks(Commits, Until, Tracks) of
0 ->
discard;
Tracks ->
keep;
TracksLeft ->
TracksLeft
end.
discard_tracks(#{ack := AckedUntil, comp := CompUntil}, Until, Tracks) ->
TAck =
case Until > AckedUntil of
true -> ?TRACK_FLAG(?ACK) band Tracks;
false -> 0
end,
TComp =
case Until > CompUntil of
true -> ?TRACK_FLAG(?COMP) band Tracks;
false -> 0
end,
TAck bor TComp.
replay_range(
ReplyFun,
Range0 = #ds_pubrange{type = inflight, id = {_, First}, until = Until, iterator = It},
AckedUntil,
Acc
) ->
Size = range_size(First, Until),
FirstUnacked = max(First, AckedUntil),
{ok, ItNext, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size),
MessagesUnacked =
case FirstUnacked of
First ->
Messages;
_ ->
lists:nthtail(range_size(First, FirstUnacked), Messages)
end,
MessagesReplay = [emqx_message:set_flag(dup, true, Msg) || Msg <- MessagesUnacked],
{ok, ItNext, MessagesUnacked} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size),
%% Asserting that range is consistent with the message storage state.
{Replies, Until} = publish(FirstUnacked, MessagesReplay, _PreserveQoS0 = false),
{Replies, {Until, Tracks}} = publish(ReplyFun, First, MessagesUnacked),
%% Again, we need to keep the iterator pointing past the end of the
%% range, so that we can pick up where we left off.
Range = Range0#ds_pubrange{iterator = ItNext},
Range = keep_next_iterator(ItNext, ensure_range_tracks(Tracks, Range0)),
{Range, Replies ++ Acc};
replay_range(Range0 = #ds_pubrange{type = checkpoint}, _AckedUntil, Acc) ->
replay_range(_ReplyFun, Range0 = #ds_pubrange{type = checkpoint}, Acc) ->
{Range0, Acc}.
publish(FirstSeqNo, Messages, PreserveQos0) ->
do_publish(FirstSeqNo, Messages, PreserveQos0, []).
validate_commit(
Track,
PacketId,
Inflight = #inflight{commits = Commits, next_seqno = NextSeqno}
) ->
Seqno = packet_id_to_seqno_(NextSeqno, PacketId),
CommittedUntil = maps:get(Track, Commits),
CommitNext = get_commit_next(Track, Inflight),
case Seqno >= CommittedUntil andalso Seqno < CommitNext of
true ->
next_seqno(Seqno);
false ->
?SLOG(warning, #{
msg => "out-of-order_commit",
track => Track,
packet_id => PacketId,
commit_seqno => Seqno,
committed_until => CommittedUntil,
commit_next => CommitNext
}),
false
end.
do_publish(SeqNo, [], _, Acc) ->
{lists:reverse(Acc), SeqNo};
do_publish(SeqNo, [#message{qos = 0} | Messages], false, Acc) ->
do_publish(SeqNo, Messages, false, Acc);
do_publish(SeqNo, [#message{qos = 0} = Message | Messages], true, Acc) ->
do_publish(SeqNo, Messages, true, [{undefined, Message} | Acc]);
do_publish(SeqNo, [Message | Messages], PreserveQos0, Acc) ->
PacketId = seqno_to_packet_id(SeqNo),
do_publish(next_seqno(SeqNo), Messages, PreserveQos0, [{PacketId, Message} | Acc]).
get_commit_next(ack, #inflight{next_seqno = NextSeqno}) ->
NextSeqno;
get_commit_next(rec, #inflight{next_seqno = NextSeqno}) ->
NextSeqno;
get_commit_next(comp, #inflight{commits = Commits}) ->
maps:get(rec, Commits).
publish(ReplyFun, FirstSeqno, Messages) ->
lists:mapfoldl(
fun(Message, {Seqno, TAcc}) ->
case ReplyFun(Seqno, Message) of
{_Advance = false, Reply} ->
{Reply, {Seqno, TAcc}};
Reply ->
NextSeqno = next_seqno(Seqno),
NextTAcc = add_msg_track(Message, TAcc),
{Reply, {NextSeqno, NextTAcc}}
end
end,
{FirstSeqno, 0},
Messages
).
add_msg_track(Message, Tracks) ->
case emqx_message:qos(Message) of
1 -> ?TRACK_FLAG(?ACK) bor Tracks;
2 -> ?TRACK_FLAG(?COMP) bor Tracks;
_ -> Tracks
end.
keep_next_iterator(ItNext, Range = #ds_pubrange{iterator = ItFirst, misc = Misc}) ->
Range#ds_pubrange{
iterator = ItNext,
%% We need to keep the first iterator around, in case we need to preserve
%% this range again, updating still uncommitted tracks it's part of.
misc = Misc#{iterator_first => ItFirst}
}.
restore_first_iterator(Range = #ds_pubrange{misc = Misc = #{iterator_first := ItFirst}}) ->
Range#ds_pubrange{
iterator = ItFirst,
misc = maps:remove(iterator_first, Misc)
}.
ensure_range_tracks(_Tracks, Range = #ds_pubrange{misc = #{?T_tracks := _Existing}}) ->
Range;
ensure_range_tracks(Tracks, Range = #ds_pubrange{}) ->
update_range_tracks(Tracks, Range).
update_range_tracks(?TRACK_FLAG(?ACK), Range = #ds_pubrange{misc = Misc}) ->
%% This is assumed as the default value for the tracks field.
Range#ds_pubrange{misc = maps:remove(?T_tracks, Misc)};
update_range_tracks(Tracks, Range = #ds_pubrange{misc = Misc}) ->
Range#ds_pubrange{misc = Misc#{?T_tracks => Tracks}}.
get_range_tracks(#ds_pubrange{misc = Misc}) ->
%% This is assumed as the default value for the tracks field.
maps:get(?T_tracks, Misc, ?TRACK_FLAG(?ACK)).
-spec preserve_range(ds_pubrange()) -> ok.
preserve_range(Range = #ds_pubrange{type = inflight}) ->
mria:dirty_write(?SESSION_PUBRANGE_TAB, Range).
has_range_track(Track, Range) ->
has_track(Track, get_range_tracks(Range)).
has_track(ack, Tracks) ->
(?TRACK_FLAG(?ACK) band Tracks) > 0;
has_track(comp, Tracks) ->
(?TRACK_FLAG(?COMP) band Tracks) > 0.
-spec discard_range(ds_pubrange()) -> ok.
discard_range(#ds_pubrange{id = RangeId}) ->
mria:dirty_delete(?SESSION_PUBRANGE_TAB, RangeId).
-spec checkpoint_range(ds_pubrange()) -> ds_pubrange().
checkpoint_range(Range0 = #ds_pubrange{type = inflight}) ->
Range = Range0#ds_pubrange{type = checkpoint},
Range = Range0#ds_pubrange{type = checkpoint, misc = #{}},
ok = mria:dirty_write(?SESSION_PUBRANGE_TAB, Range),
Range;
checkpoint_range(Range = #ds_pubrange{type = checkpoint}) ->
@ -320,6 +508,19 @@ get_last_iterator(DSStream = #ds_stream{ref = StreamRef}, Ranges) ->
get_streams(SessionId) ->
mnesia:dirty_read(?SESSION_STREAM_TAB, SessionId).
-spec get_marker(emqx_persistent_session_ds:id(), _Name) -> seqno().
get_marker(SessionId, Name) ->
case mnesia:read(?SESSION_MARKER_TAB, {SessionId, Name}) of
[] ->
1;
[#ds_marker{until = Seqno}] ->
Seqno
end.
-spec update_marker(emqx_persistent_session_ds:id(), _Name, seqno()) -> ok.
update_marker(SessionId, Name, Until) ->
mria:dirty_write(?SESSION_MARKER_TAB, #ds_marker{id = {SessionId, Name}, until = Until}).
next_seqno(Seqno) ->
NextSeqno = Seqno + 1,
case seqno_to_packet_id(NextSeqno) of
@ -332,26 +533,15 @@ next_seqno(Seqno) ->
NextSeqno
end.
%% Reconstruct session counter by adding most significant bits from
%% the current counter to the packet id.
-spec packet_id_to_seqno(_Next :: seqno(), emqx_types:packet_id()) -> seqno().
packet_id_to_seqno(NextSeqNo, PacketId) ->
Epoch = NextSeqNo bsr 16,
case packet_id_to_seqno_(Epoch, PacketId) of
N when N =< NextSeqNo ->
packet_id_to_seqno_(NextSeqno, PacketId) ->
Epoch = NextSeqno bsr 16,
case (Epoch bsl 16) + PacketId of
N when N =< NextSeqno ->
N;
_ ->
packet_id_to_seqno_(Epoch - 1, PacketId)
N ->
N - ?EPOCH_SIZE
end.
-spec packet_id_to_seqno_(non_neg_integer(), emqx_types:packet_id()) -> seqno().
packet_id_to_seqno_(Epoch, PacketId) ->
(Epoch bsl 16) + PacketId.
-spec seqno_to_packet_id(seqno()) -> emqx_types:packet_id() | 0.
seqno_to_packet_id(Seqno) ->
Seqno rem 16#10000.
range_size(FirstSeqno, UntilSeqno) ->
%% This function assumes that gaps in the sequence ID occur _only_ when the
%% packet ID wraps.
@ -379,19 +569,19 @@ ro_transaction(Fun) ->
%% This test only tests boundary conditions (to make sure property-based test didn't skip them):
packet_id_to_seqno_test() ->
%% Packet ID = 1; first epoch:
?assertEqual(1, packet_id_to_seqno(1, 1)),
?assertEqual(1, packet_id_to_seqno(10, 1)),
?assertEqual(1, packet_id_to_seqno(1 bsl 16 - 1, 1)),
?assertEqual(1, packet_id_to_seqno(1 bsl 16, 1)),
?assertEqual(1, packet_id_to_seqno_(1, 1)),
?assertEqual(1, packet_id_to_seqno_(10, 1)),
?assertEqual(1, packet_id_to_seqno_(1 bsl 16 - 1, 1)),
?assertEqual(1, packet_id_to_seqno_(1 bsl 16, 1)),
%% Packet ID = 1; second and 3rd epochs:
?assertEqual(1 bsl 16 + 1, packet_id_to_seqno(1 bsl 16 + 1, 1)),
?assertEqual(1 bsl 16 + 1, packet_id_to_seqno(2 bsl 16, 1)),
?assertEqual(2 bsl 16 + 1, packet_id_to_seqno(2 bsl 16 + 1, 1)),
?assertEqual(1 bsl 16 + 1, packet_id_to_seqno_(1 bsl 16 + 1, 1)),
?assertEqual(1 bsl 16 + 1, packet_id_to_seqno_(2 bsl 16, 1)),
?assertEqual(2 bsl 16 + 1, packet_id_to_seqno_(2 bsl 16 + 1, 1)),
%% Packet ID = 16#ffff:
PID = 1 bsl 16 - 1,
?assertEqual(PID, packet_id_to_seqno(PID, PID)),
?assertEqual(PID, packet_id_to_seqno(1 bsl 16, PID)),
?assertEqual(1 bsl 16 + PID, packet_id_to_seqno(2 bsl 16, PID)),
?assertEqual(PID, packet_id_to_seqno_(PID, PID)),
?assertEqual(PID, packet_id_to_seqno_(1 bsl 16, PID)),
?assertEqual(1 bsl 16 + PID, packet_id_to_seqno_(2 bsl 16, PID)),
ok.
packet_id_to_seqno_test_() ->
@ -406,8 +596,8 @@ packet_id_to_seqno_prop() ->
SeqNo,
seqno_gen(NextSeqNo),
begin
PacketId = SeqNo rem 16#10000,
?assertEqual(SeqNo, packet_id_to_seqno(NextSeqNo, PacketId)),
PacketId = seqno_to_packet_id(SeqNo),
?assertEqual(SeqNo, packet_id_to_seqno_(NextSeqNo, PacketId)),
true
end
)
@ -437,22 +627,37 @@ range_size_test_() ->
compute_inflight_range_test_() ->
[
?_assertEqual(
{1, 1},
{#{ack => 1, comp => 1}, 1},
compute_inflight_range([])
),
?_assertEqual(
{12, 42},
{#{ack => 12, comp => 13}, 42},
compute_inflight_range([
#ds_pubrange{id = {<<>>, 1}, until = 2, type = checkpoint},
#ds_pubrange{id = {<<>>, 4}, until = 8, type = checkpoint},
#ds_pubrange{id = {<<>>, 11}, until = 12, type = checkpoint},
#ds_pubrange{id = {<<>>, 12}, until = 13, type = inflight},
#ds_pubrange{id = {<<>>, 13}, until = 20, type = inflight},
#ds_pubrange{id = {<<>>, 20}, until = 42, type = inflight}
#ds_pubrange{
id = {<<>>, 12},
until = 13,
type = inflight,
misc = #{}
},
#ds_pubrange{
id = {<<>>, 13},
until = 20,
type = inflight,
misc = #{?T_tracks => ?TRACK_FLAG(?COMP)}
},
#ds_pubrange{
id = {<<>>, 20},
until = 42,
type = inflight,
misc = #{?T_tracks => ?TRACK_FLAG(?ACK) bor ?TRACK_FLAG(?COMP)}
}
])
),
?_assertEqual(
{13, 13},
{#{ack => 13, comp => 13}, 13},
compute_inflight_range([
#ds_pubrange{id = {<<>>, 1}, until = 2, type = checkpoint},
#ds_pubrange{id = {<<>>, 4}, until = 8, type = checkpoint},

View File

@ -239,6 +239,7 @@ print_session(ClientId) ->
session => Session,
streams => mnesia:read(?SESSION_STREAM_TAB, ClientId),
pubranges => session_read_pubranges(ClientId),
markers => session_read_markers(ClientId),
subscriptions => session_read_subscriptions(ClientId)
};
[] ->
@ -319,12 +320,13 @@ publish(_PacketId, Msg, Session) ->
{ok, emqx_types:message(), replies(), session()}
| {error, emqx_types:reason_code()}.
puback(_ClientInfo, PacketId, Session = #{id := Id, inflight := Inflight0}) ->
case emqx_persistent_message_ds_replayer:commit_offset(Id, PacketId, Inflight0) of
case emqx_persistent_message_ds_replayer:commit_offset(Id, ack, PacketId, Inflight0) of
{true, Inflight} ->
%% TODO
Msg = #message{},
{ok, Msg, [], Session#{inflight => Inflight}};
{false, _} ->
%% Invalid Packet Id
{error, ?RC_PACKET_IDENTIFIER_NOT_FOUND}
end.
@ -335,9 +337,16 @@ puback(_ClientInfo, PacketId, Session = #{id := Id, inflight := Inflight0}) ->
-spec pubrec(emqx_types:packet_id(), session()) ->
{ok, emqx_types:message(), session()}
| {error, emqx_types:reason_code()}.
pubrec(_PacketId, _Session = #{}) ->
% TODO: stub
{error, ?RC_PACKET_IDENTIFIER_NOT_FOUND}.
pubrec(PacketId, Session = #{id := Id, inflight := Inflight0}) ->
case emqx_persistent_message_ds_replayer:commit_marker(Id, rec, PacketId, Inflight0) of
{true, Inflight} ->
%% TODO
Msg = #message{},
{ok, Msg, Session#{inflight => Inflight}};
{false, _} ->
%% Invalid Packet Id
{error, ?RC_PACKET_IDENTIFIER_NOT_FOUND}
end.
%%--------------------------------------------------------------------
%% Client -> Broker: PUBREL
@ -356,9 +365,16 @@ pubrel(_PacketId, Session = #{}) ->
-spec pubcomp(clientinfo(), emqx_types:packet_id(), session()) ->
{ok, emqx_types:message(), replies(), session()}
| {error, emqx_types:reason_code()}.
pubcomp(_ClientInfo, _PacketId, _Session = #{}) ->
% TODO: stub
{error, ?RC_PACKET_IDENTIFIER_NOT_FOUND}.
pubcomp(_ClientInfo, PacketId, Session = #{id := Id, inflight := Inflight0}) ->
case emqx_persistent_message_ds_replayer:commit_offset(Id, comp, PacketId, Inflight0) of
{true, Inflight} ->
%% TODO
Msg = #message{},
{ok, Msg, [], Session#{inflight => Inflight}};
{false, _} ->
%% Invalid Packet Id
{error, ?RC_PACKET_IDENTIFIER_NOT_FOUND}
end.
%%--------------------------------------------------------------------
@ -375,7 +391,18 @@ handle_timeout(
pull,
Session = #{id := Id, inflight := Inflight0, receive_maximum := ReceiveMaximum}
) ->
{Publishes, Inflight} = emqx_persistent_message_ds_replayer:poll(Id, Inflight0, ReceiveMaximum),
{Publishes, Inflight} = emqx_persistent_message_ds_replayer:poll(
fun
(_Seqno, Message = #message{qos = ?QOS_0}) ->
{false, {undefined, Message}};
(Seqno, Message) ->
PacketId = emqx_persistent_message_ds_replayer:seqno_to_packet_id(Seqno),
{PacketId, Message}
end,
Id,
Inflight0,
ReceiveMaximum
),
IdlePollInterval = emqx_config:get([session_persistence, idle_poll_interval]),
Timeout =
case Publishes of
@ -385,7 +412,7 @@ handle_timeout(
0
end,
ensure_timer(pull, Timeout),
{ok, Publishes, Session#{inflight => Inflight}};
{ok, Publishes, Session#{inflight := Inflight}};
handle_timeout(_ClientInfo, get_streams, Session) ->
renew_streams(Session),
ensure_timer(get_streams),
@ -394,7 +421,24 @@ handle_timeout(_ClientInfo, get_streams, Session) ->
-spec replay(clientinfo(), [], session()) ->
{ok, replies(), session()}.
replay(_ClientInfo, [], Session = #{inflight := Inflight0}) ->
{Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(Inflight0),
AckedUntil = emqx_persistent_message_ds_replayer:committed_until(ack, Inflight0),
RecUntil = emqx_persistent_message_ds_replayer:committed_until(rec, Inflight0),
CompUntil = emqx_persistent_message_ds_replayer:committed_until(comp, Inflight0),
ReplyFun = fun
(_Seqno, #message{qos = ?QOS_0}) ->
{false, []};
(Seqno, #message{qos = ?QOS_1}) when Seqno < AckedUntil ->
[];
(Seqno, #message{qos = ?QOS_2}) when Seqno < CompUntil ->
[];
(Seqno, #message{qos = ?QOS_2}) when Seqno < RecUntil ->
PacketId = emqx_persistent_message_ds_replayer:seqno_to_packet_id(Seqno),
{pubrel, PacketId};
(Seqno, Message) ->
PacketId = emqx_persistent_message_ds_replayer:seqno_to_packet_id(Seqno),
{PacketId, emqx_message:set_flag(dup, true, Message)}
end,
{Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(ReplyFun, Inflight0),
{ok, Replies, Session#{inflight := Inflight}}.
%%--------------------------------------------------------------------
@ -507,11 +551,22 @@ create_tables() ->
{attributes, record_info(fields, ds_pubrange)}
]
),
ok = mria:create_table(
?SESSION_MARKER_TAB,
[
{rlog_shard, ?DS_MRIA_SHARD},
{type, set},
{storage, storage()},
{record_name, ds_marker},
{attributes, record_info(fields, ds_marker)}
]
),
ok = mria:wait_for_tables([
?SESSION_TAB,
?SESSION_SUBSCRIPTIONS_TAB,
?SESSION_STREAM_TAB,
?SESSION_PUBRANGE_TAB
?SESSION_PUBRANGE_TAB,
?SESSION_MARKER_TAB
]),
ok.
@ -578,6 +633,7 @@ session_drop(DSSessionId) ->
transaction(fun() ->
ok = session_drop_subscriptions(DSSessionId),
ok = session_drop_pubranges(DSSessionId),
ok = session_drop_markers(DSSessionId),
ok = session_drop_streams(DSSessionId),
ok = mnesia:delete(?SESSION_TAB, DSSessionId, write)
end).
@ -669,6 +725,17 @@ session_read_pubranges(DSSessionId, LockKind) ->
),
mnesia:select(?SESSION_PUBRANGE_TAB, MS, LockKind).
session_read_markers(DSSessionID) ->
session_read_markers(DSSessionID, read).
session_read_markers(DSSessionId, LockKind) ->
MS = ets:fun2ms(
fun(#ds_marker{id = {Sess, Name}}) when Sess =:= DSSessionId ->
{DSSessionId, Name}
end
),
mnesia:select(?SESSION_MARKER_TAB, MS, LockKind).
-spec new_subscription_id(id(), topic_filter()) -> {subscription_id(), integer()}.
new_subscription_id(DSSessionId, TopicFilter) ->
%% Note: here we use _milliseconds_ to match with the timestamp
@ -778,6 +845,17 @@ session_drop_pubranges(DSSessionId) ->
RangeIds
).
%% must be called inside a transaction
-spec session_drop_markers(id()) -> ok.
session_drop_markers(DSSessionId) ->
MarkerIds = session_read_markers(DSSessionId, write),
lists:foreach(
fun(MarkerId) ->
mnesia:delete(?SESSION_MARKER_TAB, MarkerId, write)
end,
MarkerIds
).
%%--------------------------------------------------------------------------------
transaction(Fun) ->

View File

@ -22,8 +22,12 @@
-define(SESSION_SUBSCRIPTIONS_TAB, emqx_ds_session_subscriptions).
-define(SESSION_STREAM_TAB, emqx_ds_stream_tab).
-define(SESSION_PUBRANGE_TAB, emqx_ds_pubrange_tab).
-define(SESSION_MARKER_TAB, emqx_ds_marker_tab).
-define(DS_MRIA_SHARD, emqx_ds_session_shard).
%% Integer tags for `misc` maps keys.
-define(T_tracks, 1).
-record(ds_sub, {
id :: emqx_persistent_session_ds:subscription_id(),
start_time :: emqx_ds:time(),
@ -64,10 +68,27 @@
%% message in the range.
iterator :: emqx_ds:iterator(),
%% Reserved for future use.
misc = #{} :: map()
misc = #{} :: #{
%% What commit tracks this range is part of.
%% This is rarely stored: we only need to persist it when the range
%% contains QoS 2 messages.
?T_tracks => non_neg_integer(),
_ => _
}
}).
-type ds_pubrange() :: #ds_pubrange{}.
-record(ds_marker, {
id :: {
%% What session this marker belongs to.
_Session :: emqx_persistent_session_ds:id(),
%% Marker name.
_MarkerName
},
%% Where this marker is pointing to: the first seqno that is not marked.
until :: emqx_persistent_message_ds_replayer:seqno()
}).
-record(session, {
%% same as clientid
id :: emqx_persistent_session_ds:id(),

View File

@ -233,7 +233,7 @@ t_session_subscription_iterators(Config) ->
),
ok.
t_qos0(Config) ->
t_qos0(_Config) ->
Sub = connect(<<?MODULE_STRING "1">>, true, 30),
Pub = connect(<<?MODULE_STRING "2">>, true, 0),
try
@ -258,7 +258,7 @@ t_qos0(Config) ->
emqtt:stop(Pub)
end.
t_publish_as_persistent(Config) ->
t_publish_as_persistent(_Config) ->
Sub = connect(<<?MODULE_STRING "1">>, true, 30),
Pub = connect(<<?MODULE_STRING "2">>, true, 30),
try
@ -272,9 +272,8 @@ t_publish_as_persistent(Config) ->
?assertMatch(
[
#{qos := 0, topic := <<"t/1">>, payload := <<"1">>},
#{qos := 1, topic := <<"t/1">>, payload := <<"2">>}
%% TODO: QoS 2
%% #{qos := 2, topic := <<"t/1">>, payload := <<"3">>}
#{qos := 1, topic := <<"t/1">>, payload := <<"2">>},
#{qos := 2, topic := <<"t/1">>, payload := <<"3">>}
],
receive_messages(3)
)

View File

@ -17,6 +17,7 @@
-module(emqx_persistent_session_SUITE).
-include_lib("stdlib/include/assert.hrl").
-include_lib("emqx/include/asserts.hrl").
-include_lib("common_test/include/ct.hrl").
-include_lib("snabbkaffe/include/snabbkaffe.hrl").
-include_lib("emqx/include/emqx_mqtt.hrl").
@ -53,10 +54,10 @@ all() ->
groups() ->
TCs = emqx_common_test_helpers:all(?MODULE),
TCsNonGeneric = [t_choose_impl],
TCGroups = [{group, tcp}, {group, quic}, {group, ws}],
[
{persistence_disabled, [{group, no_kill_connection_process}]},
{persistence_enabled, [{group, no_kill_connection_process}]},
{no_kill_connection_process, [], [{group, tcp}, {group, quic}, {group, ws}]},
{persistence_disabled, TCGroups},
{persistence_enabled, TCGroups},
{tcp, [], TCs},
{quic, [], TCs -- TCsNonGeneric},
{ws, [], TCs -- TCsNonGeneric}
@ -74,7 +75,7 @@ init_per_group(persistence_enabled, Config) ->
{persistence, ds}
| Config
];
init_per_group(Group, Config) when Group == tcp ->
init_per_group(tcp, Config) ->
Apps = emqx_cth_suite:start(
[{emqx, ?config(emqx_config, Config)}],
#{work_dir => emqx_cth_suite:work_dir(Config)}
@ -85,7 +86,7 @@ init_per_group(Group, Config) when Group == tcp ->
{group_apps, Apps}
| Config
];
init_per_group(Group, Config) when Group == ws ->
init_per_group(ws, Config) ->
Apps = emqx_cth_suite:start(
[{emqx, ?config(emqx_config, Config)}],
#{work_dir => emqx_cth_suite:work_dir(Config)}
@ -99,7 +100,7 @@ init_per_group(Group, Config) when Group == ws ->
{group_apps, Apps}
| Config
];
init_per_group(Group, Config) when Group == quic ->
init_per_group(quic, Config) ->
Apps = emqx_cth_suite:start(
[
{emqx,
@ -118,11 +119,7 @@ init_per_group(Group, Config) when Group == quic ->
{ssl, true},
{group_apps, Apps}
| Config
];
init_per_group(no_kill_connection_process, Config) ->
[{kill_connection_process, false} | Config];
init_per_group(kill_connection_process, Config) ->
[{kill_connection_process, true} | Config].
].
get_listener_port(Type, Name) ->
case emqx_config:get([listeners, Type, Name, bind]) of
@ -194,6 +191,8 @@ receive_message_loop(Count, Deadline) ->
receive
{publish, Msg} ->
[Msg | receive_message_loop(Count - 1, Deadline)];
{pubrel, Msg} ->
[{pubrel, Msg} | receive_message_loop(Count - 1, Deadline)];
_Other ->
receive_message_loop(Count, Deadline)
after Timeout ->
@ -201,38 +200,43 @@ receive_message_loop(Count, Deadline) ->
end.
maybe_kill_connection_process(ClientId, Config) ->
case ?config(kill_connection_process, Config) of
true ->
Persistence = ?config(persistence, Config),
case emqx_cm:lookup_channels(ClientId) of
[] ->
ok;
[ConnectionPid] when Persistence == ds ->
Ref = monitor(process, ConnectionPid),
ConnectionPid ! die_if_test,
?assertReceive(
{'DOWN', Ref, process, ConnectionPid, Reason} when
Reason == normal orelse Reason == noproc,
3000
),
wait_connection_process_unregistered(ClientId);
_ ->
ok
end.
wait_connection_process_dies(ClientId) ->
case emqx_cm:lookup_channels(ClientId) of
[] ->
ok;
[ConnectionPid] ->
?assert(is_pid(ConnectionPid)),
Ref = monitor(process, ConnectionPid),
ConnectionPid ! die_if_test,
receive
{'DOWN', Ref, process, ConnectionPid, normal} -> ok
after 3000 -> error(process_did_not_die)
end,
wait_for_cm_unregister(ClientId)
end;
false ->
ok
?assertReceive(
{'DOWN', Ref, process, ConnectionPid, Reason} when
Reason == normal orelse Reason == noproc,
3000
),
wait_connection_process_unregistered(ClientId)
end.
wait_for_cm_unregister(ClientId) ->
wait_for_cm_unregister(ClientId, 100).
wait_for_cm_unregister(_ClientId, 0) ->
error(cm_did_not_unregister);
wait_for_cm_unregister(ClientId, N) ->
case emqx_cm:lookup_channels(ClientId) of
[] ->
ok;
[_] ->
timer:sleep(100),
wait_for_cm_unregister(ClientId, N - 1)
end.
wait_connection_process_unregistered(ClientId) ->
?retry(
_Timeout = 100,
_Retries = 20,
?assertEqual([], emqx_cm:lookup_channels(ClientId))
).
messages(Topic, Payloads) ->
messages(Topic, Payloads, ?QOS_2).
@ -272,23 +276,7 @@ do_publish(Messages = [_ | _], PublishFun, WaitForUnregister) ->
lists:foreach(fun(Message) -> PublishFun(Client, Message) end, Messages),
ok = emqtt:disconnect(Client),
%% Snabbkaffe sometimes fails unless all processes are gone.
case WaitForUnregister of
false ->
ok;
true ->
case emqx_cm:lookup_channels(ClientID) of
[] ->
ok;
[ConnectionPid] ->
?assert(is_pid(ConnectionPid)),
Ref1 = monitor(process, ConnectionPid),
receive
{'DOWN', Ref1, process, ConnectionPid, _} -> ok
after 3000 -> error(process_did_not_die)
end,
wait_for_cm_unregister(ClientID)
end
end
WaitForUnregister andalso wait_connection_process_dies(ClientID)
end
),
receive
@ -438,7 +426,7 @@ t_cancel_on_disconnect(Config) ->
{ok, _} = emqtt:ConnFun(Client1),
ok = emqtt:disconnect(Client1, 0, #{'Session-Expiry-Interval' => 0}),
wait_for_cm_unregister(ClientId),
wait_connection_process_unregistered(ClientId),
{ok, Client2} = emqtt:start_link([
{clientid, ClientId},
@ -470,7 +458,7 @@ t_persist_on_disconnect(Config) ->
%% Strangely enough, the disconnect is reported as successful by emqtt.
ok = emqtt:disconnect(Client1, 0, #{'Session-Expiry-Interval' => 30}),
wait_for_cm_unregister(ClientId),
wait_connection_process_unregistered(ClientId),
{ok, Client2} = emqtt:start_link([
{clientid, ClientId},
@ -582,7 +570,7 @@ t_publish_many_while_client_is_gone_qos1(Config) ->
{clientid, ClientId},
{properties, #{'Session-Expiry-Interval' => 30}},
{clean_start, true},
{auto_ack, false}
{auto_ack, never}
| Config
]),
{ok, _} = emqtt:ConnFun(Client1),
@ -629,8 +617,7 @@ t_publish_many_while_client_is_gone_qos1(Config) ->
?assertEqual(
get_topicwise_order(Pubs1),
get_topicwise_order(Msgs1),
Msgs1
get_topicwise_order(Msgs1)
),
NAcked = 4,
@ -688,21 +675,6 @@ t_publish_many_while_client_is_gone_qos1(Config) ->
ok = emqtt:disconnect(Client2).
get_topicwise_order(Msgs) ->
maps:groups_from_list(fun get_msgpub_topic/1, fun get_msgpub_payload/1, Msgs).
get_msgpub_topic(#mqtt_msg{topic = Topic}) ->
Topic;
get_msgpub_topic(#{topic := Topic}) ->
Topic.
get_msgpub_payload(#mqtt_msg{payload = Payload}) ->
Payload;
get_msgpub_payload(#{payload := Payload}) ->
Payload.
t_publish_while_client_is_gone(init, Config) -> skip_ds_tc(Config);
t_publish_while_client_is_gone('end', _Config) -> ok.
t_publish_while_client_is_gone(Config) ->
%% A persistent session should receive messages in its
%% subscription even if the process owning the session dies.
@ -745,6 +717,157 @@ t_publish_while_client_is_gone(Config) ->
ok = emqtt:disconnect(Client2).
t_publish_many_while_client_is_gone(Config) ->
%% A persistent session should receive all of the still unacked messages
%% for its subscriptions after the client dies or reconnects, in addition
%% to PUBRELs for the messages it has PUBRECed. While client must send
%% PUBACKs and PUBRECs in order, those orders are independent of each other.
ClientId = ?config(client_id, Config),
ConnFun = ?config(conn_fun, Config),
ClientOpts = [
{proto_ver, v5},
{clientid, ClientId},
{properties, #{'Session-Expiry-Interval' => 30}},
{auto_ack, never}
| Config
],
{ok, Client1} = emqtt:start_link([{clean_start, true} | ClientOpts]),
{ok, _} = emqtt:ConnFun(Client1),
{ok, _, [?QOS_1]} = emqtt:subscribe(Client1, <<"t/+/foo">>, ?QOS_1),
{ok, _, [?QOS_2]} = emqtt:subscribe(Client1, <<"msg/feed/#">>, ?QOS_2),
{ok, _, [?QOS_2]} = emqtt:subscribe(Client1, <<"loc/+/+/+">>, ?QOS_2),
Pubs1 = [
#mqtt_msg{topic = <<"t/42/foo">>, payload = <<"M1">>, qos = 1},
#mqtt_msg{topic = <<"t/42/foo">>, payload = <<"M2">>, qos = 1},
#mqtt_msg{topic = <<"msg/feed/me">>, payload = <<"M3">>, qos = 2},
#mqtt_msg{topic = <<"loc/1/2/42">>, payload = <<"M4">>, qos = 2},
#mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M5">>, qos = 2},
#mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M6">>, qos = 1},
#mqtt_msg{topic = <<"loc/3/4/5">>, payload = <<"M7">>, qos = 2},
#mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M8">>, qos = 1},
#mqtt_msg{topic = <<"msg/feed/me">>, payload = <<"M9">>, qos = 2}
],
ok = publish_many(Pubs1),
NPubs1 = length(Pubs1),
Msgs1 = receive_messages(NPubs1),
ct:pal("Msgs1 = ~p", [Msgs1]),
NMsgs1 = length(Msgs1),
?assertEqual(NPubs1, NMsgs1),
?assertEqual(
get_topicwise_order(Pubs1),
get_topicwise_order(Msgs1)
),
%% PUBACK every QoS 1 message.
lists:foreach(
fun(PktId) -> ok = emqtt:puback(Client1, PktId) end,
[PktId || #{qos := 1, packet_id := PktId} <- Msgs1]
),
%% PUBREC first `NRecs` QoS 2 messages.
NRecs = 3,
PubRecs1 = lists:sublist([PktId || #{qos := 2, packet_id := PktId} <- Msgs1], NRecs),
lists:foreach(
fun(PktId) -> ok = emqtt:pubrec(Client1, PktId) end,
PubRecs1
),
%% Ensure that PUBACKs / PUBRECs are propagated to the channel.
pong = emqtt:ping(Client1),
%% Receive PUBRELs for the sent PUBRECs.
PubRels1 = receive_messages(NRecs),
ct:pal("PubRels1 = ~p", [PubRels1]),
?assertEqual(
PubRecs1,
[PktId || {pubrel, #{packet_id := PktId}} <- PubRels1],
PubRels1
),
ok = emqtt:disconnect(Client1),
maybe_kill_connection_process(ClientId, Config),
Pubs2 = [
#mqtt_msg{topic = <<"loc/3/4/5">>, payload = <<"M10">>, qos = 2},
#mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M11">>, qos = 1},
#mqtt_msg{topic = <<"msg/feed/friend">>, payload = <<"M12">>, qos = 2}
],
ok = publish_many(Pubs2),
NPubs2 = length(Pubs2),
{ok, Client2} = emqtt:start_link([{clean_start, false} | ClientOpts]),
{ok, _} = emqtt:ConnFun(Client2),
%% Try to receive _at most_ `NPubs` messages.
%% There shouldn't be that much unacked messages in the replay anyway,
%% but it's an easy number to pick.
NPubs = NPubs1 + NPubs2,
Msgs2 = receive_messages(NPubs, _Timeout = 2000),
ct:pal("Msgs2 = ~p", [Msgs2]),
%% We should again receive PUBRELs for the PUBRECs we sent earlier.
?assertEqual(
get_msgs_essentials(PubRels1),
[get_msg_essentials(PubRel) || PubRel = {pubrel, _} <- Msgs2]
),
%% We should receive duplicates only for QoS 2 messages where PUBRELs were
%% not sent, in the same order as the original messages.
Msgs2Dups = [get_msg_essentials(M) || M = #{dup := true} <- Msgs2],
?assertEqual(
Msgs2Dups,
[M || M = #{qos := 2} <- Msgs2Dups]
),
?assertEqual(
get_msgs_essentials(pick_respective_msgs(Msgs2Dups, Msgs1)),
Msgs2Dups
),
%% Now complete all yet incomplete QoS 2 message flows instead.
PubRecs2 = [PktId || #{qos := 2, packet_id := PktId} <- Msgs2],
lists:foreach(
fun(PktId) -> ok = emqtt:pubrec(Client2, PktId) end,
PubRecs2
),
PubRels2 = receive_messages(length(PubRecs2)),
ct:pal("PubRels2 = ~p", [PubRels2]),
?assertEqual(
PubRecs2,
[PktId || {pubrel, #{packet_id := PktId}} <- PubRels2],
PubRels2
),
%% PUBCOMP every PUBREL.
PubComps = [PktId || {pubrel, #{packet_id := PktId}} <- PubRels1 ++ PubRels2],
lists:foreach(
fun(PktId) -> ok = emqtt:pubcomp(Client2, PktId) end,
PubComps
),
%% Ensure that PUBCOMPs are propagated to the channel.
pong = emqtt:ping(Client2),
ok = emqtt:disconnect(Client2),
maybe_kill_connection_process(ClientId, Config),
{ok, Client3} = emqtt:start_link([{clean_start, false} | ClientOpts]),
{ok, _} = emqtt:ConnFun(Client3),
%% Only the last unacked QoS 1 message should be retransmitted.
Msgs3 = receive_messages(NPubs, _Timeout = 2000),
ct:pal("Msgs3 = ~p", [Msgs3]),
?assertMatch(
[#{topic := <<"t/100/foo">>, payload := <<"M11">>, qos := 1, dup := true}],
Msgs3
),
ok = emqtt:disconnect(Client3).
t_clean_start_drops_subscriptions(Config) ->
%% 1. A persistent session is started and disconnected.
%% 2. While disconnected, a message is published and persisted.
@ -795,6 +918,7 @@ t_clean_start_drops_subscriptions(Config) ->
[Msg1] = receive_messages(1),
?assertEqual({ok, iolist_to_binary(Payload2)}, maps:find(payload, Msg1)),
pong = emqtt:ping(Client2),
ok = emqtt:disconnect(Client2),
maybe_kill_connection_process(ClientId, Config),
@ -812,6 +936,7 @@ t_clean_start_drops_subscriptions(Config) ->
[Msg2] = receive_messages(1),
?assertEqual({ok, iolist_to_binary(Payload3)}, maps:find(payload, Msg2)),
pong = emqtt:ping(Client3),
ok = emqtt:disconnect(Client3).
t_unsubscribe(Config) ->
@ -875,6 +1000,30 @@ t_multiple_subscription_matches(Config) ->
?assertEqual({ok, 2}, maps:find(qos, Msg2)),
ok = emqtt:disconnect(Client2).
get_topicwise_order(Msgs) ->
maps:groups_from_list(fun get_msgpub_topic/1, fun get_msgpub_payload/1, Msgs).
get_msgpub_topic(#mqtt_msg{topic = Topic}) ->
Topic;
get_msgpub_topic(#{topic := Topic}) ->
Topic.
get_msgpub_payload(#mqtt_msg{payload = Payload}) ->
Payload;
get_msgpub_payload(#{payload := Payload}) ->
Payload.
get_msg_essentials(Msg = #{}) ->
maps:with([packet_id, topic, payload, qos], Msg);
get_msg_essentials({pubrel, Msg}) ->
{pubrel, maps:with([packet_id, reason_code], Msg)}.
get_msgs_essentials(Msgs) ->
[get_msg_essentials(M) || M <- Msgs].
pick_respective_msgs(MsgRefs, Msgs) ->
[M || M <- Msgs, Ref <- MsgRefs, maps:get(packet_id, M) =:= maps:get(packet_id, Ref)].
skip_ds_tc(Config) ->
case ?config(persistence, Config) of
ds ->