saakshigupta commited on
Commit
be65f5f
·
verified ·
1 Parent(s): 4048570

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -42
app.py CHANGED
@@ -424,10 +424,6 @@ def process_image_with_gradcam(image, model, device, pred_class):
424
 
425
  # ----- BLIP Image Captioning -----
426
 
427
- # Define conditional prompts for BLIP
428
- ORIGINAL_IMAGE_PROMPT = "an image of" # For the original image
429
- GRADCAM_IMAGE_PROMPT = "a heatmap showing" # For the GradCAM visualization
430
-
431
  # Function to load BLIP captioning model
432
  @st.cache_resource
433
  def load_blip_model():
@@ -443,33 +439,39 @@ def load_blip_model():
443
  # Function to generate image caption using BLIP
444
  def generate_image_caption(image, processor, model, is_gradcam=False, max_length=75, num_beams=5):
445
  """
446
- Generate a caption for the input image using BLIP model's conditional captioning
447
  """
448
  try:
449
- # Select the appropriate prompt based on image type
450
- conditional_prompt = GRADCAM_IMAGE_PROMPT if is_gradcam else ORIGINAL_IMAGE_PROMPT
451
-
452
  # Check for available GPU
453
  device = "cuda" if torch.cuda.is_available() else "cpu"
454
  model = model.to(device)
455
 
456
- # Get conditional caption
457
- conditional_inputs = processor(image, conditional_prompt, return_tensors="pt").to(device)
458
- with torch.no_grad():
459
- conditional_output = model.generate(**conditional_inputs, max_length=max_length, num_beams=num_beams)
460
- conditional_caption = processor.decode(conditional_output[0], skip_special_tokens=True)
 
 
 
461
 
462
- # Remove the prompt from the beginning if it appears
463
- if conditional_prompt in conditional_caption:
464
- conditional_caption = conditional_caption.replace(conditional_prompt, "").strip()
465
 
466
- # Format the caption based on image type
 
 
 
 
 
 
 
467
  if is_gradcam:
468
- full_info = format_gradcam_caption(conditional_caption)
469
  else:
470
- full_info = format_image_caption(conditional_caption)
471
-
472
- return full_info
473
  except Exception as e:
474
  st.error(f"Error generating caption: {str(e)}")
475
  return "Error generating caption"
@@ -675,14 +677,18 @@ def main():
675
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
676
 
677
  if uploaded_file is not None:
678
- # Display the uploaded image
679
  try:
 
680
  image = Image.open(uploaded_file).convert("RGB")
681
- st.image(image, caption="Uploaded Image", use_column_width=True)
 
 
 
 
682
 
683
  # Generate detailed caption for original image if BLIP model is loaded
684
  if st.session_state.blip_model_loaded:
685
- with st.spinner("Generating detailed image description..."):
686
  caption = generate_image_caption(
687
  image,
688
  st.session_state.blip_processor,
@@ -690,11 +696,8 @@ def main():
690
  is_gradcam=False
691
  )
692
  st.session_state.image_caption = caption
693
- st.success(f"📝 Image Description Generated")
694
 
695
- # Format the caption nicely
696
- st.markdown("### Image Description:")
697
- st.markdown(caption)
698
 
699
  # Detect with CLIP model if loaded
700
  if st.session_state.clip_model_loaded:
@@ -728,11 +731,12 @@ def main():
728
  pred_label = "Fake" if pred_class == 1 else "Real"
729
 
730
  # Display results
731
- result_col1, result_col2 = st.columns(2)
732
- with result_col1:
733
- st.metric("Prediction", pred_label)
734
- with result_col2:
735
- st.metric("Confidence", f"{confidence:.2%}")
 
736
 
737
  # GradCAM visualization
738
  st.subheader("GradCAM Visualization")
@@ -740,8 +744,8 @@ def main():
740
  image, model, device, pred_class
741
  )
742
 
743
- # Display GradCAM results
744
- st.image(comparison, caption="Original | CAM | Overlay", use_column_width=True)
745
 
746
  # Generate caption for GradCAM overlay image if BLIP model is loaded
747
  if st.session_state.blip_model_loaded:
@@ -754,11 +758,8 @@ def main():
754
  max_length=150 # Longer for detailed analysis
755
  )
756
  st.session_state.gradcam_caption = gradcam_caption
757
- st.success("✅ GradCAM analysis complete")
758
 
759
- # Format the GradCAM caption nicely
760
- st.markdown("### GradCAM Analysis:")
761
- st.markdown(gradcam_caption)
762
 
763
  # Save results in session state for LLM analysis
764
  st.session_state.current_image = image
@@ -854,10 +855,10 @@ def main():
854
  col1, col2 = st.columns([1, 2])
