|
|
|
import itertools |
|
from collections.abc import Iterable, Iterator, Sequence, Sized |
|
from typing import Generic, Optional, TypeVar, Union |
|
|
|
import torch |
|
|
|
|
|
__all__ = [ |
|
"BatchSampler", |
|
"RandomSampler", |
|
"Sampler", |
|
"SequentialSampler", |
|
"SubsetRandomSampler", |
|
"WeightedRandomSampler", |
|
] |
|
|
|
|
|
_T_co = TypeVar("_T_co", covariant=True) |
|
|
|
|
|
class Sampler(Generic[_T_co]): |
|
r"""Base class for all Samplers. |
|
|
|
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a |
|
way to iterate over indices or lists of indices (batches) of dataset elements, |
|
and may provide a :meth:`__len__` method that returns the length of the returned iterators. |
|
|
|
Args: |
|
data_source (Dataset): This argument is not used and will be removed in 2.2.0. |
|
You may still have custom implementation that utilizes it. |
|
|
|
Example: |
|
>>> # xdoctest: +SKIP |
|
>>> class AccedingSequenceLengthSampler(Sampler[int]): |
|
>>> def __init__(self, data: List[str]) -> None: |
|
>>> self.data = data |
|
>>> |
|
>>> def __len__(self) -> int: |
|
>>> return len(self.data) |
|
>>> |
|
>>> def __iter__(self) -> Iterator[int]: |
|
>>> sizes = torch.tensor([len(x) for x in self.data]) |
|
>>> yield from torch.argsort(sizes).tolist() |
|
>>> |
|
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): |
|
>>> def __init__(self, data: List[str], batch_size: int) -> None: |
|
>>> self.data = data |
|
>>> self.batch_size = batch_size |
|
>>> |
|
>>> def __len__(self) -> int: |
|
>>> return (len(self.data) + self.batch_size - 1) // self.batch_size |
|
>>> |
|
>>> def __iter__(self) -> Iterator[List[int]]: |
|
>>> sizes = torch.tensor([len(x) for x in self.data]) |
|
>>> for batch in torch.chunk(torch.argsort(sizes), len(self)): |
|
>>> yield batch.tolist() |
|
|
|
.. note:: The :meth:`__len__` method isn't strictly required by |
|
:class:`~torch.utils.data.DataLoader`, but is expected in any |
|
calculation involving the length of a :class:`~torch.utils.data.DataLoader`. |
|
""" |
|
|
|
def __init__(self, data_source: Optional[Sized] = None) -> None: |
|
if data_source is not None: |
|
import warnings |
|
|
|
warnings.warn( |
|
"`data_source` argument is not used and will be removed in 2.2.0." |
|
"You may still have custom implementation that utilizes it." |
|
) |
|
|
|
def __iter__(self) -> Iterator[_T_co]: |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SequentialSampler(Sampler[int]): |
|
r"""Samples elements sequentially, always in the same order. |
|
|
|
Args: |
|
data_source (Dataset): dataset to sample from |
|
""" |
|
|
|
data_source: Sized |
|
|
|
def __init__(self, data_source: Sized) -> None: |
|
self.data_source = data_source |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
return iter(range(len(self.data_source))) |
|
|
|
def __len__(self) -> int: |
|
return len(self.data_source) |
|
|
|
|
|
class RandomSampler(Sampler[int]): |
|
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. |
|
|
|
If with replacement, then user can specify :attr:`num_samples` to draw. |
|
|
|
Args: |
|
data_source (Dataset): dataset to sample from |
|
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` |
|
num_samples (int): number of samples to draw, default=`len(dataset)`. |
|
generator (Generator): Generator used in sampling. |
|
""" |
|
|
|
data_source: Sized |
|
replacement: bool |
|
|
|
def __init__( |
|
self, |
|
data_source: Sized, |
|
replacement: bool = False, |
|
num_samples: Optional[int] = None, |
|
generator=None, |
|
) -> None: |
|
self.data_source = data_source |
|
self.replacement = replacement |
|
self._num_samples = num_samples |
|
self.generator = generator |
|
|
|
if not isinstance(self.replacement, bool): |
|
raise TypeError( |
|
f"replacement should be a boolean value, but got replacement={self.replacement}" |
|
) |
|
|
|
if not isinstance(self.num_samples, int) or self.num_samples <= 0: |
|
raise ValueError( |
|
f"num_samples should be a positive integer value, but got num_samples={self.num_samples}" |
|
) |
|
|
|
@property |
|
def num_samples(self) -> int: |
|
|
|
if self._num_samples is None: |
|
return len(self.data_source) |
|
return self._num_samples |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
n = len(self.data_source) |
|
if self.generator is None: |
|
seed = int(torch.empty((), dtype=torch.int64).random_().item()) |
|
generator = torch.Generator() |
|
generator.manual_seed(seed) |
|
else: |
|
generator = self.generator |
|
|
|
if self.replacement: |
|
for _ in range(self.num_samples // 32): |
|
yield from torch.randint( |
|
high=n, size=(32,), dtype=torch.int64, generator=generator |
|
).tolist() |
|
yield from torch.randint( |
|
high=n, |
|
size=(self.num_samples % 32,), |
|
dtype=torch.int64, |
|
generator=generator, |
|
).tolist() |
|
else: |
|
for _ in range(self.num_samples // n): |
|
yield from torch.randperm(n, generator=generator).tolist() |
|
yield from torch.randperm(n, generator=generator).tolist()[ |
|
: self.num_samples % n |
|
] |
|
|
|
def __len__(self) -> int: |
|
return self.num_samples |
|
|
|
|
|
class SubsetRandomSampler(Sampler[int]): |
|
r"""Samples elements randomly from a given list of indices, without replacement. |
|
|
|
Args: |
|
indices (sequence): a sequence of indices |
|
generator (Generator): Generator used in sampling. |
|
""" |
|
|
|
indices: Sequence[int] |
|
|
|
def __init__(self, indices: Sequence[int], generator=None) -> None: |
|
self.indices = indices |
|
self.generator = generator |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
for i in torch.randperm(len(self.indices), generator=self.generator): |
|
yield self.indices[i] |
|
|
|
def __len__(self) -> int: |
|
return len(self.indices) |
|
|
|
|
|
class WeightedRandomSampler(Sampler[int]): |
|
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). |
|
|
|
Args: |
|
weights (sequence) : a sequence of weights, not necessary summing up to one |
|
num_samples (int): number of samples to draw |
|
replacement (bool): if ``True``, samples are drawn with replacement. |
|
If not, they are drawn without replacement, which means that when a |
|
sample index is drawn for a row, it cannot be drawn again for that row. |
|
generator (Generator): Generator used in sampling. |
|
|
|
Example: |
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic") |
|
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) |
|
[4, 4, 1, 4, 5] |
|
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) |
|
[0, 1, 4, 3, 2] |
|
""" |
|
|
|
weights: torch.Tensor |
|
num_samples: int |
|
replacement: bool |
|
|
|
def __init__( |
|
self, |
|
weights: Sequence[float], |
|
num_samples: int, |
|
replacement: bool = True, |
|
generator=None, |
|
) -> None: |
|
if ( |
|
not isinstance(num_samples, int) |
|
or isinstance(num_samples, bool) |
|
or num_samples <= 0 |
|
): |
|
raise ValueError( |
|
f"num_samples should be a positive integer value, but got num_samples={num_samples}" |
|
) |
|
if not isinstance(replacement, bool): |
|
raise ValueError( |
|
f"replacement should be a boolean value, but got replacement={replacement}" |
|
) |
|
|
|
weights_tensor = torch.as_tensor(weights, dtype=torch.double) |
|
if len(weights_tensor.shape) != 1: |
|
raise ValueError( |
|
"weights should be a 1d sequence but given " |
|
f"weights have shape {tuple(weights_tensor.shape)}" |
|
) |
|
|
|
self.weights = weights_tensor |
|
self.num_samples = num_samples |
|
self.replacement = replacement |
|
self.generator = generator |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
rand_tensor = torch.multinomial( |
|
self.weights, self.num_samples, self.replacement, generator=self.generator |
|
) |
|
yield from iter(rand_tensor.tolist()) |
|
|
|
def __len__(self) -> int: |
|
return self.num_samples |
|
|
|
|
|
class BatchSampler(Sampler[list[int]]): |
|
r"""Wraps another sampler to yield a mini-batch of indices. |
|
|
|
Args: |
|
sampler (Sampler or Iterable): Base sampler. Can be any iterable object |
|
batch_size (int): Size of mini-batch. |
|
drop_last (bool): If ``True``, the sampler will drop the last batch if |
|
its size would be less than ``batch_size`` |
|
|
|
Example: |
|
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) |
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] |
|
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) |
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sampler: Union[Sampler[int], Iterable[int]], |
|
batch_size: int, |
|
drop_last: bool, |
|
) -> None: |
|
|
|
|
|
|
|
if ( |
|
not isinstance(batch_size, int) |
|
or isinstance(batch_size, bool) |
|
or batch_size <= 0 |
|
): |
|
raise ValueError( |
|
f"batch_size should be a positive integer value, but got batch_size={batch_size}" |
|
) |
|
if not isinstance(drop_last, bool): |
|
raise ValueError( |
|
f"drop_last should be a boolean value, but got drop_last={drop_last}" |
|
) |
|
self.sampler = sampler |
|
self.batch_size = batch_size |
|
self.drop_last = drop_last |
|
|
|
def __iter__(self) -> Iterator[list[int]]: |
|
|
|
sampler_iter = iter(self.sampler) |
|
if self.drop_last: |
|
|
|
args = [sampler_iter] * self.batch_size |
|
for batch_droplast in zip(*args): |
|
yield [*batch_droplast] |
|
else: |
|
batch = [*itertools.islice(sampler_iter, self.batch_size)] |
|
while batch: |
|
yield batch |
|
batch = [*itertools.islice(sampler_iter, self.batch_size)] |
|
|
|
def __len__(self) -> int: |
|
|
|
|
|
|
|
|
|
if self.drop_last: |
|
return len(self.sampler) // self.batch_size |
|
else: |
|
return (len(self.sampler) + self.batch_size - 1) // self.batch_size |
|
|