refactor(ocsp): use listener id instead of name for sni fun matching

This commit is contained in:
Thales Macedo Garitezi 2022-11-08 16:12:40 -03:00
parent 2dcecafce6
commit ce282797be
2 changed files with 16 additions and 18 deletions

View File

@ -42,6 +42,7 @@
]). ]).
-type(listener_name() :: binary()). -type(listener_name() :: binary()).
-type(listener_id() :: binary()).
-type(listener() :: #{ name := listener_name() -type(listener() :: #{ name := listener_name()
, proto := esockd:proto() , proto := esockd:proto()
, listen_on := esockd:listen_on() , listen_on := esockd:listen_on()
@ -71,7 +72,7 @@ find_by_id(Id) ->
find_by_id(iolist_to_binary(Id), emqx:get_env(listeners, [])). find_by_id(iolist_to_binary(Id), emqx:get_env(listeners, [])).
%% @doc Return the ID of the given listener. %% @doc Return the ID of the given listener.
-spec identifier(listener()) -> binary(). -spec identifier(listener()) -> listener_id().
identifier(#{proto := Proto, name := Name}) -> identifier(#{proto := Proto, name := Name}) ->
identifier(Proto, Name). identifier(Proto, Name).
@ -108,7 +109,7 @@ format_listen_on(ListenOn) -> format(ListenOn).
-spec(start_listener(listener()) -> ok). -spec(start_listener(listener()) -> ok).
start_listener(#{proto := Proto, name := Name, listen_on := ListenOn, opts := Opts0}) -> start_listener(#{proto := Proto, name := Name, listen_on := ListenOn, opts := Opts0}) ->
ID = identifier(Proto, Name), ID = identifier(Proto, Name),
Opts = [{listener_name, Name} | Opts0], Opts = [{listener_id, ID} | Opts0],
case start_listener(Proto, ListenOn, Opts) of case start_listener(Proto, ListenOn, Opts) of
{ok, _} -> {ok, _} ->
console_print("Start ~s listener on ~s successfully.~n", [ID, format(ListenOn)]); console_print("Start ~s listener on ~s successfully.~n", [ID, format(ListenOn)]);
@ -127,21 +128,16 @@ console_print(_Fmt, _Args) -> ok.
%% Start MQTT/TCP listener %% Start MQTT/TCP listener
-spec(start_listener(esockd:proto(), esockd:listen_on(), [ esockd:option() -spec(start_listener(esockd:proto(), esockd:listen_on(), [ esockd:option()
| {listener_name, listener_name()}]) | {listener_id, binary()}])
-> {ok, pid()} | {error, term()}). -> {ok, pid()} | {error, term()}).
start_listener(tcp, ListenOn, Options) -> start_listener(tcp, ListenOn, Options) ->
start_mqtt_listener('mqtt:tcp', ListenOn, Options); start_mqtt_listener('mqtt:tcp', ListenOn, Options);
%% Start MQTT/TLS listener %% Start MQTT/TLS listener
start_listener(Proto, ListenOn, Options0) when Proto == ssl; Proto == tls -> start_listener(Proto, ListenOn, Options0) when Proto == ssl; Proto == tls ->
Name = proplists:get_value(listener_name, Options0, <<"mqtt:ssl:external">>), ListenerID = proplists:get_value(listener_id, Options0, <<"mqtt:ssl:external">>),
Options1 = proplists:delete(listener_name, Options0), Options1 = proplists:delete(listener_id, Options0),
Listener = #{ name => Name Options = emqx_ocsp_cache:inject_sni_fun(ListenerID, Options1),
, proto => Proto
, listen_on => ListenOn
, opts => Options1
},
Options = emqx_ocsp_cache:inject_sni_fun(Listener),
start_mqtt_listener('mqtt:ssl', ListenOn, Options); start_mqtt_listener('mqtt:ssl', ListenOn, Options);
%% Start MQTT/WS listener %% Start MQTT/WS listener

View File

@ -28,7 +28,7 @@
, sni_fun/2 , sni_fun/2
, fetch_response/1 , fetch_response/1
, register_listener/1 , register_listener/1
, inject_sni_fun/1 , inject_sni_fun/2
]). ]).
%% gen_server API %% gen_server API
@ -92,12 +92,11 @@ fetch_response(ListenerID) ->
register_listener(ListenerID) -> register_listener(ListenerID) ->
gen_server:call(?MODULE, {register_listener, ListenerID}, ?CALL_TIMEOUT). gen_server:call(?MODULE, {register_listener, ListenerID}, ?CALL_TIMEOUT).
-spec inject_sni_fun(emqx_listeners:listener()) -> [esockd:option()]. -spec inject_sni_fun(emqx_listeners:listener_id(), [esockd:option()]) -> [esockd:option()].
inject_sni_fun(Listener = #{proto := Proto, name := Name, opts := Options0}) -> inject_sni_fun(ListenerID, Options0) ->
%% We need to patch `sni_fun' here and not in `emqx.schema' %% We need to patch `sni_fun' here and not in `emqx.schema'
%% because otherwise an anonymous function will end up in %% because otherwise an anonymous function will end up in
%% `app.*.config'... %% `app.*.config'...
ListenerID = emqx_listeners:identifier(Listener),
OCSPOpts = proplists:get_value(ocsp_options, Options0, []), OCSPOpts = proplists:get_value(ocsp_options, Options0, []),
case proplists:get_bool(ocsp_stapling_enabled, OCSPOpts) of case proplists:get_bool(ocsp_stapling_enabled, OCSPOpts) of
false -> false ->
@ -110,8 +109,8 @@ inject_sni_fun(Listener = #{proto := Proto, name := Name, opts := Options0}) ->
%% save to env %% save to env
{[ThisListener0], Listeners} = {[ThisListener0], Listeners} =
lists:partition( lists:partition(
fun(#{name := N, proto := P}) -> fun(L) ->
N =:= Name andalso P =:= Proto emqx_listeners:identifier(L) =:= ListenerID
end, end,
emqx:get_env(listeners)), emqx:get_env(listeners)),
ThisListener = ThisListener0#{opts => Options}, ThisListener = ThisListener0#{opts => Options},
@ -188,7 +187,10 @@ code_change(_Vsn, State, _Extra) ->
false =/= proplists:get_bool(ocsp_stapling_enabled, OCSPOpts) false =/= proplists:get_bool(ocsp_stapling_enabled, OCSPOpts)
end, end,
emqx:get_env(listeners, [])), emqx:get_env(listeners, [])),
PatchedListeners = [L#{opts => ?MODULE:inject_sni_fun(L)} || L <- ListenersToPatch], PatchedListeners = [L#{opts => ?MODULE:inject_sni_fun(
emqx_listeners:identifier(L),
Opts)}
|| L = #{opts := Opts} <- ListenersToPatch],
lists:foreach( lists:foreach(
fun(L) -> fun(L) ->
emqx_listeners:update_listeners_env(update, L) emqx_listeners:update_listeners_env(update, L)