Improve the emqttd_parser design

This commit is contained in:
Feng Lee 2017-02-21 20:10:17 +08:00
parent de12c58af0
commit 7e98650233
2 changed files with 26 additions and 30 deletions

View File

@ -52,9 +52,9 @@
-export([prioritise_call/4, prioritise_info/3, handle_pre_hibernate/1]).
%% Client State
-record(client_state, {connection, connname, peername, peerhost, peerport,
await_recv, conn_state, rate_limit, parser_fun,
proto_state, packet_opts, keepalive, enable_stats}).
-record(client_state, {connection, connname, peername, peerhost, peerport, await_recv,
conn_state, rate_limit, packet_limit, parse_state, proto_state,
keepalive, enable_stats}).
-define(INFO_KEYS, [connname, peername, peerhost, peerport, await_recv, conn_state]).
@ -120,9 +120,10 @@ init([Conn0, Env]) ->
error:Error -> Self ! {shutdown, Error}
end
end,
ParserFun = emqttd_parser:new(Env),
ProtoState = emqttd_protocol:init(PeerName, SendFun, Env),
RateLimit = get_value(rate_limit, Conn:opts()),
PacketLimit = proplists:get_value(max_packet_size, Env, ?MAX_PACKET_LEN),
ParseState = emqttd_parser:initial_state(PacketLimit),
ProtoState = emqttd_protocol:init(PeerName, SendFun, Env),
EnableStats = get_value(client_enable_stats, Env, false),
State = run_socket(#client_state{connection = Conn,
connname = ConnName,
@ -132,9 +133,9 @@ init([Conn0, Env]) ->
await_recv = false,
conn_state = running,
rate_limit = RateLimit,
parser_fun = ParserFun,
packet_limit = PacketLimit,
parse_state = ParseState,
proto_state = ProtoState,
packet_opts = Env,
enable_stats = EnableStats}),
IdleTimout = get_value(client_idle_timeout, Env, 30000),
gen_server2:enter_loop(?MODULE, [], State, self(), IdleTimout,
@ -294,17 +295,17 @@ code_change(_OldVsn, State, _Extra) ->
received(<<>>, State) ->
{noreply, State, hibernate};
received(Bytes, State = #client_state{parser_fun = ParserFun,
packet_opts = PacketOpts,
received(Bytes, State = #client_state{parse_state = ParseState,
packet_limit = PacketLimit,
proto_state = ProtoState}) ->
case catch ParserFun(Bytes) of
{more, NewParser} ->
{noreply, run_socket(State#client_state{parser_fun = NewParser}), hibernate};
case catch emqttd_parser:parse(Bytes, ParseState) of
{more, NewParseState} ->
{noreply, run_socket(State#client_state{parse_state = NewParseState}), hibernate};
{ok, Packet, Rest} ->
emqttd_metrics:received(Packet),
case emqttd_protocol:received(Packet, ProtoState) of
{ok, ProtoState1} ->
received(Rest, State#client_state{parser_fun = emqttd_parser:new(PacketOpts),
received(Rest, State#client_state{parse_state = emqttd_parser:initial_state(PacketLimit),
proto_state = ProtoState1});
{error, Error} ->
?LOG(error, "Protocol error - ~p", [Error], State),

View File

@ -24,27 +24,22 @@
-include("emqttd_protocol.hrl").
%% API
-export([new/1, parse/2]).
-export([initial_state/0, initial_state/1, parse/2]).
-record(mqtt_packet_limit, {max_packet_size}).
-type(option() :: {atom(), any()}).
-type(parser() :: fun( (binary()) -> any() )).
-spec(initial_state() -> {none, pos_integer()}).
initial_state() ->
initial_state(?MAX_PACKET_LEN).
%% @doc Initialize a parser
-spec(new(Opts :: [option()]) -> parser()).
new(Opts) ->
fun(Bin) -> parse(Bin, {none, limit(Opts)}) end.
limit(Opts) ->
#mqtt_packet_limit{max_packet_size = proplists:get_value(max_packet_size, Opts, ?MAX_LEN)}.
-spec(initial_state(pos_integer()) -> {none, pos_integer()}).
initial_state(MaxLen) ->
{none, MaxLen}.
%% @doc Parse MQTT Packet
-spec(parse(binary(), {none, [option()]} | fun())
-spec(parse(binary(), {none, pos_integer()} | fun())
-> {ok, mqtt_packet()} | {error, any()} | {more, fun()}).
parse(<<>>, {none, Limit}) ->
{more, fun(Bin) -> parse(Bin, {none, Limit}) end};
parse(<<>>, {none, MaxLen}) ->
{more, fun(Bin) -> parse(Bin, {none, MaxLen}) end};
parse(<<Type:4, Dup:1, QoS:2, Retain:1, Rest/binary>>, {none, Limit}) ->
parse_remaining_len(Rest, #mqtt_packet_header{type = Type,
dup = bool(Dup),
@ -57,7 +52,7 @@ parse_remaining_len(<<>>, Header, Limit) ->
parse_remaining_len(Rest, Header, Limit) ->
parse_remaining_len(Rest, Header, 1, 0, Limit).
parse_remaining_len(_Bin, _Header, _Multiplier, Length, #mqtt_packet_limit{max_packet_size = MaxLen})
parse_remaining_len(_Bin, _Header, _Multiplier, Length, MaxLen)
when Length > MaxLen ->
{error, invalid_mqtt_frame_len};
parse_remaining_len(<<>>, Header, Multiplier, Length, Limit) ->
@ -70,7 +65,7 @@ parse_remaining_len(<<0:8, Rest/binary>>, Header, 1, 0, _Limit) ->
parse_frame(Rest, Header, 0);
parse_remaining_len(<<1:1, Len:7, Rest/binary>>, Header, Multiplier, Value, Limit) ->
parse_remaining_len(Rest, Header, Multiplier * ?HIGHBIT, Value + Len * Multiplier, Limit);
parse_remaining_len(<<0:1, Len:7, Rest/binary>>, Header, Multiplier, Value, #mqtt_packet_limit{max_packet_size = MaxLen}) ->
parse_remaining_len(<<0:1, Len:7, Rest/binary>>, Header, Multiplier, Value, MaxLen) ->
FrameLen = Value + Len * Multiplier,
if
FrameLen > MaxLen -> {error, invalid_mqtt_frame_len};