Spaces:
Runtime error
Runtime error
from datasets import load_dataset | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator | |
# Load the handwritten math dataset | |
ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:1000]") | |
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") | |
def preprocess(ex): | |
img = ex["image"].convert("RGB") | |
inputs = processor(images=img, return_tensors="pt") | |
labels = processor.tokenizer(ex["text"], truncation=True, padding="max_length", max_length=128).input_ids | |
ex["pixel_values"] = inputs.pixel_values[0] | |
ex["labels"] = labels | |
return ex | |
ds = ds.map(preprocess, remove_columns=["image", "text"]) | |
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") | |
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id | |
model.config.pad_token_id = processor.tokenizer.pad_token_id | |
training_args = Seq2SeqTrainingArguments( | |
output_dir="trained_model", | |
per_device_train_batch_size=2, | |
num_train_epochs=1, | |
learning_rate=5e-5, | |
logging_steps=10, | |
save_steps=500, | |
fp16=False, | |
push_to_hub=False, | |
) | |
trainer = Seq2SeqTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=ds, | |
tokenizer=processor.tokenizer, | |
data_collator=default_data_collator, | |
) | |
trainer.train() | |
model.save_pretrained("trained_model") | |
processor.save_pretrained("trained_model") | |