import torch import torch.nn as nn import numpy as np from typing import Dict, List, Any, Optional from safetensors.torch import load_file from preprocessor import preprocess_for_summarization from collections import Counter import re import os class Seq2SeqTokenizer: """Arabic tokenizer for Seq2Seq tasks""" def __init__(self, vocab_size=10000): self.vocab_size = vocab_size self.word2idx = {'': 0, '': 1, '': 2, '': 3} self.idx2word = {0: '', 1: '', 2: '', 3: ''} self.vocab_built = False def clean_arabic_text(self, text): if text is None or text == "": return "" text = str(text) text = preprocess_for_summarization(text) text = re.sub(r'[^\u0600-\u06FF\u0750-\u077F\s\d.,!?():\-]', ' ', text) text = re.sub(r'\s+', ' ', text.strip()) return text def build_vocab_from_texts(self, texts, min_freq=2): word_counts = Counter() total_words = 0 for text in texts: cleaned = self.clean_arabic_text(text) words = cleaned.split() word_counts.update(words) total_words += len(words) filtered_words = {word: count for word, count in word_counts.items() if count >= min_freq and len(word.strip()) > 0} most_common = sorted(filtered_words.items(), key=lambda x: x[1], reverse=True) vocab_words = most_common[:self.vocab_size - 4] for word, count in vocab_words: if word not in self.word2idx: idx = len(self.word2idx) self.word2idx[word] = idx self.idx2word[idx] = word self.vocab_built = True return len(self.word2idx) def encode(self, text, max_len, add_special=False): cleaned = self.clean_arabic_text(text) words = cleaned.split() if add_special: words = [''] + words + [''] indices = [] for word in words[:max_len]: indices.append(self.word2idx.get(word, self.word2idx[''])) while len(indices) < max_len: indices.append(self.word2idx['']) return indices[:max_len] def decode(self, indices, skip_special=True): words = [] for idx in indices: if isinstance(idx, torch.Tensor): idx = idx.item() word = self.idx2word.get(int(idx), '') if skip_special: if word in ['', '']: continue elif word == '': break if word != '' or not skip_special: words.append(word) return ' '.join(words) class Seq2SeqModel(nn.Module): def __init__(self, vocab_size=10000, embedding_dim=128, encoder_hidden=256, decoder_hidden=256): super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.encoder_hidden = encoder_hidden self.decoder_hidden = decoder_hidden self.encoder_embedding = nn.Embedding(vocab_size, embedding_dim) self.decoder_embedding = nn.Embedding(vocab_size, embedding_dim) self.encoder_lstm = nn.LSTM(embedding_dim, encoder_hidden, batch_first=True) self.decoder_lstm = nn.LSTM(embedding_dim + encoder_hidden, decoder_hidden, batch_first=True) self.attention = nn.Linear(encoder_hidden + decoder_hidden, decoder_hidden) self.context_combine = nn.Linear(encoder_hidden + decoder_hidden, decoder_hidden) self.output_proj = nn.Linear(decoder_hidden, vocab_size) def forward(self, src_seq, tgt_seq=None, max_len=50): batch_size = src_seq.size(0) src_len = src_seq.size(1) src_embedded = self.encoder_embedding(src_seq) encoder_outputs, (encoder_hidden, encoder_cell) = self.encoder_lstm(src_embedded) if tgt_seq is not None: tgt_embedded = self.decoder_embedding(tgt_seq) encoder_context = encoder_hidden[-1].unsqueeze(1).repeat(1, tgt_seq.size(1), 1) decoder_input = torch.cat([tgt_embedded, encoder_context], dim=2) decoder_outputs, _ = self.decoder_lstm(decoder_input, (encoder_hidden, encoder_cell)) outputs = self.output_proj(decoder_outputs) return outputs else: outputs = [] decoder_hidden = encoder_hidden decoder_cell = encoder_cell decoder_input = torch.ones(batch_size, 1, dtype=torch.long, device=src_seq.device) for _ in range(max_len): tgt_embedded = self.decoder_embedding(decoder_input) encoder_context = encoder_hidden[-1].unsqueeze(1) decoder_input_combined = torch.cat([tgt_embedded, encoder_context], dim=2) decoder_output, (decoder_hidden, decoder_cell) = self.decoder_lstm( decoder_input_combined, (decoder_hidden, decoder_cell) ) output = self.output_proj(decoder_output) outputs.append(output) decoder_input = torch.argmax(output, dim=2) if decoder_input.item() == 2: break return torch.cat(outputs, dim=1) if outputs else torch.zeros(batch_size, 1, self.vocab_size) class Seq2SeqSummarizer: def __init__(self, model_path: str): self.model_path = model_path self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Seq2Seq Summarizer: Using device: {self.device}") self.tokenizer = Seq2SeqTokenizer(vocab_size=10000) self._load_model() def _load_model(self): try: print(f"Seq2Seq Summarizer: Loading model from {self.model_path}") state_dict = load_file(self.model_path) vocab_size, embedding_dim = state_dict['encoder_embedding.weight'].shape encoder_hidden = state_dict['encoder_lstm.weight_hh_l0'].shape[1] decoder_hidden = state_dict['decoder_lstm.weight_hh_l0'].shape[1] print(f"Model architecture: vocab={vocab_size}, emb={embedding_dim}, enc_h={encoder_hidden}, dec_h={decoder_hidden}") self.model = Seq2SeqModel(vocab_size, embedding_dim, encoder_hidden, decoder_hidden) self.model.load_state_dict(state_dict, strict=True) self.model.to(self.device) self.model.eval() self._build_basic_vocab() print("Seq2Seq Summarizer: Model loaded successfully") except Exception as e: raise RuntimeError(f"Error loading seq2seq model: {e}") def _build_basic_vocab(self): basic_arabic_words = [ 'في', 'من', 'إلى', 'على', 'هذا', 'هذه', 'التي', 'الذي', 'كان', 'كانت', 'يكون', 'تكون', 'قال', 'قالت', 'يقول', 'تقول', 'بعد', 'قبل', 'أن', 'أنه', 'ما', 'لا', 'نعم', 'كل', 'بعض', 'جميع', 'هناك', 'هنا', 'حيث', 'كيف', 'متى', 'أين', 'لماذا', 'والذي', 'والتي', 'أيضا', 'كذلك', 'حول', 'خلال', 'عند' ] for i, word in enumerate(basic_arabic_words): if len(self.tokenizer.word2idx) < self.tokenizer.vocab_size: idx = len(self.tokenizer.word2idx) self.tokenizer.word2idx[word] = idx self.tokenizer.idx2word[idx] = word self.tokenizer.vocab_built = True def _generate_summary(self, text: str, max_length: int = 50) -> str: try: src_tokens = self.tokenizer.encode(text, max_len=100, add_special=False) src_tensor = torch.tensor([src_tokens], dtype=torch.long, device=self.device) with torch.no_grad(): output = self.model(src_tensor, max_len=max_length) if output.numel() > 0: predicted_ids = torch.argmax(output, dim=-1) predicted_ids = predicted_ids[0].cpu().numpy() summary = self.tokenizer.decode(predicted_ids, skip_special=True) if summary.strip() and len(summary.strip()) > 5 and 'متاح' not in summary: return summary.strip() sentences = re.split(r'[.!؟\n]+', text) sentences = [s.strip() for s in sentences if s.strip() and len(s.strip()) > 10] unique_sentences = [] for sent in sentences: if not any(sent.strip() == existing.strip() for existing in unique_sentences): unique_sentences.append(sent) if len(unique_sentences) >= 2: return '. '.join(unique_sentences[:2]) + '.' elif len(unique_sentences) == 1: return unique_sentences[0] + '.' else: return text[:150] + "..." if len(text) > 150 else text except Exception as e: print(f"Seq2Seq generation error: {e}") # Same improved fallback logic sentences = re.split(r'[.!؟\n]+', text) sentences = [s.strip() for s in sentences if s.strip() and len(s.strip()) > 10] unique_sentences = [] for sent in sentences: if not any(sent.strip() == existing.strip() for existing in unique_sentences): unique_sentences.append(sent) if len(unique_sentences) >= 2: return '. '.join(unique_sentences[:2]) + '.' elif len(unique_sentences) == 1: return unique_sentences[0] + '.' else: return text[:150] + "..." if len(text) > 150 else text def summarize(self, text: str, num_sentences: int = 3) -> Dict[str, Any]: print(f"Seq2Seq Summarizer: Processing text with {len(text)} characters") cleaned_text = preprocess_for_summarization(text) print(f"Seq2Seq Summarizer: After preprocessing: '{cleaned_text[:100]}...'") sentences = re.split(r'[.!؟\n]+', cleaned_text) sentences = [s.strip() for s in sentences if s.strip()] print(f"Seq2Seq Summarizer: Found {len(sentences)} sentences") original_sentence_count = len(sentences) target_length = min(num_sentences * 15, 50) generated_summary = self._generate_summary(cleaned_text, max_length=target_length) print(f"Seq2Seq Summarizer: Generated summary: '{generated_summary[:100]}...'") summary_sentences = re.split(r'[.!؟\n]+', generated_summary) summary_sentences = [s.strip() for s in summary_sentences if s.strip()] if len(summary_sentences) < num_sentences and len(sentences) > len(summary_sentences): remaining_needed = num_sentences - len(summary_sentences) additional_sentences = sentences[:remaining_needed] summary_sentences.extend(additional_sentences) summary_sentences = summary_sentences[:num_sentences] final_summary = ' '.join(summary_sentences) dummy_scores = [1.0] * len(sentences) selected_indices = list(range(min(len(sentences), len(summary_sentences)))) return { "summary": final_summary, "original_sentence_count": original_sentence_count, "summary_sentence_count": len(summary_sentences), "sentences": sentences, "selected_indices": selected_indices, "sentence_scores": dummy_scores, "top_sentence_scores": [1.0] * len(summary_sentences), "generated_summary": generated_summary, "model_type": "seq2seq" }