diff --git a/apps/emqx_utils/src/emqx_variform.erl b/apps/emqx_utils/src/emqx_variform.erl index 95ea1e1ce..25825ea9f 100644 --- a/apps/emqx_utils/src/emqx_variform.erl +++ b/apps/emqx_utils/src/emqx_variform.erl @@ -1,5 +1,5 @@ %%-------------------------------------------------------------------- -%% Copyright (c) 2020-2024 EMQ Technologies Co., Ltd. All Rights Reserved. +%% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -22,7 +22,12 @@ %% or used to choose the first non-empty value from a list of variables. -module(emqx_variform). --export([inject_allowed_modules/1]). +-export([ + inject_allowed_module/1, + inject_allowed_modules/1, + erase_allowed_module/1, + erase_allowed_modules/1 +]). -export([render/2, render/3]). %% @doc Render a variform expression with bindings. @@ -48,6 +53,8 @@ render(Expression, Bindings) -> render(Expression, Bindings, #{}). +render(Expression, Bindings, Opts) when is_binary(Expression) -> + render(unicode:characters_to_list(Expression), Bindings, Opts); render(Expression, Bindings, Opts) -> case emqx_variform_scan:string(Expression) of {ok, Tokens, _Line} -> @@ -66,7 +73,7 @@ render(Expression, Bindings, Opts) -> eval_as_string(Expr, Bindings, _Opts) -> try - {ok, iolist_to_binary(eval(Expr, Bindings))} + {ok, str(eval(Expr, Bindings))} catch throw:Reason -> {error, Reason}; @@ -97,7 +104,7 @@ call(emqx_variform_str, concat, Args) -> call(emqx_variform_str, coalesce, Args) -> str(emqx_variform_str:coalesce(Args)); call(Mod, Fun, Args) -> - str(erlang:apply(Mod, Fun, Args)). + erlang:apply(Mod, Fun, Args). resolve_func_name(FuncNameStr) -> case string:tokens(FuncNameStr, ".") of @@ -107,7 +114,10 @@ resolve_func_name(FuncNameStr) -> list_to_existing_atom(Mod0) catch error:badarg -> - throw(#{unknown_module => Mod0}) + throw(#{ + reason => unknown_variform_module, + module => Mod0 + }) end, ok = assert_module_allowed(Mod), Fun = @@ -115,7 +125,10 @@ resolve_func_name(FuncNameStr) -> list_to_existing_atom(Fun0) catch error:badarg -> - throw(#{unknown_function => Fun0}) + throw(#{ + reason => unknown_variform_function, + function => Fun0 + }) end, {Mod, Fun}; [Fun] -> @@ -125,11 +138,13 @@ resolve_func_name(FuncNameStr) -> catch error:badarg -> throw(#{ - reason => "unknown_variform_function", + reason => unknown_variform_function, function => Fun }) end, - {emqx_variform_str, FuncName} + {emqx_variform_str, FuncName}; + _ -> + throw(#{reason => invalid_function_reference, function => FuncNameStr}) end. resolve_var_value(VarName, Bindings) -> @@ -145,13 +160,14 @@ assert_func_exported(emqx_variform_str, concat, _Arity) -> assert_func_exported(emqx_variform_str, coalesce, _Arity) -> ok; assert_func_exported(Mod, Fun, Arity) -> + %% ensure beam loaded _ = Mod:module_info(md5), case erlang:function_exported(Mod, Fun, Arity) of true -> ok; false -> throw(#{ - reason => "unknown_variform_function", + reason => unknown_variform_function, module => Mod, function => Fun, arity => Arity @@ -167,16 +183,27 @@ assert_module_allowed(Mod) -> ok; false -> throw(#{ - reason => "unallowed_veriform_module", + reason => unallowed_veriform_module, module => Mod }) end. -inject_allowed_modules(Modules) -> +inject_allowed_module(Module) when is_atom(Module) -> + inject_allowed_modules([Module]). + +inject_allowed_modules(Modules) when is_list(Modules) -> Allowed0 = get_allowed_modules(), Allowed = lists:usort(Allowed0 ++ Modules), persistent_term:put({emqx_variform, allowed_modules}, Allowed). +erase_allowed_module(Module) when is_atom(Module) -> + erase_allowed_modules([Module]). + +erase_allowed_modules(Modules) when is_list(Modules) -> + Allowed0 = get_allowed_modules(), + Allowed = Allowed0 -- Modules, + persistent_term:put({emqx_variform, allowed_modules}, Allowed). + get_allowed_modules() -> persistent_term:get({emqx_variform, allowed_modules}, []). diff --git a/apps/emqx_utils/src/emqx_variform_str.erl b/apps/emqx_utils/src/emqx_variform_str.erl index 7b8e2e742..a53e1e216 100644 --- a/apps/emqx_utils/src/emqx_variform_str.erl +++ b/apps/emqx_utils/src/emqx_variform_str.erl @@ -52,7 +52,8 @@ find/3, join_to_string/1, join_to_string/2, - unescape/1 + unescape/1, + nth/2 ]). -define(IS_EMPTY(X), (X =:= <<>> orelse X =:= "" orelse X =:= undefined)). @@ -224,6 +225,20 @@ unescape(Bin) when is_binary(Bin) -> throw({invalid_unicode_character, Error}) end. +nth(N, List) when (is_list(N) orelse is_binary(N)) andalso is_list(List) -> + try binary_to_integer(iolist_to_binary(N)) of + N1 -> + nth(N1, List) + catch + _:_ -> + throw(#{reason => invalid_argument, func => nth, index => N}) + end; +nth(N, List) when is_integer(N) andalso is_list(List) -> + case length(List) of + L when L < N -> <<>>; + _ -> lists:nth(N, List) + end. + unescape_string(Input) -> unescape_string(Input, []). unescape_string([], Acc) -> diff --git a/apps/emqx_utils/test/emqx_variform_tests.erl b/apps/emqx_utils/test/emqx_variform_tests.erl new file mode 100644 index 000000000..da26a383d --- /dev/null +++ b/apps/emqx_utils/test/emqx_variform_tests.erl @@ -0,0 +1,129 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2024 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_variform_tests). + +-compile(export_all). +-compile(nowarn_export_all). + +-include_lib("eunit/include/eunit.hrl"). + +-define(SYNTAX_ERROR, {error, "syntax error before:" ++ _}). + +redner_test_() -> + [ + {"direct var reference", fun() -> ?assertEqual({ok, <<"1">>}, render("a", #{a => 1})) end}, + {"concat strings", fun() -> + ?assertEqual({ok, <<"a,b">>}, render("concat('a',',','b')", #{})) + end}, + {"concat empty string", fun() -> ?assertEqual({ok, <<"">>}, render("concat('')", #{})) end}, + {"tokens 1st", fun() -> + ?assertEqual({ok, <<"a">>}, render("nth(1,tokens(var, ','))", #{var => <<"a,b">>})) + end}, + {"unknown var as empty str", fun() -> + ?assertEqual({ok, <<>>}, render("var", #{})) + end}, + {"out of range nth index", fun() -> + ?assertEqual({ok, <<>>}, render("nth(2, tokens(var, ','))", #{var => <<"a">>})) + end}, + {"not a index number for nth", fun() -> + ?assertMatch( + {error, #{reason := invalid_argument, func := nth, index := <<"notnum">>}}, + render("nth('notnum', tokens(var, ','))", #{var => <<"a">>}) + ) + end} + ]. + +unknown_func_test_() -> + [ + {"unknown function", fun() -> + ?assertMatch( + {error, #{reason := unknown_variform_function}}, + render("nonexistingatom__(a)", #{}) + ) + end}, + {"unknown module", fun() -> + ?assertMatch( + {error, #{reason := unknown_variform_module}}, + render("nonexistingatom__.nonexistingatom__(a)", #{}) + ) + end}, + {"unknown function in a known module", fun() -> + ?assertMatch( + {error, #{reason := unknown_variform_function}}, + render("emqx_variform_str.nonexistingatom__(a)", #{}) + ) + end}, + {"invalid func reference", fun() -> + ?assertMatch( + {error, #{reason := invalid_function_reference, function := "a.b.c"}}, + render("a.b.c(var)", #{}) + ) + end} + ]. + +concat(L) -> iolist_to_binary(L). + +inject_allowed_module_test() -> + try + emqx_variform:inject_allowed_module(?MODULE), + ?assertEqual({ok, <<"ab">>}, render(atom_to_list(?MODULE) ++ ".concat(['a','b'])", #{})), + ?assertMatch( + {error, #{ + reason := unknown_variform_function, + module := ?MODULE, + function := concat, + arity := 2 + }}, + render(atom_to_list(?MODULE) ++ ".concat('a','b')", #{}) + ), + ?assertMatch( + {error, #{reason := unallowed_veriform_module, module := emqx}}, + render("emqx.concat('a','b')", #{}) + ) + after + emqx_variform:erase_allowed_module(?MODULE) + end. + +coalesce_test_() -> + [ + {"coalesce first", fun() -> + ?assertEqual({ok, <<"a">>}, render("coalesce('a','b')", #{})) + end}, + {"coalesce second", fun() -> + ?assertEqual({ok, <<"b">>}, render("coalesce('', 'b')", #{})) + end}, + {"coalesce first var", fun() -> + ?assertEqual({ok, <<"a">>}, render("coalesce(a,b)", #{a => <<"a">>, b => <<"b">>})) + end}, + {"coalesce second var", fun() -> + ?assertEqual({ok, <<"b">>}, render("coalesce(a,b)", #{b => <<"b">>})) + end}, + {"coalesce empty", fun() -> ?assertEqual({ok, <<>>}, render("coalesce(a,b)", #{})) end} + ]. + +syntax_error_test_() -> + [ + {"empty expression", fun() -> ?assertMatch(?SYNTAX_ERROR, render("", #{})) end}, + {"const string single quote", fun() -> ?assertMatch(?SYNTAX_ERROR, render("'a'", #{})) end}, + {"const string double quote", fun() -> + ?assertMatch(?SYNTAX_ERROR, render(<<"\"a\"">>, #{})) + end}, + {"no arity", fun() -> ?assertMatch(?SYNTAX_ERROR, render("concat()", #{})) end} + ]. + +render(Expression, Bindings) -> + emqx_variform:render(Expression, Bindings).