From 811184ddad9a14ec4956e8d96ccb1c45ce61c743 Mon Sep 17 00:00:00 2001 From: Thales Macedo Garitezi Date: Mon, 8 Jul 2024 13:29:21 -0300 Subject: [PATCH] feat(jwks): allow specifying custom request headers Fixes https://emqx.atlassian.net/browse/EMQX-12655 --- apps/emqx/src/emqx_schema.erl | 4 +- .../src/emqx_authn_jwks_client.erl | 15 ++- .../src/emqx_authn_jwt_schema.erl | 32 +++++ .../test/emqx_authn_jwt_SUITE.erl | 116 +++++++++++++++++- changes/ce/feat-13436.en.md | 1 + rel/i18n/emqx_authn_jwt_schema.hocon | 5 + 6 files changed, 170 insertions(+), 3 deletions(-) create mode 100644 changes/ce/feat-13436.en.md diff --git a/apps/emqx/src/emqx_schema.erl b/apps/emqx/src/emqx_schema.erl index b6be28d21..d639523bb 100644 --- a/apps/emqx/src/emqx_schema.erl +++ b/apps/emqx/src/emqx_schema.erl @@ -63,6 +63,7 @@ -type json_binary() :: binary(). -type template() :: binary(). -type template_str() :: string(). +-type binary_kv() :: #{binary() => binary()}. -typerefl_from_string({duration/0, emqx_schema, to_duration}). -typerefl_from_string({duration_s/0, emqx_schema, to_duration_s}). @@ -167,7 +168,8 @@ json_binary/0, port_number/0, template/0, - template_str/0 + template_str/0, + binary_kv/0 ]). -export([namespace/0, roots/0, roots/1, fields/1, desc/1, tags/0]). diff --git a/apps/emqx_auth_jwt/src/emqx_authn_jwks_client.erl b/apps/emqx_auth_jwt/src/emqx_authn_jwks_client.erl index dc03cafef..f574a421f 100644 --- a/apps/emqx_auth_jwt/src/emqx_authn_jwks_client.erl +++ b/apps/emqx_auth_jwt/src/emqx_authn_jwks_client.erl @@ -133,11 +133,13 @@ code_change(_OldVsn, State, _Extra) -> handle_options(#{ endpoint := Endpoint, + headers := Headers, refresh_interval := RefreshInterval0, ssl_opts := SSLOpts }) -> #{ endpoint => Endpoint, + headers => to_httpc_headers(Headers), refresh_interval => limit_refresh_interval(RefreshInterval0), ssl_opts => maps:to_list(SSLOpts), jwks => [], @@ -147,6 +149,7 @@ handle_options(#{ refresh_jwks( #{ endpoint := Endpoint, + headers := Headers, ssl_opts := SSLOpts } = State ) -> @@ -159,7 +162,7 @@ refresh_jwks( case httpc:request( get, - {Endpoint, [{"Accept", "application/json"}]}, + {Endpoint, Headers}, HTTPOpts, [{body_format, binary}, {sync, false}, {receiver, self()}] ) @@ -185,6 +188,9 @@ limit_refresh_interval(Interval) when Interval < 10 -> limit_refresh_interval(Interval) -> Interval. +to_httpc_headers(Headers) -> + [{binary_to_list(bin(K)), V} || {K, V} <- maps:to_list(Headers)]. + cancel_http_request(#{request_id := undefined} = State) -> State; cancel_http_request(#{request_id := RequestID} = State) -> @@ -195,3 +201,10 @@ cancel_http_request(#{request_id := RequestID} = State) -> ok end, State#{request_id => undefined}. + +bin(List) when is_list(List) -> + unicode:characters_to_binary(List, utf8); +bin(Atom) when is_atom(Atom) -> + erlang:atom_to_binary(Atom); +bin(Bin) when is_binary(Bin) -> + Bin. diff --git a/apps/emqx_auth_jwt/src/emqx_authn_jwt_schema.erl b/apps/emqx_auth_jwt/src/emqx_authn_jwt_schema.erl index aff0f12c7..2c49121f9 100644 --- a/apps/emqx_auth_jwt/src/emqx_authn_jwt_schema.erl +++ b/apps/emqx_auth_jwt/src/emqx_authn_jwt_schema.erl @@ -95,6 +95,15 @@ fields(jwt_jwks) -> [ {use_jwks, sc(hoconsc:enum([true]), #{required => true, desc => ?DESC(use_jwks)})}, {endpoint, fun endpoint/1}, + {headers, + sc( + typerefl:alias("map", emqx_schema:binary_kv()), + #{ + default => #{<<"Accept">> => <<"application/json">>}, + validator => fun validate_headers/1, + desc => ?DESC("jwks_headers") + } + )}, {pool_size, fun emqx_connector_schema_lib:pool_size/1}, {refresh_interval, fun refresh_interval/1}, {ssl, #{ @@ -225,3 +234,26 @@ to_binary(B) when is_binary(B) -> B. sc(Type, Meta) -> hoconsc:mk(Type, Meta). + +validate_headers(undefined) -> + ok; +validate_headers(Headers) -> + BadKeys0 = + lists:filter( + fun(K) -> + re:run(K, <<"[^-0-9a-zA-Z_ ]">>, [{capture, none}]) =:= match + end, + maps:keys(Headers) + ), + case BadKeys0 of + [] -> + ok; + _ -> + BadKeys = lists:join(", ", BadKeys0), + Msg0 = io_lib:format( + "headers should contain only characters matching [-0-9a-zA-Z_ ]; bad headers: ~s", + [BadKeys] + ), + Msg = iolist_to_binary(Msg0), + {error, Msg} + end. diff --git a/apps/emqx_auth_jwt/test/emqx_authn_jwt_SUITE.erl b/apps/emqx_auth_jwt/test/emqx_authn_jwt_SUITE.erl index 8bf0cc68a..ef6efcdf2 100644 --- a/apps/emqx_auth_jwt/test/emqx_authn_jwt_SUITE.erl +++ b/apps/emqx_auth_jwt/test/emqx_authn_jwt_SUITE.erl @@ -22,18 +22,21 @@ -include_lib("common_test/include/ct.hrl"). -include_lib("eunit/include/eunit.hrl"). -include_lib("snabbkaffe/include/snabbkaffe.hrl"). +-include_lib("emqx/include/asserts.hrl"). -define(AUTHN_ID, <<"mechanism:jwt">>). -define(JWKS_PORT, 31333). -define(JWKS_PATH, "/jwks.json"). +-import(emqx_common_test_helpers, [on_exit/1]). + all() -> emqx_common_test_helpers:all(?MODULE). init_per_suite(Config) -> Apps = emqx_cth_suite:start([emqx, emqx_conf, emqx_auth, emqx_auth_jwt], #{ - work_dir => ?config(priv_dir, Config) + work_dir => emqx_cth_suite:work_dir(Config) }), [{apps, Apps} | Config]. @@ -41,6 +44,10 @@ end_per_suite(Config) -> ok = emqx_cth_suite:stop(?config(apps, Config)), ok. +end_per_testcase(_TestCase, _Config) -> + emqx_common_test_helpers:call_janitor(), + ok. + %%------------------------------------------------------------------------------ %% Tests %%------------------------------------------------------------------------------ @@ -244,6 +251,7 @@ t_jwks_renewal(_Config) -> disconnect_after_expire => false, use_jwks => true, endpoint => "https://127.0.0.1:" ++ integer_to_list(?JWKS_PORT + 1) ++ ?JWKS_PATH, + headers => #{<<"Accept">> => <<"application/json">>}, refresh_interval => 1000, pool_size => 1 }, @@ -328,6 +336,102 @@ t_jwks_renewal(_Config) -> ?assertEqual(ok, emqx_authn_jwt:destroy(State2)), ok = emqx_authn_http_test_server:stop(). +t_jwks_custom_headers(_Config) -> + {ok, _} = emqx_authn_http_test_server:start_link(?JWKS_PORT, ?JWKS_PATH, server_ssl_opts()), + on_exit(fun() -> ok = emqx_authn_http_test_server:stop() end), + ok = emqx_authn_http_test_server:set_handler(jwks_handler_spy()), + + PrivateKey = test_rsa_key(private), + Payload = #{ + <<"username">> => <<"myuser">>, + <<"foo">> => <<"myuser">>, + <<"exp">> => erlang:system_time(second) + 10 + }, + Endpoint = iolist_to_binary("https://127.0.0.1:" ++ integer_to_list(?JWKS_PORT) ++ ?JWKS_PATH), + Config0 = #{ + <<"mechanism">> => <<"jwt">>, + <<"use_jwks">> => true, + <<"from">> => <<"password">>, + <<"endpoint">> => Endpoint, + <<"headers">> => #{ + <<"Accept">> => <<"application/json">>, + <<"Content-Type">> => <<>>, + <<"foo">> => <<"bar">> + }, + <<"pool_size">> => 1, + <<"refresh_interval">> => 1_000, + <<"ssl">> => #{ + <<"keyfile">> => cert_file("client.key"), + <<"certfile">> => cert_file("client.crt"), + <<"cacertfile">> => cert_file("ca.crt"), + <<"enable">> => true, + <<"verify">> => <<"verify_peer">>, + <<"server_name_indication">> => <<"authn-server">> + }, + <<"verify_claims">> => #{<<"foo">> => <<"${username}">>} + }, + {ok, Config} = hocon:binary(hocon_pp:do(Config0, #{})), + ChainName = 'mqtt:global', + AuthenticatorId = <<"jwt">>, + ?check_trace( + #{timetrap => 10_000}, + begin + %% bad header keys + BadConfig1 = emqx_utils_maps:deep_put( + [<<"headers">>, <<"ça-va"/utf8>>], Config, <<"bien">> + ), + ?assertMatch( + {error, #{ + kind := validation_error, + reason := <<"headers should contain only characters matching ", _/binary>> + }}, + emqx_authn_api:update_config( + [authentication], + {create_authenticator, ChainName, BadConfig1} + ) + ), + BadConfig2 = emqx_utils_maps:deep_put( + [<<"headers">>, <<"test_哈哈"/utf8>>], + Config, + <<"test_haha">> + ), + ?assertMatch( + {error, #{ + kind := validation_error, + reason := <<"headers should contain only characters matching ", _/binary>> + }}, + emqx_authn_api:update_config( + [authentication], + {create_authenticator, ChainName, BadConfig2} + ) + ), + {{ok, _}, {ok, _}} = + ?wait_async_action( + emqx_authn_api:update_config( + [authentication], + {create_authenticator, ChainName, Config} + ), + #{?snk_kind := jwks_endpoint_response}, + 5_000 + ), + ?assertReceive( + {http_request, #{ + headers := #{ + <<"accept">> := <<"application/json">>, + <<"foo">> := <<"bar">> + } + }} + ), + {ok, _} = emqx_authn_api:update_config( + [authentication], + {delete_authenticator, ChainName, AuthenticatorId} + ), + ok + end, + [] + ), + ok. + t_verify_claims(_) -> Secret = <<"abcdef">>, Config0 = #{ @@ -469,6 +573,16 @@ jwks_handler(Req0, State) -> ), {ok, Req, State}. +jwks_handler_spy() -> + TestPid = self(), + fun(Req, State) -> + ReqHeaders = cowboy_req:headers(Req), + ReqMap = #{headers => ReqHeaders}, + ct:pal("jwks request:\n ~p", [ReqMap]), + TestPid ! {http_request, ReqMap}, + jwks_handler(Req, State) + end. + test_rsa_key(public) -> data_file("public_key.pem"); test_rsa_key(private) -> diff --git a/changes/ce/feat-13436.en.md b/changes/ce/feat-13436.en.md new file mode 100644 index 000000000..cda52a137 --- /dev/null +++ b/changes/ce/feat-13436.en.md @@ -0,0 +1 @@ +Added the option to add custom request headers to JWKS requests. diff --git a/rel/i18n/emqx_authn_jwt_schema.hocon b/rel/i18n/emqx_authn_jwt_schema.hocon index a7a0aad09..aadab1b68 100644 --- a/rel/i18n/emqx_authn_jwt_schema.hocon +++ b/rel/i18n/emqx_authn_jwt_schema.hocon @@ -145,4 +145,9 @@ disconnect_after_expire.desc: disconnect_after_expire.label: """Disconnect After Expire""" +jwks_headers.label: +"""HTTP Headers""" +jwks_headers.desc: +"""List of HTTP headers to send with the JWKS request.""" + }