|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections.abc import Mapping |
|
from typing import Any, Union |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
|
|
def _convert_fp_tensor(tensor: Tensor, dst_type: Union[str, torch.dtype]) -> Tensor: |
|
return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor |
|
|
|
|
|
|
|
class _DtypeContextManager: |
|
"""A context manager to change the default tensor type when tensors get created. |
|
|
|
See: :func:`torch.set_default_dtype` |
|
|
|
""" |
|
|
|
def __init__(self, dtype: torch.dtype) -> None: |
|
self._previous_dtype: torch.dtype = torch.get_default_dtype() |
|
self._new_dtype = dtype |
|
|
|
def __enter__(self) -> None: |
|
torch.set_default_dtype(self._new_dtype) |
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
|
torch.set_default_dtype(self._previous_dtype) |
|
|
|
|
|
class _ClassReplacementContextManager: |
|
"""A context manager to monkeypatch classes.""" |
|
|
|
def __init__(self, mapping: Mapping[str, type]) -> None: |
|
self._mapping = mapping |
|
self._originals = {} |
|
self._modules = {} |
|
for class_string in mapping: |
|
module_name, class_name = class_string.rsplit(".", 1) |
|
module = __import__(module_name, fromlist=[class_name]) |
|
self._modules[class_string] = module |
|
self._originals[class_string] = getattr(module, class_name) |
|
|
|
def __enter__(self) -> None: |
|
for class_string, replacement in self._mapping.items(): |
|
_, class_name = class_string.rsplit(".", 1) |
|
setattr(self._modules[class_string], class_name, replacement) |
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
|
for class_string, replacement in self._mapping.items(): |
|
_, class_name = class_string.rsplit(".", 1) |
|
setattr(self._modules[class_string], class_name, self._originals[class_string]) |
|
|