|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, List, Optional, Union |
|
|
|
import torch |
|
|
|
|
|
class Bucket: |
|
""" |
|
Helper class to simplify the handling of buckets, which unify the underlying storage of multiple tensors |
|
""" |
|
|
|
def __init__(self, size: int, dtype: torch.dtype, device: torch.device) -> None: |
|
self._params: List[torch.Tensor] = [] |
|
self._param_ids: List[int] = [] |
|
self._fill = 0 |
|
|
|
|
|
self.buffer: torch.Tensor = torch.zeros(size, dtype=dtype, device=device) |
|
|
|
def to( |
|
self, |
|
device: Optional[Union[int, torch.device]], |
|
dtype: Optional[torch.dtype] = None, |
|
non_blocking: bool = False, |
|
keep_param_alignment: bool = True, |
|
) -> "ParamBucket": |
|
""" |
|
Move the underlying buffer |
|
""" |
|
assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it" |
|
self.buffer = self.buffer.to(device, dtype, non_blocking) |
|
|
|
|
|
class ParamBucket(Bucket): |
|
""" |
|
Helper class to simplify the handling of parameter buckets |
|
""" |
|
|
|
def __init__(self, size: int, dtype: torch.dtype, device: torch.device) -> None: |
|
super().__init__(size, dtype, device) |
|
|
|
def to( |
|
self, |
|
device: Optional[Union[int, torch.device]], |
|
dtype: Optional[torch.dtype] = None, |
|
non_blocking: bool = False, |
|
keep_param_alignment: bool = True, |
|
) -> "ParamBucket": |
|
""" |
|
Move the underlying buffer |
|
""" |
|
super().to(device, dtype, non_blocking) |
|
|
|
if keep_param_alignment: |
|
self._reattach_params() |
|
|
|
@torch.no_grad() |
|
def add_param(self, param: torch.Tensor) -> None: |
|
""" |
|
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer |
|
""" |
|
|
|
assert id(param) not in self._param_ids, "The same param cannot be checked in twice" |
|
|
|
self._add_param_as_view(param) |
|
self._params.append(param) |
|
self._param_ids.append(id(param)) |
|
|
|
@torch.no_grad() |
|
def _add_param_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None: |
|
assert self.buffer is not None |
|
assert ( |
|
param.dtype == self.buffer.dtype |
|
), f"Different types for the bucket and the param, cannot proceed: {param.dtype} - {self.buffer.dtype}" |
|
assert ( |
|
param.device == self.buffer.device |
|
), f"Different devices for the bucket and the param, cannot proceed: {param.device} - {self.buffer.device}" |
|
|
|
fill_next = self._fill + param.numel() |
|
assert fill_next <= self.buffer.numel() |
|
|
|
|
|
if keep_existing_value: |
|
self.buffer[self._fill : fill_next].copy_(param.data.flatten()) |
|
param.data = self.buffer[self._fill : fill_next].view_as(param.data) |
|
self._fill = fill_next |
|
|
|
@torch.no_grad() |
|
def _reattach_params(self) -> None: |
|
""" |
|
Given the parameters which have been registered previously, rebuild the whole bucket |
|
""" |
|
assert len(self._params) > 0 |
|
|
|
self._fill = 0 |
|
for p in self._params: |
|
if p.dtype != self.buffer.dtype: |
|
p.data = p.data.to(self.buffer.dtype) |
|
self._add_param_as_view(p, keep_existing_value=False) |
|
|
|
|
|
class GradBucket(Bucket): |
|
""" |
|
Helper class to simplify the handling of gradient buckets |
|
""" |
|
|
|
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None: |
|
super().__init__(size, dtype, device) |
|
|
|
self._max_size = size |
|
self._is_collapsed = False |
|
|
|
self.params_checked_in = 0 |
|
self.destination = destination |
|
self.sent = True |
|
self.callback: Optional[Callable[[Any], None]] = None |
|
|
|
def reset_checked_in(self) -> None: |
|
"""Reset the counter of the parameter grads which have been checked in""" |
|
self.params_checked_in = 0 |
|
self.sent = False |
|
|
|
@property |
|
def all_checked_in(self) -> bool: |
|
"""Have all the expected gradient check-in happened ?""" |
|
return len(self._params) == self.params_checked_in |
|
|
|
def can_add_grad_view(self, param: torch.Tensor) -> bool: |
|
"""Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?""" |
|
return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids |
|
|
|
def to( |
|
self, |
|
device: Optional[Union[int, torch.device]], |
|
dtype: Optional[torch.dtype] = None, |
|
non_blocking: bool = False, |
|
keep_param_alignment: bool = True, |
|
) -> "GradBucket": |
|
""" |
|
Move the underlying buffer |
|
""" |
|
if self._is_collapsed: |
|
self.rebuild() |
|
|
|
super().to(device, dtype, non_blocking) |
|
|
|
if keep_param_alignment: |
|
self._reattach_grads() |
|
|
|
def zero(self) -> None: |
|
""" |
|
Set all the grads to zero |
|
""" |
|
self.buffer.fill_(0.0) |
|
|
|
@torch.no_grad() |
|
def add_grad(self, param: torch.Tensor) -> None: |
|
""" |
|
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer |
|
""" |
|
|
|
assert id(param) not in self._param_ids, "The same gradients cannot be checked in twice" |
|
|
|
if param.grad is None: |
|
param.grad = torch.zeros_like(param) |
|
|
|
self._add_grad_as_view(param) |
|
self._params.append(param) |
|
self._param_ids.append(id(param)) |
|
|
|
@torch.no_grad() |
|
def collapse(self) -> None: |
|
""" |
|
Release the buffer from memory. The bucket will need to be rebuilt before use |
|
""" |
|
if not self._is_collapsed: |
|
for p in self._params: |
|
assert p.grad is not None |
|
p.grad.detach_() |
|
p.grad = None |
|
|
|
self.buffer = torch.zeros(0, dtype=self.buffer.dtype, device=self.buffer.device) |
|
self._fill = 0 |
|
self.params_checked_in = 0 |
|
self._is_collapsed = True |
|
|
|
@torch.no_grad() |
|
def rebuild(self) -> None: |
|
""" |
|
Given the parameter gradients which have been registered previously, rebuild the whole bucket |
|
""" |
|
assert len(self._params) > 0 |
|
|
|
if self._is_collapsed: |
|
self.buffer = torch.zeros(self._max_size, dtype=self._params[0].dtype, device=self._params[0].device) |
|
|
|
for p in self._params: |
|
self._add_grad_as_view(p) |
|
|
|
self._is_collapsed = False |
|
|
|
@torch.no_grad() |
|
def shrink(self) -> None: |
|
""" |
|
Shrink the buffer to the size of the parameter gradients currently checked in, release the extra memory |
|
""" |
|
assert self.buffer.numel() > 0, "Cannot shrink a collapsed bucket, please rebuild" |
|
|
|
self.buffer = self.buffer.resize_(self._fill).clone() |
|
self._fill = 0 |
|
for p in self._params: |
|
self._add_grad_as_view(p) |
|
|
|
self._max_size = self._fill |
|
|
|
@torch.no_grad() |
|
def _reattach_grads(self) -> None: |
|
""" |
|
Given the parameters gradients which have been registered previously, rebuild the whole bucket |
|
""" |
|
assert len(self._params) > 0 |
|
|
|
self._fill = 0 |
|
for p in self._params: |
|
self._add_grad_as_view(p, keep_existing_value=False) |
|
|
|
@torch.no_grad() |
|
def _add_grad_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None: |
|
assert self.buffer.numel() > 0, "Cannot add a gradient to a collapsed bucket, please rebuild" |
|
assert param.dtype == self.buffer.dtype |
|
assert param.device == self.buffer.device |
|
|
|
fill_next = self._fill + param.numel() |
|
assert fill_next <= self.buffer.numel() |
|
|
|
|
|
if param.grad is not None: |
|
|
|
if keep_existing_value: |
|
self.buffer[self._fill : fill_next].copy_(param.grad.data.flatten()) |
|
param.grad.data = self.buffer[self._fill : fill_next].view_as(param.data) |
|
else: |
|
param.grad = self.buffer[self._fill : fill_next].view_as(param.data) |
|
self._fill = fill_next |
|
|