|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from typing import List, Dict, Any, Optional |
|
from preprocessor import preprocess_for_classification |
|
import re |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from safetensors.torch import load_file |
|
from transformers import AutoConfig |
|
|
|
|
|
class LSTMClassifier(nn.Module): |
|
"""LSTM-based Arabic text classifier.""" |
|
|
|
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers=2, bidirectional=False): |
|
super(LSTMClassifier, self).__init__() |
|
self.embedding = nn.Embedding(vocab_size, embedding_dim) |
|
self.bidirectional = bidirectional |
|
self.lstm = nn.LSTM( |
|
embedding_dim, |
|
hidden_dim, |
|
num_layers, |
|
batch_first=True, |
|
dropout=0.3, |
|
bidirectional=self.bidirectional |
|
) |
|
fc_input_dim = hidden_dim * 2 if self.bidirectional else hidden_dim |
|
self.fc = nn.Linear(fc_input_dim, output_dim) |
|
self.dropout = nn.Dropout(0.5) |
|
|
|
def forward(self, x): |
|
embedded = self.embedding(x) |
|
_, (hidden, _) = self.lstm(embedded) |
|
if self.bidirectional: |
|
forward_hidden = hidden[-2] |
|
backward_hidden = hidden[-1] |
|
combined = torch.cat((forward_hidden, backward_hidden), dim=1) |
|
h = combined |
|
else: |
|
h = hidden[-1] |
|
output = self.fc(self.dropout(h)) |
|
return output |
|
|
|
|
|
class ModernClassifier: |
|
"""Modern Arabic text classifier supporting BERT and LSTM models.""" |
|
|
|
def __init__(self, model_type: str, model_path: str, config_path: Optional[str] = None): |
|
self.model_type = model_type.lower() |
|
self.model_path = model_path |
|
self.config_path = config_path |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
self.classes = np.array(['culture', 'economy', 'international', 'local', 'religion', 'sports']) |
|
|
|
if self.model_type == 'bert': |
|
self._load_bert_model() |
|
elif self.model_type == 'lstm': |
|
self._load_lstm_model() |
|
else: |
|
raise ValueError(f"Unsupported model type: {model_type}") |
|
|
|
self.model_name = f"{model_type}_classifier" |
|
|
|
def _load_bert_model(self): |
|
"""Load BERT model from safetensors.""" |
|
try: |
|
|
|
tokenizer_options = [ |
|
'asafaya/bert-base-arabic', |
|
'aubmindlab/bert-base-arabertv02', |
|
'aubmindlab/bert-base-arabertv2' |
|
] |
|
|
|
self.tokenizer = None |
|
for tokenizer_name in tokenizer_options: |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, local_files_only=True) |
|
|
|
if len(tokenizer.vocab) <= 32000: |
|
self.tokenizer = tokenizer |
|
print(f"Using tokenizer: {tokenizer_name} (vocab size: {len(tokenizer.vocab)})") |
|
break |
|
except: |
|
continue |
|
|
|
if self.tokenizer is None: |
|
|
|
for tokenizer_name in tokenizer_options: |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
if len(tokenizer.vocab) <= 32000: |
|
self.tokenizer = tokenizer |
|
print(f"Downloaded tokenizer: {tokenizer_name} (vocab size: {len(tokenizer.vocab)})") |
|
break |
|
except: |
|
continue |
|
|
|
if self.tokenizer is None: |
|
raise RuntimeError("No compatible Arabic BERT tokenizer found with 32K vocabulary") |
|
|
|
state_dict = load_file(self.model_path) |
|
embed_key = next(k for k in state_dict if 'embeddings.word_embeddings.weight' in k) |
|
checkpoint_vocab_size = state_dict[embed_key].shape[0] |
|
|
|
|
|
try: |
|
config = AutoConfig.from_pretrained( |
|
'aubmindlab/bert-base-arabertv2', |
|
num_labels=len(self.classes), |
|
vocab_size=checkpoint_vocab_size, |
|
local_files_only=True |
|
) |
|
except: |
|
try: |
|
config = AutoConfig.from_pretrained( |
|
'aubmindlab/bert-base-arabertv2', |
|
num_labels=len(self.classes), |
|
vocab_size=checkpoint_vocab_size |
|
) |
|
except: |
|
|
|
from transformers import BertConfig |
|
config = BertConfig( |
|
vocab_size=checkpoint_vocab_size, |
|
hidden_size=768, |
|
num_hidden_layers=12, |
|
num_attention_heads=12, |
|
intermediate_size=3072, |
|
num_labels=len(self.classes) |
|
) |
|
|
|
self.model = AutoModelForSequenceClassification.from_config(config) |
|
self.model.resize_token_embeddings(checkpoint_vocab_size) |
|
self.model.load_state_dict(state_dict, strict=False) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading BERT model: {e}") |
|
|
|
def _load_lstm_model(self): |
|
"""Load LSTM model from .pth file.""" |
|
try: |
|
checkpoint = torch.load(self.model_path, map_location=self.device) |
|
state_dict = checkpoint.get('model_state_dict', checkpoint) |
|
vocab_size, embedding_dim = state_dict['embedding.weight'].shape |
|
_, hidden_dim = state_dict['lstm.weight_hh_l0'].shape |
|
layer_nums = set(int(re.match(r'lstm\.weight_ih_l(\d+)', k).group(1)) |
|
for k in state_dict if re.match(r'lstm\.weight_ih_l(\d+)$', k)) |
|
num_layers = len(layer_nums) |
|
bidirectional = True |
|
output_dim = len(self.classes) |
|
self.model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim, |
|
output_dim, num_layers=num_layers, |
|
bidirectional=bidirectional) |
|
self.model.load_state_dict(state_dict, strict=False) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
self.vocab = checkpoint.get('vocab', {}) |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading LSTM model: {e}") |
|
|
|
def _preprocess_text_for_bert(self, text: str) -> Dict[str, torch.Tensor]: |
|
"""Preprocess text for BERT model.""" |
|
cleaned_text = preprocess_for_classification(text) |
|
|
|
inputs = self.tokenizer( |
|
cleaned_text, |
|
return_tensors='pt', |
|
truncation=True, |
|
padding=True, |
|
max_length=512 |
|
) |
|
|
|
|
|
input_ids = inputs['input_ids'] |
|
max_token_id = input_ids.max().item() |
|
model_vocab_size = self.model.config.vocab_size |
|
|
|
if max_token_id >= model_vocab_size: |
|
|
|
inputs['input_ids'] = torch.clamp(input_ids, 0, model_vocab_size - 1) |
|
|
|
return {key: value.to(self.device) for key, value in inputs.items()} |
|
|
|
def _preprocess_text_for_lstm(self, text: str) -> torch.Tensor: |
|
"""Preprocess text for LSTM model.""" |
|
cleaned_text = preprocess_for_classification(text) |
|
|
|
tokens = cleaned_text.split() |
|
|
|
if hasattr(self, 'vocab') and self.vocab: |
|
indices = [self.vocab.get(token, 0) for token in tokens] |
|
else: |
|
indices = [hash(token) % 10000 for token in tokens] |
|
|
|
max_length = 100 |
|
if len(indices) > max_length: |
|
indices = indices[:max_length] |
|
else: |
|
indices.extend([0] * (max_length - len(indices))) |
|
|
|
return torch.tensor([indices], dtype=torch.long).to(self.device) |
|
|
|
def predict(self, text: str) -> Dict[str, Any]: |
|
"""Predict class with full probability distribution and metadata.""" |
|
cleaned_text = preprocess_for_classification(text) |
|
|
|
with torch.no_grad(): |
|
if self.model_type == 'bert': |
|
inputs = self._preprocess_text_for_bert(text) |
|
outputs = self.model(**inputs) |
|
logits = outputs.logits |
|
elif self.model_type == 'lstm': |
|
inputs = self._preprocess_text_for_lstm(text) |
|
logits = self.model(inputs) |
|
|
|
probabilities = torch.softmax(logits, dim=-1).cpu().numpy() |
|
|
|
|
|
if len(probabilities.shape) > 1: |
|
probabilities = probabilities[0] |
|
|
|
prediction_index = int(np.argmax(probabilities)) |
|
prediction = self.classes[prediction_index] |
|
confidence = float(probabilities[prediction_index]) |
|
|
|
prob_distribution = {} |
|
for i, class_label in enumerate(self.classes): |
|
prob_distribution[str(class_label)] = float(probabilities[i]) |
|
|
|
return { |
|
"prediction": str(prediction), |
|
"prediction_index": prediction_index, |
|
"confidence": confidence, |
|
"probability_distribution": prob_distribution, |
|
"cleaned_text": cleaned_text, |
|
"model_used": self.model_name, |
|
"prediction_metadata": { |
|
"max_probability": float(np.max(probabilities)), |
|
"min_probability": float(np.min(probabilities)), |
|
"entropy": float(-np.sum(probabilities * np.log(probabilities + 1e-10))), |
|
"num_classes": len(probabilities), |
|
"model_type": self.model_type, |
|
"device": str(self.device) |
|
}, |
|
} |
|
|
|
def predict_batch(self, texts: List[str]) -> List[Dict[str, Any]]: |
|
"""Predict classes for multiple texts using true batch processing.""" |
|
if not texts: |
|
return [] |
|
|
|
cleaned_texts = [preprocess_for_classification(text) for text in texts] |
|
|
|
with torch.no_grad(): |
|
if self.model_type == 'bert': |
|
inputs = self.tokenizer( |
|
cleaned_texts, |
|
return_tensors='pt', |
|
truncation=True, |
|
padding=True, |
|
max_length=512 |
|
) |
|
inputs = {key: value.to(self.device) for key, value in inputs.items()} |
|
outputs = self.model(**inputs) |
|
logits = outputs.logits |
|
|
|
elif self.model_type == 'lstm': |
|
batch_indices = [] |
|
max_length = 100 |
|
|
|
for cleaned_text in cleaned_texts: |
|
tokens = cleaned_text.split() |
|
if hasattr(self, 'vocab') and self.vocab: |
|
indices = [self.vocab.get(token, 0) for token in tokens] |
|
else: |
|
indices = [hash(token) % 10000 for token in tokens] |
|
|
|
if len(indices) > max_length: |
|
indices = indices[:max_length] |
|
else: |
|
indices.extend([0] * (max_length - len(indices))) |
|
|
|
batch_indices.append(indices) |
|
|
|
batch_tensor = torch.tensor(batch_indices, dtype=torch.long).to(self.device) |
|
logits = self.model(batch_tensor) |
|
|
|
probabilities = torch.softmax(logits, dim=-1).cpu().numpy() |
|
|
|
results = [] |
|
for i, (text, cleaned_text) in enumerate(zip(texts, cleaned_texts)): |
|
probs = probabilities[i] |
|
prediction_index = int(np.argmax(probs)) |
|
prediction = self.classes[prediction_index] |
|
confidence = float(probs[prediction_index]) |
|
|
|
prob_distribution = {} |
|
for j, class_label in enumerate(self.classes): |
|
prob_distribution[str(class_label)] = float(probs[j]) |
|
|
|
result = { |
|
"prediction": str(prediction), |
|
"prediction_index": prediction_index, |
|
"confidence": confidence, |
|
"probability_distribution": prob_distribution, |
|
"cleaned_text": cleaned_text, |
|
"model_used": self.model_name, |
|
"prediction_metadata": { |
|
"max_probability": float(np.max(probs)), |
|
"min_probability": float(np.min(probs)), |
|
"entropy": float(-np.sum(probs * np.log(probs + 1e-10))), |
|
"num_classes": len(probs), |
|
"model_type": self.model_type, |
|
"device": str(self.device) |
|
}, |
|
} |
|
results.append(result) |
|
|
|
return results |
|
|
|
def get_model_info(self) -> Dict[str, Any]: |
|
"""Get model information and capabilities.""" |
|
return { |
|
"model_name": self.model_name, |
|
"model_type": self.model_type, |
|
"model_path": self.model_path, |
|
"num_classes": len(self.classes), |
|
"classes": self.classes.tolist(), |
|
"device": str(self.device), |
|
"has_predict_proba": True, |
|
"framework": "pytorch", |
|
"modern_model": True |
|
} |
|
|