clockclock commited on
Commit
3e5a622
·
verified ·
1 Parent(s): 663da4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -3
app.py CHANGED
@@ -16,15 +16,35 @@ model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=to
16
  model.eval()
17
  print("Model and processor loaded successfully.")
18
 
19
- # --- 2. Define the Explainability (Grad-CAM) Function ---
 
20
  def generate_heatmap(image_tensor, original_image, target_class_index):
21
- target_layer = model.convnext.encoder.stages[-1].layers[-1].dwconv
 
 
 
 
 
 
22
  lgc = LayerGradCam(model, target_layer)
 
 
 
23
  baselines = torch.zeros_like(image_tensor)
24
  attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True)
 
 
 
25
  heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
 
 
26
  visualized_image, _ = viz.visualize_image_attr(
27
- heatmap, np.array(original_image), method="blended_heat_map", sign="all", show_colorbar=True, title="Model Attention Heatmap",
 
 
 
 
 
28
  )
29
  return visualized_image
30
 
 
16
  model.eval()
17
  print("Model and processor loaded successfully.")
18
 
19
+ # --- 2. MODIFIED Define the Explainability (Grad-CAM) Function ---
20
+ # This function generates the heatmap showing which parts of the image the model focused on.
21
  def generate_heatmap(image_tensor, original_image, target_class_index):
22
+ # --- THIS IS THE FIX ---
23
+ # The original code assumed a ConvNeXT model. This model is a Swin Transformer.
24
+ # We now target the final layer normalization of the Swin Transformer's main body,
25
+ # which is a standard and effective layer for Grad-CAM on this architecture.
26
+ target_layer = model.swin.layernorm
27
+
28
+ # Initialize LayerGradCam
29
  lgc = LayerGradCam(model, target_layer)
30
+
31
+ # Generate attributions (the "importance" of each pixel)
32
+ # The baselines are a reference point, typically a black image.
33
  baselines = torch.zeros_like(image_tensor)
34
  attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True)
35
+
36
+ # The output of LayerGradCam is a heatmap. We process it for visualization.
37
+ # We take the mean across the color channels and format it correctly.
38
  heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
39
+
40
+ # Use Captum's visualization tool to overlay the heatmap on the original image.
41
  visualized_image, _ = viz.visualize_image_attr(
42
+ heatmap,
43
+ np.array(original_image),
44
+ method="blended_heat_map",
45
+ sign="all",
46
+ show_colorbar=True,
47
+ title="Model Attention Heatmap",
48
  )
49
  return visualized_image
50