davanstrien HF Staff commited on
Commit
003891a
·
1 Parent(s): 5639776

add second model

Browse files
Files changed (1) hide show
  1. app.py +84 -42
app.py CHANGED
@@ -7,16 +7,35 @@ from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline
7
  import spaces
8
 
9
  # --- Global Model and Processor ---
10
- HF_PROCESSOR = None
11
- HF_MODEL = None
12
- HF_PIPE = None
13
- MODEL_LOAD_ERROR_MSG = None
14
-
15
- HF_PROCESSOR = AutoProcessor.from_pretrained("reducto/RolmOCR")
16
- HF_MODEL = AutoModelForImageTextToText.from_pretrained(
17
- "reducto/RolmOCR", torch_dtype=torch.bfloat16, device_map="auto"
18
- )
19
- HF_PIPE = pipeline("image-text-to-text", model=HF_MODEL, processor=HF_PROCESSOR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  # --- Helper Functions ---
@@ -139,46 +158,63 @@ def parse_xml_for_text(xml_file_path):
139
 
140
 
141
  @spaces.GPU
142
- def predict(pil_image):
143
- """Performs OCR prediction using the Hugging Face model."""
144
- global HF_PIPE, MODEL_LOAD_ERROR_MSG
145
-
146
- if HF_PIPE is None:
147
- error_to_report = (
148
- MODEL_LOAD_ERROR_MSG
149
- if MODEL_LOAD_ERROR_MSG
150
- else "OCR model could not be initialized."
151
  )
152
  raise RuntimeError(error_to_report)
153
 
154
- # Format the message in the expected structure
155
- messages = [
156
- {
157
- "role": "user",
158
- "content": [
159
- {"type": "image", "image": pil_image},
160
- {
161
- "type": "text",
162
- "text": "Return the plain text representation of this document as if you were reading it naturally.\n",
163
- },
164
- ],
165
- }
166
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  # Use the pipeline with the properly formatted messages
169
- return HF_PIPE(messages, max_new_tokens=8096)
170
 
171
 
172
- def run_hf_ocr(image_path):
173
  """
174
- Runs OCR on the provided image using the Hugging Face model (via predict function).
175
  """
176
  if image_path is None:
177
  return "No image provided for OCR."
178
 
179
  try:
180
  pil_image = Image.open(image_path).convert("RGB")
181
- ocr_results = predict(pil_image) # predict handles model loading and inference
182
 
183
  # Parse the output based on the user's example structure
184
  if (
@@ -237,10 +273,10 @@ def run_hf_ocr(image_path):
237
  # --- Gradio Interface Function ---
238
 
239
 
240
- def process_files(image_path, xml_path):
241
  """
242
  Main function for the Gradio interface.
243
- Processes the image for display, runs OCR (Hugging Face model),
244
  and parses XML if provided.
245
  """
246
  img_to_display = None
@@ -250,10 +286,10 @@ def process_files(image_path, xml_path):
250
  if image_path:
251
  try:
252
  img_to_display = Image.open(image_path).convert("RGB")
253
- hf_ocr_text_output = run_hf_ocr(image_path)
254
  except Exception as e:
255
  img_to_display = None # Clear image if it failed to load
256
- hf_ocr_text_output = f"Error loading image or running HF OCR: {e}"
257
  else:
258
  hf_ocr_text_output = "Please upload an image to perform OCR."
259
 
@@ -281,6 +317,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
281
 
282
  with gr.Row():
283
  with gr.Column(scale=1):
 
 
 
 
 
 
284
  image_input = gr.File(
285
  label="Upload Image (PNG, JPG, etc.)", type="filepath"
286
  )
@@ -296,7 +338,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
296
  )
297
  with gr.Column(scale=1):
298
  hf_ocr_output_textbox = gr.Markdown(
299
- label="OCR Output (Hugging Face Model)",
300
  show_copy_button=True,
301
  )
302
  xml_output_textbox = gr.Textbox(
@@ -308,7 +350,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
308
 
309
  submit_button.click(
310
  fn=process_files,
311
- inputs=[image_input, xml_input],
312
  outputs=[output_image_display, xml_output_textbox, hf_ocr_output_textbox],
313
  )
314
 
 
7
  import spaces
8
 
9
  # --- Global Model and Processor ---
10
+ MODELS = {}
11
+ PROCESSORS = {}
12
+ PIPELINES = {}
13
+ MODEL_LOAD_ERROR_MSG = {}
14
+
15
+ # Available models
16
+ AVAILABLE_MODELS = ["RolmOCR", "Nanonets-OCR-s"]
17
+
18
+ # Load RolmOCR
19
+ try:
20
+ PROCESSORS["RolmOCR"] = AutoProcessor.from_pretrained("reducto/RolmOCR")
21
+ MODELS["RolmOCR"] = AutoModelForImageTextToText.from_pretrained(
22
+ "reducto/RolmOCR", torch_dtype=torch.bfloat16, device_map="auto"
23
+ )
24
+ PIPELINES["RolmOCR"] = pipeline("image-text-to-text", model=MODELS["RolmOCR"], processor=PROCESSORS["RolmOCR"])
25
+ except Exception as e:
26
+ MODEL_LOAD_ERROR_MSG["RolmOCR"] = f"Failed to load RolmOCR: {str(e)}"
27
+ print(f"Error loading RolmOCR: {e}")
28
+
29
+ # Load Nanonets-OCR-s
30
+ try:
31
+ PROCESSORS["Nanonets-OCR-s"] = AutoProcessor.from_pretrained("nanonets/Nanonets-OCR-s")
32
+ MODELS["Nanonets-OCR-s"] = AutoModelForImageTextToText.from_pretrained(
33
+ "nanonets/Nanonets-OCR-s", torch_dtype=torch.bfloat16, device_map="auto"
34
+ )
35
+ PIPELINES["Nanonets-OCR-s"] = pipeline("image-text-to-text", model=MODELS["Nanonets-OCR-s"], processor=PROCESSORS["Nanonets-OCR-s"])
36
+ except Exception as e:
37
+ MODEL_LOAD_ERROR_MSG["Nanonets-OCR-s"] = f"Failed to load Nanonets-OCR-s: {str(e)}"
38
+ print(f"Error loading Nanonets-OCR-s: {e}")
39
 
40
 
41
  # --- Helper Functions ---
 
158
 
159
 
160
  @spaces.GPU
161
+ def predict(pil_image, model_name="RolmOCR"):
162
+ """Performs OCR prediction using the selected Hugging Face model."""
163
+ global PIPELINES, MODEL_LOAD_ERROR_MSG
164
+
165
+ if model_name not in PIPELINES:
166
+ error_to_report = MODEL_LOAD_ERROR_MSG.get(
167
+ model_name,
168
+ f"Model {model_name} could not be initialized or is not available."
 
169
  )
170
  raise RuntimeError(error_to_report)
171
 
172
+ selected_pipe = PIPELINES[model_name]
173
+
174
+ # Format the message based on the model
175
+ if model_name == "RolmOCR":
176
+ messages = [
177
+ {
178
+ "role": "user",
179
+ "content": [
180
+ {"type": "image", "image": pil_image},
181
+ {
182
+ "type": "text",
183
+ "text": "Return the plain text representation of this document as if you were reading it naturally.\n",
184
+ },
185
+ ],
186
+ }
187
+ ]
188
+ max_tokens = 8096
189
+ else: # Nanonets-OCR-s
190
+ messages = [
191
+ {
192
+ "role": "user",
193
+ "content": [
194
+ {"type": "image", "image": pil_image},
195
+ {
196
+ "type": "text",
197
+ "text": "Extract and return all the text from this image. Include all text elements and maintain the reading order. If there are tables, convert them to markdown format. If there are mathematical equations, convert them to LaTeX format.",
198
+ },
199
+ ],
200
+ }
201
+ ]
202
+ max_tokens = 8096
203
 
204
  # Use the pipeline with the properly formatted messages
205
+ return selected_pipe(messages, max_new_tokens=max_tokens)
206
 
207
 
208
+ def run_hf_ocr(image_path, model_name="RolmOCR"):
209
  """
210
+ Runs OCR on the provided image using the selected Hugging Face model (via predict function).
211
  """
212
  if image_path is None:
213
  return "No image provided for OCR."
214
 
215
  try:
216
  pil_image = Image.open(image_path).convert("RGB")
217
+ ocr_results = predict(pil_image, model_name) # predict handles model loading and inference
218
 
219
  # Parse the output based on the user's example structure
220
  if (
 
273
  # --- Gradio Interface Function ---
274
 
275
 
276
+ def process_files(image_path, xml_path, model_name):
277
  """
278
  Main function for the Gradio interface.
279
+ Processes the image for display, runs OCR with selected model,
280
  and parses XML if provided.
281
  """
282
  img_to_display = None
 
286
  if image_path:
287
  try:
288
  img_to_display = Image.open(image_path).convert("RGB")
289
+ hf_ocr_text_output = run_hf_ocr(image_path, model_name)
290
  except Exception as e:
291
  img_to_display = None # Clear image if it failed to load
292
+ hf_ocr_text_output = f"Error loading image or running {model_name} OCR: {e}"
293
  else:
294
  hf_ocr_text_output = "Please upload an image to perform OCR."
295
 
 
317
 
318
  with gr.Row():
319
  with gr.Column(scale=1):
320
+ model_selector = gr.Radio(
321
+ choices=AVAILABLE_MODELS,
322
+ value="RolmOCR",
323
+ label="Select OCR Model",
324
+ info="Choose between RolmOCR (fast, general purpose) or Nanonets-OCR-s (detailed extraction)"
325
+ )
326
  image_input = gr.File(
327
  label="Upload Image (PNG, JPG, etc.)", type="filepath"
328
  )
 
338
  )
339
  with gr.Column(scale=1):
340
  hf_ocr_output_textbox = gr.Markdown(
341
+ label="OCR Output",
342
  show_copy_button=True,
343
  )
344
  xml_output_textbox = gr.Textbox(
 
350
 
351
  submit_button.click(
352
  fn=process_files,
353
+ inputs=[image_input, xml_input, model_selector],
354
  outputs=[output_image_display, xml_output_textbox, hf_ocr_output_textbox],
355
  )
356