Spaces:
Running
Running
File size: 7,571 Bytes
1aabfb0 2b021a8 1aabfb0 62764e2 1aabfb0 62764e2 1aabfb0 62764e2 1aabfb0 62764e2 1aabfb0 62764e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import gradio as gr
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import numpy as np
from captum.attr import LayerGradCam
from captum.attr import visualization as viz
# --- 1. Load Model and Processor ---
# Load the pre-trained model and the image processor from Hugging Face.
# We explicitly set torch_dtype to float32 to ensure CPU compatibility.
print("Loading model and processor...")
model_id = "Organika/sdxl-detector"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32)
model.eval() # Set the model to evaluation mode
print("Model and processor loaded successfully.")
# --- 2. Define the Explainability (Grad-CAM) Function ---
# This function generates the heatmap showing which parts of the image the model focused on.
def generate_heatmap(image_tensor, original_image, target_class_index):
# LayerGradCam requires a specific layer to hook into. For ConvNeXT models (like this one),
# a good choice is the final layer of the last stage of the encoder.
target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv
# Initialize LayerGradCam
lgc = LayerGradCam(model, target_layer)
# Generate attributions (the "importance" of each pixel)
# The baselines are a reference point, typically a black image.
baselines = torch.zeros_like(image_tensor)
attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True)
# The output of LayerGradCam is a heatmap. We process it for visualization.
# We take the mean across the color channels and format it correctly.
heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
# Use Captum's visualization tool to overlay the heatmap on the original image.
visualized_image, _ = viz.visualize_image_attr(
heatmap,
np.array(original_image),
method="blended_heat_map",
sign="all",
show_colorbar=True,
title="Model Attention Heatmap",
)
return visualized_image
# --- 3. Define the Main Prediction Function ---
# This function will be called by Gradio every time a user uploads an image.
def predict(input_image: Image.Image):
print(f"Received image of size: {input_image.size}")
# Convert image to RGB if it has an alpha channel
if input_image.mode == 'RGBA':
input_image = input_image.convert('RGB')
# Preprocess the image for the model
inputs = processor(images=input_image, return_tensors="pt")
# Make a prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Convert logits to probabilities
probabilities = torch.nn.functional.softmax(logits, dim=-1)
# Get the predicted class index and the confidence score
predicted_class_idx = logits.argmax(-1).item()
confidence_score = probabilities[0][predicted_class_idx].item()
# Get the label name (e.g., 'ai' or 'human')
predicted_label = model.config.id2label[predicted_class_idx]
# --- Generate Human-Readable Explanation ---
# This directly answers your requirement to "say out which one is less human".
if predicted_label.lower() == 'ai':
explanation = (
f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n"
"The heatmap on the right highlights the areas that most influenced this decision. "
"According to your research, pay close attention if these hotspots are on "
"unnatural-looking features like hair, eyes, skin texture, or strange background details."
)
else:
explanation = (
f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n"
"The heatmap shows which areas the model found to be most 'natural'. "
"These are likely well-formed, realistic features that AI models often struggle to replicate perfectly."
)
# --- Generate the Heatmap ---
# We call our Grad-CAM function to create the visualization.
print("Generating heatmap...")
heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx)
print("Heatmap generated.")
# Return the classification labels, the text explanation, and the heatmap image
# The labels dictionary is for the gr.Label component.
labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))}
return labels_dict, explanation, heatmap_image
# --- 4. Create the Gradio Interface ---
# This sets up the web UI with inputs and outputs.
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# AI Image Detector with Explainability
Upload an image to determine if it was generated by AI or created by a human.
This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model.
In addition to the prediction, it provides a **heatmap** to show *why* the model made its decision, highlighting the areas it found most suspicious or authentic.
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
submit_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column():
output_label = gr.Label(label="Prediction")
output_text = gr.Textbox(label="Explanation", lines=6)
output_heatmap = gr.Image(label="Model Attention Heatmap")
submit_btn.click(
fn=predict,
inputs=input_image,
outputs=[output_label, output_text, output_heatmap]
)
gr.Examples(
examples=[
["examples/ai_example_1.png"],
["examples/human_example_1.jpg"],
["examples/ai_example_2.png"],
],
inputs=input_image,
outputs=[output_label, output_text, output_heatmap],
fn=predict,
cache_examples=True, # Speeds up demo loading
# Add this line to grant permission for local files
allow_file_access=True
)
# --- Create example files for the demo ---
import os
from urllib.request import urlretrieve
print("Creating examples directory and downloading example images...")
os.makedirs("examples", exist_ok=True)
# These URLs are from the stable Hugging Face documentation assets
try:
# AI Example 1: A classic AI-generated image (astronaut on a horse)
urlretrieve(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/horse.png",
"examples/ai_example_1.png"
)
# Human Example 1: A real photograph
urlretrieve(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio-guide/zookeeper.png",
"examples/human_example_1.jpg"
)
# AI Example 2: An AI-generated portrait, good for testing face/hair detection
urlretrieve(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/stable-diffusion-sdxl/sdxl-gfpgan-output.png",
"examples/ai_example_2.png"
)
print("Example images downloaded successfully.")
except Exception as e:
print(f"Failed to download example images: {e}")
# --- 5. Launch the App ---
# This line was already there, just make sure it's the last part of your script
if __name__ == "__main__":
demo.launch(debug=True) |