AnseMin commited on
Commit
2184c47
·
1 Parent(s): f89451e

restore to check point

Browse files
Files changed (1) hide show
  1. src/parsers/got_ocr_parser.py +47 -64
src/parsers/got_ocr_parser.py CHANGED
@@ -21,8 +21,7 @@ class GotOcrParser(DocumentParser):
21
  """
22
 
23
  _model = None
24
- _processor = None
25
- _stop_str = "<|im_end|>"
26
 
27
  @classmethod
28
  def get_name(cls) -> str:
@@ -68,13 +67,19 @@ class GotOcrParser(DocumentParser):
68
  @classmethod
69
  def _load_model(cls):
70
  """Load the GOT-OCR model and tokenizer if not already loaded."""
71
- if cls._model is None or cls._processor is None:
72
  try:
73
  # Import dependencies inside the method to avoid global import errors
74
  import torch
75
- from transformers import AutoModel, AutoProcessor
76
 
77
- logger.info("Loading GOT-OCR model and processor...")
 
 
 
 
 
 
78
 
79
  # Determine device
80
  device_map = 'cuda' if torch.cuda.is_available() else 'auto'
@@ -83,18 +88,15 @@ class GotOcrParser(DocumentParser):
83
  else:
84
  logger.warning("Using CPU for model inference (not recommended)")
85
 
86
- # Load the processor (includes tokenizer)
87
- cls._processor = AutoProcessor.from_pretrained(
88
- 'stepfun-ai/GOT-OCR2_0-hf'
89
- )
90
-
91
  # Load model with explicit float16 for T4 compatibility
92
  cls._model = AutoModel.from_pretrained(
93
- 'stepfun-ai/GOT-OCR2_0-hf',
 
94
  low_cpu_mem_usage=True,
95
  device_map=device_map,
 
96
  torch_dtype=torch.float16, # Force float16 for T4 compatibility
97
- trust_remote_code=True
98
  )
99
 
100
  # Explicitly convert model to half precision (float16)
@@ -105,6 +107,7 @@ class GotOcrParser(DocumentParser):
105
  cls._model = cls._model.cuda()
106
 
107
  # Patch torch.autocast to force float16 instead of bfloat16
 
108
  original_autocast = torch.autocast
109
  def patched_autocast(*args, **kwargs):
110
  # Force dtype to float16 when CUDA is involved
@@ -120,7 +123,7 @@ class GotOcrParser(DocumentParser):
120
  return True
121
  except Exception as e:
122
  cls._model = None
123
- cls._processor = None
124
  logger.error(f"Failed to load GOT-OCR model: {str(e)}")
125
  return False
126
  return True
@@ -135,9 +138,9 @@ class GotOcrParser(DocumentParser):
135
  del cls._model
136
  cls._model = None
137
 
138
- if cls._processor is not None:
139
- del cls._processor
140
- cls._processor = None
141
 
142
  # Clear CUDA cache if available
143
  if torch.cuda.is_available():
@@ -172,7 +175,6 @@ class GotOcrParser(DocumentParser):
172
 
173
  # Import torch here to ensure it's available
174
  import torch
175
- from transformers.image_utils import load_image
176
 
177
  # Validate file path and extension
178
  file_path = Path(file_path)
@@ -185,43 +187,31 @@ class GotOcrParser(DocumentParser):
185
  f"Received file with extension: {file_path.suffix}"
186
  )
187
 
188
- # Determine format flag based on OCR method
189
- format_flag = ocr_method == "format"
190
- logger.info(f"Using OCR method: {'format' if format_flag else 'plain'}")
191
 
192
  # Process the image
193
  try:
194
  logger.info(f"Processing image with GOT-OCR: {file_path}")
195
 
196
- # Load image with transformers utils
197
- image = load_image(str(file_path))
198
-
199
  # First attempt: Normal processing with autocast
200
  try:
201
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
202
- # Process image with format flag if needed
203
- if format_flag:
204
- inputs = self._processor(image, return_tensors="pt", format=True).to("cuda")
 
 
 
 
205
  else:
206
- inputs = self._processor(image, return_tensors="pt").to("cuda")
207
-
208
- # Generate text
209
- generate_ids = self._model.generate(
210
- **inputs,
211
- do_sample=False,
212
- tokenizer=self._processor.tokenizer,
213
- stop_strings=self._stop_str,
214
- max_new_tokens=4096,
215
- )
216
-
217
- # Decode the generated text
218
- result = self._processor.decode(
219
- generate_ids[0, inputs["input_ids"].shape[1]:],
220
- skip_special_tokens=True,
221
- )
222
-
223
- return result
224
-
225
  except RuntimeError as e:
226
  # Check if it's a bfloat16 error
227
  if "bfloat16" in str(e) or "BFloat16" in str(e):
@@ -237,26 +227,19 @@ class GotOcrParser(DocumentParser):
237
  torch.set_default_dtype(torch.float16)
238
 
239
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
240
- # Process image with format flag if needed
241
- if format_flag:
242
- inputs = self._processor(image, return_tensors="pt", format=True).to("cuda")
 
 
 
 
243
  else:
244
- inputs = self._processor(image, return_tensors="pt").to("cuda")
245
-
246
- # Generate text
247
- generate_ids = self._model.generate(
248
- **inputs,
249
- do_sample=False,
250
- tokenizer=self._processor.tokenizer,
251
- stop_strings=self._stop_str,
252
- max_new_tokens=4096,
253
- )
254
-
255
- # Decode the generated text
256
- result = self._processor.decode(
257
- generate_ids[0, inputs["input_ids"].shape[1]:],
258
- skip_special_tokens=True,
259
- )
260
 
261
  # Restore default dtype
262
  torch.set_default_dtype(original_dtype)
 
21
  """
