jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, List, Optional, Union
import torch
class Bucket:
"""
Helper class to simplify the handling of buckets, which unify the underlying storage of multiple tensors
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device) -> None:
self._params: List[torch.Tensor] = []
self._param_ids: List[int] = []
self._fill = 0
# The actual flat tensor
self.buffer: torch.Tensor = torch.zeros(size, dtype=dtype, device=device)
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
keep_param_alignment: bool = True,
) -> "ParamBucket":
"""
Move the underlying buffer
"""
assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it"
self.buffer = self.buffer.to(device, dtype, non_blocking)
class ParamBucket(Bucket):
"""
Helper class to simplify the handling of parameter buckets
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device) -> None:
super().__init__(size, dtype, device)
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
keep_param_alignment: bool = True,
) -> "ParamBucket":
"""
Move the underlying buffer
"""
super().to(device, dtype, non_blocking)
if keep_param_alignment:
self._reattach_params()
@torch.no_grad()
def add_param(self, param: torch.Tensor) -> None:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert id(param) not in self._param_ids, "The same param cannot be checked in twice"
self._add_param_as_view(param)
self._params.append(param)
self._param_ids.append(id(param))
@torch.no_grad()
def _add_param_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None:
assert self.buffer is not None
assert (
param.dtype == self.buffer.dtype
), f"Different types for the bucket and the param, cannot proceed: {param.dtype} - {self.buffer.dtype}"
assert (
param.device == self.buffer.device
), f"Different devices for the bucket and the param, cannot proceed: {param.device} - {self.buffer.device}"
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
# Copy the current param value
if keep_existing_value:
self.buffer[self._fill : fill_next].copy_(param.data.flatten())
param.data = self.buffer[self._fill : fill_next].view_as(param.data)
self._fill = fill_next
@torch.no_grad()
def _reattach_params(self) -> None:
"""
Given the parameters which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
self._fill = 0
for p in self._params:
if p.dtype != self.buffer.dtype:
p.data = p.data.to(self.buffer.dtype)
self._add_param_as_view(p, keep_existing_value=False)
class GradBucket(Bucket):
"""
Helper class to simplify the handling of gradient buckets
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None:
super().__init__(size, dtype, device)
self._max_size = size
self._is_collapsed = False
self.params_checked_in = 0
self.destination = destination
self.sent = True
self.callback: Optional[Callable[[Any], None]] = None
def reset_checked_in(self) -> None:
"""Reset the counter of the parameter grads which have been checked in"""
self.params_checked_in = 0
self.sent = False
@property
def all_checked_in(self) -> bool:
"""Have all the expected gradient check-in happened ?"""
return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool:
"""Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?"""
return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
keep_param_alignment: bool = True,
) -> "GradBucket":
"""
Move the underlying buffer
"""
if self._is_collapsed:
self.rebuild()
super().to(device, dtype, non_blocking)
if keep_param_alignment:
self._reattach_grads()
def zero(self) -> None:
"""
Set all the grads to zero
"""
self.buffer.fill_(0.0)
@torch.no_grad()
def add_grad(self, param: torch.Tensor) -> None:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert id(param) not in self._param_ids, "The same gradients cannot be checked in twice"
if param.grad is None:
param.grad = torch.zeros_like(param)
self._add_grad_as_view(param)
self._params.append(param)
self._param_ids.append(id(param))
@torch.no_grad()
def collapse(self) -> None:
"""
Release the buffer from memory. The bucket will need to be rebuilt before use
"""
if not self._is_collapsed:
for p in self._params:
assert p.grad is not None
p.grad.detach_()
p.grad = None
self.buffer = torch.zeros(0, dtype=self.buffer.dtype, device=self.buffer.device)
self._fill = 0
self.params_checked_in = 0
self._is_collapsed = True
@torch.no_grad()
def rebuild(self) -> None:
"""
Given the parameter gradients which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
if self._is_collapsed:
self.buffer = torch.zeros(self._max_size, dtype=self._params[0].dtype, device=self._params[0].device)
for p in self._params:
self._add_grad_as_view(p)
self._is_collapsed = False
@torch.no_grad()
def shrink(self) -> None:
"""
Shrink the buffer to the size of the parameter gradients currently checked in, release the extra memory
"""
assert self.buffer.numel() > 0, "Cannot shrink a collapsed bucket, please rebuild"
self.buffer = self.buffer.resize_(self._fill).clone()
self._fill = 0
for p in self._params:
self._add_grad_as_view(p)
self._max_size = self._fill
@torch.no_grad()
def _reattach_grads(self) -> None:
"""
Given the parameters gradients which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
self._fill = 0
for p in self._params:
self._add_grad_as_view(p, keep_existing_value=False)
@torch.no_grad()
def _add_grad_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None:
assert self.buffer.numel() > 0, "Cannot add a gradient to a collapsed bucket, please rebuild"
assert param.dtype == self.buffer.dtype
assert param.device == self.buffer.device
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
# Copy the current grad value, if any
if param.grad is not None:
# keep param.grad in place
if keep_existing_value:
self.buffer[self._fill : fill_next].copy_(param.grad.data.flatten())
param.grad.data = self.buffer[self._fill : fill_next].view_as(param.data)
else:
param.grad = self.buffer[self._fill : fill_next].view_as(param.data)
self._fill = fill_next