|
import json |
|
import os |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import wandb |
|
import wandb.data_types as data_types |
|
from wandb.data_types import _SavedModel |
|
from wandb.sdk import wandb_setup |
|
from wandb.sdk.artifacts.artifact import Artifact |
|
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry |
|
|
|
|
|
def _add_any( |
|
artifact: Artifact, |
|
path_or_obj: Union[ |
|
str, ArtifactManifestEntry, data_types.WBValue |
|
], |
|
name: Optional[str], |
|
) -> Any: |
|
"""Add an object to an artifact. |
|
|
|
High-level wrapper to add object(s) to an artifact - calls any of the .add* methods |
|
under Artifact depending on the type of object that's passed in. This will probably |
|
be moved to the Artifact class in the future. |
|
|
|
Args: |
|
artifact: `Artifact` - artifact created with `wandb.Artifact(...)` |
|
path_or_obj: `Union[str, ArtifactManifestEntry, data_types.WBValue]` - either a |
|
str or valid object which indicates what to add to an artifact. |
|
|
|
name: `str` - the name of the object which is added to an artifact. |
|
|
|
Returns: |
|
Type[Any] - Union[None, ArtifactManifestEntry, etc] |
|
|
|
""" |
|
if isinstance(path_or_obj, ArtifactManifestEntry): |
|
return artifact.add_reference(path_or_obj, name) |
|
elif isinstance(path_or_obj, data_types.WBValue): |
|
return artifact.add(path_or_obj, name) |
|
elif isinstance(path_or_obj, str): |
|
if os.path.isdir(path_or_obj): |
|
return artifact.add_dir(path_or_obj) |
|
elif os.path.isfile(path_or_obj): |
|
return artifact.add_file(path_or_obj) |
|
else: |
|
with artifact.new_file(name) as f: |
|
f.write(json.dumps(path_or_obj, sort_keys=True)) |
|
else: |
|
raise TypeError( |
|
"Expected `path_or_obj` to be instance of `ArtifactManifestEntry`," |
|
f" `WBValue`, or `str, found {type(path_or_obj)}" |
|
) |
|
|
|
|
|
def _log_artifact_version( |
|
name: str, |
|
type: str, |
|
entries: Dict[str, Union[str, ArtifactManifestEntry, data_types.WBValue]], |
|
aliases: Optional[Union[str, List[str]]] = None, |
|
description: Optional[str] = None, |
|
metadata: Optional[dict] = None, |
|
project: Optional[str] = None, |
|
scope_project: Optional[bool] = None, |
|
job_type: str = "auto", |
|
) -> Artifact: |
|
"""Create an artifact, populate it, and log it with a run. |
|
|
|
If a run is not present, we create one. |
|
|
|
Args: |
|
name: `str` - name of the artifact. If not scoped to a project, name will be |
|
suffixed by "-{run_id}". |
|
type: `str` - type of the artifact, used in the UI to group artifacts of the |
|
same type. |
|
entries: `Dict` - dictionary containing the named objects we want added to this |
|
artifact. |
|
description: `str` - text description of artifact. |
|
metadata: `Dict` - users can pass in artifact-specific metadata here, will be |
|
visible in the UI. |
|
project: `str` - project under which to place this artifact. |
|
scope_project: `bool` - if True, we will not suffix `name` with "-{run_id}". |
|
job_type: `str` - Only applied if run is not present and we create one. |
|
Used to identify runs of a certain job type, i.e "evaluation". |
|
|
|
Returns: |
|
Artifact |
|
|
|
""" |
|
run = wandb_setup.singleton().most_recent_active_run |
|
if not run: |
|
run = wandb.init( |
|
project=project, |
|
job_type=job_type, |
|
settings=wandb.Settings(silent=True), |
|
) |
|
|
|
if not scope_project: |
|
name = f"{name}-{run.id}" |
|
|
|
if metadata is None: |
|
metadata = {} |
|
|
|
art = wandb.Artifact(name, type, description, metadata, False, None) |
|
|
|
for path in entries: |
|
_add_any(art, entries[path], path) |
|
|
|
|
|
aliases = wandb.util._resolve_aliases(aliases) |
|
run.log_artifact(art, aliases=aliases) |
|
|
|
return art |
|
|
|
|
|
def log_model( |
|
model_obj: Any, |
|
name: str = "model", |
|
aliases: Optional[Union[str, List[str]]] = None, |
|
description: Optional[str] = None, |
|
metadata: Optional[dict] = None, |
|
project: Optional[str] = None, |
|
scope_project: Optional[bool] = None, |
|
**kwargs: Dict[str, Any], |
|
) -> "_SavedModel": |
|
"""Log a model object to enable model-centric workflows in the UI. |
|
|
|
Supported frameworks include PyTorch, Keras, Tensorflow, Scikit-learn, etc. Under |
|
the hood, we create a model artifact, bind it to the run that produced this model, |
|
associate it with the latest metrics logged with `run.log(...)` and more. |
|
|
|
Args: |
|
model_obj: any model object created with the following ML frameworks: PyTorch, |
|
Keras, Tensorflow, Scikit-learn. name: `str` - name of the model artifact |
|
that will be created to house this model_obj. |
|
aliases: `str, List[str]` - optional alias(es) that will be applied on this |
|
model and allow for unique identification. The alias "latest" will always be |
|
applied to the latest version of a model. |
|
description: `str` - text description/notes about the model - will be visible in |
|
the Model Card UI. |
|
metadata: `Dict` - model-specific metadata goes here - will be visible the UI. |
|
project: `str` - project under which to place this artifact. |
|
scope_project: `bool` - If true, name of this model artifact will not be |
|
suffixed by `-{run_id}`. |
|
|
|
Returns: |
|
_SavedModel instance |
|
|
|
Example: |
|
```python |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class Net(nn.Module): |
|
def __init__(self): |
|
super(Net, self).__init__() |
|
self.fc1 = nn.Linear(10, 10) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = F.relu(x) |
|
return x |
|
|
|
|
|
model = Net() |
|
sm = log_model(model, "my-simple-model", aliases=["best"]) |
|
``` |
|
|
|
""" |
|
model = data_types._SavedModel.init(model_obj, **kwargs) |
|
_ = _log_artifact_version( |
|
name=name, |
|
type="model", |
|
entries={ |
|
"index": model, |
|
}, |
|
aliases=aliases, |
|
description=description, |
|
metadata=metadata, |
|
project=project, |
|
scope_project=scope_project, |
|
job_type="log_model", |
|
) |
|
|
|
return model |
|
|
|
|
|
def use_model(aliased_path: str, unsafe: bool = False) -> "_SavedModel": |
|
"""Fetch a saved model from an alias. |
|
|
|
Under the hood, we use the alias to fetch the model artifact containing the |
|
serialized model files and rebuild the model object from these files. We also |
|
declare the fetched model artifact as an input to the run (with `run.use_artifact`). |
|
|
|
Args: |
|
aliased_path: `str` - the following forms are valid: "name:version", |
|
"name:alias". May be prefixed with "entity/project". |
|
unsafe: `bool` - must be True to indicate the user understands the risks |
|
associated with loading external models. |
|
|
|
Returns: |
|
_SavedModel instance |
|
|
|
Example: |
|
```python |
|
# Assuming the model with the name "my-simple-model" is trusted: |
|
sm = use_model("my-simple-model:latest", unsafe=True) |
|
model = sm.model_obj() |
|
``` |
|
""" |
|
if not unsafe: |
|
raise ValueError("The 'unsafe' parameter must be set to True to load a model.") |
|
|
|
if ":" not in aliased_path: |
|
raise ValueError( |
|
"aliased_path must be of the form 'name:alias' or 'name:version'." |
|
) |
|
|
|
|
|
if run := wandb_setup.singleton().most_recent_active_run: |
|
artifact = run.use_artifact(aliased_path) |
|
sm = artifact.get("index") |
|
|
|
if sm is None or not isinstance(sm, _SavedModel): |
|
raise ValueError( |
|
"Deserialization into model object failed: _SavedModel instance could not be initialized properly." |
|
) |
|
|
|
return sm |
|
else: |
|
raise ValueError( |
|
"use_model can only be called inside a run. Please call wandb.init() before use_model(...)" |
|
) |
|
|
|
|
|
def link_model( |
|
model: "_SavedModel", |
|
target_path: str, |
|
aliases: Optional[Union[str, List[str]]] = None, |
|
) -> None: |
|
"""Link the given model to a portfolio. |
|
|
|
A portfolio is a promoted collection which contains (in this case) model artifacts. |
|
Linking to a portfolio allows for useful model-centric workflows in the UI. |
|
|
|
Args: |
|
model: `_SavedModel` - an instance of _SavedModel, most likely from the output |
|
of `log_model` or `use_model`. |
|
target_path: `str` - the target portfolio. The following forms are valid for the |
|
string: {portfolio}, {project/portfolio},{entity}/{project}/{portfolio}. |
|
aliases: `str, List[str]` - optional alias(es) that will only be applied on this |
|
linked model inside the portfolio. The alias "latest" will always be applied |
|
to the latest version of a model. |
|
|
|
Returns: |
|
None |
|
|
|
Example: |
|
sm = use_model("my-simple-model:latest") |
|
link_model(sm, "my-portfolio") |
|
|
|
""" |
|
aliases = wandb.util._resolve_aliases(aliases) |
|
|
|
if run := wandb_setup.singleton().most_recent_active_run: |
|
|
|
|
|
if model._artifact_source: |
|
artifact = model._artifact_source.artifact |
|
|
|
|
|
elif model._artifact_target and model._artifact_target.artifact._final: |
|
artifact = model._artifact_target.artifact |
|
else: |
|
raise ValueError( |
|
"Linking requires that the given _SavedModel belongs to an artifact" |
|
) |
|
|
|
run.link_artifact(artifact, target_path, aliases) |
|
|
|
else: |
|
if model._artifact_source is not None: |
|
model._artifact_source.artifact.link(target_path, aliases) |
|
else: |
|
raise ValueError( |
|
"Linking requires that the given _SavedModel belongs to a logged artifact." |
|
) |
|
|