clockclock commited on
Commit
37ffee9
·
verified ·
1 Parent(s): 2b021a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -95
app.py CHANGED
@@ -5,82 +5,68 @@ from PIL import Image
5
  import numpy as np
6
  from captum.attr import LayerGradCam
7
  from captum.attr import visualization as viz
 
 
8
 
9
  # --- 1. Load Model and Processor ---
10
- # Load the pre-trained model and the image processor from Hugging Face.
11
- # We explicitly set torch_dtype to float32 to ensure CPU compatibility.
12
  print("Loading model and processor...")
13
  model_id = "Organika/sdxl-detector"
14
  processor = AutoImageProcessor.from_pretrained(model_id)
15
  model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32)
16
- model.eval() # Set the model to evaluation mode
17
  print("Model and processor loaded successfully.")
18
 
19
  # --- 2. Define the Explainability (Grad-CAM) Function ---
20
- # This function generates the heatmap showing which parts of the image the model focused on.
21
  def generate_heatmap(image_tensor, original_image, target_class_index):
22
- # LayerGradCam requires a specific layer to hook into. For ConvNeXT models (like this one),
23
- # a good choice is the final layer of the last stage of the encoder.
24
  target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv
25
-
26
- # Initialize LayerGradCam
27
  lgc = LayerGradCam(model, target_layer)
28
-
29
- # Generate attributions (the "importance" of each pixel)
30
- # The baselines are a reference point, typically a black image.
31
  baselines = torch.zeros_like(image_tensor)
32
  attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True)
33
-
34
- # The output of LayerGradCam is a heatmap. We process it for visualization.
35
- # We take the mean across the color channels and format it correctly.
36
  heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
37
-
38
- # Use Captum's visualization tool to overlay the heatmap on the original image.
39
  visualized_image, _ = viz.visualize_image_attr(
40
- heatmap,
41
- np.array(original_image),
42
- method="blended_heat_map",
43
- sign="all",
44
- show_colorbar=True,
45
- title="Model Attention Heatmap",
46
  )
47
  return visualized_image
48
 
49
- # --- 3. Define the Main Prediction Function ---
50
- # This function will be called by Gradio every time a user uploads an image.
51
- def predict(input_image: Image.Image):
52
- print(f"Received image of size: {input_image.size}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Convert image to RGB if it has an alpha channel
55
  if input_image.mode == 'RGBA':
56
  input_image = input_image.convert('RGB')
57
 
58
- # Preprocess the image for the model
59
  inputs = processor(images=input_image, return_tensors="pt")
60
 
61
- # Make a prediction
62
  with torch.no_grad():
63
  outputs = model(**inputs)
64
  logits = outputs.logits
65
 
66
- # Convert logits to probabilities
67
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
68
-
69
- # Get the predicted class index and the confidence score
70
  predicted_class_idx = logits.argmax(-1).item()
71
  confidence_score = probabilities[0][predicted_class_idx].item()
72
-
73
- # Get the label name (e.g., 'ai' or 'human')
74
  predicted_label = model.config.id2label[predicted_class_idx]
75
 
76
- # --- Generate Human-Readable Explanation ---
77
- # This directly answers your requirement to "say out which one is less human".
78
  if predicted_label.lower() == 'ai':
79
  explanation = (
80
  f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n"
81
  "The heatmap on the right highlights the areas that most influenced this decision. "
82
- "According to your research, pay close attention if these hotspots are on "
83
- "unnatural-looking features like hair, eyes, skin texture, or strange background details."
84
  )
85
  else:
86
  explanation = (
@@ -89,90 +75,50 @@ def predict(input_image: Image.Image):
89
  "These are likely well-formed, realistic features that AI models often struggle to replicate perfectly."
90
  )
91
 
92
- # --- Generate the Heatmap ---
93
- # We call our Grad-CAM function to create the visualization.
94
  print("Generating heatmap...")
95
  heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx)
96
  print("Heatmap generated.")
97
 
98
- # Return the classification labels, the text explanation, and the heatmap image
99
- # The labels dictionary is for the gr.Label component.
100
  labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))}
101
 
