|
from typing import Dict, Any |
|
import os |
|
from traditional_classifier import TraditionalClassifier |
|
|
|
try: |
|
from modern_classifier import ModernClassifier |
|
MODERN_MODELS_AVAILABLE = True |
|
except ImportError: |
|
MODERN_MODELS_AVAILABLE = False |
|
|
|
|
|
class ModelManager: |
|
"""Manages different types of Arabic text classification models with per-request model selection and caching.""" |
|
|
|
AVAILABLE_MODELS = { |
|
"traditional_svm": { |
|
"type": "traditional", |
|
"classifier_path": "models/traditional_svm_classifier.joblib", |
|
"vectorizer_path": "models/traditional_tfidf_vectorizer_classifier.joblib", |
|
"description": "Traditional SVM classifier with TF-IDF vectorization" |
|
}, |
|
|
|
"modern_bert": { |
|
"type": "modern", |
|
"model_type": "bert", |
|
"model_path": "models/modern_bert_classifier.safetensors", |
|
"config_path": "config.json", |
|
"description": "Modern BERT-based transformer classifier" |
|
}, |
|
|
|
"modern_lstm": { |
|
"type": "modern", |
|
"model_type": "lstm", |
|
"model_path": "models/modern_lstm_classifier.pth", |
|
"description": "Modern LSTM-based neural network classifier" |
|
} |
|
} |
|
|
|
def __init__(self, default_model: str = "traditional_svm"): |
|
self.default_model = default_model |
|
self._model_cache = {} |
|
|
|
def _get_model(self, model_name: str): |
|
"""Get model instance, loading from cache or creating new one.""" |
|
if model_name not in self.AVAILABLE_MODELS: |
|
raise ValueError(f"Model '{model_name}' not available. Available models: {list(self.AVAILABLE_MODELS.keys())}") |
|
|
|
if model_name in self._model_cache: |
|
return self._model_cache[model_name] |
|
|
|
model_config = self.AVAILABLE_MODELS[model_name] |
|
|
|
if model_config["type"] == "traditional": |
|
classifier_path = model_config["classifier_path"] |
|
vectorizer_path = model_config["vectorizer_path"] |
|
|
|
if not os.path.exists(classifier_path): |
|
raise FileNotFoundError(f"Classifier file not found: {classifier_path}") |
|
if not os.path.exists(vectorizer_path): |
|
raise FileNotFoundError(f"Vectorizer file not found: {vectorizer_path}") |
|
|
|
model = TraditionalClassifier(classifier_path, vectorizer_path) |
|
|
|
elif model_config["type"] == "modern": |
|
if not MODERN_MODELS_AVAILABLE: |
|
raise ImportError("Modern models require PyTorch and transformers") |
|
|
|
model_path = model_config["model_path"] |
|
|
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError(f"Model file not found: {model_path}") |
|
|
|
config_path = model_config.get("config_path") |
|
if config_path and not os.path.exists(config_path): |
|
config_path = None |
|
|
|
model = ModernClassifier( |
|
model_type=model_config["model_type"], |
|
model_path=model_path, |
|
config_path=config_path |
|
) |
|
|
|
self._model_cache[model_name] = model |
|
return model |
|
|
|
def predict(self, text: str, model_name: str = None) -> Dict[str, Any]: |
|
"""Predict using the specified model (or default if none specified).""" |
|
if model_name is None: |
|
model_name = self.default_model |
|
|
|
model = self._get_model(model_name) |
|
result = model.predict(text) |
|
|
|
result["model_manager"] = { |
|
"model_used": model_name, |
|
"model_description": self.AVAILABLE_MODELS[model_name]["description"] |
|
} |
|
return result |
|
|
|
def predict_batch(self, texts: list, model_name: str = None) -> list: |
|
"""Predict batch using the specified model (or default if none specified).""" |
|
if model_name is None: |
|
model_name = self.default_model |
|
|
|
model = self._get_model(model_name) |
|
results = model.predict_batch(texts) |
|
|
|
for result in results: |
|
result["model_manager"] = { |
|
"model_used": model_name, |
|
"model_description": self.AVAILABLE_MODELS[model_name]["description"] |
|
} |
|
return results |
|
|
|
def get_model_info(self, model_name: str = None) -> Dict[str, Any]: |
|
"""Get information about a specific model (or default if none specified).""" |
|
if model_name is None: |
|
model_name = self.default_model |
|
|
|
model = self._get_model(model_name) |
|
model_info = model.get_model_info() |
|
model_info.update({ |
|
"model_manager": { |
|
"model_name": model_name, |
|
"model_description": self.AVAILABLE_MODELS[model_name]["description"], |
|
"model_config": self.AVAILABLE_MODELS[model_name], |
|
"is_cached": model_name in self._model_cache |
|
} |
|
}) |
|
return model_info |
|
|
|
def get_available_models(self) -> Dict[str, Any]: |
|
"""Get list of all available models.""" |
|
available = {} |
|
for model_name, config in self.AVAILABLE_MODELS.items(): |
|
files_exist = True |
|
missing_files = [] |
|
|
|
if config["type"] == "traditional": |
|
for file_key in ["classifier_path", "vectorizer_path"]: |
|
if not os.path.exists(config[file_key]): |
|
files_exist = False |
|
missing_files.append(config[file_key]) |
|
elif config["type"] == "modern": |
|
if not os.path.exists(config["model_path"]): |
|
files_exist = False |
|
missing_files.append(config["model_path"]) |
|
|
|
available[model_name] = { |
|
"description": config["description"], |
|
"type": config["type"], |
|
"available": files_exist, |
|
"missing_files": missing_files if not files_exist else [], |
|
"is_default": model_name == self.default_model, |
|
"is_cached": model_name in self._model_cache |
|
} |
|
|
|
return available |
|
|
|
def clear_cache(self, model_name: str = None) -> Dict[str, Any]: |
|
"""Clear model cache (specific model or all models).""" |
|
if model_name: |
|
if model_name in self._model_cache: |
|
del self._model_cache[model_name] |
|
return {"message": f"Cache cleared for model: {model_name}"} |
|
else: |
|
return {"message": f"Model {model_name} was not cached"} |
|
else: |
|
cleared_count = len(self._model_cache) |
|
self._model_cache.clear() |
|
return {"message": f"Cache cleared for {cleared_count} models"} |
|
|
|
def get_cache_status(self) -> Dict[str, Any]: |
|
"""Get information about cached models.""" |
|
return { |
|
"cached_models": list(self._model_cache.keys()), |
|
"cache_count": len(self._model_cache), |
|
"default_model": self.default_model |
|
} |
|
|