File size: 10,360 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import json
import os
from typing import Any, Dict, List, Optional, Union

import wandb
import wandb.data_types as data_types
from wandb.data_types import _SavedModel
from wandb.sdk import wandb_setup
from wandb.sdk.artifacts.artifact import Artifact
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry


def _add_any(
    artifact: Artifact,
    path_or_obj: Union[
        str, ArtifactManifestEntry, data_types.WBValue
    ],  # todo: add dataframe
    name: Optional[str],
) -> Any:
    """Add an object to an artifact.

    High-level wrapper to add object(s) to an artifact - calls any of the .add* methods
    under Artifact depending on the type of object that's passed in. This will probably
    be moved to the Artifact class in the future.

    Args:
        artifact: `Artifact` - artifact created with `wandb.Artifact(...)`
        path_or_obj: `Union[str, ArtifactManifestEntry, data_types.WBValue]` - either a
            str or valid object which indicates what to add to an artifact.

        name: `str` - the name of the object which is added to an artifact.

    Returns:
        Type[Any] - Union[None, ArtifactManifestEntry, etc]

    """
    if isinstance(path_or_obj, ArtifactManifestEntry):
        return artifact.add_reference(path_or_obj, name)
    elif isinstance(path_or_obj, data_types.WBValue):
        return artifact.add(path_or_obj, name)
    elif isinstance(path_or_obj, str):
        if os.path.isdir(path_or_obj):
            return artifact.add_dir(path_or_obj)
        elif os.path.isfile(path_or_obj):
            return artifact.add_file(path_or_obj)
        else:
            with artifact.new_file(name) as f:
                f.write(json.dumps(path_or_obj, sort_keys=True))
    else:
        raise TypeError(
            "Expected `path_or_obj` to be instance of `ArtifactManifestEntry`,"
            f" `WBValue`, or `str, found {type(path_or_obj)}"
        )


def _log_artifact_version(
    name: str,
    type: str,
    entries: Dict[str, Union[str, ArtifactManifestEntry, data_types.WBValue]],
    aliases: Optional[Union[str, List[str]]] = None,
    description: Optional[str] = None,
    metadata: Optional[dict] = None,
    project: Optional[str] = None,
    scope_project: Optional[bool] = None,
    job_type: str = "auto",
) -> Artifact:
    """Create an artifact, populate it, and log it with a run.

    If a run is not present, we create one.

    Args:
        name: `str` - name of the artifact. If not scoped to a project, name will be
            suffixed by "-{run_id}".
        type: `str` - type of the artifact, used in the UI to group artifacts of the
            same type.
        entries: `Dict` - dictionary containing the named objects we want added to this
            artifact.
        description: `str` - text description of artifact.
        metadata: `Dict` - users can pass in artifact-specific metadata here, will be
            visible in the UI.
        project: `str` - project under which to place this artifact.
        scope_project: `bool` - if True, we will not suffix `name` with "-{run_id}".
        job_type: `str` - Only applied if run is not present and we create one.
            Used to identify runs of a certain job type, i.e "evaluation".

    Returns:
        Artifact

    """
    run = wandb_setup.singleton().most_recent_active_run
    if not run:
        run = wandb.init(
            project=project,
            job_type=job_type,
            settings=wandb.Settings(silent=True),
        )

    if not scope_project:
        name = f"{name}-{run.id}"

    if metadata is None:
        metadata = {}

    art = wandb.Artifact(name, type, description, metadata, False, None)

    for path in entries:
        _add_any(art, entries[path], path)

    # "latest" should always be present as an alias
    aliases = wandb.util._resolve_aliases(aliases)
    run.log_artifact(art, aliases=aliases)

    return art


