File size: 1,791 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import os
from typing import Union
def _maybestr2bool(value: Union[bool, str], error: str) -> bool:
if isinstance(value, bool):
return value
elif isinstance(value, str):
if value.lower() in ("0", "false"):
return False
elif value.lower() in ("1", "true"):
return True
else:
raise ValueError(error)
else:
raise ValueError(error)
class _JaxtypingConfig:
def __init__(self):
self.update("jaxtyping_disable", os.environ.get("JAXTYPING_DISABLE", "0"))
self.update(
"jaxtyping_remove_typechecker_stack",
os.environ.get("JAXTYPING_REMOVE_TYPECHECKER_STACK", "0"),
)
def update(self, item: str, value):
if item.lower() == "jaxtyping_disable":
msg = (
"Unrecognised value for `JAXTYPING_DISABLE`. Valid values are "
"`JAXTYPING_DISABLE=0` (the default) or `JAXTYPING_DISABLE=1` (to "
"disable runtime type checking)."
)
self.jaxtyping_disable = _maybestr2bool(value, msg)
elif item.lower() == "jaxtyping_remove_typechecker_stack":
msg = (
"Unrecognised value for `JAXTYPING_REMOVE_TYPECHECKER_STACK`. Valid "
"values are `JAXTYPING_REMOVE_TYPECHECKER_STACK=0` (the default) or "
"`JAXTYPING_REMOVE_TYPECHECKER_STACK=1` (to remove the stack frames "
"from the typechecker in `jaxtyped(typechecker=...)`, when it raises a "
"runtime type-checking error)."
)
self.jaxtyping_remove_typechecker_stack = _maybestr2bool(value, msg)
else:
raise ValueError(f"Unrecognised config value {item}")
config = _JaxtypingConfig()
|