diff --git a/src/emqx_broker.erl b/src/emqx_broker.erl index 6e74bcfa4..674c50194 100644 --- a/src/emqx_broker.erl +++ b/src/emqx_broker.erl @@ -376,9 +376,9 @@ set_subopts(Topic, NewOpts) when is_binary(Topic), is_map(NewOpts) -> topics() -> emqx_router:topics(). -%%------------------------------------------------------------------------------ +%%-------------------------------------------------------------------- %% Stats fun -%%------------------------------------------------------------------------------ +%%-------------------------------------------------------------------- stats_fun() -> safe_update_stats(?SUBSCRIBER, 'subscribers.count', 'subscribers.max'), diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl index 9799719b6..a3804298b 100644 --- a/src/emqx_channel.erl +++ b/src/emqx_channel.erl @@ -14,568 +14,598 @@ %% limitations under the License. %%-------------------------------------------------------------------- -%% MQTT Channel +%% MQTT TCP/SSL Channel -module(emqx_channel). +-behaviour(gen_statem). + -include("emqx.hrl"). -include("emqx_mqtt.hrl"). -include("logger.hrl"). -include("types.hrl"). --logger_header("[Channel]"). +-export([start_link/3]). --export([ attrs/1 ]). - --export([ zone/1 - , client_id/1 - , conn_mod/1 - , endpoint/1 - , proto_ver/1 - , keepalive/1 - , session/1 +%% APIs +-export([ info/1 + , stats/1 ]). --export([ init/2 - , handle_in/2 - , handle_out/2 - , handle_timeout/3 - , terminate/2 +%% state callbacks +-export([ idle/3 + , connected/3 + , disconnected/3 ]). --export_type([channel/0]). +%% gen_statem callbacks +-export([ init/1 + , callback_mode/0 + , code_change/4 + , terminate/3 + ]). --record(channel, { - conn_mod :: maybe(module()), - endpoint :: emqx_endpoint:endpoint(), - proto_name :: binary(), - proto_ver :: emqx_mqtt:version(), - keepalive :: non_neg_integer(), - session :: emqx_session:session(), - will_msg :: emqx_types:message(), - enable_acl :: boolean(), - is_bridge :: boolean(), - connected :: boolean(), - topic_aliases :: map(), - alias_maximum :: map(), - connected_at :: erlang:timestamp() +-record(state, { + transport :: esockd:transport(), + socket :: esockd:socket(), + peername :: emqx_types:peername(), + sockname :: emqx_types:peername(), + conn_state :: running | blocked, + active_n :: pos_integer(), + rate_limit :: maybe(esockd_rate_limit:bucket()), + pub_limit :: maybe(esockd_rate_limit:bucket()), + limit_timer :: maybe(reference()), + serialize :: fun((emqx_types:packet()) -> iodata()), + parse_state :: emqx_frame:parse_state(), + proto_state :: emqx_protocol:proto_state(), + gc_state :: emqx_gc:gc_state(), + keepalive :: maybe(emqx_keepalive:keepalive()), + stats_timer :: disabled | maybe(reference()), + idle_timeout :: timeout() }). --opaque(channel() :: #channel{}). +-logger_header("[Channel]"). -attrs(#channel{endpoint = Endpoint, session = Session}) -> - maps:merge(emqx_endpoint:to_map(Endpoint), - emqx_session:attrs(Session)). +-define(ACTIVE_N, 100). +-define(HANDLE(T, C, D), handle((T), (C), (D))). +-define(CHAN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). +-define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]). -zone(#channel{endpoint = Endpoint}) -> - emqx_endpoint:zone(Endpoint). +-spec(start_link(esockd:transport(), esockd:socket(), proplists:proplist()) + -> {ok, pid()}). +start_link(Transport, Socket, Options) -> + {ok, proc_lib:spawn_link(?MODULE, init, [{Transport, Socket, Options}])}. --spec(client_id(channel()) -> emqx_types:client_id()). -client_id(#channel{endpoint = Endpoint}) -> - emqx_endpoint:client_id(Endpoint). +%%-------------------------------------------------------------------- +%% API +%%-------------------------------------------------------------------- --spec(conn_mod(channel()) -> module()). -conn_mod(#channel{conn_mod = ConnMod}) -> - ConnMod. +%% @doc Get channel's info. +-spec(info(pid() | #state{}) -> proplists:proplist()). +info(CPid) when is_pid(CPid) -> + call(CPid, info); --spec(endpoint(channel()) -> emqx_endpoint:endpoint()). -endpoint(#channel{endpoint = Endpoint}) -> - Endpoint. +info(#state{transport = Transport, + socket = Socket, + peername = Peername, + sockname = Sockname, + conn_state = ConnState, + active_n = ActiveN, + rate_limit = RateLimit, + pub_limit = PubLimit, + proto_state = ProtoState, + gc_state = GCState, + stats_timer = StatsTimer, + idle_timeout = IdleTimeout}) -> + [{socktype, Transport:type(Socket)}, + {peername, Peername}, + {sockname, Sockname}, + {conn_state, ConnState}, + {active_n, ActiveN}, + {rate_limit, rate_limit_info(RateLimit)}, + {pub_limit, rate_limit_info(PubLimit)}, + {gc_state, emqx_gc:info(GCState)}, + {enable_stats, case StatsTimer of + disabled -> false; + _Otherwise -> true + end}, + {idle_timeout, IdleTimeout} | + emqx_protocol:info(ProtoState)]. --spec(proto_ver(channel()) -> emqx_mqtt:version()). -proto_ver(#channel{proto_ver = ProtoVer}) -> - ProtoVer. +rate_limit_info(undefined) -> + undefined; +rate_limit_info(Limit) -> + esockd_rate_limit:info(Limit). -keepalive(#channel{keepalive = Keepalive}) -> - Keepalive. +%% @doc Get channel's stats. +-spec(stats(pid() | #state{}) -> proplists:proplist()). +stats(CPid) when is_pid(CPid) -> + call(CPid, stats); --spec(session(channel()) -> emqx_session:session()). -session(#channel{session = Session}) -> - Session. +stats(#state{transport = Transport, socket = Socket}) -> + SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of + {ok, Ss} -> Ss; + {error, _} -> [] + end, + ChanStats = [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS], + lists:append([SockStats, ChanStats, emqx_misc:proc_stats()]). --spec(init(map(), proplists:proplist()) -> channel()). -init(ConnInfo = #{peername := Peername, - sockname := Sockname, - conn_mod := ConnMod}, Options) -> +%% @private +call(CPid, Req) -> + gen_statem:call(CPid, Req, infinity). + +%%-------------------------------------------------------------------- +%% gen_statem callbacks +%%-------------------------------------------------------------------- + +init({Transport, RawSocket, Options}) -> + {ok, Socket} = Transport:wait(RawSocket), + {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]), + {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]), + Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]), + emqx_logger:set_metadata_peername(esockd_net:format(Peername)), Zone = proplists:get_value(zone, Options), - Peercert = maps:get(peercert, ConnInfo, nossl), - Username = peer_cert_as_username(Peercert, Options), - Mountpoint = emqx_zone:get_env(Zone, mountpoint), - WsCookie = maps:get(ws_cookie, ConnInfo, undefined), - Endpoint = emqx_endpoint:new(#{zone => Zone, - peername => Peername, - sockname => Sockname, - username => Username, - peercert => Peercert, - mountpoint => Mountpoint, - ws_cookie => WsCookie - }), - EnableAcl = emqx_zone:get_env(Zone, enable_acl, false), - #channel{conn_mod = ConnMod, - endpoint = Endpoint, - enable_acl = EnableAcl, - is_bridge = false, - connected = false - }. + RateLimit = init_limiter(proplists:get_value(rate_limit, Options)), + PubLimit = init_limiter(emqx_zone:get_env(Zone, publish_limit)), + ActiveN = proplists:get_value(active_n, Options, ?ACTIVE_N), + MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE), + ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}), + ProtoState = emqx_protocol:init(#{peername => Peername, + sockname => Sockname, + peercert => Peercert, + conn_mod => ?MODULE}, Options), + GcPolicy = emqx_zone:get_env(Zone, force_gc_policy, false), + GcState = emqx_gc:init(GcPolicy), + EnableStats = emqx_zone:get_env(Zone, enable_stats, true), + StatsTimer = if EnableStats -> undefined; ?Otherwise -> disabled end, + IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), + ok = emqx_misc:init_proc_mng_policy(Zone), + State = #state{transport = Transport, + socket = Socket, + peername = Peername, + conn_state = running, + active_n = ActiveN, + rate_limit = RateLimit, + pub_limit = PubLimit, + parse_state = ParseState, + proto_state = ProtoState, + gc_state = GcState, + stats_timer = StatsTimer, + idle_timeout = IdleTimout + }, + gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}], + idle, State, self(), [IdleTimout]). -peer_cert_as_username(Peercert, Options) -> - case proplists:get_value(peer_cert_as_username, Options) of - cn -> esockd_peercert:common_name(Peercert); - dn -> esockd_peercert:subject(Peercert); - crt -> Peercert; - _ -> undefined +init_limiter(undefined) -> + undefined; +init_limiter({Rate, Burst}) -> + esockd_rate_limit:new(Rate, Burst). + +callback_mode() -> + [state_functions, state_enter]. + +%%-------------------------------------------------------------------- +%% Idle State + +idle(enter, _, State) -> + case activate_socket(State) of + ok -> keep_state_and_data; + {error, Reason} -> + shutdown(Reason, State) + end; + +idle(timeout, _Timeout, State) -> + stop(idle_timeout, State); + +idle(cast, {incoming, Packet = ?CONNECT_PACKET( + #mqtt_packet_connect{ + proto_ver = ProtoVer} + )}, State) -> + State1 = State#state{serialize = serialize_fun(ProtoVer)}, + handle_incoming(Packet, fun(NewSt) -> + {next_state, connected, NewSt} + end, State1); + +idle(cast, {incoming, Packet}, State) -> + ?LOG(warning, "Unexpected incoming: ~p", [Packet]), + shutdown(unexpected_incoming_packet, State); + +idle(EventType, Content, State) -> + ?HANDLE(EventType, Content, State). + +%%-------------------------------------------------------------------- +%% Connected State + +connected(enter, _PrevSt, State = #state{proto_state = ProtoState}) -> + ClientId = emqx_protocol:client_id(ProtoState), + ok = emqx_cm:set_chan_attrs(ClientId, info(State)), + %% Ensure keepalive after connected successfully. + Interval = emqx_protocol:info(keepalive, ProtoState), + case ensure_keepalive(Interval, State) of + ignore -> keep_state_and_data; + {ok, KeepAlive} -> + keep_state(State#state{keepalive = KeepAlive}); + {error, Reason} -> + shutdown(Reason, State) + end; + +connected(cast, {incoming, Packet = ?PACKET(?CONNECT)}, State) -> + ?LOG(warning, "Unexpected connect: ~p", [Packet]), + shutdown(unexpected_incoming_connect, State); + +connected(cast, {incoming, Packet}, State) when is_record(Packet, mqtt_packet) -> + handle_incoming(Packet, fun keep_state/1, State); + +connected(info, Deliver = {deliver, _Topic, _Msg}, + State = #state{proto_state = ProtoState}) -> + Delivers = emqx_misc:drain_deliver([Deliver]), + case emqx_protocol:handle_deliver(Delivers, ProtoState) of + {ok, NProtoState} -> + keep_state(State#state{proto_state = NProtoState}); + {ok, Packets, NProtoState} -> + NState = State#state{proto_state = NProtoState}, + handle_outgoing(Packets, fun keep_state/1, NState); + {error, Reason} -> + shutdown(Reason, State); + {error, Reason, NProtoState} -> + shutdown(Reason, State#state{proto_state = NProtoState}) + end; + +%% Keepalive timer +connected(info, {keepalive, check}, State = #state{keepalive = KeepAlive}) -> + case emqx_keepalive:check(KeepAlive) of + {ok, KeepAlive1} -> + keep_state(State#state{keepalive = KeepAlive1}); + {error, timeout} -> + shutdown(keepalive_timeout, State); + {error, Reason} -> + shutdown(Reason, State) + end; + +connected(EventType, Content, State) -> + ?HANDLE(EventType, Content, State). + +%%-------------------------------------------------------------------- +%% Disconnected State + +disconnected(enter, _, _State) -> + %% TODO: What to do? + %% CleanStart is true + keep_state_and_data; + +disconnected(EventType, Content, State) -> + ?HANDLE(EventType, Content, State). + +%% Handle call +handle({call, From}, info, State) -> + reply(From, info(State), State); + +handle({call, From}, stats, State) -> + reply(From, stats(State), State); + +%%handle({call, From}, kick, State) -> +%% ok = gen_statem:reply(From, ok), +%% shutdown(kicked, State); + +%%handle({call, From}, discard, State) -> +%% ok = gen_statem:reply(From, ok), +%% shutdown(discard, State); + +handle({call, From}, Req, State) -> + ?LOG(error, "Unexpected call: ~p", [Req]), + reply(From, ignored, State); + +%% Handle cast +handle(cast, Msg, State) -> + ?LOG(error, "Unexpected cast: ~p", [Msg]), + keep_state(State); + +%% Handle incoming data +handle(info, {Inet, _Sock, Data}, State) when Inet == tcp; + Inet == ssl -> + Oct = iolist_size(Data), + ?LOG(debug, "RECV ~p", [Data]), + emqx_pd:update_counter(incoming_bytes, Oct), + ok = emqx_metrics:inc('bytes.received', Oct), + NState = maybe_gc(1, Oct, State), + process_incoming(Data, ensure_stats_timer(NState)); + +handle(info, {Error, _Sock, Reason}, State) + when Error == tcp_error; Error == ssl_error -> + shutdown(Reason, State); + +handle(info, {Closed, _Sock}, State) + when Closed == tcp_closed; Closed == ssl_closed -> + shutdown(closed, State); + +handle(info, {Passive, _Sock}, State) when Passive == tcp_passive; + Passive == ssl_passive -> + %% Rate limit here:) + NState = ensure_rate_limit(State), + case activate_socket(NState) of + ok -> keep_state(NState); + {error, Reason} -> + shutdown(Reason, NState) + end; + +handle(info, activate_socket, State) -> + %% Rate limit timer expired. + NState = State#state{conn_state = running}, + case activate_socket(NState) of + ok -> + keep_state(NState#state{limit_timer = undefined}); + {error, Reason} -> + shutdown(Reason, NState) + end; + +handle(info, {inet_reply, _Sock, ok}, State) -> + %% something sent + keep_state(ensure_stats_timer(State)); + +handle(info, {inet_reply, _Sock, {error, Reason}}, State) -> + shutdown(Reason, State); + +handle(info, {timeout, Timer, emit_stats}, + State = #state{stats_timer = Timer, + proto_state = ProtoState, + gc_state = GcState}) -> + ClientId = emqx_protocol:client_id(ProtoState), + ok = emqx_cm:set_chan_stats(ClientId, stats(State)), + NState = State#state{stats_timer = undefined}, + Limits = erlang:get(force_shutdown_policy), + case emqx_misc:conn_proc_mng_policy(Limits) of + continue -> + keep_state(NState); + hibernate -> + %% going to hibernate, reset gc stats + GcState1 = emqx_gc:reset(GcState), + {keep_state, NState#state{gc_state = GcState1}, hibernate}; + {shutdown, Reason} -> + ?LOG(error, "Shutdown exceptionally due to ~p", [Reason]), + shutdown(Reason, NState) + end; + +handle(info, {timeout, Timer, Msg}, + State = #state{proto_state = ProtoState}) -> + case emqx_protocol:handle_timeout(Timer, Msg, ProtoState) of + {ok, NProtoState} -> + keep_state(State#state{proto_state = NProtoState}); + {ok, Packets, NProtoState} -> + handle_outgoing(Packets, fun keep_state/1, + State#state{proto_state = NProtoState}); + {error, Reason} -> + shutdown(Reason, State); + {error, Reason, NProtoState} -> + shutdown(Reason, State#state{proto_state = NProtoState}) + end; + +handle(info, {shutdown, discard, {ClientId, ByPid}}, State) -> + ?LOG(error, "Discarded by ~s:~p", [ClientId, ByPid]), + shutdown(discard, State); + +handle(info, {shutdown, conflict, {ClientId, NewPid}}, State) -> + ?LOG(warning, "Clientid '~s' conflict with ~p", [ClientId, NewPid]), + shutdown(conflict, State); + +handle(info, {shutdown, Reason}, State) -> + shutdown(Reason, State); + +handle(info, Info, State) -> + ?LOG(error, "Unexpected info: ~p", [Info]), + keep_state(State). + +code_change(_Vsn, State, Data, _Extra) -> + {ok, State, Data}. + +terminate(Reason, _StateName, #state{transport = Transport, + socket = Socket, + keepalive = KeepAlive, + proto_state = ProtoState}) -> + ?LOG(debug, "Terminated for ~p", [Reason]), + ok = Transport:fast_close(Socket), + ok = emqx_keepalive:cancel(KeepAlive), + emqx_protocol:terminate(Reason, ProtoState). + +%%-------------------------------------------------------------------- +%% Process incoming data + +-compile({inline, [process_incoming/2]}). +process_incoming(Data, State) -> + process_incoming(Data, [], State). + +process_incoming(<<>>, Packets, State) -> + {keep_state, State, next_incoming_events(Packets)}; + +process_incoming(Data, Packets, State = #state{parse_state = ParseState}) -> + try emqx_frame:parse(Data, ParseState) of + {ok, NParseState} -> + NState = State#state{parse_state = NParseState}, + {keep_state, NState, next_incoming_events(Packets)}; + {ok, Packet, Rest, NParseState} -> + NState = State#state{parse_state = NParseState}, + process_incoming(Rest, [Packet|Packets], NState); + {error, Reason} -> + shutdown(Reason, State) + catch + error:Reason:Stk -> + ?LOG(error, "Parse failed for ~p~n\ + Stacktrace:~p~nError data:~p", [Reason, Stk, Data]), + shutdown(parse_error, State) end. +next_incoming_events(Packets) when is_list(Packets) -> + [next_event(cast, {incoming, Packet}) + || Packet <- lists:reverse(Packets)]. + %%-------------------------------------------------------------------- %% Handle incoming packet -%%-------------------------------------------------------------------- --spec(handle_in(emqx_mqtt:packet(), channel()) - -> {ok, channel()} - | {ok, emqx_mqtt:packet(), channel()} - | {error, Reason :: term(), channel()} - | {stop, Error :: atom(), channel()}). -handle_in(?CONNECT_PACKET( - #mqtt_packet_connect{proto_name = ProtoName, - proto_ver = ProtoVer, - is_bridge = IsBridge, - client_id = ClientId, - username = Username, - password = Password, - keepalive = Keepalive} = ConnPkt), - Channel = #channel{endpoint = Endpoint}) -> - Endpoint1 = emqx_endpoint:update(#{client_id => ClientId, - username => Username, - password => Password - }, Endpoint), - emqx_logger:set_metadata_client_id(ClientId), - WillMsg = emqx_packet:will_msg(ConnPkt), - Channel1 = Channel#channel{endpoint = Endpoint1, - proto_name = ProtoName, - proto_ver = ProtoVer, - is_bridge = IsBridge, - keepalive = Keepalive, - will_msg = WillMsg - }, - %% fun validate_packet/2, - case pipeline([fun check_connect/2, - fun handle_connect/2], ConnPkt, Channel1) of - {ok, SP, Channel2} -> - handle_out({connack, ?RC_SUCCESS, sp(SP)}, Channel2); - {error, ReasonCode} -> - handle_out({connack, ReasonCode}, Channel1); - {error, ReasonCode, Channel2} -> - handle_out({connack, ReasonCode}, Channel2) - end; - -handle_in(Packet = ?PUBLISH_PACKET(QoS, Topic, PacketId), Channel) -> - case pipeline([fun validate_packet/2, - fun check_pub_caps/2, - fun check_pub_acl/2, - fun handle_publish/2], Packet, Channel) of - {error, ReasonCode} -> - ?LOG(warning, "Cannot publish qos~w message to ~s due to ~s", - [QoS, Topic, emqx_reason_codes:text(ReasonCode)]), - handle_out(case QoS of - ?QOS_0 -> {puberr, ReasonCode}; - ?QOS_1 -> {puback, PacketId, ReasonCode}; - ?QOS_2 -> {pubrec, PacketId, ReasonCode} - end, Channel); - Result -> Result - end; - -handle_in(?PUBACK_PACKET(PacketId, ReasonCode), Channel = #channel{session = Session}) -> - case emqx_session:puback(PacketId, ReasonCode, Session) of - {ok, NSession} -> - {ok, Channel#channel{session = NSession}}; - {error, _NotFound} -> - %% TODO: metrics? error msg? - {ok, Channel} - end; - -handle_in(?PUBREC_PACKET(PacketId, ReasonCode), Channel = #channel{session = Session}) -> - case emqx_session:pubrec(PacketId, ReasonCode, Session) of - {ok, NSession} -> - handle_out({pubrel, PacketId}, Channel#channel{session = NSession}); - {error, ReasonCode} -> - handle_out({pubrel, PacketId, ReasonCode}, Channel) - end; - -handle_in(?PUBREL_PACKET(PacketId, ReasonCode), Channel = #channel{session = Session}) -> - case emqx_session:pubrel(PacketId, ReasonCode, Session) of - {ok, NSession} -> - handle_out({pubcomp, PacketId}, Channel#channel{session = NSession}); - {error, ReasonCode} -> - handle_out({pubcomp, PacketId, ReasonCode}, Channel) - end; - -handle_in(?PUBCOMP_PACKET(PacketId, ReasonCode), Channel = #channel{session = Session}) -> - case emqx_session:pubcomp(PacketId, ReasonCode, Session) of - {ok, NSession} -> - {ok, Channel#channel{session = NSession}}; - {error, _ReasonCode} -> - %% TODO: How to handle the reason code? - {ok, Channel} - end; - -handle_in(?SUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters), - Channel = #channel{endpoint = Endpoint, session = Session}) -> - case check_subscribe(parse_topic_filters(?SUBSCRIBE, RawTopicFilters), Channel) of - {ok, TopicFilters} -> - TopicFilters1 = preprocess_topic_filters(?SUBSCRIBE, Endpoint, - enrich_subopts(TopicFilters, Channel)), - {ok, ReasonCodes, NSession} = emqx_session:subscribe(emqx_endpoint:to_map(Endpoint), - TopicFilters1, Session), - handle_out({suback, PacketId, ReasonCodes}, Channel#channel{session = NSession}); - {error, TopicFilters} -> - {Topics, ReasonCodes} = lists:unzip([{Topic, RC} || {Topic, #{rc := RC}} <- TopicFilters]), - ?LOG(warning, "Cannot subscribe ~p due to ~p", - [Topics, [emqx_reason_codes:text(R) || R <- ReasonCodes]]), - %% Tell the client that all subscriptions has been failed. - ReasonCodes1 = lists:map(fun(?RC_SUCCESS) -> - ?RC_IMPLEMENTATION_SPECIFIC_ERROR; - (RC) -> RC - end, ReasonCodes), - handle_out({suback, PacketId, ReasonCodes1}, Channel) - end; - -handle_in(?UNSUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters), - Channel = #channel{endpoint = Endpoint, session = Session}) -> - TopicFilters = preprocess_topic_filters( - ?UNSUBSCRIBE, Endpoint, - parse_topic_filters(?UNSUBSCRIBE, RawTopicFilters)), - {ok, ReasonCodes, NSession} = emqx_session:unsubscribe(emqx_endpoint:to_map(Endpoint), TopicFilters, Session), - handle_out({unsuback, PacketId, ReasonCodes}, Channel#channel{session = NSession}); - -handle_in(?PACKET(?PINGREQ), Channel) -> - {ok, ?PACKET(?PINGRESP), Channel}; - -handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), Channel) -> - %% Clear will msg - {stop, normal, Channel#channel{will_msg = undefined}}; - -handle_in(?DISCONNECT_PACKET(RC), Channel = #channel{proto_ver = Ver}) -> - %% TODO: - %% {stop, {shutdown, abnormal_disconnet}, Channel}; - {stop, {shutdown, emqx_reason_codes:name(RC, Ver)}, Channel}; - -handle_in(?AUTH_PACKET(), Channel) -> - %%TODO: implement later. - {ok, Channel}; - -handle_in(Packet, Channel) -> - io:format("In: ~p~n", [Packet]), - {ok, Channel}. +handle_incoming(Packet = ?PACKET(Type), SuccFun, + State = #state{proto_state = ProtoState}) -> + _ = inc_incoming_stats(Type), + ok = emqx_metrics:inc_recv(Packet), + ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]), + case emqx_protocol:handle_in(Packet, ProtoState) of + {ok, NProtoState} -> + SuccFun(State#state{proto_state = NProtoState}); + {ok, OutPackets, NProtoState} -> + handle_outgoing(OutPackets, SuccFun, + State#state{proto_state = NProtoState}); + {error, Reason, NProtoState} -> + shutdown(Reason, State#state{proto_state = NProtoState}); + {stop, Error, NProtoState} -> + stop(Error, State#state{proto_state = NProtoState}) + end. %%-------------------------------------------------------------------- -%% Handle outgoing packet -%%-------------------------------------------------------------------- +%% Handle outgoing packets -handle_out({connack, ?RC_SUCCESS, SP}, Channel = #channel{endpoint = Endpoint}) -> - ok = emqx_hooks:run('client.connected', - [emqx_endpoint:to_map(Endpoint), ?RC_SUCCESS, attrs(Channel)]), - Props = #{}, %% TODO: ... - {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, Props), Channel}; +handle_outgoing(Packets, SuccFun, State = #state{serialize = Serialize}) + when is_list(Packets) -> + send(lists:map(Serialize, Packets), SuccFun, State); -handle_out({connack, ReasonCode}, Channel = #channel{endpoint = Endpoint, - proto_ver = ProtoVer}) -> - ok = emqx_hooks:run('client.connected', - [emqx_endpoint:to_map(Endpoint), ReasonCode, attrs(Channel)]), - ReasonCode1 = if - ProtoVer == ?MQTT_PROTO_V5 -> ReasonCode; - true -> emqx_reason_codes:compat(connack, ReasonCode) - end, - Reason = emqx_reason_codes:name(ReasonCode1, ProtoVer), - {error, Reason, ?CONNACK_PACKET(ReasonCode1), Channel}; - -handle_out(Delivers, Channel = #channel{endpoint = Endpoint, session = Session}) - when is_list(Delivers) -> - case emqx_session:deliver([{Topic, Msg} || {deliver, Topic, Msg} <- Delivers], Session) of - {ok, Publishes, NSession} -> - Credentials = emqx_endpoint:credentials(Endpoint), - Packets = lists:map(fun({publish, PacketId, Msg}) -> - Msg0 = emqx_hooks:run_fold('message.deliver', [Credentials], Msg), - Msg1 = emqx_message:update_expiry(Msg0), - Msg2 = emqx_mountpoint:unmount(emqx_endpoint:mountpoint(Endpoint), Msg1), - emqx_packet:from_message(PacketId, Msg2) - end, Publishes), - {ok, Packets, Channel#channel{session = NSession}}; - {ok, NSession} -> - {ok, Channel#channel{session = NSession}} - end; - -handle_out({publish, PacketId, Msg}, Channel = #channel{endpoint = Endpoint}) -> - Credentials = emqx_endpoint:credentials(Endpoint), - Msg0 = emqx_hooks:run_fold('message.deliver', [Credentials], Msg), - Msg1 = emqx_message:update_expiry(Msg0), - Msg2 = emqx_mountpoint:unmount( - emqx_endpoint:mountpoint(Credentials), Msg1), - {ok, emqx_packet:from_message(PacketId, Msg2), Channel}; - -handle_out({puberr, ReasonCode}, Channel) -> - {ok, Channel}; - -handle_out({puback, PacketId, ReasonCode}, Channel) -> - {ok, ?PUBACK_PACKET(PacketId, ReasonCode), Channel}; - -handle_out({pubrel, PacketId}, Channel) -> - {ok, ?PUBREL_PACKET(PacketId), Channel}; -handle_out({pubrel, PacketId, ReasonCode}, Channel) -> - {ok, ?PUBREL_PACKET(PacketId, ReasonCode), Channel}; - -handle_out({pubrec, PacketId, ReasonCode}, Channel) -> - {ok, ?PUBREC_PACKET(PacketId, ReasonCode), Channel}; - -handle_out({pubcomp, PacketId}, Channel) -> - {ok, ?PUBCOMP_PACKET(PacketId), Channel}; -handle_out({pubcomp, PacketId, ReasonCode}, Channel) -> - {ok, ?PUBCOMP_PACKET(PacketId, ReasonCode), Channel}; - -handle_out({suback, PacketId, ReasonCodes}, Channel = #channel{proto_ver = ?MQTT_PROTO_V5}) -> - %% TODO: ACL Deny - {ok, ?SUBACK_PACKET(PacketId, ReasonCodes), Channel}; -handle_out({suback, PacketId, ReasonCodes}, Channel) -> - %% TODO: ACL Deny - ReasonCodes1 = [emqx_reason_codes:compat(suback, RC) || RC <- ReasonCodes], - {ok, ?SUBACK_PACKET(PacketId, ReasonCodes1), Channel}; - -handle_out({unsuback, PacketId, ReasonCodes}, Channel = #channel{proto_ver = ?MQTT_PROTO_V5}) -> - {ok, ?UNSUBACK_PACKET(PacketId, ReasonCodes), Channel}; -%% Ignore reason codes if not MQTT5 -handle_out({unsuback, PacketId, _ReasonCodes}, Channel) -> - {ok, ?UNSUBACK_PACKET(PacketId), Channel}; - -handle_out(Packet, State) -> - io:format("Out: ~p~n", [Packet]), - {ok, State}. - -handle_deliver(Msg, State) -> - io:format("Msg: ~p~n", [Msg]), - %% Msg -> Pub - {ok, State}. - -handle_timeout(Name, TRef, State) -> - io:format("Timeout: ~s ~p~n", [Name, TRef]), - {ok, State}. - -terminate(Reason, _State) -> - %%io:format("Terminated for ~p~n", [Reason]), - ok. +handle_outgoing(Packet, SuccFun, State = #state{serialize = Serialize}) -> + send(Serialize(Packet), SuccFun, State). %%-------------------------------------------------------------------- -%% Check Connect Packet -%%-------------------------------------------------------------------- +%% Serialize fun -check_connect(_ConnPkt, Channel) -> - {ok, Channel}. +serialize_fun(ProtoVer) -> + fun(Packet = ?PACKET(Type)) -> + ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), + _ = inc_outgoing_stats(Type), + emqx_frame:serialize(Packet, ProtoVer) + end. %%-------------------------------------------------------------------- -%% Handle Connect Packet -%%-------------------------------------------------------------------- +%% Send data -handle_connect(#mqtt_packet_connect{proto_name = ProtoName, - proto_ver = ProtoVer, - is_bridge = IsBridge, - clean_start = CleanStart, - keepalive = Keepalive, - properties = ConnProps, - client_id = ClientId, - username = Username, - password = Password} = ConnPkt, - Channel = #channel{endpoint = Endpoint}) -> - Credentials = emqx_endpoint:credentials(Endpoint), - case emqx_access_control:authenticate( - Credentials#{password => Password}) of - {ok, Credentials1} -> - Endpoint1 = emqx_endpoint:update(Credentials1, Endpoint), - %% Open session - case open_session(ConnPkt, Channel) of - {ok, Session, SP} -> - Channel1 = Channel#channel{endpoint = Endpoint1, - session = Session, - connected = true, - connected_at = os:timestamp() - }, - ok = emqx_cm:register_channel(ClientId), - {ok, SP, Channel1}; - {error, Error} -> - ?LOG(error, "Failed to open session: ~p", [Error]), - {error, ?RC_UNSPECIFIED_ERROR, Channel#channel{endpoint = Endpoint1}} - end; +send(IoData, SuccFun, State = #state{transport = Transport, + socket = Socket}) -> + Oct = iolist_size(IoData), + ok = emqx_metrics:inc('bytes.sent', Oct), + case Transport:async_send(Socket, IoData) of + ok -> SuccFun(maybe_gc(1, Oct, State)); {error, Reason} -> - ?LOG(warning, "Client ~s (Username: '~s') login failed for ~p", - [ClientId, Username, Reason]), - {error, emqx_reason_codes:connack_error(Reason), Channel} - end. - -open_session(#mqtt_packet_connect{clean_start = CleanStart, - %%properties = ConnProps, - client_id = ClientId, - username = Username} = ConnPkt, - Channel = #channel{endpoint = Endpoint}) -> - emqx_cm:open_session(maps:merge(emqx_endpoint:to_map(Endpoint), - #{clean_start => CleanStart, - max_inflight => 0, - expiry_interval => 0})). - -%%-------------------------------------------------------------------- -%% Handle Publish Message: Client -> Broker -%%-------------------------------------------------------------------- - -handle_publish(Packet = ?PUBLISH_PACKET(QoS, Topic, PacketId), - Channel = #channel{endpoint = Endpoint}) -> - Credentials = emqx_endpoint:credentials(Endpoint), - %% TODO: ugly... publish_to_msg(...) - Msg = emqx_packet:to_message(Credentials, Packet), - Msg1 = emqx_mountpoint:mount( - emqx_endpoint:mountpoint(Endpoint), Msg), - Msg2 = emqx_message:set_flag(dup, false, Msg1), - handle_publish(PacketId, Msg2, Channel). - -handle_publish(_PacketId, Msg = #message{qos = ?QOS_0}, Channel) -> - _ = emqx_broker:publish(Msg), - {ok, Channel}; - -handle_publish(PacketId, Msg = #message{qos = ?QOS_1}, Channel) -> - Results = emqx_broker:publish(Msg), - ReasonCode = emqx_reason_codes:puback(Results), - handle_out({puback, PacketId, ReasonCode}, Channel); - -handle_publish(PacketId, Msg = #message{qos = ?QOS_2}, - Channel = #channel{session = Session}) -> - case emqx_session:publish(PacketId, Msg, Session) of - {ok, Results, NSession} -> - ReasonCode = emqx_reason_codes:puback(Results), - handle_out({pubrec, PacketId, ReasonCode}, - Channel#channel{session = NSession}); - {error, ReasonCode} -> - handle_out({pubrec, PacketId, ReasonCode}, Channel) - end. - -%%-------------------------------------------------------------------- -%% Validate Incoming Packet -%%-------------------------------------------------------------------- - --spec(validate_packet(emqx_mqtt:packet(), channel()) -> ok). -validate_packet(Packet, _Channel) -> - try emqx_packet:validate(Packet) of - true -> ok - catch - error:protocol_error -> - {error, ?RC_PROTOCOL_ERROR}; - error:subscription_identifier_invalid -> - {error, ?RC_SUBSCRIPTION_IDENTIFIERS_NOT_SUPPORTED}; - error:topic_alias_invalid -> - {error, ?RC_TOPIC_ALIAS_INVALID}; - error:topic_filters_invalid -> - {error, ?RC_TOPIC_FILTER_INVALID}; - error:topic_name_invalid -> - {error, ?RC_TOPIC_FILTER_INVALID}; - error:_Reason -> - {error, ?RC_MALFORMED_PACKET} + shutdown(Reason, State) end. %%-------------------------------------------------------------------- -%% Preprocess MQTT Properties -%%-------------------------------------------------------------------- +%% Ensure keepalive -%% TODO:... +ensure_keepalive(0, _State) -> + ignore; +ensure_keepalive(Interval, #state{transport = Transport, + socket = Socket, + proto_state = ProtoState}) -> + StatFun = fun() -> + case Transport:getstat(Socket, [recv_oct]) of + {ok, [{recv_oct, RecvOct}]} -> + {ok, RecvOct}; + Error -> Error + end + end, + Backoff = emqx_zone:get_env(emqx_protocol:info(zone, ProtoState), + keepalive_backoff, 0.75), + emqx_keepalive:start(StatFun, round(Interval * Backoff), {keepalive, check}). %%-------------------------------------------------------------------- -%% Check Publish +%% Ensure rate limit + +ensure_rate_limit(State = #state{rate_limit = Rl, pub_limit = Pl}) -> + Limiters = [{Pl, #state.pub_limit, emqx_pd:reset_counter(incoming_pubs)}, + {Rl, #state.rate_limit, emqx_pd:reset_counter(incoming_bytes)}], + ensure_rate_limit(Limiters, State). + +ensure_rate_limit([], State) -> + State; +ensure_rate_limit([{undefined, _Pos, _Cnt}|Limiters], State) -> + ensure_rate_limit(Limiters, State); +ensure_rate_limit([{Rl, Pos, Cnt}|Limiters], State) -> + case esockd_rate_limit:check(Cnt, Rl) of + {0, Rl1} -> + ensure_rate_limit(Limiters, setelement(Pos, State, Rl1)); + {Pause, Rl1} -> + ?LOG(debug, "Rate limit pause connection ~pms", [Pause]), + TRef = erlang:send_after(Pause, self(), activate_socket), + setelement(Pos, State#state{conn_state = blocked, + limit_timer = TRef}, Rl1) + end. + %%-------------------------------------------------------------------- +%% Activate Socket -check_pub_caps(#mqtt_packet{header = #mqtt_packet_header{qos = QoS, - retain = Retain}, - variable = #mqtt_packet_publish{}}, - #channel{endpoint = Endpoint}) -> - emqx_mqtt_caps:check_pub(emqx_endpoint:zone(Endpoint), - #{qos => QoS, retain => Retain}). - -check_pub_acl(_Packet, #channel{enable_acl = false}) -> +activate_socket(#state{conn_state = blocked}) -> ok; -check_pub_acl(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = Topic}}, - #channel{endpoint = Endpoint}) -> - case emqx_endpoint:is_superuser(Endpoint) of - true -> ok; - false -> - do_acl_check(Endpoint, publish, Topic) +activate_socket(#state{transport = Transport, + socket = Socket, + active_n = N}) -> + Transport:setopts(Socket, [{active, N}]). + +%%-------------------------------------------------------------------- +%% Inc incoming/outgoing stats + +-compile({inline, + [ inc_incoming_stats/1 + , inc_outgoing_stats/1 + ]}). + +inc_incoming_stats(Type) -> + emqx_pd:update_counter(recv_pkt, 1), + case Type == ?PUBLISH of + true -> + emqx_pd:update_counter(recv_msg, 1), + emqx_pd:update_counter(incoming_pubs, 1); + false -> ok end. -check_sub_acl(_Packet, #channel{enable_acl = false}) -> - ok. - -do_acl_check(Endpoint, PubSub, Topic) -> - case emqx_access_control:check_acl( - emqx_endpoint:to_map(Endpoint), PubSub, Topic) of - allow -> ok; - deny -> {error, ?RC_NOT_AUTHORIZED} - end. +inc_outgoing_stats(Type) -> + emqx_pd:update_counter(send_pkt, 1), + (Type == ?PUBLISH) + andalso emqx_pd:update_counter(send_msg, 1). %%-------------------------------------------------------------------- -%% Check Subscribe Packet -%%-------------------------------------------------------------------- +%% Ensure stats timer -check_subscribe(TopicFilters, _Channel) -> - {ok, TopicFilters}. +ensure_stats_timer(State = #state{stats_timer = undefined, + idle_timeout = IdleTimeout}) -> + TRef = emqx_misc:start_timer(IdleTimeout, emit_stats), + State#state{stats_timer = TRef}; +%% disabled or timer existed +ensure_stats_timer(State) -> State. %%-------------------------------------------------------------------- -%% Pipeline -%%-------------------------------------------------------------------- +%% Maybe GC -pipeline([Fun], Packet, Channel) -> - Fun(Packet, Channel); -pipeline([Fun|More], Packet, Channel) -> - case Fun(Packet, Channel) of - ok -> pipeline(More, Packet, Channel); - {ok, NChannel} -> - pipeline(More, Packet, NChannel); - {ok, NPacket, NChannel} -> - pipeline(More, NPacket, NChannel); - {error, Reason} -> - {error, Reason} - end. - -%%-------------------------------------------------------------------- -%% Preprocess topic filters -%%-------------------------------------------------------------------- - -preprocess_topic_filters(Type, Endpoint, TopicFilters) -> - TopicFilters1 = emqx_hooks:run_fold(case Type of - ?SUBSCRIBE -> 'client.subscribe'; - ?UNSUBSCRIBE -> 'client.unsubscribe' - end, - [emqx_endpoint:credentials(Endpoint)], - TopicFilters), - emqx_mountpoint:mount(emqx_endpoint:mountpoint(Endpoint), TopicFilters1). - -%%-------------------------------------------------------------------- -%% Enrich subopts -%%-------------------------------------------------------------------- - -enrich_subopts(TopicFilters, #channel{proto_ver = ?MQTT_PROTO_V5}) -> - TopicFilters; -enrich_subopts(TopicFilters, #channel{endpoint = Endpoint, is_bridge = IsBridge}) -> - Rap = flag(IsBridge), - Nl = flag(emqx_zone:get_env(emqx_endpoint:zone(Endpoint), ignore_loop_deliver, false)), - [{Topic, SubOpts#{rap => Rap, nl => Nl}} || {Topic, SubOpts} <- TopicFilters]. - -%%-------------------------------------------------------------------- -%% Parse topic filters -%%-------------------------------------------------------------------- - -parse_topic_filters(?SUBSCRIBE, TopicFilters) -> - [emqx_topic:parse(Topic, SubOpts) || {Topic, SubOpts} <- TopicFilters]; - -parse_topic_filters(?UNSUBSCRIBE, TopicFilters) -> - lists:map(fun emqx_topic:parse/1, TopicFilters). +maybe_gc(_Cnt, _Oct, State = #state{gc_state = undefined}) -> + State; +maybe_gc(Cnt, Oct, State = #state{gc_state = GCSt}) -> + {Ok, GCSt1} = emqx_gc:run(Cnt, Oct, GCSt), + Ok andalso emqx_metrics:inc('channel.gc.cnt'), + State#state{gc_state = GCSt1}. %%-------------------------------------------------------------------- %% Helper functions -%%-------------------------------------------------------------------- -sp(true) -> 1; -sp(false) -> 0. +-compile({inline, + [ reply/3 + , keep_state/1 + , next_event/2 + , shutdown/2 + , stop/2 + ]}). -flag(true) -> 1; -flag(false) -> 0. +reply(From, Reply, State) -> + {keep_state, State, [{reply, From, Reply}]}. + +keep_state(State) -> + {keep_state, State}. + +next_event(Type, Content) -> + {next_event, Type, Content}. + +shutdown(Reason, State) -> + stop({shutdown, Reason}, State). + +stop(Reason, State) -> + {stop, Reason, State}. diff --git a/src/emqx_client.erl b/src/emqx_client.erl index 41ff2f72e..7e4120628 100644 --- a/src/emqx_client.erl +++ b/src/emqx_client.erl @@ -160,7 +160,7 @@ clean_start :: boolean(), username :: maybe(binary()), password :: maybe(binary()), - proto_ver :: emqx_mqtt_types:version(), + proto_ver :: emqx_types:mqtt_ver(), proto_name :: iodata(), keepalive :: non_neg_integer(), keepalive_timer :: maybe(reference()), @@ -192,11 +192,11 @@ -type(payload() :: iodata()). --type(packet_id() :: emqx_mqtt_types:packet_id()). +-type(packet_id() :: emqx_types:packet_id()). --type(properties() :: emqx_mqtt_types:properties()). +-type(properties() :: emqx_types:properties()). --type(qos() :: emqx_mqtt_types:qos_name() | emqx_mqtt_types:qos()). +-type(qos() :: emqx_types:qos_name() | emqx_types:qos()). -type(pubopt() :: {retain, boolean()} | {qos, qos()} | {timeout, timeout()}). diff --git a/src/emqx_cm.erl b/src/emqx_cm.erl index 7f230e841..e756a37e5 100644 --- a/src/emqx_cm.erl +++ b/src/emqx_cm.erl @@ -34,12 +34,12 @@ -export([ get_conn_attrs/1 , get_conn_attrs/2 - , set_conn_attrs/2 + , set_chan_attrs/2 ]). -export([ get_conn_stats/1 , get_conn_stats/2 - , set_conn_stats/2 + , set_chan_stats/2 ]). -export([ open_session/1 @@ -163,8 +163,8 @@ get_conn_attrs(ClientId, ChanPid) -> rpc_call(node(ChanPid), get_conn_attrs, [ClientId, ChanPid]). %% @doc Set conn attrs. --spec(set_conn_attrs(emqx_types:client_id(), attrs()) -> ok). -set_conn_attrs(ClientId, Attrs) when is_binary(ClientId), is_map(Attrs) -> +-spec(set_chan_attrs(emqx_types:client_id(), attrs()) -> ok). +set_chan_attrs(ClientId, Attrs) when is_binary(ClientId) -> Chan = {ClientId, self()}, case ets:update_element(?CONN_TAB, Chan, {2, Attrs}) of true -> ok; @@ -191,12 +191,12 @@ get_conn_stats(ClientId, ChanPid) -> rpc_call(node(ChanPid), get_conn_stats, [ClientId, ChanPid]). %% @doc Set conn stats. --spec(set_conn_stats(emqx_types:client_id(), stats()) -> ok). -set_conn_stats(ClientId, Stats) when is_binary(ClientId) -> - set_conn_stats(ClientId, self(), Stats). +-spec(set_chan_stats(emqx_types:client_id(), stats()) -> ok). +set_chan_stats(ClientId, Stats) when is_binary(ClientId) -> + set_chan_stats(ClientId, self(), Stats). --spec(set_conn_stats(emqx_types:client_id(), chan_pid(), stats()) -> ok). -set_conn_stats(ClientId, ChanPid, Stats) -> +-spec(set_chan_stats(emqx_types:client_id(), chan_pid(), stats()) -> ok). +set_chan_stats(ClientId, ChanPid, Stats) -> Chan = {ClientId, ChanPid}, _ = ets:update_element(?CONN_TAB, Chan, {3, Stats}), ok. @@ -208,7 +208,7 @@ open_session(Attrs = #{clean_start := true, client_id := ClientId}) -> CleanStart = fun(_) -> ok = discard_session(ClientId), - {ok, emqx_session:new(Attrs), false} + {ok, emqx_session:init(Attrs), false} end, emqx_cm_locker:trans(ClientId, CleanStart); @@ -219,7 +219,7 @@ open_session(Attrs = #{clean_start := false, {ok, Session} -> {ok, Session, true}; {error, not_found} -> - {ok, emqx_session:new(Attrs), false} + {ok, emqx_session:init(Attrs), false} end end, emqx_cm_locker:trans(ClientId, ResumeStart). diff --git a/src/emqx_connection.erl b/src/emqx_connection.erl deleted file mode 100644 index b7478e445..000000000 --- a/src/emqx_connection.erl +++ /dev/null @@ -1,586 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - -%% MQTT TCP/SSL Connection --module(emqx_connection). - --behaviour(gen_statem). - --include("emqx.hrl"). --include("emqx_mqtt.hrl"). --include("logger.hrl"). --include("types.hrl"). - --logger_header("[Conn]"). - --export([start_link/3]). - -%% APIs --export([ info/1 - , attrs/1 - , stats/1 - ]). - -%% gen_statem callbacks --export([ idle/3 - , connected/3 - , disconnected/3 - ]). - --export([ init/1 - , callback_mode/0 - , code_change/4 - , terminate/3 - ]). - --record(state, { - transport :: esockd:transport(), - socket :: esockd:socket(), - peername :: emqx_types:peername(), - sockname :: emqx_types:peername(), - conn_state :: running | blocked, - active_n :: pos_integer(), - rate_limit :: maybe(esockd_rate_limit:bucket()), - pub_limit :: maybe(esockd_rate_limit:bucket()), - limit_timer :: maybe(reference()), - parse_state :: emqx_frame:parse_state(), - chan_state :: emqx_channel:channel(), - gc_state :: emqx_gc:gc_state(), - keepalive :: maybe(emqx_keepalive:keepalive()), - stats_timer :: disabled | maybe(reference()), - idle_timeout :: timeout() - }). - --define(ACTIVE_N, 100). --define(HANDLE(T, C, D), handle((T), (C), (D))). --define(CHAN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]). --define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]). - --spec(start_link(esockd:transport(), esockd:socket(), proplists:proplist()) - -> {ok, pid()}). -start_link(Transport, Socket, Options) -> - {ok, proc_lib:spawn_link(?MODULE, init, [{Transport, Socket, Options}])}. - -%%-------------------------------------------------------------------- -%% API -%%-------------------------------------------------------------------- - -%% For debug --spec(info(pid() | #state{}) -> map()). -info(CPid) when is_pid(CPid) -> - call(CPid, info); - -info(#state{transport = Transport, - socket = Socket, - peername = Peername, - sockname = Sockname, - conn_state = ConnState, - active_n = ActiveN, - rate_limit = RateLimit, - pub_limit = PubLimit, - chan_state = ChanState}) -> - ConnInfo = #{socktype => Transport:type(Socket), - peername => Peername, - sockname => Sockname, - conn_state => ConnState, - active_n => ActiveN, - rate_limit => rate_limit_info(RateLimit), - pub_limit => rate_limit_info(PubLimit) - }, - ChanInfo = emqx_channel:info(ChanState), - maps:merge(ConnInfo, ChanInfo). - -rate_limit_info(undefined) -> - undefined; -rate_limit_info(Limit) -> - esockd_rate_limit:info(Limit). - -%% For dashboard -attrs(CPid) when is_pid(CPid) -> - call(CPid, attrs); - -attrs(#state{peername = Peername, - sockname = Sockname, - conn_state = ConnState, - chan_state = ChanState}) -> - SockAttrs = #{peername => Peername, - sockname => Sockname, - conn_state => ConnState - }, - ChanAttrs = emqx_channel:attrs(ChanState), - maps:merge(SockAttrs, ChanAttrs). - -%% @doc Get connection stats -stats(CPid) when is_pid(CPid) -> - call(CPid, stats); - -stats(#state{transport = Transport, socket = Socket}) -> - SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of - {ok, Ss} -> Ss; - {error, _} -> [] - end, - ChanStats = [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS], - lists:append([SockStats, ChanStats, emqx_misc:proc_stats()]). - -call(CPid, Req) -> - gen_statem:call(CPid, Req, infinity). - -%%-------------------------------------------------------------------- -%% gen_statem callbacks -%%-------------------------------------------------------------------- - -init({Transport, RawSocket, Options}) -> - {ok, Socket} = Transport:wait(RawSocket), - {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]), - {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]), - Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]), - emqx_logger:set_metadata_peername(esockd_net:format(Peername)), - Zone = proplists:get_value(zone, Options), - RateLimit = init_limiter(proplists:get_value(rate_limit, Options)), - PubLimit = init_limiter(emqx_zone:get_env(Zone, publish_limit)), - ActiveN = proplists:get_value(active_n, Options, ?ACTIVE_N), - MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE), - ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}), - ChanState = emqx_channel:init(#{peername => Peername, - sockname => Sockname, - peercert => Peercert, - conn_mod => ?MODULE}, Options), - GcPolicy = emqx_zone:get_env(Zone, force_gc_policy, false), - GcState = emqx_gc:init(GcPolicy), - EnableStats = emqx_zone:get_env(Zone, enable_stats, true), - StatsTimer = if EnableStats -> undefined; ?Otherwise-> disabled end, - IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), - ok = emqx_misc:init_proc_mng_policy(Zone), - State = #state{transport = Transport, - socket = Socket, - peername = Peername, - conn_state = running, - active_n = ActiveN, - rate_limit = RateLimit, - pub_limit = PubLimit, - parse_state = ParseState, - chan_state = ChanState, - gc_state = GcState, - stats_timer = StatsTimer, - idle_timeout = IdleTimout - }, - gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}], - idle, State, self(), [IdleTimout]). - -init_limiter(undefined) -> - undefined; -init_limiter({Rate, Burst}) -> - esockd_rate_limit:new(Rate, Burst). - -callback_mode() -> - [state_functions, state_enter]. - -%%-------------------------------------------------------------------- -%% Idle State - -idle(enter, _, State) -> - case activate_socket(State) of - ok -> keep_state_and_data; - {error, Reason} -> - shutdown(Reason, State) - end; - -idle(timeout, _Timeout, State) -> - stop(idle_timeout, State); - -idle(cast, {incoming, Packet = ?CONNECT_PACKET(ConnVar)}, State) -> - handle_incoming(Packet, - fun(St = #state{chan_state = ChanState}) -> - %% Ensure keepalive after connected successfully. - Interval = emqx_channel:keepalive(ChanState), - NextEvent = {next_event, info, {keepalive, start, Interval}}, - {next_state, connected, St, NextEvent} - end, State); - -idle(cast, {incoming, Packet}, State) -> - ?LOG(warning, "Unexpected incoming: ~p", [Packet]), - shutdown(unexpected_incoming_packet, State); - -idle(EventType, Content, State) -> - ?HANDLE(EventType, Content, State). - -%%-------------------------------------------------------------------- -%% Connected State - -connected(enter, _, _State) -> - %% What to do? - keep_state_and_data; - -connected(cast, {incoming, Packet = ?PACKET(?CONNECT)}, State) -> - ?LOG(warning, "Unexpected connect: ~p", [Packet]), - shutdown(unexpected_incoming_connect, State); - -connected(cast, {incoming, Packet = ?PACKET(Type)}, State) -> - handle_incoming(Packet, fun keep_state/1, State); - -connected(info, Deliver = {deliver, _Topic, _Msg}, - State = #state{chan_state = ChanState}) -> - Delivers = emqx_misc:drain_deliver([Deliver]), - case emqx_channel:handle_out(Delivers, ChanState) of - {ok, NChanState} -> - keep_state(State#state{chan_state = NChanState}); - {ok, Packets, NChanState} -> - NState = State#state{chan_state = NChanState}, - handle_outgoing(Packets, fun keep_state/1, NState); - {error, Reason} -> - shutdown(Reason, State) - end; - -%% Start Keepalive -connected(info, {keepalive, start, Interval}, State) -> - case ensure_keepalive(Interval, State) of - ignore -> keep_state(State); - {ok, KeepAlive} -> - keep_state(State#state{keepalive = KeepAlive}); - {error, Reason} -> - shutdown(Reason, State) - end; - -%% Keepalive timer -connected(info, {keepalive, check}, State = #state{keepalive = KeepAlive}) -> - case emqx_keepalive:check(KeepAlive) of - {ok, KeepAlive1} -> - keep_state(State#state{keepalive = KeepAlive1}); - {error, timeout} -> - shutdown(keepalive_timeout, State); - {error, Reason} -> - shutdown(Reason, State) - end; - -connected(EventType, Content, State) -> - ?HANDLE(EventType, Content, State). - -%%-------------------------------------------------------------------- -%% Disconnected State - -disconnected(enter, _, _State) -> - %% TODO: What to do? - keep_state_and_data; - -disconnected(EventType, Content, State) -> - ?HANDLE(EventType, Content, State). - -%% Handle call -handle({call, From}, info, State) -> - reply(From, info(State), State); - -handle({call, From}, attrs, State) -> - reply(From, attrs(State), State); - -handle({call, From}, stats, State) -> - reply(From, stats(State), State); - -%%handle({call, From}, kick, State) -> -%% ok = gen_statem:reply(From, ok), -%% shutdown(kicked, State); - -%%handle({call, From}, discard, State) -> -%% ok = gen_statem:reply(From, ok), -%% shutdown(discard, State); - -handle({call, From}, Req, State) -> - ?LOG(error, "Unexpected call: ~p", [Req]), - reply(From, ignored, State); - -%% Handle cast -handle(cast, Msg, State) -> - ?LOG(error, "Unexpected cast: ~p", [Msg]), - keep_state(State); - -%% Handle Incoming -handle(info, {Inet, _Sock, Data}, State) when Inet == tcp; - Inet == ssl -> - ?LOG(debug, "RECV ~p", [Data]), - Oct = iolist_size(Data), - emqx_pd:update_counter(incoming_bytes, Oct), - ok = emqx_metrics:inc('bytes.received', Oct), - NState = ensure_stats_timer(maybe_gc(1, Oct, State)), - process_incoming(Data, [], NState); - -handle(info, {Error, _Sock, Reason}, State) - when Error == tcp_error; Error == ssl_error -> - shutdown(Reason, State); - -handle(info, {Closed, _Sock}, State) - when Closed == tcp_closed; Closed == ssl_closed -> - shutdown(closed, State); - -handle(info, {Passive, _Sock}, State) when Passive == tcp_passive; - Passive == ssl_passive -> - %% Rate limit here:) - NState = ensure_rate_limit(State), - case activate_socket(NState) of - ok -> keep_state(NState); - {error, Reason} -> - shutdown(Reason, NState) - end; - -handle(info, activate_socket, State) -> - %% Rate limit timer expired. - NState = State#state{conn_state = running}, - case activate_socket(NState) of - ok -> - keep_state(NState#state{limit_timer = undefined}); - {error, Reason} -> - shutdown(Reason, NState) - end; - -handle(info, {inet_reply, _Sock, ok}, State) -> - %% something sent - keep_state(ensure_stats_timer(State)); - -handle(info, {inet_reply, _Sock, {error, Reason}}, State) -> - shutdown(Reason, State); - -handle(info, {timeout, Timer, emit_stats}, - State = #state{stats_timer = Timer, - chan_state = ChanState, - gc_state = GcState}) -> - ClientId = emqx_channel:client_id(ChanState), - ok = emqx_cm:set_conn_stats(ClientId, stats(State)), - NState = State#state{stats_timer = undefined}, - Limits = erlang:get(force_shutdown_policy), - case emqx_misc:conn_proc_mng_policy(Limits) of - continue -> - keep_state(NState); - hibernate -> - %% going to hibernate, reset gc stats - GcState1 = emqx_gc:reset(GcState), - {keep_state, NState#state{gc_state = GcState1}, hibernate}; - {shutdown, Reason} -> - ?LOG(error, "Shutdown exceptionally due to ~p", [Reason]), - shutdown(Reason, NState) - end; - -handle(info, {shutdown, discard, {ClientId, ByPid}}, State) -> - ?LOG(error, "Discarded by ~s:~p", [ClientId, ByPid]), - shutdown(discard, State); - -handle(info, {shutdown, conflict, {ClientId, NewPid}}, State) -> - ?LOG(warning, "Clientid '~s' conflict with ~p", [ClientId, NewPid]), - shutdown(conflict, State); - -handle(info, {shutdown, Reason}, State) -> - shutdown(Reason, State); - -handle(info, Info, State) -> - ?LOG(error, "Unexpected info: ~p", [Info]), - keep_state(State). - -code_change(_Vsn, State, Data, _Extra) -> - {ok, State, Data}. - -terminate(Reason, _StateName, #state{transport = Transport, - socket = Socket, - keepalive = KeepAlive, - chan_state = ChanState}) -> - ?LOG(debug, "Terminated for ~p", [Reason]), - ok = Transport:fast_close(Socket), - ok = emqx_keepalive:cancel(KeepAlive), - emqx_channel:terminate(Reason, ChanState). - -%%-------------------------------------------------------------------- -%% Process incoming data - -process_incoming(<<>>, Packets, State) -> - {keep_state, State, next_events(Packets)}; - -process_incoming(Data, Packets, State = #state{parse_state = ParseState}) -> - try emqx_frame:parse(Data, ParseState) of - {ok, NParseState} -> - NState = State#state{parse_state = NParseState}, - {keep_state, NState, next_events(Packets)}; - {ok, Packet, Rest, NParseState} -> - NState = State#state{parse_state = NParseState}, - process_incoming(Rest, [Packet|Packets], NState); - {error, Reason} -> - shutdown(Reason, State) - catch - error:Reason:Stk -> - ?LOG(error, "Parse failed for ~p~n\ - Stacktrace:~p~nError data:~p", [Reason, Stk, Data]), - shutdown(parse_error, State) - end. - -next_events(Packets) when is_list(Packets) -> - [next_events(Packet) || Packet <- lists:reverse(Packets)]; -next_events(Packet) -> - {next_event, cast, {incoming, Packet}}. - -%%-------------------------------------------------------------------- -%% Handle incoming packet - -handle_incoming(Packet = ?PACKET(Type), SuccFun, - State = #state{chan_state = ChanState}) -> - _ = inc_incoming_stats(Type), - ok = emqx_metrics:inc_recv(Packet), - ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]), - case emqx_channel:handle_in(Packet, ChanState) of - {ok, NChanState} -> - SuccFun(State#state{chan_state = NChanState}); - {ok, OutPacket, NChanState} -> - handle_outgoing(OutPacket, SuccFun, - State#state{chan_state = NChanState}); - {error, Reason, NChanState} -> - shutdown(Reason, State#state{chan_state = NChanState}); - {stop, Error, NChanState} -> - stop(Error, State#state{chan_state = NChanState}) - end. - -%%-------------------------------------------------------------------- -%% Handle outgoing packets - -handle_outgoing(Packets, SuccFun, State = #state{chan_state = ChanState}) - when is_list(Packets) -> - ProtoVer = emqx_channel:proto_ver(ChanState), - IoData = lists:foldl( - fun(Packet = ?PACKET(Type), Acc) -> - ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), - _ = inc_outgoing_stats(Type), - [emqx_frame:serialize(Packet, ProtoVer)|Acc] - end, [], Packets), - send(lists:reverse(IoData), SuccFun, State); - -handle_outgoing(Packet = ?PACKET(Type), SuccFun, State = #state{chan_state = ChanState}) -> - ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), - _ = inc_outgoing_stats(Type), - ProtoVer = emqx_channel:proto_ver(ChanState), - IoData = emqx_frame:serialize(Packet, ProtoVer), - send(IoData, SuccFun, State). - -%%-------------------------------------------------------------------- -%% Send data - -send(IoData, SuccFun, State = #state{transport = Transport, socket = Socket}) -> - Oct = iolist_size(IoData), - ok = emqx_metrics:inc('bytes.sent', Oct), - case Transport:async_send(Socket, IoData) of - ok -> SuccFun(maybe_gc(1, Oct, State)); - {error, Reason} -> - shutdown(Reason, State) - end. - -%%-------------------------------------------------------------------- -%% Ensure keepalive - -ensure_keepalive(0, State) -> - ignore; -ensure_keepalive(Interval, State = #state{transport = Transport, - socket = Socket, - chan_state = ChanState}) -> - StatFun = fun() -> - case Transport:getstat(Socket, [recv_oct]) of - {ok, [{recv_oct, RecvOct}]} -> - {ok, RecvOct}; - Error -> Error - end - end, - Backoff = emqx_zone:get_env(emqx_channel:zone(ChanState), - keepalive_backoff, 0.75), - emqx_keepalive:start(StatFun, round(Interval * Backoff), {keepalive, check}). - -%%-------------------------------------------------------------------- -%% Ensure rate limit - -ensure_rate_limit(State = #state{rate_limit = Rl, pub_limit = Pl}) -> - Limiters = [{Pl, #state.pub_limit, emqx_pd:reset_counter(incoming_pubs)}, - {Rl, #state.rate_limit, emqx_pd:reset_counter(incoming_bytes)}], - ensure_rate_limit(Limiters, State). - -ensure_rate_limit([], State) -> - State; -ensure_rate_limit([{undefined, _Pos, _Cnt}|Limiters], State) -> - ensure_rate_limit(Limiters, State); -ensure_rate_limit([{Rl, Pos, Cnt}|Limiters], State) -> - case esockd_rate_limit:check(Cnt, Rl) of - {0, Rl1} -> - ensure_rate_limit(Limiters, setelement(Pos, State, Rl1)); - {Pause, Rl1} -> - ?LOG(debug, "Rate limit pause connection ~pms", [Pause]), - TRef = erlang:send_after(Pause, self(), activate_socket), - setelement(Pos, State#state{conn_state = blocked, limit_timer = TRef}, Rl1) - end. - -%%-------------------------------------------------------------------- -%% Activate Socket - -activate_socket(#state{conn_state = blocked}) -> - ok; -activate_socket(#state{transport = Transport, - socket = Socket, - active_n = N}) -> - Transport:setopts(Socket, [{active, N}]). - -%%-------------------------------------------------------------------- -%% Inc incoming/outgoing stats - -inc_incoming_stats(Type) -> - emqx_pd:update_counter(recv_pkt, 1), - case Type == ?PUBLISH of - true -> - emqx_pd:update_counter(recv_msg, 1), - emqx_pd:update_counter(incoming_pubs, 1); - false -> ok - end. - -inc_outgoing_stats(Type) -> - emqx_pd:update_counter(send_pkt, 1), - (Type == ?PUBLISH) - andalso emqx_pd:update_counter(send_msg, 1). - -%%-------------------------------------------------------------------- -%% Ensure stats timer - -ensure_stats_timer(State = #state{stats_timer = undefined, - idle_timeout = IdleTimeout}) -> - State#state{stats_timer = emqx_misc:start_timer(IdleTimeout, emit_stats)}; -%% disabled or timer existed -ensure_stats_timer(State) -> State. - -%%-------------------------------------------------------------------- -%% Maybe GC - -maybe_gc(_Cnt, _Oct, State = #state{gc_state = undefined}) -> - State; -maybe_gc(Cnt, Oct, State = #state{gc_state = GCSt}) -> - {_, GCSt1} = emqx_gc:run(Cnt, Oct, GCSt), - %% TODO: gc metric? - State#state{gc_state = GCSt1}. - -%%-------------------------------------------------------------------- -%% Helper functions - --compile({inline, [reply/3]}). -reply(From, Reply, State) -> - {keep_state, State, [{reply, From, Reply}]}. - --compile({inline, [keep_state/1]}). -keep_state(State) -> - {keep_state, State}. - --compile({inline, [shutdown/2]}). -shutdown(Reason, State) -> - stop({shutdown, Reason}, State). - --compile({inline, [stop/2]}). -stop(Reason, State) -> - {stop, Reason, State}. - diff --git a/src/emqx_endpoint.erl b/src/emqx_endpoint.erl index 1529698aa..dee80a099 100644 --- a/src/emqx_endpoint.erl +++ b/src/emqx_endpoint.erl @@ -21,6 +21,7 @@ %% APIs -export([ new/0 , new/1 + , info/1 ]). -export([ zone/1 @@ -36,25 +37,25 @@ -export_type([endpoint/0]). --opaque(endpoint() :: - {endpoint, - #{zone := emqx_types:zone(), - peername := emqx_types:peername(), - sockname => emqx_types:peername(), - client_id := emqx_types:client_id(), - username := emqx_types:username(), - peercert := esockd_peercert:peercert(), - is_superuser := boolean(), - mountpoint := maybe(binary()), - ws_cookie := maybe(list()), - password => binary(), - auth_result => emqx_types:auth_result(), - anonymous => boolean(), - atom() => term() - } - }). +-type(st() :: #{zone := emqx_types:zone(), + conn_mod := maybe(module()), + peername := emqx_types:peername(), + sockname := emqx_types:peername(), + client_id := emqx_types:client_id(), + username := emqx_types:username(), + peercert := esockd_peercert:peercert(), + is_superuser := boolean(), + mountpoint := maybe(binary()), + ws_cookie := maybe(list()), + password => binary(), + auth_result => emqx_types:auth_result(), + anonymous => boolean(), + atom() => term() + }). --define(Endpoint(M), {endpoint, M}). +-opaque(endpoint() :: {endpoint, st()}). + +-define(Endpoint(St), {endpoint, St}). -define(Default, #{is_superuser => false, anonymous => false @@ -68,6 +69,9 @@ new() -> new(M) when is_map(M) -> ?Endpoint(maps:merge(?Default, M)). +info(?Endpoint(M)) -> + maps:to_list(M). + -spec(zone(endpoint()) -> emqx_zone:zone()). zone(?Endpoint(#{zone := Zone})) -> Zone. diff --git a/src/emqx_frame.erl b/src/emqx_frame.erl index 3a18ab297..c0115cea8 100644 --- a/src/emqx_frame.erl +++ b/src/emqx_frame.erl @@ -35,13 +35,13 @@ ]). -type(options() :: #{max_size => 1..?MAX_PACKET_SIZE, - version => emqx_mqtt:version() + version => emqx_types:version() }). -opaque(parse_state() :: {none, options()} | {more, cont_fun()}). -opaque(parse_result() :: {ok, parse_state()} - | {ok, emqx_mqtt:packet(), binary(), parse_state()}). + | {ok, emqx_types:packet(), binary(), parse_state()}). -type(cont_fun() :: fun((binary()) -> parse_result())). @@ -385,11 +385,11 @@ parse_binary_data(<>) -> %% Serialize MQTT Packet %%-------------------------------------------------------------------- --spec(serialize(emqx_mqtt:packet()) -> iodata()). +-spec(serialize(emqx_types:packet()) -> iodata()). serialize(Packet) -> serialize(Packet, ?MQTT_PROTO_V4). --spec(serialize(emqx_mqtt:packet(), emqx_mqtt:version()) -> iodata()). +-spec(serialize(emqx_types:packet(), emqx_types:version()) -> iodata()). serialize(#mqtt_packet{header = Header, variable = Variable, payload = Payload}, Ver) -> diff --git a/src/emqx_gc.erl b/src/emqx_gc.erl index 83264ecee..974d5b48e 100644 --- a/src/emqx_gc.erl +++ b/src/emqx_gc.erl @@ -15,8 +15,9 @@ %%-------------------------------------------------------------------- %%-------------------------------------------------------------------- -%% @doc This module manages an opaque collection of statistics data used -%% to force garbage collection on `self()' process when hitting thresholds. +%% @doc +%% This module manages an opaque collection of statistics data used to +%% force garbage collection on `self()' process when hitting thresholds. %% Namely: %% (1) Total number of messages passed through %% (2) Total data volume passed through @@ -41,9 +42,9 @@ -type(st() :: #{cnt => {integer(), integer()}, oct => {integer(), integer()}}). --opaque(gc_state() :: {?MODULE, st()}). +-opaque(gc_state() :: {gc_state, st()}). --define(GCS(St), {?MODULE, St}). +-define(GCS(St), {gc_state, St}). -define(disabled, disabled). -define(ENABLED(X), (is_integer(X) andalso X > 0)). diff --git a/src/emqx_inflight.erl b/src/emqx_inflight.erl index 88c4796c4..1bf70f4c8 100644 --- a/src/emqx_inflight.erl +++ b/src/emqx_inflight.erl @@ -22,7 +22,7 @@ , lookup/2 , insert/3 , update/3 - , update_size/2 + , resize/2 , delete/2 , values/1 , to_list/1 @@ -39,11 +39,11 @@ -type(max_size() :: pos_integer()). --opaque(inflight() :: {?MODULE, max_size(), gb_trees:tree()}). +-opaque(inflight() :: {inflight, max_size(), gb_trees:tree()}). --define(Inflight(Tree), {?MODULE, _MaxSize, Tree}). +-define(Inflight(Tree), {inflight, _MaxSize, Tree}). --define(Inflight(MaxSize, Tree), {?MODULE, MaxSize, (Tree)}). +-define(Inflight(MaxSize, Tree), {inflight, MaxSize, (Tree)}). %%-------------------------------------------------------------------- %% APIs @@ -73,8 +73,8 @@ delete(Key, ?Inflight(MaxSize, Tree)) -> update(Key, Val, ?Inflight(MaxSize, Tree)) -> ?Inflight(MaxSize, gb_trees:update(Key, Val, Tree)). --spec(update_size(integer(), inflight()) -> inflight()). -update_size(MaxSize, ?Inflight(Tree)) -> +-spec(resize(integer(), inflight()) -> inflight()). +resize(MaxSize, ?Inflight(Tree)) -> ?Inflight(MaxSize, Tree). -spec(is_full(inflight()) -> boolean()). diff --git a/src/emqx_listeners.erl b/src/emqx_listeners.erl index 94babe7fd..b39873879 100644 --- a/src/emqx_listeners.erl +++ b/src/emqx_listeners.erl @@ -46,13 +46,15 @@ start() -> -spec(start_listener(listener()) -> {ok, pid()} | {error, term()}). start_listener({Proto, ListenOn, Options}) -> - case start_listener(Proto, ListenOn, Options) of - {ok, _} -> - io:format("Start mqtt:~s listener on ~s successfully.~n", [Proto, format(ListenOn)]); + StartRet = start_listener(Proto, ListenOn, Options), + case StartRet of + {ok, _} -> io:format("Start mqtt:~s listener on ~s successfully.~n", + [Proto, format(ListenOn)]); {error, Reason} -> io:format(standard_error, "Failed to start mqtt:~s listener on ~s - ~p~n!", [Proto, format(ListenOn), Reason]) - end. + end, + StartRet. %% Start MQTT/TCP listener -spec(start_listener(esockd:proto(), esockd:listen_on(), [esockd:option()]) @@ -66,16 +68,18 @@ start_listener(Proto, ListenOn, Options) when Proto == ssl; Proto == tls -> %% Start MQTT/WS listener start_listener(Proto, ListenOn, Options) when Proto == http; Proto == ws -> - start_http_listener(fun cowboy:start_clear/3, 'mqtt:ws', ListenOn, ranch_opts(Options), ws_opts(Options)); + start_http_listener(fun cowboy:start_clear/3, 'mqtt:ws', ListenOn, + ranch_opts(Options), ws_opts(Options)); %% Start MQTT/WSS listener start_listener(Proto, ListenOn, Options) when Proto == https; Proto == wss -> - start_http_listener(fun cowboy:start_tls/3, 'mqtt:wss', ListenOn, ranch_opts(Options), ws_opts(Options)). + start_http_listener(fun cowboy:start_tls/3, 'mqtt:wss', ListenOn, + ranch_opts(Options), ws_opts(Options)). start_mqtt_listener(Name, ListenOn, Options) -> SockOpts = esockd:parse_opt(Options), esockd:open(Name, ListenOn, merge_default(SockOpts), - {emqx_connection, start_link, [Options -- SockOpts]}). + {emqx_channel, start_link, [Options -- SockOpts]}). start_http_listener(Start, Name, ListenOn, RanchOpts, ProtoOpts) -> Start(Name, with_port(ListenOn, RanchOpts), ProtoOpts). @@ -84,8 +88,10 @@ mqtt_path(Options) -> proplists:get_value(mqtt_path, Options, "/mqtt"). ws_opts(Options) -> - Dispatch = cowboy_router:compile([{'_', [{mqtt_path(Options), emqx_ws_connection, Options}]}]), - #{env => #{dispatch => Dispatch}, proxy_header => proplists:get_value(proxy_protocol, Options, false)}. + WsPaths = [{mqtt_path(Options), emqx_ws_channel, Options}], + Dispatch = cowboy_router:compile([{'_', WsPaths}]), + ProxyProto = proplists:get_value(proxy_protocol, Options, false), + #{env => #{dispatch => Dispatch}, proxy_header => ProxyProto}. ranch_opts(Options) -> NumAcceptors = proplists:get_value(acceptors, Options, 4), @@ -134,13 +140,15 @@ stop() -> -spec(stop_listener(listener()) -> ok | {error, term()}). stop_listener({Proto, ListenOn, Opts}) -> - case stop_listener(Proto, ListenOn, Opts) of - ok -> - io:format("Stop mqtt:~s listener on ~s successfully.~n", [Proto, format(ListenOn)]); + StopRet = stop_listener(Proto, ListenOn, Opts), + case StopRet of + ok -> io:format("Stop mqtt:~s listener on ~s successfully.~n", + [Proto, format(ListenOn)]); {error, Reason} -> io:format(standard_error, "Failed to stop mqtt:~s listener on ~s - ~p~n.", [Proto, format(ListenOn), Reason]) - end. + end, + StopRet. -spec(stop_listener(esockd:proto(), esockd:listen_on(), [esockd:option()]) -> ok | {error, term()}). diff --git a/src/emqx_message.erl b/src/emqx_message.erl index a00928af8..8617c8608 100644 --- a/src/emqx_message.erl +++ b/src/emqx_message.erl @@ -76,7 +76,7 @@ make(From, Topic, Payload) -> make(From, ?QOS_0, Topic, Payload). -spec(make(atom() | emqx_types:client_id(), - emqx_mqtt_types:qos(), + emqx_types:qos(), emqx_topic:topic(), emqx_types:payload()) -> emqx_types:message()). make(From, QoS, Topic, Payload) when ?QOS_0 =< QoS, QoS =< ?QOS_2 -> @@ -91,7 +91,7 @@ make(From, QoS, Topic, Payload) when ?QOS_0 =< QoS, QoS =< ?QOS_2 -> -spec(id(emqx_types:message()) -> maybe(binary())). id(#message{id = Id}) -> Id. --spec(qos(emqx_types:message()) -> emqx_mqtt_types:qos()). +-spec(qos(emqx_types:message()) -> emqx_types:qos()). qos(#message{qos = QoS}) -> QoS. -spec(from(emqx_types:message()) -> atom() | binary()). diff --git a/src/emqx_misc.erl b/src/emqx_misc.erl index e04fd606d..fdc7c5dc3 100644 --- a/src/emqx_misc.erl +++ b/src/emqx_misc.erl @@ -122,19 +122,20 @@ check([{Pred, Result} | Rest]) -> is_message_queue_too_long(Qlength, Max) -> is_enabled(Max) andalso Qlength > Max. -is_enabled(Max) -> is_integer(Max) andalso Max > ?DISABLED. +is_enabled(Max) -> + is_integer(Max) andalso Max > ?DISABLED. proc_info(Key) -> {Key, Value} = erlang:process_info(self(), Key), Value. -%% @doc Drain delivers from channel's mailbox. +%% @doc Drain delivers from the channel's mailbox. drain_deliver(Acc) -> receive Deliver = {deliver, _Topic, _Msg} -> drain_deliver([Deliver|Acc]) after 0 -> - lists:reverse(Acc) + lists:reverse(Acc) end. %% @doc Drain process down events. @@ -150,6 +151,6 @@ drain_down(Cnt, Acc) -> {'DOWN', _MRef, process, Pid, _Reason} -> drain_down(Cnt - 1, [Pid|Acc]) after 0 -> - lists:reverse(Acc) + drain_down(0, Acc) end. diff --git a/src/emqx_mod_presence.erl b/src/emqx_mod_presence.erl index 151ffbb7f..9c96bd774 100644 --- a/src/emqx_mod_presence.erl +++ b/src/emqx_mod_presence.erl @@ -46,9 +46,9 @@ load(Env) -> on_client_connected(#{client_id := ClientId, username := Username, peername := {IpAddr, _}}, ConnAck, ConnAttrs, Env) -> - Attrs = maps:filter(fun(K, _) -> - lists:member(K, ?ATTR_KEYS) - end, ConnAttrs), + Attrs = #{},%maps:filter(fun(K, _) -> + % lists:member(K, ?ATTR_KEYS) + % end, ConnAttrs), case emqx_json:safe_encode(Attrs#{clientid => ClientId, username => Username, ipaddress => iolist_to_binary(esockd_net:ntoa(IpAddr)), diff --git a/src/emqx_mountpoint.erl b/src/emqx_mountpoint.erl index 80a65b743..4b820fdc7 100644 --- a/src/emqx_mountpoint.erl +++ b/src/emqx_mountpoint.erl @@ -35,15 +35,23 @@ mount(undefined, Any) -> Any; +mount(MountPoint, Topic) when is_binary(Topic) -> + <>; mount(MountPoint, Msg = #message{topic = Topic}) -> Msg#message{topic = <>}; - mount(MountPoint, TopicFilters) when is_list(TopicFilters) -> [{<>, SubOpts} || {Topic, SubOpts} <- TopicFilters]. unmount(undefined, Msg) -> Msg; +%% TODO: Fixme later +unmount(MountPoint, Topic) when is_binary(Topic) -> + try split_binary(Topic, byte_size(MountPoint)) of + {MountPoint, Topic1} -> Topic1 + catch + error:badarg-> Topic + end; unmount(MountPoint, Msg = #message{topic = Topic}) -> try split_binary(Topic, byte_size(MountPoint)) of {MountPoint, Topic1} -> Msg#message{topic = Topic1} diff --git a/src/emqx_mqtt.erl b/src/emqx_mqtt.erl deleted file mode 100644 index bd10bce40..000000000 --- a/src/emqx_mqtt.erl +++ /dev/null @@ -1,60 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - -%% MQTT Types --module(emqx_mqtt). - --include("emqx_mqtt.hrl"). - --export_type([ version/0 - , qos/0 - , qos_name/0 - ]). - --export_type([ connack/0 - , reason_code/0 - ]). - --export_type([ properties/0 - , subopts/0 - ]). - --export_type([topic_filters/0]). - --export_type([ packet_id/0 - , packet_type/0 - , packet/0 - ]). - --type(qos() :: ?QOS_0 | ?QOS_1 | ?QOS_2). --type(version() :: ?MQTT_PROTO_V3 | ?MQTT_PROTO_V4 | ?MQTT_PROTO_V5). --type(qos_name() :: qos0 | at_most_once | - qos1 | at_least_once | - qos2 | exactly_once). --type(packet_type() :: ?RESERVED..?AUTH). --type(connack() :: ?CONNACK_ACCEPT..?CONNACK_AUTH). --type(reason_code() :: 0..16#FF). --type(packet_id() :: 1..16#FFFF). --type(properties() :: #{atom() => term()}). --type(subopts() :: #{rh := 0 | 1 | 2, - rap := 0 | 1, - nl := 0 | 1, - qos := qos(), - rc => reason_code() - }). --type(topic_filters() :: list({emqx_topic:topic(), subopts()})). --type(packet() :: #mqtt_packet{}). - diff --git a/src/emqx_mqtt_caps.erl b/src/emqx_mqtt_caps.erl index f941b1e24..a95641723 100644 --- a/src/emqx_mqtt_caps.erl +++ b/src/emqx_mqtt_caps.erl @@ -32,7 +32,7 @@ max_clientid_len => integer(), max_topic_alias => integer(), max_topic_levels => integer(), - max_qos_allowed => emqx_mqtt_types:qos(), + max_qos_allowed => emqx_types:qos(), mqtt_retain_available => boolean(), mqtt_shared_subscription => boolean(), mqtt_wildcard_subscription => boolean()}). @@ -57,7 +57,7 @@ mqtt_shared_subscription, mqtt_wildcard_subscription]). --spec(check_pub(emqx_types:zone(), map()) -> ok | {error, emqx_mqtt_types:reason_code()}). +-spec(check_pub(emqx_types:zone(), map()) -> ok | {error, emqx_types:reason_code()}). check_pub(Zone, Props) when is_map(Props) -> do_check_pub(Props, maps:to_list(get_caps(Zone, publish))). @@ -80,8 +80,8 @@ do_check_pub(Props, [{max_topic_alias, _} | Caps]) -> do_check_pub(Props, [{mqtt_retain_available, _}|Caps]) -> do_check_pub(Props, Caps). --spec(check_sub(emqx_types:zone(), emqx_mqtt_types:topic_filters()) - -> {ok | error, emqx_mqtt_types:topic_filters()}). +-spec(check_sub(emqx_types:zone(), emqx_types:topic_filters()) + -> {ok | error, emqx_types:topic_filters()}). check_sub(Zone, TopicFilters) -> Caps = maps:to_list(get_caps(Zone, subscribe)), lists:foldr(fun({Topic, Opts}, {Ok, Result}) -> @@ -154,3 +154,4 @@ with_env(Zone, Key, InitFun) -> Caps; ZoneCaps -> ZoneCaps end. + diff --git a/src/emqx_packet.erl b/src/emqx_packet.erl index ea917c8bf..71d1fb116 100644 --- a/src/emqx_packet.erl +++ b/src/emqx_packet.erl @@ -29,7 +29,7 @@ ]). %% @doc Protocol name of version --spec(protocol_name(emqx_mqtt_types:version()) -> binary()). +-spec(protocol_name(emqx_types:version()) -> binary()). protocol_name(?MQTT_PROTO_V3) -> <<"MQIsdp">>; protocol_name(?MQTT_PROTO_V4) -> @@ -38,7 +38,7 @@ protocol_name(?MQTT_PROTO_V5) -> <<"MQTT">>. %% @doc Name of MQTT packet type --spec(type_name(emqx_mqtt_types:packet_type()) -> atom()). +-spec(type_name(emqx_types:packet_type()) -> atom()). type_name(Type) when Type > ?RESERVED andalso Type =< ?AUTH -> lists:nth(Type, ?TYPE_NAMES). @@ -46,7 +46,7 @@ type_name(Type) when Type > ?RESERVED andalso Type =< ?AUTH -> %% Validate MQTT Packet %%-------------------------------------------------------------------- --spec(validate(emqx_mqtt_types:packet()) -> true). +-spec(validate(emqx_types:packet()) -> true). validate(?SUBSCRIBE_PACKET(_PacketId, _Properties, [])) -> error(topic_filters_invalid); validate(?SUBSCRIBE_PACKET(PacketId, Properties, TopicFilters)) -> @@ -113,8 +113,8 @@ validate_qos(QoS) when ?QOS_0 =< QoS, QoS =< ?QOS_2 -> validate_qos(_) -> error(bad_qos). %% @doc From message to packet --spec(from_message(emqx_mqtt_types:packet_id(), emqx_types:message()) - -> emqx_mqtt_types:packet()). +-spec(from_message(emqx_types:packet_id(), emqx_types:message()) + -> emqx_types:packet()). from_message(PacketId, #message{qos = QoS, flags = Flags, headers = Headers, topic = Topic, payload = Payload}) -> Flags1 = if Flags =:= undefined -> @@ -142,7 +142,7 @@ publish_props(Headers) -> 'Message-Expiry-Interval'], Headers). %% @doc Message from Packet --spec(to_message(emqx_types:credentials(), emqx_mqtt_types:packet()) +-spec(to_message(emqx_types:credentials(), emqx_ypes:packet()) -> emqx_types:message()). to_message(#{client_id := ClientId, username := Username, peername := Peername}, #mqtt_packet{header = #mqtt_packet_header{type = ?PUBLISH, @@ -177,7 +177,7 @@ merge_props(Headers, Props) -> maps:merge(Headers, Props). %% @doc Format packet --spec(format(emqx_mqtt_types:packet()) -> iolist()). +-spec(format(emqx_types:packet()) -> iolist()). format(#mqtt_packet{header = Header, variable = Variable, payload = Payload}) -> format_header(Header, format_variable(Variable, Payload)). diff --git a/src/emqx_protocol.erl b/src/emqx_protocol.erl new file mode 100644 index 000000000..9bbb63bf4 --- /dev/null +++ b/src/emqx_protocol.erl @@ -0,0 +1,594 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +%% MQTT Protocol +-module(emqx_protocol). + +-include("emqx.hrl"). +-include("emqx_mqtt.hrl"). +-include("logger.hrl"). +-include("types.hrl"). + +-logger_header("[Protocol]"). + +-export([ info/1 + , info/2 + , attrs/1 + , stats/1 + ]). + +-export([ client_id/1 + , session/1 + ]). + +-export([ init/2 + , handle_in/2 + , handle_deliver/2 + , handle_out/2 + , handle_timeout/3 + , terminate/2 + ]). + +-export_type([proto_state/0]). + +-record(protocol, { + proto_name :: binary(), + proto_ver :: emqx_types:version(), + client :: emqx_types:client(), + session :: emqx_session:session(), + mountfun :: fun((emqx_topic:topic()) -> emqx_topic:topic()), + keepalive :: non_neg_integer(), + will_msg :: emqx_types:message(), + enable_acl :: boolean(), + is_bridge :: boolean(), + connected :: boolean(), + connected_at :: erlang:timestamp(), + topic_aliases :: map(), + alias_maximum :: map() + }). + +-opaque(proto_state() :: #protocol{}). + +info(#protocol{client = Client, session = Session}) -> + lists:append([maps:to_list(Client), emqx_session:info(Session)]). + +info(zone, #protocol{client = #{zone := Zone}}) -> + Zone; +info(proto_name, #protocol{proto_name = ProtoName}) -> + ProtoName; +info(proto_ver, #protocol{proto_ver = ProtoVer}) -> + ProtoVer; +info(keepalive, #protocol{keepalive = Keepalive}) -> + Keepalive. + +attrs(#protocol{}) -> + #{}. + +stats(#protocol{}) -> + []. + +-spec(client_id(proto_state()) -> emqx_types:client_id()). +client_id(#protocol{client = #{client_id := ClientId}}) -> + ClientId. + +-spec(session(proto_state()) -> emqx_session:session()). +session(#protocol{session = Session}) -> + Session. + +-spec(init(map(), proplists:proplist()) -> proto_state()). +init(ConnInfo, Options) -> + Zone = proplists:get_value(zone, Options), + Peercert = maps:get(peercert, ConnInfo, nossl), + Username = peer_cert_as_username(Peercert, Options), + Mountpoint = emqx_zone:get_env(Zone, mountpoint), + Client = maps:merge(#{zone => Zone, + username => Username, + mountpoint => Mountpoint + }, ConnInfo), + EnableAcl = emqx_zone:get_env(Zone, enable_acl, false), + MountFun = fun(Topic) -> + emqx_mountpoint:mount(Mountpoint, Topic) + end, + #protocol{client = Client, + mountfun = MountFun, + enable_acl = EnableAcl, + is_bridge = false, + connected = false + }. + +peer_cert_as_username(Peercert, Options) -> + case proplists:get_value(peer_cert_as_username, Options) of + cn -> esockd_peercert:common_name(Peercert); + dn -> esockd_peercert:subject(Peercert); + crt -> Peercert; + _ -> undefined + end. + +%%-------------------------------------------------------------------- +%% Handle incoming packet +%%-------------------------------------------------------------------- + +-spec(handle_in(emqx_types:packet(), proto_state()) + -> {ok, proto_state()} + | {ok, emqx_types:packet(), proto_state()} + | {error, Reason :: term(), proto_state()} + | {stop, Error :: atom(), proto_state()}). +handle_in(?CONNECT_PACKET( + #mqtt_packet_connect{proto_name = ProtoName, + proto_ver = ProtoVer, + is_bridge = IsBridge, + client_id = ClientId, + username = Username, + password = Password, + keepalive = Keepalive} = ConnPkt), + PState = #protocol{client = Client}) -> + Client1 = maps:merge(Client, #{client_id => ClientId, + username => Username, + password => Password + }), + emqx_logger:set_metadata_client_id(ClientId), + WillMsg = emqx_packet:will_msg(ConnPkt), + PState1 = PState#protocol{client = Client1, + proto_name = ProtoName, + proto_ver = ProtoVer, + is_bridge = IsBridge, + keepalive = Keepalive, + will_msg = WillMsg + }, + %% fun validate_packet/2, + case pipeline([fun check_connect/2, + fun handle_connect/2], ConnPkt, PState1) of + {ok, SP, PState2} -> + handle_out({connack, ?RC_SUCCESS, sp(SP)}, PState2); + {error, ReasonCode} -> + handle_out({connack, ReasonCode}, PState1); + {error, ReasonCode, PState2} -> + handle_out({connack, ReasonCode}, PState2) + end; + +handle_in(Packet = ?PUBLISH_PACKET(QoS, Topic, PacketId), PState) -> + case pipeline([fun validate_packet/2, + fun check_pub_caps/2, + fun check_pub_acl/2, + fun handle_publish/2], Packet, PState) of + {error, ReasonCode} -> + ?LOG(warning, "Cannot publish qos~w message to ~s due to ~s", + [QoS, Topic, emqx_reason_codes:text(ReasonCode)]), + handle_out(case QoS of + ?QOS_0 -> {puberr, ReasonCode}; + ?QOS_1 -> {puback, PacketId, ReasonCode}; + ?QOS_2 -> {pubrec, PacketId, ReasonCode} + end, PState); + Result -> Result + end; + +handle_in(?PUBACK_PACKET(PacketId, ReasonCode), PState = #protocol{session = Session}) -> + case emqx_session:puback(PacketId, ReasonCode, Session) of + {ok, NSession} -> + {ok, PState#protocol{session = NSession}}; + {error, _NotFound} -> + %% TODO: metrics? error msg? + {ok, PState} + end; + +handle_in(?PUBREC_PACKET(PacketId, ReasonCode), PState = #protocol{session = Session}) -> + case emqx_session:pubrec(PacketId, ReasonCode, Session) of + {ok, NSession} -> + handle_out({pubrel, PacketId}, PState#protocol{session = NSession}); + {error, ReasonCode} -> + handle_out({pubrel, PacketId, ReasonCode}, PState) + end; + +handle_in(?PUBREL_PACKET(PacketId, ReasonCode), PState = #protocol{session = Session}) -> + case emqx_session:pubrel(PacketId, ReasonCode, Session) of + {ok, NSession} -> + handle_out({pubcomp, PacketId}, PState#protocol{session = NSession}); + {error, ReasonCode} -> + handle_out({pubcomp, PacketId, ReasonCode}, PState) + end; + +handle_in(?PUBCOMP_PACKET(PacketId, ReasonCode), + PState = #protocol{session = Session}) -> + case emqx_session:pubcomp(PacketId, ReasonCode, Session) of + {ok, NSession} -> + {ok, PState#protocol{session = NSession}}; + {error, _ReasonCode} -> + %% TODO: How to handle the reason code? + {ok, PState} + end; + +handle_in(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), + PState = #protocol{client = Client}) -> + case validate(Packet) of + ok -> ok = emqx_hooks:run('client.subscribe', + [Client, Properties, TopicFilters]), + TopicFilters1 = enrich_subid(Properties, TopicFilters), + {ReasonCodes, PState1} = handle_subscribe(TopicFilters1, PState), + handle_out({suback, PacketId, ReasonCodes}, PState1); + {error, ReasonCode} -> + handle_out({disconnect, ReasonCode}, PState) + end; + +handle_in(Packet = ?UNSUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), + PState = #protocol{client = Client}) -> + case validate(Packet) of + ok -> ok = emqx_hooks:run('client.unsubscribe', + [Client, Properties, TopicFilters]), + {ReasonCodes, PState1} = handle_unsubscribe(TopicFilters, PState), + handle_out({unsuback, PacketId, ReasonCodes}, PState1); + {error, ReasonCode} -> + handle_out({disconnect, ReasonCode}, PState) + end; + +handle_in(?PACKET(?PINGREQ), PState) -> + {ok, ?PACKET(?PINGRESP), PState}; + +handle_in(?DISCONNECT_PACKET(?RC_SUCCESS), PState) -> + %% Clear will msg + {stop, normal, PState#protocol{will_msg = undefined}}; + +handle_in(?DISCONNECT_PACKET(RC), PState = #protocol{proto_ver = Ver}) -> + %% TODO: + %% {stop, {shutdown, abnormal_disconnet}, PState}; + {sto, {shutdown, emqx_reason_codes:name(RC, Ver)}, PState}; + +handle_in(?AUTH_PACKET(), PState) -> + %%TODO: implement later. + {ok, PState}; + +handle_in(Packet, PState) -> + io:format("In: ~p~n", [Packet]), + {ok, PState}. + +%%-------------------------------------------------------------------- +%% Handle delivers +%%-------------------------------------------------------------------- + +handle_deliver(Delivers, PState = #protocol{client = Client, session = Session}) + when is_list(Delivers) -> + case emqx_session:handle(Delivers, Session) of + {ok, Publishes, NSession} -> + Packets = lists:map(fun({publish, PacketId, Msg}) -> + Msg0 = emqx_hooks:run_fold('message.deliver', [Client], Msg), + Msg1 = emqx_message:update_expiry(Msg0), + Msg2 = emqx_mountpoint:unmount(maps:get(mountpoint, Client), Msg1), + emqx_packet:from_message(PacketId, Msg2) + end, Publishes), + {ok, Packets, PState#protocol{session = NSession}}; + {ok, NSession} -> + {ok, PState#protocol{session = NSession}} + end. + +%%-------------------------------------------------------------------- +%% Handle outgoing packet +%%-------------------------------------------------------------------- + +handle_out({connack, ?RC_SUCCESS, SP}, PState = #protocol{client = Client}) -> + ok = emqx_hooks:run('client.connected', [Client, ?RC_SUCCESS, info(PState)]), + Props = #{}, %% TODO: ... + {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, Props), PState}; + +handle_out({connack, ReasonCode}, PState = #protocol{client = Client, + proto_ver = ProtoVer}) -> + ok = emqx_hooks:run('client.connected', [Client, ReasonCode, info(PState)]), + ReasonCode1 = if + ProtoVer == ?MQTT_PROTO_V5 -> ReasonCode; + true -> emqx_reason_codes:compat(connack, ReasonCode) + end, + Reason = emqx_reason_codes:name(ReasonCode1, ProtoVer), + {error, Reason, ?CONNACK_PACKET(ReasonCode1), PState}; + +handle_out({publish, PacketId, Msg}, PState = #protocol{client = Client}) -> + Msg0 = emqx_hooks:run_fold('message.deliver', [Client], Msg), + Msg1 = emqx_message:update_expiry(Msg0), + Msg2 = emqx_mountpoint:unmount(maps:get(mountpoint, Client), Msg1), + {ok, emqx_packet:from_message(PacketId, Msg2), PState}; + +handle_out({puberr, ReasonCode}, PState) -> + {ok, PState}; + +handle_out({puback, PacketId, ReasonCode}, PState) -> + {ok, ?PUBACK_PACKET(PacketId, ReasonCode), PState}; + +handle_out({pubrel, PacketId}, PState) -> + {ok, ?PUBREL_PACKET(PacketId), PState}; +handle_out({pubrel, PacketId, ReasonCode}, PState) -> + {ok, ?PUBREL_PACKET(PacketId, ReasonCode), PState}; + +handle_out({pubrec, PacketId, ReasonCode}, PState) -> + {ok, ?PUBREC_PACKET(PacketId, ReasonCode), PState}; + +handle_out({pubcomp, PacketId}, PState) -> + {ok, ?PUBCOMP_PACKET(PacketId), PState}; +handle_out({pubcomp, PacketId, ReasonCode}, PState) -> + {ok, ?PUBCOMP_PACKET(PacketId, ReasonCode), PState}; + +handle_out({suback, PacketId, ReasonCodes}, PState = #protocol{proto_ver = ?MQTT_PROTO_V5}) -> + %% TODO: ACL Deny + {ok, ?SUBACK_PACKET(PacketId, ReasonCodes), PState}; +handle_out({suback, PacketId, ReasonCodes}, PState) -> + %% TODO: ACL Deny + ReasonCodes1 = [emqx_reason_codes:compat(suback, RC) || RC <- ReasonCodes], + {ok, ?SUBACK_PACKET(PacketId, ReasonCodes1), PState}; + +handle_out({unsuback, PacketId, ReasonCodes}, PState = #protocol{proto_ver = ?MQTT_PROTO_V5}) -> + {ok, ?UNSUBACK_PACKET(PacketId, ReasonCodes), PState}; +%% Ignore reason codes if not MQTT5 +handle_out({unsuback, PacketId, _ReasonCodes}, PState) -> + {ok, ?UNSUBACK_PACKET(PacketId), PState}; + +handle_out(Packet, State) -> + io:format("Out: ~p~n", [Packet]), + {ok, State}. + +%%-------------------------------------------------------------------- +%% Handle timeout +%%-------------------------------------------------------------------- + +handle_timeout(TRef, Msg, PState = #protocol{session = Session}) -> + case emqx_session:timeout(TRef, Msg, Session) of + {ok, NSession} -> + {ok, PState#protocol{session = NSession}}; + {ok, Publishes, NSession} -> + %% TODO: handle out... + io:format("Timeout publishes: ~p~n", [Publishes]), + {ok, PState#protocol{session = NSession}} + end. + +terminate(Reason, _State) -> + io:format("Terminated for ~p~n", [Reason]), + ok. + +%%-------------------------------------------------------------------- +%% Check Connect Packet +%%-------------------------------------------------------------------- + +check_connect(_ConnPkt, PState) -> + {ok, PState}. + +%%-------------------------------------------------------------------- +%% Handle Connect Packet +%%-------------------------------------------------------------------- + +handle_connect(#mqtt_packet_connect{proto_name = ProtoName, + proto_ver = ProtoVer, + is_bridge = IsBridge, + clean_start = CleanStart, + keepalive = Keepalive, + properties = ConnProps, + client_id = ClientId, + username = Username, + password = Password + } = ConnPkt, + PState = #protocol{client = Client}) -> + case emqx_access_control:authenticate( + Client#{password => Password}) of + {ok, AuthResult} -> + Client1 = maps:merge(Client, AuthResult), + %% Open session + case open_session(ConnPkt, PState) of + {ok, Session, SP} -> + PState1 = PState#protocol{client = Client1, + session = Session, + connected = true, + connected_at = os:timestamp() + }, + ok = emqx_cm:register_channel(ClientId), + {ok, SP, PState1}; + {error, Error} -> + ?LOG(error, "Failed to open session: ~p", [Error]), + {error, ?RC_UNSPECIFIED_ERROR, PState#protocol{client = Client1}} + end; + {error, Reason} -> + ?LOG(warning, "Client ~s (Username: '~s') login failed for ~p", + [ClientId, Username, Reason]), + {error, emqx_reason_codes:connack_error(Reason), PState} + end. + +open_session(#mqtt_packet_connect{clean_start = CleanStart, + %%properties = ConnProps, + client_id = ClientId, + username = Username} = ConnPkt, + PState = #protocol{client = Client}) -> + emqx_cm:open_session(maps:merge(Client, #{clean_start => CleanStart, + max_inflight => 0, + expiry_interval => 0})). + +%%-------------------------------------------------------------------- +%% Handle Publish Message: Client -> Broker +%%-------------------------------------------------------------------- + +handle_publish(Packet = ?PUBLISH_PACKET(_QoS, Topic, PacketId), + PState = #protocol{client = Client = #{mountpoint := Mountpoint}}) -> + %% TODO: ugly... publish_to_msg(...) + Msg = emqx_packet:to_message(Client, Packet), + Msg1 = emqx_mountpoint:mount(Mountpoint, Msg), + Msg2 = emqx_message:set_flag(dup, false, Msg1), + handle_publish(PacketId, Msg2, PState). + +handle_publish(_PacketId, Msg = #message{qos = ?QOS_0}, PState) -> + _ = emqx_broker:publish(Msg), + {ok, PState}; + +handle_publish(PacketId, Msg = #message{qos = ?QOS_1}, PState) -> + Results = emqx_broker:publish(Msg), + ReasonCode = emqx_reason_codes:puback(Results), + handle_out({puback, PacketId, ReasonCode}, PState); + +handle_publish(PacketId, Msg = #message{qos = ?QOS_2}, + PState = #protocol{session = Session}) -> + case emqx_session:publish(PacketId, Msg, Session) of + {ok, Results, NSession} -> + ReasonCode = emqx_reason_codes:puback(Results), + handle_out({pubrec, PacketId, ReasonCode}, + PState#protocol{session = NSession}); + {error, ReasonCode} -> + handle_out({pubrec, PacketId, ReasonCode}, PState) + end. + +%%-------------------------------------------------------------------- +%% Handle Subscribe Request +%%-------------------------------------------------------------------- + +handle_subscribe(TopicFilters, PState) -> + handle_subscribe(TopicFilters, [], PState). + +handle_subscribe([], Acc, PState) -> + {lists:reverse(Acc), PState}; + +handle_subscribe([{TopicFilter, SubOpts}|More], Acc, PState) -> + {RC, PState1} = do_subscribe(TopicFilter, SubOpts, PState), + handle_subscribe(More, [RC|Acc], PState1). + +do_subscribe(TopicFilter, SubOpts = #{qos := QoS}, + PState = #protocol{client = Client, + session = Session, + mountfun = Mount}) -> + %% 1. Parse 2. Check 3. Enrich 5. MountPoint 6. Session + SubOpts1 = maps:merge(?DEFAULT_SUBOPTS, SubOpts), + {TopicFilter1, SubOpts2} = emqx_topic:parse(TopicFilter, SubOpts1), + SubOpts3 = enrich_subopts(SubOpts2, PState), + case check_subscribe(TopicFilter1, PState) of + ok -> + TopicFilter2 = Mount(TopicFilter1), + case emqx_session:subscribe(Client, TopicFilter2, SubOpts3, Session) of + {ok, NSession} -> + {QoS, PState#protocol{session = NSession}}; + {error, RC} -> {RC, PState} + end; + {error, RC} -> {RC, PState} + end. + +enrich_subid(#{'Subscription-Identifier' := SubId}, TopicFilters) -> + [{Topic, SubOpts#{subid => SubId}} || {Topic, SubOpts} <- TopicFilters]; +enrich_subid(_Properties, TopicFilters) -> + TopicFilters. + +enrich_subopts(SubOpts, #protocol{proto_ver = ?MQTT_PROTO_V5}) -> + SubOpts; +enrich_subopts(SubOpts, #protocol{client = #{zone := Zone}, + is_bridge = IsBridge}) -> + Rap = flag(IsBridge), + Nl = flag(emqx_zone:get_env(Zone, ignore_loop_deliver, false)), + SubOpts#{rap => Rap, nl => Nl}. + +check_subscribe(_TopicFilter, _PState) -> + ok. + +%%-------------------------------------------------------------------- +%% Handle Unsubscribe Request +%%-------------------------------------------------------------------- + +handle_unsubscribe(TopicFilters, PState) -> + handle_unsubscribe(TopicFilters, [], PState). + +handle_unsubscribe([], Acc, PState) -> + {lists:reverse(Acc), PState}; + +handle_unsubscribe([TopicFilter|More], Acc, PState) -> + {RC, PState1} = do_unsubscribe(TopicFilter, PState), + handle_unsubscribe(More, [RC|Acc], PState1). + +do_unsubscribe(TopicFilter, PState = #protocol{client = Client, + session = Session, + mountfun = Mount}) -> + TopicFilter1 = Mount(element(1, emqx_topic:parse(TopicFilter))), + case emqx_session:unsubscribe(Client, TopicFilter1, Session) of + {ok, NSession} -> + {?RC_SUCCESS, PState#protocol{session = NSession}}; + {error, RC} -> {RC, PState} + end. + +%%-------------------------------------------------------------------- +%% Validate Incoming Packet +%%-------------------------------------------------------------------- + +validate_packet(Packet, _PState) -> + validate(Packet). + +-spec(validate(emqx_types:packet()) -> ok | {error, emqx_types:reason_code()}). +validate(Packet) -> + try emqx_packet:validate(Packet) of + true -> ok + catch + error:protocol_error -> + {error, ?RC_PROTOCOL_ERROR}; + error:subscription_identifier_invalid -> + {error, ?RC_SUBSCRIPTION_IDENTIFIERS_NOT_SUPPORTED}; + error:topic_alias_invalid -> + {error, ?RC_TOPIC_ALIAS_INVALID}; + error:topic_filters_invalid -> + {error, ?RC_TOPIC_FILTER_INVALID}; + error:topic_name_invalid -> + {error, ?RC_TOPIC_FILTER_INVALID}; + error:_Reason -> + {error, ?RC_MALFORMED_PACKET} + end. + +%%-------------------------------------------------------------------- +%% Check Publish +%%-------------------------------------------------------------------- + +check_pub_caps(#mqtt_packet{header = #mqtt_packet_header{qos = QoS, + retain = Retain}, + variable = #mqtt_packet_publish{}}, + #protocol{client = #{zone := Zone}}) -> + emqx_mqtt_caps:check_pub(Zone, #{qos => QoS, retain => Retain}). + +check_pub_acl(_Packet, #protocol{enable_acl = false}) -> + ok; +check_pub_acl(_Packet, #protocol{client = #{is_superuser := true}}) -> + ok; +check_pub_acl(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = Topic}}, + #protocol{client = Client}) -> + do_acl_check(Client, publish, Topic). + +check_sub_acl(_Packet, #protocol{enable_acl = false}) -> + ok. + +do_acl_check(Client, PubSub, Topic) -> + case emqx_access_control:check_acl(Client, PubSub, Topic) of + allow -> ok; + deny -> {error, ?RC_NOT_AUTHORIZED} + end. + +%%-------------------------------------------------------------------- +%% Pipeline +%%-------------------------------------------------------------------- + +pipeline([Fun], Packet, PState) -> + Fun(Packet, PState); +pipeline([Fun|More], Packet, PState) -> + case Fun(Packet, PState) of + ok -> pipeline(More, Packet, PState); + {ok, NPState} -> + pipeline(More, Packet, NPState); + {ok, NPacket, NPState} -> + pipeline(More, NPacket, NPState); + {error, Reason} -> + {error, Reason} + end. + +%%-------------------------------------------------------------------- +%% Helper functions +%%-------------------------------------------------------------------- + +sp(true) -> 1; +sp(false) -> 0. + +flag(true) -> 1; +flag(false) -> 0. + diff --git a/src/emqx_session.erl b/src/emqx_session.erl index e699a7252..e5978142b 100644 --- a/src/emqx_session.erl +++ b/src/emqx_session.erl @@ -48,35 +48,28 @@ -include("logger.hrl"). -include("types.hrl"). --logger_header("[Session]"). - --export([new/1]). +-export([init/1]). -export([ info/1 - , attrs/1 , stats/1 ]). --export([ subscribe/3 +-export([ subscribe/4 , unsubscribe/3 ]). --export([publish/3]). - --export([ puback/3 +-export([ publish/3 + , puback/3 , pubrec/3 , pubrel/3 , pubcomp/3 ]). --export([ deliver/2 - , await/3 - , enqueue/2 - ]). +-export([handle/2]). --export_type([ session/0 - , puback_ret/0 - ]). +-export([timeout/3]). + +-export_type([session/0]). -import(emqx_zone, [ get_env/2 @@ -107,7 +100,7 @@ mqueue :: emqx_mqueue:mqueue(), %% Next packet id of the session - next_pkt_id = 1 :: emqx_mqtt:packet_id(), + next_pkt_id = 1 :: emqx_types:packet_id(), %% Retry interval for redelivering QoS1/2 messages retry_interval :: timeout(), @@ -140,17 +133,20 @@ -opaque(session() :: #session{}). --type(puback_ret() :: {ok, session()} - | {ok, emqx_types:message(), session()} - | {error, emqx_mqtt:reason_code()}). +-logger_header("[Session]"). -%% @doc Create a session. --spec(new(Attrs :: map()) -> session()). -new(#{zone := Zone, - clean_start := CleanStart, - max_inflight := MaxInflight, - expiry_interval := ExpiryInterval}) -> - %% emqx_logger:set_metadata_client_id(ClientId), +-define(DEFAULT_BATCH_N, 1000). + +%%-------------------------------------------------------------------- +%% Init a session +%%-------------------------------------------------------------------- + +%% @doc Init a session. +-spec(init(Attrs :: map()) -> session()). +init(#{zone := Zone, + clean_start := CleanStart, + max_inflight := MaxInflight, + expiry_interval := ExpiryInterval}) -> #session{clean_start = CleanStart, max_subscriptions = get_env(Zone, max_subscriptions, 0), subscriptions = #{}, @@ -173,12 +169,11 @@ init_mqueue(Zone) -> default_priority => get_env(Zone, mqueue_default_priority) }). -%%------------------------------------------------------------------------------ -%% Info, Attrs, Stats -%%------------------------------------------------------------------------------ +%%-------------------------------------------------------------------- +%% Info, Stats of Session +%%-------------------------------------------------------------------- -%% @doc Get session info --spec(info(session()) -> map()). +-spec(info(session()) -> proplists:proplist()). info(#session{clean_start = CleanStart, max_subscriptions = MaxSubscriptions, subscriptions = Subscriptions, @@ -186,174 +181,163 @@ info(#session{clean_start = CleanStart, inflight = Inflight, retry_interval = RetryInterval, mqueue = MQueue, - next_pkt_id = PktId, + next_pkt_id = PacketId, max_awaiting_rel = MaxAwaitingRel, awaiting_rel = AwaitingRel, await_rel_timeout = AwaitRelTimeout, expiry_interval = ExpiryInterval, created_at = CreatedAt}) -> - #{clean_start => CleanStart, - max_subscriptions => MaxSubscriptions, - subscriptions => Subscriptions, - upgrade_qos => UpgradeQoS, - inflight => Inflight, - retry_interval => RetryInterval, - mqueue_len => emqx_mqueue:len(MQueue), - next_pkt_id => PktId, - awaiting_rel => AwaitingRel, - max_awaiting_rel => MaxAwaitingRel, - await_rel_timeout => AwaitRelTimeout, - expiry_interval => ExpiryInterval div 1000, - created_at => CreatedAt - }. - -%% @doc Get session attrs. --spec(attrs(session()) -> map()). -attrs(#session{clean_start = CleanStart, - expiry_interval = ExpiryInterval, - created_at = CreatedAt}) -> - #{clean_start => CleanStart, - expiry_interval => ExpiryInterval div 1000, - created_at => CreatedAt - }. + [{clean_start, CleanStart}, + {max_subscriptions, MaxSubscriptions}, + {subscriptions, Subscriptions}, + {upgrade_qos, UpgradeQoS}, + {inflight, Inflight}, + {retry_interval, RetryInterval}, + {mqueue_len, emqx_mqueue:len(MQueue)}, + {next_pkt_id, PacketId}, + {awaiting_rel, AwaitingRel}, + {max_awaiting_rel, MaxAwaitingRel}, + {await_rel_timeout, AwaitRelTimeout}, + {expiry_interval, ExpiryInterval div 1000}, + {created_at, CreatedAt}]. %% @doc Get session stats. --spec(stats(session()) -> #{atom() => non_neg_integer()}). +-spec(stats(session()) -> list({atom(), non_neg_integer()})). stats(#session{max_subscriptions = MaxSubscriptions, subscriptions = Subscriptions, inflight = Inflight, mqueue = MQueue, max_awaiting_rel = MaxAwaitingRel, awaiting_rel = AwaitingRel}) -> - #{max_subscriptions => MaxSubscriptions, - subscriptions_count => maps:size(Subscriptions), - max_inflight => emqx_inflight:max_size(Inflight), - inflight_len => emqx_inflight:size(Inflight), - max_mqueue => emqx_mqueue:max_len(MQueue), - mqueue_len => emqx_mqueue:len(MQueue), - mqueue_dropped => emqx_mqueue:dropped(MQueue), - max_awaiting_rel => MaxAwaitingRel, - awaiting_rel_len => maps:size(AwaitingRel) - }. + [{max_subscriptions, MaxSubscriptions}, + {subscriptions_count, maps:size(Subscriptions)}, + {max_inflight, emqx_inflight:max_size(Inflight)}, + {inflight_len, emqx_inflight:size(Inflight)}, + {max_mqueue, emqx_mqueue:max_len(MQueue)}, + {mqueue_len, emqx_mqueue:len(MQueue)}, + {mqueue_dropped, emqx_mqueue:dropped(MQueue)}, + {max_awaiting_rel, MaxAwaitingRel}, + {awaiting_rel_len, maps:size(AwaitingRel)}]. %%-------------------------------------------------------------------- -%% PubSub API -%%-------------------------------------------------------------------- - %% Client -> Broker: SUBSCRIBE --spec(subscribe(emqx_types:credentials(), emqx_mqtt:topic_filters(), session()) - -> {ok, list(emqx_mqtt:reason_code()), session()}). -subscribe(Credentials, RawTopicFilters, Session = #session{subscriptions = Subscriptions}) - when is_list(RawTopicFilters) -> - TopicFilters = [emqx_topic:parse(RawTopic, maps:merge(?DEFAULT_SUBOPTS, SubOpts)) - || {RawTopic, SubOpts} <- RawTopicFilters], - {ReasonCodes, Subscriptions1} = - lists:foldr( - fun({Topic, SubOpts = #{qos := QoS, rc := RC}}, {RcAcc, SubMap}) - when RC == ?QOS_0; RC == ?QOS_1; RC == ?QOS_2 -> - {[QoS|RcAcc], do_subscribe(Credentials, Topic, SubOpts, SubMap)}; - ({_Topic, #{rc := RC}}, {RcAcc, SubMap}) -> - {[RC|RcAcc], SubMap} - end, {[], Subscriptions}, TopicFilters), - {ok, ReasonCodes, Session#session{subscriptions = Subscriptions1}}. +%%-------------------------------------------------------------------- -do_subscribe(Credentials = #{client_id := ClientId}, Topic, SubOpts, SubMap) -> - case maps:find(Topic, SubMap) of - {ok, SubOpts} -> - ok = emqx_hooks:run('session.subscribed', [Credentials, Topic, SubOpts#{first => false}]), - SubMap; - {ok, _SubOpts} -> - emqx_broker:set_subopts(Topic, SubOpts), - %% Why??? - ok = emqx_hooks:run('session.subscribed', [Credentials, Topic, SubOpts#{first => false}]), - maps:put(Topic, SubOpts, SubMap); - error -> - ok = emqx_broker:subscribe(Topic, ClientId, SubOpts), - ok = emqx_hooks:run('session.subscribed', [Credentials, Topic, SubOpts#{first => true}]), - maps:put(Topic, SubOpts, SubMap) +-spec(subscribe(emqx_types:client(), emqx_types:topic(), emqx_types:subopts(), + session()) -> {ok, session()} | {error, emqx_types:reason_code()}). +subscribe(Client, TopicFilter, SubOpts, Session = #session{subscriptions = Subs}) -> + case is_subscriptions_full(Session) + andalso (not maps:is_key(TopicFilter, Subs)) of + true -> {error, ?RC_QUOTA_EXCEEDED}; + false -> + do_subscribe(Client, TopicFilter, SubOpts, Session) end. -%% Client -> Broker: UNSUBSCRIBE --spec(unsubscribe(emqx_types:credentials(), emqx_mqtt:topic_filters(), session()) - -> {ok, list(emqx_mqtt:reason_code()), session()}). -unsubscribe(Credentials, RawTopicFilters, Session = #session{subscriptions = Subscriptions}) - when is_list(RawTopicFilters) -> - TopicFilters = lists:map(fun({RawTopic, Opts}) -> - emqx_topic:parse(RawTopic, Opts); - (RawTopic) when is_binary(RawTopic) -> - emqx_topic:parse(RawTopic) - end, RawTopicFilters), - {ReasonCodes, Subscriptions1} = - lists:foldr(fun({Topic, _SubOpts}, {Acc, SubMap}) -> - case maps:find(Topic, SubMap) of - {ok, SubOpts} -> - ok = emqx_broker:unsubscribe(Topic), - ok = emqx_hooks:run('session.unsubscribed', [Credentials, Topic, SubOpts]), - {[?RC_SUCCESS|Acc], maps:remove(Topic, SubMap)}; - error -> - {[?RC_NO_SUBSCRIPTION_EXISTED|Acc], SubMap} - end - end, {[], Subscriptions}, TopicFilters), - {ok, ReasonCodes, Session#session{subscriptions = Subscriptions1}}. +is_subscriptions_full(#session{max_subscriptions = 0}) -> + false; +is_subscriptions_full(#session{max_subscriptions = MaxLimit, + subscriptions = Subs}) -> + maps:size(Subs) >= MaxLimit. -%% Client -> Broker: QoS2 PUBLISH --spec(publish(emqx_mqtt:packet_id(), emqx_types:message(), session()) - -> {ok, emqx_types:deliver_results(), session()} | {error, emqx_mqtt:reason_code()}). -publish(PacketId, Msg = #message{qos = ?QOS_2, timestamp = Ts}, - Session = #session{awaiting_rel = AwaitingRel, - max_awaiting_rel = MaxAwaitingRel}) -> - case is_awaiting_full(MaxAwaitingRel, AwaitingRel) of +do_subscribe(Client = #{client_id := ClientId}, + TopicFilter, SubOpts, Session = #session{subscriptions = Subs}) -> + case IsNew = (not maps:is_key(TopicFilter, Subs)) of + true -> + ok = emqx_broker:subscribe(TopicFilter, ClientId, SubOpts); false -> - case maps:is_key(PacketId, AwaitingRel) of - false -> - DeliverResults = emqx_broker:publish(Msg), - AwaitingRel1 = maps:put(PacketId, Ts, AwaitingRel), - NSession = Session#session{awaiting_rel = AwaitingRel1}, - {ok, DeliverResults, ensure_await_rel_timer(NSession)}; - true -> - {error, ?RC_PACKET_IDENTIFIER_IN_USE} - end; + _ = emqx_broker:set_subopts(TopicFilter, SubOpts) + end, + ok = emqx_hooks:run('session.subscribed', + [Client, TopicFilter, SubOpts#{new => IsNew}]), + Subs1 = maps:put(TopicFilter, SubOpts, Subs), + {ok, Session#session{subscriptions = Subs1}}. + +%%-------------------------------------------------------------------- +%% Client -> Broker: UNSUBSCRIBE +%%-------------------------------------------------------------------- + +-spec(unsubscribe(emqx_types:client(), emqx_types:topic(), session()) + -> {ok, session()} | {error, emqx_types:reason_code()}). +unsubscribe(Client, TopicFilter, Session = #session{subscriptions = Subs}) -> + case maps:find(TopicFilter, Subs) of + {ok, SubOpts} -> + ok = emqx_broker:unsubscribe(TopicFilter), + ok = emqx_hooks:run('session.unsubscribed', [Client, TopicFilter, SubOpts]), + {ok, Session#session{subscriptions = maps:remove(TopicFilter, Subs)}}; + error -> + {error, ?RC_NO_SUBSCRIPTION_EXISTED} + end. + +%%-------------------------------------------------------------------- +%% Client -> Broker: PUBLISH +%%-------------------------------------------------------------------- + +-spec(publish(emqx_types:packet_id(), emqx_types:message(), session()) + -> {ok, emqx_types:deliver_results()} | + {ok, emqx_types:deliver_results(), session()} | + {error, emqx_types:reason_code()}). +publish(PacketId, Msg = #message{qos = ?QOS_2}, Session) -> + case is_awaiting_full(Session) of + false -> + do_publish(PacketId, Msg, Session); true -> ?LOG(warning, "Dropped qos2 packet ~w for too many awaiting_rel", [PacketId]), ok = emqx_metrics:inc('messages.qos2.dropped'), {error, ?RC_RECEIVE_MAXIMUM_EXCEEDED} end; -%% QoS0/1 -publish(_PacketId, Msg, Session) -> +%% Publish QoS0/1 directly +publish(_PacketId, Msg, _Session) -> {ok, emqx_broker:publish(Msg)}. +is_awaiting_full(#session{max_awaiting_rel = 0}) -> + false; +is_awaiting_full(#session{awaiting_rel = AwaitingRel, + max_awaiting_rel = MaxLimit}) -> + maps:size(AwaitingRel) >= MaxLimit. + +-compile({inline, [do_publish/3]}). +do_publish(PacketId, Msg = #message{timestamp = Ts}, + Session = #session{awaiting_rel = AwaitingRel}) -> + case maps:is_key(PacketId, AwaitingRel) of + false -> + DeliverResults = emqx_broker:publish(Msg), + AwaitingRel1 = maps:put(PacketId, Ts, AwaitingRel), + Session1 = Session#session{awaiting_rel = AwaitingRel1}, + {ok, DeliverResults, ensure_await_rel_timer(Session1)}; + true -> + {error, ?RC_PACKET_IDENTIFIER_IN_USE} + end. + +%%-------------------------------------------------------------------- %% Client -> Broker: PUBACK --spec(puback(emqx_mqtt:packet_id(), emqx_mqtt:reason_code(), session()) - -> puback_ret()). -puback(PacketId, _ReasonCode, Session = #session{inflight = Inflight, mqueue = Q}) -> +%%-------------------------------------------------------------------- + +-spec(puback(emqx_types:packet_id(), emqx_types:reason_code(), session()) + -> {ok, session()} | {error, emqx_types:reason_code()}). +puback(PacketId, _ReasonCode, Session = #session{inflight = Inflight}) -> case emqx_inflight:lookup(PacketId, Inflight) of - {value, {publish, {_, Msg}, _Ts}} -> - %% #{client_id => ClientId, username => Username} - %% ok = emqx_hooks:run('message.acked', [], Msg]), + {value, {Msg, _Ts}} when is_record(Msg, message) -> Inflight1 = emqx_inflight:delete(PacketId, Inflight), - Session1 = Session#session{inflight = Inflight1}, - case (emqx_mqueue:is_empty(Q) orelse emqx_mqueue:out(Q)) of - true -> {ok, Session1}; - {{value, Msg}, Q1} -> - {ok, Msg, Session1#session{mqueue = Q1}} - end; + dequeue(Session#session{inflight = Inflight1}); false -> ?LOG(warning, "The PUBACK PacketId ~w is not found", [PacketId]), ok = emqx_metrics:inc('packets.puback.missed'), {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND} end. +%%-------------------------------------------------------------------- %% Client -> Broker: PUBREC --spec(pubrec(emqx_mqtt:packet_id(), emqx_mqtt:reason_code(), session()) - -> {ok, session()} | {error, emqx_mqtt:reason_code()}). +%%-------------------------------------------------------------------- + +-spec(pubrec(emqx_types:packet_id(), emqx_types:reason_code(), session()) + -> {ok, session()} | {error, emqx_types:reason_code()}). pubrec(PacketId, _ReasonCode, Session = #session{inflight = Inflight}) -> case emqx_inflight:lookup(PacketId, Inflight) of - {value, {publish, {_, Msg}, _Ts}} -> - %% ok = emqx_hooks:run('message.acked', [#{client_id => ClientId, username => Username}, Msg]), - Inflight1 = emqx_inflight:update(PacketId, {pubrel, PacketId, os:timestamp()}, Inflight), + {value, {Msg, _Ts}} when is_record(Msg, message) -> + Inflight1 = emqx_inflight:update(PacketId, {pubrel, os:timestamp()}, Inflight), {ok, Session#session{inflight = Inflight1}}; - {value, {pubrel, PacketId, _Ts}} -> + {value, {pubrel, _Ts}} -> ?LOG(warning, "The PUBREC ~w is duplicated", [PacketId]), {error, ?RC_PACKET_IDENTIFIER_IN_USE}; none -> @@ -362,10 +346,13 @@ pubrec(PacketId, _ReasonCode, Session = #session{inflight = Inflight}) -> {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND} end. +%%-------------------------------------------------------------------- %% Client -> Broker: PUBREL --spec(pubrel(emqx_mqtt:packet_id(), emqx_mqtt:reason_code(), session()) - -> {ok, session()} | {error, emqx_mqtt:reason_code()}). -pubrel(PacketId, ReasonCode, Session = #session{awaiting_rel = AwaitingRel}) -> +%%-------------------------------------------------------------------- + +-spec(pubrel(emqx_types:packet_id(), emqx_types:reason_code(), session()) + -> {ok, session()} | {error, emqx_types:reason_code()}). +pubrel(PacketId, _ReasonCode, Session = #session{awaiting_rel = AwaitingRel}) -> case maps:take(PacketId, AwaitingRel) of {_Ts, AwaitingRel1} -> {ok, Session#session{awaiting_rel = AwaitingRel1}}; @@ -375,18 +362,17 @@ pubrel(PacketId, ReasonCode, Session = #session{awaiting_rel = AwaitingRel}) -> {error, ?RC_PACKET_IDENTIFIER_NOT_FOUND} end. +%%-------------------------------------------------------------------- %% Client -> Broker: PUBCOMP --spec(pubcomp(emqx_mqtt:packet_id(), emqx_mqtt:reason_code(), session()) -> puback_ret()). -pubcomp(PacketId, ReasonCode, Session = #session{inflight = Inflight, mqueue = Q}) -> +%%-------------------------------------------------------------------- + +-spec(pubcomp(emqx_types:packet_id(), emqx_types:reason_code(), session()) + -> {ok, session()} | {error, emqx_types:reason_code()}). +pubcomp(PacketId, _ReasonCode, Session = #session{inflight = Inflight}) -> case emqx_inflight:contain(PacketId, Inflight) of true -> Inflight1 = emqx_inflight:delete(PacketId, Inflight), - Session1 = Session#session{inflight = Inflight1}, - case (emqx_mqueue:is_empty(Q) orelse emqx_mqueue:out(Q)) of - true -> {ok, Session1}; - {{value, Msg}, Q1} -> - {ok, Msg, Session1#session{mqueue = Q1}} - end; + dequeue(Session#session{inflight = Inflight1}); false -> ?LOG(warning, "The PUBCOMP PacketId ~w is not found", [PacketId]), ok = emqx_metrics:inc('packets.pubcomp.missed'), @@ -394,32 +380,59 @@ pubcomp(PacketId, ReasonCode, Session = #session{inflight = Inflight, mqueue = Q end. %%-------------------------------------------------------------------- -%% Handle delivery +%% Dequeue Msgs %%-------------------------------------------------------------------- -deliver(Delivers, Session = #session{subscriptions = SubMap}) +dequeue(Session = #session{inflight = Inflight, mqueue = Q}) -> + case emqx_mqueue:is_empty(Q) of + true -> {ok, Session}; + false -> + {Msgs, Q1} = dequeue(batch_n(Inflight), [], Q), + handle(lists:reverse(Msgs), [], Session#session{mqueue = Q1}) + end. + +dequeue(Cnt, Msgs, Q) when Cnt =< 0 -> + {Msgs, Q}; + +dequeue(Cnt, Msgs, Q) -> + case emqx_mqueue:out(Q) of + {empty, _Q} -> {Msgs, Q}; + {{value, Msg}, Q1} -> + dequeue(Cnt-1, [Msg|Msgs], Q1) + end. + +batch_n(Inflight) -> + case emqx_inflight:max_size(Inflight) of + 0 -> ?DEFAULT_BATCH_N; + Sz -> Sz - emqx_inflight:size(Inflight) + end. + +%%-------------------------------------------------------------------- +%% Broker -> Client: Publish | Msg +%%-------------------------------------------------------------------- + +handle(Delivers, Session = #session{subscriptions = Subs}) when is_list(Delivers) -> - Msgs = [enrich(get_subopts(Topic, SubMap), Msg, Session) - || {Topic, Msg} <- Delivers], - deliver(Msgs, [], Session). + Msgs = [enrich(get_subopts(Topic, Subs), Msg, Session) + || {deliver, Topic, Msg} <- Delivers], + handle(Msgs, [], Session). - -deliver([], Publishes, Session) -> +handle([], Publishes, Session) -> {ok, lists:reverse(Publishes), Session}; -deliver([Msg = #message{qos = ?QOS_0}|More], Acc, Session) -> - deliver(More, [{publish, undefined, Msg}|Acc], Session); +handle([Msg = #message{qos = ?QOS_0}|More], Acc, Session) -> + handle(More, [{publish, undefined, Msg}|Acc], Session); -deliver([Msg = #message{qos = QoS}|More], Acc, - Session = #session{next_pkt_id = PacketId, inflight = Inflight}) +handle([Msg = #message{qos = QoS}|More], Acc, + Session = #session{next_pkt_id = PacketId, inflight = Inflight}) when QoS =:= ?QOS_1 orelse QoS =:= ?QOS_2 -> case emqx_inflight:is_full(Inflight) of true -> - deliver(More, Acc, enqueue(Msg, Session)); + handle(More, Acc, enqueue(Msg, Session)); false -> Publish = {publish, PacketId, Msg}, - NSession = await(PacketId, Msg, Session), - deliver(More, [Publish|Acc], next_pkt_id(NSession)) + Session1 = await(PacketId, Msg, Session), + handle(More, [Publish|Acc], next_pkt_id(Session1)) end. enqueue(Msg, Session = #session{mqueue = Q}) -> @@ -427,19 +440,20 @@ enqueue(Msg, Session = #session{mqueue = Q}) -> {Dropped, NewQ} = emqx_mqueue:in(Msg, Q), if Dropped =/= undefined -> + %% TODO:... %% SessProps = #{client_id => ClientId, username => Username}, ok; %% = emqx_hooks:run('message.dropped', [SessProps, Dropped]); true -> ok end, Session#session{mqueue = NewQ}. -%%------------------------------------------------------------------------------ +%%-------------------------------------------------------------------- %% Awaiting ACK for QoS1/QoS2 Messages -%%------------------------------------------------------------------------------ +%%-------------------------------------------------------------------- await(PacketId, Msg, Session = #session{inflight = Inflight}) -> - Publish = {publish, {PacketId, Msg}, os:timestamp()}, - Inflight1 = emqx_inflight:insert(PacketId, Publish, Inflight), + Inflight1 = emqx_inflight:insert( + PacketId, {Msg, os:timestamp()}, Inflight), ensure_retry_timer(Session#session{inflight = Inflight1}). get_subopts(Topic, SubMap) -> @@ -470,11 +484,28 @@ enrich([{rap, _}|Opts], Msg, Session) -> enrich([{subid, SubId}|Opts], Msg, Session) -> enrich(Opts, emqx_message:set_header('Subscription-Identifier', SubId, Msg), Session). +%%-------------------------------------------------------------------- +%% Handle timeout +%%-------------------------------------------------------------------- + +-spec(timeout(reference(), atom(), session()) + -> {ok, session()} | {ok, list(), session()}). +timeout(TRef, retry_delivery, Session = #session{retry_timer = TRef}) -> + retry_delivery(Session#session{retry_timer = undefined}); + +timeout(TRef, check_awaiting_rel, Session = #session{await_rel_timer = TRef}) -> + expire_awaiting_rel(Session); + +timeout(TRef, Msg, Session) -> + ?LOG(error, "unexpected timeout - ~p: ~p", [TRef, Msg]), + {ok, Session}. + %%-------------------------------------------------------------------- %% Ensure retry timer %%-------------------------------------------------------------------- -ensure_retry_timer(Session = #session{retry_interval = Interval, retry_timer = undefined}) -> +ensure_retry_timer(Session = #session{retry_interval = Interval, + retry_timer = undefined}) -> ensure_retry_timer(Interval, Session); ensure_retry_timer(Session) -> Session. @@ -486,13 +517,48 @@ ensure_retry_timer(_Interval, Session) -> Session. %%-------------------------------------------------------------------- -%% Check awaiting rel +%% Retry Delivery %%-------------------------------------------------------------------- -is_awaiting_full(_MaxAwaitingRel = 0, _AwaitingRel) -> - false; -is_awaiting_full(MaxAwaitingRel, AwaitingRel) -> - maps:size(AwaitingRel) >= MaxAwaitingRel. +%% Redeliver at once if force is true +retry_delivery(Session = #session{inflight = Inflight}) -> + case emqx_inflight:is_empty(Inflight) of + true -> {ok, Session}; + false -> + SortFun = fun({_, {_, Ts1}}, {_, {_, Ts2}}) -> Ts1 < Ts2 end, + Msgs = lists:sort(SortFun, emqx_inflight:to_list(Inflight)), + retry_delivery(Msgs, os:timestamp(), [], Session) + end. + +retry_delivery([], _Now, Acc, Session) -> + %% Retry again... + {ok, lists:reverse(Acc), ensure_retry_timer(Session)}; + +retry_delivery([{PacketId, {Val, Ts}}|More], Now, Acc, + Session = #session{retry_interval = Interval, inflight = Inflight}) -> + %% Microseconds -> MilliSeconds + Age = timer:now_diff(Now, Ts) div 1000, + if + Age >= Interval -> + {Acc1, Inflight1} = retry_delivery(PacketId, Val, Now, Acc, Inflight), + retry_delivery(More, Now, Acc1, Session#session{inflight = Inflight1}); + true -> + {ok, lists:reverse(Acc), ensure_retry_timer(Interval - max(0, Age), Session)} + end. + +retry_delivery(PacketId, Msg, Now, Acc, Inflight) when is_record(Msg, message) -> + case emqx_message:is_expired(Msg) of + true -> + ok = emqx_metrics:inc('messages.expired'), + {Acc, emqx_inflight:delete(PacketId, Inflight)}; + false -> + {[{publish, PacketId, Msg}|Acc], + emqx_inflight:update(PacketId, {Msg, Now}, Inflight)} + end; + +retry_delivery(PacketId, pubrel, Now, Acc, Inflight) -> + Inflight1 = emqx_inflight:update(PacketId, {pubrel, Now}, Inflight), + {[{pubrel, PacketId}|Acc], Inflight1}. %%-------------------------------------------------------------------- %% Ensure await_rel timer @@ -516,22 +582,21 @@ ensure_await_rel_timer(_Timeout, Session) -> expire_awaiting_rel(Session = #session{awaiting_rel = AwaitingRel}) -> case maps:size(AwaitingRel) of - 0 -> Session; + 0 -> {ok, Session}; _ -> expire_awaiting_rel(lists:keysort(2, maps:to_list(AwaitingRel)), os:timestamp(), Session) end. expire_awaiting_rel([], _Now, Session) -> - Session#session{await_rel_timer = undefined}; + {ok, Session#session{await_rel_timer = undefined}}; expire_awaiting_rel([{PacketId, Ts} | More], Now, - Session = #session{awaiting_rel = AwaitingRel, - await_rel_timeout = Timeout}) -> + Session = #session{awaiting_rel = AwaitingRel, await_rel_timeout = Timeout}) -> case (timer:now_diff(Now, Ts) div 1000) of Age when Age >= Timeout -> ok = emqx_metrics:inc('messages.qos2.expired'), ?LOG(warning, "Dropped qos2 packet ~s for await_rel_timeout", [PacketId]), - NSession = Session#session{awaiting_rel = maps:remove(PacketId, AwaitingRel)}, - expire_awaiting_rel(More, Now, NSession); + Session1 = Session#session{awaiting_rel = maps:remove(PacketId, AwaitingRel)}, + expire_awaiting_rel(More, Now, Session1); Age -> ensure_await_rel_timer(Timeout - max(0, Age), Session) end. diff --git a/src/emqx_topic.erl b/src/emqx_topic.erl index 86cbcf547..dd187ddf6 100644 --- a/src/emqx_topic.erl +++ b/src/emqx_topic.erl @@ -212,28 +212,27 @@ join(Words) -> end, {true, <<>>}, [bin(W) || W <- Words]), Bin. --spec(parse(topic()) -> {topic(), #{}}). -parse(Topic) when is_binary(Topic) -> - parse(Topic, #{}). +-spec(parse(topic() | {topic(), map()}) -> {topic(), #{share => binary()}}). +parse(TopicFilter) when is_binary(TopicFilter) -> + parse(TopicFilter, #{}); +parse({TopicFilter, Options}) when is_binary(TopicFilter) -> + parse(TopicFilter, Options). -parse(Topic = <<"$queue/", _/binary>>, #{share := _Group}) -> - error({invalid_topic, Topic}); -parse(Topic = <>, #{share := _Group}) -> - error({invalid_topic, Topic}); -parse(<<"$queue/", Topic1/binary>>, Options) -> - parse(Topic1, maps:put(share, <<"$queue">>, Options)); -parse(Topic = <>, Options) -> - case binary:split(Topic1, <<"/">>) of - [<<>>] -> error({invalid_topic, Topic}); - [_] -> error({invalid_topic, Topic}); - [Group, Topic2] -> - case binary:match(Group, [<<"/">>, <<"+">>, <<"#">>]) of - nomatch -> {Topic2, maps:put(share, Group, Options)}; - _ -> error({invalid_topic, Topic}) +parse(TopicFilter = <<"$queue/", _/binary>>, #{share := _Group}) -> + error({invalid_topic_filter, TopicFilter}); +parse(TopicFilter = <>, #{share := _Group}) -> + error({invalid_topic_filter, TopicFilter}); +parse(<<"$queue/", TopicFilter/binary>>, Options) -> + parse(TopicFilter, Options#{share => <<"$queue">>}); +parse(TopicFilter = <>, Options) -> + case binary:split(Rest, <<"/">>) of + [_Any] -> error({invalid_topic_filter, TopicFilter}); + [ShareName, Filter] -> + case binary:match(ShareName, [<<"+">>, <<"#">>]) of + nomatch -> parse(Filter, Options#{share => ShareName}); + _ -> error({invalid_topic_filter, TopicFilter}) end end; -parse(Topic, Options = #{qos := QoS}) -> - {Topic, Options#{rc => QoS}}; -parse(Topic, Options) -> - {Topic, Options}. +parse(TopicFilter, Options) -> + {TopicFilter, Options}. diff --git a/src/emqx_types.erl b/src/emqx_types.erl index 57dfe7e8c..e4ec9ec75 100644 --- a/src/emqx_types.erl +++ b/src/emqx_types.erl @@ -18,23 +18,39 @@ -include("emqx.hrl"). -include("types.hrl"). +-include("emqx_mqtt.hrl"). -export_type([zone/0]). +-export_type([ ver/0 + , qos/0 + , qos_name/0 + ]). + -export_type([ pubsub/0 , topic/0 , subid/0 , subopts/0 ]). --export_type([ client_id/0 +-export_type([ client/0 + , client_id/0 , username/0 , password/0 , peername/0 , protocol/0 ]). --export_type([credentials/0]). +-export_type([ connack/0 + , reason_code/0 + , properties/0 + , topic_filters/0 + ]). + +-export_type([ packet_id/0 + , packet_type/0 + , packet/0 + ]). -export_type([ subscription/0 , subscriber/0 @@ -49,22 +65,56 @@ , deliver_results/0 ]). --export_type([route/0]). - --export_type([ alarm/0 +-export_type([ route/0 + , alarm/0 , plugin/0 , banned/0 , command/0 ]). --type(zone() :: atom()). +-type(zone() :: emqx_zone:zone()). +-type(ver() :: ?MQTT_PROTO_V3 | ?MQTT_PROTO_V4 | ?MQTT_PROTO_V5). +-type(qos() :: ?QOS_0 | ?QOS_1 | ?QOS_2). +-type(qos_name() :: qos0 | at_most_once | + qos1 | at_least_once | + qos2 | exactly_once). + -type(pubsub() :: publish | subscribe). --type(topic() :: binary()). +-type(topic() :: emqx_topic:topic()). -type(subid() :: binary() | atom()). --type(subopts() :: #{qos := emqx_mqtt_types:qos(), + +-type(subopts() :: #{rh := 0 | 1 | 2, + rap := 0 | 1, + nl := 0 | 1, + qos := qos(), + rc => reason_code(), share => binary(), atom() => term() }). + +-type(packet_type() :: ?RESERVED..?AUTH). +-type(connack() :: ?CONNACK_ACCEPT..?CONNACK_AUTH). +-type(reason_code() :: 0..16#FF). +-type(packet_id() :: 1..16#FFFF). +-type(properties() :: #{atom() => term()}). +-type(topic_filters() :: list({emqx_topic:topic(), subopts()})). +-type(packet() :: #mqtt_packet{}). + +-type(client() :: #{zone := zone(), + conn_mod := maybe(module()), + peername := peername(), + sockname := peername(), + client_id := client_id(), + username := username(), + peercert := esockd_peercert:peercert(), + is_superuser := boolean(), + mountpoint := maybe(binary()), + ws_cookie := maybe(list()), + password => maybe(binary()), + auth_result => auth_result(), + anonymous => boolean(), + atom() => term() + }). -type(client_id() :: binary() | atom()). -type(username() :: maybe(binary())). -type(password() :: maybe(binary())). @@ -79,18 +129,6 @@ | banned | bad_authentication_method). -type(protocol() :: mqtt | 'mqtt-sn' | coap | stomp | none | atom()). --type(credentials() :: #{zone := zone(), - client_id := client_id(), - username := username(), - sockname := peername(), - peername := peername(), - ws_cookie := undefined | list(), - mountpoint := binary(), - password => binary(), - auth_result => auth_result(), - anonymous => boolean(), - atom() => term() - }). -type(subscription() :: #subscription{}). -type(subscriber() :: {pid(), subid()}). -type(topic_table() :: [{topic(), subopts()}]). diff --git a/src/emqx_ws_connection.erl b/src/emqx_ws_channel.erl similarity index 51% rename from src/emqx_ws_connection.erl rename to src/emqx_ws_channel.erl index b21897b5e..4da81a2d7 100644 --- a/src/emqx_ws_connection.erl +++ b/src/emqx_ws_channel.erl @@ -14,22 +14,22 @@ %% limitations under the License. %%-------------------------------------------------------------------- -%% MQTT WebSocket Connection --module(emqx_ws_connection). +%% MQTT WebSocket Channel +-module(emqx_ws_channel). -include("emqx.hrl"). -include("emqx_mqtt.hrl"). -include("logger.hrl"). -include("types.hrl"). --logger_header("[WS Conn]"). +-logger_header("[WsChannel]"). -export([ info/1 , attrs/1 , stats/1 ]). -%% websocket callbacks +%% WebSocket callbacks -export([ init/2 , websocket_init/1 , websocket_handle/2 @@ -38,13 +38,15 @@ ]). -record(state, { - request, - options, - peername :: {inet:ip_address(), inet:port_number()}, - sockname :: {inet:ip_address(), inet:port_number()}, + request :: cowboy_req:req(), + options :: proplists:proplist(), + peername :: emqx_types:peername(), + sockname :: emqx_types:peername(), + fsm_state :: idle | connected | disconnected, + serialize :: fun((emqx_types:packet()) -> iodata()), parse_state :: emqx_frame:parse_state(), - packets :: list(emqx_mqtt:packet()), - chan_state :: emqx_channel:channel(), + proto_state :: emqx_protocol:proto_state(), + gc_state :: emqx_gc:gc_state(), keepalive :: maybe(emqx_keepalive:keepalive()), stats_timer :: disabled | maybe(reference()), idle_timeout :: timeout(), @@ -64,14 +66,12 @@ info(WSPid) when is_pid(WSPid) -> info(#state{peername = Peername, sockname = Sockname, - chan_state = ChanState}) -> - ConnInfo = #{socktype => websocket, - conn_state => running, - peername => Peername, - sockname => Sockname - }, - ChanInfo = emqx_channel:info(ChanState), - maps:merge(ConnInfo, ChanInfo). + proto_state = ProtoState}) -> + [{socktype, websocket}, + {conn_state, running}, + {peername, Peername}, + {sockname, Sockname} | + emqx_protocol:info(ProtoState)]. %% for dashboard attrs(WSPid) when is_pid(WSPid) -> @@ -79,11 +79,10 @@ attrs(WSPid) when is_pid(WSPid) -> attrs(#state{peername = Peername, sockname = Sockname, - chan_state = ChanState}) -> - SockAttrs = #{peername => Peername, - sockname => Sockname}, - ChanAttrs = emqx_channel:attrs(ChanState), - maps:merge(SockAttrs, ChanAttrs). + proto_state = ProtoState}) -> + [{peername, Peername}, + {sockname, Sockname} | + emqx_protocol:attrs(ProtoState)]. stats(WSPid) when is_pid(WSPid) -> call(WSPid, stats); @@ -91,12 +90,6 @@ stats(WSPid) when is_pid(WSPid) -> stats(#state{}) -> lists:append([chan_stats(), wsock_stats(), emqx_misc:proc_stats()]). -%%kick(WSPid) when is_pid(WSPid) -> -%% call(WSPid, kick). - -%%session(WSPid) when is_pid(WSPid) -> -%% call(WSPid, session). - call(WSPid, Req) when is_pid(WSPid) -> Mref = erlang:monitor(process, WSPid), WSPid ! {call, {self(), Mref}, Req}, @@ -120,7 +113,7 @@ init(Req, Opts) -> DeflateOptions = maps:from_list(proplists:get_value(deflate_options, Opts, [])), MaxFrameSize = case proplists:get_value(max_frame_size, Opts, 0) of 0 -> infinity; - MFS -> MFS + I -> I end, Compress = proplists:get_value(compress, Opts, false), Options = #{compress => Compress, @@ -151,80 +144,59 @@ websocket_init(#state{request = Req, options = Options}) -> [Error, Reason]), undefined end, - ChanState = emqx_channel:init(#{peername => Peername, - sockname => Sockname, - peercert => Peercert, - ws_cookie => WsCookie, - conn_mod => ?MODULE}, Options), + ProtoState = emqx_protocol:init(#{peername => Peername, + sockname => Sockname, + peercert => Peercert, + ws_cookie => WsCookie, + conn_mod => ?MODULE}, Options), Zone = proplists:get_value(zone, Options), MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE), ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}), + GcPolicy = emqx_zone:get_env(Zone, force_gc_policy, false), + GcState = emqx_gc:init(GcPolicy), EnableStats = emqx_zone:get_env(Zone, enable_stats, true), StatsTimer = if EnableStats -> undefined; ?Otherwise-> disabled end, IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), emqx_logger:set_metadata_peername(esockd_net:format(Peername)), ok = emqx_misc:init_proc_mng_policy(Zone), + %% TODO: Idle timeout? {ok, #state{peername = Peername, sockname = Sockname, + fsm_state = idle, parse_state = ParseState, - chan_state = ChanState, + proto_state = ProtoState, + gc_state = GcState, stats_timer = StatsTimer, idle_timeout = IdleTimout }}. -send_fun(WsPid) -> - fun(Packet, Options) -> - Data = emqx_frame:serialize(Packet, Options), - BinSize = iolist_size(Data), - emqx_pd:update_counter(send_cnt, 1), - emqx_pd:update_counter(send_oct, BinSize), - WsPid ! {binary, iolist_to_binary(Data)}, - {ok, Data} - end. - stat_fun() -> fun() -> {ok, emqx_pd:get_counter(recv_oct)} end. -websocket_handle({binary, <<>>}, State) -> - {ok, ensure_stats_timer(State)}; -websocket_handle({binary, [<<>>]}, State) -> - {ok, ensure_stats_timer(State)}; -websocket_handle({binary, Data}, State = #state{parse_state = ParseState}) -> +websocket_handle({binary, Data}, State) when is_list(Data) -> + websocket_handle({binary, iolist_to_binary(Data)}, State); + +websocket_handle({binary, Data}, State) when is_binary(Data) -> ?LOG(debug, "RECV ~p", [Data]), - BinSize = iolist_size(Data), - emqx_pd:update_counter(recv_oct, BinSize), - ok = emqx_metrics:inc('bytes.received', BinSize), - try emqx_frame:parse(iolist_to_binary(Data), ParseState) of - {ok, NParseState} -> - {ok, State#state{parse_state = NParseState}}; - {ok, Packet, Rest, NParseState} -> - ok = emqx_metrics:inc_recv(Packet), - emqx_pd:update_counter(recv_cnt, 1), - handle_incoming(Packet, fun(NState) -> - websocket_handle({binary, Rest}, NState) - end, - State#state{parse_state = NParseState}); - {error, Reason} -> - ?LOG(error, "Frame error: ~p", [Reason]), - shutdown(Reason, State) - catch - error:Reason:Stk -> - ?LOG(error, "Parse failed for ~p~n\ - Stacktrace:~p~nFrame data: ~p", [Reason, Stk, Data]), - shutdown(parse_error, State) - end; + Oct = iolist_size(Data), + emqx_pd:update_counter(recv_cnt, 1), + emqx_pd:update_counter(recv_oct, Oct), + ok = emqx_metrics:inc('bytes.received', Oct), + NState = maybe_gc(1, Oct, State), + process_incoming(Data, ensure_stats_timer(NState)); + %% Pings should be replied with pongs, cowboy does it automatically %% Pongs can be safely ignored. Clause here simply prevents crash. websocket_handle(Frame, State) when Frame =:= ping; Frame =:= pong -> - {ok, ensure_stats_timer(State)}; + {ok, State}; websocket_handle({FrameType, _}, State) when FrameType =:= ping; FrameType =:= pong -> - {ok, ensure_stats_timer(State)}; + {ok, State}; %% According to mqtt spec[https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901285] -websocket_handle({_OtherFrameType, _}, State) -> - ?LOG(error, "Frame error: Other type of data frame"), - shutdown(other_frame_type, State). +websocket_handle({FrameType, _}, State) -> + ?LOG(error, "Frame error: unexpected frame - ~p", [FrameType]), + shutdown(unexpected_ws_frame, State). websocket_info({call, From, info}, State) -> gen_server:reply(From, info(State)), @@ -242,31 +214,60 @@ websocket_info({call, From, kick}, State) -> gen_server:reply(From, ok), shutdown(kick, State); -websocket_info(Delivery, State = #state{chan_state = ChanState}) - when element(1, Delivery) =:= deliver -> - case emqx_channel:handle_out(Delivery, ChanState) of - {ok, NChanState} -> - {ok, State#state{chan_state = NChanState}}; - {ok, Packet, NChanState} -> - handle_outgoing(Packet, State#state{chan_state = NChanState}); - {error, Reason} -> - shutdown(Reason, State) +websocket_info({incoming, Packet = ?CONNECT_PACKET( + #mqtt_packet_connect{ + proto_ver = ProtoVer} + )}, + State = #state{fsm_state = idle}) -> + State1 = State#state{serialize = serialize_fun(ProtoVer)}, + %% TODO: Fixme later + case handle_incoming(Packet, State1) of + Rep = {reply, _Data, _State} -> + self() ! {enter, connected}, + Rep; + Other -> Other end; -websocket_info({timeout, Timer, emit_stats}, - State = #state{stats_timer = Timer, chan_state = ChanState}) -> - ClientId = emqx_channel:client_id(ChanState), - ok = emqx_cm:set_conn_stats(ClientId, stats(State)), - {ok, State#state{stats_timer = undefined}, hibernate}; +websocket_info({incoming, Packet}, State = #state{fsm_state = idle}) -> + ?LOG(warning, "Unexpected incoming: ~p", [Packet]), + shutdown(unexpected_incoming_packet, State); -websocket_info({keepalive, start, Interval}, State) -> - ?LOG(debug, "Keepalive at the interval of ~p", [Interval]), - case emqx_keepalive:start(stat_fun(), Interval, {keepalive, check}) of +websocket_info({enter, connected}, State = #state{proto_state = ProtoState}) -> + ClientId = emqx_protocol:client_id(ProtoState), + ok = emqx_cm:set_chan_attrs(ClientId, info(State)), + %% Ensure keepalive after connected successfully. + Interval = emqx_protocol:info(keepalive, ProtoState), + State1 = State#state{fsm_state = connected}, + case ensure_keepalive(Interval, State1) of + ignore -> {ok, State1}; {ok, KeepAlive} -> - {ok, State#state{keepalive = KeepAlive}}; - {error, Error} -> - ?LOG(warning, "Keepalive error: ~p", [Error]), - shutdown(Error, State) + {ok, State1#state{keepalive = KeepAlive}}; + {error, Reason} -> + shutdown(Reason, State1) + end; + +websocket_info({incoming, Packet = ?PACKET(?CONNECT)}, + State = #state{fsm_state = connected}) -> + ?LOG(warning, "Unexpected connect: ~p", [Packet]), + shutdown(unexpected_incoming_connect, State); + +websocket_info({incoming, Packet}, State = #state{fsm_state = connected}) + when is_record(Packet, mqtt_packet) -> + handle_incoming(Packet, State); + +websocket_info(Deliver = {deliver, _Topic, _Msg}, + State = #state{proto_state = ProtoState}) -> + Delivers = emqx_misc:drain_deliver([Deliver]), + case emqx_protocol:handle_deliver(Delivers, ProtoState) of + {ok, NProtoState} -> + {ok, State#state{proto_state = NProtoState}}; + {ok, Packets, NProtoState} -> + NState = State#state{proto_state = NProtoState}, + handle_outgoing(Packets, NState); + {error, Reason} -> + shutdown(Reason, State); + {error, Reason, NProtoState} -> + shutdown(Reason, State#state{proto_state = NProtoState}) end; websocket_info({keepalive, check}, State = #state{keepalive = KeepAlive}) -> @@ -281,6 +282,39 @@ websocket_info({keepalive, check}, State = #state{keepalive = KeepAlive}) -> shutdown(keepalive_error, State) end; +websocket_info({timeout, Timer, emit_stats}, + State = #state{stats_timer = Timer, + proto_state = ProtoState, + gc_state = GcState}) -> + ClientId = emqx_protocol:client_id(ProtoState), + ok = emqx_cm:set_chan_stats(ClientId, stats(State)), + NState = State#state{stats_timer = undefined}, + Limits = erlang:get(force_shutdown_policy), + case emqx_misc:conn_proc_mng_policy(Limits) of + continue -> + {ok, NState}; + hibernate -> + %% going to hibernate, reset gc stats + GcState1 = emqx_gc:reset(GcState), + {ok, NState#state{gc_state = GcState1}, hibernate}; + {shutdown, Reason} -> + ?LOG(error, "Shutdown exceptionally due to ~p", [Reason]), + shutdown(Reason, NState) + end; + +websocket_info({timeout, Timer, Msg}, + State = #state{proto_state = ProtoState}) -> + case emqx_protocol:handle_timeout(Timer, Msg, ProtoState) of + {ok, NProtoState} -> + {ok, State#state{proto_state = NProtoState}}; + {ok, Packets, NProtoState} -> + handle_outgoing(Packets, State#state{proto_state = NProtoState}); + {error, Reason} -> + shutdown(Reason, State); + {error, Reason, NProtoState} -> + shutdown(Reason, State#state{proto_state = NProtoState}) + end; + websocket_info({shutdown, discard, {ClientId, ByPid}}, State) -> ?LOG(warning, "Discarded by ~s:~p", [ClientId, ByPid]), shutdown(discard, State); @@ -302,58 +336,123 @@ websocket_info(Info, State) -> ?LOG(error, "Unexpected info: ~p", [Info]), {ok, State}. -terminate(SockError, _Req, #state{keepalive = Keepalive, - chan_state = ChanState, - shutdown = Shutdown}) -> +terminate(SockError, _Req, #state{keepalive = Keepalive, + proto_state = ProtoState, + shutdown = Shutdown}) -> ?LOG(debug, "Terminated for ~p, sockerror: ~p", [Shutdown, SockError]), emqx_keepalive:cancel(Keepalive), - case {ChanState, Shutdown} of + case {ProtoState, Shutdown} of {undefined, _} -> ok; {_, {shutdown, Reason}} -> - emqx_channel:terminate(Reason, ChanState); + emqx_protocol:terminate(Reason, ProtoState); {_, Error} -> - emqx_channel:terminate(Error, ChanState) + emqx_protocol:terminate(Error, ProtoState) end. %%-------------------------------------------------------------------- -%% Internal functions -%%-------------------------------------------------------------------- +%% Ensure keepalive -handle_incoming(Packet = ?PACKET(Type), SuccFun, - State = #state{chan_state = ChanState}) -> +ensure_keepalive(0, _State) -> + ignore; +ensure_keepalive(Interval, State = #state{proto_state = ProtoState}) -> + Backoff = emqx_zone:get_env(emqx_protocol:info(zone, ProtoState), + keepalive_backoff, 0.75), + case emqx_keepalive:start(stat_fun(), round(Interval * Backoff), {keepalive, check}) of + {ok, KeepAlive} -> + {ok, State#state{keepalive = KeepAlive}}; + {error, Error} -> + ?LOG(warning, "Keepalive error: ~p", [Error]), + shutdown(Error, State) + end. + +%%-------------------------------------------------------------------- +%% Process incoming data + +process_incoming(<<>>, State) -> + {ok, State}; + +process_incoming(Data, State = #state{parse_state = ParseState}) -> + try emqx_frame:parse(Data, ParseState) of + {ok, NParseState} -> + {ok, State#state{parse_state = NParseState}}; + {ok, Packet, Rest, NParseState} -> + self() ! {incoming, Packet}, + process_incoming(Rest, State#state{parse_state = NParseState}); + {error, Reason} -> + ?LOG(error, "Frame error: ~p", [Reason]), + shutdown(Reason, State) + catch + error:Reason:Stk -> + ?LOG(error, "Parse failed for ~p~n\ + Stacktrace:~p~nFrame data: ~p", [Reason, Stk, Data]), + shutdown(parse_error, State) + end. + +%%-------------------------------------------------------------------- +%% Handle incoming packets + +handle_incoming(Packet = ?PACKET(Type), State = #state{proto_state = ProtoState}) -> _ = inc_incoming_stats(Type), - case emqx_channel:handle_in(Packet, ChanState) of - {ok, NChanState} -> - SuccFun(State#state{chan_state = NChanState}); - {ok, OutPacket, NChanState} -> - %% TODO: SuccFun, - handle_outgoing(OutPacket, State#state{chan_state = NChanState}); - {error, Reason, NChanState} -> - shutdown(Reason, State#state{chan_state = NChanState}); - {stop, Error, NChanState} -> - shutdown(Error, State#state{chan_state = NChanState}) + ok = emqx_metrics:inc_recv(Packet), + ?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]), + case emqx_protocol:handle_in(Packet, ProtoState) of + {ok, NProtoState} -> + {ok, State#state{proto_state = NProtoState}}; + {ok, OutPackets, NProtoState} -> + handle_outgoing(OutPackets, State#state{proto_state = NProtoState}); + {error, Reason, NProtoState} -> + shutdown(Reason, State#state{proto_state = NProtoState}); + {stop, Error, NProtoState} -> + shutdown(Error, State#state{proto_state = NProtoState}) end. -handle_outgoing(Packet = ?PACKET(Type), State = #state{chan_state = ChanState}) -> - ProtoVer = emqx_channel:info(proto_ver, ChanState), - Data = emqx_frame:serialize(Packet, ProtoVer), - BinSize = iolist_size(Data), - _ = inc_outgoing_stats(Type, BinSize), - {reply, {binary, Data}, ensure_stats_timer(State)}. +%%-------------------------------------------------------------------- +%% Handle outgoing packets + +handle_outgoing(Packets, State = #state{serialize = Serialize}) + when is_list(Packets) -> + reply(lists:map(Serialize, Packets), State); + +handle_outgoing(Packet, State = #state{serialize = Serialize}) -> + reply(Serialize(Packet), State). + +%%-------------------------------------------------------------------- +%% Serialize fun + +serialize_fun(ProtoVer) -> + fun(Packet = ?PACKET(Type)) -> + ?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]), + _ = inc_outgoing_stats(Type), + emqx_frame:serialize(Packet, ProtoVer) + end. + +%%-------------------------------------------------------------------- +%% Inc incoming/outgoing stats inc_incoming_stats(Type) -> emqx_pd:update_counter(recv_pkt, 1), (Type == ?PUBLISH) andalso emqx_pd:update_counter(recv_msg, 1). -inc_outgoing_stats(Type, BinSize) -> +inc_outgoing_stats(Type) -> emqx_pd:update_counter(send_cnt, 1), - emqx_pd:update_counter(send_oct, BinSize), emqx_pd:update_counter(send_pkt, 1), (Type == ?PUBLISH) andalso emqx_pd:update_counter(send_msg, 1). +%%-------------------------------------------------------------------- +%% Reply data + +-compile({inline, [reply/2]}). +reply(Data, State) -> + BinSize = iolist_size(Data), + emqx_pd:update_counter(send_oct, BinSize), + {reply, {binary, Data}, State}. + +%%-------------------------------------------------------------------- +%% Ensure stats timer + ensure_stats_timer(State = #state{stats_timer = undefined, idle_timeout = IdleTimeout}) -> TRef = emqx_misc:start_timer(IdleTimeout, emit_stats), @@ -361,6 +460,16 @@ ensure_stats_timer(State = #state{stats_timer = undefined, %% disabled or timer existed ensure_stats_timer(State) -> State. +%%-------------------------------------------------------------------- +%% Maybe GC + +maybe_gc(_Cnt, _Oct, State = #state{gc_state = undefined}) -> + State; +maybe_gc(Cnt, Oct, State = #state{gc_state = GCSt}) -> + {Ok, GCSt1} = emqx_gc:run(Cnt, Oct, GCSt), + Ok andalso emqx_metrics:inc('channel.gc.cnt'), + State#state{gc_state = GCSt1}. + -compile({inline, [shutdown/2]}). shutdown(Reason, State) -> %% Fix the issue#2591(https://github.com/emqx/emqx/issues/2591#issuecomment-500278696) diff --git a/test/emqx_inflight_SUITE.erl b/test/emqx_inflight_SUITE.erl index d0373e11b..533194f1e 100644 --- a/test/emqx_inflight_SUITE.erl +++ b/test/emqx_inflight_SUITE.erl @@ -1,4 +1,5 @@ -%% Copyright (c) 2013-2019 EMQ Technologies Co., Ltd. All Rights Reserved. +%%-------------------------------------------------------------------- +%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -11,6 +12,7 @@ %% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. %% See the License for the specific language governing permissions and %% limitations under the License. +%%-------------------------------------------------------------------- -module(emqx_inflight_SUITE). @@ -39,3 +41,4 @@ t_inflight_all(_) -> [1, 2] = emqx_inflight:values(Inflight2), [{a, 1}, {b ,2}] = emqx_inflight:to_list(Inflight2), [a, b] = emqx_inflight:window(Inflight2). + diff --git a/test/emqx_topic_SUITE.erl b/test/emqx_topic_SUITE.erl index 0a51598ee..202f8beb8 100644 --- a/test/emqx_topic_SUITE.erl +++ b/test/emqx_topic_SUITE.erl @@ -1,4 +1,5 @@ -%% Copyright (c) 2013-2019 EMQ Technologies Co., Ltd. All Rights Reserved. +%%-------------------------------------------------------------------- +%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -11,6 +12,7 @@ %% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. %% See the License for the specific language governing permissions and %% limitations under the License. +%%-------------------------------------------------------------------- -module(emqx_topic_SUITE). @@ -21,15 +23,15 @@ -compile(nowarn_export_all). -import(emqx_topic, - [wildcard/1, - match/2, - validate/1, - triples/1, - join/1, - words/1, - systop/1, - feed_var/3, - parse/1 + [ wildcard/1 + , match/2 + , validate/1 + , triples/1 + , join/1 + , words/1 + , systop/1 + , feed_var/3 + , parse/1 ]). -define(N, 10000). @@ -218,6 +220,7 @@ long_topic() -> t_parse(_) -> ?assertEqual({<<"a/b/+/#">>, #{}}, parse(<<"a/b/+/#">>)), + ?assertEqual({<<"a/b/+/#">>, #{qos => 1}}, parse({<<"a/b/+/#">>, #{qos => 1}})), ?assertEqual({<<"topic">>, #{ share => <<"$queue">> }}, parse(<<"$queue/topic">>)), ?assertEqual({<<"topic">>, #{ share => <<"group">>}}, parse(<<"$share/group/topic">>)), ?assertEqual({<<"$local/topic">>, #{}}, parse(<<"$local/topic">>)),