refactor(tls): move ssl files handling to emqx_tls_lib

This is an attempt ot make it more generic for other APPs to use.
Aslo added test cases to cover most of the code paths.
This commit is contained in:
Zaiming Shi 2021-10-23 18:13:31 +02:00 committed by x1001100011
parent d796f815d7
commit fdf439bc7b
6 changed files with 261 additions and 116 deletions

View File

@ -65,6 +65,7 @@
, remove_config/1
, remove_config/2
, reset_config/2
, data_dir/0
]).
-define(APP, ?MODULE).
@ -246,3 +247,6 @@ reset_config([RootName | _] = KeyPath, Opts) ->
{error, _} = Error ->
Error
end.
data_dir() ->
application:get_env(emqx, data_dir, "data").

View File

@ -27,9 +27,8 @@
, authn_type/1
]).
%% TODO: certs handling should be moved out of emqx app
-ifdef(TEST).
-export([convert_certs/2, convert_certs/3, diff_cert/2, clear_certs/2]).
-export([convert_certs/2, convert_certs/3, clear_certs/2]).
-endif.
-export_type([config/0]).
@ -64,7 +63,7 @@ pre_config_update(UpdateReq, OldConfig) ->
do_pre_config_update({create_authenticator, ChainName, Config}, OldConfig) ->
try
CertsDir = certs_dir([to_bin(ChainName), authenticator_id(Config)]),
CertsDir = certs_dir(ChainName, Config),
NConfig = convert_certs(CertsDir, Config),
{ok, OldConfig ++ [NConfig]}
catch
@ -80,7 +79,7 @@ do_pre_config_update({delete_authenticator, _ChainName, AuthenticatorID}, OldCon
{ok, NewConfig};
do_pre_config_update({update_authenticator, ChainName, AuthenticatorID, Config}, OldConfig) ->
try
CertsDir = certs_dir([to_bin(ChainName), AuthenticatorID]),
CertsDir = certs_dir(ChainName, AuthenticatorID),
NewConfig = lists:map(
fun(OldConfig0) ->
case AuthenticatorID =:= authenticator_id(OldConfig0) of
@ -127,9 +126,8 @@ do_post_config_update({delete_authenticator, ChainName, AuthenticatorID}, _NewCo
case emqx_authentication:delete_authenticator(ChainName, AuthenticatorID) of
ok ->
[Config] = [Config0 || Config0 <- to_list(OldConfig), AuthenticatorID == authenticator_id(Config0)],
CertsDir = certs_dir([to_bin(ChainName), AuthenticatorID]),
clear_certs(CertsDir, Config),
ok;
CertsDir = certs_dir(ChainName, AuthenticatorID),
ok = clear_certs(CertsDir, Config);
{error, Reason} ->
{error, Reason}
end;
@ -193,105 +191,33 @@ to_list(M) when M =:= #{} -> [];
to_list(M) when is_map(M) -> [M];
to_list(L) when is_list(L) -> L.
certs_dir(Dirs) when is_list(Dirs) ->
to_bin(filename:join([emqx:get_config([node, data_dir]), "certs", "authn"] ++ Dirs)).
convert_certs(CertsDir, Config) ->
case maps:get(<<"ssl">>, Config, undefined) of
undefined ->
Config;
SSLOpts ->
NSSLOPts = lists:foldl(fun(K, Acc) ->
case maps:get(K, Acc, undefined) of
undefined -> Acc;
PemBin ->
CertFile = generate_filename(CertsDir, K),
ok = save_cert_to_file(CertFile, PemBin),
Acc#{K => CertFile}
end
end, SSLOpts, [<<"certfile">>, <<"keyfile">>, <<"cacertfile">>]),
Config#{<<"ssl">> => NSSLOPts}
case emqx_tls_lib:ensure_ssl_files(CertsDir, maps:get(<<"ssl">>, Config, undefined)) of
{ok, SSL} ->
new_ssl_config(Config, SSL);
{error, Reason} ->
?SLOG(error, Reason#{msg => bad_ssl_config}),
throw(bad_ssl_config)
end.
convert_certs(CertsDir, NewConfig, OldConfig) ->
case maps:get(<<"ssl">>, NewConfig, undefined) of
undefined ->
NewConfig;
NewSSLOpts ->
OldSSLOpts = maps:get(<<"ssl">>, OldConfig, #{}),
Diff = diff_certs(NewSSLOpts, OldSSLOpts),
NSSLOpts = lists:foldl(fun({identical, K}, Acc) ->
Acc#{K => maps:get(K, OldSSLOpts)};
({_, K}, Acc) ->
CertFile = generate_filename(CertsDir, K),
ok = save_cert_to_file(CertFile, maps:get(K, NewSSLOpts)),
Acc#{K => CertFile}
end, NewSSLOpts, Diff),
NewConfig#{<<"ssl">> => NSSLOpts}
OldSSL = maps:get(<<"ssl">>, OldConfig, undefined),
NewSSL = maps:get(<<"ssl">>, NewConfig, undefined),
case emqx_tls_lib:ensure_ssl_files(CertsDir, NewSSL) of
{ok, NewSSL1} ->
ok = emqx_tls_lib:delete_ssl_files(CertsDir, NewSSL1, OldSSL),
new_ssl_config(NewConfig, NewSSL1);
{error, Reason} ->
?SLOG(error, Reason#{msg => bad_ssl_config}),
throw(bad_ssl_config)
end.
new_ssl_config(Config, undefined) -> Config;
new_ssl_config(Config, SSL) -> Config#{<<"ssl">> => SSL}.
clear_certs(CertsDir, Config) ->
case maps:get(<<"ssl">>, Config, undefined) of
undefined ->
ok;
SSLOpts ->
lists:foreach(
fun({_, Filename}) ->
_ = file:delete(filename:join([CertsDir, Filename]))
end,
maps:to_list(maps:with([<<"certfile">>, <<"keyfile">>, <<"cacertfile">>], SSLOpts)))
end.
save_cert_to_file(Filename, PemBin) ->
case public_key:pem_decode(PemBin) =/= [] of
true ->
case filelib:ensure_dir(Filename) of
ok ->
case file:write_file(Filename, PemBin) of
ok -> ok;
{error, Reason} -> error({save_cert_to_file, {write_file, Reason}})
end;
{error, Reason} ->
error({save_cert_to_file, {ensure_dir, Reason}})
end;
false ->
error({save_cert_to_file, invalid_certificate})
end.
generate_filename(CertsDir, Key) ->
Prefix = case Key of
<<"keyfile">> -> "key-";
<<"certfile">> -> "cert-";
<<"cacertfile">> -> "cacert-"
end,
to_bin(filename:join([CertsDir, Prefix ++ emqx_misc:gen_id() ++ ".pem"])).
diff_certs(NewSSLOpts, OldSSLOpts) ->
Keys = [<<"cacertfile">>, <<"certfile">>, <<"keyfile">>],
CertPems = maps:with(Keys, NewSSLOpts),
CertFiles = maps:with(Keys, OldSSLOpts),
Diff = lists:foldl(fun({K, CertFile}, Acc) ->
case maps:find(K, CertPems) of
error -> Acc;
{ok, PemBin1} ->
{ok, PemBin2} = file:read_file(CertFile),
case diff_cert(PemBin1, PemBin2) of
true ->
[{changed, K} | Acc];
false ->
[{identical, K} | Acc]
end
end
end,
[], maps:to_list(CertFiles)),
Added = [{added, K} || K <- maps:keys(maps:without(maps:keys(CertFiles), CertPems))],
Diff ++ Added.
diff_cert(Pem1, Pem2) ->
cal_md5_for_cert(Pem1) =/= cal_md5_for_cert(Pem2).
cal_md5_for_cert(Pem) ->
crypto:hash(md5, term_to_binary(public_key:pem_decode(Pem))).
OldSSL = maps:get(<<"ssl">>, Config, undefined),
ok = emqx_tls_lib:delete_ssl_files(CertsDir, undefined, OldSSL).
split_by_id(ID, AuthenticatorsConfig) ->
case lists:foldl(
@ -342,3 +268,9 @@ authn_type(#{<<"mechanism">> := M}) -> atom(M).
atom(Bin) ->
binary_to_existing_atom(Bin, utf8).
%% The relative dir for ssl files.
certs_dir(ChainName, ID) when is_binary(ID) ->
filename:join([to_bin(ChainName), ID]);
certs_dir(ChainName, Config) when is_map(Config) ->
certs_dir(ChainName, authenticator_id(Config)).

View File

@ -277,7 +277,7 @@ init_load(SchemaMod, RawConf0) when is_map(RawConf0) ->
maps:with(get_root_names(), RawConf0)).
include_dirs() ->
[filename:join(application:get_env(emqx, data_dir, "data/"), "configs") ++ "/"].
[filename:join(emqx:data_dir(), "configs")].
-spec check_config(module(), raw_config()) -> {AppEnvs, CheckedConf}
when AppEnvs :: app_envs(), CheckedConf :: config().

View File

@ -16,6 +16,7 @@
-module(emqx_tls_lib).
%% version & cipher suites
-export([ default_versions/0
, integral_versions/1
, default_ciphers/0
@ -25,6 +26,15 @@
, all_ciphers/0
]).
%% files
-export([ ensure_ssl_files/2
, delete_ssl_files/3
]).
-include("logger.hrl").
-define(SSL_FILE_OPT_NAMES, [<<"keyfile">>, <<"certfile">>, <<"cacertfile">>]).
%% non-empty string
-define(IS_STRING(L), (is_list(L) andalso L =/= [] andalso is_integer(hd(L)))).
%% non-empty list of strings
@ -212,6 +222,129 @@ drop_tls13(SslOpts0) ->
SslOpts1#{ciphers => Ciphers -- ?TLSV13_EXCLUSIVE_CIPHERS}
end.
%% @doc The input map is a HOCON decoded result of a struct defined as
%% emqx_schema:server_ssl_opts_schema. (NOTE: before schema-checked).
%% `keyfile', `certfile' and `cacertfile' can be either pem format key or certificates,
%% or file path.
%% When PEM format key or certificate is given, it tries to to save them in the given
%% sub-dir in emqx's data_dir, and replace saved file paths for SSL options.
-spec ensure_ssl_files(file:name_all(), undefined | map()) ->
{ok, undefined | map()} | {error, map()}.
ensure_ssl_files(Dir, Opts) ->
ensure_ssl_files(Dir, Opts, _DryRun = false).
ensure_ssl_files(_Dir, undefined, _DryRun) -> {ok, undefined};
ensure_ssl_files(_Dir, #{<<"enable">> := false} = Opts, _DryRun) -> {ok, Opts};
ensure_ssl_files(Dir, Opts, DryRun) ->
ensure_ssl_files(Dir, Opts, ?SSL_FILE_OPT_NAMES, DryRun).
ensure_ssl_files(_Dir,Opts, [], _DryRun) -> {ok, Opts};
ensure_ssl_files(Dir, Opts, [Key | Keys], DryRun) ->
case ensure_ssl_file(Dir, Key, Opts, maps:get(Key, Opts, undefined), DryRun) of
{ok, NewOpts} ->
ensure_ssl_files(Dir, NewOpts, Keys, DryRun);
{error, Reason} ->
{error, Reason#{which_option => Key}}
end.
%% @doc Compare old and new config, delete the ones in old but not in new.
-spec delete_ssl_files(file:name_all(), undefiend | map(), undefined | map()) -> ok.
delete_ssl_files(Dir, NewOpts0, OldOpts0) ->
DryRun = true,
{ok, NewOpts} = ensure_ssl_files(Dir, NewOpts0, DryRun),
{ok, OldOpts} = ensure_ssl_files(Dir, OldOpts0, DryRun),
Get = fun(_K, undefined) -> undefined;
(K, Opts) -> maps:get(K, Opts, undefined)
end,
lists:foreach(fun(Key) -> delete_old_file(Get(Key, NewOpts), Get(Key, OldOpts)) end,
?SSL_FILE_OPT_NAMES).
delete_old_file(New, Old) when New =:= Old -> ok;
delete_old_file(_New, _Old = undefined) -> ok;
delete_old_file(_New, Old) ->
case filelib:is_regular(Old) andalso file:delete(Old) of
ok -> ok;
false -> ok; %% already deleted
{error, Reason} ->
?SLOG(error, #{msg => "failed_to_delete_ssl_file", file_path => Old, reason => Reason})
end.
ensure_ssl_file(_Dir, _Key, Opts, undefined, _DryRun) ->
{ok, Opts};
ensure_ssl_file(Dir, Key, Opts, MaybePem, DryRun) ->
case is_valid_string(MaybePem) of
true ->
do_ensure_ssl_file(Dir, Key, Opts, MaybePem, DryRun);
false ->
{error, #{reason => invalid_file_path_or_pem_string}}
end.
do_ensure_ssl_file(Dir, Key, Opts, MaybePem, DryRun) ->
case is_pem(MaybePem) of
true ->
case save_pem_file(Dir, Key, MaybePem, DryRun) of
{ok, Path} -> {ok, Opts#{Key => Path}};
{error, Reason} -> {error, Reason}
end;
false ->
case is_valid_pem_file(MaybePem) of
true -> {ok, Opts};
{error, enoent} when DryRun -> {ok, Opts};
{error, Reason} ->
{error, #{file_path => MaybePem,
reason => Reason
}}
end
end.
is_valid_string(String) when is_list(String) ->
io_lib:printable_unicode_list(String);
is_valid_string(Binary) when is_binary(Binary) ->
case unicode:characters_to_list(Binary, utf8) of
String when is_list(String) -> is_valid_string(String);
_Otherwise -> false
end.
%% Check if it is a valid PEM formated key.
is_pem(MaybePem) ->
try public_key:pem_decode(MaybePem) =/= []
catch _ : _ -> false
end.
%% Write the pem file to the given dir.
%% To make it simple, the file is always overwritten.
%% Also a potentially half-written PEM file (e.g. due to power outage)
%% can be corrected with an overwrite.
save_pem_file(Dir, Key, Pem, DryRun) ->
Path = pem_file_name(Dir, Key, Pem),
case filelib:ensure_dir(Path) of
ok when DryRun ->
{ok, Path};
ok ->
case file:write_file(Path, Pem) of
ok -> {ok, Path};
{error, Reason} ->
{error, #{failed_to_write_file => Reason, file_path => Path}}
end;
{error, Reason} ->
{error, #{failed_to_create_dir_for => Path, reason => Reason}}
end.
%% compute the filename for a PEM format key/certificate
%% the filename is prefixed by the option name without the 'file' part
%% and suffixed with the first 8 byets of base64 encode result of the PEM content's
%% md5 checksum. e.g. key-EKjjO9um, cert-TwuCW1vh, and cacert-6ZaWqNuC
pem_file_name(Dir, Key, Pem) ->
<<CK:8/binary, _/binary>> = base64:encode(crypto:hash(md5, Pem)),
FileName = binary:replace(Key, <<"file">>, <<"-", CK/binary>>),
filename:join([emqx:data_dir(), Dir, FileName]).
is_valid_pem_file(Path) ->
case file:read_file(Path) of
{ok, Pem} -> is_pem(Pem) orelse {error, not_pem};
{error, Reason} -> {error, Reason}
end.
-if(?OTP_RELEASE > 22).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").

View File

@ -97,17 +97,10 @@ end_per_suite(_) ->
ok.
init_per_testcase(Case, Config) ->
meck:new(emqx, [non_strict, passthrough, no_history, no_link]),
meck:expect(emqx, get_config, fun([node, data_dir]) ->
{data_dir, Data} = lists:keyfind(data_dir, 1, Config),
Data;
(C) -> meck:passthrough([C])
end),
?MODULE:Case({'init', Config}).
end_per_testcase(Case, Config) ->
_ = ?MODULE:Case({'end', Config}),
meck:unload(emqx),
ok.
t_chain({_, Config}) -> Config;
@ -119,7 +112,7 @@ t_chain(Config) when is_list(Config) ->
?assertEqual({error, {already_exists, {chain, ChainName}}}, ?AUTHN:create_chain(ChainName)),
?assertMatch({ok, #{name := ChainName, authenticators := []}}, ?AUTHN:lookup_chain(ChainName)),
?assertMatch({ok, [#{name := ChainName}]}, ?AUTHN:list_chains()),
?assertEqual(ok, ?AUTHN:delete_chain(ChainName)),
?assertEqual(ok, ?AUTHN:delete_chain(ChainName)),
?assertMatch({error, {not_found, {chain, ChainName}}}, ?AUTHN:lookup_chain(ChainName)),
ok.
@ -273,13 +266,11 @@ t_convert_certs(Config) when is_list(Config) ->
CertsDir = certs_dir(Config, [Global, <<"password-based:built-in-database">>]),
#{<<"ssl">> := NCerts} = convert_certs(CertsDir, #{<<"ssl">> => Certs}),
?assertEqual(false, diff_cert(maps:get(<<"keyfile">>, NCerts), maps:get(<<"keyfile">>, Certs))),
Certs2 = certs([ {<<"keyfile">>, "key.pem"}
, {<<"certfile">>, "cert.pem"}
]),
#{<<"ssl">> := NCerts2} = convert_certs(CertsDir, #{<<"ssl">> => Certs2}, #{<<"ssl">> => NCerts}),
?assertEqual(false, diff_cert(maps:get(<<"keyfile">>, NCerts2), maps:get(<<"keyfile">>, Certs2))),
?assertEqual(maps:get(<<"keyfile">>, NCerts), maps:get(<<"keyfile">>, NCerts2)),
?assertEqual(maps:get(<<"certfile">>, NCerts), maps:get(<<"certfile">>, NCerts2)),
@ -288,7 +279,6 @@ t_convert_certs(Config) when is_list(Config) ->
, {<<"cacertfile">>, "cacert.pem"}
]),
#{<<"ssl">> := NCerts3} = convert_certs(CertsDir, #{<<"ssl">> => Certs3}, #{<<"ssl">> => NCerts2}),
?assertEqual(false, diff_cert(maps:get(<<"keyfile">>, NCerts3), maps:get(<<"keyfile">>, Certs3))),
?assertNotEqual(maps:get(<<"keyfile">>, NCerts2), maps:get(<<"keyfile">>, NCerts3)),
?assertNotEqual(maps:get(<<"certfile">>, NCerts2), maps:get(<<"certfile">>, NCerts3)),
@ -306,10 +296,6 @@ certs(Certs) ->
Acc#{Key => Bin}
end, #{}, Certs).
diff_cert(CertFile, CertPem2) ->
{ok, CertPem1} = file:read_file(CertFile),
emqx_authentication_config:diff_cert(CertPem1, CertPem2).
register_provider(Type, Module) ->
ok = ?AUTHN:register_providers([{Type, Module}]).

View File

@ -38,7 +38,7 @@ use_default_ciphers_test() ->
ciphers_format_test_() ->
String = ?TLS_13_CIPHER ++ "," ++ ?TLS_12_CIPHER,
Binary = iolist_to_binary(String),
Binary = bin(String),
List = [?TLS_13_CIPHER, ?TLS_12_CIPHER],
[ {"string", fun() -> test_cipher_format(String) end}
, {"binary", fun() -> test_cipher_format(Binary) end}
@ -66,3 +66,93 @@ cipher_suites_no_duplication_test() ->
AllCiphers = emqx_tls_lib:default_ciphers(),
?assertEqual(length(AllCiphers), length(lists:usort(AllCiphers))).
ssl_files_failure_test_() ->
[{"undefined_is_undefined",
fun() ->
?assertEqual({ok, undefined},
emqx_tls_lib:ensure_ssl_files("dir", undefined)) end},
{"no_op_if_disabled",
fun() ->
Disabled = #{<<"enable">> => false, foo => bar},
?assertEqual({ok, Disabled},
emqx_tls_lib:ensure_ssl_files("dir", Disabled)) end},
{"enoent_key_file",
fun() ->
NonExistingFile = filename:join("/tmp", integer_to_list(erlang:system_time(microsecond))),
?assertMatch({error, #{reason := enoent}},
emqx_tls_lib:ensure_ssl_files("/tmp", #{<<"keyfile">> => NonExistingFile}))
end},
{"bad_pem_string",
fun() ->
%% not valid unicode
?assertMatch({error, #{reason := invalid_file_path_or_pem_string, which_option := <<"keyfile">>}},
emqx_tls_lib:ensure_ssl_files("/tmp", #{<<"keyfile">> => <<255, 255>>})),
%% not printable
?assertMatch({error, #{reason := invalid_file_path_or_pem_string}},
emqx_tls_lib:ensure_ssl_files("/tmp", #{<<"keyfile">> => <<33, 22>>})),
TmpFile = filename:join("/tmp", integer_to_list(erlang:system_time(microsecond))),
try
ok = file:write_file(TmpFile, <<"not a valid pem">>),
?assertMatch({error, #{file_path := _, reason := not_pem}},
emqx_tls_lib:ensure_ssl_files("/tmp", #{<<"cacertfile">> => bin(TmpFile)}))
after
file:delete(TmpFile)
end
end}
].
ssl_files_save_delete_test() ->
SSL0 = #{<<"keyfile">> => bin(test_key())},
Dir = filename:join(["/tmp", "ssl-test-dir"]),
{ok, SSL} = emqx_tls_lib:ensure_ssl_files(Dir, SSL0),
File = maps:get(<<"keyfile">>, SSL),
?assertMatch(<<"/tmp/ssl-test-dir/key-", _:8/binary>>, File),
?assertEqual({ok, bin(test_key())}, file:read_file(File)),
%% no old file to delete
ok = emqx_tls_lib:delete_ssl_files(Dir, SSL, undefined),
?assertEqual({ok, bin(test_key())}, file:read_file(File)),
%% old and new identical, no delete
ok = emqx_tls_lib:delete_ssl_files(Dir, SSL, SSL),
?assertEqual({ok, bin(test_key())}, file:read_file(File)),
%% new is gone, delete old
ok = emqx_tls_lib:delete_ssl_files(Dir, undefined, SSL),
?assertEqual({error, enoent}, file:read_file(File)),
%% test idempotence
ok = emqx_tls_lib:delete_ssl_files(Dir, undefined, SSL),
ok.
ssl_file_replace_test() ->
SSL0 = #{<<"keyfile">> => bin(test_key())},
SSL1 = #{<<"keyfile">> => bin(test_key2())},
Dir = filename:join(["/tmp", "ssl-test-dir2"]),
{ok, SSL2} = emqx_tls_lib:ensure_ssl_files(Dir, SSL0),
{ok, SSL3} = emqx_tls_lib:ensure_ssl_files(Dir, SSL1),
File1 = maps:get(<<"keyfile">>, SSL2),
File2 = maps:get(<<"keyfile">>, SSL3),
?assert(filelib:is_regular(File1)),
?assert(filelib:is_regular(File2)),
%% delete old file (File1, in SSL2)
ok = emqx_tls_lib:delete_ssl_files(Dir, SSL3, SSL2),
?assertNot(filelib:is_regular(File1)),
?assert(filelib:is_regular(File2)),
ok.
bin(X) -> iolist_to_binary(X).
test_key() ->
"""
-----BEGIN EC PRIVATE KEY-----
MHQCAQEEICKTbbathzvD8zvgjL7qRHhW4alS0+j0Loo7WeYX9AxaoAcGBSuBBAAK
oUQDQgAEJBdF7MIdam5T4YF3JkEyaPKdG64TVWCHwr/plC0QzNVJ67efXwxlVGTo
ju0VBj6tOX1y6C0U+85VOM0UU5xqvw==
-----END EC PRIVATE KEY-----
""".
test_key2() ->
"""
-----BEGIN EC PRIVATE KEY-----
MHQCAQEEID9UlIyAlLFw0irkRHX29N+ZGivGtDjlVJvATY3B0TTmoAcGBSuBBAAK
oUQDQgAEUwiarudRNAT25X11js8gE9G+q0GdsT53QJQjRtBO+rTwuCW1vhLzN0Ve
AbToUD4JmV9m/XwcSVH06ZaWqNuC5w==
-----END EC PRIVATE KEY-----
""".