AnseMin commited on
Commit
610b772
·
1 Parent(s): 62f9c09

handling zerogpu usage

Browse files
Files changed (4) hide show
  1. app.py +8 -0
  2. requirements.txt +1 -0
  3. setup.sh +5 -0
  4. src/parsers/got_ocr_parser.py +93 -82
app.py CHANGED
@@ -47,6 +47,14 @@ except Exception:
47
  print("WARNING: Hugging Face CLI not found. Installing...")
48
  subprocess.run([sys.executable, "-m", "pip", "install", "-q", "huggingface_hub[cli]"], check=False)
49
 
 
 
 
 
 
 
 
 
50
  # Try to load environment variables from .env file
51
  try:
52
  from dotenv import load_dotenv
 
47
  print("WARNING: Hugging Face CLI not found. Installing...")
48
  subprocess.run([sys.executable, "-m", "pip", "install", "-q", "huggingface_hub[cli]"], check=False)
49
 
50
+ # Check if spaces module is installed (needed for ZeroGPU)
51
+ try:
52
+ import spaces
53
+ print("Spaces module found for ZeroGPU support")
54
+ except ImportError:
55
+ print("WARNING: Spaces module not found. Installing...")
56
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", "spaces"], check=False)
57
+
58
  # Try to load environment variables from .env file
59
  try:
60
  from dotenv import load_dotenv
requirements.txt CHANGED
@@ -3,6 +3,7 @@ gradio==5.14.0
3
  markdown==3.7
4
  Pillow>=9.0.0,<11.0.0
5
  numpy<2.0.0
 
6
 
7
  # Image processing
8
  opencv-python-headless>=4.5.0 # Headless version for server environments
 
3
  markdown==3.7
4
  Pillow>=9.0.0,<11.0.0
5
  numpy<2.0.0
6
+ spaces # For ZeroGPU support
7
 
8
  # Image processing
9
  opencv-python-headless>=4.5.0 # Headless version for server environments
setup.sh CHANGED
@@ -45,6 +45,11 @@ echo "Installing Hugging Face CLI..."
45
  pip install -q -U "huggingface_hub[cli]"
46
  echo "Hugging Face CLI installed successfully"
47
 
 
 
 
 
 
48
  # Add debug section for GOT-OCR repo
49
  echo "===== GOT-OCR Repository Debugging ====="
50
 
 
45
  pip install -q -U "huggingface_hub[cli]"
46
  echo "Hugging Face CLI installed successfully"
47
 
48
+ # Install spaces module for ZeroGPU support
49
+ echo "Installing spaces module for ZeroGPU support..."
50
+ pip install -q -U spaces
51
+ echo "Spaces module installed successfully"
52
+
53
  # Add debug section for GOT-OCR repo
54
  echo "===== GOT-OCR Repository Debugging ====="
55
 
src/parsers/got_ocr_parser.py CHANGED
@@ -7,7 +7,13 @@ import tempfile
7
  import shutil
8
  from typing import Dict, List, Optional, Any, Union
9
 
10
- import spaces # Import spaces module for ZeroGPU support
 
 
 
 
 
 
11
  from src.parsers.parser_interface import DocumentParser
12
  from src.parsers.parser_registry import ParserRegistry
13
 
@@ -72,8 +78,9 @@ class GotOcrParser(DocumentParser):
72
  import transformers
73
  import tiktoken
74
 
75
- # For ZeroGPU, we don't need to check CUDA availability here
76
- # as the GPU will be allocated when needed
 
77
 
78
  # Check for latex2markdown
79
  try:
@@ -195,13 +202,9 @@ class GotOcrParser(DocumentParser):
195
  logger.error(f"Failed to set up GOT-OCR2.0 repository: {str(e)}")
196
  return False
197
 
198
- @spaces.GPU(duration=120) # Set duration to 120 seconds for OCR processing
199
  def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str:
200
  """Parse a document using GOT-OCR 2.0.
201
 
202
- This method is decorated with @spaces.GPU to enable ZeroGPU support.
203
- When called, it will request a GPU from the ZeroGPU pool.
204
-
205
  Args:
206
  file_path: Path to the image file
207
  ocr_method: OCR method to use ('plain' or 'format')
@@ -284,95 +287,67 @@ class GotOcrParser(DocumentParser):
284
  f.write(f"cd {parent_dir}\n") # Change to parent directory
285
  f.write("export PYTHONPATH=$PYTHONPATH:$(pwd)\n") # Add current directory to PYTHONPATH
286
 
287
- # Add environment variables for ZeroGPU support
288
- f.write("export SPACES_ZERO_GPU=1\n") # Enable ZeroGPU
289
- f.write("export CUDA_VISIBLE_DEVICES=0\n") # Use first available GPU
290
-
291
  # Add a Python script to patch torch.bfloat16
292
  patch_script = os.path.join(tempfile.gettempdir(), "patch_torch.py")
293
  with open(patch_script, 'w') as patch_f:
294
  patch_f.write("""
