import gradio as gr import onnxruntime as ort from transformers import RobertaTokenizer, ViTImageProcessor from PIL import Image import numpy as np import torch import os import time import logging # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") tokenizer = RobertaTokenizer.from_pretrained("roberta-base") model_path = "./multimodal_model.onnx" try: if not os.path.exists(model_path): raise FileNotFoundError(f"ONNX model not found at {model_path}") logger.info(f"Loading ONNX model from {model_path}") sess_options = ort.SessionOptions() sess_options.log_severity_level = 0 ort_session = ort.InferenceSession( model_path, sess_options=sess_options, providers=['CPUExecutionProvider'] ) logger.info("ONNX model loaded successfully") input_names = [input.name for input in ort_session.get_inputs()] input_shapes = {input.name: input.shape for input in ort_session.get_inputs()} output_names = [output.name for output in ort_session.get_outputs()] logger.info(f"Model inputs: {input_names} with shapes {input_shapes}") logger.info(f"Model outputs: {output_names}") except Exception as e: logger.error(f"Error loading ONNX model: {e}") raise labels = ["Real", "Real Text with fake image", "Fake"] def softmax(x): """Compute softmax values for each sets of scores in x.""" e_x = np.exp(x - np.max(x, axis=1, keepdims=True)) return e_x / e_x.sum(axis=1, keepdims=True) def image_with_prediction(img, label, confidence): """Return the original image with an overlay showing the prediction""" from PIL import Image, ImageDraw, ImageFont img_copy = img.copy() draw = ImageDraw.Draw(img_copy) width, height = img_copy.size overlay = Image.new('RGBA', (width, 40), (0, 0, 0, 150)) img_copy.paste(overlay, (0, height-40), overlay) text = f"{label}: {confidence:.1%}" try: font = ImageFont.truetype("arial.ttf", 20) except IOError: font = ImageFont.load_default() try: text_width = draw.textlength(text, font=font) except AttributeError: text_width = font.getsize(text)[0] if hasattr(font, 'getsize') else 200 text_position = ((width - text_width) // 2, height - 35) draw.text(text_position, text, fill=(255, 255, 255), font=font) return img_copy def predict_news(text, image): if text is None or text.strip() == "": return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please enter some text to analyze." if image is None: return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please upload an image to analyze." try: logger.info(f"Processing text: {text[:50]}...") logger.info(f"Processing image size: {image.size}") # Process text input inputs = tokenizer.encode_plus(text, add_special_tokens = True, return_tensors='np', max_length=80, truncation=True, padding='max_length') input_ids = inputs['input_ids'] attention_mask = inputs['attention_mask'] logger.info(f"Input IDs shape: {input_ids.shape}") logger.info(f"Attention mask shape: {attention_mask.shape}") # Process image input image_processed = vit_processor(images=image, return_tensors="np")["pixel_values"] logger.info(f"Processed image shape: {image_processed.shape}") ort_inputs = {} for input_meta in ort_session.get_inputs(): input_name = input_meta.name if 'ids' in input_name.lower() or input_name == 'text_input_ids': ort_inputs[input_name] = input_ids elif 'mask' in input_name.lower() or input_name == 'text_attention_mask': ort_inputs[input_name] = attention_mask elif 'image' in input_name.lower() or input_name == 'image_input': ort_inputs[input_name] = image_processed logger.info(f"ONNX input keys: {list(ort_inputs.keys())}") # Run inference start_time = time.time() logger.info("Starting inference") outputs = ort_session.run(None, ort_inputs) inference_time = time.time() - start_time logger.info(f"Inference completed in {inference_time:.3f}s") # Process model outputs logits = outputs[0] logger.info(f"Raw output shape: {logits.shape}, values: {logits}") probs = softmax(logits)[0] logger.info(f"Probabilities: {probs}") pred_idx = int(np.argmax(probs)) confidence = float(probs[pred_idx]) if pred_idx == 1: color = "orange" message = f"This content appears to be **REAL TEXT WITH FAKE IMAGE** with {confidence:.1%} confidence." elif pred_idx == 2: color = "red" message = f"This content appears to contain **FAKE** with {confidence:.1%} confidence." else: color = "green" message = f"This content appears to be **REAL** with {confidence:.1%} confidence." analysis = f"""
Inference time: {inference_time:.3f} seconds