clockclock commited on
Commit
03f09e4
·
verified ·
1 Parent(s): ed558ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -16,17 +16,25 @@ model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=to
16
  model.eval()
17
  print("Model and processor loaded successfully.")
18
 
19
- # --- 2. MODIFIED Define the Explainability (Grad-CAM) Function ---
20
  def generate_heatmap(image_tensor, original_image, target_class_index):
 
 
 
 
 
 
 
 
 
21
  # This part is correct from our last fix.
22
  target_layer = model.swin.layernorm
23
 
24
- # Initialize LayerGradCam
25
- lgc = LayerGradCam(model, target_layer)
 
26
 
27
- # --- THIS IS THE FIX ---
28
- # The 'baselines' argument is not used by LayerGradCam, so we remove it.
29
- # The call is now simpler and correct for this specific method.
30
  attributions = lgc.attribute(image_tensor, target=target_class_index, relu_attributions=True)
31
 
32
  # The rest of the function remains the same.
 
16
  model.eval()
17
  print("Model and processor loaded successfully.")
18
 
19
+ # --- 2. Define the Explainability (Grad-CAM) Function ---
20
  def generate_heatmap(image_tensor, original_image, target_class_index):
21
+
22
+ # --- THIS IS THE FIX ---
23
+ # We define a wrapper function that ensures our model returns a simple tensor,
24
+ # which is what Captum expects. It takes the model's output object and
25
+ # extracts the 'logits' tensor from it.
26
+ def model_forward_wrapper(input_tensor):
27
+ outputs = model(pixel_values=input_tensor)
28
+ return outputs.logits
29
+
30
  # This part is correct from our last fix.
31
  target_layer = model.swin.layernorm
32
 
33
+ # Initialize LayerGradCam, but pass our new wrapper function instead of the raw model.
34
+ # Captum will now use this wrapper to get the model's output.
35
+ lgc = LayerGradCam(model_forward_wrapper, target_layer)
36
 
37
+ # This call now works because `lgc` gets a proper tensor from our wrapper.
 
 
38
  attributions = lgc.attribute(image_tensor, target=target_class_index, relu_attributions=True)
39
 
40
  # The rest of the function remains the same.