File size: 4,946 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from __future__ import annotations
import asyncio
import math
import threading
from typing import TYPE_CHECKING
from typing_extensions import override
from wandb.proto import wandb_server_pb2 as spb
from .mailbox_handle import HandleAbandonedError, MailboxHandle
# Necessary to break an import loop.
if TYPE_CHECKING:
from wandb.sdk.interface import interface
class MailboxResponseHandle(MailboxHandle[spb.ServerResponse]):
"""A general handle for any ServerResponse."""
def __init__(self, address: str) -> None:
self._address = address
self._lock = threading.Lock()
self._event = threading.Event()
self._abandoned = False
self._response: spb.ServerResponse | None = None
self._asyncio_events: dict[asyncio.Event, _AsyncioEvent] = dict()
def deliver(self, response: spb.ServerResponse) -> None:
"""Deliver the response.
This may only be called once. It is an error to respond to the same
request more than once. It is a no-op if the handle has been abandoned.
"""
with self._lock:
if self._abandoned:
return
if self._response:
raise ValueError(
f"A response has already been delivered to {self._address}."
)
self._response = response
self._signal_done()
@override
def cancel(self, iface: interface.InterfaceBase) -> None:
iface.publish_cancel(self._address)
self.abandon()
@override
def abandon(self) -> None:
with self._lock:
self._abandoned = True
self._signal_done()
def _signal_done(self) -> None:
"""Indicate that the handle either got a response or became abandoned.
The lock must be held.
"""
# Unblock threads blocked on `wait_or`.
self._event.set()
# Unblock asyncio loops blocked on `wait_async`.
for asyncio_event in self._asyncio_events.values():
asyncio_event.set_threadsafe()
self._asyncio_events.clear()
@override
def check(self) -> spb.ServerResponse | None:
with self._lock:
return self._response
@override
def wait_or(self, *, timeout: float | None) -> spb.ServerResponse:
if timeout is not None and not math.isfinite(timeout):
raise ValueError("Timeout must be finite or None.")
if not self._event.wait(timeout=timeout):
raise TimeoutError(
f"Timed out waiting for response on {self._address}",
)
with self._lock:
if self._response:
return self._response
assert self._abandoned
raise HandleAbandonedError()
@override
async def wait_async(self, *, timeout: float | None) -> spb.ServerResponse:
if timeout is not None and not math.isfinite(timeout):
raise ValueError("Timeout must be finite or None.")
evt = asyncio.Event()
self._add_asyncio_event(asyncio.get_event_loop(), evt)
try:
await asyncio.wait_for(evt.wait(), timeout=timeout)
except (asyncio.TimeoutError, TimeoutError) as e:
with self._lock:
if self._response:
return self._response
elif self._abandoned:
raise HandleAbandonedError()
else:
raise TimeoutError(
f"Timed out waiting for response on {self._address}"
) from e
else:
with self._lock:
if self._response:
return self._response
assert self._abandoned
raise HandleAbandonedError()
finally:
self._forget_asyncio_event(evt)
def _add_asyncio_event(
self,
loop: asyncio.AbstractEventLoop,
event: asyncio.Event,
) -> None:
"""Add an event to signal when a response is received.
If a response already exists, this notifies the event loop immediately.
"""
asyncio_event = _AsyncioEvent(loop, event)
with self._lock:
if self._response or self._abandoned:
asyncio_event.set_threadsafe()
else:
self._asyncio_events[event] = asyncio_event
def _forget_asyncio_event(self, event: asyncio.Event) -> None:
"""Cancel signalling an event when a response is received."""
with self._lock:
self._asyncio_events.pop(event, None)
class _AsyncioEvent:
def __init__(
self,
loop: asyncio.AbstractEventLoop,
event: asyncio.Event,
):
self._loop = loop
self._event = event
def set_threadsafe(self) -> None:
"""Set the asyncio event in its own loop."""
self._loop.call_soon_threadsafe(self._event.set)
|