AnseMin commited on
Commit
fa54d05
·
1 Parent(s): c4c3253

Error: Too many output

Browse files
Files changed (1) hide show
  1. src/parsers/got_ocr_parser.py +65 -49
src/parsers/got_ocr_parser.py CHANGED
@@ -21,7 +21,8 @@ class GotOcrParser(DocumentParser):
21
  """
22
 
23
  _model = None
24
- _tokenizer = None
 
25
 
26
  @classmethod
27
  def get_name(cls) -> str:
@@ -67,19 +68,13 @@ class GotOcrParser(DocumentParser):
67
  @classmethod
68
  def _load_model(cls):
69
  """Load the GOT-OCR model and tokenizer if not already loaded."""
70
- if cls._model is None or cls._tokenizer is None:
71
  try:
72
  # Import dependencies inside the method to avoid global import errors
73
  import torch
74
- from transformers import AutoModel, AutoTokenizer
75
 
76
- logger.info("Loading GOT-OCR model and tokenizer...")
77
-
78
- # Load tokenizer
79
- cls._tokenizer = AutoTokenizer.from_pretrained(
80
- 'stepfun-ai/GOT-OCR2_0',
81
- trust_remote_code=True
82
- )
83
 
84
  # Determine device
85
  device_map = 'cuda' if torch.cuda.is_available() else 'auto'
@@ -88,15 +83,17 @@ class GotOcrParser(DocumentParser):
88
  else:
89
  logger.warning("Using CPU for model inference (not recommended)")
90
 
 
 
 
 
 
91
  # Load model with explicit float16 for T4 compatibility
92
- cls._model = AutoModel.from_pretrained(
93
- 'stepfun-ai/GOT-OCR2_0',
94
- trust_remote_code=True,
95
  low_cpu_mem_usage=True,
96
  device_map=device_map,
97
- use_safetensors=True,
98
- torch_dtype=torch.float16, # Force float16 for T4 compatibility
99
- pad_token_id=cls._tokenizer.eos_token_id
100
  )
101
 
102
  # Explicitly convert model to half precision (float16)
@@ -107,7 +104,6 @@ class GotOcrParser(DocumentParser):
107
  cls._model = cls._model.cuda()
108
 
109
  # Patch torch.autocast to force float16 instead of bfloat16
110
- # This fixes the issue in the model's chat method (line 581)
111
  original_autocast = torch.autocast
112
  def patched_autocast(*args, **kwargs):
113
  # Force dtype to float16 when CUDA is involved
@@ -123,7 +119,7 @@ class GotOcrParser(DocumentParser):
123
  return True
124
  except Exception as e:
125
  cls._model = None
126
- cls._tokenizer = None
127
  logger.error(f"Failed to load GOT-OCR model: {str(e)}")
128
  return False
129
  return True
@@ -138,9 +134,9 @@ class GotOcrParser(DocumentParser):
138
  del cls._model
139
  cls._model = None
140
 
141
- if cls._tokenizer is not None:
142
- del cls._tokenizer
143
- cls._tokenizer = None
144
 
145
  # Clear CUDA cache if available
146
  if torch.cuda.is_available():
@@ -175,6 +171,7 @@ class GotOcrParser(DocumentParser):
175
 
176
  # Import torch here to ensure it's available
177
  import torch
 
178
 
179
  # Validate file path and extension
180
  file_path = Path(file_path)
@@ -187,31 +184,43 @@ class GotOcrParser(DocumentParser):
187
  f"Received file with extension: {file_path.suffix}"
188
  )
189
 
190
- # Determine OCR type based on method
191
- ocr_type = "format" if ocr_method == "format" else "ocr"
192
- logger.info(f"Using OCR method: {ocr_type}")
193
 
194
  # Process the image
195
  try:
196
  logger.info(f"Processing image with GOT-OCR: {file_path}")
197
 
 
 
 
198
  # First attempt: Normal processing with autocast
199
  try:
200
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
201
- # Use format=True parameter when ocr_type is "format"
202
- if ocr_type == "format":
203
- result = self._model.chat(
204
- self._tokenizer,
205
- str(file_path),
206
- ocr_type='format'
207
- )
208
  else:
209
- result = self._model.chat(
210
- self._tokenizer,
211
- str(file_path),
212
- ocr_type='ocr'
213
- )
214
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  except RuntimeError as e:
216
  # Check if it's a bfloat16 error
217
  if "bfloat16" in str(e) or "BFloat16" in str(e):
@@ -227,19 +236,26 @@ class GotOcrParser(DocumentParser):
227
  torch.set_default_dtype(torch.float16)
228
 
229
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
230
- # Use format=True parameter when ocr_type is "format"
231
- if ocr_type == "format":
232
- result = self._model.chat(
233
- self._tokenizer,
234
- str(file_path),
235
- ocr_type='format'
236
- )
237
  else:
238
- result = self._model.chat(
239
- self._tokenizer,
240
- str(file_path),
241
- ocr_type='ocr'
242
- )
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  # Restore default dtype
245
  torch.set_default_dtype(original_dtype)
 
21
  """
