diff --git a/src/emqttd.erl b/src/emqttd.erl index 63b53d4e1..037c0de1a 100644 --- a/src/emqttd.erl +++ b/src/emqttd.erl @@ -138,17 +138,20 @@ subscriber_down(Subscriber) -> %% Hooks API %%-------------------------------------------------------------------- --spec(hook(atom(), function(), list(any())) -> ok | {error, any()}). -hook(Hook, Function, InitArgs) -> - emqttd_hooks:add(Hook, Function, InitArgs). +-spec(hook(atom(), function() | {emqttd_hooks:hooktag(), function()}, list(any())) + -> ok | {error, any()}). +hook(Hook, TagFunction, InitArgs) -> + emqttd_hooks:add(Hook, TagFunction, InitArgs). --spec(hook(atom(), function(), list(any()), integer()) -> ok | {error, any()}). -hook(Hook, Function, InitArgs, Priority) -> - emqttd_hooks:add(Hook, Function, InitArgs, Priority). +-spec(hook(atom(), function() | {emqttd_hooks:hooktag(), function()}, list(any()), integer()) + -> ok | {error, any()}). +hook(Hook, TagFunction, InitArgs, Priority) -> + emqttd_hooks:add(Hook, TagFunction, InitArgs, Priority). --spec(unhook(atom(), function()) -> ok | {error, any()}). -unhook(Hook, Function) -> - emqttd_hooks:delete(Hook, Function). +-spec(unhook(atom(), function() | {emqttd_hooks:hooktag(), function()}) + -> ok | {error, any()}). +unhook(Hook, TagFunction) -> + emqttd_hooks:delete(Hook, TagFunction). -spec(run_hooks(atom(), list(any())) -> ok | stop). run_hooks(Hook, Args) -> diff --git a/src/emqttd_hooks.erl b/src/emqttd_hooks.erl index ce2691894..693a67ff7 100644 --- a/src/emqttd_hooks.erl +++ b/src/emqttd_hooks.erl @@ -32,7 +32,12 @@ -record(state, {}). --record(callback, {function :: function(), +-type(hooktag() :: atom() | string() | binary()). + +-export_type([hooktag/0]). + +-record(callback, {tag :: hooktag(), + function :: function(), init_args = [] :: list(any()), priority = 0 :: integer()}). @@ -47,17 +52,24 @@ start_link() -> %% Hooks API %%-------------------------------------------------------------------- --spec(add(atom(), function(), list(any())) -> ok). -add(HookPoint, Function, InitArgs) -> - add(HookPoint, Function, InitArgs, 0). +-spec(add(atom(), function() | {hooktag(), function()}, list(any())) -> ok). +add(HookPoint, Function, InitArgs) when is_function(Function) -> + add(HookPoint, {undefined, Function}, InitArgs, 0); --spec(add(atom(), function(), list(any()), integer()) -> ok). -add(HookPoint, Function, InitArgs, Priority) -> - gen_server:call(?MODULE, {add, HookPoint, Function, InitArgs, Priority}). +add(HookPoint, {Tag, Function}, InitArgs) when is_function(Function) -> + add(HookPoint, {Tag, Function}, InitArgs, 0). --spec(delete(atom(), function()) -> ok). -delete(HookPoint, Function) -> - gen_server:call(?MODULE, {delete, HookPoint, Function}). +-spec(add(atom(), function() | {hooktag(), function()}, list(any()), integer()) -> ok). +add(HookPoint, Function, InitArgs, Priority) when is_function(Function) -> + add(HookPoint, {undefined, Function}, InitArgs, Priority); +add(HookPoint, {Tag, Function}, InitArgs, Priority) when is_function(Function) -> + gen_server:call(?MODULE, {add, HookPoint, {Tag, Function}, InitArgs, Priority}). + +-spec(delete(atom(), function() | {hooktag(), function()}) -> ok). +delete(HookPoint, Function) when is_function(Function) -> + delete(HookPoint, {undefined, Function}); +delete(HookPoint, {Tag, Function}) when is_function(Function) -> + gen_server:call(?MODULE, {delete, HookPoint, {Tag, Function}}). %% @doc Run hooks without Acc. -spec(run(atom(), list(Arg :: any())) -> ok | stop). @@ -85,7 +97,8 @@ run_([#callback{function = Fun, init_args = InitArgs} | Callbacks], Args, Acc) - ok -> run_(Callbacks, Args, Acc); {ok, NewAcc} -> run_(Callbacks, Args, NewAcc); stop -> {stop, Acc}; - {stop, NewAcc} -> {stop, NewAcc} + {stop, NewAcc} -> {stop, NewAcc}; + _Any -> run_(Callbacks, Args, Acc) end; run_([], _Args, Acc) -> @@ -94,8 +107,8 @@ run_([], _Args, Acc) -> -spec(lookup(atom()) -> [#callback{}]). lookup(HookPoint) -> case ets:lookup(?HOOK_TAB, HookPoint) of - [] -> []; - [#hook{callbacks = Callbacks}] -> Callbacks + [#hook{callbacks = Callbacks}] -> Callbacks; + [] -> [] end. %%-------------------------------------------------------------------- @@ -106,39 +119,38 @@ init([]) -> ets:new(?HOOK_TAB, [set, protected, named_table, {keypos, #hook.name}]), {ok, #state{}}. -handle_call({add, HookPoint, Function, InitArgs, Priority}, _From, State) -> - Reply = - case ets:lookup(?HOOK_TAB, HookPoint) of - [#hook{callbacks = Callbacks}] -> - case lists:keyfind(Function, #callback.function, Callbacks) of - false -> - Callback = #callback{function = Function, - init_args = InitArgs, - priority = Priority}, - insert_hook_(HookPoint, add_callback_(Callback, Callbacks)); - _Callback -> - {error, already_hooked} - end; - [] -> - Callback = #callback{function = Function, - init_args = InitArgs, - priority = Priority}, - insert_hook_(HookPoint, [Callback]) - end, - {reply, Reply, State}; +handle_call({add, HookPoint, {Tag, Function}, InitArgs, Priority}, _From, State) -> + Callback = #callback{tag = Tag, function = Function, + init_args = InitArgs, priority = Priority}, + {reply, + case ets:lookup(?HOOK_TAB, HookPoint) of + [#hook{callbacks = Callbacks}] -> + case contain_(Tag, Function, Callbacks) of + false -> + insert_hook_(HookPoint, add_callback_(Callback, Callbacks)); + true -> + {error, already_hooked} + end; + [] -> + insert_hook_(HookPoint, [Callback]) + end, State}; -handle_call({delete, HookPoint, Function}, _From, State) -> - Reply = - case ets:lookup(?HOOK_TAB, HookPoint) of - [#hook{callbacks = Callbacks}] -> - insert_hook_(HookPoint, del_callback_(Function, Callbacks)); - [] -> - {error, not_found} - end, - {reply, Reply, State}; +handle_call({delete, HookPoint, {Tag, Function}}, _From, State) -> + {reply, + case ets:lookup(?HOOK_TAB, HookPoint) of + [#hook{callbacks = Callbacks}] -> + case contain_(Tag, Function, Callbacks) of + true -> + insert_hook_(HookPoint, del_callback_(Tag, Function, Callbacks)); + false -> + {error, not_found} + end; + [] -> + {error, not_found} + end, State}; -handle_call(_Req, _From, State) -> - {reply, ignore, State}. +handle_call(Req, _From, State) -> + {reply, {error, {unexpected_request, Req}}, State}. handle_cast(_Msg, State) -> {noreply, State}. @@ -162,6 +174,16 @@ insert_hook_(HookPoint, Callbacks) -> add_callback_(Callback, Callbacks) -> lists:keymerge(#callback.priority, Callbacks, [Callback]). -del_callback_(Function, Callbacks) -> - lists:keydelete(Function, #callback.function, Callbacks). +del_callback_(Tag, Function, Callbacks) -> + lists:filter( + fun(#callback{tag = Tag1, function = Func1}) -> + not ((Tag =:= Tag1) andalso (Function =:= Func1)) + end, Callbacks). + +contain_(_Tag, _Function, []) -> + false; +contain_(Tag, Function, [#callback{tag = Tag, function = Function}|_Callbacks]) -> + true; +contain_(Tag, Function, [_Callback | Callbacks]) -> + contain_(Tag, Function, Callbacks). diff --git a/test/emqttd_SUITE.erl b/test/emqttd_SUITE.erl index d7af619df..afa1a1f06 100644 --- a/test/emqttd_SUITE.erl +++ b/test/emqttd_SUITE.erl @@ -366,37 +366,39 @@ set_get_stat(_) -> %%-------------------------------------------------------------------- add_delete_hook(_) -> - emqttd:hook(test_hook, fun ?MODULE:hook_fun1/1, []), - emqttd:hook(test_hook, fun ?MODULE:hook_fun2/1, []), - {error, already_hooked} = emqttd:hook(test_hook, fun ?MODULE:hook_fun2/1, []), - Callbacks = [{callback, fun ?MODULE:hook_fun1/1, [], 0}, - {callback, fun ?MODULE:hook_fun2/1, [], 0}], + ok = emqttd:hook(test_hook, fun ?MODULE:hook_fun1/1, []), + ok = emqttd:hook(test_hook, {tag, fun ?MODULE:hook_fun2/1}, []), + {error, already_hooked} = emqttd:hook(test_hook, {tag, fun ?MODULE:hook_fun2/1}, []), + Callbacks = [{callback, undefined, fun ?MODULE:hook_fun1/1, [], 0}, + {callback, tag, fun ?MODULE:hook_fun2/1, [], 0}], Callbacks = emqttd_hooks:lookup(test_hook), - emqttd:unhook(test_hook, fun ?MODULE:hook_fun1/1), - emqttd:unhook(test_hook, fun ?MODULE:hook_fun2/1), - ok = emqttd:unhook(test_hook, fun ?MODULE:hook_fun2/1), - {error, not_found} = emqttd:unhook(test_hook1, fun ?MODULE:hook_fun2/1), + ok = emqttd:unhook(test_hook, fun ?MODULE:hook_fun1/1), + ct:print("Callbacks: ~p~n", [emqttd_hooks:lookup(test_hook)]), + ok = emqttd:unhook(test_hook, {tag, fun ?MODULE:hook_fun2/1}), + {error, not_found} = emqttd:unhook(test_hook1, {tag, fun ?MODULE:hook_fun2/1}), [] = emqttd_hooks:lookup(test_hook), - emqttd:hook(emqttd_hook, fun ?MODULE:hook_fun1/1, [], 9), - emqttd:hook(emqttd_hook, fun ?MODULE:hook_fun2/1, [], 8), - Callbacks2 = [{callback, fun ?MODULE:hook_fun2/1, [], 8}, - {callback, fun ?MODULE:hook_fun1/1, [], 9}], + ok = emqttd:hook(emqttd_hook, fun ?MODULE:hook_fun1/1, [], 9), + ok = emqttd:hook(emqttd_hook, {"tag", fun ?MODULE:hook_fun2/1}, [], 8), + Callbacks2 = [{callback, "tag", fun ?MODULE:hook_fun2/1, [], 8}, + {callback, undefined, fun ?MODULE:hook_fun1/1, [], 9}], Callbacks2 = emqttd_hooks:lookup(emqttd_hook), - emqttd:unhook(emqttd_hook, fun ?MODULE:hook_fun1/1), - emqttd:unhook(emqttd_hook, fun ?MODULE:hook_fun2/1), + ok = emqttd:unhook(emqttd_hook, fun ?MODULE:hook_fun1/1), + ok = emqttd:unhook(emqttd_hook, {"tag", fun ?MODULE:hook_fun2/1}), [] = emqttd_hooks:lookup(emqttd_hook). run_hooks(_) -> - emqttd:hook(foldl_hook, fun ?MODULE:hook_fun3/4, [init]), - emqttd:hook(foldl_hook, fun ?MODULE:hook_fun4/4, [init]), - emqttd:hook(foldl_hook, fun ?MODULE:hook_fun5/4, [init]), + ok = emqttd:hook(foldl_hook, fun ?MODULE:hook_fun3/4, [init]), + ok = emqttd:hook(foldl_hook, {tag, fun ?MODULE:hook_fun3/4}, [init]), + ok = emqttd:hook(foldl_hook, fun ?MODULE:hook_fun4/4, [init]), + ok = emqttd:hook(foldl_hook, fun ?MODULE:hook_fun5/4, [init]), {stop, [r3, r2]} = emqttd:run_hooks(foldl_hook, [arg1, arg2], []), {ok, []} = emqttd:run_hooks(unknown_hook, [], []), - emqttd:hook(foreach_hook, fun ?MODULE:hook_fun6/2, [initArg]), - emqttd:hook(foreach_hook, fun ?MODULE:hook_fun7/2, [initArg]), - emqttd:hook(foreach_hook, fun ?MODULE:hook_fun8/2, [initArg]), + ok = emqttd:hook(foreach_hook, fun ?MODULE:hook_fun6/2, [initArg]), + ok = emqttd:hook(foreach_hook, {tag, fun ?MODULE:hook_fun6/2}, [initArg]), + ok = emqttd:hook(foreach_hook, fun ?MODULE:hook_fun7/2, [initArg]), + ok = emqttd:hook(foreach_hook, fun ?MODULE:hook_fun8/2, [initArg]), stop = emqttd:run_hooks(foreach_hook, [arg]). hook_fun1([]) -> ok.