AnseMin commited on
Commit
66d2b1b
·
1 Parent(s): 9b25e42

adding zerogpufor got ocr

Browse files
Files changed (4) hide show
  1. build.sh +0 -34
  2. requirements.txt +4 -1
  3. setup.sh +37 -2
  4. src/parsers/got_ocr_parser.py +44 -5
build.sh DELETED
@@ -1,34 +0,0 @@
1
- #!/bin/bash
2
-
3
- # Exit on error
4
- set -e
5
-
6
- echo "Starting build process..."
7
-
8
- # Install system dependencies
9
- echo "Installing system dependencies..."
10
- apt-get update && apt-get install -y \
11
- wget \
12
- pkg-config
13
-
14
- # Install Google Gemini API client
15
- echo "Installing Google Gemini API client..."
16
- pip install -q -U google-genai
17
- echo "Google Gemini API client installed successfully"
18
-
19
- # Install GOT-OCR dependencies
20
- echo "Installing GOT-OCR dependencies..."
21
- pip install -q -U torch==2.0.1 torchvision==0.15.2 transformers==4.37.2 tiktoken==0.6.0 verovio==4.3.1 accelerate==0.28.0 safetensors==0.4.3
22
- echo "GOT-OCR dependencies installed successfully"
23
-
24
- # Install Python dependencies
25
- echo "Installing Python dependencies..."
26
- pip install -e .
27
-
28
- # Create .env file if it doesn't exist
29
- if [ ! -f .env ]; then
30
- echo "Creating .env file..."
31
- cp .env.example .env || echo "Warning: .env.example not found"
32
- fi
33
-
34
- echo "Build process completed successfully!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -21,4 +21,7 @@ transformers==4.37.2 # Pin to a specific version that works with safetensors 0.
21
  tiktoken==0.6.0
22
  verovio==4.3.1
23
  accelerate==0.28.0
24
- safetensors==0.4.3 # Updated to meet minimum version required by accelerate
 
 
 
 
21
  tiktoken==0.6.0
22
  verovio==4.3.1
23
  accelerate==0.28.0
24
+ safetensors==0.4.3 # Updated to meet minimum version required by accelerate
25
+
26
+ # ZeroGPU support for HuggingFace Spaces
27
+ spaces>=0.19.1
setup.sh CHANGED
@@ -3,7 +3,20 @@
3
  # Exit on error
4
  set -e
5
 
6
- echo "Setting up environment..."
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Install Python dependencies
9
  echo "Installing Python dependencies..."
@@ -16,4 +29,26 @@ echo "Installing GOT-OCR dependencies..."
16
  pip install -q -U torch==2.0.1 torchvision==0.15.2 transformers==4.37.2 tiktoken==0.6.0 verovio==4.3.1 accelerate==0.28.0 safetensors==0.4.3
17
  echo "GOT-OCR dependencies installed successfully"
18
 
19
- echo "Setup completed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  # Exit on error
4
  set -e
5
 
6
+ echo "Starting setup process..."
7
+
8
+ # Check if running with sudo/root permissions for system dependencies
9
+ if [ "$EUID" -eq 0 ]; then
10
+ # Install system dependencies
11
+ echo "Installing system dependencies..."
12
+ apt-get update && apt-get install -y \
13
+ wget \
14
+ pkg-config
15
+ echo "System dependencies installed successfully"
16
+ else
17
+ echo "Not running as root. Skipping system dependencies installation."
18
+ echo "If system dependencies are needed, please run this script with sudo."
19
+ fi
20
 
21
  # Install Python dependencies
22
  echo "Installing Python dependencies..."
 
29
  pip install -q -U torch==2.0.1 torchvision==0.15.2 transformers==4.37.2 tiktoken==0.6.0 verovio==4.3.1 accelerate==0.28.0 safetensors==0.4.3
30
  echo "GOT-OCR dependencies installed successfully"
31
 
32
+ # Install ZeroGPU support
33
+ echo "Installing ZeroGPU support..."
34
+ pip install -q -U spaces>=0.19.1
35
+ echo "ZeroGPU support installed successfully"
36
+
37
+ # Install the project in development mode
38
+ echo "Installing project in development mode..."
39
+ pip install -e .
40
+ echo "Project installed successfully"
41
+
42
+ # Create .env file if it doesn't exist
43
+ if [ ! -f .env ]; then
44
+ echo "Creating .env file..."
45
+ if [ -f .env.example ]; then
46
+ cp .env.example .env
47
+ echo ".env file created from .env.example"
48
+ else
49
+ echo "Warning: .env.example not found. Creating empty .env file."
50
+ touch .env
51
+ fi
52
+ fi
53
+
54
+ echo "Setup process completed successfully!"
src/parsers/got_ocr_parser.py CHANGED
@@ -25,9 +25,19 @@ try:
25
  "Consider downgrading to version <4.48.0"
