obichimav commited on
Commit
e957e17
Β·
verified Β·
1 Parent(s): 3aab296

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +365 -256
app.py CHANGED
@@ -1,179 +1,194 @@
1
- import os
2
- import openai
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
  from PIL import Image
7
  import matplotlib.pyplot as plt
8
- import importlib.util
9
  from transformers import pipeline
10
- import requests
 
 
 
 
 
 
 
 
 
11
 
12
- # Set your OpenAI API key (ensure the environment variable is set or replace with your key)
13
- openai.api_key = os.getenv("OPENAI_API_KEY", "your-openai-api-key-here")
 
 
 
 
 
 
 
 
 
14
 
15
  def install_sam2_if_needed():
16
- """
17
- Check if SAM2 is installed, and install it if needed.
18
- """
19
  if importlib.util.find_spec("sam2") is not None:
20
  print("SAM2 is already installed.")
21
- return
22
 
23
  try:
24
- import pip
 
25
  print("Installing SAM2 from GitHub...")
26
- pip.main(['install', 'git+https://github.com/facebookresearch/sam2.git'])
27
  print("SAM2 installed successfully.")
 
28
  except Exception as e:
29
  print(f"Error installing SAM2: {e}")
30
- print("You may need to manually install SAM2: !pip install git+https://github.com/facebookresearch/sam2.git")
31
- raise
32
-
33
- def detect_objects_owlv2(text_query, image, threshold=0.1):
34
- """
35
- Detect objects in an image using OWLv2 model.
36
-
37
- Args:
38
- text_query (str): Text description of objects to detect
39
- image (PIL.Image or numpy.ndarray): Input image
40
- threshold (float): Detection threshold
41
-
42
- Returns:
43
- list: List of detections with bbox, label, and score
44
- """
45
- # Initialize the OWL-ViT model
46
- detector = pipeline(model="google/owlv2-base-patch16-ensemble", task="zero-shot-object-detection")
47
-
48
- # Convert numpy array to PIL Image if needed
49
- if isinstance(image, np.ndarray):
50
- image = Image.fromarray(image)
51
-
52
- # Run detection
53
- predictions = detector(image, candidate_labels=[text_query])
54
-
55
- # Filter by threshold and format results
56
- detections = []
57
- for pred in predictions:
58
- if pred['score'] >= threshold:
59
- bbox = pred['box']
60
- # Normalize bbox coordinates (OWL-ViT returns absolute coordinates)
61
- width, height = image.size
62
- normalized_bbox = [
63
- bbox['xmin'] / width,
64
- bbox['ymin'] / height,
65
- bbox['xmax'] / width,
66
- bbox['ymax'] / height
67
- ]
68
-
69
- detection = {
70
- 'label': pred['label'],
71
- 'bbox': normalized_bbox,
72
- 'score': pred['score']
73
- }
74
- detections.append(detection)
75
-
76
- return detections
77
 
78
- def generate_masks_from_detections(detections, image, model_name="facebook/sam2-hiera-large"):
79
- """
80
- Generate segmentation masks for objects detected by OWLv2 using SAM2 from Hugging Face.
81
-
82
- Args:
83
- detections (list): List of detections [{'label': str, 'bbox': [x1, y1, x2, y2], 'score': float}, ...]
84
- image (PIL.Image.Image or str): The image or path to the image to analyze
85
- model_name (str): Hugging Face model name for SAM2.
86
-
87
- Returns:
88
- list: List of detections with added 'mask' arrays.
89
- """
90
- install_sam2_if_needed()
91
- from sam2.sam2_image_predictor import SAM2ImagePredictor
92
-
93
- # Load image
94
- if isinstance(image, str):
95
- image = Image.open(image)
96
- elif isinstance(image, np.ndarray):
97
- image = Image.fromarray(image)
98
-
99
- image_np = np.array(image.convert("RGB"))
100
- H, W = image_np.shape[:2]
101
-
102
- device = "cuda" if torch.cuda.is_available() else "cpu"
103
- print(f"Using device: {device}")
104
- print(f"Loading SAM2 model from Hugging Face: {model_name}")
105
- predictor = SAM2ImagePredictor.from_pretrained(model_name)
106
- predictor.model.to(device)
107
 
