Spaces:
Sleeping
Sleeping
File size: 7,049 Bytes
942b420 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import os
from typing import Dict, List, Any, Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langgraph.graph import StateGraph
from langgraph.graph.graph import END
from dotenv import load_dotenv
import google.generativeai as genai
from google.generativeai import GenerativeModel
import sys
# Add the parent directory to the path to import utils
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
from utils.create_vectordb import query_chroma_db
load_dotenv()
# Initialize Gemini model
api_key = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=api_key)
model = GenerativeModel("gemini-2.5-flash-preview-05-20")
def retrieve_context(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Retrieve relevant context from the vector database based on the user query.
"""
query = state.get("user_input", "")
if not query:
return {"context": "No query provided.", "user_input": query, "next": "request_clarification"}
# Check if query is clear enough
if len(query.split()) < 3 or "?" not in query and not any(w in query.lower() for w in ["what", "how", "why", "when", "where", "who", "which"]):
return {"context": "", "user_input": query, "next": "request_clarification"}
# Query the vector database
results = query_chroma_db(query, n_results=3)
# Extract the retrieved documents
documents = results.get("documents", [[]])[0]
metadatas = results.get("metadatas", [[]])[0]
# Format the context
formatted_context = []
for i, (doc, metadata) in enumerate(zip(documents, metadatas)):
source = metadata.get("source", "Unknown")
formatted_context.append(f"Document {i+1} (Source: {source}):\n{doc}\n")
context = "\n".join(formatted_context) if formatted_context else ""
# Determine next step based on context quality
if not context or len(context) < 50:
next_step = "use_gemini_knowledge"
else:
next_step = "generate_response"
return {"context": context, "user_input": query, "next": next_step}
def request_clarification(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Request clarification from the user when the query is unclear.
"""
query = state.get("user_input", "")
clarification_message = model.generate_content(
f"""The user asked: "{query}"
This query seems vague or unclear. Generate a polite response asking for more specific details.
Focus on what additional information would help you understand their request better.
Keep your response under 3 sentences and make it conversational."""
)
response = clarification_message.text
# Update chat history
chat_history = state.get("chat_history", [])
new_chat_history = chat_history + [
{"role": "user", "content": query},
{"role": "assistant", "content": response}
]
return {
"response": response,
"chat_history": new_chat_history,
"needs_clarification": True
}
def use_gemini_knowledge(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Use Gemini's knowledge base when no relevant information is found in the vector database.
"""
query = state.get("user_input", "")
chat_history = state.get("chat_history", [])
# Construct the prompt
prompt_template = """
I couldn't find specific information about this in my local database. However, I can try to answer based on my general knowledge.
User Question: {query}
First, acknowledge that you're answering from general knowledge rather than the specific database.
Then provide a helpful, accurate response based on what you know about the topic.
"""
# Generate response
response = model.generate_content(
prompt_template.format(query=query)
)
# Update chat history
new_chat_history = chat_history + [
{"role": "user", "content": query},
{"role": "assistant", "content": response.text}
]
return {
"response": response.text,
"chat_history": new_chat_history
}
def generate_response(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Generate a response using the LLM based on the retrieved context and user query.
"""
context = state.get("context", "")
query = state.get("user_input", "")
chat_history = state.get("chat_history", [])
# Construct the prompt
prompt_template = """
You are a helpful assistant that answers questions based on the provided context.
Context:
{context}
Chat History:
{chat_history}
User Question: {query}
Answer the question based only on the provided context. If the context doesn't contain enough information,
acknowledge this but still try to provide a helpful response based on the available information.
Provide a clear, concise, and helpful response.
"""
# Format chat history for the prompt
formatted_chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history])
# Generate response
response = model.generate_content(
prompt_template.format(
context=context,
chat_history=formatted_chat_history,
query=query
)
)
# Update chat history
new_chat_history = chat_history + [
{"role": "user", "content": query},
{"role": "assistant", "content": response.text}
]
return {
"response": response.text,
"chat_history": new_chat_history
}
def decide_next_step(state: Dict[str, Any]) -> Literal["request_clarification", "use_gemini_knowledge", "generate_response"]:
"""
Decide the next step in the workflow based on the state.
"""
return state["next"]
# Define the workflow
def build_graph():
workflow = StateGraph(state_schema=Dict[str, Any])
# Add nodes
workflow.add_node("retrieve_context", retrieve_context)
workflow.add_node("request_clarification", request_clarification)
workflow.add_node("use_gemini_knowledge", use_gemini_knowledge)
workflow.add_node("generate_response", generate_response)
# Define edges with conditional routing
workflow.set_entry_point("retrieve_context")
workflow.add_conditional_edges(
"retrieve_context",
decide_next_step,
{
"request_clarification": "request_clarification",
"use_gemini_knowledge": "use_gemini_knowledge",
"generate_response": "generate_response"
}
)
# Set finish points
workflow.add_edge("request_clarification", END)
workflow.add_edge("use_gemini_knowledge", END)
workflow.add_edge("generate_response", END)
# Compile the graph
return workflow.compile()
# Create the graph
graph = build_graph()
|