From 36c1ecd9b7313d7de5f74637b88f0ed2255ec935 Mon Sep 17 00:00:00 2001 From: Ilya Averyanov Date: Mon, 11 Apr 2022 19:15:14 +0300 Subject: [PATCH] chore(emqx_auth_jwt): do not use gen_server call for jwt verification --- apps/emqx_auth_jwt/src/emqx_auth_jwt.erl | 5 +- apps/emqx_auth_jwt/src/emqx_auth_jwt_app.erl | 7 +-- apps/emqx_auth_jwt/src/emqx_auth_jwt_svr.erl | 63 +++++++++++--------- 3 files changed, 39 insertions(+), 36 deletions(-) diff --git a/apps/emqx_auth_jwt/src/emqx_auth_jwt.erl b/apps/emqx_auth_jwt/src/emqx_auth_jwt.erl index 9635b185d..8221ad8e6 100644 --- a/apps/emqx_auth_jwt/src/emqx_auth_jwt.erl +++ b/apps/emqx_auth_jwt/src/emqx_auth_jwt.erl @@ -46,8 +46,7 @@ register_metrics() -> %% Authentication callbacks %%-------------------------------------------------------------------- -check(ClientInfo, AuthResult, #{pid := Pid, - from := From, +check(ClientInfo, AuthResult, #{from := From, checklists := Checklists}) -> case maps:find(From, ClientInfo) of error -> @@ -55,7 +54,7 @@ check(ClientInfo, AuthResult, #{pid := Pid, {ok, undefined} -> ok = emqx_metrics:inc(?AUTH_METRICS(ignore)); {ok, Token} -> - case emqx_auth_jwt_svr:verify(Pid, Token) of + case emqx_auth_jwt_svr:verify(Token) of {error, not_found} -> ok = emqx_metrics:inc(?AUTH_METRICS(ignore)); {error, not_token} -> diff --git a/apps/emqx_auth_jwt/src/emqx_auth_jwt_app.erl b/apps/emqx_auth_jwt/src/emqx_auth_jwt_app.erl index 7d89ddd5f..07a09d4d1 100644 --- a/apps/emqx_auth_jwt/src/emqx_auth_jwt_app.erl +++ b/apps/emqx_auth_jwt/src/emqx_auth_jwt_app.erl @@ -33,11 +33,10 @@ start(_Type, _Args) -> {ok, Pid} = start_auth_server(jwks_svr_options()), ok = emqx_auth_jwt:register_metrics(), - AuthEnv0 = auth_env(), - AuthEnv1 = AuthEnv0#{pid => Pid}, + AuthEnv = auth_env(), - _ = emqx:hook('client.authenticate', {emqx_auth_jwt, check, [AuthEnv1]}), - {ok, Sup, AuthEnv1}. + _ = emqx:hook('client.authenticate', {emqx_auth_jwt, check, [AuthEnv]}), + {ok, Sup, AuthEnv}. stop(AuthEnv) -> emqx:unhook('client.authenticate', {emqx_auth_jwt, check, [AuthEnv]}). diff --git a/apps/emqx_auth_jwt/src/emqx_auth_jwt_svr.erl b/apps/emqx_auth_jwt/src/emqx_auth_jwt_svr.erl index d550ac590..179fbded5 100644 --- a/apps/emqx_auth_jwt/src/emqx_auth_jwt_svr.erl +++ b/apps/emqx_auth_jwt/src/emqx_auth_jwt_svr.erl @@ -26,7 +26,7 @@ %% APIs -export([start_link/1]). --export([verify/2]). +-export([verify/1]). %% gen_server callbacks -export([ init/1 @@ -44,8 +44,9 @@ | {interval, pos_integer()}. -define(INTERVAL, 300000). +-define(TAB, ?MODULE). --record(state, {static, remote, addr, tref, intv}). +-record(state, {addr, tref, intv}). %%-------------------------------------------------------------------- %% APIs @@ -55,13 +56,13 @@ start_link(Options) -> gen_server:start_link(?MODULE, [Options], []). --spec verify(pid(), binary()) +-spec verify(binary()) -> {error, term()} | {ok, Payload :: map()}. -verify(S, JwsCompacted) when is_binary(JwsCompacted) -> +verify(JwsCompacted) when is_binary(JwsCompacted) -> case catch jose_jws:peek(JwsCompacted) of {'EXIT', _} -> {error, not_token}; - _ -> gen_server:call(S, {verify, JwsCompacted}) + _ -> do_verify(JwsCompacted) end. %%-------------------------------------------------------------------- @@ -70,12 +71,13 @@ verify(S, JwsCompacted) when is_binary(JwsCompacted) -> init([Options]) -> ok = jose:json_module(jiffy), + _ = ets:new(?TAB, [set, protected, named_table]), {Static, Remote} = do_init_jwks(Options), + true = ets:insert(?TAB, {static, Static}), + true = ets:insert(?TAB, {remote, Remote}), Intv = proplists:get_value(interval, Options, ?INTERVAL), {ok, reset_timer( #state{ - static = Static, - remote = Remote, addr = proplists:get_value(jwks_addr, Options), intv = Intv})}. @@ -105,9 +107,6 @@ do_init_jwks(Options) -> Remote = K2J(jwks_addr, fun request_jwks/1), {[J ||J <- [OctJwk, PemJwk], J /= undefined], Remote}. -handle_call({verify, JwsCompacted}, _From, State) -> - handle_verify(JwsCompacted, State); - handle_call(_Req, _From, State) -> {reply, ok, State}. @@ -116,7 +115,7 @@ handle_cast(_Msg, State) -> handle_info({timeout, _TRef, refresh}, State = #state{addr = Addr}) -> NState = try - State#state{remote = request_jwks(Addr)} + true = ets:insert(?TAB, {remote, request_jwks(Addr)}) catch _:_ -> State end, @@ -136,24 +135,10 @@ code_change(_OldVsn, State, _Extra) -> %% Internal funcs %%-------------------------------------------------------------------- -handle_verify(JwsCompacted, - State = #state{static = Static, remote = Remote}) -> - try - Jwks = case emqx_json:decode(jose_jws:peek_protected(JwsCompacted), [return_maps]) of - #{<<"kid">> := Kid} when Remote /= undefined -> - [J || J <- Remote, maps:get(<<"kid">>, J#jose_jwk.fields, undefined) =:= Kid]; - _ -> Static - end, - case Jwks of - [] -> {reply, {error, not_found}, State}; - _ -> - {reply, do_verify(JwsCompacted, Jwks), State} - end - catch - Class : Reason : Stk -> - ?LOG(error, "Handle JWK crashed: ~p, ~p, stacktrace: ~p~n", - [Class, Reason, Stk]), - {reply, {error, invalid_signature}, State} +keys(Type) -> + case ets:lookup(?TAB, Type) of + [{_, Keys}] -> Keys; + [] -> [] end. request_jwks(Addr) -> @@ -181,6 +166,26 @@ cancel_timer(State = #state{tref = TRef}) -> _ = erlang:cancel_timer(TRef), State#state{tref = undefined}. +do_verify(JwsCompacted) -> + try + Remote = keys(remote), + Jwks = case emqx_json:decode(jose_jws:peek_protected(JwsCompacted), [return_maps]) of + #{<<"kid">> := Kid} when Remote /= undefined -> + [J || J <- Remote, maps:get(<<"kid">>, J#jose_jwk.fields, undefined) =:= Kid]; + _ -> keys(static) + end, + case Jwks of + [] -> {error, not_found}; + _ -> + do_verify(JwsCompacted, Jwks) + end + catch + Class : Reason : Stk -> + ?LOG(error, "Handle JWK crashed: ~p, ~p, stacktrace: ~p~n", + [Class, Reason, Stk]), + {error, invalid_signature} + end. + do_verify(_JwsCompated, []) -> {error, invalid_signature}; do_verify(JwsCompacted, [Jwk|More]) ->