diff --git a/rebar.config b/rebar.config index 5306e0ae0..aedddfbe9 100644 --- a/rebar.config +++ b/rebar.config @@ -1,4 +1,4 @@ {deps, [ -{gproc,".*",{git,"https://github.com/uwiger/gproc.git",""}},{lager,".*",{git,"https://github.com/basho/lager.git",""}},{gen_logger,".*",{git,"https://github.com/emqtt/gen_logger.git",""}},{gen_conf,".*",{git,"https://github.com/emqtt/gen_conf.git",""}},{esockd,".*",{git,"https://github.com/emqtt/esockd.git","udp"}},{mochiweb,".*",{git,"https://github.com/emqtt/mochiweb.git",""}} +{gproc,".*",{git,"https://github.com/uwiger/gproc.git",""}},{lager,".*",{git,"https://github.com/basho/lager.git",""}},{gen_logger,".*",{git,"https://github.com/emqtt/gen_logger.git",""}},{gen_conf,".*",{git,"https://github.com/emqtt/gen_conf.git",""}},{esockd,".*",{git,"https://github.com/emqtt/esockd.git","emq20"}},{mochiweb,".*",{git,"https://github.com/emqtt/mochiweb.git",""}} ]}. {erl_opts, [{parse_transform,lager_transform}]}. diff --git a/src/emqttd_app.erl b/src/emqttd_app.erl index 2db44c351..65af78ec3 100644 --- a/src/emqttd_app.erl +++ b/src/emqttd_app.erl @@ -81,7 +81,6 @@ start_servers(Sup) -> {"emqttd pubsub", {supervisor, emqttd_pubsub_sup}}, {"emqttd stats", emqttd_stats}, {"emqttd metrics", emqttd_metrics}, - {"emqttd retainer", emqttd_retainer}, {"emqttd pooler", {supervisor, emqttd_pooler}}, {"emqttd trace", {supervisor, emqttd_trace_sup}}, {"emqttd client manager", {supervisor, emqttd_cm_sup}}, diff --git a/src/emqttd_mod_presence.erl b/src/emqttd_mod_presence.erl index 7815e88be..29fe2915d 100644 --- a/src/emqttd_mod_presence.erl +++ b/src/emqttd_mod_presence.erl @@ -14,7 +14,6 @@ %% limitations under the License. %%-------------------------------------------------------------------- -%% @doc emqttd presence management module -module(emqttd_mod_presence). -behaviour(emqttd_gen_mod). diff --git a/src/emqttd_retainer.erl b/src/emqttd_mod_retainer.erl similarity index 68% rename from src/emqttd_retainer.erl rename to src/emqttd_mod_retainer.erl index 435bd7abc..32887bab4 100644 --- a/src/emqttd_retainer.erl +++ b/src/emqttd_mod_retainer.erl @@ -14,28 +14,26 @@ %% limitations under the License. %%-------------------------------------------------------------------- -%% @doc MQTT retained message. --module(emqttd_retainer). +-module(emqttd_mod_retainer). -behaviour(gen_server). +-behaviour(emqttd_gen_mod). + -include("emqttd.hrl"). -include("emqttd_internal.hrl"). -include_lib("stdlib/include/ms_transform.hrl"). -%% Mnesia Callbacks --export([mnesia/1]). +%% gen_mod Callbacks +-export([load/1, unload/1]). --boot_mnesia({mnesia, [boot]}). --copy_mnesia({mnesia, [copy]}). +%% Hook Callbacks +-export([on_session_subscribed/4, on_message_publish/2]). %% API Function Exports --export([retain/1, read_messages/1, dispatch/2]). - -%% API Function Exports --export([start_link/0]). +-export([start_link/1]). %% gen_server Function Exports -export([init/1, handle_call/3, handle_cast/2, handle_info/2, @@ -46,83 +44,90 @@ -record(state, {stats_fun, expired_after, stats_timer, expire_timer}). %%-------------------------------------------------------------------- -%% Mnesia callbacks +%% Load/Unload %%-------------------------------------------------------------------- -mnesia(boot) -> - ok = emqttd_mnesia:create_table(retained_message, [ - {type, ordered_set}, - {disc_copies, [node()]}, - {record_name, retained_message}, - {attributes, record_info(fields, retained_message)}, - {storage_properties, [{ets, [compressed]}, - {dets, [{auto_save, 1000}]}]}]); +load(Env) -> + emqttd_mod_sup:start_child(spec(Env)), + emqttd:hook('session.subscribed', fun ?MODULE:on_session_subscribed/4, [Env]), + emqttd:hook('message.publish', fun ?MODULE:on_message_publish/2, [Env]). -mnesia(copy) -> - ok = emqttd_mnesia:copy_table(retained_message). +on_session_subscribed(_ClientId, _Username, {Topic, _Opts}, _Env) -> + SessPid = self(), + Msgs = case emqttd_topic:wildcard(Topic) of + false -> read_messages(Topic); + true -> match_messages(Topic) + end, + lists:foreach(fun(Msg) -> SessPid ! {dispatch, Topic, Msg} end, lists:reverse(Msgs)). + +on_message_publish(Msg = #mqtt_message{retain = false}, _Env) -> + {ok, Msg}; + +%% RETAIN flag set to 1 and payload containing zero bytes +on_message_publish(Msg = #mqtt_message{retain = true, topic = Topic, payload = <<>>}, _Env) -> + mnesia:dirty_delete(retained_message, Topic), + {stop, Msg}; + +on_message_publish(Msg = #mqtt_message{topic = Topic, retain = true, payload = Payload}, Env) -> + case {is_table_full(Env), is_too_big(size(Payload), Env)} of + {false, false} -> + mnesia:dirty_write(#retained_message{topic = Topic, msg = Msg}), + emqttd_metrics:set('messages/retained', retained_count()); + {true, _}-> + lager:error("Cannot retain message(topic=~s) for table is full!", [Topic]); + {_, true}-> + lager:error("Cannot retain message(topic=~s, payload_size=~p)" + " for payload is too big!", [Topic, size(Payload)]) + end, + {ok, Msg#mqtt_message{retain = false}}. + +is_table_full(Env) -> + Limit = proplists:get_value(max_message_num, Env, 0), + Limit > 0 andalso (retained_count() > Limit). + +is_too_big(Size, Env) -> + Limit = proplists:get_value(max_payload_size, Env, 0), + Limit > 0 andalso (Size > Limit). + +unload(_Env) -> + emqttd:unhook('session.subscribed', fun ?MODULE:on_session_subscribed/4), + emqttd:unhook('message.publish', fun ?MODULE:on_message_publish/2), + emqttd_mod_sup:stop_child(?MODULE). + +spec(Env) -> + {?MODULE, {?MODULE, start_link, [Env]}, permanent, 5000, worker, [?MODULE]}. %%-------------------------------------------------------------------- %% API %%-------------------------------------------------------------------- %% @doc Start the retainer --spec(start_link() -> {ok, pid()} | ignore | {error, any()}). -start_link() -> - gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). +-spec(start_link(Env :: list()) -> {ok, pid()} | ignore | {error, any()}). +start_link(Env) -> + gen_server:start_link({local, ?MODULE}, ?MODULE, [Env], []). -%% @doc Retain a message --spec(retain(mqtt_message()) -> ok | ignore). -retain(#mqtt_message{retain = false}) -> ignore; +%%-------------------------------------------------------------------- +%% gen_server Callbacks +%%-------------------------------------------------------------------- -%% RETAIN flag set to 1 and payload containing zero bytes -retain(#mqtt_message{retain = true, topic = Topic, payload = <<>>}) -> - delete_message(Topic); - -retain(Msg = #mqtt_message{topic = Topic, retain = true, payload = Payload}) -> - TabSize = retained_count(), - case {TabSize < limit(table), size(Payload) < limit(payload)} of - {true, true} -> - retain_message(Msg), - emqttd_metrics:set('messages/retained', retained_count()); - {false, _}-> - lager:error("Cannot retain message(topic=~s) for table is full!", [Topic]); - {_, false}-> - lager:error("Cannot retain message(topic=~s, payload_size=~p)" - " for payload is too big!", [Topic, size(Payload)]) - end, ok. - -limit(table) -> env(max_message_num); -limit(payload) -> env(max_playload_size). - -env(Key) -> - case get({retained, Key}) of - undefined -> - Env = emqttd_conf:retained(), - Val = proplists:get_value(Key, Env), - put({retained, Key}, Val), Val; - Val -> - Val - end. - -%% @doc Deliver retained messages to the subscriber --spec(dispatch(Topic :: binary(), CPid :: pid()) -> any()). -dispatch(Topic, CPid) when is_binary(Topic) -> - Msgs = case emqttd_topic:wildcard(Topic) of - false -> read_messages(Topic); - true -> match_messages(Topic) +init([Env]) -> + Copy = case proplists:get_value(storage, Env, disc) of + disc -> disc_copies; + ram -> ram_copies end, - lists:foreach(fun(Msg) -> CPid ! {dispatch, Topic, Msg} end, lists:reverse(Msgs)). - -%%-------------------------------------------------------------------- -%% gen_server callbacks -%%-------------------------------------------------------------------- - -init([]) -> + ok = emqttd_mnesia:create_table(retained_message, [ + {type, ordered_set}, + {Copy, [node()]}, + {record_name, retained_message}, + {attributes, record_info(fields, retained_message)}, + {storage_properties, [{ets, [compressed]}, + {dets, [{auto_save, 1000}]}]}]), + ok = emqttd_mnesia:copy_table(retained_message), StatsFun = emqttd_stats:statsfun('retained/count', 'retained/max'), %% One second {ok, StatsTimer} = timer:send_interval(timer:seconds(1), stats), State = #state{stats_fun = StatsFun, stats_timer = StatsTimer}, - {ok, init_expire_timer(env(expired_after), State)}. + {ok, init_expire_timer(proplists:get_value(expired_after, Env, 0), State)}. init_expire_timer(0, State) -> State; @@ -164,10 +169,6 @@ code_change(_OldVsn, State, _Extra) -> %% Internal Functions %%-------------------------------------------------------------------- --spec(retain_message(mqtt_message()) -> ok). -retain_message(Msg = #mqtt_message{topic = Topic}) -> - mnesia:dirty_write(#retained_message{topic = Topic, msg = Msg}). - -spec(read_messages(binary()) -> [mqtt_message()]). read_messages(Topic) -> [Msg || #retained_message{msg = Msg} <- mnesia:dirty_read(retained_message, Topic)]. @@ -183,10 +184,6 @@ match_messages(Filter) -> end, mnesia:async_dirty(fun mnesia:foldl/3, [Fun, [], retained_message]). --spec(delete_message(binary()) -> ok). -delete_message(Topic) -> - mnesia:dirty_delete(retained_message, Topic). - -spec(expire_messages(pos_integer()) -> any()). expire_messages(Time) when is_integer(Time) -> mnesia:transaction( diff --git a/src/emqttd_mod_sup.erl b/src/emqttd_mod_sup.erl index 759886fd3..a0f3cb9d3 100644 --- a/src/emqttd_mod_sup.erl +++ b/src/emqttd_mod_sup.erl @@ -21,7 +21,7 @@ -include("emqttd.hrl"). %% API --export([start_link/0, start_child/1, start_child/2]). +-export([start_link/0, start_child/1, start_child/2, stop_child/1]). %% Supervisor callbacks -export([init/1]). @@ -46,6 +46,13 @@ start_child(ChildSpec) when is_tuple(ChildSpec) -> start_child(Mod, Type) when is_atom(Mod) and is_atom(Type) -> supervisor:start_child(?MODULE, ?CHILD(Mod, Type)). +-spec(stop_child(any()) -> ok | {error, any()}). +stop_child(ChildId) -> + case supervisor:terminate_child(?MODULE, ChildId) of + ok -> supervisor:delete_child(?MODULE, ChildId); + Error -> Error + end. + %%-------------------------------------------------------------------- %% Supervisor callbacks %%-------------------------------------------------------------------- diff --git a/src/emqttd_protocol.erl b/src/emqttd_protocol.erl index 3448ce387..d8bc05034 100644 --- a/src/emqttd_protocol.erl +++ b/src/emqttd_protocol.erl @@ -197,23 +197,32 @@ process(?PUBACK_PACKET(?PUBCOMP, PacketId), State = #proto_state{session = Sessi process(?SUBSCRIBE_PACKET(PacketId, []), State) -> send(?SUBACK_PACKET(PacketId, []), State); -process(?SUBSCRIBE_PACKET(PacketId, TopicTable), State = #proto_state{session = Session}) -> +process(?SUBSCRIBE_PACKET(PacketId, RawTopicTable), State = #proto_state{ + client_id = ClientId, username = Username, session = Session}) -> Client = client(State), - AllowDenies = [check_acl(subscribe, Topic, Client) || {Topic, _Qos} <- TopicTable], + TopicTable = parse_topic_table(RawTopicTable), + AllowDenies = [check_acl(subscribe, Topic, Client) || {Topic, _Opts} <- TopicTable], case lists:member(deny, AllowDenies) of true -> ?LOG(error, "Cannot SUBSCRIBE ~p for ACL Deny", [TopicTable], State), send(?SUBACK_PACKET(PacketId, [16#80 || _ <- TopicTable]), State); false -> - emqttd_session:subscribe(Session, PacketId, TopicTable), {ok, State} + case emqttd:run_hooks('client.subscribe', [ClientId, Username], TopicTable) of + {ok, TopicTable1} -> + emqttd_session:subscribe(Session, PacketId, TopicTable1), {ok, State}; + {stop, _} -> + {ok, State} + end end; %% Protect from empty topic list process(?UNSUBSCRIBE_PACKET(PacketId, []), State) -> send(?UNSUBACK_PACKET(PacketId), State); -process(?UNSUBSCRIBE_PACKET(PacketId, Topics), State = #proto_state{session = Session}) -> - emqttd_session:unsubscribe(Session, Topics), +process(?UNSUBSCRIBE_PACKET(PacketId, RawTopics), State = #proto_state{ + client_id = ClientId, username = Username, session = Session}) -> + {ok, TopicTable} = emqttd:run_hooks('client.unsubscribe', [ClientId, Username], parse_topics(RawTopics)), + emqttd_session:unsubscribe(Session, TopicTable), send(?UNSUBACK_PACKET(PacketId), State); process(?PACKET(?PINGREQ), State) -> @@ -249,7 +258,7 @@ with_puback(Type, Packet = ?PUBLISH_PACKET(_Qos, PacketId), -spec(send(mqtt_message() | mqtt_packet(), proto_state()) -> {ok, proto_state()}). send(Msg, State = #proto_state{client_id = ClientId, username = Username}) when is_record(Msg, mqtt_message) -> - emqttd:run_hooks('message.delivered', [{ClientId, Username}], Msg), + emqttd:run_hooks('message.delivered', [ClientId, Username], Msg), send(emqttd_message:to_packet(Msg), State); send(Packet, State = #proto_state{sendfun = SendFun}) @@ -393,6 +402,15 @@ validate_qos(Qos) when ?IS_QOS(Qos) -> validate_qos(_) -> false. +parse_topic_table(TopicTable) -> + lists:map(fun({Topic0, Qos}) -> + {Topic, Opts} = emqttd_topic:parse(Topic0), + {Topic, [{qos, Qos}|Opts]} + end, TopicTable). + +parse_topics(Topics) -> + [emqttd_topic:parse(Topic) || Topic <- Topics]. + %% PUBLISH ACL is cached in process dictionary. check_acl(publish, Topic, Client) -> IfCache = emqttd:conf(cache_acl, true), @@ -412,4 +430,3 @@ check_acl(subscribe, Topic, Client) -> sp(true) -> 1; sp(false) -> 0. - diff --git a/src/emqttd_server.erl b/src/emqttd_server.erl index 6e6c8620f..7cf4c3007 100644 --- a/src/emqttd_server.erl +++ b/src/emqttd_server.erl @@ -91,12 +91,7 @@ publish(Msg = #mqtt_message{from = From}) -> trace(publish, From, Msg), case emqttd_hook:run('message.publish', [], Msg) of {ok, Msg1 = #mqtt_message{topic = Topic}} -> - %% Retain message first. Don't create retained topic. - Msg2 = case emqttd_retainer:retain(Msg1) of - ok -> emqttd_message:unset_flag(Msg1); - ignore -> Msg1 - end, - emqttd_pubsub:publish(Topic, Msg2); + emqttd_pubsub:publish(Topic, Msg1); {stop, Msg1} -> lager:warning("Stop publishing: ~s", [emqttd_message:format(Msg1)]), ignore diff --git a/src/emqttd_session.erl b/src/emqttd_session.erl index bf475e4f6..6c4ab5c97 100644 --- a/src/emqttd_session.erl +++ b/src/emqttd_session.erl @@ -162,16 +162,14 @@ destroy(SessPid, ClientId) -> %%-------------------------------------------------------------------- %% @doc Subscribe Topics --spec(subscribe(pid(), [{binary(), mqtt_qos()}]) -> ok). +-spec(subscribe(pid(), [{binary(), [emqttd_topic:option()]}]) -> ok). subscribe(SessPid, TopicTable) -> gen_server2:cast(SessPid, {subscribe, TopicTable, fun(_) -> ok end}). --spec(subscribe(pid(), mqtt_packet_id(), [{binary(), mqtt_qos()}]) -> ok). -subscribe(SessPid, PacketId, TopicTable) -> - From = self(), - AckFun = fun(GrantedQos) -> - From ! {suback, PacketId, GrantedQos} - end, +-spec(subscribe(pid(), mqtt_pktid(), [{binary(), [emqttd_topic:option()]}]) -> ok). +subscribe(SessPid, PktId, TopicTable) -> + From = self(), + AckFun = fun(GrantedQos) -> From ! {suback, PktId, GrantedQos} end, gen_server2:cast(SessPid, {subscribe, TopicTable, AckFun}). %% @doc Publish message @@ -206,9 +204,9 @@ pubcomp(SessPid, PktId) -> gen_server2:cast(SessPid, {pubcomp, PktId}). %% @doc Unsubscribe Topics --spec(unsubscribe(pid(), [binary()]) -> ok). -unsubscribe(SessPid, Topics) -> - gen_server2:cast(SessPid, {unsubscribe, Topics}). +-spec(unsubscribe(pid(), [{binary(), [emqttd_topic:option()]}]) -> ok). +unsubscribe(SessPid, TopicTable) -> + gen_server2:cast(SessPid, {unsubscribe, TopicTable}). %%-------------------------------------------------------------------- %% gen_server Callbacks @@ -223,7 +221,7 @@ init([CleanSess, {ClientId, Username}, ClientPid]) -> client_id = ClientId, client_pid = ClientPid, username = Username, - subscriptions = dict:new(), + subscriptions = #{}, inflight_queue = [], max_inflight = get_value(max_inflight, SessEnv, 0), message_queue = emqttd_mqueue:new(ClientId, emqttd_conf:queue(), emqttd_alarm:alarm_fun()), @@ -250,10 +248,10 @@ prioritise_cast(Msg, _Len, _State) -> case Msg of {destroy, _} -> 10; {resume, _, _} -> 9; - {pubrel, _PktId} -> 8; - {pubcomp, _PktId} -> 8; - {pubrec, _PktId} -> 8; - {puback, _PktId} -> 7; + {pubrel, _} -> 8; + {pubcomp, _} -> 8; + {pubrec, _} -> 8; + {puback, _} -> 7; {unsubscribe, _, _} -> 6; {subscribe, _, _} -> 5; _ -> 0 @@ -288,67 +286,48 @@ handle_call({publish, Msg = #mqtt_message{qos = ?QOS_2, pktid = PktId}}, handle_call(Req, _From, State) -> ?UNEXPECTED_REQ(Req, State). -%%TODO: 2.0 FIX - handle_cast({subscribe, TopicTable, AckFun}, Session = #session{client_id = ClientId, username = Username, subscriptions = Subscriptions}) -> ?LOG(info, "Subscribe ~p", [TopicTable], Session), {GrantedQos, Subscriptions1} = - lists:foldl(fun({RawTopic, Qos}, {QosAcc, SubDict}) -> - {Topic, Opts} = emqttd_topic:strip(RawTopic), - case emqttd:run_hooks('client.subscribe', [{ClientId, Username}], {Topic, Opts}) of - {ok, {Topic1, Opts1}} -> - NewQos = proplists:get_value(qos, Opts1, Qos), - {[NewQos | QosAcc], case dict:find(Topic, SubDict) of - {ok, NewQos} -> - ?LOG(warning, "duplicated subscribe: ~s, qos = ~w", [Topic, NewQos], Session), - SubDict; - {ok, OldQos} -> - emqttd:setqos(Topic, ClientId, NewQos), - ?LOG(warning, "duplicated subscribe ~s, old_qos=~w, new_qos=~w", [Topic, OldQos, NewQos], Session), - dict:store(Topic, NewQos, SubDict); - error -> - emqttd:subscribe(Topic1, ClientId, Opts1), - %%TODO: the design is ugly... - %% : 3.8.4 - %% Where the Topic Filter is not identical to any existing Subscription’s filter, - %% a new Subscription is created and all matching retained messages are sent. - emqttd_retainer:dispatch(Topic1, self()), - emqttd:run_hooks('client.subscribe.after', [{ClientId, Username}], {Topic1, Opts1}), - - dict:store(Topic1, NewQos, SubDict) - end}; - {stop, _} -> - ?LOG(error, "Cannot subscribe: ~p", [Topic], Session), - {[128 | QosAcc], SubDict} - end - end, {[], Subscriptions}, TopicTable), + lists:foldl(fun({Topic, Opts}, {QosAcc, SubMap}) -> + NewQos = proplists:get_value(qos, Opts), + SubMap1 = + case maps:find(Topic, SubMap) of + {ok, NewQos} -> + ?LOG(warning, "duplicated subscribe: ~s, qos = ~w", [Topic, NewQos], Session), + SubMap; + {ok, OldQos} -> + emqttd:setqos(Topic, ClientId, NewQos), + ?LOG(warning, "duplicated subscribe ~s, old_qos=~w, new_qos=~w", + [Topic, OldQos, NewQos], Session), + maps:put(Topic, NewQos, SubMap); + error -> + emqttd:subscribe(Topic, ClientId, Opts), + emqttd:run_hooks('session.subscribed', [ClientId, Username], {Topic, Opts}), + maps:put(Topic, NewQos, SubMap) + end, + {[NewQos|QosAcc], SubMap1} + end, {[], Subscriptions}, TopicTable), AckFun(lists:reverse(GrantedQos)), hibernate(Session#session{subscriptions = Subscriptions1}); -%%TODO: 2.0 FIX - -handle_cast({unsubscribe, Topics}, Session = #session{client_id = ClientId, - username = Username, - subscriptions = Subscriptions}) -> - ?LOG(info, "unsubscribe ~p", [Topics], Session), +handle_cast({unsubscribe, TopicTable}, Session = #session{client_id = ClientId, + username = Username, + subscriptions = Subscriptions}) -> + ?LOG(info, "unsubscribe ~p", [TopicTable], Session), Subscriptions1 = - lists:foldl(fun(RawTopic, SubDict) -> - {Topic0, _Opts} = emqttd_topic:strip(RawTopic), - case emqttd:run_hooks('client.unsubscribe', [ClientId, Username], Topic0) of - {ok, Topic1} -> - case dict:find(Topic1, SubDict) of - {ok, _Qos} -> - emqttd:unsubscribe(Topic1, ClientId), - dict:erase(Topic1, SubDict); - error -> - SubDict - end; - {stop, _} -> - SubDict - end - end, Subscriptions, Topics), + lists:foldl(fun({Topic, Opts}, SubMap) -> + case maps:find(Topic, SubMap) of + {ok, _Qos} -> + emqttd:unsubscribe(Topic, ClientId), + emqttd:run_hooks('session.unsubscribed', [ClientId, Username], {Topic, Opts}), + dict:erase(Topic, SubMap); + error -> + SubMap + end + end, Subscriptions, TopicTable), hibernate(Session#session{subscriptions = Subscriptions1}); handle_cast({destroy, ClientId}, Session = #session{client_id = ClientId}) -> @@ -664,7 +643,7 @@ acked(PktId, Session = #session{client_id = ClientId, awaiting_ack = Awaiting}) -> case lists:keyfind(PktId, 1, InflightQ) of {_, Msg} -> - emqttd:run_hooks('message.acked', [{ClientId, Username}], Msg); + emqttd:run_hooks('message.acked', [ClientId, Username], Msg); false -> ?LOG(error, "Cannot find acked pktid: ~p", [PktId], Session) end, diff --git a/src/emqttd_topic.erl b/src/emqttd_topic.erl index ebd16714d..2a198c2e1 100644 --- a/src/emqttd_topic.erl +++ b/src/emqttd_topic.erl @@ -16,23 +16,27 @@ -module(emqttd_topic). +-include("emqttd_protocol.hrl"). + -import(lists, [reverse/1]). -export([match/2, validate/1, triples/1, words/1, wildcard/1]). -export([join/1, feed_var/3, systop/1]). --export([strip/1, strip/2]). +-export([parse/1, parse/2]). -type(topic() :: binary()). +-type(option() :: local | {qos, mqtt_qos()} | {share, '$queue' | binary()}). + -type(word() :: '' | '+' | '#' | binary()). -type(words() :: list(word())). -type(triple() :: {root | binary(), word(), binary()}). --export_type([topic/0, word/0, triple/0]). +-export_type([topic/0, option/0, word/0, triple/0]). -define(MAX_TOPIC_LEN, 4096). @@ -172,28 +176,28 @@ join(Words) -> end, {true, <<>>}, [bin(W) || W <- Words]), Bin. --spec(strip(topic()) -> {topic(), [local | {share, binary()}]}). -strip(Topic) when is_binary(Topic) -> - strip(Topic, []). +-spec(parse(topic()) -> {topic(), [option()]}). +parse(Topic) when is_binary(Topic) -> + parse(Topic, []). -strip(Topic = <<"$local/", Topic1/binary>>, Options) -> +parse(Topic = <<"$local/", Topic1/binary>>, Options) -> case lists:member(local, Options) of true -> error({invalid_topic, Topic}); - false -> strip(Topic1, [local | Options]) + false -> parse(Topic1, [local | Options]) end; -strip(Topic = <<"$queue/", Topic1/binary>>, Options) -> +parse(Topic = <<"$queue/", Topic1/binary>>, Options) -> case lists:keyfind(share, 1, Options) of {share, _} -> error({invalid_topic, Topic}); - false -> strip(Topic1, [{share, '$queue'} | Options]) + false -> parse(Topic1, [{share, '$queue'} | Options]) end; -strip(Topic = <<"$share/", Topic1/binary>>, Options) -> +parse(Topic = <<"$share/", Topic1/binary>>, Options) -> case lists:keyfind(share, 1, Options) of {share, _} -> error({invalid_topic, Topic}); false -> [Share, Topic2] = binary:split(Topic1, <<"/">>), {Topic2, [{share, Share} | Options]} end; -strip(Topic, Options) -> {Topic, Options}. +parse(Topic, Options) -> {Topic, Options}. diff --git a/test/emqttd_SUITE.erl b/test/emqttd_SUITE.erl index 06923ee8e..0cdfbd8c2 100644 --- a/test/emqttd_SUITE.erl +++ b/test/emqttd_SUITE.erl @@ -61,7 +61,7 @@ groups() -> [add_delete_hook, run_hooks]}, {retainer, [sequence], - [dispatch_retained_messages]}, + [t_retained_messages]}, {backend, [sequence], []}, {http, [sequence], @@ -307,7 +307,7 @@ hook_fun5(arg1, arg2, Acc, init) -> {stop, [r3 | Acc]}. %% Retainer Test %%-------------------------------------------------------------------- -dispatch_retained_messages(_) -> +t_retained_messages(_) -> Msg = #mqtt_message{retain = true, topic = <<"a/b/c">>, payload = <<"payload">>}, emqttd_retainer:retain(Msg), diff --git a/test/emqttd_topic_SUITE.erl b/test/emqttd_topic_SUITE.erl index 5e9608e00..5692dbb43 100644 --- a/test/emqttd_topic_SUITE.erl +++ b/test/emqttd_topic_SUITE.erl @@ -22,14 +22,14 @@ -compile(export_all). -import(emqttd_topic, [wildcard/1, match/2, validate/1, triples/1, join/1, - words/1, systop/1, feed_var/3, strip/1, strip/2]). + words/1, systop/1, feed_var/3, parse/1, parse/2]). -define(N, 10000). all() -> [t_wildcard, t_match, t_match2, t_validate, t_triples, t_join, t_words, t_systop, t_feed_var, t_sys_match, 't_#_match', t_sigle_level_validate, t_sigle_level_match, t_match_perf, - t_triples_perf, t_strip]. + t_triples_perf, t_parse]. t_wildcard(_) -> true = wildcard(<<"a/b/#">>), @@ -171,11 +171,11 @@ t_feed_var(_) -> long_topic() -> iolist_to_binary([[integer_to_list(I), "/"] || I <- lists:seq(0, 10000)]). -t_strip(_) -> - ?assertEqual({<<"a/b/+/#">>, []}, strip(<<"a/b/+/#">>)), - ?assertEqual({<<"topic">>, [{share, '$queue'}]}, strip(<<"$queue/topic">>)), - ?assertEqual({<<"topic">>, [{share, <<"group">>}]}, strip(<<"$share/group/topic">>)), - ?assertEqual({<<"topic">>, [local]}, strip(<<"$local/topic">>)), - ?assertEqual({<<"topic">>, [{share, '$queue'}, local]}, strip(<<"$local/$queue/topic">>)), - ?assertEqual({<<"/a/b/c">>, [{share, <<"group">>}, local]}, strip(<<"$local/$share/group//a/b/c">>)). +t_parse(_) -> + ?assertEqual({<<"a/b/+/#">>, []}, parse(<<"a/b/+/#">>)), + ?assertEqual({<<"topic">>, [{share, '$queue'}]}, parse(<<"$queue/topic">>)), + ?assertEqual({<<"topic">>, [{share, <<"group">>}]}, parse(<<"$share/group/topic">>)), + ?assertEqual({<<"topic">>, [local]}, parse(<<"$local/topic">>)), + ?assertEqual({<<"topic">>, [{share, '$queue'}, local]}, parse(<<"$local/$queue/topic">>)), + ?assertEqual({<<"/a/b/c">>, [{share, <<"group">>}, local]}, parse(<<"$local/$share/group//a/b/c">>)).