emqx/src/emqx_hooks.erl

292 lines
9.5 KiB
Erlang

%%--------------------------------------------------------------------
%% Copyright (c) 2017-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_hooks).
-behaviour(gen_server).
-include("logger.hrl").
-include("types.hrl").
-logger_header("[Hooks]").
-export([ start_link/0
, stop/0
]).
%% Hooks API
-export([ add/2
, add/3
, add/4
, put/2
, del/2
, run/2
, run_fold/3
, lookup/1
]).
-export([ callback_action/1
, callback_filter/1
, callback_priority/1
]).
%% gen_server Function Exports
-export([ init/1
, handle_call/3
, handle_cast/2
, handle_info/2
, terminate/2
, code_change/3
]).
-export_type([ hookpoint/0
, action/0
, filter/0
]).
%% Multiple callbacks can be registered on a hookpoint.
%% The execution order depends on the priority value:
%% - Callbacks with greater priority values will be run before
%% the ones with lower priority values. e.g. A Callback with
%% priority = 2 precedes the callback with priority = 1.
%% - The execution order is the adding order of callbacks if they have
%% equal priority values.
-type(hookpoint() :: atom()).
-type(action() :: function() | {function(), [term()]} | mfargs()).
-type(filter() :: function() | mfargs()).
-record(callback, {
action :: action(),
filter :: maybe(filter()),
priority :: integer()
}).
-record(hook, {
name :: hookpoint(),
callbacks :: list(#callback{})
}).
-define(TAB, ?MODULE).
-define(SERVER, ?MODULE).
-spec(start_link() -> startlink_ret()).
start_link() ->
gen_server:start_link({local, ?SERVER},
?MODULE, [], [{hibernate_after, 1000}]).
-spec(stop() -> ok).
stop() ->
gen_server:stop(?SERVER, normal, infinity).
%%--------------------------------------------------------------------
%% Test APIs
%%--------------------------------------------------------------------
%% @doc Get callback action.
callback_action(#callback{action = A}) -> A.
%% @doc Get callback filter.
callback_filter(#callback{filter = F}) -> F.
%% @doc Get callback priority.
callback_priority(#callback{priority= P}) -> P.
%%--------------------------------------------------------------------
%% Hooks API
%%--------------------------------------------------------------------
%% @doc Register a callback
-spec(add(hookpoint(), action() | #callback{}) -> ok_or_error(already_exists)).
add(HookPoint, Callback) when is_record(Callback, callback) ->
gen_server:call(?SERVER, {add, HookPoint, Callback}, infinity);
add(HookPoint, Action) when is_function(Action); is_tuple(Action) ->
add(HookPoint, #callback{action = Action, priority = 0}).
-spec(add(hookpoint(), action(), filter() | integer() | list())
-> ok_or_error(already_exists)).
add(HookPoint, Action, InitArgs) when is_function(Action), is_list(InitArgs) ->
add(HookPoint, #callback{action = {Action, InitArgs}, priority = 0});
add(HookPoint, Action, Filter) when is_function(Filter); is_tuple(Filter) ->
add(HookPoint, #callback{action = Action, filter = Filter, priority = 0});
add(HookPoint, Action, Priority) when is_integer(Priority) ->
add(HookPoint, #callback{action = Action, priority = Priority}).
-spec(add(hookpoint(), action(), filter(), integer())
-> ok_or_error(already_exists)).
add(HookPoint, Action, Filter, Priority) when is_integer(Priority) ->
add(HookPoint, #callback{action = Action, filter = Filter, priority = Priority}).
%% @doc Like add/2, it register a callback, discard 'already_exists' error.
-spec(put(hookpoint(), action() | #callback{}) -> ok).
put(HookPoint, Callback) ->
case add(HookPoint, Callback) of
ok -> ok;
{error, already_exists} -> ok
end.
%% @doc Unregister a callback.
-spec(del(hookpoint(), action() | {module(), atom()}) -> ok).
del(HookPoint, Action) ->
gen_server:cast(?SERVER, {del, HookPoint, Action}).
%% @doc Run hooks.
-spec(run(atom(), list(Arg::term())) -> ok).
run(HookPoint, Args) ->
do_run(lookup(HookPoint), Args).
%% @doc Run hooks with Accumulator.
-spec(run_fold(atom(), list(Arg::term()), Acc::term()) -> Acc::term()).
run_fold(HookPoint, Args, Acc) ->
do_run_fold(lookup(HookPoint), Args, Acc).
do_run([#callback{action = Action, filter = Filter} | Callbacks], Args) ->
case filter_passed(Filter, Args) andalso safe_execute(Action, Args) of
%% stop the hook chain and return
stop -> ok;
%% continue the hook chain, in following cases:
%% - the filter validation failed with 'false'
%% - the callback returns any term other than 'stop'
_ -> do_run(Callbacks, Args)
end;
do_run([], _Args) ->
ok.
do_run_fold([#callback{action = Action, filter = Filter} | Callbacks], Args, Acc) ->
Args1 = Args ++ [Acc],
case filter_passed(Filter, Args1) andalso safe_execute(Action, Args1) of
%% stop the hook chain
stop -> Acc;
%% stop the hook chain with NewAcc
{stop, NewAcc} -> NewAcc;
%% continue the hook chain with NewAcc
{ok, NewAcc} -> do_run_fold(Callbacks, Args, NewAcc);
%% continue the hook chain, in following cases:
%% - the filter validation failed with 'false'
%% - the callback returns any term other than 'stop' or {'stop', NewAcc}
_ -> do_run_fold(Callbacks, Args, Acc)
end;
do_run_fold([], _Args, Acc) ->
Acc.
-spec(filter_passed(filter(), Args::term()) -> true | false).
filter_passed(undefined, _Args) -> true;
filter_passed(Filter, Args) ->
execute(Filter, Args).
safe_execute(Fun, Args) ->
try execute(Fun, Args) of
Result -> Result
catch
Error:Reason:Stacktrace ->
?LOG(error, "Failed to execute ~0p: ~0p", [Fun, {Error, Reason, Stacktrace}]),
ok
end.
%% @doc execute a function.
execute(Fun, Args) when is_function(Fun) ->
erlang:apply(Fun, Args);
execute({Fun, InitArgs}, Args) when is_function(Fun) ->
erlang:apply(Fun, Args ++ InitArgs);
execute({M, F, A}, Args) ->
erlang:apply(M, F, Args ++ A).
%% @doc Lookup callbacks.
-spec(lookup(hookpoint()) -> [#callback{}]).
lookup(HookPoint) ->
case ets:lookup(?TAB, HookPoint) of
[#hook{callbacks = Callbacks}] ->
Callbacks;
[] -> []
end.
%%--------------------------------------------------------------------
%% gen_server callbacks
%%--------------------------------------------------------------------
init([]) ->
ok = emqx_tables:new(?TAB, [{keypos, #hook.name}, {read_concurrency, true}]),
{ok, #{}}.
handle_call({add, HookPoint, Callback = #callback{action = Action}}, _From, State) ->
Reply = case lists:keymember(Action, #callback.action, Callbacks = lookup(HookPoint)) of
true ->
{error, already_exists};
false ->
insert_hook(HookPoint, add_callback(Callback, Callbacks))
end,
{reply, Reply, State};
handle_call(Req, _From, State) ->
?LOG(error, "Unexpected call: ~p", [Req]),
{reply, ignored, State}.
handle_cast({del, HookPoint, Action}, State) ->
case del_callback(Action, lookup(HookPoint)) of
[] ->
ets:delete(?TAB, HookPoint);
Callbacks ->
insert_hook(HookPoint, Callbacks)
end,
{noreply, State};
handle_cast(Msg, State) ->
?LOG(error, "Unexpected msg: ~p", [Msg]),
{noreply, State}.
handle_info(Info, State) ->
?LOG(error, "Unexpected info: ~p", [Info]),
{noreply, State}.
terminate(_Reason, _State) ->
ok.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%------------------------------------------------------------------------------
%% Internal functions
%%------------------------------------------------------------------------------
insert_hook(HookPoint, Callbacks) ->
ets:insert(?TAB, #hook{name = HookPoint, callbacks = Callbacks}), ok.
add_callback(C, Callbacks) ->
add_callback(C, Callbacks, []).
add_callback(C, [], Acc) ->
lists:reverse([C|Acc]);
add_callback(C1 = #callback{priority = P1}, [C2 = #callback{priority = P2}|More], Acc)
when P1 =< P2 ->
add_callback(C1, More, [C2|Acc]);
add_callback(C1, More, Acc) ->
lists:append(lists:reverse(Acc), [C1 | More]).
del_callback(Action, Callbacks) ->
del_callback(Action, Callbacks, []).
del_callback(_Action, [], Acc) ->
lists:reverse(Acc);
del_callback(Action, [#callback{action = Action} | Callbacks], Acc) ->
del_callback(Action, Callbacks, Acc);
del_callback(Action = {M, F}, [#callback{action = {M, F, _A}} | Callbacks], Acc) ->
del_callback(Action, Callbacks, Acc);
del_callback(Func, [#callback{action = {Func, _A}} | Callbacks], Acc) ->
del_callback(Func, Callbacks, Acc);
del_callback(Action, [Callback | Callbacks], Acc) ->
del_callback(Action, Callbacks, [Callback | Acc]).