diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl index 791433b37..daaec16e1 100644 --- a/src/emqx_channel.erl +++ b/src/emqx_channel.erl @@ -56,9 +56,12 @@ , clear_keepalive/1 ]). -%% Exports for CT -export([set_field/3]). +-ifdef(TEST). +-export([ensure_timer/3]). +-endif. + -import(emqx_misc, [ run_fold/3 , pipeline/3 @@ -622,20 +625,20 @@ do_publish(PacketId, Msg = #message{qos = ?QOS_1}, Channel) -> NChannel = ensure_quota(PubRes, Channel), handle_out(puback, {PacketId, RC}, NChannel); -do_publish(PacketId, Msg = #message{qos = ?QOS_2}, - Channel = #channel{session = Session}) -> +do_publish(PacketId, Msg = #message{qos = ?QOS_2}, Channel0) -> + #channel{session = Session} = NChannel = maybe_clean_expired_awaiting_rel(Channel0), case emqx_session:publish(PacketId, Msg, Session) of {ok, PubRes, NSession} -> RC = puback_reason_code(PubRes), - NChannel1 = ensure_timer(await_timer, Channel#channel{session = NSession}), + NChannel1 = ensure_timer(await_timer, NChannel#channel{session = NSession}), NChannel2 = ensure_quota(PubRes, NChannel1), handle_out(pubrec, {PacketId, RC}, NChannel2); {error, RC = ?RC_PACKET_IDENTIFIER_IN_USE} -> ok = emqx_metrics:inc('packets.publish.inuse'), - handle_out(pubrec, {PacketId, RC}, Channel); + handle_out(pubrec, {PacketId, RC}, NChannel); {error, RC = ?RC_RECEIVE_MAXIMUM_EXCEEDED} -> ok = emqx_metrics:inc('packets.publish.dropped'), - handle_out(disconnect, RC, Channel) + handle_out(disconnect, RC, NChannel) end. ensure_quota(_, Channel = #channel{quota = undefined}) -> @@ -841,7 +844,6 @@ handle_out(connack, {?RC_SUCCESS, SP, Props}, Channel = #channel{conninfo = Conn [ConnInfo, emqx_reason_codes:name(?RC_SUCCESS)], AckProps ), - return_connack(?CONNACK_PACKET(?RC_SUCCESS, SP, NAckProps), ensure_keepalive(NAckProps, Channel)); @@ -923,7 +925,7 @@ return_connack(AckPacket, Channel) -> }, {Packets, NChannel1} = do_deliver(Publishes, NChannel), Outgoing = [{outgoing, Packets} || length(Packets) > 0], - {ok, Replies ++ Outgoing, NChannel1} + {ok, Replies ++ Outgoing, ensure_timer(retry_timer, NChannel1)} end. %%-------------------------------------------------------------------- @@ -1127,17 +1129,8 @@ handle_timeout(_TRef, retry_delivery, handle_out(publish, Publishes, reset_timer(retry_timer, Timeout, NChannel)) end; -handle_timeout(_TRef, expire_awaiting_rel, - Channel = #channel{conn_state = disconnected}) -> - {ok, Channel}; -handle_timeout(_TRef, expire_awaiting_rel, - Channel = #channel{session = Session}) -> - case emqx_session:expire(awaiting_rel, Session) of - {ok, NSession} -> - {ok, clean_timer(await_timer, Channel#channel{session = NSession})}; - {ok, Timeout, NSession} -> - {ok, reset_timer(await_timer, Timeout, Channel#channel{session = NSession})} - end; +handle_timeout(_TRef, expire_awaiting_rel, Channel) -> + {ok, clean_expired_awaiting_rel(Channel)}; handle_timeout(_TRef, expire_session, Channel) -> shutdown(expired, Channel); @@ -1182,6 +1175,26 @@ reset_timer(Name, Time, Channel) -> clean_timer(Name, Channel = #channel{timers = Timers}) -> Channel#channel{timers = maps:remove(Name, Timers)}. +is_timer_alive(Name, #channel{timers = Timers}) -> + case maps:find(Name, Timers) of + error -> false; + {ok, _TRef} -> true + end. + +maybe_clean_expired_awaiting_rel(Channel) -> + case is_timer_alive(await_timer, Channel) of + true -> Channel; + false -> clean_expired_awaiting_rel(Channel) + end. + +clean_expired_awaiting_rel(Channel = #channel{session = Session}) -> + case emqx_session:expire(awaiting_rel, Session) of + {ok, NSession} -> + clean_timer(await_timer, Channel#channel{session = NSession}); + {ok, Timeout, NSession} -> + reset_timer(await_timer, Timeout, Channel#channel{session = NSession}) + end. + -spec interval(channel_timer(), channel()) -> timeout(). interval(alive_timer, #channel{keepalive = KeepAlive}) -> emqx_keepalive:info(interval, KeepAlive); @@ -1878,10 +1891,6 @@ is_disconnect_event_enabled(discarded) -> is_disconnect_event_enabled(takeovered) -> emqx:get_env(client_disconnect_takeovered, false). -%%-------------------------------------------------------------------- -%% For CT tests -%%-------------------------------------------------------------------- - set_field(Name, Value, Channel) -> Pos = emqx_misc:index_of(Name, record_info(fields, channel)), setelement(Pos+1, Channel, Value).