|
""" |
|
Enhanced ArXiv RAG System - GPU Optimized for Hugging Face Spaces |
|
""" |
|
|
|
import os |
|
import re |
|
import arxiv |
|
import numpy as np |
|
import pandas as pd |
|
from typing import List, Dict, Tuple, Optional, Any |
|
from dataclasses import dataclass |
|
from datetime import datetime, timedelta |
|
import logging |
|
import tempfile |
|
import shutil |
|
import gc |
|
import time |
|
import signal |
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from sentence_transformers import SentenceTransformer, CrossEncoder |
|
from transformers import pipeline, AutoTokenizer, AutoModel |
|
import gradio as gr |
|
import spaces |
|
|
|
|
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import nltk |
|
from nltk.corpus import stopwords |
|
from nltk.tokenize import word_tokenize, sent_tokenize |
|
from nltk.stem import PorterStemmer |
|
|
|
|
|
try: |
|
nltk.data.find('tokenizers/punkt') |
|
except LookupError: |
|
nltk.download('punkt', quiet=True) |
|
|
|
try: |
|
nltk.data.find('corpora/stopwords') |
|
except LookupError: |
|
nltk.download('stopwords', quiet=True) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Using device: {DEVICE}") |
|
|
|
if DEVICE == "cuda": |
|
logger.info(f"GPU: {torch.cuda.get_device_name()}") |
|
logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
|
|
|
@dataclass |
|
class Paper: |
|
"""Data class for storing paper information""" |
|
id: str |
|
title: str |
|
abstract: str |
|
authors: List[str] |
|
categories: List[str] |
|
published: datetime |
|
url: str |
|
|
|
@dataclass |
|
class Chunk: |
|
"""Data class for storing text chunks""" |
|
id: str |
|
paper_id: str |
|
text: str |
|
chunk_type: str |
|
metadata: Dict[str, Any] |
|
|
|
class GPUMemoryManager: |
|
"""Manages GPU memory efficiently""" |
|
|
|
@staticmethod |
|
def clear_cache(): |
|
"""Clear GPU cache""" |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
@staticmethod |
|
def get_memory_info(): |
|
"""Get GPU memory information""" |
|
if torch.cuda.is_available(): |
|
allocated = torch.cuda.memory_allocated() / 1e9 |
|
cached = torch.cuda.memory_reserved() / 1e9 |
|
return f"Allocated: {allocated:.1f}GB, Cached: {cached:.1f}GB" |
|
return "CPU mode" |
|
|
|
@staticmethod |
|
def optimize_memory(): |
|
"""Optimize memory usage""" |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
class BM25Retriever: |
|
"""Optimized BM25 retriever for keyword-based search""" |
|
|
|
def __init__(self, k1: float = 1.5, b: float = 0.75): |
|
self.k1 = k1 |
|
self.b = b |
|
self.documents = [] |
|
self.doc_lengths = [] |
|
self.avg_doc_length = 0 |
|
self.stemmer = PorterStemmer() |
|
try: |
|
self.stop_words = set(stopwords.words('english')) |
|
except: |
|
self.stop_words = set() |
|
|
|
def preprocess_text(self, text: str) -> List[str]: |
|
"""Preprocess text for BM25""" |
|
try: |
|
tokens = word_tokenize(text.lower()) |
|
processed_tokens = [ |
|
self.stemmer.stem(token) |
|
for token in tokens |
|
if token.isalpha() and token not in self.stop_words and len(token) > 2 |
|
] |
|
return processed_tokens |
|
except Exception as e: |
|
logger.warning(f"Text preprocessing error: {e}") |
|
return text.lower().split() |
|
|
|
def fit(self, documents: List[str]): |
|
"""Fit BM25 on documents with memory optimization""" |
|
try: |
|
self.documents = [self.preprocess_text(doc) for doc in documents] |
|
self.doc_lengths = [len(doc) for doc in self.documents] |
|
self.avg_doc_length = sum(self.doc_lengths) / len(self.doc_lengths) if self.doc_lengths else 0 |
|
|
|
|
|
vocab = set() |
|
for doc in self.documents: |
|
vocab.update(doc) |
|
self.vocab = list(vocab) |
|
|
|
|
|
self.term_freqs = [] |
|
for doc in self.documents: |
|
tf = {} |
|
for term in doc: |
|
tf[term] = tf.get(term, 0) + 1 |
|
self.term_freqs.append(tf) |
|
|
|
|
|
self.idf = {} |
|
for term in self.vocab: |
|
df = sum(1 for doc in self.documents if term in doc) |
|
self.idf[term] = np.log((len(self.documents) - df + 0.5) / (df + 0.5)) |
|
|
|
except Exception as e: |
|
logger.error(f"BM25 fitting error: {e}") |
|
self.documents = [] |
|
|
|
def get_scores(self, query: str) -> np.ndarray: |
|
"""Get BM25 scores for query""" |
|
try: |
|
query_terms = self.preprocess_text(query) |
|
scores = np.zeros(len(self.documents)) |
|
|
|
for i, doc_tf in enumerate(self.term_freqs): |
|
score = 0 |
|
doc_length = self.doc_lengths[i] |
|
|
|
for term in query_terms: |
|
if term in doc_tf: |
|
tf = doc_tf[term] |
|
idf = self.idf.get(term, 0) |
|
score += idf * (tf * (self.k1 + 1)) / ( |
|
tf + self.k1 * (1 - self.b + self.b * (doc_length / self.avg_doc_length)) |
|
) |
|
|
|
scores[i] = score |
|
|
|
return scores |
|
except Exception as e: |
|
logger.error(f"BM25 scoring error: {e}") |
|
return np.zeros(len(self.documents)) |
|
|
|
class OptimizedRagSystem: |
|
"""GPU-optimized RAG system for ArXiv papers""" |
|
|
|
def __init__(self): |
|
self.papers = [] |
|
self.chunks = [] |
|
self.embeddings = None |
|
self.embedding_model = None |
|
self.reranker = None |
|
self.bm25 = BM25Retriever() |
|
self.generator = None |
|
self.memory_manager = GPUMemoryManager() |
|
|
|
|
|
self._load_models() |
|
|
|
def _load_models(self): |
|
"""Load models with GPU optimization""" |
|
try: |
|
logger.info("Loading models...") |
|
|
|
|
|
self.embedding_model = SentenceTransformer( |
|
'sentence-transformers/all-MiniLM-L6-v2', |
|
device=DEVICE |
|
) |
|
|
|
|
|
if DEVICE == "cuda": |
|
self.embedding_model.half() |
|
|
|
|
|
self.reranker = CrossEncoder( |
|
'cross-encoder/ms-marco-MiniLM-L-6-v2', |
|
device=DEVICE |
|
) |
|
|
|
|
|
self.generator = pipeline( |
|
"text-generation", |
|
model="microsoft/DialoGPT-small", |
|
tokenizer="microsoft/DialoGPT-small", |
|
device=0 if DEVICE == "cuda" else -1, |
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
|
return_full_text=False, |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.7, |
|
pad_token_id=50256 |
|
) |
|
|
|
self.memory_manager.optimize_memory() |
|
logger.info("Models loaded successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Model loading error: {e}") |
|
raise |
|
|
|
def search_arxiv(self, query: str, max_results: int = 15, categories: List[str] = None) -> List[Paper]: |
|
"""Search ArXiv with enhanced error handling and retry logic""" |
|
max_retries = 3 |
|
retry_delay = 1.0 |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
papers = [] |
|
search_query = query.strip() |
|
|
|
|
|
if not search_query or len(search_query) < 2: |
|
logger.warning("Query too short, using default search") |
|
search_query = "machine learning" |
|
|
|
|
|
|
|
if "attention" in search_query.lower() and "transformer" in search_query.lower(): |
|
search_query = "attention mechanism transformer" |
|
elif "transformer" in search_query.lower(): |
|
search_query = "transformer neural network" |
|
elif "attention" in search_query.lower(): |
|
search_query = "attention mechanism" |
|
|
|
logger.info(f"Simplified query: '{search_query}'") |
|
|
|
|
|
if "attention" in search_query.lower() or "transformer" in search_query.lower(): |
|
if not categories or len(categories) == 0: |
|
categories = ["cs.CL", "cs.LG", "cs.AI"] |
|
logger.info(f"Added default categories for transformer search: {categories}") |
|
|
|
if categories and len(categories) > 0: |
|
category_filter = " OR ".join([f"cat:{cat.strip()}" for cat in categories if cat.strip()]) |
|
if category_filter: |
|
search_query = f"({search_query}) AND ({category_filter})" |
|
|
|
logger.info(f"π ArXiv search attempt {attempt + 1}: '{search_query}'") |
|
|
|
|
|
|
|
search = arxiv.Search( |
|
query=search_query, |
|
max_results=min(max_results, 50), |
|
sort_by=arxiv.SortCriterion.Relevance, |
|
sort_order=arxiv.SortOrder.Descending |
|
) |
|
|
|
|
|
def timeout_handler(signum, frame): |
|
raise TimeoutError("ArXiv search timeout") |
|
|
|
signal.signal(signal.SIGALRM, timeout_handler) |
|
signal.alarm(30) |
|
|
|
try: |
|
result_count = 0 |
|
for result in search.results(): |
|
try: |
|
|
|
if not result.title or not result.summary: |
|
logger.warning("Skipping paper with missing title/abstract") |
|
continue |
|
|
|
paper = Paper( |
|
id=result.entry_id.split('/')[-1] if result.entry_id else f"unknown_{result_count}", |
|
title=result.title.strip(), |
|
abstract=result.summary.strip(), |
|
authors=[author.name for author in (result.authors or [])], |
|
categories=result.categories or [], |
|
published=result.published or datetime.now(), |
|
url=result.entry_id or f"https://arxiv.org/abs/{result_count}" |
|
) |
|
papers.append(paper) |
|
result_count += 1 |
|
|
|
|
|
time.sleep(0.1) |
|
|
|
|
|
if len(papers) >= max_results: |
|
break |
|
|
|
except Exception as e: |
|
logger.warning(f"Error processing individual paper: {e}") |
|
continue |
|
|
|
finally: |
|
signal.alarm(0) |
|
|
|
if papers: |
|
logger.info(f"β
Successfully found {len(papers)} papers") |
|
return papers |
|
else: |
|
logger.warning(f"No papers found on attempt {attempt + 1}") |
|
|
|
except TimeoutError: |
|
logger.warning(f"ArXiv search timeout on attempt {attempt + 1}") |
|
except Exception as e: |
|
logger.error(f"ArXiv search error on attempt {attempt + 1}: {type(e).__name__}: {e}") |
|
|
|
|
|
if attempt < max_retries - 1: |
|
logger.info(f"Retrying in {retry_delay} seconds...") |
|
time.sleep(retry_delay) |
|
retry_delay *= 2 |
|
|
|
|
|
logger.warning("All search attempts failed, trying fallback search...") |
|
try: |
|
|
|
fallback_queries = [ |
|
"attention is all you need", |
|
"transformer attention mechanism", |
|
"BERT language representation", |
|
"GPT generative pretrained", |
|
"artificial intelligence" |
|
] |
|
|
|
for fallback_query in fallback_queries: |
|
logger.info(f"Trying fallback: '{fallback_query}'") |
|
fallback_search = arxiv.Search( |
|
query=fallback_query, |
|
max_results=5, |
|
sort_by=arxiv.SortCriterion.Relevance, |
|
sort_order=arxiv.SortOrder.Descending |
|
) |
|
|
|
papers = [] |
|
for i, result in enumerate(fallback_search.results()): |
|
if i >= 5: |
|
break |
|
try: |
|
paper = Paper( |
|
id=result.entry_id.split('/')[-1], |
|
title=result.title, |
|
abstract=result.summary, |
|
authors=[author.name for author in result.authors], |
|
categories=result.categories, |
|
published=result.published, |
|
url=result.entry_id |
|
) |
|
papers.append(paper) |
|
except Exception as e: |
|
logger.warning(f"Error in fallback paper processing: {e}") |
|
continue |
|
|
|
if papers: |
|
logger.info(f"π Fallback search '{fallback_query}' returned {len(papers)} papers") |
|
return papers |
|
|
|
except Exception as e: |
|
logger.error(f"Even fallback search failed: {e}") |
|
|
|
logger.error("β All ArXiv search methods failed") |
|
return [] |
|
|
|
def create_chunks(self, papers: List[Paper]) -> List[Chunk]: |
|
"""Create text chunks from papers""" |
|
chunks = [] |
|
|
|
for paper in papers: |
|
try: |
|
|
|
chunks.append(Chunk( |
|
id=f"{paper.id}_title", |
|
paper_id=paper.id, |
|
text=paper.title, |
|
chunk_type="title", |
|
metadata={"paper": paper} |
|
)) |
|
|
|
|
|
abstract_sentences = sent_tokenize(paper.abstract) |
|
chunk_size = 3 |
|
|
|
for i in range(0, len(abstract_sentences), chunk_size): |
|
chunk_text = ' '.join(abstract_sentences[i:i + chunk_size]) |
|
chunks.append(Chunk( |
|
id=f"{paper.id}_abstract_{i}", |
|
paper_id=paper.id, |
|
text=chunk_text, |
|
chunk_type="abstract", |
|
metadata={"paper": paper} |
|
)) |
|
|
|
except Exception as e: |
|
logger.warning(f"Error creating chunks for paper {paper.id}: {e}") |
|
continue |
|
|
|
return chunks |
|
|
|
@spaces.GPU(duration=120) |
|
def embed_chunks(self, chunks: List[Chunk]) -> np.ndarray: |
|
"""Create embeddings for chunks with GPU optimization""" |
|
try: |
|
if not chunks: |
|
return np.array([]) |
|
|
|
logger.info(f"Creating embeddings for {len(chunks)} chunks") |
|
self.memory_manager.clear_cache() |
|
|
|
texts = [chunk.text for chunk in chunks] |
|
|
|
|
|
batch_size = 32 if DEVICE == "cuda" else 8 |
|
embeddings = [] |
|
|
|
for i in range(0, len(texts), batch_size): |
|
batch_texts = texts[i:i + batch_size] |
|
|
|
with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad(): |
|
batch_embeddings = self.embedding_model.encode( |
|
batch_texts, |
|
convert_to_tensor=True, |
|
show_progress_bar=False, |
|
batch_size=len(batch_texts) |
|
) |
|
|
|
if DEVICE == "cuda": |
|
batch_embeddings = batch_embeddings.cpu() |
|
|
|
embeddings.append(batch_embeddings.numpy()) |
|
|
|
|
|
if i % (batch_size * 4) == 0: |
|
self.memory_manager.clear_cache() |
|
|
|
result = np.vstack(embeddings) if embeddings else np.array([]) |
|
self.memory_manager.clear_cache() |
|
|
|
logger.info(f"Created embeddings shape: {result.shape}") |
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"Embedding error: {e}") |
|
self.memory_manager.clear_cache() |
|
return np.array([]) |
|
|
|
@spaces.GPU(duration=60) |
|
def hybrid_retrieval(self, query: str, top_k: int = 10, semantic_weight: float = 0.7) -> List[Tuple[Chunk, float]]: |
|
"""Perform hybrid retrieval with GPU optimization""" |
|
try: |
|
if not self.chunks or self.embeddings is None or len(self.embeddings) == 0: |
|
return [] |
|
|
|
self.memory_manager.clear_cache() |
|
|
|
|
|
with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad(): |
|
query_embedding = self.embedding_model.encode( |
|
[query], |
|
convert_to_tensor=True, |
|
show_progress_bar=False |
|
) |
|
|
|
if DEVICE == "cuda": |
|
query_embedding = query_embedding.cpu() |
|
|
|
query_embedding = query_embedding.numpy() |
|
|
|
semantic_scores = cosine_similarity(query_embedding, self.embeddings)[0] |
|
|
|
|
|
bm25_scores = self.bm25.get_scores(query) |
|
|
|
|
|
min_length = min(len(semantic_scores), len(bm25_scores), len(self.chunks)) |
|
semantic_scores = semantic_scores[:min_length] |
|
bm25_scores = bm25_scores[:min_length] |
|
chunks = self.chunks[:min_length] |
|
|
|
|
|
if len(semantic_scores) > 0 and semantic_scores.max() > semantic_scores.min(): |
|
semantic_scores = (semantic_scores - semantic_scores.min()) / (semantic_scores.max() - semantic_scores.min()) |
|
else: |
|
semantic_scores = np.ones_like(semantic_scores) * 0.5 |
|
|
|
if len(bm25_scores) > 0 and bm25_scores.max() > bm25_scores.min(): |
|
bm25_scores = (bm25_scores - bm25_scores.min()) / (bm25_scores.max() - bm25_scores.min()) |
|
else: |
|
bm25_scores = np.ones_like(bm25_scores) * 0.5 |
|
|
|
|
|
combined_scores = semantic_weight * semantic_scores + (1 - semantic_weight) * bm25_scores |
|
|
|
|
|
combined_scores = np.maximum(combined_scores, 0.0) |
|
|
|
|
|
top_indices = np.argsort(combined_scores)[::-1][:top_k] |
|
results = [(chunks[i], float(combined_scores[i])) for i in top_indices] |
|
|
|
self.memory_manager.clear_cache() |
|
return results |
|
|
|
except Exception as e: |
|
logger.error(f"Retrieval error: {e}") |
|
self.memory_manager.clear_cache() |
|
return [] |
|
|
|
@spaces.GPU(duration=60) |
|
def rerank_results(self, query: str, results: List[Tuple[Chunk, float]], top_k: int = 5) -> List[Tuple[Chunk, float]]: |
|
"""Rerank results using cross-encoder with GPU optimization""" |
|
try: |
|
if not results or not self.reranker: |
|
return results[:top_k] |
|
|
|
self.memory_manager.clear_cache() |
|
|
|
pairs = [(query, chunk.text) for chunk, _ in results] |
|
|
|
with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad(): |
|
rerank_scores = self.reranker.predict(pairs, show_progress_bar=False) |
|
|
|
|
|
reranked_results = [] |
|
for i, (chunk, original_score) in enumerate(results): |
|
|
|
rerank_score = float(rerank_scores[i]) |
|
rerank_score = max(0.0, min(1.0, (rerank_score + 1) / 2)) |
|
|
|
combined_score = 0.6 * rerank_score + 0.4 * max(0.0, original_score) |
|
reranked_results.append((chunk, combined_score)) |
|
|
|
|
|
reranked_results.sort(key=lambda x: x[1], reverse=True) |
|
|
|
self.memory_manager.clear_cache() |
|
return reranked_results[:top_k] |
|
|
|
except Exception as e: |
|
logger.error(f"Reranking error: {e}") |
|
self.memory_manager.clear_cache() |
|
return results[:top_k] |
|
|
|
@spaces.GPU(duration=90) |
|
def generate_answer(self, query: str, context_chunks: List[Chunk]) -> str: |
|
"""Generate answer using retrieved context with GPU optimization""" |
|
try: |
|
if not context_chunks or not self.generator: |
|
return "No relevant information found to answer your query." |
|
|
|
self.memory_manager.clear_cache() |
|
|
|
|
|
context_parts = [] |
|
for chunk in context_chunks[:3]: |
|
paper = chunk.metadata.get("paper") |
|
if paper: |
|
context_parts.append(f"Title: {paper.title}\nContent: {chunk.text}") |
|
|
|
context = "\n\n".join(context_parts) |
|
|
|
|
|
prompt = f"""Based on the following research papers, provide a comprehensive answer to the query: |
|
|
|
Query: {query} |
|
|
|
Research Context: |
|
{context[:2000]} |
|
|
|
Answer:""" |
|
|
|
with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad(): |
|
response = self.generator( |
|
prompt, |
|
max_new_tokens=300, |
|
temperature=0.7, |
|
do_sample=True, |
|
pad_token_id=50256 |
|
) |
|
|
|
answer = response[0]['generated_text'].strip() |
|
|
|
self.memory_manager.clear_cache() |
|
return answer |
|
|
|
except Exception as e: |
|
logger.error(f"Answer generation error: {e}") |
|
self.memory_manager.clear_cache() |
|
return f"Error generating answer: {str(e)}" |
|
|
|
def format_results(self, results: List[Tuple[Chunk, float]]) -> Tuple[str, pd.DataFrame]: |
|
"""Format results for display""" |
|
try: |
|
if not results: |
|
return "No relevant papers found.", pd.DataFrame() |
|
|
|
|
|
papers_dict = {} |
|
for chunk, score in results: |
|
paper = chunk.metadata.get("paper") |
|
if paper and paper.id not in papers_dict: |
|
papers_dict[paper.id] = { |
|
'paper': paper, |
|
'max_score': score, |
|
'chunks': [(chunk, score)] |
|
} |
|
elif paper: |
|
papers_dict[paper.id]['chunks'].append((chunk, score)) |
|
papers_dict[paper.id]['max_score'] = max(papers_dict[paper.id]['max_score'], score) |
|
|
|
|
|
sorted_papers = sorted(papers_dict.values(), key=lambda x: x['max_score'], reverse=True) |
|
|
|
|
|
markdown_parts = [] |
|
table_data = [] |
|
|
|
for i, paper_info in enumerate(sorted_papers[:8], 1): |
|
paper = paper_info['paper'] |
|
score = paper_info['max_score'] |
|
|
|
|
|
authors_str = ", ".join(paper.authors[:3]) |
|
if len(paper.authors) > 3: |
|
authors_str += " et al." |
|
|
|
categories_str = ", ".join(paper.categories[:3]) |
|
|
|
markdown_parts.append(f""" |
|
### {i}. [{paper.title}]({paper.url}) |
|
|
|
**Authors:** {authors_str} |
|
**Categories:** {categories_str} |
|
**Published:** {paper.published.strftime('%Y-%m-%d')} |
|
**Relevance Score:** {score:.3f} |
|
|
|
**Abstract:** {paper.abstract[:300]}{'...' if len(paper.abstract) > 300 else ''} |
|
|
|
--- |
|
""") |
|
|
|
|
|
table_data.append({ |
|
'Rank': i, |
|
'Title': paper.title[:60] + ('...' if len(paper.title) > 60 else ''), |
|
'Authors': authors_str, |
|
'Categories': categories_str, |
|
'Published': paper.published.strftime('%Y-%m-%d'), |
|
'Score': f"{score:.3f}", |
|
'URL': paper.url |
|
}) |
|
|
|
markdown_text = "".join(markdown_parts) |
|
df = pd.DataFrame(table_data) |
|
|
|
return markdown_text, df |
|
|
|
except Exception as e: |
|
logger.error(f"Formatting error: {e}") |
|
return f"Error formatting results: {str(e)}", pd.DataFrame() |
|
|
|
|
|
rag_system = None |
|
|
|
def initialize_system(): |
|
"""Initialize the RAG system""" |
|
global rag_system |
|
try: |
|
if rag_system is None: |
|
logger.info("Initializing RAG system...") |
|
rag_system = OptimizedRagSystem() |
|
logger.info("RAG system initialized successfully") |
|
except Exception as e: |
|
logger.error(f"System initialization error: {e}") |
|
raise |
|
|
|
|
|
@spaces.GPU(duration=180) |
|
def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10, |
|
top_k_rerank: int = 5, categories: str = "", semantic_weight: float = 0.7): |
|
"""Main search function with GPU optimization""" |
|
try: |
|
if not query.strip(): |
|
return "β Please enter a search query.", "", pd.DataFrame() |
|
|
|
|
|
initialize_system() |
|
|
|
start_time = time.time() |
|
|
|
|
|
category_list = [] |
|
if categories.strip(): |
|
category_list = [cat.strip() for cat in categories.split(',') if cat.strip()] |
|
|
|
|
|
papers = rag_system.search_arxiv(query, max_papers, category_list) |
|
|
|
if not papers: |
|
return "β No papers found. Try different keywords or check your internet connection.", "", pd.DataFrame() |
|
|
|
|
|
rag_system.papers = papers |
|
rag_system.chunks = rag_system.create_chunks(papers) |
|
|
|
if not rag_system.chunks: |
|
return "β Error processing papers.", "", pd.DataFrame() |
|
|
|
|
|
rag_system.embeddings = rag_system.embed_chunks(rag_system.chunks) |
|
|
|
if rag_system.embeddings is None or len(rag_system.embeddings) == 0: |
|
return "β Error creating embeddings.", "", pd.DataFrame() |
|
|
|
|
|
chunk_texts = [chunk.text for chunk in rag_system.chunks] |
|
rag_system.bm25.fit(chunk_texts) |
|
|
|
|
|
retrieved_results = rag_system.hybrid_retrieval(query, top_k_retrieval, semantic_weight) |
|
|
|
if not retrieved_results: |
|
return "β No relevant content found.", "", pd.DataFrame() |
|
|
|
|
|
reranked_results = rag_system.rerank_results(query, retrieved_results, top_k_rerank) |
|
|
|
|
|
answer = rag_system.generate_answer(query, [chunk for chunk, _ in reranked_results]) |
|
|
|
|
|
papers_md, papers_df = rag_system.format_results(reranked_results) |
|
|
|
|
|
end_time = time.time() |
|
processing_time = end_time - start_time |
|
|
|
stats = f""" |
|
## π€ AI-Generated Answer |
|
|
|
{answer} |
|
|
|
## π Search Statistics |
|
|
|
- **Query:** {query} |
|
- **Papers Found:** {len(papers)} |
|
- **Chunks Processed:** {len(rag_system.chunks)} |
|
- **Top Results:** {len(reranked_results)} |
|
- **Processing Time:** {processing_time:.2f}s |
|
- **GPU Memory:** {rag_system.memory_manager.get_memory_info()} |
|
- **Semantic Weight:** {semantic_weight} |
|
|
|
--- |
|
""" |
|
|
|
|
|
rag_system.memory_manager.clear_cache() |
|
|
|
return stats, papers_md, papers_df |
|
|
|
except Exception as e: |
|
logger.error(f"Search error: {e}") |
|
error_msg = f"β An error occurred: {str(e)}\n\nPlease try different keywords or check your internet connection." |
|
return error_msg, "", pd.DataFrame() |
|
|
|
|
|
def create_interface(): |
|
"""Create optimized Gradio interface""" |
|
|
|
css = """ |
|
.gradio-container { |
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
} |
|
.gpu-badge { |
|
background: linear-gradient(45deg, #00d4aa, #00b4d8); |
|
color: white; |
|
padding: 0.5rem 1rem; |
|
border-radius: 20px; |
|
font-weight: bold; |
|
display: inline-block; |
|
margin-bottom: 1rem; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css, title="Enhanced ArXiv RAG System - GPU Optimized") as interface: |
|
|
|
gr.HTML(f""" |
|
<div style="text-align: center; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;"> |
|
<h1>π Enhanced ArXiv RAG System</h1> |
|
<p>GPU-Optimized scientific paper discovery with semantic search, BM25, and neural reranking</p> |
|
<div class="gpu-badge"> |
|
π₯ GPU Accelerated β’ Device: {DEVICE.upper()} |
|
</div> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
query_input = gr.Textbox( |
|
label="Research Query", |
|
placeholder="Enter your research question (e.g., 'transformer attention mechanisms in NLP')", |
|
lines=2 |
|
) |
|
|
|
with gr.Row(): |
|
max_papers = gr.Slider(5, 25, value=15, step=1, label="Max Papers") |
|
semantic_weight = gr.Slider(0.1, 0.9, value=0.7, step=0.1, label="Semantic Weight") |
|
|
|
categories_input = gr.Textbox( |
|
label="ArXiv Categories (Optional)", |
|
placeholder="e.g., cs.CL, cs.AI, cs.LG", |
|
value="" |
|
) |
|
|
|
with gr.Accordion("π§ Advanced GPU Settings", open=False): |
|
with gr.Row(): |
|
top_k_retrieval = gr.Slider(5, 15, value=10, step=1, label="Top-K Retrieval") |
|
top_k_rerank = gr.Slider(3, 8, value=5, step=1, label="Top-K Reranking") |
|
|
|
gr.HTML(f""" |
|
<div style="background: #e8f5e8; padding: 1rem; border-radius: 8px; margin-top: 1rem;"> |
|
<h4>β‘ GPU Optimization Info</h4> |
|
<ul> |
|
<li><strong>Device:</strong> {DEVICE.upper()}</li> |
|
<li><strong>Mixed Precision:</strong> {'Enabled' if DEVICE == 'cuda' else 'Disabled'}</li> |
|
<li><strong>Memory Management:</strong> Automatic cleanup</li> |
|
<li><strong>Batch Processing:</strong> Optimized for GPU</li> |
|
</ul> |
|
</div> |
|
""") |
|
|
|
search_btn = gr.Button("π Search Papers", variant="primary", size="lg") |
|
|
|
with gr.Column(scale=1): |
|
gr.HTML(""" |
|
<div style="background: #e3f2fd; padding: 1rem; border-radius: 8px;"> |
|
<h4>π‘ Tips for Best Results</h4> |
|
<ul> |
|
<li>Use specific technical terms</li> |
|
<li>Try different category filters</li> |
|
<li>Adjust semantic weight for different search styles</li> |
|
<li>Higher semantic weight = more conceptual matching</li> |
|
<li>Lower semantic weight = more keyword matching</li> |
|
</ul> |
|
|
|
<h4>π Popular Categories</h4> |
|
<ul> |
|
<li><code>cs.AI</code> - Artificial Intelligence</li> |
|
<li><code>cs.CL</code> - Computation and Language</li> |
|
<li><code>cs.LG</code> - Machine Learning</li> |
|
<li><code>cs.CV</code> - Computer Vision</li> |
|
<li><code>cs.RO</code> - Robotics</li> |
|
<li><code>stat.ML</code> - Machine Learning (Stats)</li> |
|
</ul> |
|
</div> |
|
""") |
|
|
|
|
|
answer_output = gr.Markdown(label="AI Answer & Statistics") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("π Papers"): |
|
papers_output = gr.Markdown(label="Relevant Papers") |
|
|
|
with gr.TabItem("π Papers Table"): |
|
papers_table = gr.Dataframe( |
|
label="Papers Summary", |
|
wrap=True, |
|
interactive=False |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["transformer attention mechanisms", 15, 10, 5, "cs.CL, cs.AI", 0.7], |
|
["graph neural networks for molecular property prediction", 12, 8, 4, "cs.LG", 0.6], |
|
["computer vision deep learning", 15, 10, 5, "cs.CV", 0.8], |
|
["reinforcement learning robotics", 18, 10, 5, "cs.AI, cs.RO", 0.7], |
|
["large language models fine-tuning", 20, 12, 6, "cs.CL", 0.75] |
|
], |
|
inputs=[query_input, max_papers, top_k_retrieval, top_k_rerank, categories_input, semantic_weight] |
|
) |
|
|
|
|
|
search_btn.click( |
|
fn=search_papers, |
|
inputs=[query_input, max_papers, top_k_retrieval, top_k_rerank, categories_input, semantic_weight], |
|
outputs=[answer_output, papers_output, papers_table] |
|
) |
|
|
|
gr.HTML(""" |
|
<div style="text-align: center; margin-top: 2rem; padding: 1rem; background: #f5f5f5; border-radius: 8px;"> |
|
<p><strong>Enhanced ArXiv RAG System</strong> | GPU-Optimized β’ Semantic Search + BM25 + Neural Reranking</p> |
|
<p><em>Powered by Hugging Face Spaces GPU β’ Optimized for high-performance research</em></p> |
|
</div> |
|
""") |
|
|
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
try: |
|
initialize_system() |
|
except Exception as e: |
|
logger.error(f"Pre-initialization failed: {e}") |
|
|
|
interface = create_interface() |
|
interface.launch( |
|
show_error=True, |
|
share=True |
|
) |