mihirinamdar commited on
Commit
82723d8
·
verified ·
1 Parent(s): ab4f49f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -106
app.py CHANGED
@@ -30,14 +30,14 @@ from nltk.stem import PorterStemmer
30
 
31
  # Download required NLTK data
32
  try:
33
- nltk.data.find("tokenizers/punkt")
34
  except LookupError:
35
- nltk.download("punkt")
36
 
37
  try:
38
- nltk.data.find("corpora/stopwords")
39
  except LookupError:
40
- nltk.download("stopwords")
41
 
42
  # Setup logging
43
  logging.basicConfig(level=logging.INFO)
@@ -74,7 +74,7 @@ class BM25Retriever:
74
  self.avg_doc_length = 0
75
  self.stemmer = PorterStemmer()
76
  try:
77
- self.stop_words = set(stopwords.words("english"))
78
  except:
79
  self.stop_words = set()
80
 
@@ -150,7 +150,7 @@ class SimpleVectorStore:
150
  def query(self, query_embedding: List[float], n_results: int = 10) -> Dict:
151
  """Query the vector store"""
152
  if not self.embeddings:
153
- return {"ids": [[]], "documents": [[]], "metadatas": [[]]}
154
 
155
  # Calculate cosine similarities
156
  query_embedding = np.array(query_embedding)
@@ -167,25 +167,25 @@ class SimpleVectorStore:
167
  top_indices = np.argsort(similarities)[::-1][:n_results]
168
 
169
  return {
170
- "ids": [[self.ids[i] for i in top_indices]],
171
- "documents": [[self.documents[i] for i in top_indices]],
172
- "metadatas": [[self.metadatas[i] for i in top_indices]]
173
  }
174
 
175
  def get(self, ids: Optional[List[str]] = None) -> Dict:
176
  """Get documents by IDs or all documents"""
177
  if ids is None:
178
  return {
179
- "ids": self.ids,
180
- "documents": self.documents,
181
- "metadatas": self.metadatas
182
  }
183
  else:
184
  indices = [self.ids.index(id_) for id_ in ids if id_ in self.ids]
185
  return {
186
- "ids": [self.ids[i] for i in indices],
187
- "documents": [self.documents[i] for i in indices],
188
- "metadatas": [self.metadatas[i] for i in indices]
189
  }
190
 
191
  def clear(self):
@@ -201,29 +201,11 @@ class EnhancedArxivRAG:
201
  def __init__(self):
202
  logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
203
 
204
- # Determine device (GPU if available, else CPU)
205
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
206
- logger.info(f"Using device: {self.device}")
207
-
208
- # Load models with appropriate device settings
209
- try:
210
- logger.info("Loading embedding model...")
211
- self.embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=self.device)
212
- logger.info("Embedding model loaded.")
213
-
214
- logger.info("Loading reranker model...")
215
- self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-2-v2", device=self.device)
216
- logger.info("Reranker model loaded.")
217
-
218
- logger.info("Loading summarizer model...")
219
- # For pipeline, device_map="auto" is often better for ZeroGPU
220
- # If issues persist, try device=0 for the first GPU, or device=self.device
221
- self.summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device_map="auto")
222
- logger.info("Summarizer model loaded.")
223
-
224
- except Exception as e:
225
- logger.error(f"Error loading models: {e}")
226
- raise
227
 
228
  # Use simple vector store instead of ChromaDB for HF Spaces
229
  self.vector_store = SimpleVectorStore()
@@ -256,9 +238,9 @@ class EnhancedArxivRAG:
256
  papers = []
257
  for result in search.results():
