diff --git a/apps/emqx_stomp/src/emqx_stomp_connection.erl b/apps/emqx_stomp/src/emqx_stomp_connection.erl index d4e7f6475..d926e8eef 100644 --- a/apps/emqx_stomp/src/emqx_stomp_connection.erl +++ b/apps/emqx_stomp/src/emqx_stomp_connection.erl @@ -20,9 +20,16 @@ -include("emqx_stomp.hrl"). -include_lib("emqx/include/logger.hrl"). +-include_lib("emqx/include/types.hrl"). +-include_lib("snabbkaffe/include/snabbkaffe.hrl"). -logger_header("[Stomp-Conn]"). +-import(emqx_misc, + [ maybe_apply/2 + , start_timer/2 + ]). + -export([ start_link/3 , info/1 ]). @@ -39,52 +46,160 @@ %% for protocol -export([send/4, heartbeat/2]). --record(state, {transport, socket, peername, conn_name, conn_state, - await_recv, rate_limit, parser, pstate, - proto_env, heartbeat}). +%% for mgmt +-export([call/2, call/3]). --define(INFO_KEYS, [peername, await_recv, conn_state]). --define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]). +-record(state, { + %% TCP/TLS Transport + transport :: esockd:transport(), + %% TCP/TLS Socket + socket :: esockd:socket(), + %% Peername of the connection + peername :: emqx_types:peername(), + %% Sockname of the connection + sockname :: emqx_types:peername(), + %% Sock State + sockstate :: emqx_types:sockstate(), + %% The {active, N} option + active_n :: pos_integer(), + %% Limiter + limiter :: maybe(emqx_limiter:limiter()), + %% Limit Timer + limit_timer :: maybe(reference()), + %% GC State + gc_state :: maybe(emqx_gc:gc_state()), + %% Stats Timer + stats_timer :: disabled | maybe(reference()), + + await_recv, parser, pstate, + proto_env, heartbeat}). + +-type(state() :: #state{}). + +-define(DEFAULT_GC_POLICY, #{bytes => 16777216, count => 16000}). +-define(DEFAULT_OOM_POLICY, #{ max_heap_size => 8388608, + message_queue_len => 10000}). + +-define(ACTIVE_N, 100). +-define(IDLE_TIMEOUT, 30000). +-define(INFO_KEYS, [socktype, peername, sockname, sockstate, active_n]). +-define(CONN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). +-define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]). + +-define(ENABLED(X), (X =/= undefined)). start_link(Transport, Sock, ProtoEnv) -> {ok, proc_lib:spawn_link(?MODULE, init, [[Transport, Sock, ProtoEnv]])}. -info(CPid) -> - gen_server:call(CPid, info, infinity). +-spec info(pid()|state()) -> emqx_types:infos(). +info(CPid) when is_pid(CPid) -> + call(CPid, info); +info(State = #state{pstate = PState}) -> + ChanInfo = emqx_stomp_protocol:info(PState), + SockInfo = maps:from_list( + info(?INFO_KEYS, State)), + ChanInfo#{sockinfo => SockInfo}. -init([Transport, Sock, ProtoEnv]) -> - process_flag(trap_exit, true), - case Transport:wait(Sock) of - {ok, NewSock} -> - {ok, Peername} = Transport:ensure_ok_or_exit(peername, [NewSock]), - ConnName = esockd:format(Peername), - SendFun = {fun ?MODULE:send/4, [Transport, Sock, self()]}, - HrtBtFun = {fun ?MODULE:heartbeat/2, [Transport, Sock]}, - Parser = emqx_stomp_frame:init_parer_state(ProtoEnv), - PState = emqx_stomp_protocol:init(#{peername => Peername, - sendfun => SendFun, - heartfun => HrtBtFun}, ProtoEnv), - RateLimit = init_rate_limit(proplists:get_value(rate_limit, ProtoEnv)), - State = run_socket(#state{transport = Transport, - socket = NewSock, - peername = Peername, - conn_name = ConnName, - conn_state = running, - await_recv = false, - rate_limit = RateLimit, - parser = Parser, - proto_env = ProtoEnv, - pstate = PState}), - emqx_logger:set_metadata_peername(esockd:format(Peername)), - gen_server:enter_loop(?MODULE, [{hibernate_after, 5000}], State, 20000); +info(Keys, State) when is_list(Keys) -> + [{Key, info(Key, State)} || Key <- Keys]; +info(socktype, #state{transport = Transport, socket = Socket}) -> + Transport:type(Socket); +info(peername, #state{peername = Peername}) -> + Peername; +info(sockname, #state{sockname = Sockname}) -> + Sockname; +info(sockstate, #state{sockstate = SockSt}) -> + SockSt; +info(active_n, #state{active_n = ActiveN}) -> + ActiveN; +info(stats_timer, #state{stats_timer = StatsTimer}) -> + StatsTimer; +info(limit_timer, #state{limit_timer = LimitTimer}) -> + LimitTimer; +info(limiter, #state{limiter = Limiter}) -> + maybe_apply(fun emqx_limiter:info/1, Limiter). + +-spec stats(pid()|state()) -> emqx_types:stats(). +stats(CPid) when is_pid(CPid) -> + call(CPid, stats); +stats(#state{transport = Transport, + socket = Socket, + pstate = PState}) -> + SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of + {ok, Ss} -> Ss; + {error, _} -> [] + end, + ConnStats = emqx_pd:get_counters(?CONN_STATS), + ChanStats = emqx_stomp_protocol:stats(PState), + ProcStats = emqx_misc:proc_stats(), + lists:append([SockStats, ConnStats, ChanStats, ProcStats]). + +call(Pid, Req) -> + call(Pid, Req, infinity). +call(Pid, Req, Timeout) -> + gen_server:call(Pid, Req, Timeout). + +init([Transport, RawSocket, ProtoEnv]) -> + case Transport:wait(RawSocket) of + {ok, Socket} -> + init_state(Transport, Socket, ProtoEnv); {error, Reason} -> - {stop, Reason} + ok = Transport:fast_close(RawSocket), + exit_on_sock_error(Reason) end. -init_rate_limit(undefined) -> - undefined; -init_rate_limit({Rate, Burst}) -> - esockd_rate_limit:new(Rate, Burst). +init_state(Transport, Socket, ProtoEnv) -> + {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]), + {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]), + + SendFun = {fun ?MODULE:send/4, [Transport, Socket, self()]}, + HrtBtFun = {fun ?MODULE:heartbeat/2, [Transport, Socket]}, + Parser = emqx_stomp_frame:init_parer_state(ProtoEnv), + + ActiveN = proplists:get_value(active_n, ProtoEnv, ?ACTIVE_N), + GcState = emqx_gc:init(?DEFAULT_GC_POLICY), + + Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]), + ConnInfo = #{socktype => Transport:type(Socket), + peername => Peername, + sockname => Sockname, + peercert => Peercert, + sendfun => SendFun, + heartfun => HrtBtFun, + conn_mod => ?MODULE + }, + PState = emqx_stomp_protocol:init(ConnInfo, ProtoEnv), + State = #state{transport = Transport, + socket = Socket, + peername = Peername, + sockname = Sockname, + sockstate = idle, + active_n = ActiveN, + limiter = undefined, + parser = Parser, + proto_env = ProtoEnv, + gc_state = GcState, + pstate = PState}, + case activate_socket(State) of + {ok, NState} -> + emqx_logger:set_metadata_peername( + esockd:format(Peername)), + gen_server:enter_loop( + ?MODULE, [{hibernate_after, 5000}], NState, 20000); + {error, Reason} -> + ok = Transport:fast_close(Socket), + exit_on_sock_error(Reason) + end. + +-spec exit_on_sock_error(any()) -> no_return(). +exit_on_sock_error(Reason) when Reason =:= einval; + Reason =:= enotconn; + Reason =:= closed -> + erlang:exit(normal); +exit_on_sock_error(timeout) -> + erlang:exit({shutdown, ssl_upgrade_timeout}); +exit_on_sock_error(Reason) -> + erlang:exit({shutdown, Reason}). send(Data, Transport, Sock, ConnPid) -> try Transport:async_send(Sock, Data) of @@ -101,10 +216,10 @@ handle_call(info, _From, State = #state{transport = Transport, socket = Sock, peername = Peername, await_recv = AwaitRecv, - conn_state = ConnState, + sockstate = ConnState, pstate = PState}) -> ClientInfo = [{peername, Peername}, {await_recv, AwaitRecv}, - {conn_state, ConnState}], + {sockstate, ConnState}], ProtoInfo = emqx_stomp_protocol:info(PState), case Transport:getstat(Sock, ?SOCK_STATS) of {ok, SockStats} -> @@ -113,6 +228,12 @@ handle_call(info, _From, State = #state{transport = Transport, {stop, Reason, lists:append([ClientInfo, ProtoInfo]), State} end; +handle_call(discard, _From, State) -> + shutdown(discared, State); + +handle_call(kick, _From, State) -> + shutdown(kicked, State); + handle_call(Req, _From, State) -> ?LOG(error, "unexpected request: ~p", [Req]), {reply, ignored, State}. @@ -121,6 +242,11 @@ handle_cast(Msg, State) -> ?LOG(error, "unexpected msg: ~p", [Msg]), noreply(State). +handle_info({event, connected}, State = #state{pstate = PState}) -> + ClientId = emqx_stomp_protocol:info(clientid, PState), + emqx_cm:insert_channel_info(ClientId, info(State), stats(State)), + noreply(State); + handle_info(timeout, State) -> shutdown(idle_timeout, State); @@ -141,26 +267,76 @@ handle_info({timeout, TRef, TMsg}, State) when TMsg =:= incoming; shutdown({sock_error, Reason}, State) end; +handle_info({timeout, _TRef, limit_timeout}, State) -> + NState = State#state{sockstate = idle, + limit_timer = undefined + }, + handle_info(activate_socket, NState); + +handle_info({timeout, _TRef, emit_stats}, + State = #state{pstate = PState}) -> + ClientId = emqx_stomp_protocol:info(clientid, PState), + emqx_cm:set_chan_stats(ClientId, stats(State)), + {ok, State#state{stats_timer = undefined}}; + handle_info({timeout, TRef, TMsg}, State) -> with_proto(timeout, [TRef, TMsg], State); handle_info({'EXIT', HbProc, Error}, State = #state{heartbeat = HbProc}) -> stop(Error, State); -handle_info(activate_sock, State) -> - noreply(run_socket(State#state{conn_state = running})); - -handle_info({inet_async, _Sock, _Ref, {ok, Bytes}}, State) -> - ?LOG(debug, "RECV ~p", [Bytes]), - received(Bytes, rate_limit(size(Bytes), State#state{await_recv = false})); - -handle_info({inet_async, _Sock, _Ref, {error, Reason}}, State) -> - shutdown(Reason, State); +handle_info(activate_socket, State = #state{sockstate = OldSst}) -> + case activate_socket(State) of + {ok, NState = #state{sockstate = NewSst}} -> + case OldSst =/= NewSst of + true -> {ok, {event, NewSst}, NState}; + false -> {ok, NState} + end; + {error, Reason} -> + handle_info({sock_error, Reason}, State) + end; handle_info({inet_reply, _Ref, ok}, State) -> noreply(State); +handle_info({Inet, _Sock, Data}, State) when Inet == tcp; Inet == ssl -> + ?LOG(debug, "RECV ~0p", [Data]), + Oct = iolist_size(Data), + inc_counter(incoming_bytes, Oct), + ok = emqx_metrics:inc('bytes.received', Oct), + ensure_stats_timer(?IDLE_TIMEOUT, received(Data, State)); + +handle_info({Passive, _Sock}, State) + when Passive == tcp_passive; Passive == ssl_passive -> + %% In Stats + Pubs = emqx_pd:reset_counter(incoming_pubs), + Bytes = emqx_pd:reset_counter(incoming_bytes), + InStats = #{cnt => Pubs, oct => Bytes}, + %% Ensure Rate Limit + NState = ensure_rate_limit(InStats, State), + %% Run GC and Check OOM + NState1 = check_oom(run_gc(InStats, NState)), + handle_info(activate_socket, NState1); + +handle_info({Error, _Sock, Reason}, State) + when Error == tcp_error; Error == ssl_error -> + handle_info({sock_error, Reason}, State); + +handle_info({Closed, _Sock}, State) + when Closed == tcp_closed; Closed == ssl_closed -> + handle_info({sock_closed, Closed}, close_socket(State)); + handle_info({inet_reply, _Sock, {error, Reason}}, State) -> + handle_info({sock_error, Reason}, State); + +handle_info({sock_error, Reason}, State) -> + case Reason =/= closed andalso Reason =/= einval of + true -> ?LOG(warning, "socket_error: ~p", [Reason]); + false -> ok + end, + handle_info({sock_closed, Reason}, close_socket(State)); + +handle_info({sock_closed, Reason}, State) -> shutdown(Reason, State); handle_info({deliver, _Topic, Msg}, State = #state{pstate = PState}) -> @@ -208,7 +384,7 @@ with_proto(Fun, Args, State = #state{pstate = PState}) -> received(<<>>, State) -> noreply(State); -received(Bytes, State = #state{parser = Parser, +received(Bytes, State = #state{parser = Parser, pstate = PState}) -> try emqx_stomp_frame:parse(Bytes, Parser) of {more, NewParser} -> @@ -237,25 +413,68 @@ received(Bytes, State = #state{parser = Parser, reset_parser(State = #state{proto_env = ProtoEnv}) -> State#state{parser = emqx_stomp_frame:init_parer_state(ProtoEnv)}. -rate_limit(_Size, State = #state{rate_limit = undefined}) -> - run_socket(State); -rate_limit(Size, State = #state{rate_limit = Rl}) -> - case esockd_rate_limit:check(Size, Rl) of - {0, Rl1} -> - run_socket(State#state{conn_state = running, rate_limit = Rl1}); - {Pause, Rl1} -> - ?LOG(error, "Rate limiter pause for ~p", [Pause]), - erlang:send_after(Pause, self(), activate_sock), - State#state{conn_state = blocked, rate_limit = Rl1} +activate_socket(State = #state{sockstate = closed}) -> + {ok, State}; +activate_socket(State = #state{sockstate = blocked}) -> + {ok, State}; +activate_socket(State = #state{transport = Transport, + socket = Socket, + active_n = N}) -> + case Transport:setopts(Socket, [{active, N}]) of + ok -> {ok, State#state{sockstate = running}}; + Error -> Error end. -run_socket(State = #state{conn_state = blocked}) -> - State; -run_socket(State = #state{await_recv = true}) -> - State; -run_socket(State = #state{transport = Transport, socket = Sock}) -> - Transport:async_recv(Sock, 0, infinity), - State#state{await_recv = true}. +close_socket(State = #state{sockstate = closed}) -> State; +close_socket(State = #state{transport = Transport, socket = Socket}) -> + ok = Transport:fast_close(Socket), + State#state{sockstate = closed}. + +%%-------------------------------------------------------------------- +%% Ensure rate limit + +ensure_rate_limit(Stats, State = #state{limiter = Limiter}) -> + case ?ENABLED(Limiter) andalso emqx_limiter:check(Stats, Limiter) of + false -> State; + {ok, Limiter1} -> + State#state{limiter = Limiter1}; + {pause, Time, Limiter1} -> + ?LOG(warning, "Pause ~pms due to rate limit", [Time]), + TRef = start_timer(Time, limit_timeout), + State#state{sockstate = blocked, + limiter = Limiter1, + limit_timer = TRef + } + end. + +%%-------------------------------------------------------------------- +%% Run GC and Check OOM + +run_gc(Stats, State = #state{gc_state = GcSt}) -> + case ?ENABLED(GcSt) andalso emqx_gc:run(Stats, GcSt) of + false -> State; + {_IsGC, GcSt1} -> + State#state{gc_state = GcSt1} + end. + +check_oom(State) -> + OomPolicy = ?DEFAULT_OOM_POLICY, + ?tp(debug, check_oom, #{policy => OomPolicy}), + case ?ENABLED(OomPolicy) andalso emqx_misc:check_oom(OomPolicy) of + {shutdown, Reason} -> + %% triggers terminate/2 callback immediately + erlang:exit({shutdown, Reason}); + _Other -> + ok + end, + State. + +%%-------------------------------------------------------------------- +%% Ensure/cancel stats timer + +ensure_stats_timer(Timeout, State = #state{stats_timer = undefined}) -> + State#state{stats_timer = start_timer(Timeout, emit_stats)}; +ensure_stats_timer(_Timeout, State) -> State. getstat(Stat, #state{transport = Transport, socket = Sock}) -> case Transport:getstat(Sock, [Stat]) of @@ -272,3 +491,6 @@ stop(Reason, State) -> shutdown(Reason, State) -> stop({shutdown, Reason}, State). +inc_counter(Key, Inc) -> + _ = emqx_pd:inc_counter(Key, Inc), + ok. diff --git a/apps/emqx_stomp/src/emqx_stomp_heartbeat.erl b/apps/emqx_stomp/src/emqx_stomp_heartbeat.erl index 145359e53..2a221ad68 100644 --- a/apps/emqx_stomp/src/emqx_stomp_heartbeat.erl +++ b/apps/emqx_stomp/src/emqx_stomp_heartbeat.erl @@ -33,7 +33,6 @@ outgoing => #heartbeater{} }. - %%-------------------------------------------------------------------- %% APIs %%-------------------------------------------------------------------- diff --git a/apps/emqx_stomp/src/emqx_stomp_protocol.erl b/apps/emqx_stomp/src/emqx_stomp_protocol.erl index 0bd80d628..dc02ac232 100644 --- a/apps/emqx_stomp/src/emqx_stomp_protocol.erl +++ b/apps/emqx_stomp/src/emqx_stomp_protocol.erl @@ -30,6 +30,8 @@ %% API -export([ init/2 , info/1 + , info/2 + , stats/1 ]). -export([ received/2 @@ -45,19 +47,28 @@ ]). -record(pstate, { - peername, - heartfun, - sendfun, + %% Stomp ConnInfo + conninfo :: emqx_types:conninfo(), + %% Stomp ClientInfo + clientinfo :: emqx_types:clientinfo(), + %% Stomp Heartbeats + heart_beats :: emqx_stomp_hearbeat:heartbeat(), + %% Stomp Connection State connected = false, - proto_ver, - proto_name, - heart_beats, - login, - allow_anonymous, - default_user, - subscriptions = [], + %% Timers timers :: #{atom() => disable | undefined | reference()}, - transaction :: #{binary() => list()} + %% Transaction + transaction :: #{binary() => list()}, + %% Subscriptions + subscriptions = [], + %% Send function + sendfun :: function(), + %% Heartbeat function + heartfun :: function(), + %% The confs for the connection + %% TODO: put these configs into a public mem? + allow_anonymous, + default_user }). -define(TIMER_TABLE, #{ @@ -68,34 +79,132 @@ -define(TRANS_TIMEOUT, 60000). +-define(INFO_KEYS, [conninfo, conn_state, clientinfo, session, will_msg]). + +-define(STATS_KEYS, [subscriptions_cnt, + subscriptions_max, + inflight_cnt, + inflight_max, + mqueue_len, + mqueue_max, + mqueue_dropped, + next_pkt_id, + awaiting_rel_cnt, + awaiting_rel_max + ]). + -type(pstate() :: #pstate{}). %% @doc Init protocol -init(#{peername := Peername, - sendfun := SendFun, - heartfun := HeartFun}, Env) -> - AllowAnonymous = get_value(allow_anonymous, Env, false), - DefaultUser = get_value(default_user, Env), - #pstate{peername = Peername, - heartfun = HeartFun, - sendfun = SendFun, - timers = #{}, - transaction = #{}, - allow_anonymous = AllowAnonymous, - default_user = DefaultUser}. +init(ConnInfo = #{peername := {PeerHost, _Port}, + sockname := {_Host, SockPort}, + sendfun := SendFun, + heartfun := HeartFun}, Opts) -> -info(#pstate{connected = Connected, - proto_ver = ProtoVer, - proto_name = ProtoName, - heart_beats = Heartbeats, - login = Login, - subscriptions = Subscriptions}) -> - [{connected, Connected}, - {proto_ver, ProtoVer}, - {proto_name, ProtoName}, - {heart_beats, Heartbeats}, - {login, Login}, - {subscriptions, Subscriptions}]. + NConnInfo = default_conninfo(ConnInfo), + + ClientInfo = #{zone => undefined, + protocol => stomp, + peerhost => PeerHost, + sockport => SockPort, + clientid => undefined, + username => undefined, + mountpoint => undefined, + is_bridge => false, + is_superuser => false + }, + + AllowAnonymous = get_value(allow_anonymous, Opts, false), + DefaultUser = get_value(default_user, Opts), + + #pstate{ + conninfo = NConnInfo, + clientinfo = ClientInfo, + heartfun = HeartFun, + sendfun = SendFun, + timers = #{}, + transaction = #{}, + allow_anonymous = AllowAnonymous, + default_user = DefaultUser + }. + +default_conninfo(ConnInfo) -> + NConnInfo = maps:without([sendfun, heartfun], ConnInfo), + NConnInfo#{ + proto_name => <<"STOMP">>, + proto_ver => <<"1.2">>, + clean_start => true, + clientid => undefined, + username => undefined, + conn_props => [], + connected => false, + connected_at => undefined, + keepalive => undefined, + receive_maximum => 0, + expiry_interval => 0 + }. + +-spec info(pstate()) -> emqx_types:infos(). +info(State) -> + maps:from_list(info(?INFO_KEYS, State)). + +-spec info(list(atom())|atom(), pstate()) -> term(). +info(Keys, State) when is_list(Keys) -> + [{Key, info(Key, State)} || Key <- Keys]; +info(conninfo, #pstate{conninfo = ConnInfo}) -> + ConnInfo; +info(socktype, #pstate{conninfo = ConnInfo}) -> + maps:get(socktype, ConnInfo, undefined); +info(peername, #pstate{conninfo = ConnInfo}) -> + maps:get(peername, ConnInfo, undefined); +info(sockname, #pstate{conninfo = ConnInfo}) -> + maps:get(sockname, ConnInfo, undefined); +info(proto_name, #pstate{conninfo = ConnInfo}) -> + maps:get(proto_name, ConnInfo, undefined); +info(proto_ver, #pstate{conninfo = ConnInfo}) -> + maps:get(proto_ver, ConnInfo, undefined); +info(connected_at, #pstate{conninfo = ConnInfo}) -> + maps:get(connected_at, ConnInfo, undefined); +info(clientinfo, #pstate{clientinfo = ClientInfo}) -> + ClientInfo; +info(zone, _) -> + undefined; +info(clientid, #pstate{clientinfo = ClientInfo}) -> + maps:get(clientid, ClientInfo, undefined); +info(username, #pstate{clientinfo = ClientInfo}) -> + maps:get(username, ClientInfo, undefined); +info(session, State) -> + session_info(State); +info(conn_state, #pstate{connected = true}) -> + connected; +info(conn_state, _) -> + disconnected; +info(will_msg, _) -> + undefined. + +session_info(#pstate{conninfo = ConnInfo, subscriptions = Subs}) -> + NSubs = lists:foldl(fun({_Id, Topic, _Ack}, Acc) -> + Acc#{Topic => ?DEFAULT_SUBOPTS} + end, #{}, Subs), + #{subscriptions => NSubs, + upgrade_qos => false, + retry_interval => 0, + await_rel_timeout => 0, + created_at => maps:get(connected_at, ConnInfo, 0) + }. + +-spec stats(pstate()) -> emqx_types:stats(). +stats(#pstate{subscriptions = Subs}) -> + [{subscriptions_cnt, length(Subs)}, + {subscriptions_max, 0}, + {inflight_cnt, 0}, + {inflight_max, 0}, + {mqueue_len, 0}, + {mqueue_max, 0}, + {mqueue_dropped, 0}, + {next_pkt_id, 0}, + {awaiting_rel_cnt, 0}, + {awaiting_rel_max, 0}]. -spec(received(stomp_frame(), pstate()) -> {ok, pstate()} @@ -105,20 +214,50 @@ received(Frame = #stomp_frame{command = <<"STOMP">>}, State) -> received(Frame#stomp_frame{command = <<"CONNECT">>}, State); received(#stomp_frame{command = <<"CONNECT">>, headers = Headers}, - State = #pstate{connected = false, allow_anonymous = AllowAnonymous, default_user = DefaultUser}) -> + State = #pstate{connected = false}) -> case negotiate_version(header(<<"accept-version">>, Headers)) of {ok, Version} -> Login = header(<<"login">>, Headers), Passc = header(<<"passcode">>, Headers), - case check_login(Login, Passc, AllowAnonymous, DefaultUser) of + case check_login(Login, Passc, + allow_anonymous(State), + default_user(State) + ) of true -> - emqx_logger:set_metadata_clientid(Login), + NLogin = case Login == undefined orelse Login == <<>> of + false -> Login; + true -> emqx_guid:to_base62(emqx_guid:gen()) + end, + emqx_logger:set_metadata_clientid(NLogin), + ConnInfo = State#pstate.conninfo, + ClitInfo = State#pstate.clientinfo, + NConnInfo = ConnInfo#{ + proto_ver => Version, + clientid => NLogin, + username => NLogin + }, + NClitInfo = ClitInfo#{ + clientid => NLogin, + username => NLogin + }, + ConnPid = self(), + _ = emqx_cm_locker:trans(NLogin, fun(_) -> + emqx_cm:discard_session(NLogin), + emqx_cm:register_channel(NLogin, ConnPid, NConnInfo) + end), Heartbeats = parse_heartbeats(header(<<"heart-beat">>, Headers, <<"0,0">>)), - NState = start_heartbeart_timer(Heartbeats, State#pstate{connected = true, - proto_ver = Version, login = Login}), - send(connected_frame([{<<"version">>, Version}, - {<<"heart-beat">>, reverse_heartbeats(Heartbeats)}]), NState); + NState = start_heartbeart_timer( + Heartbeats, + State#pstate{ + conninfo = NConnInfo, + clientinfo = NClitInfo} + ), + ConnectedFrame = connected_frame( + [{<<"version">>, Version}, + {<<"heart-beat">>, reverse_heartbeats(Heartbeats)} + ]), + send(ConnectedFrame, ensure_connected(NState)); false -> _ = send(error_frame(undefined, <<"Login or passcode error!">>), State), {error, login_or_passcode_error, State} @@ -130,6 +269,7 @@ received(#stomp_frame{command = <<"CONNECT">>, headers = Headers}, end; received(#stomp_frame{command = <<"CONNECT">>}, State = #pstate{connected = true}) -> + ?LOG(error, "Received CONNECT frame on connected=true state"), {error, unexpected_connect, State}; received(Frame = #stomp_frame{command = <<"SEND">>, headers = Headers}, State) -> @@ -139,30 +279,56 @@ received(Frame = #stomp_frame{command = <<"SEND">>, headers = Headers}, State) - end; received(#stomp_frame{command = <<"SUBSCRIBE">>, headers = Headers}, - State = #pstate{subscriptions = Subscriptions, login = Login}) -> + State = #pstate{subscriptions = Subs}) -> Id = header(<<"id">>, Headers), Topic = header(<<"destination">>, Headers), Ack = header(<<"ack">>, Headers, <<"auto">>), - {ok, State1} = case lists:keyfind(Id, 1, Subscriptions) of - {Id, Topic, Ack} -> - {ok, State}; - false -> - emqx_broker:subscribe(Topic, Login), - {ok, State#pstate{subscriptions = [{Id, Topic, Ack}|Subscriptions]}} - end, - maybe_send_receipt(receipt_id(Headers), State1); + + case lists:keyfind(Id, 1, Subs) of + {Id, Topic, Ack} -> + ?LOG(info, "Subscription has established: ~s", [Topic]), + maybe_send_receipt(receipt_id(Headers), State); + false -> + case check_acl(subscribe, Topic, State) of + allow -> + ClientInfo = State#pstate.clientinfo, + ClientId = maps:get(clientid, ClientInfo), + %% XXX: We don't parse the request topic name or filter + %% which meaning stomp does not support shared + %% subscription + _ = run_hooks('client.subscribe', + [ClientInfo, _SubProps = #{}], + [{Topic, _TopicOpts = #{}}]), + + emqx_broker:subscribe(Topic, ClientId), + + SubOpts = ?DEFAULT_SUBOPTS#{is_new => true}, + _ = run_hooks('session.subscribed', + [ClientInfo, Topic, SubOpts]), + + NState = put_subs({Id, Topic, Ack}, State), + maybe_send_receipt(receipt_id(Headers), NState) + end + end; received(#stomp_frame{command = <<"UNSUBSCRIBE">>, headers = Headers}, - State = #pstate{subscriptions = Subscriptions}) -> + State = #pstate{subscriptions = Subs, clientinfo = ClientInfo}) -> Id = header(<<"id">>, Headers), + {ok, State1} = case lists:keyfind(Id, 1, Subs) of + {Id, Topic, _Ack} -> + _ = run_hooks('client.unsubscribe', + [ClientInfo, #{}], + [{Topic, #{}}]), - {ok, State1} = case lists:keyfind(Id, 1, Subscriptions) of - {Id, Topic, _Ack} -> - ok = emqx_broker:unsubscribe(Topic), - {ok, State#pstate{subscriptions = lists:keydelete(Id, 1, Subscriptions)}}; - false -> - {ok, State} - end, + ok = emqx_broker:unsubscribe(Topic), + + _ = run_hooks('session.unsubscribe', + [ClientInfo, Topic, ?DEFAULT_SUBOPTS]), + + {ok, remove_subs(Id, State)}; + false -> + {ok, State} + end, maybe_send_receipt(receipt_id(Headers), State1); %% ACK @@ -240,8 +406,8 @@ received(#stomp_frame{command = <<"DISCONNECT">>, headers = Headers}, State) -> {stop, normal, State}. send(Msg = #message{topic = Topic, headers = Headers, payload = Payload}, - State = #pstate{subscriptions = Subscriptions}) -> - case lists:keyfind(Topic, 2, Subscriptions) of + State = #pstate{subscriptions = Subs}) -> + case lists:keyfind(Topic, 2, Subs) of {Id, Topic, Ack} -> Headers0 = [{<<"subscription">>, Id}, {<<"message-id">>, next_msgid()}, @@ -269,6 +435,9 @@ send(Frame, State = #pstate{sendfun = {Fun, Args}}) -> erlang:apply(Fun, [Data] ++ Args), {ok, State}. +shutdown(Reason, State = #pstate{connected = true}) -> + _ = ensure_disconnected(Reason, State), + ok; shutdown(_Reason, _State) -> ok. @@ -398,11 +567,18 @@ receipt_id(Headers) -> handle_recv_send_frame(#stomp_frame{command = <<"SEND">>, headers = Headers, body = Body}, State) -> Topic = header(<<"destination">>, Headers), - _ = maybe_send_receipt(receipt_id(Headers), State), - _ = emqx_broker:publish( - make_mqtt_message(Topic, Headers, iolist_to_binary(Body)) - ), - State. + case check_acl(publish, Topic, State) of + allow -> + _ = maybe_send_receipt(receipt_id(Headers), State), + _ = emqx_broker:publish( + make_mqtt_message(Topic, Headers, iolist_to_binary(Body)) + ), + State; + deny -> + ErrFrame = error_frame(receipt_id(Headers), <<"Not Authorized">>), + {ok, NState} = send(ErrFrame, State), + NState + end. handle_recv_ack_frame(#stomp_frame{command = <<"ACK">>, headers = Headers}, State) -> Id = header(<<"id">>, Headers), @@ -435,6 +611,58 @@ start_heartbeart_timer(Heartbeats, State) -> [incoming_timer, outgoing_timer], State#pstate{heart_beats = emqx_stomp_heartbeat:init(Heartbeats)}). +%%-------------------------------------------------------------------- +%% ... + +check_acl(PubSub, Topic, State = #pstate{clientinfo = ClientInfo}) -> + case is_acl_enabled(State) andalso + emqx_access_control:check_acl(ClientInfo, PubSub, Topic) of + false -> allow; + Res -> Res + end. + +put_subs({Id, Topic, Ack}, State = #pstate{subscriptions = Subs}) -> + State#pstate{subscriptions = lists:keystore(Id, 1, Subs, {Id, Topic, Ack})}. + +remove_subs(Id, State = #pstate{subscriptions = Subs}) -> + State#pstate{subscriptions = lists:keydelete(Id, 1, Subs)}. + +%%-------------------------------------------------------------------- +%% ... + +is_acl_enabled(_) -> + %% TODO: configs from somewhere + true. + +default_user(#pstate{default_user = DefaultUser}) -> + DefaultUser. +allow_anonymous(#pstate{allow_anonymous = AllowAnonymous}) -> + AllowAnonymous. + +ensure_connected(State = #pstate{conninfo = ConnInfo, + clientinfo = ClientInfo}) -> + NConnInfo = ConnInfo#{ + connected => true, + connected_at => erlang:system_time(millisecond) + }, + %% send connected event + self() ! {event, connected}, + ok = run_hooks('client.connected', [ClientInfo, NConnInfo]), + State#pstate{conninfo = NConnInfo, + connected = true + }. + +ensure_disconnected(Reason, State = #pstate{conninfo = ConnInfo, clientinfo = ClientInfo}) -> + NConnInfo = ConnInfo#{disconnected_at => erlang:system_time(millisecond)}, + ok = run_hooks('client.disconnected', [ClientInfo, Reason, NConnInfo]), + State#pstate{conninfo = NConnInfo, connected = false}. + +run_hooks(Name, Args) -> + emqx_hooks:run(Name, Args). + +run_hooks(Name, Args, Acc) -> + emqx_hooks:run_fold(Name, Args, Acc). + %%-------------------------------------------------------------------- %% Timer