Update app.py
Browse files
app.py
CHANGED
@@ -261,18 +261,22 @@ class OptimizedRagSystem:
|
|
261 |
logger.warning("Query too short, using default search")
|
262 |
search_query = "machine learning"
|
263 |
|
264 |
-
#
|
265 |
-
|
266 |
-
if
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
|
|
276 |
|
277 |
if categories and len(categories) > 0:
|
278 |
category_filter = " OR ".join([f"cat:{cat.strip()}" for cat in categories if cat.strip()])
|
@@ -352,35 +356,46 @@ class OptimizedRagSystem:
|
|
352 |
# If all attempts failed, try a simple fallback search
|
353 |
logger.warning("All search attempts failed, trying fallback search...")
|
354 |
try:
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
|
|
|
|
361 |
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
except Exception as e:
|
386 |
logger.error(f"Even fallback search failed: {e}")
|
@@ -546,10 +561,14 @@ class OptimizedRagSystem:
|
|
546 |
with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
|
547 |
rerank_scores = self.reranker.predict(pairs, show_progress_bar=False)
|
548 |
|
549 |
-
# Combine with original scores
|
550 |
reranked_results = []
|
551 |
for i, (chunk, original_score) in enumerate(results):
|
552 |
-
|
|
|
|
|
|
|
|
|
553 |
reranked_results.append((chunk, combined_score))
|
554 |
|
555 |
# Sort by new scores
|
|
|
261 |
logger.warning("Query too short, using default search")
|
262 |
search_query = "machine learning"
|
263 |
|
264 |
+
# Simplify transformer queries for better ArXiv results
|
265 |
+
# ArXiv search works better with simple, specific terms
|
266 |
+
if "attention" in search_query.lower() and "transformer" in search_query.lower():
|
267 |
+
search_query = "attention mechanism transformer"
|
268 |
+
elif "transformer" in search_query.lower():
|
269 |
+
search_query = "transformer neural network"
|
270 |
+
elif "attention" in search_query.lower():
|
271 |
+
search_query = "attention mechanism"
|
272 |
+
|
273 |
+
logger.info(f"Simplified query: '{search_query}'")
|
274 |
+
|
275 |
+
# Handle categories - for transformers, default to relevant categories
|
276 |
+
if "attention" in search_query.lower() or "transformer" in search_query.lower():
|
277 |
+
if not categories or len(categories) == 0:
|
278 |
+
categories = ["cs.CL", "cs.LG", "cs.AI"] # Default to relevant categories
|
279 |
+
logger.info(f"Added default categories for transformer search: {categories}")
|
280 |
|
281 |
if categories and len(categories) > 0:
|
282 |
category_filter = " OR ".join([f"cat:{cat.strip()}" for cat in categories if cat.strip()])
|
|
|
356 |
# If all attempts failed, try a simple fallback search
|
357 |
logger.warning("All search attempts failed, trying fallback search...")
|
358 |
try:
|
359 |
+
# Try a specific search that should return transformer papers
|
360 |
+
fallback_queries = [
|
361 |
+
"attention is all you need",
|
362 |
+
"transformer attention mechanism",
|
363 |
+
"BERT language representation",
|
364 |
+
"GPT generative pretrained",
|
365 |
+
"artificial intelligence"
|
366 |
+
]
|
367 |
|
368 |
+
for fallback_query in fallback_queries:
|
369 |
+
logger.info(f"Trying fallback: '{fallback_query}'")
|
370 |
+
fallback_search = arxiv.Search(
|
371 |
+
query=fallback_query,
|
372 |
+
max_results=5,
|
373 |
+
sort_by=arxiv.SortCriterion.Relevance,
|
374 |
+
sort_order=arxiv.SortOrder.Descending
|
375 |
+
)
|
376 |
+
|
377 |
+
papers = []
|
378 |
+
for i, result in enumerate(fallback_search.results()):
|
379 |
+
if i >= 5: # Limit fallback results
|
380 |
+
break
|
381 |
+
try:
|
382 |
+
paper = Paper(
|
383 |
+
id=result.entry_id.split('/')[-1],
|
384 |
+
title=result.title,
|
385 |
+
abstract=result.summary,
|
386 |
+
authors=[author.name for author in result.authors],
|
387 |
+
categories=result.categories,
|
388 |
+
published=result.published,
|
389 |
+
url=result.entry_id
|
390 |
+
)
|
391 |
+
papers.append(paper)
|
392 |
+
except Exception as e:
|
393 |
+
logger.warning(f"Error in fallback paper processing: {e}")
|
394 |
+
continue
|
395 |
+
|
396 |
+
if papers:
|
397 |
+
logger.info(f"🔄 Fallback search '{fallback_query}' returned {len(papers)} papers")
|
398 |
+
return papers
|
399 |
|
400 |
except Exception as e:
|
401 |
logger.error(f"Even fallback search failed: {e}")
|
|
|
561 |
with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
|
562 |
rerank_scores = self.reranker.predict(pairs, show_progress_bar=False)
|
563 |
|
564 |
+
# Combine with original scores (ensure positive scores)
|
565 |
reranked_results = []
|
566 |
for i, (chunk, original_score) in enumerate(results):
|
567 |
+
# Normalize rerank scores to [0, 1] and ensure positive
|
568 |
+
rerank_score = float(rerank_scores[i])
|
569 |
+
rerank_score = max(0.0, min(1.0, (rerank_score + 1) / 2)) # Convert from [-1,1] to [0,1]
|
570 |
+
|
571 |
+
combined_score = 0.6 * rerank_score + 0.4 * max(0.0, original_score)
|
572 |
reranked_results.append((chunk, combined_score))
|
573 |
|
574 |
# Sort by new scores
|