Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
58b56ea
1
Parent(s):
864e5c4
Refactor OCR model loading to use lazy initialization and enhance error handling in predict function
Browse files
app.py
CHANGED
@@ -5,21 +5,12 @@ import os
|
|
5 |
import torch
|
6 |
from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline
|
7 |
import spaces
|
8 |
-
|
9 |
-
#
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
torch_dtype=torch.bfloat16,
|
15 |
-
# attn_implementation="flash_attention_2", # User had this commented out
|
16 |
-
device_map="auto"
|
17 |
-
)
|
18 |
-
HF_PIPE = pipeline("image-text-to-text", model=HF_MODEL, processor=HF_PROCESSOR)
|
19 |
-
print("Hugging Face OCR model loaded successfully.")
|
20 |
-
except Exception as e:
|
21 |
-
print(f"Error loading Hugging Face model: {e}")
|
22 |
-
HF_PIPE = None
|
23 |
|
24 |
# --- Helper Functions ---
|
25 |
|
@@ -68,72 +59,87 @@ def parse_alto_xml_for_text(xml_file_path):
|
|
68 |
except Exception as e:
|
69 |
return f"An unexpected error occurred during XML parsing: {e}"
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
def run_hf_ocr(image_path):
|
72 |
"""
|
73 |
-
Runs OCR on the provided image using the
|
74 |
"""
|
75 |
-
if HF_PIPE is None:
|
76 |
-
return "Hugging Face OCR model not available."
|
77 |
if image_path is None:
|
78 |
return "No image provided for OCR."
|
79 |
|
80 |
try:
|
81 |
-
# Load the image using PIL, as the pipeline expects an image object or path
|
82 |
pil_image = Image.open(image_path).convert("RGB")
|
83 |
-
|
84 |
-
# The user's example output for the pipeline call was:
|
85 |
-
# [{'generated_text': [{'role': 'user', ...}, {'role': 'assistant', 'content': "TEXT..."}]}]
|
86 |
-
# This suggests the pipeline is returning a conversational style output.
|
87 |
-
# We will try to call the pipeline with the image and prompt directly.
|
88 |
-
ocr_results = predict(pil_image)
|
89 |
|
90 |
# Parse the output based on the user's example structure
|
91 |
if isinstance(ocr_results, list) and ocr_results and 'generated_text' in ocr_results[0]:
|
92 |
generated_content = ocr_results[0]['generated_text']
|
93 |
|
94 |
-
# Check if generated_content itself is the direct text (some pipelines do this)
|
95 |
if isinstance(generated_content, str):
|
96 |
return generated_content
|
97 |
|
98 |
-
# Check for the conversational structure
|
99 |
-
# [{'role': 'user', ...}, {'role': 'assistant', 'content': "TEXT..."}]
|
100 |
if isinstance(generated_content, list) and generated_content:
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
return assistant_message
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
|
118 |
print(f"Unexpected OCR output structure from HF model: {ocr_results}")
|
119 |
-
return "Error: Could not parse OCR model output.
|
120 |
|
121 |
else:
|
122 |
print(f"Unexpected OCR output structure from HF model: {ocr_results}")
|
123 |
-
return "Error: OCR model did not return expected output.
|
124 |
|
|
|
|
|
125 |
except Exception as e:
|
126 |
-
print(f"Error during Hugging Face OCR: {e}")
|
127 |
return f"Error during Hugging Face OCR: {str(e)}"
|
128 |
-
@spaces.GPU
|
129 |
-
def predict(pil_image):
|
130 |
-
ocr_results = HF_PIPE(
|
131 |
-
pil_image,
|
132 |
-
prompt="Return the plain text representation of this document as if you were reading it naturally.\n"
|
133 |
-
# The pipeline should handle formatting this into messages if needed by the model.
|
134 |
-
)
|
135 |
-
|
136 |
-
return ocr_results
|
137 |
|
138 |
# --- Gradio Interface Function ---
|
139 |
|
@@ -241,5 +247,5 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
241 |
if __name__ == "__main__":
|
242 |
# Removed dummy file creation as it's less relevant for single file focus
|
243 |
print("Attempting to launch Gradio demo...")
|
244 |
-
print("If the Hugging Face model is large, initial startup might take some time due to model download/loading.")
|
245 |
demo.launch()
|
|
|
5 |
import torch
|
6 |
from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline
|
7 |
import spaces
|
8 |
+
|
9 |
+
# --- Global Model and Processor (initialize as None for lazy loading) ---
|
10 |
+
HF_PROCESSOR = None
|
11 |
+
HF_MODEL = None
|
12 |
+
HF_PIPE = None
|
13 |
+
MODEL_LOAD_ERROR_MSG = None # To store any error message from loading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
# --- Helper Functions ---
|
16 |
|
|
|
59 |
except Exception as e:
|
60 |
return f"An unexpected error occurred during XML parsing: {e}"
|
61 |
|
62 |
+
@spaces.GPU # Ensures GPU is available for model loading (on first call) and inference
|
63 |
+
def predict(pil_image):
|
64 |
+
"""Performs OCR prediction using the Hugging Face model, with lazy loading."""
|
65 |
+
global HF_PROCESSOR, HF_MODEL, HF_PIPE, MODEL_LOAD_ERROR_MSG
|
66 |
+
|
67 |
+
if HF_PIPE is None and MODEL_LOAD_ERROR_MSG is None:
|
68 |
+
try:
|
69 |
+
print("Attempting to load Hugging Face model and processor within @spaces.GPU context...")
|
70 |
+
HF_PROCESSOR = AutoProcessor.from_pretrained("reducto/RolmOCR")
|
71 |
+
HF_MODEL = AutoModelForImageTextToText.from_pretrained(
|
72 |
+
"reducto/RolmOCR",
|
73 |
+
torch_dtype=torch.bfloat16,
|
74 |
+
device_map="auto" # Should utilize ZeroGPU correctly here
|
75 |
+
)
|
76 |
+
HF_PIPE = pipeline("image-text-to-text", model=HF_MODEL, processor=HF_PROCESSOR)
|
77 |
+
print("Hugging Face OCR model loaded successfully.")
|
78 |
+
except Exception as e:
|
79 |
+
MODEL_LOAD_ERROR_MSG = f"Error loading Hugging Face model: {str(e)}"
|
80 |
+
print(MODEL_LOAD_ERROR_MSG)
|
81 |
+
# HF_PIPE remains None, error message is stored
|
82 |
+
|
83 |
+
if HF_PIPE is None:
|
84 |
+
error_to_report = MODEL_LOAD_ERROR_MSG if MODEL_LOAD_ERROR_MSG else "OCR model could not be initialized."
|
85 |
+
raise RuntimeError(error_to_report)
|
86 |
+
|
87 |
+
# Proceed with inference if pipe is available
|
88 |
+
return HF_PIPE(
|
89 |
+
pil_image,
|
90 |
+
prompt="Return the plain text representation of this document as if you were reading it naturally.\n",
|
91 |
+
)
|
92 |
+
|
93 |
def run_hf_ocr(image_path):
|
94 |
"""
|
95 |
+
Runs OCR on the provided image using the Hugging Face model (via predict function).
|
96 |
"""
|
|
|
|
|
97 |
if image_path is None:
|
98 |
return "No image provided for OCR."
|
99 |
|
100 |
try:
|
|
|
101 |
pil_image = Image.open(image_path).convert("RGB")
|
102 |
+
ocr_results = predict(pil_image) # predict handles model loading and inference
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
# Parse the output based on the user's example structure
|
105 |
if isinstance(ocr_results, list) and ocr_results and 'generated_text' in ocr_results[0]:
|
106 |
generated_content = ocr_results[0]['generated_text']
|
107 |
|
|
|
108 |
if isinstance(generated_content, str):
|
109 |
return generated_content
|
110 |
|
|
|
|
|
111 |
if isinstance(generated_content, list) and generated_content:
|
112 |
+
if assistant_message := next(
|
113 |
+
(
|
114 |
+
msg['content']
|
115 |
+
for msg in reversed(generated_content)
|
116 |
+
if isinstance(msg, dict)
|
117 |
+
and msg.get('role') == 'assistant'
|
118 |
+
and 'content' in msg
|
119 |
+
),
|
120 |
+
None,
|
121 |
+
):
|
122 |
return assistant_message
|
123 |
+
|
124 |
+
# Fallback if the specific assistant message structure isn't found but there's content
|
125 |
+
if isinstance(generated_content[0], dict) and 'content' in generated_content[0]:
|
126 |
+
if len(generated_content) > 1 and isinstance(generated_content[1], dict) and 'content' in generated_content[1]:
|
127 |
+
return generated_content[1]['content'] # Assuming second part is assistant
|
128 |
+
elif 'content' in generated_content[0]: # Or if first part is already the content
|
129 |
+
return generated_content[0]['content']
|
130 |
|
131 |
print(f"Unexpected OCR output structure from HF model: {ocr_results}")
|
132 |
+
return "Error: Could not parse OCR model output. Check console."
|
133 |
|
134 |
else:
|
135 |
print(f"Unexpected OCR output structure from HF model: {ocr_results}")
|
136 |
+
return "Error: OCR model did not return expected output. Check console."
|
137 |
|
138 |
+
except RuntimeError as e: # Catch model loading/initialization errors from predict
|
139 |
+
return str(e)
|
140 |
except Exception as e:
|
141 |
+
print(f"Error during Hugging Face OCR processing: {e}")
|
142 |
return f"Error during Hugging Face OCR: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
# --- Gradio Interface Function ---
|
145 |
|
|
|
247 |
if __name__ == "__main__":
|
248 |
# Removed dummy file creation as it's less relevant for single file focus
|
249 |
print("Attempting to launch Gradio demo...")
|
250 |
+
print("If the Hugging Face model is large, initial startup might take some time due to model download/loading (on first OCR attempt).")
|
251 |
demo.launch()
|