import os from typing import Dict, List, Any, Literal from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langgraph.graph import StateGraph from langgraph.graph.graph import END from dotenv import load_dotenv import google.generativeai as genai from google.generativeai import GenerativeModel import sys # Add the parent directory to the path to import utils sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) from utils.create_vectordb import query_chroma_db load_dotenv() # Initialize Gemini model api_key = os.getenv("GOOGLE_API_KEY") genai.configure(api_key=api_key) model = GenerativeModel("gemini-2.5-flash-preview-05-20") def retrieve_context(state: Dict[str, Any]) -> Dict[str, Any]: """ Retrieve relevant context from the vector database based on the user query. """ query = state.get("user_input", "") if not query: return {"context": "No query provided.", "user_input": query, "next": "request_clarification"} # Check if query is clear enough if len(query.split()) < 3 or "?" not in query and not any(w in query.lower() for w in ["what", "how", "why", "when", "where", "who", "which"]): return {"context": "", "user_input": query, "next": "request_clarification"} # Query the vector database results = query_chroma_db(query, n_results=3) # Extract the retrieved documents documents = results.get("documents", [[]])[0] metadatas = results.get("metadatas", [[]])[0] # Format the context formatted_context = [] for i, (doc, metadata) in enumerate(zip(documents, metadatas)): source = metadata.get("source", "Unknown") formatted_context.append(f"Document {i+1} (Source: {source}):\n{doc}\n") context = "\n".join(formatted_context) if formatted_context else "" # Determine next step based on context quality if not context or len(context) < 50: next_step = "use_gemini_knowledge" else: next_step = "generate_response" return {"context": context, "user_input": query, "next": next_step} def request_clarification(state: Dict[str, Any]) -> Dict[str, Any]: """ Request clarification from the user when the query is unclear. """ query = state.get("user_input", "") clarification_message = model.generate_content( f"""The user asked: "{query}" This query seems vague or unclear. Generate a polite response asking for more specific details. Focus on what additional information would help you understand their request better. Keep your response under 3 sentences and make it conversational.""" ) response = clarification_message.text # Update chat history chat_history = state.get("chat_history", []) new_chat_history = chat_history + [ {"role": "user", "content": query}, {"role": "assistant", "content": response} ] return { "response": response, "chat_history": new_chat_history, "needs_clarification": True } def use_gemini_knowledge(state: Dict[str, Any]) -> Dict[str, Any]: """ Use Gemini's knowledge base when no relevant information is found in the vector database. """ query = state.get("user_input", "") chat_history = state.get("chat_history", []) # Construct the prompt prompt_template = """ I couldn't find specific information about this in my local database. However, I can try to answer based on my general knowledge. User Question: {query} First, acknowledge that you're answering from general knowledge rather than the specific database. Then provide a helpful, accurate response based on what you know about the topic. """ # Generate response response = model.generate_content( prompt_template.format(query=query) ) # Update chat history new_chat_history = chat_history + [ {"role": "user", "content": query}, {"role": "assistant", "content": response.text} ] return { "response": response.text, "chat_history": new_chat_history } def generate_response(state: Dict[str, Any]) -> Dict[str, Any]: """ Generate a response using the LLM based on the retrieved context and user query. """ context = state.get("context", "") query = state.get("user_input", "") chat_history = state.get("chat_history", []) # Construct the prompt prompt_template = """ You are a helpful assistant that answers questions based on the provided context. Context: {context} Chat History: {chat_history} User Question: {query} Answer the question based only on the provided context. If the context doesn't contain enough information, acknowledge this but still try to provide a helpful response based on the available information. Provide a clear, concise, and helpful response. """ # Format chat history for the prompt formatted_chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history]) # Generate response response = model.generate_content( prompt_template.format( context=context, chat_history=formatted_chat_history, query=query ) ) # Update chat history new_chat_history = chat_history + [ {"role": "user", "content": query}, {"role": "assistant", "content": response.text} ] return { "response": response.text, "chat_history": new_chat_history } def decide_next_step(state: Dict[str, Any]) -> Literal["request_clarification", "use_gemini_knowledge", "generate_response"]: """ Decide the next step in the workflow based on the state. """ return state["next"] # Define the workflow def build_graph(): workflow = StateGraph(state_schema=Dict[str, Any]) # Add nodes workflow.add_node("retrieve_context", retrieve_context) workflow.add_node("request_clarification", request_clarification) workflow.add_node("use_gemini_knowledge", use_gemini_knowledge) workflow.add_node("generate_response", generate_response) # Define edges with conditional routing workflow.set_entry_point("retrieve_context") workflow.add_conditional_edges( "retrieve_context", decide_next_step, { "request_clarification": "request_clarification", "use_gemini_knowledge": "use_gemini_knowledge", "generate_response": "generate_response" } ) # Set finish points workflow.add_edge("request_clarification", END) workflow.add_edge("use_gemini_knowledge", END) workflow.add_edge("generate_response", END) # Compile the graph return workflow.compile() # Create the graph graph = build_graph()