diff --git a/include/emqx_mqtt.hrl b/include/emqx_mqtt.hrl index 7e7670112..3bba42216 100644 --- a/include/emqx_mqtt.hrl +++ b/include/emqx_mqtt.hrl @@ -171,10 +171,16 @@ -define(RC_WILDCARD_SUBSCRIPTIONS_NOT_SUPPORTED, 16#A2). %%-------------------------------------------------------------------- -%% Maximum MQTT Packet Length +%% Maximum MQTT Packet ID and Length %%-------------------------------------------------------------------- +-define(MAX_PACKET_ID, 16#ffff). -define(MAX_PACKET_SIZE, 16#fffffff). +-define(BUMP_PACKET_ID(Base, Bump), + case Base + Bump of + __I__ when __I__ > ?MAX_PACKET_ID -> __I__ - ?MAX_PACKET_ID; + __I__ -> __I__ + end). %%-------------------------------------------------------------------- %% MQTT Frame Mask diff --git a/src/emqx_client.erl b/src/emqx_client.erl index 389c9e902..71527871f 100644 --- a/src/emqx_client.erl +++ b/src/emqx_client.erl @@ -389,8 +389,8 @@ publish(Client, Topic, Properties, Payload, Opts) props = Properties, payload = iolist_to_binary(Payload)}). --spec(publish(client(), #mqtt_msg{}) -> ok | {ok, packet_id()} | {error, term()}). -publish(Client, Msg) when is_record(Msg, mqtt_msg) -> +-spec(publish(client(), #mqtt_msg{} | [#mqtt_msg{}]) -> ok | {ok, packet_id()} | {error, term()}). +publish(Client, Msg) -> gen_statem:call(Client, {publish, Msg}). -spec(unsubscribe(client(), topic() | [topic()]) -> subscribe_ret()). @@ -756,9 +756,6 @@ connected({call, From}, pause, State) -> connected({call, From}, resume, State) -> {keep_state, State#state{paused = false}, [{reply, From, ok}]}; -connected({call, From}, stop, _State) -> - {stop_and_reply, normal, [{reply, From, ok}]}; - connected({call, From}, get_properties, State = #state{properties = Properties}) -> {keep_state, State, [{reply, From, Properties}]}; @@ -790,19 +787,22 @@ connected({call, From}, {publish, Msg = #mqtt_msg{qos = ?QOS_0}}, State) -> {stop_and_reply, Reason, [{reply, From, Error}]} end; -connected({call, From}, {publish, Msg = #mqtt_msg{qos = QoS}}, - State = #state{inflight = Inflight, last_packet_id = PacketId}) +connected({call, From}, {publish, Msg = #mqtt_msg{qos = QoS}}, State) when (QoS =:= ?QOS_1); (QoS =:= ?QOS_2) -> + connected({call, From}, {publish, [Msg]}, State); + +%% when publishing a batch, {ok, BasePacketId} is returned, +%% following packet ids for the batch tail are mod (1 bsl 16) consecutive +connected({call, From}, {publish, Msgs}, + State = #state{inflight = Inflight, last_packet_id = PacketId}) when is_list(Msgs) -> + %% NOTE: to ensure API call atomicity, inflight buffer may overflow case emqx_inflight:is_full(Inflight) of true -> - {keep_state, State, [{reply, From, {error, {PacketId, inflight_full}}}]}; + {keep_state, State, [{reply, From, {error, inflight_full}}]}; false -> - Msg1 = Msg#mqtt_msg{packet_id = PacketId}, - case send(Msg1, State) of + case send_batch(assign_packet_id(Msgs, PacketId), State) of {ok, NewState} -> - Inflight1 = emqx_inflight:insert(PacketId, {publish, Msg1, os:timestamp()}, Inflight), - {keep_state, ensure_retry_timer(NewState#state{inflight = Inflight1}), - [{reply, From, {ok, PacketId}}]}; + {keep_state, ensure_retry_timer(NewState), [{reply, From, {ok, PacketId}}]}; {error, Reason} -> {stop_and_reply, Reason, [{reply, From, {error, {PacketId, Reason}}}]} end @@ -1011,6 +1011,8 @@ should_ping(Sock) -> Error end. +handle_event({call, From}, stop, _StateName, _State) -> + {stop_and_reply, normal, [{reply, From, ok}]}; handle_event(info, {TcpOrSsL, _Sock, Data}, _StateName, State) when TcpOrSsL =:= tcp; TcpOrSsL =:= ssl -> emqx_logger:debug("RECV Data: ~p", [Data]), @@ -1333,6 +1335,17 @@ send_puback(Packet, State) -> {error, Reason} -> {stop, {shutdown, Reason}} end. +send_batch([], State) -> {ok, State}; +send_batch([Msg = #mqtt_msg{packet_id = PacketId} | Rest], + State = #state{inflight = Inflight}) -> + case send(Msg, State) of + {ok, NewState} -> + Inflight1 = emqx_inflight:insert(PacketId, {publish, Msg, os:timestamp()}, Inflight), + send_batch(Rest, NewState#state{inflight = Inflight1}); + {error, Reason} -> + {error, Reason} + end. + send(Msg, State) when is_record(Msg, mqtt_msg) -> send(msg_to_packet(Msg), State); @@ -1375,10 +1388,17 @@ next_events(Packets) -> [{next_event, cast, Packet} || Packet <- lists:reverse(Packets)]. %%------------------------------------------------------------------------------ -%% Next packet id +%% packet_id generation and assignment -next_packet_id(State = #state{last_packet_id = 16#ffff}) -> - State#state{last_packet_id = 1}; +assign_packet_id(Msg = #mqtt_msg{}, Id) -> + Msg#mqtt_msg{packet_id = Id}; +assign_packet_id([H | T], Id) -> + [assign_packet_id(H, Id) | assign_packet_id(T, next_packet_id(Id))]; +assign_packet_id([], _Id) -> + []. next_packet_id(State = #state{last_packet_id = Id}) -> - State#state{last_packet_id = Id + 1}. + State#state{last_packet_id = next_packet_id(Id)}; +next_packet_id(16#ffff) -> 1; +next_packet_id(Id) -> Id + 1. + diff --git a/src/emqx_topic.erl b/src/emqx_topic.erl index 4c90c3f39..bb615ccfe 100644 --- a/src/emqx_topic.erl +++ b/src/emqx_topic.erl @@ -144,7 +144,8 @@ prepend(Parent0, W) -> bin('') -> <<>>; bin('+') -> <<"+">>; bin('#') -> <<"#">>; -bin(B) when is_binary(B) -> B. +bin(B) when is_binary(B) -> B; +bin(L) when is_list(L) -> list_to_binary(L). levels(Topic) when is_binary(Topic) -> length(words(Topic)). diff --git a/src/portal/emqx_portal.erl b/src/portal/emqx_portal.erl index 623f7f233..51ce4a4dc 100644 --- a/src/portal/emqx_portal.erl +++ b/src/portal/emqx_portal.erl @@ -78,7 +78,7 @@ ]). -type config() :: map(). --type batch() :: [emqx_portal_msg:msg()]. +-type batch() :: [emqx_portal_msg:exp_msg()]. -type ack_ref() :: term(). -include("logger.hrl"). @@ -352,7 +352,7 @@ maybe_send(#{connect_module := Module, connection := Connection, mountpoint := Mountpoint }, Batch) -> - Module:send(Connection, [emqx_portal_msg:to_export(M, Mountpoint) || M <- Batch]). + Module:send(Connection, [emqx_portal_msg:to_export(Module, Mountpoint, M) || M <- Batch]). format_mountpoint(undefined) -> undefined; diff --git a/src/portal/emqx_portal_mqtt.erl b/src/portal/emqx_portal_mqtt.erl index 0ce5140b0..f01633111 100644 --- a/src/portal/emqx_portal_mqtt.erl +++ b/src/portal/emqx_portal_mqtt.erl @@ -28,7 +28,8 @@ -define(ACK_REF(ClientPid, PktId), {ClientPid, PktId}). %% Messages towards ack collector process --define(SENT(MaxPktId), {sent, MaxPktId}). +-define(RANGE(Min, Max), {Min, Max}). +-define(SENT(PktIdRange), {sent, PktIdRange}). -define(ACKED(AnyPktId), {acked, AnyPktId}). -define(STOP(Ref), {stop, Ref}). @@ -41,10 +42,17 @@ start(Config) -> {ok, Pid} -> case emqx_client:connect(Pid) of {ok, _} -> - %% ack collector is always a new pid every reconnect. - %% use it as a connection reference - {ok, Ref, #{ack_collector => AckCollector, - client_pid => Pid}}; + try + subscribe_remote_topics(Pid, maps:get(subscriptions, Config, [])), + %% ack collector is always a new pid every reconnect. + %% use it as a connection reference + {ok, Ref, #{ack_collector => AckCollector, + client_pid => Pid}} + catch + throw : Reason -> + ok = stop(AckCollector, Pid), + {error, Reason} + end; {error, Reason} -> ok = stop(AckCollector, Pid), {error, Reason} @@ -53,72 +61,79 @@ start(Config) -> {error, Reason} end. -stop(Ref, #{ack_collector := AckCollector, - client_pid := Pid}) -> - MRef = monitor(process, AckCollector), - unlink(AckCollector), - _ = AckCollector ! ?STOP(Ref), +stop(Ref, #{ack_collector := AckCollector, client_pid := Pid}) -> + safe_stop(AckCollector, fun() -> AckCollector ! ?STOP(Ref) end, 1000), + safe_stop(Pid, fun() -> emqx_client:stop(Pid) end, 1000), + ok. + +safe_stop(Pid, StopF, Timeout) -> + MRef = monitor(process, Pid), + unlink(Pid), + try + StopF() + catch + _ : _ -> + ok + end, receive {'DOWN', MRef, _, _, _} -> ok after - 1000 -> - exit(AckCollector, kill) - end, - _ = emqx_client:stop(Pid), - ok. + Timeout -> + exit(Pid, kill) + end. -send(#{client_pid := ClientPid, ack_collector := AckCollector}, Batch) -> - send_loop(ClientPid, AckCollector, Batch). - -send_loop(ClientPid, AckCollector, [Msg | Rest]) -> - case emqx_client:publish(ClientPid, Msg) of - {ok, PktId} when Rest =:= [] -> - Rest =:= [] andalso AckCollector ! ?SENT(PktId), - {ok, PktId}; - {ok, _PktId} -> - send_loop(ClientPid, AckCollector, Rest); +send(#{client_pid := ClientPid, ack_collector := AckCollector} = Conn, Batch) -> + case emqx_client:publish(ClientPid, Batch) of + {ok, BasePktId} -> + LastPktId = ?BUMP_PACKET_ID(BasePktId, length(Batch) - 1), + AckCollector ! ?SENT(?RANGE(BasePktId, LastPktId)), + %% return last pakcet id as batch reference + {ok, LastPktId}; {error, {_PacketId, inflight_full}} -> timer:sleep(100), - send_loop(ClientPid, AckCollector, [Msg | Rest]); + send(Conn, Batch); {error, Reason} -> - %% There is no partial sucess of a batch and recover from the middle + %% NOTE: There is no partial sucess of a batch and recover from the middle %% only to retry all messages in one batch {error, Reason} end. ack_collector(Parent, ConnRef) -> - ack_collector(Parent, ConnRef, []). + ack_collector(Parent, ConnRef, queue:new(), []). -ack_collector(Parent, ConnRef, PktIds) -> - NewIds = +ack_collector(Parent, ConnRef, Acked, Sent) -> + {NewAcked, NewSent} = receive ?STOP(ConnRef) -> exit(normal); - ?SENT(PktId) -> - %% this ++ only happens per-BATCH, hence no optimization - PktIds ++ [PktId]; ?ACKED(PktId) -> - handle_ack(Parent, PktId, PktIds) + match_acks(Parent, queue:in(PktId, Acked), Sent); + ?SENT(Range) -> + %% this message only happens per-batch, hence ++ is ok + match_acks(Parent, Acked, Sent ++ [Range]) after 200 -> - PktIds + {Acked, Sent} end, - ack_collector(Parent, ConnRef, NewIds). + ack_collector(Parent, ConnRef, NewAcked, NewSent). -handle_ack(Parent, PktId, [PktId | Rest]) -> - %% A batch is finished, time to ack portal +match_acks(_Parent, Acked, []) -> {Acked, []}; +match_acks(Parent, Acked, Sent) -> + match_acks_1(Parent, queue:out(Acked), Sent). + +match_acks_1(_Parent, {empty, Empty}, Sent) -> {Empty, Sent}; +match_acks_1(Parent, {{value, PktId}, Acked}, [?RANGE(PktId, PktId) | Sent]) -> + %% batch finished ok = emqx_portal:handle_ack(Parent, PktId), - Rest; -handle_ack(_Parent, PktId, [BatchMaxPktId | _] = All) -> - %% partial ack of a batch, terminate here. - true = (PktId < BatchMaxPktId), %% bad order otherwise - All. + match_acks(Parent, Acked, Sent); +match_acks_1(Parent, {{value, PktId}, Acked}, [?RANGE(PktId, Max) | Sent]) -> + match_acks(Parent, Acked, [?RANGE(PktId + 1, Max) | Sent]). %% When puback for QoS-1 message is received from remote MQTT broker %% NOTE: no support for QoS-2 handle_puback(AckCollector, #{packet_id := PktId, reason_code := RC}) -> - RC =:= ?RC_SUCCESS andalso error(RC), + RC =:= ?RC_SUCCESS orelse error({puback_error_code, RC}), AckCollector ! ?ACKED(PktId), ok. @@ -133,3 +148,10 @@ make_hdlr(Parent, AckCollector, Ref) -> disconnected => fun(RC, _Properties) -> Parent ! {disconnected, Ref, RC}, ok end }. +subscribe_remote_topics(ClientPid, Subscriptions) -> + [case emqx_client:subscribe(ClientPid, {bin(Topic), Qos}) of + {ok, _, _} -> ok; + Error -> throw(Error) + end || {Topic, Qos} <- Subscriptions, emqx_topic:validate({filter, bin(Topic)})]. + +bin(L) -> iolist_to_binary(L). diff --git a/src/portal/emqx_portal_msg.erl b/src/portal/emqx_portal_msg.erl index 12f5926a3..f8554f0b6 100644 --- a/src/portal/emqx_portal_msg.erl +++ b/src/portal/emqx_portal_msg.erl @@ -16,7 +16,7 @@ -export([ to_binary/1 , from_binary/1 - , to_export/2 + , to_export/3 , to_broker_msgs/1 , estimate_size/1 ]). @@ -25,14 +25,32 @@ -include("emqx.hrl"). -include("emqx_mqtt.hrl"). +-include("emqx_client.hrl"). -type msg() :: emqx_types:message(). +-type exp_msg() :: emqx_types:message() | #mqtt_msg{}. %% @doc Make export format: %% 1. Mount topic to a prefix -%% 2. fix QoS to 1 --spec to_export(msg(), undefined | binary()) -> msg(). -to_export(#message{topic = Topic} = Msg, Mountpoint) -> +%% 2. Fix QoS to 1 +%% @end +%% Shame that we have to know the callback module here +%% would be great if we can get rid of #mqtt_msg{} record +%% and use #message{} in all places. +-spec to_export(emqx_portal_rpc | emqx_portal_mqtt, + undefined | binary(), msg()) -> exp_msg(). +to_export(emqx_portal_mqtt, Mountpoint, + #message{topic = Topic, + payload = Payload, + flags = Flags + }) -> + Retain = maps:get(retain, Flags, false), + #mqtt_msg{qos = ?QOS_1, + retain = Retain, + topic = topic(Mountpoint, Topic), + payload = Payload}; +to_export(_Module, Mountpoint, + #message{topic = Topic} = Msg) -> Msg#message{topic = topic(Mountpoint, Topic), qos = 1}. %% @doc Make `binary()' in order to make iodata to be persisted on disk. diff --git a/test/emqx_portal_SUITE.erl b/test/emqx_portal_SUITE.erl index 21effdb08..f8ca04ea2 100644 --- a/test/emqx_portal_SUITE.erl +++ b/test/emqx_portal_SUITE.erl @@ -43,6 +43,7 @@ init_per_suite(Config) -> end_per_suite(_Config) -> emqx_ct_broker_helpers:run_teardown_steps(). +%% A loopback RPC to local node t_rpc(Config) when is_list(Config) -> Cfg = #{address => node(), forwards => [<<"t_rpc/#">>], @@ -68,6 +69,74 @@ t_rpc(Config) when is_list(Config) -> ok = emqx_portal:stop(Pid) end. -t_mqtt(Config) when is_list(Config) -> ok. +t_mqtt(Config) when is_list(Config) -> + SendToTopic = <<"t_mqtt/one">>, + Mountpoint = <<"forwarded/${node}/">>, + ForwardedTopic = emqx_topic:join(["forwarded", atom_to_list(node()), SendToTopic]), + Cfg = #{address => "127.0.0.1:1883", + forwards => [SendToTopic], + connect_module => emqx_portal_mqtt, + mountpoint => Mountpoint, + username => "user", + clean_start => true, + client_id => "bridge_aws", + keepalive => 60000, + max_inflight => 32, + password => "passwd", + proto_ver => mqttv4, + queue => #{replayq_dir => "data/t_mqtt/", + replayq_seg_bytes => 10000, + batch_bytes_limit => 1000, + batch_count_limit => 10 + }, + reconnect_delay_ms => 1000, + ssl => false, + start_type => manual, + %% Consume back to forwarded message for verification + %% NOTE: this is a indefenite loopback without mocking emqx_portal:import_batch/2 + subscriptions => [{ForwardedTopic, 1}] + }, + Tester = self(), + Ref = make_ref(), + meck:new(emqx_portal, [passthrough, no_history]), + meck:expect(emqx_portal, import_batch, 2, + fun(Batch, AckFun) -> + Tester ! {Ref, Batch}, + AckFun() + end), + {ok, Pid} = emqx_portal:start_link(?FUNCTION_NAME, Cfg), + ClientId = <<"client-1">>, + try + {ok, ConnPid} = emqx_mock_client:start_link(ClientId), + {ok, SPid} = emqx_mock_client:open_session(ConnPid, ClientId, internal), + %% message from a different client, to avoid getting terminated by no-local + Msgs = lists:seq(1, 10), + lists:foreach(fun(I) -> + Msg = emqx_message:make(<<"client-2">>, ?QOS_1, SendToTopic, integer_to_binary(I)), + emqx_session:publish(SPid, I, Msg) + end, Msgs), + ok = receive_and_match_messages(Ref, Msgs), + emqx_mock_client:close_session(ConnPid) + after + ok = emqx_portal:stop(Pid), + meck:unload(emqx_portal) + end. +receive_and_match_messages(Ref, Msgs) -> + TRef = erlang:send_after(timer:seconds(4), self(), {Ref, timeout}), + try + do_receive_and_match_messages(Ref, Msgs) + after + erlang:cancel_timer(TRef) + end, + ok. + +do_receive_and_match_messages(_Ref, []) -> ok; +do_receive_and_match_messages(Ref, [I | Rest]) -> + receive + {Ref, timeout} -> erlang:error(timeout); + {Ref, [#{payload := P}]} -> + ?assertEqual(I, binary_to_integer(P)), + do_receive_and_match_messages(Ref, Rest) + end. diff --git a/test/emqx_portal_mqtt_tests.erl b/test/emqx_portal_mqtt_tests.erl index 0312bca49..8f513a853 100644 --- a/test/emqx_portal_mqtt_tests.erl +++ b/test/emqx_portal_mqtt_tests.erl @@ -14,23 +14,28 @@ -module(emqx_portal_mqtt_tests). -include_lib("eunit/include/eunit.hrl"). +-include("emqx_mqtt.hrl"). send_and_ack_test() -> %% delegate from gen_rpc to rpc for unit test Tester = self(), meck:new(emqx_client, [passthrough, no_history]), meck:expect(emqx_client, start_link, 1, - fun(#{msg_handler := Hdlr}) -> {ok, Hdlr} end), + fun(#{msg_handler := Hdlr}) -> + {ok, spawn_link(fun() -> fake_client(Hdlr) end)} + end), meck:expect(emqx_client, connect, 1, {ok, dummy}), - meck:expect(emqx_client, stop, 1, ok), + meck:expect(emqx_client, stop, 1, + fun(Pid) -> Pid ! stop end), meck:expect(emqx_client, publish, 2, - fun(_Conn, Msg) -> + fun(_Conn, Msgs) -> case rand:uniform(100) of 1 -> {error, {dummy, inflight_full}}; _ -> - Tester ! {published, Msg}, - {ok, Msg} + BaseId = hd(Msgs), + Tester ! {published, Msgs}, + {ok, BaseId} end end), try @@ -39,24 +44,38 @@ send_and_ack_test() -> {ok, Ref, Conn} = emqx_portal_mqtt:start(#{}), %% return last packet id as batch reference {ok, AckRef} = emqx_portal_mqtt:send(Conn, Batch), + %% as if the remote broker replied with puback + ok = fake_pubacks(Conn), %% expect batch ack - {ok, LastId} = collect_acks(Conn, Batch), + AckRef1= receive {batch_ack, Id} -> Id end, %% asset received ack matches the batch ref returned in send API - ?assertEqual(AckRef, LastId), + ?assertEqual(AckRef, AckRef1), ok = emqx_portal_mqtt:stop(Ref, Conn) after meck:unload(emqx_client) end. -collect_acks(_Conn, []) -> - receive {batch_ack, Id} -> {ok, Id} end; -collect_acks(#{client_pid := Client} = Conn, [Id | Rest]) -> - %% mocked for testing, should be a pid() at runtime - #{puback := PubAckCallback} = Client, +fake_pubacks(#{client_pid := Client}) -> + #{puback := PubAckCallback} = get_hdlr(Client), receive - {published, Id} -> - PubAckCallback(#{packet_id => Id, reason_code => dummy}), - collect_acks(Conn, Rest) + {published, Msgs} -> + lists:foreach( + fun(Id) -> + PubAckCallback(#{packet_id => Id, reason_code => ?RC_SUCCESS}) + end, Msgs) + end. + +get_hdlr(Client) -> + Client ! {get_hdlr, self()}, + receive {hdr, Hdlr} -> Hdlr end. + +fake_client(Hdlr) -> + receive + {get_hdlr, Pid} -> + Pid ! {hdr, Hdlr}, + fake_client(Hdlr); + stop -> + exit(normal) end. diff --git a/test/emqx_portal_tests.erl b/test/emqx_portal_tests.erl index 85975da3e..3b2879b12 100644 --- a/test/emqx_portal_tests.erl +++ b/test/emqx_portal_tests.erl @@ -16,6 +16,7 @@ -behaviour(emqx_portal_connect). -include_lib("eunit/include/eunit.hrl"). +-include("emqx.hrl"). -include("emqx_mqtt.hrl"). -define(PORTAL_NAME, test). @@ -120,7 +121,7 @@ random_sleep(MaxInterval) -> end. match_nums([], Rest) -> Rest; -match_nums([#{payload := P} | Rest], Nums) -> +match_nums([#message{payload = P} | Rest], Nums) -> I = binary_to_integer(P), case Nums of [I | NumsLeft] -> match_nums(Rest, NumsLeft); @@ -137,11 +138,5 @@ make_config(Ref, TestPid, Result) -> make_msg(I) -> Payload = integer_to_binary(I), - #{qos => ?QOS_1, - dup => false, - retain => false, - topic => <<"test/topic">>, - properties => [], - payload => Payload - }. + emqx_message:make(<<"test/topic">>, Payload).