Improve the channel design

This commit is contained in:
Feng Lee 2019-06-25 14:53:45 +08:00
parent c4eb283517
commit 4974eab20e
8 changed files with 757 additions and 1044 deletions

View File

@ -286,7 +286,7 @@ dispatch(Topic, Delivery = #delivery{message = Msg, results = Results}) ->
dispatch(SubPid, Topic, Msg) when is_pid(SubPid) -> dispatch(SubPid, Topic, Msg) when is_pid(SubPid) ->
case erlang:is_process_alive(SubPid) of case erlang:is_process_alive(SubPid) of
true -> true ->
SubPid ! {dispatch, Topic, Msg}, SubPid ! {deliver, Topic, Msg},
1; 1;
false -> 0 false -> 0
end; end;

View File

@ -14,6 +14,7 @@
%% limitations under the License. %% limitations under the License.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% MQTT TCP/SSL Channel
-module(emqx_channel). -module(emqx_channel).
-behaviour(gen_statem). -behaviour(gen_statem).
@ -21,6 +22,7 @@
-include("emqx.hrl"). -include("emqx.hrl").
-include("emqx_mqtt.hrl"). -include("emqx_mqtt.hrl").
-include("logger.hrl"). -include("logger.hrl").
-include("types.hrl").
-logger_header("[Channel]"). -logger_header("[Channel]").
@ -32,16 +34,10 @@
, stats/1 , stats/1
]). ]).
-export([ kick/1
, discard/1
, takeover/1
]).
-export([session/1]).
%% gen_statem callbacks %% gen_statem callbacks
-export([ idle/3 -export([ idle/3
, connected/3 , connected/3
, disconnected/3
]). ]).
-export([ init/1 -export([ init/1
@ -51,28 +47,32 @@
]). ]).
-record(state, { -record(state, {
transport, transport :: esockd:transport(),
socket, socket :: esockd:sock(),
peername, peername :: {inet:ip_address(), inet:port_number()},
sockname, sockname :: {inet:ip_address(), inet:port_number()},
conn_state, conn_state :: running | blocked,
active_n, active_n :: pos_integer(),
proto_state, rate_limit :: maybe(esockd_rate_limit:bucket()),
parse_state, pub_limit :: maybe(esockd_rate_limit:bucket()),
gc_state, limit_timer :: maybe(reference()),
keepalive, serializer :: emqx_frame:serializer(), %% TODO: remove it later.
rate_limit, parse_state :: emqx_frame:parse_state(),
pub_limit, proto_state :: emqx_protocol:protocol(),
limit_timer, gc_state :: emqx_gc:gc_state(),
enable_stats, keepalive :: maybe(reference()),
stats_timer, enable_stats :: boolean(),
idle_timeout stats_timer :: maybe(reference()),
idle_timeout :: timeout()
}). }).
-define(ACTIVE_N, 100). -define(ACTIVE_N, 100).
-define(HANDLE(T, C, D), handle((T), (C), (D))). -define(HANDLE(T, C, D), handle((T), (C), (D))).
-define(CHAN_STATS, [recv_pkt, recv_msg, send_pkt, send_msg]).
-define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]). -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt, send_pend]).
-spec(start_link(esockd:transport(), esockd:sock(), proplists:proplist())
-> {ok, pid()}).
start_link(Transport, Socket, Options) -> start_link(Transport, Socket, Options) ->
{ok, proc_lib:spawn_link(?MODULE, init, [{Transport, Socket, Options}])}. {ok, proc_lib:spawn_link(?MODULE, init, [{Transport, Socket, Options}])}.
@ -126,28 +126,13 @@ attrs(#state{peername = Peername,
stats(CPid) when is_pid(CPid) -> stats(CPid) when is_pid(CPid) ->
call(CPid, stats); call(CPid, stats);
stats(#state{transport = Transport, stats(#state{transport = Transport, socket = Socket}) ->
socket = Socket,
proto_state = ProtoState}) ->
SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of SockStats = case Transport:getstat(Socket, ?SOCK_STATS) of
{ok, Ss} -> Ss; {ok, Ss} -> Ss;
{error, _} -> [] {error, _} -> []
end, end,
lists:append([SockStats, ChanStats = [{Name, emqx_pd:get_counter(Name)} || Name <- ?CHAN_STATS],
emqx_misc:proc_stats(), lists:append([SockStats, ChanStats, emqx_misc:proc_stats()]).
emqx_protocol:stats(ProtoState)]).
kick(CPid) ->
call(CPid, kick).
discard(CPid) ->
call(CPid, discard).
takeover(CPid) ->
call(CPid, takeover).
session(CPid) ->
call(CPid, session).
call(CPid, Req) -> call(CPid, Req) ->
gen_statem:call(CPid, Req, infinity). gen_statem:call(CPid, Req, infinity).
@ -166,23 +151,15 @@ init({Transport, RawSocket, Options}) ->
RateLimit = init_limiter(proplists:get_value(rate_limit, Options)), RateLimit = init_limiter(proplists:get_value(rate_limit, Options)),
PubLimit = init_limiter(emqx_zone:get_env(Zone, publish_limit)), PubLimit = init_limiter(emqx_zone:get_env(Zone, publish_limit)),
ActiveN = proplists:get_value(active_n, Options, ?ACTIVE_N), ActiveN = proplists:get_value(active_n, Options, ?ACTIVE_N),
SendFun = fun(Packet, Opts) -> MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE),
Data = emqx_frame:serialize(Packet, Opts), ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}),
case Transport:async_send(Socket, Data) of
ok -> {ok, Data};
{error, Reason} ->
{error, Reason}
end
end,
ProtoState = emqx_protocol:init(#{peername => Peername, ProtoState = emqx_protocol:init(#{peername => Peername,
sockname => Sockname, sockname => Sockname,
peercert => Peercert, peercert => Peercert,
sendfun => SendFun,
conn_mod => ?MODULE}, Options), conn_mod => ?MODULE}, Options),
MaxSize = emqx_zone:get_env(Zone, max_packet_size, ?MAX_PACKET_SIZE),
ParseState = emqx_frame:initial_parse_state(#{max_size => MaxSize}),
GcPolicy = emqx_zone:get_env(Zone, force_gc_policy, false), GcPolicy = emqx_zone:get_env(Zone, force_gc_policy, false),
GcState = emqx_gc:init(GcPolicy), GcState = emqx_gc:init(GcPolicy),
ok = emqx_misc:init_proc_mng_policy(Zone),
EnableStats = emqx_zone:get_env(Zone, enable_stats, true), EnableStats = emqx_zone:get_env(Zone, enable_stats, true),
IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000), IdleTimout = emqx_zone:get_env(Zone, idle_timeout, 30000),
State = #state{transport = Transport, State = #state{transport = Transport,
@ -192,13 +169,12 @@ init({Transport, RawSocket, Options}) ->
active_n = ActiveN, active_n = ActiveN,
rate_limit = RateLimit, rate_limit = RateLimit,
pub_limit = PubLimit, pub_limit = PubLimit,
proto_state = ProtoState,
parse_state = ParseState, parse_state = ParseState,
proto_state = ProtoState,
gc_state = GcState, gc_state = GcState,
enable_stats = EnableStats, enable_stats = EnableStats,
idle_timeout = IdleTimout idle_timeout = IdleTimout
}, },
ok = emqx_misc:init_proc_mng_policy(Zone),
gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}], gen_statem:enter_loop(?MODULE, [{hibernate_after, 2 * IdleTimout}],
idle, State, self(), [IdleTimout]). idle, State, self(), [IdleTimout]).
@ -218,12 +194,17 @@ idle(enter, _, State) ->
keep_state_and_data; keep_state_and_data;
idle(timeout, _Timeout, State) -> idle(timeout, _Timeout, State) ->
{stop, idle_timeout, State}; stop(idle_timeout, State);
idle(cast, {incoming, Packet = ?CONNECT_PACKET(ConnVar)}, State) ->
#mqtt_packet_connect{proto_ver = ProtoVer} = ConnVar,
Serializer = emqx_frame:init_serializer(#{version => ProtoVer}),
NState = State#state{serializer = Serializer},
handle_incoming(Packet, fun(St) -> {next_state, connected, St} end, NState);
idle(cast, {incoming, Packet}, State) -> idle(cast, {incoming, Packet}, State) ->
handle_incoming(Packet, fun(NState) -> ?LOG(warning, "Unexpected incoming: ~p", [Packet]),
{next_state, connected, NState} shutdown(unexpected_incoming_packet, State);
end, State);
idle(EventType, Content, State) -> idle(EventType, Content, State) ->
?HANDLE(EventType, Content, State). ?HANDLE(EventType, Content, State).
@ -235,18 +216,23 @@ connected(enter, _, _State) ->
%% What to do? %% What to do?
keep_state_and_data; keep_state_and_data;
%% Handle Input connected(cast, {incoming, Packet = ?PACKET(?CONNECT)}, State) ->
?LOG(warning, "Unexpected connect: ~p", [Packet]),
shutdown(unexpected_incoming_connect, State);
connected(cast, {incoming, Packet = ?PACKET(Type)}, State) -> connected(cast, {incoming, Packet = ?PACKET(Type)}, State) ->
ok = emqx_metrics:inc_recv(Packet), ok = emqx_metrics:inc_recv(Packet),
(Type == ?PUBLISH) andalso emqx_pd:update_counter(incoming_pubs, 1), (Type == ?PUBLISH) andalso emqx_pd:update_counter(incoming_pubs, 1),
handle_incoming(Packet, fun(NState) -> {keep_state, NState} end, State); handle_incoming(Packet, fun(St) -> {keep_state, St} end, State);
%% Handle Output %% Handle delivery
connected(info, {deliver, PubOrAck}, State = #state{proto_state = ProtoState}) -> connected(info, Devliery = {deliver, _Topic, Msg}, State = #state{proto_state = ProtoState}) ->
case emqx_protocol:deliver(PubOrAck, ProtoState) of case emqx_protocol:handle_out(Devliery, ProtoState) of
{ok, NProtoState} -> {ok, NProtoState} ->
{keep_state, State#state{proto_state = NProtoState}};
{ok, Packet, NProtoState} ->
NState = State#state{proto_state = NProtoState}, NState = State#state{proto_state = NProtoState},
{keep_state, maybe_gc(PubOrAck, NState)}; handle_outgoing(Packet, fun(St) -> {keep_state, St} end, NState);
{error, Reason} -> {error, Reason} ->
shutdown(Reason, State) shutdown(Reason, State)
end; end;
@ -281,6 +267,16 @@ connected(info, {keepalive, check}, State = #state{keepalive = KeepAlive}) ->
connected(EventType, Content, State) -> connected(EventType, Content, State) ->
?HANDLE(EventType, Content, State). ?HANDLE(EventType, Content, State).
%%--------------------------------------------------------------------
%% Disconnected State
disconnected(enter, _, _State) ->
%% TODO: What to do?
keep_state_and_data;
disconnected(EventType, Content, State) ->
?HANDLE(EventType, Content, State).
%% Handle call %% Handle call
handle({call, From}, info, State) -> handle({call, From}, info, State) ->
reply(From, info(State), State); reply(From, info(State), State);
@ -299,9 +295,6 @@ handle({call, From}, discard, State) ->
ok = gen_statem:reply(From, ok), ok = gen_statem:reply(From, ok),
shutdown(discard, State); shutdown(discard, State);
handle({call, From}, session, State = #state{proto_state = ProtoState}) ->
reply(From, emqx_protocol:session(ProtoState), State);
handle({call, From}, Req, State) -> handle({call, From}, Req, State) ->
?LOG(error, "Unexpected call: ~p", [Req]), ?LOG(error, "Unexpected call: ~p", [Req]),
reply(From, ignored, State); reply(From, ignored, State);
@ -312,7 +305,8 @@ handle(cast, Msg, State) ->
{keep_state, State}; {keep_state, State};
%% Handle Incoming %% Handle Incoming
handle(info, {Inet, _Sock, Data}, State) when Inet == tcp; Inet == ssl -> handle(info, {Inet, _Sock, Data}, State) when Inet == tcp;
Inet == ssl ->
Oct = iolist_size(Data), Oct = iolist_size(Data),
?LOG(debug, "RECV ~p", [Data]), ?LOG(debug, "RECV ~p", [Data]),
emqx_pd:update_counter(incoming_bytes, Oct), emqx_pd:update_counter(incoming_bytes, Oct),
@ -390,15 +384,9 @@ terminate(Reason, _StateName, #state{transport = Transport,
keepalive = KeepAlive, keepalive = KeepAlive,
proto_state = ProtoState}) -> proto_state = ProtoState}) ->
?LOG(debug, "Terminated for ~p", [Reason]), ?LOG(debug, "Terminated for ~p", [Reason]),
Transport:fast_close(Socket), ok = Transport:fast_close(Socket),
emqx_keepalive:cancel(KeepAlive), ok = emqx_keepalive:cancel(KeepAlive),
case {ProtoState, Reason} of emqx_protocol:terminate(Reason, ProtoState).
{undefined, _} -> ok;
{_, {shutdown, Error}} ->
emqx_protocol:terminate(Error, ProtoState);
{_, Reason} ->
emqx_protocol:terminate(Reason, ProtoState)
end.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Process incoming data %% Process incoming data
@ -431,10 +419,16 @@ next_events(Packet) ->
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Handle incoming packet %% Handle incoming packet
handle_incoming(Packet, SuccFun, State = #state{proto_state = ProtoState}) -> handle_incoming(Packet = ?PACKET(Type), SuccFun,
case emqx_protocol:received(Packet, ProtoState) of State = #state{proto_state = ProtoState}) ->
_ = inc_incoming_stats(Type),
?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]),
case emqx_protocol:handle_in(Packet, ProtoState) of
{ok, NProtoState} -> {ok, NProtoState} ->
SuccFun(State#state{proto_state = NProtoState}); SuccFun(State#state{proto_state = NProtoState});
{ok, OutPacket, NProtoState} ->
handle_outgoing(OutPacket, SuccFun,
State#state{proto_state = NProtoState});
{error, Reason} -> {error, Reason} ->
shutdown(Reason, State); shutdown(Reason, State);
{error, Reason, NProtoState} -> {error, Reason, NProtoState} ->
@ -443,6 +437,22 @@ handle_incoming(Packet, SuccFun, State = #state{proto_state = ProtoState}) ->
stop(Error, State#state{proto_state = NProtoState}) stop(Error, State#state{proto_state = NProtoState})
end. end.
%%--------------------------------------------------------------------
%% Handle outgoing packet
handle_outgoing(Packet = ?PACKET(Type), SuccFun,
State = #state{transport = Transport,
socket = Socket,
serializer = Serializer}) ->
_ = inc_outgoing_stats(Type),
?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]),
Data = Serializer(Packet),
case Transport:async_send(Socket, Data) of
ok -> SuccFun(State);
{error, Reason} ->
shutdown(Reason, State)
end.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Ensure rate limit %% Ensure rate limit
@ -465,6 +475,12 @@ ensure_rate_limit([{Rl, Pos, Cnt}|Limiters], State) ->
setelement(Pos, State#state{conn_state = blocked, limit_timer = TRef}, Rl1) setelement(Pos, State#state{conn_state = blocked, limit_timer = TRef}, Rl1)
end. end.
%% start_keepalive(0, _PState) ->
%% ignore;
%% start_keepalive(Secs, #pstate{zone = Zone}) when Secs > 0 ->
%% Backoff = emqx_zone:get_env(Zone, keepalive_backoff, 0.75),
%% self() ! {keepalive, start, round(Secs * Backoff)}.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Activate socket %% Activate socket
@ -479,6 +495,17 @@ activate_socket(#state{transport = Transport, socket = Socket, active_n = N}) ->
ok ok
end. end.
%%--------------------------------------------------------------------
%% Inc incoming/outgoing stats
inc_incoming_stats(Type) ->
emqx_pd:update_counter(recv_pkt, 1),
Type =:= ?PUBLISH andalso emqx_pd:update_counter(recv_msg, 1).
inc_outgoing_stats(Type) ->
emqx_pd:update_counter(send_pkt, 1),
Type =:= ?PUBLISH andalso emqx_pd:update_counter(send_msg, 1).
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Ensure stats timer %% Ensure stats timer

View File

@ -21,6 +21,7 @@
-export([ initial_parse_state/0 -export([ initial_parse_state/0
, initial_parse_state/1 , initial_parse_state/1
, init_serializer/1
]). ]).
-export([ parse/1 -export([ parse/1
@ -29,22 +30,22 @@
, serialize/2 , serialize/2
]). ]).
-export_type([ options/0
, parse_state/0
, parse_result/0
]).
-type(options() :: #{max_size => 1..?MAX_PACKET_SIZE, -type(options() :: #{max_size => 1..?MAX_PACKET_SIZE,
version => emqx_mqtt_types:version() version => emqx_mqtt:version()
}). }).
-opaque(parse_state() :: {none, options()} | {more, cont_fun()}). -opaque(parse_state() :: {none, options()} | {more, cont_fun()}).
-opaque(parse_result() :: {ok, parse_state()} -opaque(parse_result() :: {ok, parse_state()}
| {ok, emqx_mqtt_types:packet(), binary(), parse_state()}). | {ok, emqx_mqtt:packet(), binary(), parse_state()}).
-type(cont_fun() :: fun((binary()) -> parse_result())). -type(cont_fun() :: fun((binary()) -> parse_result())).
-export_type([ options/0
, parse_state/0
, parse_result/0
]).
-define(none(Opts), {none, Opts}). -define(none(Opts), {none, Opts}).
-define(more(Cont), {more, Cont}). -define(more(Cont), {more, Cont}).
-define(DEFAULT_OPTIONS, -define(DEFAULT_OPTIONS,
@ -385,11 +386,14 @@ parse_binary_data(<<Len:16/big, Data:Len/binary, Rest/binary>>) ->
%% Serialize MQTT Packet %% Serialize MQTT Packet
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
-spec(serialize(emqx_mqtt_types:packet()) -> iodata()). init_serializer(Options) ->
fun(Packet) -> serialize(Packet, Options) end.
-spec(serialize(emqx_mqtt:packet()) -> iodata()).
serialize(Packet) -> serialize(Packet) ->
serialize(Packet, ?DEFAULT_OPTIONS). serialize(Packet, ?DEFAULT_OPTIONS).
-spec(serialize(emqx_mqtt_types:packet(), options()) -> iodata()). -spec(serialize(emqx_mqtt:packet(), options()) -> iodata()).
serialize(#mqtt_packet{header = Header, serialize(#mqtt_packet{header = Header,
variable = Variable, variable = Variable,
payload = Payload}, Options) when is_map(Options) -> payload = Payload}, Options) when is_map(Options) ->

View File

@ -33,6 +33,8 @@
, window/1 , window/1
]). ]).
-export_type([inflight/0]).
-type(key() :: term()). -type(key() :: term()).
-type(max_size() :: pos_integer()). -type(max_size() :: pos_integer()).
@ -43,8 +45,6 @@
-define(Inflight(MaxSize, Tree), {?MODULE, MaxSize, (Tree)}). -define(Inflight(MaxSize, Tree), {?MODULE, MaxSize, (Tree)}).
-export_type([inflight/0]).
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% APIs %% APIs
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------

