davanstrien HF Staff commited on
Commit
58b56ea
·
1 Parent(s): 864e5c4

Refactor OCR model loading to use lazy initialization and enhance error handling in predict function

Browse files
Files changed (1) hide show
  1. app.py +62 -56
app.py CHANGED
@@ -5,21 +5,12 @@ import os
5
  import torch
6
  from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline
7
  import spaces
8
- # --- Global Model and Processor Initialization ---
9
- # Load the OCR model and processor once when the app starts
10
- try:
11
- HF_PROCESSOR = AutoProcessor.from_pretrained("reducto/RolmOCR")
12
- HF_MODEL = AutoModelForImageTextToText.from_pretrained(
13
- "reducto/RolmOCR",
14
- torch_dtype=torch.bfloat16,
15
- # attn_implementation="flash_attention_2", # User had this commented out
16
- device_map="auto"
17
- )
18
- HF_PIPE = pipeline("image-text-to-text", model=HF_MODEL, processor=HF_PROCESSOR)
19
- print("Hugging Face OCR model loaded successfully.")
20
- except Exception as e:
21
- print(f"Error loading Hugging Face model: {e}")
22
- HF_PIPE = None
23
 
24
  # --- Helper Functions ---
25
 
@@ -68,72 +59,87 @@ def parse_alto_xml_for_text(xml_file_path):
68
  except Exception as e:
69
  return f"An unexpected error occurred during XML parsing: {e}"
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def run_hf_ocr(image_path):
72
  """
73
- Runs OCR on the provided image using the pre-loaded Hugging Face model.
74
  """
75
- if HF_PIPE is None:
76
- return "Hugging Face OCR model not available."
77
  if image_path is None:
78
  return "No image provided for OCR."
79
 
80
  try:
81
- # Load the image using PIL, as the pipeline expects an image object or path
82
  pil_image = Image.open(image_path).convert("RGB")
83
-
84
- # The user's example output for the pipeline call was:
85
- # [{'generated_text': [{'role': 'user', ...}, {'role': 'assistant', 'content': "TEXT..."}]}]
86
- # This suggests the pipeline is returning a conversational style output.
87
- # We will try to call the pipeline with the image and prompt directly.
88
- ocr_results = predict(pil_image)
89
 
90
  # Parse the output based on the user's example structure
91
  if isinstance(ocr_results, list) and ocr_results and 'generated_text' in ocr_results[0]:
92
  generated_content = ocr_results[0]['generated_text']
93
 
94
- # Check if generated_content itself is the direct text (some pipelines do this)
95
  if isinstance(generated_content, str):
96
  return generated_content
97
 
98
- # Check for the conversational structure
99
- # [{'role': 'user', ...}, {'role': 'assistant', 'content': "TEXT..."}]
100
  if isinstance(generated_content, list) and generated_content:
101
- # The assistant's response is typically the last message in the list
102
- # or specifically the one with role 'assistant'.
103
- assistant_message = None
104
- for msg in reversed(generated_content): # Check from the end
105
- if isinstance(msg, dict) and msg.get('role') == 'assistant' and 'content' in msg:
106
- assistant_message = msg['content']
107
- break
108
- if assistant_message:
 
 
109
  return assistant_message
110
-
111
- # Fallback if parsing the complex structure fails but we got some string
112
- if isinstance(generated_content, list) and generated_content and isinstance(generated_content[0], dict) and 'content' in generated_content[0]:
113
- # This is a guess if the structure is simpler than expected.
114
- # Or if the first part is the user prompt echo and second is assistant.
115
- if len(generated_content) > 1 and isinstance(generated_content[1], dict) and 'content' in generated_content[1]:
116
- return generated_content[1]['content'] # Assuming second part is assistant
117
 
118
  print(f"Unexpected OCR output structure from HF model: {ocr_results}")
119
- return "Error: Could not parse OCR model output. Please check console for details."
120
 
121
  else:
122
  print(f"Unexpected OCR output structure from HF model: {ocr_results}")
123
- return "Error: OCR model did not return expected output. Please check console for details."
124
 
 
 
125
  except Exception as e:
126
- print(f"Error during Hugging Face OCR: {e}")
127
  return f"Error during Hugging Face OCR: {str(e)}"
128
- @spaces.GPU
129
- def predict(pil_image):
130
- ocr_results = HF_PIPE(
131
- pil_image,
132
- prompt="Return the plain text representation of this document as if you were reading it naturally.\n"
133
- # The pipeline should handle formatting this into messages if needed by the model.
134
- )
135
-
136
- return ocr_results
137
 
138
  # --- Gradio Interface Function ---
139
 
@@ -241,5 +247,5 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
241
  if __name__ == "__main__":
242
  # Removed dummy file creation as it's less relevant for single file focus
243
  print("Attempting to launch Gradio demo...")
244
- print("If the Hugging Face model is large, initial startup might take some time due to model download/loading.")
245
  demo.launch()
 
5
  import torch
6
  from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline
