import abc
import contextlib
import logging
import math
import threading
from collections import defaultdict
from typing import Any, Callable, DefaultDict, Generator, Generic, Hashable, List, Optional, Tuple, TypeVar
from generic_connection_pool import exceptions
from generic_connection_pool.common import BaseConnectionPool, BaseEndpointPool, BaseEventQueue, ConnectionInfo
from generic_connection_pool.common import EventType, Timer
from .locks import SharedLock
logger = logging.getLogger(__package__)
EndpointT = TypeVar('EndpointT', bound=Hashable)
ConnectionT = TypeVar('ConnectionT', bound=Hashable)
[docs]class BaseConnectionManager(Generic[EndpointT, ConnectionT], abc.ABC):
"""
Abstract synchronous connection factory.
"""
[docs] @abc.abstractmethod
def create(self, endpoint: EndpointT, timeout: Optional[float] = None) -> ConnectionT:
"""
Creates a new connection.
:param endpoint: endpoint to connect to
:param timeout: operation timeout
:return: new connection
"""
[docs] @abc.abstractmethod
def dispose(self, endpoint: EndpointT, conn: ConnectionT, timeout: Optional[float] = None) -> None:
"""
Disposes the connection.
:param endpoint: endpoint to connect to
:param conn: connection to be disposed
:param timeout: operation timeout
"""
[docs] def check_aliveness(self, endpoint: EndpointT, conn: ConnectionT, timeout: Optional[float] = None) -> bool:
"""
Checks that the connection is alive.
:param endpoint: endpoint to connect to
:param conn: connection to be checked
:param timeout: operation timeout
:return: ``True`` if connection is alive otherwise ``False``
"""
return True
[docs] def on_acquire(self, endpoint: EndpointT, conn: ConnectionT) -> None:
"""
Callback invoked on connection acquire.
:param endpoint: endpoint to connect to
:param conn: connection to be acquired
"""
[docs] def on_release(self, endpoint: EndpointT, conn: ConnectionT) -> None:
"""
Callback invoked on connection on_release.
:param endpoint: endpoint to connect to
:param conn: connection to be acquired
"""
[docs] def on_connection_dead(self, endpoint: EndpointT, conn: ConnectionT) -> None:
"""
Callback invoked on when connection aliveness check failed.
:param endpoint: endpoint to connect to
:param conn: dead connection
"""
KeyType = TypeVar('KeyType', bound=Hashable)
class EventQueue(BaseEventQueue[KeyType], Generic[KeyType]):
"""
Thread-safe event queue wrapper.
"""
def __init__(self) -> None:
super().__init__()
self._lock = threading.Condition(threading.Lock())
self._stopped = False
def get_size(self) -> int:
with self._lock:
return len(self._queue)
def insert(self, timestamp: float, key: KeyType) -> None:
with self._lock:
self._insert(timestamp, key)
self._lock.notify_all()
def remove(self, key: KeyType) -> None:
with self._lock:
self._remove(key)
def clear(self) -> None:
with self._lock:
self._clear()
def wait(self, timeout: Optional[float] = None) -> KeyType:
"""
Waits for the next event. The event is not removed from the queue.
"""
timer = Timer(timeout)
with self._lock:
while True:
if self._stopped:
raise exceptions.ConnectionPoolClosedError
key, backoff = self._try_get_next_event()
if key is not None:
return key
elif timer.timedout:
raise TimeoutError
else:
self._lock.wait(
timeout=min(backoff, timer.remains)
if backoff is not None and timer.remains is not None
else backoff or timer.remains,
)
def top(self) -> Optional[KeyType]:
"""
Returns top event.
"""
with self._lock:
return self._top()
def stop(self) -> None:
"""
Notifies the subscribers that the process in stopped.
"""
with self._lock:
self._stopped = True
self._lock.notify_all()
class EndpointPool(BaseEndpointPool[EndpointT, ConnectionT], Generic[EndpointT, ConnectionT]):
"""
Thread-safe endpoint pool wrapper.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._condvar = threading.Condition(threading.Lock())
@property
def empty(self) -> bool:
return self._size() == 0
def size(self) -> int:
with self._condvar:
return self._size()
def has_available_slot(self) -> bool:
with self._condvar:
return self._has_available_slot()
def is_overflowed(self) -> bool:
with self._condvar:
return self._is_overflowed()
def get_size(self, acquired: Optional[bool] = None) -> int:
with self._condvar:
return self._get_size(acquired)
def reserve(self) -> bool:
with self._condvar:
return self._reserve()
def acquire(self) -> Tuple[Optional[ConnectionInfo[EndpointT, ConnectionT]], bool]:
with self._condvar:
return self._acquire()
def release(self, conn: ConnectionT) -> Tuple[ConnectionInfo[EndpointT, ConnectionT], bool]:
with self._condvar:
result = self._release(conn)
self._condvar.notify()
return result
def detach(self, conn: ConnectionT, acquired: bool = False) -> ConnectionInfo[EndpointT, ConnectionT]:
with self._condvar:
result = self._detach(conn, acquired)
self._condvar.notify()
return result
def attach(self, conn_info: ConnectionInfo[EndpointT, ConnectionT], acquired: bool = False) -> None:
with self._condvar:
self._attach(conn_info, acquired)
self._condvar.notify()
def acquire_and_detach(self) -> Optional[ConnectionInfo[EndpointT, ConnectionT]]:
with self._condvar:
conn_info, extra = self._acquire()
if conn_info is None:
return None
result = self._detach(conn_info.conn, acquired=True)
self._condvar.notify()
return result
def try_acquire_or_reserve(
self,
timeout: Optional[float] = None,
) -> Tuple[Optional[ConnectionInfo[EndpointT, ConnectionT]], bool]:
timer = Timer(timeout)
with self._condvar:
while True:
conn_info, extra = self._acquire()
if conn_info is not None:
return conn_info, extra
elif self._reserve():
return None, False
elif not self._condvar.wait(timer.remains):
raise TimeoutError
def attach_and_unreserve(self, conn_info: ConnectionInfo[EndpointT, ConnectionT], acquired: bool = False) -> None:
with self._condvar:
self._unreserve()
self._attach(conn_info, acquired)
self._condvar.notify()
def unreserve(self) -> None:
with self._condvar:
self._unreserve()
self._condvar.notify()
class PoolManager(Generic[EndpointT, ConnectionT]):
"""
Connection pool manager.
Provides an api to work with connection pools safely.
"""
def __init__(self, pool_factory: Callable[[], EndpointPool[EndpointT, ConnectionT]]) -> None:
self._pools: DefaultDict[EndpointT, Tuple[SharedLock, EndpointPool[EndpointT, ConnectionT]]] = defaultdict(
lambda: (SharedLock(), pool_factory()),
)
self._condvar = threading.Condition(lock=threading.Lock())
def get_size(self) -> int:
with self._condvar:
return sum(pool.size() for lock, pool in self._pools.values())
def endpoints(self) -> List[EndpointT]:
"""
Returns available endpoints.
"""
with self._condvar:
return list(self._pools.keys())
def wait_for(self, predicate: Callable[[], bool], timeout: Optional[float] = None) -> bool:
"""
Waits for the pool manager state change.
"""
with self._condvar:
return self._condvar.wait_for(predicate, timeout=timeout)
@contextlib.contextmanager
def acquired(
self,
endpoint: EndpointT,
exclusive: bool = False,
blocking: bool = True,
timeout: Optional[float] = None,
setdefault: bool = False,
) -> Generator[EndpointPool[EndpointT, ConnectionT], None, None]:
"""
Opens the endpoint pool acquiring context.
:param endpoint: pool endpoint
:param exclusive: pool access mode (shared or exclusive)
:param blocking: acquiring mode
:param timeout: pool acquiring timeout
:param setdefault: create a new pool if it not exists
:return: acquired pool
"""
with self._condvar:
if (lock_and_pool := self._pools[endpoint] if setdefault else self._pools.get(endpoint)) is None:
raise exceptions.ConnectionPoolNotFound
lock, pool = lock_and_pool
if not lock.acquire(exclusive, blocking=blocking, timeout=timeout):
raise TimeoutError
try:
yield pool
finally:
lock.release(exclusive)
with self._condvar:
self._condvar.notify()
def try_delete(self, endpoint: EndpointT) -> bool:
"""
Tries to delete the endpoint pool.
Acquires the pool in exclusive mode and checks that pool is empty.
:param endpoint: pool endpoint
:return: `True` if the pool has been deleted otherwise `False`
"""
with self._condvar:
if (lock_and_pool := self._pools.get(endpoint)) is None:
return True
else:
lock, pool = lock_and_pool
try:
with lock.acquired(exclusive=True, blocking=False):
if not pool.empty:
return False
else:
self._pools.pop(endpoint)
return True
except TimeoutError:
return False
[docs]class ConnectionPool(Generic[EndpointT, ConnectionT], BaseConnectionPool[EndpointT, ConnectionT]):
"""
Synchronous connection pool.
:param connection_manager: connection manager instance. Used to create, dispose or check connection aliveness.
:param acquire_timeout: connection acquiring default timeout.
:param background_collector: if ``True`` starts a background worker that disposes expired and idle connections
maintaining requested pool state. If ``False`` the connections will be disposed
on each connection release.
:param dispose_batch_size: maximum number of expired and idle connections to be disposed on connection release
(if background collector is started the parameter is ignored).
:param dispose_timeout: connection disposal timeout.
:param min_idle: minimum number of connections in each endpoint the pool tries to hold. Connections that exceed
that number will be considered as extra and disposed after ``idle_timeout`` seconds of inactivity.
:param max_size: maximum number of endpoint connections.
:param kwargs: see :py:class:`generic_connection_pool.common.BaseConnectionPool`
"""
def __init__(
self,
connection_manager: BaseConnectionManager[EndpointT, ConnectionT],
*,
acquire_timeout: Optional[float] = None,
background_collector: bool = False,
dispose_batch_size: int = 0,
dispose_timeout: Optional[float] = None,
min_idle: int = 1,
max_size: int = 10,
**kwargs: Any,
):
super().__init__(min_idle=min_idle, max_size=max_size, **kwargs)
self._stopped = False
self._acquire_timeout = acquire_timeout
self._dispose_batch_size = dispose_batch_size
self._dispose_timeout = dispose_timeout
self._connection_manager = connection_manager
self._lock = threading.Lock()
self._pools = PoolManager(
pool_factory=lambda: EndpointPool[EndpointT, ConnectionT](
max_pool_size=min_idle,
max_extra_size=max_size - min_idle,
),
)
self._event_queue = EventQueue[Tuple[EventType, EndpointT, ConnectionT]]()
self._collector: Optional[threading.Thread] = None
if background_collector:
self._collector = threading.Thread(
target=self._start_collector,
name='gcp-collector',
)
self._collector.start()
def get_size(self) -> int:
with self._lock:
return super().get_size()
[docs] def get_endpoint_pool_size(self, endpoint: EndpointT, acquired: Optional[bool] = None) -> int:
"""
Returns endpoint pool size.
:param endpoint: pool endpoint
:param acquired: if `True` returns the number of acquired connections,
if `False` returns the number of free connections
otherwise returns total size
"""
try:
with self._pools.acquired(endpoint) as pool:
return pool.get_size(acquired)
except exceptions.ConnectionPoolNotFound:
return 0
[docs] @contextlib.contextmanager
def connection(self, endpoint: EndpointT, timeout: Optional[float] = None) -> Generator[ConnectionT, None, None]:
"""
Acquires a connection from the pool.
:param endpoint: connection endpoint
:param timeout: number of seconds to wait. If timeout is reached :py:class:`TimeoutError` is raised.
:return: acquired connection
"""
conn = self.acquire(endpoint, timeout=timeout)
try:
yield conn
finally:
self.release(conn, endpoint)
[docs] def acquire(self, endpoint: EndpointT, timeout: Optional[float] = None) -> ConnectionT:
"""
Acquires a connection from the pool.
:param endpoint: connection endpoint
:param timeout: number of seconds to wait. If timeout is reached :py:class:`TimeoutError` is raised.
:return: acquired connection
"""
timeout = self._acquire_timeout if timeout is None else timeout
conn = self._acquire_connection(endpoint, timeout=timeout)
try:
self._connection_manager.on_acquire(endpoint, conn)
except Exception:
self._release_connection(endpoint, conn)
raise
return conn
[docs] def release(self, conn: ConnectionT, endpoint: EndpointT) -> None:
"""
Releases a connection.
:param conn: connection to be released
:param endpoint: connection endpoint
"""
try:
self._connection_manager.on_release(endpoint, conn)
finally:
self._release_connection(endpoint, conn)
if self._collector is None:
dispose_batch_size = self._dispose_batch_size or int(math.log2(self._pool_size + 1)) + 1
self._collect_disposable_connections(dispose_batch_size)
[docs] def close(self, timeout: Optional[float] = None) -> None:
"""
Closes the connection pool.
:param timeout: timeout after which the pool closes all connection despite they are released or not
"""
timer = Timer(timeout)
self._stopped = True
self._event_queue.stop()
if self._collector is not None:
self._collector.join(timeout=timer.remains)
self._close_connections(timeout=timer.remains)
self._event_queue.clear()
def _acquire_connection(self, endpoint: EndpointT, timeout: Optional[float]) -> ConnectionT:
timer = Timer(timeout)
while True:
if self._stopped:
raise exceptions.ConnectionPoolClosedError
with self._pools.acquired(endpoint, setdefault=True) as pool:
conn_info, extra = pool.try_acquire_or_reserve(timeout=timer.remains)
if conn_info is not None:
# unsubscribe the connection since acquired connection can't be disposed
self._event_queue.remove((EventType.LIFETIME, conn_info.endpoint, conn_info.conn))
if extra:
self._event_queue.remove((EventType.IDLETIME, conn_info.endpoint, conn_info.conn))
try:
is_alive = self._connection_manager.check_aliveness(endpoint, conn_info.conn, timer.remains)
except Exception:
self._release_connection(endpoint, conn_info.conn)
raise
if not is_alive:
pool.detach(conn_info.conn, acquired=True)
self._decrease_pool_size()
self._connection_manager.on_connection_dead(endpoint, conn_info.conn)
continue
else:
try:
if conn_info := self._create_connection(endpoint, timer.remains):
pool.attach_and_unreserve(conn_info, acquired=True)
else:
pool.unreserve()
continue
except Exception:
pool.unreserve()
raise
logger.debug("connection created: %s", endpoint)
return conn_info.conn
def _create_connection(
self,
endpoint: EndpointT,
timeout: Optional[float] = None,
) -> ConnectionInfo[EndpointT, ConnectionT]:
timer = Timer(timeout)
while True:
if self._increase_pool_size():
try:
conn = self._connection_manager.create(endpoint, timeout=timer.remains)
except Exception:
self._decrease_pool_size()
raise
return ConnectionInfo(endpoint, conn)
else:
self._try_free_slot(timeout=timer.remains)
def _try_free_slot(self, timeout: Optional[float] = None) -> bool:
timer = Timer(timeout)
if event := self._event_queue.top():
ev, endpoint, conn = event
return self._try_detach_connection(endpoint, conn, timeout=timer.remains)
elif not self._pools.wait_for(predicate=lambda: not self.is_full, timeout=timer.remains):
raise TimeoutError
return True
def _collect_disposable_connections(self, max_disposals: int) -> None:
disposals = 0
while disposals < max_disposals:
try:
ev, endpoint, conn = self._event_queue.wait(timeout=0.0)
except TimeoutError:
# no connections to dispose
break
if self._try_detach_connection(endpoint, conn):
disposals += 1
if disposals > 0:
logger.debug("disposed %d connections", disposals)
def _start_collector(self) -> None:
logger.debug("collector started")
while not self._stopped:
try:
ev, endpoint, conn = self._event_queue.wait()
self._try_detach_connection(endpoint, conn)
except exceptions.ConnectionPoolClosedError:
break
def _try_detach_connection(self, endpoint: EndpointT, conn: ConnectionT, timeout: Optional[float] = None) -> bool:
with contextlib.suppress(exceptions.ConnectionPoolNotFound):
with self._pools.acquired(endpoint, timeout=timeout) as pool:
try:
conn_info = pool.detach(conn)
except KeyError:
return False
finally:
self._event_queue.remove((EventType.LIFETIME, endpoint, conn))
self._event_queue.remove((EventType.IDLETIME, endpoint, conn))
self._dispose_connection(conn_info, timeout=self._dispose_timeout)
self._decrease_pool_size()
is_pool_empty = pool.empty
if is_pool_empty:
self._pools.try_delete(endpoint)
return True
def _dispose_connection(self, conn_info: ConnectionInfo[EndpointT, ConnectionT], timeout: Optional[float]) -> bool:
try:
self._connection_manager.dispose(conn_info.endpoint, conn_info.conn, timeout=timeout)
except TimeoutError:
logger.error("connection disposal timed-out: %s", conn_info.endpoint)
return False
except Exception as e:
logger.error("connection disposal failed: %s", e)
return False
logger.debug("connection disposed: %s", conn_info.endpoint)
return True
def _release_connection(self, endpoint: EndpointT, conn: ConnectionT) -> None:
with self._pools.acquired(endpoint) as pool:
try:
conn_info, extra = pool.release(conn)
except KeyError:
raise RuntimeError("connection not acquired")
# subscribe the connection, it is disposable again
self._event_queue.insert(
conn_info.created_at + self.max_lifetime,
(EventType.LIFETIME, conn_info.endpoint, conn_info.conn),
)
if extra:
self._event_queue.insert(
conn_info.accessed_at + self._idle_timeout,
(EventType.IDLETIME, conn_info.endpoint, conn_info.conn),
)
def _close_connections(self, timeout: Optional[float]) -> None:
timer = Timer(timeout)
while self._pools.get_size() != 0:
for endpoint in self._pools.endpoints():
with contextlib.suppress(exceptions.ConnectionPoolNotFound):
with self._pools.acquired(endpoint, timeout=timer.remains) as pool:
conn_info = pool.acquire_and_detach()
if conn_info is None:
continue
self._dispose_connection(conn_info, timeout=timer.remains)
self._decrease_pool_size()
is_pool_empty = pool.empty
if is_pool_empty:
self._pools.try_delete(endpoint)
def _increase_pool_size(self) -> bool:
with self._lock:
if self.is_full:
return False
else:
self._pool_size += 1
return True
def _decrease_pool_size(self) -> None:
with self._lock:
self._pool_size -= 1