from typing import Dict, Any import os from traditional_classifier import TraditionalClassifier try: from modern_classifier import ModernClassifier MODERN_MODELS_AVAILABLE = True except ImportError: MODERN_MODELS_AVAILABLE = False class ModelManager: """Manages different types of Arabic text classification models with per-request model selection and caching.""" AVAILABLE_MODELS = { "traditional_svm": { "type": "traditional", "classifier_path": "models/traditional_svm_classifier.joblib", "vectorizer_path": "models/traditional_tfidf_vectorizer_classifier.joblib", "description": "Traditional SVM classifier with TF-IDF vectorization" }, "modern_bert": { "type": "modern", "model_type": "bert", "model_path": "models/modern_bert_classifier.safetensors", "config_path": "config.json", "description": "Modern BERT-based transformer classifier" }, "modern_lstm": { "type": "modern", "model_type": "lstm", "model_path": "models/modern_lstm_classifier.pth", "description": "Modern LSTM-based neural network classifier" } } def __init__(self, default_model: str = "traditional_svm"): self.default_model = default_model self._model_cache = {} def _get_model(self, model_name: str): """Get model instance, loading from cache or creating new one.""" if model_name not in self.AVAILABLE_MODELS: raise ValueError(f"Model '{model_name}' not available. Available models: {list(self.AVAILABLE_MODELS.keys())}") if model_name in self._model_cache: return self._model_cache[model_name] model_config = self.AVAILABLE_MODELS[model_name] if model_config["type"] == "traditional": classifier_path = model_config["classifier_path"] vectorizer_path = model_config["vectorizer_path"] if not os.path.exists(classifier_path): raise FileNotFoundError(f"Classifier file not found: {classifier_path}") if not os.path.exists(vectorizer_path): raise FileNotFoundError(f"Vectorizer file not found: {vectorizer_path}") model = TraditionalClassifier(classifier_path, vectorizer_path) elif model_config["type"] == "modern": if not MODERN_MODELS_AVAILABLE: raise ImportError("Modern models require PyTorch and transformers") model_path = model_config["model_path"] if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") config_path = model_config.get("config_path") if config_path and not os.path.exists(config_path): config_path = None model = ModernClassifier( model_type=model_config["model_type"], model_path=model_path, config_path=config_path ) self._model_cache[model_name] = model return model def predict(self, text: str, model_name: str = None) -> Dict[str, Any]: """Predict using the specified model (or default if none specified).""" if model_name is None: model_name = self.default_model model = self._get_model(model_name) result = model.predict(text) result["model_manager"] = { "model_used": model_name, "model_description": self.AVAILABLE_MODELS[model_name]["description"] } return result def predict_batch(self, texts: list, model_name: str = None) -> list: """Predict batch using the specified model (or default if none specified).""" if model_name is None: model_name = self.default_model model = self._get_model(model_name) results = model.predict_batch(texts) for result in results: result["model_manager"] = { "model_used": model_name, "model_description": self.AVAILABLE_MODELS[model_name]["description"] } return results def get_model_info(self, model_name: str = None) -> Dict[str, Any]: """Get information about a specific model (or default if none specified).""" if model_name is None: model_name = self.default_model model = self._get_model(model_name) model_info = model.get_model_info() model_info.update({ "model_manager": { "model_name": model_name, "model_description": self.AVAILABLE_MODELS[model_name]["description"], "model_config": self.AVAILABLE_MODELS[model_name], "is_cached": model_name in self._model_cache } }) return model_info def get_available_models(self) -> Dict[str, Any]: """Get list of all available models.""" available = {} for model_name, config in self.AVAILABLE_MODELS.items(): files_exist = True missing_files = [] if config["type"] == "traditional": for file_key in ["classifier_path", "vectorizer_path"]: if not os.path.exists(config[file_key]): files_exist = False missing_files.append(config[file_key]) elif config["type"] == "modern": if not os.path.exists(config["model_path"]): files_exist = False missing_files.append(config["model_path"]) available[model_name] = { "description": config["description"], "type": config["type"], "available": files_exist, "missing_files": missing_files if not files_exist else [], "is_default": model_name == self.default_model, "is_cached": model_name in self._model_cache } return available def clear_cache(self, model_name: str = None) -> Dict[str, Any]: """Clear model cache (specific model or all models).""" if model_name: if model_name in self._model_cache: del self._model_cache[model_name] return {"message": f"Cache cleared for model: {model_name}"} else: return {"message": f"Model {model_name} was not cached"} else: cleared_count = len(self._model_cache) self._model_cache.clear() return {"message": f"Cache cleared for {cleared_count} models"} def get_cache_status(self) -> Dict[str, Any]: """Get information about cached models.""" return { "cached_models": list(self._model_cache.keys()), "cache_count": len(self._model_cache), "default_model": self.default_model }