import os
import re
import torch
import logging
import threading
from typing import Dict, Optional, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import login

class ModelLoadingError(Exception):
    """Custom exception for model loading failures"""
    pass


class ModelGenerationError(Exception):
    """Custom exception for model generation failures"""
    pass


class LLMModelManager:
    """
    負責LLM模型的載入、設備管理和文本生成。
    管理模型、記憶體優化和設備配置。
    實現單例模式確保全應用程式只有一個模型載入方式。
    """
    
    _instance = None
    _initialized = False
    _lock = threading.Lock()

    def __new__(cls, *args, **kwargs):
        """
        單例模式實現:確保整個應用程式只創建一個 LLMModelManager 
        """
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super(LLMModelManager, cls).__new__(cls)
        return cls._instance

    def __init__(self,
                 model_path: Optional[str] = None,
                 tokenizer_path: Optional[str] = None,
                 device: Optional[str] = None,
                 max_length: int = 2048,
                 temperature: float = 0.3,
                 top_p: float = 0.85):
        """
        初始化模型管理器(只在第一次創建實例時執行)

        Args:
            model_path: LLM模型的路徑或HuggingFace模型名稱,默認使用Llama 3.2
            tokenizer_path: tokenizer的路徑,通常與model_path相同
            device: 運行設備 ('cpu'或'cuda'),None時自動檢測
            max_length: 輸入文本的最大長度
            temperature: 生成文本的溫度參數
            top_p: 生成文本時的核心採樣機率閾值
        """
        # 避免重複初始化
        if self._initialized:
            return
            
        with self._lock:
            if self._initialized:
                return
                
            # set logger
            self.logger = logging.getLogger(self.__class__.__name__)
            if not self.logger.handlers:
                handler = logging.StreamHandler()
                formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
                handler.setFormatter(formatter)
                self.logger.addHandler(handler)
                self.logger.setLevel(logging.INFO)

            # model config
            self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct"
            self.tokenizer_path = tokenizer_path or self.model_path

            # device management
            self.device = self._detect_device(device)
            self.logger.info(f"Device selected: {self.device}")

            # 生成參數
            self.max_length = max_length
            self.temperature = temperature
            self.top_p = top_p

            # 模型狀態
            self.model = None
            self.tokenizer = None
            self._model_loaded = False
            self.call_count = 0

            # HuggingFace認證
            self.hf_token = self._setup_huggingface_auth()
            
            # 標記為已初始化
            self._initialized = True
            self.logger.info("LLMModelManager singleton initialized")

    def _detect_device(self, device: Optional[str]) -> str:
        """
        檢測並設置運行設備

        Args:
            device: 用戶指定的設備,None時自動檢測

        Returns:
            str: ('cuda' or 'cpu')
        """
        if device:
            if device == 'cuda' and not torch.cuda.is_available():
                self.logger.warning("CUDA requested but not available, falling back to CPU")
                return 'cpu'
            return device

        detected_device = 'cuda' if torch.cuda.is_available() else 'cpu'

        if detected_device == 'cuda':
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
            self.logger.info(f"CUDA detected with {gpu_memory:.2f} GB GPU memory")

        return detected_device

    def _setup_huggingface_auth(self) -> Optional[str]:
        """
        設置HuggingFace認證

        Returns:
            Optional[str]: HuggingFace token,如果可用
        """
        hf_token = os.environ.get("HF_TOKEN")

        if hf_token:
            try:
                login(token=hf_token)
                self.logger.info("Successfully authenticated with HuggingFace")
                return hf_token
            except Exception as e:
                self.logger.error(f"HuggingFace authentication failed: {e}")
                return None
        else:
            self.logger.warning("HF_TOKEN not found. Access to gated models may be limited")
            return None

    def _load_model(self):
        """
        載入LLM模型和tokenizer,使用8位量化以節省記憶體
        增強的狀態檢查確保模型只載入一次

        Raises:
            ModelLoadingError: 當模型載入失敗時
        """
        # 完整的模型狀態檢查
        if (self._model_loaded and 
            hasattr(self, 'model') and self.model is not None and
            hasattr(self, 'tokenizer') and self.tokenizer is not None):
            self.logger.info("Model already loaded, skipping reload")
            return

        try:
            self.logger.info(f"Loading model from {self.model_path} with 8-bit quantization")

            # 清理GPU記憶體
            self._clear_gpu_cache()

            # 設置8位量化配置
            quantization_config = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_enable_fp32_cpu_offload=True
            )

            # 載入tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.tokenizer_path,
                padding_side="left",
                use_fast=False,
                token=self.hf_token
            )

            # 設置特殊標記
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            # 載入模型
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                quantization_config=quantization_config,
                device_map="auto",
                low_cpu_mem_usage=True,
                token=self.hf_token
            )

            self._model_loaded = True
            self.logger.info("Model loaded successfully (singleton instance)")

        except Exception as e:
            error_msg = f"Failed to load model: {str(e)}"
            self.logger.error(error_msg)
            raise ModelLoadingError(error_msg) from e

    def _clear_gpu_cache(self):
        """清理GPU記憶體緩存"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            self.logger.debug("GPU cache cleared")

    def generate_response(self, prompt: str, **generation_kwargs) -> str:
        # 確保模型已載入
        if not self._model_loaded:
            self._load_model()

        try:
            self.call_count += 1
            self.logger.info(f"Generating response (call #{self.call_count})")

            # # record input prompt
            # self.logger.info(f"DEBUG: Input prompt length: {len(prompt)}")
            # self.logger.info(f"DEBUG: Input prompt preview: {prompt[:200]}...")

            # clean GPU
            self._clear_gpu_cache()

            # 設置固定種子以提高一致性
            torch.manual_seed(42)

            # prepare input
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=self.max_length
            ).to(self.device)

            # 準備生成參數
            generation_params = self._prepare_generation_params(**generation_kwargs)
            generation_params.update({
                "pad_token_id": self.tokenizer.eos_token_id,
                "attention_mask": inputs.attention_mask,
                "use_cache": True,
            })

            # response
            with torch.no_grad():
                outputs = self.model.generate(inputs.input_ids, **generation_params)

            # 解碼回應
            full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # # record whole response
            # self.logger.info(f"DEBUG: Full LLM response: {full_response}")

            response = self._extract_generated_response(full_response, prompt)

            # # 記錄提取後的回應
            # self.logger.info(f"DEBUG: Extracted response: {response}")

            if not response or len(response.strip()) < 10:
                raise ModelGenerationError("Generated response is too short or empty")

            self.logger.info(f"Response generated successfully ({len(response)} characters)")
            return response

        except Exception as e:
            error_msg = f"Text generation failed: {str(e)}"
            self.logger.error(error_msg)
            raise ModelGenerationError(error_msg) from e

    def _prepare_generation_params(self, **kwargs) -> Dict[str, Any]:
        """
        準備生成參數,支援模型特定的優化

        Args:
            **kwargs: 用戶提供的生成參數

        Returns:
            Dict[str, Any]: 完整的生成參數配置
        """
        # basic parameters
        params = {
            "max_new_tokens": 120,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "do_sample": True,
        }

        # 針對Llama模型的特殊優化
        if "llama" in self.model_path.lower():
            params.update({
                "max_new_tokens": 600,
                "temperature": 0.35, # not too big
                "top_p": 0.75,
                "repetition_penalty": 1.5,
                "num_beams": 5,
                "length_penalty": 1,
                "no_repeat_ngram_size": 3
            })
        else:
            params.update({
                "max_new_tokens": 300,
                "temperature": 0.6,
                "top_p": 0.9,
                "num_beams": 1,
                "repetition_penalty": 1.05
            })

        # 用戶參數覆蓋預設值
        params.update(kwargs)

        return params

    def _extract_generated_response(self, full_response: str, prompt: str) -> str:
        """
        從完整回應中提取生成的部分
        """
        # 尋找assistant標記
        assistant_tag = "<|assistant|>"
        if assistant_tag in full_response:
            response = full_response.split(assistant_tag)[-1].strip()

            # 檢查是否有未閉合的user標記
            user_tag = "<|user|>"
            if user_tag in response:
                response = response.split(user_tag)[0].strip()
        else:
            # 移除輸入提示詞
            if full_response.startswith(prompt):
                response = full_response[len(prompt):].strip()
            else:
                response = full_response.strip()

        # 移除不自然的場景類型前綴
        response = self._remove_scene_type_prefixes(response)

        return response

    def _remove_scene_type_prefixes(self, response: str) -> str:
        """
        移除LLM生成回應中的場景類型前綴

        Args:
            response: 原始LLM回應

        Returns:
            str: 移除前綴後的回應
        """
        if not response:
            return response

        prefix_patterns = [r'^[A-Za-z]+\,\s*']

        # 應用清理模式
        for pattern in prefix_patterns:
            response = re.sub(pattern, '', response, flags=re.IGNORECASE)

        # 確保首字母大寫
        if response and response[0].islower():
            response = response[0].upper() + response[1:]

        return response.strip()

    def reset_context(self):
        """重置模型上下文,清理GPU緩存"""
        if self._model_loaded:
            self._clear_gpu_cache()
            self.logger.info("Model context reset (singleton instance)")
        else:
            self.logger.info("Model not loaded, no context to reset")

    def get_current_device(self) -> str:
        """
        獲取當前運行設備

        Returns:
            str: 當前設備名稱
        """
        return self.device

    def is_model_loaded(self) -> bool:
        """
        檢查模型是否已載入

        Returns:
            bool: 模型載入狀態
        """
        return self._model_loaded

    def get_call_count(self) -> int:
        """
        獲取模型調用次數

        Returns:
            int: 調用次數
        """
        return self.call_count

    def get_model_info(self) -> Dict[str, Any]:
        """
        獲取模型信息

        Returns:
            Dict[str, Any]: 包含模型路徑、設備、載入狀態等信息
        """
        return {
            "model_path": self.model_path,
            "device": self.device,
            "is_loaded": self._model_loaded,
            "call_count": self.call_count,
            "has_hf_token": self.hf_token is not None,
            "is_singleton": True
        }

    @classmethod
    def reset_singleton(cls):
        """
        重置單例實例(僅用於測試或應用程式重啟)
        注意:這會導致模型需要重新載入
        """
        with cls._lock:
            if cls._instance is not None:
                instance = cls._instance
                if hasattr(instance, 'logger'):
                    instance.logger.info("Resetting singleton instance")
                cls._instance = None
                cls._initialized = False