davanstrien HF Staff commited on
Commit
5639776
·
1 Parent(s): 2c499db

Refactor XML parsing functions for improved readability and consistency

Browse files
Files changed (1) hide show
  1. app.py +100 -68
app.py CHANGED
@@ -14,15 +14,14 @@ MODEL_LOAD_ERROR_MSG = None
14
 
15
  HF_PROCESSOR = AutoProcessor.from_pretrained("reducto/RolmOCR")
16
  HF_MODEL = AutoModelForImageTextToText.from_pretrained(
17
- "reducto/RolmOCR",
18
- torch_dtype=torch.bfloat16,
19
- device_map="auto"
20
  )
21
  HF_PIPE = pipeline("image-text-to-text", model=HF_MODEL, processor=HF_PROCESSOR)
22
 
23
 
24
  # --- Helper Functions ---
25
 
 
26
  def get_xml_namespace(xml_file_path):
27
  """
28
  Dynamically gets the namespace from the XML file.
@@ -31,16 +30,17 @@ def get_xml_namespace(xml_file_path):
31
  try:
32
  tree = ET.parse(xml_file_path)
33
  root = tree.getroot()
34
- if '}' in root.tag:
35
- ns = root.tag.split('}')[0] + '}'
36
  # Determine format based on root element
37
- if 'PcGts' in root.tag:
38
- return ns, 'PAGE'
39
- elif 'alto' in root.tag.lower():
40
- return ns, 'ALTO'
41
  except ET.ParseError:
42
  print(f"Error parsing XML to find namespace: {xml_file_path}")
43
- return '', 'UNKNOWN'
 
44
 
45
  def parse_page_xml_for_text(xml_file_path):
46
  """
@@ -49,7 +49,7 @@ def parse_page_xml_for_text(xml_file_path):
49
  - full_text (str): All extracted text concatenated.
50
  """
51
  full_text_lines = []
52
-
53
  if not xml_file_path or not os.path.exists(xml_file_path):
54
  return "Error: XML file not provided or does not exist."
55
 
@@ -59,23 +59,23 @@ def parse_page_xml_for_text(xml_file_path):
59
  root = tree.getroot()
60
 
61
  # Find all TextLine elements
62
- for text_line in root.findall(f'.//{ns_prefix}TextLine'):
63
  # First try to get text from TextEquiv/Unicode
64
- text_equiv = text_line.find(f'{ns_prefix}TextEquiv/{ns_prefix}Unicode')
65
  if text_equiv is not None and text_equiv.text:
66
  full_text_lines.append(text_equiv.text)
67
  continue
68
 
69
  # If no TextEquiv, try to get text from Word elements
70
  line_text_parts = []
71
- for word in text_line.findall(f'{ns_prefix}Word'):
72
- word_text = word.find(f'{ns_prefix}TextEquiv/{ns_prefix}Unicode')
73
  if word_text is not None and word_text.text:
74
  line_text_parts.append(word_text.text)
75
-
76
  if line_text_parts:
77
  full_text_lines.append(" ".join(line_text_parts))
78
-
79
  return "\n".join(full_text_lines)
80
 
81
  except ET.ParseError as e:
@@ -83,6 +83,7 @@ def parse_page_xml_for_text(xml_file_path):
83
  except Exception as e:
84
  return f"An unexpected error occurred during XML parsing: {e}"
85
 
 
86
  def parse_alto_xml_for_text(xml_file_path):
87
  """
88
  Parses an ALTO XML file to extract text content.
@@ -90,7 +91,7 @@ def parse_alto_xml_for_text(xml_file_path):
90
  - full_text (str): All extracted text concatenated.
91
  """
92
  full_text_lines = []
93
-
94
  if not xml_file_path or not os.path.exists(xml_file_path):
95
  return "Error: XML file not provided or does not exist."
96
 
@@ -99,15 +100,15 @@ def parse_alto_xml_for_text(xml_file_path):
99
  tree = ET.parse(xml_file_path)
100
  root = tree.getroot()
101
 
102
- for text_line in root.findall(f'.//{ns_prefix}TextLine'):
103
  line_text_parts = []
104
- for string_element in text_line.findall(f'{ns_prefix}String'):
105
- text = string_element.get('CONTENT')
106
  if text:
107
  line_text_parts.append(text)
108
  if line_text_parts:
109
  full_text_lines.append(" ".join(line_text_parts))
110
-
111
  return "\n".join(full_text_lines)
112
 
113
  except ET.ParseError as e:
@@ -115,6 +116,7 @@ def parse_alto_xml_for_text(xml_file_path):
115
  except Exception as e:
116
  return f"An unexpected error occurred during XML parsing: {e}"
117
 
 
118
  def parse_xml_for_text(xml_file_path):
119
  """
