MasteredUltraInstinct commited on
Commit
32cd99d
Β·
verified Β·
1 Parent(s): 819a63d

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +43 -61
train.py CHANGED
@@ -1,69 +1,51 @@
 
1
  from datasets import load_dataset
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
3
 
4
- # Load the handwritten math dataset (1000 examples)
5
- ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:100]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # Load processor and model
8
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
9
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
10
-
11
- # Preprocess function
12
- def preprocess(ex):
13
- img = ex["image"].convert("RGB")
14
- inputs = processor(images=img, return_tensors="pt")
15
-
16
- # Convert label index to actual LaTeX string
17
- label_str = ds.features["label"].int2str(ex["label"])
18
- labels = processor.tokenizer(
19
- label_str,
20
- truncation=True,
21
- padding="max_length",
22
- max_length=128
23
- ).input_ids
24
-
25
- ex["pixel_values"] = inputs.pixel_values[0]
26
- ex["labels"] = labels
27
- return ex
28
-
29
- # Apply preprocessing
30
- ds = ds.map(
31
- preprocess,
32
- remove_columns=["image", "label"],
33
- num_proc=1,
34
- load_from_cache_file=False
35
- )
36
-
37
-
38
- # Model config
39
- model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
40
- model.config.pad_token_id = processor.tokenizer.pad_token_id
41
-
42
- # Training arguments
43
- training_args = Seq2SeqTrainingArguments(
44
- output_dir="trained_model",
45
- per_device_train_batch_size=2,
46
- num_train_epochs=1,
47
- learning_rate=5e-5,
48
- logging_steps=10,
49
- save_steps=500,
50
- fp16=False,
51
- push_to_hub=False,
52
- )
53
-
54
- # Trainer
55
- trainer = Seq2SeqTrainer(
56
- model=model,
57
- args=training_args,
58
- train_dataset=ds,
59
- tokenizer=processor.tokenizer,
60
- data_collator=default_data_collator,
61
- )
62
-
63
- # Train and save
64
- if __name__ == "__main__":
65
- print("πŸš€ Training started")
66
  trainer.train()
67
  print("βœ… Training completed")
 
68
  model.save_pretrained("trained_model")
69
  processor.save_pretrained("trained_model")
 
 
1
+ import os
2
  from datasets import load_dataset
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
4
 
5
+ if os.path.exists("trained_model"):
6
+ print("βœ… Model already exists. Skipping training.")
7
+ else:
8
+ print("πŸš€ Starting training...")
9
+
10
+ ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:100]")
11
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
12
+
13
+ def preprocess(ex):
14
+ img = ex["image"].convert("RGB")
15
+ inputs = processor(images=img, return_tensors="pt")
16
+ labels = processor.tokenizer(ex["label"], truncation=True, padding="max_length", max_length=128).input_ids
17
+ ex["pixel_values"] = inputs.pixel_values[0]
18
+ ex["labels"] = labels
19
+ return ex
20
+
21
+ ds = ds.map(preprocess, remove_columns=["image", "label"])
22
+
23
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
24
+ model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
25
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
26
+
27
+ training_args = Seq2SeqTrainingArguments(
28
+ output_dir="trained_model",
29
+ per_device_train_batch_size=2,
30
+ num_train_epochs=1,
31
+ learning_rate=5e-5,
32
+ logging_steps=10,
33
+ save_steps=500,
34
+ fp16=False,
35
+ push_to_hub=False,
36
+ )
37
+
38
+ trainer = Seq2SeqTrainer(
39
+ model=model,
40
+ args=training_args,
41
+ train_dataset=ds,
42
+ tokenizer=processor.tokenizer,
43
+ data_collator=default_data_collator,
44
+ )
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  trainer.train()
47
  print("βœ… Training completed")
48
+
49
  model.save_pretrained("trained_model")
50
  processor.save_pretrained("trained_model")
51
+ print("βœ… Model saved to trained_model/")