22
 
23
  _model = None
24
+ _processor = None
25
+ _stop_str = "<|im_end|>"
26
 
27
  @classmethod
28
  def get_name(cls) -> str:
 
68
  @classmethod
69
  def _load_model(cls):
70
  """Load the GOT-OCR model and tokenizer if not already loaded."""
71
+ if cls._model is None or cls._processor is None:
72
  try:
73
  # Import dependencies inside the method to avoid global import errors
74
  import torch
75
+ from transformers import AutoModelForImageTextToText, AutoProcessor
76
 
77
+ logger.info("Loading GOT-OCR model and processor...")
 
 
 
 
 
 
78
 
79
  # Determine device
80
  device_map = 'cuda' if torch.cuda.is_available() else 'auto'
 
83
  else:
84
  logger.warning("Using CPU for model inference (not recommended)")
85
 
86
+ # Load the processor (includes tokenizer)
87
+ cls._processor = AutoProcessor.from_pretrained(
88
+ 'stepfun-ai/GOT-OCR2_0-hf'
89
+ )
90
+
91
  # Load model with explicit float16 for T4 compatibility
92
+ cls._model = AutoModelForImageTextToText.from_pretrained(
93
+ 'stepfun-ai/GOT-OCR2_0-hf',
 
94
  low_cpu_mem_usage=True,
95
  device_map=device_map,
96
+ torch_dtype=torch.float16 # Force float16 for T4 compatibility
 
 
97
  )
98
 
99
  # Explicitly convert model to half precision (float16)
 
104
  cls._model = cls._model.cuda()
105
 
106
  # Patch torch.autocast to force float16 instead of bfloat16
 
107
  original_autocast = torch.autocast
108
  def patched_autocast(*args, **kwargs):
109
  # Force dtype to float16 when CUDA is involved
 
119
  return True
120
  except Exception as e:
121
  cls._model = None
122
+ cls._processor = None
123
  logger.error(f"Failed to load GOT-OCR model: {str(e)}")
124
  return False
125
  return True
 
134
  del cls._model
135
  cls._model = None
136
 
137
+ if cls._processor is not None:
138
+ del cls._processor
139
+ cls._processor = None
140
 
141
  # Clear CUDA cache if available
142
  if torch.cuda.is_available():
 
171
 
172
  # Import torch here to ensure it's available
173
  import torch
174
+ from transformers.image_utils import load_image
175
 
176
  # Validate file path and extension
177
  file_path = Path(file_path)
 
184
  f"Received file with extension: {file_path.suffix}"
185
  )
186
 
187
+ # Determine format flag based on OCR method
188
+ format_flag = ocr_method == "format"
189
+ logger.info(f"Using OCR method: {'format' if format_flag else 'plain'}")
190
 
191
  # Process the image
192
  try:
193
  logger.info(f"Processing image with GOT-OCR: {file_path}")
194
 
195
+ # Load image with transformers utils
196
+ image = load_image(str(file_path))
197
+
198
  # First attempt: Normal processing with autocast
199
  try:
200
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
201
+ # Process image with format flag if needed
202
+ if format_flag:
203
+ inputs = self._processor(image, return_tensors="pt", format=True).to("cuda")
 
 
 
 
204
  else:
205
+ inputs = self._processor(image, return_tensors="pt").to("cuda")
206
+
207
+ # Generate text
208
+ generate_ids = self._model.generate(
209
+ **inputs,
210
+ do_sample=False,
211
+ tokenizer=self._processor.tokenizer,
212
+ stop_strings=self._stop_str,
213
+ max_new_tokens=4096,
214
+ )
215
+
216
+ # Decode the generated text
217
+ result = self._processor.decode(
218
+ generate_ids[0, inputs["input_ids"].shape[1]:],
219
+ skip_special_tokens=True,
220
+ )
221
+
222
+ return result
223
+
224
  except RuntimeError as e:
225
  # Check if it's a bfloat16 error
226
  if "bfloat16" in str(e) or "BFloat16" in str(e):
 
236
  torch.set_default_dtype(torch.float16)
237
 
238
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
239
+ # Process image with format flag if needed
240
+ if format_flag:
241
+ inputs = self._processor(image, return_tensors="pt", format=True).to("cuda")
 
 
 
 
242
  else:
243
+ inputs = self._processor(image, return_tensors="pt").to("cuda")
244
+
245
+ # Generate text
246
+ generate_ids = self._model.generate(
247
+ **inputs,
248
+ do_sample=False,
249
+ tokenizer=self._processor.tokenizer,
250
+ stop_strings=self._stop_str,
251
+ max_new_tokens=4096,
252
+ )
253
+
254
+ # Decode the generated text
255
+ result = self._processor.decode(
256
+ generate_ids[0, inputs["input_ids"].shape[1]:],
257
+ skip_special_tokens=True,
258
+ )
259
 
260
  # Restore default dtype
261
  torch.set_default_dtype(original_dtype)