diff --git a/apps/emqx_auth_jwt/src/emqx_auth_jwt_svr.erl b/apps/emqx_auth_jwt/src/emqx_auth_jwt_svr.erl index 08c60d8ed..ac07a8640 100644 --- a/apps/emqx_auth_jwt/src/emqx_auth_jwt_svr.erl +++ b/apps/emqx_auth_jwt/src/emqx_auth_jwt_svr.erl @@ -73,8 +73,9 @@ verify(JwsCompacted) when is_binary(JwsCompacted) -> init([Options]) -> ok = jose:json_module(jiffy), _ = ets:new(?TAB, [set, protected, named_table]), - {Static, Remote} = do_init_jwks(Options), - true = ets:insert(?TAB, [{static, Static}, {remote, Remote}]), + Static = do_init_jwks(Options), + to_request_jwks(Options), + true = ets:insert(?TAB, [{static, Static}, {remote, undefined}]), Intv = proplists:get_value(interval, Options, ?INTERVAL), {ok, reset_timer( #state{ @@ -83,29 +84,13 @@ init([Options]) -> %% @private do_init_jwks(Options) -> - K2J = fun(K, F) -> - case proplists:get_value(K, Options) of - undefined -> undefined; - V -> - try F(V) of - {error, Reason} -> - ?LOG(warning, "Build ~p JWK ~p failed: {error, ~p}~n", - [K, V, Reason]), - undefined; - J -> J - catch T:R -> - ?LOG(warning, "Build ~p JWK ~p failed: {~p, ~p}~n", - [K, V, T, R]), - undefined - end - end - end, - OctJwk = K2J(secret, fun(V) -> - jose_jwk:from_oct(list_to_binary(V)) - end), - PemJwk = K2J(pubkey, fun jose_jwk:from_pem_file/1), - Remote = K2J(jwks_addr, fun request_jwks/1), - {[J ||J <- [OctJwk, PemJwk], J /= undefined], Remote}. + OctJwk = key2jwt_value(secret, + fun(V) -> + jose_jwk:from_oct(list_to_binary(V)) + end, + Options), + PemJwk = key2jwt_value(pubkey, fun jose_jwk:from_pem_file/1, Options), + [J ||J <- [OctJwk, PemJwk], J /= undefined]. handle_call(_Req, _From, State) -> {reply, ok, State}. @@ -122,6 +107,11 @@ handle_info({timeout, _TRef, refresh}, State = #state{addr = Addr}) -> end, {noreply, reset_timer(NState)}; +handle_info({request_jwks, Options}, State) -> + Remote = key2jwt_value(jwks_addr, fun request_jwks/1, Options), + true = ets:insert(?TAB, {remote, Remote}), + {noreply, State}; + handle_info(_Info, State) -> {noreply, State}. @@ -249,3 +239,23 @@ do_check_claim([{K, F}|More], Claims) -> _ -> do_check_claim(More, Claims) end. + +to_request_jwks(Options) -> + erlang:send(self(), {request_jwks, Options}). + +key2jwt_value(Key, Func, Options) -> + case proplists:get_value(Key, Options) of + undefined -> undefined; + V -> + try Func(V) of + {error, Reason} -> + ?LOG(warning, "Build ~p JWK ~p failed: {error, ~p}~n", + [Key, V, Reason]), + undefined; + J -> J + catch T:R -> + ?LOG(warning, "Build ~p JWK ~p failed: {~p, ~p}~n", + [Key, V, T, R]), + undefined + end + end.