File size: 6,282 Bytes
1aabfb0
 
 
 
 
 
 
37ffee9
 
1aabfb0
 
 
 
 
 
37ffee9
1aabfb0
 
03f09e4
1aabfb0
03f09e4
 
 
 
 
 
 
 
 
ed558ff
3e5a622
 
03f09e4
 
 
3e5a622
03f09e4
ed558ff
3e5a622
ed558ff
1aabfb0
3e5a622
1aabfb0
3e5a622
 
 
 
 
 
1aabfb0
 
 
37ffee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1aabfb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37ffee9
1aabfb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37ffee9
 
1aabfb0
 
 
 
37ffee9
 
1aabfb0
 
 
 
37ffee9
 
 
 
 
 
 
1aabfb0
37ffee9
1aabfb0
 
37ffee9
1aabfb0
 
37ffee9
1aabfb0
 
37ffee9
1aabfb0
 
 
37ffee9
 
62764e2
1aabfb0
 
62764e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import gradio as gr
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import numpy as np
from captum.attr import LayerGradCam
from captum.attr import visualization as viz
import requests # <-- Import requests
from io import BytesIO # <-- Import BytesIO

# --- 1. Load Model and Processor ---
print("Loading model and processor...")
model_id = "Organika/sdxl-detector"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32)
model.eval()
print("Model and processor loaded successfully.")

# --- 2. Define the Explainability (Grad-CAM) Function ---
def generate_heatmap(image_tensor, original_image, target_class_index):

    # --- THIS IS THE FIX ---
    # We define a wrapper function that ensures our model returns a simple tensor,
    # which is what Captum expects. It takes the model's output object and
    # extracts the 'logits' tensor from it.
    def model_forward_wrapper(input_tensor):
        outputs = model(pixel_values=input_tensor)
        return outputs.logits

    # This part is correct from our last fix.
    target_layer = model.swin.layernorm

    # Initialize LayerGradCam, but pass our new wrapper function instead of the raw model.
    # Captum will now use this wrapper to get the model's output.
    lgc = LayerGradCam(model_forward_wrapper, target_layer)

    # This call now works because `lgc` gets a proper tensor from our wrapper.
    attributions = lgc.attribute(image_tensor, target=target_class_index, relu_attributions=True)

    # The rest of the function remains the same.
    heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0))

    visualized_image, _ = viz.visualize_image_attr(
        heatmap,
        np.array(original_image),
        method="blended_heat_map",
        sign="all",
        show_colorbar=True,
        title="Model Attention Heatmap",
    )
    return visualized_image

# --- 3. MODIFIED Main Prediction Function ---
# Now it accepts two inputs: an uploaded image and a URL string.
def predict(image_upload: Image.Image, image_url: str):
    
    # --- Logic to decide which input to use ---
    if image_upload is not None:
        input_image = image_upload
        print(f"Processing uploaded image of size: {input_image.size}")
    elif image_url:
        try:
            response = requests.get(image_url)
            response.raise_for_status() # Raise an exception for bad status codes
            input_image = Image.open(BytesIO(response.content))
            print(f"Processing image from URL: {image_url}")
        except Exception as e:
            raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}")
    else:
        # If no input is provided, raise an error
        raise gr.Error("Please upload an image or provide a URL to analyze.")

    if input_image.mode == 'RGBA':
        input_image = input_image.convert('RGB')

    inputs = processor(images=input_image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    predicted_class_idx = logits.argmax(-1).item()
    confidence_score = probabilities[0][predicted_class_idx].item()
    predicted_label = model.config.id2label[predicted_class_idx]

    if predicted_label.lower() == 'ai':
        explanation = (
            f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n"
            "The heatmap on the right highlights the areas that most influenced this decision. "
            "Pay close attention if these hotspots are on unnatural-looking features like hair, eyes, skin texture, or strange background details."
        )
    else:
        explanation = (
            f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n"
            "The heatmap shows which areas the model found to be most 'natural'. "
            "These are likely well-formed, realistic features that AI models often struggle to replicate perfectly."
        )

    print("Generating heatmap...")
    heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx)
    print("Heatmap generated.")

    labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))}
    
    return labels_dict, explanation, heatmap_image

# --- 4. MODIFIED Gradio Interface ---
# We use gr.Tabs to create separate input sections.
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # AI Image Detector with Explainability
        Determine if an image is AI-generated or human-made. Upload a file or paste a URL.
        This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model and provides a **heatmap** to show *why* the model made its decision.
        """
    )
    with gr.Row():
        with gr.Column():
            # --- TABS for different input methods ---
            with gr.Tabs():
                with gr.TabItem("Upload File"):
                    input_image_upload = gr.Image(type="pil", label="Upload Your Image")
                with gr.TabItem("Use Image URL"):
                    input_image_url = gr.Textbox(label="Paste Image URL here")

            submit_btn = gr.Button("Analyze Image", variant="primary")
        
        with gr.Column():
            output_label = gr.Label(label="Prediction")
            output_text = gr.Textbox(label="Explanation", lines=6, interactive=False)
            output_heatmap = gr.Image(label="Model Attention Heatmap")

    # The click event now passes both possible inputs to the predict function
    submit_btn.click(
        fn=predict,
        inputs=[input_image_upload, input_image_url],
        outputs=[output_label, output_text, output_heatmap]
    )
    
    # We remove the examples for now to simplify, as they don't work well with a tabbed interface by default.
    # If you want them back, you would need a more complex setup to handle which tab the example populates.
    
# --- 5. Launch the App ---
if __name__ == "__main__":
    demo.launch(debug=True)