import gradio as gr import onnxruntime as ort from transformers import RobertaTokenizer, ViTImageProcessor from PIL import Image import numpy as np import torch import os import time import logging # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") tokenizer = RobertaTokenizer.from_pretrained("roberta-base") model_path = "./multimodal_model.onnx" try: if not os.path.exists(model_path): raise FileNotFoundError(f"ONNX model not found at {model_path}") logger.info(f"Loading ONNX model from {model_path}") sess_options = ort.SessionOptions() sess_options.log_severity_level = 0 ort_session = ort.InferenceSession( model_path, sess_options=sess_options, providers=['CPUExecutionProvider'] ) logger.info("ONNX model loaded successfully") input_names = [input.name for input in ort_session.get_inputs()] input_shapes = {input.name: input.shape for input in ort_session.get_inputs()} output_names = [output.name for output in ort_session.get_outputs()] logger.info(f"Model inputs: {input_names} with shapes {input_shapes}") logger.info(f"Model outputs: {output_names}") except Exception as e: logger.error(f"Error loading ONNX model: {e}") raise labels = ["Real", "Real Text with fake image", "Fake"] def softmax(x): """Compute softmax values for each sets of scores in x.""" e_x = np.exp(x - np.max(x, axis=1, keepdims=True)) return e_x / e_x.sum(axis=1, keepdims=True) def image_with_prediction(img, label, confidence): """Return the original image with an overlay showing the prediction""" from PIL import Image, ImageDraw, ImageFont img_copy = img.copy() draw = ImageDraw.Draw(img_copy) width, height = img_copy.size overlay = Image.new('RGBA', (width, 40), (0, 0, 0, 150)) img_copy.paste(overlay, (0, height-40), overlay) text = f"{label}: {confidence:.1%}" try: font = ImageFont.truetype("arial.ttf", 20) except IOError: font = ImageFont.load_default() try: text_width = draw.textlength(text, font=font) except AttributeError: text_width = font.getsize(text)[0] if hasattr(font, 'getsize') else 200 text_position = ((width - text_width) // 2, height - 35) draw.text(text_position, text, fill=(255, 255, 255), font=font) return img_copy def predict_news(text, image): if text is None or text.strip() == "": return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please enter some text to analyze." if image is None: return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please upload an image to analyze." try: logger.info(f"Processing text: {text[:50]}...") logger.info(f"Processing image size: {image.size}") # Process text input inputs = tokenizer.encode_plus(text, add_special_tokens = True, return_tensors='np', max_length=80, truncation=True, padding='max_length') input_ids = inputs['input_ids'] attention_mask = inputs['attention_mask'] logger.info(f"Input IDs shape: {input_ids.shape}") logger.info(f"Attention mask shape: {attention_mask.shape}") # Process image input image_processed = vit_processor(images=image, return_tensors="np")["pixel_values"] logger.info(f"Processed image shape: {image_processed.shape}") ort_inputs = {} for input_meta in ort_session.get_inputs(): input_name = input_meta.name if 'ids' in input_name.lower() or input_name == 'text_input_ids': ort_inputs[input_name] = input_ids elif 'mask' in input_name.lower() or input_name == 'text_attention_mask': ort_inputs[input_name] = attention_mask elif 'image' in input_name.lower() or input_name == 'image_input': ort_inputs[input_name] = image_processed logger.info(f"ONNX input keys: {list(ort_inputs.keys())}") # Run inference start_time = time.time() logger.info("Starting inference") outputs = ort_session.run(None, ort_inputs) inference_time = time.time() - start_time logger.info(f"Inference completed in {inference_time:.3f}s") # Process model outputs logits = outputs[0] logger.info(f"Raw output shape: {logits.shape}, values: {logits}") probs = softmax(logits)[0] logger.info(f"Probabilities: {probs}") pred_idx = int(np.argmax(probs)) confidence = float(probs[pred_idx]) if pred_idx == 1: color = "orange" message = f"This content appears to be **REAL TEXT WITH FAKE IMAGE** with {confidence:.1%} confidence." elif pred_idx == 2: color = "red" message = f"This content appears to contain **FAKE** with {confidence:.1%} confidence." else: color = "green" message = f"This content appears to be **REAL** with {confidence:.1%} confidence." analysis = f"""
{message}

Inference time: {inference_time:.3f} seconds

""" result = { labels[0]: float(probs[0]), labels[1]: float(probs[1]), labels[2]: float(probs[2]) } interpretation = image_with_prediction(image, labels[pred_idx], confidence) return result, interpretation, analysis except Exception as e: logger.error(f"Error during analysis: {str(e)}", exc_info=True) return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, f"Error during analysis: {str(e)}" examples = [ ["COVID-19 vaccine causes severe side effects in 80% of recipients", "https://images.unsplash.com/photo-1605289982774-9a6fef564df8?q=80&w=1000&auto=format&fit=crop"], ["Scientists discover new species of deep-sea fish", "https://images.unsplash.com/photo-1524704796725-9fc3044a58b2?q=80&w=1000&auto=format&fit=crop"], ] # Build Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 📰 Fake News Detector (RoBERTa + ViT) This multimodal AI system analyzes both text and images to detect potentially fake news content. Upload an image and enter a news headline to see if the combination is likely to be real or fake news. """ ) with gr.Row(): with gr.Column(scale=1): text_input = gr.Textbox( label="News Headline / Text", placeholder="Enter the news headline or text here...", lines=3 ) image_input = gr.Image(type="pil", label="Associated Image") analyze_btn = gr.Button("Analyze Content", variant="primary") with gr.Column(scale=1): label_output = gr.Label(label="Prediction Probabilities") image_output = gr.Image(type="pil", label="Visual Analysis") analysis_html = gr.HTML(label="Analysis") gr.Examples( examples=examples, inputs=[text_input, image_input], outputs=[label_output, image_output, analysis_html], fn=predict_news, cache_examples=True, ) gr.Markdown( """ This system combines: - **RoBERTa**: Analyzes the textual content - **ViT**: Processes the image data - **Multimodal Fusion**: Combines both signals to make a prediction The model was trained on the Fakeddit dataset containing real and fake news pairs with both text and images. """ ) analyze_btn.click( predict_news, inputs=[text_input, image_input], outputs=[label_output, image_output, analysis_html] ) if __name__ == "__main__": logger.info("Starting Gradio application") demo.launch()