File size: 7,571 Bytes
1aabfb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b021a8
 
 
1aabfb0
 
62764e2
1aabfb0
 
62764e2
 
1aabfb0
62764e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1aabfb0
 
 
62764e2
1aabfb0
62764e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import gradio as gr
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import numpy as np
from captum.attr import LayerGradCam
from captum.attr import visualization as viz

# --- 1. Load Model and Processor ---
# Load the pre-trained model and the image processor from Hugging Face.
# We explicitly set torch_dtype to float32 to ensure CPU compatibility.
print("Loading model and processor...")
model_id = "Organika/sdxl-detector"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32)
model.eval()  # Set the model to evaluation mode
print("Model and processor loaded successfully.")

# --- 2. Define the Explainability (Grad-CAM) Function ---
# This function generates the heatmap showing which parts of the image the model focused on.
def generate_heatmap(image_tensor, original_image, target_class_index):
    # LayerGradCam requires a specific layer to hook into. For ConvNeXT models (like this one),
    # a good choice is the final layer of the last stage of the encoder.
    target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv

    # Initialize LayerGradCam
    lgc = LayerGradCam(model, target_layer)

    # Generate attributions (the "importance" of each pixel)
    # The baselines are a reference point, typically a black image.
    baselines = torch.zeros_like(image_tensor)
    attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True)

    # The output of LayerGradCam is a heatmap. We process it for visualization.
    # We take the mean across the color channels and format it correctly.
    heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0))

    # Use Captum's visualization tool to overlay the heatmap on the original image.
    visualized_image, _ = viz.visualize_image_attr(
        heatmap,
        np.array(original_image),
        method="blended_heat_map",
        sign="all",
        show_colorbar=True,
        title="Model Attention Heatmap",
    )
    return visualized_image

# --- 3. Define the Main Prediction Function ---
# This function will be called by Gradio every time a user uploads an image.
def predict(input_image: Image.Image):
    print(f"Received image of size: {input_image.size}")

    # Convert image to RGB if it has an alpha channel
    if input_image.mode == 'RGBA':
        input_image = input_image.convert('RGB')

    # Preprocess the image for the model
    inputs = processor(images=input_image, return_tensors="pt")

    # Make a prediction
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Convert logits to probabilities
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    
    # Get the predicted class index and the confidence score
    predicted_class_idx = logits.argmax(-1).item()
    confidence_score = probabilities[0][predicted_class_idx].item()
    
    # Get the label name (e.g., 'ai' or 'human')
    predicted_label = model.config.id2label[predicted_class_idx]

    # --- Generate Human-Readable Explanation ---
    # This directly answers your requirement to "say out which one is less human".
    if predicted_label.lower() == 'ai':
        explanation = (
            f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n"
            "The heatmap on the right highlights the areas that most influenced this decision. "
            "According to your research, pay close attention if these hotspots are on "
            "unnatural-looking features like hair, eyes, skin texture, or strange background details."
        )
    else:
        explanation = (
            f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n"
            "The heatmap shows which areas the model found to be most 'natural'. "
            "These are likely well-formed, realistic features that AI models often struggle to replicate perfectly."
        )

    # --- Generate the Heatmap ---
    # We call our Grad-CAM function to create the visualization.
    print("Generating heatmap...")
    heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx)
    print("Heatmap generated.")

    # Return the classification labels, the text explanation, and the heatmap image
    # The labels dictionary is for the gr.Label component.
    labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))}
    
    return labels_dict, explanation, heatmap_image

# --- 4. Create the Gradio Interface ---
# This sets up the web UI with inputs and outputs.
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # AI Image Detector with Explainability
        Upload an image to determine if it was generated by AI or created by a human.
        This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model.
        In addition to the prediction, it provides a **heatmap** to show *why* the model made its decision, highlighting the areas it found most suspicious or authentic.
        """
    )
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Upload Image")
            submit_btn = gr.Button("Analyze Image", variant="primary")
        with gr.Column():
            output_label = gr.Label(label="Prediction")
            output_text = gr.Textbox(label="Explanation", lines=6)
            output_heatmap = gr.Image(label="Model Attention Heatmap")

    submit_btn.click(
        fn=predict,
        inputs=input_image,
        outputs=[output_label, output_text, output_heatmap]
    )
    
    gr.Examples(
        examples=[
            ["examples/ai_example_1.png"],
            ["examples/human_example_1.jpg"],
            ["examples/ai_example_2.png"],
        ],
        inputs=input_image,
        outputs=[output_label, output_text, output_heatmap],
        fn=predict,
        cache_examples=True, # Speeds up demo loading
        # Add this line to grant permission for local files
        allow_file_access=True 
    )

# --- Create example files for the demo ---
import os
from urllib.request import urlretrieve

print("Creating examples directory and downloading example images...")
os.makedirs("examples", exist_ok=True)

# These URLs are from the stable Hugging Face documentation assets
try:
    # AI Example 1: A classic AI-generated image (astronaut on a horse)
    urlretrieve(
        "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/horse.png",
        "examples/ai_example_1.png"
    )
    
    # Human Example 1: A real photograph
    urlretrieve(
        "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio-guide/zookeeper.png",
        "examples/human_example_1.jpg"
    )

    # AI Example 2: An AI-generated portrait, good for testing face/hair detection
    urlretrieve(
        "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/stable-diffusion-sdxl/sdxl-gfpgan-output.png",
        "examples/ai_example_2.png"
    )
    print("Example images downloaded successfully.")
except Exception as e:
    print(f"Failed to download example images: {e}")


# --- 5. Launch the App ---
# This line was already there, just make sure it's the last part of your script
if __name__ == "__main__":
    demo.launch(debug=True)