emqx/apps/emqx_retainer/src/emqx_retainer.erl

297 lines
10 KiB
Erlang

%%--------------------------------------------------------------------
%% Copyright (c) 2020 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_retainer).
-behaviour(gen_server).
-include("emqx_retainer.hrl").
-include_lib("emqx/include/emqx.hrl").
-include_lib("emqx/include/logger.hrl").
-include_lib("stdlib/include/ms_transform.hrl").
-logger_header("[Retainer]").
-export([start_link/1]).
-export([ load/1
, unload/0
]).
-export([ on_session_subscribed/3
, on_message_publish/2
]).
-export([clean/1]).
%% for emqx_pool task func
-export([dispatch/2]).
%% gen_server callbacks
-export([ init/1
, handle_call/3
, handle_cast/2
, handle_info/2
, terminate/2
, code_change/3
]).
-record(state, {stats_fun, stats_timer, expiry_timer}).
%%--------------------------------------------------------------------
%% Load/Unload
%%--------------------------------------------------------------------
load(Env) ->
emqx:hook('session.subscribed', fun ?MODULE:on_session_subscribed/3, []),
emqx:hook('message.publish', fun ?MODULE:on_message_publish/2, [Env]).
unload() ->
emqx:unhook('message.publish', fun ?MODULE:on_message_publish/2),
emqx:unhook('session.subscribed', fun ?MODULE:on_session_subscribed/3).
on_session_subscribed(_, _, #{share := ShareName}) when ShareName =/= undefined ->
ok;
on_session_subscribed(_, Topic, #{rh := Rh, is_new := IsNew}) ->
case Rh =:= 0 orelse (Rh =:= 1 andalso IsNew) of
true -> emqx_pool:async_submit(fun ?MODULE:dispatch/2, [self(), Topic]);
_ -> ok
end.
%% @private
dispatch(Pid, Topic) ->
Msgs = case emqx_topic:wildcard(Topic) of
false -> read_messages(Topic);
true -> match_messages(Topic)
end,
[Pid ! {deliver, Topic, Msg} || Msg <- sort_retained(Msgs)].
%% RETAIN flag set to 1 and payload containing zero bytes
on_message_publish(Msg = #message{flags = #{retain := true},
topic = Topic,
payload = <<>>}, _Env) ->
mnesia:dirty_delete(?TAB, topic2tokens(Topic)),
{ok, Msg};
on_message_publish(Msg = #message{flags = #{retain := true}}, Env) ->
Msg1 = emqx_message:set_header(retained, true, Msg),
store_retained(Msg1, Env),
{ok, Msg};
on_message_publish(Msg, _Env) ->
{ok, Msg}.
%%--------------------------------------------------------------------
%% APIs
%%--------------------------------------------------------------------
%% @doc Start the retainer
-spec(start_link(Env :: list()) -> emqx_types:startlink_ret()).
start_link(Env) ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [Env], []).
-spec(clean(emqx_types:topic()) -> non_neg_integer()).
clean(Topic) when is_binary(Topic) ->
case emqx_topic:wildcard(Topic) of
true -> match_delete_messages(Topic);
false ->
Tokens = topic2tokens(Topic),
Fun = fun() ->
case mnesia:read({?TAB, Tokens}) of
[] -> 0;
[_M] -> mnesia:delete({?TAB, Tokens}), 1
end
end,
{atomic, N} = mnesia:transaction(Fun), N
end.
%%--------------------------------------------------------------------
%% gen_server callbacks
%%--------------------------------------------------------------------
init([Env]) ->
Copies = case proplists:get_value(storage_type, Env, disc) of
ram -> ram_copies;
disc -> disc_copies;
disc_only -> disc_only_copies
end,
StoreProps = [{ets, [compressed,
{read_concurrency, true},
{write_concurrency, true}]},
{dets, [{auto_save, 1000}]}],
ok = ekka_mnesia:create_table(?TAB, [
{type, set},
{Copies, [node()]},
{record_name, retained},
{attributes, record_info(fields, retained)},
{storage_properties, StoreProps}]),
ok = ekka_mnesia:copy_table(?TAB),
case mnesia:table_info(?TAB, storage_type) of
Copies -> ok;
_Other ->
{atomic, ok} = mnesia:change_table_copy_type(?TAB, node(), Copies)
end,
StatsFun = emqx_stats:statsfun('retained.count', 'retained.max'),
{ok, StatsTimer} = timer:send_interval(timer:seconds(1), stats),
State = #state{stats_fun = StatsFun, stats_timer = StatsTimer},
{ok, start_expire_timer(proplists:get_value(expiry_interval, Env, 0), State)}.
start_expire_timer(0, State) ->
State;
start_expire_timer(undefined, State) ->
State;
start_expire_timer(Ms, State) ->
{ok, Timer} = timer:send_interval(Ms, expire),
State#state{expiry_timer = Timer}.
handle_call(Req, _From, State) ->
?LOG(error, "Unexpected call: ~p", [Req]),
{reply, ignored, State}.
handle_cast(Msg, State) ->
?LOG(error, "Unexpected cast: ~p", [Msg]),
{noreply, State}.
handle_info(stats, State = #state{stats_fun = StatsFun}) ->
StatsFun(retained_count()),
{noreply, State, hibernate};
handle_info(expire, State) ->
expire_messages(),
{noreply, State, hibernate};
handle_info(Info, State) ->
?LOG(error, "Unexpected info: ~p", [Info]),
{noreply, State}.
terminate(_Reason, _State = #state{stats_timer = TRef1, expiry_timer = TRef2}) ->
timer:cancel(TRef1), timer:cancel(TRef2).
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%--------------------------------------------------------------------
%% Internal functions
%%--------------------------------------------------------------------
sort_retained([]) -> [];
sort_retained([Msg]) -> [Msg];
sort_retained(Msgs) ->
lists:sort(fun(#message{timestamp = Ts1}, #message{timestamp = Ts2}) ->
Ts1 =< Ts2
end, Msgs).
store_retained(Msg = #message{topic = Topic, payload = Payload}, Env) ->
case {is_table_full(Env), is_too_big(size(Payload), Env)} of
{false, false} ->
ok = emqx_metrics:set('messages.retained', retained_count()),
mnesia:dirty_write(?TAB, #retained{topic = topic2tokens(Topic),
msg = Msg,
expiry_time = get_expiry_time(Msg, Env)});
{true, false} ->
case mnesia:dirty_read(?TAB, Topic) of
[_] ->
mnesia:dirty_write(?TAB, #retained{topic = topic2tokens(Topic),
msg = Msg,
expiry_time = get_expiry_time(Msg, Env)});
[] ->
?LOG(error, "Cannot retain message(topic=~s) for table is full!", [Topic])
end;
{true, _} ->
?LOG(error, "Cannot retain message(topic=~s) for table is full!", [Topic]);
{_, true} ->
?LOG(error, "Cannot retain message(topic=~s, payload_size=~p) "
"for payload is too big!", [Topic, iolist_size(Payload)])
end.
is_table_full(Env) ->
Limit = proplists:get_value(max_retained_messages, Env, 0),
Limit > 0 andalso (retained_count() > Limit).
is_too_big(Size, Env) ->
Limit = proplists:get_value(max_payload_size, Env, 0),
Limit > 0 andalso (Size > Limit).
get_expiry_time(#message{headers = #{properties := #{'Message-Expiry-Interval' := 0}}}, _Env) ->
0;
get_expiry_time(#message{headers = #{properties := #{'Message-Expiry-Interval' := Interval}}, timestamp = Ts}, _Env) ->
Ts + Interval * 1000;
get_expiry_time(#message{timestamp = Ts}, Env) ->
case proplists:get_value(expiry_interval, Env, 0) of
0 -> 0;
Interval -> Ts + Interval
end.
%%--------------------------------------------------------------------
%% Internal funcs
%%--------------------------------------------------------------------
-spec(retained_count() -> non_neg_integer()).
retained_count() -> mnesia:table_info(?TAB, size).
topic2tokens(Topic) ->
emqx_topic:words(Topic).
expire_messages() ->
NowMs = erlang:system_time(millisecond),
MsHd = #retained{topic = '$1', msg = '_', expiry_time = '$3'},
Ms = [{MsHd, [{'=/=','$3',0}, {'<','$3',NowMs}], ['$1']}],
mnesia:transaction(
fun() ->
Keys = mnesia:select(?TAB, Ms, write),
lists:foreach(fun(Key) -> mnesia:delete({?TAB, Key}) end, Keys)
end).
-spec(read_messages(emqx_types:topic())
-> [emqx_types:message()]).
read_messages(Topic) ->
Tokens = topic2tokens(Topic),
case mnesia:dirty_read(?TAB, Tokens) of
[] -> [];
[#retained{msg = Msg, expiry_time = Et}] ->
case Et =:= 0 orelse Et >= erlang:system_time(millisecond) of
true -> [Msg];
false -> []
end
end.
-spec(match_messages(emqx_types:topic())
-> [emqx_types:message()]).
match_messages(Filter) ->
NowMs = erlang:system_time(millisecond),
Cond = condition(emqx_topic:words(Filter)),
MsHd = #retained{topic = Cond, msg = '$2', expiry_time = '$3'},
Ms = [{MsHd, [{'=:=','$3',0}], ['$2']},
{MsHd, [{'>','$3',NowMs}], ['$2']}],
mnesia:dirty_select(?TAB, Ms).
-spec(match_delete_messages(emqx_types:topic())
-> DeletedCnt :: non_neg_integer()).
match_delete_messages(Filter) ->
Cond = condition(emqx_topic:words(Filter)),
MsHd = #retained{topic = Cond, msg = '_', expiry_time = '_'},
Ms = [{MsHd, [], ['$_']}],
Rs = mnesia:dirty_select(?TAB, Ms),
lists:foreach(fun(R) -> mnesia:dirty_delete_object(?TAB, R) end, Rs),
length(Rs).
%% @private
condition(Ws) ->
Ws1 = [case W =:= '+' of true -> '_'; _ -> W end || W <- Ws],
case lists:last(Ws1) =:= '#' of
false -> Ws1;
_ -> (Ws1 -- ['#']) ++ '_'
end.