Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
003891a
1
Parent(s):
5639776
add second model
Browse files
app.py
CHANGED
@@ -7,16 +7,35 @@ from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline
|
|
7 |
import spaces
|
8 |
|
9 |
# --- Global Model and Processor ---
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
MODEL_LOAD_ERROR_MSG =
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
# --- Helper Functions ---
|
@@ -139,46 +158,63 @@ def parse_xml_for_text(xml_file_path):
|
|
139 |
|
140 |
|
141 |
@spaces.GPU
|
142 |
-
def predict(pil_image):
|
143 |
-
"""Performs OCR prediction using the Hugging Face model."""
|
144 |
-
global
|
145 |
-
|
146 |
-
if
|
147 |
-
error_to_report = (
|
148 |
-
|
149 |
-
|
150 |
-
else "OCR model could not be initialized."
|
151 |
)
|
152 |
raise RuntimeError(error_to_report)
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
"
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
# Use the pipeline with the properly formatted messages
|
169 |
-
return
|
170 |
|
171 |
|
172 |
-
def run_hf_ocr(image_path):
|
173 |
"""
|
174 |
-
Runs OCR on the provided image using the Hugging Face model (via predict function).
|
175 |
"""
|
176 |
if image_path is None:
|
177 |
return "No image provided for OCR."
|
178 |
|
179 |
try:
|
180 |
pil_image = Image.open(image_path).convert("RGB")
|
181 |
-
ocr_results = predict(pil_image) # predict handles model loading and inference
|
182 |
|
183 |
# Parse the output based on the user's example structure
|
184 |
if (
|
@@ -237,10 +273,10 @@ def run_hf_ocr(image_path):
|
|
237 |
# --- Gradio Interface Function ---
|
238 |
|
239 |
|
240 |
-
def process_files(image_path, xml_path):
|
241 |
"""
|
242 |
Main function for the Gradio interface.
|
243 |
-
Processes the image for display, runs OCR
|
244 |
and parses XML if provided.
|
245 |
"""
|
246 |
img_to_display = None
|
@@ -250,10 +286,10 @@ def process_files(image_path, xml_path):
|
|
250 |
if image_path:
|
251 |
try:
|
252 |
img_to_display = Image.open(image_path).convert("RGB")
|
253 |
-
hf_ocr_text_output = run_hf_ocr(image_path)
|
254 |
except Exception as e:
|
255 |
img_to_display = None # Clear image if it failed to load
|
256 |
-
hf_ocr_text_output = f"Error loading image or running
|
257 |
else:
|
258 |
hf_ocr_text_output = "Please upload an image to perform OCR."
|
259 |
|
@@ -281,6 +317,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
281 |
|
282 |
with gr.Row():
|
283 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
image_input = gr.File(
|
285 |
label="Upload Image (PNG, JPG, etc.)", type="filepath"
|
286 |
)
|
@@ -296,7 +338,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
296 |
)
|
297 |
with gr.Column(scale=1):
|
298 |
hf_ocr_output_textbox = gr.Markdown(
|
299 |
-
label="OCR Output
|
300 |
show_copy_button=True,
|
301 |
)
|
302 |
xml_output_textbox = gr.Textbox(
|
@@ -308,7 +350,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
308 |
|
309 |
submit_button.click(
|
310 |
fn=process_files,
|
311 |
-
inputs=[image_input, xml_input],
|
312 |
outputs=[output_image_display, xml_output_textbox, hf_ocr_output_textbox],
|
313 |
)
|
314 |
|
|
|
7 |
import spaces
|
8 |
|
9 |
# --- Global Model and Processor ---
|
10 |
+
MODELS = {}
|
11 |
+
PROCESSORS = {}
|
12 |
+
PIPELINES = {}
|
13 |
+
MODEL_LOAD_ERROR_MSG = {}
|
14 |
+
|
15 |
+
# Available models
|
16 |
+
AVAILABLE_MODELS = ["RolmOCR", "Nanonets-OCR-s"]
|
17 |
+
|
18 |
+
# Load RolmOCR
|
19 |
+
try:
|
20 |
+
PROCESSORS["RolmOCR"] = AutoProcessor.from_pretrained("reducto/RolmOCR")
|
21 |
+
MODELS["RolmOCR"] = AutoModelForImageTextToText.from_pretrained(
|
22 |
+
"reducto/RolmOCR", torch_dtype=torch.bfloat16, device_map="auto"
|
23 |
+
)
|
24 |
+
PIPELINES["RolmOCR"] = pipeline("image-text-to-text", model=MODELS["RolmOCR"], processor=PROCESSORS["RolmOCR"])
|
25 |
+
except Exception as e:
|
26 |
+
MODEL_LOAD_ERROR_MSG["RolmOCR"] = f"Failed to load RolmOCR: {str(e)}"
|
27 |
+
print(f"Error loading RolmOCR: {e}")
|
28 |
+
|
29 |
+
# Load Nanonets-OCR-s
|
30 |
+
try:
|
31 |
+
PROCESSORS["Nanonets-OCR-s"] = AutoProcessor.from_pretrained("nanonets/Nanonets-OCR-s")
|
32 |
+
MODELS["Nanonets-OCR-s"] = AutoModelForImageTextToText.from_pretrained(
|
33 |
+
"nanonets/Nanonets-OCR-s", torch_dtype=torch.bfloat16, device_map="auto"
|
34 |
+
)
|
35 |
+
PIPELINES["Nanonets-OCR-s"] = pipeline("image-text-to-text", model=MODELS["Nanonets-OCR-s"], processor=PROCESSORS["Nanonets-OCR-s"])
|
36 |
+
except Exception as e:
|
37 |
+
MODEL_LOAD_ERROR_MSG["Nanonets-OCR-s"] = f"Failed to load Nanonets-OCR-s: {str(e)}"
|
38 |
+
print(f"Error loading Nanonets-OCR-s: {e}")
|
39 |
|
40 |
|
41 |
# --- Helper Functions ---
|
|
|
158 |
|
159 |
|
160 |
@spaces.GPU
|
161 |
+
def predict(pil_image, model_name="RolmOCR"):
|
162 |
+
"""Performs OCR prediction using the selected Hugging Face model."""
|
163 |
+
global PIPELINES, MODEL_LOAD_ERROR_MSG
|
164 |
+
|
165 |
+
if model_name not in PIPELINES:
|
166 |
+
error_to_report = MODEL_LOAD_ERROR_MSG.get(
|
167 |
+
model_name,
|
168 |
+
f"Model {model_name} could not be initialized or is not available."
|
|
|
169 |
)
|
170 |
raise RuntimeError(error_to_report)
|
171 |
|
172 |
+
selected_pipe = PIPELINES[model_name]
|
173 |
+
|
174 |
+
# Format the message based on the model
|
175 |
+
if model_name == "RolmOCR":
|
176 |
+
messages = [
|
177 |
+
{
|
178 |
+
"role": "user",
|
179 |
+
"content": [
|
180 |
+
{"type": "image", "image": pil_image},
|
181 |
+
{
|
182 |
+
"type": "text",
|
183 |
+
"text": "Return the plain text representation of this document as if you were reading it naturally.\n",
|
184 |
+
},
|
185 |
+
],
|
186 |
+
}
|
187 |
+
]
|
188 |
+
max_tokens = 8096
|
189 |
+
else: # Nanonets-OCR-s
|
190 |
+
messages = [
|
191 |
+
{
|
192 |
+
"role": "user",
|
193 |
+
"content": [
|
194 |
+
{"type": "image", "image": pil_image},
|
195 |
+
{
|
196 |
+
"type": "text",
|
197 |
+
"text": "Extract and return all the text from this image. Include all text elements and maintain the reading order. If there are tables, convert them to markdown format. If there are mathematical equations, convert them to LaTeX format.",
|
198 |
+
},
|
199 |
+
],
|
200 |
+
}
|
201 |
+
]
|
202 |
+
max_tokens = 8096
|
203 |
|
204 |
# Use the pipeline with the properly formatted messages
|
205 |
+
return selected_pipe(messages, max_new_tokens=max_tokens)
|
206 |
|
207 |
|
208 |
+
def run_hf_ocr(image_path, model_name="RolmOCR"):
|
209 |
"""
|
210 |
+
Runs OCR on the provided image using the selected Hugging Face model (via predict function).
|
211 |
"""
|
212 |
if image_path is None:
|
213 |
return "No image provided for OCR."
|
214 |
|
215 |
try:
|
216 |
pil_image = Image.open(image_path).convert("RGB")
|
217 |
+
ocr_results = predict(pil_image, model_name) # predict handles model loading and inference
|
218 |
|
219 |
# Parse the output based on the user's example structure
|
220 |
if (
|
|
|
273 |
# --- Gradio Interface Function ---
|
274 |
|
275 |
|
276 |
+
def process_files(image_path, xml_path, model_name):
|
277 |
"""
|
278 |
Main function for the Gradio interface.
|
279 |
+
Processes the image for display, runs OCR with selected model,
|
280 |
and parses XML if provided.
|
281 |
"""
|
282 |
img_to_display = None
|
|
|
286 |
if image_path:
|
287 |
try:
|
288 |
img_to_display = Image.open(image_path).convert("RGB")
|
289 |
+
hf_ocr_text_output = run_hf_ocr(image_path, model_name)
|
290 |
except Exception as e:
|
291 |
img_to_display = None # Clear image if it failed to load
|
292 |
+
hf_ocr_text_output = f"Error loading image or running {model_name} OCR: {e}"
|
293 |
else:
|
294 |
hf_ocr_text_output = "Please upload an image to perform OCR."
|
295 |
|
|
|
317 |
|
318 |
with gr.Row():
|
319 |
with gr.Column(scale=1):
|
320 |
+
model_selector = gr.Radio(
|
321 |
+
choices=AVAILABLE_MODELS,
|
322 |
+
value="RolmOCR",
|
323 |
+
label="Select OCR Model",
|
324 |
+
info="Choose between RolmOCR (fast, general purpose) or Nanonets-OCR-s (detailed extraction)"
|
325 |
+
)
|
326 |
image_input = gr.File(
|
327 |
label="Upload Image (PNG, JPG, etc.)", type="filepath"
|
328 |
)
|
|
|
338 |
)
|
339 |
with gr.Column(scale=1):
|
340 |
hf_ocr_output_textbox = gr.Markdown(
|
341 |
+
label="OCR Output",
|
342 |
show_copy_button=True,
|
343 |
)
|
344 |
xml_output_textbox = gr.Textbox(
|
|
|
350 |
|
351 |
submit_button.click(
|
352 |
fn=process_files,
|
353 |
+
inputs=[image_input, xml_input, model_selector],
|
354 |
outputs=[output_image_display, xml_output_textbox, hf_ocr_output_textbox],
|
355 |
)
|
356 |
|