from datasets import load_dataset from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator # Load the handwritten math dataset ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:1000]") processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") def preprocess(ex): img = ex["image"].convert("RGB") inputs = processor(images=img, return_tensors="pt") labels = processor.tokenizer(ex["text"], truncation=True, padding="max_length", max_length=128).input_ids ex["pixel_values"] = inputs.pixel_values[0] ex["labels"] = labels return ex ds = ds.map(preprocess, remove_columns=["image", "text"]) model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") model.config.decoder_start_token_id = processor.tokenizer.cls_token_id model.config.pad_token_id = processor.tokenizer.pad_token_id training_args = Seq2SeqTrainingArguments( output_dir="trained_model", per_device_train_batch_size=2, num_train_epochs=1, learning_rate=5e-5, logging_steps=10, save_steps=500, fp16=False, push_to_hub=False, ) trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=ds, tokenizer=processor.tokenizer, data_collator=default_data_collator, ) trainer.train() model.save_pretrained("trained_model") processor.save_pretrained("trained_model")