diff --git a/apps/emqx_modules/src/emqx_delayed.erl b/apps/emqx_modules/src/emqx_delayed.erl index 8adf86e5d..479786eba 100644 --- a/apps/emqx_modules/src/emqx_delayed.erl +++ b/apps/emqx_modules/src/emqx_delayed.erl @@ -143,10 +143,10 @@ store(DelayedMsg) -> gen_server:call(?SERVER, {store, DelayedMsg}, infinity). enable() -> - gen_server:call(?SERVER, enable). + enable(true). disable() -> - gen_server:call(?SERVER, disable). + enable(false). set_max_delayed_messages(Max) -> gen_server:call(?SERVER, {set_max_delayed_messages, Max}). @@ -238,21 +238,7 @@ update_config(Config) -> emqx_conf:update([delayed], Config, #{rawconf_with_defaults => true, override_to => cluster}). post_config_update(_KeyPath, Config, _NewConf, _OldConf, _AppEnvs) -> - case maps:get(<<"enable">>, Config, undefined) of - undefined -> - ignore; - true -> - emqx_delayed:enable(); - false -> - emqx_delayed:disable() - end, - case maps:get(<<"max_delayed_messages">>, Config, undefined) of - undefined -> - ignore; - Max -> - ok = emqx_delayed:set_max_delayed_messages(Max) - end, - ok. + gen_server:call(?SERVER, {update_config, Config}). %%-------------------------------------------------------------------- %% gen_server callback @@ -262,7 +248,7 @@ init([Opts]) -> erlang:process_flag(trap_exit, true), emqx_conf:add_handler([delayed], ?MODULE), MaxDelayedMessages = maps:get(max_delayed_messages, Opts, 0), - {ok, + State = ensure_stats_event( ensure_publish_timer(#{ publish_timer => undefined, @@ -271,7 +257,8 @@ init([Opts]) -> stats_fun => undefined, max_delayed_messages => MaxDelayedMessages }) - )}. + ), + {ok, ensure_enable(emqx:get_config([delayed, enable]), State)}. handle_call({set_max_delayed_messages, Max}, _From, State) -> {reply, ok, State#{max_delayed_messages => Max}}; @@ -293,12 +280,11 @@ handle_call( emqx_metrics:inc('messages.delayed'), {reply, ok, ensure_publish_timer(Key, State)} end; -handle_call(enable, _From, State) -> - emqx_hooks:put('message.publish', {?MODULE, on_message_publish, []}), - {reply, ok, State}; -handle_call(disable, _From, State) -> - emqx_hooks:del('message.publish', {?MODULE, on_message_publish}), - {reply, ok, State}; +handle_call({update_config, Config}, _From, #{max_delayed_messages := Max} = State) -> + Max2 = maps:get(<<"max_delayed_messages">>, Config, Max), + State2 = State#{max_delayed_messages := Max2}, + State3 = ensure_enable(maps:get(<<"enable">>, Config, undefined), State2), + {reply, ok, State3}; handle_call(Req, _From, State) -> ?tp(error, emqx_delayed_unexpected_call, #{call => Req}), {reply, ignored, State}. @@ -320,10 +306,10 @@ handle_info(Info, State) -> ?tp(error, emqx_delayed_unexpected_info, #{info => Info}), {noreply, State}. -terminate(_Reason, #{publish_timer := PublishTimer, stats_timer := StatsTimer}) -> +terminate(_Reason, #{stats_timer := StatsTimer} = State) -> emqx_conf:remove_handler([delayed]), - emqx_misc:cancel_timer(PublishTimer), - emqx_misc:cancel_timer(StatsTimer). + emqx_misc:cancel_timer(StatsTimer), + ensure_enable(false, State). code_change(_Vsn, State, _Extra) -> {ok, State}. @@ -378,3 +364,23 @@ do_publish(Key = {Ts, _Id}, Now, Acc) when Ts =< Now -> -spec delayed_count() -> non_neg_integer(). delayed_count() -> mnesia:table_info(?TAB, size). + +enable(Enable) -> + case emqx:get_raw_config([delayed]) of + #{<<"enable">> := Enable} -> + ok; + Cfg -> + {ok, _} = update_config(Cfg#{<<"enable">> := Enable}), + ok + end. + +ensure_enable(true, State) -> + emqx_hooks:put('message.publish', {?MODULE, on_message_publish, []}), + State; +ensure_enable(false, #{publish_timer := PubTimer} = State) -> + emqx_hooks:del('message.publish', {?MODULE, on_message_publish}), + emqx_misc:cancel_timer(PubTimer), + ets:delete_all_objects(?TAB), + State#{publish_timer := undefined, publish_at := 0}; +ensure_enable(_, State) -> + State. diff --git a/apps/emqx_modules/test/emqx_delayed_SUITE.erl b/apps/emqx_modules/test/emqx_delayed_SUITE.erl index c582a9722..2f11c9ba2 100644 --- a/apps/emqx_modules/test/emqx_delayed_SUITE.erl +++ b/apps/emqx_modules/test/emqx_delayed_SUITE.erl @@ -55,13 +55,26 @@ end_per_testcase(_Case, _Config) -> %% Test cases %%-------------------------------------------------------------------- -t_load_case(_) -> +t_enable_disable_case(_) -> + emqx_delayed:disable(), Hooks = emqx_hooks:lookup('message.publish'), MFA = {emqx_delayed, on_message_publish, []}, ?assertEqual(false, lists:keyfind(MFA, 2, Hooks)), + ok = emqx_delayed:enable(), Hooks1 = emqx_hooks:lookup('message.publish'), ?assertNotEqual(false, lists:keyfind(MFA, 2, Hooks1)), + + Ts0 = integer_to_binary(erlang:system_time(second) + 10), + DelayedMsg0 = emqx_message:make( + ?MODULE, 1, <<"$delayed/", Ts0/binary, "/publish">>, <<"delayed_abs">> + ), + _ = on_message_publish(DelayedMsg0), + ?assertMatch(#{data := Datas} when Datas =/= [], emqx_delayed:list(#{})), + + emqx_delayed:disable(), + ?assertEqual(false, lists:keyfind(MFA, 2, Hooks)), + ?assertMatch(#{data := []}, emqx_delayed:list(#{})), ok. t_delayed_message(_) -> @@ -76,7 +89,7 @@ t_delayed_message(_) -> [#delayed_message{msg = #message{payload = Payload}}] = ets:tab2list(emqx_delayed), ?assertEqual(<<"delayed_m">>, Payload), - ct:sleep(2000), + ct:sleep(2500), EmptyKey = mnesia:dirty_all_keys(emqx_delayed), ?assertEqual([], EmptyKey). diff --git a/apps/emqx_modules/test/emqx_delayed_api_SUITE.erl b/apps/emqx_modules/test/emqx_delayed_api_SUITE.erl index 41c1e10b9..590fe24e6 100644 --- a/apps/emqx_modules/test/emqx_delayed_api_SUITE.erl +++ b/apps/emqx_modules/test/emqx_delayed_api_SUITE.erl @@ -98,6 +98,7 @@ t_status(_Config) -> t_messages(_) -> clear_all_record(), + emqx_delayed:enable(), {ok, C1} = emqtt:start_link([{clean_start, true}]), {ok, _} = emqtt:connect(C1), @@ -114,7 +115,7 @@ t_messages(_) -> end, lists:foreach(Each, lists:seq(1, 5)), - timer:sleep(500), + timer:sleep(1000), Msgs = get_messages(5), [First | _] = Msgs, @@ -197,6 +198,7 @@ t_messages(_) -> t_large_payload(_) -> clear_all_record(), + emqx_delayed:enable(), {ok, C1} = emqtt:start_link([{clean_start, true}]), {ok, _} = emqtt:connect(C1), @@ -209,7 +211,7 @@ t_large_payload(_) -> [{qos, 0}, {retain, true}] ), - timer:sleep(500), + timer:sleep(1000), [#{msgid := MsgId}] = get_messages(1), @@ -241,8 +243,13 @@ get_messages(Len) -> {ok, 200, MsgsJson} = request(get, uri(["mqtt", "delayed", "messages"])), #{data := Msgs} = decode_json(MsgsJson), MsgLen = erlang:length(Msgs), - ?assert( - MsgLen =:= Len, - lists:flatten(io_lib:format("message length is:~p~n", [MsgLen])) + ?assertEqual( + Len, + MsgLen, + lists:flatten( + io_lib:format("message length is:~p~nWhere:~p~nHooks:~p~n", [ + MsgLen, erlang:whereis(emqx_delayed), ets:tab2list(emqx_hooks) + ]) + ) ), Msgs.