Pasindu599 commited on
Commit
c2db0d4
·
2 Parent(s): 87e9bdc a9ac1e4

Merge branch 'pr/4'

Browse files
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  venv
2
  .env
3
  __pycache__
4
- .vscode
 
 
1
  venv
2
  .env
3
  __pycache__
4
+ .vscode
5
+ corpus
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI
2
  from langgraph.agents.summarize_agent.graph import graph
 
3
  from fastapi import Request
4
  from fastapi.middleware.cors import CORSMiddleware
5
 
@@ -31,6 +32,23 @@ async def summarize(request: Request):
31
  notes = data.get("notes")
32
  return graph.invoke({"urls": urls, "codes": codes, "notes": notes})
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
 
 
1
  from fastapi import FastAPI
2
  from langgraph.agents.summarize_agent.graph import graph
3
+ from langgraph.agents.rag_agent.graph import graph as rag_graph
4
  from fastapi import Request
5
  from fastapi.middleware.cors import CORSMiddleware
6
 
 
32
  notes = data.get("notes")
33
  return graph.invoke({"urls": urls, "codes": codes, "notes": notes})
34
 
35
+ @app.post("/chat")
36
+ async def chat(request: Request):
37
+ data = await request.json()
38
+ user_input = data.get("message", "")
39
+ chat_history = data.get("chat_history", [])
40
+
41
+ # Invoke the RAG chatbot graph
42
+ result = rag_graph.invoke({
43
+ "user_input": user_input,
44
+ "chat_history": chat_history
45
+ })
46
+
47
+ return {
48
+ "response": result["response"],
49
+ "chat_history": result["chat_history"]
50
+ }
51
+
52
 
53
 
54
 
