Spaces:
Runtime error
Runtime error
File size: 1,478 Bytes
e5bfb7c eb0ee57 e5bfb7c eb0ee57 e5bfb7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from datasets import load_dataset
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
# Load the handwritten math dataset
ds = load_dataset("Azu/Handwritten-Mathematical-Expression-Convert-LaTeX", split="train[:1000]")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
def preprocess(ex):
img = ex["image"].convert("RGB")
inputs = processor(images=img, return_tensors="pt")
labels = processor.tokenizer(ex["text"], truncation=True, padding="max_length", max_length=128).input_ids
ex["pixel_values"] = inputs.pixel_values[0]
ex["labels"] = labels
return ex
ds = ds.map(preprocess, remove_columns=["image", "text"])
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
training_args = Seq2SeqTrainingArguments(
output_dir="trained_model",
per_device_train_batch_size=2,
num_train_epochs=1,
learning_rate=5e-5,
logging_steps=10,
save_steps=500,
fp16=False,
push_to_hub=False,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=ds,
tokenizer=processor.tokenizer,
data_collator=default_data_collator,
)
trainer.train()
model.save_pretrained("trained_model")
processor.save_pretrained("trained_model")
|