github-actions[bot] commited on
Commit
6bd37dd
·
0 Parent(s):

Sync from https://github.com/ryanlinjui/menu-text-detection

Browse files
.checkpoints/.gitkeep ADDED
File without changes
.env.example ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ HUGGINGFACE_TOKEN="HUGGINGFACE_TOKEN"
2
+ GEMINI_API_TOKEN="GEMINI_API_TOKEN"
3
+ OPENAI_API_TOKEN="OPENAI_API_TOKEN"
.github/workflows/sync.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face Spaces
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ jobs:
8
+ sync:
9
+ name: Sync
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - name: Checkout Repository
13
+ uses: actions/checkout@v4
14
+
15
+ - name: Remove bad files
16
+ run: rm -rf examples assets
17
+
18
+ - name: Sync to Hugging Face Spaces
19
+ uses: JacobLinCool/huggingface-sync@v1
20
+ with:
21
+ github: ${{ secrets.GITHUB_TOKEN }}
22
+ user: ryanlinjui # Hugging Face username or organization name
23
+ space: menu-text-detection # Hugging Face space name
24
+ token: ${{ secrets.HF_TOKEN }} # Hugging Face token
25
+ python_version: 3.11 # Python version
.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mac
2
+ .DS_Store
3
+
4
+ # cache
5
+ __pycache__
6
+
7
+ # datasets
8
+ datasets
9
+
10
+ # papers
11
+ docs/papers
12
+
13
+ # uv
14
+ .venv
15
+
16
+ # gradio
17
+ .gradio
18
+
19
+ # env
20
+ .env
21
+
22
+ # checkpoint
23
+ .checkpoints/*
24
+ !.checkpoints/.gitkeep
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 RyanLin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: menu text detection
3
+ emoji: 🦄
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ python_version: 3.11
8
+ short_description: Extract structured menu information from images into JSON...
9
+ tags: [ "donut","fine-tuning","image-to-text","transformer" ]
10
+ ---
11
+
12
+ # Menu Text Detection System
13
+
14
+ Extract structured menu information from images into JSON using a fine-tuned E2E model or LLM.
15
+
16
+ [![Gradio Space Demo](https://img.shields.io/badge/GradioSpace-Demo-important?logo=huggingface)](https://huggingface.co/spaces/ryanlinjui/menu-text-detection)
17
+ [![Hugging Face Models & Datasets](https://img.shields.io/badge/HuggingFace-Models_&_Datasets-important?logo=huggingface)](https://huggingface.co/collections/ryanlinjui/menu-text-detection-670ccf527626bb004bbfb39b)
18
+
19
+ https://github.com/user-attachments/assets/80e5d54c-f2c8-4593-ad9b-499e5b71d8f6
20
+
21
+ ## 🚀 Features
22
+ ### Overview
23
+ Currently supports the following information from menu images:
24
+
25
+ - **Restaurant Name**
26
+ - **Business Hours**
27
+ - **Address**
28
+ - **Phone Number**
29
+ - **Dish Information**
30
+ - Name
31
+ - Price
32
+
33
+ > For the JSON schema, see [tools directory](./tools).
34
+
35
+ ### Supported Methods to Extract Menu Information
36
+ #### Fine-tuned E2E model and Training metrics
37
+ - [**Donut (Document Parsing Task)**](https://huggingface.co/ryanlinjui/donut-base-finetuned-menu) - Base model by [Clova AI (ECCV ’22)](https://github.com/clovaai/donut)
38
+
39
+ #### LLM Function Calling
40
+ - Google Gemini API
41
+ - OpenAI GPT API
42
+
43
+ ## 💻 Training / Fine-Tuning
44
+ ### Setup
45
+ Use [uv](https://github.com/astral-sh/uv) to set up the development environment:
46
+
47
+ ```bash
48
+ uv sync
49
+ ```
50
+
51
+ > or use `pip install -r requirements.txt` if it has any problems
52
+
53
+ ### Training Script (Datasets collecting, Fine-Tuning)
54
+ Please refer [`train.ipynb`](./train.ipynb). Use Jupyter Notebook for training:
55
+
56
+ ```bash
57
+ uv run jupyter-notebook
58
+ ```
59
+
60
+ > For VSCode users, please install Jupyter extension, then select `.venv/bin/python` as your kernel.
61
+
62
+ ### Run Demo Locally
63
+ ```bash
64
+ uv run python app.py
65
+ ```
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from dotenv import load_dotenv
7
+ from pillow_heif import register_heif_opener
8
+
9
+ from menu.llm import (
10
+ GeminiAPI,
11
+ OpenAIAPI
12
+ )
13
+ from menu.donut import DonutFinetuned
14
+
15
+ register_heif_opener()
16
+ load_dotenv(override=True)
17
+ GEMINI_API_TOKEN = os.getenv("GEMINI_API_TOKEN", "")
18
+ OPENAI_API_TOKEN = os.getenv("OPENAI_API_TOKEN", "")
19
+
20
+ SOURCE_CODE_GH_URL = "https://github.com/ryanlinjui/menu-text-detection"
21
+ BADGE_URL = "https://img.shields.io/badge/GitHub_Code-Click_Here!!-default?logo=github"
22
+
23
+ GITHUB_RAW_URL = "https://raw.githubusercontent.com/ryanlinjui/menu-text-detection/main"
24
+ EXAMPLE_IMAGE_LIST = [
25
+ f"{GITHUB_RAW_URL}/examples/menu-hd.jpg",
26
+ f"{GITHUB_RAW_URL}/examples/menu-vs.jpg",
27
+ f"{GITHUB_RAW_URL}/examples/menu-si.jpg"
28
+ ]
29
+ FINETUNED_MODEL_LIST = [
30
+ "Donut (Document Parsing Task) Fine-tuned Model"
31
+ ]
32
+ LLM_MODEL_LIST = [
33
+ "gemini-2.5-pro",
34
+ "gemini-2.5-flash",
35
+ "gemini-2.0-flash",
36
+ "gpt-4.1",
37
+ "gpt-4o",
38
+ "o4-mini"
39
+ ]
40
+
41
+ donut_finetuned = DonutFinetuned("ryanlinjui/donut-base-finetuned-menu")
42
+
43
+ def handle(image: Image.Image, model: str, api_token: str) -> str:
44
+ if image is None:
45
+ raise gr.Error("Please upload an image first.")
46
+
47
+ if model == FINETUNED_MODEL_LIST[0]:
48
+ result = donut_finetuned.predict(image)
49
+
50
+ elif model in LLM_MODEL_LIST:
51
+ if len(api_token) < 10:
52
+ raise gr.Error(f"Please provide a valid token for {model}.")
53
+ try:
54
+ if model in LLM_MODEL_LIST[:3]:
55
+ result = GeminiAPI.call(image, model, api_token)
56
+ else:
57
+ result = OpenAIAPI.call(image, model, api_token)
58
+ except Exception as e:
59
+ raise gr.Error(f"Failed to process with API model {model}: {str(e)}")
60
+ else:
61
+ raise gr.Error("Invalid model selection. Please choose a valid model.")
62
+
63
+ return json.dumps(result, indent=4, ensure_ascii=False, sort_keys=True)
64
+
65
+ def UserInterface() -> gr.Interface:
66
+ with gr.Blocks(
67
+ delete_cache=(86400, 86400),
68
+ css="""
69
+ .image-panel {
70
+ display: flex;
71
+ flex-direction: column;
72
+ height: 600px;
73
+ }
74
+ .image-panel img {
75
+ object-fit: contain;
76
+ max-height: 600px;
77
+ max-width: 600px;
78
+ width: 100%;
79
+ }
80
+ .large-text textarea {
81
+ font-size: 20px !important;
82
+ height: 600px !important;
83
+ width: 100% !important;
84
+ }
85
+ """
86
+ ) as gradio_interface:
87
+ gr.HTML(f'<a href="{SOURCE_CODE_GH_URL}"><img src="{BADGE_URL}" alt="GitHub Code"/></a>')
88
+ gr.Markdown("# Menu Text Detection")
89
+
90
+ with gr.Row():
91
+ with gr.Column(scale=1, min_width=500):
92
+ gr.Markdown("## 📷 Menu Image")
93
+ menu_image = gr.Image(
94
+ type="pil",
95
+ label="Input menu image",
96
+ elem_classes="image-panel"
97
+ )
98
+
99
+ gr.Markdown("## 🤖 Model Selection")
100
+ model_choice_dropdown = gr.Dropdown(
101
+ choices=FINETUNED_MODEL_LIST + LLM_MODEL_LIST,
102
+ value=FINETUNED_MODEL_LIST[0],
103
+ label="Select Text Detection Model"
104
+ )
105
+
106
+ api_token_textbox = gr.Textbox(
107
+ label="API Token",
108
+ placeholder="Enter your API token here...",
109
+ type="password",
110
+ visible=False
111
+ )
112
+
113
+ generate_button = gr.Button("Generate Menu Information", variant="primary")
114
+
115
+ gr.Examples(
116
+ examples=EXAMPLE_IMAGE_LIST,
117
+ inputs=menu_image,
118
+ label="Example Menu Images"
119
+ )
120
+
121
+ with gr.Column(scale=1):
122
+ gr.Markdown("## 🍽️ Menu Info")
123
+ menu_json_textbox = gr.Textbox(
124
+ label="Ouput JSON",
125
+ interactive=True,
126
+ text_align="left",
127
+ elem_classes="large-text"
128
+ )
129
+
130
+ def update_token_visibility(choice):
131
+ if choice in LLM_MODEL_LIST:
132
+ current_token = ""
133
+ if choice in LLM_MODEL_LIST[:3]:
134
+ current_token = GEMINI_API_TOKEN
135
+ else:
136
+ current_token = OPENAI_API_TOKEN
137
+ return gr.update(visible=True, value=current_token)
138
+ else:
139
+ return gr.update(visible=False)
140
+
141
+ model_choice_dropdown.change(
142
+ fn=update_token_visibility,
143
+ inputs=model_choice_dropdown,
144
+ outputs=api_token_textbox
145
+ )
146
+
147
+ generate_button.click(
148
+ fn=handle,
149
+ inputs=[menu_image, model_choice_dropdown, api_token_textbox],
150
+ outputs=menu_json_textbox
151
+ )
152
+
153
+ return gradio_interface
154
+
155
+ if __name__ == "__main__":
156
+ demo = UserInterface()
157
+ demo.launch()
menu/donut.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is modified from the HuggingFace transformers tutorial script for fine-tuning Donut on a custom dataset.
3
+ It's defined from `.ipynb` to the module implementation for better reusability and maintainability.
4
+ Reference: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Donut/CORD/Fine_tune_Donut_on_a_custom_dataset_(CORD)_with_PyTorch_Lightning.ipynb
5
+ """
6
+
7
+ import re
8
+ import random
9
+ from typing import Any, List, Tuple, Dict
10
+
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ from tqdm.auto import tqdm
15
+ from nltk import edit_distance
16
+ import pytorch_lightning as pl
17
+ from datasets import DatasetDict
18
+ from donut import JSONParseEvaluator
19
+ from huggingface_hub import upload_folder
20
+ from pillow_heif import register_heif_opener
21
+ from pytorch_lightning.callbacks import Callback
22
+ from pytorch_lightning.loggers import TensorBoardLogger
23
+ from torch.utils.data import (
24
+ Dataset,
25
+ DataLoader
26
+ )
27
+ from transformers import (
28
+ DonutProcessor,
29
+ VisionEncoderDecoderModel,
30
+ VisionEncoderDecoderConfig
31
+ )
32
+
33
+ TASK_PROMPT_NAME = "<s_menu-text-detection>"
34
+ register_heif_opener()
35
+
36
+ class DonutFinetuned:
37
+ def __init__(self, pretrained_model_repo_id: str = "ryanlinjui/donut-test"):
38
+ self.device = (
39
+ "cuda"
40
+ if torch.cuda.is_available()
41
+ else "mps" if torch.backends.mps.is_available() else "cpu"
42
+ )
43
+ self.processor = DonutProcessor.from_pretrained(pretrained_model_repo_id)
44
+ self.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_repo_id)
45
+ self.model.eval()
46
+ self.model.to(self.device)
47
+ print(f"Using {self.device} device")
48
+
49
+ def predict(self, image: Image.Image) -> Dict[str, Any]:
50
+ # prepare encoder inputs
51
+ pixel_values = self.processor(image.convert("RGB"), return_tensors="pt").pixel_values
52
+ pixel_values = pixel_values.to(self.device)
53
+
54
+ # prepare decoder inputs
55
+ decoder_input_ids = self.processor.tokenizer(TASK_PROMPT_NAME, add_special_tokens=False, return_tensors="pt").input_ids
56
+ decoder_input_ids = decoder_input_ids.to(self.device)
57
+
58
+ # autoregressively generate sequence
59
+ outputs = self.model.generate(
60
+ pixel_values,
61
+ decoder_input_ids=decoder_input_ids,
62
+ max_length=self.model.decoder.config.max_position_embeddings,
63
+ early_stopping=True,
64
+ pad_token_id=self.processor.tokenizer.pad_token_id,
65
+ eos_token_id=self.processor.tokenizer.eos_token_id,
66
+ use_cache=True,
67
+ num_beams=1,
68
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
69
+ return_dict_in_generate=True
70
+ )
71
+
72
+ # turn into JSON
73
+ seq = self.processor.batch_decode(outputs.sequences)[0]
74
+ seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
75
+ seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
76
+ seq = self.processor.token2json(seq)
77
+ return seq
78
+
79
+ def evaluate(self, dataset: Dataset, ground_truth_key: str = "ground_truth") -> Tuple[Dict[str, Any], List[Any]]:
80
+ output_list = []
81
+ accs = []
82
+ ted_accs = []
83
+ f1_accs = []
84
+
85
+ for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
86
+ seq = self.predict(sample["image"])
87
+ ground_truth = sample[ground_truth_key]
88
+
89
+ # Original JSON accuracy
90
+ evaluator = JSONParseEvaluator()
91
+ score = evaluator.cal_acc(seq, ground_truth)
92
+ accs.append(score)
93
+ output_list.append(seq)
94
+
95
+ # TED (Tree Edit Distance) Accuracy
96
+ # Convert predictions and ground truth to string format for comparison
97
+ pred_str = str(seq) if seq else ""
98
+ gt_str = str(ground_truth) if ground_truth else ""
99
+
100
+ # Calculate normalized edit distance (1 - normalized_edit_distance = accuracy)
101
+ if len(pred_str) == 0 and len(gt_str) == 0:
102
+ ted_acc = 1.0
103
+ elif len(pred_str) == 0 or len(gt_str) == 0:
104
+ ted_acc = 0.0
105
+ else:
106
+ edit_dist = edit_distance(pred_str, gt_str)
107
+ max_len = max(len(pred_str), len(gt_str))
108
+ ted_acc = 1 - (edit_dist / max_len)
109
+ ted_accs.append(ted_acc)
110
+
111
+ # F1 Score Accuracy (character-level)
112
+ if len(pred_str) == 0 and len(gt_str) == 0:
113
+ f1_acc = 1.0
114
+ elif len(pred_str) == 0 or len(gt_str) == 0:
115
+ f1_acc = 0.0
116
+ else:
117
+ # Character-level precision and recall
118
+ pred_chars = set(pred_str)
119
+ gt_chars = set(gt_str)
120
+
121
+ if len(pred_chars) == 0:
122
+ precision = 0.0
123
+ else:
124
+ precision = len(pred_chars.intersection(gt_chars)) / len(pred_chars)
125
+
126
+ if len(gt_chars) == 0:
127
+ recall = 0.0
128
+ else:
129
+ recall = len(pred_chars.intersection(gt_chars)) / len(gt_chars)
130
+
131
+ if precision + recall == 0:
132
+ f1_acc = 0.0
133
+ else:
134
+ f1_acc = 2 * (precision * recall) / (precision + recall)
135
+ f1_accs.append(f1_acc)
136
+
137
+ scores = {
138
+ "accuracies": accs,
139
+ "mean_accuracy": np.mean(accs),
140
+ "ted_accuracies": ted_accs,
141
+ "mean_ted_accuracy": np.mean(ted_accs),
142
+ "f1_accuracies": f1_accs,
143
+ "mean_f1_accuracy": np.mean(f1_accs),
144
+ "length": len(accs)
145
+ }
146
+ return scores, output_list
147
+
148
+ class DonutTrainer:
149
+ processor = None
150
+ max_length = 768
151
+ image_size = [1280, 960]
152
+ added_tokens = []
153
+ train_dataloader = None
154
+ val_dataloader = None
155
+ huggingface_model_id = None
156
+
157
+ class DonutDataset(Dataset):
158
+ """
159
+ PyTorch Dataset for Donut. This class takes a HuggingFace Dataset as input.
160
+
161
+ Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
162
+ and it will be converted into pixel_values (vectorized image) and labels (input_ids of the tokenized string).
163
+
164
+ Args:
165
+ dataset: HuggingFace DatasetDict containing the dataset to be used
166
+ max_length: the max number of tokens for the target sequences
167
+ split: whether to load "train", "validation" or "test" split
168
+ ignore_id: ignore_index for torch.nn.CrossEntropyLoss
169
+ task_start_token: the special token to be fed to the decoder to conduct the target task
170
+ prompt_end_token: the special token at the end of the sequences
171
+ sort_json_key: whether or not to sort the JSON keys
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ dataset: DatasetDict,
177
+ ground_truth_key: str,
178
+ max_length: int,
179
+ split: str = "train",
180
+ ignore_id: int = -100,
181
+ task_start_token: str = "<s>",
182
+ prompt_end_token: str = None,
183
+ sort_json_key: bool = True,
184
+ ):
185
+ super().__init__()
186
+
187
+ self.dataset = dataset[split]
188
+ self.ground_truth_key = ground_truth_key
189
+ self.max_length = max_length
190
+ self.split = split
191
+ self.ignore_id = ignore_id
192
+ self.task_start_token = task_start_token
193
+ self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
194
+ self.sort_json_key = sort_json_key
195
+
196
+ self.dataset_length = len(self.dataset)
197
+
198
+ self.gt_token_sequences = []
199
+ for sample in self.dataset:
200
+ ground_truth = sample[self.ground_truth_key]
201
+ self.gt_token_sequences.append(
202
+ [
203
+ self.json2token(
204
+ gt_json,
205
+ update_special_tokens_for_json_key=self.split == "train",
206
+ sort_json_key=self.sort_json_key,
207
+ )
208
+ + DonutTrainer.processor.tokenizer.eos_token
209
+ for gt_json in [ground_truth] # load json from list of json
210
+ ]
211
+ )
212
+
213
+ self.add_tokens([self.task_start_token, self.prompt_end_token])
214
+ self.prompt_end_token_id = DonutTrainer.processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)
215
+
216
+ def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
217
+ """
218
+ Convert an ordered JSON object into a token sequence
219
+ """
220
+ if type(obj) == dict:
221
+ if len(obj) == 1 and "text_sequence" in obj:
222
+ return obj["text_sequence"]
223
+ else:
224
+ output = ""
225
+ if sort_json_key:
226
+ keys = sorted(obj.keys(), reverse=True)
227
+ else:
228
+ keys = obj.keys()
229
+ for k in keys:
230
+ if update_special_tokens_for_json_key:
231
+ self.add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
232
+ output += (
233
+ fr"<s_{k}>"
234
+ + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
235
+ + fr"</s_{k}>"
236
+ )
237
+ return output
238
+ elif type(obj) == list:
239
+ return r"<sep/>".join(
240
+ [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
241
+ )
242
+ else:
243
+ obj = str(obj)
244
+ if f"<{obj}/>" in DonutTrainer.added_tokens:
245
+ obj = f"<{obj}/>" # for categorical special tokens
246
+ return obj
247
+
248
+ def add_tokens(self, list_of_tokens: List[str]):
249
+ """
250
+ Add special tokens to tokenizer and resize the token embeddings of the decoder
251
+ """
252
+ newly_added_num = DonutTrainer.processor.tokenizer.add_tokens(list_of_tokens)
253
+ if newly_added_num > 0:
254
+ DonutTrainer.model.decoder.resize_token_embeddings(len(DonutTrainer.processor.tokenizer))
255
+ DonutTrainer.added_tokens.extend(list_of_tokens)
256
+
257
+ def __len__(self) -> int:
258
+ return self.dataset_length
259
+
260
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
261
+ """
262
+ Load image from image_path of given dataset_path and convert into input_tensor and labels
263
+ Convert gt data into input_ids (tokenized string)
264
+ Returns:
265
+ input_tensor : preprocessed image
266
+ input_ids : tokenized gt_data
267
+ labels : masked labels (model doesn't need to predict prompt and pad token)
268
+ """
269
+ sample = self.dataset[idx]
270
+
271
+ # inputs
272
+ pixel_values = DonutTrainer.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
273
+ pixel_values = pixel_values.squeeze()
274
+
275
+ # targets
276
+ target_sequence = random.choice(self.gt_token_sequences[idx]) # can be more than one, e.g., DocVQA Task 1
277
+ input_ids = DonutTrainer.processor.tokenizer(
278
+ target_sequence,
279
+ add_special_tokens=False,
280
+ max_length=self.max_length,
281
+ padding="max_length",
282
+ truncation=True,
283
+ return_tensors="pt",
284
+ )["input_ids"].squeeze(0)
285
+
286
+ labels = input_ids.clone()
287
+ labels[labels == DonutTrainer.processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token
288
+ # labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id # model doesn't need to predict prompt (for VQA)
289
+ return pixel_values, labels, target_sequence
290
+
291
+
292
+ class DonutModelPLModule(pl.LightningModule):
293
+ def __init__(self, config, processor, model):
294
+ super().__init__()
295
+ self.config = config
296
+ self.processor = processor
297
+ self.model = model
298
+
299
+ def training_step(self, batch, batch_idx):
300
+ pixel_values, labels, _ = batch
301
+
302
+ outputs = self.model(pixel_values, labels=labels)
303
+ loss = outputs.loss
304
+ self.log("train_loss", loss)
305
+ return loss
306
+
307
+ def validation_step(self, batch, batch_idx, dataset_idx=0):
308
+ pixel_values, labels, answers = batch
309
+ batch_size = pixel_values.shape[0]
310
+ # we feed the prompt to the model
311
+ decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
312
+
313
+ outputs = self.model.generate(pixel_values,
314
+ decoder_input_ids=decoder_input_ids,
315
+ max_length=DonutTrainer.max_length,
316
+ early_stopping=True,
317
+ pad_token_id=self.processor.tokenizer.pad_token_id,
318
+ eos_token_id=self.processor.tokenizer.eos_token_id,
319
+ use_cache=True,
320
+ num_beams=1,
321
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
322
+ return_dict_in_generate=True,)
323
+
324
+ predictions = []
325
+ for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
326
+ seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
327
+ seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
328
+ predictions.append(seq)
329
+
330
+ scores = []
331
+ for pred, answer in zip(predictions, answers):
332
+ pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
333
+ # NOT NEEDED ANYMORE
334
+ # answer = re.sub(r"<.*?>", "", answer, count=1)
335
+ answer = answer.replace(self.processor.tokenizer.eos_token, "")
336
+ scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
337
+
338
+ if self.config.get("verbose", False) and len(scores) == 1:
339
+ print(f"Prediction: {pred}")
340
+ print(f" Answer: {answer}")
341
+ print(f" Normed ED: {scores[0]}")
342
+
343
+ val_edit_distance = np.mean(scores)
344
+ self.log("val_edit_distance", val_edit_distance)
345
+ print(f"Validation Edit Distance: {val_edit_distance}")
346
+
347
+ return scores
348
+
349
+ def configure_optimizers(self):
350
+ # you could also add a learning rate scheduler if you want
351
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
352
+
353
+ return optimizer
354
+
355
+ def train_dataloader(self):
356
+ return DonutTrainer.train_dataloader
357
+
358
+ def val_dataloader(self):
359
+ return DonutTrainer.val_dataloader
360
+
361
+ class PushToHubCallback(Callback):
362
+ def on_train_epoch_end(self, trainer, pl_module):
363
+ print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
364
+ pl_module.model.push_to_hub(DonutTrainer.huggingface_model_id, commit_message=f"Training in progress, epoch {trainer.current_epoch}")
365
+ self._upload_logs(trainer.logger.log_dir, trainer.current_epoch)
366
+
367
+ def on_train_end(self, trainer, pl_module):
368
+ print(f"Pushing model to the hub after training")
369
+ pl_module.processor.push_to_hub(DonutTrainer.huggingface_model_id,commit_message=f"Training done")
370
+ pl_module.model.push_to_hub(DonutTrainer.huggingface_model_id, commit_message=f"Training done")
371
+ self._upload_logs(trainer.logger.log_dir, "final")
372
+
373
+ def _upload_logs(self, log_dir: str, epoch_info):
374
+ try:
375
+ print(f"Attempting to upload logs from: {log_dir}")
376
+ upload_folder(folder_path=log_dir, repo_id=DonutTrainer.huggingface_model_id,
377
+ path_in_repo="tensorboard_logs",
378
+ commit_message=f"Upload logs - epoch {epoch_info}", ignore_patterns=["*.tmp", "*.lock"])
379
+ print(f"Successfully uploaded logs for epoch {epoch_info}")
380
+ except Exception as e:
381
+ print(f"Failed to upload logs: {e}")
382
+ pass
383
+
384
+ @classmethod
385
+ def train(
386
+ cls,
387
+ dataset: DatasetDict,
388
+ pretrained_model_repo_id: str,
389
+ huggingface_model_id: str,
390
+ epochs: int,
391
+ train_batch_size: int,
392
+ val_batch_size: int,
393
+ learning_rate: float,
394
+ val_check_interval: float,
395
+ check_val_every_n_epoch: int,
396
+ gradient_clip_val: float,
397
+ num_training_samples_per_epoch: int,
398
+ num_nodes: int,
399
+ warmup_steps: int,
400
+ ground_truth_key: str = "ground_truth"
401
+ ):
402
+ cls.huggingface_model_id = huggingface_model_id
403
+ config = VisionEncoderDecoderConfig.from_pretrained(pretrained_model_repo_id)
404
+ config.encoder.image_size = cls.image_size
405
+ config.decoder.max_length = cls.max_length
406
+
407
+ cls.processor = DonutProcessor.from_pretrained(pretrained_model_repo_id)
408
+ cls.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_repo_id, config=config)
409
+ cls.processor.image_processor.size = cls.image_size[::-1]
410
+ cls.processor.image_processor.do_align_long_axis = False
411
+
412
+ train_dataset = cls.DonutDataset(
413
+ dataset=dataset,
414
+ ground_truth_key=ground_truth_key,
415
+ max_length=cls.max_length,
416
+ split="train",
417
+ task_start_token=TASK_PROMPT_NAME,
418
+ prompt_end_token=TASK_PROMPT_NAME,
419
+ sort_json_key=True
420
+ )
421
+ val_dataset = cls.DonutDataset(
422
+ dataset=dataset,
423
+ ground_truth_key=ground_truth_key,
424
+ max_length=cls.max_length,
425
+ split="validation",
426
+ task_start_token=TASK_PROMPT_NAME,
427
+ prompt_end_token=TASK_PROMPT_NAME,
428
+ sort_json_key=True
429
+ )
430
+
431
+ cls.model.config.pad_token_id = cls.processor.tokenizer.pad_token_id
432
+ cls.model.config.decoder_start_token_id = cls.processor.tokenizer.convert_tokens_to_ids([TASK_PROMPT_NAME])[0]
433
+
434
+ cls.train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
435
+ cls.val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
436
+
437
+ config = {
438
+ "max_epochs": epochs,
439
+ "val_check_interval": val_check_interval, # how many times we want to validate during an epoch
440
+ "check_val_every_n_epoch": check_val_every_n_epoch,
441
+ "gradient_clip_val": gradient_clip_val,
442
+ "num_training_samples_per_epoch": num_training_samples_per_epoch,
443
+ "lr": learning_rate,
444
+ "train_batch_sizes": [train_batch_size],
445
+ "val_batch_sizes": [val_batch_size],
446
+ # "seed":2022,
447
+ "num_nodes": num_nodes,
448
+ "warmup_steps": warmup_steps, # 10%
449
+ "result_path": "./.checkpoints",
450
+ "verbose": True,
451
+ }
452
+ model_module = cls.DonutModelPLModule(config, cls.processor, cls.model)
453
+
454
+ device = (
455
+ "cuda"
456
+ if torch.cuda.is_available()
457
+ else "mps" if torch.backends.mps.is_available() else "cpu"
458
+ )
459
+ print(f"Using {device} device")
460
+ trainer = pl.Trainer(
461
+ accelerator="gpu" if device == "cuda" else "mps" if device == "mps" else "cpu",
462
+ devices=1 if device == "cuda" else 0,
463
+ max_epochs=config.get("max_epochs"),
464
+ val_check_interval=config.get("val_check_interval"),
465
+ check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
466
+ gradient_clip_val=config.get("gradient_clip_val"),
467
+ precision=16 if device == "cuda" else 32, # we'll use mixed precision if device == "cuda"
468
+ num_sanity_val_steps=0,
469
+ logger=TensorBoardLogger(save_dir="./.checkpoints", name="donut_training", version=None),
470
+ callbacks=[cls.PushToHubCallback()]
471
+ )
472
+ trainer.fit(model_module)
menu/llm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .gemini import GeminiAPI
2
+ from .openai import OpenAIAPI
menu/llm/base.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+
5
+ class LLMBase(ABC):
6
+ @classmethod
7
+ @abstractmethod
8
+ def call(image: np.ndarray, model: str, token: str) -> dict:
9
+ raise NotImplementedError
menu/llm/gemini.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from google import genai
6
+ from google.genai import types
7
+
8
+ from .base import LLMBase
9
+
10
+ FUNCTION_CALL = json.load(open("tools/schema_gemini.json", "r"))
11
+
12
+ class GeminiAPI(LLMBase):
13
+ @classmethod
14
+ def call(cls, image: np.ndarray, model: str, token: str) -> dict:
15
+ client = genai.Client(api_key=token) # Initialize the client with the API key
16
+ encode_img = Image.fromarray(image) # Convert the image for the API
17
+
18
+ config = types.GenerateContentConfig(
19
+ tools=[types.Tool(function_declarations=[FUNCTION_CALL])],
20
+ tool_config={
21
+ "function_calling_config": {
22
+ "mode": "ANY",
23
+ "allowed_function_names": [FUNCTION_CALL["name"]]
24
+ }
25
+ }
26
+ )
27
+ response = client.models.generate_content(
28
+ model=model,
29
+ contents=[encode_img],
30
+ config=config
31
+ )
32
+ if response.candidates[0].content.parts[0].function_call:
33
+ function_call = response.candidates[0].content.parts[0].function_call
34
+ return function_call.args
35
+
36
+ return {}
menu/llm/openai.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import base64
3
+ from io import BytesIO
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ from openai import OpenAI
8
+
9
+ from .base import LLMBase
10
+
11
+ FUNCTION_CALL = json.load(open("tools/schema_openai.json", "r"))
12
+
13
+ class OpenAIAPI(LLMBase):
14
+ @classmethod
15
+ def call(cls, image: np.ndarray, model: str, token: str) -> dict:
16
+ client = OpenAI(api_key=token) # Initialize the client with the API key
17
+ buffer = BytesIO()
18
+ Image.fromarray(image).save(buffer, format="JPEG")
19
+ encode_img = base64.b64encode(buffer.getvalue()).decode("utf-8") # Convert the image for the API
20
+
21
+ response = client.responses.create(
22
+ model=model,
23
+ input=[
24
+ {
25
+ "role": "user",
26
+ "content": [
27
+ {
28
+ "type": "input_image",
29
+ "image_url": f"data:image/jpeg;base64,{encode_img}",
30
+ },
31
+ ],
32
+ }
33
+ ],
34
+ tools=[FUNCTION_CALL],
35
+ )
36
+ if response and response.output:
37
+ if hasattr(response.output[0], "arguments"):
38
+ return json.loads(response.output[0].arguments)
39
+ return {}
menu/utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from datasets import Dataset, DatasetDict
4
+
5
+ def split_dataset(
6
+ dataset: Dataset,
7
+ train: float,
8
+ validation: float,
9
+ test: float,
10
+ seed: Optional[int] = None
11
+ ) -> DatasetDict:
12
+ """
13
+ Split a single-split Hugging Face Dataset into train/validation/test subsets.
14
+
15
+ Args:
16
+ dataset (Dataset): The input dataset (e.g. load_dataset(...)['train']).
17
+ train (float): Proportion of data for the train split (0 < train < 1).
18
+ val (float): Proportion of data for the validation split (0 < val < 1).
19
+ test (float): Proportion of data for the test split (0 < test < 1).
20
+ Must satisfy train + val + test == 1.0.
21
+ seed (int): Random seed for reproducibility (default: None).
22
+
23
+ Returns:
24
+ DatasetDict: A dictionary with keys "train", "validation", and "test".
25
+ """
26
+ # Verify ratios sum to 1.0
27
+ total = train + validation + test
28
+ if abs(total - 1.0) > 1e-8:
29
+ raise ValueError(f"train + validation + test must equal 1.0 (got {total})")
30
+
31
+ # First split: extract train vs. temp (validation + test)
32
+ temp_size = validation + test
33
+ split_1 = dataset.train_test_split(test_size=temp_size, seed=seed)
34
+ train_ds = split_1["train"]
35
+ temp_ds = split_1["test"]
36
+
37
+ # Second split: divide temp into validation vs. test
38
+ relative_test_size = test / temp_size
39
+ split_2 = temp_ds.train_test_split(test_size=relative_test_size, seed=seed)
40
+ validation_ds = split_2["train"]
41
+ test_ds = split_2["test"]
42
+
43
+ # Return a DatasetDict with all three splits
44
+ return DatasetDict({
45
+ "train": train_ds,
46
+ "validation": validation_ds,
47
+ "test": test_ds,
48
+ })
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ authors = [{name = "ryanlinjui", email = "ryanlinjui@gmail.com"}]
3
+ name = "menu-text-detection"
4
+ version = "0.1.0"
5
+ description = "Extract structured menu information from images into JSON using a fine-tuned Donut E2E model."
6
+ readme = "README.md"
7
+ requires-python = "==3.11.*"
8
+ dependencies = [
9
+ "datasets>=3.6.0",
10
+ "dotenv>=0.9.9",
11
+ "google-genai>=1.14.0",
12
+ "gradio>=5.29.0",
13
+ "huggingface-hub>=0.31.1",
14
+ "matplotlib>=3.10.1",
15
+ "nltk>=3.9.1",
16
+ "notebook>=7.4.2",
17
+ "openai>=1.77.0",
18
+ "pillow>=11.2.1",
19
+ "pillow-heif>=0.22.0",
20
+ "protobuf>=6.30.2",
21
+ "pytorch-lightning>=2.5.2",
22
+ "sentencepiece>=0.2.0",
23
+ "tensorboardx>=2.6.2.2",
24
+ "transformers==4.49",
25
+ "torch==2.4.1",
26
+ "donut-python>=1.0.9",
27
+ ]
requirements.txt ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.11.18
4
+ aiosignal==1.3.2
5
+ annotated-types==0.7.0
6
+ anyio==4.9.0
7
+ appnope==0.1.4
8
+ argon2-cffi==23.1.0
9
+ argon2-cffi-bindings==21.2.0
10
+ arrow==1.3.0
11
+ asttokens==3.0.0
12
+ async-lru==2.0.5
13
+ attrs==25.3.0
14
+ babel==2.17.0
15
+ beautifulsoup4==4.13.4
16
+ bleach==6.2.0
17
+ cachetools==5.5.2
18
+ certifi==2025.4.26
19
+ cffi==1.17.1
20
+ charset-normalizer==3.4.2
21
+ click==8.1.8
22
+ comm==0.2.2
23
+ contourpy==1.3.2
24
+ cycler==0.12.1
25
+ datasets==3.6.0
26
+ debugpy==1.8.14
27
+ decorator==5.2.1
28
+ defusedxml==0.7.1
29
+ dill==0.3.8
30
+ distro==1.9.0
31
+ donut-python==1.0.9
32
+ dotenv==0.9.9
33
+ executing==2.2.0
34
+ fastapi==0.115.12
35
+ fastjsonschema==2.21.1
36
+ ffmpy==0.5.0
37
+ filelock==3.18.0
38
+ fonttools==4.57.0
39
+ fqdn==1.5.1
40
+ frozenlist==1.6.0
41
+ fsspec==2025.3.0
42
+ google-auth==2.40.1
43
+ google-genai==1.14.0
44
+ gradio==5.29.0
45
+ gradio-client==1.10.0
46
+ groovy==0.1.2
47
+ h11==0.16.0
48
+ hf-xet==1.1.0
49
+ httpcore==1.0.9
50
+ httpx==0.28.1
51
+ huggingface-hub==0.31.1
52
+ idna==3.10
53
+ ipykernel==6.29.5
54
+ ipython==9.2.0
55
+ ipython-pygments-lexers==1.1.1
56
+ isoduration==20.11.0
57
+ jedi==0.19.2
58
+ jinja2==3.1.6
59
+ jiter==0.9.0
60
+ joblib==1.5.0
61
+ json5==0.12.0
62
+ jsonpointer==3.0.0
63
+ jsonschema==4.23.0
64
+ jsonschema-specifications==2025.4.1
65
+ jupyter-client==8.6.3
66
+ jupyter-core==5.7.2
67
+ jupyter-events==0.12.0
68
+ jupyter-lsp==2.2.5
69
+ jupyter-server==2.15.0
70
+ jupyter-server-terminals==0.5.3
71
+ jupyterlab==4.4.2
72
+ jupyterlab-pygments==0.3.0
73
+ jupyterlab-server==2.27.3
74
+ kiwisolver==1.4.8
75
+ lightning-utilities==0.14.3
76
+ markdown-it-py==3.0.0
77
+ markupsafe==3.0.2
78
+ matplotlib==3.10.1
79
+ matplotlib-inline==0.1.7
80
+ mdurl==0.1.2
81
+ mistune==3.1.3
82
+ mpmath==1.3.0
83
+ multidict==6.4.3
84
+ multiprocess==0.70.16
85
+ munch==4.0.0
86
+ nbclient==0.10.2
87
+ nbconvert==7.16.6
88
+ nbformat==5.10.4
89
+ nest-asyncio==1.6.0
90
+ networkx==3.4.2
91
+ nltk==3.9.1
92
+ notebook==7.4.2
93
+ notebook-shim==0.2.4
94
+ numpy==2.2.5
95
+ openai==1.77.0
96
+ orjson==3.10.18
97
+ overrides==7.7.0
98
+ packaging==25.0
99
+ pandas==2.2.3
100
+ pandocfilters==1.5.1
101
+ parso==0.8.4
102
+ pexpect==4.9.0
103
+ pillow==11.2.1
104
+ pillow-heif==0.22.0
105
+ platformdirs==4.3.8
106
+ prometheus-client==0.21.1
107
+ prompt-toolkit==3.0.51
108
+ propcache==0.3.1
109
+ protobuf==6.30.2
110
+ psutil==7.0.0
111
+ ptyprocess==0.7.0
112
+ pure-eval==0.2.3
113
+ pyarrow==20.0.0
114
+ pyasn1==0.6.1
115
+ pyasn1-modules==0.4.2
116
+ pycparser==2.22
117
+ pydantic==2.11.4
118
+ pydantic-core==2.33.2
119
+ pydub==0.25.1
120
+ pygments==2.19.1
121
+ pyparsing==3.2.3
122
+ python-dateutil==2.9.0.post0
123
+ python-dotenv==1.1.0
124
+ python-json-logger==3.3.0
125
+ python-multipart==0.0.20
126
+ pytorch-lightning==2.5.2
127
+ pytz==2025.2
128
+ pyyaml==6.0.2
129
+ pyzmq==26.4.0
130
+ referencing==0.36.2
131
+ regex==2024.11.6
132
+ requests==2.32.3
133
+ rfc3339-validator==0.1.4
134
+ rfc3986-validator==0.1.1
135
+ rich==14.0.0
136
+ rpds-py==0.24.0
137
+ rsa==4.9.1
138
+ ruamel-yaml==0.18.14
139
+ ruamel-yaml-clib==0.2.12
140
+ ruff==0.11.8
141
+ safehttpx==0.1.6
142
+ safetensors==0.5.3
143
+ sconf==0.2.5
144
+ semantic-version==2.10.0
145
+ send2trash==1.8.3
146
+ sentencepiece==0.2.0
147
+ setuptools==80.3.1
148
+ shellingham==1.5.4
149
+ six==1.17.0
150
+ sniffio==1.3.1
151
+ soupsieve==2.7
152
+ stack-data==0.6.3
153
+ starlette==0.46.2
154
+ sympy==1.14.0
155
+ tensorboardx==2.6.2.2
156
+ terminado==0.18.1
157
+ timm==1.0.16
158
+ tinycss2==1.4.0
159
+ tokenizers==0.21.1
160
+ tomlkit==0.13.2
161
+ torch==2.4.1
162
+ torchmetrics==1.7.3
163
+ torchvision==0.19.1
164
+ tornado==6.4.2
165
+ tqdm==4.67.1
166
+ traitlets==5.14.3
167
+ transformers==4.49.0
168
+ typer==0.15.3
169
+ types-python-dateutil==2.9.0.20241206
170
+ typing-extensions==4.13.2
171
+ typing-inspection==0.4.0
172
+ tzdata==2025.2
173
+ uri-template==1.3.0
174
+ urllib3==2.4.0
175
+ uvicorn==0.34.2
176
+ wcwidth==0.2.13
177
+ webcolors==24.11.1
178
+ webencodings==0.5.1
179
+ websocket-client==1.8.0
180
+ websockets==15.0.1
181
+ xxhash==3.5.0
182
+ yarl==1.20.0
183
+ zss==1.2.0
tools/schema_gemini.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "extract_menu_data",
3
+ "description": "Extract structured menu information from images.",
4
+ "parameters": {
5
+ "type": "object",
6
+ "properties": {
7
+ "restaurant": {
8
+ "type": "string",
9
+ "description": "Name of the restaurant. If the name is not available, it should be ''."
10
+ },
11
+ "address": {
12
+ "type": "string",
13
+ "description": "Address of the restaurant. If the address is not available, it should be ''."
14
+ },
15
+ "phone": {
16
+ "type": "string",
17
+ "description": "Phone number of the restaurant. If the phone number is not available, it should be ''."
18
+ },
19
+ "business_hours": {
20
+ "type": "string",
21
+ "description": "Business hours of the restaurant. If the business hours are not available, it should be ''."
22
+ },
23
+ "dishes": {
24
+ "type": "array",
25
+ "items": {
26
+ "type": "object",
27
+ "properties": {
28
+ "name": {
29
+ "type": "string",
30
+ "description": "Name of the menu item."
31
+ },
32
+ "price": {
33
+ "type": "string",
34
+ "description": "Price of the menu item. If the price is not available, it should be -1."
35
+ }
36
+ },
37
+ "required": ["name", "price"]
38
+ },
39
+ "description": "List of menu dishes item."
40
+ }
41
+ },
42
+ "required": ["restaurant", "address", "phone", "business_hours", "dishes"]
43
+ }
44
+ }
tools/schema_openai.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "function",
3
+ "name": "extract_menu_data",
4
+ "description": "Extract structured menu information from images.",
5
+ "parameters": {
6
+ "type": "object",
7
+ "properties": {
8
+ "restaurant": {
9
+ "type": "string",
10
+ "description": "Name of the restaurant. If the name is not available, it should be ''."
11
+ },
12
+ "address": {
13
+ "type": "string",
14
+ "description": "Address of the restaurant. If the address is not available, it should be ''."
15
+ },
16
+ "phone": {
17
+ "type": "string",
18
+ "description": "Phone number of the restaurant. If the phone number is not available, it should be ''."
19
+ },
20
+ "business_hours": {
21
+ "type": "string",
22
+ "description": "Business hours of the restaurant. If the business hours are not available, it should be ''."
23
+ },
24
+ "dishes": {
25
+ "type": "array",
26
+ "items": {
27
+ "type": "object",
28
+ "properties": {
29
+ "name": {
30
+ "type": "string",
31
+ "description": "Name of the menu item."
32
+ },
33
+ "price": {
34
+ "type": "string",
35
+ "description": "Price of the menu item. If the price is not available, it should be -1."
36
+ }
37
+ },
38
+ "required": ["name", "price"],
39
+ "additionalProperties": false
40
+ },
41
+ "description": "List of menu dishes item."
42
+ }
43
+ },
44
+ "required": ["restaurant", "address", "phone", "business_hours", "dishes"],
45
+ "additionalProperties": false
46
+ }
47
+ }
train.ipynb ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Login to HuggingFace (just login once)"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "from huggingface_hub import interpreter_login\n",
17
+ "interpreter_login()"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "# Collect Menu Image Datasets\n",
25
+ "- Use `metadata.jsonl` to label the images's ground truth. You can visit [here](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) to see the examples.\n",
26
+ "- After finishing, push to HuggingFace Datasets.\n",
27
+ "- For labeling:\n",
28
+ " - [Google AI Studio](https://aistudio.google.com) or [OpenAI ChatGPT](https://chatgpt.com).\n",
29
+ " - Use function calling by API. Start the gradio app locally or visit [here](https://huggingface.co/spaces/ryanlinjui/menu-text-detection).\n",
30
+ "\n",
31
+ "### Menu Type\n",
32
+ "- **h**: horizontal menu\n",
33
+ "- **v**: vertical menu\n",
34
+ "- **d**: document-style menu\n",
35
+ "- **s**: in-scene menu (non-document style)\n",
36
+ "- **i**: irregular menu (menu with irregular text layout)\n",
37
+ "\n",
38
+ "> Please see the [examples](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) for more details."
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "import os\n",
48
+ "import json\n",
49
+ "\n",
50
+ "import numpy as np\n",
51
+ "from PIL import Image\n",
52
+ "from pillow_heif import register_heif_opener\n",
53
+ "\n",
54
+ "from menu.llm import (\n",
55
+ " GeminiAPI,\n",
56
+ " OpenAIAPI\n",
57
+ ")\n",
58
+ "\n",
59
+ "IMAGE_DIR = \"datasets/images\" # set your image directory here\n",
60
+ "SELECTED_MODEL = \"gemini-2.5-flash\" # set model name here, refer MODEL_LIST from app.py for more\n",
61
+ "API_TOKEN = \"\" # set your API token here\n",
62
+ "SELECTED_FUNCTION = GeminiAPI # set \"GeminiAPI\" or \"OpenAIAPI\"\n",
63
+ "\n",
64
+ "register_heif_opener()\n",
65
+ "\n",
66
+ "for file in os.listdir(IMAGE_DIR):\n",
67
+ " print(f\"Processing image: {file}\")\n",
68
+ " try:\n",
69
+ " image = np.array(Image.open(os.path.join(IMAGE_DIR, file)))\n",
70
+ " data = {\n",
71
+ " \"file_name\": file,\n",
72
+ " \"menu\": SELECTED_FUNCTION.call(image, SELECTED_MODEL, API_TOKEN)\n",
73
+ " }\n",
74
+ " with open(os.path.join(IMAGE_DIR, \"metadata.jsonl\"), \"a\", encoding=\"utf-8\") as metaf:\n",
75
+ " metaf.write(json.dumps(data, ensure_ascii=False, sort_keys=True) + \"\\n\")\n",
76
+ " except Exception as e:\n",
77
+ " print(f\"Skipping invalid image '{file}': {e}\")\n",
78
+ " continue"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "metadata": {},
84
+ "source": [
85
+ "# Push Datasets to HuggingFace"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "from datasets import load_dataset\n",
95
+ "\n",
96
+ "dataset = load_dataset(path=\"datasets/menu-zh-TW\") # load dataset from the local directory including the metadata.jsonl, images files.\n",
97
+ "dataset.push_to_hub(repo_id=\"ryanlinjui/menu-zh-TW\") # push to the huggingface dataset hub"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "markdown",
102
+ "metadata": {},
103
+ "source": [
104
+ "# Prepare the dataset for training"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "from menu.utils import split_dataset\n",
114
+ "from datasets import load_dataset\n",
115
+ "\n",
116
+ "dataset = load_dataset(path=\"ryanlinjui/menu-zh-TW\") # set your dataset repo id for training\n",
117
+ "dataset = split_dataset(dataset[\"train\"], train=0.8, validation=0.1, test=0.1, seed=42) # (optional) use it if your dataset is not split into train/validation/test\n",
118
+ "print(f\"Dataset split: {len(dataset['train'])} train, {len(dataset['validation'])} validation, {len(dataset['test'])} test\")"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "markdown",
123
+ "metadata": {},
124
+ "source": [
125
+ "# Fine-tune Donut Model"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "import logging\n",
135
+ "from menu.donut import DonutTrainer\n",
136
+ "\n",
137
+ "logging.getLogger(\"transformers\").setLevel(logging.ERROR) # filter output message from transformers\n",
138
+ "\n",
139
+ "DonutTrainer.train(\n",
140
+ " dataset=dataset,\n",
141
+ " pretrained_model_repo_id=\"naver-clova-ix/donut-base\", # set your pretrained model repo id for fine-tuning\n",
142
+ " ground_truth_key=\"menu\", # set your ground truth key for training\n",
143
+ " huggingface_model_id=\"ryanlinjui/donut-base-finetuned-menu\", # set your huggingface model repo id for saving / pushing to the hub\n",
144
+ " epochs=15, # set your training epochs\n",
145
+ " train_batch_size=8, # set your training batch size\n",
146
+ " val_batch_size=1, # set your validation batch size\n",
147
+ " learning_rate=3e-5, # set your learning rate\n",
148
+ " val_check_interval=0.5, # how many times we want to validate during an epoch\n",
149
+ " check_val_every_n_epoch=1, # how many epochs we want to validate\n",
150
+ " gradient_clip_val=1.0, # gradient clipping value for training stability\n",
151
+ " num_training_samples_per_epoch=198, # set num_training_samples_per_epoch = training set size\n",
152
+ " num_nodes=1, # number of nodes for distributed training\n",
153
+ " warmup_steps=75 # number of warmup steps for learning rate scheduler, 198/8*30/10, 10%\n",
154
+ ")"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "markdown",
159
+ "metadata": {},
160
+ "source": [
161
+ "# Evaluate Donut Model"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": [
170
+ "import json\n",
171
+ "from datasets import load_dataset\n",
172
+ "\n",
173
+ "from menu.utils import split_dataset\n",
174
+ "from menu.donut import DonutFinetuned\n",
175
+ "\n",
176
+ "dataset = load_dataset(\"ryanlinjui/menu-zh-TW\")\n",
177
+ "dataset = split_dataset(dataset[\"train\"], train=0.8, validation=0.1, test=0.1, seed=42) # (optional) use it if your dataset is not split into train/validation/test\n",
178
+ "donut_finetuned = DonutFinetuned(pretrained_model_repo_id=\"ryanlinjui/donut-base-finetuned-menu\")\n",
179
+ "scores, output_list = donut_finetuned.evaluate(dataset=dataset[\"test\"], ground_truth_key=\"menu\")\n",
180
+ "\n",
181
+ "print(\"Evaluation scores:\")\n",
182
+ "for key, value in scores.items():\n",
183
+ " print(f\"{key}: {value}\")\n",
184
+ "\n",
185
+ "print(\"\\nSample outputs:\")\n",
186
+ "for output in output_list[:5]:\n",
187
+ " print(json.dumps(output, ensure_ascii=False, indent=4))"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "markdown",
192
+ "metadata": {},
193
+ "source": [
194
+ "# Test Donut Model"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "metadata": {},
201
+ "outputs": [],
202
+ "source": [
203
+ "from PIL import Image\n",
204
+ "from menu.donut import DonutFinetuned\n",
205
+ "\n",
206
+ "image = Image.open(\"./examples/menu-hd.jpg\")\n",
207
+ "\n",
208
+ "donut_finetuned = DonutFinetuned(pretrained_model_repo_id=\"ryanlinjui/donut-base-finetuned-menu\")\n",
209
+ "outputs = donut_finetuned.predict(image=image)\n",
210
+ "print(outputs)"
211
+ ]
212
+ }
213
+ ],
214
+ "metadata": {
215
+ "kernelspec": {
216
+ "display_name": "menu-text-detection",
217
+ "language": "python",
218
+ "name": "python3"
219
+ },
220
+ "language_info": {
221
+ "codemirror_mode": {
222
+ "name": "ipython",
223
+ "version": 3
224
+ },
225
+ "file_extension": ".py",
226
+ "mimetype": "text/x-python",
227
+ "name": "python",
228
+ "nbconvert_exporter": "python",
229
+ "pygments_lexer": "ipython3",
230
+ "version": "3.11.12"
231
+ }
232
+ },
233
+ "nbformat": 4,
234
+ "nbformat_minor": 2
235
+ }
uv.lock ADDED
The diff for this file is too large to render. See raw diff