diff --git a/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_sup.erl b/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_sup.erl index 18cc48c9e..4f3e24cce 100644 --- a/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_sup.erl +++ b/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_sup.erl @@ -19,8 +19,8 @@ -behaviour(supervisor). -export([ start_link/0 - , start_worker/2 - , stop_worker/1 + , ensure_worker_present/2 + , ensure_worker_deleted/1 ]). -export([init/1]). @@ -45,32 +45,31 @@ init([]) -> %% @doc Starts a new JWT worker. The worker will send the caller a %% message when it creates and stores its first JWT, or if it fails to %% do so, using a generated reference. --spec start_worker(worker_id(), map()) -> - {ok, {reference(), supervisor:child()}} - | {error, already_present} - | {error, {already_started, supervisor:child()}}. -start_worker(Id, Config) -> - Ref = erlang:alias([reply]), - ChildSpec = jwt_worker_child_spec(Id, Config, Ref), +-spec ensure_worker_present(worker_id(), map()) -> + {ok, supervisor:child()}. +ensure_worker_present(Id, Config) -> + ChildSpec = jwt_worker_child_spec(Id, Config), case supervisor:start_child(?MODULE, ChildSpec) of {ok, Pid} -> - {ok, {Ref, Pid}}; - Error -> - Error + {ok, Pid}; + {error, {already_started, Pid}} -> + {ok, Pid}; + {error, already_present} -> + supervisor:restart_child(?MODULE, Id) end. %% @doc Stops a given JWT worker by its id. --spec stop_worker(worker_id()) -> ok. -stop_worker(Id) -> +-spec ensure_worker_deleted(worker_id()) -> ok. +ensure_worker_deleted(Id) -> case supervisor:terminate_child(?MODULE, Id) of ok -> ok; {error, not_found} -> ok end. -jwt_worker_child_spec(Id, Config, Ref) -> +jwt_worker_child_spec(Id, Config) -> #{ id => Id - , start => {emqx_rule_engine_jwt_worker, start_link, [Config, Ref]} - , restart => permanent + , start => {emqx_rule_engine_jwt_worker, start_link, [Config]} + , restart => transient , type => worker , significant => false , shutdown => brutal_kill diff --git a/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_worker.erl b/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_worker.erl index 855c3e076..4190a3536 100644 --- a/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_worker.erl +++ b/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_worker.erl @@ -19,7 +19,8 @@ -behaviour(gen_server). %% API --export([ start_link/2 +-export([ start_link/1 + , ensure_jwt/1 ]). %% gen_server API @@ -68,7 +69,7 @@ %% API %%----------------------------------------------------------------------------------------- --spec start_link(config(), reference()) -> gen_server:start_ret(). +-spec start_link(config()) -> gen_server:start_ret(). start_link(#{ private_key := _ , expiration := _ , resource_id := _ @@ -78,60 +79,68 @@ start_link(#{ private_key := _ , aud := _ , kid := _ , alg := _ - } = Config, - Ref) -> - gen_server:start_link(?MODULE, {Config, Ref}, []). + } = Config) -> + gen_server:start_link(?MODULE, Config, []). + +-spec ensure_jwt(pid()) -> reference(). +ensure_jwt(Worker) -> + Ref = alias([reply]), + gen_server:cast(Worker, {ensure_jwt, Ref}), + Ref. %%----------------------------------------------------------------------------------------- %% gen_server API %%----------------------------------------------------------------------------------------- --spec init({config(), Ref}) -> {ok, state(), {continue, {make_key, binary(), Ref}}} - | {stop, {error, term()}} - when Ref :: reference(). -init({#{private_key := PrivateKeyPEM} = Config, Ref}) -> +-spec init(config()) -> {ok, state(), {continue, {make_key, binary()}}} + | {stop, {error, term()}}. +init(#{private_key := PrivateKeyPEM} = Config) -> State0 = maps:without([private_key], Config), State = State0#{ jwk => undefined , jwt => undefined , refresh_timer => undefined }, - {ok, State, {continue, {make_key, PrivateKeyPEM, Ref}}}. + {ok, State, {continue, {make_key, PrivateKeyPEM}}}. -handle_continue({make_key, PrivateKeyPEM, Ref}, State0) -> +handle_continue({make_key, PrivateKeyPEM}, State0) -> case jose_jwk:from_pem(PrivateKeyPEM) of JWK = #jose_jwk{} -> State = State0#{jwk := JWK}, - {noreply, State, {continue, {create_token, Ref}}}; + {noreply, State, {continue, create_token}}; [] -> - Ref ! {Ref, {error, {invalid_private_key, empty_key}}}, - {stop, {error, empty_key}, State0}; + ?tp(rule_engine_jwt_worker_startup_error, #{error => empty_key}), + {stop, {shutdown, {error, empty_key}}, State0}; {error, Reason} -> - Ref ! {Ref, {error, {invalid_private_key, Reason}}}, - {stop, {error, Reason}, State0}; - Error -> - Ref ! {Ref, {error, {invalid_private_key, Error}}}, - {stop, {error, Error}, State0} + Error = {invalid_private_key, Reason}, + ?tp(rule_engine_jwt_worker_startup_error, #{error => Error}), + {stop, {shutdown, {error, Error}}, State0}; + Error0 -> + Error = {invalid_private_key, Error0}, + ?tp(rule_engine_jwt_worker_startup_error, #{error => Error}), + {stop, {shutdown, {error, Error}}, State0} end; -handle_continue({create_token, Ref}, State0) -> - JWT = do_generate_jwt(State0), - store_jwt(State0, JWT), - State1 = State0#{jwt := JWT}, - State = ensure_timer(State1), - Ref ! {Ref, token_created}, +handle_continue(create_token, State0) -> + State = generate_and_store_jwt(State0), {noreply, State}. handle_call(_Req, _From, State) -> {reply, {error, bad_call}, State}. +handle_cast({ensure_jwt, From}, State0 = #{jwt := JWT}) -> + State = + case JWT of + undefined -> + generate_and_store_jwt(State0); + _ -> + State0 + end, + From ! {From, token_created}, + {noreply, State}; handle_cast(_Req, State) -> {noreply, State}. handle_info({timeout, TRef, ?refresh_jwt}, State0 = #{refresh_timer := TRef}) -> - JWT = do_generate_jwt(State0), - store_jwt(State0, JWT), - ?tp(rule_engine_jwt_worker_refresh, #{}), - State1 = State0#{jwt := JWT}, - State = ensure_timer(State1#{refresh_timer := undefined}), + State = generate_and_store_jwt(State0), {noreply, State}; handle_info(_Msg, State) -> {noreply, State}. @@ -171,10 +180,18 @@ do_generate_jwt(#{ expiration := ExpirationMS {_, JWT} = jose_jws:compact(JWT0), JWT. +-spec generate_and_store_jwt(state()) -> state(). +generate_and_store_jwt(State0) -> + JWT = do_generate_jwt(State0), + store_jwt(State0, JWT), + ?tp(rule_engine_jwt_worker_refresh, #{jwt => JWT}), + State1 = State0#{jwt := JWT}, + ensure_timer(State1). + -spec store_jwt(state(), jwt()) -> ok. store_jwt(#{resource_id := ResourceId, table := TId}, JWT) -> true = ets:insert(TId, {{ResourceId, jwt}, JWT}), - ?tp(jwt_worker_token_stored, #{resource_id => ResourceId}), + ?tp(rule_engine_jwt_worker_token_stored, #{resource_id => ResourceId}), ok. -spec ensure_timer(state()) -> state(). diff --git a/apps/emqx_rule_engine/test/emqx_rule_engine_jwt_worker_SUITE.erl b/apps/emqx_rule_engine/test/emqx_rule_engine_jwt_worker_SUITE.erl index bd701765b..fc84293e3 100644 --- a/apps/emqx_rule_engine/test/emqx_rule_engine_jwt_worker_SUITE.erl +++ b/apps/emqx_rule_engine/test/emqx_rule_engine_jwt_worker_SUITE.erl @@ -73,9 +73,11 @@ is_expired(JWT) -> %%----------------------------------------------------------------------------- t_create_success(_Config) -> - Ref = alias([reply]), Config = generate_config(), - ?assertMatch({ok, _}, emqx_rule_engine_jwt_worker:start_link(Config, Ref)), + Res = emqx_rule_engine_jwt_worker:start_link(Config), + ?assertMatch({ok, _}, Res), + {ok, Worker} = Res, + Ref = emqx_rule_engine_jwt_worker:ensure_jwt(Worker), receive {Ref, token_created} -> ok @@ -87,41 +89,40 @@ t_create_success(_Config) -> ok. t_empty_key(_Config) -> - Ref = alias([reply]), Config0 = generate_config(), Config = Config0#{private_key := <<>>}, process_flag(trap_exit, true), - ?assertMatch({ok, _}, emqx_rule_engine_jwt_worker:start_link(Config, Ref)), - receive - {Ref, {error, {invalid_private_key, empty_key}}} -> - ok - after - 1_000 -> - ct:fail("should have errored; msgs: ~0p", - [process_info(self(), messages)]) - end, + ?check_trace( + ?wait_async_action( + ?assertMatch({ok, _}, emqx_rule_engine_jwt_worker:start_link(Config)), + #{?snk_kind := rule_engine_jwt_worker_startup_error}, + 1_000), + fun(Trace) -> + ?assertMatch([#{error := empty_key}], + ?of_kind(rule_engine_jwt_worker_startup_error, Trace)), + ok + end), ok. t_invalid_pem(_Config) -> - Ref = alias([reply]), Config0 = generate_config(), InvalidPEM = public_key:pem_encode([{'PrivateKeyInfo', <<"xxxxxx">>, not_encrypted}, {'PrivateKeyInfo', <<"xxxxxx">>, not_encrypted}]), Config = Config0#{private_key := InvalidPEM}, process_flag(trap_exit, true), - ?assertMatch({ok, _}, emqx_rule_engine_jwt_worker:start_link(Config, Ref)), - receive - {Ref, {error, {invalid_private_key, _}}} -> - ok - after - 1_000 -> - ct:fail("should have errored; msgs: ~0p", - [process_info(self(), messages)]) - end, + ?check_trace( + ?wait_async_action( + ?assertMatch({ok, _}, emqx_rule_engine_jwt_worker:start_link(Config)), + #{?snk_kind := rule_engine_jwt_worker_startup_error}, + 1_000), + fun(Trace) -> + ?assertMatch([#{error := {invalid_private_key, _}}], + ?of_kind(rule_engine_jwt_worker_startup_error, Trace)), + ok + end), ok. t_refresh(_Config) -> - Ref = alias([reply]), Config0 = #{ table := Table , resource_id := ResourceId } = generate_config(), @@ -130,11 +131,12 @@ t_refresh(_Config) -> begin {{ok, _Pid}, {ok, _Event}} = ?wait_async_action( - emqx_rule_engine_jwt_worker:start_link(Config, Ref), - #{?snk_kind := jwt_worker_token_stored}, + emqx_rule_engine_jwt_worker:start_link(Config), + #{?snk_kind := rule_engine_jwt_worker_token_stored}, 5_000), {ok, FirstJWT} = emqx_rule_engine_jwt:lookup_jwt(Table, ResourceId), - ?block_until(#{?snk_kind := rule_engine_jwt_worker_refresh}, 15_000), + ?block_until(#{?snk_kind := rule_engine_jwt_worker_refresh, + jwt := JWT0} when JWT0 =/= FirstJWT, 15_000), {ok, SecondJWT} = emqx_rule_engine_jwt:lookup_jwt(Table, ResourceId), ?assertNot(is_expired(SecondJWT)), ?assert(is_expired(FirstJWT)), @@ -142,16 +144,15 @@ t_refresh(_Config) -> end, fun({FirstJWT, SecondJWT}, Trace) -> ?assertMatch([_, _ | _], - ?of_kind(jwt_worker_token_stored, Trace)), + ?of_kind(rule_engine_jwt_worker_token_stored, Trace)), ?assertNotEqual(FirstJWT, SecondJWT), ok end), ok. t_format_status(_Config) -> - Ref = alias([reply]), Config = generate_config(), - {ok, Pid} = emqx_rule_engine_jwt_worker:start_link(Config, Ref), + {ok, Pid} = emqx_rule_engine_jwt_worker:start_link(Config), {status, _, _, Props} = sys:get_status(Pid), [State] = [State || Info = [_ | _] <- Props, @@ -165,7 +166,6 @@ t_format_status(_Config) -> ok. t_lookup_ok(_Config) -> - Ref = alias([reply]), Config = #{ table := Table , resource_id := ResourceId , private_key := PrivateKeyPEM @@ -174,7 +174,8 @@ t_lookup_ok(_Config) -> , sub := Sub , kid := KId } = generate_config(), - {ok, _} = emqx_rule_engine_jwt_worker:start_link(Config, Ref), + {ok, Worker} = emqx_rule_engine_jwt_worker:start_link(Config), + Ref = emqx_rule_engine_jwt_worker:ensure_jwt(Worker), receive {Ref, token_created} -> ok @@ -225,7 +226,8 @@ t_lookup_badarg(_Config) -> t_start_supervised_worker(_Config) -> {ok, _} = emqx_rule_engine_jwt_sup:start_link(), Config = #{resource_id := ResourceId} = generate_config(), - {ok, {Ref, Pid}} = emqx_rule_engine_jwt_sup:start_worker(ResourceId, Config), + {ok, Pid} = emqx_rule_engine_jwt_sup:ensure_worker_present(ResourceId, Config), + Ref = emqx_rule_engine_jwt_worker:ensure_jwt(Pid), receive {Ref, token_created} -> ok @@ -235,7 +237,7 @@ t_start_supervised_worker(_Config) -> end, MRef = monitor(process, Pid), ?assert(is_process_alive(Pid)), - ok = emqx_rule_engine_jwt_sup:stop_worker(ResourceId), + ok = emqx_rule_engine_jwt_sup:ensure_worker_deleted(ResourceId), receive {'DOWN', MRef, process, Pid, _} -> ok