File size: 14,252 Bytes
123e49c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fc9256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123e49c
 
 
5fc9256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123e49c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fc9256
 
 
 
 
 
 
 
 
123e49c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fc9256
 
 
 
 
123e49c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import torch
import torch.nn as nn
import numpy as np
from typing import List, Dict, Any, Optional
from preprocessor import preprocess_for_classification
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from safetensors.torch import load_file
from transformers import AutoConfig


class LSTMClassifier(nn.Module):
    """LSTM-based Arabic text classifier."""
    
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers=2, bidirectional=False):
        super(LSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.bidirectional = bidirectional
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers,
            batch_first=True,
            dropout=0.3,
            bidirectional=self.bidirectional
        )
        fc_input_dim = hidden_dim * 2 if self.bidirectional else hidden_dim
        self.fc = nn.Linear(fc_input_dim, output_dim)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        embedded = self.embedding(x)
        _, (hidden, _) = self.lstm(embedded)
        if self.bidirectional:
            forward_hidden = hidden[-2]
            backward_hidden = hidden[-1]
            combined = torch.cat((forward_hidden, backward_hidden), dim=1)
            h = combined
        else:
            h = hidden[-1]
        output = self.fc(self.dropout(h))
        return output


class ModernClassifier:
    """Modern Arabic text classifier supporting BERT and LSTM models."""
    
    def __init__(self, model_type: str, model_path: str, config_path: Optional[str] = None):
        self.model_type = model_type.lower()
        self.model_path = model_path
        self.config_path = config_path
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.classes = np.array(['culture', 'economy', 'international', 'local', 'religion', 'sports'])
        
        if self.model_type == 'bert':
            self._load_bert_model()
        elif self.model_type == 'lstm':
            self._load_lstm_model()
        else:
            raise ValueError(f"Unsupported model type: {model_type}")
            
        self.model_name = f"{model_type}_classifier"
    
    def _load_bert_model(self):
        """Load BERT model from safetensors."""
        try:
            # Try different Arabic BERT tokenizers that match 32K vocabulary
            tokenizer_options = [
                'asafaya/bert-base-arabic',  # This one has 32K vocab
                'aubmindlab/bert-base-arabertv02',  # Alternative
                'aubmindlab/bert-base-arabertv2'   # Fallback (64K vocab)
            ]
            
            self.tokenizer = None
            for tokenizer_name in tokenizer_options:
                try:
                    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, local_files_only=True)
                    # Test if vocabulary size matches
                    if len(tokenizer.vocab) <= 32000:
                        self.tokenizer = tokenizer
                        print(f"Using tokenizer: {tokenizer_name} (vocab size: {len(tokenizer.vocab)})")
                        break
                except:
                    continue
            
            if self.tokenizer is None:
                # Try downloading if local files don't work
                for tokenizer_name in tokenizer_options:
                    try:
                        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
                        if len(tokenizer.vocab) <= 32000:
                            self.tokenizer = tokenizer
                            print(f"Downloaded tokenizer: {tokenizer_name} (vocab size: {len(tokenizer.vocab)})")
                            break
                    except:
                        continue
            
            if self.tokenizer is None:
                raise RuntimeError("No compatible Arabic BERT tokenizer found with 32K vocabulary")
            
            state_dict = load_file(self.model_path)
            embed_key = next(k for k in state_dict if 'embeddings.word_embeddings.weight' in k)
            checkpoint_vocab_size = state_dict[embed_key].shape[0]
            
            # Try to load config locally first
            try:
                config = AutoConfig.from_pretrained(
                    'aubmindlab/bert-base-arabertv2',
                    num_labels=len(self.classes),
                    vocab_size=checkpoint_vocab_size,
                    local_files_only=True
                )
            except:
                try:
                    config = AutoConfig.from_pretrained(
                        'aubmindlab/bert-base-arabertv2',
                        num_labels=len(self.classes),
                        vocab_size=checkpoint_vocab_size
                    )
                except:
                    # Fallback: create a basic BERT config
                    from transformers import BertConfig
                    config = BertConfig(
                        vocab_size=checkpoint_vocab_size,
                        hidden_size=768,
                        num_hidden_layers=12,
                        num_attention_heads=12,
                        intermediate_size=3072,
                        num_labels=len(self.classes)
                    )
            
            self.model = AutoModelForSequenceClassification.from_config(config)
            self.model.resize_token_embeddings(checkpoint_vocab_size)
            self.model.load_state_dict(state_dict, strict=False)
            self.model.to(self.device)
            self.model.eval()
        except Exception as e:
            raise RuntimeError(f"Error loading BERT model: {e}")
    
    def _load_lstm_model(self):
        """Load LSTM model from .pth file."""
        try:
            checkpoint = torch.load(self.model_path, map_location=self.device)
            state_dict = checkpoint.get('model_state_dict', checkpoint)
            vocab_size, embedding_dim = state_dict['embedding.weight'].shape
            _, hidden_dim = state_dict['lstm.weight_hh_l0'].shape
            layer_nums = set(int(re.match(r'lstm\.weight_ih_l(\d+)', k).group(1))
                             for k in state_dict if re.match(r'lstm\.weight_ih_l(\d+)$', k))
            num_layers = len(layer_nums)
            bidirectional = True
            output_dim = len(self.classes)
            self.model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim,
                                        output_dim, num_layers=num_layers,
                                        bidirectional=bidirectional)
            self.model.load_state_dict(state_dict, strict=False)
            self.model.to(self.device)
            self.model.eval()
            self.vocab = checkpoint.get('vocab', {})
        except Exception as e:
            raise RuntimeError(f"Error loading LSTM model: {e}")
    
    def _preprocess_text_for_bert(self, text: str) -> Dict[str, torch.Tensor]:
        """Preprocess text for BERT model."""
        cleaned_text = preprocess_for_classification(text)
        
        inputs = self.tokenizer(
            cleaned_text,
            return_tensors='pt',
            truncation=True,
            padding=True,
            max_length=512
        )
        
        # CRITICAL FIX: Check for vocabulary mismatch and clamp token IDs
        input_ids = inputs['input_ids']
        max_token_id = input_ids.max().item()
        model_vocab_size = self.model.config.vocab_size
        
        if max_token_id >= model_vocab_size:
            # Fix: Clamp token IDs to valid range to prevent "index out of range" error
            inputs['input_ids'] = torch.clamp(input_ids, 0, model_vocab_size - 1)
        
        return {key: value.to(self.device) for key, value in inputs.items()}
    
    def _preprocess_text_for_lstm(self, text: str) -> torch.Tensor:
        """Preprocess text for LSTM model."""
        cleaned_text = preprocess_for_classification(text)
        
        tokens = cleaned_text.split()
        
        if hasattr(self, 'vocab') and self.vocab:
            indices = [self.vocab.get(token, 0) for token in tokens]
        else:
            indices = [hash(token) % 10000 for token in tokens]
        
        max_length = 100
        if len(indices) > max_length:
            indices = indices[:max_length]
        else:
            indices.extend([0] * (max_length - len(indices)))
        
        return torch.tensor([indices], dtype=torch.long).to(self.device)
    
    def predict(self, text: str) -> Dict[str, Any]:
        """Predict class with full probability distribution and metadata."""
        cleaned_text = preprocess_for_classification(text)
        
        with torch.no_grad():
            if self.model_type == 'bert':
                inputs = self._preprocess_text_for_bert(text)
                outputs = self.model(**inputs)
                logits = outputs.logits
            elif self.model_type == 'lstm':
                inputs = self._preprocess_text_for_lstm(text)
                logits = self.model(inputs)
            
            probabilities = torch.softmax(logits, dim=-1).cpu().numpy()
            
            # Handle batch dimension
            if len(probabilities.shape) > 1:
                probabilities = probabilities[0]
            
            prediction_index = int(np.argmax(probabilities))
            prediction = self.classes[prediction_index]
            confidence = float(probabilities[prediction_index])
        
        prob_distribution = {}
        for i, class_label in enumerate(self.classes):
            prob_distribution[str(class_label)] = float(probabilities[i])
        
        return {
            "prediction": str(prediction),
            "prediction_index": prediction_index,
            "confidence": confidence,
            "probability_distribution": prob_distribution,
            "cleaned_text": cleaned_text,
            "model_used": self.model_name,
            "prediction_metadata": {
                "max_probability": float(np.max(probabilities)),
                "min_probability": float(np.min(probabilities)),
                "entropy": float(-np.sum(probabilities * np.log(probabilities + 1e-10))),
                "num_classes": len(probabilities),
                "model_type": self.model_type,
                "device": str(self.device)
            },
        }
    
    def predict_batch(self, texts: List[str]) -> List[Dict[str, Any]]:
        """Predict classes for multiple texts using true batch processing."""
        if not texts:
            return []
        
        cleaned_texts = [preprocess_for_classification(text) for text in texts]
        
        with torch.no_grad():
            if self.model_type == 'bert':
                inputs = self.tokenizer(
                    cleaned_texts,
                    return_tensors='pt',
                    truncation=True,
                    padding=True,
                    max_length=512
                )
                inputs = {key: value.to(self.device) for key, value in inputs.items()}
                outputs = self.model(**inputs)
                logits = outputs.logits
                
            elif self.model_type == 'lstm':
                batch_indices = []
                max_length = 100
                
                for cleaned_text in cleaned_texts:
                    tokens = cleaned_text.split()
                    if hasattr(self, 'vocab') and self.vocab:
                        indices = [self.vocab.get(token, 0) for token in tokens]
                    else:
                        indices = [hash(token) % 10000 for token in tokens]
                    
                    if len(indices) > max_length:
                        indices = indices[:max_length]
                    else:
                        indices.extend([0] * (max_length - len(indices)))
                    
                    batch_indices.append(indices)
                
                batch_tensor = torch.tensor(batch_indices, dtype=torch.long).to(self.device)
                logits = self.model(batch_tensor)
            
            probabilities = torch.softmax(logits, dim=-1).cpu().numpy()
            
            results = []
            for i, (text, cleaned_text) in enumerate(zip(texts, cleaned_texts)):
                probs = probabilities[i]
                prediction_index = int(np.argmax(probs))
                prediction = self.classes[prediction_index]
                confidence = float(probs[prediction_index])
                
                prob_distribution = {}
                for j, class_label in enumerate(self.classes):
                    prob_distribution[str(class_label)] = float(probs[j])
                
                result = {
                    "prediction": str(prediction),
                    "prediction_index": prediction_index,
                    "confidence": confidence,
                    "probability_distribution": prob_distribution,
                    "cleaned_text": cleaned_text,
                    "model_used": self.model_name,
                    "prediction_metadata": {
                        "max_probability": float(np.max(probs)),
                        "min_probability": float(np.min(probs)),
                        "entropy": float(-np.sum(probs * np.log(probs + 1e-10))),
                        "num_classes": len(probs),
                        "model_type": self.model_type,
                        "device": str(self.device)
                    },
                }
                results.append(result)
        
        return results
    
    def get_model_info(self) -> Dict[str, Any]:
        """Get model information and capabilities."""
        return {
            "model_name": self.model_name,
            "model_type": self.model_type,
            "model_path": self.model_path,
            "num_classes": len(self.classes),
            "classes": self.classes.tolist(),
            "device": str(self.device),
            "has_predict_proba": True,
            "framework": "pytorch",
            "modern_model": True
        }