feat(emqx_ws_connection): Prevent EMQX from CSWSH Cross-Site Web-Socket Hijack

This commit is contained in:
ayodele.akingbule 2021-01-11 22:08:32 +01:00 committed by Zaiming Shi
parent 9e03d6fea1
commit 5794a708ed
3 changed files with 114 additions and 6 deletions

View File

@ -681,7 +681,7 @@ mqtt.ignore_loop_deliver = false
mqtt.strict_mode = false
## Specify the response information returned to the client
##
##
## Value: String
## mqtt.response_information = example
@ -917,7 +917,7 @@ zone.external.ignore_loop_deliver = false
zone.external.strict_mode = false
## Specify the response information returned to the client
##
##
## Value: String
## zone.external.response_information = example
@ -1012,7 +1012,7 @@ zone.internal.ignore_loop_deliver = false
zone.internal.strict_mode = false
## Specify the response information returned to the client
##
##
## Value: String
## zone.internal.response_information = example
@ -1704,6 +1704,21 @@ listener.ws.external.nodelay = true
## Value: single | multiple
listener.ws.external.mqtt_piggyback = multiple
## Enable origin check in header for websocket connection
##
## Value: true | false (default false)
listener.ws.external.check_origin_enable = false
## Allow origin to be absent in header in websocket connection when check_origin_enable is true
##
## Value: true | false (default true)
listener.ws.external.allow_origin_absence = true
## Comma separated list of allowed origin in header for websocket connection
##
## Value: http://url eg. local http dashboard url - http://localhost:18083, http://127.0.0.1:18083
listener.ws.external.check_origins = http://localhost:18083, http://127.0.0.1:18083
##--------------------------------------------------------------------
## External WebSocket/SSL listener for MQTT Protocol
@ -1984,6 +1999,18 @@ listener.wss.external.send_timeout_close = on
##
## Value: single | multiple
listener.wss.external.mqtt_piggyback = multiple
## Enable origin check in header for secure websocket connection
##
## Value: true | false (default false)
listener.wss.external.check_origin_enable = false
## Allow origin to be absent in header in secure websocket connection when check_origin_enable is true
##
## Value: true | false (default true)
listener.wss.external.allow_origin_absence = true
## Comma separated list of allowed origin in header for secure websocket connection
##
## Value: http://url eg. https://localhost:8084, https://127.0.0.1:8084
listener.wss.external.check_origins = https://localhost:8084, https://127.0.0.1:8084
##--------------------------------------------------------------------
## Modules
@ -2245,7 +2272,7 @@ alarm.actions = log,publish
## The maximum number of deactivated alarms
##
## Value: Integer
## Value: Integer
##
## Default: 1000
alarm.size_limit = 1000

View File

