From db3a54e31a654235208eace9f53531b86dd96462 Mon Sep 17 00:00:00 2001 From: Feng Lee Date: Mon, 20 Nov 2017 14:15:51 +0800 Subject: [PATCH] Fast close the invalid websocket in init/1 function --- src/emqttd_ws_client.erl | 43 +++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/src/emqttd_ws_client.erl b/src/emqttd_ws_client.erl index a583611a4..b9d25ad3e 100644 --- a/src/emqttd_ws_client.erl +++ b/src/emqttd_ws_client.erl @@ -91,23 +91,31 @@ clean_acl_cache(CPid, Topic) -> init([Env, WsPid, Req, ReplyChannel]) -> process_flag(trap_exit, true), - true = link(WsPid), - {ok, Peername} = Req:get(peername), - Headers = mochiweb_headers:to_list( - mochiweb_request:get(headers, Req)), Conn = Req:get(connection), - ProtoState = emqttd_protocol:init(Conn, Peername, send_fun(ReplyChannel), - [{ws_initial_headers, Headers} | Env]), - IdleTimeout = get_value(client_idle_timeout, Env, 30000), - EnableStats = get_value(client_enable_stats, Env, false), - ForceGcCount = emqttd_gc:conn_max_gc_count(), - {ok, #wsclient_state{connection = Conn, - ws_pid = WsPid, - peername = Peername, - proto_state = ProtoState, - enable_stats = EnableStats, - force_gc_count = ForceGcCount}, - IdleTimeout, {backoff, 2000, 2000, 20000}, ?MODULE}. + true = link(WsPid), + case Req:get(peername) of + {ok, Peername} -> + Headers = mochiweb_headers:to_list( + mochiweb_request:get(headers, Req)), + ProtoState = emqttd_protocol:init(Conn, Peername, send_fun(ReplyChannel), + [{ws_initial_headers, Headers} | Env]), + IdleTimeout = get_value(client_idle_timeout, Env, 30000), + EnableStats = get_value(client_enable_stats, Env, false), + ForceGcCount = emqttd_gc:conn_max_gc_count(), + {ok, #wsclient_state{connection = Conn, + ws_pid = WsPid, + peername = Peername, + proto_state = ProtoState, + enable_stats = EnableStats, + force_gc_count = ForceGcCount}, + IdleTimeout, {backoff, 2000, 2000, 20000}, ?MODULE}; + {error, enotconn} -> Conn:fast_close(), + exit(WsPid, normal), + exit(normal); + {error, Reason} -> Conn:fast_close(), + exit(WsPid, normal), + exit({shutdown, Reason}) + end. prioritise_call(Msg, _From, _Len, _State) -> case Msg of info -> 10; stats -> 10; state -> 10; _ -> 5 end. @@ -203,6 +211,9 @@ handle_info({shutdown, conflict, {ClientId, NewPid}}, State) -> ?WSLOG(warning, "clientid '~s' conflict with ~p", [ClientId, NewPid], State), shutdown(conflict, State); +handle_info({shutdown, Reason}, State) -> + shutdown(Reason, State); + handle_info({keepalive, start, Interval}, State = #wsclient_state{connection = Conn}) -> ?WSLOG(debug, "Keepalive at the interval of ~p", [Interval], State), case emqttd_keepalive:start(stat_fun(Conn), Interval, {keepalive, check}) of