7
  import spaces
8
+
9
+ # --- Global Model and Processor (initialize as None for lazy loading) ---
10
+ HF_PROCESSOR = None
11
+ HF_MODEL = None
12
+ HF_PIPE = None
13
+ MODEL_LOAD_ERROR_MSG = None # To store any error message from loading
 
 
 
 
 
 
 
 
 
14
 
15
  # --- Helper Functions ---
16
 
 
59
  except Exception as e:
60
  return f"An unexpected error occurred during XML parsing: {e}"
61
 
62
+ @spaces.GPU # Ensures GPU is available for model loading (on first call) and inference
63
+ def predict(pil_image):
64
+ """Performs OCR prediction using the Hugging Face model, with lazy loading."""
65
+ global HF_PROCESSOR, HF_MODEL, HF_PIPE, MODEL_LOAD_ERROR_MSG
66
+
67
+ if HF_PIPE is None and MODEL_LOAD_ERROR_MSG is None:
68
+ try:
69
+ print("Attempting to load Hugging Face model and processor within @spaces.GPU context...")
70
+ HF_PROCESSOR = AutoProcessor.from_pretrained("reducto/RolmOCR")
71
+ HF_MODEL = AutoModelForImageTextToText.from_pretrained(
72
+ "reducto/RolmOCR",
73
+ torch_dtype=torch.bfloat16,
74
+ device_map="auto" # Should utilize ZeroGPU correctly here
75
+ )
76
+ HF_PIPE = pipeline("image-text-to-text", model=HF_MODEL, processor=HF_PROCESSOR)
77
+ print("Hugging Face OCR model loaded successfully.")
78
+ except Exception as e:
79
+ MODEL_LOAD_ERROR_MSG = f"Error loading Hugging Face model: {str(e)}"
80
+ print(MODEL_LOAD_ERROR_MSG)
81
+ # HF_PIPE remains None, error message is stored
82
+
83
+ if HF_PIPE is None:
84
+ error_to_report = MODEL_LOAD_ERROR_MSG if MODEL_LOAD_ERROR_MSG else "OCR model could not be initialized."
85
+ raise RuntimeError(error_to_report)
86
+
87
+ # Proceed with inference if pipe is available
88
+ return HF_PIPE(
89
+ pil_image,
90
+ prompt="Return the plain text representation of this document as if you were reading it naturally.\n",
91
+ )
92
+
93
  def run_hf_ocr(image_path):
94
  """
95
+ Runs OCR on the provided image using the Hugging Face model (via predict function).
96
  """
 
 
97
  if image_path is None:
98
  return "No image provided for OCR."
99
 
100
  try:
 
101
  pil_image = Image.open(image_path).convert("RGB")
102
+ ocr_results = predict(pil_image) # predict handles model loading and inference
 
 
 
 
 
103
 
104
  # Parse the output based on the user's example structure
105
  if isinstance(ocr_results, list) and ocr_results and 'generated_text' in ocr_results[0]:
106
  generated_content = ocr_results[0]['generated_text']
107
 
 
108
  if isinstance(generated_content, str):
109
  return generated_content
110
 
 
 
111
  if isinstance(generated_content, list) and generated_content:
112
+ if assistant_message := next(
113
+ (
114
+ msg['content']
115
+ for msg in reversed(generated_content)
116
+ if isinstance(msg, dict)
117
+ and msg.get('role') == 'assistant'
118
+ and 'content' in msg
119
+ ),
120
+ None,
121
+ ):
122
  return assistant_message
123
+
124
+ # Fallback if the specific assistant message structure isn't found but there's content
125
+ if isinstance(generated_content[0], dict) and 'content' in generated_content[0]:
126
+ if len(generated_content) > 1 and isinstance(generated_content[1], dict) and 'content' in generated_content[1]:
127
+ return generated_content[1]['content'] # Assuming second part is assistant
128
+ elif 'content' in generated_content[0]: # Or if first part is already the content
129
+ return generated_content[0]['content']
130
 
131
  print(f"Unexpected OCR output structure from HF model: {ocr_results}")
132
+ return "Error: Could not parse OCR model output. Check console."
133
 
134
  else:
135
  print(f"Unexpected OCR output structure from HF model: {ocr_results}")
136
+ return "Error: OCR model did not return expected output. Check console."
137
 
138
+ except RuntimeError as e: # Catch model loading/initialization errors from predict
139
+ return str(e)
140
  except Exception as e:
141
+ print(f"Error during Hugging Face OCR processing: {e}")
142
  return f"Error during Hugging Face OCR: {str(e)}"
 
 
 
 
 
 
 
 
 
143
 
144
  # --- Gradio Interface Function ---
145
 
 
247
  if __name__ == "__main__":
248
  # Removed dummy file creation as it's less relevant for single file focus
249
  print("Attempting to launch Gradio demo...")
250
+ print("If the Hugging Face model is large, initial startup might take some time due to model download/loading (on first OCR attempt).")
251
  demo.launch()