lvwerra HF Staff commited on
Commit
363da21
·
1 Parent(s): 69a883d

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +237 -0
train.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-Tune SantaCoder on code/text dataset
3
+ """
4
+
5
+ import argparse
6
+ import os
7
+
8
+ import torch
9
+ from datasets import load_dataset
10
+ from torch.utils.data import IterableDataset
11
+ from torch.utils.data.dataloader import DataLoader
12
+ from tqdm import tqdm
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoTokenizer,
16
+ Trainer,
17
+ TrainingArguments,
18
+ logging,
19
+ set_seed,
20
+ )
21
+
22
+
23
+ def get_args():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--model_path", type=str, default="gpt2-xl")
26
+ parser.add_argument("--dataset_name", type=str, default="lvwerra/stack-exchange-paired")
27
+ parser.add_argument("--subset", type=str, default="data/finetune")
28
+ parser.add_argument("--split", type=str, default="train")
29
+ parser.add_argument("--size_valid_set", type=int, default=4000)
30
+ parser.add_argument("--streaming", action="store_true")
31
+ parser.add_argument("--shuffle_buffer", type=int, default=5000)
32
+
33
+ parser.add_argument("--seq_length", type=int, default=1024)
34
+ parser.add_argument("--max_steps", type=int, default=10000)
35
+ parser.add_argument("--batch_size", type=int, default=16)
36
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
37
+ parser.add_argument("--eos_token_id", type=int, default=49152)
38
+
39
+ parser.add_argument("--learning_rate", type=float, default=1e-4)
40
+ parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
41
+ parser.add_argument("--num_warmup_steps", type=int, default=100)
42
+ parser.add_argument("--weight_decay", type=float, default=0.05)
43
+
44
+ parser.add_argument("--local_rank", type=int, default=0)
45
+ parser.add_argument("--no_fp16", action="store_false")
46
+ parser.add_argument("--bf16", action="store_true")
47
+ parser.add_argument("--no_gradient_checkpointing", action="store_false")
48
+ parser.add_argument("--seed", type=int, default=0)
49
+ parser.add_argument("--num_workers", type=int, default=None)
50
+ parser.add_argument("--output_dir", type=str, default="./checkpoints")
51
+ parser.add_argument("--log_freq", default=1, type=int)
52
+ parser.add_argument("--eval_freq", default=1000, type=int)
53
+ parser.add_argument("--save_freq", default=1000, type=int)
54
+
55
+ return parser.parse_args()
56
+
57
+
58
+ def chars_token_ratio(dataset, tokenizer, nb_examples=400):
59
+ """
60
+ Estimate the average number of characters per token in the dataset.
61
+ """
62
+ total_characters, total_tokens = 0, 0
63
+ for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
64
+ text = prepare_sample_text(example)
65
+ total_characters += len(text)
66
+ total_tokens += len(tokenizer(text).tokens())
67
+
68
+ return total_characters / total_tokens
69
+
70
+ def prepare_sample_text(example):
71
+ """Prepare the text from a sample of the dataset."""
72
+ text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
73
+ return text
74
+
75
+
76
+ class ConstantLengthDataset(IterableDataset):
77
+ """
78
+ Iterable dataset that returns constant length chunks of tokens from stream of text files.
79
+ Args:
80
+ tokenizer (Tokenizer): The processor used for proccessing the data.
81
+ dataset (dataset.Dataset): Dataset with text files.
82
+ infinite (bool): If True the iterator is reset after dataset reaches end else stops.
83
+ seq_length (int): Length of token sequences to return.
84
+ num_of_sequences (int): Number of token sequences to keep in buffer.
85
+ chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ tokenizer,
91
+ dataset,
92
+ infinite=False,
93
+ seq_length=1024,
94
+ num_of_sequences=1024,
95
+ chars_per_token=3.6,
96
+ ):
97
+ self.tokenizer = tokenizer
98
+ self.concat_token_id = (
99
+ tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
100
+ )
101
+ self.dataset = dataset
102
+ self.seq_length = seq_length
103
+ self.infinite = infinite
104
+ self.current_size = 0
105
+ self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
106
+
107
+ def __iter__(self):
108
+ iterator = iter(self.dataset)
109
+ more_examples = True
110
+ while more_examples:
111
+ buffer, buffer_len = [], 0
112
+ while True:
113
+ if buffer_len >= self.max_buffer_size:
114
+ break
115
+ try:
116
+ buffer.append(prepare_sample_text(next(iterator)))
117
+ buffer_len += len(buffer[-1])
118
+ except StopIteration:
119
+ if self.infinite:
120
+ iterator = iter(self.dataset)
121
+ else:
122
+ more_examples = False
123
+ break
124
+ tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
125
+ all_token_ids = []
126
+ for tokenized_input in tokenized_inputs:
127
+ all_token_ids.extend(tokenized_input + [self.concat_token_id])
128
+ for i in range(0, len(all_token_ids), self.seq_length):
129
+ input_ids = all_token_ids[i : i + self.seq_length]
130
+ if len(input_ids) == self.seq_length:
131
+ self.current_size += 1
132
+ yield {
133
+ "input_ids": torch.LongTensor(input_ids),
134
+ "labels": torch.LongTensor(input_ids),
135
+ }
136
+
137
+
138
+ def create_datasets(tokenizer, args):
139
+ dataset = load_dataset(
140
+ args.dataset_name,
141
+ data_dir=args.subset,
142
+ split=args.split,
143
+ use_auth_token=True,
144
+ num_proc=args.num_workers if not args.streaming else None,
145
+ streaming=args.streaming,
146
+ )
147
+ if args.streaming:
148
+ print("Loading the dataset in streaming mode")
149
+ valid_data = dataset.take(args.size_valid_set)
150
+ train_data = dataset.skip(args.size_valid_set)
151
+ train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
152
+ else:
153
+ dataset = dataset.train_test_split(test_size=0.005, seed=args.seed)
154
+ train_data = dataset["train"]
155
+ valid_data = dataset["test"]
156
+ print(
157
+ f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}"
158
+ )
159
+ chars_per_token = chars_token_ratio(train_data, tokenizer)
160
+ print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
161
+ train_dataset = ConstantLengthDataset(
162
+ tokenizer,
163
+ train_data,
164
+ infinite=True,
165
+ seq_length=args.seq_length,
166
+ chars_per_token=chars_per_token,
167
+ )
168
+ valid_dataset = ConstantLengthDataset(
169
+ tokenizer,
170
+ valid_data,
171
+ infinite=False,
172
+ seq_length=args.seq_length,
173
+ chars_per_token=chars_per_token,
174
+ )
175
+ return train_dataset, valid_dataset
176
+
177
+
178
+ def run_training(args, train_data, val_data):
179
+ print("Loading the model")
180
+ # disable caching mechanism when using gradient checkpointing
181
+ model = AutoModelForCausalLM.from_pretrained(
182
+ args.model_path,
183
+ trust_remote_code=True,
184
+ use_cache=not args.no_gradient_checkpointing,
185
+ )
186
+ train_data.start_iteration = 0
187
+
188
+ print(f"Starting main loop")
189
+
190
+ training_args = TrainingArguments(
191
+ output_dir=args.output_dir,
192
+ dataloader_drop_last=True,
193
+ evaluation_strategy="steps",
194
+ max_steps=args.max_steps,
195
+ eval_steps=args.eval_freq,
196
+ save_steps=args.save_freq,
197
+ logging_steps=args.log_freq,
198
+ per_device_train_batch_size=args.batch_size,
199
+ per_device_eval_batch_size=args.batch_size,
200
+ learning_rate=args.learning_rate,
201
+ lr_scheduler_type=args.lr_scheduler_type,
202
+ warmup_steps=args.num_warmup_steps,
203
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
204
+ gradient_checkpointing=args.no_gradient_checkpointing,
205
+ fp16=args.no_fp16,
206
+ bf16=args.bf16,
207
+ weight_decay=args.weight_decay,
208
+ run_name="gpt2-finetuned",
209
+ report_to="wandb",
210
+ )
211
+
212
+ trainer = Trainer(
213
+ model=model, args=training_args, train_dataset=train_data, eval_dataset=val_data
214
+ )
215
+
216
+ print("Training...")
217
+ trainer.train()
218
+
219
+ print("Saving last checkpoint of the model")
220
+ model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
221
+
222
+
223
+ def main(args):
224
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_auth_token=True)
225
+ train_dataset, eval_dataset = create_datasets(tokenizer, args)
226
+ run_training(args, train_dataset, eval_dataset)
227
+
228
+
229
+ if __name__ == "__main__":
230
+
231
+ args = get_args()
232
+ set_seed(args.seed)
233
+ os.makedirs(args.output_dir, exist_ok=True)
234
+
235
+ logging.set_verbosity_error()
236
+
237
+ main(args)