diff --git a/apps/emqx_bridge_mysql/src/emqx_bridge_mysql_connector.erl b/apps/emqx_bridge_mysql/src/emqx_bridge_mysql_connector.erl index 468f64d1f..5d331790b 100644 --- a/apps/emqx_bridge_mysql/src/emqx_bridge_mysql_connector.erl +++ b/apps/emqx_bridge_mysql/src/emqx_bridge_mysql_connector.erl @@ -35,13 +35,18 @@ on_add_channel( ) -> ChannelConfig1 = emqx_utils_maps:unindent(parameters, ChannelConfig0), QueryTemplates = emqx_mysql:parse_prepare_sql(ChannelId, ChannelConfig1), - ChannelConfig2 = maps:merge(ChannelConfig1, QueryTemplates), - ChannelConfig = set_prepares(ChannelConfig2, ConnectorState), - State = State0#{ - channels => maps:put(ChannelId, ChannelConfig, Channels), - connector_state => ConnectorState - }, - {ok, State}. + case validate_sql_type(ChannelId, ChannelConfig1, QueryTemplates) of + ok -> + ChannelConfig2 = maps:merge(ChannelConfig1, QueryTemplates), + ChannelConfig = set_prepares(ChannelConfig2, ConnectorState), + State = State0#{ + channels => maps:put(ChannelId, ChannelConfig, Channels), + connector_state => ConnectorState + }, + {ok, State}; + {error, Error} -> + {error, Error} + end. on_get_channel_status(_InstanceId, ChannelId, #{channels := Channels}) -> case maps:get(ChannelId, Channels) of @@ -116,11 +121,13 @@ on_batch_query(InstanceId, BatchRequest, _State = #{connector_state := Connector on_remove_channel( _InstanceId, #{channels := Channels, connector_state := ConnectorState} = State, ChannelId -) -> +) when is_map_key(ChannelId, Channels) -> ChannelConfig = maps:get(ChannelId, Channels), emqx_mysql:unprepare_sql(maps:merge(ChannelConfig, ConnectorState)), NewState = State#{channels => maps:remove(ChannelId, Channels)}, - {ok, NewState}. + {ok, NewState}; +on_remove_channel(_InstanceId, State, _ChannelId) -> + {ok, State}. -spec on_start(binary(), hocon:config()) -> {ok, #{connector_state := emqx_mysql:state(), channels := map()}} | {error, _}. @@ -148,3 +155,43 @@ set_prepares(ChannelConfig, ConnectorState) -> #{prepares := Prepares} = emqx_mysql:init_prepare(maps:merge(ConnectorState, ChannelConfig)), ChannelConfig#{prepares => Prepares}. + +validate_sql_type(ChannelId, ChannelConfig, #{query_templates := QueryTemplates}) -> + Batch = + case emqx_utils_maps:deep_get([resource_opts, batch_size], ChannelConfig) of + N when N > 1 -> batch; + _ -> single + end, + BatchKey = {ChannelId, batch}, + SingleKey = {ChannelId, prepstmt}, + case {QueryTemplates, Batch} of + {#{BatchKey := _}, batch} -> + ok; + {#{SingleKey := _}, single} -> + ok; + {_, batch} -> + %% try to provide helpful info + SQL = maps:get(sql, ChannelConfig), + Type = emqx_utils_sql:get_statement_type(SQL), + ErrorContext0 = #{ + reason => failed_to_prepare_statement, + statement_type => Type, + operation_type => Batch + }, + ErrorContext = emqx_utils_maps:put_if( + ErrorContext0, + hint, + <<"UPDATE statements are not supported for batch operations">>, + Type =:= update + ), + {error, ErrorContext}; + _ -> + SQL = maps:get(sql, ChannelConfig), + Type = emqx_utils_sql:get_statement_type(SQL), + ErrorContext = #{ + reason => failed_to_prepare_statement, + statement_type => Type, + operation_type => Batch + }, + {error, ErrorContext} + end. diff --git a/apps/emqx_bridge_mysql/test/emqx_bridge_mysql_SUITE.erl b/apps/emqx_bridge_mysql/test/emqx_bridge_mysql_SUITE.erl index 51cdc573f..96fcf6d24 100644 --- a/apps/emqx_bridge_mysql/test/emqx_bridge_mysql_SUITE.erl +++ b/apps/emqx_bridge_mysql/test/emqx_bridge_mysql_SUITE.erl @@ -31,6 +31,8 @@ -define(WORKER_POOL_SIZE, 4). +-define(ACTION_TYPE, mysql). + -import(emqx_common_test_helpers, [on_exit/1]). %%------------------------------------------------------------------------------ @@ -45,7 +47,14 @@ all() -> groups() -> TCs = emqx_common_test_helpers:all(?MODULE), - NonBatchCases = [t_write_timeout, t_uninitialized_prepared_statement], + NonBatchCases = [ + t_write_timeout, + t_uninitialized_prepared_statement, + t_non_batch_update_is_allowed + ], + OnlyBatchCases = [ + t_batch_update_is_forbidden + ], BatchingGroups = [ {group, with_batch}, {group, without_batch} @@ -57,7 +66,7 @@ groups() -> {async, BatchingGroups}, {sync, BatchingGroups}, {with_batch, TCs -- NonBatchCases}, - {without_batch, TCs} + {without_batch, TCs -- OnlyBatchCases} ]. init_per_group(tcp, Config) -> @@ -103,6 +112,8 @@ end_per_group(_Group, _Config) -> ok. init_per_suite(Config) -> + emqx_common_test_helpers:clear_screen(), + Config. end_per_suite(_Config) -> @@ -151,6 +162,9 @@ common_init(Config0) -> {mysql_config, MysqlConfig}, {mysql_bridge_type, BridgeType}, {mysql_name, Name}, + {bridge_type, BridgeType}, + {bridge_name, Name}, + {bridge_config, MysqlConfig}, {proxy_host, ProxyHost}, {proxy_port, ProxyPort} | Config0 @@ -874,3 +888,91 @@ t_nested_payload_template(Config) -> connect_and_get_payload(Config) ), ok. + +t_batch_update_is_forbidden(Config) -> + ?check_trace( + begin + Overrides = #{ + <<"sql">> => + << + "UPDATE mqtt_test " + "SET arrived = FROM_UNIXTIME(${timestamp}/1000) " + "WHERE payload = ${payload.value}" + >> + }, + ProbeRes = emqx_bridge_testlib:probe_bridge_api(Config, Overrides), + ?assertMatch({error, {{_, 400, _}, _, _Body}}, ProbeRes), + {error, {{_, 400, _}, _, ProbeBodyRaw}} = ProbeRes, + ?assertEqual( + match, + re:run( + ProbeBodyRaw, + <<"UPDATE statements are not supported for batch operations">>, + [global, {capture, none}] + ) + ), + CreateRes = emqx_bridge_testlib:create_bridge_api(Config, Overrides), + ?assertMatch( + {ok, {{_, 201, _}, _, #{<<"status">> := <<"disconnected">>}}}, + CreateRes + ), + {ok, {{_, 201, _}, _, #{<<"status_reason">> := Reason}}} = CreateRes, + ?assertEqual( + match, + re:run( + Reason, + <<"UPDATE statements are not supported for batch operations">>, + [global, {capture, none}] + ) + ), + ok + end, + [] + ), + ok. + +t_non_batch_update_is_allowed(Config) -> + ?check_trace( + begin + BridgeName = ?config(bridge_name, Config), + Overrides = #{ + <<"resource_opts">> => #{<<"metrics_flush_interval">> => <<"500ms">>}, + <<"sql">> => + << + "UPDATE mqtt_test " + "SET arrived = FROM_UNIXTIME(${timestamp}/1000) " + "WHERE payload = ${payload.value}" + >> + }, + ProbeRes = emqx_bridge_testlib:probe_bridge_api(Config, Overrides), + ?assertMatch({ok, {{_, 204, _}, _, _Body}}, ProbeRes), + ?assertMatch( + {ok, {{_, 201, _}, _, #{<<"status">> := <<"connected">>}}}, + emqx_bridge_testlib:create_bridge_api(Config, Overrides) + ), + {ok, #{ + <<"id">> := RuleId, + <<"from">> := [Topic] + }} = create_rule_and_action_http(Config), + Payload = emqx_utils_json:encode(#{value => <<"aaaa">>}), + Message = emqx_message:make(Topic, Payload), + {_, {ok, _}} = + ?wait_async_action( + emqx:publish(Message), + #{?snk_kind := mysql_connector_query_return}, + 10_000 + ), + ActionId = emqx_bridge_v2:id(?ACTION_TYPE, BridgeName), + ?assertEqual(1, emqx_resource_metrics:matched_get(ActionId)), + ?retry( + _Sleep0 = 200, + _Attempts0 = 10, + ?assertEqual(1, emqx_resource_metrics:success_get(ActionId)) + ), + + ?assertEqual(1, emqx_metrics_worker:get(rule_metrics, RuleId, 'actions.success')), + ok + end, + [] + ), + ok. diff --git a/apps/emqx_mysql/src/emqx_mysql.erl b/apps/emqx_mysql/src/emqx_mysql.erl index fcfabd61e..a9b132570 100644 --- a/apps/emqx_mysql/src/emqx_mysql.erl +++ b/apps/emqx_mysql/src/emqx_mysql.erl @@ -436,11 +436,11 @@ parse_batch_sql(Key, Query, Acc) -> end; select -> Acc; - Otherwise -> + Type -> ?SLOG(error, #{ msg => "invalid sql statement type", sql => Query, - type => Otherwise + type => Type }), Acc end. diff --git a/apps/emqx_utils/src/emqx_utils_sql.erl b/apps/emqx_utils/src/emqx_utils_sql.erl index 9ce9e576d..3fe6d67ec 100644 --- a/apps/emqx_utils/src/emqx_utils_sql.erl +++ b/apps/emqx_utils/src/emqx_utils_sql.erl @@ -28,7 +28,7 @@ -export_type([value/0]). --type statement_type() :: select | insert | delete. +-type statement_type() :: select | insert | delete | update. -type value() :: null | binary() | number() | boolean() | [value()]. -dialyzer({no_improper_lists, [escape_mysql/4, escape_prepend/4]}). @@ -38,6 +38,7 @@ get_statement_type(Query) -> KnownTypes = #{ <<"select">> => select, <<"insert">> => insert, + <<"update">> => update, <<"delete">> => delete }, case re:run(Query, <<"^\\s*([a-zA-Z]+)">>, [{capture, all_but_first, binary}]) of