108
- # Convert normalized bboxes to pixels
109
- input_boxes = []
110
- for det in detections:
111
- x1, y1, x2, y2 = det['bbox']
112
- input_boxes.append([int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H)])
113
- input_boxes = np.array(input_boxes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- print(f"Processing image and predicting masks for {len(input_boxes)} boxes...")
116
- with torch.inference_mode():
117
- predictor.set_image(image_np)
118
- if device == "cuda":
119
- with torch.autocast("cuda", dtype=torch.bfloat16):
120
- masks, scores, _ = predictor.predict(
121
- point_coords=None, point_labels=None,
122
- box=input_boxes, multimask_output=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  )
124
- else:
125
- masks, scores, _ = predictor.predict(
126
- point_coords=None, point_labels=None,
127
- box=input_boxes, multimask_output=False
128
- )
129
-
130
- # Attach masks to detections, handling both (1,H,W) and (H,W) outputs
131
- results = []
132
- for i, det in enumerate(detections):
133
- raw = masks[i]
134
- if raw.ndim == 3:
135
- mask = raw[0]
136
- else:
137
- mask = raw
138
- mask = mask.astype(np.uint8)
139
-
140
- new_det = det.copy()
141
- new_det['mask'] = mask
142
- results.append(new_det)
143
-
144
- print(f"Successfully generated {len(results)} masks.")
145
- return results
146
 
147
- def overlay_detections_on_image(image, detections_with_masks, show_masks=True, show_boxes=True, show_labels=True):
148
  """
149
- Overlay detections (boxes and/or masks) on the image and return as numpy array.
150
-
151
- Args:
152
- image: Input image (PIL.Image or numpy array)
153
- detections_with_masks: List of detections with masks
154
- show_masks: Whether to show segmentation masks
155
- show_boxes: Whether to show bounding boxes
156
- show_labels: Whether to show labels
157
-
158
- Returns:
159
- numpy.ndarray: Image with overlaid detections
160
  """
161
- # Convert to PIL Image if needed
162
  if isinstance(image, np.ndarray):
163
  image = Image.fromarray(image)
164
-
165
  image_np = np.array(image.convert("RGB"))
 
 
166
  height, width = image_np.shape[:2]
167
 
168
- # Create figure without displaying
169
- fig, ax = plt.subplots(1, 1, figsize=(12, 8))
170
- ax.imshow(image_np)
171
 
172
  # Define colors for different instances
173
  colors = plt.cm.tab10(np.linspace(0, 1, 10))
174
 
175
  # Plot each detection
176
  for i, detection in enumerate(detections_with_masks):
 
177
  bbox = detection['bbox']
178
  label = detection['label']
179
  score = detection['score']
@@ -186,170 +201,264 @@ def overlay_detections_on_image(image, detections_with_masks, show_masks=True, s
186
  # Color for this instance
187
  color = colors[i % len(colors)]
188
 
189
- # Display mask if available and requested
190
- if show_masks and 'mask' in detection:
191
  mask = detection['mask']
192
  mask_color = np.zeros((height, width, 4), dtype=np.float32)
193
  mask_color[mask > 0] = [color[0], color[1], color[2], 0.5]
194
- ax.imshow(mask_color)
195
 
196
  # Draw bounding box if requested
197
  if show_boxes:
198
  rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px,
199
  fill=False, edgecolor=color, linewidth=2)
200
- ax.add_patch(rect)
201
 
202
  # Add label and score if requested
203
  if show_labels:
204
- ax.text(x1_px, y1_px - 5, f"{label}: {score:.2f}",
205
  color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10)
206
 
207
- ax.axis('off')
208
-
209
- # Convert plot to numpy array
210
- fig.canvas.draw()
211
- result_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
212
- result_array = result_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
213
 
214
- plt.close(fig) # Important: close the figure to free memory
 
 
 
 
215
 
216
- return result_array
 
217
 
218
- def get_single_prompt(user_input):
219
  """
220
- Uses OpenAI to rephrase the user's chatter into a single, concise prompt for object detection.
221
- The generated prompt will not include any question marks.
222
  """
223
- if not user_input.strip():
224
- user_input = "Detect objects in the image"
 
 
225
 
226
- prompt_instruction = (
227
- f"Based on the following user input, generate a single, concise prompt for object detection. "
228
- f"Do not include any question marks in the output. "
229
- f"User input: \"{user_input}\""
230
- )
231
 
232
- response = openai.chat.completions.create(
233
- model="gpt-4o", # adjust model name if needed
234
- messages=[{"role": "user", "content": prompt_instruction}],
235
- temperature=0.3,
236
- max_tokens=50,
237
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
- generated_prompt = response.choices[0].message.content.strip()
240
- # Ensure no question marks remain
241
- generated_prompt = generated_prompt.replace("?", "")
242
- return generated_prompt
 
 
 
 
243
 
244
- def is_count_query(user_input):
245
- """
246
- Check if the user's input indicates a counting request.
247
- Looks for common keywords such as "count", "how many", "number of", etc.
248
- """
249
- keywords = ["count", "how many", "number of", "total", "get me a count"]
250
- for kw in keywords:
251
- if kw.lower() in user_input.lower():
252
- return True
253
- return False
254
 
255
- def process_question_and_detect(user_input, image, threshold, use_sam):
256
- """
257
- 1. Uses OpenAI to generate a single, concise prompt (without question marks) from the user's input.
258
- 2. Feeds that prompt to the custom detection function.
259
- 3. Optionally generates segmentation masks using SAM2.
260
- 4. Overlays the detection results on the image.
261
- 5. If the user's input implies a counting request, it also returns the count of detected objects.
262
- """
263
  if image is None:
264
- return None, "Please upload an image."
265
 
266
  try:
267
- # Generate the concise prompt from the user's input
268
- generated_prompt = get_single_prompt(user_input)
269
 
270
- # Run object detection using the generated prompt
271
- detections = detect_objects_owlv2(generated_prompt, image, threshold=threshold)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- # Generate masks if SAM is enabled
274
- if use_sam and len(detections) > 0:
275
- try:
276
- detections_with_masks = generate_masks_from_detections(detections, image)
277
- except Exception as e:
278
- print(f"SAM2 failed, using detections without masks: {e}")
279
- detections_with_masks = detections
 
 
 
 
 
 
 
 
 
 
280
  else:
281
- detections_with_masks = detections
 
 
 
 
 
 
 
282
 
283
- # Overlay results on the image
284
- viz = overlay_detections_on_image(image, detections_with_masks,
285
- show_masks=use_sam,
286
- show_boxes=True,
287
- show_labels=True)
288
 
289
- # If the user's input implies a counting request, include the count
290
- count_text = ""
291
- if is_count_query(user_input):
292
- count = len(detections)
293
- count_text = f"Detected {count} objects."
 
 
 
 
 
 
 
 
 
294
 
295
- output_text = f"Generated prompt: {generated_prompt}\n{count_text}"
296
- if len(detections) == 0:
297
- output_text += f"\nNo objects detected with threshold {threshold}. Try lowering the threshold."
298
 
299
- print(output_text)
300
- return viz, output_text
301
 
302
  except Exception as e:
303
- error_msg = f"Error during detection: {str(e)}"
304
- print(error_msg)
305
  return image, error_msg
306
 
307
- # Gradio interface
308
- with gr.Blocks() as demo:
309
- gr.Markdown("# Custom Object Detection and Counting App")
310
- gr.Markdown(
311
- """
312
- Enter your input (for example:
313
- - "What is the number of fruit in my image?"
314
- - "How many bicycles can you see?"
315
- - "Get me a count of my bottles")
316
- and upload an image.
317
- The app uses OpenAI to generate a single, concise prompt for object detection (without question marks),
318
- then runs the detection using OWL-ViT. Optionally, SAM2 can generate precise segmentation masks.
319
- """
320
- )
 
321
 
322
  with gr.Row():
323
- with gr.Column():
324
- user_input = gr.Textbox(label="Enter your input", placeholder="Type your input here...")
325
- image_input = gr.Image(label="Upload Image", type="numpy")
 
 
 
 
 
 
 
 
 
326
 
327
  with gr.Row():
328
  threshold_slider = gr.Slider(
329
- minimum=0.01,
330
- maximum=1.0,
331
- value=0.1,
332
  step=0.01,
333
- label="Detection Threshold",
334
- info="Lower values detect more objects but may include false positives"
335
  )
336
- use_sam_checkbox = gr.Checkbox(
337
- label="Use SAM2 for Segmentation",
 
338
  value=False,
339
- info="Enable to generate precise segmentation masks (requires additional computation)"
340
  )
341
 
342
- submit_btn = gr.Button("Detect and Count")
343
-
344
- with gr.Column():
345
- output_image = gr.Image(label="Detection Result")
346
- output_text = gr.Textbox(label="Output Details", lines=3)
347
 
348
- submit_btn.click(
349
- fn=process_question_and_detect,
350
- inputs=[user_input, image_input, threshold_slider, use_sam_checkbox],
 
 
 
 
 
 
 
 
 
351
  outputs=[output_image, output_text]
352
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
 
354
  if __name__ == "__main__":
355
- demo.launch()
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
  from PIL import Image
5
  import matplotlib.pyplot as plt
 
6
  from transformers import pipeline
7
+ import warnings
8
+ from io import BytesIO
9
+ import importlib.util
10
+
11
+ # Suppress warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ # Global variables for models
15
+ detector = None
16
+ sam_predictor = None
17
 
18
+ def load_detector():
19
+ """Load the OWL-ViT detector once and cache it."""
20
+ global detector
21
+ if detector is None:
22
+ print("Loading OWL-ViT model...")
23
+ detector = pipeline(
24
+ model="google/owlv2-base-patch16-ensemble",
25
+ task="zero-shot-object-detection",
26
+ device=0 if torch.cuda.is_available() else -1
27
+ )
28
+ print("OWL-ViT model loaded successfully!")
29
 
30
  def install_sam2_if_needed():
31
+ """Check if SAM2 is installed, and install it if needed."""
 
 
32
  if importlib.util.find_spec("sam2") is not None:
33
  print("SAM2 is already installed.")
34
+ return True
35
 
36
  try:
37
+ import subprocess
38
+ import sys
39
  print("Installing SAM2 from GitHub...")
40
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/sam2.git"])
41
  print("SAM2 installed successfully.")
42
+ return True
43
  except Exception as e:
44
  print(f"Error installing SAM2: {e}")
45
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ def load_sam_predictor():
48
+ """Load SAM2 predictor if available."""
49
+ global sam_predictor
50
+ if sam_predictor is None:
51
+ if install_sam2_if_needed():
52
+ try:
53
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
54
+ print("Loading SAM2 model...")
55
+ sam_predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
56
+ device = "cuda" if torch.cuda.is_available() else "cpu"
57
+ sam_predictor.model.to(device)
58
+ print(f"SAM2 model loaded successfully on {device}!")
59
+ return True
60
+ except Exception as e:
61
+ print(f"Error loading SAM2: {e}")
62
+ return False
63
+ return sam_predictor is not None
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ def detect_objects_owlv2(text_query, image, threshold=0.1):
66
+ """Detect objects using OWL-ViT."""
67
+ try:
68
+ load_detector()
69
+
70
+ if isinstance(image, np.ndarray):
71
+ image = Image.fromarray(image)
72
+
73
+ # Clean up the text query
74
+ query_terms = [term.strip() for term in text_query.split(',') if term.strip()]
75
+ if not query_terms:
76
+ query_terms = ["object"]
77
+
78
+ print(f"Detecting: {query_terms}")
79
+ predictions = detector(image, candidate_labels=query_terms)
80
+
81
+ detections = []
82
+ for pred in predictions:
83
+ if pred['score'] >= threshold:
84
+ bbox = pred['box']
85
+ width, height = image.size
86
+ normalized_bbox = [
87
+ bbox['xmin'] / width,
88
+ bbox['ymin'] / height,
89
+ bbox['xmax'] / width,
90
+ bbox['ymax'] / height
91
+ ]
92
+
93
+ detection = {
94
+ 'label': pred['label'],
95
+ 'bbox': normalized_bbox,
96
+ 'score': pred['score']
97
+ }
98
+ detections.append(detection)
99
+
100
+ return detections, image
101
+ except Exception as e:
102
+ print(f"Detection error: {e}")
103
+ return [], image
104
 
105
+ def generate_masks_sam2(detections, image):
106
+ """Generate segmentation masks using SAM2."""
107
+ try:
108
+ if not load_sam_predictor():
109
+ print("SAM2 not available, skipping mask generation")
110
+ return detections
111
+
112
+ if isinstance(image, np.ndarray):
113
+ image = Image.fromarray(image)
114
+
115
+ image_np = np.array(image.convert("RGB"))
116
+ H, W = image_np.shape[:2]
117
+
118
+ # Set image for SAM2
119
+ sam_predictor.set_image(image_np)
120
+
121
+ # Convert normalized bboxes to pixel coordinates
122
+ input_boxes = []
123
+ for det in detections:
124
+ x1, y1, x2, y2 = det['bbox']
125
+ input_boxes.append([int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H)])
126
+
127
+ if not input_boxes:
128
+ return detections
129
+
130
+ input_boxes = np.array(input_boxes)
131
+
132
+ print(f"Generating masks for {len(input_boxes)} detections...")
133
+
134
+ with torch.inference_mode():
135
+ device = "cuda" if torch.cuda.is_available() else "cpu"
136
+ if device == "cuda":
137
+ with torch.autocast("cuda", dtype=torch.bfloat16):
138
+ masks, scores, _ = sam_predictor.predict(
139
+ point_coords=None,
140
+ point_labels=None,
141
+ box=input_boxes,
142
+ multimask_output=False
143
+ )
144
+ else:
145
+ masks, scores, _ = sam_predictor.predict(
146
+ point_coords=None,
147
+ point_labels=None,
148
+ box=input_boxes,
149
+ multimask_output=False
150
  )
151
+
152
+ # Add masks to detections
153
+ results = []
154
+ for i, det in enumerate(detections):
155
+ new_det = det.copy()
156
+ mask = masks[i]
157
+ if mask.ndim == 3:
158
+ mask = mask[0] # Remove batch dimension if present
159
+ new_det['mask'] = mask.astype(np.uint8)
160
+ results.append(new_det)
161
+
162
+ print(f"Successfully generated {len(results)} masks")
163
+ return results
164
+
165
+ except Exception as e:
166
+ print(f"SAM2 mask generation error: {e}")
167
+ return detections
 
 
 
 
 
168
 
169
+ def visualize_detections_with_masks(image, detections_with_masks, show_labels=True, show_boxes=True):
170
  """
171
+ Visualize the detections with their segmentation masks.
172
+ Returns PIL Image instead of showing plot.
 
 
 
 
 
 
 
 
 
173
  """
174
+ # Load the image
175
  if isinstance(image, np.ndarray):
176
  image = Image.fromarray(image)
 
177
  image_np = np.array(image.convert("RGB"))
178
+
179
+ # Get image dimensions
180
  height, width = image_np.shape[:2]
181
 
182
+ # Create figure
183
+ fig = plt.figure(figsize=(12, 8))
184
+ plt.imshow(image_np)
185
 
186
  # Define colors for different instances
187
  colors = plt.cm.tab10(np.linspace(0, 1, 10))
188
 
189
  # Plot each detection
190
  for i, detection in enumerate(detections_with_masks):
191
+ # Get bbox, mask, label, and score
192
  bbox = detection['bbox']
193
  label = detection['label']
194
  score = detection['score']
 
201
  # Color for this instance
202
  color = colors[i % len(colors)]
203
 
204
+ # Display mask with transparency if available
205
+ if 'mask' in detection:
206
  mask = detection['mask']
207
  mask_color = np.zeros((height, width, 4), dtype=np.float32)
208
  mask_color[mask > 0] = [color[0], color[1], color[2], 0.5]
209
+ plt.imshow(mask_color)
210
 
211
  # Draw bounding box if requested
212
  if show_boxes:
213
  rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px,
214
  fill=False, edgecolor=color, linewidth=2)
215
+ plt.gca().add_patch(rect)
216
 
217
  # Add label and score if requested
218
  if show_labels:
219
+ plt.text(x1_px, y1_px - 5, f"{label}: {score:.2f}",
220
  color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10)
221
 
222
+ plt.axis('off')
223
+ plt.tight_layout()
 
 
 
 
224
 
225
+ # Convert to PIL Image using the correct method
226
+ buf = BytesIO()
227
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=150)
228
+ plt.close(fig)
229
+ buf.seek(0)
230
 
