Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
"""
|
2 |
-
Enhanced ArXiv RAG System - Hugging Face Spaces Compatible Version
|
3 |
"""
|
4 |
|
5 |
import os
|
@@ -13,7 +13,6 @@ from datetime import datetime, timedelta
|
|
13 |
import logging
|
14 |
import tempfile
|
15 |
import shutil
|
16 |
-
import spaces
|
17 |
|
18 |
# Core ML libraries
|
19 |
import torch
|
@@ -31,14 +30,14 @@ from nltk.stem import PorterStemmer
|
|
31 |
|
32 |
# Download required NLTK data
|
33 |
try:
|
34 |
-
nltk.data.find(
|
35 |
except LookupError:
|
36 |
-
nltk.download(
|
37 |
|
38 |
try:
|
39 |
-
nltk.data.find(
|
40 |
except LookupError:
|
41 |
-
nltk.download(
|
42 |
|
43 |
# Setup logging
|
44 |
logging.basicConfig(level=logging.INFO)
|
@@ -75,7 +74,7 @@ class BM25Retriever:
|
|
75 |
self.avg_doc_length = 0
|
76 |
self.stemmer = PorterStemmer()
|
77 |
try:
|
78 |
-
self.stop_words = set(stopwords.words(
|
79 |
except:
|
80 |
self.stop_words = set()
|
81 |
|
@@ -151,7 +150,7 @@ class SimpleVectorStore:
|
|
151 |
def query(self, query_embedding: List[float], n_results: int = 10) -> Dict:
|
152 |
"""Query the vector store"""
|
153 |
if not self.embeddings:
|
154 |
-
return {
|
155 |
|
156 |
# Calculate cosine similarities
|
157 |
query_embedding = np.array(query_embedding)
|
@@ -168,25 +167,25 @@ class SimpleVectorStore:
|
|
168 |
top_indices = np.argsort(similarities)[::-1][:n_results]
|
169 |
|
170 |
return {
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
}
|
175 |
|
176 |
def get(self, ids: Optional[List[str]] = None) -> Dict:
|
177 |
"""Get documents by IDs or all documents"""
|
178 |
if ids is None:
|
179 |
return {
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
}
|
184 |
else:
|
185 |
indices = [self.ids.index(id_) for id_ in ids if id_ in self.ids]
|
186 |
return {
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
}
|
191 |
|
192 |
def clear(self):
|
@@ -202,16 +201,29 @@ class EnhancedArxivRAG:
|
|
202 |
def __init__(self):
|
203 |
logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
|
204 |
|
205 |
-
#
|
206 |
-
self.
|
207 |
-
|
208 |
-
|
209 |
-
#
|
210 |
-
|
211 |
-
"
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
# Use simple vector store instead of ChromaDB for HF Spaces
|
217 |
self.vector_store = SimpleVectorStore()
|
@@ -244,9 +256,9 @@ class EnhancedArxivRAG:
|
|
244 |
papers = []
|
245 |
for result in search.results():
|
246 |
paper = Paper(
|
247 |
-
id=result.entry_id.split(
|
248 |
-
title=result.title.strip().replace(
|
249 |
-
abstract=result.summary.strip().replace(
|
250 |
authors=[author.name for author in result.authors],
|
251 |
categories=result.categories,
|
252 |
published=result.published.replace(tzinfo=None),
|
@@ -318,7 +330,6 @@ class EnhancedArxivRAG:
|
|
318 |
|
319 |
return chunks
|
320 |
|
321 |
-
@spaces.GPU(duration=60) # GPU decorator for processing
|
322 |
def process_and_store(self, papers: List[Paper]):
|
323 |
"""Process papers and store in vector store"""
|
324 |
logger.info("Processing and storing papers...")
|
@@ -372,12 +383,12 @@ class EnhancedArxivRAG:
|
|
372 |
bm25_scores = self.bm25_retriever.score(query, top_k * 2)
|
373 |
|
374 |
for idx, score in bm25_scores:
|
375 |
-
if idx < len(all_docs[
|
376 |
bm25_results.append({
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
})
|
382 |
|
383 |
# Combine results using RRF
|
@@ -385,13 +396,13 @@ class EnhancedArxivRAG:
|
|
385 |
bm25_weight = 1.0 - semantic_weight
|
386 |
|
387 |
# Add semantic scores
|
388 |
-
for i, doc_id in enumerate(semantic_results[
|
389 |
rank = i + 1
|
390 |
combined_scores[doc_id] = combined_scores.get(doc_id, 0) + semantic_weight / rank
|
391 |
|
392 |
# Add BM25 scores
|
393 |
for i, result in enumerate(bm25_results):
|
394 |
-
doc_id = result[
|
395 |
rank = i + 1
|
396 |
combined_scores[doc_id] = combined_scores.get(doc_id, 0) + bm25_weight / rank
|
397 |
|
@@ -402,34 +413,33 @@ class EnhancedArxivRAG:
|
|
402 |
final_results = []
|
403 |
for doc_id, score in sorted_results[:top_k]:
|
404 |
doc_result = self.vector_store.get(ids=[doc_id])
|
405 |
-
if doc_result[
|
406 |
final_results.append({
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
})
|
412 |
|
413 |
return final_results
|
414 |
|
415 |
-
@spaces.GPU(duration=30) # GPU decorator for reranking
|
416 |
def rerank_results(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
|
417 |
"""Rerank results using cross-encoder"""
|
418 |
if not results:
|
419 |
return results
|
420 |
|
421 |
# Prepare query-document pairs
|
422 |
-
query_doc_pairs = [(query, result[
|
423 |
|
424 |
# Get reranking scores
|
425 |
rerank_scores = self.reranker.predict(query_doc_pairs)
|
426 |
|
427 |
# Add rerank scores to results
|
428 |
for i, result in enumerate(results):
|
429 |
-
result[
|
430 |
|
431 |
# Sort by rerank score
|
432 |
-
reranked_results = sorted(results, key=lambda x: x[
|
433 |
return reranked_results[:top_k]
|
434 |
|
435 |
def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
|
@@ -438,7 +448,7 @@ class EnhancedArxivRAG:
|
|
438 |
return "No relevant information found to answer your query."
|
439 |
|
440 |
# Combine context from top chunks
|
441 |
-
context_texts = [chunk[
|
442 |
combined_context = "\n\n".join(context_texts)
|
443 |
|
444 |
# Limit context length
|
@@ -451,14 +461,13 @@ class EnhancedArxivRAG:
|
|
451 |
summary = self.summarizer(summary_input,
|
452 |
max_length=120,
|
453 |
min_length=30,
|
454 |
-
do_sample=False)[0][
|
455 |
return summary
|
456 |
except Exception as e:
|
457 |
logger.error(f"Error generating summary: {e}")
|
458 |
-
return f"Based on the retrieved papers about '{query}', here are the key findings:\n\n" + \
|
459 |
-
"\n\n".join([chunk[
|
460 |
|
461 |
-
@spaces.GPU(duration=120) # Main GPU decorator for the full pipeline
|
462 |
def search_and_answer(self, query: str, max_papers: int = 15,
|
463 |
top_k_retrieval: int = 10, top_k_rerank: int = 5,
|
464 |
categories: Optional[List[str]] = None,
|
@@ -467,10 +476,10 @@ class EnhancedArxivRAG:
|
|
467 |
|
468 |
if not query.strip():
|
469 |
return {
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
}
|
475 |
|
476 |
try:
|
@@ -479,10 +488,10 @@ class EnhancedArxivRAG:
|
|
479 |
|
480 |
if not papers:
|
481 |
return {
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
}
|
487 |
|
488 |
# Process and store papers
|
@@ -500,36 +509,36 @@ class EnhancedArxivRAG:
|
|
500 |
# Prepare unique papers
|
501 |
unique_papers = {}
|
502 |
for chunk in reranked_results:
|
503 |
-
paper_id = chunk[
|
504 |
if paper_id in self.papers_cache and paper_id not in unique_papers:
|
505 |
paper = self.papers_cache[paper_id]
|
506 |
unique_papers[paper_id] = {
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
}
|
514 |
|
515 |
return {
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
}
|
524 |
}
|
525 |
|
526 |
except Exception as e:
|
527 |
logger.error(f"Error in search_and_answer: {e}")
|
528 |
return {
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
}
|
534 |
|
535 |
# Global RAG instance
|
@@ -548,7 +557,7 @@ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
|
|
548 |
"""Main search function for Gradio interface"""
|
549 |
|
550 |
if not query.strip():
|
551 |
-
return "❌ Please enter a research topic or question.", "",
|
552 |
|
553 |
try:
|
554 |
# Initialize RAG system
|
@@ -557,7 +566,7 @@ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
|
|
557 |
# Parse categories
|
558 |
category_list = None
|
559 |
if categories.strip():
|
560 |
-
category_list = [cat.strip() for cat in categories.split(
|
561 |
|
562 |
# Perform search
|
563 |
result = rag.search_and_answer(
|
@@ -570,33 +579,33 @@ def search_papers(query: str, max_papers: int = 15, top_k_retrieval: int = 10,
|
|
570 |
)
|
571 |
|
572 |
# Format answer
|
573 |
-
answer = f"## 🤖 AI-Generated Answer\n\n{result[
|
574 |
answer += f"**Search Statistics:**\n"
|
575 |
-
answer += f"- Papers found: {result[
|
576 |
-
answer += f"- Chunks retrieved: {result[
|
577 |
-
answer += f"- Unique papers in results: {result[
|
578 |
|
579 |
# Format papers
|
580 |
papers_md = "## 📚 Relevant Papers\n\n"
|
581 |
-
for i, paper in enumerate(result[
|
582 |
-
papers_md += f"### {i}. {paper[
|
583 |
-
papers_md += f"**Authors:** {
|
584 |
-
papers_md += f"**Categories:** {
|
585 |
-
papers_md += f"**Published:** {paper[
|
586 |
-
papers_md += f"**Abstract:** {paper[
|
587 |
-
papers_md += f"**URL:** [{paper[
|
588 |
papers_md += "---\n\n"
|
589 |
|
590 |
# Create papers dataframe
|
591 |
papers_df = pd.DataFrame([
|
592 |
{
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
}
|
599 |
-
for paper in result[
|
600 |
])
|
601 |
|
602 |
return answer, papers_md, papers_df
|
@@ -709,4 +718,6 @@ def create_interface():
|
|
709 |
# Launch interface
|
710 |
if __name__ == "__main__":
|
711 |
interface = create_interface()
|
712 |
-
|
|
|
|
|
|
1 |
"""
|
2 |
+
Enhanced ArXiv RAG System - Hugging Face Spaces Compatible Version
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
13 |
import logging
|
14 |
import tempfile
|
15 |
import shutil
|
|
|
16 |
|
17 |
# Core ML libraries
|
18 |
import torch
|
|
|
30 |
|
31 |
# Download required NLTK data
|
32 |
try:
|
33 |
+
nltk.data.find("tokenizers/punkt")
|
34 |
except LookupError:
|
35 |
+
nltk.download("punkt")
|
36 |
|
37 |
try:
|
38 |
+
nltk.data.find("corpora/stopwords")
|
39 |
except LookupError:
|
40 |
+
nltk.download("stopwords")
|
41 |
|
42 |
# Setup logging
|
43 |
logging.basicConfig(level=logging.INFO)
|
|
|
74 |
self.avg_doc_length = 0
|
75 |
self.stemmer = PorterStemmer()
|
76 |
try:
|
77 |
+
self.stop_words = set(stopwords.words("english"))
|
78 |
except:
|
79 |
self.stop_words = set()
|
80 |
|
|
|
150 |
def query(self, query_embedding: List[float], n_results: int = 10) -> Dict:
|
151 |
"""Query the vector store"""
|
152 |
if not self.embeddings:
|
153 |
+
return {"ids": [[]], "documents": [[]], "metadatas": [[]]}
|
154 |
|
155 |
# Calculate cosine similarities
|
156 |
query_embedding = np.array(query_embedding)
|
|
|
167 |
top_indices = np.argsort(similarities)[::-1][:n_results]
|
168 |
|
169 |
return {
|
170 |
+
"ids": [[self.ids[i] for i in top_indices]],
|
171 |
+
"documents": [[self.documents[i] for i in top_indices]],
|
172 |
+
"metadatas": [[self.metadatas[i] for i in top_indices]]
|
173 |
}
|
174 |
|
175 |
def get(self, ids: Optional[List[str]] = None) -> Dict:
|
176 |
"""Get documents by IDs or all documents"""
|
177 |
if ids is None:
|
178 |
return {
|
179 |
+
"ids": self.ids,
|
180 |
+
"documents": self.documents,
|
181 |
+
"metadatas": self.metadatas
|
182 |
}
|
183 |
else:
|
184 |
indices = [self.ids.index(id_) for id_ in ids if id_ in self.ids]
|
185 |
return {
|
186 |
+
"ids": [self.ids[i] for i in indices],
|
187 |
+
"documents": [self.documents[i] for i in indices],
|
188 |
+
"metadatas": [self.metadatas[i] for i in indices]
|
189 |
}
|
190 |
|
191 |
def clear(self):
|
|
|
201 |
def __init__(self):
|
202 |
logger.info("Initializing Enhanced ArXiv RAG System for HF Spaces...")
|
203 |
|
204 |
+
# Determine device (GPU if available, else CPU)
|
205 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
206 |
+
logger.info(f"Using device: {self.device}")
|
207 |
+
|
208 |
+
# Load models with appropriate device settings
|
209 |
+
try:
|
210 |
+
logger.info("Loading embedding model...")
|
211 |
+
self.embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=self.device)
|
212 |
+
logger.info("Embedding model loaded.")
|
213 |
+
|
214 |
+
logger.info("Loading reranker model...")
|
215 |
+
self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-2-v2", device=self.device)
|
216 |
+
logger.info("Reranker model loaded.")
|
217 |
+
|
218 |
+
logger.info("Loading summarizer model...")
|
219 |
+
# For pipeline, device_map="auto" is often better for ZeroGPU
|
220 |
+
# If issues persist, try device=0 for the first GPU, or device=self.device
|
221 |
+
self.summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device_map="auto")
|
222 |
+
logger.info("Summarizer model loaded.")
|
223 |
+
|
224 |
+
except Exception as e:
|
225 |
+
logger.error(f"Error loading models: {e}")
|
226 |
+
raise
|
227 |
|
228 |
# Use simple vector store instead of ChromaDB for HF Spaces
|
229 |
self.vector_store = SimpleVectorStore()
|
|
|
256 |
papers = []
|
257 |
for result in search.results():
|
258 |
paper = Paper(
|
259 |
+
id=result.entry_id.split("/")[-1],
|
260 |
+
title=result.title.strip().replace("\n", " "),
|
261 |
+
abstract=result.summary.strip().replace("\n", " "),
|
262 |
authors=[author.name for author in result.authors],
|
263 |
categories=result.categories,
|
264 |
published=result.published.replace(tzinfo=None),
|
|
|
330 |
|
331 |
return chunks
|
332 |
|
|
|
333 |
def process_and_store(self, papers: List[Paper]):
|
334 |
"""Process papers and store in vector store"""
|
335 |
logger.info("Processing and storing papers...")
|
|
|
383 |
bm25_scores = self.bm25_retriever.score(query, top_k * 2)
|
384 |
|
385 |
for idx, score in bm25_scores:
|
386 |
+
if idx < len(all_docs["ids"]):
|
387 |
bm25_results.append({
|
388 |
+
"id": all_docs["ids"][idx],
|
389 |
+
"document": all_docs["documents"][idx],
|
390 |
+
"metadata": all_docs["metadatas"][idx],
|
391 |
+
"score": score
|
392 |
})
|
393 |
|
394 |
# Combine results using RRF
|
|
|
396 |
bm25_weight = 1.0 - semantic_weight
|
397 |
|
398 |
# Add semantic scores
|
399 |
+
for i, doc_id in enumerate(semantic_results["ids"][0]):
|
400 |
rank = i + 1
|
401 |
combined_scores[doc_id] = combined_scores.get(doc_id, 0) + semantic_weight / rank
|
402 |
|
403 |
# Add BM25 scores
|
404 |
for i, result in enumerate(bm25_results):
|
405 |
+
doc_id = result["id"]
|
406 |
rank = i + 1
|
407 |
combined_scores[doc_id] = combined_scores.get(doc_id, 0) + bm25_weight / rank
|
408 |
|
|
|
413 |
final_results = []
|
414 |
for doc_id, score in sorted_results[:top_k]:
|
415 |
doc_result = self.vector_store.get(ids=[doc_id])
|
416 |
+
if doc_result["ids"]:
|
417 |
final_results.append({
|
418 |
+
"id": doc_id,
|
419 |
+
"document": doc_result["documents"][0],
|
420 |
+
"metadata": doc_result["metadatas"][0],
|
421 |
+
"combined_score": score
|
422 |
})
|
423 |
|
424 |
return final_results
|
425 |
|
|
|
426 |
def rerank_results(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
|
427 |
"""Rerank results using cross-encoder"""
|
428 |
if not results:
|
429 |
return results
|
430 |
|
431 |
# Prepare query-document pairs
|
432 |
+
query_doc_pairs = [(query, result["document"]) for result in results]
|
433 |
|
434 |
# Get reranking scores
|
435 |
rerank_scores = self.reranker.predict(query_doc_pairs)
|
436 |
|
437 |
# Add rerank scores to results
|
438 |
for i, result in enumerate(results):
|
439 |
+
result["rerank_score"] = float(rerank_scores[i])
|
440 |
|
441 |
# Sort by rerank score
|
442 |
+
reranked_results = sorted(results, key=lambda x: x["rerank_score"], reverse=True)
|
443 |
return reranked_results[:top_k]
|
444 |
|
445 |
def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
|
|
|
448 |
return "No relevant information found to answer your query."
|
449 |
|
450 |
# Combine context from top chunks
|
451 |
+
context_texts = [chunk["document"] for chunk in context_chunks[:3]]
|
452 |
combined_context = "\n\n".join(context_texts)
|
453 |
|
454 |
# Limit context length
|
|
|
461 |
summary = self.summarizer(summary_input,
|
462 |
max_length=120,
|
463 |
min_length=30,
|
464 |
+
do_sample=False)[0]["summary_text"]
|
465 |
return summary
|
466 |
except Exception as e:
|
467 |
logger.error(f"Error generating summary: {e}")
|
468 |
+
return f"Based on the retrieved papers about \'{query}\', here are the key findings:\n\n" + \
|
469 |
+
"\n\n".join([chunk["document"][:150] + "..." for chunk in context_chunks[:2]])
|
470 |
|
|
|
471 |
def search_and_answer(self, query: str, max_papers: int = 15,
|
472 |
top_k_retrieval: int = 10, top_k_rerank: int = 5,
|
473 |
categories: Optional[List[str]] = None,
|
|
|
476 |
|
477 |
if not query.strip():
|
478 |
return {
|
479 |
+
"answer": "Please enter a valid research query.",
|
480 |
+
"papers": [],
|
481 |
+
"retrieved_chunks": [],
|
482 |
+
"search_stats": {"papers_found": 0, "chunks_retrieved": 0}
|
483 |
}
|
484 |
|
485 |
try:
|
|
|
488 |
|
489 |
if not papers:
|
490 |
return {
|
491 |
+
"answer": "No papers found for your query. Please try different keywords.",
|
492 |
+
"papers": [],
|
493 |
+
"retrieved_chunks": [],
|
494 |
+
"search_stats": {"papers_found": 0, "chunks_retrieved": 0}
|
495 |
}
|
496 |
|
497 |
# Process and store papers
|
|
|
509 |
# Prepare unique papers
|
510 |
unique_papers = {}
|
511 |
for chunk in reranked_results:
|
512 |
+
paper_id = chunk["id"].split("_")[0]
|
513 |
if paper_id in self.papers_cache and paper_id not in unique_papers:
|
514 |
paper = self.papers_cache[paper_id]
|
515 |
unique_papers[paper_id] = {
|
516 |
+
"title": paper.title,
|
517 |
+
"authors": paper.authors,
|
518 |
+
"abstract": paper.abstract,
|
519 |
+
"url": paper.url,
|
520 |
+
"categories": paper.categories,
|
521 |
+
"published": paper.published.strftime("%Y-%m-%d")
|
522 |
}
|
523 |
|
524 |
return {
|
525 |
+
"answer": answer,
|
526 |
+
"papers": list(unique_papers.values()),
|
527 |
+
"retrieved_chunks": reranked_results,
|
528 |
+
"search_stats": {
|
529 |
+
"papers_found": len(papers),
|
530 |
+
"chunks_retrieved": len(reranked_results),
|
531 |
+
"unique_papers_in_results": len(unique_papers)
|
532 |
}
|
533 |
}
|
534 |
|
535 |
except Exception as e:
|
536 |
logger.error(f"Error in search_and_answer: {e}")
|
537 |
return {
|
538 |
+
"answer": f"An error occurred while processing your query: {str(e)}",
|
539 |
+
"papers": [],
|
540 |
+
"retrieved_chunks": [],
|
541 |
+
"search_stats": {"papers_found": 0, "chunks_retrieved": 0}
|
542 |
}
|
543 |
|
544 |
# Global RAG instance
|
|
|
557 |
"""Main search function for Gradio interface"""
|
558 |
|
559 |
if not query.strip():
|
560 |
+
return "❌ Please enter a research topic or question.", "", pd.DataFrame()
|
561 |
|
562 |
try:
|
563 |
# Initialize RAG system
|
|
|
566 |
# Parse categories
|
567 |
category_list = None
|
568 |
if categories.strip():
|
569 |
+
category_list = [cat.strip() for cat in categories.split(",") if cat.strip()]
|
570 |
|
571 |
# Perform search
|
572 |
result = rag.search_and_answer(
|
|
|
579 |
)
|
580 |
|
581 |
# Format answer
|
582 |
+
answer = f"## 🤖 AI-Generated Answer\n\n{result["answer"]}\n\n"
|
583 |
answer += f"**Search Statistics:**\n"
|
584 |
+
answer += f"- Papers found: {result["search_stats"]["papers_found"]}\n"
|
585 |
+
answer += f"- Chunks retrieved: {result["search_stats"]["chunks_retrieved"]}\n"
|
586 |
+
answer += f"- Unique papers in results: {result["search_stats"]["unique_papers_in_results"]}\n\n"
|
587 |
|
588 |
# Format papers
|
589 |
papers_md = "## 📚 Relevant Papers\n\n"
|
590 |
+
for i, paper in enumerate(result["papers"], 1):
|
591 |
+
papers_md += f"### {i}. {paper["title"]}\n\n"
|
592 |
+
papers_md += f"**Authors:** {", ".join(paper["authors"][:3])}{"..." if len(paper["authors"]) > 3 else ""}\n\n"
|
593 |
+
papers_md += f"**Categories:** {", ".join(paper["categories"])}\n\n"
|
594 |
+
papers_md += f"**Published:** {paper["published"]}\n\n"
|
595 |
+
papers_md += f"**Abstract:** {paper["abstract"][:250]}{"..." if len(paper["abstract"]) > 250 else ""}\n\n"
|
596 |
+
papers_md += f"**URL:** [{paper["url"]}]({paper["url"]})\n\n"
|
597 |
papers_md += "---\n\n"
|
598 |
|
599 |
# Create papers dataframe
|
600 |
papers_df = pd.DataFrame([
|
601 |
{
|
602 |
+
"Title": paper["title"][:50] + "..." if len(paper["title"]) > 50 else paper["title"],
|
603 |
+
"Authors": ", ".join(paper["authors"][:2]) + ("..." if len(paper["authors"]) > 2 else ""),
|
604 |
+
"Categories": ", ".join(paper["categories"][:2]),
|
605 |
+
"Published": paper["published"],
|
606 |
+
"URL": paper["url"]
|
607 |
}
|
608 |
+
for paper in result["papers"]
|
609 |
])
|
610 |
|
611 |
return answer, papers_md, papers_df
|
|
|
718 |
# Launch interface
|
719 |
if __name__ == "__main__":
|
720 |
interface = create_interface()
|
721 |
+
# Remove share=True for Hugging Face Spaces compatibility
|
722 |
+
interface.launch()
|
723 |
+
|