Update app.py
Browse files
app.py
CHANGED
@@ -424,10 +424,6 @@ def process_image_with_gradcam(image, model, device, pred_class):
|
|
424 |
|
425 |
# ----- BLIP Image Captioning -----
|
426 |
|
427 |
-
# Define conditional prompts for BLIP
|
428 |
-
ORIGINAL_IMAGE_PROMPT = "an image of" # For the original image
|
429 |
-
GRADCAM_IMAGE_PROMPT = "a heatmap showing" # For the GradCAM visualization
|
430 |
-
|
431 |
# Function to load BLIP captioning model
|
432 |
@st.cache_resource
|
433 |
def load_blip_model():
|
@@ -443,33 +439,39 @@ def load_blip_model():
|
|
443 |
# Function to generate image caption using BLIP
|
444 |
def generate_image_caption(image, processor, model, is_gradcam=False, max_length=75, num_beams=5):
|
445 |
"""
|
446 |
-
Generate a caption for the input image using BLIP model
|
447 |
"""
|
448 |
try:
|
449 |
-
# Select the appropriate prompt based on image type
|
450 |
-
conditional_prompt = GRADCAM_IMAGE_PROMPT if is_gradcam else ORIGINAL_IMAGE_PROMPT
|
451 |
-
|
452 |
# Check for available GPU
|
453 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
454 |
model = model.to(device)
|
455 |
|
456 |
-
#
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
|
|
|
|
|
|
461 |
|
462 |
-
#
|
463 |
-
|
464 |
-
|
465 |
|
466 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
if is_gradcam:
|
468 |
-
|
469 |
else:
|
470 |
-
|
471 |
-
|
472 |
-
return full_info
|
473 |
except Exception as e:
|
474 |
st.error(f"Error generating caption: {str(e)}")
|
475 |
return "Error generating caption"
|
@@ -675,14 +677,18 @@ def main():
|
|
675 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
676 |
|
677 |
if uploaded_file is not None:
|
678 |
-
# Display the uploaded image
|
679 |
try:
|
|
|
680 |
image = Image.open(uploaded_file).convert("RGB")
|
681 |
-
|
|
|
|
|
|
|
|
|
682 |
|
683 |
# Generate detailed caption for original image if BLIP model is loaded
|
684 |
if st.session_state.blip_model_loaded:
|
685 |
-
with st.spinner("Generating
|
686 |
caption = generate_image_caption(
|
687 |
image,
|
688 |
st.session_state.blip_processor,
|
@@ -690,11 +696,8 @@ def main():
|
|
690 |
is_gradcam=False
|
691 |
)
|
692 |
st.session_state.image_caption = caption
|
693 |
-
st.success(f"📝 Image Description Generated")
|
694 |
|
695 |
-
#
|
696 |
-
st.markdown("### Image Description:")
|
697 |
-
st.markdown(caption)
|
698 |
|
699 |
# Detect with CLIP model if loaded
|
700 |
if st.session_state.clip_model_loaded:
|
@@ -728,11 +731,12 @@ def main():
|
|
728 |
pred_label = "Fake" if pred_class == 1 else "Real"
|
729 |
|
730 |
# Display results
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
|
|
736 |
|
737 |
# GradCAM visualization
|
738 |
st.subheader("GradCAM Visualization")
|
@@ -740,8 +744,8 @@ def main():
|
|
740 |
image, model, device, pred_class
|
741 |
)
|
742 |
|
743 |
-
# Display GradCAM results
|
744 |
-
st.image(comparison, caption="Original | CAM | Overlay",
|
745 |
|
746 |
# Generate caption for GradCAM overlay image if BLIP model is loaded
|
747 |
if st.session_state.blip_model_loaded:
|
@@ -754,11 +758,8 @@ def main():
|
|
754 |
max_length=150 # Longer for detailed analysis
|
755 |
)
|
756 |
st.session_state.gradcam_caption = gradcam_caption
|
757 |
-
st.success("✅ GradCAM analysis complete")
|
758 |
|
759 |
-
#
|
760 |
-
st.markdown("### GradCAM Analysis:")
|
761 |
-
st.markdown(gradcam_caption)
|
762 |
|
763 |
# Save results in session state for LLM analysis
|
764 |
st.session_state.current_image = image
|
@@ -854,10 +855,10 @@ def main():
|
|
854 |
col1, col2 = st.columns([1, 2])
|
855 |
|
856 |
with col1:
|
857 |
-
# Display original image and overlay side by side
|
858 |
-
st.image(st.session_state.current_image, caption="Original Image",
|
859 |
if hasattr(st.session_state, 'current_overlay'):
|
860 |
-
st.image(st.session_state.current_overlay, caption="GradCAM Overlay",
|
861 |
|
862 |
with col2:
|
863 |
# Detection result
|
|
|
424 |
|
425 |
# ----- BLIP Image Captioning -----
|
426 |
|
|
|
|
|
|
|
|
|
427 |
# Function to load BLIP captioning model
|
428 |
@st.cache_resource
|
429 |
def load_blip_model():
|
|
|
439 |
# Function to generate image caption using BLIP
|
440 |
def generate_image_caption(image, processor, model, is_gradcam=False, max_length=75, num_beams=5):
|
441 |
"""
|
442 |
+
Generate a caption for the input image using BLIP model
|
443 |
"""
|
444 |
try:
|
|
|
|
|
|
|
445 |
# Check for available GPU
|
446 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
447 |
model = model.to(device)
|
448 |
|
449 |
+
# Choose the right prompting method based on image type
|
450 |
+
if is_gradcam:
|
451 |
+
# For GradCAM, use conditional captioning with a specific prompt
|
452 |
+
text = "a heatmap showing"
|
453 |
+
inputs = processor(image, text, return_tensors="pt").to(device)
|
454 |
+
else:
|
455 |
+
# For original image, use unconditional captioning (works better for portraits)
|
456 |
+
inputs = processor(image, return_tensors="pt").to(device)
|
457 |
|
458 |
+
# Generate caption
|
459 |
+
with torch.no_grad():
|
460 |
+
output = model.generate(**inputs, max_length=max_length, num_beams=num_beams)
|
461 |
|
462 |
+
# Decode the output
|
463 |
+
caption = processor.decode(output[0], skip_special_tokens=True)
|
464 |
+
|
465 |
+
# Remove the prompt from the beginning if it appears (for conditional captioning)
|
466 |
+
if is_gradcam and "a heatmap showing" in caption:
|
467 |
+
caption = caption.replace("a heatmap showing", "").strip()
|
468 |
+
|
469 |
+
# Format based on image type
|
470 |
if is_gradcam:
|
471 |
+
return format_gradcam_caption(caption)
|
472 |
else:
|
473 |
+
return format_image_caption(caption)
|
474 |
+
|
|
|
475 |
except Exception as e:
|
476 |
st.error(f"Error generating caption: {str(e)}")
|
477 |
return "Error generating caption"
|
|
|
677 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
678 |
|
679 |
if uploaded_file is not None:
|
|
|
680 |
try:
|
681 |
+
# Load and display the image (with controlled size)
|
682 |
image = Image.open(uploaded_file).convert("RGB")
|
683 |
+
|
684 |
+
# Display the image with a controlled width
|
685 |
+
col1, col2 = st.columns([1, 2])
|
686 |
+
with col1:
|
687 |
+
st.image(image, caption="Uploaded Image", width=300)
|
688 |
|
689 |
# Generate detailed caption for original image if BLIP model is loaded
|
690 |
if st.session_state.blip_model_loaded:
|
691 |
+
with st.spinner("Generating image description..."):
|
692 |
caption = generate_image_caption(
|
693 |
image,
|
694 |
st.session_state.blip_processor,
|
|
|
696 |
is_gradcam=False
|
697 |
)
|
698 |
st.session_state.image_caption = caption
|
|
|
699 |
|
700 |
+
# Store caption but don't display it here - it will be shown in the summary section
|
|
|
|
|
701 |
|
702 |
# Detect with CLIP model if loaded
|
703 |
if st.session_state.clip_model_loaded:
|
|
|
731 |
pred_label = "Fake" if pred_class == 1 else "Real"
|
732 |
|
733 |
# Display results
|
734 |
+
with col2:
|
735 |
+
result_col1, result_col2 = st.columns(2)
|
736 |
+
with result_col1:
|
737 |
+
st.metric("Prediction", pred_label)
|
738 |
+
with result_col2:
|
739 |
+
st.metric("Confidence", f"{confidence:.2%}")
|
740 |
|
741 |
# GradCAM visualization
|
742 |
st.subheader("GradCAM Visualization")
|
|
|
744 |
image, model, device, pred_class
|
745 |
)
|
746 |
|
747 |
+
# Display GradCAM results (controlled size)
|
748 |
+
st.image(comparison, caption="Original | CAM | Overlay", width=700)
|
749 |
|
750 |
# Generate caption for GradCAM overlay image if BLIP model is loaded
|
751 |
if st.session_state.blip_model_loaded:
|
|
|
758 |
max_length=150 # Longer for detailed analysis
|
759 |
)
|
760 |
st.session_state.gradcam_caption = gradcam_caption
|
|
|
761 |
|
762 |
+
# Store caption but don't display it here - it will be shown in the summary section
|
|
|
|
|
763 |
|
764 |
# Save results in session state for LLM analysis
|
765 |
st.session_state.current_image = image
|
|
|
855 |
col1, col2 = st.columns([1, 2])
|
856 |
|
857 |
with col1:
|
858 |
+
# Display original image and overlay side by side with controlled size
|
859 |
+
st.image(st.session_state.current_image, caption="Original Image", width=300)
|
860 |
if hasattr(st.session_state, 'current_overlay'):
|
861 |
+
st.image(st.session_state.current_overlay, caption="GradCAM Overlay", width=300)
|
862 |
|
863 |
with col2:
|
864 |
# Detection result
|