248 lines
7.6 KiB
Erlang
248 lines
7.6 KiB
Erlang
%%--------------------------------------------------------------------
|
|
%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
|
|
%%
|
|
%% Licensed under the Apache License, Version 2.0 (the "License");
|
|
%% you may not use this file except in compliance with the License.
|
|
%% You may obtain a copy of the License at
|
|
%%
|
|
%% http://www.apache.org/licenses/LICENSE-2.0
|
|
%%
|
|
%% Unless required by applicable law or agreed to in writing, software
|
|
%% distributed under the License is distributed on an "AS IS" BASIS,
|
|
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
%% See the License for the specific language governing permissions and
|
|
%% limitations under the License.
|
|
%%--------------------------------------------------------------------
|
|
|
|
-module(emqx_psk).
|
|
|
|
-behaviour(gen_server).
|
|
|
|
-include_lib("emqx/include/logger.hrl").
|
|
|
|
-export([ load/0
|
|
, unload/0
|
|
, on_psk_lookup/2
|
|
, import/1
|
|
]).
|
|
|
|
-export([ start_link/0
|
|
, stop/0
|
|
]).
|
|
|
|
%% gen_server callbacks
|
|
-export([ init/1
|
|
, handle_call/3
|
|
, handle_cast/2
|
|
, handle_info/2
|
|
, terminate/2
|
|
, code_change/3
|
|
]).
|
|
|
|
-record(psk_entry, {psk_id :: binary(),
|
|
shared_secret :: binary(),
|
|
extra :: term()
|
|
}).
|
|
|
|
-export([mnesia/1]).
|
|
|
|
-boot_mnesia({mnesia, [boot]}).
|
|
|
|
-define(TAB, ?MODULE).
|
|
-define(PSK_SHARD, emqx_psk_shard).
|
|
|
|
-define(DEFAULT_DELIMITER, <<":">>).
|
|
|
|
-define(CR, 13).
|
|
-define(LF, 10).
|
|
|
|
%%------------------------------------------------------------------------------
|
|
%% Mnesia bootstrap
|
|
%%------------------------------------------------------------------------------
|
|
|
|
%% @doc Create or replicate tables.
|
|
-spec(mnesia(boot | copy) -> ok).
|
|
mnesia(boot) ->
|
|
ok = mria:create_table(?TAB, [
|
|
{rlog_shard, ?PSK_SHARD},
|
|
{type, ordered_set},
|
|
{storage, disc_copies},
|
|
{record_name, psk_entry},
|
|
{attributes, record_info(fields, psk_entry)},
|
|
{storage_properties, [{ets, [{read_concurrency, true}]}]}]).
|
|
|
|
%%------------------------------------------------------------------------------
|
|
%% APIs
|
|
%%------------------------------------------------------------------------------
|
|
|
|
load() ->
|
|
emqx:hook('tls_handshake.psk_lookup', {?MODULE, on_psk_lookup, []}).
|
|
|
|
unload() ->
|
|
emqx:unhook('tls_handshake.psk_lookup', {?MODULE, on_psk_lookup, []}).
|
|
|
|
on_psk_lookup(PSKIdentity, _UserState) ->
|
|
case mnesia:dirty_read(?TAB, PSKIdentity) of
|
|
[#psk_entry{shared_secret = SharedSecret}] ->
|
|
{stop, {ok, SharedSecret}};
|
|
_ ->
|
|
ignore
|
|
end.
|
|
|
|
import(SrcFile) ->
|
|
call({import, SrcFile}).
|
|
|
|
-spec start_link() -> {ok, pid()} | ignore | {error, term()}.
|
|
start_link() ->
|
|
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
|
|
|
|
-spec stop() -> ok.
|
|
stop() ->
|
|
gen_server:stop(?MODULE).
|
|
|
|
%%--------------------------------------------------------------------
|
|
%% gen_server callbacks
|
|
%%--------------------------------------------------------------------
|
|
|
|
init(_Opts) ->
|
|
_ = case get_config(enable) of
|
|
true -> load();
|
|
false -> ?SLOG(info, #{msg => "emqx_psk_disabled"})
|
|
end,
|
|
_ = case get_config(init_file) of
|
|
undefined -> ok;
|
|
InitFile -> import_psks(InitFile)
|
|
end,
|
|
{ok, #{}}.
|
|
|
|
handle_call({import, SrcFile}, _From, State) ->
|
|
{reply, import_psks(SrcFile), State};
|
|
|
|
handle_call(Req, _From, State) ->
|
|
?SLOG(info, #{msg => "unexpected_call_discarded", req => Req}),
|
|
{reply, {error, unexecpted}, State}.
|
|
|
|
handle_cast(Req, State) ->
|
|
?SLOG(info, #{msg => "unexpected_cast_discarded", req => Req}),
|
|
{noreply, State}.
|
|
|
|
handle_info(Info, State) ->
|
|
?SLOG(info, #{msg => "unexpected_info_discarded", info => Info}),
|
|
{noreply, State}.
|
|
|
|
terminate(_Reason, _State) ->
|
|
unload(),
|
|
ok.
|
|
|
|
code_change(_OldVsn, State, _Extra) ->
|
|
{ok, State}.
|
|
|
|
%%------------------------------------------------------------------------------
|
|
%% Internal functions
|
|
%%------------------------------------------------------------------------------
|
|
|
|
get_config(enable) ->
|
|
emqx_conf:get([psk_authentication, enable]);
|
|
get_config(init_file) ->
|
|
emqx_conf:get([psk_authentication, init_file], undefined);
|
|
get_config(separator) ->
|
|
emqx_conf:get([psk_authentication, separator], ?DEFAULT_DELIMITER);
|
|
get_config(chunk_size) ->
|
|
emqx_conf:get([psk_authentication, chunk_size]).
|
|
|
|
import_psks(SrcFile) ->
|
|
case file:open(SrcFile, [read, raw, binary, read_ahead]) of
|
|
{error, Reason} ->
|
|
?SLOG(error, #{msg => "failed_to_open_psk_file",
|
|
file => SrcFile,
|
|
reason => Reason}),
|
|
{error, Reason};
|
|
{ok, Io} ->
|
|
try import_psks(Io, get_config(separator), get_config(chunk_size), 0) of
|
|
ok -> ok;
|
|
{error, Reason} ->
|
|
?SLOG(error, #{msg => "failed_to_import_psk_file",
|
|
file => SrcFile,
|
|
reason => Reason}),
|
|
{error, Reason}
|
|
catch
|
|
Exception:Reason:Stacktrace ->
|
|
?SLOG(error, #{msg => "failed_to_import_psk_file",
|
|
file => SrcFile,
|
|
exception => Exception,
|
|
reason => Reason,
|
|
stacktrace => Stacktrace}),
|
|
{error, Reason}
|
|
after
|
|
_ = file:close(Io)
|
|
end
|
|
end.
|
|
|
|
import_psks(Io, Delimiter, ChunkSize, NChunk) ->
|
|
case get_psks(Io, Delimiter, ChunkSize) of
|
|
{ok, Entries} ->
|
|
_ = trans(fun insert_psks/1, [Entries]),
|
|
import_psks(Io, Delimiter, ChunkSize, NChunk + 1);
|
|
{eof, Entries} ->
|
|
_ = trans(fun insert_psks/1, [Entries]),
|
|
ok;
|
|
{error, {bad_format, {line, N}}} ->
|
|
{error, {bad_format, {line, NChunk * ChunkSize + N}}};
|
|
{error, Reaosn} ->
|
|
{error, Reaosn}
|
|
end.
|
|
|
|
get_psks(Io, Delimiter, Max) ->
|
|
get_psks(Io, Delimiter, Max, []).
|
|
|
|
get_psks(_Io, _Delimiter, 0, Acc) ->
|
|
{ok, Acc};
|
|
get_psks(Io, Delimiter, Remaining, Acc) ->
|
|
case file:read_line(Io) of
|
|
{ok, Line} ->
|
|
case binary:split(Line, Delimiter) of
|
|
[PSKIdentity, SharedSecret] ->
|
|
NSharedSecret = trim_crlf(SharedSecret),
|
|
get_psks(Io, Delimiter, Remaining - 1, [{PSKIdentity, NSharedSecret} | Acc]);
|
|
_ ->
|
|
{error, {bad_format, {line, length(Acc) + 1}}}
|
|
end;
|
|
eof ->
|
|
{eof, Acc};
|
|
{error, Reason} ->
|
|
{error, Reason}
|
|
end.
|
|
|
|
insert_psks(Entries) ->
|
|
lists:foreach(fun(Entry) ->
|
|
insert_psk(Entry)
|
|
end, Entries).
|
|
|
|
insert_psk({PSKIdentity, SharedSecret}) ->
|
|
mnesia:write(?TAB, #psk_entry{psk_id = PSKIdentity, shared_secret = SharedSecret}, write).
|
|
|
|
trim_crlf(Bin) ->
|
|
Size = byte_size(Bin),
|
|
case binary:at(Bin, Size - 1) of
|
|
?LF ->
|
|
case binary:at(Bin, Size - 2) of
|
|
?CR -> binary:part(Bin, 0, Size - 2);
|
|
_ -> binary:part(Bin, 0, Size - 1)
|
|
end;
|
|
_ -> Bin
|
|
end.
|
|
|
|
trans(Fun, Args) ->
|
|
case mria:transaction(?PSK_SHARD, Fun, Args) of
|
|
{atomic, Res} -> Res;
|
|
{aborted, Reason} -> {error, Reason}
|
|
end.
|
|
|
|
call(Request) ->
|
|
try
|
|
gen_server:call(?MODULE, Request, 10000)
|
|
catch
|
|
exit:{timeout, _Details} ->
|
|
{error, timeout}
|
|
end.
|