mihirinamdar commited on
Commit
da1b347
Β·
verified Β·
1 Parent(s): 82723d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +614 -487
app.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Enhanced ArXiv RAG System - Hugging Face Spaces Compatible Version
3
  """
4
 
5
  import os
@@ -13,12 +13,16 @@ from datetime import datetime, timedelta
13
  import logging
14
  import tempfile
15
  import shutil
 
 
16
 
17
  # Core ML libraries
18
  import torch
 
19
  from sentence_transformers import SentenceTransformer, CrossEncoder
20
- from transformers import pipeline
21
  import gradio as gr
 
22
 
23
  # BM25 and text processing
24
  from sklearn.feature_extraction.text import TfidfVectorizer
@@ -32,17 +36,25 @@ from nltk.stem import PorterStemmer
32
  try:
33
  nltk.data.find('tokenizers/punkt')
34
  except LookupError:
35
- nltk.download('punkt')
36
 
37
  try:
38
  nltk.data.find('corpora/stopwords')
39
  except LookupError:
40
- nltk.download('stopwords')
41
 
42
  # Setup logging
43
  logging.basicConfig(level=logging.INFO)
44
  logger = logging.getLogger(__name__)
45
 
 
 
 
 
 
 
 
 
46
  @dataclass
47
  class Paper:
48
  """Data class for storing paper information"""
@@ -63,9 +75,36 @@ class Chunk:
63
  chunk_type: str
64
  metadata: Dict[str, Any]
65
 
66
- class BM25Retriever:
67
- """BM25 retriever for keyword-based search"""
 
 
 
 
 
 
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def __init__(self, k1: float = 1.5, b: float = 0.75):
70
  self.k1 = k1
71
  self.b = b
@@ -77,542 +116,598 @@ class BM25Retriever:
77
  self.stop_words = set(stopwords.words('english'))
78
  except:
79
  self.stop_words = set()
80
-
81
  def preprocess_text(self, text: str) -> List[str]:
82
  """Preprocess text for BM25"""
83
- tokens = word_tokenize(text.lower())
84
- processed_tokens = [
85
- self.stemmer.stem(token)
86
- for token in tokens
87
- if token.isalpha() and token not in self.stop_words
88
- ]
89
- return processed_tokens
90
-
 
 
 
 
91
  def fit(self, documents: List[str]):
92
- """Fit BM25 on documents"""
93
- self.documents = [self.preprocess_text(doc) for doc in documents]
94
- self.doc_lengths = [len(doc) for doc in self.documents]
95
- self.avg_doc_length = sum(self.doc_lengths) / len(self.doc_lengths) if self.doc_lengths else 0
96
-
97
- vocab = set()
98
- for doc in self.documents:
99
- vocab.update(doc)
100
- self.vocab = list(vocab)
101
-
102
- self.term_freqs = []
103
- for doc in self.documents:
104
- tf = {}
105
- for term in doc:
106
- tf[term] = tf.get(term, 0) + 1
107
- self.term_freqs.append(tf)
108
-
109
- self.idf = {}
110
- for term in self.vocab:
111
- containing_docs = sum(1 for tf in self.term_freqs if term in tf)
112
- self.idf[term] = np.log((len(self.documents) - containing_docs + 0.5) / (containing_docs + 0.5))
113
-
114
- def score(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]:
115
- """Score documents against query"""
116
- query_terms = self.preprocess_text(query)
117
- scores = []
118
-
119
- for i, (doc, tf, doc_len) in enumerate(zip(self.documents, self.term_freqs, self.doc_lengths)):
120
- score = 0
121
- for term in query_terms:
122
- if term in tf:
123
- term_freq = tf[term]
124
- idf = self.idf.get(term, 0)
125
- numerator = term_freq * (self.k1 + 1)
126
- denominator = term_freq + self.k1 * (1 - self.b + self.b * (doc_len / self.avg_doc_length))
127
- score += idf * (numerator / denominator)
128
- scores.append((i, score))
129
-
130
- scores.sort(key=lambda x: x[1], reverse=True)
131
- return scores[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- class SimpleVectorStore:
134
- """Simple in-memory vector store for HF Spaces compatibility"""
135
-
136
  def __init__(self):
137
- self.embeddings = []
138
- self.documents = []
139
- self.metadatas = []
140
- self.ids = []
141
-
142
- def add(self, ids: List[str], embeddings: List[List[float]],
143
- documents: List[str], metadatas: List[Dict]):
144
- """Add documents to the store"""
145
- self.ids.extend(ids)
146
- self.embeddings.extend(embeddings)
147
- self.documents.extend(documents)
148
- self.metadatas.extend(metadatas)
149
-
150
- def query(self, query_embedding: List[float], n_results: int = 10) -> Dict:
151
- """Query the vector store"""
152
- if not self.embeddings:
153
- return {'ids': [[]], 'documents': [[]], 'metadatas': [[]]}
154
-
155
- # Calculate cosine similarities
156
- query_embedding = np.array(query_embedding)
157
- similarities = []
158
-
159
- for emb in self.embeddings:
160
- emb_array = np.array(emb)
161
- similarity = np.dot(query_embedding, emb_array) / (
162
- np.linalg.norm(query_embedding) * np.linalg.norm(emb_array)
163
  )
164
- similarities.append(similarity)
165
-
166
- # Get top results
167
- top_indices = np.argsort(similarities)[::-1][:n_results]
168
-
169
- return {
170
- 'ids': [[self.ids[i] for i in top_indices]],
171
- 'documents': [[self.documents[i] for i in top_indices]],
172
- 'metadatas': [[self.metadatas[i] for i in top_indices]]
173
- }
174
-
175
- def get(self, ids: Optional[List[str]] = None) -> Dict:
176
- """Get documents by IDs or all documents"""
177
- if ids is None:
178
- return {
179
- 'ids': self.ids,
180
- 'documents': self.documents,
181
- 'metadatas': self.metadatas
182
- }
183
- else:
184
- indices = [self.ids.index(id_) for id_ in ids if id_ in self.ids]
185
- return {
186
- 'ids': [self.ids[i] for i in indices],
187
- 'documents': [self.documents[i] for i in indices],
188
- 'metadatas': [self.metadatas[i] for i in indices]
189
- }
190
-
191
- def clear(self):
192
- """Clear the store"""
193
- self.embeddings.clear()
194
- self.documents.clear()
195
- self.metadatas.clear()
196
- self.ids.clear()
197
-
198
- class EnhancedArxivRAG:
199
- """Enhanced RAG system optimized for Hugging Face Spaces"""
200
-
201
- def __init__(self):
202
- logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
203
-
204
- # Use smaller, faster models for HF Spaces
205
- self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
206
- self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2') # Smaller reranker
207
- self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn",
208
- device=0 if torch.cuda.is_available() else -1)
209
-
210
- # Use simple vector store instead of ChromaDB for HF Spaces
211
- self.vector_store = SimpleVectorStore()
212
- self.bm25_retriever = BM25Retriever()
213
-
214
- # Cache for papers and chunks
215
- self.papers_cache = {}
216
- self.chunks_cache = {}
217
- self.bm25_fitted = False
218
-
219
- logger.info("RAG system initialized successfully!")
220
-
221
- def fetch_papers(self, query: str, max_results: int = 15,
222
- categories: Optional[List[str]] = None) -> List[Paper]:
223
- """Fetch papers from ArXiv"""
224
- search_query = query
225
- if categories:
226
- category_filter = " OR ".join([f"cat:{cat}" for cat in categories])
227
- search_query = f"({query}) AND ({category_filter})"
228
-
229
- logger.info(f"Fetching papers with query: {search_query}")
230
-
231
  try:
 
 
 
 
 
 
 
 
 
232
  search = arxiv.Search(
233
  query=search_query,
234
  max_results=max_results,
235
- sort_by=arxiv.SortCriterion.Relevance
 
236
  )
237
 
238
- papers = []
239
  for result in search.results():
240
- paper = Paper(
241
- id=result.entry_id.split('/')[-1],
242
- title=result.title.strip().replace('\n', ' '),
243
- abstract=result.summary.strip().replace('\n', ' '),
244
- authors=[author.name for author in result.authors],
245
- categories=result.categories,
246
- published=result.published.replace(tzinfo=None),
247
- url=result.entry_id
248
- )
249
- papers.append(paper)
250
- self.papers_cache[paper.id] = paper
 
 
 
 
 
 
 
251
 
252
- logger.info(f"Fetched {len(papers)} papers")
253
  return papers
254
 
255
  except Exception as e:
256
- logger.error(f"Error fetching papers: {e}")
257
  return []
258
-
259
  def create_chunks(self, papers: List[Paper]) -> List[Chunk]:
260
  """Create text chunks from papers"""
261
  chunks = []
262
 
263
  for paper in papers:
264
- # Title chunk
265
- title_chunk = Chunk(
266
- id=f"{paper.id}_title",
267
- paper_id=paper.id,
268
- text=paper.title,
269
- chunk_type="title",
270
- metadata={
271
- "authors": paper.authors,
272
- "categories": paper.categories,
273
- "published": paper.published.isoformat(),
274
- "url": paper.url
275
- }
276
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- # Abstract chunk
279
- abstract_chunk = Chunk(
280
- id=f"{paper.id}_abstract",
281
- paper_id=paper.id,
282
- text=paper.abstract,
283
- chunk_type="abstract",
284
- metadata={
285
- "authors": paper.authors,
286
- "categories": paper.categories,
287
- "published": paper.published.isoformat(),
288
- "url": paper.url
289
- }
290
- )
291
 
292
- # Combined chunk
293
- combined_text = f"Title: {paper.title}\n\nAbstract: {paper.abstract}"
294
- combined_chunk = Chunk(
295
- id=f"{paper.id}_combined",
296
- paper_id=paper.id,
297
- text=combined_text,
298
- chunk_type="combined",
299
- metadata={
300
- "authors": paper.authors,
301
- "categories": paper.categories,
302
- "published": paper.published.isoformat(),
303
- "url": paper.url
304
- }
305
- )
306
 
307
- chunks.extend([title_chunk, abstract_chunk, combined_chunk])
 
 
308
 
309
- # Cache chunks
310
- for chunk in [title_chunk, abstract_chunk, combined_chunk]:
311
- self.chunks_cache[chunk.id] = chunk
312
-
313
- return chunks
314
-
315
- def process_and_store(self, papers: List[Paper]):
316
- """Process papers and store in vector store"""
317
- logger.info("Processing and storing papers...")
318
-
319
- # Clear previous data
320
- self.vector_store.clear()
321
-
322
- # Create chunks
323
- chunks = self.create_chunks(papers)
324
-
325
- if not chunks:
326
- return
327
-
328
- # Generate embeddings
329
- texts = [chunk.text for chunk in chunks]
330
- logger.info("Generating embeddings...")
331
- embeddings = self.embedding_model.encode(texts, show_progress_bar=False)
332
-
333
- # Store in vector store
334
- ids = [chunk.id for chunk in chunks]
335
- metadatas = [chunk.metadata for chunk in chunks]
336
-
337
- self.vector_store.add(
338
- ids=ids,
339
- embeddings=embeddings.tolist(),
340
- documents=texts,
341
- metadatas=metadatas
342
- )
343
-
344
- # Fit BM25
345
- logger.info("Fitting BM25...")
346
- self.bm25_retriever.fit(texts)
347
- self.bm25_fitted = True
348
-
349
- logger.info(f"Stored {len(chunks)} chunks")
350
-
351
- def hybrid_search(self, query: str, top_k: int = 10,
352
- semantic_weight: float = 0.7) -> List[Dict]:
353
- """Perform hybrid search"""
354
- # Semantic search
355
- query_embedding = self.embedding_model.encode([query])
356
- semantic_results = self.vector_store.query(
357
- query_embedding=query_embedding[0].tolist(),
358
- n_results=top_k * 2
359
- )
360
-
361
- # BM25 search
362
- bm25_results = []
363
- if self.bm25_fitted:
364
- all_docs = self.vector_store.get()
365
- bm25_scores = self.bm25_retriever.score(query, top_k * 2)
366
-
367
- for idx, score in bm25_scores:
368
- if idx < len(all_docs['ids']):
369
- bm25_results.append({
370
- 'id': all_docs['ids'][idx],
371
- 'document': all_docs['documents'][idx],
372
- 'metadata': all_docs['metadatas'][idx],
373
- 'score': score
374
- })
375
-
376
- # Combine results using RRF
377
- combined_scores = {}
378
- bm25_weight = 1.0 - semantic_weight
379
-
380
- # Add semantic scores
381
- for i, doc_id in enumerate(semantic_results['ids'][0]):
382
- rank = i + 1
383
- combined_scores[doc_id] = combined_scores.get(doc_id, 0) + semantic_weight / rank
384
-
385
- # Add BM25 scores
386
- for i, result in enumerate(bm25_results):
387
- doc_id = result['id']
388
- rank = i + 1
389
- combined_scores[doc_id] = combined_scores.get(doc_id, 0) + bm25_weight / rank
390
-
391
- # Sort by combined score
392
- sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
393
-
394
- # Prepare final results
395
- final_results = []
396
- for doc_id, score in sorted_results[:top_k]:
397
- doc_result = self.vector_store.get(ids=[doc_id])
398
- if doc_result['ids']:
399
- final_results.append({
400
- 'id': doc_id,
401
- 'document': doc_result['documents'][0],
402
- 'metadata': doc_result['metadatas'][0],
403
- 'combined_score': score
404
- })
405
-
406
- return final_results
407
-
408
- def rerank_results(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
409
- """Rerank results using cross-encoder"""
410
- if not results:
411
  return results
412
-
413
- # Prepare query-document pairs
414
- query_doc_pairs = [(query, result['document']) for result in results]
415
-
416
- # Get reranking scores
417
- rerank_scores = self.reranker.predict(query_doc_pairs)
418
-
419
- # Add rerank scores to results
420
- for i, result in enumerate(results):
421
- result['rerank_score'] = float(rerank_scores[i])
422
-
423
- # Sort by rerank score
424
- reranked_results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
425
- return reranked_results[:top_k]
426
-
427
- def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
428
- """Generate answer using retrieved context"""
429
- if not context_chunks:
430
- return "No relevant information found to answer your query."
431
-
432
- # Combine context from top chunks
433
- context_texts = [chunk['document'] for chunk in context_chunks[:3]]
434
- combined_context = "\n\n".join(context_texts)
435
-
436
- # Limit context length
437
- max_context_length = 800
438
- if len(combined_context) > max_context_length:
439
- combined_context = combined_context[:max_context_length] + "..."
440
-
441
  try:
442
- summary_input = f"Based on the following research papers, answer this question: {query}\n\nContext: {combined_context}"
443
- summary = self.summarizer(summary_input,
444
- max_length=120,
445
- min_length=30,
446
- do_sample=False)[0]['summary_text']
447
- return summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  except Exception as e:
449
- logger.error(f"Error generating summary: {e}")
450
- return f"Based on the retrieved papers about '{query}', here are the key findings:\n\n" + \
451
- "\n\n".join([chunk['document'][:150] + "..." for chunk in context_chunks[:2]])
452
-
453
- def search_and_answer(self, query: str, max_papers: int = 15,
454
- top_k_retrieval: int = 10, top_k_rerank: int = 5,
455
- categories: Optional[List[str]] = None,
456
- semantic_weight: float = 0.7) -> Dict[str, Any]:
457
- """Main search and answer pipeline"""
458
-
459
- if not query.strip():
460
- return {
461
- 'answer': "Please enter a valid research query.",
462
- 'papers': [],
463
- 'retrieved_chunks': [],
464
- 'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
465
- }
466
-
467
  try:
468
- # Fetch papers
469
- papers = self.fetch_papers(query, max_papers, categories)
470
-
471
- if not papers:
472
- return {
473
- 'answer': "No papers found for your query. Please try different keywords.",
474
- 'papers': [],
475
- 'retrieved_chunks': [],
476
- 'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
477
- }
478
-
479
- # Process and store papers
480
- self.process_and_store(papers)
481
-
482
- # Hybrid search
483
- search_results = self.hybrid_search(query, top_k_retrieval, semantic_weight)
484
-
485
- # Rerank results
486
- reranked_results = self.rerank_results(query, search_results, top_k_rerank)
487
-
488
- # Generate answer
489
- answer = self.generate_answer(query, reranked_results)
490
-
491
- # Prepare unique papers
492
- unique_papers = {}
493
- for chunk in reranked_results:
494
- paper_id = chunk['id'].split('_')[0]
495
- if paper_id in self.papers_cache and paper_id not in unique_papers:
496
- paper = self.papers_cache[paper_id]
497
- unique_papers[paper_id] = {
498
- 'title': paper.title,
499
- 'authors': paper.authors,
500
- 'abstract': paper.abstract,
501
- 'url': paper.url,
502
- 'categories': paper.categories,
503
- 'published': paper.published.strftime('%Y-%m-%d')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  }
 
 
 
505
 
506
- return {
507
- 'answer': answer,
508
- 'papers': list(unique_papers.values()),
509
- 'retrieved_chunks': reranked_results,
510
- 'search_stats': {
511
- 'papers_found': len(papers),
512
- 'chunks_retrieved': len(reranked_results),
513
- 'unique_papers_in_results': len(unique_papers)
514
- }
515
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
 
517
  except Exception as e:
518
- logger.error(f"Error in search_and_answer: {e}")
519
- return {
520
- 'answer': f"An error occurred while processing your query: {str(e)}",
521
- 'papers': [],
522
- 'retrieved_chunks': [],
523
- 'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
524
- }
525
-
526
- # Global RAG instance
527
  rag_system = None
528
 
529
- def initialize_rag():
530
- """Initialize RAG system"""
531
  global rag_system
532
- if rag_system is None:
533
- rag_system = EnhancedArxivRAG()
534
- return rag_system
535
-
536
- def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
537
- top_k_rerank: int = 5, categories: str = "",
538
- semantic_weight: float = 0.7) -> tuple:
539
- """Main search function for Gradio interface"""
540
-
541
- if not query.strip():
542
- return "❌ Please enter a research topic or question.", "", ""
543
-
544
  try:
545
- # Initialize RAG system
546
- rag = initialize_rag()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
 
548
  # Parse categories
549
- category_list = None
550
  if categories.strip():
551
  category_list = [cat.strip() for cat in categories.split(',') if cat.strip()]
552
 
553
- # Perform search
554
- result = rag.search_and_answer(
555
- query=query,
556
- max_papers=max_papers,
557
- top_k_retrieval=top_k_retrieval,
558
- top_k_rerank=top_k_rerank,
559
- categories=category_list,
560
- semantic_weight=semantic_weight
561
- )
 
 
 
 
 
 
 
 
 
562
 
563
- # Format answer
564
- answer = f"## πŸ€– AI-Generated Answer\n\n{result['answer']}\n\n"
565
- answer += f"**Search Statistics:**\n"
566
- answer += f"- Papers found: {result['search_stats']['papers_found']}\n"
567
- answer += f"- Chunks retrieved: {result['search_stats']['chunks_retrieved']}\n"
568
- answer += f"- Unique papers in results: {result['search_stats']['unique_papers_in_results']}\n\n"
 
 
 
 
 
 
569
 
570
- # Format papers
571
- papers_md = "## πŸ“š Relevant Papers\n\n"
572
- for i, paper in enumerate(result['papers'], 1):
573
- papers_md += f"### {i}. {paper['title']}\n\n"
574
- papers_md += f"**Authors:** {', '.join(paper['authors'][:3])}{'...' if len(paper['authors']) > 3 else ''}\n\n"
575
- papers_md += f"**Categories:** {', '.join(paper['categories'])}\n\n"
576
- papers_md += f"**Published:** {paper['published']}\n\n"
577
- papers_md += f"**Abstract:** {paper['abstract'][:250]}{'...' if len(paper['abstract']) > 250 else ''}\n\n"
578
- papers_md += f"**URL:** [{paper['url']}]({paper['url']})\n\n"
579
- papers_md += "---\n\n"
580
 
581
- # Create papers dataframe
582
- papers_df = pd.DataFrame([
583
- {
584
- 'Title': paper['title'][:50] + '...' if len(paper['title']) > 50 else paper['title'],
585
- 'Authors': ', '.join(paper['authors'][:2]) + ('...' if len(paper['authors']) > 2 else ''),
586
- 'Categories': ', '.join(paper['categories'][:2]),
587
- 'Published': paper['published'],
588
- 'URL': paper['url']
589
- }
590
- for paper in result['papers']
591
- ])
592
 
593
- return answer, papers_md, papers_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
 
595
  except Exception as e:
596
- logger.error(f"Error processing query: {e}")
597
  error_msg = f"❌ An error occurred: {str(e)}\n\nPlease try different keywords or check your internet connection."
598
  return error_msg, "", pd.DataFrame()
599
 
600
  # Create Gradio interface
601
  def create_interface():
602
- """Create Gradio interface"""
603
 
604
  css = """