102
  return labels_dict, explanation, heatmap_image
103
 
104
- # --- 4. Create the Gradio Interface ---
105
- # This sets up the web UI with inputs and outputs.
106
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
107
  gr.Markdown(
108
  """
109
  # AI Image Detector with Explainability
110
- Upload an image to determine if it was generated by AI or created by a human.
111
- This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model.
112
- In addition to the prediction, it provides a **heatmap** to show *why* the model made its decision, highlighting the areas it found most suspicious or authentic.
113
  """
114
  )
115
  with gr.Row():
116
  with gr.Column():
117
- input_image = gr.Image(type="pil", label="Upload Image")
 
 
 
 
 
 
118
  submit_btn = gr.Button("Analyze Image", variant="primary")
 
119
  with gr.Column():
120
  output_label = gr.Label(label="Prediction")
121
- output_text = gr.Textbox(label="Explanation", lines=6)
122
  output_heatmap = gr.Image(label="Model Attention Heatmap")
123
 
 
124
  submit_btn.click(
125
  fn=predict,
126
- inputs=input_image,
127
  outputs=[output_label, output_text, output_heatmap]
128
  )
129
 
130
- gr.Examples(
131
- examples=[
132
- ["examples/ai_example_1.png"],
133
- ["examples/human_example_1.jpg"],
134
- ["examples/ai_example_2.png"],
135
- ],
136
- inputs=input_image,
137
- outputs=[output_label, output_text, output_heatmap],
138
- fn=predict,
139
- cache_examples=True, # Speeds up demo loading
140
- # Add this line to grant permission for local files
141
- allow_file_access=True
142
- )
143
-
144
- # --- Create example files for the demo ---
145
- import os
146
- from urllib.request import urlretrieve
147
-
148
- print("Creating examples directory and downloading example images...")
149
- os.makedirs("examples", exist_ok=True)
150
-
151
- # These URLs are from the stable Hugging Face documentation assets
152
- try:
153
- # AI Example 1: A classic AI-generated image (astronaut on a horse)
154
- urlretrieve(
155
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/horse.png",
156
- "examples/ai_example_1.png"
157
- )
158
 
159
- # Human Example 1: A real photograph
160
- urlretrieve(
161
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio-guide/zookeeper.png",
162
- "examples/human_example_1.jpg"
163
- )
164
-
165
- # AI Example 2: An AI-generated portrait, good for testing face/hair detection
166
- urlretrieve(
167
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/stable-diffusion-sdxl/sdxl-gfpgan-output.png",
168
- "examples/ai_example_2.png"
169
- )
170
- print("Example images downloaded successfully.")
171
- except Exception as e:
172
- print(f"Failed to download example images: {e}")
173
-
174
-
175
  # --- 5. Launch the App ---
176
- # This line was already there, just make sure it's the last part of your script
177
  if __name__ == "__main__":
178
  demo.launch(debug=True)
 
5
  import numpy as np
6
  from captum.attr import LayerGradCam
7
  from captum.attr import visualization as viz
8
+ import requests # <-- Import requests
9
+ from io import BytesIO # <-- Import BytesIO
10
 
11
  # --- 1. Load Model and Processor ---
 
 
12
  print("Loading model and processor...")
13
  model_id = "Organika/sdxl-detector"
14
  processor = AutoImageProcessor.from_pretrained(model_id)
15
  model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=torch.float32)
16
+ model.eval()
17
  print("Model and processor loaded successfully.")
18
 
19
  # --- 2. Define the Explainability (Grad-CAM) Function ---
 
20
  def generate_heatmap(image_tensor, original_image, target_class_index):
 
 
21
  target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv
 
 
22
  lgc = LayerGradCam(model, target_layer)
 
 
 
23
  baselines = torch.zeros_like(image_tensor)
24
  attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True)
 
 
 
25
  heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
 
 
26
  visualized_image, _ = viz.visualize_image_attr(
27
+ heatmap, np.array(original_image), method="blended_heat_map", sign="all", show_colorbar=True, title="Model Attention Heatmap",
 
 
 
 
 
28
  )
29
  return visualized_image
30
 
