File size: 3,104 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
"""Feature Flags Module.
This module implements a feature flag system for the wandb library to require experimental features
and notify the user when features have been deprecated.
Example:
import wandb
wandb.require("wandb-service@beta")
wandb.require("incremental-artifacts@beta")
"""
from __future__ import annotations
import os
from typing import Iterable
import wandb
from wandb.env import _REQUIRE_LEGACY_SERVICE
from wandb.errors import UnsupportedError
from wandb.sdk import wandb_run
class _Requires:
"""Internal feature class."""
_features: tuple[str, ...]
def __init__(self, features: str | Iterable[str]) -> None:
self._features = (
tuple([features]) if isinstance(features, str) else tuple(features)
)
def require_require(self) -> None:
pass
def _require_service(self) -> None:
wandb.teardown = wandb._teardown # type: ignore
wandb.attach = wandb._attach # type: ignore
wandb_run.Run.detach = wandb_run.Run._detach # type: ignore
def require_service(self) -> None:
self._require_service()
def require_core(self) -> None:
wandb.termwarn(
"`wandb.require('core')` is redundant as it is now the default behavior."
)
def require_legacy_service(self) -> None:
os.environ[_REQUIRE_LEGACY_SERVICE] = "true"
def apply(self) -> None:
"""Call require_* method for supported features."""
last_message: str = ""
for feature_item in self._features:
full_feature = feature_item.split("@", 2)[0]
feature = full_feature.split(":", 2)[0]
func_str = "require_{}".format(feature.replace("-", "_"))
func = getattr(self, func_str, None)
if not func:
last_message = f"require() unsupported requirement: {feature}"
wandb.termwarn(last_message)
continue
func()
if last_message:
wandb.termwarn("Supported requirements are: `legacy-service`, `service`.")
raise UnsupportedError(last_message)
def require(
requirement: str | Iterable[str] | None = None,
experiment: str | Iterable[str] | None = None,
) -> None:
"""Indicate which experimental features are used by the script.
This should be called before any other `wandb` functions, ideally right
after importing `wandb`.
Args:
requirement: The name of a feature to require or an iterable of
feature names.
experiment: An alias for `requirement`.
Raises:
wandb.errors.UnsupportedError: If a feature name is unknown.
"""
features = requirement or experiment
if not features:
return
f = _Requires(features=features)
f.apply()
def _import_module_hook() -> None:
"""On wandb import, setup anything needed based on parent process require calls."""
# TODO: optimize by caching which pids this has been done for or use real import hooks
# TODO: make this more generic, but for now this works
require("service")
|