605
  .gradio-container {
606
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
607
  }
 
 
 
 
 
 
 
 
 
608
  """
609
 
610
- with gr.Blocks(css=css, title="Enhanced ArXiv RAG System") as interface:
611
 
612
- gr.HTML("""
613
  <div style="text-align: center; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
614
  <h1>πŸš€ Enhanced ArXiv RAG System</h1>
615
- <p>Advanced scientific paper discovery with semantic search, BM25, and neural reranking</p>
 
 
 
616
  </div>
617
  """)
618
 
@@ -634,29 +729,45 @@ def create_interface():
634
  value=""
635
  )
636
 
637
- with gr.Accordion("Advanced Settings", open=False):
638
  with gr.Row():
639
  top_k_retrieval = gr.Slider(5, 15, value=10, step=1, label="Top-K Retrieval")
640
  top_k_rerank = gr.Slider(3, 8, value=5, step=1, label="Top-K Reranking")
 
 
 
 
 
 
 
 
 
 
 
 
641
 
642
- search_btn = gr.Button("πŸ” Search Papers", variant="primary")
643
 
644
  with gr.Column(scale=1):
645
  gr.HTML("""
646
  <div style="background: #e3f2fd; padding: 1rem; border-radius: 8px;">
647
- <h4>πŸ’‘ Tips</h4>
648
  <ul>
