feat: refactor flapping detect conf

This commit is contained in:
zhongwencool 2023-06-06 17:56:36 +08:00
parent d8190caa49
commit 48381d4c86
6 changed files with 75 additions and 63 deletions

View File

@ -1632,10 +1632,9 @@ check_banned(_ConnPkt, #channel{clientinfo = ClientInfo}) ->
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Flapping %% Flapping
count_flapping_event(_ConnPkt, Channel = #channel{clientinfo = ClientInfo = #{zone := Zone}}) -> count_flapping_event(_ConnPkt, #channel{clientinfo = ClientInfo}) ->
is_integer(emqx_config:get_zone_conf(Zone, [flapping_detect, window_time])) andalso _ = emqx_flapping:detect(ClientInfo),
emqx_flapping:detect(ClientInfo), ok.
{ok, Channel}.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Authenticate %% Authenticate

View File

@ -707,7 +707,10 @@ do_put(Type, Putter, [], DeepValue) ->
do_put(Type, Putter, [RootName | KeyPath], DeepValue) -> do_put(Type, Putter, [RootName | KeyPath], DeepValue) ->
OldValue = do_get(Type, [RootName], #{}), OldValue = do_get(Type, [RootName], #{}),
NewValue = do_deep_put(Type, Putter, KeyPath, OldValue, DeepValue), NewValue = do_deep_put(Type, Putter, KeyPath, OldValue, DeepValue),
persistent_term:put(?PERSIS_KEY(Type, RootName), NewValue). Key = ?PERSIS_KEY(Type, RootName),
persistent_term:put(Key, NewValue),
post_save_config_hook(Key, NewValue),
ok.
do_deep_get(?CONF, AtomKeyPath, Map, Default) -> do_deep_get(?CONF, AtomKeyPath, Map, Default) ->
emqx_utils_maps:deep_get(AtomKeyPath, Map, Default); emqx_utils_maps:deep_get(AtomKeyPath, Map, Default);
@ -829,15 +832,12 @@ merge_with_global_defaults(GlobalDefaults, ZoneVal) ->
maybe_update_zone([zones | T], ZonesValue, Value) -> maybe_update_zone([zones | T], ZonesValue, Value) ->
%% note, do not write to PT, return *New value* instead %% note, do not write to PT, return *New value* instead
NewZonesValue = emqx_utils_maps:deep_put(T, ZonesValue, Value), NewZonesValue = emqx_utils_maps:deep_put(T, ZonesValue, Value),
ExistingZoneNames = maps:keys(?MODULE:get([zones], #{})),
%% Update only new zones with global defaults
GLD = zone_global_defaults(), GLD = zone_global_defaults(),
maps:fold( maps:map(
fun(ZoneName, ZoneValue, Acc) -> fun(_ZoneName, ZoneValue) ->
Acc#{ZoneName := merge_with_global_defaults(GLD, ZoneValue)} merge_with_global_defaults(GLD, ZoneValue)
end, end,
NewZonesValue, NewZonesValue
maps:without(ExistingZoneNames, NewZonesValue)
); );
maybe_update_zone([RootName | T], RootValue, Value) when is_atom(RootName) -> maybe_update_zone([RootName | T], RootValue, Value) when is_atom(RootName) ->
NewRootValue = emqx_utils_maps:deep_put(T, RootValue, Value), NewRootValue = emqx_utils_maps:deep_put(T, RootValue, Value),
@ -911,3 +911,11 @@ rawconf_to_conf(SchemaModule, RawPath, RawValue) ->
), ),
AtomPath = to_atom_conf_path(RawPath, {raise_error, maybe_update_zone_error}), AtomPath = to_atom_conf_path(RawPath, {raise_error, maybe_update_zone_error}),
emqx_utils_maps:deep_get(AtomPath, RawUserDefinedValues). emqx_utils_maps:deep_get(AtomPath, RawUserDefinedValues).
%% When the global zone change, the zones is updated with the new global zone.
%% The zones config has no config_handler callback, so we need to update via this hook
post_save_config_hook(?PERSIS_KEY(?CONF, zones), _Zones) ->
emqx_flapping:update_config(),
ok;
post_save_config_hook(_Key, _NewValue) ->
ok.

View File

@ -22,13 +22,13 @@
-include("types.hrl"). -include("types.hrl").
-include("logger.hrl"). -include("logger.hrl").
-export([start_link/0, stop/0]). -export([start_link/0, update_config/0, stop/0]).
%% API %% API
-export([detect/1]). -export([detect/1]).
-ifdef(TEST). -ifdef(TEST).
-export([get_policy/2]). -export([get_policy/1]).
-endif. -endif.
%% gen_server callbacks %% gen_server callbacks
@ -59,12 +59,17 @@
start_link() -> start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
update_config() ->
gen_server:cast(?MODULE, update_config).
stop() -> gen_server:stop(?MODULE). stop() -> gen_server:stop(?MODULE).
%% @doc Detect flapping when a MQTT client disconnected. %% @doc Detect flapping when a MQTT client disconnected.
-spec detect(emqx_types:clientinfo()) -> boolean(). -spec detect(emqx_types:clientinfo()) -> boolean().
detect(#{clientid := ClientId, peerhost := PeerHost, zone := Zone}) -> detect(#{clientid := ClientId, peerhost := PeerHost, zone := Zone}) ->
Policy = #{max_count := Threshold} = get_policy([max_count, window_time, ban_time], Zone), detect(ClientId, PeerHost, get_policy(Zone)).
detect(ClientId, PeerHost, #{enable := true, max_count := Threshold} = Policy) ->
%% The initial flapping record sets the detect_cnt to 0. %% The initial flapping record sets the detect_cnt to 0.
InitVal = #flapping{ InitVal = #flapping{
clientid = ClientId, clientid = ClientId,
@ -82,24 +87,12 @@ detect(#{clientid := ClientId, peerhost := PeerHost, zone := Zone}) ->
[] -> [] ->
false false
end end
end. end;
detect(_ClientId, _PeerHost, #{enable := false}) ->
false.
get_policy(Keys, Zone) when is_list(Keys) -> get_policy(Zone) ->
RootKey = flapping_detect, emqx_config:get_zone_conf(Zone, [flapping_detect]).
Conf = emqx_config:get_zone_conf(Zone, [RootKey]),
lists:foldl(
fun(Key, Acc) ->
case maps:find(Key, Conf) of
{ok, V} -> Acc#{Key => V};
error -> Acc#{Key => emqx_config:get([RootKey, Key])}
end
end,
#{},
Keys
);
get_policy(Key, Zone) ->
#{Key := Conf} = get_policy([Key], Zone),
Conf.
now_diff(TS) -> erlang:system_time(millisecond) - TS. now_diff(TS) -> erlang:system_time(millisecond) - TS.
@ -115,8 +108,8 @@ init([]) ->
{read_concurrency, true}, {read_concurrency, true},
{write_concurrency, true} {write_concurrency, true}
]), ]),
start_timers(), Timers = start_timers(),
{ok, #{}, hibernate}. {ok, Timers, hibernate}.
handle_call(Req, _From, State) -> handle_call(Req, _From, State) ->
?SLOG(error, #{msg => "unexpected_call", call => Req}), ?SLOG(error, #{msg => "unexpected_call", call => Req}),
@ -169,17 +162,20 @@ handle_cast(
) )
end, end,
{noreply, State}; {noreply, State};
handle_cast(update_config, State) ->
NState = update_timer(State),
{noreply, NState};
handle_cast(Msg, State) -> handle_cast(Msg, State) ->
?SLOG(error, #{msg => "unexpected_cast", cast => Msg}), ?SLOG(error, #{msg => "unexpected_cast", cast => Msg}),
{noreply, State}. {noreply, State}.
handle_info({timeout, _TRef, {garbage_collect, Zone}}, State) -> handle_info({timeout, _TRef, {garbage_collect, Zone}}, State) ->
Timestamp = Policy = #{window_time := WindowTime} = get_policy(Zone),
erlang:system_time(millisecond) - get_policy(window_time, Zone), Timestamp = erlang:system_time(millisecond) - WindowTime,
MatchSpec = [{{'_', '_', '_', '$1', '_'}, [{'<', '$1', Timestamp}], [true]}], MatchSpec = [{{'_', '_', '_', '$1', '_'}, [{'<', '$1', Timestamp}], [true]}],
ets:select_delete(?FLAPPING_TAB, MatchSpec), ets:select_delete(?FLAPPING_TAB, MatchSpec),
_ = start_timer(Zone), Timer = start_timer(Policy, Zone),
{noreply, State, hibernate}; {noreply, State#{Zone => Timer}, hibernate};
handle_info(Info, State) -> handle_info(Info, State) ->
?SLOG(error, #{msg => "unexpected_info", info => Info}), ?SLOG(error, #{msg => "unexpected_info", info => Info}),
{noreply, State}. {noreply, State}.
@ -190,18 +186,27 @@ terminate(_Reason, _State) ->
code_change(_OldVsn, State, _Extra) -> code_change(_OldVsn, State, _Extra) ->
{ok, State}. {ok, State}.
start_timer(Zone) -> start_timer(#{enable := true, window_time := WindowTime}, Zone) ->
case get_policy(window_time, Zone) of
WindowTime when is_integer(WindowTime) ->
emqx_utils:start_timer(WindowTime, {garbage_collect, Zone}); emqx_utils:start_timer(WindowTime, {garbage_collect, Zone});
disabled -> start_timer(_Policy, _Zone) ->
ok undefined.
end.
start_timers() -> start_timers() ->
maps:foreach( maps:map(
fun(Zone, _ZoneConf) -> fun(ZoneName, #{flapping_detect := FlappingDetect}) ->
start_timer(Zone) start_timer(FlappingDetect, ZoneName)
end,
emqx:get_config([zones], #{})
).
update_timer(Timers) ->
maps:map(
fun(ZoneName, #{flapping_detect := FlappingDetect}) ->
case maps:get(ZoneName, Timers, undefined) of
undefined -> start_timer(FlappingDetect, ZoneName);
%% Don't reset this timer, it will be updated after next timeout.
TRef -> TRef
end
end, end,
emqx:get_config([zones], #{}) emqx:get_config([zones], #{})
). ).

View File

@ -275,7 +275,7 @@ roots(low) ->
{"flapping_detect", {"flapping_detect",
sc( sc(
ref("flapping_detect"), ref("flapping_detect"),
#{importance => ?IMPORTANCE_HIDDEN} #{importance => ?DEFAULT_IMPORTANCE}
)}, )},
{"persistent_session_store", {"persistent_session_store",
sc( sc(
@ -685,15 +685,14 @@ fields("flapping_detect") ->
boolean(), boolean(),
#{ #{
default => false, default => false,
deprecated => {since, "5.0.23"},
desc => ?DESC(flapping_detect_enable) desc => ?DESC(flapping_detect_enable)
} }
)}, )},
{"window_time", {"window_time",
sc( sc(
hoconsc:union([disabled, duration()]), duration(),
#{ #{
default => disabled, default => "1m",
importance => ?IMPORTANCE_HIGH, importance => ?IMPORTANCE_HIGH,
desc => ?DESC(flapping_detect_window_time) desc => ?DESC(flapping_detect_window_time)
} }

View File

@ -58,8 +58,7 @@ hidden() ->
[ [
"stats", "stats",
"overload_protection", "overload_protection",
"conn_congestion", "conn_congestion"
"flapping_detect"
]. ].
%% zone schemas are clones from the same name from root level %% zone schemas are clones from the same name from root level

View File

@ -30,6 +30,7 @@ init_per_suite(Config) ->
default, default,
[flapping_detect], [flapping_detect],
#{ #{
enable => true,
max_count => 3, max_count => 3,
% 0.1s % 0.1s
window_time => 100, window_time => 100,
@ -102,20 +103,21 @@ t_expired_detecting(_) ->
) )
). ).
t_conf_without_window_time(_) -> t_conf_update(_) ->
%% enable is deprecated, so we need to make sure it won't be used.
Global = emqx_config:get([flapping_detect]), Global = emqx_config:get([flapping_detect]),
?assertNot(maps:is_key(enable, Global)), #{
%% zones don't have default value, so we need to make sure fallback to global conf. ban_time := _BanTime,
%% this new_zone will fallback to global conf. enable := _Enable,
max_count := _MaxCount,
window_time := _WindowTime
} = Global,
emqx_config:put_zone_conf(new_zone, [flapping_detect], #{}), emqx_config:put_zone_conf(new_zone, [flapping_detect], #{}),
?assertEqual(Global, get_policy(new_zone)), ?assertEqual(Global, get_policy(new_zone)),
emqx_config:put_zone_conf(new_zone_1, [flapping_detect], #{window_time => 100}), emqx_config:put_zone_conf(new_zone_1, [flapping_detect], #{window_time => 100}),
?assertEqual(100, emqx_flapping:get_policy(window_time, new_zone_1)), ?assertEqual(Global#{window_time := 100}, emqx_flapping:get_policy(new_zone_1)),
?assertEqual(maps:get(ban_time, Global), emqx_flapping:get_policy(ban_time, new_zone_1)),
?assertEqual(maps:get(max_count, Global), emqx_flapping:get_policy(max_count, new_zone_1)),
ok. ok.
get_policy(Zone) -> get_policy(Zone) ->
emqx_flapping:get_policy([window_time, ban_time, max_count], Zone). emqx_flapping:get_policy(Zone).