File size: 6,167 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
def wandb_log( # noqa: C901
func=None,
# /, # py38 only
log_component_file=True,
):
"""Wrap a standard python function and log to W&B."""
import json
import os
from functools import wraps
from inspect import Parameter, signature
from kfp import components
from kfp.components import (
InputArtifact,
InputBinaryFile,
InputPath,
InputTextFile,
OutputArtifact,
OutputBinaryFile,
OutputPath,
OutputTextFile,
)
import wandb
from wandb.sdk.lib import telemetry as wb_telemetry
output_types = (OutputArtifact, OutputBinaryFile, OutputPath, OutputTextFile)
input_types = (InputArtifact, InputBinaryFile, InputPath, InputTextFile)
def isinstance_namedtuple(x):
t = type(x)
b = t.__bases__
if len(b) != 1 or b[0] is not tuple:
return False
f = getattr(t, "_fields", None)
if not isinstance(f, tuple):
return False
return all(isinstance(n, str) for n in f)
def get_iframe_html(run):
return f'<iframe src="{run.url}?kfp=true" style="border:none;width:100%;height:100%;min-width:900px;min-height:600px;"></iframe>'
def get_link_back_to_kubeflow():
wandb_kubeflow_url = os.getenv("WANDB_KUBEFLOW_URL")
return f"{wandb_kubeflow_url}/#/runs/details/{{workflow.uid}}"
def log_input_scalar(name, data, run=None):
run.config[name] = data
wandb.termlog(f"Setting config: {name} to {data}")
def log_input_artifact(name, data, type, run=None):
artifact = wandb.Artifact(name, type=type)
artifact.add_file(data)
run.use_artifact(artifact)
wandb.termlog(f"Using artifact: {name}")
def log_output_scalar(name, data, run=None):
if isinstance_namedtuple(data):
for k, v in zip(data._fields, data):
run.log({f"{func.__name__}.{k}": v})
else:
run.log({name: data})
def log_output_artifact(name, data, type, run=None):
artifact = wandb.Artifact(name, type=type)
artifact.add_file(data)
run.log_artifact(artifact)
wandb.termlog(f"Logging artifact: {name}")
def _log_component_file(func, run=None):
name = func.__name__
output_component_file = f"{name}.yml"
components._python_op.func_to_component_file(func, output_component_file)
artifact = wandb.Artifact(name, type="kubeflow_component_file")
artifact.add_file(output_component_file)
run.log_artifact(artifact)
wandb.termlog(f"Logging component file: {output_component_file}")
# Add `mlpipeline_ui_metadata_path` to signature to show W&B run in "ML Visualizations tab"
sig = signature(func)
no_default = []
has_default = []
for param in sig.parameters.values():
if param.default is param.empty:
no_default.append(param)
else:
has_default.append(param)
new_params = tuple(
(
*no_default,
Parameter(
"mlpipeline_ui_metadata_path",
annotation=OutputPath(),
kind=Parameter.POSITIONAL_OR_KEYWORD,
),
*has_default,
)
)
new_sig = sig.replace(parameters=new_params)
new_anns = {param.name: param.annotation for param in new_params}
if "return" in func.__annotations__:
new_anns["return"] = func.__annotations__["return"]
def decorator(func):
input_scalars = {}
input_artifacts = {}
output_scalars = {}
output_artifacts = {}
for name, ann in func.__annotations__.items():
if name == "return":
output_scalars[name] = ann
elif isinstance(ann, output_types):
output_artifacts[name] = ann
elif isinstance(ann, input_types):
input_artifacts[name] = ann
else:
input_scalars[name] = ann
@wraps(func)
def wrapper(*args, **kwargs):
bound = new_sig.bind(*args, **kwargs)
bound.apply_defaults()
mlpipeline_ui_metadata_path = bound.arguments["mlpipeline_ui_metadata_path"]
del bound.arguments["mlpipeline_ui_metadata_path"]
with wandb.init(
job_type=func.__name__,
group="{{workflow.annotations.pipelines.kubeflow.org/run_name}}",
) as run:
# Link back to the kfp UI
kubeflow_url = get_link_back_to_kubeflow()
run.notes = kubeflow_url
run.config["LINK_TO_KUBEFLOW_RUN"] = kubeflow_url
iframe_html = get_iframe_html(run)
metadata = {
"outputs": [
{
"type": "markdown",
"storage": "inline",
"source": iframe_html,
}
]
}
with open(mlpipeline_ui_metadata_path, "w") as metadata_file:
json.dump(metadata, metadata_file)
if log_component_file:
_log_component_file(func, run=run)
for name, _ in input_scalars.items():
log_input_scalar(name, kwargs[name], run)
for name, ann in input_artifacts.items():
log_input_artifact(name, kwargs[name], ann.type, run)
with wb_telemetry.context(run=run) as tel:
tel.feature.kfp_wandb_log = True
result = func(*bound.args, **bound.kwargs)
for name, _ in output_scalars.items():
log_output_scalar(name, result, run)
for name, ann in output_artifacts.items():
log_output_artifact(name, kwargs[name], ann.type, run)
return result
wrapper.__signature__ = new_sig
wrapper.__annotations__ = new_anns
return wrapper
if func is None:
return decorator
else:
return decorator(func)
|