obichimav commited on
Commit
085a78a
Β·
verified Β·
1 Parent(s): e957e17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -53
app.py CHANGED
@@ -7,10 +7,24 @@ from transformers import pipeline
7
  import warnings
8
  from io import BytesIO
9
  import importlib.util
 
 
10
 
11
  # Suppress warnings
12
  warnings.filterwarnings("ignore")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Global variables for models
15
  detector = None
16
  sam_predictor = None
@@ -248,34 +262,38 @@ def visualize_detections(image, detections, show_labels=True):
248
  fig = plt.figure(figsize=(12, 8))
249
  plt.imshow(image_np)
250
 
251
- # Define colors for different instances
252
- colors = plt.cm.tab10(np.linspace(0, 1, 10))
253
-
254
- # Plot each detection
255
- for i, detection in enumerate(detections):
256
- # Get bbox, label, and score
257
- bbox = detection['bbox']
258
- label = detection['label']
259
- score = detection['score']
260
-
261
- # Convert normalized bbox to pixel coordinates
262
- x1, y1, x2, y2 = bbox
263
- x1_px, y1_px = int(x1 * width), int(y1 * height)
264
- x2_px, y2_px = int(x2 * width), int(y2 * height)
265
-
266
- # Color for this instance
267
- color = colors[i % len(colors)]
268
-
269
- # Draw bounding box
270
- rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px,
271
- fill=False, edgecolor=color, linewidth=2)
272
- plt.gca().add_patch(rect)
273
-
274
- # Add label and score if requested
275
- if show_labels:
276
- plt.text(x1_px, y1_px - 5, f"{label}: {score:.2f}",
277
- color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10)
 
 
278
 
 
 
279
  plt.axis('off')
280
  plt.tight_layout()
281
 
@@ -288,6 +306,99 @@ def visualize_detections(image, detections, show_labels=True):
288
  result_image = Image.open(buf)
289
  return result_image
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  def is_count_query(text):
292
  """Check if the query is asking for counting."""
293
  count_keywords = ["how many", "count", "number of", "total"]
@@ -299,61 +410,54 @@ def detection_pipeline(query_text, image, threshold, use_sam):
299
  return None, "⚠️ Please upload an image first!"
300
 
301
  try:
302
- # Extract object name from query
303
- query_lower = query_text.lower()
304
-
305
- # Simple keyword extraction
306
- if "people" in query_lower or "person" in query_lower:
307
- search_terms = "person"
308
- elif "car" in query_lower or "vehicle" in query_lower:
309
- search_terms = "car"
310
- elif "apple" in query_lower:
311
- search_terms = "apple"
312
- elif "bottle" in query_lower:
313
- search_terms = "bottle"
314
- elif "phone" in query_lower:
315
- search_terms = "phone"
316
- elif "dog" in query_lower:
317
- search_terms = "dog"
318
- elif "cat" in query_lower:
319
- search_terms = "cat"
320
- else:
321
- # Extract last word as potential object
322
- words = query_text.strip().split()
323
- search_terms = words[-1] if words else "object"
324
 
325
  print(f"Processing query: '{query_text}' -> searching for: '{search_terms}'")
326
 
327
  # Run object detection
328
  detections, processed_image = detect_objects_owlv2(search_terms, image, threshold)
329
 
 
 
 
 
330
  # Generate masks if requested
331
  if use_sam and detections:
 
332
  detections = generate_masks_sam2(detections, processed_image)
333
 
334
  # Create visualization using your proven functions
335
- if use_sam and detections:
 
336
  result_image = visualize_detections_with_masks(
337
  processed_image,
338
  detections,
339
  show_labels=True,
340
  show_boxes=True
341
  )
 
342
  else:
343
  result_image = visualize_detections(
344
  processed_image,
345
  detections,
346
  show_labels=True
347
  )
 
 
 
 
 
 
348
 
349
  # Generate summary
350
  count = len(detections)
351
 
352
  summary_parts = []
353
- summary_parts.append(f"πŸ” **Search Query**: '{query_text}'")
354
- summary_parts.append(f"🎯 **Detected Object Type**: '{search_terms}'")
355
  summary_parts.append(f"βš™οΈ **Threshold**: {threshold}")
356
- summary_parts.append(f"πŸ€– **SAM2 Segmentation**: {'Enabled' if use_sam else 'Disabled'}")
357
 
358
  if count > 0:
359
  if is_count_query(query_text):
 
7
  import warnings
8
  from io import BytesIO
9
  import importlib.util
10
+ import os
11
+ import openai
12
 
13
  # Suppress warnings
14
  warnings.filterwarnings("ignore")
15
 
16
+ # Set up OpenAI API key
17
+ api_key = os.getenv('OPENAI_API_KEY')
18
+ if not api_key:
19
+ print("No OpenAI API key found - will use simple keyword extraction")
20
+ elif not api_key.startswith("sk-proj-") and not api_key.startswith("sk-"):
21
+ print("API key found but doesn't look correct")
22
+ elif api_key.strip() != api_key:
23
+ print("API key has leading or trailing whitespace - please fix it.")
24
+ else:
25
+ print("OpenAI API key found and looks good!")
26
+ openai.api_key = api_key
27
+
28
  # Global variables for models
29
  detector = None
30
  sam_predictor = None
 
262
  fig = plt.figure(figsize=(12, 8))
263
  plt.imshow(image_np)
264
 
265
+ # If we have detections, draw them
266
+ if detections:
267
+ # Define colors for different instances
268
+ colors = plt.cm.tab10(np.linspace(0, 1, 10))
269
+
270
+ # Plot each detection
271
+ for i, detection in enumerate(detections):
272
+ # Get bbox, label, and score
273
+ bbox = detection['bbox']
274
+ label = detection['label']
275
+ score = detection['score']
276
+
277
+ # Convert normalized bbox to pixel coordinates
278
+ x1, y1, x2, y2 = bbox
279
+ x1_px, y1_px = int(x1 * width), int(y1 * height)
280
+ x2_px, y2_px = int(x2 * width), int(y2 * height)
281
+
282
+ # Color for this instance
283
+ color = colors[i % len(colors)]
284
+
285
+ # Draw bounding box
286
+ rect = plt.Rectangle((x1_px, y1_px), x2_px - x1_px, y2_px - y1_px,
287
+ fill=False, edgecolor=color, linewidth=2)
288
+ plt.gca().add_patch(rect)
289
+
290
+ # Add label and score if requested
291
+ if show_labels:
292
+ plt.text(x1_px, y1_px - 5, f"{label}: {score:.2f}",
293
+ color='white', bbox=dict(facecolor=color, alpha=0.8), fontsize=10)
294
 
295
+ # Set title
296
+ plt.title(f'Object Detection Results ({len(detections)} objects found)', fontsize=14, pad=20)
297
  plt.axis('off')
298
  plt.tight_layout()
299
 
 
306
  result_image = Image.open(buf)
307
  return result_image
308
 
309
+ def get_optimized_prompt(query_text):
310
+ """
311
+ Use OpenAI to convert natural language query into optimal detection prompt.
312
+ Falls back to simple extraction if OpenAI is not available.
313
+ """
314
+ if not query_text.strip():
315
+ return "object"
316
+
317
+ # Try OpenAI first if API key is available
318
+ if hasattr(openai, 'api_key') and openai.api_key:
319
+ try:
320
+ response = openai.chat.completions.create(
321
+ model="gpt-3.5-turbo",
322
+ messages=[{
323
+ "role": "system",
324
+ "content": """You are an expert at converting natural language queries into precise object detection terms.
325
+
326
+ RULES:
327
+ 1. Return ONLY 1-2 words maximum that describe the object to detect
328
+ 2. Use the exact object name from the user's query
329
+ 3. For people: use "person"
330
+ 4. For vehicles: use "car", "truck", "bicycle"
331
+ 5. Do NOT include counting words, articles, or explanations
332
+ 6. Examples:
333
+ - "How many cacao fruits are there?" β†’ "cacao fruit"
334
+ - "Count the corn in the field" β†’ "corn"
335
+ - "Find all people" β†’ "person"
336
+ - "How many cacao pods?" β†’ "cacao pod"
337
+ - "Detect cars" β†’ "car"
338
+ - "Count bananas" β†’ "banana"
339
+ - "How many apples?" β†’ "apple"
340
+
341
+ Return ONLY the object name, nothing else."""
342
+ }, {
343
+ "role": "user",
344
+ "content": query_text
345
+ }],
346
+ temperature=0.0, # Make it deterministic
347
+ max_tokens=5 # Force brevity
348
+ )
349
+
350
+ llm_result = response.choices[0].message.content.strip().lower()
351
+ # Extra safety: take only first 2 words
352
+ words = llm_result.split()[:2]
353
+ final_result = " ".join(words)
354
+
355
+ print(f"πŸ€– OpenAI suggested prompt: '{final_result}'")
356
+ return final_result
357
+
358
+ except Exception as e:
359
+ print(f"OpenAI error: {e}, falling back to keyword extraction")
360
+
361
+ # Fallback to simple keyword extraction (no hardcoded fruits)
362
+ print("πŸ”€ Using keyword extraction (no OpenAI)")
363
+ query_lower = query_text.lower().replace("?", "").strip()
364
+
365
+ # Look for common patterns and extract object names
366
+ if "how many" in query_lower:
367
+ parts = query_lower.split("how many")
368
+ if len(parts) > 1:
369
+ remaining = parts[1].strip()
370
+ remaining = remaining.replace("are", "").replace("in", "").replace("the", "").replace("image", "").replace("there", "").strip()
371
+ # Take first meaningful word(s)
372
+ words = remaining.split()[:2]
373
+ search_terms = " ".join(words) if words else "object"
374
+ else:
375
+ search_terms = "object"
376
+ elif "count" in query_lower:
377
+ parts = query_lower.split("count")
378
+ if len(parts) > 1:
379
+ remaining = parts[1].strip()
380
+ remaining = remaining.replace("the", "").replace("in", "").replace("image", "").strip()
381
+ words = remaining.split()[:2]
382
+ search_terms = " ".join(words) if words else "object"
383
+ else:
384
+ search_terms = "object"
385
+ elif "find" in query_lower:
386
+ parts = query_lower.split("find")
387
+ if len(parts) > 1:
388
+ remaining = parts[1].strip()
389
+ remaining = remaining.replace("all", "").replace("the", "").replace("in", "").replace("image", "").strip()
390
+ words = remaining.split()[:2]
391
+ search_terms = " ".join(words) if words else "object"
392
+ else:
393
+ search_terms = "object"
394
+ else:
395
+ # Extract first 1-2 meaningful words from the query
396
+ words = query_lower.split()
397
+ meaningful_words = [w for w in words if w not in ["how", "many", "are", "in", "the", "image", "find", "count", "detect", "there", "this", "that", "a", "an"]]
398
+ search_terms = " ".join(meaningful_words[:2]) if meaningful_words else "object"
399
+
400
+ return search_terms
401
+
402
  def is_count_query(text):
403
  """Check if the query is asking for counting."""
404
  count_keywords = ["how many", "count", "number of", "total"]
 
410
  return None, "⚠️ Please upload an image first!"
411
 
412
  try:
413
+ # Use OpenAI or fallback to get optimized search terms
414
+ search_terms = get_optimized_prompt(query_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
  print(f"Processing query: '{query_text}' -> searching for: '{search_terms}'")
417
 
418
  # Run object detection
419
  detections, processed_image = detect_objects_owlv2(search_terms, image, threshold)
420
 
421
+ print(f"Found {len(detections)} detections")
422
+ for i, det in enumerate(detections):
423
+ print(f"Detection {i+1}: {det['label']} (score: {det['score']:.3f})")
424
+
425
  # Generate masks if requested
426
  if use_sam and detections:
427
+ print("Generating SAM2 masks...")
428
  detections = generate_masks_sam2(detections, processed_image)
429
 
430
  # Create visualization using your proven functions
431
+ print("Creating visualization...")
432
+ if use_sam and detections and 'mask' in detections[0]:
433
  result_image = visualize_detections_with_masks(
434
  processed_image,
435
  detections,
436
  show_labels=True,
437
  show_boxes=True
438
  )
439
+ print("Created visualization with masks")
440
  else:
441
  result_image = visualize_detections(
442
  processed_image,
443
  detections,
444
  show_labels=True
445
  )
446
+ print("Created visualization with bounding boxes only")
447
+
448
+ # Make sure we have a valid result image
449
+ if result_image is None:
450
+ print("Warning: result_image is None, returning original image")
451
+ result_image = processed_image
452
 
453
  # Generate summary
454
  count = len(detections)
455
 
456
  summary_parts = []
457
+ summary_parts.append(f"πŸ—£οΈ **Original Query**: '{query_text}'")
458
+ summary_parts.append(f"πŸ€– **AI-Optimized Search**: '{search_terms}'")
459
  summary_parts.append(f"βš™οΈ **Threshold**: {threshold}")
460
+ summary_parts.append(f"🎭 **SAM2 Segmentation**: {'Enabled' if use_sam else 'Disabled'}")
461
 
462
  if count > 0:
463
  if is_count_query(query_text):