Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
"""
|
2 |
-
Enhanced ArXiv RAG System - Hugging Face Spaces Compatible Version
|
3 |
"""
|
4 |
|
5 |
import os
|
@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
|
|
13 |
import logging
|
14 |
import tempfile
|
15 |
import shutil
|
|
|
16 |
|
17 |
# Core ML libraries
|
18 |
import torch
|
@@ -201,11 +202,16 @@ class EnhancedArxivRAG:
|
|
201 |
def __init__(self):
|
202 |
logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
|
203 |
|
204 |
-
# Use
|
205 |
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
206 |
-
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2')
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
# Use simple vector store instead of ChromaDB for HF Spaces
|
211 |
self.vector_store = SimpleVectorStore()
|
@@ -312,6 +318,7 @@ class EnhancedArxivRAG:
|
|
312 |
|
313 |
return chunks
|
314 |
|
|
|
315 |
def process_and_store(self, papers: List[Paper]):
|
316 |
"""Process papers and store in vector store"""
|
317 |
logger.info("Processing and storing papers...")
|
@@ -405,6 +412,7 @@ class EnhancedArxivRAG:
|
|
405 |
|
406 |
return final_results
|
407 |
|
|
|
408 |
def rerank_results(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
|
409 |
"""Rerank results using cross-encoder"""
|
410 |
if not results:
|
@@ -450,6 +458,7 @@ class EnhancedArxivRAG:
|
|
450 |
return f"Based on the retrieved papers about '{query}', here are the key findings:\n\n" + \
|
451 |
"\n\n".join([chunk['document'][:150] + "..." for chunk in context_chunks[:2]])
|
452 |
|
|
|
453 |
def search_and_answer(self, query: str, max_papers: int = 15,
|
454 |
top_k_retrieval: int = 10, top_k_rerank: int = 5,
|
455 |
categories: Optional[List[str]] = None,
|
@@ -700,4 +709,4 @@ def create_interface():
|
|
700 |
# Launch interface
|
701 |
if __name__ == "__main__":
|
702 |
interface = create_interface()
|
703 |
-
interface.launch(
|
|
|
1 |
"""
|
2 |
+
Enhanced ArXiv RAG System - Hugging Face Spaces Compatible Version (Fixed)
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
13 |
import logging
|
14 |
import tempfile
|
15 |
import shutil
|
16 |
+
import spaces
|
17 |
|
18 |
# Core ML libraries
|
19 |
import torch
|
|
|
202 |
def __init__(self):
|
203 |
logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
|
204 |
|
205 |
+
# Use CPU-friendly models for HF Spaces
|
206 |
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
207 |
+
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2')
|
208 |
+
|
209 |
+
# Initialize summarizer without GPU specification
|
210 |
+
self.summarizer = pipeline(
|
211 |
+
"summarization",
|
212 |
+
model="facebook/bart-large-cnn",
|
213 |
+
device=-1 # Force CPU usage
|
214 |
+
)
|
215 |
|
216 |
# Use simple vector store instead of ChromaDB for HF Spaces
|
217 |
self.vector_store = SimpleVectorStore()
|
|
|
318 |
|
319 |
return chunks
|
320 |
|
321 |
+
@spaces.GPU(duration=60) # GPU decorator for processing
|
322 |
def process_and_store(self, papers: List[Paper]):
|
323 |
"""Process papers and store in vector store"""
|
324 |
logger.info("Processing and storing papers...")
|
|
|
412 |
|
413 |
return final_results
|
414 |
|
415 |
+
@spaces.GPU(duration=30) # GPU decorator for reranking
|
416 |
def rerank_results(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
|
417 |
"""Rerank results using cross-encoder"""
|
418 |
if not results:
|
|
|
458 |
return f"Based on the retrieved papers about '{query}', here are the key findings:\n\n" + \
|
459 |
"\n\n".join([chunk['document'][:150] + "..." for chunk in context_chunks[:2]])
|
460 |
|
461 |
+
@spaces.GPU(duration=120) # Main GPU decorator for the full pipeline
|
462 |
def search_and_answer(self, query: str, max_papers: int = 15,
|
463 |
top_k_retrieval: int = 10, top_k_rerank: int = 5,
|
464 |
categories: Optional[List[str]] = None,
|
|
|
709 |
# Launch interface
|
710 |
if __name__ == "__main__":
|
711 |
interface = create_interface()
|
712 |
+
interface.launch()
|