obichimav commited on
Commit
ed0c337
·
verified ·
1 Parent(s): 3dd83a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -654
app.py CHANGED
@@ -1,573 +1,3 @@
1
- # import gradio as gr
2
- # import numpy as np
3
- # import torch
4
- # from PIL import Image
5
- # import matplotlib.pyplot as plt
6
- # from transformers import pipeline
7
- # import warnings
8
- # from io import BytesIO
9
- # import importlib.util
10
- # import os
11
- # import openai
12
-
13
- # # Suppress warnings
14
- # warnings.filterwarnings("ignore")
15
-
16
- # # Set up OpenAI API key
17
- # api_key = os.getenv('OPENAI_API_KEY')
18
- # if not api_key:
19
- # print("No OpenAI API key found - will use simple keyword extraction")
20
- # elif not api_key.startswith("sk-proj-") and not api_key.startswith("sk-"):
21
- # print("API key found but doesn't look correct")
22
- # elif api_key.strip() != api_key:
23
- # print("API key has leading or trailing whitespace - please fix it.")
24
- # else:
25
- # print("OpenAI API key found and looks good!")
26
- # openai.api_key = api_key
27
-
28
- # # Global variables for models
29
- # detector = None
30
- # sam_predictor = None
31
-
32
- # def load_detector():
33
- # """Load the OWL-ViT detector once and cache it."""
34
- # global detector
35
- # if detector is None:
36
- # print("Loading OWL-ViT model...")
37
- # detector = pipeline(
38
- # model="google/owlv2-base-patch16-ensemble",
39
- # task="zero-shot-object-detection",
40
- # device=0 if torch.cuda.is_available() else -1
41
- # )
42
- # print("OWL-ViT model loaded successfully!")
43
-
44
- # def install_sam2_if_needed():
45
- # """Check if SAM2 is installed, and install it if needed."""
46
- # if importlib.util.find_spec("sam2") is not None:
47
- # print("SAM2 is already installed.")
48
- # return True
49
-
50
- # try:
51
- # import subprocess
52
- # import sys
53
- # print("Installing SAM2 from GitHub...")
54
- # subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/sam2.git"])
55
- # print("SAM2 installed successfully.")
56
- # return True
57
- # except Exception as e:
58
- # print(f"Error installing SAM2: {e}")
59
- # return False
60
-
61
- # def load_sam_predictor():
62
- # """Load SAM2 predictor if available."""
63
- # global sam_predictor
64
- # if sam_predictor is None:
65
- # if install_sam2_if_needed():
66
- # try:
67
- # from sam2.sam2_image_predictor import SAM2ImagePredictor
68
- # print("Loading SAM2 model...")
69
- # sam_predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
70
- # device = "cuda" if torch.cuda.is_available() else "cpu"
71
- # sam_predictor.model.to(device)
72
- # print(f"SAM2 model loaded successfully on {device}!")
73
- # return True
74
- # except Exception as e:
75
- # print(f"Error loading SAM2: {e}")
76
- # return False
77
- # return sam_predictor is not None
78
-
79
- # def detect_objects_owlv2(text_query, image, threshold=0.1):
80
- # """Detect objects using OWL-ViT."""
81
- # try:
82
- # load_detector()
83
-
84
- # if isinstance(image, np.ndarray):
85
- # image = Image.fromarray(image)
86
-
87
- # # Clean up the text query
88
- # query_terms = [term.strip() for term in text_query.split(',') if term.strip()]
89
- # if not query_terms:
90
- # query_terms = ["object"]
91
-
92
- # print(f"Detecting: {query_terms}")
93
- # predictions = detector(image, candidate_labels=query_terms)
94
-
95
- # detections = []
96
- # for pred in predictions:
97
- # if pred['score'] >= threshold:
98
- # bbox = pred['box']
99
- # width, height = image.size
100
- # normalized_bbox = [
101
- # bbox['xmin'] / width,
102
- # bbox['ymin'] / height,
103
- # bbox['xmax'] / width,
104
- # bbox['ymax'] / height
105
- # ]
106
-
107
- # detection = {
108
- # 'label': pred['label'],
109
- # 'bbox': normalized_bbox,
110
- # 'score': pred['score']
111
- # }
112
- # detections.append(detection)
113
-
114
- # return detections, image
115
- # except Exception as e:
116
- # print(f"Detection error: {e}")
117
- # return [], image
118
-
119
- # def generate_masks_sam2(detections, image):
120
- # """Generate segmentation masks using SAM2."""
121
- # try:
122
- # if not load_sam_predictor():
123
- # print("SAM2 not available, skipping mask generation")
124
- # return detections
125
-
126
- # if isinstance(image, np.ndarray):
127
- # image = Image.fromarray(image)
128
-
129
- # image_np = np.array(image.convert("RGB"))
130
- # H, W = image_np.shape[:2]
131
-
132
- # # Set image for SAM2
133
- # sam_predictor.set_image(image_np)
134
-
135
- # # Convert normalized bboxes to pixel coordinates
136
- # input_boxes = []
137
- # for det in detections:
138
- # x1, y1, x2, y2 = det['bbox']
139
- # input_boxes.append([int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H)])
140
-
141
- # if not input_boxes:
142
- # return detections
143
-
144
- # input_boxes = np.array(input_boxes)
145
-
146
- # print(f"Generating masks for {len(input_boxes)} detections...")
147
-
148
- # with torch.inference_mode():
149
- # device = "cuda" if torch.cuda.is_available() else "cpu"
150
- # if device == "cuda":
151
- # with torch.autocast("cuda", dtype=torch.bfloat16):
152
- # masks, scores, _ = sam_predictor.predict(
153
- # point_coords=None,
154
- # point_labels=None,
155
- # box=input_boxes,
156
- # multimask_output=False
157
- # )
158
- # else:
159
- # masks, scores, _ = sam_predictor.predict(
160
- # point_coords=None,
161
- # point_labels=None,
162
- # box=input_boxes,
163
- # multimask_output=False
164
- # )
165
-
166
- # # Add masks to detections
167
- # results = []
168
- # for i, det in enumerate(detections):
169
- # new_det = det.copy()
170
- # mask = masks[i]
171
- # if mask.ndim == 3:
172
- # mask = mask[0] # Remove batch dimension if present
173
- # new_det['mask'] = mask.astype(np.uint8)
174
- # results.append(new_det)
175
-
176
- # print(f"Successfully generated {len(results)} masks")
177
- # return results
178
-
179
- # except Exception as e:
180
- # print(f"SAM2 mask generation error: {e}")
181
- # return detections
182
-
183
- # def visualize_detections_with_masks(image, detections_with_masks, show_labels=True, show_boxes=True):
184
- # """
185
- # Visualize the detections with their segmentation masks.
186
- # Returns PIL Image instead of showing plot.
187
- # """
188
- # # Load the image
189
- # if isinstance(image, np.ndarray):
190
- # image = Image.fromarray(image)
191
- # image_np = np.array(image.convert("RGB"))
192
-
193
- # # Get image dimensions
194
- # height, width = image_np.shape[:2]
195
-
196
- # # Create figure
197
- # fig = plt.figure(figsize=(12, 8))
198
- # plt.imshow(image_np)
199
-
200
- # # Define colors for different instances
201
- # colors = plt.cm.tab10(np.linspace(0, 1, 10))
202
-
203
- # # Plot each detection
204
- # for i, detection in enumerate(detections_with_masks):
205
- # # Get bbox, mask, label, and score
206
- # bbox = detection['bbox']
207
- # label = detection['label']
208
- # score = detection['score']
209
-
210
- # # Convert normalized bbox to pixel coordinates
211
- # x1, y1, x2, y2 = bbox
212
- # x1_px, y1_px = int(x1 * width), int(y1 * height)
213
- # x2_px, y2_px = int(x2 * width), int(y2 * height)
214
-
215
- # # Color for this instance
216
- # color = colors[i % len(colors)]
217
-
218
- # # Display mask with transparency if available
219
- # if 'mask' in detection:
220
- # mask = detection['mask']
221
- # mask_color = np.zeros((height, width, 4), dtype=np.float32)
222
- # mask_color[mask > 0] = [color[0], color[1], color[2], 0.5]
223
- # plt.imshow(mask_color)
224
-
225
- # # Draw bounding box if requested
226
- # if show_boxes:
227
- # rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px,
228
- # fill=False, edgecolor=color, linewidth=2)
229
- # plt.gca().add_patch(rect)
230
-
231
- # # Add label and score if requested
232
- # if show_labels:
233
- # plt.text(x1_px, y1_px - 5, f"{label}: {score:.2f}",
234
- # color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10)
235
-
236
- # plt.axis('off')
237
- # plt.tight_layout()
238
-
239
- # # Convert to PIL Image using the correct method
240
- # buf = BytesIO()
241
- # plt.savefig(buf, format='png', bbox_inches='tight', dpi=150)
242
- # plt.close(fig)
243
- # buf.seek(0)
244
-
245
- # result_image = Image.open(buf)
246
- # return result_image
247
-
248
- # def visualize_detections(image, detections, show_labels=True):
249
- # """
250
- # Visualize object detections with bounding boxes only.
251
- # Returns PIL Image instead of showing plot.
252
- # """
253
- # # Load the image
254
- # if isinstance(image, np.ndarray):
255
- # image = Image.fromarray(image)
256
- # image_np = np.array(image.convert("RGB"))
257
-
258
- # # Get image dimensions
259
- # height, width = image_np.shape[:2]
260
-
261
- # # Create figure
262
- # fig = plt.figure(figsize=(12, 8))
263
- # plt.imshow(image_np)
264
-
265
- # # If we have detections, draw them
266
- # if detections:
267
- # # Define colors for different instances
268
- # colors = plt.cm.tab10(np.linspace(0, 1, 10))
269
-
270
- # # Plot each detection
271
- # for i, detection in enumerate(detections):
272
- # # Get bbox, label, and score
273
- # bbox = detection['bbox']
274
- # label = detection['label']
275
- # score = detection['score']
276
-
277
- # # Convert normalized bbox to pixel coordinates
278
- # x1, y1, x2, y2 = bbox
279
- # x1_px, y1_px = int(x1 * width), int(y1 * height)
280
- # x2_px, y2_px = int(x2 * width), int(y2 * height)
281
-
282
- # # Color for this instance
283
- # color = colors[i % len(colors)]
284
-
285
- # # Draw bounding box
286
- # rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px,
287
- # fill=False, edgecolor=color, linewidth=2)
288
- # plt.gca().add_patch(rect)
289
-
290
- # # Add label and score if requested
291
- # if show_labels:
292
- # plt.text(x1_px, y1_px - 5, f"{label}: {score:.2f}",
293
- # color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10)
294
-
295
- # # Set title
296
- # plt.title(f'Object Detection Results ({len(detections)} objects found)', fontsize=14, pad=20)
297
- # plt.axis('off')
298
- # plt.tight_layout()
299
-
300
- # # Convert to PIL Image
301
- # buf = BytesIO()
302
- # plt.savefig(buf, format='png', bbox_inches='tight', dpi=150)
303
- # plt.close(fig)
304
- # buf.seek(0)
305
-
306
- # result_image = Image.open(buf)
307
- # return result_image
308
-
309
- # def get_optimized_prompt(query_text):
310
- # """
311
- # Use OpenAI to convert natural language query into optimal detection prompt.
312
- # Falls back to simple extraction if OpenAI is not available.
313
- # """
314
- # if not query_text.strip():
315
- # return "object"
316
-
317
- # # Try OpenAI first if API key is available
318
- # if hasattr(openai, 'api_key') and openai.api_key:
319
- # try:
320
- # response = openai.chat.completions.create(
321
- # model="gpt-3.5-turbo",
322
- # messages=[{
323
- # "role": "system",
324
- # "content": """You are an expert at converting natural language queries into precise object detection terms.
325
-
326
- # RULES:
327
- # 1. Return ONLY 1-2 words maximum that describe the object to detect
328
- # 2. Use the exact object name from the user's query
329
- # 3. For people: use "person"
330
- # 4. For vehicles: use "car", "truck", "bicycle"
331
- # 5. Do NOT include counting words, articles, or explanations
332
- # 6. Examples:
333
- # - "How many cacao fruits are there?" → "cacao fruit"
334
- # - "Count the corn in the field" → "corn"
335
- # - "Find all people" → "person"
336
- # - "How many cacao pods?" → "cacao pod"
337
- # - "Detect cars" → "car"
338
- # - "Count bananas" → "banana"
339
- # - "How many apples?" → "apple"
340
-
341
- # Return ONLY the object name, nothing else."""
342
- # }, {
343
- # "role": "user",
344
- # "content": query_text
345
- # }],
346
- # temperature=0.0, # Make it deterministic
347
- # max_tokens=5 # Force brevity
348
- # )
349
-
350
- # llm_result = response.choices[0].message.content.strip().lower()
351
- # # Extra safety: take only first 2 words
352
- # words = llm_result.split()[:2]
353
- # final_result = " ".join(words)
354
-
355
- # print(f"🤖 OpenAI suggested prompt: '{final_result}'")
356
- # return final_result
357
-
358
- # except Exception as e:
359
- # print(f"OpenAI error: {e}, falling back to keyword extraction")
360
-
361
- # # Fallback to simple keyword extraction (no hardcoded fruits)
362
- # print("🔤 Using keyword extraction (no OpenAI)")
363
- # query_lower = query_text.lower().replace("?", "").strip()
364
-
365
- # # Look for common patterns and extract object names
366
- # if "how many" in query_lower:
367
- # parts = query_lower.split("how many")
368
- # if len(parts) > 1:
369
- # remaining = parts[1].strip()
370
- # remaining = remaining.replace("are", "").replace("in", "").replace("the", "").replace("image", "").replace("there", "").strip()
371
- # # Take first meaningful word(s)
372
- # words = remaining.split()[:2]
373
- # search_terms = " ".join(words) if words else "object"
374
- # else:
375
- # search_terms = "object"
376
- # elif "count" in query_lower:
377
- # parts = query_lower.split("count")
378
- # if len(parts) > 1:
379
- # remaining = parts[1].strip()
380
- # remaining = remaining.replace("the", "").replace("in", "").replace("image", "").strip()
381
- # words = remaining.split()[:2]
382
- # search_terms = " ".join(words) if words else "object"
383
- # else:
384
- # search_terms = "object"
385
- # elif "find" in query_lower:
386
- # parts = query_lower.split("find")
387
- # if len(parts) > 1:
388
- # remaining = parts[1].strip()
389
- # remaining = remaining.replace("all", "").replace("the", "").replace("in", "").replace("image", "").strip()
390
- # words = remaining.split()[:2]
391
- # search_terms = " ".join(words) if words else "object"
392
- # else:
393
- # search_terms = "object"
394
- # else:
395
- # # Extract first 1-2 meaningful words from the query
396
- # words = query_lower.split()
397
- # meaningful_words = [w for w in words if w not in ["how", "many", "are", "in", "the", "image", "find", "count", "detect", "there", "this", "that", "a", "an"]]
398
- # search_terms = " ".join(meaningful_words[:2]) if meaningful_words else "object"
399
-
400
- # return search_terms
401
-
402
- # def is_count_query(text):
403
- # """Check if the query is asking for counting."""
404
- # count_keywords = ["how many", "count", "number of", "total"]
405
- # return any(keyword in text.lower() for keyword in count_keywords)
406
-
407
- # def detection_pipeline(query_text, image, threshold, use_sam):
408
- # """Main detection pipeline."""
409
- # if image is None:
410
- # return None, "⚠️ Please upload an image first!"
411
-
412
- # try:
413
- # # Use OpenAI or fallback to get optimized search terms
414
- # search_terms = get_optimized_prompt(query_text)
415
-
416
- # print(f"Processing query: '{query_text}' -> searching for: '{search_terms}'")
417
-
418
- # # Run object detection
419
- # detections, processed_image = detect_objects_owlv2(search_terms, image, threshold)
420
-
421
- # print(f"Found {len(detections)} detections")
422
- # for i, det in enumerate(detections):
423
- # print(f"Detection {i+1}: {det['label']} (score: {det['score']:.3f})")
424
-
425
- # # Generate masks if requested
426
- # if use_sam and detections:
427
- # print("Generating SAM2 masks...")
428
- # detections = generate_masks_sam2(detections, processed_image)
429
-
430
- # # Create visualization using your proven functions
431
- # print("Creating visualization...")
432
- # if use_sam and detections and 'mask' in detections[0]:
433
- # result_image = visualize_detections_with_masks(
434
- # processed_image,
435
- # detections,
436
- # show_labels=True,
437
- # show_boxes=True
438
- # )
439
- # print("Created visualization with masks")
440
- # else:
441
- # result_image = visualize_detections(
442
- # processed_image,
443
- # detections,
444
- # show_labels=True
445
- # )
446
- # print("Created visualization with bounding boxes only")
447
-
448
- # # Make sure we have a valid result image
449
- # if result_image is None:
450
- # print("Warning: result_image is None, returning original image")
451
- # result_image = processed_image
452
-
453
- # # Generate summary
454
- # count = len(detections)
455
-
456
- # summary_parts = []
457
- # summary_parts.append(f"🗣️ **Original Query**: '{query_text}'")
458
- # summary_parts.append(f"🤖 **AI-Optimized Search**: '{search_terms}'")
459
- # summary_parts.append(f"⚙️ **Threshold**: {threshold}")
460
- # summary_parts.append(f"🎭 **SAM2 Segmentation**: {'Enabled' if use_sam else 'Disabled'}")
461
-
462
- # if count > 0:
463
- # if is_count_query(query_text):
464
- # summary_parts.append(f"🔢 **Answer: {count} {search_terms}(s) found**")
465
- # else:
466
- # summary_parts.append(f"✅ **Found {count} {search_terms}(s)**")
467
-
468
- # # Show detection details
469
- # for i, det in enumerate(detections[:5]): # Show first 5
470
- # summary_parts.append(f" • Detection {i+1}: {det['score']:.3f} confidence")
471
- # if count > 5:
472
- # summary_parts.append(f" • ... and {count-5} more detections")
473
- # else:
474
- # summary_parts.append(f"❌ **No {search_terms}(s) detected**")
475
- # summary_parts.append("💡 Try lowering the threshold or using different terms")
476
-
477
- # summary_text = "\n".join(summary_parts)
478
-
479
- # return result_image, summary_text
480
-
481
- # except Exception as e:
482
- # error_msg = f"❌ **Error**: {str(e)}"
483
- # return image, error_msg
484
-
485
- # # ----------------
486
- # # GRADIO INTERFACE
487
- # # ----------------
488
- # with gr.Blocks(title="🔍 Object Detection & Segmentation") as demo:
489
- # gr.Markdown("""
490
- # # 🔍 Object Detection & Segmentation App
491
-
492
- # **Simple and powerful object detection using OWL-ViT + SAM2**
493
-
494
- # 1. **Enter your query** (e.g., "How many people?", "Find cars", "Count apples")
495
- # 2. **Upload an image**
496
- # 3. **Adjust detection sensitivity**
497
- # 4. **Toggle SAM2 segmentation** for precise masks
498
- # 5. **Click Detect!**
499
- # """)
500
-
501
- # with gr.Row():
502
- # with gr.Column(scale=1):
503
- # query_input = gr.Textbox(
504
- # label="🗣️ What do you want to detect?",
505
- # placeholder="e.g., 'How many people are in the image?'",
506
- # value="How many people are in the image?",
507
- # lines=2
508
- # )
509
-
510
- # image_input = gr.Image(
511
- # label="📸 Upload your image",
512
- # type="numpy"
513
- # )
514
-
515
- # with gr.Row():
516
- # threshold_slider = gr.Slider(
517
- # minimum=0.01,
518
- # maximum=0.9,
519
- # value=0.1,
520
- # step=0.01,
521
- # label="🎚️ Detection Sensitivity"
522
- # )
523
-
524
- # sam_checkbox = gr.Checkbox(
525
- # label="🎭 Enable SAM2 Segmentation",
526
- # value=False,
527
- # info="Generate precise pixel masks"
528
- # )
529
-
530
- # detect_button = gr.Button("🔍 Detect Objects!", variant="primary", size="lg")
531
-
532
- # with gr.Column(scale=1):
533
- # output_image = gr.Image(label="🎯 Detection Results")
534
- # output_text = gr.Textbox(
535
- # label="📊 Detection Summary",
536
- # lines=12,
537
- # show_copy_button=True
538
- # )
539
-
540
- # # Event handlers
541
- # detect_button.click(
542
- # fn=detection_pipeline,
543
- # inputs=[query_input, image_input, threshold_slider, sam_checkbox],
544
- # outputs=[output_image, output_text]
545
- # )
546
-
547
- # # Also trigger on Enter in text box
548
- # query_input.submit(
549
- # fn=detection_pipeline,
550
- # inputs=[query_input, image_input, threshold_slider, sam_checkbox],
551
- # outputs=[output_image, output_text]
552
- # )
553
-
554
- # # Examples section
555
- # gr.Examples(
556
- # examples=[
557
- # ["How many people are in the image?", None, 0.1, False],
558
- # ["Find all cars", None, 0.15, True],
559
- # ["Count the bottles", None, 0.1, True],
560
- # ["Detect dogs", None, 0.2, False],
561
- # ["How many phones?", None, 0.15, True],
562
- # ],
563
- # inputs=[query_input, image_input, threshold_slider, sam_checkbox],
564
- # )
565
-
566
- # # Launch
567
- # if __name__ == "__main__":
568
- # demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
569
-
570
-
571
  import gradio as gr
