Spaces:
Sleeping
Sleeping
import gradio as gr | |
import onnxruntime as ort | |
from transformers import RobertaTokenizer, ViTImageProcessor | |
from PIL import Image | |
import numpy as np | |
import torch | |
import os | |
import time | |
import logging | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") | |
tokenizer = RobertaTokenizer.from_pretrained("roberta-base") | |
model_path = "./multimodal_model.onnx" | |
try: | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"ONNX model not found at {model_path}") | |
logger.info(f"Loading ONNX model from {model_path}") | |
sess_options = ort.SessionOptions() | |
sess_options.log_severity_level = 0 | |
ort_session = ort.InferenceSession( | |
model_path, | |
sess_options=sess_options, | |
providers=['CPUExecutionProvider'] | |
) | |
logger.info("ONNX model loaded successfully") | |
input_names = [input.name for input in ort_session.get_inputs()] | |
input_shapes = {input.name: input.shape for input in ort_session.get_inputs()} | |
output_names = [output.name for output in ort_session.get_outputs()] | |
logger.info(f"Model inputs: {input_names} with shapes {input_shapes}") | |
logger.info(f"Model outputs: {output_names}") | |
except Exception as e: | |
logger.error(f"Error loading ONNX model: {e}") | |
raise | |
labels = ["Real", "Real Text with fake image", "Fake"] | |
def softmax(x): | |
"""Compute softmax values for each sets of scores in x.""" | |
e_x = np.exp(x - np.max(x, axis=1, keepdims=True)) | |
return e_x / e_x.sum(axis=1, keepdims=True) | |
def image_with_prediction(img, label, confidence): | |
"""Return the original image with an overlay showing the prediction""" | |
from PIL import Image, ImageDraw, ImageFont | |
img_copy = img.copy() | |
draw = ImageDraw.Draw(img_copy) | |
width, height = img_copy.size | |
overlay = Image.new('RGBA', (width, 40), (0, 0, 0, 150)) | |
img_copy.paste(overlay, (0, height-40), overlay) | |
text = f"{label}: {confidence:.1%}" | |
try: | |
font = ImageFont.truetype("arial.ttf", 20) | |
except IOError: | |
font = ImageFont.load_default() | |
try: | |
text_width = draw.textlength(text, font=font) | |
except AttributeError: | |
text_width = font.getsize(text)[0] if hasattr(font, 'getsize') else 200 | |
text_position = ((width - text_width) // 2, height - 35) | |
draw.text(text_position, text, fill=(255, 255, 255), font=font) | |
return img_copy | |
def predict_news(text, image): | |
if text is None or text.strip() == "": | |
return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please enter some text to analyze." | |
if image is None: | |
return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please upload an image to analyze." | |
try: | |
logger.info(f"Processing text: {text[:50]}...") | |
logger.info(f"Processing image size: {image.size}") | |
# Process text input | |
inputs = tokenizer.encode_plus(text, add_special_tokens = True, return_tensors='np', max_length=80, truncation=True, padding='max_length') | |
input_ids = inputs['input_ids'] | |
attention_mask = inputs['attention_mask'] | |
logger.info(f"Input IDs shape: {input_ids.shape}") | |
logger.info(f"Attention mask shape: {attention_mask.shape}") | |
# Process image input | |
image_processed = vit_processor(images=image, return_tensors="np")["pixel_values"] | |
logger.info(f"Processed image shape: {image_processed.shape}") | |
ort_inputs = {} | |
for input_meta in ort_session.get_inputs(): | |
input_name = input_meta.name | |
if 'ids' in input_name.lower() or input_name == 'text_input_ids': | |
ort_inputs[input_name] = input_ids | |
elif 'mask' in input_name.lower() or input_name == 'text_attention_mask': | |
ort_inputs[input_name] = attention_mask | |
elif 'image' in input_name.lower() or input_name == 'image_input': | |
ort_inputs[input_name] = image_processed | |
logger.info(f"ONNX input keys: {list(ort_inputs.keys())}") | |
# Run inference | |
start_time = time.time() | |
logger.info("Starting inference") | |
outputs = ort_session.run(None, ort_inputs) | |
inference_time = time.time() - start_time | |
logger.info(f"Inference completed in {inference_time:.3f}s") | |
# Process model outputs | |
logits = outputs[0] | |
logger.info(f"Raw output shape: {logits.shape}, values: {logits}") | |
probs = softmax(logits)[0] | |
logger.info(f"Probabilities: {probs}") | |
pred_idx = int(np.argmax(probs)) | |
confidence = float(probs[pred_idx]) | |
if pred_idx == 1: | |
color = "orange" | |
message = f"This content appears to be **REAL TEXT WITH FAKE IMAGE** with {confidence:.1%} confidence." | |
elif pred_idx == 2: | |
color = "red" | |
message = f"This content appears to contain **FAKE** with {confidence:.1%} confidence." | |
else: | |
color = "green" | |
message = f"This content appears to be **REAL** with {confidence:.1%} confidence." | |
analysis = f""" | |
<div style='text-align: center; padding: 10px; background-color: {color}15; border-radius: 5px; margin-top: 10px;'> | |
<span style='font-size: 18px; color: {color}; font-weight: bold;'>{message}</span> | |
<p>Inference time: {inference_time:.3f} seconds</p> | |
</div> | |
""" | |
result = { | |
labels[0]: float(probs[0]), | |
labels[1]: float(probs[1]), | |
labels[2]: float(probs[2]) | |
} | |
interpretation = image_with_prediction(image, labels[pred_idx], confidence) | |
return result, interpretation, analysis | |
except Exception as e: | |
logger.error(f"Error during analysis: {str(e)}", exc_info=True) | |
return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, f"Error during analysis: {str(e)}" | |
examples = [ | |
["COVID-19 vaccine causes severe side effects in 80% of recipients", "https://images.unsplash.com/photo-1605289982774-9a6fef564df8?q=80&w=1000&auto=format&fit=crop"], | |
["Scientists discover new species of deep-sea fish", "https://images.unsplash.com/photo-1524704796725-9fc3044a58b2?q=80&w=1000&auto=format&fit=crop"], | |
] | |
# Build Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# 📰 Fake News Detector (RoBERTa + ViT) | |
This multimodal AI system analyzes both text and images to detect potentially fake news content. | |
Upload an image and enter a news headline to see if the combination is likely to be real or fake news. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
text_input = gr.Textbox( | |
label="News Headline / Text", | |
placeholder="Enter the news headline or text here...", | |
lines=3 | |
) | |
image_input = gr.Image(type="pil", label="Associated Image") | |
analyze_btn = gr.Button("Analyze Content", variant="primary") | |
with gr.Column(scale=1): | |
label_output = gr.Label(label="Prediction Probabilities") | |
image_output = gr.Image(type="pil", label="Visual Analysis") | |
analysis_html = gr.HTML(label="Analysis") | |
gr.Examples( | |
examples=examples, | |
inputs=[text_input, image_input], | |
outputs=[label_output, image_output, analysis_html], | |
fn=predict_news, | |
cache_examples=True, | |
) | |
gr.Markdown( | |
""" | |
This system combines: | |
- **RoBERTa**: Analyzes the textual content | |
- **ViT**: Processes the image data | |
- **Multimodal Fusion**: Combines both signals to make a prediction | |
The model was trained on the Fakeddit dataset containing real and fake news pairs with both text and images. | |
""" | |
) | |
analyze_btn.click( | |
predict_news, | |
inputs=[text_input, image_input], | |
outputs=[label_output, image_output, analysis_html] | |
) | |
if __name__ == "__main__": | |
logger.info("Starting Gradio application") | |
demo.launch() |