|
import os |
|
import openai |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import importlib.util |
|
from transformers import pipeline |
|
import requests |
|
|
|
|
|
openai.api_key = os.getenv("OPENAI_API_KEY", "your-openai-api-key-here") |
|
|
|
def install_sam2_if_needed(): |
|
""" |
|
Check if SAM2 is installed, and install it if needed. |
|
""" |
|
if importlib.util.find_spec("sam2") is not None: |
|
print("SAM2 is already installed.") |
|
return |
|
|
|
try: |
|
import pip |
|
print("Installing SAM2 from GitHub...") |
|
pip.main(['install', 'git+https://github.com/facebookresearch/sam2.git']) |
|
print("SAM2 installed successfully.") |
|
except Exception as e: |
|
print(f"Error installing SAM2: {e}") |
|
print("You may need to manually install SAM2: !pip install git+https://github.com/facebookresearch/sam2.git") |
|
raise |
|
|
|
def detect_objects_owlv2(text_query, image, threshold=0.1): |
|
""" |
|
Detect objects in an image using OWLv2 model. |
|
|
|
Args: |
|
text_query (str): Text description of objects to detect |
|
image (PIL.Image or numpy.ndarray): Input image |
|
threshold (float): Detection threshold |
|
|
|
Returns: |
|
list: List of detections with bbox, label, and score |
|
""" |
|
|
|
detector = pipeline(model="google/owlv2-base-patch16-ensemble", task="zero-shot-object-detection") |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
|
|
|
|
predictions = detector(image, candidate_labels=[text_query]) |
|
|
|
|
|
detections = [] |
|
for pred in predictions: |
|
if pred['score'] >= threshold: |
|
bbox = pred['box'] |
|
|
|
width, height = image.size |
|
normalized_bbox = [ |
|
bbox['xmin'] / width, |
|
bbox['ymin'] / height, |
|
bbox['xmax'] / width, |
|
bbox['ymax'] / height |
|
] |
|
|
|
detection = { |
|
'label': pred['label'], |
|
'bbox': normalized_bbox, |
|
'score': pred['score'] |
|
} |
|
detections.append(detection) |
|
|
|
return detections |
|
|
|
def generate_masks_from_detections(detections, image, model_name="facebook/sam2-hiera-large"): |
|
""" |
|
Generate segmentation masks for objects detected by OWLv2 using SAM2 from Hugging Face. |
|
|
|
Args: |
|
detections (list): List of detections [{'label': str, 'bbox': [x1, y1, x2, y2], 'score': float}, ...] |
|
image (PIL.Image.Image or str): The image or path to the image to analyze |
|
model_name (str): Hugging Face model name for SAM2. |
|
|
|
Returns: |
|
list: List of detections with added 'mask' arrays. |
|
""" |
|
install_sam2_if_needed() |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
|
|
if isinstance(image, str): |
|
image = Image.open(image) |
|
elif isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
|
|
image_np = np.array(image.convert("RGB")) |
|
H, W = image_np.shape[:2] |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
print(f"Loading SAM2 model from Hugging Face: {model_name}") |
|
predictor = SAM2ImagePredictor.from_pretrained(model_name) |
|
predictor.model.to(device) |
|
|
|
|
|
input_boxes = [] |
|
for det in detections: |
|
x1, y1, x2, y2 = det['bbox'] |
|
input_boxes.append([int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H)]) |
|
input_boxes = np.array(input_boxes) |
|
|
|
print(f"Processing image and predicting masks for {len(input_boxes)} boxes...") |
|
with torch.inference_mode(): |
|
predictor.set_image(image_np) |
|
if device == "cuda": |
|
with torch.autocast("cuda", dtype=torch.bfloat16): |
|
masks, scores, _ = predictor.predict( |
|
point_coords=None, point_labels=None, |
|
box=input_boxes, multimask_output=False |
|
) |
|
else: |
|
masks, scores, _ = predictor.predict( |
|
point_coords=None, point_labels=None, |
|
box=input_boxes, multimask_output=False |
|
) |
|
|
|
|
|
results = [] |
|
for i, det in enumerate(detections): |
|
raw = masks[i] |
|
if raw.ndim == 3: |
|
mask = raw[0] |
|
else: |
|
mask = raw |
|
mask = mask.astype(np.uint8) |
|
|
|
new_det = det.copy() |
|
new_det['mask'] = mask |
|
results.append(new_det) |
|
|
|
print(f"Successfully generated {len(results)} masks.") |
|
return results |
|
|
|
def overlay_detections_on_image(image, detections_with_masks, show_masks=True, show_boxes=True, show_labels=True): |
|
""" |
|
Overlay detections (boxes and/or masks) on the image and return as numpy array. |
|
|
|
Args: |
|
image: Input image (PIL.Image or numpy array) |
|
detections_with_masks: List of detections with masks |
|
show_masks: Whether to show segmentation masks |
|
show_boxes: Whether to show bounding boxes |
|
show_labels: Whether to show labels |
|
|
|
Returns: |
|
numpy.ndarray: Image with overlaid detections |
|
""" |
|
|
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
|
|
image_np = np.array(image.convert("RGB")) |
|
height, width = image_np.shape[:2] |
|
|
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
|
ax.imshow(image_np) |
|
|
|
|
|
colors = plt.cm.tab10(np.linspace(0, 1, 10)) |
|
|
|
|
|
for i, detection in enumerate(detections_with_masks): |
|
bbox = detection['bbox'] |
|
label = detection['label'] |
|
score = detection['score'] |
|
|
|
|
|
x1, y1, x2, y2 = bbox |
|
x1_px, y1_px = int(x1 * width), int(y1 * height) |
|
x2_px, y2_px = int(x2 * width), int(y2 * height) |
|
|
|
|
|
color = colors[i % len(colors)] |
|
|
|
|
|
if show_masks and 'mask' in detection: |
|
mask = detection['mask'] |
|
mask_color = np.zeros((height, width, 4), dtype=np.float32) |
|
mask_color[mask > 0] = [color[0], color[1], color[2], 0.5] |
|
ax.imshow(mask_color) |
|
|
|
|
|
if show_boxes: |
|
rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px, |
|
fill=False, edgecolor=color, linewidth=2) |
|
ax.add_patch(rect) |
|
|
|
|
|
if show_labels: |
|
ax.text(x1_px, y1_px - 5, f"{label}: {score:.2f}", |
|
color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10) |
|
|
|
ax.axis('off') |
|
|
|
|
|
fig.canvas.draw() |
|
result_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
|
result_array = result_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
|
|
plt.close(fig) |
|
|
|
return result_array |
|
|
|
def get_single_prompt(user_input): |
|
""" |
|
Uses OpenAI to rephrase the user's chatter into a single, concise prompt for object detection. |
|
The generated prompt will not include any question marks. |
|
""" |
|
if not user_input.strip(): |
|
user_input = "Detect objects in the image" |
|
|
|
prompt_instruction = ( |
|
f"Based on the following user input, generate a single, concise prompt for object detection. " |
|
f"Do not include any question marks in the output. " |
|
f"User input: \"{user_input}\"" |
|
) |
|
|
|
response = openai.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[{"role": "user", "content": prompt_instruction}], |
|
temperature=0.3, |
|
max_tokens=50, |
|
) |
|
|
|
generated_prompt = response.choices[0].message.content.strip() |
|
|
|
generated_prompt = generated_prompt.replace("?", "") |
|
return generated_prompt |
|
|
|
def is_count_query(user_input): |
|
""" |
|
Check if the user's input indicates a counting request. |
|
Looks for common keywords such as "count", "how many", "number of", etc. |
|
""" |
|
keywords = ["count", "how many", "number of", "total", "get me a count"] |
|
for kw in keywords: |
|
if kw.lower() in user_input.lower(): |
|
return True |
|
return False |
|
|
|
def process_question_and_detect(user_input, image, threshold, use_sam): |
|
""" |
|
1. Uses OpenAI to generate a single, concise prompt (without question marks) from the user's input. |
|
2. Feeds that prompt to the custom detection function. |
|
3. Optionally generates segmentation masks using SAM2. |
|
4. Overlays the detection results on the image. |
|
5. If the user's input implies a counting request, it also returns the count of detected objects. |
|
""" |
|
if image is None: |
|
return None, "Please upload an image." |
|
|
|
try: |
|
|
|
generated_prompt = get_single_prompt(user_input) |
|
|
|
|
|
detections = detect_objects_owlv2(generated_prompt, image, threshold=threshold) |
|
|
|
|
|
if use_sam and len(detections) > 0: |
|
try: |
|
detections_with_masks = generate_masks_from_detections(detections, image) |
|
except Exception as e: |
|
print(f"SAM2 failed, using detections without masks: {e}") |
|
detections_with_masks = detections |
|
else: |
|
detections_with_masks = detections |
|
|
|
|
|
viz = overlay_detections_on_image(image, detections_with_masks, |
|
show_masks=use_sam, |
|
show_boxes=True, |
|
show_labels=True) |
|
|
|
|
|
count_text = "" |
|
if is_count_query(user_input): |
|
count = len(detections) |
|
count_text = f"Detected {count} objects." |
|
|
|
output_text = f"Generated prompt: {generated_prompt}\n{count_text}" |
|
if len(detections) == 0: |
|
output_text += f"\nNo objects detected with threshold {threshold}. Try lowering the threshold." |
|
|
|
print(output_text) |
|
return viz, output_text |
|
|
|
except Exception as e: |
|
error_msg = f"Error during detection: {str(e)}" |
|
print(error_msg) |
|
return image, error_msg |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Custom Object Detection and Counting App") |
|
gr.Markdown( |
|
""" |
|
Enter your input (for example: |
|
- "What is the number of fruit in my image?" |
|
- "How many bicycles can you see?" |
|
- "Get me a count of my bottles") |
|
and upload an image. |
|
The app uses OpenAI to generate a single, concise prompt for object detection (without question marks), |
|
then runs the detection using OWL-ViT. Optionally, SAM2 can generate precise segmentation masks. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
user_input = gr.Textbox(label="Enter your input", placeholder="Type your input here...") |
|
image_input = gr.Image(label="Upload Image", type="numpy") |
|
|
|
with gr.Row(): |
|
threshold_slider = gr.Slider( |
|
minimum=0.01, |
|
maximum=1.0, |
|
value=0.1, |
|
step=0.01, |
|
label="Detection Threshold", |
|
info="Lower values detect more objects but may include false positives" |
|
) |
|
use_sam_checkbox = gr.Checkbox( |
|
label="Use SAM2 for Segmentation", |
|
value=False, |
|
info="Enable to generate precise segmentation masks (requires additional computation)" |
|
) |
|
|
|
submit_btn = gr.Button("Detect and Count") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label="Detection Result") |
|
output_text = gr.Textbox(label="Output Details", lines=3) |
|
|
|
submit_btn.click( |
|
fn=process_question_and_detect, |
|
inputs=[user_input, image_input, threshold_slider, use_sam_checkbox], |
|
outputs=[output_image, output_text] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |