Spaces:
Runtime error
Runtime error
import os | |
from datasets import load_dataset | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator | |
# Check if model already exists | |
if os.path.exists("trained_model"): | |
print("β Model already exists. Skipping training.") | |
exit() | |
print("π Starting training...") | |
# Load only 100 samples for faster CPU training | |
ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:100]") | |
# DEBUG: Inspect a few labels | |
print("\nπ Sample labels from dataset:") | |
for i in range(5): | |
print(f"{i}: {ds[i]['label']} (type: {type(ds[i]['label'])})") | |
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") | |
# Safely extract label string from possible dict or str | |
def safe_get_label(example): | |
label = example.get("label") | |
if isinstance(label, dict) and "latex" in label: | |
return label["latex"] | |
elif isinstance(label, str): | |
return label | |
else: | |
return None | |
def preprocess(example): | |
label_str = safe_get_label(example) | |
if not isinstance(label_str, str) or label_str.strip() == "": | |
return {} # Skip if label is invalid | |
# Convert image to RGB | |
img = example["image"].convert("RGB") | |
inputs = processor(images=img, return_tensors="pt") | |
# Tokenize label | |
labels = processor.tokenizer( | |
label_str, | |
truncation=True, | |
padding="max_length", | |
max_length=128 | |
).input_ids | |
return { | |
"pixel_values": inputs.pixel_values[0], | |
"labels": labels | |
} | |
# Preprocess and filter | |
ds = ds.map(preprocess, remove_columns=["image", "label"]) | |
ds = ds.filter(lambda ex: "labels" in ex and ex["labels"] is not None) | |
# Check number of remaining examples | |
print(f"β Total usable training samples: {len(ds)}") | |
if len(ds) == 0: | |
raise RuntimeError("β No usable training samples after preprocessing.") | |
# Model setup | |
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") | |
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id | |
model.config.pad_token_id = processor.tokenizer.pad_token_id | |
training_args = Seq2SeqTrainingArguments( | |
output_dir="trained_model", | |
per_device_train_batch_size=2, | |
num_train_epochs=1, | |
learning_rate=5e-5, | |
logging_steps=10, | |
save_steps=500, | |
fp16=False, | |
push_to_hub=False, | |
) | |
trainer = Seq2SeqTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=ds, | |
tokenizer=processor.tokenizer, | |
data_collator=default_data_collator, | |
) | |
trainer.train() | |
print("β Training completed") | |
# Save model | |
model.save_pretrained("trained_model") | |
processor.save_pretrained("trained_model") | |
print("β Model saved to trained_model/") | |