649
  <li>Use specific technical terms</li>
650
  <li>Try different category filters</li>
651
  <li>Adjust semantic weight for different search styles</li>
 
 
652
  </ul>
653
 
654
- <h4>πŸ“Š Categories</h4>
655
  <ul>
656
  <li><code>cs.AI</code> - Artificial Intelligence</li>
657
  <li><code>cs.CL</code> - Computation and Language</li>
658
  <li><code>cs.LG</code> - Machine Learning</li>
659
  <li><code>cs.CV</code> - Computer Vision</li>
 
 
660
  </ul>
661
  </div>
662
  """)
@@ -669,15 +780,20 @@ def create_interface():
669
  papers_output = gr.Markdown(label="Relevant Papers")
670
 
671
  with gr.TabItem("πŸ“Š Papers Table"):
672
- papers_table = gr.Dataframe(label="Papers Summary")
 
 
 
 
673
 
674
  # Examples
675
  gr.Examples(
676
  examples=[
677
  ["transformer attention mechanisms", 15, 10, 5, "cs.CL, cs.AI", 0.7],
678
- ["graph neural networks", 12, 8, 4, "cs.LG", 0.6],
679
  ["computer vision deep learning", 15, 10, 5, "cs.CV", 0.8],
680
- ["reinforcement learning", 18, 10, 5, "cs.AI", 0.7]
 
681
  ],
682
  inputs=[query_input, max_papers, top_k_retrieval, top_k_rerank, categories_input, semantic_weight]
683
  )
@@ -691,7 +807,8 @@ def create_interface():
691
 
692
  gr.HTML("""
