diff --git a/apps/emqx/src/emqx_flapping.erl b/apps/emqx/src/emqx_flapping.erl index a0eab9c18..b48f43094 100644 --- a/apps/emqx/src/emqx_flapping.erl +++ b/apps/emqx/src/emqx_flapping.erl @@ -45,9 +45,9 @@ -define(FLAPPING_DURATION, 60000). -define(FLAPPING_BANNED_INTERVAL, 300000). -define(DEFAULT_DETECT_POLICY, - #{threshold => ?FLAPPING_THRESHOLD, - duration => ?FLAPPING_DURATION, - banned_interval => ?FLAPPING_BANNED_INTERVAL + #{max_count => ?FLAPPING_THRESHOLD, + window_time => ?FLAPPING_DURATION, + ban_time => ?FLAPPING_BANNED_INTERVAL }). -record(flapping, { @@ -69,33 +69,28 @@ stop() -> gen_server:stop(?MODULE). %% @doc Detect flapping when a MQTT client disconnected. -spec(detect(emqx_types:clientinfo()) -> boolean()). -detect(Client) -> detect(Client, get_policy()). - -detect(#{clientid := ClientId, peerhost := PeerHost}, Policy = #{threshold := Threshold}) -> - try ets:update_counter(?FLAPPING_TAB, ClientId, {#flapping.detect_cnt, 1}) of +detect(#{clientid := ClientId, peerhost := PeerHost, zone := Zone, listener := Listener}) -> + Policy = #{max_count := Threshold} = get_policy(Zone, Listener), + %% The initial flapping record sets the detect_cnt to 0. + InitVal = #flapping{ + clientid = ClientId, + peerhost = PeerHost, + started_at = erlang:system_time(millisecond), + detect_cnt = 0 + }, + case ets:update_counter(?FLAPPING_TAB, ClientId, {#flapping.detect_cnt, 1}, InitVal) of Cnt when Cnt < Threshold -> false; - _Cnt -> case ets:take(?FLAPPING_TAB, ClientId) of - [Flapping] -> - ok = gen_server:cast(?MODULE, {detected, Flapping, Policy}), - true; - [] -> false - end - catch - error:badarg -> - %% Create a flapping record. - Flapping = #flapping{clientid = ClientId, - peerhost = PeerHost, - started_at = erlang:system_time(millisecond), - detect_cnt = 1 - }, - true = ets:insert(?FLAPPING_TAB, Flapping), - false + _Cnt -> + case ets:take(?FLAPPING_TAB, ClientId) of + [Flapping] -> + ok = gen_server:cast(?MODULE, {detected, Flapping, Policy}), + true; + [] -> false + end end. --compile({inline, [get_policy/0, now_diff/1]}). - -get_policy() -> - emqx:get_env(flapping_detect_policy, ?DEFAULT_DETECT_POLICY). +get_policy(Zone, Listener) -> + emqx_config:get_listener_conf(Zone, Listener, [flapping_detect]). now_diff(TS) -> erlang:system_time(millisecond) - TS. @@ -105,11 +100,12 @@ now_diff(TS) -> erlang:system_time(millisecond) - TS. init([]) -> ok = emqx_tables:new(?FLAPPING_TAB, [public, set, - {keypos, 2}, + {keypos, #flapping.clientid}, {read_concurrency, true}, {write_concurrency, true} ]), - {ok, ensure_timer(#{}), hibernate}. + start_timers(), + {ok, #{}, hibernate}. handle_call(Req, _From, State) -> ?LOG(error, "Unexpected call: ~p", [Req]), @@ -119,11 +115,11 @@ handle_cast({detected, #flapping{clientid = ClientId, peerhost = PeerHost, started_at = StartedAt, detect_cnt = DetectCnt}, - #{duration := Duration, banned_interval := Interval}}, State) -> - case now_diff(StartedAt) < Duration of + #{window_time := WindTime, ban_time := Interval}}, State) -> + case now_diff(StartedAt) < WindTime of true -> %% Flapping happened:( ?LOG(error, "Flapping detected: ~s(~s) disconnected ~w times in ~wms", - [ClientId, inet:ntoa(PeerHost), DetectCnt, Duration]), + [ClientId, inet:ntoa(PeerHost), DetectCnt, WindTime]), Now = erlang:system_time(second), Banned = #banned{who = {clientid, ClientId}, by = <<"flapping detector">>, @@ -141,11 +137,13 @@ handle_cast(Msg, State) -> ?LOG(error, "Unexpected cast: ~p", [Msg]), {noreply, State}. -handle_info({timeout, TRef, expired_detecting}, State = #{expired_timer := TRef}) -> - Timestamp = erlang:system_time(millisecond) - maps:get(duration, get_policy()), +handle_info({timeout, _TRef, {garbage_collect, Zone, Listener}}, State) -> + Timestamp = erlang:system_time(millisecond) + - maps:get(window_time, get_policy(Zone, Listener)), MatchSpec = [{{'_', '_', '_', '$1', '_'},[{'<', '$1', Timestamp}], [true]}], ets:select_delete(?FLAPPING_TAB, MatchSpec), - {noreply, ensure_timer(State), hibernate}; + start_timer(Zone, Listener), + {noreply, State, hibernate}; handle_info(Info, State) -> ?LOG(error, "Unexpected info: ~p", [Info]), @@ -157,7 +155,13 @@ terminate(_Reason, _State) -> code_change(_OldVsn, State, _Extra) -> {ok, State}. -ensure_timer(State) -> - Timeout = maps:get(duration, get_policy()), - TRef = emqx_misc:start_timer(Timeout, expired_detecting), - State#{expired_timer => TRef}. \ No newline at end of file +start_timer(Zone, Listener) -> + WindTime = maps:get(window_time, get_policy(Zone, Listener)), + emqx_misc:start_timer(WindTime, {garbage_collect, Zone, Listener}). + +start_timers() -> + lists:foreach(fun({Zone, ZoneConf}) -> + lists:foreach(fun({Listener, _}) -> + start_timer(Zone, Listener) + end, maps:to_list(maps:get(listeners, ZoneConf, #{}))) + end, maps:to_list(emqx_config:get([zones], #{}))). \ No newline at end of file diff --git a/apps/emqx/test/emqx_flapping_SUITE.erl b/apps/emqx/test/emqx_flapping_SUITE.erl index 79eb64b45..e5b12a122 100644 --- a/apps/emqx/test/emqx_flapping_SUITE.erl +++ b/apps/emqx/test/emqx_flapping_SUITE.erl @@ -26,7 +26,11 @@ all() -> emqx_ct:all(?MODULE). init_per_suite(Config) -> emqx_ct_helpers:boot_modules(all), emqx_ct_helpers:start_apps([]), - emqx_config:put_listener_conf(default, mqtt_tcp, [flapping_detect, enable], true), + emqx_config:put_listener_conf(default, mqtt_tcp, [flapping_detect], + #{max_count => 3, + window_time => 100, + ban_time => 2 + }), Config. end_per_suite(_Config) -> @@ -35,7 +39,8 @@ end_per_suite(_Config) -> ok. t_detect_check(_) -> - ClientInfo = #{zone => external, + ClientInfo = #{zone => default, + listener => mqtt_tcp, clientid => <<"clientid">>, peerhost => {127,0,0,1} }, @@ -56,7 +61,8 @@ t_detect_check(_) -> ok = emqx_flapping:stop(). t_expired_detecting(_) -> - ClientInfo = #{zone => external, + ClientInfo = #{zone => default, + listener => mqtt_tcp, clientid => <<"clientid">>, peerhost => {127,0,0,1}}, false = emqx_flapping:detect(ClientInfo),