import logging
from typing import Dict, List, Tuple, Optional, Any

class ObjectGroupProcessor:
    """
    物件組處理器 - 專門處理物件分組、排序和子句生成的邏輯
    負責物件按類別分組、重複物件檢測移除、物件組優先級排序以及描述子句的生成
    """

    def __init__(self, confidence_threshold_for_description: float = 0.25,
                 spatial_handler: Optional[Any] = None,
                 text_optimizer: Optional[Any] = None):
        """
        初始化物件組處理器

        Args:
            confidence_threshold_for_description: 用於描述的置信度閾值
            spatial_handler: 空間位置處理器實例
            text_optimizer: 文本優化器實例
        """
        self.logger = logging.getLogger(self.__class__.__name__)
        self.confidence_threshold_for_description = confidence_threshold_for_description
        self.spatial_handler = spatial_handler
        self.text_optimizer = text_optimizer

    def group_objects_by_class(self, confident_objects: List[Dict],
                              object_statistics: Optional[Dict]) -> Dict[str, List[Dict]]:
        """
        按類別分組物件

        Args:
            confident_objects: 置信度過濾後的物件
            object_statistics: 物件統計信息

        Returns:
            Dict[str, List[Dict]]: 按類別分組的物件
        """
        objects_by_class = {}

        if object_statistics:
            # 使用預計算的統計信息,採用動態的信心度
            for class_name, stats in object_statistics.items():
                count = stats.get("count", 0)
                avg_confidence = stats.get("avg_confidence", 0)

                # 動態調整置信度閾值
                dynamic_threshold = self.confidence_threshold_for_description
                if class_name in ["potted plant", "vase", "clock", "book"]:
                    dynamic_threshold = max(0.15, self.confidence_threshold_for_description * 0.6)
                elif count >= 3:
                    dynamic_threshold = max(0.2, self.confidence_threshold_for_description * 0.8)

                if count > 0 and avg_confidence >= dynamic_threshold:
                    matching_objects = [obj for obj in confident_objects if obj.get("class_name") == class_name]
                    if not matching_objects:
                        matching_objects = [obj for obj in confident_objects
                                          if obj.get("class_name") == class_name and obj.get("confidence", 0) >= dynamic_threshold]

                    if matching_objects:
                        actual_count = min(stats["count"], len(matching_objects))
                        objects_by_class[class_name] = matching_objects[:actual_count]

                        # Debug logging for specific classes
                        if class_name in ["car", "traffic light", "person", "handbag"]:
                            print(f"DEBUG: Before spatial deduplication:")
                            print(f"DEBUG: {class_name}: {len(objects_by_class[class_name])} objects before dedup")
        else:
            # 備用邏輯,同樣使用動態閾值
            for obj in confident_objects:
                name = obj.get("class_name", "unknown object")
                if name == "unknown object" or not name:
                    continue
                if name not in objects_by_class:
                    objects_by_class[name] = []
                objects_by_class[name].append(obj)

        return objects_by_class

    def remove_duplicate_objects(self, objects_by_class: Dict[str, List[Dict]]) -> Dict[str, List[Dict]]:
        """
        移除重複物件

        Args:
            objects_by_class: 按類別分組的物件

        Returns:
            Dict[str, List[Dict]]: 去重後的物件
        """
        deduplicated_objects_by_class = {}
        processed_positions = []

        for class_name, group_of_objects in objects_by_class.items():
            unique_objects = []

            for obj in group_of_objects:
                obj_position = obj.get("normalized_center", [0.5, 0.5])
                is_duplicate = False

                for processed_pos in processed_positions:
                    position_distance = abs(obj_position[0] - processed_pos[0]) + abs(obj_position[1] - processed_pos[1])
                    if position_distance < 0.15:
                        is_duplicate = True
                        break

                if not is_duplicate:
                    unique_objects.append(obj)
                    processed_positions.append(obj_position)

            if unique_objects:
                deduplicated_objects_by_class[class_name] = unique_objects

        # Debug logging after deduplication
        for class_name in ["car", "traffic light", "person", "handbag"]:
            if class_name in deduplicated_objects_by_class:
                print(f"DEBUG: After spatial deduplication:")
                print(f"DEBUG: {class_name}: {len(deduplicated_objects_by_class[class_name])} objects after dedup")

        return deduplicated_objects_by_class

    def sort_object_groups(self, objects_by_class: Dict[str, List[Dict]]) -> List[Tuple[str, List[Dict]]]:
        """
        排序物件組

        Args:
            objects_by_class: 按類別分組的物件

        Returns:
            List[Tuple[str, List[Dict]]]: 排序後的物件組
        """
        def sort_key_object_groups(item_tuple: Tuple[str, List[Dict]]):
            class_name_key, obj_group_list = item_tuple
            priority = 3
            count = len(obj_group_list)

            # 確保類別名稱已標準化
            normalized_class_name = self._normalize_object_class_name(class_name_key)

            # 動態優先級
            if normalized_class_name == "person":
                priority = 0
            elif normalized_class_name in ["dining table", "chair", "sofa", "bed"]:
                priority = 1
            elif normalized_class_name in ["car", "bus", "truck", "traffic light"]:
                priority = 2
            elif count >= 3:
                priority = max(1, priority - 1)
            elif normalized_class_name in ["potted plant", "vase", "clock", "book"] and count >= 2:
                priority = 2

            avg_area = sum(o.get("normalized_area", 0.0) for o in obj_group_list) / len(obj_group_list) if obj_group_list else 0
            quantity_bonus = min(count / 5.0, 1.0)

            return (priority, -len(obj_group_list), -avg_area, -quantity_bonus)

        return sorted(objects_by_class.items(), key=sort_key_object_groups)

    def generate_object_clauses(self, sorted_object_groups: List[Tuple[str, List[Dict]]],
                               object_statistics: Optional[Dict],
                               scene_type: str,
                               image_width: Optional[int],
                               image_height: Optional[int],
                               region_analyzer: Optional[Any] = None) -> List[str]:
        """
        生成物件描述子句

        Args:
            sorted_object_groups: 排序後的物件組
            object_statistics: 物件統計信息
            scene_type: 場景類型
            image_width: 圖像寬度
            image_height: 圖像高度
            region_analyzer: 區域分析器實例

        Returns:
            List[str]: 物件描述子句列表
        """
        object_clauses = []

        for class_name, group_of_objects in sorted_object_groups:
            count = len(group_of_objects)

            # Debug logging for final count
            if class_name in ["car", "traffic light", "person", "handbag"]:
                print(f"DEBUG: Final count for {class_name}: {count}")

            if count == 0:
                continue

            # 標準化class name
            normalized_class_name = self._normalize_object_class_name(class_name)

            # 使用統計信息確保準確的數量描述
            if object_statistics and class_name in object_statistics:
                actual_count = object_statistics[class_name]["count"]
                formatted_name_with_exact_count = self._format_object_count_description(
                    normalized_class_name,
                    actual_count,
                    scene_type=scene_type
                )
            else:
                formatted_name_with_exact_count = self._format_object_count_description(
                    normalized_class_name,
                    count,
                    scene_type=scene_type
                )

            if formatted_name_with_exact_count == "no specific objects clearly identified" or not formatted_name_with_exact_count:
                continue

            # 確定群組的集體位置
            location_description_suffix = self._generate_location_description(
                group_of_objects, count, image_width, image_height, region_analyzer
            )

            # 首字母大寫
            formatted_name_capitalized = formatted_name_with_exact_count[0].upper() + formatted_name_with_exact_count[1:]
            object_clauses.append(f"{formatted_name_capitalized} {location_description_suffix}")

        return object_clauses

    def format_object_clauses(self, object_clauses: List[str]) -> str:
        """
        格式化物件描述子句

        Args:
            object_clauses: 物件描述子句列表

        Returns:
            str: 格式化後的描述
        """
        if not object_clauses:
            return "No common objects were confidently identified for detailed description."

        # 處理第一個子句
        first_clause = object_clauses.pop(0)
        result = first_clause + "."

        # 處理剩餘子句
        if object_clauses:
            result += " The scene features:"
            joined_object_clauses = ". ".join(object_clauses)
            if joined_object_clauses and not joined_object_clauses.endswith("."):
                joined_object_clauses += "."
            result += " " + joined_object_clauses

        return result

    def _generate_location_description(self, group_of_objects: List[Dict], count: int,
                                     image_width: Optional[int], image_height: Optional[int],
                                     region_analyzer: Optional[Any] = None) -> str:
        """
        生成位置描述

        Args:
            group_of_objects: 物件組
            count: 物件數量
            image_width: 圖像寬度
            image_height: 圖像高度
            region_analyzer: 區域分析器實例

        Returns:
            str: 位置描述
        """
        if count == 1:
            if self.spatial_handler:
                spatial_desc = self.spatial_handler.generate_spatial_description(
                    group_of_objects[0], image_width, image_height, region_analyzer
                )
            else:
                spatial_desc = self._get_spatial_description_phrase(group_of_objects[0].get("region", ""))

            if spatial_desc:
                return f"is {spatial_desc}"
            else:
                distinct_regions = sorted(list(set(obj.get("region", "") for obj in group_of_objects if obj.get("region"))))
                valid_regions = [r for r in distinct_regions if r and r != "unknown" and r.strip()]
                if not valid_regions:
                    return "is positioned in the scene"
                elif len(valid_regions) == 1:
                    spatial_desc = self._get_spatial_description_phrase(valid_regions[0])
                    return f"is primarily {spatial_desc}" if spatial_desc else "is positioned in the scene"
                elif len(valid_regions) == 2:
                    clean_region1 = valid_regions[0].replace('_', ' ')
                    clean_region2 = valid_regions[1].replace('_', ' ')
                    return f"is mainly across the {clean_region1} and {clean_region2} areas"
                else:
                    return "is distributed in various parts of the scene"
        else:
            distinct_regions = sorted(list(set(obj.get("region", "") for obj in group_of_objects if obj.get("region"))))
            valid_regions = [r for r in distinct_regions if r and r != "unknown" and r.strip()]
            if not valid_regions:
                return "are visible in the scene"
            elif len(valid_regions) == 1:
                clean_region = valid_regions[0].replace('_', ' ')
                return f"are primarily in the {clean_region} area"
            elif len(valid_regions) == 2:
                clean_region1 = valid_regions[0].replace('_', ' ')
                clean_region2 = valid_regions[1].replace('_', ' ')
                return f"are mainly across the {clean_region1} and {clean_region2} areas"
            else:
                return "are distributed in various parts of the scene"

    def _get_spatial_description_phrase(self, region: str) -> str:
        """
        獲取空間描述短語的備用方法

        Args:
            region: 區域字符串

        Returns:
            str: 空間描述短語
        """
        if not region or region == "unknown":
            return ""

        clean_region = region.replace('_', ' ').strip().lower()

        region_map = {
            "top left": "in the upper left area",
            "top center": "in the upper area",
            "top right": "in the upper right area",
            "middle left": "on the left side",
            "middle center": "in the center",
            "center": "in the center",
            "middle right": "on the right side",
            "bottom left": "in the lower left area",
            "bottom center": "in the lower area",
            "bottom right": "in the lower right area"
        }

        return region_map.get(clean_region, "")

    def _normalize_object_class_name(self, class_name: str) -> str:
        """
        標準化物件類別名稱

        Args:
            class_name: 原始類別名稱

        Returns:
            str: 標準化後的類別名稱
        """
        if self.text_optimizer:
            return self.text_optimizer.normalize_object_class_name(class_name)
        else:
            # 備用標準化邏輯
            if not class_name or not isinstance(class_name, str):
                return "object"

            # 簡單的標準化處理
            normalized = class_name.replace('_', ' ').strip().lower()
            return normalized

    def _format_object_count_description(self, class_name: str, count: int,
                                       scene_type: Optional[str] = None,
                                       detected_objects: Optional[List[Dict]] = None,
                                       avg_confidence: float = 0.0) -> str:
        """
        格式化物件數量描述

        Args:
            class_name: 標準化後的類別名稱
            count: 物件數量
            scene_type: 場景類型
            detected_objects: 該類型的所有檢測物件
            avg_confidence: 平均檢測置信度

        Returns:
            str: 完整的格式化數量描述
        """
        if self.text_optimizer:
            return self.text_optimizer.format_object_count_description(
                class_name, count, scene_type, detected_objects, avg_confidence
            )
        else:
            # 備用格式化邏輯
            if count <= 0:
                return ""
            elif count == 1:
                article = "an" if class_name[0].lower() in 'aeiou' else "a"
                return f"{article} {class_name}"
            else:
                # 簡單的複數處理
                plural_form = class_name + "s" if not class_name.endswith("s") else class_name

                number_words = {
                    2: "two", 3: "three", 4: "four", 5: "five", 6: "six",
                    7: "seven", 8: "eight", 9: "nine", 10: "ten",
                    11: "eleven", 12: "twelve"
                }

                if count in number_words:
                    return f"{number_words[count]} {plural_form}"
                elif count <= 20:
                    return f"several {plural_form}"
                else:
                    return f"numerous {plural_form}"