File size: 3,959 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#

"""Classes for mixing samples from multiple sources."""

import random

import numpy as np

from .pytorch import IterableDataset


def round_robin_shortest(*sources):
    """Yield samples from multiple sources in a round-robin fashion until the shortest source is exhausted.

    Args:
        *sources: Iterable sources to draw samples from.

    Yields:
        Sample from one of the sources.
    """
    i = 0
    while True:
        try:
            sample = next(sources[i % len(sources)])
            yield sample
        except StopIteration:
            break
        i += 1


def round_robin_longest(*sources):
    """Yield samples from multiple sources in a round-robin fashion until all sources are exhausted.

    Args:
        *sources: Iterable sources to draw samples from.

    Yields:
        Sample from one of the sources.
    """
    sources = list(sources)
    i = 0
    while len(sources) > 0:
        i %= len(sources)
        try:
            sample = next(sources[i])
            i += 1
            yield sample
        except StopIteration:
            del sources[i]


class RoundRobin(IterableDataset):
    """Iterate over multiple datasets in a round-robin fashion."""

    def __init__(self, datasets, longest=False):
        """Initialize the RoundRobin iterator.

        Args:
            datasets (list): List of datasets to iterate over.
            longest (bool): If True, continue until the longest dataset is exhausted.
        """
        self.datasets = datasets
        self.longest = longest

    def __iter__(self):
        """Return an iterator over the sources.

        Returns:
            iterator: An iterator that yields samples from the datasets in a round-robin fashion.
        """
        sources = [iter(d) for d in self.datasets]
        if self.longest:
            return round_robin_longest(*sources)
        else:
            return round_robin_shortest(*sources)


def random_samples(sources, probs=None, longest=False):
    """Yield samples randomly from multiple sources based on given probabilities.

    Args:
        sources (list): List of iterable sources to draw samples from.
        probs (list, optional): List of probabilities for each source. Defaults to None.
        longest (bool): If True, continue until all sources are exhausted. Defaults to False.

    Yields:
        Sample randomly selected from one of the sources.
    """
    if probs is None:
        probs = [1] * len(sources)
    else:
        probs = list(probs)
    while len(sources) > 0:
        cum = (np.array(probs) / np.sum(probs)).cumsum()
        r = random.random()
        i = np.searchsorted(cum, r)
        try:
            yield next(sources[i])
        except StopIteration:
            if longest:
                del sources[i]
                del probs[i]
            else:
                break


class RandomMix(IterableDataset):
    """Iterate over multiple datasets by randomly selecting samples based on given probabilities."""

    def __init__(self, datasets, probs=None, longest=False):
        """Initialize the RandomMix iterator.

        Args:
            datasets (list): List of datasets to iterate over.
            probs (list, optional): List of probabilities for each dataset. Defaults to None.
            longest (bool): If True, continue until all datasets are exhausted. Defaults to False.
        """
        self.datasets = datasets
        self.probs = probs
        self.longest = longest

    def __iter__(self):
        """Return an iterator over the sources.

        Returns:
            iterator: An iterator that yields samples randomly from the datasets.
        """
        sources = [iter(d) for d in self.datasets]
        return random_samples(sources, self.probs, longest=self.longest)