File: //home/arjun/.local/lib/python3.10/site-packages/aiohappyeyeballs/_staggered.py
import asyncio
import contextlib
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
_T = TypeVar("_T")
def _set_result(wait_next: "asyncio.Future[None]") -> None:
"""Set the result of a future if it is not already done."""
if not wait_next.done():
wait_next.set_result(None)
async def _wait_one(
futures: "Iterable[asyncio.Future[Any]]",
loop: asyncio.AbstractEventLoop,
) -> _T:
"""Wait for the first future to complete."""
wait_next = loop.create_future()
def _on_completion(fut: "asyncio.Future[Any]") -> None:
if not wait_next.done():
wait_next.set_result(fut)
for f in futures:
f.add_done_callback(_on_completion)
try:
return await wait_next
finally:
for f in futures:
f.remove_done_callback(_on_completion)
async def staggered_race(
coro_fns: Iterable[Callable[[], Awaitable[_T]]],
delay: Optional[float],
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
"""
Run coroutines with staggered start times and take the first to finish.
This method takes an iterable of coroutine functions. The first one is
started immediately. From then on, whenever the immediately preceding one
fails (raises an exception), or when *delay* seconds has passed, the next
coroutine is started. This continues until one of the coroutines complete
successfully, in which case all others are cancelled, or until all
coroutines fail.
The coroutines provided should be well-behaved in the following way:
* They should only ``return`` if completed successfully.
* They should always raise an exception if they did not complete
successfully. In particular, if they handle cancellation, they should
probably reraise, like this::
try:
# do work
except asyncio.CancelledError:
# undo partially completed work
raise
Args:
----
coro_fns: an iterable of coroutine functions, i.e. callables that
return a coroutine object when called. Use ``functools.partial`` or
lambdas to pass arguments.
delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially.
loop: the event loop to use. If ``None``, the running loop is used.
Returns:
-------
tuple *(winner_result, winner_index, exceptions)* where
- *winner_result*: the result of the winning coroutine, or ``None``
if no coroutines won.
- *winner_index*: the index of the winning coroutine in
``coro_fns``, or ``None`` if no coroutines won. If the winning
coroutine may return None on success, *winner_index* can be used
to definitively determine whether any coroutine won.
- *exceptions*: list of exceptions returned by the coroutines.
``len(exceptions)`` is equal to the number of coroutines actually
started, and the order is the same as in ``coro_fns``. The winning
coroutine's entry is ``None``.
"""
loop = loop or asyncio.get_running_loop()
exceptions: List[Optional[BaseException]] = []
tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set()
async def run_one_coro(
coro_fn: Callable[[], Awaitable[_T]],
this_index: int,
start_next: "asyncio.Future[None]",
) -> Optional[Tuple[_T, int]]:
"""
Run a single coroutine.
If the coroutine fails, set the exception in the exceptions list and
start the next coroutine by setting the result of the start_next.
If the coroutine succeeds, return the result and the index of the
coroutine in the coro_fns list.
If SystemExit or KeyboardInterrupt is raised, re-raise it.
"""
try:
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as e:
exceptions[this_index] = e
_set_result(start_next) # Kickstart the next coroutine
return None
return result, this_index
start_next_timer: Optional[asyncio.TimerHandle] = None
start_next: Optional[asyncio.Future[None]]
task: asyncio.Task[Optional[Tuple[_T, int]]]
done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]]
coro_iter = iter(coro_fns)
this_index = -1
try:
while True:
if coro_fn := next(coro_iter, None):
this_index += 1
exceptions.append(None)
start_next = loop.create_future()
task = loop.create_task(run_one_coro(coro_fn, this_index, start_next))
tasks.add(task)
start_next_timer = (
loop.call_later(delay, _set_result, start_next) if delay else None
)
elif not tasks:
# We exhausted the coro_fns list and no tasks are running
# so we have no winner and all coroutines failed.
break
while tasks or start_next:
done = await _wait_one(
(*tasks, start_next) if start_next else tasks, loop
)
if done is start_next:
# The current task has failed or the timer has expired
# so we need to start the next task.
start_next = None
if start_next_timer:
start_next_timer.cancel()
start_next_timer = None
# Break out of the task waiting loop to start the next
# task.
break
if TYPE_CHECKING:
assert isinstance(done, asyncio.Task)
tasks.remove(done)
if winner := done.result():
return *winner, exceptions
finally:
# We either have:
# - a winner
# - all tasks failed
# - a KeyboardInterrupt or SystemExit.
#
# If the timer is still running, cancel it.
#
if start_next_timer:
start_next_timer.cancel()
#
# If there are any tasks left, cancel them and than
# wait them so they fill the exceptions list.
#
for task in tasks:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
return None, None, exceptions