Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import traceback | |
| from typing import Dict, List, Any, Optional | |
| logger = logging.getLogger(__name__) | |
| class ObjectExtractor: | |
| """ | |
| 專門處理物件檢測結果的提取和預處理 | |
| 負責從YOLO檢測結果提取物件資訊、物件分類和核心物件的辨識 | |
| """ | |
| def __init__(self, class_names: Dict[int, str] = None, object_categories: Dict[str, List[int]] = None): | |
| """ | |
| 初始化物件提取器 | |
| Args: | |
| class_names: 類別ID到類別名稱的映射字典 | |
| object_categories: 物件類別分組字典 | |
| """ | |
| try: | |
| self.class_names = class_names or {} | |
| self.object_categories = object_categories or {} | |
| # 1. 讀取並設定基本信心度門檻(如果外部沒傳,就預設 0.25) | |
| self.base_conf_threshold = 0.25 | |
| # 2. 動態信心度調整映射表 (key: 小寫 class_name, value: 調整係數) | |
| # 最終的門檻 = base_conf_threshold * factor | |
| # 如果某個 class_name 沒在這裡,就直接用 base_conf_threshold(相當於 factor=1.0) | |
| self.dynamic_conf_map = { | |
| "traffic light": 0.6, | |
| "car": 0.8, | |
| "person": 0.7, | |
| } | |
| logger.info(f"ObjectExtractor initialized with {len(self.class_names)} class names and {len(self.object_categories)} object categories") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize ObjectExtractor: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| raise | |
| def _get_dynamic_threshold(self, class_name: str) -> float: | |
| """ | |
| 根據 class_name 從 dynamic_conf_map 拿到 factor,計算最終的信心度門檻: | |
| threshold = base_conf_threshold * factor | |
| 如果 class_name 不在映射表裡,就回傳 base_conf_threshold。 | |
| """ | |
| # 使用小寫做匹配,確保在 dynamic_conf_map 裡的 key 也都用小寫 | |
| key = class_name.lower() | |
| factor = self.dynamic_conf_map.get(key, 1.0) | |
| return self.base_conf_threshold * factor | |
| def extract_detected_objects( | |
| self, | |
| detection_result: Any, | |
| confidence_threshold: float = 0.25, | |
| region_analyzer=None | |
| ) -> List[Dict]: | |
| """ | |
| 從檢測結果中提取物件資訊,包含位置資訊 | |
| Args: | |
| detection_result: YOLO檢測結果 | |
| confidence_threshold: 改由動態門檻決定 | |
| region_analyzer: 區域分析器實例,用於判斷物件所屬區域 | |
| Returns: | |
| 包含檢測物件資訊的字典列表 | |
| """ | |
| try: | |
| # 調試信息:記錄當前類別映射狀態 | |
| logger.info(f"ObjectExtractor.extract_detected_objects called") | |
| logger.info(f"Current class_names keys: {list(self.class_names.keys()) if self.class_names else 'None'}") | |
| if detection_result is None: | |
| logger.warning("Detection result is None") | |
| return [] | |
| if not hasattr(detection_result, 'boxes'): | |
| logger.error("Detection result does not have boxes attribute") | |
| return [] | |
| boxes = detection_result.boxes.xyxy.cpu().numpy() | |
| classes = detection_result.boxes.cls.cpu().numpy().astype(int) | |
| confidences = detection_result.boxes.conf.cpu().numpy() | |
| # 獲取圖像尺寸 | |
| img_height, img_width = detection_result.orig_shape[:2] | |
| detected_objects = [] | |
| for box, class_id, confidence in zip(boxes, classes, confidences): | |
| try: | |
| # 1. 先拿到這筆偵測物件的 class_name | |
| class_name = self.class_names.get(int(class_id), f"unknown_class_{class_id}") | |
| # 2. 計算這個 class 應該採用的動態 threshold | |
| dyn_thr = self._get_dynamic_threshold(class_name) # e.g. 0.25 * factor | |
| # 3. 如果 confidence < dyn_thr,就跳過這一筆 | |
| if confidence < dyn_thr: | |
| continue | |
| # 後面維持原本的座標、中心、大小、區域等資訊計算 | |
| x1, y1, x2, y2 = box | |
| width = x2 - x1 | |
| height = y2 - y1 | |
| # 中心點計算 | |
| center_x = (x1 + x2) / 2 | |
| center_y = (y1 + y2) / 2 | |
| # 標準化位置 (0-1) | |
| norm_x = center_x / img_width | |
| norm_y = center_y / img_height | |
| norm_width = width / img_width | |
| norm_height = height / img_height | |
| # 面積計算 | |
| area = width * height | |
| norm_area = area / (img_width * img_height) | |
| # 區域判斷 | |
| object_region = "unknown" | |
| if region_analyzer: | |
| object_region = region_analyzer.determine_region(norm_x, norm_y) | |
| # 調試信息:記錄映射過程 | |
| if class_name.startswith("unknown_class_"): | |
| logger.warning( | |
| f"Class ID {class_id} not found in class_names. " | |
| f"Available keys: {list(self.class_names.keys())}" | |
| ) | |
| else: | |
| logger.debug(f"Successfully mapped class ID {class_id} to '{class_name}'") | |
| detected_objects.append({ | |
| "class_id": int(class_id), | |
| "class_name": class_name, | |
| "confidence": float(confidence), | |
| "box": [float(x1), float(y1), float(x2), float(y2)], | |
| "center": [float(center_x), float(center_y)], | |
| "normalized_center": [float(norm_x), float(norm_y)], | |
| "size": [float(width), float(height)], | |
| "normalized_size": [float(norm_width), float(norm_height)], | |
| "area": float(area), | |
| "normalized_area": float(norm_area), | |
| "region": object_region | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error processing object with class_id {class_id}: {str(e)}") | |
| continue | |
| logger.info(f"Extracted {len(detected_objects)} objects from detection result") | |
| # print(f"DEBUG: ObjectExtractor filtered objects by class:") | |
| # for class_name in ["car", "traffic light", "person", "handbag"]: | |
| # class_objects = [obj for obj in detected_objects if obj.get("class_name") == class_name] | |
| # if class_objects: | |
| # confidences = [obj.get("confidence", 0) for obj in class_objects] | |
| # print(f"DEBUG: {class_name}: {len(class_objects)} objects, confidences: {confidences}") | |
| # print(f"DEBUG: base_conf_threshold: {self.base_conf_threshold}") | |
| # print(f"DEBUG: dynamic_conf_map: {self.dynamic_conf_map}") | |
| return detected_objects | |
| except Exception as e: | |
| logger.error(f"Error extracting detected objects: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return [] | |
| def update_class_names(self, class_names: Dict[int, str]): | |
| """ | |
| 動態更新類別名稱映射 | |
| Args: | |
| class_names: 新的類別名稱映射字典 | |
| """ | |
| try: | |
| self.class_names = class_names or {} | |
| logger.info(f"Class names updated: {len(self.class_names)} classes") | |
| logger.debug(f"Updated class names: {self.class_names}") | |
| except Exception as e: | |
| logger.error(f"Failed to update class names: {str(e)}") | |
| def categorize_object(self, obj: Dict) -> str: | |
| """ | |
| 將檢測到的物件分類到功能類別中,用於區域識別 | |
| Args: | |
| obj: 物件字典 | |
| Returns: | |
| 物件功能類別字串 | |
| """ | |
| try: | |
| class_id = obj.get("class_id", -1) | |
| class_name = obj.get("class_name", "").lower() | |
| # 使用現有的類別映射(如果可用) | |
| if self.object_categories: | |
| for category, ids in self.object_categories.items(): | |
| if class_id in ids: | |
| return category | |
| # 基於COCO類別名稱的後備分類 | |
| furniture_items = ["chair", "couch", "bed", "dining table", "toilet"] | |
| plant_items = ["potted plant"] | |
| electronic_items = ["tv", "laptop", "mouse", "remote", "keyboard", "cell phone"] | |
| vehicle_items = ["bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat"] | |
| person_items = ["person"] | |
| kitchen_items = ["bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", | |
| "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", | |
| "pizza", "donut", "cake", "refrigerator", "oven", "toaster", "sink", "microwave"] | |
| sports_items = ["frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", | |
| "baseball glove", "skateboard", "surfboard", "tennis racket"] | |
| personal_items = ["handbag", "tie", "suitcase", "umbrella", "backpack"] | |
| if any(item in class_name for item in furniture_items): | |
| return "furniture" | |
| elif any(item in class_name for item in plant_items): | |
| return "plant" | |
| elif any(item in class_name for item in electronic_items): | |
| return "electronics" | |
| elif any(item in class_name for item in vehicle_items): | |
| return "vehicle" | |
| elif any(item in class_name for item in person_items): | |
| return "person" | |
| elif any(item in class_name for item in kitchen_items): | |
| return "kitchen_items" | |
| elif any(item in class_name for item in sports_items): | |
| return "sports" | |
| elif any(item in class_name for item in personal_items): | |
| return "personal_items" | |
| else: | |
| return "misc" | |
| except Exception as e: | |
| logger.error(f"Error categorizing object: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return "misc" | |
| def get_object_categories(self, detected_objects: List[Dict]) -> set: | |
| """ | |
| 從檢測到的物件中取得唯一的物件類別 | |
| Args: | |
| detected_objects: 檢測到的物件列表 | |
| Returns: | |
| 唯一物件類別的集合 | |
| """ | |
| try: | |
| object_categories = set() | |
| for obj in detected_objects: | |
| category = self.categorize_object(obj) | |
| if category: | |
| object_categories.add(category) | |
| logger.info(f"Found {len(object_categories)} unique object categories") | |
| return object_categories | |
| except Exception as e: | |
| logger.error(f"Error getting object categories: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return set() | |
| def identify_core_objects_for_scene(self, detected_objects: List[Dict], scene_type: str) -> List[Dict]: | |
| """ | |
| 識別定義特定場景類型的核心物件 | |
| Args: | |
| detected_objects: 檢測到的物件列表 | |
| scene_type: 場景類型 | |
| Returns: | |
| 場景的核心物件列表 | |
| """ | |
| try: | |
| core_objects = [] | |
| # 場景核心物件映射 | |
| scene_core_mapping = { | |
| "bedroom": [59], # bed | |
| "kitchen": [68, 69, 71, 72], # microwave, oven, sink, refrigerator | |
| "living_room": [57, 58, 62], # sofa, chair, tv | |
| "dining_area": [60, 42, 43], # dining table, fork, knife | |
| "office_workspace": [63, 64, 66, 73] # laptop, mouse, keyboard, book | |
| } | |
| if scene_type in scene_core_mapping: | |
| core_class_ids = scene_core_mapping[scene_type] | |
| for obj in detected_objects: | |
| if obj.get("class_id") in core_class_ids and obj.get("confidence", 0) >= 0.4: | |
| core_objects.append(obj) | |
| logger.info(f"Identified {len(core_objects)} core objects for scene type '{scene_type}'") | |
| return core_objects | |
| except Exception as e: | |
| logger.error(f"Error identifying core objects for scene '{scene_type}': {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return [] | |
| def group_objects_by_category_and_region(self, detected_objects: List[Dict]) -> Dict: | |
| """ | |
| 將物件按類別和區域分組 | |
| Args: | |
| detected_objects: 檢測到的物件列表 | |
| Returns: | |
| 按類別和區域分組的物件字典 | |
| """ | |
| try: | |
| category_regions = {} | |
| for obj in detected_objects: | |
| category = self.categorize_object(obj) | |
| if not category: | |
| continue | |
| if category not in category_regions: | |
| category_regions[category] = {} | |
| region = obj.get("region", "center") | |
| if region not in category_regions[category]: | |
| category_regions[category][region] = [] | |
| category_regions[category][region].append(obj) | |
| logger.info(f"Grouped objects into {len(category_regions)} categories across regions") | |
| return category_regions | |
| except Exception as e: | |
| logger.error(f"Error grouping objects by category and region: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return {} | |
| def filter_objects_by_confidence(self, detected_objects: List[Dict], min_confidence: float) -> List[Dict]: | |
| """ | |
| 根據信心度過濾物件 | |
| Args: | |
| detected_objects: 檢測到的物件列表 | |
| min_confidence: 最小信心度閾值 | |
| Returns: | |
| 過濾後的物件列表 | |
| """ | |
| try: | |
| filtered_objects = [ | |
| obj for obj in detected_objects | |
| if obj.get("confidence", 0) >= min_confidence | |
| ] | |
| logger.info(f"Filtered {len(detected_objects)} objects to {len(filtered_objects)} objects with confidence >= {min_confidence}") | |
| return filtered_objects | |
| except Exception as e: | |
| logger.error(f"Error filtering objects by confidence: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return detected_objects # 發生錯誤時返回原始列表 | |