26
  )
27
 
 
 
 
 
 
 
 
 
 
28
  GOT_AVAILABLE = True
29
  except ImportError:
30
  GOT_AVAILABLE = False
 
31
  logger.warning("GOT-OCR dependencies not installed. The parser will not be available.")
32
 
33
  class GotOcrParser(DocumentParser):
@@ -65,15 +75,35 @@ class GotOcrParser(DocumentParser):
65
  'stepfun-ai/GOT-OCR2_0',
66
  trust_remote_code=True
67
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  cls._model = AutoModel.from_pretrained(
69
  'stepfun-ai/GOT-OCR2_0',
70
  trust_remote_code=True,
71
  low_cpu_mem_usage=True,
72
- device_map='cuda',
73
  use_safetensors=True,
74
  pad_token_id=cls._tokenizer.eos_token_id
75
  )
76
- cls._model = cls._model.eval().cuda()
 
 
 
 
 
 
77
  logger.info("GOT-OCR model loaded successfully")
78
  except Exception as e:
79
  cls._model = None
@@ -92,6 +122,15 @@ class GotOcrParser(DocumentParser):
92
  cls._tokenizer = None
93
  if torch.cuda.is_available():
94
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
95
  logger.info("GOT-OCR model released from memory")
96
 
97
  def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str:
@@ -102,9 +141,9 @@ class GotOcrParser(DocumentParser):
102
  "torch, transformers, tiktoken, verovio, accelerate"
103
  )
104
 
105
- # Check if CUDA is available
106
- if not torch.cuda.is_available():
107
- raise RuntimeError("GOT-OCR requires CUDA. CPU-only mode is not supported.")
108
 
109
  # Check file extension
110
  file_path = Path(file_path)
 
25
  "Consider downgrading to version <4.48.0"
26
  )
27
 
28
+ # Import spaces for ZeroGPU support
29
+ try:
30
+ import spaces
31
+ ZEROGPU_AVAILABLE = True
32
+ logger.info("ZeroGPU support is available")
33
+ except ImportError:
34
+ ZEROGPU_AVAILABLE = False
35
+ logger.info("ZeroGPU not available, will use standard GPU if available")
36
+
37
  GOT_AVAILABLE = True
38
  except ImportError:
39
  GOT_AVAILABLE = False
40
+ ZEROGPU_AVAILABLE = False
41
  logger.warning("GOT-OCR dependencies not installed. The parser will not be available.")
42
 
43
  class GotOcrParser(DocumentParser):
 
75
  'stepfun-ai/GOT-OCR2_0',
76
  trust_remote_code=True
77
  )
78
+
79
+ # Determine device mapping based on ZeroGPU availability
80
+ if ZEROGPU_AVAILABLE:
81
+ logger.info("Using ZeroGPU for model loading")
82
+ # Request GPU resources through ZeroGPU
83
+ spaces.enable_gpu()
84
+ device_map = 'cuda'
85
+ elif torch.cuda.is_available():
86
+ logger.info("Using local CUDA device for model loading")
87
+ device_map = 'cuda'
88
+ else:
89
+ logger.warning("No GPU available, falling back to CPU (not recommended)")
90
+ device_map = 'auto'
91
+
92
  cls._model = AutoModel.from_pretrained(
93
  'stepfun-ai/GOT-OCR2_0',
94
  trust_remote_code=True,
95
  low_cpu_mem_usage=True,
96
+ device_map=device_map,
97
  use_safetensors=True,
98
  pad_token_id=cls._tokenizer.eos_token_id
99
  )
100
+
101
+ # Set model to evaluation mode
102
+ if device_map == 'cuda':
103
+ cls._model = cls._model.eval().cuda()
104
+ else:
105
+ cls._model = cls._model.eval()
106
+
107
  logger.info("GOT-OCR model loaded successfully")
108
  except Exception as e:
109
  cls._model = None
 
122
  cls._tokenizer = None
123
  if torch.cuda.is_available():
124
  torch.cuda.empty_cache()
125
+
126
+ # Release ZeroGPU resources if available
127
+ if ZEROGPU_AVAILABLE:
128
+ try:
129
+ spaces.disable_gpu()
130
+ logger.info("ZeroGPU resources released")
131
+ except Exception as e:
132
+ logger.warning(f"Error releasing ZeroGPU resources: {str(e)}")
133
+
134
  logger.info("GOT-OCR model released from memory")
135
 
136
  def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str:
 
141
  "torch, transformers, tiktoken, verovio, accelerate"
142
  )
143
 
144
+ # Check if CUDA is available (either directly or through ZeroGPU)
145
+ if not torch.cuda.is_available() and not ZEROGPU_AVAILABLE:
146
+ logger.warning("No GPU available. GOT-OCR performance may be severely degraded.")
147
 
148
  # Check file extension
149
  file_path = Path(file_path)