Pasindu599 commited on
Commit
51a3d33
·
1 Parent(s): c2db0d4

Add save_summary and get_summaries endpoints to FastAPI app; refactor create_chroma_db to handle single document input

Browse files
Files changed (2) hide show
  1. app.py +51 -0
  2. utils/create_vectordb.py +21 -9
app.py CHANGED
@@ -4,6 +4,11 @@ from langgraph.agents.rag_agent.graph import graph as rag_graph
4
  from fastapi import Request
5
  from fastapi.middleware.cors import CORSMiddleware
6
 
 
 
 
 
 
7
 
8
 
9
 
@@ -32,6 +37,52 @@ async def summarize(request: Request):
32
  notes = data.get("notes")
33
  return graph.invoke({"urls": urls, "codes": codes, "notes": notes})
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  @app.post("/chat")
36
  async def chat(request: Request):
37
  data = await request.json()
 
4
  from fastapi import Request
5
  from fastapi.middleware.cors import CORSMiddleware
6
 
7
+ from langchain_core.documents import Document
8
+
9
+ from utils.create_vectordb import create_chroma_db_and_document,query_chroma_db
10
+
11
+
12
 
13
 
14
 
 
37
  notes = data.get("notes")
38
  return graph.invoke({"urls": urls, "codes": codes, "notes": notes})
39
 
40
+
41
+ @app.post("/save_summary")
42
+ async def save_summary(request: Request):
43
+ data = await request.json()
44
+ summary = data.get("summary", "")
45
+ post_id = data.get("post_id", None)
46
+ title = data.get("title", "")
47
+ category = data.get("category", "")
48
+ tags = data.get("tags", [])
49
+ references = data.get("references", [])
50
+
51
+ page_content = f"""
52
+ Title: {title}
53
+ Category: {category}
54
+ Tags: {', '.join(tags)}
55
+ Summary: {summary}
56
+ """
57
+
58
+ document = Document(
59
+ page_content=page_content,
60
+ id = str(post_id)
61
+
62
+ )
63
+
64
+ is_added = create_chroma_db_and_document(document)
65
+
66
+ if not is_added:
67
+ return {"error": "Failed to save summary to the database." , "status": "error"}
68
+
69
+ return {"message": "Summary saved successfully." , "status": "success"}
70
+
71
+ @app.post("/summaries")
72
+ async def get_summaries(request: Request):
73
+
74
+
75
+ data = await request.json()
76
+ print(data)
77
+ query = data.get("query" , "")
78
+
79
+ print(f"Query received: {query}")
80
+ results = query_chroma_db(query=query)
81
+ return results
82
+
83
+
84
+
85
+
86
  @app.post("/chat")
87
  async def chat(request: Request):
88
  data = await request.json()
utils/create_vectordb.py CHANGED
@@ -54,7 +54,7 @@ def split_documents(documents, chunk_size=1000, chunk_overlap=200):
54
 
55
  return splits
56
 
57
- def create_chroma_db(documents, collection_name="corpus_collection", db_dir=DB_DIR):
58
  """Create a Chroma vector database from documents."""
59
  # Initialize the Gemini embedding function
60
  gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction(
@@ -75,17 +75,25 @@ def create_chroma_db(documents, collection_name="corpus_collection", db_dir=DB_D
75
  embedding_function=gemini_ef
76
  )
77
  print(f"Created new collection: {collection_name}")
 
 
 
78
 
79
- # Add documents to collection
80
- for i, doc in enumerate(documents):
81
  collection.add(
82
- documents=[doc.page_content],
83
- metadatas=[doc.metadata],
84
- ids=[f"doc_{i}"]
85
  )
86
-
87
- print(f"Added {len(documents)} documents to collection {collection_name}")
88
- return collection
 
 
 
 
 
 
 
 
89
 
90
  def query_chroma_db(query: str, collection_name="corpus_collection", n_results=5, db_dir=DB_DIR):
91
  """Query the Chroma vector database."""
@@ -136,6 +144,10 @@ def main():
136
  print(f"Source: {metadata.get('source', 'Unknown')}")
137
 
138
  print("\nVector database creation and testing complete!")
 
 
 
 
139
 
140
  if __name__ == "__main__":
141
  main()
 
54
 
55
  return splits
56
 
57
+ def create_chroma_db_and_document(document, collection_name="corpus_collection", db_dir=DB_DIR):
58
  """Create a Chroma vector database from documents."""
59
  # Initialize the Gemini embedding function
60
  gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction(
 
75
  embedding_function=gemini_ef
76
  )
77
  print(f"Created new collection: {collection_name}")
78
+
79
+
80
+ try:
81
 
 
 
82
  collection.add(
83
+ documents = [document.page_content],
84
+ ids = [document.id]
 
85
  )
86
+
87
+ print("Document added to collection successfully.")
88
+ return True
89
+
90
+ except Exception as e:
91
+ print(f"Error adding document to collection: {e}")
92
+
93
+ return False
94
+
95
+
96
+
97
 
98
  def query_chroma_db(query: str, collection_name="corpus_collection", n_results=5, db_dir=DB_DIR):
99
  """Query the Chroma vector database."""
 
144
  print(f"Source: {metadata.get('source', 'Unknown')}")
145
 
146
  print("\nVector database creation and testing complete!")
147
+
148
+
149
+
150
+
151
 
152
  if __name__ == "__main__":
153
  main()