obichimav commited on
Commit
3dd83a2
Β·
verified Β·
1 Parent(s): 1249f8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +651 -8
app.py CHANGED
@@ -1,3 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
@@ -9,6 +579,7 @@ from io import BytesIO
9
  import importlib.util
10
  import os
11
  import openai
 
12
 
13
  # Suppress warnings
14
  warnings.filterwarnings("ignore")
@@ -29,7 +600,75 @@ else:
29
  detector = None
30
  sam_predictor = None
31
 
32
- def load_detector():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  """Load the OWL-ViT detector once and cache it."""
34
  global detector
35
  if detector is None:
@@ -418,22 +1057,27 @@ def detection_pipeline(query_text, image, threshold, use_sam):
418
  # Run object detection
419
  detections, processed_image = detect_objects_owlv2(search_terms, image, threshold)
420
 
421
- print(f"Found {len(detections)} detections")
422
  for i, det in enumerate(detections):
423
- print(f"Detection {i+1}: {det['label']} (score: {det['score']:.3f})")
 
 
 
 
 
424
 
425
  # Generate masks if requested
426
  if use_sam and detections:
427
  print("Generating SAM2 masks...")
428
  detections = generate_masks_sam2(detections, processed_image)
429
 
430
- # Create visualization using your proven functions
431
  print("Creating visualization...")
432
  if use_sam and detections and 'mask' in detections[0]:
433
  result_image = visualize_detections_with_masks(
434
  processed_image,
435
  detections,
436
- show_labels=True,
437
  show_boxes=True
438
  )
439
  print("Created visualization with masks")
@@ -441,7 +1085,7 @@ def detection_pipeline(query_text, image, threshold, use_sam):
441
  result_image = visualize_detections(
442
  processed_image,
443
  detections,
444
- show_labels=True
445
  )
446
  print("Created visualization with bounding boxes only")
447
 
@@ -565,5 +1209,4 @@ with gr.Blocks(title="πŸ” Object Detection & Segmentation") as demo:
565
 
566
  # Launch
567
  if __name__ == "__main__":
568
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
569
-
 
