diff --git a/apps/emqx/src/emqx_authentication.erl b/apps/emqx/src/emqx_authentication.erl index 9b249c8a7..7df2036ac 100644 --- a/apps/emqx/src/emqx_authentication.erl +++ b/apps/emqx/src/emqx_authentication.erl @@ -40,8 +40,10 @@ , stop/0 ]). --export([ add_provider/2 - , remove_provider/1 +-export([ register_provider/2 + , register_providers/1 + , deregister_provider/1 + , deregister_providers/1 , create_chain/1 , delete_chain/1 , lookup_chain/1 @@ -334,13 +336,27 @@ stop() -> get_refs() -> gen_server:call(?MODULE, get_refs). --spec add_provider(authn_type(), module()) -> ok. -add_provider(AuthNType, Provider) -> - gen_server:call(?MODULE, {add_provider, AuthNType, Provider}). +%% @doc Register authentication providers. +%% A provider is a tuple of `AuthNType' the module which implements +%% the authenticator callbacks. +%% For example, ``[{{'password-based', redis}, emqx_authn_redis}]'' +%% NOTE: Later registered provider may override earlier registered if they +%% happen to clash the same `AuthNType'. +-spec register_providers([{authn_type(), module()}]) -> ok. +register_providers(Providers) -> + gen_server:call(?MODULE, {register_providers, Providers}). --spec remove_provider(authn_type()) -> ok. -remove_provider(AuthNType) -> - gen_server:call(?MODULE, {remove_provider, AuthNType}). +-spec register_provider(authn_type(), module()) -> ok. +register_provider(AuthNType, Provider) -> + register_providers([{AuthNType, Provider}]). + +-spec deregister_providers([authn_type()]) -> ok. +deregister_providers(AuthNTypes) when is_list(AuthNTypes) -> + gen_server:call(?MODULE, {deregister_providers, AuthNTypes}). + +-spec deregister_provider(authn_type()) -> ok. +deregister_provider(AuthNType) -> + deregister_providers([AuthNType]). -spec create_chain(chain_name()) -> {ok, chain()} | {error, term()}. create_chain(Name) -> @@ -447,11 +463,20 @@ init(_Opts) -> ok = emqx_config_handler:add_handler([listeners, '?', '?', authentication], ?MODULE), {ok, #{hooked => false, providers => #{}}}. -handle_call({add_provider, AuthNType, Provider}, _From, #{providers := Providers} = State) -> - reply(ok, State#{providers := Providers#{AuthNType => Provider}}); +handle_call({register_providers, Providers}, _From, + #{providers := Reg0} = State) -> + case lists:filter(fun({T, _}) -> maps:is_key(T, Reg0) end, Providers) of + [] -> + Reg = lists:foldl(fun({AuthNType, Module}, Pin) -> + Pin#{AuthNType => Module} + end, Reg0, Providers), + reply(ok, State#{providers := Reg}); + Clashes -> + reply({error, {authentication_type_clash, Clashes}}, State) + end; -handle_call({remove_provider, AuthNType}, _From, #{providers := Providers} = State) -> - reply(ok, State#{providers := maps:remove(AuthNType, Providers)}); +handle_call({deregister_providers, AuthNTypes}, _From, #{providers := Providers} = State) -> + reply(ok, State#{providers := maps:without(AuthNTypes, Providers)}); handle_call(get_refs, _From, #{providers := Providers} = State) -> Refs = lists:foldl(fun({_, Provider}, Acc) -> diff --git a/apps/emqx/test/emqx_authentication_SUITE.erl b/apps/emqx/test/emqx_authentication_SUITE.erl index ff219e64c..15aecc269 100644 --- a/apps/emqx/test/emqx_authentication_SUITE.erl +++ b/apps/emqx/test/emqx_authentication_SUITE.erl @@ -36,6 +36,7 @@ ]). -define(AUTHN, emqx_authentication). +-define(config(KEY), (fun() -> {KEY, _V_} = lists:keyfind(KEY, 1, Config), _V_ end)()). %%------------------------------------------------------------------------------ %% Hocon Schema @@ -92,20 +93,22 @@ end_per_suite(_) -> emqx_ct_helpers:stop_apps([]), ok. -init_per_testcase(_, Config) -> +init_per_testcase(Case, Config) -> meck:new(emqx, [non_strict, passthrough, no_history, no_link]), meck:expect(emqx, get_config, fun([node, data_dir]) -> {data_dir, Data} = lists:keyfind(data_dir, 1, Config), Data; (C) -> meck:passthrough([C]) end), - Config. + ?MODULE:Case({'init', Config}). -end_per_testcase(_, _Config) -> +end_per_testcase(Case, Config) -> + _ = ?MODULE:Case({'end', Config}), meck:unload(emqx), ok. -t_chain(_) -> +t_chain({_, Config}) -> Config; +t_chain(Config) when is_list(Config) -> % CRUD of authentication chain ChainName = 'test', ?assertMatch({ok, []}, ?AUTHN:list_chains()), @@ -117,7 +120,10 @@ t_chain(_) -> ?assertMatch({error, {not_found, {chain, ChainName}}}, ?AUTHN:lookup_chain(ChainName)), ok. -t_authenticator(_) -> +t_authenticator({'init', Config}) -> + [{"auth1", {'password-based', 'built-in-database'}}, + {"auth2", {'password-based', mysql}} | Config]; +t_authenticator(Config) when is_list(Config) -> ChainName = 'test', AuthenticatorConfig1 = #{mechanism => 'password-based', backend => 'built-in-database', @@ -129,8 +135,8 @@ t_authenticator(_) -> % Create an authenticator when the provider does not exist ?assertEqual({error, no_available_provider}, ?AUTHN:create_authenticator(ChainName, AuthenticatorConfig1)), - AuthNType1 = {'password-based', 'built-in-database'}, - ?AUTHN:add_provider(AuthNType1, ?MODULE), + AuthNType1 = ?config("auth1"), + register_provider(AuthNType1, ?MODULE), ID1 = <<"password-based:built-in-database">>, % CRUD of authencaticator @@ -144,8 +150,8 @@ t_authenticator(_) -> ?assertMatch({ok, []}, ?AUTHN:list_authenticators(ChainName)), % Multiple authenticators exist at the same time - AuthNType2 = {'password-based', mysql}, - ?AUTHN:add_provider(AuthNType2, ?MODULE), + AuthNType2 = ?config("auth2"), + register_provider(AuthNType2, ?MODULE), ID2 = <<"password-based:mysql">>, AuthenticatorConfig2 = #{mechanism => 'password-based', backend => mysql, @@ -160,15 +166,18 @@ t_authenticator(_) -> ?assertEqual(ok, ?AUTHN:move_authenticator(ChainName, ID2, bottom)), ?assertMatch({ok, [#{id := ID1}, #{id := ID2}]}, ?AUTHN:list_authenticators(ChainName)), ?assertEqual(ok, ?AUTHN:move_authenticator(ChainName, ID2, {before, ID1})), - ?assertMatch({ok, [#{id := ID2}, #{id := ID1}]}, ?AUTHN:list_authenticators(ChainName)), - - ?AUTHN:delete_chain(ChainName), - ?AUTHN:remove_provider(AuthNType1), - ?AUTHN:remove_provider(AuthNType2), + ?assertMatch({ok, [#{id := ID2}, #{id := ID1}]}, ?AUTHN:list_authenticators(ChainName)); +t_authenticator({'end', Config}) -> + ?AUTHN:delete_chain(test), + ?AUTHN:deregister_providers([?config("auth1"), ?config("auth2")]), ok. -t_authenticate(_) -> - ListenerID = 'tcp:default', +t_authenticate({init, Config}) -> + [{listener_id, 'tcp:default'}, + {authn_type, {'password-based', 'built-in-database'}} | Config]; +t_authenticate(Config) when is_list(Config) -> + ListenerID = ?config(listener_id), + AuthNType = ?config(authn_type), ClientInfo = #{zone => default, listener => ListenerID, protocol => mqtt, @@ -176,8 +185,7 @@ t_authenticate(_) -> password => <<"any">>}, ?assertEqual({ok, #{is_superuser => false}}, emqx_access_control:authenticate(ClientInfo)), - AuthNType = {'password-based', 'built-in-database'}, - ?AUTHN:add_provider(AuthNType, ?MODULE), + register_provider(AuthNType, ?MODULE), AuthenticatorConfig = #{mechanism => 'password-based', backend => 'built-in-database', @@ -185,21 +193,24 @@ t_authenticate(_) -> ?AUTHN:create_chain(ListenerID), ?assertMatch({ok, _}, ?AUTHN:create_authenticator(ListenerID, AuthenticatorConfig)), ?assertEqual({ok, #{is_superuser => true}}, emqx_access_control:authenticate(ClientInfo)), - ?assertEqual({error, bad_username_or_password}, emqx_access_control:authenticate(ClientInfo#{username => <<"bad">>})), - - ?AUTHN:delete_chain(ListenerID), - ?AUTHN:remove_provider(AuthNType), + ?assertEqual({error, bad_username_or_password}, emqx_access_control:authenticate(ClientInfo#{username => <<"bad">>})); +t_authenticate({'end', Config}) -> + ?AUTHN:delete_chain(?config(listener_id)), + ?AUTHN:deregister_provider(?config(authn_type)), ok. -t_update_config(_) -> - emqx_config_handler:add_handler([authentication], emqx_authentication), - +t_update_config({init, Config}) -> + Global = 'mqtt:global', AuthNType1 = {'password-based', 'built-in-database'}, AuthNType2 = {'password-based', mysql}, - ?AUTHN:add_provider(AuthNType1, ?MODULE), - ?AUTHN:add_provider(AuthNType2, ?MODULE), - - Global = 'mqtt:global', + [{global, Global}, + {"auth1", AuthNType1}, + {"auth2", AuthNType2} | Config]; +t_update_config(Config) when is_list(Config) -> + emqx_config_handler:add_handler([authentication], emqx_authentication), + ok = register_provider(?config("auth1"), ?MODULE), + ok = register_provider(?config("auth2"), ?MODULE), + Global = ?config(global), AuthenticatorConfig1 = #{mechanism => 'password-based', backend => 'built-in-database', enable => true}, @@ -208,7 +219,7 @@ t_update_config(_) -> enable => true}, ID1 = <<"password-based:built-in-database">>, ID2 = <<"password-based:mysql">>, - + ?assertMatch({ok, []}, ?AUTHN:list_chains()), ?assertMatch({ok, _}, update_config([authentication], {create_authenticator, Global, AuthenticatorConfig1})), ?assertMatch({ok, #{id := ID1, state := #{mark := 1}}}, ?AUTHN:lookup_authenticator(Global, ID1)), @@ -240,14 +251,14 @@ t_update_config(_) -> ?assertMatch({ok, [#{id := ID2}, #{id := ID1}]}, ?AUTHN:list_authenticators(ListenerID)), ?assertMatch({ok, _}, update_config(ConfKeyPath, {delete_authenticator, ListenerID, ID1})), - ?assertEqual({error, {not_found, {authenticator, ID1}}}, ?AUTHN:lookup_authenticator(ListenerID, ID1)), - - ?AUTHN:delete_chain(Global), - ?AUTHN:remove_provider(AuthNType1), - ?AUTHN:remove_provider(AuthNType2), + ?assertEqual({error, {not_found, {authenticator, ID1}}}, ?AUTHN:lookup_authenticator(ListenerID, ID1)); +t_update_config({'end', Config}) -> + ?AUTHN:delete_chain(?config(global)), + ?AUTHN:deregister_providers([?config("auth1"), ?config("auth2")]), ok. -t_convert_cert_options(_) -> +t_convert_cert_options({_, Config}) -> Config; +t_convert_cert_options(Config) when is_list(Config) -> Certs = certs([ {<<"keyfile">>, "key.pem"} , {<<"certfile">>, "cert.pem"} , {<<"cacertfile">>, "cacert.pem"} @@ -284,4 +295,7 @@ certs(Certs) -> diff_cert(CertFile, CertPem2) -> {ok, CertPem1} = file:read_file(CertFile), - ?AUTHN:diff_cert(CertPem1, CertPem2). \ No newline at end of file + ?AUTHN:diff_cert(CertPem1, CertPem2). + +register_provider(Type, Module) -> + ok = ?AUTHN:register_providers([{Type, Module}]). diff --git a/apps/emqx_authn/src/emqx_authn_app.erl b/apps/emqx_authn/src/emqx_authn_app.erl index 98d53e438..d297c9042 100644 --- a/apps/emqx_authn/src/emqx_authn_app.erl +++ b/apps/emqx_authn/src/emqx_authn_app.erl @@ -32,34 +32,27 @@ start(_StartType, _StartArgs) -> ok = ekka_rlog:wait_for_shards([?AUTH_SHARD], infinity), {ok, Sup} = emqx_authn_sup:start_link(), - ok = add_providers(), + ok = ?AUTHN:register_providers(providers()), ok = initialize(), {ok, Sup}. stop(_State) -> - ok = remove_providers(), + ok = ?AUTHN:deregister_providers(provider_types()), ok. %%------------------------------------------------------------------------------ %% Internal functions %%------------------------------------------------------------------------------ -add_providers() -> - lists:foreach(fun(AuthNType, Provider}) -> - ?AUTHN:add_provider(AuthNType, Provider) - end, providers()). - -remove_providers() -> - lists:foreach(fun({AuthNType, _}) -> - ?AUTHN:remove_provider(AuthNType) - end, providers()). - initialize() -> ?AUTHN:initialize_authentication(?GLOBAL, emqx:get_raw_config([authentication], [])), lists:foreach(fun({ListenerID, ListenerConfig}) -> ?AUTHN:initialize_authentication(ListenerID, maps:get(authentication, ListenerConfig, [])) end, emqx_listeners:list()). +provider_types() -> + lists:map(fun({Type, _Module}) -> Type end, providers()). + providers() -> [ {{'password-based', 'built-in-database'}, emqx_authn_mnesia} , {{'password-based', mysql}, emqx_authn_mysql}