import torch import torch.nn as nn import numpy as np from typing import List, Dict, Any, Optional from preprocessor import preprocess_for_classification import re from transformers import AutoTokenizer, AutoModelForSequenceClassification from safetensors.torch import load_file from transformers import AutoConfig class LSTMClassifier(nn.Module): """LSTM-based Arabic text classifier.""" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers=2, bidirectional=False): super(LSTMClassifier, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.bidirectional = bidirectional self.lstm = nn.LSTM( embedding_dim, hidden_dim, num_layers, batch_first=True, dropout=0.3, bidirectional=self.bidirectional ) fc_input_dim = hidden_dim * 2 if self.bidirectional else hidden_dim self.fc = nn.Linear(fc_input_dim, output_dim) self.dropout = nn.Dropout(0.5) def forward(self, x): embedded = self.embedding(x) _, (hidden, _) = self.lstm(embedded) if self.bidirectional: forward_hidden = hidden[-2] backward_hidden = hidden[-1] combined = torch.cat((forward_hidden, backward_hidden), dim=1) h = combined else: h = hidden[-1] output = self.fc(self.dropout(h)) return output class ModernClassifier: """Modern Arabic text classifier supporting BERT and LSTM models.""" def __init__(self, model_type: str, model_path: str, config_path: Optional[str] = None): self.model_type = model_type.lower() self.model_path = model_path self.config_path = config_path self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.classes = np.array(['culture', 'economy', 'international', 'local', 'religion', 'sports']) if self.model_type == 'bert': self._load_bert_model() elif self.model_type == 'lstm': self._load_lstm_model() else: raise ValueError(f"Unsupported model type: {model_type}") self.model_name = f"{model_type}_classifier" def _load_bert_model(self): """Load BERT model from safetensors.""" try: # Try different Arabic BERT tokenizers that match 32K vocabulary tokenizer_options = [ 'asafaya/bert-base-arabic', # This one has 32K vocab 'aubmindlab/bert-base-arabertv02', # Alternative 'aubmindlab/bert-base-arabertv2' # Fallback (64K vocab) ] self.tokenizer = None for tokenizer_name in tokenizer_options: try: tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, local_files_only=True) # Test if vocabulary size matches if len(tokenizer.vocab) <= 32000: self.tokenizer = tokenizer print(f"Using tokenizer: {tokenizer_name} (vocab size: {len(tokenizer.vocab)})") break except: continue if self.tokenizer is None: # Try downloading if local files don't work for tokenizer_name in tokenizer_options: try: tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if len(tokenizer.vocab) <= 32000: self.tokenizer = tokenizer print(f"Downloaded tokenizer: {tokenizer_name} (vocab size: {len(tokenizer.vocab)})") break except: continue if self.tokenizer is None: raise RuntimeError("No compatible Arabic BERT tokenizer found with 32K vocabulary") state_dict = load_file(self.model_path) embed_key = next(k for k in state_dict if 'embeddings.word_embeddings.weight' in k) checkpoint_vocab_size = state_dict[embed_key].shape[0] # Try to load config locally first try: config = AutoConfig.from_pretrained( 'aubmindlab/bert-base-arabertv2', num_labels=len(self.classes), vocab_size=checkpoint_vocab_size, local_files_only=True ) except: try: config = AutoConfig.from_pretrained( 'aubmindlab/bert-base-arabertv2', num_labels=len(self.classes), vocab_size=checkpoint_vocab_size ) except: # Fallback: create a basic BERT config from transformers import BertConfig config = BertConfig( vocab_size=checkpoint_vocab_size, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, num_labels=len(self.classes) ) self.model = AutoModelForSequenceClassification.from_config(config) self.model.resize_token_embeddings(checkpoint_vocab_size) self.model.load_state_dict(state_dict, strict=False) self.model.to(self.device) self.model.eval() except Exception as e: raise RuntimeError(f"Error loading BERT model: {e}") def _load_lstm_model(self): """Load LSTM model from .pth file.""" try: checkpoint = torch.load(self.model_path, map_location=self.device) state_dict = checkpoint.get('model_state_dict', checkpoint) vocab_size, embedding_dim = state_dict['embedding.weight'].shape _, hidden_dim = state_dict['lstm.weight_hh_l0'].shape layer_nums = set(int(re.match(r'lstm\.weight_ih_l(\d+)', k).group(1)) for k in state_dict if re.match(r'lstm\.weight_ih_l(\d+)$', k)) num_layers = len(layer_nums) bidirectional = True output_dim = len(self.classes) self.model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim, output_dim, num_layers=num_layers, bidirectional=bidirectional) self.model.load_state_dict(state_dict, strict=False) self.model.to(self.device) self.model.eval() self.vocab = checkpoint.get('vocab', {}) except Exception as e: raise RuntimeError(f"Error loading LSTM model: {e}") def _preprocess_text_for_bert(self, text: str) -> Dict[str, torch.Tensor]: """Preprocess text for BERT model.""" cleaned_text = preprocess_for_classification(text) inputs = self.tokenizer( cleaned_text, return_tensors='pt', truncation=True, padding=True, max_length=512 ) # CRITICAL FIX: Check for vocabulary mismatch and clamp token IDs input_ids = inputs['input_ids'] max_token_id = input_ids.max().item() model_vocab_size = self.model.config.vocab_size if max_token_id >= model_vocab_size: # Fix: Clamp token IDs to valid range to prevent "index out of range" error inputs['input_ids'] = torch.clamp(input_ids, 0, model_vocab_size - 1) return {key: value.to(self.device) for key, value in inputs.items()} def _preprocess_text_for_lstm(self, text: str) -> torch.Tensor: """Preprocess text for LSTM model.""" cleaned_text = preprocess_for_classification(text) tokens = cleaned_text.split() if hasattr(self, 'vocab') and self.vocab: indices = [self.vocab.get(token, 0) for token in tokens] else: indices = [hash(token) % 10000 for token in tokens] max_length = 100 if len(indices) > max_length: indices = indices[:max_length] else: indices.extend([0] * (max_length - len(indices))) return torch.tensor([indices], dtype=torch.long).to(self.device) def predict(self, text: str) -> Dict[str, Any]: """Predict class with full probability distribution and metadata.""" cleaned_text = preprocess_for_classification(text) with torch.no_grad(): if self.model_type == 'bert': inputs = self._preprocess_text_for_bert(text) outputs = self.model(**inputs) logits = outputs.logits elif self.model_type == 'lstm': inputs = self._preprocess_text_for_lstm(text) logits = self.model(inputs) probabilities = torch.softmax(logits, dim=-1).cpu().numpy() # Handle batch dimension if len(probabilities.shape) > 1: probabilities = probabilities[0] prediction_index = int(np.argmax(probabilities)) prediction = self.classes[prediction_index] confidence = float(probabilities[prediction_index]) prob_distribution = {} for i, class_label in enumerate(self.classes): prob_distribution[str(class_label)] = float(probabilities[i]) return { "prediction": str(prediction), "prediction_index": prediction_index, "confidence": confidence, "probability_distribution": prob_distribution, "cleaned_text": cleaned_text, "model_used": self.model_name, "prediction_metadata": { "max_probability": float(np.max(probabilities)), "min_probability": float(np.min(probabilities)), "entropy": float(-np.sum(probabilities * np.log(probabilities + 1e-10))), "num_classes": len(probabilities), "model_type": self.model_type, "device": str(self.device) }, } def predict_batch(self, texts: List[str]) -> List[Dict[str, Any]]: """Predict classes for multiple texts using true batch processing.""" if not texts: return [] cleaned_texts = [preprocess_for_classification(text) for text in texts] with torch.no_grad(): if self.model_type == 'bert': inputs = self.tokenizer( cleaned_texts, return_tensors='pt', truncation=True, padding=True, max_length=512 ) inputs = {key: value.to(self.device) for key, value in inputs.items()} outputs = self.model(**inputs) logits = outputs.logits elif self.model_type == 'lstm': batch_indices = [] max_length = 100 for cleaned_text in cleaned_texts: tokens = cleaned_text.split() if hasattr(self, 'vocab') and self.vocab: indices = [self.vocab.get(token, 0) for token in tokens] else: indices = [hash(token) % 10000 for token in tokens] if len(indices) > max_length: indices = indices[:max_length] else: indices.extend([0] * (max_length - len(indices))) batch_indices.append(indices) batch_tensor = torch.tensor(batch_indices, dtype=torch.long).to(self.device) logits = self.model(batch_tensor) probabilities = torch.softmax(logits, dim=-1).cpu().numpy() results = [] for i, (text, cleaned_text) in enumerate(zip(texts, cleaned_texts)): probs = probabilities[i] prediction_index = int(np.argmax(probs)) prediction = self.classes[prediction_index] confidence = float(probs[prediction_index]) prob_distribution = {} for j, class_label in enumerate(self.classes): prob_distribution[str(class_label)] = float(probs[j]) result = { "prediction": str(prediction), "prediction_index": prediction_index, "confidence": confidence, "probability_distribution": prob_distribution, "cleaned_text": cleaned_text, "model_used": self.model_name, "prediction_metadata": { "max_probability": float(np.max(probs)), "min_probability": float(np.min(probs)), "entropy": float(-np.sum(probs * np.log(probs + 1e-10))), "num_classes": len(probs), "model_type": self.model_type, "device": str(self.device) }, } results.append(result) return results def get_model_info(self) -> Dict[str, Any]: """Get model information and capabilities.""" return { "model_name": self.model_name, "model_type": self.model_type, "model_path": self.model_path, "num_classes": len(self.classes), "classes": self.classes.tolist(), "device": str(self.device), "has_predict_proba": True, "framework": "pytorch", "modern_model": True }