31
+ # --- 3. MODIFIED Main Prediction Function ---
32
+ # Now it accepts two inputs: an uploaded image and a URL string.
33
+ def predict(image_upload: Image.Image, image_url: str):
34
+
35
+ # --- Logic to decide which input to use ---
36
+ if image_upload is not None:
37
+ input_image = image_upload
38
+ print(f"Processing uploaded image of size: {input_image.size}")
39
+ elif image_url:
40
+ try:
41
+ response = requests.get(image_url)
42
+ response.raise_for_status() # Raise an exception for bad status codes
43
+ input_image = Image.open(BytesIO(response.content))
44
+ print(f"Processing image from URL: {image_url}")
45
+ except Exception as e:
46
+ raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}")
47
+ else:
48
+ # If no input is provided, raise an error
49
+ raise gr.Error("Please upload an image or provide a URL to analyze.")
50
 
 
51
  if input_image.mode == 'RGBA':
52
  input_image = input_image.convert('RGB')
53
 
 
54
  inputs = processor(images=input_image, return_tensors="pt")
55
 
 
56
  with torch.no_grad():
57
  outputs = model(**inputs)
58
  logits = outputs.logits
59
 
 
60
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
 
 
61
  predicted_class_idx = logits.argmax(-1).item()
62
  confidence_score = probabilities[0][predicted_class_idx].item()
 
 
63
  predicted_label = model.config.id2label[predicted_class_idx]
64
 
 
 
65
  if predicted_label.lower() == 'ai':
66
  explanation = (
67
  f"The model is {confidence_score:.2%} confident that this image is AI-GENERATED.\n\n"
68
  "The heatmap on the right highlights the areas that most influenced this decision. "
69
+ "Pay close attention if these hotspots are on unnatural-looking features like hair, eyes, skin texture, or strange background details."
 
70
  )
71
  else:
72
  explanation = (
 
75
  "These are likely well-formed, realistic features that AI models often struggle to replicate perfectly."
76
  )
77
 
 
 
78
  print("Generating heatmap...")
79
  heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx)
80
  print("Heatmap generated.")
81
 
 
 
82
  labels_dict = {model.config.id2label[i]: float(probabilities[0][i]) for i in range(len(model.config.id2label))}
83
 
84
  return labels_dict, explanation, heatmap_image
85
 
86
+ # --- 4. MODIFIED Gradio Interface ---
87
+ # We use gr.Tabs to create separate input sections.
88
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
89
  gr.Markdown(
90
  """
91
  # AI Image Detector with Explainability
92
+ Determine if an image is AI-generated or human-made. Upload a file or paste a URL.
93
+ This tool uses the [Organika/sdxl-detector](https://huggingface.co/Organika/sdxl-detector) model and provides a **heatmap** to show *why* the model made its decision.
 
94
  """
95
  )
96
  with gr.Row():
97
  with gr.Column():
98
+ # --- TABS for different input methods ---
99
+ with gr.Tabs():
100
+ with gr.TabItem("Upload File"):
101
+ input_image_upload = gr.Image(type="pil", label="Upload Your Image")
102
+ with gr.TabItem("Use Image URL"):
103
+ input_image_url = gr.Textbox(label="Paste Image URL here")
104
+
105
  submit_btn = gr.Button("Analyze Image", variant="primary")
106
+
107
  with gr.Column():
108
  output_label = gr.Label(label="Prediction")
109
+ output_text = gr.Textbox(label="Explanation", lines=6, interactive=False)
110
  output_heatmap = gr.Image(label="Model Attention Heatmap")
111
 
112
+ # The click event now passes both possible inputs to the predict function
113
  submit_btn.click(
114
  fn=predict,
115
+ inputs=[input_image_upload, input_image_url],
116
  outputs=[output_label, output_text, output_heatmap]
117
  )
118
 
119
+ # We remove the examples for now to simplify, as they don't work well with a tabbed interface by default.
120
+ # If you want them back, you would need a more complex setup to handle which tab the example populates.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  # --- 5. Launch the App ---
 
123
  if __name__ == "__main__":
124
  demo.launch(debug=True)