Update app.py
Browse files
app.py
CHANGED
@@ -419,37 +419,52 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
|
|
419 |
# Preprocess image for Xception
|
420 |
def preprocess_image_xception(image):
|
421 |
"""Preprocesses image for Xception model input and face detection."""
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
|
|
|
|
|
|
|
|
426 |
|
427 |
-
|
428 |
-
|
429 |
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
|
|
|
|
440 |
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
|
|
450 |
|
451 |
-
|
452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
|
454 |
# Main app
|
455 |
def main():
|
@@ -555,69 +570,103 @@ def main():
|
|
555 |
|
556 |
# Detect with Xception model if loaded
|
557 |
if st.session_state.xception_model_loaded:
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
device = st.session_state.device
|
564 |
-
model = st.session_state.xception_model
|
565 |
-
|
566 |
-
# Move tensor to device
|
567 |
-
input_tensor = input_tensor.to(device)
|
568 |
-
|
569 |
-
# Forward pass
|
570 |
-
with torch.no_grad():
|
571 |
-
logits = model(input_tensor)
|
572 |
-
probabilities = torch.softmax(logits, dim=1)[0]
|
573 |
-
pred_class = torch.argmax(probabilities).item()
|
574 |
-
confidence = probabilities[pred_class].item()
|
575 |
-
pred_label = "Fake" if pred_class == 0 else "Real" # Check class mapping
|
576 |
-
|
577 |
-
# Display results
|
578 |
-
with col2:
|
579 |
-
st.markdown("### Detection Result")
|
580 |
-
st.markdown(f"**Classification:** {pred_label} (Confidence: {confidence:.2%})")
|
581 |
|
582 |
-
#
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
st.image(Image.fromarray(img_draw), caption="Detected Face", width=300)
|
589 |
-
|
590 |
-
# GradCAM visualization
|
591 |
-
st.subheader("GradCAM Visualization")
|
592 |
-
cam, overlay, comparison, detected_face_box = process_image_with_xception_gradcam(
|
593 |
-
image, model, device, pred_class
|
594 |
-
)
|
595 |
-
|
596 |
-
if comparison:
|
597 |
-
# Display GradCAM results (controlled size)
|
598 |
-
st.image(comparison, caption="Original | CAM | Overlay", width=700)
|
599 |
|
600 |
-
#
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
610 |
)
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
else:
|
622 |
st.warning("⚠️ Please load the Xception model first to perform initial detection.")
|
623 |
except Exception as e:
|
|
|
419 |
# Preprocess image for Xception
|
420 |
def preprocess_image_xception(image):
|
421 |
"""Preprocesses image for Xception model input and face detection."""
|
422 |
+
try:
|
423 |
+
st.write("Starting image preprocessing...")
|
424 |
+
face_detector = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
425 |
+
image_np = np.array(image.convert('RGB')) # Ensure RGB
|
426 |
+
st.write(f"Image shape: {image_np.shape}")
|
427 |
+
|
428 |
+
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
|
429 |
+
faces = face_detector.detectMultiScale(gray, 1.1, 5)
|
430 |
|
431 |
+
face_img_for_transform = image # Default to whole image
|
432 |
+
face_box_display = None # For drawing on original image
|
433 |
|
434 |
+
if len(faces) == 0:
|
435 |
+
st.warning("No face detected, using whole image for prediction/CAM.")
|
436 |
+
else:
|
437 |
+
areas = [w * h for (x, y, w, h) in faces]
|
438 |
+
largest_idx = np.argmax(areas)
|
439 |
+
x, y, w, h = faces[largest_idx]
|
440 |
+
st.write(f"Face detected at: x={x}, y={y}, w={w}, h={h}")
|
441 |
+
|
442 |
+
padding_x = int(w * 0.05) # Use percentages as in gradcam_xception
|
443 |
+
padding_y = int(h * 0.05)
|
444 |
+
x1, y1 = max(0, x - padding_x), max(0, y - padding_y)
|
445 |
+
x2, y2 = min(image_np.shape[1], x + w + padding_x), min(image_np.shape[0], y + h + padding_y)
|
446 |
|
447 |
+
# Use the padded face region for the model transform
|
448 |
+
face_img_for_transform = Image.fromarray(image_np[y1:y2, x1:x2])
|
449 |
+
# Use the original detected box (without padding) for display rectangle
|
450 |
+
face_box_display = (x, y, w, h)
|
451 |
|
452 |
+
# Xception specific transform
|
453 |
+
transform = get_xception_transform()
|
454 |
+
# Apply transform to the selected region (face or whole image)
|
455 |
+
input_tensor = transform(face_img_for_transform).unsqueeze(0)
|
456 |
+
st.write(f"Tensor shape: {input_tensor.shape}")
|
457 |
|
458 |
+
# Return tensor, original full image, and the display face box
|
459 |
+
return input_tensor, image, face_box_display
|
460 |
+
except Exception as e:
|
461 |
+
st.error(f"Error in preprocessing image: {str(e)}")
|
462 |
+
import traceback
|
463 |
+
st.error(traceback.format_exc())
|
464 |
+
# Return defaults that won't break the pipeline
|
465 |
+
transform = get_xception_transform()
|
466 |
+
input_tensor = transform(image).unsqueeze(0)
|
467 |
+
return input_tensor, image, None
|
468 |
|
469 |
# Main app
|
470 |
def main():
|
|
|
570 |
|
571 |
# Detect with Xception model if loaded
|
572 |
if st.session_state.xception_model_loaded:
|
573 |
+
try:
|
574 |
+
with st.spinner("Analyzing image with Xception model..."):
|
575 |
+
# Preprocess image for Xception
|
576 |
+
st.write("Starting Xception processing...")
|
577 |
+
input_tensor, original_image, face_box = preprocess_image_xception(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
578 |
|
579 |
+
# Get device and model
|
580 |
+
device = st.session_state.device
|
581 |
+
model = st.session_state.xception_model
|
582 |
+
|
583 |
+
# Ensure model is in eval mode
|
584 |
+
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
|
586 |
+
# Move tensor to device
|
587 |
+
input_tensor = input_tensor.to(device)
|
588 |
+
st.write(f"Input tensor on device: {device}")
|
589 |
+
|
590 |
+
# Forward pass with proper error handling
|
591 |
+
try:
|
592 |
+
with torch.no_grad():
|
593 |
+
st.write("Running model inference...")
|
594 |
+
logits = model(input_tensor)
|
595 |
+
st.write(f"Raw logits: {logits}")
|
596 |
+
probabilities = torch.softmax(logits, dim=1)[0]
|
597 |
+
st.write(f"Probabilities: {probabilities}")
|
598 |
+
pred_class = torch.argmax(probabilities).item()
|
599 |
+
confidence = probabilities[pred_class].item()
|
600 |
+
st.write(f"Predicted class: {pred_class}, Confidence: {confidence:.4f}")
|
601 |
+
|
602 |
+
# Explicit class mapping - adjust if needed based on your model
|
603 |
+
pred_label = "Fake" if pred_class == 0 else "Real"
|
604 |
+
st.write(f"Mapped to label: {pred_label}")
|
605 |
+
except Exception as e:
|
606 |
+
st.error(f"Error in model inference: {str(e)}")
|
607 |
+
import traceback
|
608 |
+
st.error(traceback.format_exc())
|
609 |
+
# Set default values
|
610 |
+
pred_class = 0
|
611 |
+
confidence = 0.5
|
612 |
+
pred_label = "Error in prediction"
|
613 |
+
|
614 |
+
# Display results
|
615 |
+
with col2:
|
616 |
+
st.markdown("### Detection Result")
|
617 |
+
st.markdown(f"**Classification:** {pred_label} (Confidence: {confidence:.2%})")
|
618 |
+
|
619 |
+
# Display face box on image if detected
|
620 |
+
if face_box:
|
621 |
+
img_to_show = original_image.copy()
|
622 |
+
img_draw = np.array(img_to_show)
|
623 |
+
x, y, w, h = face_box
|
624 |
+
cv2.rectangle(img_draw, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
625 |
+
st.image(Image.fromarray(img_draw), caption="Detected Face", width=300)
|
626 |
+
|
627 |
+
# GradCAM visualization with error handling
|
628 |
+
st.subheader("GradCAM Visualization")
|
629 |
+
try:
|
630 |
+
st.write("Generating GradCAM visualization...")
|
631 |
+
cam, overlay, comparison, detected_face_box = process_image_with_xception_gradcam(
|
632 |
+
image, model, device, pred_class
|
633 |
)
|
634 |
+
|
635 |
+
if comparison:
|
636 |
+
# Display GradCAM results (controlled size)
|
637 |
+
st.image(comparison, caption="Original | CAM | Overlay", width=700)
|
638 |
+
|
639 |
+
# Save for later use
|
640 |
+
st.session_state.comparison_image = comparison
|
641 |
+
else:
|
642 |
+
st.error("GradCAM visualization failed - comparison image not generated")
|
643 |
+
|
644 |
+
# Generate caption for GradCAM overlay image if BLIP model is loaded
|
645 |
+
if st.session_state.blip_model_loaded and overlay:
|
646 |
+
with st.spinner("Analyzing GradCAM visualization..."):
|
647 |
+
gradcam_caption = generate_gradcam_caption(
|
648 |
+
overlay,
|
649 |
+
st.session_state.finetuned_processor,
|
650 |
+
st.session_state.finetuned_model
|
651 |
+
)
|
652 |
+
st.session_state.gradcam_caption = gradcam_caption
|
653 |
+
except Exception as e:
|
654 |
+
st.error(f"Error generating GradCAM: {str(e)}")
|
655 |
+
import traceback
|
656 |
+
st.error(traceback.format_exc())
|
657 |
+
|
658 |
+
# Save results in session state for LLM analysis
|
659 |
+
st.session_state.current_image = image
|
660 |
+
st.session_state.current_overlay = overlay if 'overlay' in locals() else None
|
661 |
+
st.session_state.current_face_box = detected_face_box if 'detected_face_box' in locals() else None
|
662 |
+
st.session_state.current_pred_label = pred_label
|
663 |
+
st.session_state.current_confidence = confidence
|
664 |
+
|
665 |
+
st.success("✅ Initial detection and GradCAM visualization complete!")
|
666 |
+
except Exception as e:
|
667 |
+
st.error(f"Overall error in Xception processing: {str(e)}")
|
668 |
+
import traceback
|
669 |
+
st.error(traceback.format_exc())
|
670 |
else:
|
671 |
st.warning("⚠️ Please load the Xception model first to perform initial detection.")
|
672 |
except Exception as e:
|