irashperera's picture
Enhance FastAPI application with RAG chatbot integration
942b420
import os
from typing import Dict, List, Any, Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langgraph.graph import StateGraph
from langgraph.graph.graph import END
from dotenv import load_dotenv
import google.generativeai as genai
from google.generativeai import GenerativeModel
import sys
# Add the parent directory to the path to import utils
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
from utils.create_vectordb import query_chroma_db
load_dotenv()
# Initialize Gemini model
api_key = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=api_key)
model = GenerativeModel("gemini-2.5-flash-preview-05-20")
def retrieve_context(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Retrieve relevant context from the vector database based on the user query.
"""
query = state.get("user_input", "")
if not query:
return {"context": "No query provided.", "user_input": query, "next": "request_clarification"}
# Check if query is clear enough
if len(query.split()) < 3 or "?" not in query and not any(w in query.lower() for w in ["what", "how", "why", "when", "where", "who", "which"]):
return {"context": "", "user_input": query, "next": "request_clarification"}
# Query the vector database
results = query_chroma_db(query, n_results=3)
# Extract the retrieved documents
documents = results.get("documents", [[]])[0]
metadatas = results.get("metadatas", [[]])[0]
# Format the context
formatted_context = []
for i, (doc, metadata) in enumerate(zip(documents, metadatas)):
source = metadata.get("source", "Unknown")
formatted_context.append(f"Document {i+1} (Source: {source}):\n{doc}\n")
context = "\n".join(formatted_context) if formatted_context else ""
# Determine next step based on context quality
if not context or len(context) < 50:
next_step = "use_gemini_knowledge"
else:
next_step = "generate_response"
return {"context": context, "user_input": query, "next": next_step}
def request_clarification(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Request clarification from the user when the query is unclear.
"""
query = state.get("user_input", "")
clarification_message = model.generate_content(
f"""The user asked: "{query}"
This query seems vague or unclear. Generate a polite response asking for more specific details.
Focus on what additional information would help you understand their request better.
Keep your response under 3 sentences and make it conversational."""
)
response = clarification_message.text
# Update chat history
chat_history = state.get("chat_history", [])
new_chat_history = chat_history + [
{"role": "user", "content": query},
{"role": "assistant", "content": response}
]
return {
"response": response,
"chat_history": new_chat_history,
"needs_clarification": True
}
def use_gemini_knowledge(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Use Gemini's knowledge base when no relevant information is found in the vector database.
"""
query = state.get("user_input", "")
chat_history = state.get("chat_history", [])
# Construct the prompt
prompt_template = """
I couldn't find specific information about this in my local database. However, I can try to answer based on my general knowledge.
User Question: {query}
First, acknowledge that you're answering from general knowledge rather than the specific database.
Then provide a helpful, accurate response based on what you know about the topic.
"""
# Generate response
response = model.generate_content(
prompt_template.format(query=query)
)
# Update chat history
new_chat_history = chat_history + [
{"role": "user", "content": query},
{"role": "assistant", "content": response.text}
]
return {
"response": response.text,
"chat_history": new_chat_history
}
def generate_response(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Generate a response using the LLM based on the retrieved context and user query.
"""
context = state.get("context", "")
query = state.get("user_input", "")
chat_history = state.get("chat_history", [])
# Construct the prompt
prompt_template = """
You are a helpful assistant that answers questions based on the provided context.
Context:
{context}
Chat History:
{chat_history}
User Question: {query}
Answer the question based only on the provided context. If the context doesn't contain enough information,
acknowledge this but still try to provide a helpful response based on the available information.
Provide a clear, concise, and helpful response.
"""
# Format chat history for the prompt
formatted_chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history])
# Generate response
response = model.generate_content(
prompt_template.format(
context=context,
chat_history=formatted_chat_history,
query=query
)
)
# Update chat history
new_chat_history = chat_history + [
{"role": "user", "content": query},
{"role": "assistant", "content": response.text}
]
return {
"response": response.text,
"chat_history": new_chat_history
}
def decide_next_step(state: Dict[str, Any]) -> Literal["request_clarification", "use_gemini_knowledge", "generate_response"]:
"""
Decide the next step in the workflow based on the state.
"""
return state["next"]
# Define the workflow
def build_graph():
workflow = StateGraph(state_schema=Dict[str, Any])
# Add nodes
workflow.add_node("retrieve_context", retrieve_context)
workflow.add_node("request_clarification", request_clarification)
workflow.add_node("use_gemini_knowledge", use_gemini_knowledge)
workflow.add_node("generate_response", generate_response)
# Define edges with conditional routing
workflow.set_entry_point("retrieve_context")
workflow.add_conditional_edges(
"retrieve_context",
decide_next_step,
{
"request_clarification": "request_clarification",
"use_gemini_knowledge": "use_gemini_knowledge",
"generate_response": "generate_response"
}
)
# Set finish points
workflow.add_edge("request_clarification", END)
workflow.add_edge("use_gemini_knowledge", END)
workflow.add_edge("generate_response", END)
# Compile the graph
return workflow.compile()
# Create the graph
graph = build_graph()