refactor protocol

This commit is contained in:
Feng Lee 2015-10-30 18:00:23 +08:00
parent 1ef861715e
commit 19930b6382
2 changed files with 95 additions and 124 deletions

View File

@ -120,11 +120,11 @@ handle_call(session, _From, State = #client_state{proto_state = ProtoState}) ->
handle_call(info, _From, State = #client_state{connection = Connection,
proto_state = ProtoState}) ->
ClientInfo = [{Key, Val} || {Key, Val} <- ?record_to_proplist(client_state, State), lists:member(Key, ?INFO_KEYS)],
ClientInfo = ?record_to_proplist(client_state, State, ?INFO_KEYS),
ProtoInfo = emqttd_protocol:info(ProtoState),
{ok, SockStats} = Connection:getstat(?SOCK_STATS),
Info = lists:append([ClientInfo, [{proto_info, ProtoInfo}, {sock_stats, SockStats}]]),
{reply, Info, State};
{noreply, lists:append([ClientInfo, [{proto_info, ProtoInfo},
{sock_stats, SockStats}]]), State};
handle_call(kick, _From, State) ->
{stop, {shutdown, kick}, ok, State};
@ -176,7 +176,7 @@ handle_info(activate_sock, State) ->
handle_info({inet_async, _Sock, _Ref, {ok, Data}}, State) ->
Size = size(Data),
?LOG(debug, "RECV: ~p", [Data], State),
?LOG(debug, "RECV <- ~p", [Data], State),
emqttd_metrics:inc('bytes/received', Size),
received(Data, rate_limit(Size, State#client_state{await_recv = false}));
@ -203,7 +203,7 @@ handle_info({keepalive, start, Interval}, State = #client_state{connection = Con
handle_info({keepalive, check}, State = #client_state{keepalive = KeepAlive}) ->
case emqttd_keepalive:check(KeepAlive) of
{ok, KeepAlive1} ->
noreply(State#state{keepalive = KeepAlive1});
noreply(State#client_state{keepalive = KeepAlive1});
{error, timeout} ->
?LOG(debug, "Keepalive timeout", [], State),
shutdown(keepalive_timeout, State);

View File

@ -24,7 +24,6 @@
%%%
%%% @end
%%%-----------------------------------------------------------------------------
-module(emqttd_protocol).
-author("Feng Lee <feng@emqtt.io>").
@ -33,6 +32,8 @@
-include("emqttd_protocol.hrl").
-include("emqttd_internal.hrl").
%% API
-export([init/3, info/1, clientid/1, client/1, session/1]).
@ -41,29 +42,26 @@
-export([process/2]).
%% Protocol State
-record(proto_state, {peername,
sendfun,
connected = false, %received CONNECT action?
proto_ver,
proto_name,
username,
client_id,
clean_sess,
session,
will_msg,
keepalive,
max_clientid_len = ?MAX_CLIENTID_LEN,
client_pid,
ws_initial_headers, %% Headers from first HTTP request for websocket client
-record(proto_state, {peername, sendfun, connected = false,
client_id, client_pid, clean_sess,
proto_ver, proto_name, username,
will_msg, keepalive, max_clientid_len = ?MAX_CLIENTID_LEN,
session, ws_initial_headers, %% Headers from first HTTP request for websocket client
connected_at}).
-type proto_state() :: #proto_state{}.
-define(INFO_KEYS, [client_id, username, clean_sess, proto_ver, proto_name,
keepalive, will_msg, ws_initial_headers, connected_at]).
-define(LOG(Level, Format, Args, State),
lager:Level([{client, State#proto_state.client_id}], "Client(~s@~s): " ++ Format,
[State#proto_state.client_id, State#proto_state.peername | Args])).
%%------------------------------------------------------------------------------
%% @doc Init protocol
%% @end
%%------------------------------------------------------------------------------
init(Peername, SendFun, Opts) ->
MaxLen = emqttd_opts:g(max_clientid_len, Opts, ?MAX_CLIENTID_LEN),
WsInitialHeaders = emqttd_opts:g(ws_initial_headers, Opts),
@ -73,38 +71,20 @@ init(Peername, SendFun, Opts) ->
client_pid = self(),
ws_initial_headers = WsInitialHeaders}.
info(#proto_state{client_id = ClientId,
username = Username,
peername = Peername,
proto_ver = ProtoVer,
proto_name = ProtoName,
keepalive = KeepAlive,
clean_sess = CleanSess,
ws_initial_headers = WsInitialHeaders,
will_msg = WillMsg,
connected_at = ConnectedAt}) ->
[{client_id, ClientId},
{username, Username},
{peername, Peername},
{proto_ver, ProtoVer},
{proto_name, ProtoName},
{keepalive, KeepAlive},
{clean_sess, CleanSess},
{ws_initial_headers, WsInitialHeaders},
{will_msg, WillMsg},
{connected_at, ConnectedAt}].
info(ProtoState) ->
?record_to_proplist(proto_state, ProtoState, ?INFO_KEYS).
clientid(#proto_state{client_id = ClientId}) ->
ClientId.
client(#proto_state{client_id = ClientId,
client_pid = ClientPid,
peername = Peername,
username = Username,
clean_sess = CleanSess,
proto_ver = ProtoVer,
keepalive = Keepalive,
will_msg = WillMsg,
client_pid = Pid,
ws_initial_headers = WsInitialHeaders,
connected_at = Time}) ->
WillTopic = if
@ -112,7 +92,7 @@ client(#proto_state{client_id = ClientId,
true -> WillMsg#mqtt_message.topic
end,
#mqtt_client{client_id = ClientId,
client_pid = Pid,
client_pid = ClientPid,
username = Username,
peername = Peername,
clean_sess = CleanSess,
@ -148,7 +128,7 @@ received(Packet = ?PACKET(_Type), State) ->
{error, Reason, State}
end.
process(Packet = ?CONNECT_PACKET(Var), State0 = #proto_state{peername = Peername}) ->
process(Packet = ?CONNECT_PACKET(Var), State0) ->
#mqtt_packet_connect{proto_ver = ProtoVer,
proto_name = ProtoName,
@ -190,10 +170,8 @@ process(Packet = ?CONNECT_PACKET(Var), State0 = #proto_state{peername = Peername
exit({shutdown, Error})
end;
{error, Reason}->
lager:error("~s@~s: username '~s' login failed for ~s",
[ClientId, emqttd_net:format(Peername), Username, Reason]),
?LOG(error, "Username '~s' login failed for ~s", [Username, Reason], State1),
{?CONNACK_CREDENTIALS, State1}
end;
ReturnCode ->
{ReturnCode, State1}
@ -203,19 +181,18 @@ process(Packet = ?CONNECT_PACKET(Var), State0 = #proto_state{peername = Peername
%% Send connack
send(?CONNACK_PACKET(ReturnCode1), State3);
process(Packet = ?PUBLISH_PACKET(_Qos, Topic, _PacketId, _Payload),
State = #proto_state{client_id = ClientId}) ->
process(Packet = ?PUBLISH_PACKET(_Qos, Topic, _PacketId, _Payload), State) ->
case check_acl(publish, Topic, State) of
allow ->
publish(Packet, State);
deny ->
lager:error("ACL Deny: ~s cannot publish to ~s", [ClientId, Topic])
?LOG(error, "Cannot publish to ~s for ACL Deny", [Topic], State)
end,
{ok, State};
process(?PUBACK_PACKET(?PUBACK, PacketId), State = #proto_state{session = Session}) ->
emqttd_session:puback(Session, PacketId), {ok, State};
emqttd_session:puback(Session, PacketId),
{ok, State};
process(?PUBACK_PACKET(?PUBREC, PacketId), State = #proto_state{session = Session}) ->
emqttd_session:pubrec(Session, PacketId),
@ -228,22 +205,21 @@ process(?PUBACK_PACKET(?PUBREL, PacketId), State = #proto_state{session = Sessio
process(?PUBACK_PACKET(?PUBCOMP, PacketId), State = #proto_state{session = Session})->
emqttd_session:pubcomp(Session, PacketId), {ok, State};
%% protect from empty topic list
%% Protect from empty topic table
process(?SUBSCRIBE_PACKET(PacketId, []), State) ->
send(?SUBACK_PACKET(PacketId, []), State);
process(?SUBSCRIBE_PACKET(PacketId, TopicTable),
State = #proto_state{client_id = ClientId, session = Session}) ->
process(?SUBSCRIBE_PACKET(PacketId, TopicTable), State = #proto_state{session = Session}) ->
AllowDenies = [check_acl(subscribe, Topic, State) || {Topic, _Qos} <- TopicTable],
case lists:member(deny, AllowDenies) of
true ->
lager:error("SUBSCRIBE from '~s' Denied: ~p", [ClientId, TopicTable]),
?LOG(error, "Cannot SUBSCRIBE ~p for ACL Deny", [TopicTable], State),
send(?SUBACK_PACKET(PacketId, [16#80 || _ <- TopicTable]), State);
false ->
emqttd_session:subscribe(Session, PacketId, TopicTable), {ok, State}
end;
%% protect from empty topic list
%% Protect from empty topic list
process(?UNSUBSCRIBE_PACKET(PacketId, []), State) ->
send(?UNSUBACK_PACKET(PacketId), State);
@ -255,72 +231,65 @@ process(?PACKET(?PINGREQ), State) ->
send(?PACKET(?PINGRESP), State);
process(?PACKET(?DISCONNECT), State) ->
% clean willmsg
% Clean willmsg
{stop, normal, State#proto_state{will_msg = undefined}}.
publish(Packet = ?PUBLISH_PACKET(?QOS_0, _PacketId),
#proto_state{client_id = ClientId, session = Session}) ->
Msg = emqttd_message:from_packet(ClientId, Packet),
emqttd_session:publish(Session, Msg);
emqttd_session:publish(Session, emqttd_message:from_packet(ClientId, Packet));
publish(Packet = ?PUBLISH_PACKET(?QOS_1, PacketId),
publish(Packet = ?PUBLISH_PACKET(?QOS_1, _PacketId), State) ->
with_puback(?PUBACK, Packet, State);
publish(Packet = ?PUBLISH_PACKET(?QOS_2, _PacketId), State) ->
with_puback(?PUBREC, Packet, State).
with_puback(Type, Packet = ?PUBLISH_PACKET(_Qos, PacketId),
State = #proto_state{client_id = ClientId, session = Session}) ->
Msg = emqttd_message:from_packet(ClientId, Packet),
case emqttd_session:publish(Session, Msg) of
ok ->
send(?PUBACK_PACKET(?PUBACK, PacketId), State);
send(?PUBACK_PACKET(Type, PacketId), State);
{error, Error} ->
lager:error("Client(~s): publish qos1 error - ~p", [ClientId, Error])
end;
publish(Packet = ?PUBLISH_PACKET(?QOS_2, PacketId),
State = #proto_state{client_id = ClientId, session = Session}) ->
Msg = emqttd_message:from_packet(ClientId, Packet),
case emqttd_session:publish(Session, Msg) of
ok ->
send(?PUBACK_PACKET(?PUBREC, PacketId), State);
{error, Error} ->
lager:error("Client(~s): publish qos2 error - ~p", [ClientId, Error])
?LOG(error, "PUBLISH ~p error: ~p", [PacketId, Error], State)
end.
-spec send(mqtt_message() | mqtt_packet(), proto_state()) -> {ok, proto_state()}.
send(Msg, State) when is_record(Msg, mqtt_message) ->
send(emqttd_message:to_packet(Msg), State);
send(Packet, State = #proto_state{sendfun = SendFun, peername = Peername})
send(Packet, State = #proto_state{sendfun = SendFun})
when is_record(Packet, mqtt_packet) ->
trace(send, Packet, State),
emqttd_metrics:sent(Packet),
Data = emqttd_serialiser:serialise(Packet),
lager:debug("SENT to ~s: ~p", [emqttd_net:format(Peername), Data]),
?LOG(debug, "SENT -> ~p", [Data], State),
emqttd_metrics:inc('bytes/sent', size(Data)),
SendFun(Data),
{ok, State}.
trace(recv, Packet, #proto_state{peername = Peername, client_id = ClientId}) ->
lager:info([{client, ClientId}], "RECV from ~s@~s: ~s",
[ClientId, emqttd_net:format(Peername), emqttd_packet:format(Packet)]);
trace(recv, Packet, ProtoState) ->
trace2("RECV <-", Packet, ProtoState);
trace(send, Packet, #proto_state{peername = Peername, client_id = ClientId}) ->
lager:info([{client, ClientId}], "SEND to ~s@~s: ~s",
[ClientId, emqttd_net:format(Peername), emqttd_packet:format(Packet)]).
trace(send, Packet, ProtoState) ->
trace2("SEND ->", Packet, ProtoState).
trace2(Tag, Packet, #proto_state{peername = Peername, client_id = ClientId}) ->
lager:info([{client, ClientId}], "Client(~s@~s): ~s ~s",
[ClientId, Peername, Tag, emqttd_packet:format(Packet)]).
%% @doc redeliver PUBREL PacketId
redeliver({?PUBREL, PacketId}, State) ->
send(?PUBREL_PACKET(PacketId), State).
shutdown(Error, #proto_state{client_id = undefined}) ->
lager:info("Protocol shutdown ~p", [Error]),
shutdown(_Error, #proto_state{client_id = undefined}) ->
ignore;
shutdown(duplicate_id, #proto_state{client_id = ClientId}) ->
%% unregister the device
shutdown(confict, #proto_state{client_id = ClientId}) ->
emqttd_cm:unregister(ClientId);
%% TODO: ClientId??
shutdown(Error, #proto_state{peername = Peername, client_id = ClientId, will_msg = WillMsg}) ->
lager:info([{client, ClientId}], "Client ~s@~s: shutdown ~p",
[ClientId, emqttd_net:format(Peername), Error]),
shutdown(Error, State = #proto_state{client_id = ClientId, will_msg = WillMsg}) ->
?LOG(info, "shutdown for ~p", [Error], State),
send_willmsg(ClientId, WillMsg),
emqttd_broker:foreach_hooks('client.disconnected', [Error, ClientId]),
emqttd_cm:unregister(ClientId).
@ -341,7 +310,6 @@ maybe_set_clientid(State) ->
send_willmsg(_ClientId, undefined) ->
ignore;
send_willmsg(ClientId, WillMsg) ->
lager:info("Client ~s send willmsg: ~p", [ClientId, WillMsg]),
emqttd_pubsub:publish(WillMsg#mqtt_message{from = ClientId}).
start_keepalive(0) -> ignore;
@ -368,7 +336,8 @@ validate_connect(Connect = #mqtt_packet_connect{}, ProtoState) ->
validate_protocol(#mqtt_packet_connect{proto_ver = Ver, proto_name = Name}) ->
lists:member({Ver, Name}, ?PROTOCOL_NAMES).
validate_clientid(#mqtt_packet_connect{client_id = ClientId}, #proto_state{max_clientid_len = MaxLen})
validate_clientid(#mqtt_packet_connect{client_id = ClientId},
#proto_state{max_clientid_len = MaxLen})
when (size(ClientId) >= 1) andalso (size(ClientId) =< MaxLen) ->
true;
@ -378,42 +347,44 @@ validate_clientid(#mqtt_packet_connect{proto_ver =?MQTT_PROTO_V311,
when size(ClientId) =:= 0 ->
true;
validate_clientid(#mqtt_packet_connect{proto_ver = Ver,
clean_sess = CleanSess,
client_id = ClientId}, _ProtoState) ->
lager:warning("Invalid ClientId: ~s, ProtoVer: ~p, CleanSess: ~s", [ClientId, Ver, CleanSess]),
validate_clientid(#mqtt_packet_connect{proto_ver = ProtoVer,
clean_sess = CleanSess}, ProtoState) ->
?LOG(warning, "Invalid clientId. ProtoVer: ~p, CleanSess: ~s",
[ProtoVer, CleanSess], ProtoState),
false.
validate_packet(#mqtt_packet{header = #mqtt_packet_header{type = ?PUBLISH},
variable = #mqtt_packet_publish{topic_name = Topic}}) ->
validate_packet(?PUBLISH_PACKET(_Qos, Topic, _PacketId, _Payload)) ->
case emqttd_topic:validate({name, Topic}) of
true -> ok;
false -> lager:warning("Error publish topic: ~p", [Topic]), {error, badtopic}
false -> {error, badtopic}
end;
validate_packet(#mqtt_packet{header = #mqtt_packet_header{type = ?SUBSCRIBE},
variable = #mqtt_packet_subscribe{topic_table = Topics}}) ->
validate_topics(filter, Topics);
validate_packet(#mqtt_packet{header = #mqtt_packet_header{type = ?UNSUBSCRIBE},
variable = #mqtt_packet_subscribe{topic_table = Topics}}) ->
validate_packet(?SUBSCRIBE_PACKET(_PacketId, TopicTable)) ->
validate_topics(filter, TopicTable);
validate_packet(?UNSUBSCRIBE_PACKET(_PacketId, Topics)) ->
validate_topics(filter, Topics);
validate_packet(_Packet) ->
ok.
validate_topics(Type, []) when Type =:= name orelse Type =:= filter ->
lager:error("Empty Topics!"),
validate_topics(_Type, []) ->
{error, empty_topics};
validate_topics(Type, Topics) when Type =:= name orelse Type =:= filter ->
ErrTopics = [Topic || {Topic, Qos} <- Topics,
not (emqttd_topic:validate({Type, Topic}) and validate_qos(Qos))],
case ErrTopics of
validate_topics(Type, TopicTable = [{_Topic, _Qos}|_])
when Type =:= name orelse Type =:= filter ->
Valid = fun(Topic, Qos) ->
emqttd_topic:validate({Type, Topic}) and validate_qos(Qos)
end,
case [Topic || {Topic, Qos} <- TopicTable, not Valid(Topic, Qos)] of
[] -> ok;
_ -> lager:error("Error Topics: ~p", [ErrTopics]), {error, badtopic}
_ -> {error, badtopic}
end;
validate_topics(Type, Topics = [Topic0|_]) when is_binary(Topic0) ->
case [Topic || Topic <- Topics, not emqttd_topic:validate({Type, Topic})] of
[] -> ok;
_ -> {error, badtopic}
end.
validate_qos(undefined) ->
@ -423,7 +394,7 @@ validate_qos(Qos) when ?IS_QOS(Qos) ->
validate_qos(_) ->
false.
%% publish ACL is cached in process dictionary.
%% PUBLISH ACL is cached in process dictionary.
check_acl(publish, Topic, State) ->
case get({acl, publish, Topic}) of
undefined ->