diff --git a/src/emqx_message.erl b/src/emqx_message.erl index 25e6492bd..aeae37acc 100644 --- a/src/emqx_message.erl +++ b/src/emqx_message.erl @@ -17,13 +17,34 @@ -include("emqx.hrl"). -include("emqx_mqtt.hrl"). --export([make/2, make/3, make/4]). +-export([ make/2 + , make/3 + , make/4 ]). + +-export([ get_flag/2 + , get_flag/3 + , set_flag/2 + , set_flag/3 + , unset_flag/2 + ]). -export([set_flags/2]). --export([get_flag/2, get_flag/3, set_flag/2, set_flag/3, unset_flag/2]). + +-export([ get_headers/1 + , get_header/2 + , get_header/3 + , set_header/3 + , remove_header/2 + ]). -export([set_headers/2]). --export([get_header/2, get_header/3, set_header/3]). --export([is_expired/1, update_expiry/1]). --export([remove_topic_alias/1]). + +-export([ is_expired/1 + , update_expiry/1 + ]). + +-export([ to_map/1 + , to_list/1 + ]). + -export([format/1]). -type(flag() :: atom()). @@ -40,13 +61,13 @@ make(From, Topic, Payload) -> -spec(make(atom() | emqx_types:client_id(), emqx_mqtt_types:qos(), emqx_topic:topic(), emqx_types:payload()) -> emqx_types:message()). make(From, QoS, Topic, Payload) -> - #message{id = emqx_guid:gen(), - qos = QoS, - from = From, - flags = #{dup => false}, - topic = Topic, - payload = Payload, - timestamp = os:timestamp()}. + #message{id = emqx_guid:gen(), + qos = QoS, + from = From, + flags = #{dup => false}, + topic = Topic, + payload = Payload, + timestamp = os:timestamp()}. -spec(set_flags(map(), emqx_types:message()) -> emqx_types:message()). set_flags(Flags, Msg = #message{flags = undefined}) when is_map(Flags) -> @@ -88,6 +109,10 @@ set_headers(New, Msg = #message{headers = Old}) when is_map(New) -> Msg#message{headers = maps:merge(Old, New)}; set_headers(undefined, Msg) -> Msg. +-spec(get_headers(emqx_types:message()) -> map()). +get_headers(Msg) -> + Msg#message.headers. + -spec(get_header(term(), emqx_types:message()) -> term()). get_header(Hdr, Msg) -> get_header(Hdr, Msg, undefined). @@ -101,14 +126,24 @@ set_header(Hdr, Val, Msg = #message{headers = undefined}) -> set_header(Hdr, Val, Msg = #message{headers = Headers}) -> Msg#message{headers = maps:put(Hdr, Val, Headers)}. +-spec(remove_header(term(), emqx_types:message()) -> emqx_types:message()). +remove_header(Hdr, Msg = #message{headers = Headers}) -> + case maps:is_key(Hdr, Headers) of + true -> + Msg#message{headers = maps:remove(Hdr, Headers)}; + false -> Msg + end. + -spec(is_expired(emqx_types:message()) -> boolean()). -is_expired(#message{headers = #{'Message-Expiry-Interval' := Interval}, timestamp = CreatedAt}) -> +is_expired(#message{headers = #{'Message-Expiry-Interval' := Interval}, + timestamp = CreatedAt}) -> elapsed(CreatedAt) > timer:seconds(Interval); is_expired(_Msg) -> false. -spec(update_expiry(emqx_types:message()) -> emqx_types:message()). -update_expiry(Msg = #message{headers = #{'Message-Expiry-Interval' := Interval}, timestamp = CreatedAt}) -> +update_expiry(Msg = #message{headers = #{'Message-Expiry-Interval' := Interval}, + timestamp = CreatedAt}) -> case elapsed(CreatedAt) of Elapsed when Elapsed > 0 -> set_header('Message-Expiry-Interval', max(1, Interval - (Elapsed div 1000)), Msg); @@ -116,14 +151,21 @@ update_expiry(Msg = #message{headers = #{'Message-Expiry-Interval' := Interval}, end; update_expiry(Msg) -> Msg. -remove_topic_alias(Msg = #message{headers = Headers}) -> - Msg#message{headers = maps:remove('Topic-Alias', Headers)}. +%% @doc Message to map +-spec(to_map(emqx_types:message()) -> map()). +to_map(Msg) -> + maps:from_list(to_list(Msg)). + +%% @doc Message to tuple list +-spec(to_list(emqx_types:message()) -> map()). +to_list(Msg) -> + lists:zip(record_info(fields, message), tl(tuple_to_list(Msg))). %% MilliSeconds elapsed(Since) -> max(0, timer:now_diff(os:timestamp(), Since) div 1000). -format(#message{id = Id, qos = QoS, topic = Topic, from = From, flags = Flags, headers = Headers}) -> +format(#message{id = Id,qos = QoS, topic = Topic, from = From, flags = Flags, headers = Headers}) -> io_lib:format("Message(Id=~s, QoS=~w, Topic=~s, From=~p, Flags=~s, Headers=~s)", [Id, QoS, Topic, From, format(flags, Flags), format(headers, Headers)]). @@ -133,3 +175,4 @@ format(flags, Flags) -> io_lib:format("~p", [[Flag || {Flag, true} <- maps:to_list(Flags)]]); format(headers, Headers) -> io_lib:format("~p", [Headers]). + diff --git a/src/emqx_protocol.erl b/src/emqx_protocol.erl index 4548821b0..2b4ec6fbf 100644 --- a/src/emqx_protocol.erl +++ b/src/emqx_protocol.erl @@ -654,7 +654,7 @@ deliver({publish, PacketId, Msg}, PState = #pstate{mountpoint = MountPoint}) -> Msg0 = emqx_hooks:run_fold('message.deliver', [credentials(PState)], Msg), Msg1 = emqx_message:update_expiry(Msg0), Msg2 = emqx_mountpoint:unmount(MountPoint, Msg1), - send(emqx_packet:from_message(PacketId, emqx_message:remove_topic_alias(Msg2)), PState); + send(emqx_packet:from_message(PacketId, Msg2), PState); deliver({puback, PacketId, ReasonCode}, PState) -> send(?PUBACK_PACKET(PacketId, ReasonCode), PState); diff --git a/test/emqx_message_SUITE.erl b/test/emqx_message_SUITE.erl index 0975585cc..74f6702ce 100644 --- a/test/emqx_message_SUITE.erl +++ b/test/emqx_message_SUITE.erl @@ -24,12 +24,12 @@ -include_lib("eunit/include/eunit.hrl"). all() -> - [ - message_make, - message_flag, - message_header, - message_format, - message_expired + [ message_make + , message_flag + , message_header + , message_format + , message_expired + , message_to_map ]. message_make(_) -> @@ -60,7 +60,9 @@ message_header(_) -> Msg1 = emqx_message:set_headers(#{a => 1, b => 2}, Msg), Msg2 = emqx_message:set_header(c, 3, Msg1), ?assertEqual(1, emqx_message:get_header(a, Msg2)), - ?assertEqual(4, emqx_message:get_header(d, Msg2, 4)). + ?assertEqual(4, emqx_message:get_header(d, Msg2, 4)), + Msg3 = emqx_message:remove_header(a, Msg2), + ?assertEqual(#{b => 2, c => 3}, emqx_message:get_headers(Msg3)). message_format(_) -> io:format("~s", [emqx_message:format(emqx_message:make(<<"clientid">>, <<"topic">>, <<"payload">>))]). @@ -75,3 +77,17 @@ message_expired(_) -> timer:sleep(1000), Msg2 = emqx_message:update_expiry(Msg1), ?assertEqual(1, emqx_message:get_header('Message-Expiry-Interval', Msg2)). + +message_to_map(_) -> + Msg = emqx_message:make(<<"clientid">>, ?QOS_1, <<"topic">>, <<"payload">>), + List = [{id, Msg#message.id}, + {qos, ?QOS_1}, + {from, <<"clientid">>}, + {flags, #{dup => false}}, + {headers, #{}}, + {topic, <<"topic">>}, + {payload, <<"payload">>}, + {timestamp, Msg#message.timestamp}], + ?assertEqual(List, emqx_message:to_list(Msg)), + ?assertEqual(maps:from_list(List), emqx_message:to_map(Msg)). +