120
  Main function to parse XML files, automatically detecting the format.
@@ -124,24 +126,29 @@ def parse_xml_for_text(xml_file_path):
124
 
125
  try:
126
  _, xml_format = get_xml_namespace(xml_file_path)
127
-
128
- if xml_format == 'PAGE':
129
  return parse_page_xml_for_text(xml_file_path)
130
- elif xml_format == 'ALTO':
131
  return parse_alto_xml_for_text(xml_file_path)
132
  else:
133
  return f"Error: Unsupported XML format. Expected ALTO or PAGE XML."
134
-
135
  except Exception as e:
136
  return f"Error determining XML format: {str(e)}"
137
 
 
138
  @spaces.GPU
139
  def predict(pil_image):
140
  """Performs OCR prediction using the Hugging Face model."""
141
  global HF_PIPE, MODEL_LOAD_ERROR_MSG
142
 
143
  if HF_PIPE is None:
144
- error_to_report = MODEL_LOAD_ERROR_MSG if MODEL_LOAD_ERROR_MSG else "OCR model could not be initialized."
 
 
 
 
145
  raise RuntimeError(error_to_report)
146
 
147
  # Format the message in the expected structure
@@ -150,13 +157,17 @@ def predict(pil_image):
150
  "role": "user",
151
  "content": [
152
  {"type": "image", "image": pil_image},
153
- {"type": "text", "text": "Return the plain text representation of this document as if you were reading it naturally.\n"}
154
- ]
 
 
 
155
  }
156
  ]
157
 
158
  # Use the pipeline with the properly formatted messages
159
- return HF_PIPE(messages,max_new_tokens=8096)
 
160
 
161
  def run_hf_ocr(image_path):
162
  """
@@ -164,53 +175,68 @@ def run_hf_ocr(image_path):
164
  """
165
  if image_path is None:
166
  return "No image provided for OCR."
167
-
168
  try:
169
  pil_image = Image.open(image_path).convert("RGB")
170
- ocr_results = predict(pil_image) # predict handles model loading and inference
171
-
172
  # Parse the output based on the user's example structure
173
- if isinstance(ocr_results, list) and ocr_results and 'generated_text' in ocr_results[0]:
174
- generated_content = ocr_results[0]['generated_text']
175
-
 
 
 
 
176
  if isinstance(generated_content, str):
177
  return generated_content
178
 
179
  if isinstance(generated_content, list) and generated_content:
180
  if assistant_message := next(
181
  (
182
- msg['content']
183
  for msg in reversed(generated_content)
184
  if isinstance(msg, dict)
185
- and msg.get('role') == 'assistant'
186
- and 'content' in msg
187
  ),
188
  None,
189
  ):
190
  return assistant_message
191
-
192
  # Fallback if the specific assistant message structure isn't found but there's content
193
- if isinstance(generated_content[0], dict) and 'content' in generated_content[0]:
194
- if len(generated_content) > 1 and isinstance(generated_content[1], dict) and 'content' in generated_content[1]:
195
- return generated_content[1]['content'] # Assuming second part is assistant
196
- elif 'content' in generated_content[0]: # Or if first part is already the content
197
- return generated_content[0]['content']
 
 
 
 
 
 
 
 
 
198
 
199
  print(f"Unexpected OCR output structure from HF model: {ocr_results}")
200
  return "Error: Could not parse OCR model output. Check console."
201
-
202
  else:
203
  print(f"Unexpected OCR output structure from HF model: {ocr_results}")
204
  return "Error: OCR model did not return expected output. Check console."
205
 
206
- except RuntimeError as e: # Catch model loading/initialization errors from predict
207
  return str(e)
208
  except Exception as e:
209
  print(f"Error during Hugging Face OCR processing: {e}")
210
  return f"Error during Hugging Face OCR: {str(e)}"
211
 
 
212
  # --- Gradio Interface Function ---
213
 
 
214
  def process_files(image_path, xml_path):
