Spaces:
Running
Running
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -1,210 +1,320 @@
|
|
1 |
|
2 |
-
#
|
3 |
-
import os
|
4 |
-
import
|
5 |
-
|
6 |
-
import
|
7 |
-
|
8 |
-
|
9 |
-
from
|
10 |
-
|
11 |
-
|
12 |
-
from
|
13 |
-
from
|
14 |
-
from
|
15 |
-
from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
|
16 |
-
from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
|
17 |
-
|
18 |
-
# LangChain community & experimental imports
|
19 |
-
from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
|
20 |
-
from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
|
21 |
-
from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
|
22 |
-
from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
|
23 |
-
from langchain.text_splitter import (
|
24 |
-
CharacterTextSplitter, # Splitting text by characters
|
25 |
-
RecursiveCharacterTextSplitter # Recursive splitting of text by characters
|
26 |
-
)
|
27 |
-
from langchain_core.tools import tool
|
28 |
-
from langchain.agents import create_tool_calling_agent, AgentExecutor
|
29 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
# LangChain OpenAI
|
32 |
-
# Commented cout - not used
|
33 |
-
#from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
|
34 |
-
#from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
|
35 |
-
# Added and used below
|
36 |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
37 |
-
#from langchain_openai import ChatOpenAI
|
38 |
|
39 |
-
#
|
40 |
-
from
|
41 |
-
from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
|
42 |
|
43 |
-
#
|
44 |
-
from
|
|
|
45 |
|
46 |
-
#
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
#
|
50 |
-
|
51 |
|
52 |
-
#
|
53 |
-
import numpy as np # Numpy for numerical operations
|
54 |
-
from groq import Groq
|
55 |
-
from mem0 import MemoryClient
|
56 |
-
import streamlit as st
|
57 |
-
from datetime import datetime
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
#
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
)
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
class AgentState(TypedDict):
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
108 |
feedback: str
|
109 |
query_feedback: str
|
110 |
-
groundedness_check: bool
|
111 |
loop_max_iter: int
|
112 |
|
113 |
-
def expand_query(state):
|
114 |
-
"""
|
115 |
-
Expands the user query to improve retrieval of nutrition disorder-related information.
|
116 |
-
|
117 |
-
Args:
|
118 |
-
state (Dict): The current state of the workflow, containing the user query.
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
""
|
123 |
-
print("---------Expanding Query---------")
|
124 |
system_message = '''
|
125 |
You are a domain expert assisting in answering questions related to research papers.
|
126 |
Convert the user query into something that a nutritionist would understand. Use domain related words.
|
127 |
-
Return 3 related search queries based on the user's request
|
128 |
-
Return only 3 versions of the question as a list.
|
129 |
Perform query expansion on the question received. If there are multiple common ways of phrasing a user question \
|
130 |
or common synonyms for key words in the question, make sure to return multiple versions \
|
131 |
of the query with the different phrasings.
|
132 |
-
If the query has multiple parts, split them into separate simpler queries. This is the only case where you can generate more than 3 queries.
|
133 |
If there are acronyms or words you are not familiar with, do not try to rephrase them.
|
134 |
Generate only a list of questions. Do not mention anything before or after the list.
|
135 |
-
Use the query
|
136 |
'''
|
137 |
-
|
138 |
expand_prompt = ChatPromptTemplate.from_messages([
|
139 |
("system", system_message),
|
140 |
("user", "Expand this query: {query} using the feedback: {query_feedback}")
|
141 |
-
|
142 |
])
|
143 |
-
|
144 |
-
chain = expand_prompt | llm | StrOutputParser()
|
145 |
expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
|
146 |
-
|
147 |
state["expanded_query"] = expanded_query
|
148 |
return state
|
149 |
|
|
|
|
|
|
|
|
|
150 |
|
151 |
-
#
|
152 |
-
vector_store
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
)
|
157 |
-
|
158 |
-
# Create a retriever from the vector store
|
159 |
-
retriever = vector_store.as_retriever(
|
160 |
-
search_type='similarity',
|
161 |
-
search_kwargs={'k': 3}
|
162 |
-
)
|
163 |
-
|
164 |
-
def retrieve_context(state):
|
165 |
-
"""
|
166 |
-
Retrieves context from the vector store using the expanded or original query.
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
Dict: The updated state with the retrieved context.
|
173 |
-
"""
|
174 |
-
print("---------retrieve_context---------")
|
175 |
-
# Add original query to the state to improve result.
|
176 |
-
query = f"{state['query']}; {state['expanded_query']}" # Complete the code to define the key for the expanded query
|
177 |
-
#print("Query used for retrieval:", query) # Debugging: Print the query
|
178 |
-
|
179 |
-
# Retrieve documents from the vector store
|
180 |
docs = retriever.invoke(query)
|
181 |
-
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
{
|
186 |
-
"content": doc.page_content, # The actual content of the document
|
187 |
-
"metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
|
188 |
-
}
|
189 |
for doc in docs
|
190 |
]
|
191 |
-
state['context'] = context
|
192 |
-
|
193 |
-
#print(f"Groundedness loop count: {state['groundedness_loop_count']}")
|
194 |
return state
|
195 |
|
196 |
-
|
197 |
-
|
198 |
-
"""
|
199 |
-
Generates a response using the retrieved context, focusing on nutrition disorders.
|
200 |
-
|
201 |
-
Args:
|
202 |
-
state (Dict): The current state of the workflow, containing the query and retrieved context.
|
203 |
-
|
204 |
-
Returns:
|
205 |
-
Dict: The updated state with the generated response.
|
206 |
-
"""
|
207 |
-
print("---------craft_response---------")
|
208 |
system_message = '''
|
209 |
Generates a response to a user query and context provided.
|
210 |
|
@@ -222,41 +332,28 @@ def craft_response(state: Dict) -> Dict:
|
|
222 |
|
223 |
The answer you provide must come from the user queries with context provided.
|
224 |
If feedback is provided, use it to craft the response.
|
225 |
-
If information provided is not enough to answer the query
|
226 |
'''
|
227 |
-
|
228 |
response_prompt = ChatPromptTemplate.from_messages([
|
229 |
("system", system_message),
|
230 |
-
("user", "Query
|
231 |
])
|
232 |
-
|
233 |
-
chain = response_prompt | llm
|
234 |
response = chain.invoke({
|
235 |
"query": state['query'],
|
236 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
237 |
-
"feedback": state["feedback"]
|
238 |
})
|
239 |
-
state['response'] = response
|
240 |
-
|
241 |
-
|
242 |
return state
|
243 |
|
244 |
-
|
245 |
-
|
246 |
-
def score_groundedness(state: Dict) -> Dict:
|
247 |
-
"""
|
248 |
-
Checks whether the response is grounded in the retrieved context.
|
249 |
-
|
250 |
-
Args:
|
251 |
-
state (Dict): The current state of the workflow, containing the response and context.
|
252 |
-
|
253 |
-
Returns:
|
254 |
-
Dict: The updated state with the groundedness score.
|
255 |
-
"""
|
256 |
-
print("---------check_groundedness---------")
|
257 |
system_message = '''
|
258 |
You are tasked with rating AI generated answers to questions posed by users.
|
259 |
Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context.
|
|
|
260 |
In the input, the context is {context}, while the AI generated response is {response}.
|
261 |
|
262 |
Evaluation criteria:
|
@@ -272,37 +369,24 @@ def score_groundedness(state: Dict) -> Dict:
|
|
272 |
Do not show any instructions for deriving your answer.
|
273 |
|
274 |
Output your result as a float number between 0 and 1 using the evaluation criteria.
|
275 |
-
The better the criteria, the
|
276 |
'''
|
277 |
-
|
278 |
groundedness_prompt = ChatPromptTemplate.from_messages([
|
279 |
("system", system_message),
|
280 |
("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
|
281 |
])
|
282 |
-
|
283 |
-
chain = groundedness_prompt | llm | StrOutputParser()
|
284 |
groundedness_score = float(chain.invoke({
|
285 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
286 |
-
"response": state['response']
|
287 |
}))
|
288 |
-
print("groundedness_score: ", groundedness_score)
|
289 |
-
state['groundedness_loop_count'] += 1
|
290 |
-
print("#########Groundedness Incremented###########")
|
291 |
state['groundedness_score'] = groundedness_score
|
292 |
-
|
|
|
293 |
return state
|
294 |
|
295 |
-
def check_precision(state:
|
296 |
-
""
|
297 |
-
Checks whether the response precisely addresses the user’s query.
|
298 |
-
|
299 |
-
Args:
|
300 |
-
state (Dict): The current state of the workflow, containing the query and response.
|
301 |
-
|
302 |
-
Returns:
|
303 |
-
Dict: The updated state with the precision score.
|
304 |
-
"""
|
305 |
-
print("---------check_precision---------")
|
306 |
system_message = '''
|
307 |
Given question, answer and context verify if the context was useful in arriving at the given answer.
|
308 |
Give verdict as "1" if useful and "0" if not useful.
|
@@ -311,277 +395,191 @@ def check_precision(state: Dict) -> Dict:
|
|
311 |
0 or near 0 if it is least useful, 0.5 or near 0.5 if retry is warranted, and 1 or close to 1 is most useful.
|
312 |
Do not show any instructions for deriving your answer.
|
313 |
'''
|
314 |
-
|
315 |
precision_prompt = ChatPromptTemplate.from_messages([
|
316 |
("system", system_message),
|
317 |
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
|
318 |
])
|
319 |
-
|
320 |
-
chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
|
321 |
precision_score = float(chain.invoke({
|
322 |
"query": state['query'],
|
323 |
-
"response": state['response']
|
324 |
}))
|
325 |
state['precision_score'] = precision_score
|
326 |
-
print("precision_score:", precision_score)
|
327 |
state['precision_loop_count'] +=1
|
328 |
-
|
329 |
return state
|
330 |
|
331 |
-
def refine_response(state:
|
332 |
-
""
|
333 |
-
Suggests improvements for the generated response.
|
334 |
-
|
335 |
-
Args:
|
336 |
-
state (Dict): The current state of the workflow, containing the query and response.
|
337 |
-
|
338 |
-
Returns:
|
339 |
-
Dict: The updated state with response refinement suggestions.
|
340 |
-
"""
|
341 |
-
print("---------refine_response---------")
|
342 |
-
|
343 |
system_message = '''
|
344 |
-
Since the last response
|
345 |
use the feedback in terms of the query, context and the last response
|
346 |
to identify potential gaps, ambiguities, or missing details, and
|
347 |
to suggest improvements to enhance accuracy and completeness of the response.
|
348 |
'''
|
349 |
-
|
350 |
refine_response_prompt = ChatPromptTemplate.from_messages([
|
351 |
("system", system_message),
|
352 |
("user", "Query: {query}\nResponse: {response}\n\n"
|
353 |
"What improvements can be made to enhance accuracy and completeness?")
|
354 |
])
|
355 |
-
|
356 |
-
chain = refine_response_prompt | llm| StrOutputParser()
|
357 |
-
|
358 |
-
# Store response suggestions in a structured format
|
359 |
feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
|
360 |
-
print("feedback: ", feedback)
|
361 |
-
print(f"State: {state}")
|
362 |
state['feedback'] = feedback
|
|
|
363 |
return state
|
364 |
|
365 |
-
|
366 |
-
|
367 |
-
def refine_query(state: Dict) -> Dict:
|
368 |
-
"""
|
369 |
-
Suggests improvements for the expanded query.
|
370 |
-
|
371 |
-
Args:
|
372 |
-
state (Dict): The current state of the workflow, containing the query and expanded query.
|
373 |
-
|
374 |
-
Returns:
|
375 |
-
Dict: The updated state with query refinement suggestions.
|
376 |
-
"""
|
377 |
-
print("---------refine_query---------")
|
378 |
system_message = '''
|
379 |
-
Since the last response
|
380 |
use the feedback in terms of the query, context and re-generate extended queries
|
381 |
to identify specific keywords, scope refinements, or missing details, and
|
382 |
to provides structured suggestions for improvement to enhance accuracy and completeness of the response.
|
383 |
'''
|
384 |
-
|
385 |
refine_query_prompt = ChatPromptTemplate.from_messages([
|
386 |
("system", system_message),
|
387 |
("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
|
388 |
"What improvements can be made for a better search?")
|
389 |
])
|
390 |
-
|
391 |
-
chain = refine_query_prompt | llm | StrOutputParser()
|
392 |
-
|
393 |
-
# Store refinement suggestions without modifying the original expanded query
|
394 |
query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
|
395 |
-
print("query_feedback: ", query_feedback)
|
396 |
-
print(f"Groundedness loop count: {state['groundedness_loop_count']}")
|
397 |
state['query_feedback'] = query_feedback
|
|
|
398 |
return state
|
399 |
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
if state['groundedness_score'] >= 0.8: # Complete the code to define the threshold for groundedness
|
407 |
-
print("Moving to precision")
|
408 |
-
return "check_precision"
|
409 |
-
else:
|
410 |
-
if state["groundedness_loop_count"] > state['loop_max_iter']:
|
411 |
-
return "max_iterations_reached"
|
412 |
-
else:
|
413 |
-
print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
|
414 |
-
return "refine_response"
|
415 |
-
|
416 |
-
|
417 |
-
def should_continue_precision(state: Dict) -> str:
|
418 |
-
"""Decides if precision is sufficient or needs improvement."""
|
419 |
-
print("---------should_continue_precision---------")
|
420 |
-
print("precision loop count: ", state["precision_loop_count"])
|
421 |
-
if state['precision_score'] > 0.8: # Threshold for precision
|
422 |
-
return "pass" # Complete the workflow
|
423 |
else:
|
424 |
-
if state["
|
|
|
425 |
return "max_iterations_reached"
|
426 |
else:
|
427 |
-
|
428 |
-
return "
|
429 |
-
|
430 |
-
|
431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
|
433 |
-
def max_iterations_reached(state:
|
434 |
-
"
|
435 |
-
print("---------max_iterations_reached---------")
|
436 |
-
"""Handles the case when the maximum number of iterations is reached."""
|
437 |
response = "I'm unable to refine the response further. Please provide more context or clarify your question."
|
438 |
state['response'] = response
|
439 |
return state
|
440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
|
|
|
|
|
|
|
442 |
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
"""Creates the updated workflow for the AI nutrition agent."""
|
447 |
-
workflow = StateGraph(AgentState) # Complete the code to define the initial state of the agent
|
448 |
|
449 |
-
|
450 |
-
workflow.add_node("
|
451 |
-
workflow.add_node("
|
452 |
-
workflow.add_node("
|
453 |
-
workflow.add_node("
|
454 |
-
workflow.add_node("
|
455 |
-
workflow.add_node("
|
456 |
-
workflow.add_node("
|
457 |
-
workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations. Complete with the function to handle max iterations
|
458 |
|
459 |
-
# Main flow edges
|
460 |
workflow.add_edge(START, "expand_query")
|
461 |
workflow.add_edge("expand_query", "retrieve_context")
|
462 |
workflow.add_edge("retrieve_context", "craft_response")
|
463 |
workflow.add_edge("craft_response", "score_groundedness")
|
464 |
|
465 |
-
# Conditional edges based on groundedness check
|
466 |
workflow.add_conditional_edges(
|
467 |
"score_groundedness",
|
468 |
-
should_continue_groundedness,
|
469 |
{
|
470 |
-
"check_precision": "check_precision",
|
471 |
-
"refine_response": "refine_response",
|
472 |
-
"max_iterations_reached": "max_iterations_reached"
|
473 |
}
|
474 |
)
|
|
|
475 |
|
476 |
-
workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
|
477 |
-
|
478 |
-
# Conditional edges based on precision check
|
479 |
workflow.add_conditional_edges(
|
480 |
"check_precision",
|
481 |
-
should_continue_precision,
|
482 |
{
|
483 |
-
"pass": END,
|
484 |
-
"refine_query": "refine_query",
|
485 |
-
"max_iterations_reached": "max_iterations_reached"
|
486 |
}
|
487 |
)
|
488 |
-
|
489 |
-
workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
|
490 |
-
|
491 |
workflow.add_edge("max_iterations_reached", END)
|
492 |
|
493 |
return workflow
|
494 |
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
#=========================== Defining the agentic rag tool ====================#
|
499 |
-
|
500 |
-
WORKFLOW_APP = create_workflow().compile()
|
501 |
-
|
502 |
-
@tool
|
503 |
-
def agentic_rag(query: str):
|
504 |
-
"""
|
505 |
-
Runs the RAG-based agent with conversation history for context-aware responses.
|
506 |
-
|
507 |
-
Args:
|
508 |
-
query (str): The current user query.
|
509 |
-
|
510 |
-
Returns:
|
511 |
-
Dict[str, Any]: The updated state with the generated response and conversation history.
|
512 |
-
"""
|
513 |
-
# Initialize state with necessary parameters
|
514 |
-
inputs = {
|
515 |
-
"query": query, # Current user query
|
516 |
-
"expanded_query": "", # Complete the code to define the expanded version of the query
|
517 |
-
"context": [], # Retrieved documents (initially empty)
|
518 |
-
"response": "", # Complete the code to define the AI-generated response
|
519 |
-
"precision_score": 0.0, # Complete the code to define the precision score of the response
|
520 |
-
"groundedness_score": 0.0, # Complete the code to define the groundedness score of the response
|
521 |
-
"groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
|
522 |
-
"precision_loop_count": 0, # Complete the code to define the counter for precision loops
|
523 |
-
"feedback": "", # Complete the code to define the feedback
|
524 |
-
"query_feedback": "", # Complete the code to define the query feedback
|
525 |
-
"loop_max_iter": 3 # Complete the code to define the maximum number of iterations for loops
|
526 |
-
}
|
527 |
-
|
528 |
-
output = WORKFLOW_APP.invoke(inputs)
|
529 |
-
|
530 |
-
return output
|
531 |
-
|
532 |
-
#================================ Guardrails ===========================#
|
533 |
-
llama_guard_client = Groq(api_key=groq_api_key) # Groq(api_key=llama_api_key)
|
534 |
-
# Function to filter user input with Llama Guard
|
535 |
-
#def filter_input_with_llama_guard(user_input, model="llama-guard-3-8b"):
|
536 |
-
def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12b"):
|
537 |
-
"""
|
538 |
-
Filters user input using Llama Guard to ensure it is safe.
|
539 |
-
|
540 |
-
Parameters:
|
541 |
-
- user_input: The input provided by the user.
|
542 |
-
- model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
|
543 |
-
|
544 |
-
Returns:
|
545 |
-
- The filtered and safe input.
|
546 |
-
"""
|
547 |
-
try:
|
548 |
-
# Create a request to Llama Guard to filter the user input
|
549 |
-
response = llama_guard_client.chat.completions.create(
|
550 |
-
messages=[{"role": "user", "content": user_input}],
|
551 |
-
model=model,
|
552 |
-
)
|
553 |
-
# Return the filtered input
|
554 |
-
return response.choices[0].message.content.strip()
|
555 |
-
except Exception as e:
|
556 |
-
print(f"Error with Llama Guard: {e}")
|
557 |
-
return None
|
558 |
-
|
559 |
-
|
560 |
-
#============================= Adding Memory to the agent using mem0 ===============================#
|
561 |
|
562 |
class NutritionBot:
|
563 |
-
def __init__(self):
|
564 |
"""
|
565 |
Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
|
566 |
"""
|
|
|
|
|
|
|
567 |
|
568 |
# Initialize a memory client to store and retrieve customer interactions
|
569 |
-
self.memory = MemoryClient(api_key=
|
570 |
|
|
|
571 |
self.client = ChatOpenAI(
|
572 |
-
#model_name="gpt-4o", # Specify the model to use (e.g., GPT-4 optimized version)
|
573 |
model="gpt-4o-mini",
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
openai_api_base=endpoint, # Fill in the endpoint
|
578 |
-
temperature=0 # Controls randomness in responses; 0 ensures deterministic results
|
579 |
)
|
|
|
|
|
580 |
|
581 |
-
# Define tools available to the chatbot
|
582 |
-
tools = [
|
583 |
|
584 |
-
# Define the system prompt
|
585 |
system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
|
586 |
Guidelines for Interaction:
|
587 |
Maintain a polite, professional, and reassuring tone.
|
@@ -590,16 +588,16 @@ class NutritionBot:
|
|
590 |
Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
|
591 |
Ensure consistent and accurate information across conversations.
|
592 |
If any detail is unclear or missing, proactively ask for clarification.
|
593 |
-
Always use the
|
594 |
Keep track of ongoing issues and follow-ups to ensure continuity in support.
|
595 |
Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
|
596 |
"""
|
597 |
|
598 |
# Build the prompt template for the agent
|
599 |
prompt = ChatPromptTemplate.from_messages([
|
600 |
-
("system", system_prompt),
|
601 |
-
("human", "{input}"),
|
602 |
-
("placeholder", "{agent_scratchpad}")
|
603 |
])
|
604 |
|
605 |
# Create an agent capable of interacting with tools and executing tasks
|
@@ -609,80 +607,35 @@ class NutritionBot:
|
|
609 |
self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
610 |
|
611 |
def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
|
612 |
-
"""
|
613 |
-
Store customer interaction in memory for future reference.
|
614 |
-
|
615 |
-
Args:
|
616 |
-
user_id (str): Unique identifier for the customer.
|
617 |
-
message (str): Customer's query or message.
|
618 |
-
response (str): Chatbot's response.
|
619 |
-
metadata (Dict, optional): Additional metadata for the interaction.
|
620 |
-
"""
|
621 |
if metadata is None:
|
622 |
metadata = {}
|
623 |
-
|
624 |
-
# Add a timestamp to the metadata for tracking purposes
|
625 |
metadata["timestamp"] = datetime.now().isoformat()
|
626 |
-
|
627 |
-
# Format the conversation for storage
|
628 |
conversation = [
|
629 |
{"role": "user", "content": message},
|
630 |
{"role": "assistant", "content": response}
|
631 |
]
|
632 |
-
|
633 |
-
# Store the interaction in the memory client
|
634 |
-
self.memory.add(
|
635 |
-
conversation,
|
636 |
-
user_id=user_id,
|
637 |
-
output_format="v1.1",
|
638 |
-
metadata=metadata
|
639 |
-
)
|
640 |
-
|
641 |
|
642 |
def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
|
643 |
-
"""
|
644 |
-
|
645 |
-
|
646 |
-
Args:
|
647 |
-
user_id (str): Unique identifier for the customer.
|
648 |
-
query (str): The customer's current query.
|
649 |
-
|
650 |
-
Returns:
|
651 |
-
List[Dict]: A list of relevant past interactions.
|
652 |
-
"""
|
653 |
-
return self.memory.search(
|
654 |
-
query=query, # Search for interactions related to the query
|
655 |
-
user_id=user_id, # Restrict search to the specific user
|
656 |
-
limit=5 # Complete the code to define the limit for retrieved interactions
|
657 |
-
)
|
658 |
-
|
659 |
|
660 |
def handle_customer_query(self, user_id: str, query: str) -> str:
|
661 |
-
"""
|
662 |
-
Process a customer's query and provide a response, taking into account past interactions.
|
663 |
-
|
664 |
-
Args:
|
665 |
-
user_id (str): Unique identifier for the customer.
|
666 |
-
query (str): Customer's query.
|
667 |
-
|
668 |
-
Returns:
|
669 |
-
str: Chatbot's response.
|
670 |
-
"""
|
671 |
-
|
672 |
-
# Retrieve relevant past interactions for context
|
673 |
relevant_history = self.get_relevant_history(user_id, query)
|
674 |
|
675 |
-
# Build a context string from the relevant history
|
676 |
context = "Previous relevant interactions:\n"
|
677 |
-
for
|
678 |
-
|
679 |
-
|
|
|
|
|
|
|
|
|
|
|
680 |
context += "---\n"
|
681 |
|
682 |
-
# Print context for debugging purposes
|
683 |
-
print("Context: ", context)
|
684 |
-
|
685 |
-
# Prepare a prompt combining past context and the current query
|
686 |
prompt = f"""
|
687 |
Context:
|
688 |
{context}
|
@@ -691,62 +644,108 @@ class NutritionBot:
|
|
691 |
|
692 |
Provide a helpful response that takes into account any relevant past interactions.
|
693 |
"""
|
694 |
-
|
695 |
-
|
696 |
-
# Generate a response using the agent
|
697 |
-
response = self.agent_executor.invoke({"input": prompt})
|
698 |
-
st.write("Context: ", response)
|
699 |
-
|
700 |
-
# Store the current interaction for future reference
|
701 |
-
self.store_customer_interaction(
|
702 |
-
user_id=user_id,
|
703 |
-
message=query,
|
704 |
-
response=response["output"],
|
705 |
-
metadata={"type": "support_query"}
|
706 |
-
)
|
707 |
|
708 |
-
|
709 |
-
|
|
|
|
|
|
|
|
|
710 |
|
|
|
|
|
711 |
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
"""
|
717 |
-
|
718 |
-
st.
|
719 |
-
st.
|
|
|
720 |
|
721 |
-
# Initialize session state
|
722 |
if 'chat_history' not in st.session_state:
|
723 |
st.session_state.chat_history = []
|
724 |
if 'user_id' not in st.session_state:
|
725 |
st.session_state.user_id = None
|
726 |
-
|
727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
728 |
if st.session_state.user_id is None:
|
729 |
with st.form("login_form", clear_on_submit=True):
|
730 |
-
|
731 |
submit_button = st.form_submit_button("Login")
|
732 |
-
if submit_button and
|
733 |
-
st.session_state.user_id =
|
734 |
st.session_state.chat_history.append({
|
735 |
"role": "assistant",
|
736 |
-
"content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
|
737 |
})
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
|
|
|
|
|
|
|
|
743 |
# Display chat history
|
744 |
for message in st.session_state.chat_history:
|
745 |
with st.chat_message(message["role"]):
|
746 |
st.write(message["content"])
|
747 |
|
748 |
-
|
749 |
-
|
750 |
if user_query:
|
751 |
if user_query.lower() == "exit":
|
752 |
st.session_state.chat_history.append({"role": "user", "content": "exit"})
|
@@ -756,7 +755,9 @@ def nutrition_disorder_streamlit():
|
|
756 |
st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
|
757 |
with st.chat_message("assistant"):
|
758 |
st.write(goodbye_msg)
|
759 |
-
st.session_state.user_id = None
|
|
|
|
|
760 |
st.rerun()
|
761 |
return
|
762 |
|
@@ -765,44 +766,35 @@ def nutrition_disorder_streamlit():
|
|
765 |
st.write(user_query)
|
766 |
|
767 |
# Filter input using Llama Guard
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
if 'chatbot' not in st.session_state:
|
787 |
-
st.session_state.chatbot = NutritionBot() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
|
788 |
-
|
789 |
-
st.write("chatbot is calling handle_customer_query...")
|
790 |
-
st.write("user_id: ", st.session_state.user_id)
|
791 |
-
st.write("user_query: ", user_query)
|
792 |
-
|
793 |
-
response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
|
794 |
-
st.write("response is returned.")
|
795 |
-
# Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
|
796 |
-
|
797 |
-
st.write(response)
|
798 |
-
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
799 |
-
except Exception as e:
|
800 |
-
st.write(f"Sorry, I encountered an error while processing your query. Please try again. Error: {e}", str(e))
|
801 |
-
st.session_state.chat_history.append({"role": "assistant", "content": f"Sorry, I encountered an error while processing your query. Please try again. Error: {e}"})
|
802 |
else:
|
803 |
-
inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
|
804 |
-
st.write(inappropriate_msg)
|
805 |
st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
806 |
|
|
|
807 |
if __name__ == "__main__":
|
808 |
-
|
|
|
1 |
|
2 |
+
# --- 0. Library Imports ---
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import shutil
|
6 |
+
import time
|
7 |
+
import numpy as np
|
8 |
+
from datetime import datetime
|
9 |
+
from typing import Dict, List, Any, TypedDict, Tuple
|
10 |
+
|
11 |
+
# LangChain Core & Community
|
12 |
+
from langchain_core.documents import Document
|
13 |
+
from langchain_core.runnables import RunnablePassthrough
|
14 |
+
from langchain_core.output_parsers import StrOutputParser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
from langchain_core.prompts import ChatPromptTemplate
|
16 |
+
from langchain_core.tools import tool
|
17 |
+
from langchain_community.vectorstores import Chroma
|
18 |
+
from langchain_community.document_loaders import PyPDFDirectoryLoader
|
19 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
20 |
+
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
|
21 |
|
22 |
+
# LangChain OpenAI
|
|
|
|
|
|
|
|
|
23 |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
|
|
24 |
|
25 |
+
# LangChain Experimental
|
26 |
+
from langchain_experimental.text_splitter import SemanticChunker
|
|
|
27 |
|
28 |
+
# LangChain Agents & Graph
|
29 |
+
from langchain.agents import create_tool_calling_agent, AgentExecutor
|
30 |
+
from langgraph.graph import StateGraph, END, START
|
31 |
|
32 |
+
# External Libraries
|
33 |
+
import chromadb
|
34 |
+
from llama_parse import LlamaParse # For PDF parsing
|
35 |
+
from groq import Groq # For Llama Guard
|
36 |
+
from mem0 import MemoryClient # For memory
|
37 |
+
import streamlit as st # For Web UI
|
38 |
|
39 |
+
# Fix for numpy depreciation warning if necessary
|
40 |
+
np.float_ = np.float64
|
41 |
|
42 |
+
#import nest_asyncio
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
# --- 1. Configuration and Setup Utilities ---
|
45 |
+
|
46 |
+
from typing import Dict, List, Any, TypedDict, Tuple
|
47 |
+
|
48 |
+
def load_config_from_env() -> Dict:
|
49 |
+
"""Loads API keys and endpoints from environment variables."""
|
50 |
+
# Prioritize environment variables for deployment
|
51 |
+
config = {
|
52 |
+
"AZURE_OPENAI_API_KEY": os.getenv("AZURE_OPENAI_API_KEY"),
|
53 |
+
"AZURE_OPENAI_API_BASE": os.getenv("AZURE_OPENAI_API_BASE"),
|
54 |
+
"LLAMA_KEY": os.getenv("LLAMA_KEY"), # LlamaParse API Key
|
55 |
+
"MEM0_API_KEY": os.getenv("MEM0_API_KEY"),
|
56 |
+
"GROQ_API_KEY": os.getenv("GROQ_API_KEY"), # For Llama Guard via Groq
|
57 |
+
}
|
58 |
+
# Basic validation
|
59 |
+
for key, value in config.items():
|
60 |
+
if not value:
|
61 |
+
st.warning(f"Warning: Environment variable '{key}' is not set.")
|
62 |
+
return config
|
63 |
+
|
64 |
+
def initialize_llms_and_embeddings(config: Dict) -> Tuple[OpenAIEmbeddings, ChatOpenAI, chromadb.utils.embedding_functions.OpenAIEmbeddingFunction, Groq]:
|
65 |
+
"""Initializes LLM, Embedding models, and API clients."""
|
66 |
+
api_key = config["AZURE_OPENAI_API_KEY"]
|
67 |
+
endpoint = config["AZURE_OPENAI_API_BASE"]
|
68 |
+
groq_api_key = config["GROQ_API_KEY"]
|
69 |
+
|
70 |
+
# Initialize ChromaDB embedding function (used for collection creation)
|
71 |
+
embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
|
72 |
+
api_base=endpoint,
|
73 |
+
api_key=api_key,
|
74 |
+
model_name='text-embedding-ada-002' # Specify model explicitly
|
75 |
+
)
|
76 |
+
|
77 |
+
# Initialize LangChain OpenAI Embeddings (used for `SemanticChunker` and `Chroma` vectorstore)
|
78 |
+
embedding_model = OpenAIEmbeddings(
|
79 |
+
openai_api_base=endpoint,
|
80 |
+
openai_api_key=api_key,
|
81 |
+
model='text-embedding-ada-002' # Specify model explicitly
|
82 |
+
)
|
83 |
+
|
84 |
+
# Initialize LangChain Chat OpenAI model
|
85 |
+
llm = ChatOpenAI(
|
86 |
+
openai_api_base=endpoint,
|
87 |
+
openai_api_key=api_key,
|
88 |
+
model="gpt-4o-mini",
|
89 |
+
streaming=False,
|
90 |
+
temperature=0.0 # Ensure deterministic behavior for evaluations
|
91 |
+
)
|
92 |
+
|
93 |
+
# Initialize Groq client for Llama Guard
|
94 |
+
llama_guard_client = Groq(api_key=groq_api_key)
|
95 |
+
|
96 |
+
return embedding_model, llm, embedding_function, llama_guard_client
|
97 |
+
|
98 |
+
def filter_input_with_llama_guard(user_input: str, llama_guard_client: Groq, model: str = "meta-llama/llama-guard-4-12b") -> str:
|
99 |
+
"""
|
100 |
+
Filters user input using Llama Guard to ensure it is safe.
|
101 |
+
Returns "safe", "UNSAFE" (with category), or None on error.
|
102 |
+
"""
|
103 |
+
try:
|
104 |
+
response = llama_guard_client.chat.completions.create(
|
105 |
+
messages=[{"role": "user", "content": user_input}],
|
106 |
+
model=model,
|
107 |
+
)
|
108 |
+
return response.choices[0].message.content.strip()
|
109 |
+
except Exception as e:
|
110 |
+
st.error(f"Error with Llama Guard: {e}")
|
111 |
+
return None
|
112 |
+
|
113 |
+
# --- 2. Data Preparation (Parsing & Chunking) ---
|
114 |
+
|
115 |
+
# Note: In a deployed app, PDF parsing and vector DB creation would typically be
|
116 |
+
# a separate, offline process, and the pre-built vector DB would be loaded.
|
117 |
+
# For this template, we'll assume the nutritional_db is pre-existing and loaded.
|
118 |
+
|
119 |
+
def load_and_split_documents(folder_path: str, embedding_model: OpenAIEmbeddings) -> List[Document]:
|
120 |
+
"""Loads PDF documents from a folder and semantically chunks them."""
|
121 |
+
semantic_text_splitter = SemanticChunker(
|
122 |
+
embedding_model,
|
123 |
+
breakpoint_threshold_type='percentile',
|
124 |
+
breakpoint_threshold_amount=80
|
125 |
+
)
|
126 |
+
pdf_loader = PyPDFDirectoryLoader(folder_path)
|
127 |
+
chunks = pdf_loader.load_and_split(semantic_text_splitter)
|
128 |
+
return chunks
|
129 |
+
|
130 |
+
def parse_pdf_tables_with_llamaparse(pdf_path: str, llamaparse_api_key: str) -> Tuple[Dict, Dict]:
|
131 |
+
"""Parses a PDF file using LlamaParse and extracts page texts and tables."""
|
132 |
+
# This requires `nest_asyncio.apply()` to be called once at the start of the app if running async.
|
133 |
+
# In a Streamlit app, ensure it's at the very top level if needed.
|
134 |
+
# import nest_asyncio; nest_asyncio.apply() # Uncomment if needed for async parser
|
135 |
+
|
136 |
+
parser = LlamaParse(
|
137 |
+
result_type="markdown",
|
138 |
+
skip_diagonal_text=True,
|
139 |
+
fast_mode=False,
|
140 |
+
num_workers=1, # Adjust as per environment capabilities
|
141 |
+
check_interval=10,
|
142 |
+
api_key=llamaparse_api_key
|
143 |
+
)
|
144 |
+
json_objs = parser.get_json_result(pdf_path)
|
145 |
+
page_texts, tables = {}, {}
|
146 |
+
for obj in json_objs:
|
147 |
+
json_list = obj['pages']
|
148 |
+
name = obj["file_path"].split("/")[-1]
|
149 |
+
page_texts[name] = {}
|
150 |
+
tables[name] = {}
|
151 |
+
for json_item in json_list:
|
152 |
+
for component in json_item['items']:
|
153 |
+
if component['type'] == 'table':
|
154 |
+
tables[name][json_item['page']] = component['rows']
|
155 |
+
return page_texts, tables
|
156 |
+
|
157 |
+
def generate_hypothetical_questions(llm: ChatOpenAI, docs: List[Document], is_table: bool = False) -> List[Document]:
|
158 |
+
"""Generates hypothetical questions for text chunks or tables."""
|
159 |
+
prompt_template = """
|
160 |
+
Generate a list of exactly three (3) hypothetical questions that the below nutritional disorder {content_type} could be used to answer:
|
161 |
+
{content}
|
162 |
+
Ensure that the questions are specific in the context of nutrition, dietary deficiencies, metabolic disorders, vitamin and mineral imbalances, obesity, and related health conditions.
|
163 |
+
Generate only a list of questions.
|
164 |
+
Do not mention anything before or after the list.
|
165 |
+
If the content cannot answer any questions, return an empty list.
|
166 |
+
"""
|
167 |
+
hyp_docs = []
|
168 |
+
content_type = "table" if is_table else "document"
|
169 |
+
|
170 |
+
for i, doc in enumerate(docs):
|
171 |
+
content_to_use = str(doc) if is_table else doc.page_content # Tables are often raw data, stringify
|
172 |
+
try:
|
173 |
+
response = llm.invoke(prompt_template.format(content_type=content_type, content=content_to_use))
|
174 |
+
questions = response.content
|
175 |
+
except Exception as e:
|
176 |
+
st.error(f"Error generating hypothetical questions for {'table' if is_table else 'text'} chunk ID {doc.id}: {e}")
|
177 |
+
questions = "[]" # Return empty list string on error
|
178 |
+
|
179 |
+
questions_metadata = {
|
180 |
+
'original_content': content_to_use,
|
181 |
+
'source': doc.metadata.get('source', 'unknown'),
|
182 |
+
'page': doc.metadata.get('page', -1),
|
183 |
+
'type': content_type
|
184 |
+
}
|
185 |
+
hyp_docs.append(
|
186 |
+
Document(
|
187 |
+
id=f"{'table_' if is_table else 'text_chunk_'}{doc.id or i}", # Ensure unique IDs
|
188 |
+
page_content=questions,
|
189 |
+
metadata=questions_metadata
|
190 |
+
)
|
191 |
+
)
|
192 |
+
time.sleep(0.1) # Small delay to avoid rate limits
|
193 |
+
return hyp_docs
|
194 |
+
|
195 |
+
|
196 |
+
def create_and_persist_vector_db(
|
197 |
+
documents: List[Document],
|
198 |
+
embedding_model: OpenAIEmbeddings,
|
199 |
+
collection_name: str,
|
200 |
+
persist_directory: str
|
201 |
+
):
|
202 |
+
"""Creates/updates a Chroma vector store and persists it."""
|
203 |
+
# Ensure IDs are strings as required by ChromaDB
|
204 |
+
doc_ids = [str(d.id) for d in documents] if documents else []
|
205 |
+
if not doc_ids:
|
206 |
+
st.warning(f"No documents to add to collection '{collection_name}'.")
|
207 |
+
return
|
208 |
+
|
209 |
+
# Initialize or connect to Chroma vectorstore
|
210 |
+
vector_store = Chroma.from_documents(
|
211 |
+
documents,
|
212 |
+
embedding_model,
|
213 |
+
collection_name=collection_name,
|
214 |
+
persist_directory=persist_directory
|
215 |
+
)
|
216 |
+
st.info(f"Initialized ChromaDB with collection '{collection_name}' at '{persist_directory}'. "
|
217 |
+
f"Documents count: {len(documents)}")
|
218 |
+
return vector_store
|
219 |
+
|
220 |
+
def load_vector_db(
|
221 |
+
embedding_model: OpenAIEmbeddings,
|
222 |
+
collection_name: str,
|
223 |
+
persist_directory: str
|
224 |
+
) -> Chroma:
|
225 |
+
"""Loads an existing Chroma vector store."""
|
226 |
+
try:
|
227 |
+
# Check if the directory exists and contains ChromaDB files
|
228 |
+
if not os.path.exists(persist_directory) or not os.listdir(persist_directory):
|
229 |
+
st.error(f"Vector DB directory '{persist_directory}' is empty or does not exist.") or print(f"Vector DB directory '{persist_directory}' is empty or does not exist.")
|
230 |
+
st.warning("Please ensure the 'nutritional_db' folder is correctly placed/mounted.") or print("Please ensure the 'nutritional_db' folder is correctly placed/mounted.")
|
231 |
+
return None
|
232 |
+
|
233 |
+
vector_store = Chroma(
|
234 |
+
collection_name=collection_name,
|
235 |
+
persist_directory=persist_directory,
|
236 |
+
embedding_function=embedding_model
|
237 |
+
)
|
238 |
+
st.success(f"Successfully loaded ChromaDB collection '{collection_name}' from '{persist_directory}'.") or print(f"Successfully loaded ChromaDB collection '{collection_name}' from '{persist_directory}'.")
|
239 |
+
# You can add a check for the number of documents loaded for verification
|
240 |
+
# Example: print(vector_store._collection.count())
|
241 |
+
return vector_store
|
242 |
+
except Exception as e:
|
243 |
+
st.error(f"Error loading ChromaDB from '{persist_directory}': {e}") or print(f"Error loading ChromaDB from '{persist_directory}': {e}")
|
244 |
+
return None
|
245 |
+
|
246 |
+
# --- 3. Agent Workflow Definition (LangGraph) ---
|
247 |
|
248 |
class AgentState(TypedDict):
|
249 |
+
"""Represents the state of the AI agent at different stages of the workflow."""
|
250 |
+
query: str
|
251 |
+
expanded_query: str
|
252 |
+
context: List[Dict[str, Any]]
|
253 |
+
response: str
|
254 |
+
precision_score: float
|
255 |
+
groundedness_score: float
|
256 |
+
groundedness_loop_count: int
|
257 |
+
precision_loop_count: int
|
258 |
feedback: str
|
259 |
query_feedback: str
|
260 |
+
groundedness_check: bool # This field isn't used in should_continue_groundedness, can be removed
|
261 |
loop_max_iter: int
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
+
# Node functions for LangGraph
|
265 |
+
def expand_query(state: AgentState) -> AgentState:
|
266 |
+
st.write("---Expanding Query---")
|
|
|
267 |
system_message = '''
|
268 |
You are a domain expert assisting in answering questions related to research papers.
|
269 |
Convert the user query into something that a nutritionist would understand. Use domain related words.
|
270 |
+
Return three (3) related search queries based on the user's request separated by newline.
|
271 |
+
Return only three (3) versions of the question as a list.
|
272 |
Perform query expansion on the question received. If there are multiple common ways of phrasing a user question \
|
273 |
or common synonyms for key words in the question, make sure to return multiple versions \
|
274 |
of the query with the different phrasings.
|
275 |
+
If the query has multiple parts, split them into separate simpler queries. This is the only case where you can generate more than three (3) queries.
|
276 |
If there are acronyms or words you are not familiar with, do not try to rephrase them.
|
277 |
Generate only a list of questions. Do not mention anything before or after the list.
|
278 |
+
Use the query feedback if provided to craft the search queries.
|
279 |
'''
|
|
|
280 |
expand_prompt = ChatPromptTemplate.from_messages([
|
281 |
("system", system_message),
|
282 |
("user", "Expand this query: {query} using the feedback: {query_feedback}")
|
|
|
283 |
])
|
284 |
+
chain = expand_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state
|
|
|
285 |
expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
|
286 |
+
st.write(f"Expanded query:\n{expanded_query}")
|
287 |
state["expanded_query"] = expanded_query
|
288 |
return state
|
289 |
|
290 |
+
def retrieve_context(state: AgentState) -> AgentState:
|
291 |
+
st.write("---Retrieving Context---")
|
292 |
+
query = f"{state['query']}; {state['expanded_query']}"
|
293 |
+
st.write(f"Query used for retrieval:\n{query}")
|
294 |
|
295 |
+
# Ensure vector_store is loaded and available in session_state
|
296 |
+
if 'vector_store' not in st.session_state or st.session_state.vector_store is None:
|
297 |
+
st.error("Vector store not initialized. Cannot retrieve context.")
|
298 |
+
state['context'] = []
|
299 |
+
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
+
retriever = st.session_state.vector_store.as_retriever(
|
302 |
+
search_type='similarity',
|
303 |
+
search_kwargs={'k': 3}
|
304 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
docs = retriever.invoke(query)
|
306 |
+
st.write(f"Retrieved documents (first 100 chars each):\n{[doc.page_content[:100] for doc in docs]}")
|
307 |
|
308 |
+
context = [
|
309 |
+
{"content": doc.page_content, "metadata": doc.metadata}
|
|
|
|
|
|
|
|
|
310 |
for doc in docs
|
311 |
]
|
312 |
+
state['context'] = context
|
313 |
+
#st.write(f"Extracted context with metadata:\n{context}") # Too verbose for production UI
|
|
|
314 |
return state
|
315 |
|
316 |
+
def craft_response(state: AgentState) -> AgentState:
|
317 |
+
st.write("---Crafting Response---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
system_message = '''
|
319 |
Generates a response to a user query and context provided.
|
320 |
|
|
|
332 |
|
333 |
The answer you provide must come from the user queries with context provided.
|
334 |
If feedback is provided, use it to craft the response.
|
335 |
+
If information provided is not enough to answer the query respond with 'I don't know the answer. Not in my records.'
|
336 |
'''
|
|
|
337 |
response_prompt = ChatPromptTemplate.from_messages([
|
338 |
("system", system_message),
|
339 |
+
("user", "Query:\n{query}\nContext:\n{context}\n\nfeedback:\n{feedback}")
|
340 |
])
|
341 |
+
chain = response_prompt | st.session_state.llm # Use llm from session state
|
|
|
342 |
response = chain.invoke({
|
343 |
"query": state['query'],
|
344 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
345 |
+
"feedback": state["feedback"]
|
346 |
})
|
347 |
+
state['response'] = response.content # Access content from AIMessage
|
348 |
+
st.write(f"Intermediate response:\n{state['response']}")
|
|
|
349 |
return state
|
350 |
|
351 |
+
def score_groundedness(state: AgentState) -> AgentState:
|
352 |
+
st.write("---Checking Groundedness---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
system_message = '''
|
354 |
You are tasked with rating AI generated answers to questions posed by users.
|
355 |
Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context.
|
356 |
+
|
357 |
In the input, the context is {context}, while the AI generated response is {response}.
|
358 |
|
359 |
Evaluation criteria:
|
|
|
369 |
Do not show any instructions for deriving your answer.
|
370 |
|
371 |
Output your result as a float number between 0 and 1 using the evaluation criteria.
|
372 |
+
The better the criteria, the closer it is to 1 and the worse the criteria, the closer it is to 0.
|
373 |
'''
|
|
|
374 |
groundedness_prompt = ChatPromptTemplate.from_messages([
|
375 |
("system", system_message),
|
376 |
("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
|
377 |
])
|
378 |
+
chain = groundedness_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state
|
|
|
379 |
groundedness_score = float(chain.invoke({
|
380 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
381 |
+
"response": state['response']
|
382 |
}))
|
|
|
|
|
|
|
383 |
state['groundedness_score'] = groundedness_score
|
384 |
+
state['groundedness_loop_count'] += 1
|
385 |
+
st.write(f"Groundedness score: {groundedness_score}")
|
386 |
return state
|
387 |
|
388 |
+
def check_precision(state: AgentState) -> AgentState:
|
389 |
+
st.write("---Checking Precision---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
system_message = '''
|
391 |
Given question, answer and context verify if the context was useful in arriving at the given answer.
|
392 |
Give verdict as "1" if useful and "0" if not useful.
|
|
|
395 |
0 or near 0 if it is least useful, 0.5 or near 0.5 if retry is warranted, and 1 or close to 1 is most useful.
|
396 |
Do not show any instructions for deriving your answer.
|
397 |
'''
|
|
|
398 |
precision_prompt = ChatPromptTemplate.from_messages([
|
399 |
("system", system_message),
|
400 |
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
|
401 |
])
|
402 |
+
chain = precision_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state
|
|
|
403 |
precision_score = float(chain.invoke({
|
404 |
"query": state['query'],
|
405 |
+
"response": state['response']
|
406 |
}))
|
407 |
state['precision_score'] = precision_score
|
|
|
408 |
state['precision_loop_count'] +=1
|
409 |
+
st.write(f"Precision score: {precision_score}")
|
410 |
return state
|
411 |
|
412 |
+
def refine_response(state: AgentState) -> AgentState:
|
413 |
+
st.write("---Refining Response---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
system_message = '''
|
415 |
+
Since the last response failed the groundedness test, and is deemed not satisfactory,
|
416 |
use the feedback in terms of the query, context and the last response
|
417 |
to identify potential gaps, ambiguities, or missing details, and
|
418 |
to suggest improvements to enhance accuracy and completeness of the response.
|
419 |
'''
|
|
|
420 |
refine_response_prompt = ChatPromptTemplate.from_messages([
|
421 |
("system", system_message),
|
422 |
("user", "Query: {query}\nResponse: {response}\n\n"
|
423 |
"What improvements can be made to enhance accuracy and completeness?")
|
424 |
])
|
425 |
+
chain = refine_response_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state
|
|
|
|
|
|
|
426 |
feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
|
|
|
|
|
427 |
state['feedback'] = feedback
|
428 |
+
st.write(f"Refinement feedback:\n{feedback}")
|
429 |
return state
|
430 |
|
431 |
+
def refine_query(state: AgentState) -> AgentState:
|
432 |
+
st.write("---Refining Query---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
system_message = '''
|
434 |
+
Since the last response failed the precision test, and is deemed not satisfactory,
|
435 |
use the feedback in terms of the query, context and re-generate extended queries
|
436 |
to identify specific keywords, scope refinements, or missing details, and
|
437 |
to provides structured suggestions for improvement to enhance accuracy and completeness of the response.
|
438 |
'''
|
|
|
439 |
refine_query_prompt = ChatPromptTemplate.from_messages([
|
440 |
("system", system_message),
|
441 |
("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
|
442 |
"What improvements can be made for a better search?")
|
443 |
])
|
444 |
+
chain = refine_query_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state
|
|
|
|
|
|
|
445 |
query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
|
|
|
|
|
446 |
state['query_feedback'] = query_feedback
|
447 |
+
st.write(f"Query refinement feedback:\n{query_feedback}")
|
448 |
return state
|
449 |
|
450 |
+
def should_continue_groundedness(state: AgentState) -> str:
|
451 |
+
st.write("---Deciding Groundedness Continuation---")
|
452 |
+
st.write(f"Groundedness loop count: {state['groundedness_loop_count']}")
|
453 |
+
if state['groundedness_score'] >= 0.8:
|
454 |
+
st.write("Moving to precision check.")
|
455 |
+
return "check_precision"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
else:
|
457 |
+
if state["groundedness_loop_count"] >= state['loop_max_iter']:
|
458 |
+
st.write("Max iterations reached for groundedness.")
|
459 |
return "max_iterations_reached"
|
460 |
else:
|
461 |
+
st.write("Groundedness score not met. Refining response.")
|
462 |
+
return "refine_response"
|
463 |
+
|
464 |
+
def should_continue_precision(state: AgentState) -> str:
|
465 |
+
st.write("---Deciding Precision Continuation---")
|
466 |
+
st.write(f"Precision loop count: {state['precision_loop_count']}")
|
467 |
+
if state['precision_score'] > 0.8:
|
468 |
+
st.write("Precision sufficient. Ending workflow.")
|
469 |
+
return "pass"
|
470 |
+
else:
|
471 |
+
if state["precision_loop_count"] >= state['loop_max_iter']:
|
472 |
+
st.write("Max iterations reached for precision.")
|
473 |
+
return "max_iterations_reached"
|
474 |
+
else:
|
475 |
+
st.write("Precision score not met. Refining query.")
|
476 |
+
return "refine_query"
|
477 |
|
478 |
+
def max_iterations_reached(state: AgentState) -> AgentState:
|
479 |
+
st.write("---Max Iterations Reached---")
|
|
|
|
|
480 |
response = "I'm unable to refine the response further. Please provide more context or clarify your question."
|
481 |
state['response'] = response
|
482 |
return state
|
483 |
|
484 |
+
@tool
|
485 |
+
def agentic_rag_tool(query: str) -> Dict[str, Any]:
|
486 |
+
"""
|
487 |
+
Runs the RAG-based agent workflow to generate context-aware responses.
|
488 |
+
This function is exposed as a tool for the overall chatbot.
|
489 |
+
"""
|
490 |
+
# Initialize state for the LangGraph workflow
|
491 |
+
inputs = {
|
492 |
+
"query": query,
|
493 |
+
"expanded_query": "",
|
494 |
+
"context": [],
|
495 |
+
"response": "",
|
496 |
+
"precision_score": 0.0,
|
497 |
+
"groundedness_score": 0.0,
|
498 |
+
"groundedness_loop_count": 0,
|
499 |
+
"precision_loop_count": 0,
|
500 |
+
"feedback": "",
|
501 |
+
"query_feedback": "",
|
502 |
+
"loop_max_iter": 3
|
503 |
+
}
|
504 |
+
# Compile the workflow once and store it in session state if not already done
|
505 |
+
if 'workflow_app' not in st.session_state:
|
506 |
+
st.session_state.workflow_app = create_rag_workflow().compile()
|
507 |
|
508 |
+
# Invoke the compiled LangGraph workflow
|
509 |
+
output = st.session_state.workflow_app.invoke(inputs)
|
510 |
+
return output
|
511 |
|
512 |
+
def create_rag_workflow() -> StateGraph:
|
513 |
+
"""Creates the LangGraph workflow for the RAG agent."""
|
514 |
+
workflow = StateGraph(AgentState)
|
|
|
|
|
515 |
|
516 |
+
workflow.add_node("expand_query", expand_query)
|
517 |
+
workflow.add_node("retrieve_context", retrieve_context)
|
518 |
+
workflow.add_node("craft_response", craft_response)
|
519 |
+
workflow.add_node("score_groundedness", score_groundedness)
|
520 |
+
workflow.add_node("refine_response", refine_response)
|
521 |
+
workflow.add_node("check_precision", check_precision)
|
522 |
+
workflow.add_node("refine_query", refine_query)
|
523 |
+
workflow.add_node("max_iterations_reached", max_iterations_reached)
|
|
|
524 |
|
|
|
525 |
workflow.add_edge(START, "expand_query")
|
526 |
workflow.add_edge("expand_query", "retrieve_context")
|
527 |
workflow.add_edge("retrieve_context", "craft_response")
|
528 |
workflow.add_edge("craft_response", "score_groundedness")
|
529 |
|
|
|
530 |
workflow.add_conditional_edges(
|
531 |
"score_groundedness",
|
532 |
+
should_continue_groundedness,
|
533 |
{
|
534 |
+
"check_precision": "check_precision",
|
535 |
+
"refine_response": "refine_response",
|
536 |
+
"max_iterations_reached": "max_iterations_reached"
|
537 |
}
|
538 |
)
|
539 |
+
workflow.add_edge("refine_response", "craft_response")
|
540 |
|
|
|
|
|
|
|
541 |
workflow.add_conditional_edges(
|
542 |
"check_precision",
|
543 |
+
should_continue_precision,
|
544 |
{
|
545 |
+
"pass": END,
|
546 |
+
"refine_query": "refine_query",
|
547 |
+
"max_iterations_reached": "max_iterations_reached"
|
548 |
}
|
549 |
)
|
550 |
+
workflow.add_edge("refine_query", "expand_query")
|
|
|
|
|
551 |
workflow.add_edge("max_iterations_reached", END)
|
552 |
|
553 |
return workflow
|
554 |
|
555 |
+
# --- 4. Main Chatbot Class (Integrating Memory & Agent) ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
|
557 |
class NutritionBot:
|
558 |
+
def __init__(self, config: Dict):
|
559 |
"""
|
560 |
Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
|
561 |
"""
|
562 |
+
mem0_api_key = config["MEM0_API_KEY"]
|
563 |
+
openai_api_key = config["AZURE_OPENAI_API_KEY"]
|
564 |
+
openai_api_base = config["AZURE_OPENAI_API_BASE"]
|
565 |
|
566 |
# Initialize a memory client to store and retrieve customer interactions
|
567 |
+
self.memory = MemoryClient(api_key=mem0_api_key)
|
568 |
|
569 |
+
# Initialize the OpenAI client (LangChain ChatOpenAI)
|
570 |
self.client = ChatOpenAI(
|
|
|
571 |
model="gpt-4o-mini",
|
572 |
+
openai_api_key=openai_api_key,
|
573 |
+
openai_api_base=openai_api_base,
|
574 |
+
temperature=0
|
|
|
|
|
575 |
)
|
576 |
+
# Store LLM in session state for use in graph nodes
|
577 |
+
st.session_state.llm = self.client
|
578 |
|
579 |
+
# Define tools available to the chatbot
|
580 |
+
tools = [agentic_rag_tool]
|
581 |
|
582 |
+
# Define the system prompt for the agent
|
583 |
system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
|
584 |
Guidelines for Interaction:
|
585 |
Maintain a polite, professional, and reassuring tone.
|
|
|
588 |
Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
|
589 |
Ensure consistent and accurate information across conversations.
|
590 |
If any detail is unclear or missing, proactively ask for clarification.
|
591 |
+
Always use the agentic_rag_tool to retrieve up-to-date and evidence-based nutrition insights.
|
592 |
Keep track of ongoing issues and follow-ups to ensure continuity in support.
|
593 |
Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
|
594 |
"""
|
595 |
|
596 |
# Build the prompt template for the agent
|
597 |
prompt = ChatPromptTemplate.from_messages([
|
598 |
+
("system", system_prompt),
|
599 |
+
("human", "{input}"),
|
600 |
+
("placeholder", "{agent_scratchpad}")
|
601 |
])
|
602 |
|
603 |
# Create an agent capable of interacting with tools and executing tasks
|
|
|
607 |
self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
608 |
|
609 |
def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
|
610 |
+
"""Store customer interaction in memory for future reference."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
611 |
if metadata is None:
|
612 |
metadata = {}
|
|
|
|
|
613 |
metadata["timestamp"] = datetime.now().isoformat()
|
|
|
|
|
614 |
conversation = [
|
615 |
{"role": "user", "content": message},
|
616 |
{"role": "assistant", "content": response}
|
617 |
]
|
618 |
+
self.memory.add(conversation, user_id=user_id, output_format="v1.1", metadata=metadata)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
619 |
|
620 |
def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
|
621 |
+
"""Retrieve past interactions relevant to the current query."""
|
622 |
+
return self.memory.search(query=query, user_id=user_id, limit=5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
623 |
|
624 |
def handle_customer_query(self, user_id: str, query: str) -> str:
|
625 |
+
"""Process a customer's query and provide a response, taking into account past interactions."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
626 |
relevant_history = self.get_relevant_history(user_id, query)
|
627 |
|
|
|
628 |
context = "Previous relevant interactions:\n"
|
629 |
+
for memory_item in relevant_history:
|
630 |
+
# Mem0 'memory' field is typically a list of dicts or a string.
|
631 |
+
# Assuming 'v1.1' output format from `memory.add` means `memory_item['memory']` is structured.
|
632 |
+
if isinstance(memory_item.get('memory'), list):
|
633 |
+
for part in memory_item['memory']:
|
634 |
+
context += f"{part['role'].capitalize()}: {part['content']}\n"
|
635 |
+
else: # Fallback for older formats or if it's a simple string
|
636 |
+
context += f"History: {memory_item.get('memory', 'N/A')}\n"
|
637 |
context += "---\n"
|
638 |
|
|
|
|
|
|
|
|
|
639 |
prompt = f"""
|
640 |
Context:
|
641 |
{context}
|
|
|
644 |
|
645 |
Provide a helpful response that takes into account any relevant past interactions.
|
646 |
"""
|
647 |
+
st.write("Prompt sent to agent executor:", prompt) # Debugging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
648 |
|
649 |
+
try:
|
650 |
+
response_dict = self.agent_executor.invoke({"input": prompt})
|
651 |
+
response_content = response_dict.get('output', "I'm sorry, I couldn't generate a response.")
|
652 |
+
except Exception as e:
|
653 |
+
st.error(f"Error during agent execution: {e}")
|
654 |
+
response_content = f"I'm sorry, I encountered an internal error: {e}"
|
655 |
|
656 |
+
self.store_customer_interaction(user_id=user_id, message=query, response=response_content, metadata={"type": "support_query"})
|
657 |
+
return response_content
|
658 |
|
659 |
+
# --- 5. Streamlit UI ---
|
660 |
+
|
661 |
+
def nutrition_disorder_streamlit_app():
|
662 |
+
"""Streamlit-based UI for the Nutrition Disorder Specialist Agent."""
|
663 |
+
st.set_page_config(page_title="Nutrition Disorder Specialist", layout="centered")
|
664 |
+
|
665 |
+
st.title("👨⚕️ Nutrition Disorder Specialist")
|
666 |
+
st.markdown("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
|
667 |
+
st.markdown("---")
|
668 |
|
669 |
+
# Initialize session state variables
|
670 |
if 'chat_history' not in st.session_state:
|
671 |
st.session_state.chat_history = []
|
672 |
if 'user_id' not in st.session_state:
|
673 |
st.session_state.user_id = None
|
674 |
+
if 'chatbot' not in st.session_state:
|
675 |
+
st.session_state.chatbot = None
|
676 |
+
if 'config_loaded' not in st.session_state:
|
677 |
+
st.session_state.config_loaded = False
|
678 |
+
if 'vector_store_loaded' not in st.session_state:
|
679 |
+
st.session_state.vector_store_loaded = False
|
680 |
+
|
681 |
+
# --- Configuration Loading and Model Initialization ---
|
682 |
+
if not st.session_state.config_loaded:
|
683 |
+
with st.spinner("Loading configurations and initializing models..."):
|
684 |
+
config = load_config_from_env()
|
685 |
+
if not all(config.values()):
|
686 |
+
st.error("Some environment variables are missing. Please set them up for the app to function.")
|
687 |
+
st.stop() # Stop execution if critical configs are missing
|
688 |
+
|
689 |
+
# Step 1.
|
690 |
+
embedding_model, llm_instance, embedding_function, llama_guard_client_instance = initialize_llms_and_embeddings(config)
|
691 |
+
|
692 |
+
# Step 2. Store initialized objects in session state
|
693 |
+
st.session_state.config = config
|
694 |
+
st.session_state.embedding_model = embedding_model
|
695 |
+
st.session_state.llm = llm_instance
|
696 |
+
st.session_state.embedding_function = embedding_function # Used during vector_store creation/loading
|
697 |
+
st.session_state.llama_guard_client = llama_guard_client_instance
|
698 |
+
st.session_state.config_loaded = True
|
699 |
+
st.rerun() # Rerun to update UI after loading
|
700 |
+
|
701 |
+
# --- Vector Store Loading ---
|
702 |
+
if st.session_state.config_loaded and not st.session_state.vector_store_loaded:
|
703 |
+
with st.spinner("Loading nutrition knowledge base (vector database)..."):
|
704 |
+
# Ensure the nutritional_db directory exists relative to the app.py
|
705 |
+
# In Docker, this means the folder should be copied into /app
|
706 |
+
persist_dir = "./nutritional_db"
|
707 |
+
if not os.path.exists(persist_dir):
|
708 |
+
st.error(f"Required data directory '{persist_dir}' not found. Please ensure it's copied into the Docker image.")
|
709 |
+
st.stop()
|
710 |
+
|
711 |
+
st.session_state.vector_store = load_vector_db(
|
712 |
+
embedding_model=st.session_state.embedding_model,
|
713 |
+
collection_name="nutritional_hypotheticals",
|
714 |
+
persist_directory=persist_dir
|
715 |
+
)
|
716 |
+
if st.session_state.vector_store is None:
|
717 |
+
st.error("Failed to load vector database. Chat functionality will be limited.")
|
718 |
+
st.session_state.vector_store_loaded = True
|
719 |
+
st.rerun() # Rerun to update UI after loading
|
720 |
+
|
721 |
+
|
722 |
+
# --- Login Form ---
|
723 |
if st.session_state.user_id is None:
|
724 |
with st.form("login_form", clear_on_submit=True):
|
725 |
+
user_id_input = st.text_input("Please enter your name to begin:", key="user_id_input")
|
726 |
submit_button = st.form_submit_button("Login")
|
727 |
+
if submit_button and user_id_input:
|
728 |
+
st.session_state.user_id = user_id_input.strip()
|
729 |
st.session_state.chat_history.append({
|
730 |
"role": "assistant",
|
731 |
+
"content": f"Welcome, {st.session_state.user_id}! How can I help you with nutrition disorders today?"
|
732 |
})
|
733 |
+
# Initialize chatbot only after config and vector store are ready
|
734 |
+
if st.session_state.config_loaded and st.session_state.vector_store_loaded:
|
735 |
+
st.session_state.chatbot = NutritionBot(st.session_state.config)
|
736 |
+
else:
|
737 |
+
st.warning("Chatbot initialization delayed as configurations or vector store are still loading.")
|
738 |
+
st.rerun()
|
739 |
+
|
740 |
+
# --- Chat Interface ---
|
741 |
+
elif st.session_state.chatbot: # Only show chat if chatbot is initialized
|
742 |
# Display chat history
|
743 |
for message in st.session_state.chat_history:
|
744 |
with st.chat_message(message["role"]):
|
745 |
st.write(message["content"])
|
746 |
|
747 |
+
user_query = st.chat_input("Type your question here (e.g., 'What are dietary deficiencies?')")
|
748 |
+
|
749 |
if user_query:
|
750 |
if user_query.lower() == "exit":
|
751 |
st.session_state.chat_history.append({"role": "user", "content": "exit"})
|
|
|
755 |
st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
|
756 |
with st.chat_message("assistant"):
|
757 |
st.write(goodbye_msg)
|
758 |
+
st.session_state.user_id = None # Log out
|
759 |
+
st.session_state.chatbot = None # Clear chatbot instance
|
760 |
+
st.session_state.chat_history = [] # Clear history on exit
|
761 |
st.rerun()
|
762 |
return
|
763 |
|
|
|
766 |
st.write(user_query)
|
767 |
|
768 |
# Filter input using Llama Guard
|
769 |
+
with st.spinner("Filtering input for safety..."):
|
770 |
+
filtered_result = filter_input_with_llama_guard(user_query, st.session_state.llama_guard_client)
|
771 |
+
if filtered_result:
|
772 |
+
filtered_result = filtered_result.replace("\n", " ").strip()
|
773 |
+
st.info(f"Llama Guard says: {filtered_result}") # Show Llama Guard's verdict
|
774 |
+
|
775 |
+
# Process the user query if safe
|
776 |
+
if filtered_result and ("safe" in filtered_result.lower() or "s7" in filtered_result.lower()): # Allow "safe S7" etc.
|
777 |
+
with st.spinner("Thinking..."):
|
778 |
+
try:
|
779 |
+
response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
|
780 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
781 |
+
with st.chat_message("assistant"):
|
782 |
+
st.write(response)
|
783 |
+
except Exception as e:
|
784 |
+
error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {e}"
|
785 |
+
st.error(error_msg)
|
786 |
+
st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
787 |
else:
|
788 |
+
inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate or unsafe. Please try again."
|
|
|
789 |
st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
|
790 |
+
with st.chat_message("assistant"):
|
791 |
+
st.write(inappropriate_msg)
|
792 |
+
st.rerun() # Rerun to update chat history instantly
|
793 |
+
|
794 |
+
elif st.session_state.user_id: # User is logged in but chatbot not ready
|
795 |
+
st.info("Initializing chatbot. Please wait...")
|
796 |
+
|
797 |
|
798 |
+
# --- Main entry point for Streamlit App ---
|
799 |
if __name__ == "__main__":
|
800 |
+
nutrition_disorder_streamlit_app()
|