@ -1582,6 +1582,23 @@ end}.
hidden
]}.
{mapping, "listener.ws.$name.check_origin_enable", "emqx.listeners", [
{datatype, {enum, [true, false]}},
{default, false},
hidden
]}.
{mapping, "listener.ws.$name.allow_origin_absence", "emqx.listeners", [
{datatype, {enum, [true, false]}},
{default, true},
hidden
]}.
{mapping, "listener.ws.$name.check_origins", "emqx.listeners", [
{datatype, string},
hidden
]}.
%%--------------------------------------------------------------------
%% MQTT/WebSocket/SSL Listeners
@ -1800,6 +1817,23 @@ end}.
hidden
]}.
{mapping, "listener.wss.$name.check_origin_enable", "emqx.listeners", [
{datatype, {enum, [true, false]}},
{default, false},
hidden
]}.
{mapping, "listener.wss.$name.allow_origin_absence", "emqx.listeners", [
{datatype, {enum, [true, false]}},
{default, true},
hidden
]}.
{mapping, "listener.wss.$name.check_origins", "emqx.listeners", [
{datatype, string},
hidden
]}.
{translation, "emqx.listeners", fun(Conf) ->
Filter = fun(Opts) -> [{K, V} || {K, V} <- Opts, V =/= undefined] end,
@ -1833,6 +1867,20 @@ end}.
{Limit, Duration}
end,
CheckOrigin = fun(S) ->
Origins = string:tokens(S, ","),
[ list_to_binary(string:trim(O)) || O <- Origins]
end,
WsOpts = fun(Prefix) ->
case cuttlefish_variable:filter_by_prefix(Prefix ++ ".check_origins", Conf) of
[] -> undefined;
Rules ->
OriginList = [CheckOrigin(Rule) || {_, Rule} <- Rules],
lists:flatten(OriginList)
end
end,
LisOpts = fun(Prefix) ->
Filter([{acceptors, cuttlefish:conf_get(Prefix ++ ".acceptors", Conf)},
{mqtt_path, cuttlefish:conf_get(Prefix ++ ".mqtt_path", Conf, undefined)},
@ -1849,7 +1897,10 @@ end}.
{compress, cuttlefish:conf_get(Prefix ++ ".compress", Conf, undefined)},
{idle_timeout, cuttlefish:conf_get(Prefix ++ ".idle_timeout", Conf, undefined)},
{max_frame_size, cuttlefish:conf_get(Prefix ++ ".max_frame_size", Conf, undefined)},
{mqtt_piggyback, cuttlefish:conf_get(Prefix ++ ".mqtt_piggyback", Conf, undefined)} | AccOpts(Prefix)])
{mqtt_piggyback, cuttlefish:conf_get(Prefix ++ ".mqtt_piggyback", Conf, undefined)},
{check_origin_enable, cuttlefish:conf_get(Prefix ++ ".check_origin_enable", Conf, undefined)},
{allow_origin_absence, cuttlefish:conf_get(Prefix ++ ".allow_origin_absence", Conf, undefined)},
{check_origins, WsOpts(Prefix)} | AccOpts(Prefix)])
end,
DeflateOpts = fun(Prefix) ->
Filter([{level, cuttlefish:conf_get(Prefix ++ ".deflate_opts.level", Conf, undefined)},

View File

@ -183,18 +183,48 @@ init(Req, Opts) ->
max_frame_size => MaxFrameSize,
idle_timeout => IdleTimeout
},
case check_origin_header(Req, Opts) of
{error, Message} ->
?LOG(error, "Invalid Origin Header ~p~n", [Message]),
{ok, cowboy_req:reply(403, Req), WsOpts};
ok -> parse_sec_websocket_protocol(Req, Opts, WsOpts)
end.
parse_sec_websocket_protocol(Req, Opts, WsOpts) ->
case cowboy_req:parse_header(<<"sec-websocket-protocol">>, Req) of
undefined ->
%% TODO: why not reply 500???
{cowboy_websocket, Req, [Req, Opts], WsOpts};
[<<"mqtt", Vsn/binary>>] ->
Resp = cowboy_req:set_resp_header(
<<"sec-websocket-protocol">>, <<"mqtt", Vsn/binary>>, Req),
<<"sec-websocket-protocol">>, <<"mqtt", Vsn/binary>>, Req),
{cowboy_websocket, Resp, [Req, Opts], WsOpts};
_ ->
{ok, cowboy_req:reply(400, Req), WsOpts}
end.
parse_header_fun_origin(Req, Opts) ->
case cowboy_req:header(<<"origin">>, Req) of
undefined ->
case proplists:get_value(allow_origin_absence, Opts, true) of
true -> ok;
false -> {error, origin_header_cannot_be_absent}
end;
Value ->
Origins = proplists:get_value(check_origins, Opts, []),
case lists:member(Value, Origins) of
true -> ok;
false -> {origin_not_allowed, Value}
end
end.
check_origin_header(Req, Opts) ->
case proplists:get_value(check_origin_enable, Opts) of
true -> parse_header_fun_origin(Req, Opts);
false -> ok
end.
websocket_init([Req, Opts]) ->
Peername = case proplists:get_bool(proxy_protocol, Opts)
andalso maps:get(proxy_header, Req) of