Spaces:
Running
Running
from typing import Optional | |
from datasets import Dataset, DatasetDict | |
def split_dataset( | |
dataset: Dataset, | |
train: float, | |
validation: float, | |
test: float, | |
seed: Optional[int] = None | |
) -> DatasetDict: | |
""" | |
Split a single-split Hugging Face Dataset into train/validation/test subsets. | |
Args: | |
dataset (Dataset): The input dataset (e.g. load_dataset(...)['train']). | |
train (float): Proportion of data for the train split (0 < train < 1). | |
val (float): Proportion of data for the validation split (0 < val < 1). | |
test (float): Proportion of data for the test split (0 < test < 1). | |
Must satisfy train + val + test == 1.0. | |
seed (int): Random seed for reproducibility (default: None). | |
Returns: | |
DatasetDict: A dictionary with keys "train", "validation", and "test". | |
""" | |
# Verify ratios sum to 1.0 | |
total = train + validation + test | |
if abs(total - 1.0) > 1e-8: | |
raise ValueError(f"train + validation + test must equal 1.0 (got {total})") | |
# First split: extract train vs. temp (validation + test) | |
temp_size = validation + test | |
split_1 = dataset.train_test_split(test_size=temp_size, seed=seed) | |
train_ds = split_1["train"] | |
temp_ds = split_1["test"] | |
# Second split: divide temp into validation vs. test | |
relative_test_size = test / temp_size | |
split_2 = temp_ds.train_test_split(test_size=relative_test_size, seed=seed) | |
validation_ds = split_2["train"] | |
test_ds = split_2["test"] | |
# Return a DatasetDict with all three splits | |
return DatasetDict({ | |
"train": train_ds, | |
"validation": validation_ds, | |
"test": test_ds, | |
}) |