572
  import numpy as np
573
  import torch
@@ -579,7 +9,6 @@ from io import BytesIO
579
  import importlib.util
580
  import os
581
  import openai
582
- from typing import List, Dict
583
 
584
  # Suppress warnings
585
  warnings.filterwarnings("ignore")
@@ -600,75 +29,7 @@ else:
600
  detector = None
601
  sam_predictor = None
602
 
603
- def calculate_bbox_area(bbox):
604
- """Calculate the area of a normalized bounding box."""
605
- x1, y1, x2, y2 = bbox
606
- width = abs(x2 - x1)
607
- height = abs(y2 - y1)
608
- return width * height
609
-
610
- def filter_bbox_outliers(detections: List[Dict],
611
- method: str = 'zscore',
612
- threshold: float = 2.0,
613
- min_score: float = 0.0) -> List[Dict]:
614
- """
615
- Filter out outlier bounding boxes based on their area.
616
-
617
- Args:
618
- detections: List of detection dictionaries with 'bbox', 'label', 'score'
619
- method: 'iqr' (Interquartile Range) or 'zscore' (Z-score)
620
- threshold: Multiplier for IQR method or Z-score threshold
621
- min_score: Minimum confidence score to keep detection
622
-
623
- Returns:
624
- Filtered list of detections
625
- """
626
- if not detections:
627
- return detections
628
-
629
- # Filter by minimum score first
630
- detections = [det for det in detections if det['score'] >= min_score]
631
-
632
- if len(detections) <= 2: # Need at least 3 detections for meaningful outlier removal
633
- return detections
634
-
635
- # Calculate areas for all bounding boxes
636
- areas = [calculate_bbox_area(det['bbox']) for det in detections]
637
- areas = np.array(areas)
638
-
639
- if method == 'iqr':
640
- # IQR method
641
- q1 = np.percentile(areas, 25)
642
- q3 = np.percentile(areas, 75)
643
- iqr = q3 - q1
644
-
645
- lower_bound = q1 - threshold * iqr
646
- upper_bound = q3 + threshold * iqr
647
-
648
- valid_indices = np.where((areas >= lower_bound) & (areas <= upper_bound))[0]
649
-
650
- elif method == 'zscore':
651
- # Z-score method
652
- if np.std(areas) == 0: # All areas are the same
653
- return detections
654
-
655
- mean_area = np.mean(areas)
656
- std_area = np.std(areas)
657
-
658
- z_scores = np.abs((areas - mean_area) / std_area)
659
- valid_indices = np.where(z_scores <= threshold)[0]
660
-
661
- else:
662
- raise ValueError("Method must be 'iqr' or 'zscore'")
663
-
664
- # Return filtered detections
665
- filtered_detections = [detections[i] for i in valid_indices]
666
-
667
- print(f"Original detections: {len(detections)}")
668
- print(f"Filtered detections: {len(filtered_detections)}")
669
- print(f"Removed {len(detections) - len(filtered_detections)} outliers")
670
-
671
- return filtered_detections
672
  """Load the OWL-ViT detector once and cache it."""
