From 7e9865023375b3546873c144b4f597fb69b2b8c9 Mon Sep 17 00:00:00 2001 From: Feng Lee Date: Tue, 21 Feb 2017 20:10:17 +0800 Subject: [PATCH] Improve the emqttd_parser design --- src/emqttd_client.erl | 27 ++++++++++++++------------- src/emqttd_parser.erl | 29 ++++++++++++----------------- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/src/emqttd_client.erl b/src/emqttd_client.erl index 765c9d837..e58f55738 100644 --- a/src/emqttd_client.erl +++ b/src/emqttd_client.erl @@ -52,9 +52,9 @@ -export([prioritise_call/4, prioritise_info/3, handle_pre_hibernate/1]). %% Client State --record(client_state, {connection, connname, peername, peerhost, peerport, - await_recv, conn_state, rate_limit, parser_fun, - proto_state, packet_opts, keepalive, enable_stats}). +-record(client_state, {connection, connname, peername, peerhost, peerport, await_recv, + conn_state, rate_limit, packet_limit, parse_state, proto_state, + keepalive, enable_stats}). -define(INFO_KEYS, [connname, peername, peerhost, peerport, await_recv, conn_state]). @@ -120,9 +120,10 @@ init([Conn0, Env]) -> error:Error -> Self ! {shutdown, Error} end end, - ParserFun = emqttd_parser:new(Env), - ProtoState = emqttd_protocol:init(PeerName, SendFun, Env), RateLimit = get_value(rate_limit, Conn:opts()), + PacketLimit = proplists:get_value(max_packet_size, Env, ?MAX_PACKET_LEN), + ParseState = emqttd_parser:initial_state(PacketLimit), + ProtoState = emqttd_protocol:init(PeerName, SendFun, Env), EnableStats = get_value(client_enable_stats, Env, false), State = run_socket(#client_state{connection = Conn, connname = ConnName, @@ -132,9 +133,9 @@ init([Conn0, Env]) -> await_recv = false, conn_state = running, rate_limit = RateLimit, - parser_fun = ParserFun, + packet_limit = PacketLimit, + parse_state = ParseState, proto_state = ProtoState, - packet_opts = Env, enable_stats = EnableStats}), IdleTimout = get_value(client_idle_timeout, Env, 30000), gen_server2:enter_loop(?MODULE, [], State, self(), IdleTimout, @@ -294,17 +295,17 @@ code_change(_OldVsn, State, _Extra) -> received(<<>>, State) -> {noreply, State, hibernate}; -received(Bytes, State = #client_state{parser_fun = ParserFun, - packet_opts = PacketOpts, +received(Bytes, State = #client_state{parse_state = ParseState, + packet_limit = PacketLimit, proto_state = ProtoState}) -> - case catch ParserFun(Bytes) of - {more, NewParser} -> - {noreply, run_socket(State#client_state{parser_fun = NewParser}), hibernate}; + case catch emqttd_parser:parse(Bytes, ParseState) of + {more, NewParseState} -> + {noreply, run_socket(State#client_state{parse_state = NewParseState}), hibernate}; {ok, Packet, Rest} -> emqttd_metrics:received(Packet), case emqttd_protocol:received(Packet, ProtoState) of {ok, ProtoState1} -> - received(Rest, State#client_state{parser_fun = emqttd_parser:new(PacketOpts), + received(Rest, State#client_state{parse_state = emqttd_parser:initial_state(PacketLimit), proto_state = ProtoState1}); {error, Error} -> ?LOG(error, "Protocol error - ~p", [Error], State), diff --git a/src/emqttd_parser.erl b/src/emqttd_parser.erl index 669f4aab2..bc6e0037b 100644 --- a/src/emqttd_parser.erl +++ b/src/emqttd_parser.erl @@ -24,27 +24,22 @@ -include("emqttd_protocol.hrl"). %% API --export([new/1, parse/2]). +-export([initial_state/0, initial_state/1, parse/2]). --record(mqtt_packet_limit, {max_packet_size}). - --type(option() :: {atom(), any()}). - --type(parser() :: fun( (binary()) -> any() )). +-spec(initial_state() -> {none, pos_integer()}). +initial_state() -> + initial_state(?MAX_PACKET_LEN). %% @doc Initialize a parser --spec(new(Opts :: [option()]) -> parser()). -new(Opts) -> - fun(Bin) -> parse(Bin, {none, limit(Opts)}) end. - -limit(Opts) -> - #mqtt_packet_limit{max_packet_size = proplists:get_value(max_packet_size, Opts, ?MAX_LEN)}. +-spec(initial_state(pos_integer()) -> {none, pos_integer()}). +initial_state(MaxLen) -> + {none, MaxLen}. %% @doc Parse MQTT Packet --spec(parse(binary(), {none, [option()]} | fun()) +-spec(parse(binary(), {none, pos_integer()} | fun()) -> {ok, mqtt_packet()} | {error, any()} | {more, fun()}). -parse(<<>>, {none, Limit}) -> - {more, fun(Bin) -> parse(Bin, {none, Limit}) end}; +parse(<<>>, {none, MaxLen}) -> + {more, fun(Bin) -> parse(Bin, {none, MaxLen}) end}; parse(<>, {none, Limit}) -> parse_remaining_len(Rest, #mqtt_packet_header{type = Type, dup = bool(Dup), @@ -57,7 +52,7 @@ parse_remaining_len(<<>>, Header, Limit) -> parse_remaining_len(Rest, Header, Limit) -> parse_remaining_len(Rest, Header, 1, 0, Limit). -parse_remaining_len(_Bin, _Header, _Multiplier, Length, #mqtt_packet_limit{max_packet_size = MaxLen}) +parse_remaining_len(_Bin, _Header, _Multiplier, Length, MaxLen) when Length > MaxLen -> {error, invalid_mqtt_frame_len}; parse_remaining_len(<<>>, Header, Multiplier, Length, Limit) -> @@ -70,7 +65,7 @@ parse_remaining_len(<<0:8, Rest/binary>>, Header, 1, 0, _Limit) -> parse_frame(Rest, Header, 0); parse_remaining_len(<<1:1, Len:7, Rest/binary>>, Header, Multiplier, Value, Limit) -> parse_remaining_len(Rest, Header, Multiplier * ?HIGHBIT, Value + Len * Multiplier, Limit); -parse_remaining_len(<<0:1, Len:7, Rest/binary>>, Header, Multiplier, Value, #mqtt_packet_limit{max_packet_size = MaxLen}) -> +parse_remaining_len(<<0:1, Len:7, Rest/binary>>, Header, Multiplier, Value, MaxLen) -> FrameLen = Value + Len * Multiplier, if FrameLen > MaxLen -> {error, invalid_mqtt_frame_len};