langgraph/agents/rag_agent/graph.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Any, Literal
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from langchain_core.runnables import RunnablePassthrough
6
+ from langgraph.graph import StateGraph
7
+ from langgraph.graph.graph import END
8
+ from dotenv import load_dotenv
9
+ import google.generativeai as genai
10
+ from google.generativeai import GenerativeModel
11
+ import sys
12
+
13
+ # Add the parent directory to the path to import utils
14
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
15
+ from utils.create_vectordb import query_chroma_db
16
+
17
+ load_dotenv()
18
+
19
+ # Initialize Gemini model
20
+ api_key = os.getenv("GOOGLE_API_KEY")
21
+ genai.configure(api_key=api_key)
22
+ model = GenerativeModel("gemini-2.5-flash-preview-05-20")
23
+
24
+ def retrieve_context(state: Dict[str, Any]) -> Dict[str, Any]:
25
+ """
26
+ Retrieve relevant context from the vector database based on the user query.
27
+ """
28
+ query = state.get("user_input", "")
29
+ if not query:
30
+ return {"context": "No query provided.", "user_input": query, "next": "request_clarification"}
31
+
32
+ # Check if query is clear enough
33
+ if len(query.split()) < 3 or "?" not in query and not any(w in query.lower() for w in ["what", "how", "why", "when", "where", "who", "which"]):
34
+ return {"context": "", "user_input": query, "next": "request_clarification"}
35
+
36
+ # Query the vector database
37
+ results = query_chroma_db(query, n_results=3)
38
+
39
+ # Extract the retrieved documents
40
+ documents = results.get("documents", [[]])[0]
41
+ metadatas = results.get("metadatas", [[]])[0]
42
+
43
+ # Format the context
44
+ formatted_context = []
45
+ for i, (doc, metadata) in enumerate(zip(documents, metadatas)):
46
+ source = metadata.get("source", "Unknown")
47
+ formatted_context.append(f"Document {i+1} (Source: {source}):\n{doc}\n")
48
+
49
+ context = "\n".join(formatted_context) if formatted_context else ""
50
+
51
+ # Determine next step based on context quality
52
+ if not context or len(context) < 50:
53
+ next_step = "use_gemini_knowledge"
54
+ else:
55
+ next_step = "generate_response"
56
+
57
+ return {"context": context, "user_input": query, "next": next_step}
58
+
59
+ def request_clarification(state: Dict[str, Any]) -> Dict[str, Any]:
60
+ """
61
+ Request clarification from the user when the query is unclear.
62
+ """
63
+ query = state.get("user_input", "")
64
+
65
+ clarification_message = model.generate_content(
66
+ f"""The user asked: "{query}"
67
+
68
+ This query seems vague or unclear. Generate a polite response asking for more specific details.
69
+ Focus on what additional information would help you understand their request better.
70
+ Keep your response under 3 sentences and make it conversational."""
71
+ )
72
+
73
+ response = clarification_message.text
74
+
75
+ # Update chat history
76
+ chat_history = state.get("chat_history", [])
77
+ new_chat_history = chat_history + [
78
+ {"role": "user", "content": query},
79
+ {"role": "assistant", "content": response}
80
+ ]
81
+
82
+
83
+ return {
84
+ "response": response,
85
+ "chat_history": new_chat_history,
86
+ "needs_clarification": True
87
+ }
88
+
89
+ def use_gemini_knowledge(state: Dict[str, Any]) -> Dict[str, Any]:
90
+ """
91
+ Use Gemini's knowledge base when no relevant information is found in the vector database.
92
+ """
93
+ query = state.get("user_input", "")
94
+ chat_history = state.get("chat_history", [])
95
+
96
+ # Construct the prompt
97
+ prompt_template = """
98
+ I couldn't find specific information about this in my local database. However, I can try to answer based on my general knowledge.
99
+
100
+ User Question: {query}
101
+
102
+ First, acknowledge that you're answering from general knowledge rather than the specific database.
103
+ Then provide a helpful, accurate response based on what you know about the topic.
104
+ """
105
+
106
+ # Generate response
107
+ response = model.generate_content(
108
+ prompt_template.format(query=query)
109
+ )
110
+
111
+ # Update chat history
112
+ new_chat_history = chat_history + [
113
+ {"role": "user", "content": query},
114
+ {"role": "assistant", "content": response.text}
115
+ ]
116
+
117
+ return {
118
+ "response": response.text,
119
+ "chat_history": new_chat_history
120
+ }
121
+
122
+ def generate_response(state: Dict[str, Any]) -> Dict[str, Any]:
123
+ """
124
+ Generate a response using the LLM based on the retrieved context and user query.
125
+ """
126
+ context = state.get("context", "")
127
+ query = state.get("user_input", "")
128
+ chat_history = state.get("chat_history", [])
129
+
130
+ # Construct the prompt
131
+ prompt_template = """
132
+ You are a helpful assistant that answers questions based on the provided context.
133
+
134
+ Context:
135
+ {context}
136
+
137
+ Chat History:
138
+ {chat_history}
139
+
140
+ User Question: {query}
141
+
142
+ Answer the question based only on the provided context. If the context doesn't contain enough information,
143
+ acknowledge this but still try to provide a helpful response based on the available information.
144
+ Provide a clear, concise, and helpful response.
145
+ """
146
+
147
+ # Format chat history for the prompt
148
+ formatted_chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history])
149
+
150
+ # Generate response
151
+ response = model.generate_content(
152
+ prompt_template.format(
153
+ context=context,
154
+ chat_history=formatted_chat_history,
155
+ query=query
156
+ )
157
+ )
158
+
159
+ # Update chat history
160
+ new_chat_history = chat_history + [
161
+ {"role": "user", "content": query},
162
+ {"role": "assistant", "content": response.text}
163
+ ]
164
+
165
+ return {
166
+ "response": response.text,
167
+ "chat_history": new_chat_history
168
+ }
169
+
170
+ def decide_next_step(state: Dict[str, Any]) -> Literal["request_clarification", "use_gemini_knowledge", "generate_response"]:
171
+ """
172
+ Decide the next step in the workflow based on the state.
173
+ """
174
+ return state["next"]
175
+
176
+ # Define the workflow
177
+ def build_graph():
178
+ workflow = StateGraph(state_schema=Dict[str, Any])
179
+
180
+ # Add nodes
181
+ workflow.add_node("retrieve_context", retrieve_context)
182
+ workflow.add_node("request_clarification", request_clarification)
183
+ workflow.add_node("use_gemini_knowledge", use_gemini_knowledge)
184
+ workflow.add_node("generate_response", generate_response)
185
+
186
+ # Define edges with conditional routing
187
+ workflow.set_entry_point("retrieve_context")
188
+ workflow.add_conditional_edges(
189
+ "retrieve_context",
190
+ decide_next_step,
191
+ {
192
+ "request_clarification": "request_clarification",
193
+ "use_gemini_knowledge": "use_gemini_knowledge",
194
+ "generate_response": "generate_response"
195
+ }
196
+ )
197
+
198
+ # Set finish points
199
+ workflow.add_edge("request_clarification", END)
200
+ workflow.add_edge("use_gemini_knowledge", END)
201
+ workflow.add_edge("generate_response", END)
202
+
203
+ # Compile the graph
204
+ return workflow.compile()
205
+
206
+ # Create the graph
207
+ graph = build_graph()
requirements.txt CHANGED
@@ -3,6 +3,10 @@ uvicorn[standard]
3
  langgraph
4
  langsmith
5
  google-genai
6
-
 
 
 
7
  python-dotenv
 
8
 
 
3
  langgraph
4
  langsmith
5
  google-genai
6
+ google-generativeai
7
+ chromadb
8
+ langchain
9
+ langchain-community
10
  python-dotenv
