From 8f00e28576a5aac45879e8ff30173598ba759a4c Mon Sep 17 00:00:00 2001 From: Feng Lee Date: Fri, 17 Nov 2017 20:51:51 +0800 Subject: [PATCH] Improve the pubsub design and fix the race-condition issue --- src/emqttd.erl | 48 +++++----- src/emqttd_pubsub.erl | 77 +++++++++-------- src/emqttd_router.erl | 197 +++++++++++++++++++++--------------------- src/emqttd_server.erl | 196 ++++++++++++++++++++++------------------- src/emqttd_trie.erl | 96 ++++++++++---------- 5 files changed, 321 insertions(+), 293 deletions(-) diff --git a/src/emqttd.erl b/src/emqttd.erl index d4cdd8437..f012fdc23 100644 --- a/src/emqttd.erl +++ b/src/emqttd.erl @@ -31,8 +31,7 @@ unsubscribe/1, unsubscribe/2]). %% PubSub Management API --export([setqos/3, topics/0, subscriptions/1, subscribers/1, - is_subscribed/2, subscriber_down/1]). +-export([setqos/3, topics/0, subscriptions/1, subscribers/1, subscribed/2]). %% Hooks API -export([hook/4, hook/3, unhook/2, run_hooks/2, run_hooks/3]). @@ -43,14 +42,13 @@ %% Shutdown and reboot -export([shutdown/0, shutdown/1, reboot/0]). --type(subscriber() :: pid() | binary()). +-type(subid() :: binary()). + +-type(subscriber() :: pid() | subid() | {subid(), pid()}). -type(suboption() :: local | {qos, non_neg_integer()} | {share, {'$queue' | binary()}}). --type(pubsub_error() :: {error, {already_subscribed, binary()} - | {subscription_not_found, binary()}}). - --export_type([subscriber/0, suboption/0, pubsub_error/0]). +-export_type([subscriber/0, suboption/0]). -define(APP, ?MODULE). @@ -59,19 +57,19 @@ %%-------------------------------------------------------------------- %% @doc Start emqttd application. --spec(start() -> ok | {error, any()}). +-spec(start() -> ok | {error, term()}). start() -> application:start(?APP). %% @doc Stop emqttd application. --spec(stop() -> ok | {error, any()}). +-spec(stop() -> ok | {error, term()}). stop() -> application:stop(?APP). %% @doc Environment --spec(env(Key:: atom()) -> {ok, any()} | undefined). +-spec(env(Key :: atom()) -> {ok, any()} | undefined). env(Key) -> application:get_env(?APP, Key). %% @doc Get environment --spec(env(Key:: atom(), Default:: any()) -> undefined | any()). +-spec(env(Key :: atom(), Default :: any()) -> undefined | any()). env(Key, Default) -> application:get_env(?APP, Key, Default). %% @doc Is running? @@ -88,15 +86,15 @@ is_running(Node) -> %%-------------------------------------------------------------------- %% @doc Subscribe --spec(subscribe(iodata()) -> ok | {error, any()}). +-spec(subscribe(iodata()) -> ok | {error, term()}). subscribe(Topic) -> - subscribe(Topic, self()). + emqttd_server:subscribe(iolist_to_binary(Topic)). --spec(subscribe(iodata(), subscriber()) -> ok | {error, any()}). +-spec(subscribe(iodata(), subscriber()) -> ok | {error, term()}). subscribe(Topic, Subscriber) -> - subscribe(Topic, Subscriber, []). + emqttd_server:subscribe(iolist_to_binary(Topic), Subscriber). --spec(subscribe(iodata(), subscriber(), [suboption()]) -> ok | pubsub_error()). +-spec(subscribe(iodata(), subscriber(), [suboption()]) -> ok | {error, term()}). subscribe(Topic, Subscriber, Options) -> emqttd_server:subscribe(iolist_to_binary(Topic), Subscriber, Options). @@ -106,11 +104,11 @@ publish(Msg) -> emqttd_server:publish(Msg). %% @doc Unsubscribe --spec(unsubscribe(iodata()) -> ok | pubsub_error()). +-spec(unsubscribe(iodata()) -> ok | {error, term()}). unsubscribe(Topic) -> - unsubscribe(Topic, self()). + emqttd_server:unsubscribe(iolist_to_binary(Topic)). --spec(unsubscribe(iodata(), subscriber()) -> ok | pubsub_error()). +-spec(unsubscribe(iodata(), subscriber()) -> ok | {error, term()}). unsubscribe(Topic, Subscriber) -> emqttd_server:unsubscribe(iolist_to_binary(Topic), Subscriber). @@ -125,17 +123,13 @@ topics() -> emqttd_router:topics(). subscribers(Topic) -> emqttd_server:subscribers(iolist_to_binary(Topic)). --spec(subscriptions(subscriber()) -> [{binary(), binary(), list(suboption())}]). +-spec(subscriptions(subscriber()) -> [{emqttd:subscriber(), binary(), list(emqttd:suboption())}]). subscriptions(Subscriber) -> emqttd_server:subscriptions(Subscriber). --spec(is_subscribed(iodata(), subscriber()) -> boolean()). -is_subscribed(Topic, Subscriber) -> - emqttd_server:is_subscribed(iolist_to_binary(Topic), Subscriber). - --spec(subscriber_down(subscriber()) -> ok). -subscriber_down(Subscriber) -> - emqttd_server:subscriber_down(Subscriber). +-spec(subscribed(iodata(), subscriber()) -> boolean()). +subscribed(Topic, Subscriber) -> + emqttd_server:subscribed(iolist_to_binary(Topic), Subscriber). %%-------------------------------------------------------------------- %% Hooks API diff --git a/src/emqttd_pubsub.erl b/src/emqttd_pubsub.erl index d976618cd..994ef6230 100644 --- a/src/emqttd_pubsub.erl +++ b/src/emqttd_pubsub.erl @@ -46,7 +46,7 @@ %% Start PubSub %%-------------------------------------------------------------------- --spec(start_link(atom(), pos_integer(), list()) -> {ok, pid()} | ignore | {error, any()}). +-spec(start_link(atom(), pos_integer(), list()) -> {ok, pid()} | ignore | {error, term()}). start_link(Pool, Id, Env) -> gen_server2:start_link({local, ?PROC_NAME(?MODULE, Id)}, ?MODULE, [Pool, Id, Env], []). @@ -54,7 +54,7 @@ start_link(Pool, Id, Env) -> %% PubSub API %%-------------------------------------------------------------------- -%% @doc Subscribe a Topic +%% @doc Subscribe to a Topic -spec(subscribe(binary(), emqttd:subscriber(), [emqttd:suboption()]) -> ok). subscribe(Topic, Subscriber, Options) -> call(pick(Topic), {subscribe, Topic, Subscriber, Options}). @@ -63,8 +63,8 @@ subscribe(Topic, Subscriber, Options) -> async_subscribe(Topic, Subscriber, Options) -> cast(pick(Topic), {subscribe, Topic, Subscriber, Options}). -%% @doc Publish MQTT Message to Topic --spec(publish(binary(), any()) -> {ok, mqtt_delivery()} | ignore). +%% @doc Publish MQTT Message to Topic. +-spec(publish(binary(), mqtt_message()) -> {ok, mqtt_delivery()} | ignore). publish(Topic, Msg) -> route(lists:append(emqttd_router:match(Topic), emqttd_router:match_local(Topic)), delivery(Msg)). @@ -72,7 +72,7 @@ publish(Topic, Msg) -> route([], #mqtt_delivery{message = #mqtt_message{topic = Topic}}) -> dropped(Topic), ignore; -%% Dispatch on the local node +%% Dispatch on the local node. route([#mqtt_route{topic = To, node = Node}], Delivery = #mqtt_delivery{flows = Flows}) when Node =:= node() -> dispatch(To, Delivery#mqtt_delivery{flows = [{route, Node, To} | Flows]}); @@ -82,8 +82,8 @@ route([#mqtt_route{topic = To, node = Node}], Delivery = #mqtt_delivery{flows = forward(Node, To, Delivery#mqtt_delivery{flows = [{route, Node, To}|Flows]}); route(Routes, Delivery) -> - {ok, lists:foldl(fun(Route, DelAcc) -> - {ok, DelAcc1} = route([Route], DelAcc), DelAcc1 + {ok, lists:foldl(fun(Route, Acc) -> + {ok, Acc1} = route([Route], Acc), Acc1 end, Delivery, Routes)}. delivery(Msg) -> #mqtt_delivery{sender = self(), message = Msg, flows = []}. @@ -92,7 +92,7 @@ delivery(Msg) -> #mqtt_delivery{sender = self(), message = Msg, flows = []}. forward(Node, To, Delivery) -> rpc:cast(Node, ?PUBSUB, dispatch, [To, Delivery]), {ok, Delivery}. -%% @doc Dispatch Message to Subscribers +%% @doc Dispatch Message to Subscribers. -spec(dispatch(binary(), mqtt_delivery()) -> mqtt_delivery()). dispatch(Topic, Delivery = #mqtt_delivery{message = Msg, flows = Flows}) -> case subscribers(Topic) of @@ -107,16 +107,16 @@ dispatch(Topic, Delivery = #mqtt_delivery{message = Msg, flows = Flows}) -> {ok, Delivery#mqtt_delivery{flows = Flows1}} end. -dispatch(Pid, Topic, Msg) when is_pid(Pid) -> - Pid ! {dispatch, Topic, Msg}; -dispatch(SubId, Topic, Msg) when is_binary(SubId) -> - emqttd_sm:dispatch(SubId, Topic, Msg); -dispatch({_Share, [Sub]}, Topic, Msg) -> +%%TODO: Is SubPid aliving??? +dispatch(SubPid, Topic, Msg) when is_pid(SubPid) -> + SubPid ! {dispatch, Topic, Msg}; +dispatch({SubId, SubPid}, Topic, Msg) when is_binary(SubId), is_pid(SubPid) -> + SubPid ! {dispatch, Topic, Msg}; +dispatch({{share, _Share}, [Sub]}, Topic, Msg) -> dispatch(Sub, Topic, Msg); -dispatch({_Share, []}, _Topic, _Msg) -> +dispatch({{share, _Share}, []}, _Topic, _Msg) -> ok; -%%TODO: round-robbin -dispatch({_Share, Subs}, Topic, Msg) -> +dispatch({{share, _Share}, Subs}, Topic, Msg) -> %% round-robbin? dispatch(lists:nth(rand:uniform(length(Subs)), Subs), Topic, Msg). subscribers(Topic) -> @@ -126,8 +126,8 @@ group_by_share([]) -> []; group_by_share(Subscribers) -> {Subs1, Shares1} = - lists:foldl(fun({Share, Sub}, {Subs, Shares}) -> - {Subs, dict:append(Share, Sub, Shares)}; + lists:foldl(fun({share, Share, Sub}, {Subs, Shares}) -> + {Subs, dict:append({share, Share}, Sub, Shares)}; (Sub, {Subs, Shares}) -> {[Sub|Subs], Shares} end, {[], dict:new()}, Subscribers), @@ -155,8 +155,8 @@ call(PubSub, Req) when is_pid(PubSub) -> cast(PubSub, Msg) when is_pid(PubSub) -> gen_server2:cast(PubSub, Msg). -pick(Subscriber) -> - gproc_pool:pick_worker(pubsub, Subscriber). +pick(Topic) -> + gproc_pool:pick_worker(pubsub, Topic). %%-------------------------------------------------------------------- %% gen_server Callbacks @@ -169,22 +169,22 @@ init([Pool, Id, Env]) -> handle_call({subscribe, Topic, Subscriber, Options}, _From, State) -> add_subscriber(Topic, Subscriber, Options), - {reply, ok, setstats(State), hibernate}; + reply(ok, setstats(State)); handle_call({unsubscribe, Topic, Subscriber, Options}, _From, State) -> del_subscriber(Topic, Subscriber, Options), - {reply, ok, setstats(State), hibernate}; + reply(ok, setstats(State)); handle_call(Req, _From, State) -> ?UNEXPECTED_REQ(Req, State). handle_cast({subscribe, Topic, Subscriber, Options}, State) -> add_subscriber(Topic, Subscriber, Options), - {noreply, setstats(State), hibernate}; + noreply(setstats(State)); handle_cast({unsubscribe, Topic, Subscriber, Options}, State) -> del_subscriber(Topic, Subscriber, Options), - {noreply, setstats(State), hibernate}; + noreply(setstats(State)); handle_cast(Msg, State) -> ?UNEXPECTED_MSG(Msg, State). @@ -205,39 +205,48 @@ code_change(_OldVsn, State, _Extra) -> add_subscriber(Topic, Subscriber, Options) -> Share = proplists:get_value(share, Options), case ?is_local(Options) of - false -> add_subscriber_(Share, Topic, Subscriber); - true -> add_local_subscriber_(Share, Topic, Subscriber) + false -> add_global_subscriber(Share, Topic, Subscriber); + true -> add_local_subscriber(Share, Topic, Subscriber) end. -add_subscriber_(Share, Topic, Subscriber) -> - (not ets:member(mqtt_subscriber, Topic)) andalso emqttd_router:add_route(Topic), +add_global_subscriber(Share, Topic, Subscriber) -> + case ets:member(mqtt_subscriber, Topic) and emqttd_router:has_route(Topic) of + true -> ok; + false -> emqttd_router:add_route(Topic) + end, ets:insert(mqtt_subscriber, {Topic, shared(Share, Subscriber)}). -add_local_subscriber_(Share, Topic, Subscriber) -> +add_local_subscriber(Share, Topic, Subscriber) -> (not ets:member(mqtt_subscriber, {local, Topic})) andalso emqttd_router:add_local_route(Topic), ets:insert(mqtt_subscriber, {{local, Topic}, shared(Share, Subscriber)}). del_subscriber(Topic, Subscriber, Options) -> Share = proplists:get_value(share, Options), case ?is_local(Options) of - false -> del_subscriber_(Share, Topic, Subscriber); - true -> del_local_subscriber_(Share, Topic, Subscriber) + false -> del_global_subscriber(Share, Topic, Subscriber); + true -> del_local_subscriber(Share, Topic, Subscriber) end. -del_subscriber_(Share, Topic, Subscriber) -> +del_global_subscriber(Share, Topic, Subscriber) -> ets:delete_object(mqtt_subscriber, {Topic, shared(Share, Subscriber)}), (not ets:member(mqtt_subscriber, Topic)) andalso emqttd_router:del_route(Topic). -del_local_subscriber_(Share, Topic, Subscriber) -> +del_local_subscriber(Share, Topic, Subscriber) -> ets:delete_object(mqtt_subscriber, {{local, Topic}, shared(Share, Subscriber)}), (not ets:member(mqtt_subscriber, {local, Topic})) andalso emqttd_router:del_local_route(Topic). shared(undefined, Subscriber) -> Subscriber; shared(Share, Subscriber) -> - {Share, Subscriber}. + {share, Share, Subscriber}. setstats(State) -> emqttd_stats:setstats('subscribers/count', 'subscribers/max', ets:info(mqtt_subscriber, size)), State. +reply(Reply, State) -> + {reply, Reply, State, hibernate}. + +noreply(State) -> + {noreply, State, hibernate}. + diff --git a/src/emqttd_router.erl b/src/emqttd_router.erl index b3dd8b4ad..4d4e22160 100644 --- a/src/emqttd_router.erl +++ b/src/emqttd_router.erl @@ -28,15 +28,17 @@ -boot_mnesia({mnesia, [boot]}). -copy_mnesia({mnesia, [copy]}). -%% Start/Stop --export([start_link/0, topics/0, local_topics/0, stop/0]). +-export([start_link/0, topics/0, local_topics/0]). + +%% For eunit tests +-export([start/0, stop/0]). %% Route APIs --export([add_route/1, add_route/2, add_routes/1, match/1, print/1, - del_route/1, del_route/2, del_routes/1, has_route/1]). +-export([add_route/1, del_route/1, match/1, print/1, has_route/1]). %% Local Route API --export([add_local_route/1, del_local_route/1, match_local/1]). +-export([get_local_routes/0, add_local_route/1, match_local/1, + del_local_route/1, clean_local_routes/0]). %% gen_server Function Exports -export([init/1, handle_call/3, handle_cast/2, handle_info/2, @@ -55,10 +57,6 @@ %%-------------------------------------------------------------------- mnesia(boot) -> - ok = ekka_mnesia:create_table(mqtt_topic, [ - {ram_copies, [node()]}, - {record_name, mqtt_topic}, - {attributes, record_info(fields, mqtt_topic)}]), ok = ekka_mnesia:create_table(mqtt_route, [ {type, bag}, {ram_copies, [node()]}, @@ -66,7 +64,6 @@ mnesia(boot) -> {attributes, record_info(fields, mqtt_route)}]); mnesia(copy) -> - ok = ekka_mnesia:copy_table(mqtt_topic), ok = ekka_mnesia:copy_table(mqtt_route, ram_copies). %%-------------------------------------------------------------------- @@ -77,19 +74,26 @@ start_link() -> gen_server:start_link({local, ?ROUTER}, ?MODULE, [], []). %%-------------------------------------------------------------------- -%% API +%% Topics %%-------------------------------------------------------------------- +-spec(topics() -> list(binary())). topics() -> mnesia:dirty_all_keys(mqtt_route). +-spec(local_topics() -> list(binary())). local_topics() -> ets:select(mqtt_local_route, [{{'$1', '_'}, [], ['$1']}]). +%%-------------------------------------------------------------------- +%% Match API +%%-------------------------------------------------------------------- + %% @doc Match Routes. -spec(match(Topic:: binary()) -> [mqtt_route()]). match(Topic) when is_binary(Topic) -> - Matched = mnesia:async_dirty(fun emqttd_trie:match/1, [Topic]), + %% Optimize: ets??? + Matched = mnesia:ets(fun emqttd_trie:match/1, [Topic]), %% Optimize: route table will be replicated to all nodes. lists:append([ets:lookup(mqtt_route, To) || To <- [Topic | Matched]]). @@ -99,93 +103,68 @@ print(Topic) -> [io:format("~s -> ~s~n", [To, Node]) || #mqtt_route{topic = To, node = Node} <- match(Topic)]. -%% @doc Add Route --spec(add_route(binary() | mqtt_route()) -> ok | {error, Reason :: any()}). +%%-------------------------------------------------------------------- +%% Route Management API +%%-------------------------------------------------------------------- + +%% @doc Add Route. +-spec(add_route(binary() | mqtt_route()) -> ok | {error, Reason :: term()}). add_route(Topic) when is_binary(Topic) -> add_route(#mqtt_route{topic = Topic, node = node()}); -add_route(Route) when is_record(Route, mqtt_route) -> - add_routes([Route]). - --spec(add_route(Topic :: binary(), Node :: node()) -> ok | {error, Reason :: any()}). -add_route(Topic, Node) when is_binary(Topic), is_atom(Node) -> - add_route(#mqtt_route{topic = Topic, node = Node}). - -%% @doc Add Routes --spec(add_routes([mqtt_route()]) -> ok | {error, Reason :: any()}). -add_routes(Routes) -> - AddFun = fun() -> [add_route_(Route) || Route <- Routes] end, - case mnesia:is_transaction() of - true -> AddFun(); - false -> trans(AddFun) +add_route(Route = #mqtt_route{topic = Topic}) -> + case emqttd_topic:wildcard(Topic) of + true -> case mnesia:is_transaction() of + true -> add_trie_route(Route); + false -> trans(fun add_trie_route/1, [Route]) + end; + false -> add_direct_route(Route) end. -%% @private -add_route_(Route = #mqtt_route{topic = Topic}) -> +add_direct_route(Route) -> + mnesia:async_dirty(fun mnesia:write/1, [Route]). + +add_trie_route(Route = #mqtt_route{topic = Topic}) -> case mnesia:wread({mqtt_route, Topic}) of - [] -> - case emqttd_topic:wildcard(Topic) of - true -> emqttd_trie:insert(Topic); - false -> ok - end, - mnesia:write(Route), - mnesia:write(#mqtt_topic{topic = Topic}); - Records -> - case lists:member(Route, Records) of - true -> ok; - false -> mnesia:write(Route) - end - end. + [] -> emqttd_trie:insert(Topic); + _ -> ok + end, + mnesia:write(Route). %% @doc Delete Route --spec(del_route(binary() | mqtt_route()) -> ok | {error, Reason :: any()}). +-spec(del_route(binary() | mqtt_route()) -> ok | {error, Reason :: term()}). del_route(Topic) when is_binary(Topic) -> del_route(#mqtt_route{topic = Topic, node = node()}); -del_route(Route) when is_record(Route, mqtt_route) -> - del_routes([Route]). - --spec(del_route(Topic :: binary(), Node :: node()) -> ok | {error, Reason :: any()}). -del_route(Topic, Node) when is_binary(Topic), is_atom(Node) -> - del_route(#mqtt_route{topic = Topic, node = Node}). - -%% @doc Delete Routes --spec(del_routes([mqtt_route()]) -> ok | {error, any()}). -del_routes(Routes) -> - DelFun = fun() -> [del_route_(Route) || Route <- Routes] end, - case mnesia:is_transaction() of - true -> DelFun(); - false -> trans(DelFun) +del_route(Route = #mqtt_route{topic = Topic}) -> + case emqttd_topic:wildcard(Topic) of + true -> case mnesia:is_transaction() of + true -> del_trie_route(Route); + false -> trans(fun del_trie_route/1, [Route]) + end; + false -> del_direct_route(Route) end. -del_route_(Route = #mqtt_route{topic = Topic}) -> +del_direct_route(Route) -> + mnesia:async_dirty(fun mnesia:delete_object/1, [Route]). + +del_trie_route(Route = #mqtt_route{topic = Topic}) -> case mnesia:wread({mqtt_route, Topic}) of - [] -> - ok; - [Route] -> - %% Remove route and trie - mnesia:delete_object(Route), - case emqttd_topic:wildcard(Topic) of - true -> emqttd_trie:delete(Topic); - false -> ok - end, - mnesia:delete({mqtt_topic, Topic}); - _More -> - %% Remove route only - mnesia:delete_object(Route) + [Route] -> %% Remove route and trie + mnesia:delete_object(Route), + emqttd_trie:delete(Topic); + [_|_] -> %% Remove route only + mnesia:delete_object(Route); + [] -> ok end. -%% @doc Has Route? +%% @doc Has route? -spec(has_route(binary()) -> boolean()). -has_route(Topic) -> - Routes = case mnesia:is_transaction() of - true -> mnesia:read(mqtt_route, Topic); - false -> mnesia:dirty_read(mqtt_route, Topic) - end, - length(Routes) > 0. +has_route(Topic) when is_binary(Topic) -> + ets:member(mqtt_route, Topic). %% @private --spec(trans(function()) -> ok | {error, any()}). -trans(Fun) -> - case mnesia:transaction(Fun) of +-spec(trans(function(), list(any())) -> ok | {error, term()}). +trans(Fun, Args) -> + case mnesia:transaction(Fun, Args) of {atomic, _} -> ok; {aborted, Error} -> {error, Error} end. @@ -194,24 +173,44 @@ trans(Fun) -> %% Local Route API %%-------------------------------------------------------------------- +-spec(get_local_routes() -> list({binary(), node()})). +get_local_routes() -> + ets:tab2list(mqtt_local_route). + -spec(add_local_route(binary()) -> ok). add_local_route(Topic) -> - gen_server:cast(?ROUTER, {add_local_route, Topic}). + gen_server:call(?ROUTER, {add_local_route, Topic}). -spec(del_local_route(binary()) -> ok). del_local_route(Topic) -> - gen_server:cast(?ROUTER, {del_local_route, Topic}). + gen_server:call(?ROUTER, {del_local_route, Topic}). -spec(match_local(binary()) -> [mqtt_route()]). match_local(Name) -> - [#mqtt_route{topic = {local, Filter}, node = Node} - || {Filter, Node} <- ets:tab2list(mqtt_local_route), - emqttd_topic:match(Name, Filter)]. + case ets:info(mqtt_local_route, size) of + 0 -> []; + _ -> ets:foldl( + fun({Filter, Node}, Matched) -> + case emqttd_topic:match(Name, Filter) of + true -> [#mqtt_route{topic = {local, Filter}, node = Node} | Matched]; + false -> Matched + end + end, [], mqtt_local_route) + end. + +-spec(clean_local_routes() -> ok). +clean_local_routes() -> + gen_server:call(?ROUTER, clean_local_routes). dump() -> [{route, ets:tab2list(mqtt_route)}, {local_route, ets:tab2list(mqtt_local_route)}]. -stop() -> gen_server:call(?ROUTER, stop). +%% For unit test. +start() -> + gen_server:start({local, ?ROUTER}, ?MODULE, [], []). + +stop() -> + gen_server:call(?ROUTER, stop). %%-------------------------------------------------------------------- %% gen_server Callbacks @@ -223,21 +222,25 @@ init([]) -> {ok, TRef} = timer:send_interval(timer:seconds(1), stats), {ok, #state{stats_timer = TRef}}. +handle_call({add_local_route, Topic}, _From, State) -> + %% why node()...? + ets:insert(mqtt_local_route, {Topic, node()}), + {reply, ok, State}; + +handle_call({del_local_route, Topic}, _From, State) -> + ets:delete(mqtt_local_route, Topic), + {reply, ok, State}; + +handle_call(clean_local_routes, _From, State) -> + ets:delete_all_objects(mqtt_local_route), + {reply, ok, State}; + handle_call(stop, _From, State) -> {stop, normal, ok, State}; handle_call(_Req, _From, State) -> {reply, ignore, State}. -handle_cast({add_local_route, Topic}, State) -> - %% why node()...? - ets:insert(mqtt_local_route, {Topic, node()}), - {noreply, State}; - -handle_cast({del_local_route, Topic}, State) -> - ets:delete(mqtt_local_route, Topic), - {noreply, State}; - handle_cast(_Msg, State) -> {noreply, State}. diff --git a/src/emqttd_server.erl b/src/emqttd_server.erl index 69d18e1e4..4e05c00aa 100644 --- a/src/emqttd_server.erl +++ b/src/emqttd_server.erl @@ -37,8 +37,7 @@ async_unsubscribe/1, async_unsubscribe/2]). %% Management API. --export([setqos/3, subscriptions/1, subscribers/1, is_subscribed/2, - subscriber_down/1]). +-export([setqos/3, subscriptions/1, subscribers/1, subscribed/2]). %% Debug API -export([dump/0]). @@ -47,10 +46,10 @@ -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --record(state, {pool, id, env, submon :: emqttd_pmon:pmon()}). +-record(state, {pool, id, env, subids :: map(), submon :: emqttd_pmon:pmon()}). -%% @doc Start server --spec(start_link(atom(), pos_integer(), list()) -> {ok, pid()} | ignore | {error, any()}). +%% @doc Start the server +-spec(start_link(atom(), pos_integer(), list()) -> {ok, pid()} | ignore | {error, term()}). start_link(Pool, Id, Env) -> gen_server2:start_link({local, ?PROC_NAME(?MODULE, Id)}, ?MODULE, [Pool, Id, Env], []). @@ -58,21 +57,21 @@ start_link(Pool, Id, Env) -> %% PubSub API %%-------------------------------------------------------------------- -%% @doc Subscribe a Topic --spec(subscribe(binary()) -> ok | emqttd:pubsub_error()). +%% @doc Subscribe to a Topic. +-spec(subscribe(binary()) -> ok | {error, term()}). subscribe(Topic) when is_binary(Topic) -> subscribe(Topic, self()). --spec(subscribe(binary(), emqttd:subscriber()) -> ok | emqttd:pubsub_error()). +-spec(subscribe(binary(), emqttd:subscriber()) -> ok | {error, term()}). subscribe(Topic, Subscriber) when is_binary(Topic) -> subscribe(Topic, Subscriber, []). -spec(subscribe(binary(), emqttd:subscriber(), [emqttd:suboption()]) -> - ok | emqttd:pubsub_error()). + ok | {error, term()}). subscribe(Topic, Subscriber, Options) when is_binary(Topic) -> - call(pick(Subscriber), {subscribe, Topic, Subscriber, Options}). + call(pick(Subscriber), {subscribe, Topic, with_subpid(Subscriber), Options}). -%% @doc Subscribe a Topic Asynchronously +%% @doc Subscribe to a Topic asynchronously. -spec(async_subscribe(binary()) -> ok). async_subscribe(Topic) when is_binary(Topic) -> async_subscribe(Topic, self()). @@ -83,7 +82,7 @@ async_subscribe(Topic, Subscriber) when is_binary(Topic) -> -spec(async_subscribe(binary(), emqttd:subscriber(), [emqttd:suboption()]) -> ok). async_subscribe(Topic, Subscriber, Options) when is_binary(Topic) -> - cast(pick(Subscriber), {subscribe, Topic, Subscriber, Options}). + cast(pick(Subscriber), {subscribe, Topic, with_subpid(Subscriber), Options}). %% @doc Publish message to Topic. -spec(publish(mqtt_message()) -> {ok, mqtt_delivery()} | ignore). @@ -109,14 +108,14 @@ trace(publish, From, #mqtt_message{topic = Topic, payload = Payload}) -> "~s PUBLISH to ~s: ~p", [From, Topic, Payload]). %% @doc Unsubscribe --spec(unsubscribe(binary()) -> ok | emqttd:pubsub_error()). +-spec(unsubscribe(binary()) -> ok | {error, term()}). unsubscribe(Topic) when is_binary(Topic) -> unsubscribe(Topic, self()). %% @doc Unsubscribe --spec(unsubscribe(binary(), emqttd:subscriber()) -> ok | emqttd:pubsub_error()). +-spec(unsubscribe(binary(), emqttd:subscriber()) -> ok | {error, term()}). unsubscribe(Topic, Subscriber) when is_binary(Topic) -> - call(pick(Subscriber), {unsubscribe, Topic, Subscriber}). + call(pick(Subscriber), {unsubscribe, Topic, with_subpid(Subscriber)}). %% @doc Async Unsubscribe -spec(async_unsubscribe(binary()) -> ok). @@ -125,32 +124,47 @@ async_unsubscribe(Topic) when is_binary(Topic) -> -spec(async_unsubscribe(binary(), emqttd:subscriber()) -> ok). async_unsubscribe(Topic, Subscriber) when is_binary(Topic) -> - cast(pick(Subscriber), {unsubscribe, Topic, Subscriber}). + cast(pick(Subscriber), {unsubscribe, Topic, with_subpid(Subscriber)}). +-spec(setqos(binary(), emqttd:subscriber(), mqtt_qos()) -> ok). setqos(Topic, Subscriber, Qos) when is_binary(Topic) -> - call(pick(Subscriber), {setqos, Topic, Subscriber, Qos}). + call(pick(Subscriber), {setqos, Topic, with_subpid(Subscriber), Qos}). --spec(subscriptions(emqttd:subscriber()) -> [{binary(), binary(), list(emqttd:suboption())}]). -subscriptions(Subscriber) -> - lists:map(fun({_, {_Share, Topic}}) -> - subscription(Topic, Subscriber); - ({_, Topic}) -> - subscription(Topic, Subscriber) - end, ets:lookup(mqtt_subscription, Subscriber)). +with_subpid(SubPid) when is_pid(SubPid) -> + SubPid; +with_subpid(SubId) when is_binary(SubId) -> + {SubId, self()}; +with_subpid({SubId, SubPid}) when is_binary(SubId), is_pid(SubPid) -> + {SubId, SubPid}. -subscription(Topic, Subscriber) -> - {Topic, Subscriber, ets:lookup_element(mqtt_subproperty, {Topic, Subscriber}, 2)}. +-spec(subscriptions(emqttd:subscriber()) -> [{emqttd:subscriber(), binary(), list(emqttd:suboption())}]). +subscriptions(SubPid) when is_pid(SubPid) -> + with_subproperty(ets:lookup(mqtt_subscription, SubPid)); -subscribers(Topic) -> +subscriptions(SubId) when is_binary(SubId) -> + with_subproperty(ets:match_object(mqtt_subscription, {{SubId, '_'}, '_'})); + +subscriptions({SubId, SubPid}) when is_binary(SubId), is_pid(SubPid) -> + with_subproperty(ets:lookup(mqtt_subscription, {SubId, SubPid})). + +with_subproperty({Subscriber, {share, _Share, Topic}}) -> + with_subproperty({Subscriber, Topic}); +with_subproperty({Subscriber, Topic}) -> + {Subscriber, Topic, ets:lookup_element(mqtt_subproperty, {Topic, Subscriber}, 2)}; +with_subproperty(Subscriptions) when is_list(Subscriptions) -> + [with_subproperty(Subscription) || Subscription <- Subscriptions]. + +-spec(subscribers(binary()) -> list(emqttd:subscriber())). +subscribers(Topic) when is_binary(Topic) -> emqttd_pubsub:subscribers(Topic). --spec(is_subscribed(binary(), emqttd:subscriber()) -> boolean()). -is_subscribed(Topic, Subscriber) when is_binary(Topic) -> - ets:member(mqtt_subproperty, {Topic, Subscriber}). - --spec(subscriber_down(emqttd:subscriber()) -> ok). -subscriber_down(Subscriber) -> - cast(pick(Subscriber), {subscriber_down, Subscriber}). +-spec(subscribed(binary(), emqttd:subscriber()) -> boolean()). +subscribed(Topic, SubPid) when is_binary(Topic), is_pid(SubPid) -> + ets:member(mqtt_subproperty, {Topic, SubPid}); +subscribed(Topic, SubId) when is_binary(Topic), is_binary(SubId) -> + length(ets:match_object(mqtt_subproperty, {{Topic, {SubId, '_'}}, '_'}, 1)) == 1; +subscribed(Topic, {SubId, SubPid}) when is_binary(Topic), is_binary(SubId), is_pid(SubPid) -> + ets:member(mqtt_subproperty, {Topic, {SubId, SubPid}}). call(Server, Req) -> gen_server2:call(Server, Req, infinity). @@ -158,8 +172,12 @@ call(Server, Req) -> cast(Server, Msg) when is_pid(Server) -> gen_server2:cast(Server, Msg). -pick(Subscriber) -> - gproc_pool:pick_worker(server, Subscriber). +pick(SubPid) when is_pid(SubPid) -> + gproc_pool:pick_worker(server, SubPid); +pick(SubId) when is_binary(SubId) -> + gproc_pool:pick_worker(server, SubId); +pick({SubId, SubPid}) when is_binary(SubId), is_pid(SubPid) -> + pick(SubId). dump() -> [{Tab, ets:tab2list(Tab)} || Tab <- [mqtt_subproperty, mqtt_subscription, mqtt_subscriber]]. @@ -170,18 +188,20 @@ dump() -> init([Pool, Id, Env]) -> ?GPROC_POOL(join, Pool, Id), - {ok, #state{pool = Pool, id = Id, env = Env, submon = emqttd_pmon:new()}}. + State = #state{pool = Pool, id = Id, env = Env, + subids = #{}, submon = emqttd_pmon:new()}, + {ok, State, hibernate, {backoff, 2000, 2000, 20000}}. handle_call({subscribe, Topic, Subscriber, Options}, _From, State) -> - case do_subscribe_(Topic, Subscriber, Options, State) of - {ok, NewState} -> {reply, ok, setstats(NewState)}; - {error, Error} -> {reply, {error, Error}, State} + case do_subscribe(Topic, Subscriber, Options, State) of + {ok, NewState} -> reply(ok, setstats(NewState)); + {error, Error} -> reply({error, Error}, State) end; handle_call({unsubscribe, Topic, Subscriber}, _From, State) -> - case do_unsubscribe_(Topic, Subscriber, State) of - {ok, NewState} -> {reply, ok, setstats(NewState), hibernate}; - {error, Error} -> {reply, {error, Error}, State} + case do_unsubscribe(Topic, Subscriber, State) of + {ok, NewState} -> reply(ok, setstats(NewState)); + {error, Error} -> reply({error, Error}, State) end; handle_call({setqos, Topic, Subscriber, Qos}, _From, State) -> @@ -190,36 +210,37 @@ handle_call({setqos, Topic, Subscriber, Qos}, _From, State) -> [{_, Opts}] -> Opts1 = lists:ukeymerge(1, [{qos, Qos}], Opts), ets:insert(mqtt_subproperty, {Key, Opts1}), - {reply, ok, State}; + reply(ok, State); [] -> - {reply, {error, {subscription_not_found, Topic}}, State} + reply({error, {subscription_not_found, Topic}}, State) end; handle_call(Req, _From, State) -> ?UNEXPECTED_REQ(Req, State). handle_cast({subscribe, Topic, Subscriber, Options}, State) -> - case do_subscribe_(Topic, Subscriber, Options, State) of - {ok, NewState} -> {noreply, setstats(NewState)}; - {error, _Error} -> {noreply, State} + case do_subscribe(Topic, Subscriber, Options, State) of + {ok, NewState} -> noreply(setstats(NewState)); + {error, _Error} -> noreply(State) end; handle_cast({unsubscribe, Topic, Subscriber}, State) -> - case do_unsubscribe_(Topic, Subscriber, State) of - {ok, NewState} -> {noreply, setstats(NewState), hibernate}; - {error, _Error} -> {noreply, State} + case do_unsubscribe(Topic, Subscriber, State) of + {ok, NewState} -> noreply(setstats(NewState)); + {error, _Error} -> noreply(State) end; -handle_cast({subscriber_down, Subscriber}, State) -> - subscriber_down_(Subscriber), - {noreply, setstats(State)}; - handle_cast(Msg, State) -> ?UNEXPECTED_MSG(Msg, State). -handle_info({'DOWN', _MRef, process, DownPid, _Reason}, State = #state{submon = PM}) -> - subscriber_down_(DownPid), - {noreply, setstats(State#state{submon = PM:erase(DownPid)}), hibernate}; +handle_info({'DOWN', _MRef, process, DownPid, _Reason}, State = #state{subids = SubIds}) -> + case maps:find(DownPid, SubIds) of + {ok, SubId} -> + clean_subscriber({SubId, DownPid}); + error -> + clean_subscriber(DownPid) + end, + noreply(setstats(demonitor_subscriber(DownPid, State))); handle_info(Info, State) -> ?UNEXPECTED_INFO(Info, State). @@ -234,62 +255,54 @@ code_change(_OldVsn, State, _Extra) -> %% Internal Functions %%-------------------------------------------------------------------- -do_subscribe_(Topic, Subscriber, Options, State) -> +do_subscribe(Topic, Subscriber, Options, State) -> case ets:lookup(mqtt_subproperty, {Topic, Subscriber}) of [] -> emqttd_pubsub:async_subscribe(Topic, Subscriber, Options), Share = proplists:get_value(share, Options), - add_subscription_(Share, Subscriber, Topic), + add_subscription(Share, Subscriber, Topic), ets:insert(mqtt_subproperty, {{Topic, Subscriber}, Options}), - {ok, monitor_subpid(Subscriber, State)}; + {ok, monitor_subscriber(Subscriber, State)}; [_] -> {error, {already_subscribed, Topic}} end. -add_subscription_(undefined, Subscriber, Topic) -> +add_subscription(undefined, Subscriber, Topic) -> ets:insert(mqtt_subscription, {Subscriber, Topic}); -add_subscription_(Share, Subscriber, Topic) -> - ets:insert(mqtt_subscription, {Subscriber, {Share, Topic}}). +add_subscription(Share, Subscriber, Topic) -> + ets:insert(mqtt_subscription, {Subscriber, {share, Share, Topic}}). -monitor_subpid(SubPid, State = #state{submon = PMon}) when is_pid(SubPid) -> - State#state{submon = PMon:monitor(SubPid)}; -monitor_subpid(_SubPid, State) -> - State. +monitor_subscriber(SubPid, State = #state{submon = SubMon}) when is_pid(SubPid) -> + State#state{submon = SubMon:monitor(SubPid)}; +monitor_subscriber({SubId, SubPid}, State = #state{subids = SubIds, submon = SubMon}) -> + State#state{subids = maps:put(SubPid, SubId, SubIds), submon = SubMon:monitor(SubPid)}. -do_unsubscribe_(Topic, Subscriber, State) -> +do_unsubscribe(Topic, Subscriber, State) -> case ets:lookup(mqtt_subproperty, {Topic, Subscriber}) of [{_, Options}] -> emqttd_pubsub:async_unsubscribe(Topic, Subscriber, Options), Share = proplists:get_value(share, Options), - del_subscription_(Share, Subscriber, Topic), + del_subscription(Share, Subscriber, Topic), ets:delete(mqtt_subproperty, {Topic, Subscriber}), - {ok, case ets:member(mqtt_subscription, Subscriber) of - true -> State; - false -> demonitor_subpid(Subscriber, State) - end}; + {ok, State}; [] -> {error, {subscription_not_found, Topic}} end. -del_subscription_(undefined, Subscriber, Topic) -> +del_subscription(undefined, Subscriber, Topic) -> ets:delete_object(mqtt_subscription, {Subscriber, Topic}); -del_subscription_(Share, Subscriber, Topic) -> - ets:delete_object(mqtt_subscription, {Subscriber, {Share, Topic}}). +del_subscription(Share, Subscriber, Topic) -> + ets:delete_object(mqtt_subscription, {Subscriber, {share, Share, Topic}}). -demonitor_subpid(SubPid, State = #state{submon = PMon}) when is_pid(SubPid) -> - State#state{submon = PMon:demonitor(SubPid)}; -demonitor_subpid(_SubPid, State) -> - State. - -subscriber_down_(Subscriber) -> - lists:foreach(fun({_, {Share, Topic}}) -> - subscriber_down_(Share, Subscriber, Topic); +clean_subscriber(Subscriber) -> + lists:foreach(fun({_, {share, Share, Topic}}) -> + clean_subscriber(Share, Subscriber, Topic); ({_, Topic}) -> - subscriber_down_(undefined, Subscriber, Topic) + clean_subscriber(undefined, Subscriber, Topic) end, ets:lookup(mqtt_subscription, Subscriber)), ets:delete(mqtt_subscription, Subscriber). -subscriber_down_(Share, Subscriber, Topic) -> +clean_subscriber(Share, Subscriber, Topic) -> case ets:lookup(mqtt_subproperty, {Topic, Subscriber}) of [] -> %% TODO:....??? @@ -300,7 +313,16 @@ subscriber_down_(Share, Subscriber, Topic) -> ets:delete(mqtt_subproperty, {Topic, Subscriber}) end. +demonitor_subscriber(SubPid, State = #state{subids = SubIds, submon = SubMon}) -> + State#state{subids = maps:remove(SubPid, SubIds), submon = SubMon:demonitor(SubPid)}. + setstats(State) -> emqttd_stats:setstats('subscriptions/count', 'subscriptions/max', ets:info(mqtt_subscription, size)), State. +reply(Reply, State) -> + {reply, Reply, State, hibernate}. + +noreply(State) -> + {noreply, State, hibernate}. + diff --git a/src/emqttd_trie.erl b/src/emqttd_trie.erl index 5b36e6e04..0bb6ec63e 100644 --- a/src/emqttd_trie.erl +++ b/src/emqttd_trie.erl @@ -31,7 +31,7 @@ -copy_mnesia({mnesia, [copy]}). %% Trie API --export([insert/1, match/1, delete/1, lookup/1]). +-export([insert/1, match/1, lookup/1, delete/1]). %%-------------------------------------------------------------------- %% Mnesia Callbacks @@ -65,22 +65,22 @@ mnesia(copy) -> -spec(insert(Topic :: binary()) -> ok). insert(Topic) when is_binary(Topic) -> case mnesia:read(mqtt_trie_node, Topic) of - [#trie_node{topic=Topic}] -> - ok; - [TrieNode=#trie_node{topic=undefined}] -> - write_trie_node(TrieNode#trie_node{topic=Topic}); - [] -> - % Add trie path - lists:foreach(fun add_path/1, emqttd_topic:triples(Topic)), - % Add last node - write_trie_node(#trie_node{node_id=Topic, topic=Topic}) + [#trie_node{topic = Topic}] -> + ok; + [TrieNode = #trie_node{topic = undefined}] -> + write_trie_node(TrieNode#trie_node{topic = Topic}); + [] -> + % Add trie path + lists:foreach(fun add_path/1, emqttd_topic:triples(Topic)), + % Add last node + write_trie_node(#trie_node{node_id = Topic, topic = Topic}) end. %% @doc Find trie nodes that match topic -spec(match(Topic :: binary()) -> list(MatchedTopic :: binary())). match(Topic) when is_binary(Topic) -> TrieNodes = match_node(root, emqttd_topic:words(Topic)), - [Name || #trie_node{topic=Name} <- TrieNodes, Name =/= undefined]. + [Name || #trie_node{topic = Name} <- TrieNodes, Name =/= undefined]. %% @doc Lookup a Trie Node -spec(lookup(NodeId :: binary()) -> [#trie_node{}]). @@ -91,13 +91,13 @@ lookup(NodeId) -> -spec(delete(Topic :: binary()) -> ok). delete(Topic) when is_binary(Topic) -> case mnesia:read(mqtt_trie_node, Topic) of - [#trie_node{edge_count=0}] -> - mnesia:delete({mqtt_trie_node, Topic}), - delete_path(lists:reverse(emqttd_topic:triples(Topic))); - [TrieNode] -> - write_trie_node(TrieNode#trie_node{topic = undefined}); - [] -> - ok + [#trie_node{edge_count = 0}] -> + mnesia:delete({mqtt_trie_node, Topic}), + delete_path(lists:reverse(emqttd_topic:triples(Topic))); + [TrieNode] -> + write_trie_node(TrieNode#trie_node{topic = undefined}); + [] -> + ok end. %%-------------------------------------------------------------------- @@ -107,19 +107,19 @@ delete(Topic) when is_binary(Topic) -> %% @private %% @doc Add path to trie tree. add_path({Node, Word, Child}) -> - Edge = #trie_edge{node_id=Node, word=Word}, + Edge = #trie_edge{node_id = Node, word = Word}, case mnesia:read(mqtt_trie_node, Node) of - [TrieNode = #trie_node{edge_count=Count}] -> - case mnesia:wread({mqtt_trie, Edge}) of - [] -> - write_trie_node(TrieNode#trie_node{edge_count=Count+1}), - write_trie(#trie{edge=Edge, node_id=Child}); - [_] -> - ok - end; - [] -> - write_trie_node(#trie_node{node_id=Node, edge_count=1}), - write_trie(#trie{edge=Edge, node_id=Child}) + [TrieNode = #trie_node{edge_count = Count}] -> + case mnesia:wread({mqtt_trie, Edge}) of + [] -> + write_trie_node(TrieNode#trie_node{edge_count = Count+1}), + write_trie(#trie{edge = Edge, node_id = Child}); + [_] -> + ok + end; + [] -> + write_trie_node(#trie_node{node_id = Node, edge_count = 1}), + write_trie(#trie{edge = Edge, node_id = Child}) end. %% @private @@ -135,20 +135,20 @@ match_node(NodeId, [], ResAcc) -> match_node(NodeId, [W|Words], ResAcc) -> lists:foldl(fun(WArg, Acc) -> - case mnesia:read(mqtt_trie, #trie_edge{node_id=NodeId, word=WArg}) of - [#trie{node_id=ChildId}] -> match_node(ChildId, Words, Acc); - [] -> Acc + case mnesia:read(mqtt_trie, #trie_edge{node_id = NodeId, word = WArg}) of + [#trie{node_id = ChildId}] -> match_node(ChildId, Words, Acc); + [] -> Acc end end, 'match_#'(NodeId, ResAcc), [W, '+']). %% @private %% @doc Match node with '#'. 'match_#'(NodeId, ResAcc) -> - case mnesia:read(mqtt_trie, #trie_edge{node_id=NodeId, word = '#'}) of - [#trie{node_id=ChildId}] -> - mnesia:read(mqtt_trie_node, ChildId) ++ ResAcc; - [] -> - ResAcc + case mnesia:read(mqtt_trie, #trie_edge{node_id = NodeId, word = '#'}) of + [#trie{node_id = ChildId}] -> + mnesia:read(mqtt_trie_node, ChildId) ++ ResAcc; + [] -> + ResAcc end. %% @private @@ -156,17 +156,17 @@ match_node(NodeId, [W|Words], ResAcc) -> delete_path([]) -> ok; delete_path([{NodeId, Word, _} | RestPath]) -> - mnesia:delete({mqtt_trie, #trie_edge{node_id=NodeId, word=Word}}), + mnesia:delete({mqtt_trie, #trie_edge{node_id = NodeId, word = Word}}), case mnesia:read(mqtt_trie_node, NodeId) of - [#trie_node{edge_count=1, topic=undefined}] -> - mnesia:delete({mqtt_trie_node, NodeId}), - delete_path(RestPath); - [TrieNode=#trie_node{edge_count=1, topic=_}] -> - write_trie_node(TrieNode#trie_node{edge_count=0}); - [TrieNode=#trie_node{edge_count=C}] -> - write_trie_node(TrieNode#trie_node{edge_count=C-1}); - [] -> - throw({notfound, NodeId}) + [#trie_node{edge_count = 1, topic = undefined}] -> + mnesia:delete({mqtt_trie_node, NodeId}), + delete_path(RestPath); + [TrieNode = #trie_node{edge_count = 1, topic = _}] -> + write_trie_node(TrieNode#trie_node{edge_count = 0}); + [TrieNode = #trie_node{edge_count = C}] -> + write_trie_node(TrieNode#trie_node{edge_count = C-1}); + [] -> + mnesia:abort({node_not_found, NodeId}) end. %% @private