import os from typing import Union def _maybestr2bool(value: Union[bool, str], error: str) -> bool: if isinstance(value, bool): return value elif isinstance(value, str): if value.lower() in ("0", "false"): return False elif value.lower() in ("1", "true"): return True else: raise ValueError(error) else: raise ValueError(error) class _JaxtypingConfig: def __init__(self): self.update("jaxtyping_disable", os.environ.get("JAXTYPING_DISABLE", "0")) self.update( "jaxtyping_remove_typechecker_stack", os.environ.get("JAXTYPING_REMOVE_TYPECHECKER_STACK", "0"), ) def update(self, item: str, value): if item.lower() == "jaxtyping_disable": msg = ( "Unrecognised value for `JAXTYPING_DISABLE`. Valid values are " "`JAXTYPING_DISABLE=0` (the default) or `JAXTYPING_DISABLE=1` (to " "disable runtime type checking)." ) self.jaxtyping_disable = _maybestr2bool(value, msg) elif item.lower() == "jaxtyping_remove_typechecker_stack": msg = ( "Unrecognised value for `JAXTYPING_REMOVE_TYPECHECKER_STACK`. Valid " "values are `JAXTYPING_REMOVE_TYPECHECKER_STACK=0` (the default) or " "`JAXTYPING_REMOVE_TYPECHECKER_STACK=1` (to remove the stack frames " "from the typechecker in `jaxtyped(typechecker=...)`, when it raises a " "runtime type-checking error)." ) self.jaxtyping_remove_typechecker_stack = _maybestr2bool(value, msg) else: raise ValueError(f"Unrecognised config value {item}") config = _JaxtypingConfig()