Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
6bd37dd
0
Parent(s):
Sync from https://github.com/ryanlinjui/menu-text-detection
Browse files- .checkpoints/.gitkeep +0 -0
- .env.example +3 -0
- .github/workflows/sync.yml +25 -0
- .gitignore +24 -0
- .python-version +1 -0
- LICENSE +21 -0
- README.md +65 -0
- app.py +157 -0
- menu/donut.py +472 -0
- menu/llm/__init__.py +2 -0
- menu/llm/base.py +9 -0
- menu/llm/gemini.py +36 -0
- menu/llm/openai.py +39 -0
- menu/utils.py +48 -0
- pyproject.toml +27 -0
- requirements.txt +183 -0
- tools/schema_gemini.json +44 -0
- tools/schema_openai.json +47 -0
- train.ipynb +235 -0
- uv.lock +0 -0
.checkpoints/.gitkeep
ADDED
File without changes
|
.env.example
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
HUGGINGFACE_TOKEN="HUGGINGFACE_TOKEN"
|
2 |
+
GEMINI_API_TOKEN="GEMINI_API_TOKEN"
|
3 |
+
OPENAI_API_TOKEN="OPENAI_API_TOKEN"
|
.github/workflows/sync.yml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face Spaces
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
jobs:
|
8 |
+
sync:
|
9 |
+
name: Sync
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
steps:
|
12 |
+
- name: Checkout Repository
|
13 |
+
uses: actions/checkout@v4
|
14 |
+
|
15 |
+
- name: Remove bad files
|
16 |
+
run: rm -rf examples assets
|
17 |
+
|
18 |
+
- name: Sync to Hugging Face Spaces
|
19 |
+
uses: JacobLinCool/huggingface-sync@v1
|
20 |
+
with:
|
21 |
+
github: ${{ secrets.GITHUB_TOKEN }}
|
22 |
+
user: ryanlinjui # Hugging Face username or organization name
|
23 |
+
space: menu-text-detection # Hugging Face space name
|
24 |
+
token: ${{ secrets.HF_TOKEN }} # Hugging Face token
|
25 |
+
python_version: 3.11 # Python version
|
.gitignore
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# mac
|
2 |
+
.DS_Store
|
3 |
+
|
4 |
+
# cache
|
5 |
+
__pycache__
|
6 |
+
|
7 |
+
# datasets
|
8 |
+
datasets
|
9 |
+
|
10 |
+
# papers
|
11 |
+
docs/papers
|
12 |
+
|
13 |
+
# uv
|
14 |
+
.venv
|
15 |
+
|
16 |
+
# gradio
|
17 |
+
.gradio
|
18 |
+
|
19 |
+
# env
|
20 |
+
.env
|
21 |
+
|
22 |
+
# checkpoint
|
23 |
+
.checkpoints/*
|
24 |
+
!.checkpoints/.gitkeep
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.11
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 RyanLin
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: menu text detection
|
3 |
+
emoji: 🦄
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
python_version: 3.11
|
8 |
+
short_description: Extract structured menu information from images into JSON...
|
9 |
+
tags: [ "donut","fine-tuning","image-to-text","transformer" ]
|
10 |
+
---
|
11 |
+
|
12 |
+
# Menu Text Detection System
|
13 |
+
|
14 |
+
Extract structured menu information from images into JSON using a fine-tuned E2E model or LLM.
|
15 |
+
|
16 |
+
[](https://huggingface.co/spaces/ryanlinjui/menu-text-detection)
|
17 |
+
[](https://huggingface.co/collections/ryanlinjui/menu-text-detection-670ccf527626bb004bbfb39b)
|
18 |
+
|
19 |
+
https://github.com/user-attachments/assets/80e5d54c-f2c8-4593-ad9b-499e5b71d8f6
|
20 |
+
|
21 |
+
## 🚀 Features
|
22 |
+
### Overview
|
23 |
+
Currently supports the following information from menu images:
|
24 |
+
|
25 |
+
- **Restaurant Name**
|
26 |
+
- **Business Hours**
|
27 |
+
- **Address**
|
28 |
+
- **Phone Number**
|
29 |
+
- **Dish Information**
|
30 |
+
- Name
|
31 |
+
- Price
|
32 |
+
|
33 |
+
> For the JSON schema, see [tools directory](./tools).
|
34 |
+
|
35 |
+
### Supported Methods to Extract Menu Information
|
36 |
+
#### Fine-tuned E2E model and Training metrics
|
37 |
+
- [**Donut (Document Parsing Task)**](https://huggingface.co/ryanlinjui/donut-base-finetuned-menu) - Base model by [Clova AI (ECCV ’22)](https://github.com/clovaai/donut)
|
38 |
+
|
39 |
+
#### LLM Function Calling
|
40 |
+
- Google Gemini API
|
41 |
+
- OpenAI GPT API
|
42 |
+
|
43 |
+
## 💻 Training / Fine-Tuning
|
44 |
+
### Setup
|
45 |
+
Use [uv](https://github.com/astral-sh/uv) to set up the development environment:
|
46 |
+
|
47 |
+
```bash
|
48 |
+
uv sync
|
49 |
+
```
|
50 |
+
|
51 |
+
> or use `pip install -r requirements.txt` if it has any problems
|
52 |
+
|
53 |
+
### Training Script (Datasets collecting, Fine-Tuning)
|
54 |
+
Please refer [`train.ipynb`](./train.ipynb). Use Jupyter Notebook for training:
|
55 |
+
|
56 |
+
```bash
|
57 |
+
uv run jupyter-notebook
|
58 |
+
```
|
59 |
+
|
60 |
+
> For VSCode users, please install Jupyter extension, then select `.venv/bin/python` as your kernel.
|
61 |
+
|
62 |
+
### Run Demo Locally
|
63 |
+
```bash
|
64 |
+
uv run python app.py
|
65 |
+
```
|
app.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from pillow_heif import register_heif_opener
|
8 |
+
|
9 |
+
from menu.llm import (
|
10 |
+
GeminiAPI,
|
11 |
+
OpenAIAPI
|
12 |
+
)
|
13 |
+
from menu.donut import DonutFinetuned
|
14 |
+
|
15 |
+
register_heif_opener()
|
16 |
+
load_dotenv(override=True)
|
17 |
+
GEMINI_API_TOKEN = os.getenv("GEMINI_API_TOKEN", "")
|
18 |
+
OPENAI_API_TOKEN = os.getenv("OPENAI_API_TOKEN", "")
|
19 |
+
|
20 |
+
SOURCE_CODE_GH_URL = "https://github.com/ryanlinjui/menu-text-detection"
|
21 |
+
BADGE_URL = "https://img.shields.io/badge/GitHub_Code-Click_Here!!-default?logo=github"
|
22 |
+
|
23 |
+
GITHUB_RAW_URL = "https://raw.githubusercontent.com/ryanlinjui/menu-text-detection/main"
|
24 |
+
EXAMPLE_IMAGE_LIST = [
|
25 |
+
f"{GITHUB_RAW_URL}/examples/menu-hd.jpg",
|
26 |
+
f"{GITHUB_RAW_URL}/examples/menu-vs.jpg",
|
27 |
+
f"{GITHUB_RAW_URL}/examples/menu-si.jpg"
|
28 |
+
]
|
29 |
+
FINETUNED_MODEL_LIST = [
|
30 |
+
"Donut (Document Parsing Task) Fine-tuned Model"
|
31 |
+
]
|
32 |
+
LLM_MODEL_LIST = [
|
33 |
+
"gemini-2.5-pro",
|
34 |
+
"gemini-2.5-flash",
|
35 |
+
"gemini-2.0-flash",
|
36 |
+
"gpt-4.1",
|
37 |
+
"gpt-4o",
|
38 |
+
"o4-mini"
|
39 |
+
]
|
40 |
+
|
41 |
+
donut_finetuned = DonutFinetuned("ryanlinjui/donut-base-finetuned-menu")
|
42 |
+
|
43 |
+
def handle(image: Image.Image, model: str, api_token: str) -> str:
|
44 |
+
if image is None:
|
45 |
+
raise gr.Error("Please upload an image first.")
|
46 |
+
|
47 |
+
if model == FINETUNED_MODEL_LIST[0]:
|
48 |
+
result = donut_finetuned.predict(image)
|
49 |
+
|
50 |
+
elif model in LLM_MODEL_LIST:
|
51 |
+
if len(api_token) < 10:
|
52 |
+
raise gr.Error(f"Please provide a valid token for {model}.")
|
53 |
+
try:
|
54 |
+
if model in LLM_MODEL_LIST[:3]:
|
55 |
+
result = GeminiAPI.call(image, model, api_token)
|
56 |
+
else:
|
57 |
+
result = OpenAIAPI.call(image, model, api_token)
|
58 |
+
except Exception as e:
|
59 |
+
raise gr.Error(f"Failed to process with API model {model}: {str(e)}")
|
60 |
+
else:
|
61 |
+
raise gr.Error("Invalid model selection. Please choose a valid model.")
|
62 |
+
|
63 |
+
return json.dumps(result, indent=4, ensure_ascii=False, sort_keys=True)
|
64 |
+
|
65 |
+
def UserInterface() -> gr.Interface:
|
66 |
+
with gr.Blocks(
|
67 |
+
delete_cache=(86400, 86400),
|
68 |
+
css="""
|
69 |
+
.image-panel {
|
70 |
+
display: flex;
|
71 |
+
flex-direction: column;
|
72 |
+
height: 600px;
|
73 |
+
}
|
74 |
+
.image-panel img {
|
75 |
+
object-fit: contain;
|
76 |
+
max-height: 600px;
|
77 |
+
max-width: 600px;
|
78 |
+
width: 100%;
|
79 |
+
}
|
80 |
+
.large-text textarea {
|
81 |
+
font-size: 20px !important;
|
82 |
+
height: 600px !important;
|
83 |
+
width: 100% !important;
|
84 |
+
}
|
85 |
+
"""
|
86 |
+
) as gradio_interface:
|
87 |
+
gr.HTML(f'<a href="{SOURCE_CODE_GH_URL}"><img src="{BADGE_URL}" alt="GitHub Code"/></a>')
|
88 |
+
gr.Markdown("# Menu Text Detection")
|
89 |
+
|
90 |
+
with gr.Row():
|
91 |
+
with gr.Column(scale=1, min_width=500):
|
92 |
+
gr.Markdown("## 📷 Menu Image")
|
93 |
+
menu_image = gr.Image(
|
94 |
+
type="pil",
|
95 |
+
label="Input menu image",
|
96 |
+
elem_classes="image-panel"
|
97 |
+
)
|
98 |
+
|
99 |
+
gr.Markdown("## 🤖 Model Selection")
|
100 |
+
model_choice_dropdown = gr.Dropdown(
|
101 |
+
choices=FINETUNED_MODEL_LIST + LLM_MODEL_LIST,
|
102 |
+
value=FINETUNED_MODEL_LIST[0],
|
103 |
+
label="Select Text Detection Model"
|
104 |
+
)
|
105 |
+
|
106 |
+
api_token_textbox = gr.Textbox(
|
107 |
+
label="API Token",
|
108 |
+
placeholder="Enter your API token here...",
|
109 |
+
type="password",
|
110 |
+
visible=False
|
111 |
+
)
|
112 |
+
|
113 |
+
generate_button = gr.Button("Generate Menu Information", variant="primary")
|
114 |
+
|
115 |
+
gr.Examples(
|
116 |
+
examples=EXAMPLE_IMAGE_LIST,
|
117 |
+
inputs=menu_image,
|
118 |
+
label="Example Menu Images"
|
119 |
+
)
|
120 |
+
|
121 |
+
with gr.Column(scale=1):
|
122 |
+
gr.Markdown("## 🍽️ Menu Info")
|
123 |
+
menu_json_textbox = gr.Textbox(
|
124 |
+
label="Ouput JSON",
|
125 |
+
interactive=True,
|
126 |
+
text_align="left",
|
127 |
+
elem_classes="large-text"
|
128 |
+
)
|
129 |
+
|
130 |
+
def update_token_visibility(choice):
|
131 |
+
if choice in LLM_MODEL_LIST:
|
132 |
+
current_token = ""
|
133 |
+
if choice in LLM_MODEL_LIST[:3]:
|
134 |
+
current_token = GEMINI_API_TOKEN
|
135 |
+
else:
|
136 |
+
current_token = OPENAI_API_TOKEN
|
137 |
+
return gr.update(visible=True, value=current_token)
|
138 |
+
else:
|
139 |
+
return gr.update(visible=False)
|
140 |
+
|
141 |
+
model_choice_dropdown.change(
|
142 |
+
fn=update_token_visibility,
|
143 |
+
inputs=model_choice_dropdown,
|
144 |
+
outputs=api_token_textbox
|
145 |
+
)
|
146 |
+
|
147 |
+
generate_button.click(
|
148 |
+
fn=handle,
|
149 |
+
inputs=[menu_image, model_choice_dropdown, api_token_textbox],
|
150 |
+
outputs=menu_json_textbox
|
151 |
+
)
|
152 |
+
|
153 |
+
return gradio_interface
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
demo = UserInterface()
|
157 |
+
demo.launch()
|
menu/donut.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is modified from the HuggingFace transformers tutorial script for fine-tuning Donut on a custom dataset.
|
3 |
+
It's defined from `.ipynb` to the module implementation for better reusability and maintainability.
|
4 |
+
Reference: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Donut/CORD/Fine_tune_Donut_on_a_custom_dataset_(CORD)_with_PyTorch_Lightning.ipynb
|
5 |
+
"""
|
6 |
+
|
7 |
+
import re
|
8 |
+
import random
|
9 |
+
from typing import Any, List, Tuple, Dict
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
from PIL import Image
|
14 |
+
from tqdm.auto import tqdm
|
15 |
+
from nltk import edit_distance
|
16 |
+
import pytorch_lightning as pl
|
17 |
+
from datasets import DatasetDict
|
18 |
+
from donut import JSONParseEvaluator
|
19 |
+
from huggingface_hub import upload_folder
|
20 |
+
from pillow_heif import register_heif_opener
|
21 |
+
from pytorch_lightning.callbacks import Callback
|
22 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
23 |
+
from torch.utils.data import (
|
24 |
+
Dataset,
|
25 |
+
DataLoader
|
26 |
+
)
|
27 |
+
from transformers import (
|
28 |
+
DonutProcessor,
|
29 |
+
VisionEncoderDecoderModel,
|
30 |
+
VisionEncoderDecoderConfig
|
31 |
+
)
|
32 |
+
|
33 |
+
TASK_PROMPT_NAME = "<s_menu-text-detection>"
|
34 |
+
register_heif_opener()
|
35 |
+
|
36 |
+
class DonutFinetuned:
|
37 |
+
def __init__(self, pretrained_model_repo_id: str = "ryanlinjui/donut-test"):
|
38 |
+
self.device = (
|
39 |
+
"cuda"
|
40 |
+
if torch.cuda.is_available()
|
41 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
42 |
+
)
|
43 |
+
self.processor = DonutProcessor.from_pretrained(pretrained_model_repo_id)
|
44 |
+
self.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_repo_id)
|
45 |
+
self.model.eval()
|
46 |
+
self.model.to(self.device)
|
47 |
+
print(f"Using {self.device} device")
|
48 |
+
|
49 |
+
def predict(self, image: Image.Image) -> Dict[str, Any]:
|
50 |
+
# prepare encoder inputs
|
51 |
+
pixel_values = self.processor(image.convert("RGB"), return_tensors="pt").pixel_values
|
52 |
+
pixel_values = pixel_values.to(self.device)
|
53 |
+
|
54 |
+
# prepare decoder inputs
|
55 |
+
decoder_input_ids = self.processor.tokenizer(TASK_PROMPT_NAME, add_special_tokens=False, return_tensors="pt").input_ids
|
56 |
+
decoder_input_ids = decoder_input_ids.to(self.device)
|
57 |
+
|
58 |
+
# autoregressively generate sequence
|
59 |
+
outputs = self.model.generate(
|
60 |
+
pixel_values,
|
61 |
+
decoder_input_ids=decoder_input_ids,
|
62 |
+
max_length=self.model.decoder.config.max_position_embeddings,
|
63 |
+
early_stopping=True,
|
64 |
+
pad_token_id=self.processor.tokenizer.pad_token_id,
|
65 |
+
eos_token_id=self.processor.tokenizer.eos_token_id,
|
66 |
+
use_cache=True,
|
67 |
+
num_beams=1,
|
68 |
+
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
|
69 |
+
return_dict_in_generate=True
|
70 |
+
)
|
71 |
+
|
72 |
+
# turn into JSON
|
73 |
+
seq = self.processor.batch_decode(outputs.sequences)[0]
|
74 |
+
seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
|
75 |
+
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
|
76 |
+
seq = self.processor.token2json(seq)
|
77 |
+
return seq
|
78 |
+
|
79 |
+
def evaluate(self, dataset: Dataset, ground_truth_key: str = "ground_truth") -> Tuple[Dict[str, Any], List[Any]]:
|
80 |
+
output_list = []
|
81 |
+
accs = []
|
82 |
+
ted_accs = []
|
83 |
+
f1_accs = []
|
84 |
+
|
85 |
+
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
|
86 |
+
seq = self.predict(sample["image"])
|
87 |
+
ground_truth = sample[ground_truth_key]
|
88 |
+
|
89 |
+
# Original JSON accuracy
|
90 |
+
evaluator = JSONParseEvaluator()
|
91 |
+
score = evaluator.cal_acc(seq, ground_truth)
|
92 |
+
accs.append(score)
|
93 |
+
output_list.append(seq)
|
94 |
+
|
95 |
+
# TED (Tree Edit Distance) Accuracy
|
96 |
+
# Convert predictions and ground truth to string format for comparison
|
97 |
+
pred_str = str(seq) if seq else ""
|
98 |
+
gt_str = str(ground_truth) if ground_truth else ""
|
99 |
+
|
100 |
+
# Calculate normalized edit distance (1 - normalized_edit_distance = accuracy)
|
101 |
+
if len(pred_str) == 0 and len(gt_str) == 0:
|
102 |
+
ted_acc = 1.0
|
103 |
+
elif len(pred_str) == 0 or len(gt_str) == 0:
|
104 |
+
ted_acc = 0.0
|
105 |
+
else:
|
106 |
+
edit_dist = edit_distance(pred_str, gt_str)
|
107 |
+
max_len = max(len(pred_str), len(gt_str))
|
108 |
+
ted_acc = 1 - (edit_dist / max_len)
|
109 |
+
ted_accs.append(ted_acc)
|
110 |
+
|
111 |
+
# F1 Score Accuracy (character-level)
|
112 |
+
if len(pred_str) == 0 and len(gt_str) == 0:
|
113 |
+
f1_acc = 1.0
|
114 |
+
elif len(pred_str) == 0 or len(gt_str) == 0:
|
115 |
+
f1_acc = 0.0
|
116 |
+
else:
|
117 |
+
# Character-level precision and recall
|
118 |
+
pred_chars = set(pred_str)
|
119 |
+
gt_chars = set(gt_str)
|
120 |
+
|
121 |
+
if len(pred_chars) == 0:
|
122 |
+
precision = 0.0
|
123 |
+
else:
|
124 |
+
precision = len(pred_chars.intersection(gt_chars)) / len(pred_chars)
|
125 |
+
|
126 |
+
if len(gt_chars) == 0:
|
127 |
+
recall = 0.0
|
128 |
+
else:
|
129 |
+
recall = len(pred_chars.intersection(gt_chars)) / len(gt_chars)
|
130 |
+
|
131 |
+
if precision + recall == 0:
|
132 |
+
f1_acc = 0.0
|
133 |
+
else:
|
134 |
+
f1_acc = 2 * (precision * recall) / (precision + recall)
|
135 |
+
f1_accs.append(f1_acc)
|
136 |
+
|
137 |
+
scores = {
|
138 |
+
"accuracies": accs,
|
139 |
+
"mean_accuracy": np.mean(accs),
|
140 |
+
"ted_accuracies": ted_accs,
|
141 |
+
"mean_ted_accuracy": np.mean(ted_accs),
|
142 |
+
"f1_accuracies": f1_accs,
|
143 |
+
"mean_f1_accuracy": np.mean(f1_accs),
|
144 |
+
"length": len(accs)
|
145 |
+
}
|
146 |
+
return scores, output_list
|
147 |
+
|
148 |
+
class DonutTrainer:
|
149 |
+
processor = None
|
150 |
+
max_length = 768
|
151 |
+
image_size = [1280, 960]
|
152 |
+
added_tokens = []
|
153 |
+
train_dataloader = None
|
154 |
+
val_dataloader = None
|
155 |
+
huggingface_model_id = None
|
156 |
+
|
157 |
+
class DonutDataset(Dataset):
|
158 |
+
"""
|
159 |
+
PyTorch Dataset for Donut. This class takes a HuggingFace Dataset as input.
|
160 |
+
|
161 |
+
Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
|
162 |
+
and it will be converted into pixel_values (vectorized image) and labels (input_ids of the tokenized string).
|
163 |
+
|
164 |
+
Args:
|
165 |
+
dataset: HuggingFace DatasetDict containing the dataset to be used
|
166 |
+
max_length: the max number of tokens for the target sequences
|
167 |
+
split: whether to load "train", "validation" or "test" split
|
168 |
+
ignore_id: ignore_index for torch.nn.CrossEntropyLoss
|
169 |
+
task_start_token: the special token to be fed to the decoder to conduct the target task
|
170 |
+
prompt_end_token: the special token at the end of the sequences
|
171 |
+
sort_json_key: whether or not to sort the JSON keys
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
dataset: DatasetDict,
|
177 |
+
ground_truth_key: str,
|
178 |
+
max_length: int,
|
179 |
+
split: str = "train",
|
180 |
+
ignore_id: int = -100,
|
181 |
+
task_start_token: str = "<s>",
|
182 |
+
prompt_end_token: str = None,
|
183 |
+
sort_json_key: bool = True,
|
184 |
+
):
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
self.dataset = dataset[split]
|
188 |
+
self.ground_truth_key = ground_truth_key
|
189 |
+
self.max_length = max_length
|
190 |
+
self.split = split
|
191 |
+
self.ignore_id = ignore_id
|
192 |
+
self.task_start_token = task_start_token
|
193 |
+
self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
|
194 |
+
self.sort_json_key = sort_json_key
|
195 |
+
|
196 |
+
self.dataset_length = len(self.dataset)
|
197 |
+
|
198 |
+
self.gt_token_sequences = []
|
199 |
+
for sample in self.dataset:
|
200 |
+
ground_truth = sample[self.ground_truth_key]
|
201 |
+
self.gt_token_sequences.append(
|
202 |
+
[
|
203 |
+
self.json2token(
|
204 |
+
gt_json,
|
205 |
+
update_special_tokens_for_json_key=self.split == "train",
|
206 |
+
sort_json_key=self.sort_json_key,
|
207 |
+
)
|
208 |
+
+ DonutTrainer.processor.tokenizer.eos_token
|
209 |
+
for gt_json in [ground_truth] # load json from list of json
|
210 |
+
]
|
211 |
+
)
|
212 |
+
|
213 |
+
self.add_tokens([self.task_start_token, self.prompt_end_token])
|
214 |
+
self.prompt_end_token_id = DonutTrainer.processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)
|
215 |
+
|
216 |
+
def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
|
217 |
+
"""
|
218 |
+
Convert an ordered JSON object into a token sequence
|
219 |
+
"""
|
220 |
+
if type(obj) == dict:
|
221 |
+
if len(obj) == 1 and "text_sequence" in obj:
|
222 |
+
return obj["text_sequence"]
|
223 |
+
else:
|
224 |
+
output = ""
|
225 |
+
if sort_json_key:
|
226 |
+
keys = sorted(obj.keys(), reverse=True)
|
227 |
+
else:
|
228 |
+
keys = obj.keys()
|
229 |
+
for k in keys:
|
230 |
+
if update_special_tokens_for_json_key:
|
231 |
+
self.add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
|
232 |
+
output += (
|
233 |
+
fr"<s_{k}>"
|
234 |
+
+ self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
|
235 |
+
+ fr"</s_{k}>"
|
236 |
+
)
|
237 |
+
return output
|
238 |
+
elif type(obj) == list:
|
239 |
+
return r"<sep/>".join(
|
240 |
+
[self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
|
241 |
+
)
|
242 |
+
else:
|
243 |
+
obj = str(obj)
|
244 |
+
if f"<{obj}/>" in DonutTrainer.added_tokens:
|
245 |
+
obj = f"<{obj}/>" # for categorical special tokens
|
246 |
+
return obj
|
247 |
+
|
248 |
+
def add_tokens(self, list_of_tokens: List[str]):
|
249 |
+
"""
|
250 |
+
Add special tokens to tokenizer and resize the token embeddings of the decoder
|
251 |
+
"""
|
252 |
+
newly_added_num = DonutTrainer.processor.tokenizer.add_tokens(list_of_tokens)
|
253 |
+
if newly_added_num > 0:
|
254 |
+
DonutTrainer.model.decoder.resize_token_embeddings(len(DonutTrainer.processor.tokenizer))
|
255 |
+
DonutTrainer.added_tokens.extend(list_of_tokens)
|
256 |
+
|
257 |
+
def __len__(self) -> int:
|
258 |
+
return self.dataset_length
|
259 |
+
|
260 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
261 |
+
"""
|
262 |
+
Load image from image_path of given dataset_path and convert into input_tensor and labels
|
263 |
+
Convert gt data into input_ids (tokenized string)
|
264 |
+
Returns:
|
265 |
+
input_tensor : preprocessed image
|
266 |
+
input_ids : tokenized gt_data
|
267 |
+
labels : masked labels (model doesn't need to predict prompt and pad token)
|
268 |
+
"""
|
269 |
+
sample = self.dataset[idx]
|
270 |
+
|
271 |
+
# inputs
|
272 |
+
pixel_values = DonutTrainer.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
|
273 |
+
pixel_values = pixel_values.squeeze()
|
274 |
+
|
275 |
+
# targets
|
276 |
+
target_sequence = random.choice(self.gt_token_sequences[idx]) # can be more than one, e.g., DocVQA Task 1
|
277 |
+
input_ids = DonutTrainer.processor.tokenizer(
|
278 |
+
target_sequence,
|
279 |
+
add_special_tokens=False,
|
280 |
+
max_length=self.max_length,
|
281 |
+
padding="max_length",
|
282 |
+
truncation=True,
|
283 |
+
return_tensors="pt",
|
284 |
+
)["input_ids"].squeeze(0)
|
285 |
+
|
286 |
+
labels = input_ids.clone()
|
287 |
+
labels[labels == DonutTrainer.processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token
|
288 |
+
# labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id # model doesn't need to predict prompt (for VQA)
|
289 |
+
return pixel_values, labels, target_sequence
|
290 |
+
|
291 |
+
|
292 |
+
class DonutModelPLModule(pl.LightningModule):
|
293 |
+
def __init__(self, config, processor, model):
|
294 |
+
super().__init__()
|
295 |
+
self.config = config
|
296 |
+
self.processor = processor
|
297 |
+
self.model = model
|
298 |
+
|
299 |
+
def training_step(self, batch, batch_idx):
|
300 |
+
pixel_values, labels, _ = batch
|
301 |
+
|
302 |
+
outputs = self.model(pixel_values, labels=labels)
|
303 |
+
loss = outputs.loss
|
304 |
+
self.log("train_loss", loss)
|
305 |
+
return loss
|
306 |
+
|
307 |
+
def validation_step(self, batch, batch_idx, dataset_idx=0):
|
308 |
+
pixel_values, labels, answers = batch
|
309 |
+
batch_size = pixel_values.shape[0]
|
310 |
+
# we feed the prompt to the model
|
311 |
+
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
|
312 |
+
|
313 |
+
outputs = self.model.generate(pixel_values,
|
314 |
+
decoder_input_ids=decoder_input_ids,
|
315 |
+
max_length=DonutTrainer.max_length,
|
316 |
+
early_stopping=True,
|
317 |
+
pad_token_id=self.processor.tokenizer.pad_token_id,
|
318 |
+
eos_token_id=self.processor.tokenizer.eos_token_id,
|
319 |
+
use_cache=True,
|
320 |
+
num_beams=1,
|
321 |
+
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
|
322 |
+
return_dict_in_generate=True,)
|
323 |
+
|
324 |
+
predictions = []
|
325 |
+
for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
|
326 |
+
seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
|
327 |
+
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
|
328 |
+
predictions.append(seq)
|
329 |
+
|
330 |
+
scores = []
|
331 |
+
for pred, answer in zip(predictions, answers):
|
332 |
+
pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
|
333 |
+
# NOT NEEDED ANYMORE
|
334 |
+
# answer = re.sub(r"<.*?>", "", answer, count=1)
|
335 |
+
answer = answer.replace(self.processor.tokenizer.eos_token, "")
|
336 |
+
scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
|
337 |
+
|
338 |
+
if self.config.get("verbose", False) and len(scores) == 1:
|
339 |
+
print(f"Prediction: {pred}")
|
340 |
+
print(f" Answer: {answer}")
|
341 |
+
print(f" Normed ED: {scores[0]}")
|
342 |
+
|
343 |
+
val_edit_distance = np.mean(scores)
|
344 |
+
self.log("val_edit_distance", val_edit_distance)
|
345 |
+
print(f"Validation Edit Distance: {val_edit_distance}")
|
346 |
+
|
347 |
+
return scores
|
348 |
+
|
349 |
+
def configure_optimizers(self):
|
350 |
+
# you could also add a learning rate scheduler if you want
|
351 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
|
352 |
+
|
353 |
+
return optimizer
|
354 |
+
|
355 |
+
def train_dataloader(self):
|
356 |
+
return DonutTrainer.train_dataloader
|
357 |
+
|
358 |
+
def val_dataloader(self):
|
359 |
+
return DonutTrainer.val_dataloader
|
360 |
+
|
361 |
+
class PushToHubCallback(Callback):
|
362 |
+
def on_train_epoch_end(self, trainer, pl_module):
|
363 |
+
print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
|
364 |
+
pl_module.model.push_to_hub(DonutTrainer.huggingface_model_id, commit_message=f"Training in progress, epoch {trainer.current_epoch}")
|
365 |
+
self._upload_logs(trainer.logger.log_dir, trainer.current_epoch)
|
366 |
+
|
367 |
+
def on_train_end(self, trainer, pl_module):
|
368 |
+
print(f"Pushing model to the hub after training")
|
369 |
+
pl_module.processor.push_to_hub(DonutTrainer.huggingface_model_id,commit_message=f"Training done")
|
370 |
+
pl_module.model.push_to_hub(DonutTrainer.huggingface_model_id, commit_message=f"Training done")
|
371 |
+
self._upload_logs(trainer.logger.log_dir, "final")
|
372 |
+
|
373 |
+
def _upload_logs(self, log_dir: str, epoch_info):
|
374 |
+
try:
|
375 |
+
print(f"Attempting to upload logs from: {log_dir}")
|
376 |
+
upload_folder(folder_path=log_dir, repo_id=DonutTrainer.huggingface_model_id,
|
377 |
+
path_in_repo="tensorboard_logs",
|
378 |
+
commit_message=f"Upload logs - epoch {epoch_info}", ignore_patterns=["*.tmp", "*.lock"])
|
379 |
+
print(f"Successfully uploaded logs for epoch {epoch_info}")
|
380 |
+
except Exception as e:
|
381 |
+
print(f"Failed to upload logs: {e}")
|
382 |
+
pass
|
383 |
+
|
384 |
+
@classmethod
|
385 |
+
def train(
|
386 |
+
cls,
|
387 |
+
dataset: DatasetDict,
|
388 |
+
pretrained_model_repo_id: str,
|
389 |
+
huggingface_model_id: str,
|
390 |
+
epochs: int,
|
391 |
+
train_batch_size: int,
|
392 |
+
val_batch_size: int,
|
393 |
+
learning_rate: float,
|
394 |
+
val_check_interval: float,
|
395 |
+
check_val_every_n_epoch: int,
|
396 |
+
gradient_clip_val: float,
|
397 |
+
num_training_samples_per_epoch: int,
|
398 |
+
num_nodes: int,
|
399 |
+
warmup_steps: int,
|
400 |
+
ground_truth_key: str = "ground_truth"
|
401 |
+
):
|
402 |
+
cls.huggingface_model_id = huggingface_model_id
|
403 |
+
config = VisionEncoderDecoderConfig.from_pretrained(pretrained_model_repo_id)
|
404 |
+
config.encoder.image_size = cls.image_size
|
405 |
+
config.decoder.max_length = cls.max_length
|
406 |
+
|
407 |
+
cls.processor = DonutProcessor.from_pretrained(pretrained_model_repo_id)
|
408 |
+
cls.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_repo_id, config=config)
|
409 |
+
cls.processor.image_processor.size = cls.image_size[::-1]
|
410 |
+
cls.processor.image_processor.do_align_long_axis = False
|
411 |
+
|
412 |
+
train_dataset = cls.DonutDataset(
|
413 |
+
dataset=dataset,
|
414 |
+
ground_truth_key=ground_truth_key,
|
415 |
+
max_length=cls.max_length,
|
416 |
+
split="train",
|
417 |
+
task_start_token=TASK_PROMPT_NAME,
|
418 |
+
prompt_end_token=TASK_PROMPT_NAME,
|
419 |
+
sort_json_key=True
|
420 |
+
)
|
421 |
+
val_dataset = cls.DonutDataset(
|
422 |
+
dataset=dataset,
|
423 |
+
ground_truth_key=ground_truth_key,
|
424 |
+
max_length=cls.max_length,
|
425 |
+
split="validation",
|
426 |
+
task_start_token=TASK_PROMPT_NAME,
|
427 |
+
prompt_end_token=TASK_PROMPT_NAME,
|
428 |
+
sort_json_key=True
|
429 |
+
)
|
430 |
+
|
431 |
+
cls.model.config.pad_token_id = cls.processor.tokenizer.pad_token_id
|
432 |
+
cls.model.config.decoder_start_token_id = cls.processor.tokenizer.convert_tokens_to_ids([TASK_PROMPT_NAME])[0]
|
433 |
+
|
434 |
+
cls.train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
|
435 |
+
cls.val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
|
436 |
+
|
437 |
+
config = {
|
438 |
+
"max_epochs": epochs,
|
439 |
+
"val_check_interval": val_check_interval, # how many times we want to validate during an epoch
|
440 |
+
"check_val_every_n_epoch": check_val_every_n_epoch,
|
441 |
+
"gradient_clip_val": gradient_clip_val,
|
442 |
+
"num_training_samples_per_epoch": num_training_samples_per_epoch,
|
443 |
+
"lr": learning_rate,
|
444 |
+
"train_batch_sizes": [train_batch_size],
|
445 |
+
"val_batch_sizes": [val_batch_size],
|
446 |
+
# "seed":2022,
|
447 |
+
"num_nodes": num_nodes,
|
448 |
+
"warmup_steps": warmup_steps, # 10%
|
449 |
+
"result_path": "./.checkpoints",
|
450 |
+
"verbose": True,
|
451 |
+
}
|
452 |
+
model_module = cls.DonutModelPLModule(config, cls.processor, cls.model)
|
453 |
+
|
454 |
+
device = (
|
455 |
+
"cuda"
|
456 |
+
if torch.cuda.is_available()
|
457 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
458 |
+
)
|
459 |
+
print(f"Using {device} device")
|
460 |
+
trainer = pl.Trainer(
|
461 |
+
accelerator="gpu" if device == "cuda" else "mps" if device == "mps" else "cpu",
|
462 |
+
devices=1 if device == "cuda" else 0,
|
463 |
+
max_epochs=config.get("max_epochs"),
|
464 |
+
val_check_interval=config.get("val_check_interval"),
|
465 |
+
check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
|
466 |
+
gradient_clip_val=config.get("gradient_clip_val"),
|
467 |
+
precision=16 if device == "cuda" else 32, # we'll use mixed precision if device == "cuda"
|
468 |
+
num_sanity_val_steps=0,
|
469 |
+
logger=TensorBoardLogger(save_dir="./.checkpoints", name="donut_training", version=None),
|
470 |
+
callbacks=[cls.PushToHubCallback()]
|
471 |
+
)
|
472 |
+
trainer.fit(model_module)
|
menu/llm/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .gemini import GeminiAPI
|
2 |
+
from .openai import OpenAIAPI
|
menu/llm/base.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class LLMBase(ABC):
|
6 |
+
@classmethod
|
7 |
+
@abstractmethod
|
8 |
+
def call(image: np.ndarray, model: str, token: str) -> dict:
|
9 |
+
raise NotImplementedError
|
menu/llm/gemini.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from google import genai
|
6 |
+
from google.genai import types
|
7 |
+
|
8 |
+
from .base import LLMBase
|
9 |
+
|
10 |
+
FUNCTION_CALL = json.load(open("tools/schema_gemini.json", "r"))
|
11 |
+
|
12 |
+
class GeminiAPI(LLMBase):
|
13 |
+
@classmethod
|
14 |
+
def call(cls, image: np.ndarray, model: str, token: str) -> dict:
|
15 |
+
client = genai.Client(api_key=token) # Initialize the client with the API key
|
16 |
+
encode_img = Image.fromarray(image) # Convert the image for the API
|
17 |
+
|
18 |
+
config = types.GenerateContentConfig(
|
19 |
+
tools=[types.Tool(function_declarations=[FUNCTION_CALL])],
|
20 |
+
tool_config={
|
21 |
+
"function_calling_config": {
|
22 |
+
"mode": "ANY",
|
23 |
+
"allowed_function_names": [FUNCTION_CALL["name"]]
|
24 |
+
}
|
25 |
+
}
|
26 |
+
)
|
27 |
+
response = client.models.generate_content(
|
28 |
+
model=model,
|
29 |
+
contents=[encode_img],
|
30 |
+
config=config
|
31 |
+
)
|
32 |
+
if response.candidates[0].content.parts[0].function_call:
|
33 |
+
function_call = response.candidates[0].content.parts[0].function_call
|
34 |
+
return function_call.args
|
35 |
+
|
36 |
+
return {}
|
menu/llm/openai.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import base64
|
3 |
+
from io import BytesIO
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from openai import OpenAI
|
8 |
+
|
9 |
+
from .base import LLMBase
|
10 |
+
|
11 |
+
FUNCTION_CALL = json.load(open("tools/schema_openai.json", "r"))
|
12 |
+
|
13 |
+
class OpenAIAPI(LLMBase):
|
14 |
+
@classmethod
|
15 |
+
def call(cls, image: np.ndarray, model: str, token: str) -> dict:
|
16 |
+
client = OpenAI(api_key=token) # Initialize the client with the API key
|
17 |
+
buffer = BytesIO()
|
18 |
+
Image.fromarray(image).save(buffer, format="JPEG")
|
19 |
+
encode_img = base64.b64encode(buffer.getvalue()).decode("utf-8") # Convert the image for the API
|
20 |
+
|
21 |
+
response = client.responses.create(
|
22 |
+
model=model,
|
23 |
+
input=[
|
24 |
+
{
|
25 |
+
"role": "user",
|
26 |
+
"content": [
|
27 |
+
{
|
28 |
+
"type": "input_image",
|
29 |
+
"image_url": f"data:image/jpeg;base64,{encode_img}",
|
30 |
+
},
|
31 |
+
],
|
32 |
+
}
|
33 |
+
],
|
34 |
+
tools=[FUNCTION_CALL],
|
35 |
+
)
|
36 |
+
if response and response.output:
|
37 |
+
if hasattr(response.output[0], "arguments"):
|
38 |
+
return json.loads(response.output[0].arguments)
|
39 |
+
return {}
|
menu/utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from datasets import Dataset, DatasetDict
|
4 |
+
|
5 |
+
def split_dataset(
|
6 |
+
dataset: Dataset,
|
7 |
+
train: float,
|
8 |
+
validation: float,
|
9 |
+
test: float,
|
10 |
+
seed: Optional[int] = None
|
11 |
+
) -> DatasetDict:
|
12 |
+
"""
|
13 |
+
Split a single-split Hugging Face Dataset into train/validation/test subsets.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
dataset (Dataset): The input dataset (e.g. load_dataset(...)['train']).
|
17 |
+
train (float): Proportion of data for the train split (0 < train < 1).
|
18 |
+
val (float): Proportion of data for the validation split (0 < val < 1).
|
19 |
+
test (float): Proportion of data for the test split (0 < test < 1).
|
20 |
+
Must satisfy train + val + test == 1.0.
|
21 |
+
seed (int): Random seed for reproducibility (default: None).
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
DatasetDict: A dictionary with keys "train", "validation", and "test".
|
25 |
+
"""
|
26 |
+
# Verify ratios sum to 1.0
|
27 |
+
total = train + validation + test
|
28 |
+
if abs(total - 1.0) > 1e-8:
|
29 |
+
raise ValueError(f"train + validation + test must equal 1.0 (got {total})")
|
30 |
+
|
31 |
+
# First split: extract train vs. temp (validation + test)
|
32 |
+
temp_size = validation + test
|
33 |
+
split_1 = dataset.train_test_split(test_size=temp_size, seed=seed)
|
34 |
+
train_ds = split_1["train"]
|
35 |
+
temp_ds = split_1["test"]
|
36 |
+
|
37 |
+
# Second split: divide temp into validation vs. test
|
38 |
+
relative_test_size = test / temp_size
|
39 |
+
split_2 = temp_ds.train_test_split(test_size=relative_test_size, seed=seed)
|
40 |
+
validation_ds = split_2["train"]
|
41 |
+
test_ds = split_2["test"]
|
42 |
+
|
43 |
+
# Return a DatasetDict with all three splits
|
44 |
+
return DatasetDict({
|
45 |
+
"train": train_ds,
|
46 |
+
"validation": validation_ds,
|
47 |
+
"test": test_ds,
|
48 |
+
})
|
pyproject.toml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
authors = [{name = "ryanlinjui", email = "ryanlinjui@gmail.com"}]
|
3 |
+
name = "menu-text-detection"
|
4 |
+
version = "0.1.0"
|
5 |
+
description = "Extract structured menu information from images into JSON using a fine-tuned Donut E2E model."
|
6 |
+
readme = "README.md"
|
7 |
+
requires-python = "==3.11.*"
|
8 |
+
dependencies = [
|
9 |
+
"datasets>=3.6.0",
|
10 |
+
"dotenv>=0.9.9",
|
11 |
+
"google-genai>=1.14.0",
|
12 |
+
"gradio>=5.29.0",
|
13 |
+
"huggingface-hub>=0.31.1",
|
14 |
+
"matplotlib>=3.10.1",
|
15 |
+
"nltk>=3.9.1",
|
16 |
+
"notebook>=7.4.2",
|
17 |
+
"openai>=1.77.0",
|
18 |
+
"pillow>=11.2.1",
|
19 |
+
"pillow-heif>=0.22.0",
|
20 |
+
"protobuf>=6.30.2",
|
21 |
+
"pytorch-lightning>=2.5.2",
|
22 |
+
"sentencepiece>=0.2.0",
|
23 |
+
"tensorboardx>=2.6.2.2",
|
24 |
+
"transformers==4.49",
|
25 |
+
"torch==2.4.1",
|
26 |
+
"donut-python>=1.0.9",
|
27 |
+
]
|
requirements.txt
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==24.1.0
|
2 |
+
aiohappyeyeballs==2.6.1
|
3 |
+
aiohttp==3.11.18
|
4 |
+
aiosignal==1.3.2
|
5 |
+
annotated-types==0.7.0
|
6 |
+
anyio==4.9.0
|
7 |
+
appnope==0.1.4
|
8 |
+
argon2-cffi==23.1.0
|
9 |
+
argon2-cffi-bindings==21.2.0
|
10 |
+
arrow==1.3.0
|
11 |
+
asttokens==3.0.0
|
12 |
+
async-lru==2.0.5
|
13 |
+
attrs==25.3.0
|
14 |
+
babel==2.17.0
|
15 |
+
beautifulsoup4==4.13.4
|
16 |
+
bleach==6.2.0
|
17 |
+
cachetools==5.5.2
|
18 |
+
certifi==2025.4.26
|
19 |
+
cffi==1.17.1
|
20 |
+
charset-normalizer==3.4.2
|
21 |
+
click==8.1.8
|
22 |
+
comm==0.2.2
|
23 |
+
contourpy==1.3.2
|
24 |
+
cycler==0.12.1
|
25 |
+
datasets==3.6.0
|
26 |
+
debugpy==1.8.14
|
27 |
+
decorator==5.2.1
|
28 |
+
defusedxml==0.7.1
|
29 |
+
dill==0.3.8
|
30 |
+
distro==1.9.0
|
31 |
+
donut-python==1.0.9
|
32 |
+
dotenv==0.9.9
|
33 |
+
executing==2.2.0
|
34 |
+
fastapi==0.115.12
|
35 |
+
fastjsonschema==2.21.1
|
36 |
+
ffmpy==0.5.0
|
37 |
+
filelock==3.18.0
|
38 |
+
fonttools==4.57.0
|
39 |
+
fqdn==1.5.1
|
40 |
+
frozenlist==1.6.0
|
41 |
+
fsspec==2025.3.0
|
42 |
+
google-auth==2.40.1
|
43 |
+
google-genai==1.14.0
|
44 |
+
gradio==5.29.0
|
45 |
+
gradio-client==1.10.0
|
46 |
+
groovy==0.1.2
|
47 |
+
h11==0.16.0
|
48 |
+
hf-xet==1.1.0
|
49 |
+
httpcore==1.0.9
|
50 |
+
httpx==0.28.1
|
51 |
+
huggingface-hub==0.31.1
|
52 |
+
idna==3.10
|
53 |
+
ipykernel==6.29.5
|
54 |
+
ipython==9.2.0
|
55 |
+
ipython-pygments-lexers==1.1.1
|
56 |
+
isoduration==20.11.0
|
57 |
+
jedi==0.19.2
|
58 |
+
jinja2==3.1.6
|
59 |
+
jiter==0.9.0
|
60 |
+
joblib==1.5.0
|
61 |
+
json5==0.12.0
|
62 |
+
jsonpointer==3.0.0
|
63 |
+
jsonschema==4.23.0
|
64 |
+
jsonschema-specifications==2025.4.1
|
65 |
+
jupyter-client==8.6.3
|
66 |
+
jupyter-core==5.7.2
|
67 |
+
jupyter-events==0.12.0
|
68 |
+
jupyter-lsp==2.2.5
|
69 |
+
jupyter-server==2.15.0
|
70 |
+
jupyter-server-terminals==0.5.3
|
71 |
+
jupyterlab==4.4.2
|
72 |
+
jupyterlab-pygments==0.3.0
|
73 |
+
jupyterlab-server==2.27.3
|
74 |
+
kiwisolver==1.4.8
|
75 |
+
lightning-utilities==0.14.3
|
76 |
+
markdown-it-py==3.0.0
|
77 |
+
markupsafe==3.0.2
|
78 |
+
matplotlib==3.10.1
|
79 |
+
matplotlib-inline==0.1.7
|
80 |
+
mdurl==0.1.2
|
81 |
+
mistune==3.1.3
|
82 |
+
mpmath==1.3.0
|
83 |
+
multidict==6.4.3
|
84 |
+
multiprocess==0.70.16
|
85 |
+
munch==4.0.0
|
86 |
+
nbclient==0.10.2
|
87 |
+
nbconvert==7.16.6
|
88 |
+
nbformat==5.10.4
|
89 |
+
nest-asyncio==1.6.0
|
90 |
+
networkx==3.4.2
|
91 |
+
nltk==3.9.1
|
92 |
+
notebook==7.4.2
|
93 |
+
notebook-shim==0.2.4
|
94 |
+
numpy==2.2.5
|
95 |
+
openai==1.77.0
|
96 |
+
orjson==3.10.18
|
97 |
+
overrides==7.7.0
|
98 |
+
packaging==25.0
|
99 |
+
pandas==2.2.3
|
100 |
+
pandocfilters==1.5.1
|
101 |
+
parso==0.8.4
|
102 |
+
pexpect==4.9.0
|
103 |
+
pillow==11.2.1
|
104 |
+
pillow-heif==0.22.0
|
105 |
+
platformdirs==4.3.8
|
106 |
+
prometheus-client==0.21.1
|
107 |
+
prompt-toolkit==3.0.51
|
108 |
+
propcache==0.3.1
|
109 |
+
protobuf==6.30.2
|
110 |
+
psutil==7.0.0
|
111 |
+
ptyprocess==0.7.0
|
112 |
+
pure-eval==0.2.3
|
113 |
+
pyarrow==20.0.0
|
114 |
+
pyasn1==0.6.1
|
115 |
+
pyasn1-modules==0.4.2
|
116 |
+
pycparser==2.22
|
117 |
+
pydantic==2.11.4
|
118 |
+
pydantic-core==2.33.2
|
119 |
+
pydub==0.25.1
|
120 |
+
pygments==2.19.1
|
121 |
+
pyparsing==3.2.3
|
122 |
+
python-dateutil==2.9.0.post0
|
123 |
+
python-dotenv==1.1.0
|
124 |
+
python-json-logger==3.3.0
|
125 |
+
python-multipart==0.0.20
|
126 |
+
pytorch-lightning==2.5.2
|
127 |
+
pytz==2025.2
|
128 |
+
pyyaml==6.0.2
|
129 |
+
pyzmq==26.4.0
|
130 |
+
referencing==0.36.2
|
131 |
+
regex==2024.11.6
|
132 |
+
requests==2.32.3
|
133 |
+
rfc3339-validator==0.1.4
|
134 |
+
rfc3986-validator==0.1.1
|
135 |
+
rich==14.0.0
|
136 |
+
rpds-py==0.24.0
|
137 |
+
rsa==4.9.1
|
138 |
+
ruamel-yaml==0.18.14
|
139 |
+
ruamel-yaml-clib==0.2.12
|
140 |
+
ruff==0.11.8
|
141 |
+
safehttpx==0.1.6
|
142 |
+
safetensors==0.5.3
|
143 |
+
sconf==0.2.5
|
144 |
+
semantic-version==2.10.0
|
145 |
+
send2trash==1.8.3
|
146 |
+
sentencepiece==0.2.0
|
147 |
+
setuptools==80.3.1
|
148 |
+
shellingham==1.5.4
|
149 |
+
six==1.17.0
|
150 |
+
sniffio==1.3.1
|
151 |
+
soupsieve==2.7
|
152 |
+
stack-data==0.6.3
|
153 |
+
starlette==0.46.2
|
154 |
+
sympy==1.14.0
|
155 |
+
tensorboardx==2.6.2.2
|
156 |
+
terminado==0.18.1
|
157 |
+
timm==1.0.16
|
158 |
+
tinycss2==1.4.0
|
159 |
+
tokenizers==0.21.1
|
160 |
+
tomlkit==0.13.2
|
161 |
+
torch==2.4.1
|
162 |
+
torchmetrics==1.7.3
|
163 |
+
torchvision==0.19.1
|
164 |
+
tornado==6.4.2
|
165 |
+
tqdm==4.67.1
|
166 |
+
traitlets==5.14.3
|
167 |
+
transformers==4.49.0
|
168 |
+
typer==0.15.3
|
169 |
+
types-python-dateutil==2.9.0.20241206
|
170 |
+
typing-extensions==4.13.2
|
171 |
+
typing-inspection==0.4.0
|
172 |
+
tzdata==2025.2
|
173 |
+
uri-template==1.3.0
|
174 |
+
urllib3==2.4.0
|
175 |
+
uvicorn==0.34.2
|
176 |
+
wcwidth==0.2.13
|
177 |
+
webcolors==24.11.1
|
178 |
+
webencodings==0.5.1
|
179 |
+
websocket-client==1.8.0
|
180 |
+
websockets==15.0.1
|
181 |
+
xxhash==3.5.0
|
182 |
+
yarl==1.20.0
|
183 |
+
zss==1.2.0
|
tools/schema_gemini.json
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "extract_menu_data",
|
3 |
+
"description": "Extract structured menu information from images.",
|
4 |
+
"parameters": {
|
5 |
+
"type": "object",
|
6 |
+
"properties": {
|
7 |
+
"restaurant": {
|
8 |
+
"type": "string",
|
9 |
+
"description": "Name of the restaurant. If the name is not available, it should be ''."
|
10 |
+
},
|
11 |
+
"address": {
|
12 |
+
"type": "string",
|
13 |
+
"description": "Address of the restaurant. If the address is not available, it should be ''."
|
14 |
+
},
|
15 |
+
"phone": {
|
16 |
+
"type": "string",
|
17 |
+
"description": "Phone number of the restaurant. If the phone number is not available, it should be ''."
|
18 |
+
},
|
19 |
+
"business_hours": {
|
20 |
+
"type": "string",
|
21 |
+
"description": "Business hours of the restaurant. If the business hours are not available, it should be ''."
|
22 |
+
},
|
23 |
+
"dishes": {
|
24 |
+
"type": "array",
|
25 |
+
"items": {
|
26 |
+
"type": "object",
|
27 |
+
"properties": {
|
28 |
+
"name": {
|
29 |
+
"type": "string",
|
30 |
+
"description": "Name of the menu item."
|
31 |
+
},
|
32 |
+
"price": {
|
33 |
+
"type": "string",
|
34 |
+
"description": "Price of the menu item. If the price is not available, it should be -1."
|
35 |
+
}
|
36 |
+
},
|
37 |
+
"required": ["name", "price"]
|
38 |
+
},
|
39 |
+
"description": "List of menu dishes item."
|
40 |
+
}
|
41 |
+
},
|
42 |
+
"required": ["restaurant", "address", "phone", "business_hours", "dishes"]
|
43 |
+
}
|
44 |
+
}
|
tools/schema_openai.json
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"type": "function",
|
3 |
+
"name": "extract_menu_data",
|
4 |
+
"description": "Extract structured menu information from images.",
|
5 |
+
"parameters": {
|
6 |
+
"type": "object",
|
7 |
+
"properties": {
|
8 |
+
"restaurant": {
|
9 |
+
"type": "string",
|
10 |
+
"description": "Name of the restaurant. If the name is not available, it should be ''."
|
11 |
+
},
|
12 |
+
"address": {
|
13 |
+
"type": "string",
|
14 |
+
"description": "Address of the restaurant. If the address is not available, it should be ''."
|
15 |
+
},
|
16 |
+
"phone": {
|
17 |
+
"type": "string",
|
18 |
+
"description": "Phone number of the restaurant. If the phone number is not available, it should be ''."
|
19 |
+
},
|
20 |
+
"business_hours": {
|
21 |
+
"type": "string",
|
22 |
+
"description": "Business hours of the restaurant. If the business hours are not available, it should be ''."
|
23 |
+
},
|
24 |
+
"dishes": {
|
25 |
+
"type": "array",
|
26 |
+
"items": {
|
27 |
+
"type": "object",
|
28 |
+
"properties": {
|
29 |
+
"name": {
|
30 |
+
"type": "string",
|
31 |
+
"description": "Name of the menu item."
|
32 |
+
},
|
33 |
+
"price": {
|
34 |
+
"type": "string",
|
35 |
+
"description": "Price of the menu item. If the price is not available, it should be -1."
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"required": ["name", "price"],
|
39 |
+
"additionalProperties": false
|
40 |
+
},
|
41 |
+
"description": "List of menu dishes item."
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"required": ["restaurant", "address", "phone", "business_hours", "dishes"],
|
45 |
+
"additionalProperties": false
|
46 |
+
}
|
47 |
+
}
|
train.ipynb
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Login to HuggingFace (just login once)"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": null,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"from huggingface_hub import interpreter_login\n",
|
17 |
+
"interpreter_login()"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "markdown",
|
22 |
+
"metadata": {},
|
23 |
+
"source": [
|
24 |
+
"# Collect Menu Image Datasets\n",
|
25 |
+
"- Use `metadata.jsonl` to label the images's ground truth. You can visit [here](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) to see the examples.\n",
|
26 |
+
"- After finishing, push to HuggingFace Datasets.\n",
|
27 |
+
"- For labeling:\n",
|
28 |
+
" - [Google AI Studio](https://aistudio.google.com) or [OpenAI ChatGPT](https://chatgpt.com).\n",
|
29 |
+
" - Use function calling by API. Start the gradio app locally or visit [here](https://huggingface.co/spaces/ryanlinjui/menu-text-detection).\n",
|
30 |
+
"\n",
|
31 |
+
"### Menu Type\n",
|
32 |
+
"- **h**: horizontal menu\n",
|
33 |
+
"- **v**: vertical menu\n",
|
34 |
+
"- **d**: document-style menu\n",
|
35 |
+
"- **s**: in-scene menu (non-document style)\n",
|
36 |
+
"- **i**: irregular menu (menu with irregular text layout)\n",
|
37 |
+
"\n",
|
38 |
+
"> Please see the [examples](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) for more details."
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": null,
|
44 |
+
"metadata": {},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"import os\n",
|
48 |
+
"import json\n",
|
49 |
+
"\n",
|
50 |
+
"import numpy as np\n",
|
51 |
+
"from PIL import Image\n",
|
52 |
+
"from pillow_heif import register_heif_opener\n",
|
53 |
+
"\n",
|
54 |
+
"from menu.llm import (\n",
|
55 |
+
" GeminiAPI,\n",
|
56 |
+
" OpenAIAPI\n",
|
57 |
+
")\n",
|
58 |
+
"\n",
|
59 |
+
"IMAGE_DIR = \"datasets/images\" # set your image directory here\n",
|
60 |
+
"SELECTED_MODEL = \"gemini-2.5-flash\" # set model name here, refer MODEL_LIST from app.py for more\n",
|
61 |
+
"API_TOKEN = \"\" # set your API token here\n",
|
62 |
+
"SELECTED_FUNCTION = GeminiAPI # set \"GeminiAPI\" or \"OpenAIAPI\"\n",
|
63 |
+
"\n",
|
64 |
+
"register_heif_opener()\n",
|
65 |
+
"\n",
|
66 |
+
"for file in os.listdir(IMAGE_DIR):\n",
|
67 |
+
" print(f\"Processing image: {file}\")\n",
|
68 |
+
" try:\n",
|
69 |
+
" image = np.array(Image.open(os.path.join(IMAGE_DIR, file)))\n",
|
70 |
+
" data = {\n",
|
71 |
+
" \"file_name\": file,\n",
|
72 |
+
" \"menu\": SELECTED_FUNCTION.call(image, SELECTED_MODEL, API_TOKEN)\n",
|
73 |
+
" }\n",
|
74 |
+
" with open(os.path.join(IMAGE_DIR, \"metadata.jsonl\"), \"a\", encoding=\"utf-8\") as metaf:\n",
|
75 |
+
" metaf.write(json.dumps(data, ensure_ascii=False, sort_keys=True) + \"\\n\")\n",
|
76 |
+
" except Exception as e:\n",
|
77 |
+
" print(f\"Skipping invalid image '{file}': {e}\")\n",
|
78 |
+
" continue"
|
79 |
+
]
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "markdown",
|
83 |
+
"metadata": {},
|
84 |
+
"source": [
|
85 |
+
"# Push Datasets to HuggingFace"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": null,
|
91 |
+
"metadata": {},
|
92 |
+
"outputs": [],
|
93 |
+
"source": [
|
94 |
+
"from datasets import load_dataset\n",
|
95 |
+
"\n",
|
96 |
+
"dataset = load_dataset(path=\"datasets/menu-zh-TW\") # load dataset from the local directory including the metadata.jsonl, images files.\n",
|
97 |
+
"dataset.push_to_hub(repo_id=\"ryanlinjui/menu-zh-TW\") # push to the huggingface dataset hub"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "markdown",
|
102 |
+
"metadata": {},
|
103 |
+
"source": [
|
104 |
+
"# Prepare the dataset for training"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": null,
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"from menu.utils import split_dataset\n",
|
114 |
+
"from datasets import load_dataset\n",
|
115 |
+
"\n",
|
116 |
+
"dataset = load_dataset(path=\"ryanlinjui/menu-zh-TW\") # set your dataset repo id for training\n",
|
117 |
+
"dataset = split_dataset(dataset[\"train\"], train=0.8, validation=0.1, test=0.1, seed=42) # (optional) use it if your dataset is not split into train/validation/test\n",
|
118 |
+
"print(f\"Dataset split: {len(dataset['train'])} train, {len(dataset['validation'])} validation, {len(dataset['test'])} test\")"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "markdown",
|
123 |
+
"metadata": {},
|
124 |
+
"source": [
|
125 |
+
"# Fine-tune Donut Model"
|
126 |
+
]
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"cell_type": "code",
|
130 |
+
"execution_count": null,
|
131 |
+
"metadata": {},
|
132 |
+
"outputs": [],
|
133 |
+
"source": [
|
134 |
+
"import logging\n",
|
135 |
+
"from menu.donut import DonutTrainer\n",
|
136 |
+
"\n",
|
137 |
+
"logging.getLogger(\"transformers\").setLevel(logging.ERROR) # filter output message from transformers\n",
|
138 |
+
"\n",
|
139 |
+
"DonutTrainer.train(\n",
|
140 |
+
" dataset=dataset,\n",
|
141 |
+
" pretrained_model_repo_id=\"naver-clova-ix/donut-base\", # set your pretrained model repo id for fine-tuning\n",
|
142 |
+
" ground_truth_key=\"menu\", # set your ground truth key for training\n",
|
143 |
+
" huggingface_model_id=\"ryanlinjui/donut-base-finetuned-menu\", # set your huggingface model repo id for saving / pushing to the hub\n",
|
144 |
+
" epochs=15, # set your training epochs\n",
|
145 |
+
" train_batch_size=8, # set your training batch size\n",
|
146 |
+
" val_batch_size=1, # set your validation batch size\n",
|
147 |
+
" learning_rate=3e-5, # set your learning rate\n",
|
148 |
+
" val_check_interval=0.5, # how many times we want to validate during an epoch\n",
|
149 |
+
" check_val_every_n_epoch=1, # how many epochs we want to validate\n",
|
150 |
+
" gradient_clip_val=1.0, # gradient clipping value for training stability\n",
|
151 |
+
" num_training_samples_per_epoch=198, # set num_training_samples_per_epoch = training set size\n",
|
152 |
+
" num_nodes=1, # number of nodes for distributed training\n",
|
153 |
+
" warmup_steps=75 # number of warmup steps for learning rate scheduler, 198/8*30/10, 10%\n",
|
154 |
+
")"
|
155 |
+
]
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"cell_type": "markdown",
|
159 |
+
"metadata": {},
|
160 |
+
"source": [
|
161 |
+
"# Evaluate Donut Model"
|
162 |
+
]
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"cell_type": "code",
|
166 |
+
"execution_count": null,
|
167 |
+
"metadata": {},
|
168 |
+
"outputs": [],
|
169 |
+
"source": [
|
170 |
+
"import json\n",
|
171 |
+
"from datasets import load_dataset\n",
|
172 |
+
"\n",
|
173 |
+
"from menu.utils import split_dataset\n",
|
174 |
+
"from menu.donut import DonutFinetuned\n",
|
175 |
+
"\n",
|
176 |
+
"dataset = load_dataset(\"ryanlinjui/menu-zh-TW\")\n",
|
177 |
+
"dataset = split_dataset(dataset[\"train\"], train=0.8, validation=0.1, test=0.1, seed=42) # (optional) use it if your dataset is not split into train/validation/test\n",
|
178 |
+
"donut_finetuned = DonutFinetuned(pretrained_model_repo_id=\"ryanlinjui/donut-base-finetuned-menu\")\n",
|
179 |
+
"scores, output_list = donut_finetuned.evaluate(dataset=dataset[\"test\"], ground_truth_key=\"menu\")\n",
|
180 |
+
"\n",
|
181 |
+
"print(\"Evaluation scores:\")\n",
|
182 |
+
"for key, value in scores.items():\n",
|
183 |
+
" print(f\"{key}: {value}\")\n",
|
184 |
+
"\n",
|
185 |
+
"print(\"\\nSample outputs:\")\n",
|
186 |
+
"for output in output_list[:5]:\n",
|
187 |
+
" print(json.dumps(output, ensure_ascii=False, indent=4))"
|
188 |
+
]
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"cell_type": "markdown",
|
192 |
+
"metadata": {},
|
193 |
+
"source": [
|
194 |
+
"# Test Donut Model"
|
195 |
+
]
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"cell_type": "code",
|
199 |
+
"execution_count": null,
|
200 |
+
"metadata": {},
|
201 |
+
"outputs": [],
|
202 |
+
"source": [
|
203 |
+
"from PIL import Image\n",
|
204 |
+
"from menu.donut import DonutFinetuned\n",
|
205 |
+
"\n",
|
206 |
+
"image = Image.open(\"./examples/menu-hd.jpg\")\n",
|
207 |
+
"\n",
|
208 |
+
"donut_finetuned = DonutFinetuned(pretrained_model_repo_id=\"ryanlinjui/donut-base-finetuned-menu\")\n",
|
209 |
+
"outputs = donut_finetuned.predict(image=image)\n",
|
210 |
+
"print(outputs)"
|
211 |
+
]
|
212 |
+
}
|
213 |
+
],
|
214 |
+
"metadata": {
|
215 |
+
"kernelspec": {
|
216 |
+
"display_name": "menu-text-detection",
|
217 |
+
"language": "python",
|
218 |
+
"name": "python3"
|
219 |
+
},
|
220 |
+
"language_info": {
|
221 |
+
"codemirror_mode": {
|
222 |
+
"name": "ipython",
|
223 |
+
"version": 3
|
224 |
+
},
|
225 |
+
"file_extension": ".py",
|
226 |
+
"mimetype": "text/x-python",
|
227 |
+
"name": "python",
|
228 |
+
"nbconvert_exporter": "python",
|
229 |
+
"pygments_lexer": "ipython3",
|
230 |
+
"version": "3.11.12"
|
231 |
+
}
|
232 |
+
},
|
233 |
+
"nbformat": 4,
|
234 |
+
"nbformat_minor": 2
|
235 |
+
}
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|