|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities for Argument Parsing within Lightning Components.""" |
|
|
|
import inspect |
|
import os |
|
from argparse import Namespace |
|
from ast import literal_eval |
|
from contextlib import suppress |
|
from functools import wraps |
|
from typing import Any, Callable, TypeVar, cast |
|
|
|
_T = TypeVar("_T", bound=Callable[..., Any]) |
|
|
|
|
|
def _parse_env_variables(cls: type, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: |
|
"""Parse environment arguments if they are defined. |
|
|
|
Examples: |
|
|
|
>>> from pytorch_lightning import Trainer |
|
>>> _parse_env_variables(Trainer) |
|
Namespace() |
|
>>> import os |
|
>>> os.environ["PL_TRAINER_DEVICES"] = '42' |
|
>>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23' |
|
>>> _parse_env_variables(Trainer) |
|
Namespace(devices=42) |
|
>>> del os.environ["PL_TRAINER_DEVICES"] |
|
|
|
""" |
|
env_args = {} |
|
for arg_name in inspect.signature(cls).parameters: |
|
env = template % {"cls_name": cls.__name__.upper(), "cls_argument": arg_name.upper()} |
|
val = os.environ.get(env) |
|
if not (val is None or val == ""): |
|
|
|
with suppress(Exception): |
|
|
|
val = literal_eval(val) |
|
env_args[arg_name] = val |
|
return Namespace(**env_args) |
|
|
|
|
|
def _defaults_from_env_vars(fn: _T) -> _T: |
|
@wraps(fn) |
|
def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any: |
|
cls = self.__class__ |
|
if args: |
|
|
|
cls_arg_names = inspect.signature(cls).parameters |
|
|
|
kwargs.update(dict(zip(cls_arg_names, args))) |
|
env_variables = vars(_parse_env_variables(cls)) |
|
|
|
kwargs = dict(list(env_variables.items()) + list(kwargs.items())) |
|
|
|
|
|
return fn(self, **kwargs) |
|
|
|
return cast(_T, insert_env_defaults) |
|
|