231
+ result_image = Image.open(buf)
232
+ return result_image
233
 
234
+ def visualize_detections(image, detections, show_labels=True):
235
  """
236
+ Visualize object detections with bounding boxes only.
237
+ Returns PIL Image instead of showing plot.
238
  """
239
+ # Load the image
240
+ if isinstance(image, np.ndarray):
241
+ image = Image.fromarray(image)
242
+ image_np = np.array(image.convert("RGB"))
243
 
244
+ # Get image dimensions
245
+ height, width = image_np.shape[:2]
 
 
 
246
 
247
+ # Create figure
248
+ fig = plt.figure(figsize=(12, 8))
249
+ plt.imshow(image_np)
250
+
251
+ # Define colors for different instances
252
+ colors = plt.cm.tab10(np.linspace(0, 1, 10))
253
+
254
+ # Plot each detection
255
+ for i, detection in enumerate(detections):
256
+ # Get bbox, label, and score
257
+ bbox = detection['bbox']
258
+ label = detection['label']
259
+ score = detection['score']
260
+
261
+ # Convert normalized bbox to pixel coordinates
262
+ x1, y1, x2, y2 = bbox
263
+ x1_px, y1_px = int(x1 * width), int(y1 * height)
264
+ x2_px, y2_px = int(x2 * width), int(y2 * height)
265
+
266
+ # Color for this instance
267
+ color = colors[i % len(colors)]
268
+
269
+ # Draw bounding box
270
+ rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px,
271
+ fill=False, edgecolor=color, linewidth=2)
272
+ plt.gca().add_patch(rect)
273
+
274
+ # Add label and score if requested
275
+ if show_labels:
276
+ plt.text(x1_px, y1_px - 5, f"{label}: {score:.2f}",
277
+ color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10)
278
+
279
+ plt.axis('off')
280
+ plt.tight_layout()
281
 
