diff --git a/apps/emqx_rule_engine/src/emqx_rule_engine_jwt.erl b/apps/emqx_rule_engine/src/emqx_rule_engine_jwt.erl index 828c77f93..237c08c72 100644 --- a/apps/emqx_rule_engine/src/emqx_rule_engine_jwt.erl +++ b/apps/emqx_rule_engine/src/emqx_rule_engine_jwt.erl @@ -22,6 +22,7 @@ %% API -export([ lookup_jwt/1 , lookup_jwt/2 + , delete_jwt/2 ]). -type jwt() :: binary(). @@ -43,3 +44,13 @@ lookup_jwt(TId, ResourceId) -> error:badarg -> {error, not_found} end. + +-spec delete_jwt(ets:table(), resource_id()) -> ok. +delete_jwt(TId, ResourceId) -> + try + ets:delete(TId, {ResourceId, jwt}), + ok + catch + error:badarg -> + ok + end. diff --git a/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_sup.erl b/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_sup.erl index b393dd08b..989be2304 100644 --- a/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_sup.erl +++ b/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_sup.erl @@ -72,7 +72,7 @@ jwt_worker_child_spec(Id, Config) -> , restart => transient , type => worker , significant => false - , shutdown => brutal_kill + , shutdown => 5_000 , modules => [emqx_rule_engine_jwt_worker] }. diff --git a/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_worker.erl b/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_worker.erl index 4190a3536..7e3604701 100644 --- a/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_worker.erl +++ b/apps/emqx_rule_engine/src/emqx_rule_engine_jwt_worker.erl @@ -31,8 +31,11 @@ , handle_info/2 , format_status/1 , format_status/2 + , terminate/2 ]). +-export([force_refresh/1]). + -include_lib("jose/include/jose_jwk.hrl"). -include_lib("emqx_rule_engine/include/rule_engine.hrl"). -include_lib("emqx_rule_engine/include/rule_actions.hrl"). @@ -49,7 +52,7 @@ , alg := binary() }. -type jwt() :: binary(). --type state() :: #{ refresh_timer := undefined | timer:tref() +-type state() :: #{ refresh_timer := undefined | timer:tref() | reference() , resource_id := resource_id() , expiration := timer:time() , table := ets:table() @@ -88,6 +91,11 @@ ensure_jwt(Worker) -> gen_server:cast(Worker, {ensure_jwt, Ref}), Ref. +-spec force_refresh(pid()) -> ok. +force_refresh(Worker) -> + _ = erlang:send(Worker, {timeout, force_refresh, ?refresh_jwt}), + ok. + %%----------------------------------------------------------------------------------------- %% gen_server API %%----------------------------------------------------------------------------------------- @@ -95,6 +103,7 @@ ensure_jwt(Worker) -> -spec init(config()) -> {ok, state(), {continue, {make_key, binary()}}} | {stop, {error, term()}}. init(#{private_key := PrivateKeyPEM} = Config) -> + process_flag(trap_exit, true), State0 = maps:without([private_key], Config), State = State0#{ jwk => undefined , jwt => undefined @@ -139,7 +148,7 @@ handle_cast({ensure_jwt, From}, State0 = #{jwt := JWT}) -> handle_cast(_Req, State) -> {noreply, State}. -handle_info({timeout, TRef, ?refresh_jwt}, State0 = #{refresh_timer := TRef}) -> +handle_info({timeout, _TRef, ?refresh_jwt}, State0) -> State = generate_and_store_jwt(State0), {noreply, State}; handle_info(_Msg, State) -> @@ -152,6 +161,11 @@ format_status(_Opt, [_PDict, State0]) -> State = censor_secrets(State0), [{data, [{"State", State}]}]. +terminate(_Reason, State) -> + #{resource_id := ResourceId, table := TId} = State, + emqx_rule_engine_jwt:delete_jwt(TId, ResourceId), + ok. + %%----------------------------------------------------------------------------------------- %% Helper fns %%----------------------------------------------------------------------------------------- @@ -195,14 +209,13 @@ store_jwt(#{resource_id := ResourceId, table := TId}, JWT) -> ok. -spec ensure_timer(state()) -> state(). -ensure_timer(State = #{ refresh_timer := undefined +ensure_timer(State = #{ refresh_timer := OldTimer , expiration := ExpirationMS0 }) -> - ExpirationMS = max(5_000, ExpirationMS0 - 5_000), + ok = cancel_timer(OldTimer), + ExpirationMS = max(5_000, ExpirationMS0 - 60_000), TRef = erlang:start_timer(ExpirationMS, self(), ?refresh_jwt), - State#{refresh_timer => TRef}; -ensure_timer(State) -> - State. + State#{refresh_timer => TRef}. -spec censor_secrets(state()) -> map(). censor_secrets(State) -> @@ -214,3 +227,10 @@ censor_secrets(State) -> Value end, State). + +-spec cancel_timer(undefined | timer:tref() | reference()) -> ok. +cancel_timer(undefined) -> + ok; +cancel_timer(TRef) -> + _ = erlang:cancel_timer(TRef), + ok. diff --git a/apps/emqx_rule_engine/test/emqx_rule_engine_jwt_worker_SUITE.erl b/apps/emqx_rule_engine/test/emqx_rule_engine_jwt_worker_SUITE.erl index fc84293e3..0de58df40 100644 --- a/apps/emqx_rule_engine/test/emqx_rule_engine_jwt_worker_SUITE.erl +++ b/apps/emqx_rule_engine/test/emqx_rule_engine_jwt_worker_SUITE.erl @@ -140,12 +140,22 @@ t_refresh(_Config) -> {ok, SecondJWT} = emqx_rule_engine_jwt:lookup_jwt(Table, ResourceId), ?assertNot(is_expired(SecondJWT)), ?assert(is_expired(FirstJWT)), - {FirstJWT, SecondJWT} + %% check yet another refresh to ensure the timer was properly + %% reset. + ?block_until(#{?snk_kind := rule_engine_jwt_worker_refresh, + jwt := JWT1} when JWT1 =/= SecondJWT + andalso JWT1 =/= FirstJWT, 15_000), + {ok, ThirdJWT} = emqx_rule_engine_jwt:lookup_jwt(Table, ResourceId), + ?assertNot(is_expired(ThirdJWT)), + ?assert(is_expired(SecondJWT)), + {FirstJWT, SecondJWT, ThirdJWT} end, - fun({FirstJWT, SecondJWT}, Trace) -> - ?assertMatch([_, _ | _], + fun({FirstJWT, SecondJWT, ThirdJWT}, Trace) -> + ?assertMatch([_, _, _ | _], ?of_kind(rule_engine_jwt_worker_token_stored, Trace)), ?assertNotEqual(FirstJWT, SecondJWT), + ?assertNotEqual(SecondJWT, ThirdJWT), + ?assertNotEqual(FirstJWT, ThirdJWT), ok end), ok. @@ -225,7 +235,7 @@ t_lookup_badarg(_Config) -> t_start_supervised_worker(_Config) -> {ok, _} = emqx_rule_engine_jwt_sup:start_link(), - Config = #{resource_id := ResourceId} = generate_config(), + Config = #{resource_id := ResourceId, table := TId} = generate_config(), {ok, Pid} = emqx_rule_engine_jwt_sup:ensure_worker_present(ResourceId, Config), Ref = emqx_rule_engine_jwt_worker:ensure_jwt(Pid), receive @@ -237,6 +247,7 @@ t_start_supervised_worker(_Config) -> end, MRef = monitor(process, Pid), ?assert(is_process_alive(Pid)), + ?assertMatch({ok, _}, emqx_rule_engine_jwt:lookup_jwt(TId, ResourceId)), ok = emqx_rule_engine_jwt_sup:ensure_worker_deleted(ResourceId), receive {'DOWN', MRef, process, Pid, _} -> @@ -245,4 +256,7 @@ t_start_supervised_worker(_Config) -> 1_000 -> ct:fail("timeout") end, + %% ensure it cleans up its own tokens to avoid leakage when + %% probing/testing rule resources. + ?assertEqual({error, not_found}, emqx_rule_engine_jwt:lookup_jwt(TId, ResourceId)), ok.