diff --git a/apps/emqx/include/emqx_placeholder.hrl b/apps/emqx/include/emqx_placeholder.hrl index 7b2ce6c6b..1db80c72d 100644 --- a/apps/emqx/include/emqx_placeholder.hrl +++ b/apps/emqx/include/emqx_placeholder.hrl @@ -19,67 +19,79 @@ -define(PH_VAR_THIS, <<"$_THIS_">>). --define(PH(Type), <<"${", Type/binary, "}">>). +-define(PH(Var), <<"${" Var "}">>). %% action: publish/subscribe --define(PH_ACTION, <<"${action}">>). +-define(VAR_ACTION, "action"). +-define(PH_ACTION, ?PH(?VAR_ACTION)). %% cert --define(PH_CERT_SUBJECT, <<"${cert_subject}">>). --define(PH_CERT_CN_NAME, <<"${cert_common_name}">>). +-define(VAR_CERT_SUBJECT, "cert_subject"). +-define(VAR_CERT_CN_NAME, "cert_common_name"). +-define(PH_CERT_SUBJECT, ?PH(?VAR_CERT_SUBJECT)). +-define(PH_CERT_CN_NAME, ?PH(?VAR_CERT_CN_NAME)). %% MQTT --define(PH_PASSWORD, <<"${password}">>). --define(PH_CLIENTID, <<"${clientid}">>). --define(PH_FROM_CLIENTID, <<"${from_clientid}">>). --define(PH_USERNAME, <<"${username}">>). --define(PH_FROM_USERNAME, <<"${from_username}">>). --define(PH_TOPIC, <<"${topic}">>). +-define(VAR_PASSWORD, "password"). +-define(VAR_CLIENTID, "clientid"). +-define(VAR_USERNAME, "username"). +-define(VAR_TOPIC, "topic"). +-define(PH_PASSWORD, ?PH(?VAR_PASSWORD)). +-define(PH_CLIENTID, ?PH(?VAR_CLIENTID)). +-define(PH_FROM_CLIENTID, ?PH("from_clientid")). +-define(PH_USERNAME, ?PH(?VAR_USERNAME)). +-define(PH_FROM_USERNAME, ?PH("from_username")). +-define(PH_TOPIC, ?PH(?VAR_TOPIC)). %% MQTT payload --define(PH_PAYLOAD, <<"${payload}">>). +-define(PH_PAYLOAD, ?PH("payload")). %% client IPAddress --define(PH_PEERHOST, <<"${peerhost}">>). +-define(VAR_PEERHOST, "peerhost"). +-define(PH_PEERHOST, ?PH(?VAR_PEERHOST)). %% ip & port --define(PH_HOST, <<"${host}">>). --define(PH_PORT, <<"${port}">>). +-define(PH_HOST, ?PH("host")). +-define(PH_PORT, ?PH("port")). %% Enumeration of message QoS 0,1,2 --define(PH_QOS, <<"${qos}">>). --define(PH_FLAGS, <<"${flags}">>). +-define(VAR_QOS, "qos"). +-define(PH_QOS, ?PH(?VAR_QOS)). +-define(PH_FLAGS, ?PH("flags")). %% Additional data related to process within the MQTT message --define(PH_HEADERS, <<"${headers}">>). +-define(PH_HEADERS, ?PH("headers")). %% protocol name --define(PH_PROTONAME, <<"${proto_name}">>). +-define(VAR_PROTONAME, "proto_name"). +-define(PH_PROTONAME, ?PH(?VAR_PROTONAME)). %% protocol version --define(PH_PROTOVER, <<"${proto_ver}">>). +-define(PH_PROTOVER, ?PH("proto_ver")). %% MQTT keepalive interval --define(PH_KEEPALIVE, <<"${keepalive}">>). +-define(PH_KEEPALIVE, ?PH("keepalive")). %% MQTT clean_start --define(PH_CLEAR_START, <<"${clean_start}">>). +-define(PH_CLEAR_START, ?PH("clean_start")). %% MQTT Session Expiration time --define(PH_EXPIRY_INTERVAL, <<"${expiry_interval}">>). +-define(PH_EXPIRY_INTERVAL, ?PH("expiry_interval")). %% Time when PUBLISH message reaches Broker (ms) --define(PH_PUBLISH_RECEIVED_AT, <<"${publish_received_at}">>). +-define(PH_PUBLISH_RECEIVED_AT, ?PH("publish_received_at")). %% Mountpoint for bridging messages --define(PH_MOUNTPOINT, <<"${mountpoint}">>). +-define(VAR_MOUNTPOINT, "mountpoint"). +-define(PH_MOUNTPOINT, ?PH(?VAR_MOUNTPOINT)). %% IPAddress and Port of terminal --define(PH_PEERNAME, <<"${peername}">>). +-define(PH_PEERNAME, ?PH("peername")). %% IPAddress and Port listened by emqx --define(PH_SOCKNAME, <<"${sockname}">>). +-define(PH_SOCKNAME, ?PH("sockname")). %% whether it is MQTT bridge connection --define(PH_IS_BRIDGE, <<"${is_bridge}">>). +-define(PH_IS_BRIDGE, ?PH("is_bridge")). %% Terminal connection completion time (s) --define(PH_CONNECTED_AT, <<"${connected_at}">>). +-define(PH_CONNECTED_AT, ?PH("connected_at")). %% Event trigger time(millisecond) --define(PH_TIMESTAMP, <<"${timestamp}">>). +-define(PH_TIMESTAMP, ?PH("timestamp")). %% Terminal disconnection completion time (s) --define(PH_DISCONNECTED_AT, <<"${disconnected_at}">>). +-define(PH_DISCONNECTED_AT, ?PH("disconnected_at")). --define(PH_NODE, <<"${node}">>). --define(PH_REASON, <<"${reason}">>). +-define(PH_NODE, ?PH("node")). +-define(PH_REASON, ?PH("reason")). --define(PH_ENDPOINT_NAME, <<"${endpoint_name}">>). --define(PH_RETAIN, <<"${retain}">>). +-define(PH_ENDPOINT_NAME, ?PH("endpoint_name")). +-define(VAR_RETAIN, "retain"). +-define(PH_RETAIN, ?PH(?VAR_RETAIN)). %% sync change these place holder with binary def. -define(PH_S_ACTION, "${action}"). diff --git a/apps/emqx_auth/src/emqx_authn/emqx_authn_utils.erl b/apps/emqx_auth/src/emqx_authn/emqx_authn_utils.erl index a9d672922..f782e0e6c 100644 --- a/apps/emqx_auth/src/emqx_authn/emqx_authn_utils.erl +++ b/apps/emqx_auth/src/emqx_authn/emqx_authn_utils.erl @@ -18,6 +18,7 @@ -include_lib("emqx/include/emqx_placeholder.hrl"). -include_lib("emqx_authn.hrl"). +-include_lib("snabbkaffe/include/trace.hrl"). -export([ create_resource/3, @@ -44,13 +45,13 @@ default_headers_no_content_type/0 ]). --define(AUTHN_PLACEHOLDERS, [ - ?PH_USERNAME, - ?PH_CLIENTID, - ?PH_PASSWORD, - ?PH_PEERHOST, - ?PH_CERT_SUBJECT, - ?PH_CERT_CN_NAME +-define(ALLOWED_VARS, [ + ?VAR_USERNAME, + ?VAR_CLIENTID, + ?VAR_PASSWORD, + ?VAR_PEERHOST, + ?VAR_CERT_SUBJECT, + ?VAR_CERT_CN_NAME ]). -define(DEFAULT_RESOURCE_OPTS, #{ @@ -107,48 +108,96 @@ check_password_from_selected_map(Algorithm, Selected, Password) -> end. parse_deep(Template) -> - emqx_placeholder:preproc_tmpl_deep(Template, #{placeholders => ?AUTHN_PLACEHOLDERS}). + Result = emqx_template:parse_deep(Template), + handle_disallowed_placeholders(Result, {deep, Template}). parse_str(Template) -> - emqx_placeholder:preproc_tmpl(Template, #{placeholders => ?AUTHN_PLACEHOLDERS}). + Result = emqx_template:parse(Template), + handle_disallowed_placeholders(Result, {string, Template}). parse_sql(Template, ReplaceWith) -> - emqx_placeholder:preproc_sql( + {Statement, Result} = emqx_template_sql:parse_prepstmt( Template, - #{ - replace_with => ReplaceWith, - placeholders => ?AUTHN_PLACEHOLDERS, - strip_double_quote => true - } - ). + #{parameters => ReplaceWith, strip_double_quote => true} + ), + {Statement, handle_disallowed_placeholders(Result, {string, Template})}. + +handle_disallowed_placeholders(Template, Source) -> + case emqx_template:validate(?ALLOWED_VARS, Template) of + ok -> + Template; + {error, Disallowed} -> + ?tp(warning, "authn_template_invalid", #{ + template => Source, + reason => Disallowed, + allowed => #{placeholders => ?ALLOWED_VARS}, + notice => + "Disallowed placeholders will be rendered as is." + " However, consider using `${$}` escaping for literal `$` where" + " needed to avoid unexpected results." + }), + Result = prerender_disallowed_placeholders(Template), + case Source of + {string, _} -> + emqx_template:parse(Result); + {deep, _} -> + emqx_template:parse_deep(Result) + end + end. + +prerender_disallowed_placeholders(Template) -> + {Result, _} = emqx_template:render(Template, #{}, #{ + var_trans => fun(Name, _) -> + % NOTE + % Rendering disallowed placeholders in escaped form, which will then + % parse as a literal string. + case lists:member(Name, ?ALLOWED_VARS) of + true -> "${" ++ Name ++ "}"; + false -> "${$}{" ++ Name ++ "}" + end + end + }), + Result. render_deep(Template, Credential) -> - emqx_placeholder:proc_tmpl_deep( + % NOTE + % Ignoring errors here, undefined bindings will be replaced with empty string. + {Term, _Errors} = emqx_template:render( Template, mapping_credential(Credential), - #{return => full_binary, var_trans => fun handle_var/2} - ). + #{var_trans => fun to_string/2} + ), + Term. render_str(Template, Credential) -> - emqx_placeholder:proc_tmpl( + % NOTE + % Ignoring errors here, undefined bindings will be replaced with empty string. + {String, _Errors} = emqx_template:render( Template, mapping_credential(Credential), - #{return => full_binary, var_trans => fun handle_var/2} - ). + #{var_trans => fun to_string/2} + ), + unicode:characters_to_binary(String). render_urlencoded_str(Template, Credential) -> - emqx_placeholder:proc_tmpl( + % NOTE + % Ignoring errors here, undefined bindings will be replaced with empty string. + {String, _Errors} = emqx_template:render( Template, mapping_credential(Credential), - #{return => full_binary, var_trans => fun urlencode_var/2} - ). + #{var_trans => fun to_urlencoded_string/2} + ), + unicode:characters_to_binary(String). render_sql_params(ParamList, Credential) -> - emqx_placeholder:proc_tmpl( + % NOTE + % Ignoring errors here, undefined bindings will be replaced with empty string. + {Row, _Errors} = emqx_template:render( ParamList, mapping_credential(Credential), - #{return => rawlist, var_trans => fun handle_sql_var/2} - ). + #{var_trans => fun to_sql_valaue/2} + ), + Row. is_superuser(#{<<"is_superuser">> := Value}) -> #{is_superuser => to_bool(Value)}; @@ -269,22 +318,24 @@ without_password(Credential, [Name | Rest]) -> without_password(Credential, Rest) end. -urlencode_var(Var, Value) -> - emqx_http_lib:uri_encode(handle_var(Var, Value)). +to_urlencoded_string(Name, Value) -> + emqx_http_lib:uri_encode(to_string(Name, Value)). -handle_var(_Name, undefined) -> - <<>>; -handle_var([<<"peerhost">>], PeerHost) -> - emqx_placeholder:bin(inet:ntoa(PeerHost)); -handle_var(_, Value) -> - emqx_placeholder:bin(Value). +to_string(Name, Value) -> + emqx_template:to_string(render_var(Name, Value)). -handle_sql_var(_Name, undefined) -> +to_sql_valaue(Name, Value) -> + emqx_utils_sql:to_sql_value(render_var(Name, Value)). + +render_var(_, undefined) -> + % NOTE + % Any allowed but undefined binding will be replaced with empty string, even when + % rendering SQL values. <<>>; -handle_sql_var([<<"peerhost">>], PeerHost) -> - emqx_placeholder:bin(inet:ntoa(PeerHost)); -handle_sql_var(_, Value) -> - emqx_placeholder:sql_data(Value). +render_var(?VAR_PEERHOST, Value) -> + inet:ntoa(Value); +render_var(_Name, Value) -> + Value. mapping_credential(C = #{cn := CN, dn := DN}) -> C#{cert_common_name => CN, cert_subject => DN}; diff --git a/apps/emqx_auth/src/emqx_authz/emqx_authz_rule.erl b/apps/emqx_auth/src/emqx_authz/emqx_authz_rule.erl index 6e13cac91..ad6dec56b 100644 --- a/apps/emqx_auth/src/emqx_authz/emqx_authz_rule.erl +++ b/apps/emqx_auth/src/emqx_authz/emqx_authz_rule.erl @@ -183,19 +183,14 @@ compile_topic(<<"eq ", Topic/binary>>) -> compile_topic({eq, Topic}) -> {eq, emqx_topic:words(bin(Topic))}; compile_topic(Topic) -> - TopicBin = bin(Topic), - case - emqx_placeholder:preproc_tmpl( - TopicBin, - #{placeholders => [?PH_USERNAME, ?PH_CLIENTID]} - ) - of - [{str, _}] -> emqx_topic:words(TopicBin); - Tokens -> {pattern, Tokens} + Template = emqx_authz_utils:parse_str(Topic, [?VAR_USERNAME, ?VAR_CLIENTID]), + case emqx_template:is_const(Template) of + true -> emqx_topic:words(bin(Topic)); + false -> {pattern, Template} end. bin(L) when is_list(L) -> - list_to_binary(L); + unicode:characters_to_binary(L); bin(B) when is_binary(B) -> B. @@ -307,7 +302,7 @@ match_who(_, _) -> match_topics(_ClientInfo, _Topic, []) -> false; match_topics(ClientInfo, Topic, [{pattern, PatternFilter} | Filters]) -> - TopicFilter = emqx_placeholder:proc_tmpl(PatternFilter, ClientInfo), + TopicFilter = bin(emqx_template:render_strict(PatternFilter, ClientInfo)), match_topic(emqx_topic:words(Topic), emqx_topic:words(TopicFilter)) orelse match_topics(ClientInfo, Topic, Filters); match_topics(ClientInfo, Topic, [TopicFilter | Filters]) -> diff --git a/apps/emqx_auth/src/emqx_authz/emqx_authz_utils.erl b/apps/emqx_auth/src/emqx_authz/emqx_authz_utils.erl index 3a0d4f1a1..a17a563ae 100644 --- a/apps/emqx_auth/src/emqx_authz/emqx_authz_utils.erl +++ b/apps/emqx_auth/src/emqx_authz/emqx_authz_utils.erl @@ -16,7 +16,9 @@ -module(emqx_authz_utils). +-include_lib("emqx/include/emqx_placeholder.hrl"). -include_lib("emqx_authz.hrl"). +-include_lib("snabbkaffe/include/trace.hrl"). -export([ cleanup_resources/0, @@ -108,48 +110,97 @@ update_config(Path, ConfigRequest) -> }). parse_deep(Template, PlaceHolders) -> - emqx_placeholder:preproc_tmpl_deep(Template, #{placeholders => PlaceHolders}). + Result = emqx_template:parse_deep(Template), + handle_disallowed_placeholders(Result, {deep, Template}, PlaceHolders). parse_str(Template, PlaceHolders) -> - emqx_placeholder:preproc_tmpl(Template, #{placeholders => PlaceHolders}). + Result = emqx_template:parse(Template), + handle_disallowed_placeholders(Result, {string, Template}, PlaceHolders). parse_sql(Template, ReplaceWith, PlaceHolders) -> - emqx_placeholder:preproc_sql( + {Statement, Result} = emqx_template_sql:parse_prepstmt( Template, - #{ - replace_with => ReplaceWith, - placeholders => PlaceHolders, - strip_double_quote => true - } - ). + #{parameters => ReplaceWith, strip_double_quote => true} + ), + FResult = handle_disallowed_placeholders(Result, {string, Template}, PlaceHolders), + {Statement, FResult}. + +handle_disallowed_placeholders(Template, Source, Allowed) -> + case emqx_template:validate(Allowed, Template) of + ok -> + Template; + {error, Disallowed} -> + ?tp(warning, "authz_template_invalid", #{ + template => Source, + reason => Disallowed, + allowed => #{placeholders => Allowed}, + notice => + "Disallowed placeholders will be rendered as is." + " However, consider using `${$}` escaping for literal `$` where" + " needed to avoid unexpected results." + }), + Result = prerender_disallowed_placeholders(Template, Allowed), + case Source of + {string, _} -> + emqx_template:parse(Result); + {deep, _} -> + emqx_template:parse_deep(Result) + end + end. + +prerender_disallowed_placeholders(Template, Allowed) -> + {Result, _} = emqx_template:render(Template, #{}, #{ + var_trans => fun(Name, _) -> + % NOTE + % Rendering disallowed placeholders in escaped form, which will then + % parse as a literal string. + case lists:member(Name, Allowed) of + true -> "${" ++ Name ++ "}"; + false -> "${$}{" ++ Name ++ "}" + end + end + }), + Result. render_deep(Template, Values) -> - emqx_placeholder:proc_tmpl_deep( + % NOTE + % Ignoring errors here, undefined bindings will be replaced with empty string. + {Term, _Errors} = emqx_template:render( Template, client_vars(Values), - #{return => full_binary, var_trans => fun handle_var/2} - ). + #{var_trans => fun to_string/2} + ), + Term. render_str(Template, Values) -> - emqx_placeholder:proc_tmpl( + % NOTE + % Ignoring errors here, undefined bindings will be replaced with empty string. + {String, _Errors} = emqx_template:render( Template, client_vars(Values), - #{return => full_binary, var_trans => fun handle_var/2} - ). + #{var_trans => fun to_string/2} + ), + unicode:characters_to_binary(String). render_urlencoded_str(Template, Values) -> - emqx_placeholder:proc_tmpl( + % NOTE + % Ignoring errors here, undefined bindings will be replaced with empty string. + {String, _Errors} = emqx_template:render( Template, client_vars(Values), - #{return => full_binary, var_trans => fun urlencode_var/2} - ). + #{var_trans => fun to_urlencoded_string/2} + ), + unicode:characters_to_binary(String). render_sql_params(ParamList, Values) -> - emqx_placeholder:proc_tmpl( + % NOTE + % Ignoring errors here, undefined bindings will be replaced with empty string. + {Row, _Errors} = emqx_template:render( ParamList, client_vars(Values), - #{return => rawlist, var_trans => fun handle_sql_var/2} - ). + #{var_trans => fun to_sql_value/2} + ), + Row. -spec parse_http_resp_body(binary(), binary()) -> allow | deny | ignore | error. parse_http_resp_body(<<"application/x-www-form-urlencoded", _/binary>>, Body) -> @@ -215,22 +266,24 @@ convert_client_var({dn, DN}) -> {cert_subject, DN}; convert_client_var({protocol, Proto}) -> {proto_name, Proto}; convert_client_var(Other) -> Other. -urlencode_var(Var, Value) -> - emqx_http_lib:uri_encode(handle_var(Var, Value)). +to_urlencoded_string(Name, Value) -> + emqx_http_lib:uri_encode(to_string(Name, Value)). -handle_var(_Name, undefined) -> - <<>>; -handle_var([<<"peerhost">>], IpAddr) -> - inet_parse:ntoa(IpAddr); -handle_var(_Name, Value) -> - emqx_placeholder:bin(Value). +to_string(Name, Value) -> + emqx_template:to_string(render_var(Name, Value)). -handle_sql_var(_Name, undefined) -> +to_sql_value(Name, Value) -> + emqx_utils_sql:to_sql_value(render_var(Name, Value)). + +render_var(_, undefined) -> + % NOTE + % Any allowed but undefined binding will be replaced with empty string, even when + % rendering SQL values. <<>>; -handle_sql_var([<<"peerhost">>], IpAddr) -> - inet_parse:ntoa(IpAddr); -handle_sql_var(_Name, Value) -> - emqx_placeholder:sql_data(Value). +render_var(?VAR_PEERHOST, Value) -> + inet:ntoa(Value); +render_var(_Name, Value) -> + Value. bin(A) when is_atom(A) -> atom_to_binary(A, utf8); bin(L) when is_list(L) -> list_to_binary(L); diff --git a/apps/emqx_auth/test/emqx_authz/emqx_authz_rule_SUITE.erl b/apps/emqx_auth/test/emqx_authz/emqx_authz_rule_SUITE.erl index b34e4fb00..d81a93038 100644 --- a/apps/emqx_auth/test/emqx_authz/emqx_authz_rule_SUITE.erl +++ b/apps/emqx_auth/test/emqx_authz/emqx_authz_rule_SUITE.erl @@ -67,6 +67,10 @@ set_special_configs(_App) -> ok. t_compile(_) -> + % NOTE + % Some of the following testcase are relying on the internal representation of + % `emqx_template:t()`. If the internal representation is changed, these testcases + % may fail. ?assertEqual({deny, all, all, [['#']]}, emqx_authz_rule:compile({deny, all})), ?assertEqual( @@ -74,13 +78,13 @@ t_compile(_) -> emqx_authz_rule:compile({allow, {ipaddr, "127.0.0.1"}, all, [{eq, "#"}, {eq, "+"}]}) ), - ?assertEqual( + ?assertMatch( {allow, {ipaddrs, [ {{127, 0, 0, 1}, {127, 0, 0, 1}, 32}, {{192, 168, 1, 0}, {192, 168, 1, 255}, 24} ]}, - subscribe, [{pattern, [{var, [<<"clientid">>]}]}]}, + subscribe, [{pattern, [{var, "clientid", [_]}]}]}, emqx_authz_rule:compile( {allow, {ipaddrs, ["127.0.0.1", "192.168.1.0/24"]}, subscribe, [?PH_S_CLIENTID]} ) @@ -102,7 +106,7 @@ t_compile(_) -> {clientid, {re_pattern, _, _, _, _}} ]}, publish, [ - {pattern, [{var, [<<"username">>]}]}, {pattern, [{var, [<<"clientid">>]}]} + {pattern, [{var, "username", [_]}]}, {pattern, [{var, "clientid", [_]}]} ]}, emqx_authz_rule:compile( {allow, @@ -114,9 +118,9 @@ t_compile(_) -> ) ), - ?assertEqual( + ?assertMatch( {allow, {username, {eq, <<"test">>}}, publish, [ - {pattern, [{str, <<"t/foo">>}, {var, [<<"username">>]}, {str, <<"boo">>}]} + {pattern, [<<"t/foo">>, {var, "username", [_]}, <<"boo">>]} ]}, emqx_authz_rule:compile({allow, {username, "test"}, publish, ["t/foo${username}boo"]}) ), diff --git a/apps/emqx_auth_http/src/emqx_authz_http.erl b/apps/emqx_auth_http/src/emqx_authz_http.erl index ed7051bb6..04f76b4c9 100644 --- a/apps/emqx_auth_http/src/emqx_authz_http.erl +++ b/apps/emqx_auth_http/src/emqx_authz_http.erl @@ -38,21 +38,21 @@ -compile(nowarn_export_all). -endif. --define(PLACEHOLDERS, [ - ?PH_USERNAME, - ?PH_CLIENTID, - ?PH_PEERHOST, - ?PH_PROTONAME, - ?PH_MOUNTPOINT, - ?PH_TOPIC, - ?PH_ACTION, - ?PH_CERT_SUBJECT, - ?PH_CERT_CN_NAME +-define(ALLOWED_VARS, [ + ?VAR_USERNAME, + ?VAR_CLIENTID, + ?VAR_PEERHOST, + ?VAR_PROTONAME, + ?VAR_MOUNTPOINT, + ?VAR_TOPIC, + ?VAR_ACTION, + ?VAR_CERT_SUBJECT, + ?VAR_CERT_CN_NAME ]). --define(PLACEHOLDERS_FOR_RICH_ACTIONS, [ - ?PH_QOS, - ?PH_RETAIN +-define(ALLOWED_VARS_RICH_ACTIONS, [ + ?VAR_QOS, + ?VAR_RETAIN ]). description() -> @@ -157,14 +157,14 @@ parse_config( method => Method, base_url => BaseUrl, headers => Headers, - base_path_templete => emqx_authz_utils:parse_str(Path, placeholders()), + base_path_templete => emqx_authz_utils:parse_str(Path, allowed_vars()), base_query_template => emqx_authz_utils:parse_deep( cow_qs:parse_qs(to_bin(Query)), - placeholders() + allowed_vars() ), body_template => emqx_authz_utils:parse_deep( maps:to_list(maps:get(body, Conf, #{})), - placeholders() + allowed_vars() ), request_timeout => ReqTimeout, %% pool_type default value `random` @@ -260,10 +260,10 @@ to_bin(B) when is_binary(B) -> B; to_bin(L) when is_list(L) -> list_to_binary(L); to_bin(X) -> X. -placeholders() -> - placeholders(emqx_authz:feature_available(rich_actions)). +allowed_vars() -> + allowed_vars(emqx_authz:feature_available(rich_actions)). -placeholders(true) -> - ?PLACEHOLDERS ++ ?PLACEHOLDERS_FOR_RICH_ACTIONS; -placeholders(false) -> - ?PLACEHOLDERS. +allowed_vars(true) -> + ?ALLOWED_VARS ++ ?ALLOWED_VARS_RICH_ACTIONS; +allowed_vars(false) -> + ?ALLOWED_VARS. diff --git a/apps/emqx_auth_http/test/emqx_authn_http_SUITE.erl b/apps/emqx_auth_http/test/emqx_authn_http_SUITE.erl index 577b3b638..e307b5bbf 100644 --- a/apps/emqx_auth_http/test/emqx_authn_http_SUITE.erl +++ b/apps/emqx_auth_http/test/emqx_authn_http_SUITE.erl @@ -27,7 +27,7 @@ -define(PATH, [?CONF_NS_ATOM]). -define(HTTP_PORT, 32333). --define(HTTP_PATH, "/auth"). +-define(HTTP_PATH, "/auth/[...]"). -define(CREDENTIALS, #{ clientid => <<"clienta">>, username => <<"plain">>, @@ -146,8 +146,12 @@ t_authenticate(_Config) -> test_user_auth(#{ handler := Handler, config_params := SpecificConfgParams, - result := Result + result := Expect }) -> + Result = perform_user_auth(SpecificConfgParams, Handler, ?CREDENTIALS), + ?assertEqual(Expect, Result). + +perform_user_auth(SpecificConfgParams, Handler, Credentials) -> AuthConfig = maps:merge(raw_http_auth_config(), SpecificConfgParams), {ok, _} = emqx:update_config( @@ -157,21 +161,21 @@ test_user_auth(#{ ok = emqx_authn_http_test_server:set_handler(Handler), - ?assertEqual(Result, emqx_access_control:authenticate(?CREDENTIALS)), + Result = emqx_access_control:authenticate(Credentials), emqx_authn_test_lib:delete_authenticators( [authentication], ?GLOBAL - ). + ), + + Result. t_authenticate_path_placeholders(_Config) -> - ok = emqx_authn_http_test_server:stop(), - {ok, _} = emqx_authn_http_test_server:start_link(?HTTP_PORT, <<"/[...]">>), ok = emqx_authn_http_test_server:set_handler( fun(Req0, State) -> Req = case cowboy_req:path(Req0) of - <<"/my/p%20ath//us%20er/auth//">> -> + <<"/auth/p%20ath//us%20er/auth//">> -> cowboy_req:reply( 200, #{<<"content-type">> => <<"application/json">>}, @@ -193,7 +197,7 @@ t_authenticate_path_placeholders(_Config) -> AuthConfig = maps:merge( raw_http_auth_config(), #{ - <<"url">> => <<"http://127.0.0.1:32333/my/p%20ath//${username}/auth//">>, + <<"url">> => <<"http://127.0.0.1:32333/auth/p%20ath//${username}/auth//">>, <<"body">> => #{} } ), @@ -255,6 +259,39 @@ t_no_value_for_placeholder(_Config) -> ?GLOBAL ). +t_disallowed_placeholders_preserved(_Config) -> + Config = #{ + <<"method">> => <<"post">>, + <<"headers">> => #{<<"content-type">> => <<"application/json">>}, + <<"body">> => #{ + <<"username">> => ?PH_USERNAME, + <<"password">> => ?PH_PASSWORD, + <<"this">> => <<"${whatisthis}">> + } + }, + Handler = fun(Req0, State) -> + {ok, Body, Req1} = cowboy_req:read_body(Req0), + #{ + <<"username">> := <<"plain">>, + <<"password">> := <<"plain">>, + <<"this">> := <<"${whatisthis}">> + } = emqx_utils_json:decode(Body), + Req = cowboy_req:reply( + 200, + #{<<"content-type">> => <<"application/json">>}, + emqx_utils_json:encode(#{result => allow, is_superuser => false}), + Req1 + ), + {ok, Req, State} + end, + ?assertMatch({ok, _}, perform_user_auth(Config, Handler, ?CREDENTIALS)), + + % NOTE: disallowed placeholder left intact, which makes the URL invalid + ConfigUrl = Config#{ + <<"url">> => <<"http://127.0.0.1:32333/auth/${whatisthis}">> + }, + ?assertMatch({error, _}, perform_user_auth(ConfigUrl, Handler, ?CREDENTIALS)). + t_destroy(_Config) -> AuthConfig = raw_http_auth_config(), diff --git a/apps/emqx_auth_http/test/emqx_authz_http_SUITE.erl b/apps/emqx_auth_http/test/emqx_authz_http_SUITE.erl index e56e25f5f..845259e78 100644 --- a/apps/emqx_auth_http/test/emqx_authz_http_SUITE.erl +++ b/apps/emqx_auth_http/test/emqx_authz_http_SUITE.erl @@ -494,6 +494,67 @@ t_no_value_for_placeholder(_Config) -> emqx_access_control:authorize(ClientInfo, ?AUTHZ_PUBLISH, <<"t">>) ). +t_disallowed_placeholders_preserved(_Config) -> + ok = setup_handler_and_config( + fun(Req0, State) -> + {ok, Body, Req1} = cowboy_req:read_body(Req0), + ?assertMatch( + #{ + <<"cname">> := <<>>, + <<"usertypo">> := <<"${usertypo}">> + }, + emqx_utils_json:decode(Body) + ), + {ok, ?AUTHZ_HTTP_RESP(allow, Req1), State} + end, + #{ + <<"method">> => <<"post">>, + <<"body">> => #{ + <<"cname">> => ?PH_CERT_CN_NAME, + <<"usertypo">> => <<"${usertypo}">> + } + } + ), + + ClientInfo = #{ + clientid => <<"client id">>, + username => <<"user name">>, + peerhost => {127, 0, 0, 1}, + protocol => <<"MQTT">>, + zone => default, + listener => {tcp, default} + }, + + ?assertEqual( + allow, + emqx_access_control:authorize(ClientInfo, ?AUTHZ_PUBLISH, <<"t">>) + ). + +t_disallowed_placeholders_path(_Config) -> + ok = setup_handler_and_config( + fun(Req, State) -> + {ok, ?AUTHZ_HTTP_RESP(allow, Req), State} + end, + #{ + <<"url">> => <<"http://127.0.0.1:33333/authz/use%20rs/${typo}">> + } + ), + + ClientInfo = #{ + clientid => <<"client id">>, + username => <<"user name">>, + peerhost => {127, 0, 0, 1}, + protocol => <<"MQTT">>, + zone => default, + listener => {tcp, default} + }, + + % % NOTE: disallowed placeholder left intact, which makes the URL invalid + ?assertEqual( + deny, + emqx_access_control:authorize(ClientInfo, ?AUTHZ_PUBLISH, <<"t">>) + ). + t_create_replace(_Config) -> ClientInfo = #{ clientid => <<"clientid">>, diff --git a/apps/emqx_auth_mongodb/src/emqx_authz_mongodb.erl b/apps/emqx_auth_mongodb/src/emqx_authz_mongodb.erl index 3b235ad2c..fdeb9d542 100644 --- a/apps/emqx_auth_mongodb/src/emqx_authz_mongodb.erl +++ b/apps/emqx_auth_mongodb/src/emqx_authz_mongodb.erl @@ -35,12 +35,12 @@ -compile(nowarn_export_all). -endif. --define(PLACEHOLDERS, [ - ?PH_USERNAME, - ?PH_CLIENTID, - ?PH_PEERHOST, - ?PH_CERT_CN_NAME, - ?PH_CERT_SUBJECT +-define(ALLOWED_VARS, [ + ?VAR_USERNAME, + ?VAR_CLIENTID, + ?VAR_PEERHOST, + ?VAR_CERT_CN_NAME, + ?VAR_CERT_SUBJECT ]). description() -> @@ -49,11 +49,11 @@ description() -> create(#{filter := Filter} = Source) -> ResourceId = emqx_authz_utils:make_resource_id(?MODULE), {ok, _Data} = emqx_authz_utils:create_resource(ResourceId, emqx_mongodb, Source), - FilterTemp = emqx_authz_utils:parse_deep(Filter, ?PLACEHOLDERS), + FilterTemp = emqx_authz_utils:parse_deep(Filter, ?ALLOWED_VARS), Source#{annotations => #{id => ResourceId}, filter_template => FilterTemp}. update(#{filter := Filter} = Source) -> - FilterTemp = emqx_authz_utils:parse_deep(Filter, ?PLACEHOLDERS), + FilterTemp = emqx_authz_utils:parse_deep(Filter, ?ALLOWED_VARS), case emqx_authz_utils:update_resource(emqx_mongodb, Source) of {error, Reason} -> error({load_config_error, Reason}); diff --git a/apps/emqx_auth_mysql/src/emqx_authz_mysql.erl b/apps/emqx_auth_mysql/src/emqx_authz_mysql.erl index 4ca71e332..8c9e54ee1 100644 --- a/apps/emqx_auth_mysql/src/emqx_authz_mysql.erl +++ b/apps/emqx_auth_mysql/src/emqx_authz_mysql.erl @@ -37,26 +37,26 @@ -compile(nowarn_export_all). -endif. --define(PLACEHOLDERS, [ - ?PH_USERNAME, - ?PH_CLIENTID, - ?PH_PEERHOST, - ?PH_CERT_CN_NAME, - ?PH_CERT_SUBJECT +-define(ALLOWED_VARS, [ + ?VAR_USERNAME, + ?VAR_CLIENTID, + ?VAR_PEERHOST, + ?VAR_CERT_CN_NAME, + ?VAR_CERT_SUBJECT ]). description() -> "AuthZ with Mysql". create(#{query := SQL} = Source0) -> - {PrepareSQL, TmplToken} = emqx_authz_utils:parse_sql(SQL, '?', ?PLACEHOLDERS), + {PrepareSQL, TmplToken} = emqx_authz_utils:parse_sql(SQL, '?', ?ALLOWED_VARS), ResourceId = emqx_authz_utils:make_resource_id(?MODULE), Source = Source0#{prepare_statement => #{?PREPARE_KEY => PrepareSQL}}, {ok, _Data} = emqx_authz_utils:create_resource(ResourceId, emqx_mysql, Source), Source#{annotations => #{id => ResourceId, tmpl_token => TmplToken}}. update(#{query := SQL} = Source0) -> - {PrepareSQL, TmplToken} = emqx_authz_utils:parse_sql(SQL, '?', ?PLACEHOLDERS), + {PrepareSQL, TmplToken} = emqx_authz_utils:parse_sql(SQL, '?', ?ALLOWED_VARS), Source = Source0#{prepare_statement => #{?PREPARE_KEY => PrepareSQL}}, case emqx_authz_utils:update_resource(emqx_mysql, Source) of {error, Reason} -> diff --git a/apps/emqx_auth_postgresql/src/emqx_authz_postgresql.erl b/apps/emqx_auth_postgresql/src/emqx_authz_postgresql.erl index b930f77e4..14b7598a6 100644 --- a/apps/emqx_auth_postgresql/src/emqx_authz_postgresql.erl +++ b/apps/emqx_auth_postgresql/src/emqx_authz_postgresql.erl @@ -37,19 +37,19 @@ -compile(nowarn_export_all). -endif. --define(PLACEHOLDERS, [ - ?PH_USERNAME, - ?PH_CLIENTID, - ?PH_PEERHOST, - ?PH_CERT_CN_NAME, - ?PH_CERT_SUBJECT +-define(ALLOWED_VARS, [ + ?VAR_USERNAME, + ?VAR_CLIENTID, + ?VAR_PEERHOST, + ?VAR_CERT_CN_NAME, + ?VAR_CERT_SUBJECT ]). description() -> "AuthZ with PostgreSQL". create(#{query := SQL0} = Source) -> - {SQL, PlaceHolders} = emqx_authz_utils:parse_sql(SQL0, '$n', ?PLACEHOLDERS), + {SQL, PlaceHolders} = emqx_authz_utils:parse_sql(SQL0, '$n', ?ALLOWED_VARS), ResourceID = emqx_authz_utils:make_resource_id(emqx_postgresql), {ok, _Data} = emqx_authz_utils:create_resource( ResourceID, @@ -59,7 +59,7 @@ create(#{query := SQL0} = Source) -> Source#{annotations => #{id => ResourceID, placeholders => PlaceHolders}}. update(#{query := SQL0, annotations := #{id := ResourceID}} = Source) -> - {SQL, PlaceHolders} = emqx_authz_utils:parse_sql(SQL0, '$n', ?PLACEHOLDERS), + {SQL, PlaceHolders} = emqx_authz_utils:parse_sql(SQL0, '$n', ?ALLOWED_VARS), case emqx_authz_utils:update_resource( emqx_postgresql, diff --git a/apps/emqx_auth_redis/src/emqx_authz_redis.erl b/apps/emqx_auth_redis/src/emqx_authz_redis.erl index 9b69f508a..ca4a11742 100644 --- a/apps/emqx_auth_redis/src/emqx_authz_redis.erl +++ b/apps/emqx_auth_redis/src/emqx_authz_redis.erl @@ -35,12 +35,12 @@ -compile(nowarn_export_all). -endif. --define(PLACEHOLDERS, [ - ?PH_CERT_CN_NAME, - ?PH_CERT_SUBJECT, - ?PH_PEERHOST, - ?PH_CLIENTID, - ?PH_USERNAME +-define(ALLOWED_VARS, [ + ?VAR_CERT_CN_NAME, + ?VAR_CERT_SUBJECT, + ?VAR_PEERHOST, + ?VAR_CLIENTID, + ?VAR_USERNAME ]). description() -> @@ -133,7 +133,7 @@ parse_cmd(Query) -> case emqx_redis_command:split(Query) of {ok, Cmd} -> ok = validate_cmd(Cmd), - emqx_authz_utils:parse_deep(Cmd, ?PLACEHOLDERS); + emqx_authz_utils:parse_deep(Cmd, ?ALLOWED_VARS); {error, Reason} -> error({invalid_redis_cmd, Reason, Query}) end. diff --git a/apps/emqx_bridge_http/src/emqx_bridge_http_connector.erl b/apps/emqx_bridge_http/src/emqx_bridge_http_connector.erl index 5d1b1947c..b2f876d21 100644 --- a/apps/emqx_bridge_http/src/emqx_bridge_http_connector.erl +++ b/apps/emqx_bridge_http/src/emqx_bridge_http_connector.erl @@ -479,61 +479,47 @@ preprocess_request( } = Req ) -> #{ - method => emqx_placeholder:preproc_tmpl(to_bin(Method)), - path => emqx_placeholder:preproc_tmpl(Path), - body => maybe_preproc_tmpl(body, Req), - headers => wrap_auth_header(preproc_headers(Headers)), + method => parse_template(to_bin(Method)), + path => parse_template(Path), + body => maybe_parse_template(body, Req), + headers => parse_headers(Headers), request_timeout => maps:get(request_timeout, Req, ?DEFAULT_REQUEST_TIMEOUT_MS), max_retries => maps:get(max_retries, Req, 2) }. -preproc_headers(Headers) when is_map(Headers) -> +parse_headers(Headers) when is_map(Headers) -> maps:fold( - fun(K, V, Acc) -> - [ - { - emqx_placeholder:preproc_tmpl(to_bin(K)), - emqx_placeholder:preproc_tmpl(to_bin(V)) - } - | Acc - ] - end, + fun(K, V, Acc) -> [parse_header(K, V) | Acc] end, [], Headers ); -preproc_headers(Headers) when is_list(Headers) -> +parse_headers(Headers) when is_list(Headers) -> lists:map( - fun({K, V}) -> - { - emqx_placeholder:preproc_tmpl(to_bin(K)), - emqx_placeholder:preproc_tmpl(to_bin(V)) - } - end, + fun({K, V}) -> parse_header(K, V) end, Headers ). -wrap_auth_header(Headers) -> - lists:map(fun maybe_wrap_auth_header/1, Headers). +parse_header(K, V) -> + KStr = to_bin(K), + VTpl = parse_template(to_bin(V)), + {parse_template(KStr), maybe_wrap_auth_header(KStr, VTpl)}. -maybe_wrap_auth_header({[{str, Key}] = StrKey, Val}) -> - {_, MaybeWrapped} = maybe_wrap_auth_header({Key, Val}), - {StrKey, MaybeWrapped}; -maybe_wrap_auth_header({Key, Val} = Header) when - is_binary(Key), (size(Key) =:= 19 orelse size(Key) =:= 13) +maybe_wrap_auth_header(Key, VTpl) when + (byte_size(Key) =:= 19 orelse byte_size(Key) =:= 13) -> %% We check the size of potential keys in the guard above and consider only %% those that match the number of characters of either "Authorization" or %% "Proxy-Authorization". case try_bin_to_lower(Key) of <<"authorization">> -> - {Key, emqx_secret:wrap(Val)}; + emqx_secret:wrap(VTpl); <<"proxy-authorization">> -> - {Key, emqx_secret:wrap(Val)}; + emqx_secret:wrap(VTpl); _Other -> - Header + VTpl end; -maybe_wrap_auth_header(Header) -> - Header. +maybe_wrap_auth_header(_Key, VTpl) -> + VTpl. try_bin_to_lower(Bin) -> try iolist_to_binary(string:lowercase(Bin)) of @@ -542,46 +528,57 @@ try_bin_to_lower(Bin) -> _:_ -> Bin end. -maybe_preproc_tmpl(Key, Conf) -> +maybe_parse_template(Key, Conf) -> case maps:get(Key, Conf, undefined) of undefined -> undefined; - Val -> emqx_placeholder:preproc_tmpl(Val) + Val -> parse_template(Val) end. +parse_template(String) -> + emqx_template:parse(String). + process_request( #{ - method := MethodTks, - path := PathTks, - body := BodyTks, - headers := HeadersTks, + method := MethodTemplate, + path := PathTemplate, + body := BodyTemplate, + headers := HeadersTemplate, request_timeout := ReqTimeout } = Conf, Msg ) -> Conf#{ - method => make_method(emqx_placeholder:proc_tmpl(MethodTks, Msg)), - path => emqx_placeholder:proc_tmpl(PathTks, Msg), - body => process_request_body(BodyTks, Msg), - headers => proc_headers(HeadersTks, Msg), + method => make_method(render_template_string(MethodTemplate, Msg)), + path => unicode:characters_to_list(render_template(PathTemplate, Msg)), + body => render_request_body(BodyTemplate, Msg), + headers => render_headers(HeadersTemplate, Msg), request_timeout => ReqTimeout }. -process_request_body(undefined, Msg) -> +render_request_body(undefined, Msg) -> emqx_utils_json:encode(Msg); -process_request_body(BodyTks, Msg) -> - emqx_placeholder:proc_tmpl(BodyTks, Msg). +render_request_body(BodyTks, Msg) -> + render_template(BodyTks, Msg). -proc_headers(HeaderTks, Msg) -> +render_headers(HeaderTks, Msg) -> lists:map( fun({K, V}) -> { - emqx_placeholder:proc_tmpl(K, Msg), - emqx_placeholder:proc_tmpl(emqx_secret:unwrap(V), Msg) + render_template_string(K, Msg), + render_template_string(emqx_secret:unwrap(V), Msg) } end, HeaderTks ). +render_template(Template, Msg) -> + % NOTE: ignoring errors here, missing variables will be rendered as `"undefined"`. + {String, _Errors} = emqx_template:render(Template, {emqx_jsonish, Msg}), + String. + +render_template_string(Template, Msg) -> + unicode:characters_to_binary(render_template(Template, Msg)). + make_method(M) when M == <<"POST">>; M == <<"post">> -> post; make_method(M) when M == <<"PUT">>; M == <<"put">> -> put; make_method(M) when M == <<"GET">>; M == <<"get">> -> get; @@ -716,8 +713,6 @@ maybe_retry(Result, _Context, ReplyFunAndArgs) -> emqx_resource:apply_reply_fun(ReplyFunAndArgs, Result). %% The HOCON schema system may generate sensitive keys with this format -is_sensitive_key([{str, StringKey}]) -> - is_sensitive_key(StringKey); is_sensitive_key(Atom) when is_atom(Atom) -> is_sensitive_key(erlang:atom_to_binary(Atom)); is_sensitive_key(Bin) when is_binary(Bin), (size(Bin) =:= 19 orelse size(Bin) =:= 13) -> @@ -742,25 +737,19 @@ redact(Data) -> %% and we also can't know the body format and where the sensitive data will be %% so the easy way to keep data security is redacted the whole body redact_request({Path, Headers}) -> - {Path, redact(Headers)}; + {Path, Headers}; redact_request({Path, Headers, _Body}) -> - {Path, redact(Headers), <<"******">>}. + {Path, Headers, <<"******">>}. -ifdef(TEST). -include_lib("eunit/include/eunit.hrl"). redact_test_() -> - TestData1 = [ - {<<"content-type">>, <<"application/json">>}, - {<<"Authorization">>, <<"Basic YWxhZGRpbjpvcGVuc2VzYW1l">>} - ], - - TestData2 = #{ - headers => - [ - {[{str, <<"content-type">>}], [{str, <<"application/json">>}]}, - {[{str, <<"Authorization">>}], [{str, <<"Basic YWxhZGRpbjpvcGVuc2VzYW1l">>}]} - ] + TestData = #{ + headers => [ + {<<"content-type">>, <<"application/json">>}, + {<<"Authorization">>, <<"Basic YWxhZGRpbjpvcGVuc2VzYW1l">>} + ] }, [ ?_assert(is_sensitive_key(<<"Authorization">>)), @@ -770,8 +759,7 @@ redact_test_() -> ?_assert(is_sensitive_key('PrOxy-authoRizaTion')), ?_assertNot(is_sensitive_key(<<"Something">>)), ?_assertNot(is_sensitive_key(89)), - ?_assertNotEqual(TestData1, redact(TestData1)), - ?_assertNotEqual(TestData2, redact(TestData2)) + ?_assertNotEqual(TestData, redact(TestData)) ]. join_paths_test_() -> diff --git a/apps/emqx_bridge_http/test/emqx_bridge_http_connector_tests.erl b/apps/emqx_bridge_http/test/emqx_bridge_http_connector_tests.erl index 6b5c2b0cd..4f5e2929c 100644 --- a/apps/emqx_bridge_http/test/emqx_bridge_http_connector_tests.erl +++ b/apps/emqx_bridge_http/test/emqx_bridge_http_connector_tests.erl @@ -83,7 +83,8 @@ is_wrapped(Secret) when is_function(Secret) -> is_wrapped(_Other) -> false. -untmpl([{_, V} | _]) -> V. +untmpl(Tpl) -> + iolist_to_binary(emqx_template:render_strict(Tpl, #{})). is_unwrapped_headers(Headers) -> lists:all(fun is_unwrapped_header/1, Headers). diff --git a/apps/emqx_bridge_mysql/test/emqx_bridge_mysql_SUITE.erl b/apps/emqx_bridge_mysql/test/emqx_bridge_mysql_SUITE.erl index 3ed40e903..a34b65ede 100644 --- a/apps/emqx_bridge_mysql/test/emqx_bridge_mysql_SUITE.erl +++ b/apps/emqx_bridge_mysql/test/emqx_bridge_mysql_SUITE.erl @@ -566,7 +566,6 @@ t_simple_sql_query(Config) -> t_missing_data(Config) -> BatchSize = ?config(batch_size, Config), - IsBatch = BatchSize > 1, ?assertMatch( {ok, _}, create_bridge(Config) @@ -577,8 +576,8 @@ t_missing_data(Config) -> ), send_message(Config, #{}), {ok, [Event]} = snabbkaffe:receive_events(SRef), - case IsBatch of - true -> + case BatchSize of + N when N > 1 -> ?assertMatch( #{ result := @@ -588,7 +587,7 @@ t_missing_data(Config) -> }, Event ); - false -> + 1 -> ?assertMatch( #{ result := diff --git a/apps/emqx_bridge_pgsql/test/emqx_bridge_pgsql_SUITE.erl b/apps/emqx_bridge_pgsql/test/emqx_bridge_pgsql_SUITE.erl index cd79db43d..156d4bd16 100644 --- a/apps/emqx_bridge_pgsql/test/emqx_bridge_pgsql_SUITE.erl +++ b/apps/emqx_bridge_pgsql/test/emqx_bridge_pgsql_SUITE.erl @@ -324,6 +324,7 @@ connect_and_drop_table(Config) -> connect_and_clear_table(Config) -> Con = connect_direct_pgsql(Config), + _ = epgsql:squery(Con, ?SQL_CREATE_TABLE), {ok, _} = epgsql:squery(Con, ?SQL_DELETE), ok = epgsql:close(Con). @@ -668,7 +669,7 @@ t_missing_table(Config) -> ok end, fun(Trace) -> - ?assertMatch([_, _, _], ?of_kind(pgsql_undefined_table, Trace)), + ?assertMatch([_], ?of_kind(pgsql_undefined_table, Trace)), ok end ), diff --git a/apps/emqx_connector/src/emqx_connector_utils.erl b/apps/emqx_connector/src/emqx_connector_utils.erl deleted file mode 100644 index 6000f6be5..000000000 --- a/apps/emqx_connector/src/emqx_connector_utils.erl +++ /dev/null @@ -1,35 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2022-2023 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - --module(emqx_connector_utils). - --export([split_insert_sql/1]). - -%% SQL = <<"INSERT INTO \"abc\" (c1,c2,c3) VALUES (${1}, ${1}, ${1})">> -split_insert_sql(SQL) -> - case re:split(SQL, "((?i)values)", [{return, binary}]) of - [Part1, _, Part3] -> - case string:trim(Part1, leading) of - <<"insert", _/binary>> = InsertSQL -> - {ok, {InsertSQL, Part3}}; - <<"INSERT", _/binary>> = InsertSQL -> - {ok, {InsertSQL, Part3}}; - _ -> - {error, not_insert_sql} - end; - _ -> - {error, not_insert_sql} - end. diff --git a/apps/emqx_mysql/src/emqx_mysql.erl b/apps/emqx_mysql/src/emqx_mysql.erl index 4440bcfbb..d8b7994ab 100644 --- a/apps/emqx_mysql/src/emqx_mysql.erl +++ b/apps/emqx_mysql/src/emqx_mysql.erl @@ -46,16 +46,12 @@ default_port => ?MYSQL_DEFAULT_PORT }). --type prepares() :: #{atom() => binary()}. --type params_tokens() :: #{atom() => list()}. --type sqls() :: #{atom() => binary()}. +-type template() :: {unicode:chardata(), emqx_template:str()}. -type state() :: #{ pool_name := binary(), - prepare_statement := prepares(), - params_tokens := params_tokens(), - batch_inserts := sqls(), - batch_params_tokens := params_tokens() + prepares := ok | {error, _}, + templates := #{{atom(), batch | prepstmt} => template()} }. %%===================================================================== @@ -154,13 +150,13 @@ on_query(InstId, {TypeOrKey, SQLOrKey, Params}, State) -> on_query( InstId, {TypeOrKey, SQLOrKey, Params, Timeout}, - #{pool_name := PoolName, prepare_statement := Prepares} = State + State ) -> MySqlFunction = mysql_function(TypeOrKey), {SQLOrKey2, Data} = proc_sql_params(TypeOrKey, SQLOrKey, Params, State), case on_sql_query(InstId, MySqlFunction, SQLOrKey2, Data, Timeout, State) of {error, not_prepared} -> - case maybe_prepare_sql(SQLOrKey2, Prepares, PoolName) of + case maybe_prepare_sql(SQLOrKey2, State) of ok -> ?tp( mysql_connector_on_query_prepared_sql, @@ -187,23 +183,27 @@ on_query( on_batch_query( InstId, - BatchReq, - #{batch_inserts := Inserts, batch_params_tokens := ParamsTokens} = State + BatchReq = [{Key, _} | _], + #{query_templates := Templates} = State ) -> - case hd(BatchReq) of - {Key, _} -> - case maps:get(Key, Inserts, undefined) of - undefined -> - {error, {unrecoverable_error, batch_select_not_implemented}}; - InsertSQL -> - Tokens = maps:get(Key, ParamsTokens), - on_batch_insert(InstId, BatchReq, InsertSQL, Tokens, State) - end; - Request -> - LogMeta = #{connector => InstId, first_request => Request, state => State}, - ?SLOG(error, LogMeta#{msg => "invalid request"}), - {error, {unrecoverable_error, invalid_request}} - end. + case maps:get({Key, batch}, Templates, undefined) of + undefined -> + {error, {unrecoverable_error, batch_select_not_implemented}}; + Template -> + on_batch_insert(InstId, BatchReq, Template, State) + end; +on_batch_query( + InstId, + BatchReq, + State +) -> + ?SLOG(error, #{ + msg => "invalid request", + connector => InstId, + request => BatchReq, + state => State + }), + {error, {unrecoverable_error, invalid_request}}. mysql_function(sql) -> query; @@ -222,8 +222,8 @@ on_get_status(_InstId, #{pool_name := PoolName} = State) -> {ok, NState} -> %% return new state with prepared statements {connected, NState}; - {error, {undefined_table, NState}} -> - {disconnected, NState, unhealthy_target}; + {error, undefined_table} -> + {disconnected, State, unhealthy_target}; {error, _Reason} -> %% do not log error, it is logged in prepare_sql_to_conn connecting @@ -238,8 +238,8 @@ do_get_status(Conn) -> do_check_prepares( #{ pool_name := PoolName, - prepare_statement := #{send_message := SQL} - } = State + templates := #{{send_message, prepstmt} := SQL} + } ) -> % it's already connected. Verify if target table still exists Workers = [Worker || {_WorkerName, Worker} <- ecpool:workers(PoolName)], @@ -250,7 +250,7 @@ do_check_prepares( {ok, Conn} -> case mysql:prepare(Conn, get_status, SQL) of {error, {1146, _, _}} -> - {error, {undefined_table, State}}; + {error, undefined_table}; {ok, Statement} -> mysql:unprepare(Conn, Statement); _ -> @@ -265,17 +265,14 @@ do_check_prepares( ok, Workers ); -do_check_prepares(#{prepare_statement := Statement}) when is_map(Statement) -> +do_check_prepares(#{prepares := ok}) -> ok; -do_check_prepares(State = #{pool_name := PoolName, prepare_statement := {error, Prepares}}) -> +do_check_prepares(#{prepares := {error, _}} = State) -> %% retry to prepare - case prepare_sql(Prepares, PoolName) of + case prepare_sql(State) of ok -> %% remove the error - {ok, State#{prepare_statement => Prepares}}; - {error, undefined_table} -> - %% indicate the error - {error, {undefined_table, State#{prepare_statement => {error, Prepares}}}}; + {ok, State#{prepares => ok}}; {error, Reason} -> {error, Reason} end. @@ -285,41 +282,44 @@ do_check_prepares(State = #{pool_name := PoolName, prepare_statement := {error, connect(Options) -> mysql:start_link(Options). -init_prepare(State = #{prepare_statement := Prepares, pool_name := PoolName}) -> - case maps:size(Prepares) of +init_prepare(State = #{query_templates := Templates}) -> + case maps:size(Templates) of 0 -> - State; + State#{prepares => ok}; _ -> - case prepare_sql(Prepares, PoolName) of + case prepare_sql(State) of ok -> - State; + State#{prepares => ok}; {error, Reason} -> - LogMeta = #{msg => <<"mysql_init_prepare_statement_failed">>, reason => Reason}, - ?SLOG(error, LogMeta), + ?SLOG(error, #{ + msg => <<"MySQL init prepare statement failed">>, + reason => Reason + }), %% mark the prepare_statement as failed - State#{prepare_statement => {error, Prepares}} + State#{prepares => {error, Reason}} end end. -maybe_prepare_sql(SQLOrKey, Prepares, PoolName) -> - case maps:is_key(SQLOrKey, Prepares) of - true -> prepare_sql(Prepares, PoolName); +maybe_prepare_sql(SQLOrKey, State = #{query_templates := Templates}) -> + case maps:is_key({SQLOrKey, prepstmt}, Templates) of + true -> prepare_sql(State); false -> {error, {unrecoverable_error, prepared_statement_invalid}} end. -prepare_sql(Prepares, PoolName) when is_map(Prepares) -> - prepare_sql(maps:to_list(Prepares), PoolName); -prepare_sql(Prepares, PoolName) -> - case do_prepare_sql(Prepares, PoolName) of +prepare_sql(#{query_templates := Templates, pool_name := PoolName}) -> + prepare_sql(maps:to_list(Templates), PoolName). + +prepare_sql(Templates, PoolName) -> + case do_prepare_sql(Templates, PoolName) of ok -> %% prepare for reconnect - ecpool:add_reconnect_callback(PoolName, {?MODULE, prepare_sql_to_conn, [Prepares]}), + ecpool:add_reconnect_callback(PoolName, {?MODULE, prepare_sql_to_conn, [Templates]}), ok; {error, R} -> {error, R} end. -do_prepare_sql(Prepares, PoolName) -> +do_prepare_sql(Templates, PoolName) -> Conns = [ begin @@ -328,33 +328,30 @@ do_prepare_sql(Prepares, PoolName) -> end || {_Name, Worker} <- ecpool:workers(PoolName) ], - prepare_sql_to_conn_list(Conns, Prepares). + prepare_sql_to_conn_list(Conns, Templates). -prepare_sql_to_conn_list([], _PrepareList) -> +prepare_sql_to_conn_list([], _Templates) -> ok; -prepare_sql_to_conn_list([Conn | ConnList], PrepareList) -> - case prepare_sql_to_conn(Conn, PrepareList) of +prepare_sql_to_conn_list([Conn | ConnList], Templates) -> + case prepare_sql_to_conn(Conn, Templates) of ok -> - prepare_sql_to_conn_list(ConnList, PrepareList); + prepare_sql_to_conn_list(ConnList, Templates); {error, R} -> %% rollback - Fun = fun({Key, _}) -> - _ = unprepare_sql_to_conn(Conn, Key), - ok - end, - lists:foreach(Fun, PrepareList), + _ = [unprepare_sql_to_conn(Conn, Template) || Template <- Templates], {error, R} end. -prepare_sql_to_conn(Conn, []) when is_pid(Conn) -> ok; -prepare_sql_to_conn(Conn, [{Key, SQL} | PrepareList]) when is_pid(Conn) -> - LogMeta = #{msg => "mysql_prepare_statement", name => Key, prepare_sql => SQL}, +prepare_sql_to_conn(_Conn, []) -> + ok; +prepare_sql_to_conn(Conn, [{{Key, prepstmt}, {SQL, _RowTemplate}} | Rest]) -> + LogMeta = #{msg => "MySQL Prepare Statement", name => Key, prepare_sql => SQL}, ?SLOG(info, LogMeta), _ = unprepare_sql_to_conn(Conn, Key), case mysql:prepare(Conn, Key, SQL) of {ok, _Key} -> ?SLOG(info, LogMeta#{result => success}), - prepare_sql_to_conn(Conn, PrepareList); + prepare_sql_to_conn(Conn, Rest); {error, {1146, _, _} = Reason} -> %% Target table is not created ?tp(mysql_undefined_table, #{}), @@ -365,84 +362,92 @@ prepare_sql_to_conn(Conn, [{Key, SQL} | PrepareList]) when is_pid(Conn) -> % syntax failures. Retrying syntax failures is not very productive. ?SLOG(error, LogMeta#{result => failed, reason => Reason}), {error, Reason} - end. + end; +prepare_sql_to_conn(Conn, [{_Key, _Template} | Rest]) -> + prepare_sql_to_conn(Conn, Rest). -unprepare_sql_to_conn(Conn, PrepareSqlKey) -> - mysql:unprepare(Conn, PrepareSqlKey). +unprepare_sql_to_conn(Conn, {{Key, prepstmt}, _}) -> + mysql:unprepare(Conn, Key); +unprepare_sql_to_conn(Conn, Key) when is_atom(Key) -> + mysql:unprepare(Conn, Key); +unprepare_sql_to_conn(_Conn, _) -> + ok. parse_prepare_sql(Config) -> - SQL = - case maps:get(prepare_statement, Config, undefined) of - undefined -> - case maps:get(sql, Config, undefined) of - undefined -> #{}; - Template -> #{send_message => Template} - end; - Any -> - Any + Queries = + case Config of + #{prepare_statement := Qs} -> + Qs; + #{sql := Query} -> + #{send_message => Query}; + _ -> + #{} end, - parse_prepare_sql(maps:to_list(SQL), #{}, #{}, #{}, #{}). + Templates = maps:fold(fun parse_prepare_sql/3, #{}, Queries), + #{query_templates => Templates}. -parse_prepare_sql([{Key, H} | _] = L, Prepares, Tokens, BatchInserts, BatchTks) -> - {PrepareSQL, ParamsTokens} = emqx_placeholder:preproc_sql(H), - parse_batch_prepare_sql( - L, Prepares#{Key => PrepareSQL}, Tokens#{Key => ParamsTokens}, BatchInserts, BatchTks - ); -parse_prepare_sql([], Prepares, Tokens, BatchInserts, BatchTks) -> - #{ - prepare_statement => Prepares, - params_tokens => Tokens, - batch_inserts => BatchInserts, - batch_params_tokens => BatchTks - }. +parse_prepare_sql(Key, Query, Acc) -> + Template = emqx_template_sql:parse_prepstmt(Query, #{parameters => '?'}), + AccNext = Acc#{{Key, prepstmt} => Template}, + parse_batch_sql(Key, Query, AccNext). -parse_batch_prepare_sql([{Key, H} | T], Prepares, Tokens, BatchInserts, BatchTks) -> - case emqx_utils_sql:get_statement_type(H) of - select -> - parse_prepare_sql(T, Prepares, Tokens, BatchInserts, BatchTks); +parse_batch_sql(Key, Query, Acc) -> + case emqx_utils_sql:get_statement_type(Query) of insert -> - case emqx_utils_sql:parse_insert(H) of - {ok, {InsertSQL, Params}} -> - ParamsTks = emqx_placeholder:preproc_tmpl(Params), - parse_prepare_sql( - T, - Prepares, - Tokens, - BatchInserts#{Key => InsertSQL}, - BatchTks#{Key => ParamsTks} - ); + case emqx_utils_sql:parse_insert(Query) of + {ok, {Insert, Params}} -> + RowTemplate = emqx_template_sql:parse(Params), + Acc#{{Key, batch} => {Insert, RowTemplate}}; {error, Reason} -> - ?SLOG(error, #{msg => "split_sql_failed", sql => H, reason => Reason}), - parse_prepare_sql(T, Prepares, Tokens, BatchInserts, BatchTks) + ?SLOG(error, #{ + msg => "parse insert sql statement failed", + sql => Query, + reason => Reason + }), + Acc end; - Type when is_atom(Type) -> - ?SLOG(error, #{msg => "detect_sql_type_unsupported", sql => H, type => Type}), - parse_prepare_sql(T, Prepares, Tokens, BatchInserts, BatchTks); - {error, Reason} -> - ?SLOG(error, #{msg => "detect_sql_type_failed", sql => H, reason => Reason}), - parse_prepare_sql(T, Prepares, Tokens, BatchInserts, BatchTks) + select -> + Acc; + Otherwise -> + ?SLOG(error, #{ + msg => "invalid sql statement type", + sql => Query, + type => Otherwise + }), + Acc end. proc_sql_params(query, SQLOrKey, Params, _State) -> {SQLOrKey, Params}; proc_sql_params(prepared_query, SQLOrKey, Params, _State) -> {SQLOrKey, Params}; -proc_sql_params(TypeOrKey, SQLOrData, Params, #{params_tokens := ParamsTokens}) -> - case maps:get(TypeOrKey, ParamsTokens, undefined) of +proc_sql_params(TypeOrKey, SQLOrData, Params, #{query_templates := Templates}) -> + case maps:get({TypeOrKey, prepstmt}, Templates, undefined) of undefined -> {SQLOrData, Params}; - Tokens -> - {TypeOrKey, emqx_placeholder:proc_sql(Tokens, SQLOrData)} + {_InsertPart, RowTemplate} -> + % NOTE + % Ignoring errors here, missing variables are set to `null`. + {Row, _Errors} = emqx_template_sql:render_prepstmt( + RowTemplate, + {emqx_jsonish, SQLOrData} + ), + {TypeOrKey, Row} end. -on_batch_insert(InstId, BatchReqs, InsertPart, Tokens, State) -> - ValuesPart = lists:join($,, [ - emqx_placeholder:proc_param_str(Tokens, Msg, fun emqx_placeholder:quote_mysql/1) - || {_, Msg} <- BatchReqs - ]), - Query = [InsertPart, <<" values ">> | ValuesPart], +on_batch_insert(InstId, BatchReqs, {InsertPart, RowTemplate}, State) -> + Rows = [render_row(RowTemplate, Msg) || {_, Msg} <- BatchReqs], + Query = [InsertPart, <<" values ">> | lists:join($,, Rows)], on_sql_query(InstId, query, Query, no_params, default_timeout, State). +render_row(RowTemplate, Data) -> + % NOTE + % Ignoring errors here, missing variables are set to "'undefined'" due to backward + % compatibility requirements. + RenderOpts = #{escaping => mysql, undefined => <<"undefined">>}, + {Row, _Errors} = emqx_template_sql:render(RowTemplate, {emqx_jsonish, Data}, RenderOpts), + Row. + on_sql_query( InstId, SQLFunc, diff --git a/apps/emqx_postgresql/src/emqx_postgresql.erl b/apps/emqx_postgresql/src/emqx_postgresql.erl index dc6447536..814d8a074 100644 --- a/apps/emqx_postgresql/src/emqx_postgresql.erl +++ b/apps/emqx_postgresql/src/emqx_postgresql.erl @@ -52,15 +52,12 @@ default_port => ?PGSQL_DEFAULT_PORT }). --type prepares() :: #{atom() => binary()}. --type params_tokens() :: #{atom() => list()}. - +-type template() :: {unicode:chardata(), emqx_template_sql:row_template()}. -type state() :: #{ pool_name := binary(), - prepare_sql := prepares(), - params_tokens := params_tokens(), - prepare_statement := epgsql:statement() + query_templates := #{binary() => template()}, + prepares := #{binary() => epgsql:statement()} | {error, _} }. %% FIXME: add `{error, sync_required}' to `epgsql:execute_batch' @@ -142,7 +139,7 @@ on_start( State = parse_prepare_sql(Config), case emqx_resource_pool:start(InstId, ?MODULE, Options ++ SslOpts) of ok -> - {ok, init_prepare(State#{pool_name => InstId, prepare_statement => #{}})}; + {ok, init_prepare(State#{pool_name => InstId, prepares => #{}})}; {error, Reason} -> ?tp( pgsql_connector_start_failed, @@ -189,55 +186,50 @@ pgsql_query_type(_) -> on_batch_query( InstId, - BatchReq, - #{pool_name := PoolName, params_tokens := Tokens, prepare_statement := Sts} = State + [{Key, _} = Request | _] = BatchReq, + #{pool_name := PoolName, query_templates := Templates, prepares := PrepStatements} = State ) -> - case BatchReq of - [{Key, _} = Request | _] -> - BinKey = to_bin(Key), - case maps:get(BinKey, Tokens, undefined) of - undefined -> - Log = #{ - connector => InstId, - first_request => Request, - state => State, - msg => "batch_prepare_not_implemented" - }, - ?SLOG(error, Log), - {error, {unrecoverable_error, batch_prepare_not_implemented}}; - TokenList -> - {_, Datas} = lists:unzip(BatchReq), - Datas2 = [emqx_placeholder:proc_sql(TokenList, Data) || Data <- Datas], - St = maps:get(BinKey, Sts), - case on_sql_query(InstId, PoolName, execute_batch, St, Datas2) of - {error, _Error} = Result -> - handle_result(Result); - {_Column, Results} -> - handle_batch_result(Results, 0) - end - end; - _ -> + BinKey = to_bin(Key), + case maps:get(BinKey, Templates, undefined) of + undefined -> Log = #{ connector => InstId, - request => BatchReq, + first_request => Request, state => State, - msg => "invalid_request" + msg => "batch prepare not implemented" }, ?SLOG(error, Log), - {error, {unrecoverable_error, invalid_request}} - end. + {error, {unrecoverable_error, batch_prepare_not_implemented}}; + {_Statement, RowTemplate} -> + PrepStatement = maps:get(BinKey, PrepStatements), + Rows = [render_prepare_sql_row(RowTemplate, Data) || {_Key, Data} <- BatchReq], + case on_sql_query(InstId, PoolName, execute_batch, PrepStatement, Rows) of + {error, _Error} = Result -> + handle_result(Result); + {_Column, Results} -> + handle_batch_result(Results, 0) + end + end; +on_batch_query(InstId, BatchReq, State) -> + ?SLOG(error, #{ + connector => InstId, + request => BatchReq, + state => State, + msg => "invalid request" + }), + {error, {unrecoverable_error, invalid_request}}. proc_sql_params(query, SQLOrKey, Params, _State) -> {SQLOrKey, Params}; proc_sql_params(prepared_query, SQLOrKey, Params, _State) -> {SQLOrKey, Params}; -proc_sql_params(TypeOrKey, SQLOrData, Params, #{params_tokens := ParamsTokens}) -> +proc_sql_params(TypeOrKey, SQLOrData, Params, #{query_templates := Templates}) -> Key = to_bin(TypeOrKey), - case maps:get(Key, ParamsTokens, undefined) of + case maps:get(Key, Templates, undefined) of undefined -> {SQLOrData, Params}; - Tokens -> - {Key, emqx_placeholder:proc_sql(Tokens, SQLOrData)} + {_Statement, RowTemplate} -> + {Key, render_prepare_sql_row(RowTemplate, SQLOrData)} end. on_sql_query(InstId, PoolName, Type, NameOrSQL, Data) -> @@ -297,9 +289,9 @@ on_get_status(_InstId, #{pool_name := PoolName} = State) -> {ok, NState} -> %% return new state with prepared statements {connected, NState}; - {error, {undefined_table, NState}} -> + {error, undefined_table} -> %% return new state indicating that we are connected but the target table is not created - {disconnected, NState, unhealthy_target}; + {disconnected, State, unhealthy_target}; {error, _Reason} -> %% do not log error, it is logged in prepare_sql_to_conn connecting @@ -314,29 +306,26 @@ do_get_status(Conn) -> do_check_prepares( #{ pool_name := PoolName, - prepare_sql := #{<<"send_message">> := SQL} - } = State + query_templates := #{<<"send_message">> := {SQL, _RowTemplate}} + } ) -> WorkerPids = [Worker || {_WorkerName, Worker} <- ecpool:workers(PoolName)], case validate_table_existence(WorkerPids, SQL) of ok -> ok; - {error, undefined_table} -> - {error, {undefined_table, State}} + {error, Reason} -> + {error, Reason} end; -do_check_prepares(#{prepare_sql := Prepares}) when is_map(Prepares) -> +do_check_prepares(#{prepares := Prepares}) when is_map(Prepares) -> ok; -do_check_prepares(State = #{pool_name := PoolName, prepare_sql := {error, Prepares}}) -> +do_check_prepares(#{prepares := {error, _}} = State) -> %% retry to prepare - case prepare_sql(Prepares, PoolName) of - {ok, Sts} -> + case prepare_sql(State) of + {ok, PrepStatements} -> %% remove the error - {ok, State#{prepare_sql => Prepares, prepare_statement := Sts}}; - {error, undefined_table} -> - %% indicate the error - {error, {undefined_table, State#{prepare_sql => {error, Prepares}}}}; - Error -> - {error, Error} + {ok, State#{prepares := PrepStatements}}; + {error, Reason} -> + {error, Reason} end. -spec validate_table_existence([pid()], binary()) -> ok | {error, undefined_table}. @@ -426,69 +415,66 @@ conn_opts([_Opt | Opts], Acc) -> conn_opts(Opts, Acc). parse_prepare_sql(Config) -> - SQL = - case maps:get(prepare_statement, Config, undefined) of - undefined -> - case maps:get(sql, Config, undefined) of - undefined -> #{}; - Template -> #{<<"send_message">> => Template} - end; - Any -> - Any + Queries = + case Config of + #{prepare_statement := Qs} -> + Qs; + #{sql := Query} -> + #{<<"send_message">> => Query}; + #{} -> + #{} end, - parse_prepare_sql(maps:to_list(SQL), #{}, #{}). + Templates = maps:fold(fun parse_prepare_sql/3, #{}, Queries), + #{query_templates => Templates}. -parse_prepare_sql([{Key, H} | T], Prepares, Tokens) -> - {PrepareSQL, ParamsTokens} = emqx_placeholder:preproc_sql(H, '$n'), - parse_prepare_sql( - T, Prepares#{Key => PrepareSQL}, Tokens#{Key => ParamsTokens} - ); -parse_prepare_sql([], Prepares, Tokens) -> - #{ - prepare_sql => Prepares, - params_tokens => Tokens - }. +parse_prepare_sql(Key, Query, Acc) -> + Template = emqx_template_sql:parse_prepstmt(Query, #{parameters => '$n'}), + Acc#{Key => Template}. -init_prepare(State = #{prepare_sql := Prepares, pool_name := PoolName}) -> - case maps:size(Prepares) of - 0 -> - State; - _ -> - case prepare_sql(Prepares, PoolName) of - {ok, Sts} -> - State#{prepare_statement := Sts}; - Error -> - LogMsg = - maps:merge( - #{msg => <<"postgresql_init_prepare_statement_failed">>}, - translate_to_log_context(Error) - ), - ?SLOG(error, LogMsg), - %% mark the prepare_sql as failed - State#{prepare_sql => {error, Prepares}} - end +render_prepare_sql_row(RowTemplate, Data) -> + % NOTE: ignoring errors here, missing variables will be replaced with `null`. + {Row, _Errors} = emqx_template_sql:render_prepstmt(RowTemplate, {emqx_jsonish, Data}), + Row. + +init_prepare(State = #{query_templates := Templates}) when map_size(Templates) == 0 -> + State; +init_prepare(State = #{}) -> + case prepare_sql(State) of + {ok, PrepStatements} -> + State#{prepares => PrepStatements}; + Error -> + ?SLOG( + error, + maps:merge( + #{msg => <<"postgresql_init_prepare_statement_failed">>}, + translate_to_log_context(Error) + ) + ), + %% mark the prepares failed + State#{prepares => Error} end. -prepare_sql(Prepares, PoolName) when is_map(Prepares) -> - prepare_sql(maps:to_list(Prepares), PoolName); -prepare_sql(Prepares, PoolName) -> - case do_prepare_sql(Prepares, PoolName) of +prepare_sql(#{query_templates := Templates, pool_name := PoolName}) -> + prepare_sql(maps:to_list(Templates), PoolName). + +prepare_sql(Templates, PoolName) -> + case do_prepare_sql(Templates, PoolName) of {ok, _Sts} = Ok -> %% prepare for reconnect - ecpool:add_reconnect_callback(PoolName, {?MODULE, prepare_sql_to_conn, [Prepares]}), + ecpool:add_reconnect_callback(PoolName, {?MODULE, prepare_sql_to_conn, [Templates]}), Ok; Error -> Error end. -do_prepare_sql(Prepares, PoolName) -> - do_prepare_sql(ecpool:workers(PoolName), Prepares, #{}). +do_prepare_sql(Templates, PoolName) -> + do_prepare_sql(ecpool:workers(PoolName), Templates, #{}). -do_prepare_sql([{_Name, Worker} | T], Prepares, _LastSts) -> +do_prepare_sql([{_Name, Worker} | Rest], Templates, _LastSts) -> {ok, Conn} = ecpool_worker:client(Worker), - case prepare_sql_to_conn(Conn, Prepares) of + case prepare_sql_to_conn(Conn, Templates) of {ok, Sts} -> - do_prepare_sql(T, Prepares, Sts); + do_prepare_sql(Rest, Templates, Sts); Error -> Error end; @@ -498,13 +484,14 @@ do_prepare_sql([], _Prepares, LastSts) -> prepare_sql_to_conn(Conn, Prepares) -> prepare_sql_to_conn(Conn, Prepares, #{}). -prepare_sql_to_conn(Conn, [], Statements) when is_pid(Conn) -> {ok, Statements}; -prepare_sql_to_conn(Conn, [{Key, SQL} | PrepareList], Statements) when is_pid(Conn) -> - LogMeta = #{msg => "postgresql_prepare_statement", name => Key, prepare_sql => SQL}, +prepare_sql_to_conn(Conn, [], Statements) when is_pid(Conn) -> + {ok, Statements}; +prepare_sql_to_conn(Conn, [{Key, {SQL, _RowTemplate}} | Rest], Statements) when is_pid(Conn) -> + LogMeta = #{msg => "postgresql_prepare_statement", name => Key, sql => SQL}, ?SLOG(info, LogMeta), case epgsql:parse2(Conn, Key, SQL, []) of {ok, Statement} -> - prepare_sql_to_conn(Conn, PrepareList, Statements#{Key => Statement}); + prepare_sql_to_conn(Conn, Rest, Statements#{Key => Statement}); {error, {error, error, _, undefined_table, _, _} = Error} -> %% Target table is not created ?tp(pgsql_undefined_table, #{}), diff --git a/apps/emqx_prometheus/src/emqx_prometheus.app.src b/apps/emqx_prometheus/src/emqx_prometheus.app.src index c4abbec27..4631fec8b 100644 --- a/apps/emqx_prometheus/src/emqx_prometheus.app.src +++ b/apps/emqx_prometheus/src/emqx_prometheus.app.src @@ -2,7 +2,7 @@ {application, emqx_prometheus, [ {description, "Prometheus for EMQX"}, % strict semver, bump manually! - {vsn, "5.0.16"}, + {vsn, "5.0.17"}, {modules, []}, {registered, [emqx_prometheus_sup]}, {applications, [kernel, stdlib, prometheus, emqx, emqx_management]}, diff --git a/apps/emqx_prometheus/src/emqx_prometheus.erl b/apps/emqx_prometheus/src/emqx_prometheus.erl index e9030d3ed..a242931c4 100644 --- a/apps/emqx_prometheus/src/emqx_prometheus.erl +++ b/apps/emqx_prometheus/src/emqx_prometheus.erl @@ -24,7 +24,6 @@ -include("emqx_prometheus.hrl"). --include_lib("prometheus/include/prometheus.hrl"). -include_lib("prometheus/include/prometheus_model.hrl"). -include_lib("emqx/include/logger.hrl"). @@ -114,16 +113,20 @@ handle_info(_Msg, State) -> push_to_push_gateway(Uri, Headers, JobName) when is_list(Headers) -> [Name, Ip] = string:tokens(atom_to_list(node()), "@"), - JobName1 = emqx_placeholder:preproc_tmpl(JobName), - JobName2 = binary_to_list( - emqx_placeholder:proc_tmpl( - JobName1, - #{<<"name">> => Name, <<"host">> => Ip} - ) + % NOTE: allowing errors here to keep rough backward compatibility + {JobName1, Errors} = emqx_template:render( + emqx_template:parse(JobName), + #{<<"name">> => Name, <<"host">> => Ip} ), - - Url = lists:concat([Uri, "/metrics/job/", JobName2]), + _ = + Errors == [] orelse + ?SLOG(warning, #{ + msg => "prometheus_job_name_template_invalid", + errors => Errors, + template => JobName + }), Data = prometheus_text_format:format(), + Url = lists:concat([Uri, "/metrics/job/", unicode:characters_to_list(JobName1)]), case httpc:request(post, {Url, Headers, "text/plain", Data}, ?HTTP_OPTIONS, []) of {ok, {{"HTTP/1.1", 200, _}, _RespHeaders, _RespBody}} -> ok; diff --git a/apps/emqx_rule_engine/src/emqx_rule_actions.erl b/apps/emqx_rule_engine/src/emqx_rule_actions.erl index f136cd5df..cd8d597de 100644 --- a/apps/emqx_rule_engine/src/emqx_rule_actions.erl +++ b/apps/emqx_rule_engine/src/emqx_rule_actions.erl @@ -82,23 +82,18 @@ pre_process_action_args( qos := QoS, retain := Retain, payload := Payload, - mqtt_properties := MQTTPropertiesTemplate0, - user_properties := UserPropertiesTemplate + mqtt_properties := MQTTProperties, + user_properties := UserProperties } = Args ) -> - MQTTPropertiesTemplate = - maps:map( - fun(_Key, V) -> emqx_placeholder:preproc_tmpl(V) end, - MQTTPropertiesTemplate0 - ), Args#{ preprocessed_tmpl => #{ - topic => emqx_placeholder:preproc_tmpl(Topic), - qos => preproc_vars(QoS), - retain => preproc_vars(Retain), - payload => emqx_placeholder:preproc_tmpl(Payload), - mqtt_properties => MQTTPropertiesTemplate, - user_properties => preproc_user_properties(UserPropertiesTemplate) + topic => emqx_template:parse(Topic), + qos => parse_simple_var(QoS), + retain => parse_simple_var(Retain), + payload => parse_payload(Payload), + mqtt_properties => parse_mqtt_properties(MQTTProperties), + user_properties => parse_user_properties(UserProperties) } }; pre_process_action_args(_, Args) -> @@ -131,25 +126,28 @@ republish( #{metadata := #{rule_id := RuleId}} = Env, #{ preprocessed_tmpl := #{ - qos := QoSTks, - retain := RetainTks, - topic := TopicTks, - payload := PayloadTks, + qos := QoSTemplate, + retain := RetainTemplate, + topic := TopicTemplate, + payload := PayloadTemplate, mqtt_properties := MQTTPropertiesTemplate, - user_properties := UserPropertiesTks + user_properties := UserPropertiesTemplate } } ) -> - Topic = emqx_placeholder:proc_tmpl(TopicTks, Selected), - Payload = format_msg(PayloadTks, Selected), - QoS = replace_simple_var(QoSTks, Selected, 0), - Retain = replace_simple_var(RetainTks, Selected, false), + % NOTE: rendering missing bindings as string "undefined" + {TopicString, _Errors1} = render_template(TopicTemplate, Selected), + {PayloadString, _Errors2} = render_template(PayloadTemplate, Selected), + Topic = iolist_to_binary(TopicString), + Payload = iolist_to_binary(PayloadString), + QoS = render_simple_var(QoSTemplate, Selected, 0), + Retain = render_simple_var(RetainTemplate, Selected, false), %% 'flags' is set for message re-publishes or message related %% events such as message.acked and message.dropped Flags0 = maps:get(flags, Env, #{}), Flags = Flags0#{retain => Retain}, - PubProps0 = format_pub_props(UserPropertiesTks, Selected, Env), - MQTTProps = format_mqtt_properties(MQTTPropertiesTemplate, Selected, Env), + PubProps0 = render_pub_props(UserPropertiesTemplate, Selected, Env), + MQTTProps = render_mqtt_properties(MQTTPropertiesTemplate, Selected, Env), PubProps = maps:merge(PubProps0, MQTTProps), ?TRACE( "RULE", @@ -220,79 +218,90 @@ safe_publish(RuleId, Topic, QoS, Flags, Payload, PubProps) -> _ = emqx_broker:safe_publish(Msg), emqx_metrics:inc_msg(Msg). -preproc_vars(Data) when is_binary(Data) -> - emqx_placeholder:preproc_tmpl(Data); -preproc_vars(Data) -> - Data. +parse_simple_var(Data) when is_binary(Data) -> + emqx_template:parse(Data); +parse_simple_var(Data) -> + {const, Data}. -preproc_user_properties(<<"${pub_props.'User-Property'}">>) -> +parse_payload(Payload) -> + case string:is_empty(Payload) of + false -> emqx_template:parse(Payload); + true -> emqx_template:parse("${.}") + end. + +parse_mqtt_properties(MQTTPropertiesTemplate) -> + maps:map( + fun(_Key, V) -> emqx_template:parse(V) end, + MQTTPropertiesTemplate + ). + +parse_user_properties(<<"${pub_props.'User-Property'}">>) -> %% keep the original %% avoid processing this special variable because %% we do not want to force users to select the value %% the value will be taken from Env.pub_props directly ?ORIGINAL_USER_PROPERTIES; -preproc_user_properties(<<"${", _/binary>> = V) -> +parse_user_properties(<<"${", _/binary>> = V) -> %% use a variable - emqx_placeholder:preproc_tmpl(V); -preproc_user_properties(_) -> + emqx_template:parse(V); +parse_user_properties(_) -> %% invalid, discard undefined. -replace_simple_var(Tokens, Data, Default) when is_list(Tokens) -> - [Var] = emqx_placeholder:proc_tmpl(Tokens, Data, #{return => rawlist}), - case Var of +render_template(Template, Bindings) -> + emqx_template:render(Template, {emqx_jsonish, Bindings}). + +render_simple_var([{var, _Name, Accessor}], Data, Default) -> + case emqx_jsonish:lookup(Accessor, Data) of + {ok, Var} -> Var; %% cannot find the variable from Data - undefined -> Default; - _ -> Var + {error, _} -> Default end; -replace_simple_var(Val, _Data, _Default) -> +render_simple_var({const, Val}, _Data, _Default) -> Val. -format_msg([], Selected) -> - emqx_utils_json:encode(Selected); -format_msg(Tokens, Selected) -> - emqx_placeholder:proc_tmpl(Tokens, Selected). - -format_pub_props(UserPropertiesTks, Selected, Env) -> +render_pub_props(UserPropertiesTemplate, Selected, Env) -> UserProperties = - case UserPropertiesTks of + case UserPropertiesTemplate of ?ORIGINAL_USER_PROPERTIES -> maps:get('User-Property', maps:get(pub_props, Env, #{}), #{}); undefined -> #{}; _ -> - replace_simple_var(UserPropertiesTks, Selected, #{}) + render_simple_var(UserPropertiesTemplate, Selected, #{}) end, #{'User-Property' => UserProperties}. -format_mqtt_properties(MQTTPropertiesTemplate, Selected, Env) -> - #{metadata := #{rule_id := RuleId}} = Env, - MQTTProperties0 = - maps:fold( - fun(K, Template, Acc) -> - try - V = emqx_placeholder:proc_tmpl(Template, Selected), - Acc#{K => V} - catch - Kind:Error -> - ?SLOG( - debug, - #{ - msg => "bad_mqtt_property_value_ignored", - rule_id => RuleId, - exception => Kind, - reason => Error, - property => K, - selected => Selected - } - ), - Acc - end +%% + +-define(BADPROP(K, REASON, ENV, DATA), + ?SLOG( + debug, + DATA#{ + msg => "bad_mqtt_property_value_ignored", + rule_id => emqx_utils_maps:deep_get([metadata, rule_id], ENV, undefined), + reason => REASON, + property => K + } + ) +). + +render_mqtt_properties(MQTTPropertiesTemplate, Selected, Env) -> + MQTTProperties = + maps:map( + fun(K, Template) -> + {V, Errors} = render_template(Template, Selected), + case Errors of + [] -> + ok; + Errors -> + ?BADPROP(K, Errors, Env, #{selected => Selected}) + end, + iolist_to_binary(V) end, - #{}, MQTTPropertiesTemplate ), - coerce_properties_values(MQTTProperties0, Env). + coerce_properties_values(MQTTProperties, Env). ensure_int(B) when is_binary(B) -> try @@ -304,42 +313,24 @@ ensure_int(B) when is_binary(B) -> ensure_int(I) when is_integer(I) -> I. -coerce_properties_values(MQTTProperties, #{metadata := #{rule_id := RuleId}}) -> - maps:fold( - fun(K, V0, Acc) -> +coerce_properties_values(MQTTProperties, Env) -> + maps:filtermap( + fun(K, V) -> try - V = encode_mqtt_property(K, V0), - Acc#{K => V} + {true, encode_mqtt_property(K, V)} catch - throw:bad_integer -> - ?SLOG( - debug, - #{ - msg => "bad_mqtt_property_value_ignored", - rule_id => RuleId, - reason => bad_integer, - property => K, - value => V0 - } - ), - Acc; + throw:Reason -> + ?BADPROP(K, Reason, Env, #{value => V}), + false; Kind:Reason:Stacktrace -> - ?SLOG( - debug, - #{ - msg => "bad_mqtt_property_value_ignored", - rule_id => RuleId, - exception => Kind, - reason => Reason, - property => K, - value => V0, - stacktrace => Stacktrace - } - ), - Acc + ?BADPROP(K, Reason, Env, #{ + value => V, + exception => Kind, + stacktrace => Stacktrace + }), + false end end, - #{}, MQTTProperties ). diff --git a/apps/emqx_rule_engine/src/emqx_rule_engine.app.src b/apps/emqx_rule_engine/src/emqx_rule_engine.app.src index c353742ae..cad752886 100644 --- a/apps/emqx_rule_engine/src/emqx_rule_engine.app.src +++ b/apps/emqx_rule_engine/src/emqx_rule_engine.app.src @@ -5,7 +5,16 @@ {vsn, "5.0.28"}, {modules, []}, {registered, [emqx_rule_engine_sup, emqx_rule_engine]}, - {applications, [kernel, stdlib, rulesql, getopt, emqx_ctl, uuid]}, + {applications, [ + kernel, + stdlib, + rulesql, + getopt, + uuid, + emqx, + emqx_utils, + emqx_ctl + ]}, {mod, {emqx_rule_engine_app, []}}, {env, []}, {licenses, ["Apache-2.0"]}, diff --git a/apps/emqx_rule_engine/test/emqx_rule_engine_SUITE.erl b/apps/emqx_rule_engine/test/emqx_rule_engine_SUITE.erl index 00ca68264..41fec48ee 100644 --- a/apps/emqx_rule_engine/test/emqx_rule_engine_SUITE.erl +++ b/apps/emqx_rule_engine/test/emqx_rule_engine_SUITE.erl @@ -81,6 +81,7 @@ groups() -> t_sqlselect_3, t_sqlselect_message_publish_event_keep_original_props_1, t_sqlselect_message_publish_event_keep_original_props_2, + t_sqlselect_missing_template_vars_render_as_undefined, t_sqlparse_event_1, t_sqlparse_event_2, t_sqlparse_event_3, @@ -1364,14 +1365,13 @@ t_sqlselect_inject_props(_Config) -> actions => [Repub] } ), - Props = user_properties(#{<<"inject_key">> => <<"inject_val">>}), {ok, Client} = emqtt:start_link([{username, <<"emqx">>}, {proto_ver, v5}]), {ok, _} = emqtt:connect(Client), {ok, _, _} = emqtt:subscribe(Client, <<"t2">>, 0), emqtt:publish(Client, <<"t1">>, #{}, <<"{\"x\":1}">>, [{qos, 0}]), receive - {publish, #{topic := T, payload := Payload, properties := Props2}} -> - ?assertEqual(Props, Props2), + {publish, #{topic := T, payload := Payload, properties := Props}} -> + ?assertEqual(user_properties(#{<<"inject_key">> => <<"inject_val">>}), Props), ?assertEqual(<<"t2">>, T), ?assertEqual(<<"{\"x\":1}">>, Payload) after 2000 -> @@ -1947,6 +1947,32 @@ t_sqlselect_as_put(_Config) -> PayloadMap2 ). +t_sqlselect_missing_template_vars_render_as_undefined(_Config) -> + SQL = <<"SELECT * FROM \"$events/client_connected\"">>, + Repub = republish_action(<<"t2">>, <<"${clientid}:${missing.var}">>), + {ok, TopicRule} = emqx_rule_engine:create_rule( + #{ + sql => SQL, + id => ?TMP_RULEID, + actions => [Repub] + } + ), + {ok, Client1} = emqtt:start_link([{clientid, <<"sub-01">>}]), + {ok, _} = emqtt:connect(Client1), + {ok, _, _} = emqtt:subscribe(Client1, <<"t2">>), + {ok, Client2} = emqtt:start_link([{clientid, <<"pub-02">>}]), + {ok, _} = emqtt:connect(Client2), + emqtt:publish(Client2, <<"foo/bar/1">>, <<>>), + receive + {publish, Msg} -> + ?assertMatch(#{topic := <<"t2">>, payload := <<"pub-02:undefined">>}, Msg) + after 2000 -> + ct:fail(wait_for_t2) + end, + emqtt:stop(Client2), + emqtt:stop(Client1), + delete_rule(TopicRule). + t_sqlparse_event_1(_Config) -> Sql = "select topic as tp " diff --git a/apps/emqx_utils/src/emqx_jsonish.erl b/apps/emqx_utils/src/emqx_jsonish.erl new file mode 100644 index 000000000..ef26da1d8 --- /dev/null +++ b/apps/emqx_utils/src/emqx_jsonish.erl @@ -0,0 +1,72 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2020-2022 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_jsonish). + +-behaviour(emqx_template). +-export([lookup/2]). + +-export_type([t/0]). + +%% @doc Either a map or a JSON serial. +%% Think of it as a kind of lazily parsed and/or constructed JSON. +-type t() :: propmap() | serial(). + +%% @doc JSON in serialized form. +-type serial() :: binary(). + +-type propmap() :: #{prop() => value()}. +-type prop() :: atom() | binary(). +-type value() :: scalar() | [scalar() | propmap()] | t(). +-type scalar() :: atom() | unicode:chardata() | number(). + +%% + +%% @doc Lookup a value in the JSON-ish map accessible through the given accessor. +%% If accessor implies drilling down into a binary, it will be treated as JSON serial. +%% Failure to parse the binary as JSON will result in an _invalid type_ error. +%% Nested JSON is NOT parsed recursively. +-spec lookup(emqx_template:accessor(), t()) -> + {ok, value()} + | {error, undefined | {_Location :: non_neg_integer(), _InvalidType :: atom()}}. +lookup(Var, Jsonish) -> + lookup(0, _Decoded = false, Var, Jsonish). + +lookup(_, _, [], Value) -> + {ok, Value}; +lookup(Loc, Decoded, [Prop | Rest], Jsonish) when is_map(Jsonish) -> + case emqx_template:lookup(Prop, Jsonish) of + {ok, Value} -> + lookup(Loc + 1, Decoded, Rest, Value); + {error, Reason} -> + {error, Reason} + end; +lookup(Loc, _Decoded = false, Props, Json) when is_binary(Json) -> + try emqx_utils_json:decode(Json) of + Value -> + % NOTE: This is intentional, we don't want to parse nested JSON. + lookup(Loc, true, Props, Value) + catch + error:_ -> + {error, {Loc, binary}} + end; +lookup(Loc, _, _, Invalid) -> + {error, {Loc, type_name(Invalid)}}. + +type_name(Term) when is_atom(Term) -> atom; +type_name(Term) when is_number(Term) -> number; +type_name(Term) when is_binary(Term) -> binary; +type_name(Term) when is_list(Term) -> list. diff --git a/apps/emqx_utils/src/emqx_placeholder.erl b/apps/emqx_utils/src/emqx_placeholder.erl index 4d386840f..90df6003b 100644 --- a/apps/emqx_utils/src/emqx_placeholder.erl +++ b/apps/emqx_utils/src/emqx_placeholder.erl @@ -249,15 +249,15 @@ bin(Val) -> emqx_utils_conv:bin(Val). -spec quote_sql(_Value) -> iolist(). quote_sql(Str) -> - emqx_utils_sql:to_sql_string(Str, #{escaping => sql}). + emqx_utils_sql:to_sql_string(Str, #{escaping => sql, undefined => <<"undefined">>}). -spec quote_cql(_Value) -> iolist(). quote_cql(Str) -> - emqx_utils_sql:to_sql_string(Str, #{escaping => cql}). + emqx_utils_sql:to_sql_string(Str, #{escaping => cql, undefined => <<"undefined">>}). -spec quote_mysql(_Value) -> iolist(). quote_mysql(Str) -> - emqx_utils_sql:to_sql_string(Str, #{escaping => mysql}). + emqx_utils_sql:to_sql_string(Str, #{escaping => mysql, undefined => <<"undefined">>}). lookup_var(Var, Value) when Var == ?PH_VAR_THIS orelse Var == [] -> Value; diff --git a/apps/emqx_utils/src/emqx_template.erl b/apps/emqx_utils/src/emqx_template.erl new file mode 100644 index 000000000..ac330becf --- /dev/null +++ b/apps/emqx_utils/src/emqx_template.erl @@ -0,0 +1,386 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2020-2022 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_template). + +-export([parse/1]). +-export([parse/2]). +-export([parse_deep/1]). +-export([parse_deep/2]). +-export([validate/2]). +-export([is_const/1]). +-export([unparse/1]). +-export([render/2]). +-export([render/3]). +-export([render_strict/2]). +-export([render_strict/3]). + +-export([lookup_var/2]). +-export([lookup/2]). + +-export([to_string/1]). + +-export_type([t/0]). +-export_type([str/0]). +-export_type([deep/0]). +-export_type([placeholder/0]). +-export_type([varname/0]). +-export_type([bindings/0]). +-export_type([accessor/0]). + +-export_type([context/0]). +-export_type([render_opts/0]). + +-type t() :: str() | {'$tpl', deeptpl()}. + +-type str() :: [iodata() | byte() | placeholder()]. +-type deep() :: {'$tpl', deeptpl()}. + +-type deeptpl() :: + t() + | #{deeptpl() => deeptpl()} + | {list, [deeptpl()]} + | {tuple, [deeptpl()]} + | scalar() + | function() + | pid() + | port() + | reference(). + +-type placeholder() :: {var, varname(), accessor()}. +-type accessor() :: [binary()]. +-type varname() :: string(). + +-type scalar() :: atom() | unicode:chardata() | number(). +-type binding() :: scalar() | list(scalar()) | bindings(). +-type bindings() :: #{atom() | binary() => binding()}. + +-type reason() :: undefined | {location(), _InvalidType :: atom()}. +-type location() :: non_neg_integer(). + +-type var_trans() :: + fun((Value :: term()) -> unicode:chardata()) + | fun((varname(), Value :: term()) -> unicode:chardata()). + +-type parse_opts() :: #{ + strip_double_quote => boolean() +}. + +-type render_opts() :: #{ + var_trans => var_trans() +}. + +-type context() :: + %% Map with (potentially nested) bindings. + bindings() + %% Arbitrary term accessible via an access module with `lookup/2` function. + | {_AccessModule :: module(), _Bindings}. + +%% Access module API +-callback lookup(accessor(), _Bindings) -> {ok, _Value} | {error, reason()}. + +-define(RE_PLACEHOLDER, "\\$\\{[.]?([a-zA-Z0-9._]*)\\}"). +-define(RE_ESCAPE, "\\$\\{(\\$)\\}"). + +%% @doc Parse a unicode string into a template. +%% String might contain zero or more of placeholders in the form of `${var}`, +%% where `var` is a _location_ (possibly deeply nested) of some value in the +%% bindings map. +%% String might contain special escaped form `$${...}` which interpreted as a +%% literal `${...}`. +-spec parse(String :: unicode:chardata()) -> + t(). +parse(String) -> + parse(String, #{}). + +-spec parse(String :: unicode:chardata(), parse_opts()) -> + t(). +parse(String, Opts) -> + RE = + case Opts of + #{strip_double_quote := true} -> + <<"((?|" ?RE_PLACEHOLDER "|\"" ?RE_PLACEHOLDER "\")|" ?RE_ESCAPE ")">>; + #{} -> + <<"(" ?RE_PLACEHOLDER "|" ?RE_ESCAPE ")">> + end, + Splits = re:split(String, RE, [{return, binary}, group, trim, unicode]), + lists:flatmap(fun parse_split/1, Splits). + +parse_split([Part, _PH, Var, <<>>]) -> + % Regular placeholder + prepend(Part, [{var, unicode:characters_to_list(Var), parse_accessor(Var)}]); +parse_split([Part, _Escape, <<>>, <<"$">>]) -> + % Escaped literal `$`. + % Use single char as token so the `unparse/1` function can distinguish escaped `$`. + prepend(Part, [$$]); +parse_split([Tail]) -> + [Tail]. + +prepend(<<>>, To) -> + To; +prepend(Head, To) -> + [Head | To]. + +parse_accessor(Var) -> + case string:split(Var, <<".">>, all) of + [<<>>] -> + []; + Name -> + Name + end. + +%% @doc Validate a template against a set of allowed variables. +%% If the given template contains any variable not in the allowed set, an error +%% is returned. +-spec validate([varname()], t()) -> + ok | {error, [_Error :: {varname(), disallowed}]}. +validate(Allowed, Template) -> + {_, Errors} = render(Template, #{}), + {Used, _} = lists:unzip(Errors), + case lists:usort(Used) -- Allowed of + [] -> + ok; + Disallowed -> + {error, [{Var, disallowed} || Var <- Disallowed]} + end. + +%% @doc Check if a template is constant with respect to rendering, i.e. does not +%% contain any placeholders. +-spec is_const(t()) -> + boolean(). +is_const(Template) -> + validate([], Template) == ok. + +%% @doc Restore original term from a parsed template. +-spec unparse(t()) -> + term(). +unparse({'$tpl', Template}) -> + unparse_deep(Template); +unparse(Template) -> + unicode:characters_to_list(lists:map(fun unparse_part/1, Template)). + +unparse_part({var, Name, _Accessor}) -> + render_placeholder(Name); +unparse_part($$) -> + <<"${$}">>; +unparse_part(Part) -> + Part. + +render_placeholder(Name) -> + "${" ++ Name ++ "}". + +%% @doc Render a template with given bindings. +%% Returns a term with all placeholders replaced with values from bindings. +%% If one or more placeholders are not found in bindings, an error is returned. +%% By default, all binding values are converted to strings using `to_string/1` +%% function. Option `var_trans` can be used to override this behaviour. +-spec render(t(), context()) -> + {term(), [_Error :: {varname(), reason()}]}. +render(Template, Context) -> + render(Template, Context, #{}). + +-spec render(t(), context(), render_opts()) -> + {term(), [_Error :: {varname(), undefined}]}. +render(Template, Context, Opts) when is_list(Template) -> + lists:mapfoldl( + fun + ({var, Name, Accessor}, EAcc) -> + {String, Errors} = render_binding(Name, Accessor, Context, Opts), + {String, Errors ++ EAcc}; + (String, EAcc) -> + {String, EAcc} + end, + [], + Template + ); +render({'$tpl', Template}, Context, Opts) -> + render_deep(Template, Context, Opts). + +render_binding(Name, Accessor, Context, Opts) -> + case lookup_value(Accessor, Context) of + {ok, Value} -> + {render_value(Name, Value, Opts), []}; + {error, Reason} -> + % TODO + % Currently, it's not possible to distinguish between a missing value + % and an atom `undefined` in `TransFun`. + {render_value(Name, undefined, Opts), [{Name, Reason}]} + end. + +lookup_value(Accessor, {AccessMod, Bindings}) -> + AccessMod:lookup(Accessor, Bindings); +lookup_value(Accessor, Bindings) -> + lookup_var(Accessor, Bindings). + +render_value(_Name, Value, #{var_trans := TransFun}) when is_function(TransFun, 1) -> + TransFun(Value); +render_value(Name, Value, #{var_trans := TransFun}) when is_function(TransFun, 2) -> + TransFun(Name, Value); +render_value(_Name, Value, #{}) -> + to_string(Value). + +%% @doc Render a template with given bindings. +%% Behaves like `render/2`, but raises an error exception if one or more placeholders +%% are not found in the bindings. +-spec render_strict(t(), context()) -> + term(). +render_strict(Template, Context) -> + render_strict(Template, Context, #{}). + +-spec render_strict(t(), context(), render_opts()) -> + term(). +render_strict(Template, Context, Opts) -> + case render(Template, Context, Opts) of + {Render, []} -> + Render; + {_, Errors = [_ | _]} -> + error(Errors, [unparse(Template), Context]) + end. + +%% @doc Parse an arbitrary Erlang term into a "deep" template. +%% Any binaries nested in the term are treated as string templates, while +%% lists are not analyzed for "printability" and are treated as nested terms. +%% The result is a usual template, and can be fed to other functions in this +%% module. +-spec parse_deep(term()) -> + t(). +parse_deep(Term) -> + parse_deep(Term, #{}). + +-spec parse_deep(term(), parse_opts()) -> + t(). +parse_deep(Term, Opts) -> + {'$tpl', parse_deep_term(Term, Opts)}. + +parse_deep_term(Term, Opts) when is_map(Term) -> + maps:fold( + fun(K, V, Acc) -> + Acc#{parse_deep_term(K, Opts) => parse_deep_term(V, Opts)} + end, + #{}, + Term + ); +parse_deep_term(Term, Opts) when is_list(Term) -> + {list, [parse_deep_term(E, Opts) || E <- Term]}; +parse_deep_term(Term, Opts) when is_tuple(Term) -> + {tuple, [parse_deep_term(E, Opts) || E <- tuple_to_list(Term)]}; +parse_deep_term(Term, Opts) when is_binary(Term) -> + parse(Term, Opts); +parse_deep_term(Term, _Opts) -> + Term. + +render_deep(Template, Context, Opts) when is_map(Template) -> + maps:fold( + fun(KT, VT, {Acc, Errors}) -> + {K, KErrors} = render_deep(KT, Context, Opts), + {V, VErrors} = render_deep(VT, Context, Opts), + {Acc#{K => V}, KErrors ++ VErrors ++ Errors} + end, + {#{}, []}, + Template + ); +render_deep({list, Template}, Context, Opts) when is_list(Template) -> + lists:mapfoldr( + fun(T, Errors) -> + {E, VErrors} = render_deep(T, Context, Opts), + {E, VErrors ++ Errors} + end, + [], + Template + ); +render_deep({tuple, Template}, Context, Opts) when is_list(Template) -> + {Term, Errors} = render_deep({list, Template}, Context, Opts), + {list_to_tuple(Term), Errors}; +render_deep(Template, Context, Opts) when is_list(Template) -> + {String, Errors} = render(Template, Context, Opts), + {unicode:characters_to_binary(String), Errors}; +render_deep(Term, _Bindings, _Opts) -> + {Term, []}. + +unparse_deep(Template) when is_map(Template) -> + maps:fold( + fun(K, V, Acc) -> + Acc#{unparse_deep(K) => unparse_deep(V)} + end, + #{}, + Template + ); +unparse_deep({list, Template}) when is_list(Template) -> + [unparse_deep(E) || E <- Template]; +unparse_deep({tuple, Template}) when is_list(Template) -> + list_to_tuple(unparse_deep({list, Template})); +unparse_deep(Template) when is_list(Template) -> + unicode:characters_to_binary(unparse(Template)); +unparse_deep(Term) -> + Term. + +%% + +%% @doc Lookup a variable in the bindings accessible through the accessor. +%% Lookup is "loose" in the sense that atom and binary keys in the bindings are +%% treated equally. This is useful for both hand-crafted and JSON-like bindings. +%% This is the default lookup function used by rendering functions. +-spec lookup_var(accessor(), bindings()) -> + {ok, binding()} | {error, reason()}. +lookup_var(Var, Bindings) -> + lookup_var(0, Var, Bindings). + +lookup_var(_, [], Value) -> + {ok, Value}; +lookup_var(Loc, [Prop | Rest], Bindings) when is_map(Bindings) -> + case lookup(Prop, Bindings) of + {ok, Value} -> + lookup_var(Loc + 1, Rest, Value); + {error, Reason} -> + {error, Reason} + end; +lookup_var(Loc, _, Invalid) -> + {error, {Loc, type_name(Invalid)}}. + +type_name(Term) when is_atom(Term) -> atom; +type_name(Term) when is_number(Term) -> number; +type_name(Term) when is_binary(Term) -> binary; +type_name(Term) when is_list(Term) -> list. + +-spec lookup(Prop :: binary(), bindings()) -> + {ok, binding()} | {error, undefined}. +lookup(Prop, Bindings) when is_binary(Prop) -> + case maps:get(Prop, Bindings, undefined) of + undefined -> + try + {ok, maps:get(binary_to_existing_atom(Prop, utf8), Bindings)} + catch + error:{badkey, _} -> + {error, undefined}; + error:badarg -> + {error, undefined} + end; + Value -> + {ok, Value} + end. + +-spec to_string(binding()) -> + unicode:chardata(). +to_string(Bin) when is_binary(Bin) -> Bin; +to_string(Num) when is_integer(Num) -> integer_to_binary(Num); +to_string(Num) when is_float(Num) -> float_to_binary(Num, [{decimals, 10}, compact]); +to_string(Atom) when is_atom(Atom) -> atom_to_binary(Atom, utf8); +to_string(Map) when is_map(Map) -> emqx_utils_json:encode(Map); +to_string(List) when is_list(List) -> + case io_lib:printable_unicode_list(List) of + true -> List; + false -> emqx_utils_json:encode(List) + end. diff --git a/apps/emqx_utils/src/emqx_template_sql.erl b/apps/emqx_utils/src/emqx_template_sql.erl new file mode 100644 index 000000000..9b2c1d55c --- /dev/null +++ b/apps/emqx_utils/src/emqx_template_sql.erl @@ -0,0 +1,142 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2020-2022 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_template_sql). + +-export([parse/1]). +-export([parse/2]). +-export([render/3]). +-export([render_strict/3]). + +-export([parse_prepstmt/2]). +-export([render_prepstmt/2]). +-export([render_prepstmt_strict/2]). + +-export_type([row_template/0]). + +-type template() :: emqx_template:str(). +-type row_template() :: [emqx_template:placeholder()]. +-type context() :: emqx_template:context(). + +-type values() :: [emqx_utils_sql:value()]. + +-type parse_opts() :: #{ + parameters => '$n' | ':n' | '?', + % Inherited from `emqx_template:parse_opts()` + strip_double_quote => boolean() +}. + +-type render_opts() :: #{ + %% String escaping rules to use. + %% Default: `sql` (generic) + escaping => sql | mysql | cql, + %% Value to map `undefined` to, either to NULLs or to arbitrary strings. + %% Default: `null` + undefined => null | unicode:chardata() +}. + +-define(TEMPLATE_PARSE_OPTS, [strip_double_quote]). + +%% + +%% @doc Parse an SQL statement string with zero or more placeholders into a template. +-spec parse(unicode:chardata()) -> + template(). +parse(String) -> + parse(String, #{}). + +%% @doc Parse an SQL statement string with zero or more placeholders into a template. +-spec parse(unicode:chardata(), parse_opts()) -> + template(). +parse(String, Opts) -> + emqx_template:parse(String, Opts). + +%% @doc Render an SQL statement template given a set of bindings. +%% Interpolation generally follows the SQL syntax, strings are escaped according to the +%% `escaping` option. +-spec render(template(), context(), render_opts()) -> + {unicode:chardata(), [_Error]}. +render(Template, Context, Opts) -> + emqx_template:render(Template, Context, #{ + var_trans => fun(Value) -> emqx_utils_sql:to_sql_string(Value, Opts) end + }). + +%% @doc Render an SQL statement template given a set of bindings. +%% Errors are raised if any placeholders are not bound. +-spec render_strict(template(), context(), render_opts()) -> + unicode:chardata(). +render_strict(Template, Context, Opts) -> + emqx_template:render_strict(Template, Context, #{ + var_trans => fun(Value) -> emqx_utils_sql:to_sql_string(Value, Opts) end + }). + +%% @doc Parse an SQL statement string into a prepared statement and a row template. +%% The row template is a template for a row of SQL values to be inserted to a database +%% during the execution of the prepared statement. +%% Example: +%% ``` +%% {Statement, RowTemplate} = emqx_template_sql:parse_prepstmt( +%% "INSERT INTO table (id, name, age) VALUES (${id}, ${name}, 42)", +%% #{parameters => '$n'} +%% ), +%% Statement = <<"INSERT INTO table (id, name, age) VALUES ($1, $2, 42)">>, +%% RowTemplate = [{var, "...", [...]}, ...] +%% ``` +-spec parse_prepstmt(unicode:chardata(), parse_opts()) -> + {unicode:chardata(), row_template()}. +parse_prepstmt(String, Opts) -> + Template = emqx_template:parse(String, maps:with(?TEMPLATE_PARSE_OPTS, Opts)), + Statement = mk_prepared_statement(Template, Opts), + Placeholders = [Placeholder || Placeholder <- Template, element(1, Placeholder) == var], + {Statement, Placeholders}. + +mk_prepared_statement(Template, Opts) -> + ParameterFormat = maps:get(parameters, Opts, '?'), + {Statement, _} = + lists:mapfoldl( + fun + (Var, Acc) when element(1, Var) == var -> + mk_replace(ParameterFormat, Acc); + (String, Acc) -> + {String, Acc} + end, + 1, + Template + ), + Statement. + +mk_replace('?', Acc) -> + {"?", Acc}; +mk_replace('$n', N) -> + {"$" ++ integer_to_list(N), N + 1}; +mk_replace(':n', N) -> + {":" ++ integer_to_list(N), N + 1}. + +%% @doc Render a row template into a list of SQL values. +%% An _SQL value_ is a vaguely defined concept here, it is something that's considered +%% compatible with the protocol of the database being used. See the definition of +%% `emqx_utils_sql:value()` for more details. +-spec render_prepstmt(template(), context()) -> + {values(), [_Error]}. +render_prepstmt(Template, Context) -> + Opts = #{var_trans => fun emqx_utils_sql:to_sql_value/1}, + emqx_template:render(Template, Context, Opts). + +-spec render_prepstmt_strict(template(), context()) -> + values(). +render_prepstmt_strict(Template, Context) -> + Opts = #{var_trans => fun emqx_utils_sql:to_sql_value/1}, + emqx_template:render_strict(Template, Context, Opts). diff --git a/apps/emqx_utils/src/emqx_utils_sql.erl b/apps/emqx_utils/src/emqx_utils_sql.erl index 3caed6b62..9ce9e576d 100644 --- a/apps/emqx_utils/src/emqx_utils_sql.erl +++ b/apps/emqx_utils/src/emqx_utils_sql.erl @@ -80,10 +80,15 @@ to_sql_value(Map) when is_map(Map) -> emqx_utils_json:encode(Map). %% @doc Convert an Erlang term to a string that can be interpolated in literal %% SQL statements. The value is escaped if necessary. --spec to_sql_string(term(), Options) -> iodata() when +-spec to_sql_string(term(), Options) -> unicode:chardata() when Options :: #{ - escaping => cql | mysql | sql + escaping => mysql | sql | cql, + undefined => null | unicode:chardata() }. +to_sql_string(undefined, #{undefined := Str} = Opts) when Str =/= null -> + to_sql_string(Str, Opts); +to_sql_string(undefined, #{}) -> + <<"NULL">>; to_sql_string(String, #{escaping := mysql}) when is_binary(String) -> try escape_mysql(String) @@ -98,7 +103,7 @@ to_sql_string(Term, #{escaping := cql}) -> to_sql_string(Term, #{}) -> maybe_escape(Term, fun escape_sql/1). --spec maybe_escape(_Value, fun((binary()) -> iodata())) -> iodata(). +-spec maybe_escape(_Value, fun((binary()) -> iodata())) -> unicode:chardata(). maybe_escape(Str, EscapeFun) when is_binary(Str) -> EscapeFun(Str); maybe_escape(Str, EscapeFun) when is_list(Str) -> @@ -109,9 +114,9 @@ maybe_escape(Str, EscapeFun) when is_list(Str) -> error(Otherwise) end; maybe_escape(Val, EscapeFun) when is_atom(Val) orelse is_map(Val) -> - EscapeFun(emqx_utils_conv:bin(Val)); + EscapeFun(emqx_template:to_string(Val)); maybe_escape(Val, _EscapeFun) -> - emqx_utils_conv:bin(Val). + emqx_template:to_string(Val). -spec escape_sql(binary()) -> iodata(). escape_sql(S) -> diff --git a/apps/emqx_utils/test/emqx_jsonish_tests.erl b/apps/emqx_utils/test/emqx_jsonish_tests.erl new file mode 100644 index 000000000..c776615a1 --- /dev/null +++ b/apps/emqx_utils/test/emqx_jsonish_tests.erl @@ -0,0 +1,97 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2020-2022 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_jsonish_tests). + +-include_lib("eunit/include/eunit.hrl"). + +prop_prio_test_() -> + [ + ?_assertEqual( + {ok, 42}, + emqx_jsonish:lookup([<<"foo">>], #{<<"foo">> => 42, foo => 1337}) + ), + ?_assertEqual( + {ok, 1337}, + emqx_jsonish:lookup([<<"foo">>], #{foo => 1337}) + ) + ]. + +undefined_test() -> + ?assertEqual( + {error, undefined}, + emqx_jsonish:lookup([<<"foo">>], #{}) + ). + +undefined_deep_test() -> + ?assertEqual( + {error, undefined}, + emqx_jsonish:lookup([<<"foo">>, <<"bar">>], #{}) + ). + +undefined_deep_json_test() -> + ?assertEqual( + {error, undefined}, + emqx_jsonish:lookup( + [<<"foo">>, <<"bar">>, <<"baz">>], + <<"{\"foo\":{\"bar\":{\"no\":{}}}}">> + ) + ). + +invalid_type_test() -> + ?assertEqual( + {error, {0, number}}, + emqx_jsonish:lookup([<<"foo">>], <<"42">>) + ). + +invalid_type_deep_test() -> + ?assertEqual( + {error, {2, atom}}, + emqx_jsonish:lookup([<<"foo">>, <<"bar">>, <<"tuple">>], #{foo => #{bar => baz}}) + ). + +decode_json_test() -> + ?assertEqual( + {ok, 42}, + emqx_jsonish:lookup([<<"foo">>, <<"bar">>], <<"{\"foo\":{\"bar\":42}}">>) + ). + +decode_json_deep_test() -> + ?assertEqual( + {ok, 42}, + emqx_jsonish:lookup([<<"foo">>, <<"bar">>], #{<<"foo">> => <<"{\"bar\": 42}">>}) + ). + +decode_json_invalid_type_test() -> + ?assertEqual( + {error, {1, list}}, + emqx_jsonish:lookup([<<"foo">>, <<"bar">>], #{<<"foo">> => <<"[1,2,3]">>}) + ). + +decode_no_json_test() -> + ?assertEqual( + {error, {1, binary}}, + emqx_jsonish:lookup([<<"foo">>, <<"bar">>], #{<<"foo">> => <<0, 1, 2, 3>>}) + ). + +decode_json_no_nested_test() -> + ?assertEqual( + {error, {2, binary}}, + emqx_jsonish:lookup( + [<<"foo">>, <<"bar">>, <<"baz">>], + #{<<"foo">> => <<"{\"bar\":\"{\\\"baz\\\":42}\"}">>} + ) + ). diff --git a/apps/emqx_utils/test/emqx_template_SUITE.erl b/apps/emqx_utils/test/emqx_template_SUITE.erl new file mode 100644 index 000000000..4dfe5de2e --- /dev/null +++ b/apps/emqx_utils/test/emqx_template_SUITE.erl @@ -0,0 +1,360 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2020-2023 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_template_SUITE). + +-compile(export_all). +-compile(nowarn_export_all). + +-include_lib("emqx/include/emqx_placeholder.hrl"). +-include_lib("eunit/include/eunit.hrl"). + +all() -> emqx_common_test_helpers:all(?MODULE). + +t_render(_) -> + Context = #{ + a => <<"1">>, + b => 1, + c => 1.0, + d => #{<<"d1">> => <<"hi">>}, + l => [0, 1, 1000], + u => "utf-8 is ǝɹǝɥ" + }, + Template = emqx_template:parse( + <<"a:${a},b:${b},c:${c},d:${d},d1:${d.d1},l:${l},u:${u}">> + ), + ?assertEqual( + {<<"a:1,b:1,c:1.0,d:{\"d1\":\"hi\"},d1:hi,l:[0,1,1000],u:utf-8 is ǝɹǝɥ"/utf8>>, []}, + render_string(Template, Context) + ). + +t_render_var_trans(_) -> + Context = #{a => <<"1">>, b => 1, c => #{prop => 1.0}}, + Template = emqx_template:parse(<<"a:${a},b:${b},c:${c.prop}">>), + {String, Errors} = emqx_template:render( + Template, + Context, + #{var_trans => fun(Name, _) -> "<" ++ Name ++ ">" end} + ), + ?assertEqual( + {<<"a:,b:,c:">>, []}, + {bin(String), Errors} + ). + +t_render_path(_) -> + Context = #{d => #{d1 => <<"hi">>}}, + Template = emqx_template:parse(<<"d.d1:${d.d1}">>), + ?assertEqual( + ok, + emqx_template:validate(["d.d1"], Template) + ), + ?assertEqual( + {<<"d.d1:hi">>, []}, + render_string(Template, Context) + ). + +t_render_custom_ph(_) -> + Context = #{a => <<"a">>, b => <<"b">>}, + Template = emqx_template:parse(<<"a:${a},b:${b}">>), + ?assertEqual( + {error, [{"b", disallowed}]}, + emqx_template:validate(["a"], Template) + ), + ?assertEqual( + <<"a:a,b:b">>, + render_strict_string(Template, Context) + ). + +t_render_this(_) -> + Context = #{a => <<"a">>, b => [1, 2, 3]}, + Template = emqx_template:parse(<<"this:${} / also:${.}">>), + ?assertEqual(ok, emqx_template:validate(["."], Template)), + ?assertEqual( + % NOTE: order of the keys in the JSON object depends on the JSON encoder + <<"this:{\"b\":[1,2,3],\"a\":\"a\"} / also:{\"b\":[1,2,3],\"a\":\"a\"}">>, + render_strict_string(Template, Context) + ). + +t_render_missing_bindings(_) -> + Context = #{no => #{}, c => #{<<"c1">> => 42}}, + Template = emqx_template:parse( + <<"a:${a},b:${b},c:${c.c1.c2},d:${d.d1},e:${no.such_atom_i_swear}">> + ), + ?assertEqual( + {<<"a:undefined,b:undefined,c:undefined,d:undefined,e:undefined">>, [ + {"no.such_atom_i_swear", undefined}, + {"d.d1", undefined}, + {"c.c1.c2", {2, number}}, + {"b", undefined}, + {"a", undefined} + ]}, + render_string(Template, Context) + ), + ?assertError( + [ + {"no.such_atom_i_swear", undefined}, + {"d.d1", undefined}, + {"c.c1.c2", {2, number}}, + {"b", undefined}, + {"a", undefined} + ], + render_strict_string(Template, Context) + ). + +t_render_custom_bindings(_) -> + _ = erlang:put(a, <<"foo">>), + _ = erlang:put(b, #{<<"bar">> => #{atom => 42}}), + Template = emqx_template:parse( + <<"a:${a},b:${b.bar.atom},c:${c},oops:${b.bar.atom.oops}">> + ), + ?assertEqual( + {<<"a:foo,b:42,c:undefined,oops:undefined">>, [ + {"b.bar.atom.oops", {2, number}}, + {"c", undefined} + ]}, + render_string(Template, {?MODULE, []}) + ). + +t_unparse(_) -> + TString = <<"a:${a},b:${b},c:$${c},d:{${d.d1}},e:${$}{e},lit:${$}{$}">>, + Template = emqx_template:parse(TString), + ?assertEqual( + TString, + unicode:characters_to_binary(emqx_template:unparse(Template)) + ). + +t_const(_) -> + ?assertEqual( + true, + emqx_template:is_const(emqx_template:parse(<<"">>)) + ), + ?assertEqual( + false, + emqx_template:is_const( + emqx_template:parse(<<"a:${a},b:${b},c:${$}{c}">>) + ) + ), + ?assertEqual( + true, + emqx_template:is_const( + emqx_template:parse(<<"a:${$}{a},b:${$}{b}">>) + ) + ). + +t_render_partial_ph(_) -> + Context = #{a => <<"1">>, b => 1, c => 1.0, d => #{d1 => <<"hi">>}}, + Template = emqx_template:parse(<<"a:$a,b:b},c:{c},d:${d">>), + ?assertEqual( + <<"a:$a,b:b},c:{c},d:${d">>, + render_strict_string(Template, Context) + ). + +t_parse_escaped(_) -> + Context = #{a => <<"1">>, b => 1, c => "VAR"}, + Template = emqx_template:parse(<<"a:${a},b:${$}{b},c:${$}{${c}},lit:${$}{$}">>), + ?assertEqual( + <<"a:1,b:${b},c:${VAR},lit:${$}">>, + render_strict_string(Template, Context) + ). + +t_parse_escaped_dquote(_) -> + Context = #{a => <<"1">>, b => 1}, + Template = emqx_template:parse(<<"a:\"${a}\",b:\"${$}{b}\"">>, #{ + strip_double_quote => true + }), + ?assertEqual( + <<"a:1,b:\"${b}\"">>, + render_strict_string(Template, Context) + ). + +t_parse_sql_prepstmt(_) -> + Context = #{a => <<"1">>, b => 1, c => 1.0, d => #{d1 => <<"hi">>}}, + {PrepareStatement, RowTemplate} = + emqx_template_sql:parse_prepstmt(<<"a:${a},b:${b},c:${c},d:${d}">>, #{ + parameters => '?' + }), + ?assertEqual(<<"a:?,b:?,c:?,d:?">>, bin(PrepareStatement)), + ?assertEqual( + {[<<"1">>, 1, 1.0, <<"{\"d1\":\"hi\"}">>], _Errors = []}, + emqx_template_sql:render_prepstmt(RowTemplate, Context) + ). + +t_parse_sql_prepstmt_n(_) -> + Context = #{a => undefined, b => true, c => atom, d => #{d1 => 42.1337}}, + {PrepareStatement, RowTemplate} = + emqx_template_sql:parse_prepstmt(<<"a:${a},b:${b},c:${c},d:${d}">>, #{ + parameters => '$n' + }), + ?assertEqual(<<"a:$1,b:$2,c:$3,d:$4">>, bin(PrepareStatement)), + ?assertEqual( + [null, true, <<"atom">>, <<"{\"d1\":42.1337}">>], + emqx_template_sql:render_prepstmt_strict(RowTemplate, Context) + ). + +t_parse_sql_prepstmt_colon(_) -> + {PrepareStatement, _RowTemplate} = + emqx_template_sql:parse_prepstmt(<<"a=${a},b=${b},c=${c},d=${d}">>, #{ + parameters => ':n' + }), + ?assertEqual(<<"a=:1,b=:2,c=:3,d=:4">>, bin(PrepareStatement)). + +t_parse_sql_prepstmt_partial_ph(_) -> + Context = #{a => <<"1">>, b => 1, c => 1.0, d => #{d1 => <<"hi">>}}, + {PrepareStatement, RowTemplate} = + emqx_template_sql:parse_prepstmt(<<"a:$a,b:b},c:{c},d:${d">>, #{parameters => '?'}), + ?assertEqual(<<"a:$a,b:b},c:{c},d:${d">>, bin(PrepareStatement)), + ?assertEqual([], emqx_template_sql:render_prepstmt_strict(RowTemplate, Context)). + +t_render_sql(_) -> + Context = #{ + a => <<"1">>, + b => 1, + c => 1.0, + d => #{d1 => <<"hi">>}, + n => undefined, + u => "utf8's cool 🐸" + }, + Template = emqx_template:parse(<<"a:${a},b:${b},c:${c},d:${d},n:${n},u:${u}">>), + ?assertMatch( + {_String, _Errors = []}, + emqx_template_sql:render(Template, Context, #{}) + ), + ?assertEqual( + <<"a:'1',b:1,c:1.0,d:'{\"d1\":\"hi\"}',n:NULL,u:'utf8\\'s cool 🐸'"/utf8>>, + bin(emqx_template_sql:render_strict(Template, Context, #{})) + ), + ?assertEqual( + <<"a:'1',b:1,c:1.0,d:'{\"d1\":\"hi\"}',n:'undefined',u:'utf8\\'s cool 🐸'"/utf8>>, + bin(emqx_template_sql:render_strict(Template, Context, #{undefined => "undefined"})) + ). + +t_render_mysql(_) -> + %% with apostrophes + %% https://github.com/emqx/emqx/issues/4135 + Context = #{ + a => <<"1''2">>, + b => 1, + c => 1.0, + d => #{d1 => <<"someone's phone">>}, + e => <<$\\, 0, "💩"/utf8>>, + f => <<"non-utf8", 16#DCC900:24>>, + g => "utf8's cool 🐸", + h => imgood + }, + Template = emqx_template_sql:parse( + <<"a:${a},b:${b},c:${c},d:${d},e:${e},f:${f},g:${g},h:${h}">> + ), + ?assertEqual( + << + "a:'1\\'\\'2',b:1,c:1.0,d:'{\"d1\":\"someone\\'s phone\"}'," + "e:'\\\\\\0💩',f:0x6E6F6E2D75746638DCC900,g:'utf8\\'s cool 🐸',"/utf8, + "h:'imgood'" + >>, + bin(emqx_template_sql:render_strict(Template, Context, #{escaping => mysql})) + ). + +t_render_cql(_) -> + %% with apostrophes for cassandra + %% https://github.com/emqx/emqx/issues/4148 + Context = #{ + a => <<"1''2">>, + b => 1, + c => 1.0, + d => #{d1 => <<"someone's phone">>} + }, + Template = emqx_template:parse(<<"a:${a},b:${b},c:${c},d:${d}">>), + ?assertEqual( + <<"a:'1''''2',b:1,c:1.0,d:'{\"d1\":\"someone''s phone\"}'">>, + bin(emqx_template_sql:render_strict(Template, Context, #{escaping => cql})) + ). + +t_render_sql_custom_ph(_) -> + {PrepareStatement, RowTemplate} = + emqx_template_sql:parse_prepstmt(<<"a:${a},b:${b.c}">>, #{parameters => '$n'}), + ?assertEqual( + {error, [{"b.c", disallowed}]}, + emqx_template:validate(["a"], RowTemplate) + ), + ?assertEqual(<<"a:$1,b:$2">>, bin(PrepareStatement)). + +t_render_sql_strip_double_quote(_) -> + Context = #{a => <<"a">>, b => <<"b">>}, + + %% no strip_double_quote option: "${key}" -> "value" + {PrepareStatement1, RowTemplate1} = emqx_template_sql:parse_prepstmt( + <<"a:\"${a}\",b:\"${b}\"">>, + #{parameters => '$n'} + ), + ?assertEqual(<<"a:\"$1\",b:\"$2\"">>, bin(PrepareStatement1)), + ?assertEqual( + [<<"a">>, <<"b">>], + emqx_template_sql:render_prepstmt_strict(RowTemplate1, Context) + ), + + %% strip_double_quote = true: "${key}" -> value + {PrepareStatement2, RowTemplate2} = emqx_template_sql:parse_prepstmt( + <<"a:\"${a}\",b:\"${b}\"">>, + #{parameters => '$n', strip_double_quote => true} + ), + ?assertEqual(<<"a:$1,b:$2">>, bin(PrepareStatement2)), + ?assertEqual( + [<<"a">>, <<"b">>], + emqx_template_sql:render_prepstmt_strict(RowTemplate2, Context) + ). + +t_render_tmpl_deep(_) -> + Context = #{a => <<"1">>, b => 1, c => 1.0, d => #{d1 => <<"hi">>}}, + + Template = emqx_template:parse_deep( + #{<<"${a}">> => [<<"$${b}">>, "c", 2, 3.0, '${d}', {[<<"${c}">>, <<"${$}{d}">>], 0}]} + ), + + ?assertEqual( + {error, [{V, disallowed} || V <- ["b", "c"]]}, + emqx_template:validate(["a"], Template) + ), + + ?assertEqual( + #{<<"1">> => [<<"$1">>, "c", 2, 3.0, '${d}', {[<<"1.0">>, <<"${d}">>], 0}]}, + emqx_template:render_strict(Template, Context) + ). + +t_unparse_tmpl_deep(_) -> + Term = #{<<"${a}">> => [<<"$${b}">>, "c", 2, 3.0, '${d}', {[<<"${c}">>], <<"${$}{d}">>, 0}]}, + Template = emqx_template:parse_deep(Term), + ?assertEqual(Term, emqx_template:unparse(Template)). + +%% + +render_string(Template, Context) -> + {String, Errors} = emqx_template:render(Template, Context), + {bin(String), Errors}. + +render_strict_string(Template, Context) -> + bin(emqx_template:render_strict(Template, Context)). + +bin(String) -> + unicode:characters_to_binary(String). + +%% Access module API + +lookup([], _) -> + {error, undefined}; +lookup([Prop | Rest], _) -> + case erlang:get(binary_to_atom(Prop)) of + undefined -> {error, undefined}; + Value -> emqx_template:lookup_var(Rest, Value) + end. diff --git a/mix.exs b/mix.exs index 7f757b716..193f3371b 100644 --- a/mix.exs +++ b/mix.exs @@ -338,6 +338,7 @@ defmodule EMQXUmbrella.MixProject do :emqx_management, :emqx_retainer, :emqx_prometheus, + :emqx_rule_engine, :emqx_auto_subscribe, :emqx_slow_subs, :emqx_plugins,