support to configure max packet size, fix issue#57

This commit is contained in:
Ery Lee 2015-03-12 15:53:52 +08:00
parent 320fd31ca9
commit 43e1b7e966
5 changed files with 85 additions and 62 deletions

View File

@ -56,19 +56,22 @@ open(Listeners) when is_list(Listeners) ->
%% open mqtt port %% open mqtt port
open({mqtt, Port, Options}) -> open({mqtt, Port, Options}) ->
MFArgs = {emqttd_client, start_link, []}, open(mqtt, Port, Options);
esockd:open(mqtt, Port, emqttd_opts:merge(?MQTT_SOCKOPTS, Options) , MFArgs);
%% open mqtt(SSL) port %% open mqtt(SSL) port
open({mqtts, Port, Options}) -> open({mqtts, Port, Options}) ->
MFArgs = {emqttd_client, start_link, []}, open(mqtts, Port, Options);
esockd:open(mqtts, Port, emqttd_opts:merge(?MQTT_SOCKOPTS, Options) , MFArgs);
%% open http port %% open http port
open({http, Port, Options}) -> open({http, Port, Options}) ->
MFArgs = {emqttd_http, handle, []}, MFArgs = {emqttd_http, handle, []},
mochiweb:start_http(Port, Options, MFArgs). mochiweb:start_http(Port, Options, MFArgs).
open(Protocol, Port, Options) ->
{ok, PktOpts} = application:get_env(emqttd, packet),
MFArgs = {emqttd_client, start_link, [PktOpts]},
esockd:open(Protocol, Port, emqttd_opts:merge(?MQTT_SOCKOPTS, Options) , MFArgs).
is_running(Node) -> is_running(Node) ->
case rpc:call(Node, erlang, whereis, [emqttd]) of case rpc:call(Node, erlang, whereis, [emqttd]) of
{badrpc, _} -> false; {badrpc, _} -> false;

View File

@ -30,7 +30,7 @@
-behaviour(gen_server). -behaviour(gen_server).
-export([start_link/1, info/1]). -export([start_link/2, info/1]).
-export([init/1, -export([init/1,
handle_call/3, handle_call/3,
@ -53,30 +53,34 @@
conserve, conserve,
parse_state, parse_state,
proto_state, proto_state,
packet_opts,
keepalive}). keepalive}).
start_link(SockArgs) -> start_link(SockArgs, PktOpts) ->
{ok, proc_lib:spawn_link(?MODULE, init, [SockArgs])}. {ok, proc_lib:spawn_link(?MODULE, init, [[SockArgs, PktOpts]])}.
%%TODO: rename? %%TODO: rename?
info(Pid) -> info(Pid) ->
gen_server:call(Pid, info). gen_server:call(Pid, info).
init(SockArgs = {Transport, Sock, _SockFun}) -> init([SockArgs = {Transport, Sock, _SockFun}, PacketOpts]) ->
%transform if ssl. %transform if ssl.
{ok, NewSock} = esockd_connection:accept(SockArgs), {ok, NewSock} = esockd_connection:accept(SockArgs),
{ok, Peername} = emqttd_net:peer_string(Sock), {ok, Peername} = emqttd_net:peer_string(Sock),
{ok, ConnStr} = emqttd_net:connection_string(Sock, inbound), {ok, ConnStr} = emqttd_net:connection_string(Sock, inbound),
lager:info("Connect from ~s", [ConnStr]), lager:info("Connect from ~s", [ConnStr]),
ParserState = emqttd_parser:init(PacketOpts),
ProtoState = emqttd_protocol:init({Transport, NewSock, Peername}, PacketOpts),
State = control_throttle(#state{transport = Transport, State = control_throttle(#state{transport = Transport,
socket = NewSock, socket = NewSock,
peer_name = Peername, peer_name = Peername,
conn_name = ConnStr, conn_name = ConnStr,
await_recv = false, await_recv = false,
conn_state = running, conn_state = running,
conserve = false, conserve = false,
parse_state = emqttd_parser:init(), packet_opts = PacketOpts,
proto_state = emqttd_protocol:init(Transport, NewSock, Peername)}), parse_state = ParserState,
proto_state = ProtoState}),
gen_server:enter_loop(?MODULE, [], State, 10000). gen_server:enter_loop(?MODULE, [], State, 10000).
%%TODO: Not enough... %%TODO: Not enough...
@ -168,7 +172,8 @@ code_change(_OldVsn, State, _Extra) ->
process_received_bytes(<<>>, State) -> process_received_bytes(<<>>, State) ->
{noreply, State, hibernate}; {noreply, State, hibernate};
process_received_bytes(Bytes, State = #state{parse_state = ParseState, process_received_bytes(Bytes, State = #state{packet_opts = PacketOpts,
parse_state = ParseState,
proto_state = ProtoState, proto_state = ProtoState,
conn_name = ConnStr}) -> conn_name = ConnStr}) ->
case emqttd_parser:parse(Bytes, ParseState) of case emqttd_parser:parse(Bytes, ParseState) of
@ -180,7 +185,7 @@ process_received_bytes(Bytes, State = #state{parse_state = ParseState,
received_stats(Packet), received_stats(Packet),
case emqttd_protocol:received(Packet, ProtoState) of case emqttd_protocol:received(Packet, ProtoState) of
{ok, ProtoState1} -> {ok, ProtoState1} ->
process_received_bytes(Rest, State#state{parse_state = emqttd_parser:init(), process_received_bytes(Rest, State#state{parse_state = emqttd_parser:init(PacketOpts),
proto_state = ProtoState1}); proto_state = ProtoState1});
{error, Error} -> {error, Error} ->
lager:error("MQTT protocol error ~p for connection ~p~n", [Error, ConnStr]), lager:error("MQTT protocol error ~p for connection ~p~n", [Error, ConnStr]),

