diff --git a/apps/emqx_authentication/data/user-credentials.csv b/apps/emqx_authentication/data/user-credentials.csv new file mode 100644 index 000000000..2543d39ca --- /dev/null +++ b/apps/emqx_authentication/data/user-credentials.csv @@ -0,0 +1,3 @@ +user_id,password_hash,salt +myuser3,b6c743545a7817ae8c8f624371d5f5f0373234bb0ff36b8ffbf19bce0e06ab75,de1024f462fb83910fd13151bd4bd235 +myuser4,ee68c985a69208b6eda8c6c9b4c7c2d2b15ee2352cdd64a903171710a99182e8,ad773b5be9dd0613fe6c2f4d8c403139 diff --git a/apps/emqx_authentication/data/user-credentials.json b/apps/emqx_authentication/data/user-credentials.json new file mode 100644 index 000000000..169122bd2 --- /dev/null +++ b/apps/emqx_authentication/data/user-credentials.json @@ -0,0 +1,12 @@ +[ + { + "user_id":"myuser1", + "password_hash":"c5e46903df45e5dc096dc74657610dbee8deaacae656df88a1788f1847390242", + "salt": "e378187547bf2d6f0545a3f441aa4d8a" + }, + { + "user_id":"myuser2", + "password_hash":"f4d17f300b11e522fd33f497c11b126ef1ea5149c74d2220f9a16dc876d4567b", + "salt": "6d3f9bd5b54d94b98adbcfe10b6d181f" + } +] diff --git a/apps/emqx_authentication/etc/emqx_authentication.conf b/apps/emqx_authentication/etc/emqx_authentication.conf new file mode 100644 index 000000000..e69de29bb diff --git a/apps/emqx_authentication/include/emqx_authentication.hrl b/apps/emqx_authentication/include/emqx_authentication.hrl new file mode 100644 index 000000000..09d3c5fc4 --- /dev/null +++ b/apps/emqx_authentication/include/emqx_authentication.hrl @@ -0,0 +1,41 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2020 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-define(APP, emqx_authentication). + +-type(service_type_name() :: atom()). +-type(service_name() :: binary()). +-type(chain_id() :: binary()). + +-record(service_type, + { name :: service_type_name() + , provider :: module() + , params_spec :: #{atom() => term()} + }). + +-record(service, + { name :: service_name() + , type :: service_type_name() + , provider :: module() + , params :: map() + , state :: map() + }). + +-record(chain, + { id :: chain_id() + , services :: [{service_name(), #service{}}] + , created_at :: integer() + }). diff --git a/apps/emqx_authentication/priv/emqx_authentication.schema b/apps/emqx_authentication/priv/emqx_authentication.schema new file mode 100644 index 000000000..e69de29bb diff --git a/apps/emqx_authentication/rebar.config b/apps/emqx_authentication/rebar.config new file mode 100644 index 000000000..0a0af8c29 --- /dev/null +++ b/apps/emqx_authentication/rebar.config @@ -0,0 +1,18 @@ +{deps, []}. + +{edoc_opts, [{preprocess, true}]}. +{erl_opts, [warn_unused_vars, + warn_shadow_vars, + warnings_as_errors, + warn_unused_import, + warn_obsolete_guard, + debug_info, + {parse_transform}]}. + +{xref_checks, [undefined_function_calls, undefined_functions, + locals_not_used, deprecated_function_calls, + warnings_as_errors, deprecated_functions]}. + +{cover_enabled, true}. +{cover_opts, [verbose]}. +{cover_export_enabled, true}. \ No newline at end of file diff --git a/apps/emqx_authentication/src/emqx_authentication.app.src b/apps/emqx_authentication/src/emqx_authentication.app.src new file mode 100644 index 000000000..4f55ca0a7 --- /dev/null +++ b/apps/emqx_authentication/src/emqx_authentication.app.src @@ -0,0 +1,12 @@ +{application, emqx_authentication, + [{description, "EMQ X Authentication"}, + {vsn, "0.1.0"}, + {modules, []}, + {registered, [emqx_authentication_sup, emqx_authentication_registry]}, + {applications, [kernel,stdlib]}, + {mod, {emqx_authentication_app,[]}}, + {env, []}, + {licenses, ["Apache-2.0"]}, + {maintainers, ["EMQ X Team "]}, + {links, [{"Homepage", "https://emqx.io/"}]} + ]}. diff --git a/apps/emqx_authentication/src/emqx_authentication.erl b/apps/emqx_authentication/src/emqx_authentication.erl new file mode 100644 index 000000000..90a234386 --- /dev/null +++ b/apps/emqx_authentication/src/emqx_authentication.erl @@ -0,0 +1,519 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_authentication). + +-include("emqx_authentication.hrl"). + +-export([ enable/0 + , disable/0 + ]). + +-export([authenticate/1]). + +-export([register_service_types/0]). + +-export([ create_chain/1 + , delete_chain/1 + , lookup_chain/1 + , list_chains/0 + , add_services/2 + , delete_services/2 + , update_service/3 + , lookup_service/2 + , list_services/1 + , move_service_to_the_front/2 + , move_service_to_the_end/2 + , move_service_to_the_nth/3 + ]). + +-export([ import_users/3 + , add_user/3 + , delete_user/3 + , update_user/4 + , lookup_user/3 + , list_users/2 + ]). + +-export([mnesia/1]). + +-boot_mnesia({mnesia, [boot]}). +-copy_mnesia({mnesia, [copy]}). + +-define(CHAIN_TAB, emqx_authentication_chain). +-define(SERVICE_TYPE_TAB, emqx_authentication_service_type). + +%%------------------------------------------------------------------------------ +%% Mnesia bootstrap +%%------------------------------------------------------------------------------ + +%% @doc Create or replicate tables. +-spec(mnesia(boot | copy) -> ok). +mnesia(boot) -> + %% Optimize storage + StoreProps = [{ets, [{read_concurrency, true}]}], + %% Chain table + ok = ekka_mnesia:create_table(?CHAIN_TAB, [ + {disc_copies, [node()]}, + {record_name, chain}, + {attributes, record_info(fields, chain)}, + {storage_properties, StoreProps}]), + %% Service type table + ok = ekka_mnesia:create_table(?SERVICE_TYPE_TAB, [ + {ram_copies, [node()]}, + {record_name, service_type}, + {attributes, record_info(fields, service_type)}, + {storage_properties, StoreProps}]); + +mnesia(copy) -> + %% Copy chain table + ok = ekka_mnesia:copy_table(?CHAIN_TAB, disc_copies), + %% Copy service type table + ok = ekka_mnesia:copy_table(?SERVICE_TYPE_TAB, ram_copies). + +enable() -> + case emqx:hook('client.authenticate', fun emqx_authentication:authenticate/1) of + ok -> ok; + {error, already_exists} -> ok + end. + +disable() -> + emqx:unhook('client.authenticate', fun emqx_authentication:authenticate/1), + ok. + +authenticate(#{chain_id := ChainID} = ClientInfo) -> + case mnesia:dirty_read(?CHAIN_TAB, ChainID) of + [#chain{services = []}] -> + {error, no_services}; + [#chain{services = Services}] -> + do_authenticate(Services, ClientInfo); + [] -> + {error, todo} + end. + +do_authenticate([], _) -> + {error, user_not_found}; +do_authenticate([{_, #service{provider = Provider, state = State}} | More], ClientInfo) -> + case Provider:authenticate(ClientInfo, State) of + ignore -> do_authenticate(More, ClientInfo); + ok -> ok; + {ok, NewClientInfo} -> {ok, NewClientInfo}; + {stop, Reason} -> {error, Reason} + end. + +register_service_types() -> + Attrs = find_attrs(?APP, service_type), + register_service_types(Attrs). + +register_service_types(Attrs) -> + register_service_types(Attrs, []). + +register_service_types([], Acc) -> + do_register_service_types(Acc); +register_service_types([{_App, Mod, #{name := Name, + params_spec := ParamsSpec}} | Types], Acc) -> + %% TODO: Temporary realization + ok = emqx_rule_validator:validate_spec(ParamsSpec), + ServiceType = #service_type{name = Name, + provider = Mod, + params_spec = ParamsSpec}, + register_service_types(Types, [ServiceType | Acc]). + +create_chain(#{id := ID}) -> + trans( + fun() -> + case mnesia:read(?CHAIN_TAB, ID, write) of + [] -> + Chain = #chain{id = ID, + services = [], + created_at = erlang:system_time(millisecond)}, + mnesia:write(?CHAIN_TAB, Chain, write), + {ok, serialize_chain(Chain)}; + [_ | _] -> + {error, {already_exists, {chain, ID}}} + end + end). + +delete_chain(ID) -> + trans( + fun() -> + case mnesia:read(?CHAIN_TAB, ID, write) of + [] -> + {error, {not_found, {chain, ID}}}; + [#chain{services = Services}] -> + ok = delete_services_(Services), + mnesia:delete(?CHAIN_TAB, ID, write) + end + end). + +lookup_chain(ID) -> + case mnesia:dirty_read(?CHAIN_TAB, ID) of + [] -> + {error, {not_found, {chain, ID}}}; + [Chain] -> + {ok, serialize_chain(Chain)} + end. + +list_chains() -> + Chains = ets:tab2list(?CHAIN_TAB), + {ok, [serialize_chain(Chain) || Chain <- Chains]}. + +add_services(ChainID, ServiceParams) -> + case validate_service_params(ServiceParams) of + {ok, NServiceParams} -> + UpdateFun = fun(Chain = #chain{services = Services}) -> + Names = [Name || {Name, _} <- Services] ++ [Name || #{name := Name} <- NServiceParams], + case no_duplicate_names(Names) of + ok -> + case create_services(ChainID, NServiceParams) of + {ok, NServices} -> + NChain = Chain#chain{services = Services ++ NServices}, + ok = mnesia:write(?CHAIN_TAB, NChain, write), + {ok, serialize_services(NServices)}; + {error, Reason} -> + {error, Reason} + end; + {error, {duplicate, Name}} -> + {error, {already_exists, {service, Name}}} + end + end, + update_chain(ChainID, UpdateFun); + {error, Reason} -> + {error, Reason} + end. + +delete_services(ChainID, ServiceNames) -> + case no_duplicate_names(ServiceNames) of + ok -> + UpdateFun = fun(Chain = #chain{services = Services}) -> + case extract_services(ServiceNames, Services) of + {ok, Extracted, Rest} -> + ok = delete_services_(Extracted), + NChain = Chain#chain{services = Rest}, + mnesia:write(?CHAIN_TAB, NChain, write); + {error, Reason} -> + {error, Reason} + end + end, + update_chain(ChainID, UpdateFun); + {error, Reason} -> + {error, Reason} + end. + +update_service(ChainID, ServiceName, NewParams) -> + UpdateFun = fun(Chain = #chain{services = Services}) -> + case proplists:get_value(ServiceName, Services, undefined) of + undefined -> + {error, {not_found, {service, ServiceName}}}; + #service{type = Type, + provider = Provider, + params = OriginalParams, + state = State} = Service -> + Params = maps:merge(OriginalParams, NewParams), + {ok, #service_type{params_spec = ParamsSpec}} = find_service_type(Type), + NParams = emqx_rule_validator:validate_params(Params, ParamsSpec), + case Provider:update(ChainID, ServiceName, NParams, State) of + {ok, NState} -> + NService = Service#service{params = Params, + state = NState}, + NServices = lists:keyreplace(ServiceName, 1, Services, {ServiceName, NService}), + ok = mnesia:write(?CHAIN_TAB, Chain#chain{services = NServices}, write), + {ok, serialize_service({ServiceName, NService})}; + {error, Reason} -> + {error, Reason} + end + end + end, + update_chain(ChainID, UpdateFun). + +lookup_service(ChainID, ServiceName) -> + case mnesia:dirty_read(?CHAIN_TAB, ChainID) of + [] -> + {error, {not_found, {chain, ChainID}}}; + [#chain{services = Services}] -> + case lists:keytake(ServiceName, 1, Services) of + {value, Service, _} -> + {ok, serialize_service(Service)}; + false -> + {error, {not_found, {service, ServiceName}}} + end + end. + +list_services(ChainID) -> + case mnesia:dirty_read(?CHAIN_TAB, ChainID) of + [] -> + {error, {not_found, {chain, ChainID}}}; + [#chain{services = Services}] -> + {ok, serialize_services(Services)} + end. + +move_service_to_the_front(ChainID, ServiceName) -> + UpdateFun = fun(Chain = #chain{services = Services}) -> + case move_service_to_the_front_(ServiceName, Services) of + {ok, NServices} -> + NChain = Chain#chain{services = NServices}, + mnesia:write(?CHAIN_TAB, NChain, write); + {error, Reason} -> + {error, Reason} + end + end, + update_chain(ChainID, UpdateFun). + +move_service_to_the_end(ChainID, ServiceName) -> + UpdateFun = fun(Chain = #chain{services = Services}) -> + case move_service_to_the_end_(ServiceName, Services) of + {ok, NServices} -> + NChain = Chain#chain{services = NServices}, + mnesia:write(?CHAIN_TAB, NChain, write); + {error, Reason} -> + {error, Reason} + end + end, + update_chain(ChainID, UpdateFun). + +move_service_to_the_nth(ChainID, ServiceName, N) -> + UpdateFun = fun(Chain = #chain{services = Services}) -> + case move_service_to_the_nth_(ServiceName, Services, N) of + {ok, NServices} -> + NChain = Chain#chain{services = NServices}, + mnesia:write(?CHAIN_TAB, NChain, write); + {error, Reason} -> + {error, Reason} + end + end, + update_chain(ChainID, UpdateFun). + +import_users(ChainID, ServiceName, Filename) -> + call_service(ChainID, ServiceName, import_users, [Filename]). + +add_user(ChainID, ServiceName, UserInfo) -> + call_service(ChainID, ServiceName, add_user, [UserInfo]). + +delete_user(ChainID, ServiceName, UserID) -> + call_service(ChainID, ServiceName, delete_user, [UserID]). + +update_user(ChainID, ServiceName, UserID, NewUserInfo) -> + call_service(ChainID, ServiceName, update_user, [UserID, NewUserInfo]). + +lookup_user(ChainID, ServiceName, UserID) -> + call_service(ChainID, ServiceName, lookup_user, [UserID]). + +list_users(ChainID, ServiceName) -> + call_service(ChainID, ServiceName, list_users, []). + +%%------------------------------------------------------------------------------ +%% Internal functions +%%------------------------------------------------------------------------------ + +find_attrs(App, AttrName) -> + [{App, Mod, Attr} || {ok, Modules} <- [application:get_key(App, modules)], + Mod <- Modules, + {Name, Attrs} <- module_attributes(Mod), Name =:= AttrName, + Attr <- Attrs]. + +module_attributes(Module) -> + try Module:module_info(attributes) + catch + error:undef -> [] + end. + +do_register_service_types(ServiceTypes) -> + trans(fun lists:foreach/2, [fun insert_service_type/1, ServiceTypes]). + +insert_service_type(ServiceType) -> + mnesia:write(?SERVICE_TYPE_TAB, ServiceType, write). + +find_service_type(Name) -> + case mnesia:dirty_read(?SERVICE_TYPE_TAB, Name) of + [ServiceType] -> {ok, ServiceType}; + [] -> {error, not_found} + end. + +validate_service_params(ServiceParams) -> + case validate_service_names(ServiceParams) of + ok -> + validate_other_service_params(ServiceParams); + {error, Reason} -> + {error, Reason} + end. + +validate_service_names(ServiceParams) -> + Names = [Name || #{name := Name} <- ServiceParams], + no_duplicate_names(Names). + +validate_other_service_params(ServiceParams) -> + validate_other_service_params(ServiceParams, []). + +validate_other_service_params([], Acc) -> + {ok, lists:reverse(Acc)}; +validate_other_service_params([#{type := Type, params := Params} = ServiceParams | More], Acc) -> + case find_service_type(Type) of + {ok, #service_type{provider = Provider, params_spec = ParamsSpec}} -> + NParams = emqx_rule_validator:validate_params(Params, ParamsSpec), + validate_other_service_params(More, + [ServiceParams#{params => NParams, + original_params => Params, + provider => Provider} | Acc]); + {error, not_found} -> + {error, {not_found, {service_type, Type}}} + end. + +no_duplicate_names(Names) -> + no_duplicate_names(Names, #{}). + +no_duplicate_names([], _) -> + ok; +no_duplicate_names([Name | More], Acc) -> + case maps:is_key(Name, Acc) of + false -> no_duplicate_names(More, Acc#{Name => true}); + true -> {error, {duplicate, Name}} + end. + +create_services(ChainID, ServiceParams) -> + create_services(ChainID, ServiceParams, []). + +create_services(_ChainID, [], Acc) -> + {ok, lists:reverse(Acc)}; +create_services(ChainID, [#{name := Name, + type := Type, + provider := Provider, + params := Params, + original_params := OriginalParams} | More], Acc) -> + case Provider:create(ChainID, Name, Params) of + {ok, State} -> + Service = #service{name = Name, + type = Type, + provider = Provider, + params = OriginalParams, + state = State}, + create_services(ChainID, More, [{Name, Service} | Acc]); + {error, Reason} -> + delete_services_(Acc), + {error, Reason} + end. + +delete_services_([]) -> + ok; +delete_services_([{_, #service{provider = Provider, state = State}} | More]) -> + Provider:destroy(State), + delete_services_(More). + +extract_services(ServiceNames, Services) -> + extract_services(ServiceNames, Services, []). + +extract_services([], Rest, Extracted) -> + {ok, lists:reverse(Extracted), Rest}; +extract_services([ServiceName | More], Services, Acc) -> + case lists:keytake(ServiceName, 1, Services) of + {value, Extracted, Rest} -> + extract_services(More, Rest, [Extracted | Acc]); + false -> + {error, {not_found, {service, ServiceName}}} + end. + +move_service_to_the_front_(ServiceName, Services) -> + move_service_to_the_front_(ServiceName, Services, []). + +move_service_to_the_front_(ServiceName, [], _) -> + {error, {not_found, {service, ServiceName}}}; +move_service_to_the_front_(ServiceName, [{ServiceName, _} = Service | More], Passed) -> + {ok, [Service | (lists:reverse(Passed) ++ More)]}; +move_service_to_the_front_(ServiceName, [Service | More], Passed) -> + move_service_to_the_front_(ServiceName, More, [Service | Passed]). + +move_service_to_the_end_(ServiceName, Services) -> + move_service_to_the_end_(ServiceName, Services, []). + +move_service_to_the_end_(ServiceName, [], _) -> + {error, {not_found, {service, ServiceName}}}; +move_service_to_the_end_(ServiceName, [{ServiceName, _} = Service | More], Passed) -> + {ok, lists:reverse(Passed) ++ More ++ [Service]}; +move_service_to_the_end_(ServiceName, [Service | More], Passed) -> + move_service_to_the_end_(ServiceName, More, [Service | Passed]). + +move_service_to_the_nth_(ServiceName, Services, N) + when N =< length(Services) andalso N > 0 -> + move_service_to_the_nth_(ServiceName, Services, N, []); +move_service_to_the_nth_(_, _, _) -> + {error, out_of_range}. + +move_service_to_the_nth_(ServiceName, [], _, _) -> + {error, {not_found, {service, ServiceName}}}; +move_service_to_the_nth_(ServiceName, [{ServiceName, _} = Service | More], N, Passed) + when N =< length(Passed) -> + {L1, L2} = lists:split(N - 1, lists:reverse(Passed)), + {ok, L1 ++ [Service] ++ L2 ++ More}; +move_service_to_the_nth_(ServiceName, [{ServiceName, _} = Service | More], N, Passed) -> + {L1, L2} = lists:split(N - length(Passed) - 1, More), + {ok, lists:reverse(Passed) ++ L1 ++ [Service] ++ L2}; +move_service_to_the_nth_(ServiceName, [Service | More], N, Passed) -> + move_service_to_the_nth_(ServiceName, More, N, [Service | Passed]). + +update_chain(ChainID, UpdateFun) -> + trans( + fun() -> + case mnesia:read(?CHAIN_TAB, ChainID, write) of + [] -> + {error, {not_found, {chain, ChainID}}}; + [Chain] -> + UpdateFun(Chain) + end + end). + +call_service(ChainID, ServiceName, Func, Args) -> + case mnesia:dirty_read(?CHAIN_TAB, ChainID) of + [] -> + {error, {not_found, {chain, ChainID}}}; + [#chain{services = Services}] -> + case proplists:get_value(ServiceName, Services, undefined) of + undefined -> + {error, {not_found, {service, ServiceName}}}; + #service{provider = Provider, + state = State} -> + case erlang:function_exported(Provider, Func, length(Args) + 1) of + true -> + erlang:apply(Provider, Func, Args ++ [State]); + false -> + {error, unsupported_feature} + end + end + end. + +serialize_chain(#chain{id = ID, + services = Services, + created_at = CreatedAt}) -> + #{id => ID, + services => serialize_services(Services), + created_at => CreatedAt}. + +serialize_services(Services) -> + [serialize_service(Service) || Service <- Services]. + +serialize_service({_, #service{name = Name, + type = Type, + params = Params}}) -> + #{name => Name, + type => Type, + params => Params}. + +trans(Fun) -> + trans(Fun, []). + +trans(Fun, Args) -> + case mnesia:transaction(Fun, Args) of + {atomic, Res} -> Res; + {aborted, Reason} -> {error, Reason} + end. diff --git a/apps/emqx_authentication/src/emqx_authentication_api.erl b/apps/emqx_authentication/src/emqx_authentication_api.erl new file mode 100644 index 000000000..74887a0b2 --- /dev/null +++ b/apps/emqx_authentication/src/emqx_authentication_api.erl @@ -0,0 +1,407 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_authentication_api). + +-export([ create_chain/2 + , delete_chain/2 + , lookup_chain/2 + , list_chains/2 + , add_service/2 + , delete_service/2 + , update_service/2 + , lookup_service/2 + , list_services/2 + , move_service/2 + , import_users/2 + , add_user/2 + , delete_user/2 + , update_user/2 + , lookup_user/2 + , list_users/2 + ]). + +-import(minirest, [return/1]). + +-rest_api(#{name => create_chain, + method => 'POST', + path => "/authentication/chains", + func => create_chain, + descr => "Create a chain" + }). + +-rest_api(#{name => delete_chain, + method => 'DELETE', + path => "/authentication/chains/:bin:id", + func => delete_chain, + descr => "Delete chain" + }). + +-rest_api(#{name => lookup_chain, + method => 'GET', + path => "/authentication/chains/:bin:id", + func => lookup_chain, + descr => "Lookup chain" + }). + +-rest_api(#{name => list_chains, + method => 'GET', + path => "/authentication/chains", + func => list_chains, + descr => "List all chains" + }). + +-rest_api(#{name => add_service, + method => 'POST', + path => "/authentication/chains/:bin:id/services", + func => add_service, + descr => "Add service to chain" + }). + +-rest_api(#{name => delete_service, + method => 'DELETE', + path => "/authentication/chains/:bin:id/services/:bin:service_name", + func => delete_service, + descr => "Delete service from chain" + }). + +-rest_api(#{name => update_service, + method => 'PUT', + path => "/authentication/chains/:bin:id/services/:bin:service_name", + func => update_service, + descr => "Update service in chain" + }). + +-rest_api(#{name => lookup_service, + method => 'GET', + path => "/authentication/chains/:bin:id/services/:bin:service_name", + func => lookup_service, + descr => "Lookup service in chain" + }). + +-rest_api(#{name => list_services, + method => 'GET', + path => "/authentication/chains/:bin:id/services", + func => list_services, + descr => "List services in chain" + }). + +-rest_api(#{name => move_service, + method => 'POST', + path => "/authentication/chains/:bin:id/services/:bin:service_name/position", + func => move_service, + descr => "Change the order of services" + }). + +-rest_api(#{name => import_users, + method => 'POST', + path => "/authentication/chains/:bin:id/services/:bin:service_name/import-users", + func => import_users, + descr => "Import users" + }). + +-rest_api(#{name => add_user, + method => 'POST', + path => "/authentication/chains/:bin:id/services/:bin:service_name/users", + func => add_user, + descr => "Add user" + }). + +-rest_api(#{name => delete_user, + method => 'DELETE', + path => "/authentication/chains/:bin:id/services/:bin:service_name/users/:bin:user_id", + func => delete_user, + descr => "Delete user" + }). + +-rest_api(#{name => update_user, + method => 'PUT', + path => "/authentication/chains/:bin:id/services/:bin:service_name/users/:bin:user_id", + func => update_user, + descr => "Update user" + }). + +-rest_api(#{name => lookup_user, + method => 'GET', + path => "/authentication/chains/:bin:id/services/:bin:service_name/users/:bin:user_id", + func => lookup_user, + descr => "Lookup user" + }). + +%% TODO: Support pagination +-rest_api(#{name => list_users, + method => 'GET', + path => "/authentication/chains/:bin:id/services/:bin:service_name/users", + func => list_users, + descr => "List all users" + }). + +create_chain(Binding, Params) -> + do_create_chain(uri_decode(Binding), maps:from_list(Params)). + +do_create_chain(_Binding, #{<<"id">> := ChainID}) -> + case emqx_authentication:create_chain(#{id => ChainID}) of + {ok, Chain} -> + return({ok, Chain}); + {error, Reason} -> + return(serialize_error(Reason)) + end; +do_create_chain(_Binding, _Params) -> + return(serialize_error({missing_parameter, id})). + +delete_chain(Binding, Params) -> + do_delete_chain(uri_decode(Binding), maps:from_list(Params)). + +do_delete_chain(#{id := ChainID}, _Params) -> + case emqx_authentication:delete_chain(ChainID) of + ok -> + return(ok); + {error, Reason} -> + return(serialize_error(Reason)) + end. + +lookup_chain(Binding, Params) -> + do_lookup_chain(uri_decode(Binding), maps:from_list(Params)). + +do_lookup_chain(#{id := ChainID}, _Params) -> + case emqx_authentication:lookup_chain(ChainID) of + {ok, Chain} -> + return({ok, Chain}); + {error, Reason} -> + return(serialize_error(Reason)) + end. + +list_chains(Binding, Params) -> + do_list_chains(uri_decode(Binding), maps:from_list(Params)). + +do_list_chains(_Binding, _Params) -> + {ok, Chains} = emqx_authentication:list_chains(), + return({ok, Chains}). + +add_service(Binding, Params) -> + do_add_service(uri_decode(Binding), maps:from_list(Params)). + +do_add_service(#{id := ChainID}, #{<<"name">> := Name, + <<"type">> := Type, + <<"params">> := Params}) -> + case emqx_authentication:add_services(ChainID, [#{name => Name, + type => binary_to_existing_atom(Type, utf8), + params => maps:from_list(Params)}]) of + {ok, Services} -> + return({ok, Services}); + {error, Reason} -> + return(serialize_error(Reason)) + end; +%% TODO: Check missed field in params +do_add_service(_Binding, Params) -> + Missed = get_missed_params(Params, [<<"name">>, <<"type">>, <<"params">>]), + return(serialize_error({missing_parameter, Missed})). + +delete_service(Binding, Params) -> + do_delete_service(uri_decode(Binding), maps:from_list(Params)). + +do_delete_service(#{id := ChainID, + service_name := ServiceName}, _Params) -> + case emqx_authentication:delete_services(ChainID, [ServiceName]) of + ok -> + return(ok); + {error, Reason} -> + return(serialize_error(Reason)) + end. + +update_service(Binding, Params) -> + do_update_service(uri_decode(Binding), maps:from_list(Params)). + +%% TOOD: PUT method supports creation and update +do_update_service(#{id := ChainID, + service_name := ServiceName}, Params) -> + case emqx_authentication:update_service(ChainID, ServiceName, Params) of + {ok, Service} -> + return({ok, Service}); + {error, Reason} -> + return(serialize_error(Reason)) + end. + +lookup_service(Binding, Params) -> + do_lookup_service(uri_decode(Binding), maps:from_list(Params)). + +do_lookup_service(#{id := ChainID, + service_name := ServiceName}, _Params) -> + case emqx_authentication:lookup_service(ChainID, ServiceName) of + {ok, Service} -> + return({ok, Service}); + {error, Reason} -> + return(serialize_error(Reason)) + end. + +list_services(Binding, Params) -> + do_list_services(uri_decode(Binding), maps:from_list(Params)). + +do_list_services(#{id := ChainID}, _Params) -> + case emqx_authentication:list_services(ChainID) of + {ok, Services} -> + return({ok, Services}); + {error, Reason} -> + return(serialize_error(Reason)) + end. + +move_service(Binding, Params) -> + do_move_service(uri_decode(Binding), maps:from_list(Params)). + +do_move_service(#{id := ChainID, + service_name := ServiceName}, #{<<"position">> := <<"the front">>}) -> + case emqx_authentication:move_service_to_the_front(ChainID, ServiceName) of + ok -> + return(ok); + {error, Reason} -> + return(serialize_error(Reason)) + end; +do_move_service(#{id := ChainID, + service_name := ServiceName}, #{<<"position">> := <<"the end">>}) -> + case emqx_authentication:move_service_to_the_end(ChainID, ServiceName) of + ok -> + return(ok); + {error, Reason} -> + return(serialize_error(Reason)) + end; +do_move_service(#{id := ChainID, + service_name := ServiceName}, #{<<"position">> := N}) when is_number(N) -> + case emqx_authentication:move_service_to_the_nth(ChainID, ServiceName, N) of + ok -> + return(ok); + {error, Reason} -> + return(serialize_error(Reason)) + end; +do_move_service(_Binding, _Params) -> + return(serialize_error({missing_parameter, <<"position">>})). + +import_users(Binding, Params) -> + do_import_users(uri_decode(Binding), maps:from_list(Params)). + +do_import_users(#{id := ChainID, service_name := ServiceName}, + #{<<"filename">> := Filename}) -> + case emqx_authentication:import_users(ChainID, ServiceName, Filename) of + ok -> + return(ok); + {error, Reason} -> + return(serialize_error(Reason)) + end; +do_import_users(_Binding, Params) -> + Missed = get_missed_params(Params, [<<"filename">>, <<"file_format">>]), + return(serialize_error({missing_parameter, Missed})). + +add_user(Binding, Params) -> + do_add_user(uri_decode(Binding), maps:from_list(Params)). + +do_add_user(#{id := ChainID, + service_name := ServiceName}, UserInfo) -> + case emqx_authentication:add_user(ChainID, ServiceName, UserInfo) of + {ok, User} -> + return({ok, User}); + {error, Reason} -> + return(serialize_error(Reason)) + end. + +delete_user(Binding, Params) -> + do_delete_user(uri_decode(Binding), maps:from_list(Params)). + +do_delete_user(#{id := ChainID, + service_name := ServiceName, + user_id := UserID}, _Params) -> + case emqx_authentication:delete_user(ChainID, ServiceName, UserID) of + ok -> + return(ok); + {error, Reason} -> + return(serialize_error(Reason)) + end. + +update_user(Binding, Params) -> + do_update_user(uri_decode(Binding), maps:from_list(Params)). + +do_update_user(#{id := ChainID, + service_name := ServiceName, + user_id := UserID}, NewUserInfo) -> + case emqx_authentication:update_user(ChainID, ServiceName, UserID, NewUserInfo) of + {ok, User} -> + return({ok, User}); + {error, Reason} -> + return(serialize_error(Reason)) + end. + +lookup_user(Binding, Params) -> + do_lookup_user(uri_decode(Binding), maps:from_list(Params)). + +do_lookup_user(#{id := ChainID, + service_name := ServiceName, + user_id := UserID}, _Params) -> + case emqx_authentication:lookup_user(ChainID, ServiceName, UserID) of + {ok, User} -> + return({ok, User}); + {error, Reason} -> + return(serialize_error(Reason)) + end. + +list_users(Binding, Params) -> + do_list_users(uri_decode(Binding), maps:from_list(Params)). + +do_list_users(#{id := ChainID, + service_name := ServiceName}, _Params) -> + case emqx_authentication:list_users(ChainID, ServiceName) of + {ok, Users} -> + return({ok, Users}); + {error, Reason} -> + return(serialize_error(Reason)) + end. + +%%------------------------------------------------------------------------------ +%% Internal functions +%%------------------------------------------------------------------------------ + +uri_decode(Params) -> + maps:fold(fun(K, V, Acc) -> + Acc#{K => emqx_http_lib:uri_decode(V)} + end, #{}, Params). + +serialize_error({already_exists, {Type, ID}}) -> + {error, <<"ALREADY_EXISTS">>, list_to_binary(io_lib:format("~s '~s' already exists", [serialize_type(Type), ID]))}; +serialize_error({not_found, {Type, ID}}) -> + {error, <<"NOT_FOUND">>, list_to_binary(io_lib:format("~s '~s' not found", [serialize_type(Type), ID]))}; +serialize_error({duplicate, Name}) -> + {error, <<"INVALID_PARAMETER">>, list_to_binary(io_lib:format("Service name '~s' is duplicated", [Name]))}; +serialize_error({missing_parameter, Names = [_ | Rest]}) -> + Format = ["~s," || _ <- Rest] ++ ["~s"], + NFormat = binary_to_list(iolist_to_binary(Format)), + {error, <<"MISSING_PARAMETER">>, list_to_binary(io_lib:format("The input parameters " ++ NFormat ++ " that are mandatory for processing this request are not supplied.", Names))}; +serialize_error({missing_parameter, Name}) -> + {error, <<"MISSING_PARAMETER">>, list_to_binary(io_lib:format("The input parameter '~s' that is mandatory for processing this request is not supplied.", [Name]))}; +serialize_error(_) -> + {error, <<"UNKNOWN_ERROR">>, <<"Unknown error">>}. + +serialize_type(service) -> + "Service"; +serialize_type(chain) -> + "Chain"; +serialize_type(service_type) -> + "Service type". + +get_missed_params(Actual, Expected) -> + Keys = lists:foldl(fun(Key, Acc) -> + case maps:is_key(Key, Actual) of + true -> Acc; + false -> [Key | Acc] + end + end, [], Expected), + lists:reverse(Keys). diff --git a/apps/emqx_authentication/src/emqx_authentication_app.erl b/apps/emqx_authentication/src/emqx_authentication_app.erl new file mode 100644 index 000000000..2d395def7 --- /dev/null +++ b/apps/emqx_authentication/src/emqx_authentication_app.erl @@ -0,0 +1,34 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_authentication_app). + +-behaviour(application). + +-emqx_plugin(?MODULE). + +%% Application callbacks +-export([ start/2 + , stop/1 + ]). + +start(_StartType, _StartArgs) -> + {ok, Sup} = emqx_authentication_sup:start_link(), + ok = emqx_authentication:register_service_types(), + {ok, Sup}. + +stop(_State) -> + ok. diff --git a/apps/emqx_authentication/src/emqx_authentication_jwks_connector.erl b/apps/emqx_authentication/src/emqx_authentication_jwks_connector.erl new file mode 100644 index 000000000..9dafc9f5e --- /dev/null +++ b/apps/emqx_authentication/src/emqx_authentication_jwks_connector.erl @@ -0,0 +1,171 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_authentication_jwks_connector). + +-behaviour(gen_server). + +-include_lib("emqx/include/logger.hrl"). +-include_lib("jose/include/jose_jwk.hrl"). + +-export([ start_link/1 + , stop/1 + ]). + +-export([ get_jwks/1 + , update/2 + ]). + +%% gen_server callbacks +-export([ init/1 + , handle_call/3 + , handle_cast/2 + , handle_info/2 + , terminate/2 + , code_change/3 + ]). + +%%-------------------------------------------------------------------- +%% APIs +%%-------------------------------------------------------------------- + +start_link(Opts) -> + gen_server:start_link(?MODULE, [Opts], []). + +stop(Pid) -> + gen_server:stop(Pid). + +get_jwks(Pid) -> + gen_server:call(Pid, get_cached_jwks, 5000). + +update(Pid, Opts) -> + gen_server:call(Pid, {update, Opts}, 5000). + +%%-------------------------------------------------------------------- +%% gen_server callbacks +%%-------------------------------------------------------------------- + +init([Opts]) -> + ok = jose:json_module(jiffy), + State = handle_options(Opts), + {ok, refresh_jwks(State)}. + +handle_call(get_cached_jwks, _From, #{jwks := Jwks} = State) -> + {reply, {ok, Jwks}, State}; + +handle_call({update, Opts}, _From, State) -> + State = handle_options(Opts), + {reply, ok, refresh_jwks(State)}; + +handle_call(_Req, _From, State) -> + {reply, ok, State}. + +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info({refresh_jwks, _TRef, refresh}, #{request_id := RequestID} = State) -> + case RequestID of + undefined -> ok; + _ -> + ok = httpc:cancel_request(RequestID), + receive + {http, _} -> ok + after 0 -> + ok + end + end, + {noreply, refresh_jwks(State)}; + +handle_info({http, {RequestID, Result}}, + #{request_id := RequestID, endpoint := Endpoint} = State0) -> + State1 = State0#{request_id := undefined}, + case Result of + {error, Reason} -> + ?LOG(error, "Failed to request jwks endpoint(~s): ~p", [Endpoint, Reason]), + State1; + {_StatusLine, _Headers, Body} -> + try + JWKS = jose_jwk:from(emqx_json:decode(Body, [return_maps])), + {_, JWKs} = JWKS#jose_jwk.keys, + State1#{jwks := JWKs} + catch _:_ -> + ?LOG(error, "Invalid jwks returned from jwks endpoint(~s): ~p~n", [Endpoint, Body]), + State1 + end + end; + +handle_info({http, {_, _}}, State) -> + %% ignore + {noreply, State}; + +handle_info(_Info, State) -> + {noreply, State}. + +terminate(_Reason, State) -> + _ = cancel_timer(State), + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%% Internal functions +%%-------------------------------------------------------------------- + +handle_options(Opts) -> + #{endpoint => proplists:get_value(jwks_endpoint, Opts), + refresh_interval => limit_refresh_interval(proplists:get_value(refresh_interval, Opts)), + ssl_opts => get_ssl_opts(Opts), + jwks => [], + request_id => undefined}. + +get_ssl_opts(Opts) -> + case proplists:get_value(enable_ssl, Opts) of + false -> []; + true -> + maps:to_list(maps:with([cacertfile, + keyfile, + certfile, + verify, + server_name_indication], maps:from_list(Opts))) + end. + +refresh_jwks(#{endpoint := Endpoint, + ssl_opts := SSLOpts} = State) -> + HTTPOpts = [{timeout, 5000}, {connect_timeout, 5000}, {ssl, SSLOpts}], + NState = case httpc:request(get, {Endpoint, [{"Accept", "application/json"}]}, HTTPOpts, + [{body_format, binary}, {sync, false}, {receiver, self()}]) of + {error, Reason} -> + ?LOG(error, "Failed to request jwks endpoint(~s): ~p", [Endpoint, Reason]), + State; + {ok, RequestID} -> + State#{request_id := RequestID} + end, + ensure_expiry_timer(NState). + +ensure_expiry_timer(State = #{refresh_interval := Interval}) -> + State#{refresh_timer := emqx_misc:start_timer(timer:seconds(Interval), refresh_jwks)}. + +cancel_timer(State = #{refresh_timer := undefined}) -> + State; +cancel_timer(State = #{refresh_timer := TRef}) -> + _ = emqx_misc:cancel_timer(TRef), + State#{refresh_timer := undefined}. + +limit_refresh_interval(Interval) when Interval < 10 -> + 10; +limit_refresh_interval(Interval) -> + Interval. diff --git a/apps/emqx_authentication/src/emqx_authentication_jwt.erl b/apps/emqx_authentication/src/emqx_authentication_jwt.erl new file mode 100644 index 000000000..2b8024e1c --- /dev/null +++ b/apps/emqx_authentication/src/emqx_authentication_jwt.erl @@ -0,0 +1,409 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_authentication_jwt). + +-export([ create/3 + , update/4 + , authenticate/2 + , destroy/1 + ]). + +-service_type(#{ + name => jwt, + params_spec => #{ + use_jwks => #{ + order => 1, + type => boolean + }, + jwks_endpoint => #{ + order => 2, + type => string + }, + refresh_interval => #{ + order => 3, + type => number + }, + algorithm => #{ + order => 3, + type => string, + enum => [<<"hmac-based">>, <<"public-key">>] + }, + secret => #{ + order => 4, + type => string + }, + secret_base64_encoded => #{ + order => 5, + type => boolean + }, + jwt_certfile => #{ + order => 6, + type => file + }, + cacertfile => #{ + order => 7, + type => file + }, + keyfile => #{ + order => 8, + type => file + }, + certfile => #{ + order => 9, + type => file + }, + verify => #{ + order => 10, + type => boolean + }, + server_name_indication => #{ + order => 11, + type => string + } + } +}). + +-define(RULES, + #{ + use_jwks => [], + jwks_endpoint => [use_jwks], + refresh_interval => [use_jwks], + algorithm => [use_jwks], + secret => [algorithm], + secret_base64_encoded => [algorithm], + jwt_certfile => [algorithm], + cacertfile => [jwks_endpoint], + keyfile => [jwks_endpoint], + certfile => [jwks_endpoint], + verify => [jwks_endpoint], + server_name_indication => [jwks_endpoint], + verify_claims => [] + }). + +create(_ChainID, _ServiceName, Params) -> + try handle_options(Params) of + Opts -> + do_create(Opts) + catch + {error, Reason} -> + {error, Reason} + end. + +update(_ChainID, _ServiceName, Params, State) -> + try handle_options(Params) of + Opts -> + do_update(Opts, State) + catch + {error, Reason} -> + {error, Reason} + end. + +authenticate(ClientInfo = #{password := JWT}, #{jwk := JWK, + verify_claims := VerifyClaims0}) -> + JWKs = case erlang:is_pid(JWK) of + false -> + [JWK]; + true -> + {ok, JWKs0} = emqx_authentication_jwks_connector:get_jwks(JWK), + JWKs0 + end, + VerifyClaims = replace_placeholder(VerifyClaims0, ClientInfo), + case verify(JWT, JWKs, VerifyClaims) of + ok -> ok; + {error, invalid_signature} -> ignore; + {error, {claims, _}} -> {stop, bad_passowrd} + end. + +destroy(#{jwks_connector := undefined}) -> + ok; +destroy(#{jwks_connector := Connector}) -> + _ = emqx_authentication_jwks_connector:stop(Connector), + ok. + +%%-------------------------------------------------------------------- +%% Internal functions +%%-------------------------------------------------------------------- + +do_create(#{use_jwks := false, + algorithm := 'hmac-based', + secret := Secret0, + secret_base64_encoded := Base64Encoded} = Opts) -> + Secret = case Base64Encoded of + true -> + base64:decode(Secret0); + false -> + Secret0 + end, + JWK = jose_jwk:from_oct(Secret), + {ok, #{jwk => JWK, + verify_claims => maps:get(verify_claims, Opts)}}; + +do_create(#{use_jwks := false, + algorithm := 'public-key', + jwt_certfile := Certfile} = Opts) -> + JWK = jose_jwk:from_pem_file(Certfile), + {ok, #{jwk => JWK, + verify_claims => maps:get(verify_claims, Opts)}}; + +do_create(#{use_jwks := true} = Opts) -> + case emqx_authentication_jwks_connector:start_link(Opts) of + {ok, Connector} -> + {ok, #{jwk => Connector, + verify_claims => maps:get(verify_claims, Opts)}}; + {error, Reason} -> + {error, Reason} + end. + +do_update(Opts, #{jwk_connector := undefined}) -> + do_create(Opts); +do_update(#{use_jwks := false} = Opts, #{jwk_connector := Connector}) -> + _ = emqx_authentication_jwks_connector:stop(Connector), + do_create(Opts); +do_update(#{use_jwks := true} = Opts, #{jwk_connector := Connector} = State) -> + ok = emqx_authentication_jwks_connector:update(Connector, Opts), + {ok, State}. + +replace_placeholder(L, Variables) -> + replace_placeholder(L, Variables, []). + +replace_placeholder([], _Variables, Acc) -> + Acc; +replace_placeholder([{Name, {placeholder, PL}} | More], Variables, Acc) -> + Value = maps:get(PL, Variables), + replace_placeholder(More, Variables, [{Name, Value} | Acc]); +replace_placeholder([{Name, Value} | More], Variables, Acc) -> + replace_placeholder(More, Variables, [{Name, Value} | Acc]). + +verify(_JWS, [], _VerifyClaims) -> + {error, invalid_signature}; +verify(JWS, [JWK | More], VerifyClaims) -> + case jose_jws:verify(JWK, JWS) of + {true, Payload, _JWS} -> + Claims = emqx_json:decode(Payload, [return_maps]), + verify_claims(Claims, VerifyClaims); + {false, _, _} -> + verify(JWS, More, VerifyClaims) + end. + +verify_claims(Claims, VerifyClaims0) -> + Now = os:system_time(seconds), + VerifyClaims = [{<<"exp">>, fun(ExpireTime) -> + Now < ExpireTime + end}, + {<<"iat">>, fun(IssueAt) -> + IssueAt =< Now + end}, + {<<"nbf">>, fun(NotBefore) -> + NotBefore =< Now + end}] ++ VerifyClaims0, + do_verify_claims(Claims, VerifyClaims). + +do_verify_claims(_Claims, []) -> + ok; +do_verify_claims(Claims, [{Name, Fun} | More]) when is_function(Fun) -> + case maps:take(Name, Claims) of + error -> + do_verify_claims(Claims, More); + {Value, NClaims} -> + case Fun(Value) of + true -> + do_verify_claims(NClaims, More); + _ -> + {error, {claims, {Name, Value}}} + end + end; +do_verify_claims(Claims, [{Name, Value} | More]) -> + case maps:take(Name, Claims) of + error -> + do_verify_claims(Claims, More); + {Value, NClaims} -> + do_verify_claims(NClaims, More); + {Value0, _} -> + {error, {claims, {Name, Value0}}} + end. + +handle_options(Opts0) when is_map(Opts0) -> + Ks = maps:fold(fun(K, _, Acc) -> + [atom_to_binary(K, utf8) | Acc] + end, [], ?RULES), + Opts1 = maps:to_list(maps:with(Ks, Opts0)), + handle_options([{binary_to_existing_atom(K, utf8), V} || {K, V} <- Opts1]); + +handle_options(Opts0) when is_list(Opts0) -> + Opts1 = add_missing_options(Opts0), + process_options({Opts1, [], length(Opts1)}, #{}). + +add_missing_options(Opts) -> + AllOpts = maps:keys(?RULES), + Fun = fun(K, Acc) -> + case proplists:is_defined(K, Acc) of + true -> + Acc; + false -> + [{K, unbound} | Acc] + end + end, + lists:foldl(Fun, Opts, AllOpts). + +process_options({[], [], _}, OptsMap) -> + OptsMap; +process_options({[], Skipped, Counter}, OptsMap) + when length(Skipped) < Counter -> + process_options({Skipped, [], length(Skipped)}, OptsMap); +process_options({[], _Skipped, _Counter}, _OptsMap) -> + throw({error, faulty_configuration}); +process_options({[{K, V} = Opt | More], Skipped, Counter}, OptsMap0) -> + case check_dependencies(K, OptsMap0) of + true -> + OptsMap1 = handle_option(K, V, OptsMap0), + process_options({More, Skipped, Counter}, OptsMap1); + false -> + process_options({More, [Opt | Skipped], Counter}, OptsMap0) + end. + +%% TODO: This is not a particularly good implementation(K => needless), it needs to be improved +handle_option(use_jwks, true, OptsMap) -> + OptsMap#{use_jwks => true, + algorithm => needless}; +handle_option(use_jwks, false, OptsMap) -> + OptsMap#{use_jwks => false, + jwks_endpoint => needless}; +handle_option(jwks_endpoint = Opt, unbound, #{use_jwks := true}) -> + throw({error, {options, {Opt, unbound}}}); +handle_option(jwks_endpoint, Value, #{use_jwks := true} = OptsMap) + when Value =/= unbound -> + case emqx_http_lib:uri_parse(Value) of + {ok, #{scheme := http}} -> + OptsMap#{enable_ssl => false, + jwks_endpoint => Value}; + {ok, #{scheme := https}} -> + OptsMap#{enable_ssl => true, + jwks_endpoint => Value}; + {error, _Reason} -> + throw({error, {options, {jwks_endpoint, Value}}}) + end; +handle_option(refresh_interval = Opt, Value0, #{use_jwks := true} = OptsMap) -> + Value = validate_option(Opt, Value0), + OptsMap#{Opt => Value}; +handle_option(algorithm = Opt, Value0, #{use_jwks := false} = OptsMap) -> + Value = validate_option(Opt, Value0), + OptsMap#{Opt => Value}; +handle_option(secret = Opt, unbound, #{algorithm := 'hmac-based'}) -> + throw({error, {options, {Opt, unbound}}}); +handle_option(secret = Opt, Value, #{algorithm := 'hmac-based'} = OptsMap) -> + OptsMap#{Opt => Value}; +handle_option(secret_base64_encoded = Opt, Value0, #{algorithm := 'hmac-based'} = OptsMap) -> + Value = validate_option(Opt, Value0), + OptsMap#{Opt => Value}; +handle_option(jwt_certfile = Opt, unbound, #{algorithm := 'public-key'}) -> + throw({error, {options, {Opt, unbound}}}); +handle_option(jwt_certfile = Opt, Value, #{algorithm := 'public-key'} = OptsMap) -> + OptsMap#{Opt => Value}; +handle_option(verify = Opt, Value0, #{enable_ssl := true} = OptsMap) -> + Value = validate_option(Opt, Value0), + OptsMap#{Opt => Value}; +handle_option(cacertfile = Opt, Value, #{enable_ssl := true} = OptsMap) + when Value =/= unbound -> + OptsMap#{Opt => Value}; +handle_option(certfile, unbound, #{enable_ssl := true} = OptsMap) -> + OptsMap; +handle_option(certfile = Opt, Value, #{enable_ssl := true} = OptsMap) -> + OptsMap#{Opt => Value}; +handle_option(keyfile, unbound, #{enable_ssl := true} = OptsMap) -> + OptsMap; +handle_option(keyfile = Opt, Value, #{enable_ssl := true} = OptsMap) -> + OptsMap#{Opt => Value}; +handle_option(server_name_indication = Opt, Value0, #{enable_ssl := true} = OptsMap) -> + Value = validate_option(Opt, Value0), + OptsMap#{Opt => Value}; +handle_option(verify_claims = Opt, Value0, OptsMap) -> + Value = handle_verify_claims(Value0), + OptsMap#{Opt => Value}; +handle_option(_Opt, _Value, OptsMap) -> + OptsMap. + +validate_option(refresh_interval, unbound) -> + 300; +validate_option(refresh_interval, Value) when is_integer(Value) -> + Value; +validate_option(algorithm, <<"hmac-based">>) -> + 'hmac-based'; +validate_option(algorithm, <<"public-key">>) -> + 'public-key'; +validate_option(secret_base64_encoded, unbound) -> + false; +validate_option(secret_base64_encoded, Value) when is_boolean(Value) -> + Value; +validate_option(verify, unbound) -> + verify_none; +validate_option(verify, true) -> + verify_peer; +validate_option(verify, false) -> + verify_none; +validate_option(server_name_indication, unbound) -> + disable; +validate_option(server_name_indication, <<"disable">>) -> + disable; +validate_option(server_name_indication, Value) when is_list(Value) -> + Value; +validate_option(Opt, Value) -> + throw({error, {options, {Opt, Value}}}). + +handle_verify_claims(Opts0) -> + try handle_verify_claims(Opts0, []) + catch + error:_ -> + throw({error, {options, {verify_claims, Opts0}}}) + end. + +handle_verify_claims([], Acc) -> + Acc; +handle_verify_claims([{Name, Expected0} | More], Acc) + when is_binary(Name) andalso is_binary(Expected0) -> + Expected = handle_placeholder(Expected0), + handle_verify_claims(More, [{Name, Expected} | Acc]). + +handle_placeholder(Placeholder0) -> + case re:run(Placeholder0, "^\\$\\{[a-z0-9\\_]+\\}$", [{capture, all}]) of + {match, [{Offset, Length}]} -> + Placeholder1 = binary:part(Placeholder0, Offset + 2, Length - 3), + Placeholder2 = validate_placeholder(Placeholder1), + {placeholder, Placeholder2}; + nomatch -> + Placeholder0 + end. + +validate_placeholder(<<"clientid">>) -> + clientid; +validate_placeholder(<<"username">>) -> + username. + +check_dependencies(Opt, OptsMap) -> + case maps:get(Opt, ?RULES) of + [] -> + true; + Deps -> + option_already_defined(Opt, OptsMap) orelse + dependecies_already_defined(Deps, OptsMap) + end. + +option_already_defined(Opt, OptsMap) -> + maps:get(Opt, OptsMap, unbound) =/= unbound. + +dependecies_already_defined(Deps, OptsMap) -> + Fun = fun(Opt) -> option_already_defined(Opt, OptsMap) end, + lists:all(Fun, Deps). diff --git a/apps/emqx_authentication/src/emqx_authentication_mnesia.erl b/apps/emqx_authentication/src/emqx_authentication_mnesia.erl new file mode 100644 index 000000000..53dc4dd73 --- /dev/null +++ b/apps/emqx_authentication/src/emqx_authentication_mnesia.erl @@ -0,0 +1,342 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_authentication_mnesia). + +-include("emqx_authentication.hrl"). + +-export([ create/3 + , update/4 + , authenticate/2 + , destroy/1 + ]). + +-export([ import_users/2 + , add_user/2 + , delete_user/2 + , update_user/3 + , lookup_user/2 + , list_users/1 + ]). + +%% TODO: support bcrypt +-service_type(#{ + name => mnesia, + params_spec => #{ + user_id_type => #{ + order => 1, + type => string, + enum => [<<"username">>, <<"clientid">>, <<"ip">>, <<"common name">>, <<"issuer">>], + default => <<"username">> + }, + password_hash_algorithm => #{ + order => 2, + type => string, + enum => [<<"plain">>, <<"md5">>, <<"sha">>, <<"sha256">>, <<"sha512">>, <<"bcrypt">>], + default => <<"sha256">> + }, + salt_rounds => #{ + order => 3, + type => number, + default => 10 + } + } +}). + +-record(user_info, + { user_id :: {user_group(), user_id()} + , password_hash :: binary() + , salt :: binary() + }). + +-type(user_group() :: {chain_id(), service_name()}). +-type(user_id() :: binary()). + +-export([mnesia/1]). + +-boot_mnesia({mnesia, [boot]}). +-copy_mnesia({mnesia, [copy]}). + +-define(TAB, mnesia_basic_auth). + +%%------------------------------------------------------------------------------ +%% Mnesia bootstrap +%%------------------------------------------------------------------------------ + +%% @doc Create or replicate tables. +-spec(mnesia(boot | copy) -> ok). +mnesia(boot) -> + ok = ekka_mnesia:create_table(?TAB, [ + {disc_copies, [node()]}, + {record_name, user_info}, + {attributes, record_info(fields, user_info)}, + {storage_properties, [{ets, [{read_concurrency, true}]}]}]); + +mnesia(copy) -> + ok = ekka_mnesia:copy_table(?TAB, disc_copies). + +create(ChainID, ServiceName, #{<<"user_id_type">> := Type, + <<"password_hash_algorithm">> := Algorithm, + <<"salt_rounds">> := SaltRounds}) -> + Algorithm =:= <<"bcrypt">> andalso ({ok, _} = application:ensure_all_started(bcrypt)), + State = #{user_group => {ChainID, ServiceName}, + user_id_type => binary_to_atom(Type, utf8), + password_hash_algorithm => binary_to_atom(Algorithm, utf8), + salt_rounds => SaltRounds}, + {ok, State}. + +update(ChainID, ServiceName, Params, _State) -> + create(ChainID, ServiceName, Params). + +authenticate(ClientInfo = #{password := Password}, + #{user_group := UserGroup, + user_id_type := Type, + password_hash_algorithm := Algorithm}) -> + UserID = get_user_identity(ClientInfo, Type), + case mnesia:dirty_read(?TAB, {UserGroup, UserID}) of + [] -> + ignore; + [#user_info{password_hash = PasswordHash, salt = Salt}] -> + case PasswordHash =:= hash(Algorithm, Password, Salt) of + true -> ok; + false -> {stop, bad_password} + end + end. + +destroy(#{user_group := UserGroup}) -> + trans( + fun() -> + MatchSpec = [{{user_info, {UserGroup, '_'}, '_', '_'}, [], ['$_']}], + ok = lists:foreach(fun delete_user2/1, mnesia:select(?TAB, MatchSpec, write)) + end). + +import_users(Filename0, State) -> + Filename = to_binary(Filename0), + case filename:extension(Filename) of + <<".json">> -> + import_users_from_json(Filename, State); + <<".csv">> -> + import_users_from_csv(Filename, State); + <<>> -> + {error, unknown_file_format}; + Extension -> + {error, {unsupported_file_format, Extension}} + end. + +add_user(#{<<"user_id">> := UserID, + <<"password">> := Password}, + #{user_group := UserGroup} = State) -> + trans( + fun() -> + case mnesia:read(?TAB, {UserGroup, UserID}, write) of + [] -> + add(UserID, Password, State), + {ok, #{user_id => UserID}}; + [_] -> + {error, already_exist} + end + end). + +delete_user(UserID, #{user_group := UserGroup}) -> + trans( + fun() -> + case mnesia:read(?TAB, {UserGroup, UserID}, write) of + [] -> + {error, not_found}; + [_] -> + mnesia:delete(?TAB, {UserGroup, UserID}, write) + end + end). + +update_user(UserID, #{<<"password">> := Password}, + #{user_group := UserGroup} = State) -> + trans( + fun() -> + case mnesia:read(?TAB, {UserGroup, UserID}, write) of + [] -> + {error, not_found}; + [_] -> + add(UserID, Password, State), + {ok, #{user_id => UserID}} + end + end). + +lookup_user(UserID, #{user_group := UserGroup}) -> + case mnesia:dirty_read(?TAB, {UserGroup, UserID}) of + [#user_info{user_id = {_, UserID}}] -> + {ok, #{user_id => UserID}}; + [] -> + {error, not_found} + end. + +list_users(#{user_group := UserGroup}) -> + Users = [#{user_id => UserID} || #user_info{user_id = {UserGroup0, UserID}} <- ets:tab2list(?TAB), UserGroup0 =:= UserGroup], + {ok, Users}. + +%%------------------------------------------------------------------------------ +%% Internal functions +%%------------------------------------------------------------------------------ + +%% Example: data/user-credentials.json +import_users_from_json(Filename, #{user_group := UserGroup}) -> + case file:read_file(Filename) of + {ok, Bin} -> + case emqx_json:safe_decode(Bin, [return_maps]) of + {ok, List} -> + trans(fun import/2, [UserGroup, List]); + {error, Reason} -> + {error, Reason} + end; + {error, Reason} -> + {error, Reason} + end. + +%% Example: data/user-credentials.csv +import_users_from_csv(Filename, #{user_group := UserGroup}) -> + case file:open(Filename, [read, binary]) of + {ok, File} -> + case get_csv_header(File) of + {ok, Seq} -> + Result = trans(fun import/3, [UserGroup, File, Seq]), + _ = file:close(File), + Result; + {error, Reason} -> + {error, Reason} + end; + {error, Reason} -> + {error, Reason} + end. + +import(_UserGroup, []) -> + ok; +import(UserGroup, [#{<<"user_id">> := UserID, + <<"password_hash">> := PasswordHash} = UserInfo | More]) + when is_binary(UserID) andalso is_binary(PasswordHash) -> + Salt = maps:get(<<"salt">>, UserInfo, <<>>), + insert_user(UserGroup, UserID, PasswordHash, Salt), + import(UserGroup, More); +import(_UserGroup, [_ | _More]) -> + {error, bad_format}. + +%% Importing 5w users needs 1.7 seconds +import(UserGroup, File, Seq) -> + case file:read_line(File) of + {ok, Line} -> + Fields = binary:split(Line, [<<",">>, <<" ">>, <<"\n">>], [global, trim_all]), + case get_user_info_by_seq(Fields, Seq) of + {ok, #{user_id := UserID, + password_hash := PasswordHash} = UserInfo} -> + Salt = maps:get(salt, UserInfo, <<>>), + insert_user(UserGroup, UserID, PasswordHash, Salt), + import(UserGroup, File, Seq); + {error, Reason} -> + {error, Reason} + end; + eof -> + ok; + {error, Reason} -> + {error, Reason} + end. + +get_csv_header(File) -> + case file:read_line(File) of + {ok, Line} -> + Seq = binary:split(Line, [<<",">>, <<" ">>, <<"\n">>], [global, trim_all]), + {ok, Seq}; + eof -> + {error, empty_file}; + {error, Reason} -> + {error, Reason} + end. + +get_user_info_by_seq(Fields, Seq) -> + get_user_info_by_seq(Fields, Seq, #{}). + +get_user_info_by_seq([], [], #{user_id := _, password_hash := _, salt := _} = Acc) -> + {ok, Acc}; +get_user_info_by_seq([], [], #{user_id := _, password_hash := _} = Acc) -> + {ok, Acc}; +get_user_info_by_seq(_, [], _) -> + {error, bad_format}; +get_user_info_by_seq([UserID | More1], [<<"user_id">> | More2], Acc) -> + get_user_info_by_seq(More1, More2, Acc#{user_id => UserID}); +get_user_info_by_seq([PasswordHash | More1], [<<"password_hash">> | More2], Acc) -> + get_user_info_by_seq(More1, More2, Acc#{password_hash => PasswordHash}); +get_user_info_by_seq([Salt | More1], [<<"salt">> | More2], Acc) -> + get_user_info_by_seq(More1, More2, Acc#{salt => Salt}); +get_user_info_by_seq(_, _, _) -> + {error, bad_format}. + +-compile({inline, [add/3]}). +add(UserID, Password, #{user_group := UserGroup, + password_hash_algorithm := Algorithm} = State) -> + Salt = gen_salt(State), + PasswordHash = hash(Algorithm, Password, Salt), + case Algorithm of + bcrypt -> insert_user(UserGroup, UserID, PasswordHash); + _ -> insert_user(UserGroup, UserID, PasswordHash, Salt) + end. + +gen_salt(#{password_hash_algorithm := plain}) -> + <<>>; +gen_salt(#{password_hash_algorithm := bcrypt, + salt_rounds := Rounds}) -> + {ok, Salt} = bcrypt:gen_salt(Rounds), + Salt; +gen_salt(_) -> + <> = crypto:strong_rand_bytes(16), + iolist_to_binary(io_lib:format("~32.16.0b", [X])). + +hash(bcrypt, Password, Salt) -> + {ok, Hash} = bcrypt:hashpw(Password, Salt), + list_to_binary(Hash); +hash(Algorithm, Password, Salt) -> + emqx_passwd:hash(Algorithm, <>). + +insert_user(UserGroup, UserID, PasswordHash) -> + insert_user(UserGroup, UserID, PasswordHash, <<>>). + +insert_user(UserGroup, UserID, PasswordHash, Salt) -> + Credential = #user_info{user_id = {UserGroup, UserID}, + password_hash = PasswordHash, + salt = Salt}, + mnesia:write(?TAB, Credential, write). + +delete_user2(UserInfo) -> + mnesia:delete_object(?TAB, UserInfo, write). + +%% TODO: Support other type +get_user_identity(#{username := Username}, username) -> + Username; +get_user_identity(#{clientid := ClientID}, clientid) -> + ClientID; +get_user_identity(_, Type) -> + {error, {bad_user_identity_type, Type}}. + +trans(Fun) -> + trans(Fun, []). + +trans(Fun, Args) -> + case mnesia:transaction(Fun, Args) of + {atomic, Res} -> Res; + {aborted, Reason} -> {error, Reason} + end. + + +to_binary(B) when is_binary(B) -> + B; +to_binary(L) when is_list(L) -> + iolist_to_binary(L). diff --git a/apps/emqx_authentication/src/emqx_authentication_sup.erl b/apps/emqx_authentication/src/emqx_authentication_sup.erl new file mode 100644 index 000000000..06e12ce6c --- /dev/null +++ b/apps/emqx_authentication/src/emqx_authentication_sup.erl @@ -0,0 +1,29 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_authentication_sup). + +-behaviour(supervisor). + +-export([ start_link/0 + , init/1 + ]). + +start_link() -> + supervisor:start_link({local, ?MODULE}, ?MODULE, []). + +init([]) -> + {ok, {{one_for_one, 10, 10}, []}}. diff --git a/apps/emqx_authentication/test/data/user-credentials.csv b/apps/emqx_authentication/test/data/user-credentials.csv new file mode 100644 index 000000000..2543d39ca --- /dev/null +++ b/apps/emqx_authentication/test/data/user-credentials.csv @@ -0,0 +1,3 @@ +user_id,password_hash,salt +myuser3,b6c743545a7817ae8c8f624371d5f5f0373234bb0ff36b8ffbf19bce0e06ab75,de1024f462fb83910fd13151bd4bd235 +myuser4,ee68c985a69208b6eda8c6c9b4c7c2d2b15ee2352cdd64a903171710a99182e8,ad773b5be9dd0613fe6c2f4d8c403139 diff --git a/apps/emqx_authentication/test/data/user-credentials.json b/apps/emqx_authentication/test/data/user-credentials.json new file mode 100644 index 000000000..169122bd2 --- /dev/null +++ b/apps/emqx_authentication/test/data/user-credentials.json @@ -0,0 +1,12 @@ +[ + { + "user_id":"myuser1", + "password_hash":"c5e46903df45e5dc096dc74657610dbee8deaacae656df88a1788f1847390242", + "salt": "e378187547bf2d6f0545a3f441aa4d8a" + }, + { + "user_id":"myuser2", + "password_hash":"f4d17f300b11e522fd33f497c11b126ef1ea5149c74d2220f9a16dc876d4567b", + "salt": "6d3f9bd5b54d94b98adbcfe10b6d181f" + } +] diff --git a/apps/emqx_authentication/test/emqx_authentication_SUITE.erl b/apps/emqx_authentication/test/emqx_authentication_SUITE.erl new file mode 100644 index 000000000..d110d940a --- /dev/null +++ b/apps/emqx_authentication/test/emqx_authentication_SUITE.erl @@ -0,0 +1,191 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_authentication_SUITE). + +-compile(export_all). +-compile(nowarn_export_all). + +-include_lib("common_test/include/ct.hrl"). +-include_lib("eunit/include/eunit.hrl"). + +-define(AUTH, emqx_authentication). + +all() -> + emqx_ct:all(?MODULE). + +init_per_suite(Config) -> + emqx_ct_helpers:start_apps([emqx_authentication]), + Config. + +end_per_suite(_) -> + emqx_ct_helpers:stop_apps([emqx_authentication]), + ok. + +t_chain(_) -> + ChainID = <<"mychain">>, + ?assertMatch({ok, #{id := ChainID, services := []}}, ?AUTH:create_chain(#{id => ChainID})), + ?assertEqual({error, {already_exists, {chain, ChainID}}}, ?AUTH:create_chain(#{id => ChainID})), + ?assertMatch({ok, #{id := ChainID, services := []}}, ?AUTH:lookup_chain(ChainID)), + ?assertEqual(ok, ?AUTH:delete_chain(ChainID)), + ?assertMatch({error, {not_found, {chain, ChainID}}}, ?AUTH:lookup_chain(ChainID)), + ok. + +t_service(_) -> + ChainID = <<"mychain">>, + ?assertMatch({ok, #{id := ChainID, services := []}}, ?AUTH:create_chain(#{id => ChainID})), + ?assertMatch({ok, #{id := ChainID, services := []}}, ?AUTH:lookup_chain(ChainID)), + + ServiceName1 = <<"myservice1">>, + ServiceParams1 = #{name => ServiceName1, + type => mnesia, + params => #{ + user_id_type => <<"username">>, + password_hash_algorithm => <<"sha256">>}}, + ?assertEqual({ok, [ServiceParams1]}, ?AUTH:add_services(ChainID, [ServiceParams1])), + ?assertEqual({ok, ServiceParams1}, ?AUTH:lookup_service(ChainID, ServiceName1)), + ?assertEqual({ok, [ServiceParams1]}, ?AUTH:list_services(ChainID)), + ?assertEqual({error, {already_exists, {service, ServiceName1}}}, ?AUTH:add_services(ChainID, [ServiceParams1])), + + ServiceName2 = <<"myservice2">>, + ServiceParams2 = ServiceParams1#{name => ServiceName2}, + ?assertEqual({ok, [ServiceParams2]}, ?AUTH:add_services(ChainID, [ServiceParams2])), + ?assertMatch({ok, #{id := ChainID, services := [ServiceParams1, ServiceParams2]}}, ?AUTH:lookup_chain(ChainID)), + ?assertEqual({ok, ServiceParams2}, ?AUTH:lookup_service(ChainID, ServiceName2)), + ?assertEqual({ok, [ServiceParams1, ServiceParams2]}, ?AUTH:list_services(ChainID)), + + ?assertEqual(ok, ?AUTH:move_service_to_the_front(ChainID, ServiceName2)), + ?assertEqual({ok, [ServiceParams2, ServiceParams1]}, ?AUTH:list_services(ChainID)), + ?assertEqual(ok, ?AUTH:move_service_to_the_end(ChainID, ServiceName2)), + ?assertEqual({ok, [ServiceParams1, ServiceParams2]}, ?AUTH:list_services(ChainID)), + ?assertEqual(ok, ?AUTH:move_service_to_the_nth(ChainID, ServiceName2, 1)), + ?assertEqual({ok, [ServiceParams2, ServiceParams1]}, ?AUTH:list_services(ChainID)), + ?assertEqual({error, out_of_range}, ?AUTH:move_service_to_the_nth(ChainID, ServiceName2, 3)), + ?assertEqual({error, out_of_range}, ?AUTH:move_service_to_the_nth(ChainID, ServiceName2, 0)), + ?assertEqual(ok, ?AUTH:delete_services(ChainID, [ServiceName1, ServiceName2])), + ?assertEqual({ok, []}, ?AUTH:list_services(ChainID)), + ?assertEqual(ok, ?AUTH:delete_chain(ChainID)), + ok. + +t_mnesia_service(_) -> + ChainID = <<"mychain">>, + ?assertMatch({ok, #{id := ChainID, services := []}}, ?AUTH:create_chain(#{id => ChainID})), + + ServiceName = <<"myservice">>, + ServiceParams = #{name => ServiceName, + type => mnesia, + params => #{ + user_id_type => <<"username">>, + password_hash_algorithm => <<"sha256">>}}, + ?assertEqual({ok, [ServiceParams]}, ?AUTH:add_services(ChainID, [ServiceParams])), + + UserInfo = #{<<"user_id">> => <<"myuser">>, + <<"password">> => <<"mypass">>}, + ?assertEqual({ok, #{user_id => <<"myuser">>}}, ?AUTH:add_user(ChainID, ServiceName, UserInfo)), + ?assertEqual({ok, #{user_id => <<"myuser">>}}, ?AUTH:lookup_user(ChainID, ServiceName, <<"myuser">>)), + ClientInfo = #{chain_id => ChainID, + username => <<"myuser">>, + password => <<"mypass">>}, + ?assertEqual(ok, ?AUTH:authenticate(ClientInfo)), + ClientInfo2 = ClientInfo#{username => <<"baduser">>}, + ?assertEqual({error, user_not_found}, ?AUTH:authenticate(ClientInfo2)), + ClientInfo3 = ClientInfo#{password => <<"badpass">>}, + ?assertEqual({error, bad_password}, ?AUTH:authenticate(ClientInfo3)), + UserInfo2 = UserInfo#{<<"password">> => <<"mypass2">>}, + ?assertEqual({ok, #{user_id => <<"myuser">>}}, ?AUTH:update_user(ChainID, ServiceName, <<"myuser">>, UserInfo2)), + ClientInfo4 = ClientInfo#{password => <<"mypass2">>}, + ?assertEqual(ok, ?AUTH:authenticate(ClientInfo4)), + ?assertEqual(ok, ?AUTH:delete_user(ChainID, ServiceName, <<"myuser">>)), + ?assertEqual({error, not_found}, ?AUTH:lookup_user(ChainID, ServiceName, <<"myuser">>)), + + ?assertEqual({ok, #{user_id => <<"myuser">>}}, ?AUTH:add_user(ChainID, ServiceName, UserInfo)), + ?assertMatch({ok, #{user_id := <<"myuser">>}}, ?AUTH:lookup_user(ChainID, ServiceName, <<"myuser">>)), + ?assertEqual(ok, ?AUTH:delete_services(ChainID, [ServiceName])), + ?assertEqual({ok, [ServiceParams]}, ?AUTH:add_services(ChainID, [ServiceParams])), + ?assertMatch({error, not_found}, ?AUTH:lookup_user(ChainID, ServiceName, <<"myuser">>)), + + ?assertEqual(ok, ?AUTH:delete_chain(ChainID)), + ?assertEqual([], ets:tab2list(mnesia_basic_auth)), + ok. + +t_import(_) -> + ChainID = <<"mychain">>, + ?assertMatch({ok, #{id := ChainID, services := []}}, ?AUTH:create_chain(#{id => ChainID})), + + ServiceName = <<"myservice">>, + ServiceParams = #{name => ServiceName, + type => mnesia, + params => #{ + user_id_type => <<"username">>, + password_hash_algorithm => <<"sha256">>}}, + ?assertEqual({ok, [ServiceParams]}, ?AUTH:add_services(ChainID, [ServiceParams])), + + Dir = code:lib_dir(emqx_authentication, test), + ?assertEqual(ok, ?AUTH:import_users(ChainID, ServiceName, filename:join([Dir, "data/user-credentials.json"]))), + ?assertEqual(ok, ?AUTH:import_users(ChainID, ServiceName, filename:join([Dir, "data/user-credentials.csv"]))), + ?assertMatch({ok, #{user_id := <<"myuser1">>}}, ?AUTH:lookup_user(ChainID, ServiceName, <<"myuser1">>)), + ?assertMatch({ok, #{user_id := <<"myuser3">>}}, ?AUTH:lookup_user(ChainID, ServiceName, <<"myuser3">>)), + ClientInfo1 = #{chain_id => ChainID, + username => <<"myuser1">>, + password => <<"mypassword1">>}, + ?assertEqual(ok, ?AUTH:authenticate(ClientInfo1)), + ClientInfo2 = ClientInfo1#{username => <<"myuser3">>, + password => <<"mypassword3">>}, + ?assertEqual(ok, ?AUTH:authenticate(ClientInfo2)), + ?assertEqual(ok, ?AUTH:delete_chain(ChainID)), + ok. + +t_multi_mnesia_service(_) -> + ChainID = <<"mychain">>, + ?assertMatch({ok, #{id := ChainID, services := []}}, ?AUTH:create_chain(#{id => ChainID})), + + ServiceName1 = <<"myservice1">>, + ServiceParams1 = #{name => ServiceName1, + type => mnesia, + params => #{ + user_id_type => <<"username">>, + password_hash_algorithm => <<"sha256">>}}, + ServiceName2 = <<"myservice2">>, + ServiceParams2 = #{name => ServiceName2, + type => mnesia, + params => #{ + user_id_type => <<"clientid">>, + password_hash_algorithm => <<"sha256">>}}, + ?assertEqual({ok, [ServiceParams1]}, ?AUTH:add_services(ChainID, [ServiceParams1])), + ?assertEqual({ok, [ServiceParams2]}, ?AUTH:add_services(ChainID, [ServiceParams2])), + + ?assertEqual({ok, #{user_id => <<"myuser">>}}, + ?AUTH:add_user(ChainID, ServiceName1, + #{<<"user_id">> => <<"myuser">>, + <<"password">> => <<"mypass1">>})), + ?assertEqual({ok, #{user_id => <<"myclient">>}}, + ?AUTH:add_user(ChainID, ServiceName2, + #{<<"user_id">> => <<"myclient">>, + <<"password">> => <<"mypass2">>})), + ClientInfo1 = #{chain_id => ChainID, + username => <<"myuser">>, + clientid => <<"myclient">>, + password => <<"mypass1">>}, + ?assertEqual(ok, ?AUTH:authenticate(ClientInfo1)), + ?assertEqual(ok, ?AUTH:move_service_to_the_front(ChainID, ServiceName2)), + ?assertEqual({error, bad_password}, ?AUTH:authenticate(ClientInfo1)), + ClientInfo2 = ClientInfo1#{password => <<"mypass2">>}, + ?assertEqual(ok, ?AUTH:authenticate(ClientInfo2)), + ?assertEqual(ok, ?AUTH:delete_chain(ChainID)), + ok. + + + diff --git a/rebar.config.erl b/rebar.config.erl index 5cc857398..a38aa610a 100644 --- a/rebar.config.erl +++ b/rebar.config.erl @@ -279,6 +279,7 @@ relx_plugin_apps(ReleaseType) -> , emqx_sn , emqx_coap , emqx_stomp + , emqx_authentication , emqx_auth_http , emqx_auth_mysql , emqx_auth_jwt