673
  global detector
674
  if detector is None:
@@ -715,6 +76,54 @@ def load_sam_predictor():
715
  return False
716
  return sam_predictor is not None
717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718
  def detect_objects_owlv2(text_query, image, threshold=0.1):
719
  """Detect objects using OWL-ViT."""
720
  try:
@@ -749,8 +158,10 @@ def detect_objects_owlv2(text_query, image, threshold=0.1):
749
  'score': pred['score']
750
  }
751
  detections.append(detection)
 
 
 
752
 
753
- return detections, image
754
  except Exception as e:
755
  print(f"Detection error: {e}")
756
  return [], image
@@ -819,7 +230,7 @@ def generate_masks_sam2(detections, image):
819
  print(f"SAM2 mask generation error: {e}")
820
  return detections
821
 
822
- def visualize_detections_with_masks(image, detections_with_masks, show_labels=True, show_boxes=True):
823
  """
824
  Visualize the detections with their segmentation masks.
825
  Returns PIL Image instead of showing plot.
@@ -884,7 +295,7 @@ def visualize_detections_with_masks(image, detections_with_masks, show_labels=Tr
884
  result_image = Image.open(buf)
885
  return result_image
886
 
887
- def visualize_detections(image, detections, show_labels=True):
888
  """
889
  Visualize object detections with bounding boxes only.
