refactor(auth_mnesia): Export transaction funs

This commit is contained in:
ieQu1 2022-08-18 13:52:07 +02:00
parent c89de81e19
commit 9449e3cb32
2 changed files with 137 additions and 132 deletions

View File

@ -52,6 +52,14 @@
group_match_spec/1 group_match_spec/1
]). ]).
%% Internal exports (RPC)
-export([
do_destroy/1,
do_add_user/2,
do_delete_user/2,
do_update_user/3
]).
-define(TAB, ?MODULE). -define(TAB, ?MODULE).
-define(AUTHN_QSCHEMA, [ -define(AUTHN_QSCHEMA, [
{<<"like_user_id">>, binary}, {<<"like_user_id">>, binary},
@ -170,83 +178,79 @@ authenticate(_Credential, _State) ->
ignore. ignore.
destroy(#{user_group := UserGroup}) -> destroy(#{user_group := UserGroup}) ->
trans(fun ?MODULE:do_destroy/1, [UserGroup]).
do_destroy(UserGroup) ->
MatchSpec = group_match_spec(UserGroup), MatchSpec = group_match_spec(UserGroup),
trans( ok = lists:foreach(
fun() -> fun(UserInfo) ->
ok = lists:foreach( mnesia:delete_object(?TAB, UserInfo, write)
fun(UserInfo) -> end,
mnesia:delete_object(?TAB, UserInfo, write) mnesia:select(?TAB, MatchSpec, write)
end,
mnesia:select(?TAB, MatchSpec, write)
)
end
). ).
add_user( add_user(UserInfo, State) ->
trans(fun ?MODULE:do_add_user/2, [UserInfo, State]).
do_add_user(
#{ #{
user_id := UserID, user_id := UserID,
password := Password password := Password
} = UserInfo, } = UserInfo,
#{user_group := UserGroup} = State #{user_group := UserGroup} = State
) -> ) ->
trans( case mnesia:read(?TAB, {UserGroup, UserID}, write) of
fun() -> [] ->
case mnesia:read(?TAB, {UserGroup, UserID}, write) of IsSuperuser = maps:get(is_superuser, UserInfo, false),
[] -> add_user(UserGroup, UserID, Password, IsSuperuser, State),
IsSuperuser = maps:get(is_superuser, UserInfo, false), {ok, #{user_id => UserID, is_superuser => IsSuperuser}};
add_user(UserGroup, UserID, Password, IsSuperuser, State), [_] ->
{ok, #{user_id => UserID, is_superuser => IsSuperuser}}; {error, already_exist}
[_] -> end.
{error, already_exist}
end
end
).
delete_user(UserID, #{user_group := UserGroup}) -> delete_user(UserID, State) ->
trans( trans(fun ?MODULE:do_delete_user/2, [UserID, State]).
fun() ->
case mnesia:read(?TAB, {UserGroup, UserID}, write) of
[] ->
{error, not_found};
[_] ->
mnesia:delete(?TAB, {UserGroup, UserID}, write)
end
end
).
update_user( do_delete_user(UserID, #{user_group := UserGroup}) ->
case mnesia:read(?TAB, {UserGroup, UserID}, write) of
[] ->
{error, not_found};
[_] ->
mnesia:delete(?TAB, {UserGroup, UserID}, write)
end.
update_user(UserID, User, State) ->
trans(fun ?MODULE:do_update_user/3, [UserID, User, State]).
do_update_user(
UserID, UserID,
User, User,
#{user_group := UserGroup} = State #{user_group := UserGroup} = State
) -> ) ->
trans( case mnesia:read(?TAB, {UserGroup, UserID}, write) of
fun() -> [] ->
case mnesia:read(?TAB, {UserGroup, UserID}, write) of {error, not_found};
[] -> [#user_info{is_superuser = IsSuperuser} = UserInfo] ->
{error, not_found}; UserInfo1 = UserInfo#user_info{
[#user_info{is_superuser = IsSuperuser} = UserInfo] -> is_superuser = maps:get(is_superuser, User, IsSuperuser)
UserInfo1 = UserInfo#user_info{ },
is_superuser = maps:get(is_superuser, User, IsSuperuser) UserInfo2 =
}, case maps:get(password, User, undefined) of
UserInfo2 = undefined ->
case maps:get(password, User, undefined) of UserInfo1;
undefined -> Password ->
UserInfo1; {StoredKey, ServerKey, Salt} = esasl_scram:generate_authentication_info(
Password -> Password, State
{StoredKey, ServerKey, Salt} = esasl_scram:generate_authentication_info( ),
Password, State UserInfo1#user_info{
), stored_key = StoredKey,
UserInfo1#user_info{ server_key = ServerKey,
stored_key = StoredKey, salt = Salt
server_key = ServerKey, }
salt = Salt end,
} mnesia:write(?TAB, UserInfo2, write),
end, {ok, format_user_info(UserInfo2)}
mnesia:write(?TAB, UserInfo2, write), end.
{ok, format_user_info(UserInfo2)}
end
end
).
lookup_user(UserID, #{user_group := UserGroup}) -> lookup_user(UserID, #{user_group := UserGroup}) ->
case mnesia:dirty_read(?TAB, {UserGroup, UserID}) of case mnesia:dirty_read(?TAB, {UserGroup, UserID}) of
@ -386,12 +390,10 @@ retrieve(UserID, #{user_group := UserGroup}) ->
end. end.
%% TODO: Move to emqx_authn_utils.erl %% TODO: Move to emqx_authn_utils.erl
trans(Fun) ->
trans(Fun, []).
trans(Fun, Args) -> trans(Fun, Args) ->
case mria:transaction(?AUTH_SHARD, Fun, Args) of case mria:transaction(?AUTH_SHARD, Fun, Args) of
{atomic, Res} -> Res; {atomic, Res} -> Res;
{aborted, {function_clause, Stack}} -> erlang:raise(error, function_clause, Stack);
{aborted, Reason} -> {error, Reason} {aborted, Reason} -> {error, Reason}
end. end.

View File

@ -54,6 +54,16 @@
group_match_spec/1 group_match_spec/1
]). ]).
%% Internal exports (RPC)
-export([
do_destroy/1,
do_add_user/2,
do_delete_user/2,
do_update_user/3,
import/2,
import_csv/3
]).
-type user_group() :: binary(). -type user_group() :: binary().
-type user_id() :: binary(). -type user_id() :: binary().
@ -175,15 +185,14 @@ authenticate(
end. end.
destroy(#{user_group := UserGroup}) -> destroy(#{user_group := UserGroup}) ->
trans( trans(fun ?MODULE:do_destroy/1, [UserGroup]).
fun() ->
ok = lists:foreach( do_destroy(UserGroup) ->
fun(User) -> ok = lists:foreach(
mnesia:delete_object(?TAB, User, write) fun(User) ->
end, mnesia:delete_object(?TAB, User, write)
mnesia:select(?TAB, group_match_spec(UserGroup), write) end,
) mnesia:select(?TAB, group_match_spec(UserGroup), write)
end
). ).
import_users({Filename0, FileData}, State) -> import_users({Filename0, FileData}, State) ->
@ -200,7 +209,10 @@ import_users({Filename0, FileData}, State) ->
{error, {unsupported_file_format, Extension}} {error, {unsupported_file_format, Extension}}
end. end.
add_user( add_user(UserInfo, State) ->
trans(fun ?MODULE:do_add_user/2, [UserInfo, State]).
do_add_user(
#{ #{
user_id := UserID, user_id := UserID,
password := Password password := Password
@ -210,33 +222,31 @@ add_user(
password_hash_algorithm := Algorithm password_hash_algorithm := Algorithm
} }
) -> ) ->
trans( case mnesia:read(?TAB, {UserGroup, UserID}, write) of
fun() -> [] ->
case mnesia:read(?TAB, {UserGroup, UserID}, write) of {PasswordHash, Salt} = emqx_authn_password_hashing:hash(Algorithm, Password),
[] -> IsSuperuser = maps:get(is_superuser, UserInfo, false),
{PasswordHash, Salt} = emqx_authn_password_hashing:hash(Algorithm, Password), insert_user(UserGroup, UserID, PasswordHash, Salt, IsSuperuser),
IsSuperuser = maps:get(is_superuser, UserInfo, false), {ok, #{user_id => UserID, is_superuser => IsSuperuser}};
insert_user(UserGroup, UserID, PasswordHash, Salt, IsSuperuser), [_] ->
{ok, #{user_id => UserID, is_superuser => IsSuperuser}}; {error, already_exist}
[_] -> end.
{error, already_exist}
end
end
).
delete_user(UserID, #{user_group := UserGroup}) -> delete_user(UserID, State) ->
trans( trans(fun ?MODULE:do_delete_user/2, [UserID, State]).
fun() ->
case mnesia:read(?TAB, {UserGroup, UserID}, write) of
[] ->
{error, not_found};
[_] ->
mnesia:delete(?TAB, {UserGroup, UserID}, write)
end
end
).
update_user( do_delete_user(UserID, #{user_group := UserGroup}) ->
case mnesia:read(?TAB, {UserGroup, UserID}, write) of
[] ->
{error, not_found};
[_] ->
mnesia:delete(?TAB, {UserGroup, UserID}, write)
end.
update_user(UserID, UserInfo, State) ->
trans(fun ?MODULE:do_update_user/3, [UserID, UserInfo, State]).
do_update_user(
UserID, UserID,
UserInfo, UserInfo,
#{ #{
@ -244,33 +254,29 @@ update_user(
password_hash_algorithm := Algorithm password_hash_algorithm := Algorithm
} }
) -> ) ->
trans( case mnesia:read(?TAB, {UserGroup, UserID}, write) of
fun() -> [] ->
case mnesia:read(?TAB, {UserGroup, UserID}, write) of {error, not_found};
[] -> [
{error, not_found}; #user_info{
[ password_hash = PasswordHash,
#user_info{ salt = Salt,
password_hash = PasswordHash, is_superuser = IsSuperuser
salt = Salt, }
is_superuser = IsSuperuser ] ->
} NSuperuser = maps:get(is_superuser, UserInfo, IsSuperuser),
] -> {NPasswordHash, NSalt} =
NSuperuser = maps:get(is_superuser, UserInfo, IsSuperuser), case UserInfo of
{NPasswordHash, NSalt} = #{password := Password} ->
case UserInfo of emqx_authn_password_hashing:hash(
#{password := Password} -> Algorithm, Password
emqx_authn_password_hashing:hash( );
Algorithm, Password #{} ->
); {PasswordHash, Salt}
#{} -> end,
{PasswordHash, Salt} insert_user(UserGroup, UserID, NPasswordHash, NSalt, NSuperuser),
end, {ok, #{user_id => UserID, is_superuser => NSuperuser}}
insert_user(UserGroup, UserID, NPasswordHash, NSalt, NSuperuser), end.
{ok, #{user_id => UserID, is_superuser => NSuperuser}}
end
end
).
lookup_user(UserID, #{user_group := UserGroup}) -> lookup_user(UserID, #{user_group := UserGroup}) ->
case mnesia:dirty_read(?TAB, {UserGroup, UserID}) of case mnesia:dirty_read(?TAB, {UserGroup, UserID}) of
@ -335,7 +341,7 @@ run_fuzzy_filter(
import_users_from_json(Bin, #{user_group := UserGroup}) -> import_users_from_json(Bin, #{user_group := UserGroup}) ->
case emqx_json:safe_decode(Bin, [return_maps]) of case emqx_json:safe_decode(Bin, [return_maps]) of
{ok, List} -> {ok, List} ->
trans(fun import/2, [UserGroup, List]); trans(fun ?MODULE:import/2, [UserGroup, List]);
{error, Reason} -> {error, Reason} ->
{error, Reason} {error, Reason}
end. end.
@ -344,7 +350,7 @@ import_users_from_json(Bin, #{user_group := UserGroup}) ->
import_users_from_csv(CSV, #{user_group := UserGroup}) -> import_users_from_csv(CSV, #{user_group := UserGroup}) ->
case get_csv_header(CSV) of case get_csv_header(CSV) of
{ok, Seq, NewCSV} -> {ok, Seq, NewCSV} ->
trans(fun import_csv/3, [UserGroup, NewCSV, Seq]); trans(fun ?MODULE:import_csv/3, [UserGroup, NewCSV, Seq]);
{error, Reason} -> {error, Reason} ->
{error, Reason} {error, Reason}
end. end.
@ -435,9 +441,6 @@ get_user_identity(#{clientid := ClientID}, clientid) ->
get_user_identity(_, Type) -> get_user_identity(_, Type) ->
{error, {bad_user_identity_type, Type}}. {error, {bad_user_identity_type, Type}}.
trans(Fun) ->
trans(Fun, []).
trans(Fun, Args) -> trans(Fun, Args) ->
case mria:transaction(?AUTH_SHARD, Fun, Args) of case mria:transaction(?AUTH_SHARD, Fun, Args) of
{atomic, Res} -> Res; {atomic, Res} -> Res;