ocr-llm-test / app.py
winamnd's picture
Update app.py
dad8a00 verified
raw
history blame
3.71 kB
import gradio as gr
import torch
import json
import os
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch.nn.functional as F
from save_results import save_results_to_repo
# Paths
MODEL_PATH = "./distilbert_spam_model"
# Ensure model exists
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
model.save_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer.save_pretrained(MODEL_PATH)
print(f"✅ Model saved at {MODEL_PATH}.")
else:
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
# Set model to evaluation mode
model.eval()
# OCR Methods
def ocr_with_paddle(img):
ocr = PaddleOCR(lang='en', use_angle_cls=True)
result = ocr.ocr(img)
return ' '.join([item[1][0] for item in result[0]])
def ocr_with_keras(img):
pipeline = keras_ocr.pipeline.Pipeline()
images = [keras_ocr.tools.read(img)]
predictions = pipeline.recognize(images)
return ' '.join([text for text, _ in predictions[0]])
def ocr_with_easy(img):
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
reader = easyocr.Reader(['en'])
results = reader.readtext(gray_image, detail=0)
return ' '.join(results)
# OCR Extraction Function
def extract_text(method, img):
if img is None:
return "Error: Please upload an image!", ""
# Convert PIL Image to OpenCV format
img = np.array(img)
# Select OCR method
if method == "PaddleOCR":
text_output = ocr_with_paddle(img)
elif method == "EasyOCR":
text_output = ocr_with_easy(img)
else: # KerasOCR
text_output = ocr_with_keras(img)
# Clean extracted text
text_output = text_output.strip()
if len(text_output) == 0:
return "No text detected!", ""
return text_output, ""
# Classification Function
def classify_text(text_output):
if text_output.strip() in ["No text detected!", "Error: Please upload an image!"]:
return text_output, "Cannot classify"
# Tokenize text
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Model inference
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1)
prediction = torch.argmax(probs, dim=1).item()
label_map = {0: "Not Spam", 1: "Spam"}
label = label_map[prediction]
# Save results automatically
save_results_to_repo(text_output, label)
return text_output, label
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## OCR Spam Classifier")
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR", label="Choose OCR Method")
image_input = gr.Image(label="Upload Image")
extract_button = gr.Button("Submit")
classify_button = gr.Button("Classify")
output_text = gr.Textbox(label="Extracted Text", interactive=True)
output_label = gr.Textbox(label="Spam Classification", interactive=False)
# Button Click Bindings
extract_button.click(fn=extract_text, inputs=[method_input, image_input], outputs=[output_text, output_label])
classify_button.click(fn=classify_text, inputs=[output_text], outputs=[output_text, output_label])
# Launch App
if __name__ == "__main__":
demo.launch()