855
 
856
  with col1:
857
- # Display original image and overlay side by side
858
- st.image(st.session_state.current_image, caption="Original Image", use_column_width=True)
859
  if hasattr(st.session_state, 'current_overlay'):
860
- st.image(st.session_state.current_overlay, caption="GradCAM Overlay", use_column_width=True)
861
 
862
  with col2:
863
  # Detection result
 
424
 
425
  # ----- BLIP Image Captioning -----
426
 
 
 
 
 
427
  # Function to load BLIP captioning model
428
  @st.cache_resource
429
  def load_blip_model():
 
439
  # Function to generate image caption using BLIP
440
  def generate_image_caption(image, processor, model, is_gradcam=False, max_length=75, num_beams=5):
441
  """
442
+ Generate a caption for the input image using BLIP model
443
  """
444
  try:
 
 
 
445
  # Check for available GPU
446
  device = "cuda" if torch.cuda.is_available() else "cpu"
447
  model = model.to(device)
448
 
449
+ # Choose the right prompting method based on image type
450
+ if is_gradcam:
451
+ # For GradCAM, use conditional captioning with a specific prompt
452
+ text = "a heatmap showing"
453
+ inputs = processor(image, text, return_tensors="pt").to(device)
454
+ else:
455
+ # For original image, use unconditional captioning (works better for portraits)
456
+ inputs = processor(image, return_tensors="pt").to(device)
457
 
458
+ # Generate caption
459
+ with torch.no_grad():
460
+ output = model.generate(**inputs, max_length=max_length, num_beams=num_beams)
461
 
462
+ # Decode the output
463
+ caption = processor.decode(output[0], skip_special_tokens=True)
464
+
465
+ # Remove the prompt from the beginning if it appears (for conditional captioning)
466
+ if is_gradcam and "a heatmap showing" in caption:
467
+ caption = caption.replace("a heatmap showing", "").strip()
468
+
469
+ # Format based on image type
470
  if is_gradcam:
471
+ return format_gradcam_caption(caption)
472
  else:
473
+ return format_image_caption(caption)
474
+
 
475
  except Exception as e:
476
  st.error(f"Error generating caption: {str(e)}")
477
  return "Error generating caption"
 
677
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
678
 
679
  if uploaded_file is not None:
 
680
  try:
681
+ # Load and display the image (with controlled size)
682
  image = Image.open(uploaded_file).convert("RGB")
683
+
684
+ # Display the image with a controlled width
685
+ col1, col2 = st.columns([1, 2])
686
+ with col1:
687
+ st.image(image, caption="Uploaded Image", width=300)
688
 
689
  # Generate detailed caption for original image if BLIP model is loaded
690
  if st.session_state.blip_model_loaded:
691
+ with st.spinner("Generating image description..."):
692
  caption = generate_image_caption(
693
  image,
694
  st.session_state.blip_processor,
 
696
  is_gradcam=False
697
  )
698
  st.session_state.image_caption = caption
 
699
 
700
+ # Store caption but don't display it here - it will be shown in the summary section
 
 
701
 
702
  # Detect with CLIP model if loaded
703
  if st.session_state.clip_model_loaded:
 
731
  pred_label = "Fake" if pred_class == 1 else "Real"
732
 
733
  # Display results
734
+ with col2:
735
+ result_col1, result_col2 = st.columns(2)
736
+ with result_col1:
737
+ st.metric("Prediction", pred_label)
738
+ with result_col2:
739
+ st.metric("Confidence", f"{confidence:.2%}")
740
 
741
  # GradCAM visualization
742
  st.subheader("GradCAM Visualization")
 
744
  image, model, device, pred_class
745
  )
746
 
747
+ # Display GradCAM results (controlled size)
748
+ st.image(comparison, caption="Original | CAM | Overlay", width=700)
749
 
750
  # Generate caption for GradCAM overlay image if BLIP model is loaded
751
  if st.session_state.blip_model_loaded:
 
758
  max_length=150 # Longer for detailed analysis
759
  )
760
  st.session_state.gradcam_caption = gradcam_caption
 
761
 
762
+ # Store caption but don't display it here - it will be shown in the summary section
 
 
763
 
764
  # Save results in session state for LLM analysis
765
  st.session_state.current_image = image
 
855
  col1, col2 = st.columns([1, 2])
856
 
857
  with col1:
858
+ # Display original image and overlay side by side with controlled size
859
+ st.image(st.session_state.current_image, caption="Original Image", width=300)
860
  if hasattr(st.session_state, 'current_overlay'):
861
+ st.image(st.session_state.current_overlay, caption="GradCAM Overlay", width=300)
862
 
863
  with col2:
864
  # Detection result