File size: 2,555 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
"""Internal utility routines.
Collection of classes to support the internal process.
"""
import logging
import queue
import sys
import threading
import time
from typing import TYPE_CHECKING, Optional, Tuple, Type, Union
if TYPE_CHECKING:
from queue import Queue
from threading import Event
from types import TracebackType
from wandb.proto.wandb_internal_pb2 import Record, Result
ExceptionType = Union[
Tuple[Type[BaseException], BaseException, TracebackType],
Tuple[None, None, None],
]
logger = logging.getLogger(__name__)
class ExceptionThread(threading.Thread):
"""Class to catch exceptions when running a thread."""
__stopped: "Event"
__exception: Optional["ExceptionType"]
def __init__(self, stopped: "Event") -> None:
threading.Thread.__init__(self)
self.__stopped = stopped
self.__exception = None
def _run(self) -> None:
raise NotImplementedError
def run(self) -> None:
try:
self._run()
except Exception:
self.__exception = sys.exc_info()
finally:
if self.__exception and self.__stopped:
self.__stopped.set()
def get_exception(self) -> Optional["ExceptionType"]:
return self.__exception
class RecordLoopThread(ExceptionThread):
"""Class to manage reading from queues safely."""
def __init__(
self,
input_record_q: "Queue[Record]",
result_q: "Queue[Result]",
stopped: "Event",
debounce_interval_ms: "float" = 1000,
) -> None:
ExceptionThread.__init__(self, stopped=stopped)
self._input_record_q = input_record_q
self._result_q = result_q
self._stopped = stopped
self._debounce_interval_ms = debounce_interval_ms
def _setup(self) -> None:
raise NotImplementedError
def _process(self, record: "Record") -> None:
raise NotImplementedError
def _finish(self) -> None:
raise NotImplementedError
def _debounce(self) -> None:
raise NotImplementedError
def _run(self) -> None:
self._setup()
start = time.time()
while not self._stopped.is_set():
if time.time() - start >= self._debounce_interval_ms / 1000.0:
self._debounce()
start = time.time()
try:
record = self._input_record_q.get(timeout=1)
except queue.Empty:
continue
self._process(record)
self._finish()
|