View File

@ -14,11 +14,13 @@
%% limitations under the License. %% limitations under the License.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% MQTT Protocol
-module(emqx_protocol). -module(emqx_protocol).
-include("emqx.hrl"). -include("emqx.hrl").
-include("emqx_mqtt.hrl"). -include("emqx_mqtt.hrl").
-include("logger.hrl"). -include("logger.hrl").
-include("types.hrl").
-logger_header("[Protocol]"). -logger_header("[Protocol]").
@ -27,53 +29,49 @@
, attr/2 , attr/2
, caps/1 , caps/1
, caps/2 , caps/2
, stats/1
, client_id/1 , client_id/1
, credentials/1 , credentials/1
, session/1 , session/1
]). ]).
-export([ init/2 -export([ init/2
, received/2 , handle_in/2
, process/2 , handle_out/2
, deliver/2 , handle_timeout/3
, send/2
, terminate/2 , terminate/2
]). ]).
-record(pstate, { -export_type([protocol/0]).
zone,
-record(protocol, {
zone :: emqx_zone:zone(),
conn_mod :: module(),
sendfun, sendfun,
sockname, sockname,
peername, peername,
peercert, peercert,
proto_ver, proto_ver :: emqx_mqtt:version(),
proto_name, proto_name,
client_id, client_id :: maybe(emqx_types:client_id()),
is_assigned, is_assigned,
username :: maybe(emqx_types:username()),
conn_props, conn_props,
ack_props, ack_props,
username, credentials :: map(),
session, session :: maybe(emqx_session:session()),
clean_start, clean_start,
topic_aliases, topic_aliases,
will_topic, will_topic,
will_msg, will_msg,
keepalive, keepalive,
is_bridge, is_bridge :: boolean(),
recv_stats, connected :: boolean(),
send_stats, connected_at :: erlang:timestamp(),
connected,
connected_at,
topic_alias_maximum, topic_alias_maximum,
conn_mod,
credentials,
ws_cookie ws_cookie
}). }).
-opaque(state() :: #pstate{}). -opaque(protocol() :: #protocol{}).
-export_type([state/0]).
-ifdef(TEST). -ifdef(TEST).
-compile(export_all). -compile(export_all).
@ -86,14 +84,13 @@
%% Init %% Init
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
-spec(init(map(), list()) -> state()). -spec(init(map(), list()) -> protocol()).
init(SocketOpts = #{sockname := Sockname, init(SocketOpts = #{sockname := Sockname,
peername := Peername, peername := Peername,
peercert := Peercert, peercert := Peercert}, Options) ->
sendfun := SendFun}, Options) ->
Zone = proplists:get_value(zone, Options), Zone = proplists:get_value(zone, Options),
#pstate{zone = Zone, #protocol{zone = Zone,
sendfun = SendFun, %%sendfun = SendFun,
sockname = Sockname, sockname = Sockname,
peername = Peername, peername = Peername,
peercert = Peercert, peercert = Peercert,
@ -101,18 +98,17 @@ init(SocketOpts = #{sockname := Sockname,
proto_name = <<"MQTT">>, proto_name = <<"MQTT">>,
client_id = <<>>, client_id = <<>>,
is_assigned = false, is_assigned = false,
%%conn_pid = self(),
username = init_username(Peercert, Options), username = init_username(Peercert, Options),
clean_start = false, clean_start = false,
topic_aliases = #{}, topic_aliases = #{},
is_bridge = false, is_bridge = false,
recv_stats = #{msg => 0, pkt => 0},
send_stats = #{msg => 0, pkt => 0},
connected = false, connected = false,
%% TODO: ...?
topic_alias_maximum = #{to_client => 0, from_client => 0}, topic_alias_maximum = #{to_client => 0, from_client => 0},
conn_mod = maps:get(conn_mod, SocketOpts, undefined), conn_mod = maps:get(conn_mod, SocketOpts, undefined),
credentials = #{}, credentials = #{},
ws_cookie = maps:get(ws_cookie, SocketOpts, undefined)}. ws_cookie = maps:get(ws_cookie, SocketOpts, undefined)
}.
init_username(Peercert, Options) -> init_username(Peercert, Options) ->
case proplists:get_value(peer_cert_as_username, Options) of case proplists:get_value(peer_cert_as_username, Options) of
@ -122,8 +118,8 @@ init_username(Peercert, Options) ->
_ -> undefined _ -> undefined
end. end.
set_username(Username, PState = #pstate{username = undefined}) -> set_username(Username, PState = #protocol{username = undefined}) ->
PState#pstate{username = Username}; PState#protocol{username = Username};
set_username(_Username, PState) -> set_username(_Username, PState) ->
PState. PState.
@ -131,7 +127,7 @@ set_username(_Username, PState) ->
%% API %% API
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
info(PState = #pstate{zone = Zone, info(PState = #protocol{zone = Zone,
conn_props = ConnProps, conn_props = ConnProps,
ack_props = AckProps, ack_props = AckProps,
session = Session, session = Session,
@ -145,7 +141,7 @@ info(PState = #pstate{zone = Zone,
enable_acl => emqx_zone:get_env(Zone, enable_acl, false) enable_acl => emqx_zone:get_env(Zone, enable_acl, false)
}). }).
attrs(#pstate{zone = Zone, attrs(#protocol{zone = Zone,
client_id = ClientId, client_id = ClientId,
username = Username, username = Username,
peername = Peername, peername = Peername,
@ -173,25 +169,25 @@ attrs(#pstate{zone = Zone,
credentials => Credentials credentials => Credentials
}. }.
attr(proto_ver, #pstate{proto_ver = ProtoVer}) -> attr(proto_ver, #protocol{proto_ver = ProtoVer}) ->
ProtoVer; ProtoVer;
attr(max_inflight, #pstate{proto_ver = ?MQTT_PROTO_V5, conn_props = ConnProps}) -> attr(max_inflight, #protocol{proto_ver = ?MQTT_PROTO_V5, conn_props = ConnProps}) ->
get_property('Receive-Maximum', ConnProps, 65535); get_property('Receive-Maximum', ConnProps, 65535);
attr(max_inflight, #pstate{zone = Zone}) -> attr(max_inflight, #protocol{zone = Zone}) ->
emqx_zone:get_env(Zone, max_inflight, 65535); emqx_zone:get_env(Zone, max_inflight, 65535);
attr(expiry_interval, #pstate{proto_ver = ?MQTT_PROTO_V5, conn_props = ConnProps}) -> attr(expiry_interval, #protocol{proto_ver = ?MQTT_PROTO_V5, conn_props = ConnProps}) ->
get_property('Session-Expiry-Interval', ConnProps, 0); get_property('Session-Expiry-Interval', ConnProps, 0);
attr(expiry_interval, #pstate{zone = Zone, clean_start = CleanStart}) -> attr(expiry_interval, #protocol{zone = Zone, clean_start = CleanStart}) ->
case CleanStart of case CleanStart of
true -> 0; true -> 0;
false -> emqx_zone:get_env(Zone, session_expiry_interval, 16#ffffffff) false -> emqx_zone:get_env(Zone, session_expiry_interval, 16#ffffffff)
end; end;
attr(topic_alias_maximum, #pstate{proto_ver = ?MQTT_PROTO_V5, conn_props = ConnProps}) -> attr(topic_alias_maximum, #protocol{proto_ver = ?MQTT_PROTO_V5, conn_props = ConnProps}) ->
get_property('Topic-Alias-Maximum', ConnProps, 0); get_property('Topic-Alias-Maximum', ConnProps, 0);
attr(topic_alias_maximum, #pstate{zone = Zone}) -> attr(topic_alias_maximum, #protocol{zone = Zone}) ->
emqx_zone:get_env(Zone, max_topic_alias, 0); emqx_zone:get_env(Zone, max_topic_alias, 0);
attr(Name, PState) -> attr(Name, PState) ->
Attrs = lists:zip(record_info(fields, pstate), tl(tuple_to_list(PState))), Attrs = lists:zip(record_info(fields, protocol), tl(tuple_to_list(PState))),
case lists:keyfind(Name, 1, Attrs) of case lists:keyfind(Name, 1, Attrs) of
{_, Value} -> Value; {_, Value} -> Value;
false -> undefined false -> undefined
@ -200,13 +196,13 @@ attr(Name, PState) ->
caps(Name, PState) -> caps(Name, PState) ->
maps:get(Name, caps(PState)). maps:get(Name, caps(PState)).
caps(#pstate{zone = Zone}) -> caps(#protocol{zone = Zone}) ->
emqx_mqtt_caps:get_caps(Zone). emqx_mqtt_caps:get_caps(Zone).
client_id(#pstate{client_id = ClientId}) -> client_id(#protocol{client_id = ClientId}) ->
ClientId. ClientId.
credentials(#pstate{zone = Zone, credentials(#protocol{zone = Zone,
client_id = ClientId, client_id = ClientId,
username = Username, username = Username,
sockname = Sockname, sockname = Sockname,
@ -232,79 +228,61 @@ keepsafety(Credentials) ->
(cn, _) -> false; (cn, _) -> false;
(_, _) -> true end, Credentials). (_, _) -> true end, Credentials).
stats(#pstate{recv_stats = #{pkt := RecvPkt, msg := RecvMsg}, session(#protocol{session = Session}) ->
send_stats = #{pkt := SendPkt, msg := SendMsg}}) ->
[{recv_pkt, RecvPkt},
{recv_msg, RecvMsg},
{send_pkt, SendPkt},
{send_msg, SendMsg}].
session(#pstate{session = Session}) ->
Session. Session.
%%------------------------------------------------------------------------------ %%--------------------------------------------------------------------
%% Packet Received %% Packet Received
%%------------------------------------------------------------------------------ %%--------------------------------------------------------------------
set_protover(?CONNECT_PACKET(#mqtt_packet_connect{proto_ver = ProtoVer}), PState) -> set_protover(?CONNECT_PACKET(#mqtt_packet_connect{proto_ver = ProtoVer}), PState) ->
PState#pstate{proto_ver = ProtoVer}; PState#protocol{proto_ver = ProtoVer};
set_protover(_Packet, PState) -> set_protover(_Packet, PState) ->
PState. PState.
-spec(received(emqx_mqtt_types:packet(), state()) handle_in(?PACKET(Type), PState = #protocol{connected = false}) when Type =/= ?CONNECT ->
-> {ok, state()}
| {error, term()}
| {error, term(), state()}
| {stop, term(), state()}).
received(?PACKET(Type), PState = #pstate{connected = false}) when Type =/= ?CONNECT ->
{error, proto_not_connected, PState}; {error, proto_not_connected, PState};
received(?PACKET(?CONNECT), PState = #pstate{connected = true}) -> handle_in(?PACKET(?CONNECT), PState = #protocol{connected = true}) ->
{error, proto_unexpected_connect, PState}; {error, proto_unexpected_connect, PState};
received(Packet = ?PACKET(Type), PState) -> handle_in(Packet = ?PACKET(_Type), PState) ->
trace(recv, Packet),
PState1 = set_protover(Packet, PState), PState1 = set_protover(Packet, PState),
try emqx_packet:validate(Packet) of try emqx_packet:validate(Packet) of
true -> true ->
case preprocess_properties(Packet, PState1) of case preprocess_properties(Packet, PState1) of
{ok, Packet1, PState2} -> {ok, Packet1, PState2} ->
process(Packet1, inc_stats(recv, Type, PState2)); process(Packet1, PState2);
{error, ReasonCode} -> {error, ReasonCode} ->
{error, ReasonCode, PState1} handle_out({disconnect, ReasonCode}, PState1)
end end
catch catch
error:protocol_error -> error:protocol_error ->
deliver({disconnect, ?RC_PROTOCOL_ERROR}, PState1), handle_out({disconnect, ?RC_PROTOCOL_ERROR}, PState1);
{error, protocol_error, PState};
error:subscription_identifier_invalid -> error:subscription_identifier_invalid ->
deliver({disconnect, ?RC_SUBSCRIPTION_IDENTIFIERS_NOT_SUPPORTED}, PState1), handle_out({disconnect, ?RC_SUBSCRIPTION_IDENTIFIERS_NOT_SUPPORTED}, PState1);
{error, subscription_identifier_invalid, PState1};
error:topic_alias_invalid -> error:topic_alias_invalid ->
deliver({disconnect, ?RC_TOPIC_ALIAS_INVALID}, PState1), handle_out({disconnect, ?RC_TOPIC_ALIAS_INVALID}, PState1);
{error, topic_alias_invalid, PState1};
error:topic_filters_invalid -> error:topic_filters_invalid ->
deliver({disconnect, ?RC_TOPIC_FILTER_INVALID}, PState1), handle_out({disconnect, ?RC_TOPIC_FILTER_INVALID}, PState1);
{error, topic_filters_invalid, PState1};
error:topic_name_invalid -> error:topic_name_invalid ->
deliver({disconnect, ?RC_TOPIC_FILTER_INVALID}, PState1), handle_out({disconnect, ?RC_TOPIC_FILTER_INVALID}, PState1);
{error, topic_filters_invalid, PState1}; error:_Reason ->
error:Reason -> %% TODO: {error, Reason, PState1}
deliver({disconnect, ?RC_MALFORMED_PACKET}, PState1), handle_out({disconnect, ?RC_MALFORMED_PACKET}, PState1)
{error, Reason, PState1}
end. end.
%%------------------------------------------------------------------------------ %%--------------------------------------------------------------------
%% Preprocess MQTT Properties %% Preprocess MQTT Properties
%%------------------------------------------------------------------------------ %%--------------------------------------------------------------------
preprocess_properties(Packet = #mqtt_packet{ preprocess_properties(Packet = #mqtt_packet{
variable = #mqtt_packet_connect{ variable = #mqtt_packet_connect{
properties = #{'Topic-Alias-Maximum' := ToClient} properties = #{'Topic-Alias-Maximum' := ToClient}
} }
}, },
PState = #pstate{topic_alias_maximum = TopicAliasMaximum}) -> PState = #protocol{topic_alias_maximum = TopicAliasMaximum}) ->
{ok, Packet, PState#pstate{topic_alias_maximum = TopicAliasMaximum#{to_client => ToClient}}}; {ok, Packet, PState#protocol{topic_alias_maximum = TopicAliasMaximum#{to_client => ToClient}}};
%% Subscription Identifier %% Subscription Identifier
preprocess_properties(Packet = #mqtt_packet{ preprocess_properties(Packet = #mqtt_packet{
@ -313,7 +291,7 @@ preprocess_properties(Packet = #mqtt_packet{
topic_filters = TopicFilters topic_filters = TopicFilters
} }
}, },
PState = #pstate{proto_ver = ?MQTT_PROTO_V5}) -> PState = #protocol{proto_ver = ?MQTT_PROTO_V5}) ->
TopicFilters1 = [{Topic, SubOpts#{subid => SubId}} || {Topic, SubOpts} <- TopicFilters], TopicFilters1 = [{Topic, SubOpts#{subid => SubId}} || {Topic, SubOpts} <- TopicFilters],
{ok, Packet#mqtt_packet{variable = Subscribe#mqtt_packet_subscribe{topic_filters = TopicFilters1}}, PState}; {ok, Packet#mqtt_packet{variable = Subscribe#mqtt_packet_subscribe{topic_filters = TopicFilters1}}, PState};
@ -323,7 +301,6 @@ preprocess_properties(#mqtt_packet{
properties = #{'Topic-Alias' := 0}} properties = #{'Topic-Alias' := 0}}
}, },
PState) -> PState) ->
deliver({disconnect, ?RC_TOPIC_ALIAS_INVALID}, PState),
{error, ?RC_TOPIC_ALIAS_INVALID}; {error, ?RC_TOPIC_ALIAS_INVALID};
preprocess_properties(Packet = #mqtt_packet{ preprocess_properties(Packet = #mqtt_packet{
@ -331,7 +308,7 @@ preprocess_properties(Packet = #mqtt_packet{
topic_name = <<>>, topic_name = <<>>,
properties = #{'Topic-Alias' := AliasId}} properties = #{'Topic-Alias' := AliasId}}
}, },
PState = #pstate{proto_ver = ?MQTT_PROTO_V5, PState = #protocol{proto_ver = ?MQTT_PROTO_V5,
topic_aliases = Aliases, topic_aliases = Aliases,
topic_alias_maximum = #{from_client := TopicAliasMaximum}}) -> topic_alias_maximum = #{from_client := TopicAliasMaximum}}) ->
case AliasId =< TopicAliasMaximum of case AliasId =< TopicAliasMaximum of
@ -339,7 +316,6 @@ preprocess_properties(Packet = #mqtt_packet{
{ok, Packet#mqtt_packet{variable = Publish#mqtt_packet_publish{ {ok, Packet#mqtt_packet{variable = Publish#mqtt_packet_publish{
topic_name = maps:get(AliasId, Aliases, <<>>)}}, PState}; topic_name = maps:get(AliasId, Aliases, <<>>)}}, PState};
false -> false ->
deliver({disconnect, ?RC_TOPIC_ALIAS_INVALID}, PState),
{error, ?RC_TOPIC_ALIAS_INVALID} {error, ?RC_TOPIC_ALIAS_INVALID}
end; end;
@ -348,23 +324,22 @@ preprocess_properties(Packet = #mqtt_packet{
topic_name = Topic, topic_name = Topic,
properties = #{'Topic-Alias' := AliasId}} properties = #{'Topic-Alias' := AliasId}}
}, },
PState = #pstate{proto_ver = ?MQTT_PROTO_V5, PState = #protocol{proto_ver = ?MQTT_PROTO_V5,
topic_aliases = Aliases, topic_aliases = Aliases,
topic_alias_maximum = #{from_client := TopicAliasMaximum}}) -> topic_alias_maximum = #{from_client := TopicAliasMaximum}}) ->
case AliasId =< TopicAliasMaximum of case AliasId =< TopicAliasMaximum of
true -> true ->
{ok, Packet, PState#pstate{topic_aliases = maps:put(AliasId, Topic, Aliases)}}; {ok, Packet, PState#protocol{topic_aliases = maps:put(AliasId, Topic, Aliases)}};
false -> false ->
deliver({disconnect, ?RC_TOPIC_ALIAS_INVALID}, PState),
{error, ?RC_TOPIC_ALIAS_INVALID} {error, ?RC_TOPIC_ALIAS_INVALID}
end; end;
preprocess_properties(Packet, PState) -> preprocess_properties(Packet, PState) ->
{ok, Packet, PState}. {ok, Packet, PState}.
%%------------------------------------------------------------------------------ %%--------------------------------------------------------------------
%% Process MQTT Packet %% Process MQTT Packet
%%------------------------------------------------------------------------------ %%--------------------------------------------------------------------
process(?CONNECT_PACKET( process(?CONNECT_PACKET(
#mqtt_packet_connect{proto_name = ProtoName, #mqtt_packet_connect{proto_name = ProtoName,
@ -381,7 +356,7 @@ process(?CONNECT_PACKET(
%% Msg -> emqx_mountpoint:mount(MountPoint, Msg) %% Msg -> emqx_mountpoint:mount(MountPoint, Msg)
PState0 = maybe_use_username_as_clientid(ClientId, PState0 = maybe_use_username_as_clientid(ClientId,
set_username(Username, set_username(Username,
PState#pstate{proto_ver = ProtoVer, PState#protocol{proto_ver = ProtoVer,
proto_name = ProtoName, proto_name = ProtoName,
clean_start = CleanStart, clean_start = CleanStart,
keepalive = Keepalive, keepalive = Keepalive,
@ -389,120 +364,115 @@ process(?CONNECT_PACKET(
is_bridge = IsBridge, is_bridge = IsBridge,
connected_at = os:timestamp()})), connected_at = os:timestamp()})),
NewClientId = PState0#pstate.client_id, NewClientId = PState0#protocol.client_id,
emqx_logger:set_metadata_client_id(NewClientId), emqx_logger:set_metadata_client_id(NewClientId),
Credentials = credentials(PState0), Credentials = credentials(PState0),
PState1 = PState0#pstate{credentials = Credentials}, PState1 = PState0#protocol{credentials = Credentials},
connack( connack(
case check_connect(ConnPkt, PState1) of case check_connect(ConnPkt, PState1) of
ok -> ok ->
case emqx_access_control:authenticate(Credentials#{password => Password}) of case emqx_access_control:authenticate(Credentials#{password => Password}) of
{ok, Credentials0} -> {ok, Credentials0} ->
PState3 = maybe_assign_client_id(PState1), PState3 = maybe_assign_client_id(PState1),
emqx_logger:set_metadata_client_id(PState3#pstate.client_id), emqx_logger:set_metadata_client_id(PState3#protocol.client_id),
%% Open session %% Open session
SessAttrs = #{will_msg => make_will_msg(ConnPkt)}, SessAttrs = #{will_msg => make_will_msg(ConnPkt)},
case try_open_session(SessAttrs, PState3) of case try_open_session(SessAttrs, PState3) of
{ok, Session, SP} -> {ok, Session, SP} ->
PState4 = PState3#pstate{session = Session, connected = true, PState4 = PState3#protocol{session = Session, connected = true,
credentials = keepsafety(Credentials0)}, credentials = keepsafety(Credentials0)},
ok = emqx_cm:register_channel(client_id(PState4)), ok = emqx_cm:register_channel(client_id(PState4)),
ok = emqx_cm:set_conn_attrs(client_id(PState4), attrs(PState4)), true = emqx_cm:set_conn_attrs(client_id(PState4), attrs(PState4)),
%% Start keepalive %% Start keepalive
start_keepalive(Keepalive, PState4), start_keepalive(Keepalive, PState4),
%% Success %% Success
{?RC_SUCCESS, SP, PState4}; {?RC_SUCCESS, SP, PState4};
{error, Error} -> {error, Error} ->
?LOG(error, "Failed to open session: ~p", [Error]), ?LOG(error, "Failed to open session: ~p", [Error]),
{?RC_UNSPECIFIED_ERROR, PState1#pstate{credentials = Credentials0}} {?RC_UNSPECIFIED_ERROR, PState1#protocol{credentials = Credentials0}}
end; end;
{error, Reason} -> {error, Reason} ->
?LOG(warning, "Client ~s (Username: '~s') login failed for ~p", [NewClientId, Username, Reason]), ?LOG(warning, "Client ~s (Username: '~s') login failed for ~p", [NewClientId, Username, Reason]),
{emqx_reason_codes:connack_error(Reason), PState1#pstate{credentials = Credentials}} {emqx_reason_codes:connack_error(Reason), PState1#protocol{credentials = Credentials}}
end; end;
{error, ReasonCode} -> {error, ReasonCode} ->
{ReasonCode, PState1} {ReasonCode, PState1}
end); end);
process(Packet = ?PUBLISH_PACKET(?QOS_0, Topic, _PacketId, _Payload), PState = #pstate{zone = Zone}) -> process(Packet = ?PUBLISH_PACKET(?QOS_0, Topic, _PacketId, _Payload), PState = #protocol{zone = Zone}) ->
case check_publish(Packet, PState) of case check_publish(Packet, PState) of
ok -> ok ->
do_publish(Packet, PState); do_publish(Packet, PState);
{error, ReasonCode} -> {error, ReasonCode} ->
?LOG(warning, "Cannot publish qos0 message to ~s for ~s", ?LOG(warning, "Cannot publish qos0 message to ~s for ~s",
[Topic, emqx_reason_codes:text(ReasonCode)]), [Topic, emqx_reason_codes:text(ReasonCode)]),
%% TODO: ...
AclDenyAction = emqx_zone:get_env(Zone, acl_deny_action, ignore), AclDenyAction = emqx_zone:get_env(Zone, acl_deny_action, ignore),
do_acl_deny_action(AclDenyAction, Packet, ReasonCode, PState) do_acl_deny_action(AclDenyAction, Packet, ReasonCode, PState)
end; end;
process(Packet = ?PUBLISH_PACKET(?QOS_1, Topic, PacketId, _Payload), PState = #pstate{zone = Zone}) -> process(Packet = ?PUBLISH_PACKET(?QOS_1, Topic, PacketId, _Payload), PState = #protocol{zone = Zone}) ->
case check_publish(Packet, PState) of case check_publish(Packet, PState) of
ok -> ok ->
do_publish(Packet, PState); do_publish(Packet, PState);
{error, ReasonCode} -> {error, ReasonCode} ->
?LOG(warning, "Cannot publish qos1 message to ~s for ~s", [Topic, emqx_reason_codes:text(ReasonCode)]), ?LOG(warning, "Cannot publish qos1 message to ~s for ~s",
case deliver({puback, PacketId, ReasonCode}, PState) of [Topic, emqx_reason_codes:text(ReasonCode)]),
{ok, PState1} -> handle_out({puback, PacketId, ReasonCode}, PState)
AclDenyAction = emqx_zone:get_env(Zone, acl_deny_action, ignore),
do_acl_deny_action(AclDenyAction, Packet, ReasonCode, PState1);
Error -> Error
end
end; end;
process(Packet = ?PUBLISH_PACKET(?QOS_2, Topic, PacketId, _Payload), PState = #pstate{zone = Zone}) -> process(Packet = ?PUBLISH_PACKET(?QOS_2, Topic, PacketId, _Payload), PState = #protocol{zone = Zone}) ->
case check_publish(Packet, PState) of case check_publish(Packet, PState) of
ok -> ok ->
do_publish(Packet, PState); do_publish(Packet, PState);
{error, ReasonCode} -> {error, ReasonCode} ->
?LOG(warning, "Cannot publish qos2 message to ~s for ~s", ?LOG(warning, "Cannot publish qos2 message to ~s for ~s",
[Topic, emqx_reason_codes:text(ReasonCode)]), [Topic, emqx_reason_codes:text(ReasonCode)]),
case deliver({pubrec, PacketId, ReasonCode}, PState) of handle_out({pubrec, PacketId, ReasonCode}, PState)
{ok, PState1} ->
AclDenyAction = emqx_zone:get_env(Zone, acl_deny_action, ignore),
do_acl_deny_action(AclDenyAction, Packet, ReasonCode, PState1);
Error -> Error
end
end; end;
process(?PUBACK_PACKET(PacketId, ReasonCode), PState = #pstate{session = Session}) -> process(?PUBACK_PACKET(PacketId, ReasonCode), PState = #protocol{session = Session}) ->
NSession = emqx_session:puback(PacketId, ReasonCode, Session), case emqx_session:puback(PacketId, ReasonCode, Session) of
{ok, PState#pstate{session = NSession}}; {ok, NSession} ->
{ok, PState#protocol{session = NSession}};
{error, _NotFound} ->
{ok, PState} %% TODO: Fixme later
end;
process(?PUBREC_PACKET(PacketId, ReasonCode), PState = #pstate{session = Session}) -> process(?PUBREC_PACKET(PacketId, ReasonCode), PState = #protocol{session = Session}) ->
case emqx_session:pubrec(PacketId, ReasonCode, Session) of case emqx_session:pubrec(PacketId, ReasonCode, Session) of
{ok, NSession} -> {ok, NSession} ->
send(?PUBREL_PACKET(PacketId), PState#pstate{session = NSession}); {ok, ?PUBREL_PACKET(PacketId), PState#protocol{session = NSession}};
{error, NotFound} -> {error, NotFound} ->
send(?PUBREL_PACKET(PacketId, NotFound), PState) {ok, ?PUBREL_PACKET(PacketId, NotFound), PState}
end; end;
process(?PUBREL_PACKET(PacketId, ReasonCode), PState = #pstate{session = Session}) -> process(?PUBREL_PACKET(PacketId, ReasonCode), PState = #protocol{session = Session}) ->
case emqx_session:pubrel(PacketId, ReasonCode, Session) of case emqx_session:pubrel(PacketId, ReasonCode, Session) of
{ok, NSession} -> {ok, NSession} ->
send(?PUBCOMP_PACKET(PacketId), PState#pstate{session = NSession}); {ok, ?PUBCOMP_PACKET(PacketId), PState#protocol{session = NSession}};
{error, NotFound} -> {error, NotFound} ->
send(?PUBCOMP_PACKET(PacketId, NotFound), PState) {ok, ?PUBCOMP_PACKET(PacketId, NotFound), PState}
end; end;
process(?PUBCOMP_PACKET(PacketId, ReasonCode), PState = #pstate{session = Session}) -> process(?PUBCOMP_PACKET(PacketId, ReasonCode), PState = #protocol{session = Session}) ->
case emqx_session:pubcomp(PacketId, ReasonCode, Session) of case emqx_session:pubcomp(PacketId, ReasonCode, Session) of
{ok, NSession} -> {ok, NSession} ->
{ok, PState#pstate{session = NSession}}; {ok, PState#protocol{session = NSession}};
{error, _NotFound} -> {error, _NotFound} -> ok
%% TODO: How to handle NotFound? %% TODO: How to handle NotFound?
{ok, PState}
end; end;
process(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters), process(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters),
PState = #pstate{zone = Zone, proto_ver = ProtoVer, session = Session, credentials = Credentials}) -> PState = #protocol{zone = Zone, session = Session, credentials = Credentials}) ->
case check_subscribe(parse_topic_filters(?SUBSCRIBE, raw_topic_filters(PState, RawTopicFilters)), PState) of case check_subscribe(parse_topic_filters(?SUBSCRIBE, raw_topic_filters(PState, RawTopicFilters)), PState) of
{ok, TopicFilters} -> {ok, TopicFilters} ->
TopicFilters0 = emqx_hooks:run_fold('client.subscribe', [Credentials], TopicFilters), TopicFilters0 = emqx_hooks:run_fold('client.subscribe', [Credentials], TopicFilters),
TopicFilters1 = emqx_mountpoint:mount(mountpoint(Credentials), TopicFilters0), TopicFilters1 = emqx_mountpoint:mount(mountpoint(Credentials), TopicFilters0),
{ok, ReasonCodes, NSession} = emqx_session:subscribe(TopicFilters1, Session), {ok, ReasonCodes, NSession} = emqx_session:subscribe(TopicFilters1, Session),
deliver({suback, PacketId, ReasonCodes}, PState#pstate{session = NSession}); handle_out({suback, PacketId, ReasonCodes}, PState#protocol{session = NSession});
{error, TopicFilters} -> {error, TopicFilters} ->
{SubTopics, ReasonCodes} = {SubTopics, ReasonCodes} =
lists:foldr(fun({Topic, #{rc := ?RC_SUCCESS}}, {Topics, Codes}) -> lists:foldr(fun({Topic, #{rc := ?RC_SUCCESS}}, {Topics, Codes}) ->
@ -512,115 +482,98 @@ process(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters),
end, {[], []}, TopicFilters), end, {[], []}, TopicFilters),
?LOG(warning, "Cannot subscribe ~p for ~p", ?LOG(warning, "Cannot subscribe ~p for ~p",
[SubTopics, [emqx_reason_codes:text(R) || R <- ReasonCodes]]), [SubTopics, [emqx_reason_codes:text(R) || R <- ReasonCodes]]),
case deliver({suback, PacketId, ReasonCodes}, PState) of handle_out({suback, PacketId, ReasonCodes}, PState)
{ok, PState1} ->
AclDenyAction = emqx_zone:get_env(Zone, acl_deny_action, ignore),
do_acl_deny_action(AclDenyAction, Packet, ReasonCodes, PState1);
Error ->
Error
end
end; end;
process(?UNSUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters), process(?UNSUBSCRIBE_PACKET(PacketId, Properties, RawTopicFilters),
PState = #pstate{session = Session, credentials = Credentials}) -> PState = #protocol{session = Session, credentials = Credentials}) ->
TopicFilters = emqx_hooks:run_fold('client.unsubscribe', [Credentials], TopicFilters = emqx_hooks:run_fold('client.unsubscribe', [Credentials],
parse_topic_filters(?UNSUBSCRIBE, RawTopicFilters)), parse_topic_filters(?UNSUBSCRIBE, RawTopicFilters)),
TopicFilters1 = emqx_mountpoint:mount(mountpoint(Credentials), TopicFilters), TopicFilters1 = emqx_mountpoint:mount(mountpoint(Credentials), TopicFilters),
{ok, ReasonCodes, NSession} = emqx_session:unsubscribe(TopicFilters1, Session), {ok, ReasonCodes, NSession} = emqx_session:unsubscribe(TopicFilters1, Session),
deliver({unsuback, PacketId, ReasonCodes}, PState#pstate{session = NSession}); handle_out({unsuback, PacketId, ReasonCodes}, PState#protocol{session = NSession});
process(?PACKET(?PINGREQ), PState) -> process(?PACKET(?PINGREQ), PState) ->
send(?PACKET(?PINGRESP), PState); {ok, ?PACKET(?PINGRESP), PState};
process(?DISCONNECT_PACKET(?RC_SUCCESS, #{'Session-Expiry-Interval' := Interval}), process(?DISCONNECT_PACKET(?RC_SUCCESS, #{'Session-Expiry-Interval' := Interval}),
PState = #pstate{session = Session, conn_props = #{'Session-Expiry-Interval' := OldInterval}}) -> PState = #protocol{session = Session, conn_props = #{'Session-Expiry-Interval' := OldInterval}}) ->
case Interval =/= 0 andalso OldInterval =:= 0 of case Interval =/= 0 andalso OldInterval =:= 0 of
true -> true ->
deliver({disconnect, ?RC_PROTOCOL_ERROR}, PState), handle_out({disconnect, ?RC_PROTOCOL_ERROR}, PState#protocol{will_msg = undefined});
{error, protocol_error, PState#pstate{will_msg = undefined}};
false -> false ->
NSession = emqx_session:update_expiry_interval(Interval, Session), %% TODO:
%% emqx_session:update_expiry_interval(SPid, Interval),
%% Clean willmsg %% Clean willmsg
{stop, normal, PState#pstate{will_msg = undefined, session = NSession}} {stop, normal, PState#protocol{will_msg = undefined}}
end; end;
process(?DISCONNECT_PACKET(?RC_SUCCESS), PState) -> process(?DISCONNECT_PACKET(?RC_SUCCESS), PState) ->
{stop, normal, PState#pstate{will_msg = undefined}}; {stop, normal, PState#protocol{will_msg = undefined}};
process(?DISCONNECT_PACKET(_), PState) -> process(?DISCONNECT_PACKET(_), PState) ->
{stop, {shutdown, abnormal_disconnet}, PState}. {stop, {shutdown, abnormal_disconnet}, PState};
%%------------------------------------------------------------------------------ process(?AUTH_PACKET(), State) ->
%%TODO: implement later.
{ok, State}.
%%--------------------------------------------------------------------
%% ConnAck --> Client %% ConnAck --> Client
%%------------------------------------------------------------------------------ %%--------------------------------------------------------------------
connack({?RC_SUCCESS, SP, PState = #pstate{credentials = Credentials}}) -> connack({?RC_SUCCESS, SP, PState = #protocol{credentials = Credentials}}) ->
ok = emqx_hooks:run('client.connected', [Credentials, ?RC_SUCCESS, attrs(PState)]), ok = emqx_hooks:run('client.connected', [Credentials, ?RC_SUCCESS, attrs(PState)]),
deliver({connack, ?RC_SUCCESS, sp(SP)}, PState); handle_out({connack, ?RC_SUCCESS, sp(SP)}, PState);
connack({ReasonCode, PState = #pstate{proto_ver = ProtoVer, credentials = Credentials}}) -> connack({ReasonCode, PState = #protocol{proto_ver = ProtoVer, credentials = Credentials}}) ->
ok = emqx_hooks:run('client.connected', [Credentials, ReasonCode, attrs(PState)]), ok = emqx_hooks:run('client.connected', [Credentials, ReasonCode, attrs(PState)]),
[ReasonCode1] = reason_codes_compat(connack, [ReasonCode], ProtoVer), [ReasonCode1] = reason_codes_compat(connack, [ReasonCode], ProtoVer),
_ = deliver({connack, ReasonCode1}, PState), handle_out({connack, ReasonCode1}, PState).
{error, emqx_reason_codes:name(ReasonCode1, ProtoVer), PState}.
%%-------------------------------------------------------------------- %%------------------------------------------------------------------------------
%% Publish Message -> Broker %% Publish Message -> Broker
%%-------------------------------------------------------------------- %%------------------------------------------------------------------------------
do_publish(Packet = ?PUBLISH_PACKET(QoS, PacketId), do_publish(Packet = ?PUBLISH_PACKET(QoS, PacketId),
PState = #pstate{session = Session, credentials = Credentials}) -> PState = #protocol{session = Session, credentials = Credentials}) ->
Msg = emqx_mountpoint:mount(mountpoint(Credentials), Msg = emqx_mountpoint:mount(mountpoint(Credentials),
emqx_packet:to_message(Credentials, Packet)), emqx_packet:to_message(Credentials, Packet)),
Msg1 = emqx_message:set_flag(dup, false, Msg), Msg1 = emqx_message:set_flag(dup, false, Msg),
case emqx_session:publish(PacketId, Msg1, Session) of case emqx_session:publish(PacketId, Msg1, Session) of
{ok, Result} -> {ok, Results} ->
puback(QoS, PacketId, {ok, Result}, PState); puback(QoS, PacketId, Results, PState);
{ok, Result, NSession} -> {ok, Results, NSession} ->
puback(QoS, PacketId, {ok, Result}, PState#pstate{session = NSession}); puback(QoS, PacketId, Results, PState#protocol{session = NSession});
{error, ReasonCode} -> {error, Reason} ->
puback(QoS, PacketId, {error, ReasonCode}, PState) puback(QoS, PacketId, {error, Reason}, PState)
end. end.
%%-------------------------------------------------------------------- %%------------------------------------------------------------------------------
%% Puback -> Client %% Puback -> Client
%%-------------------------------------------------------------------- %%------------------------------------------------------------------------------
puback(?QOS_0, _PacketId, _Result, PState) -> puback(?QOS_0, _PacketId, _Result, PState) ->
{ok, PState}; {ok, PState};
puback(?QOS_1, PacketId, {ok, []}, PState) -> puback(?QOS_1, PacketId, {ok, []}, PState) ->
deliver({puback, PacketId, ?RC_NO_MATCHING_SUBSCRIBERS}, PState); handle_out({puback, PacketId, ?RC_NO_MATCHING_SUBSCRIBERS}, PState);
%%TODO: calc the deliver count? %%TODO: calc the deliver count?
puback(?QOS_1, PacketId, {ok, _Result}, PState) -> puback(?QOS_1, PacketId, {ok, _Result}, PState) ->
deliver({puback, PacketId, ?RC_SUCCESS}, PState); handle_out({puback, PacketId, ?RC_SUCCESS}, PState);
puback(?QOS_1, PacketId, {error, ReasonCode}, PState) -> puback(?QOS_1, PacketId, {error, ReasonCode}, PState) ->
deliver({puback, PacketId, ReasonCode}, PState); handle_out({puback, PacketId, ReasonCode}, PState);
puback(?QOS_2, PacketId, {ok, []}, PState) -> puback(?QOS_2, PacketId, {ok, []}, PState) ->
deliver({pubrec, PacketId, ?RC_NO_MATCHING_SUBSCRIBERS}, PState); handle_out({pubrec, PacketId, ?RC_NO_MATCHING_SUBSCRIBERS}, PState);
puback(?QOS_2, PacketId, {ok, _Result}, PState) -> puback(?QOS_2, PacketId, {ok, _Result}, PState) ->
deliver({pubrec, PacketId, ?RC_SUCCESS}, PState); handle_out({pubrec, PacketId, ?RC_SUCCESS}, PState);
puback(?QOS_2, PacketId, {error, ReasonCode}, PState) -> puback(?QOS_2, PacketId, {error, ReasonCode}, PState) ->
deliver({pubrec, PacketId, ReasonCode}, PState). handle_out({pubrec, PacketId, ReasonCode}, PState).
%%------------------------------------------------------------------------------ %%--------------------------------------------------------------------
%% Deliver Packet -> Client %% Handle outgoing
%%------------------------------------------------------------------------------ %%--------------------------------------------------------------------
-spec(deliver(list(tuple()) | tuple(), state()) -> {ok, state()} | {error, term()}). handle_out({connack, ?RC_SUCCESS, SP}, PState = #protocol{zone = Zone,
deliver([], PState) ->
{ok, PState};
deliver([Pub|More], PState) ->
case deliver(Pub, PState) of
{ok, PState1} ->
deliver(More, PState1);
{error, _} = Error ->
Error
end;
deliver({connack, ReasonCode}, PState) ->
send(?CONNACK_PACKET(ReasonCode), PState);
deliver({connack, ?RC_SUCCESS, SP}, PState = #pstate{zone = Zone,
proto_ver = ?MQTT_PROTO_V5, proto_ver = ?MQTT_PROTO_V5,
client_id = ClientId, client_id = ClientId,
is_assigned = IsAssigned, is_assigned = IsAssigned,
@ -668,81 +621,76 @@ deliver({connack, ?RC_SUCCESS, SP}, PState = #pstate{zone = Zone,
Keepalive -> Props2#{'Server-Keep-Alive' => Keepalive} Keepalive -> Props2#{'Server-Keep-Alive' => Keepalive}
end, end,
PState1 = PState#pstate{topic_alias_maximum = TopicAliasMaximum#{from_client => MaxAlias}}, PState1 = PState#protocol{topic_alias_maximum = TopicAliasMaximum#{from_client => MaxAlias}},
send(?CONNACK_PACKET(?RC_SUCCESS, SP, Props3), PState1); {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, Props3), PState1};
deliver({connack, ReasonCode, SP}, PState) -> handle_out({connack, ?RC_SUCCESS, SP}, PState) ->
send(?CONNACK_PACKET(ReasonCode, SP), PState); {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP), PState};
deliver({publish, PacketId, Msg}, PState = #pstate{credentials = Credentials}) -> handle_out({connack, ReasonCode}, PState = #protocol{proto_ver = ProtoVer}) ->
Msg0 = emqx_hooks:run_fold('message.deliver', [Credentials], Msg), Reason = emqx_reason_codes:name(ReasonCode, ProtoVer),
Msg1 = emqx_message:update_expiry(Msg0), {error, Reason, ?CONNACK_PACKET(ReasonCode), PState};
Msg2 = emqx_mountpoint:unmount(mountpoint(Credentials), Msg1),
send(emqx_packet:from_message(PacketId, Msg2), PState);
deliver({puback, PacketId, ReasonCode}, PState) -> handle_out({puback, PacketId, ReasonCode}, PState) ->
send(?PUBACK_PACKET(PacketId, ReasonCode), PState); {ok, ?PUBACK_PACKET(PacketId, ReasonCode), PState};
%% TODO:
%% AclDenyAction = emqx_zone:get_env(Zone, acl_deny_action, ignore),
%% do_acl_deny_action(AclDenyAction, Packet, ReasonCode, PState1);
deliver({pubrel, PacketId}, PState) -> handle_out({pubrel, PacketId}, PState) ->
send(?PUBREL_PACKET(PacketId), PState); {ok, ?PUBREL_PACKET(PacketId), PState};
deliver({pubrec, PacketId, ReasonCode}, PState) -> handle_out({pubrec, PacketId, ReasonCode}, PState) ->
send(?PUBREC_PACKET(PacketId, ReasonCode), PState); %% TODO:
%% AclDenyAction = emqx_zone:get_env(Zone, acl_deny_action, ignore),
%% do_acl_deny_action(AclDenyAction, Packet, ReasonCode, PState1);
{ok, ?PUBREC_PACKET(PacketId, ReasonCode), PState};
deliver({suback, PacketId, ReasonCodes}, PState = #pstate{proto_ver = ProtoVer}) -> %%handle_out({pubrec, PacketId, ReasonCode}, PState) ->
send(?SUBACK_PACKET(PacketId, reason_codes_compat(suback, ReasonCodes, ProtoVer)), PState); %% {ok, ?PUBREC_PACKET(PacketId, ReasonCode), PState};
deliver({unsuback, PacketId, ReasonCodes}, PState = #pstate{proto_ver = ProtoVer}) -> handle_out({suback, PacketId, ReasonCodes}, PState = #protocol{proto_ver = ProtoVer}) ->
send(?UNSUBACK_PACKET(PacketId, reason_codes_compat(unsuback, ReasonCodes, ProtoVer)), PState); %% TODO: ACL Deny
{ok, ?SUBACK_PACKET(PacketId, reason_codes_compat(suback, ReasonCodes, ProtoVer)), PState};
handle_out({unsuback, PacketId, ReasonCodes}, PState = #protocol{proto_ver = ProtoVer}) ->
{ok, ?UNSUBACK_PACKET(PacketId, reason_codes_compat(unsuback, ReasonCodes, ProtoVer)), PState};
%% Deliver a disconnect for mqtt 5.0 %% Deliver a disconnect for mqtt 5.0
deliver({disconnect, ReasonCode}, PState = #pstate{proto_ver = ?MQTT_PROTO_V5}) -> handle_out({disconnect, RC}, PState = #protocol{proto_ver = ?MQTT_PROTO_V5}) ->
send(?DISCONNECT_PACKET(ReasonCode), PState); {error, emqx_reason_codes:name(RC), ?DISCONNECT_PACKET(RC), PState};
deliver({disconnect, _ReasonCode}, PState) -> handle_out({disconnect, RC}, PState) ->
{error, emqx_reason_codes:name(RC), PState}.
handle_timeout(Timer, Name, PState) ->
{ok, PState}. {ok, PState}.
%%------------------------------------------------------------------------------
%% Send Packet to Client
-spec(send(emqx_mqtt_types:packet(), state()) -> {ok, state()} | {error, term()}).
send(Packet = ?PACKET(Type), PState = #pstate{proto_ver = Ver, sendfun = Send}) ->
case Send(Packet, #{version => Ver}) of
ok ->
trace(send, Packet),
{ok, PState};
{ok, Data} ->
trace(send, Packet),
emqx_metrics:inc_sent(Packet),
ok = emqx_metrics:inc('bytes.sent', iolist_size(Data)),
{ok, inc_stats(send, Type, PState)};
{error, Reason} ->
{error, Reason}
end.
%%------------------------------------------------------------------------------ %%------------------------------------------------------------------------------
%% Maybe use username replace client id %% Maybe use username replace client id
maybe_use_username_as_clientid(ClientId, PState = #pstate{username = undefined}) -> maybe_use_username_as_clientid(ClientId, PState = #protocol{username = undefined}) ->
PState#pstate{client_id = ClientId}; PState#protocol{client_id = ClientId};
maybe_use_username_as_clientid(ClientId, PState = #pstate{username = Username, zone = Zone}) -> maybe_use_username_as_clientid(ClientId, PState = #protocol{username = Username, zone = Zone}) ->
case emqx_zone:get_env(Zone, use_username_as_clientid, false) of case emqx_zone:get_env(Zone, use_username_as_clientid, false) of
true -> PState#pstate{client_id = Username}; true ->
false -> PState#pstate{client_id = ClientId} PState#protocol{client_id = Username};
false ->
PState#protocol{client_id = ClientId}
end. end.
%%------------------------------------------------------------------------------ %%------------------------------------------------------------------------------
%% Assign a clientId %% Assign a clientId
maybe_assign_client_id(PState = #pstate{client_id = <<>>, ack_props = AckProps}) -> maybe_assign_client_id(PState = #protocol{client_id = <<>>, ack_props = AckProps}) ->
ClientId = emqx_guid:to_base62(emqx_guid:gen()), ClientId = emqx_guid:to_base62(emqx_guid:gen()),
AckProps1 = set_property('Assigned-Client-Identifier', ClientId, AckProps), AckProps1 = set_property('Assigned-Client-Identifier', ClientId, AckProps),
PState#pstate{client_id = ClientId, is_assigned = true, ack_props = AckProps1}; PState#protocol{client_id = ClientId, is_assigned = true, ack_props = AckProps1};
maybe_assign_client_id(PState) -> maybe_assign_client_id(PState) ->
PState. PState.
try_open_session(SessAttrs, PState = #pstate{zone = Zone, try_open_session(SessAttrs, PState = #protocol{zone = Zone,
client_id = ClientId, client_id = ClientId,
username = Username, username = Username,
clean_start = CleanStart}) -> clean_start = CleanStart}) ->
@ -782,9 +730,9 @@ make_will_msg(#mqtt_packet_connect{proto_ver = ProtoVer,
ConnPkt ConnPkt
end). end).
%%------------------------------------------------------------------------------ %%--------------------------------------------------------------------
%% Check Packet %% Check Packet
%%------------------------------------------------------------------------------ %%--------------------------------------------------------------------
check_connect(Packet, PState) -> check_connect(Packet, PState) ->
run_check_steps([fun check_proto_ver/2, run_check_steps([fun check_proto_ver/2,
@ -815,7 +763,7 @@ check_client_id(#mqtt_packet_connect{client_id = <<>>,
clean_start = true}, _PState) -> clean_start = true}, _PState) ->
ok; ok;
check_client_id(#mqtt_packet_connect{client_id = ClientId}, #pstate{zone = Zone}) -> check_client_id(#mqtt_packet_connect{client_id = ClientId}, #protocol{zone = Zone}) ->
Len = byte_size(ClientId), Len = byte_size(ClientId),
MaxLen = emqx_zone:get_env(Zone, max_clientid_len), MaxLen = emqx_zone:get_env(Zone, max_clientid_len),
case (1 =< Len) andalso (Len =< MaxLen) of case (1 =< Len) andalso (Len =< MaxLen) of
@ -827,7 +775,7 @@ check_flapping(#mqtt_packet_connect{}, PState) ->
do_flapping_detect(connect, PState). do_flapping_detect(connect, PState).
check_banned(#mqtt_packet_connect{client_id = ClientId, username = Username}, check_banned(#mqtt_packet_connect{client_id = ClientId, username = Username},
#pstate{zone = Zone, peername = Peername}) -> #protocol{zone = Zone, peername = Peername}) ->
Credentials = #{client_id => ClientId, Credentials = #{client_id => ClientId,
username => Username, username => Username,
peername => Peername}, peername => Peername},
@ -845,7 +793,7 @@ check_will_topic(#mqtt_packet_connect{will_topic = WillTopic} = ConnPkt, PState)
check_will_retain(#mqtt_packet_connect{will_retain = false, proto_ver = ?MQTT_PROTO_V5}, _PState) -> check_will_retain(#mqtt_packet_connect{will_retain = false, proto_ver = ?MQTT_PROTO_V5}, _PState) ->
ok; ok;
check_will_retain(#mqtt_packet_connect{will_retain = true, proto_ver = ?MQTT_PROTO_V5}, #pstate{zone = Zone}) -> check_will_retain(#mqtt_packet_connect{will_retain = true, proto_ver = ?MQTT_PROTO_V5}, #protocol{zone = Zone}) ->
case emqx_zone:get_env(Zone, mqtt_retain_available, true) of case emqx_zone:get_env(Zone, mqtt_retain_available, true) of
true -> {error, ?RC_RETAIN_NOT_SUPPORTED}; true -> {error, ?RC_RETAIN_NOT_SUPPORTED};
false -> ok false -> ok
@ -854,7 +802,7 @@ check_will_retain(_Packet, _PState) ->
ok. ok.
check_will_acl(#mqtt_packet_connect{will_topic = WillTopic}, check_will_acl(#mqtt_packet_connect{will_topic = WillTopic},
#pstate{zone = Zone, credentials = Credentials}) -> #protocol{zone = Zone, credentials = Credentials}) ->
EnableAcl = emqx_zone:get_env(Zone, enable_acl, false), EnableAcl = emqx_zone:get_env(Zone, enable_acl, false),
case do_acl_check(EnableAcl, publish, Credentials, WillTopic) of case do_acl_check(EnableAcl, publish, Credentials, WillTopic) of
ok -> ok; ok -> ok;
@ -869,14 +817,14 @@ check_publish(Packet, PState) ->
check_pub_caps(#mqtt_packet{header = #mqtt_packet_header{qos = QoS, retain = Retain}, check_pub_caps(#mqtt_packet{header = #mqtt_packet_header{qos = QoS, retain = Retain},
variable = #mqtt_packet_publish{properties = _Properties}}, variable = #mqtt_packet_publish{properties = _Properties}},
#pstate{zone = Zone}) -> #protocol{zone = Zone}) ->
emqx_mqtt_caps:check_pub(Zone, #{qos => QoS, retain => Retain}). emqx_mqtt_caps:check_pub(Zone, #{qos => QoS, retain => Retain}).
check_pub_acl(_Packet, #pstate{credentials = #{is_superuser := IsSuper}}) check_pub_acl(_Packet, #protocol{credentials = #{is_superuser := IsSuper}})
when IsSuper -> when IsSuper ->
ok; ok;
check_pub_acl(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = Topic}}, check_pub_acl(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = Topic}},
#pstate{zone = Zone, credentials = Credentials}) -> #protocol{zone = Zone, credentials = Credentials}) ->
EnableAcl = emqx_zone:get_env(Zone, enable_acl, false), EnableAcl = emqx_zone:get_env(Zone, enable_acl, false),
do_acl_check(EnableAcl, publish, Credentials, Topic). do_acl_check(EnableAcl, publish, Credentials, Topic).
@ -890,7 +838,7 @@ run_check_steps([Check|Steps], Packet, PState) ->
Error Error
end. end.
check_subscribe(TopicFilters, PState = #pstate{zone = Zone}) -> check_subscribe(TopicFilters, PState = #protocol{zone = Zone}) ->
case emqx_mqtt_caps:check_sub(Zone, TopicFilters) of case emqx_mqtt_caps:check_sub(Zone, TopicFilters) of
{ok, TopicFilter1} -> {ok, TopicFilter1} ->
check_sub_acl(TopicFilter1, PState); check_sub_acl(TopicFilter1, PState);
@ -898,10 +846,10 @@ check_subscribe(TopicFilters, PState = #pstate{zone = Zone}) ->
{error, TopicFilter1} {error, TopicFilter1}
end. end.
check_sub_acl(TopicFilters, #pstate{credentials = #{is_superuser := IsSuper}}) check_sub_acl(TopicFilters, #protocol{credentials = #{is_superuser := IsSuper}})
when IsSuper -> when IsSuper ->
{ok, TopicFilters}; {ok, TopicFilters};
check_sub_acl(TopicFilters, #pstate{zone = Zone, credentials = Credentials}) -> check_sub_acl(TopicFilters, #protocol{zone = Zone, credentials = Credentials}) ->
EnableAcl = emqx_zone:get_env(Zone, enable_acl, false), EnableAcl = emqx_zone:get_env(Zone, enable_acl, false),
lists:foldr( lists:foldr(
fun({Topic, SubOpts}, {Ok, Acc}) when EnableAcl -> fun({Topic, SubOpts}, {Ok, Acc}) when EnableAcl ->
@ -912,26 +860,9 @@ check_sub_acl(TopicFilters, #pstate{zone = Zone, credentials = Credentials}) ->
{ok, [TopicFilter | Acc]} {ok, [TopicFilter | Acc]}
end, {ok, []}, TopicFilters). end, {ok, []}, TopicFilters).
trace(recv, Packet) -> terminate(_Reason, #protocol{client_id = undefined}) ->
?LOG(debug, "RECV ~s", [emqx_packet:format(Packet)]);
trace(send, Packet) ->
?LOG(debug, "SEND ~s", [emqx_packet:format(Packet)]).
inc_stats(recv, Type, PState = #pstate{recv_stats = Stats}) ->
PState#pstate{recv_stats = inc_stats(Type, Stats)};
inc_stats(send, Type, PState = #pstate{send_stats = Stats}) ->
PState#pstate{send_stats = inc_stats(Type, Stats)}.
inc_stats(Type, Stats = #{pkt := PktCnt, msg := MsgCnt}) ->
Stats#{pkt := PktCnt + 1, msg := case Type =:= ?PUBLISH of
true -> MsgCnt + 1;
false -> MsgCnt
end}.
terminate(_Reason, #pstate{client_id = undefined}) ->
ok; ok;
terminate(_Reason, PState = #pstate{connected = false}) -> terminate(_Reason, PState = #protocol{connected = false}) ->
do_flapping_detect(disconnect, PState), do_flapping_detect(disconnect, PState),
ok; ok;
terminate(Reason, PState) when Reason =:= conflict; terminate(Reason, PState) when Reason =:= conflict;
@ -939,20 +870,20 @@ terminate(Reason, PState) when Reason =:= conflict;
do_flapping_detect(disconnect, PState), do_flapping_detect(disconnect, PState),
ok; ok;
terminate(Reason, PState = #pstate{credentials = Credentials}) -> terminate(Reason, PState = #protocol{credentials = Credentials}) ->
do_flapping_detect(disconnect, PState), do_flapping_detect(disconnect, PState),
?LOG(info, "Shutdown for ~p", [Reason]), ?LOG(info, "Shutdown for ~p", [Reason]),
ok = emqx_hooks:run('client.disconnected', [Credentials, Reason]). ok = emqx_hooks:run('client.disconnected', [Credentials, Reason]).
start_keepalive(0, _PState) -> start_keepalive(0, _PState) ->
ignore; ignore;
start_keepalive(Secs, #pstate{zone = Zone}) when Secs > 0 -> start_keepalive(Secs, #protocol{zone = Zone}) when Secs > 0 ->
Backoff = emqx_zone:get_env(Zone, keepalive_backoff, 0.75), Backoff = emqx_zone:get_env(Zone, keepalive_backoff, 0.75),
self() ! {keepalive, start, round(Secs * Backoff)}. self() ! {keepalive, start, round(Secs * Backoff)}.
%%----------------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Parse topic filters %% Parse topic filters
%%----------------------------------------------------------------------------- %%--------------------------------------------------------------------
parse_topic_filters(?SUBSCRIBE, RawTopicFilters) -> parse_topic_filters(?SUBSCRIBE, RawTopicFilters) ->
[emqx_topic:parse(RawTopic, SubOpts) || {RawTopic, SubOpts} <- RawTopicFilters]; [emqx_topic:parse(RawTopic, SubOpts) || {RawTopic, SubOpts} <- RawTopicFilters];
@ -966,17 +897,17 @@ sp(false) -> 0.
flag(false) -> 0; flag(false) -> 0;
flag(true) -> 1. flag(true) -> 1.
%%------------------------------------------------------------------------------ %%--------------------------------------------------------------------
%% Execute actions in case acl deny %% Execute actions in case acl deny
do_flapping_detect(Action, #pstate{zone = Zone, do_flapping_detect(Action, #protocol{zone = Zone,
client_id = ClientId}) -> client_id = ClientId}) ->
ok = case emqx_zone:get_env(Zone, enable_flapping_detect, false) of ok = case emqx_zone:get_env(Zone, enable_flapping_detect, false) of
true -> true ->
Threshold = emqx_zone:get_env(Zone, flapping_threshold, {10, 60}), Threshold = emqx_zone:get_env(Zone, flapping_threshold, {10, 60}),
case emqx_flapping:check(Action, ClientId, Threshold) of case emqx_flapping:check(Action, ClientId, Threshold) of
flapping -> flapping ->
BanExpiryInterval = emqx_zone:get_env(Zone, flapping_ban_expiry_interval, 3600000), BanExpiryInterval = emqx_zone:get_env(Zone, flapping_banned_expiry_interval, 3600000),
Until = erlang:system_time(second) + BanExpiryInterval, Until = erlang:system_time(second) + BanExpiryInterval,
emqx_banned:add(#banned{who = {client_id, ClientId}, emqx_banned:add(#banned{who = {client_id, ClientId},
reason = <<"flapping">>, reason = <<"flapping">>,
@ -990,13 +921,14 @@ do_flapping_detect(Action, #pstate{zone = Zone,
end. end.
do_acl_deny_action(disconnect, ?PUBLISH_PACKET(?QOS_0, _Topic, _PacketId, _Payload), do_acl_deny_action(disconnect, ?PUBLISH_PACKET(?QOS_0, _Topic, _PacketId, _Payload),
?RC_NOT_AUTHORIZED, PState = #pstate{proto_ver = ProtoVer}) -> ?RC_NOT_AUTHORIZED, PState = #protocol{proto_ver = ProtoVer}) ->
{error, emqx_reason_codes:name(?RC_NOT_AUTHORIZED, ProtoVer), PState}; {error, emqx_reason_codes:name(?RC_NOT_AUTHORIZED, ProtoVer), PState};
do_acl_deny_action(disconnect, ?PUBLISH_PACKET(QoS, _Topic, _PacketId, _Payload), do_acl_deny_action(disconnect, ?PUBLISH_PACKET(QoS, _Topic, _PacketId, _Payload),
?RC_NOT_AUTHORIZED, PState = #pstate{proto_ver = ProtoVer}) ?RC_NOT_AUTHORIZED, PState = #protocol{proto_ver = ProtoVer})
when QoS =:= ?QOS_1; QoS =:= ?QOS_2 -> when QoS =:= ?QOS_1; QoS =:= ?QOS_2 ->
deliver({disconnect, ?RC_NOT_AUTHORIZED}, PState), %% TODO:...
%% deliver({disconnect, ?RC_NOT_AUTHORIZED}, PState),
{error, emqx_reason_codes:name(?RC_NOT_AUTHORIZED, ProtoVer), PState}; {error, emqx_reason_codes:name(?RC_NOT_AUTHORIZED, ProtoVer), PState};
do_acl_deny_action(Action, ?SUBSCRIBE_PACKET(_PacketId, _Properties, _RawTopicFilters), ReasonCodes, PState) do_acl_deny_action(Action, ?SUBSCRIBE_PACKET(_PacketId, _Properties, _RawTopicFilters), ReasonCodes, PState)
@ -1004,18 +936,18 @@ do_acl_deny_action(Action, ?SUBSCRIBE_PACKET(_PacketId, _Properties, _RawTopicFi
traverse_reason_codes(ReasonCodes, Action, PState); traverse_reason_codes(ReasonCodes, Action, PState);
do_acl_deny_action(_OtherAction, _PubSubPacket, ?RC_NOT_AUTHORIZED, PState) -> do_acl_deny_action(_OtherAction, _PubSubPacket, ?RC_NOT_AUTHORIZED, PState) ->
{ok, PState}; {ok, PState};
do_acl_deny_action(_OtherAction, _PubSubPacket, ReasonCode, PState = #pstate{proto_ver = ProtoVer}) -> do_acl_deny_action(_OtherAction, _PubSubPacket, ReasonCode, PState = #protocol{proto_ver = ProtoVer}) ->
{error, emqx_reason_codes:name(ReasonCode, ProtoVer), PState}. {error, emqx_reason_codes:name(ReasonCode, ProtoVer), PState}.
traverse_reason_codes([], _Action, PState) -> traverse_reason_codes([], _Action, PState) ->
{ok, PState}; {ok, PState};
traverse_reason_codes([?RC_SUCCESS | LeftReasonCodes], Action, PState) -> traverse_reason_codes([?RC_SUCCESS | LeftReasonCodes], Action, PState) ->
traverse_reason_codes(LeftReasonCodes, Action, PState); traverse_reason_codes(LeftReasonCodes, Action, PState);
traverse_reason_codes([?RC_NOT_AUTHORIZED | _LeftReasonCodes], disconnect, PState = #pstate{proto_ver = ProtoVer}) -> traverse_reason_codes([?RC_NOT_AUTHORIZED | _LeftReasonCodes], disconnect, PState = #protocol{proto_ver = ProtoVer}) ->
{error, emqx_reason_codes:name(?RC_NOT_AUTHORIZED, ProtoVer), PState}; {error, emqx_reason_codes:name(?RC_NOT_AUTHORIZED, ProtoVer), PState};
traverse_reason_codes([?RC_NOT_AUTHORIZED | LeftReasonCodes], Action, PState) -> traverse_reason_codes([?RC_NOT_AUTHORIZED | LeftReasonCodes], Action, PState) ->
traverse_reason_codes(LeftReasonCodes, Action, PState); traverse_reason_codes(LeftReasonCodes, Action, PState);
traverse_reason_codes([OtherCode | _LeftReasonCodes], _Action, PState = #pstate{proto_ver = ProtoVer}) -> traverse_reason_codes([OtherCode | _LeftReasonCodes], _Action, PState = #protocol{proto_ver = ProtoVer}) ->
{error, emqx_reason_codes:name(OtherCode, ProtoVer), PState}. {error, emqx_reason_codes:name(OtherCode, ProtoVer), PState}.
%% Reason code compat %% Reason code compat
@ -1026,7 +958,7 @@ reason_codes_compat(unsuback, _ReasonCodes, _ProtoVer) ->
reason_codes_compat(PktType, ReasonCodes, _ProtoVer) -> reason_codes_compat(PktType, ReasonCodes, _ProtoVer) ->
[emqx_reason_codes:compat(PktType, RC) || RC <- ReasonCodes]. [emqx_reason_codes:compat(PktType, RC) || RC <- ReasonCodes].
raw_topic_filters(#pstate{zone = Zone, proto_ver = ProtoVer, is_bridge = IsBridge}, RawTopicFilters) -> raw_topic_filters(#protocol{zone = Zone, proto_ver = ProtoVer, is_bridge = IsBridge}, RawTopicFilters) ->
IgnoreLoop = emqx_zone:get_env(Zone, ignore_loop_deliver, false), IgnoreLoop = emqx_zone:get_env(Zone, ignore_loop_deliver, false),
case ProtoVer < ?MQTT_PROTO_V5 of case ProtoVer < ?MQTT_PROTO_V5 of
true -> true ->

File diff suppressed because it is too large Load Diff

View File

@ -135,7 +135,7 @@ ack_enabled() ->
do_dispatch(SubPid, Topic, Msg, _Type) when SubPid =:= self() -> do_dispatch(SubPid, Topic, Msg, _Type) when SubPid =:= self() ->
%% Deadlock otherwise %% Deadlock otherwise
_ = erlang:send(SubPid, {dispatch, Topic, Msg}), _ = erlang:send(SubPid, {deliver, Topic, Msg}),
ok; ok;
do_dispatch(SubPid, Topic, Msg, Type) -> do_dispatch(SubPid, Topic, Msg, Type) ->
dispatch_per_qos(SubPid, Topic, Msg, Type). dispatch_per_qos(SubPid, Topic, Msg, Type).
@ -143,18 +143,18 @@ do_dispatch(SubPid, Topic, Msg, Type) ->
%% return either 'ok' (when everything is fine) or 'error' %% return either 'ok' (when everything is fine) or 'error'
dispatch_per_qos(SubPid, Topic, #message{qos = ?QOS_0} = Msg, _Type) -> dispatch_per_qos(SubPid, Topic, #message{qos = ?QOS_0} = Msg, _Type) ->
%% For QoS 0 message, send it as regular dispatch %% For QoS 0 message, send it as regular dispatch
_ = erlang:send(SubPid, {dispatch, Topic, Msg}), _ = erlang:send(SubPid, {deliver, Topic, Msg}),
ok; ok;
dispatch_per_qos(SubPid, Topic, Msg, retry) -> dispatch_per_qos(SubPid, Topic, Msg, retry) ->
%% Retry implies all subscribers nack:ed, send again without ack %% Retry implies all subscribers nack:ed, send again without ack
_ = erlang:send(SubPid, {dispatch, Topic, Msg}), _ = erlang:send(SubPid, {deliver, Topic, Msg}),
ok; ok;
dispatch_per_qos(SubPid, Topic, Msg, fresh) -> dispatch_per_qos(SubPid, Topic, Msg, fresh) ->
case ack_enabled() of case ack_enabled() of
true -> true ->
dispatch_with_ack(SubPid, Topic, Msg); dispatch_with_ack(SubPid, Topic, Msg);
false -> false ->
_ = erlang:send(SubPid, {dispatch, Topic, Msg}), _ = erlang:send(SubPid, {deliver, Topic, Msg}),
ok ok
end. end.
@ -162,7 +162,7 @@ dispatch_with_ack(SubPid, Topic, Msg) ->
%% For QoS 1/2 message, expect an ack %% For QoS 1/2 message, expect an ack
Ref = erlang:monitor(process, SubPid), Ref = erlang:monitor(process, SubPid),
Sender = self(), Sender = self(),
_ = erlang:send(SubPid, {dispatch, Topic, with_ack_ref(Msg, {Sender, Ref})}), _ = erlang:send(SubPid, {deliver, Topic, with_ack_ref(Msg, {Sender, Ref})}),
Timeout = case Msg#message.qos of Timeout = case Msg#message.qos of
?QOS_1 -> timer:seconds(?SHARED_SUB_QOS1_DISPATCH_TIMEOUT_SECONDS); ?QOS_1 -> timer:seconds(?SHARED_SUB_QOS1_DISPATCH_TIMEOUT_SECONDS);
?QOS_2 -> infinity ?QOS_2 -> infinity

View File

@ -14,6 +14,7 @@
%% limitations under the License. %% limitations under the License.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% MQTT WebSocket Channel
-module(emqx_ws_channel). -module(emqx_ws_channel).
-include("emqx.hrl"). -include("emqx.hrl").
@ -170,7 +171,8 @@ websocket_init(#state{request = Req, options = Options}) ->
parse_state = ParseState, parse_state = ParseState,
proto_state = ProtoState, proto_state = ProtoState,
enable_stats = EnableStats, enable_stats = EnableStats,
idle_timeout = IdleTimout}}. idle_timeout = IdleTimout
}}.
send_fun(WsPid) -> send_fun(WsPid) ->
fun(Packet, Options) -> fun(Packet, Options) ->
@ -242,10 +244,13 @@ websocket_info({call, From, session}, State = #state{proto_state = ProtoState})
gen_server:reply(From, emqx_protocol:session(ProtoState)), gen_server:reply(From, emqx_protocol:session(ProtoState)),
{ok, State}; {ok, State};
websocket_info({deliver, PubOrAck}, State = #state{proto_state = ProtoState}) -> websocket_info(Delivery, State = #state{proto_state = ProtoState})
case emqx_protocol:deliver(PubOrAck, ProtoState) of when element(1, Delivery) =:= deliver ->
{ok, ProtoState1} -> case emqx_protocol:handle_out(Delivery, ProtoState) of
{ok, ensure_stats_timer(State#state{proto_state = ProtoState1})}; {ok, NProtoState} ->
{ok, State#state{proto_state = NProtoState}};
{ok, Packet, NProtoState} ->
handle_outgoing(Packet, State#state{proto_state = NProtoState});
{error, Reason} -> {error, Reason} ->
shutdown(Reason, State) shutdown(Reason, State)
end; end;
@ -285,8 +290,8 @@ websocket_info({shutdown, conflict, {ClientId, NewPid}}, State) ->
?LOG(warning, "Clientid '~s' conflict with ~p", [ClientId, NewPid]), ?LOG(warning, "Clientid '~s' conflict with ~p", [ClientId, NewPid]),
shutdown(conflict, State); shutdown(conflict, State);
websocket_info({binary, Data}, State) -> %% websocket_info({binary, Data}, State) ->
{reply, {binary, Data}, State}; %% {reply, {binary, Data}, State};
websocket_info({shutdown, Reason}, State) -> websocket_info({shutdown, Reason}, State) ->
shutdown(Reason, State); shutdown(Reason, State);
@ -317,9 +322,12 @@ terminate(SockError, _Req, #state{keepalive = Keepalive,
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
handle_incoming(Packet, SuccFun, State = #state{proto_state = ProtoState}) -> handle_incoming(Packet, SuccFun, State = #state{proto_state = ProtoState}) ->
case emqx_protocol:received(Packet, ProtoState) of case emqx_protocol:handle_in(Packet, ProtoState) of
{ok, NProtoState} -> {ok, NProtoState} ->
SuccFun(State#state{proto_state = NProtoState}); SuccFun(State#state{proto_state = NProtoState});
{ok, OutPacket, NProtoState} ->
%% TODO: How to call SuccFun???
handle_outgoing(OutPacket, State#state{proto_state = NProtoState});
{error, Reason} -> {error, Reason} ->
?LOG(error, "Protocol error: ~p", [Reason]), ?LOG(error, "Protocol error: ~p", [Reason]),
shutdown(Reason, State); shutdown(Reason, State);
@ -329,7 +337,12 @@ handle_incoming(Packet, SuccFun, State = #state{proto_state = ProtoState}) ->
shutdown(Error, State#state{proto_state = NProtoState}) shutdown(Error, State#state{proto_state = NProtoState})
end. end.
handle_outgoing(Packet, State = #state{proto_state = _NProtoState}) ->
Data = emqx_frame:serialize(Packet), %% TODO:, Options),
BinSize = iolist_size(Data),
emqx_pd:update_counter(send_cnt, 1),
emqx_pd:update_counter(send_oct, BinSize),
{reply, {binary, Data}, ensure_stats_timer(State)}.
ensure_stats_timer(State = #state{enable_stats = true, ensure_stats_timer(State = #state{enable_stats = true,
stats_timer = undefined, stats_timer = undefined,