fix(mysql): ensure proper escaping in batch inserts

Also hexencode non-utf8 binaries. This is essentially an heuristic.
We don't know column types in runtime, and there's no simple way
to find them out. Since we're already doing full binary scan during
escaping it should be cheap to bail out on non-utf8 strings and
hexencode them instead.

Also introduce separate function to highlight that this escaping
is MySQL-specific.
This commit is contained in:
Andrew Mayorov 2023-03-08 13:24:09 +03:00
parent fc37d9b3cd
commit 0a7f6c7d03
No known key found for this signature in database
GPG Key ID: 2837C62ACFBFED5D
6 changed files with 128 additions and 30 deletions

View File

@ -391,8 +391,12 @@ proc_sql_params(TypeOrKey, SQLOrData, Params, #{params_tokens := ParamsTokens})
end.
on_batch_insert(InstId, BatchReqs, InsertPart, Tokens, State) ->
SQL = emqx_plugin_libs_rule:proc_batch_sql(BatchReqs, InsertPart, Tokens),
on_sql_query(InstId, query, SQL, no_params, default_timeout, State).
ValuesPart = lists:join($,, [
emqx_placeholder:proc_param_str(Tokens, Msg, fun emqx_placeholder:quote_mysql/1)
|| {_, Msg} <- BatchReqs
]),
Query = [InsertPart, <<" values ">> | ValuesPart],
on_sql_query(InstId, query, Query, no_params, default_timeout, State).
on_sql_query(
InstId,

View File

@ -30,6 +30,7 @@
proc_sql/2,
proc_sql_param_str/2,
proc_cql_param_str/2,
proc_param_str/3,
preproc_tmpl_deep/1,
preproc_tmpl_deep/2,
proc_tmpl_deep/2,
@ -39,6 +40,12 @@
sql_data/1
]).
-export([
quote_sql/1,
quote_cql/1,
quote_mysql/1
]).
-include_lib("emqx/include/emqx_placeholder.hrl").
-define(EX_PLACE_HOLDER, "(\\$\\{[a-zA-Z0-9\\._]+\\})").
@ -83,6 +90,8 @@
| {tmpl, tmpl_token()}
| {value, term()}.
-dialyzer({no_improper_lists, [quote_mysql/1, escape_mysql/4, escape_prepend/4]}).
%%------------------------------------------------------------------------------
%% APIs
%%------------------------------------------------------------------------------
@ -162,12 +171,22 @@ proc_sql(Tokens, Data) ->
-spec proc_sql_param_str(tmpl_token(), map()) -> binary().
proc_sql_param_str(Tokens, Data) ->
% NOTE
% This is a bit misleading: currently, escaping logic in `quote_sql/1` likely
% won't work with pgsql since it does not support C-style escapes by default.
% https://www.postgresql.org/docs/14/sql-syntax-lexical.html#SQL-SYNTAX-CONSTANTS
proc_param_str(Tokens, Data, fun quote_sql/1).
-spec proc_cql_param_str(tmpl_token(), map()) -> binary().
proc_cql_param_str(Tokens, Data) ->
proc_param_str(Tokens, Data, fun quote_cql/1).
-spec proc_param_str(tmpl_token(), map(), fun((_Value) -> iodata())) -> binary().
proc_param_str(Tokens, Data, Quote) ->
iolist_to_binary(
proc_tmpl(Tokens, Data, #{return => rawlist, var_trans => Quote})
).
-spec preproc_tmpl_deep(term()) -> deep_template().
preproc_tmpl_deep(Data) ->
preproc_tmpl_deep(Data, #{process_keys => true}).
@ -226,15 +245,29 @@ sql_data(Map) when is_map(Map) -> emqx_json:encode(Map).
-spec bin(term()) -> binary().
bin(Val) -> emqx_plugin_libs_rule:bin(Val).
-spec quote_sql(_Value) -> iolist().
quote_sql(Str) ->
quote_escape(Str, fun escape_sql/1).
-spec quote_cql(_Value) -> iolist().
quote_cql(Str) ->
quote_escape(Str, fun escape_cql/1).
-spec quote_mysql(_Value) -> iolist().
quote_mysql(Str) when is_binary(Str) ->
try
escape_mysql(Str)
catch
throw:invalid_utf8 ->
[<<"0x">> | binary:encode_hex(Str)]
end;
quote_mysql(Str) ->
quote_escape(Str, fun escape_mysql/1).
%%------------------------------------------------------------------------------
%% Internal functions
%%------------------------------------------------------------------------------
proc_param_str(Tokens, Data, Quote) ->
iolist_to_binary(
proc_tmpl(Tokens, Data, #{return => rawlist, var_trans => Quote})
).
get_phld_var(Phld, Data) ->
emqx_rule_maps:nested_get(Phld, Data).
@ -312,21 +345,56 @@ unwrap(<<"\"${", Val/binary>>, _StripDoubleQuote = true) ->
unwrap(<<"${", Val/binary>>, _StripDoubleQuote) ->
binary:part(Val, {0, byte_size(Val) - 1}).
quote_sql(Str) ->
quote(Str, <<"\\\\'">>).
quote_cql(Str) ->
quote(Str, <<"''">>).
quote(Str, ReplaceWith) when
is_list(Str);
is_binary(Str);
is_atom(Str);
is_map(Str)
->
[$', escape_apo(bin(Str), ReplaceWith), $'];
quote(Val, _) ->
-spec quote_escape(_Value, fun((binary()) -> iodata())) -> iodata().
quote_escape(Str, EscapeFun) when is_binary(Str) ->
EscapeFun(Str);
quote_escape(Str, EscapeFun) when is_list(Str) ->
case unicode:characters_to_binary(Str) of
Bin when is_binary(Bin) ->
EscapeFun(Bin);
Otherwise ->
error(Otherwise)
end;
quote_escape(Str, EscapeFun) when is_atom(Str) orelse is_map(Str) ->
EscapeFun(bin(Str));
quote_escape(Val, _EscapeFun) ->
bin(Val).
escape_apo(Str, ReplaceWith) ->
re:replace(Str, <<"'">>, ReplaceWith, [{return, binary}, global]).
-spec escape_sql(binary()) -> iolist().
escape_sql(S) ->
ES = binary:replace(S, [<<"\\">>, <<"'">>], <<"\\">>, [global, {insert_replaced, 1}]),
[$', ES, $'].
-spec escape_cql(binary()) -> iolist().
escape_cql(S) ->
ES = binary:replace(S, <<"'">>, <<"'">>, [global, {insert_replaced, 1}]),
[$', ES, $'].
-spec escape_mysql(binary()) -> iolist().
escape_mysql(S0) ->
% https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
[$', escape_mysql(S0, 0, 0, S0), $'].
%% NOTE
%% This thing looks more complicated than needed because it's optimized for as few
%% intermediate memory (re)allocations as possible.
escape_mysql(<<$', Rest/binary>>, I, Run, Src) ->
escape_prepend(I, Run, Src, [<<"\\'">> | escape_mysql(Rest, I + Run + 1, 0, Src)]);
escape_mysql(<<$\\, Rest/binary>>, I, Run, Src) ->
escape_prepend(I, Run, Src, [<<"\\\\">> | escape_mysql(Rest, I + Run + 1, 0, Src)]);
escape_mysql(<<0, Rest/binary>>, I, Run, Src) ->
escape_prepend(I, Run, Src, [<<"\\0">> | escape_mysql(Rest, I + Run + 1, 0, Src)]);
escape_mysql(<<_/utf8, Rest/binary>> = S, I, Run, Src) ->
CWidth = byte_size(S) - byte_size(Rest),
escape_mysql(Rest, I, Run + CWidth, Src);
escape_mysql(<<>>, 0, _, Src) ->
Src;
escape_mysql(<<>>, I, Run, Src) ->
binary:part(Src, I, Run);
escape_mysql(_, _I, _Run, _Src) ->
throw(invalid_utf8).
escape_prepend(_RunI, 0, _Src, Tail) ->
Tail;
escape_prepend(I, Run, Src, Tail) ->
[binary:part(Src, I, Run) | Tail].

View File

@ -172,8 +172,8 @@ detect_sql_type(SQL) ->
) -> InsertSQL :: binary().
proc_batch_sql(BatchReqs, InsertPart, Tokens) ->
ValuesPart = erlang:iolist_to_binary(
lists:join(", ", [
emqx_plugin_libs_rule:proc_sql_param_str(Tokens, Msg)
lists:join($,, [
proc_sql_param_str(Tokens, Msg)
|| {_, Msg} <- BatchReqs
])
),

View File

@ -105,19 +105,27 @@ t_preproc_sql3(_) ->
emqx_placeholder:proc_sql_param_str(ParamsTokens, Selected)
).
t_preproc_sql4(_) ->
t_preproc_mysql1(_) ->
%% with apostrophes
%% https://github.com/emqx/emqx/issues/4135
Selected = #{
a => <<"1''2">>,
b => 1,
c => 1.0,
d => #{d1 => <<"someone's phone">>}
d => #{d1 => <<"someone's phone">>},
e => <<$\\, 0, "💩"/utf8>>,
f => <<"non-utf8", 16#DCC900:24>>,
g => "utf8's cool 🐸"
},
ParamsTokens = emqx_placeholder:preproc_tmpl(<<"a:${a},b:${b},c:${c},d:${d}">>),
ParamsTokens = emqx_placeholder:preproc_tmpl(
<<"a:${a},b:${b},c:${c},d:${d},e:${e},f:${f},g:${g}">>
),
?assertEqual(
<<"a:'1\\'\\'2',b:1,c:1.0,d:'{\"d1\":\"someone\\'s phone\"}'">>,
emqx_placeholder:proc_sql_param_str(ParamsTokens, Selected)
<<
"a:'1\\'\\'2',b:1,c:1.0,d:'{\"d1\":\"someone\\'s phone\"}',"
"e:'\\\\\\0💩',f:0x6E6F6E2D75746638DCC900,g:'utf8\\'s cool 🐸'"/utf8
>>,
emqx_placeholder:proc_param_str(ParamsTokens, Selected, fun emqx_placeholder:quote_mysql/1)
).
t_preproc_sql5(_) ->

View File

@ -511,6 +511,17 @@ t_bad_sql_parameter(Config) ->
end,
ok.
t_nasty_sql_string(Config) ->
?assertMatch({ok, _}, create_bridge(Config)),
Payload = list_to_binary(lists:seq(0, 255)),
Message = #{payload => Payload, timestamp => erlang:system_time(millisecond)},
Result = send_message(Config, Message),
?assertEqual(ok, Result),
?assertMatch(
{ok, [<<"payload">>], [[Payload]]},
connect_and_get_payload(Config)
).
t_workload_fits_prepared_statement_limit(Config) ->
N = 50,
?assertMatch(

View File

@ -510,3 +510,10 @@ t_bad_sql_parameter(Config) ->
)
end,
ok.
t_nasty_sql_string(Config) ->
?assertMatch({ok, _}, create_bridge(Config)),
Payload = list_to_binary(lists:seq(1, 127)),
Message = #{payload => Payload, timestamp => erlang:system_time(millisecond)},
?assertEqual({ok, 1}, send_message(Config, Message)),
?assertEqual(Payload, connect_and_get_payload(Config)).