Improve the emqttd_parser design

This commit is contained in:
Feng Lee 2017-02-21 20:10:17 +08:00
parent de12c58af0
commit 7e98650233
2 changed files with 26 additions and 30 deletions

View File

@ -52,9 +52,9 @@
-export([prioritise_call/4, prioritise_info/3, handle_pre_hibernate/1]). -export([prioritise_call/4, prioritise_info/3, handle_pre_hibernate/1]).
%% Client State %% Client State
-record(client_state, {connection, connname, peername, peerhost, peerport, -record(client_state, {connection, connname, peername, peerhost, peerport, await_recv,
await_recv, conn_state, rate_limit, parser_fun, conn_state, rate_limit, packet_limit, parse_state, proto_state,
proto_state, packet_opts, keepalive, enable_stats}). keepalive, enable_stats}).
-define(INFO_KEYS, [connname, peername, peerhost, peerport, await_recv, conn_state]). -define(INFO_KEYS, [connname, peername, peerhost, peerport, await_recv, conn_state]).
@ -120,9 +120,10 @@ init([Conn0, Env]) ->
error:Error -> Self ! {shutdown, Error} error:Error -> Self ! {shutdown, Error}
end end
end, end,
ParserFun = emqttd_parser:new(Env),
ProtoState = emqttd_protocol:init(PeerName, SendFun, Env),
RateLimit = get_value(rate_limit, Conn:opts()), RateLimit = get_value(rate_limit, Conn:opts()),
PacketLimit = proplists:get_value(max_packet_size, Env, ?MAX_PACKET_LEN),
ParseState = emqttd_parser:initial_state(PacketLimit),
ProtoState = emqttd_protocol:init(PeerName, SendFun, Env),
EnableStats = get_value(client_enable_stats, Env, false), EnableStats = get_value(client_enable_stats, Env, false),
State = run_socket(#client_state{connection = Conn, State = run_socket(#client_state{connection = Conn,
connname = ConnName, connname = ConnName,
@ -132,9 +133,9 @@ init([Conn0, Env]) ->
await_recv = false, await_recv = false,
conn_state = running, conn_state = running,
rate_limit = RateLimit, rate_limit = RateLimit,
parser_fun = ParserFun, packet_limit = PacketLimit,
parse_state = ParseState,
proto_state = ProtoState, proto_state = ProtoState,
packet_opts = Env,
enable_stats = EnableStats}), enable_stats = EnableStats}),
IdleTimout = get_value(client_idle_timeout, Env, 30000), IdleTimout = get_value(client_idle_timeout, Env, 30000),
gen_server2:enter_loop(?MODULE, [], State, self(), IdleTimout, gen_server2:enter_loop(?MODULE, [], State, self(), IdleTimout,
@ -294,17 +295,17 @@ code_change(_OldVsn, State, _Extra) ->
received(<<>>, State) -> received(<<>>, State) ->
{noreply, State, hibernate}; {noreply, State, hibernate};
received(Bytes, State = #client_state{parser_fun = ParserFun, received(Bytes, State = #client_state{parse_state = ParseState,
packet_opts = PacketOpts, packet_limit = PacketLimit,
proto_state = ProtoState}) -> proto_state = ProtoState}) ->
case catch ParserFun(Bytes) of case catch emqttd_parser:parse(Bytes, ParseState) of
{more, NewParser} -> {more, NewParseState} ->
{noreply, run_socket(State#client_state{parser_fun = NewParser}), hibernate}; {noreply, run_socket(State#client_state{parse_state = NewParseState}), hibernate};
{ok, Packet, Rest} -> {ok, Packet, Rest} ->
emqttd_metrics:received(Packet), emqttd_metrics:received(Packet),
case emqttd_protocol:received(Packet, ProtoState) of case emqttd_protocol:received(Packet, ProtoState) of
{ok, ProtoState1} -> {ok, ProtoState1} ->
received(Rest, State#client_state{parser_fun = emqttd_parser:new(PacketOpts), received(Rest, State#client_state{parse_state = emqttd_parser:initial_state(PacketLimit),
proto_state = ProtoState1}); proto_state = ProtoState1});
{error, Error} -> {error, Error} ->
?LOG(error, "Protocol error - ~p", [Error], State), ?LOG(error, "Protocol error - ~p", [Error], State),

