# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. from typing import Any, Callable, List, Optional, Union import torch class Bucket: """ Helper class to simplify the handling of buckets, which unify the underlying storage of multiple tensors """ def __init__(self, size: int, dtype: torch.dtype, device: torch.device) -> None: self._params: List[torch.Tensor] = [] self._param_ids: List[int] = [] self._fill = 0 # The actual flat tensor self.buffer: torch.Tensor = torch.zeros(size, dtype=dtype, device=device) def to( # type: ignore self, device: Optional[Union[int, torch.device]], dtype: Optional[torch.dtype] = None, non_blocking: bool = False, keep_param_alignment: bool = True, ) -> "ParamBucket": """ Move the underlying buffer """ assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it" self.buffer = self.buffer.to(device, dtype, non_blocking) class ParamBucket(Bucket): """ Helper class to simplify the handling of parameter buckets """ def __init__(self, size: int, dtype: torch.dtype, device: torch.device) -> None: super().__init__(size, dtype, device) def to( # type: ignore self, device: Optional[Union[int, torch.device]], dtype: Optional[torch.dtype] = None, non_blocking: bool = False, keep_param_alignment: bool = True, ) -> "ParamBucket": """ Move the underlying buffer """ super().to(device, dtype, non_blocking) if keep_param_alignment: self._reattach_params() @torch.no_grad() def add_param(self, param: torch.Tensor) -> None: """ Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer """ assert id(param) not in self._param_ids, "The same param cannot be checked in twice" self._add_param_as_view(param) self._params.append(param) self._param_ids.append(id(param)) @torch.no_grad() def _add_param_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None: assert self.buffer is not None assert ( param.dtype == self.buffer.dtype ), f"Different types for the bucket and the param, cannot proceed: {param.dtype} - {self.buffer.dtype}" assert ( param.device == self.buffer.device ), f"Different devices for the bucket and the param, cannot proceed: {param.device} - {self.buffer.device}" fill_next = self._fill + param.numel() assert fill_next <= self.buffer.numel() # Copy the current param value if keep_existing_value: self.buffer[self._fill : fill_next].copy_(param.data.flatten()) param.data = self.buffer[self._fill : fill_next].view_as(param.data) self._fill = fill_next @torch.no_grad() def _reattach_params(self) -> None: """ Given the parameters which have been registered previously, rebuild the whole bucket """ assert len(self._params) > 0 self._fill = 0 for p in self._params: if p.dtype != self.buffer.dtype: p.data = p.data.to(self.buffer.dtype) self._add_param_as_view(p, keep_existing_value=False) class GradBucket(Bucket): """ Helper class to simplify the handling of gradient buckets """ def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None: super().__init__(size, dtype, device) self._max_size = size self._is_collapsed = False self.params_checked_in = 0 self.destination = destination self.sent = True self.callback: Optional[Callable[[Any], None]] = None def reset_checked_in(self) -> None: """Reset the counter of the parameter grads which have been checked in""" self.params_checked_in = 0 self.sent = False @property def all_checked_in(self) -> bool: """Have all the expected gradient check-in happened ?""" return len(self._params) == self.params_checked_in def can_add_grad_view(self, param: torch.Tensor) -> bool: """Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?""" return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids def to( # type: ignore self, device: Optional[Union[int, torch.device]], dtype: Optional[torch.dtype] = None, non_blocking: bool = False, keep_param_alignment: bool = True, ) -> "GradBucket": """ Move the underlying buffer """ if self._is_collapsed: self.rebuild() super().to(device, dtype, non_blocking) if keep_param_alignment: self._reattach_grads() def zero(self) -> None: """ Set all the grads to zero """ self.buffer.fill_(0.0) @torch.no_grad() def add_grad(self, param: torch.Tensor) -> None: """ Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer """ assert id(param) not in self._param_ids, "The same gradients cannot be checked in twice" if param.grad is None: param.grad = torch.zeros_like(param) self._add_grad_as_view(param) self._params.append(param) self._param_ids.append(id(param)) @torch.no_grad() def collapse(self) -> None: """ Release the buffer from memory. The bucket will need to be rebuilt before use """ if not self._is_collapsed: for p in self._params: assert p.grad is not None p.grad.detach_() p.grad = None self.buffer = torch.zeros(0, dtype=self.buffer.dtype, device=self.buffer.device) self._fill = 0 self.params_checked_in = 0 self._is_collapsed = True @torch.no_grad() def rebuild(self) -> None: """ Given the parameter gradients which have been registered previously, rebuild the whole bucket """ assert len(self._params) > 0 if self._is_collapsed: self.buffer = torch.zeros(self._max_size, dtype=self._params[0].dtype, device=self._params[0].device) for p in self._params: self._add_grad_as_view(p) self._is_collapsed = False @torch.no_grad() def shrink(self) -> None: """ Shrink the buffer to the size of the parameter gradients currently checked in, release the extra memory """ assert self.buffer.numel() > 0, "Cannot shrink a collapsed bucket, please rebuild" self.buffer = self.buffer.resize_(self._fill).clone() self._fill = 0 for p in self._params: self._add_grad_as_view(p) self._max_size = self._fill @torch.no_grad() def _reattach_grads(self) -> None: """ Given the parameters gradients which have been registered previously, rebuild the whole bucket """ assert len(self._params) > 0 self._fill = 0 for p in self._params: self._add_grad_as_view(p, keep_existing_value=False) @torch.no_grad() def _add_grad_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None: assert self.buffer.numel() > 0, "Cannot add a gradient to a collapsed bucket, please rebuild" assert param.dtype == self.buffer.dtype assert param.device == self.buffer.device fill_next = self._fill + param.numel() assert fill_next <= self.buffer.numel() # Copy the current grad value, if any if param.grad is not None: # keep param.grad in place if keep_existing_value: self.buffer[self._fill : fill_next].copy_(param.grad.data.flatten()) param.grad.data = self.buffer[self._fill : fill_next].view_as(param.data) else: param.grad = self.buffer[self._fill : fill_next].view_as(param.data) self._fill = fill_next