258
  paper = Paper(
259
- id=result.entry_id.split("/")[-1],
260
- title=result.title.strip().replace("\n", " "),
261
- abstract=result.summary.strip().replace("\n", " "),
262
  authors=[author.name for author in result.authors],
263
  categories=result.categories,
264
  published=result.published.replace(tzinfo=None),
@@ -383,12 +365,12 @@ class EnhancedArxivRAG:
383
  bm25_scores = self.bm25_retriever.score(query, top_k * 2)
384
 
385
  for idx, score in bm25_scores:
386
- if idx < len(all_docs["ids"]):
387
  bm25_results.append({
388
- "id": all_docs["ids"][idx],
389
- "document": all_docs["documents"][idx],
390
- "metadata": all_docs["metadatas"][idx],
391
- "score": score
392
  })
393
 
394
  # Combine results using RRF
@@ -396,13 +378,13 @@ class EnhancedArxivRAG:
396
  bm25_weight = 1.0 - semantic_weight
397
 
398
  # Add semantic scores
399
- for i, doc_id in enumerate(semantic_results["ids"][0]):
400
  rank = i + 1
401
  combined_scores[doc_id] = combined_scores.get(doc_id, 0) + semantic_weight / rank
402
 
403
  # Add BM25 scores
404
  for i, result in enumerate(bm25_results):
405
- doc_id = result["id"]
406
  rank = i + 1
407
  combined_scores[doc_id] = combined_scores.get(doc_id, 0) + bm25_weight / rank
408
 
@@ -413,12 +395,12 @@ class EnhancedArxivRAG:
413
  final_results = []
414
  for doc_id, score in sorted_results[:top_k]:
415
  doc_result = self.vector_store.get(ids=[doc_id])
416
- if doc_result["ids"]:
417
  final_results.append({
418
- "id": doc_id,
419
- "document": doc_result["documents"][0],
420
- "metadata": doc_result["metadatas"][0],
421
- "combined_score": score
422
  })
423
 
424
  return final_results
@@ -429,17 +411,17 @@ class EnhancedArxivRAG:
429
  return results
430
 
431
  # Prepare query-document pairs
432
- query_doc_pairs = [(query, result["document"]) for result in results]
433
 
434
  # Get reranking scores
435
  rerank_scores = self.reranker.predict(query_doc_pairs)
436
 
437
  # Add rerank scores to results
438
  for i, result in enumerate(results):
439
- result["rerank_score"] = float(rerank_scores[i])
440
 
441
  # Sort by rerank score
442
- reranked_results = sorted(results, key=lambda x: x["rerank_score"], reverse=True)
443
  return reranked_results[:top_k]
444
 
445
  def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
@@ -448,7 +430,7 @@ class EnhancedArxivRAG:
448
  return "No relevant information found to answer your query."
449
 
450
  # Combine context from top chunks
451
- context_texts = [chunk["document"] for chunk in context_chunks[:3]]
452
  combined_context = "\n\n".join(context_texts)
453
 
454
  # Limit context length
@@ -461,12 +443,12 @@ class EnhancedArxivRAG:
461
  summary = self.summarizer(summary_input,
462
  max_length=120,
463
  min_length=30,
464
- do_sample=False)[0]["summary_text"]
465
  return summary
466
  except Exception as e:
467
  logger.error(f"Error generating summary: {e}")
468
- return f"Based on the retrieved papers about \'{query}\', here are the key findings:\n\n" + \
469
- "\n\n".join([chunk["document"][:150] + "..." for chunk in context_chunks[:2]])
470
 
471
  def search_and_answer(self, query: str, max_papers: int = 15,
472
  top_k_retrieval: int = 10, top_k_rerank: int = 5,
@@ -476,10 +458,10 @@ class EnhancedArxivRAG:
476
 
477
  if not query.strip():
478
  return {
479
- "answer": "Please enter a valid research query.",
480
- "papers": [],
481
- "retrieved_chunks": [],
482
- "search_stats": {"papers_found": 0, "chunks_retrieved": 0}
483
  }
484
 
485
  try:
@@ -488,10 +470,10 @@ class EnhancedArxivRAG:
488
 
489
  if not papers:
490
  return {
491
- "answer": "No papers found for your query. Please try different keywords.",
492
- "papers": [],
493
- "retrieved_chunks": [],
494
- "search_stats": {"papers_found": 0, "chunks_retrieved": 0}
495
  }
496
 
497
  # Process and store papers
@@ -509,36 +491,36 @@ class EnhancedArxivRAG:
509
  # Prepare unique papers
510
  unique_papers = {}
511
  for chunk in reranked_results:
512
- paper_id = chunk["id"].split("_")[0]
513
  if paper_id in self.papers_cache and paper_id not in unique_papers:
514
  paper = self.papers_cache[paper_id]
515
  unique_papers[paper_id] = {
516
- "title": paper.title,
517
- "authors": paper.authors,
518
- "abstract": paper.abstract,
519
- "url": paper.url,
520
- "categories": paper.categories,
521
- "published": paper.published.strftime("%Y-%m-%d")
522
  }
523
 
524
  return {
525
- "answer": answer,
526
- "papers": list(unique_papers.values()),
527
- "retrieved_chunks": reranked_results,
528
- "search_stats": {
529
- "papers_found": len(papers),
530
- "chunks_retrieved": len(reranked_results),
531
- "unique_papers_in_results": len(unique_papers)
532
  }
533
  }
534
 
535
  except Exception as e:
536
  logger.error(f"Error in search_and_answer: {e}")
537
  return {
538
- "answer": f"An error occurred while processing your query: {str(e)}",
539
- "papers": [],
540
- "retrieved_chunks": [],
541
- "search_stats": {"papers_found": 0, "chunks_retrieved": 0}
542
  }
543
 
544
  # Global RAG instance
@@ -557,7 +539,7 @@ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
557
  """Main search function for Gradio interface"""
558
 
559
  if not query.strip():
560
- return "❌ Please enter a research topic or question.", "", pd.DataFrame()
561
 
562
  try:
563
  # Initialize RAG system
@@ -566,7 +548,7 @@ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
566
  # Parse categories
567
  category_list = None
568
  if categories.strip():
569
- category_list = [cat.strip() for cat in categories.split(",") if cat.strip()]
570
 
571
  # Perform search
572
  result = rag.search_and_answer(
@@ -579,33 +561,33 @@ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
579
  )
580
 
581
  # Format answer
582
- answer = f"## 🤖 AI-Generated Answer\n\n{result["answer"]}\n\n"
583
  answer += f"**Search Statistics:**\n"
584
- answer += f"- Papers found: {result["search_stats"]["papers_found"]}\n"
585
- answer += f"- Chunks retrieved: {result["search_stats"]["chunks_retrieved"]}\n"
586
- answer += f"- Unique papers in results: {result["search_stats"]["unique_papers_in_results"]}\n\n"
587
 
588
  # Format papers
589
  papers_md = "## 📚 Relevant Papers\n\n"
590
- for i, paper in enumerate(result["papers"], 1):
591
- papers_md += f"### {i}. {paper["title"]}\n\n"
592
- papers_md += f"**Authors:** {", ".join(paper["authors"][:3])}{"..." if len(paper["authors"]) > 3 else ""}\n\n"
593
- papers_md += f"**Categories:** {", ".join(paper["categories"])}\n\n"
594
- papers_md += f"**Published:** {paper["published"]}\n\n"
595
- papers_md += f"**Abstract:** {paper["abstract"][:250]}{"..." if len(paper["abstract"]) > 250 else ""}\n\n"
596
- papers_md += f"**URL:** [{paper["url"]}]({paper["url"]})\n\n"
597
  papers_md += "---\n\n"
598
 
599
  # Create papers dataframe
600
  papers_df = pd.DataFrame([
601
  {
602
- "Title": paper["title"][:50] + "..." if len(paper["title"]) > 50 else paper["title"],
603
- "Authors": ", ".join(paper["authors"][:2]) + ("..." if len(paper["authors"]) > 2 else ""),
604
- "Categories": ", ".join(paper["categories"][:2]),
605
- "Published": paper["published"],
606
- "URL": paper["url"]
607
  }
608
- for paper in result["papers"]
609
  ])
610
 
611
  return answer, papers_md, papers_df
@@ -718,6 +700,5 @@ def create_interface():
718
  # Launch interface
719
  if __name__ == "__main__":
720
  interface = create_interface()
721
- # Remove share=True for Hugging Face Spaces compatibility
722
  interface.launch()
723
 
 
30
 
31
  # Download required NLTK data
32
  try:
33
+ nltk.data.find('tokenizers/punkt')
34
  except LookupError:
35
+ nltk.download('punkt')
36
 
37
  try:
38
+ nltk.data.find('corpora/stopwords')
39
  except LookupError:
40
+ nltk.download('stopwords')
41
 
42
  # Setup logging
43
  logging.basicConfig(level=logging.INFO)
 
74
  self.avg_doc_length = 0
75
  self.stemmer = PorterStemmer()
76
  try:
77
+ self.stop_words = set(stopwords.words('english'))
78
  except:
79
  self.stop_words = set()
80
 
 
150
  def query(self, query_embedding: List[float], n_results: int = 10) -> Dict:
151
  """Query the vector store"""
152
  if not self.embeddings:
153
+ return {'ids': [[]], 'documents': [[]], 'metadatas': [[]]}
154
 
155
  # Calculate cosine similarities
156
  query_embedding = np.array(query_embedding)
 
167
  top_indices = np.argsort(similarities)[::-1][:n_results]
168
 
169
  return {
170
+ 'ids': [[self.ids[i] for i in top_indices]],
171
+ 'documents': [[self.documents[i] for i in top_indices]],
172
+ 'metadatas': [[self.metadatas[i] for i in top_indices]]
173
  }
174
 
175
  def get(self, ids: Optional[List[str]] = None) -> Dict:
176
  """Get documents by IDs or all documents"""
177
  if ids is None:
178
  return {
179
+ 'ids': self.ids,
180
+ 'documents': self.documents,
181
+ 'metadatas': self.metadatas
182
  }
183
  else:
184
  indices = [self.ids.index(id_) for id_ in ids if id_ in self.ids]
185
  return {
186
+ 'ids': [self.ids[i] for i in indices],
187
+ 'documents': [self.documents[i] for i in indices],
188
+ 'metadatas': [self.metadatas[i] for i in indices]
189
  }
190
 
191
  def clear(self):
 
201
  def __init__(self):
202
  logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
203
 
204
+ # Use smaller, faster models for HF Spaces
205
+ self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
206
+ self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2') # Smaller reranker
207
+ self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn",
208
+ device=0 if torch.cuda.is_available() else -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  # Use simple vector store instead of ChromaDB for HF Spaces
211
  self.vector_store = SimpleVectorStore()
 
238
  papers = []
239
  for result in search.results():
240
  paper = Paper(
241
+ id=result.entry_id.split('/')[-1],
242
+ title=result.title.strip().replace('\n', ' '),
243
+ abstract=result.summary.strip().replace('\n', ' '),
244
  authors=[author.name for author in result.authors],
245
  categories=result.categories,
246
  published=result.published.replace(tzinfo=None),
 
365
  bm25_scores = self.bm25_retriever.score(query, top_k * 2)
366
 
367
  for idx, score in bm25_scores:
368
+ if idx < len(all_docs['ids']):
369
  bm25_results.append({
370
+ 'id': all_docs['ids'][idx],
371
+ 'document': all_docs['documents'][idx],
372
+ 'metadata': all_docs['metadatas'][idx],
373
+ 'score': score
374
  })
375
 
376
  # Combine results using RRF
 
378
  bm25_weight = 1.0 - semantic_weight
379
 
380
  # Add semantic scores
381
+ for i, doc_id in enumerate(semantic_results['ids'][0]):
382
  rank = i + 1
383
  combined_scores[doc_id] = combined_scores.get(doc_id, 0) + semantic_weight / rank
384
 
385
  # Add BM25 scores
386
  for i, result in enumerate(bm25_results):
387
+ doc_id = result['id']
388
  rank = i + 1
389
  combined_scores[doc_id] = combined_scores.get(doc_id, 0) + bm25_weight / rank
390
 
 
395
  final_results = []
396
  for doc_id, score in sorted_results[:top_k]:
397
  doc_result = self.vector_store.get(ids=[doc_id])
398
+ if doc_result['ids']:
399
  final_results.append({
400
+ 'id': doc_id,
401
+ 'document': doc_result['documents'][0],
402
+ 'metadata': doc_result['metadatas'][0],
403
+ 'combined_score': score
404
  })
405
 
406
  return final_results
 
411
  return results
412
 
413
  # Prepare query-document pairs
414
+ query_doc_pairs = [(query, result['document']) for result in results]
415
 
416
  # Get reranking scores
417
  rerank_scores = self.reranker.predict(query_doc_pairs)
418
 
419
  # Add rerank scores to results
420
  for i, result in enumerate(results):
421
+ result['rerank_score'] = float(rerank_scores[i])
422
 
423
  # Sort by rerank score
424
+ reranked_results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
425
  return reranked_results[:top_k]
426
 
427
  def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
 
430
  return "No relevant information found to answer your query."
431
 
432
  # Combine context from top chunks
433
+ context_texts = [chunk['document'] for chunk in context_chunks[:3]]
434
  combined_context = "\n\n".join(context_texts)
435
 
436
  # Limit context length
 
443
  summary = self.summarizer(summary_input,
444
  max_length=120,
445
  min_length=30,
446
+ do_sample=False)[0]['summary_text']
447
  return summary
448
  except Exception as e:
449
  logger.error(f"Error generating summary: {e}")
450
+ return f"Based on the retrieved papers about '{query}', here are the key findings:\n\n" + \
451
+ "\n\n".join([chunk['document'][:150] + "..." for chunk in context_chunks[:2]])
452
 
453
  def search_and_answer(self, query: str, max_papers: int = 15,
454
  top_k_retrieval: int = 10, top_k_rerank: int = 5,
 
458
 
459
  if not query.strip():
460
  return {
461
+ 'answer': "Please enter a valid research query.",
462
+ 'papers': [],
463
+ 'retrieved_chunks': [],
464
+ 'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
465
  }
466
 
467
  try:
 
470
 
471
  if not papers:
472
  return {
473
+ 'answer': "No papers found for your query. Please try different keywords.",
474
+ 'papers': [],
475
+ 'retrieved_chunks': [],
476
+ 'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
477
  }
478
 
479
  # Process and store papers
 
491
  # Prepare unique papers
492
  unique_papers = {}
493
  for chunk in reranked_results:
494
+ paper_id = chunk['id'].split('_')[0]
495
  if paper_id in self.papers_cache and paper_id not in unique_papers:
496
  paper = self.papers_cache[paper_id]
497
  unique_papers[paper_id] = {
498
+ 'title': paper.title,
499
+ 'authors': paper.authors,
500
+ 'abstract': paper.abstract,
501
+ 'url': paper.url,
502
+ 'categories': paper.categories,
503
+ 'published': paper.published.strftime('%Y-%m-%d')
504
  }
505
 
506
  return {
507
+ 'answer': answer,
508
+ 'papers': list(unique_papers.values()),
509
+ 'retrieved_chunks': reranked_results,
510
+ 'search_stats': {
511
+ 'papers_found': len(papers),
512
+ 'chunks_retrieved': len(reranked_results),
513
+ 'unique_papers_in_results': len(unique_papers)
514
  }
515
  }
516
 
517
  except Exception as e:
518
  logger.error(f"Error in search_and_answer: {e}")
519
  return {
520
+ 'answer': f"An error occurred while processing your query: {str(e)}",
521
+ 'papers': [],
522
+ 'retrieved_chunks': [],
523
+ 'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
524
  }
525
 
526
  # Global RAG instance
 
539
  """Main search function for Gradio interface"""
540
 
541
  if not query.strip():
542
+ return "❌ Please enter a research topic or question.", "", ""
543
 
544
  try:
545
  # Initialize RAG system
 
548
  # Parse categories
549
  category_list = None
550
  if categories.strip():
551
+ category_list = [cat.strip() for cat in categories.split(',') if cat.strip()]
552
 
553
  # Perform search
554
  result = rag.search_and_answer(
 
561
  )
562
 
563
  # Format answer
564
+ answer = f"## 🤖 AI-Generated Answer\n\n{result['answer']}\n\n"
565
  answer += f"**Search Statistics:**\n"
566
+ answer += f"- Papers found: {result['search_stats']['papers_found']}\n"
567
+ answer += f"- Chunks retrieved: {result['search_stats']['chunks_retrieved']}\n"
568
+ answer += f"- Unique papers in results: {result['search_stats']['unique_papers_in_results']}\n\n"
569
 
570
  # Format papers
571
  papers_md = "## 📚 Relevant Papers\n\n"
572
+ for i, paper in enumerate(result['papers'], 1):
573
+ papers_md += f"### {i}. {paper['title']}\n\n"
574
+ papers_md += f"**Authors:** {', '.join(paper['authors'][:3])}{'...' if len(paper['authors']) > 3 else ''}\n\n"
575
+ papers_md += f"**Categories:** {', '.join(paper['categories'])}\n\n"
576
+ papers_md += f"**Published:** {paper['published']}\n\n"
577
+ papers_md += f"**Abstract:** {paper['abstract'][:250]}{'...' if len(paper['abstract']) > 250 else ''}\n\n"
578
+ papers_md += f"**URL:** [{paper['url']}]({paper['url']})\n\n"
579
  papers_md += "---\n\n"
580
 
581
  # Create papers dataframe
582
  papers_df = pd.DataFrame([
583
  {
584
+ 'Title': paper['title'][:50] + '...' if len(paper['title']) > 50 else paper['title'],
585
+ 'Authors': ', '.join(paper['authors'][:2]) + ('...' if len(paper['authors']) > 2 else ''),
586
+ 'Categories': ', '.join(paper['categories'][:2]),
587
+ 'Published': paper['published'],
588
+ 'URL': paper['url']
589
  }
590
+ for paper in result['papers']
591
  ])
592
 
593
  return answer, papers_md, papers_df
 
700
  # Launch interface
701
  if __name__ == "__main__":
702
  interface = create_interface()
 
703
  interface.launch()
704