diff --git a/apps/emqx_rule_engine/include/rule_engine.hrl b/apps/emqx_rule_engine/include/rule_engine.hrl index c69a24244..51ad6ab85 100644 --- a/apps/emqx_rule_engine/include/rule_engine.hrl +++ b/apps/emqx_rule_engine/include/rule_engine.hrl @@ -92,6 +92,8 @@ ?RAISE(EXP, _ = do_nothing, ERROR) ). +-define(RAISE_BAD_SQL(Detail), throw(Detail)). + -define(RAISE(EXP, EXP_ON_FAIL, ERROR), fun() -> try @@ -106,3 +108,16 @@ %% Tables -define(RULE_TAB, emqx_rule_engine). + +%% Allowed sql function provider modules +-define(DEFAULT_SQL_FUNC_PROVIDER, emqx_rule_funcs). +-define(IS_VALID_SQL_FUNC_PROVIDER_MODULE_NAME(Name), + (case Name of + <<"emqx_rule_funcs", _/binary>> -> + true; + <<"EmqxRuleFuncs", _/binary>> -> + true; + _ -> + false + end) +). diff --git a/apps/emqx_rule_engine/src/emqx_rule_runtime.erl b/apps/emqx_rule_engine/src/emqx_rule_runtime.erl index d7412d03c..9bc10f4d2 100644 --- a/apps/emqx_rule_engine/src/emqx_rule_runtime.erl +++ b/apps/emqx_rule_engine/src/emqx_rule_runtime.erl @@ -452,19 +452,23 @@ eval_switch_clauses(CaseOn, [{Cond, Clause} | CaseClauses], ElseClauses, Columns eval_switch_clauses(CaseOn, CaseClauses, ElseClauses, Columns) end. -apply_func(Name, Args, Columns) when is_atom(Name) -> - do_apply_func(Name, Args, Columns); apply_func(Name, Args, Columns) when is_binary(Name) -> - FunName = - try - binary_to_existing_atom(Name, utf8) - catch - error:badarg -> error({sql_function_not_supported, Name}) - end, - do_apply_func(FunName, Args, Columns). + FuncName = parse_function_name(?DEFAULT_SQL_FUNC_PROVIDER, Name), + apply_func(FuncName, Args, Columns); +apply_func([{key, ModuleName0}, {key, FuncName0}], Args, Columns) -> + ModuleName = parse_module_name(ModuleName0), + FuncName = parse_function_name(ModuleName, FuncName0), + do_apply_func(ModuleName, FuncName, Args, Columns); +apply_func(Name, Args, Columns) when is_atom(Name) -> + do_apply_func(?DEFAULT_SQL_FUNC_PROVIDER, Name, Args, Columns); +apply_func(Other, _, _) -> + ?RAISE_BAD_SQL(#{ + reason => bad_sql_function_reference, + reference => Other + }). -do_apply_func(Name, Args, Columns) -> - case erlang:apply(emqx_rule_funcs, Name, Args) of +do_apply_func(Module, Name, Args, Columns) -> + case erlang:apply(Module, Name, Args) of Func when is_function(Func) -> erlang:apply(Func, [Columns]); Result -> @@ -531,3 +535,39 @@ is_ok_result(R) when is_tuple(R) -> ok == erlang:element(1, R); is_ok_result(_) -> false. + +parse_module_name(Name) when is_binary(Name) -> + case ?IS_VALID_SQL_FUNC_PROVIDER_MODULE_NAME(Name) of + true -> + ok; + false -> + ?RAISE_BAD_SQL(#{ + reason => sql_function_provider_module_not_allowed, + module => Name + }) + end, + try + parse_module_name(binary_to_existing_atom(Name, utf8)) + catch + error:badarg -> + ?RAISE_BAD_SQL(#{ + reason => sql_function_provider_module_not_loaded, + module => Name + }) + end; +parse_module_name(Name) when is_atom(Name) -> + Name. + +parse_function_name(Module, Name) when is_binary(Name) -> + try + parse_function_name(Module, binary_to_existing_atom(Name, utf8)) + catch + error:badarg -> + ?RAISE_BAD_SQL(#{ + reason => sql_function_not_supported, + module => Module, + function => Name + }) + end; +parse_function_name(_Module, Name) when is_atom(Name) -> + Name. diff --git a/apps/emqx_rule_engine/src/emqx_rule_sqlparser.erl b/apps/emqx_rule_engine/src/emqx_rule_sqlparser.erl index 9b6ed7eae..b6661684f 100644 --- a/apps/emqx_rule_engine/src/emqx_rule_sqlparser.erl +++ b/apps/emqx_rule_engine/src/emqx_rule_sqlparser.erl @@ -44,11 +44,23 @@ -type alias() :: binary() | list(binary()). --type field() :: - const() - | variable() - | {as, field(), alias()} - | {'fun', atom(), list(field())}. +%% TODO: So far the SQL function module names and function names are as binary(), +%% binary_to_atom is called to convert to module and function name. +%% For better performance, the function references +%% can be converted to a fun Module:Function/N When compiling the SQL. +-type ext_module_name() :: atom() | binary(). +-type func_name() :: atom() | binary(). +-type func_args() :: [field()]. +%% Functions defiend in emqx_rule_funcs +-type builtin_func_ref() :: {var, func_name()}. +%% Functions defined in other modules, reference syntax: Module.Function(Arg1, Arg2, ...) +%% NOTE: it's '.' (Elixir style), but not ':' (Erlang style). +%% Parsed as a two element path-list: [{key, Module}, {key, Func}]. +-type external_func_ref() :: {path, [{key, ext_module_name() | func_name()}]}. +-type func_ref() :: builtin_func_ref() | external_func_ref(). +-type sql_func() :: {'fun', func_ref(), func_args()}. + +-type field() :: const() | variable() | {as, field(), alias()} | sql_func(). -export_type([select/0]). diff --git a/apps/emqx_rule_engine/test/emqx_rule_engine_SUITE.erl b/apps/emqx_rule_engine/test/emqx_rule_engine_SUITE.erl index eb253e516..f2218fa87 100644 --- a/apps/emqx_rule_engine/test/emqx_rule_engine_SUITE.erl +++ b/apps/emqx_rule_engine/test/emqx_rule_engine_SUITE.erl @@ -62,6 +62,9 @@ groups() -> t_match_atom_and_binary, t_sqlselect_0, t_sqlselect_00, + t_sqlselect_with_3rd_party_impl, + t_sqlselect_with_3rd_party_impl2, + t_sqlselect_with_3rd_party_funcs_unknown, t_sqlselect_001, t_sqlselect_inject_props, t_sqlselect_01, @@ -120,6 +123,8 @@ groups() -> %%------------------------------------------------------------------------------ init_per_suite(Config) -> + %% ensure module loaded + emqx_rule_funcs_demo:module_info(), application:load(emqx_conf), ok = emqx_common_test_helpers:start_apps( [emqx_conf, emqx_rule_engine, emqx_authz], @@ -1012,6 +1017,60 @@ t_sqlselect_00(_Config) -> ) ). +t_sqlselect_with_3rd_party_impl(_Config) -> + Sql = + "select * from \"t/#\" where emqx_rule_funcs_demo.is_my_topic(topic)", + T = fun(Topic) -> + emqx_rule_sqltester:test( + #{ + sql => Sql, + context => + #{ + payload => #{<<"what">> => 0}, + topic => Topic + } + } + ) + end, + ?assertMatch({ok, _}, T(<<"t/2/3/4/5">>)), + ?assertMatch({error, nomatch}, T(<<"t/1">>)). + +t_sqlselect_with_3rd_party_impl2(_Config) -> + Sql = fun(N) -> + "select emqx_rule_funcs_demo.duplicate_payload(payload," ++ integer_to_list(N) ++ + ") as payload_list from \"t/#\"" + end, + T = fun(Payload, N) -> + emqx_rule_sqltester:test( + #{ + sql => Sql(N), + context => + #{ + payload => Payload, + topic => <<"t/a">> + } + } + ) + end, + ?assertMatch({ok, #{<<"payload_list">> := [_, _]}}, T(<<"payload1">>, 2)), + ?assertMatch({ok, #{<<"payload_list">> := [_, _, _]}}, T(<<"payload1">>, 3)), + %% crash + ?assertMatch({error, {select_and_transform_error, _}}, T(<<"payload1">>, 4)). + +t_sqlselect_with_3rd_party_funcs_unknown(_Config) -> + Sql = "select emqx_rule_funcs_demo_no_such_module.foo(payload) from \"t/#\"", + ?assertMatch( + {error, + {select_and_transform_error, + {throw, #{reason := sql_function_provider_module_not_loaded}, _}}}, + emqx_rule_sqltester:test( + #{ + sql => Sql, + context => #{payload => <<"a">>, topic => <<"t/a">>} + } + ) + ). + t_sqlselect_001(_Config) -> %% Verify that the jq function can be called from SQL Sql = diff --git a/apps/emqx_rule_engine/test/emqx_rule_funcs_demo.erl b/apps/emqx_rule_engine/test/emqx_rule_funcs_demo.erl new file mode 100644 index 000000000..b0d42b10e --- /dev/null +++ b/apps/emqx_rule_engine/test/emqx_rule_funcs_demo.erl @@ -0,0 +1,32 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2023 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_rule_funcs_demo). + +-export([ + is_my_topic/1, + duplicate_payload/2 +]). + +%% check if the topic is of 5 levels. +is_my_topic(Topic) -> + emqx_topic:levels(Topic) =:= 5. + +%% duplicate the payload, but only supports 2 or 3 copies. +duplicate_payload(Payload, 2) -> + [Payload, Payload]; +duplicate_payload(Payload, 3) -> + [Payload, Payload, Payload].