bikram commited on
Commit
de7714f
·
1 Parent(s): 16ea19c

romal , devngari classifier added

Browse files
__pycache__/main.cpython-310.pyc ADDED
Binary file (4.33 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (7.7 kB). View file
 
main.py CHANGED
@@ -71,7 +71,7 @@ from pydantic import BaseModel
71
  import shutil
72
 
73
  # Import from optimized utils
74
- from utils import dev_number, roman_number, dev_letter, roman_letter
75
 
76
  app = FastAPI(
77
  title="OCR API",
@@ -150,7 +150,8 @@ async def extract_text(
150
  "dev_number": dev_number,
151
  "roman_number": roman_number,
152
  "dev_letter": dev_letter,
153
- "roman_letter": roman_letter
 
154
  }
155
 
156
  if model_type not in ocr_functions:
@@ -179,6 +180,21 @@ async def extract_roman_letter(image: UploadFile = File(...)):
179
  """Extract Roman letters from an image"""
180
  return await process_ocr_request(image, roman_letter)
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  # Health check endpoint
183
  @app.get("/health")
184
  async def health_check():
 
71
  import shutil
72
 
73
  # Import from optimized utils
74
+ from utils import dev_number, roman_number, dev_letter, roman_letter, predict_ne
75
 
76
  app = FastAPI(
77
  title="OCR API",
 
150
  "dev_number": dev_number,
151
  "roman_number": roman_number,
152
  "dev_letter": dev_letter,
153
+ "roman_letter": roman_letter,
154
+
155
  }
156
 
157
  if model_type not in ocr_functions:
 
180
  """Extract Roman letters from an image"""
181
  return await process_ocr_request(image, roman_letter)
182
 
183
+ @app.post("/predict_ne")
184
+ async def classify_ne(image: UploadFile = File(...)):
185
+ """Predict Named Entities from an image"""
186
+ # Placeholder for Named Entity Recognition logic
187
+ image_path = await save_upload_file_tmp(image)
188
+ prediction = predict_ne(
189
+ image_path=image_path,
190
+ # model="models/nepali_english_classifier.pth", # Update with actual model path
191
+ device="cpu" # or "cpu"
192
+ )
193
+
194
+
195
+
196
+ # Implement the logic as per your requirements
197
+ return JSONResponse(content={"predicted": prediction})
198
  # Health check endpoint
199
  @app.get("/health")
200
  async def health_check():
models/nepali_english_classifier.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baaaad1b2696999492a6d7cad825a51319234838a7230e4a6833705613450170
3
+ size 95411738
utils.py CHANGED
@@ -164,9 +164,11 @@ from PIL import Image
164
  import numpy as np
165
  import torchvision.transforms as transforms
166
  from doctr.io import DocumentFile
 
167
  from doctr.models import recognition_predictor
168
  import os
169
  from functools import lru_cache
 
170
 
171
  # Character sets
172
  CHARACTER_NUM = "0123456789-"
@@ -176,11 +178,28 @@ CHARACTER_LETTER = ''' "()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnopr
176
  MODEL_PATHS = {
177
  'dev_digits': "models/devnagri_digits_20k_v2.pth",
178
  'roman_digits': "models/roman_digits_20k_v5.pth",
179
- 'dev_letter': "models/small_devnagari_letter.pth"
 
180
  }
181
 
 
182
  # Use GPU if available
183
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  # Define the CRNN model
186
  class CRNN(nn.Module):
@@ -326,6 +345,8 @@ class OCRModelManager:
326
  result = self.roman_letter_model(img)
327
  # print(result)
328
  return result[0][0]
 
 
329
 
330
 
331
  # Initialize the model manager as a singleton
@@ -346,4 +367,25 @@ def dev_letter(image_path):
346
 
347
  def roman_letter(image_path):
348
  """Recognize Roman letters in an image"""
349
- return ocr_manager.predict_roman_letter(image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  import numpy as np
165
  import torchvision.transforms as transforms
166
  from doctr.io import DocumentFile
167
+ from torchvision import models
168
  from doctr.models import recognition_predictor
169
  import os
170
  from functools import lru_cache
171
+ import pickle
172
 
173
  # Character sets
174
  CHARACTER_NUM = "0123456789-"
 
178
  MODEL_PATHS = {
179
  'dev_digits': "models/devnagri_digits_20k_v2.pth",
180
  'roman_digits': "models/roman_digits_20k_v5.pth",
181
+ 'dev_letter': "models/small_devnagari_letter.pth",
182
+ 'classify_ne': "models/nepali_english_classifier.pth"
183
  }
184
 
185
+
186
  # Use GPU if available
187
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
188
+ class ResNetClassifier(nn.Module):
189
+ def __init__(self, num_classes=2):
190
+ super(ResNetClassifier, self).__init__()
191
+ self.base_model = models.resnet50(weights='IMAGENET1K_V2') # Pre-trained ResNet-50
192
+ for param in self.base_model.parameters():
193
+ param.requires_grad = False # Freeze base model
194
+ num_ftrs = self.base_model.fc.in_features
195
+ self.base_model.fc = nn.Sequential(
196
+ nn.Linear(num_ftrs, 128),
197
+ nn.ReLU(),
198
+ nn.Linear(128, num_classes)
199
+ )
200
+
201
+ def forward(self, x):
202
+ return self.base_model(x)
203
 
204
  # Define the CRNN model
205
  class CRNN(nn.Module):
 
345
  result = self.roman_letter_model(img)
346
  # print(result)
347
  return result[0][0]
348
+
349
+
350
 
351
 
352
  # Initialize the model manager as a singleton
 
367
 
368
  def roman_letter(image_path):
369
  """Recognize Roman letters in an image"""
370
+ return ocr_manager.predict_roman_letter(image_path)
371
+
372
+
373
+ def predict_ne(image_path, device="cpu"):
374
+ # load label encoder
375
+
376
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
377
+ model = ResNetClassifier(num_classes=2).to(device)
378
+ # model.eval()
379
+ transform = transforms.Compose([
380
+ transforms.Resize((224, 224)),
381
+ transforms.ToTensor(),
382
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
383
+ ])
384
+ image = Image.open(image_path).convert('RGB')
385
+ image_tensor = transform(image).unsqueeze(0).to(device)
386
+ model.load_state_dict(torch.load('models/nepali_english_classifier.pth', map_location=device))
387
+ model.eval()
388
+ with torch.no_grad():
389
+ output = model(image_tensor)
390
+ _, predicted = torch.max(output, 1)
391
+ return predicted