%%-------------------------------------------------------------------- %% Copyright (c) 2020-2023 EMQ Technologies Co., Ltd. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. %% You may obtain a copy of the License at %% %% http://www.apache.org/licenses/LICENSE-2.0 %% %% Unless required by applicable law or agreed to in writing, software %% distributed under the License is distributed on an "AS IS" BASIS, %% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. %% See the License for the specific language governing permissions and %% limitations under the License. %%-------------------------------------------------------------------- -module(emqx_auth_mysql_cli). -behaviour(ecpool_worker). -include("emqx_auth_mysql.hrl"). -include_lib("emqx/include/emqx.hrl"). -include_lib("emqx/include/logger.hrl"). -export([ parse_query/1 , connect/1 , query/4 ]). %%-------------------------------------------------------------------- %% Avoid SQL Injection: Parse SQL to Parameter Query. %%-------------------------------------------------------------------- parse_query(undefined) -> undefined; parse_query(Sql) -> case re:run(Sql, "'%[ucCad]'", [global, {capture, all, list}]) of {match, Variables} -> Params = [Var || [Var] <- Variables], {re:replace(Sql, "'%[ucCad]'", "?", [global, {return, list}]), Params}; nomatch -> {Sql, []} end. %%-------------------------------------------------------------------- %% MySQL Connect/Query %%-------------------------------------------------------------------- connect(Options) -> case mysql:start_link(Options) of {ok, Pid} -> {ok, Pid}; ignore -> {error, ignore}; {error, Reason = {{_, {error, econnrefused}}, _}} -> ?LOG(error, "[MySQL] Can't connect to MySQL server: Connection refused."), {error, Reason}; {error, Reason = {ErrorCode, _, Error}} -> ?LOG_SENSITIVE(error, "[MySQL] Can't connect to MySQL server: ~p - ~p", [ErrorCode, Error]), {error, Reason}; {error, Reason} -> ?LOG_SENSITIVE(error, "[MySQL] Can't connect to MySQL server: ~p", [Reason]), {error, Reason} end. query(Pool, Sql, Params, ClientInfo) -> ecpool:with_client(Pool, fun(C) -> mysql:query(C, Sql, replvar(Params, ClientInfo)) end). replvar(Params, ClientInfo) -> replvar(Params, ClientInfo, []). replvar([], _ClientInfo, Acc) -> lists:reverse(Acc); replvar(["'%u'" | Params], ClientInfo, Acc) -> replvar(Params, ClientInfo, [safe_get(username, ClientInfo) | Acc]); replvar(["'%c'" | Params], ClientInfo = #{clientid := ClientId}, Acc) -> replvar(Params, ClientInfo, [ClientId | Acc]); replvar(["'%a'" | Params], ClientInfo = #{peerhost := IpAddr}, Acc) -> replvar(Params, ClientInfo, [inet_parse:ntoa(IpAddr) | Acc]); replvar(["'%C'" | Params], ClientInfo, Acc) -> replvar(Params, ClientInfo, [safe_get(cn, ClientInfo)| Acc]); replvar(["'%d'" | Params], ClientInfo, Acc) -> replvar(Params, ClientInfo, [safe_get(dn, ClientInfo)| Acc]); replvar([Param | Params], ClientInfo, Acc) -> replvar(Params, ClientInfo, [Param | Acc]). safe_get(K, ClientInfo) -> bin(maps:get(K, ClientInfo, "undefined")). bin(A) when is_atom(A) -> atom_to_binary(A, utf8); bin(B) when is_binary(B) -> B; bin(X) -> X.