693
  <div style="text-align: center; margin-top: 2rem; padding: 1rem; background: #f5f5f5; border-radius: 8px;">
694
- <p><strong>Enhanced ArXiv RAG System</strong> | Semantic Search + BM25 + Neural Reranking</p>
 
695
  </div>
696
  """)
697
 
@@ -699,6 +816,16 @@ def create_interface():
699
 
700
  # Launch interface
701
  if __name__ == "__main__":
 
 
 
 
 
 
702
  interface = create_interface()
703
- interface.launch()
704
-
 
 
 
 
 
1
  """
2
+ Enhanced ArXiv RAG System - GPU Optimized for Hugging Face Spaces
3
  """
4
 
5
  import os
 
13
  import logging
14
  import tempfile
15
  import shutil
16
+ import gc
17
+ import time
18
 
19
  # Core ML libraries
20
  import torch
21
+ import torch.nn.functional as F
22
  from sentence_transformers import SentenceTransformer, CrossEncoder
23
+ from transformers import pipeline, AutoTokenizer, AutoModel
24
  import gradio as gr
25
+ import spaces # HuggingFace Spaces GPU support
26
 
27
  # BM25 and text processing
28
  from sklearn.feature_extraction.text import TfidfVectorizer
 
36
  try:
37
  nltk.data.find('tokenizers/punkt')
38
  except LookupError:
39
+ nltk.download('punkt', quiet=True)
40
 
41
  try:
42
  nltk.data.find('corpora/stopwords')
43
  except LookupError:
44
+ nltk.download('stopwords', quiet=True)
45
 
46
  # Setup logging
47
  logging.basicConfig(level=logging.INFO)
48
  logger = logging.getLogger(__name__)
49
 
50
+ # GPU Configuration
51
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
52
+ logger.info(f"Using device: {DEVICE}")
53
+
54
+ if DEVICE == "cuda":
55
+ logger.info(f"GPU: {torch.cuda.get_device_name()}")
56
+ logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
57
+
58
  @dataclass
59
  class Paper:
60
  """Data class for storing paper information"""
 
75
  chunk_type: str
76
  metadata: Dict[str, Any]
77
 
78
+ class GPUMemoryManager:
79
+ """Manages GPU memory efficiently"""
80
+
81
+ @staticmethod
82
+ def clear_cache():
83
+ """Clear GPU cache"""
84
+ if torch.cuda.is_available():
85
+ torch.cuda.empty_cache()
86
+ gc.collect()
87
 
88
+ @staticmethod
89
+ def get_memory_info():
90
+ """Get GPU memory information"""
91
+ if torch.cuda.is_available():
92
+ allocated = torch.cuda.memory_allocated() / 1e9
93
+ cached = torch.cuda.memory_reserved() / 1e9
94
+ return f"Allocated: {allocated:.1f}GB, Cached: {cached:.1f}GB"
95
+ return "CPU mode"
96
+
97
+ @staticmethod
98
+ def optimize_memory():
99
+ """Optimize memory usage"""
100
+ if torch.cuda.is_available():
101
+ torch.cuda.empty_cache()
102
+ torch.backends.cudnn.benchmark = True
103
+ torch.backends.cuda.matmul.allow_tf32 = True
104
+
105
+ class BM25Retriever:
106
+ """Optimized BM25 retriever for keyword-based search"""
107
+
108
  def __init__(self, k1: float = 1.5, b: float = 0.75):
109
  self.k1 = k1
110
  self.b = b
 
116
  self.stop_words = set(stopwords.words('english'))
117
  except:
118
  self.stop_words = set()
119
+
120
  def preprocess_text(self, text: str) -> List[str]:
121
  """Preprocess text for BM25"""
122
+ try:
123
+ tokens = word_tokenize(text.lower())
124
+ processed_tokens = [
125
+ self.stemmer.stem(token)
126
+ for token in tokens
127
+ if token.isalpha() and token not in self.stop_words and len(token) > 2
128
+ ]
129
+ return processed_tokens
130
+ except Exception as e:
131
+ logger.warning(f"Text preprocessing error: {e}")
132
+ return text.lower().split()
133
+
134
  def fit(self, documents: List[str]):
135
+ """Fit BM25 on documents with memory optimization"""
136
+ try:
137
+ self.documents = [self.preprocess_text(doc) for doc in documents]
138
+ self.doc_lengths = [len(doc) for doc in self.documents]
139
+ self.avg_doc_length = sum(self.doc_lengths) / len(self.doc_lengths) if self.doc_lengths else 0
140
+
141
+ # Build vocabulary
142
+ vocab = set()
143
+ for doc in self.documents:
144
+ vocab.update(doc)
145
+ self.vocab = list(vocab)
146
+
147
+ # Calculate term frequencies
148
+ self.term_freqs = []
149
+ for doc in self.documents:
150
+ tf = {}
151
+ for term in doc:
152
+ tf[term] = tf.get(term, 0) + 1
153
+ self.term_freqs.append(tf)
154
+
155
+ # Calculate IDF
156
+ self.idf = {}
157
+ for term in self.vocab:
158
+ df = sum(1 for doc in self.documents if term in doc)
159
+ self.idf[term] = np.log((len(self.documents) - df + 0.5) / (df + 0.5))
160
+
161
+ except Exception as e:
162
+ logger.error(f"BM25 fitting error: {e}")
163
+ self.documents = []
164
+
165
+ def get_scores(self, query: str) -> np.ndarray:
166
+ """Get BM25 scores for query"""
167
+ try:
168
+ query_terms = self.preprocess_text(query)
169
+ scores = np.zeros(len(self.documents))
170
+
171
+ for i, doc_tf in enumerate(self.term_freqs):
172
+ score = 0
173
+ doc_length = self.doc_lengths[i]
174
+
175
+ for term in query_terms:
176
+ if term in doc_tf:
177
+ tf = doc_tf[term]
178
+ idf = self.idf.get(term, 0)
179
+ score += idf * (tf * (self.k1 + 1)) / (
180
+ tf + self.k1 * (1 - self.b + self.b * (doc_length / self.avg_doc_length))
181
+ )
182
+
183
+ scores[i] = score
184
+
185
+ return scores
186
+ except Exception as e:
187
+ logger.error(f"BM25 scoring error: {e}")
188
+ return np.zeros(len(self.documents))
189
+
190
+ class OptimizedRagSystem:
191
+ """GPU-optimized RAG system for ArXiv papers"""
192
 
 
 
 
193
  def __init__(self):
194
+ self.papers = []
195
+ self.chunks = []
196
+ self.embeddings = None
197
+ self.embedding_model = None
198
+ self.reranker = None
199
+ self.bm25 = BM25Retriever()
200
+ self.generator = None
201
+ self.memory_manager = GPUMemoryManager()
202
+
203
+ # Initialize models
204
+ self._load_models()
205
+
206
+ def _load_models(self):
207
+ """Load models with GPU optimization"""
208
+ try:
209
+ logger.info("Loading models...")
210
+
211
+ # Load embedding model
212
+ self.embedding_model = SentenceTransformer(
213
+ 'sentence-transformers/all-MiniLM-L6-v2',
214
+ device=DEVICE
 
 
 
 
 
215
  )
216
+
217
+ # Optimize for GPU if available
218
+ if DEVICE == "cuda":
219
+ self.embedding_model.half() # Use FP16 for memory efficiency
220
+
221
+ # Load reranker (smaller model for efficiency)
222
+ self.reranker = CrossEncoder(
223
+ 'cross-encoder/ms-marco-MiniLM-L-6-v2',
224
+ device=DEVICE
225
+ )
226
+
227
+ # Load text generator with optimization
228
+ self.generator = pipeline(
229
+ "text-generation",
230
+ model="microsoft/DialoGPT-small", # Smaller model for efficiency
231
+ tokenizer="microsoft/DialoGPT-small",
232
+ device=0 if DEVICE == "cuda" else -1,
233
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
234
+ return_full_text=False,
235
+ max_new_tokens=512,
236
+ do_sample=True,
237
+ temperature=0.7,
238
+ pad_token_id=50256
239
+ )
240
+
241
+ self.memory_manager.optimize_memory()
242
+ logger.info("Models loaded successfully")
243
+
244
+ except Exception as e:
245
+ logger.error(f"Model loading error: {e}")
246
+ raise
247
+
248
+ def search_arxiv(self, query: str, max_results: int = 15, categories: List[str] = None) -> List[Paper]:
249
+ """Search ArXiv with error handling and rate limiting"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  try:
251
+ papers = []
252
+ search_query = query
253
+
254
+ if categories:
255
+ category_filter = " OR ".join([f"cat:{cat.strip()}" for cat in categories])
256
+ search_query = f"({query}) AND ({category_filter})"
257
+
258
+ logger.info(f"Searching ArXiv for: {search_query}")
259
+
260
  search = arxiv.Search(
261
  query=search_query,
262
  max_results=max_results,
263
+ sort_by=arxiv.SortCriterion.Relevance,
264
+ sort_order=arxiv.SortOrder.Descending
265
  )
