arabic-summarizer-classifier / bert_summarizer.py
mabosaimi's picture
Fkhrayef (#1)
5fc9256 verified
import torch
import numpy as np
import re
from typing import Dict, List, Any
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
from preprocessor import preprocess_for_summarization
class BERTExtractiveSummarizer:
def __init__(self, model_name='aubmindlab/bert-base-arabertv02'):
"""Initialize BERT-based Arabic summarizer."""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {self.device}")
# Load tokenizer and model
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertModel.from_pretrained(model_name)
self.model.to(self.device)
self.model.eval()
def get_sentence_embeddings(self, sentences: List[str]) -> np.ndarray:
"""Get BERT embeddings for sentences."""
embeddings = []
with torch.no_grad():
for sentence in sentences:
# Tokenize
inputs = self.tokenizer(
sentence,
return_tensors='pt',
max_length=512,
truncation=True,
padding=True
).to(self.device)
# Get embeddings
outputs = self.model(**inputs)
# Use CLS token embedding
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
embeddings.append(embedding.squeeze())
return np.array(embeddings)
def summarize(self, text: str, num_sentences: int = 3) -> Dict[str, Any]:
"""
Summarize Arabic text using BERT extractive summarization.
Returns the same structure as other summarizers for consistency.
"""
print(f"BERT Summarizer: Processing text with {len(text)} characters")
# Use the same preprocessing as TF-IDF for fair comparison
cleaned_text = preprocess_for_summarization(text)
print(f"BERT Summarizer: After preprocessing: '{cleaned_text[:100]}...'")
# Split into sentences - same approach as TF-IDF
sentences = re.split(r'[.!؟\n]+', cleaned_text)
sentences = [s.strip() for s in sentences if s.strip()] # Same as TF-IDF
print(f"BERT Summarizer: Found {len(sentences)} sentences")
original_sentence_count = len(sentences)
# If we have fewer sentences than requested, return all
if len(sentences) <= num_sentences:
print(f"BERT Summarizer: Returning all {len(sentences)} sentences (fewer than requested)")
return {
"summary": cleaned_text.strip(), # Use cleaned text like TF-IDF
"original_sentence_count": original_sentence_count,
"summary_sentence_count": len(sentences),
"sentences": sentences,
"selected_indices": list(range(len(sentences))),
"sentence_scores": [1.0] * len(sentences) # All sentences selected
}
print("BERT Summarizer: Getting sentence embeddings...")
# Get sentence embeddings
sentence_embeddings = self.get_sentence_embeddings(sentences)
print(f"BERT Summarizer: Got embeddings shape: {sentence_embeddings.shape}")
# Calculate document embedding (mean of all sentences)
doc_embedding = np.mean(sentence_embeddings, axis=0)
# Calculate similarity scores
similarities = cosine_similarity([doc_embedding], sentence_embeddings)[0]
print(f"BERT Summarizer: Similarity scores: {similarities}")
# Get top sentences (indices with highest scores)
top_indices = np.argsort(similarities)[-num_sentences:]
print(f"BERT Summarizer: Top indices: {top_indices}")
# Sort indices to maintain original order in summary
top_indices_sorted = sorted(top_indices)
# Convert numpy indices to regular ints for JSON serialization
top_indices_sorted = [int(i) for i in top_indices_sorted]
print(f"BERT Summarizer: Selected indices (in order): {top_indices_sorted}")
# Get selected sentences and their scores
selected_sentences = [sentences[i] for i in top_indices_sorted]
selected_scores = [float(similarities[i]) for i in top_indices_sorted]
print(f"BERT Summarizer: Selected sentences: {[s[:50] + '...' for s in selected_sentences]}")
# Create summary by joining selected sentences
summary = ' '.join(selected_sentences)
return {
"summary": summary,
"original_sentence_count": original_sentence_count,
"summary_sentence_count": len(selected_sentences),
"sentences": sentences, # All original sentences
"selected_indices": top_indices_sorted,
"sentence_scores": selected_scores,
"top_sentence_scores": selected_scores # Additional info
}