282
+ # Convert to PIL Image
283
+ buf = BytesIO()
284
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=150)
285
+ plt.close(fig)
286
+ buf.seek(0)
287
+
288
+ result_image = Image.open(buf)
289
+ return result_image
290
 
291
+ def is_count_query(text):
292
+ """Check if the query is asking for counting."""
293
+ count_keywords = ["how many", "count", "number of", "total"]
294
+ return any(keyword in text.lower() for keyword in count_keywords)
 
 
 
 
 
 
295
 
296
+ def detection_pipeline(query_text, image, threshold, use_sam):
297
+ """Main detection pipeline."""
 
 
 
 
 
 
298
  if image is None:
299
+ return None, "⚠️ Please upload an image first!"
300
 
301
  try:
302
+ # Extract object name from query
303
+ query_lower = query_text.lower()
304
 
305
+ # Simple keyword extraction
306
+ if "people" in query_lower or "person" in query_lower:
307
+ search_terms = "person"
308
+ elif "car" in query_lower or "vehicle" in query_lower:
309
+ search_terms = "car"
310
+ elif "apple" in query_lower:
311
+ search_terms = "apple"
312
+ elif "bottle" in query_lower:
313
+ search_terms = "bottle"
314
+ elif "phone" in query_lower:
315
+ search_terms = "phone"
316
+ elif "dog" in query_lower:
317
+ search_terms = "dog"
318
+ elif "cat" in query_lower:
319
+ search_terms = "cat"
320
+ else:
321
+ # Extract last word as potential object
322
+ words = query_text.strip().split()
323
+ search_terms = words[-1] if words else "object"
324
 