890
  Returns PIL Image instead of showing plot.
@@ -1057,27 +468,22 @@ def detection_pipeline(query_text, image, threshold, use_sam):
1057
  # Run object detection
1058
  detections, processed_image = detect_objects_owlv2(search_terms, image, threshold)
1059
 
1060
- print(f"Found {len(detections)} initial detections")
1061
  for i, det in enumerate(detections):
1062
- print(f"Detection {i+1}: {det['label']} (score: {det['score']:.3f}, area: {calculate_bbox_area(det['bbox']):.6f})")
1063
-
1064
- # Filter outliers before SAM2
1065
- if len(detections) > 2: # Only filter if we have enough detections
1066
- detections = filter_bbox_outliers(detections, method='zscore', threshold=2.0)
1067
- print(f"After outlier filtering: {len(detections)} detections remain")
1068
 
1069
  # Generate masks if requested
1070
  if use_sam and detections:
1071
  print("Generating SAM2 masks...")
1072
  detections = generate_masks_sam2(detections, processed_image)
1073
 
1074
- # Create visualization using your proven functions (labels OFF)
1075
  print("Creating visualization...")
1076
  if use_sam and detections and 'mask' in detections[0]:
1077
  result_image = visualize_detections_with_masks(
1078
  processed_image,
1079
  detections,
1080
- show_labels=False, # Labels OFF
1081
  show_boxes=True
1082
  )
1083
  print("Created visualization with masks")
@@ -1085,7 +491,7 @@ def detection_pipeline(query_text, image, threshold, use_sam):
1085
  result_image = visualize_detections(
1086
  processed_image,
1087
  detections,
1088
- show_labels=False # Labels OFF
1089
  )
1090
  print("Created visualization with bounding boxes only")
1091
 
@@ -1209,4 +615,5 @@ with gr.Blocks(title="🔍 Object Detection & Segmentation") as demo:
1209
 
1210
  # Launch
1211
  if __name__ == "__main__":
1212
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
 
9
  import importlib.util
10
  import os
11
  import openai
 
12
 
13
  # Suppress warnings
14
  warnings.filterwarnings("ignore")
 
29
  detector = None
30
  sam_predictor = None
31
 
32
+ def load_detector():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  """Load the OWL-ViT detector once and cache it."""
34
  global detector
35
  if detector is None:
 
76
  return False
77
  return sam_predictor is not None
78
 
79
+ def filter_bbox_outliers(detections: List[Dict],
80
+ method: str = 'iqr',
81
+ threshold: float = 1.5,
82
+ min_score: float = 0.0) -> List[Dict]:
83
+ """
84
+ Filter out outlier bounding boxes based on their area.
85
+
86
+ Args:
87
+ detections: List of detection dictionaries with 'bbox', 'label', 'score'
88
+ method: 'iqr' (Interquartile Range) or 'zscore' (Z-score) or 'percentile'
89
+ threshold: Multiplier for IQR method or Z-score threshold
90
+ min_score: Minimum confidence score to keep detection
91
+
92
+ Returns:
93
+ Filtered list of detections
94
+ """
95
+ if not detections:
96
+ return detections
97
+
98
+ # Filter by minimum score first
99
+ detections = [det for det in detections if det['score'] >= min_score]
100
+
101
+ # Calculate areas for all bounding boxes
102
+ areas = [calculate_bbox_area(det['bbox']) for det in detections]
103
+ areas = np.array(areas)
104
+
105
+
106
+ if method == 'zscore':
107
+ # Z-score method
108
+ mean_area = np.mean(areas)
109
+ std_area = np.std(areas)
110
+
111
+ z_scores = np.abs((areas - mean_area) / std_area)
112
+ valid_indices = np.where(z_scores <= threshold)[0]
113
+
114
+
115
+ else:
116
+ raise ValueError("Method must be 'iqr', 'zscore', or 'percentile'")
117
+
118
+ # Return filtered detections
119
+ filtered_detections = [detections[i] for i in valid_indices]
120
+
121
+ print(f"Original detections: {len(detections)}")
122
+ print(f"Filtered detections: {len(filtered_detections)}")
123
+ print(f"Removed {len(detections) - len(filtered_detections)} outliers")
124
+
125
+ return filtered_detections
126
+
127
  def detect_objects_owlv2(text_query, image, threshold=0.1):
