|
from typing import TypeVar |
|
|
|
from .arrow_dataset import Dataset, _split_by_node_map_style_dataset |
|
from .iterable_dataset import IterableDataset, _split_by_node_iterable_dataset |
|
|
|
|
|
DatasetType = TypeVar("DatasetType", Dataset, IterableDataset) |
|
|
|
|
|
def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> DatasetType: |
|
""" |
|
Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`. |
|
|
|
For map-style datasets: |
|
|
|
Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset. |
|
To maximize data loading throughput, chunks are made of contiguous data on disk if possible. |
|
|
|
For iterable datasets: |
|
|
|
If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`), |
|
then the shards are evenly assigned across the nodes, which is the most optimized. |
|
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. |
|
|
|
Args: |
|
dataset ([`Dataset`] or [`IterableDataset`]): |
|
The dataset to split by node. |
|
rank (`int`): |
|
Rank of the current node. |
|
world_size (`int`): |
|
Total number of nodes. |
|
|
|
Returns: |
|
[`Dataset`] or [`IterableDataset`]: The dataset to be used on the node at rank `rank`. |
|
""" |
|
if isinstance(dataset, Dataset): |
|
return _split_by_node_map_style_dataset(dataset, rank=rank, world_size=world_size) |
|
else: |
|
return _split_by_node_iterable_dataset(dataset, rank=rank, world_size=world_size) |
|
|