325
+ print(f"Processing query: '{query_text}' -> searching for: '{search_terms}'")
326
+
327
+ # Run object detection
328
+ detections, processed_image = detect_objects_owlv2(search_terms, image, threshold)
329
+
330
+ # Generate masks if requested
331
+ if use_sam and detections:
332
+ detections = generate_masks_sam2(detections, processed_image)
333
+
334
+ # Create visualization using your proven functions
335
+ if use_sam and detections:
336
+ result_image = visualize_detections_with_masks(
337
+ processed_image,
338
+ detections,
339
+ show_labels=True,
340
+ show_boxes=True
341
+ )
342
  else:
343
+ result_image = visualize_detections(
344
+ processed_image,
345
+ detections,
346
+ show_labels=True
347
+ )
348
+
349
+ # Generate summary
350
+ count = len(detections)
351
 
352
+ summary_parts = []
353
+ summary_parts.append(f"πŸ” **Search Query**: '{query_text}'")
354
+ summary_parts.append(f"🎯 **Detected Object Type**: '{search_terms}'")
355
+ summary_parts.append(f"βš™οΈ **Threshold**: {threshold}")
356
+ summary_parts.append(f"πŸ€– **SAM2 Segmentation**: {'Enabled' if use_sam else 'Disabled'}")
357
 
