chore(emqx_auth_jwt): do not use gen_server call for jwt verification

This commit is contained in:
Ilya Averyanov 2022-04-11 19:15:14 +03:00
parent 59660e99db
commit 36c1ecd9b7
3 changed files with 39 additions and 36 deletions

View File

@ -46,8 +46,7 @@ register_metrics() ->
%% Authentication callbacks
%%--------------------------------------------------------------------
check(ClientInfo, AuthResult, #{pid := Pid,
from := From,
check(ClientInfo, AuthResult, #{from := From,
checklists := Checklists}) ->
case maps:find(From, ClientInfo) of
error ->
@ -55,7 +54,7 @@ check(ClientInfo, AuthResult, #{pid := Pid,
{ok, undefined} ->
ok = emqx_metrics:inc(?AUTH_METRICS(ignore));
{ok, Token} ->
case emqx_auth_jwt_svr:verify(Pid, Token) of
case emqx_auth_jwt_svr:verify(Token) of
{error, not_found} ->
ok = emqx_metrics:inc(?AUTH_METRICS(ignore));
{error, not_token} ->

View File

@ -33,11 +33,10 @@ start(_Type, _Args) ->
{ok, Pid} = start_auth_server(jwks_svr_options()),
ok = emqx_auth_jwt:register_metrics(),
AuthEnv0 = auth_env(),
AuthEnv1 = AuthEnv0#{pid => Pid},
AuthEnv = auth_env(),
_ = emqx:hook('client.authenticate', {emqx_auth_jwt, check, [AuthEnv1]}),
{ok, Sup, AuthEnv1}.
_ = emqx:hook('client.authenticate', {emqx_auth_jwt, check, [AuthEnv]}),
{ok, Sup, AuthEnv}.
stop(AuthEnv) ->
emqx:unhook('client.authenticate', {emqx_auth_jwt, check, [AuthEnv]}).

View File

@ -26,7 +26,7 @@
%% APIs
-export([start_link/1]).
-export([verify/2]).
-export([verify/1]).
%% gen_server callbacks
-export([ init/1
@ -44,8 +44,9 @@
| {interval, pos_integer()}.
-define(INTERVAL, 300000).
-define(TAB, ?MODULE).
-record(state, {static, remote, addr, tref, intv}).
-record(state, {addr, tref, intv}).
%%--------------------------------------------------------------------
%% APIs
@ -55,13 +56,13 @@
start_link(Options) ->
gen_server:start_link(?MODULE, [Options], []).
-spec verify(pid(), binary())
-spec verify(binary())
-> {error, term()}
| {ok, Payload :: map()}.
verify(S, JwsCompacted) when is_binary(JwsCompacted) ->
verify(JwsCompacted) when is_binary(JwsCompacted) ->
case catch jose_jws:peek(JwsCompacted) of
{'EXIT', _} -> {error, not_token};
_ -> gen_server:call(S, {verify, JwsCompacted})
_ -> do_verify(JwsCompacted)
end.
%%--------------------------------------------------------------------
@ -70,12 +71,13 @@ verify(S, JwsCompacted) when is_binary(JwsCompacted) ->
init([Options]) ->
ok = jose:json_module(jiffy),
_ = ets:new(?TAB, [set, protected, named_table]),
{Static, Remote} = do_init_jwks(Options),
true = ets:insert(?TAB, {static, Static}),
true = ets:insert(?TAB, {remote, Remote}),
Intv = proplists:get_value(interval, Options, ?INTERVAL),
{ok, reset_timer(
#state{
static = Static,
remote = Remote,
addr = proplists:get_value(jwks_addr, Options),
intv = Intv})}.
@ -105,9 +107,6 @@ do_init_jwks(Options) ->
Remote = K2J(jwks_addr, fun request_jwks/1),
{[J ||J <- [OctJwk, PemJwk], J /= undefined], Remote}.
handle_call({verify, JwsCompacted}, _From, State) ->
handle_verify(JwsCompacted, State);
handle_call(_Req, _From, State) ->
{reply, ok, State}.
@ -116,7 +115,7 @@ handle_cast(_Msg, State) ->
handle_info({timeout, _TRef, refresh}, State = #state{addr = Addr}) ->
NState = try
State#state{remote = request_jwks(Addr)}
true = ets:insert(?TAB, {remote, request_jwks(Addr)})
catch _:_ ->
State
end,
@ -136,24 +135,10 @@ code_change(_OldVsn, State, _Extra) ->
%% Internal funcs
%%--------------------------------------------------------------------
handle_verify(JwsCompacted,
State = #state{static = Static, remote = Remote}) ->
try
Jwks = case emqx_json:decode(jose_jws:peek_protected(JwsCompacted), [return_maps]) of
#{<<"kid">> := Kid} when Remote /= undefined ->
[J || J <- Remote, maps:get(<<"kid">>, J#jose_jwk.fields, undefined) =:= Kid];
_ -> Static
end,
case Jwks of
[] -> {reply, {error, not_found}, State};
_ ->
{reply, do_verify(JwsCompacted, Jwks), State}
end
catch
Class : Reason : Stk ->
?LOG(error, "Handle JWK crashed: ~p, ~p, stacktrace: ~p~n",
[Class, Reason, Stk]),
{reply, {error, invalid_signature}, State}
keys(Type) ->
case ets:lookup(?TAB, Type) of
[{_, Keys}] -> Keys;
[] -> []
end.
request_jwks(Addr) ->
@ -181,6 +166,26 @@ cancel_timer(State = #state{tref = TRef}) ->
_ = erlang:cancel_timer(TRef),
State#state{tref = undefined}.
do_verify(JwsCompacted) ->
try
Remote = keys(remote),
Jwks = case emqx_json:decode(jose_jws:peek_protected(JwsCompacted), [return_maps]) of
#{<<"kid">> := Kid} when Remote /= undefined ->
[J || J <- Remote, maps:get(<<"kid">>, J#jose_jwk.fields, undefined) =:= Kid];
_ -> keys(static)
end,
case Jwks of
[] -> {error, not_found};
_ ->
do_verify(JwsCompacted, Jwks)
end
catch
Class : Reason : Stk ->
?LOG(error, "Handle JWK crashed: ~p, ~p, stacktrace: ~p~n",
[Class, Reason, Stk]),
{error, invalid_signature}
end.
do_verify(_JwsCompated, []) ->
{error, invalid_signature};
do_verify(JwsCompacted, [Jwk|More]) ->