|
import os |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
|
|
|
|
try: |
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") |
|
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") |
|
model.eval() |
|
print("β
TrOCR model loaded.") |
|
trocr_available = True |
|
except Exception as e: |
|
print(f"β Failed to load TrOCR: {e}") |
|
trocr_available = False |
|
|
|
def preprocess_image(image): |
|
""" |
|
Preprocess image for OCR: convert to grayscale and enhance contrast. |
|
""" |
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
denoised = cv2.fastNlMeansDenoising(gray, h=10) |
|
processed = cv2.adaptiveThreshold( |
|
denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, |
|
cv2.THRESH_BINARY, 11, 2 |
|
) |
|
return processed |
|
|
|
def extract_text_from_image(image_path): |
|
""" |
|
Extract handwritten text from an image using TrOCR |
|
""" |
|
try: |
|
print(f"π Reading image from: {image_path}") |
|
|
|
image = Image.open(image_path).convert("RGB") |
|
pixel_values = processor(images=image, return_tensors="pt").pixel_values |
|
|
|
generated_ids = trocr_model.generate(pixel_values) |
|
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
|
|
|
print(f"π Extracted Text from {os.path.basename(image_path)}: {text}") |
|
|
|
return text or "Text extraction failed." |
|
|
|
except Exception as e: |
|
print(f"β OCR failed on {image_path}: {str(e)}") |
|
return "Text extraction failed." |
|
|