358
+ if count > 0:
359
+ if is_count_query(query_text):
360
+ summary_parts.append(f"πŸ”’ **Answer: {count} {search_terms}(s) found**")
361
+ else:
362
+ summary_parts.append(f"βœ… **Found {count} {search_terms}(s)**")
363
+
364
+ # Show detection details
365
+ for i, det in enumerate(detections[:5]): # Show first 5
366
+ summary_parts.append(f" β€’ Detection {i+1}: {det['score']:.3f} confidence")
367
+ if count > 5:
368
+ summary_parts.append(f" β€’ ... and {count-5} more detections")
369
+ else:
370
+ summary_parts.append(f"❌ **No {search_terms}(s) detected**")
371
+ summary_parts.append("πŸ’‘ Try lowering the threshold or using different terms")
372
 
373
+ summary_text = "\n".join(summary_parts)
 
 
374
 
375
+ return result_image, summary_text
 
376
 
377
  except Exception as e:
378
+ error_msg = f"❌ **Error**: {str(e)}"
 
379
  return image, error_msg
380
 
381
+ # ----------------
382
+ # GRADIO INTERFACE
383
+ # ----------------
384
+ with gr.Blocks(title="πŸ” Object Detection & Segmentation") as demo:
385
+ gr.Markdown("""
386
+ # πŸ” Object Detection & Segmentation App
387
+
388
+ **Simple and powerful object detection using OWL-ViT + SAM2**
389
+
390
+ 1. **Enter your query** (e.g., "How many people?", "Find cars", "Count apples")
391
+ 2. **Upload an image**
392
+ 3. **Adjust detection sensitivity**
393
+ 4. **Toggle SAM2 segmentation** for precise masks
394
+ 5. **Click Detect!**
395
+ """)
396
 
