import gradio as gr import google.generativeai as genai from google.generativeai.types import generation_types from ragatouille import RAGPretrainedModel import arxiv import os import re from datetime import datetime from utils import get_md_text_abstract from huggingface_hub import snapshot_download # --- Core Configuration --- hf_token = os.getenv("HF_TOKEN") gemini_api_key = os.getenv("GEMINI_API_KEY") RAG_SOURCE = os.getenv("RAG_SOURCE") LOCAL_DATA_DIR = './rag_index_data' LLM_MODELS_TO_CHOOSE = [ 'google/gemma-3-4b-it', 'google/gemma-3-12b-it', 'google/gemma-3-27b-it', 'None' ] DEFAULT_LLM_MODEL = 'google/gemma-3-4b-it' RETRIEVE_RESULTS = 20 # --- Gemini API Configuration --- if gemini_api_key: genai.configure(api_key=gemini_api_key) else: print("CRITICAL WARNING: GEMINI_API_KEY environment variable not set. The application will not function without it.") GEMINI_GENERATION_CONFIG = genai.types.GenerationConfig( temperature=0.2, max_output_tokens=450, top_p=0.95, ) # --- RAG & Data Source Setup --- try: gr.Info("Setting up the RAG retriever...") # If the local index directory doesn't exist, download it from Hugging Face. if not os.path.exists(LOCAL_DATA_DIR): if not RAG_SOURCE or not hf_token: raise ValueError("RAG index not found locally, and RAG_SOURCE or HF_TOKEN environment variables are not set. Cannot download index.") snapshot_download( repo_id=RAG_SOURCE, repo_type="dataset", token=hf_token, local_dir=LOCAL_DATA_DIR ) gr.Info("Index downloaded successfully.") else: gr.Info(f"Found existing local index at {LOCAL_DATA_DIR}.") # Load the RAG model from the (now existing) local index path. gr.Info(f'''Loading index from {os.path.join(LOCAL_DATA_DIR, "arxiv_colbert")}...''') RAG = RAGPretrainedModel.from_index(os.path.join(LOCAL_DATA_DIR, "arxiv_colbert")) _ = RAG.search("Test query", k=1) # Warm-up query gr.Info("Retriever loaded successfully!") except Exception as e: gr.Warning(f"Could not initialize the RAG retriever. The app may not function correctly. Error: {e}") RAG = None # --- UI Text and Metadata --- MARKDOWN_SEARCH_RESULTS_HEADER = '# 🔍 Search Results\n' APP_HEADER_TEXT = "# ArXiv CS RAG\n" INDEX_INFO = "Semantic Search" try: with open("README.md", "r") as f: mdfile = f.read() date_match = re.search(r'Index Last Updated : (\d{4}-\d{2}-\d{2})', mdfile) if date_match: date = date_match.group(1) formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y') APP_HEADER_TEXT += f'Index Last Updated: {formatted_date}\n' INDEX_INFO = f"Semantic Search - up to {formatted_date}" except Exception: print("README.md not found or is invalid. Using default data source info.") DATABASE_CHOICES = [INDEX_INFO, 'Arxiv Search - Latest - (EXPERIMENTAL)'] ARX_CLIENT = arxiv.Client() # --- Helper Functions --- def get_prompt_text(question, context): """Formats the prompt for the Gemma 3 model.""" system_instruction = ( "Based on the provided scientific paper abstracts, provide a comprehensive answer of 6-7 lines. " "Synthesize information from multiple sources if possible. Your answer must be grounded in the " "details found in the abstracts. Cite the titles of the papers you use as sources in your answer." ) message = f"Abstracts:\n{context}\n\nQuestion: {question}" return f"{system_instruction}\n\n{message}" def update_with_rag_md(message, llm_results_use, database_choice): """Fetches documents, updates the UI, and creates the final prompt for the LLM.""" prompt_context = "" rag_out = [] source_used = database_choice try: if database_choice == INDEX_INFO and RAG: rag_out = RAG.search(message, k=RETRIEVE_RESULTS) else: rag_out = list(ARX_CLIENT.results(arxiv.Search(query=message, max_results=RETRIEVE_RESULTS, sort_by=arxiv.SortCriterion.Relevance))) if not rag_out: gr.Warning("Live Arxiv search returned no results. Falling back to semantic search.") if RAG: rag_out = RAG.search(message, k=RETRIEVE_RESULTS) source_used = INDEX_INFO except Exception as e: gr.Warning(f"An error occurred during search: {e}. Falling back to semantic search.") if RAG: rag_out = RAG.search(message, k=RETRIEVE_RESULTS) source_used = INDEX_INFO md_text_updated = MARKDOWN_SEARCH_RESULTS_HEADER for i, rag_answer in enumerate(rag_out): md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=source_used, return_prompt_formatting=True) if i < llm_results_use: prompt_context += f"{i+1}. {prompt_text}\n" md_text_updated += md_text_paper final_prompt = get_prompt_text(message, prompt_context) return md_text_updated, final_prompt def ask_gemma_llm(prompt, llm_model_picked, stream_outputs): """Sends a prompt to the Google Gemini API and streams the response.""" if not prompt or not prompt.strip(): yield "Error: The generated prompt is empty. Please try a different query." return if llm_model_picked == 'None': yield "LLM Model is disabled." return if not gemini_api_key: yield "Error: GEMINI_API_KEY is not configured. Cannot contact the LLM." return try: safety_settings = [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}, ] gemini_model_name = llm_model_picked.split('/')[-1] model = genai.GenerativeModel(gemini_model_name) response = model.generate_content( prompt, generation_config=GEMINI_GENERATION_CONFIG, stream=stream_outputs, safety_settings=safety_settings ) if stream_outputs: output = "" for chunk in response: try: text = chunk.parts[0].text output += text yield output except (IndexError, AttributeError): # Ignore empty chunks, which can occur at the end of a stream. pass if not output: yield "Model returned an empty or blocked stream. This may be due to the safety settings or the nature of the prompt." else: # Handle non-streaming responses. try: yield response.parts[0].text except (IndexError, AttributeError): reason = "UNKNOWN" if response.prompt_feedback.block_reason: reason = response.prompt_feedback.block_reason.name elif response.candidates and not response.candidates[0].content.parts: reason = response.candidates[0].finish_reason.name yield f"Model returned an empty or blocked response." except Exception as e: error_message = f"An error occurred with the Gemini API: {e}" print(error_message) # Server side log gr.Warning("An error occurred with the Gemini API. Check the server logs for details.") yield error_message # --- Gradio User Interface --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(APP_HEADER_TEXT) with gr.Group(): msg = gr.Textbox(label='Search', placeholder='e.g., What is Mixtral?') with gr.Accordion("Advanced Settings", open=False): llm_model = gr.Dropdown(choices=LLM_MODELS_TO_CHOOSE, value=DEFAULT_LLM_MODEL, label='LLM Model') llm_results = gr.Slider(5, 20, value=10, step=1, label="Top n results as context") database_src = gr.Dropdown(choices=DATABASE_CHOICES, value=INDEX_INFO, label='Search Source') stream_results = gr.Checkbox(value=True, label="Stream output") output_text = gr.Textbox(label='LLM Answer', placeholder="The model's answer will appear here...", interactive=False, lines=8) input_prompt = gr.Textbox(visible=False) gr_md = gr.Markdown(MARKDOWN_SEARCH_RESULTS_HEADER) msg.submit( fn=update_with_rag_md, inputs=[msg, llm_results, database_src], outputs=[gr_md, input_prompt] ).then( fn=ask_gemma_llm, inputs=[input_prompt, llm_model, stream_results], outputs=[output_text] ) if __name__ == "__main__": # Launch the app demo.queue().launch()