Merge pull request #7716 from lafirest/fix/delayed_disable_action

fix(delayed): unify and optimize the enable/disable codes
This commit is contained in:
JianBo He 2022-04-28 09:39:41 +08:00 committed by GitHub
commit 2dded74584
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 35 deletions

View File

@ -143,10 +143,10 @@ store(DelayedMsg) ->
gen_server:call(?SERVER, {store, DelayedMsg}, infinity). gen_server:call(?SERVER, {store, DelayedMsg}, infinity).
enable() -> enable() ->
gen_server:call(?SERVER, enable). enable(true).
disable() -> disable() ->
gen_server:call(?SERVER, disable). enable(false).
set_max_delayed_messages(Max) -> set_max_delayed_messages(Max) ->
gen_server:call(?SERVER, {set_max_delayed_messages, Max}). gen_server:call(?SERVER, {set_max_delayed_messages, Max}).
@ -238,21 +238,7 @@ update_config(Config) ->
emqx_conf:update([delayed], Config, #{rawconf_with_defaults => true, override_to => cluster}). emqx_conf:update([delayed], Config, #{rawconf_with_defaults => true, override_to => cluster}).
post_config_update(_KeyPath, Config, _NewConf, _OldConf, _AppEnvs) -> post_config_update(_KeyPath, Config, _NewConf, _OldConf, _AppEnvs) ->
case maps:get(<<"enable">>, Config, undefined) of gen_server:call(?SERVER, {update_config, Config}).
undefined ->
ignore;
true ->
emqx_delayed:enable();
false ->
emqx_delayed:disable()
end,
case maps:get(<<"max_delayed_messages">>, Config, undefined) of
undefined ->
ignore;
Max ->
ok = emqx_delayed:set_max_delayed_messages(Max)
end,
ok.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% gen_server callback %% gen_server callback
@ -262,7 +248,7 @@ init([Opts]) ->
erlang:process_flag(trap_exit, true), erlang:process_flag(trap_exit, true),
emqx_conf:add_handler([delayed], ?MODULE), emqx_conf:add_handler([delayed], ?MODULE),
MaxDelayedMessages = maps:get(max_delayed_messages, Opts, 0), MaxDelayedMessages = maps:get(max_delayed_messages, Opts, 0),
{ok, State =
ensure_stats_event( ensure_stats_event(
ensure_publish_timer(#{ ensure_publish_timer(#{
publish_timer => undefined, publish_timer => undefined,
@ -271,7 +257,8 @@ init([Opts]) ->
stats_fun => undefined, stats_fun => undefined,
max_delayed_messages => MaxDelayedMessages max_delayed_messages => MaxDelayedMessages
}) })
)}. ),
{ok, ensure_enable(emqx:get_config([delayed, enable]), State)}.
handle_call({set_max_delayed_messages, Max}, _From, State) -> handle_call({set_max_delayed_messages, Max}, _From, State) ->
{reply, ok, State#{max_delayed_messages => Max}}; {reply, ok, State#{max_delayed_messages => Max}};
@ -293,12 +280,11 @@ handle_call(
emqx_metrics:inc('messages.delayed'), emqx_metrics:inc('messages.delayed'),
{reply, ok, ensure_publish_timer(Key, State)} {reply, ok, ensure_publish_timer(Key, State)}
end; end;
handle_call(enable, _From, State) -> handle_call({update_config, Config}, _From, #{max_delayed_messages := Max} = State) ->
emqx_hooks:put('message.publish', {?MODULE, on_message_publish, []}), Max2 = maps:get(<<"max_delayed_messages">>, Config, Max),
{reply, ok, State}; State2 = State#{max_delayed_messages := Max2},
handle_call(disable, _From, State) -> State3 = ensure_enable(maps:get(<<"enable">>, Config, undefined), State2),
emqx_hooks:del('message.publish', {?MODULE, on_message_publish}), {reply, ok, State3};
{reply, ok, State};
handle_call(Req, _From, State) -> handle_call(Req, _From, State) ->
?tp(error, emqx_delayed_unexpected_call, #{call => Req}), ?tp(error, emqx_delayed_unexpected_call, #{call => Req}),
{reply, ignored, State}. {reply, ignored, State}.
@ -320,10 +306,10 @@ handle_info(Info, State) ->
?tp(error, emqx_delayed_unexpected_info, #{info => Info}), ?tp(error, emqx_delayed_unexpected_info, #{info => Info}),
{noreply, State}. {noreply, State}.
terminate(_Reason, #{publish_timer := PublishTimer, stats_timer := StatsTimer}) -> terminate(_Reason, #{stats_timer := StatsTimer} = State) ->
emqx_conf:remove_handler([delayed]), emqx_conf:remove_handler([delayed]),
emqx_misc:cancel_timer(PublishTimer), emqx_misc:cancel_timer(StatsTimer),
emqx_misc:cancel_timer(StatsTimer). ensure_enable(false, State).
code_change(_Vsn, State, _Extra) -> code_change(_Vsn, State, _Extra) ->
{ok, State}. {ok, State}.
@ -378,3 +364,23 @@ do_publish(Key = {Ts, _Id}, Now, Acc) when Ts =< Now ->
-spec delayed_count() -> non_neg_integer(). -spec delayed_count() -> non_neg_integer().
delayed_count() -> mnesia:table_info(?TAB, size). delayed_count() -> mnesia:table_info(?TAB, size).
enable(Enable) ->
case emqx:get_raw_config([delayed]) of
#{<<"enable">> := Enable} ->
ok;
Cfg ->
{ok, _} = update_config(Cfg#{<<"enable">> := Enable}),
ok
end.
ensure_enable(true, State) ->
emqx_hooks:put('message.publish', {?MODULE, on_message_publish, []}),
State;
ensure_enable(false, #{publish_timer := PubTimer} = State) ->
emqx_hooks:del('message.publish', {?MODULE, on_message_publish}),
emqx_misc:cancel_timer(PubTimer),
ets:delete_all_objects(?TAB),
State#{publish_timer := undefined, publish_at := 0};
ensure_enable(_, State) ->
State.

View File

@ -55,13 +55,26 @@ end_per_testcase(_Case, _Config) ->
%% Test cases %% Test cases
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
t_load_case(_) -> t_enable_disable_case(_) ->
emqx_delayed:disable(),
Hooks = emqx_hooks:lookup('message.publish'), Hooks = emqx_hooks:lookup('message.publish'),
MFA = {emqx_delayed, on_message_publish, []}, MFA = {emqx_delayed, on_message_publish, []},
?assertEqual(false, lists:keyfind(MFA, 2, Hooks)), ?assertEqual(false, lists:keyfind(MFA, 2, Hooks)),
ok = emqx_delayed:enable(), ok = emqx_delayed:enable(),
Hooks1 = emqx_hooks:lookup('message.publish'), Hooks1 = emqx_hooks:lookup('message.publish'),
?assertNotEqual(false, lists:keyfind(MFA, 2, Hooks1)), ?assertNotEqual(false, lists:keyfind(MFA, 2, Hooks1)),
Ts0 = integer_to_binary(erlang:system_time(second) + 10),
DelayedMsg0 = emqx_message:make(
?MODULE, 1, <<"$delayed/", Ts0/binary, "/publish">>, <<"delayed_abs">>
),
_ = on_message_publish(DelayedMsg0),
?assertMatch(#{data := Datas} when Datas =/= [], emqx_delayed:list(#{})),
emqx_delayed:disable(),
?assertEqual(false, lists:keyfind(MFA, 2, Hooks)),
?assertMatch(#{data := []}, emqx_delayed:list(#{})),
ok. ok.
t_delayed_message(_) -> t_delayed_message(_) ->
@ -76,7 +89,7 @@ t_delayed_message(_) ->
[#delayed_message{msg = #message{payload = Payload}}] = ets:tab2list(emqx_delayed), [#delayed_message{msg = #message{payload = Payload}}] = ets:tab2list(emqx_delayed),
?assertEqual(<<"delayed_m">>, Payload), ?assertEqual(<<"delayed_m">>, Payload),
ct:sleep(2000), ct:sleep(2500),
EmptyKey = mnesia:dirty_all_keys(emqx_delayed), EmptyKey = mnesia:dirty_all_keys(emqx_delayed),
?assertEqual([], EmptyKey). ?assertEqual([], EmptyKey).

View File

@ -98,6 +98,7 @@ t_status(_Config) ->
t_messages(_) -> t_messages(_) ->
clear_all_record(), clear_all_record(),
emqx_delayed:enable(),
{ok, C1} = emqtt:start_link([{clean_start, true}]), {ok, C1} = emqtt:start_link([{clean_start, true}]),
{ok, _} = emqtt:connect(C1), {ok, _} = emqtt:connect(C1),
@ -114,7 +115,7 @@ t_messages(_) ->
end, end,
lists:foreach(Each, lists:seq(1, 5)), lists:foreach(Each, lists:seq(1, 5)),
timer:sleep(500), timer:sleep(1000),
Msgs = get_messages(5), Msgs = get_messages(5),
[First | _] = Msgs, [First | _] = Msgs,
@ -197,6 +198,7 @@ t_messages(_) ->
t_large_payload(_) -> t_large_payload(_) ->
clear_all_record(), clear_all_record(),
emqx_delayed:enable(),
{ok, C1} = emqtt:start_link([{clean_start, true}]), {ok, C1} = emqtt:start_link([{clean_start, true}]),
{ok, _} = emqtt:connect(C1), {ok, _} = emqtt:connect(C1),
@ -209,7 +211,7 @@ t_large_payload(_) ->
[{qos, 0}, {retain, true}] [{qos, 0}, {retain, true}]
), ),
timer:sleep(500), timer:sleep(1000),
[#{msgid := MsgId}] = get_messages(1), [#{msgid := MsgId}] = get_messages(1),
@ -241,8 +243,13 @@ get_messages(Len) ->
{ok, 200, MsgsJson} = request(get, uri(["mqtt", "delayed", "messages"])), {ok, 200, MsgsJson} = request(get, uri(["mqtt", "delayed", "messages"])),
#{data := Msgs} = decode_json(MsgsJson), #{data := Msgs} = decode_json(MsgsJson),
MsgLen = erlang:length(Msgs), MsgLen = erlang:length(Msgs),
?assert( ?assertEqual(
MsgLen =:= Len, Len,
lists:flatten(io_lib:format("message length is:~p~n", [MsgLen])) MsgLen,
lists:flatten(
io_lib:format("message length is:~p~nWhere:~p~nHooks:~p~n", [
MsgLen, erlang:whereis(emqx_delayed), ets:tab2list(emqx_hooks)
])
)
), ),
Msgs. Msgs.