From f297c36929815a0a178fd6960c3672a3156c0e7b Mon Sep 17 00:00:00 2001 From: zhouzb Date: Tue, 8 Jun 2021 11:40:46 +0800 Subject: [PATCH] feat(jwt auth): support JWT auth --- .../src/emqx_authentication.erl | 1 - .../src/emqx_authentication_api.erl | 82 ++-- .../emqx_authentication_jwks_connector.erl | 171 ++++++++ .../src/emqx_authentication_jwt.erl | 391 ++++++++++++++++++ .../src/emqx_authentication_mnesia.erl | 2 - 5 files changed, 603 insertions(+), 44 deletions(-) create mode 100644 apps/emqx_authentication/src/emqx_authentication_jwks_connector.erl create mode 100644 apps/emqx_authentication/src/emqx_authentication_jwt.erl diff --git a/apps/emqx_authentication/src/emqx_authentication.erl b/apps/emqx_authentication/src/emqx_authentication.erl index c41ca9249..90a234386 100644 --- a/apps/emqx_authentication/src/emqx_authentication.erl +++ b/apps/emqx_authentication/src/emqx_authentication.erl @@ -180,7 +180,6 @@ add_services(ChainID, ServiceParams) -> ok -> case create_services(ChainID, NServiceParams) of {ok, NServices} -> - io:format("~p~n", [NServices]), NChain = Chain#chain{services = Services ++ NServices}, ok = mnesia:write(?CHAIN_TAB, NChain, write), {ok, serialize_services(NServices)}; diff --git a/apps/emqx_authentication/src/emqx_authentication_api.erl b/apps/emqx_authentication/src/emqx_authentication_api.erl index f43e5c0df..74887a0b2 100644 --- a/apps/emqx_authentication/src/emqx_authentication_api.erl +++ b/apps/emqx_authentication/src/emqx_authentication_api.erl @@ -45,14 +45,14 @@ -rest_api(#{name => delete_chain, method => 'DELETE', - path => "/authentication/chains/:bin:chain_id", + path => "/authentication/chains/:bin:id", func => delete_chain, descr => "Delete chain" }). -rest_api(#{name => lookup_chain, method => 'GET', - path => "/authentication/chains/:bin:chain_id", + path => "/authentication/chains/:bin:id", func => lookup_chain, descr => "Lookup chain" }). @@ -66,77 +66,77 @@ -rest_api(#{name => add_service, method => 'POST', - path => "/authentication/chains/:bin:chain_id/services", + path => "/authentication/chains/:bin:id/services", func => add_service, descr => "Add service to chain" }). -rest_api(#{name => delete_service, method => 'DELETE', - path => "/authentication/chains/:bin:chain_id/services/:bin:service_name", + path => "/authentication/chains/:bin:id/services/:bin:service_name", func => delete_service, descr => "Delete service from chain" }). -rest_api(#{name => update_service, method => 'PUT', - path => "/authentication/chains/:bin:chain_id/services/:bin:service_name", + path => "/authentication/chains/:bin:id/services/:bin:service_name", func => update_service, descr => "Update service in chain" }). -rest_api(#{name => lookup_service, method => 'GET', - path => "/authentication/chains/:bin:chain_id/services/:bin:service_name", + path => "/authentication/chains/:bin:id/services/:bin:service_name", func => lookup_service, descr => "Lookup service in chain" }). -rest_api(#{name => list_services, method => 'GET', - path => "/authentication/chains/:bin:chain_id/services", + path => "/authentication/chains/:bin:id/services", func => list_services, descr => "List services in chain" }). -rest_api(#{name => move_service, method => 'POST', - path => "/authentication/chains/:bin:chain_id/services/:bin:service_name/position", + path => "/authentication/chains/:bin:id/services/:bin:service_name/position", func => move_service, descr => "Change the order of services" }). -rest_api(#{name => import_users, method => 'POST', - path => "/authentication/chains/:bin:chain_id/services/:bin:service_name/import-users", + path => "/authentication/chains/:bin:id/services/:bin:service_name/import-users", func => import_users, descr => "Import users" }). -rest_api(#{name => add_user, method => 'POST', - path => "/authentication/chains/:bin:chain_id/services/:bin:service_name/users", + path => "/authentication/chains/:bin:id/services/:bin:service_name/users", func => add_user, descr => "Add user" }). -rest_api(#{name => delete_user, method => 'DELETE', - path => "/authentication/chains/:bin:chain_id/services/:bin:service_name/users/:bin:user_id", + path => "/authentication/chains/:bin:id/services/:bin:service_name/users/:bin:user_id", func => delete_user, descr => "Delete user" }). -rest_api(#{name => update_user, method => 'PUT', - path => "/authentication/chains/:bin:chain_id/services/:bin:service_name/users/:bin:user_id", + path => "/authentication/chains/:bin:id/services/:bin:service_name/users/:bin:user_id", func => update_user, descr => "Update user" }). -rest_api(#{name => lookup_user, method => 'GET', - path => "/authentication/chains/:bin:chain_id/services/:bin:service_name/users/:bin:user_id", + path => "/authentication/chains/:bin:id/services/:bin:service_name/users/:bin:user_id", func => lookup_user, descr => "Lookup user" }). @@ -144,7 +144,7 @@ %% TODO: Support pagination -rest_api(#{name => list_users, method => 'GET', - path => "/authentication/chains/:bin:chain_id/services/:bin:service_name/users", + path => "/authentication/chains/:bin:id/services/:bin:service_name/users", func => list_users, descr => "List all users" }). @@ -152,20 +152,20 @@ create_chain(Binding, Params) -> do_create_chain(uri_decode(Binding), maps:from_list(Params)). -do_create_chain(_Binding, #{<<"chain_id">> := ChainID}) -> - case emqx_authentication:create_chain(ChainID) of +do_create_chain(_Binding, #{<<"id">> := ChainID}) -> + case emqx_authentication:create_chain(#{id => ChainID}) of {ok, Chain} -> return({ok, Chain}); {error, Reason} -> return(serialize_error(Reason)) end; do_create_chain(_Binding, _Params) -> - return(serialize_error({missing_parameter, chain_id})). + return(serialize_error({missing_parameter, id})). delete_chain(Binding, Params) -> do_delete_chain(uri_decode(Binding), maps:from_list(Params)). -do_delete_chain(#{chain_id := ChainID}, _Params) -> +do_delete_chain(#{id := ChainID}, _Params) -> case emqx_authentication:delete_chain(ChainID) of ok -> return(ok); @@ -176,7 +176,7 @@ do_delete_chain(#{chain_id := ChainID}, _Params) -> lookup_chain(Binding, Params) -> do_lookup_chain(uri_decode(Binding), maps:from_list(Params)). -do_lookup_chain(#{chain_id := ChainID}, _Params) -> +do_lookup_chain(#{id := ChainID}, _Params) -> case emqx_authentication:lookup_chain(ChainID) of {ok, Chain} -> return({ok, Chain}); @@ -194,9 +194,9 @@ do_list_chains(_Binding, _Params) -> add_service(Binding, Params) -> do_add_service(uri_decode(Binding), maps:from_list(Params)). -do_add_service(#{chain_id := ChainID}, #{<<"name">> := Name, - <<"type">> := Type, - <<"params">> := Params}) -> +do_add_service(#{id := ChainID}, #{<<"name">> := Name, + <<"type">> := Type, + <<"params">> := Params}) -> case emqx_authentication:add_services(ChainID, [#{name => Name, type => binary_to_existing_atom(Type, utf8), params => maps:from_list(Params)}]) of @@ -213,7 +213,7 @@ do_add_service(_Binding, Params) -> delete_service(Binding, Params) -> do_delete_service(uri_decode(Binding), maps:from_list(Params)). -do_delete_service(#{chain_id := ChainID, +do_delete_service(#{id := ChainID, service_name := ServiceName}, _Params) -> case emqx_authentication:delete_services(ChainID, [ServiceName]) of ok -> @@ -225,8 +225,8 @@ do_delete_service(#{chain_id := ChainID, update_service(Binding, Params) -> do_update_service(uri_decode(Binding), maps:from_list(Params)). -%% TOOD: PUT 方法支持创建和更新 -do_update_service(#{chain_id := ChainID, +%% TOOD: PUT method supports creation and update +do_update_service(#{id := ChainID, service_name := ServiceName}, Params) -> case emqx_authentication:update_service(ChainID, ServiceName, Params) of {ok, Service} -> @@ -238,7 +238,7 @@ do_update_service(#{chain_id := ChainID, lookup_service(Binding, Params) -> do_lookup_service(uri_decode(Binding), maps:from_list(Params)). -do_lookup_service(#{chain_id := ChainID, +do_lookup_service(#{id := ChainID, service_name := ServiceName}, _Params) -> case emqx_authentication:lookup_service(ChainID, ServiceName) of {ok, Service} -> @@ -250,7 +250,7 @@ do_lookup_service(#{chain_id := ChainID, list_services(Binding, Params) -> do_list_services(uri_decode(Binding), maps:from_list(Params)). -do_list_services(#{chain_id := ChainID}, _Params) -> +do_list_services(#{id := ChainID}, _Params) -> case emqx_authentication:list_services(ChainID) of {ok, Services} -> return({ok, Services}); @@ -261,7 +261,7 @@ do_list_services(#{chain_id := ChainID}, _Params) -> move_service(Binding, Params) -> do_move_service(uri_decode(Binding), maps:from_list(Params)). -do_move_service(#{chain_id := ChainID, +do_move_service(#{id := ChainID, service_name := ServiceName}, #{<<"position">> := <<"the front">>}) -> case emqx_authentication:move_service_to_the_front(ChainID, ServiceName) of ok -> @@ -269,7 +269,7 @@ do_move_service(#{chain_id := ChainID, {error, Reason} -> return(serialize_error(Reason)) end; -do_move_service(#{chain_id := ChainID, +do_move_service(#{id := ChainID, service_name := ServiceName}, #{<<"position">> := <<"the end">>}) -> case emqx_authentication:move_service_to_the_end(ChainID, ServiceName) of ok -> @@ -277,7 +277,7 @@ do_move_service(#{chain_id := ChainID, {error, Reason} -> return(serialize_error(Reason)) end; -do_move_service(#{chain_id := ChainID, +do_move_service(#{id := ChainID, service_name := ServiceName}, #{<<"position">> := N}) when is_number(N) -> case emqx_authentication:move_service_to_the_nth(ChainID, ServiceName, N) of ok -> @@ -291,7 +291,7 @@ do_move_service(_Binding, _Params) -> import_users(Binding, Params) -> do_import_users(uri_decode(Binding), maps:from_list(Params)). -do_import_users(#{chain_id := ChainID, service_name := ServiceName}, +do_import_users(#{id := ChainID, service_name := ServiceName}, #{<<"filename">> := Filename}) -> case emqx_authentication:import_users(ChainID, ServiceName, Filename) of ok -> @@ -306,7 +306,7 @@ do_import_users(_Binding, Params) -> add_user(Binding, Params) -> do_add_user(uri_decode(Binding), maps:from_list(Params)). -do_add_user(#{chain_id := ChainID, +do_add_user(#{id := ChainID, service_name := ServiceName}, UserInfo) -> case emqx_authentication:add_user(ChainID, ServiceName, UserInfo) of {ok, User} -> @@ -318,7 +318,7 @@ do_add_user(#{chain_id := ChainID, delete_user(Binding, Params) -> do_delete_user(uri_decode(Binding), maps:from_list(Params)). -do_delete_user(#{chain_id := ChainID, +do_delete_user(#{id := ChainID, service_name := ServiceName, user_id := UserID}, _Params) -> case emqx_authentication:delete_user(ChainID, ServiceName, UserID) of @@ -331,7 +331,7 @@ do_delete_user(#{chain_id := ChainID, update_user(Binding, Params) -> do_update_user(uri_decode(Binding), maps:from_list(Params)). -do_update_user(#{chain_id := ChainID, +do_update_user(#{id := ChainID, service_name := ServiceName, user_id := UserID}, NewUserInfo) -> case emqx_authentication:update_user(ChainID, ServiceName, UserID, NewUserInfo) of @@ -344,7 +344,7 @@ do_update_user(#{chain_id := ChainID, lookup_user(Binding, Params) -> do_lookup_user(uri_decode(Binding), maps:from_list(Params)). -do_lookup_user(#{chain_id := ChainID, +do_lookup_user(#{id := ChainID, service_name := ServiceName, user_id := UserID}, _Params) -> case emqx_authentication:lookup_user(ChainID, ServiceName, UserID) of @@ -357,7 +357,7 @@ do_lookup_user(#{chain_id := ChainID, list_users(Binding, Params) -> do_list_users(uri_decode(Binding), maps:from_list(Params)). -do_list_users(#{chain_id := ChainID, +do_list_users(#{id := ChainID, service_name := ServiceName}, _Params) -> case emqx_authentication:list_users(ChainID, ServiceName) of {ok, Users} -> @@ -376,17 +376,17 @@ uri_decode(Params) -> end, #{}, Params). serialize_error({already_exists, {Type, ID}}) -> - {error, <<"ALREADY_EXISTS">>, list_to_binary(io_lib:format("~p ~p already exists", [serialize_type(Type), ID]))}; + {error, <<"ALREADY_EXISTS">>, list_to_binary(io_lib:format("~s '~s' already exists", [serialize_type(Type), ID]))}; serialize_error({not_found, {Type, ID}}) -> - {error, <<"NOT_FOUND">>, list_to_binary(io_lib:format("~p ~p not found", [serialize_type(Type), ID]))}; + {error, <<"NOT_FOUND">>, list_to_binary(io_lib:format("~s '~s' not found", [serialize_type(Type), ID]))}; serialize_error({duplicate, Name}) -> - {error, <<"INVALID_PARAMETER">>, list_to_binary(io_lib:format("Service name ~p is duplicated", [Name]))}; + {error, <<"INVALID_PARAMETER">>, list_to_binary(io_lib:format("Service name '~s' is duplicated", [Name]))}; serialize_error({missing_parameter, Names = [_ | Rest]}) -> - Format = ["~p," || _ <- Rest] ++ ["~p"], + Format = ["~s," || _ <- Rest] ++ ["~s"], NFormat = binary_to_list(iolist_to_binary(Format)), {error, <<"MISSING_PARAMETER">>, list_to_binary(io_lib:format("The input parameters " ++ NFormat ++ " that are mandatory for processing this request are not supplied.", Names))}; serialize_error({missing_parameter, Name}) -> - {error, <<"MISSING_PARAMETER">>, list_to_binary(io_lib:format("The input parameter ~p that is mandatory for processing this request is not supplied.", [Name]))}; + {error, <<"MISSING_PARAMETER">>, list_to_binary(io_lib:format("The input parameter '~s' that is mandatory for processing this request is not supplied.", [Name]))}; serialize_error(_) -> {error, <<"UNKNOWN_ERROR">>, <<"Unknown error">>}. diff --git a/apps/emqx_authentication/src/emqx_authentication_jwks_connector.erl b/apps/emqx_authentication/src/emqx_authentication_jwks_connector.erl new file mode 100644 index 000000000..9dafc9f5e --- /dev/null +++ b/apps/emqx_authentication/src/emqx_authentication_jwks_connector.erl @@ -0,0 +1,171 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_authentication_jwks_connector). + +-behaviour(gen_server). + +-include_lib("emqx/include/logger.hrl"). +-include_lib("jose/include/jose_jwk.hrl"). + +-export([ start_link/1 + , stop/1 + ]). + +-export([ get_jwks/1 + , update/2 + ]). + +%% gen_server callbacks +-export([ init/1 + , handle_call/3 + , handle_cast/2 + , handle_info/2 + , terminate/2 + , code_change/3 + ]). + +%%-------------------------------------------------------------------- +%% APIs +%%-------------------------------------------------------------------- + +start_link(Opts) -> + gen_server:start_link(?MODULE, [Opts], []). + +stop(Pid) -> + gen_server:stop(Pid). + +get_jwks(Pid) -> + gen_server:call(Pid, get_cached_jwks, 5000). + +update(Pid, Opts) -> + gen_server:call(Pid, {update, Opts}, 5000). + +%%-------------------------------------------------------------------- +%% gen_server callbacks +%%-------------------------------------------------------------------- + +init([Opts]) -> + ok = jose:json_module(jiffy), + State = handle_options(Opts), + {ok, refresh_jwks(State)}. + +handle_call(get_cached_jwks, _From, #{jwks := Jwks} = State) -> + {reply, {ok, Jwks}, State}; + +handle_call({update, Opts}, _From, State) -> + State = handle_options(Opts), + {reply, ok, refresh_jwks(State)}; + +handle_call(_Req, _From, State) -> + {reply, ok, State}. + +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info({refresh_jwks, _TRef, refresh}, #{request_id := RequestID} = State) -> + case RequestID of + undefined -> ok; + _ -> + ok = httpc:cancel_request(RequestID), + receive + {http, _} -> ok + after 0 -> + ok + end + end, + {noreply, refresh_jwks(State)}; + +handle_info({http, {RequestID, Result}}, + #{request_id := RequestID, endpoint := Endpoint} = State0) -> + State1 = State0#{request_id := undefined}, + case Result of + {error, Reason} -> + ?LOG(error, "Failed to request jwks endpoint(~s): ~p", [Endpoint, Reason]), + State1; + {_StatusLine, _Headers, Body} -> + try + JWKS = jose_jwk:from(emqx_json:decode(Body, [return_maps])), + {_, JWKs} = JWKS#jose_jwk.keys, + State1#{jwks := JWKs} + catch _:_ -> + ?LOG(error, "Invalid jwks returned from jwks endpoint(~s): ~p~n", [Endpoint, Body]), + State1 + end + end; + +handle_info({http, {_, _}}, State) -> + %% ignore + {noreply, State}; + +handle_info(_Info, State) -> + {noreply, State}. + +terminate(_Reason, State) -> + _ = cancel_timer(State), + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%% Internal functions +%%-------------------------------------------------------------------- + +handle_options(Opts) -> + #{endpoint => proplists:get_value(jwks_endpoint, Opts), + refresh_interval => limit_refresh_interval(proplists:get_value(refresh_interval, Opts)), + ssl_opts => get_ssl_opts(Opts), + jwks => [], + request_id => undefined}. + +get_ssl_opts(Opts) -> + case proplists:get_value(enable_ssl, Opts) of + false -> []; + true -> + maps:to_list(maps:with([cacertfile, + keyfile, + certfile, + verify, + server_name_indication], maps:from_list(Opts))) + end. + +refresh_jwks(#{endpoint := Endpoint, + ssl_opts := SSLOpts} = State) -> + HTTPOpts = [{timeout, 5000}, {connect_timeout, 5000}, {ssl, SSLOpts}], + NState = case httpc:request(get, {Endpoint, [{"Accept", "application/json"}]}, HTTPOpts, + [{body_format, binary}, {sync, false}, {receiver, self()}]) of + {error, Reason} -> + ?LOG(error, "Failed to request jwks endpoint(~s): ~p", [Endpoint, Reason]), + State; + {ok, RequestID} -> + State#{request_id := RequestID} + end, + ensure_expiry_timer(NState). + +ensure_expiry_timer(State = #{refresh_interval := Interval}) -> + State#{refresh_timer := emqx_misc:start_timer(timer:seconds(Interval), refresh_jwks)}. + +cancel_timer(State = #{refresh_timer := undefined}) -> + State; +cancel_timer(State = #{refresh_timer := TRef}) -> + _ = emqx_misc:cancel_timer(TRef), + State#{refresh_timer := undefined}. + +limit_refresh_interval(Interval) when Interval < 10 -> + 10; +limit_refresh_interval(Interval) -> + Interval. diff --git a/apps/emqx_authentication/src/emqx_authentication_jwt.erl b/apps/emqx_authentication/src/emqx_authentication_jwt.erl new file mode 100644 index 000000000..f80c90a36 --- /dev/null +++ b/apps/emqx_authentication/src/emqx_authentication_jwt.erl @@ -0,0 +1,391 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_authentication_jwt). + +-export([ create/3 + , update/4 + , authenticate/2 + , destroy/1 + ]). + +-service_type(#{ + name => jwt, + params_spec => #{ + use_jwks => #{ + order => 1, + type => boolean + }, + jwks_endpoint => #{ + order => 2, + type => string + }, + refresh_interval => #{ + order => 3, + type => number + }, + algorithm => #{ + order => 3, + type => string, + enum => [<<"hmac-based">>, <<"public-key">>] + }, + secret => #{ + order => 4, + type => string + }, + jwt_certfile => #{ + order => 5, + type => file + }, + cacertfile => #{ + order => 6, + type => file + }, + keyfile => #{ + order => 7, + type => file + }, + certfile => #{ + order => 8, + type => file + }, + verify => #{ + order => 9, + type => boolean + }, + server_name_indication => #{ + order => 10, + type => string + } + } +}). + +-define(RULES, + #{ + use_jwks => [], + jwks_endpoint => [use_jwks], + refresh_interval => [use_jwks], + algorithm => [use_jwks], + secret => [algorithm], + jwt_certfile => [algorithm], + cacertfile => [jwks_endpoint], + keyfile => [jwks_endpoint], + certfile => [jwks_endpoint], + verify => [jwks_endpoint], + server_name_indication => [jwks_endpoint], + verify_claims => [] + }). + +create(_ChainID, _ServiceName, Params) -> + try handle_options(Params) of + Opts -> + do_create(Opts) + catch + {error, Reason} -> + {error, Reason} + end. + +update(_ChainID, _ServiceName, Params, State) -> + try handle_options(Params) of + Opts -> + do_update(Opts, State) + catch + {error, Reason} -> + {error, Reason} + end. + +authenticate(ClientInfo = #{password := JWT}, #{jwk := JWK, + jwks_connector := Connector, + verify_claims := VerifyClaims0}) -> + JWKs = case Connector of + undefined -> + [JWK]; + _ -> + {ok, JWKs0} = emqx_authentication_jwks_connector:get_jwks(Connector), + JWKs0 + end, + VerifyClaims = replace_placeholder(VerifyClaims0, ClientInfo), + verify(JWT, JWKs, VerifyClaims). + +destroy(#{jwks_connector := undefined}) -> + ok; +destroy(#{jwks_connector := Connector}) -> + emqx_authentication_jwks_connector:stop(Connector), + ok. + +%%-------------------------------------------------------------------- +%% Internal functions +%%-------------------------------------------------------------------- + +do_create(#{use_jwks := false, + algorithm := 'hmac-based', + secret := Secret, + verify_claims := VerifyClaims}) -> + JWK = jose_jwk:from_oct(Secret), + {ok, #{jwk => JWK, + jwks_connector => undefined, + verify_claims => VerifyClaims}}; +do_create(#{use_jwks := false, + algorithm := 'public-key', + jwt_certfile := Certfile, + verify_claims := VerifyClaims}) -> + JWK = jose_jwk:from_pem_file(Certfile), + {ok, #{jwk => JWK, + jwks_connector => undefined, + verify_claims => VerifyClaims}}; +do_create(#{use_jwks := true, + verify_claims := VerifyClaims} = Opts) -> + case emqx_authentication_jwks_connector:start_link(Opts) of + {ok, Connector} -> + {ok, #{jwk => undefined, + jwks_connector => Connector, + verify_claims => VerifyClaims}}; + {error, Reason} -> + {error, Reason} + end. + +do_update(Opts, #{jwk_connector := undefined}) -> + do_create(Opts); +do_update(#{use_jwks := false} = Opts, #{jwk_connector := Connector}) -> + emqx_authentication_jwks_connector:stop(Connector), + do_create(Opts); +do_update(#{use_jwks := true} = Opts, #{jwk_connector := Connector} = State) -> + ok = emqx_authentication_jwks_connector:update(Connector, Opts), + {ok, State}. + +replace_placeholder(L, Variables) -> + replace_placeholder(L, Variables, []). + +replace_placeholder([], _Variables, Acc) -> + Acc; +replace_placeholder([{Name, {placeholder, PL}} | More], Variables, Acc) -> + Value = maps:get(PL, Variables), + replace_placeholder(More, Variables, [{Name, Value} | Acc]); +replace_placeholder([{Name, Value} | More], Variables, Acc) -> + replace_placeholder(More, Variables, [{Name, Value} | Acc]). + +verify(_JWS, [], _VerifyClaims) -> + {error, invalid_signature}; +verify(JWS, [JWK | More], VerifyClaims) -> + case jose_jws:verify(JWK, JWS) of + {true, Payload, _JWS} -> + Claims = emqx_json:decode(Payload, [return_maps]), + verify_claims(Claims, VerifyClaims); + {false, _, _} -> + verify(JWS, More, VerifyClaims) + end. + +verify_claims(Claims, VerifyClaims0) -> + Now = os:system_time(seconds), + VerifyClaims = [{<<"exp">>, fun(ExpireTime) -> + Now < ExpireTime + end}, + {<<"iat">>, fun(IssueAt) -> + IssueAt =< Now + end}, + {<<"nbf">>, fun(NotBefore) -> + NotBefore =< Now + end}] ++ VerifyClaims0, + do_verify_claims(Claims, VerifyClaims). + +do_verify_claims(_Claims, []) -> + ok; +do_verify_claims(Claims, [{Name, Fun} | More]) when is_function(Fun) -> + case maps:take(Name, Claims) of + error -> + do_verify_claims(Claims, More); + {Value, NClaims} -> + case Fun(Value) of + true -> + do_verify_claims(NClaims, More); + _ -> + {error, {claims, {Name, Value}}} + end + end; +do_verify_claims(Claims, [{Name, Value} | More]) -> + case maps:take(Name, Claims) of + error -> + do_verify_claims(Claims, More); + {Value, NClaims} -> + do_verify_claims(NClaims, More); + {Value0, _} -> + {error, {claims, {Name, Value0}}} + end. + +handle_options(Opts0) when is_map(Opts0) -> + Ks = maps:fold(fun(K, _, Acc) -> + [atom_to_binary(K, utf8) | Acc] + end, [], ?RULES), + Opts1 = maps:to_list(maps:with(Ks, Opts0)), + handle_options([{binary_to_existing_atom(K, utf8), V} || {K, V} <- Opts1]); + +handle_options(Opts0) when is_list(Opts0) -> + Opts1 = add_missing_options(Opts0), + process_options({Opts1, [], length(Opts1)}, #{}). + +add_missing_options(Opts) -> + AllOpts = maps:keys(?RULES), + Fun = fun(K, Acc) -> + case proplists:is_defined(K, Acc) of + true -> + Acc; + false -> + [{K, unbound} | Acc] + end + end, + lists:foldl(Fun, Opts, AllOpts). + +process_options({[], [], _}, OptsMap) -> + OptsMap; +process_options({[], Skipped, Counter}, OptsMap) + when length(Skipped) < Counter -> + process_options({Skipped, [], length(Skipped)}, OptsMap); +process_options({[], _Skipped, _Counter}, _OptsMap) -> + throw({error, faulty_configuration}); +process_options({[{K, V} = Opt | More], Skipped, Counter}, OptsMap0) -> + case check_dependencies(K, OptsMap0) of + true -> + OptsMap1 = handle_option(K, V, OptsMap0), + process_options({More, Skipped, Counter}, OptsMap1); + false -> + process_options({More, [Opt | Skipped], Counter}, OptsMap0) + end. + +%% TODO: This is not a particularly good implementation(K => needless), it needs to be improved +handle_option(use_jwks, true, OptsMap) -> + OptsMap#{use_jwks => true, + algorithm => needless}; +handle_option(use_jwks, false, OptsMap) -> + OptsMap#{use_jwks => false, + jwks_endpoint => needless}; +handle_option(jwks_endpoint = Opt, unbound, #{use_jwks := true}) -> + throw({error, {options, {Opt, unbound}}}); +handle_option(jwks_endpoint, Value, #{use_jwks := true} = OptsMap) + when Value =/= unbound -> + case emqx_http_lib:uri_parse(Value) of + {ok, #{scheme := http}} -> + OptsMap#{enable_ssl => false, + jwks_endpoint => Value}; + {ok, #{scheme := https}} -> + OptsMap#{enable_ssl => true, + jwks_endpoint => Value}; + {error, _Reason} -> + throw({error, {options, {jwks_endpoint, Value}}}) + end; +handle_option(refresh_interval = Opt, Value0, #{use_jwks := true} = OptsMap) -> + Value = validate_option(Opt, Value0), + OptsMap#{Opt => Value}; +handle_option(algorithm = Opt, Value0, #{use_jwks := false} = OptsMap) -> + Value = validate_option(Opt, Value0), + OptsMap#{Opt => Value}; +handle_option(secret = Opt, unbound, #{algorithm := 'hmac-based'}) -> + throw({error, {options, {Opt, unbound}}}); +handle_option(secret = Opt, Value, #{algorithm := 'hmac-based'} = OptsMap) -> + OptsMap#{Opt => Value}; +handle_option(jwt_certfile = Opt, unbound, #{algorithm := 'public-key'}) -> + throw({error, {options, {Opt, unbound}}}); +handle_option(jwt_certfile = Opt, Value, #{algorithm := 'public-key'} = OptsMap) -> + OptsMap#{Opt => Value}; +handle_option(verify = Opt, Value0, #{enable_ssl := true} = OptsMap) -> + Value = validate_option(Opt, Value0), + OptsMap#{Opt => Value}; +handle_option(cacertfile = Opt, Value, #{enable_ssl := true} = OptsMap) + when Value =/= unbound -> + OptsMap#{Opt => Value}; +handle_option(certfile, unbound, #{enable_ssl := true} = OptsMap) -> + OptsMap; +handle_option(certfile = Opt, Value, #{enable_ssl := true} = OptsMap) -> + OptsMap#{Opt => Value}; +handle_option(keyfile, unbound, #{enable_ssl := true} = OptsMap) -> + OptsMap; +handle_option(keyfile = Opt, Value, #{enable_ssl := true} = OptsMap) -> + OptsMap#{Opt => Value}; +handle_option(server_name_indication = Opt, Value0, #{enable_ssl := true} = OptsMap) -> + Value = validate_option(Opt, Value0), + OptsMap#{Opt => Value}; +handle_option(verify_claims = Opt, Value0, OptsMap) -> + Value = handle_verify_claims(Value0), + OptsMap#{Opt => Value}; +handle_option(_Opt, _Value, OptsMap) -> + OptsMap. + +validate_option(refresh_interval, unbound) -> + 300; +validate_option(refresh_interval, Value) when is_integer(Value) -> + Value; +validate_option(algorithm, <<"hmac-based">>) -> + 'hmac-based'; +validate_option(algorithm, <<"public-key">>) -> + 'public-key'; +validate_option(verify, unbound) -> + verify_none; +validate_option(verify, true) -> + verify_peer; +validate_option(verify, false) -> + verify_none; +validate_option(server_name_indication, unbound) -> + disable; +validate_option(server_name_indication, <<"disable">>) -> + disable; +validate_option(server_name_indication, Value) when is_list(Value) -> + Value; +validate_option(Opt, Value) -> + throw({error, {options, {Opt, Value}}}). + +handle_verify_claims(Opts0) -> + try handle_verify_claims(Opts0, []) + catch + error:_ -> + throw({error, {options, {verify_claims, Opts0}}}) + end. + +handle_verify_claims([], Acc) -> + Acc; +handle_verify_claims([{Name, Expected0} | More], Acc) + when is_binary(Name) andalso is_binary(Expected0) -> + Expected = handle_placeholder(Expected0), + handle_verify_claims(More, [{Name, Expected} | Acc]). + +handle_placeholder(Placeholder0) -> + case re:run(Placeholder0, "^\\$\\{[a-z0-9\\_]+\\}$", [{capture, all}]) of + {match, [{Offset, Length}]} -> + Placeholder1 = binary:part(Placeholder0, Offset + 2, Length - 3), + Placeholder2 = validate_placeholder(Placeholder1), + {placeholder, Placeholder2}; + nomatch -> + Placeholder0 + end. + +validate_placeholder(<<"clientid">>) -> + clientid; +validate_placeholder(<<"username">>) -> + username. + +check_dependencies(Opt, OptsMap) -> + case maps:get(Opt, ?RULES) of + [] -> + true; + Deps -> + option_already_defined(Opt, OptsMap) orelse + dependecies_already_defined(Deps, OptsMap) + end. + +option_already_defined(Opt, OptsMap) -> + maps:get(Opt, OptsMap, unbound) =/= unbound. + +dependecies_already_defined(Deps, OptsMap) -> + Fun = fun(Opt) -> option_already_defined(Opt, OptsMap) end, + lists:all(Fun, Deps). diff --git a/apps/emqx_authentication/src/emqx_authentication_mnesia.erl b/apps/emqx_authentication/src/emqx_authentication_mnesia.erl index 0fa28d178..d671a234d 100644 --- a/apps/emqx_authentication/src/emqx_authentication_mnesia.erl +++ b/apps/emqx_authentication/src/emqx_authentication_mnesia.erl @@ -52,8 +52,6 @@ } }). - - -record(user_info, { user_id :: {user_group(), user_id()} , password_hash :: binary()