from typing import Optional, TypeVar
from .arrow_dataset import Dataset, _concatenate_map_style_datasets, _interleave_map_style_datasets
from .dataset_dict import DatasetDict, IterableDatasetDict
from .info import DatasetInfo
from .iterable_dataset import IterableDataset, _concatenate_iterable_datasets, _interleave_iterable_datasets
from .splits import NamedSplit
from .utils import logging
from .utils.py_utils import Literal
logger = logging.get_logger(__name__)
DatasetType = TypeVar("DatasetType", Dataset, IterableDataset)
def interleave_datasets(
datasets: list[DatasetType],
probabilities: Optional[list[float]] = None,
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
) -> DatasetType:
"""
Interleave several datasets (sources) into a single dataset.
The new dataset is constructed by alternating between the sources to get the examples.
You can use this function on a list of [`Dataset`] objects, or on a list of [`IterableDataset`] objects.
- If `probabilities` is `None` (default) the new dataset is constructed by cycling between each source to get the examples.
- If `probabilities` is not `None`, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities.
The resulting dataset ends when one of the source datasets runs out of examples except when `oversampling` is `True`,
in which case, the resulting dataset ends when all datasets have ran out of examples at least one time.
Note for iterable datasets:
In a distributed setup or in PyTorch DataLoader workers, the stopping strategy is applied per process.
Therefore the "first_exhausted" strategy on an sharded iterable dataset can generate less samples in total (up to 1 missing sample per subdataset per worker).
Args:
datasets (`List[Dataset]` or `List[IterableDataset]`):
List of datasets to interleave.
probabilities (`List[float]`, *optional*, defaults to `None`):
If specified, the new dataset is constructed by sampling
examples from one source at a time according to these probabilities.
seed (`int`, *optional*, defaults to `None`):
The random seed used to choose a source for each example.
info ([`DatasetInfo`], *optional*):
Dataset information, like description, citation, etc.
split ([`NamedSplit`], *optional*):
Name of the dataset split.
stopping_strategy (`str`, defaults to `first_exhausted`):
Two strategies are proposed right now, `first_exhausted` and `all_exhausted`.
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous:
- with no probabilities, the resulting dataset will have `max_length_datasets*nb_dataset` samples.
- with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting.
Returns:
[`Dataset`] or [`IterableDataset`]: Return type depends on the input `datasets`
parameter. `Dataset` if the input is a list of `Dataset`, `IterableDataset` if the input is a list of
`IterableDataset`.
Example:
For regular datasets (map-style):
```python
>>> from datasets import Dataset, interleave_datasets
>>> d1 = Dataset.from_dict({"a": [0, 1, 2]})
>>> d2 = Dataset.from_dict({"a": [10, 11, 12]})
>>> d3 = Dataset.from_dict({"a": [20, 21, 22]})
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42, stopping_strategy="all_exhausted")
>>> dataset["a"]
[10, 0, 11, 1, 2, 20, 12, 10, 0, 1, 2, 21, 0, 11, 1, 2, 0, 1, 12, 2, 10, 0, 22]
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42)
>>> dataset["a"]
[10, 0, 11, 1, 2]
>>> dataset = interleave_datasets([d1, d2, d3])
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22]
>>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22]
>>> d1 = Dataset.from_dict({"a": [0, 1, 2]})
>>> d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
>>> d3 = Dataset.from_dict({"a": [20, 21, 22, 23, 24]})
>>> dataset = interleave_datasets([d1, d2, d3])
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22]
>>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 23, 1, 10, 24]
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42)
>>> dataset["a"]
[10, 0, 11, 1, 2]
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42, stopping_strategy="all_exhausted")
>>> dataset["a"]
[10, 0, 11, 1, 2, 20, 12, 13, ..., 0, 1, 2, 0, 24]
For datasets in streaming mode (iterable):
>>> from datasets import interleave_datasets
>>> d1 = load_dataset('allenai/c4', 'es', split='train', streaming=True)
>>> d2 = load_dataset('allenai/c4', 'fr', split='train', streaming=True)
>>> dataset = interleave_datasets([d1, d2])
>>> iterator = iter(dataset)
>>> next(iterator)
{'text': 'Comprar Zapatillas para niƱa en chancla con goma por...'}
>>> next(iterator)
{'text': 'Le sacre de philippe ier, 23 mai 1059 - Compte Rendu...'
```
"""
from .arrow_dataset import Dataset
from .iterable_dataset import IterableDataset
if not datasets:
raise ValueError("Unable to interleave an empty list of datasets.")
for i, dataset in enumerate(datasets):
if not isinstance(dataset, (Dataset, IterableDataset)):
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
if not dataset:
raise ValueError(
f"Expected a list of Dataset objects or a list of IterableDataset objects, but element at position {i} "
"is an empty dataset dictionary."
)
raise ValueError(
f"Dataset at position {i} has at least one split: {list(dataset)}\n"
f"Please pick one to interleave with the other datasets, for example: dataset['{next(iter(dataset))}']"
)
raise ValueError(
f"Expected a list of Dataset objects or a list of IterableDataset objects, but element at position {i} is a {type(dataset).__name__}."
)
if i == 0:
dataset_type, other_type = (
(Dataset, IterableDataset) if isinstance(dataset, Dataset) else (IterableDataset, Dataset)
)
elif not isinstance(dataset, dataset_type):
raise ValueError(
f"Unable to interleave a {dataset_type.__name__} (at position 0) with a {other_type.__name__} (at position {i}). Expected a list of Dataset objects or a list of IterableDataset objects."
)
if stopping_strategy not in ["first_exhausted", "all_exhausted"]:
raise ValueError(f"{stopping_strategy} is not supported. Please enter a valid stopping_strategy.")
if dataset_type is Dataset:
return _interleave_map_style_datasets(
datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy
)
else:
return _interleave_iterable_datasets(
datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy
)
def concatenate_datasets(
dsets: list[DatasetType],
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
axis: int = 0,
) -> DatasetType:
"""
Converts a list of [`Dataset`] with the same schema into a single [`Dataset`].
Args:
dsets (`List[datasets.Dataset]`):
List of Datasets to concatenate.
info (`DatasetInfo`, *optional*):
Dataset information, like description, citation, etc.
split (`NamedSplit`, *optional*):
Name of the dataset split.
axis (`{0, 1}`, defaults to `0`):
Axis to concatenate over, where `0` means over rows (vertically) and `1` means over columns
(horizontally).
Example:
```py
>>> ds3 = concatenate_datasets([ds1, ds2])
```
"""
if not dsets:
raise ValueError("Unable to concatenate an empty list of datasets.")
for i, dataset in enumerate(dsets):
if not isinstance(dataset, (Dataset, IterableDataset)):
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
if not dataset:
raise ValueError(
f"Expected a list of Dataset objects or a list of IterableDataset objects, but element at position {i} "
"is an empty dataset dictionary."
)
raise ValueError(
f"Dataset at position {i} has at least one split: {list(dataset)}\n"
f"Please pick one to interleave with the other datasets, for example: dataset['{next(iter(dataset))}']"
)
raise ValueError(
f"Expected a list of Dataset objects or a list of IterableDataset objects, but element at position {i} is a {type(dataset).__name__}."
)
if i == 0:
dataset_type, other_type = (
(Dataset, IterableDataset) if isinstance(dataset, Dataset) else (IterableDataset, Dataset)
)
elif not isinstance(dataset, dataset_type):
raise ValueError(
f"Unable to interleave a {dataset_type.__name__} (at position 0) with a {other_type.__name__} (at position {i}). Expected a list of Dataset objects or a list of IterableDataset objects."
)
if dataset_type is Dataset:
return _concatenate_map_style_datasets(dsets, info=info, split=split, axis=axis)
else:
return _concatenate_iterable_datasets(dsets, info=info, split=split, axis=axis)