# app.py from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional import uvicorn import logging from src.RAGSample import setup_retriever, setup_rag_chain, RAGApplication import os from dotenv import load_dotenv # Load environment variables load_dotenv() # Create FastAPI app app = FastAPI( title="RAG API", description="A REST API for Retrieval-Augmented Generation using local vector database", version="1.0.0" ) # Initialize RAG components (this will be done once when the server starts) retriever = None rag_chain = None rag_application = None # Pydantic model for request class QuestionRequest(BaseModel): question: str # Pydantic model for response class QuestionResponse(BaseModel): question: str answer: str @app.on_event("startup") async def startup_event(): """Initialize RAG components when the server starts.""" global retriever, rag_chain, rag_application try: print("Initializing RAG components...") # Check if Kaggle credentials are provided via environment variables kaggle_username = os.getenv("KAGGLE_USERNAME") kaggle_key = os.getenv("KAGGLE_KEY") kaggle_dataset = os.getenv("KAGGLE_DATASET") # If no environment variables, try to load from kaggle.json if not (kaggle_username and kaggle_key): try: from src.kaggle_loader import KaggleDataLoader # Test if we can create a loader (this will auto-load from kaggle.json) test_loader = KaggleDataLoader() if test_loader.kaggle_username and test_loader.kaggle_key: kaggle_username = test_loader.kaggle_username kaggle_key = test_loader.kaggle_key print(f"Loaded Kaggle credentials from kaggle.json: {kaggle_username}") except Exception as e: print(f"Could not load Kaggle credentials from kaggle.json: {e}") if kaggle_username and kaggle_key and kaggle_dataset: print(f"Loading Kaggle dataset: {kaggle_dataset}") retriever = setup_retriever( use_kaggle_data=True, kaggle_dataset=kaggle_dataset, kaggle_username=kaggle_username, kaggle_key=kaggle_key ) else: print("Loading mental health FAQ data from local file...") # Load mental health FAQ data from local file (default behavior) retriever = setup_retriever() rag_chain = setup_rag_chain() rag_application = RAGApplication(retriever, rag_chain) print("RAG components initialized successfully!") except Exception as e: print(f"Error initializing RAG components: {e}") raise @app.get("/") async def root(): """Root endpoint with API information.""" return { "message": "RAG API is running", "endpoints": { "ask_question": "/ask", "health_check": "/health", "load_kaggle_dataset": "/load-kaggle-dataset" } } @app.get("/health") async def health_check(): """Health check endpoint.""" return { "status": "healthy", "rag_initialized": rag_application is not None } @app.post("/medical/ask", response_model=QuestionResponse) async def ask_question(request: QuestionRequest): """Ask a question and get an answer using RAG.""" if rag_application is None: raise HTTPException(status_code=500, detail="RAG application not initialized") try: print(f"Processing question: {request.question}") # Debug: Check what retriever we're using retriever_type = type(rag_application.retriever).__name__ print(f"DEBUG: Using retriever type: {retriever_type}") answer = rag_application.run(request.question) return QuestionResponse( question=request.question, answer=answer ) except Exception as e: print(f"Error processing question: {e}") raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}") @app.post("/load-kaggle-dataset") async def load_kaggle_dataset(dataset_name: str): """Load a Kaggle dataset for RAG.""" try: from src.kaggle_loader import KaggleDataLoader # Create loader without parameters - it will auto-load from kaggle.json loader = KaggleDataLoader() # Download the dataset dataset_path = loader.download_dataset(dataset_name) # Reload the retriever with the new dataset global rag_application retriever = setup_retriever(use_kaggle_data=True, kaggle_dataset=dataset_name) rag_chain = setup_rag_chain() rag_application = RAGApplication(retriever, rag_chain) return { "status": "success", "message": f"Dataset {dataset_name} loaded successfully", "dataset_path": dataset_path } except Exception as e: return {"status": "error", "message": str(e)} @app.get("/models") async def get_models(): """Get information about available models.""" return { "llm_model": "dolphin-llama3:8b", "embedding_model": "TF-IDF embeddings", "vector_database": "ChromaDB (local)" } logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) if __name__ == "__main__": try: logger.info("Starting application...") # Add any initialization code here with try/except blocks port = int(os.getenv("PORT", 7860)) logger.info(f"Starting server on port {port}") uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") except Exception as e: logger.error(f"Failed to start application: {e}") raise