Update app.py
Browse files
app.py
CHANGED
@@ -256,11 +256,24 @@ class OptimizedRagSystem:
|
|
256 |
papers = []
|
257 |
search_query = query.strip()
|
258 |
|
259 |
-
# Simple query validation
|
260 |
if not search_query or len(search_query) < 2:
|
261 |
logger.warning("Query too short, using default search")
|
262 |
search_query = "machine learning"
|
263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
if categories and len(categories) > 0:
|
265 |
category_filter = " OR ".join([f"cat:{cat.strip()}" for cat in categories if cat.strip()])
|
266 |
if category_filter:
|
@@ -269,6 +282,7 @@ class OptimizedRagSystem:
|
|
269 |
logger.info(f"🔍 ArXiv search attempt {attempt + 1}: '{search_query}'")
|
270 |
|
271 |
# Create search with timeout and retry settings
|
|
|
272 |
search = arxiv.Search(
|
273 |
query=search_query,
|
274 |
max_results=min(max_results, 50), # Limit to prevent API issues
|
@@ -489,15 +503,23 @@ class OptimizedRagSystem:
|
|
489 |
bm25_scores = bm25_scores[:min_length]
|
490 |
chunks = self.chunks[:min_length]
|
491 |
|
492 |
-
# Normalize scores
|
493 |
-
if len(semantic_scores) > 0:
|
494 |
-
semantic_scores = (semantic_scores - semantic_scores.min()) / (semantic_scores.max() - semantic_scores.min()
|
495 |
-
|
496 |
-
|
|
|
|
|
|
|
|
|
|
|
497 |
|
498 |
-
# Combine scores
|
499 |
combined_scores = semantic_weight * semantic_scores + (1 - semantic_weight) * bm25_scores
|
500 |
|
|
|
|
|
|
|
501 |
# Get top results
|
502 |
top_indices = np.argsort(combined_scores)[::-1][:top_k]
|
503 |
results = [(chunks[i], float(combined_scores[i])) for i in top_indices]
|
|
|
256 |
papers = []
|
257 |
search_query = query.strip()
|
258 |
|
259 |
+
# Simple query validation and enhancement
|
260 |
if not search_query or len(search_query) < 2:
|
261 |
logger.warning("Query too short, using default search")
|
262 |
search_query = "machine learning"
|
263 |
|
264 |
+
# Enhance transformer-related queries for better results
|
265 |
+
transformer_keywords = ["transformer", "attention", "bert", "gpt", "llm", "language model"]
|
266 |
+
if any(keyword in search_query.lower() for keyword in transformer_keywords):
|
267 |
+
# Add related terms to improve relevance
|
268 |
+
enhanced_terms = []
|
269 |
+
if "attention" in search_query.lower():
|
270 |
+
enhanced_terms.extend(["self-attention", "multi-head attention", "scaled dot-product"])
|
271 |
+
if "transformer" in search_query.lower():
|
272 |
+
enhanced_terms.extend(["encoder", "decoder", "positional encoding"])
|
273 |
+
|
274 |
+
if enhanced_terms:
|
275 |
+
search_query = f"({search_query}) OR ({' OR '.join(enhanced_terms)})"
|
276 |
+
|
277 |
if categories and len(categories) > 0:
|
278 |
category_filter = " OR ".join([f"cat:{cat.strip()}" for cat in categories if cat.strip()])
|
279 |
if category_filter:
|
|
|
282 |
logger.info(f"🔍 ArXiv search attempt {attempt + 1}: '{search_query}'")
|
283 |
|
284 |
# Create search with timeout and retry settings
|
285 |
+
# Use relevance sorting for better results, but mix with recent papers
|
286 |
search = arxiv.Search(
|
287 |
query=search_query,
|
288 |
max_results=min(max_results, 50), # Limit to prevent API issues
|
|
|
503 |
bm25_scores = bm25_scores[:min_length]
|
504 |
chunks = self.chunks[:min_length]
|
505 |
|
506 |
+
# Normalize scores properly to [0, 1] range
|
507 |
+
if len(semantic_scores) > 0 and semantic_scores.max() > semantic_scores.min():
|
508 |
+
semantic_scores = (semantic_scores - semantic_scores.min()) / (semantic_scores.max() - semantic_scores.min())
|
509 |
+
else:
|
510 |
+
semantic_scores = np.ones_like(semantic_scores) * 0.5
|
511 |
+
|
512 |
+
if len(bm25_scores) > 0 and bm25_scores.max() > bm25_scores.min():
|
513 |
+
bm25_scores = (bm25_scores - bm25_scores.min()) / (bm25_scores.max() - bm25_scores.min())
|
514 |
+
else:
|
515 |
+
bm25_scores = np.ones_like(bm25_scores) * 0.5
|
516 |
|
517 |
+
# Combine scores (both should be in [0, 1] range now)
|
518 |
combined_scores = semantic_weight * semantic_scores + (1 - semantic_weight) * bm25_scores
|
519 |
|
520 |
+
# Ensure final scores are positive
|
521 |
+
combined_scores = np.maximum(combined_scores, 0.0)
|
522 |
+
|
523 |
# Get top results
|
524 |
top_indices = np.argsort(combined_scores)[::-1][:top_k]
|
525 |
results = [(chunks[i], float(combined_scores[i])) for i in top_indices]
|