|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from typing import Callable |
|
from torch import Tensor |
|
from packaging import version as pkg_version |
|
|
|
|
|
class OnDevice(object): |
|
""" |
|
Create modules/tensors w. specific devices and dtypes. Examples: |
|
|
|
Create MyModule which consists of many different sub-modules and parameters. In this case we can create |
|
MyModule as a collection of 'meta' tensors by passing `device='meta'` or we can create the module _directly_ |
|
on a CUDA device by passing `device=f'cuda:{local_rank}'` (where `local_rank` is the local GPU id. |
|
|
|
with OnDevice(dtype=torch.float16, device='meta'): |
|
model = MyModel() |
|
|
|
with OnDevice(dtype=torch.float16, device=f'cuda:{local_rank}'): |
|
model = MyModel() |
|
|
|
""" |
|
|
|
_orig_torch_empty = torch.empty |
|
_orig_torch_zeros = torch.zeros |
|
_orig_torch_ones = torch.ones |
|
_orig_torch_full = torch.full |
|
|
|
def __init__(self, dtype, device="meta", enabled=True): |
|
self.dtype = dtype |
|
self.enabled = enabled |
|
self.device = device |
|
|
|
if device == "meta": |
|
if pkg_version.parse('1.10') > pkg_version.parse(torch.__version__): |
|
raise NotImplementedError("Meta tensor support is not available, please upgrade to torch 1.10+") |
|
|
|
def fp_tensor_constructor(self, fn: Callable, target_fp_dtype: torch.dtype) -> Callable: |
|
|
|
def wrapped_fn(*args, **kwargs) -> Tensor: |
|
if kwargs.get("device", None) is None: |
|
kwargs['device'] = self.device |
|
tensor: Tensor = fn(*args, **kwargs) |
|
if tensor.is_floating_point(): |
|
tensor = tensor.to(target_fp_dtype) |
|
return tensor |
|
|
|
return wrapped_fn |
|
|
|
def get_new_tensor_fn_for_dtype(self, dtype: torch.dtype) -> Callable: |
|
|
|
def new_tensor(cls, *args) -> Tensor: |
|
tensor = OnDevice._orig_torch_empty(0, device=self.device).new_empty(*args) |
|
if tensor.is_floating_point(): |
|
tensor = tensor.to(dtype) |
|
return tensor |
|
|
|
return new_tensor |
|
|
|
def __enter__(self): |
|
if not self.enabled: |
|
return |
|
torch.Tensor.__old_new__ = torch.Tensor.__new__ |
|
torch.Tensor.__new__ = self.get_new_tensor_fn_for_dtype(self.dtype) |
|
torch.empty = self.fp_tensor_constructor(self._orig_torch_empty, self.dtype) |
|
torch.zeros = self.fp_tensor_constructor(self._orig_torch_zeros, self.dtype) |
|
torch.ones = self.fp_tensor_constructor(self._orig_torch_ones, self.dtype) |
|
torch.full = self.fp_tensor_constructor(self._orig_torch_full, self.dtype) |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
if not self.enabled: |
|
return |
|
torch.Tensor.__new__ = torch.Tensor.__old_new__ |
|
torch.empty = self._orig_torch_empty |
|
torch.zeros = self._orig_torch_zeros |
|
torch.ones = self._orig_torch_ones |
|
torch.full = self._orig_torch_full |
|
|