1
+ # import gradio as gr
2
+ # import numpy as np
3
+ # import torch
4
+ # from PIL import Image
5
+ # import matplotlib.pyplot as plt
6
+ # from transformers import pipeline
7
+ # import warnings
8
+ # from io import BytesIO
9
+ # import importlib.util
10
+ # import os
11
+ # import openai
12
+
13
+ # # Suppress warnings
14
+ # warnings.filterwarnings("ignore")
15
+
16
+ # # Set up OpenAI API key
17
+ # api_key = os.getenv('OPENAI_API_KEY')
18
+ # if not api_key:
19
+ # print("No OpenAI API key found - will use simple keyword extraction")
20
+ # elif not api_key.startswith("sk-proj-") and not api_key.startswith("sk-"):
21
+ # print("API key found but doesn't look correct")
22
+ # elif api_key.strip() != api_key:
23
+ # print("API key has leading or trailing whitespace - please fix it.")
24
+ # else:
25
+ # print("OpenAI API key found and looks good!")
26
+ # openai.api_key = api_key
27
+
28
+ # # Global variables for models
29
+ # detector = None
30
+ # sam_predictor = None
31
+
32
+ # def load_detector():
33
+ # """Load the OWL-ViT detector once and cache it."""
34
+ # global detector
35
+ # if detector is None:
36
+ # print("Loading OWL-ViT model...")
37
+ # detector = pipeline(
38
+ # model="google/owlv2-base-patch16-ensemble",
39
+ # task="zero-shot-object-detection",
40
+ # device=0 if torch.cuda.is_available() else -1
41
+ # )
42
+ # print("OWL-ViT model loaded successfully!")
43
+
44
+ # def install_sam2_if_needed():
45
+ # """Check if SAM2 is installed, and install it if needed."""
46
+ # if importlib.util.find_spec("sam2") is not None:
47
+ # print("SAM2 is already installed.")
48
+ # return True
49
+
50
+ # try:
51
+ # import subprocess
52
+ # import sys
53
+ # print("Installing SAM2 from GitHub...")
54
+ # subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/sam2.git"])
55
+ # print("SAM2 installed successfully.")
56
+ # return True
57
+ # except Exception as e:
58
+ # print(f"Error installing SAM2: {e}")
59
+ # return False
60
+
61
+ # def load_sam_predictor():
62
+ # """Load SAM2 predictor if available."""
63
+ # global sam_predictor
64
+ # if sam_predictor is None:
65
+ # if install_sam2_if_needed():
66
+ # try:
67
+ # from sam2.sam2_image_predictor import SAM2ImagePredictor
68
+ # print("Loading SAM2 model...")
69
+ # sam_predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
70
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
71
+ # sam_predictor.model.to(device)
72
+ # print(f"SAM2 model loaded successfully on {device}!")
73
+ # return True
74
+ # except Exception as e:
75
+ # print(f"Error loading SAM2: {e}")
76
+ # return False
77
+ # return sam_predictor is not None
78
+
79
+ # def detect_objects_owlv2(text_query, image, threshold=0.1):
80
+ # """Detect objects using OWL-ViT."""
81
+ # try:
82
+ # load_detector()
83
+
84
+ # if isinstance(image, np.ndarray):
85
+ # image = Image.fromarray(image)
86
+
87
+ # # Clean up the text query
88
+ # query_terms = [term.strip() for term in text_query.split(',') if term.strip()]
89
+ # if not query_terms:
90
+ # query_terms = ["object"]
91
+
92
+ # print(f"Detecting: {query_terms}")
93
+ # predictions = detector(image, candidate_labels=query_terms)
94
+
95
+ # detections = []
96
+ # for pred in predictions:
97
+ # if pred['score'] >= threshold:
98
+ # bbox = pred['box']
99
+ # width, height = image.size
100
+ # normalized_bbox = [
101
+ # bbox['xmin'] / width,
102
+ # bbox['ymin'] / height,
103
+ # bbox['xmax'] / width,
104
+ # bbox['ymax'] / height
105
+ # ]
106
+
107
+ # detection = {
108
+ # 'label': pred['label'],
109
+ # 'bbox': normalized_bbox,
110
+ # 'score': pred['score']
111
+ # }
112
+ # detections.append(detection)
113
+
114
+ # return detections, image
115
+ # except Exception as e:
116
+ # print(f"Detection error: {e}")
117
+ # return [], image
118
+
119
+ # def generate_masks_sam2(detections, image):
120
+ # """Generate segmentation masks using SAM2."""
121
+ # try:
122
+ # if not load_sam_predictor():
123
+ # print("SAM2 not available, skipping mask generation")
124
+ # return detections
125
+
126
+ # if isinstance(image, np.ndarray):
127
+ # image = Image.fromarray(image)
128
+
129
+ # image_np = np.array(image.convert("RGB"))
130
+ # H, W = image_np.shape[:2]
131
+
132
+ # # Set image for SAM2
133
+ # sam_predictor.set_image(image_np)
134
+
135
+ # # Convert normalized bboxes to pixel coordinates
136
+ # input_boxes = []
137
+ # for det in detections:
138
+ # x1, y1, x2, y2 = det['bbox']
139
+ # input_boxes.append([int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H)])
140
+
141
+ # if not input_boxes:
142
+ # return detections
143
+
144
+ # input_boxes = np.array(input_boxes)
145
+
146
+ # print(f"Generating masks for {len(input_boxes)} detections...")
147
+
148
+ # with torch.inference_mode():
149
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
150
+ # if device == "cuda":
151
+ # with torch.autocast("cuda", dtype=torch.bfloat16):
152
+ # masks, scores, _ = sam_predictor.predict(
153
+ # point_coords=None,
154
+ # point_labels=None,
155
+ # box=input_boxes,
156
+ # multimask_output=False
157
+ # )
158
+ # else:
159
+ # masks, scores, _ = sam_predictor.predict(
160
+ # point_coords=None,
161
+ # point_labels=None,
162
+ # box=input_boxes,
163
+ # multimask_output=False
164
+ # )
165
+
166
+ # # Add masks to detections
167
+ # results = []
168
+ # for i, det in enumerate(detections):
169
+ # new_det = det.copy()
170
+ # mask = masks[i]
171
+ # if mask.ndim == 3:
172
+ # mask = mask[0] # Remove batch dimension if present
173
+ # new_det['mask'] = mask.astype(np.uint8)
174
+ # results.append(new_det)
175
+
176
+ # print(f"Successfully generated {len(results)} masks")
177
+ # return results
178
+
179
+ # except Exception as e:
180
+ # print(f"SAM2 mask generation error: {e}")
181
+ # return detections
182
+
183
+ # def visualize_detections_with_masks(image, detections_with_masks, show_labels=True, show_boxes=True):
184
+ # """
185
+ # Visualize the detections with their segmentation masks.
186
+ # Returns PIL Image instead of showing plot.
187
+ # """
188
+ # # Load the image
189
+ # if isinstance(image, np.ndarray):
190
+ # image = Image.fromarray(image)
191
+ # image_np = np.array(image.convert("RGB"))
192
+
193
+ # # Get image dimensions
194
+ # height, width = image_np.shape[:2]
195
+
196
+ # # Create figure
197
+ # fig = plt.figure(figsize=(12, 8))
198
+ # plt.imshow(image_np)
199
+
200
+ # # Define colors for different instances
201
+ # colors = plt.cm.tab10(np.linspace(0, 1, 10))
202
+
203
+ # # Plot each detection
204
+ # for i, detection in enumerate(detections_with_masks):
205
+ # # Get bbox, mask, label, and score
206
+ # bbox = detection['bbox']
207
+ # label = detection['label']
208
+ # score = detection['score']
209
+
210
+ # # Convert normalized bbox to pixel coordinates
211
+ # x1, y1, x2, y2 = bbox
212
+ # x1_px, y1_px = int(x1 * width), int(y1 * height)
213
+ # x2_px, y2_px = int(x2 * width), int(y2 * height)
214
+
215
+ # # Color for this instance
216
+ # color = colors[i % len(colors)]
217
+
218
+ # # Display mask with transparency if available
219
+ # if 'mask' in detection:
220
+ # mask = detection['mask']
221
+ # mask_color = np.zeros((height, width, 4), dtype=np.float32)
222
+ # mask_color[mask > 0] = [color[0], color[1], color[2], 0.5]
223
+ # plt.imshow(mask_color)
224
+
225
+ # # Draw bounding box if requested
226
+ # if show_boxes:
227
+ # rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px,
228
+ # fill=False, edgecolor=color, linewidth=2)
229
+ # plt.gca().add_patch(rect)
230
+
231
+ # # Add label and score if requested
232
+ # if show_labels:
233
+ # plt.text(x1_px, y1_px - 5, f"{label}: {score:.2f}",
234
+ # color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10)
235
+
236
+ # plt.axis('off')
237
+ # plt.tight_layout()
238
+
239
+ # # Convert to PIL Image using the correct method
240
+ # buf = BytesIO()
241
+ # plt.savefig(buf, format='png', bbox_inches='tight', dpi=150)
242
+ # plt.close(fig)
243
+ # buf.seek(0)
244
+
245
+ # result_image = Image.open(buf)
246
+ # return result_image
247
+
248
+ # def visualize_detections(image, detections, show_labels=True):
249
+ # """
250
+ # Visualize object detections with bounding boxes only.
251
+ # Returns PIL Image instead of showing plot.
252
+ # """
253
+ # # Load the image
254
+ # if isinstance(image, np.ndarray):
255
+ # image = Image.fromarray(image)
256
+ # image_np = np.array(image.convert("RGB"))
257
+
258
+ # # Get image dimensions
259
+ # height, width = image_np.shape[:2]
260
+
261
+ # # Create figure
262
+ # fig = plt.figure(figsize=(12, 8))
263
+ # plt.imshow(image_np)
264
+
265
+ # # If we have detections, draw them
266
+ # if detections:
267
+ # # Define colors for different instances
268
+ # colors = plt.cm.tab10(np.linspace(0, 1, 10))
269
+
270
+ # # Plot each detection
271
+ # for i, detection in enumerate(detections):
272
+ # # Get bbox, label, and score
273
+ # bbox = detection['bbox']
274
+ # label = detection['label']
275
+ # score = detection['score']
276
+
277
+ # # Convert normalized bbox to pixel coordinates
278
+ # x1, y1, x2, y2 = bbox
279
+ # x1_px, y1_px = int(x1 * width), int(y1 * height)
280
+ # x2_px, y2_px = int(x2 * width), int(y2 * height)
281
+
282
+ # # Color for this instance
283
+ # color = colors[i % len(colors)]
284
+
285
+ # # Draw bounding box
286
+ # rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px,
287
+ # fill=False, edgecolor=color, linewidth=2)
288
+ # plt.gca().add_patch(rect)
289
+
290
+ # # Add label and score if requested
291
+ # if show_labels:
292
+ # plt.text(x1_px, y1_px - 5, f"{label}: {score:.2f}",
293
+ # color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10)
294
+
295
+ # # Set title
296
+ # plt.title(f'Object Detection Results ({len(detections)} objects found)', fontsize=14, pad=20)
297
+ # plt.axis('off')
298
+ # plt.tight_layout()
299
+
300
+ # # Convert to PIL Image
301
+ # buf = BytesIO()
302
+ # plt.savefig(buf, format='png', bbox_inches='tight', dpi=150)
303
+ # plt.close(fig)
304
+ # buf.seek(0)
305
+
306
+ # result_image = Image.open(buf)
307
+ # return result_image
308
+
309
+ # def get_optimized_prompt(query_text):
310
+ # """
311
+ # Use OpenAI to convert natural language query into optimal detection prompt.
312
+ # Falls back to simple extraction if OpenAI is not available.
313
+ # """
314
+ # if not query_text.strip():
315
+ # return "object"
316
+
317
+ # # Try OpenAI first if API key is available
318
+ # if hasattr(openai, 'api_key') and openai.api_key:
319
+ # try:
320
+ # response = openai.chat.completions.create(
321
+ # model="gpt-3.5-turbo",
322
+ # messages=[{
323
+ # "role": "system",
324
+ # "content": """You are an expert at converting natural language queries into precise object detection terms.
325
+
326
+ # RULES:
327
+ # 1. Return ONLY 1-2 words maximum that describe the object to detect
328
+ # 2. Use the exact object name from the user's query
329
+ # 3. For people: use "person"
330
+ # 4. For vehicles: use "car", "truck", "bicycle"
331
+ # 5. Do NOT include counting words, articles, or explanations
332
+ # 6. Examples:
333
+ # - "How many cacao fruits are there?" β†’ "cacao fruit"
334
+ # - "Count the corn in the field" β†’ "corn"
335
+ # - "Find all people" β†’ "person"
336
+ # - "How many cacao pods?" β†’ "cacao pod"
337
+ # - "Detect cars" β†’ "car"
338
+ # - "Count bananas" β†’ "banana"
339
+ # - "How many apples?" β†’ "apple"
340
+
341
+ # Return ONLY the object name, nothing else."""
342
+ # }, {
343
+ # "role": "user",
344
+ # "content": query_text
345
+ # }],
346
+ # temperature=0.0, # Make it deterministic
347
+ # max_tokens=5 # Force brevity
348
+ # )
349
+
350
+ # llm_result = response.choices[0].message.content.strip().lower()
351
+ # # Extra safety: take only first 2 words
352
+ # words = llm_result.split()[:2]
353
+ # final_result = " ".join(words)
354
+
355
+ # print(f"πŸ€– OpenAI suggested prompt: '{final_result}'")
356
+ # return final_result
357
+
358
+ # except Exception as e:
359
+ # print(f"OpenAI error: {e}, falling back to keyword extraction")
360
+
361
+ # # Fallback to simple keyword extraction (no hardcoded fruits)
362
+ # print("πŸ”€ Using keyword extraction (no OpenAI)")
363
+ # query_lower = query_text.lower().replace("?", "").strip()
364
+
365
+ # # Look for common patterns and extract object names
366
+ # if "how many" in query_lower:
367
+ # parts = query_lower.split("how many")
368
+ # if len(parts) > 1:
369
+ # remaining = parts[1].strip()
370
+ # remaining = remaining.replace("are", "").replace("in", "").replace("the", "").replace("image", "").replace("there", "").strip()
371
+ # # Take first meaningful word(s)
372
+ # words = remaining.split()[:2]
373
+ # search_terms = " ".join(words) if words else "object"
374
+ # else:
375
+ # search_terms = "object"
376
+ # elif "count" in query_lower:
377
+ # parts = query_lower.split("count")
378
+ # if len(parts) > 1:
379
+ # remaining = parts[1].strip()
380
+ # remaining = remaining.replace("the", "").replace("in", "").replace("image", "").strip()
381
+ # words = remaining.split()[:2]
382
+ # search_terms = " ".join(words) if words else "object"
383
+ # else:
384
+ # search_terms = "object"
385
+ # elif "find" in query_lower:
386
+ # parts = query_lower.split("find")
387
+ # if len(parts) > 1:
388
+ # remaining = parts[1].strip()
389
+ # remaining = remaining.replace("all", "").replace("the", "").replace("in", "").replace("image", "").strip()
390
+ # words = remaining.split()[:2]
391
+ # search_terms = " ".join(words) if words else "object"
392
+ # else:
393
+ # search_terms = "object"
394
+ # else:
395
+ # # Extract first 1-2 meaningful words from the query
396
+ # words = query_lower.split()
397
+ # meaningful_words = [w for w in words if w not in ["how", "many", "are", "in", "the", "image", "find", "count", "detect", "there", "this", "that", "a", "an"]]
398
+ # search_terms = " ".join(meaningful_words[:2]) if meaningful_words else "object"
399
+
400
+ # return search_terms
401
+
402
+ # def is_count_query(text):
403
+ # """Check if the query is asking for counting."""
404
+ # count_keywords = ["how many", "count", "number of", "total"]
405
+ # return any(keyword in text.lower() for keyword in count_keywords)
406
+
407
+ # def detection_pipeline(query_text, image, threshold, use_sam):
408
+ # """Main detection pipeline."""
409
+ # if image is None:
410
+ # return None, "⚠️ Please upload an image first!"
411
+
412
+ # try:
413
+ # # Use OpenAI or fallback to get optimized search terms
414
+ # search_terms = get_optimized_prompt(query_text)
415
+
416
+ # print(f"Processing query: '{query_text}' -> searching for: '{search_terms}'")
417
+
418
+ # # Run object detection
419
+ # detections, processed_image = detect_objects_owlv2(search_terms, image, threshold)
420
+
421
+ # print(f"Found {len(detections)} detections")
422
+ # for i, det in enumerate(detections):
423
+ # print(f"Detection {i+1}: {det['label']} (score: {det['score']:.3f})")
424
+
425
+ # # Generate masks if requested
426
+ # if use_sam and detections:
427
+ # print("Generating SAM2 masks...")
428
+ # detections = generate_masks_sam2(detections, processed_image)
429
+
430
+ # # Create visualization using your proven functions
431
+ # print("Creating visualization...")
432
+ # if use_sam and detections and 'mask' in detections[0]:
433
+ # result_image = visualize_detections_with_masks(
434
+ # processed_image,
435
+ # detections,
436
+ # show_labels=True,
437
+ # show_boxes=True
438
+ # )
439
+ # print("Created visualization with masks")
440
+ # else:
441
+ # result_image = visualize_detections(
442
+ # processed_image,
443
+ # detections,
444
+ # show_labels=True
445
+ # )
446
+ # print("Created visualization with bounding boxes only")
447
+
448
+ # # Make sure we have a valid result image
449
+ # if result_image is None:
450
+ # print("Warning: result_image is None, returning original image")
451
+ # result_image = processed_image
452
+
453
+ # # Generate summary
454
+ # count = len(detections)
455
+
456
+ # summary_parts = []
457
+ # summary_parts.append(f"πŸ—£οΈ **Original Query**: '{query_text}'")
458
+ # summary_parts.append(f"πŸ€– **AI-Optimized Search**: '{search_terms}'")
459
+ # summary_parts.append(f"βš™οΈ **Threshold**: {threshold}")
460
+ # summary_parts.append(f"🎭 **SAM2 Segmentation**: {'Enabled' if use_sam else 'Disabled'}")
461
+
462
+ # if count > 0:
463
+ # if is_count_query(query_text):
464
+ # summary_parts.append(f"πŸ”’ **Answer: {count} {search_terms}(s) found**")
465
+ # else:
466
+ # summary_parts.append(f"βœ… **Found {count} {search_terms}(s)**")
467
+
468
+ # # Show detection details
469
+ # for i, det in enumerate(detections[:5]): # Show first 5
470
+ # summary_parts.append(f" β€’ Detection {i+1}: {det['score']:.3f} confidence")
471
+ # if count > 5:
472
+ # summary_parts.append(f" β€’ ... and {count-5} more detections")
473
+ # else:
474
+ # summary_parts.append(f"❌ **No {search_terms}(s) detected**")
475
+ # summary_parts.append("πŸ’‘ Try lowering the threshold or using different terms")
476
+
477
+ # summary_text = "\n".join(summary_parts)
478
+
479
+ # return result_image, summary_text
480
+
481
+ # except Exception as e:
482
+ # error_msg = f"❌ **Error**: {str(e)}"
483
+ # return image, error_msg
484
+
485
+ # # ----------------
486
+ # # GRADIO INTERFACE
487
+ # # ----------------
488
+ # with gr.Blocks(title="πŸ” Object Detection & Segmentation") as demo:
489
+ # gr.Markdown("""
490
+ # # πŸ” Object Detection & Segmentation App
491
+
492
+ # **Simple and powerful object detection using OWL-ViT + SAM2**
493
+
494
+ # 1. **Enter your query** (e.g., "How many people?", "Find cars", "Count apples")
495
+ # 2. **Upload an image**
496
+ # 3. **Adjust detection sensitivity**
497
+ # 4. **Toggle SAM2 segmentation** for precise masks
498
+ # 5. **Click Detect!**
499
+ # """)
500
+
501
+ # with gr.Row():
502
+ # with gr.Column(scale=1):
503
+ # query_input = gr.Textbox(
504
+ # label="πŸ—£οΈ What do you want to detect?",
505
+ # placeholder="e.g., 'How many people are in the image?'",
506
+ # value="How many people are in the image?",
507
+ # lines=2
508
+ # )
509
+
510
+ # image_input = gr.Image(
511
+ # label="πŸ“Έ Upload your image",
512
+ # type="numpy"
513
+ # )
514
+
515
+ # with gr.Row():
516
+ # threshold_slider = gr.Slider(
517
+ # minimum=0.01,
518
+ # maximum=0.9,
519
+ # value=0.1,
520
+ # step=0.01,
521
+ # label="🎚️ Detection Sensitivity"
522
+ # )
523
+
524
+ # sam_checkbox = gr.Checkbox(
525
+ # label="🎭 Enable SAM2 Segmentation",
526
+ # value=False,
527
+ # info="Generate precise pixel masks"
528
+ # )
529
+
530
+ # detect_button = gr.Button("πŸ” Detect Objects!", variant="primary", size="lg")
531
+
532
+ # with gr.Column(scale=1):
533
+ # output_image = gr.Image(label="🎯 Detection Results")
534
+ # output_text = gr.Textbox(
535
+ # label="πŸ“Š Detection Summary",
536
+ # lines=12,
537
+ # show_copy_button=True
538
+ # )
539
+
540
+ # # Event handlers
541
+ # detect_button.click(
542
+ # fn=detection_pipeline,
543
+ # inputs=[query_input, image_input, threshold_slider, sam_checkbox],
544
+ # outputs=[output_image, output_text]
545
+ # )
546
+
547
+ # # Also trigger on Enter in text box
548
+ # query_input.submit(
549
+ # fn=detection_pipeline,
550
+ # inputs=[query_input, image_input, threshold_slider, sam_checkbox],
551
+ # outputs=[output_image, output_text]
552
+ # )
553
+
554
+ # # Examples section
555
+ # gr.Examples(
556
+ # examples=[
557
+ # ["How many people are in the image?", None, 0.1, False],
558
+ # ["Find all cars", None, 0.15, True],
559
+ # ["Count the bottles", None, 0.1, True],
560
+ # ["Detect dogs", None, 0.2, False],
561
+ # ["How many phones?", None, 0.15, True],
562
+ # ],
563
+ # inputs=[query_input, image_input, threshold_slider, sam_checkbox],
564
+ # )
565
+
566
+ # # Launch
567
+ # if __name__ == "__main__":
568
+ # demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
569
+
570
+
571
  import gradio as gr