View File

@ -24,27 +24,22 @@
-include("emqttd_protocol.hrl"). -include("emqttd_protocol.hrl").
%% API %% API
-export([new/1, parse/2]). -export([initial_state/0, initial_state/1, parse/2]).
-record(mqtt_packet_limit, {max_packet_size}). -spec(initial_state() -> {none, pos_integer()}).
initial_state() ->
-type(option() :: {atom(), any()}). initial_state(?MAX_PACKET_LEN).
-type(parser() :: fun( (binary()) -> any() )).
%% @doc Initialize a parser %% @doc Initialize a parser
-spec(new(Opts :: [option()]) -> parser()). -spec(initial_state(pos_integer()) -> {none, pos_integer()}).
new(Opts) -> initial_state(MaxLen) ->
fun(Bin) -> parse(Bin, {none, limit(Opts)}) end. {none, MaxLen}.
limit(Opts) ->
#mqtt_packet_limit{max_packet_size = proplists:get_value(max_packet_size, Opts, ?MAX_LEN)}.
%% @doc Parse MQTT Packet %% @doc Parse MQTT Packet
-spec(parse(binary(), {none, [option()]} | fun()) -spec(parse(binary(), {none, pos_integer()} | fun())
-> {ok, mqtt_packet()} | {error, any()} | {more, fun()}). -> {ok, mqtt_packet()} | {error, any()} | {more, fun()}).
parse(<<>>, {none, Limit}) -> parse(<<>>, {none, MaxLen}) ->
{more, fun(Bin) -> parse(Bin, {none, Limit}) end}; {more, fun(Bin) -> parse(Bin, {none, MaxLen}) end};
parse(<<Type:4, Dup:1, QoS:2, Retain:1, Rest/binary>>, {none, Limit}) -> parse(<<Type:4, Dup:1, QoS:2, Retain:1, Rest/binary>>, {none, Limit}) ->
parse_remaining_len(Rest, #mqtt_packet_header{type = Type, parse_remaining_len(Rest, #mqtt_packet_header{type = Type,
dup = bool(Dup), dup = bool(Dup),
@ -57,7 +52,7 @@ parse_remaining_len(<<>>, Header, Limit) ->
parse_remaining_len(Rest, Header, Limit) -> parse_remaining_len(Rest, Header, Limit) ->
parse_remaining_len(Rest, Header, 1, 0, Limit). parse_remaining_len(Rest, Header, 1, 0, Limit).
parse_remaining_len(_Bin, _Header, _Multiplier, Length, #mqtt_packet_limit{max_packet_size = MaxLen}) parse_remaining_len(_Bin, _Header, _Multiplier, Length, MaxLen)
when Length > MaxLen -> when Length > MaxLen ->
{error, invalid_mqtt_frame_len}; {error, invalid_mqtt_frame_len};
parse_remaining_len(<<>>, Header, Multiplier, Length, Limit) -> parse_remaining_len(<<>>, Header, Multiplier, Length, Limit) ->
@ -70,7 +65,7 @@ parse_remaining_len(<<0:8, Rest/binary>>, Header, 1, 0, _Limit) ->
parse_frame(Rest, Header, 0); parse_frame(Rest, Header, 0);
parse_remaining_len(<<1:1, Len:7, Rest/binary>>, Header, Multiplier, Value, Limit) -> parse_remaining_len(<<1:1, Len:7, Rest/binary>>, Header, Multiplier, Value, Limit) ->
parse_remaining_len(Rest, Header, Multiplier * ?HIGHBIT, Value + Len * Multiplier, Limit); parse_remaining_len(Rest, Header, Multiplier * ?HIGHBIT, Value + Len * Multiplier, Limit);
parse_remaining_len(<<0:1, Len:7, Rest/binary>>, Header, Multiplier, Value, #mqtt_packet_limit{max_packet_size = MaxLen}) -> parse_remaining_len(<<0:1, Len:7, Rest/binary>>, Header, Multiplier, Value, MaxLen) ->
FrameLen = Value + Len * Multiplier, FrameLen = Value + Len * Multiplier,
if if
FrameLen > MaxLen -> {error, invalid_mqtt_frame_len}; FrameLen > MaxLen -> {error, invalid_mqtt_frame_len};