File size: 1,719 Bytes
6bd37dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from typing import Optional

from datasets import Dataset, DatasetDict

def split_dataset(
    dataset: Dataset,
    train: float,
    validation: float,
    test: float,
    seed: Optional[int] = None
) -> DatasetDict:
    """
    Split a single-split Hugging Face Dataset into train/validation/test subsets.

    Args:
        dataset (Dataset): The input dataset (e.g. load_dataset(...)['train']).
        train (float): Proportion of data for the train split (0 < train < 1).
        val (float): Proportion of data for the validation split (0 < val < 1).
        test (float): Proportion of data for the test split (0 < test < 1).
                            Must satisfy train + val + test == 1.0.
        seed (int): Random seed for reproducibility (default: None).

    Returns:
        DatasetDict: A dictionary with keys "train", "validation", and "test".
    """
    # Verify ratios sum to 1.0
    total = train + validation + test
    if abs(total - 1.0) > 1e-8:
        raise ValueError(f"train + validation + test must equal 1.0 (got {total})")

    # First split: extract train vs. temp (validation + test)
    temp_size = validation + test
    split_1 = dataset.train_test_split(test_size=temp_size, seed=seed)
    train_ds = split_1["train"]
    temp_ds  = split_1["test"]

    # Second split: divide temp into validation vs. test
    relative_test_size = test / temp_size
    split_2 = temp_ds.train_test_split(test_size=relative_test_size, seed=seed)
    validation_ds  = split_2["train"]
    test_ds = split_2["test"]

    # Return a DatasetDict with all three splits
    return DatasetDict({
        "train":      train_ds,
        "validation": validation_ds,
        "test":       test_ds,
    })