moabos
feat: replace current tesnorflow seq2seq model with improved pytorch implementation
07edbf0
from typing import Optional, List, Dict, Any
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from enum import Enum
from summarizer import ArabicSummarizer
from preprocessor import ArabicPreprocessor
from model_manager import ModelManager
from examples import REQUEST_EXAMPLES, RESPONSE_EXAMPLES
from bert_summarizer import BERTExtractiveSummarizer
from seq2seq_summarizer import Seq2SeqSummarizer
class TaskType(str, Enum):
CLASSIFICATION = "classification"
SUMMARIZATION = "summarization"
# New enums for frontend compatibility
class ClassificationModelType(str, Enum):
TRADITIONAL_SVM = "traditional_svm"
MODERN_LSTM = "modern_lstm"
MODERN_BERT = "modern_bert"
class SummarizationModelType(str, Enum):
TRADITIONAL_TFIDF = "traditional_tfidf"
MODERN_SEQ2SEQ = "modern_seq2seq"
MODERN_BERT = "modern_bert"
# Request models
class PreprocessRequest(BaseModel):
text: str
task_type: TaskType
model_config = {
"json_schema_extra": {"example": {"text": "هذا نص عربي للمعالجة", "task_type": "classification"}}
}
class ClassificationRequest(BaseModel):
text: str
model: ClassificationModelType
model_config = {"json_schema_extra": {"example": {"text": "هذا نص عربي للتصنيف", "model": "traditional_svm"}}}
class SummarizationRequest(BaseModel):
text: str
num_sentences: int = 3
model: SummarizationModelType
model_config = {"json_schema_extra": {"example": {"text": "هذا نص عربي طويل للتلخيص", "num_sentences": 3, "model": "traditional_tfidf"}}}
# Response models
class PreprocessingSteps(BaseModel):
original: str
stripped_lowered: Optional[str] = None
normalized: Optional[str] = None
diacritics_removed: Optional[str] = None
punctuation_removed: Optional[str] = None
repeated_chars_reduced: Optional[str] = None
whitespace_normalized: Optional[str] = None
numbers_removed: Optional[str] = None
tokenized: Optional[List[str]] = None
stopwords_removed: Optional[List[str]] = None
stemmed: Optional[List[str]] = None
final: str
class PreprocessingResponse(BaseModel):
task_type: str
preprocessing_steps: PreprocessingSteps
class ClassificationResponse(BaseModel):
prediction: str
confidence: float
probability_distribution: Dict[str, float]
cleaned_text: str
model_used: str
# Optional fields for extra info
prediction_index: Optional[int] = None
prediction_metadata: Optional[Dict[str, Any]] = None
class SummarizationResponse(BaseModel):
summary: str
original_sentence_count: int
summary_sentence_count: int
sentences: List[str]
selected_indices: List[int]
sentence_scores: List[float]
model_used: str
# Optional fields for extra info
top_sentence_scores: Optional[List[float]] = None
app = FastAPI(
title="Arabic Text Analysis API",
description="API for Arabic text classification, summarization, and preprocessing with multiple model support",
version="1.0.0",
)
model_manager = ModelManager(default_model="traditional_svm")
summarizer = ArabicSummarizer("models/traditional_tfidf_vectorizer_summarization.joblib")
preprocessor = ArabicPreprocessor()
# Summarizer manager for model dispatch
class SummarizerManager:
"""Manages different types of Arabic text summarizers."""
def __init__(self):
# Initialize the traditional TF-IDF summarizer
self.traditional_tfidf = ArabicSummarizer("models/traditional_tfidf_vectorizer_summarization.joblib")
# Initialize other summarizers (lazy loading to avoid startup delays)
self.bert_summarizer = None
self.seq2seq_summarizer = None
def get_summarizer(self, model_type: str):
"""Get summarizer based on model type."""
if model_type == "traditional_tfidf":
return self.traditional_tfidf
elif model_type == "modern_seq2seq":
# Initialize seq2seq summarizer on first use
if self.seq2seq_summarizer is None:
try:
print("Loading Seq2Seq summarizer...")
model_path = os.path.join(os.path.dirname(__file__), "models", "modern_seq2seq_summarizer.safetensors")
self.seq2seq_summarizer = Seq2SeqSummarizer(model_path)
print("Seq2Seq summarizer loaded successfully!")
except Exception as e:
print(f"Failed to load Seq2Seq summarizer: {e}")
raise ValueError(f"Seq2Seq summarizer initialization failed: {e}")
return self.seq2seq_summarizer
elif model_type == "modern_bert":
# Initialize BERT summarizer on first use
if self.bert_summarizer is None:
try:
print("Loading BERT summarizer...")
self.bert_summarizer = BERTExtractiveSummarizer()
print("BERT summarizer loaded successfully!")
except Exception as e:
print(f"Failed to load BERT summarizer: {e}")
raise ValueError(f"BERT summarizer initialization failed: {e}")
return self.bert_summarizer
else:
raise ValueError(f"Unknown summarizer model: {model_type}")
def summarize(self, text: str, num_sentences: int, model_type: str) -> Dict[str, Any]:
"""Summarize text using the specified model."""
try:
print(f"SummarizerManager: Using model '{model_type}' for text with {len(text)} characters")
summarizer_instance = self.get_summarizer(model_type)
result = summarizer_instance.summarize(text, num_sentences)
# Add debugging info
print(f"SummarizerManager: {model_type} selected indices: {result.get('selected_indices', [])}")
print(f"SummarizerManager: {model_type} summary preview: '{result.get('summary', '')[:100]}...'")
# Ensure sentence_scores is always a list (not None)
if result.get("sentence_scores") is None:
result["sentence_scores"] = []
return result
except Exception as e:
# If BERT fails, provide helpful error message
if model_type == "modern_bert":
raise ValueError(f"BERT summarization failed: {str(e)}. This might be due to missing dependencies (torch, transformers) or network issues downloading the model.")
else:
raise
summarizer_manager = SummarizerManager()
# Check which models are actually available
def check_model_availability():
"""Check which models are actually available and working."""
available_models = {
"traditional_svm": True, # Always available
"modern_lstm": True, # Always available
"modern_bert": False # Will be checked
}
# Test BERT model availability
try:
from modern_classifier import ModernClassifier
# Try to create a BERT classifier instance
bert_classifier = ModernClassifier("bert", "models/modern_bert_classifier.safetensors")
available_models["modern_bert"] = True
except Exception as e:
print(f"BERT model not available: {e}")
available_models["modern_bert"] = False
return available_models
# Check model availability at startup
AVAILABLE_MODELS = check_model_availability()
def _map_classification_model(frontend_model: str) -> str:
"""Map frontend model names to backend model names."""
# Check if the requested model is available
if not AVAILABLE_MODELS.get(frontend_model, False):
raise ValueError(f"Model '{frontend_model}' is not available. Available models: {[k for k, v in AVAILABLE_MODELS.items() if v]}")
mapping = {
"traditional_svm": "traditional_svm",
"modern_lstm": "modern_lstm",
"modern_bert": "modern_bert"
}
return mapping.get(frontend_model, frontend_model)
def _create_preprocessing_steps(steps: Dict[str, Any]) -> PreprocessingSteps:
"""Create preprocessing steps response with only the fields that exist."""
return PreprocessingSteps(
original=steps.get("original", ""),
stripped_lowered=steps.get("stripped_lowered"),
normalized=steps.get("normalized"),
diacritics_removed=steps.get("diacritics_removed"),
punctuation_removed=steps.get("punctuation_removed"),
repeated_chars_reduced=steps.get("repeated_chars_reduced"),
whitespace_normalized=steps.get("whitespace_normalized"),
numbers_removed=steps.get("numbers_removed"),
tokenized=steps.get("tokenized"),
stopwords_removed=steps.get("stopwords_removed"),
stemmed=steps.get("stemmed"),
final=steps.get("final", "")
)
# Main endpoints
@app.get("/")
def read_root() -> Dict[str, Any]:
"""API welcome message and endpoint documentation."""
return {
"message": "Welcome to the Arabic Text Analysis API!",
"documentation": {
"interactive_docs": "/docs",
"redoc": "/redoc",
"openapi_schema": "/openapi.json",
},
"endpoints": {
"preprocess": "POST /preprocess - Preprocess text with detailed steps",
"classify": "POST /classify - Classify Arabic text",
"summarize": "POST /summarize - Summarize Arabic text",
},
}
@app.post("/preprocess", response_model=PreprocessingResponse)
def preprocess_text(req: PreprocessRequest) -> PreprocessingResponse:
"""Preprocess text with step-by-step breakdown."""
try:
steps = preprocessor.get_preprocessing_steps(req.text, req.task_type.value)
preprocessing_steps = _create_preprocessing_steps(steps)
return PreprocessingResponse(
task_type=req.task_type.value,
preprocessing_steps=preprocessing_steps
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Preprocessing failed: {str(e)}")
@app.post("/classify", response_model=ClassificationResponse)
def classify_text(req: ClassificationRequest) -> ClassificationResponse:
"""Classify Arabic text."""
try:
backend_model = _map_classification_model(req.model.value)
result = model_manager.predict(req.text, backend_model)
return ClassificationResponse(
prediction=result["prediction"],
confidence=result["confidence"],
probability_distribution=result["probability_distribution"],
cleaned_text=result["cleaned_text"],
model_used=req.model.value, # Echo back the frontend model name
prediction_index=result.get("prediction_index"),
prediction_metadata=result.get("prediction_metadata")
)
except ValueError as e:
# Handle model availability errors
if "not available" in str(e):
raise HTTPException(
status_code=503,
detail=f"Model unavailable: {str(e)}. Check /models/available for current model status."
)
else:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
error_msg = str(e)
# Provide more helpful error messages for common issues
if "BERT" in error_msg and ("connect" in error_msg.lower() or "internet" in error_msg.lower() or "huggingface" in error_msg.lower()):
raise HTTPException(
status_code=503,
detail=f"BERT model unavailable: The model requires internet connection to download tokenizer/config from Hugging Face, or the files need to be cached locally. Error: {error_msg}"
)
elif "modern_bert" in req.model.value and "Error loading" in error_msg:
raise HTTPException(
status_code=503,
detail=f"BERT model loading failed: {error_msg}. Please ensure the model files are properly configured and Hugging Face dependencies are available."
)
else:
raise HTTPException(status_code=500, detail=f"Classification failed: {error_msg}")
@app.post("/summarize", response_model=SummarizationResponse)
def summarize_text(req: SummarizationRequest) -> SummarizationResponse:
"""Summarize Arabic text."""
try:
result = summarizer_manager.summarize(req.text, req.num_sentences, req.model.value)
return SummarizationResponse(
summary=result["summary"],
original_sentence_count=result["original_sentence_count"],
summary_sentence_count=result["summary_sentence_count"],
sentences=result["sentences"],
selected_indices=result["selected_indices"],
sentence_scores=result["sentence_scores"],
model_used=req.model.value, # Echo back the frontend model name
top_sentence_scores=result.get("top_sentence_scores")
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Summarization failed: {str(e)}")
@app.get("/models/available")
def get_available_models() -> Dict[str, Any]:
"""Get information about which models are currently available."""
return {
"classification_models": {
"traditional_svm": {
"available": AVAILABLE_MODELS.get("traditional_svm", False),
"description": "Traditional SVM classifier with TF-IDF vectorization"
},
"modern_lstm": {
"available": AVAILABLE_MODELS.get("modern_lstm", False),
"description": "Modern LSTM-based neural network classifier"
},
"modern_bert": {
"available": AVAILABLE_MODELS.get("modern_bert", False),
"description": "Modern BERT-based transformer classifier",
"note": "Requires internet connection or cached Hugging Face models" if not AVAILABLE_MODELS.get("modern_bert", False) else None
}
},
"summarization_models": {
"traditional_tfidf": {
"available": True,
"description": "Traditional TF-IDF based extractive summarization"
},
"modern_seq2seq": {
"available": True,
"description": "Modern sequence-to-sequence summarization (currently uses TF-IDF fallback)",
"note": "Implementation in progress - currently falls back to TF-IDF"
},
"modern_bert": {
"available": True,
"description": "Modern BERT-based extractive summarization using asafaya/bert-base-arabic",
"note": "Requires torch and transformers dependencies. Model will be downloaded on first use."
}
},
"status": {
"total_classification_models": len([k for k, v in AVAILABLE_MODELS.items() if v]),
"total_available": len([k for k, v in AVAILABLE_MODELS.items() if v]),
"unavailable_models": [k for k, v in AVAILABLE_MODELS.items() if not v]
}
}