Fkhrayef
#1
by
Fkhrayef
- opened
- __pycache__/app.cpython-312.pyc +0 -0
- __pycache__/bert_summarizer.cpython-312.pyc +0 -0
- __pycache__/debug_test.cpython-312.pyc +0 -0
- __pycache__/examples.cpython-312.pyc +0 -0
- __pycache__/model_manager.cpython-312.pyc +0 -0
- __pycache__/modern_classifier.cpython-312.pyc +0 -0
- __pycache__/preprocessor.cpython-312.pyc +0 -0
- __pycache__/seq2seq_summarizer.cpython-312.pyc +0 -0
- __pycache__/summarizer.cpython-312.pyc +0 -0
- __pycache__/traditional_classifier.cpython-312.pyc +0 -0
- app.py +272 -258
- bert_summarizer.py +114 -0
- examples.py +2 -2
- model_manager.py +4 -4
- models/Seq2seq/seq2seq_config.json +1 -0
- models/Seq2seq/seq2seq_model.h5 +3 -0
- models/Seq2seq/src_tokenizer.pkl +3 -0
- models/Seq2seq/tgt_tokenizer.pkl +3 -0
- modern_bert_classifier.safetensors → models/modern_bert_classifier.safetensors +0 -0
- modern_lstm_classifier.pth → models/modern_lstm_classifier.pth +0 -0
- traditional_svm_classifier.joblib → models/traditional_svm_classifier.joblib +0 -0
- traditional_tfidf_vectorizer_classifier.joblib → models/traditional_tfidf_vectorizer_classifier.joblib +0 -0
- traditional_tfidf_vectorizer_summarization.joblib → models/traditional_tfidf_vectorizer_summarization.joblib +0 -0
- modern_classifier.py +76 -7
- summarizer.py +1 -1
- traditional_classifier.py +2 -2
__pycache__/app.cpython-312.pyc
ADDED
Binary file (17 kB). View file
|
|
__pycache__/bert_summarizer.cpython-312.pyc
ADDED
Binary file (5.94 kB). View file
|
|
__pycache__/debug_test.cpython-312.pyc
ADDED
Binary file (3.86 kB). View file
|
|
__pycache__/examples.cpython-312.pyc
ADDED
Binary file (7.97 kB). View file
|
|
__pycache__/model_manager.cpython-312.pyc
ADDED
Binary file (7.56 kB). View file
|
|
__pycache__/modern_classifier.cpython-312.pyc
ADDED
Binary file (17.3 kB). View file
|
|
__pycache__/preprocessor.cpython-312.pyc
ADDED
Binary file (8.7 kB). View file
|
|
__pycache__/seq2seq_summarizer.cpython-312.pyc
ADDED
Binary file (9.74 kB). View file
|
|
__pycache__/summarizer.cpython-312.pyc
ADDED
Binary file (4.7 kB). View file
|
|
__pycache__/traditional_classifier.cpython-312.pyc
ADDED
Binary file (8.89 kB). View file
|
|
app.py
CHANGED
@@ -7,6 +7,7 @@ from summarizer import ArabicSummarizer
|
|
7 |
from preprocessor import ArabicPreprocessor
|
8 |
from model_manager import ModelManager
|
9 |
from examples import REQUEST_EXAMPLES, RESPONSE_EXAMPLES
|
|
|
10 |
|
11 |
|
12 |
class TaskType(str, Enum):
|
@@ -14,80 +15,74 @@ class TaskType(str, Enum):
|
|
14 |
SUMMARIZATION = "summarization"
|
15 |
|
16 |
|
17 |
-
|
|
|
18 |
TRADITIONAL_SVM = "traditional_svm"
|
19 |
-
MODERN_BERT = "modern_bert"
|
20 |
MODERN_LSTM = "modern_lstm"
|
|
|
21 |
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
)
|
28 |
-
|
29 |
-
model_manager = ModelManager(default_model="traditional_svm")
|
30 |
-
summarizer = ArabicSummarizer("traditional_tfidf_vectorizer_summarization.joblib")
|
31 |
-
preprocessor = ArabicPreprocessor()
|
32 |
|
33 |
|
34 |
-
|
|
|
35 |
text: str
|
36 |
-
|
37 |
|
38 |
-
model_config = {
|
|
|
|
|
39 |
|
40 |
|
41 |
-
class
|
42 |
text: str
|
43 |
-
|
44 |
-
model: Optional[ModelType] = None
|
45 |
|
46 |
-
model_config = {
|
47 |
-
"json_schema_extra": {"example": REQUEST_EXAMPLES["text_input_with_sentences"]}
|
48 |
-
}
|
49 |
|
50 |
|
51 |
-
class
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
-
model_config = {
|
56 |
-
"json_schema_extra": {"example": REQUEST_EXAMPLES["batch_text_input"]}
|
57 |
-
}
|
58 |
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
67 |
|
68 |
|
69 |
class ClassificationResponse(BaseModel):
|
70 |
prediction: str
|
71 |
-
prediction_index: int
|
72 |
confidence: float
|
73 |
probability_distribution: Dict[str, float]
|
74 |
cleaned_text: str
|
75 |
model_used: str
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
"protected_namespaces": (),
|
80 |
-
"json_schema_extra": {
|
81 |
-
"example": RESPONSE_EXAMPLES["classification"],
|
82 |
-
"schema_extra": {
|
83 |
-
"properties": {
|
84 |
-
"prediction_index": {
|
85 |
-
"description": "Numerical index of the predicted class (0=culture, 1=economy, 2=international, 3=local, 4=religion, 5=sports)"
|
86 |
-
}
|
87 |
-
}
|
88 |
-
},
|
89 |
-
},
|
90 |
-
}
|
91 |
|
92 |
|
93 |
class SummarizationResponse(BaseModel):
|
@@ -96,89 +91,142 @@ class SummarizationResponse(BaseModel):
|
|
96 |
summary_sentence_count: int
|
97 |
sentences: List[str]
|
98 |
selected_indices: List[int]
|
99 |
-
sentence_scores:
|
100 |
-
top_sentence_scores: Optional[List[float]]
|
101 |
-
|
102 |
-
model_config = {
|
103 |
-
"json_schema_extra": {"example": RESPONSE_EXAMPLES["summarization"]}
|
104 |
-
}
|
105 |
-
|
106 |
-
|
107 |
-
class TextAnalysisResponse(BaseModel):
|
108 |
-
text: str
|
109 |
-
analysis: Dict[str, Any]
|
110 |
-
|
111 |
-
model_config = {
|
112 |
-
"json_schema_extra": {"example": RESPONSE_EXAMPLES["text_analysis"]}
|
113 |
-
}
|
114 |
-
|
115 |
-
|
116 |
-
class BatchClassificationResponse(BaseModel):
|
117 |
-
results: List[ClassificationResponse]
|
118 |
-
total_texts: int
|
119 |
model_used: str
|
120 |
-
|
121 |
-
|
122 |
-
"protected_namespaces": (),
|
123 |
-
"json_schema_extra": {"example": RESPONSE_EXAMPLES["batch_classification"]},
|
124 |
-
}
|
125 |
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
model_config = {
|
133 |
-
"json_schema_extra": {"example": RESPONSE_EXAMPLES["sentence_analysis"]}
|
134 |
-
}
|
135 |
|
|
|
|
|
|
|
136 |
|
137 |
-
class CompleteAnalysisResponse(BaseModel):
|
138 |
-
original_text: str
|
139 |
-
text_analysis: Dict[str, Any]
|
140 |
-
classification: ClassificationResponse
|
141 |
-
summarization: SummarizationResponse
|
142 |
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
}
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
}
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
class AvailableModelsResponse(BaseModel):
|
178 |
-
models: Dict[str, Any]
|
179 |
-
current_model: str
|
180 |
-
|
181 |
-
|
182 |
@app.get("/")
|
183 |
def read_root() -> Dict[str, Any]:
|
184 |
"""API welcome message and endpoint documentation."""
|
@@ -190,162 +238,128 @@ def read_root() -> Dict[str, Any]:
|
|
190 |
"openapi_schema": "/openapi.json",
|
191 |
},
|
192 |
"endpoints": {
|
|
|
193 |
"classify": "POST /classify - Classify Arabic text",
|
194 |
-
"classify_batch": "POST /classify/batch - Classify multiple texts",
|
195 |
"summarize": "POST /summarize - Summarize Arabic text",
|
196 |
-
"analyze": "POST /analyze - Both classify and summarize",
|
197 |
-
"preprocess": "POST /preprocess - Preprocess text with detailed steps",
|
198 |
-
"text_analysis": "POST /text-analysis - Analyze text characteristics",
|
199 |
-
"sentence_analysis": "POST /sentence-analysis - Detailed sentence analysis",
|
200 |
-
"model_info": "GET /model-info - Get model information",
|
201 |
-
"available_models": "GET /models - Get all available models",
|
202 |
},
|
203 |
}
|
204 |
|
205 |
|
206 |
-
@app.post("/
|
207 |
-
def
|
208 |
-
"""
|
209 |
try:
|
210 |
-
|
211 |
-
|
212 |
-
return
|
|
|
|
|
|
|
213 |
except Exception as e:
|
214 |
-
raise HTTPException(status_code=500, detail=f"
|
215 |
|
216 |
|
217 |
-
@app.post("/classify
|
218 |
-
def
|
219 |
-
"""Classify
|
220 |
try:
|
221 |
-
|
222 |
-
|
223 |
-
used_model = model_name or model_manager.default_model
|
224 |
|
225 |
-
return
|
226 |
-
"
|
227 |
-
"
|
228 |
-
"
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
)
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
raise HTTPException(status_code=500, detail=f"Summarization failed: {str(e)}")
|
244 |
-
|
245 |
-
|
246 |
-
@app.post("/sentence-analysis", response_model=SentenceAnalysisResponse)
|
247 |
-
def analyze_sentences(data: TextInput) -> SentenceAnalysisResponse:
|
248 |
-
"""Analyze all sentences with scores and rankings."""
|
249 |
-
try:
|
250 |
-
result = summarizer.get_sentence_analysis(data.text)
|
251 |
-
return result
|
252 |
except Exception as e:
|
253 |
-
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
|
258 |
-
@app.post("/
|
259 |
-
def
|
260 |
-
"""
|
261 |
try:
|
262 |
-
|
263 |
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
"
|
270 |
-
"
|
271 |
-
|
272 |
-
"
|
273 |
-
}
|
274 |
-
except Exception as e:
|
275 |
-
raise HTTPException(
|
276 |
-
status_code=500, detail=f"Complete analysis failed: {str(e)}"
|
277 |
)
|
278 |
-
|
279 |
-
|
280 |
-
@app.post("/preprocess", response_model=PreprocessingResponse)
|
281 |
-
def preprocess_text(data: PreprocessingInput) -> PreprocessingResponse:
|
282 |
-
"""Preprocess text with step-by-step breakdown."""
|
283 |
-
try:
|
284 |
-
steps = preprocessor.get_preprocessing_steps(data.text, data.task_type.value)
|
285 |
-
return {"task_type": data.task_type.value, "preprocessing_steps": steps}
|
286 |
except Exception as e:
|
287 |
-
raise HTTPException(status_code=500, detail=f"
|
288 |
-
|
289 |
-
|
290 |
-
@app.post("/text-analysis", response_model=TextAnalysisResponse)
|
291 |
-
def analyze_text_characteristics(data: TextInput) -> TextAnalysisResponse:
|
292 |
-
"""Analyze text characteristics and statistics."""
|
293 |
-
try:
|
294 |
-
analysis = preprocessor.analyze_text(data.text)
|
295 |
-
return {"text": data.text, "analysis": analysis}
|
296 |
-
except Exception as e:
|
297 |
-
raise HTTPException(status_code=500, detail=f"Text analysis failed: {str(e)}")
|
298 |
|
299 |
|
300 |
-
@app.get("/
|
301 |
-
def
|
302 |
-
"""Get information about
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
}
|
313 |
-
|
314 |
-
raise HTTPException(
|
315 |
-
status_code=500, detail=f"Failed to get model info: {str(e)}"
|
316 |
-
)
|
317 |
-
|
318 |
-
|
319 |
-
@app.get("/models", response_model=AvailableModelsResponse)
|
320 |
-
def get_available_models() -> AvailableModelsResponse:
|
321 |
-
"""Get all available classification models."""
|
322 |
-
try:
|
323 |
-
models = model_manager.get_available_models()
|
324 |
-
return {"models": models, "current_model": model_manager.default_model}
|
325 |
-
except Exception as e:
|
326 |
-
raise HTTPException(
|
327 |
-
status_code=500, detail=f"Failed to get available models: {str(e)}"
|
328 |
-
)
|
329 |
-
|
330 |
-
|
331 |
-
@app.get("/models/cache")
|
332 |
-
def get_cache_status() -> Dict[str, Any]:
|
333 |
-
"""Get information about cached models."""
|
334 |
-
try:
|
335 |
-
return model_manager.get_cache_status()
|
336 |
-
except Exception as e:
|
337 |
-
raise HTTPException(
|
338 |
-
status_code=500, detail=f"Failed to get cache status: {str(e)}"
|
339 |
-
)
|
340 |
-
|
341 |
-
|
342 |
-
@app.post("/models/cache/clear")
|
343 |
-
def clear_model_cache(model: Optional[ModelType] = None) -> Dict[str, Any]:
|
344 |
-
"""Clear model cache for a specific model or all models."""
|
345 |
-
try:
|
346 |
-
model_name = model.value if model else None
|
347 |
-
return model_manager.clear_cache(model_name)
|
348 |
-
except Exception as e:
|
349 |
-
raise HTTPException(
|
350 |
-
status_code=500, detail=f"Failed to clear cache: {str(e)}"
|
351 |
-
)
|
|
|
7 |
from preprocessor import ArabicPreprocessor
|
8 |
from model_manager import ModelManager
|
9 |
from examples import REQUEST_EXAMPLES, RESPONSE_EXAMPLES
|
10 |
+
from bert_summarizer import BERTExtractiveSummarizer
|
11 |
|
12 |
|
13 |
class TaskType(str, Enum):
|
|
|
15 |
SUMMARIZATION = "summarization"
|
16 |
|
17 |
|
18 |
+
# New enums for frontend compatibility
|
19 |
+
class ClassificationModelType(str, Enum):
|
20 |
TRADITIONAL_SVM = "traditional_svm"
|
|
|
21 |
MODERN_LSTM = "modern_lstm"
|
22 |
+
MODERN_BERT = "modern_bert"
|
23 |
|
24 |
|
25 |
+
class SummarizationModelType(str, Enum):
|
26 |
+
TRADITIONAL_TFIDF = "traditional_tfidf"
|
27 |
+
MODERN_SEQ2SEQ = "modern_seq2seq"
|
28 |
+
MODERN_BERT = "modern_bert"
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
+
# Request models
|
32 |
+
class PreprocessRequest(BaseModel):
|
33 |
text: str
|
34 |
+
task_type: TaskType
|
35 |
|
36 |
+
model_config = {
|
37 |
+
"json_schema_extra": {"example": {"text": "هذا نص عربي للمعالجة", "task_type": "classification"}}
|
38 |
+
}
|
39 |
|
40 |
|
41 |
+
class ClassificationRequest(BaseModel):
|
42 |
text: str
|
43 |
+
model: ClassificationModelType
|
|
|
44 |
|
45 |
+
model_config = {"json_schema_extra": {"example": {"text": "هذا نص عربي للتصنيف", "model": "traditional_svm"}}}
|
|
|
|
|
46 |
|
47 |
|
48 |
+
class SummarizationRequest(BaseModel):
|
49 |
+
text: str
|
50 |
+
num_sentences: int = 3
|
51 |
+
model: SummarizationModelType
|
52 |
|
53 |
+
model_config = {"json_schema_extra": {"example": {"text": "هذا نص عربي طويل للتلخيص", "num_sentences": 3, "model": "traditional_tfidf"}}}
|
|
|
|
|
54 |
|
55 |
|
56 |
+
# Response models
|
57 |
+
class PreprocessingSteps(BaseModel):
|
58 |
+
original: str
|
59 |
+
stripped_lowered: Optional[str] = None
|
60 |
+
normalized: Optional[str] = None
|
61 |
+
diacritics_removed: Optional[str] = None
|
62 |
+
punctuation_removed: Optional[str] = None
|
63 |
+
repeated_chars_reduced: Optional[str] = None
|
64 |
+
whitespace_normalized: Optional[str] = None
|
65 |
+
numbers_removed: Optional[str] = None
|
66 |
+
tokenized: Optional[List[str]] = None
|
67 |
+
stopwords_removed: Optional[List[str]] = None
|
68 |
+
stemmed: Optional[List[str]] = None
|
69 |
+
final: str
|
70 |
|
71 |
+
|
72 |
+
class PreprocessingResponse(BaseModel):
|
73 |
+
task_type: str
|
74 |
+
preprocessing_steps: PreprocessingSteps
|
75 |
|
76 |
|
77 |
class ClassificationResponse(BaseModel):
|
78 |
prediction: str
|
|
|
79 |
confidence: float
|
80 |
probability_distribution: Dict[str, float]
|
81 |
cleaned_text: str
|
82 |
model_used: str
|
83 |
+
# Optional fields for extra info
|
84 |
+
prediction_index: Optional[int] = None
|
85 |
+
prediction_metadata: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
|
88 |
class SummarizationResponse(BaseModel):
|
|
|
91 |
summary_sentence_count: int
|
92 |
sentences: List[str]
|
93 |
selected_indices: List[int]
|
94 |
+
sentence_scores: List[float]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
model_used: str
|
96 |
+
# Optional fields for extra info
|
97 |
+
top_sentence_scores: Optional[List[float]] = None
|
|
|
|
|
|
|
98 |
|
99 |
|
100 |
+
app = FastAPI(
|
101 |
+
title="Arabic Text Analysis API",
|
102 |
+
description="API for Arabic text classification, summarization, and preprocessing with multiple model support",
|
103 |
+
version="1.0.0",
|
104 |
+
)
|
|
|
|
|
|
|
105 |
|
106 |
+
model_manager = ModelManager(default_model="traditional_svm")
|
107 |
+
summarizer = ArabicSummarizer("models/traditional_tfidf_vectorizer_summarization.joblib")
|
108 |
+
preprocessor = ArabicPreprocessor()
|
109 |
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
# Summarizer manager for model dispatch
|
112 |
+
class SummarizerManager:
|
113 |
+
"""Manages different types of Arabic text summarizers."""
|
114 |
+
|
115 |
+
def __init__(self):
|
116 |
+
# Initialize the traditional TF-IDF summarizer
|
117 |
+
self.traditional_tfidf = ArabicSummarizer("models/traditional_tfidf_vectorizer_summarization.joblib")
|
118 |
+
|
119 |
+
# Initialize BERT summarizer (lazy loading to avoid startup delays)
|
120 |
+
self.bert_summarizer = None
|
121 |
+
|
122 |
+
def get_summarizer(self, model_type: str):
|
123 |
+
"""Get summarizer based on model type."""
|
124 |
+
if model_type == "traditional_tfidf":
|
125 |
+
return self.traditional_tfidf
|
126 |
+
elif model_type == "modern_seq2seq":
|
127 |
+
# TODO: Implement seq2seq summarizer
|
128 |
+
# For now, fallback to TF-IDF
|
129 |
+
return self.traditional_tfidf
|
130 |
+
elif model_type == "modern_bert":
|
131 |
+
# Initialize BERT summarizer on first use
|
132 |
+
if self.bert_summarizer is None:
|
133 |
+
try:
|
134 |
+
print("Loading BERT summarizer...")
|
135 |
+
self.bert_summarizer = BERTExtractiveSummarizer()
|
136 |
+
print("BERT summarizer loaded successfully!")
|
137 |
+
except Exception as e:
|
138 |
+
print(f"Failed to load BERT summarizer: {e}")
|
139 |
+
raise ValueError(f"BERT summarizer initialization failed: {e}")
|
140 |
+
return self.bert_summarizer
|
141 |
+
else:
|
142 |
+
raise ValueError(f"Unknown summarizer model: {model_type}")
|
143 |
+
|
144 |
+
def summarize(self, text: str, num_sentences: int, model_type: str) -> Dict[str, Any]:
|
145 |
+
"""Summarize text using the specified model."""
|
146 |
+
try:
|
147 |
+
print(f"SummarizerManager: Using model '{model_type}' for text with {len(text)} characters")
|
148 |
+
summarizer_instance = self.get_summarizer(model_type)
|
149 |
+
result = summarizer_instance.summarize(text, num_sentences)
|
150 |
+
|
151 |
+
# Add debugging info
|
152 |
+
print(f"SummarizerManager: {model_type} selected indices: {result.get('selected_indices', [])}")
|
153 |
+
print(f"SummarizerManager: {model_type} summary preview: '{result.get('summary', '')[:100]}...'")
|
154 |
+
|
155 |
+
# Ensure sentence_scores is always a list (not None)
|
156 |
+
if result.get("sentence_scores") is None:
|
157 |
+
result["sentence_scores"] = []
|
158 |
+
|
159 |
+
return result
|
160 |
+
except Exception as e:
|
161 |
+
# If BERT fails, provide helpful error message
|
162 |
+
if model_type == "modern_bert":
|
163 |
+
raise ValueError(f"BERT summarization failed: {str(e)}. This might be due to missing dependencies (torch, transformers) or network issues downloading the model.")
|
164 |
+
else:
|
165 |
+
raise
|
166 |
+
|
167 |
+
|
168 |
+
summarizer_manager = SummarizerManager()
|
169 |
+
|
170 |
+
|
171 |
+
# Check which models are actually available
|
172 |
+
def check_model_availability():
|
173 |
+
"""Check which models are actually available and working."""
|
174 |
+
available_models = {
|
175 |
+
"traditional_svm": True, # Always available
|
176 |
+
"modern_lstm": True, # Always available
|
177 |
+
"modern_bert": False # Will be checked
|
178 |
}
|
179 |
+
|
180 |
+
# Test BERT model availability
|
181 |
+
try:
|
182 |
+
from modern_classifier import ModernClassifier
|
183 |
+
# Try to create a BERT classifier instance
|
184 |
+
bert_classifier = ModernClassifier("bert", "models/modern_bert_classifier.safetensors")
|
185 |
+
available_models["modern_bert"] = True
|
186 |
+
except Exception as e:
|
187 |
+
print(f"BERT model not available: {e}")
|
188 |
+
available_models["modern_bert"] = False
|
189 |
+
|
190 |
+
return available_models
|
191 |
+
|
192 |
+
|
193 |
+
# Check model availability at startup
|
194 |
+
AVAILABLE_MODELS = check_model_availability()
|
195 |
+
|
196 |
+
|
197 |
+
def _map_classification_model(frontend_model: str) -> str:
|
198 |
+
"""Map frontend model names to backend model names."""
|
199 |
+
# Check if the requested model is available
|
200 |
+
if not AVAILABLE_MODELS.get(frontend_model, False):
|
201 |
+
raise ValueError(f"Model '{frontend_model}' is not available. Available models: {[k for k, v in AVAILABLE_MODELS.items() if v]}")
|
202 |
+
|
203 |
+
mapping = {
|
204 |
+
"traditional_svm": "traditional_svm",
|
205 |
+
"modern_lstm": "modern_lstm",
|
206 |
+
"modern_bert": "modern_bert"
|
207 |
}
|
208 |
+
return mapping.get(frontend_model, frontend_model)
|
209 |
+
|
210 |
+
|
211 |
+
def _create_preprocessing_steps(steps: Dict[str, Any]) -> PreprocessingSteps:
|
212 |
+
"""Create preprocessing steps response with only the fields that exist."""
|
213 |
+
return PreprocessingSteps(
|
214 |
+
original=steps.get("original", ""),
|
215 |
+
stripped_lowered=steps.get("stripped_lowered"),
|
216 |
+
normalized=steps.get("normalized"),
|
217 |
+
diacritics_removed=steps.get("diacritics_removed"),
|
218 |
+
punctuation_removed=steps.get("punctuation_removed"),
|
219 |
+
repeated_chars_reduced=steps.get("repeated_chars_reduced"),
|
220 |
+
whitespace_normalized=steps.get("whitespace_normalized"),
|
221 |
+
numbers_removed=steps.get("numbers_removed"),
|
222 |
+
tokenized=steps.get("tokenized"),
|
223 |
+
stopwords_removed=steps.get("stopwords_removed"),
|
224 |
+
stemmed=steps.get("stemmed"),
|
225 |
+
final=steps.get("final", "")
|
226 |
+
)
|
227 |
+
|
228 |
+
|
229 |
+
# Main endpoints
|
|
|
|
|
|
|
|
|
|
|
230 |
@app.get("/")
|
231 |
def read_root() -> Dict[str, Any]:
|
232 |
"""API welcome message and endpoint documentation."""
|
|
|
238 |
"openapi_schema": "/openapi.json",
|
239 |
},
|
240 |
"endpoints": {
|
241 |
+
"preprocess": "POST /preprocess - Preprocess text with detailed steps",
|
242 |
"classify": "POST /classify - Classify Arabic text",
|
|
|
243 |
"summarize": "POST /summarize - Summarize Arabic text",
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
},
|
245 |
}
|
246 |
|
247 |
|
248 |
+
@app.post("/preprocess", response_model=PreprocessingResponse)
|
249 |
+
def preprocess_text(req: PreprocessRequest) -> PreprocessingResponse:
|
250 |
+
"""Preprocess text with step-by-step breakdown."""
|
251 |
try:
|
252 |
+
steps = preprocessor.get_preprocessing_steps(req.text, req.task_type.value)
|
253 |
+
preprocessing_steps = _create_preprocessing_steps(steps)
|
254 |
+
return PreprocessingResponse(
|
255 |
+
task_type=req.task_type.value,
|
256 |
+
preprocessing_steps=preprocessing_steps
|
257 |
+
)
|
258 |
except Exception as e:
|
259 |
+
raise HTTPException(status_code=500, detail=f"Preprocessing failed: {str(e)}")
|
260 |
|
261 |
|
262 |
+
@app.post("/classify", response_model=ClassificationResponse)
|
263 |
+
def classify_text(req: ClassificationRequest) -> ClassificationResponse:
|
264 |
+
"""Classify Arabic text."""
|
265 |
try:
|
266 |
+
backend_model = _map_classification_model(req.model.value)
|
267 |
+
result = model_manager.predict(req.text, backend_model)
|
|
|
268 |
|
269 |
+
return ClassificationResponse(
|
270 |
+
prediction=result["prediction"],
|
271 |
+
confidence=result["confidence"],
|
272 |
+
probability_distribution=result["probability_distribution"],
|
273 |
+
cleaned_text=result["cleaned_text"],
|
274 |
+
model_used=req.model.value, # Echo back the frontend model name
|
275 |
+
prediction_index=result.get("prediction_index"),
|
276 |
+
prediction_metadata=result.get("prediction_metadata")
|
277 |
)
|
278 |
+
except ValueError as e:
|
279 |
+
# Handle model availability errors
|
280 |
+
if "not available" in str(e):
|
281 |
+
raise HTTPException(
|
282 |
+
status_code=503,
|
283 |
+
detail=f"Model unavailable: {str(e)}. Check /models/available for current model status."
|
284 |
+
)
|
285 |
+
else:
|
286 |
+
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
except Exception as e:
|
288 |
+
error_msg = str(e)
|
289 |
+
|
290 |
+
# Provide more helpful error messages for common issues
|
291 |
+
if "BERT" in error_msg and ("connect" in error_msg.lower() or "internet" in error_msg.lower() or "huggingface" in error_msg.lower()):
|
292 |
+
raise HTTPException(
|
293 |
+
status_code=503,
|
294 |
+
detail=f"BERT model unavailable: The model requires internet connection to download tokenizer/config from Hugging Face, or the files need to be cached locally. Error: {error_msg}"
|
295 |
+
)
|
296 |
+
elif "modern_bert" in req.model.value and "Error loading" in error_msg:
|
297 |
+
raise HTTPException(
|
298 |
+
status_code=503,
|
299 |
+
detail=f"BERT model loading failed: {error_msg}. Please ensure the model files are properly configured and Hugging Face dependencies are available."
|
300 |
+
)
|
301 |
+
else:
|
302 |
+
raise HTTPException(status_code=500, detail=f"Classification failed: {error_msg}")
|
303 |
|
304 |
|
305 |
+
@app.post("/summarize", response_model=SummarizationResponse)
|
306 |
+
def summarize_text(req: SummarizationRequest) -> SummarizationResponse:
|
307 |
+
"""Summarize Arabic text."""
|
308 |
try:
|
309 |
+
result = summarizer_manager.summarize(req.text, req.num_sentences, req.model.value)
|
310 |
|
311 |
+
return SummarizationResponse(
|
312 |
+
summary=result["summary"],
|
313 |
+
original_sentence_count=result["original_sentence_count"],
|
314 |
+
summary_sentence_count=result["summary_sentence_count"],
|
315 |
+
sentences=result["sentences"],
|
316 |
+
selected_indices=result["selected_indices"],
|
317 |
+
sentence_scores=result["sentence_scores"],
|
318 |
+
model_used=req.model.value, # Echo back the frontend model name
|
319 |
+
top_sentence_scores=result.get("top_sentence_scores")
|
|
|
|
|
|
|
|
|
320 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
except Exception as e:
|
322 |
+
raise HTTPException(status_code=500, detail=f"Summarization failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
|
324 |
|
325 |
+
@app.get("/models/available")
|
326 |
+
def get_available_models() -> Dict[str, Any]:
|
327 |
+
"""Get information about which models are currently available."""
|
328 |
+
return {
|
329 |
+
"classification_models": {
|
330 |
+
"traditional_svm": {
|
331 |
+
"available": AVAILABLE_MODELS.get("traditional_svm", False),
|
332 |
+
"description": "Traditional SVM classifier with TF-IDF vectorization"
|
333 |
+
},
|
334 |
+
"modern_lstm": {
|
335 |
+
"available": AVAILABLE_MODELS.get("modern_lstm", False),
|
336 |
+
"description": "Modern LSTM-based neural network classifier"
|
337 |
+
},
|
338 |
+
"modern_bert": {
|
339 |
+
"available": AVAILABLE_MODELS.get("modern_bert", False),
|
340 |
+
"description": "Modern BERT-based transformer classifier",
|
341 |
+
"note": "Requires internet connection or cached Hugging Face models" if not AVAILABLE_MODELS.get("modern_bert", False) else None
|
342 |
+
}
|
343 |
+
},
|
344 |
+
"summarization_models": {
|
345 |
+
"traditional_tfidf": {
|
346 |
+
"available": True,
|
347 |
+
"description": "Traditional TF-IDF based extractive summarization"
|
348 |
},
|
349 |
+
"modern_seq2seq": {
|
350 |
+
"available": True,
|
351 |
+
"description": "Modern sequence-to-sequence summarization (currently uses TF-IDF fallback)",
|
352 |
+
"note": "Implementation in progress - currently falls back to TF-IDF"
|
353 |
+
},
|
354 |
+
"modern_bert": {
|
355 |
+
"available": True,
|
356 |
+
"description": "Modern BERT-based extractive summarization using asafaya/bert-base-arabic",
|
357 |
+
"note": "Requires torch and transformers dependencies. Model will be downloaded on first use."
|
358 |
+
}
|
359 |
+
},
|
360 |
+
"status": {
|
361 |
+
"total_classification_models": len([k for k, v in AVAILABLE_MODELS.items() if v]),
|
362 |
+
"total_available": len([k for k, v in AVAILABLE_MODELS.items() if v]),
|
363 |
+
"unavailable_models": [k for k, v in AVAILABLE_MODELS.items() if not v]
|
364 |
}
|
365 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bert_summarizer.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import re
|
4 |
+
from typing import Dict, List, Any
|
5 |
+
from transformers import BertTokenizer, BertModel
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
+
from preprocessor import preprocess_for_summarization
|
8 |
+
|
9 |
+
|
10 |
+
class BERTExtractiveSummarizer:
|
11 |
+
def __init__(self, model_name='aubmindlab/bert-base-arabertv02'):
|
12 |
+
"""Initialize BERT-based Arabic summarizer."""
|
13 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
print(f"Using device: {self.device}")
|
15 |
+
|
16 |
+
# Load tokenizer and model
|
17 |
+
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
18 |
+
self.model = BertModel.from_pretrained(model_name)
|
19 |
+
self.model.to(self.device)
|
20 |
+
self.model.eval()
|
21 |
+
|
22 |
+
def get_sentence_embeddings(self, sentences: List[str]) -> np.ndarray:
|
23 |
+
"""Get BERT embeddings for sentences."""
|
24 |
+
embeddings = []
|
25 |
+
|
26 |
+
with torch.no_grad():
|
27 |
+
for sentence in sentences:
|
28 |
+
# Tokenize
|
29 |
+
inputs = self.tokenizer(
|
30 |
+
sentence,
|
31 |
+
return_tensors='pt',
|
32 |
+
max_length=512,
|
33 |
+
truncation=True,
|
34 |
+
padding=True
|
35 |
+
).to(self.device)
|
36 |
+
|
37 |
+
# Get embeddings
|
38 |
+
outputs = self.model(**inputs)
|
39 |
+
# Use CLS token embedding
|
40 |
+
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
41 |
+
embeddings.append(embedding.squeeze())
|
42 |
+
|
43 |
+
return np.array(embeddings)
|
44 |
+
|
45 |
+
def summarize(self, text: str, num_sentences: int = 3) -> Dict[str, Any]:
|
46 |
+
"""
|
47 |
+
Summarize Arabic text using BERT extractive summarization.
|
48 |
+
Returns the same structure as other summarizers for consistency.
|
49 |
+
"""
|
50 |
+
print(f"BERT Summarizer: Processing text with {len(text)} characters")
|
51 |
+
|
52 |
+
# Use the same preprocessing as TF-IDF for fair comparison
|
53 |
+
cleaned_text = preprocess_for_summarization(text)
|
54 |
+
print(f"BERT Summarizer: After preprocessing: '{cleaned_text[:100]}...'")
|
55 |
+
|
56 |
+
# Split into sentences - same approach as TF-IDF
|
57 |
+
sentences = re.split(r'[.!؟\n]+', cleaned_text)
|
58 |
+
sentences = [s.strip() for s in sentences if s.strip()] # Same as TF-IDF
|
59 |
+
|
60 |
+
print(f"BERT Summarizer: Found {len(sentences)} sentences")
|
61 |
+
original_sentence_count = len(sentences)
|
62 |
+
|
63 |
+
# If we have fewer sentences than requested, return all
|
64 |
+
if len(sentences) <= num_sentences:
|
65 |
+
print(f"BERT Summarizer: Returning all {len(sentences)} sentences (fewer than requested)")
|
66 |
+
return {
|
67 |
+
"summary": cleaned_text.strip(), # Use cleaned text like TF-IDF
|
68 |
+
"original_sentence_count": original_sentence_count,
|
69 |
+
"summary_sentence_count": len(sentences),
|
70 |
+
"sentences": sentences,
|
71 |
+
"selected_indices": list(range(len(sentences))),
|
72 |
+
"sentence_scores": [1.0] * len(sentences) # All sentences selected
|
73 |
+
}
|
74 |
+
|
75 |
+
print("BERT Summarizer: Getting sentence embeddings...")
|
76 |
+
# Get sentence embeddings
|
77 |
+
sentence_embeddings = self.get_sentence_embeddings(sentences)
|
78 |
+
print(f"BERT Summarizer: Got embeddings shape: {sentence_embeddings.shape}")
|
79 |
+
|
80 |
+
# Calculate document embedding (mean of all sentences)
|
81 |
+
doc_embedding = np.mean(sentence_embeddings, axis=0)
|
82 |
+
|
83 |
+
# Calculate similarity scores
|
84 |
+
similarities = cosine_similarity([doc_embedding], sentence_embeddings)[0]
|
85 |
+
print(f"BERT Summarizer: Similarity scores: {similarities}")
|
86 |
+
|
87 |
+
# Get top sentences (indices with highest scores)
|
88 |
+
top_indices = np.argsort(similarities)[-num_sentences:]
|
89 |
+
print(f"BERT Summarizer: Top indices: {top_indices}")
|
90 |
+
|
91 |
+
# Sort indices to maintain original order in summary
|
92 |
+
top_indices_sorted = sorted(top_indices)
|
93 |
+
# Convert numpy indices to regular ints for JSON serialization
|
94 |
+
top_indices_sorted = [int(i) for i in top_indices_sorted]
|
95 |
+
print(f"BERT Summarizer: Selected indices (in order): {top_indices_sorted}")
|
96 |
+
|
97 |
+
# Get selected sentences and their scores
|
98 |
+
selected_sentences = [sentences[i] for i in top_indices_sorted]
|
99 |
+
selected_scores = [float(similarities[i]) for i in top_indices_sorted]
|
100 |
+
|
101 |
+
print(f"BERT Summarizer: Selected sentences: {[s[:50] + '...' for s in selected_sentences]}")
|
102 |
+
|
103 |
+
# Create summary by joining selected sentences
|
104 |
+
summary = ' '.join(selected_sentences)
|
105 |
+
|
106 |
+
return {
|
107 |
+
"summary": summary,
|
108 |
+
"original_sentence_count": original_sentence_count,
|
109 |
+
"summary_sentence_count": len(selected_sentences),
|
110 |
+
"sentences": sentences, # All original sentences
|
111 |
+
"selected_indices": top_indices_sorted,
|
112 |
+
"sentence_scores": selected_scores,
|
113 |
+
"top_sentence_scores": selected_scores # Additional info
|
114 |
+
}
|
examples.py
CHANGED
@@ -269,8 +269,8 @@ RESPONSE_EXAMPLES = {
|
|
269 |
"model_description": "Traditional SVM classifier with TF-IDF vectorization",
|
270 |
"model_config": {
|
271 |
"type": "traditional",
|
272 |
-
|
273 |
-
|
274 |
"description": "Traditional SVM classifier with TF-IDF vectorization"
|
275 |
},
|
276 |
"is_cached": True
|
|
|
269 |
"model_description": "Traditional SVM classifier with TF-IDF vectorization",
|
270 |
"model_config": {
|
271 |
"type": "traditional",
|
272 |
+
"classifier_path": "models/traditional_svm_classifier.joblib",
|
273 |
+
"vectorizer_path": "models/traditional_tfidf_vectorizer_classifier.joblib",
|
274 |
"description": "Traditional SVM classifier with TF-IDF vectorization"
|
275 |
},
|
276 |
"is_cached": True
|
model_manager.py
CHANGED
@@ -15,15 +15,15 @@ class ModelManager:
|
|
15 |
AVAILABLE_MODELS = {
|
16 |
"traditional_svm": {
|
17 |
"type": "traditional",
|
18 |
-
"classifier_path": "traditional_svm_classifier.joblib",
|
19 |
-
"vectorizer_path": "traditional_tfidf_vectorizer_classifier.joblib",
|
20 |
"description": "Traditional SVM classifier with TF-IDF vectorization"
|
21 |
},
|
22 |
|
23 |
"modern_bert": {
|
24 |
"type": "modern",
|
25 |
"model_type": "bert",
|
26 |
-
"model_path": "modern_bert_classifier.safetensors",
|
27 |
"config_path": "config.json",
|
28 |
"description": "Modern BERT-based transformer classifier"
|
29 |
},
|
@@ -31,7 +31,7 @@ class ModelManager:
|
|
31 |
"modern_lstm": {
|
32 |
"type": "modern",
|
33 |
"model_type": "lstm",
|
34 |
-
"model_path": "modern_lstm_classifier.pth",
|
35 |
"description": "Modern LSTM-based neural network classifier"
|
36 |
}
|
37 |
}
|
|
|
15 |
AVAILABLE_MODELS = {
|
16 |
"traditional_svm": {
|
17 |
"type": "traditional",
|
18 |
+
"classifier_path": "models/traditional_svm_classifier.joblib",
|
19 |
+
"vectorizer_path": "models/traditional_tfidf_vectorizer_classifier.joblib",
|
20 |
"description": "Traditional SVM classifier with TF-IDF vectorization"
|
21 |
},
|
22 |
|
23 |
"modern_bert": {
|
24 |
"type": "modern",
|
25 |
"model_type": "bert",
|
26 |
+
"model_path": "models/modern_bert_classifier.safetensors",
|
27 |
"config_path": "config.json",
|
28 |
"description": "Modern BERT-based transformer classifier"
|
29 |
},
|
|
|
31 |
"modern_lstm": {
|
32 |
"type": "modern",
|
33 |
"model_type": "lstm",
|
34 |
+
"model_path": "models/modern_lstm_classifier.pth",
|
35 |
"description": "Modern LSTM-based neural network classifier"
|
36 |
}
|
37 |
}
|
models/Seq2seq/seq2seq_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"ENC_MAXLEN": 1900, "DEC_MAXLEN": 178, "SRC_VOCAB_SIZE": 20000, "TGT_VOCAB_SIZE": 10000, "EMB_DIM": 128, "HID_DIM": 256}
|
models/Seq2seq/seq2seq_model.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a35f8f2f2dc4f77570cc86c77a9fb90a1649d79d3e5e632be92499e889958a27
|
3 |
+
size 117152336
|
models/Seq2seq/src_tokenizer.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff87d78b4f45fa3aaa9b9a43c0d94e7aecc1f7f18e0ab5c4caed15a0f1ca61ee
|
3 |
+
size 12722191
|
models/Seq2seq/tgt_tokenizer.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca4e33cc944afd29a11b4fed11da27787ef604e7403b765ab589a7b304059e95
|
3 |
+
size 2577556
|
modern_bert_classifier.safetensors → models/modern_bert_classifier.safetensors
RENAMED
File without changes
|
modern_lstm_classifier.pth → models/modern_lstm_classifier.pth
RENAMED
File without changes
|
traditional_svm_classifier.joblib → models/traditional_svm_classifier.joblib
RENAMED
File without changes
|
traditional_tfidf_vectorizer_classifier.joblib → models/traditional_tfidf_vectorizer_classifier.joblib
RENAMED
File without changes
|
traditional_tfidf_vectorizer_summarization.joblib → models/traditional_tfidf_vectorizer_summarization.joblib
RENAMED
File without changes
|
modern_classifier.py
CHANGED
@@ -65,15 +65,71 @@ class ModernClassifier:
|
|
65 |
def _load_bert_model(self):
|
66 |
"""Load BERT model from safetensors."""
|
67 |
try:
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
state_dict = load_file(self.model_path)
|
70 |
embed_key = next(k for k in state_dict if 'embeddings.word_embeddings.weight' in k)
|
71 |
checkpoint_vocab_size = state_dict[embed_key].shape[0]
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
self.model = AutoModelForSequenceClassification.from_config(config)
|
78 |
self.model.resize_token_embeddings(checkpoint_vocab_size)
|
79 |
self.model.load_state_dict(state_dict, strict=False)
|
@@ -116,6 +172,15 @@ class ModernClassifier:
|
|
116 |
max_length=512
|
117 |
)
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
return {key: value.to(self.device) for key, value in inputs.items()}
|
120 |
|
121 |
def _preprocess_text_for_lstm(self, text: str) -> torch.Tensor:
|
@@ -150,7 +215,11 @@ class ModernClassifier:
|
|
150 |
inputs = self._preprocess_text_for_lstm(text)
|
151 |
logits = self.model(inputs)
|
152 |
|
153 |
-
probabilities = torch.softmax(logits, dim=-1).cpu().numpy()
|
|
|
|
|
|
|
|
|
154 |
|
155 |
prediction_index = int(np.argmax(probabilities))
|
156 |
prediction = self.classes[prediction_index]
|
|
|
65 |
def _load_bert_model(self):
|
66 |
"""Load BERT model from safetensors."""
|
67 |
try:
|
68 |
+
# Try different Arabic BERT tokenizers that match 32K vocabulary
|
69 |
+
tokenizer_options = [
|
70 |
+
'asafaya/bert-base-arabic', # This one has 32K vocab
|
71 |
+
'aubmindlab/bert-base-arabertv02', # Alternative
|
72 |
+
'aubmindlab/bert-base-arabertv2' # Fallback (64K vocab)
|
73 |
+
]
|
74 |
+
|
75 |
+
self.tokenizer = None
|
76 |
+
for tokenizer_name in tokenizer_options:
|
77 |
+
try:
|
78 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, local_files_only=True)
|
79 |
+
# Test if vocabulary size matches
|
80 |
+
if len(tokenizer.vocab) <= 32000:
|
81 |
+
self.tokenizer = tokenizer
|
82 |
+
print(f"Using tokenizer: {tokenizer_name} (vocab size: {len(tokenizer.vocab)})")
|
83 |
+
break
|
84 |
+
except:
|
85 |
+
continue
|
86 |
+
|
87 |
+
if self.tokenizer is None:
|
88 |
+
# Try downloading if local files don't work
|
89 |
+
for tokenizer_name in tokenizer_options:
|
90 |
+
try:
|
91 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
92 |
+
if len(tokenizer.vocab) <= 32000:
|
93 |
+
self.tokenizer = tokenizer
|
94 |
+
print(f"Downloaded tokenizer: {tokenizer_name} (vocab size: {len(tokenizer.vocab)})")
|
95 |
+
break
|
96 |
+
except:
|
97 |
+
continue
|
98 |
+
|
99 |
+
if self.tokenizer is None:
|
100 |
+
raise RuntimeError("No compatible Arabic BERT tokenizer found with 32K vocabulary")
|
101 |
+
|
102 |
state_dict = load_file(self.model_path)
|
103 |
embed_key = next(k for k in state_dict if 'embeddings.word_embeddings.weight' in k)
|
104 |
checkpoint_vocab_size = state_dict[embed_key].shape[0]
|
105 |
+
|
106 |
+
# Try to load config locally first
|
107 |
+
try:
|
108 |
+
config = AutoConfig.from_pretrained(
|
109 |
+
'aubmindlab/bert-base-arabertv2',
|
110 |
+
num_labels=len(self.classes),
|
111 |
+
vocab_size=checkpoint_vocab_size,
|
112 |
+
local_files_only=True
|
113 |
+
)
|
114 |
+
except:
|
115 |
+
try:
|
116 |
+
config = AutoConfig.from_pretrained(
|
117 |
+
'aubmindlab/bert-base-arabertv2',
|
118 |
+
num_labels=len(self.classes),
|
119 |
+
vocab_size=checkpoint_vocab_size
|
120 |
+
)
|
121 |
+
except:
|
122 |
+
# Fallback: create a basic BERT config
|
123 |
+
from transformers import BertConfig
|
124 |
+
config = BertConfig(
|
125 |
+
vocab_size=checkpoint_vocab_size,
|
126 |
+
hidden_size=768,
|
127 |
+
num_hidden_layers=12,
|
128 |
+
num_attention_heads=12,
|
129 |
+
intermediate_size=3072,
|
130 |
+
num_labels=len(self.classes)
|
131 |
+
)
|
132 |
+
|
133 |
self.model = AutoModelForSequenceClassification.from_config(config)
|
134 |
self.model.resize_token_embeddings(checkpoint_vocab_size)
|
135 |
self.model.load_state_dict(state_dict, strict=False)
|
|
|
172 |
max_length=512
|
173 |
)
|
174 |
|
175 |
+
# CRITICAL FIX: Check for vocabulary mismatch and clamp token IDs
|
176 |
+
input_ids = inputs['input_ids']
|
177 |
+
max_token_id = input_ids.max().item()
|
178 |
+
model_vocab_size = self.model.config.vocab_size
|
179 |
+
|
180 |
+
if max_token_id >= model_vocab_size:
|
181 |
+
# Fix: Clamp token IDs to valid range to prevent "index out of range" error
|
182 |
+
inputs['input_ids'] = torch.clamp(input_ids, 0, model_vocab_size - 1)
|
183 |
+
|
184 |
return {key: value.to(self.device) for key, value in inputs.items()}
|
185 |
|
186 |
def _preprocess_text_for_lstm(self, text: str) -> torch.Tensor:
|
|
|
215 |
inputs = self._preprocess_text_for_lstm(text)
|
216 |
logits = self.model(inputs)
|
217 |
|
218 |
+
probabilities = torch.softmax(logits, dim=-1).cpu().numpy()
|
219 |
+
|
220 |
+
# Handle batch dimension
|
221 |
+
if len(probabilities.shape) > 1:
|
222 |
+
probabilities = probabilities[0]
|
223 |
|
224 |
prediction_index = int(np.argmax(probabilities))
|
225 |
prediction = self.classes[prediction_index]
|
summarizer.py
CHANGED
@@ -8,7 +8,7 @@ from preprocessor import preprocess_for_summarization
|
|
8 |
class ArabicSummarizer:
|
9 |
"""Arabic text summarizer using TF-IDF scoring."""
|
10 |
|
11 |
-
def __init__(self, vectorizer_path: str = "traditional_tfidf_vectorizer_summarization.joblib"):
|
12 |
self.vectorizer = joblib.load(vectorizer_path)
|
13 |
|
14 |
def summarize(self, text: str, num_sentences: int = 3) -> Dict[str, Any]:
|
|
|
8 |
class ArabicSummarizer:
|
9 |
"""Arabic text summarizer using TF-IDF scoring."""
|
10 |
|
11 |
+
def __init__(self, vectorizer_path: str = "models/traditional_tfidf_vectorizer_summarization.joblib"):
|
12 |
self.vectorizer = joblib.load(vectorizer_path)
|
13 |
|
14 |
def summarize(self, text: str, num_sentences: int = 3) -> Dict[str, Any]:
|
traditional_classifier.py
CHANGED
@@ -9,8 +9,8 @@ class TraditionalClassifier:
|
|
9 |
|
10 |
def __init__(
|
11 |
self,
|
12 |
-
classifier_path: str = "traditional_svm_classifier.joblib",
|
13 |
-
vectorizer_path: str = "traditional_tfidf_vectorizer_classifier.joblib",
|
14 |
):
|
15 |
self.model = joblib.load(classifier_path)
|
16 |
self.vectorizer = joblib.load(vectorizer_path)
|
|
|
9 |
|
10 |
def __init__(
|
11 |
self,
|
12 |
+
classifier_path: str = "models/traditional_svm_classifier.joblib",
|
13 |
+
vectorizer_path: str = "models/traditional_tfidf_vectorizer_classifier.joblib",
|
14 |
):
|
15 |
self.model = joblib.load(classifier_path)
|
16 |
self.vectorizer = joblib.load(vectorizer_path)
|