jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
"""Internal utility routines.
Collection of classes to support the internal process.
"""
import logging
import queue
import sys
import threading
import time
from typing import TYPE_CHECKING, Optional, Tuple, Type, Union
if TYPE_CHECKING:
from queue import Queue
from threading import Event
from types import TracebackType
from wandb.proto.wandb_internal_pb2 import Record, Result
ExceptionType = Union[
Tuple[Type[BaseException], BaseException, TracebackType],
Tuple[None, None, None],
]
logger = logging.getLogger(__name__)
class ExceptionThread(threading.Thread):
"""Class to catch exceptions when running a thread."""
__stopped: "Event"
__exception: Optional["ExceptionType"]
def __init__(self, stopped: "Event") -> None:
threading.Thread.__init__(self)
self.__stopped = stopped
self.__exception = None
def _run(self) -> None:
raise NotImplementedError
def run(self) -> None:
try:
self._run()
except Exception:
self.__exception = sys.exc_info()
finally:
if self.__exception and self.__stopped:
self.__stopped.set()
def get_exception(self) -> Optional["ExceptionType"]:
return self.__exception
class RecordLoopThread(ExceptionThread):
"""Class to manage reading from queues safely."""
def __init__(
self,
input_record_q: "Queue[Record]",
result_q: "Queue[Result]",
stopped: "Event",
debounce_interval_ms: "float" = 1000,
) -> None:
ExceptionThread.__init__(self, stopped=stopped)
self._input_record_q = input_record_q
self._result_q = result_q
self._stopped = stopped
self._debounce_interval_ms = debounce_interval_ms
def _setup(self) -> None:
raise NotImplementedError
def _process(self, record: "Record") -> None:
raise NotImplementedError
def _finish(self) -> None:
raise NotImplementedError
def _debounce(self) -> None:
raise NotImplementedError
def _run(self) -> None:
self._setup()
start = time.time()
while not self._stopped.is_set():
if time.time() - start >= self._debounce_interval_ms / 1000.0:
self._debounce()
start = time.time()
try:
record = self._input_record_q.get(timeout=1)
except queue.Empty:
continue
self._process(record)
self._finish()