File size: 1,331 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import os
from functools import wraps
from typing import Any, Callable, Dict, TypeVar, cast
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
requirement_env_var_mapping: Dict[str, str] = {
"report-editing:v0": "WANDB_REQUIRE_REPORT_EDITING_V0"
}
def requires(requirement: str) -> FuncT: # type: ignore
"""Decorate functions to gate features with wandb.require."""
env_var = requirement_env_var_mapping[requirement]
def deco(func: FuncT) -> FuncT:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if not os.getenv(env_var):
raise Exception(
f"You need to enable this feature with `wandb.require({requirement!r})`"
)
return func(*args, **kwargs)
return cast(FuncT, wrapper)
return cast(FuncT, deco)
class RequiresMixin:
requirement = ""
def __init__(self) -> None:
self._check_if_requirements_met()
def __post_init__(self) -> None:
self._check_if_requirements_met()
def _check_if_requirements_met(self) -> None:
env_var = requirement_env_var_mapping[self.requirement]
if not os.getenv(env_var):
raise Exception(
f'You must explicitly enable this feature with `wandb.require("{self.requirement})"'
)
|