File size: 2,903 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import TYPE_CHECKING, Any, Tuple, List, Dict, Optional
from .utils import check_comet_availability
from .monitor import Monitor
import deepspeed.comm as dist
if TYPE_CHECKING:
import comet_ml
from .config import CometConfig
Name = str
Value = Any
GlobalSamples = int
Event = Tuple[Name, Value, GlobalSamples]
class CometMonitor(Monitor):
def __init__(self, comet_config: "CometConfig"):
super().__init__(comet_config)
check_comet_availability()
import comet_ml
self.enabled = comet_config.enabled
self._samples_log_interval = comet_config.samples_log_interval
self._experiment: Optional["comet_ml.ExperimentBase"] = None
if self.enabled and dist.get_rank() == 0:
self._experiment = comet_ml.start(
api_key=comet_config.api_key,
project=comet_config.project,
workspace=comet_config.workspace,
experiment_key=comet_config.experiment_key,
mode=comet_config.mode,
online=comet_config.online,
)
if comet_config.experiment_name is not None:
self._experiment.set_name(comet_config.experiment_name)
self._events_log_scheduler = EventsLogScheduler(comet_config.samples_log_interval)
@property
def experiment(self) -> Optional["comet_ml.ExperimentBase"]:
return self._experiment
@property
def samples_log_interval(self) -> int:
return self._samples_log_interval
def write_events(self, event_list: List[Event]) -> None:
if not self.enabled or dist.get_rank() != 0:
return None
for event in event_list:
name = event[0]
value = event[1]
engine_global_samples = event[2]
if self._events_log_scheduler.needs_logging(name, engine_global_samples):
self._experiment.__internal_api__log_metric__(
name=name,
value=value,
step=engine_global_samples,
)
class EventsLogScheduler:
def __init__(self, samples_log_interval: int):
self._samples_log_interval = samples_log_interval
self._last_logged_events_samples: Dict[str, int] = {}
def needs_logging(self, name: str, current_sample: int) -> bool:
if name not in self._last_logged_events_samples:
self._last_logged_events_samples[name] = current_sample
return True
last_logged_sample = self._last_logged_events_samples[name]
samples_delta = current_sample - last_logged_sample
if samples_delta >= self._samples_log_interval:
self._last_logged_events_samples[name] = current_sample
return True
return False
|