Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
"""
|
2 |
-
Enhanced ArXiv RAG System - Hugging Face Spaces
|
3 |
"""
|
4 |
|
5 |
import os
|
@@ -13,12 +13,16 @@ from datetime import datetime, timedelta
|
|
13 |
import logging
|
14 |
import tempfile
|
15 |
import shutil
|
|
|
|
|
16 |
|
17 |
# Core ML libraries
|
18 |
import torch
|
|
|
19 |
from sentence_transformers import SentenceTransformer, CrossEncoder
|
20 |
-
from transformers import pipeline
|
21 |
import gradio as gr
|
|
|
22 |
|
23 |
# BM25 and text processing
|
24 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
@@ -32,17 +36,25 @@ from nltk.stem import PorterStemmer
|
|
32 |
try:
|
33 |
nltk.data.find('tokenizers/punkt')
|
34 |
except LookupError:
|
35 |
-
nltk.download('punkt')
|
36 |
|
37 |
try:
|
38 |
nltk.data.find('corpora/stopwords')
|
39 |
except LookupError:
|
40 |
-
nltk.download('stopwords')
|
41 |
|
42 |
# Setup logging
|
43 |
logging.basicConfig(level=logging.INFO)
|
44 |
logger = logging.getLogger(__name__)
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
@dataclass
|
47 |
class Paper:
|
48 |
"""Data class for storing paper information"""
|
@@ -63,9 +75,36 @@ class Chunk:
|
|
63 |
chunk_type: str
|
64 |
metadata: Dict[str, Any]
|
65 |
|
66 |
-
class
|
67 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def __init__(self, k1: float = 1.5, b: float = 0.75):
|
70 |
self.k1 = k1
|
71 |
self.b = b
|
@@ -77,542 +116,598 @@ class BM25Retriever:
|
|
77 |
self.stop_words = set(stopwords.words('english'))
|
78 |
except:
|
79 |
self.stop_words = set()
|
80 |
-
|
81 |
def preprocess_text(self, text: str) -> List[str]:
|
82 |
"""Preprocess text for BM25"""
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
91 |
def fit(self, documents: List[str]):
|
92 |
-
"""Fit BM25 on documents"""
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
vocab
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
-
class SimpleVectorStore:
|
134 |
-
"""Simple in-memory vector store for HF Spaces compatibility"""
|
135 |
-
|
136 |
def __init__(self):
|
137 |
-
self.
|
138 |
-
self.
|
139 |
-
self.
|
140 |
-
self.
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
self.
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
for emb in self.embeddings:
|
160 |
-
emb_array = np.array(emb)
|
161 |
-
similarity = np.dot(query_embedding, emb_array) / (
|
162 |
-
np.linalg.norm(query_embedding) * np.linalg.norm(emb_array)
|
163 |
)
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
class EnhancedArxivRAG:
|
199 |
-
"""Enhanced RAG system optimized for Hugging Face Spaces"""
|
200 |
-
|
201 |
-
def __init__(self):
|
202 |
-
logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
|
203 |
-
|
204 |
-
# Use smaller, faster models for HF Spaces
|
205 |
-
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
206 |
-
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2') # Smaller reranker
|
207 |
-
self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn",
|
208 |
-
device=0 if torch.cuda.is_available() else -1)
|
209 |
-
|
210 |
-
# Use simple vector store instead of ChromaDB for HF Spaces
|
211 |
-
self.vector_store = SimpleVectorStore()
|
212 |
-
self.bm25_retriever = BM25Retriever()
|
213 |
-
|
214 |
-
# Cache for papers and chunks
|
215 |
-
self.papers_cache = {}
|
216 |
-
self.chunks_cache = {}
|
217 |
-
self.bm25_fitted = False
|
218 |
-
|
219 |
-
logger.info("RAG system initialized successfully!")
|
220 |
-
|
221 |
-
def fetch_papers(self, query: str, max_results: int = 15,
|
222 |
-
categories: Optional[List[str]] = None) -> List[Paper]:
|
223 |
-
"""Fetch papers from ArXiv"""
|
224 |
-
search_query = query
|
225 |
-
if categories:
|
226 |
-
category_filter = " OR ".join([f"cat:{cat}" for cat in categories])
|
227 |
-
search_query = f"({query}) AND ({category_filter})"
|
228 |
-
|
229 |
-
logger.info(f"Fetching papers with query: {search_query}")
|
230 |
-
|
231 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
search = arxiv.Search(
|
233 |
query=search_query,
|
234 |
max_results=max_results,
|
235 |
-
sort_by=arxiv.SortCriterion.Relevance
|
|
|
236 |
)
|
237 |
|
238 |
-
papers = []
|
239 |
for result in search.results():
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
252 |
-
logger.info(f"
|
253 |
return papers
|
254 |
|
255 |
except Exception as e:
|
256 |
-
logger.error(f"
|
257 |
return []
|
258 |
-
|
259 |
def create_chunks(self, papers: List[Paper]) -> List[Chunk]:
|
260 |
"""Create text chunks from papers"""
|
261 |
chunks = []
|
262 |
|
263 |
for paper in papers:
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
"
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
-
|
279 |
-
|
280 |
-
id=f"{paper.id}_abstract",
|
281 |
-
paper_id=paper.id,
|
282 |
-
text=paper.abstract,
|
283 |
-
chunk_type="abstract",
|
284 |
-
metadata={
|
285 |
-
"authors": paper.authors,
|
286 |
-
"categories": paper.categories,
|
287 |
-
"published": paper.published.isoformat(),
|
288 |
-
"url": paper.url
|
289 |
-
}
|
290 |
-
)
|
291 |
|
292 |
-
|
293 |
-
combined_text = f"Title: {paper.title}\n\nAbstract: {paper.abstract}"
|
294 |
-
combined_chunk = Chunk(
|
295 |
-
id=f"{paper.id}_combined",
|
296 |
-
paper_id=paper.id,
|
297 |
-
text=combined_text,
|
298 |
-
chunk_type="combined",
|
299 |
-
metadata={
|
300 |
-
"authors": paper.authors,
|
301 |
-
"categories": paper.categories,
|
302 |
-
"published": paper.published.isoformat(),
|
303 |
-
"url": paper.url
|
304 |
-
}
|
305 |
-
)
|
306 |
|
307 |
-
|
|
|
|
|
308 |
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
bm25_scores = self.
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
doc_id = result['id']
|
388 |
-
rank = i + 1
|
389 |
-
combined_scores[doc_id] = combined_scores.get(doc_id, 0) + bm25_weight / rank
|
390 |
-
|
391 |
-
# Sort by combined score
|
392 |
-
sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
|
393 |
-
|
394 |
-
# Prepare final results
|
395 |
-
final_results = []
|
396 |
-
for doc_id, score in sorted_results[:top_k]:
|
397 |
-
doc_result = self.vector_store.get(ids=[doc_id])
|
398 |
-
if doc_result['ids']:
|
399 |
-
final_results.append({
|
400 |
-
'id': doc_id,
|
401 |
-
'document': doc_result['documents'][0],
|
402 |
-
'metadata': doc_result['metadatas'][0],
|
403 |
-
'combined_score': score
|
404 |
-
})
|
405 |
-
|
406 |
-
return final_results
|
407 |
-
|
408 |
-
def rerank_results(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
|
409 |
-
"""Rerank results using cross-encoder"""
|
410 |
-
if not results:
|
411 |
return results
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
result['rerank_score'] = float(rerank_scores[i])
|
422 |
-
|
423 |
-
# Sort by rerank score
|
424 |
-
reranked_results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
|
425 |
-
return reranked_results[:top_k]
|
426 |
-
|
427 |
-
def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
|
428 |
-
"""Generate answer using retrieved context"""
|
429 |
-
if not context_chunks:
|
430 |
-
return "No relevant information found to answer your query."
|
431 |
-
|
432 |
-
# Combine context from top chunks
|
433 |
-
context_texts = [chunk['document'] for chunk in context_chunks[:3]]
|
434 |
-
combined_context = "\n\n".join(context_texts)
|
435 |
-
|
436 |
-
# Limit context length
|
437 |
-
max_context_length = 800
|
438 |
-
if len(combined_context) > max_context_length:
|
439 |
-
combined_context = combined_context[:max_context_length] + "..."
|
440 |
-
|
441 |
try:
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
except Exception as e:
|
449 |
-
logger.error(f"
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
semantic_weight: float = 0.7) -> Dict[str, Any]:
|
457 |
-
"""Main search and answer pipeline"""
|
458 |
-
|
459 |
-
if not query.strip():
|
460 |
-
return {
|
461 |
-
'answer': "Please enter a valid research query.",
|
462 |
-
'papers': [],
|
463 |
-
'retrieved_chunks': [],
|
464 |
-
'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
|
465 |
-
}
|
466 |
-
|
467 |
try:
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
#
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
}
|
|
|
|
|
|
|
505 |
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
|
517 |
except Exception as e:
|
518 |
-
logger.error(f"
|
519 |
-
return {
|
520 |
-
|
521 |
-
|
522 |
-
'retrieved_chunks': [],
|
523 |
-
'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
|
524 |
-
}
|
525 |
-
|
526 |
-
# Global RAG instance
|
527 |
rag_system = None
|
528 |
|
529 |
-
def
|
530 |
-
"""Initialize RAG system"""
|
531 |
global rag_system
|
532 |
-
if rag_system is None:
|
533 |
-
rag_system = EnhancedArxivRAG()
|
534 |
-
return rag_system
|
535 |
-
|
536 |
-
def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
|
537 |
-
top_k_rerank: int = 5, categories: str = "",
|
538 |
-
semantic_weight: float = 0.7) -> tuple:
|
539 |
-
"""Main search function for Gradio interface"""
|
540 |
-
|
541 |
-
if not query.strip():
|
542 |
-
return "β Please enter a research topic or question.", "", ""
|
543 |
-
|
544 |
try:
|
545 |
-
|
546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
547 |
|
548 |
# Parse categories
|
549 |
-
category_list =
|
550 |
if categories.strip():
|
551 |
category_list = [cat.strip() for cat in categories.split(',') if cat.strip()]
|
552 |
|
553 |
-
#
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
|
563 |
-
#
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
|
570 |
-
#
|
571 |
-
|
572 |
-
for i, paper in enumerate(result['papers'], 1):
|
573 |
-
papers_md += f"### {i}. {paper['title']}\n\n"
|
574 |
-
papers_md += f"**Authors:** {', '.join(paper['authors'][:3])}{'...' if len(paper['authors']) > 3 else ''}\n\n"
|
575 |
-
papers_md += f"**Categories:** {', '.join(paper['categories'])}\n\n"
|
576 |
-
papers_md += f"**Published:** {paper['published']}\n\n"
|
577 |
-
papers_md += f"**Abstract:** {paper['abstract'][:250]}{'...' if len(paper['abstract']) > 250 else ''}\n\n"
|
578 |
-
papers_md += f"**URL:** [{paper['url']}]({paper['url']})\n\n"
|
579 |
-
papers_md += "---\n\n"
|
580 |
|
581 |
-
#
|
582 |
-
papers_df =
|
583 |
-
{
|
584 |
-
'Title': paper['title'][:50] + '...' if len(paper['title']) > 50 else paper['title'],
|
585 |
-
'Authors': ', '.join(paper['authors'][:2]) + ('...' if len(paper['authors']) > 2 else ''),
|
586 |
-
'Categories': ', '.join(paper['categories'][:2]),
|
587 |
-
'Published': paper['published'],
|
588 |
-
'URL': paper['url']
|
589 |
-
}
|
590 |
-
for paper in result['papers']
|
591 |
-
])
|
592 |
|
593 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
594 |
|
595 |
except Exception as e:
|
596 |
-
logger.error(f"
|
597 |
error_msg = f"β An error occurred: {str(e)}\n\nPlease try different keywords or check your internet connection."
|
598 |
return error_msg, "", pd.DataFrame()
|
599 |
|
600 |
# Create Gradio interface
|
601 |
def create_interface():
|
602 |
-
"""Create Gradio interface"""
|
603 |
|
604 |
css = """
|
605 |
.gradio-container {
|
606 |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
607 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
608 |
"""
|
609 |
|
610 |
-
with gr.Blocks(css=css, title="Enhanced ArXiv RAG System") as interface:
|
611 |
|
612 |
-
gr.HTML("""
|
613 |
<div style="text-align: center; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
|
614 |
<h1>π Enhanced ArXiv RAG System</h1>
|
615 |
-
<p>
|
|
|
|
|
|
|
616 |
</div>
|
617 |
""")
|
618 |
|
@@ -634,29 +729,45 @@ def create_interface():
|
|
634 |
value=""
|
635 |
)
|
636 |
|
637 |
-
with gr.Accordion("Advanced Settings", open=False):
|
638 |
with gr.Row():
|
639 |
top_k_retrieval = gr.Slider(5, 15, value=10, step=1, label="Top-K Retrieval")
|
640 |
top_k_rerank = gr.Slider(3, 8, value=5, step=1, label="Top-K Reranking")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
641 |
|
642 |
-
search_btn = gr.Button("π Search Papers", variant="primary")
|
643 |
|
644 |
with gr.Column(scale=1):
|
645 |
gr.HTML("""
|
646 |
<div style="background: #e3f2fd; padding: 1rem; border-radius: 8px;">
|
647 |
-
<h4>π‘ Tips</h4>
|
648 |
<ul>
|
649 |
<li>Use specific technical terms</li>
|
650 |
<li>Try different category filters</li>
|
651 |
<li>Adjust semantic weight for different search styles</li>
|
|
|
|
|
652 |
</ul>
|
653 |
|
654 |
-
<h4>π Categories</h4>
|
655 |
<ul>
|
656 |
<li><code>cs.AI</code> - Artificial Intelligence</li>
|
657 |
<li><code>cs.CL</code> - Computation and Language</li>
|
658 |
<li><code>cs.LG</code> - Machine Learning</li>
|
659 |
<li><code>cs.CV</code> - Computer Vision</li>
|
|
|
|
|
660 |
</ul>
|
661 |
</div>
|
662 |
""")
|
@@ -669,15 +780,20 @@ def create_interface():
|
|
669 |
papers_output = gr.Markdown(label="Relevant Papers")
|
670 |
|
671 |
with gr.TabItem("π Papers Table"):
|
672 |
-
papers_table = gr.Dataframe(
|
|
|
|
|
|
|
|
|
673 |
|
674 |
# Examples
|
675 |
gr.Examples(
|
676 |
examples=[
|
677 |
["transformer attention mechanisms", 15, 10, 5, "cs.CL, cs.AI", 0.7],
|
678 |
-
["graph neural networks", 12, 8, 4, "cs.LG", 0.6],
|
679 |
["computer vision deep learning", 15, 10, 5, "cs.CV", 0.8],
|
680 |
-
["reinforcement learning", 18, 10, 5, "cs.AI", 0.7]
|
|
|
681 |
],
|
682 |
inputs=[query_input, max_papers, top_k_retrieval, top_k_rerank, categories_input, semantic_weight]
|
683 |
)
|
@@ -691,7 +807,8 @@ def create_interface():
|
|
691 |
|
692 |
gr.HTML("""
|
693 |
<div style="text-align: center; margin-top: 2rem; padding: 1rem; background: #f5f5f5; border-radius: 8px;">
|
694 |
-
<p><strong>Enhanced ArXiv RAG System</strong> | Semantic Search + BM25 + Neural Reranking</p>
|
|
|
695 |
</div>
|
696 |
""")
|
697 |
|
@@ -699,6 +816,16 @@ def create_interface():
|
|
699 |
|
700 |
# Launch interface
|
701 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
702 |
interface = create_interface()
|
703 |
-
interface.launch(
|
704 |
-
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
+
Enhanced ArXiv RAG System - GPU Optimized for Hugging Face Spaces
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
13 |
import logging
|
14 |
import tempfile
|
15 |
import shutil
|
16 |
+
import gc
|
17 |
+
import time
|
18 |
|
19 |
# Core ML libraries
|
20 |
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
from sentence_transformers import SentenceTransformer, CrossEncoder
|
23 |
+
from transformers import pipeline, AutoTokenizer, AutoModel
|
24 |
import gradio as gr
|
25 |
+
import spaces # HuggingFace Spaces GPU support
|
26 |
|
27 |
# BM25 and text processing
|
28 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
|
36 |
try:
|
37 |
nltk.data.find('tokenizers/punkt')
|
38 |
except LookupError:
|
39 |
+
nltk.download('punkt', quiet=True)
|
40 |
|
41 |
try:
|
42 |
nltk.data.find('corpora/stopwords')
|
43 |
except LookupError:
|
44 |
+
nltk.download('stopwords', quiet=True)
|
45 |
|
46 |
# Setup logging
|
47 |
logging.basicConfig(level=logging.INFO)
|
48 |
logger = logging.getLogger(__name__)
|
49 |
|
50 |
+
# GPU Configuration
|
51 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
52 |
+
logger.info(f"Using device: {DEVICE}")
|
53 |
+
|
54 |
+
if DEVICE == "cuda":
|
55 |
+
logger.info(f"GPU: {torch.cuda.get_device_name()}")
|
56 |
+
logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
57 |
+
|
58 |
@dataclass
|
59 |
class Paper:
|
60 |
"""Data class for storing paper information"""
|
|
|
75 |
chunk_type: str
|
76 |
metadata: Dict[str, Any]
|
77 |
|
78 |
+
class GPUMemoryManager:
|
79 |
+
"""Manages GPU memory efficiently"""
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def clear_cache():
|
83 |
+
"""Clear GPU cache"""
|
84 |
+
if torch.cuda.is_available():
|
85 |
+
torch.cuda.empty_cache()
|
86 |
+
gc.collect()
|
87 |
|
88 |
+
@staticmethod
|
89 |
+
def get_memory_info():
|
90 |
+
"""Get GPU memory information"""
|
91 |
+
if torch.cuda.is_available():
|
92 |
+
allocated = torch.cuda.memory_allocated() / 1e9
|
93 |
+
cached = torch.cuda.memory_reserved() / 1e9
|
94 |
+
return f"Allocated: {allocated:.1f}GB, Cached: {cached:.1f}GB"
|
95 |
+
return "CPU mode"
|
96 |
+
|
97 |
+
@staticmethod
|
98 |
+
def optimize_memory():
|
99 |
+
"""Optimize memory usage"""
|
100 |
+
if torch.cuda.is_available():
|
101 |
+
torch.cuda.empty_cache()
|
102 |
+
torch.backends.cudnn.benchmark = True
|
103 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
104 |
+
|
105 |
+
class BM25Retriever:
|
106 |
+
"""Optimized BM25 retriever for keyword-based search"""
|
107 |
+
|
108 |
def __init__(self, k1: float = 1.5, b: float = 0.75):
|
109 |
self.k1 = k1
|
110 |
self.b = b
|
|
|
116 |
self.stop_words = set(stopwords.words('english'))
|
117 |
except:
|
118 |
self.stop_words = set()
|
119 |
+
|
120 |
def preprocess_text(self, text: str) -> List[str]:
|
121 |
"""Preprocess text for BM25"""
|
122 |
+
try:
|
123 |
+
tokens = word_tokenize(text.lower())
|
124 |
+
processed_tokens = [
|
125 |
+
self.stemmer.stem(token)
|
126 |
+
for token in tokens
|
127 |
+
if token.isalpha() and token not in self.stop_words and len(token) > 2
|
128 |
+
]
|
129 |
+
return processed_tokens
|
130 |
+
except Exception as e:
|
131 |
+
logger.warning(f"Text preprocessing error: {e}")
|
132 |
+
return text.lower().split()
|
133 |
+
|
134 |
def fit(self, documents: List[str]):
|
135 |
+
"""Fit BM25 on documents with memory optimization"""
|
136 |
+
try:
|
137 |
+
self.documents = [self.preprocess_text(doc) for doc in documents]
|
138 |
+
self.doc_lengths = [len(doc) for doc in self.documents]
|
139 |
+
self.avg_doc_length = sum(self.doc_lengths) / len(self.doc_lengths) if self.doc_lengths else 0
|
140 |
+
|
141 |
+
# Build vocabulary
|
142 |
+
vocab = set()
|
143 |
+
for doc in self.documents:
|
144 |
+
vocab.update(doc)
|
145 |
+
self.vocab = list(vocab)
|
146 |
+
|
147 |
+
# Calculate term frequencies
|
148 |
+
self.term_freqs = []
|
149 |
+
for doc in self.documents:
|
150 |
+
tf = {}
|
151 |
+
for term in doc:
|
152 |
+
tf[term] = tf.get(term, 0) + 1
|
153 |
+
self.term_freqs.append(tf)
|
154 |
+
|
155 |
+
# Calculate IDF
|
156 |
+
self.idf = {}
|
157 |
+
for term in self.vocab:
|
158 |
+
df = sum(1 for doc in self.documents if term in doc)
|
159 |
+
self.idf[term] = np.log((len(self.documents) - df + 0.5) / (df + 0.5))
|
160 |
+
|
161 |
+
except Exception as e:
|
162 |
+
logger.error(f"BM25 fitting error: {e}")
|
163 |
+
self.documents = []
|
164 |
+
|
165 |
+
def get_scores(self, query: str) -> np.ndarray:
|
166 |
+
"""Get BM25 scores for query"""
|
167 |
+
try:
|
168 |
+
query_terms = self.preprocess_text(query)
|
169 |
+
scores = np.zeros(len(self.documents))
|
170 |
+
|
171 |
+
for i, doc_tf in enumerate(self.term_freqs):
|
172 |
+
score = 0
|
173 |
+
doc_length = self.doc_lengths[i]
|
174 |
+
|
175 |
+
for term in query_terms:
|
176 |
+
if term in doc_tf:
|
177 |
+
tf = doc_tf[term]
|
178 |
+
idf = self.idf.get(term, 0)
|
179 |
+
score += idf * (tf * (self.k1 + 1)) / (
|
180 |
+
tf + self.k1 * (1 - self.b + self.b * (doc_length / self.avg_doc_length))
|
181 |
+
)
|
182 |
+
|
183 |
+
scores[i] = score
|
184 |
+
|
185 |
+
return scores
|
186 |
+
except Exception as e:
|
187 |
+
logger.error(f"BM25 scoring error: {e}")
|
188 |
+
return np.zeros(len(self.documents))
|
189 |
+
|
190 |
+
class OptimizedRagSystem:
|
191 |
+
"""GPU-optimized RAG system for ArXiv papers"""
|
192 |
|
|
|
|
|
|
|
193 |
def __init__(self):
|
194 |
+
self.papers = []
|
195 |
+
self.chunks = []
|
196 |
+
self.embeddings = None
|
197 |
+
self.embedding_model = None
|
198 |
+
self.reranker = None
|
199 |
+
self.bm25 = BM25Retriever()
|
200 |
+
self.generator = None
|
201 |
+
self.memory_manager = GPUMemoryManager()
|
202 |
+
|
203 |
+
# Initialize models
|
204 |
+
self._load_models()
|
205 |
+
|
206 |
+
def _load_models(self):
|
207 |
+
"""Load models with GPU optimization"""
|
208 |
+
try:
|
209 |
+
logger.info("Loading models...")
|
210 |
+
|
211 |
+
# Load embedding model
|
212 |
+
self.embedding_model = SentenceTransformer(
|
213 |
+
'sentence-transformers/all-MiniLM-L6-v2',
|
214 |
+
device=DEVICE
|
|
|
|
|
|
|
|
|
|
|
215 |
)
|
216 |
+
|
217 |
+
# Optimize for GPU if available
|
218 |
+
if DEVICE == "cuda":
|
219 |
+
self.embedding_model.half() # Use FP16 for memory efficiency
|
220 |
+
|
221 |
+
# Load reranker (smaller model for efficiency)
|
222 |
+
self.reranker = CrossEncoder(
|
223 |
+
'cross-encoder/ms-marco-MiniLM-L-6-v2',
|
224 |
+
device=DEVICE
|
225 |
+
)
|
226 |
+
|
227 |
+
# Load text generator with optimization
|
228 |
+
self.generator = pipeline(
|
229 |
+
"text-generation",
|
230 |
+
model="microsoft/DialoGPT-small", # Smaller model for efficiency
|
231 |
+
tokenizer="microsoft/DialoGPT-small",
|
232 |
+
device=0 if DEVICE == "cuda" else -1,
|
233 |
+
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
234 |
+
return_full_text=False,
|
235 |
+
max_new_tokens=512,
|
236 |
+
do_sample=True,
|
237 |
+
temperature=0.7,
|
238 |
+
pad_token_id=50256
|
239 |
+
)
|
240 |
+
|
241 |
+
self.memory_manager.optimize_memory()
|
242 |
+
logger.info("Models loaded successfully")
|
243 |
+
|
244 |
+
except Exception as e:
|
245 |
+
logger.error(f"Model loading error: {e}")
|
246 |
+
raise
|
247 |
+
|
248 |
+
def search_arxiv(self, query: str, max_results: int = 15, categories: List[str] = None) -> List[Paper]:
|
249 |
+
"""Search ArXiv with error handling and rate limiting"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
try:
|
251 |
+
papers = []
|
252 |
+
search_query = query
|
253 |
+
|
254 |
+
if categories:
|
255 |
+
category_filter = " OR ".join([f"cat:{cat.strip()}" for cat in categories])
|
256 |
+
search_query = f"({query}) AND ({category_filter})"
|
257 |
+
|
258 |
+
logger.info(f"Searching ArXiv for: {search_query}")
|
259 |
+
|
260 |
search = arxiv.Search(
|
261 |
query=search_query,
|
262 |
max_results=max_results,
|
263 |
+
sort_by=arxiv.SortCriterion.Relevance,
|
264 |
+
sort_order=arxiv.SortOrder.Descending
|
265 |
)
|
266 |
|
|
|
267 |
for result in search.results():
|
268 |
+
try:
|
269 |
+
paper = Paper(
|
270 |
+
id=result.entry_id.split('/')[-1],
|
271 |
+
title=result.title,
|
272 |
+
abstract=result.summary,
|
273 |
+
authors=[author.name for author in result.authors],
|
274 |
+
categories=result.categories,
|
275 |
+
published=result.published,
|
276 |
+
url=result.entry_id
|
277 |
+
)
|
278 |
+
papers.append(paper)
|
279 |
+
|
280 |
+
# Rate limiting
|
281 |
+
time.sleep(0.1)
|
282 |
+
|
283 |
+
except Exception as e:
|
284 |
+
logger.warning(f"Error processing paper: {e}")
|
285 |
+
continue
|
286 |
|
287 |
+
logger.info(f"Found {len(papers)} papers")
|
288 |
return papers
|
289 |
|
290 |
except Exception as e:
|
291 |
+
logger.error(f"ArXiv search error: {e}")
|
292 |
return []
|
293 |
+
|
294 |
def create_chunks(self, papers: List[Paper]) -> List[Chunk]:
|
295 |
"""Create text chunks from papers"""
|
296 |
chunks = []
|
297 |
|
298 |
for paper in papers:
|
299 |
+
try:
|
300 |
+
# Title chunk
|
301 |
+
chunks.append(Chunk(
|
302 |
+
id=f"{paper.id}_title",
|
303 |
+
paper_id=paper.id,
|
304 |
+
text=paper.title,
|
305 |
+
chunk_type="title",
|
306 |
+
metadata={"paper": paper}
|
307 |
+
))
|
308 |
+
|
309 |
+
# Abstract chunks (split if too long)
|
310 |
+
abstract_sentences = sent_tokenize(paper.abstract)
|
311 |
+
chunk_size = 3 # sentences per chunk
|
312 |
+
|
313 |
+
for i in range(0, len(abstract_sentences), chunk_size):
|
314 |
+
chunk_text = ' '.join(abstract_sentences[i:i + chunk_size])
|
315 |
+
chunks.append(Chunk(
|
316 |
+
id=f"{paper.id}_abstract_{i}",
|
317 |
+
paper_id=paper.id,
|
318 |
+
text=chunk_text,
|
319 |
+
chunk_type="abstract",
|
320 |
+
metadata={"paper": paper}
|
321 |
+
))
|
322 |
+
|
323 |
+
except Exception as e:
|
324 |
+
logger.warning(f"Error creating chunks for paper {paper.id}: {e}")
|
325 |
+
continue
|
326 |
+
|
327 |
+
return chunks
|
328 |
+
|
329 |
+
@spaces.GPU(duration=120) # HuggingFace Spaces GPU decorator
|
330 |
+
def embed_chunks(self, chunks: List[Chunk]) -> np.ndarray:
|
331 |
+
"""Create embeddings for chunks with GPU optimization"""
|
332 |
+
try:
|
333 |
+
if not chunks:
|
334 |
+
return np.array([])
|
335 |
|
336 |
+
logger.info(f"Creating embeddings for {len(chunks)} chunks")
|
337 |
+
self.memory_manager.clear_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
+
texts = [chunk.text for chunk in chunks]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
|
341 |
+
# Batch processing for efficiency
|
342 |
+
batch_size = 32 if DEVICE == "cuda" else 8
|
343 |
+
embeddings = []
|
344 |
|
345 |
+
for i in range(0, len(texts), batch_size):
|
346 |
+
batch_texts = texts[i:i + batch_size]
|
347 |
+
|
348 |
+
with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
|
349 |
+
batch_embeddings = self.embedding_model.encode(
|
350 |
+
batch_texts,
|
351 |
+
convert_to_tensor=True,
|
352 |
+
show_progress_bar=False,
|
353 |
+
batch_size=len(batch_texts)
|
354 |
+
)
|
355 |
+
|
356 |
+
if DEVICE == "cuda":
|
357 |
+
batch_embeddings = batch_embeddings.cpu()
|
358 |
+
|
359 |
+
embeddings.append(batch_embeddings.numpy())
|
360 |
+
|
361 |
+
# Memory management
|
362 |
+
if i % (batch_size * 4) == 0:
|
363 |
+
self.memory_manager.clear_cache()
|
364 |
+
|
365 |
+
result = np.vstack(embeddings) if embeddings else np.array([])
|
366 |
+
self.memory_manager.clear_cache()
|
367 |
+
|
368 |
+
logger.info(f"Created embeddings shape: {result.shape}")
|
369 |
+
return result
|
370 |
+
|
371 |
+
except Exception as e:
|
372 |
+
logger.error(f"Embedding error: {e}")
|
373 |
+
self.memory_manager.clear_cache()
|
374 |
+
return np.array([])
|
375 |
+
|
376 |
+
@spaces.GPU(duration=60) # HuggingFace Spaces GPU decorator
|
377 |
+
def hybrid_retrieval(self, query: str, top_k: int = 10, semantic_weight: float = 0.7) -> List[Tuple[Chunk, float]]:
|
378 |
+
"""Perform hybrid retrieval with GPU optimization"""
|
379 |
+
try:
|
380 |
+
if not self.chunks or self.embeddings is None or len(self.embeddings) == 0:
|
381 |
+
return []
|
382 |
+
|
383 |
+
self.memory_manager.clear_cache()
|
384 |
+
|
385 |
+
# Semantic search
|
386 |
+
with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
|
387 |
+
query_embedding = self.embedding_model.encode(
|
388 |
+
[query],
|
389 |
+
convert_to_tensor=True,
|
390 |
+
show_progress_bar=False
|
391 |
+
)
|
392 |
+
|
393 |
+
if DEVICE == "cuda":
|
394 |
+
query_embedding = query_embedding.cpu()
|
395 |
+
|
396 |
+
query_embedding = query_embedding.numpy()
|
397 |
+
|
398 |
+
semantic_scores = cosine_similarity(query_embedding, self.embeddings)[0]
|
399 |
+
|
400 |
+
# BM25 search
|
401 |
+
bm25_scores = self.bm25.get_scores(query)
|
402 |
+
|
403 |
+
# Ensure same length
|
404 |
+
min_length = min(len(semantic_scores), len(bm25_scores), len(self.chunks))
|
405 |
+
semantic_scores = semantic_scores[:min_length]
|
406 |
+
bm25_scores = bm25_scores[:min_length]
|
407 |
+
chunks = self.chunks[:min_length]
|
408 |
+
|
409 |
+
# Normalize scores
|
410 |
+
if len(semantic_scores) > 0:
|
411 |
+
semantic_scores = (semantic_scores - semantic_scores.min()) / (semantic_scores.max() - semantic_scores.min() + 1e-8)
|
412 |
+
if len(bm25_scores) > 0:
|
413 |
+
bm25_scores = (bm25_scores - bm25_scores.min()) / (bm25_scores.max() - bm25_scores.min() + 1e-8)
|
414 |
+
|
415 |
+
# Combine scores
|
416 |
+
combined_scores = semantic_weight * semantic_scores + (1 - semantic_weight) * bm25_scores
|
417 |
+
|
418 |
+
# Get top results
|
419 |
+
top_indices = np.argsort(combined_scores)[::-1][:top_k]
|
420 |
+
results = [(chunks[i], float(combined_scores[i])) for i in top_indices]
|
421 |
+
|
422 |
+
self.memory_manager.clear_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
return results
|
424 |
+
|
425 |
+
except Exception as e:
|
426 |
+
logger.error(f"Retrieval error: {e}")
|
427 |
+
self.memory_manager.clear_cache()
|
428 |
+
return []
|
429 |
+
|
430 |
+
@spaces.GPU(duration=60) # HuggingFace Spaces GPU decorator
|
431 |
+
def rerank_results(self, query: str, results: List[Tuple[Chunk, float]], top_k: int = 5) -> List[Tuple[Chunk, float]]:
|
432 |
+
"""Rerank results using cross-encoder with GPU optimization"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
try:
|
434 |
+
if not results or not self.reranker:
|
435 |
+
return results[:top_k]
|
436 |
+
|
437 |
+
self.memory_manager.clear_cache()
|
438 |
+
|
439 |
+
pairs = [(query, chunk.text) for chunk, _ in results]
|
440 |
+
|
441 |
+
with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
|
442 |
+
rerank_scores = self.reranker.predict(pairs, show_progress_bar=False)
|
443 |
+
|
444 |
+
# Combine with original scores
|
445 |
+
reranked_results = []
|
446 |
+
for i, (chunk, original_score) in enumerate(results):
|
447 |
+
combined_score = 0.6 * float(rerank_scores[i]) + 0.4 * original_score
|
448 |
+
reranked_results.append((chunk, combined_score))
|
449 |
+
|
450 |
+
# Sort by new scores
|
451 |
+
reranked_results.sort(key=lambda x: x[1], reverse=True)
|
452 |
+
|
453 |
+
self.memory_manager.clear_cache()
|
454 |
+
return reranked_results[:top_k]
|
455 |
+
|
456 |
except Exception as e:
|
457 |
+
logger.error(f"Reranking error: {e}")
|
458 |
+
self.memory_manager.clear_cache()
|
459 |
+
return results[:top_k]
|
460 |
+
|
461 |
+
@spaces.GPU(duration=90) # HuggingFace Spaces GPU decorator
|
462 |
+
def generate_answer(self, query: str, context_chunks: List[Chunk]) -> str:
|
463 |
+
"""Generate answer using retrieved context with GPU optimization"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
try:
|
465 |
+
if not context_chunks or not self.generator:
|
466 |
+
return "No relevant information found to answer your query."
|
467 |
+
|
468 |
+
self.memory_manager.clear_cache()
|
469 |
+
|
470 |
+
# Create context
|
471 |
+
context_parts = []
|
472 |
+
for chunk in context_chunks[:3]: # Limit context
|
473 |
+
paper = chunk.metadata.get("paper")
|
474 |
+
if paper:
|
475 |
+
context_parts.append(f"Title: {paper.title}\nContent: {chunk.text}")
|
476 |
+
|
477 |
+
context = "\n\n".join(context_parts)
|
478 |
+
|
479 |
+
# Create prompt
|
480 |
+
prompt = f"""Based on the following research papers, provide a comprehensive answer to the query:
|
481 |
+
|
482 |
+
Query: {query}
|
483 |
+
|
484 |
+
Research Context:
|
485 |
+
{context[:2000]}
|
486 |
+
|
487 |
+
Answer:"""
|
488 |
+
|
489 |
+
with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
|
490 |
+
response = self.generator(
|
491 |
+
prompt,
|
492 |
+
max_new_tokens=300,
|
493 |
+
temperature=0.7,
|
494 |
+
do_sample=True,
|
495 |
+
pad_token_id=50256
|
496 |
+
)
|
497 |
+
|
498 |
+
answer = response[0]['generated_text'].strip()
|
499 |
+
|
500 |
+
self.memory_manager.clear_cache()
|
501 |
+
return answer
|
502 |
+
|
503 |
+
except Exception as e:
|
504 |
+
logger.error(f"Answer generation error: {e}")
|
505 |
+
self.memory_manager.clear_cache()
|
506 |
+
return f"Error generating answer: {str(e)}"
|
507 |
+
|
508 |
+
def format_results(self, results: List[Tuple[Chunk, float]]) -> Tuple[str, pd.DataFrame]:
|
509 |
+
"""Format results for display"""
|
510 |
+
try:
|
511 |
+
if not results:
|
512 |
+
return "No relevant papers found.", pd.DataFrame()
|
513 |
+
|
514 |
+
# Group by paper
|
515 |
+
papers_dict = {}
|
516 |
+
for chunk, score in results:
|
517 |
+
paper = chunk.metadata.get("paper")
|
518 |
+
if paper and paper.id not in papers_dict:
|
519 |
+
papers_dict[paper.id] = {
|
520 |
+
'paper': paper,
|
521 |
+
'max_score': score,
|
522 |
+
'chunks': [(chunk, score)]
|
523 |
}
|
524 |
+
elif paper:
|
525 |
+
papers_dict[paper.id]['chunks'].append((chunk, score))
|
526 |
+
papers_dict[paper.id]['max_score'] = max(papers_dict[paper.id]['max_score'], score)
|
527 |
|
528 |
+
# Sort by max score
|
529 |
+
sorted_papers = sorted(papers_dict.values(), key=lambda x: x['max_score'], reverse=True)
|
530 |
+
|
531 |
+
# Format markdown
|
532 |
+
markdown_parts = []
|
533 |
+
table_data = []
|
534 |
+
|
535 |
+
for i, paper_info in enumerate(sorted_papers[:8], 1):
|
536 |
+
paper = paper_info['paper']
|
537 |
+
score = paper_info['max_score']
|
538 |
+
|
539 |
+
# Markdown format
|
540 |
+
authors_str = ", ".join(paper.authors[:3])
|
541 |
+
if len(paper.authors) > 3:
|
542 |
+
authors_str += " et al."
|
543 |
+
|
544 |
+
categories_str = ", ".join(paper.categories[:3])
|
545 |
+
|
546 |
+
markdown_parts.append(f"""
|
547 |
+
### {i}. [{paper.title}]({paper.url})
|
548 |
+
|
549 |
+
**Authors:** {authors_str}
|
550 |
+
**Categories:** {categories_str}
|
551 |
+
**Published:** {paper.published.strftime('%Y-%m-%d')}
|
552 |
+
**Relevance Score:** {score:.3f}
|
553 |
+
|
554 |
+
**Abstract:** {paper.abstract[:300]}{'...' if len(paper.abstract) > 300 else ''}
|
555 |
+
|
556 |
+
---
|
557 |
+
""")
|
558 |
+
|
559 |
+
# Table data
|
560 |
+
table_data.append({
|
561 |
+
'Rank': i,
|
562 |
+
'Title': paper.title[:60] + ('...' if len(paper.title) > 60 else ''),
|
563 |
+
'Authors': authors_str,
|
564 |
+
'Categories': categories_str,
|
565 |
+
'Published': paper.published.strftime('%Y-%m-%d'),
|
566 |
+
'Score': f"{score:.3f}",
|
567 |
+
'URL': paper.url
|
568 |
+
})
|
569 |
+
|
570 |
+
markdown_text = "".join(markdown_parts)
|
571 |
+
df = pd.DataFrame(table_data)
|
572 |
+
|
573 |
+
return markdown_text, df
|
574 |
|
575 |
except Exception as e:
|
576 |
+
logger.error(f"Formatting error: {e}")
|
577 |
+
return f"Error formatting results: {str(e)}", pd.DataFrame()
|
578 |
+
|
579 |
+
# Global RAG system instance
|
|
|
|
|
|
|
|
|
|
|
580 |
rag_system = None
|
581 |
|
582 |
+
def initialize_system():
|
583 |
+
"""Initialize the RAG system"""
|
584 |
global rag_system
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
try:
|
586 |
+
if rag_system is None:
|
587 |
+
logger.info("Initializing RAG system...")
|
588 |
+
rag_system = OptimizedRagSystem()
|
589 |
+
logger.info("RAG system initialized successfully")
|
590 |
+
except Exception as e:
|
591 |
+
logger.error(f"System initialization error: {e}")
|
592 |
+
raise
|
593 |
+
|
594 |
+
# Main search function
|
595 |
+
@spaces.GPU(duration=180) # HuggingFace Spaces GPU decorator for main function
|
596 |
+
def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
|
597 |
+
top_k_rerank: int = 5, categories: str = "", semantic_weight: float = 0.7):
|
598 |
+
"""Main search function with GPU optimization"""
|
599 |
+
try:
|
600 |
+
if not query.strip():
|
601 |
+
return "β Please enter a search query.", "", pd.DataFrame()
|
602 |
+
|
603 |
+
# Initialize system if needed
|
604 |
+
initialize_system()
|
605 |
+
|
606 |
+
start_time = time.time()
|
607 |
|
608 |
# Parse categories
|
609 |
+
category_list = []
|
610 |
if categories.strip():
|
611 |
category_list = [cat.strip() for cat in categories.split(',') if cat.strip()]
|
612 |
|
613 |
+
# Search ArXiv
|
614 |
+
papers = rag_system.search_arxiv(query, max_papers, category_list)
|
615 |
+
|
616 |
+
if not papers:
|
617 |
+
return "β No papers found. Try different keywords or check your internet connection.", "", pd.DataFrame()
|
618 |
+
|
619 |
+
# Create chunks and embeddings
|
620 |
+
rag_system.papers = papers
|
621 |
+
rag_system.chunks = rag_system.create_chunks(papers)
|
622 |
+
|
623 |
+
if not rag_system.chunks:
|
624 |
+
return "β Error processing papers.", "", pd.DataFrame()
|
625 |
+
|
626 |
+
# Create embeddings with GPU acceleration
|
627 |
+
rag_system.embeddings = rag_system.embed_chunks(rag_system.chunks)
|
628 |
+
|
629 |
+
if rag_system.embeddings is None or len(rag_system.embeddings) == 0:
|
630 |
+
return "β Error creating embeddings.", "", pd.DataFrame()
|
631 |
|
632 |
+
# Fit BM25
|
633 |
+
chunk_texts = [chunk.text for chunk in rag_system.chunks]
|
634 |
+
rag_system.bm25.fit(chunk_texts)
|
635 |
+
|
636 |
+
# Hybrid retrieval with GPU acceleration
|
637 |
+
retrieved_results = rag_system.hybrid_retrieval(query, top_k_retrieval, semantic_weight)
|
638 |
+
|
639 |
+
if not retrieved_results:
|
640 |
+
return "β No relevant content found.", "", pd.DataFrame()
|
641 |
+
|
642 |
+
# Rerank results with GPU acceleration
|
643 |
+
reranked_results = rag_system.rerank_results(query, retrieved_results, top_k_rerank)
|
644 |
|
645 |
+
# Generate answer with GPU acceleration
|
646 |
+
answer = rag_system.generate_answer(query, [chunk for chunk, _ in reranked_results])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
647 |
|
648 |
+
# Format results
|
649 |
+
papers_md, papers_df = rag_system.format_results(reranked_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
650 |
|
651 |
+
# Create response with statistics
|
652 |
+
end_time = time.time()
|
653 |
+
processing_time = end_time - start_time
|
654 |
+
|
655 |
+
stats = f"""
|
656 |
+
## π€ AI-Generated Answer
|
657 |
+
|
658 |
+
{answer}
|
659 |
+
|
660 |
+
## π Search Statistics
|
661 |
+
|
662 |
+
- **Query:** {query}
|
663 |
+
- **Papers Found:** {len(papers)}
|
664 |
+
- **Chunks Processed:** {len(rag_system.chunks)}
|
665 |
+
- **Top Results:** {len(reranked_results)}
|
666 |
+
- **Processing Time:** {processing_time:.2f}s
|
667 |
+
- **GPU Memory:** {rag_system.memory_manager.get_memory_info()}
|
668 |
+
- **Semantic Weight:** {semantic_weight}
|
669 |
+
|
670 |
+
---
|
671 |
+
"""
|
672 |
+
|
673 |
+
# Clean up GPU memory
|
674 |
+
rag_system.memory_manager.clear_cache()
|
675 |
+
|
676 |
+
return stats, papers_md, papers_df
|
677 |
|
678 |
except Exception as e:
|
679 |
+
logger.error(f"Search error: {e}")
|
680 |
error_msg = f"β An error occurred: {str(e)}\n\nPlease try different keywords or check your internet connection."
|
681 |
return error_msg, "", pd.DataFrame()
|
682 |
|
683 |
# Create Gradio interface
|
684 |
def create_interface():
|
685 |
+
"""Create optimized Gradio interface"""
|
686 |
|
687 |
css = """
|
688 |
.gradio-container {
|
689 |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
690 |
}
|
691 |
+
.gpu-badge {
|
692 |
+
background: linear-gradient(45deg, #00d4aa, #00b4d8);
|
693 |
+
color: white;
|
694 |
+
padding: 0.5rem 1rem;
|
695 |
+
border-radius: 20px;
|
696 |
+
font-weight: bold;
|
697 |
+
display: inline-block;
|
698 |
+
margin-bottom: 1rem;
|
699 |
+
}
|
700 |
"""
|
701 |
|
702 |
+
with gr.Blocks(css=css, title="Enhanced ArXiv RAG System - GPU Optimized") as interface:
|
703 |
|
704 |
+
gr.HTML(f"""
|
705 |
<div style="text-align: center; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
|
706 |
<h1>π Enhanced ArXiv RAG System</h1>
|
707 |
+
<p>GPU-Optimized scientific paper discovery with semantic search, BM25, and neural reranking</p>
|
708 |
+
<div class="gpu-badge">
|
709 |
+
π₯ GPU Accelerated β’ Device: {DEVICE.upper()}
|
710 |
+
</div>
|
711 |
</div>
|
712 |
""")
|
713 |
|
|
|
729 |
value=""
|
730 |
)
|
731 |
|
732 |
+
with gr.Accordion("π§ Advanced GPU Settings", open=False):
|
733 |
with gr.Row():
|
734 |
top_k_retrieval = gr.Slider(5, 15, value=10, step=1, label="Top-K Retrieval")
|
735 |
top_k_rerank = gr.Slider(3, 8, value=5, step=1, label="Top-K Reranking")
|
736 |
+
|
737 |
+
gr.HTML(f"""
|
738 |
+
<div style="background: #e8f5e8; padding: 1rem; border-radius: 8px; margin-top: 1rem;">
|
739 |
+
<h4>β‘ GPU Optimization Info</h4>
|
740 |
+
<ul>
|
741 |
+
<li><strong>Device:</strong> {DEVICE.upper()}</li>
|
742 |
+
<li><strong>Mixed Precision:</strong> {'Enabled' if DEVICE == 'cuda' else 'Disabled'}</li>
|
743 |
+
<li><strong>Memory Management:</strong> Automatic cleanup</li>
|
744 |
+
<li><strong>Batch Processing:</strong> Optimized for GPU</li>
|
745 |
+
</ul>
|
746 |
+
</div>
|
747 |
+
""")
|
748 |
|
749 |
+
search_btn = gr.Button("π Search Papers", variant="primary", size="lg")
|
750 |
|
751 |
with gr.Column(scale=1):
|
752 |
gr.HTML("""
|
753 |
<div style="background: #e3f2fd; padding: 1rem; border-radius: 8px;">
|
754 |
+
<h4>π‘ Tips for Best Results</h4>
|
755 |
<ul>
|
756 |
<li>Use specific technical terms</li>
|
757 |
<li>Try different category filters</li>
|
758 |
<li>Adjust semantic weight for different search styles</li>
|
759 |
+
<li>Higher semantic weight = more conceptual matching</li>
|
760 |
+
<li>Lower semantic weight = more keyword matching</li>
|
761 |
</ul>
|
762 |
|
763 |
+
<h4>π Popular Categories</h4>
|
764 |
<ul>
|
765 |
<li><code>cs.AI</code> - Artificial Intelligence</li>
|
766 |
<li><code>cs.CL</code> - Computation and Language</li>
|
767 |
<li><code>cs.LG</code> - Machine Learning</li>
|
768 |
<li><code>cs.CV</code> - Computer Vision</li>
|
769 |
+
<li><code>cs.RO</code> - Robotics</li>
|
770 |
+
<li><code>stat.ML</code> - Machine Learning (Stats)</li>
|
771 |
</ul>
|
772 |
</div>
|
773 |
""")
|
|
|
780 |
papers_output = gr.Markdown(label="Relevant Papers")
|
781 |
|
782 |
with gr.TabItem("π Papers Table"):
|
783 |
+
papers_table = gr.Dataframe(
|
784 |
+
label="Papers Summary",
|
785 |
+
wrap=True,
|
786 |
+
interactive=False
|
787 |
+
)
|
788 |
|
789 |
# Examples
|
790 |
gr.Examples(
|
791 |
examples=[
|
792 |
["transformer attention mechanisms", 15, 10, 5, "cs.CL, cs.AI", 0.7],
|
793 |
+
["graph neural networks for molecular property prediction", 12, 8, 4, "cs.LG", 0.6],
|
794 |
["computer vision deep learning", 15, 10, 5, "cs.CV", 0.8],
|
795 |
+
["reinforcement learning robotics", 18, 10, 5, "cs.AI, cs.RO", 0.7],
|
796 |
+
["large language models fine-tuning", 20, 12, 6, "cs.CL", 0.75]
|
797 |
],
|
798 |
inputs=[query_input, max_papers, top_k_retrieval, top_k_rerank, categories_input, semantic_weight]
|
799 |
)
|
|
|
807 |
|
808 |
gr.HTML("""
|
809 |
<div style="text-align: center; margin-top: 2rem; padding: 1rem; background: #f5f5f5; border-radius: 8px;">
|
810 |
+
<p><strong>Enhanced ArXiv RAG System</strong> | GPU-Optimized β’ Semantic Search + BM25 + Neural Reranking</p>
|
811 |
+
<p><em>Powered by Hugging Face Spaces GPU β’ Optimized for high-performance research</em></p>
|
812 |
</div>
|
813 |
""")
|
814 |
|
|
|
816 |
|
817 |
# Launch interface
|
818 |
if __name__ == "__main__":
|
819 |
+
# Pre-initialize system to reduce first-run latency
|
820 |
+
try:
|
821 |
+
initialize_system()
|
822 |
+
except Exception as e:
|
823 |
+
logger.error(f"Pre-initialization failed: {e}")
|
824 |
+
|
825 |
interface = create_interface()
|
826 |
+
interface.launch(
|
827 |
+
show_error=True,
|
828 |
+
share=True,
|
829 |
+
enable_queue=True,
|
830 |
+
max_threads=4
|
831 |
+
)
|