572
  import numpy as np
573
  import torch
 
579
  import importlib.util
580
  import os
581
  import openai
582
+ from typing import List, Dict
583
 
584
  # Suppress warnings
585
  warnings.filterwarnings("ignore")
 
600
  detector = None
601
  sam_predictor = None
602
 
603
+ def calculate_bbox_area(bbox):
604
+ """Calculate the area of a normalized bounding box."""
605
+ x1, y1, x2, y2 = bbox
606
+ width = abs(x2 - x1)
607
+ height = abs(y2 - y1)
608
+ return width * height
609
+
610
+ def filter_bbox_outliers(detections: List[Dict],
611
+ method: str = 'zscore',
612
+ threshold: float = 2.0,
613
+ min_score: float = 0.0) -> List[Dict]:
614
+ """
615
+ Filter out outlier bounding boxes based on their area.
616
+
617
+ Args:
618
+ detections: List of detection dictionaries with 'bbox', 'label', 'score'
619
+ method: 'iqr' (Interquartile Range) or 'zscore' (Z-score)
620
+ threshold: Multiplier for IQR method or Z-score threshold
621
+ min_score: Minimum confidence score to keep detection
622
+
623
+ Returns:
624
+ Filtered list of detections
625
+ """
626
+ if not detections:
627
+ return detections
628
+
629
+ # Filter by minimum score first
630
+ detections = [det for det in detections if det['score'] >= min_score]
631
+
632
+ if len(detections) <= 2: # Need at least 3 detections for meaningful outlier removal
633
+ return detections
634
+
635
+ # Calculate areas for all bounding boxes
636
+ areas = [calculate_bbox_area(det['bbox']) for det in detections]
637
+ areas = np.array(areas)
638
+
639
+ if method == 'iqr':
640
+ # IQR method
641
+ q1 = np.percentile(areas, 25)
642
+ q3 = np.percentile(areas, 75)
643
+ iqr = q3 - q1
644
+
645
+ lower_bound = q1 - threshold * iqr
646
+ upper_bound = q3 + threshold * iqr
647
+
648
+ valid_indices = np.where((areas >= lower_bound) & (areas <= upper_bound))[0]
649
+
650
+ elif method == 'zscore':
651
+ # Z-score method
652
+ if np.std(areas) == 0: # All areas are the same
653
+ return detections
654
+
655
+ mean_area = np.mean(areas)
656
+ std_area = np.std(areas)
657
+
658
+ z_scores = np.abs((areas - mean_area) / std_area)
659
+ valid_indices = np.where(z_scores <= threshold)[0]
660
+
661
+ else:
662
+ raise ValueError("Method must be 'iqr' or 'zscore'")
663
+
664
+ # Return filtered detections
665
+ filtered_detections = [detections[i] for i in valid_indices]
666
+
667
+ print(f"Original detections: {len(detections)}")
668
+ print(f"Filtered detections: {len(filtered_detections)}")
669
+ print(f"Removed {len(detections) - len(filtered_detections)} outliers")
670
+
671
+ return filtered_detections
672
  """Load the OWL-ViT detector once and cache it."""
