fix(stomp): fix frame limitation is not working

This commit is contained in:
JianBo He 2023-06-08 16:57:12 +08:00
parent bad0c35bb9
commit 4065f08083
5 changed files with 208 additions and 18 deletions

View File

@ -970,6 +970,12 @@ close_socket(State = #state{socket = Socket}) ->
%% Inc incoming/outgoing stats %% Inc incoming/outgoing stats
inc_incoming_stats(Ctx, FrameMod, Packet) -> inc_incoming_stats(Ctx, FrameMod, Packet) ->
do_inc_incoming_stats(FrameMod:type(Packet), Ctx, FrameMod, Packet).
%% If a mailformed packet is received, the type of the packet is undefined.
do_inc_incoming_stats(undefined, _Ctx, _FrameMod, _Packet) ->
ok;
do_inc_incoming_stats(Type, Ctx, FrameMod, Packet) ->
inc_counter(recv_pkt, 1), inc_counter(recv_pkt, 1),
case FrameMod:is_message(Packet) of case FrameMod:is_message(Packet) of
true -> true ->
@ -978,9 +984,7 @@ inc_incoming_stats(Ctx, FrameMod, Packet) ->
false -> false ->
ok ok
end, end,
Name = list_to_atom( Name = list_to_atom(lists:concat(["packets.", Type, ".received"])),
lists:concat(["packets.", FrameMod:type(Packet), ".received"])
),
emqx_gateway_ctx:metrics_inc(Ctx, Name). emqx_gateway_ctx:metrics_inc(Ctx, Name).
inc_outgoing_stats(Ctx, FrameMod, Packet) -> inc_outgoing_stats(Ctx, FrameMod, Packet) ->

View File

@ -1,6 +1,6 @@
{application, emqx_gateway_stomp, [ {application, emqx_gateway_stomp, [
{description, "Stomp Gateway"}, {description, "Stomp Gateway"},
{vsn, "0.1.0"}, {vsn, "0.1.1"},
{registered, []}, {registered, []},
{applications, [kernel, stdlib, emqx, emqx_gateway]}, {applications, [kernel, stdlib, emqx, emqx_gateway]},
{env, []}, {env, []},

View File

@ -638,12 +638,12 @@ handle_in(
] ]
end, end,
{ok, Outgoings, Channel}; {ok, Outgoings, Channel};
handle_in({frame_error, Reason}, Channel = #channel{conn_state = idle}) ->
shutdown(Reason, Channel);
handle_in({frame_error, Reason}, Channel = #channel{conn_state = _ConnState}) -> handle_in({frame_error, Reason}, Channel = #channel{conn_state = _ConnState}) ->
?SLOG(error, #{ ErrMsg = io_lib:format("Frame error: ~0p", [Reason]),
msg => "unexpected_frame_error", Frame = error_frame(undefined, ErrMsg),
reason => Reason shutdown(Reason, Frame, Channel).
}),
shutdown(Reason, Channel).
with_transaction(Headers, Channel = #channel{transaction = Trans}, Fun) -> with_transaction(Headers, Channel = #channel{transaction = Trans}, Fun) ->
Id = header(<<"transaction">>, Headers), Id = header(<<"transaction">>, Headers),

View File

@ -129,8 +129,8 @@ initial_parse_state(Opts) ->
limit(Opts) -> limit(Opts) ->
#frame_limit{ #frame_limit{
max_header_num = g(max_header_num, Opts, ?MAX_HEADER_NUM), max_header_num = g(max_headers, Opts, ?MAX_HEADER_NUM),
max_header_length = g(max_header_length, Opts, ?MAX_HEADER_LENGTH), max_header_length = g(max_headers_length, Opts, ?MAX_HEADER_LENGTH),
max_body_length = g(max_body_length, Opts, ?MAX_BODY_LENGTH) max_body_length = g(max_body_length, Opts, ?MAX_BODY_LENGTH)
}. }.
@ -243,7 +243,9 @@ content_len(#parser_state{headers = Headers}) ->
false -> none false -> none
end. end.
new_frame(#parser_state{cmd = Cmd, headers = Headers, acc = Acc}) -> new_frame(#parser_state{cmd = Cmd, headers = Headers, acc = Acc, limit = Limit}) ->
ok = check_max_headers(Headers, Limit),
ok = check_max_body(Acc, Limit),
#stomp_frame{command = Cmd, headers = Headers, body = Acc}. #stomp_frame{command = Cmd, headers = Headers, body = Acc}.
acc(Chunk, State = #parser_state{acc = Acc}) when is_binary(Chunk) -> acc(Chunk, State = #parser_state{acc = Acc}) when is_binary(Chunk) ->
@ -261,6 +263,57 @@ unescape($c) -> ?COLON;
unescape($\\) -> ?BSL; unescape($\\) -> ?BSL;
unescape(_Ch) -> error(cannnot_unescape). unescape(_Ch) -> error(cannnot_unescape).
check_max_headers(
Headers,
#frame_limit{
max_header_num = MaxNum,
max_header_length = MaxLen
}
) ->
HeadersLen = length(Headers),
case HeadersLen > MaxNum of
true ->
error(
{too_many_headers, #{
max_header_num => MaxNum,
received_headers_num => length(Headers)
}}
);
false ->
ok
end,
lists:foreach(
fun({Name, Val}) ->
Len = byte_size(Name) + byte_size(Val),
case Len > MaxLen of
true ->
error(
{too_long_header, #{
max_header_length => MaxLen,
found_header_length => Len
}}
);
false ->
ok
end
end,
Headers
).
check_max_body(Acc, #frame_limit{max_body_length = MaxLen}) ->
Len = byte_size(Acc),
case Len > MaxLen of
true ->
error(
{too_long_body, #{
max_body_length => MaxLen,
received_body_length => Len
}}
);
false ->
ok
end.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Serialize funcs %% Serialize funcs
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
@ -330,7 +383,10 @@ make(Command, Headers, Body) ->
#stomp_frame{command = Command, headers = Headers, body = Body}. #stomp_frame{command = Command, headers = Headers, body = Body}.
%% @doc Format a frame %% @doc Format a frame
format(Frame) -> serialize_pkt(Frame, #{}). format({frame_error, _Reason} = Error) ->
Error;
format(Frame) ->
serialize_pkt(Frame, #{}).
is_message(#stomp_frame{command = CMD}) when is_message(#stomp_frame{command = CMD}) when
CMD == ?CMD_SEND; CMD == ?CMD_SEND;
@ -373,4 +429,6 @@ type(?CMD_RECEIPT) ->
type(?CMD_ERROR) -> type(?CMD_ERROR) ->
error; error;
type(?CMD_HEARTBEAT) -> type(?CMD_HEARTBEAT) ->
heartbeat. heartbeat;
type(_) ->
undefined.

View File

@ -40,7 +40,12 @@
" username = \"${Packet.headers.login}\"\n" " username = \"${Packet.headers.login}\"\n"
" password = \"${Packet.headers.passcode}\"\n" " password = \"${Packet.headers.passcode}\"\n"
" }\n" " }\n"
" listeners.tcp.default {\n" " frame {\n"
" max_headers = 10\n"
" max_headers_length = 100\n"
" max_body_length = 1024\n"
" }\n"
" listeners.tcp.default {\n"
" bind = 61613\n" " bind = 61613\n"
" }\n" " }\n"
"}\n" "}\n"
@ -705,6 +710,129 @@ t_sticky_packets_truncate_after_headers(_) ->
?assert(false, "waiting message timeout") ?assert(false, "waiting message timeout")
end end
end). end).
t_frame_error_in_connect(_) ->
with_connection(fun(Sock) ->
gen_tcp:send(
Sock,
serialize(
<<"CONNECT">>,
[
{<<"accept-version">>, ?STOMP_VER},
{<<"host">>, <<"127.0.0.1:61613">>},
{<<"login">>, <<"guest">>},
{<<"passcode">>, <<"guest">>},
{<<"heart-beat">>, <<"0,0">>},
{<<"custome_header1">>, <<"val">>},
{<<"custome_header2">>, <<"val">>},
{<<"custome_header3">>, <<"val">>},
{<<"custome_header4">>, <<"val">>},
{<<"custome_header5">>, <<"val">>},
{<<"custome_header6">>, <<"val">>}
]
)
),
?assertMatch({error, closed}, gen_tcp:recv(Sock, 0))
end).
t_frame_error_too_many_headers(_) ->
Frame = serialize(
<<"SEND">>,
[
{<<"destination">>, <<"/queue/foo">>},
{<<"custome_header1">>, <<"val">>},
{<<"custome_header2">>, <<"val">>},
{<<"custome_header3">>, <<"val">>},
{<<"custome_header4">>, <<"val">>},
{<<"custome_header5">>, <<"val">>},
{<<"custome_header6">>, <<"val">>},
{<<"custome_header7">>, <<"val">>},
{<<"custome_header8">>, <<"val">>},
{<<"custome_header9">>, <<"val">>},
{<<"custome_header10">>, <<"val">>}
],
<<"test">>
),
Assert =
fun(Sock) ->
{ok, Data} = gen_tcp:recv(Sock, 0),
{ok, ErrorFrame, _, _} = parse(Data),
?assertMatch(#stomp_frame{command = <<"ERROR">>}, ErrorFrame),
?assertMatch(
match, re:run(ErrorFrame#stomp_frame.body, "too_many_headers", [{capture, none}])
),
?assertMatch({error, closed}, gen_tcp:recv(Sock, 0))
end,
test_frame_error(Frame, Assert).
t_frame_error_too_long_header(_) ->
LongHeaderVal = emqx_utils:bin_to_hexstr(crypto:strong_rand_bytes(50), upper),
Frame = serialize(
<<"SEND">>,
[
{<<"destination">>, <<"/queue/foo">>},
{<<"custome_header10">>, LongHeaderVal}
],
<<"test">>
),
Assert =
fun(Sock) ->
{ok, Data} = gen_tcp:recv(Sock, 0),
{ok, ErrorFrame, _, _} = parse(Data),
?assertMatch(#stomp_frame{command = <<"ERROR">>}, ErrorFrame),
?assertMatch(
match, re:run(ErrorFrame#stomp_frame.body, "too_long_header", [{capture, none}])
),
?assertMatch({error, closed}, gen_tcp:recv(Sock, 0))
end,
test_frame_error(Frame, Assert).
t_frame_error_too_long_body(_) ->
LongBody = emqx_utils:bin_to_hexstr(crypto:strong_rand_bytes(513), upper),
Frame = serialize(
<<"SEND">>,
[{<<"destination">>, <<"/queue/foo">>}],
LongBody
),
Assert =
fun(Sock) ->
{ok, Data} = gen_tcp:recv(Sock, 0),
{ok, ErrorFrame, _, _} = parse(Data),
?assertMatch(#stomp_frame{command = <<"ERROR">>}, ErrorFrame),
?assertMatch(
match, re:run(ErrorFrame#stomp_frame.body, "too_long_body", [{capture, none}])
),
?assertMatch({error, closed}, gen_tcp:recv(Sock, 0))
end,
test_frame_error(Frame, Assert).
test_frame_error(Frame, AssertFun) ->
with_connection(fun(Sock) ->
gen_tcp:send(
Sock,
serialize(
<<"CONNECT">>,
[
{<<"accept-version">>, ?STOMP_VER},
{<<"host">>, <<"127.0.0.1:61613">>},
{<<"login">>, <<"guest">>},
{<<"passcode">>, <<"guest">>},
{<<"heart-beat">>, <<"0,0">>}
]
)
),
{ok, Data} = gen_tcp:recv(Sock, 0),
{ok,
#stomp_frame{
command = <<"CONNECTED">>,
headers = _,
body = _
},
_, _} = parse(Data),
gen_tcp:send(Sock, Frame),
AssertFun(Sock)
end).
t_rest_clienit_info(_) -> t_rest_clienit_info(_) ->
with_connection(fun(Sock) -> with_connection(fun(Sock) ->
gen_tcp:send( gen_tcp:send(
@ -856,9 +984,9 @@ serialize(Command, Headers, Body) ->
parse(Data) -> parse(Data) ->
ProtoEnv = #{ ProtoEnv = #{
max_headers => 10, max_headers => 1024,
max_header_length => 1024, max_header_length => 10240,
max_body_length => 8192 max_body_length => 81920
}, },
Parser = emqx_stomp_frame:initial_parse_state(ProtoEnv), Parser = emqx_stomp_frame:initial_parse_state(ProtoEnv),
emqx_stomp_frame:parse(Data, Parser). emqx_stomp_frame:parse(Data, Parser).