|
from __future__ import annotations |
|
|
|
import asyncio |
|
import math |
|
import threading |
|
from typing import TYPE_CHECKING |
|
|
|
from typing_extensions import override |
|
|
|
from wandb.proto import wandb_server_pb2 as spb |
|
|
|
from .mailbox_handle import HandleAbandonedError, MailboxHandle |
|
|
|
|
|
if TYPE_CHECKING: |
|
from wandb.sdk.interface import interface |
|
|
|
|
|
class MailboxResponseHandle(MailboxHandle[spb.ServerResponse]): |
|
"""A general handle for any ServerResponse.""" |
|
|
|
def __init__(self, address: str) -> None: |
|
self._address = address |
|
self._lock = threading.Lock() |
|
self._event = threading.Event() |
|
|
|
self._abandoned = False |
|
self._response: spb.ServerResponse | None = None |
|
|
|
self._asyncio_events: dict[asyncio.Event, _AsyncioEvent] = dict() |
|
|
|
def deliver(self, response: spb.ServerResponse) -> None: |
|
"""Deliver the response. |
|
|
|
This may only be called once. It is an error to respond to the same |
|
request more than once. It is a no-op if the handle has been abandoned. |
|
""" |
|
with self._lock: |
|
if self._abandoned: |
|
return |
|
|
|
if self._response: |
|
raise ValueError( |
|
f"A response has already been delivered to {self._address}." |
|
) |
|
|
|
self._response = response |
|
self._signal_done() |
|
|
|
@override |
|
def cancel(self, iface: interface.InterfaceBase) -> None: |
|
iface.publish_cancel(self._address) |
|
self.abandon() |
|
|
|
@override |
|
def abandon(self) -> None: |
|
with self._lock: |
|
self._abandoned = True |
|
self._signal_done() |
|
|
|
def _signal_done(self) -> None: |
|
"""Indicate that the handle either got a response or became abandoned. |
|
|
|
The lock must be held. |
|
""" |
|
|
|
self._event.set() |
|
|
|
|
|
for asyncio_event in self._asyncio_events.values(): |
|
asyncio_event.set_threadsafe() |
|
self._asyncio_events.clear() |
|
|
|
@override |
|
def check(self) -> spb.ServerResponse | None: |
|
with self._lock: |
|
return self._response |
|
|
|
@override |
|
def wait_or(self, *, timeout: float | None) -> spb.ServerResponse: |
|
if timeout is not None and not math.isfinite(timeout): |
|
raise ValueError("Timeout must be finite or None.") |
|
|
|
if not self._event.wait(timeout=timeout): |
|
raise TimeoutError( |
|
f"Timed out waiting for response on {self._address}", |
|
) |
|
|
|
with self._lock: |
|
if self._response: |
|
return self._response |
|
|
|
assert self._abandoned |
|
raise HandleAbandonedError() |
|
|
|
@override |
|
async def wait_async(self, *, timeout: float | None) -> spb.ServerResponse: |
|
if timeout is not None and not math.isfinite(timeout): |
|
raise ValueError("Timeout must be finite or None.") |
|
|
|
evt = asyncio.Event() |
|
self._add_asyncio_event(asyncio.get_event_loop(), evt) |
|
|
|
try: |
|
await asyncio.wait_for(evt.wait(), timeout=timeout) |
|
|
|
except (asyncio.TimeoutError, TimeoutError) as e: |
|
with self._lock: |
|
if self._response: |
|
return self._response |
|
elif self._abandoned: |
|
raise HandleAbandonedError() |
|
else: |
|
raise TimeoutError( |
|
f"Timed out waiting for response on {self._address}" |
|
) from e |
|
|
|
else: |
|
with self._lock: |
|
if self._response: |
|
return self._response |
|
|
|
assert self._abandoned |
|
raise HandleAbandonedError() |
|
|
|
finally: |
|
self._forget_asyncio_event(evt) |
|
|
|
def _add_asyncio_event( |
|
self, |
|
loop: asyncio.AbstractEventLoop, |
|
event: asyncio.Event, |
|
) -> None: |
|
"""Add an event to signal when a response is received. |
|
|
|
If a response already exists, this notifies the event loop immediately. |
|
""" |
|
asyncio_event = _AsyncioEvent(loop, event) |
|
|
|
with self._lock: |
|
if self._response or self._abandoned: |
|
asyncio_event.set_threadsafe() |
|
else: |
|
self._asyncio_events[event] = asyncio_event |
|
|
|
def _forget_asyncio_event(self, event: asyncio.Event) -> None: |
|
"""Cancel signalling an event when a response is received.""" |
|
with self._lock: |
|
self._asyncio_events.pop(event, None) |
|
|
|
|
|
class _AsyncioEvent: |
|
def __init__( |
|
self, |
|
loop: asyncio.AbstractEventLoop, |
|
event: asyncio.Event, |
|
): |
|
self._loop = loop |
|
self._event = event |
|
|
|
def set_threadsafe(self) -> None: |
|
"""Set the asyncio event in its own loop.""" |
|
self._loop.call_soon_threadsafe(self._event.set) |
|
|