emqx/apps/emqx_stomp/src/emqx_stomp_connection.erl

530 lines
18 KiB
Erlang

%%--------------------------------------------------------------------
%% Copyright (c) 2020-2022 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_stomp_connection).
-behaviour(gen_server).
-include("emqx_stomp.hrl").
-include_lib("emqx/include/logger.hrl").
-include_lib("emqx/include/types.hrl").
-include_lib("snabbkaffe/include/snabbkaffe.hrl").
-logger_header("[Stomp-Conn]").
-import(emqx_misc,
[ start_timer/2
]).
-export([ start_link/3
, info/1
]).
%% gen_server Function Exports
-export([ init/1
, handle_call/3
, handle_cast/2
, handle_info/2
, code_change/3
, terminate/2
]).
%% for protocol
-export([send/4, heartbeat/2, statfun/3]).
%% for mgmt
-export([call/2, call/3]).
-record(state, {
%% TCP/TLS Transport
transport :: esockd:transport(),
%% TCP/TLS Socket
socket :: esockd:socket(),
%% Peername of the connection
peername :: emqx_types:peername(),
%% Sockname of the connection
sockname :: emqx_types:peername(),
%% Sock State
sockstate :: emqx_types:sockstate(),
%% The {active, N} option
active_n :: pos_integer(),
%% Limiter
limiter :: maybe(emqx_limiter:limiter()),
%% Limit Timer
limit_timer :: maybe(reference()),
%% GC State
gc_state :: maybe(emqx_gc:gc_state()),
%% Stats Timer
stats_timer :: disabled | maybe(reference()),
%% Parser State
parser :: emqx_stomp_frame:parser(),
%% Protocol State
pstate :: emqx_stomp_protocol:pstate(),
%% XXX: some common confs
proto_env :: list()
}).
-type(state() :: #state{}).
-define(DEFAULT_GC_POLICY, #{bytes => 16777216, count => 16000}).
-define(DEFAULT_OOM_POLICY, #{ max_heap_size => 8388608,
message_queue_len => 10000}).
-define(ACTIVE_N, 100).
-define(IDLE_TIMEOUT, 30000).
-define(INFO_KEYS, [socktype, peername, sockname, sockstate, active_n]).
-define(CONN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]).
-define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]).
-define(ENABLED(X), (X =/= undefined)).
start_link(Transport, Sock, ProtoEnv) ->
{ok, proc_lib:spawn_link(?MODULE, init, [[Transport, Sock, ProtoEnv]])}.
-spec info(pid() | state()) -> emqx_types:infos().
info(CPid) when is_pid(CPid) ->
call(CPid, info);
info(State = #state{pstate = PState}) ->
ChanInfo = emqx_stomp_protocol:info(PState),
SockInfo = maps:from_list(
info(?INFO_KEYS, State)),
ChanInfo#{sockinfo => SockInfo}.
info(Keys, State) when is_list(Keys) ->
[{Key, info(Key, State)} || Key <- Keys];
info(socktype, #state{transport = Transport, socket = Socket}) ->
Transport:type(Socket);
info(peername, #state{peername = Peername}) ->
Peername;
info(sockname, #state{sockname = Sockname}) ->
Sockname;
info(sockstate, #state{sockstate = SockSt}) ->
SockSt;
info(active_n, #state{active_n = ActiveN}) ->
ActiveN.
-spec stats(pid() | state()) -> emqx_types:stats().
stats(CPid) when is_pid(CPid) ->
call(CPid, stats);
stats(#state{transport = Transport,
socket = Socket,
pstate = PState}) ->
SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of
{ok, Ss} -> Ss;
{error, _} -> []
end,
ConnStats = emqx_pd:get_counters(?CONN_STATS),
ChanStats = emqx_stomp_protocol:stats(PState),
ProcStats = emqx_misc:proc_stats(),
lists:append([SockStats, ConnStats, ChanStats, ProcStats]).
call(Pid, Req) ->
call(Pid, Req, infinity).
call(Pid, Req, Timeout) ->
gen_server:call(Pid, Req, Timeout).
-spec init([term()]) -> no_return().
init([Transport, RawSocket, ProtoEnv]) ->
case Transport:wait(RawSocket) of
{ok, Socket} ->
init_state(Transport, Socket, ProtoEnv);
{error, Reason} ->
ok = Transport:fast_close(RawSocket),
exit_on_sock_error(Reason)
end.
-spec init_state(module(), port(), [proplists:property()]) -> no_return().
init_state(Transport, Socket, ProtoEnv) ->
{ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]),
{ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]),
SendFun = {fun ?MODULE:send/4, [Transport, Socket, self()]},
StatFun = {fun ?MODULE:statfun/3, [Transport, Socket]},
HrtBtFun = {fun ?MODULE:heartbeat/2, [Transport, Socket]},
Parser = emqx_stomp_frame:init_parer_state(ProtoEnv),
ActiveN = proplists:get_value(active_n, ProtoEnv, ?ACTIVE_N),
GcState = emqx_gc:init(?DEFAULT_GC_POLICY),
Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]),
ConnInfo = #{socktype => Transport:type(Socket),
peername => Peername,
sockname => Sockname,
peercert => Peercert,
statfun => StatFun,
sendfun => SendFun,
heartfun => HrtBtFun,
conn_mod => ?MODULE
},
PState = emqx_stomp_protocol:init(ConnInfo, ProtoEnv),
State = #state{transport = Transport,
socket = Socket,
peername = Peername,
sockname = Sockname,
sockstate = idle,
active_n = ActiveN,
limiter = undefined,
parser = Parser,
proto_env = ProtoEnv,
gc_state = GcState,
pstate = PState},
case activate_socket(State) of
{ok, NState} ->
emqx_logger:set_metadata_peername(
esockd:format(Peername)),
gen_server:enter_loop(
?MODULE, [{hibernate_after, 5000}], NState, 20000);
{error, Reason} ->
ok = Transport:fast_close(Socket),
exit_on_sock_error(Reason)
end.
-spec exit_on_sock_error(any()) -> no_return().
exit_on_sock_error(Reason) when Reason =:= einval;
Reason =:= enotconn;
Reason =:= closed ->
erlang:exit(normal);
exit_on_sock_error(timeout) ->
erlang:exit({shutdown, ssl_upgrade_timeout});
exit_on_sock_error(Reason) ->
erlang:exit({shutdown, Reason}).
send(Frame, Transport, Sock, ConnPid) ->
?LOG(info, "SEND Frame: ~s", [emqx_stomp_frame:format(Frame)]),
ok = inc_outgoing_stats(Frame),
Data = emqx_stomp_frame:serialize(Frame),
?LOG(debug, "SEND ~p", [Data]),
try Transport:async_send(Sock, Data) of
ok -> ok;
{error, Reason} -> ConnPid ! {shutdown, Reason}
catch
error:Error -> ConnPid ! {shutdown, Error}
end.
heartbeat(Transport, Sock) ->
?LOG(debug, "SEND heartbeat: \\n"),
Transport:send(Sock, <<$\n>>).
statfun(Stat, Transport, Sock) ->
case Transport:getstat(Sock, [Stat]) of
{ok, [{Stat, Val}]} -> {ok, Val};
{error, Error} -> {error, Error}
end.
handle_call(info, _From, State) ->
{reply, info(State), State};
handle_call(stats, _From, State) ->
{reply, stats(State), State};
handle_call(discard, _From, State) ->
%% TODO: send the DISCONNECT packet?
shutdown_and_reply(discared, ok, State);
handle_call(kick, _From, State) ->
shutdown_and_reply(kicked, ok, State);
handle_call(Req, _From, State) ->
?LOG(error, "unexpected request: ~p", [Req]),
{reply, ignored, State}.
handle_cast(Msg, State) ->
?LOG(error, "unexpected msg: ~p", [Msg]),
noreply(State).
handle_info({event, Name}, State = #state{pstate = PState})
when Name == connected;
Name == updated ->
ClientId = emqx_stomp_protocol:info(clientid, PState),
emqx_cm:insert_channel_info(ClientId, info(State), stats(State)),
noreply(State);
handle_info(timeout, State) ->
shutdown(idle_timeout, State);
handle_info({shutdown, Reason}, State) ->
shutdown(Reason, State);
handle_info({timeout, TRef, TMsg}, State) when TMsg =:= incoming;
TMsg =:= outgoing ->
Stat = case TMsg of
incoming -> recv_oct;
_ -> send_oct
end,
case getstat(Stat, State) of
{ok, Val} ->
with_proto(timeout, [TRef, {TMsg, Val}], State);
{error, Reason} ->
shutdown({sock_error, Reason}, State)
end;
handle_info({timeout, _TRef, limit_timeout}, State) ->
NState = State#state{sockstate = idle,
limit_timer = undefined
},
handle_info(activate_socket, NState);
handle_info({timeout, _TRef, emit_stats},
State = #state{pstate = PState}) ->
ClientId = emqx_stomp_protocol:info(clientid, PState),
emqx_cm:set_chan_stats(ClientId, stats(State)),
noreply(State#state{stats_timer = undefined});
handle_info({timeout, TRef, TMsg}, State) ->
with_proto(timeout, [TRef, TMsg], State);
handle_info(activate_socket, State) ->
case activate_socket(State) of
{ok, NState} ->
noreply(NState);
{error, Reason} ->
handle_info({sock_error, Reason}, State)
end;
handle_info({inet_reply, _Ref, ok}, State) ->
noreply(State);
handle_info({Inet, _Sock, Data}, State) when Inet == tcp; Inet == ssl ->
?LOG(debug, "RECV ~0p", [Data]),
Oct = iolist_size(Data),
inc_counter(incoming_bytes, Oct),
ok = emqx_metrics:inc('bytes.received', Oct),
received(Data, ensure_stats_timer(?IDLE_TIMEOUT, State));
handle_info({Passive, _Sock}, State)
when Passive == tcp_passive; Passive == ssl_passive ->
%% In Stats
Pubs = emqx_pd:reset_counter(incoming_pubs),
Bytes = emqx_pd:reset_counter(incoming_bytes),
InStats = #{cnt => Pubs, oct => Bytes},
%% Ensure Rate Limit
NState = ensure_rate_limit(InStats, State),
%% Run GC and Check OOM
NState1 = check_oom(run_gc(InStats, NState)),
handle_info(activate_socket, NState1);
handle_info({Error, _Sock, Reason}, State)
when Error == tcp_error; Error == ssl_error ->
handle_info({sock_error, Reason}, State);
handle_info({Closed, _Sock}, State)
when Closed == tcp_closed; Closed == ssl_closed ->
handle_info({sock_closed, Closed}, close_socket(State));
handle_info({inet_reply, _Sock, {error, Reason}}, State) ->
handle_info({sock_error, Reason}, State);
handle_info({sock_error, Reason}, State) ->
case Reason =/= closed andalso Reason =/= einval of
true -> ?LOG(warning, "socket_error: ~p", [Reason]);
false -> ok
end,
handle_info({sock_closed, Reason}, close_socket(State));
handle_info({sock_closed, Reason}, State) ->
shutdown(Reason, State);
handle_info({deliver, _Topic, Msg}, State = #state{pstate = PState}) ->
noreply(State#state{pstate = case emqx_stomp_protocol:send(Msg, PState) of
{ok, PState1} ->
PState1;
{error, dropped, PState1} ->
PState1
end});
handle_info(Info, State) ->
with_proto(handle_info, [Info], State).
terminate(Reason, #state{transport = Transport,
socket = Sock,
pstate = PState}) ->
?LOG(info, "terminated for ~p", [Reason]),
Transport:fast_close(Sock),
case {PState, Reason} of
{undefined, _} -> ok;
{_, {shutdown, Error}} ->
emqx_stomp_protocol:shutdown(Error, PState);
{_, Reason} ->
emqx_stomp_protocol:shutdown(Reason, PState)
end.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%--------------------------------------------------------------------
%% Receive and Parse data
%%--------------------------------------------------------------------
with_proto(Fun, Args, State = #state{pstate = PState}) ->
case erlang:apply(emqx_stomp_protocol, Fun, Args ++ [PState]) of
ok ->
noreply(State);
{ok, NPState} ->
noreply(State#state{pstate = NPState});
{F, Reason, NPState} when F == stop;
F == error;
F == shutdown ->
shutdown(Reason, State#state{pstate = NPState})
end.
received(<<>>, State) ->
noreply(State);
received(Bytes, State = #state{parser = Parser,
pstate = PState}) ->
try emqx_stomp_frame:parse(Bytes, Parser) of
{more, NewParser} ->
noreply(State#state{parser = NewParser});
{ok, Frame, Rest} ->
?LOG(info, "RECV Frame: ~s", [emqx_stomp_frame:format(Frame)]),
ok = inc_incoming_stats(Frame),
case emqx_stomp_protocol:received(Frame, PState) of
{ok, PState1} ->
received(Rest, reset_parser(State#state{pstate = PState1}));
{error, Error, PState1} ->
shutdown(Error, State#state{pstate = PState1});
{stop, Reason, PState1} ->
stop(Reason, State#state{pstate = PState1})
end;
{error, Error} ->
?LOG(error, "Framing error - ~s", [Error]),
?LOG(error, "Bytes: ~p", [Bytes]),
shutdown(frame_error, State)
catch
_Error:Reason ->
?LOG(error, "Parser failed for ~p", [Reason]),
?LOG(error, "Error data: ~p", [Bytes]),
shutdown(parse_error, State)
end.
reset_parser(State = #state{proto_env = ProtoEnv}) ->
State#state{parser = emqx_stomp_frame:init_parer_state(ProtoEnv)}.
activate_socket(State = #state{sockstate = closed}) ->
{ok, State};
activate_socket(State = #state{sockstate = blocked}) ->
{ok, State};
activate_socket(State = #state{transport = Transport,
socket = Socket,
active_n = N}) ->
case Transport:setopts(Socket, [{active, N}]) of
ok -> {ok, State#state{sockstate = running}};
Error -> Error
end.
close_socket(State = #state{sockstate = closed}) -> State;
close_socket(State = #state{transport = Transport, socket = Socket}) ->
ok = Transport:fast_close(Socket),
State#state{sockstate = closed}.
%%--------------------------------------------------------------------
%% Inc incoming/outgoing stats
inc_incoming_stats(#stomp_frame{command = Cmd}) ->
inc_counter(recv_pkt, 1),
case Cmd of
<<"SEND">> ->
inc_counter(recv_msg, 1),
inc_counter(incoming_pubs, 1),
emqx_metrics:inc('messages.received'),
emqx_metrics:inc('messages.qos1.received');
_ ->
ok
end,
emqx_metrics:inc('packets.received').
inc_outgoing_stats(#stomp_frame{command = Cmd}) ->
inc_counter(send_pkt, 1),
case Cmd of
<<"MESSAGE">> ->
inc_counter(send_msg, 1),
inc_counter(outgoing_pubs, 1),
emqx_metrics:inc('messages.sent'),
emqx_metrics:inc('messages.qos1.sent');
_ ->
ok
end,
emqx_metrics:inc('packets.sent').
%%--------------------------------------------------------------------
%% Ensure rate limit
ensure_rate_limit(Stats, State = #state{limiter = Limiter}) ->
case ?ENABLED(Limiter) andalso emqx_limiter:check(Stats, Limiter) of
false -> State;
{ok, Limiter1} ->
State#state{limiter = Limiter1};
{pause, Time, Limiter1} ->
?LOG(notice, "Pause ~pms due to rate limit", [Time]),
TRef = start_timer(Time, limit_timeout),
State#state{sockstate = blocked,
limiter = Limiter1,
limit_timer = TRef
}
end.
%%--------------------------------------------------------------------
%% Run GC and Check OOM
run_gc(Stats, State = #state{gc_state = GcSt}) ->
case ?ENABLED(GcSt) andalso emqx_gc:run(Stats, GcSt) of
false -> State;
{_IsGC, GcSt1} ->
State#state{gc_state = GcSt1}
end.
check_oom(State) ->
OomPolicy = ?DEFAULT_OOM_POLICY,
?tp(debug, check_oom, #{policy => OomPolicy}),
case ?ENABLED(OomPolicy) andalso emqx_misc:check_oom(OomPolicy) of
{shutdown, Reason} ->
%% triggers terminate/2 callback immediately
erlang:exit({shutdown, Reason});
_Other ->
ok
end,
State.
%%--------------------------------------------------------------------
%% Ensure/cancel stats timer
ensure_stats_timer(Timeout, State = #state{stats_timer = undefined}) ->
State#state{stats_timer = start_timer(Timeout, emit_stats)};
ensure_stats_timer(_Timeout, State) -> State.
getstat(Stat, #state{transport = Transport, socket = Sock}) ->
case Transport:getstat(Sock, [Stat]) of
{ok, [{Stat, Val}]} -> {ok, Val};
{error, Error} -> {error, Error}
end.
noreply(State) ->
{noreply, State}.
stop(Reason, State) ->
{stop, Reason, State}.
shutdown(Reason, State) ->
stop({shutdown, Reason}, State).
shutdown_and_reply(Reason, Reply, State) ->
{stop, {shutdown, Reason}, Reply, State}.
inc_counter(Key, Inc) ->
_ = emqx_pd:inc_counter(Key, Inc),
ok.