File: //home/arjun/.local/lib/python3.10/site-packages/httpcore/_sync/connection.py
from __future__ import annotations
import itertools
import logging
import ssl
import types
import typing
from .._backends.sync import SyncBackend
from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream
from .._exceptions import ConnectError, ConnectTimeout
from .._models import Origin, Request, Response
from .._ssl import default_ssl_context
from .._synchronization import Lock
from .._trace import Trace
from .http11 import HTTP11Connection
from .interfaces import ConnectionInterface
RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
logger = logging.getLogger("httpcore.connection")
def exponential_backoff(factor: float) -> typing.Iterator[float]:
"""
Generate a geometric sequence that has a ratio of 2 and starts with 0.
For example:
- `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...`
- `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...`
"""
yield 0
for n in itertools.count():
yield factor * 2**n
class HTTPConnection(ConnectionInterface):
def __init__(
self,
origin: Origin,
ssl_context: ssl.SSLContext | None = None,
keepalive_expiry: float | None = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: str | None = None,
uds: str | None = None,
network_backend: NetworkBackend | None = None,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
) -> None:
self._origin = origin
self._ssl_context = ssl_context
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._retries = retries
self._local_address = local_address
self._uds = uds
self._network_backend: NetworkBackend = (
SyncBackend() if network_backend is None else network_backend
)
self._connection: ConnectionInterface | None = None
self._connect_failed: bool = False
self._request_lock = Lock()
self._socket_options = socket_options
def handle_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
raise RuntimeError(
f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
)
try:
with self._request_lock:
if self._connection is None:
stream = self._connect(request)
ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
)
if http2_negotiated or (self._http2 and not self._http1):
from .http2 import HTTP2Connection
self._connection = HTTP2Connection(
origin=self._origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = HTTP11Connection(
origin=self._origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
except BaseException as exc:
self._connect_failed = True
raise exc
return self._connection.handle_request(request)
def _connect(self, request: Request) -> NetworkStream:
timeouts = request.extensions.get("timeout", {})
sni_hostname = request.extensions.get("sni_hostname", None)
timeout = timeouts.get("connect", None)
retries_left = self._retries
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
while True:
try:
if self._uds is None:
kwargs = {
"host": self._origin.host.decode("ascii"),
"port": self._origin.port,
"local_address": self._local_address,
"timeout": timeout,
"socket_options": self._socket_options,
}
with Trace("connect_tcp", logger, request, kwargs) as trace:
stream = self._network_backend.connect_tcp(**kwargs)
trace.return_value = stream
else:
kwargs = {
"path": self._uds,
"timeout": timeout,
"socket_options": self._socket_options,
}
with Trace(
"connect_unix_socket", logger, request, kwargs
) as trace:
stream = self._network_backend.connect_unix_socket(
**kwargs
)
trace.return_value = stream
if self._origin.scheme in (b"https", b"wss"):
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)
kwargs = {
"ssl_context": ssl_context,
"server_hostname": sni_hostname
or self._origin.host.decode("ascii"),
"timeout": timeout,
}
with Trace("start_tls", logger, request, kwargs) as trace:
stream = stream.start_tls(**kwargs)
trace.return_value = stream
return stream
except (ConnectError, ConnectTimeout):
if retries_left <= 0:
raise
retries_left -= 1
delay = next(delays)
with Trace("retry", logger, request, kwargs) as trace:
self._network_backend.sleep(delay)
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._origin
def close(self) -> None:
if self._connection is not None:
with Trace("close", logger, None, {}):
self._connection.close()
def is_available(self) -> bool:
if self._connection is None:
# If HTTP/2 support is enabled, and the resulting connection could
# end up as HTTP/2 then we should indicate the connection as being
# available to service multiple requests.
return (
self._http2
and (self._origin.scheme == b"https" or not self._http1)
and not self._connect_failed
)
return self._connection.is_available()
def has_expired(self) -> bool:
if self._connection is None:
return self._connect_failed
return self._connection.has_expired()
def is_idle(self) -> bool:
if self._connection is None:
return self._connect_failed
return self._connection.is_idle()
def is_closed(self) -> bool:
if self._connection is None:
return self._connect_failed
return self._connection.is_closed()
def info(self) -> str:
if self._connection is None:
return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
return self._connection.info()
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.info()}]>"
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
def __enter__(self) -> HTTPConnection:
return self
def __exit__(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
self.close()