from typing import Dict, Any
import os
from traditional_classifier import TraditionalClassifier

try:
    from modern_classifier import ModernClassifier
    MODERN_MODELS_AVAILABLE = True
except ImportError:
    MODERN_MODELS_AVAILABLE = False


class ModelManager:
    """Manages different types of Arabic text classification models with per-request model selection and caching."""
    
    AVAILABLE_MODELS = {
        "traditional_svm": {
            "type": "traditional",
            "classifier_path": "models/traditional_svm_classifier.joblib",
            "vectorizer_path": "models/traditional_tfidf_vectorizer_classifier.joblib",
            "description": "Traditional SVM classifier with TF-IDF vectorization"
        },
        
        "modern_bert": {
            "type": "modern",
            "model_type": "bert",
            "model_path": "models/modern_bert_classifier.safetensors",
            "config_path": "config.json",
            "description": "Modern BERT-based transformer classifier"
        },
        
        "modern_lstm": {
            "type": "modern", 
            "model_type": "lstm",
            "model_path": "models/modern_lstm_classifier.pth",
            "description": "Modern LSTM-based neural network classifier"
        }
    }
    
    def __init__(self, default_model: str = "traditional_svm"):
        self.default_model = default_model
        self._model_cache = {}

    def _get_model(self, model_name: str):
        """Get model instance, loading from cache or creating new one."""
        if model_name not in self.AVAILABLE_MODELS:
            raise ValueError(f"Model '{model_name}' not available. Available models: {list(self.AVAILABLE_MODELS.keys())}")
        
        if model_name in self._model_cache:
            return self._model_cache[model_name]
        
        model_config = self.AVAILABLE_MODELS[model_name]
        
        if model_config["type"] == "traditional":
            classifier_path = model_config["classifier_path"]
            vectorizer_path = model_config["vectorizer_path"]
            
            if not os.path.exists(classifier_path):
                raise FileNotFoundError(f"Classifier file not found: {classifier_path}")
            if not os.path.exists(vectorizer_path):
                raise FileNotFoundError(f"Vectorizer file not found: {vectorizer_path}")
                
            model = TraditionalClassifier(classifier_path, vectorizer_path)
            
        elif model_config["type"] == "modern":
            if not MODERN_MODELS_AVAILABLE:
                raise ImportError("Modern models require PyTorch and transformers")
                
            model_path = model_config["model_path"]
            
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Model file not found: {model_path}")
                
            config_path = model_config.get("config_path")
            if config_path and not os.path.exists(config_path):
                config_path = None
                
            model = ModernClassifier(
                model_type=model_config["model_type"],
                model_path=model_path,
                config_path=config_path
            )
        
        self._model_cache[model_name] = model
        return model
    
    def predict(self, text: str, model_name: str = None) -> Dict[str, Any]:
        """Predict using the specified model (or default if none specified)."""
        if model_name is None:
            model_name = self.default_model
            
        model = self._get_model(model_name)
        result = model.predict(text)
        
        result["model_manager"] = {
            "model_used": model_name,
            "model_description": self.AVAILABLE_MODELS[model_name]["description"]
        }
        return result
    
    def predict_batch(self, texts: list, model_name: str = None) -> list:
        """Predict batch using the specified model (or default if none specified)."""
        if model_name is None:
            model_name = self.default_model
            
        model = self._get_model(model_name)
        results = model.predict_batch(texts)
        
        for result in results:
            result["model_manager"] = {
                "model_used": model_name,
                "model_description": self.AVAILABLE_MODELS[model_name]["description"]
            }
        return results
    
    def get_model_info(self, model_name: str = None) -> Dict[str, Any]:
        """Get information about a specific model (or default if none specified)."""
        if model_name is None:
            model_name = self.default_model
            
        model = self._get_model(model_name)
        model_info = model.get_model_info()
        model_info.update({
            "model_manager": {
                "model_name": model_name,
                "model_description": self.AVAILABLE_MODELS[model_name]["description"],
                "model_config": self.AVAILABLE_MODELS[model_name],
                "is_cached": model_name in self._model_cache
            }
        })
        return model_info
    
    def get_available_models(self) -> Dict[str, Any]:
        """Get list of all available models."""
        available = {}
        for model_name, config in self.AVAILABLE_MODELS.items():
            files_exist = True
            missing_files = []
            
            if config["type"] == "traditional":
                for file_key in ["classifier_path", "vectorizer_path"]:
                    if not os.path.exists(config[file_key]):
                        files_exist = False
                        missing_files.append(config[file_key])
            elif config["type"] == "modern":
                if not os.path.exists(config["model_path"]):
                    files_exist = False
                    missing_files.append(config["model_path"])
            
            available[model_name] = {
                "description": config["description"],
                "type": config["type"],
                "available": files_exist,
                "missing_files": missing_files if not files_exist else [],
                "is_default": model_name == self.default_model,
                "is_cached": model_name in self._model_cache
            }
        
        return available
    
    def clear_cache(self, model_name: str = None) -> Dict[str, Any]:
        """Clear model cache (specific model or all models)."""
        if model_name:
            if model_name in self._model_cache:
                del self._model_cache[model_name]
                return {"message": f"Cache cleared for model: {model_name}"}
            else:
                return {"message": f"Model {model_name} was not cached"}
        else:
            cleared_count = len(self._model_cache)
            self._model_cache.clear()
            return {"message": f"Cache cleared for {cleared_count} models"}
    
    def get_cache_status(self) -> Dict[str, Any]:
        """Get information about cached models."""
        return {
            "cached_models": list(self._model_cache.keys()),
            "cache_count": len(self._model_cache),
            "default_model": self.default_model
        }