|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from typing import Dict, List, Any, Optional |
|
from safetensors.torch import load_file |
|
from preprocessor import preprocess_for_summarization |
|
from collections import Counter |
|
import re |
|
import os |
|
|
|
|
|
class Seq2SeqTokenizer: |
|
"""Arabic tokenizer for Seq2Seq tasks""" |
|
|
|
def __init__(self, vocab_size=10000): |
|
self.vocab_size = vocab_size |
|
self.word2idx = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3} |
|
self.idx2word = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'} |
|
self.vocab_built = False |
|
|
|
def clean_arabic_text(self, text): |
|
if text is None or text == "": |
|
return "" |
|
text = str(text) |
|
|
|
text = preprocess_for_summarization(text) |
|
text = re.sub(r'[^\u0600-\u06FF\u0750-\u077F\s\d.,!?():\-]', ' ', text) |
|
text = re.sub(r'\s+', ' ', text.strip()) |
|
|
|
return text |
|
|
|
def build_vocab_from_texts(self, texts, min_freq=2): |
|
word_counts = Counter() |
|
total_words = 0 |
|
|
|
for text in texts: |
|
cleaned = self.clean_arabic_text(text) |
|
words = cleaned.split() |
|
word_counts.update(words) |
|
total_words += len(words) |
|
|
|
filtered_words = {word: count for word, count in word_counts.items() |
|
if count >= min_freq and len(word.strip()) > 0} |
|
|
|
most_common = sorted(filtered_words.items(), key=lambda x: x[1], reverse=True) |
|
vocab_words = most_common[:self.vocab_size - 4] |
|
|
|
for word, count in vocab_words: |
|
if word not in self.word2idx: |
|
idx = len(self.word2idx) |
|
self.word2idx[word] = idx |
|
self.idx2word[idx] = word |
|
|
|
self.vocab_built = True |
|
return len(self.word2idx) |
|
|
|
def encode(self, text, max_len, add_special=False): |
|
cleaned = self.clean_arabic_text(text) |
|
words = cleaned.split() |
|
|
|
if add_special: |
|
words = ['<SOS>'] + words + ['<EOS>'] |
|
|
|
indices = [] |
|
for word in words[:max_len]: |
|
indices.append(self.word2idx.get(word, self.word2idx['<UNK>'])) |
|
|
|
while len(indices) < max_len: |
|
indices.append(self.word2idx['<PAD>']) |
|
|
|
return indices[:max_len] |
|
|
|
def decode(self, indices, skip_special=True): |
|
words = [] |
|
for idx in indices: |
|
if isinstance(idx, torch.Tensor): |
|
idx = idx.item() |
|
|
|
word = self.idx2word.get(int(idx), '<UNK>') |
|
|
|
if skip_special: |
|
if word in ['<PAD>', '<SOS>']: |
|
continue |
|
elif word == '<EOS>': |
|
break |
|
|
|
if word != '<UNK>' or not skip_special: |
|
words.append(word) |
|
|
|
return ' '.join(words) |
|
|
|
|
|
class Seq2SeqModel(nn.Module): |
|
|
|
def __init__(self, vocab_size=10000, embedding_dim=128, encoder_hidden=256, decoder_hidden=256): |
|
super().__init__() |
|
self.vocab_size = vocab_size |
|
self.embedding_dim = embedding_dim |
|
self.encoder_hidden = encoder_hidden |
|
self.decoder_hidden = decoder_hidden |
|
|
|
self.encoder_embedding = nn.Embedding(vocab_size, embedding_dim) |
|
self.decoder_embedding = nn.Embedding(vocab_size, embedding_dim) |
|
|
|
self.encoder_lstm = nn.LSTM(embedding_dim, encoder_hidden, batch_first=True) |
|
self.decoder_lstm = nn.LSTM(embedding_dim + encoder_hidden, decoder_hidden, batch_first=True) |
|
|
|
self.attention = nn.Linear(encoder_hidden + decoder_hidden, decoder_hidden) |
|
self.context_combine = nn.Linear(encoder_hidden + decoder_hidden, decoder_hidden) |
|
self.output_proj = nn.Linear(decoder_hidden, vocab_size) |
|
|
|
def forward(self, src_seq, tgt_seq=None, max_len=50): |
|
batch_size = src_seq.size(0) |
|
src_len = src_seq.size(1) |
|
|
|
src_embedded = self.encoder_embedding(src_seq) |
|
encoder_outputs, (encoder_hidden, encoder_cell) = self.encoder_lstm(src_embedded) |
|
|
|
if tgt_seq is not None: |
|
tgt_embedded = self.decoder_embedding(tgt_seq) |
|
encoder_context = encoder_hidden[-1].unsqueeze(1).repeat(1, tgt_seq.size(1), 1) |
|
decoder_input = torch.cat([tgt_embedded, encoder_context], dim=2) |
|
|
|
decoder_outputs, _ = self.decoder_lstm(decoder_input, (encoder_hidden, encoder_cell)) |
|
outputs = self.output_proj(decoder_outputs) |
|
return outputs |
|
else: |
|
outputs = [] |
|
decoder_hidden = encoder_hidden |
|
decoder_cell = encoder_cell |
|
|
|
decoder_input = torch.ones(batch_size, 1, dtype=torch.long, device=src_seq.device) |
|
|
|
for _ in range(max_len): |
|
tgt_embedded = self.decoder_embedding(decoder_input) |
|
encoder_context = encoder_hidden[-1].unsqueeze(1) |
|
decoder_input_combined = torch.cat([tgt_embedded, encoder_context], dim=2) |
|
|
|
decoder_output, (decoder_hidden, decoder_cell) = self.decoder_lstm( |
|
decoder_input_combined, (decoder_hidden, decoder_cell) |
|
) |
|
|
|
output = self.output_proj(decoder_output) |
|
outputs.append(output) |
|
|
|
decoder_input = torch.argmax(output, dim=2) |
|
|
|
if decoder_input.item() == 2: |
|
break |
|
|
|
return torch.cat(outputs, dim=1) if outputs else torch.zeros(batch_size, 1, self.vocab_size) |
|
|
|
|
|
class Seq2SeqSummarizer: |
|
|
|
def __init__(self, model_path: str): |
|
self.model_path = model_path |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f"Seq2Seq Summarizer: Using device: {self.device}") |
|
|
|
self.tokenizer = Seq2SeqTokenizer(vocab_size=10000) |
|
self._load_model() |
|
|
|
def _load_model(self): |
|
try: |
|
print(f"Seq2Seq Summarizer: Loading model from {self.model_path}") |
|
state_dict = load_file(self.model_path) |
|
|
|
vocab_size, embedding_dim = state_dict['encoder_embedding.weight'].shape |
|
encoder_hidden = state_dict['encoder_lstm.weight_hh_l0'].shape[1] |
|
decoder_hidden = state_dict['decoder_lstm.weight_hh_l0'].shape[1] |
|
|
|
print(f"Model architecture: vocab={vocab_size}, emb={embedding_dim}, enc_h={encoder_hidden}, dec_h={decoder_hidden}") |
|
|
|
self.model = Seq2SeqModel(vocab_size, embedding_dim, encoder_hidden, decoder_hidden) |
|
|
|
self.model.load_state_dict(state_dict, strict=True) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
self._build_basic_vocab() |
|
|
|
print("Seq2Seq Summarizer: Model loaded successfully") |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Error loading seq2seq model: {e}") |
|
|
|
def _build_basic_vocab(self): |
|
basic_arabic_words = [ |
|
'ูู', 'ู
ู', 'ุฅูู', 'ุนูู', 'ูุฐุง', 'ูุฐู', 'ุงูุชู', 'ุงูุฐู', 'ูุงู', 'ูุงูุช', |
|
'ูููู', 'ุชููู', 'ูุงู', 'ูุงูุช', 'ูููู', 'ุชููู', 'ุจุนุฏ', 'ูุจู', 'ุฃู', 'ุฃูู', |
|
'ู
ุง', 'ูุง', 'ูุนู
', 'ูู', 'ุจุนุถ', 'ุฌู
ูุน', 'ููุงู', 'ููุง', 'ุญูุซ', 'ููู', |
|
'ู
ุชู', 'ุฃูู', 'ูู
ุงุฐุง', 'ูุงูุฐู', 'ูุงูุชู', 'ุฃูุถุง', 'ูุฐูู', 'ุญูู', 'ุฎูุงู', 'ุนูุฏ' |
|
] |
|
|
|
for i, word in enumerate(basic_arabic_words): |
|
if len(self.tokenizer.word2idx) < self.tokenizer.vocab_size: |
|
idx = len(self.tokenizer.word2idx) |
|
self.tokenizer.word2idx[word] = idx |
|
self.tokenizer.idx2word[idx] = word |
|
|
|
self.tokenizer.vocab_built = True |
|
|
|
def _generate_summary(self, text: str, max_length: int = 50) -> str: |
|
try: |
|
src_tokens = self.tokenizer.encode(text, max_len=100, add_special=False) |
|
src_tensor = torch.tensor([src_tokens], dtype=torch.long, device=self.device) |
|
|
|
with torch.no_grad(): |
|
output = self.model(src_tensor, max_len=max_length) |
|
|
|
if output.numel() > 0: |
|
predicted_ids = torch.argmax(output, dim=-1) |
|
predicted_ids = predicted_ids[0].cpu().numpy() |
|
|
|
summary = self.tokenizer.decode(predicted_ids, skip_special=True) |
|
|
|
if summary.strip() and len(summary.strip()) > 5 and 'ู
ุชุงุญ' not in summary: |
|
return summary.strip() |
|
|
|
sentences = re.split(r'[.!ุ\n]+', text) |
|
sentences = [s.strip() for s in sentences if s.strip() and len(s.strip()) > 10] |
|
|
|
unique_sentences = [] |
|
for sent in sentences: |
|
if not any(sent.strip() == existing.strip() for existing in unique_sentences): |
|
unique_sentences.append(sent) |
|
|
|
if len(unique_sentences) >= 2: |
|
return '. '.join(unique_sentences[:2]) + '.' |
|
elif len(unique_sentences) == 1: |
|
return unique_sentences[0] + '.' |
|
else: |
|
return text[:150] + "..." if len(text) > 150 else text |
|
|
|
except Exception as e: |
|
print(f"Seq2Seq generation error: {e}") |
|
|
|
sentences = re.split(r'[.!ุ\n]+', text) |
|
sentences = [s.strip() for s in sentences if s.strip() and len(s.strip()) > 10] |
|
|
|
unique_sentences = [] |
|
for sent in sentences: |
|
if not any(sent.strip() == existing.strip() for existing in unique_sentences): |
|
unique_sentences.append(sent) |
|
|
|
if len(unique_sentences) >= 2: |
|
return '. '.join(unique_sentences[:2]) + '.' |
|
elif len(unique_sentences) == 1: |
|
return unique_sentences[0] + '.' |
|
else: |
|
return text[:150] + "..." if len(text) > 150 else text |
|
|
|
|
|
def summarize(self, text: str, num_sentences: int = 3) -> Dict[str, Any]: |
|
print(f"Seq2Seq Summarizer: Processing text with {len(text)} characters") |
|
|
|
cleaned_text = preprocess_for_summarization(text) |
|
print(f"Seq2Seq Summarizer: After preprocessing: '{cleaned_text[:100]}...'") |
|
|
|
sentences = re.split(r'[.!ุ\n]+', cleaned_text) |
|
sentences = [s.strip() for s in sentences if s.strip()] |
|
|
|
print(f"Seq2Seq Summarizer: Found {len(sentences)} sentences") |
|
original_sentence_count = len(sentences) |
|
|
|
target_length = min(num_sentences * 15, 50) |
|
generated_summary = self._generate_summary(cleaned_text, max_length=target_length) |
|
print(f"Seq2Seq Summarizer: Generated summary: '{generated_summary[:100]}...'") |
|
|
|
summary_sentences = re.split(r'[.!ุ\n]+', generated_summary) |
|
summary_sentences = [s.strip() for s in summary_sentences if s.strip()] |
|
|
|
if len(summary_sentences) < num_sentences and len(sentences) > len(summary_sentences): |
|
remaining_needed = num_sentences - len(summary_sentences) |
|
additional_sentences = sentences[:remaining_needed] |
|
summary_sentences.extend(additional_sentences) |
|
|
|
summary_sentences = summary_sentences[:num_sentences] |
|
final_summary = ' '.join(summary_sentences) |
|
|
|
dummy_scores = [1.0] * len(sentences) |
|
selected_indices = list(range(min(len(sentences), len(summary_sentences)))) |
|
|
|
return { |
|
"summary": final_summary, |
|
"original_sentence_count": original_sentence_count, |
|
"summary_sentence_count": len(summary_sentences), |
|
"sentences": sentences, |
|
"selected_indices": selected_indices, |
|
"sentence_scores": dummy_scores, |
|
"top_sentence_scores": [1.0] * len(summary_sentences), |
|
"generated_summary": generated_summary, |
|
"model_type": "seq2seq" |
|
} |
|
|