View File

@ -31,7 +31,11 @@
-include("emqttd_packet.hrl"). -include("emqttd_packet.hrl").
%% API %% API
-export([init/0, parse/2]). -export([init/1, parse/2]).
-record(mqtt_packet_limit, {max_packet_size}).
-type option() :: {atom(), any()}.
%%%----------------------------------------------------------------------------- %%%-----------------------------------------------------------------------------
%% @doc %% @doc
@ -39,8 +43,11 @@
%% %%
%% @end %% @end
%%%----------------------------------------------------------------------------- %%%-----------------------------------------------------------------------------
-spec init() -> none. -spec init(Opts :: [option()]) -> {none, #mqtt_packet_limit{}}.
init() -> none. init(Opts) -> {none, limit(Opts)}.
limit(Opts) ->
#mqtt_packet_limit{max_packet_size = proplists:get_value(max_packet_size, Opts, ?MAX_LEN)}.
%%%----------------------------------------------------------------------------- %%%-----------------------------------------------------------------------------
%% @doc %% @doc
@ -48,33 +55,36 @@ init() -> none.
%% %%
%% @end %% @end
%%%----------------------------------------------------------------------------- %%%-----------------------------------------------------------------------------
-spec parse(binary(), none | fun()) -> {ok, mqtt_packet()} | {error, any()} | {more, fun()}. -spec parse(binary(), {none, [option()]} | fun()) -> {ok, mqtt_packet()} | {error, any()} | {more, fun()}.
parse(<<>>, none) -> parse(<<>>, {none, Limit}) ->
{more, fun(Bin) -> parse(Bin, none) end}; {more, fun(Bin) -> parse(Bin, {none, Limit}) end};
parse(<<PacketType:4, Dup:1, QoS:2, Retain:1, Rest/binary>>, none) -> parse(<<PacketType:4, Dup:1, QoS:2, Retain:1, Rest/binary>>, {none, Limit}) ->
parse_remaining_len(Rest, #mqtt_packet_header{type = PacketType, parse_remaining_len(Rest, #mqtt_packet_header{type = PacketType,
dup = bool(Dup), dup = bool(Dup),
qos = QoS, qos = QoS,
retain = bool(Retain)}); retain = bool(Retain)}, Limit);
parse(Bin, Cont) -> Cont(Bin). parse(Bin, Cont) -> Cont(Bin).
parse_remaining_len(<<>>, Header) -> parse_remaining_len(<<>>, Header, Limit) ->
{more, fun(Bin) -> parse_remaining_len(Bin, Header) end}; {more, fun(Bin) -> parse_remaining_len(Bin, Header, Limit) end};
parse_remaining_len(Rest, Header) -> parse_remaining_len(Rest, Header, Limit) ->
parse_remaining_len(Rest, Header, 1, 0). parse_remaining_len(Rest, Header, 1, 0, Limit).
parse_remaining_len(_Bin, _Header, _Multiplier, Length) parse_remaining_len(_Bin, _Header, _Multiplier, Length, #mqtt_packet_limit{max_packet_size = MaxLen})
when Length > ?MAX_LEN -> when Length > MaxLen ->
{error, invalid_mqtt_frame_len}; {error, invalid_mqtt_frame_len};
parse_remaining_len(<<>>, Header, Multiplier, Length) -> parse_remaining_len(<<>>, Header, Multiplier, Length, Limit) ->
{more, fun(Bin) -> parse_remaining_len(Bin, Header, Multiplier, Length) end}; {more, fun(Bin) -> parse_remaining_len(Bin, Header, Multiplier, Length, Limit) end};
parse_remaining_len(<<1:1, Len:7, Rest/binary>>, Header, Multiplier, Value) -> parse_remaining_len(<<1:1, Len:7, Rest/binary>>, Header, Multiplier, Value, Limit) ->
parse_remaining_len(Rest, Header, Multiplier * ?HIGHBIT, Value + Len * Multiplier); parse_remaining_len(Rest, Header, Multiplier * ?HIGHBIT, Value + Len * Multiplier, Limit);
parse_remaining_len(<<0:1, Len:7, Rest/binary>>, Header, Multiplier, Value) -> parse_remaining_len(<<0:1, Len:7, Rest/binary>>, Header, Multiplier, Value, #mqtt_packet_limit{max_packet_size = MaxLen}) ->
parse_frame(Rest, Header, Value + Len * Multiplier). FrameLen = Value + Len * Multiplier,
if
FrameLen > MaxLen -> {error, invalid_mqtt_frame_len};
true -> parse_frame(Rest, Header, FrameLen)
end.
parse_frame(Bin, #mqtt_packet_header{type = Type, parse_frame(Bin, #mqtt_packet_header{type = Type, qos = Qos} = Header, Length) ->
qos = Qos} = Header, Length) ->
case {Type, Bin} of case {Type, Bin} of
{?CONNECT, <<FrameBin:Length/binary, Rest/binary>>} -> {?CONNECT, <<FrameBin:Length/binary, Rest/binary>>} ->
{ProtoName, Rest1} = parse_utf(FrameBin), {ProtoName, Rest1} = parse_utf(FrameBin),

View File

@ -31,7 +31,7 @@
-include("emqttd_packet.hrl"). -include("emqttd_packet.hrl").
%% API %% API
-export([init/3, client_id/1]). -export([init/2, client_id/1]).
-export([received/2, send/2, redeliver/2, shutdown/2]). -export([received/2, send/2, redeliver/2, shutdown/2]).
@ -43,32 +43,34 @@
socket, socket,
peer_name, peer_name,
connected = false, %received CONNECT action? connected = false, %received CONNECT action?
proto_vsn, proto_ver,
proto_name, proto_name,
%packet_id, %packet_id,
client_id, client_id,
clean_sess, clean_sess,
session, %% session state or session pid session, %% session state or session pid
will_msg will_msg,
max_clientid_len = ?MAX_CLIENTID_LEN
}). }).
-type proto_state() :: #proto_state{}. -type proto_state() :: #proto_state{}.
init(Transport, Socket, Peername) -> init({Transport, Socket, Peername}, Opts) ->
#proto_state{ #proto_state{
transport = Transport, transport = Transport,
socket = Socket, socket = Socket,
peer_name = Peername}. peer_name = Peername,
max_clientid_len = proplists:get_value(max_clientid_len, Opts, ?MAX_CLIENTID_LEN)}.
client_id(#proto_state{client_id = ClientId}) -> ClientId. client_id(#proto_state{client_id = ClientId}) -> ClientId.
%%SHOULD be registered in emqttd_cm %%SHOULD be registered in emqttd_cm
info(#proto_state{proto_vsn = ProtoVsn, info(#proto_state{proto_ver = ProtoVer,
proto_name = ProtoName, proto_name = ProtoName,
client_id = ClientId, client_id = ClientId,
clean_sess = CleanSess, clean_sess = CleanSess,
will_msg = WillMsg}) -> will_msg = WillMsg}) ->
[{proto_vsn, ProtoVsn}, [{proto_ver, ProtoVer},
{proto_name, ProtoName}, {proto_name, ProtoName},
{client_id, ClientId}, {client_id, ClientId},
{clean_sess, CleanSess}, {clean_sess, CleanSess},
@ -100,7 +102,8 @@ received(Packet = ?PACKET(_Type), State = #proto_state{peer_name = PeerName,
handle(Packet = ?CONNECT_PACKET(Var), State = #proto_state{peer_name = PeerName}) -> handle(Packet = ?CONNECT_PACKET(Var), State = #proto_state{peer_name = PeerName}) ->
#mqtt_packet_connect{username = Username, #mqtt_packet_connect{proto_ver = ProtoVer,
username = Username,
password = Password, password = Password,
clean_sess = CleanSess, clean_sess = CleanSess,
keep_alive = KeepAlive, keep_alive = KeepAlive,
@ -109,22 +112,24 @@ handle(Packet = ?CONNECT_PACKET(Var), State = #proto_state{peer_name = PeerName}
lager:info("RECV from ~s@~s: ~s", [ClientId, PeerName, emqttd_packet:dump(Packet)]), lager:info("RECV from ~s@~s: ~s", [ClientId, PeerName, emqttd_packet:dump(Packet)]),
{ReturnCode1, State1} = {ReturnCode1, State1} =
case validate_connect(Var) of case validate_connect(Var, State) of
?CONNACK_ACCEPT -> ?CONNACK_ACCEPT ->
case emqttd_auth:check(Username, Password) of case emqttd_auth:check(Username, Password) of
true -> true ->
ClientId1 = clientid(ClientId, State), ClientId1 = clientid(ClientId, State),
start_keepalive(KeepAlive), start_keepalive(KeepAlive),
emqttd_cm:register(ClientId1, self()), emqttd_cm:register(ClientId1, self()),
{?CONNACK_ACCEPT, State#proto_state{will_msg = willmsg(Var), {?CONNACK_ACCEPT, State#proto_state{proto_ver = ProtoVer,
client_id = ClientId1,
clean_sess = CleanSess, clean_sess = CleanSess,
client_id = ClientId1}}; will_msg = willmsg(Var)}};
false -> false ->
lager:error("~s@~s: username '~s' login failed - no credentials", [ClientId, PeerName, Username]), lager:error("~s@~s: username '~s' login failed - no credentials", [ClientId, PeerName, Username]),
{?CONNACK_CREDENTIALS, State#proto_state{client_id = ClientId}} {?CONNACK_CREDENTIALS, State#proto_state{client_id = ClientId}}
end; end;
ReturnCode -> ReturnCode ->
{ReturnCode, State#proto_state{client_id = ClientId}} {ReturnCode, State#proto_state{client_id = ClientId,
clean_sess = CleanSess}}
end, end,
notify(connected, ReturnCode1, State1), notify(connected, ReturnCode1, State1),
send(?CONNACK_PACKET(ReturnCode1), State1), send(?CONNACK_PACKET(ReturnCode1), State1),
@ -234,10 +239,10 @@ start_keepalive(Sec) when Sec > 0 ->
%%---------------------------------------------------------------------------- %%----------------------------------------------------------------------------
%% Validate Packets %% Validate Packets
%%---------------------------------------------------------------------------- %%----------------------------------------------------------------------------
validate_connect(Connect = #mqtt_packet_connect{}) -> validate_connect(Connect = #mqtt_packet_connect{}, ProtoState) ->
case validate_protocol(Connect) of case validate_protocol(Connect) of
true -> true ->
case validate_clientid(Connect) of case validate_clientid(Connect, ProtoState) of
true -> true ->
?CONNACK_ACCEPT; ?CONNACK_ACCEPT;
false -> false ->
@ -250,16 +255,16 @@ validate_connect(Connect = #mqtt_packet_connect{}) ->
validate_protocol(#mqtt_packet_connect{proto_ver = Ver, proto_name = Name}) -> validate_protocol(#mqtt_packet_connect{proto_ver = Ver, proto_name = Name}) ->
lists:member({Ver, Name}, ?PROTOCOL_NAMES). lists:member({Ver, Name}, ?PROTOCOL_NAMES).
validate_clientid(#mqtt_packet_connect{client_id = ClientId}) validate_clientid(#mqtt_packet_connect{client_id = ClientId}, #proto_state{max_clientid_len = MaxLen})
when ( size(ClientId) >= 1 ) andalso ( size(ClientId) =< ?MAX_CLIENTID_LEN ) -> when ( size(ClientId) >= 1 ) andalso ( size(ClientId) =< MaxLen ) ->
true; true;
%% MQTT3.1.1 allow null clientId. %% MQTT3.1.1 allow null clientId.
validate_clientid(#mqtt_packet_connect{proto_ver =?MQTT_PROTO_V311, client_id = ClientId}) validate_clientid(#mqtt_packet_connect{proto_ver =?MQTT_PROTO_V311, client_id = ClientId}, _ProtoState)
when size(ClientId) =:= 0 -> when size(ClientId) =:= 0 ->
true; true;
validate_clientid(#mqtt_packet_connect {proto_ver = Ver, clean_sess = CleanSess, client_id = ClientId}) -> validate_clientid(#mqtt_packet_connect {proto_ver = Ver, clean_sess = CleanSess, client_id = ClientId}, _ProtoState) ->
lager:warning("Invalid ClientId: ~s, ProtoVer: ~p, CleanSess: ~s", [ClientId, Ver, CleanSess]), lager:warning("Invalid ClientId: ~s, ProtoVer: ~p, CleanSess: ~s", [ClientId, Ver, CleanSess]),
false. false.
@ -320,7 +325,7 @@ inc(_) ->
ingore. ingore.
notify(connected, ReturnCode, #proto_state{peer_name = PeerName, notify(connected, ReturnCode, #proto_state{peer_name = PeerName,
proto_vsn = ProtoVsn, proto_ver = ProtoVer,
client_id = ClientId, client_id = ClientId,
clean_sess = CleanSess}) -> clean_sess = CleanSess}) ->
Sess = case CleanSess of Sess = case CleanSess of
@ -328,7 +333,7 @@ notify(connected, ReturnCode, #proto_state{peer_name = PeerName,
false -> true false -> true
end, end,
Params = [{from, PeerName}, Params = [{from, PeerName},
{protocol, ProtoVsn}, {protocol, ProtoVer},
{session, Sess}, {session, Sess},
{connack, ReturnCode}], {connack, ReturnCode}],
emqttd_event:notify({connected, ClientId, Params}). emqttd_event:notify({connected, ClientId, Params}).

View File

@ -43,7 +43,7 @@
{access, []}, {access, []},
{packet, [ {packet, [
{max_clientid_len, 1024}, {max_clientid_len, 1024},
{max_packet_size, 64k}, {max_packet_size, 16#ffff}
]}, ]},
{session, [ {session, [
{expires, 1}, {expires, 1},