Spaces:
Sleeping
Sleeping
""" | |
Enhanced ArXiv RAG System - Hugging Face Spaces Compatible Version | |
""" | |
import os | |
import re | |
import arxiv | |
import numpy as np | |
import pandas as pd | |
from typing import List, Dict, Tuple, Optional, Any | |
from dataclasses import dataclass | |
from datetime import datetime, timedelta | |
import logging | |
import tempfile | |
import shutil | |
# Core ML libraries | |
import torch | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
from transformers import pipeline | |
import gradio as gr | |
# BM25 and text processing | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import nltk | |
from nltk.corpus import stopwords | |
from nltk.tokenize import word_tokenize, sent_tokenize | |
from nltk.stem import PorterStemmer | |
# Download required NLTK data | |
try: | |
nltk.data.find("tokenizers/punkt") | |
except LookupError: | |
nltk.download("punkt") | |
try: | |
nltk.data.find("corpora/stopwords") | |
except LookupError: | |
nltk.download("stopwords") | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class Paper: | |
"""Data class for storing paper information""" | |
id: str | |
title: str | |
abstract: str | |
authors: List[str] | |
categories: List[str] | |
published: datetime | |
url: str | |
class Chunk: | |
"""Data class for storing text chunks""" | |
id: str | |
paper_id: str | |
text: str | |
chunk_type: str | |
metadata: Dict[str, Any] | |
class BM25Retriever: | |
"""BM25 retriever for keyword-based search""" | |
def __init__(self, k1: float = 1.5, b: float = 0.75): | |
self.k1 = k1 | |
self.b = b | |
self.documents = [] | |
self.doc_lengths = [] | |
self.avg_doc_length = 0 | |
self.stemmer = PorterStemmer() | |
try: | |
self.stop_words = set(stopwords.words("english")) | |
except: | |
self.stop_words = set() | |
def preprocess_text(self, text: str) -> List[str]: | |
"""Preprocess text for BM25""" | |
tokens = word_tokenize(text.lower()) | |
processed_tokens = [ | |
self.stemmer.stem(token) | |
for token in tokens | |
if token.isalpha() and token not in self.stop_words | |
] | |
return processed_tokens | |
def fit(self, documents: List[str]): | |
"""Fit BM25 on documents""" | |
self.documents = [self.preprocess_text(doc) for doc in documents] | |
self.doc_lengths = [len(doc) for doc in self.documents] | |
self.avg_doc_length = sum(self.doc_lengths) / len(self.doc_lengths) if self.doc_lengths else 0 | |
vocab = set() | |
for doc in self.documents: | |
vocab.update(doc) | |
self.vocab = list(vocab) | |
self.term_freqs = [] | |
for doc in self.documents: | |
tf = {} | |
for term in doc: | |
tf[term] = tf.get(term, 0) + 1 | |
self.term_freqs.append(tf) | |
self.idf = {} | |
for term in self.vocab: | |
containing_docs = sum(1 for tf in self.term_freqs if term in tf) | |
self.idf[term] = np.log((len(self.documents) - containing_docs + 0.5) / (containing_docs + 0.5)) | |
def score(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]: | |
"""Score documents against query""" | |
query_terms = self.preprocess_text(query) | |
scores = [] | |
for i, (doc, tf, doc_len) in enumerate(zip(self.documents, self.term_freqs, self.doc_lengths)): | |
score = 0 | |
for term in query_terms: | |
if term in tf: | |
term_freq = tf[term] | |
idf = self.idf.get(term, 0) | |
numerator = term_freq * (self.k1 + 1) | |
denominator = term_freq + self.k1 * (1 - self.b + self.b * (doc_len / self.avg_doc_length)) | |
score += idf * (numerator / denominator) | |
scores.append((i, score)) | |
scores.sort(key=lambda x: x[1], reverse=True) | |
return scores[:top_k] | |
class SimpleVectorStore: | |
"""Simple in-memory vector store for HF Spaces compatibility""" | |
def __init__(self): | |
self.embeddings = [] | |
self.documents = [] | |
self.metadatas = [] | |
self.ids = [] | |
def add(self, ids: List[str], embeddings: List[List[float]], | |
documents: List[str], metadatas: List[Dict]): | |
"""Add documents to the store""" | |
self.ids.extend(ids) | |
self.embeddings.extend(embeddings) | |
self.documents.extend(documents) | |
self.metadatas.extend(metadatas) | |
def query(self, query_embedding: List[float], n_results: int = 10) -> Dict: | |
"""Query the vector store""" | |
if not self.embeddings: | |
return {"ids": [[]], "documents": [[]], "metadatas": [[]]} | |
# Calculate cosine similarities | |
query_embedding = np.array(query_embedding) | |
similarities = [] | |
for emb in self.embeddings: | |
emb_array = np.array(emb) | |
similarity = np.dot(query_embedding, emb_array) / ( | |
np.linalg.norm(query_embedding) * np.linalg.norm(emb_array) | |
) | |
similarities.append(similarity) | |
# Get top results | |
top_indices = np.argsort(similarities)[::-1][:n_results] | |
return { | |
"ids": [[self.ids[i] for i in top_indices]], | |
"documents": [[self.documents[i] for i in top_indices]], | |
"metadatas": [[self.metadatas[i] for i in top_indices]] | |
} | |
def get(self, ids: Optional[List[str]] = None) -> Dict: | |
"""Get documents by IDs or all documents""" | |
if ids is None: | |
return { | |
"ids": self.ids, | |
"documents": self.documents, | |
"metadatas": self.metadatas | |
} | |
else: | |
indices = [self.ids.index(id_) for id_ in ids if id_ in self.ids] | |
return { | |
"ids": [self.ids[i] for i in indices], | |
"documents": [self.documents[i] for i in indices], | |
"metadatas": [self.metadatas[i] for i in indices] | |
} | |
def clear(self): | |
"""Clear the store""" | |
self.embeddings.clear() | |
self.documents.clear() | |
self.metadatas.clear() | |
self.ids.clear() | |
class EnhancedArxivRAG: | |
"""Enhanced RAG system optimized for Hugging Face Spaces""" | |
def __init__(self): | |
logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...") | |
# Determine device (GPU if available, else CPU) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {self.device}") | |
# Load models with appropriate device settings | |
try: | |
logger.info("Loading embedding model...") | |
self.embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=self.device) | |
logger.info("Embedding model loaded.") | |
logger.info("Loading reranker model...") | |
self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-2-v2", device=self.device) | |
logger.info("Reranker model loaded.") | |
logger.info("Loading summarizer model...") | |
# For pipeline, device_map="auto" is often better for ZeroGPU | |
# If issues persist, try device=0 for the first GPU, or device=self.device | |
self.summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device_map="auto") | |
logger.info("Summarizer model loaded.") | |
except Exception as e: | |
logger.error(f"Error loading models: {e}") | |
raise | |
# Use simple vector store instead of ChromaDB for HF Spaces | |
self.vector_store = SimpleVectorStore() | |
self.bm25_retriever = BM25Retriever() | |
# Cache for papers and chunks | |
self.papers_cache = {} | |
self.chunks_cache = {} | |
self.bm25_fitted = False | |
logger.info("RAG system initialized successfully!") | |
def fetch_papers(self, query: str, max_results: int = 15, | |
categories: Optional[List[str]] = None) -> List[Paper]: | |
"""Fetch papers from ArXiv""" | |
search_query = query | |
if categories: | |
category_filter = " OR ".join([f"cat:{cat}" for cat in categories]) | |
search_query = f"({query}) AND ({category_filter})" | |
logger.info(f"Fetching papers with query: {search_query}") | |
try: | |
search = arxiv.Search( | |
query=search_query, | |
max_results=max_results, | |
sort_by=arxiv.SortCriterion.Relevance | |
) | |
papers = [] | |
for result in search.results(): | |
paper = Paper( | |
id=result.entry_id.split("/")[-1], | |
title=result.title.strip().replace("\n", " "), | |
abstract=result.summary.strip().replace("\n", " "), | |
authors=[author.name for author in result.authors], | |
categories=result.categories, | |
published=result.published.replace(tzinfo=None), | |
url=result.entry_id | |
) | |
papers.append(paper) | |
self.papers_cache[paper.id] = paper | |
logger.info(f"Fetched {len(papers)} papers") | |
return papers | |
except Exception as e: | |
logger.error(f"Error fetching papers: {e}") | |
return [] | |
def create_chunks(self, papers: List[Paper]) -> List[Chunk]: | |
"""Create text chunks from papers""" | |
chunks = [] | |
for paper in papers: | |
# Title chunk | |
title_chunk = Chunk( | |
id=f"{paper.id}_title", | |
paper_id=paper.id, | |
text=paper.title, | |
chunk_type="title", | |
metadata={ | |
"authors": paper.authors, | |
"categories": paper.categories, | |
"published": paper.published.isoformat(), | |
"url": paper.url | |
} | |
) | |
# Abstract chunk | |
abstract_chunk = Chunk( | |
id=f"{paper.id}_abstract", | |
paper_id=paper.id, | |
text=paper.abstract, | |
chunk_type="abstract", | |
metadata={ | |
"authors": paper.authors, | |
"categories": paper.categories, | |
"published": paper.published.isoformat(), | |
"url": paper.url | |
} | |
) | |
# Combined chunk | |
combined_text = f"Title: {paper.title}\n\nAbstract: {paper.abstract}" | |
combined_chunk = Chunk( | |
id=f"{paper.id}_combined", | |
paper_id=paper.id, | |
text=combined_text, | |
chunk_type="combined", | |
metadata={ | |
"authors": paper.authors, | |
"categories": paper.categories, | |
"published": paper.published.isoformat(), | |
"url": paper.url | |
} | |
) | |
chunks.extend([title_chunk, abstract_chunk, combined_chunk]) | |
# Cache chunks | |
for chunk in [title_chunk, abstract_chunk, combined_chunk]: | |
self.chunks_cache[chunk.id] = chunk | |
return chunks | |
def process_and_store(self, papers: List[Paper]): | |
"""Process papers and store in vector store""" | |
logger.info("Processing and storing papers...") | |
# Clear previous data | |
self.vector_store.clear() | |
# Create chunks | |
chunks = self.create_chunks(papers) | |
if not chunks: | |
return | |
# Generate embeddings | |
texts = [chunk.text for chunk in chunks] | |
logger.info("Generating embeddings...") | |
embeddings = self.embedding_model.encode(texts, show_progress_bar=False) | |
# Store in vector store | |
ids = [chunk.id for chunk in chunks] | |
metadatas = [chunk.metadata for chunk in chunks] | |
self.vector_store.add( | |
ids=ids, | |
embeddings=embeddings.tolist(), | |
documents=texts, | |
metadatas=metadatas | |
) | |
# Fit BM25 | |
logger.info("Fitting BM25...") | |
self.bm25_retriever.fit(texts) | |
self.bm25_fitted = True | |
logger.info(f"Stored {len(chunks)} chunks") | |
def hybrid_search(self, query: str, top_k: int = 10, | |
semantic_weight: float = 0.7) -> List[Dict]: | |
"""Perform hybrid search""" | |
# Semantic search | |
query_embedding = self.embedding_model.encode([query]) | |
semantic_results = self.vector_store.query( | |
query_embedding=query_embedding[0].tolist(), | |
n_results=top_k * 2 | |
) | |
# BM25 search | |
bm25_results = [] | |
if self.bm25_fitted: | |
all_docs = self.vector_store.get() | |
bm25_scores = self.bm25_retriever.score(query, top_k * 2) | |
for idx, score in bm25_scores: | |
if idx < len(all_docs["ids"]): | |
bm25_results.append({ | |
"id": all_docs["ids"][idx], | |
"document": all_docs["documents"][idx], | |
"metadata": all_docs["metadatas"][idx], | |
"score": score | |
}) | |
# Combine results using RRF | |
combined_scores = {} | |
bm25_weight = 1.0 - semantic_weight | |
# Add semantic scores | |
for i, doc_id in enumerate(semantic_results["ids"][0]): | |
rank = i + 1 | |
combined_scores[doc_id] = combined_scores.get(doc_id, 0) + semantic_weight / rank | |
# Add BM25 scores | |
for i, result in enumerate(bm25_results): | |
doc_id = result["id"] | |
rank = i + 1 | |
combined_scores[doc_id] = combined_scores.get(doc_id, 0) + bm25_weight / rank | |
# Sort by combined score | |
sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True) | |
# Prepare final results | |
final_results = [] | |
for doc_id, score in sorted_results[:top_k]: | |
doc_result = self.vector_store.get(ids=[doc_id]) | |
if doc_result["ids"]: | |
final_results.append({ | |
"id": doc_id, | |
"document": doc_result["documents"][0], | |
"metadata": doc_result["metadatas"][0], | |
"combined_score": score | |
}) | |
return final_results | |
def rerank_results(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]: | |
"""Rerank results using cross-encoder""" | |
if not results: | |
return results | |
# Prepare query-document pairs | |
query_doc_pairs = [(query, result["document"]) for result in results] | |
# Get reranking scores | |
rerank_scores = self.reranker.predict(query_doc_pairs) | |
# Add rerank scores to results | |
for i, result in enumerate(results): | |
result["rerank_score"] = float(rerank_scores[i]) | |
# Sort by rerank score | |
reranked_results = sorted(results, key=lambda x: x["rerank_score"], reverse=True) | |
return reranked_results[:top_k] | |
def generate_answer(self, query: str, context_chunks: List[Dict]) -> str: | |
"""Generate answer using retrieved context""" | |
if not context_chunks: | |
return "No relevant information found to answer your query." | |
# Combine context from top chunks | |
context_texts = [chunk["document"] for chunk in context_chunks[:3]] | |
combined_context = "\n\n".join(context_texts) | |
# Limit context length | |
max_context_length = 800 | |
if len(combined_context) > max_context_length: | |
combined_context = combined_context[:max_context_length] + "..." | |
try: | |
summary_input = f"Based on the following research papers, answer this question: {query}\n\nContext: {combined_context}" | |
summary = self.summarizer(summary_input, | |
max_length=120, | |
min_length=30, | |
do_sample=False)[0]["summary_text"] | |
return summary | |
except Exception as e: | |
logger.error(f"Error generating summary: {e}") | |
return f"Based on the retrieved papers about \'{query}\', here are the key findings:\n\n" + \ | |
"\n\n".join([chunk["document"][:150] + "..." for chunk in context_chunks[:2]]) | |
def search_and_answer(self, query: str, max_papers: int = 15, | |
top_k_retrieval: int = 10, top_k_rerank: int = 5, | |
categories: Optional[List[str]] = None, | |
semantic_weight: float = 0.7) -> Dict[str, Any]: | |
"""Main search and answer pipeline""" | |
if not query.strip(): | |
return { | |
"answer": "Please enter a valid research query.", | |
"papers": [], | |
"retrieved_chunks": [], | |
"search_stats": {"papers_found": 0, "chunks_retrieved": 0} | |
} | |
try: | |
# Fetch papers | |
papers = self.fetch_papers(query, max_papers, categories) | |
if not papers: | |
return { | |
"answer": "No papers found for your query. Please try different keywords.", | |
"papers": [], | |
"retrieved_chunks": [], | |
"search_stats": {"papers_found": 0, "chunks_retrieved": 0} | |
} | |
# Process and store papers | |
self.process_and_store(papers) | |
# Hybrid search | |
search_results = self.hybrid_search(query, top_k_retrieval, semantic_weight) | |
# Rerank results | |
reranked_results = self.rerank_results(query, search_results, top_k_rerank) | |
# Generate answer | |
answer = self.generate_answer(query, reranked_results) | |
# Prepare unique papers | |
unique_papers = {} | |
for chunk in reranked_results: | |
paper_id = chunk["id"].split("_")[0] | |
if paper_id in self.papers_cache and paper_id not in unique_papers: | |
paper = self.papers_cache[paper_id] | |
unique_papers[paper_id] = { | |
"title": paper.title, | |
"authors": paper.authors, | |
"abstract": paper.abstract, | |
"url": paper.url, | |
"categories": paper.categories, | |
"published": paper.published.strftime("%Y-%m-%d") | |
} | |
return { | |
"answer": answer, | |
"papers": list(unique_papers.values()), | |
"retrieved_chunks": reranked_results, | |
"search_stats": { | |
"papers_found": len(papers), | |
"chunks_retrieved": len(reranked_results), | |
"unique_papers_in_results": len(unique_papers) | |
} | |
} | |
except Exception as e: | |
logger.error(f"Error in search_and_answer: {e}") | |
return { | |
"answer": f"An error occurred while processing your query: {str(e)}", | |
"papers": [], | |
"retrieved_chunks": [], | |
"search_stats": {"papers_found": 0, "chunks_retrieved": 0} | |
} | |
# Global RAG instance | |
rag_system = None | |
def initialize_rag(): | |
"""Initialize RAG system""" | |
global rag_system | |
if rag_system is None: | |
rag_system = EnhancedArxivRAG() | |
return rag_system | |
def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10, | |
top_k_rerank: int = 5, categories: str = "", | |
semantic_weight: float = 0.7) -> tuple: | |
"""Main search function for Gradio interface""" | |
if not query.strip(): | |
return "β Please enter a research topic or question.", "", pd.DataFrame() | |
try: | |
# Initialize RAG system | |
rag = initialize_rag() | |
# Parse categories | |
category_list = None | |
if categories.strip(): | |
category_list = [cat.strip() for cat in categories.split(",") if cat.strip()] | |
# Perform search | |
result = rag.search_and_answer( | |
query=query, | |
max_papers=max_papers, | |
top_k_retrieval=top_k_retrieval, | |
top_k_rerank=top_k_rerank, | |
categories=category_list, | |
semantic_weight=semantic_weight | |
) | |
# Format answer | |
answer = f"## π€ AI-Generated Answer\n\n{result["answer"]}\n\n" | |
answer += f"**Search Statistics:**\n" | |
answer += f"- Papers found: {result["search_stats"]["papers_found"]}\n" | |
answer += f"- Chunks retrieved: {result["search_stats"]["chunks_retrieved"]}\n" | |
answer += f"- Unique papers in results: {result["search_stats"]["unique_papers_in_results"]}\n\n" | |
# Format papers | |
papers_md = "## π Relevant Papers\n\n" | |
for i, paper in enumerate(result["papers"], 1): | |
papers_md += f"### {i}. {paper["title"]}\n\n" | |
papers_md += f"**Authors:** {", ".join(paper["authors"][:3])}{"..." if len(paper["authors"]) > 3 else ""}\n\n" | |
papers_md += f"**Categories:** {", ".join(paper["categories"])}\n\n" | |
papers_md += f"**Published:** {paper["published"]}\n\n" | |
papers_md += f"**Abstract:** {paper["abstract"][:250]}{"..." if len(paper["abstract"]) > 250 else ""}\n\n" | |
papers_md += f"**URL:** [{paper["url"]}]({paper["url"]})\n\n" | |
papers_md += "---\n\n" | |
# Create papers dataframe | |
papers_df = pd.DataFrame([ | |
{ | |
"Title": paper["title"][:50] + "..." if len(paper["title"]) > 50 else paper["title"], | |
"Authors": ", ".join(paper["authors"][:2]) + ("..." if len(paper["authors"]) > 2 else ""), | |
"Categories": ", ".join(paper["categories"][:2]), | |
"Published": paper["published"], | |
"URL": paper["url"] | |
} | |
for paper in result["papers"] | |
]) | |
return answer, papers_md, papers_df | |
except Exception as e: | |
logger.error(f"Error processing query: {e}") | |
error_msg = f"β An error occurred: {str(e)}\n\nPlease try different keywords or check your internet connection." | |
return error_msg, "", pd.DataFrame() | |
# Create Gradio interface | |
def create_interface(): | |
"""Create Gradio interface""" | |
css = """ | |
.gradio-container { | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
} | |
""" | |
with gr.Blocks(css=css, title="Enhanced ArXiv RAG System") as interface: | |
gr.HTML(""" | |
<div style="text-align: center; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;"> | |
<h1>π Enhanced ArXiv RAG System</h1> | |
<p>Advanced scientific paper discovery with semantic search, BM25, and neural reranking</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
query_input = gr.Textbox( | |
label="Research Query", | |
placeholder="Enter your research question (e.g., 'transformer attention mechanisms in NLP')", | |
lines=2 | |
) | |
with gr.Row(): | |
max_papers = gr.Slider(5, 25, value=15, step=1, label="Max Papers") | |
semantic_weight = gr.Slider(0.1, 0.9, value=0.7, step=0.1, label="Semantic Weight") | |
categories_input = gr.Textbox( | |
label="ArXiv Categories (Optional)", | |
placeholder="e.g., cs.CL, cs.AI, cs.LG", | |
value="" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
top_k_retrieval = gr.Slider(5, 15, value=10, step=1, label="Top-K Retrieval") | |
top_k_rerank = gr.Slider(3, 8, value=5, step=1, label="Top-K Reranking") | |
search_btn = gr.Button("π Search Papers", variant="primary") | |
with gr.Column(scale=1): | |
gr.HTML(""" | |
<div style="background: #e3f2fd; padding: 1rem; border-radius: 8px;"> | |
<h4>π‘ Tips</h4> | |
<ul> | |
<li>Use specific technical terms</li> | |
<li>Try different category filters</li> | |
<li>Adjust semantic weight for different search styles</li> | |
</ul> | |
<h4>π Categories</h4> | |
<ul> | |
<li><code>cs.AI</code> - Artificial Intelligence</li> | |
<li><code>cs.CL</code> - Computation and Language</li> | |
<li><code>cs.LG</code> - Machine Learning</li> | |
<li><code>cs.CV</code> - Computer Vision</li> | |
</ul> | |
</div> | |
""") | |
# Results | |
answer_output = gr.Markdown(label="AI Answer & Statistics") | |
with gr.Tabs(): | |
with gr.TabItem("π Papers"): | |
papers_output = gr.Markdown(label="Relevant Papers") | |
with gr.TabItem("π Papers Table"): | |
papers_table = gr.Dataframe(label="Papers Summary") | |
# Examples | |
gr.Examples( | |
examples=[ | |
["transformer attention mechanisms", 15, 10, 5, "cs.CL, cs.AI", 0.7], | |
["graph neural networks", 12, 8, 4, "cs.LG", 0.6], | |
["computer vision deep learning", 15, 10, 5, "cs.CV", 0.8], | |
["reinforcement learning", 18, 10, 5, "cs.AI", 0.7] | |
], | |
inputs=[query_input, max_papers, top_k_retrieval, top_k_rerank, categories_input, semantic_weight] | |
) | |
# Connect search function | |
search_btn.click( | |
fn=search_papers, | |
inputs=[query_input, max_papers, top_k_retrieval, top_k_rerank, categories_input, semantic_weight], | |
outputs=[answer_output, papers_output, papers_table] | |
) | |
gr.HTML(""" | |
<div style="text-align: center; margin-top: 2rem; padding: 1rem; background: #f5f5f5; border-radius: 8px;"> | |
<p><strong>Enhanced ArXiv RAG System</strong> | Semantic Search + BM25 + Neural Reranking</p> | |
</div> | |
""") | |
return interface | |
# Launch interface | |
if __name__ == "__main__": | |
interface = create_interface() | |
# Remove share=True for Hugging Face Spaces compatibility | |
interface.launch() | |