215
  """
216
  Main function for the Gradio interface.
@@ -226,7 +252,7 @@ def process_files(image_path, xml_path):
226
  img_to_display = Image.open(image_path).convert("RGB")
227
  hf_ocr_text_output = run_hf_ocr(image_path)
228
  except Exception as e:
229
- img_to_display = None # Clear image if it failed to load
230
  hf_ocr_text_output = f"Error loading image or running HF OCR: {e}"
231
  else:
232
  hf_ocr_text_output = "Please upload an image to perform OCR."
@@ -235,10 +261,10 @@ def process_files(image_path, xml_path):
235
  xml_text_output = parse_xml_for_text(xml_path)
236
  else:
237
  xml_text_output = "No XML file uploaded."
238
-
239
  # If only XML is provided without an image
240
  if not image_path and xml_path:
241
- img_to_display = None # No image to display
242
  hf_ocr_text_output = "Upload an image to perform OCR."
243
 
244
  return img_to_display, xml_text_output, hf_ocr_text_output
@@ -255,38 +281,42 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
255
 
256
  with gr.Row():
257
  with gr.Column(scale=1):
258
- image_input = gr.File(label="Upload Image (PNG, JPG, etc.)", type="filepath")
259
- xml_input = gr.File(label="Upload XML File (Optional, ALTO or PAGE format)", type="filepath")
 
 
 
 
260
  submit_button = gr.Button("Process Image and XML", variant="primary")
261
 
262
  with gr.Row():
263
  with gr.Column(scale=1):
264
- output_image_display = gr.Image(label="Uploaded Image", type="pil", interactive=False)
 
 
265
  with gr.Column(scale=1):
266
- hf_ocr_output_textbox = gr.Textbox(
267
- label="OCR Output (Hugging Face Model)",
268
- lines=15,
269
- interactive=False,
270
- show_copy_button=True
271
  )
272
  xml_output_textbox = gr.Textbox(
273
- label="Text from XML",
274
- lines=15,
275
  interactive=False,
276
- show_copy_button=True
277
  )
278
-
279
  submit_button.click(
280
  fn=process_files,
281
  inputs=[image_input, xml_input],
282
- outputs=[output_image_display, xml_output_textbox, hf_ocr_output_textbox]
283
  )
284
-
285
  gr.Markdown("---")
286
  gr.Markdown("### Example ALTO XML Snippet (for `String` element extraction):")
287
  gr.Code(
288
  value=(
289
- """<alto xmlns="http://www.loc.gov/standards/alto/v3/alto.xsd">
290
  <Description>...</Description>
291
  <Styles>...</Styles>
292
  <Layout>
@@ -307,11 +337,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
307
  </Layout>
308
  </alto>"""
309
  ),
310
- interactive=False
311
  )
312
 
313
  if __name__ == "__main__":
314
  # Removed dummy file creation as it's less relevant for single file focus
315
  print("Attempting to launch Gradio demo...")
316
- print("If the Hugging Face model is large, initial startup might take some time due to model download/loading (on first OCR attempt).")
317
- demo.launch()
 
 
 
14
 
15
  HF_PROCESSOR = AutoProcessor.from_pretrained("reducto/RolmOCR")
16
  HF_MODEL = AutoModelForImageTextToText.from_pretrained(
17
+ "reducto/RolmOCR", torch_dtype=torch.bfloat16, device_map="auto"
 
 
18
  )
19
  HF_PIPE = pipeline("image-text-to-text", model=HF_MODEL, processor=HF_PROCESSOR)
20
 
21
 
22
  # --- Helper Functions ---
23
 
24
+
25
  def get_xml_namespace(xml_file_path):
26
  """
27
  Dynamically gets the namespace from the XML file.
 
30
  try:
31
  tree = ET.parse(xml_file_path)
32
  root = tree.getroot()
33
+ if "}" in root.tag:
34
+ ns = root.tag.split("}")[0] + "}"
35
  # Determine format based on root element
36
+ if "PcGts" in root.tag:
37
+ return ns, "PAGE"
38
+ elif "alto" in root.tag.lower():
39
+ return ns, "ALTO"
40
  except ET.ParseError:
41
  print(f"Error parsing XML to find namespace: {xml_file_path}")
42
+ return "", "UNKNOWN"
43
+
44
 
45
  def parse_page_xml_for_text(xml_file_path):
46
  """
 
49
  - full_text (str): All extracted text concatenated.
50
  """
51
  full_text_lines = []
52
+
53
  if not xml_file_path or not os.path.exists(xml_file_path):
54
  return "Error: XML file not provided or does not exist."
55
 
 
59
  root = tree.getroot()
60
 
61
  # Find all TextLine elements
62
+ for text_line in root.findall(f".//{ns_prefix}TextLine"):
63
  # First try to get text from TextEquiv/Unicode
