File size: 13,027 Bytes
3aab296
 
519f2b2
3aab296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519f2b2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import os
import openai
import gradio as gr
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
import importlib.util
from transformers import pipeline
import requests

# Set your OpenAI API key (ensure the environment variable is set or replace with your key)
openai.api_key = os.getenv("OPENAI_API_KEY", "your-openai-api-key-here")

def install_sam2_if_needed():
    """
    Check if SAM2 is installed, and install it if needed.
    """
    if importlib.util.find_spec("sam2") is not None:
        print("SAM2 is already installed.")
        return
    
    try:
        import pip
        print("Installing SAM2 from GitHub...")
        pip.main(['install', 'git+https://github.com/facebookresearch/sam2.git'])
        print("SAM2 installed successfully.")
    except Exception as e:
        print(f"Error installing SAM2: {e}")
        print("You may need to manually install SAM2: !pip install git+https://github.com/facebookresearch/sam2.git")
        raise

def detect_objects_owlv2(text_query, image, threshold=0.1):
    """
    Detect objects in an image using OWLv2 model.
    
    Args:
        text_query (str): Text description of objects to detect
        image (PIL.Image or numpy.ndarray): Input image
        threshold (float): Detection threshold
    
    Returns:
        list: List of detections with bbox, label, and score
    """
    # Initialize the OWL-ViT model
    detector = pipeline(model="google/owlv2-base-patch16-ensemble", task="zero-shot-object-detection")
    
    # Convert numpy array to PIL Image if needed
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    # Run detection
    predictions = detector(image, candidate_labels=[text_query])
    
    # Filter by threshold and format results
    detections = []
    for pred in predictions:
        if pred['score'] >= threshold:
            bbox = pred['box']
            # Normalize bbox coordinates (OWL-ViT returns absolute coordinates)
            width, height = image.size
            normalized_bbox = [
                bbox['xmin'] / width,
                bbox['ymin'] / height, 
                bbox['xmax'] / width,
                bbox['ymax'] / height
            ]
            
            detection = {
                'label': pred['label'],
                'bbox': normalized_bbox,
                'score': pred['score']
            }
            detections.append(detection)
    
    return detections

def generate_masks_from_detections(detections, image, model_name="facebook/sam2-hiera-large"):
    """
    Generate segmentation masks for objects detected by OWLv2 using SAM2 from Hugging Face.

    Args:
        detections (list): List of detections [{'label': str, 'bbox': [x1, y1, x2, y2], 'score': float}, ...]
        image (PIL.Image.Image or str): The image or path to the image to analyze
        model_name (str): Hugging Face model name for SAM2.

    Returns:
        list: List of detections with added 'mask' arrays.
    """
    install_sam2_if_needed()
    from sam2.sam2_image_predictor import SAM2ImagePredictor

    # Load image
    if isinstance(image, str):
        image = Image.open(image)
    elif isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    image_np = np.array(image.convert("RGB"))
    H, W = image_np.shape[:2]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    print(f"Loading SAM2 model from Hugging Face: {model_name}")
    predictor = SAM2ImagePredictor.from_pretrained(model_name)
    predictor.model.to(device)

    # Convert normalized bboxes to pixels
    input_boxes = []
    for det in detections:
        x1, y1, x2, y2 = det['bbox']
        input_boxes.append([int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H)])
    input_boxes = np.array(input_boxes)

    print(f"Processing image and predicting masks for {len(input_boxes)} boxes...")
    with torch.inference_mode():
        predictor.set_image(image_np)
        if device == "cuda":
            with torch.autocast("cuda", dtype=torch.bfloat16):
                masks, scores, _ = predictor.predict(
                    point_coords=None, point_labels=None,
                    box=input_boxes, multimask_output=False
                )
        else:
            masks, scores, _ = predictor.predict(
                point_coords=None, point_labels=None,
                box=input_boxes, multimask_output=False
            )

    # Attach masks to detections, handling both (1,H,W) and (H,W) outputs
    results = []
    for i, det in enumerate(detections):
        raw = masks[i]
        if raw.ndim == 3:
            mask = raw[0]
        else:
            mask = raw
        mask = mask.astype(np.uint8)

        new_det = det.copy()
        new_det['mask'] = mask
        results.append(new_det)

    print(f"Successfully generated {len(results)} masks.")
    return results

def overlay_detections_on_image(image, detections_with_masks, show_masks=True, show_boxes=True, show_labels=True):
    """
    Overlay detections (boxes and/or masks) on the image and return as numpy array.
    
    Args:
        image: Input image (PIL.Image or numpy array)
        detections_with_masks: List of detections with masks
        show_masks: Whether to show segmentation masks
        show_boxes: Whether to show bounding boxes  
        show_labels: Whether to show labels
    
    Returns:
        numpy.ndarray: Image with overlaid detections
    """
    # Convert to PIL Image if needed
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    image_np = np.array(image.convert("RGB"))
    height, width = image_np.shape[:2]
    
    # Create figure without displaying
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    ax.imshow(image_np)
    
    # Define colors for different instances
    colors = plt.cm.tab10(np.linspace(0, 1, 10))
    
    # Plot each detection
    for i, detection in enumerate(detections_with_masks):
        bbox = detection['bbox']
        label = detection['label']
        score = detection['score']
        
        # Convert normalized bbox to pixel coordinates
        x1, y1, x2, y2 = bbox
        x1_px, y1_px = int(x1 * width), int(y1 * height)
        x2_px, y2_px = int(x2 * width), int(y2 * height)
        
        # Color for this instance
        color = colors[i % len(colors)]
        
        # Display mask if available and requested
        if show_masks and 'mask' in detection:
            mask = detection['mask']
            mask_color = np.zeros((height, width, 4), dtype=np.float32)
            mask_color[mask > 0] = [color[0], color[1], color[2], 0.5]
            ax.imshow(mask_color)
        
        # Draw bounding box if requested
        if show_boxes:
            rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px,
                                fill=False, edgecolor=color, linewidth=2)
            ax.add_patch(rect)
        
        # Add label and score if requested
        if show_labels:
            ax.text(x1_px, y1_px - 5, f"{label}: {score:.2f}",
                    color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10)
    
    ax.axis('off')
    
    # Convert plot to numpy array
    fig.canvas.draw()
    result_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    result_array = result_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    
    plt.close(fig)  # Important: close the figure to free memory
    
    return result_array

