clockclock commited on
Commit
304b7e6
·
verified ·
1 Parent(s): 03f09e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -34
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoModelForImageClassification, AutoImageProcessor
4
  from PIL import Image
5
  import numpy as np
6
  from captum.attr import LayerGradCam
7
  from captum.attr import visualization as viz
8
- import requests # <-- Import requests
9
- from io import BytesIO # <-- Import BytesIO
10
 
11
  # --- 1. Load Model and Processor ---
12
  print("Loading model and processor...")
@@ -16,32 +17,48 @@ model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=to
16
  model.eval()
17
  print("Model and processor loaded successfully.")
18
 
19
- # --- 2. Define the Explainability (Grad-CAM) Function ---
20
- def generate_heatmap(image_tensor, original_image, target_class_index):
21
 
22
- # --- THIS IS THE FIX ---
23
- # We define a wrapper function that ensures our model returns a simple tensor,
24
- # which is what Captum expects. It takes the model's output object and
25
- # extracts the 'logits' tensor from it.
26
  def model_forward_wrapper(input_tensor):
27
  outputs = model(pixel_values=input_tensor)
28
  return outputs.logits
29
 
30
- # This part is correct from our last fix.
31
  target_layer = model.swin.layernorm
32
-
33
- # Initialize LayerGradCam, but pass our new wrapper function instead of the raw model.
34
- # Captum will now use this wrapper to get the model's output.
35
  lgc = LayerGradCam(model_forward_wrapper, target_layer)
36
 
37
- # This call now works because `lgc` gets a proper tensor from our wrapper.
38
  attributions = lgc.attribute(image_tensor, target=target_class_index, relu_attributions=True)
39
 
40
- # The rest of the function remains the same.
41
- heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  visualized_image, _ = viz.visualize_image_attr(
44
- heatmap,
45
  np.array(original_image),
46
  method="blended_heat_map",
47
  sign="all",
@@ -50,24 +67,21 @@ def generate_heatmap(image_tensor, original_image, target_class_index):
50
  )
51
  return visualized_image
52
 
53
- # --- 3. MODIFIED Main Prediction Function ---
54
- # Now it accepts two inputs: an uploaded image and a URL string.
55
  def predict(image_upload: Image.Image, image_url: str):
56
-
57
- # --- Logic to decide which input to use ---
58
  if image_upload is not None:
59
  input_image = image_upload
60
  print(f"Processing uploaded image of size: {input_image.size}")
61
  elif image_url:
62
  try:
63
  response = requests.get(image_url)
64
- response.raise_for_status() # Raise an exception for bad status codes
65
  input_image = Image.open(BytesIO(response.content))
66
  print(f"Processing image from URL: {image_url}")
67
  except Exception as e:
68
  raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}")
69
  else:
70
- # If no input is provided, raise an error
71
  raise gr.Error("Please upload an image or provide a URL to analyze.")
72
 
73
  if input_image.mode == 'RGBA':
@@ -105,8 +119,8 @@ def predict(image_upload: Image.Image, image_url: str):
105
 
106
  return labels_dict, explanation, heatmap_image
107
 
