MasteredUltraInstinct commited on
Commit
e5bfb7c
Β·
verified Β·
1 Parent(s): c92fbb4

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +41 -30
train.py CHANGED
@@ -1,33 +1,44 @@
1
- from model import get_model
2
- from pix2tex.dataset.dataset import Im2LatexDataset
3
- from pix2tex.trainer import Trainer
4
- import os
5
-
6
- os.makedirs('trained_model', exist_ok=True)
7
-
8
- # Training parameters
9
- config = {
10
- "batch_size": 4,
11
- "epochs": 1,
12
- "max_seq_len": 150,
13
- "warmup_steps": 10,
14
- "lr": 1e-4,
15
- "device": "cpu",
16
- "save_dir": "trained_model",
17
- "resume": False
18
- }
19
-
20
- # Dataset path
21
- dataset = Im2LatexDataset(
22
- data_root='handwritten_dataset',
23
- transform=None,
24
- max_length=config["max_seq_len"]
 
 
 
 
 
 
 
25
  )
26
 
27
- # Initialize model and trainer
28
- model, tokenizer = get_model()
29
- trainer = Trainer(model, tokenizer, config)
 
 
 
 
30
 
31
- print("🧠 Starting training...")
32
- trainer.train(dataset)
33
- print("βœ… Training complete. Model saved to 'trained_model/'")
 
1
+ from datasets import load_dataset
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
3
+
4
+ # Load the handwritten math dataset
5
+ ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:1000]")
6
+
7
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
8
+
9
+ def preprocess(ex):
10
+ img = ex["image"].convert("RGB")
11
+ inputs = processor(images=img, return_tensors="pt")
12
+ labels = processor.tokenizer(ex["text"], truncation=True, padding="max_length", max_length=128).input_ids
13
+ ex["pixel_values"] = inputs.pixel_values[0]
14
+ ex["labels"] = labels
15
+ return ex
16
+
17
+ ds = ds.map(preprocess, remove_columns=["image", "text"])
18
+
19
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
20
+ model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
21
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
22
+
23
+ training_args = Seq2SeqTrainingArguments(
24
+ output_dir="trained_model",
25
+ per_device_train_batch_size=2,
26
+ num_train_epochs=1,
27
+ learning_rate=5e-5,
28
+ logging_steps=10,
29
+ save_steps=500,
30
+ fp16=False,
31
+ push_to_hub=False,
32
  )
33
 
34
+ trainer = Seq2SeqTrainer(
35
+ model=model,
36
+ args=training_args,
37
+ train_dataset=ds,
38
+ tokenizer=processor.tokenizer,
39
+ data_collator=default_data_collator,
40
+ )
41
 
42
+ trainer.train()
43
+ model.save_pretrained("trained_model")
44
+ processor.save_pretrained("trained_model")