github-actions[bot]
Sync from https://github.com/ryanlinjui/menu-text-detection
6bd37dd
from typing import Optional
from datasets import Dataset, DatasetDict
def split_dataset(
dataset: Dataset,
train: float,
validation: float,
test: float,
seed: Optional[int] = None
) -> DatasetDict:
"""
Split a single-split Hugging Face Dataset into train/validation/test subsets.
Args:
dataset (Dataset): The input dataset (e.g. load_dataset(...)['train']).
train (float): Proportion of data for the train split (0 < train < 1).
val (float): Proportion of data for the validation split (0 < val < 1).
test (float): Proportion of data for the test split (0 < test < 1).
Must satisfy train + val + test == 1.0.
seed (int): Random seed for reproducibility (default: None).
Returns:
DatasetDict: A dictionary with keys "train", "validation", and "test".
"""
# Verify ratios sum to 1.0
total = train + validation + test
if abs(total - 1.0) > 1e-8:
raise ValueError(f"train + validation + test must equal 1.0 (got {total})")
# First split: extract train vs. temp (validation + test)
temp_size = validation + test
split_1 = dataset.train_test_split(test_size=temp_size, seed=seed)
train_ds = split_1["train"]
temp_ds = split_1["test"]
# Second split: divide temp into validation vs. test
relative_test_size = test / temp_size
split_2 = temp_ds.train_test_split(test_size=relative_test_size, seed=seed)
validation_ds = split_2["train"]
test_ds = split_2["test"]
# Return a DatasetDict with all three splits
return DatasetDict({
"train": train_ds,
"validation": validation_ds,
"test": test_ds,
})