Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
from PIL import Image | |
import numpy as np | |
from captum.attr import LayerGradCam | |
from captum.attr import visualization as viz | |
# --- 1. Load Model and Processor --- | |
# Load the pre-trained model and the image processor from Hugging Face. | |
# We explicitly set torch_dtype to float32 to ensure CPU compatibility. | |
print("Loading model and processor...") | |
model_id = "Organika/sdxl-detector" | |
processor = AutoImageProcessor.from_pretrained(model_id) | |
model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32) | |
model.eval() # Set the model to evaluation mode | |
print("Model and processor loaded successfully.") | |
# --- 2. Define the Explainability (Grad-CAM) Function --- | |
# This function generates the heatmap showing which parts of the image the model focused on. | |
def generate_heatmap(image_tensor, original_image, target_class_index): | |
# LayerGradCam requires a specific layer to hook into. For ConvNeXT models (like this one), | |
# a good choice is the final layer of the last stage of the encoder. | |
target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv | |
# Initialize LayerGradCam | |
lgc = LayerGradCam(model, target_layer) | |
# Generate attributions (the "importance" of each pixel) | |
# The baselines are a reference point, typically a black image. | |
baselines = torch.zeros_like(image_tensor) | |
attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True) | |
# The output of LayerGradCam is a heatmap. We process it for visualization. | |
# We take the mean across the color channels and format it correctly. | |
heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0)) | |
# Use Captum's visualization tool to overlay the heatmap on the original image. | |
visualized_image, _ = viz.visualize_image_attr( | |
heatmap, | |
np.array(original_image), | |
method="blended_heat_map", | |
sign="all", | |
show_colorbar=True, | |
title="Model Attention Heatmap", | |
) | |
return visualized_image | |
# --- 3. Define the Main Prediction Function --- | |
# This function will be called by Gradio every time a user uploads an image. | |
def predict(input_image: Image.Image): | |
print(f"Received image of size: {input_image.size}") | |
# Convert image to RGB if it has an alpha channel | |
if input_image.mode == 'RGBA': | |
input_image = input_image.convert('RGB') | |
# Preprocess the image for the model | |
inputs = processor(images=input_image, return_tensors="pt") | |
# Make a prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Convert logits to probabilities | |
probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
# Get the predicted class index and the confidence score | |
predicted_class_idx = logits.argmax(-1).item() | |
confidence_score = probabilities[0][predicted_class_idx].item() | |
# Get the label name (e.g., 'ai' or 'human') | |
predicted_label = model.config.id2label[predicted_class_idx] | |
# --- Generate Human-Readable Explanation --- | |
# This directly answers your requirement to "say out which one is less human". | |
if predicted_label.lower() == 'ai': | |
explanation = ( | |
f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n" | |
"The heatmap on the right highlights the areas that most influenced this decision. " | |
"According to your research, pay close attention if these hotspots are on " | |
"unnatural-looking features like hair, eyes, skin texture, or strange background details." | |
) | |
else: | |
explanation = ( | |
f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n" | |
"The heatmap shows which areas the model found to be most 'natural'. " | |
"These are likely well-formed, realistic features that AI models often struggle to replicate perfectly." | |
) | |
# --- Generate the Heatmap --- | |
# We call our Grad-CAM function to create the visualization. | |
print("Generating heatmap...") | |
heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx) | |
print("Heatmap generated.") | |
# Return the classification labels, the text explanation, and the heatmap image | |
# The labels dictionary is for the gr.Label component. | |
labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))} | |
return labels_dict, explanation, heatmap_image | |
# --- 4. Create the Gradio Interface --- | |
# This sets up the web UI with inputs and outputs. | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# AI Image Detector with Explainability | |
Upload an image to determine if it was generated by AI or created by a human. | |
This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model. | |
In addition to the prediction, it provides a **heatmap** to show *why* the model made its decision, highlighting the areas it found most suspicious or authentic. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Upload Image") | |
submit_btn = gr.Button("Analyze Image", variant="primary") | |
with gr.Column(): | |
output_label = gr.Label(label="Prediction") | |
output_text = gr.Textbox(label="Explanation", lines=6) | |
output_heatmap = gr.Image(label="Model Attention Heatmap") | |
submit_btn.click( | |
fn=predict, | |
inputs=input_image, | |
outputs=[output_label, output_text, output_heatmap] | |
) | |
gr.Examples( | |
examples=[ | |
["examples/ai_example_1.png"], | |
["examples/human_example_1.jpg"], | |
["examples/ai_example_2.png"], | |
], | |
inputs=input_image, | |
outputs=[output_label, output_text, output_heatmap], | |
fn=predict, | |
cache_examples=True, # Speeds up demo loading | |
# Add this line to grant permission for local files | |
allow_file_access=True | |
) | |
# --- Create example files for the demo --- | |
import os | |
from urllib.request import urlretrieve | |
print("Creating examples directory and downloading example images...") | |
os.makedirs("examples", exist_ok=True) | |
# These URLs are from the stable Hugging Face documentation assets | |
try: | |
# AI Example 1: A classic AI-generated image (astronaut on a horse) | |
urlretrieve( | |
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/horse.png", | |
"examples/ai_example_1.png" | |
) | |
# Human Example 1: A real photograph | |
urlretrieve( | |
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio-guide/zookeeper.png", | |
"examples/human_example_1.jpg" | |
) | |
# AI Example 2: An AI-generated portrait, good for testing face/hair detection | |
urlretrieve( | |
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/stable-diffusion-sdxl/sdxl-gfpgan-output.png", | |
"examples/ai_example_2.png" | |
) | |
print("Example images downloaded successfully.") | |
except Exception as e: | |
print(f"Failed to download example images: {e}") | |
# --- 5. Launch the App --- | |
# This line was already there, just make sure it's the last part of your script | |
if __name__ == "__main__": | |
demo.launch(debug=True) |