64
+ text_equiv = text_line.find(f"{ns_prefix}TextEquiv/{ns_prefix}Unicode")
65
  if text_equiv is not None and text_equiv.text:
66
  full_text_lines.append(text_equiv.text)
67
  continue
68
 
69
  # If no TextEquiv, try to get text from Word elements
70
  line_text_parts = []
71
+ for word in text_line.findall(f"{ns_prefix}Word"):
72
+ word_text = word.find(f"{ns_prefix}TextEquiv/{ns_prefix}Unicode")
73
  if word_text is not None and word_text.text:
74
  line_text_parts.append(word_text.text)
75
+
76
  if line_text_parts:
77
  full_text_lines.append(" ".join(line_text_parts))
78
+
79
  return "\n".join(full_text_lines)
80
 
81
  except ET.ParseError as e:
 
83
  except Exception as e:
84
  return f"An unexpected error occurred during XML parsing: {e}"
85
 
86
+
87
  def parse_alto_xml_for_text(xml_file_path):
88
  """
89
  Parses an ALTO XML file to extract text content.
 
91
  - full_text (str): All extracted text concatenated.
92
  """
93
  full_text_lines = []
94
+
95
  if not xml_file_path or not os.path.exists(xml_file_path):
96
  return "Error: XML file not provided or does not exist."
97
 
 
100
  tree = ET.parse(xml_file_path)
101
  root = tree.getroot()
102
 
103
+ for text_line in root.findall(f".//{ns_prefix}TextLine"):
104
  line_text_parts = []
105
+ for string_element in text_line.findall(f"{ns_prefix}String"):
106
+ text = string_element.get("CONTENT")
107
  if text:
108
  line_text_parts.append(text)
109
  if line_text_parts:
110
  full_text_lines.append(" ".join(line_text_parts))
111
+
112
  return "\n".join(full_text_lines)
113
 
114
  except ET.ParseError as e:
 
116
  except Exception as e:
117
  return f"An unexpected error occurred during XML parsing: {e}"
118
 
119
+
120
  def parse_xml_for_text(xml_file_path):
121
  """
122
  Main function to parse XML files, automatically detecting the format.
 
126
 
127
  try:
128
  _, xml_format = get_xml_namespace(xml_file_path)
129
+
130
+ if xml_format == "PAGE":
131
  return parse_page_xml_for_text(xml_file_path)
132
+ elif xml_format == "ALTO":
133
  return parse_alto_xml_for_text(xml_file_path)
134
  else:
135
  return f"Error: Unsupported XML format. Expected ALTO or PAGE XML."
136
+
137
  except Exception as e:
138
  return f"Error determining XML format: {str(e)}"
139
 
140
+
141
  @spaces.GPU
142
  def predict(pil_image):
143
  """Performs OCR prediction using the Hugging Face model."""
144
  global HF_PIPE, MODEL_LOAD_ERROR_MSG
145
 
146
  if HF_PIPE is None:
147
+ error_to_report = (
148
+ MODEL_LOAD_ERROR_MSG
149
+ if MODEL_LOAD_ERROR_MSG
150
+ else "OCR model could not be initialized."
151
+ )
152
  raise RuntimeError(error_to_report)
153
 
154
  # Format the message in the expected structure
 
157
  "role": "user",
158
  "content": [
159
  {"type": "image", "image": pil_image},
160
+ {
161
+ "type": "text",
162
+ "text": "Return the plain text representation of this document as if you were reading it naturally.\n",
163
+ },
164
+ ],
165
  }
166
  ]
167
 
168
  # Use the pipeline with the properly formatted messages
169
+ return HF_PIPE(messages, max_new_tokens=8096)
170
+
171
 
172
  def run_hf_ocr(image_path):
173
  """
 
