From 630b54f6eec97bd89313288d2d83ea03c346ee2c Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Wed, 7 Jul 2021 14:43:54 +0800 Subject: [PATCH] feat(acl): make mqtt over websocket work with the new config --- apps/emqx/etc/emqx.conf | 11 ++- apps/emqx/src/emqx_connection.erl | 28 +++----- apps/emqx/src/emqx_listeners.erl | 25 +++---- apps/emqx/src/emqx_map_lib.erl | 11 +-- apps/emqx/src/emqx_schema.erl | 10 +-- apps/emqx/src/emqx_ws_connection.erl | 104 +++++++++++++++------------ 6 files changed, 100 insertions(+), 89 deletions(-) diff --git a/apps/emqx/etc/emqx.conf b/apps/emqx/etc/emqx.conf index b4b1393d6..4e3ff60c3 100644 --- a/apps/emqx/etc/emqx.conf +++ b/apps/emqx/etc/emqx.conf @@ -2218,8 +2218,8 @@ example_common_websocket_options { ## ## @doc listeners..websocket.compress ## ValueType: Boolean - ## Default: true - websocket.compress: true + ## Default: false + websocket.compress: false ## The idle timeout for external WebSocket connections. ## @@ -2244,6 +2244,13 @@ example_common_websocket_options { ## Default: true websocket.fail_if_no_subprotocol: true + ## Supported subprotocols + ## + ## @doc listeners..websocket.supported_subprotocols + ## ValueType: String + ## Default: mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5 + websocket.supported_subprotocols: "mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5" + ## Enable origin check in header for websocket connection ## ## @doc listeners..websocket.check_origin_enable diff --git a/apps/emqx/src/emqx_connection.erl b/apps/emqx/src/emqx_connection.erl index fb8fe241b..0df7cccf1 100644 --- a/apps/emqx/src/emqx_connection.erl +++ b/apps/emqx/src/emqx_connection.erl @@ -243,7 +243,7 @@ init(Parent, Transport, RawSocket, Options) -> exit_on_sock_error(Reason) end. -init_state(Transport, Socket, Options) -> +init_state(Transport, Socket, #{zone := Zone, listener := Listener} = Opts) -> {ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]), {ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]), Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]), @@ -253,8 +253,6 @@ init_state(Transport, Socket, Options) -> peercert => Peercert, conn_mod => ?MODULE }, - Zone = maps:get(zone, Options), - Listener = maps:get(listener, Options), Limiter = emqx_limiter:init(Zone, undefined, undefined, []), FrameOpts = #{ strict_mode => emqx_config:get_listener_conf(Zone, Listener, [mqtt, strict_mode]), @@ -262,7 +260,7 @@ init_state(Transport, Socket, Options) -> }, ParseState = emqx_frame:initial_parse_state(FrameOpts), Serialize = emqx_frame:serialize_opts(), - Channel = emqx_channel:init(ConnInfo, Options), + Channel = emqx_channel:init(ConnInfo, Opts), GcState = case emqx_config:get_listener_conf(Zone, Listener, [force_gc]) of #{enable := false} -> undefined; GcPolicy -> emqx_gc:init(GcPolicy) @@ -295,11 +293,9 @@ run_loop(Parent, State = #state{transport = Transport, peername = Peername, channel = Channel}) -> emqx_logger:set_metadata_peername(esockd:format(Peername)), - case emqx_config:get_listener_conf(emqx_channel:info(zone, Channel), - emqx_channel:info(listener, Channel), [force_shutdown]) of - #{enable := false} -> ok; - ShutdownPolicy -> emqx_misc:tune_heap_size(ShutdownPolicy) - end, + ShutdownPolicy = emqx_config:get_listener_conf(emqx_channel:info(zone, Channel), + emqx_channel:info(listener, Channel), [force_shutdown]), + emqx_misc:tune_heap_size(ShutdownPolicy), case activate_socket(State) of {ok, NState} -> hibernate(Parent, NState); {error, Reason} -> @@ -793,15 +789,11 @@ check_oom(State = #state{channel = Channel}) -> ShutdownPolicy = emqx_config:get_listener_conf(emqx_channel:info(zone, Channel), emqx_channel:info(listener, Channel), [force_shutdown]), ?tp(debug, check_oom, #{policy => ShutdownPolicy}), - case ShutdownPolicy of - #{enable := false} -> ok; - ShutdownPolicy -> - case emqx_misc:check_oom(ShutdownPolicy) of - {shutdown, Reason} -> - %% triggers terminate/2 callback immediately - erlang:exit({shutdown, Reason}); - _ -> ok - end + case emqx_misc:check_oom(ShutdownPolicy) of + {shutdown, Reason} -> + %% triggers terminate/2 callback immediately + erlang:exit({shutdown, Reason}); + _ -> ok end, State. diff --git a/apps/emqx/src/emqx_listeners.erl b/apps/emqx/src/emqx_listeners.erl index 9daf6c79e..6e3bc69be 100644 --- a/apps/emqx/src/emqx_listeners.erl +++ b/apps/emqx/src/emqx_listeners.erl @@ -74,13 +74,13 @@ do_start_listener(ZoneName, ListenerName, #{type := tcp, bind := ListenOn} = Opt %% Start MQTT/WS listener do_start_listener(ZoneName, ListenerName, #{type := ws, bind := ListenOn} = Opts) -> Id = listener_id(ZoneName, ListenerName), - RanchOpts = ranch_opts(Opts), + RanchOpts = ranch_opts(ListenOn, Opts), WsOpts = ws_opts(ZoneName, ListenerName, Opts), case is_ssl(Opts) of false -> - cowboy:start_clear(Id, with_port(ListenOn, RanchOpts), WsOpts); + cowboy:start_clear(Id, RanchOpts, WsOpts); true -> - cowboy:start_tls(Id, with_port(ListenOn, RanchOpts), WsOpts) + cowboy:start_tls(Id, RanchOpts, WsOpts) end. esockd_opts(Opts0) -> @@ -104,21 +104,22 @@ ws_opts(ZoneName, ListenerName, Opts) -> ProxyProto = maps:get(proxy_protocol, Opts, false), #{env => #{dispatch => Dispatch}, proxy_header => ProxyProto}. -ranch_opts(Opts) -> +ranch_opts(ListenOn, Opts) -> NumAcceptors = maps:get(acceptors, Opts, 4), MaxConnections = maps:get(max_connections, Opts, 1024), + SocketOpts = case is_ssl(Opts) of + true -> tcp_opts(Opts) ++ proplists:delete(handshake_timeout, ssl_opts(Opts)); + false -> tcp_opts(Opts) + end, #{num_acceptors => NumAcceptors, max_connections => MaxConnections, handshake_timeout => maps:get(handshake_timeout, Opts, 15000), - socket_opts => case is_ssl(Opts) of - true -> tcp_opts(Opts) ++ proplists:delete(handshake_timeout, ssl_opts(Opts)); - false -> tcp_opts(Opts) - end}. + socket_opts => ip_port(ListenOn) ++ SocketOpts}. -with_port(Port, Opts = #{socket_opts := SocketOption}) when is_integer(Port) -> - Opts#{socket_opts => [{port, Port}| SocketOption]}; -with_port({Addr, Port}, Opts = #{socket_opts := SocketOption}) -> - Opts#{socket_opts => [{ip, Addr}, {port, Port}| SocketOption]}. +ip_port(Port) when is_integer(Port) -> + [{port, Port}]; +ip_port({Addr, Port}) -> + [{ip, Addr}, {port, Port}]. esockd_access_rules(StrRules) -> Access = fun(S) -> diff --git a/apps/emqx/src/emqx_map_lib.erl b/apps/emqx/src/emqx_map_lib.erl index 477891c51..154a3d24f 100644 --- a/apps/emqx/src/emqx_map_lib.erl +++ b/apps/emqx/src/emqx_map_lib.erl @@ -43,16 +43,17 @@ deep_get(ConfKeyPath, Map, Default) -> {ok, Data} -> Data end. --spec deep_find(config_key_path(), map()) -> {ok, term()} | {not_found, config_key(), term()}. +-spec deep_find(config_key_path(), map()) -> + {ok, term()} | {not_found, config_key_path(), term()}. deep_find([], Map) -> {ok, Map}; -deep_find([Key | KeyPath], Map) when is_map(Map) -> +deep_find([Key | KeyPath] = Path, Map) when is_map(Map) -> case maps:find(Key, Map) of {ok, SubMap} -> deep_find(KeyPath, SubMap); - error -> {not_found, Key, Map} + error -> {not_found, Path, Map} end; -deep_find([Key | _KeyPath], Data) -> - {not_found, Key, Data}. +deep_find(_KeyPath, Data) -> + {not_found, _KeyPath, Data}. -spec deep_put(config_key_path(), map(), term()) -> map(). deep_put([], Map, Config) when is_map(Map) -> diff --git a/apps/emqx/src/emqx_schema.erl b/apps/emqx/src/emqx_schema.erl index 98488edf8..0d6c8f3de 100644 --- a/apps/emqx/src/emqx_schema.erl +++ b/apps/emqx/src/emqx_schema.erl @@ -364,11 +364,11 @@ fields("mqtt_ws_listener") -> fields("ws_opts") -> [ {"mqtt_path", t(string(), undefined, "/mqtt")} , {"mqtt_piggyback", t(union(single, multiple), undefined, multiple)} - , {"compress", t(boolean())} + , {"compress", t(boolean(), undefined, false)} , {"idle_timeout", t(duration(), undefined, "15s")} , {"max_frame_size", maybe_infinity(integer())} , {"fail_if_no_subprotocol", t(boolean(), undefined, true)} - , {"supported_subprotocols", t(string(), undefined, + , {"supported_subprotocols", t(comma_separated_list(), undefined, "mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5")} , {"check_origin_enable", t(boolean(), undefined, false)} , {"allow_origin_absence", t(boolean(), undefined, true)} @@ -401,12 +401,12 @@ fields("ssl_opts") -> fields("deflate_opts") -> [ {"level", t(union([none, default, best_compression, best_speed]))} - , {"mem_level", t(range(1, 9))} + , {"mem_level", t(range(1, 9), undefined, 8)} , {"strategy", t(union([default, filtered, huffman_only, rle]))} , {"server_context_takeover", t(union(takeover, no_takeover))} , {"client_context_takeover", t(union(takeover, no_takeover))} - , {"server_max_window_bits", t(integer())} - , {"client_max_window_bits", t(integer())} + , {"server_max_window_bits", t(range(8, 15), undefined, 15)} + , {"client_max_window_bits", t(range(8, 15), undefined, 15)} ]; fields("module") -> diff --git a/apps/emqx/src/emqx_ws_connection.erl b/apps/emqx/src/emqx_ws_connection.erl index b50505bf8..cd432b9fc 100644 --- a/apps/emqx/src/emqx_ws_connection.erl +++ b/apps/emqx/src/emqx_ws_connection.erl @@ -174,21 +174,13 @@ call(WsPid, Req, Timeout) when is_pid(WsPid) -> %% WebSocket callbacks %%-------------------------------------------------------------------- -init(Req, Opts) -> +init(Req, #{zone := Zone, listener := Listener} = Opts) -> %% WS Transport Idle Timeout - IdleTimeout = proplists:get_value(idle_timeout, Opts, 7200000), - DeflateOptions = maps:from_list(proplists:get_value(deflate_options, Opts, [])), - MaxFrameSize = case proplists:get_value(max_frame_size, Opts, 0) of - 0 -> infinity; - I -> I - end, - Compress = proplists:get_bool(compress, Opts), - WsOpts = #{compress => Compress, - deflate_opts => DeflateOptions, - max_frame_size => MaxFrameSize, - idle_timeout => IdleTimeout + WsOpts = #{compress => get_ws_opts(Zone, Listener, compress), + deflate_opts => get_ws_opts(Zone, Listener, deflate_opts), + max_frame_size => get_ws_opts(Zone, Listener, max_frame_size), + idle_timeout => get_ws_opts(Zone, Listener, idle_timeout) }, - case check_origin_header(Req, Opts) of {error, Message} -> ?LOG(error, "Invalid Origin Header ~p~n", [Message]), @@ -196,18 +188,17 @@ init(Req, Opts) -> ok -> parse_sec_websocket_protocol(Req, Opts, WsOpts) end. -parse_sec_websocket_protocol(Req, Opts, WsOpts) -> - FailIfNoSubprotocol = proplists:get_value(fail_if_no_subprotocol, Opts), +parse_sec_websocket_protocol(Req, #{zone := Zone, listener := Listener} = Opts, WsOpts) -> case cowboy_req:parse_header(<<"sec-websocket-protocol">>, Req) of undefined -> - case FailIfNoSubprotocol of + case get_ws_opts(Zone, Listener, fail_if_no_subprotocol) of true -> {ok, cowboy_req:reply(400, Req), WsOpts}; false -> {cowboy_websocket, Req, [Req, Opts], WsOpts} end; Subprotocols -> - SupportedSubprotocols = proplists:get_value(supported_subprotocols, Opts), + SupportedSubprotocols = get_ws_opts(Zone, Listener, supported_subprotocols), NSupportedSubprotocols = [list_to_binary(Subprotocol) || Subprotocol <- SupportedSubprotocols], case pick_subprotocol(Subprotocols, NSupportedSubprotocols) of @@ -231,31 +222,30 @@ pick_subprotocol([Subprotocol | Rest], SupportedSubprotocols) -> pick_subprotocol(Rest, SupportedSubprotocols) end. -parse_header_fun_origin(Req, Opts) -> +parse_header_fun_origin(Req, #{zone := Zone, listener := Listener}) -> case cowboy_req:header(<<"origin">>, Req) of undefined -> - case proplists:get_bool(allow_origin_absence, Opts) of + case get_ws_opts(Zone, Listener, allow_origin_absence) of true -> ok; false -> {error, origin_header_cannot_be_absent} end; Value -> - Origins = proplists:get_value(check_origins, Opts, []), - case lists:member(Value, Origins) of + case lists:member(Value, get_ws_opts(Zone, Listener, check_origins)) of true -> ok; false -> {origin_not_allowed, Value} end end. -check_origin_header(Req, Opts) -> - case proplists:get_bool(check_origin_enable, Opts) of +check_origin_header(Req, #{zone := Zone, listener := Listener} = Opts) -> + case get_ws_opts(Zone, Listener, check_origin_enable) of true -> parse_header_fun_origin(Req, Opts); false -> ok end. -websocket_init([Req, Opts]) -> +websocket_init([Req, #{zone := Zone, listener := Listener} = Opts]) -> {Peername, Peercert} = - case proplists:get_bool(proxy_protocol, Opts) - andalso maps:get(proxy_header, Req) of + case emqx_config:get_listener_conf(Zone, Listener, [proxy_protocol]) andalso + maps:get(proxy_header, Req) of #{src_address := SrcAddr, src_port := SrcPort, ssl := SSL} -> SourceName = {SrcAddr, SrcPort}, %% Notice: Only CN is available in Proxy Protocol V2 additional info @@ -266,7 +256,7 @@ websocket_init([Req, Opts]) -> {SourceName, SourceSSL}; #{src_address := SrcAddr, src_port := SrcPort} -> SourceName = {SrcAddr, SrcPort}, - {SourceName , nossl}; + {SourceName, nossl}; _ -> {get_peer(Req, Opts), cowboy_req:cert(Req)} end, @@ -288,22 +278,31 @@ websocket_init([Req, Opts]) -> ws_cookie => WsCookie, conn_mod => ?MODULE }, - Zone = proplists:get_value(zone, Opts), - PubLimit = emqx_zone:publish_limit(Zone), - BytesIn = proplists:get_value(rate_limit, Opts), - RateLimit = emqx_zone:ratelimit(Zone), - Limiter = emqx_limiter:init(Zone, PubLimit, BytesIn, RateLimit), - MQTTPiggyback = proplists:get_value(mqtt_piggyback, Opts, multiple), - FrameOpts = emqx_zone:mqtt_frame_options(Zone), + Limiter = emqx_limiter:init(Zone, undefined, undefined, []), + MQTTPiggyback = get_ws_opts(Zone, Listener, mqtt_piggyback), + FrameOpts = #{ + strict_mode => emqx_config:get_listener_conf(Zone, Listener, [mqtt, strict_mode]), + max_size => emqx_config:get_listener_conf(Zone, Listener, [mqtt, max_packet_size]) + }, ParseState = emqx_frame:initial_parse_state(FrameOpts), Serialize = emqx_frame:serialize_opts(), Channel = emqx_channel:init(ConnInfo, Opts), - GcState = emqx_zone:init_gc_state(Zone), - StatsTimer = emqx_zone:stats_timer(Zone), + GcState = case emqx_config:get_listener_conf(Zone, Listener, [force_gc]) of + #{enable := false} -> undefined; + GcPolicy -> emqx_gc:init(GcPolicy) + end, + StatsTimer = case emqx_config:get_listener_conf(Zone, Listener, [stats, enable]) of + true -> undefined; + false -> disabled + end, %% MQTT Idle Timeout - IdleTimeout = emqx_zone:idle_timeout(Zone), + IdleTimeout = emqx_channel:get_mqtt_conf(Zone, Listener, idle_timeout), IdleTimer = start_timer(IdleTimeout, idle_timeout), - emqx_misc:tune_heap_size(emqx_zone:oom_policy(Zone)), + case emqx_config:get_listener_conf(emqx_channel:info(zone, Channel), + emqx_channel:info(listener, Channel), [force_shutdown]) of + #{enable := false} -> ok; + ShutdownPolicy -> emqx_misc:tune_heap_size(ShutdownPolicy) + end, emqx_logger:set_metadata_peername(esockd:format(Peername)), {ok, #state{peername = Peername, sockname = Sockname, @@ -317,7 +316,9 @@ websocket_init([Req, Opts]) -> postponed = [], stats_timer = StatsTimer, idle_timeout = IdleTimeout, - idle_timer = IdleTimer + idle_timer = IdleTimer, + zone = Zone, + listener = Listener }, hibernate}. websocket_handle({binary, Data}, State) when is_list(Data) -> @@ -517,11 +518,16 @@ run_gc(Stats, State = #state{gc_state = GcSt}) -> end. check_oom(State = #state{channel = Channel}) -> - OomPolicy = emqx_zone:oom_policy(emqx_channel:info(zone, Channel)), - case ?ENABLED(OomPolicy) andalso emqx_misc:check_oom(OomPolicy) of - Shutdown = {shutdown, _Reason} -> - postpone(Shutdown, State); - _Other -> State + ShutdownPolicy = emqx_config:get_listener_conf(emqx_channel:info(zone, Channel), + emqx_channel:info(listener, Channel), [force_shutdown]), + case ShutdownPolicy of + #{enable := false} -> ok; + #{enable := true} -> + case emqx_misc:check_oom(ShutdownPolicy) of + Shutdown = {shutdown, _Reason} -> + postpone(Shutdown, State); + _Other -> State + end end. %%-------------------------------------------------------------------- @@ -741,9 +747,10 @@ classify([Event|More], Packets, Cmds, Events) -> trigger(Event) -> erlang:send(self(), Event). -get_peer(Req, Opts) -> +get_peer(Req, #{zone := Zone, listener := Listener}) -> {PeerAddr, PeerPort} = cowboy_req:peer(Req), - AddrHeader = cowboy_req:header(proplists:get_value(proxy_address_header, Opts), Req, <<>>), + AddrHeader = cowboy_req:header( + get_ws_opts(Zone, Listener, proxy_address_header), Req, <<>>), ClientAddr = case string:tokens(binary_to_list(AddrHeader), ", ") of [] -> undefined; @@ -756,7 +763,8 @@ get_peer(Req, Opts) -> _ -> PeerAddr end, - PortHeader = cowboy_req:header(proplists:get_value(proxy_port_header, Opts), Req, <<>>), + PortHeader = cowboy_req:header( + get_ws_opts(Zone, Listener, proxy_port_header), Req, <<>>), ClientPort = case string:tokens(binary_to_list(PortHeader), ", ") of [] -> undefined; @@ -777,3 +785,5 @@ set_field(Name, Value, State) -> Pos = emqx_misc:index_of(Name, record_info(fields, state)), setelement(Pos+1, State, Value). +get_ws_opts(Zone, Listener, Key) -> + emqx_config:get_listener_conf(Zone, Listener, [websocket, Key]).