File size: 8,449 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, List, Optional, Union
import torch
class Bucket:
"""
Helper class to simplify the handling of buckets, which unify the underlying storage of multiple tensors
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device) -> None:
self._params: List[torch.Tensor] = []
self._param_ids: List[int] = []
self._fill = 0
# The actual flat tensor
self.buffer: torch.Tensor = torch.zeros(size, dtype=dtype, device=device)
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
keep_param_alignment: bool = True,
) -> "ParamBucket":
"""
Move the underlying buffer
"""
assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it"
self.buffer = self.buffer.to(device, dtype, non_blocking)
class ParamBucket(Bucket):
"""
Helper class to simplify the handling of parameter buckets
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device) -> None:
super().__init__(size, dtype, device)
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
keep_param_alignment: bool = True,
) -> "ParamBucket":
"""
Move the underlying buffer
"""
super().to(device, dtype, non_blocking)
if keep_param_alignment:
self._reattach_params()
@torch.no_grad()
def add_param(self, param: torch.Tensor) -> None:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert id(param) not in self._param_ids, "The same param cannot be checked in twice"
self._add_param_as_view(param)
self._params.append(param)
self._param_ids.append(id(param))
@torch.no_grad()
def _add_param_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None:
assert self.buffer is not None
assert (
param.dtype == self.buffer.dtype
), f"Different types for the bucket and the param, cannot proceed: {param.dtype} - {self.buffer.dtype}"
assert (
param.device == self.buffer.device
), f"Different devices for the bucket and the param, cannot proceed: {param.device} - {self.buffer.device}"
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
# Copy the current param value
if keep_existing_value:
self.buffer[self._fill : fill_next].copy_(param.data.flatten())
param.data = self.buffer[self._fill : fill_next].view_as(param.data)
self._fill = fill_next
@torch.no_grad()
def _reattach_params(self) -> None:
"""
Given the parameters which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
self._fill = 0
for p in self._params:
if p.dtype != self.buffer.dtype:
p.data = p.data.to(self.buffer.dtype)
self._add_param_as_view(p, keep_existing_value=False)
class GradBucket(Bucket):
"""
Helper class to simplify the handling of gradient buckets
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None:
super().__init__(size, dtype, device)
self._max_size = size
self._is_collapsed = False
self.params_checked_in = 0
self.destination = destination
self.sent = True
self.callback: Optional[Callable[[Any], None]] = None
def reset_checked_in(self) -> None:
"""Reset the counter of the parameter grads which have been checked in"""
self.params_checked_in = 0
self.sent = False
@property
def all_checked_in(self) -> bool:
"""Have all the expected gradient check-in happened ?"""
return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool:
"""Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?"""
return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
keep_param_alignment: bool = True,
) -> "GradBucket":
"""
Move the underlying buffer
"""
if self._is_collapsed:
self.rebuild()
super().to(device, dtype, non_blocking)
if keep_param_alignment:
self._reattach_grads()
def zero(self) -> None:
"""
Set all the grads to zero
"""
self.buffer.fill_(0.0)
@torch.no_grad()
def add_grad(self, param: torch.Tensor) -> None:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert id(param) not in self._param_ids, "The same gradients cannot be checked in twice"
if param.grad is None:
param.grad = torch.zeros_like(param)
self._add_grad_as_view(param)
self._params.append(param)
self._param_ids.append(id(param))
@torch.no_grad()
def collapse(self) -> None:
"""
Release the buffer from memory. The bucket will need to be rebuilt before use
"""
if not self._is_collapsed:
for p in self._params:
assert p.grad is not None
p.grad.detach_()
p.grad = None
self.buffer = torch.zeros(0, dtype=self.buffer.dtype, device=self.buffer.device)
self._fill = 0
self.params_checked_in = 0
self._is_collapsed = True
@torch.no_grad()
def rebuild(self) -> None:
"""
Given the parameter gradients which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
if self._is_collapsed:
self.buffer = torch.zeros(self._max_size, dtype=self._params[0].dtype, device=self._params[0].device)
for p in self._params:
self._add_grad_as_view(p)
self._is_collapsed = False
@torch.no_grad()
def shrink(self) -> None:
"""
Shrink the buffer to the size of the parameter gradients currently checked in, release the extra memory
"""
assert self.buffer.numel() > 0, "Cannot shrink a collapsed bucket, please rebuild"
self.buffer = self.buffer.resize_(self._fill).clone()
self._fill = 0
for p in self._params:
self._add_grad_as_view(p)
self._max_size = self._fill
@torch.no_grad()
def _reattach_grads(self) -> None:
"""
Given the parameters gradients which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
self._fill = 0
for p in self._params:
self._add_grad_as_view(p, keep_existing_value=False)
@torch.no_grad()
def _add_grad_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None:
assert self.buffer.numel() > 0, "Cannot add a gradient to a collapsed bucket, please rebuild"
assert param.dtype == self.buffer.dtype
assert param.device == self.buffer.device
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
# Copy the current grad value, if any
if param.grad is not None:
# keep param.grad in place
if keep_existing_value:
self.buffer[self._fill : fill_next].copy_(param.grad.data.flatten())
param.grad.data = self.buffer[self._fill : fill_next].view_as(param.data)
else:
param.grad = self.buffer[self._fill : fill_next].view_as(param.data)
self._fill = fill_next
|