refactor(eqmx_limiter): use the new hierarchical token bucket to replace the old ratelimit code

This commit is contained in:
firest 2021-12-10 00:46:11 +08:00 committed by lafirest
parent 540e5dbc0b
commit 8493b61cb5
35 changed files with 2672 additions and 1280 deletions

View File

@ -82,7 +82,7 @@
%% Authentication Data Cache
auth_cache :: maybe(map()),
%% Quota checkers
quota :: maybe(emqx_limiter:limiter()),
quota :: maybe(emqx_limiter_container:limiter()),
%% Timers
timers :: #{atom() => disabled | maybe(reference())},
%% Conn State
@ -120,6 +120,7 @@
}).
-define(INFO_KEYS, [conninfo, conn_state, clientinfo, session, will_msg]).
-define(LIMITER_ROUTING, message_routing).
-dialyzer({no_match, [shutdown/4, ensure_timer/2, interval/2]}).
@ -200,14 +201,13 @@ caps(#channel{clientinfo = #{zone := Zone}}) ->
-spec(init(emqx_types:conninfo(), opts()) -> channel()).
init(ConnInfo = #{peername := {PeerHost, _Port},
sockname := {_Host, SockPort}},
#{zone := Zone, listener := {Type, Listener}}) ->
#{zone := Zone, limiter := LimiterCfg, listener := {Type, Listener}}) ->
Peercert = maps:get(peercert, ConnInfo, undefined),
Protocol = maps:get(protocol, ConnInfo, mqtt),
MountPoint = case emqx_config:get_listener_conf(Type, Listener, [mountpoint]) of
<<>> -> undefined;
MP -> MP
end,
QuotaPolicy = emqx_config:get_zone_conf(Zone, [quota], #{}),
ClientInfo = set_peercert_infos(
Peercert,
#{zone => Zone,
@ -228,7 +228,7 @@ init(ConnInfo = #{peername := {PeerHost, _Port},
outbound => #{}
},
auth_cache = #{},
quota = emqx_limiter:init(Zone, quota_policy(QuotaPolicy)),
quota = emqx_limiter_container:get_limiter_by_names([?LIMITER_ROUTING], LimiterCfg),
timers = #{},
conn_state = idle,
takeover = false,
@ -236,11 +236,6 @@ init(ConnInfo = #{peername := {PeerHost, _Port},
pendings = []
}.
quota_policy(RawPolicy) ->
[{Name, {list_to_integer(StrCount),
erlang:trunc(hocon_postprocess:duration(StrWind) / 1000)}}
|| {Name, [StrCount, StrWind]} <- maps:to_list(RawPolicy)].
set_peercert_infos(NoSSL, ClientInfo, _)
when NoSSL =:= nossl;
NoSSL =:= undefined ->
@ -653,10 +648,10 @@ ensure_quota(PubRes, Channel = #channel{quota = Limiter}) ->
({_, _, {ok, I}}, N) -> N + I;
(_, N) -> N
end, 1, PubRes),
case emqx_limiter:check(#{cnt => Cnt, oct => 0}, Limiter) of
case emqx_limiter_container:check(Cnt, ?LIMITER_ROUTING, Limiter) of
{ok, NLimiter} ->
Channel#channel{quota = NLimiter};
{pause, Intv, NLimiter} ->
{_, Intv, NLimiter} ->
ensure_timer(quota_timer, Intv, Channel#channel{quota = NLimiter})
end.
@ -1005,10 +1000,9 @@ handle_call({takeover, 'end'}, Channel = #channel{session = Session,
handle_call(list_authz_cache, Channel) ->
{reply, emqx_authz_cache:list_authz_cache(), Channel};
handle_call({quota, Policy}, Channel) ->
Zone = info(zone, Channel),
Quota = emqx_limiter:init(Zone, Policy),
reply(ok, Channel#channel{quota = Quota});
handle_call({quota, Bucket}, #channel{quota = Quota} = Channel) ->
Quota2 = emqx_limiter_container:update_by_name(message_routing, Bucket, Quota),
reply(ok, Channel#channel{quota = Quota2});
handle_call({keepalive, Interval}, Channel = #channel{keepalive = KeepAlive,
conninfo = ConnInfo}) ->
@ -1147,8 +1141,15 @@ handle_timeout(_TRef, will_message, Channel = #channel{will_msg = WillMsg}) ->
(WillMsg =/= undefined) andalso publish_will_msg(WillMsg),
{ok, clean_timer(will_timer, Channel#channel{will_msg = undefined})};
handle_timeout(_TRef, expire_quota_limit, Channel) ->
{ok, clean_timer(quota_timer, Channel)};
handle_timeout(_TRef, expire_quota_limit,
#channel{quota = Quota} = Channel) ->
case emqx_limiter_container:retry(?LIMITER_ROUTING, Quota) of
{_, Intv, Quota2} ->
Channel2 = ensure_timer(quota_timer, Intv, Channel#channel{quota = Quota2}),
{ok, Channel2};
{_, Quota2} ->
{ok, clean_timer(quota_timer, Channel#channel{quota = Quota2})}
end;
handle_timeout(_TRef, Msg, Channel) ->
?SLOG(error, #{msg => "unexpected_timeout", timeout_msg => Msg}),

View File

@ -67,8 +67,7 @@
-export([set_field/3]).
-import(emqx_misc,
[ maybe_apply/2
, start_timer/2
[ start_timer/2
]).
-record(state, {
@ -82,11 +81,6 @@
sockname :: emqx_types:peername(),
%% Sock State
sockstate :: emqx_types:sockstate(),
%% Limiter
limiter :: maybe(emqx_limiter:limiter()),
%% Limit Timer
limit_timer :: maybe(reference()),
%% Parse State
parse_state :: emqx_frame:parse_state(),
%% Serialize options
serialize :: emqx_frame:serialize_opts(),
@ -103,10 +97,30 @@
%% Zone name
zone :: atom(),
%% Listener Type and Name
listener :: {Type::atom(), Name::atom()}
}).
listener :: {Type::atom(), Name::atom()},
%% Limiter
limiter :: maybe(limiter()),
%% cache operation when overload
limiter_cache :: queue:queue(cache()),
%% limiter timers
limiter_timer :: undefined | reference()
}).
-record(retry, { types :: list(limiter_type())
, data :: any()
, next :: check_succ_handler()
}).
-record(cache, { need :: list({pos_integer(), limiter_type()})
, data :: any()
, next :: check_succ_handler()
}).
-type(state() :: #state{}).
-type cache() :: #cache{}.
-define(ACTIVE_N, 100).
-define(INFO_KEYS, [socktype, peername, sockname, sockstate]).
@ -127,6 +141,11 @@
-define(ALARM_SOCK_STATS_KEYS, [send_pend, recv_cnt, recv_oct, send_cnt, send_oct]).
-define(ALARM_SOCK_OPTS_KEYS, [high_watermark, high_msgq_watermark, sndbuf, recbuf, buffer]).
%% use macro to do compile time limiter's type check
-define(LIMITER_BYTES_IN, bytes_in).
-define(LIMITER_MESSAGE_IN, message_in).
-define(EMPTY_QUEUE, {[], []}).
-dialyzer({no_match, [info/2]}).
-dialyzer({nowarn_function, [ init/4
, init_state/3
@ -170,10 +189,10 @@ info(sockstate, #state{sockstate = SockSt}) ->
SockSt;
info(stats_timer, #state{stats_timer = StatsTimer}) ->
StatsTimer;
info(limit_timer, #state{limit_timer = LimitTimer}) ->
LimitTimer;
info(limiter, #state{limiter = Limiter}) ->
maybe_apply(fun emqx_limiter:info/1, Limiter).
Limiter;
info(limiter_timer, #state{limiter_timer = Timer}) ->
Timer.
%% @doc Get stats of the connection/channel.
-spec(stats(pid() | state()) -> emqx_types:stats()).
@ -244,7 +263,8 @@ init(Parent, Transport, RawSocket, Options) ->
exit_on_sock_error(Reason)
end.
init_state(Transport, Socket, #{zone := Zone, listener := Listener} = Opts) ->
init_state(Transport, Socket,
#{zone := Zone, limiter := LimiterCfg, listener := Listener} = Opts) ->
{ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]),
{ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]),
Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]),
@ -254,7 +274,10 @@ init_state(Transport, Socket, #{zone := Zone, listener := Listener} = Opts) ->
peercert => Peercert,
conn_mod => ?MODULE
},
Limiter = emqx_limiter:init(Zone, undefined, undefined, []),
LimiterTypes = [?LIMITER_BYTES_IN, ?LIMITER_MESSAGE_IN],
Limiter = emqx_limiter_container:get_limiter_by_names(LimiterTypes, LimiterCfg),
FrameOpts = #{
strict_mode => emqx_config:get_zone_conf(Zone, [mqtt, strict_mode]),
max_size => emqx_config:get_zone_conf(Zone, [mqtt, max_packet_size])
@ -286,7 +309,9 @@ init_state(Transport, Socket, #{zone := Zone, listener := Listener} = Opts) ->
idle_timeout = IdleTimeout,
idle_timer = IdleTimer,
zone = Zone,
listener = Listener
listener = Listener,
limiter_cache = queue:new(),
limiter_timer = undefined
}.
run_loop(Parent, State = #state{transport = Transport,
@ -428,14 +453,23 @@ handle_msg({Inet, _Sock, Data}, State) when Inet == tcp; Inet == ssl ->
Oct = iolist_size(Data),
inc_counter(incoming_bytes, Oct),
ok = emqx_metrics:inc('bytes.received', Oct),
parse_incoming(Data, State);
when_bytes_in(Oct, Data, State);
handle_msg({quic, Data, _Sock, _, _, _}, State) ->
?SLOG(debug, #{msg => "RECV_data", data => Data, transport => quic}),
Oct = iolist_size(Data),
inc_counter(incoming_bytes, Oct),
ok = emqx_metrics:inc('bytes.received', Oct),
parse_incoming(Data, State);
when_bytes_in(Oct, Data, State);
handle_msg(check_cache, #state{limiter_cache = Cache} = State) ->
case queue:peek(Cache) of
empty ->
activate_socket(State);
{value, #cache{need = Needs, data = Data, next = Next}} ->
State2 = State#state{limiter_cache = queue:drop(Cache)},
check_limiter(Needs, Data, Next, [check_cache], State2)
end;
handle_msg({incoming, Packet = ?CONNECT_PACKET(ConnPkt)},
State = #state{idle_timer = IdleTimer}) ->
@ -466,14 +500,12 @@ handle_msg({Passive, _Sock}, State)
Pubs = emqx_pd:reset_counter(incoming_pubs),
Bytes = emqx_pd:reset_counter(incoming_bytes),
InStats = #{cnt => Pubs, oct => Bytes},
%% Ensure Rate Limit
NState = ensure_rate_limit(InStats, State),
%% Run GC and Check OOM
NState1 = check_oom(run_gc(InStats, NState)),
NState1 = check_oom(run_gc(InStats, State)),
handle_info(activate_socket, NState1);
handle_msg(Deliver = {deliver, _Topic, _Msg}, #state{
listener = {Type, Listener}} = State) ->
handle_msg(Deliver = {deliver, _Topic, _Msg},
#state{listener = {Type, Listener}} = State) ->
ActiveN = get_active_n(Type, Listener),
Delivers = [Deliver | emqx_misc:drain_deliver(ActiveN)],
with_channel(handle_deliver, [Delivers], State);
@ -579,10 +611,12 @@ handle_call(_From, info, State) ->
handle_call(_From, stats, State) ->
{reply, stats(State), State};
handle_call(_From, {ratelimit, Policy}, State = #state{channel = Channel}) ->
Zone = emqx_channel:info(zone, Channel),
Limiter = emqx_limiter:init(Zone, Policy),
{reply, ok, State#state{limiter = Limiter}};
handle_call(_From, {ratelimit, Changes}, State = #state{limiter = Limiter}) ->
Fun = fun({Type, Bucket}, Acc) ->
emqx_limiter_container:update_by_name(Type, Bucket, Acc)
end,
Limiter2 = lists:foldl(Fun, Limiter, Changes),
{reply, ok, State#state{limiter = Limiter2}};
handle_call(_From, Req, State = #state{channel = Channel}) ->
case emqx_channel:handle_call(Req, Channel) of
@ -603,10 +637,7 @@ handle_timeout(_TRef, idle_timeout, State) ->
shutdown(idle_timeout, State);
handle_timeout(_TRef, limit_timeout, State) ->
NState = State#state{sockstate = idle,
limit_timer = undefined
},
handle_info(activate_socket, NState);
retry_limiter(State);
handle_timeout(_TRef, emit_stats, State = #state{channel = Channel, transport = Transport,
socket = Socket}) ->
@ -634,11 +665,23 @@ handle_timeout(TRef, Msg, State) ->
%%--------------------------------------------------------------------
%% Parse incoming data
-compile({inline, [parse_incoming/2]}).
parse_incoming(Data, State) ->
-compile({inline, [when_bytes_in/3]}).
when_bytes_in(Oct, Data, State) ->
{Packets, NState} = parse_incoming(Data, [], State),
{ok, next_incoming_msgs(Packets), NState}.
Len = erlang:length(Packets),
check_limiter([{Oct, ?LIMITER_BYTES_IN}, {Len, ?LIMITER_MESSAGE_IN}],
Packets,
fun next_incoming_msgs/3,
[],
NState).
-compile({inline, [next_incoming_msgs/3]}).
next_incoming_msgs([Packet], Msgs, State) ->
{ok, [{incoming, Packet} | Msgs], State};
next_incoming_msgs(Packets, Msgs, State) ->
Fun = fun(Packet, Acc) -> [{incoming, Packet} | Acc] end,
Msgs2 = lists:foldl(Fun, Msgs, Packets),
{ok, Msgs2, State}.
parse_incoming(<<>>, Packets, State) ->
{Packets, State};
@ -668,12 +711,6 @@ parse_incoming(Data, Packets, State = #state{parse_state = ParseState}) ->
{[{frame_error, Reason} | Packets], State}
end.
-compile({inline, [next_incoming_msgs/1]}).
next_incoming_msgs([Packet]) ->
{incoming, Packet};
next_incoming_msgs(Packets) ->
[{incoming, Packet} || Packet <- lists:reverse(Packets)].
%%--------------------------------------------------------------------
%% Handle incoming packet
@ -810,20 +847,82 @@ handle_cast(Req, State) ->
State.
%%--------------------------------------------------------------------
%% Ensure rate limit
%% rate limit
ensure_rate_limit(Stats, State = #state{limiter = Limiter}) ->
case ?ENABLED(Limiter) andalso emqx_limiter:check(Stats, Limiter) of
false -> State;
{ok, Limiter1} ->
State#state{limiter = Limiter1};
{pause, Time, Limiter1} ->
?SLOG(warning, #{msg => "pause_time_due_to_rate_limit", time_in_ms => Time}),
TRef = start_timer(Time, limit_timeout),
State#state{sockstate = blocked,
limiter = Limiter1,
limit_timer = TRef
}
-type limiter_type() :: emqx_limiter_container:limiter_type().
-type limiter() :: emqx_limiter_container:limiter().
-type check_succ_handler() ::
fun((any(), list(any()), state()) -> _).
%% check limiters, if successed call WhenOk with Data and Msgs
%% Data is the data to be processed
%% Msgs include the next msg which after Data processed
-spec check_limiter(list({pos_integer(), limiter_type()}),
any(),
check_succ_handler(),
list(any()),
state()) -> _.
check_limiter(Needs,
Data,
WhenOk,
Msgs,
#state{limiter = Limiter,
limiter_timer = LimiterTimer,
limiter_cache = Cache} = State) when Limiter =/= undefined ->
case LimiterTimer of
undefined ->
case emqx_limiter_container:check_list(Needs, Limiter) of
{ok, Limiter2} ->
WhenOk(Data, Msgs, State#state{limiter = Limiter2});
{pause, Time, Limiter2} ->
?SLOG(warning, #{msg => "pause time dueto rate limit",
needs => Needs,
time_in_ms => Time}),
Retry = #retry{types = [Type || {_, Type} <- Needs],
data = Data,
next = WhenOk},
Limiter3 = emqx_limiter_container:set_retry_context(Retry, Limiter2),
TRef = start_timer(Time, limit_timeout),
{ok, State#state{limiter = Limiter3,
limiter_timer = TRef}};
{drop, Limiter2} ->
{ok, State#state{limiter = Limiter2}}
end;
_ ->
%% if there has a retry timer, cache the operation and execute it after the retry is over
%% TODO: maybe we need to set socket to passive if size of queue is very large
%% because we queue up lots of ops that checks with the limiters.
New = #cache{need = Needs, data = Data, next = WhenOk},
{ok, State#state{limiter_cache = queue:in(New, Cache)}}
end;
check_limiter(_, Data, WhenOk, Msgs, State) ->
WhenOk(Data, Msgs, State).
%% try to perform a retry
-spec retry_limiter(state()) -> _.
retry_limiter(#state{limiter = Limiter} = State) ->
#retry{types = Types, data = Data, next = Next} = emqx_limiter_container:get_retry_context(Limiter),
case emqx_limiter_container:retry_list(Types, Limiter) of
{ok, Limiter2} ->
Next(Data,
[check_cache],
State#state{ limiter = Limiter2
, limiter_timer = undefined
});
{pause, Time, Limiter2} ->
?SLOG(warning, #{msg => "pause time dueto rate limit",
types => Types,
time_in_ms => Time}),
TRef = start_timer(Time, limit_timeout),
{ok, State#state{limiter = Limiter2,
limiter_timer = TRef}}
end.
%%--------------------------------------------------------------------
@ -852,19 +951,25 @@ check_oom(State = #state{channel = Channel}) ->
%%--------------------------------------------------------------------
%% Activate Socket
%% TODO: maybe we could keep socket passive for receiving socket closed event.
-compile({inline, [activate_socket/1]}).
activate_socket(State = #state{sockstate = closed}) ->
{ok, State};
activate_socket(State = #state{sockstate = blocked}) ->
{ok, State};
activate_socket(State = #state{transport = Transport, socket = Socket,
listener = {Type, Listener}}) ->
activate_socket(#state{limiter_timer = Timer} = State)
when Timer =/= undefined ->
{ok, State#state{sockstate = blocked}};
activate_socket(#state{transport = Transport,
sockstate = SockState,
socket = Socket,
listener = {Type, Listener}} = State)
when SockState =/= closed ->
ActiveN = get_active_n(Type, Listener),
case Transport:setopts(Socket, [{active, ActiveN}]) of
ok -> {ok, State#state{sockstate = running}};
Error -> Error
end.
end;
activate_socket(State) ->
{ok, State}.
%%--------------------------------------------------------------------
%% Close Socket
@ -943,6 +1048,6 @@ get_state(Pid) ->
maps:from_list(lists:zip(record_info(fields, state),
tl(tuple_to_list(State)))).
get_active_n(quic, _Listener) -> 100;
get_active_n(quic, _Listener) -> ?ACTIVE_N;
get_active_n(Type, Listener) ->
emqx_config:get_listener_conf(Type, Listener, [tcp, active_n]).

View File

@ -0,0 +1,52 @@
##--------------------------------------------------------------------
## Emq X Rate Limiter
##--------------------------------------------------------------------
emqx_limiter {
bytes_in {
global.rate = infinity # token generation rate
zone.default.rate = infinity
bucket.default {
zone = default
aggregated.rate = infinity
aggregated.capacity = infinity
per_client.rate = infinity
per_client.capacity = infinity
}
}
message_in {
global.rate = infinity
zone.default.rate = infinity
bucket.default {
zone = default
aggregated.rate = infinity
aggregated.capacity = infinity
per_client.rate = infinity
per_client.capacity = infinity
}
}
connection {
global.rate = infinity
zone.default.rate = infinity
bucket.default {
zone = default
aggregated.rate = infinity
aggregated.capacity = infinity
per_client.rate = infinity
per_client.capacity = infinity
}
}
message_routing {
global.rate = infinity
zone.default.rate = infinity
bucket.default {
zone = default
aggregated.rate = infinity
aggregated.capacity = infinity
per_client.rate = infinity
per_client.capacity = infinity
}
}
}

View File

@ -0,0 +1,358 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_htb_limiter).
%% @doc the limiter of the hierarchical token limiter system
%% this module provides api for creating limiters, consume tokens, check tokens and retry
%% @end
%% API
-export([ make_token_bucket_limiter/2, make_ref_limiter/2, check/2
, consume/2, set_retry/2, retry/1, make_infinity_limiter/1
, make_future/1, available/1
]).
-export_type([token_bucket_limiter/0]).
%% a token bucket limiter with a limiter server's bucket reference
-type token_bucket_limiter() :: #{ tokens := non_neg_integer() %% the number of tokens currently available
, rate := decimal()
, capacity := decimal()
, lasttime := millisecond()
, max_retry_time := non_neg_integer() %% @see emqx_limiter_schema
, failure_strategy := failure_strategy() %% @see emqx_limiter_schema
, divisible := boolean() %% @see emqx_limiter_schema
, low_water_mark := non_neg_integer() %% @see emqx_limiter_schema
, bucket := bucket() %% the limiter server's bucket
%% retry contenxt
, retry_ctx => undefined %% undefined meaning there is no retry context or no need to retry
| retry_context(token_bucket_limiter()) %% the retry context
, atom => any() %% allow to add other keys
}.
%% a limiter server's bucket reference
-type ref_limiter() :: #{ max_retry_time := non_neg_integer()
, failure_strategy := failure_strategy()
, divisible := boolean()
, low_water_mark := non_neg_integer()
, bucket := bucket()
, retry_ctx => undefined | retry_context(ref_limiter())
, atom => any() %% allow to add other keys
}.
-type retry_fun(Limiter) :: fun((pos_integer(), Limiter) -> inner_check_result(Limiter)).
-type acquire_type(Limiter) :: integer() | retry_context(Limiter).
-type retry_context(Limiter) :: #{ continuation := undefined | retry_fun(Limiter)
, diff := non_neg_integer() %% how many tokens are left to obtain
, need => pos_integer()
, start => millisecond()
}.
-type bucket() :: emqx_limiter_bucket_ref:bucket_ref().
-type limiter() :: token_bucket_limiter() | ref_limiter() | infinity.
-type millisecond() :: non_neg_integer().
-type pause_type() :: pause | partial.
-type check_result_ok(Limiter) :: {ok, Limiter}.
-type check_result_pause(Limiter) :: {pause_type(), millisecond(), retry_context(Limiter), Limiter}.
-type result_drop(Limiter) :: {drop, Limiter}.
-type check_result(Limiter) :: check_result_ok(Limiter)
| check_result_pause(Limiter)
| result_drop(Limiter).
-type inner_check_result(Limiter) :: check_result_ok(Limiter)
| check_result_pause(Limiter).
-type consume_result(Limiter) :: check_result_ok(Limiter)
| result_drop(Limiter).
-type decimal() :: emqx_limiter_decimal:decimal().
-type failure_strategy() :: emqx_limiter_schema:failure_strategy().
-type limiter_bucket_cfg() :: #{ rate := decimal()
, initial := non_neg_integer()
, low_water_mark := non_neg_integer()
, capacity := decimal()
, divisible := boolean()
, max_retry_time := non_neg_integer()
, failure_strategy := failure_strategy()
}.
-type future() :: pos_integer().
-define(NOW, erlang:monotonic_time(millisecond)).
-define(MINIMUM_PAUSE, 50).
-define(MAXIMUM_PAUSE, 5000).
-import(emqx_limiter_decimal, [sub/2, mul/2, floor_div/2, add/2]).
%%--------------------------------------------------------------------
%% API
%%--------------------------------------------------------------------
%%@doc create a limiter
-spec make_token_bucket_limiter(limiter_bucket_cfg(), bucket()) -> _.
make_token_bucket_limiter(Cfg, Bucket) ->
Cfg#{ tokens => emqx_limiter_server:get_initial_val(Cfg)
, lasttime => ?NOW
, bucket => Bucket
}.
%%@doc create a limiter server's reference
-spec make_ref_limiter(limiter_bucket_cfg(), bucket()) -> ref_limiter().
make_ref_limiter(Cfg, Bucket) when Bucket =/= infinity ->
Cfg#{bucket => Bucket}.
-spec make_infinity_limiter(limiter_bucket_cfg()) -> infinity.
make_infinity_limiter(_) ->
infinity.
%% @doc request some tokens
%% it will automatically retry when failed until the maximum retry time is reached
%% @end
-spec consume(integer(), Limiter) -> consume_result(Limiter)
when Limiter :: limiter().
consume(Need, #{max_retry_time := RetryTime} = Limiter) when Need > 0 ->
try_consume(RetryTime, Need, Limiter);
consume(_, Limiter) ->
{ok, Limiter}.
%% @doc try to request the token and return the result without automatically retrying
-spec check(acquire_type(Limiter), Limiter) -> check_result(Limiter)
when Limiter :: limiter().
check(_, infinity) ->
{ok, infinity};
check(Need, Limiter) when is_integer(Need), Need > 0 ->
case do_check(Need, Limiter) of
{ok, _} = Done ->
Done;
{PauseType, Pause, Ctx, Limiter2} ->
{PauseType,
Pause,
Ctx#{start => ?NOW, need => Need}, Limiter2}
end;
%% check with retry context.
%% when continuation = undefined, the diff will be 0
%% so there is no need to check continuation here
check(#{continuation := Cont,
diff := Diff,
start := Start} = Retry,
#{failure_strategy := Failure,
max_retry_time := RetryTime} = Limiter) when Diff > 0 ->
case Cont(Diff, Limiter) of
{ok, _} = Done ->
Done;
{PauseType, Pause, Ctx, Limiter2} ->
IsFailed = ?NOW - Start >= RetryTime,
Retry2 = maps:merge(Retry, Ctx),
case IsFailed of
false ->
{PauseType, Pause, Retry2, Limiter2};
_ ->
on_failure(Failure, try_restore(Retry2, Limiter2))
end
end;
check(_, Limiter) ->
{ok, Limiter}.
%% @doc pack the retry context into the limiter data
-spec set_retry(retry_context(Limiter), Limiter) -> Limiter
when Limiter :: limiter().
set_retry(Retry, Limiter) ->
Limiter#{retry_ctx => Retry}.
%% @doc check if there is a retry context, and try again if there is
-spec retry(Limiter) -> check_result(Limiter) when Limiter :: limiter().
retry(#{retry_ctx := Retry} = Limiter) when is_map(Retry) ->
check(Retry, Limiter#{retry_ctx := undefined});
retry(Limiter) ->
{ok, Limiter}.
%% @doc make a future value
%% this similar to retry context, but represents a value that will be checked in the future
%% @end
-spec make_future(pos_integer()) -> future().
make_future(Need) ->
Need.
%% @doc get the number of tokens currently available
-spec available(limiter()) -> decimal().
available(#{tokens := Tokens,
rate := Rate,
lasttime := LastTime,
capacity := Capacity,
bucket := Bucket}) ->
Tokens2 = apply_elapsed_time(Rate, ?NOW - LastTime, Tokens, Capacity),
erlang:min(Tokens2, emqx_limiter_bucket_ref:available(Bucket));
available(#{bucket := Bucket}) ->
emqx_limiter_bucket_ref:available(Bucket);
available(infinity) ->
infinity.
%%--------------------------------------------------------------------
%% Internal functions
%%--------------------------------------------------------------------
-spec try_consume(millisecond(),
acquire_type(Limiter),
Limiter) -> consume_result(Limiter) when Limiter :: limiter().
try_consume(LeftTime, Retry, #{failure_strategy := Failure} = Limiter)
when LeftTime =< 0, is_map(Retry) ->
on_failure(Failure, try_restore(Retry, Limiter));
try_consume(LeftTime, Need, Limiter) when is_integer(Need) ->
case do_check(Need, Limiter) of
{ok, _} = Done ->
Done;
{_, Pause, Ctx, Limiter2} ->
timer:sleep(erlang:min(LeftTime, Pause)),
try_consume(LeftTime - Pause, Ctx#{need => Need}, Limiter2)
end;
try_consume(LeftTime,
#{continuation := Cont,
diff := Diff} = Retry, Limiter) when Diff > 0 ->
case Cont(Diff, Limiter) of
{ok, _} = Done ->
Done;
{_, Pause, Ctx, Limiter2} ->
timer:sleep(erlang:min(LeftTime, Pause)),
try_consume(LeftTime - Pause, maps:merge(Retry, Ctx), Limiter2)
end;
try_consume(_, _, Limiter) ->
{ok, Limiter}.
-spec do_check(acquire_type(Limiter), Limiter) -> inner_check_result(Limiter)
when Limiter :: limiter().
do_check(Need, #{tokens := Tokens} = Limiter) ->
if Need =< Tokens ->
do_check_with_parent_limiter(Need, Limiter);
true ->
do_reset(Need, Limiter)
end;
do_check(Need, #{divisible := Divisible,
bucket := Bucket} = Ref) ->
case emqx_limiter_bucket_ref:check(Need, Bucket, Divisible) of
{ok, Tokens} ->
may_return_or_pause(Tokens, Ref);
{PauseType, Rate, Obtained} ->
return_pause(Rate,
PauseType,
fun ?FUNCTION_NAME/2, Need - Obtained, Ref)
end.
on_failure(force, Limiter) ->
{ok, Limiter};
on_failure(drop, Limiter) ->
{drop, Limiter};
on_failure(throw, Limiter) ->
Message = io_lib:format("limiter consume failed, limiter:~p~n", [Limiter]),
erlang:throw({rate_check_fail, Message}).
-spec do_check_with_parent_limiter(pos_integer(), token_bucket_limiter()) -> inner_check_result(token_bucket_limiter()).
do_check_with_parent_limiter(Need,
#{tokens := Tokens,
divisible := Divisible,
bucket := Bucket} = Limiter) ->
case emqx_limiter_bucket_ref:check(Need, Bucket, Divisible) of
{ok, RefLeft} ->
Left = sub(Tokens, Need),
may_return_or_pause(erlang:min(RefLeft, Left), Limiter#{tokens := Left});
{PauseType, Rate, Obtained} ->
return_pause(Rate,
PauseType,
fun ?FUNCTION_NAME/2,
Need - Obtained,
Limiter#{tokens := sub(Tokens, Obtained)})
end.
-spec do_reset(pos_integer(), token_bucket_limiter()) -> inner_check_result(token_bucket_limiter()).
do_reset(Need,
#{tokens := Tokens,
rate := Rate,
lasttime := LastTime,
divisible := Divisible,
capacity := Capacity} = Limiter) ->
Now = ?NOW,
Tokens2 = apply_elapsed_time(Rate, Now - LastTime, Tokens, Capacity),
if Tokens2 >= Need ->
Limiter2 = Limiter#{tokens := Tokens2, lasttime := Now},
do_check_with_parent_limiter(Need, Limiter2);
Divisible andalso Tokens2 > 0 ->
%% must be allocated here, because may be Need > Capacity
return_pause(Rate,
partial,
fun do_reset/2,
Need - Tokens2,
Limiter#{tokens := 0, lasttime := Now});
true ->
return_pause(Rate, pause, fun do_reset/2, Need, Limiter)
end.
-spec return_pause(decimal(), pause_type(), retry_fun(Limiter), pos_integer(), Limiter)
-> check_result_pause(Limiter) when Limiter :: limiter().
return_pause(infinity, PauseType, Fun, Diff, Limiter) ->
%% workaround when emqx_limiter_server's rate is infinity
{PauseType, ?MINIMUM_PAUSE, make_retry_context(Fun, Diff), Limiter};
return_pause(Rate, PauseType, Fun, Diff, Limiter) ->
Val = erlang:round(Diff * emqx_limiter_schema:minimum_period() / Rate),
Pause = emqx_misc:clamp(Val, ?MINIMUM_PAUSE, ?MAXIMUM_PAUSE),
{PauseType, Pause, make_retry_context(Fun, Diff), Limiter}.
-spec make_retry_context(undefined | retry_fun(Limiter), non_neg_integer()) -> retry_context(Limiter)
when Limiter :: limiter().
make_retry_context(Fun, Diff) ->
#{continuation => Fun, diff => Diff}.
-spec try_restore(retry_context(Limiter), Limiter) -> Limiter
when Limiter :: limiter().
try_restore(#{need := Need, diff := Diff},
#{tokens := Tokens, capcacity := Capacity, bucket := Bucket} = Limiter) ->
Back = Need - Diff,
Tokens2 = erlang:min(Capacity, Back + Tokens),
emqx_limiter_bucket_ref:try_restore(Back, Bucket),
Limiter#{tokens := Tokens2};
try_restore(#{need := Need, diff := Diff}, #{bucket := Bucket} = Limiter) ->
emqx_limiter_bucket_ref:try_restore(Need - Diff, Bucket),
Limiter.
-spec may_return_or_pause(non_neg_integer(), Limiter) -> check_result(Limiter)
when Limiter :: limiter().
may_return_or_pause(Left, #{low_water_mark := Mark} = Limiter) when Left >= Mark ->
{ok, Limiter};
may_return_or_pause(_, Limiter) ->
{pause, ?MINIMUM_PAUSE, make_retry_context(undefined, 0), Limiter}.
%% @doc apply the elapsed time to the limiter
apply_elapsed_time(Rate, Elapsed, Tokens, Capacity) ->
Inc = floor_div(mul(Elapsed, Rate), emqx_limiter_schema:minimum_period()),
erlang:min(add(Tokens, Inc), Capacity).

View File

@ -9,7 +9,5 @@
{env, []},
{licenses, ["Apache-2.0"]},
{maintainers, ["EMQ X Team <contact@emqx.io>"]},
{links, [{"Homepage", "https://emqx.io/"},
{"Github", "https://github.com/emqx/emqx-retainer"}
]}
{links, []}
]}.

View File

@ -1,5 +1,5 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.

View File

@ -0,0 +1,102 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_limiter_bucket_ref).
%% @doc limiter bucket reference
%% this module is used to manage the bucket reference of the limiter server
%% @end
%% API
-export([ new/3, check/3, try_restore/2
, available/1]).
-export_type([bucket_ref/0]).
-type infinity_bucket_ref() :: infinity.
-type finite_bucket_ref() :: #{ counter := counters:counters_ref()
, index := index()
, rate := rate()}.
-type bucket_ref() :: infinity_bucket_ref()
| finite_bucket_ref().
-type index() :: emqx_limiter_server:index().
-type rate() :: emqx_limiter_decimal:decimal().
-type check_failure_type() :: partial | pause.
%%--------------------------------------------------------------------
%% API
%%--------------------------------------------------------------------
-spec new(undefined | counters:countres_ref(),
undefined | index(),
rate()) -> bucket_ref().
new(undefined, _, _) ->
infinity;
new(Counter, Index, Rate) ->
#{counter => Counter,
index => Index,
rate => Rate}.
%% @doc check tokens
-spec check(pos_integer(), bucket_ref(), Disivisble :: boolean()) ->
HasToken :: {ok, emqx_limiter_decimal:decimal()}
| {check_failure_type(), rate(), pos_integer()}.
check(_, infinity, _) ->
{ok, infinity};
check(Need,
#{counter := Counter,
index := Index,
rate := Rate},
Divisible)->
RefToken = counters:get(Counter, Index),
if RefToken >= Need ->
counters:sub(Counter, Index, Need),
{ok, RefToken - Need};
Divisible andalso RefToken > 0 ->
counters:sub(Counter, Index, RefToken),
{partial, Rate, RefToken};
true ->
{pause, Rate, 0}
end.
%% @doc try to restore token when consume failed
-spec try_restore(non_neg_integer(), bucket_ref()) -> ok.
try_restore(0, _) ->
ok;
try_restore(_, infinity) ->
ok;
try_restore(Inc, #{counter := Counter, index := Index}) ->
case counters:get(Counter, Index) of
Tokens when Tokens < 0 ->
counters:add(Counter, Index, Inc);
_ ->
ok
end.
%% @doc get the number of tokens currently available
-spec available(bucket_ref()) -> emqx_limiter_decimal:decimal().
available(#{counter := Counter, index := Index}) ->
counters:get(Counter, Index);
available(infinity) ->
infinity.
%%--------------------------------------------------------------------
%% Internal functions
%%--------------------------------------------------------------------

View File

@ -0,0 +1,157 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_limiter_container).
%% @doc the container of emqx_htb_limiter
%% used to merge limiters of different type of limiters to simplify operations
%% @end
%% API
-export([ new/0, new/1, get_limiter_by_names/2
, add_new/3, update_by_name/3, set_retry_context/2
, check/3, retry/2, get_retry_context/1
, check_list/2, retry_list/2
]).
-export_type([container/0, check_result/0]).
-type container() :: #{ limiter_type() => undefined | limiter()
, retry_key() => undefined | retry_context() | future() %% the retry context of the limiter
, retry_ctx := undefined | any() %% the retry context of the container
}.
-type future() :: pos_integer().
-type limiter_type() :: emqx_limiter_schema:limiter_type().
-type limiter() :: emqx_htb_limiter:limiter().
-type retry_context() :: emqx_htb_limiter:retry_context().
-type bucket_name() :: emqx_limiter_schema:bucket_name().
-type millisecond() :: non_neg_integer().
-type check_result() :: {ok, container()}
| {drop, container()}
| {pause, millisecond(), container()}.
-define(RETRY_KEY(Type), {retry, Type}).
-type retry_key() :: ?RETRY_KEY(limiter_type()).
%%--------------------------------------------------------------------
%% API
%%--------------------------------------------------------------------
-spec new() -> container().
new() ->
new([]).
%% @doc generate default data according to the type of limiter
-spec new(list(limiter_type())) -> container().
new(Types) ->
get_limiter_by_names(Types, #{}).
%% @doc generate a container
%% according to the type of limiter and the bucket name configuration of the limiter
%% @end
-spec get_limiter_by_names(list(limiter_type()), #{limiter_type() => emqx_limiter_schema:bucket_name()}) -> container().
get_limiter_by_names(Types, BucketNames) ->
Init = fun(Type, Acc) ->
Limiter = emqx_limiter_server:connect(Type, BucketNames),
add_new(Type, Limiter, Acc)
end,
lists:foldl(Init, #{retry_ctx => undefined}, Types).
%% @doc add the specified type of limiter to the container
-spec update_by_name(limiter_type(),
bucket_name() | #{limiter_type() => bucket_name()},
container()) -> container().
update_by_name(Type, Buckets, Container) ->
Limiter = emqx_limiter_server:connect(Type, Buckets),
add_new(Type, Limiter, Container).
-spec add_new(limiter_type(), limiter(), container()) -> container().
add_new(Type, Limiter, Container) ->
Container#{ Type => Limiter
, ?RETRY_KEY(Type) => undefined
}.
%% @doc check the specified limiter
-spec check(pos_integer(), limiter_type(), container()) -> check_result().
check(Need, Type, Container) ->
check_list([{Need, Type}], Container).
%% @doc check multiple limiters
-spec check_list(list({pos_integer(), limiter_type()}), container()) -> check_result().
check_list([{Need, Type} | T], Container) ->
Limiter = maps:get(Type, Container),
case emqx_htb_limiter:check(Need, Limiter) of
{ok, Limiter2} ->
check_list(T, Container#{Type := Limiter2});
{_, PauseMs, Ctx, Limiter2} ->
Fun = fun({FN, FT}, Acc) ->
Future = emqx_htb_limiter:make_future(FN),
Acc#{?RETRY_KEY(FT) := Future}
end,
C2 = lists:foldl(Fun,
Container#{Type := Limiter2,
?RETRY_KEY(Type) := Ctx},
T),
{pause, PauseMs, C2};
{drop, Limiter2} ->
{drop, Container#{Type := Limiter2}}
end;
check_list([], Container) ->
{ok, Container}.
%% @doc retry the specified limiter
-spec retry(limiter_type(), container()) -> check_result().
retry(Type, Container) ->
retry_list([Type], Container).
%% @doc retry multiple limiters
-spec retry_list(list(limiter_type()), container()) -> check_result().
retry_list([Type | T], Container) ->
Key = ?RETRY_KEY(Type),
case Container of
#{Type := Limiter,
Key := Retry} when Retry =/= undefined ->
case emqx_htb_limiter:check(Retry, Limiter) of
{ok, Limiter2} ->
%% undefined meaning there is no retry context or there is no need to retry
%% when a limiter has a undefined retry context, the check will always success
retry_list(T, Container#{Type := Limiter2, Key := undefined});
{_, PauseMs, Ctx, Limiter2} ->
{pause,
PauseMs,
Container#{Type := Limiter2, Key := Ctx}};
{drop, Limiter2} ->
{drop, Container#{Type := Limiter2}}
end;
_ ->
retry_list(T, Container)
end;
retry_list([], Container) ->
{ok, Container}.
-spec set_retry_context(any(), container()) -> container().
set_retry_context(Data, Container) ->
Container#{retry_ctx := Data}.
-spec get_retry_context(container()) -> any().
get_retry_context(#{retry_ctx := Data}) ->
Data.
%%--------------------------------------------------------------------
%% Internal functions
%%--------------------------------------------------------------------

View File

@ -20,7 +20,7 @@
%% API
-export([ add/2, sub/2, mul/2
, add_to_counter/3, put_to_counter/3]).
, add_to_counter/3, put_to_counter/3, floor_div/2]).
-export_type([decimal/0, zero_or_float/0]).
-type decimal() :: infinity | number().
@ -53,6 +53,13 @@ mul(A, B) when A =:= infinity
mul(A, B) ->
A * B.
-spec floor_div(decimal(), number()) -> decimal().
floor_div(infinity, _) ->
infinity;
floor_div(A, B) ->
erlang:floor(A / B).
-spec add_to_counter(counters:counters_ref(), pos_integer(), decimal()) ->
{zero_or_float(), zero_or_float()}.
add_to_counter(_, _, infinity) ->

View File

@ -22,29 +22,27 @@
-include_lib("stdlib/include/ms_transform.hrl").
%% API
-export([ start_link/0, start_server/1, find_counter/1
, find_counter/3, insert_counter/4, insert_counter/6
-export([ start_link/0, start_server/1, find_bucket/1
, find_bucket/3, insert_bucket/2, insert_bucket/4
, make_path/3, restart_server/1]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3, format_status/2]).
-export_type([path/0]).
-type path() :: list(atom()).
-type limiter_type() :: emqx_limiter_schema:limiter_type().
-type zone_name() :: emqx_limiter_schema:zone_name().
-type bucket_name() :: emqx_limiter_schema:bucket_name().
%% counter record in ets table
-record(element, {path :: path(),
counter :: counters:counters_ref(),
index :: index(),
rate :: rate()
}).
-record(bucket, { path :: path()
, bucket :: bucket_ref()
}).
-type index() :: emqx_limiter_server:index().
-type rate() :: emqx_limiter_decimal:decimal().
-type bucket_ref() :: emqx_limiter_bucket_ref:bucket_ref().
-define(TAB, emqx_limiter_counters).
@ -59,43 +57,32 @@ start_server(Type) ->
restart_server(Type) ->
emqx_limiter_server_sup:restart(Type).
-spec find_counter(limiter_type(), zone_name(), bucket_name()) ->
{ok, counters:counters_ref(), index(), rate()} | undefined.
find_counter(Type, Zone, BucketId) ->
find_counter(make_path(Type, Zone, BucketId)).
-spec find_bucket(limiter_type(), zone_name(), bucket_name()) ->
{ok, bucket_ref()} | undefined.
find_bucket(Type, Zone, BucketId) ->
find_bucket(make_path(Type, Zone, BucketId)).
-spec find_counter(path()) ->
{ok, counters:counters_ref(), index(), rate()} | undefined.
find_counter(Path) ->
-spec find_bucket(path()) -> {ok, bucket_ref()} | undefined.
find_bucket(Path) ->
case ets:lookup(?TAB, Path) of
[#element{counter = Counter, index = Index, rate = Rate}] ->
{ok, Counter, Index, Rate};
[#bucket{bucket = Bucket}] ->
{ok, Bucket};
_ ->
undefined
end.
-spec insert_counter(limiter_type(),
zone_name(),
bucket_name(),
counters:counters_ref(),
index(),
rate()) -> boolean().
insert_counter(Type, Zone, BucketId, Counter, Index, Rate) ->
insert_counter(make_path(Type, Zone, BucketId),
Counter,
Index,
Rate).
-spec insert_bucket(limiter_type(),
zone_name(),
bucket_name(),
bucket_ref()) -> boolean().
insert_bucket(Type, Zone, BucketId, Bucket) ->
inner_insert_bucket(make_path(Type, Zone, BucketId),
Bucket).
-spec insert_counter(path(),
counters:counters_ref(),
index(),
rate()) -> boolean().
insert_counter(Path, Counter, Index, Rate) ->
ets:insert(?TAB,
#element{path = Path,
counter = Counter,
index = Index,
rate = Rate}).
-spec insert_bucket(path(), bucket_ref()) -> true.
insert_bucket(Path, Bucket) ->
inner_insert_bucket(Path, Bucket).
-spec make_path(limiter_type(), zone_name(), bucket_name()) -> path().
make_path(Type, Name, BucketId) ->
@ -129,7 +116,7 @@ start_link() ->
{stop, Reason :: term()} |
ignore.
init([]) ->
_ = ets:new(?TAB, [ set, public, named_table, {keypos, #element.path}
_ = ets:new(?TAB, [ set, public, named_table, {keypos, #bucket.path}
, {write_concurrency, true}, {read_concurrency, true}
, {heir, erlang:whereis(emqx_limiter_sup), none}
]),
@ -227,3 +214,7 @@ format_status(_Opt, Status) ->
%%--------------------------------------------------------------------
%% Internal functions
%%--------------------------------------------------------------------
-spec inner_insert_bucket(path(), bucket_ref()) -> true.
inner_insert_bucket(Path, Bucket) ->
ets:insert(?TAB,
#bucket{path = Path, bucket = Bucket}).

View File

@ -0,0 +1,176 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_limiter_schema).
-include_lib("typerefl/include/types.hrl").
-export([ roots/0, fields/1, to_rate/1, to_capacity/1
, minimum_period/0, to_burst_rate/1, to_initial/1]).
-define(KILOBYTE, 1024).
-type limiter_type() :: bytes_in
| message_in
| connection
| message_routing.
-type bucket_name() :: atom().
-type zone_name() :: atom().
-type rate() :: infinity | float().
-type burst_rate() :: 0 | float().
-type capacity() :: infinity | number(). %% the capacity of the token bucket
-type initial() :: non_neg_integer(). %% initial capacity of the token bucket
%% the processing strategy after the failure of the token request
-type failure_strategy() :: force %% Forced to pass
| drop %% discard the current request
| throw. %% throw an exception
-typerefl_from_string({rate/0, ?MODULE, to_rate}).
-typerefl_from_string({burst_rate/0, ?MODULE, to_burst_rate}).
-typerefl_from_string({capacity/0, ?MODULE, to_capacity}).
-typerefl_from_string({initial/0, ?MODULE, to_initial}).
-reflect_type([ rate/0
, burst_rate/0
, capacity/0
, initial/0
, failure_strategy/0
]).
-export_type([limiter_type/0, bucket_name/0, zone_name/0]).
-import(emqx_schema, [sc/2, map/2]).
roots() -> [emqx_limiter].
fields(emqx_limiter) ->
[ {bytes_in, sc(ref(limiter), #{})}
, {message_in, sc(ref(limiter), #{})}
, {connection, sc(ref(limiter), #{})}
, {message_routing, sc(ref(limiter), #{})}
];
fields(limiter) ->
[ {global, sc(ref(rate_burst), #{})}
, {zone, sc(map("zone name", ref(rate_burst)), #{})}
, {bucket, sc(map("bucket id", ref(bucket)),
#{desc => "token bucket"})}
];
fields(rate_burst) ->
[ {rate, sc(rate(), #{})}
, {burst, sc(burst_rate(), #{default => "0/0s"})}
];
fields(bucket) ->
[ {zone, sc(atom(), #{desc => "the zone which the bucket in"})}
, {aggregated, sc(ref(bucket_aggregated), #{})}
, {per_client, sc(ref(client_bucket), #{})}
];
fields(bucket_aggregated) ->
[ {rate, sc(rate(), #{})}
, {initial, sc(initial(), #{default => "0"})}
, {capacity, sc(capacity(), #{})}
];
fields(client_bucket) ->
[ {rate, sc(rate(), #{})}
, {initial, sc(initial(), #{default => "0"})}
%% low_water_mark add for emqx_channel and emqx_session
%% both modules consume first and then check
%% so we need to use this value to prevent excessive consumption (e.g, consumption from an empty bucket)
, {low_water_mark, sc(initial(),
#{desc => "if the remaining tokens are lower than this value,
the check/consume will succeed, but it will be forced to hang for a short period of time",
default => "0"})}
, {capacity, sc(capacity(), #{desc => "the capacity of the token bucket"})}
, {divisible, sc(boolean(),
#{desc => "is it possible to split the number of tokens requested",
default => false})}
, {max_retry_time, sc(emqx_schema:duration(),
#{ desc => "the maximum retry time when acquire failed"
, default => "5s"})}
, {failure_strategy, sc(failure_strategy(),
#{ desc => "the strategy when all retry failed"
, default => force})}
].
%% minimum period is 100ms
minimum_period() ->
100.
%%--------------------------------------------------------------------
%% Internal functions
%%--------------------------------------------------------------------
ref(Field) -> hoconsc:ref(?MODULE, Field).
to_rate(Str) ->
to_rate(Str, true, false).
to_burst_rate(Str) ->
to_rate(Str, false, true).
to_rate(Str, CanInfinity, CanZero) ->
Tokens = [string:trim(T) || T <- string:tokens(Str, "/")],
case Tokens of
["infinity"] when CanInfinity ->
{ok, infinity};
["0", _] when CanZero ->
{ok, 0}; %% for burst
[Quota, Interval] ->
{ok, Val} = to_capacity(Quota),
case emqx_schema:to_duration_ms(Interval) of
{ok, Ms} when Ms > 0 ->
{ok, Val * minimum_period() / Ms};
_ ->
{error, Str}
end;
_ ->
{error, Str}
end.
to_capacity(Str) ->
Regex = "^\s*(?:(?:([1-9][0-9]*)([a-zA-z]*))|infinity)\s*$",
to_quota(Str, Regex).
to_initial(Str) ->
Regex = "^\s*([0-9]+)([a-zA-z]*)\s*$",
to_quota(Str, Regex).
to_quota(Str, Regex) ->
{ok, MP} = re:compile(Regex),
Result = re:run(Str, MP, [{capture, all_but_first, list}]),
case Result of
{match, [Quota, Unit]} ->
Val = erlang:list_to_integer(Quota),
Unit2 = string:to_lower(Unit),
{ok, apply_unit(Unit2, Val)};
{match, [Quota]} ->
{ok, erlang:list_to_integer(Quota)};
{match, []} ->
{ok, infinity};
_ ->
{error, Str}
end.
apply_unit("", Val) -> Val;
apply_unit("kb", Val) -> Val * ?KILOBYTE;
apply_unit("mb", Val) -> Val * ?KILOBYTE * ?KILOBYTE;
apply_unit("gb", Val) -> Val * ?KILOBYTE * ?KILOBYTE * ?KILOBYTE;
apply_unit(Unit, _) -> throw("invalid unit:" ++ Unit).

View File

@ -0,0 +1,582 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
%% A hierarchical token bucket algorithm
%% Note: this is not the linux HTB algorithm(http://luxik.cdi.cz/~devik/qos/htb/manual/theory.htm)
%% Algorithm:
%% 1. the root node periodically generates tokens and then distributes them
%% just like the oscillation of water waves
%% 2. the leaf node has a counter, which is the place where the token is actually held.
%% 3. other nodes only play the role of transmission, and the rate of the node is like a valve,
%% limiting the oscillation transmitted from the parent node
-module(emqx_limiter_server).
-behaviour(gen_server).
-include_lib("emqx/include/logger.hrl").
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3, format_status/2]).
-export([ start_link/1, connect/2, info/2
, name/1, get_initial_val/1]).
-type root() :: #{ rate := rate() %% number of tokens generated per period
, burst := rate()
, period := pos_integer() %% token generation interval(second)
, childs := list(node_id()) %% node children
, consumed := non_neg_integer()
}.
-type zone() :: #{ id := node_id()
, name := zone_name()
, rate := rate()
, burst := rate()
, obtained := non_neg_integer() %% number of tokens obtained
, childs := list(node_id())
}.
-type bucket() :: #{ id := node_id()
, name := bucket_name()
, zone := zone_name() %% pointer to zone node, use for burst
, rate := rate()
, obtained := non_neg_integer()
, correction := emqx_limiter_decimal:zero_or_float() %% token correction value
, capacity := capacity()
, counter := undefined | counters:counters_ref()
, index := undefined | index()
}.
-type state() :: #{ root := undefined | root()
, counter := undefined | counters:counters_ref() %% current counter to alloc
, index := index()
, zones := #{zone_name() => node_id()}
, buckets := list(node_id())
, nodes := nodes()
, type := limiter_type()
}.
-type node_id() :: pos_integer().
-type node_data() :: zone() | bucket().
-type nodes() :: #{node_id() => node_data()}.
-type zone_name() :: emqx_limiter_schema:zone_name().
-type limiter_type() :: emqx_limiter_schema:limiter_type().
-type bucket_name() :: emqx_limiter_schema:bucket_name().
-type rate() :: decimal().
-type flow() :: decimal().
-type capacity() :: decimal().
-type decimal() :: emqx_limiter_decimal:decimal().
-type index() :: pos_integer().
-define(CALL(Type, Msg), gen_server:call(name(Type), {?FUNCTION_NAME, Msg})).
-define(OVERLOAD_MIN_ALLOC, 0.3). %% minimum coefficient for overloaded limiter
-export_type([index/0]).
-import(emqx_limiter_decimal, [add/2, sub/2, mul/2, add_to_counter/3, put_to_counter/3]).
%%--------------------------------------------------------------------
%% API
%%--------------------------------------------------------------------
-spec connect(limiter_type(),
bucket_name() | #{limiter_type() => bucket_name()}) -> emqx_htb_limiter:limiter().
connect(Type, BucketName) when is_atom(BucketName) ->
Path = [emqx_limiter, Type, bucket, BucketName],
case emqx:get_config(Path, undefined) of
undefined ->
?LOG(error, "can't find the config of this bucket: ~p~n", [Path]),
throw("bucket's config not found");
#{zone := Zone,
aggregated := #{rate := AggrRate, capacity := AggrSize},
per_client := #{rate := CliRate, capacity := CliSize} = Cfg} ->
case emqx_limiter_manager:find_bucket(Type, Zone, BucketName) of
{ok, Bucket} ->
if CliRate < AggrRate orelse CliSize < AggrSize ->
emqx_htb_limiter:make_token_bucket_limiter(Cfg, Bucket);
Bucket =:= infinity ->
emqx_htb_limiter:make_infinity_limiter(Cfg);
true ->
emqx_htb_limiter:make_ref_limiter(Cfg, Bucket)
end;
undefined ->
?LOG(error, "can't find the bucket:~p~n", [Path]),
throw("invalid bucket")
end
end;
connect(Type, Names) ->
connect(Type, maps:get(Type, Names, default)).
-spec info(limiter_type(), atom()) -> term().
info(Type, Info) ->
?CALL(Type, Info).
-spec name(limiter_type()) -> atom().
name(Type) ->
erlang:list_to_atom(io_lib:format("~s_~s", [?MODULE, Type])).
%%--------------------------------------------------------------------
%% @doc
%% Starts the server
%% @end
%%--------------------------------------------------------------------
-spec start_link(limiter_type()) -> _.
start_link(Type) ->
gen_server:start_link({local, name(Type)}, ?MODULE, [Type], []).
%%--------------------------------------------------------------------
%%% gen_server callbacks
%%--------------------------------------------------------------------
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Initializes the server
%% @end
%%--------------------------------------------------------------------
-spec init(Args :: term()) -> {ok, State :: term()} |
{ok, State :: term(), Timeout :: timeout()} |
{ok, State :: term(), hibernate} |
{stop, Reason :: term()} |
ignore.
init([Type]) ->
State = #{root => undefined,
counter => undefined,
index => 1,
zones => #{},
nodes => #{},
buckets => [],
type => Type},
State2 = init_tree(Type, State),
#{root := #{period := Perido}} = State2,
oscillate(Perido),
{ok, State2}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Handling call messages
%% @end
%%--------------------------------------------------------------------
-spec handle_call(Request :: term(), From :: {pid(), term()}, State :: term()) ->
{reply, Reply :: term(), NewState :: term()} |
{reply, Reply :: term(), NewState :: term(), Timeout :: timeout()} |
{reply, Reply :: term(), NewState :: term(), hibernate} |
{noreply, NewState :: term()} |
{noreply, NewState :: term(), Timeout :: timeout()} |
{noreply, NewState :: term(), hibernate} |
{stop, Reason :: term(), Reply :: term(), NewState :: term()} |
{stop, Reason :: term(), NewState :: term()}.
handle_call(Req, _From, State) ->
?LOG(error, "Unexpected call: ~p", [Req]),
{reply, ignored, State}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Handling cast messages
%% @end
%%--------------------------------------------------------------------
-spec handle_cast(Request :: term(), State :: term()) ->
{noreply, NewState :: term()} |
{noreply, NewState :: term(), Timeout :: timeout()} |
{noreply, NewState :: term(), hibernate} |
{stop, Reason :: term(), NewState :: term()}.
handle_cast(Req, State) ->
?LOG(error, "Unexpected cast: ~p", [Req]),
{noreply, State}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Handling all non call/cast messages
%% @end
%%--------------------------------------------------------------------
-spec handle_info(Info :: timeout() | term(), State :: term()) ->
{noreply, NewState :: term()} |
{noreply, NewState :: term(), Timeout :: timeout()} |
{noreply, NewState :: term(), hibernate} |
{stop, Reason :: normal | term(), NewState :: term()}.
handle_info(oscillate, State) ->
{noreply, oscillation(State)};
handle_info(Info, State) ->
?LOG(error, "Unexpected info: ~p", [Info]),
{noreply, State}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% This function is called by a gen_server when it is about to
%% terminate. It should be the opposite of Module:init/1 and do any
%% necessary cleaning up. When it returns, the gen_server terminates
%% with Reason. The return value is ignored.
%% @end
%%--------------------------------------------------------------------
-spec terminate(Reason :: normal | shutdown | {shutdown, term()} | term(),
State :: term()) -> any().
terminate(_Reason, _State) ->
ok.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Convert process state when code is changed
%% @end
%%--------------------------------------------------------------------
-spec code_change(OldVsn :: term() | {down, term()},
State :: term(),
Extra :: term()) -> {ok, NewState :: term()} |
{error, Reason :: term()}.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% This function is called for changing the form and appearance
%% of gen_server status when it is returned from sys:get_status/1,2
%% or when it appears in termination error logs.
%% @end
%%--------------------------------------------------------------------
-spec format_status(Opt :: normal | terminate,
Status :: list()) -> Status :: term().
format_status(_Opt, Status) ->
Status.
%%--------------------------------------------------------------------
%%% Internal functions
%%--------------------------------------------------------------------
oscillate(Interval) ->
erlang:send_after(Interval, self(), ?FUNCTION_NAME).
%% @doc generate tokens, and then spread to leaf nodes
-spec oscillation(state()) -> state().
oscillation(#{root := #{rate := Flow,
period := Interval,
childs := ChildIds,
consumed := Consumed} = Root,
nodes := Nodes} = State) ->
oscillate(Interval),
Childs = get_ordered_childs(ChildIds, Nodes),
{Alloced, Nodes2} = transverse(Childs, Flow, 0, Nodes),
maybe_burst(State#{nodes := Nodes2,
root := Root#{consumed := Consumed + Alloced}}).
%% @doc horizontal spread
-spec transverse(list(node_data()),
flow(),
non_neg_integer(),
nodes()) -> {non_neg_integer(), nodes()}.
transverse([H | T], InFlow, Alloced, Nodes) when InFlow > 0 ->
{NodeAlloced, Nodes2} = longitudinal(H, InFlow, Nodes),
InFlow2 = sub(InFlow, NodeAlloced),
Alloced2 = Alloced + NodeAlloced,
transverse(T, InFlow2, Alloced2, Nodes2);
transverse(_, _, Alloced, Nodes) ->
{Alloced, Nodes}.
%% @doc vertical spread
-spec longitudinal(node_data(), flow(), nodes()) ->
{non_neg_integer(), nodes()}.
longitudinal(#{id := Id,
rate := Rate,
obtained := Obtained,
childs := ChildIds} = Node, InFlow, Nodes) ->
Flow = erlang:min(InFlow, Rate),
if Flow > 0 ->
Childs = get_ordered_childs(ChildIds, Nodes),
{Alloced, Nodes2} = transverse(Childs, Flow, 0, Nodes),
if Alloced > 0 ->
{Alloced,
Nodes2#{Id => Node#{obtained := Obtained + Alloced}}};
true ->
%% childs are empty or all counter childs are full
{0, Nodes2}
end;
true ->
{0, Nodes}
end;
longitudinal(#{id := Id,
rate := Rate,
capacity := Capacity,
correction := Correction,
counter := Counter,
index := Index,
obtained := Obtained} = Node,
InFlow, Nodes) when Counter =/= undefined ->
Flow = add(erlang:min(InFlow, Rate), Correction),
ShouldAlloc =
case counters:get(Counter, Index) of
Tokens when Tokens < 0 ->
%% toknes's value mayb be a negative value(stolen from the future)
%% because x. add(Capacity, x) < 0, so here we must compare with minimum value
erlang:max(add(Capacity, Tokens),
mul(Capacity, ?OVERLOAD_MIN_ALLOC));
Tokens ->
%% is it possible that Tokens > Capacity ???
erlang:max(sub(Capacity, Tokens), 0)
end,
case lists:min([ShouldAlloc, Flow, Capacity]) of
Avaiable when Avaiable > 0 ->
%% XXX if capacity is infinity, and flow always > 0, the value in counter
%% will be overflow at some point in the future, do we need to deal with this situation???
{Alloced, Decimal} = add_to_counter(Counter, Index, Avaiable),
{Alloced,
Nodes#{Id := Node#{obtained := Obtained + Alloced,
correction := Decimal}}};
_ ->
{0, Nodes}
end;
longitudinal(_, _, Nodes) ->
{0, Nodes}.
-spec get_ordered_childs(list(node_id()), nodes()) -> list(node_data()).
get_ordered_childs(Ids, Nodes) ->
Childs = [maps:get(Id, Nodes) || Id <- Ids],
%% sort by obtained, avoid node goes hungry
lists:sort(fun(#{obtained := A}, #{obtained := B}) ->
A < B
end,
Childs).
-spec maybe_burst(state()) -> state().
maybe_burst(#{buckets := Buckets,
zones := Zones,
root := #{burst := Burst},
nodes := Nodes} = State) when Burst > 0 ->
%% find empty buckets and group by zone name
GroupFun = fun(Id, Groups) ->
#{counter := Counter,
index := Index,
zone := Zone} = maps:get(Id, Nodes),
case counters:get(Counter, Index) of
Any when Any =< 0 ->
Group = maps:get(Zone, Groups, []),
maps:put(Zone, [Id | Group], Groups);
_ ->
Groups
end
end,
case lists:foldl(GroupFun, #{}, Buckets) of
Groups when map_size(Groups) > 0 ->
%% remove the zone which don't support burst
Filter = fun({Name, Childs}, Acc) ->
ZoneId = maps:get(Name, Zones),
#{burst := ZoneBurst} = Zone = maps:get(ZoneId, Nodes),
case ZoneBurst > 0 of
true ->
[{Zone, Childs} | Acc];
_ ->
Acc
end
end,
FilterL = lists:foldl(Filter, [], maps:to_list(Groups)),
dispatch_burst(FilterL, State);
_ ->
State
end;
maybe_burst(State) ->
State.
-spec dispatch_burst(list({zone(), list(node_id())}), state()) -> state().
dispatch_burst([], State) ->
State;
dispatch_burst(GroupL,
#{root := #{burst := Burst},
nodes := Nodes} = State) ->
InFlow = erlang:floor(Burst / erlang:length(GroupL)),
Dispatch = fun({Zone, Childs}, NodeAcc) ->
#{id := ZoneId,
burst := ZoneBurst,
obtained := Obtained} = Zone,
ZoneFlow = erlang:min(InFlow, ZoneBurst),
EachFlow = ZoneFlow div erlang:length(Childs),
Zone2 = Zone#{obtained := Obtained + ZoneFlow},
NodeAcc2 = NodeAcc#{ZoneId := Zone2},
dispatch_burst_to_buckets(Childs, EachFlow, NodeAcc2)
end,
State#{nodes := lists:foldl(Dispatch, Nodes, GroupL)}.
-spec dispatch_burst_to_buckets(list(node_id()),
non_neg_integer(), nodes()) -> nodes().
dispatch_burst_to_buckets(Childs, InFlow, Nodes) ->
Each = fun(ChildId, NodeAcc) ->
#{counter := Counter,
index := Index,
obtained := Obtained} = Bucket = maps:get(ChildId, NodeAcc),
counters:add(Counter, Index, InFlow),
NodeAcc#{ChildId := Bucket#{obtained := Obtained + InFlow}}
end,
lists:foldl(Each, Nodes, Childs).
-spec init_tree(emqx_limiter_schema:limiter_type(), state()) -> state().
init_tree(Type, State) ->
#{global := Global,
zone := Zone,
bucket := Bucket} = emqx:get_config([emqx_limiter, Type]),
{Factor, Root} = make_root(Global, Zone),
State2 = State#{root := Root},
{NodeId, State3} = make_zone(maps:to_list(Zone), Factor, 1, State2),
State4 = State3#{counter := counters:new(maps:size(Bucket),
[write_concurrency])},
make_bucket(maps:to_list(Bucket), Global, Zone, Factor, NodeId, [], State4).
-spec make_root(hocons:confg(), hocon:config()) -> {number(), root()}.
make_root(#{rate := Rate, burst := Burst}, Zone) ->
ZoneNum = maps:size(Zone),
Childs = lists:seq(1, ZoneNum),
MiniPeriod = emqx_limiter_schema:minimum_period(),
if Rate >= 1 ->
{1, #{rate => Rate,
burst => Burst,
period => MiniPeriod,
childs => Childs,
consumed => 0}};
true ->
Factor = 1 / Rate,
{Factor, #{rate => 1,
burst => Burst * Factor,
period => erlang:floor(Factor * MiniPeriod),
childs => Childs,
consumed => 0}}
end.
make_zone([{Name, ZoneCfg} | T], Factor, NodeId, State) ->
#{rate := Rate, burst := Burst} = ZoneCfg,
#{zones := Zones, nodes := Nodes} = State,
Zone = #{id => NodeId,
name => Name,
rate => mul(Rate, Factor),
burst => Burst,
obtained => 0,
childs => []},
State2 = State#{zones := Zones#{Name => NodeId},
nodes := Nodes#{NodeId => Zone}},
make_zone(T, Factor, NodeId + 1, State2);
make_zone([], _, NodeId, State2) ->
{NodeId, State2}.
make_bucket([{Name, Conf} | T], Global, Zone, Factor, Id, Buckets, #{type := Type} = State) ->
#{zone := ZoneName,
aggregated := Aggregated} = Conf,
Path = emqx_limiter_manager:make_path(Type, ZoneName, Name),
case get_counter_rate(Conf, Zone, Global) of
infinity ->
State2 = State,
Rate = infinity,
Capacity = infinity,
Counter = undefined,
Index = undefined,
Ref = emqx_limiter_bucket_ref:new(Counter, Index, Rate),
emqx_limiter_manager:insert_bucket(Path, Ref);
RawRate ->
#{capacity := Capacity} = Aggregated,
Initial = get_initial_val(Aggregated),
{Counter, Index, State2} = alloc_counter(Path, RawRate, Initial, State),
Rate = mul(RawRate, Factor)
end,
Node = #{ id => Id
, name => Name
, zone => ZoneName
, rate => Rate
, obtained => 0
, correction => 0
, capacity => Capacity
, counter => Counter
, index => Index},
State3 = add_zone_child(Id, Node, ZoneName, State2),
make_bucket(T, Global, Zone, Factor, Id + 1, [Id | Buckets], State3);
make_bucket([], _, _, _, _, Buckets, State) ->
State#{buckets := Buckets}.
-spec alloc_counter(emqx_limiter_manager:path(), rate(), capacity(), state()) ->
{counters:counters_ref(), pos_integer(), state()}.
alloc_counter(Path, Rate, Initial,
#{counter := Counter, index := Index} = State) ->
case emqx_limiter_manager:find_bucket(Path) of
{ok, #{counter := ECounter,
index := EIndex}} when ECounter =/= undefined ->
init_counter(Path, ECounter, EIndex, Rate, Initial, State);
_ ->
init_counter(Path, Counter, Index,
Rate, Initial, State#{index := Index + 1})
end.
init_counter(Path, Counter, Index, Rate, Initial, State) ->
_ = put_to_counter(Counter, Index, Initial),
Ref = emqx_limiter_bucket_ref:new(Counter, Index, Rate),
emqx_limiter_manager:insert_bucket(Path, Ref),
{Counter, Index, State}.
-spec add_zone_child(node_id(), bucket(), zone_name(), state()) -> state().
add_zone_child(NodeId, Bucket, Name, #{zones := Zones, nodes := Nodes} = State) ->
ZoneId = maps:get(Name, Zones),
#{childs := Childs} = Zone = maps:get(ZoneId, Nodes),
Nodes2 = Nodes#{ZoneId => Zone#{childs := [NodeId | Childs]},
NodeId => Bucket},
State#{nodes := Nodes2}.
%% @doc find first limited node
get_counter_rate(#{zone := ZoneName,
aggregated := Cfg}, ZoneCfg, Global) ->
Zone = maps:get(ZoneName, ZoneCfg),
Search = lists:search(fun(E) -> is_limited(E) end,
[Cfg, Zone, Global]),
case Search of
{value, #{rate := Rate}} ->
Rate;
false ->
infinity
end.
is_limited(#{rate := Rate, capacity := Capacity}) ->
Rate =/= infinity orelse Capacity =/= infinity;
is_limited(#{rate := Rate}) ->
Rate =/= infinity.
get_initial_val(#{initial := Initial,
rate := Rate,
capacity := Capacity}) ->
%% initial will nevner be infinity(see the emqx_limiter_schema)
if Initial > 0 ->
Initial;
Rate =/= infinity ->
erlang:min(Rate, Capacity);
Capacity =/= infinity ->
Capacity;
true ->
0
end.

View File

@ -1,5 +1,5 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
@ -24,9 +24,9 @@
%% Supervisor callbacks
-export([init/1]).
%%--==================================================================
%%--------------------------------------------------------------------
%% API functions
%%--==================================================================
%%--------------------------------------------------------------------
%%--------------------------------------------------------------------
%% @doc
@ -52,9 +52,9 @@ restart(Type) ->
_ = supervisor:terminate_child(?MODULE, Id),
supervisor:restart_child(?MODULE, Id).
%%--==================================================================
%%--------------------------------------------------------------------
%% Supervisor callbacks
%%--==================================================================
%%--------------------------------------------------------------------
%%--------------------------------------------------------------------
%% @private

View File

@ -1,5 +1,5 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.

View File

@ -228,7 +228,8 @@ do_start_listener(Type, ListenerName, #{bind := ListenOn} = Opts)
esockd:open(listener_id(Type, ListenerName), ListenOn, merge_default(esockd_opts(Type, Opts)),
{emqx_connection, start_link,
[#{listener => {Type, ListenerName},
zone => zone(Opts)}]});
zone => zone(Opts),
limiter => limiter(Opts)}]});
%% Start MQTT/WS listener
do_start_listener(Type, ListenerName, #{bind := ListenOn} = Opts)
@ -260,6 +261,7 @@ do_start_listener(quic, ListenerName, #{bind := ListenOn} = Opts) ->
, peer_bidi_stream_count => 10
, zone => zone(Opts)
, listener => {quic, ListenerName}
, limiter => limiter(Opts)
},
StreamOpts = [{stream_callback, emqx_quic_stream}],
quicer:start_listener(listener_id(quic, ListenerName),
@ -315,7 +317,9 @@ esockd_opts(Type, Opts0) ->
ws_opts(Type, ListenerName, Opts) ->
WsPaths = [{maps:get(mqtt_path, Opts, "/mqtt"), emqx_ws_connection,
#{zone => zone(Opts), listener => {Type, ListenerName}}}],
#{zone => zone(Opts),
listener => {Type, ListenerName},
limiter => limiter(Opts)}}],
Dispatch = cowboy_router:compile([{'_', WsPaths}]),
ProxyProto = maps:get(proxy_protocol, Opts, false),
#{env => #{dispatch => Dispatch}, proxy_header => ProxyProto}.
@ -380,6 +384,9 @@ parse_listener_id(Id) ->
zone(Opts) ->
maps:get(zone, Opts, undefined).
limiter(Opts) ->
maps:get(limiter, Opts).
ssl_opts(Opts) ->
maps:to_list(
emqx_tls_lib:drop_tls13_for_old_otp(

View File

@ -55,6 +55,8 @@
, hexstr2bin/1
]).
-export([clamp/3]).
-define(SHORT, 8).
%% @doc Parse v4 or v6 string format address to tuple.
@ -305,6 +307,13 @@ gen_id(Len) ->
<<R:BitLen>> = crypto:strong_rand_bytes(Len div 2),
int_to_hex(R, Len).
-spec clamp(number(), number(), number()) -> number().
clamp(Val, Min, Max) ->
if Val < Min -> Min;
Val > Max -> Max;
true -> Val
end.
%%------------------------------------------------------------------------------
%% Internal Functions
%%------------------------------------------------------------------------------

View File

@ -1017,6 +1017,8 @@ base_listener() ->
sc(atom(),
#{ default => 'default'
})}
, {"limiter",
sc(map("ratelimit bucket's name", atom()), #{default => #{}})}
].
%% utils

View File

@ -68,12 +68,13 @@ init([]) ->
SessionSup = child_spec(emqx_persistent_session_sup, supervisor),
CMSup = child_spec(emqx_cm_sup, supervisor),
SysSup = child_spec(emqx_sys_sup, supervisor),
Limiter = child_spec(emqx_limiter_sup, supervisor),
Children = [KernelSup] ++
[SessionSup || emqx_persistent_session:is_store_enabled()] ++
[RouterSup || emqx_boot:is_enabled(router)] ++
[BrokerSup || emqx_boot:is_enabled(broker)] ++
[CMSup || emqx_boot:is_enabled(broker)] ++
[SysSup],
[SysSup, Limiter],
SupFlags = #{strategy => one_for_all,
intensity => 0,
period => 1

View File

@ -63,10 +63,6 @@
sockstate :: emqx_types:sockstate(),
%% MQTT Piggyback
mqtt_piggyback :: single | multiple,
%% Limiter
limiter :: maybe(emqx_limiter:limiter()),
%% Limit Timer
limit_timer :: maybe(reference()),
%% Parse State
parse_state :: emqx_frame:parse_state(),
%% Serialize options
@ -86,10 +82,30 @@
%% Zone name
zone :: atom(),
%% Listener Type and Name
listener :: {Type::atom(), Name::atom()}
}).
listener :: {Type::atom(), Name::atom()},
%% Limiter
limiter :: maybe(container()),
%% cache operation when overload
limiter_cache :: queue:queue(cache()),
%% limiter timers
limiter_timer :: undefined | reference()
}).
-record(retry, { types :: list(limiter_type())
, data :: any()
, next :: check_succ_handler()
}).
-record(cache, { need :: list({pos_integer(), limiter_type()})
, data :: any()
, next :: check_succ_handler()
}).
-type(state() :: #state{}).
-type cache() :: #cache{}.
-type(ws_cmd() :: {active, boolean()}|close).
@ -99,6 +115,8 @@
-define(CONN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]).
-define(ENABLED(X), (X =/= undefined)).
-define(LIMITER_BYTES_IN, bytes_in).
-define(LIMITER_MESSAGE_IN, message_in).
-dialyzer({no_match, [info/2]}).
-dialyzer({nowarn_function, [websocket_init/1]}).
@ -126,7 +144,7 @@ info(sockname, #state{sockname = Sockname}) ->
info(sockstate, #state{sockstate = SockSt}) ->
SockSt;
info(limiter, #state{limiter = Limiter}) ->
maybe_apply(fun emqx_limiter:info/1, Limiter);
Limiter;
info(channel, #state{channel = Channel}) ->
emqx_channel:info(Channel);
info(gc_state, #state{gc_state = GcSt}) ->
@ -242,7 +260,8 @@ check_origin_header(Req, #{listener := {Type, Listener}} = Opts) ->
false -> ok
end.
websocket_init([Req, #{zone := Zone, listener := {Type, Listener}} = Opts]) ->
websocket_init([Req,
#{zone := Zone, limiter := LimiterCfg, listener := {Type, Listener}} = Opts]) ->
{Peername, Peercert} =
case emqx_config:get_listener_conf(Type, Listener, [proxy_protocol]) andalso
maps:get(proxy_header, Req) of
@ -279,7 +298,7 @@ websocket_init([Req, #{zone := Zone, listener := {Type, Listener}} = Opts]) ->
ws_cookie => WsCookie,
conn_mod => ?MODULE
},
Limiter = emqx_limiter:init(Zone, undefined, undefined, []),
Limiter = emqx_limiter_container:get_limiter_by_names([?LIMITER_BYTES_IN, ?LIMITER_MESSAGE_IN], LimiterCfg),
MQTTPiggyback = get_ws_opts(Type, Listener, mqtt_piggyback),
FrameOpts = #{
strict_mode => emqx_config:get_zone_conf(Zone, [mqtt, strict_mode]),
@ -319,7 +338,9 @@ websocket_init([Req, #{zone := Zone, listener := {Type, Listener}} = Opts]) ->
idle_timeout = IdleTimeout,
idle_timer = IdleTimer,
zone = Zone,
listener = {Type, Listener}
listener = {Type, Listener},
limiter_timer = undefined,
limiter_cache = queue:new()
}, hibernate}.
websocket_handle({binary, Data}, State) when is_list(Data) ->
@ -327,9 +348,17 @@ websocket_handle({binary, Data}, State) when is_list(Data) ->
websocket_handle({binary, Data}, State) ->
?SLOG(debug, #{msg => "RECV_data", data => Data, transport => websocket}),
ok = inc_recv_stats(1, iolist_size(Data)),
NState = ensure_stats_timer(State),
return(parse_incoming(Data, NState));
State2 = ensure_stats_timer(State),
{Packets, State3} = parse_incoming(Data, [], State2),
LenMsg = erlang:length(Packets),
ByteSize = erlang:iolist_size(Data),
inc_recv_stats(LenMsg, ByteSize),
State4 = check_limiter([{ByteSize, ?LIMITER_BYTES_IN}, {LenMsg, ?LIMITER_MESSAGE_IN}],
Packets,
fun when_msg_in/3,
[],
State3),
return(State4);
%% Pings should be replied with pongs, cowboy does it automatically
%% Pongs can be safely ignored. Clause here simply prevents crash.
@ -343,7 +372,6 @@ websocket_handle({Frame, _}, State) ->
%% TODO: should not close the ws connection
?SLOG(error, #{msg => "unexpected_frame", frame => Frame}),
shutdown(unexpected_ws_frame, State).
websocket_info({call, From, Req}, State) ->
handle_call(From, Req, State);
@ -351,8 +379,7 @@ websocket_info({cast, rate_limit}, State) ->
Stats = #{cnt => emqx_pd:reset_counter(incoming_pubs),
oct => emqx_pd:reset_counter(incoming_bytes)
},
NState = postpone({check_gc, Stats}, State),
return(ensure_rate_limit(Stats, NState));
return(postpone({check_gc, Stats}, State));
websocket_info({cast, Msg}, State) ->
handle_info(Msg, State);
@ -377,12 +404,18 @@ websocket_info(Deliver = {deliver, _Topic, _Msg},
Delivers = [Deliver|emqx_misc:drain_deliver(ActiveN)],
with_channel(handle_deliver, [Delivers], State);
websocket_info({timeout, TRef, limit_timeout},
State = #state{limit_timer = TRef}) ->
NState = State#state{sockstate = running,
limit_timer = undefined
},
return(enqueue({active, true}, NState));
websocket_info({timeout, _, limit_timeout},
State) ->
return(retry_limiter(State));
websocket_info(check_cache, #state{limiter_cache = Cache} = State) ->
case queue:peek(Cache) of
empty ->
return(enqueue({active, true}, State#state{sockstate = running}));
{value, #cache{need = Needs, data = Data, next = Next}} ->
State2 = State#state{limiter_cache = queue:drop(Cache)},
return(check_limiter(Needs, Data, Next, [check_cache], State2))
end;
websocket_info({timeout, TRef, Msg}, State) when is_reference(TRef) ->
handle_timeout(TRef, Msg, State);
@ -421,10 +454,9 @@ handle_call(From, stats, State) ->
gen_server:reply(From, stats(State)),
return(State);
handle_call(_From, {ratelimit, Policy}, State = #state{channel = Channel}) ->
Zone = emqx_channel:info(zone, Channel),
Limiter = emqx_limiter:init(Zone, Policy),
{reply, ok, State#state{limiter = Limiter}};
handle_call(_From, {ratelimit, Type, Bucket}, State = #state{limiter = Limiter}) ->
Limiter2 = emqx_limiter_container:update_by_name(Type, Bucket, Limiter),
{reply, ok, State#state{limiter = Limiter2}};
handle_call(From, Req, State = #state{channel = Channel}) ->
case emqx_channel:handle_call(Req, Channel) of
@ -495,21 +527,80 @@ handle_timeout(TRef, TMsg, State) ->
%% Ensure rate limit
%%--------------------------------------------------------------------
ensure_rate_limit(Stats, State = #state{limiter = Limiter}) ->
case ?ENABLED(Limiter) andalso emqx_limiter:check(Stats, Limiter) of
false -> State;
{ok, Limiter1} ->
State#state{limiter = Limiter1};
{pause, Time, Limiter1} ->
?SLOG(warning, #{msg => "pause_due_to_rate_limit", time => Time}),
TRef = start_timer(Time, limit_timeout),
NState = State#state{sockstate = blocked,
limiter = Limiter1,
limit_timer = TRef
},
enqueue({active, false}, NState)
-type limiter_type() :: emqx_limiter_container:limiter_type().
-type container() :: emqx_limiter_container:container().
-type check_succ_handler() ::
fun((any(), list(any()), state()) -> state()).
-spec check_limiter(list({pos_integer(), limiter_type()}),
any(),
check_succ_handler(),
list(any()),
state()) -> state().
check_limiter(Needs,
Data,
WhenOk,
Msgs,
#state{limiter = Limiter,
limiter_timer = LimiterTimer,
limiter_cache = Cache} = State) ->
case LimiterTimer of
undefined ->
case emqx_limiter_container:check_list(Needs, Limiter) of
{ok, Limiter2} ->
WhenOk(Data, Msgs, State#state{limiter = Limiter2});
{pause, Time, Limiter2} ->
?SLOG(warning, #{msg => "pause time dueto rate limit",
needs => Needs,
time_in_ms => Time}),
Retry = #retry{types = [Type || {_, Type} <- Needs],
data = Data,
next = WhenOk},
Limiter3 = emqx_limiter_container:set_retry_context(Retry, Limiter2),
TRef = start_timer(Time, limit_timeout),
enqueue({active, false},
State#state{sockstate = blocked,
limiter = Limiter3,
limiter_timer = TRef});
{drop, Limiter2} ->
{ok, State#state{limiter = Limiter2}}
end;
_ ->
New = #cache{need = Needs, data = Data, next = WhenOk},
State#state{limiter_cache = queue:in(New, Cache)}
end.
-spec retry_limiter(state()) -> state().
retry_limiter(#state{limiter = Limiter} = State) ->
#retry{types = Types, data = Data, next = Next} = emqx_limiter_container:get_retry_context(Limiter),
case emqx_limiter_container:retry_list(Types, Limiter) of
{ok, Limiter2} ->
Next(Data,
[check_cache],
State#state{ limiter = Limiter2
, limiter_timer = undefined
});
{pause, Time, Limiter2} ->
?SLOG(warning, #{msg => "pause time dueto rate limit",
types => Types,
time_in_ms => Time}),
TRef = start_timer(Time, limit_timeout),
State#state{limiter = Limiter2, limiter_timer = TRef}
end.
when_msg_in(Packets, [], State) ->
postpone(Packets, State);
when_msg_in(Packets, Msgs, State) ->
postpone(Packets, enqueue(Msgs, State)).
%%--------------------------------------------------------------------
%% Run GC, Check OOM
%%--------------------------------------------------------------------
@ -538,16 +629,16 @@ check_oom(State = #state{channel = Channel}) ->
%% Parse incoming data
%%--------------------------------------------------------------------
parse_incoming(<<>>, State) ->
State;
parse_incoming(<<>>, Packets, State) ->
{Packets, State};
parse_incoming(Data, State = #state{parse_state = ParseState}) ->
parse_incoming(Data, Packets, State = #state{parse_state = ParseState}) ->
try emqx_frame:parse(Data, ParseState) of
{more, NParseState} ->
State#state{parse_state = NParseState};
{Packets, State#state{parse_state = NParseState}};
{ok, Packet, Rest, NParseState} ->
NState = State#state{parse_state = NParseState},
parse_incoming(Rest, postpone({incoming, Packet}, NState))
parse_incoming(Rest, [{incoming, Packet} | Packets], NState)
catch
throw : ?FRAME_PARSE_ERROR(Reason) ->
?SLOG(info, #{ reason => Reason
@ -555,7 +646,7 @@ parse_incoming(Data, State = #state{parse_state = ParseState}) ->
, input_bytes => Data
}),
FrameError = {frame_error, Reason},
postpone({incoming, FrameError}, State);
{[{incoming, FrameError} | Packets], State};
error : Reason : Stacktrace ->
?SLOG(error, #{ at_state => emqx_frame:describe_state(ParseState)
, input_bytes => Data
@ -563,7 +654,7 @@ parse_incoming(Data, State = #state{parse_state = ParseState}) ->
, stacktrace => Stacktrace
}),
FrameError = {frame_error, Reason},
postpone({incoming, FrameError}, State)
{[{incoming, FrameError} | Packets], State}
end.
%%--------------------------------------------------------------------

View File

@ -129,7 +129,8 @@ basic_conf() ->
rpc => rpc_conf(),
stats => stats_conf(),
listeners => listeners_conf(),
zones => zone_conf()
zones => zone_conf(),
emqx_limiter => emqx:get_config([emqx_limiter])
}.
set_test_listener_confs() ->
@ -178,14 +179,48 @@ end_per_suite(_Config) ->
emqx_banned
]).
init_per_testcase(_TestCase, Config) ->
init_per_testcase(TestCase, Config) ->
NewConf = set_test_listener_confs(),
emqx_common_test_helpers:start_apps([]),
modify_limiter(TestCase, NewConf),
[{config, NewConf}|Config].
end_per_testcase(_TestCase, Config) ->
emqx_config:put(?config(config, Config)),
emqx_common_test_helpers:stop_apps([]),
Config.
modify_limiter(TestCase, NewConf) ->
Checks = [t_quota_qos0, t_quota_qos1, t_quota_qos2],
case lists:member(TestCase, Checks) of
true ->
modify_limiter(NewConf);
_ ->
ok
end.
%% per_client 5/1s,5
%% aggregated 10/1s,10
modify_limiter(#{emqx_limiter := Limiter} = NewConf) ->
#{message_routing := #{bucket := Bucket} = Routing} = Limiter,
#{default := #{per_client := Client} = Default} = Bucket,
Client2 = Client#{rate := 5,
initial := 0,
capacity := 5,
low_water_mark := 1},
Default2 = Default#{per_client := Client2,
aggregated := #{rate => 10,
initial => 0,
capacity => 10
}},
Bucket2 = Bucket#{default := Default2},
Routing2 = Routing#{bucket := Bucket2},
NewConf2 = NewConf#{emqx_limiter := Limiter#{message_routing := Routing2}},
emqx_config:put(NewConf2),
emqx_limiter_manager:restart_server(message_routing),
ok.
%%--------------------------------------------------------------------
%% Test cases for channel info/stats/caps
%%--------------------------------------------------------------------
@ -547,6 +582,7 @@ t_quota_qos0(_) ->
{ok, Chann1} = emqx_channel:handle_in(Pub, Chann),
{ok, Chann2} = emqx_channel:handle_in(Pub, Chann1),
M1 = emqx_metrics:val('packets.publish.dropped') - 1,
timer:sleep(1000),
{ok, Chann3} = emqx_channel:handle_timeout(ref, expire_quota_limit, Chann2),
{ok, _} = emqx_channel:handle_in(Pub, Chann3),
M1 = emqx_metrics:val('packets.publish.dropped') - 1,
@ -718,7 +754,7 @@ t_handle_call_takeover_end(_) ->
t_handle_call_quota(_) ->
{reply, ok, _Chan} = emqx_channel:handle_call(
{quota, [{conn_messages_routing, {100,1}}]},
{quota, default},
channel()
).
@ -886,7 +922,7 @@ t_ws_cookie_init(_) ->
conn_mod => emqx_ws_connection,
ws_cookie => WsCookie
},
Channel = emqx_channel:init(ConnInfo, #{zone => default, listener => {tcp, default}}),
Channel = emqx_channel:init(ConnInfo, #{zone => default, limiter => limiter_cfg(), listener => {tcp, default}}),
?assertMatch(#{ws_cookie := WsCookie}, emqx_channel:info(clientinfo, Channel)).
%%--------------------------------------------------------------------
@ -911,7 +947,7 @@ channel(InitFields) ->
maps:fold(fun(Field, Value, Channel) ->
emqx_channel:set_field(Field, Value, Channel)
end,
emqx_channel:init(ConnInfo, #{zone => default, listener => {tcp, default}}),
emqx_channel:init(ConnInfo, #{zone => default, limiter => limiter_cfg(), listener => {tcp, default}}),
maps:merge(#{clientinfo => clientinfo(),
session => session(),
conn_state => connected
@ -957,5 +993,6 @@ session(InitFields) when is_map(InitFields) ->
%% conn: 5/s; overall: 10/s
quota() ->
emqx_limiter:init(zone, [{conn_messages_routing, {5, 1}},
{overall_messages_routing, {10, 1}}]).
emqx_limiter_container:get_limiter_by_names([message_routing], limiter_cfg()).
limiter_cfg() -> #{}.

View File

@ -134,6 +134,7 @@ start_apps(Apps, Handler) when is_function(Handler) ->
%% Because, minirest, ekka etc.. application will scan these modules
lists:foreach(fun load/1, [emqx | Apps]),
ekka:start(),
ok = emqx_ratelimiter_SUITE:base_conf(),
lists:foreach(fun(App) -> start_app(App, Handler) end, [emqx | Apps]).
load(App) ->

View File

@ -39,7 +39,7 @@ init_per_suite(Config) ->
ok = meck:expect(emqx_cm, mark_channel_connected, fun(_) -> ok end),
ok = meck:expect(emqx_cm, mark_channel_disconnected, fun(_) -> ok end),
%% Meck Limiter
ok = meck:new(emqx_limiter, [passthrough, no_history, no_link]),
ok = meck:new(emqx_htb_limiter, [passthrough, no_history, no_link]),
%% Meck Pd
ok = meck:new(emqx_pd, [passthrough, no_history, no_link]),
%% Meck Metrics
@ -60,17 +60,19 @@ init_per_suite(Config) ->
ok = meck:expect(emqx_alarm, deactivate, fun(_, _) -> ok end),
emqx_channel_SUITE:set_test_listener_confs(),
emqx_common_test_helpers:start_apps([]),
Config.
end_per_suite(_Config) ->
ok = meck:unload(emqx_transport),
catch meck:unload(emqx_channel),
ok = meck:unload(emqx_cm),
ok = meck:unload(emqx_limiter),
ok = meck:unload(emqx_htb_limiter),
ok = meck:unload(emqx_pd),
ok = meck:unload(emqx_metrics),
ok = meck:unload(emqx_hooks),
ok = meck:unload(emqx_alarm),
emqx_common_test_helpers:stop_apps([]),
ok.
init_per_testcase(TestCase, Config) when
@ -129,8 +131,9 @@ t_info(_) ->
socktype := tcp}, SockInfo).
t_info_limiter(_) ->
St = st(#{limiter => emqx_limiter:init(default, [])}),
?assertEqual(undefined, emqx_connection:info(limiter, St)).
Limiter = init_limiter(),
St = st(#{limiter => Limiter}),
?assertEqual(Limiter, emqx_connection:info(limiter, St)).
t_stats(_) ->
CPid = spawn(fun() ->
@ -250,24 +253,22 @@ t_handle_msg_shutdown(_) ->
?assertMatch({stop, {shutdown, for_testing}, _St}, handle_msg({shutdown, for_testing}, st())).
t_handle_call(_) ->
St = st(),
St = st(#{limiter => init_limiter()}),
?assertMatch({ok, _St}, handle_msg({event, undefined}, St)),
?assertMatch({reply, _Info, _NSt}, handle_call(self(), info, St)),
?assertMatch({reply, _Stats, _NSt}, handle_call(self(), stats, St)),
?assertMatch({reply, ok, _NSt}, handle_call(self(), {ratelimit, []}, St)),
?assertMatch({reply, ok, _NSt},
handle_call(self(), {ratelimit, [{conn_messages_in, {100, 1}}]}, St)),
handle_call(self(), {ratelimit, [{bytes_in, default}]}, St)),
?assertEqual({reply, ignored, St}, handle_call(self(), for_testing, St)),
?assertMatch({stop, {shutdown,kicked}, ok, _NSt},
handle_call(self(), kick, St)).
t_handle_timeout(_) ->
TRef = make_ref(),
State = st(#{idle_timer => TRef, limit_timer => TRef, stats_timer => TRef}),
State = st(#{idle_timer => TRef, stats_timer => TRef, limiter => init_limiter()}),
?assertMatch({stop, {shutdown,idle_timeout}, _NState},
emqx_connection:handle_timeout(TRef, idle_timeout, State)),
?assertMatch({ok, {event,running}, _NState},
emqx_connection:handle_timeout(TRef, limit_timeout, State)),
?assertMatch({ok, _NState},
emqx_connection:handle_timeout(TRef, emit_stats, State)),
?assertMatch({ok, _NState},
@ -279,13 +280,15 @@ t_handle_timeout(_) ->
?assertMatch({ok, _NState}, emqx_connection:handle_timeout(TRef, undefined, State)).
t_parse_incoming(_) ->
?assertMatch({ok, [], _NState}, emqx_connection:parse_incoming(<<>>, st())),
?assertMatch({[], _NState}, emqx_connection:parse_incoming(<<>>, [], st())),
?assertMatch({[], _NState}, emqx_connection:parse_incoming(<<"for_testing">>, [], st())).
t_next_incoming_msgs(_) ->
?assertEqual({incoming, packet}, emqx_connection:next_incoming_msgs([packet])),
?assertEqual([{incoming, packet2}, {incoming, packet1}],
emqx_connection:next_incoming_msgs([packet1, packet2])).
State = st(#{}),
?assertEqual({ok, [{incoming, packet}], State},
emqx_connection:next_incoming_msgs([packet], [], State)),
?assertEqual({ok, [{incoming, packet2}, {incoming, packet1}], State},
emqx_connection:next_incoming_msgs([packet1, packet2], [], State)).
t_handle_incoming(_) ->
?assertMatch({ok, _Out, _NState},
@ -331,26 +334,28 @@ t_handle_info(_) ->
?assertMatch({ok, _NState}, emqx_connection:handle_info(for_testing, st())).
t_ensure_rate_limit(_) ->
State = emqx_connection:ensure_rate_limit(#{}, st(#{limiter => undefined})),
WhenOk = fun emqx_connection:next_incoming_msgs/3,
{ok, [], State} = emqx_connection:check_limiter([], [], WhenOk, [], st(#{limiter => undefined})),
?assertEqual(undefined, emqx_connection:info(limiter, State)),
ok = meck:expect(emqx_limiter, check,
fun(_, _) -> {ok, emqx_limiter:init(default, [])} end),
State1 = emqx_connection:ensure_rate_limit(#{}, st(#{limiter => #{}})),
?assertEqual(undefined, emqx_connection:info(limiter, State1)),
Limiter = init_limiter(),
{ok, [], State1} = emqx_connection:check_limiter([], [], WhenOk, [], st(#{limiter => Limiter})),
?assertEqual(Limiter, emqx_connection:info(limiter, State1)),
ok = meck:expect(emqx_limiter, check,
fun(_, _) -> {pause, 3000, emqx_limiter:init(default, [])} end),
State2 = emqx_connection:ensure_rate_limit(#{}, st(#{limiter => #{}})),
?assertEqual(undefined, emqx_connection:info(limiter, State2)),
?assertEqual(blocked, emqx_connection:info(sockstate, State2)).
ok = meck:expect(emqx_htb_limiter, check,
fun(_, Client) -> {pause, 3000, undefined, Client} end),
{ok, State2} = emqx_connection:check_limiter([{1000, bytes_in}], [], WhenOk, [], st(#{limiter => Limiter})),
meck:unload(emqx_htb_limiter),
ok = meck:new(emqx_htb_limiter, [passthrough, no_history, no_link]),
?assertNotEqual(undefined, emqx_connection:info(limiter_timer, State2)).
t_activate_socket(_) ->
State = st(),
Limiter = init_limiter(),
State = st(#{limiter => Limiter}),
{ok, NStats} = emqx_connection:activate_socket(State),
?assertEqual(running, emqx_connection:info(sockstate, NStats)),
State1 = st(#{sockstate => blocked}),
State1 = st(#{sockstate => blocked, limiter_timer => any_timer}),
?assertEqual({ok, State1}, emqx_connection:activate_socket(State1)),
State2 = st(#{sockstate => closed}),
@ -458,7 +463,10 @@ with_conn(TestFun, Opts) when is_map(Opts) ->
TrapExit = maps:get(trap_exit, Opts, false),
process_flag(trap_exit, TrapExit),
{ok, CPid} = emqx_connection:start_link(emqx_transport, sock,
maps:merge(Opts, #{zone => default, listener => {tcp, default}})),
maps:merge(Opts,
#{zone => default,
limiter => limiter_cfg(),
listener => {tcp, default}})),
TestFun(CPid),
TrapExit orelse emqx_connection:stop(CPid),
ok.
@ -481,7 +489,8 @@ st(InitFields) when is_map(InitFields) ->
st(InitFields, #{}).
st(InitFields, ChannelFields) when is_map(InitFields) ->
St = emqx_connection:init_state(emqx_transport, sock, #{zone => default,
listener => {tcp, default}}),
limiter => limiter_cfg(),
listener => {tcp, default}}),
maps:fold(fun(N, V, S) -> emqx_connection:set_field(N, V, S) end,
emqx_connection:set_field(channel, channel(ChannelFields), St),
InitFields
@ -515,7 +524,7 @@ channel(InitFields) ->
maps:fold(fun(Field, Value, Channel) ->
emqx_channel:set_field(Field, Value, Channel)
end,
emqx_channel:init(ConnInfo, #{zone => default, listener => {tcp, default}}),
emqx_channel:init(ConnInfo, #{zone => default, limiter => limiter_cfg(), listener => {tcp, default}}),
maps:merge(#{clientinfo => ClientInfo,
session => Session,
conn_state => connected
@ -524,3 +533,8 @@ channel(InitFields) ->
handle_msg(Msg, St) -> emqx_connection:handle_msg(Msg, St).
handle_call(Pid, Call, St) -> emqx_connection:handle_call(Pid, Call, St).
limiter_cfg() -> #{}.
init_limiter() ->
emqx_limiter_container:get_limiter_by_names([bytes_in, message_in], limiter_cfg()).

View File

@ -46,6 +46,7 @@ init_per_testcase(Case, Config)
emqx_config:put([listeners, tcp], #{ listener_test =>
#{ bind => {"127.0.0.1", 9999}
, max_connections => 4321
, limiter => #{}
}
}),
emqx_config:put([rate_limit], #{max_conn_rate => 1000}),

View File

@ -0,0 +1,659 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_ratelimiter_SUITE).
-compile(export_all).
-compile(nowarn_export_all).
-define(APP, emqx).
-include_lib("eunit/include/eunit.hrl").
-include_lib("common_test/include/ct.hrl").
-define(BASE_CONF, <<"""
emqx_limiter {
bytes_in {
global.rate = infinity
zone.default.rate = infinity
bucket.default {
zone = default
aggregated.rate = infinity
aggregated.capacity = infinity
per_client.rate = \"100MB/1s\"
per_client.capacity = infinity
}
}
message_in {
global.rate = infinity
zone.default.rate = infinity
bucket.default {
zone = default
aggregated.rate = infinity
aggregated.capacity = infinity
per_client.rate = infinity
per_client.capacity = infinity
}
}
connection {
global.rate = infinity
zone.default.rate = infinity
bucket.default {
zone = default
aggregated.rate = infinity
aggregated.capacity = infinity
per_client.rate = infinity
per_client.capacity = infinity
}
}
message_routing {
global.rate = infinity
zone.default.rate = infinity
bucket.default {
zone = default
aggregated.rate = infinity
aggregated.capacity = infinity
per_client.rate = infinity
per_client.capacity = infinity
}
}
}
""">>).
-record(client, { counter :: counters:counter_ref()
, start :: pos_integer()
, endtime :: pos_integer()
, obtained :: pos_integer()
, rate :: float()
, client :: emqx_htb_limiter:client()
}).
-define(LOGT(Format, Args), ct:pal("TEST_SUITE: " ++ Format, Args)).
-define(RATE(Rate), to_rate(Rate)).
-define(NOW, erlang:system_time(millisecond)).
%%--------------------------------------------------------------------
%% Setups
%%--------------------------------------------------------------------
all() ->
emqx_common_test_helpers:all(?MODULE).
init_per_suite(Config) ->
ok = emqx_config:init_load(emqx_limiter_schema, ?BASE_CONF),
emqx_common_test_helpers:start_apps([?APP]),
Config.
end_per_suite(_Config) ->
emqx_common_test_helpers:stop_apps([?APP]).
init_per_testcase(_TestCase, Config) ->
Config.
base_conf() ->
emqx_config:init_load(emqx_limiter_schema, ?BASE_CONF).
%%--------------------------------------------------------------------
%% Test Cases Bucket Level
%%--------------------------------------------------------------------
t_max_retry_time(_) ->
Cfg = fun(Cfg) ->
Cfg#{rate := 1,
capacity := 1,
max_retry_time := 500,
failure_strategy := drop}
end,
Case = fun() ->
Client = connect(default),
Begin = ?NOW,
Result = emqx_htb_limiter:consume(101, Client),
?assertMatch({drop, _}, Result),
Time = ?NOW - Begin,
?assert(Time >= 500 andalso Time < 550)
end,
with_per_client(default, Cfg, Case).
t_divisible(_) ->
Cfg = fun(Cfg) ->
Cfg#{divisible := true,
rate := ?RATE("1000/1s"),
initial := 600,
capacity := 600}
end,
Case = fun() ->
Client = connect(default),
Result = emqx_htb_limiter:check(1000, Client),
?assertMatch({partial,
400,
#{continuation := _,
diff := 400,
start := _,
need := 1000},
_}, Result)
end,
with_per_client(default, Cfg, Case).
t_low_water_mark(_) ->
Cfg = fun(Cfg) ->
Cfg#{low_water_mark := 400,
rate := ?RATE("1000/1s"),
initial := 1000,
capacity := 1000}
end,
Case = fun() ->
Client = connect(default),
Result = emqx_htb_limiter:check(500, Client),
?assertMatch({ok, _}, Result),
{_, Client2} = Result,
Result2 = emqx_htb_limiter:check(101, Client2),
?assertMatch({pause,
_,
#{continuation := undefined,
diff := 0},
_}, Result2)
end,
with_per_client(default, Cfg, Case).
t_infinity_client(_) ->
Fun = fun(#{aggregated := Aggr, per_client := Cli} = Bucket) ->
Aggr2 = Aggr#{rate := infinity,
capacity := infinity},
Cli2 = Cli#{rate := infinity, capacity := infinity},
Bucket#{aggregated := Aggr2,
per_client := Cli2}
end,
Case = fun() ->
Client = connect(default),
?assertEqual(infinity, Client),
Result = emqx_htb_limiter:check(100000, Client),
?assertEqual({ok, Client}, Result)
end,
with_bucket(default, Fun, Case).
t_short_board(_) ->
Fun = fun(#{aggregated := Aggr, per_client := Cli} = Bucket) ->
Aggr2 = Aggr#{rate := ?RATE("100/1s"),
initial := 0,
capacity := 100},
Cli2 = Cli#{rate := ?RATE("600/1s"),
capacity := 600,
initial := 600},
Bucket#{aggregated := Aggr2,
per_client := Cli2}
end,
Case = fun() ->
Counter = counters:new(1, [write_concurrency]),
start_client(default, ?NOW + 2000, Counter, 20),
timer:sleep(2100),
check_average_rate(Counter, 2, 100, 20)
end,
with_bucket(default, Fun, Case).
t_rate(_) ->
Fun = fun(#{aggregated := Aggr, per_client := Cli} = Bucket) ->
Aggr2 = Aggr#{rate := ?RATE("100/100ms"),
initial := 0,
capacity := infinity},
Cli2 = Cli#{rate := infinity,
capacity := infinity,
initial := 0},
Bucket#{aggregated := Aggr2,
per_client := Cli2}
end,
Case = fun() ->
Client = connect(default),
Ts1 = erlang:system_time(millisecond),
C1 = emqx_htb_limiter:available(Client),
timer:sleep(1000),
Ts2 = erlang:system_time(millisecond),
C2 = emqx_htb_limiter:available(Client),
ShouldInc = floor((Ts2 - Ts1) / 100) * 100,
Inc = C2 - C1,
?assert(in_range(Inc, ShouldInc - 100, ShouldInc + 100), "test bucket rate")
end,
with_bucket(default, Fun, Case).
t_capacity(_) ->
Capacity = 600,
Fun = fun(#{aggregated := Aggr, per_client := Cli} = Bucket) ->
Aggr2 = Aggr#{rate := ?RATE("100/100ms"),
initial := 0,
capacity := 600},
Cli2 = Cli#{rate := infinity,
capacity := infinity,
initial := 0},
Bucket#{aggregated := Aggr2,
per_client := Cli2}
end,
Case = fun() ->
Client = connect(default),
timer:sleep(1000),
C1 = emqx_htb_limiter:available(Client),
?assertEqual(Capacity, C1, "test bucket capacity")
end,
with_bucket(default, Fun, Case).
%%--------------------------------------------------------------------
%% Test Cases Zone Level
%%--------------------------------------------------------------------
t_limit_zone_with_unlimit_bucket(_) ->
ZoneMod = fun(Cfg) ->
Cfg#{rate := ?RATE("600/1s"),
burst := ?RATE("60/1s")}
end,
Bucket = fun(#{aggregated := Aggr, per_client := Cli} = Bucket) ->
Aggr2 = Aggr#{rate := infinity,
initial := 0,
capacity := infinity},
Cli2 = Cli#{rate := infinity,
initial := 0,
capacity := infinity,
divisible := true},
Bucket#{aggregated := Aggr2, per_client := Cli2}
end,
Case = fun() ->
C1 = counters:new(1, [write_concurrency]),
start_client(b1, ?NOW + 2000, C1, 20),
timer:sleep(2100),
check_average_rate(C1, 2, 600, 1000)
end,
with_zone(default, ZoneMod, [{b1, Bucket}], Case).
%%--------------------------------------------------------------------
%% Test Cases Global Level
%%--------------------------------------------------------------------
t_burst_and_fairness(_) ->
GlobalMod = fun(Cfg) ->
Cfg#{burst := ?RATE("60/1s")}
end,
ZoneMod = fun(Cfg) ->
Cfg#{rate := ?RATE("600/1s"),
burst := ?RATE("60/1s")}
end,
Bucket = fun(#{aggregated := Aggr, per_client := Cli} = Bucket) ->
Aggr2 = Aggr#{rate := ?RATE("500/1s"),
initial := 0,
capacity := 500},
Cli2 = Cli#{rate := ?RATE("600/1s"),
capacity := 600,
initial := 600},
Bucket#{aggregated := Aggr2,
per_client := Cli2}
end,
Case = fun() ->
C1 = counters:new(1, [write_concurrency]),
C2 = counters:new(1, [write_concurrency]),
start_client(b1, ?NOW + 2000, C1, 20),
start_client(b2, ?NOW + 2000, C2, 30),
timer:sleep(2100),
check_average_rate(C1, 2, 330, 25),
check_average_rate(C2, 2, 330, 25)
end,
with_global(GlobalMod,
default,
ZoneMod,
[{b1, Bucket}, {b2, Bucket}],
Case).
t_limit_global_with_unlimit_other(_) ->
GlobalMod = fun(Cfg) ->
Cfg#{rate := ?RATE("600/1s")}
end,
ZoneMod = fun(Cfg) -> Cfg#{rate := infinity} end,
Bucket = fun(#{aggregated := Aggr, per_client := Cli} = Bucket) ->
Aggr2 = Aggr#{rate := infinity,
initial := 0,
capacity := infinity},
Cli2 = Cli#{rate := infinity,
capacity := infinity,
initial := 0},
Bucket#{aggregated := Aggr2,
per_client := Cli2}
end,
Case = fun() ->
C1 = counters:new(1, [write_concurrency]),
start_client(b1, ?NOW + 2000, C1, 20),
timer:sleep(2100),
check_average_rate(C1, 2, 600, 100)
end,
with_global(GlobalMod,
default,
ZoneMod,
[{b1, Bucket}],
Case).
t_multi_zones(_) ->
GlobalMod = fun(Cfg) ->
Cfg#{rate := ?RATE("600/1s")}
end,
Zone1 = fun(Cfg) ->
Cfg#{rate := ?RATE("400/1s")}
end,
Zone2 = fun(Cfg) ->
Cfg#{rate := ?RATE("500/1s")}
end,
Bucket = fun(Zone, Rate) ->
fun(#{aggregated := Aggr, per_client := Cli} = Bucket) ->
Aggr2 = Aggr#{rate := infinity,
initial := 0,
capacity := infinity},
Cli2 = Cli#{rate := Rate,
capacity := infinity,
initial := 0},
Bucket#{aggregated := Aggr2,
per_client := Cli2,
zone := Zone}
end
end,
Case = fun() ->
C1 = counters:new(1, [write_concurrency]),
C2 = counters:new(1, [write_concurrency]),
start_client(b1, ?NOW + 2000, C1, 25),
start_client(b2, ?NOW + 2000, C2, 20),
timer:sleep(2100),
check_average_rate(C1, 2, 300, 25),
check_average_rate(C2, 2, 300, 25)
end,
with_global(GlobalMod,
[z1, z2],
[Zone1, Zone2],
[{b1, Bucket(z1, ?RATE("400/1s"))}, {b2, Bucket(z2, ?RATE("500/1s"))}],
Case).
%% because the simulated client will try to reach the maximum rate
%% when divisiable = true, a large number of divided tokens will be generated
%% so this is not an accurate test
t_multi_zones_with_divisible(_) ->
GlobalMod = fun(Cfg) ->
Cfg#{rate := ?RATE("600/1s")}
end,
Zone1 = fun(Cfg) ->
Cfg#{rate := ?RATE("400/1s")}
end,
Zone2 = fun(Cfg) ->
Cfg#{rate := ?RATE("500/1s")}
end,
Bucket = fun(Zone, Rate) ->
fun(#{aggregated := Aggr, per_client := Cli} = Bucket) ->
Aggr2 = Aggr#{rate := Rate,
initial := 0,
capacity := infinity},
Cli2 = Cli#{rate := Rate,
divisible := true,
capacity := infinity,
initial := 0},
Bucket#{aggregated := Aggr2,
per_client := Cli2,
zone := Zone}
end
end,
Case = fun() ->
C1 = counters:new(1, [write_concurrency]),
C2 = counters:new(1, [write_concurrency]),
start_client(b1, ?NOW + 2000, C1, 25),
start_client(b2, ?NOW + 2000, C2, 20),
timer:sleep(2100),
check_average_rate(C1, 2, 300, 120),
check_average_rate(C2, 2, 300, 120)
end,
with_global(GlobalMod,
[z1, z2],
[Zone1, Zone2],
[{b1, Bucket(z1, ?RATE("400/1s"))}, {b2, Bucket(z2, ?RATE("500/1s"))}],
Case).
t_zone_hunger_and_fair(_) ->
GlobalMod = fun(Cfg) ->
Cfg#{rate := ?RATE("600/1s")}
end,
Zone1 = fun(Cfg) ->
Cfg#{rate := ?RATE("600/1s")}
end,
Zone2 = fun(Cfg) ->
Cfg#{rate := ?RATE("50/1s")}
end,
Bucket = fun(Zone, Rate) ->
fun(#{aggregated := Aggr, per_client := Cli} = Bucket) ->
Aggr2 = Aggr#{rate := infinity,
initial := 0,
capacity := infinity},
Cli2 = Cli#{rate := Rate,
capacity := infinity,
initial := 0},
Bucket#{aggregated := Aggr2,
per_client := Cli2,
zone := Zone}
end
end,
Case = fun() ->
C1 = counters:new(1, [write_concurrency]),
C2 = counters:new(1, [write_concurrency]),
start_client(b1, ?NOW + 2000, C1, 20),
start_client(b2, ?NOW + 2000, C2, 20),
timer:sleep(2100),
check_average_rate(C1, 2, 550, 25),
check_average_rate(C2, 2, 50, 25)
end,
with_global(GlobalMod,
[z1, z2],
[Zone1, Zone2],
[{b1, Bucket(z1, ?RATE("600/1s"))}, {b2, Bucket(z2, ?RATE("50/1s"))}],
Case).
%%--------------------------------------------------------------------
%%% Internal functions
%%--------------------------------------------------------------------
start_client(Name, EndTime, Counter, Number) ->
lists:foreach(fun(_) ->
spawn(fun() ->
start_client(Name, EndTime, Counter)
end)
end,
lists:seq(1, Number)).
start_client(Name, EndTime, Counter) ->
#{per_client := PerClient} =
emqx_config:get([emqx_limiter, message_routing, bucket, Name]),
#{rate := Rate} = PerClient,
Client = #client{start = ?NOW,
endtime = EndTime,
counter = Counter,
obtained = 0,
rate = Rate,
client = connect(Name)
},
client_loop(Client).
%% the simulated client will try to reach the configured rate as much as possible
%% note this client will not considered the capacity, so must make sure rate < capacity
client_loop(#client{start = Start,
endtime = EndTime,
obtained = Obtained,
rate = Rate} = State) ->
Now = ?NOW,
Period = emqx_limiter_schema:minimum_period(),
MinPeriod = erlang:ceil(0.25 * Period),
if Now >= EndTime ->
stop;
Now - Start < MinPeriod ->
timer:sleep(client_random_val(MinPeriod)),
client_loop(State);
Obtained =< 0 ->
Rand = client_random_val(Rate),
client_try_check(Rand, State);
true ->
Span = Now - Start,
CurrRate = Obtained * Period / Span,
if CurrRate < Rate ->
Rand = client_random_val(Rate),
client_try_check(Rand, State);
true ->
LeftTime = EndTime - Now,
CanSleep = erlang:min(LeftTime, client_random_val(MinPeriod div 2)),
timer:sleep(CanSleep),
client_loop(State)
end
end.
client_try_check(Need, #client{counter = Counter,
endtime = EndTime,
obtained = Obtained,
client = Client} = State) ->
case emqx_htb_limiter:check(Need, Client) of
{ok, Client2} ->
case Need of
#{need := Val} -> ok;
Val -> ok
end,
counters:add(Counter, 1, Val),
client_loop(State#client{obtained = Obtained + Val, client = Client2});
{_, Pause, Retry, Client2} ->
LeftTime = EndTime - ?NOW,
if LeftTime =< 0 ->
stop;
true ->
timer:sleep(erlang:min(Pause, LeftTime)),
client_try_check(Retry, State#client{client = Client2})
end
end.
%% XXX not a god test, because client's rate maybe bigger than global rate
%% so if client' rate = infinity
%% client's divisible should be true or capacity must be bigger than number of each comsume
client_random_val(infinity) ->
1000;
%% random in 0.5Range ~ 1Range
client_random_val(Range) ->
Half = erlang:floor(Range) div 2,
Rand = rand:uniform(Half + 1) + Half,
erlang:max(1, Rand).
to_rate(Str) ->
{ok, Rate} = emqx_limiter_schema:to_rate(Str),
Rate.
with_global(Modifier, ZoneName, ZoneModifier, Buckets, Case) ->
Path = [emqx_limiter, message_routing],
#{global := Global} = Cfg = emqx_config:get(Path),
Cfg2 = Cfg#{global := Modifier(Global)},
with_zone(Cfg2, ZoneName, ZoneModifier, Buckets, Case).
with_zone(Name, Modifier, Buckets, Case) ->
Path = [emqx_limiter, message_routing],
Cfg = emqx_config:get(Path),
with_zone(Cfg, Name, Modifier, Buckets, Case).
with_zone(Cfg, Name, Modifier, Buckets, Case) ->
Path = [emqx_limiter, message_routing],
#{zone := ZoneCfgs,
bucket := BucketCfgs} = Cfg,
ZoneCfgs2 = apply_modifier(Name, Modifier, ZoneCfgs),
BucketCfgs2 = apply_modifier(Buckets, BucketCfgs),
Cfg2 = Cfg#{zone := ZoneCfgs2, bucket := BucketCfgs2},
with_config(Path, fun(_) -> Cfg2 end, Case).
with_bucket(Bucket, Modifier, Case) ->
Path = [emqx_limiter, message_routing, bucket, Bucket],
with_config(Path, Modifier, Case).
with_per_client(Bucket, Modifier, Case) ->
Path = [emqx_limiter, message_routing, bucket, Bucket, per_client],
with_config(Path, Modifier, Case).
with_config(Path, Modifier, Case) ->
Cfg = emqx_config:get(Path),
NewCfg = Modifier(Cfg),
ct:pal("test with config:~p~n", [NewCfg]),
emqx_config:put(Path, NewCfg),
emqx_limiter_manager:restart_server(message_routing),
timer:sleep(100),
DelayReturn
= try
Return = Case(),
fun() -> Return end
catch Type:Reason:Trace ->
fun() -> erlang:raise(Type, Reason, Trace) end
end,
emqx_config:put(Path, Cfg),
DelayReturn().
connect(Name) ->
emqx_limiter_server:connect(message_routing, Name).
check_average_rate(Counter, Second, Rate, Margin) ->
Cost = counters:get(Counter, 1),
PerSec = Cost / Second,
?LOGT(">>>> Cost:~p PerSec:~p Rate:~p ~n", [Cost, PerSec, Rate]),
?assert(in_range(PerSec, Rate - Margin, Rate + Margin)).
print_average_rate(Counter, Second) ->
Cost = counters:get(Counter, 1),
PerSec = Cost / Second,
ct:pal(">>>> Cost:~p PerSec:~p ~n", [Cost, PerSec]).
in_range(Val, Min, _Max) when Val < Min ->
ct:pal("Val:~p smaller than min bound:~p~n", [Val, Min]),
false;
in_range(Val, _Min, Max) when Val > Max->
ct:pal("Val:~p bigger than max bound:~p~n", [Val, Max]),
false;
in_range(_, _, _) ->
true.
apply_modifier(Name, Modifier, Cfg) when is_list(Name) ->
Pairs = lists:zip(Name, Modifier),
apply_modifier(Pairs, Cfg);
apply_modifier(Name, Modifier, #{default := Template} = Cfg) ->
Cfg#{Name => Modifier(Template)}.
apply_modifier(Pairs, #{default := Template}) ->
Fun = fun({N, M}, Acc) ->
Acc#{N => M(Template)}
end,
lists:foldl(Fun, #{}, Pairs).

View File

@ -105,6 +105,15 @@ end_per_testcase(_, Config) ->
emqx_common_test_helpers:stop_apps([]),
Config.
init_per_suite(Config) ->
emqx_channel_SUITE:set_test_listener_confs(),
emqx_common_test_helpers:start_apps([]),
Config.
end_per_suite(_) ->
emqx_common_test_helpers:stop_apps([]),
ok.
%%--------------------------------------------------------------------
%% Test Cases
%%--------------------------------------------------------------------
@ -131,7 +140,9 @@ t_header(_) ->
(<<"x-forwarded-port">>, _, _) -> <<"1000">> end),
set_ws_opts(proxy_address_header, <<"x-forwarded-for">>),
set_ws_opts(proxy_port_header, <<"x-forwarded-port">>),
{ok, St, _} = ?ws_conn:websocket_init([req, #{zone => default, listener => {ws, default}}]),
{ok, St, _} = ?ws_conn:websocket_init([req, #{zone => default,
limiter => limiter_cfg(),
listener => {ws, default}}]),
WsPid = spawn(fun() ->
receive {call, From, info} ->
gen_server:reply(From, ?ws_conn:info(St))
@ -143,8 +154,9 @@ t_header(_) ->
} = SockInfo.
t_info_limiter(_) ->
St = st(#{limiter => emqx_limiter:init(external, [])}),
?assertEqual(undefined, ?ws_conn:info(limiter, St)).
Limiter = init_limiter(),
St = st(#{limiter => Limiter}),
?assertEqual(Limiter, ?ws_conn:info(limiter, St)).
t_info_channel(_) ->
#{conn_state := connected} = ?ws_conn:info(channel, st()).
@ -249,7 +261,7 @@ t_ws_non_check_origin(_) ->
headers => [{<<"origin">>, <<"http://localhost:18080">>}]})).
t_init(_) ->
Opts = #{listener => {ws, default}, zone => default},
Opts = #{listener => {ws, default}, zone => default, limiter => limiter_cfg()},
ok = meck:expect(cowboy_req, parse_header, fun(_, req) -> undefined end),
ok = meck:expect(cowboy_req, reply, fun(_, Req) -> Req end),
{ok, req, _} = ?ws_conn:init(req, Opts),
@ -329,8 +341,11 @@ t_websocket_info_deliver(_) ->
t_websocket_info_timeout_limiter(_) ->
Ref = make_ref(),
LimiterT = init_limiter(),
Next = fun emqx_ws_connection:when_msg_in/3,
Limiter = emqx_limiter_container:set_retry_context({retry, [], [], Next}, LimiterT),
Event = {timeout, Ref, limit_timeout},
{[{active, true}], St} = websocket_info(Event, st(#{limit_timer => Ref})),
{ok, St} = websocket_info(Event, st(#{limiter => Limiter})),
?assertEqual([], ?ws_conn:info(postponed, St)).
t_websocket_info_timeout_keepalive(_) ->
@ -389,23 +404,27 @@ t_handle_timeout_emit_stats(_) ->
?assertEqual(undefined, ?ws_conn:info(stats_timer, St)).
t_ensure_rate_limit(_) ->
Limiter = emqx_limiter:init(external, {1, 10}, {100, 1000}, []),
Limiter = init_limiter(),
St = st(#{limiter => Limiter}),
St1 = ?ws_conn:ensure_rate_limit(#{cnt => 0, oct => 0}, St),
St2 = ?ws_conn:ensure_rate_limit(#{cnt => 11, oct => 1200}, St1),
?assertEqual(blocked, ?ws_conn:info(sockstate, St2)),
?assertEqual([{active, false}], ?ws_conn:info(postponed, St2)).
{ok, Need} = emqx_limiter_schema:to_capacity("1GB"), %% must bigger than value in emqx_ratelimit_SUITE
St1 = ?ws_conn:check_limiter([{Need, bytes_in}],
[],
fun(_, _, S) -> S end,
[],
St),
?assertEqual(blocked, ?ws_conn:info(sockstate, St1)),
?assertEqual([{active, false}], ?ws_conn:info(postponed, St1)).
t_parse_incoming(_) ->
St = ?ws_conn:parse_incoming(<<48,3>>, st()),
St1 = ?ws_conn:parse_incoming(<<0,1,116>>, St),
{Packets, St} = ?ws_conn:parse_incoming(<<48,3>>, [], st()),
{Packets1, _} = ?ws_conn:parse_incoming(<<0,1,116>>, Packets, St),
Packet = ?PUBLISH_PACKET(?QOS_0, <<"t">>, undefined, <<>>),
?assertMatch([{incoming, Packet}], ?ws_conn:info(postponed, St1)).
?assertMatch([{incoming, Packet}], Packets1).
t_parse_incoming_frame_error(_) ->
St = ?ws_conn:parse_incoming(<<3,2,1,0>>, st()),
{Packets, _St} = ?ws_conn:parse_incoming(<<3,2,1,0>>, [], st()),
FrameError = {frame_error, function_clause},
[{incoming, FrameError}] = ?ws_conn:info(postponed, St).
[{incoming, FrameError}] = Packets.
t_handle_incomming_frame_error(_) ->
FrameError = {frame_error, bad_qos},
@ -440,7 +459,9 @@ t_shutdown(_) ->
st() -> st(#{}).
st(InitFields) when is_map(InitFields) ->
{ok, St, _} = ?ws_conn:websocket_init([req, #{zone => default, listener => {ws, default}}]),
{ok, St, _} = ?ws_conn:websocket_init([req, #{zone => default,
listener => {ws, default},
limiter => limiter_cfg()}]),
maps:fold(fun(N, V, S) -> ?ws_conn:set_field(N, V, S) end,
?ws_conn:set_field(channel, channel(), St),
InitFields
@ -474,7 +495,9 @@ channel(InitFields) ->
maps:fold(fun(Field, Value, Channel) ->
emqx_channel:set_field(Field, Value, Channel)
end,
emqx_channel:init(ConnInfo, #{zone => default, listener => {ws, default}}),
emqx_channel:init(ConnInfo, #{zone => default,
listener => {ws, default},
limiter => limiter_cfg()}),
maps:merge(#{clientinfo => ClientInfo,
session => Session,
conn_state => connected
@ -533,3 +556,8 @@ ws_client(State) ->
after 1000 ->
ct:fail(ws_timeout)
end.
limiter_cfg() -> #{}.
init_limiter() ->
emqx_limiter_container:get_limiter_by_names([bytes_in, message_in], limiter_cfg()).

View File

@ -434,8 +434,15 @@ typename_to_spec("log_level()", _Mod) ->
};
typename_to_spec("rate()", _Mod) ->
#{type => string, example => <<"10M/s">>};
typename_to_spec("bucket_rate()", _Mod) ->
#{type => string, example => <<"10M/s, 100M">>};
typename_to_spec("capacity()", _Mod) ->
#{type => string, example => <<"100M">>};
typename_to_spec("burst_rate()", _Mod) ->
%% 0/0s = no burst
#{type => string, example => <<"10M/1s">>};
typename_to_spec("failure_strategy()", _Mod) ->
#{type => string, example => <<"force">>};
typename_to_spec("initial()", _Mod) ->
#{type => string, example => <<"0M">>};
typename_to_spec(Name, Mod) ->
Spec = range(Name),
Spec1 = remote_module_type(Spec, Name, Mod),

View File

@ -70,13 +70,14 @@ all() ->
init_per_suite(Config) ->
ok = emqx_config:init_load(emqx_gateway_schema, ?CONF_DEFAULT),
emqx_mgmt_api_test_util:init_suite([emqx_conf, emqx_gateway]),
application:load(emqx_gateway),
emqx_mgmt_api_test_util:init_suite([emqx_conf]),
Config.
end_per_suite(Config) ->
timer:sleep(300),
{ok, _} = emqx_conf:remove([<<"gateway">>,<<"lwm2m">>], #{}),
emqx_mgmt_api_test_util:end_suite([emqx_gateway, emqx_conf]),
emqx_mgmt_api_test_util:end_suite([emqx_conf]),
Config.
init_per_testcase(_AllTestCase, Config) ->

View File

@ -1,50 +0,0 @@
##--------------------------------------------------------------------
## Emq X Rate Limiter
##--------------------------------------------------------------------
emqx_limiter {
bytes_in {
global = "100KB/10s" # token generation rate
zone.default = "100kB/10s"
zone.external = "20kB/10s"
bucket.tcp {
zone = default
aggregated = "100kB/10s,1Mb"
per_client = "100KB/10s,10Kb"
}
bucket.ssl {
zone = external
aggregated = "100kB/10s,1Mb"
per_client = "100KB/10s,10Kb"
}
}
message_in {
global = "100/10s"
zone.default = "100/10s"
bucket.bucket1 {
zone = default
aggregated = "100/10s,1000"
per_client = "100/10s,100"
}
}
connection {
global = "100/10s"
zone.default = "100/10s"
bucket.bucket1 {
zone = default
aggregated = "100/10s,1000"
per_client = "100/10s,100"
}
}
message_routing {
global = "100/10s"
zone.default = "100/10s"
bucket.bucket1 {
zone = default
aggregated = "100/10s,100"
per_client = "100/10s,10"
}
}
}

View File

@ -1,144 +0,0 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2019-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_limiter_client).
%% API
-export([create/5, make_ref/3, consume/2]).
-export_type([limiter/0]).
%% tocket bucket algorithm
-record(limiter, { tokens :: non_neg_integer()
, rate :: float()
, capacity :: decimal()
, lasttime :: millisecond()
, ref :: ref_limiter()
}).
-record(ref, { counter :: counters:counters_ref()
, index :: index()
, rate :: decimal()
, obtained :: non_neg_integer()
}).
%% TODO
%% we should add a nop-limiter, when all the upper layers (global, zone, and buckets ) are infinity
-type limiter() :: #limiter{}.
-type ref_limiter() :: #ref{}.
-type client() :: limiter() | ref_limiter().
-type millisecond() :: non_neg_integer().
-type pause_result(Client) :: {pause, millisecond(), Client}.
-type consume_result(Client) :: {ok, Client}
| pause_result(Client).
-type decimal() :: emqx_limiter_decimal:decimal().
-type index() :: emqx_limiter_server:index().
-define(NOW, erlang:monotonic_time(millisecond)).
-define(MINIUMN_PAUSE, 100).
-import(emqx_limiter_decimal, [sub/2]).
%%--------------------------------------------------------------------
%% API
%%--------------------------------------------------------------------
-spec create(float(),
decimal(),
counters:counters_ref(),
index(),
decimal()) -> limiter().
create(Rate, Capacity, Counter, Index, CounterRate) ->
#limiter{ tokens = Capacity
, rate = Rate
, capacity = Capacity
, lasttime = ?NOW
, ref = make_ref(Counter, Index, CounterRate)
}.
-spec make_ref(counters:counters_ref(), index(), decimal()) -> ref_limiter().
make_ref(Counter, Idx, Rate) ->
#ref{counter = Counter, index = Idx, rate = Rate, obtained = 0}.
-spec consume(pos_integer(), Client) -> consume_result(Client)
when Client :: client().
consume(Need, #limiter{tokens = Tokens,
capacity = Capacity} = Limiter) ->
if Need =< Tokens ->
try_consume_counter(Need, Limiter);
Need > Capacity ->
%% FIXME
%% The client should be able to send 4kb data if the rate is configured to be 2kb/s, it just needs 2s to complete.
throw("too big request"); %% FIXME how to deal this?
true ->
try_reset(Need, Limiter)
end;
consume(Need, #ref{counter = Counter,
index = Index,
rate = Rate,
obtained = Obtained} = Ref) ->
Tokens = counters:get(Counter, Index),
if Tokens >= Need ->
counters:sub(Counter, Index, Need),
{ok, Ref#ref{obtained = Obtained + Need}};
true ->
return_pause(Need - Tokens, Rate, Ref)
end.
%%--------------------------------------------------------------------
%% Internal functions
%%--------------------------------------------------------------------
-spec try_consume_counter(pos_integer(), limiter()) -> consume_result(limiter()).
try_consume_counter(Need,
#limiter{tokens = Tokens,
ref = #ref{counter = Counter,
index = Index,
obtained = Obtained,
rate = CounterRate} = Ref} = Limiter) ->
CT = counters:get(Counter, Index),
if CT >= Need ->
counters:sub(Counter, Index, Need),
{ok, Limiter#limiter{tokens = sub(Tokens, Need),
ref = Ref#ref{obtained = Obtained + Need}}};
true ->
return_pause(Need - CT, CounterRate, Limiter)
end.
-spec try_reset(pos_integer(), limiter()) -> consume_result(limiter()).
try_reset(Need,
#limiter{tokens = Tokens,
rate = Rate,
lasttime = LastTime,
capacity = Capacity} = Limiter) ->
Now = ?NOW,
Inc = erlang:floor((Now - LastTime) * Rate / emqx_limiter_schema:minimum_period()),
Tokens2 = erlang:min(Tokens + Inc, Capacity),
if Need > Tokens2 ->
return_pause(Need, Rate, Limiter);
true ->
Limiter2 = Limiter#limiter{tokens = Tokens2,
lasttime = Now},
try_consume_counter(Need, Limiter2)
end.
-spec return_pause(pos_integer(), decimal(), Client) -> pause_result(Client)
when Client :: client().
return_pause(_, infinity, Limiter) ->
%% workaround when emqx_limiter_server's rate is infinity
{pause, ?MINIUMN_PAUSE, Limiter};
return_pause(Diff, Rate, Limiter) ->
Pause = erlang:round(Diff * emqx_limiter_schema:minimum_period() / Rate),
{pause, erlang:max(Pause, ?MINIUMN_PAUSE), Limiter}.

View File

@ -1,140 +0,0 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_limiter_schema).
-include_lib("typerefl/include/types.hrl").
-export([ roots/0, fields/1, to_rate/1
, to_bucket_rate/1, minimum_period/0]).
-define(KILOBYTE, 1024).
-type limiter_type() :: bytes_in
| message_in
| connection
| message_routing.
-type bucket_name() :: atom().
-type zone_name() :: atom().
-type rate() :: infinity | float().
-type bucket_rate() :: list(infinity | number()).
-typerefl_from_string({rate/0, ?MODULE, to_rate}).
-typerefl_from_string({bucket_rate/0, ?MODULE, to_bucket_rate}).
-reflect_type([ rate/0
, bucket_rate/0
]).
-export_type([limiter_type/0, bucket_name/0, zone_name/0]).
-import(emqx_schema, [sc/2, map/2]).
roots() -> [emqx_limiter].
fields(emqx_limiter) ->
[ {bytes_in, sc(ref(limiter), #{})}
, {message_in, sc(ref(limiter), #{})}
, {connection, sc(ref(limiter), #{})}
, {message_routing, sc(ref(limiter), #{})}
];
fields(limiter) ->
[ {global, sc(rate(), #{})}
, {zone, sc(map("zone name", rate()), #{})}
, {bucket, sc(map("bucket id", ref(bucket)),
#{desc => "Token Buckets"})}
];
fields(bucket) ->
[ {zone, sc(atom(), #{desc => "the zone which the bucket in"})}
, {aggregated, sc(bucket_rate(), #{})}
, {per_client, sc(bucket_rate(), #{})}
].
%% minimum period is 100ms
minimum_period() ->
100.
%%--------------------------------------------------------------------
%% Internal functions
%%--------------------------------------------------------------------
ref(Field) -> hoconsc:ref(?MODULE, Field).
to_rate(Str) ->
Tokens = [string:trim(T) || T <- string:tokens(Str, "/")],
case Tokens of
["infinity"] ->
{ok, infinity};
[Quota, Interval] ->
{ok, Val} = to_quota(Quota),
case emqx_schema:to_duration_ms(Interval) of
{ok, Ms} when Ms > 0 ->
{ok, Val * minimum_period() / Ms};
_ ->
{error, Str}
end;
_ ->
{error, Str}
end.
to_bucket_rate(Str) ->
Tokens = [string:trim(T) || T <- string:tokens(Str, "/,")],
case Tokens of
[Rate, Capa] ->
{ok, infinity} = to_quota(Rate),
{ok, CapaVal} = to_quota(Capa),
if CapaVal =/= infinity ->
{ok, [infinity, CapaVal]};
true ->
{error, Str}
end;
[Quota, Interval, Capacity] ->
{ok, Val} = to_quota(Quota),
case emqx_schema:to_duration_ms(Interval) of
{ok, Ms} when Ms > 0 ->
{ok, CapaVal} = to_quota(Capacity),
{ok, [Val * minimum_period() / Ms, CapaVal]};
_ ->
{error, Str}
end;
_ ->
{error, Str}
end.
to_quota(Str) ->
{ok, MP} = re:compile("^\s*(?:(?:([1-9][0-9]*)([a-zA-z]*))|infinity)\s*$"),
Result = re:run(Str, MP, [{capture, all_but_first, list}]),
case Result of
{match, [Quota, Unit]} ->
Val = erlang:list_to_integer(Quota),
Unit2 = string:to_lower(Unit),
{ok, apply_unit(Unit2, Val)};
{match, [Quota]} ->
{ok, erlang:list_to_integer(Quota)};
{match, []} ->
{ok, infinity};
_ ->
{error, Str}
end.
apply_unit("", Val) -> Val;
apply_unit("kb", Val) -> Val * ?KILOBYTE;
apply_unit("mb", Val) -> Val * ?KILOBYTE * ?KILOBYTE;
apply_unit("gb", Val) -> Val * ?KILOBYTE * ?KILOBYTE * ?KILOBYTE;
apply_unit(Unit, _) -> throw("invalid unit:" ++ Unit).

View File

@ -1,426 +0,0 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
%% A hierachical token bucket algorithm
%% Note: this is not the linux HTB algorithm(http://luxik.cdi.cz/~devik/qos/htb/manual/theory.htm)
%% Algorithm:
%% 1. the root node periodically generates tokens and then distributes them
%% just like the oscillation of water waves
%% 2. the leaf node has a counter, which is the place where the token is actually held.
%% 3. other nodes only play the role of transmission, and the rate of the node is like a valve,
%% limiting the oscillation transmitted from the parent node
-module(emqx_limiter_server).
-behaviour(gen_server).
-include_lib("emqx/include/logger.hrl").
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3, format_status/2]).
-export([ start_link/1, connect/2, info/2
, name/1]).
-record(root, { rate :: rate() %% number of tokens generated per period
, period :: pos_integer() %% token generation interval(second)
, childs :: list(node_id()) %% node children
, consumed :: non_neg_integer()
}).
-record(zone, { id :: pos_integer()
, name :: zone_name()
, rate :: rate()
, obtained :: non_neg_integer() %% number of tokens obtained
, childs :: list(node_id())
}).
-record(bucket, { id :: pos_integer()
, name :: bucket_name()
, rate :: rate()
, obtained :: non_neg_integer()
, correction :: emqx_limiter_decimal:zero_or_float() %% token correction value
, capacity :: capacity()
, counter :: counters:counters_ref()
, index :: index()
}).
-record(state, { root :: undefined | root()
, counter :: undefined | counters:counters_ref() %% current counter to alloc
, index :: index()
, zones :: #{zone_name() => node_id()}
, nodes :: nodes()
, type :: limiter_type()
}).
%% maybe use maps is better, but record is fastter
-define(FIELD_OBTAINED, #zone.obtained).
-define(GET_FIELD(F, Node), element(F, Node)).
-define(CALL(Type, Msg), gen_server:call(name(Type), {?FUNCTION_NAME, Msg})).
-type node_id() :: pos_integer().
-type root() :: #root{}.
-type zone() :: #zone{}.
-type bucket() :: #bucket{}.
-type node_data() :: zone() | bucket().
-type nodes() :: #{node_id() => node_data()}.
-type zone_name() :: emqx_limiter_schema:zone_name().
-type limiter_type() :: emqx_limiter_schema:limiter_type().
-type bucket_name() :: emqx_limiter_schema:bucket_name().
-type rate() :: decimal().
-type flow() :: decimal().
-type capacity() :: decimal().
-type decimal() :: emqx_limiter_decimal:decimal().
-type state() :: #state{}.
-type index() :: pos_integer().
-export_type([index/0]).
-import(emqx_limiter_decimal, [add/2, sub/2, mul/2, add_to_counter/3, put_to_counter/3]).
%%--------------------------------------------------------------------
%% API
%%--------------------------------------------------------------------
-spec connect(limiter_type(), bucket_name()) -> emqx_limiter_client:client().
connect(Type, Bucket) ->
#{zone := Zone,
aggregated := [Aggr, Capacity],
per_client := [Client, ClientCapa]} = emqx:get_config([emqx_limiter, Type, bucket, Bucket]),
case emqx_limiter_manager:find_counter(Type, Zone, Bucket) of
{ok, Counter, Idx, Rate} ->
if Client =/= infinity andalso (Client < Aggr orelse ClientCapa < Capacity) ->
emqx_limiter_client:create(Client, ClientCapa, Counter, Idx, Rate);
true ->
emqx_limiter_client:make_ref(Counter, Idx, Rate)
end;
_ ->
?LOG(error, "can't find the bucket:~p which type is:~p~n", [Bucket, Type]),
throw("invalid bucket")
end.
-spec info(limiter_type(), atom()) -> term().
info(Type, Info) ->
?CALL(Type, Info).
-spec name(limiter_type()) -> atom().
name(Type) ->
erlang:list_to_atom(io_lib:format("~s_~s", [?MODULE, Type])).
%%--------------------------------------------------------------------
%% @doc
%% Starts the server
%% @end
%%--------------------------------------------------------------------
-spec start_link(limiter_type()) -> _.
start_link(Type) ->
gen_server:start_link({local, name(Type)}, ?MODULE, [Type], []).
%%--------------------------------------------------------------------
%%% gen_server callbacks
%%--------------------------------------------------------------------
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Initializes the server
%% @end
%%--------------------------------------------------------------------
-spec init(Args :: term()) -> {ok, State :: term()} |
{ok, State :: term(), Timeout :: timeout()} |
{ok, State :: term(), hibernate} |
{stop, Reason :: term()} |
ignore.
init([Type]) ->
State = #state{zones = #{},
nodes = #{},
type = Type,
index = 1},
State2 = init_tree(Type, State),
oscillate(State2#state.root#root.period),
{ok, State2}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Handling call messages
%% @end
%%--------------------------------------------------------------------
-spec handle_call(Request :: term(), From :: {pid(), term()}, State :: term()) ->
{reply, Reply :: term(), NewState :: term()} |
{reply, Reply :: term(), NewState :: term(), Timeout :: timeout()} |
{reply, Reply :: term(), NewState :: term(), hibernate} |
{noreply, NewState :: term()} |
{noreply, NewState :: term(), Timeout :: timeout()} |
{noreply, NewState :: term(), hibernate} |
{stop, Reason :: term(), Reply :: term(), NewState :: term()} |
{stop, Reason :: term(), NewState :: term()}.
handle_call(Req, _From, State) ->
?LOG(error, "Unexpected call: ~p", [Req]),
{reply, ignored, State}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Handling cast messages
%% @end
%%--------------------------------------------------------------------
-spec handle_cast(Request :: term(), State :: term()) ->
{noreply, NewState :: term()} |
{noreply, NewState :: term(), Timeout :: timeout()} |
{noreply, NewState :: term(), hibernate} |
{stop, Reason :: term(), NewState :: term()}.
handle_cast(Req, State) ->
?LOG(error, "Unexpected cast: ~p", [Req]),
{noreply, State}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Handling all non call/cast messages
%% @end
%%--------------------------------------------------------------------
-spec handle_info(Info :: timeout() | term(), State :: term()) ->
{noreply, NewState :: term()} |
{noreply, NewState :: term(), Timeout :: timeout()} |
{noreply, NewState :: term(), hibernate} |
{stop, Reason :: normal | term(), NewState :: term()}.
handle_info(oscillate, State) ->
{noreply, oscillation(State)};
handle_info(Info, State) ->
?LOG(error, "Unexpected info: ~p", [Info]),
{noreply, State}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% This function is called by a gen_server when it is about to
%% terminate. It should be the opposite of Module:init/1 and do any
%% necessary cleaning up. When it returns, the gen_server terminates
%% with Reason. The return value is ignored.
%% @end
%%--------------------------------------------------------------------
-spec terminate(Reason :: normal | shutdown | {shutdown, term()} | term(),
State :: term()) -> any().
terminate(_Reason, _State) ->
ok.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Convert process state when code is changed
%% @end
%%--------------------------------------------------------------------
-spec code_change(OldVsn :: term() | {down, term()},
State :: term(),
Extra :: term()) -> {ok, NewState :: term()} |
{error, Reason :: term()}.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% This function is called for changing the form and appearance
%% of gen_server status when it is returned from sys:get_status/1,2
%% or when it appears in termination error logs.
%% @end
%%--------------------------------------------------------------------
-spec format_status(Opt :: normal | terminate,
Status :: list()) -> Status :: term().
format_status(_Opt, Status) ->
Status.
%%--------------------------------------------------------------------
%%% Internal functions
%%--------------------------------------------------------------------
oscillate(Interval) ->
erlang:send_after(Interval, self(), ?FUNCTION_NAME).
%% @doc generate tokens, and then spread to leaf nodes
-spec oscillation(state()) -> state().
oscillation(#state{root = #root{rate = Flow,
period = Interval,
childs = ChildIds,
consumed = Consumed} = Root,
nodes = Nodes} = State) ->
oscillate(Interval),
Childs = get_orderd_childs(ChildIds, Nodes),
{Alloced, Nodes2} = transverse(Childs, Flow, 0, Nodes),
State#state{nodes = Nodes2,
root = Root#root{consumed = Consumed + Alloced}}.
%% @doc horizontal spread
-spec transverse(list(node_data()),
flow(),
non_neg_integer(),
nodes()) -> {non_neg_integer(), nodes()}.
transverse([H | T], InFlow, Alloced, Nodes) when InFlow > 0 ->
{NodeAlloced, Nodes2} = longitudinal(H, InFlow, Nodes),
InFlow2 = sub(InFlow, NodeAlloced),
Alloced2 = Alloced + NodeAlloced,
transverse(T, InFlow2, Alloced2, Nodes2);
transverse(_, _, Alloced, Nodes) ->
{Alloced, Nodes}.
%% @doc vertical spread
-spec longitudinal(node_data(), flow(), nodes()) ->
{non_neg_integer(), nodes()}.
longitudinal(#zone{id = Id,
rate = Rate,
obtained = Obtained,
childs = ChildIds} = Node, InFlow, Nodes) ->
Flow = erlang:min(InFlow, Rate),
if Flow > 0 ->
Childs = get_orderd_childs(ChildIds, Nodes),
{Alloced, Nodes2} = transverse(Childs, Flow, 0, Nodes),
if Alloced > 0 ->
{Alloced,
Nodes2#{Id => Node#zone{obtained = Obtained + Alloced}}};
true ->
%% childs are empty or all counter childs are full
{0, Nodes}
end;
true ->
{0, Nodes}
end;
longitudinal(#bucket{id = Id,
rate = Rate,
capacity = Capacity,
correction = Correction,
counter = Counter,
index = Index,
obtained = Obtained} = Node, InFlow, Nodes) ->
Flow = add(erlang:min(InFlow, Rate), Correction),
Tokens = counters:get(Counter, Index),
%% toknes's value mayb be a negative value(stolen from the future)
Avaiable = erlang:min(if Tokens < 0 ->
add(Capacity, Tokens);
true ->
sub(Capacity, Tokens)
end, Flow),
FixAvaiable = erlang:min(Capacity, Avaiable),
if FixAvaiable > 0 ->
{Alloced, Decimal} = add_to_counter(Counter, Index, FixAvaiable),
{Alloced,
Nodes#{Id => Node#bucket{obtained = Obtained + Alloced,
correction = Decimal}}};
true ->
{0, Nodes}
end.
-spec get_orderd_childs(list(node_id()), nodes()) -> list(node_data()).
get_orderd_childs(Ids, Nodes) ->
Childs = [maps:get(Id, Nodes) || Id <- Ids],
%% sort by obtained, avoid node goes hungry
lists:sort(fun(A, B) ->
?GET_FIELD(?FIELD_OBTAINED, A) < ?GET_FIELD(?FIELD_OBTAINED, B)
end,
Childs).
-spec init_tree(emqx_limiter_schema:limiter_type(), state()) -> state().
init_tree(Type, State) ->
#{global := Global,
zone := Zone,
bucket := Bucket} = emqx:get_config([emqx_limiter, Type]),
{Factor, Root} = make_root(Global, Zone),
State2 = State#state{root = Root},
{NodeId, State3} = make_zone(maps:to_list(Zone), Factor, 1, State2),
State4 = State3#state{counter = counters:new(maps:size(Bucket),
[write_concurrency])},
make_bucket(maps:to_list(Bucket), Factor, NodeId, State4).
-spec make_root(decimal(), hocon:config()) -> {number(), root()}.
make_root(Rate, Zone) ->
ZoneNum = maps:size(Zone),
Childs = lists:seq(1, ZoneNum),
MiniPeriod = emqx_limiter_schema:minimum_period(),
if Rate >= 1 ->
{1, #root{rate = Rate,
period = MiniPeriod,
childs = Childs,
consumed = 0}};
true ->
Factor = 1 / Rate,
{Factor, #root{rate = 1,
period = erlang:floor(Factor * MiniPeriod),
childs = Childs,
consumed = 0}}
end.
make_zone([{Name, Rate} | T], Factor, NodeId, State) ->
#state{zones = Zones, nodes = Nodes} = State,
Zone = #zone{id = NodeId,
name = Name,
rate = mul(Rate, Factor),
obtained = 0,
childs = []},
State2 = State#state{zones = Zones#{Name => NodeId},
nodes = Nodes#{NodeId => Zone}},
make_zone(T, Factor, NodeId + 1, State2);
make_zone([], _, NodeId, State2) ->
{NodeId, State2}.
make_bucket([{Name, Conf} | T], Factor, NodeId, State) ->
#{zone := ZoneName,
aggregated := [Rate, Capacity]} = Conf,
{Counter, Idx, State2} = alloc_counter(ZoneName, Name, Rate, State),
Node = #bucket{ id = NodeId
, name = Name
, rate = mul(Rate, Factor)
, obtained = 0
, correction = 0
, capacity = Capacity
, counter = Counter
, index = Idx},
State3 = add_zone_child(NodeId, Node, ZoneName, State2),
make_bucket(T, Factor, NodeId + 1, State3);
make_bucket([], _, _, State) ->
State.
-spec alloc_counter(zone_name(), bucket_name(), rate(), state()) ->
{counters:counters_ref(), pos_integer(), state()}.
alloc_counter(Zone, Bucket, Rate,
#state{type = Type, counter = Counter, index = Index} = State) ->
Path = emqx_limiter_manager:make_path(Type, Zone, Bucket),
case emqx_limiter_manager:find_counter(Path) of
undefined ->
init_counter(Path, Counter, Index,
Rate, State#state{index = Index + 1});
{ok, ECounter, EIndex, _} ->
init_counter(Path, ECounter, EIndex, Rate, State)
end.
init_counter(Path, Counter, Index, Rate, State) ->
_ = put_to_counter(Counter, Index, 0),
emqx_limiter_manager:insert_counter(Path, Counter, Index, Rate),
{Counter, Index, State}.
-spec add_zone_child(node_id(), bucket(), zone_name(), state()) -> state().
add_zone_child(NodeId, Bucket, Name, #state{zones = Zones, nodes = Nodes} = State) ->
ZoneId = maps:get(Name, Zones),
#zone{childs = Childs} = Zone = maps:get(ZoneId, Nodes),
Nodes2 = Nodes#{ZoneId => Zone#zone{childs = [NodeId | Childs]},
NodeId => Bucket},
State#state{nodes = Nodes2}.

View File

@ -1,272 +0,0 @@
%%--------------------------------------------------------------------
%% Copyright (c) 2020-2021 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_limiter_SUITE).
-compile(export_all).
-compile(nowarn_export_all).
-define(APP, emqx_limiter).
-include_lib("eunit/include/eunit.hrl").
-include_lib("common_test/include/ct.hrl").
-define(BASE_CONF, <<"""
emqx_limiter {
bytes_in {global = \"100KB/10s\"
zone.default = \"100kB/10s\"
zone.external = \"20kB/10s\"
bucket.tcp {zone = default
aggregated = \"100kB/10s,1Mb\"
per_client = \"100KB/10s,10Kb\"}
bucket.ssl {zone = external
aggregated = \"100kB/10s,1Mb\"
per_client = \"100KB/10s,10Kb\"}
}
message_in {global = \"100/10s\"
zone.default = \"100/10s\"
bucket.bucket1 {zone = default
aggregated = \"100/10s,1000\"
per_client = \"100/10s,100\"}
}
connection {global = \"100/10s\"
zone.default = \"100/10s\"
bucket.bucket1 {zone = default
aggregated = \"100/10s,100\"
per_client = \"100/10s,10\"
}
}
message_routing {global = \"100/10s\"
zone.default = \"100/10s\"
bucket.bucket1 {zone = default
aggregated = \"100/10s,100\"
per_client = \"100/10s,10\"
}
}
}""">>).
-define(LOGT(Format, Args), ct:pal("TEST_SUITE: " ++ Format, Args)).
-record(client_options, { interval :: non_neg_integer()
, per_cost :: non_neg_integer()
, type :: atom()
, bucket :: atom()
, lifetime :: non_neg_integer()
, rates :: list(tuple())
}).
-record(client_state, { client :: emqx_limiter_client:limiter()
, pid :: pid()
, got :: non_neg_integer()
, options :: #client_options{}}).
%%--------------------------------------------------------------------
%% Setups
%%--------------------------------------------------------------------
all() -> emqx_common_test_helpers:all(?MODULE).
init_per_suite(Config) ->
ok = emqx_config:init_load(emqx_limiter_schema, ?BASE_CONF),
emqx_common_test_helpers:start_apps([?APP]),
Config.
end_per_suite(_Config) ->
emqx_common_test_helpers:stop_apps([?APP]).
init_per_testcase(_TestCase, Config) ->
Config.
%%--------------------------------------------------------------------
%% Test Cases
%%--------------------------------------------------------------------
t_un_overload(_) ->
Conf = emqx:get_config([emqx_limiter]),
Conn = #{global => to_rate("infinity"),
zone => #{z1 => to_rate("1000/1s"),
z2 => to_rate("1000/1s")},
bucket => #{b1 => #{zone => z1,
aggregated => to_bucket_rate("100/1s, 500"),
per_client => to_bucket_rate("10/1s, 50")},
b2 => #{zone => z2,
aggregated => to_bucket_rate("500/1s, 500"),
per_client => to_bucket_rate("100/1s, infinity")
}}},
Conf2 = Conf#{connection => Conn},
emqx_config:put([emqx_limiter], Conf2),
{ok, _} = emqx_limiter_manager:restart_server(connection),
timer:sleep(200),
B1C = #client_options{interval = 100,
per_cost = 1,
type = connection,
bucket = b1,
lifetime = timer:seconds(3),
rates = [{fun erlang:'=<'/2, ["1000/1s", "100/1s"]},
{fun erlang:'=:='/2, ["10/1s"]}]},
B2C = #client_options{interval = 100,
per_cost = 10,
type = connection,
bucket = b2,
lifetime = timer:seconds(3),
rates = [{fun erlang:'=<'/2, ["1000/1s", "500/1s"]},
{fun erlang:'=:='/2, ["100/1s"]}]},
lists:foreach(fun(_) -> start_client(B1C) end,
lists:seq(1, 10)),
lists:foreach(fun(_) -> start_client(B2C) end,
lists:seq(1, 5)),
?assert(check_client_result(10 + 5)).
t_infinity(_) ->
Conf = emqx:get_config([emqx_limiter]),
Conn = #{global => to_rate("infinity"),
zone => #{z1 => to_rate("1000/1s"),
z2 => to_rate("infinity")},
bucket => #{b1 => #{zone => z1,
aggregated => to_bucket_rate("100/1s, infinity"),
per_client => to_bucket_rate("10/1s, 100")},
b2 => #{zone => z2,
aggregated => to_bucket_rate("infinity, 600"),
per_client => to_bucket_rate("100/1s, infinity")
}}},
Conf2 = Conf#{connection => Conn},
emqx_config:put([emqx_limiter], Conf2),
{ok, _} = emqx_limiter_manager:restart_server(connection),
timer:sleep(200),
B1C = #client_options{interval = 100,
per_cost = 1,
type = connection,
bucket = b1,
lifetime = timer:seconds(3),
rates = [{fun erlang:'=<'/2, ["1000/1s", "100/1s"]},
{fun erlang:'=:='/2, ["10/1s"]}]},
B2C = #client_options{interval = 100,
per_cost = 10,
type = connection,
bucket = b2,
lifetime = timer:seconds(3),
rates = [{fun erlang:'=:='/2, ["100/1s"]}]},
lists:foreach(fun(_) -> start_client(B1C) end,
lists:seq(1, 8)),
lists:foreach(fun(_) -> start_client(B2C) end,
lists:seq(1, 4)),
?assert(check_client_result(8 + 4)).
%%--------------------------------------------------------------------
%%% Internal functions
%%--------------------------------------------------------------------
start_client(Opts) ->
Pid = self(),
erlang:spawn(fun() -> enter_client(Opts, Pid) end).
enter_client(#client_options{type = Type,
bucket = Bucket,
lifetime = Lifetime} = Opts,
Pid) ->
erlang:send_after(Lifetime, self(), stop),
erlang:send(self(), consume),
Client = emqx_limiter_server:connect(Type, Bucket),
client_loop(#client_state{client = Client,
pid = Pid,
got = 0,
options = Opts}).
client_loop(#client_state{client = Client,
got = Got,
pid = Pid,
options = #client_options{interval = Interval,
per_cost = PerCost,
lifetime = Lifetime,
rates = Rates}} = State) ->
receive
consume ->
case emqx_limiter_client:consume(PerCost, Client) of
{ok, Client2} ->
erlang:send_after(Interval, self(), consume),
client_loop(State#client_state{client = Client2,
got = Got + PerCost});
{pause, MS, Client2} ->
erlang:send_after(MS, self(), {resume, erlang:system_time(millisecond)}),
client_loop(State#client_state{client = Client2})
end;
stop ->
Rate = Got * emqx_limiter_schema:minimum_period() / Lifetime,
?LOGT("Got:~p, Rate is:~p Checks:~p~n", [Got, Rate, Rate]),
Check = check_rates(Rate, Rates),
erlang:send(Pid, {client, Check});
{resume, Begin} ->
case emqx_limiter_client:consume(PerCost, Client) of
{ok, Client2} ->
Now = erlang:system_time(millisecond),
Diff = erlang:max(0, Interval - (Now - Begin)),
erlang:send_after(Diff, self(), consume),
client_loop(State#client_state{client = Client2,
got = Got + PerCost});
{pause, MS, Client2} ->
erlang:send_after(MS, self(), {resume, Begin}),
client_loop(State#client_state{client = Client2})
end
end.
check_rates(Rate, [{Fun, Rates} | T]) ->
case lists:all(fun(E) -> Fun(Rate, to_rate(E)) end, Rates) of
true ->
check_rates(Rate, T);
false ->
false
end;
check_rates(_, _) ->
true.
check_client_result(0) ->
true;
check_client_result(N) ->
?LOGT("check_client_result:~p~n", [N]),
receive
{client, true} ->
check_client_result(N - 1);
{client, false} ->
false;
Any ->
?LOGT(">>>> other:~p~n", [Any])
after 3500 ->
?LOGT(">>>> timeout~n", []),
false
end.
to_rate(Str) ->
{ok, Rate} = emqx_limiter_schema:to_rate(Str),
Rate.
to_bucket_rate(Str) ->
{ok, Result} = emqx_limiter_schema:to_bucket_rate(Str),
Result.

View File

@ -177,6 +177,6 @@ t_keepalive(_Config) ->
[Pid] = emqx_cm:lookup_channels(list_to_binary(ClientId)),
State = sys:get_state(Pid),
ct:pal("~p~n", [State]),
?assertEqual(11000, element(2, element(5, element(11, State)))),
?assertEqual(11000, element(2, element(5, element(9, State)))),
emqtt:disconnect(C1),
ok.

View File

@ -305,7 +305,6 @@ relx_apps(ReleaseType, Edition) ->
, emqx_statsd
, emqx_prometheus
, emqx_psk
, emqx_limiter
]
++ [quicer || is_quicer_supported()]
%++ [emqx_license || is_enterprise(Edition)]

View File

@ -12,16 +12,54 @@
main(_) ->
{ok, BaseConf} = file:read_file("apps/emqx_conf/etc/emqx_conf.conf"),
Apps = filelib:wildcard("*", "apps/") -- ["emqx_machine", "emqx_conf"],
Conf = lists:foldl(fun(App, Acc) ->
Filename = filename:join([apps, App, "etc", App]) ++ ".conf",
case filelib:is_regular(Filename) of
true ->
{ok, Bin1} = file:read_file(Filename),
[Acc, io_lib:nl(), Bin1];
false -> Acc
end
end, BaseConf, Apps),
Cfgs = get_all_cfgs("apps/"),
Conf = lists:foldl(fun(CfgFile, Acc) ->
case filelib:is_regular(CfgFile) of
true ->
{ok, Bin1} = file:read_file(CfgFile),
[Acc, io_lib:nl(), Bin1];
false -> Acc
end
end, BaseConf, Cfgs),
ClusterInc = "include \"cluster-override.conf\"\n",
LocalInc = "include \"local-override.conf\"\n",
ok = file:write_file("apps/emqx_conf/etc/emqx.conf.all", [Conf, ClusterInc, LocalInc]).
get_all_cfgs(Root) ->
Apps = filelib:wildcard("*", Root) -- ["emqx_machine", "emqx_conf"],
Dirs = [filename:join([Root, App]) || App <- Apps],
lists:foldl(fun get_cfgs/2, [], Dirs).
get_all_cfgs(Dir, Cfgs) ->
Fun = fun(E, Acc) ->
Path = filename:join([Dir, E]),
get_cfgs(Path, Acc)
end,
lists:foldl(Fun, Cfgs, filelib:wildcard("*", Dir)).
get_cfgs(Dir, Cfgs) ->
case filelib:is_dir(Dir) of
false ->
Cfgs;
_ ->
Files = filelib:wildcard("*", Dir),
case lists:member("etc", Files) of
false ->
try_enter_child(Dir, Files, Cfgs);
true ->
EtcDir = filename:join([Dir, "etc"]),
%% the conf name must start with emqx
%% because there are some other conf, and these conf don't start with emqx
Confs = filelib:wildcard("emqx*.conf", EtcDir),
NewCfgs = [filename:join([EtcDir, Name]) || Name <- Confs],
try_enter_child(Dir, Files, NewCfgs ++ Cfgs)
end
end.
try_enter_child(Dir, Files, Cfgs) ->
case lists:member("src", Files) of
false ->
Cfgs;
true ->
get_all_cfgs(filename:join([Dir, "src"]), Cfgs)
end.