diff --git a/src/ecpool.erl b/src/ecpool.erl index d1627a4f6..d62b84f3f 100644 --- a/src/ecpool.erl +++ b/src/ecpool.erl @@ -28,13 +28,17 @@ -export([pool_spec/4, start_pool/3, start_sup_pool/3, stop_sup_pool/1, get_client/1, get_client/2, with_client/2, with_client/3, + set_reconnect_callback/2, name/1, workers/1]). -type pool_type() :: random | hash | round_robin. +-type reconn_callback() :: {fun((pid()) -> term())}. + -type option() :: {pool_size, pos_integer()} | {pool_type, pool_type()} | {auto_reconnect, false | pos_integer()} + | {on_reconnect, reconn_callback()} | tuple(). pool_spec(ChildId, Pool, Mod, Opts) -> @@ -64,6 +68,12 @@ get_client(Pool) -> get_client(Pool, Key) -> gproc_pool:pick_worker(name(Pool), Key). +-spec(set_reconnect_callback(atom(), reconn_callback()) -> ok). +set_reconnect_callback(Pool, Callback) -> + [ecpool_worker:set_reconnect_callback(Worker, Callback) + || {_WorkerName, Worker} <- ecpool:workers(Pool)], + ok. + %% @doc Call the fun with client/connection -spec(with_client(atom(), fun((Client :: pid()) -> any())) -> any()). with_client(Pool, Fun) when is_atom(Pool) -> diff --git a/src/ecpool_worker.erl b/src/ecpool_worker.erl index feffc1332..d2294383f 100644 --- a/src/ecpool_worker.erl +++ b/src/ecpool_worker.erl @@ -29,13 +29,13 @@ -behaviour(gen_server). %% API Function Exports --export([start_link/4, client/1, is_connected/1]). +-export([start_link/4, client/1, is_connected/1, set_reconnect_callback/2]). %% gen_server Function Exports -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --record(state, {pool, id, client, mod, opts}). +-record(state, {pool, id, client, mod, on_reconnect, opts}). %%%============================================================================= %%% Callback @@ -77,13 +77,18 @@ client(Pid) -> is_connected(Pid) -> gen_server:call(Pid, is_connected). +-spec(set_reconnect_callback(pid(), ecpool:reconn_callback()) -> ok). +set_reconnect_callback(Pid, OnReconnect) -> + gen_server:cast(Pid, {set_reconn_callbk, OnReconnect}). + %%%============================================================================= %%% gen_server callbacks %%%============================================================================= init([Pool, Id, Mod, Opts]) -> process_flag(trap_exit, true), - State = #state{pool = Pool, id = Id, mod = Mod, opts = Opts}, + State = #state{pool = Pool, id = Id, mod = Mod, opts = Opts, + on_reconnect = proplists:get_value(on_reconnect, Opts)}, case connect(State) of {ok, Client} when is_pid(Client) -> erlang:link(Client), @@ -102,6 +107,9 @@ handle_call(client, _From, State = #state{client = undefined}) -> handle_call(client, _From, State = #state{client = Client}) -> {reply, {ok, Client}, State}. +handle_cast({set_reconn_callbk, OnReconnect}, State) -> + {noreply, State#state{on_reconnect = OnReconnect}}; + handle_cast(_Msg, State) -> {noreply, State}. @@ -113,9 +121,10 @@ handle_info({'EXIT', Pid, Reason}, State = #state{client = Pid, opts = Opts}) -> reconnect(Secs, State) end; -handle_info(reconnect, State = #state{opts = Opts}) -> +handle_info(reconnect, State = #state{opts = Opts, on_reconnect = OnReconnect}) -> case catch connect(State) of {ok, Client} -> + handle_reconnect(Client, OnReconnect), {noreply, State#state{client = Client}}; {Err, _Reason} when Err =:= error orelse Err =:= 'EXIT' -> reconnect(proplists:get_value(auto_reconnect, Opts), State) @@ -152,3 +161,7 @@ reconnect(Secs, State) -> erlang:send_after(timer:seconds(Secs), self(), reconnect), {noreply, State#state{client = undefined}}. +handle_reconnect(_, undefined) -> + ok; +handle_reconnect(Client, OnReconnect) -> + OnReconnect(Client).