File size: 1,331 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
from functools import wraps
from typing import Any, Callable, Dict, TypeVar, cast

FuncT = TypeVar("FuncT", bound=Callable[..., Any])

requirement_env_var_mapping: Dict[str, str] = {
    "report-editing:v0": "WANDB_REQUIRE_REPORT_EDITING_V0"
}


def requires(requirement: str) -> FuncT:  # type: ignore
    """Decorate functions to gate features with wandb.require."""
    env_var = requirement_env_var_mapping[requirement]

    def deco(func: FuncT) -> FuncT:
        @wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            if not os.getenv(env_var):
                raise Exception(
                    f"You need to enable this feature with `wandb.require({requirement!r})`"
                )
            return func(*args, **kwargs)

        return cast(FuncT, wrapper)

    return cast(FuncT, deco)


class RequiresMixin:
    requirement = ""

    def __init__(self) -> None:
        self._check_if_requirements_met()

    def __post_init__(self) -> None:
        self._check_if_requirements_met()

    def _check_if_requirements_met(self) -> None:
        env_var = requirement_env_var_mapping[self.requirement]
        if not os.getenv(env_var):
            raise Exception(
                f'You must explicitly enable this feature with `wandb.require("{self.requirement})"'
            )