# import gradio as gr | |
# import numpy as np | |
# import torch | |
# from PIL import Image | |
# import matplotlib.pyplot as plt | |
# from transformers import pipeline | |
# import warnings | |
# from io import BytesIO | |
# import importlib.util | |
# import os | |
# import openai | |
# # Suppress warnings | |
# warnings.filterwarnings("ignore") | |
# # Set up OpenAI API key | |
# api_key = os.getenv('OPENAI_API_KEY') | |
# if not api_key: | |
# print("No OpenAI API key found - will use simple keyword extraction") | |
# elif not api_key.startswith("sk-proj-") and not api_key.startswith("sk-"): | |
# print("API key found but doesn't look correct") | |
# elif api_key.strip() != api_key: | |
# print("API key has leading or trailing whitespace - please fix it.") | |
# else: | |
# print("OpenAI API key found and looks good!") | |
# openai.api_key = api_key | |
# # Global variables for models | |
# detector = None | |
# sam_predictor = None | |
# def load_detector(): | |
# """Load the OWL-ViT detector once and cache it.""" | |
# global detector | |
# if detector is None: | |
# print("Loading OWL-ViT model...") | |
# detector = pipeline( | |
# model="google/owlv2-base-patch16-ensemble", | |
# task="zero-shot-object-detection", | |
# device=0 if torch.cuda.is_available() else -1 | |
# ) | |
# print("OWL-ViT model loaded successfully!") | |
# def install_sam2_if_needed(): | |
# """Check if SAM2 is installed, and install it if needed.""" | |
# if importlib.util.find_spec("sam2") is not None: | |
# print("SAM2 is already installed.") | |
# return True | |
# try: | |
# import subprocess | |
# import sys | |
# print("Installing SAM2 from GitHub...") | |
# subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/sam2.git"]) | |
# print("SAM2 installed successfully.") | |
# return True | |
# except Exception as e: | |
# print(f"Error installing SAM2: {e}") | |
# return False | |
# def load_sam_predictor(): | |
# """Load SAM2 predictor if available.""" | |
# global sam_predictor | |
# if sam_predictor is None: | |
# if install_sam2_if_needed(): | |
# try: | |
# from sam2.sam2_image_predictor import SAM2ImagePredictor | |
# print("Loading SAM2 model...") | |
# sam_predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
# sam_predictor.model.to(device) | |
# print(f"SAM2 model loaded successfully on {device}!") | |
# return True | |
# except Exception as e: | |
# print(f"Error loading SAM2: {e}") | |
# return False | |
# return sam_predictor is not None | |
# def detect_objects_owlv2(text_query, image, threshold=0.1): | |
# """Detect objects using OWL-ViT.""" | |
# try: | |
# load_detector() | |
# if isinstance(image, np.ndarray): | |
# image = Image.fromarray(image) | |
# # Clean up the text query | |
# query_terms = [term.strip() for term in text_query.split(',') if term.strip()] | |
# if not query_terms: | |
# query_terms = ["object"] | |
# print(f"Detecting: {query_terms}") | |
# predictions = detector(image, candidate_labels=query_terms) | |
# detections = [] | |
# for pred in predictions: | |
# if pred['score'] >= threshold: | |
# bbox = pred['box'] | |
# width, height = image.size | |
# normalized_bbox = [ | |
# bbox['xmin'] / width, | |
# bbox['ymin'] / height, | |
# bbox['xmax'] / width, | |
# bbox['ymax'] / height | |
# ] | |
# detection = { | |
# 'label': pred['label'], | |
# 'bbox': normalized_bbox, | |
# 'score': pred['score'] | |
# } | |
# detections.append(detection) | |
# return detections, image | |
# except Exception as e: | |
# print(f"Detection error: {e}") | |
# return [], image | |
# def generate_masks_sam2(detections, image): | |
# """Generate segmentation masks using SAM2.""" | |
# try: | |
# if not load_sam_predictor(): | |
# print("SAM2 not available, skipping mask generation") | |
# return detections | |
# if isinstance(image, np.ndarray): | |
# image = Image.fromarray(image) | |
# image_np = np.array(image.convert("RGB")) | |
# H, W = image_np.shape[:2] | |
# # Set image for SAM2 | |
# sam_predictor.set_image(image_np) | |
# # Convert normalized bboxes to pixel coordinates | |
# input_boxes = [] | |
# for det in detections: | |
# x1, y1, x2, y2 = det['bbox'] | |
# input_boxes.append([int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H)]) | |
# if not input_boxes: | |
# return detections | |
# input_boxes = np.array(input_boxes) | |
# print(f"Generating masks for {len(input_boxes)} detections...") | |
# with torch.inference_mode(): | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
# if device == "cuda": | |
# with torch.autocast("cuda", dtype=torch.bfloat16): | |
# masks, scores, _ = sam_predictor.predict( | |
# point_coords=None, | |
# point_labels=None, | |
# box=input_boxes, | |
# multimask_output=False | |
# ) | |
# else: | |
# masks, scores, _ = sam_predictor.predict( | |
# point_coords=None, | |
# point_labels=None, | |
# box=input_boxes, | |
# multimask_output=False | |
# ) | |
# # Add masks to detections | |
# results = [] | |
# for i, det in enumerate(detections): | |
# new_det = det.copy() | |
# mask = masks[i] | |
# if mask.ndim == 3: | |
# mask = mask[0] # Remove batch dimension if present | |
# new_det['mask'] = mask.astype(np.uint8) | |
# results.append(new_det) | |
# print(f"Successfully generated {len(results)} masks") | |
# return results | |
# except Exception as e: | |
# print(f"SAM2 mask generation error: {e}") | |
# return detections | |
# def visualize_detections_with_masks(image, detections_with_masks, show_labels=True, show_boxes=True): | |
# """ | |
# Visualize the detections with their segmentation masks. | |
# Returns PIL Image instead of showing plot. | |
# """ | |
# # Load the image | |
# if isinstance(image, np.ndarray): | |
# image = Image.fromarray(image) | |
# image_np = np.array(image.convert("RGB")) | |
# # Get image dimensions | |
# height, width = image_np.shape[:2] | |
# # Create figure | |
# fig = plt.figure(figsize=(12, 8)) | |
# plt.imshow(image_np) | |
# # Define colors for different instances | |
# colors = plt.cm.tab10(np.linspace(0, 1, 10)) | |
# # Plot each detection | |
# for i, detection in enumerate(detections_with_masks): | |
# # Get bbox, mask, label, and score | |
# bbox = detection['bbox'] | |
# label = detection['label'] | |
# score = detection['score'] | |
# # Convert normalized bbox to pixel coordinates | |
# x1, y1, x2, y2 = bbox | |
# x1_px, y1_px = int(x1 * width), int(y1 * height) | |
# x2_px, y2_px = int(x2 * width), int(y2 * height) | |
# # Color for this instance | |
# color = colors[i % len(colors)] | |
# # Display mask with transparency if available | |
# if 'mask' in detection: | |
# mask = detection['mask'] | |
# mask_color = np.zeros((height, width, 4), dtype=np.float32) | |
# mask_color[mask > 0] = [color[0], color[1], color[2], 0.5] | |
# plt.imshow(mask_color) | |
# # Draw bounding box if requested | |
# if show_boxes: | |
# rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px, | |
# fill=False, edgecolor=color, linewidth=2) | |
# plt.gca().add_patch(rect) | |
# # Add label and score if requested | |
# if show_labels: | |
# plt.text(x1_px, y1_px - 5, f"{label}: {score:.2f}", | |
# color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10) | |
# plt.axis('off') | |
# plt.tight_layout() | |
# # Convert to PIL Image using the correct method | |
# buf = BytesIO() | |
# plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) | |
# plt.close(fig) | |
# buf.seek(0) | |
# result_image = Image.open(buf) | |
# return result_image | |
# def visualize_detections(image, detections, show_labels=True): | |
# """ | |
# Visualize object detections with bounding boxes only. | |
# Returns PIL Image instead of showing plot. | |
# """ | |
# # Load the image | |
# if isinstance(image, np.ndarray): | |
# image = Image.fromarray(image) | |
# image_np = np.array(image.convert("RGB")) | |
# # Get image dimensions | |
# height, width = image_np.shape[:2] | |
# # Create figure | |
# fig = plt.figure(figsize=(12, 8)) | |
# plt.imshow(image_np) | |
# # If we have detections, draw them | |
# if detections: | |
# # Define colors for different instances | |
# colors = plt.cm.tab10(np.linspace(0, 1, 10)) | |
# # Plot each detection | |
# for i, detection in enumerate(detections): | |
# # Get bbox, label, and score | |
# bbox = detection['bbox'] | |
# label = detection['label'] | |
# score = detection['score'] | |
# # Convert normalized bbox to pixel coordinates | |
# x1, y1, x2, y2 = bbox | |
# x1_px, y1_px = int(x1 * width), int(y1 * height) | |
# x2_px, y2_px = int(x2 * width), int(y2 * height) | |
# # Color for this instance | |
# color = colors[i % len(colors)] | |
# # Draw bounding box | |
# rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px, | |
# fill=False, edgecolor=color, linewidth=2) | |
# plt.gca().add_patch(rect) | |
# # Add label and score if requested | |
# if show_labels: | |
# plt.text(x1_px, y1_px - 5, f"{label}: {score:.2f}", | |
# color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10) | |
# # Set title | |
# plt.title(f'Object Detection Results ({len(detections)} objects found)', fontsize=14, pad=20) | |
# plt.axis('off') | |
# plt.tight_layout() | |
# # Convert to PIL Image | |
# buf = BytesIO() | |
# plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) | |
# plt.close(fig) | |
# buf.seek(0) | |
# result_image = Image.open(buf) | |
# return result_image | |
# def get_optimized_prompt(query_text): | |
# """ | |
# Use OpenAI to convert natural language query into optimal detection prompt. | |
# Falls back to simple extraction if OpenAI is not available. | |
# """ | |
# if not query_text.strip(): | |
# return "object" | |
# # Try OpenAI first if API key is available | |
# if hasattr(openai, 'api_key') and openai.api_key: | |
# try: | |
# response = openai.chat.completions.create( | |
# model="gpt-3.5-turbo", | |
# messages=[{ | |
# "role": "system", | |
# "content": """You are an expert at converting natural language queries into precise object detection terms. | |
# RULES: | |
# 1. Return ONLY 1-2 words maximum that describe the object to detect | |
# 2. Use the exact object name from the user's query | |
# 3. For people: use "person" | |
# 4. For vehicles: use "car", "truck", "bicycle" | |
# 5. Do NOT include counting words, articles, or explanations | |
# 6. Examples: | |
# - "How many cacao fruits are there?" β "cacao fruit" | |
# - "Count the corn in the field" β "corn" | |
# - "Find all people" β "person" | |
# - "How many cacao pods?" β "cacao pod" | |
# - "Detect cars" β "car" | |
# - "Count bananas" β "banana" | |
# - "How many apples?" β "apple" | |
# Return ONLY the object name, nothing else.""" | |
# }, { | |
# "role": "user", | |
# "content": query_text | |
# }], | |
# temperature=0.0, # Make it deterministic | |
# max_tokens=5 # Force brevity | |
# ) | |
# llm_result = response.choices[0].message.content.strip().lower() | |
# # Extra safety: take only first 2 words | |
# words = llm_result.split()[:2] | |
# final_result = " ".join(words) | |
# print(f"π€ OpenAI suggested prompt: '{final_result}'") | |
# return final_result | |
# except Exception as e: | |
# print(f"OpenAI error: {e}, falling back to keyword extraction") | |
# # Fallback to simple keyword extraction (no hardcoded fruits) | |
# print("π€ Using keyword extraction (no OpenAI)") | |
# query_lower = query_text.lower().replace("?", "").strip() | |
# # Look for common patterns and extract object names | |
# if "how many" in query_lower: | |
# parts = query_lower.split("how many") | |
# if len(parts) > 1: | |
# remaining = parts[1].strip() | |
# remaining = remaining.replace("are", "").replace("in", "").replace("the", "").replace("image", "").replace("there", "").strip() | |
# # Take first meaningful word(s) | |
# words = remaining.split()[:2] | |
# search_terms = " ".join(words) if words else "object" | |
# else: | |
# search_terms = "object" | |
# elif "count" in query_lower: | |
# parts = query_lower.split("count") | |
# if len(parts) > 1: | |
# remaining = parts[1].strip() | |
# remaining = remaining.replace("the", "").replace("in", "").replace("image", "").strip() | |
# words = remaining.split()[:2] | |
# search_terms = " ".join(words) if words else "object" | |
# else: | |
# search_terms = "object" | |
# elif "find" in query_lower: | |
# parts = query_lower.split("find") | |
# if len(parts) > 1: | |
# remaining = parts[1].strip() | |
# remaining = remaining.replace("all", "").replace("the", "").replace("in", "").replace("image", "").strip() | |
# words = remaining.split()[:2] | |
# search_terms = " ".join(words) if words else "object" | |
# else: | |
# search_terms = "object" | |
# else: | |
# # Extract first 1-2 meaningful words from the query | |
# words = query_lower.split() | |
# meaningful_words = [w for w in words if w not in ["how", "many", "are", "in", "the", "image", "find", "count", "detect", "there", "this", "that", "a", "an"]] | |
# search_terms = " ".join(meaningful_words[:2]) if meaningful_words else "object" | |
# return search_terms | |
# def is_count_query(text): | |
# """Check if the query is asking for counting.""" | |
# count_keywords = ["how many", "count", "number of", "total"] | |
# return any(keyword in text.lower() for keyword in count_keywords) | |
# def detection_pipeline(query_text, image, threshold, use_sam): | |
# """Main detection pipeline.""" | |
# if image is None: | |
# return None, "β οΈ Please upload an image first!" | |
# try: | |
# # Use OpenAI or fallback to get optimized search terms | |
# search_terms = get_optimized_prompt(query_text) | |
# print(f"Processing query: '{query_text}' -> searching for: '{search_terms}'") | |
# # Run object detection | |
# detections, processed_image = detect_objects_owlv2(search_terms, image, threshold) | |
# print(f"Found {len(detections)} detections") | |
# for i, det in enumerate(detections): | |
# print(f"Detection {i+1}: {det['label']} (score: {det['score']:.3f})") | |
# # Generate masks if requested | |
# if use_sam and detections: | |
# print("Generating SAM2 masks...") | |
# detections = generate_masks_sam2(detections, processed_image) | |
# # Create visualization using your proven functions | |
# print("Creating visualization...") | |
# if use_sam and detections and 'mask' in detections[0]: | |
# result_image = visualize_detections_with_masks( | |
# processed_image, | |
# detections, | |
# show_labels=True, | |
# show_boxes=True | |
# ) | |
# print("Created visualization with masks") | |
# else: | |
# result_image = visualize_detections( | |
# processed_image, | |
# detections, | |
# show_labels=True | |
# ) | |
# print("Created visualization with bounding boxes only") | |
# # Make sure we have a valid result image | |
# if result_image is None: | |
# print("Warning: result_image is None, returning original image") | |
# result_image = processed_image | |
# # Generate summary | |
# count = len(detections) | |
# summary_parts = [] | |
# summary_parts.append(f"π£οΈ **Original Query**: '{query_text}'") | |
# summary_parts.append(f"π€ **AI-Optimized Search**: '{search_terms}'") | |
# summary_parts.append(f"βοΈ **Threshold**: {threshold}") | |
# summary_parts.append(f"π **SAM2 Segmentation**: {'Enabled' if use_sam else 'Disabled'}") | |
# if count > 0: | |
# if is_count_query(query_text): | |
# summary_parts.append(f"π’ **Answer: {count} {search_terms}(s) found**") | |
# else: | |
# summary_parts.append(f"β **Found {count} {search_terms}(s)**") | |
# # Show detection details | |
# for i, det in enumerate(detections[:5]): # Show first 5 | |
# summary_parts.append(f" β’ Detection {i+1}: {det['score']:.3f} confidence") | |
# if count > 5: | |
# summary_parts.append(f" β’ ... and {count-5} more detections") | |
# else: | |
# summary_parts.append(f"β **No {search_terms}(s) detected**") | |
# summary_parts.append("π‘ Try lowering the threshold or using different terms") | |
# summary_text = "\n".join(summary_parts) | |
# return result_image, summary_text | |
# except Exception as e: | |
# error_msg = f"β **Error**: {str(e)}" | |
# return image, error_msg | |
# # ---------------- | |
# # GRADIO INTERFACE | |
# # ---------------- | |
# with gr.Blocks(title="π Object Detection & Segmentation") as demo: | |
# gr.Markdown(""" | |
# # π Object Detection & Segmentation App | |
# **Simple and powerful object detection using OWL-ViT + SAM2** | |
# 1. **Enter your query** (e.g., "How many people?", "Find cars", "Count apples") | |
# 2. **Upload an image** | |
# 3. **Adjust detection sensitivity** | |
# 4. **Toggle SAM2 segmentation** for precise masks | |
# 5. **Click Detect!** | |
# """) | |
# with gr.Row(): | |
# with gr.Column(scale=1): | |
# query_input = gr.Textbox( | |
# label="π£οΈ What do you want to detect?", | |
# placeholder="e.g., 'How many people are in the image?'", | |
# value="How many people are in the image?", | |
# lines=2 | |
# ) | |
# image_input = gr.Image( | |
# label="πΈ Upload your image", | |
# type="numpy" | |
# ) | |
# with gr.Row(): | |
# threshold_slider = gr.Slider( | |
# minimum=0.01, | |
# maximum=0.9, | |
# value=0.1, | |
# step=0.01, | |
# label="ποΈ Detection Sensitivity" | |
# ) | |
# sam_checkbox = gr.Checkbox( | |
# label="π Enable SAM2 Segmentation", | |
# value=False, | |
# info="Generate precise pixel masks" | |
# ) | |
# detect_button = gr.Button("π Detect Objects!", variant="primary", size="lg") | |
# with gr.Column(scale=1): | |
# output_image = gr.Image(label="π― Detection Results") | |
# output_text = gr.Textbox( | |
# label="π Detection Summary", | |
# lines=12, | |
# show_copy_button=True | |
# ) | |
# # Event handlers | |
# detect_button.click( | |
# fn=detection_pipeline, | |
# inputs=[query_input, image_input, threshold_slider, sam_checkbox], | |
# outputs=[output_image, output_text] | |
# ) | |
# # Also trigger on Enter in text box | |
# query_input.submit( | |
# fn=detection_pipeline, | |
# inputs=[query_input, image_input, threshold_slider, sam_checkbox], | |
# outputs=[output_image, output_text] | |
# ) | |
# # Examples section | |
# gr.Examples( | |
# examples=[ | |
# ["How many people are in the image?", None, 0.1, False], | |
# ["Find all cars", None, 0.15, True], | |
# ["Count the bottles", None, 0.1, True], | |
# ["Detect dogs", None, 0.2, False], | |
# ["How many phones?", None, 0.15, True], | |
# ], | |
# inputs=[query_input, image_input, threshold_slider, sam_checkbox], | |
# ) | |
# # Launch | |
# if __name__ == "__main__": | |
# demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from transformers import pipeline | |
import warnings | |
from io import BytesIO | |
import importlib.util | |
import os | |
import openai | |
from typing import List, Dict | |
# Suppress warnings | |
warnings.filterwarnings("ignore") | |
# Set up OpenAI API key | |
api_key = os.getenv('OPENAI_API_KEY') | |
if not api_key: | |
print("No OpenAI API key found - will use simple keyword extraction") | |
elif not api_key.startswith("sk-proj-") and not api_key.startswith("sk-"): | |
print("API key found but doesn't look correct") | |
elif api_key.strip() != api_key: | |
print("API key has leading or trailing whitespace - please fix it.") | |
else: | |
print("OpenAI API key found and looks good!") | |
openai.api_key = api_key | |
# Global variables for models | |
detector = None | |
sam_predictor = None | |
def calculate_bbox_area(bbox): | |
"""Calculate the area of a normalized bounding box.""" | |
x1, y1, x2, y2 = bbox | |
width = abs(x2 - x1) | |
height = abs(y2 - y1) | |
return width * height | |
def filter_bbox_outliers(detections: List[Dict], | |
method: str = 'zscore', | |
threshold: float = 2.0, | |
min_score: float = 0.0) -> List[Dict]: | |
""" | |
Filter out outlier bounding boxes based on their area. | |
Args: | |
detections: List of detection dictionaries with 'bbox', 'label', 'score' | |
method: 'iqr' (Interquartile Range) or 'zscore' (Z-score) | |
threshold: Multiplier for IQR method or Z-score threshold | |
min_score: Minimum confidence score to keep detection | |
Returns: | |
Filtered list of detections | |
""" | |
if not detections: | |
return detections | |
# Filter by minimum score first | |
detections = [det for det in detections if det['score'] >= min_score] | |
if len(detections) <= 2: # Need at least 3 detections for meaningful outlier removal | |
return detections | |
# Calculate areas for all bounding boxes | |
areas = [calculate_bbox_area(det['bbox']) for det in detections] | |
areas = np.array(areas) | |
if method == 'iqr': | |
# IQR method | |
q1 = np.percentile(areas, 25) | |
q3 = np.percentile(areas, 75) | |
iqr = q3 - q1 | |
lower_bound = q1 - threshold * iqr | |
upper_bound = q3 + threshold * iqr | |
valid_indices = np.where((areas >= lower_bound) & (areas <= upper_bound))[0] | |
elif method == 'zscore': | |
# Z-score method | |
if np.std(areas) == 0: # All areas are the same | |
return detections | |
mean_area = np.mean(areas) | |
std_area = np.std(areas) | |
z_scores = np.abs((areas - mean_area) / std_area) | |
valid_indices = np.where(z_scores <= threshold)[0] | |
else: | |
raise ValueError("Method must be 'iqr' or 'zscore'") | |
# Return filtered detections | |
filtered_detections = [detections[i] for i in valid_indices] | |
print(f"Original detections: {len(detections)}") | |
print(f"Filtered detections: {len(filtered_detections)}") | |
print(f"Removed {len(detections) - len(filtered_detections)} outliers") | |
return filtered_detections | |
"""Load the OWL-ViT detector once and cache it.""" | |
global detector | |
if detector is None: | |
print("Loading OWL-ViT model...") | |
detector = pipeline( | |
model="google/owlv2-base-patch16-ensemble", | |
task="zero-shot-object-detection", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
print("OWL-ViT model loaded successfully!") | |
def install_sam2_if_needed(): | |
"""Check if SAM2 is installed, and install it if needed.""" | |
if importlib.util.find_spec("sam2") is not None: | |
print("SAM2 is already installed.") | |
return True | |
try: | |
import subprocess | |
import sys | |
print("Installing SAM2 from GitHub...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/sam2.git"]) | |
print("SAM2 installed successfully.") | |
return True | |
except Exception as e: | |
print(f"Error installing SAM2: {e}") | |
return False | |
def load_sam_predictor(): | |
"""Load SAM2 predictor if available.""" | |
global sam_predictor | |
if sam_predictor is None: | |
if install_sam2_if_needed(): | |
try: | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
print("Loading SAM2 model...") | |
sam_predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
sam_predictor.model.to(device) | |
print(f"SAM2 model loaded successfully on {device}!") | |
return True | |
except Exception as e: | |
print(f"Error loading SAM2: {e}") | |
return False | |
return sam_predictor is not None | |
def detect_objects_owlv2(text_query, image, threshold=0.1): | |
"""Detect objects using OWL-ViT.""" | |
try: | |
load_detector() | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# Clean up the text query | |
query_terms = [term.strip() for term in text_query.split(',') if term.strip()] | |
if not query_terms: | |
query_terms = ["object"] | |
print(f"Detecting: {query_terms}") | |
predictions = detector(image, candidate_labels=query_terms) | |
detections = [] | |
for pred in predictions: | |
if pred['score'] >= threshold: | |
bbox = pred['box'] | |
width, height = image.size | |
normalized_bbox = [ | |
bbox['xmin'] / width, | |
bbox['ymin'] / height, | |
bbox['xmax'] / width, | |
bbox['ymax'] / height | |
] | |
detection = { | |
'label': pred['label'], | |
'bbox': normalized_bbox, | |
'score': pred['score'] | |
} | |
detections.append(detection) | |
return detections, image | |
except Exception as e: | |
print(f"Detection error: {e}") | |
return [], image | |
def generate_masks_sam2(detections, image): | |
"""Generate segmentation masks using SAM2.""" | |
try: | |
if not load_sam_predictor(): | |
print("SAM2 not available, skipping mask generation") | |
return detections | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
image_np = np.array(image.convert("RGB")) | |
H, W = image_np.shape[:2] | |
# Set image for SAM2 | |
sam_predictor.set_image(image_np) | |
# Convert normalized bboxes to pixel coordinates | |
input_boxes = [] | |
for det in detections: | |
x1, y1, x2, y2 = det['bbox'] | |
input_boxes.append([int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H)]) | |
if not input_boxes: | |
return detections | |
input_boxes = np.array(input_boxes) | |
print(f"Generating masks for {len(input_boxes)} detections...") | |
with torch.inference_mode(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if device == "cuda": | |
with torch.autocast("cuda", dtype=torch.bfloat16): | |
masks, scores, _ = sam_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=input_boxes, | |
multimask_output=False | |
) | |
else: | |
masks, scores, _ = sam_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=input_boxes, | |
multimask_output=False | |
) | |
# Add masks to detections | |
results = [] | |
for i, det in enumerate(detections): | |
new_det = det.copy() | |
mask = masks[i] | |
if mask.ndim == 3: | |
mask = mask[0] # Remove batch dimension if present | |
new_det['mask'] = mask.astype(np.uint8) | |
results.append(new_det) | |
print(f"Successfully generated {len(results)} masks") | |
return results | |
except Exception as e: | |
print(f"SAM2 mask generation error: {e}") | |
return detections | |
def visualize_detections_with_masks(image, detections_with_masks, show_labels=True, show_boxes=True): | |
""" | |
Visualize the detections with their segmentation masks. | |
Returns PIL Image instead of showing plot. | |
""" | |
# Load the image | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
image_np = np.array(image.convert("RGB")) | |
# Get image dimensions | |
height, width = image_np.shape[:2] | |
# Create figure | |
fig = plt.figure(figsize=(12, 8)) | |
plt.imshow(image_np) | |
# Define colors for different instances | |
colors = plt.cm.tab10(np.linspace(0, 1, 10)) | |
# Plot each detection | |
for i, detection in enumerate(detections_with_masks): | |
# Get bbox, mask, label, and score | |
bbox = detection['bbox'] | |
label = detection['label'] | |
score = detection['score'] | |
# Convert normalized bbox to pixel coordinates | |
x1, y1, x2, y2 = bbox | |
x1_px, y1_px = int(x1 * width), int(y1 * height) | |
x2_px, y2_px = int(x2 * width), int(y2 * height) | |
# Color for this instance | |
color = colors[i % len(colors)] | |
# Display mask with transparency if available | |
if 'mask' in detection: | |
mask = detection['mask'] | |
mask_color = np.zeros((height, width, 4), dtype=np.float32) | |
mask_color[mask > 0] = [color[0], color[1], color[2], 0.5] | |
plt.imshow(mask_color) | |
# Draw bounding box if requested | |
if show_boxes: | |
rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px, | |
fill=False, edgecolor=color, linewidth=2) | |
plt.gca().add_patch(rect) | |
# Add label and score if requested | |
if show_labels: | |
plt.text(x1_px, y1_px - 5, f"{label}: {score:.2f}", | |
color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10) | |
plt.axis('off') | |
plt.tight_layout() | |
# Convert to PIL Image using the correct method | |
buf = BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) | |
plt.close(fig) | |
buf.seek(0) | |
result_image = Image.open(buf) | |
return result_image | |
def visualize_detections(image, detections, show_labels=True): | |
""" | |
Visualize object detections with bounding boxes only. | |
Returns PIL Image instead of showing plot. | |
""" | |
# Load the image | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
image_np = np.array(image.convert("RGB")) | |
# Get image dimensions | |
height, width = image_np.shape[:2] | |
# Create figure | |
fig = plt.figure(figsize=(12, 8)) | |
plt.imshow(image_np) | |
# If we have detections, draw them | |
if detections: | |
# Define colors for different instances | |
colors = plt.cm.tab10(np.linspace(0, 1, 10)) | |
# Plot each detection | |
for i, detection in enumerate(detections): | |
# Get bbox, label, and score | |
bbox = detection['bbox'] | |
label = detection['label'] | |
score = detection['score'] | |
# Convert normalized bbox to pixel coordinates | |
x1, y1, x2, y2 = bbox | |
x1_px, y1_px = int(x1 * width), int(y1 * height) | |
x2_px, y2_px = int(x2 * width), int(y2 * height) | |
# Color for this instance | |
color = colors[i % len(colors)] | |
# Draw bounding box | |
rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px, | |
fill=False, edgecolor=color, linewidth=2) | |
plt.gca().add_patch(rect) | |
# Add label and score if requested | |
if show_labels: | |
plt.text(x1_px, y1_px - 5, f"{label}: {score:.2f}", | |
color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10) | |
# Set title | |
plt.title(f'Object Detection Results ({len(detections)} objects found)', fontsize=14, pad=20) | |
plt.axis('off') | |
plt.tight_layout() | |
# Convert to PIL Image | |
buf = BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) | |
plt.close(fig) | |
buf.seek(0) | |
result_image = Image.open(buf) | |
return result_image | |
def get_optimized_prompt(query_text): | |
""" | |
Use OpenAI to convert natural language query into optimal detection prompt. | |
Falls back to simple extraction if OpenAI is not available. | |
""" | |
if not query_text.strip(): | |
return "object" | |
# Try OpenAI first if API key is available | |
if hasattr(openai, 'api_key') and openai.api_key: | |
try: | |
response = openai.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[{ | |
"role": "system", | |
"content": """You are an expert at converting natural language queries into precise object detection terms. | |
RULES: | |
1. Return ONLY 1-2 words maximum that describe the object to detect | |
2. Use the exact object name from the user's query | |
3. For people: use "person" | |
4. For vehicles: use "car", "truck", "bicycle" | |
5. Do NOT include counting words, articles, or explanations | |
6. Examples: | |
- "How many cacao fruits are there?" β "cacao fruit" | |
- "Count the corn in the field" β "corn" | |
- "Find all people" β "person" | |
- "How many cacao pods?" β "cacao pod" | |
- "Detect cars" β "car" | |
- "Count bananas" β "banana" | |
- "How many apples?" β "apple" | |
Return ONLY the object name, nothing else.""" | |
}, { | |
"role": "user", | |
"content": query_text | |
}], | |
temperature=0.0, # Make it deterministic | |
max_tokens=5 # Force brevity | |
) | |
llm_result = response.choices[0].message.content.strip().lower() | |
# Extra safety: take only first 2 words | |
words = llm_result.split()[:2] | |
final_result = " ".join(words) | |
print(f"π€ OpenAI suggested prompt: '{final_result}'") | |
return final_result | |
except Exception as e: | |
print(f"OpenAI error: {e}, falling back to keyword extraction") | |
# Fallback to simple keyword extraction (no hardcoded fruits) | |
print("π€ Using keyword extraction (no OpenAI)") | |
query_lower = query_text.lower().replace("?", "").strip() | |
# Look for common patterns and extract object names | |
if "how many" in query_lower: | |
parts = query_lower.split("how many") | |
if len(parts) > 1: | |
remaining = parts[1].strip() | |
remaining = remaining.replace("are", "").replace("in", "").replace("the", "").replace("image", "").replace("there", "").strip() | |
# Take first meaningful word(s) | |
words = remaining.split()[:2] | |
search_terms = " ".join(words) if words else "object" | |
else: | |
search_terms = "object" | |
elif "count" in query_lower: | |
parts = query_lower.split("count") | |
if len(parts) > 1: | |
remaining = parts[1].strip() | |
remaining = remaining.replace("the", "").replace("in", "").replace("image", "").strip() | |
words = remaining.split()[:2] | |
search_terms = " ".join(words) if words else "object" | |
else: | |
search_terms = "object" | |
elif "find" in query_lower: | |
parts = query_lower.split("find") | |
if len(parts) > 1: | |
remaining = parts[1].strip() | |
remaining = remaining.replace("all", "").replace("the", "").replace("in", "").replace("image", "").strip() | |
words = remaining.split()[:2] | |
search_terms = " ".join(words) if words else "object" | |
else: | |
search_terms = "object" | |
else: | |
# Extract first 1-2 meaningful words from the query | |
words = query_lower.split() | |
meaningful_words = [w for w in words if w not in ["how", "many", "are", "in", "the", "image", "find", "count", "detect", "there", "this", "that", "a", "an"]] | |
search_terms = " ".join(meaningful_words[:2]) if meaningful_words else "object" | |
return search_terms | |
def is_count_query(text): | |
"""Check if the query is asking for counting.""" | |
count_keywords = ["how many", "count", "number of", "total"] | |
return any(keyword in text.lower() for keyword in count_keywords) | |
def detection_pipeline(query_text, image, threshold, use_sam): | |
"""Main detection pipeline.""" | |
if image is None: | |
return None, "β οΈ Please upload an image first!" | |
try: | |
# Use OpenAI or fallback to get optimized search terms | |
search_terms = get_optimized_prompt(query_text) | |
print(f"Processing query: '{query_text}' -> searching for: '{search_terms}'") | |
# Run object detection | |
detections, processed_image = detect_objects_owlv2(search_terms, image, threshold) | |
print(f"Found {len(detections)} initial detections") | |
for i, det in enumerate(detections): | |
print(f"Detection {i+1}: {det['label']} (score: {det['score']:.3f}, area: {calculate_bbox_area(det['bbox']):.6f})") | |
# Filter outliers before SAM2 | |
if len(detections) > 2: # Only filter if we have enough detections | |
detections = filter_bbox_outliers(detections, method='zscore', threshold=2.0) | |
print(f"After outlier filtering: {len(detections)} detections remain") | |
# Generate masks if requested | |
if use_sam and detections: | |
print("Generating SAM2 masks...") | |
detections = generate_masks_sam2(detections, processed_image) | |
# Create visualization using your proven functions (labels OFF) | |
print("Creating visualization...") | |
if use_sam and detections and 'mask' in detections[0]: | |
result_image = visualize_detections_with_masks( | |
processed_image, | |
detections, | |
show_labels=False, # Labels OFF | |
show_boxes=True | |
) | |
print("Created visualization with masks") | |
else: | |
result_image = visualize_detections( | |
processed_image, | |
detections, | |
show_labels=False # Labels OFF | |
) | |
print("Created visualization with bounding boxes only") | |
# Make sure we have a valid result image | |
if result_image is None: | |
print("Warning: result_image is None, returning original image") | |
result_image = processed_image | |
# Generate summary | |
count = len(detections) | |
summary_parts = [] | |
summary_parts.append(f"π£οΈ **Original Query**: '{query_text}'") | |
summary_parts.append(f"π€ **AI-Optimized Search**: '{search_terms}'") | |
summary_parts.append(f"βοΈ **Threshold**: {threshold}") | |
summary_parts.append(f"π **SAM2 Segmentation**: {'Enabled' if use_sam else 'Disabled'}") | |
if count > 0: | |
if is_count_query(query_text): | |
summary_parts.append(f"π’ **Answer: {count} {search_terms}(s) found**") | |
else: | |
summary_parts.append(f"β **Found {count} {search_terms}(s)**") | |
# Show detection details | |
for i, det in enumerate(detections[:5]): # Show first 5 | |
summary_parts.append(f" β’ Detection {i+1}: {det['score']:.3f} confidence") | |
if count > 5: | |
summary_parts.append(f" β’ ... and {count-5} more detections") | |
else: | |
summary_parts.append(f"β **No {search_terms}(s) detected**") | |
summary_parts.append("π‘ Try lowering the threshold or using different terms") | |
summary_text = "\n".join(summary_parts) | |
return result_image, summary_text | |
except Exception as e: | |
error_msg = f"β **Error**: {str(e)}" | |
return image, error_msg | |
# ---------------- | |
# GRADIO INTERFACE | |
# ---------------- | |
with gr.Blocks(title="π Object Detection & Segmentation") as demo: | |
gr.Markdown(""" | |
# π Object Detection & Segmentation App | |
**Simple and powerful object detection using OWL-ViT + SAM2** | |
1. **Enter your query** (e.g., "How many people?", "Find cars", "Count apples") | |
2. **Upload an image** | |
3. **Adjust detection sensitivity** | |
4. **Toggle SAM2 segmentation** for precise masks | |
5. **Click Detect!** | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
query_input = gr.Textbox( | |
label="π£οΈ What do you want to detect?", | |
placeholder="e.g., 'How many people are in the image?'", | |
value="How many people are in the image?", | |
lines=2 | |
) | |
image_input = gr.Image( | |
label="πΈ Upload your image", | |
type="numpy" | |
) | |
with gr.Row(): | |
threshold_slider = gr.Slider( | |
minimum=0.01, | |
maximum=0.9, | |
value=0.1, | |
step=0.01, | |
label="ποΈ Detection Sensitivity" | |
) | |
sam_checkbox = gr.Checkbox( | |
label="π Enable SAM2 Segmentation", | |
value=False, | |
info="Generate precise pixel masks" | |
) | |
detect_button = gr.Button("π Detect Objects!", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
output_image = gr.Image(label="π― Detection Results") | |
output_text = gr.Textbox( | |
label="π Detection Summary", | |
lines=12, | |
show_copy_button=True | |
) | |
# Event handlers | |
detect_button.click( | |
fn=detection_pipeline, | |
inputs=[query_input, image_input, threshold_slider, sam_checkbox], | |
outputs=[output_image, output_text] | |
) | |
# Also trigger on Enter in text box | |
query_input.submit( | |
fn=detection_pipeline, | |
inputs=[query_input, image_input, threshold_slider, sam_checkbox], | |
outputs=[output_image, output_text] | |
) | |
# Examples section | |
gr.Examples( | |
examples=[ | |
["How many people are in the image?", None, 0.1, False], | |
["Find all cars", None, 0.15, True], | |
["Count the bottles", None, 0.1, True], | |
["Detect dogs", None, 0.2, False], | |
["How many phones?", None, 0.15, True], | |
], | |
inputs=[query_input, image_input, threshold_slider, sam_checkbox], | |
) | |
# Launch | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |