mihirinamdar commited on
Commit
ab4f49f
·
verified ·
1 Parent(s): d28aff5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -98
app.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Enhanced ArXiv RAG System - Hugging Face Spaces Compatible Version (Fixed)
3
  """
4
 
5
  import os
@@ -13,7 +13,6 @@ from datetime import datetime, timedelta
13
  import logging
14
  import tempfile
15
  import shutil
16
- import spaces
17
 
18
  # Core ML libraries
19
  import torch
@@ -31,14 +30,14 @@ from nltk.stem import PorterStemmer
31
 
32
  # Download required NLTK data
33
  try:
34
- nltk.data.find('tokenizers/punkt')
35
  except LookupError:
36
- nltk.download('punkt')
37
 
38
  try:
39
- nltk.data.find('corpora/stopwords')
40
  except LookupError:
41
- nltk.download('stopwords')
42
 
43
  # Setup logging
44
  logging.basicConfig(level=logging.INFO)
@@ -75,7 +74,7 @@ class BM25Retriever:
75
  self.avg_doc_length = 0
76
  self.stemmer = PorterStemmer()
77
  try:
78
- self.stop_words = set(stopwords.words('english'))
79
  except:
80
  self.stop_words = set()
81
 
@@ -151,7 +150,7 @@ class SimpleVectorStore:
151
  def query(self, query_embedding: List[float], n_results: int = 10) -> Dict:
152
  """Query the vector store"""
153
  if not self.embeddings:
154
- return {'ids': [[]], 'documents': [[]], 'metadatas': [[]]}
155
 
156
  # Calculate cosine similarities
157
  query_embedding = np.array(query_embedding)
@@ -168,25 +167,25 @@ class SimpleVectorStore:
168
  top_indices = np.argsort(similarities)[::-1][:n_results]
169
 
170
  return {
171
- 'ids': [[self.ids[i] for i in top_indices]],
172
- 'documents': [[self.documents[i] for i in top_indices]],
173
- 'metadatas': [[self.metadatas[i] for i in top_indices]]
174
  }
175
 
176
  def get(self, ids: Optional[List[str]] = None) -> Dict:
177
  """Get documents by IDs or all documents"""
178
  if ids is None:
179
  return {
180
- 'ids': self.ids,
181
- 'documents': self.documents,
182
- 'metadatas': self.metadatas
183
  }
184
  else:
185
  indices = [self.ids.index(id_) for id_ in ids if id_ in self.ids]
186
  return {
187
- 'ids': [self.ids[i] for i in indices],
188
- 'documents': [self.documents[i] for i in indices],
189
- 'metadatas': [self.metadatas[i] for i in indices]
190
  }
191
 
192
  def clear(self):
@@ -202,16 +201,29 @@ class EnhancedArxivRAG:
202
  def __init__(self):
203
  logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
204
 
205
- # Use CPU-friendly models for HF Spaces
206
- self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
207
- self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2')
208
-
209
- # Initialize summarizer without GPU specification
210
- self.summarizer = pipeline(
211
- "summarization",
212
- model="facebook/bart-large-cnn",
213
- device=-1 # Force CPU usage
214
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  # Use simple vector store instead of ChromaDB for HF Spaces
217
  self.vector_store = SimpleVectorStore()
@@ -244,9 +256,9 @@ class EnhancedArxivRAG:
244
  papers = []
245
  for result in search.results():
246
  paper = Paper(
247
- id=result.entry_id.split('/')[-1],
248
- title=result.title.strip().replace('\n', ' '),
249
- abstract=result.summary.strip().replace('\n', ' '),
250
  authors=[author.name for author in result.authors],
251
  categories=result.categories,
252
  published=result.published.replace(tzinfo=None),
@@ -318,7 +330,6 @@ class EnhancedArxivRAG:
318
 
319
  return chunks
320
 
321
- @spaces.GPU(duration=60) # GPU decorator for processing
322
  def process_and_store(self, papers: List[Paper]):
323
  """Process papers and store in vector store"""
324
  logger.info("Processing and storing papers...")
@@ -372,12 +383,12 @@ class EnhancedArxivRAG:
372
  bm25_scores = self.bm25_retriever.score(query, top_k * 2)
373
 
374
  for idx, score in bm25_scores:
375
- if idx < len(all_docs['ids']):
376
  bm25_results.append({
377
- 'id': all_docs['ids'][idx],
378
- 'document': all_docs['documents'][idx],
379
- 'metadata': all_docs['metadatas'][idx],
380
- 'score': score
381
  })
382
 
383
  # Combine results using RRF
@@ -385,13 +396,13 @@ class EnhancedArxivRAG:
385
  bm25_weight = 1.0 - semantic_weight
386
 
387
  # Add semantic scores
388
- for i, doc_id in enumerate(semantic_results['ids'][0]):
389
  rank = i + 1
390
  combined_scores[doc_id] = combined_scores.get(doc_id, 0) + semantic_weight / rank
391
 
392
  # Add BM25 scores
393
  for i, result in enumerate(bm25_results):
394
- doc_id = result['id']
395
  rank = i + 1
396
  combined_scores[doc_id] = combined_scores.get(doc_id, 0) + bm25_weight / rank
397
 
@@ -402,34 +413,33 @@ class EnhancedArxivRAG:
402
  final_results = []
403
  for doc_id, score in sorted_results[:top_k]:
404
  doc_result = self.vector_store.get(ids=[doc_id])
405
- if doc_result['ids']:
406
  final_results.append({
407
- 'id': doc_id,
408
- 'document': doc_result['documents'][0],
409
- 'metadata': doc_result['metadatas'][0],
410
- 'combined_score': score
411
  })
412
 
413
  return final_results
414
 
415
- @spaces.GPU(duration=30) # GPU decorator for reranking
416
  def rerank_results(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
417
  """Rerank results using cross-encoder"""
418
  if not results:
419
  return results
420
 
421
  # Prepare query-document pairs
422
- query_doc_pairs = [(query, result['document']) for result in results]
423
 
424
  # Get reranking scores
425
  rerank_scores = self.reranker.predict(query_doc_pairs)
426
 
427
  # Add rerank scores to results
428
  for i, result in enumerate(results):
429
- result['rerank_score'] = float(rerank_scores[i])
430
 
431
  # Sort by rerank score
432
- reranked_results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
433
  return reranked_results[:top_k]
434
 
435
  def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
@@ -438,7 +448,7 @@ class EnhancedArxivRAG:
438
  return "No relevant information found to answer your query."
439
 
440
  # Combine context from top chunks
441
- context_texts = [chunk['document'] for chunk in context_chunks[:3]]
442
  combined_context = "\n\n".join(context_texts)
443
 
444
  # Limit context length
@@ -451,14 +461,13 @@ class EnhancedArxivRAG:
451
  summary = self.summarizer(summary_input,
452
  max_length=120,
453
  min_length=30,
454
- do_sample=False)[0]['summary_text']
455
  return summary
456
  except Exception as e:
457
  logger.error(f"Error generating summary: {e}")
458
- return f"Based on the retrieved papers about '{query}', here are the key findings:\n\n" + \
459
- "\n\n".join([chunk['document'][:150] + "..." for chunk in context_chunks[:2]])
460
 
461
- @spaces.GPU(duration=120) # Main GPU decorator for the full pipeline
462
  def search_and_answer(self, query: str, max_papers: int = 15,
463
  top_k_retrieval: int = 10, top_k_rerank: int = 5,
464
  categories: Optional[List[str]] = None,
@@ -467,10 +476,10 @@ class EnhancedArxivRAG:
467
 
468
  if not query.strip():
469
  return {
470
- 'answer': "Please enter a valid research query.",
471
- 'papers': [],
472
- 'retrieved_chunks': [],
473
- 'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
474
  }
475
 
476
  try:
@@ -479,10 +488,10 @@ class EnhancedArxivRAG:
479
 
480
  if not papers:
481
  return {
482
- 'answer': "No papers found for your query. Please try different keywords.",
483
- 'papers': [],
484
- 'retrieved_chunks': [],
485
- 'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
486
  }
487
 
488
  # Process and store papers
@@ -500,36 +509,36 @@ class EnhancedArxivRAG:
500
  # Prepare unique papers
501
  unique_papers = {}
502
  for chunk in reranked_results:
503
- paper_id = chunk['id'].split('_')[0]
504
  if paper_id in self.papers_cache and paper_id not in unique_papers:
505
  paper = self.papers_cache[paper_id]
506
  unique_papers[paper_id] = {
507
- 'title': paper.title,
508
- 'authors': paper.authors,
509
- 'abstract': paper.abstract,
510
- 'url': paper.url,
511
- 'categories': paper.categories,
512
- 'published': paper.published.strftime('%Y-%m-%d')
513
  }
514
 
515
  return {
516
- 'answer': answer,
517
- 'papers': list(unique_papers.values()),
518
- 'retrieved_chunks': reranked_results,
519
- 'search_stats': {
520
- 'papers_found': len(papers),
521
- 'chunks_retrieved': len(reranked_results),
522
- 'unique_papers_in_results': len(unique_papers)
523
  }
524
  }
525
 
526
  except Exception as e:
527
  logger.error(f"Error in search_and_answer: {e}")
528
  return {
529
- 'answer': f"An error occurred while processing your query: {str(e)}",
530
- 'papers': [],
531
- 'retrieved_chunks': [],
532
- 'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
533
  }
534
 
535
  # Global RAG instance
@@ -548,7 +557,7 @@ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
548
  """Main search function for Gradio interface"""
549
 
550
  if not query.strip():
551
- return "❌ Please enter a research topic or question.", "", ""
552
 
553
  try:
554
  # Initialize RAG system
@@ -557,7 +566,7 @@ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
557
  # Parse categories
558
  category_list = None
559
  if categories.strip():
560
- category_list = [cat.strip() for cat in categories.split(',') if cat.strip()]
561
 
562
  # Perform search
563
  result = rag.search_and_answer(
@@ -570,33 +579,33 @@ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
570
  )
571
 
572
  # Format answer
573
- answer = f"## 🤖 AI-Generated Answer\n\n{result['answer']}\n\n"
574
  answer += f"**Search Statistics:**\n"
575
- answer += f"- Papers found: {result['search_stats']['papers_found']}\n"
576
- answer += f"- Chunks retrieved: {result['search_stats']['chunks_retrieved']}\n"
577
- answer += f"- Unique papers in results: {result['search_stats']['unique_papers_in_results']}\n\n"
578
 
579
  # Format papers
580
  papers_md = "## 📚 Relevant Papers\n\n"
581
- for i, paper in enumerate(result['papers'], 1):
582
- papers_md += f"### {i}. {paper['title']}\n\n"
583
- papers_md += f"**Authors:** {', '.join(paper['authors'][:3])}{'...' if len(paper['authors']) > 3 else ''}\n\n"
584
- papers_md += f"**Categories:** {', '.join(paper['categories'])}\n\n"
585
- papers_md += f"**Published:** {paper['published']}\n\n"
586
- papers_md += f"**Abstract:** {paper['abstract'][:250]}{'...' if len(paper['abstract']) > 250 else ''}\n\n"
587
- papers_md += f"**URL:** [{paper['url']}]({paper['url']})\n\n"
588
  papers_md += "---\n\n"
589
 
590
  # Create papers dataframe
591
  papers_df = pd.DataFrame([
592
  {
593
- 'Title': paper['title'][:50] + '...' if len(paper['title']) > 50 else paper['title'],
594
- 'Authors': ', '.join(paper['authors'][:2]) + ('...' if len(paper['authors']) > 2 else ''),
595
- 'Categories': ', '.join(paper['categories'][:2]),
596
- 'Published': paper['published'],
597
- 'URL': paper['url']
598
  }
599
- for paper in result['papers']
600
  ])
601
 
602
  return answer, papers_md, papers_df
@@ -709,4 +718,6 @@ def create_interface():
709
  # Launch interface
710
  if __name__ == "__main__":
711
  interface = create_interface()
712
- interface.launch()
 
 
 
1
  """
