File size: 2,903 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import TYPE_CHECKING, Any, Tuple, List, Dict, Optional

from .utils import check_comet_availability
from .monitor import Monitor

import deepspeed.comm as dist

if TYPE_CHECKING:
    import comet_ml
    from .config import CometConfig

Name = str
Value = Any
GlobalSamples = int
Event = Tuple[Name, Value, GlobalSamples]


class CometMonitor(Monitor):

    def __init__(self, comet_config: "CometConfig"):
        super().__init__(comet_config)
        check_comet_availability()
        import comet_ml

        self.enabled = comet_config.enabled
        self._samples_log_interval = comet_config.samples_log_interval
        self._experiment: Optional["comet_ml.ExperimentBase"] = None

        if self.enabled and dist.get_rank() == 0:
            self._experiment = comet_ml.start(
                api_key=comet_config.api_key,
                project=comet_config.project,
                workspace=comet_config.workspace,
                experiment_key=comet_config.experiment_key,
                mode=comet_config.mode,
                online=comet_config.online,
            )

            if comet_config.experiment_name is not None:
                self._experiment.set_name(comet_config.experiment_name)

        self._events_log_scheduler = EventsLogScheduler(comet_config.samples_log_interval)

    @property
    def experiment(self) -> Optional["comet_ml.ExperimentBase"]:
        return self._experiment

    @property
    def samples_log_interval(self) -> int:
        return self._samples_log_interval

    def write_events(self, event_list: List[Event]) -> None:
        if not self.enabled or dist.get_rank() != 0:
            return None

        for event in event_list:
            name = event[0]
            value = event[1]
            engine_global_samples = event[2]

            if self._events_log_scheduler.needs_logging(name, engine_global_samples):
                self._experiment.__internal_api__log_metric__(
                    name=name,
                    value=value,
                    step=engine_global_samples,
                )


class EventsLogScheduler:

    def __init__(self, samples_log_interval: int):
        self._samples_log_interval = samples_log_interval
        self._last_logged_events_samples: Dict[str, int] = {}

    def needs_logging(self, name: str, current_sample: int) -> bool:
        if name not in self._last_logged_events_samples:
            self._last_logged_events_samples[name] = current_sample
            return True

        last_logged_sample = self._last_logged_events_samples[name]
        samples_delta = current_sample - last_logged_sample

        if samples_delta >= self._samples_log_interval:
            self._last_logged_events_samples[name] = current_sample
            return True

        return False