import os from datasets import load_dataset from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator # Check if model already exists if os.path.exists("trained_model"): print("āœ… Model already exists. Skipping training.") exit() print("šŸš€ Starting training...") # Load only 100 samples for faster CPU training ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:100]") # DEBUG: Inspect a few labels print("\nšŸ” Sample labels from dataset:") for i in range(5): print(f"{i}: {ds[i]['label']} (type: {type(ds[i]['label'])})") processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") # Safely extract label string from possible dict or str def safe_get_label(example): label = example.get("label") if isinstance(label, dict) and "latex" in label: return label["latex"] elif isinstance(label, str): return label else: return None def preprocess(example): label_str = safe_get_label(example) if not isinstance(label_str, str) or label_str.strip() == "": return {} # Skip if label is invalid # Convert image to RGB img = example["image"].convert("RGB") inputs = processor(images=img, return_tensors="pt") # Tokenize label labels = processor.tokenizer( label_str, truncation=True, padding="max_length", max_length=128 ).input_ids return { "pixel_values": inputs.pixel_values[0], "labels": labels } # Preprocess and filter ds = ds.map(preprocess, remove_columns=["image", "label"]) ds = ds.filter(lambda ex: "labels" in ex and ex["labels"] is not None) # Check number of remaining examples print(f"āœ… Total usable training samples: {len(ds)}") if len(ds) == 0: raise RuntimeError("āŒ No usable training samples after preprocessing.") # Model setup model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") model.config.decoder_start_token_id = processor.tokenizer.cls_token_id model.config.pad_token_id = processor.tokenizer.pad_token_id training_args = Seq2SeqTrainingArguments( output_dir="trained_model", per_device_train_batch_size=2, num_train_epochs=1, learning_rate=5e-5, logging_steps=10, save_steps=500, fp16=False, push_to_hub=False, ) trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=ds, tokenizer=processor.tokenizer, data_collator=default_data_collator, ) trainer.train() print("āœ… Training completed") # Save model model.save_pretrained("trained_model") processor.save_pretrained("trained_model") print("āœ… Model saved to trained_model/")