diff --git a/include/emqx.hrl b/include/emqx.hrl index 0a9bd646e..b68c54b72 100644 --- a/include/emqx.hrl +++ b/include/emqx.hrl @@ -154,15 +154,13 @@ %% Banned %%-------------------------------------------------------------------- --type(banned_who() :: {clientid, binary()} - | {username, binary()} - | {ip_address, inet:ip_address()}). - -record(banned, { - who :: banned_who(), - reason :: binary(), + who :: {clientid, binary()} + | {username, binary()} + | {ip_address, inet:ip_address()}, by :: binary(), - desc :: binary(), + reason :: binary(), + at :: integer(), until :: integer() }). diff --git a/src/emqx_banned.erl b/src/emqx_banned.erl index e9044c8c1..6f2a31766 100644 --- a/src/emqx_banned.erl +++ b/src/emqx_banned.erl @@ -33,7 +33,7 @@ -export([start_link/0, stop/0]). -export([ check/1 - , add/1 + , create/1 , delete/1 , info/1 ]). @@ -74,21 +74,39 @@ start_link() -> stop() -> gen_server:stop(?MODULE). -spec(check(emqx_types:clientinfo()) -> boolean()). -check(#{clientid := ClientId, - username := Username, - peerhost := IPAddr}) -> - ets:member(?BANNED_TAB, {clientid, ClientId}) - orelse ets:member(?BANNED_TAB, {username, Username}) - orelse ets:member(?BANNED_TAB, {ipaddr, IPAddr}). +check(ClientInfo) -> + do_check({clientid, maps:get(clientid, ClientInfo, undefined)}) + orelse do_check({username, maps:get(username, ClientInfo, undefined)}) + orelse do_check({peerhost, maps:get(peerhost, ClientInfo, undefined)}). --spec(add(emqx_types:banned()) -> ok). -add(Banned) when is_record(Banned, banned) -> +do_check({_, undefined}) -> + false; +do_check(Who) when is_tuple(Who) -> + case mnesia:dirty_read(?BANNED_TAB, Who) of + [] -> false; + [#banned{until = Until}] -> + Until > erlang:system_time(millisecond) + end. + +-spec(create(emqx_types:banned()) -> ok). +create(#{who := Who, + by := By, + reason := Reason, + at := At, + until := Until}) -> + mnesia:dirty_write(?BANNED_TAB, #banned{who = Who, + by = By, + reason = Reason, + at = At, + until = Until}); +create(Banned) when is_record(Banned, banned) -> mnesia:dirty_write(?BANNED_TAB, Banned). -spec(delete({clientid, emqx_types:clientid()} | {username, emqx_types:username()} | {peerhost, emqx_types:peerhost()}) -> ok). -delete(Key) -> mnesia:dirty_delete(?BANNED_TAB, Key). +delete(Who) -> + mnesia:dirty_delete(?BANNED_TAB, Who). info(InfoKey) -> mnesia:table_info(?BANNED_TAB, InfoKey). diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl index d3c15d8df..435cba730 100644 --- a/src/emqx_channel.erl +++ b/src/emqx_channel.erl @@ -219,7 +219,6 @@ handle_in(?CONNECT_PACKET(ConnPkt), Channel) -> fun enrich_client/2, fun set_logger_meta/2, fun check_banned/2, - fun check_flapping/2, fun auth_connect/2], ConnPkt, Channel) of {ok, NConnPkt, NChannel} -> process_connect(NConnPkt, NChannel); @@ -942,7 +941,7 @@ set_logger_meta(_ConnPkt, #channel{clientinfo = #{clientid := ClientId}}) -> emqx_logger:set_metadata_clientid(ClientId). %%-------------------------------------------------------------------- -%% Check banned/flapping +%% Check banned %%-------------------------------------------------------------------- check_banned(_ConnPkt, #channel{clientinfo = ClientInfo = #{zone := Zone}}) -> @@ -951,13 +950,6 @@ check_banned(_ConnPkt, #channel{clientinfo = ClientInfo = #{zone := Zone}}) -> false -> ok end. -check_flapping(_ConnPkt, #channel{clientinfo = ClientInfo = #{zone := Zone}}) -> - case emqx_zone:enable_flapping_detect(Zone) - andalso emqx_flapping:check(ClientInfo) of - true -> {error, ?RC_CONNECTION_RATE_EXCEEDED}; - false -> ok - end. - %%-------------------------------------------------------------------- %% Auth Connect %%-------------------------------------------------------------------- diff --git a/src/emqx_flapping.erl b/src/emqx_flapping.erl index ef1e0e3e9..52a9074c9 100644 --- a/src/emqx_flapping.erl +++ b/src/emqx_flapping.erl @@ -27,7 +27,7 @@ -export([start_link/0, stop/0]). %% API --export([check/1, detect/1]). +-export([detect/1]). %% gen_server callbacks -export([ init/1 @@ -54,8 +54,7 @@ clientid :: emqx_types:clientid(), peerhost :: emqx_types:peerhost(), started_at :: pos_integer(), - detect_cnt :: pos_integer(), - banned_at :: pos_integer() + detect_cnt :: pos_integer() }). -opaque(flapping() :: #flapping{}). @@ -68,27 +67,14 @@ start_link() -> stop() -> gen_server:stop(?MODULE). -%% @doc Check flapping when a MQTT client connected. --spec(check(emqx_types:clientinfo()) -> boolean()). -check(#{clientid := ClientId}) -> - check(ClientId, get_policy()). - -check(ClientId, #{banned_interval := Interval}) -> - case ets:lookup(?FLAPPING_TAB, {banned, ClientId}) of - [] -> false; - [#flapping{banned_at = BannedAt}] -> - now_diff(BannedAt) < Interval - end. - %% @doc Detect flapping when a MQTT client disconnected. -spec(detect(emqx_types:clientinfo()) -> boolean()). detect(Client) -> detect(Client, get_policy()). -detect(#{clientid := ClientId, peerhost := PeerHost}, - Policy = #{threshold := Threshold}) -> +detect(#{clientid := ClientId, peerhost := PeerHost}, Policy = #{threshold := Threshold}) -> try ets:update_counter(?FLAPPING_TAB, ClientId, {#flapping.detect_cnt, 1}) of Cnt when Cnt < Threshold -> false; - _Cnt -> case ets:lookup(?FLAPPING_TAB, ClientId) of + _Cnt -> case ets:take(?FLAPPING_TAB, ClientId) of [Flapping] -> ok = gen_server:cast(?MODULE, {detected, Flapping, Policy}), true; @@ -118,52 +104,44 @@ now_diff(TS) -> erlang:system_time(millisecond) - TS. %%-------------------------------------------------------------------- init([]) -> - #{duration := Duration, banned_interval := Interval} = get_policy(), ok = emqx_tables:new(?FLAPPING_TAB, [public, set, {keypos, 2}, {read_concurrency, true}, {write_concurrency, true} ]), - State = #{time => max(Duration, Interval) + 1, tref => undefined}, - {ok, ensure_timer(State), hibernate}. + {ok, #{}, hibernate}. handle_call(Req, _From, State) -> ?LOG(error, "Unexpected call: ~p", [Req]), {reply, ignored, State}. -handle_cast({detected, Flapping = #flapping{clientid = ClientId, - peerhost = PeerHost, - started_at = StartedAt, - detect_cnt = DetectCnt}, - #{duration := Duration}}, State) -> - case (Interval = now_diff(StartedAt)) < Duration of +handle_cast({detected, #flapping{clientid = ClientId, + peerhost = PeerHost, + started_at = StartedAt, + detect_cnt = DetectCnt}, + #{duration := Duration, banned_interval := Interval}}, State) -> + case now_diff(StartedAt) < Duration of true -> %% Flapping happened:( - %% Log first ?LOG(error, "Flapping detected: ~s(~s) disconnected ~w times in ~wms", [ClientId, esockd_net:ntoa(PeerHost), DetectCnt, Duration]), - %% Banned. - BannedFlapping = Flapping#flapping{clientid = {banned, ClientId}, - banned_at = erlang:system_time(millisecond) - }, - alarm_handler:set_alarm({{flapping_detected, ClientId}, BannedFlapping}), - ets:insert(?FLAPPING_TAB, BannedFlapping); + Now = erlang:system_time(millisecond), + Banned = #banned{who = {clientid, ClientId}, + by = <<"flapping detector">>, + reason = <<"flapping is detected">>, + at = Now, + until = Now + Interval}, + alarm_handler:set_alarm({{flapping_detected, ClientId}, Banned}), + emqx_banned:create(Banned); false -> ?LOG(warning, "~s(~s) disconnected ~w times in ~wms", [ClientId, esockd_net:ntoa(PeerHost), DetectCnt, Interval]) end, - ets:delete_object(?FLAPPING_TAB, Flapping), {noreply, State}; handle_cast(Msg, State) -> ?LOG(error, "Unexpected cast: ~p", [Msg]), {noreply, State}. -handle_info({timeout, TRef, expire_flapping}, State = #{tref := TRef}) -> - with_flapping_tab(fun expire_flapping/2, - [erlang:system_time(millisecond), - get_policy()]), - {noreply, ensure_timer(State#{tref => undefined}), hibernate}; - handle_info(Info, State) -> ?LOG(error, "Unexpected info: ~p", [Info]), {noreply, State}. @@ -173,34 +151,3 @@ terminate(_Reason, _State) -> code_change(_OldVsn, State, _Extra) -> {ok, State}. - -%%-------------------------------------------------------------------- -%% Internal functions -%%-------------------------------------------------------------------- - -ensure_timer(State = #{time := Time, tref := undefined}) -> - State#{tref => emqx_misc:start_timer(Time, expire_flapping)}; -ensure_timer(State) -> State. - -with_flapping_tab(Fun, Args) -> - case ets:info(?FLAPPING_TAB, size) of - undefined -> ok; - 0 -> ok; - _Size -> erlang:apply(Fun, Args) - end. - -expire_flapping(NowTime, #{duration := Duration, banned_interval := Interval}) -> - case ets:select(?FLAPPING_TAB, - [{#flapping{started_at = '$1', banned_at = undefined, _ = '_'}, - [{'<', '$1', NowTime-Duration}], ['$_']}, - {#flapping{clientid = {banned, '_'}, banned_at = '$1', _ = '_'}, - [{'<', '$1', NowTime-Interval}], ['$_']}]) of - [] -> ok; - Flappings -> - lists:foreach(fun(Flapping = #flapping{clientid = {banned, ClientId}}) -> - ets:delete_object(?FLAPPING_TAB, Flapping), - alarm_handler:clear_alarm({flapping_detected, ClientId}); - (_) -> ok - end, Flappings) - end. - diff --git a/test/emqx_banned_SUITE.erl b/test/emqx_banned_SUITE.erl index 7801aea9a..849bcef90 100644 --- a/test/emqx_banned_SUITE.erl +++ b/test/emqx_banned_SUITE.erl @@ -38,20 +38,20 @@ end_per_suite(_Config) -> t_add_delete(_) -> Banned = #banned{who = {clientid, <<"TestClient">>}, - reason = <<"test">>, by = <<"banned suite">>, - desc = <<"test">>, + reason = <<"test">>, + at = erlang:system_time(second), until = erlang:system_time(second) + 1000 }, - ok = emqx_banned:add(Banned), + ok = emqx_banned:create(Banned), ?assertEqual(1, emqx_banned:info(size)), ok = emqx_banned:delete({clientid, <<"TestClient">>}), ?assertEqual(0, emqx_banned:info(size)). t_check(_) -> - ok = emqx_banned:add(#banned{who = {clientid, <<"BannedClient">>}}), - ok = emqx_banned:add(#banned{who = {username, <<"BannedUser">>}}), - ok = emqx_banned:add(#banned{who = {ipaddr, {192,168,0,1}}}), + ok = emqx_banned:create(#banned{who = {clientid, <<"BannedClient">>}}), + ok = emqx_banned:create(#banned{who = {username, <<"BannedUser">>}}), + ok = emqx_banned:create(#banned{who = {peerhost, {192,168,0,1}}}), ?assertEqual(3, emqx_banned:info(size)), ClientInfo1 = #{clientid => <<"BannedClient">>, username => <<"user">>, @@ -75,7 +75,7 @@ t_check(_) -> ?assertNot(emqx_banned:check(ClientInfo4)), ok = emqx_banned:delete({clientid, <<"BannedClient">>}), ok = emqx_banned:delete({username, <<"BannedUser">>}), - ok = emqx_banned:delete({ipaddr, {192,168,0,1}}), + ok = emqx_banned:delete({peerhost, {192,168,0,1}}), ?assertNot(emqx_banned:check(ClientInfo1)), ?assertNot(emqx_banned:check(ClientInfo2)), ?assertNot(emqx_banned:check(ClientInfo3)), @@ -84,9 +84,8 @@ t_check(_) -> t_unused(_) -> {ok, Banned} = emqx_banned:start_link(), - ok = emqx_banned:add(#banned{who = {clientid, <<"BannedClient">>}, - until = erlang:system_time(second) - }), + ok = emqx_banned:create(#banned{who = {clientid, <<"BannedClient">>}, + until = erlang:system_time(second)}), ?assertEqual(ignored, gen_server:call(Banned, unexpected_req)), ?assertEqual(ok, gen_server:cast(Banned, unexpected_msg)), ?assertEqual(ok, Banned ! ok), diff --git a/test/emqx_flapping_SUITE.erl b/test/emqx_flapping_SUITE.erl index a9234c541..e860d5703 100644 --- a/test/emqx_flapping_SUITE.erl +++ b/test/emqx_flapping_SUITE.erl @@ -45,13 +45,13 @@ t_detect_check(_) -> peerhost => {127,0,0,1} }, false = emqx_flapping:detect(ClientInfo), - false = emqx_flapping:check(ClientInfo), + false = emqx_banned:check(ClientInfo), false = emqx_flapping:detect(ClientInfo), - false = emqx_flapping:check(ClientInfo), + false = emqx_banned:check(ClientInfo), true = emqx_flapping:detect(ClientInfo), timer:sleep(100), - true = emqx_flapping:check(ClientInfo), - timer:sleep(300), - false = emqx_flapping:check(ClientInfo), + true = emqx_banned:check(ClientInfo), + timer:sleep(200), + false = emqx_banned:check(ClientInfo), ok = emqx_flapping:stop().