feat(acl): make mqtt over websocket work with the new config

This commit is contained in:
Shawn 2021-07-07 14:43:54 +08:00
parent 707851c36f
commit 630b54f6ee
6 changed files with 100 additions and 89 deletions

View File

@ -2218,8 +2218,8 @@ example_common_websocket_options {
##
## @doc listeners.<name>.websocket.compress
## ValueType: Boolean
## Default: true
websocket.compress: true
## Default: false
websocket.compress: false
## The idle timeout for external WebSocket connections.
##
@ -2244,6 +2244,13 @@ example_common_websocket_options {
## Default: true
websocket.fail_if_no_subprotocol: true
## Supported subprotocols
##
## @doc listeners.<name>.websocket.supported_subprotocols
## ValueType: String
## Default: mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5
websocket.supported_subprotocols: "mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5"
## Enable origin check in header for websocket connection
##
## @doc listeners.<name>.websocket.check_origin_enable

View File

@ -243,7 +243,7 @@ init(Parent, Transport, RawSocket, Options) ->
exit_on_sock_error(Reason)
end.
init_state(Transport, Socket, Options) ->
init_state(Transport, Socket, #{zone := Zone, listener := Listener} = Opts) ->
{ok, Peername} = Transport:ensure_ok_or_exit(peername, [Socket]),
{ok, Sockname} = Transport:ensure_ok_or_exit(sockname, [Socket]),
Peercert = Transport:ensure_ok_or_exit(peercert, [Socket]),
@ -253,8 +253,6 @@ init_state(Transport, Socket, Options) ->
peercert => Peercert,
conn_mod => ?MODULE
},
Zone = maps:get(zone, Options),
Listener = maps:get(listener, Options),
Limiter = emqx_limiter:init(Zone, undefined, undefined, []),
FrameOpts = #{
strict_mode => emqx_config:get_listener_conf(Zone, Listener, [mqtt, strict_mode]),
@ -262,7 +260,7 @@ init_state(Transport, Socket, Options) ->
},
ParseState = emqx_frame:initial_parse_state(FrameOpts),
Serialize = emqx_frame:serialize_opts(),
Channel = emqx_channel:init(ConnInfo, Options),
Channel = emqx_channel:init(ConnInfo, Opts),
GcState = case emqx_config:get_listener_conf(Zone, Listener, [force_gc]) of
#{enable := false} -> undefined;
GcPolicy -> emqx_gc:init(GcPolicy)
@ -295,11 +293,9 @@ run_loop(Parent, State = #state{transport = Transport,
peername = Peername,
channel = Channel}) ->
emqx_logger:set_metadata_peername(esockd:format(Peername)),
case emqx_config:get_listener_conf(emqx_channel:info(zone, Channel),
emqx_channel:info(listener, Channel), [force_shutdown]) of
#{enable := false} -> ok;
ShutdownPolicy -> emqx_misc:tune_heap_size(ShutdownPolicy)
end,
ShutdownPolicy = emqx_config:get_listener_conf(emqx_channel:info(zone, Channel),
emqx_channel:info(listener, Channel), [force_shutdown]),
emqx_misc:tune_heap_size(ShutdownPolicy),
case activate_socket(State) of
{ok, NState} -> hibernate(Parent, NState);
{error, Reason} ->
@ -793,15 +789,11 @@ check_oom(State = #state{channel = Channel}) ->
ShutdownPolicy = emqx_config:get_listener_conf(emqx_channel:info(zone, Channel),
emqx_channel:info(listener, Channel), [force_shutdown]),
?tp(debug, check_oom, #{policy => ShutdownPolicy}),
case ShutdownPolicy of
#{enable := false} -> ok;
ShutdownPolicy ->
case emqx_misc:check_oom(ShutdownPolicy) of
{shutdown, Reason} ->
%% triggers terminate/2 callback immediately
erlang:exit({shutdown, Reason});
_ -> ok
end
case emqx_misc:check_oom(ShutdownPolicy) of
{shutdown, Reason} ->
%% triggers terminate/2 callback immediately
erlang:exit({shutdown, Reason});
_ -> ok
end,
State.

View File

@ -74,13 +74,13 @@ do_start_listener(ZoneName, ListenerName, #{type := tcp, bind := ListenOn} = Opt
%% Start MQTT/WS listener
do_start_listener(ZoneName, ListenerName, #{type := ws, bind := ListenOn} = Opts) ->
Id = listener_id(ZoneName, ListenerName),
RanchOpts = ranch_opts(Opts),
RanchOpts = ranch_opts(ListenOn, Opts),
WsOpts = ws_opts(ZoneName, ListenerName, Opts),
case is_ssl(Opts) of
false ->
cowboy:start_clear(Id, with_port(ListenOn, RanchOpts), WsOpts);
cowboy:start_clear(Id, RanchOpts, WsOpts);
true ->
cowboy:start_tls(Id, with_port(ListenOn, RanchOpts), WsOpts)
cowboy:start_tls(Id, RanchOpts, WsOpts)
end.
esockd_opts(Opts0) ->
@ -104,21 +104,22 @@ ws_opts(ZoneName, ListenerName, Opts) ->
ProxyProto = maps:get(proxy_protocol, Opts, false),
#{env => #{dispatch => Dispatch}, proxy_header => ProxyProto}.
ranch_opts(Opts) ->
ranch_opts(ListenOn, Opts) ->
NumAcceptors = maps:get(acceptors, Opts, 4),
MaxConnections = maps:get(max_connections, Opts, 1024),
SocketOpts = case is_ssl(Opts) of
true -> tcp_opts(Opts) ++ proplists:delete(handshake_timeout, ssl_opts(Opts));
false -> tcp_opts(Opts)
end,
#{num_acceptors => NumAcceptors,
max_connections => MaxConnections,
handshake_timeout => maps:get(handshake_timeout, Opts, 15000),
socket_opts => case is_ssl(Opts) of
true -> tcp_opts(Opts) ++ proplists:delete(handshake_timeout, ssl_opts(Opts));
false -> tcp_opts(Opts)
end}.
socket_opts => ip_port(ListenOn) ++ SocketOpts}.
with_port(Port, Opts = #{socket_opts := SocketOption}) when is_integer(Port) ->
Opts#{socket_opts => [{port, Port}| SocketOption]};
with_port({Addr, Port}, Opts = #{socket_opts := SocketOption}) ->
Opts#{socket_opts => [{ip, Addr}, {port, Port}| SocketOption]}.
ip_port(Port) when is_integer(Port) ->
[{port, Port}];
ip_port({Addr, Port}) ->
[{ip, Addr}, {port, Port}].
esockd_access_rules(StrRules) ->
Access = fun(S) ->

View File

@ -43,16 +43,17 @@ deep_get(ConfKeyPath, Map, Default) ->
{ok, Data} -> Data
end.
-spec deep_find(config_key_path(), map()) -> {ok, term()} | {not_found, config_key(), term()}.
-spec deep_find(config_key_path(), map()) ->
{ok, term()} | {not_found, config_key_path(), term()}.
deep_find([], Map) ->
{ok, Map};
deep_find([Key | KeyPath], Map) when is_map(Map) ->
deep_find([Key | KeyPath] = Path, Map) when is_map(Map) ->
case maps:find(Key, Map) of
{ok, SubMap} -> deep_find(KeyPath, SubMap);
error -> {not_found, Key, Map}
error -> {not_found, Path, Map}
end;
deep_find([Key | _KeyPath], Data) ->
{not_found, Key, Data}.
deep_find(_KeyPath, Data) ->
{not_found, _KeyPath, Data}.
-spec deep_put(config_key_path(), map(), term()) -> map().
deep_put([], Map, Config) when is_map(Map) ->

View File

@ -364,11 +364,11 @@ fields("mqtt_ws_listener") ->
fields("ws_opts") ->
[ {"mqtt_path", t(string(), undefined, "/mqtt")}
, {"mqtt_piggyback", t(union(single, multiple), undefined, multiple)}
, {"compress", t(boolean())}
, {"compress", t(boolean(), undefined, false)}
, {"idle_timeout", t(duration(), undefined, "15s")}
, {"max_frame_size", maybe_infinity(integer())}
, {"fail_if_no_subprotocol", t(boolean(), undefined, true)}
, {"supported_subprotocols", t(string(), undefined,
, {"supported_subprotocols", t(comma_separated_list(), undefined,
"mqtt, mqtt-v3, mqtt-v3.1.1, mqtt-v5")}
, {"check_origin_enable", t(boolean(), undefined, false)}
, {"allow_origin_absence", t(boolean(), undefined, true)}
@ -401,12 +401,12 @@ fields("ssl_opts") ->
fields("deflate_opts") ->
[ {"level", t(union([none, default, best_compression, best_speed]))}
, {"mem_level", t(range(1, 9))}
, {"mem_level", t(range(1, 9), undefined, 8)}
, {"strategy", t(union([default, filtered, huffman_only, rle]))}
, {"server_context_takeover", t(union(takeover, no_takeover))}
, {"client_context_takeover", t(union(takeover, no_takeover))}
, {"server_max_window_bits", t(integer())}
, {"client_max_window_bits", t(integer())}
, {"server_max_window_bits", t(range(8, 15), undefined, 15)}
, {"client_max_window_bits", t(range(8, 15), undefined, 15)}
];
fields("module") ->

View File

@ -174,21 +174,13 @@ call(WsPid, Req, Timeout) when is_pid(WsPid) ->
%% WebSocket callbacks
%%--------------------------------------------------------------------
init(Req, Opts) ->
init(Req, #{zone := Zone, listener := Listener} = Opts) ->
%% WS Transport Idle Timeout
IdleTimeout = proplists:get_value(idle_timeout, Opts, 7200000),
DeflateOptions = maps:from_list(proplists:get_value(deflate_options, Opts, [])),
MaxFrameSize = case proplists:get_value(max_frame_size, Opts, 0) of
0 -> infinity;
I -> I
end,
Compress = proplists:get_bool(compress, Opts),
WsOpts = #{compress => Compress,
deflate_opts => DeflateOptions,
max_frame_size => MaxFrameSize,
idle_timeout => IdleTimeout
WsOpts = #{compress => get_ws_opts(Zone, Listener, compress),
deflate_opts => get_ws_opts(Zone, Listener, deflate_opts),
max_frame_size => get_ws_opts(Zone, Listener, max_frame_size),
idle_timeout => get_ws_opts(Zone, Listener, idle_timeout)
},
case check_origin_header(Req, Opts) of
{error, Message} ->
?LOG(error, "Invalid Origin Header ~p~n", [Message]),
@ -196,18 +188,17 @@ init(Req, Opts) ->
ok -> parse_sec_websocket_protocol(Req, Opts, WsOpts)
end.
parse_sec_websocket_protocol(Req, Opts, WsOpts) ->
FailIfNoSubprotocol = proplists:get_value(fail_if_no_subprotocol, Opts),
parse_sec_websocket_protocol(Req, #{zone := Zone, listener := Listener} = Opts, WsOpts) ->
case cowboy_req:parse_header(<<"sec-websocket-protocol">>, Req) of
undefined ->
case FailIfNoSubprotocol of
case get_ws_opts(Zone, Listener, fail_if_no_subprotocol) of
true ->
{ok, cowboy_req:reply(400, Req), WsOpts};
false ->
{cowboy_websocket, Req, [Req, Opts], WsOpts}
end;
Subprotocols ->
SupportedSubprotocols = proplists:get_value(supported_subprotocols, Opts),
SupportedSubprotocols = get_ws_opts(Zone, Listener, supported_subprotocols),
NSupportedSubprotocols = [list_to_binary(Subprotocol)
|| Subprotocol <- SupportedSubprotocols],
case pick_subprotocol(Subprotocols, NSupportedSubprotocols) of
@ -231,31 +222,30 @@ pick_subprotocol([Subprotocol | Rest], SupportedSubprotocols) ->
pick_subprotocol(Rest, SupportedSubprotocols)
end.
parse_header_fun_origin(Req, Opts) ->
parse_header_fun_origin(Req, #{zone := Zone, listener := Listener}) ->
case cowboy_req:header(<<"origin">>, Req) of
undefined ->
case proplists:get_bool(allow_origin_absence, Opts) of
case get_ws_opts(Zone, Listener, allow_origin_absence) of
true -> ok;
false -> {error, origin_header_cannot_be_absent}
end;
Value ->
Origins = proplists:get_value(check_origins, Opts, []),
case lists:member(Value, Origins) of
case lists:member(Value, get_ws_opts(Zone, Listener, check_origins)) of
true -> ok;
false -> {origin_not_allowed, Value}
end
end.
check_origin_header(Req, Opts) ->
case proplists:get_bool(check_origin_enable, Opts) of
check_origin_header(Req, #{zone := Zone, listener := Listener} = Opts) ->
case get_ws_opts(Zone, Listener, check_origin_enable) of
true -> parse_header_fun_origin(Req, Opts);
false -> ok
end.
websocket_init([Req, Opts]) ->
websocket_init([Req, #{zone := Zone, listener := Listener} = Opts]) ->
{Peername, Peercert} =
case proplists:get_bool(proxy_protocol, Opts)
andalso maps:get(proxy_header, Req) of
case emqx_config:get_listener_conf(Zone, Listener, [proxy_protocol]) andalso
maps:get(proxy_header, Req) of
#{src_address := SrcAddr, src_port := SrcPort, ssl := SSL} ->
SourceName = {SrcAddr, SrcPort},
%% Notice: Only CN is available in Proxy Protocol V2 additional info
@ -266,7 +256,7 @@ websocket_init([Req, Opts]) ->
{SourceName, SourceSSL};
#{src_address := SrcAddr, src_port := SrcPort} ->
SourceName = {SrcAddr, SrcPort},
{SourceName , nossl};
{SourceName, nossl};
_ ->
{get_peer(Req, Opts), cowboy_req:cert(Req)}
end,
@ -288,22 +278,31 @@ websocket_init([Req, Opts]) ->
ws_cookie => WsCookie,
conn_mod => ?MODULE
},
Zone = proplists:get_value(zone, Opts),
PubLimit = emqx_zone:publish_limit(Zone),
BytesIn = proplists:get_value(rate_limit, Opts),
RateLimit = emqx_zone:ratelimit(Zone),
Limiter = emqx_limiter:init(Zone, PubLimit, BytesIn, RateLimit),
MQTTPiggyback = proplists:get_value(mqtt_piggyback, Opts, multiple),
FrameOpts = emqx_zone:mqtt_frame_options(Zone),
Limiter = emqx_limiter:init(Zone, undefined, undefined, []),
MQTTPiggyback = get_ws_opts(Zone, Listener, mqtt_piggyback),
FrameOpts = #{
strict_mode => emqx_config:get_listener_conf(Zone, Listener, [mqtt, strict_mode]),
max_size => emqx_config:get_listener_conf(Zone, Listener, [mqtt, max_packet_size])
},
ParseState = emqx_frame:initial_parse_state(FrameOpts),
Serialize = emqx_frame:serialize_opts(),
Channel = emqx_channel:init(ConnInfo, Opts),
GcState = emqx_zone:init_gc_state(Zone),
StatsTimer = emqx_zone:stats_timer(Zone),
GcState = case emqx_config:get_listener_conf(Zone, Listener, [force_gc]) of
#{enable := false} -> undefined;
GcPolicy -> emqx_gc:init(GcPolicy)
end,
StatsTimer = case emqx_config:get_listener_conf(Zone, Listener, [stats, enable]) of
true -> undefined;
false -> disabled
end,
%% MQTT Idle Timeout
IdleTimeout = emqx_zone:idle_timeout(Zone),
IdleTimeout = emqx_channel:get_mqtt_conf(Zone, Listener, idle_timeout),
IdleTimer = start_timer(IdleTimeout, idle_timeout),
emqx_misc:tune_heap_size(emqx_zone:oom_policy(Zone)),
case emqx_config:get_listener_conf(emqx_channel:info(zone, Channel),
emqx_channel:info(listener, Channel), [force_shutdown]) of
#{enable := false} -> ok;
ShutdownPolicy -> emqx_misc:tune_heap_size(ShutdownPolicy)
end,
emqx_logger:set_metadata_peername(esockd:format(Peername)),
{ok, #state{peername = Peername,
sockname = Sockname,
@ -317,7 +316,9 @@ websocket_init([Req, Opts]) ->
postponed = [],
stats_timer = StatsTimer,
idle_timeout = IdleTimeout,
idle_timer = IdleTimer
idle_timer = IdleTimer,
zone = Zone,
listener = Listener
}, hibernate}.
websocket_handle({binary, Data}, State) when is_list(Data) ->
@ -517,11 +518,16 @@ run_gc(Stats, State = #state{gc_state = GcSt}) ->
end.
check_oom(State = #state{channel = Channel}) ->
OomPolicy = emqx_zone:oom_policy(emqx_channel:info(zone, Channel)),
case ?ENABLED(OomPolicy) andalso emqx_misc:check_oom(OomPolicy) of
Shutdown = {shutdown, _Reason} ->
postpone(Shutdown, State);
_Other -> State
ShutdownPolicy = emqx_config:get_listener_conf(emqx_channel:info(zone, Channel),
emqx_channel:info(listener, Channel), [force_shutdown]),
case ShutdownPolicy of
#{enable := false} -> ok;
#{enable := true} ->
case emqx_misc:check_oom(ShutdownPolicy) of
Shutdown = {shutdown, _Reason} ->
postpone(Shutdown, State);
_Other -> State
end
end.
%%--------------------------------------------------------------------
@ -741,9 +747,10 @@ classify([Event|More], Packets, Cmds, Events) ->
trigger(Event) -> erlang:send(self(), Event).
get_peer(Req, Opts) ->
get_peer(Req, #{zone := Zone, listener := Listener}) ->
{PeerAddr, PeerPort} = cowboy_req:peer(Req),
AddrHeader = cowboy_req:header(proplists:get_value(proxy_address_header, Opts), Req, <<>>),
AddrHeader = cowboy_req:header(
get_ws_opts(Zone, Listener, proxy_address_header), Req, <<>>),
ClientAddr = case string:tokens(binary_to_list(AddrHeader), ", ") of
[] ->
undefined;
@ -756,7 +763,8 @@ get_peer(Req, Opts) ->
_ ->
PeerAddr
end,
PortHeader = cowboy_req:header(proplists:get_value(proxy_port_header, Opts), Req, <<>>),
PortHeader = cowboy_req:header(
get_ws_opts(Zone, Listener, proxy_port_header), Req, <<>>),
ClientPort = case string:tokens(binary_to_list(PortHeader), ", ") of
[] ->
undefined;
@ -777,3 +785,5 @@ set_field(Name, Value, State) ->
Pos = emqx_misc:index_of(Name, record_info(fields, state)),
setelement(Pos+1, State, Value).
get_ws_opts(Zone, Listener, Key) ->
emqx_config:get_listener_conf(Zone, Listener, [websocket, Key]).