Spaces:
Runtime error
Runtime error
File size: 2,761 Bytes
32cd99d e5bfb7c 5560a92 86d3183 32cd99d 5560a92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
from datasets import load_dataset
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
# Check if model already exists
if os.path.exists("trained_model"):
print("β
Model already exists. Skipping training.")
exit()
print("π Starting training...")
# Load only 100 samples for faster CPU training
ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:100]")
# DEBUG: Inspect a few labels
print("\nπ Sample labels from dataset:")
for i in range(5):
print(f"{i}: {ds[i]['label']} (type: {type(ds[i]['label'])})")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
# Safely extract label string from possible dict or str
def safe_get_label(example):
label = example.get("label")
if isinstance(label, dict) and "latex" in label:
return label["latex"]
elif isinstance(label, str):
return label
else:
return None
def preprocess(example):
label_str = safe_get_label(example)
if not isinstance(label_str, str) or label_str.strip() == "":
return {} # Skip if label is invalid
# Convert image to RGB
img = example["image"].convert("RGB")
inputs = processor(images=img, return_tensors="pt")
# Tokenize label
labels = processor.tokenizer(
label_str,
truncation=True,
padding="max_length",
max_length=128
).input_ids
return {
"pixel_values": inputs.pixel_values[0],
"labels": labels
}
# Preprocess and filter
ds = ds.map(preprocess, remove_columns=["image", "label"])
ds = ds.filter(lambda ex: "labels" in ex and ex["labels"] is not None)
# Check number of remaining examples
print(f"β
Total usable training samples: {len(ds)}")
if len(ds) == 0:
raise RuntimeError("β No usable training samples after preprocessing.")
# Model setup
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
training_args = Seq2SeqTrainingArguments(
output_dir="trained_model",
per_device_train_batch_size=2,
num_train_epochs=1,
learning_rate=5e-5,
logging_steps=10,
save_steps=500,
fp16=False,
push_to_hub=False,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=ds,
tokenizer=processor.tokenizer,
data_collator=default_data_collator,
)
trainer.train()
print("β
Training completed")
# Save model
model.save_pretrained("trained_model")
processor.save_pretrained("trained_model")
print("β
Model saved to trained_model/")
|