266
 
 
267
  for result in search.results():
268
+ try:
269
+ paper = Paper(
270
+ id=result.entry_id.split('/')[-1],
271
+ title=result.title,
272
+ abstract=result.summary,
273
+ authors=[author.name for author in result.authors],
274
+ categories=result.categories,
275
+ published=result.published,
276
+ url=result.entry_id
277
+ )
278
+ papers.append(paper)
279
+
280
+ # Rate limiting
281
+ time.sleep(0.1)
282
+
283
+ except Exception as e:
284
+ logger.warning(f"Error processing paper: {e}")
285
+ continue
286
 
287
+ logger.info(f"Found {len(papers)} papers")
288
  return papers
289
 
290
  except Exception as e:
291
+ logger.error(f"ArXiv search error: {e}")
292
  return []
293
+
294
  def create_chunks(self, papers: List[Paper]) -> List[Chunk]:
295
  """Create text chunks from papers"""
296
  chunks = []
297
 
298
  for paper in papers:
299
+ try:
300
+ # Title chunk
301
+ chunks.append(Chunk(
302
+ id=f"{paper.id}_title",
303
+ paper_id=paper.id,
304
+ text=paper.title,
305
+ chunk_type="title",
306
+ metadata={"paper": paper}
307
+ ))
308
+
309
+ # Abstract chunks (split if too long)
310
+ abstract_sentences = sent_tokenize(paper.abstract)
311
+ chunk_size = 3 # sentences per chunk
312
+
313
+ for i in range(0, len(abstract_sentences), chunk_size):
314
+ chunk_text = ' '.join(abstract_sentences[i:i + chunk_size])
315
+ chunks.append(Chunk(
316
+ id=f"{paper.id}_abstract_{i}",
317
+ paper_id=paper.id,
318
+ text=chunk_text,
319
+ chunk_type="abstract",
320
+ metadata={"paper": paper}
321
+ ))
322
+
323
+ except Exception as e:
324
+ logger.warning(f"Error creating chunks for paper {paper.id}: {e}")
325
+ continue
326
+
327
+ return chunks
328
+
329
+ @spaces.GPU(duration=120) # HuggingFace Spaces GPU decorator
330
+ def embed_chunks(self, chunks: List[Chunk]) -> np.ndarray:
331
+ """Create embeddings for chunks with GPU optimization"""
332
+ try:
333
+ if not chunks:
334
+ return np.array([])
335
 