175
  """
176
  if image_path is None:
177
  return "No image provided for OCR."
178
+
179
  try:
180
  pil_image = Image.open(image_path).convert("RGB")
181
+ ocr_results = predict(pil_image) # predict handles model loading and inference
182
+
183
  # Parse the output based on the user's example structure
184
+ if (
185
+ isinstance(ocr_results, list)
186
+ and ocr_results
187
+ and "generated_text" in ocr_results[0]
188
+ ):
189
+ generated_content = ocr_results[0]["generated_text"]
190
+
191
  if isinstance(generated_content, str):
192
  return generated_content
193
 
194
  if isinstance(generated_content, list) and generated_content:
195
  if assistant_message := next(
196
  (
197
+ msg["content"]
198
  for msg in reversed(generated_content)
199
  if isinstance(msg, dict)
200
+ and msg.get("role") == "assistant"
201
+ and "content" in msg
202
  ),
203
  None,
204
  ):
205
  return assistant_message
206
+
207
  # Fallback if the specific assistant message structure isn't found but there's content
208
+ if (
209
+ isinstance(generated_content[0], dict)
210
+ and "content" in generated_content[0]
211
+ ):
212
+ if (
213
+ len(generated_content) > 1
214
+ and isinstance(generated_content[1], dict)
215
+ and "content" in generated_content[1]
216
+ ):
217
+ return generated_content[1][
218
+ "content"
219
+ ] # Assuming second part is assistant
220
+ else:
221
+ return generated_content[0]["content"]
222
 
223
  print(f"Unexpected OCR output structure from HF model: {ocr_results}")
224
  return "Error: Could not parse OCR model output. Check console."
225
+
226
  else:
227
  print(f"Unexpected OCR output structure from HF model: {ocr_results}")
228
  return "Error: OCR model did not return expected output. Check console."
229
 
230
+ except RuntimeError as e: # Catch model loading/initialization errors from predict
231
  return str(e)
232
  except Exception as e:
233
  print(f"Error during Hugging Face OCR processing: {e}")
234
  return f"Error during Hugging Face OCR: {str(e)}"
235
 
236
+
237
  # --- Gradio Interface Function ---
238
 
239
+
240
  def process_files(image_path, xml_path):
241
  """
242
  Main function for the Gradio interface.
 
252
  img_to_display = Image.open(image_path).convert("RGB")
253
  hf_ocr_text_output = run_hf_ocr(image_path)
254
  except Exception as e:
255
+ img_to_display = None # Clear image if it failed to load
256
  hf_ocr_text_output = f"Error loading image or running HF OCR: {e}"
257
  else:
258
  hf_ocr_text_output = "Please upload an image to perform OCR."
 
261
  xml_text_output = parse_xml_for_text(xml_path)
262
  else:
263
  xml_text_output = "No XML file uploaded."
264
+
265
  # If only XML is provided without an image
266
  if not image_path and xml_path:
267
+ img_to_display = None # No image to display
268
  hf_ocr_text_output = "Upload an image to perform OCR."
269
 
270
  return img_to_display, xml_text_output, hf_ocr_text_output
 
281
 
282
  with gr.Row():
283
  with gr.Column(scale=1):
284
+ image_input = gr.File(
285
+ label="Upload Image (PNG, JPG, etc.)", type="filepath"
286
+ )
287
+ xml_input = gr.File(
288
+ label="Upload XML File (Optional, ALTO or PAGE format)", type="filepath"
289
+ )
290
  submit_button = gr.Button("Process Image and XML", variant="primary")
291
 
292
  with gr.Row():
293
  with gr.Column(scale=1):
294
+ output_image_display = gr.Image(
295
+ label="Uploaded Image", type="pil", interactive=False
296
+ )
297
  with gr.Column(scale=1):
298
+ hf_ocr_output_textbox = gr.Markdown(
299
+ label="OCR Output (Hugging Face Model)",
300
+ show_copy_button=True,
 
 
301
  )
302
  xml_output_textbox = gr.Textbox(
303
+ label="Text from XML",
304
+ lines=15,
305
  interactive=False,
306
+ show_copy_button=True,
307
  )
308
+
309
  submit_button.click(
310
  fn=process_files,
311
  inputs=[image_input, xml_input],
312
+ outputs=[output_image_display, xml_output_textbox, hf_ocr_output_textbox],
313
  )
314
+
315
  gr.Markdown("---")
316
  gr.Markdown("### Example ALTO XML Snippet (for `String` element extraction):")
317
  gr.Code(
318
  value=(
319
+ """<alto xmlns="http://www.loc.gov/standards/alto/v3/alto.xsd">
320
  <Description>...</Description>
321
  <Styles>...</Styles>
322
  <Layout>
 
337
  </Layout>
338
  </alto>"""
339
  ),
340
+ interactive=False,
341
  )
342
 
343
  if __name__ == "__main__":
344
  # Removed dummy file creation as it's less relevant for single file focus
345
  print("Attempting to launch Gradio demo...")
346
+ print(
347
+ "If the Hugging Face model is large, initial startup might take some time due to model download/loading (on first OCR attempt)."
348
+ )
349
+ demo.launch()