chore(emqx_auth_jwt): do not use gen_server call for jwt verification

This commit is contained in:
Ilya Averyanov 2022-04-11 19:15:14 +03:00
parent 59660e99db
commit 36c1ecd9b7
3 changed files with 39 additions and 36 deletions

View File

@ -46,8 +46,7 @@ register_metrics() ->
%% Authentication callbacks %% Authentication callbacks
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
check(ClientInfo, AuthResult, #{pid := Pid, check(ClientInfo, AuthResult, #{from := From,
from := From,
checklists := Checklists}) -> checklists := Checklists}) ->
case maps:find(From, ClientInfo) of case maps:find(From, ClientInfo) of
error -> error ->
@ -55,7 +54,7 @@ check(ClientInfo, AuthResult, #{pid := Pid,
{ok, undefined} -> {ok, undefined} ->
ok = emqx_metrics:inc(?AUTH_METRICS(ignore)); ok = emqx_metrics:inc(?AUTH_METRICS(ignore));
{ok, Token} -> {ok, Token} ->
case emqx_auth_jwt_svr:verify(Pid, Token) of case emqx_auth_jwt_svr:verify(Token) of
{error, not_found} -> {error, not_found} ->
ok = emqx_metrics:inc(?AUTH_METRICS(ignore)); ok = emqx_metrics:inc(?AUTH_METRICS(ignore));
{error, not_token} -> {error, not_token} ->

View File

@ -33,11 +33,10 @@ start(_Type, _Args) ->
{ok, Pid} = start_auth_server(jwks_svr_options()), {ok, Pid} = start_auth_server(jwks_svr_options()),
ok = emqx_auth_jwt:register_metrics(), ok = emqx_auth_jwt:register_metrics(),
AuthEnv0 = auth_env(), AuthEnv = auth_env(),
AuthEnv1 = AuthEnv0#{pid => Pid},
_ = emqx:hook('client.authenticate', {emqx_auth_jwt, check, [AuthEnv1]}), _ = emqx:hook('client.authenticate', {emqx_auth_jwt, check, [AuthEnv]}),
{ok, Sup, AuthEnv1}. {ok, Sup, AuthEnv}.
stop(AuthEnv) -> stop(AuthEnv) ->
emqx:unhook('client.authenticate', {emqx_auth_jwt, check, [AuthEnv]}). emqx:unhook('client.authenticate', {emqx_auth_jwt, check, [AuthEnv]}).

View File

@ -26,7 +26,7 @@
%% APIs %% APIs
-export([start_link/1]). -export([start_link/1]).
-export([verify/2]). -export([verify/1]).
%% gen_server callbacks %% gen_server callbacks
-export([ init/1 -export([ init/1
@ -44,8 +44,9 @@
| {interval, pos_integer()}. | {interval, pos_integer()}.
-define(INTERVAL, 300000). -define(INTERVAL, 300000).
-define(TAB, ?MODULE).
-record(state, {static, remote, addr, tref, intv}). -record(state, {addr, tref, intv}).
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% APIs %% APIs
@ -55,13 +56,13 @@
start_link(Options) -> start_link(Options) ->
gen_server:start_link(?MODULE, [Options], []). gen_server:start_link(?MODULE, [Options], []).
-spec verify(pid(), binary()) -spec verify(binary())
-> {error, term()} -> {error, term()}
| {ok, Payload :: map()}. | {ok, Payload :: map()}.
verify(S, JwsCompacted) when is_binary(JwsCompacted) -> verify(JwsCompacted) when is_binary(JwsCompacted) ->
case catch jose_jws:peek(JwsCompacted) of case catch jose_jws:peek(JwsCompacted) of
{'EXIT', _} -> {error, not_token}; {'EXIT', _} -> {error, not_token};
_ -> gen_server:call(S, {verify, JwsCompacted}) _ -> do_verify(JwsCompacted)
end. end.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
@ -70,12 +71,13 @@ verify(S, JwsCompacted) when is_binary(JwsCompacted) ->
init([Options]) -> init([Options]) ->
ok = jose:json_module(jiffy), ok = jose:json_module(jiffy),
_ = ets:new(?TAB, [set, protected, named_table]),
{Static, Remote} = do_init_jwks(Options), {Static, Remote} = do_init_jwks(Options),
true = ets:insert(?TAB, {static, Static}),
true = ets:insert(?TAB, {remote, Remote}),
Intv = proplists:get_value(interval, Options, ?INTERVAL), Intv = proplists:get_value(interval, Options, ?INTERVAL),
{ok, reset_timer( {ok, reset_timer(
#state{ #state{
static = Static,
remote = Remote,
addr = proplists:get_value(jwks_addr, Options), addr = proplists:get_value(jwks_addr, Options),
intv = Intv})}. intv = Intv})}.
@ -105,9 +107,6 @@ do_init_jwks(Options) ->
Remote = K2J(jwks_addr, fun request_jwks/1), Remote = K2J(jwks_addr, fun request_jwks/1),
{[J ||J <- [OctJwk, PemJwk], J /= undefined], Remote}. {[J ||J <- [OctJwk, PemJwk], J /= undefined], Remote}.
handle_call({verify, JwsCompacted}, _From, State) ->
handle_verify(JwsCompacted, State);
handle_call(_Req, _From, State) -> handle_call(_Req, _From, State) ->
{reply, ok, State}. {reply, ok, State}.
@ -116,7 +115,7 @@ handle_cast(_Msg, State) ->
handle_info({timeout, _TRef, refresh}, State = #state{addr = Addr}) -> handle_info({timeout, _TRef, refresh}, State = #state{addr = Addr}) ->
NState = try NState = try
State#state{remote = request_jwks(Addr)} true = ets:insert(?TAB, {remote, request_jwks(Addr)})
catch _:_ -> catch _:_ ->
State State
end, end,
@ -136,24 +135,10 @@ code_change(_OldVsn, State, _Extra) ->
%% Internal funcs %% Internal funcs
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
handle_verify(JwsCompacted, keys(Type) ->
State = #state{static = Static, remote = Remote}) -> case ets:lookup(?TAB, Type) of
try [{_, Keys}] -> Keys;
Jwks = case emqx_json:decode(jose_jws:peek_protected(JwsCompacted), [return_maps]) of [] -> []
#{<<"kid">> := Kid} when Remote /= undefined ->
[J || J <- Remote, maps:get(<<"kid">>, J#jose_jwk.fields, undefined) =:= Kid];
_ -> Static
end,
case Jwks of
[] -> {reply, {error, not_found}, State};
_ ->
{reply, do_verify(JwsCompacted, Jwks), State}
end
catch
Class : Reason : Stk ->
?LOG(error, "Handle JWK crashed: ~p, ~p, stacktrace: ~p~n",
[Class, Reason, Stk]),
{reply, {error, invalid_signature}, State}
end. end.
request_jwks(Addr) -> request_jwks(Addr) ->
@ -181,6 +166,26 @@ cancel_timer(State = #state{tref = TRef}) ->
_ = erlang:cancel_timer(TRef), _ = erlang:cancel_timer(TRef),
State#state{tref = undefined}. State#state{tref = undefined}.
do_verify(JwsCompacted) ->
try
Remote = keys(remote),
Jwks = case emqx_json:decode(jose_jws:peek_protected(JwsCompacted), [return_maps]) of
#{<<"kid">> := Kid} when Remote /= undefined ->
[J || J <- Remote, maps:get(<<"kid">>, J#jose_jwk.fields, undefined) =:= Kid];
_ -> keys(static)
end,
case Jwks of
[] -> {error, not_found};
_ ->
do_verify(JwsCompacted, Jwks)
end
catch
Class : Reason : Stk ->
?LOG(error, "Handle JWK crashed: ~p, ~p, stacktrace: ~p~n",
[Class, Reason, Stk]),
{error, invalid_signature}
end.
do_verify(_JwsCompated, []) -> do_verify(_JwsCompated, []) ->
{error, invalid_signature}; {error, invalid_signature};
do_verify(JwsCompacted, [Jwk|More]) -> do_verify(JwsCompacted, [Jwk|More]) ->