|
from typing import Optional, List, Dict, Any |
|
import os |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from enum import Enum |
|
|
|
from summarizer import ArabicSummarizer |
|
from preprocessor import ArabicPreprocessor |
|
from model_manager import ModelManager |
|
from examples import REQUEST_EXAMPLES, RESPONSE_EXAMPLES |
|
from bert_summarizer import BERTExtractiveSummarizer |
|
from seq2seq_summarizer import Seq2SeqSummarizer |
|
|
|
|
|
class TaskType(str, Enum): |
|
CLASSIFICATION = "classification" |
|
SUMMARIZATION = "summarization" |
|
|
|
|
|
|
|
class ClassificationModelType(str, Enum): |
|
TRADITIONAL_SVM = "traditional_svm" |
|
MODERN_LSTM = "modern_lstm" |
|
MODERN_BERT = "modern_bert" |
|
|
|
|
|
class SummarizationModelType(str, Enum): |
|
TRADITIONAL_TFIDF = "traditional_tfidf" |
|
MODERN_SEQ2SEQ = "modern_seq2seq" |
|
MODERN_BERT = "modern_bert" |
|
|
|
|
|
|
|
class PreprocessRequest(BaseModel): |
|
text: str |
|
task_type: TaskType |
|
|
|
model_config = { |
|
"json_schema_extra": {"example": {"text": "هذا نص عربي للمعالجة", "task_type": "classification"}} |
|
} |
|
|
|
|
|
class ClassificationRequest(BaseModel): |
|
text: str |
|
model: ClassificationModelType |
|
|
|
model_config = {"json_schema_extra": {"example": {"text": "هذا نص عربي للتصنيف", "model": "traditional_svm"}}} |
|
|
|
|
|
class SummarizationRequest(BaseModel): |
|
text: str |
|
num_sentences: int = 3 |
|
model: SummarizationModelType |
|
|
|
model_config = {"json_schema_extra": {"example": {"text": "هذا نص عربي طويل للتلخيص", "num_sentences": 3, "model": "traditional_tfidf"}}} |
|
|
|
|
|
|
|
class PreprocessingSteps(BaseModel): |
|
original: str |
|
stripped_lowered: Optional[str] = None |
|
normalized: Optional[str] = None |
|
diacritics_removed: Optional[str] = None |
|
punctuation_removed: Optional[str] = None |
|
repeated_chars_reduced: Optional[str] = None |
|
whitespace_normalized: Optional[str] = None |
|
numbers_removed: Optional[str] = None |
|
tokenized: Optional[List[str]] = None |
|
stopwords_removed: Optional[List[str]] = None |
|
stemmed: Optional[List[str]] = None |
|
final: str |
|
|
|
|
|
class PreprocessingResponse(BaseModel): |
|
task_type: str |
|
preprocessing_steps: PreprocessingSteps |
|
|
|
|
|
class ClassificationResponse(BaseModel): |
|
prediction: str |
|
confidence: float |
|
probability_distribution: Dict[str, float] |
|
cleaned_text: str |
|
model_used: str |
|
|
|
prediction_index: Optional[int] = None |
|
prediction_metadata: Optional[Dict[str, Any]] = None |
|
|
|
|
|
class SummarizationResponse(BaseModel): |
|
summary: str |
|
original_sentence_count: int |
|
summary_sentence_count: int |
|
sentences: List[str] |
|
selected_indices: List[int] |
|
sentence_scores: List[float] |
|
model_used: str |
|
|
|
top_sentence_scores: Optional[List[float]] = None |
|
|
|
|
|
app = FastAPI( |
|
title="Arabic Text Analysis API", |
|
description="API for Arabic text classification, summarization, and preprocessing with multiple model support", |
|
version="1.0.0", |
|
) |
|
|
|
model_manager = ModelManager(default_model="traditional_svm") |
|
summarizer = ArabicSummarizer("models/traditional_tfidf_vectorizer_summarization.joblib") |
|
preprocessor = ArabicPreprocessor() |
|
|
|
|
|
|
|
class SummarizerManager: |
|
"""Manages different types of Arabic text summarizers.""" |
|
|
|
def __init__(self): |
|
|
|
self.traditional_tfidf = ArabicSummarizer("models/traditional_tfidf_vectorizer_summarization.joblib") |
|
|
|
|
|
self.bert_summarizer = None |
|
self.seq2seq_summarizer = None |
|
|
|
def get_summarizer(self, model_type: str): |
|
"""Get summarizer based on model type.""" |
|
if model_type == "traditional_tfidf": |
|
return self.traditional_tfidf |
|
elif model_type == "modern_seq2seq": |
|
|
|
if self.seq2seq_summarizer is None: |
|
try: |
|
print("Loading Seq2Seq summarizer...") |
|
model_path = os.path.join(os.path.dirname(__file__), "models", "modern_seq2seq_summarizer.safetensors") |
|
self.seq2seq_summarizer = Seq2SeqSummarizer(model_path) |
|
print("Seq2Seq summarizer loaded successfully!") |
|
except Exception as e: |
|
print(f"Failed to load Seq2Seq summarizer: {e}") |
|
raise ValueError(f"Seq2Seq summarizer initialization failed: {e}") |
|
return self.seq2seq_summarizer |
|
elif model_type == "modern_bert": |
|
|
|
if self.bert_summarizer is None: |
|
try: |
|
print("Loading BERT summarizer...") |
|
self.bert_summarizer = BERTExtractiveSummarizer() |
|
print("BERT summarizer loaded successfully!") |
|
except Exception as e: |
|
print(f"Failed to load BERT summarizer: {e}") |
|
raise ValueError(f"BERT summarizer initialization failed: {e}") |
|
return self.bert_summarizer |
|
else: |
|
raise ValueError(f"Unknown summarizer model: {model_type}") |
|
|
|
def summarize(self, text: str, num_sentences: int, model_type: str) -> Dict[str, Any]: |
|
"""Summarize text using the specified model.""" |
|
try: |
|
print(f"SummarizerManager: Using model '{model_type}' for text with {len(text)} characters") |
|
summarizer_instance = self.get_summarizer(model_type) |
|
result = summarizer_instance.summarize(text, num_sentences) |
|
|
|
|
|
print(f"SummarizerManager: {model_type} selected indices: {result.get('selected_indices', [])}") |
|
print(f"SummarizerManager: {model_type} summary preview: '{result.get('summary', '')[:100]}...'") |
|
|
|
|
|
if result.get("sentence_scores") is None: |
|
result["sentence_scores"] = [] |
|
|
|
return result |
|
except Exception as e: |
|
|
|
if model_type == "modern_bert": |
|
raise ValueError(f"BERT summarization failed: {str(e)}. This might be due to missing dependencies (torch, transformers) or network issues downloading the model.") |
|
else: |
|
raise |
|
|
|
|
|
summarizer_manager = SummarizerManager() |
|
|
|
|
|
|
|
def check_model_availability(): |
|
"""Check which models are actually available and working.""" |
|
available_models = { |
|
"traditional_svm": True, |
|
"modern_lstm": True, |
|
"modern_bert": False |
|
} |
|
|
|
|
|
try: |
|
from modern_classifier import ModernClassifier |
|
|
|
bert_classifier = ModernClassifier("bert", "models/modern_bert_classifier.safetensors") |
|
available_models["modern_bert"] = True |
|
except Exception as e: |
|
print(f"BERT model not available: {e}") |
|
available_models["modern_bert"] = False |
|
|
|
return available_models |
|
|
|
|
|
|
|
AVAILABLE_MODELS = check_model_availability() |
|
|
|
|
|
def _map_classification_model(frontend_model: str) -> str: |
|
"""Map frontend model names to backend model names.""" |
|
|
|
if not AVAILABLE_MODELS.get(frontend_model, False): |
|
raise ValueError(f"Model '{frontend_model}' is not available. Available models: {[k for k, v in AVAILABLE_MODELS.items() if v]}") |
|
|
|
mapping = { |
|
"traditional_svm": "traditional_svm", |
|
"modern_lstm": "modern_lstm", |
|
"modern_bert": "modern_bert" |
|
} |
|
return mapping.get(frontend_model, frontend_model) |
|
|
|
|
|
def _create_preprocessing_steps(steps: Dict[str, Any]) -> PreprocessingSteps: |
|
"""Create preprocessing steps response with only the fields that exist.""" |
|
return PreprocessingSteps( |
|
original=steps.get("original", ""), |
|
stripped_lowered=steps.get("stripped_lowered"), |
|
normalized=steps.get("normalized"), |
|
diacritics_removed=steps.get("diacritics_removed"), |
|
punctuation_removed=steps.get("punctuation_removed"), |
|
repeated_chars_reduced=steps.get("repeated_chars_reduced"), |
|
whitespace_normalized=steps.get("whitespace_normalized"), |
|
numbers_removed=steps.get("numbers_removed"), |
|
tokenized=steps.get("tokenized"), |
|
stopwords_removed=steps.get("stopwords_removed"), |
|
stemmed=steps.get("stemmed"), |
|
final=steps.get("final", "") |
|
) |
|
|
|
|
|
|
|
@app.get("/") |
|
def read_root() -> Dict[str, Any]: |
|
"""API welcome message and endpoint documentation.""" |
|
return { |
|
"message": "Welcome to the Arabic Text Analysis API!", |
|
"documentation": { |
|
"interactive_docs": "/docs", |
|
"redoc": "/redoc", |
|
"openapi_schema": "/openapi.json", |
|
}, |
|
"endpoints": { |
|
"preprocess": "POST /preprocess - Preprocess text with detailed steps", |
|
"classify": "POST /classify - Classify Arabic text", |
|
"summarize": "POST /summarize - Summarize Arabic text", |
|
}, |
|
} |
|
|
|
|
|
@app.post("/preprocess", response_model=PreprocessingResponse) |
|
def preprocess_text(req: PreprocessRequest) -> PreprocessingResponse: |
|
"""Preprocess text with step-by-step breakdown.""" |
|
try: |
|
steps = preprocessor.get_preprocessing_steps(req.text, req.task_type.value) |
|
preprocessing_steps = _create_preprocessing_steps(steps) |
|
return PreprocessingResponse( |
|
task_type=req.task_type.value, |
|
preprocessing_steps=preprocessing_steps |
|
) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Preprocessing failed: {str(e)}") |
|
|
|
|
|
@app.post("/classify", response_model=ClassificationResponse) |
|
def classify_text(req: ClassificationRequest) -> ClassificationResponse: |
|
"""Classify Arabic text.""" |
|
try: |
|
backend_model = _map_classification_model(req.model.value) |
|
result = model_manager.predict(req.text, backend_model) |
|
|
|
return ClassificationResponse( |
|
prediction=result["prediction"], |
|
confidence=result["confidence"], |
|
probability_distribution=result["probability_distribution"], |
|
cleaned_text=result["cleaned_text"], |
|
model_used=req.model.value, |
|
prediction_index=result.get("prediction_index"), |
|
prediction_metadata=result.get("prediction_metadata") |
|
) |
|
except ValueError as e: |
|
|
|
if "not available" in str(e): |
|
raise HTTPException( |
|
status_code=503, |
|
detail=f"Model unavailable: {str(e)}. Check /models/available for current model status." |
|
) |
|
else: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except Exception as e: |
|
error_msg = str(e) |
|
|
|
|
|
if "BERT" in error_msg and ("connect" in error_msg.lower() or "internet" in error_msg.lower() or "huggingface" in error_msg.lower()): |
|
raise HTTPException( |
|
status_code=503, |
|
detail=f"BERT model unavailable: The model requires internet connection to download tokenizer/config from Hugging Face, or the files need to be cached locally. Error: {error_msg}" |
|
) |
|
elif "modern_bert" in req.model.value and "Error loading" in error_msg: |
|
raise HTTPException( |
|
status_code=503, |
|
detail=f"BERT model loading failed: {error_msg}. Please ensure the model files are properly configured and Hugging Face dependencies are available." |
|
) |
|
else: |
|
raise HTTPException(status_code=500, detail=f"Classification failed: {error_msg}") |
|
|
|
|
|
@app.post("/summarize", response_model=SummarizationResponse) |
|
def summarize_text(req: SummarizationRequest) -> SummarizationResponse: |
|
"""Summarize Arabic text.""" |
|
try: |
|
result = summarizer_manager.summarize(req.text, req.num_sentences, req.model.value) |
|
|
|
return SummarizationResponse( |
|
summary=result["summary"], |
|
original_sentence_count=result["original_sentence_count"], |
|
summary_sentence_count=result["summary_sentence_count"], |
|
sentences=result["sentences"], |
|
selected_indices=result["selected_indices"], |
|
sentence_scores=result["sentence_scores"], |
|
model_used=req.model.value, |
|
top_sentence_scores=result.get("top_sentence_scores") |
|
) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Summarization failed: {str(e)}") |
|
|
|
|
|
@app.get("/models/available") |
|
def get_available_models() -> Dict[str, Any]: |
|
"""Get information about which models are currently available.""" |
|
return { |
|
"classification_models": { |
|
"traditional_svm": { |
|
"available": AVAILABLE_MODELS.get("traditional_svm", False), |
|
"description": "Traditional SVM classifier with TF-IDF vectorization" |
|
}, |
|
"modern_lstm": { |
|
"available": AVAILABLE_MODELS.get("modern_lstm", False), |
|
"description": "Modern LSTM-based neural network classifier" |
|
}, |
|
"modern_bert": { |
|
"available": AVAILABLE_MODELS.get("modern_bert", False), |
|
"description": "Modern BERT-based transformer classifier", |
|
"note": "Requires internet connection or cached Hugging Face models" if not AVAILABLE_MODELS.get("modern_bert", False) else None |
|
} |
|
}, |
|
"summarization_models": { |
|
"traditional_tfidf": { |
|
"available": True, |
|
"description": "Traditional TF-IDF based extractive summarization" |
|
}, |
|
"modern_seq2seq": { |
|
"available": True, |
|
"description": "Modern sequence-to-sequence summarization (currently uses TF-IDF fallback)", |
|
"note": "Implementation in progress - currently falls back to TF-IDF" |
|
}, |
|
"modern_bert": { |
|
"available": True, |
|
"description": "Modern BERT-based extractive summarization using asafaya/bert-base-arabic", |
|
"note": "Requires torch and transformers dependencies. Model will be downloaded on first use." |
|
} |
|
}, |
|
"status": { |
|
"total_classification_models": len([k for k, v in AVAILABLE_MODELS.items() if v]), |
|
"total_available": len([k for k, v in AVAILABLE_MODELS.items() if v]), |
|
"unavailable_models": [k for k, v in AVAILABLE_MODELS.items() if not v] |
|
} |
|
} |
|
|