File size: 3,959 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Classes for mixing samples from multiple sources."""
import random
import numpy as np
from .pytorch import IterableDataset
def round_robin_shortest(*sources):
"""Yield samples from multiple sources in a round-robin fashion until the shortest source is exhausted.
Args:
*sources: Iterable sources to draw samples from.
Yields:
Sample from one of the sources.
"""
i = 0
while True:
try:
sample = next(sources[i % len(sources)])
yield sample
except StopIteration:
break
i += 1
def round_robin_longest(*sources):
"""Yield samples from multiple sources in a round-robin fashion until all sources are exhausted.
Args:
*sources: Iterable sources to draw samples from.
Yields:
Sample from one of the sources.
"""
sources = list(sources)
i = 0
while len(sources) > 0:
i %= len(sources)
try:
sample = next(sources[i])
i += 1
yield sample
except StopIteration:
del sources[i]
class RoundRobin(IterableDataset):
"""Iterate over multiple datasets in a round-robin fashion."""
def __init__(self, datasets, longest=False):
"""Initialize the RoundRobin iterator.
Args:
datasets (list): List of datasets to iterate over.
longest (bool): If True, continue until the longest dataset is exhausted.
"""
self.datasets = datasets
self.longest = longest
def __iter__(self):
"""Return an iterator over the sources.
Returns:
iterator: An iterator that yields samples from the datasets in a round-robin fashion.
"""
sources = [iter(d) for d in self.datasets]
if self.longest:
return round_robin_longest(*sources)
else:
return round_robin_shortest(*sources)
def random_samples(sources, probs=None, longest=False):
"""Yield samples randomly from multiple sources based on given probabilities.
Args:
sources (list): List of iterable sources to draw samples from.
probs (list, optional): List of probabilities for each source. Defaults to None.
longest (bool): If True, continue until all sources are exhausted. Defaults to False.
Yields:
Sample randomly selected from one of the sources.
"""
if probs is None:
probs = [1] * len(sources)
else:
probs = list(probs)
while len(sources) > 0:
cum = (np.array(probs) / np.sum(probs)).cumsum()
r = random.random()
i = np.searchsorted(cum, r)
try:
yield next(sources[i])
except StopIteration:
if longest:
del sources[i]
del probs[i]
else:
break
class RandomMix(IterableDataset):
"""Iterate over multiple datasets by randomly selecting samples based on given probabilities."""
def __init__(self, datasets, probs=None, longest=False):
"""Initialize the RandomMix iterator.
Args:
datasets (list): List of datasets to iterate over.
probs (list, optional): List of probabilities for each dataset. Defaults to None.
longest (bool): If True, continue until all datasets are exhausted. Defaults to False.
"""
self.datasets = datasets
self.probs = probs
self.longest = longest
def __iter__(self):
"""Return an iterator over the sources.
Returns:
iterator: An iterator that yields samples randomly from the datasets.
"""
sources = [iter(d) for d in self.datasets]
return random_samples(sources, self.probs, longest=self.longest)
|