obichimav's picture
Update app.py
3aab296 verified
raw
history blame
13 kB
import os
import openai
import gradio as gr
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
import importlib.util
from transformers import pipeline
import requests
# Set your OpenAI API key (ensure the environment variable is set or replace with your key)
openai.api_key = os.getenv("OPENAI_API_KEY", "your-openai-api-key-here")
def install_sam2_if_needed():
"""
Check if SAM2 is installed, and install it if needed.
"""
if importlib.util.find_spec("sam2") is not None:
print("SAM2 is already installed.")
return
try:
import pip
print("Installing SAM2 from GitHub...")
pip.main(['install', 'git+https://github.com/facebookresearch/sam2.git'])
print("SAM2 installed successfully.")
except Exception as e:
print(f"Error installing SAM2: {e}")
print("You may need to manually install SAM2: !pip install git+https://github.com/facebookresearch/sam2.git")
raise
def detect_objects_owlv2(text_query, image, threshold=0.1):
"""
Detect objects in an image using OWLv2 model.
Args:
text_query (str): Text description of objects to detect
image (PIL.Image or numpy.ndarray): Input image
threshold (float): Detection threshold
Returns:
list: List of detections with bbox, label, and score
"""
# Initialize the OWL-ViT model
detector = pipeline(model="google/owlv2-base-patch16-ensemble", task="zero-shot-object-detection")
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Run detection
predictions = detector(image, candidate_labels=[text_query])
# Filter by threshold and format results
detections = []
for pred in predictions:
if pred['score'] >= threshold:
bbox = pred['box']
# Normalize bbox coordinates (OWL-ViT returns absolute coordinates)
width, height = image.size
normalized_bbox = [
bbox['xmin'] / width,
bbox['ymin'] / height,
bbox['xmax'] / width,
bbox['ymax'] / height
]
detection = {
'label': pred['label'],
'bbox': normalized_bbox,
'score': pred['score']
}
detections.append(detection)
return detections
def generate_masks_from_detections(detections, image, model_name="facebook/sam2-hiera-large"):
"""
Generate segmentation masks for objects detected by OWLv2 using SAM2 from Hugging Face.
Args:
detections (list): List of detections [{'label': str, 'bbox': [x1, y1, x2, y2], 'score': float}, ...]
image (PIL.Image.Image or str): The image or path to the image to analyze
model_name (str): Hugging Face model name for SAM2.
Returns:
list: List of detections with added 'mask' arrays.
"""
install_sam2_if_needed()
from sam2.sam2_image_predictor import SAM2ImagePredictor
# Load image
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
image_np = np.array(image.convert("RGB"))
H, W = image_np.shape[:2]
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"Loading SAM2 model from Hugging Face: {model_name}")
predictor = SAM2ImagePredictor.from_pretrained(model_name)
predictor.model.to(device)
# Convert normalized bboxes to pixels
input_boxes = []
for det in detections:
x1, y1, x2, y2 = det['bbox']
input_boxes.append([int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H)])
input_boxes = np.array(input_boxes)
print(f"Processing image and predicting masks for {len(input_boxes)} boxes...")
with torch.inference_mode():
predictor.set_image(image_np)
if device == "cuda":
with torch.autocast("cuda", dtype=torch.bfloat16):
masks, scores, _ = predictor.predict(
point_coords=None, point_labels=None,
box=input_boxes, multimask_output=False
)
else:
masks, scores, _ = predictor.predict(
point_coords=None, point_labels=None,
box=input_boxes, multimask_output=False
)
# Attach masks to detections, handling both (1,H,W) and (H,W) outputs
results = []
for i, det in enumerate(detections):
raw = masks[i]
if raw.ndim == 3:
mask = raw[0]
else:
mask = raw
mask = mask.astype(np.uint8)
new_det = det.copy()
new_det['mask'] = mask
results.append(new_det)
print(f"Successfully generated {len(results)} masks.")
return results
def overlay_detections_on_image(image, detections_with_masks, show_masks=True, show_boxes=True, show_labels=True):
"""
Overlay detections (boxes and/or masks) on the image and return as numpy array.
Args:
image: Input image (PIL.Image or numpy array)
detections_with_masks: List of detections with masks
show_masks: Whether to show segmentation masks
show_boxes: Whether to show bounding boxes
show_labels: Whether to show labels
Returns:
numpy.ndarray: Image with overlaid detections
"""
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image_np = np.array(image.convert("RGB"))
height, width = image_np.shape[:2]
# Create figure without displaying
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
ax.imshow(image_np)
# Define colors for different instances
colors = plt.cm.tab10(np.linspace(0, 1, 10))
# Plot each detection
for i, detection in enumerate(detections_with_masks):
bbox = detection['bbox']
label = detection['label']
score = detection['score']
# Convert normalized bbox to pixel coordinates
x1, y1, x2, y2 = bbox
x1_px, y1_px = int(x1 * width), int(y1 * height)
x2_px, y2_px = int(x2 * width), int(y2 * height)
# Color for this instance
color = colors[i % len(colors)]
# Display mask if available and requested
if show_masks and 'mask' in detection:
mask = detection['mask']
mask_color = np.zeros((height, width, 4), dtype=np.float32)
mask_color[mask > 0] = [color[0], color[1], color[2], 0.5]
ax.imshow(mask_color)
# Draw bounding box if requested
if show_boxes:
rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px,
fill=False, edgecolor=color, linewidth=2)
ax.add_patch(rect)
# Add label and score if requested
if show_labels:
ax.text(x1_px, y1_px - 5, f"{label}: {score:.2f}",
color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10)
ax.axis('off')
# Convert plot to numpy array
fig.canvas.draw()
result_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
result_array = result_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close(fig) # Important: close the figure to free memory
return result_array
def get_single_prompt(user_input):
"""
Uses OpenAI to rephrase the user's chatter into a single, concise prompt for object detection.
The generated prompt will not include any question marks.
"""
if not user_input.strip():
user_input = "Detect objects in the image"
prompt_instruction = (
f"Based on the following user input, generate a single, concise prompt for object detection. "
f"Do not include any question marks in the output. "
f"User input: \"{user_input}\""
)
response = openai.chat.completions.create(
model="gpt-4o", # adjust model name if needed
messages=[{"role": "user", "content": prompt_instruction}],
temperature=0.3,
max_tokens=50,
)
generated_prompt = response.choices[0].message.content.strip()
# Ensure no question marks remain
generated_prompt = generated_prompt.replace("?", "")
return generated_prompt
def is_count_query(user_input):
"""
Check if the user's input indicates a counting request.
Looks for common keywords such as "count", "how many", "number of", etc.
"""
keywords = ["count", "how many", "number of", "total", "get me a count"]
for kw in keywords:
if kw.lower() in user_input.lower():
return True
return False
def process_question_and_detect(user_input, image, threshold, use_sam):
"""
1. Uses OpenAI to generate a single, concise prompt (without question marks) from the user's input.
2. Feeds that prompt to the custom detection function.
3. Optionally generates segmentation masks using SAM2.
4. Overlays the detection results on the image.
5. If the user's input implies a counting request, it also returns the count of detected objects.
"""
if image is None:
return None, "Please upload an image."
try:
# Generate the concise prompt from the user's input
generated_prompt = get_single_prompt(user_input)
# Run object detection using the generated prompt
detections = detect_objects_owlv2(generated_prompt, image, threshold=threshold)
# Generate masks if SAM is enabled
if use_sam and len(detections) > 0:
try:
detections_with_masks = generate_masks_from_detections(detections, image)
except Exception as e:
print(f"SAM2 failed, using detections without masks: {e}")
detections_with_masks = detections
else:
detections_with_masks = detections
# Overlay results on the image
viz = overlay_detections_on_image(image, detections_with_masks,
show_masks=use_sam,
show_boxes=True,
show_labels=True)
# If the user's input implies a counting request, include the count
count_text = ""
if is_count_query(user_input):
count = len(detections)
count_text = f"Detected {count} objects."
output_text = f"Generated prompt: {generated_prompt}\n{count_text}"
if len(detections) == 0:
output_text += f"\nNo objects detected with threshold {threshold}. Try lowering the threshold."
print(output_text)
return viz, output_text
except Exception as e:
error_msg = f"Error during detection: {str(e)}"
print(error_msg)
return image, error_msg
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Custom Object Detection and Counting App")
gr.Markdown(
"""
Enter your input (for example:
- "What is the number of fruit in my image?"
- "How many bicycles can you see?"
- "Get me a count of my bottles")
and upload an image.
The app uses OpenAI to generate a single, concise prompt for object detection (without question marks),
then runs the detection using OWL-ViT. Optionally, SAM2 can generate precise segmentation masks.
"""
)
with gr.Row():
with gr.Column():
user_input = gr.Textbox(label="Enter your input", placeholder="Type your input here...")
image_input = gr.Image(label="Upload Image", type="numpy")
with gr.Row():
threshold_slider = gr.Slider(
minimum=0.01,
maximum=1.0,
value=0.1,
step=0.01,
label="Detection Threshold",
info="Lower values detect more objects but may include false positives"
)
use_sam_checkbox = gr.Checkbox(
label="Use SAM2 for Segmentation",
value=False,
info="Enable to generate precise segmentation masks (requires additional computation)"
)
submit_btn = gr.Button("Detect and Count")
with gr.Column():
output_image = gr.Image(label="Detection Result")
output_text = gr.Textbox(label="Output Details", lines=3)
submit_btn.click(
fn=process_question_and_detect,
inputs=[user_input, image_input, threshold_slider, use_sam_checkbox],
outputs=[output_image, output_text]
)
if __name__ == "__main__":
demo.launch()