295
  import sys
296
  import torch
297
- import spaces
298
-
299
- @spaces.GPU(duration=120)
300
- def patch_torch():
301
- # Patch torch.bfloat16 to use torch.float16 instead
302
- if hasattr(torch, 'bfloat16'):
303
- # Save reference to original bfloat16
304
- original_bfloat16 = torch.bfloat16
305
- # Replace with float16
306
- torch.bfloat16 = torch.float16
307
- print("Successfully patched torch.bfloat16 to use torch.float16")
308
 
309
- # Also patch torch.autocast context manager for CUDA
310
- original_autocast = torch.autocast
311
- def patched_autocast(*args, **kwargs):
312
- # Force dtype to float16 when CUDA is involved
313
- if args and args[0] == "cuda" and kwargs.get("dtype") == torch.bfloat16:
314
- kwargs["dtype"] = torch.float16
315
- print(f"Autocast: Changed bfloat16 to float16 for {args}")
316
- return original_autocast(*args, **kwargs)
317
 
318
- torch.autocast = patched_autocast
319
- print("Successfully patched torch.autocast to ensure float16 is used instead of bfloat16")
 
 
 
 
 
 
320
 
321
- patch_torch() # Execute the patching
 
322
  """)
323
-
324
- # Build the command with the patch included and ZeroGPU support
325
- py_cmd = [
326
- sys.executable,
327
- "-c",
328
- f"""
329
- import sys
330
- import spaces
331
- sys.path.insert(0, '{parent_dir}')
332
- exec(open('{patch_script}').read())
333
-
334
- @spaces.GPU(duration=120)
335
- def run_got_ocr():
336
- import runpy
337
- runpy.run_path('{script_path}', run_name='__main__')
338
-
339
- run_got_ocr()
340
- """
341
- ]
342
-
343
- # Add the arguments
344
- py_cmd.extend(["--model-name", self._weights_path])
345
- py_cmd.extend(["--image-file", str(file_path)])
346
- py_cmd.extend(["--type", ocr_type])
347
-
348
- # Add render flag if required
349
- if render:
350
- py_cmd.append("--render")
351
-
352
- # Check if box or color is specified in kwargs
353
- if 'box' in kwargs and kwargs['box']:
354
- py_cmd.extend(["--box", str(kwargs['box'])])
355
-
356
- if 'color' in kwargs and kwargs['color']:
357
- py_cmd.extend(["--color", kwargs['color']])
358
-
359
- # Add the command to the script
360
- f.write(" ".join(py_cmd) + "\n")
361
 
362
  # Make the script executable
363
  os.chmod(temp_script, 0o755)
364
 
365
- # Run the script
366
- logger.info(f"Running command through wrapper script: {temp_script}")
367
- process = subprocess.run(
368
- [temp_script],
369
- check=True,
370
- capture_output=True,
371
- text=True
372
- )
373
-
374
- # Process the output
375
- result = process.stdout.strip()
376
 
377
  # If render was requested, find and return the path to the HTML file
378
  if render:
@@ -417,6 +392,42 @@ run_got_ocr()
417
 
418
  # Generic error
419
  raise RuntimeError(f"Error processing document with GOT-OCR: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
  @classmethod
422
  def release_model(cls):
 
7
  import shutil
8
  from typing import Dict, List, Optional, Any, Union
9
 
10
+ # Import spaces module for ZeroGPU support
11
+ try:
12
+ import spaces
13
+ HAS_SPACES = True
14
+ except ImportError:
15
+ HAS_SPACES = False
16
+
17
  from src.parsers.parser_interface import DocumentParser
18
  from src.parsers.parser_registry import ParserRegistry
19
 
 
78
  import transformers
79
  import tiktoken
80
 
81
+ # Check CUDA availability if using torch
82
+ if hasattr(torch, 'cuda') and not torch.cuda.is_available():
83
+ logger.warning("CUDA is not available. GOT-OCR performs best with GPU acceleration.")
84
 
85
  # Check for latex2markdown
86
  try:
 
202
  logger.error(f"Failed to set up GOT-OCR2.0 repository: {str(e)}")
203
  return False
204
 
 
205
  def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str:
206
  """Parse a document using GOT-OCR 2.0.
207
 
 
 
 
208
  Args:
209
  file_path: Path to the image file
210
  ocr_method: OCR method to use ('plain' or 'format')
 
287
  f.write(f"cd {parent_dir}\n") # Change to parent directory
288
  f.write("export PYTHONPATH=$PYTHONPATH:$(pwd)\n") # Add current directory to PYTHONPATH
289
 
 
 
 
 
290
  # Add a Python script to patch torch.bfloat16
291
  patch_script = os.path.join(tempfile.gettempdir(), "patch_torch.py")
292
  with open(patch_script, 'w') as patch_f:
293
  patch_f.write("""
294
  import sys
295
  import torch
 
 
 
 
 
 
 
 
 
 
 
296
 
297
+ # Patch torch.bfloat16 to use torch.float16 instead
298
+ if hasattr(torch, 'bfloat16'):
299
+ # Save reference to original bfloat16
300
+ original_bfloat16 = torch.bfloat16
301
+ # Replace with float16
302
+ torch.bfloat16 = torch.float16
303
+ print("Successfully patched torch.bfloat16 to use torch.float16")
 
304
 
305
+ # Also patch torch.autocast context manager for CUDA
306
+ original_autocast = torch.autocast
307
+ def patched_autocast(*args, **kwargs):
308
+ # Force dtype to float16 when CUDA is involved
309
+ if args and args[0] == "cuda" and kwargs.get("dtype") == torch.bfloat16:
310
+ kwargs["dtype"] = torch.float16
311
+ print(f"Autocast: Changed bfloat16 to float16 for {args}")
312
+ return original_autocast(*args, **kwargs)
313
 
314
+ torch.autocast = patched_autocast
315
+ print("Successfully patched torch.autocast to ensure float16 is used instead of bfloat16")
316
  """)
317
+
318
+ # Build the command with the patch included
319
+ py_cmd = [
320
+ sys.executable,
321
+ "-c",
322
+ f"import sys; sys.path.insert(0, '{parent_dir}'); "
323
+ f"exec(open('{patch_script}').read()); "
324
+ f"import runpy; runpy.run_path('{script_path}', run_name='__main__')"
325
+ ]
326
+
327
+ # Add the arguments
328
+ py_cmd.extend(["--model-name", self._weights_path])
329
+ py_cmd.extend(["--image-file", str(file_path)])
330
+ py_cmd.extend(["--type", ocr_type])
331
+
332
+ # Add render flag if required
333
+ if render:
334
+ py_cmd.append("--render")
335
+
336
+ # Check if box or color is specified in kwargs
337
+ if 'box' in kwargs and kwargs['box']:
338
+ py_cmd.extend(["--box", str(kwargs['box'])])
339
+
340
+ if 'color' in kwargs and kwargs['color']:
341
+ py_cmd.extend(["--color", kwargs['color']])
342
+
343
+ # Add the command to the script
344
+ f.write(" ".join(py_cmd) + "\n")
 
 
 
 
 
 
 
 
 
 
345
 
346
  # Make the script executable
347
  os.chmod(temp_script, 0o755)
348
 
349
+ # Run the script with GPU access if available
350
+ result = self._run_with_gpu(temp_script)
 
 
 
 
 
 
 
 
 
351
 
352
  # If render was requested, find and return the path to the HTML file
353
  if render:
 
392
 
393
  # Generic error
394
  raise RuntimeError(f"Error processing document with GOT-OCR: {str(e)}")
395
+
396
+ # Define a method that will be decorated with spaces.GPU to ensure GPU access
397
+ def _run_with_gpu(self, script_path):
398
+ """Run a script with GPU access using the spaces.GPU decorator if available."""
399
+ if HAS_SPACES:
400
+ # Use the spaces.GPU decorator to ensure GPU access
401
+ return self._run_script_with_gpu_allocation(script_path)
402
+ else:
403
+ # Fall back to regular execution if spaces module is not available
404
+ logger.info(f"Running command through wrapper script without ZeroGPU: {script_path}")
405
+ process = subprocess.run(
406
+ [script_path],
407
+ check=True,
408
+ capture_output=True,
409
+ text=True
410
+ )
411
+ return process.stdout.strip()
412
+
413
+ # This method will be decorated with spaces.GPU
414
+ if HAS_SPACES:
415
+ @spaces.GPU(duration=180) # Allocate up to 3 minutes for OCR processing
416
+ def _run_script_with_gpu_allocation(self, script_path):
417
+ """Run a script with GPU access using the spaces.GPU decorator."""
418
+ logger.info(f"Running command through wrapper script with ZeroGPU allocation: {script_path}")
419
+ process = subprocess.run(
420
+ [script_path],
421
+ check=True,
422
+ capture_output=True,
423
+ text=True
424
+ )
425
+ return process.stdout.strip()
426
+ else:
427
+ # Define a dummy method if spaces is not available
428
+ def _run_script_with_gpu_allocation(self, script_path):
429
+ # This should never be called if HAS_SPACES is False
430
+ raise NotImplementedError("spaces module is not available")
431
 
432
  @classmethod
433
  def release_model(cls):