Spaces:
Sleeping
Sleeping
# app.py | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import Optional | |
import uvicorn | |
import logging | |
from src.RAGSample import setup_retriever, setup_rag_chain, RAGApplication | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Create FastAPI app | |
app = FastAPI( | |
title="RAG API", | |
description="A REST API for Retrieval-Augmented Generation using local vector database", | |
version="1.0.0" | |
) | |
# Initialize RAG components (this will be done once when the server starts) | |
retriever = None | |
rag_chain = None | |
rag_application = None | |
# Pydantic model for request | |
class QuestionRequest(BaseModel): | |
question: str | |
# Pydantic model for response | |
class QuestionResponse(BaseModel): | |
question: str | |
answer: str | |
async def startup_event(): | |
"""Initialize RAG components when the server starts.""" | |
global retriever, rag_chain, rag_application | |
try: | |
print("Initializing RAG components...") | |
# Check if Kaggle credentials are provided via environment variables | |
kaggle_username = os.getenv("KAGGLE_USERNAME") | |
kaggle_key = os.getenv("KAGGLE_KEY") | |
kaggle_dataset = os.getenv("KAGGLE_DATASET") | |
# If no environment variables, try to load from kaggle.json | |
if not (kaggle_username and kaggle_key): | |
try: | |
from src.kaggle_loader import KaggleDataLoader | |
# Test if we can create a loader (this will auto-load from kaggle.json) | |
test_loader = KaggleDataLoader() | |
if test_loader.kaggle_username and test_loader.kaggle_key: | |
kaggle_username = test_loader.kaggle_username | |
kaggle_key = test_loader.kaggle_key | |
print(f"Loaded Kaggle credentials from kaggle.json: {kaggle_username}") | |
except Exception as e: | |
print(f"Could not load Kaggle credentials from kaggle.json: {e}") | |
if kaggle_username and kaggle_key and kaggle_dataset: | |
print(f"Loading Kaggle dataset: {kaggle_dataset}") | |
retriever = setup_retriever( | |
use_kaggle_data=True, | |
kaggle_dataset=kaggle_dataset, | |
kaggle_username=kaggle_username, | |
kaggle_key=kaggle_key | |
) | |
else: | |
print("Loading mental health FAQ data from local file...") | |
# Load mental health FAQ data from local file (default behavior) | |
retriever = setup_retriever() | |
rag_chain = setup_rag_chain() | |
rag_application = RAGApplication(retriever, rag_chain) | |
print("RAG components initialized successfully!") | |
except Exception as e: | |
print(f"Error initializing RAG components: {e}") | |
raise | |
async def root(): | |
"""Root endpoint with API information.""" | |
return { | |
"message": "RAG API is running", | |
"endpoints": { | |
"ask_question": "/ask", | |
"health_check": "/health", | |
"load_kaggle_dataset": "/load-kaggle-dataset" | |
} | |
} | |
async def health_check(): | |
"""Health check endpoint.""" | |
return { | |
"status": "healthy", | |
"rag_initialized": rag_application is not None | |
} | |
async def ask_question(request: QuestionRequest): | |
"""Ask a question and get an answer using RAG.""" | |
if rag_application is None: | |
raise HTTPException(status_code=500, detail="RAG application not initialized") | |
try: | |
print(f"Processing question: {request.question}") | |
# Debug: Check what retriever we're using | |
retriever_type = type(rag_application.retriever).__name__ | |
print(f"DEBUG: Using retriever type: {retriever_type}") | |
answer = rag_application.run(request.question) | |
return QuestionResponse( | |
question=request.question, | |
answer=answer | |
) | |
except Exception as e: | |
print(f"Error processing question: {e}") | |
raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}") | |
async def load_kaggle_dataset(dataset_name: str): | |
"""Load a Kaggle dataset for RAG.""" | |
try: | |
from src.kaggle_loader import KaggleDataLoader | |
# Create loader without parameters - it will auto-load from kaggle.json | |
loader = KaggleDataLoader() | |
# Download the dataset | |
dataset_path = loader.download_dataset(dataset_name) | |
# Reload the retriever with the new dataset | |
global rag_application | |
retriever = setup_retriever(use_kaggle_data=True, kaggle_dataset=dataset_name) | |
rag_chain = setup_rag_chain() | |
rag_application = RAGApplication(retriever, rag_chain) | |
return { | |
"status": "success", | |
"message": f"Dataset {dataset_name} loaded successfully", | |
"dataset_path": dataset_path | |
} | |
except Exception as e: | |
return {"status": "error", "message": str(e)} | |
async def get_models(): | |
"""Get information about available models.""" | |
return { | |
"llm_model": "dolphin-llama3:8b", | |
"embedding_model": "TF-IDF embeddings", | |
"vector_database": "ChromaDB (local)" | |
} | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
if __name__ == "__main__": | |
try: | |
logger.info("Starting application...") | |
# Add any initialization code here with try/except blocks | |
port = int(os.getenv("PORT", 7860)) | |
logger.info(f"Starting server on port {port}") | |
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") | |
except Exception as e: | |
logger.error(f"Failed to start application: {e}") | |
raise |