2
+ Enhanced ArXiv RAG System - Hugging Face Spaces Compatible Version
3
  """
4
 
5
  import os
 
13
  import logging
14
  import tempfile
15
  import shutil
 
16
 
17
  # Core ML libraries
18
  import torch
 
30
 
31
  # Download required NLTK data
32
  try:
33
+ nltk.data.find("tokenizers/punkt")
34
  except LookupError:
35
+ nltk.download("punkt")
36
 
37
  try:
38
+ nltk.data.find("corpora/stopwords")
39
  except LookupError:
40
+ nltk.download("stopwords")
41
 
42
  # Setup logging
43
  logging.basicConfig(level=logging.INFO)
 
74
  self.avg_doc_length = 0
75
  self.stemmer = PorterStemmer()
76
  try:
77
+ self.stop_words = set(stopwords.words("english"))
78
  except:
79
  self.stop_words = set()
80
 
 
150
  def query(self, query_embedding: List[float], n_results: int = 10) -> Dict:
151
  """Query the vector store"""
152
  if not self.embeddings:
153
+ return {"ids": [[]], "documents": [[]], "metadatas": [[]]}
154
 
155
  # Calculate cosine similarities
156
  query_embedding = np.array(query_embedding)
 
167
  top_indices = np.argsort(similarities)[::-1][:n_results]
168
 
169
  return {
170
+ "ids": [[self.ids[i] for i in top_indices]],
171
+ "documents": [[self.documents[i] for i in top_indices]],
172
+ "metadatas": [[self.metadatas[i] for i in top_indices]]
173
  }
174
 
175
  def get(self, ids: Optional[List[str]] = None) -> Dict:
176
  """Get documents by IDs or all documents"""
177
  if ids is None:
178
  return {
179
+ "ids": self.ids,
180
+ "documents": self.documents,
181
+ "metadatas": self.metadatas
182
  }
183
  else:
184
  indices = [self.ids.index(id_) for id_ in ids if id_ in self.ids]
185
  return {
186
+ "ids": [self.ids[i] for i in indices],
187
+ "documents": [self.documents[i] for i in indices],
188
+ "metadatas": [self.metadatas[i] for i in indices]
189
  }
190
 
191
  def clear(self):
 
201
  def __init__(self):
202
  logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
203
 
204
+ # Determine device (GPU if available, else CPU)
205
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
206
+ logger.info(f"Using device: {self.device}")
207
+
208
+ # Load models with appropriate device settings
209
+ try:
210
+ logger.info("Loading embedding model...")
211
+ self.embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=self.device)
212
+ logger.info("Embedding model loaded.")
213
+
214
+ logger.info("Loading reranker model...")
215
+ self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-2-v2", device=self.device)
216
+ logger.info("Reranker model loaded.")
217
+
218
+ logger.info("Loading summarizer model...")
219
+ # For pipeline, device_map="auto" is often better for ZeroGPU
220
+ # If issues persist, try device=0 for the first GPU, or device=self.device
221
+ self.summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device_map="auto")
222
+ logger.info("Summarizer model loaded.")
223
+
224
+ except Exception as e:
225
+ logger.error(f"Error loading models: {e}")
226
+ raise
227
 
228
  # Use simple vector store instead of ChromaDB for HF Spaces
229
  self.vector_store = SimpleVectorStore()
 
256
  papers = []
257
  for result in search.results():
258
  paper = Paper(
259
+ id=result.entry_id.split("/")[-1],
260
+ title=result.title.strip().replace("\n", " "),
261
+ abstract=result.summary.strip().replace("\n", " "),
262
  authors=[author.name for author in result.authors],
263
  categories=result.categories,
264
  published=result.published.replace(tzinfo=None),
 
330
 
331
  return chunks
332
 
 
333
  def process_and_store(self, papers: List[Paper]):
334
  """Process papers and store in vector store"""
335
  logger.info("Processing and storing papers...")
 
383
  bm25_scores = self.bm25_retriever.score(query, top_k * 2)
384
 
385
  for idx, score in bm25_scores:
386
+ if idx < len(all_docs["ids"]):
387
  bm25_results.append({
388
+ "id": all_docs["ids"][idx],
389
+ "document": all_docs["documents"][idx],
390
+ "metadata": all_docs["metadatas"][idx],
391
+ "score": score
392
  })
393
 
394
  # Combine results using RRF
 
396
  bm25_weight = 1.0 - semantic_weight
397
 
398
  # Add semantic scores
399
+ for i, doc_id in enumerate(semantic_results["ids"][0]):
400
  rank = i + 1
401
  combined_scores[doc_id] = combined_scores.get(doc_id, 0) + semantic_weight / rank
402
 
403
  # Add BM25 scores
404
  for i, result in enumerate(bm25_results):
405
+ doc_id = result["id"]
406
  rank = i + 1
407
  combined_scores[doc_id] = combined_scores.get(doc_id, 0) + bm25_weight / rank
408
 
 
413
  final_results = []
414
  for doc_id, score in sorted_results[:top_k]:
415
  doc_result = self.vector_store.get(ids=[doc_id])
416
+ if doc_result["ids"]:
417
  final_results.append({
418
+ "id": doc_id,
419
+ "document": doc_result["documents"][0],
420
+ "metadata": doc_result["metadatas"][0],
421
+ "combined_score": score
422
  })
423
 
424
  return final_results
425
 
 
426
  def rerank_results(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
427
  """Rerank results using cross-encoder"""
428
  if not results:
429
  return results
430
 
431
  # Prepare query-document pairs
432
+ query_doc_pairs = [(query, result["document"]) for result in results]
433
 
434
  # Get reranking scores
435
  rerank_scores = self.reranker.predict(query_doc_pairs)
436
 
437
  # Add rerank scores to results
438
  for i, result in enumerate(results):
439
+ result["rerank_score"] = float(rerank_scores[i])
440
 
441
  # Sort by rerank score
442
+ reranked_results = sorted(results, key=lambda x: x["rerank_score"], reverse=True)
443
  return reranked_results[:top_k]
444
 
445
  def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
 
448
  return "No relevant information found to answer your query."
449
 
450
  # Combine context from top chunks
451
+ context_texts = [chunk["document"] for chunk in context_chunks[:3]]
452
  combined_context = "\n\n".join(context_texts)
453
 
454
  # Limit context length
 
461
  summary = self.summarizer(summary_input,
462
  max_length=120,
463
  min_length=30,
464
+ do_sample=False)[0]["summary_text"]
465
  return summary
466
  except Exception as e:
467
  logger.error(f"Error generating summary: {e}")
468
+ return f"Based on the retrieved papers about \'{query}\', here are the key findings:\n\n" + \
469
+ "\n\n".join([chunk["document"][:150] + "..." for chunk in context_chunks[:2]])
470
 
 
471
  def search_and_answer(self, query: str, max_papers: int = 15,
472
  top_k_retrieval: int = 10, top_k_rerank: int = 5,
473
  categories: Optional[List[str]] = None,
 
476
 
477
  if not query.strip():
478
  return {
479
+ "answer": "Please enter a valid research query.",
480
+ "papers": [],
481
+ "retrieved_chunks": [],
482
+ "search_stats": {"papers_found": 0, "chunks_retrieved": 0}
483
  }
484
 
485
  try:
 
488
 
489
  if not papers:
490
  return {
491
+ "answer": "No papers found for your query. Please try different keywords.",
492
+ "papers": [],
493
+ "retrieved_chunks": [],
494
+ "search_stats": {"papers_found": 0, "chunks_retrieved": 0}
495
  }
496
 
497
  # Process and store papers
 
509
  # Prepare unique papers
510
  unique_papers = {}
511
  for chunk in reranked_results:
512
+ paper_id = chunk["id"].split("_")[0]
513
  if paper_id in self.papers_cache and paper_id not in unique_papers:
514
  paper = self.papers_cache[paper_id]
515
  unique_papers[paper_id] = {
516
+ "title": paper.title,
517
+ "authors": paper.authors,
518
+ "abstract": paper.abstract,
519
+ "url": paper.url,
520
+ "categories": paper.categories,
521
+ "published": paper.published.strftime("%Y-%m-%d")
522
  }
523
 
524
  return {
525
+ "answer": answer,
526
+ "papers": list(unique_papers.values()),
527
+ "retrieved_chunks": reranked_results,
528
+ "search_stats": {
529
+ "papers_found": len(papers),
530
+ "chunks_retrieved": len(reranked_results),
531
+ "unique_papers_in_results": len(unique_papers)
532
  }
533
  }
534
 
535
  except Exception as e:
536
  logger.error(f"Error in search_and_answer: {e}")
537
  return {
538
+ "answer": f"An error occurred while processing your query: {str(e)}",
539
+ "papers": [],
540
+ "retrieved_chunks": [],
541
+ "search_stats": {"papers_found": 0, "chunks_retrieved": 0}
542
  }
543
 
544
  # Global RAG instance
 
557
  """Main search function for Gradio interface"""
558
 
559
  if not query.strip():
560
+ return "❌ Please enter a research topic or question.", "", pd.DataFrame()
561
 
562
  try:
563
  # Initialize RAG system
 
566
  # Parse categories
567
  category_list = None
568
  if categories.strip():
569
+ category_list = [cat.strip() for cat in categories.split(",") if cat.strip()]
570
 
571
  # Perform search
572
  result = rag.search_and_answer(
 
579
  )
580
 
581
  # Format answer
582
+ answer = f"## 🤖 AI-Generated Answer\n\n{result["answer"]}\n\n"
583
  answer += f"**Search Statistics:**\n"
584
+ answer += f"- Papers found: {result["search_stats"]["papers_found"]}\n"
585
+ answer += f"- Chunks retrieved: {result["search_stats"]["chunks_retrieved"]}\n"
586
+ answer += f"- Unique papers in results: {result["search_stats"]["unique_papers_in_results"]}\n\n"
587
 
588
  # Format papers
589
  papers_md = "## 📚 Relevant Papers\n\n"
590
+ for i, paper in enumerate(result["papers"], 1):
591
+ papers_md += f"### {i}. {paper["title"]}\n\n"
592
+ papers_md += f"**Authors:** {", ".join(paper["authors"][:3])}{"..." if len(paper["authors"]) > 3 else ""}\n\n"
593
+ papers_md += f"**Categories:** {", ".join(paper["categories"])}\n\n"
594
+ papers_md += f"**Published:** {paper["published"]}\n\n"
595
+ papers_md += f"**Abstract:** {paper["abstract"][:250]}{"..." if len(paper["abstract"]) > 250 else ""}\n\n"
596
+ papers_md += f"**URL:** [{paper["url"]}]({paper["url"]})\n\n"
597
  papers_md += "---\n\n"
598
 
599
  # Create papers dataframe
600
  papers_df = pd.DataFrame([
601
  {
602
+ "Title": paper["title"][:50] + "..." if len(paper["title"]) > 50 else paper["title"],
603
+ "Authors": ", ".join(paper["authors"][:2]) + ("..." if len(paper["authors"]) > 2 else ""),
604
+ "Categories": ", ".join(paper["categories"][:2]),
605
+ "Published": paper["published"],
606
+ "URL": paper["url"]
607
  }
608
+ for paper in result["papers"]
609
  ])
610
 
611
  return answer, papers_md, papers_df
 
718
  # Launch interface
719
  if __name__ == "__main__":
720
  interface = create_interface()
721
+ # Remove share=True for Hugging Face Spaces compatibility
722
+ interface.launch()
723
+