arabic-summarizer-classifier / modern_classifier.py
mabosaimi's picture
Fkhrayef (#1)
5fc9256 verified
import torch
import torch.nn as nn
import numpy as np
from typing import List, Dict, Any, Optional
from preprocessor import preprocess_for_classification
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from safetensors.torch import load_file
from transformers import AutoConfig
class LSTMClassifier(nn.Module):
"""LSTM-based Arabic text classifier."""
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers=2, bidirectional=False):
super(LSTMClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.bidirectional = bidirectional
self.lstm = nn.LSTM(
embedding_dim,
hidden_dim,
num_layers,
batch_first=True,
dropout=0.3,
bidirectional=self.bidirectional
)
fc_input_dim = hidden_dim * 2 if self.bidirectional else hidden_dim
self.fc = nn.Linear(fc_input_dim, output_dim)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
embedded = self.embedding(x)
_, (hidden, _) = self.lstm(embedded)
if self.bidirectional:
forward_hidden = hidden[-2]
backward_hidden = hidden[-1]
combined = torch.cat((forward_hidden, backward_hidden), dim=1)
h = combined
else:
h = hidden[-1]
output = self.fc(self.dropout(h))
return output
class ModernClassifier:
"""Modern Arabic text classifier supporting BERT and LSTM models."""
def __init__(self, model_type: str, model_path: str, config_path: Optional[str] = None):
self.model_type = model_type.lower()
self.model_path = model_path
self.config_path = config_path
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.classes = np.array(['culture', 'economy', 'international', 'local', 'religion', 'sports'])
if self.model_type == 'bert':
self._load_bert_model()
elif self.model_type == 'lstm':
self._load_lstm_model()
else:
raise ValueError(f"Unsupported model type: {model_type}")
self.model_name = f"{model_type}_classifier"
def _load_bert_model(self):
"""Load BERT model from safetensors."""
try:
# Try different Arabic BERT tokenizers that match 32K vocabulary
tokenizer_options = [
'asafaya/bert-base-arabic', # This one has 32K vocab
'aubmindlab/bert-base-arabertv02', # Alternative
'aubmindlab/bert-base-arabertv2' # Fallback (64K vocab)
]
self.tokenizer = None
for tokenizer_name in tokenizer_options:
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, local_files_only=True)
# Test if vocabulary size matches
if len(tokenizer.vocab) <= 32000:
self.tokenizer = tokenizer
print(f"Using tokenizer: {tokenizer_name} (vocab size: {len(tokenizer.vocab)})")
break
except:
continue
if self.tokenizer is None:
# Try downloading if local files don't work
for tokenizer_name in tokenizer_options:
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if len(tokenizer.vocab) <= 32000:
self.tokenizer = tokenizer
print(f"Downloaded tokenizer: {tokenizer_name} (vocab size: {len(tokenizer.vocab)})")
break
except:
continue
if self.tokenizer is None:
raise RuntimeError("No compatible Arabic BERT tokenizer found with 32K vocabulary")
state_dict = load_file(self.model_path)
embed_key = next(k for k in state_dict if 'embeddings.word_embeddings.weight' in k)
checkpoint_vocab_size = state_dict[embed_key].shape[0]
# Try to load config locally first
try:
config = AutoConfig.from_pretrained(
'aubmindlab/bert-base-arabertv2',
num_labels=len(self.classes),
vocab_size=checkpoint_vocab_size,
local_files_only=True
)
except:
try:
config = AutoConfig.from_pretrained(
'aubmindlab/bert-base-arabertv2',
num_labels=len(self.classes),
vocab_size=checkpoint_vocab_size
)
except:
# Fallback: create a basic BERT config
from transformers import BertConfig
config = BertConfig(
vocab_size=checkpoint_vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
num_labels=len(self.classes)
)
self.model = AutoModelForSequenceClassification.from_config(config)
self.model.resize_token_embeddings(checkpoint_vocab_size)
self.model.load_state_dict(state_dict, strict=False)
self.model.to(self.device)
self.model.eval()
except Exception as e:
raise RuntimeError(f"Error loading BERT model: {e}")
def _load_lstm_model(self):
"""Load LSTM model from .pth file."""
try:
checkpoint = torch.load(self.model_path, map_location=self.device)
state_dict = checkpoint.get('model_state_dict', checkpoint)
vocab_size, embedding_dim = state_dict['embedding.weight'].shape
_, hidden_dim = state_dict['lstm.weight_hh_l0'].shape
layer_nums = set(int(re.match(r'lstm\.weight_ih_l(\d+)', k).group(1))
for k in state_dict if re.match(r'lstm\.weight_ih_l(\d+)$', k))
num_layers = len(layer_nums)
bidirectional = True
output_dim = len(self.classes)
self.model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim,
output_dim, num_layers=num_layers,
bidirectional=bidirectional)
self.model.load_state_dict(state_dict, strict=False)
self.model.to(self.device)
self.model.eval()
self.vocab = checkpoint.get('vocab', {})
except Exception as e:
raise RuntimeError(f"Error loading LSTM model: {e}")
def _preprocess_text_for_bert(self, text: str) -> Dict[str, torch.Tensor]:
"""Preprocess text for BERT model."""
cleaned_text = preprocess_for_classification(text)
inputs = self.tokenizer(
cleaned_text,
return_tensors='pt',
truncation=True,
padding=True,
max_length=512
)
# CRITICAL FIX: Check for vocabulary mismatch and clamp token IDs
input_ids = inputs['input_ids']
max_token_id = input_ids.max().item()
model_vocab_size = self.model.config.vocab_size
if max_token_id >= model_vocab_size:
# Fix: Clamp token IDs to valid range to prevent "index out of range" error
inputs['input_ids'] = torch.clamp(input_ids, 0, model_vocab_size - 1)
return {key: value.to(self.device) for key, value in inputs.items()}
def _preprocess_text_for_lstm(self, text: str) -> torch.Tensor:
"""Preprocess text for LSTM model."""
cleaned_text = preprocess_for_classification(text)
tokens = cleaned_text.split()
if hasattr(self, 'vocab') and self.vocab:
indices = [self.vocab.get(token, 0) for token in tokens]
else:
indices = [hash(token) % 10000 for token in tokens]
max_length = 100
if len(indices) > max_length:
indices = indices[:max_length]
else:
indices.extend([0] * (max_length - len(indices)))
return torch.tensor([indices], dtype=torch.long).to(self.device)
def predict(self, text: str) -> Dict[str, Any]:
"""Predict class with full probability distribution and metadata."""
cleaned_text = preprocess_for_classification(text)
with torch.no_grad():
if self.model_type == 'bert':
inputs = self._preprocess_text_for_bert(text)
outputs = self.model(**inputs)
logits = outputs.logits
elif self.model_type == 'lstm':
inputs = self._preprocess_text_for_lstm(text)
logits = self.model(inputs)
probabilities = torch.softmax(logits, dim=-1).cpu().numpy()
# Handle batch dimension
if len(probabilities.shape) > 1:
probabilities = probabilities[0]
prediction_index = int(np.argmax(probabilities))
prediction = self.classes[prediction_index]
confidence = float(probabilities[prediction_index])
prob_distribution = {}
for i, class_label in enumerate(self.classes):
prob_distribution[str(class_label)] = float(probabilities[i])
return {
"prediction": str(prediction),
"prediction_index": prediction_index,
"confidence": confidence,
"probability_distribution": prob_distribution,
"cleaned_text": cleaned_text,
"model_used": self.model_name,
"prediction_metadata": {
"max_probability": float(np.max(probabilities)),
"min_probability": float(np.min(probabilities)),
"entropy": float(-np.sum(probabilities * np.log(probabilities + 1e-10))),
"num_classes": len(probabilities),
"model_type": self.model_type,
"device": str(self.device)
},
}
def predict_batch(self, texts: List[str]) -> List[Dict[str, Any]]:
"""Predict classes for multiple texts using true batch processing."""
if not texts:
return []
cleaned_texts = [preprocess_for_classification(text) for text in texts]
with torch.no_grad():
if self.model_type == 'bert':
inputs = self.tokenizer(
cleaned_texts,
return_tensors='pt',
truncation=True,
padding=True,
max_length=512
)
inputs = {key: value.to(self.device) for key, value in inputs.items()}
outputs = self.model(**inputs)
logits = outputs.logits
elif self.model_type == 'lstm':
batch_indices = []
max_length = 100
for cleaned_text in cleaned_texts:
tokens = cleaned_text.split()
if hasattr(self, 'vocab') and self.vocab:
indices = [self.vocab.get(token, 0) for token in tokens]
else:
indices = [hash(token) % 10000 for token in tokens]
if len(indices) > max_length:
indices = indices[:max_length]
else:
indices.extend([0] * (max_length - len(indices)))
batch_indices.append(indices)
batch_tensor = torch.tensor(batch_indices, dtype=torch.long).to(self.device)
logits = self.model(batch_tensor)
probabilities = torch.softmax(logits, dim=-1).cpu().numpy()
results = []
for i, (text, cleaned_text) in enumerate(zip(texts, cleaned_texts)):
probs = probabilities[i]
prediction_index = int(np.argmax(probs))
prediction = self.classes[prediction_index]
confidence = float(probs[prediction_index])
prob_distribution = {}
for j, class_label in enumerate(self.classes):
prob_distribution[str(class_label)] = float(probs[j])
result = {
"prediction": str(prediction),
"prediction_index": prediction_index,
"confidence": confidence,
"probability_distribution": prob_distribution,
"cleaned_text": cleaned_text,
"model_used": self.model_name,
"prediction_metadata": {
"max_probability": float(np.max(probs)),
"min_probability": float(np.min(probs)),
"entropy": float(-np.sum(probs * np.log(probs + 1e-10))),
"num_classes": len(probs),
"model_type": self.model_type,
"device": str(self.device)
},
}
results.append(result)
return results
def get_model_info(self) -> Dict[str, Any]:
"""Get model information and capabilities."""
return {
"model_name": self.model_name,
"model_type": self.model_type,
"model_path": self.model_path,
"num_classes": len(self.classes),
"classes": self.classes.tolist(),
"device": str(self.device),
"has_predict_proba": True,
"framework": "pytorch",
"modern_model": True
}