feat(jwks): allow specifying custom request headers

Fixes https://emqx.atlassian.net/browse/EMQX-12655
This commit is contained in:
Thales Macedo Garitezi 2024-07-08 13:29:21 -03:00
parent d34fc7a03a
commit 811184ddad
6 changed files with 170 additions and 3 deletions

View File

@ -63,6 +63,7 @@
-type json_binary() :: binary(). -type json_binary() :: binary().
-type template() :: binary(). -type template() :: binary().
-type template_str() :: string(). -type template_str() :: string().
-type binary_kv() :: #{binary() => binary()}.
-typerefl_from_string({duration/0, emqx_schema, to_duration}). -typerefl_from_string({duration/0, emqx_schema, to_duration}).
-typerefl_from_string({duration_s/0, emqx_schema, to_duration_s}). -typerefl_from_string({duration_s/0, emqx_schema, to_duration_s}).
@ -167,7 +168,8 @@
json_binary/0, json_binary/0,
port_number/0, port_number/0,
template/0, template/0,
template_str/0 template_str/0,
binary_kv/0
]). ]).
-export([namespace/0, roots/0, roots/1, fields/1, desc/1, tags/0]). -export([namespace/0, roots/0, roots/1, fields/1, desc/1, tags/0]).

View File

@ -133,11 +133,13 @@ code_change(_OldVsn, State, _Extra) ->
handle_options(#{ handle_options(#{
endpoint := Endpoint, endpoint := Endpoint,
headers := Headers,
refresh_interval := RefreshInterval0, refresh_interval := RefreshInterval0,
ssl_opts := SSLOpts ssl_opts := SSLOpts
}) -> }) ->
#{ #{
endpoint => Endpoint, endpoint => Endpoint,
headers => to_httpc_headers(Headers),
refresh_interval => limit_refresh_interval(RefreshInterval0), refresh_interval => limit_refresh_interval(RefreshInterval0),
ssl_opts => maps:to_list(SSLOpts), ssl_opts => maps:to_list(SSLOpts),
jwks => [], jwks => [],
@ -147,6 +149,7 @@ handle_options(#{
refresh_jwks( refresh_jwks(
#{ #{
endpoint := Endpoint, endpoint := Endpoint,
headers := Headers,
ssl_opts := SSLOpts ssl_opts := SSLOpts
} = State } = State
) -> ) ->
@ -159,7 +162,7 @@ refresh_jwks(
case case
httpc:request( httpc:request(
get, get,
{Endpoint, [{"Accept", "application/json"}]}, {Endpoint, Headers},
HTTPOpts, HTTPOpts,
[{body_format, binary}, {sync, false}, {receiver, self()}] [{body_format, binary}, {sync, false}, {receiver, self()}]
) )
@ -185,6 +188,9 @@ limit_refresh_interval(Interval) when Interval < 10 ->
limit_refresh_interval(Interval) -> limit_refresh_interval(Interval) ->
Interval. Interval.
to_httpc_headers(Headers) ->
[{binary_to_list(bin(K)), V} || {K, V} <- maps:to_list(Headers)].
cancel_http_request(#{request_id := undefined} = State) -> cancel_http_request(#{request_id := undefined} = State) ->
State; State;
cancel_http_request(#{request_id := RequestID} = State) -> cancel_http_request(#{request_id := RequestID} = State) ->
@ -195,3 +201,10 @@ cancel_http_request(#{request_id := RequestID} = State) ->
ok ok
end, end,
State#{request_id => undefined}. State#{request_id => undefined}.
bin(List) when is_list(List) ->
unicode:characters_to_binary(List, utf8);
bin(Atom) when is_atom(Atom) ->
erlang:atom_to_binary(Atom);
bin(Bin) when is_binary(Bin) ->
Bin.

View File

@ -95,6 +95,15 @@ fields(jwt_jwks) ->
[ [
{use_jwks, sc(hoconsc:enum([true]), #{required => true, desc => ?DESC(use_jwks)})}, {use_jwks, sc(hoconsc:enum([true]), #{required => true, desc => ?DESC(use_jwks)})},
{endpoint, fun endpoint/1}, {endpoint, fun endpoint/1},
{headers,
sc(
typerefl:alias("map", emqx_schema:binary_kv()),
#{
default => #{<<"Accept">> => <<"application/json">>},
validator => fun validate_headers/1,
desc => ?DESC("jwks_headers")
}
)},
{pool_size, fun emqx_connector_schema_lib:pool_size/1}, {pool_size, fun emqx_connector_schema_lib:pool_size/1},
{refresh_interval, fun refresh_interval/1}, {refresh_interval, fun refresh_interval/1},
{ssl, #{ {ssl, #{
@ -225,3 +234,26 @@ to_binary(B) when is_binary(B) ->
B. B.
sc(Type, Meta) -> hoconsc:mk(Type, Meta). sc(Type, Meta) -> hoconsc:mk(Type, Meta).
validate_headers(undefined) ->
ok;
validate_headers(Headers) ->
BadKeys0 =
lists:filter(
fun(K) ->
re:run(K, <<"[^-0-9a-zA-Z_ ]">>, [{capture, none}]) =:= match
end,
maps:keys(Headers)
),
case BadKeys0 of
[] ->
ok;
_ ->
BadKeys = lists:join(", ", BadKeys0),
Msg0 = io_lib:format(
"headers should contain only characters matching [-0-9a-zA-Z_ ]; bad headers: ~s",
[BadKeys]
),
Msg = iolist_to_binary(Msg0),
{error, Msg}
end.

View File

@ -22,18 +22,21 @@
-include_lib("common_test/include/ct.hrl"). -include_lib("common_test/include/ct.hrl").
-include_lib("eunit/include/eunit.hrl"). -include_lib("eunit/include/eunit.hrl").
-include_lib("snabbkaffe/include/snabbkaffe.hrl"). -include_lib("snabbkaffe/include/snabbkaffe.hrl").
-include_lib("emqx/include/asserts.hrl").
-define(AUTHN_ID, <<"mechanism:jwt">>). -define(AUTHN_ID, <<"mechanism:jwt">>).
-define(JWKS_PORT, 31333). -define(JWKS_PORT, 31333).
-define(JWKS_PATH, "/jwks.json"). -define(JWKS_PATH, "/jwks.json").
-import(emqx_common_test_helpers, [on_exit/1]).
all() -> all() ->
emqx_common_test_helpers:all(?MODULE). emqx_common_test_helpers:all(?MODULE).
init_per_suite(Config) -> init_per_suite(Config) ->
Apps = emqx_cth_suite:start([emqx, emqx_conf, emqx_auth, emqx_auth_jwt], #{ Apps = emqx_cth_suite:start([emqx, emqx_conf, emqx_auth, emqx_auth_jwt], #{
work_dir => ?config(priv_dir, Config) work_dir => emqx_cth_suite:work_dir(Config)
}), }),
[{apps, Apps} | Config]. [{apps, Apps} | Config].
@ -41,6 +44,10 @@ end_per_suite(Config) ->
ok = emqx_cth_suite:stop(?config(apps, Config)), ok = emqx_cth_suite:stop(?config(apps, Config)),
ok. ok.
end_per_testcase(_TestCase, _Config) ->
emqx_common_test_helpers:call_janitor(),
ok.
%%------------------------------------------------------------------------------ %%------------------------------------------------------------------------------
%% Tests %% Tests
%%------------------------------------------------------------------------------ %%------------------------------------------------------------------------------
@ -244,6 +251,7 @@ t_jwks_renewal(_Config) ->
disconnect_after_expire => false, disconnect_after_expire => false,
use_jwks => true, use_jwks => true,
endpoint => "https://127.0.0.1:" ++ integer_to_list(?JWKS_PORT + 1) ++ ?JWKS_PATH, endpoint => "https://127.0.0.1:" ++ integer_to_list(?JWKS_PORT + 1) ++ ?JWKS_PATH,
headers => #{<<"Accept">> => <<"application/json">>},
refresh_interval => 1000, refresh_interval => 1000,
pool_size => 1 pool_size => 1
}, },
@ -328,6 +336,102 @@ t_jwks_renewal(_Config) ->
?assertEqual(ok, emqx_authn_jwt:destroy(State2)), ?assertEqual(ok, emqx_authn_jwt:destroy(State2)),
ok = emqx_authn_http_test_server:stop(). ok = emqx_authn_http_test_server:stop().
t_jwks_custom_headers(_Config) ->
{ok, _} = emqx_authn_http_test_server:start_link(?JWKS_PORT, ?JWKS_PATH, server_ssl_opts()),
on_exit(fun() -> ok = emqx_authn_http_test_server:stop() end),
ok = emqx_authn_http_test_server:set_handler(jwks_handler_spy()),
PrivateKey = test_rsa_key(private),
Payload = #{
<<"username">> => <<"myuser">>,
<<"foo">> => <<"myuser">>,
<<"exp">> => erlang:system_time(second) + 10
},
Endpoint = iolist_to_binary("https://127.0.0.1:" ++ integer_to_list(?JWKS_PORT) ++ ?JWKS_PATH),
Config0 = #{
<<"mechanism">> => <<"jwt">>,
<<"use_jwks">> => true,
<<"from">> => <<"password">>,
<<"endpoint">> => Endpoint,
<<"headers">> => #{
<<"Accept">> => <<"application/json">>,
<<"Content-Type">> => <<>>,
<<"foo">> => <<"bar">>
},
<<"pool_size">> => 1,
<<"refresh_interval">> => 1_000,
<<"ssl">> => #{
<<"keyfile">> => cert_file("client.key"),
<<"certfile">> => cert_file("client.crt"),
<<"cacertfile">> => cert_file("ca.crt"),
<<"enable">> => true,
<<"verify">> => <<"verify_peer">>,
<<"server_name_indication">> => <<"authn-server">>
},
<<"verify_claims">> => #{<<"foo">> => <<"${username}">>}
},
{ok, Config} = hocon:binary(hocon_pp:do(Config0, #{})),
ChainName = 'mqtt:global',
AuthenticatorId = <<"jwt">>,
?check_trace(
#{timetrap => 10_000},
begin
%% bad header keys
BadConfig1 = emqx_utils_maps:deep_put(
[<<"headers">>, <<"ça-va"/utf8>>], Config, <<"bien">>
),
?assertMatch(
{error, #{
kind := validation_error,
reason := <<"headers should contain only characters matching ", _/binary>>
}},
emqx_authn_api:update_config(
[authentication],
{create_authenticator, ChainName, BadConfig1}
)
),
BadConfig2 = emqx_utils_maps:deep_put(
[<<"headers">>, <<"test_哈哈"/utf8>>],
Config,
<<"test_haha">>
),
?assertMatch(
{error, #{
kind := validation_error,
reason := <<"headers should contain only characters matching ", _/binary>>
}},
emqx_authn_api:update_config(
[authentication],
{create_authenticator, ChainName, BadConfig2}
)
),
{{ok, _}, {ok, _}} =
?wait_async_action(
emqx_authn_api:update_config(
[authentication],
{create_authenticator, ChainName, Config}
),
#{?snk_kind := jwks_endpoint_response},
5_000
),
?assertReceive(
{http_request, #{
headers := #{
<<"accept">> := <<"application/json">>,
<<"foo">> := <<"bar">>
}
}}
),
{ok, _} = emqx_authn_api:update_config(
[authentication],
{delete_authenticator, ChainName, AuthenticatorId}
),
ok
end,
[]
),
ok.
t_verify_claims(_) -> t_verify_claims(_) ->
Secret = <<"abcdef">>, Secret = <<"abcdef">>,
Config0 = #{ Config0 = #{
@ -469,6 +573,16 @@ jwks_handler(Req0, State) ->
), ),
{ok, Req, State}. {ok, Req, State}.
jwks_handler_spy() ->
TestPid = self(),
fun(Req, State) ->
ReqHeaders = cowboy_req:headers(Req),
ReqMap = #{headers => ReqHeaders},
ct:pal("jwks request:\n ~p", [ReqMap]),
TestPid ! {http_request, ReqMap},
jwks_handler(Req, State)
end.
test_rsa_key(public) -> test_rsa_key(public) ->
data_file("public_key.pem"); data_file("public_key.pem");
test_rsa_key(private) -> test_rsa_key(private) ->

View File

@ -0,0 +1 @@
Added the option to add custom request headers to JWKS requests.

View File

@ -145,4 +145,9 @@ disconnect_after_expire.desc:
disconnect_after_expire.label: disconnect_after_expire.label:
"""Disconnect After Expire""" """Disconnect After Expire"""
jwks_headers.label:
"""HTTP Headers"""
jwks_headers.desc:
"""List of HTTP headers to send with the JWKS request."""
} }