108
- # --- 4. MODIFIED Gradio Interface ---
109
- # We use gr.Tabs to create separate input sections.
110
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
111
  gr.Markdown(
112
  """
@@ -117,30 +131,23 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
117
  )
118
  with gr.Row():
119
  with gr.Column():
120
- # --- TABS for different input methods ---
121
  with gr.Tabs():
122
  with gr.TabItem("Upload File"):
123
  input_image_upload = gr.Image(type="pil", label="Upload Your Image")
124
  with gr.TabItem("Use Image URL"):
125
  input_image_url = gr.Textbox(label="Paste Image URL here")
126
-
127
  submit_btn = gr.Button("Analyze Image", variant="primary")
128
-
129
  with gr.Column():
130
  output_label = gr.Label(label="Prediction")
131
  output_text = gr.Textbox(label="Explanation", lines=6, interactive=False)
132
  output_heatmap = gr.Image(label="Model Attention Heatmap")
133
 
134
- # The click event now passes both possible inputs to the predict function
135
  submit_btn.click(
136
  fn=predict,
137
  inputs=[input_image_upload, input_image_url],
138
  outputs=[output_label, output_text, output_heatmap]
139
  )
140
-
141
- # We remove the examples for now to simplify, as they don't work well with a tabbed interface by default.
142
- # If you want them back, you would need a more complex setup to handle which tab the example populates.
143
-
144
- # --- 5. Launch the App ---
145
  if __name__ == "__main__":
146
  demo.launch(debug=True)
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F # <-- ADD THIS IMPORT
4
  from transformers import AutoModelForImageClassification, AutoImageProcessor
5
  from PIL import Image
6
  import numpy as np
7
  from captum.attr import LayerGradCam
8
  from captum.attr import visualization as viz
9
+ import requests
10
+ from io import BytesIO
11
 
12
  # --- 1. Load Model and Processor ---
13
  print("Loading model and processor...")
 
17
  model.eval()
18
  print("Model and processor loaded successfully.")
19
 
 
 
20
 
21
+ # --- 2. FINAL, CORRECTED Explainability (Grad-CAM) Function ---
22
+ def generate_heatmap(image_tensor, original_image, target_class_index):
23
+ # This wrapper is correct and necessary for Captum to work with Hugging Face models.
 
24
  def model_forward_wrapper(input_tensor):
25
  outputs = model(pixel_values=input_tensor)
26
  return outputs.logits
27
 
28
+ # The target layer is also correct for the Swin Transformer.
29
  target_layer = model.swin.layernorm
30
+
31
+ # Initialize LayerGradCam with the wrapper and the target layer.
 
32
  lgc = LayerGradCam(model_forward_wrapper, target_layer)
33
 
34
+ # This call now works and returns the attributions.
35
  attributions = lgc.attribute(image_tensor, target=target_class_index, relu_attributions=True)
36
 
37
+ # --- THIS IS THE FIX for the Transformer Architecture ---
38
+ # Transformer models output a sequence of patch attributions, not a 2D grid.
39
+ # We must reshape this sequence into a grid and then upsample it.
40
+
41
+ # 1. Determine the grid size (e.g., for 49 patches, it's 7x7)
42
+ # We remove the batch dimension, and get the number of patches (sequence length).
43
+ num_patches = attributions.shape[-1]
44
+ grid_size = int(np.sqrt(num_patches))
45
+
46
+ # 2. Reshape the 1D attributions into a 2D grid.
47
+ heatmap = attributions.squeeze(0).squeeze(0).reshape(grid_size, grid_size)
48
+
49
+ # 3. Upsample the small heatmap to match the original image size for overlay.
50
+ # We need to add batch and channel dimensions back for the interpolate function.
51
+ heatmap = heatmap.unsqueeze(0).unsqueeze(0)
52
+ # Note: original_image.size is (W, H), interpolate needs size as (H, W)
53
+ upsampled_heatmap = F.interpolate(heatmap, size=original_image.size[::-1], mode='bilinear', align_corners=False)
54
+
55
+ # 4. Prepare the final heatmap for visualization
56
+ heatmap_for_viz = upsampled_heatmap.squeeze().cpu().detach().numpy()
57
+
58
+ # The visualization function expects a (H, W, C) shaped numpy array.
59
+ # Our heatmap is (H, W), so we add a channel dimension.
60
  visualized_image, _ = viz.visualize_image_attr(
61
+ np.expand_dims(heatmap_for_viz, axis=-1),
62
  np.array(original_image),
63
  method="blended_heat_map",
64
  sign="all",
 
67
  )
68
  return visualized_image
69
 
70
+
71
+ # --- 3. Main Prediction Function (Unchanged) ---
72
  def predict(image_upload: Image.Image, image_url: str):
 
 
73
  if image_upload is not None:
74
  input_image = image_upload
75
  print(f"Processing uploaded image of size: {input_image.size}")
76
  elif image_url:
77
  try:
78
  response = requests.get(image_url)
79
+ response.raise_for_status()
80
  input_image = Image.open(BytesIO(response.content))
81
  print(f"Processing image from URL: {image_url}")
82
  except Exception as e:
83
  raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}")
84
  else:
 
85
  raise gr.Error("Please upload an image or provide a URL to analyze.")
86
 
87
  if input_image.mode == 'RGBA':
 
119
 
120
  return labels_dict, explanation, heatmap_image
121
 
122
+
123
+ # --- 4. Gradio Interface (Unchanged) ---
124
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
125
  gr.Markdown(
126
  """
 
131
  )
132
  with gr.Row():
133
  with gr.Column():
 
134
  with gr.Tabs():
135
  with gr.TabItem("Upload File"):
136
  input_image_upload = gr.Image(type="pil", label="Upload Your Image")
137
  with gr.TabItem("Use Image URL"):
138
  input_image_url = gr.Textbox(label="Paste Image URL here")
 
139
  submit_btn = gr.Button("Analyze Image", variant="primary")
 
140
  with gr.Column():
141
  output_label = gr.Label(label="Prediction")
142
  output_text = gr.Textbox(label="Explanation", lines=6, interactive=False)
143
  output_heatmap = gr.Image(label="Model Attention Heatmap")
144
 
 
145
  submit_btn.click(
146
  fn=predict,
147
  inputs=[input_image_upload, input_image_url],
148
  outputs=[output_label, output_text, output_heatmap]
149
  )
150
+
151
+ # --- 5. Launch the App (Unchanged) ---
 
 
 
152
  if __name__ == "__main__":
153
  demo.launch(debug=True)