File size: 7,221 Bytes
123e49c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fc9256
 
123e49c
 
 
 
 
 
5fc9256
123e49c
 
 
 
 
 
 
5fc9256
123e49c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from typing import Dict, Any
import os
from traditional_classifier import TraditionalClassifier

try:
    from modern_classifier import ModernClassifier
    MODERN_MODELS_AVAILABLE = True
except ImportError:
    MODERN_MODELS_AVAILABLE = False


class ModelManager:
    """Manages different types of Arabic text classification models with per-request model selection and caching."""
    
    AVAILABLE_MODELS = {
        "traditional_svm": {
            "type": "traditional",
            "classifier_path": "models/traditional_svm_classifier.joblib",
            "vectorizer_path": "models/traditional_tfidf_vectorizer_classifier.joblib",
            "description": "Traditional SVM classifier with TF-IDF vectorization"
        },
        
        "modern_bert": {
            "type": "modern",
            "model_type": "bert",
            "model_path": "models/modern_bert_classifier.safetensors",
            "config_path": "config.json",
            "description": "Modern BERT-based transformer classifier"
        },
        
        "modern_lstm": {
            "type": "modern", 
            "model_type": "lstm",
            "model_path": "models/modern_lstm_classifier.pth",
            "description": "Modern LSTM-based neural network classifier"
        }
    }
    
    def __init__(self, default_model: str = "traditional_svm"):
        self.default_model = default_model
        self._model_cache = {}

    def _get_model(self, model_name: str):
        """Get model instance, loading from cache or creating new one."""
        if model_name not in self.AVAILABLE_MODELS:
            raise ValueError(f"Model '{model_name}' not available. Available models: {list(self.AVAILABLE_MODELS.keys())}")
        
        if model_name in self._model_cache:
            return self._model_cache[model_name]
        
        model_config = self.AVAILABLE_MODELS[model_name]
        
        if model_config["type"] == "traditional":
            classifier_path = model_config["classifier_path"]
            vectorizer_path = model_config["vectorizer_path"]
            
            if not os.path.exists(classifier_path):
                raise FileNotFoundError(f"Classifier file not found: {classifier_path}")
            if not os.path.exists(vectorizer_path):
                raise FileNotFoundError(f"Vectorizer file not found: {vectorizer_path}")
                
            model = TraditionalClassifier(classifier_path, vectorizer_path)
            
        elif model_config["type"] == "modern":
            if not MODERN_MODELS_AVAILABLE:
                raise ImportError("Modern models require PyTorch and transformers")
                
            model_path = model_config["model_path"]
            
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Model file not found: {model_path}")
                
            config_path = model_config.get("config_path")
            if config_path and not os.path.exists(config_path):
                config_path = None
                
            model = ModernClassifier(
                model_type=model_config["model_type"],
                model_path=model_path,
                config_path=config_path
            )
        
        self._model_cache[model_name] = model
        return model
    
    def predict(self, text: str, model_name: str = None) -> Dict[str, Any]:
        """Predict using the specified model (or default if none specified)."""
        if model_name is None:
            model_name = self.default_model
            
        model = self._get_model(model_name)
        result = model.predict(text)
        
        result["model_manager"] = {
            "model_used": model_name,
            "model_description": self.AVAILABLE_MODELS[model_name]["description"]
        }
        return result
    
    def predict_batch(self, texts: list, model_name: str = None) -> list:
        """Predict batch using the specified model (or default if none specified)."""
        if model_name is None:
            model_name = self.default_model
            
        model = self._get_model(model_name)
        results = model.predict_batch(texts)
        
        for result in results:
            result["model_manager"] = {
                "model_used": model_name,
                "model_description": self.AVAILABLE_MODELS[model_name]["description"]
            }
        return results
    
    def get_model_info(self, model_name: str = None) -> Dict[str, Any]:
        """Get information about a specific model (or default if none specified)."""
        if model_name is None:
            model_name = self.default_model
            
        model = self._get_model(model_name)
        model_info = model.get_model_info()
        model_info.update({
            "model_manager": {
                "model_name": model_name,
                "model_description": self.AVAILABLE_MODELS[model_name]["description"],
                "model_config": self.AVAILABLE_MODELS[model_name],
                "is_cached": model_name in self._model_cache
            }
        })
        return model_info
    
    def get_available_models(self) -> Dict[str, Any]:
        """Get list of all available models."""
        available = {}
        for model_name, config in self.AVAILABLE_MODELS.items():
            files_exist = True
            missing_files = []
            
            if config["type"] == "traditional":
                for file_key in ["classifier_path", "vectorizer_path"]:
                    if not os.path.exists(config[file_key]):
                        files_exist = False
                        missing_files.append(config[file_key])
            elif config["type"] == "modern":
                if not os.path.exists(config["model_path"]):
                    files_exist = False
                    missing_files.append(config["model_path"])
            
            available[model_name] = {
                "description": config["description"],
                "type": config["type"],
                "available": files_exist,
                "missing_files": missing_files if not files_exist else [],
                "is_default": model_name == self.default_model,
                "is_cached": model_name in self._model_cache
            }
        
        return available
    
    def clear_cache(self, model_name: str = None) -> Dict[str, Any]:
        """Clear model cache (specific model or all models)."""
        if model_name:
            if model_name in self._model_cache:
                del self._model_cache[model_name]
                return {"message": f"Cache cleared for model: {model_name}"}
            else:
                return {"message": f"Model {model_name} was not cached"}
        else:
            cleared_count = len(self._model_cache)
            self._model_cache.clear()
            return {"message": f"Cache cleared for {cleared_count} models"}
    
    def get_cache_status(self) -> Dict[str, Any]:
        """Get information about cached models."""
        return {
            "cached_models": list(self._model_cache.keys()),
            "cache_count": len(self._model_cache),
            "default_model": self.default_model
        }