From 6b8ffc386a604a6154e41f679376cfa9ef538247 Mon Sep 17 00:00:00 2001 From: turtleDeng Date: Sat, 4 Jan 2020 10:06:50 +0800 Subject: [PATCH] Fix WS reason code (#3149) --- .travis.yml | 5 +++++ src/emqx_channel.erl | 3 ++- src/emqx_ws_connection.erl | 6 +++++- test/emqx_connection_SUITE.erl | 6 +++++- test/emqx_ws_connection_SUITE.erl | 11 +++++------ 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/.travis.yml b/.travis.yml index 90b15d560..95074d084 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,10 +8,15 @@ before_install: script: - make compile + - rm -rf rebar.lock - make xref + - rm -rf rebar.lock - make eunit + - rm -rf rebar.lock - make ct + - rm -rf rebar.lock - make cover + - rm -rf rebar.lock after_success: - make coveralls diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl index 7cb25848d..23dee6ce9 100644 --- a/src/emqx_channel.erl +++ b/src/emqx_channel.erl @@ -699,7 +699,8 @@ return_unsuback(Packet, Channel) -> -> {reply, Reply :: term(), channel()} | {shutdown, Reason :: term(), Reply :: term(), channel()}). handle_call(kick, Channel) -> - shutdown(kicked, ok, Channel); + Channel1 = ensure_disconnected(kicked, Channel), + shutdown(kicked, ok, Channel1); handle_call(discard, Channel = #channel{conn_state = connected}) -> Packet = ?DISCONNECT_PACKET(?RC_SESSION_TAKEN_OVER), diff --git a/src/emqx_ws_connection.erl b/src/emqx_ws_connection.erl index 7409d0318..ceae32823 100644 --- a/src/emqx_ws_connection.erl +++ b/src/emqx_ws_connection.erl @@ -317,6 +317,8 @@ websocket_info({stop, Reason}, State) -> websocket_info(Info, State) -> handle_info(Info, State). +websocket_close({_, ReasonCode, _Payload}, State) when is_integer(ReasonCode) -> + websocket_close(ReasonCode, State); websocket_close(Reason, State) -> ?LOG(debug, "Websocket closed due to ~p~n", [Reason]), handle_info({sock_closed, Reason}, State). @@ -360,7 +362,7 @@ handle_info({connack, ConnAck}, State) -> handle_info({close, Reason}, State) -> ?LOG(debug, "Force to close the socket due to ~p", [Reason]), - return(enqueue(close, State)); + return(enqueue({close, Reason}, State)); handle_info({event, connected}, State = #state{channel = Channel}) -> ClientId = emqx_channel:info(clientid, Channel), @@ -639,6 +641,8 @@ classify([Cmd = {shutdown, _Reason}|More], Packets, Cmds, Events) -> classify(More, Packets, [Cmd|Cmds], Events); classify([Cmd = close|More], Packets, Cmds, Events) -> classify(More, Packets, [Cmd|Cmds], Events); +classify([Cmd = {close, _Reason}|More], Packets, Cmds, Events) -> + classify(More, Packets, [Cmd|Cmds], Events); classify([Event|More], Packets, Cmds, Events) -> classify(More, Packets, Cmds, [Event|Events]). diff --git a/test/emqx_connection_SUITE.erl b/test/emqx_connection_SUITE.erl index 12e85535f..822194443 100644 --- a/test/emqx_connection_SUITE.erl +++ b/test/emqx_connection_SUITE.erl @@ -41,6 +41,7 @@ init_per_suite(Config) -> ok = meck:new(emqx_pd, [passthrough, no_history, no_link]), %% Meck Metrics ok = meck:new(emqx_metrics, [passthrough, no_history, no_link]), + ok = meck:expect(emqx_metrics, inc, fun(_) -> ok end), ok = meck:expect(emqx_metrics, inc, fun(_, _) -> ok end), ok = meck:expect(emqx_metrics, inc_recv, fun(_) -> ok end), ok = meck:expect(emqx_metrics, inc_sent, fun(_) -> ok end), @@ -48,6 +49,9 @@ init_per_suite(Config) -> ok = meck:new(emqx_hooks, [passthrough, no_history, no_link]), ok = meck:expect(emqx_hooks, run, fun(_Hook, _Args) -> ok end), ok = meck:expect(emqx_hooks, run_fold, fun(_Hook, _Args, Acc) -> {ok, Acc} end), + + ok = meck:expect(emqx_channel, ensure_disconnected, fun(_, Channel) -> Channel end), + Config. end_per_suite(_Config) -> @@ -218,7 +222,7 @@ t_handle_call(_) -> ?assertMatch({reply, _Info, _NSt}, emqx_connection:handle_call(self(), info, St)), ?assertMatch({reply, _Stats, _NSt }, emqx_connection:handle_call(self(), stats, St)), ?assertEqual({reply, ignored, St}, emqx_connection:handle_call(self(), for_testing, St)), - ?assertEqual({stop, {shutdown,kicked}, ok, St}, emqx_connection:handle_call(self(), kick, St)). + ?assertMatch({stop, {shutdown,kicked}, ok, _NSt}, emqx_connection:handle_call(self(), kick, St)). t_handle_timeout(_) -> TRef = make_ref(), diff --git a/test/emqx_ws_connection_SUITE.erl b/test/emqx_ws_connection_SUITE.erl index 9eaa9b99c..09c9d5dd7 100644 --- a/test/emqx_ws_connection_SUITE.erl +++ b/test/emqx_ws_connection_SUITE.erl @@ -206,7 +206,7 @@ t_websocket_info_incoming(_) -> username = <<"username">>, password = <<"passwd">> }, - {ok, St1} = websocket_info({incoming, ?CONNECT_PACKET(ConnPkt)}, st()), + {[{close,protocol_error}], St1} = websocket_info({incoming, ?CONNECT_PACKET(ConnPkt)}, st()), % ?assertEqual(<<224,2,130,0>>, iolist_to_binary(IoData1)), %% PINGREQ {[{binary, IoData2}], St2} = @@ -214,8 +214,7 @@ t_websocket_info_incoming(_) -> ?assertEqual(<<208,0>>, iolist_to_binary(IoData2)), %% PUBLISH Publish = ?PUBLISH_PACKET(?QOS_1, <<"t">>, 1, <<"payload">>), - {[{binary, IoData3}], _St3} = - websocket_info({incoming, Publish}, St2), + {[{binary, IoData3}], _St3} = websocket_info({incoming, Publish}, St2), ?assertEqual(<<64,4,0,1,0,0>>, iolist_to_binary(IoData3)). t_websocket_info_check_gc(_) -> @@ -248,7 +247,7 @@ t_websocket_info_timeout_retry(_) -> {ok, _St} = websocket_info({timeout, make_ref(), retry_delivery}, st()). t_websocket_info_close(_) -> - {[close], _St} = websocket_info({close, sock_error}, st()). + {[{close, _}], _St} = websocket_info({close, sock_error}, st()). t_websocket_info_shutdown(_) -> {[{shutdown, reason}], _St} = websocket_info({shutdown, reason}, st()). @@ -266,7 +265,7 @@ t_handle_info_connack(_) -> ?assertEqual(<<32,2,0,0>>, iolist_to_binary(IoData)). t_handle_info_close(_) -> - {[close], _St} = ?ws_conn:handle_info({close, protocol_error}, st()). + {[{close, _}], _St} = ?ws_conn:handle_info({close, protocol_error}, st()). t_handle_info_event(_) -> ok = meck:new(emqx_cm, [passthrough, no_history]), @@ -315,7 +314,7 @@ t_parse_incoming_frame_error(_) -> t_handle_incomming_frame_error(_) -> FrameError = {frame_error, bad_qos}, Serialize = emqx_frame:serialize_fun(#{version => 5, max_size => 16#FFFF}), - {ok, _St} = ?ws_conn:handle_incoming(FrameError, st(#{serialize => Serialize})). + {[{close, bad_qos}], _St} = ?ws_conn:handle_incoming(FrameError, st(#{serialize => Serialize})). % ?assertEqual(<<224,2,129,0>>, iolist_to_binary(IoData)). t_handle_outgoing(_) ->