jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import TYPE_CHECKING, Any, Tuple, List, Dict, Optional
from .utils import check_comet_availability
from .monitor import Monitor
import deepspeed.comm as dist
if TYPE_CHECKING:
import comet_ml
from .config import CometConfig
Name = str
Value = Any
GlobalSamples = int
Event = Tuple[Name, Value, GlobalSamples]
class CometMonitor(Monitor):
def __init__(self, comet_config: "CometConfig"):
super().__init__(comet_config)
check_comet_availability()
import comet_ml
self.enabled = comet_config.enabled
self._samples_log_interval = comet_config.samples_log_interval
self._experiment: Optional["comet_ml.ExperimentBase"] = None
if self.enabled and dist.get_rank() == 0:
self._experiment = comet_ml.start(
api_key=comet_config.api_key,
project=comet_config.project,
workspace=comet_config.workspace,
experiment_key=comet_config.experiment_key,
mode=comet_config.mode,
online=comet_config.online,
)
if comet_config.experiment_name is not None:
self._experiment.set_name(comet_config.experiment_name)
self._events_log_scheduler = EventsLogScheduler(comet_config.samples_log_interval)
@property
def experiment(self) -> Optional["comet_ml.ExperimentBase"]:
return self._experiment
@property
def samples_log_interval(self) -> int:
return self._samples_log_interval
def write_events(self, event_list: List[Event]) -> None:
if not self.enabled or dist.get_rank() != 0:
return None
for event in event_list:
name = event[0]
value = event[1]
engine_global_samples = event[2]
if self._events_log_scheduler.needs_logging(name, engine_global_samples):
self._experiment.__internal_api__log_metric__(
name=name,
value=value,
step=engine_global_samples,
)
class EventsLogScheduler:
def __init__(self, samples_log_interval: int):
self._samples_log_interval = samples_log_interval
self._last_logged_events_samples: Dict[str, int] = {}
def needs_logging(self, name: str, current_sample: int) -> bool:
if name not in self._last_logged_events_samples:
self._last_logged_events_samples[name] = current_sample
return True
last_logged_sample = self._last_logged_events_samples[name]
samples_delta = current_sample - last_logged_sample
if samples_delta >= self._samples_log_interval:
self._last_logged_events_samples[name] = current_sample
return True
return False