Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -16,17 +16,25 @@ model = AutoModelForImageClassification.from_pretrained(model_id, torch_dtype=to
|
|
16 |
model.eval()
|
17 |
print("Model and processor loaded successfully.")
|
18 |
|
19 |
-
# --- 2.
|
20 |
def generate_heatmap(image_tensor, original_image, target_class_index):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
# This part is correct from our last fix.
|
22 |
target_layer = model.swin.layernorm
|
23 |
|
24 |
-
# Initialize LayerGradCam
|
25 |
-
|
|
|
26 |
|
27 |
-
#
|
28 |
-
# The 'baselines' argument is not used by LayerGradCam, so we remove it.
|
29 |
-
# The call is now simpler and correct for this specific method.
|
30 |
attributions = lgc.attribute(image_tensor, target=target_class_index, relu_attributions=True)
|
31 |
|
32 |
# The rest of the function remains the same.
|
|
|
16 |
model.eval()
|
17 |
print("Model and processor loaded successfully.")
|
18 |
|
19 |
+
# --- 2. Define the Explainability (Grad-CAM) Function ---
|
20 |
def generate_heatmap(image_tensor, original_image, target_class_index):
|
21 |
+
|
22 |
+
# --- THIS IS THE FIX ---
|
23 |
+
# We define a wrapper function that ensures our model returns a simple tensor,
|
24 |
+
# which is what Captum expects. It takes the model's output object and
|
25 |
+
# extracts the 'logits' tensor from it.
|
26 |
+
def model_forward_wrapper(input_tensor):
|
27 |
+
outputs = model(pixel_values=input_tensor)
|
28 |
+
return outputs.logits
|
29 |
+
|
30 |
# This part is correct from our last fix.
|
31 |
target_layer = model.swin.layernorm
|
32 |
|
33 |
+
# Initialize LayerGradCam, but pass our new wrapper function instead of the raw model.
|
34 |
+
# Captum will now use this wrapper to get the model's output.
|
35 |
+
lgc = LayerGradCam(model_forward_wrapper, target_layer)
|
36 |
|
37 |
+
# This call now works because `lgc` gets a proper tensor from our wrapper.
|
|
|
|
|
38 |
attributions = lgc.attribute(image_tensor, target=target_class_index, relu_attributions=True)
|
39 |
|
40 |
# The rest of the function remains the same.
|