397
  with gr.Row():
398
+ with gr.Column(scale=1):
399
+ query_input = gr.Textbox(
400
+ label="πŸ—£οΈ What do you want to detect?",
401
+ placeholder="e.g., 'How many people are in the image?'",
402
+ value="How many people are in the image?",
403
+ lines=2
404
+ )
405
+
406
+ image_input = gr.Image(
407
+ label="πŸ“Έ Upload your image",
408
+ type="numpy"
409
+ )
410
 
411
  with gr.Row():
412
  threshold_slider = gr.Slider(
413
+ minimum=0.01,
414
+ maximum=0.9,
415
+ value=0.1,
416
  step=0.01,
417
+ label="🎚️ Detection Sensitivity"
 
418
  )
419
+
420
+ sam_checkbox = gr.Checkbox(
421
+ label="🎭 Enable SAM2 Segmentation",
422
  value=False,
423
+ info="Generate precise pixel masks"
424
  )
425
 
426
+ detect_button = gr.Button("πŸ” Detect Objects!", variant="primary", size="lg")
 
 
 
 
427
 
428
+ with gr.Column(scale=1):
429
+ output_image = gr.Image(label="🎯 Detection Results")
430
+ output_text = gr.Textbox(
431
+ label="πŸ“Š Detection Summary",
432
+ lines=12,
433
+ show_copy_button=True
434
+ )
435
+
436
+ # Event handlers
437
+ detect_button.click(
438
+ fn=detection_pipeline,
439
+ inputs=[query_input, image_input, threshold_slider, sam_checkbox],
440
  outputs=[output_image, output_text]
441
  )
442
+
443
+ # Also trigger on Enter in text box
444
+ query_input.submit(
445
+ fn=detection_pipeline,
446
+ inputs=[query_input, image_input, threshold_slider, sam_checkbox],
447
+ outputs=[output_image, output_text]
448
+ )
449
+
450
+ # Examples section
451
+ gr.Examples(
452
+ examples=[
453
+ ["How many people are in the image?", None, 0.1, False],
454
+ ["Find all cars", None, 0.15, True],
455
+ ["Count the bottles", None, 0.1, True],
456
+ ["Detect dogs", None, 0.2, False],
457
+ ["How many phones?", None, 0.15, True],
458
+ ],
459
+ inputs=[query_input, image_input, threshold_slider, sam_checkbox],
460
+ )
461
 
462
+ # Launch
463
  if __name__ == "__main__":
464
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)