jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for Argument Parsing within Lightning Components."""
import inspect
import os
from argparse import Namespace
from ast import literal_eval
from contextlib import suppress
from functools import wraps
from typing import Any, Callable, TypeVar, cast
_T = TypeVar("_T", bound=Callable[..., Any])
def _parse_env_variables(cls: type, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
"""Parse environment arguments if they are defined.
Examples:
>>> from pytorch_lightning import Trainer
>>> _parse_env_variables(Trainer)
Namespace()
>>> import os
>>> os.environ["PL_TRAINER_DEVICES"] = '42'
>>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23'
>>> _parse_env_variables(Trainer)
Namespace(devices=42)
>>> del os.environ["PL_TRAINER_DEVICES"]
"""
env_args = {}
for arg_name in inspect.signature(cls).parameters:
env = template % {"cls_name": cls.__name__.upper(), "cls_argument": arg_name.upper()}
val = os.environ.get(env)
if not (val is None or val == ""):
# todo: specify the possible exception
with suppress(Exception):
# converting to native types like int/float/bool
val = literal_eval(val)
env_args[arg_name] = val
return Namespace(**env_args)
def _defaults_from_env_vars(fn: _T) -> _T:
@wraps(fn)
def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any:
cls = self.__class__ # get the class
if args: # in case any args passed move them to kwargs
# parse the argument names
cls_arg_names = inspect.signature(cls).parameters
# convert args to kwargs
kwargs.update(dict(zip(cls_arg_names, args)))
env_variables = vars(_parse_env_variables(cls))
# update the kwargs by env variables
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
# all args were already moved to kwargs
return fn(self, **kwargs)
return cast(_T, insert_env_defaults)