File size: 4,536 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import abc
import typing as t

from .interface.summary_record import SummaryItem, SummaryRecord


def _get_dict(d):
    if isinstance(d, dict):
        return d
    # assume argparse Namespace
    return vars(d)


class SummaryDict(metaclass=abc.ABCMeta):
    """dict-like wrapper for the nested dictionaries in a SummarySubDict.

    Triggers self._root._callback on property changes.
    """

    @abc.abstractmethod
    def _as_dict(self):
        raise NotImplementedError

    @abc.abstractmethod
    def _update(self, record: SummaryRecord):
        raise NotImplementedError

    def keys(self):
        return [k for k in self._as_dict().keys() if k != "_wandb"]

    def get(self, key, default=None):
        return self._as_dict().get(key, default)

    def __getitem__(self, key):
        item = self._as_dict()[key]

        if isinstance(item, dict):
            # this nested dict needs to be wrapped:
            wrapped_item = SummarySubDict()
            object.__setattr__(wrapped_item, "_items", item)
            object.__setattr__(wrapped_item, "_parent", self)
            object.__setattr__(wrapped_item, "_parent_key", key)

            return wrapped_item

        # this item isn't a nested dict
        return item

    __getattr__ = __getitem__

    def __setitem__(self, key, val):
        self.update({key: val})

    __setattr__ = __setitem__

    def __delattr__(self, key):
        record = SummaryRecord()
        item = SummaryItem()
        item.key = (key,)
        record.remove = (item,)
        self._update(record)

    __delitem__ = __delattr__

    def update(self, d: t.Dict):
        # import ipdb; ipdb.set_trace()
        record = SummaryRecord()
        for key, value in d.items():
            item = SummaryItem()
            item.key = (key,)
            item.value = value
            record.update.append(item)

        self._update(record)


class Summary(SummaryDict):
    """Track single values for each metric for each run.

    By default, a metric's summary is the last value of its History.

    For example, `wandb.log({'accuracy': 0.9})` will add a new step to History and
    update Summary to the latest value. In some cases, it's more useful to have
    the maximum or minimum of a metric instead of the final value. You can set
    history manually `(wandb.summary['accuracy'] = best_acc)`.

    In the UI, summary metrics appear in the table to compare across runs.
    Summary metrics are also used in visualizations like the scatter plot and
    parallel coordinates chart.

    After training has completed, you may want to save evaluation metrics to a
    run. Summary can handle numpy arrays and PyTorch/TensorFlow tensors. When
    you save one of these types to Summary, we persist the entire tensor in a
    binary file and store high level metrics in the summary object, such as min,
    mean, variance, and 95th percentile.

    Examples:
        ```python
        wandb.init(config=args)

        best_accuracy = 0
        for epoch in range(1, args.epochs + 1):
            test_loss, test_accuracy = test()
            if test_accuracy > best_accuracy:
                wandb.run.summary["best_accuracy"] = test_accuracy
                best_accuracy = test_accuracy
        ```
    """

    _update_callback: t.Callable
    _get_current_summary_callback: t.Callable

    def __init__(self, get_current_summary_callback: t.Callable):
        super().__init__()
        object.__setattr__(self, "_update_callback", None)
        object.__setattr__(
            self, "_get_current_summary_callback", get_current_summary_callback
        )

    def _set_update_callback(self, update_callback: t.Callable):
        object.__setattr__(self, "_update_callback", update_callback)

    def _as_dict(self):
        return self._get_current_summary_callback()

    def _update(self, record: SummaryRecord):
        if self._update_callback:  # type: ignore
            self._update_callback(record)


class SummarySubDict(SummaryDict):
    """Non-root node of the summary data structure.

    Contains a path to itself from the root.
    """

    _items: t.Dict
    _parent: SummaryDict
    _parent_key: str

    def __init__(self):
        object.__setattr__(self, "_items", dict())
        object.__setattr__(self, "_parent", None)
        object.__setattr__(self, "_parent_key", None)

    def _as_dict(self):
        return self._items

    def _update(self, record: SummaryRecord):
        return self._parent._update(record._add_next_parent(self._parent_key))