refactor(authn): register providers in batch

This commit is contained in:
Zaiming Shi 2021-09-16 22:58:42 +02:00
parent 0b432a6a77
commit 0877fb5569
3 changed files with 93 additions and 61 deletions

View File

@ -40,8 +40,10 @@
, stop/0 , stop/0
]). ]).
-export([ add_provider/2 -export([ register_provider/2
, remove_provider/1 , register_providers/1
, deregister_provider/1
, deregister_providers/1
, create_chain/1 , create_chain/1
, delete_chain/1 , delete_chain/1
, lookup_chain/1 , lookup_chain/1
@ -334,13 +336,27 @@ stop() ->
get_refs() -> get_refs() ->
gen_server:call(?MODULE, get_refs). gen_server:call(?MODULE, get_refs).
-spec add_provider(authn_type(), module()) -> ok. %% @doc Register authentication providers.
add_provider(AuthNType, Provider) -> %% A provider is a tuple of `AuthNType' the module which implements
gen_server:call(?MODULE, {add_provider, AuthNType, Provider}). %% the authenticator callbacks.
%% For example, ``[{{'password-based', redis}, emqx_authn_redis}]''
%% NOTE: Later registered provider may override earlier registered if they
%% happen to clash the same `AuthNType'.
-spec register_providers([{authn_type(), module()}]) -> ok.
register_providers(Providers) ->
gen_server:call(?MODULE, {register_providers, Providers}).
-spec remove_provider(authn_type()) -> ok. -spec register_provider(authn_type(), module()) -> ok.
remove_provider(AuthNType) -> register_provider(AuthNType, Provider) ->
gen_server:call(?MODULE, {remove_provider, AuthNType}). register_providers([{AuthNType, Provider}]).
-spec deregister_providers([authn_type()]) -> ok.
deregister_providers(AuthNTypes) when is_list(AuthNTypes) ->
gen_server:call(?MODULE, {deregister_providers, AuthNTypes}).
-spec deregister_provider(authn_type()) -> ok.
deregister_provider(AuthNType) ->
deregister_providers([AuthNType]).
-spec create_chain(chain_name()) -> {ok, chain()} | {error, term()}. -spec create_chain(chain_name()) -> {ok, chain()} | {error, term()}.
create_chain(Name) -> create_chain(Name) ->
@ -447,11 +463,20 @@ init(_Opts) ->
ok = emqx_config_handler:add_handler([listeners, '?', '?', authentication], ?MODULE), ok = emqx_config_handler:add_handler([listeners, '?', '?', authentication], ?MODULE),
{ok, #{hooked => false, providers => #{}}}. {ok, #{hooked => false, providers => #{}}}.
handle_call({add_provider, AuthNType, Provider}, _From, #{providers := Providers} = State) -> handle_call({register_providers, Providers}, _From,
reply(ok, State#{providers := Providers#{AuthNType => Provider}}); #{providers := Reg0} = State) ->
case lists:filter(fun({T, _}) -> maps:is_key(T, Reg0) end, Providers) of
[] ->
Reg = lists:foldl(fun({AuthNType, Module}, Pin) ->
Pin#{AuthNType => Module}
end, Reg0, Providers),
reply(ok, State#{providers := Reg});
Clashes ->
reply({error, {authentication_type_clash, Clashes}}, State)
end;
handle_call({remove_provider, AuthNType}, _From, #{providers := Providers} = State) -> handle_call({deregister_providers, AuthNTypes}, _From, #{providers := Providers} = State) ->
reply(ok, State#{providers := maps:remove(AuthNType, Providers)}); reply(ok, State#{providers := maps:without(AuthNTypes, Providers)});
handle_call(get_refs, _From, #{providers := Providers} = State) -> handle_call(get_refs, _From, #{providers := Providers} = State) ->
Refs = lists:foldl(fun({_, Provider}, Acc) -> Refs = lists:foldl(fun({_, Provider}, Acc) ->

View File

@ -36,6 +36,7 @@
]). ]).
-define(AUTHN, emqx_authentication). -define(AUTHN, emqx_authentication).
-define(config(KEY), (fun() -> {KEY, _V_} = lists:keyfind(KEY, 1, Config), _V_ end)()).
%%------------------------------------------------------------------------------ %%------------------------------------------------------------------------------
%% Hocon Schema %% Hocon Schema
@ -92,20 +93,22 @@ end_per_suite(_) ->
emqx_ct_helpers:stop_apps([]), emqx_ct_helpers:stop_apps([]),
ok. ok.
init_per_testcase(_, Config) -> init_per_testcase(Case, Config) ->
meck:new(emqx, [non_strict, passthrough, no_history, no_link]), meck:new(emqx, [non_strict, passthrough, no_history, no_link]),
meck:expect(emqx, get_config, fun([node, data_dir]) -> meck:expect(emqx, get_config, fun([node, data_dir]) ->
{data_dir, Data} = lists:keyfind(data_dir, 1, Config), {data_dir, Data} = lists:keyfind(data_dir, 1, Config),
Data; Data;
(C) -> meck:passthrough([C]) (C) -> meck:passthrough([C])
end), end),
Config. ?MODULE:Case({'init', Config}).
end_per_testcase(_, _Config) -> end_per_testcase(Case, Config) ->
_ = ?MODULE:Case({'end', Config}),
meck:unload(emqx), meck:unload(emqx),
ok. ok.
t_chain(_) -> t_chain({_, Config}) -> Config;
t_chain(Config) when is_list(Config) ->
% CRUD of authentication chain % CRUD of authentication chain
ChainName = 'test', ChainName = 'test',
?assertMatch({ok, []}, ?AUTHN:list_chains()), ?assertMatch({ok, []}, ?AUTHN:list_chains()),
@ -117,7 +120,10 @@ t_chain(_) ->
?assertMatch({error, {not_found, {chain, ChainName}}}, ?AUTHN:lookup_chain(ChainName)), ?assertMatch({error, {not_found, {chain, ChainName}}}, ?AUTHN:lookup_chain(ChainName)),
ok. ok.
t_authenticator(_) -> t_authenticator({'init', Config}) ->
[{"auth1", {'password-based', 'built-in-database'}},
{"auth2", {'password-based', mysql}} | Config];
t_authenticator(Config) when is_list(Config) ->
ChainName = 'test', ChainName = 'test',
AuthenticatorConfig1 = #{mechanism => 'password-based', AuthenticatorConfig1 = #{mechanism => 'password-based',
backend => 'built-in-database', backend => 'built-in-database',
@ -129,8 +135,8 @@ t_authenticator(_) ->
% Create an authenticator when the provider does not exist % Create an authenticator when the provider does not exist
?assertEqual({error, no_available_provider}, ?AUTHN:create_authenticator(ChainName, AuthenticatorConfig1)), ?assertEqual({error, no_available_provider}, ?AUTHN:create_authenticator(ChainName, AuthenticatorConfig1)),
AuthNType1 = {'password-based', 'built-in-database'}, AuthNType1 = ?config("auth1"),
?AUTHN:add_provider(AuthNType1, ?MODULE), register_provider(AuthNType1, ?MODULE),
ID1 = <<"password-based:built-in-database">>, ID1 = <<"password-based:built-in-database">>,
% CRUD of authencaticator % CRUD of authencaticator
@ -144,8 +150,8 @@ t_authenticator(_) ->
?assertMatch({ok, []}, ?AUTHN:list_authenticators(ChainName)), ?assertMatch({ok, []}, ?AUTHN:list_authenticators(ChainName)),
% Multiple authenticators exist at the same time % Multiple authenticators exist at the same time
AuthNType2 = {'password-based', mysql}, AuthNType2 = ?config("auth2"),
?AUTHN:add_provider(AuthNType2, ?MODULE), register_provider(AuthNType2, ?MODULE),
ID2 = <<"password-based:mysql">>, ID2 = <<"password-based:mysql">>,
AuthenticatorConfig2 = #{mechanism => 'password-based', AuthenticatorConfig2 = #{mechanism => 'password-based',
backend => mysql, backend => mysql,
@ -160,15 +166,18 @@ t_authenticator(_) ->
?assertEqual(ok, ?AUTHN:move_authenticator(ChainName, ID2, bottom)), ?assertEqual(ok, ?AUTHN:move_authenticator(ChainName, ID2, bottom)),
?assertMatch({ok, [#{id := ID1}, #{id := ID2}]}, ?AUTHN:list_authenticators(ChainName)), ?assertMatch({ok, [#{id := ID1}, #{id := ID2}]}, ?AUTHN:list_authenticators(ChainName)),
?assertEqual(ok, ?AUTHN:move_authenticator(ChainName, ID2, {before, ID1})), ?assertEqual(ok, ?AUTHN:move_authenticator(ChainName, ID2, {before, ID1})),
?assertMatch({ok, [#{id := ID2}, #{id := ID1}]}, ?AUTHN:list_authenticators(ChainName)), ?assertMatch({ok, [#{id := ID2}, #{id := ID1}]}, ?AUTHN:list_authenticators(ChainName));
t_authenticator({'end', Config}) ->
?AUTHN:delete_chain(ChainName), ?AUTHN:delete_chain(test),
?AUTHN:remove_provider(AuthNType1), ?AUTHN:deregister_providers([?config("auth1"), ?config("auth2")]),
?AUTHN:remove_provider(AuthNType2),
ok. ok.
t_authenticate(_) -> t_authenticate({init, Config}) ->
ListenerID = 'tcp:default', [{listener_id, 'tcp:default'},
{authn_type, {'password-based', 'built-in-database'}} | Config];
t_authenticate(Config) when is_list(Config) ->
ListenerID = ?config(listener_id),
AuthNType = ?config(authn_type),
ClientInfo = #{zone => default, ClientInfo = #{zone => default,
listener => ListenerID, listener => ListenerID,
protocol => mqtt, protocol => mqtt,
@ -176,8 +185,7 @@ t_authenticate(_) ->
password => <<"any">>}, password => <<"any">>},
?assertEqual({ok, #{is_superuser => false}}, emqx_access_control:authenticate(ClientInfo)), ?assertEqual({ok, #{is_superuser => false}}, emqx_access_control:authenticate(ClientInfo)),
AuthNType = {'password-based', 'built-in-database'}, register_provider(AuthNType, ?MODULE),
?AUTHN:add_provider(AuthNType, ?MODULE),
AuthenticatorConfig = #{mechanism => 'password-based', AuthenticatorConfig = #{mechanism => 'password-based',
backend => 'built-in-database', backend => 'built-in-database',
@ -185,21 +193,24 @@ t_authenticate(_) ->
?AUTHN:create_chain(ListenerID), ?AUTHN:create_chain(ListenerID),
?assertMatch({ok, _}, ?AUTHN:create_authenticator(ListenerID, AuthenticatorConfig)), ?assertMatch({ok, _}, ?AUTHN:create_authenticator(ListenerID, AuthenticatorConfig)),
?assertEqual({ok, #{is_superuser => true}}, emqx_access_control:authenticate(ClientInfo)), ?assertEqual({ok, #{is_superuser => true}}, emqx_access_control:authenticate(ClientInfo)),
?assertEqual({error, bad_username_or_password}, emqx_access_control:authenticate(ClientInfo#{username => <<"bad">>})), ?assertEqual({error, bad_username_or_password}, emqx_access_control:authenticate(ClientInfo#{username => <<"bad">>}));
t_authenticate({'end', Config}) ->
?AUTHN:delete_chain(ListenerID), ?AUTHN:delete_chain(?config(listener_id)),
?AUTHN:remove_provider(AuthNType), ?AUTHN:deregister_provider(?config(authn_type)),
ok. ok.
t_update_config(_) -> t_update_config({init, Config}) ->
emqx_config_handler:add_handler([authentication], emqx_authentication), Global = 'mqtt:global',
AuthNType1 = {'password-based', 'built-in-database'}, AuthNType1 = {'password-based', 'built-in-database'},
AuthNType2 = {'password-based', mysql}, AuthNType2 = {'password-based', mysql},
?AUTHN:add_provider(AuthNType1, ?MODULE), [{global, Global},
?AUTHN:add_provider(AuthNType2, ?MODULE), {"auth1", AuthNType1},
{"auth2", AuthNType2} | Config];
Global = 'mqtt:global', t_update_config(Config) when is_list(Config) ->
emqx_config_handler:add_handler([authentication], emqx_authentication),
ok = register_provider(?config("auth1"), ?MODULE),
ok = register_provider(?config("auth2"), ?MODULE),
Global = ?config(global),
AuthenticatorConfig1 = #{mechanism => 'password-based', AuthenticatorConfig1 = #{mechanism => 'password-based',
backend => 'built-in-database', backend => 'built-in-database',
enable => true}, enable => true},
@ -240,14 +251,14 @@ t_update_config(_) ->
?assertMatch({ok, [#{id := ID2}, #{id := ID1}]}, ?AUTHN:list_authenticators(ListenerID)), ?assertMatch({ok, [#{id := ID2}, #{id := ID1}]}, ?AUTHN:list_authenticators(ListenerID)),
?assertMatch({ok, _}, update_config(ConfKeyPath, {delete_authenticator, ListenerID, ID1})), ?assertMatch({ok, _}, update_config(ConfKeyPath, {delete_authenticator, ListenerID, ID1})),
?assertEqual({error, {not_found, {authenticator, ID1}}}, ?AUTHN:lookup_authenticator(ListenerID, ID1)), ?assertEqual({error, {not_found, {authenticator, ID1}}}, ?AUTHN:lookup_authenticator(ListenerID, ID1));
t_update_config({'end', Config}) ->
?AUTHN:delete_chain(Global), ?AUTHN:delete_chain(?config(global)),
?AUTHN:remove_provider(AuthNType1), ?AUTHN:deregister_providers([?config("auth1"), ?config("auth2")]),
?AUTHN:remove_provider(AuthNType2),
ok. ok.
t_convert_cert_options(_) -> t_convert_cert_options({_, Config}) -> Config;
t_convert_cert_options(Config) when is_list(Config) ->
Certs = certs([ {<<"keyfile">>, "key.pem"} Certs = certs([ {<<"keyfile">>, "key.pem"}
, {<<"certfile">>, "cert.pem"} , {<<"certfile">>, "cert.pem"}
, {<<"cacertfile">>, "cacert.pem"} , {<<"cacertfile">>, "cacert.pem"}
@ -285,3 +296,6 @@ certs(Certs) ->
diff_cert(CertFile, CertPem2) -> diff_cert(CertFile, CertPem2) ->
{ok, CertPem1} = file:read_file(CertFile), {ok, CertPem1} = file:read_file(CertFile),
?AUTHN:diff_cert(CertPem1, CertPem2). ?AUTHN:diff_cert(CertPem1, CertPem2).
register_provider(Type, Module) ->
ok = ?AUTHN:register_providers([{Type, Module}]).

View File

@ -32,34 +32,27 @@
start(_StartType, _StartArgs) -> start(_StartType, _StartArgs) ->
ok = ekka_rlog:wait_for_shards([?AUTH_SHARD], infinity), ok = ekka_rlog:wait_for_shards([?AUTH_SHARD], infinity),
{ok, Sup} = emqx_authn_sup:start_link(), {ok, Sup} = emqx_authn_sup:start_link(),
ok = add_providers(), ok = ?AUTHN:register_providers(providers()),
ok = initialize(), ok = initialize(),
{ok, Sup}. {ok, Sup}.
stop(_State) -> stop(_State) ->
ok = remove_providers(), ok = ?AUTHN:deregister_providers(provider_types()),
ok. ok.
%%------------------------------------------------------------------------------ %%------------------------------------------------------------------------------
%% Internal functions %% Internal functions
%%------------------------------------------------------------------------------ %%------------------------------------------------------------------------------
add_providers() ->
lists:foreach(fun(AuthNType, Provider}) ->
?AUTHN:add_provider(AuthNType, Provider)
end, providers()).
remove_providers() ->
lists:foreach(fun({AuthNType, _}) ->
?AUTHN:remove_provider(AuthNType)
end, providers()).
initialize() -> initialize() ->
?AUTHN:initialize_authentication(?GLOBAL, emqx:get_raw_config([authentication], [])), ?AUTHN:initialize_authentication(?GLOBAL, emqx:get_raw_config([authentication], [])),
lists:foreach(fun({ListenerID, ListenerConfig}) -> lists:foreach(fun({ListenerID, ListenerConfig}) ->
?AUTHN:initialize_authentication(ListenerID, maps:get(authentication, ListenerConfig, [])) ?AUTHN:initialize_authentication(ListenerID, maps:get(authentication, ListenerConfig, []))
end, emqx_listeners:list()). end, emqx_listeners:list()).
provider_types() ->
lists:map(fun({Type, _Module}) -> Type end, providers()).
providers() -> providers() ->
[ {{'password-based', 'built-in-database'}, emqx_authn_mnesia} [ {{'password-based', 'built-in-database'}, emqx_authn_mnesia}
, {{'password-based', mysql}, emqx_authn_mysql} , {{'password-based', mysql}, emqx_authn_mysql}