diff --git a/apps/emqx_gateway/src/bhvrs/emqx_gateway_conn.erl b/apps/emqx_gateway/src/bhvrs/emqx_gateway_conn.erl index 4145a92a7..52f96bcd2 100644 --- a/apps/emqx_gateway/src/bhvrs/emqx_gateway_conn.erl +++ b/apps/emqx_gateway/src/bhvrs/emqx_gateway_conn.erl @@ -970,6 +970,12 @@ close_socket(State = #state{socket = Socket}) -> %% Inc incoming/outgoing stats inc_incoming_stats(Ctx, FrameMod, Packet) -> + do_inc_incoming_stats(FrameMod:type(Packet), Ctx, FrameMod, Packet). + +%% If a mailformed packet is received, the type of the packet is undefined. +do_inc_incoming_stats(undefined, _Ctx, _FrameMod, _Packet) -> + ok; +do_inc_incoming_stats(Type, Ctx, FrameMod, Packet) -> inc_counter(recv_pkt, 1), case FrameMod:is_message(Packet) of true -> @@ -978,9 +984,7 @@ inc_incoming_stats(Ctx, FrameMod, Packet) -> false -> ok end, - Name = list_to_atom( - lists:concat(["packets.", FrameMod:type(Packet), ".received"]) - ), + Name = list_to_atom(lists:concat(["packets.", Type, ".received"])), emqx_gateway_ctx:metrics_inc(Ctx, Name). inc_outgoing_stats(Ctx, FrameMod, Packet) -> diff --git a/apps/emqx_gateway_stomp/src/emqx_gateway_stomp.app.src b/apps/emqx_gateway_stomp/src/emqx_gateway_stomp.app.src index 38da1e18b..1fda99700 100644 --- a/apps/emqx_gateway_stomp/src/emqx_gateway_stomp.app.src +++ b/apps/emqx_gateway_stomp/src/emqx_gateway_stomp.app.src @@ -1,6 +1,6 @@ {application, emqx_gateway_stomp, [ {description, "Stomp Gateway"}, - {vsn, "0.1.0"}, + {vsn, "0.1.1"}, {registered, []}, {applications, [kernel, stdlib, emqx, emqx_gateway]}, {env, []}, diff --git a/apps/emqx_gateway_stomp/src/emqx_stomp_channel.erl b/apps/emqx_gateway_stomp/src/emqx_stomp_channel.erl index 7a16792a0..07dfd5f46 100644 --- a/apps/emqx_gateway_stomp/src/emqx_stomp_channel.erl +++ b/apps/emqx_gateway_stomp/src/emqx_stomp_channel.erl @@ -638,12 +638,12 @@ handle_in( ] end, {ok, Outgoings, Channel}; +handle_in({frame_error, Reason}, Channel = #channel{conn_state = idle}) -> + shutdown(Reason, Channel); handle_in({frame_error, Reason}, Channel = #channel{conn_state = _ConnState}) -> - ?SLOG(error, #{ - msg => "unexpected_frame_error", - reason => Reason - }), - shutdown(Reason, Channel). + ErrMsg = io_lib:format("Frame error: ~0p", [Reason]), + Frame = error_frame(undefined, ErrMsg), + shutdown(Reason, Frame, Channel). with_transaction(Headers, Channel = #channel{transaction = Trans}, Fun) -> Id = header(<<"transaction">>, Headers), diff --git a/apps/emqx_gateway_stomp/src/emqx_stomp_frame.erl b/apps/emqx_gateway_stomp/src/emqx_stomp_frame.erl index 4913d6b2a..561f9e229 100644 --- a/apps/emqx_gateway_stomp/src/emqx_stomp_frame.erl +++ b/apps/emqx_gateway_stomp/src/emqx_stomp_frame.erl @@ -129,8 +129,8 @@ initial_parse_state(Opts) -> limit(Opts) -> #frame_limit{ - max_header_num = g(max_header_num, Opts, ?MAX_HEADER_NUM), - max_header_length = g(max_header_length, Opts, ?MAX_HEADER_LENGTH), + max_header_num = g(max_headers, Opts, ?MAX_HEADER_NUM), + max_header_length = g(max_headers_length, Opts, ?MAX_HEADER_LENGTH), max_body_length = g(max_body_length, Opts, ?MAX_BODY_LENGTH) }. @@ -243,7 +243,9 @@ content_len(#parser_state{headers = Headers}) -> false -> none end. -new_frame(#parser_state{cmd = Cmd, headers = Headers, acc = Acc}) -> +new_frame(#parser_state{cmd = Cmd, headers = Headers, acc = Acc, limit = Limit}) -> + ok = check_max_headers(Headers, Limit), + ok = check_max_body(Acc, Limit), #stomp_frame{command = Cmd, headers = Headers, body = Acc}. acc(Chunk, State = #parser_state{acc = Acc}) when is_binary(Chunk) -> @@ -261,6 +263,57 @@ unescape($c) -> ?COLON; unescape($\\) -> ?BSL; unescape(_Ch) -> error(cannnot_unescape). +check_max_headers( + Headers, + #frame_limit{ + max_header_num = MaxNum, + max_header_length = MaxLen + } +) -> + HeadersLen = length(Headers), + case HeadersLen > MaxNum of + true -> + error( + {too_many_headers, #{ + max_header_num => MaxNum, + received_headers_num => length(Headers) + }} + ); + false -> + ok + end, + lists:foreach( + fun({Name, Val}) -> + Len = byte_size(Name) + byte_size(Val), + case Len > MaxLen of + true -> + error( + {too_long_header, #{ + max_header_length => MaxLen, + found_header_length => Len + }} + ); + false -> + ok + end + end, + Headers + ). + +check_max_body(Acc, #frame_limit{max_body_length = MaxLen}) -> + Len = byte_size(Acc), + case Len > MaxLen of + true -> + error( + {too_long_body, #{ + max_body_length => MaxLen, + received_body_length => Len + }} + ); + false -> + ok + end. + %%-------------------------------------------------------------------- %% Serialize funcs %%-------------------------------------------------------------------- @@ -330,7 +383,10 @@ make(Command, Headers, Body) -> #stomp_frame{command = Command, headers = Headers, body = Body}. %% @doc Format a frame -format(Frame) -> serialize_pkt(Frame, #{}). +format({frame_error, _Reason} = Error) -> + Error; +format(Frame) -> + serialize_pkt(Frame, #{}). is_message(#stomp_frame{command = CMD}) when CMD == ?CMD_SEND; @@ -373,4 +429,6 @@ type(?CMD_RECEIPT) -> type(?CMD_ERROR) -> error; type(?CMD_HEARTBEAT) -> - heartbeat. + heartbeat; +type(_) -> + undefined. diff --git a/apps/emqx_gateway_stomp/test/emqx_stomp_SUITE.erl b/apps/emqx_gateway_stomp/test/emqx_stomp_SUITE.erl index b4a8fe139..196ed703c 100644 --- a/apps/emqx_gateway_stomp/test/emqx_stomp_SUITE.erl +++ b/apps/emqx_gateway_stomp/test/emqx_stomp_SUITE.erl @@ -40,7 +40,12 @@ " username = \"${Packet.headers.login}\"\n" " password = \"${Packet.headers.passcode}\"\n" " }\n" - " listeners.tcp.default {\n" + " frame {\n" + " max_headers = 10\n" + " max_headers_length = 100\n" + " max_body_length = 1024\n" + " }\n" + " listeners.tcp.default {\n" " bind = 61613\n" " }\n" "}\n" @@ -705,6 +710,129 @@ t_sticky_packets_truncate_after_headers(_) -> ?assert(false, "waiting message timeout") end end). + +t_frame_error_in_connect(_) -> + with_connection(fun(Sock) -> + gen_tcp:send( + Sock, + serialize( + <<"CONNECT">>, + [ + {<<"accept-version">>, ?STOMP_VER}, + {<<"host">>, <<"127.0.0.1:61613">>}, + {<<"login">>, <<"guest">>}, + {<<"passcode">>, <<"guest">>}, + {<<"heart-beat">>, <<"0,0">>}, + {<<"custome_header1">>, <<"val">>}, + {<<"custome_header2">>, <<"val">>}, + {<<"custome_header3">>, <<"val">>}, + {<<"custome_header4">>, <<"val">>}, + {<<"custome_header5">>, <<"val">>}, + {<<"custome_header6">>, <<"val">>} + ] + ) + ), + ?assertMatch({error, closed}, gen_tcp:recv(Sock, 0)) + end). + +t_frame_error_too_many_headers(_) -> + Frame = serialize( + <<"SEND">>, + [ + {<<"destination">>, <<"/queue/foo">>}, + {<<"custome_header1">>, <<"val">>}, + {<<"custome_header2">>, <<"val">>}, + {<<"custome_header3">>, <<"val">>}, + {<<"custome_header4">>, <<"val">>}, + {<<"custome_header5">>, <<"val">>}, + {<<"custome_header6">>, <<"val">>}, + {<<"custome_header7">>, <<"val">>}, + {<<"custome_header8">>, <<"val">>}, + {<<"custome_header9">>, <<"val">>}, + {<<"custome_header10">>, <<"val">>} + ], + <<"test">> + ), + Assert = + fun(Sock) -> + {ok, Data} = gen_tcp:recv(Sock, 0), + {ok, ErrorFrame, _, _} = parse(Data), + ?assertMatch(#stomp_frame{command = <<"ERROR">>}, ErrorFrame), + ?assertMatch( + match, re:run(ErrorFrame#stomp_frame.body, "too_many_headers", [{capture, none}]) + ), + ?assertMatch({error, closed}, gen_tcp:recv(Sock, 0)) + end, + test_frame_error(Frame, Assert). + +t_frame_error_too_long_header(_) -> + LongHeaderVal = emqx_utils:bin_to_hexstr(crypto:strong_rand_bytes(50), upper), + Frame = serialize( + <<"SEND">>, + [ + {<<"destination">>, <<"/queue/foo">>}, + {<<"custome_header10">>, LongHeaderVal} + ], + <<"test">> + ), + Assert = + fun(Sock) -> + {ok, Data} = gen_tcp:recv(Sock, 0), + {ok, ErrorFrame, _, _} = parse(Data), + ?assertMatch(#stomp_frame{command = <<"ERROR">>}, ErrorFrame), + ?assertMatch( + match, re:run(ErrorFrame#stomp_frame.body, "too_long_header", [{capture, none}]) + ), + ?assertMatch({error, closed}, gen_tcp:recv(Sock, 0)) + end, + test_frame_error(Frame, Assert). + +t_frame_error_too_long_body(_) -> + LongBody = emqx_utils:bin_to_hexstr(crypto:strong_rand_bytes(513), upper), + Frame = serialize( + <<"SEND">>, + [{<<"destination">>, <<"/queue/foo">>}], + LongBody + ), + Assert = + fun(Sock) -> + {ok, Data} = gen_tcp:recv(Sock, 0), + {ok, ErrorFrame, _, _} = parse(Data), + ?assertMatch(#stomp_frame{command = <<"ERROR">>}, ErrorFrame), + ?assertMatch( + match, re:run(ErrorFrame#stomp_frame.body, "too_long_body", [{capture, none}]) + ), + ?assertMatch({error, closed}, gen_tcp:recv(Sock, 0)) + end, + test_frame_error(Frame, Assert). + +test_frame_error(Frame, AssertFun) -> + with_connection(fun(Sock) -> + gen_tcp:send( + Sock, + serialize( + <<"CONNECT">>, + [ + {<<"accept-version">>, ?STOMP_VER}, + {<<"host">>, <<"127.0.0.1:61613">>}, + {<<"login">>, <<"guest">>}, + {<<"passcode">>, <<"guest">>}, + {<<"heart-beat">>, <<"0,0">>} + ] + ) + ), + {ok, Data} = gen_tcp:recv(Sock, 0), + {ok, + #stomp_frame{ + command = <<"CONNECTED">>, + headers = _, + body = _ + }, + _, _} = parse(Data), + gen_tcp:send(Sock, Frame), + AssertFun(Sock) + end). + t_rest_clienit_info(_) -> with_connection(fun(Sock) -> gen_tcp:send( @@ -856,9 +984,9 @@ serialize(Command, Headers, Body) -> parse(Data) -> ProtoEnv = #{ - max_headers => 10, - max_header_length => 1024, - max_body_length => 8192 + max_headers => 1024, + max_header_length => 10240, + max_body_length => 81920 }, Parser = emqx_stomp_frame:initial_parse_state(ProtoEnv), emqx_stomp_frame:parse(Data, Parser).