clockclock commited on
Commit
ed558ff
·
verified ·
1 Parent(s): 3e5a622

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -17,27 +17,21 @@ model.eval()
17
  print("Model and processor loaded successfully.")
18
 
19
  # --- 2. MODIFIED Define the Explainability (Grad-CAM) Function ---
20
- # This function generates the heatmap showing which parts of the image the model focused on.
21
  def generate_heatmap(image_tensor, original_image, target_class_index):
22
- # --- THIS IS THE FIX ---
23
- # The original code assumed a ConvNeXT model. This model is a Swin Transformer.
24
- # We now target the final layer normalization of the Swin Transformer's main body,
25
- # which is a standard and effective layer for Grad-CAM on this architecture.
26
  target_layer = model.swin.layernorm
27
 
28
  # Initialize LayerGradCam
29
  lgc = LayerGradCam(model, target_layer)
30
 
31
- # Generate attributions (the "importance" of each pixel)
32
- # The baselines are a reference point, typically a black image.
33
- baselines = torch.zeros_like(image_tensor)
34
- attributions = lgc.attribute(image_tensor, target=target_class_index, baselines=baselines, relu_attributions=True)
35
 
36
- # The output of LayerGradCam is a heatmap. We process it for visualization.
37
- # We take the mean across the color channels and format it correctly.
38
  heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
39
 
40
- # Use Captum's visualization tool to overlay the heatmap on the original image.
41
  visualized_image, _ = viz.visualize_image_attr(
42
  heatmap,
43
  np.array(original_image),
 
17
  print("Model and processor loaded successfully.")
18
 
19
  # --- 2. MODIFIED Define the Explainability (Grad-CAM) Function ---
 
20
  def generate_heatmap(image_tensor, original_image, target_class_index):
21
+ # This part is correct from our last fix.
 
 
 
22
  target_layer = model.swin.layernorm
23
 
24
  # Initialize LayerGradCam
25
  lgc = LayerGradCam(model, target_layer)
26
 
27
+ # --- THIS IS THE FIX ---
28
+ # The 'baselines' argument is not used by LayerGradCam, so we remove it.
29
+ # The call is now simpler and correct for this specific method.
30
+ attributions = lgc.attribute(image_tensor, target=target_class_index, relu_attributions=True)
31
 
32
+ # The rest of the function remains the same.
 
33
  heatmap = np.transpose(attributions.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
34
 
 
35
  visualized_image, _ = viz.visualize_image_attr(
36
  heatmap,
37
  np.array(original_image),