emqx/apps/emqx_auth_mysql/src/emqx_auth_mysql_cli.erl

92 lines
3.4 KiB
Erlang

%%--------------------------------------------------------------------
%% Copyright (c) 2020-2023 EMQ Technologies Co., Ltd. All Rights Reserved.
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%%--------------------------------------------------------------------
-module(emqx_auth_mysql_cli).
-behaviour(ecpool_worker).
-include("emqx_auth_mysql.hrl").
-include_lib("emqx/include/emqx.hrl").
-include_lib("emqx/include/logger.hrl").
-export([ parse_query/1
, connect/1
, query/4
]).
%%--------------------------------------------------------------------
%% Avoid SQL Injection: Parse SQL to Parameter Query.
%%--------------------------------------------------------------------
parse_query(undefined) ->
undefined;
parse_query(Sql) ->
case re:run(Sql, "'%[ucCad]'", [global, {capture, all, list}]) of
{match, Variables} ->
Params = [Var || [Var] <- Variables],
{re:replace(Sql, "'%[ucCad]'", "?", [global, {return, list}]), Params};
nomatch ->
{Sql, []}
end.
%%--------------------------------------------------------------------
%% MySQL Connect/Query
%%--------------------------------------------------------------------
connect(Options) ->
case mysql:start_link(Options) of
{ok, Pid} -> {ok, Pid};
ignore -> {error, ignore};
{error, Reason = {{_, {error, econnrefused}}, _}} ->
?LOG(error, "[MySQL] Can't connect to MySQL server: Connection refused."),
{error, Reason};
{error, Reason = {ErrorCode, _, Error}} ->
?LOG_SENSITIVE(error, "[MySQL] Can't connect to MySQL server: ~p - ~p", [ErrorCode, Error]),
{error, Reason};
{error, Reason} ->
?LOG_SENSITIVE(error, "[MySQL] Can't connect to MySQL server: ~p", [Reason]),
{error, Reason}
end.
query(Pool, Sql, Params, ClientInfo) ->
ecpool:with_client(Pool, fun(C) -> mysql:query(C, Sql, replvar(Params, ClientInfo)) end).
replvar(Params, ClientInfo) ->
replvar(Params, ClientInfo, []).
replvar([], _ClientInfo, Acc) ->
lists:reverse(Acc);
replvar(["'%u'" | Params], ClientInfo, Acc) ->
replvar(Params, ClientInfo, [safe_get(username, ClientInfo) | Acc]);
replvar(["'%c'" | Params], ClientInfo = #{clientid := ClientId}, Acc) ->
replvar(Params, ClientInfo, [ClientId | Acc]);
replvar(["'%a'" | Params], ClientInfo = #{peerhost := IpAddr}, Acc) ->
replvar(Params, ClientInfo, [inet_parse:ntoa(IpAddr) | Acc]);
replvar(["'%C'" | Params], ClientInfo, Acc) ->
replvar(Params, ClientInfo, [safe_get(cn, ClientInfo)| Acc]);
replvar(["'%d'" | Params], ClientInfo, Acc) ->
replvar(Params, ClientInfo, [safe_get(dn, ClientInfo)| Acc]);
replvar([Param | Params], ClientInfo, Acc) ->
replvar(Params, ClientInfo, [Param | Acc]).
safe_get(K, ClientInfo) ->
bin(maps:get(K, ClientInfo, "undefined")).
bin(A) when is_atom(A) -> atom_to_binary(A, utf8);
bin(B) when is_binary(B) -> B;
bin(X) -> X.