def get_single_prompt(user_input):
    """
    Uses OpenAI to rephrase the user's chatter into a single, concise prompt for object detection.
    The generated prompt will not include any question marks.
    """
    if not user_input.strip():
        user_input = "Detect objects in the image"
    
    prompt_instruction = (
        f"Based on the following user input, generate a single, concise prompt for object detection. "
        f"Do not include any question marks in the output. "
        f"User input: \"{user_input}\""
    )
    
    response = openai.chat.completions.create(
        model="gpt-4o",  # adjust model name if needed
        messages=[{"role": "user", "content": prompt_instruction}],
        temperature=0.3,
        max_tokens=50,
    )
    
    generated_prompt = response.choices[0].message.content.strip()
    # Ensure no question marks remain
    generated_prompt = generated_prompt.replace("?", "")
    return generated_prompt

def is_count_query(user_input):
    """
    Check if the user's input indicates a counting request.
    Looks for common keywords such as "count", "how many", "number of", etc.
    """
    keywords = ["count", "how many", "number of", "total", "get me a count"]
    for kw in keywords:
        if kw.lower() in user_input.lower():
            return True
    return False

def process_question_and_detect(user_input, image, threshold, use_sam):
    """
    1. Uses OpenAI to generate a single, concise prompt (without question marks) from the user's input.
    2. Feeds that prompt to the custom detection function.
    3. Optionally generates segmentation masks using SAM2.
    4. Overlays the detection results on the image.
    5. If the user's input implies a counting request, it also returns the count of detected objects.
    """
    if image is None:
        return None, "Please upload an image."
    
    try:
        # Generate the concise prompt from the user's input
        generated_prompt = get_single_prompt(user_input)
        
        # Run object detection using the generated prompt
        detections = detect_objects_owlv2(generated_prompt, image, threshold=threshold)
        
        # Generate masks if SAM is enabled
        if use_sam and len(detections) > 0:
            try:
                detections_with_masks = generate_masks_from_detections(detections, image)
            except Exception as e:
                print(f"SAM2 failed, using detections without masks: {e}")
                detections_with_masks = detections
        else:
            detections_with_masks = detections
        
        # Overlay results on the image
        viz = overlay_detections_on_image(image, detections_with_masks, 
                                        show_masks=use_sam, 
                                        show_boxes=True, 
                                        show_labels=True)
        
        # If the user's input implies a counting request, include the count
        count_text = ""
        if is_count_query(user_input):
            count = len(detections)
            count_text = f"Detected {count} objects."
        
        output_text = f"Generated prompt: {generated_prompt}\n{count_text}"
        if len(detections) == 0:
            output_text += f"\nNo objects detected with threshold {threshold}. Try lowering the threshold."
        
        print(output_text)
        return viz, output_text
        
    except Exception as e:
        error_msg = f"Error during detection: {str(e)}"
        print(error_msg)
        return image, error_msg

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Custom Object Detection and Counting App")
    gr.Markdown(
        """
        Enter your input (for example:
        - "What is the number of fruit in my image?"
        - "How many bicycles can you see?"
        - "Get me a count of my bottles")
        and upload an image.
        The app uses OpenAI to generate a single, concise prompt for object detection (without question marks),
        then runs the detection using OWL-ViT. Optionally, SAM2 can generate precise segmentation masks.
        """
    )
    
    with gr.Row():
        with gr.Column():
            user_input = gr.Textbox(label="Enter your input", placeholder="Type your input here...")
            image_input = gr.Image(label="Upload Image", type="numpy")
            
            with gr.Row():
                threshold_slider = gr.Slider(
                    minimum=0.01, 
                    maximum=1.0, 
                    value=0.1, 
                    step=0.01,
                    label="Detection Threshold",
                    info="Lower values detect more objects but may include false positives"
                )
                use_sam_checkbox = gr.Checkbox(
                    label="Use SAM2 for Segmentation", 
                    value=False,
                    info="Enable to generate precise segmentation masks (requires additional computation)"
                )
            
            submit_btn = gr.Button("Detect and Count")
        
        with gr.Column():
            output_image = gr.Image(label="Detection Result")
            output_text = gr.Textbox(label="Output Details", lines=3)
    
    submit_btn.click(
        fn=process_question_and_detect, 
        inputs=[user_input, image_input, threshold_slider, use_sam_checkbox], 
        outputs=[output_image, output_text]
    )

if __name__ == "__main__":
    demo.launch()