|
|
|
|
|
|
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Tuple, List, Dict, Optional |
|
|
|
from .utils import check_comet_availability |
|
from .monitor import Monitor |
|
|
|
import deepspeed.comm as dist |
|
|
|
if TYPE_CHECKING: |
|
import comet_ml |
|
from .config import CometConfig |
|
|
|
Name = str |
|
Value = Any |
|
GlobalSamples = int |
|
Event = Tuple[Name, Value, GlobalSamples] |
|
|
|
|
|
class CometMonitor(Monitor): |
|
|
|
def __init__(self, comet_config: "CometConfig"): |
|
super().__init__(comet_config) |
|
check_comet_availability() |
|
import comet_ml |
|
|
|
self.enabled = comet_config.enabled |
|
self._samples_log_interval = comet_config.samples_log_interval |
|
self._experiment: Optional["comet_ml.ExperimentBase"] = None |
|
|
|
if self.enabled and dist.get_rank() == 0: |
|
self._experiment = comet_ml.start( |
|
api_key=comet_config.api_key, |
|
project=comet_config.project, |
|
workspace=comet_config.workspace, |
|
experiment_key=comet_config.experiment_key, |
|
mode=comet_config.mode, |
|
online=comet_config.online, |
|
) |
|
|
|
if comet_config.experiment_name is not None: |
|
self._experiment.set_name(comet_config.experiment_name) |
|
|
|
self._events_log_scheduler = EventsLogScheduler(comet_config.samples_log_interval) |
|
|
|
@property |
|
def experiment(self) -> Optional["comet_ml.ExperimentBase"]: |
|
return self._experiment |
|
|
|
@property |
|
def samples_log_interval(self) -> int: |
|
return self._samples_log_interval |
|
|
|
def write_events(self, event_list: List[Event]) -> None: |
|
if not self.enabled or dist.get_rank() != 0: |
|
return None |
|
|
|
for event in event_list: |
|
name = event[0] |
|
value = event[1] |
|
engine_global_samples = event[2] |
|
|
|
if self._events_log_scheduler.needs_logging(name, engine_global_samples): |
|
self._experiment.__internal_api__log_metric__( |
|
name=name, |
|
value=value, |
|
step=engine_global_samples, |
|
) |
|
|
|
|
|
class EventsLogScheduler: |
|
|
|
def __init__(self, samples_log_interval: int): |
|
self._samples_log_interval = samples_log_interval |
|
self._last_logged_events_samples: Dict[str, int] = {} |
|
|
|
def needs_logging(self, name: str, current_sample: int) -> bool: |
|
if name not in self._last_logged_events_samples: |
|
self._last_logged_events_samples[name] = current_sample |
|
return True |
|
|
|
last_logged_sample = self._last_logged_events_samples[name] |
|
samples_delta = current_sample - last_logged_sample |
|
|
|
if samples_delta >= self._samples_log_interval: |
|
self._last_logged_events_samples[name] = current_sample |
|
return True |
|
|
|
return False |
|
|