File: //home/arjun/.local/lib/python3.10/site-packages/httpcore/_backends/trio.py
from __future__ import annotations
import ssl
import typing
import trio
from .._exceptions import (
ConnectError,
ConnectTimeout,
ExceptionMapping,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
class TrioStream(AsyncNetworkStream):
def __init__(self, stream: trio.abc.Stream) -> None:
self._stream = stream
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: ReadTimeout,
trio.BrokenResourceError: ReadError,
trio.ClosedResourceError: ReadError,
}
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
data: bytes = await self._stream.receive_some(max_bytes=max_bytes)
return data
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
if not buffer:
return
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: WriteTimeout,
trio.BrokenResourceError: WriteError,
trio.ClosedResourceError: WriteError,
}
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
await self._stream.send_all(data=buffer)
async def aclose(self) -> None:
await self._stream.aclose()
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: str | None = None,
timeout: float | None = None,
) -> AsyncNetworkStream:
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: ConnectTimeout,
trio.BrokenResourceError: ConnectError,
}
ssl_stream = trio.SSLStream(
self._stream,
ssl_context=ssl_context,
server_hostname=server_hostname,
https_compatible=True,
server_side=False,
)
with map_exceptions(exc_map):
try:
with trio.fail_after(timeout_or_inf):
await ssl_stream.do_handshake()
except Exception as exc: # pragma: nocover
await self.aclose()
raise exc
return TrioStream(ssl_stream)
def get_extra_info(self, info: str) -> typing.Any:
if info == "ssl_object" and isinstance(self._stream, trio.SSLStream):
# Type checkers cannot see `_ssl_object` attribute because trio._ssl.SSLStream uses __getattr__/__setattr__.
# Tracked at https://github.com/python-trio/trio/issues/542
return self._stream._ssl_object # type: ignore[attr-defined]
if info == "client_addr":
return self._get_socket_stream().socket.getsockname()
if info == "server_addr":
return self._get_socket_stream().socket.getpeername()
if info == "socket":
stream = self._stream
while isinstance(stream, trio.SSLStream):
stream = stream.transport_stream
assert isinstance(stream, trio.SocketStream)
return stream.socket
if info == "is_readable":
socket = self.get_extra_info("socket")
return socket.is_readable()
return None
def _get_socket_stream(self) -> trio.SocketStream:
stream = self._stream
while isinstance(stream, trio.SSLStream):
stream = stream.transport_stream
assert isinstance(stream, trio.SocketStream)
return stream
class TrioBackend(AsyncNetworkBackend):
async def connect_tcp(
self,
host: str,
port: int,
timeout: float | None = None,
local_address: str | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream:
# By default for TCP sockets, trio enables TCP_NODELAY.
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream
if socket_options is None:
socket_options = [] # pragma: no cover
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: ConnectTimeout,
trio.BrokenResourceError: ConnectError,
OSError: ConnectError,
}
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
stream: trio.abc.Stream = await trio.open_tcp_stream(
host=host, port=port, local_address=local_address
)
for option in socket_options:
stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
return TrioStream(stream)
async def connect_unix_socket(
self,
path: str,
timeout: float | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> AsyncNetworkStream: # pragma: nocover
if socket_options is None:
socket_options = []
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: ConnectTimeout,
trio.BrokenResourceError: ConnectError,
OSError: ConnectError,
}
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
stream: trio.abc.Stream = await trio.open_unix_socket(path)
for option in socket_options:
stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
return TrioStream(stream)
async def sleep(self, seconds: float) -> None:
await trio.sleep(seconds) # pragma: nocover