Source code for generic_connection_pool.contrib.asyncpg
"""
Postgres asyncpg connection manager implementation.
"""
from typing import Generic, Mapping, Optional, TypeVar
import asyncpg
from generic_connection_pool.asyncio import BaseConnectionManager
DbEndpoint = str
Connection = asyncpg.Connection
DsnParameters = Mapping[DbEndpoint, Mapping[str, str]]
RecordT = TypeVar('RecordT', bound=asyncpg.Record)
[docs]
class DbConnectionManager(BaseConnectionManager[DbEndpoint, 'Connection[RecordT]'], Generic[RecordT]):
"""
Psycopg2 based postgres connection manager.
:param dsn_params: databases dsn parameters
"""
def __init__(self, dsn_params: DsnParameters):
self._dsn_params = dsn_params
[docs]
async def create(
self,
endpoint: DbEndpoint,
timeout: Optional[float] = None,
) -> 'Connection[RecordT]':
return await asyncpg.connect(**self._dsn_params[endpoint]) # type: ignore[call-overload]
[docs]
async def dispose(
self,
endpoint: DbEndpoint,
conn: 'Connection[RecordT]',
timeout: Optional[float] = None,
) -> None:
await conn.close(timeout=timeout)
[docs]
async def check_aliveness(
self,
endpoint: DbEndpoint,
conn: 'Connection[RecordT]',
timeout: Optional[float] = None,
) -> bool:
return conn.is_closed()
[docs]
async def on_acquire(self, endpoint: DbEndpoint, conn: 'Connection[RecordT]') -> None:
await self._rollback_uncommitted(conn)
[docs]
async def on_release(self, endpoint: DbEndpoint, conn: 'Connection[RecordT]') -> None:
await self._rollback_uncommitted(conn)
async def _rollback_uncommitted(self, conn: 'Connection[RecordT]') -> None:
if conn.is_in_transaction():
await conn.execute('ROLLBACK')