Improve 'emqx_ws_connection' module and add more test cases

This commit is contained in:
Feng Lee 2019-12-13 17:09:59 +08:00
parent 2a3baed7e0
commit c5c99b7c4e
10 changed files with 617 additions and 361 deletions

View File

@ -118,6 +118,8 @@ info(Keys, Channel) when is_list(Keys) ->
[{Key, info(Key, Channel)} || Key <- Keys];
info(conninfo, #channel{conninfo = ConnInfo}) ->
ConnInfo;
info(zone, #channel{clientinfo = #{zone := Zone}}) ->
Zone;
info(clientid, #channel{clientinfo = #{clientid := ClientId}}) ->
ClientId;
info(clientinfo, #channel{clientinfo = ClientInfo}) ->
@ -147,11 +149,6 @@ stats(#channel{session = Session})->
caps(#channel{clientinfo = #{zone := Zone}}) ->
emqx_mqtt_caps:get_caps(Zone).
%% For tests
set_field(Name, Val, Channel) ->
Fields = record_info(fields, channel),
Pos = emqx_misc:index_of(Name, Fields),
setelement(Pos+1, Channel, Val).
%%--------------------------------------------------------------------
%% Init the channel
@ -1309,3 +1306,11 @@ sp(false) -> 0.
flag(true) -> 1;
flag(false) -> 0.
%%--------------------------------------------------------------------
%% For CT tests
%%--------------------------------------------------------------------
set_field(Name, Value, Channel) ->
Pos = emqx_misc:index_of(Name, record_info(fields, channel)),
setelement(Pos+1, Channel, Value).

View File

@ -53,6 +53,9 @@
%% Internal callback
-export([wakeup_from_hib/3]).
%% Export for CT
-export([set_field/3]).
-import(emqx_misc,
[ maybe_apply/2
, start_timer/2
@ -181,7 +184,8 @@ do_init(Parent, Transport, Socket, Options) ->
},
Zone = proplists:get_value(zone, Options),
ActiveN = proplists:get_value(active_n, Options, ?ACTIVE_N),
Limiter = emqx_limiter:init(Options),
PubLimit = emqx_zone:publish_limit(Zone),
Limiter = emqx_limiter:init([{pub_limit, PubLimit}|Options]),
FrameOpts = emqx_zone:mqtt_frame_options(Zone),
ParseState = emqx_frame:initial_parse_state(FrameOpts),
Serialize = emqx_frame:serialize_fun(),
@ -491,7 +495,7 @@ parse_incoming(Data, Packets, State = #state{parse_state = ParseState}) ->
parse_incoming(Rest, [Packet|Packets], NState)
catch
error:Reason:Stk ->
?LOG(error, "~nParse failed for ~p~nStacktrace: ~p~nFrame data:~p",
?LOG(error, "~nParse failed for ~p~n~p~nFrame data:~p",
[Reason, Stk, Data]),
{[{frame_error, Reason}|Packets], State}
end.
@ -613,12 +617,12 @@ run_gc(Stats, State = #state{gc_state = GcSt}) ->
case ?ENABLED(GcSt) andalso emqx_gc:run(Stats, GcSt) of
false -> State;
{IsGC, GcSt1} ->
IsGC andalso emqx_metrics:inc('channel.gc.cnt'),
IsGC andalso emqx_metrics:inc('channel.gc'),
State#state{gc_state = GcSt1}
end.
check_oom(State = #state{channel = Channel}) ->
#{zone := Zone} = emqx_channel:info(clientinfo, Channel),
Zone = emqx_channel:info(zone, Channel),
OomPolicy = emqx_zone:oom_policy(Zone),
case ?ENABLED(OomPolicy) andalso emqx_misc:check_oom(OomPolicy) of
Shutdown = {shutdown, _Reason} ->
@ -706,3 +710,11 @@ stop(Reason, State) ->
stop(Reason, Reply, State) ->
{stop, Reason, Reply, State}.
%%--------------------------------------------------------------------
%% For CT tests
%%--------------------------------------------------------------------
set_field(Name, Value, State) ->
Pos = emqx_misc:index_of(Name, record_info(fields, state)),
setelement(Pos+1, State, Value).

View File

@ -37,8 +37,7 @@
-spec(init(proplists:proplist()) -> maybe(limiter())).
init(Options) ->
Zone = proplists:get_value(zone, Options),
Pl = emqx_zone:publish_limit(Zone),
Pl = proplists:get_value(pub_limit, Options),
Rl = proplists:get_value(rate_limit, Options),
case ?ENABLED(Pl) or ?ENABLED(Rl) of
true -> #limiter{pub_limit = init_limit(Pl),

View File

@ -651,8 +651,7 @@ age(Now, Ts) -> Now - Ts.
%% For CT tests
%%--------------------------------------------------------------------
set_field(Name, Val, Channel) ->
Fields = record_info(fields, session),
Pos = emqx_misc:index_of(Name, Fields),
setelement(Pos+1, Channel, Val).
set_field(Name, Value, Session) ->
Pos = emqx_misc:index_of(Name, record_info(fields, session)),
setelement(Pos+1, Session, Value).

View File

@ -45,13 +45,16 @@
, terminate/3
]).
%% Export for CT
-export([set_field/3]).
-import(emqx_misc,
[ maybe_apply/2
, start_timer/2
]).
-record(state, {
%% Peername of the ws connection.
%% Peername of the ws connection
peername :: emqx_types:peername(),
%% Sockname of the ws connection
sockname :: emqx_types:peername(),
@ -65,26 +68,28 @@
limit_timer :: maybe(reference()),
%% Parse State
parse_state :: emqx_frame:parse_state(),
%% Serialize function
%% Serialize Fun
serialize :: emqx_frame:serialize_fun(),
%% Channel
channel :: emqx_channel:channel(),
%% GC State
gc_state :: maybe(emqx_gc:gc_state()),
%% Out Pending Packets
pendings :: list(emqx_types:packet()),
%% Postponed Packets|Cmds|Events
postponed :: list(emqx_types:packet()|ws_cmd()|tuple()),
%% Stats Timer
stats_timer :: disabled | maybe(reference()),
%% Idle Timeout
idle_timeout :: timeout(),
%% Idle Timer
idle_timer :: reference(),
%% The stop reason
%% The Stop Reason
stop_reason :: term()
}).
-type(state() :: #state{}).
-type(ws_cmd() :: {active, boolean()}|close).
-define(ACTIVE_N, 100).
-define(INFO_KEYS, [socktype, peername, sockname, sockstate, active_n]).
-define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]).
@ -93,7 +98,7 @@
-define(ENABLED(X), (X =/= undefined)).
%%--------------------------------------------------------------------
%% API
%% Info, Stats
%%--------------------------------------------------------------------
-spec(info(pid()|state()) -> emqx_types:infos()).
@ -120,6 +125,12 @@ info(limiter, #state{limiter = Limiter}) ->
maybe_apply(fun emqx_limiter:info/1, Limiter);
info(channel, #state{channel = Channel}) ->
emqx_channel:info(Channel);
info(gc_state, #state{gc_state = GcSt}) ->
maybe_apply(fun emqx_gc:info/1, GcSt);
info(postponed, #state{postponed = Postponed}) ->
Postponed;
info(stats_timer, #state{stats_timer = TRef}) ->
TRef;
info(stop_reason, #state{stop_reason = Reason}) ->
Reason.
@ -201,8 +212,9 @@ websocket_init([Req, Opts]) ->
conn_mod => ?MODULE
},
Zone = proplists:get_value(zone, Opts),
PubLimit = emqx_zone:publish_limit(Zone),
Limiter = emqx_limiter:init([{pub_limit, PubLimit}|Opts]),
ActiveN = proplists:get_value(active_n, Opts, ?ACTIVE_N),
Limiter = emqx_limiter:init(Opts),
FrameOpts = emqx_zone:mqtt_frame_options(Zone),
ParseState = emqx_frame:initial_parse_state(FrameOpts),
Serialize = emqx_frame:serialize_fun(),
@ -223,7 +235,7 @@ websocket_init([Req, Opts]) ->
serialize = Serialize,
channel = Channel,
gc_state = GcState,
pendings = [],
postponed = [],
stats_timer = StatsTimer,
idle_timeout = IdleTimeout,
idle_timer = IdleTimer
@ -235,63 +247,60 @@ websocket_handle({binary, Data}, State) when is_list(Data) ->
websocket_handle({binary, Data}, State) ->
?LOG(debug, "RECV ~p", [Data]),
ok = inc_recv_stats(1, iolist_size(Data)),
parse_incoming(Data, ensure_stats_timer(State));
NState = ensure_stats_timer(State),
return(parse_incoming(Data, NState));
%% Pings should be replied with pongs, cowboy does it automatically
%% Pongs can be safely ignored. Clause here simply prevents crash.
websocket_handle(Frame, State) when Frame =:= ping; Frame =:= pong ->
{ok, State};
return(State);
websocket_handle({FrameType, _}, State) when FrameType =:= ping;
FrameType =:= pong ->
{ok, State};
websocket_handle({Frame, _}, State) when Frame =:= ping; Frame =:= pong ->
return(State);
websocket_handle({FrameType, _}, State) ->
?LOG(error, "Unexpected frame - ~p", [FrameType]),
stop({shutdown, unexpected_ws_frame}, State).
websocket_handle({Frame, _}, State) ->
%% TODO: should not close the ws connection
?LOG(error, "Unexpected frame - ~p", [Frame]),
shutdown(unexpected_ws_frame, State).
websocket_info({call, From, Req}, State) ->
handle_call(From, Req, State);
websocket_info({cast, Msg}, State = #state{channel = Channel}) ->
handle_chan_return(emqx_channel:handle_info(Msg, Channel), State);
websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)},
State = #state{idle_timer = IdleTimer}) ->
ok = emqx_misc:cancel_timer(IdleTimer),
Serialize = emqx_frame:serialize_fun(ConnPkt),
NState = State#state{serialize = Serialize,
idle_timer = undefined
websocket_info({cast, rate_limit}, State) ->
Stats = #{cnt => emqx_pd:reset_counter(incoming_pubs),
oct => emqx_pd:reset_counter(incoming_bytes)
},
handle_incoming(Packet, NState);
NState = postpone({check_gc, Stats}, State),
return(ensure_rate_limit(Stats, NState));
websocket_info({cast, Msg}, State) ->
handle_info(Msg, State);
websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State) ->
Serialize = emqx_frame:serialize_fun(ConnPkt),
NState = State#state{serialize = Serialize},
handle_incoming(Packet, cancel_idle_timer(NState));
websocket_info({incoming, ?PACKET(?PINGREQ)}, State) ->
reply(?PACKET(?PINGRESP), State);
return(enqueue(?PACKET(?PINGRESP), State));
websocket_info({incoming, Packet}, State) ->
handle_incoming(Packet, State);
websocket_info(rate_limit, State) ->
InStats = #{cnt => emqx_pd:reset_counter(incoming_pubs),
oct => emqx_pd:reset_counter(incoming_bytes)
},
erlang:send(self(), {check_gc, InStats}),
ensure_rate_limit(InStats, State);
websocket_info({check_gc, Stats}, State) ->
{ok, check_oom(run_gc(Stats, State))};
return(check_oom(run_gc(Stats, State)));
websocket_info(Deliver = {deliver, _Topic, _Msg},
State = #state{active_n = ActiveN, channel = Channel}) ->
State = #state{active_n = ActiveN}) ->
Delivers = [Deliver|emqx_misc:drain_deliver(ActiveN)],
Ret = emqx_channel:handle_deliver(Delivers, Channel),
handle_chan_return(Ret, State);
with_channel(handle_deliver, [Delivers], State);
websocket_info({timeout, TRef, limit_timeout}, State = #state{limit_timer = TRef}) ->
websocket_info({timeout, TRef, limit_timeout},
State = #state{limit_timer = TRef}) ->
NState = State#state{sockstate = running,
limit_timer = undefined
},
{reply, [{active, true}], NState};
return(enqueue({active, true}, NState));
websocket_info({timeout, TRef, Msg}, State) when is_reference(TRef) ->
handle_timeout(TRef, Msg, State);
@ -300,7 +309,7 @@ websocket_info(Close = {close, _Reason}, State) ->
handle_info(Close, State);
websocket_info({shutdown, Reason}, State) ->
stop({shutdown, Reason}, State);
shutdown(Reason, State);
websocket_info({stop, Reason}, State) ->
stop(Reason, State);
@ -312,68 +321,70 @@ websocket_close(Reason, State) ->
?LOG(debug, "WebSocket closed due to ~p~n", [Reason]),
handle_info({sock_closed, Reason}, State).
terminate(SockError, _Req, #state{channel = Channel,
stop_reason = Reason}) ->
?LOG(debug, "Terminated for ~p, sockerror: ~p", [Reason, SockError]),
terminate(Error, _Req, #state{channel = Channel, stop_reason = Reason}) ->
?LOG(debug, "Terminated for ~p, error: ~p", [Reason, Error]),
emqx_channel:terminate(Reason, Channel).
%%--------------------------------------------------------------------
%% Handle call
%%--------------------------------------------------------------------
handle_call(From, info, State) ->
gen_server:reply(From, info(State)),
{ok, State};
return(State);
handle_call(From, stats, State) ->
gen_server:reply(From, stats(State)),
{ok, State};
return(State);
handle_call(From, Req, State = #state{channel = Channel}) ->
case emqx_channel:handle_call(Req, Channel) of
{reply, Reply, NChannel} ->
_ = gen_server:reply(From, Reply),
{ok, State#state{channel = NChannel}};
{stop, Reason, Reply, NChannel} ->
_ = gen_server:reply(From, Reply),
stop(Reason, State#state{channel = NChannel});
{stop, Reason, Reply, OutPacket, NChannel} ->
gen_server:reply(From, Reply),
return(State#state{channel = NChannel});
{shutdown, Reason, Reply, NChannel} ->
gen_server:reply(From, Reply),
shutdown(Reason, State#state{channel = NChannel});
{shutdown, Reason, Reply, Packet, NChannel} ->
gen_server:reply(From, Reply),
NState = State#state{channel = NChannel},
stop(Reason, enqueue(OutPacket, NState))
shutdown(Reason, enqueue(Packet, NState))
end.
%%--------------------------------------------------------------------
%% Handle Info
%%--------------------------------------------------------------------
handle_info({connack, ConnAck}, State) ->
reply(enqueue(ConnAck, State));
return(enqueue(ConnAck, State));
handle_info({close, Reason}, State) ->
stop({shutdown, Reason}, State);
%% TODO: close ws conn?
shutdown(Reason, State);
handle_info({event, connected}, State = #state{channel = Channel}) ->
ClientId = emqx_channel:info(clientid, Channel),
emqx_cm:register_channel(ClientId, info(State), stats(State)),
reply(State);
ok = emqx_cm:register_channel(ClientId, info(State), stats(State)),
return(State);
handle_info({event, disconnected}, State = #state{channel = Channel}) ->
ClientId = emqx_channel:info(clientid, Channel),
emqx_cm:set_chan_info(ClientId, info(State)),
emqx_cm:connection_closed(ClientId),
reply(State);
return(State);
handle_info({event, _Other}, State = #state{channel = Channel}) ->
ClientId = emqx_channel:info(clientid, Channel),
emqx_cm:set_chan_info(ClientId, info(State)),
emqx_cm:set_chan_stats(ClientId, stats(State)),
reply(State);
return(State);
handle_info(Info, State = #state{channel = Channel}) ->
Ret = emqx_channel:handle_info(Info, Channel),
handle_chan_return(Ret, State).
handle_info(Info, State) ->
with_channel(handle_info, [Info], State).
%%--------------------------------------------------------------------
%% Handle timeout
%%--------------------------------------------------------------------
handle_timeout(TRef, idle_timeout, State = #state{idle_timer = TRef}) ->
shutdown(idle_timeout, State);
@ -382,33 +393,24 @@ handle_timeout(TRef, keepalive, State) when is_reference(TRef) ->
RecvOct = emqx_pd:get_counter(recv_oct),
handle_timeout(TRef, {keepalive, RecvOct}, State);
handle_timeout(TRef, emit_stats, State =
#state{channel = Channel, stats_timer = TRef}) ->
handle_timeout(TRef, emit_stats, State = #state{channel = Channel,
stats_timer = TRef}) ->
ClientId = emqx_channel:info(clientid, Channel),
emqx_cm:set_chan_stats(ClientId, stats(State)),
reply(State#state{stats_timer = undefined});
return(State#state{stats_timer = undefined});
handle_timeout(TRef, TMsg, State = #state{channel = Channel}) ->
Ret = emqx_channel:handle_timeout(TRef, TMsg, Channel),
handle_chan_return(Ret, State).
%%--------------------------------------------------------------------
%% Ensure stats timer
-compile({inline, [ensure_stats_timer/1]}).
ensure_stats_timer(State = #state{idle_timeout = Timeout,
stats_timer = undefined}) ->
State#state{stats_timer = start_timer(Timeout, emit_stats)};
ensure_stats_timer(State) -> State.
handle_timeout(TRef, TMsg, State) ->
with_channel(handle_timeout, [TRef, TMsg], State).
%%--------------------------------------------------------------------
%% Ensure rate limit
%%--------------------------------------------------------------------
ensure_rate_limit(Stats, State = #state{limiter = Limiter}) ->
case ?ENABLED(Limiter) andalso emqx_limiter:check(Stats, Limiter) of
false -> {ok, State};
false -> State;
{ok, Limiter1} ->
{ok, State#state{limiter = Limiter1}};
State#state{limiter = Limiter1};
{pause, Time, Limiter1} ->
?LOG(debug, "Pause ~pms due to rate limit", [Time]),
TRef = start_timer(Time, limit_timeout),
@ -416,102 +418,108 @@ ensure_rate_limit(Stats, State = #state{limiter = Limiter}) ->
limiter = Limiter1,
limit_timer = TRef
},
{reply, [{active, false}], NState}
enqueue({active, false}, NState)
end.
%%--------------------------------------------------------------------
%% Run GC and Check OOM
%% Run GC, Check OOM
%%--------------------------------------------------------------------
run_gc(Stats, State = #state{gc_state = GcSt}) ->
case ?ENABLED(GcSt) andalso emqx_gc:run(Stats, GcSt) of
false -> State;
{IsGC, GcSt1} ->
IsGC andalso emqx_metrics:inc('channel.gc.cnt'),
State#state{gc_state = GcSt1}
IsGC andalso emqx_metrics:inc('channel.gc'),
State#state{gc_state = GcSt1};
false -> State
end.
check_oom(State = #state{channel = Channel}) ->
#{zone := Zone} = emqx_channel:info(clientinfo, Channel),
OomPolicy = emqx_zone:oom_policy(Zone),
OomPolicy = emqx_zone:oom_policy(emqx_channel:info(zone, Channel)),
case ?ENABLED(OomPolicy) andalso emqx_misc:check_oom(OomPolicy) of
Shutdown = {shutdown, _Reason} ->
erlang:send(self(), Shutdown);
_Other -> ok
end,
State.
postpone(Shutdown, State);
_Other -> State
end.
%%--------------------------------------------------------------------
%% Parse incoming data
%%--------------------------------------------------------------------
parse_incoming(<<>>, State) ->
{ok, State};
State;
parse_incoming(Data, State = #state{parse_state = ParseState}) ->
try emqx_frame:parse(Data, ParseState) of
{more, NParseState} ->
{ok, State#state{parse_state = NParseState}};
State#state{parse_state = NParseState};
{ok, Packet, Rest, NParseState} ->
erlang:send(self(), {incoming, Packet}),
parse_incoming(Rest, State#state{parse_state = NParseState})
NState = State#state{parse_state = NParseState},
parse_incoming(Rest, postpone({incoming, Packet}, NState))
catch
error:Reason:Stk ->
?LOG(error, "~nParse failed for ~p~nStacktrace: ~p~nFrame data: ~p",
?LOG(error, "~nParse failed for ~p~n~p~nFrame data: ~p",
[Reason, Stk, Data]),
self() ! {incoming, {frame_error, Reason}},
{ok, State}
FrameError = {frame_error, Reason},
postpone({incoming, FrameError}, State)
end.
%%--------------------------------------------------------------------
%% Handle incoming packet
%%--------------------------------------------------------------------
handle_incoming(Packet, State = #state{active_n = ActiveN, channel = Channel})
handle_incoming(Packet, State = #state{active_n = ActiveN})
when is_record(Packet, mqtt_packet) ->
?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]),
ok = inc_incoming_stats(Packet),
(emqx_pd:get_counter(incoming_pubs) > ActiveN)
andalso erlang:send(self(), rate_limit),
Ret = emqx_channel:handle_in(Packet, Channel),
handle_chan_return(Ret, State);
NState = case emqx_pd:get_counter(incoming_pubs) > ActiveN of
true -> postpone({cast, rate_limit}, State);
false -> State
end,
with_channel(handle_in, [Packet], NState);
handle_incoming(FrameError, State = #state{channel = Channel}) ->
handle_chan_return(emqx_channel:handle_in(FrameError, Channel), State).
handle_incoming(FrameError, State) ->
with_channel(handle_in, [FrameError], State).
%%--------------------------------------------------------------------
%% Handle channel return
%% With Channel
%%--------------------------------------------------------------------
handle_chan_return(ok, State) ->
reply(State);
handle_chan_return({ok, NChannel}, State) ->
reply(State#state{channel= NChannel});
handle_chan_return({ok, Replies, NChannel}, State) ->
reply(Replies, State#state{channel= NChannel});
handle_chan_return({shutdown, Reason, NChannel}, State) ->
stop(Reason, State#state{channel = NChannel});
handle_chan_return({shutdown, Reason, OutPacket, NChannel}, State) ->
with_channel(Fun, Args, State = #state{channel = Channel}) ->
case erlang:apply(emqx_channel, Fun, Args ++ [Channel]) of
ok -> return(State);
{ok, NChannel} ->
return(State#state{channel = NChannel});
{ok, Replies, NChannel} ->
return(postpone(Replies, State#state{channel= NChannel}));
{shutdown, Reason, NChannel} ->
shutdown(Reason, State#state{channel = NChannel});
{shutdown, Reason, Packet, NChannel} ->
NState = State#state{channel = NChannel},
stop(Reason, enqueue(OutPacket, NState)).
shutdown(Reason, postpone(Packet, NState))
end.
%%--------------------------------------------------------------------
%% Handle outgoing packets
%%--------------------------------------------------------------------
handle_outgoing(Packets, State = #state{active_n = ActiveN}) ->
IoData = lists:map(serialize_and_inc_stats_fun(State), Packets),
Oct = iolist_size(IoData),
ok = inc_sent_stats(length(Packets), Oct),
case emqx_pd:get_counter(outgoing_pubs) > ActiveN of
NState = case emqx_pd:get_counter(outgoing_pubs) > ActiveN of
true ->
OutStats = #{cnt => emqx_pd:reset_counter(outgoing_pubs),
Stats = #{cnt => emqx_pd:reset_counter(outgoing_pubs),
oct => emqx_pd:reset_counter(outgoing_bytes)
},
erlang:send(self(), {check_gc, OutStats});
false -> ok
postpone({check_gc, Stats}, State);
false -> State
end,
{{binary, IoData}, ensure_stats_timer(State)}.
{{binary, IoData}, ensure_stats_timer(NState)}.
serialize_and_inc_stats_fun(#state{serialize = Serialize}) ->
fun(Packet) ->
case Serialize(Packet) of
<<>> -> ?LOG(warning, "~s is discarded due to the frame is too large!",
<<>> -> ?LOG(warning, "~s is discarded due to the frame is too large.",
[emqx_packet:format(Packet)]),
<<>>;
Data -> ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]),
@ -522,6 +530,7 @@ serialize_and_inc_stats_fun(#state{serialize = Serialize}) ->
%%--------------------------------------------------------------------
%% Inc incoming/outgoing stats
%%--------------------------------------------------------------------
-compile({inline,
[ inc_recv_stats/2
@ -561,46 +570,90 @@ inc_sent_stats(Cnt, Oct) ->
emqx_metrics:inc('bytes.sent', Oct).
%%--------------------------------------------------------------------
%% Reply or Stop
%% Helper functions
%%--------------------------------------------------------------------
reply(Packet, State) when is_record(Packet, mqtt_packet) ->
reply(enqueue(Packet, State));
reply({outgoing, Packets}, State) ->
reply(enqueue(Packets, State));
reply(Other, State) when is_tuple(Other) ->
self() ! Other,
reply(State);
-compile({inline, [cancel_idle_timer/1, ensure_stats_timer/1]}).
reply([], State) ->
reply(State);
reply([Packet|More], State) when is_record(Packet, mqtt_packet) ->
reply(More, enqueue(Packet, State));
reply([{outgoing, Packets}|More], State) ->
reply(More, enqueue(Packets, State));
reply([Other|More], State) ->
self() ! Other,
reply(More, State).
%%--------------------------------------------------------------------
%% Cancel idle timer
-compile({inline, [reply/1, enqueue/2]}).
cancel_idle_timer(State = #state{idle_timer = IdleTimer}) ->
ok = emqx_misc:cancel_timer(IdleTimer),
State#state{idle_timer = undefined}.
reply(State = #state{pendings = []}) ->
%%--------------------------------------------------------------------
%% Ensure stats timer
ensure_stats_timer(State = #state{idle_timeout = Timeout,
stats_timer = undefined}) ->
State#state{stats_timer = start_timer(Timeout, emit_stats)};
ensure_stats_timer(State) -> State.
-compile({inline, [enqueue/2, return/1, shutdown/2, stop/2]}).
%%--------------------------------------------------------------------
%% Postpone the packet, cmd or event
postpone(Packet, State) when is_record(Packet, mqtt_packet) ->
enqueue(Packet, State);
postpone({outgoing, Packets}, State) ->
enqueue(Packets, State);
postpone(Event, State) when is_tuple(Event) ->
enqueue(Event, State);
postpone(More, State) when is_list(More) ->
lists:foldl(fun postpone/2, State, More).
enqueue([Packet], State = #state{postponed = Postponed}) ->
State#state{postponed = [Packet|Postponed]};
enqueue(Packets, State = #state{postponed = Postponed})
when is_list(Packets) ->
State#state{postponed = lists:reverse(Packets) ++ Postponed};
enqueue(Other, State = #state{postponed = Postponed}) ->
State#state{postponed = [Other|Postponed]}.
return(State = #state{postponed = []}) ->
{ok, State};
reply(State = #state{pendings = Pendings}) ->
{Reply, NState} = handle_outgoing(Pendings, State),
{reply, Reply, NState#state{pendings = []}}.
return(State = #state{postponed = Postponed}) ->
{Packets, Cmds, Events} = classify(Postponed, [], [], []),
ok = lists:foreach(fun trigger/1, Events),
State1 = State#state{postponed = []},
case {Packets, Cmds} of
{[], []} -> {ok, State1};
{[], Cmds} -> {reply, Cmds, State1};
{Packets, Cmds} ->
{Reply, State2} = handle_outgoing(Packets, State1),
{reply, [Reply|Cmds], State2}
end.
enqueue(Packet, State) when is_record(Packet, mqtt_packet) ->
enqueue([Packet], State);
enqueue(Packets, State = #state{pendings = Pendings}) ->
State#state{pendings = lists:append(Pendings, Packets)}.
classify([], Packets, Cmds, Events) ->
{Packets, Cmds, Events};
classify([Packet|More], Packets, Cmds, Events)
when is_record(Packet, mqtt_packet) ->
classify(More, [Packet|Packets], Cmds, Events);
classify([Cmd = {active, _}|More], Packets, Cmds, Events) ->
classify(More, Packets, [Cmd|Cmds], Events);
classify([Cmd = close|More], Packets, Cmds, Events) ->
classify(More, Packets, [Cmd|Cmds], Events);
classify([Event|More], Packets, Cmds, Events) ->
classify(More, Packets, Cmds, [Event|Events]).
trigger(Event) -> erlang:send(self(), Event).
shutdown(Reason, State) ->
stop({shutdown, Reason}, State).
stop(Reason, State = #state{pendings = []}) ->
stop(Reason, State = #state{postponed = []}) ->
{stop, State#state{stop_reason = Reason}};
stop(Reason, State = #state{pendings = Pendings}) ->
{Reply, State1} = handle_outgoing(Pendings, State),
State2 = State1#state{pendings = [], stop_reason = Reason},
{reply, [Reply, close], State2}.
stop(Reason, State = #state{postponed = Postponed}) ->
return(State#state{postponed = [close|Postponed],
stop_reason = Reason}).
%%--------------------------------------------------------------------
%% For CT tests
%%--------------------------------------------------------------------
set_field(Name, Value, State) ->
Pos = emqx_misc:index_of(Name, record_info(fields, state)),
setelement(Pos+1, State, Value).

View File

@ -89,6 +89,7 @@ t_subscribers(_) ->
t_subscriptions(_) ->
emqx_broker:subscribe(<<"topic">>, <<"clientid">>, #{qos => 1}),
ok = timer:sleep(100),
?assertEqual(#{qos => 1, subid => <<"clientid">>},
proplists:get_value(<<"topic">>, emqx_broker:subscriptions(self()))),
?assertEqual(#{qos => 1, subid => <<"clientid">>},

View File

@ -23,20 +23,6 @@
-include("emqx_mqtt.hrl").
-include_lib("eunit/include/eunit.hrl").
-define(DEFAULT_CONNINFO,
#{peername => {{127,0,0,1}, 3456},
sockname => {{127,0,0,1}, 1883},
conn_mod => emqx_connection,
proto_name => <<"MQTT">>,
proto_ver => ?MQTT_PROTO_V5,
clean_start => true,
keepalive => 30,
clientid => <<"clientid">>,
username => <<"username">>,
conn_props => #{},
receive_maximum => 100,
expiry_interval => 0
}).
all() -> emqx_ct:all(?MODULE).
@ -45,40 +31,40 @@ all() -> emqx_ct:all(?MODULE).
%%--------------------------------------------------------------------
init_per_suite(Config) ->
Config.
end_per_suite(_Config) ->
ok.
init_per_testcase(_TestCase, Config) ->
%% CM Meck
ok = meck:new(emqx_cm, [passthrough, no_history]),
ok = meck:new(emqx_cm, [passthrough, no_history, no_link]),
%% Access Control Meck
ok = meck:new(emqx_access_control, [passthrough, no_history]),
ok = meck:new(emqx_access_control, [passthrough, no_history, no_link]),
ok = meck:expect(emqx_access_control, authenticate,
fun(_) -> {ok, #{auth_result => success}} end),
ok = meck:expect(emqx_access_control, check_acl, fun(_, _, _) -> allow end),
%% Broker Meck
ok = meck:new(emqx_broker, [passthrough, no_history]),
ok = meck:new(emqx_broker, [passthrough, no_history, no_link]),
%% Hooks Meck
ok = meck:new(emqx_hooks, [passthrough, no_history]),
ok = meck:new(emqx_hooks, [passthrough, no_history, no_link]),
ok = meck:expect(emqx_hooks, run, fun(_Hook, _Args) -> ok end),
ok = meck:expect(emqx_hooks, run_fold, fun(_Hook, _Args, Acc) -> Acc end),
%% Session Meck
ok = meck:new(emqx_session, [passthrough, no_history]),
ok = meck:new(emqx_session, [passthrough, no_history, no_link]),
%% Metrics
ok = meck:new(emqx_metrics, [passthrough, no_history]),
ok = meck:new(emqx_metrics, [passthrough, no_history, no_link]),
ok = meck:expect(emqx_metrics, inc, fun(_) -> ok end),
ok = meck:expect(emqx_metrics, inc, fun(_, _) -> ok end),
Config.
end_per_testcase(_TestCase, Config) ->
end_per_suite(_Config) ->
ok = meck:unload(emqx_access_control),
ok = meck:unload(emqx_metrics),
ok = meck:unload(emqx_session),
ok = meck:unload(emqx_broker),
ok = meck:unload(emqx_hooks),
ok = meck:unload(emqx_cm),
ok.
init_per_testcase(_TestCase, Config) ->
Config.
end_per_testcase(_TestCase, Config) ->
Config.
%%--------------------------------------------------------------------
@ -328,15 +314,6 @@ t_process_unsubscribe(_) ->
%%--------------------------------------------------------------------
t_handle_deliver(_) ->
WithPacketId = fun(Msgs) ->
lists:zip(lists:seq(1, length(Msgs)), Msgs)
end,
ok = meck:expect(emqx_session, deliver,
fun(Delivers, Session) ->
Publishes = WithPacketId([Msg || {deliver, _, Msg} <- Delivers]),
{ok, Publishes, Session}
end),
ok = meck:expect(emqx_session, info, fun(retry_interval, _Session) -> 20 end),
Msg0 = emqx_message:make(test, ?QOS_1, <<"t1">>, <<"qos1">>),
Msg1 = emqx_message:make(test, ?QOS_2, <<"t2">>, <<"qos2">>),
Delivers = [{deliver, <<"+">>, Msg0}, {deliver, <<"+">>, Msg1}],
@ -426,7 +403,7 @@ t_handle_call_discard(_) ->
emqx_channel:handle_call(discard, channel()).
t_handle_call_takeover_begin(_) ->
{reply, undefined, _Chan} = emqx_channel:handle_call({takeover, 'begin'}, channel()).
{reply, _Session, _Chan} = emqx_channel:handle_call({takeover, 'begin'}, channel()).
t_handle_call_takeover_end(_) ->
ok = meck:expect(emqx_session, takeover, fun(_) -> ok end),
@ -565,14 +542,27 @@ t_terminate(_) ->
channel() -> channel(#{}).
channel(InitFields) ->
ConnInfo = #{peername => {{127,0,0,1}, 3456},
sockname => {{127,0,0,1}, 1883},
conn_mod => emqx_connection,
proto_name => <<"MQTT">>,
proto_ver => ?MQTT_PROTO_V5,
clean_start => true,
keepalive => 30,
clientid => <<"clientid">>,
username => <<"username">>,
conn_props => #{},
receive_maximum => 100,
expiry_interval => 0
},
maps:fold(fun(Field, Value, Channel) ->
emqx_channel:set_field(Field, Value, Channel)
end, default_channel(), InitFields).
default_channel() ->
Channel = emqx_channel:init(?DEFAULT_CONNINFO, [{zone, zone}]),
Channel1 = emqx_channel:set_field(conn_state, connected, Channel),
emqx_channel:set_field(clientinfo, clientinfo(), Channel1).
end,
emqx_channel:init(ConnInfo, [{zone, zone}]),
maps:merge(#{clientinfo => clientinfo(),
session => session(),
conn_state => connected
}, InitFields)).
clientinfo() -> clientinfo(#{}).
clientinfo(InitProps) ->

View File

@ -24,15 +24,12 @@
all() -> emqx_ct:all(?MODULE).
init_per_testcase(_TestCase, Config) ->
ok = meck:new(emqx_zone, [passthrough, no_history]),
Config.
end_per_testcase(_TestCase, _Config) ->
ok = meck:unload(emqx_zone).
ok.
t_info(_) ->
meck:expect(emqx_zone, publish_limit, fun(_) -> {1, 10} end),
Limiter = emqx_limiter:init([{rate_limit, {100, 1000}}]),
#{pub_limit := #{rate := 1,
burst := 10,
tokens := 10
@ -41,22 +38,21 @@ t_info(_) ->
burst := 1000,
tokens := 1000
}
} = emqx_limiter:info(Limiter).
} = emqx_limiter:info(limiter()).
t_check(_) ->
meck:expect(emqx_zone, publish_limit, fun(_) -> {1, 10} end),
Limiter = emqx_limiter:init([{rate_limit, {100, 1000}}]),
lists:foreach(fun(I) ->
{ok, Limiter1} = emqx_limiter:check(#{cnt => I, oct => I*100}, Limiter),
{ok, Limiter} = emqx_limiter:check(#{cnt => I, oct => I*100}, limiter()),
#{pub_limit := #{tokens := Cnt},
rate_limit := #{tokens := Oct}
} = emqx_limiter:info(Limiter1),
} = emqx_limiter:info(Limiter),
?assertEqual({10 - I, 1000 - I*100}, {Cnt, Oct})
end, lists:seq(1, 10)).
t_check_pause(_) ->
meck:expect(emqx_zone, publish_limit, fun(_) -> {1, 10} end),
Limiter = emqx_limiter:init([{rate_limit, {100, 1000}}]),
{pause, 1000, _} = emqx_limiter:check(#{cnt => 11, oct => 2000}, Limiter),
{pause, 2000, _} = emqx_limiter:check(#{cnt => 10, oct => 1200}, Limiter).
{pause, 1000, _} = emqx_limiter:check(#{cnt => 11, oct => 2000}, limiter()),
{pause, 2000, _} = emqx_limiter:check(#{cnt => 10, oct => 1200}, limiter()).
limiter() ->
emqx_limiter:init([{pub_limit, {1, 10}}, {rate_limit, {100, 1000}}]).

View File

@ -29,16 +29,21 @@ all() -> emqx_ct:all(?MODULE).
%% CT callbacks
%%--------------------------------------------------------------------
init_per_testcase(_TestCase, Config) ->
%% Meck Broker
ok = meck:new(emqx_broker, [passthrough, no_history]),
ok = meck:new(emqx_hooks, [passthrough, no_history]),
init_per_suite(Config) ->
%% Broker
ok = meck:new(emqx_broker, [passthrough, no_history, no_link]),
ok = meck:new(emqx_hooks, [passthrough, no_history, no_link]),
ok = meck:expect(emqx_hooks, run, fun(_Hook, _Args) -> ok end),
Config.
end_per_testcase(_TestCase, Config) ->
end_per_suite(_Config) ->
ok = meck:unload(emqx_broker),
ok = meck:unload(emqx_hooks),
ok = meck:unload(emqx_hooks).
init_per_testcase(_TestCase, Config) ->
Config.
end_per_testcase(_TestCase, Config) ->
Config.
%%--------------------------------------------------------------------
@ -330,7 +335,7 @@ t_replay(_) ->
{ok, Pubs, Session1} = emqx_session:deliver(Delivers, session()),
Msg = emqx_message:make(clientid, ?QOS_1, <<"t1">>, <<"payload">>),
Session2 = emqx_session:enqueue(Msg, Session1),
Pubs1 = [{I, emqx_message:set_flag(dup, Msg)} || {I, Msg} <- Pubs],
Pubs1 = [{I, emqx_message:set_flag(dup, M)} || {I, M} <- Pubs],
{ok, ReplayPubs, Session3} = emqx_session:replay(Session2),
?assertEqual(Pubs1 ++ [{3, Msg}], ReplayPubs),
?assertEqual(3, emqx_session:info(inflight_cnt, Session3)).

View File

@ -16,6 +16,7 @@
-module(emqx_ws_connection_SUITE).
-include("emqx.hrl").
-include("emqx_mqtt.hrl").
-include_lib("eunit/include/eunit.hrl").
@ -25,12 +26,15 @@
-import(emqx_ws_connection,
[ websocket_handle/2
, websocket_info/2
, websocket_close/2
]).
-define(STATS_KEYS, [recv_oct, recv_cnt, send_oct, send_cnt,
recv_pkt, recv_msg, send_pkt, send_msg
]).
-define(ws_conn, emqx_ws_connection).
all() -> emqx_ct:all(?MODULE).
%%--------------------------------------------------------------------
@ -38,184 +42,376 @@ all() -> emqx_ct:all(?MODULE).
%%--------------------------------------------------------------------
init_per_suite(Config) ->
Config.
end_per_suite(_Config) ->
ok.
init_per_testcase(_TestCase, Config) ->
%% Meck CowboyReq
ok = meck:new(cowboy_req, [passthrough, no_history]),
%% Mock cowboy_req
ok = meck:new(cowboy_req, [passthrough, no_history, no_link]),
ok = meck:expect(cowboy_req, peer, fun(_) -> {{127,0,0,1}, 3456} end),
ok = meck:expect(cowboy_req, sock, fun(_) -> {{127,0,0,1}, 8883} end),
ok = meck:expect(cowboy_req, sock, fun(_) -> {{127,0,0,1}, 18083} end),
ok = meck:expect(cowboy_req, cert, fun(_) -> undefined end),
ok = meck:expect(cowboy_req, parse_cookies, fun(_) -> undefined end),
%% Meck Channel
ok = meck:new(emqx_channel, [passthrough, no_history]),
%% Meck Metrics
ok = meck:new(emqx_metrics, [passthrough, no_history]),
ok = meck:expect(cowboy_req, parse_cookies, fun(_) -> error(badarg) end),
%% Mock emqx_zone
ok = meck:new(emqx_zone, [passthrough, no_history, no_link]),
ok = meck:expect(emqx_zone, oom_policy,
fun(_) -> #{max_heap_size => 838860800,
message_queue_len => 8000
}
end),
%% Mock emqx_access_control
ok = meck:new(emqx_access_control, [passthrough, no_history, no_link]),
ok = meck:expect(emqx_access_control, check_acl, fun(_, _, _) -> allow end),
%% Mock emqx_hooks
ok = meck:new(emqx_hooks, [passthrough, no_history, no_link]),
ok = meck:expect(emqx_hooks, run, fun(_Hook, _Args) -> ok end),
ok = meck:expect(emqx_hooks, run_fold, fun(_Hook, _Args, Acc) -> Acc end),
%% Mock emqx_broker
ok = meck:new(emqx_broker, [passthrough, no_history, no_link]),
ok = meck:expect(emqx_broker, subscribe, fun(_, _, _) -> ok end),
ok = meck:expect(emqx_broker, publish, fun(#message{topic = Topic}) ->
[{node(), Topic, 1}]
end),
ok = meck:expect(emqx_broker, unsubscribe, fun(_) -> ok end),
%% Mock emqx_metrics
ok = meck:new(emqx_metrics, [passthrough, no_history, no_link]),
ok = meck:expect(emqx_metrics, inc, fun(_) -> ok end),
ok = meck:expect(emqx_metrics, inc, fun(_, _) -> ok end),
ok = meck:expect(emqx_metrics, inc_recv, fun(_) -> ok end),
ok = meck:expect(emqx_metrics, inc_sent, fun(_) -> ok end),
Config.
end_per_suite(_Config) ->
lists:foreach(fun meck:unload/1,
[cowboy_req,
emqx_zone,
emqx_access_control,
emqx_broker,
emqx_hooks,
emqx_metrics
]).
init_per_testcase(_TestCase, Config) ->
Config.
end_per_testcase(_TestCase, Config) ->
ok = meck:unload(cowboy_req),
ok = meck:unload(emqx_channel),
ok = meck:unload(emqx_metrics),
Config.
%%--------------------------------------------------------------------
%% Test Cases
%%--------------------------------------------------------------------
%%TODO:...
t_ws_conn_init(_) ->
with_ws_conn(fun(_WsConn) -> ok end).
t_ws_conn_info(_) ->
with_ws_conn(fun(WsConn) ->
#{sockinfo := SockInfo} = emqx_ws_connection:info(WsConn),
t_info(_) ->
WsPid = spawn(fun() ->
receive {call, From, info} ->
gen_server:reply(From, ?ws_conn:info(st()))
end
end),
#{sockinfo := SockInfo} = ?ws_conn:call(WsPid, info),
#{socktype := ws,
active_n := 100,
peername := {{127,0,0,1}, 3456},
sockname := {{127,0,0,1}, 8883},
sockstate := running} = SockInfo
end).
t_websocket_init(_) ->
with_ws_conn(fun(WsConn) ->
#{sockinfo := SockInfo} = emqx_ws_connection:info(WsConn),
#{socktype := ws,
peername := {{127,0,0,1}, 3456},
sockname := {{127,0,0,1}, 8883},
sockname := {{127,0,0,1}, 18083},
sockstate := running
} = SockInfo
end).
} = SockInfo.
t_info_limiter(_) ->
St = st(#{limiter => emqx_limiter:init([])}),
?assertEqual(undefined, ?ws_conn:info(limiter, St)).
t_info_channel(_) ->
#{conn_state := connected} = ?ws_conn:info(channel, st()).
t_info_gc_state(_) ->
GcSt = emqx_gc:init(#{count => 10, bytes => 1000}),
GcInfo = ?ws_conn:info(gc_state, st(#{gc_state => GcSt})),
?assertEqual(#{cnt => {10,10}, oct => {1000,1000}}, GcInfo).
t_info_postponed(_) ->
?assertEqual([], ?ws_conn:info(postponed, st())),
St = ?ws_conn:postpone({active, false}, st()),
?assertEqual([{active, false}], ?ws_conn:info(postponed, St)).
t_info_stop_reason(_) ->
St = st(#{stop_reason => normal}),
?assertEqual(normal, ?ws_conn:info(stop_reason, St)).
t_stats(_) ->
WsPid = spawn(fun() ->
receive {call, From, stats} ->
gen_server:reply(From, ?ws_conn:stats(st()))
end
end),
Stats = ?ws_conn:call(WsPid, stats),
[{recv_oct, 0}, {recv_cnt, 0}, {send_oct, 0}, {send_cnt, 0},
{recv_pkt, 0}, {recv_msg, 0}, {send_pkt, 0}, {send_msg, 0}|_] = Stats.
t_call(_) ->
Info = ?ws_conn:info(st()),
WsPid = spawn(fun() ->
receive {call, From, info} -> gen_server:reply(From, Info) end
end),
?assertEqual(Info, ?ws_conn:call(WsPid, info)).
t_init(_) ->
Opts = [{idle_timeout, 300000}],
WsOpts = #{compress => false,
deflate_opts => #{},
max_frame_size => infinity,
idle_timeout => 300000
},
ok = meck:expect(cowboy_req, parse_header, fun(_, req) -> undefined end),
{cowboy_websocket, req, [req, Opts], WsOpts} = ?ws_conn:init(req, Opts),
ok = meck:expect(cowboy_req, parse_header, fun(_, req) -> [<<"mqtt">>] end),
ok = meck:expect(cowboy_req, set_resp_header, fun(_, <<"mqtt">>, req) -> resp end),
{cowboy_websocket, resp, [req, Opts], WsOpts} = ?ws_conn:init(req, Opts).
t_websocket_handle_binary(_) ->
with_ws_conn(fun(WsConn) ->
{ok, _} = websocket_handle({binary, [<<>>]}, WsConn)
end).
{ok, _} = websocket_handle({binary, <<>>}, st()),
{ok, _} = websocket_handle({binary, [<<>>]}, st()),
{ok, _} = websocket_handle({binary, <<192,0>>}, st()),
receive {incoming, ?PACKET(?PINGREQ)} -> ok
after 0 -> error(expect_incoming_pingreq)
end.
t_websocket_handle_ping_pong(_) ->
with_ws_conn(fun(WsConn) ->
{ok, WsConn} = websocket_handle(ping, WsConn),
{ok, WsConn} = websocket_handle(pong, WsConn),
{ok, WsConn} = websocket_handle({ping, <<>>}, WsConn),
{ok, WsConn} = websocket_handle({pong, <<>>}, WsConn)
end).
t_websocket_handle_ping(_) ->
{ok, St} = websocket_handle(ping, St = st()),
{ok, St} = websocket_handle({ping, <<>>}, St).
t_websocket_handle_pong(_) ->
{ok, St} = websocket_handle(pong, St = st()),
{ok, St} = websocket_handle({pong, <<>>}, St).
t_websocket_handle_bad_frame(_) ->
with_ws_conn(fun(WsConn) ->
{stop, WsConn1} = websocket_handle({badframe, <<>>}, WsConn),
?assertEqual({shutdown, unexpected_ws_frame}, stop_reason(WsConn1))
end).
{stop, St} = websocket_handle({badframe, <<>>}, st()),
{shutdown, unexpected_ws_frame} = ?ws_conn:info(stop_reason, St).
t_websocket_info_call(_) ->
with_ws_conn(fun(WsConn) ->
From = {make_ref(), self()},
Call = {call, From, badreq},
websocket_info(Call, WsConn)
end).
{ok, _St} = websocket_info(Call, st()).
t_websocket_info_rate_limit(_) ->
{ok, _} = websocket_info({cast, rate_limit}, st()),
ok = timer:sleep(1),
receive
{check_gc, Stats} ->
?assertEqual(#{cnt => 0, oct => 0}, Stats)
after 0 -> error(expect_check_gc)
end.
t_websocket_info_cast(_) ->
ok = meck:expect(emqx_channel, handle_info, fun(_Msg, Channel) -> {ok, Channel} end),
with_ws_conn(fun(WsConn) -> websocket_info({cast, msg}, WsConn) end).
{ok, _St} = websocket_info({cast, msg}, st()).
t_websocket_info_incoming(_) ->
ok = meck:expect(emqx_channel, handle_in, fun(_Packet, Channel) -> {ok, Channel} end),
with_ws_conn(fun(WsConn) ->
Connect = ?CONNECT_PACKET(
#mqtt_packet_connect{proto_ver = ?MQTT_PROTO_V5,
ConnPkt = #mqtt_packet_connect{
proto_name = <<"MQTT">>,
clientid = <<>>,
proto_ver = ?MQTT_PROTO_V5,
is_bridge = false,
clean_start = true,
keepalive = 60}),
{ok, WsConn1} = websocket_info({incoming, Connect}, WsConn),
keepalive = 60,
properties = undefined,
clientid = <<"clientid">>,
username = <<"username">>,
password = <<"passwd">>
},
{reply, [{binary, IoData1}], St1} =
websocket_info({incoming, ?CONNECT_PACKET(ConnPkt)}, st()),
?assertEqual(<<224,2,130,0>>, iolist_to_binary(IoData1)),
%% PINGREQ
{reply, [{binary, IoData2}], St2} =
websocket_info({incoming, ?PACKET(?PINGREQ)}, St1),
?assertEqual(<<208,0>>, iolist_to_binary(IoData2)),
%% PUBLISH
Publish = ?PUBLISH_PACKET(?QOS_1, <<"t">>, 1, <<"payload">>),
{ok, _WsConn2} = websocket_info({incoming, Publish}, WsConn1)
end).
{reply, [{binary, IoData3}], _St3} =
websocket_info({incoming, Publish}, St2),
?assertEqual(<<64,4,0,1,0,0>>, iolist_to_binary(IoData3)).
t_websocket_info_check_gc(_) ->
Stats = #{cnt => 10, oct => 1000},
{ok, _St} = websocket_info({check_gc, Stats}, st()).
t_websocket_info_deliver(_) ->
with_ws_conn(fun(WsConn) ->
ok = meck:expect(emqx_channel, handle_deliver,
fun(Delivers, Channel) ->
Packets = [emqx_message:to_packet(1, Msg) || {deliver, _, Msg} <- Delivers],
{ok, {outgoing, Packets}, Channel}
end),
Deliver = {deliver, <<"#">>, emqx_message:make(<<"topic">>, <<"payload">>)},
{reply, {binary, _Data}, _WsConn1} = websocket_info(Deliver, WsConn)
end).
Msg0 = emqx_message:make(clientid, ?QOS_0, <<"t">>, <<"">>),
Msg1 = emqx_message:make(clientid, ?QOS_1, <<"t">>, <<"">>),
self() ! {deliver, <<"#">>, Msg1},
{reply, [{binary, IoData}], _St} =
websocket_info({deliver, <<"#">>, Msg0}, st()),
?assertEqual(<<48,3,0,1,116,50,5,0,1,116,0,1>>, iolist_to_binary(IoData)).
t_websocket_info_timeout(_) ->
with_ws_conn(fun(WsConn) ->
websocket_info({timeout, make_ref(), keepalive}, WsConn),
websocket_info({timeout, make_ref(), emit_stats}, WsConn),
websocket_info({timeout, make_ref(), retry_delivery}, WsConn)
end).
t_websocket_info_timeout_limiter(_) ->
Ref = make_ref(),
Event = {timeout, Ref, limit_timeout},
{reply, [{active, true}], St} =
websocket_info(Event, st(#{limit_timer => Ref})),
?assertEqual([], ?ws_conn:info(postponed, St)).
t_websocket_info_timeout_keepalive(_) ->
{ok, _St} = websocket_info({timeout, make_ref(), keepalive}, st()).
t_websocket_info_timeout_emit_stats(_) ->
Ref = make_ref(),
St = st(#{stats_timer => Ref}),
{ok, St1} = websocket_info({timeout, Ref, emit_stats}, St),
?assertEqual(undefined, ?ws_conn:info(stats_timer, St1)).
t_websocket_info_timeout_retry(_) ->
{ok, _St} = websocket_info({timeout, make_ref(), retry_delivery}, st()).
t_websocket_info_close(_) ->
with_ws_conn(fun(WsConn) ->
{stop, WsConn1} = websocket_info({close, sock_error}, WsConn),
?assertEqual({shutdown, sock_error}, stop_reason(WsConn1))
end).
{stop, St} = websocket_info({close, sock_error}, st()),
?assertEqual({shutdown, sock_error}, ?ws_conn:info(stop_reason, St)).
t_websocket_info_shutdown(_) ->
with_ws_conn(fun(WsConn) ->
{stop, WsConn1} = websocket_info({shutdown, reason}, WsConn),
?assertEqual({shutdown, reason}, stop_reason(WsConn1))
end).
{stop, St} = websocket_info({shutdown, reason}, st()),
?assertEqual({shutdown, reason}, ?ws_conn:info(stop_reason, St)).
t_websocket_info_stop(_) ->
with_ws_conn(fun(WsConn) ->
{stop, WsConn1} = websocket_info({stop, normal}, WsConn),
?assertEqual(normal, stop_reason(WsConn1))
end).
{stop, St} = websocket_info({stop, normal}, st()),
?assertEqual(normal, ?ws_conn:info(stop_reason, St)).
t_websocket_close(_) ->
ok = meck:expect(emqx_channel, handle_info,
fun({sock_closed, badframe}, Channel) ->
{shutdown, sock_closed, Channel}
end),
with_ws_conn(fun(WsConn) ->
{stop, WsConn1} = emqx_ws_connection:websocket_close(badframe, WsConn),
?assertEqual(sock_closed, stop_reason(WsConn1))
end).
{stop, St} = websocket_close(badframe, st()),
?assertEqual({shutdown, badframe}, ?ws_conn:info(stop_reason, St)).
t_handle_call(_) ->
with_ws_conn(fun(WsConn) -> ok end).
t_handle_info_connack(_) ->
ConnAck = ?CONNACK_PACKET(?RC_SUCCESS),
{reply, [{binary, IoData}], _St} =
?ws_conn:handle_info({connack, ConnAck}, st()),
?assertEqual(<<32,2,0,0>>, iolist_to_binary(IoData)).
t_handle_info(_) ->
with_ws_conn(fun(WsConn) -> ok end).
t_handle_info_close(_) ->
{stop, St} = ?ws_conn:handle_info({close, protocol_error}, st()),
?assertEqual({shutdown, protocol_error},
?ws_conn:info(stop_reason, St)).
t_handle_timeout(_) ->
with_ws_conn(fun(WsConn) -> ok end).
t_handle_info_event(_) ->
ok = meck:new(emqx_cm, [passthrough, no_history]),
ok = meck:expect(emqx_cm, register_channel, fun(_,_,_) -> ok end),
ok = meck:expect(emqx_cm, connection_closed, fun(_) -> true end),
{ok, _} = ?ws_conn:handle_info({event, connected}, st()),
{ok, _} = ?ws_conn:handle_info({event, disconnected}, st()),
{ok, _} = ?ws_conn:handle_info({event, updated}, st()),
ok = meck:unload(emqx_cm).
t_handle_timeout_idle_timeout(_) ->
TRef = make_ref(),
{stop, St} = ?ws_conn:handle_timeout(
TRef, idle_timeout, st(#{idle_timer => TRef})),
?assertEqual({shutdown, idle_timeout}, ?ws_conn:info(stop_reason, St)).
t_handle_timeout_keepalive(_) ->
{ok, _St} = ?ws_conn:handle_timeout(make_ref(), keepalive, st()).
t_handle_timeout_emit_stats(_) ->
TRef = make_ref(),
{ok, St} = ?ws_conn:handle_timeout(
TRef, emit_stats, st(#{stats_timer => TRef})),
?assertEqual(undefined, ?ws_conn:info(stats_timer, St)).
t_ensure_rate_limit(_) ->
Limiter = emqx_limiter:init([{pub_limit, {1, 10}},
{rate_limit, {100, 1000}}
]),
St = st(#{limiter => Limiter}),
St1 = ?ws_conn:ensure_rate_limit(#{cnt => 0, oct => 0}, St),
St2 = ?ws_conn:ensure_rate_limit(#{cnt => 11, oct => 1200}, St1),
?assertEqual(blocked, ?ws_conn:info(sockstate, St2)),
?assertEqual([{active, false}], ?ws_conn:info(postponed, St2)).
t_parse_incoming(_) ->
with_ws_conn(fun(WsConn) -> ok end).
St = ?ws_conn:parse_incoming(<<48,3>>, st()),
St1 = ?ws_conn:parse_incoming(<<0,1,116>>, St),
Packet = ?PUBLISH_PACKET(?QOS_0, <<"t">>, undefined, <<>>),
[{incoming, Packet}] = ?ws_conn:info(postponed, St1).
t_handle_incoming(_) ->
with_ws_conn(fun(WsConn) -> ok end).
t_parse_incoming_frame_error(_) ->
St = ?ws_conn:parse_incoming(<<3,2,1,0>>, st()),
FrameError = {frame_error, function_clause},
[{incoming, FrameError}] = ?ws_conn:info(postponed, St).
t_handle_return(_) ->
with_ws_conn(fun(WsConn) -> ok end).
t_handle_incomming_frame_error(_) ->
FrameError = {frame_error, bad_qos},
Serialize = emqx_frame:serialize_fun(#{version => 5, max_size => 16#FFFF}),
{reply, [{binary, IoData}], _St} =
?ws_conn:handle_incoming(FrameError, st(#{serialize => Serialize})),
?assertEqual(<<224,2,129,0>>, iolist_to_binary(IoData)).
t_handle_outgoing(_) ->
with_ws_conn(fun(WsConn) -> ok end).
Packets = [?PUBLISH_PACKET(?QOS_1, <<"t1">>, 1, <<"payload">>),
?PUBLISH_PACKET(?QOS_2, <<"t2">>, 2, <<"payload">>)
],
{{binary, IoData}, _St} = ?ws_conn:handle_outgoing(Packets, st()),
?assert(is_binary(iolist_to_binary(IoData))).
t_run_gc(_) ->
GcSt = emqx_gc:init(#{count => 10, bytes => 100}),
WsSt = st(#{gc_state => GcSt}),
?ws_conn:run_gc(#{cnt => 100, oct => 10000}, WsSt).
t_check_oom(_) ->
%%Policy = #{max_heap_size => 10, message_queue_len => 10},
%%meck:expect(emqx_zone, oom_policy, fun(_) -> Policy end),
_St = ?ws_conn:check_oom(st()),
ok = timer:sleep(10).
%%receive {shutdown, proc_heap_too_large} -> ok
%%after 0 -> error(expect_shutdown)
%%end.
t_enqueue(_) ->
Packet = ?PUBLISH_PACKET(?QOS_0),
St = ?ws_conn:enqueue(Packet, st()),
[Packet] = ?ws_conn:info(postponed, St).
t_shutdown(_) ->
{stop, St} = ?ws_conn:shutdown(closed, st()),
{shutdown, closed} = ?ws_conn:info(stop_reason, St).
t_stop(_) ->
St = st(#{postponed => [{active, false}]}),
{reply, [{active, false}, close], _} = ?ws_conn:stop(closed, St).
%%--------------------------------------------------------------------
%% Helper functions
%%--------------------------------------------------------------------
with_ws_conn(TestFun) ->
with_ws_conn(TestFun, []).
st() -> st(#{}).
st(InitFields) when is_map(InitFields) ->
{ok, St, _} = ?ws_conn:websocket_init([req, [{zone, external}]]),
maps:fold(fun(N, V, S) -> ?ws_conn:set_field(N, V, S) end,
?ws_conn:set_field(channel, channel(), St),
InitFields
).
with_ws_conn(TestFun, Opts) ->
{ok, WsConn, _} = emqx_ws_connection:websocket_init(
[req, emqx_misc:merge_opts([{zone, external}], Opts)]),
TestFun(WsConn).
stop_reason(WsConn) ->
emqx_ws_connection:info(stop_reason, WsConn).
channel() -> channel(#{}).
channel(InitFields) ->
ConnInfo = #{peername => {{127,0,0,1}, 3456},
sockname => {{127,0,0,1}, 18083},
conn_mod => emqx_ws_connection,
proto_name => <<"MQTT">>,
proto_ver => ?MQTT_PROTO_V5,
clean_start => true,
keepalive => 30,
clientid => <<"clientid">>,
username => <<"username">>,
receive_maximum => 100,
expiry_interval => 0
},
ClientInfo = #{zone => zone,
protocol => mqtt,
peerhost => {127,0,0,1},
clientid => <<"clientid">>,
username => <<"username">>,
is_superuser => false,
peercert => undefined,
mountpoint => undefined
},
Session = emqx_session:init(#{zone => external},
#{receive_maximum => 0}
),
maps:fold(fun(Field, Value, Channel) ->
emqx_channel:set_field(Field, Value, Channel)
end,
emqx_channel:init(ConnInfo, [{zone, zone}]),
maps:merge(#{clientinfo => ClientInfo,
session => Session,
conn_state => connected
}, InitFields)).