File size: 6,078 Bytes
2f5127c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
from huggingface_hub import ModelCard
from transformers import HfArgumentParser


@dataclass
class ScriptArguments:
    r"""
    Arguments for the script.

    Args:
        push_to_hub (`bool`, *optional*, defaults to `False`):
            Whether to push the dataset to the Hugging Face Hub.
        repo_id (`str`, *optional*, defaults to `"trl-lib/prm800k"`):
            Hugging Face repository ID to push the dataset to.
        dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
            Number of workers to use for dataset processing.
    """

    push_to_hub: bool = field(
        default=False,
        metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
    )
    repo_id: str = field(
        default="trl-lib/prm800k",
        metadata={"help": "Hugging Face repository ID to push the dataset to."},
    )
    dataset_num_proc: Optional[int] = field(
        default=None,
        metadata={"help": "Number of workers to use for dataset processing."},
    )


def process_example(example):
    outputs = []
    prompt = example["question"]["problem"]

    # Iterate through each step
    previous_completions = []
    previous_labels = []
    for step in example["label"]["steps"]:
        if step["completions"] is None and step["human_completion"] is None and step["chosen_completion"] is None:
            # happens sometimes
            break
        # Loop through completions
        for completion_idx, completion in enumerate(step["completions"]):
            # For every completion that are not chosen, we are in a terminal state, so we can add it to the list of outputs.
            if completion_idx != step["chosen_completion"]:
                content = completion["text"]
                completions = previous_completions[:] + [content]
                label = completion["rating"] == 1
                labels = previous_labels[:] + [label]
                outputs.append({"prompt": prompt, "completions": completions, "labels": labels})

        # Now, exapand the previous completions and labels
        if step["chosen_completion"] is not None:
            chosen_completion = step["completions"][step["chosen_completion"]]
            label = chosen_completion["rating"] == 1
        elif step["human_completion"] is not None:
            chosen_completion = step["human_completion"]
            label = True
        else:
            break
        content = chosen_completion["text"]
        previous_completions.append(content)
        previous_labels.append(label)

    # Last step: we are in a terminal state, so we can add it to the list of outputs
    outputs.append({"prompt": prompt, "completions": previous_completions, "labels": previous_labels})
    return outputs


def process_batch(examples):
    outputs = []
    batch_size = len(examples["label"])
    for idx in range(batch_size):
        example = {k: v[idx] for k, v in examples.items()}
        outputs.extend(process_example(example))
    # list of dict to dict of list
    outputs = {k: [v[k] for v in outputs] for k in outputs[0]}
    return outputs


model_card = ModelCard("""
---
tags: [trl]
---

# PRM800K Dataset

## Summary

The PRM800K dataset is a processed version of [OpenAI's PRM800K](https://github.com/openai/prm800k), designed to train models using the [TRL library](https://github.com/huggingface/trl) for stepwise supervision tasks. It contains 800,000 step-level correctness labels for model-generated solutions to problems from the MATH dataset. This dataset enables models to learn and verify each step of a solution, enhancing their reasoning capabilities.

## Data Structure

- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
- **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision)

Columns:
- `"prompt"`: The problem statement.
- `"completions"`: A list of reasoning steps generated to solve the problem.
- `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step.

This structure allows models to learn the correctness of each step in a solution, facilitating improved reasoning and problem-solving abilities.

## Generation script

The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/prm800k.py).
""")

if __name__ == "__main__":
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]

    data_files = {
        "train": "https://github.com/openai/prm800k/raw/refs/heads/main/prm800k/data/phase1_train.jsonl",
        "test": "https://github.com/openai/prm800k/raw/refs/heads/main/prm800k/data/phase1_test.jsonl",
    }
    dataset = load_dataset("json", data_files=data_files)

    dataset = dataset.map(
        process_batch,
        batched=True,
        batch_size=10,
        remove_columns=[
            "labeler",
            "timestamp",
            "generation",
            "is_quality_control_question",
            "is_initial_screening_question",
            "question",
            "label",
        ],
        num_proc=script_args.dataset_num_proc,
    )

    if script_args.push_to_hub:
        dataset.push_to_hub(script_args.repo_id)
        model_card.push_to_hub(script_args.repo_id, repo_type="dataset")