emqx/apps/emqx_psk/src/emqx_psk.erl

248 lines
7.5 KiB
Erlang

%%--------------------------------------------------------------------
%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_psk).
-behaviour(gen_server).
-include_lib("emqx/include/logger.hrl").
-export([ load/0
, unload/0
, on_psk_lookup/2
, import/1
]).
-export([ start_link/0
, stop/0
]).
%% gen_server callbacks
-export([ init/1
, handle_call/3
, handle_cast/2
, handle_info/2
, terminate/2
, code_change/3
]).
-record(psk_entry, {psk_id :: binary(),
shared_secret :: binary(),
extra :: term()
}).
-export([mnesia/1]).
-boot_mnesia({mnesia, [boot]}).
-define(TAB, ?MODULE).
-define(PSK_SHARD, emqx_psk_shard).
-define(DEFAULT_DELIMITER, <<":">>).
-define(CR, 13).
-define(LF, 10).
%%------------------------------------------------------------------------------
%% Mnesia bootstrap
%%------------------------------------------------------------------------------
%% @doc Create or replicate tables.
-spec(mnesia(boot | copy) -> ok).
mnesia(boot) ->
ok = mria:create_table(?TAB, [
{rlog_shard, ?PSK_SHARD},
{type, ordered_set},
{storage, disc_copies},
{record_name, psk_entry},
{attributes, record_info(fields, psk_entry)},
{storage_properties, [{ets, [{read_concurrency, true}]}]}]).
%%------------------------------------------------------------------------------
%% APIs
%%------------------------------------------------------------------------------
load() ->
emqx:hook('tls_handshake.psk_lookup', {?MODULE, on_psk_lookup, []}).
unload() ->
emqx:unhook('tls_handshake.psk_lookup', {?MODULE, on_psk_lookup, []}).
on_psk_lookup(PSKIdentity, _UserState) ->
case mnesia:dirty_read(?TAB, PSKIdentity) of
[#psk_entry{shared_secret = SharedSecret}] ->
{stop, {ok, SharedSecret}};
_ ->
ignore
end.
import(SrcFile) ->
call({import, SrcFile}).
-spec start_link() -> {ok, pid()} | ignore | {error, term()}.
start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
-spec stop() -> ok.
stop() ->
gen_server:stop(?MODULE).
%%--------------------------------------------------------------------
%% gen_server callbacks
%%--------------------------------------------------------------------
init(_Opts) ->
_ = case get_config(enable) of
true -> load();
false -> ?SLOG(info, #{msg => "emqx_psk_disabled"})
end,
_ = case get_config(init_file) of
undefined -> ok;
InitFile -> import_psks(InitFile)
end,
{ok, #{}}.
handle_call({import, SrcFile}, _From, State) ->
{reply, import_psks(SrcFile), State};
handle_call(Req, _From, State) ->
?SLOG(info, #{msg => "unexpected_call_discarded", req => Req}),
{reply, {error, unexecpted}, State}.
handle_cast(Req, State) ->
?SLOG(info, #{msg => "unexpected_cast_discarded", req => Req}),
{noreply, State}.
handle_info(Info, State) ->
?SLOG(info, #{msg => "unexpected_info_discarded", info => Info}),
{noreply, State}.
terminate(_Reason, _State) ->
unload(),
ok.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%------------------------------------------------------------------------------
%% Internal functions
%%------------------------------------------------------------------------------
get_config(enable) ->
emqx_conf:get([psk, enable]);
get_config(init_file) ->
emqx_conf:get([psk, init_file], undefined);
get_config(separator) ->
emqx_conf:get([psk, separator], ?DEFAULT_DELIMITER);
get_config(chunk_size) ->
emqx_conf:get([psk, chunk_size]).
import_psks(SrcFile) ->
case file:open(SrcFile, [read, raw, binary, read_ahead]) of
{error, Reason} ->
?SLOG(error, #{msg => "failed_to_open_psk_file",
file => SrcFile,
reason => Reason}),
{error, Reason};
{ok, Io} ->
try import_psks(Io, get_config(separator), get_config(chunk_size), 0) of
ok -> ok;
{error, Reason} ->
?SLOG(error, #{msg => "failed_to_import_psk_file",
file => SrcFile,
reason => Reason}),
{error, Reason}
catch
Exception:Reason:Stacktrace ->
?SLOG(error, #{msg => "failed_to_import_psk_file",
file => SrcFile,
exception => Exception,
reason => Reason,
stacktrace => Stacktrace}),
{error, Reason}
after
_ = file:close(Io)
end
end.
import_psks(Io, Delimiter, ChunkSize, NChunk) ->
case get_psks(Io, Delimiter, ChunkSize) of
{ok, Entries} ->
_ = trans(fun insert_psks/1, [Entries]),
import_psks(Io, Delimiter, ChunkSize, NChunk + 1);
{eof, Entries} ->
_ = trans(fun insert_psks/1, [Entries]),
ok;
{error, {bad_format, {line, N}}} ->
{error, {bad_format, {line, NChunk * ChunkSize + N}}};
{error, Reaosn} ->
{error, Reaosn}
end.
get_psks(Io, Delimiter, Max) ->
get_psks(Io, Delimiter, Max, []).
get_psks(_Io, _Delimiter, 0, Acc) ->
{ok, Acc};
get_psks(Io, Delimiter, Remaining, Acc) ->
case file:read_line(Io) of
{ok, Line} ->
case binary:split(Line, Delimiter) of
[PSKIdentity, SharedSecret] ->
NSharedSecret = trim_crlf(SharedSecret),
get_psks(Io, Delimiter, Remaining - 1, [{PSKIdentity, NSharedSecret} | Acc]);
_ ->
{error, {bad_format, {line, length(Acc) + 1}}}
end;
eof ->
{eof, Acc};
{error, Reason} ->
{error, Reason}
end.
insert_psks(Entries) ->
lists:foreach(fun(Entry) ->
insert_psk(Entry)
end, Entries).
insert_psk({PSKIdentity, SharedSecret}) ->
mnesia:write(?TAB, #psk_entry{psk_id = PSKIdentity, shared_secret = SharedSecret}, write).
trim_crlf(Bin) ->
Size = byte_size(Bin),
case binary:at(Bin, Size - 1) of
?LF ->
case binary:at(Bin, Size - 2) of
?CR -> binary:part(Bin, 0, Size - 2);
_ -> binary:part(Bin, 0, Size - 1)
end;
_ -> Bin
end.
trans(Fun, Args) ->
case mria:transaction(?PSK_SHARD, Fun, Args) of
{atomic, Res} -> Res;
{aborted, Reason} -> {error, Reason}
end.
call(Request) ->
try
gen_server:call(?MODULE, Request, 10000)
catch
exit:{timeout, _Details} ->
{error, timeout}
end.