diff --git a/apps/emqx_auth_mnesia/include/emqx_auth_mnesia.hrl b/apps/emqx_auth_mnesia/include/emqx_auth_mnesia.hrl index 14fdbf0a6..143f6b61e 100644 --- a/apps/emqx_auth_mnesia/include/emqx_auth_mnesia.hrl +++ b/apps/emqx_auth_mnesia/include/emqx_auth_mnesia.hrl @@ -5,6 +5,8 @@ -type(acl_target() :: login() | all). +-type(acl_target_type() :: clientid | username | all). + -type(access():: allow | deny). -type(action():: pub | sub). -type(legacy_action():: action() | pubsub). diff --git a/apps/emqx_auth_mnesia/src/emqx_acl_mnesia_api.erl b/apps/emqx_auth_mnesia/src/emqx_acl_mnesia_api.erl index 3e9a8fe93..10615b3e0 100644 --- a/apps/emqx_auth_mnesia/src/emqx_acl_mnesia_api.erl +++ b/apps/emqx_auth_mnesia/src/emqx_acl_mnesia_api.erl @@ -97,11 +97,11 @@ ]). list_clientid(_Bindings, Params) -> - Table = emqx_acl_mnesia_db:login_acl_table({clientid, '_'}), + Table = emqx_acl_mnesia_db:login_acl_table(clientid), return({ok, emqx_auth_mnesia_api:paginate_qh(Table, count(Table), Params, fun emqx_acl_mnesia_db:comparing/2, fun format/1)}). list_username(_Bindings, Params) -> - Table = emqx_acl_mnesia_db:login_acl_table({username, '_'}), + Table = emqx_acl_mnesia_db:login_acl_table(username), return({ok, emqx_auth_mnesia_api:paginate_qh(Table, count(Table), Params, fun emqx_acl_mnesia_db:comparing/2, fun format/1)}). list_all(_Bindings, Params) -> diff --git a/apps/emqx_auth_mnesia/src/emqx_acl_mnesia_cli.erl b/apps/emqx_auth_mnesia/src/emqx_acl_mnesia_cli.erl index e7ffe928b..145f0ede8 100644 --- a/apps/emqx_auth_mnesia/src/emqx_acl_mnesia_cli.erl +++ b/apps/emqx_auth_mnesia/src/emqx_acl_mnesia_cli.erl @@ -26,10 +26,10 @@ cli(["list"]) -> [print_acl(Acl) || Acl <- emqx_acl_mnesia_db:all_acls()]; cli(["list", "clientid"]) -> - [print_acl(Acl) || Acl <- emqx_acl_mnesia_db:all_acls({clientid, '_'})]; + [print_acl(Acl) || Acl <- emqx_acl_mnesia_db:all_acls(clientid)]; cli(["list", "username"]) -> - [print_acl(Acl) || Acl <- emqx_acl_mnesia_db:all_acls({username, '_'})]; + [print_acl(Acl) || Acl <- emqx_acl_mnesia_db:all_acls(username)]; cli(["list", "_all"]) -> [print_acl(Acl) || Acl <- emqx_acl_mnesia_db:all_acls(all)]; diff --git a/apps/emqx_auth_mnesia/src/emqx_acl_mnesia_db.erl b/apps/emqx_auth_mnesia/src/emqx_acl_mnesia_db.erl index dc728db3b..b483e59df 100644 --- a/apps/emqx_auth_mnesia/src/emqx_acl_mnesia_db.erl +++ b/apps/emqx_auth_mnesia/src/emqx_acl_mnesia_db.erl @@ -20,7 +20,7 @@ -include_lib("stdlib/include/ms_transform.hrl"). -include_lib("stdlib/include/qlc.hrl"). -%% Acl APIs +%% ACL APIs -export([ create_table/0 , create_table2/0 ]). @@ -39,7 +39,7 @@ -export([comparing/2]). %%-------------------------------------------------------------------- -%% Acl API +%% ACL API %%-------------------------------------------------------------------- %% @doc Create table `emqx_acl` of old format rules @@ -87,7 +87,7 @@ lookup_acl(Login) -> MergedAcl = merge_acl_records(Login, OldRecs, NewAcls), lists:sort(fun comparing/2, acl_to_list(MergedAcl)). -%% @doc Remove acl +%% @doc Remove ACL -spec remove_acl(acl_target(), emqx_topic:topic()) -> ok | {error, any()}. remove_acl(Login, Topic) -> ret(mnesia:transaction(fun() -> @@ -103,19 +103,24 @@ remove_acl(Login, Topic) -> end end)). -%% @doc All Acl rules +%% @doc All ACL rules -spec(all_acls() -> list(acl_record())). all_acls() -> - all_acls({username, '_'}) ++ - all_acls({clientid, '_'}) ++ + all_acls(username) ++ + all_acls(clientid) ++ all_acls(all). -%% @doc All Acl rules transactionally +%% @doc All ACL rules of specified type +-spec(all_acls(acl_target_type()) -> list(acl_record())). +all_acls(AclTargetType) -> + lists:sort(fun comparing/2, qlc:eval(login_acl_table(AclTargetType))). + +%% @doc All ACL rules fetched transactionally -spec(all_acls_export() -> list(acl_record())). all_acls_export() -> - LoginSpecs = [{username, '_'}, {clientid, '_'}, all], - MatchSpecNew = lists:flatmap(fun login_match_spec_new/1, LoginSpecs), - MatchSpecOld = lists:flatmap(fun login_match_spec_old/1, LoginSpecs), + AclTargetTypes = [username, clientid, all], + MatchSpecNew = lists:flatmap(fun login_match_spec_new/1, AclTargetTypes), + MatchSpecOld = lists:flatmap(fun login_match_spec_old/1, AclTargetTypes), {atomic, Records} = mnesia:transaction( fun() -> @@ -125,10 +130,10 @@ all_acls_export() -> Records. %% @doc QLC table of logins matching spec --spec(login_acl_table(ets:match_pattern()) -> qlc:query_handle()). -login_acl_table(LoginSpec) -> - MatchSpecNew = login_match_spec_new(LoginSpec), - MatchSpecOld = login_match_spec_old(LoginSpec), +-spec(login_acl_table(acl_target_type()) -> qlc:query_handle()). +login_acl_table(AclTargetType) -> + MatchSpecNew = login_match_spec_new(AclTargetType), + MatchSpecOld = login_match_spec_old(AclTargetType), acl_table(MatchSpecNew, MatchSpecOld, fun ets:table/2, fun lookup_ets/2). %% @doc Combine old `emqx_acl` ACL records with a new `emqx_acl2` ACL record for a given login @@ -183,9 +188,6 @@ add_acl_old(Login, Topic, Action, Access) -> update_permission(Action, Acl, OldRecords) end. -all_acls(LoginSpec) -> - lists:sort(fun comparing/2, qlc:eval(login_acl_table(LoginSpec))). - old_recs_to_rules(OldRecs) -> lists:flatmap(fun old_rec_to_rules/1, OldRecs). @@ -221,11 +223,25 @@ comparing({_, _, _, _, CreatedAt1}, {_, _, _, _, CreatedAt2}) -> CreatedAt1 >= CreatedAt2. -login_match_spec_new(LoginSpec) -> - [{{?ACL_TABLE2, LoginSpec, '_'}, [], ['$_']}]. +login_match_spec_old(all) -> + ets:fun2ms(fun(#?ACL_TABLE{filter = {all, _}} = Record) -> + Record + end); -login_match_spec_old(LoginSpec) -> - [{{?ACL_TABLE, {LoginSpec, '_'}, '_', '_', '_'}, [], ['$_']}]. +login_match_spec_old(Type) when (Type =:= username) or (Type =:= clientid) -> + ets:fun2ms(fun(#?ACL_TABLE{filter = {{RecordType, _}, _}} = Record) + when RecordType =:= Type -> Record + end). + +login_match_spec_new(all) -> + ets:fun2ms(fun(#?ACL_TABLE2{who = all} = Record) -> + Record + end); + +login_match_spec_new(Type) when (Type =:= username) or (Type =:= clientid) -> + ets:fun2ms(fun(#?ACL_TABLE2{who = {RecordType, _}} = Record) + when RecordType =:= Type -> Record + end). acl_table(MatchSpecNew, MatchSpecOld, TableFun, LookupFun) -> TraverseFun =