clockclock commited on
Commit
1aabfb0
·
verified ·
1 Parent(s): 7b2e222

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
4
+ from PIL import Image
5
+ import numpy as np
6
+ from captum.attr import LayerGradCam
7
+ from captum.attr import visualization as viz
8
+
9
+ # --- 1. Load Model and Processor ---
10
+ # Load the pre-trained model and the image processor from Hugging Face.
11
+ # We explicitly set torch_dtype to float32 to ensure CPU compatibility.
12
+ print("Loading model and processor...")
13
+ model_id = "Organika/sdxl-detector"
14
+ processor = AutoImageProcessor.from_pretrained(model_id)
15
+ model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32)
16
+ model.eval() # Set the model to evaluation mode
17
+ print("Model and processor loaded successfully.")
18
+
19
+ # --- 2. Define the Explainability (Grad-CAM) Function ---
20
+ # This function generates the heatmap showing which parts of the image the model focused on.
21
+ def generate_heatmap(image_tensor, original_image, target_class_index):
22
+ # LayerGradCam requires a specific layer to hook into. For ConvNeXT models (like this one),
23
+ # a good choice is the final layer of the last stage of the encoder.
24
+ target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv
25
+
26
+ # Initialize LayerGradCam
27
+ lgc = LayerGradCam(model, target_layer)
28
+
29
+ # Generate attributions (the "importance" of each pixel)
30
+ # The baselines are a reference point, typically a black image.
31
+ baselines = torch.zeros_like(image_tensor)
32
+ attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True)
33
+
34
+ # The output of LayerGradCam is a heatmap. We process it for visualization.
35
+ # We take the mean across the color channels and format it correctly.
36
+ heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
37
+
38
+ # Use Captum's visualization tool to overlay the heatmap on the original image.
39
+ visualized_image, _ = viz.visualize_image_attr(
40
+ heatmap,
41
+ np.array(original_image),
42
+ method="blended_heat_map",
43
+ sign="all",
44
+ show_colorbar=True,
45
+ title="Model Attention Heatmap",
46
+ )
47
+ return visualized_image
48
+
49
+ # --- 3. Define the Main Prediction Function ---
50
+ # This function will be called by Gradio every time a user uploads an image.
51
+ def predict(input_image: Image.Image):
52
+ print(f"Received image of size: {input_image.size}")
53
+
54
+ # Convert image to RGB if it has an alpha channel
55
+ if input_image.mode == 'RGBA':
56
+ input_image = input_image.convert('RGB')
57
+
58
+ # Preprocess the image for the model
59
+ inputs = processor(images=input_image, return_tensors="pt")
60
+
61
+ # Make a prediction
62
+ with torch.no_grad():
63
+ outputs = model(**inputs)
64
+ logits = outputs.logits
65
+
66
+ # Convert logits to probabilities
67
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
68
+
69
+ # Get the predicted class index and the confidence score
70
+ predicted_class_idx = logits.argmax(-1).item()
71
+ confidence_score = probabilities[0][predicted_class_idx].item()
72
+
73
+ # Get the label name (e.g., 'ai' or 'human')
74
+ predicted_label = model.config.id2label[predicted_class_idx]
75
+
76
+ # --- Generate Human-Readable Explanation ---
77
+ # This directly answers your requirement to "say out which one is less human".
78
+ if predicted_label.lower() == 'ai':
79
+ explanation = (
80
+ f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n"
81
+ "The heatmap on the right highlights the areas that most influenced this decision. "
82
+ "According to your research, pay close attention if these hotspots are on "
83
+ "unnatural-looking features like hair, eyes, skin texture, or strange background details."
84
+ )
85
+ else:
86
+ explanation = (
87
+ f"The model is {confidence_score:.2%} confident that this image is HUMAN-MADE.\n\n"
88
+ "The heatmap shows which areas the model found to be most 'natural'. "
89
+ "These are likely well-formed, realistic features that AI models often struggle to replicate perfectly."
90
+ )
91
+
92
+ # --- Generate the Heatmap ---
93
+ # We call our Grad-CAM function to create the visualization.
94
+ print("Generating heatmap...")
95
+ heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx)
96
+ print("Heatmap generated.")
97
+
98
+ # Return the classification labels, the text explanation, and the heatmap image
99
+ # The labels dictionary is for the gr.Label component.
100
+ labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))}
101
+
102
+ return labels_dict, explanation, heatmap_image
103
+
104
+ # --- 4. Create the Gradio Interface ---
105
+ # This sets up the web UI with inputs and outputs.
106
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
107
+ gr.Markdown(
108
+ """
109
+ # AI Image Detector with Explainability
110
+ Upload an image to determine if it was generated by AI or created by a human.
111
+ This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model.
112
+ In addition to the prediction, it provides a **heatmap** to show *why* the model made its decision, highlighting the areas it found most suspicious or authentic.
113
+ """
114
+ )
115
+ with gr.Row():
116
+ with gr.Column():
117
+ input_image = gr.Image(type="pil", label="Upload Image")
118
+ submit_btn = gr.Button("Analyze Image", variant="primary")
119
+ with gr.Column():
120
+ output_label = gr.Label(label="Prediction")
121
+ output_text = gr.Textbox(label="Explanation", lines=6)
122
+ output_heatmap = gr.Image(label="Model Attention Heatmap")
123
+
124
+ submit_btn.click(
125
+ fn=predict,
126
+ inputs=input_image,
127
+ outputs=[output_label, output_text, output_heatmap]
128
+ )
129
+
130
+ gr.Examples(
131
+ examples=[
132
+ ["examples/ai_example_1.png"],
133
+ ["examples/human_example_1.jpg"],
134
+ ["examples/ai_example_2.png"],
135
+ ],
136
+ inputs=input_image,
137
+ outputs=[output_label, output_text, output_heatmap],
138
+ fn=predict,
139
+ cache_examples=True # Speeds up demo loading
140
+ )
141
+
142
+ # Create some example files for the demo
143
+ import os
144
+ from urllib.request import urlretrieve
145
+ os.makedirs("examples", exist_ok=True)
146
+ urlretrieve("https://huggingface.co/Organika/sdxl-detector/resolve/main/ai.png", "examples/ai_example_1.png")
147
+ urlretrieve("https://huggingface.co/Organika/sdxl-detector/resolve/main/human.png", "examples/human_example_1.jpg")
148
+ urlretrieve("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/stable-diffusion-sdxl/sdxl-gfpgan-output.png", "examples/ai_example_2.png")
149
+
150
+
151
+ # --- 5. Launch the App ---
152
+ if __name__ == "__main__":
153
+ demo.launch(debug=True) # debug=True lets you see print statements in the logs