diff --git a/apps/emqx/src/emqx_cm.erl b/apps/emqx/src/emqx_cm.erl index a26d35969..6de05dabe 100644 --- a/apps/emqx/src/emqx_cm.erl +++ b/apps/emqx/src/emqx_cm.erl @@ -19,6 +19,7 @@ -behaviour(gen_server). +-include("emqx.hrl"). -include("logger.hrl"). -include("types.hrl"). -include_lib("snabbkaffe/include/snabbkaffe.hrl"). @@ -297,9 +298,9 @@ open_session(false, ClientInfo = #{clientid := ClientId}, ConnInfo) -> register_channel(ClientId, Self, ConnInfo), {ok, #{ - session => Session1, + session => clean_session(Session1), present => true, - pendings => Pendings + pendings => clean_pendings(Pendings) }}; {living, ConnMod, ChanPid, Session} -> ok = emqx_session:resume(ClientInfo, Session), @@ -316,9 +317,9 @@ open_session(false, ClientInfo = #{clientid := ClientId}, ConnInfo) -> ), register_channel(ClientId, Self, ConnInfo), {ok, #{ - session => Session1, + session => clean_session(Session1), present => true, - pendings => Pendings + pendings => clean_pendings(Pendings) }}; {error, _} -> CreateSess() @@ -732,3 +733,14 @@ get_connected_client_count() -> undefined -> 0; Size -> Size end. + +clean_session(Session) -> + emqx_session:filter_queue(fun is_banned_msg/1, Session). + +clean_pendings(Pendings) -> + lists:filter(fun is_banned_msg/1, Pendings). + +is_banned_msg(#message{from = ClientId}) -> + [] =:= emqx_banned:look_up({clientid, ClientId}); +is_banned_msg({deliver, _Topic, Msg}) -> + is_banned_msg(Msg). diff --git a/apps/emqx/src/emqx_mqueue.erl b/apps/emqx/src/emqx_mqueue.erl index 494e2b33e..fbf29d754 100644 --- a/apps/emqx/src/emqx_mqueue.erl +++ b/apps/emqx/src/emqx_mqueue.erl @@ -67,7 +67,8 @@ out/1, stats/1, dropped/1, - to_list/1 + to_list/1, + filter/2 ]). -define(NO_PRIORITY_TABLE, disabled). @@ -158,6 +159,19 @@ max_len(#mqueue{max_len = MaxLen}) -> MaxLen. to_list(MQ) -> to_list(MQ, []). +-spec filter(fun((any()) -> boolean()), mqueue()) -> mqueue(). +filter(_Pred, #mqueue{len = 0} = MQ) -> + MQ; +filter(Pred, #mqueue{q = Q, len = Len, dropped = Droppend} = MQ) -> + Q2 = ?PQUEUE:filter(Pred, Q), + case ?PQUEUE:len(Q2) of + Len -> + MQ; + Len2 -> + Diff = Len - Len2, + MQ#mqueue{q = Q2, len = Len2, dropped = Droppend + Diff} + end. + to_list(MQ, Acc) -> case out(MQ) of {empty, _MQ} -> diff --git a/apps/emqx/src/emqx_session.erl b/apps/emqx/src/emqx_session.erl index b3a8ecebc..a13dfe491 100644 --- a/apps/emqx/src/emqx_session.erl +++ b/apps/emqx/src/emqx_session.erl @@ -82,6 +82,7 @@ deliver/3, enqueue/3, dequeue/2, + filter_queue/2, ignore_local/4, retry/2, terminate/3 @@ -529,6 +530,9 @@ dequeue(ClientInfo, Cnt, Msgs, Q) -> end end. +filter_queue(Pred, #session{mqueue = Q} = Session) -> + Session#session{mqueue = emqx_mqueue:filter(Pred, Q)}. + acc_cnt(#message{qos = ?QOS_0}, Cnt) -> Cnt; acc_cnt(_Msg, Cnt) -> Cnt - 1. diff --git a/apps/emqx/test/emqx_banned_SUITE.erl b/apps/emqx/test/emqx_banned_SUITE.erl index ed22a019a..605c1de6d 100644 --- a/apps/emqx/test/emqx_banned_SUITE.erl +++ b/apps/emqx/test/emqx_banned_SUITE.erl @@ -141,3 +141,73 @@ t_kick(_) -> snabbkaffe:stop(), emqx_banned:delete(Who), ?assertEqual(1, length(?of_kind(kick_session_due_to_banned, Trace))). + +t_session_taken(_) -> + erlang:process_flag(trap_exit, true), + Topic = <<"t/banned">>, + ClientId2 = <<"t_session_taken">>, + MsgNum = 3, + Connect = fun() -> + {ok, C} = emqtt:start_link([ + {clientid, <<"client1">>}, + {proto_ver, v5}, + {clean_start, false}, + {properties, #{'Session-Expiry-Interval' => 120}} + ]), + {ok, _} = emqtt:connect(C), + {ok, _, [0]} = emqtt:subscribe(C, Topic, []), + C + end, + + Publish = fun() -> + lists:foreach( + fun(_) -> + Msg = emqx_message:make(ClientId2, Topic, <<"payload">>), + emqx_broker:safe_publish(Msg) + end, + lists:seq(1, MsgNum) + ) + end, + + C1 = Connect(), + ok = emqtt:disconnect(C1), + + Publish(), + + C2 = Connect(), + ?assertEqual(MsgNum, length(receive_messages(MsgNum + 1))), + ok = emqtt:disconnect(C2), + + Publish(), + + Now = erlang:system_time(second), + Who = {clientid, ClientId2}, + emqx_banned:create(#{ + who => Who, + by => <<"test">>, + reason => <<"test">>, + at => Now, + until => Now + 120 + }), + + C3 = Connect(), + ?assertEqual(0, length(receive_messages(MsgNum + 1))), + emqx_banned:delete(Who), + {ok, #{}, [0]} = emqtt:unsubscribe(C3, Topic), + ok = emqtt:disconnect(C3). + +receive_messages(Count) -> + receive_messages(Count, []). +receive_messages(0, Msgs) -> + Msgs; +receive_messages(Count, Msgs) -> + receive + {publish, Msg} -> + ct:log("Msg: ~p ~n", [Msg]), + receive_messages(Count - 1, [Msg | Msgs]); + Other -> + ct:log("Other Msg: ~p~n", [Other]), + receive_messages(Count, Msgs) + after 1200 -> + Msgs + end. diff --git a/changes/v5.0.16/feat-9893.en.md b/changes/v5.0.16/feat-9893.en.md new file mode 100644 index 000000000..590d82a0f --- /dev/null +++ b/changes/v5.0.16/feat-9893.en.md @@ -0,0 +1 @@ +When connecting with the flag `clean_start=false`, the new session will filter out banned messages from the `mqueue` before deliver. diff --git a/changes/v5.0.16/feat-9893.zh.md b/changes/v5.0.16/feat-9893.zh.md new file mode 100644 index 000000000..30286a679 --- /dev/null +++ b/changes/v5.0.16/feat-9893.zh.md @@ -0,0 +1 @@ +当使用 `clean_start=false` 标志连接时,新会话将在传递之前从 `mqueue` 中过滤掉被封禁的消息。