22
 
23
  _model = None
24
+ _tokenizer = None
 
25
 
26
  @classmethod
27
  def get_name(cls) -> str:
 
67
  @classmethod
68
  def _load_model(cls):
69
  """Load the GOT-OCR model and tokenizer if not already loaded."""
70
+ if cls._model is None or cls._tokenizer is None:
71
  try:
72
  # Import dependencies inside the method to avoid global import errors
73
  import torch
74
+ from transformers import AutoModel, AutoTokenizer
75
 
76
+ logger.info("Loading GOT-OCR model and tokenizer...")
77
+
78
+ # Load tokenizer
79
+ cls._tokenizer = AutoTokenizer.from_pretrained(
80
+ 'stepfun-ai/GOT-OCR2_0',
81
+ trust_remote_code=True
82
+ )
83
 
84
  # Determine device
85
  device_map = 'cuda' if torch.cuda.is_available() else 'auto'
 
88
  else:
89
  logger.warning("Using CPU for model inference (not recommended)")
90
 
 
 
 
 
 
91
  # Load model with explicit float16 for T4 compatibility
92
  cls._model = AutoModel.from_pretrained(
93
+ 'stepfun-ai/GOT-OCR2_0',
94
+ trust_remote_code=True,
95
  low_cpu_mem_usage=True,
96
  device_map=device_map,
97
+ use_safetensors=True,
98
  torch_dtype=torch.float16, # Force float16 for T4 compatibility
99
+ pad_token_id=cls._tokenizer.eos_token_id
100
  )
101
 
102
  # Explicitly convert model to half precision (float16)
 
107
  cls._model = cls._model.cuda()
108
 
109
  # Patch torch.autocast to force float16 instead of bfloat16
110
+ # This fixes the issue in the model's chat method (line 581)
111
  original_autocast = torch.autocast
112
  def patched_autocast(*args, **kwargs):
113
  # Force dtype to float16 when CUDA is involved
 
123
  return True
124
  except Exception as e:
125
  cls._model = None
126
+ cls._tokenizer = None
127
  logger.error(f"Failed to load GOT-OCR model: {str(e)}")
128
  return False
129
  return True
 
138
  del cls._model
139
  cls._model = None
140
 
141
+ if cls._tokenizer is not None:
142
+ del cls._tokenizer
143
+ cls._tokenizer = None
144
 
145
  # Clear CUDA cache if available
146
  if torch.cuda.is_available():
 
175
 
176
  # Import torch here to ensure it's available
177
  import torch
 
178
 
179
  # Validate file path and extension
180
  file_path = Path(file_path)
 
187
  f"Received file with extension: {file_path.suffix}"
188
  )
189
 
190
+ # Determine OCR type based on method
191
+ ocr_type = "format" if ocr_method == "format" else "ocr"
192
+ logger.info(f"Using OCR method: {ocr_type}")
193
 
194
  # Process the image
195
  try:
196
  logger.info(f"Processing image with GOT-OCR: {file_path}")
197
 
 
 
 
198
  # First attempt: Normal processing with autocast
199
  try:
200
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
201
+ # Use format=True parameter when ocr_type is "format"
202
+ if ocr_type == "format":
203
+ result = self._model.chat(
204
+ self._tokenizer,
205
+ str(file_path),
206
+ ocr_type='format'
207
+ )
208
  else:
209
+ result = self._model.chat(
210
+ self._tokenizer,
211
+ str(file_path),
212
+ ocr_type='ocr'
213
+ )
214
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  except RuntimeError as e:
216
  # Check if it's a bfloat16 error
217
  if "bfloat16" in str(e) or "BFloat16" in str(e):
 
227
  torch.set_default_dtype(torch.float16)
228
 
229
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
230
+ # Use format=True parameter when ocr_type is "format"
231
+ if ocr_type == "format":
232
+ result = self._model.chat(
233
+ self._tokenizer,
234
+ str(file_path),
235
+ ocr_type='format'
236
+ )
237
  else:
238
+ result = self._model.chat(
239
+ self._tokenizer,
240
+ str(file_path),
241
+ ocr_type='ocr'
242
+ )
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  # Restore default dtype
245
  torch.set_default_dtype(original_dtype)