%%-------------------------------------------------------------------- %% Copyright (c) 2021-2023 EMQ Technologies Co., Ltd. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. %% You may obtain a copy of the License at %% %% http://www.apache.org/licenses/LICENSE-2.0 %% %% Unless required by applicable law or agreed to in writing, software %% distributed under the License is distributed on an "AS IS" BASIS, %% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. %% See the License for the specific language governing permissions and %% limitations under the License. %%-------------------------------------------------------------------- %% @doc The gateway runtime -module(emqx_gateway_insta_sup). -behaviour(gen_server). -include("include/emqx_gateway.hrl"). -include_lib("emqx/include/logger.hrl"). %% APIs -export([ start_link/3, info/1, disable/1, enable/1, update/2 ]). %% gen_server callbacks -export([ init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3 ]). -record(state, { name :: gateway_name(), config :: emqx_config:config(), ctx :: emqx_gateway_ctx:context(), status :: stopped | running, child_pids :: [pid()], gw_state :: emqx_gateway_impl:state() | undefined, created_at :: integer(), started_at :: integer() | undefined, stopped_at :: integer() | undefined }). -elvis([{elvis_style, invalid_dynamic_call, disable}]). %%-------------------------------------------------------------------- %% APIs %%-------------------------------------------------------------------- start_link(Gateway, Ctx, GwDscrptr) -> gen_server:start_link( ?MODULE, [Gateway, Ctx, GwDscrptr], [] ). -spec info(pid()) -> gateway(). info(Pid) -> gen_server:call(Pid, info). %% @doc Stop gateway -spec disable(pid()) -> ok | {error, any()}. disable(Pid) -> call(Pid, disable). %% @doc Start gateway -spec enable(pid()) -> ok | {error, any()}. enable(Pid) -> call(Pid, enable). %% @doc Update the gateway configurations -spec update(pid(), emqx_config:config()) -> ok | {error, any()}. update(Pid, Config) -> call(Pid, {update, Config}). call(Pid, Req) -> %% The large timeout aim to get the modified results of the dependent %% resources gen_server:call(Pid, Req, 15000). %%-------------------------------------------------------------------- %% gen_server callbacks %%-------------------------------------------------------------------- init([Gateway, Ctx, _GwDscrptr]) -> process_flag(trap_exit, true), #{name := GwName, config := Config} = Gateway, State = #state{ ctx = Ctx, name = GwName, config = Config, child_pids = [], status = stopped, created_at = erlang:system_time(millisecond) }, Enable = maps:get(enable, Config, true), ok = ensure_authn_running(State, Enable), case Enable of false -> ?SLOG(info, #{ msg => "skip_to_start_gateway_due_to_disabled", gateway_name => GwName }), {ok, State}; true -> case cb_gateway_load(State) of {error, Reason} -> {stop, Reason}; {ok, NState} -> {ok, NState} end end. handle_call(info, _From, State) -> {reply, detailed_gateway_info(State), State}; handle_call(disable, _From, State = #state{status = Status}) -> case Status of running -> case cb_gateway_unload(State) of {ok, NState} -> ok = disable_authns(State), {reply, ok, NState}; {error, Reason} -> {reply, {error, Reason}, State} end; _ -> {reply, {error, already_stopped}, State} end; handle_call(enable, _From, State = #state{status = Status}) -> case Status of stopped -> case ensure_authn_running(State) of ok -> case cb_gateway_load(State) of {error, Reason} -> {reply, {error, Reason}, State}; {ok, NState1} -> {reply, ok, NState1} end; {error, Reason} -> {reply, {error, Reason}, State} end; _ -> {reply, {error, already_started}, State} end; handle_call({update, Config}, _From, State) -> case do_update_one_by_one(Config, State) of {ok, NState} -> {reply, ok, NState}; {error, Reason} -> %% If something wrong, nothing to update {reply, {error, Reason}, State} end; handle_call(_Request, _From, State) -> Reply = ok, {reply, Reply, State}. handle_cast(_Msg, State) -> {noreply, State}. handle_info( {'EXIT', Pid, Reason}, State = #state{ name = Name, child_pids = Pids } ) -> case lists:member(Pid, Pids) of true -> ?SLOG(info, #{ msg => "child_process_exited", child => Pid, reason => Reason }), case Pids -- [Pid] of [] -> ?SLOG(info, #{ msg => "gateway_all_children_process_existed", gateway_name => Name }), {noreply, State#state{ status = stopped, child_pids = [], gw_state = undefined }}; RemainPids -> {noreply, State#state{child_pids = RemainPids}} end; _ -> ?SLOG(info, #{ msg => "gateway_catch_a_unknown_process_exited", child => Pid, reason => Reason, gateway_name => Name }), {noreply, State} end; handle_info(Info, State) -> ?SLOG(warning, #{ msg => "unexcepted_info", info => Info }), {noreply, State}. terminate(_Reason, State = #state{child_pids = Pids}) -> Pids /= [] andalso (_ = cb_gateway_unload(State)), _ = remove_all_authns(State), ok. code_change(_OldVsn, State, _Extra) -> {ok, State}. detailed_gateway_info(State) -> maps:filter( fun(_, V) -> V =/= undefined end, #{ name => State#state.name, config => State#state.config, status => State#state.status, created_at => State#state.created_at, started_at => State#state.started_at, stopped_at => State#state.stopped_at } ). %%-------------------------------------------------------------------- %% Internal funcs %%-------------------------------------------------------------------- %%-------------------------------------------------------------------- %% Authn resources managing funcs pipeline(_, []) -> ok; pipeline(Fun, [Args | More]) -> case Fun(Args) of ok -> pipeline(Fun, More); {error, Reason} -> {error, Reason} end. %% ensure authentication chain, authenticator created and keep its configured %% status ensure_authn_running(#state{name = GwName, config = Config}) -> pipeline( fun({ChainName, AuthConf}) -> ensure_authenticator_created(ChainName, AuthConf) end, authns(GwName, Config) ). %% ensure authentication chain, authenticator created and keep its status %% as given ensure_authn_running(#state{name = GwName, config = Config}, Enable) -> pipeline( fun({ChainName, AuthConf}) -> ensure_authenticator_created(ChainName, AuthConf#{enable => Enable}) end, authns(GwName, Config) ). %% temporarily disable authenticators after gateway disabled disable_authns(State) -> ensure_authn_running(State, false). %% remove all authns if gateway unloaded remove_all_authns(#state{name = GwName, config = Config}) -> lists:foreach( fun({ChainName, _}) -> case emqx_authn_chains:delete_chain(ChainName) of ok -> ok; {error, {not_found, _}} -> ok; {error, Reason} -> ?SLOG(error, #{ msg => "failed_to_clean_authn_chain", chain_name => ChainName, reason => Reason }) end end, authns(GwName, Config) ). ensure_authenticator_created(ChainName, Confs) -> case emqx_authn_chains:list_authenticators(ChainName) of {ok, [#{id := AuthenticatorId}]} -> case emqx_authn_chains:update_authenticator(ChainName, AuthenticatorId, Confs) of {ok, _} -> ok; {error, Reason} -> {error, {badauth, Reason}} end; {ok, []} -> do_create_authenticator(ChainName, Confs); {error, {not_found, {chain, _}}} -> do_create_authenticator(ChainName, Confs) end. authns(GwName, Config) -> Listeners = maps:to_list(maps:get(listeners, Config, #{})), Authns0 = lists:append( [ [ {emqx_gateway_utils:listener_chain(GwName, LisType, LisName), authn_conf(Opts)} || {LisName, Opts} <- maps:to_list(LisNames) ] || {LisType, LisNames} <- Listeners ] ) ++ [{emqx_gateway_utils:global_chain(GwName), authn_conf(Config)}], lists:filter( fun ({_, undefined}) -> false; (_) -> true end, Authns0 ). authn_conf(Conf) -> maps:get(authentication, Conf, undefined). do_create_authenticator(ChainName, AuthConf) -> case emqx_authn_chains:create_authenticator(ChainName, AuthConf) of {ok, _} -> ok; {error, Reason} -> ?SLOG(error, #{ msg => "failed_to_create_authenticator", chain_name => ChainName, reason => Reason, config => AuthConf }), {error, {badauth, Reason}} end. do_update_one_by_one( NCfg, State = #state{ name = GwName, config = OCfg, status = Status } ) -> NEnable = maps:get(enable, NCfg, true), OAuthns = authns(GwName, OCfg), NAuthns = authns(GwName, NCfg), ok = remove_deleted_authns(NAuthns, OAuthns), case {Status, NEnable} of {stopped, true} -> case ensure_authn_running(State#state{config = NCfg}) of ok -> cb_gateway_load(State#state{config = NCfg}); {error, Reason} -> {error, Reason} end; {stopped, false} -> case disable_authns(State#state{config = NCfg}) of ok -> {ok, State#state{config = NCfg}}; {error, Reason} -> {error, Reason} end; {running, true} -> %% FIXME: minimum impact update case ensure_authn_running(State#state{config = NCfg}) of ok -> cb_gateway_update(NCfg, State); {error, Reason} -> {error, Reason} end; {running, false} -> case cb_gateway_unload(State) of {ok, NState} -> ok = disable_authns(State#state{config = NCfg}), {ok, NState#state{config = NCfg}}; {error, Reason} -> {error, Reason} end; _ -> throw(nomatch) end. remove_deleted_authns(NAuthns, OAuthns) -> NNames = proplists:get_keys(NAuthns), ONames = proplists:get_keys(OAuthns), DeletedNames = ONames -- NNames, lists:foreach( fun(ChainName) -> _ = emqx_authn_chains:delete_chain(ChainName) end, DeletedNames ). cb_gateway_unload( State = #state{ name = GwName, gw_state = GwState } ) -> Gateway = detailed_gateway_info(State), try #{cbkmod := CbMod} = emqx_gateway_registry:lookup(GwName), CbMod:on_gateway_unload(Gateway, GwState), {ok, State#state{ child_pids = [], status = stopped, gw_state = undefined, started_at = undefined, stopped_at = erlang:system_time(millisecond) }} catch Class:Reason:Stk -> ?SLOG(error, #{ msg => "unload_gateway_crashed", gateway_name => GwName, inner_state => GwState, reason => {Class, Reason}, stacktrace => Stk }), {error, Reason} end. %% @doc 1. Create Authentcation Context %% 2. Callback to Mod:on_gateway_load/2 %% %% Notes: If failed, rollback cb_gateway_load( State = #state{ name = GwName, ctx = Ctx } ) -> Gateway = detailed_gateway_info(State), try #{cbkmod := CbMod} = emqx_gateway_registry:lookup(GwName), case CbMod:on_gateway_load(Gateway, Ctx) of {error, Reason} -> {error, Reason}; {ok, ChildPidOrSpecs, GwState} -> ChildPids = start_child_process(ChildPidOrSpecs), {ok, State#state{ status = running, child_pids = ChildPids, gw_state = GwState, stopped_at = undefined, started_at = erlang:system_time(millisecond) }} end catch Class:Reason1:Stk -> ?SLOG(error, #{ msg => "load_gateway_crashed", gateway_name => GwName, gateway => Gateway, reason => {Class, Reason1}, stacktrace => Stk }), {error, Reason1} end. cb_gateway_update( Config, State = #state{ name = GwName, gw_state = GwState } ) -> try #{cbkmod := CbMod} = emqx_gateway_registry:lookup(GwName), case CbMod:on_gateway_update(Config, detailed_gateway_info(State), GwState) of {error, Reason} -> {error, Reason}; {ok, ChildPidOrSpecs, NGwState} -> ChildPids = start_child_process(ChildPidOrSpecs), {ok, State#state{ config = Config, child_pids = ChildPids, gw_state = NGwState }} end catch Class:Reason1:Stk -> ?SLOG(error, #{ msg => "update_gateway_crashed", gateway_name => GwName, new_config => Config, reason => {Class, Reason1}, stacktrace => Stk }), {error, Reason1} end. start_child_process([]) -> []; start_child_process([Indictor | _] = ChildPidOrSpecs) -> case erlang:is_pid(Indictor) of true -> ChildPidOrSpecs; _ -> do_start_child_process(ChildPidOrSpecs) end. do_start_child_process(ChildSpecs) when is_list(ChildSpecs) -> lists:map(fun do_start_child_process/1, ChildSpecs); do_start_child_process(_ChildSpec = #{start := {M, F, A}}) -> case erlang:apply(M, F, A) of {ok, Pid} -> Pid; {error, Reason} -> throw({start_child_process, Reason}) end.