Update app.py
Browse files
app.py
CHANGED
@@ -30,14 +30,14 @@ from nltk.stem import PorterStemmer
|
|
30 |
|
31 |
# Download required NLTK data
|
32 |
try:
|
33 |
-
nltk.data.find(
|
34 |
except LookupError:
|
35 |
-
nltk.download(
|
36 |
|
37 |
try:
|
38 |
-
nltk.data.find(
|
39 |
except LookupError:
|
40 |
-
nltk.download(
|
41 |
|
42 |
# Setup logging
|
43 |
logging.basicConfig(level=logging.INFO)
|
@@ -74,7 +74,7 @@ class BM25Retriever:
|
|
74 |
self.avg_doc_length = 0
|
75 |
self.stemmer = PorterStemmer()
|
76 |
try:
|
77 |
-
self.stop_words = set(stopwords.words(
|
78 |
except:
|
79 |
self.stop_words = set()
|
80 |
|
@@ -150,7 +150,7 @@ class SimpleVectorStore:
|
|
150 |
def query(self, query_embedding: List[float], n_results: int = 10) -> Dict:
|
151 |
"""Query the vector store"""
|
152 |
if not self.embeddings:
|
153 |
-
return {
|
154 |
|
155 |
# Calculate cosine similarities
|
156 |
query_embedding = np.array(query_embedding)
|
@@ -167,25 +167,25 @@ class SimpleVectorStore:
|
|
167 |
top_indices = np.argsort(similarities)[::-1][:n_results]
|
168 |
|
169 |
return {
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
}
|
174 |
|
175 |
def get(self, ids: Optional[List[str]] = None) -> Dict:
|
176 |
"""Get documents by IDs or all documents"""
|
177 |
if ids is None:
|
178 |
return {
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
}
|
183 |
else:
|
184 |
indices = [self.ids.index(id_) for id_ in ids if id_ in self.ids]
|
185 |
return {
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
}
|
190 |
|
191 |
def clear(self):
|
@@ -201,29 +201,11 @@ class EnhancedArxivRAG:
|
|
201 |
def __init__(self):
|
202 |
logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
|
203 |
|
204 |
-
#
|
205 |
-
self.
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
try:
|
210 |
-
logger.info("Loading embedding model...")
|
211 |
-
self.embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=self.device)
|
212 |
-
logger.info("Embedding model loaded.")
|
213 |
-
|
214 |
-
logger.info("Loading reranker model...")
|
215 |
-
self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-2-v2", device=self.device)
|
216 |
-
logger.info("Reranker model loaded.")
|
217 |
-
|
218 |
-
logger.info("Loading summarizer model...")
|
219 |
-
# For pipeline, device_map="auto" is often better for ZeroGPU
|
220 |
-
# If issues persist, try device=0 for the first GPU, or device=self.device
|
221 |
-
self.summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device_map="auto")
|
222 |
-
logger.info("Summarizer model loaded.")
|
223 |
-
|
224 |
-
except Exception as e:
|
225 |
-
logger.error(f"Error loading models: {e}")
|
226 |
-
raise
|
227 |
|
228 |
# Use simple vector store instead of ChromaDB for HF Spaces
|
229 |
self.vector_store = SimpleVectorStore()
|
@@ -256,9 +238,9 @@ class EnhancedArxivRAG:
|
|
256 |
papers = []
|
257 |
for result in search.results():
|
258 |
paper = Paper(
|
259 |
-
id=result.entry_id.split(
|
260 |
-
title=result.title.strip().replace(
|
261 |
-
abstract=result.summary.strip().replace(
|
262 |
authors=[author.name for author in result.authors],
|
263 |
categories=result.categories,
|
264 |
published=result.published.replace(tzinfo=None),
|
@@ -383,12 +365,12 @@ class EnhancedArxivRAG:
|
|
383 |
bm25_scores = self.bm25_retriever.score(query, top_k * 2)
|
384 |
|
385 |
for idx, score in bm25_scores:
|
386 |
-
if idx < len(all_docs[
|
387 |
bm25_results.append({
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
})
|
393 |
|
394 |
# Combine results using RRF
|
@@ -396,13 +378,13 @@ class EnhancedArxivRAG:
|
|
396 |
bm25_weight = 1.0 - semantic_weight
|
397 |
|
398 |
# Add semantic scores
|
399 |
-
for i, doc_id in enumerate(semantic_results[
|
400 |
rank = i + 1
|
401 |
combined_scores[doc_id] = combined_scores.get(doc_id, 0) + semantic_weight / rank
|
402 |
|
403 |
# Add BM25 scores
|
404 |
for i, result in enumerate(bm25_results):
|
405 |
-
doc_id = result[
|
406 |
rank = i + 1
|
407 |
combined_scores[doc_id] = combined_scores.get(doc_id, 0) + bm25_weight / rank
|
408 |
|
@@ -413,12 +395,12 @@ class EnhancedArxivRAG:
|
|
413 |
final_results = []
|
414 |
for doc_id, score in sorted_results[:top_k]:
|
415 |
doc_result = self.vector_store.get(ids=[doc_id])
|
416 |
-
if doc_result[
|
417 |
final_results.append({
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
})
|
423 |
|
424 |
return final_results
|
@@ -429,17 +411,17 @@ class EnhancedArxivRAG:
|
|
429 |
return results
|
430 |
|
431 |
# Prepare query-document pairs
|
432 |
-
query_doc_pairs = [(query, result[
|
433 |
|
434 |
# Get reranking scores
|
435 |
rerank_scores = self.reranker.predict(query_doc_pairs)
|
436 |
|
437 |
# Add rerank scores to results
|
438 |
for i, result in enumerate(results):
|
439 |
-
result[
|
440 |
|
441 |
# Sort by rerank score
|
442 |
-
reranked_results = sorted(results, key=lambda x: x[
|
443 |
return reranked_results[:top_k]
|
444 |
|
445 |
def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
|
@@ -448,7 +430,7 @@ class EnhancedArxivRAG:
|
|
448 |
return "No relevant information found to answer your query."
|
449 |
|
450 |
# Combine context from top chunks
|
451 |
-
context_texts = [chunk[
|
452 |
combined_context = "\n\n".join(context_texts)
|
453 |
|
454 |
# Limit context length
|
@@ -461,12 +443,12 @@ class EnhancedArxivRAG:
|
|
461 |
summary = self.summarizer(summary_input,
|
462 |
max_length=120,
|
463 |
min_length=30,
|
464 |
-
do_sample=False)[0][
|
465 |
return summary
|
466 |
except Exception as e:
|
467 |
logger.error(f"Error generating summary: {e}")
|
468 |
-
return f"Based on the retrieved papers about
|
469 |
-
"\n\n".join([chunk[
|
470 |
|
471 |
def search_and_answer(self, query: str, max_papers: int = 15,
|
472 |
top_k_retrieval: int = 10, top_k_rerank: int = 5,
|
@@ -476,10 +458,10 @@ class EnhancedArxivRAG:
|
|
476 |
|
477 |
if not query.strip():
|
478 |
return {
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
}
|
484 |
|
485 |
try:
|
@@ -488,10 +470,10 @@ class EnhancedArxivRAG:
|
|
488 |
|
489 |
if not papers:
|
490 |
return {
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
}
|
496 |
|
497 |
# Process and store papers
|
@@ -509,36 +491,36 @@ class EnhancedArxivRAG:
|
|
509 |
# Prepare unique papers
|
510 |
unique_papers = {}
|
511 |
for chunk in reranked_results:
|
512 |
-
paper_id = chunk[
|
513 |
if paper_id in self.papers_cache and paper_id not in unique_papers:
|
514 |
paper = self.papers_cache[paper_id]
|
515 |
unique_papers[paper_id] = {
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
}
|
523 |
|
524 |
return {
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
}
|
533 |
}
|
534 |
|
535 |
except Exception as e:
|
536 |
logger.error(f"Error in search_and_answer: {e}")
|
537 |
return {
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
}
|
543 |
|
544 |
# Global RAG instance
|
@@ -557,7 +539,7 @@ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
|
|
557 |
"""Main search function for Gradio interface"""
|
558 |
|
559 |
if not query.strip():
|
560 |
-
return "❌ Please enter a research topic or question.", "",
|
561 |
|
562 |
try:
|
563 |
# Initialize RAG system
|
@@ -566,7 +548,7 @@ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
|
|
566 |
# Parse categories
|
567 |
category_list = None
|
568 |
if categories.strip():
|
569 |
-
category_list = [cat.strip() for cat in categories.split(
|
570 |
|
571 |
# Perform search
|
572 |
result = rag.search_and_answer(
|
@@ -579,33 +561,33 @@ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
|
|
579 |
)
|
580 |
|
581 |
# Format answer
|
582 |
-
answer = f"## 🤖 AI-Generated Answer\n\n{result[
|
583 |
answer += f"**Search Statistics:**\n"
|
584 |
-
answer += f"- Papers found: {result[
|
585 |
-
answer += f"- Chunks retrieved: {result[
|
586 |
-
answer += f"- Unique papers in results: {result[
|
587 |
|
588 |
# Format papers
|
589 |
papers_md = "## 📚 Relevant Papers\n\n"
|
590 |
-
for i, paper in enumerate(result[
|
591 |
-
papers_md += f"### {i}. {paper[
|
592 |
-
papers_md += f"**Authors:** {
|
593 |
-
papers_md += f"**Categories:** {
|
594 |
-
papers_md += f"**Published:** {paper[
|
595 |
-
papers_md += f"**Abstract:** {paper[
|
596 |
-
papers_md += f"**URL:** [{paper[
|
597 |
papers_md += "---\n\n"
|
598 |
|
599 |
# Create papers dataframe
|
600 |
papers_df = pd.DataFrame([
|
601 |
{
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
}
|
608 |
-
for paper in result[
|
609 |
])
|
610 |
|
611 |
return answer, papers_md, papers_df
|
@@ -718,6 +700,5 @@ def create_interface():
|
|
718 |
# Launch interface
|
719 |
if __name__ == "__main__":
|
720 |
interface = create_interface()
|
721 |
-
# Remove share=True for Hugging Face Spaces compatibility
|
722 |
interface.launch()
|
723 |
|
|
|
30 |
|
31 |
# Download required NLTK data
|
32 |
try:
|
33 |
+
nltk.data.find('tokenizers/punkt')
|
34 |
except LookupError:
|
35 |
+
nltk.download('punkt')
|
36 |
|
37 |
try:
|
38 |
+
nltk.data.find('corpora/stopwords')
|
39 |
except LookupError:
|
40 |
+
nltk.download('stopwords')
|
41 |
|
42 |
# Setup logging
|
43 |
logging.basicConfig(level=logging.INFO)
|
|
|
74 |
self.avg_doc_length = 0
|
75 |
self.stemmer = PorterStemmer()
|
76 |
try:
|
77 |
+
self.stop_words = set(stopwords.words('english'))
|
78 |
except:
|
79 |
self.stop_words = set()
|
80 |
|
|
|
150 |
def query(self, query_embedding: List[float], n_results: int = 10) -> Dict:
|
151 |
"""Query the vector store"""
|
152 |
if not self.embeddings:
|
153 |
+
return {'ids': [[]], 'documents': [[]], 'metadatas': [[]]}
|
154 |
|
155 |
# Calculate cosine similarities
|
156 |
query_embedding = np.array(query_embedding)
|
|
|
167 |
top_indices = np.argsort(similarities)[::-1][:n_results]
|
168 |
|
169 |
return {
|
170 |
+
'ids': [[self.ids[i] for i in top_indices]],
|
171 |
+
'documents': [[self.documents[i] for i in top_indices]],
|
172 |
+
'metadatas': [[self.metadatas[i] for i in top_indices]]
|
173 |
}
|
174 |
|
175 |
def get(self, ids: Optional[List[str]] = None) -> Dict:
|
176 |
"""Get documents by IDs or all documents"""
|
177 |
if ids is None:
|
178 |
return {
|
179 |
+
'ids': self.ids,
|
180 |
+
'documents': self.documents,
|
181 |
+
'metadatas': self.metadatas
|
182 |
}
|
183 |
else:
|
184 |
indices = [self.ids.index(id_) for id_ in ids if id_ in self.ids]
|
185 |
return {
|
186 |
+
'ids': [self.ids[i] for i in indices],
|
187 |
+
'documents': [self.documents[i] for i in indices],
|
188 |
+
'metadatas': [self.metadatas[i] for i in indices]
|
189 |
}
|
190 |
|
191 |
def clear(self):
|
|
|
201 |
def __init__(self):
|
202 |
logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
|
203 |
|
204 |
+
# Use smaller, faster models for HF Spaces
|
205 |
+
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
206 |
+
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2') # Smaller reranker
|
207 |
+
self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn",
|
208 |
+
device=0 if torch.cuda.is_available() else -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
# Use simple vector store instead of ChromaDB for HF Spaces
|
211 |
self.vector_store = SimpleVectorStore()
|
|
|
238 |
papers = []
|
239 |
for result in search.results():
|
240 |
paper = Paper(
|
241 |
+
id=result.entry_id.split('/')[-1],
|
242 |
+
title=result.title.strip().replace('\n', ' '),
|
243 |
+
abstract=result.summary.strip().replace('\n', ' '),
|
244 |
authors=[author.name for author in result.authors],
|
245 |
categories=result.categories,
|
246 |
published=result.published.replace(tzinfo=None),
|
|
|
365 |
bm25_scores = self.bm25_retriever.score(query, top_k * 2)
|
366 |
|
367 |
for idx, score in bm25_scores:
|
368 |
+
if idx < len(all_docs['ids']):
|
369 |
bm25_results.append({
|
370 |
+
'id': all_docs['ids'][idx],
|
371 |
+
'document': all_docs['documents'][idx],
|
372 |
+
'metadata': all_docs['metadatas'][idx],
|
373 |
+
'score': score
|
374 |
})
|
375 |
|
376 |
# Combine results using RRF
|
|
|
378 |
bm25_weight = 1.0 - semantic_weight
|
379 |
|
380 |
# Add semantic scores
|
381 |
+
for i, doc_id in enumerate(semantic_results['ids'][0]):
|
382 |
rank = i + 1
|
383 |
combined_scores[doc_id] = combined_scores.get(doc_id, 0) + semantic_weight / rank
|
384 |
|
385 |
# Add BM25 scores
|
386 |
for i, result in enumerate(bm25_results):
|
387 |
+
doc_id = result['id']
|
388 |
rank = i + 1
|
389 |
combined_scores[doc_id] = combined_scores.get(doc_id, 0) + bm25_weight / rank
|
390 |
|
|
|
395 |
final_results = []
|
396 |
for doc_id, score in sorted_results[:top_k]:
|
397 |
doc_result = self.vector_store.get(ids=[doc_id])
|
398 |
+
if doc_result['ids']:
|
399 |
final_results.append({
|
400 |
+
'id': doc_id,
|
401 |
+
'document': doc_result['documents'][0],
|
402 |
+
'metadata': doc_result['metadatas'][0],
|
403 |
+
'combined_score': score
|
404 |
})
|
405 |
|
406 |
return final_results
|
|
|
411 |
return results
|
412 |
|
413 |
# Prepare query-document pairs
|
414 |
+
query_doc_pairs = [(query, result['document']) for result in results]
|
415 |
|
416 |
# Get reranking scores
|
417 |
rerank_scores = self.reranker.predict(query_doc_pairs)
|
418 |
|
419 |
# Add rerank scores to results
|
420 |
for i, result in enumerate(results):
|
421 |
+
result['rerank_score'] = float(rerank_scores[i])
|
422 |
|
423 |
# Sort by rerank score
|
424 |
+
reranked_results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
|
425 |
return reranked_results[:top_k]
|
426 |
|
427 |
def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
|
|
|
430 |
return "No relevant information found to answer your query."
|
431 |
|
432 |
# Combine context from top chunks
|
433 |
+
context_texts = [chunk['document'] for chunk in context_chunks[:3]]
|
434 |
combined_context = "\n\n".join(context_texts)
|
435 |
|
436 |
# Limit context length
|
|
|
443 |
summary = self.summarizer(summary_input,
|
444 |
max_length=120,
|
445 |
min_length=30,
|
446 |
+
do_sample=False)[0]['summary_text']
|
447 |
return summary
|
448 |
except Exception as e:
|
449 |
logger.error(f"Error generating summary: {e}")
|
450 |
+
return f"Based on the retrieved papers about '{query}', here are the key findings:\n\n" + \
|
451 |
+
"\n\n".join([chunk['document'][:150] + "..." for chunk in context_chunks[:2]])
|
452 |
|
453 |
def search_and_answer(self, query: str, max_papers: int = 15,
|
454 |
top_k_retrieval: int = 10, top_k_rerank: int = 5,
|
|
|
458 |
|
459 |
if not query.strip():
|
460 |
return {
|
461 |
+
'answer': "Please enter a valid research query.",
|
462 |
+
'papers': [],
|
463 |
+
'retrieved_chunks': [],
|
464 |
+
'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
|
465 |
}
|
466 |
|
467 |
try:
|
|
|
470 |
|
471 |
if not papers:
|
472 |
return {
|
473 |
+
'answer': "No papers found for your query. Please try different keywords.",
|
474 |
+
'papers': [],
|
475 |
+
'retrieved_chunks': [],
|
476 |
+
'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
|
477 |
}
|
478 |
|
479 |
# Process and store papers
|
|
|
491 |
# Prepare unique papers
|
492 |
unique_papers = {}
|
493 |
for chunk in reranked_results:
|
494 |
+
paper_id = chunk['id'].split('_')[0]
|
495 |
if paper_id in self.papers_cache and paper_id not in unique_papers:
|
496 |
paper = self.papers_cache[paper_id]
|
497 |
unique_papers[paper_id] = {
|
498 |
+
'title': paper.title,
|
499 |
+
'authors': paper.authors,
|
500 |
+
'abstract': paper.abstract,
|
501 |
+
'url': paper.url,
|
502 |
+
'categories': paper.categories,
|
503 |
+
'published': paper.published.strftime('%Y-%m-%d')
|
504 |
}
|
505 |
|
506 |
return {
|
507 |
+
'answer': answer,
|
508 |
+
'papers': list(unique_papers.values()),
|
509 |
+
'retrieved_chunks': reranked_results,
|
510 |
+
'search_stats': {
|
511 |
+
'papers_found': len(papers),
|
512 |
+
'chunks_retrieved': len(reranked_results),
|
513 |
+
'unique_papers_in_results': len(unique_papers)
|
514 |
}
|
515 |
}
|
516 |
|
517 |
except Exception as e:
|
518 |
logger.error(f"Error in search_and_answer: {e}")
|
519 |
return {
|
520 |
+
'answer': f"An error occurred while processing your query: {str(e)}",
|
521 |
+
'papers': [],
|
522 |
+
'retrieved_chunks': [],
|
523 |
+
'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
|
524 |
}
|
525 |
|
526 |
# Global RAG instance
|
|
|
539 |
"""Main search function for Gradio interface"""
|
540 |
|
541 |
if not query.strip():
|
542 |
+
return "❌ Please enter a research topic or question.", "", ""
|
543 |
|
544 |
try:
|
545 |
# Initialize RAG system
|
|
|
548 |
# Parse categories
|
549 |
category_list = None
|
550 |
if categories.strip():
|
551 |
+
category_list = [cat.strip() for cat in categories.split(',') if cat.strip()]
|
552 |
|
553 |
# Perform search
|
554 |
result = rag.search_and_answer(
|
|
|
561 |
)
|
562 |
|
563 |
# Format answer
|
564 |
+
answer = f"## 🤖 AI-Generated Answer\n\n{result['answer']}\n\n"
|
565 |
answer += f"**Search Statistics:**\n"
|
566 |
+
answer += f"- Papers found: {result['search_stats']['papers_found']}\n"
|
567 |
+
answer += f"- Chunks retrieved: {result['search_stats']['chunks_retrieved']}\n"
|
568 |
+
answer += f"- Unique papers in results: {result['search_stats']['unique_papers_in_results']}\n\n"
|
569 |
|
570 |
# Format papers
|
571 |
papers_md = "## 📚 Relevant Papers\n\n"
|
572 |
+
for i, paper in enumerate(result['papers'], 1):
|
573 |
+
papers_md += f"### {i}. {paper['title']}\n\n"
|
574 |
+
papers_md += f"**Authors:** {', '.join(paper['authors'][:3])}{'...' if len(paper['authors']) > 3 else ''}\n\n"
|
575 |
+
papers_md += f"**Categories:** {', '.join(paper['categories'])}\n\n"
|
576 |
+
papers_md += f"**Published:** {paper['published']}\n\n"
|
577 |
+
papers_md += f"**Abstract:** {paper['abstract'][:250]}{'...' if len(paper['abstract']) > 250 else ''}\n\n"
|
578 |
+
papers_md += f"**URL:** [{paper['url']}]({paper['url']})\n\n"
|
579 |
papers_md += "---\n\n"
|
580 |
|
581 |
# Create papers dataframe
|
582 |
papers_df = pd.DataFrame([
|
583 |
{
|
584 |
+
'Title': paper['title'][:50] + '...' if len(paper['title']) > 50 else paper['title'],
|
585 |
+
'Authors': ', '.join(paper['authors'][:2]) + ('...' if len(paper['authors']) > 2 else ''),
|
586 |
+
'Categories': ', '.join(paper['categories'][:2]),
|
587 |
+
'Published': paper['published'],
|
588 |
+
'URL': paper['url']
|
589 |
}
|
590 |
+
for paper in result['papers']
|
591 |
])
|
592 |
|
593 |
return answer, papers_md, papers_df
|
|
|
700 |
# Launch interface
|
701 |
if __name__ == "__main__":
|
702 |
interface = create_interface()
|
|
|
703 |
interface.launch()
|
704 |
|