128
  """Detect objects using OWL-ViT."""
129
  try:
 
158
  'score': pred['score']
159
  }
160
  detections.append(detection)
161
+
162
+ print(detections)
163
+ return filter_bbox_outliers(detections),image
164
 
 
165
  except Exception as e:
166
  print(f"Detection error: {e}")
167
  return [], image
 
230
  print(f"SAM2 mask generation error: {e}")
231
  return detections
232
 
233
+ def visualize_detections_with_masks(image, detections_with_masks, show_labels=False, show_boxes=True):
234
  """
235
  Visualize the detections with their segmentation masks.
236
  Returns PIL Image instead of showing plot.
 
295
  result_image = Image.open(buf)
296
  return result_image
297
 
298
+ def visualize_detections(image, detections, show_labels=False):
299
  """
300
  Visualize object detections with bounding boxes only.
301
  Returns PIL Image instead of showing plot.
 
468
  # Run object detection
469
  detections, processed_image = detect_objects_owlv2(search_terms, image, threshold)
470
 
471
+ print(f"Found {len(detections)} detections")
472
  for i, det in enumerate(detections):
473
+ print(f"Detection {i+1}: {det['label']} (score: {det['score']:.3f})")
 
 
 
 
 
474
 
475
  # Generate masks if requested
476
  if use_sam and detections:
477
  print("Generating SAM2 masks...")
478
  detections = generate_masks_sam2(detections, processed_image)
479
 
480
+ # Create visualization using your proven functions
481
  print("Creating visualization...")
482
  if use_sam and detections and 'mask' in detections[0]:
483
  result_image = visualize_detections_with_masks(
484
  processed_image,
485
  detections,
486
+ show_labels=True,
487
  show_boxes=True
488
  )
489
  print("Created visualization with masks")
 
491
  result_image = visualize_detections(
492
  processed_image,
493
  detections,
494
+ show_labels=True
495
  )
496
  print("Created visualization with bounding boxes only")
497
 
 
615
 
616
  # Launch
617
  if __name__ == "__main__":
618
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
619
+