diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl index 67ce784b1..6fbd7d217 100644 --- a/src/emqx_channel.erl +++ b/src/emqx_channel.erl @@ -118,6 +118,8 @@ info(Keys, Channel) when is_list(Keys) -> [{Key, info(Key, Channel)} || Key <- Keys]; info(conninfo, #channel{conninfo = ConnInfo}) -> ConnInfo; +info(zone, #channel{clientinfo = #{zone := Zone}}) -> + Zone; info(clientid, #channel{clientinfo = #{clientid := ClientId}}) -> ClientId; info(clientinfo, #channel{clientinfo = ClientInfo}) -> @@ -147,11 +149,6 @@ stats(#channel{session = Session})-> caps(#channel{clientinfo = #{zone := Zone}}) -> emqx_mqtt_caps:get_caps(Zone). -%% For tests -set_field(Name, Val, Channel) -> - Fields = record_info(fields, channel), - Pos = emqx_misc:index_of(Name, Fields), - setelement(Pos+1, Channel, Val). %%-------------------------------------------------------------------- %% Init the channel @@ -1309,3 +1306,11 @@ sp(false) -> 0. flag(true) -> 1; flag(false) -> 0. +%%-------------------------------------------------------------------- +%% For CT tests +%%-------------------------------------------------------------------- + +set_field(Name, Value, Channel) -> + Pos = emqx_misc:index_of(Name, record_info(fields, channel)), + setelement(Pos+1, Channel, Value). + diff --git a/src/emqx_connection.erl b/src/emqx_connection.erl index 182a12c73..b88a1dfb4 100644 --- a/src/emqx_connection.erl +++ b/src/emqx_connection.erl @@ -53,6 +53,9 @@ %% Internal callback -export([wakeup_from_hib/3]). +%% Export for CT +-export([set_field/3]). + -import(emqx_misc, [ maybe_apply/2 , start_timer/2 @@ -181,7 +184,8 @@ do_init(Parent, Transport, Socket, Options) -> }, Zone = proplists:get_value(zone, Options), ActiveN = proplists:get_value(active_n, Options, ?ACTIVE_N), - Limiter = emqx_limiter:init(Options), + PubLimit = emqx_zone:publish_limit(Zone), + Limiter = emqx_limiter:init([{pub_limit, PubLimit}|Options]), FrameOpts = emqx_zone:mqtt_frame_options(Zone), ParseState = emqx_frame:initial_parse_state(FrameOpts), Serialize = emqx_frame:serialize_fun(), @@ -491,7 +495,7 @@ parse_incoming(Data, Packets, State = #state{parse_state = ParseState}) -> parse_incoming(Rest, [Packet|Packets], NState) catch error:Reason:Stk -> - ?LOG(error, "~nParse failed for ~p~nStacktrace: ~p~nFrame data:~p", + ?LOG(error, "~nParse failed for ~p~n~p~nFrame data:~p", [Reason, Stk, Data]), {[{frame_error, Reason}|Packets], State} end. @@ -613,12 +617,12 @@ run_gc(Stats, State = #state{gc_state = GcSt}) -> case ?ENABLED(GcSt) andalso emqx_gc:run(Stats, GcSt) of false -> State; {IsGC, GcSt1} -> - IsGC andalso emqx_metrics:inc('channel.gc.cnt'), + IsGC andalso emqx_metrics:inc('channel.gc'), State#state{gc_state = GcSt1} end. check_oom(State = #state{channel = Channel}) -> - #{zone := Zone} = emqx_channel:info(clientinfo, Channel), + Zone = emqx_channel:info(zone, Channel), OomPolicy = emqx_zone:oom_policy(Zone), case ?ENABLED(OomPolicy) andalso emqx_misc:check_oom(OomPolicy) of Shutdown = {shutdown, _Reason} -> @@ -706,3 +710,11 @@ stop(Reason, State) -> stop(Reason, Reply, State) -> {stop, Reason, Reply, State}. +%%-------------------------------------------------------------------- +%% For CT tests +%%-------------------------------------------------------------------- + +set_field(Name, Value, State) -> + Pos = emqx_misc:index_of(Name, record_info(fields, state)), + setelement(Pos+1, State, Value). + diff --git a/src/emqx_limiter.erl b/src/emqx_limiter.erl index a063bcad2..ee8867613 100644 --- a/src/emqx_limiter.erl +++ b/src/emqx_limiter.erl @@ -37,8 +37,7 @@ -spec(init(proplists:proplist()) -> maybe(limiter())). init(Options) -> - Zone = proplists:get_value(zone, Options), - Pl = emqx_zone:publish_limit(Zone), + Pl = proplists:get_value(pub_limit, Options), Rl = proplists:get_value(rate_limit, Options), case ?ENABLED(Pl) or ?ENABLED(Rl) of true -> #limiter{pub_limit = init_limit(Pl), diff --git a/src/emqx_session.erl b/src/emqx_session.erl index 234ba2041..9561efef9 100644 --- a/src/emqx_session.erl +++ b/src/emqx_session.erl @@ -651,8 +651,7 @@ age(Now, Ts) -> Now - Ts. %% For CT tests %%-------------------------------------------------------------------- -set_field(Name, Val, Channel) -> - Fields = record_info(fields, session), - Pos = emqx_misc:index_of(Name, Fields), - setelement(Pos+1, Channel, Val). +set_field(Name, Value, Session) -> + Pos = emqx_misc:index_of(Name, record_info(fields, session)), + setelement(Pos+1, Session, Value). diff --git a/src/emqx_ws_connection.erl b/src/emqx_ws_connection.erl index ee55857da..a67d5717c 100644 --- a/src/emqx_ws_connection.erl +++ b/src/emqx_ws_connection.erl @@ -45,13 +45,16 @@ , terminate/3 ]). +%% Export for CT +-export([set_field/3]). + -import(emqx_misc, [ maybe_apply/2 , start_timer/2 ]). -record(state, { - %% Peername of the ws connection. + %% Peername of the ws connection peername :: emqx_types:peername(), %% Sockname of the ws connection sockname :: emqx_types:peername(), @@ -65,26 +68,28 @@ limit_timer :: maybe(reference()), %% Parse State parse_state :: emqx_frame:parse_state(), - %% Serialize function + %% Serialize Fun serialize :: emqx_frame:serialize_fun(), %% Channel channel :: emqx_channel:channel(), %% GC State gc_state :: maybe(emqx_gc:gc_state()), - %% Out Pending Packets - pendings :: list(emqx_types:packet()), + %% Postponed Packets|Cmds|Events + postponed :: list(emqx_types:packet()|ws_cmd()|tuple()), %% Stats Timer stats_timer :: disabled | maybe(reference()), %% Idle Timeout idle_timeout :: timeout(), %% Idle Timer idle_timer :: reference(), - %% The stop reason + %% The Stop Reason stop_reason :: term() }). -type(state() :: #state{}). +-type(ws_cmd() :: {active, boolean()}|close). + -define(ACTIVE_N, 100). -define(INFO_KEYS, [socktype, peername, sockname, sockstate, active_n]). -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]). @@ -93,7 +98,7 @@ -define(ENABLED(X), (X =/= undefined)). %%-------------------------------------------------------------------- -%% API +%% Info, Stats %%-------------------------------------------------------------------- -spec(info(pid()|state()) -> emqx_types:infos()). @@ -120,6 +125,12 @@ info(limiter, #state{limiter = Limiter}) -> maybe_apply(fun emqx_limiter:info/1, Limiter); info(channel, #state{channel = Channel}) -> emqx_channel:info(Channel); +info(gc_state, #state{gc_state = GcSt}) -> + maybe_apply(fun emqx_gc:info/1, GcSt); +info(postponed, #state{postponed = Postponed}) -> + Postponed; +info(stats_timer, #state{stats_timer = TRef}) -> + TRef; info(stop_reason, #state{stop_reason = Reason}) -> Reason. @@ -201,8 +212,9 @@ websocket_init([Req, Opts]) -> conn_mod => ?MODULE }, Zone = proplists:get_value(zone, Opts), + PubLimit = emqx_zone:publish_limit(Zone), + Limiter = emqx_limiter:init([{pub_limit, PubLimit}|Opts]), ActiveN = proplists:get_value(active_n, Opts, ?ACTIVE_N), - Limiter = emqx_limiter:init(Opts), FrameOpts = emqx_zone:mqtt_frame_options(Zone), ParseState = emqx_frame:initial_parse_state(FrameOpts), Serialize = emqx_frame:serialize_fun(), @@ -223,7 +235,7 @@ websocket_init([Req, Opts]) -> serialize = Serialize, channel = Channel, gc_state = GcState, - pendings = [], + postponed = [], stats_timer = StatsTimer, idle_timeout = IdleTimeout, idle_timer = IdleTimer @@ -235,63 +247,60 @@ websocket_handle({binary, Data}, State) when is_list(Data) -> websocket_handle({binary, Data}, State) -> ?LOG(debug, "RECV ~p", [Data]), ok = inc_recv_stats(1, iolist_size(Data)), - parse_incoming(Data, ensure_stats_timer(State)); + NState = ensure_stats_timer(State), + return(parse_incoming(Data, NState)); %% Pings should be replied with pongs, cowboy does it automatically %% Pongs can be safely ignored. Clause here simply prevents crash. websocket_handle(Frame, State) when Frame =:= ping; Frame =:= pong -> - {ok, State}; + return(State); -websocket_handle({FrameType, _}, State) when FrameType =:= ping; - FrameType =:= pong -> - {ok, State}; +websocket_handle({Frame, _}, State) when Frame =:= ping; Frame =:= pong -> + return(State); -websocket_handle({FrameType, _}, State) -> - ?LOG(error, "Unexpected frame - ~p", [FrameType]), - stop({shutdown, unexpected_ws_frame}, State). +websocket_handle({Frame, _}, State) -> + %% TODO: should not close the ws connection + ?LOG(error, "Unexpected frame - ~p", [Frame]), + shutdown(unexpected_ws_frame, State). websocket_info({call, From, Req}, State) -> handle_call(From, Req, State); -websocket_info({cast, Msg}, State = #state{channel = Channel}) -> - handle_chan_return(emqx_channel:handle_info(Msg, Channel), State); +websocket_info({cast, rate_limit}, State) -> + Stats = #{cnt => emqx_pd:reset_counter(incoming_pubs), + oct => emqx_pd:reset_counter(incoming_bytes) + }, + NState = postpone({check_gc, Stats}, State), + return(ensure_rate_limit(Stats, NState)); -websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, - State = #state{idle_timer = IdleTimer}) -> - ok = emqx_misc:cancel_timer(IdleTimer), +websocket_info({cast, Msg}, State) -> + handle_info(Msg, State); + +websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State) -> Serialize = emqx_frame:serialize_fun(ConnPkt), - NState = State#state{serialize = Serialize, - idle_timer = undefined - }, - handle_incoming(Packet, NState); + NState = State#state{serialize = Serialize}, + handle_incoming(Packet, cancel_idle_timer(NState)); websocket_info({incoming, ?PACKET(?PINGREQ)}, State) -> - reply(?PACKET(?PINGRESP), State); + return(enqueue(?PACKET(?PINGRESP), State)); websocket_info({incoming, Packet}, State) -> handle_incoming(Packet, State); -websocket_info(rate_limit, State) -> - InStats = #{cnt => emqx_pd:reset_counter(incoming_pubs), - oct => emqx_pd:reset_counter(incoming_bytes) - }, - erlang:send(self(), {check_gc, InStats}), - ensure_rate_limit(InStats, State); - websocket_info({check_gc, Stats}, State) -> - {ok, check_oom(run_gc(Stats, State))}; + return(check_oom(run_gc(Stats, State))); websocket_info(Deliver = {deliver, _Topic, _Msg}, - State = #state{active_n = ActiveN, channel = Channel}) -> + State = #state{active_n = ActiveN}) -> Delivers = [Deliver|emqx_misc:drain_deliver(ActiveN)], - Ret = emqx_channel:handle_deliver(Delivers, Channel), - handle_chan_return(Ret, State); + with_channel(handle_deliver, [Delivers], State); -websocket_info({timeout, TRef, limit_timeout}, State = #state{limit_timer = TRef}) -> +websocket_info({timeout, TRef, limit_timeout}, + State = #state{limit_timer = TRef}) -> NState = State#state{sockstate = running, limit_timer = undefined }, - {reply, [{active, true}], NState}; + return(enqueue({active, true}, NState)); websocket_info({timeout, TRef, Msg}, State) when is_reference(TRef) -> handle_timeout(TRef, Msg, State); @@ -300,7 +309,7 @@ websocket_info(Close = {close, _Reason}, State) -> handle_info(Close, State); websocket_info({shutdown, Reason}, State) -> - stop({shutdown, Reason}, State); + shutdown(Reason, State); websocket_info({stop, Reason}, State) -> stop(Reason, State); @@ -312,68 +321,70 @@ websocket_close(Reason, State) -> ?LOG(debug, "WebSocket closed due to ~p~n", [Reason]), handle_info({sock_closed, Reason}, State). -terminate(SockError, _Req, #state{channel = Channel, - stop_reason = Reason}) -> - ?LOG(debug, "Terminated for ~p, sockerror: ~p", [Reason, SockError]), +terminate(Error, _Req, #state{channel = Channel, stop_reason = Reason}) -> + ?LOG(debug, "Terminated for ~p, error: ~p", [Reason, Error]), emqx_channel:terminate(Reason, Channel). %%-------------------------------------------------------------------- %% Handle call +%%-------------------------------------------------------------------- handle_call(From, info, State) -> gen_server:reply(From, info(State)), - {ok, State}; + return(State); handle_call(From, stats, State) -> gen_server:reply(From, stats(State)), - {ok, State}; + return(State); handle_call(From, Req, State = #state{channel = Channel}) -> case emqx_channel:handle_call(Req, Channel) of {reply, Reply, NChannel} -> - _ = gen_server:reply(From, Reply), - {ok, State#state{channel = NChannel}}; - {stop, Reason, Reply, NChannel} -> - _ = gen_server:reply(From, Reply), - stop(Reason, State#state{channel = NChannel}); - {stop, Reason, Reply, OutPacket, NChannel} -> + gen_server:reply(From, Reply), + return(State#state{channel = NChannel}); + {shutdown, Reason, Reply, NChannel} -> + gen_server:reply(From, Reply), + shutdown(Reason, State#state{channel = NChannel}); + {shutdown, Reason, Reply, Packet, NChannel} -> gen_server:reply(From, Reply), NState = State#state{channel = NChannel}, - stop(Reason, enqueue(OutPacket, NState)) + shutdown(Reason, enqueue(Packet, NState)) end. %%-------------------------------------------------------------------- %% Handle Info +%%-------------------------------------------------------------------- handle_info({connack, ConnAck}, State) -> - reply(enqueue(ConnAck, State)); + return(enqueue(ConnAck, State)); handle_info({close, Reason}, State) -> - stop({shutdown, Reason}, State); + %% TODO: close ws conn? + shutdown(Reason, State); handle_info({event, connected}, State = #state{channel = Channel}) -> ClientId = emqx_channel:info(clientid, Channel), - emqx_cm:register_channel(ClientId, info(State), stats(State)), - reply(State); + ok = emqx_cm:register_channel(ClientId, info(State), stats(State)), + return(State); handle_info({event, disconnected}, State = #state{channel = Channel}) -> ClientId = emqx_channel:info(clientid, Channel), emqx_cm:set_chan_info(ClientId, info(State)), emqx_cm:connection_closed(ClientId), - reply(State); + return(State); handle_info({event, _Other}, State = #state{channel = Channel}) -> ClientId = emqx_channel:info(clientid, Channel), emqx_cm:set_chan_info(ClientId, info(State)), emqx_cm:set_chan_stats(ClientId, stats(State)), - reply(State); + return(State); -handle_info(Info, State = #state{channel = Channel}) -> - Ret = emqx_channel:handle_info(Info, Channel), - handle_chan_return(Ret, State). +handle_info(Info, State) -> + with_channel(handle_info, [Info], State). %%-------------------------------------------------------------------- %% Handle timeout +%%-------------------------------------------------------------------- handle_timeout(TRef, idle_timeout, State = #state{idle_timer = TRef}) -> shutdown(idle_timeout, State); @@ -382,33 +393,24 @@ handle_timeout(TRef, keepalive, State) when is_reference(TRef) -> RecvOct = emqx_pd:get_counter(recv_oct), handle_timeout(TRef, {keepalive, RecvOct}, State); -handle_timeout(TRef, emit_stats, State = - #state{channel = Channel, stats_timer = TRef}) -> +handle_timeout(TRef, emit_stats, State = #state{channel = Channel, + stats_timer = TRef}) -> ClientId = emqx_channel:info(clientid, Channel), emqx_cm:set_chan_stats(ClientId, stats(State)), - reply(State#state{stats_timer = undefined}); + return(State#state{stats_timer = undefined}); -handle_timeout(TRef, TMsg, State = #state{channel = Channel}) -> - Ret = emqx_channel:handle_timeout(TRef, TMsg, Channel), - handle_chan_return(Ret, State). - -%%-------------------------------------------------------------------- -%% Ensure stats timer - --compile({inline, [ensure_stats_timer/1]}). -ensure_stats_timer(State = #state{idle_timeout = Timeout, - stats_timer = undefined}) -> - State#state{stats_timer = start_timer(Timeout, emit_stats)}; -ensure_stats_timer(State) -> State. +handle_timeout(TRef, TMsg, State) -> + with_channel(handle_timeout, [TRef, TMsg], State). %%-------------------------------------------------------------------- %% Ensure rate limit +%%-------------------------------------------------------------------- ensure_rate_limit(Stats, State = #state{limiter = Limiter}) -> case ?ENABLED(Limiter) andalso emqx_limiter:check(Stats, Limiter) of - false -> {ok, State}; + false -> State; {ok, Limiter1} -> - {ok, State#state{limiter = Limiter1}}; + State#state{limiter = Limiter1}; {pause, Time, Limiter1} -> ?LOG(debug, "Pause ~pms due to rate limit", [Time]), TRef = start_timer(Time, limit_timeout), @@ -416,102 +418,108 @@ ensure_rate_limit(Stats, State = #state{limiter = Limiter}) -> limiter = Limiter1, limit_timer = TRef }, - {reply, [{active, false}], NState} + enqueue({active, false}, NState) end. %%-------------------------------------------------------------------- -%% Run GC and Check OOM +%% Run GC, Check OOM +%%-------------------------------------------------------------------- run_gc(Stats, State = #state{gc_state = GcSt}) -> case ?ENABLED(GcSt) andalso emqx_gc:run(Stats, GcSt) of - false -> State; {IsGC, GcSt1} -> - IsGC andalso emqx_metrics:inc('channel.gc.cnt'), - State#state{gc_state = GcSt1} + IsGC andalso emqx_metrics:inc('channel.gc'), + State#state{gc_state = GcSt1}; + false -> State end. check_oom(State = #state{channel = Channel}) -> - #{zone := Zone} = emqx_channel:info(clientinfo, Channel), - OomPolicy = emqx_zone:oom_policy(Zone), + OomPolicy = emqx_zone:oom_policy(emqx_channel:info(zone, Channel)), case ?ENABLED(OomPolicy) andalso emqx_misc:check_oom(OomPolicy) of Shutdown = {shutdown, _Reason} -> - erlang:send(self(), Shutdown); - _Other -> ok - end, - State. + postpone(Shutdown, State); + _Other -> State + end. %%-------------------------------------------------------------------- %% Parse incoming data +%%-------------------------------------------------------------------- parse_incoming(<<>>, State) -> - {ok, State}; + State; parse_incoming(Data, State = #state{parse_state = ParseState}) -> try emqx_frame:parse(Data, ParseState) of {more, NParseState} -> - {ok, State#state{parse_state = NParseState}}; + State#state{parse_state = NParseState}; {ok, Packet, Rest, NParseState} -> - erlang:send(self(), {incoming, Packet}), - parse_incoming(Rest, State#state{parse_state = NParseState}) + NState = State#state{parse_state = NParseState}, + parse_incoming(Rest, postpone({incoming, Packet}, NState)) catch error:Reason:Stk -> - ?LOG(error, "~nParse failed for ~p~nStacktrace: ~p~nFrame data: ~p", + ?LOG(error, "~nParse failed for ~p~n~p~nFrame data: ~p", [Reason, Stk, Data]), - self() ! {incoming, {frame_error, Reason}}, - {ok, State} + FrameError = {frame_error, Reason}, + postpone({incoming, FrameError}, State) end. %%-------------------------------------------------------------------- %% Handle incoming packet +%%-------------------------------------------------------------------- -handle_incoming(Packet, State = #state{active_n = ActiveN, channel = Channel}) +handle_incoming(Packet, State = #state{active_n = ActiveN}) when is_record(Packet, mqtt_packet) -> ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]), ok = inc_incoming_stats(Packet), - (emqx_pd:get_counter(incoming_pubs) > ActiveN) - andalso erlang:send(self(), rate_limit), - Ret = emqx_channel:handle_in(Packet, Channel), - handle_chan_return(Ret, State); + NState = case emqx_pd:get_counter(incoming_pubs) > ActiveN of + true -> postpone({cast, rate_limit}, State); + false -> State + end, + with_channel(handle_in, [Packet], NState); -handle_incoming(FrameError, State = #state{channel = Channel}) -> - handle_chan_return(emqx_channel:handle_in(FrameError, Channel), State). +handle_incoming(FrameError, State) -> + with_channel(handle_in, [FrameError], State). %%-------------------------------------------------------------------- -%% Handle channel return +%% With Channel +%%-------------------------------------------------------------------- -handle_chan_return(ok, State) -> - reply(State); -handle_chan_return({ok, NChannel}, State) -> - reply(State#state{channel= NChannel}); -handle_chan_return({ok, Replies, NChannel}, State) -> - reply(Replies, State#state{channel= NChannel}); -handle_chan_return({shutdown, Reason, NChannel}, State) -> - stop(Reason, State#state{channel = NChannel}); -handle_chan_return({shutdown, Reason, OutPacket, NChannel}, State) -> - NState = State#state{channel = NChannel}, - stop(Reason, enqueue(OutPacket, NState)). +with_channel(Fun, Args, State = #state{channel = Channel}) -> + case erlang:apply(emqx_channel, Fun, Args ++ [Channel]) of + ok -> return(State); + {ok, NChannel} -> + return(State#state{channel = NChannel}); + {ok, Replies, NChannel} -> + return(postpone(Replies, State#state{channel= NChannel})); + {shutdown, Reason, NChannel} -> + shutdown(Reason, State#state{channel = NChannel}); + {shutdown, Reason, Packet, NChannel} -> + NState = State#state{channel = NChannel}, + shutdown(Reason, postpone(Packet, NState)) + end. %%-------------------------------------------------------------------- %% Handle outgoing packets +%%-------------------------------------------------------------------- handle_outgoing(Packets, State = #state{active_n = ActiveN}) -> IoData = lists:map(serialize_and_inc_stats_fun(State), Packets), Oct = iolist_size(IoData), ok = inc_sent_stats(length(Packets), Oct), - case emqx_pd:get_counter(outgoing_pubs) > ActiveN of - true -> - OutStats = #{cnt => emqx_pd:reset_counter(outgoing_pubs), - oct => emqx_pd:reset_counter(outgoing_bytes) - }, - erlang:send(self(), {check_gc, OutStats}); - false -> ok - end, - {{binary, IoData}, ensure_stats_timer(State)}. + NState = case emqx_pd:get_counter(outgoing_pubs) > ActiveN of + true -> + Stats = #{cnt => emqx_pd:reset_counter(outgoing_pubs), + oct => emqx_pd:reset_counter(outgoing_bytes) + }, + postpone({check_gc, Stats}, State); + false -> State + end, + {{binary, IoData}, ensure_stats_timer(NState)}. serialize_and_inc_stats_fun(#state{serialize = Serialize}) -> fun(Packet) -> case Serialize(Packet) of - <<>> -> ?LOG(warning, "~s is discarded due to the frame is too large!", + <<>> -> ?LOG(warning, "~s is discarded due to the frame is too large.", [emqx_packet:format(Packet)]), <<>>; Data -> ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), @@ -522,6 +530,7 @@ serialize_and_inc_stats_fun(#state{serialize = Serialize}) -> %%-------------------------------------------------------------------- %% Inc incoming/outgoing stats +%%-------------------------------------------------------------------- -compile({inline, [ inc_recv_stats/2 @@ -561,46 +570,90 @@ inc_sent_stats(Cnt, Oct) -> emqx_metrics:inc('bytes.sent', Oct). %%-------------------------------------------------------------------- -%% Reply or Stop +%% Helper functions +%%-------------------------------------------------------------------- -reply(Packet, State) when is_record(Packet, mqtt_packet) -> - reply(enqueue(Packet, State)); -reply({outgoing, Packets}, State) -> - reply(enqueue(Packets, State)); -reply(Other, State) when is_tuple(Other) -> - self() ! Other, - reply(State); +-compile({inline, [cancel_idle_timer/1, ensure_stats_timer/1]}). -reply([], State) -> - reply(State); -reply([Packet|More], State) when is_record(Packet, mqtt_packet) -> - reply(More, enqueue(Packet, State)); -reply([{outgoing, Packets}|More], State) -> - reply(More, enqueue(Packets, State)); -reply([Other|More], State) -> - self() ! Other, - reply(More, State). +%%-------------------------------------------------------------------- +%% Cancel idle timer --compile({inline, [reply/1, enqueue/2]}). +cancel_idle_timer(State = #state{idle_timer = IdleTimer}) -> + ok = emqx_misc:cancel_timer(IdleTimer), + State#state{idle_timer = undefined}. -reply(State = #state{pendings = []}) -> +%%-------------------------------------------------------------------- +%% Ensure stats timer + +ensure_stats_timer(State = #state{idle_timeout = Timeout, + stats_timer = undefined}) -> + State#state{stats_timer = start_timer(Timeout, emit_stats)}; +ensure_stats_timer(State) -> State. + +-compile({inline, [enqueue/2, return/1, shutdown/2, stop/2]}). + +%%-------------------------------------------------------------------- +%% Postpone the packet, cmd or event + +postpone(Packet, State) when is_record(Packet, mqtt_packet) -> + enqueue(Packet, State); +postpone({outgoing, Packets}, State) -> + enqueue(Packets, State); +postpone(Event, State) when is_tuple(Event) -> + enqueue(Event, State); +postpone(More, State) when is_list(More) -> + lists:foldl(fun postpone/2, State, More). + +enqueue([Packet], State = #state{postponed = Postponed}) -> + State#state{postponed = [Packet|Postponed]}; +enqueue(Packets, State = #state{postponed = Postponed}) + when is_list(Packets) -> + State#state{postponed = lists:reverse(Packets) ++ Postponed}; +enqueue(Other, State = #state{postponed = Postponed}) -> + State#state{postponed = [Other|Postponed]}. + +return(State = #state{postponed = []}) -> {ok, State}; -reply(State = #state{pendings = Pendings}) -> - {Reply, NState} = handle_outgoing(Pendings, State), - {reply, Reply, NState#state{pendings = []}}. +return(State = #state{postponed = Postponed}) -> + {Packets, Cmds, Events} = classify(Postponed, [], [], []), + ok = lists:foreach(fun trigger/1, Events), + State1 = State#state{postponed = []}, + case {Packets, Cmds} of + {[], []} -> {ok, State1}; + {[], Cmds} -> {reply, Cmds, State1}; + {Packets, Cmds} -> + {Reply, State2} = handle_outgoing(Packets, State1), + {reply, [Reply|Cmds], State2} + end. -enqueue(Packet, State) when is_record(Packet, mqtt_packet) -> - enqueue([Packet], State); -enqueue(Packets, State = #state{pendings = Pendings}) -> - State#state{pendings = lists:append(Pendings, Packets)}. +classify([], Packets, Cmds, Events) -> + {Packets, Cmds, Events}; +classify([Packet|More], Packets, Cmds, Events) + when is_record(Packet, mqtt_packet) -> + classify(More, [Packet|Packets], Cmds, Events); +classify([Cmd = {active, _}|More], Packets, Cmds, Events) -> + classify(More, Packets, [Cmd|Cmds], Events); +classify([Cmd = close|More], Packets, Cmds, Events) -> + classify(More, Packets, [Cmd|Cmds], Events); +classify([Event|More], Packets, Cmds, Events) -> + classify(More, Packets, Cmds, [Event|Events]). + +trigger(Event) -> erlang:send(self(), Event). shutdown(Reason, State) -> stop({shutdown, Reason}, State). -stop(Reason, State = #state{pendings = []}) -> +stop(Reason, State = #state{postponed = []}) -> {stop, State#state{stop_reason = Reason}}; -stop(Reason, State = #state{pendings = Pendings}) -> - {Reply, State1} = handle_outgoing(Pendings, State), - State2 = State1#state{pendings = [], stop_reason = Reason}, - {reply, [Reply, close], State2}. +stop(Reason, State = #state{postponed = Postponed}) -> + return(State#state{postponed = [close|Postponed], + stop_reason = Reason}). + +%%-------------------------------------------------------------------- +%% For CT tests +%%-------------------------------------------------------------------- + +set_field(Name, Value, State) -> + Pos = emqx_misc:index_of(Name, record_info(fields, state)), + setelement(Pos+1, State, Value). diff --git a/test/emqx_broker_SUITE.erl b/test/emqx_broker_SUITE.erl index c3fd1d8ca..8d6a2750b 100644 --- a/test/emqx_broker_SUITE.erl +++ b/test/emqx_broker_SUITE.erl @@ -89,6 +89,7 @@ t_subscribers(_) -> t_subscriptions(_) -> emqx_broker:subscribe(<<"topic">>, <<"clientid">>, #{qos => 1}), + ok = timer:sleep(100), ?assertEqual(#{qos => 1, subid => <<"clientid">>}, proplists:get_value(<<"topic">>, emqx_broker:subscriptions(self()))), ?assertEqual(#{qos => 1, subid => <<"clientid">>}, diff --git a/test/emqx_channel_SUITE.erl b/test/emqx_channel_SUITE.erl index c582b3e68..161f7d8e5 100644 --- a/test/emqx_channel_SUITE.erl +++ b/test/emqx_channel_SUITE.erl @@ -23,20 +23,6 @@ -include("emqx_mqtt.hrl"). -include_lib("eunit/include/eunit.hrl"). --define(DEFAULT_CONNINFO, - #{peername => {{127,0,0,1}, 3456}, - sockname => {{127,0,0,1}, 1883}, - conn_mod => emqx_connection, - proto_name => <<"MQTT">>, - proto_ver => ?MQTT_PROTO_V5, - clean_start => true, - keepalive => 30, - clientid => <<"clientid">>, - username => <<"username">>, - conn_props => #{}, - receive_maximum => 100, - expiry_interval => 0 - }). all() -> emqx_ct:all(?MODULE). @@ -45,40 +31,40 @@ all() -> emqx_ct:all(?MODULE). %%-------------------------------------------------------------------- init_per_suite(Config) -> - Config. - -end_per_suite(_Config) -> - ok. - -init_per_testcase(_TestCase, Config) -> %% CM Meck - ok = meck:new(emqx_cm, [passthrough, no_history]), + ok = meck:new(emqx_cm, [passthrough, no_history, no_link]), %% Access Control Meck - ok = meck:new(emqx_access_control, [passthrough, no_history]), + ok = meck:new(emqx_access_control, [passthrough, no_history, no_link]), ok = meck:expect(emqx_access_control, authenticate, fun(_) -> {ok, #{auth_result => success}} end), ok = meck:expect(emqx_access_control, check_acl, fun(_, _, _) -> allow end), %% Broker Meck - ok = meck:new(emqx_broker, [passthrough, no_history]), + ok = meck:new(emqx_broker, [passthrough, no_history, no_link]), %% Hooks Meck - ok = meck:new(emqx_hooks, [passthrough, no_history]), + ok = meck:new(emqx_hooks, [passthrough, no_history, no_link]), ok = meck:expect(emqx_hooks, run, fun(_Hook, _Args) -> ok end), ok = meck:expect(emqx_hooks, run_fold, fun(_Hook, _Args, Acc) -> Acc end), %% Session Meck - ok = meck:new(emqx_session, [passthrough, no_history]), + ok = meck:new(emqx_session, [passthrough, no_history, no_link]), %% Metrics - ok = meck:new(emqx_metrics, [passthrough, no_history]), + ok = meck:new(emqx_metrics, [passthrough, no_history, no_link]), ok = meck:expect(emqx_metrics, inc, fun(_) -> ok end), ok = meck:expect(emqx_metrics, inc, fun(_, _) -> ok end), Config. -end_per_testcase(_TestCase, Config) -> +end_per_suite(_Config) -> ok = meck:unload(emqx_access_control), ok = meck:unload(emqx_metrics), ok = meck:unload(emqx_session), ok = meck:unload(emqx_broker), ok = meck:unload(emqx_hooks), ok = meck:unload(emqx_cm), + ok. + +init_per_testcase(_TestCase, Config) -> + Config. + +end_per_testcase(_TestCase, Config) -> Config. %%-------------------------------------------------------------------- @@ -328,15 +314,6 @@ t_process_unsubscribe(_) -> %%-------------------------------------------------------------------- t_handle_deliver(_) -> - WithPacketId = fun(Msgs) -> - lists:zip(lists:seq(1, length(Msgs)), Msgs) - end, - ok = meck:expect(emqx_session, deliver, - fun(Delivers, Session) -> - Publishes = WithPacketId([Msg || {deliver, _, Msg} <- Delivers]), - {ok, Publishes, Session} - end), - ok = meck:expect(emqx_session, info, fun(retry_interval, _Session) -> 20 end), Msg0 = emqx_message:make(test, ?QOS_1, <<"t1">>, <<"qos1">>), Msg1 = emqx_message:make(test, ?QOS_2, <<"t2">>, <<"qos2">>), Delivers = [{deliver, <<"+">>, Msg0}, {deliver, <<"+">>, Msg1}], @@ -426,7 +403,7 @@ t_handle_call_discard(_) -> emqx_channel:handle_call(discard, channel()). t_handle_call_takeover_begin(_) -> - {reply, undefined, _Chan} = emqx_channel:handle_call({takeover, 'begin'}, channel()). + {reply, _Session, _Chan} = emqx_channel:handle_call({takeover, 'begin'}, channel()). t_handle_call_takeover_end(_) -> ok = meck:expect(emqx_session, takeover, fun(_) -> ok end), @@ -565,14 +542,27 @@ t_terminate(_) -> channel() -> channel(#{}). channel(InitFields) -> + ConnInfo = #{peername => {{127,0,0,1}, 3456}, + sockname => {{127,0,0,1}, 1883}, + conn_mod => emqx_connection, + proto_name => <<"MQTT">>, + proto_ver => ?MQTT_PROTO_V5, + clean_start => true, + keepalive => 30, + clientid => <<"clientid">>, + username => <<"username">>, + conn_props => #{}, + receive_maximum => 100, + expiry_interval => 0 + }, maps:fold(fun(Field, Value, Channel) -> emqx_channel:set_field(Field, Value, Channel) - end, default_channel(), InitFields). - -default_channel() -> - Channel = emqx_channel:init(?DEFAULT_CONNINFO, [{zone, zone}]), - Channel1 = emqx_channel:set_field(conn_state, connected, Channel), - emqx_channel:set_field(clientinfo, clientinfo(), Channel1). + end, + emqx_channel:init(ConnInfo, [{zone, zone}]), + maps:merge(#{clientinfo => clientinfo(), + session => session(), + conn_state => connected + }, InitFields)). clientinfo() -> clientinfo(#{}). clientinfo(InitProps) -> diff --git a/test/emqx_limiter_SUITE.erl b/test/emqx_limiter_SUITE.erl index dc291d6bc..46c0f2070 100644 --- a/test/emqx_limiter_SUITE.erl +++ b/test/emqx_limiter_SUITE.erl @@ -24,15 +24,12 @@ all() -> emqx_ct:all(?MODULE). init_per_testcase(_TestCase, Config) -> - ok = meck:new(emqx_zone, [passthrough, no_history]), Config. end_per_testcase(_TestCase, _Config) -> - ok = meck:unload(emqx_zone). + ok. t_info(_) -> - meck:expect(emqx_zone, publish_limit, fun(_) -> {1, 10} end), - Limiter = emqx_limiter:init([{rate_limit, {100, 1000}}]), #{pub_limit := #{rate := 1, burst := 10, tokens := 10 @@ -41,22 +38,21 @@ t_info(_) -> burst := 1000, tokens := 1000 } - } = emqx_limiter:info(Limiter). + } = emqx_limiter:info(limiter()). t_check(_) -> - meck:expect(emqx_zone, publish_limit, fun(_) -> {1, 10} end), - Limiter = emqx_limiter:init([{rate_limit, {100, 1000}}]), lists:foreach(fun(I) -> - {ok, Limiter1} = emqx_limiter:check(#{cnt => I, oct => I*100}, Limiter), + {ok, Limiter} = emqx_limiter:check(#{cnt => I, oct => I*100}, limiter()), #{pub_limit := #{tokens := Cnt}, rate_limit := #{tokens := Oct} - } = emqx_limiter:info(Limiter1), + } = emqx_limiter:info(Limiter), ?assertEqual({10 - I, 1000 - I*100}, {Cnt, Oct}) end, lists:seq(1, 10)). t_check_pause(_) -> - meck:expect(emqx_zone, publish_limit, fun(_) -> {1, 10} end), - Limiter = emqx_limiter:init([{rate_limit, {100, 1000}}]), - {pause, 1000, _} = emqx_limiter:check(#{cnt => 11, oct => 2000}, Limiter), - {pause, 2000, _} = emqx_limiter:check(#{cnt => 10, oct => 1200}, Limiter). + {pause, 1000, _} = emqx_limiter:check(#{cnt => 11, oct => 2000}, limiter()), + {pause, 2000, _} = emqx_limiter:check(#{cnt => 10, oct => 1200}, limiter()). + +limiter() -> + emqx_limiter:init([{pub_limit, {1, 10}}, {rate_limit, {100, 1000}}]). diff --git a/test/emqx_session_SUITE.erl b/test/emqx_session_SUITE.erl index 1112f21ec..9f9aff825 100644 --- a/test/emqx_session_SUITE.erl +++ b/test/emqx_session_SUITE.erl @@ -29,16 +29,21 @@ all() -> emqx_ct:all(?MODULE). %% CT callbacks %%-------------------------------------------------------------------- -init_per_testcase(_TestCase, Config) -> - %% Meck Broker - ok = meck:new(emqx_broker, [passthrough, no_history]), - ok = meck:new(emqx_hooks, [passthrough, no_history]), +init_per_suite(Config) -> + %% Broker + ok = meck:new(emqx_broker, [passthrough, no_history, no_link]), + ok = meck:new(emqx_hooks, [passthrough, no_history, no_link]), ok = meck:expect(emqx_hooks, run, fun(_Hook, _Args) -> ok end), Config. -end_per_testcase(_TestCase, Config) -> +end_per_suite(_Config) -> ok = meck:unload(emqx_broker), - ok = meck:unload(emqx_hooks), + ok = meck:unload(emqx_hooks). + +init_per_testcase(_TestCase, Config) -> + Config. + +end_per_testcase(_TestCase, Config) -> Config. %%-------------------------------------------------------------------- @@ -330,7 +335,7 @@ t_replay(_) -> {ok, Pubs, Session1} = emqx_session:deliver(Delivers, session()), Msg = emqx_message:make(clientid, ?QOS_1, <<"t1">>, <<"payload">>), Session2 = emqx_session:enqueue(Msg, Session1), - Pubs1 = [{I, emqx_message:set_flag(dup, Msg)} || {I, Msg} <- Pubs], + Pubs1 = [{I, emqx_message:set_flag(dup, M)} || {I, M} <- Pubs], {ok, ReplayPubs, Session3} = emqx_session:replay(Session2), ?assertEqual(Pubs1 ++ [{3, Msg}], ReplayPubs), ?assertEqual(3, emqx_session:info(inflight_cnt, Session3)). diff --git a/test/emqx_ws_connection_SUITE.erl b/test/emqx_ws_connection_SUITE.erl index ae09626ee..f1a1e2347 100644 --- a/test/emqx_ws_connection_SUITE.erl +++ b/test/emqx_ws_connection_SUITE.erl @@ -16,6 +16,7 @@ -module(emqx_ws_connection_SUITE). +-include("emqx.hrl"). -include("emqx_mqtt.hrl"). -include_lib("eunit/include/eunit.hrl"). @@ -25,12 +26,15 @@ -import(emqx_ws_connection, [ websocket_handle/2 , websocket_info/2 + , websocket_close/2 ]). -define(STATS_KEYS, [recv_oct, recv_cnt, send_oct, send_cnt, recv_pkt, recv_msg, send_pkt, send_msg ]). +-define(ws_conn, emqx_ws_connection). + all() -> emqx_ct:all(?MODULE). %%-------------------------------------------------------------------- @@ -38,184 +42,376 @@ all() -> emqx_ct:all(?MODULE). %%-------------------------------------------------------------------- init_per_suite(Config) -> - Config. - -end_per_suite(_Config) -> - ok. - -init_per_testcase(_TestCase, Config) -> - %% Meck CowboyReq - ok = meck:new(cowboy_req, [passthrough, no_history]), + %% Mock cowboy_req + ok = meck:new(cowboy_req, [passthrough, no_history, no_link]), ok = meck:expect(cowboy_req, peer, fun(_) -> {{127,0,0,1}, 3456} end), - ok = meck:expect(cowboy_req, sock, fun(_) -> {{127,0,0,1}, 8883} end), + ok = meck:expect(cowboy_req, sock, fun(_) -> {{127,0,0,1}, 18083} end), ok = meck:expect(cowboy_req, cert, fun(_) -> undefined end), - ok = meck:expect(cowboy_req, parse_cookies, fun(_) -> undefined end), - %% Meck Channel - ok = meck:new(emqx_channel, [passthrough, no_history]), - %% Meck Metrics - ok = meck:new(emqx_metrics, [passthrough, no_history]), + ok = meck:expect(cowboy_req, parse_cookies, fun(_) -> error(badarg) end), + %% Mock emqx_zone + ok = meck:new(emqx_zone, [passthrough, no_history, no_link]), + ok = meck:expect(emqx_zone, oom_policy, + fun(_) -> #{max_heap_size => 838860800, + message_queue_len => 8000 + } + end), + %% Mock emqx_access_control + ok = meck:new(emqx_access_control, [passthrough, no_history, no_link]), + ok = meck:expect(emqx_access_control, check_acl, fun(_, _, _) -> allow end), + %% Mock emqx_hooks + ok = meck:new(emqx_hooks, [passthrough, no_history, no_link]), + ok = meck:expect(emqx_hooks, run, fun(_Hook, _Args) -> ok end), + ok = meck:expect(emqx_hooks, run_fold, fun(_Hook, _Args, Acc) -> Acc end), + %% Mock emqx_broker + ok = meck:new(emqx_broker, [passthrough, no_history, no_link]), + ok = meck:expect(emqx_broker, subscribe, fun(_, _, _) -> ok end), + ok = meck:expect(emqx_broker, publish, fun(#message{topic = Topic}) -> + [{node(), Topic, 1}] + end), + ok = meck:expect(emqx_broker, unsubscribe, fun(_) -> ok end), + %% Mock emqx_metrics + ok = meck:new(emqx_metrics, [passthrough, no_history, no_link]), + ok = meck:expect(emqx_metrics, inc, fun(_) -> ok end), ok = meck:expect(emqx_metrics, inc, fun(_, _) -> ok end), ok = meck:expect(emqx_metrics, inc_recv, fun(_) -> ok end), ok = meck:expect(emqx_metrics, inc_sent, fun(_) -> ok end), Config. +end_per_suite(_Config) -> + lists:foreach(fun meck:unload/1, + [cowboy_req, + emqx_zone, + emqx_access_control, + emqx_broker, + emqx_hooks, + emqx_metrics + ]). + +init_per_testcase(_TestCase, Config) -> + Config. + end_per_testcase(_TestCase, Config) -> - ok = meck:unload(cowboy_req), - ok = meck:unload(emqx_channel), - ok = meck:unload(emqx_metrics), Config. %%-------------------------------------------------------------------- %% Test Cases %%-------------------------------------------------------------------- -%%TODO:... -t_ws_conn_init(_) -> - with_ws_conn(fun(_WsConn) -> ok end). +t_info(_) -> + WsPid = spawn(fun() -> + receive {call, From, info} -> + gen_server:reply(From, ?ws_conn:info(st())) + end + end), + #{sockinfo := SockInfo} = ?ws_conn:call(WsPid, info), + #{socktype := ws, + active_n := 100, + peername := {{127,0,0,1}, 3456}, + sockname := {{127,0,0,1}, 18083}, + sockstate := running + } = SockInfo. -t_ws_conn_info(_) -> - with_ws_conn(fun(WsConn) -> - #{sockinfo := SockInfo} = emqx_ws_connection:info(WsConn), - #{socktype := ws, - peername := {{127,0,0,1}, 3456}, - sockname := {{127,0,0,1}, 8883}, - sockstate := running} = SockInfo - end). +t_info_limiter(_) -> + St = st(#{limiter => emqx_limiter:init([])}), + ?assertEqual(undefined, ?ws_conn:info(limiter, St)). -t_websocket_init(_) -> - with_ws_conn(fun(WsConn) -> - #{sockinfo := SockInfo} = emqx_ws_connection:info(WsConn), - #{socktype := ws, - peername := {{127,0,0,1}, 3456}, - sockname := {{127,0,0,1}, 8883}, - sockstate := running - } = SockInfo - end). +t_info_channel(_) -> + #{conn_state := connected} = ?ws_conn:info(channel, st()). + +t_info_gc_state(_) -> + GcSt = emqx_gc:init(#{count => 10, bytes => 1000}), + GcInfo = ?ws_conn:info(gc_state, st(#{gc_state => GcSt})), + ?assertEqual(#{cnt => {10,10}, oct => {1000,1000}}, GcInfo). + +t_info_postponed(_) -> + ?assertEqual([], ?ws_conn:info(postponed, st())), + St = ?ws_conn:postpone({active, false}, st()), + ?assertEqual([{active, false}], ?ws_conn:info(postponed, St)). + +t_info_stop_reason(_) -> + St = st(#{stop_reason => normal}), + ?assertEqual(normal, ?ws_conn:info(stop_reason, St)). + +t_stats(_) -> + WsPid = spawn(fun() -> + receive {call, From, stats} -> + gen_server:reply(From, ?ws_conn:stats(st())) + end + end), + Stats = ?ws_conn:call(WsPid, stats), + [{recv_oct, 0}, {recv_cnt, 0}, {send_oct, 0}, {send_cnt, 0}, + {recv_pkt, 0}, {recv_msg, 0}, {send_pkt, 0}, {send_msg, 0}|_] = Stats. + +t_call(_) -> + Info = ?ws_conn:info(st()), + WsPid = spawn(fun() -> + receive {call, From, info} -> gen_server:reply(From, Info) end + end), + ?assertEqual(Info, ?ws_conn:call(WsPid, info)). + +t_init(_) -> + Opts = [{idle_timeout, 300000}], + WsOpts = #{compress => false, + deflate_opts => #{}, + max_frame_size => infinity, + idle_timeout => 300000 + }, + ok = meck:expect(cowboy_req, parse_header, fun(_, req) -> undefined end), + {cowboy_websocket, req, [req, Opts], WsOpts} = ?ws_conn:init(req, Opts), + ok = meck:expect(cowboy_req, parse_header, fun(_, req) -> [<<"mqtt">>] end), + ok = meck:expect(cowboy_req, set_resp_header, fun(_, <<"mqtt">>, req) -> resp end), + {cowboy_websocket, resp, [req, Opts], WsOpts} = ?ws_conn:init(req, Opts). t_websocket_handle_binary(_) -> - with_ws_conn(fun(WsConn) -> - {ok, _} = websocket_handle({binary, [<<>>]}, WsConn) - end). + {ok, _} = websocket_handle({binary, <<>>}, st()), + {ok, _} = websocket_handle({binary, [<<>>]}, st()), + {ok, _} = websocket_handle({binary, <<192,0>>}, st()), + receive {incoming, ?PACKET(?PINGREQ)} -> ok + after 0 -> error(expect_incoming_pingreq) + end. -t_websocket_handle_ping_pong(_) -> - with_ws_conn(fun(WsConn) -> - {ok, WsConn} = websocket_handle(ping, WsConn), - {ok, WsConn} = websocket_handle(pong, WsConn), - {ok, WsConn} = websocket_handle({ping, <<>>}, WsConn), - {ok, WsConn} = websocket_handle({pong, <<>>}, WsConn) - end). +t_websocket_handle_ping(_) -> + {ok, St} = websocket_handle(ping, St = st()), + {ok, St} = websocket_handle({ping, <<>>}, St). + +t_websocket_handle_pong(_) -> + {ok, St} = websocket_handle(pong, St = st()), + {ok, St} = websocket_handle({pong, <<>>}, St). t_websocket_handle_bad_frame(_) -> - with_ws_conn(fun(WsConn) -> - {stop, WsConn1} = websocket_handle({badframe, <<>>}, WsConn), - ?assertEqual({shutdown, unexpected_ws_frame}, stop_reason(WsConn1)) - end). + {stop, St} = websocket_handle({badframe, <<>>}, st()), + {shutdown, unexpected_ws_frame} = ?ws_conn:info(stop_reason, St). t_websocket_info_call(_) -> - with_ws_conn(fun(WsConn) -> - From = {make_ref(), self()}, - Call = {call, From, badreq}, - websocket_info(Call, WsConn) - end). + From = {make_ref(), self()}, + Call = {call, From, badreq}, + {ok, _St} = websocket_info(Call, st()). + +t_websocket_info_rate_limit(_) -> + {ok, _} = websocket_info({cast, rate_limit}, st()), + ok = timer:sleep(1), + receive + {check_gc, Stats} -> + ?assertEqual(#{cnt => 0, oct => 0}, Stats) + after 0 -> error(expect_check_gc) + end. t_websocket_info_cast(_) -> - ok = meck:expect(emqx_channel, handle_info, fun(_Msg, Channel) -> {ok, Channel} end), - with_ws_conn(fun(WsConn) -> websocket_info({cast, msg}, WsConn) end). + {ok, _St} = websocket_info({cast, msg}, st()). t_websocket_info_incoming(_) -> - ok = meck:expect(emqx_channel, handle_in, fun(_Packet, Channel) -> {ok, Channel} end), - with_ws_conn(fun(WsConn) -> - Connect = ?CONNECT_PACKET( - #mqtt_packet_connect{proto_ver = ?MQTT_PROTO_V5, - proto_name = <<"MQTT">>, - clientid = <<>>, - clean_start = true, - keepalive = 60}), - {ok, WsConn1} = websocket_info({incoming, Connect}, WsConn), - Publish = ?PUBLISH_PACKET(?QOS_1, <<"t">>, 1, <<"payload">>), - {ok, _WsConn2} = websocket_info({incoming, Publish}, WsConn1) - end). + ConnPkt = #mqtt_packet_connect{ + proto_name = <<"MQTT">>, + proto_ver = ?MQTT_PROTO_V5, + is_bridge = false, + clean_start = true, + keepalive = 60, + properties = undefined, + clientid = <<"clientid">>, + username = <<"username">>, + password = <<"passwd">> + }, + {reply, [{binary, IoData1}], St1} = + websocket_info({incoming, ?CONNECT_PACKET(ConnPkt)}, st()), + ?assertEqual(<<224,2,130,0>>, iolist_to_binary(IoData1)), + %% PINGREQ + {reply, [{binary, IoData2}], St2} = + websocket_info({incoming, ?PACKET(?PINGREQ)}, St1), + ?assertEqual(<<208,0>>, iolist_to_binary(IoData2)), + %% PUBLISH + Publish = ?PUBLISH_PACKET(?QOS_1, <<"t">>, 1, <<"payload">>), + {reply, [{binary, IoData3}], _St3} = + websocket_info({incoming, Publish}, St2), + ?assertEqual(<<64,4,0,1,0,0>>, iolist_to_binary(IoData3)). + +t_websocket_info_check_gc(_) -> + Stats = #{cnt => 10, oct => 1000}, + {ok, _St} = websocket_info({check_gc, Stats}, st()). t_websocket_info_deliver(_) -> - with_ws_conn(fun(WsConn) -> - ok = meck:expect(emqx_channel, handle_deliver, - fun(Delivers, Channel) -> - Packets = [emqx_message:to_packet(1, Msg) || {deliver, _, Msg} <- Delivers], - {ok, {outgoing, Packets}, Channel} - end), - Deliver = {deliver, <<"#">>, emqx_message:make(<<"topic">>, <<"payload">>)}, - {reply, {binary, _Data}, _WsConn1} = websocket_info(Deliver, WsConn) - end). + Msg0 = emqx_message:make(clientid, ?QOS_0, <<"t">>, <<"">>), + Msg1 = emqx_message:make(clientid, ?QOS_1, <<"t">>, <<"">>), + self() ! {deliver, <<"#">>, Msg1}, + {reply, [{binary, IoData}], _St} = + websocket_info({deliver, <<"#">>, Msg0}, st()), + ?assertEqual(<<48,3,0,1,116,50,5,0,1,116,0,1>>, iolist_to_binary(IoData)). -t_websocket_info_timeout(_) -> - with_ws_conn(fun(WsConn) -> - websocket_info({timeout, make_ref(), keepalive}, WsConn), - websocket_info({timeout, make_ref(), emit_stats}, WsConn), - websocket_info({timeout, make_ref(), retry_delivery}, WsConn) - end). +t_websocket_info_timeout_limiter(_) -> + Ref = make_ref(), + Event = {timeout, Ref, limit_timeout}, + {reply, [{active, true}], St} = + websocket_info(Event, st(#{limit_timer => Ref})), + ?assertEqual([], ?ws_conn:info(postponed, St)). + +t_websocket_info_timeout_keepalive(_) -> + {ok, _St} = websocket_info({timeout, make_ref(), keepalive}, st()). + +t_websocket_info_timeout_emit_stats(_) -> + Ref = make_ref(), + St = st(#{stats_timer => Ref}), + {ok, St1} = websocket_info({timeout, Ref, emit_stats}, St), + ?assertEqual(undefined, ?ws_conn:info(stats_timer, St1)). + +t_websocket_info_timeout_retry(_) -> + {ok, _St} = websocket_info({timeout, make_ref(), retry_delivery}, st()). t_websocket_info_close(_) -> - with_ws_conn(fun(WsConn) -> - {stop, WsConn1} = websocket_info({close, sock_error}, WsConn), - ?assertEqual({shutdown, sock_error}, stop_reason(WsConn1)) - end). + {stop, St} = websocket_info({close, sock_error}, st()), + ?assertEqual({shutdown, sock_error}, ?ws_conn:info(stop_reason, St)). t_websocket_info_shutdown(_) -> - with_ws_conn(fun(WsConn) -> - {stop, WsConn1} = websocket_info({shutdown, reason}, WsConn), - ?assertEqual({shutdown, reason}, stop_reason(WsConn1)) - end). - + {stop, St} = websocket_info({shutdown, reason}, st()), + ?assertEqual({shutdown, reason}, ?ws_conn:info(stop_reason, St)). t_websocket_info_stop(_) -> - with_ws_conn(fun(WsConn) -> - {stop, WsConn1} = websocket_info({stop, normal}, WsConn), - ?assertEqual(normal, stop_reason(WsConn1)) - end). + {stop, St} = websocket_info({stop, normal}, st()), + ?assertEqual(normal, ?ws_conn:info(stop_reason, St)). t_websocket_close(_) -> - ok = meck:expect(emqx_channel, handle_info, - fun({sock_closed, badframe}, Channel) -> - {shutdown, sock_closed, Channel} - end), - with_ws_conn(fun(WsConn) -> - {stop, WsConn1} = emqx_ws_connection:websocket_close(badframe, WsConn), - ?assertEqual(sock_closed, stop_reason(WsConn1)) - end). + {stop, St} = websocket_close(badframe, st()), + ?assertEqual({shutdown, badframe}, ?ws_conn:info(stop_reason, St)). -t_handle_call(_) -> - with_ws_conn(fun(WsConn) -> ok end). +t_handle_info_connack(_) -> + ConnAck = ?CONNACK_PACKET(?RC_SUCCESS), + {reply, [{binary, IoData}], _St} = + ?ws_conn:handle_info({connack, ConnAck}, st()), + ?assertEqual(<<32,2,0,0>>, iolist_to_binary(IoData)). -t_handle_info(_) -> - with_ws_conn(fun(WsConn) -> ok end). +t_handle_info_close(_) -> + {stop, St} = ?ws_conn:handle_info({close, protocol_error}, st()), + ?assertEqual({shutdown, protocol_error}, + ?ws_conn:info(stop_reason, St)). -t_handle_timeout(_) -> - with_ws_conn(fun(WsConn) -> ok end). +t_handle_info_event(_) -> + ok = meck:new(emqx_cm, [passthrough, no_history]), + ok = meck:expect(emqx_cm, register_channel, fun(_,_,_) -> ok end), + ok = meck:expect(emqx_cm, connection_closed, fun(_) -> true end), + {ok, _} = ?ws_conn:handle_info({event, connected}, st()), + {ok, _} = ?ws_conn:handle_info({event, disconnected}, st()), + {ok, _} = ?ws_conn:handle_info({event, updated}, st()), + ok = meck:unload(emqx_cm). + +t_handle_timeout_idle_timeout(_) -> + TRef = make_ref(), + {stop, St} = ?ws_conn:handle_timeout( + TRef, idle_timeout, st(#{idle_timer => TRef})), + ?assertEqual({shutdown, idle_timeout}, ?ws_conn:info(stop_reason, St)). + +t_handle_timeout_keepalive(_) -> + {ok, _St} = ?ws_conn:handle_timeout(make_ref(), keepalive, st()). + +t_handle_timeout_emit_stats(_) -> + TRef = make_ref(), + {ok, St} = ?ws_conn:handle_timeout( + TRef, emit_stats, st(#{stats_timer => TRef})), + ?assertEqual(undefined, ?ws_conn:info(stats_timer, St)). + +t_ensure_rate_limit(_) -> + Limiter = emqx_limiter:init([{pub_limit, {1, 10}}, + {rate_limit, {100, 1000}} + ]), + St = st(#{limiter => Limiter}), + St1 = ?ws_conn:ensure_rate_limit(#{cnt => 0, oct => 0}, St), + St2 = ?ws_conn:ensure_rate_limit(#{cnt => 11, oct => 1200}, St1), + ?assertEqual(blocked, ?ws_conn:info(sockstate, St2)), + ?assertEqual([{active, false}], ?ws_conn:info(postponed, St2)). t_parse_incoming(_) -> - with_ws_conn(fun(WsConn) -> ok end). + St = ?ws_conn:parse_incoming(<<48,3>>, st()), + St1 = ?ws_conn:parse_incoming(<<0,1,116>>, St), + Packet = ?PUBLISH_PACKET(?QOS_0, <<"t">>, undefined, <<>>), + [{incoming, Packet}] = ?ws_conn:info(postponed, St1). -t_handle_incoming(_) -> - with_ws_conn(fun(WsConn) -> ok end). +t_parse_incoming_frame_error(_) -> + St = ?ws_conn:parse_incoming(<<3,2,1,0>>, st()), + FrameError = {frame_error, function_clause}, + [{incoming, FrameError}] = ?ws_conn:info(postponed, St). -t_handle_return(_) -> - with_ws_conn(fun(WsConn) -> ok end). +t_handle_incomming_frame_error(_) -> + FrameError = {frame_error, bad_qos}, + Serialize = emqx_frame:serialize_fun(#{version => 5, max_size => 16#FFFF}), + {reply, [{binary, IoData}], _St} = + ?ws_conn:handle_incoming(FrameError, st(#{serialize => Serialize})), + ?assertEqual(<<224,2,129,0>>, iolist_to_binary(IoData)). t_handle_outgoing(_) -> - with_ws_conn(fun(WsConn) -> ok end). + Packets = [?PUBLISH_PACKET(?QOS_1, <<"t1">>, 1, <<"payload">>), + ?PUBLISH_PACKET(?QOS_2, <<"t2">>, 2, <<"payload">>) + ], + {{binary, IoData}, _St} = ?ws_conn:handle_outgoing(Packets, st()), + ?assert(is_binary(iolist_to_binary(IoData))). + +t_run_gc(_) -> + GcSt = emqx_gc:init(#{count => 10, bytes => 100}), + WsSt = st(#{gc_state => GcSt}), + ?ws_conn:run_gc(#{cnt => 100, oct => 10000}, WsSt). + +t_check_oom(_) -> + %%Policy = #{max_heap_size => 10, message_queue_len => 10}, + %%meck:expect(emqx_zone, oom_policy, fun(_) -> Policy end), + _St = ?ws_conn:check_oom(st()), + ok = timer:sleep(10). + %%receive {shutdown, proc_heap_too_large} -> ok + %%after 0 -> error(expect_shutdown) + %%end. + +t_enqueue(_) -> + Packet = ?PUBLISH_PACKET(?QOS_0), + St = ?ws_conn:enqueue(Packet, st()), + [Packet] = ?ws_conn:info(postponed, St). + +t_shutdown(_) -> + {stop, St} = ?ws_conn:shutdown(closed, st()), + {shutdown, closed} = ?ws_conn:info(stop_reason, St). + +t_stop(_) -> + St = st(#{postponed => [{active, false}]}), + {reply, [{active, false}, close], _} = ?ws_conn:stop(closed, St). %%-------------------------------------------------------------------- %% Helper functions %%-------------------------------------------------------------------- -with_ws_conn(TestFun) -> - with_ws_conn(TestFun, []). +st() -> st(#{}). +st(InitFields) when is_map(InitFields) -> + {ok, St, _} = ?ws_conn:websocket_init([req, [{zone, external}]]), + maps:fold(fun(N, V, S) -> ?ws_conn:set_field(N, V, S) end, + ?ws_conn:set_field(channel, channel(), St), + InitFields + ). -with_ws_conn(TestFun, Opts) -> - {ok, WsConn, _} = emqx_ws_connection:websocket_init( - [req, emqx_misc:merge_opts([{zone, external}], Opts)]), - TestFun(WsConn). - -stop_reason(WsConn) -> - emqx_ws_connection:info(stop_reason, WsConn). +channel() -> channel(#{}). +channel(InitFields) -> + ConnInfo = #{peername => {{127,0,0,1}, 3456}, + sockname => {{127,0,0,1}, 18083}, + conn_mod => emqx_ws_connection, + proto_name => <<"MQTT">>, + proto_ver => ?MQTT_PROTO_V5, + clean_start => true, + keepalive => 30, + clientid => <<"clientid">>, + username => <<"username">>, + receive_maximum => 100, + expiry_interval => 0 + }, + ClientInfo = #{zone => zone, + protocol => mqtt, + peerhost => {127,0,0,1}, + clientid => <<"clientid">>, + username => <<"username">>, + is_superuser => false, + peercert => undefined, + mountpoint => undefined + }, + Session = emqx_session:init(#{zone => external}, + #{receive_maximum => 0} + ), + maps:fold(fun(Field, Value, Channel) -> + emqx_channel:set_field(Field, Value, Channel) + end, + emqx_channel:init(ConnInfo, [{zone, zone}]), + maps:merge(#{clientinfo => ClientInfo, + session => Session, + conn_state => connected + }, InitFields)).