File size: 4,284 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from __future__ import annotations

from typing import Any, Iterable

from wandb.proto import wandb_settings_pb2
from wandb.sdk.lib import RunMoment
from wandb.sdk.wandb_settings import Settings


class SettingsStatic(Settings):
    """A readonly object that wraps a protobuf Settings message.

    Implements the mapping protocol, so you can access settings as
    attributes or items.
    """

    def __init__(self, proto: wandb_settings_pb2.Settings) -> None:
        data = self._proto_to_dict(proto)
        super().__init__(**data)

    def _proto_to_dict(self, proto: wandb_settings_pb2.Settings) -> dict:
        data = {}

        exclude_fields = {
            "model_config",
            "model_fields",
            "model_fields_set",
            "__fields__",
            "__model_fields_set",
            "__pydantic_self__",
            "__pydantic_initialised__",
        }

        fields = (
            Settings.model_fields
            if hasattr(Settings, "model_fields")
            else Settings.__fields__
        )  # type: ignore [attr-defined]

        fields = {k: v for k, v in fields.items() if k not in exclude_fields}  # type: ignore [union-attr]

        forks_specified: list[str] = []
        for key in fields:
            # Skip Python-only keys that do not exist on the proto.
            if key in ("reinit",):
                continue

            value: Any = None

            field_info = fields[key]
            annotation = str(field_info.annotation)

            if key == "_stats_open_metrics_filters":
                # todo: it's an underscored field, refactor into
                #  something more elegant?
                # I'm really about this. It's ugly, but it works.
                # Do not try to repeat this at home.
                value_type = getattr(proto, key).WhichOneof("value")
                if value_type == "sequence":
                    value = list(getattr(proto, key).sequence.value)
                elif value_type == "mapping":
                    unpacked_mapping = {}
                    for outer_key, outer_value in getattr(
                        proto, key
                    ).mapping.value.items():
                        unpacked_inner = {}
                        for inner_key, inner_value in outer_value.value.items():
                            unpacked_inner[inner_key] = inner_value
                        unpacked_mapping[outer_key] = unpacked_inner
                    value = unpacked_mapping
            elif key == "fork_from" or key == "resume_from":
                value = getattr(proto, key)
                if value.run:
                    value = RunMoment(
                        run=value.run, value=value.value, metric=value.metric
                    )
                    forks_specified.append(key)
                else:
                    value = None
            else:
                if proto.HasField(key):  # type: ignore [arg-type]
                    value = getattr(proto, key).value
                    # Convert to list if the field is a sequence
                    if any(t in annotation for t in ("tuple", "Sequence", "list")):
                        value = list(value)
                else:
                    value = None

            if value is not None:
                data[key] = value

        if len(forks_specified) > 1:
            raise ValueError(
                "Only one of fork_from or resume_from can be specified, not both"
            )

        return data

    def __setattr__(self, name: str, value: object) -> None:
        raise AttributeError("Error: SettingsStatic is a readonly object")

    def __setitem__(self, key: str, val: object) -> None:
        raise AttributeError("Error: SettingsStatic is a readonly object")

    def keys(self) -> Iterable[str]:
        return self.__dict__.keys()

    def __getitem__(self, key: str) -> Any:
        return self.__dict__[key]

    def __getattr__(self, name: str) -> Any:
        try:
            return self.__dict__[name]
        except KeyError:
            raise AttributeError(f"SettingsStatic has no attribute {name}")

    def __str__(self) -> str:
        return str(self.__dict__)

    def __contains__(self, key: str) -> bool:
        return key in self.__dict__