MasteredUltraInstinct commited on
Commit
5560a92
Β·
verified Β·
1 Parent(s): 86d3183

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -83
app.py CHANGED
@@ -1,89 +1,93 @@
1
  import os
2
  from datasets import load_dataset
3
- from transformers import (
4
- TrOCRProcessor,
5
- VisionEncoderDecoderModel,
6
- Seq2SeqTrainer,
7
- Seq2SeqTrainingArguments,
8
- default_data_collator,
9
- )
10
 
11
  # Check if model already exists
12
  if os.path.exists("trained_model"):
13
  print("βœ… Model already exists. Skipping training.")
14
- else:
15
- print("πŸš€ Starting training...")
16
-
17
- # Load dataset (only 100 samples for speed)
18
- ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:100]")
19
-
20
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
21
-
22
- # First, filter out examples without valid labels
23
- def is_valid(example):
24
- label = example.get("label")
25
- return isinstance(label, str) or (isinstance(label, dict) and "latex" in label)
26
-
27
- ds = ds.filter(is_valid)
28
-
29
- # Define preprocessing
30
- def preprocess(example):
31
- img = example["image"].convert("RGB")
32
- inputs = processor(images=img, return_tensors="pt")
33
-
34
- label_data = example["label"]
35
- label_str = label_data["latex"] if isinstance(label_data, dict) else label_data
36
-
37
- labels = processor.tokenizer(
38
- label_str,
39
- truncation=True,
40
- padding="max_length",
41
- max_length=128
42
- ).input_ids
43
-
44
- return {
45
- "pixel_values": inputs.pixel_values[0],
46
- "labels": labels
47
- }
48
-
49
- # Apply preprocessing
50
- ds = ds.map(preprocess, remove_columns=["image", "label"])
51
-
52
- # Final safety check
53
- if len(ds) == 0:
54
- raise RuntimeError("❌ No usable training samples after preprocessing.")
55
-
56
- # Load model
57
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
58
- model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
59
- model.config.pad_token_id = processor.tokenizer.pad_token_id
60
-
61
- # Training arguments
62
- training_args = Seq2SeqTrainingArguments(
63
- output_dir="trained_model",
64
- per_device_train_batch_size=2,
65
- num_train_epochs=1,
66
- learning_rate=5e-5,
67
- logging_steps=10,
68
- save_steps=500,
69
- fp16=False, # Required False for CPU-only env
70
- push_to_hub=False,
71
- )
72
-
73
- # Trainer
74
- trainer = Seq2SeqTrainer(
75
- model=model,
76
- args=training_args,
77
- train_dataset=ds,
78
- tokenizer=processor.tokenizer,
79
- data_collator=default_data_collator,
80
- )
81
-
82
- # Start training
83
- trainer.train()
84
- print("βœ… Training completed")
85
-
86
- # Save model + processor
87
- model.save_pretrained("trained_model")
88
- processor.save_pretrained("trained_model")
89
- print("βœ… Model saved to trained_model/")
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from datasets import load_dataset
3
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
 
 
 
 
 
 
4
 
5
  # Check if model already exists
6
  if os.path.exists("trained_model"):
7
  print("βœ… Model already exists. Skipping training.")
8
+ exit()
9
+
10
+ print("πŸš€ Starting training...")
11
+
12
+ # Load only 100 samples for faster CPU training
13
+ ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:100]")
14
+
15
+ # DEBUG: Inspect a few labels
16
+ print("\nπŸ” Sample labels from dataset:")
17
+ for i in range(5):
18
+ print(f"{i}: {ds[i]['label']} (type: {type(ds[i]['label'])})")
19
+
20
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
21
+
22
+ # Safely extract label string from possible dict or str
23
+ def safe_get_label(example):
24
+ label = example.get("label")
25
+ if isinstance(label, dict) and "latex" in label:
26
+ return label["latex"]
27
+ elif isinstance(label, str):
28
+ return label
29
+ else:
30
+ return None
31
+
32
+ def preprocess(example):
33
+ label_str = safe_get_label(example)
34
+ if not isinstance(label_str, str) or label_str.strip() == "":
35
+ return {} # Skip if label is invalid
36
+
37
+ # Convert image to RGB
38
+ img = example["image"].convert("RGB")
39
+ inputs = processor(images=img, return_tensors="pt")
40
+
41
+ # Tokenize label
42
+ labels = processor.tokenizer(
43
+ label_str,
44
+ truncation=True,
45
+ padding="max_length",
46
+ max_length=128
47
+ ).input_ids
48
+
49
+ return {
50
+ "pixel_values": inputs.pixel_values[0],
51
+ "labels": labels
52
+ }
53
+
54
+ # Preprocess and filter
55
+ ds = ds.map(preprocess, remove_columns=["image", "label"])
56
+ ds = ds.filter(lambda ex: "labels" in ex and ex["labels"] is not None)
57
+
58
+ # Check number of remaining examples
59
+ print(f"βœ… Total usable training samples: {len(ds)}")
60
+ if len(ds) == 0:
61
+ raise RuntimeError("❌ No usable training samples after preprocessing.")
62
+
63
+ # Model setup
64
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
65
+ model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
66
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
67
+
68
+ training_args = Seq2SeqTrainingArguments(
69
+ output_dir="trained_model",
70
+ per_device_train_batch_size=2,
71
+ num_train_epochs=1,
72
+ learning_rate=5e-5,
73
+ logging_steps=10,
74
+ save_steps=500,
75
+ fp16=False,
76
+ push_to_hub=False,
77
+ )
78
+
79
+ trainer = Seq2SeqTrainer(
80
+ model=model,
81
+ args=training_args,
82
+ train_dataset=ds,
83
+ tokenizer=processor.tokenizer,
84
+ data_collator=default_data_collator,
85
+ )
86
+
87
+ trainer.train()
88
+ print("βœ… Training completed")
89
+
90
+ # Save model
91
+ model.save_pretrained("trained_model")
92
+ processor.save_pretrained("trained_model")
93
+ print("βœ… Model saved to trained_model/")