11
+ pypdf
12
 
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # This file is intentionally left empty to make the directory a Python package
utils/create_vectordb.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, List
3
+ import chromadb
4
+ from chromadb.utils import embedding_functions
5
+ from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from dotenv import load_dotenv
8
+ import google.generativeai as genai
9
+
10
+ load_dotenv()
11
+
12
+ # Configure paths
13
+ CORPUS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "corpus")
14
+ DB_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "vectordb")
15
+
16
+ # Ensure directories exist
17
+ os.makedirs(CORPUS_DIR, exist_ok=True)
18
+ os.makedirs(DB_DIR, exist_ok=True)
19
+
20
+ def load_documents(corpus_dir: str = CORPUS_DIR) -> List:
21
+ """Load documents from the corpus directory."""
22
+ if not os.path.exists(corpus_dir):
23
+ raise FileNotFoundError(f"Corpus directory not found: {corpus_dir}")
24
+ print(f"Loading documents from {corpus_dir}...")
25
+
26
+ # Initialize loaders for different file types
27
+ loaders = {
28
+ # "txt": DirectoryLoader(corpus_dir, glob="**/*.txt", loader_cls=TextLoader),
29
+ "pdf": DirectoryLoader(corpus_dir, glob="**/*.pdf", loader_cls=PyPDFLoader),
30
+ # "docx": DirectoryLoader(corpus_dir, glob="**/*.docx", loader_cls=Docx2txtLoader),
31
+ }
32
+
33
+ documents = []
34
+ for file_type, loader in loaders.items():
35
+ try:
36
+ docs = loader.load()
37
+ print(f"Loaded {len(docs)} {file_type} documents")
38
+ documents.extend(docs)
39
+ except Exception as e:
40
+ print(f"Error loading {file_type} documents: {e}")
41
+
42
+ return documents
43
+
44
+ def split_documents(documents, chunk_size=1000, chunk_overlap=200):
45
+ """Split documents into chunks."""
46
+ text_splitter = RecursiveCharacterTextSplitter(
47
+ chunk_size=chunk_size,
48
+ chunk_overlap=chunk_overlap,
49
+ length_function=len,
50
+ )
51
+
52
+ splits = text_splitter.split_documents(documents)
53
+ print(f"Split {len(documents)} documents into {len(splits)} chunks")
54
+
55
+ return splits
56
+
57
+ def create_chroma_db(documents, collection_name="corpus_collection", db_dir=DB_DIR):
58
+ """Create a Chroma vector database from documents."""
59
+ # Initialize the Gemini embedding function
60
+ gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction(
61
+ api_key=os.getenv("GOOGLE_API_KEY"),
62
+ model_name="models/embedding-001"
63
+ )
64
+
65
+ # Initialize Chroma client
66
+ client = chromadb.PersistentClient(path=db_dir)
67
+
68
+ # Create or get collection
69
+ try:
70
+ collection = client.get_collection(name=collection_name)
71
+ print(f"Using existing collection: {collection_name}")
72
+ except:
73
+ collection = client.create_collection(
74
+ name=collection_name,
75
+ embedding_function=gemini_ef
76
+ )
77
+ print(f"Created new collection: {collection_name}")
78
+
79
+ # Add documents to collection
80
+ for i, doc in enumerate(documents):
81
+ collection.add(
82
+ documents=[doc.page_content],
83
+ metadatas=[doc.metadata],
84
+ ids=[f"doc_{i}"]
85
+ )
86
+
87
+ print(f"Added {len(documents)} documents to collection {collection_name}")
88
+ return collection
89
+
90
+ def query_chroma_db(query: str, collection_name="corpus_collection", n_results=5, db_dir=DB_DIR):
91
+ """Query the Chroma vector database."""
92
+ # Initialize the Gemini embedding function
93
+ gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction(
94
+ api_key=os.getenv("GOOGLE_API_KEY"),
95
+ model_name="models/embedding-001"
96
+ )
97
+
98
+ # Initialize Chroma client
99
+ client = chromadb.PersistentClient(path=db_dir)
100
+
101
+ # Get collection
102
+ collection = client.get_collection(name=collection_name, embedding_function=gemini_ef)
103
+
104
+ # Query collection
105
+ results = collection.query(
106
+ query_texts=[query],
107
+ n_results=n_results
108
+ )
109
+
110
+ return results
111
+
112
+ def main():
113
+ """Main function to create and test the vector database."""
114
+ print("Starting vector database creation...")
115
+
116
+ # Load documents
117
+ documents = load_documents()
118
+ if not documents:
119
+ print("No documents found in corpus directory. Please add documents to proceed.")
120
+ return
121
+
122
+ # Split documents
123
+ splits = split_documents(documents)
124
+
125
+ # Create vector database
126
+ collection = create_chroma_db(splits)
127
+
128
+ # Test query
129
+ test_query = "What is this corpus about?"
130
+ print(f"\nTesting query: '{test_query}'")
131
+ results = query_chroma_db(test_query)
132
+ print(f"Found {len(results['documents'][0])} matching documents")
133
+ for i, (doc, metadata) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
134
+ print(f"\nResult {i+1}:")
135
+ print(f"Document: {doc[:150]}...")
136
+ print(f"Source: {metadata.get('source', 'Unknown')}")
137
+
138
+ print("\nVector database creation and testing complete!")
139
+
140
+ if __name__ == "__main__":
141
+ main()