mihirinamdar commited on
Commit
a844998
Β·
verified Β·
1 Parent(s): 1117778

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +704 -0
app.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced ArXiv RAG System - Hugging Face Spaces Compatible Version
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import arxiv
8
+ import numpy as np
9
+ import pandas as pd
10
+ from typing import List, Dict, Tuple, Optional, Any
11
+ from dataclasses import dataclass
12
+ from datetime import datetime, timedelta
13
+ import logging
14
+ import tempfile
15
+ import shutil
16
+
17
+ # Core ML libraries
18
+ import torch
19
+ from sentence_transformers import SentenceTransformer, CrossEncoder
20
+ from transformers import pipeline
21
+ import gradio as gr
22
+
23
+ # BM25 and text processing
24
+ from sklearn.feature_extraction.text import TfidfVectorizer
25
+ from sklearn.metrics.pairwise import cosine_similarity
26
+ import nltk
27
+ from nltk.corpus import stopwords
28
+ from nltk.tokenize import word_tokenize, sent_tokenize
29
+ from nltk.stem import PorterStemmer
30
+
31
+ # Download required NLTK data
32
+ try:
33
+ nltk.data.find('tokenizers/punkt')
34
+ except LookupError:
35
+ nltk.download('punkt')
36
+
37
+ try:
38
+ nltk.data.find('corpora/stopwords')
39
+ except LookupError:
40
+ nltk.download('stopwords')
41
+
42
+ # Setup logging
43
+ logging.basicConfig(level=logging.INFO)
44
+ logger = logging.getLogger(__name__)
45
+
46
+ @dataclass
47
+ class Paper:
48
+ """Data class for storing paper information"""
49
+ id: str
50
+ title: str
51
+ abstract: str
52
+ authors: List[str]
53
+ categories: List[str]
54
+ published: datetime
55
+ url: str
56
+
57
+ @dataclass
58
+ class Chunk:
59
+ """Data class for storing text chunks"""
60
+ id: str
61
+ paper_id: str
62
+ text: str
63
+ chunk_type: str
64
+ metadata: Dict[str, Any]
65
+
66
+ class BM25Retriever:
67
+ """BM25 retriever for keyword-based search"""
68
+
69
+ def __init__(self, k1: float = 1.5, b: float = 0.75):
70
+ self.k1 = k1
71
+ self.b = b
72
+ self.documents = []
73
+ self.doc_lengths = []
74
+ self.avg_doc_length = 0
75
+ self.stemmer = PorterStemmer()
76
+ try:
77
+ self.stop_words = set(stopwords.words('english'))
78
+ except:
79
+ self.stop_words = set()
80
+
81
+ def preprocess_text(self, text: str) -> List[str]:
82
+ """Preprocess text for BM25"""
83
+ tokens = word_tokenize(text.lower())
84
+ processed_tokens = [
85
+ self.stemmer.stem(token)
86
+ for token in tokens
87
+ if token.isalpha() and token not in self.stop_words
88
+ ]
89
+ return processed_tokens
90
+
91
+ def fit(self, documents: List[str]):
92
+ """Fit BM25 on documents"""
93
+ self.documents = [self.preprocess_text(doc) for doc in documents]
94
+ self.doc_lengths = [len(doc) for doc in self.documents]
95
+ self.avg_doc_length = sum(self.doc_lengths) / len(self.doc_lengths) if self.doc_lengths else 0
96
+
97
+ vocab = set()
98
+ for doc in self.documents:
99
+ vocab.update(doc)
100
+ self.vocab = list(vocab)
101
+
102
+ self.term_freqs = []
103
+ for doc in self.documents:
104
+ tf = {}
105
+ for term in doc:
106
+ tf[term] = tf.get(term, 0) + 1
107
+ self.term_freqs.append(tf)
108
+
109
+ self.idf = {}
110
+ for term in self.vocab:
111
+ containing_docs = sum(1 for tf in self.term_freqs if term in tf)
112
+ self.idf[term] = np.log((len(self.documents) - containing_docs + 0.5) / (containing_docs + 0.5))
113
+
114
+ def score(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]:
115
+ """Score documents against query"""
116
+ query_terms = self.preprocess_text(query)
117
+ scores = []
118
+
119
+ for i, (doc, tf, doc_len) in enumerate(zip(self.documents, self.term_freqs, self.doc_lengths)):
120
+ score = 0
121
+ for term in query_terms:
122
+ if term in tf:
123
+ term_freq = tf[term]
124
+ idf = self.idf.get(term, 0)
125
+ numerator = term_freq * (self.k1 + 1)
126
+ denominator = term_freq + self.k1 * (1 - self.b + self.b * (doc_len / self.avg_doc_length))
127
+ score += idf * (numerator / denominator)
128
+ scores.append((i, score))
129
+
130
+ scores.sort(key=lambda x: x[1], reverse=True)
131
+ return scores[:top_k]
132
+
133
+ class SimpleVectorStore:
134
+ """Simple in-memory vector store for HF Spaces compatibility"""
135
+
136
+ def __init__(self):
137
+ self.embeddings = []
138
+ self.documents = []
139
+ self.metadatas = []
140
+ self.ids = []
141
+
142
+ def add(self, ids: List[str], embeddings: List[List[float]],
143
+ documents: List[str], metadatas: List[Dict]):
144
+ """Add documents to the store"""
145
+ self.ids.extend(ids)
146
+ self.embeddings.extend(embeddings)
147
+ self.documents.extend(documents)
148
+ self.metadatas.extend(metadatas)
149
+
150
+ def query(self, query_embedding: List[float], n_results: int = 10) -> Dict:
151
+ """Query the vector store"""
152
+ if not self.embeddings:
153
+ return {'ids': [[]], 'documents': [[]], 'metadatas': [[]]}
154
+
155
+ # Calculate cosine similarities
156
+ query_embedding = np.array(query_embedding)
157
+ similarities = []
158
+
159
+ for emb in self.embeddings:
160
+ emb_array = np.array(emb)
161
+ similarity = np.dot(query_embedding, emb_array) / (
162
+ np.linalg.norm(query_embedding) * np.linalg.norm(emb_array)
163
+ )
164
+ similarities.append(similarity)
165
+
166
+ # Get top results
167
+ top_indices = np.argsort(similarities)[::-1][:n_results]
168
+
169
+ return {
170
+ 'ids': [[self.ids[i] for i in top_indices]],
171
+ 'documents': [[self.documents[i] for i in top_indices]],
172
+ 'metadatas': [[self.metadatas[i] for i in top_indices]]
173
+ }
174
+
175
+ def get(self, ids: Optional[List[str]] = None) -> Dict:
176
+ """Get documents by IDs or all documents"""
177
+ if ids is None:
178
+ return {
179
+ 'ids': self.ids,
180
+ 'documents': self.documents,
181
+ 'metadatas': self.metadatas
182
+ }
183
+ else:
184
+ indices = [self.ids.index(id_) for id_ in ids if id_ in self.ids]
185
+ return {
186
+ 'ids': [self.ids[i] for i in indices],
187
+ 'documents': [self.documents[i] for i in indices],
188
+ 'metadatas': [self.metadatas[i] for i in indices]
189
+ }
190
+
191
+ def clear(self):
192
+ """Clear the store"""
193
+ self.embeddings.clear()
194
+ self.documents.clear()
195
+ self.metadatas.clear()
196
+ self.ids.clear()
197
+
198
+ class EnhancedArxivRAG:
199
+ """Enhanced RAG system optimized for Hugging Face Spaces"""
200
+
201
+ def __init__(self):
202
+ logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
203
+
204
+ # Use smaller, faster models for HF Spaces
205
+ self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
206
+ self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2') # Smaller reranker
207
+ self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn",
208
+ device=0 if torch.cuda.is_available() else -1)
209
+
210
+ # Use simple vector store instead of ChromaDB for HF Spaces
211
+ self.vector_store = SimpleVectorStore()
212
+ self.bm25_retriever = BM25Retriever()
213
+
214
+ # Cache for papers and chunks
215
+ self.papers_cache = {}
216
+ self.chunks_cache = {}
217
+ self.bm25_fitted = False
218
+
219
+ logger.info("RAG system initialized successfully!")
220
+
221
+ def fetch_papers(self, query: str, max_results: int = 15,
222
+ categories: Optional[List[str]] = None) -> List[Paper]:
223
+ """Fetch papers from ArXiv"""
224
+ search_query = query
225
+ if categories:
226
+ category_filter = " OR ".join([f"cat:{cat}" for cat in categories])
227
+ search_query = f"({query}) AND ({category_filter})"
228
+
229
+ logger.info(f"Fetching papers with query: {search_query}")
230
+
231
+ try:
232
+ search = arxiv.Search(
233
+ query=search_query,
234
+ max_results=max_results,
235
+ sort_by=arxiv.SortCriterion.Relevance
236
+ )
237
+
238
+ papers = []
239
+ for result in search.results():
240
+ paper = Paper(
241
+ id=result.entry_id.split('/')[-1],
242
+ title=result.title.strip().replace('\n', ' '),
243
+ abstract=result.summary.strip().replace('\n', ' '),
244
+ authors=[author.name for author in result.authors],
245
+ categories=result.categories,
246
+ published=result.published.replace(tzinfo=None),
247
+ url=result.entry_id
248
+ )
249
+ papers.append(paper)
250
+ self.papers_cache[paper.id] = paper
251
+
252
+ logger.info(f"Fetched {len(papers)} papers")
253
+ return papers
254
+
255
+ except Exception as e:
256
+ logger.error(f"Error fetching papers: {e}")
257
+ return []
258
+
259
+ def create_chunks(self, papers: List[Paper]) -> List[Chunk]:
260
+ """Create text chunks from papers"""
261
+ chunks = []
262
+
263
+ for paper in papers:
264
+ # Title chunk
265
+ title_chunk = Chunk(
266
+ id=f"{paper.id}_title",
267
+ paper_id=paper.id,
268
+ text=paper.title,
269
+ chunk_type="title",
270
+ metadata={
271
+ "authors": paper.authors,
272
+ "categories": paper.categories,
273
+ "published": paper.published.isoformat(),
274
+ "url": paper.url
275
+ }
276
+ )
277
+
278
+ # Abstract chunk
279
+ abstract_chunk = Chunk(
280
+ id=f"{paper.id}_abstract",
281
+ paper_id=paper.id,
282
+ text=paper.abstract,
283
+ chunk_type="abstract",
284
+ metadata={
285
+ "authors": paper.authors,
286
+ "categories": paper.categories,
287
+ "published": paper.published.isoformat(),
288
+ "url": paper.url
289
+ }
290
+ )
291
+
292
+ # Combined chunk
293
+ combined_text = f"Title: {paper.title}\n\nAbstract: {paper.abstract}"
294
+ combined_chunk = Chunk(
295
+ id=f"{paper.id}_combined",
296
+ paper_id=paper.id,
297
+ text=combined_text,
298
+ chunk_type="combined",
299
+ metadata={
300
+ "authors": paper.authors,
301
+ "categories": paper.categories,
302
+ "published": paper.published.isoformat(),
303
+ "url": paper.url
304
+ }
305
+ )
306
+
307
+ chunks.extend([title_chunk, abstract_chunk, combined_chunk])
308
+
309
+ # Cache chunks
310
+ for chunk in [title_chunk, abstract_chunk, combined_chunk]:
311
+ self.chunks_cache[chunk.id] = chunk
312
+
313
+ return chunks
314
+
315
+ def process_and_store(self, papers: List[Paper]):
316
+ """Process papers and store in vector store"""
317
+ logger.info("Processing and storing papers...")
318
+
319
+ # Clear previous data
320
+ self.vector_store.clear()
321
+
322
+ # Create chunks
323
+ chunks = self.create_chunks(papers)
324
+
325
+ if not chunks:
326
+ return
327
+
328
+ # Generate embeddings
329
+ texts = [chunk.text for chunk in chunks]
330
+ logger.info("Generating embeddings...")
331
+ embeddings = self.embedding_model.encode(texts, show_progress_bar=False)
332
+
333
+ # Store in vector store
334
+ ids = [chunk.id for chunk in chunks]
335
+ metadatas = [chunk.metadata for chunk in chunks]
336
+
337
+ self.vector_store.add(
338
+ ids=ids,
339
+ embeddings=embeddings.tolist(),
340
+ documents=texts,
341
+ metadatas=metadatas
342
+ )
343
+
344
+ # Fit BM25
345
+ logger.info("Fitting BM25...")
346
+ self.bm25_retriever.fit(texts)
347
+ self.bm25_fitted = True
348
+
349
+ logger.info(f"Stored {len(chunks)} chunks")
350
+
351
+ def hybrid_search(self, query: str, top_k: int = 10,
352
+ semantic_weight: float = 0.7) -> List[Dict]:
353
+ """Perform hybrid search"""
354
+ # Semantic search
355
+ query_embedding = self.embedding_model.encode([query])
356
+ semantic_results = self.vector_store.query(
357
+ query_embedding=query_embedding[0].tolist(),
358
+ n_results=top_k * 2
359
+ )
360
+
361
+ # BM25 search
362
+ bm25_results = []
363
+ if self.bm25_fitted:
364
+ all_docs = self.vector_store.get()
365
+ bm25_scores = self.bm25_retriever.score(query, top_k * 2)
366
+
367
+ for idx, score in bm25_scores:
368
+ if idx < len(all_docs['ids']):
369
+ bm25_results.append({
370
+ 'id': all_docs['ids'][idx],
371
+ 'document': all_docs['documents'][idx],
372
+ 'metadata': all_docs['metadatas'][idx],
373
+ 'score': score
374
+ })
375
+
376
+ # Combine results using RRF
377
+ combined_scores = {}
378
+ bm25_weight = 1.0 - semantic_weight
379
+
380
+ # Add semantic scores
381
+ for i, doc_id in enumerate(semantic_results['ids'][0]):
382
+ rank = i + 1
383
+ combined_scores[doc_id] = combined_scores.get(doc_id, 0) + semantic_weight / rank
384
+
385
+ # Add BM25 scores
386
+ for i, result in enumerate(bm25_results):
387
+ doc_id = result['id']
388
+ rank = i + 1
389
+ combined_scores[doc_id] = combined_scores.get(doc_id, 0) + bm25_weight / rank
390
+
391
+ # Sort by combined score
392
+ sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
393
+
394
+ # Prepare final results
395
+ final_results = []
396
+ for doc_id, score in sorted_results[:top_k]:
397
+ doc_result = self.vector_store.get(ids=[doc_id])
398
+ if doc_result['ids']:
399
+ final_results.append({
400
+ 'id': doc_id,
401
+ 'document': doc_result['documents'][0],
402
+ 'metadata': doc_result['metadatas'][0],
403
+ 'combined_score': score
404
+ })
405
+
406
+ return final_results
407
+
408
+ def rerank_results(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
409
+ """Rerank results using cross-encoder"""
410
+ if not results:
411
+ return results
412
+
413
+ # Prepare query-document pairs
414
+ query_doc_pairs = [(query, result['document']) for result in results]
415
+
416
+ # Get reranking scores
417
+ rerank_scores = self.reranker.predict(query_doc_pairs)
418
+
419
+ # Add rerank scores to results
420
+ for i, result in enumerate(results):
421
+ result['rerank_score'] = float(rerank_scores[i])
422
+
423
+ # Sort by rerank score
424
+ reranked_results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
425
+ return reranked_results[:top_k]
426
+
427
+ def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
428
+ """Generate answer using retrieved context"""
429
+ if not context_chunks:
430
+ return "No relevant information found to answer your query."
431
+
432
+ # Combine context from top chunks
433
+ context_texts = [chunk['document'] for chunk in context_chunks[:3]]
434
+ combined_context = "\n\n".join(context_texts)
435
+
436
+ # Limit context length
437
+ max_context_length = 800
438
+ if len(combined_context) > max_context_length:
439
+ combined_context = combined_context[:max_context_length] + "..."
440
+
441
+ try:
442
+ summary_input = f"Based on the following research papers, answer this question: {query}\n\nContext: {combined_context}"
443
+ summary = self.summarizer(summary_input,
444
+ max_length=120,
445
+ min_length=30,
446
+ do_sample=False)[0]['summary_text']
447
+ return summary
448
+ except Exception as e:
449
+ logger.error(f"Error generating summary: {e}")
450
+ return f"Based on the retrieved papers about '{query}', here are the key findings:\n\n" + \
451
+ "\n\n".join([chunk['document'][:150] + "..." for chunk in context_chunks[:2]])
452
+
453
+ def search_and_answer(self, query: str, max_papers: int = 15,
454
+ top_k_retrieval: int = 10, top_k_rerank: int = 5,
455
+ categories: Optional[List[str]] = None,
456
+ semantic_weight: float = 0.7) -> Dict[str, Any]:
457
+ """Main search and answer pipeline"""
458
+
459
+ if not query.strip():
460
+ return {
461
+ 'answer': "Please enter a valid research query.",
462
+ 'papers': [],
463
+ 'retrieved_chunks': [],
464
+ 'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
465
+ }
466
+
467
+ try:
468
+ # Fetch papers
469
+ papers = self.fetch_papers(query, max_papers, categories)
470
+
471
+ if not papers:
472
+ return {
473
+ 'answer': "No papers found for your query. Please try different keywords.",
474
+ 'papers': [],
475
+ 'retrieved_chunks': [],
476
+ 'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
477
+ }
478
+
479
+ # Process and store papers
480
+ self.process_and_store(papers)
481
+
482
+ # Hybrid search
483
+ search_results = self.hybrid_search(query, top_k_retrieval, semantic_weight)
484
+
485
+ # Rerank results
486
+ reranked_results = self.rerank_results(query, search_results, top_k_rerank)
487
+
488
+ # Generate answer
489
+ answer = self.generate_answer(query, reranked_results)
490
+
491
+ # Prepare unique papers
492
+ unique_papers = {}
493
+ for chunk in reranked_results:
494
+ paper_id = chunk['id'].split('_')[0]
495
+ if paper_id in self.papers_cache and paper_id not in unique_papers:
496
+ paper = self.papers_cache[paper_id]
497
+ unique_papers[paper_id] = {
498
+ 'title': paper.title,
499
+ 'authors': paper.authors,
500
+ 'abstract': paper.abstract,
501
+ 'url': paper.url,
502
+ 'categories': paper.categories,
503
+ 'published': paper.published.strftime('%Y-%m-%d')
504
+ }
505
+
506
+ return {
507
+ 'answer': answer,
508
+ 'papers': list(unique_papers.values()),
509
+ 'retrieved_chunks': reranked_results,
510
+ 'search_stats': {
511
+ 'papers_found': len(papers),
512
+ 'chunks_retrieved': len(reranked_results),
513
+ 'unique_papers_in_results': len(unique_papers)
514
+ }
515
+ }
516
+
517
+ except Exception as e:
518
+ logger.error(f"Error in search_and_answer: {e}")
519
+ return {
520
+ 'answer': f"An error occurred while processing your query: {str(e)}",
521
+ 'papers': [],
522
+ 'retrieved_chunks': [],
523
+ 'search_stats': {'papers_found': 0, 'chunks_retrieved': 0}
524
+ }
525
+
526
+ # Global RAG instance
527
+ rag_system = None
528
+
529
+ def initialize_rag():
530
+ """Initialize RAG system"""
531
+ global rag_system
532
+ if rag_system is None:
533
+ rag_system = EnhancedArxivRAG()
534
+ return rag_system
535
+
536
+ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
537
+ top_k_rerank: int = 5, categories: str = "",
538
+ semantic_weight: float = 0.7) -> tuple:
539
+ """Main search function for Gradio interface"""
540
+
541
+ if not query.strip():
542
+ return "❌ Please enter a research topic or question.", "", ""
543
+
544
+ try:
545
+ # Initialize RAG system
546
+ rag = initialize_rag()
547
+
548
+ # Parse categories
549
+ category_list = None
550
+ if categories.strip():
551
+ category_list = [cat.strip() for cat in categories.split(',') if cat.strip()]
552
+
553
+ # Perform search
554
+ result = rag.search_and_answer(
555
+ query=query,
556
+ max_papers=max_papers,
557
+ top_k_retrieval=top_k_retrieval,
558
+ top_k_rerank=top_k_rerank,
559
+ categories=category_list,
560
+ semantic_weight=semantic_weight
561
+ )
562
+
563
+ # Format answer
564
+ answer = f"## πŸ€– AI-Generated Answer\n\n{result['answer']}\n\n"
565
+ answer += f"**Search Statistics:**\n"
566
+ answer += f"- Papers found: {result['search_stats']['papers_found']}\n"
567
+ answer += f"- Chunks retrieved: {result['search_stats']['chunks_retrieved']}\n"
568
+ answer += f"- Unique papers in results: {result['search_stats']['unique_papers_in_results']}\n\n"
569
+
570
+ # Format papers
571
+ papers_md = "## πŸ“š Relevant Papers\n\n"
572
+ for i, paper in enumerate(result['papers'], 1):
573
+ papers_md += f"### {i}. {paper['title']}\n\n"
574
+ papers_md += f"**Authors:** {', '.join(paper['authors'][:3])}{'...' if len(paper['authors']) > 3 else ''}\n\n"
575
+ papers_md += f"**Categories:** {', '.join(paper['categories'])}\n\n"
576
+ papers_md += f"**Published:** {paper['published']}\n\n"
577
+ papers_md += f"**Abstract:** {paper['abstract'][:250]}{'...' if len(paper['abstract']) > 250 else ''}\n\n"
578
+ papers_md += f"**URL:** [{paper['url']}]({paper['url']})\n\n"
579
+ papers_md += "---\n\n"
580
+
581
+ # Create papers dataframe
582
+ papers_df = pd.DataFrame([
583
+ {
584
+ 'Title': paper['title'][:50] + '...' if len(paper['title']) > 50 else paper['title'],
585
+ 'Authors': ', '.join(paper['authors'][:2]) + ('...' if len(paper['authors']) > 2 else ''),
586
+ 'Categories': ', '.join(paper['categories'][:2]),
587
+ 'Published': paper['published'],
588
+ 'URL': paper['url']
589
+ }
590
+ for paper in result['papers']
591
+ ])
592
+
593
+ return answer, papers_md, papers_df
594
+
595
+ except Exception as e:
596
+ logger.error(f"Error processing query: {e}")
597
+ error_msg = f"❌ An error occurred: {str(e)}\n\nPlease try different keywords or check your internet connection."
598
+ return error_msg, "", pd.DataFrame()
599
+
600
+ # Create Gradio interface
601
+ def create_interface():
602
+ """Create Gradio interface"""
603
+
604
+ css = """
605
+ .gradio-container {
606
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
607
+ }
608
+ """
609
+
610
+ with gr.Blocks(css=css, title="Enhanced ArXiv RAG System") as interface:
611
+
612
+ gr.HTML("""
613
+ <div style="text-align: center; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
614
+ <h1>πŸš€ Enhanced ArXiv RAG System</h1>
615
+ <p>Advanced scientific paper discovery with semantic search, BM25, and neural reranking</p>
616
+ </div>
617
+ """)
618
+
619
+ with gr.Row():
620
+ with gr.Column(scale=2):
621
+ query_input = gr.Textbox(
622
+ label="Research Query",
623
+ placeholder="Enter your research question (e.g., 'transformer attention mechanisms in NLP')",
624
+ lines=2
625
+ )
626
+
627
+ with gr.Row():
628
+ max_papers = gr.Slider(5, 25, value=15, step=1, label="Max Papers")
629
+ semantic_weight = gr.Slider(0.1, 0.9, value=0.7, step=0.1, label="Semantic Weight")
630
+
631
+ categories_input = gr.Textbox(
632
+ label="ArXiv Categories (Optional)",
633
+ placeholder="e.g., cs.CL, cs.AI, cs.LG",
634
+ value=""
635
+ )
636
+
637
+ with gr.Accordion("Advanced Settings", open=False):
638
+ with gr.Row():
639
+ top_k_retrieval = gr.Slider(5, 15, value=10, step=1, label="Top-K Retrieval")
640
+ top_k_rerank = gr.Slider(3, 8, value=5, step=1, label="Top-K Reranking")
641
+
642
+ search_btn = gr.Button("πŸ” Search Papers", variant="primary")
643
+
644
+ with gr.Column(scale=1):
645
+ gr.HTML("""
646
+ <div style="background: #e3f2fd; padding: 1rem; border-radius: 8px;">
647
+ <h4>πŸ’‘ Tips</h4>
648
+ <ul>
649
+ <li>Use specific technical terms</li>
650
+ <li>Try different category filters</li>
651
+ <li>Adjust semantic weight for different search styles</li>
652
+ </ul>
653
+
654
+ <h4>πŸ“Š Categories</h4>
655
+ <ul>
656
+ <li><code>cs.AI</code> - Artificial Intelligence</li>
657
+ <li><code>cs.CL</code> - Computation and Language</li>
658
+ <li><code>cs.LG</code> - Machine Learning</li>
659
+ <li><code>cs.CV</code> - Computer Vision</li>
660
+ </ul>
661
+ </div>
662
+ """)
663
+
664
+ # Results
665
+ answer_output = gr.Markdown(label="AI Answer & Statistics")
666
+
667
+ with gr.Tabs():
668
+ with gr.TabItem("πŸ“š Papers"):
669
+ papers_output = gr.Markdown(label="Relevant Papers")
670
+
671
+ with gr.TabItem("πŸ“Š Papers Table"):
672
+ papers_table = gr.Dataframe(label="Papers Summary")
673
+
674
+ # Examples
675
+ gr.Examples(
676
+ examples=[
677
+ ["transformer attention mechanisms", 15, 10, 5, "cs.CL, cs.AI", 0.7],
678
+ ["graph neural networks", 12, 8, 4, "cs.LG", 0.6],
679
+ ["computer vision deep learning", 15, 10, 5, "cs.CV", 0.8],
680
+ ["reinforcement learning", 18, 10, 5, "cs.AI", 0.7]
681
+ ],
682
+ inputs=[query_input, max_papers, top_k_retrieval, top_k_rerank, categories_input, semantic_weight]
683
+ )
684
+
685
+ # Connect search function
686
+ search_btn.click(
687
+ fn=search_papers,
688
+ inputs=[query_input, max_papers, top_k_retrieval, top_k_rerank, categories_input, semantic_weight],
689
+ outputs=[answer_output, papers_output, papers_table]
690
+ )
691
+
692
+ gr.HTML("""
693
+ <div style="text-align: center; margin-top: 2rem; padding: 1rem; background: #f5f5f5; border-radius: 8px;">
694
+ <p><strong>Enhanced ArXiv RAG System</strong> | Semantic Search + BM25 + Neural Reranking</p>
695
+ </div>
696
+ """)
697
+
698
+ return interface
699
+
700
+ # Launch interface
701
+ if __name__ == "__main__":
702
+ interface = create_interface()
703
+ interface.launch()
704
+