"""W&B Integration for Metaflow. This integration lets users apply decorators to Metaflow flows and steps to automatically log parameters and artifacts to W&B by type dispatch. - Decorating a step will enable or disable logging for certain types within that step - Decorating the flow is equivalent to decorating all steps with a default - Decorating a step after decorating the flow will overwrite the flow decoration Examples can be found at wandb/wandb/functional_tests/metaflow """ import inspect import pickle from functools import wraps from pathlib import Path from typing import Union import wandb from wandb.sdk.lib import telemetry as wb_telemetry try: from metaflow import current except ImportError as e: raise Exception( "Error: `metaflow` not installed >> This integration requires metaflow! To fix, please `pip install -Uqq metaflow`" ) from e try: from plum import dispatch except ImportError as e: raise Exception( "Error: `plum-dispatch` not installed >> " "This integration requires plum-dispatch! To fix, please `pip install -Uqq plum-dispatch`" ) from e try: import pandas as pd @dispatch def _wandb_use( name: str, data: pd.DataFrame, datasets=False, run=None, testing=False, *args, **kwargs, ): # type: ignore if testing: return "datasets" if datasets else None if datasets: run.use_artifact(f"{name}:latest") wandb.termlog(f"Using artifact: {name} ({type(data)})") @dispatch def wandb_track( name: str, data: pd.DataFrame, datasets=False, run=None, testing=False, *args, **kwargs, ): if testing: return "pd.DataFrame" if datasets else None if datasets: artifact = wandb.Artifact(name, type="dataset") with artifact.new_file(f"{name}.parquet", "wb") as f: data.to_parquet(f, engine="pyarrow") run.log_artifact(artifact) wandb.termlog(f"Logging artifact: {name} ({type(data)})") except ImportError: wandb.termwarn( "`pandas` not installed >> @wandb_log(datasets=True) may not auto log your dataset!" ) try: import torch import torch.nn as nn @dispatch def _wandb_use( name: str, data: nn.Module, models=False, run=None, testing=False, *args, **kwargs, ): # type: ignore if testing: return "models" if models else None if models: run.use_artifact(f"{name}:latest") wandb.termlog(f"Using artifact: {name} ({type(data)})") @dispatch def wandb_track( name: str, data: nn.Module, models=False, run=None, testing=False, *args, **kwargs, ): if testing: return "nn.Module" if models else None if models: artifact = wandb.Artifact(name, type="model") with artifact.new_file(f"{name}.pkl", "wb") as f: torch.save(data, f) run.log_artifact(artifact) wandb.termlog(f"Logging artifact: {name} ({type(data)})") except ImportError: wandb.termwarn( "`pytorch` not installed >> @wandb_log(models=True) may not auto log your model!" ) try: from sklearn.base import BaseEstimator @dispatch def _wandb_use( name: str, data: BaseEstimator, models=False, run=None, testing=False, *args, **kwargs, ): # type: ignore if testing: return "models" if models else None if models: run.use_artifact(f"{name}:latest") wandb.termlog(f"Using artifact: {name} ({type(data)})") @dispatch def wandb_track( name: str, data: BaseEstimator, models=False, run=None, testing=False, *args, **kwargs, ): if testing: return "BaseEstimator" if models else None if models: artifact = wandb.Artifact(name, type="model") with artifact.new_file(f"{name}.pkl", "wb") as f: pickle.dump(data, f) run.log_artifact(artifact) wandb.termlog(f"Logging artifact: {name} ({type(data)})") except ImportError: wandb.termwarn( "`sklearn` not installed >> @wandb_log(models=True) may not auto log your model!" ) class ArtifactProxy: def __init__(self, flow): # do this to avoid recursion problem with __setattr__ self.__dict__.update( { "flow": flow, "inputs": {}, "outputs": {}, "base": set(dir(flow)), "params": {p: getattr(flow, p) for p in current.parameter_names}, } ) def __setattr__(self, key, val): self.outputs[key] = val return setattr(self.flow, key, val) def __getattr__(self, key): if key not in self.base and key not in self.outputs: self.inputs[key] = getattr(self.flow, key) return getattr(self.flow, key) @dispatch def wandb_track( name: str, data: Union[dict, list, set, str, int, float, bool], run=None, testing=False, *args, **kwargs, ): # type: ignore if testing: return "scalar" run.log({name: data}) @dispatch def wandb_track( name: str, data: Path, datasets=False, run=None, testing=False, *args, **kwargs ): if testing: return "Path" if datasets else None if datasets: artifact = wandb.Artifact(name, type="dataset") if data.is_dir(): artifact.add_dir(data) elif data.is_file(): artifact.add_file(data) run.log_artifact(artifact) wandb.termlog(f"Logging artifact: {name} ({type(data)})") # this is the base case @dispatch def wandb_track( name: str, data, others=False, run=None, testing=False, *args, **kwargs ): if testing: return "generic" if others else None if others: artifact = wandb.Artifact(name, type="other") with artifact.new_file(f"{name}.pkl", "wb") as f: pickle.dump(data, f) run.log_artifact(artifact) wandb.termlog(f"Logging artifact: {name} ({type(data)})") @dispatch def wandb_use(name: str, data, *args, **kwargs): try: return _wandb_use(name, data, *args, **kwargs) except wandb.CommError: wandb.termwarn( f"This artifact ({name}, {type(data)}) does not exist in the wandb datastore!" f"If you created an instance inline (e.g. sklearn.ensemble.RandomForestClassifier), then you can safely ignore this" f"Otherwise you may want to check your internet connection!" ) @dispatch def wandb_use( name: str, data: Union[dict, list, set, str, int, float, bool], *args, **kwargs ): # type: ignore pass # do nothing for these types @dispatch def _wandb_use( name: str, data: Path, datasets=False, run=None, testing=False, *args, **kwargs ): # type: ignore if testing: return "datasets" if datasets else None if datasets: run.use_artifact(f"{name}:latest") wandb.termlog(f"Using artifact: {name} ({type(data)})") @dispatch def _wandb_use(name: str, data, others=False, run=None, testing=False, *args, **kwargs): # type: ignore if testing: return "others" if others else None if others: run.use_artifact(f"{name}:latest") wandb.termlog(f"Using artifact: {name} ({type(data)})") def coalesce(*arg): return next((a for a in arg if a is not None), None) def wandb_log( func=None, # /, # py38 only datasets=False, models=False, others=False, settings=None, ): """Automatically log parameters and artifacts to W&B by type dispatch. This decorator can be applied to a flow, step, or both. - Decorating a step will enable or disable logging for certain types within that step - Decorating the flow is equivalent to decorating all steps with a default - Decorating a step after decorating the flow will overwrite the flow decoration Args: func: (`Callable`). The method or class being decorated (if decorating a step or flow respectively). datasets: (`bool`). If `True`, log datasets. Datasets can be a `pd.DataFrame` or `pathlib.Path`. The default value is `False`, so datasets are not logged. models: (`bool`). If `True`, log models. Models can be a `nn.Module` or `sklearn.base.BaseEstimator`. The default value is `False`, so models are not logged. others: (`bool`). If `True`, log anything pickle-able. The default value is `False`, so files are not logged. settings: (`wandb.sdk.wandb_settings.Settings`). Custom settings passed to `wandb.init`. The default value is `None`, and is the same as passing `wandb.Settings()`. If `settings.run_group` is `None`, it will be set to `{flow_name}/{run_id}. If `settings.run_job_type` is `None`, it will be set to `{run_job_type}/{step_name}` """ @wraps(func) def decorator(func): # If you decorate a class, apply the decoration to all methods in that class if inspect.isclass(func): cls = func for attr in cls.__dict__: if callable(getattr(cls, attr)): if not hasattr(attr, "_base_func"): setattr(cls, attr, decorator(getattr(cls, attr))) return cls # prefer the earliest decoration (i.e. method decoration overrides class decoration) if hasattr(func, "_base_func"): return func @wraps(func) def wrapper(self, *args, settings=settings, **kwargs): if not isinstance(settings, wandb.sdk.wandb_settings.Settings): settings = wandb.Settings() settings.update_from_dict( { "run_group": coalesce( settings.run_group, f"{current.flow_name}/{current.run_id}" ), "run_job_type": coalesce(settings.run_job_type, current.step_name), } ) with wandb.init(settings=settings) as run: with wb_telemetry.context(run=run) as tel: tel.feature.metaflow = True proxy = ArtifactProxy(self) run.config.update(proxy.params) func(proxy, *args, **kwargs) for name, data in proxy.inputs.items(): wandb_use( name, data, datasets=datasets, models=models, others=others, run=run, ) for name, data in proxy.outputs.items(): wandb_track( name, data, datasets=datasets, models=models, others=others, run=run, ) wrapper._base_func = func # Add for testing visibility wrapper._kwargs = { "datasets": datasets, "models": models, "others": others, "settings": settings, } return wrapper if func is None: return decorator else: return decorator(func)