mihirinamdar commited on
Commit
a1e6055
·
verified ·
1 Parent(s): 15f008d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -42
app.py CHANGED
@@ -261,18 +261,22 @@ class OptimizedRagSystem:
261
  logger.warning("Query too short, using default search")
262
  search_query = "machine learning"
263
 
264
- # Enhance transformer-related queries for better results
265
- transformer_keywords = ["transformer", "attention", "bert", "gpt", "llm", "language model"]
266
- if any(keyword in search_query.lower() for keyword in transformer_keywords):
267
- # Add related terms to improve relevance
268
- enhanced_terms = []
269
- if "attention" in search_query.lower():
270
- enhanced_terms.extend(["self-attention", "multi-head attention", "scaled dot-product"])
271
- if "transformer" in search_query.lower():
272
- enhanced_terms.extend(["encoder", "decoder", "positional encoding"])
273
-
274
- if enhanced_terms:
275
- search_query = f"({search_query}) OR ({' OR '.join(enhanced_terms)})"
 
 
 
 
276
 
277
  if categories and len(categories) > 0:
278
  category_filter = " OR ".join([f"cat:{cat.strip()}" for cat in categories if cat.strip()])
@@ -352,35 +356,46 @@ class OptimizedRagSystem:
352
  # If all attempts failed, try a simple fallback search
353
  logger.warning("All search attempts failed, trying fallback search...")
354
  try:
355
- fallback_search = arxiv.Search(
356
- query="artificial intelligence", # Simple fallback
357
- max_results=5,
358
- sort_by=arxiv.SortCriterion.SubmittedDate,
359
- sort_order=arxiv.SortOrder.Descending
360
- )
 
 
361
 
362
- papers = []
363
- for i, result in enumerate(fallback_search.results()):
364
- if i >= 5: # Limit fallback results
365
- break
366
- try:
367
- paper = Paper(
368
- id=result.entry_id.split('/')[-1],
369
- title=result.title,
370
- abstract=result.summary,
371
- authors=[author.name for author in result.authors],
372
- categories=result.categories,
373
- published=result.published,
374
- url=result.entry_id
375
- )
376
- papers.append(paper)
377
- except Exception as e:
378
- logger.warning(f"Error in fallback paper processing: {e}")
379
- continue
380
-
381
- if papers:
382
- logger.info(f"🔄 Fallback search returned {len(papers)} papers")
383
- return papers
 
 
 
 
 
 
 
 
 
384
 
385
  except Exception as e:
386
  logger.error(f"Even fallback search failed: {e}")
@@ -546,10 +561,14 @@ class OptimizedRagSystem:
546
  with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
547
  rerank_scores = self.reranker.predict(pairs, show_progress_bar=False)
548
 
549
- # Combine with original scores
550
  reranked_results = []
551
  for i, (chunk, original_score) in enumerate(results):
552
- combined_score = 0.6 * float(rerank_scores[i]) + 0.4 * original_score
 
 
 
 
553
  reranked_results.append((chunk, combined_score))
554
 
555
  # Sort by new scores
 
261
  logger.warning("Query too short, using default search")
262
  search_query = "machine learning"
263
 
264
+ # Simplify transformer queries for better ArXiv results
265
+ # ArXiv search works better with simple, specific terms
266
+ if "attention" in search_query.lower() and "transformer" in search_query.lower():
267
+ search_query = "attention mechanism transformer"
268
+ elif "transformer" in search_query.lower():
269
+ search_query = "transformer neural network"
270
+ elif "attention" in search_query.lower():
271
+ search_query = "attention mechanism"
272
+
273
+ logger.info(f"Simplified query: '{search_query}'")
274
+
275
+ # Handle categories - for transformers, default to relevant categories
276
+ if "attention" in search_query.lower() or "transformer" in search_query.lower():
277
+ if not categories or len(categories) == 0:
278
+ categories = ["cs.CL", "cs.LG", "cs.AI"] # Default to relevant categories
279
+ logger.info(f"Added default categories for transformer search: {categories}")
280
 
281
  if categories and len(categories) > 0:
282
  category_filter = " OR ".join([f"cat:{cat.strip()}" for cat in categories if cat.strip()])
 
356
  # If all attempts failed, try a simple fallback search
357
  logger.warning("All search attempts failed, trying fallback search...")
358
  try:
359
+ # Try a specific search that should return transformer papers
360
+ fallback_queries = [
361
+ "attention is all you need",
362
+ "transformer attention mechanism",
363
+ "BERT language representation",
364
+ "GPT generative pretrained",
365
+ "artificial intelligence"
366
+ ]
367
 
368
+ for fallback_query in fallback_queries:
369
+ logger.info(f"Trying fallback: '{fallback_query}'")
370
+ fallback_search = arxiv.Search(
371
+ query=fallback_query,
372
+ max_results=5,
373
+ sort_by=arxiv.SortCriterion.Relevance,
374
+ sort_order=arxiv.SortOrder.Descending
375
+ )
376
+
377
+ papers = []
378
+ for i, result in enumerate(fallback_search.results()):
379
+ if i >= 5: # Limit fallback results
380
+ break
381
+ try:
382
+ paper = Paper(
383
+ id=result.entry_id.split('/')[-1],
384
+ title=result.title,
385
+ abstract=result.summary,
386
+ authors=[author.name for author in result.authors],
387
+ categories=result.categories,
388
+ published=result.published,
389
+ url=result.entry_id
390
+ )
391
+ papers.append(paper)
392
+ except Exception as e:
393
+ logger.warning(f"Error in fallback paper processing: {e}")
394
+ continue
395
+
396
+ if papers:
397
+ logger.info(f"🔄 Fallback search '{fallback_query}' returned {len(papers)} papers")
398
+ return papers
399
 
400
  except Exception as e:
401
  logger.error(f"Even fallback search failed: {e}")
 
561
  with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
562
  rerank_scores = self.reranker.predict(pairs, show_progress_bar=False)
563
 
564
+ # Combine with original scores (ensure positive scores)
565
  reranked_results = []
566
  for i, (chunk, original_score) in enumerate(results):
567
+ # Normalize rerank scores to [0, 1] and ensure positive
568
+ rerank_score = float(rerank_scores[i])
569
+ rerank_score = max(0.0, min(1.0, (rerank_score + 1) / 2)) # Convert from [-1,1] to [0,1]
570
+
571
+ combined_score = 0.6 * rerank_score + 0.4 * max(0.0, original_score)
572
  reranked_results.append((chunk, combined_score))
573
 
574
  # Sort by new scores