Spaces:
Sleeping
Sleeping
rag-chat
#4
by
irashperera
- opened
- .gitignore +1 -0
- app.py +0 -56
- utils/create_vectordb.py +9 -21
.gitignore
CHANGED
@@ -3,3 +3,4 @@ venv
|
|
3 |
__pycache__
|
4 |
.vscode
|
5 |
corpus
|
|
|
|
3 |
__pycache__
|
4 |
.vscode
|
5 |
corpus
|
6 |
+
|
app.py
CHANGED
@@ -4,11 +4,6 @@ from langgraph.agents.rag_agent.graph import graph as rag_graph
|
|
4 |
from fastapi import Request
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
|
7 |
-
from langchain_core.documents import Document
|
8 |
-
|
9 |
-
from utils.create_vectordb import create_chroma_db_and_document,query_chroma_db
|
10 |
-
|
11 |
-
|
12 |
|
13 |
|
14 |
|
@@ -37,63 +32,12 @@ async def summarize(request: Request):
|
|
37 |
notes = data.get("notes")
|
38 |
return graph.invoke({"urls": urls, "codes": codes, "notes": notes})
|
39 |
|
40 |
-
|
41 |
-
@app.post("/save_summary")
|
42 |
-
async def save_summary(request: Request):
|
43 |
-
data = await request.json()
|
44 |
-
summary = data.get("summary", "")
|
45 |
-
post_id = data.get("post_id", None)
|
46 |
-
title = data.get("title", "")
|
47 |
-
category = data.get("category", "")
|
48 |
-
tags = data.get("tags", [])
|
49 |
-
references = data.get("references", [])
|
50 |
-
|
51 |
-
page_content = f"""
|
52 |
-
Title: {title}
|
53 |
-
Category: {category}
|
54 |
-
Tags: {', '.join(tags)}
|
55 |
-
Summary: {summary}
|
56 |
-
"""
|
57 |
-
|
58 |
-
document = Document(
|
59 |
-
page_content=page_content,
|
60 |
-
id = str(post_id)
|
61 |
-
|
62 |
-
)
|
63 |
-
|
64 |
-
is_added = create_chroma_db_and_document(document)
|
65 |
-
|
66 |
-
if not is_added:
|
67 |
-
return {"error": "Failed to save summary to the database." , "status": "error"}
|
68 |
-
|
69 |
-
return {"message": "Summary saved successfully." , "status": "success"}
|
70 |
-
|
71 |
-
@app.post("/summaries")
|
72 |
-
async def get_summaries(request: Request):
|
73 |
-
|
74 |
-
data = await request.json()
|
75 |
-
print(data)
|
76 |
-
query = data.get("query" , "")
|
77 |
-
|
78 |
-
print(f"Query received: {query}")
|
79 |
-
results = query_chroma_db(query=query)
|
80 |
-
return results
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
@app.post("/chat")
|
86 |
async def chat(request: Request):
|
87 |
data = await request.json()
|
88 |
-
|
89 |
-
print(f"Chat request data: {data}")
|
90 |
-
|
91 |
user_input = data.get("message", "")
|
92 |
chat_history = data.get("chat_history", [])
|
93 |
|
94 |
-
print(f"User input: {user_input}")
|
95 |
-
print(f"Chat history: {chat_history}")
|
96 |
-
|
97 |
# Invoke the RAG chatbot graph
|
98 |
result = rag_graph.invoke({
|
99 |
"user_input": user_input,
|
|
|
4 |
from fastapi import Request
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
|
|
|
32 |
notes = data.get("notes")
|
33 |
return graph.invoke({"urls": urls, "codes": codes, "notes": notes})
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
@app.post("/chat")
|
36 |
async def chat(request: Request):
|
37 |
data = await request.json()
|
|
|
|
|
|
|
38 |
user_input = data.get("message", "")
|
39 |
chat_history = data.get("chat_history", [])
|
40 |
|
|
|
|
|
|
|
41 |
# Invoke the RAG chatbot graph
|
42 |
result = rag_graph.invoke({
|
43 |
"user_input": user_input,
|
utils/create_vectordb.py
CHANGED
@@ -54,7 +54,7 @@ def split_documents(documents, chunk_size=1000, chunk_overlap=200):
|
|
54 |
|
55 |
return splits
|
56 |
|
57 |
-
def
|
58 |
"""Create a Chroma vector database from documents."""
|
59 |
# Initialize the Gemini embedding function
|
60 |
gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction(
|
@@ -75,25 +75,17 @@ def create_chroma_db_and_document(document, collection_name="corpus_collection",
|
|
75 |
embedding_function=gemini_ef
|
76 |
)
|
77 |
print(f"Created new collection: {collection_name}")
|
78 |
-
|
79 |
-
|
80 |
-
try:
|
81 |
|
|
|
|
|
82 |
collection.add(
|
83 |
-
documents
|
84 |
-
|
|
|
85 |
)
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
except Exception as e:
|
91 |
-
print(f"Error adding document to collection: {e}")
|
92 |
-
|
93 |
-
return False
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
|
98 |
def query_chroma_db(query: str, collection_name="corpus_collection", n_results=5, db_dir=DB_DIR):
|
99 |
"""Query the Chroma vector database."""
|
@@ -144,10 +136,6 @@ def main():
|
|
144 |
print(f"Source: {metadata.get('source', 'Unknown')}")
|
145 |
|
146 |
print("\nVector database creation and testing complete!")
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
|
152 |
if __name__ == "__main__":
|
153 |
main()
|
|
|
54 |
|
55 |
return splits
|
56 |
|
57 |
+
def create_chroma_db(documents, collection_name="corpus_collection", db_dir=DB_DIR):
|
58 |
"""Create a Chroma vector database from documents."""
|
59 |
# Initialize the Gemini embedding function
|
60 |
gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction(
|
|
|
75 |
embedding_function=gemini_ef
|
76 |
)
|
77 |
print(f"Created new collection: {collection_name}")
|
|
|
|
|
|
|
78 |
|
79 |
+
# Add documents to collection
|
80 |
+
for i, doc in enumerate(documents):
|
81 |
collection.add(
|
82 |
+
documents=[doc.page_content],
|
83 |
+
metadatas=[doc.metadata],
|
84 |
+
ids=[f"doc_{i}"]
|
85 |
)
|
86 |
+
|
87 |
+
print(f"Added {len(documents)} documents to collection {collection_name}")
|
88 |
+
return collection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def query_chroma_db(query: str, collection_name="corpus_collection", n_results=5, db_dir=DB_DIR):
|
91 |
"""Query the Chroma vector database."""
|
|
|
136 |
print(f"Source: {metadata.get('source', 'Unknown')}")
|
137 |
|
138 |
print("\nVector database creation and testing complete!")
|
|
|
|
|
|
|
|
|
139 |
|
140 |
if __name__ == "__main__":
|
141 |
main()
|