diff --git a/apps/emqx/src/emqx_channel.erl b/apps/emqx/src/emqx_channel.erl index 880226031..e3d65b8ea 100644 --- a/apps/emqx/src/emqx_channel.erl +++ b/apps/emqx/src/emqx_channel.erl @@ -1338,6 +1338,17 @@ handle_timeout( NChannel = Channel#channel{session = NSession}, handle_out(publish, Publishes, reset_timer(TimerName, Timeout, NChannel)) end; +handle_timeout( + _TRef, + {emqx_session, TimerName}, + Channel = #channel{session = Session, clientinfo = ClientInfo} +) -> + case emqx_session:handle_timeout(ClientInfo, TimerName, Session) of + {ok, [], NSession} -> + {ok, Channel#channel{session = NSession}}; + {ok, Replies, NSession} -> + handle_out(publish, Replies, Channel#channel{session = NSession}) + end; handle_timeout(_TRef, expire_session, Channel) -> shutdown(expired, Channel); handle_timeout( diff --git a/apps/emqx/src/emqx_session.erl b/apps/emqx/src/emqx_session.erl index 759ecab58..8bdd47392 100644 --- a/apps/emqx/src/emqx_session.erl +++ b/apps/emqx/src/emqx_session.erl @@ -88,6 +88,13 @@ terminate/3 ]). +% Timers +-export([ + ensure_timer/3, + reset_timer/3, + cancel_timer/2 +]). + % Foreign session implementations -export([enrich_delivers/3]). @@ -103,7 +110,9 @@ conninfo/0, reply/0, replies/0, - common_timer_name/0 + common_timer_name/0, + custom_timer_name/0, + timerset/0 ]). -type session_id() :: _TODO. @@ -118,6 +127,7 @@ }. -type common_timer_name() :: retry_delivery | expire_awaiting_rel. +-type custom_timer_name() :: atom(). -type message() :: emqx_types:message(). -type publish() :: {maybe(emqx_types:packet_id()), emqx_types:message()}. @@ -144,6 +154,8 @@ emqx_session_mem:session() | emqx_persistent_session_ds:session(). +-type timerset() :: #{custom_timer_name() => _TimerRef :: reference()}. + -define(INFO_KEYS, [ id, created_at, @@ -442,14 +454,41 @@ enrich_subopts(_Opt, _V, Msg, _) -> %% Timeouts %%-------------------------------------------------------------------- --spec handle_timeout(clientinfo(), common_timer_name(), t()) -> +-spec handle_timeout(clientinfo(), common_timer_name() | custom_timer_name(), t()) -> {ok, replies(), t()} + %% NOTE: only relevant for `common_timer_name()` | {ok, replies(), timeout(), t()}. handle_timeout(ClientInfo, Timer, Session) -> ?IMPL(Session):handle_timeout(ClientInfo, Timer, Session). %%-------------------------------------------------------------------- +-spec ensure_timer(custom_timer_name(), timeout(), timerset()) -> + timerset(). +ensure_timer(Name, _Time, Timers = #{}) when is_map_key(Name, Timers) -> + Timers; +ensure_timer(Name, Time, Timers = #{}) when Time > 0 -> + TRef = emqx_utils:start_timer(Time, {?MODULE, Name}), + Timers#{Name => TRef}. + +-spec reset_timer(custom_timer_name(), timeout(), timerset()) -> + timerset(). +reset_timer(Name, Time, Channel) -> + ensure_timer(Name, Time, cancel_timer(Name, Channel)). + +-spec cancel_timer(custom_timer_name(), timerset()) -> + timerset(). +cancel_timer(Name, Timers) -> + case maps:take(Name, Timers) of + {TRef, NTimers} -> + ok = emqx_utils:cancel_timer(TRef), + NTimers; + error -> + Timers + end. + +%%-------------------------------------------------------------------- + -spec disconnect(clientinfo(), t()) -> {idle | shutdown, t()}. disconnect(_ClientInfo, Session) -> @@ -549,7 +588,18 @@ is_banned_msg(#message{from = ClientId}) -> get_impl_mod(Session) when ?IS_SESSION_IMPL_MEM(Session) -> emqx_session_mem; get_impl_mod(Session) when ?IS_SESSION_IMPL_DS(Session) -> - emqx_persistent_session_ds. + emqx_persistent_session_ds; +get_impl_mod(Session) -> + maybe_mock_impl_mod(Session). + +-ifdef(TEST). +maybe_mock_impl_mod({Mock, _State}) when is_atom(Mock) -> + Mock. +-else. +-spec maybe_mock_impl_mod(_Session) -> no_return(). +maybe_mock_impl_mod(_) -> + error(noimpl). +-endif. -spec choose_impl_mod(conninfo()) -> module(). choose_impl_mod(#{expiry_interval := EI}) -> diff --git a/apps/emqx/test/emqx_channel_SUITE.erl b/apps/emqx/test/emqx_channel_SUITE.erl index 408ae0014..8f6a2baaa 100644 --- a/apps/emqx/test/emqx_channel_SUITE.erl +++ b/apps/emqx/test/emqx_channel_SUITE.erl @@ -22,6 +22,7 @@ -include_lib("emqx/include/emqx.hrl"). -include_lib("emqx/include/emqx_mqtt.hrl"). -include_lib("eunit/include/eunit.hrl"). +-include_lib("emqx/include/asserts.hrl"). -include_lib("common_test/include/ct.hrl"). all() -> @@ -729,6 +730,10 @@ t_handle_info_sock_closed(_) -> %% Test cases for handle_timeout %%-------------------------------------------------------------------- +-define(CUSTOM_TIMER1_TIMEOUT, 100). +-define(CUSTOM_TIMER2_TIMEOUT, 20). +-define(CUSTOM_TIMER3_TIMEOUT, 50). + t_handle_timeout_keepalive(_) -> TRef = make_ref(), Channel = emqx_channel:set_field(timers, #{keepalive => TRef}, channel()), @@ -752,6 +757,54 @@ t_handle_timeout_expire_session(_) -> t_handle_timeout_will_message(_) -> {ok, _Chan} = emqx_channel:handle_timeout(make_ref(), will_message, channel()). +t_handle_custom_timers(_) -> + Channel = channel(#{ + conn_state => connected, + session => {?MODULE, #{}} + }), + {ok, [{outgoing, ?SUBACK_PACKET(1, [?QOS_0])} | _], Chan1} = + emqx_channel:handle_in( + ?SUBSCRIBE_PACKET(1, #{}, [{<<"+/+">>, ?DEFAULT_SUBOPTS}]), + Channel + ), + {timeout, T1Ref, T1Msg} = ?assertReceive({timeout, _, _}, ?CUSTOM_TIMER1_TIMEOUT * 2), + {ok, {outgoing, [?PUBLISH_PACKET(0, <<"a/b">>, 1, <<"t1">>)]}, Chan2} = + emqx_channel:handle_timeout(T1Ref, T1Msg, Chan1), + {timeout, T2Ref, T2Msg} = ?assertReceive({timeout, _, _}, ?CUSTOM_TIMER2_TIMEOUT * 2), + {ok, {outgoing, [?PUBLISH_PACKET(0, <<"c/d">>, 2, <<"t2">>)]}, _Chan} = + emqx_channel:handle_timeout(T2Ref, T2Msg, Chan2), + ok = ?assertNotReceive({timeout, _, _}, ?CUSTOM_TIMER3_TIMEOUT * 2). + +%%-------------------------------------------------------------------- +%% Mocked session module +%%-------------------------------------------------------------------- + +subscribe(_TopicFilter, _SubOpts = #{}, {?MODULE, Session0}) -> + % NOTE: Only this one should be triggered + Session1 = emqx_session:ensure_timer(t1, ?CUSTOM_TIMER1_TIMEOUT, Session0), + Session = emqx_session:ensure_timer(t1, ?CUSTOM_TIMER1_TIMEOUT * 5, Session1), + {ok, {?MODULE, Session}}. + +get_subscription(_TopicFilter, {?MODULE, _Session}) -> + undefined. + +handle_timeout(_ClientInfo, t1, {?MODULE, Session0}) -> + Msg = emqx_message:make(<<"a/b">>, <<"t1">>), + Session1 = maps:remove(t1, Session0), + % NOTE: Only this one should be reset by the second call. + Session2 = emqx_session:reset_timer(t2, ?CUSTOM_TIMER2_TIMEOUT * 5, Session1), + Session3 = emqx_session:reset_timer(t2, ?CUSTOM_TIMER2_TIMEOUT, Session2), + Session = emqx_session:reset_timer(t3, ?CUSTOM_TIMER3_TIMEOUT, Session3), + {ok, [{1, Msg}], {?MODULE, Session}}; +handle_timeout(_ClientInfo, t2, {?MODULE, Session0}) -> + Msg = emqx_message:make(<<"c/d">>, <<"t2">>), + Session1 = maps:remove(t2, Session0), + Session2 = emqx_session:cancel_timer(t2, Session1), + % NOTE: Thus `t3` should not be triggered, see `?assertNotReceive` above. + Session = emqx_session:cancel_timer(t3, Session2), + ok = ?assertEqual(#{}, Session), + {ok, [{2, Msg}], {?MODULE, Session}}. + %%-------------------------------------------------------------------- %% Test cases for internal functions %%-------------------------------------------------------------------- diff --git a/apps/emqx_gateway_mqttsn/src/emqx_mqttsn_channel.erl b/apps/emqx_gateway_mqttsn/src/emqx_mqttsn_channel.erl index 087187379..e7061e4a5 100644 --- a/apps/emqx_gateway_mqttsn/src/emqx_mqttsn_channel.erl +++ b/apps/emqx_gateway_mqttsn/src/emqx_mqttsn_channel.erl @@ -2116,6 +2116,9 @@ handle_timeout(_TRef, expire_session, Channel) -> handle_timeout(_TRef, expire_asleep, Channel) -> shutdown(asleep_timeout, Channel); handle_timeout(_TRef, Msg, Channel) -> + %% NOTE + %% We do not expect `emqx_mqttsn_session` to set up any custom timers (i.e with + %% `emqx_session:ensure_timer/3`), because `emqx_session_mem` doesn't use any. ?SLOG(error, #{ msg => "unexpected_timeout", timeout_msg => Msg