diff --git a/src/ecpool.erl b/src/ecpool.erl index abce78bfa..b929253cf 100644 --- a/src/ecpool.erl +++ b/src/ecpool.erl @@ -14,25 +14,20 @@ -module(ecpool). --export([pool_spec/4, - start_pool/3, - start_sup_pool/3, - stop_sup_pool/1, - get_client/1, - get_client/2, - with_client/2, - with_client/3, - workers/1 - ]). +-export([pool_spec/4, start_pool/3, start_sup_pool/3, stop_sup_pool/1, + get_client/1, get_client/2, with_client/2, with_client/3, + set_reconnect_callback/2, + name/1, workers/1]). --export([name/1]). +-type pool_type() :: random | hash | round_robin. --type(pool_name() :: atom()). --type(pool_type() :: random | hash | round_robin). --type(option() :: {pool_size, pos_integer()} +-type reconn_callback() :: {fun((pid()) -> term())}. + +-type option() :: {pool_size, pos_integer()} | {pool_type, pool_type()} | {auto_reconnect, false | pos_integer()} - | tuple()). + | {on_reconnect, reconn_callback()} + | tuple(). -export_type([pool_name/0, pool_type/0, @@ -70,6 +65,12 @@ get_client(Pool) -> get_client(Pool, Key) -> gproc_pool:pick_worker(name(Pool), Key). +-spec(set_reconnect_callback(atom(), reconn_callback()) -> ok). +set_reconnect_callback(Pool, Callback) -> + [ecpool_worker:set_reconnect_callback(Worker, Callback) + || {_WorkerName, Worker} <- ecpool:workers(Pool)], + ok. + %% @doc Call the fun with client/connection -spec(with_client(atom(), fun((Client :: pid()) -> any())) -> any()). with_client(Pool, Fun) when is_atom(Pool) -> diff --git a/src/ecpool_worker.erl b/src/ecpool_worker.erl index a3dc844b6..f107bb1d0 100644 --- a/src/ecpool_worker.erl +++ b/src/ecpool_worker.erl @@ -19,8 +19,7 @@ -export([start_link/4]). %% API Function Exports --export([client/1]). --export([is_connected/1]). +-export([start_link/4, client/1, is_connected/1, set_reconnect_callback/2]). %% gen_server Function Exports -export([init/1, @@ -31,7 +30,7 @@ code_change/3 ]). --record(state, {pool, id, client, mod, opts}). +-record(state, {pool, id, client, mod, on_reconnect, opts}). %%------------------------------------------------------------------------------ %% Callback @@ -73,17 +72,22 @@ client(Pid) -> is_connected(Pid) -> gen_server:call(Pid, is_connected, infinity). -%%------------------------------------------------------------------------------ -%% gen_server callbacks -%%------------------------------------------------------------------------------ +-spec(set_reconnect_callback(pid(), ecpool:reconn_callback()) -> ok). +set_reconnect_callback(Pid, OnReconnect) -> + gen_server:cast(Pid, {set_reconn_callbk, OnReconnect}). + +%%%============================================================================= +%%% gen_server callbacks +%%%============================================================================= init([Pool, Id, Mod, Opts]) -> process_flag(trap_exit, true), - State = #state{pool = Pool, id = Id, mod = Mod, opts = Opts}, + State = #state{pool = Pool, id = Id, mod = Mod, opts = Opts, + on_reconnect = proplists:get_value(on_reconnect, Opts)}, case connect(State) of - {ok, Client} -> - ok = maybe_apply(proplists:get_value(bind, Opts), self()), - true = gproc_pool:connect_worker(ecpool:name(Pool), {Pool, Id}), + {ok, Client} when is_pid(Client) -> + erlang:link(Client), + gproc_pool:connect_worker(ecpool:name(Pool), {Pool, Id}), {ok, State#state{client = Client}}; {error, Error} -> {stop, Error} @@ -103,8 +107,10 @@ handle_call(Req, _From, State) -> logger:error("[PoolWorker] unexpected call: ~p", [Req]), {reply, ignored, State}. -handle_cast(Msg, State) -> - logger:error("[PoolWorker] unexpected cast: ~p", [Msg]), +handle_cast({set_reconn_callbk, OnReconnect}, State) -> + {noreply, State#state{on_reconnect = OnReconnect}}; + +handle_cast(_Msg, State) -> {noreply, State}. handle_info({'EXIT', Pid, Reason}, State = #state{client = Pid, opts = Opts}) -> @@ -115,9 +121,10 @@ handle_info({'EXIT', Pid, Reason}, State = #state{client = Pid, opts = Opts}) -> reconnect_after(Secs, State) end; -handle_info(reconnect, State = #state{pool= Pool, opts = Opts}) -> - try connect(State) of +handle_info(reconnect, State = #state{opts = Opts, on_reconnect = OnReconnect}) -> + case catch connect(State) of {ok, Client} -> + handle_reconnect(Client, OnReconnect), {noreply, State#state{client = Client}}; {error, Reason} -> logger:error("[PoolWorker] ~p reconnect error: ~p", [Pool, Reason]), @@ -165,8 +172,7 @@ reconnect_after(Secs, State) -> _ = erlang:send_after(timer:seconds(Secs), self(), reconnect), {noreply, State#state{client = undefined}}. -maybe_apply(undefined, _Self) -> +handle_reconnect(_, undefined) -> ok; -maybe_apply(Fun, Self) -> - Fun(Self). - +handle_reconnect(Client, OnReconnect) -> + OnReconnect(Client).