Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torch.nn.functional as F # <-- ADD THIS IMPORT | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
from PIL import Image | |
import numpy as np | |
from captum.attr import LayerGradCam | |
from captum.attr import visualization as viz | |
import requests | |
from io import BytesIO | |
# --- 1. Load Model and Processor --- | |
print("Loading model and processor...") | |
model_id = "Organika/sdxl-detector" | |
processor = AutoImageProcessor.from_pretrained(model_id) | |
model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32) | |
model.eval() | |
print("Model and processor loaded successfully.") | |
# --- 2. FINAL, CORRECTED Explainability (Grad-CAM) Function --- | |
def generate_heatmap(image_tensor, original_image, target_class_index): | |
# This wrapper is correct and necessary for Captum to work with Hugging Face models. | |
def model_forward_wrapper(input_tensor): | |
outputs = model(pixel_values=input_tensor) | |
return outputs.logits | |
# The target layer is also correct for the Swin Transformer. | |
target_layer = model.swin.layernorm | |
# Initialize LayerGradCam with the wrapper and the target layer. | |
lgc = LayerGradCam(model_forward_wrapper, target_layer) | |
# This call now works and returns the attributions. | |
attributions = lgc.attribute(image_tensor, target=target_class_index, relu_attributions=True) | |
# --- THIS IS THE FIX for the Transformer Architecture --- | |
# Transformer models output a sequence of patch attributions, not a 2D grid. | |
# We must reshape this sequence into a grid and then upsample it. | |
# 1. Determine the grid size (e.g., for 49 patches, it's 7x7) | |
# We remove the batch dimension, and get the number of patches (sequence length). | |
num_patches = attributions.shape[-1] | |
grid_size = int(np.sqrt(num_patches)) | |
# 2. Reshape the 1D attributions into a 2D grid. | |
heatmap = attributions.squeeze(0).squeeze(0).reshape(grid_size, grid_size) | |
# 3. Upsample the small heatmap to match the original image size for overlay. | |
# We need to add batch and channel dimensions back for the interpolate function. | |
heatmap = heatmap.unsqueeze(0).unsqueeze(0) | |
# Note: original_image.size is (W, H), interpolate needs size as (H, W) | |
upsampled_heatmap = F.interpolate(heatmap, size=original_image.size[::-1], mode='bilinear', align_corners=False) | |
# 4. Prepare the final heatmap for visualization | |
heatmap_for_viz = upsampled_heatmap.squeeze().cpu().detach().numpy() | |
# The visualization function expects a (H, W, C) shaped numpy array. | |
# Our heatmap is (H, W), so we add a channel dimension. | |
visualized_image, _ = viz.visualize_image_attr( | |
np.expand_dims(heatmap_for_viz, axis=-1), | |
np.array(original_image), | |
method="blended_heat_map", | |
sign="all", | |
show_colorbar=True, | |
title="Model Attention Heatmap", | |
) | |
return visualized_image | |
# --- 3. Main Prediction Function (Unchanged) --- | |
def predict(image_upload: Image.Image, image_url: str): | |
if image_upload is not None: | |
input_image = image_upload | |
print(f"Processing uploaded image of size: {input_image.size}") | |
elif image_url: | |
try: | |
response = requests.get(image_url) | |
response.raise_for_status() | |
input_image = Image.open(BytesIO(response.content)) | |
print(f"Processing image from URL: {image_url}") | |
except Exception as e: | |
raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}") | |
else: | |
raise gr.Error("Please upload an image or provide a URL to analyze.") | |
if input_image.mode == 'RGBA': | |
input_image = input_image.convert('RGB') | |
inputs = processor(images=input_image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
predicted_class_idx = logits.argmax(-1).item() | |
confidence_score = probabilities[0][predicted_class_idx].item() | |
predicted_label = model.config.id2label[predicted_class_idx] | |
if predicted_label.lower() == 'ai': | |
explanation = ( | |
f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n" | |
"The heatmap on the right highlights the areas that most influenced this decision. " | |
"Pay close attention if these hotspots are on unnatural-looking features like hair, eyes, skin texture, or strange background details." | |
) | |
else: | |
explanation = ( | |
f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n" | |
"The heatmap shows which areas the model found to be most 'natural'. " | |
"These are likely well-formed, realistic features that AI models often struggle to replicate perfectly." | |
) | |
print("Generating heatmap...") | |
heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx) | |
print("Heatmap generated.") | |
labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))} | |
return labels_dict, explanation, heatmap_image | |
# --- 4. Gradio Interface (Unchanged) --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# AI Image Detector with Explainability | |
Determine if an image is AI-generated or human-made. Upload a file or paste a URL. | |
This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model and provides a **heatmap** to show *why* the model made its decision. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Tabs(): | |
with gr.TabItem("Upload File"): | |
input_image_upload = gr.Image(type="pil", label="Upload Your Image") | |
with gr.TabItem("Use Image URL"): | |
input_image_url = gr.Textbox(label="Paste Image URL here") | |
submit_btn = gr.Button("Analyze Image", variant="primary") | |
with gr.Column(): | |
output_label = gr.Label(label="Prediction") | |
output_text = gr.Textbox(label="Explanation", lines=6, interactive=False) | |
output_heatmap = gr.Image(label="Model Attention Heatmap") | |
submit_btn.click( | |
fn=predict, | |
inputs=[input_image_upload, input_image_url], | |
outputs=[output_label, output_text, output_heatmap] | |
) | |
# --- 5. Launch the App (Unchanged) --- | |
if __name__ == "__main__": | |
demo.launch(debug=True) |