Source code for generic_connection_pool.contrib.socket_async

"""
Asynchronous socket connection manager implementation.
"""

import asyncio
import errno
import socket
from ipaddress import IPv4Address, IPv6Address
from ssl import SSLContext
from typing import Generic, Tuple, Union

from generic_connection_pool.asyncio import BaseConnectionManager, EndpointT

IpAddress = Union[IPv4Address, IPv6Address]
Port = int
TcpEndpoint = Tuple[IpAddress, Port]


[docs]class SocketAlivenessCheckingMixin(Generic[EndpointT]): """ Nonblocking socket aliveness checking mix-in. """ async def check_aliveness(self, endpoint: EndpointT, conn: socket.socket) -> bool: try: if conn.recv(1, socket.MSG_PEEK) == b'': return False except BlockingIOError as exc: if exc.errno != errno.EAGAIN: raise except OSError: return False return True
[docs]class TcpSocketConnectionManager( SocketAlivenessCheckingMixin[TcpEndpoint], BaseConnectionManager[TcpEndpoint, socket.socket], ): """ TCP socket connection manager. """
[docs] async def create(self, endpoint: TcpEndpoint) -> socket.socket: loop = asyncio.get_running_loop() addr, port = endpoint if addr.version == 4: family = socket.AF_INET elif addr.version == 6: family = socket.AF_INET6 else: raise RuntimeError("unsupported address version type: %s", addr.version) sock = socket.socket(family=family, type=socket.SOCK_STREAM) sock.setblocking(False) await loop.sock_connect(sock, address=(str(addr), port)) return sock
[docs] async def dispose(self, endpoint: TcpEndpoint, conn: socket.socket) -> None: try: conn.shutdown(socket.SHUT_RDWR) except OSError: pass conn.close()
Hostname = str TcpStreamEndpoint = Tuple[Hostname, Port] Stream = Tuple[asyncio.StreamReader, asyncio.StreamWriter]
[docs]class StreamAlivenessCheckingMixin(Generic[EndpointT]): """ Asynchronous stream aliveness checking mix-in. """ async def check_aliveness(self, endpoint: EndpointT, conn: Stream) -> bool: reader, writer = conn try: await reader.read(0) except OSError: return False return not writer.is_closing() and not reader.at_eof()
[docs]class TcpStreamConnectionManager( StreamAlivenessCheckingMixin[TcpStreamEndpoint], BaseConnectionManager[TcpStreamEndpoint, Stream], ): """ TCP stream connection manager. """ def __init__(self, ssl: Union[None, bool, SSLContext] = None): self._ssl = ssl
[docs] async def create(self, endpoint: TcpStreamEndpoint) -> Stream: hostname, port = endpoint server_hostname = hostname if self._ssl is not None else None reader, writer = await asyncio.open_connection( hostname, port, server_hostname=server_hostname, ssl=self._ssl, ) return reader, writer
[docs] async def dispose(self, endpoint: TcpStreamEndpoint, conn: Stream) -> None: reader, writer = conn if writer.can_write_eof(): writer.write_eof() writer.close() await writer.wait_closed()