# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. # Copyright 2019 Kakao Brain # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Manipulation of micro-batches.""" import typing from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast import torch from torch import Tensor import torch.cuda.comm __all__: List[str] = [] Tensors = Tuple[Tensor, ...] TensorOrTensors = Union[Tensor, Tensors] Function = Callable[[TensorOrTensors], TensorOrTensors] class Batch: """An abstraction of an atomic tensor or a tuple of tensors. This eliminates every boilerplate code to classify an atomic tensor or a tuple of tensors. :: x = generate_tensor_or_tensors() x = Batch(x) # in-place update x[0] = F.apply(x[0]) x[:] = F.apply(*x) # f(x) if x is a tensor. # f(*x) if x is a tuple of tensors. # y is also a batch. y = x.call(f) """ def __init__(self, value: TensorOrTensors, index: int) -> None: self.value = value self.atomic = torch.is_tensor(value) self.__index = index @property def index(self) -> int: return self.__index @property def tensor(self) -> Tensor: """Retrieves the underlying tensor.""" if not self.atomic: raise AttributeError("not atomic batch") return cast(Tensor, self.value) @property def tensors(self) -> Tensors: """Retrieves the underlying tensors.""" if self.atomic: raise AttributeError("batch is atomic") return cast(Tensors, self.value) @property def tensor_or_tensors(self) -> TensorOrTensors: """Retrieves the underlying tensor or tensors regardless of type.""" return self.value def call(self, function: Function) -> "Batch": """Calls a function by the underlying tensor or tensors. It also wraps the output with :class:`Batch`. """ return Batch(function(self.value), self.index) def __repr__(self) -> str: return f"Batch[atomic={self.atomic!r}]({self.value!r})" def __iter__(self) -> Iterator[Tensor]: if self.atomic: yield self.tensor else: yield from self.tensors def __len__(self) -> int: return 1 if self.atomic else len(self.tensors) def __getitem__(self, index: int) -> Tensor: if not self.atomic: return self.tensors[index] if index != 0: raise IndexError("atomic batch allows index 0 only") return self.tensor # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload". @typing.overload def __setitem__(self, index: int, value: Tensor) -> None: ... @typing.overload def __setitem__(self, index: slice, value: Tensors) -> None: ... def __setitem__(self, index: Union[int, slice], value: TensorOrTensors) -> None: if isinstance(index, int): value = cast(Tensor, value) self._setitem_by_index(index, value) else: value = cast(Tensors, value) self._setitem_by_slice(index, value) def _setitem_by_index(self, index: int, value: Tensor) -> None: if not self.atomic: i = index self.value = self.value[:i] + (value,) + self.value[i + 1 :] return if index != 0: raise IndexError("atomic batch allows index 0 only") self.value = value def _setitem_by_slice(self, index: slice, value: Tensors) -> None: if not (index.start is index.stop is index.step is None): raise NotImplementedError("only slice [:] supported") if not self.atomic: self.value = value return if len(value) != 1: raise IndexError("atomic batch cannot be replaced with multiple tensors") self.value = value[0] def check(input: TensorOrTensors) -> None: """Checks whether the input is a tensor or tensors. Raises: TypeError: input is not a tensor or tensors. """ if isinstance(input, tuple): for x in input: check(x) return if not isinstance(input, Tensor): raise TypeError(f"expected Tensor, but got {input.__class__.__name__}") def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]: """Splits an input mini-batch into multiple micro-batches.""" inputs: Iterable[TensorOrTensors] if isinstance(input, Tensor): inputs = input.chunk(chunks) else: rotated: List[Tensors] = [] for tensor in input: tensors = tensor.chunk(chunks) rotated.append(cast(Tensors, tensors)) inputs = zip(*rotated) return [Batch(x, i) for i, x in enumerate(inputs)] def gather(outputs: List[Batch]) -> TensorOrTensors: """Concatenates output micro-batches into a mini-batch.""" output: TensorOrTensors if outputs[0].atomic: tensors = tuple(b.tensor for b in outputs) output = torch.cat(tensors) else: rotated = [b.tensors for b in outputs] output_buf = [] for tensors in zip(*rotated): output_buf.append(torch.cat(tensors)) output = tuple(output_buf) return output