File size: 6,561 Bytes
2f5127c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from dataclasses import dataclass, field
from itertools import chain
from typing import Optional

from datasets import load_dataset
from huggingface_hub import ModelCard
from transformers import HfArgumentParser


@dataclass
class ScriptArguments:
    r"""
    Arguments for the script.

    Args:
        push_to_hub (`bool`, *optional*, defaults to `False`):
            Whether to push the dataset to the Hugging Face Hub.
        repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`):
            Hugging Face repository ID to push the dataset to.
        dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
            Number of workers to use for dataset processing.
    """

    push_to_hub: bool = field(
        default=False,
        metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
    )
    repo_id: str = field(
        default="trl-lib/math_shepherd",
        metadata={"help": "Hugging Face repository ID to push the dataset to."},
    )
    dataset_num_proc: Optional[int] = field(
        default=None,
        metadata={"help": "Number of workers to use for dataset processing."},
    )


def process_example(example):
    # Replace "ки" with "ⶻ" so that the size of the "input" matches the size of the "label"
    inputs = example["input"].replace("ки", "ⶻ")

    # Find the indices of the "ⶻ" characters (that should match with the indexes of the "+" or "-" in the label)
    indexes = [m.start() for m in re.finditer("ⶻ", inputs)]

    # Sanity that all indexes are either "+" or "-"
    assert all(example["label"][idx] in ["+", "-"] for idx in indexes)

    # Get the labels
    labels = [example["label"][idx] == "+" for idx in indexes]

    # Split the inputs into steps (caution, the first step is missing here, it is the prompt)
    steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]))]

    # Remove the last step (single ⶻ)
    steps = steps[:-1]

    # Get the prompt (first part) and completions (rest)
    prompt = steps[0]
    completions = steps[1:]

    # Remove the heading "ⶻ" and the final whitespace from the completions
    assert all(completion.startswith("ⶻ") for completion in completions)
    completions = [completion[1:].strip() for completion in completions]

    # At this point, we need to retrieve the first step from the prompt.
    # First, we handle particular cases (annotation error) where we have a first label before the end of the prompt.
    if prompt.startswith(
        (
            "Mr. Rocky",
            "Parker",
            "What is the smallest positive",
            " The Myth",
            "Let $\\mathbf{a}$",
            "Find the arithmetic",
            "Determine an ordered pair",
            "Determine the ordered pair",
            "At the Quill and Scroll stationery",
            "Round to the nearest",
            r"Calculate $\sqrt{10p}",
            r"Simplify $\sqrt{28x}",
        )
    ):
        # Some spotted datasets errors where there is an annotation in the prompt: we remove it
        labels = labels[1:]

    # Then we handle the general case: we get the first step from the prompt by looking for "Step 1:" or "step 1:" or
    # (less common) "?".
    elif "Step 1:" in prompt:
        prompt, first_step = prompt.split("Step 1:")
        first_step = "Step 1:" + first_step
        completions = [first_step.strip()] + completions
    elif "step 1:" in prompt:
        prompt, first_step = prompt.split("step 1:")
        first_step = "step 1:" + first_step
        completions = [first_step.strip()] + completions
    elif "?" in prompt:
        prompt, first_step = prompt.split("?")
        prompt = prompt + "?"
        completions = [first_step.strip()] + completions
    else:
        raise ValueError(f"Prompt can't be processed: {prompt}")

    # Strip the prompt
    prompt = prompt.strip()

    # Sanity check that the length of the completions is the same as the length of the labels
    assert len(completions) == len(labels)

    return {"prompt": prompt, "completions": completions, "labels": labels}


model_card = ModelCard("""
---
tags: [trl]
---

# Math-Shepherd Dataset

## Summary

The Math-Shepherd dataset is a processed version of [Math-Shepherd dataset](peiyi9979/Math-Shepherd), designed to train models using the [TRL library](https://github.com/huggingface/trl) for stepwise supervision tasks. It provides step-by-step solutions to mathematical problems, enabling models to learn and verify each step of a solution, thereby enhancing their reasoning capabilities.

## Data Structure

- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
- **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision)

Columns:
- `"prompt"`: The problem statement.
- `"completions"`: A list of reasoning steps generated to solve the problem.
- `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step.

This structure allows models to learn the correctness of each step in a solution, facilitating improved reasoning and problem-solving abilities.

## Generation script

The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/math_shepherd.py).
""")

if __name__ == "__main__":
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]

    dataset = load_dataset("peiyi9979/Math-Shepherd", split="train")

    dataset = dataset.map(
        process_example,
        remove_columns=["input", "label", "task"],
        num_proc=script_args.dataset_num_proc,
    )
    dataset = dataset.train_test_split(test_size=0.05, seed=42)

    if script_args.push_to_hub:
        dataset.push_to_hub(script_args.repo_id)
        model_card.push_to_hub(script_args.repo_id, repo_type="dataset")