diff --git a/src/emqttd_pubsub.erl b/src/emqttd_pubsub.erl index bb8032c79..308be6184 100644 --- a/src/emqttd_pubsub.erl +++ b/src/emqttd_pubsub.erl @@ -39,11 +39,13 @@ -boot_mnesia({mnesia, [boot]}). -copy_mnesia({mnesia, [copy]}). -%% API Exports +%% API Exports -export([start_link/4]). --export([create/2, subscribe/1, subscribe/2, - unsubscribe/1, unsubscribe/2, publish/1]). +-export([create/2, lookup/2, subscribe/1, subscribe/2, + unsubscribe/1, unsubscribe/2, publish/1, delete/2]). + +%% Subscriptions API %% Local node -export([match/1]). @@ -154,6 +156,30 @@ create(subscription, {SubId, Topic, Qos}) -> {aborted, Error} -> {error, Error} end. +%%------------------------------------------------------------------------------ +%% @doc Lookup Topic or Subscription. +%% @end +%%------------------------------------------------------------------------------ +-spec lookup(topic | subscription, binary()) -> list(). +lookup(topic, Topic) -> + mnesia:dirty_read(topic, Topic); + +lookup(subscription, ClientId) -> + mnesia:dirty_read(subscription, ClientId). + +%%------------------------------------------------------------------------------ +%% @doc Delete Topic or Subscription. +%% @end +%%------------------------------------------------------------------------------ +delete(topic, _Topic) -> + {error, unsupported}; + +delete(subscription, ClientId) when is_binary(ClientId) -> + mnesia:dirty_deleate({subscription, ClientId}); + +delete(subscription, {ClientId, Topic}) when is_binary(ClientId) -> + mnesia:async_dirty(fun remove_subscriptions/2, [ClientId, [Topic]]). + %%------------------------------------------------------------------------------ %% @doc Subscribe Topics %% @end @@ -363,7 +389,7 @@ remove_subscriptions(SubId, Topics) -> lists:foreach(fun(Topic) -> Pattern = #mqtt_subscription{subid = SubId, topic = Topic, qos = '_'}, Records = mnesia:match_object(subscription, Pattern, write), - [delete_subscription(Record) || Record <- Records] + lists:foreach(fun delete_subscription/1, Records) end, Topics). delete_subscription(Record) -> diff --git a/src/emqttd_session.erl b/src/emqttd_session.erl index e2606ae5d..dde925ab0 100644 --- a/src/emqttd_session.erl +++ b/src/emqttd_session.erl @@ -483,7 +483,7 @@ handle_cast(Msg, State) -> %% Dispatch Message handle_info({dispatch, Topic, Msg}, Session = #session{subscriptions = Subscriptions}) when is_record(Msg, mqtt_message) -> - dispatch(fixqos(Topic, Msg, Subscriptions), Session); + dispatch(tune_qos(Topic, Msg, Subscriptions), Session); handle_info({timeout, awaiting_ack, PktId}, Session = #session{client_pid = undefined, awaiting_ack = AwaitingAck}) -> @@ -603,7 +603,7 @@ dispatch(Msg = #mqtt_message{qos = QoS}, Session = #session{message_queue = MsgQ hibernate(Session#session{message_queue = emqttd_mqueue:in(Msg, MsgQ)}) end. -fixqos(Topic, Msg = #mqtt_message{qos = PubQos}, Subscriptions) -> +tune_qos(Topic, Msg = #mqtt_message{qos = PubQos}, Subscriptions) -> case dict:find(Topic, Subscriptions) of {ok, SubQos} when PubQos > SubQos -> Msg#mqtt_message{qos = SubQos};