Add function 'get_flags/1' for test

This commit is contained in:
Feng Lee 2019-09-19 11:19:10 +08:00
parent 2ed9e9480e
commit 3d6b96d321
1 changed files with 6 additions and 2 deletions

View File

@ -38,6 +38,7 @@
%% Flags %% Flags
-export([ get_flag/2 -export([ get_flag/2
, get_flag/3 , get_flag/3
, get_flags/1
, set_flag/2 , set_flag/2
, set_flag/3 , set_flag/3
, set_flags/2 , set_flags/2
@ -85,6 +86,7 @@ make(From, QoS, Topic, Payload) when ?QOS_0 =< QoS, QoS =< ?QOS_2 ->
qos = QoS, qos = QoS,
from = From, from = From,
flags = #{dup => false}, flags = #{dup => false},
headers = #{},
topic = Topic, topic = Topic,
payload = Payload, payload = Payload,
timestamp = os:timestamp()}. timestamp = os:timestamp()}.
@ -119,6 +121,9 @@ get_flag(Flag, Msg) ->
get_flag(Flag, #message{flags = Flags}, Default) -> get_flag(Flag, #message{flags = Flags}, Default) ->
maps:get(Flag, Flags, Default). maps:get(Flag, Flags, Default).
-spec(get_flags(emqx_types:message()) -> maybe(map())).
get_flags(#message{flags = Flags}) -> Flags.
-spec(set_flag(flag(), emqx_types:message()) -> emqx_types:message()). -spec(set_flag(flag(), emqx_types:message()) -> emqx_types:message()).
set_flag(Flag, Msg = #message{flags = undefined}) when is_atom(Flag) -> set_flag(Flag, Msg = #message{flags = undefined}) when is_atom(Flag) ->
Msg#message{flags = #{Flag => true}}; Msg#message{flags = #{Flag => true}};
@ -144,8 +149,7 @@ unset_flag(Flag, Msg = #message{flags = Flags}) ->
set_headers(Headers, Msg = #message{headers = undefined}) when is_map(Headers) -> set_headers(Headers, Msg = #message{headers = undefined}) when is_map(Headers) ->
Msg#message{headers = Headers}; Msg#message{headers = Headers};
set_headers(New, Msg = #message{headers = Old}) when is_map(New) -> set_headers(New, Msg = #message{headers = Old}) when is_map(New) ->
Msg#message{headers = maps:merge(Old, New)}; Msg#message{headers = maps:merge(Old, New)}.
set_headers(undefined, Msg) -> Msg.
-spec(get_headers(emqx_types:message()) -> map()). -spec(get_headers(emqx_types:message()) -> map()).
get_headers(Msg) -> get_headers(Msg) ->