336
+ logger.info(f"Creating embeddings for {len(chunks)} chunks")
337
+ self.memory_manager.clear_cache()
 
 
 
 
 
 
 
 
 
 
 
338
 
339
+ texts = [chunk.text for chunk in chunks]
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
+ # Batch processing for efficiency
342
+ batch_size = 32 if DEVICE == "cuda" else 8
343
+ embeddings = []
344
 
345
+ for i in range(0, len(texts), batch_size):
346
+ batch_texts = texts[i:i + batch_size]
347
+
348
+ with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
349
+ batch_embeddings = self.embedding_model.encode(
350
+ batch_texts,
351
+ convert_to_tensor=True,
352
+ show_progress_bar=False,
353
+ batch_size=len(batch_texts)
354
+ )
355
+
356
+ if DEVICE == "cuda":
357
+ batch_embeddings = batch_embeddings.cpu()
358
+
359
+ embeddings.append(batch_embeddings.numpy())
360
+
361
+ # Memory management
362
+ if i % (batch_size * 4) == 0:
363
+ self.memory_manager.clear_cache()
364
+
365
+ result = np.vstack(embeddings) if embeddings else np.array([])
366
+ self.memory_manager.clear_cache()
367
+
368
+ logger.info(f"Created embeddings shape: {result.shape}")
369
+ return result
370
+
371
+ except Exception as e:
372
+ logger.error(f"Embedding error: {e}")
373
+ self.memory_manager.clear_cache()
374
+ return np.array([])
375
+
376
+ @spaces.GPU(duration=60) # HuggingFace Spaces GPU decorator
377
+ def hybrid_retrieval(self, query: str, top_k: int = 10, semantic_weight: float = 0.7) -> List[Tuple[Chunk, float]]:
378
+ """Perform hybrid retrieval with GPU optimization"""
379
+ try:
380
+ if not self.chunks or self.embeddings is None or len(self.embeddings) == 0:
381
+ return []
382
+
383
+ self.memory_manager.clear_cache()
384
+
385
+ # Semantic search
386
+ with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
387
+ query_embedding = self.embedding_model.encode(
388
+ [query],
389
+ convert_to_tensor=True,
390
+ show_progress_bar=False
391
+ )
392
+
393
+ if DEVICE == "cuda":
394
+ query_embedding = query_embedding.cpu()
395
+
396
+ query_embedding = query_embedding.numpy()
397
+
398
+ semantic_scores = cosine_similarity(query_embedding, self.embeddings)[0]
399
+
400
+ # BM25 search
401
+ bm25_scores = self.bm25.get_scores(query)
402
+
403
+ # Ensure same length
404
+ min_length = min(len(semantic_scores), len(bm25_scores), len(self.chunks))
405
+ semantic_scores = semantic_scores[:min_length]
406
+ bm25_scores = bm25_scores[:min_length]
407
+ chunks = self.chunks[:min_length]
408
+
409
+ # Normalize scores
410
+ if len(semantic_scores) > 0:
411
+ semantic_scores = (semantic_scores - semantic_scores.min()) / (semantic_scores.max() - semantic_scores.min() + 1e-8)
412
+ if len(bm25_scores) > 0:
413
+ bm25_scores = (bm25_scores - bm25_scores.min()) / (bm25_scores.max() - bm25_scores.min() + 1e-8)
414
+
415
+ # Combine scores
416
+ combined_scores = semantic_weight * semantic_scores + (1 - semantic_weight) * bm25_scores
417
+
418
+ # Get top results
419
+ top_indices = np.argsort(combined_scores)[::-1][:top_k]
420
+ results = [(chunks[i], float(combined_scores[i])) for i in top_indices]
421
+
422
+ self.memory_manager.clear_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  return results
424
+
425
+ except Exception as e:
426
+ logger.error(f"Retrieval error: {e}")
427
+ self.memory_manager.clear_cache()
428
+ return []
429
+
430
+ @spaces.GPU(duration=60) # HuggingFace Spaces GPU decorator
431
+ def rerank_results(self, query: str, results: List[Tuple[Chunk, float]], top_k: int = 5) -> List[Tuple[Chunk, float]]:
432
+ """Rerank results using cross-encoder with GPU optimization"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  try:
434
+ if not results or not self.reranker:
435
+ return results[:top_k]
436
+
437
+ self.memory_manager.clear_cache()
438
+
439
+ pairs = [(query, chunk.text) for chunk, _ in results]
440
+
441
+ with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
442
+ rerank_scores = self.reranker.predict(pairs, show_progress_bar=False)
443
+
444
+ # Combine with original scores
445
+ reranked_results = []
446
+ for i, (chunk, original_score) in enumerate(results):
447
+ combined_score = 0.6 * float(rerank_scores[i]) + 0.4 * original_score
448
+ reranked_results.append((chunk, combined_score))
449
+
450
+ # Sort by new scores
451
+ reranked_results.sort(key=lambda x: x[1], reverse=True)
452
+
453
+ self.memory_manager.clear_cache()
454
+ return reranked_results[:top_k]
455
+
456
  except Exception as e:
457
+ logger.error(f"Reranking error: {e}")
458
+ self.memory_manager.clear_cache()
459
+ return results[:top_k]
460
+
461
+ @spaces.GPU(duration=90) # HuggingFace Spaces GPU decorator
462
+ def generate_answer(self, query: str, context_chunks: List[Chunk]) -> str:
463
+ """Generate answer using retrieved context with GPU optimization"""
 
 
 
 
 
 
 
 
 
 
 
464
  try:
465
+ if not context_chunks or not self.generator:
466
+ return "No relevant information found to answer your query."
467
+
468
+ self.memory_manager.clear_cache()
469
+
470
+ # Create context
471
+ context_parts = []
472
+ for chunk in context_chunks[:3]: # Limit context
473
+ paper = chunk.metadata.get("paper")
474
+ if paper:
475
+ context_parts.append(f"Title: {paper.title}\nContent: {chunk.text}")
476
+
477
+ context = "\n\n".join(context_parts)
478
+
479
+ # Create prompt
480
+ prompt = f"""Based on the following research papers, provide a comprehensive answer to the query:
481
+
482
+ Query: {query}
483
+
484
+ Research Context:
485
+ {context[:2000]}
486
+
487
+ Answer:"""
488
+
489
+ with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
490
+ response = self.generator(
491
+ prompt,
492
+ max_new_tokens=300,
493
+ temperature=0.7,
494
+ do_sample=True,
495
+ pad_token_id=50256
496
+ )
497
+
498
+ answer = response[0]['generated_text'].strip()
499
+
500
+ self.memory_manager.clear_cache()
501
+ return answer
502
+
503
+ except Exception as e:
504
+ logger.error(f"Answer generation error: {e}")
505
+ self.memory_manager.clear_cache()
506
+ return f"Error generating answer: {str(e)}"
507
+
508
+ def format_results(self, results: List[Tuple[Chunk, float]]) -> Tuple[str, pd.DataFrame]:
509
+ """Format results for display"""
510
+ try:
511
+ if not results:
512
+ return "No relevant papers found.", pd.DataFrame()
513
+
514
+ # Group by paper
515
+ papers_dict = {}
516
+ for chunk, score in results:
517
+ paper = chunk.metadata.get("paper")
518
+ if paper and paper.id not in papers_dict:
519
+ papers_dict[paper.id] = {
520
+ 'paper': paper,
521
+ 'max_score': score,
522
+ 'chunks': [(chunk, score)]
523
  }
524
+ elif paper:
525
+ papers_dict[paper.id]['chunks'].append((chunk, score))
526
+ papers_dict[paper.id]['max_score'] = max(papers_dict[paper.id]['max_score'], score)
527
 
528
+ # Sort by max score
529
+ sorted_papers = sorted(papers_dict.values(), key=lambda x: x['max_score'], reverse=True)
530
+
531
+ # Format markdown
532
+ markdown_parts = []
533
+ table_data = []
534
+
535
+ for i, paper_info in enumerate(sorted_papers[:8], 1):
536
+ paper = paper_info['paper']
537
+ score = paper_info['max_score']
538
+
539
+ # Markdown format
540
+ authors_str = ", ".join(paper.authors[:3])
541
+ if len(paper.authors) > 3:
542
+ authors_str += " et al."
543
+
544
+ categories_str = ", ".join(paper.categories[:3])
545
+
546
+ markdown_parts.append(f"""
547
+ ### {i}. [{paper.title}]({paper.url})
548
+
549
+ **Authors:** {authors_str}
550
+ **Categories:** {categories_str}
551
+ **Published:** {paper.published.strftime('%Y-%m-%d')}
552
+ **Relevance Score:** {score:.3f}
553
+
554
+ **Abstract:** {paper.abstract[:300]}{'...' if len(paper.abstract) > 300 else ''}
555
+
556
+ ---
557
+ """)
558
+
559
+ # Table data
560
+ table_data.append({
561
+ 'Rank': i,
562
+ 'Title': paper.title[:60] + ('...' if len(paper.title) > 60 else ''),
563
+ 'Authors': authors_str,
564
+ 'Categories': categories_str,
565
+ 'Published': paper.published.strftime('%Y-%m-%d'),
566
+ 'Score': f"{score:.3f}",
567
+ 'URL': paper.url
568
+ })
569
+
570
+ markdown_text = "".join(markdown_parts)
571
+ df = pd.DataFrame(table_data)
572
+
573
+ return markdown_text, df
574
 
575
  except Exception as e:
576
+ logger.error(f"Formatting error: {e}")
577
+ return f"Error formatting results: {str(e)}", pd.DataFrame()
578
+
579
+ # Global RAG system instance
 
 
 
 
 
580
  rag_system = None
581
 
582
+ def initialize_system():
583
+ """Initialize the RAG system"""
584
  global rag_system
 
 
 
 
 
 
 
 
 
 
 
 
585
  try:
586
+ if rag_system is None:
587
+ logger.info("Initializing RAG system...")
588
+ rag_system = OptimizedRagSystem()
589
+ logger.info("RAG system initialized successfully")
590
+ except Exception as e:
591
+ logger.error(f"System initialization error: {e}")
592
+ raise
593
+
594
+ # Main search function
595
+ @spaces.GPU(duration=180) # HuggingFace Spaces GPU decorator for main function
596
+ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
597
+ top_k_rerank: int = 5, categories: str = "", semantic_weight: float = 0.7):
598
+ """Main search function with GPU optimization"""
599
+ try:
600
+ if not query.strip():
601
+ return "❌ Please enter a search query.", "", pd.DataFrame()
602
+
603
+ # Initialize system if needed
604
+ initialize_system()
605
+
606
+ start_time = time.time()
607
 
608
  # Parse categories
609
+ category_list = []
610
  if categories.strip():
611
  category_list = [cat.strip() for cat in categories.split(',') if cat.strip()]
612
 
613
+ # Search ArXiv
614
+ papers = rag_system.search_arxiv(query, max_papers, category_list)
615
+
616
+ if not papers:
617
+ return "❌ No papers found. Try different keywords or check your internet connection.", "", pd.DataFrame()
618
+
619
+ # Create chunks and embeddings
620
+ rag_system.papers = papers
621
+ rag_system.chunks = rag_system.create_chunks(papers)
622
+
623
+ if not rag_system.chunks:
624
+ return "❌ Error processing papers.", "", pd.DataFrame()
625
+
626
+ # Create embeddings with GPU acceleration
627
+ rag_system.embeddings = rag_system.embed_chunks(rag_system.chunks)
628
+
629
+ if rag_system.embeddings is None or len(rag_system.embeddings) == 0:
630
+ return "❌ Error creating embeddings.", "", pd.DataFrame()
631
 
632
+ # Fit BM25
633
+ chunk_texts = [chunk.text for chunk in rag_system.chunks]
634
+ rag_system.bm25.fit(chunk_texts)
635
+
636
+ # Hybrid retrieval with GPU acceleration
637
+ retrieved_results = rag_system.hybrid_retrieval(query, top_k_retrieval, semantic_weight)
638
+
639
+ if not retrieved_results:
640
+ return "❌ No relevant content found.", "", pd.DataFrame()
641
+
642
+ # Rerank results with GPU acceleration
643
+ reranked_results = rag_system.rerank_results(query, retrieved_results, top_k_rerank)
644
 
645
+ # Generate answer with GPU acceleration
646
+ answer = rag_system.generate_answer(query, [chunk for chunk, _ in reranked_results])
 
 
 
 
 
 
 
 
647
 
648
+ # Format results
649
+ papers_md, papers_df = rag_system.format_results(reranked_results)
 
 
 
 
 
 
 
 
 
650
 
651
+ # Create response with statistics
652
+ end_time = time.time()
653
+ processing_time = end_time - start_time
654
+
655
+ stats = f"""
656
+ ## πŸ€– AI-Generated Answer
657
+
658
+ {answer}
659
+
660
+ ## πŸ“Š Search Statistics
661
+
662
+ - **Query:** {query}
663
+ - **Papers Found:** {len(papers)}
664
+ - **Chunks Processed:** {len(rag_system.chunks)}
665
+ - **Top Results:** {len(reranked_results)}
666
+ - **Processing Time:** {processing_time:.2f}s
667
+ - **GPU Memory:** {rag_system.memory_manager.get_memory_info()}
668
+ - **Semantic Weight:** {semantic_weight}
669
+
670
+ ---
671
+ """
672
+
673
+ # Clean up GPU memory
674
+ rag_system.memory_manager.clear_cache()
675
+
676
+ return stats, papers_md, papers_df
677
 
678
  except Exception as e:
679
+ logger.error(f"Search error: {e}")
680
  error_msg = f"❌ An error occurred: {str(e)}\n\nPlease try different keywords or check your internet connection."
681
  return error_msg, "", pd.DataFrame()
682
 
683
  # Create Gradio interface
684
  def create_interface():
685
+ """Create optimized Gradio interface"""
686
 
687
  css = """
