{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Login to HuggingFace (just login once)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import interpreter_login\n", "interpreter_login()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Collect Menu Image Datasets\n", "- Use `metadata.jsonl` to label the images's ground truth. You can visit [here](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) to see the examples.\n", "- After finishing, push to HuggingFace Datasets.\n", "- For labeling:\n", " - [Google AI Studio](https://aistudio.google.com) or [OpenAI ChatGPT](https://chatgpt.com).\n", " - Use function calling by API. Start the gradio app locally or visit [here](https://huggingface.co/spaces/ryanlinjui/menu-text-detection).\n", "\n", "### Menu Type\n", "- **h**: horizontal menu\n", "- **v**: vertical menu\n", "- **d**: document-style menu\n", "- **s**: in-scene menu (non-document style)\n", "- **i**: irregular menu (menu with irregular text layout)\n", "\n", "> Please see the [examples](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "\n", "import numpy as np\n", "from PIL import Image\n", "from pillow_heif import register_heif_opener\n", "\n", "from menu.llm import (\n", " GeminiAPI,\n", " OpenAIAPI\n", ")\n", "\n", "IMAGE_DIR = \"datasets/images\" # set your image directory here\n", "SELECTED_MODEL = \"gemini-2.5-flash\" # set model name here, refer MODEL_LIST from app.py for more\n", "API_TOKEN = \"\" # set your API token here\n", "SELECTED_FUNCTION = GeminiAPI # set \"GeminiAPI\" or \"OpenAIAPI\"\n", "\n", "register_heif_opener()\n", "\n", "for file in os.listdir(IMAGE_DIR):\n", " print(f\"Processing image: {file}\")\n", " try:\n", " image = np.array(Image.open(os.path.join(IMAGE_DIR, file)))\n", " data = {\n", " \"file_name\": file,\n", " \"menu\": SELECTED_FUNCTION.call(image, SELECTED_MODEL, API_TOKEN)\n", " }\n", " with open(os.path.join(IMAGE_DIR, \"metadata.jsonl\"), \"a\", encoding=\"utf-8\") as metaf:\n", " metaf.write(json.dumps(data, ensure_ascii=False, sort_keys=True) + \"\\n\")\n", " except Exception as e:\n", " print(f\"Skipping invalid image '{file}': {e}\")\n", " continue" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Push Datasets to HuggingFace" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "dataset = load_dataset(path=\"datasets/menu-zh-TW\") # load dataset from the local directory including the metadata.jsonl, images files.\n", "dataset.push_to_hub(repo_id=\"ryanlinjui/menu-zh-TW\") # push to the huggingface dataset hub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Prepare the dataset for training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from menu.utils import split_dataset\n", "from datasets import load_dataset\n", "\n", "dataset = load_dataset(path=\"ryanlinjui/menu-zh-TW\") # set your dataset repo id for training\n", "dataset = split_dataset(dataset[\"train\"], train=0.8, validation=0.1, test=0.1, seed=42) # (optional) use it if your dataset is not split into train/validation/test\n", "print(f\"Dataset split: {len(dataset['train'])} train, {len(dataset['validation'])} validation, {len(dataset['test'])} test\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tune Donut Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import logging\n", "from menu.donut import DonutTrainer\n", "\n", "logging.getLogger(\"transformers\").setLevel(logging.ERROR) # filter output message from transformers\n", "\n", "DonutTrainer.train(\n", " dataset=dataset,\n", " pretrained_model_repo_id=\"naver-clova-ix/donut-base\", # set your pretrained model repo id for fine-tuning\n", " ground_truth_key=\"menu\", # set your ground truth key for training\n", " huggingface_model_id=\"ryanlinjui/donut-base-finetuned-menu\", # set your huggingface model repo id for saving / pushing to the hub\n", " epochs=15, # set your training epochs\n", " train_batch_size=8, # set your training batch size\n", " val_batch_size=1, # set your validation batch size\n", " learning_rate=3e-5, # set your learning rate\n", " val_check_interval=0.5, # how many times we want to validate during an epoch\n", " check_val_every_n_epoch=1, # how many epochs we want to validate\n", " gradient_clip_val=1.0, # gradient clipping value for training stability\n", " num_training_samples_per_epoch=198, # set num_training_samples_per_epoch = training set size\n", " num_nodes=1, # number of nodes for distributed training\n", " warmup_steps=75 # number of warmup steps for learning rate scheduler, 198/8*30/10, 10%\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluate Donut Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "from datasets import load_dataset\n", "\n", "from menu.utils import split_dataset\n", "from menu.donut import DonutFinetuned\n", "\n", "dataset = load_dataset(\"ryanlinjui/menu-zh-TW\")\n", "dataset = split_dataset(dataset[\"train\"], train=0.8, validation=0.1, test=0.1, seed=42) # (optional) use it if your dataset is not split into train/validation/test\n", "donut_finetuned = DonutFinetuned(pretrained_model_repo_id=\"ryanlinjui/donut-base-finetuned-menu\")\n", "scores, output_list = donut_finetuned.evaluate(dataset=dataset[\"test\"], ground_truth_key=\"menu\")\n", "\n", "print(\"Evaluation scores:\")\n", "for key, value in scores.items():\n", " print(f\"{key}: {value}\")\n", "\n", "print(\"\\nSample outputs:\")\n", "for output in output_list[:5]:\n", " print(json.dumps(output, ensure_ascii=False, indent=4))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test Donut Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "from menu.donut import DonutFinetuned\n", "\n", "image = Image.open(\"./examples/menu-hd.jpg\")\n", "\n", "donut_finetuned = DonutFinetuned(pretrained_model_repo_id=\"ryanlinjui/donut-base-finetuned-menu\")\n", "outputs = donut_finetuned.predict(image=image)\n", "print(outputs)" ] } ], "metadata": { "kernelspec": { "display_name": "menu-text-detection", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.12" } }, "nbformat": 4, "nbformat_minor": 2 }