Spaces:
Running
Running
import gradio as gr | |
import torch | |
import json | |
import os | |
import cv2 | |
import numpy as np | |
import easyocr | |
import keras_ocr | |
from paddleocr import PaddleOCR | |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
import torch.nn.functional as F | |
from save_results import save_results_to_repo | |
# Paths | |
MODEL_PATH = "./distilbert_spam_model" | |
# Ensure model exists | |
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")): | |
print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...") | |
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) | |
model.save_pretrained(MODEL_PATH) | |
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") | |
tokenizer.save_pretrained(MODEL_PATH) | |
print(f"✅ Model saved at {MODEL_PATH}.") | |
else: | |
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH) | |
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH) | |
# Set model to evaluation mode | |
model.eval() | |
# OCR Methods | |
def ocr_with_paddle(img): | |
ocr = PaddleOCR(lang='en', use_angle_cls=True) | |
result = ocr.ocr(img) | |
return ' '.join([item[1][0] for item in result[0]]) | |
def ocr_with_keras(img): | |
pipeline = keras_ocr.pipeline.Pipeline() | |
images = [keras_ocr.tools.read(img)] | |
predictions = pipeline.recognize(images) | |
return ' '.join([text for text, _ in predictions[0]]) | |
def ocr_with_easy(img): | |
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
reader = easyocr.Reader(['en']) | |
results = reader.readtext(gray_image, detail=0) | |
return ' '.join(results) | |
# OCR Extraction Function | |
def extract_text(method, img): | |
if img is None: | |
return "Error: Please upload an image!", "" | |
# Convert PIL Image to OpenCV format | |
img = np.array(img) | |
# Select OCR method | |
if method == "PaddleOCR": | |
text_output = ocr_with_paddle(img) | |
elif method == "EasyOCR": | |
text_output = ocr_with_easy(img) | |
else: # KerasOCR | |
text_output = ocr_with_keras(img) | |
# Clean extracted text | |
text_output = text_output.strip() | |
if len(text_output) == 0: | |
return "No text detected!", "" | |
return text_output, "" | |
# Classification Function | |
def classify_text(text_output): | |
if text_output.strip() in ["No text detected!", "Error: Please upload an image!"]: | |
return text_output, "Cannot classify" | |
# Tokenize text | |
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
# Model inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = F.softmax(outputs.logits, dim=1) | |
prediction = torch.argmax(probs, dim=1).item() | |
label_map = {0: "Not Spam", 1: "Spam"} | |
label = label_map[prediction] | |
# Save results automatically | |
save_results_to_repo(text_output, label) | |
return text_output, label | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## OCR Spam Classifier") | |
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR", label="Choose OCR Method") | |
image_input = gr.Image(label="Upload Image") | |
extract_button = gr.Button("Submit") | |
classify_button = gr.Button("Classify") | |
output_text = gr.Textbox(label="Extracted Text", interactive=True) | |
output_label = gr.Textbox(label="Spam Classification", interactive=False) | |
# Button Click Bindings | |
extract_button.click(fn=extract_text, inputs=[method_input, image_input], outputs=[output_text, output_label]) | |
classify_button.click(fn=classify_text, inputs=[output_text], outputs=[output_text, output_label]) | |
# Launch App | |
if __name__ == "__main__": | |
demo.launch() | |