diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl index 03544b9da..2419f6884 100644 --- a/src/emqx_channel.erl +++ b/src/emqx_channel.erl @@ -68,7 +68,7 @@ %% MQTT Will Msg will_msg :: maybe(emqx_types:message()), %% MQTT Topic Aliases - topic_aliases :: maybe(map()), + topic_aliases :: emqx_types:topic_aliases(), %% MQTT Topic Alias Maximum alias_maximum :: maybe(map()), %% Timers @@ -180,6 +180,9 @@ init(ConnInfo = #{peername := {PeerHost, _Port}, }, #channel{conninfo = ConnInfo, clientinfo = ClientInfo, + topic_aliases = #{inbound => #{}, + outbound => #{} + }, timers = #{}, conn_state = idle, takeover = false, @@ -599,8 +602,8 @@ handle_out(publish, [], Channel) -> {ok, Channel}; handle_out(publish, Publishes, Channel) -> - Packets = do_deliver(Publishes, Channel), - {ok, {outgoing, Packets}, Channel}; + {Packets, NChannel} = do_deliver(Publishes, Channel), + {ok, {outgoing, Packets}, NChannel}; handle_out(puback, {PacketId, ReasonCode}, Channel) -> {ok, ?PUBACK_PACKET(PacketId, ReasonCode), Channel}; @@ -658,9 +661,9 @@ return_connack(AckPacket, Channel) -> resuming = false, pendings = [] }, - Packets = do_deliver(Publishes, NChannel), + {Packets, NChannel1} = do_deliver(Publishes, NChannel), Outgoing = [{outgoing, Packets} || length(Packets) > 0], - {ok, Replies ++ Outgoing, NChannel} + {ok, Replies ++ Outgoing, NChannel1} end. %%-------------------------------------------------------------------- @@ -668,16 +671,16 @@ return_connack(AckPacket, Channel) -> %%-------------------------------------------------------------------- %% return list(emqx_types:packet()) -do_deliver({pubrel, PacketId}, _Channel) -> - [?PUBREL_PACKET(PacketId, ?RC_SUCCESS)]; +do_deliver({pubrel, PacketId}, Channel) -> + {[?PUBREL_PACKET(PacketId, ?RC_SUCCESS)], Channel}; -do_deliver({PacketId, Msg}, #channel{clientinfo = ClientInfo = +do_deliver({PacketId, Msg}, Channel = #channel{clientinfo = ClientInfo = #{mountpoint := MountPoint}}) -> case ignore_local(Msg, ClientInfo) of true -> ok = emqx_metrics:inc('delivery.dropped'), ok = emqx_metrics:inc('delivery.dropped.no_local'), - []; + {[], Channel}; false -> ok = emqx_metrics:inc('messages.delivered'), Msg1 = emqx_hooks:run_fold('message.delivered', @@ -685,18 +688,21 @@ do_deliver({PacketId, Msg}, #channel{clientinfo = ClientInfo = emqx_message:update_expiry(Msg) ), Msg2 = emqx_mountpoint:unmount(MountPoint, Msg1), - [emqx_message:to_packet(PacketId, Msg2)] + Packet = emqx_message:to_packet(PacketId, Msg2), + {NPacket, NChannel} = packing_alias(Packet, Channel), + {[NPacket], NChannel} end; do_deliver([Publish], Channel) -> do_deliver(Publish, Channel); do_deliver(Publishes, Channel) when is_list(Publishes) -> - lists:reverse( - lists:foldl( - fun(Publish, Acc) -> - lists:append(do_deliver(Publish, Channel), Acc) - end, [], Publishes)). + {Packets, NChannel} = + lists:foldl(fun(Publish, {Acc, Chann}) -> + {Packets, NChann} = do_deliver(Publish, Chann), + {Packets ++ Acc, NChann} + end, {[], Channel}, Publishes), + {lists:reverse(Packets), NChannel}. ignore_local(#message{flags = #{nl := true}, from = ClientId}, #{clientid := ClientId}) -> @@ -1072,8 +1078,8 @@ process_alias(Packet = #mqtt_packet{ properties = #{'Topic-Alias' := AliasId} } = Publish }, - Channel = #channel{topic_aliases = Aliases}) -> - case find_alias(AliasId, Aliases) of + Channel = #channel{topic_aliases = TopicAliases}) -> + case find_alias(inbound, AliasId, TopicAliases) of {ok, Topic} -> NPublish = Publish#mqtt_packet_publish{topic_name = Topic}, {ok, Packet#mqtt_packet{variable = NPublish}, Channel}; @@ -1085,12 +1091,44 @@ process_alias(#mqtt_packet{ properties = #{'Topic-Alias' := AliasId} } }, - Channel = #channel{topic_aliases = Aliases}) -> - NAliases = save_alias(AliasId, Topic, Aliases), - {ok, Channel#channel{topic_aliases = NAliases}}; + Channel = #channel{topic_aliases = TopicAliases}) -> + NTopicAliases = save_alias(inbound, AliasId, Topic, TopicAliases), + {ok, Channel#channel{topic_aliases = NTopicAliases}}; process_alias(_Packet, Channel) -> {ok, Channel}. +%%-------------------------------------------------------------------- +%% Packing Topic Alias + +packing_alias(Packet = #mqtt_packet{ + variable = #mqtt_packet_publish{topic_name = Topic} = Publish + }, + Channel = #channel{topic_aliases = TopicAliases, alias_maximum = Limits}) -> + case find_alias(outbound, Topic, TopicAliases) of + {ok, AliasId} -> + NPublish = Publish#mqtt_packet_publish{ + topic_name = <<>>, + properties = #{'Topic-Alias' => AliasId} + }, + {Packet#mqtt_packet{variable = NPublish}, Channel}; + error -> + #{outbound := Aliases} = TopicAliases, + AliasId = maps:size(Aliases) + 1, + case (Limits =:= undefined) orelse + (AliasId =< maps:get(outbound, Limits, 0)) of + true -> + NTopicAliases = save_alias(outbound, AliasId, Topic, TopicAliases), + NChannel = Channel#channel{topic_aliases = NTopicAliases}, + NPublish = Publish#mqtt_packet_publish{ + topic_name = Topic, + properties = #{'Topic-Alias' => AliasId} + }, + {Packet#mqtt_packet{variable = NPublish}, NChannel}; + false -> {Packet, Channel} + end + end; +packing_alias(Packet, Channel) -> {Packet, Channel}. + %%-------------------------------------------------------------------- %% Check Pub Alias @@ -1346,16 +1384,21 @@ run_hooks(Name, Args) -> run_hooks(Name, Args, Acc) -> ok = emqx_metrics:inc(Name), emqx_hooks:run_fold(Name, Args, Acc). --compile({inline, [find_alias/2, save_alias/3]}). +-compile({inline, [find_alias/3, save_alias/4]}). -find_alias(_AliasId, undefined) -> false; -find_alias(AliasId, Aliases) -> - maps:find(AliasId, Aliases). +find_alias(_, _ ,undefined) -> false; +find_alias(inbound, AliasId, _TopicAliases = #{inbound := Aliases}) -> + maps:find(AliasId, Aliases); +find_alias(outbound, Topic, _TopicAliases = #{outbound := Aliases}) -> + maps:find(Topic, Aliases). -save_alias(AliasId, Topic, undefined) -> - #{AliasId => Topic}; -save_alias(AliasId, Topic, Aliases) -> - maps:put(AliasId, Topic, Aliases). +save_alias(_, _, _, undefined) -> false; +save_alias(inbound, AliasId, Topic, TopicAliases = #{inbound := Aliases}) -> + NAliases = maps:put(AliasId, Topic, Aliases), + TopicAliases#{inbound => NAliases}; +save_alias(outbound, AliasId, Topic, TopicAliases = #{outbound := Aliases}) -> + NAliases = maps:put(Topic, AliasId, Aliases), + TopicAliases#{outbound => NAliases}. -compile({inline, [reply/2, shutdown/2, shutdown/3, sp/1, flag/1]}). diff --git a/src/emqx_types.erl b/src/emqx_types.erl index f00500409..e304eafcf 100644 --- a/src/emqx_types.erl +++ b/src/emqx_types.erl @@ -47,6 +47,7 @@ , subopts/0 , reason_code/0 , alias_id/0 + , topic_aliases/0 , properties/0 ]). @@ -165,6 +166,8 @@ -type(reason_code() :: 0..16#FF). -type(packet_id() :: 1..16#FFFF). -type(alias_id() :: 0..16#FFFF). +-type(topic_aliases() :: #{inbound => maybe(map()), + outbound => maybe(map())}). -type(properties() :: #{atom() => term()}). -type(topic_filters() :: list({topic(), subopts()})). -type(packet() :: #mqtt_packet{}). diff --git a/test/emqx_channel_SUITE.erl b/test/emqx_channel_SUITE.erl index b52d9ed07..342ba8bae 100644 --- a/test/emqx_channel_SUITE.erl +++ b/test/emqx_channel_SUITE.erl @@ -482,10 +482,26 @@ t_auth_connect(_) -> t_process_alias(_) -> Publish = #mqtt_packet_publish{topic_name = <<>>, properties = #{'Topic-Alias' => 1}}, - Channel = emqx_channel:set_field(topic_aliases, #{1 => <<"t">>}, channel()), + Channel = emqx_channel:set_field(topic_aliases, #{inbound => #{1 => <<"t">>}}, channel()), {ok, #mqtt_packet{variable = #mqtt_packet_publish{topic_name = <<"t">>}}, _Chan} = emqx_channel:process_alias(#mqtt_packet{variable = Publish}, Channel). +t_packing_alias(_) -> + Packet1 = #mqtt_packet{variable = #mqtt_packet_publish{topic_name = <<"x">>}}, + Packet2 = #mqtt_packet{variable = #mqtt_packet_publish{topic_name = <<"y">>}}, + Channel = emqx_channel:set_field(alias_maximum, #{outbound => 1}, channel()), + + {RePacket1, NChannel1} = emqx_channel:packing_alias(Packet1, Channel), + ?assertEqual(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = <<"x">>, properties = #{'Topic-Alias' => 1}}}, RePacket1), + + {RePacket2, NChannel2} = emqx_channel:packing_alias(Packet1, NChannel1), + ?assertEqual(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = <<>>, properties = #{'Topic-Alias' => 1}}}, RePacket2), + + {RePacket3, _} = emqx_channel:packing_alias(Packet2, NChannel2), + ?assertEqual(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = <<"y">>, properties = undefined}}, RePacket3), + + ?assertMatch({#mqtt_packet{variable = #mqtt_packet_publish{topic_name = <<"z">>}}, _}, emqx_channel:packing_alias(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = <<"z">>}}, channel())). + t_check_pub_acl(_) -> ok = meck:new(emqx_zone, [passthrough, no_history]), ok = meck:expect(emqx_zone, enable_acl, fun(_) -> true end), diff --git a/test/mqtt_protocol_v5_SUITE.erl b/test/mqtt_protocol_v5_SUITE.erl index 3c758f993..cd32e9bd8 100644 --- a/test/mqtt_protocol_v5_SUITE.erl +++ b/test/mqtt_protocol_v5_SUITE.erl @@ -62,7 +62,7 @@ receive_messages(Count, Msgs) -> receive_messages(Count-1, [Msg|Msgs]); _Other -> receive_messages(Count, Msgs) - after 100 -> + after 1000 -> Msgs end. @@ -605,6 +605,33 @@ t_publish_overlapping_subscriptions(_) -> %% Subsctibe %%-------------------------------------------------------------------- +t_subscribe_topic_alias(_) -> + Topic1 = nth(1, ?TOPICS), + Topic2 = nth(2, ?TOPICS), + {ok, Client1} = emqtt:start_link([{proto_ver, v5}, + {properties, #{'Topic-Alias-Maximum' => 1}} + ]), + {ok, _} = emqtt:connect(Client1), + {ok, _, [2]} = emqtt:subscribe(Client1, Topic1, qos2), + {ok, _, [2]} = emqtt:subscribe(Client1, Topic2, qos2), + + ok = emqtt:publish(Client1, Topic1, #{}, <<"Topic-Alias">>, [{qos, ?QOS_0}]), + [Msg1] = receive_messages(1), + ?assertEqual({ok, #{'Topic-Alias' => 1}}, maps:find(properties, Msg1)), + ?assertEqual({ok, Topic1}, maps:find(topic, Msg1)), + + ok = emqtt:publish(Client1, Topic1, #{}, <<"Topic-Alias">>, [{qos, ?QOS_0}]), + [Msg2] = receive_messages(1), + ?assertEqual({ok, #{'Topic-Alias' => 1}}, maps:find(properties, Msg2)), + ?assertEqual({ok, <<>>}, maps:find(topic, Msg2)), + + ok = emqtt:publish(Client1, Topic2, #{}, <<"Topic-Alias">>, [{qos, ?QOS_0}]), + [Msg3] = receive_messages(1), + ?assertEqual({ok, #{}}, maps:find(properties, Msg3)), + ?assertEqual({ok, Topic2}, maps:find(topic, Msg3)), + + ok = emqtt:disconnect(Client1). + t_subscribe_no_local(_) -> Topic = nth(1, ?TOPICS),