From 81602c973c3dd9fd7d17d8353f83913ac3f79330 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Fri, 19 Mar 2021 13:17:47 +0800 Subject: [PATCH] fix(emqx): deny pingreq when mqtt not connected #4370 --- src/emqx_connection.erl | 3 --- src/emqx_ws_connection.erl | 3 --- test/emqx_connection_SUITE.erl | 21 ++++++++++------ test/emqx_ws_connection_SUITE.erl | 42 +++++++++++++++++++++++++++++-- 4 files changed, 54 insertions(+), 15 deletions(-) diff --git a/src/emqx_connection.erl b/src/emqx_connection.erl index 11a3a9418..88a1ba4d7 100644 --- a/src/emqx_connection.erl +++ b/src/emqx_connection.erl @@ -359,9 +359,6 @@ handle_msg({incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, }, handle_incoming(Packet, NState); -handle_msg({incoming, ?PACKET(?PINGREQ)}, State) -> - handle_outgoing(?PACKET(?PINGRESP), State); - handle_msg({incoming, Packet}, State) -> handle_incoming(Packet, State); diff --git a/src/emqx_ws_connection.erl b/src/emqx_ws_connection.erl index 98cc2b52f..01f7b5e2b 100644 --- a/src/emqx_ws_connection.erl +++ b/src/emqx_ws_connection.erl @@ -348,9 +348,6 @@ websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State) -> NState = State#state{serialize = Serialize}, handle_incoming(Packet, cancel_idle_timer(NState)); -websocket_info({incoming, ?PACKET(?PINGREQ)}, State) -> - return(enqueue(?PACKET(?PINGRESP), State)); - websocket_info({incoming, Packet}, State) -> handle_incoming(Packet, State); diff --git a/test/emqx_connection_SUITE.erl b/test/emqx_connection_SUITE.erl index b011c6978..22abd2ca8 100644 --- a/test/emqx_connection_SUITE.erl +++ b/test/emqx_connection_SUITE.erl @@ -60,7 +60,7 @@ init_per_suite(Config) -> end_per_suite(_Config) -> ok = meck:unload(emqx_transport), - ok = meck:unload(emqx_channel), + catch meck:unload(emqx_channel), ok = meck:unload(emqx_cm), ok = meck:unload(emqx_limiter), ok = meck:unload(emqx_pd), @@ -69,7 +69,8 @@ end_per_suite(_Config) -> ok = meck:unload(emqx_alarm), ok. -init_per_testcase(_TestCase, Config) -> +init_per_testcase(TestCase, Config) when + TestCase =/= t_ws_pingreq_before_connected -> ok = meck:expect(emqx_transport, wait, fun(Sock) -> {ok, Sock} end), ok = meck:expect(emqx_transport, type, fun(_Sock) -> tcp end), ok = meck:expect(emqx_transport, ensure_ok_or_exit, @@ -87,6 +88,8 @@ init_per_testcase(_TestCase, Config) -> ok = meck:expect(emqx_transport, async_send, fun(_Sock, _Data) -> ok end), ok = meck:expect(emqx_transport, async_send, fun(_Sock, _Data, _Opts) -> ok end), ok = meck:expect(emqx_transport, fast_close, fun(_Sock) -> ok end), + Config; +init_per_testcase(_, Config) -> Config. end_per_testcase(_TestCase, Config) -> @@ -95,6 +98,9 @@ end_per_testcase(_TestCase, Config) -> %%-------------------------------------------------------------------- %% Test cases %%-------------------------------------------------------------------- +t_ws_pingreq_before_connected(_) -> + ?assertMatch({ok, [_, {close,protocol_error}], _}, + handle_msg({incoming, ?PACKET(?PINGREQ)}, st(#{}, #{conn_state => disconnected}))). t_info(_) -> CPid = spawn(fun() -> @@ -175,7 +181,6 @@ t_handle_msg(_) -> t_handle_msg_incoming(_) -> ?assertMatch({ok, _Out, _St}, handle_msg({incoming, ?CONNECT_PACKET(#mqtt_packet_connect{})}, st())), - ?assertEqual(ok, handle_msg({incoming, ?PACKET(?PINGREQ)}, st())), ok = meck:expect(emqx_channel, handle_in, fun(_Packet, Channel) -> {ok, Channel} end), ?assertMatch({ok, _St}, handle_msg({incoming, ?PUBLISH_PACKET(?QOS_1, <<"t">>, 1, <<"payload">>)}, st())), @@ -277,7 +282,6 @@ t_handle_incoming(_) -> t_with_channel(_) -> State = st(), - ok = meck:expect(emqx_channel, handle_in, fun(_, _) -> ok end), ?assertEqual({ok, State}, emqx_connection:with_channel(handle_in, [for_testing], State)), @@ -300,7 +304,8 @@ t_with_channel(_) -> {shutdown, [for_testing], ?DISCONNECT_PACKET(), Channel} end), ?assertMatch({stop, {shutdown,[for_testing]}, _NState}, - emqx_connection:with_channel(handle_in, [for_testing], State)). + emqx_connection:with_channel(handle_in, [for_testing], State)), + meck:unload(emqx_channel). t_handle_outgoing(_) -> ?assertEqual(ok, emqx_connection:handle_outgoing(?PACKET(?PINGRESP), st())), @@ -432,11 +437,13 @@ make_frame(Packet) -> payload(Len) -> iolist_to_binary(lists:duplicate(Len, 1)). -st() -> st(#{}). +st() -> st(#{}, #{}). st(InitFields) when is_map(InitFields) -> + st(InitFields, #{}). +st(InitFields, ChannelFields) when is_map(InitFields) -> St = emqx_connection:init_state(emqx_transport, sock, [#{zone => external}]), maps:fold(fun(N, V, S) -> emqx_connection:set_field(N, V, S) end, - emqx_connection:set_field(channel, channel(), St), + emqx_connection:set_field(channel, channel(ChannelFields), St), InitFields ). diff --git a/test/emqx_ws_connection_SUITE.erl b/test/emqx_ws_connection_SUITE.erl index d83d7d859..4cf803024 100644 --- a/test/emqx_ws_connection_SUITE.erl +++ b/test/emqx_ws_connection_SUITE.erl @@ -45,7 +45,9 @@ init_per_testcase(TestCase, Config) when TestCase =/= t_ws_sub_protocols_mqtt_equivalents, TestCase =/= t_ws_sub_protocols_mqtt, TestCase =/= t_ws_check_origin, - TestCase =/= t_ws_non_check_origin -> + TestCase =/= t_ws_pingreq_before_connected, + TestCase =/= t_ws_non_check_origin + -> %% Mock cowboy_req ok = meck:new(cowboy_req, [passthrough, no_history, no_link]), ok = meck:expect(cowboy_req, peer, fun(_) -> {{127,0,0,1}, 3456} end), @@ -89,7 +91,9 @@ end_per_testcase(TestCase, Config) when TestCase =/= t_ws_sub_protocols_mqtt_equivalents, TestCase =/= t_ws_sub_protocols_mqtt, TestCase =/= t_ws_check_origin, - TestCase =/= t_ws_non_check_origin -> + TestCase =/= t_ws_non_check_origin, + TestCase =/= t_ws_pingreq_before_connected + -> lists:foreach(fun meck:unload/1, [cowboy_req, emqx_zone, @@ -154,6 +158,40 @@ t_call(_) -> end), ?assertEqual(Info, ?ws_conn:call(WsPid, info)). +t_ws_pingreq_before_connected(_) -> + ok = emqx_ct_helpers:start_apps([]), + {ok, _} = application:ensure_all_started(gun), + {ok, WPID} = gun:open("127.0.0.1", 8083), + ws_pingreq(#{}), + gun:close(WPID), + emqx_ct_helpers:stop_apps([]). + +ws_pingreq(State) -> + receive + {gun_up, WPID, _Proto} -> + StreamRef = gun:ws_upgrade(WPID, "/mqtt", [], #{ + protocols => [{<<"mqtt">>, gun_ws_h}]}), + ws_pingreq(State#{wref => StreamRef}); + {gun_down, _WPID, _, Reason, _, _} -> + State#{result => {gun_down, Reason}}; + {gun_upgrade, WPID, _Ref, _Proto, _Data} -> + ct:pal("-- gun_upgrade, send ping-req"), + PingReq = {binary, <<192,0>>}, + ok = gun:ws_send(WPID, PingReq), + gun:flush(WPID), + ws_pingreq(State); + {gun_ws, _WPID, _Ref, {binary, <<208,0>>}} -> + ct:fail(unexpected_pingresp); + {gun_ws, _WPID, _Ref, Frame} -> + ct:pal("gun received frame: ~p", [Frame]), + ws_pingreq(State); + Message -> + ct:pal("Received Unknown Message on Gun: ~p~n",[Message]), + ws_pingreq(State) + after 1000 -> + ct:fail(ws_timeout) + end. + t_ws_sub_protocols_mqtt(_) -> ok = emqx_ct_helpers:start_apps([]), {ok, _} = application:ensure_all_started(gun),