Merge pull request #11958 from keynslug/ft/EMQX-9745/preserve-acks

feat(sessds): preserve acks / replays in session state
This commit is contained in:
Andrew Mayorov 2023-11-20 23:40:22 +07:00 committed by GitHub
commit 8e107ffe45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 626 additions and 337 deletions

View File

@ -11,12 +11,6 @@
-include_lib("snabbkaffe/include/snabbkaffe.hrl").
-include_lib("emqx/include/emqx_mqtt.hrl").
-include_lib("emqx/src/emqx_persistent_session_ds.hrl").
-define(DEFAULT_KEYSPACE, default).
-define(DS_SHARD_ID, <<"local">>).
-define(DS_SHARD, {?DEFAULT_KEYSPACE, ?DS_SHARD_ID}).
-import(emqx_common_test_helpers, [on_exit/1]).
%%------------------------------------------------------------------------------
@ -92,12 +86,6 @@ get_mqtt_port(Node, Type) ->
{_IP, Port} = erpc:call(Node, emqx_config, get, [[listeners, Type, default, bind]]),
Port.
get_all_iterator_ids(Node) ->
Fn = fun(K, _V, Acc) -> [K | Acc] end,
erpc:call(Node, fun() ->
emqx_ds_storage_layer:foldl_iterator_prefix(?DS_SHARD, <<>>, Fn, [])
end).
wait_nodeup(Node) ->
?retry(
_Sleep0 = 500,
@ -233,9 +221,8 @@ t_session_subscription_idempotency(Config) ->
end,
fun(Trace) ->
ct:pal("trace:\n ~p", [Trace]),
SubTopicFilterWords = emqx_topic:words(SubTopicFilter),
?assertMatch(
{ok, #{}, #{SubTopicFilterWords := #{}}},
#{subscriptions := #{SubTopicFilter := #{}}},
erpc:call(Node1, emqx_persistent_session_ds, session_open, [ClientId])
)
end
@ -308,7 +295,7 @@ t_session_unsubscription_idempotency(Config) ->
fun(Trace) ->
ct:pal("trace:\n ~p", [Trace]),
?assertMatch(
{ok, #{}, Subs = #{}} when map_size(Subs) =:= 0,
#{subscriptions := Subs = #{}} when map_size(Subs) =:= 0,
erpc:call(Node1, emqx_persistent_session_ds, session_open, [ClientId])
),
ok
@ -370,18 +357,12 @@ do_t_session_discard(Params) ->
_Attempts0 = 50,
true = map_size(emqx_persistent_session_ds:list_all_streams()) > 0
),
?retry(
_Sleep0 = 100,
_Attempts0 = 50,
true = map_size(emqx_persistent_session_ds:list_all_iterators()) > 0
),
ok = emqtt:stop(Client0),
?tp(notice, "disconnected", #{}),
?tp(notice, "reconnecting", #{}),
%% we still have iterators and streams
%% we still have streams
?assert(map_size(emqx_persistent_session_ds:list_all_streams()) > 0),
?assert(map_size(emqx_persistent_session_ds:list_all_iterators()) > 0),
Client1 = start_client(ReconnectOpts),
{ok, _} = emqtt:connect(Client1),
?assertEqual([], emqtt:subscriptions(Client1)),
@ -394,7 +375,7 @@ do_t_session_discard(Params) ->
?assertEqual(#{}, emqx_persistent_session_ds:list_all_subscriptions()),
?assertEqual([], emqx_persistent_session_ds_router:topics()),
?assertEqual(#{}, emqx_persistent_session_ds:list_all_streams()),
?assertEqual(#{}, emqx_persistent_session_ds:list_all_iterators()),
?assertEqual(#{}, emqx_persistent_session_ds:list_all_pubranges()),
ok = emqtt:stop(Client1),
?tp(notice, "disconnected", #{}),

View File

@ -19,12 +19,12 @@
-module(emqx_persistent_message_ds_replayer).
%% API:
-export([new/0, next_packet_id/1, replay/2, commit_offset/3, poll/3, n_inflight/1]).
-export([new/0, open/1, next_packet_id/1, replay/1, commit_offset/3, poll/3, n_inflight/1]).
%% internal exports:
-export([]).
-export_type([inflight/0]).
-export_type([inflight/0, seqno/0]).
-include_lib("emqx/include/logger.hrl").
-include("emqx_persistent_session_ds.hrl").
@ -41,19 +41,11 @@
%% Note: sequence numbers are monotonic; they don't wrap around:
-type seqno() :: non_neg_integer().
-record(range, {
stream :: emqx_ds:stream(),
first :: seqno(),
last :: seqno(),
iterator_next :: emqx_ds:iterator() | undefined
}).
-type range() :: #range{}.
-record(inflight, {
next_seqno = 0 :: seqno(),
acked_seqno = 0 :: seqno(),
offset_ranges = [] :: [range()]
next_seqno = 1 :: seqno(),
acked_until = 1 :: seqno(),
%% Ranges are sorted in ascending order of their sequence numbers.
offset_ranges = [] :: [ds_pubrange()]
}).
-opaque inflight() :: #inflight{}.
@ -66,34 +58,37 @@
new() ->
#inflight{}.
-spec open(emqx_persistent_session_ds:id()) -> inflight().
open(SessionId) ->
Ranges = ro_transaction(fun() -> get_ranges(SessionId) end),
{AckedUntil, NextSeqno} = compute_inflight_range(Ranges),
#inflight{
acked_until = AckedUntil,
next_seqno = NextSeqno,
offset_ranges = Ranges
}.
-spec next_packet_id(inflight()) -> {emqx_types:packet_id(), inflight()}.
next_packet_id(Inflight0 = #inflight{next_seqno = LastSeqNo}) ->
Inflight = Inflight0#inflight{next_seqno = LastSeqNo + 1},
case LastSeqNo rem 16#10000 of
0 ->
%% We skip sequence numbers that lead to PacketId = 0 to
%% simplify math. Note: it leads to occasional gaps in the
%% sequence numbers.
next_packet_id(Inflight);
PacketId ->
{PacketId, Inflight}
end.
next_packet_id(Inflight0 = #inflight{next_seqno = LastSeqno}) ->
Inflight = Inflight0#inflight{next_seqno = next_seqno(LastSeqno)},
{seqno_to_packet_id(LastSeqno), Inflight}.
-spec n_inflight(inflight()) -> non_neg_integer().
n_inflight(#inflight{next_seqno = NextSeqNo, acked_seqno = AckedSeqno}) ->
%% NOTE: this function assumes that gaps in the sequence ID occur
%% _only_ when the packet ID wraps:
case AckedSeqno >= ((NextSeqNo bsr 16) bsl 16) of
true ->
NextSeqNo - AckedSeqno;
false ->
NextSeqNo - AckedSeqno - 1
end.
n_inflight(#inflight{next_seqno = NextSeqno, acked_until = AckedUntil}) ->
range_size(AckedUntil, NextSeqno).
-spec replay(emqx_persistent_session_ds:id(), inflight()) ->
emqx_session:replies().
replay(_SessionId, _Inflight = #inflight{offset_ranges = _Ranges}) ->
[].
-spec replay(inflight()) ->
{emqx_session:replies(), inflight()}.
replay(Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0}) ->
{Ranges, Replies} = lists:mapfoldr(
fun(Range, Acc) ->
replay_range(Range, AckedUntil, Acc)
end,
[],
Ranges0
),
Inflight = Inflight0#inflight{offset_ranges = Ranges},
{Replies, Inflight}.
-spec commit_offset(emqx_persistent_session_ds:id(), emqx_types:packet_id(), inflight()) ->
{_IsValidOffset :: boolean(), inflight()}.
@ -101,47 +96,34 @@ commit_offset(
SessionId,
PacketId,
Inflight0 = #inflight{
acked_seqno = AckedSeqno0, next_seqno = NextSeqNo, offset_ranges = Ranges0
acked_until = AckedUntil, next_seqno = NextSeqno
}
) ->
AckedSeqno =
case packet_id_to_seqno(NextSeqNo, PacketId) of
N when N > AckedSeqno0; AckedSeqno0 =:= 0 ->
N;
OutOfRange ->
?SLOG(warning, #{
msg => "out-of-order_ack",
prev_seqno => AckedSeqno0,
acked_seqno => OutOfRange,
next_seqno => NextSeqNo,
packet_id => PacketId
}),
AckedSeqno0
end,
Ranges = lists:filter(
fun(#range{stream = Stream, last = LastSeqno, iterator_next = ItNext}) ->
case LastSeqno =< AckedSeqno of
true ->
%% This range has been fully
%% acked. Remove it and replace saved
%% iterator with the trailing iterator.
update_iterator(SessionId, Stream, ItNext),
false;
false ->
%% This range still has unacked
%% messages:
true
end
end,
Ranges0
),
Inflight = Inflight0#inflight{acked_seqno = AckedSeqno, offset_ranges = Ranges},
{true, Inflight}.
case packet_id_to_seqno(NextSeqno, PacketId) of
Seqno when Seqno >= AckedUntil andalso Seqno < NextSeqno ->
%% TODO
%% We do not preserve `acked_until` in the database. Instead, we discard
%% fully acked ranges from the database. In effect, this means that the
%% most recent `acked_until` the client has sent may be lost in case of a
%% crash or client loss.
Inflight1 = Inflight0#inflight{acked_until = next_seqno(Seqno)},
Inflight = discard_acked(SessionId, Inflight1),
{true, Inflight};
OutOfRange ->
?SLOG(warning, #{
msg => "out-of-order_ack",
acked_until => AckedUntil,
acked_seqno => OutOfRange,
next_seqno => NextSeqno,
packet_id => PacketId
}),
{false, Inflight0}
end.
-spec poll(emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
{emqx_session:replies(), inflight()}.
poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff ->
#inflight{next_seqno = NextSeqNo0, acked_seqno = AckedSeqno} =
#inflight{next_seqno = NextSeqNo0, acked_until = AckedSeqno} =
Inflight0,
FetchThreshold = max(1, WindowSize div 2),
FreeSpace = AckedSeqno + WindowSize - NextSeqNo0,
@ -153,6 +135,7 @@ poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff
%% client get stuck even?
{[], Inflight0};
true ->
%% TODO: Wrap this in `mria:async_dirty/2`?
Streams = shuffle(get_streams(SessionId)),
fetch(SessionId, Inflight0, Streams, FreeSpace, [])
end.
@ -165,75 +148,188 @@ poll(SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < 16#7fff
%% Internal functions
%%================================================================================
fetch(_SessionId, Inflight, _Streams = [], _N, Acc) ->
{lists:reverse(Acc), Inflight};
fetch(_SessionId, Inflight, _Streams, 0, Acc) ->
{lists:reverse(Acc), Inflight};
fetch(SessionId, Inflight0, [Stream | Streams], N, Publishes0) ->
#inflight{next_seqno = FirstSeqNo, offset_ranges = Ranges0} = Inflight0,
ItBegin = get_last_iterator(SessionId, Stream, Ranges0),
compute_inflight_range([]) ->
{1, 1};
compute_inflight_range(Ranges) ->
_RangeLast = #ds_pubrange{until = LastSeqno} = lists:last(Ranges),
RangesUnacked = lists:dropwhile(
fun(#ds_pubrange{type = T}) -> T == checkpoint end,
Ranges
),
case RangesUnacked of
[#ds_pubrange{id = {_, AckedUntil}} | _] ->
{AckedUntil, LastSeqno};
[] ->
{LastSeqno, LastSeqno}
end.
-spec get_ranges(emqx_persistent_session_ds:id()) -> [ds_pubrange()].
get_ranges(SessionId) ->
Pat = erlang:make_tuple(
record_info(size, ds_pubrange),
'_',
[{1, ds_pubrange}, {#ds_pubrange.id, {SessionId, '_'}}]
),
mnesia:match_object(?SESSION_PUBRANGE_TAB, Pat, read).
fetch(SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
#inflight{next_seqno = FirstSeqno, offset_ranges = Ranges} = Inflight0,
ItBegin = get_last_iterator(DSStream, Ranges),
{ok, ItEnd, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N),
{NMessages, Publishes, Inflight1} =
lists:foldl(
fun(Msg, {N0, PubAcc0, InflightAcc0}) ->
{PacketId, InflightAcc} = next_packet_id(InflightAcc0),
PubAcc = [{PacketId, Msg} | PubAcc0],
{N0 + 1, PubAcc, InflightAcc}
end,
{0, Publishes0, Inflight0},
Messages
),
#inflight{next_seqno = LastSeqNo} = Inflight1,
case NMessages > 0 of
true ->
Range = #range{
first = FirstSeqNo,
last = LastSeqNo - 1,
stream = Stream,
iterator_next = ItEnd
{Publishes, UntilSeqno} = publish(FirstSeqno, Messages),
case range_size(FirstSeqno, UntilSeqno) of
Size when Size > 0 ->
%% We need to preserve the iterator pointing to the beginning of the
%% range, so that we can replay it if needed.
Range0 = #ds_pubrange{
id = {SessionId, FirstSeqno},
type = inflight,
until = UntilSeqno,
stream = DSStream#ds_stream.ref,
iterator = ItBegin
},
Inflight = Inflight1#inflight{offset_ranges = Ranges0 ++ [Range]},
fetch(SessionId, Inflight, Streams, N - NMessages, Publishes);
false ->
fetch(SessionId, Inflight1, Streams, N, Publishes)
end.
ok = preserve_range(Range0),
%% ...Yet we need to keep the iterator pointing past the end of the
%% range, so that we can pick up where we left off: it will become
%% `ItBegin` of the next range for this stream.
Range = Range0#ds_pubrange{iterator = ItEnd},
Inflight = Inflight0#inflight{
next_seqno = UntilSeqno,
offset_ranges = Ranges ++ [Range]
},
fetch(SessionId, Inflight, Streams, N - Size, [Publishes | Acc]);
0 ->
fetch(SessionId, Inflight0, Streams, N, Acc)
end;
fetch(_SessionId, Inflight, _Streams, _N, Acc) ->
Publishes = lists:append(lists:reverse(Acc)),
{Publishes, Inflight}.
-spec update_iterator(emqx_persistent_session_ds:id(), emqx_ds:stream(), emqx_ds:iterator()) -> ok.
update_iterator(DSSessionId, Stream, Iterator) ->
%% Workaround: we convert `Stream' to a binary before attempting to store it in
%% mnesia(rocksdb) because of a bug in `mnesia_rocksdb' when trying to do
%% `mnesia:dirty_all_keys' later.
StreamBin = term_to_binary(Stream),
mria:dirty_write(?SESSION_ITER_TAB, #ds_iter{id = {DSSessionId, StreamBin}, iter = Iterator}).
discard_acked(
SessionId,
Inflight0 = #inflight{acked_until = AckedUntil, offset_ranges = Ranges0}
) ->
%% TODO: This could be kept and incrementally updated in the inflight state.
Checkpoints = find_checkpoints(Ranges0),
%% TODO: Wrap this in `mria:async_dirty/2`?
Ranges = discard_acked_ranges(SessionId, AckedUntil, Checkpoints, Ranges0),
Inflight0#inflight{offset_ranges = Ranges}.
get_last_iterator(SessionId, Stream, Ranges) ->
case lists:keyfind(Stream, #range.stream, lists:reverse(Ranges)) of
false ->
get_iterator(SessionId, Stream);
#range{iterator_next = Next} ->
Next
end.
-spec get_iterator(emqx_persistent_session_ds:id(), emqx_ds:stream()) -> emqx_ds:iterator().
get_iterator(DSSessionId, Stream) ->
%% See comment in `update_iterator'.
StreamBin = term_to_binary(Stream),
Id = {DSSessionId, StreamBin},
[#ds_iter{iter = It}] = mnesia:dirty_read(?SESSION_ITER_TAB, Id),
It.
-spec get_streams(emqx_persistent_session_ds:id()) -> [emqx_ds:stream()].
get_streams(SessionId) ->
lists:map(
fun(#ds_stream{stream = Stream}) ->
Stream
find_checkpoints(Ranges) ->
lists:foldl(
fun(#ds_pubrange{stream = StreamRef, until = Until}, Acc) ->
%% For each stream, remember the last range over this stream.
Acc#{StreamRef => Until}
end,
mnesia:dirty_read(?SESSION_STREAM_TAB, SessionId)
#{},
Ranges
).
discard_acked_ranges(
SessionId,
AckedUntil,
Checkpoints,
[Range = #ds_pubrange{until = Until, stream = StreamRef} | Rest]
) when Until =< AckedUntil ->
%% This range has been fully acked.
%% Either discard it completely, or preserve the iterator for the next range
%% over this stream (i.e. a checkpoint).
RangeKept =
case maps:get(StreamRef, Checkpoints) of
CP when CP > Until ->
discard_range(Range),
[];
Until ->
[checkpoint_range(Range)]
end,
%% Since we're (intentionally) not using transactions here, it's important to
%% issue database writes in the same order in which ranges are stored: from
%% the oldest to the newest. This is also why we need to compute which ranges
%% should become checkpoints before we start writing anything.
RangeKept ++ discard_acked_ranges(SessionId, AckedUntil, Checkpoints, Rest);
discard_acked_ranges(_SessionId, _AckedUntil, _Checkpoints, Ranges) ->
%% The rest of ranges (if any) still have unacked messages.
Ranges.
replay_range(
Range0 = #ds_pubrange{type = inflight, id = {_, First}, until = Until, iterator = It},
AckedUntil,
Acc
) ->
Size = range_size(First, Until),
FirstUnacked = max(First, AckedUntil),
{ok, ItNext, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size),
MessagesUnacked =
case FirstUnacked of
First ->
Messages;
_ ->
lists:nthtail(range_size(First, FirstUnacked), Messages)
end,
MessagesReplay = [emqx_message:set_flag(dup, true, Msg) || Msg <- MessagesUnacked],
%% Asserting that range is consistent with the message storage state.
{Replies, Until} = publish(FirstUnacked, MessagesReplay),
%% Again, we need to keep the iterator pointing past the end of the
%% range, so that we can pick up where we left off.
Range = Range0#ds_pubrange{iterator = ItNext},
{Range, Replies ++ Acc};
replay_range(Range0 = #ds_pubrange{type = checkpoint}, _AckedUntil, Acc) ->
{Range0, Acc}.
publish(FirstSeqno, Messages) ->
lists:mapfoldl(
fun(Message, Seqno) ->
PacketId = seqno_to_packet_id(Seqno),
{{PacketId, Message}, next_seqno(Seqno)}
end,
FirstSeqno,
Messages
).
-spec preserve_range(ds_pubrange()) -> ok.
preserve_range(Range = #ds_pubrange{type = inflight}) ->
mria:dirty_write(?SESSION_PUBRANGE_TAB, Range).
-spec discard_range(ds_pubrange()) -> ok.
discard_range(#ds_pubrange{id = RangeId}) ->
mria:dirty_delete(?SESSION_PUBRANGE_TAB, RangeId).
-spec checkpoint_range(ds_pubrange()) -> ds_pubrange().
checkpoint_range(Range0 = #ds_pubrange{type = inflight}) ->
Range = Range0#ds_pubrange{type = checkpoint},
ok = mria:dirty_write(?SESSION_PUBRANGE_TAB, Range),
Range;
checkpoint_range(Range = #ds_pubrange{type = checkpoint}) ->
%% This range should have been checkpointed already.
Range.
get_last_iterator(DSStream = #ds_stream{ref = StreamRef}, Ranges) ->
case lists:keyfind(StreamRef, #ds_pubrange.stream, lists:reverse(Ranges)) of
false ->
DSStream#ds_stream.beginning;
#ds_pubrange{iterator = ItNext} ->
ItNext
end.
-spec get_streams(emqx_persistent_session_ds:id()) -> [ds_stream()].
get_streams(SessionId) ->
mnesia:dirty_read(?SESSION_STREAM_TAB, SessionId).
next_seqno(Seqno) ->
NextSeqno = Seqno + 1,
case seqno_to_packet_id(NextSeqno) of
0 ->
%% We skip sequence numbers that lead to PacketId = 0 to
%% simplify math. Note: it leads to occasional gaps in the
%% sequence numbers.
NextSeqno + 1;
_ ->
NextSeqno
end.
%% Reconstruct session counter by adding most significant bits from
%% the current counter to the packet id.
-spec packet_id_to_seqno(non_neg_integer(), emqx_types:packet_id()) -> non_neg_integer().
-spec packet_id_to_seqno(_Next :: seqno(), emqx_types:packet_id()) -> seqno().
packet_id_to_seqno(NextSeqNo, PacketId) ->
Epoch = NextSeqNo bsr 16,
case packet_id_to_seqno_(Epoch, PacketId) of
@ -243,10 +339,20 @@ packet_id_to_seqno(NextSeqNo, PacketId) ->
packet_id_to_seqno_(Epoch - 1, PacketId)
end.
-spec packet_id_to_seqno_(non_neg_integer(), emqx_types:packet_id()) -> non_neg_integer().
-spec packet_id_to_seqno_(non_neg_integer(), emqx_types:packet_id()) -> seqno().
packet_id_to_seqno_(Epoch, PacketId) ->
(Epoch bsl 16) + PacketId.
-spec seqno_to_packet_id(seqno()) -> emqx_types:packet_id() | 0.
seqno_to_packet_id(Seqno) ->
Seqno rem 16#10000.
range_size(FirstSeqno, UntilSeqno) ->
%% This function assumes that gaps in the sequence ID occur _only_ when the
%% packet ID wraps.
Size = UntilSeqno - FirstSeqno,
Size + (FirstSeqno bsr 16) - (UntilSeqno bsr 16).
-spec shuffle([A]) -> [A].
shuffle(L0) ->
L1 = lists:map(
@ -259,6 +365,10 @@ shuffle(L0) ->
{_, L} = lists:unzip(L2),
L.
ro_transaction(Fun) ->
{atomic, Res} = mria:ro_transaction(?DS_MRIA_SHARD, Fun),
Res.
-ifdef(TEST).
%% This test only tests boundary conditions (to make sure property-based test didn't skip them):
@ -311,4 +421,40 @@ seqno_gen(NextSeqNo) ->
Max = max(0, NextSeqNo - 1),
range(Min, Max).
range_size_test_() ->
[
?_assertEqual(0, range_size(42, 42)),
?_assertEqual(1, range_size(42, 43)),
?_assertEqual(1, range_size(16#ffff, 16#10001)),
?_assertEqual(16#ffff - 456 + 123, range_size(16#1f0000 + 456, 16#200000 + 123))
].
compute_inflight_range_test_() ->
[
?_assertEqual(
{1, 1},
compute_inflight_range([])
),
?_assertEqual(
{12, 42},
compute_inflight_range([
#ds_pubrange{id = {<<>>, 1}, until = 2, type = checkpoint},
#ds_pubrange{id = {<<>>, 4}, until = 8, type = checkpoint},
#ds_pubrange{id = {<<>>, 11}, until = 12, type = checkpoint},
#ds_pubrange{id = {<<>>, 12}, until = 13, type = inflight},
#ds_pubrange{id = {<<>>, 13}, until = 20, type = inflight},
#ds_pubrange{id = {<<>>, 20}, until = 42, type = inflight}
])
),
?_assertEqual(
{13, 13},
compute_inflight_range([
#ds_pubrange{id = {<<>>, 1}, until = 2, type = checkpoint},
#ds_pubrange{id = {<<>>, 4}, until = 8, type = checkpoint},
#ds_pubrange{id = {<<>>, 11}, until = 12, type = checkpoint},
#ds_pubrange{id = {<<>>, 12}, until = 13, type = checkpoint}
])
)
].
-endif.

View File

@ -76,18 +76,19 @@
list_all_sessions/0,
list_all_subscriptions/0,
list_all_streams/0,
list_all_iterators/0
list_all_pubranges/0
]).
-endif.
%% Currently, this is the clientid. We avoid `emqx_types:clientid()' because that can be
%% an atom, in theory (?).
-type id() :: binary().
-type topic_filter() :: emqx_ds:topic_filter().
-type topic_filter() :: emqx_types:topic().
-type topic_filter_words() :: emqx_ds:topic_filter().
-type subscription_id() :: {id(), topic_filter()}.
-type subscription() :: #{
start_time := emqx_ds:time(),
propts := map(),
props := map(),
extra := map()
}.
-type session() :: #{
@ -98,7 +99,7 @@
%% When the session should expire
expires_at := timestamp() | never,
%% Clients Subscriptions.
iterators := #{topic() => subscription()},
subscriptions := #{topic_filter() => subscription()},
%% Inflight messages
inflight := emqx_persistent_message_ds_replayer:inflight(),
%% Receive maximum
@ -108,7 +109,6 @@
}.
-type timestamp() :: emqx_utils_calendar:epoch_millisecond().
-type topic() :: emqx_types:topic().
-type clientinfo() :: emqx_types:clientinfo().
-type conninfo() :: emqx_session:conninfo().
-type replies() :: emqx_session:replies().
@ -142,7 +142,7 @@ open(#{clientid := ClientID} = _ClientInfo, ConnInfo) ->
%% somehow isolate those idling not-yet-expired sessions into a separate process
%% space, and move this call back into `emqx_cm` where it belongs.
ok = emqx_cm:discard_session(ClientID),
case open_session(ClientID) of
case session_open(ClientID) of
Session0 = #{} ->
ensure_timers(),
ReceiveMaximum = receive_maximum(ConnInfo),
@ -153,24 +153,9 @@ open(#{clientid := ClientID} = _ClientInfo, ConnInfo) ->
end.
ensure_session(ClientID, ConnInfo, Conf) ->
{ok, Session, #{}} = session_ensure_new(ClientID, Conf),
Session = session_ensure_new(ClientID, Conf),
ReceiveMaximum = receive_maximum(ConnInfo),
Session#{iterators => #{}, receive_maximum => ReceiveMaximum}.
open_session(ClientID) ->
case session_open(ClientID) of
{ok, Session, Subscriptions} ->
Session#{iterators => prep_subscriptions(Subscriptions)};
false ->
false
end.
prep_subscriptions(Subscriptions) ->
maps:fold(
fun(Topic, Subscription, Acc) -> Acc#{emqx_topic:join(Topic) => Subscription} end,
#{},
Subscriptions
).
Session#{subscriptions => #{}, receive_maximum => ReceiveMaximum}.
-spec destroy(session() | clientinfo()) -> ok.
destroy(#{id := ClientID}) ->
@ -195,9 +180,9 @@ info(created_at, #{created_at := CreatedAt}) ->
CreatedAt;
info(is_persistent, #{}) ->
true;
info(subscriptions, #{iterators := Iters}) ->
info(subscriptions, #{subscriptions := Iters}) ->
maps:map(fun(_, #{props := SubOpts}) -> SubOpts end, Iters);
info(subscriptions_cnt, #{iterators := Iters}) ->
info(subscriptions_cnt, #{subscriptions := Iters}) ->
maps:size(Iters);
info(subscriptions_max, #{props := Conf}) ->
maps:get(max_subscriptions, Conf);
@ -239,47 +224,47 @@ stats(Session) ->
%% Client -> Broker: SUBSCRIBE / UNSUBSCRIBE
%%--------------------------------------------------------------------
-spec subscribe(topic(), emqx_types:subopts(), session()) ->
-spec subscribe(topic_filter(), emqx_types:subopts(), session()) ->
{ok, session()} | {error, emqx_types:reason_code()}.
subscribe(
TopicFilter,
SubOpts,
Session = #{id := ID, iterators := Iters}
) when is_map_key(TopicFilter, Iters) ->
Iterator = maps:get(TopicFilter, Iters),
NIterator = update_subscription(TopicFilter, Iterator, SubOpts, ID),
{ok, Session#{iterators := Iters#{TopicFilter => NIterator}}};
Session = #{id := ID, subscriptions := Subs}
) when is_map_key(TopicFilter, Subs) ->
Subscription = maps:get(TopicFilter, Subs),
NSubscription = update_subscription(TopicFilter, Subscription, SubOpts, ID),
{ok, Session#{subscriptions := Subs#{TopicFilter => NSubscription}}};
subscribe(
TopicFilter,
SubOpts,
Session = #{id := ID, iterators := Iters}
Session = #{id := ID, subscriptions := Subs}
) ->
% TODO: max_subscriptions
Iterator = add_subscription(TopicFilter, SubOpts, ID),
{ok, Session#{iterators := Iters#{TopicFilter => Iterator}}}.
Subscription = add_subscription(TopicFilter, SubOpts, ID),
{ok, Session#{subscriptions := Subs#{TopicFilter => Subscription}}}.
-spec unsubscribe(topic(), session()) ->
-spec unsubscribe(topic_filter(), session()) ->
{ok, session(), emqx_types:subopts()} | {error, emqx_types:reason_code()}.
unsubscribe(
TopicFilter,
Session = #{id := ID, iterators := Iters}
) when is_map_key(TopicFilter, Iters) ->
Iterator = maps:get(TopicFilter, Iters),
SubOpts = maps:get(props, Iterator),
Session = #{id := ID, subscriptions := Subs}
) when is_map_key(TopicFilter, Subs) ->
Subscription = maps:get(TopicFilter, Subs),
SubOpts = maps:get(props, Subscription),
ok = del_subscription(TopicFilter, ID),
{ok, Session#{iterators := maps:remove(TopicFilter, Iters)}, SubOpts};
{ok, Session#{subscriptions := maps:remove(TopicFilter, Subs)}, SubOpts};
unsubscribe(
_TopicFilter,
_Session = #{}
) ->
{error, ?RC_NO_SUBSCRIPTION_EXISTED}.
-spec get_subscription(topic(), session()) ->
-spec get_subscription(topic_filter(), session()) ->
emqx_types:subopts() | undefined.
get_subscription(TopicFilter, #{iterators := Iters}) ->
case maps:get(TopicFilter, Iters, undefined) of
Iterator = #{} ->
maps:get(props, Iterator);
get_subscription(TopicFilter, #{subscriptions := Subs}) ->
case maps:get(TopicFilter, Subs, undefined) of
Subscription = #{} ->
maps:get(props, Subscription);
undefined ->
undefined
end.
@ -292,7 +277,7 @@ get_subscription(TopicFilter, #{iterators := Iters}) ->
{ok, emqx_types:publish_result(), replies(), session()}
| {error, emqx_types:reason_code()}.
publish(_PacketId, Msg, Session) ->
%% TODO:
%% TODO: QoS2
Result = emqx_broker:publish(Msg),
{ok, Result, [], Session}.
@ -374,15 +359,16 @@ handle_timeout(
end,
ensure_timer(pull, Timeout),
{ok, Publishes, Session#{inflight => Inflight}};
handle_timeout(_ClientInfo, get_streams, Session = #{id := Id}) ->
renew_streams(Id),
handle_timeout(_ClientInfo, get_streams, Session) ->
renew_streams(Session),
ensure_timer(get_streams),
{ok, [], Session}.
-spec replay(clientinfo(), [], session()) ->
{ok, replies(), session()}.
replay(_ClientInfo, [], Session = #{}) ->
{ok, [], Session}.
replay(_ClientInfo, [], Session = #{inflight := Inflight0}) ->
{Replies, Inflight} = emqx_persistent_message_ds_replayer:replay(Inflight0),
{ok, Replies, Session#{inflight := Inflight}}.
%%--------------------------------------------------------------------
@ -392,14 +378,13 @@ disconnect(Session = #{}) ->
-spec terminate(Reason :: term(), session()) -> ok.
terminate(_Reason, _Session = #{}) ->
% TODO: close iterators
ok.
%%--------------------------------------------------------------------
-spec add_subscription(topic(), emqx_types:subopts(), id()) ->
-spec add_subscription(topic_filter(), emqx_types:subopts(), id()) ->
subscription().
add_subscription(TopicFilterBin, SubOpts, DSSessionID) ->
add_subscription(TopicFilter, SubOpts, DSSessionID) ->
%% N.B.: we chose to update the router before adding the subscription to the
%% session/iterator table. The reasoning for this is as follows:
%%
@ -418,8 +403,7 @@ add_subscription(TopicFilterBin, SubOpts, DSSessionID) ->
%% since it is guarded by a transaction context: we consider a subscription
%% operation to be successful if it ended up changing this table. Both router
%% and iterator information can be reconstructed from this table, if needed.
ok = emqx_persistent_session_ds_router:do_add_route(TopicFilterBin, DSSessionID),
TopicFilter = emqx_topic:words(TopicFilterBin),
ok = emqx_persistent_session_ds_router:do_add_route(TopicFilter, DSSessionID),
{ok, DSSubExt, IsNew} = session_add_subscription(
DSSessionID, TopicFilter, SubOpts
),
@ -427,20 +411,19 @@ add_subscription(TopicFilterBin, SubOpts, DSSessionID) ->
%% we'll list streams and open iterators when implementing message replay.
DSSubExt.
-spec update_subscription(topic(), subscription(), emqx_types:subopts(), id()) ->
-spec update_subscription(topic_filter(), subscription(), emqx_types:subopts(), id()) ->
subscription().
update_subscription(TopicFilterBin, DSSubExt, SubOpts, DSSessionID) ->
TopicFilter = emqx_topic:words(TopicFilterBin),
update_subscription(TopicFilter, DSSubExt, SubOpts, DSSessionID) ->
{ok, NDSSubExt, false} = session_add_subscription(
DSSessionID, TopicFilter, SubOpts
),
ok = ?tp(persistent_session_ds_iterator_updated, #{sub => DSSubExt}),
NDSSubExt.
-spec del_subscription(topic(), id()) ->
-spec del_subscription(topic_filter(), id()) ->
ok.
del_subscription(TopicFilterBin, DSSessionId) ->
TopicFilter = emqx_topic:words(TopicFilterBin),
del_subscription(TopicFilter, DSSessionId) ->
%% TODO: transaction?
?tp_span(
persistent_session_ds_subscription_delete,
#{session_id => DSSessionId},
@ -449,7 +432,7 @@ del_subscription(TopicFilterBin, DSSessionId) ->
?tp_span(
persistent_session_ds_subscription_route_delete,
#{session_id => DSSessionId},
ok = emqx_persistent_session_ds_router:do_delete_route(TopicFilterBin, DSSessionId)
ok = emqx_persistent_session_ds_router:do_delete_route(TopicFilter, DSSessionId)
).
%%--------------------------------------------------------------------
@ -492,17 +475,20 @@ create_tables() ->
]
),
ok = mria:create_table(
?SESSION_ITER_TAB,
?SESSION_PUBRANGE_TAB,
[
{rlog_shard, ?DS_MRIA_SHARD},
{type, set},
{type, ordered_set},
{storage, storage()},
{record_name, ds_iter},
{attributes, record_info(fields, ds_iter)}
{record_name, ds_pubrange},
{attributes, record_info(fields, ds_pubrange)}
]
),
ok = mria:wait_for_tables([
?SESSION_TAB, ?SESSION_SUBSCRIPTIONS_TAB, ?SESSION_STREAM_TAB, ?SESSION_ITER_TAB
?SESSION_TAB,
?SESSION_SUBSCRIPTIONS_TAB,
?SESSION_STREAM_TAB,
?SESSION_PUBRANGE_TAB
]),
ok.
@ -522,27 +508,34 @@ storage() ->
%% Note: session API doesn't handle session takeovers, it's the job of
%% the broker.
-spec session_open(id()) ->
{ok, session(), #{topic() => subscription()}} | false.
session() | false.
session_open(SessionId) ->
transaction(fun() ->
ro_transaction(fun() ->
case mnesia:read(?SESSION_TAB, SessionId, write) of
[Record = #session{}] ->
Session = export_session(Record),
DSSubs = session_read_subscriptions(SessionId),
Subscriptions = export_subscriptions(DSSubs),
{ok, Session, Subscriptions};
Inflight = emqx_persistent_message_ds_replayer:open(SessionId),
Session#{
subscriptions => Subscriptions,
inflight => Inflight
};
[] ->
false
end
end).
-spec session_ensure_new(id(), _Props :: map()) ->
{ok, session(), #{topic() => subscription()}}.
session().
session_ensure_new(SessionId, Props) ->
transaction(fun() ->
ok = session_drop_subscriptions(SessionId),
Session = export_session(session_create(SessionId, Props)),
{ok, Session, #{}}
Session#{
subscriptions => #{},
inflight => emqx_persistent_message_ds_replayer:new()
}
end).
session_create(SessionId, Props) ->
@ -550,8 +543,7 @@ session_create(SessionId, Props) ->
id = SessionId,
created_at = erlang:system_time(millisecond),
expires_at = never,
props = Props,
inflight = emqx_persistent_message_ds_replayer:new()
props = Props
},
ok = mnesia:write(?SESSION_TAB, Session, write),
Session.
@ -562,7 +554,7 @@ session_create(SessionId, Props) ->
session_drop(DSSessionId) ->
transaction(fun() ->
ok = session_drop_subscriptions(DSSessionId),
ok = session_drop_iterators(DSSessionId),
ok = session_drop_pubranges(DSSessionId),
ok = session_drop_streams(DSSessionId),
ok = mnesia:delete(?SESSION_TAB, DSSessionId, write)
end).
@ -573,8 +565,7 @@ session_drop_subscriptions(DSSessionId) ->
lists:foreach(
fun(#ds_sub{id = DSSubId} = DSSub) ->
TopicFilter = subscription_id_to_topic_filter(DSSubId),
TopicFilterBin = emqx_topic:join(TopicFilter),
ok = emqx_persistent_session_ds_router:do_delete_route(TopicFilterBin, DSSessionId),
ok = emqx_persistent_session_ds_router:do_delete_route(TopicFilter, DSSessionId),
ok = session_del_subscription(DSSub)
end,
Subscriptions
@ -677,77 +668,82 @@ do_ensure_all_iterators_closed(_DSSessionID) ->
%% Reading batches
%%--------------------------------------------------------------------
-spec renew_streams(id()) -> ok.
renew_streams(DSSessionId) ->
Subscriptions = ro_transaction(fun() -> session_read_subscriptions(DSSessionId) end),
ExistingStreams = ro_transaction(fun() -> mnesia:read(?SESSION_STREAM_TAB, DSSessionId) end),
lists:foreach(
fun(#ds_sub{id = {_, TopicFilter}, start_time = StartTime}) ->
renew_streams(DSSessionId, ExistingStreams, TopicFilter, StartTime)
-spec renew_streams(session()) -> ok.
renew_streams(#{id := SessionId, subscriptions := Subscriptions}) ->
transaction(fun() ->
ExistingStreams = mnesia:read(?SESSION_STREAM_TAB, SessionId, write),
maps:fold(
fun(TopicFilter, #{start_time := StartTime}, Streams) ->
TopicFilterWords = emqx_topic:words(TopicFilter),
renew_topic_streams(SessionId, TopicFilterWords, StartTime, Streams)
end,
ExistingStreams,
Subscriptions
)
end),
ok.
-spec renew_topic_streams(id(), topic_filter_words(), emqx_ds:time(), _Acc :: [ds_stream()]) -> ok.
renew_topic_streams(DSSessionId, TopicFilter, StartTime, ExistingStreams) ->
TopicStreams = emqx_ds:get_streams(?PERSISTENT_MESSAGE_DB, TopicFilter, StartTime),
lists:foldl(
fun({Rank, Stream}, Streams) ->
case lists:keymember(Stream, #ds_stream.stream, Streams) of
true ->
Streams;
false ->
StreamRef = length(Streams) + 1,
DSStream = session_store_stream(
DSSessionId,
StreamRef,
Stream,
Rank,
TopicFilter,
StartTime
),
[DSStream | Streams]
end
end,
Subscriptions
ExistingStreams,
TopicStreams
).
-spec renew_streams(id(), [ds_stream()], emqx_ds:topic_filter(), emqx_ds:time()) -> ok.
renew_streams(DSSessionId, ExistingStreams, TopicFilter, StartTime) ->
AllStreams = emqx_ds:get_streams(?PERSISTENT_MESSAGE_DB, TopicFilter, StartTime),
transaction(
fun() ->
lists:foreach(
fun({Rank, Stream}) ->
Rec = #ds_stream{
session = DSSessionId,
topic_filter = TopicFilter,
stream = Stream,
rank = Rank
},
case lists:member(Rec, ExistingStreams) of
true ->
ok;
false ->
mnesia:write(?SESSION_STREAM_TAB, Rec, write),
{ok, Iterator} = emqx_ds:make_iterator(
?PERSISTENT_MESSAGE_DB, Stream, TopicFilter, StartTime
),
%% Workaround: we convert `Stream' to a binary before
%% attempting to store it in mnesia(rocksdb) because of a bug
%% in `mnesia_rocksdb' when trying to do
%% `mnesia:dirty_all_keys' later.
StreamBin = term_to_binary(Stream),
IterRec = #ds_iter{id = {DSSessionId, StreamBin}, iter = Iterator},
mnesia:write(?SESSION_ITER_TAB, IterRec, write)
end
end,
AllStreams
)
end
).
session_store_stream(DSSessionId, StreamRef, Stream, Rank, TopicFilter, StartTime) ->
{ok, ItBegin} = emqx_ds:make_iterator(
?PERSISTENT_MESSAGE_DB,
Stream,
TopicFilter,
StartTime
),
DSStream = #ds_stream{
session = DSSessionId,
ref = StreamRef,
stream = Stream,
rank = Rank,
beginning = ItBegin
},
mnesia:write(?SESSION_STREAM_TAB, DSStream, write),
DSStream.
%% must be called inside a transaction
-spec session_drop_streams(id()) -> ok.
session_drop_streams(DSSessionId) ->
MS = ets:fun2ms(
fun(#ds_stream{session = DSSessionId0}) when DSSessionId0 =:= DSSessionId ->
DSSessionId0
end
),
StreamIDs = mnesia:select(?SESSION_STREAM_TAB, MS, write),
lists:foreach(fun(Key) -> mnesia:delete(?SESSION_STREAM_TAB, Key, write) end, StreamIDs).
mnesia:delete(?SESSION_STREAM_TAB, DSSessionId, write).
%% must be called inside a transaction
-spec session_drop_iterators(id()) -> ok.
session_drop_iterators(DSSessionId) ->
-spec session_drop_pubranges(id()) -> ok.
session_drop_pubranges(DSSessionId) ->
MS = ets:fun2ms(
fun(#ds_iter{id = {DSSessionId0, StreamBin}}) when DSSessionId0 =:= DSSessionId ->
StreamBin
fun(#ds_pubrange{id = {DSSessionId0, First}}) when DSSessionId0 =:= DSSessionId ->
{DSSessionId, First}
end
),
StreamBins = mnesia:select(?SESSION_ITER_TAB, MS, write),
RangeIds = mnesia:select(?SESSION_PUBRANGE_TAB, MS, write),
lists:foreach(
fun(StreamBin) ->
mnesia:delete(?SESSION_ITER_TAB, {DSSessionId, StreamBin}, write)
fun(RangeId) ->
mnesia:delete(?SESSION_PUBRANGE_TAB, RangeId, write)
end,
StreamBins
RangeIds
).
%%--------------------------------------------------------------------------------
@ -772,7 +768,7 @@ export_subscriptions(DSSubs) ->
).
export_session(#session{} = Record) ->
export_record(Record, #session.id, [id, created_at, expires_at, inflight, props], #{}).
export_record(Record, #session.id, [id, created_at, expires_at, props], #{}).
export_subscription(#ds_sub{} = Record) ->
export_record(Record, #ds_sub.start_time, [start_time, props, extra], #{}).
@ -808,10 +804,7 @@ receive_maximum(ConnInfo) ->
list_all_sessions() ->
DSSessionIds = mnesia:dirty_all_keys(?SESSION_TAB),
Sessions = lists:map(
fun(SessionID) ->
{ok, Session, Subscriptions} = session_open(SessionID),
{SessionID, #{session => Session, subscriptions => Subscriptions}}
end,
fun(SessionID) -> {SessionID, session_open(SessionID)} end,
DSSessionIds
),
maps:from_list(Sessions).
@ -850,16 +843,18 @@ list_all_streams() ->
),
maps:from_list(DSStreams).
list_all_iterators() ->
DSIterIds = mnesia:dirty_all_keys(?SESSION_ITER_TAB),
DSIters = lists:map(
fun(DSIterId) ->
[Record] = mnesia:dirty_read(?SESSION_ITER_TAB, DSIterId),
{DSIterId, export_record(Record, #ds_iter.id, [id, iter], #{})}
list_all_pubranges() ->
DSPubranges = mnesia:dirty_match_object(?SESSION_PUBRANGE_TAB, #ds_pubrange{_ = '_'}),
lists:foldl(
fun(Record = #ds_pubrange{id = {SessionId, First}}, Acc) ->
Range = export_record(
Record, #ds_pubrange.until, [until, stream, type, iterator], #{first => First}
),
maps:put(SessionId, maps:get(SessionId, Acc, []) ++ [Range], Acc)
end,
DSIterIds
),
maps:from_list(DSIters).
#{},
DSPubranges
).
%% ifdef(TEST)
-endif.

View File

@ -21,7 +21,7 @@
-define(SESSION_TAB, emqx_ds_session).
-define(SESSION_SUBSCRIPTIONS_TAB, emqx_ds_session_subscriptions).
-define(SESSION_STREAM_TAB, emqx_ds_stream_tab).
-define(SESSION_ITER_TAB, emqx_ds_iter_tab).
-define(SESSION_PUBRANGE_TAB, emqx_ds_pubrange_tab).
-define(DS_MRIA_SHARD, emqx_ds_session_shard).
-record(ds_sub, {
@ -34,17 +34,39 @@
-record(ds_stream, {
session :: emqx_persistent_session_ds:id(),
topic_filter :: emqx_ds:topic_filter(),
ref :: _StreamRef,
stream :: emqx_ds:stream(),
rank :: emqx_ds:stream_rank()
rank :: emqx_ds:stream_rank(),
beginning :: emqx_ds:iterator()
}).
-type ds_stream() :: #ds_stream{}.
-type ds_stream_bin() :: binary().
-record(ds_iter, {
id :: {emqx_persistent_session_ds:id(), ds_stream_bin()},
iter :: emqx_ds:iterator()
-record(ds_pubrange, {
id :: {
%% What session this range belongs to.
_Session :: emqx_persistent_session_ds:id(),
%% Where this range starts.
_First :: emqx_persistent_message_ds_replayer:seqno()
},
%% Where this range ends: the first seqno that is not included in the range.
until :: emqx_persistent_message_ds_replayer:seqno(),
%% Which stream this range is over.
stream :: _StreamRef,
%% Type of a range:
%% * Inflight range is a range of yet unacked messages from this stream.
%% * Checkpoint range was already acked, its purpose is to keep track of the
%% very last iterator for this stream.
type :: inflight | checkpoint,
%% Meaning of this depends on the type of the range:
%% * For inflight range, this is the iterator pointing to the first message in
%% the range.
%% * For checkpoint range, this is the iterator pointing right past the last
%% message in the range.
iterator :: emqx_ds:iterator(),
%% Reserved for future use.
misc = #{} :: map()
}).
-type ds_pubrange() :: #ds_pubrange{}.
-record(session, {
%% same as clientid
@ -52,7 +74,6 @@
%% creation time
created_at :: _Millisecond :: non_neg_integer(),
expires_at = never :: _Millisecond :: non_neg_integer() | never,
inflight :: emqx_persistent_message_ds_replayer:inflight(),
%% for future usage
props = #{} :: map()
}).

View File

@ -181,18 +181,23 @@ client_info(Key, Client) ->
maps:get(Key, maps:from_list(emqtt:info(Client)), undefined).
receive_messages(Count) ->
receive_messages(Count, []).
receive_messages(Count, 15000).
receive_messages(0, Msgs) ->
Msgs;
receive_messages(Count, Msgs) ->
receive_messages(Count, Timeout) ->
Deadline = erlang:monotonic_time(millisecond) + Timeout,
receive_message_loop(Count, Deadline).
receive_message_loop(0, _Deadline) ->
[];
receive_message_loop(Count, Deadline) ->
Timeout = max(0, Deadline - erlang:monotonic_time(millisecond)),
receive
{publish, Msg} ->
receive_messages(Count - 1, [Msg | Msgs]);
[Msg | receive_message_loop(Count - 1, Deadline)];
_Other ->
receive_messages(Count, Msgs)
after 15000 ->
Msgs
receive_message_loop(Count, Deadline)
after Timeout ->
[]
end.
maybe_kill_connection_process(ClientId, Config) ->
@ -229,16 +234,28 @@ wait_for_cm_unregister(ClientId, N) ->
wait_for_cm_unregister(ClientId, N - 1)
end.
publish(Topic, Payloads) ->
publish(Topic, Payloads, false, 2).
messages(Topic, Payloads) ->
messages(Topic, Payloads, ?QOS_2).
publish(Topic, Payloads, WaitForUnregister, QoS) ->
Fun = fun(Client, Payload) ->
{ok, _} = emqtt:publish(Client, Topic, Payload, QoS)
messages(Topic, Payloads, QoS) ->
[#mqtt_msg{topic = Topic, payload = P, qos = QoS} || P <- Payloads].
publish(Topic, Payload) ->
publish(Topic, Payload, ?QOS_2).
publish(Topic, Payload, QoS) ->
publish_many(messages(Topic, [Payload], QoS)).
publish_many(Messages) ->
publish_many(Messages, false).
publish_many(Messages, WaitForUnregister) ->
Fun = fun(Client, Message) ->
{ok, _} = emqtt:publish(Client, Message)
end,
do_publish(Payloads, Fun, WaitForUnregister).
do_publish(Messages, Fun, WaitForUnregister).
do_publish(Payloads = [_ | _], PublishFun, WaitForUnregister) ->
do_publish(Messages = [_ | _], PublishFun, WaitForUnregister) ->
%% Publish from another process to avoid connection confusion.
{Pid, Ref} =
spawn_monitor(
@ -252,7 +269,7 @@ do_publish(Payloads = [_ | _], PublishFun, WaitForUnregister) ->
{port, 1883}
]),
{ok, _} = emqtt:connect(Client),
lists:foreach(fun(Payload) -> PublishFun(Client, Payload) end, Payloads),
lists:foreach(fun(Message) -> PublishFun(Client, Message) end, Messages),
ok = emqtt:disconnect(Client),
%% Snabbkaffe sometimes fails unless all processes are gone.
case WaitForUnregister of
@ -277,9 +294,7 @@ do_publish(Payloads = [_ | _], PublishFun, WaitForUnregister) ->
receive
{'DOWN', Ref, process, Pid, normal} -> ok;
{'DOWN', Ref, process, Pid, What} -> error({failed_publish, What})
end;
do_publish(Payload, PublishFun, WaitForUnregister) ->
do_publish([Payload], PublishFun, WaitForUnregister).
end.
%%--------------------------------------------------------------------
%% Test Cases
@ -494,7 +509,7 @@ t_process_dies_session_expires(Config) ->
maybe_kill_connection_process(ClientId, Config),
ok = publish(Topic, [Payload]),
ok = publish(Topic, Payload),
timer:sleep(1100),
@ -535,7 +550,7 @@ t_publish_while_client_is_gone_qos1(Config) ->
ok = emqtt:disconnect(Client1),
maybe_kill_connection_process(ClientId, Config),
ok = publish(Topic, [Payload1, Payload2], false, 1),
ok = publish_many(messages(Topic, [Payload1, Payload2], ?QOS_1)),
{ok, Client2} = emqtt:start_link([
{proto_ver, v5},
@ -547,7 +562,7 @@ t_publish_while_client_is_gone_qos1(Config) ->
{ok, _} = emqtt:ConnFun(Client2),
Msgs = receive_messages(2),
?assertMatch([_, _], Msgs),
[Msg2, Msg1] = Msgs,
[Msg1, Msg2] = Msgs,
?assertEqual({ok, iolist_to_binary(Payload1)}, maps:find(payload, Msg1)),
?assertEqual({ok, 1}, maps:find(qos, Msg1)),
?assertEqual({ok, iolist_to_binary(Payload2)}, maps:find(payload, Msg2)),
@ -555,6 +570,137 @@ t_publish_while_client_is_gone_qos1(Config) ->
ok = emqtt:disconnect(Client2).
t_publish_many_while_client_is_gone_qos1(Config) ->
%% A persistent session should receive all of the still unacked messages
%% for its subscriptions after the client dies or reconnects, in addition
%% to new messages that were published while the client was gone. The order
%% of the messages should be consistent across reconnects.
ClientId = ?config(client_id, Config),
ConnFun = ?config(conn_fun, Config),
{ok, Client1} = emqtt:start_link([
{proto_ver, v5},
{clientid, ClientId},
{properties, #{'Session-Expiry-Interval' => 30}},
{clean_start, true},
{auto_ack, false}
| Config
]),
{ok, _} = emqtt:ConnFun(Client1),
STopics = [
<<"t/+/foo">>,
<<"msg/feed/#">>,
<<"loc/+/+/+">>
],
[{ok, _, [?QOS_1]} = emqtt:subscribe(Client1, ST, ?QOS_1) || ST <- STopics],
Pubs1 = [
#mqtt_msg{topic = <<"t/42/foo">>, payload = <<"M1">>, qos = 1},
#mqtt_msg{topic = <<"t/42/foo">>, payload = <<"M2">>, qos = 1},
#mqtt_msg{topic = <<"msg/feed/me">>, payload = <<"M3">>, qos = 1},
#mqtt_msg{topic = <<"loc/1/2/42">>, payload = <<"M4">>, qos = 1},
#mqtt_msg{topic = <<"t/42/foo">>, payload = <<"M5">>, qos = 1},
#mqtt_msg{topic = <<"loc/3/4/5">>, payload = <<"M6">>, qos = 1},
#mqtt_msg{topic = <<"msg/feed/me">>, payload = <<"M7">>, qos = 1}
],
ok = publish_many(Pubs1),
NPubs1 = length(Pubs1),
Msgs1 = receive_messages(NPubs1),
NMsgs1 = length(Msgs1),
?assertEqual(NPubs1, NMsgs1),
ct:pal("Msgs1 = ~p", [Msgs1]),
%% TODO
%% This assertion doesn't currently hold because `emqx_ds` doesn't enforce
%% strict ordering reflecting client publishing order. Instead, per-topic
%% ordering is guaranteed per each client. In fact, this violates the MQTT
%% specification, but we deemed it acceptable for now.
%% ?assertMatch([
%% #{payload := <<"M1">>},
%% #{payload := <<"M2">>},
%% #{payload := <<"M3">>},
%% #{payload := <<"M4">>},
%% #{payload := <<"M5">>},
%% #{payload := <<"M6">>},
%% #{payload := <<"M7">>}
%% ], Msgs1),
?assertEqual(
get_topicwise_order(Pubs1),
get_topicwise_order(Msgs1),
Msgs1
),
NAcked = 4,
[ok = emqtt:puback(Client1, PktId) || #{packet_id := PktId} <- lists:sublist(Msgs1, NAcked)],
%% Ensure that PUBACKs are propagated to the channel.
pong = emqtt:ping(Client1),
ok = emqtt:disconnect(Client1),
maybe_kill_connection_process(ClientId, Config),
Pubs2 = [
#mqtt_msg{topic = <<"loc/3/4/5">>, payload = <<"M8">>, qos = 1},
#mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M9">>, qos = 1},
#mqtt_msg{topic = <<"t/100/foo">>, payload = <<"M10">>, qos = 1},
#mqtt_msg{topic = <<"msg/feed/friend">>, payload = <<"M11">>, qos = 1},
#mqtt_msg{topic = <<"msg/feed/me">>, payload = <<"M12">>, qos = 1}
],
ok = publish_many(Pubs2),
NPubs2 = length(Pubs2),
{ok, Client2} = emqtt:start_link([
{proto_ver, v5},
{clientid, ClientId},
{properties, #{'Session-Expiry-Interval' => 30}},
{clean_start, false},
{auto_ack, false}
| Config
]),
{ok, _} = emqtt:ConnFun(Client2),
%% Try to receive _at most_ `NPubs` messages.
%% There shouldn't be that much unacked messages in the replay anyway,
%% but it's an easy number to pick.
NPubs = NPubs1 + NPubs2,
Msgs2 = receive_messages(NPubs, _Timeout = 2000),
NMsgs2 = length(Msgs2),
ct:pal("Msgs2 = ~p", [Msgs2]),
?assert(NMsgs2 < NPubs, Msgs2),
?assert(NMsgs2 > NPubs2, Msgs2),
?assert(NMsgs2 >= NPubs - NAcked, Msgs2),
NSame = NMsgs2 - NPubs2,
?assert(
lists:all(fun(#{dup := Dup}) -> Dup end, lists:sublist(Msgs2, NSame))
),
?assertNot(
lists:all(fun(#{dup := Dup}) -> Dup end, lists:nthtail(NSame, Msgs2))
),
?assertEqual(
[maps:with([packet_id, topic, payload], M) || M <- lists:nthtail(NMsgs1 - NSame, Msgs1)],
[maps:with([packet_id, topic, payload], M) || M <- lists:sublist(Msgs2, NSame)]
),
ok = emqtt:disconnect(Client2).
get_topicwise_order(Msgs) ->
maps:groups_from_list(fun get_msgpub_topic/1, fun get_msgpub_payload/1, Msgs).
get_msgpub_topic(#mqtt_msg{topic = Topic}) ->
Topic;
get_msgpub_topic(#{topic := Topic}) ->
Topic.
get_msgpub_payload(#mqtt_msg{payload = Payload}) ->
Payload;
get_msgpub_payload(#{payload := Payload}) ->
Payload.
t_publish_while_client_is_gone(init, Config) -> skip_ds_tc(Config);
t_publish_while_client_is_gone('end', _Config) -> ok.
t_publish_while_client_is_gone(Config) ->
@ -579,7 +725,7 @@ t_publish_while_client_is_gone(Config) ->
ok = emqtt:disconnect(Client1),
maybe_kill_connection_process(ClientId, Config),
ok = publish(Topic, [Payload1, Payload2]),
ok = publish_many(messages(Topic, [Payload1, Payload2])),
{ok, Client2} = emqtt:start_link([
{proto_ver, v5},
@ -591,7 +737,7 @@ t_publish_while_client_is_gone(Config) ->
{ok, _} = emqtt:ConnFun(Client2),
Msgs = receive_messages(2),
?assertMatch([_, _], Msgs),
[Msg2, Msg1] = Msgs,
[Msg1, Msg2] = Msgs,
?assertEqual({ok, iolist_to_binary(Payload1)}, maps:find(payload, Msg1)),
?assertEqual({ok, 2}, maps:find(qos, Msg1)),
?assertEqual({ok, iolist_to_binary(Payload2)}, maps:find(payload, Msg2)),