673
  global detector
674
  if detector is None:
 
1057
  # Run object detection
1058
  detections, processed_image = detect_objects_owlv2(search_terms, image, threshold)
1059
 
1060
+ print(f"Found {len(detections)} initial detections")
1061
  for i, det in enumerate(detections):
1062
+ print(f"Detection {i+1}: {det['label']} (score: {det['score']:.3f}, area: {calculate_bbox_area(det['bbox']):.6f})")
1063
+
1064
+ # Filter outliers before SAM2
1065
+ if len(detections) > 2: # Only filter if we have enough detections
1066
+ detections = filter_bbox_outliers(detections, method='zscore', threshold=2.0)
1067
+ print(f"After outlier filtering: {len(detections)} detections remain")
1068
 
1069
  # Generate masks if requested
1070
  if use_sam and detections:
1071
  print("Generating SAM2 masks...")
1072
  detections = generate_masks_sam2(detections, processed_image)
1073
 
1074
+ # Create visualization using your proven functions (labels OFF)
1075
  print("Creating visualization...")
1076
  if use_sam and detections and 'mask' in detections[0]:
1077
  result_image = visualize_detections_with_masks(
1078
  processed_image,
1079
  detections,
1080
+ show_labels=False, # Labels OFF
1081
  show_boxes=True
1082
  )
1083
  print("Created visualization with masks")
 
1085
  result_image = visualize_detections(
1086
  processed_image,
1087
  detections,
1088
+ show_labels=False # Labels OFF
1089
  )
1090
  print("Created visualization with bounding boxes only")
1091
 
 
1209
 
1210
  # Launch
1211
  if __name__ == "__main__":
1212
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)