mihirinamdar commited on
Commit
d28aff5
·
verified ·
1 Parent(s): 0ee4114

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Enhanced ArXiv RAG System - Hugging Face Spaces Compatible Version
3
  """
4
 
5
  import os
@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
13
  import logging
14
  import tempfile
15
  import shutil
 
16
 
17
  # Core ML libraries
18
  import torch
@@ -201,11 +202,16 @@ class EnhancedArxivRAG:
201
  def __init__(self):
202
  logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
203
 
204
- # Use smaller, faster models for HF Spaces
205
  self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
206
- self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2') # Smaller reranker
207
- self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn",
208
- device=0 if torch.cuda.is_available() else -1)
 
 
 
 
 
209
 
210
  # Use simple vector store instead of ChromaDB for HF Spaces
211
  self.vector_store = SimpleVectorStore()
@@ -312,6 +318,7 @@ class EnhancedArxivRAG:
312
 
313
  return chunks
314
 
 
315
  def process_and_store(self, papers: List[Paper]):
316
  """Process papers and store in vector store"""
317
  logger.info("Processing and storing papers...")
@@ -405,6 +412,7 @@ class EnhancedArxivRAG:
405
 
406
  return final_results
407
 
 
408
  def rerank_results(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
409
  """Rerank results using cross-encoder"""
410
  if not results:
@@ -450,6 +458,7 @@ class EnhancedArxivRAG:
450
  return f"Based on the retrieved papers about '{query}', here are the key findings:\n\n" + \
451
  "\n\n".join([chunk['document'][:150] + "..." for chunk in context_chunks[:2]])
452
 
 
453
  def search_and_answer(self, query: str, max_papers: int = 15,
454
  top_k_retrieval: int = 10, top_k_rerank: int = 5,
455
  categories: Optional[List[str]] = None,
@@ -700,4 +709,4 @@ def create_interface():
700
  # Launch interface
701
  if __name__ == "__main__":
702
  interface = create_interface()
703
- interface.launch(share=True)
 
1
  """
2
+ Enhanced ArXiv RAG System - Hugging Face Spaces Compatible Version (Fixed)
3
  """
4
 
5
  import os
 
13
  import logging
14
  import tempfile
15
  import shutil
16
+ import spaces
17
 
18
  # Core ML libraries
19
  import torch
 
202
  def __init__(self):
203
  logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
204
 
205
+ # Use CPU-friendly models for HF Spaces
206
  self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
207
+ self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2')
208
+
209
+ # Initialize summarizer without GPU specification
210
+ self.summarizer = pipeline(
211
+ "summarization",
212
+ model="facebook/bart-large-cnn",
213
+ device=-1 # Force CPU usage
214
+ )
215
 
216
  # Use simple vector store instead of ChromaDB for HF Spaces
217
  self.vector_store = SimpleVectorStore()
 
318
 
319
  return chunks
320
 
321
+ @spaces.GPU(duration=60) # GPU decorator for processing
322
  def process_and_store(self, papers: List[Paper]):
323
  """Process papers and store in vector store"""
324
  logger.info("Processing and storing papers...")
 
412
 
413
  return final_results
414
 
415
+ @spaces.GPU(duration=30) # GPU decorator for reranking
416
  def rerank_results(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
417
  """Rerank results using cross-encoder"""
418
  if not results:
 
458
  return f"Based on the retrieved papers about '{query}', here are the key findings:\n\n" + \
459
  "\n\n".join([chunk['document'][:150] + "..." for chunk in context_chunks[:2]])
460
 
461
+ @spaces.GPU(duration=120) # Main GPU decorator for the full pipeline
462
  def search_and_answer(self, query: str, max_papers: int = 15,
463
  top_k_retrieval: int = 10, top_k_rerank: int = 5,
464
  categories: Optional[List[str]] = None,
 
709
  # Launch interface
710
  if __name__ == "__main__":
711
  interface = create_interface()
712
+ interface.launch()