def log_model(
    model_obj: Any,
    name: str = "model",
    aliases: Optional[Union[str, List[str]]] = None,
    description: Optional[str] = None,
    metadata: Optional[dict] = None,
    project: Optional[str] = None,
    scope_project: Optional[bool] = None,
    **kwargs: Dict[str, Any],
) -> "_SavedModel":
    """Log a model object to enable model-centric workflows in the UI.

    Supported frameworks include PyTorch, Keras, Tensorflow, Scikit-learn, etc. Under
    the hood, we create a model artifact, bind it to the run that produced this model,
    associate it with the latest metrics logged with `run.log(...)` and more.

    Args:
        model_obj: any model object created with the following ML frameworks: PyTorch,
            Keras, Tensorflow, Scikit-learn. name: `str` - name of the model artifact
            that will be created to house this model_obj.
        aliases: `str, List[str]` - optional alias(es) that will be applied on this
            model and allow for unique identification. The alias "latest" will always be
            applied to the latest version of a model.
        description: `str` - text description/notes about the model - will be visible in
            the Model Card UI.
        metadata: `Dict` - model-specific metadata goes here - will be visible the UI.
        project: `str` - project under which to place this artifact.
        scope_project: `bool` - If true, name of this model artifact will not be
            suffixed by `-{run_id}`.

    Returns:
        _SavedModel instance

    Example:
        ```python
        import torch.nn as nn
        import torch.nn.functional as F


        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.fc1 = nn.Linear(10, 10)

            def forward(self, x):
                x = self.fc1(x)
                x = F.relu(x)
                return x


        model = Net()
        sm = log_model(model, "my-simple-model", aliases=["best"])
        ```

    """
    model = data_types._SavedModel.init(model_obj, **kwargs)
    _ = _log_artifact_version(
        name=name,
        type="model",
        entries={
            "index": model,
        },
        aliases=aliases,
        description=description,
        metadata=metadata,
        project=project,
        scope_project=scope_project,
        job_type="log_model",
    )
    # TODO: handle offline mode appropriately.
    return model


def use_model(aliased_path: str, unsafe: bool = False) -> "_SavedModel":
    """Fetch a saved model from an alias.

    Under the hood, we use the alias to fetch the model artifact containing the
    serialized model files and rebuild the model object from these files. We also
    declare the fetched model artifact as an input to the run (with `run.use_artifact`).

    Args:
        aliased_path: `str` - the following forms are valid: "name:version",
            "name:alias". May be prefixed with "entity/project".
        unsafe: `bool` - must be True to indicate the user understands the risks
            associated with loading external models.

    Returns:
        _SavedModel instance

    Example:
        ```python
        # Assuming the model with the name "my-simple-model" is trusted:
        sm = use_model("my-simple-model:latest", unsafe=True)
        model = sm.model_obj()
        ```
    """
    if not unsafe:
        raise ValueError("The 'unsafe' parameter must be set to True to load a model.")

    if ":" not in aliased_path:
        raise ValueError(
            "aliased_path must be of the form 'name:alias' or 'name:version'."
        )

    # Returns a _SavedModel instance
    if run := wandb_setup.singleton().most_recent_active_run:
        artifact = run.use_artifact(aliased_path)
        sm = artifact.get("index")

        if sm is None or not isinstance(sm, _SavedModel):
            raise ValueError(
                "Deserialization into model object failed: _SavedModel instance could not be initialized properly."
            )

        return sm
    else:
        raise ValueError(
            "use_model can only be called inside a run. Please call wandb.init() before use_model(...)"
        )


def link_model(
    model: "_SavedModel",
    target_path: str,
    aliases: Optional[Union[str, List[str]]] = None,
) -> None:
    """Link the given model to a portfolio.

    A portfolio is a promoted collection which contains (in this case) model artifacts.
    Linking to a portfolio allows for useful model-centric workflows in the UI.

    Args:
        model: `_SavedModel` - an instance of _SavedModel, most likely from the output
            of `log_model` or `use_model`.
        target_path: `str` - the target portfolio. The following forms are valid for the
            string: {portfolio}, {project/portfolio},{entity}/{project}/{portfolio}.
        aliases: `str, List[str]` - optional alias(es) that will only be applied on this
            linked model inside the portfolio. The alias "latest" will always be applied
            to the latest version of a model.

    Returns:
        None

    Example:
        sm = use_model("my-simple-model:latest")
        link_model(sm, "my-portfolio")

    """
    aliases = wandb.util._resolve_aliases(aliases)

    if run := wandb_setup.singleton().most_recent_active_run:
        # _artifact_source, if it exists, points to a Public Artifact.
        # Its existence means that _SavedModel was deserialized from a logged artifact, most likely from `use_model`.
        if model._artifact_source:
            artifact = model._artifact_source.artifact
        # If the _SavedModel has been added to a Local Artifact (most likely through `.add(WBValue)`), then
        # model._artifact_target will point to that Local Artifact.
        elif model._artifact_target and model._artifact_target.artifact._final:
            artifact = model._artifact_target.artifact
        else:
            raise ValueError(
                "Linking requires that the given _SavedModel belongs to an artifact"
            )

        run.link_artifact(artifact, target_path, aliases)

    else:
        if model._artifact_source is not None:
            model._artifact_source.artifact.link(target_path, aliases)
        else:
            raise ValueError(
                "Linking requires that the given _SavedModel belongs to a logged artifact."
            )