File size: 10,360 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
import json
import os
from typing import Any, Dict, List, Optional, Union
import wandb
import wandb.data_types as data_types
from wandb.data_types import _SavedModel
from wandb.sdk import wandb_setup
from wandb.sdk.artifacts.artifact import Artifact
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
def _add_any(
artifact: Artifact,
path_or_obj: Union[
str, ArtifactManifestEntry, data_types.WBValue
], # todo: add dataframe
name: Optional[str],
) -> Any:
"""Add an object to an artifact.
High-level wrapper to add object(s) to an artifact - calls any of the .add* methods
under Artifact depending on the type of object that's passed in. This will probably
be moved to the Artifact class in the future.
Args:
artifact: `Artifact` - artifact created with `wandb.Artifact(...)`
path_or_obj: `Union[str, ArtifactManifestEntry, data_types.WBValue]` - either a
str or valid object which indicates what to add to an artifact.
name: `str` - the name of the object which is added to an artifact.
Returns:
Type[Any] - Union[None, ArtifactManifestEntry, etc]
"""
if isinstance(path_or_obj, ArtifactManifestEntry):
return artifact.add_reference(path_or_obj, name)
elif isinstance(path_or_obj, data_types.WBValue):
return artifact.add(path_or_obj, name)
elif isinstance(path_or_obj, str):
if os.path.isdir(path_or_obj):
return artifact.add_dir(path_or_obj)
elif os.path.isfile(path_or_obj):
return artifact.add_file(path_or_obj)
else:
with artifact.new_file(name) as f:
f.write(json.dumps(path_or_obj, sort_keys=True))
else:
raise TypeError(
"Expected `path_or_obj` to be instance of `ArtifactManifestEntry`,"
f" `WBValue`, or `str, found {type(path_or_obj)}"
)
def _log_artifact_version(
name: str,
type: str,
entries: Dict[str, Union[str, ArtifactManifestEntry, data_types.WBValue]],
aliases: Optional[Union[str, List[str]]] = None,
description: Optional[str] = None,
metadata: Optional[dict] = None,
project: Optional[str] = None,
scope_project: Optional[bool] = None,
job_type: str = "auto",
) -> Artifact:
"""Create an artifact, populate it, and log it with a run.
If a run is not present, we create one.
Args:
name: `str` - name of the artifact. If not scoped to a project, name will be
suffixed by "-{run_id}".
type: `str` - type of the artifact, used in the UI to group artifacts of the
same type.
entries: `Dict` - dictionary containing the named objects we want added to this
artifact.
description: `str` - text description of artifact.
metadata: `Dict` - users can pass in artifact-specific metadata here, will be
visible in the UI.
project: `str` - project under which to place this artifact.
scope_project: `bool` - if True, we will not suffix `name` with "-{run_id}".
job_type: `str` - Only applied if run is not present and we create one.
Used to identify runs of a certain job type, i.e "evaluation".
Returns:
Artifact
"""
run = wandb_setup.singleton().most_recent_active_run
if not run:
run = wandb.init(
project=project,
job_type=job_type,
settings=wandb.Settings(silent=True),
)
if not scope_project:
name = f"{name}-{run.id}"
if metadata is None:
metadata = {}
art = wandb.Artifact(name, type, description, metadata, False, None)
for path in entries:
_add_any(art, entries[path], path)
# "latest" should always be present as an alias
aliases = wandb.util._resolve_aliases(aliases)
run.log_artifact(art, aliases=aliases)
return art
def log_model(
model_obj: Any,
name: str = "model",
aliases: Optional[Union[str, List[str]]] = None,
description: Optional[str] = None,
metadata: Optional[dict] = None,
project: Optional[str] = None,
scope_project: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> "_SavedModel":
"""Log a model object to enable model-centric workflows in the UI.
Supported frameworks include PyTorch, Keras, Tensorflow, Scikit-learn, etc. Under
the hood, we create a model artifact, bind it to the run that produced this model,
associate it with the latest metrics logged with `run.log(...)` and more.
Args:
model_obj: any model object created with the following ML frameworks: PyTorch,
Keras, Tensorflow, Scikit-learn. name: `str` - name of the model artifact
that will be created to house this model_obj.
aliases: `str, List[str]` - optional alias(es) that will be applied on this
model and allow for unique identification. The alias "latest" will always be
applied to the latest version of a model.
description: `str` - text description/notes about the model - will be visible in
the Model Card UI.
metadata: `Dict` - model-specific metadata goes here - will be visible the UI.
project: `str` - project under which to place this artifact.
scope_project: `bool` - If true, name of this model artifact will not be
suffixed by `-{run_id}`.
Returns:
_SavedModel instance
Example:
```python
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 10)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
return x
model = Net()
sm = log_model(model, "my-simple-model", aliases=["best"])
```
"""
model = data_types._SavedModel.init(model_obj, **kwargs)
_ = _log_artifact_version(
name=name,
type="model",
entries={
"index": model,
},
aliases=aliases,
description=description,
metadata=metadata,
project=project,
scope_project=scope_project,
job_type="log_model",
)
# TODO: handle offline mode appropriately.
return model
def use_model(aliased_path: str, unsafe: bool = False) -> "_SavedModel":
"""Fetch a saved model from an alias.
Under the hood, we use the alias to fetch the model artifact containing the
serialized model files and rebuild the model object from these files. We also
declare the fetched model artifact as an input to the run (with `run.use_artifact`).
Args:
aliased_path: `str` - the following forms are valid: "name:version",
"name:alias". May be prefixed with "entity/project".
unsafe: `bool` - must be True to indicate the user understands the risks
associated with loading external models.
Returns:
_SavedModel instance
Example:
```python
# Assuming the model with the name "my-simple-model" is trusted:
sm = use_model("my-simple-model:latest", unsafe=True)
model = sm.model_obj()
```
"""
if not unsafe:
raise ValueError("The 'unsafe' parameter must be set to True to load a model.")
if ":" not in aliased_path:
raise ValueError(
"aliased_path must be of the form 'name:alias' or 'name:version'."
)
# Returns a _SavedModel instance
if run := wandb_setup.singleton().most_recent_active_run:
artifact = run.use_artifact(aliased_path)
sm = artifact.get("index")
if sm is None or not isinstance(sm, _SavedModel):
raise ValueError(
"Deserialization into model object failed: _SavedModel instance could not be initialized properly."
)
return sm
else:
raise ValueError(
"use_model can only be called inside a run. Please call wandb.init() before use_model(...)"
)
def link_model(
model: "_SavedModel",
target_path: str,
aliases: Optional[Union[str, List[str]]] = None,
) -> None:
"""Link the given model to a portfolio.
A portfolio is a promoted collection which contains (in this case) model artifacts.
Linking to a portfolio allows for useful model-centric workflows in the UI.
Args:
model: `_SavedModel` - an instance of _SavedModel, most likely from the output
of `log_model` or `use_model`.
target_path: `str` - the target portfolio. The following forms are valid for the
string: {portfolio}, {project/portfolio},{entity}/{project}/{portfolio}.
aliases: `str, List[str]` - optional alias(es) that will only be applied on this
linked model inside the portfolio. The alias "latest" will always be applied
to the latest version of a model.
Returns:
None
Example:
sm = use_model("my-simple-model:latest")
link_model(sm, "my-portfolio")
"""
aliases = wandb.util._resolve_aliases(aliases)
if run := wandb_setup.singleton().most_recent_active_run:
# _artifact_source, if it exists, points to a Public Artifact.
# Its existence means that _SavedModel was deserialized from a logged artifact, most likely from `use_model`.
if model._artifact_source:
artifact = model._artifact_source.artifact
# If the _SavedModel has been added to a Local Artifact (most likely through `.add(WBValue)`), then
# model._artifact_target will point to that Local Artifact.
elif model._artifact_target and model._artifact_target.artifact._final:
artifact = model._artifact_target.artifact
else:
raise ValueError(
"Linking requires that the given _SavedModel belongs to an artifact"
)
run.link_artifact(artifact, target_path, aliases)
else:
if model._artifact_source is not None:
model._artifact_source.artifact.link(target_path, aliases)
else:
raise ValueError(
"Linking requires that the given _SavedModel belongs to a logged artifact."
)
|