688
  .gradio-container {
689
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
690
  }
691
+ .gpu-badge {
692
+ background: linear-gradient(45deg, #00d4aa, #00b4d8);
693
+ color: white;
694
+ padding: 0.5rem 1rem;
695
+ border-radius: 20px;
696
+ font-weight: bold;
697
+ display: inline-block;
698
+ margin-bottom: 1rem;
699
+ }
700
  """
701
 
702
+ with gr.Blocks(css=css, title="Enhanced ArXiv RAG System - GPU Optimized") as interface:
703
 
704
+ gr.HTML(f"""
705
  <div style="text-align: center; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
706
  <h1>πŸš€ Enhanced ArXiv RAG System</h1>
707
+ <p>GPU-Optimized scientific paper discovery with semantic search, BM25, and neural reranking</p>
708
+ <div class="gpu-badge">
709
+ πŸ”₯ GPU Accelerated β€’ Device: {DEVICE.upper()}
710
+ </div>
711
  </div>
712
  """)
713
 
 
729
  value=""
730
  )
731
 
732
+ with gr.Accordion("πŸ”§ Advanced GPU Settings", open=False):
733
  with gr.Row():
734
  top_k_retrieval = gr.Slider(5, 15, value=10, step=1, label="Top-K Retrieval")
735
  top_k_rerank = gr.Slider(3, 8, value=5, step=1, label="Top-K Reranking")
736
+
737
+ gr.HTML(f"""
738
+ <div style="background: #e8f5e8; padding: 1rem; border-radius: 8px; margin-top: 1rem;">
739
+ <h4>⚑ GPU Optimization Info</h4>
740
+ <ul>
741
+ <li><strong>Device:</strong> {DEVICE.upper()}</li>
742
+ <li><strong>Mixed Precision:</strong> {'Enabled' if DEVICE == 'cuda' else 'Disabled'}</li>
743
+ <li><strong>Memory Management:</strong> Automatic cleanup</li>
744
+ <li><strong>Batch Processing:</strong> Optimized for GPU</li>
745
+ </ul>
746
+ </div>
747
+ """)
748
 
749
+ search_btn = gr.Button("πŸ” Search Papers", variant="primary", size="lg")
750
 
751
  with gr.Column(scale=1):
752
  gr.HTML("""
753
  <div style="background: #e3f2fd; padding: 1rem; border-radius: 8px;">
754
+ <h4>πŸ’‘ Tips for Best Results</h4>
755
  <ul>
756
  <li>Use specific technical terms</li>
757
  <li>Try different category filters</li>
758
  <li>Adjust semantic weight for different search styles</li>
759
+ <li>Higher semantic weight = more conceptual matching</li>
760
+ <li>Lower semantic weight = more keyword matching</li>
761
  </ul>
762
 
763
+ <h4>πŸ“Š Popular Categories</h4>
764
  <ul>
765
  <li><code>cs.AI</code> - Artificial Intelligence</li>
766
  <li><code>cs.CL</code> - Computation and Language</li>
767
  <li><code>cs.LG</code> - Machine Learning</li>
768
  <li><code>cs.CV</code> - Computer Vision</li>
769
+ <li><code>cs.RO</code> - Robotics</li>
770
+ <li><code>stat.ML</code> - Machine Learning (Stats)</li>
771
  </ul>
772
  </div>
773
  """)
 
780
  papers_output = gr.Markdown(label="Relevant Papers")
781
 
782
  with gr.TabItem("πŸ“Š Papers Table"):
783
+ papers_table = gr.Dataframe(
784
+ label="Papers Summary",
785
+ wrap=True,
786
+ interactive=False
787
+ )
788
 
789
  # Examples
790
  gr.Examples(
791
  examples=[
792
  ["transformer attention mechanisms", 15, 10, 5, "cs.CL, cs.AI", 0.7],
793
+ ["graph neural networks for molecular property prediction", 12, 8, 4, "cs.LG", 0.6],
794
  ["computer vision deep learning", 15, 10, 5, "cs.CV", 0.8],
795
+ ["reinforcement learning robotics", 18, 10, 5, "cs.AI, cs.RO", 0.7],
796
+ ["large language models fine-tuning", 20, 12, 6, "cs.CL", 0.75]
797
  ],
798
  inputs=[query_input, max_papers, top_k_retrieval, top_k_rerank, categories_input, semantic_weight]
799
  )
 
807
 
808
  gr.HTML("""
809
  <div style="text-align: center; margin-top: 2rem; padding: 1rem; background: #f5f5f5; border-radius: 8px;">
810
+ <p><strong>Enhanced ArXiv RAG System</strong> | GPU-Optimized β€’ Semantic Search + BM25 + Neural Reranking</p>
811
+ <p><em>Powered by Hugging Face Spaces GPU β€’ Optimized for high-performance research</em></p>
812
  </div>
813
  """)
814
 
 
816
 
817
  # Launch interface
818
  if __name__ == "__main__":
819
+ # Pre-initialize system to reduce first-run latency
820
+ try:
821
+ initialize_system()
822
+ except Exception as e:
823
+ logger.error(f"Pre-initialization failed: {e}")
824
+
825
  interface = create_interface()
826
+ interface.launch(
827
+ show_error=True,
828
+ share=True,
829
+ enable_queue=True,
830
+ max_threads=4
831
+ )