Source code for generic_connection_pool.contrib.socket_async
"""
Asynchronous socket connection manager implementation.
"""
import asyncio
import errno
import socket
from ipaddress import IPv4Address, IPv6Address
from ssl import SSLContext
from typing import Generic, Tuple, Union
from generic_connection_pool.asyncio import BaseConnectionManager, EndpointT
IpAddress = Union[IPv4Address, IPv6Address]
Port = int
TcpEndpoint = Tuple[IpAddress, Port]
[docs]
class SocketAlivenessCheckingMixin(Generic[EndpointT]):
"""
Nonblocking socket aliveness checking mix-in.
"""
async def check_aliveness(self, endpoint: EndpointT, conn: socket.socket) -> bool:
try:
if conn.recv(1, socket.MSG_PEEK) == b'':
return False
except BlockingIOError as exc:
if exc.errno != errno.EAGAIN:
raise
except OSError:
return False
return True
[docs]
class TcpSocketConnectionManager(
SocketAlivenessCheckingMixin[TcpEndpoint],
BaseConnectionManager[TcpEndpoint, socket.socket],
):
"""
TCP socket connection manager.
"""
[docs]
async def create(self, endpoint: TcpEndpoint) -> socket.socket:
loop = asyncio.get_running_loop()
addr, port = endpoint
if addr.version == 4:
family = socket.AF_INET
elif addr.version == 6:
family = socket.AF_INET6
else:
raise RuntimeError("unsupported address version type: %s", addr.version)
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
try:
sock.setblocking(False)
await loop.sock_connect(sock, address=(str(addr), port))
except BaseException:
sock.close()
raise
return sock
[docs]
async def dispose(self, endpoint: TcpEndpoint, conn: socket.socket) -> None:
try:
conn.shutdown(socket.SHUT_RDWR)
except OSError:
pass
conn.close()
Hostname = str
TcpStreamEndpoint = Tuple[Hostname, Port]
Stream = Tuple[asyncio.StreamReader, asyncio.StreamWriter]
[docs]
class StreamAlivenessCheckingMixin(Generic[EndpointT]):
"""
Asynchronous stream aliveness checking mix-in.
"""
async def check_aliveness(self, endpoint: EndpointT, conn: Stream) -> bool:
reader, writer = conn
try:
await reader.read(0)
except OSError:
return False
return not writer.is_closing() and not reader.at_eof()
[docs]
class TcpStreamConnectionManager(
StreamAlivenessCheckingMixin[TcpStreamEndpoint],
BaseConnectionManager[TcpStreamEndpoint, Stream],
):
"""
TCP stream connection manager.
"""
def __init__(self, ssl: Union[None, bool, SSLContext] = None):
self._ssl = ssl
[docs]
async def create(self, endpoint: TcpStreamEndpoint) -> Stream:
hostname, port = endpoint
server_hostname = hostname if self._ssl is not None else None
reader, writer = await asyncio.open_connection(
hostname,
port,
server_hostname=server_hostname,
ssl=self._ssl,
)
return reader, writer
[docs]
async def dispose(self, endpoint: TcpStreamEndpoint, conn: Stream) -> None:
reader, writer = conn
if writer.can_write_eof():
writer.write_eof()
writer.close()
await writer.wait_closed()