jenngang commited on
Commit
6fdd408
·
verified ·
1 Parent(s): 8668829

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +523 -531
app.py CHANGED
@@ -1,210 +1,320 @@
1
 
2
- # Import necessary libraries
3
- import os # Interacting with the operating system (reading/writing files)
4
- import chromadb # High-performance vector database for storing/querying dense vectors
5
- from dotenv import load_dotenv # Loading environment variables from a .env file
6
- import json # Parsing and handling JSON data
7
-
8
- # LangChain imports
9
- from langchain_core.documents import Document # Document data structures
10
- from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines
11
- from langchain_core.output_parsers import StrOutputParser # String output parser
12
- from langchain.prompts import ChatPromptTemplate # Template for chat prompts
13
- from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction
14
- from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers
15
- from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
16
- from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
17
-
18
- # LangChain community & experimental imports
19
- from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
20
- from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
21
- from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
22
- from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
23
- from langchain.text_splitter import (
24
- CharacterTextSplitter, # Splitting text by characters
25
- RecursiveCharacterTextSplitter # Recursive splitting of text by characters
26
- )
27
- from langchain_core.tools import tool
28
- from langchain.agents import create_tool_calling_agent, AgentExecutor
29
  from langchain_core.prompts import ChatPromptTemplate
 
 
 
 
 
30
 
31
- # LangChain OpenAI imports
32
- # Commented cout - not used
33
- #from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
34
- #from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
35
- # Added and used below
36
  from langchain_openai import OpenAIEmbeddings, ChatOpenAI
37
- #from langchain_openai import ChatOpenAI
38
 
39
- # LlamaParse & LlamaIndex imports
40
- from llama_parse import LlamaParse # Document parsing library
41
- from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
42
 
43
- # LangGraph import
44
- from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain
 
45
 
46
- # Pydantic import
47
- from pydantic import BaseModel # Pydantic for data validation
 
 
 
 
48
 
49
- # Typing imports
50
- from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations
51
 
52
- # Other utilities
53
- import numpy as np # Numpy for numerical operations
54
- from groq import Groq
55
- from mem0 import MemoryClient
56
- import streamlit as st
57
- from datetime import datetime
58
 
59
- #====================================SETUP=====================================#
60
- # Fetch secrets from Hugging Face Spaces
61
- api_key = os.environ["API_KEY"]
62
- endpoint = os.environ["OPENAI_API_BASE"]
63
- groq_api_key = os.environ['LLAMA_API_KEY'] # llama_api_key = os.environ['GROQ_API_KEY']
64
- MEM0_api_key = os.environ['MEM0_API_KEY'] # MEM0_api_key = os.environ['mem0']
65
- #my_api_key = os.environ["MY_API_KEY"]
66
-
67
- # Initialize the OpenAI embedding function for Chroma
68
- embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
69
- api_base=endpoint, # Complete the code to define the API base endpoint
70
- api_key=api_key # Complete the code to define the API key
71
- )
72
- #model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
73
-
74
- # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
75
-
76
- # Initialize the OpenAI Embeddings
77
- embedding_model = OpenAIEmbeddings(
78
- openai_api_base=endpoint,
79
- openai_api_key=api_key
80
- )
81
- #model='text-embedding-ada-002'
82
-
83
- # Initialize the Chat OpenAI model
84
- llm = ChatOpenAI(
85
- openai_api_base=endpoint,
86
- openai_api_key=api_key,
87
- model="gpt-4o-mini", # used gpt4o instead of gpt-4o-mini to get improved results
88
- streaming=False
89
- )
90
-
91
- # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
92
-
93
- # set the LLM and embedding model in the LlamaIndex settings.
94
- Settings.llm = llm # Complete the code to define the LLM model
95
- Settings.embedding = embedding_model # Complete the code to define the embedding model
96
-
97
- #================================Creating Langgraph agent======================#
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  class AgentState(TypedDict):
100
- query: str # The current user query
101
- expanded_query: str # The expanded version of the user query
102
- context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
103
- response: str # The generated response to the user query
104
- precision_score: float # The precision score of the response
105
- groundedness_score: float # The groundedness score of the response
106
- groundedness_loop_count: int # Counter for groundedness refinement loops
107
- precision_loop_count: int # Counter for precision refinement loops
 
108
  feedback: str
109
  query_feedback: str
110
- groundedness_check: bool
111
  loop_max_iter: int
112
 
113
- def expand_query(state):
114
- """
115
- Expands the user query to improve retrieval of nutrition disorder-related information.
116
-
117
- Args:
118
- state (Dict): The current state of the workflow, containing the user query.
119
 
120
- Returns:
121
- Dict: The updated state with the expanded query.
122
- """
123
- print("---------Expanding Query---------")
124
  system_message = '''
125
  You are a domain expert assisting in answering questions related to research papers.
126
  Convert the user query into something that a nutritionist would understand. Use domain related words.
127
- Return 3 related search queries based on the user's request seperated by newline.
128
- Return only 3 versions of the question as a list.
129
  Perform query expansion on the question received. If there are multiple common ways of phrasing a user question \
130
  or common synonyms for key words in the question, make sure to return multiple versions \
131
  of the query with the different phrasings.
132
- If the query has multiple parts, split them into separate simpler queries. This is the only case where you can generate more than 3 queries.
133
  If there are acronyms or words you are not familiar with, do not try to rephrase them.
134
  Generate only a list of questions. Do not mention anything before or after the list.
135
- Use the query feeback if provided to craft the search queries.
136
  '''
137
-
138
  expand_prompt = ChatPromptTemplate.from_messages([
139
  ("system", system_message),
140
  ("user", "Expand this query: {query} using the feedback: {query_feedback}")
141
-
142
  ])
143
-
144
- chain = expand_prompt | llm | StrOutputParser()
145
  expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
146
- print("expanded_query", expanded_query)
147
  state["expanded_query"] = expanded_query
148
  return state
149
 
 
 
 
 
150
 
151
- # Initialize the Chroma vector store for retrieving documents
152
- vector_store = Chroma(
153
- collection_name="nutritional_hypotheticals",
154
- persist_directory="./nutritional_db",
155
- embedding_function=embedding_model
156
- )
157
-
158
- # Create a retriever from the vector store
159
- retriever = vector_store.as_retriever(
160
- search_type='similarity',
161
- search_kwargs={'k': 3}
162
- )
163
-
164
- def retrieve_context(state):
165
- """
166
- Retrieves context from the vector store using the expanded or original query.
167
 
168
- Args:
169
- state (Dict): The current state of the workflow, containing the query and expanded query.
170
-
171
- Returns:
172
- Dict: The updated state with the retrieved context.
173
- """
174
- print("---------retrieve_context---------")
175
- # Add original query to the state to improve result.
176
- query = f"{state['query']}; {state['expanded_query']}" # Complete the code to define the key for the expanded query
177
- #print("Query used for retrieval:", query) # Debugging: Print the query
178
-
179
- # Retrieve documents from the vector store
180
  docs = retriever.invoke(query)
181
- print("Retrieved documents:", docs) # Debugging: Print the raw docs object
182
 
183
- # Extract both page_content and metadata from each document
184
- context= [
185
- {
186
- "content": doc.page_content, # The actual content of the document
187
- "metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
188
- }
189
  for doc in docs
190
  ]
191
- state['context'] = context # Complete the code to define the key for storing the context
192
- print("Extracted context with metadata:", context) # Debugging: Print the extracted context
193
- #print(f"Groundedness loop count: {state['groundedness_loop_count']}")
194
  return state
195
 
196
-
197
- def craft_response(state: Dict) -> Dict:
198
- """
199
- Generates a response using the retrieved context, focusing on nutrition disorders.
200
-
201
- Args:
202
- state (Dict): The current state of the workflow, containing the query and retrieved context.
203
-
204
- Returns:
205
- Dict: The updated state with the generated response.
206
- """
207
- print("---------craft_response---------")
208
  system_message = '''
209
  Generates a response to a user query and context provided.
210
 
@@ -222,41 +332,28 @@ def craft_response(state: Dict) -> Dict:
222
 
223
  The answer you provide must come from the user queries with context provided.
224
  If feedback is provided, use it to craft the response.
225
- If information provided is not enough to answer the query respons with 'I don't know the answer. Not in my records.'
226
  '''
227
-
228
  response_prompt = ChatPromptTemplate.from_messages([
229
  ("system", system_message),
230
- ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
231
  ])
232
-
233
- chain = response_prompt | llm
234
  response = chain.invoke({
235
  "query": state['query'],
236
  "context": "\n".join([doc["content"] for doc in state['context']]),
237
- "feedback": state["feedback"] # add feedback to the prompt
238
  })
239
- state['response'] = response
240
- print("intermediate response: ", response)
241
-
242
  return state
243
 
244
-
245
-
246
- def score_groundedness(state: Dict) -> Dict:
247
- """
248
- Checks whether the response is grounded in the retrieved context.
249
-
250
- Args:
251
- state (Dict): The current state of the workflow, containing the response and context.
252
-
253
- Returns:
254
- Dict: The updated state with the groundedness score.
255
- """
256
- print("---------check_groundedness---------")
257
  system_message = '''
258
  You are tasked with rating AI generated answers to questions posed by users.
259
  Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context.
 
260
  In the input, the context is {context}, while the AI generated response is {response}.
261
 
262
  Evaluation criteria:
@@ -272,37 +369,24 @@ def score_groundedness(state: Dict) -> Dict:
272
  Do not show any instructions for deriving your answer.
273
 
274
  Output your result as a float number between 0 and 1 using the evaluation criteria.
275
- The better the criteria, the cloase it is to 1 and the worse the criteria, the closer it is to 0.
276
  '''
277
-
278
  groundedness_prompt = ChatPromptTemplate.from_messages([
279
  ("system", system_message),
280
  ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
281
  ])
282
-
283
- chain = groundedness_prompt | llm | StrOutputParser()
284
  groundedness_score = float(chain.invoke({
285
  "context": "\n".join([doc["content"] for doc in state['context']]),
286
- "response": state['response'] # Complete the code to define the response
287
  }))
288
- print("groundedness_score: ", groundedness_score)
289
- state['groundedness_loop_count'] += 1
290
- print("#########Groundedness Incremented###########")
291
  state['groundedness_score'] = groundedness_score
292
-
 
293
  return state
294
 
295
- def check_precision(state: Dict) -> Dict:
296
- """
297
- Checks whether the response precisely addresses the user’s query.
298
-
299
- Args:
300
- state (Dict): The current state of the workflow, containing the query and response.
301
-
302
- Returns:
303
- Dict: The updated state with the precision score.
304
- """
305
- print("---------check_precision---------")
306
  system_message = '''
307
  Given question, answer and context verify if the context was useful in arriving at the given answer.
308
  Give verdict as "1" if useful and "0" if not useful.
@@ -311,277 +395,191 @@ def check_precision(state: Dict) -> Dict:
311
  0 or near 0 if it is least useful, 0.5 or near 0.5 if retry is warranted, and 1 or close to 1 is most useful.
312
  Do not show any instructions for deriving your answer.
313
  '''
314
-
315
  precision_prompt = ChatPromptTemplate.from_messages([
316
  ("system", system_message),
317
  ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
318
  ])
319
-
320
- chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
321
  precision_score = float(chain.invoke({
322
  "query": state['query'],
323
- "response": state['response'] # Complete the code to access the response from the state
324
  }))
325
  state['precision_score'] = precision_score
326
- print("precision_score:", precision_score)
327
  state['precision_loop_count'] +=1
328
- print("#########Precision Incremented###########")
329
  return state
330
 
331
- def refine_response(state: Dict) -> Dict:
332
- """
333
- Suggests improvements for the generated response.
334
-
335
- Args:
336
- state (Dict): The current state of the workflow, containing the query and response.
337
-
338
- Returns:
339
- Dict: The updated state with response refinement suggestions.
340
- """
341
- print("---------refine_response---------")
342
-
343
  system_message = '''
344
- Since the last response failded the groundedness test, and is deemed not satisfactory,
345
  use the feedback in terms of the query, context and the last response
346
  to identify potential gaps, ambiguities, or missing details, and
347
  to suggest improvements to enhance accuracy and completeness of the response.
348
  '''
349
-
350
  refine_response_prompt = ChatPromptTemplate.from_messages([
351
  ("system", system_message),
352
  ("user", "Query: {query}\nResponse: {response}\n\n"
353
  "What improvements can be made to enhance accuracy and completeness?")
354
  ])
355
-
356
- chain = refine_response_prompt | llm| StrOutputParser()
357
-
358
- # Store response suggestions in a structured format
359
  feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
360
- print("feedback: ", feedback)
361
- print(f"State: {state}")
362
  state['feedback'] = feedback
 
363
  return state
364
 
365
-
366
-
367
- def refine_query(state: Dict) -> Dict:
368
- """
369
- Suggests improvements for the expanded query.
370
-
371
- Args:
372
- state (Dict): The current state of the workflow, containing the query and expanded query.
373
-
374
- Returns:
375
- Dict: The updated state with query refinement suggestions.
376
- """
377
- print("---------refine_query---------")
378
  system_message = '''
379
- Since the last response failded the precision test, and is deemed not satisfactory,
380
  use the feedback in terms of the query, context and re-generate extended queries
381
  to identify specific keywords, scope refinements, or missing details, and
382
  to provides structured suggestions for improvement to enhance accuracy and completeness of the response.
383
  '''
384
-
385
  refine_query_prompt = ChatPromptTemplate.from_messages([
386
  ("system", system_message),
387
  ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
388
  "What improvements can be made for a better search?")
389
  ])
390
-
391
- chain = refine_query_prompt | llm | StrOutputParser()
392
-
393
- # Store refinement suggestions without modifying the original expanded query
394
  query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
395
- print("query_feedback: ", query_feedback)
396
- print(f"Groundedness loop count: {state['groundedness_loop_count']}")
397
  state['query_feedback'] = query_feedback
 
398
  return state
399
 
400
-
401
-
402
- def should_continue_groundedness(state):
403
- """Decides if groundedness is sufficient or needs improvement."""
404
- print("---------should_continue_groundedness---------")
405
- print("groundedness loop count: ", state['groundedness_loop_count'])
406
- if state['groundedness_score'] >= 0.8: # Complete the code to define the threshold for groundedness
407
- print("Moving to precision")
408
- return "check_precision"
409
- else:
410
- if state["groundedness_loop_count"] > state['loop_max_iter']:
411
- return "max_iterations_reached"
412
- else:
413
- print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
414
- return "refine_response"
415
-
416
-
417
- def should_continue_precision(state: Dict) -> str:
418
- """Decides if precision is sufficient or needs improvement."""
419
- print("---------should_continue_precision---------")
420
- print("precision loop count: ", state["precision_loop_count"])
421
- if state['precision_score'] > 0.8: # Threshold for precision
422
- return "pass" # Complete the workflow
423
  else:
424
- if state["precision_loop_count"] >= state['loop_max_iter']: # Maximum allowed loops
 
425
  return "max_iterations_reached"
426
  else:
427
- print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
428
- return "refine_query" # Refine the query
429
-
430
-
431
-
 
 
 
 
 
 
 
 
 
 
 
432
 
433
- def max_iterations_reached(state: Dict) -> Dict:
434
- """Handles the case when the maximum number of iterations is reached."""
435
- print("---------max_iterations_reached---------")
436
- """Handles the case when the maximum number of iterations is reached."""
437
  response = "I'm unable to refine the response further. Please provide more context or clarify your question."
438
  state['response'] = response
439
  return state
440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
 
 
 
442
 
443
- from langgraph.graph import END, StateGraph, START
444
-
445
- def create_workflow() -> StateGraph:
446
- """Creates the updated workflow for the AI nutrition agent."""
447
- workflow = StateGraph(AgentState) # Complete the code to define the initial state of the agent
448
 
449
- # Add processing nodes
450
- workflow.add_node("expand_query", expand_query) # Step 1: Expand user query. Complete with the function to expand the query
451
- workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
452
- workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
453
- workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
454
- workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
455
- workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision. Complete with the function to check precision
456
- workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
457
- workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations. Complete with the function to handle max iterations
458
 
459
- # Main flow edges
460
  workflow.add_edge(START, "expand_query")
461
  workflow.add_edge("expand_query", "retrieve_context")
462
  workflow.add_edge("retrieve_context", "craft_response")
463
  workflow.add_edge("craft_response", "score_groundedness")
464
 
465
- # Conditional edges based on groundedness check
466
  workflow.add_conditional_edges(
467
  "score_groundedness",
468
- should_continue_groundedness, # Use the conditional function
469
  {
470
- "check_precision": "check_precision", # If well-grounded, proceed to precision check.
471
- "refine_response": "refine_response", # If not, refine the response.
472
- "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
473
  }
474
  )
 
475
 
476
- workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
477
-
478
- # Conditional edges based on precision check
479
  workflow.add_conditional_edges(
480
  "check_precision",
481
- should_continue_precision, # Use the conditional function
482
  {
483
- "pass": END, # If precise, complete the workflow.
484
- "refine_query": "refine_query", # If imprecise, refine the query.
485
- "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
486
  }
487
  )
488
-
489
- workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
490
-
491
  workflow.add_edge("max_iterations_reached", END)
492
 
493
  return workflow
494
 
495
-
496
-
497
-
498
- #=========================== Defining the agentic rag tool ====================#
499
-
500
- WORKFLOW_APP = create_workflow().compile()
501
-
502
- @tool
503
- def agentic_rag(query: str):
504
- """
505
- Runs the RAG-based agent with conversation history for context-aware responses.
506
-
507
- Args:
508
- query (str): The current user query.
509
-
510
- Returns:
511
- Dict[str, Any]: The updated state with the generated response and conversation history.
512
- """
513
- # Initialize state with necessary parameters
514
- inputs = {
515
- "query": query, # Current user query
516
- "expanded_query": "", # Complete the code to define the expanded version of the query
517
- "context": [], # Retrieved documents (initially empty)
518
- "response": "", # Complete the code to define the AI-generated response
519
- "precision_score": 0.0, # Complete the code to define the precision score of the response
520
- "groundedness_score": 0.0, # Complete the code to define the groundedness score of the response
521
- "groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
522
- "precision_loop_count": 0, # Complete the code to define the counter for precision loops
523
- "feedback": "", # Complete the code to define the feedback
524
- "query_feedback": "", # Complete the code to define the query feedback
525
- "loop_max_iter": 3 # Complete the code to define the maximum number of iterations for loops
526
- }
527
-
528
- output = WORKFLOW_APP.invoke(inputs)
529
-
530
- return output
531
-
532
- #================================ Guardrails ===========================#
533
- llama_guard_client = Groq(api_key=groq_api_key) # Groq(api_key=llama_api_key)
534
- # Function to filter user input with Llama Guard
535
- #def filter_input_with_llama_guard(user_input, model="llama-guard-3-8b"):
536
- def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12b"):
537
- """
538
- Filters user input using Llama Guard to ensure it is safe.
539
-
540
- Parameters:
541
- - user_input: The input provided by the user.
542
- - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
543
-
544
- Returns:
545
- - The filtered and safe input.
546
- """
547
- try:
548
- # Create a request to Llama Guard to filter the user input
549
- response = llama_guard_client.chat.completions.create(
550
- messages=[{"role": "user", "content": user_input}],
551
- model=model,
552
- )
553
- # Return the filtered input
554
- return response.choices[0].message.content.strip()
555
- except Exception as e:
556
- print(f"Error with Llama Guard: {e}")
557
- return None
558
-
559
-
560
- #============================= Adding Memory to the agent using mem0 ===============================#
561
 
562
  class NutritionBot:
563
- def __init__(self):
564
  """
565
  Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
566
  """
 
 
 
567
 
568
  # Initialize a memory client to store and retrieve customer interactions
569
- self.memory = MemoryClient(api_key=MEM0_api_key) # userdata.get("mem0")) # Complete the code to define the memory client API key
570
 
 
571
  self.client = ChatOpenAI(
572
- #model_name="gpt-4o", # Specify the model to use (e.g., GPT-4 optimized version)
573
  model="gpt-4o-mini",
574
- #api_key = api_key, # config.get("API_KEY"), # API key for authentication
575
- openai_api_key=api_key, # Fill in the API key
576
- #base_url = endpoint, # config.get("OPENAI_API_BASE"),
577
- openai_api_base=endpoint, # Fill in the endpoint
578
- temperature=0 # Controls randomness in responses; 0 ensures deterministic results
579
  )
 
 
580
 
581
- # Define tools available to the chatbot, such as web search
582
- tools = [agentic_rag]
583
 
584
- # Define the system prompt to set the behavior of the chatbot
585
  system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
586
  Guidelines for Interaction:
587
  Maintain a polite, professional, and reassuring tone.
@@ -590,16 +588,16 @@ class NutritionBot:
590
  Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
591
  Ensure consistent and accurate information across conversations.
592
  If any detail is unclear or missing, proactively ask for clarification.
593
- Always use the agentic_rag tool to retrieve up-to-date and evidence-based nutrition insights.
594
  Keep track of ongoing issues and follow-ups to ensure continuity in support.
595
  Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
596
  """
597
 
598
  # Build the prompt template for the agent
599
  prompt = ChatPromptTemplate.from_messages([
600
- ("system", system_prompt), # System instructions
601
- ("human", "{input}"), # Placeholder for human input
602
- ("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
603
  ])
604
 
605
  # Create an agent capable of interacting with tools and executing tasks
@@ -609,80 +607,35 @@ class NutritionBot:
609
  self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
610
 
611
  def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
612
- """
613
- Store customer interaction in memory for future reference.
614
-
615
- Args:
616
- user_id (str): Unique identifier for the customer.
617
- message (str): Customer's query or message.
618
- response (str): Chatbot's response.
619
- metadata (Dict, optional): Additional metadata for the interaction.
620
- """
621
  if metadata is None:
622
  metadata = {}
623
-
624
- # Add a timestamp to the metadata for tracking purposes
625
  metadata["timestamp"] = datetime.now().isoformat()
626
-
627
- # Format the conversation for storage
628
  conversation = [
629
  {"role": "user", "content": message},
630
  {"role": "assistant", "content": response}
631
  ]
632
-
633
- # Store the interaction in the memory client
634
- self.memory.add(
635
- conversation,
636
- user_id=user_id,
637
- output_format="v1.1",
638
- metadata=metadata
639
- )
640
-
641
 
642
  def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
643
- """
644
- Retrieve past interactions relevant to the current query.
645
-
646
- Args:
647
- user_id (str): Unique identifier for the customer.
648
- query (str): The customer's current query.
649
-
650
- Returns:
651
- List[Dict]: A list of relevant past interactions.
652
- """
653
- return self.memory.search(
654
- query=query, # Search for interactions related to the query
655
- user_id=user_id, # Restrict search to the specific user
656
- limit=5 # Complete the code to define the limit for retrieved interactions
657
- )
658
-
659
 
660
  def handle_customer_query(self, user_id: str, query: str) -> str:
661
- """
662
- Process a customer's query and provide a response, taking into account past interactions.
663
-
664
- Args:
665
- user_id (str): Unique identifier for the customer.
666
- query (str): Customer's query.
667
-
668
- Returns:
669
- str: Chatbot's response.
670
- """
671
-
672
- # Retrieve relevant past interactions for context
673
  relevant_history = self.get_relevant_history(user_id, query)
674
 
675
- # Build a context string from the relevant history
676
  context = "Previous relevant interactions:\n"
677
- for memory in relevant_history:
678
- context += f"Customer: {memory['memory']}\n" # Customer's past messages
679
- context += f"Support: {memory['memory']}\n" # Chatbot's past responses
 
 
 
 
 
680
  context += "---\n"
681
 
682
- # Print context for debugging purposes
683
- print("Context: ", context)
684
-
685
- # Prepare a prompt combining past context and the current query
686
  prompt = f"""
687
  Context:
688
  {context}
@@ -691,62 +644,108 @@ class NutritionBot:
691
 
692
  Provide a helpful response that takes into account any relevant past interactions.
693
  """
694
- #st.write("Context: ", prompt)
695
-
696
- # Generate a response using the agent
697
- response = self.agent_executor.invoke({"input": prompt})
698
- st.write("Context: ", response)
699
-
700
- # Store the current interaction for future reference
701
- self.store_customer_interaction(
702
- user_id=user_id,
703
- message=query,
704
- response=response["output"],
705
- metadata={"type": "support_query"}
706
- )
707
 
708
- # Return the chatbot's response
709
- return response['output']
 
 
 
 
710
 
 
 
711
 
712
- #=====================User Interface using streamlit ===========================#
713
- def nutrition_disorder_streamlit():
714
- """
715
- A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
716
- """
717
- st.title("Nutrition Disorder Specialist")
718
- st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
719
- st.write("Type 'exit' to end the conversation.")
 
720
 
721
- # Initialize session state for chat history and user_id if they don't exist
722
  if 'chat_history' not in st.session_state:
723
  st.session_state.chat_history = []
724
  if 'user_id' not in st.session_state:
725
  st.session_state.user_id = None
726
-
727
- # Login form: Only if user is not logged in
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
  if st.session_state.user_id is None:
729
  with st.form("login_form", clear_on_submit=True):
730
- user_id = st.text_input("Please enter your name to begin:")
731
  submit_button = st.form_submit_button("Login")
732
- if submit_button and user_id:
733
- st.session_state.user_id = user_id
734
  st.session_state.chat_history.append({
735
  "role": "assistant",
736
- "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
737
  })
738
- st.session_state.login_submitted = True # Set flag to trigger rerun
739
- if st.session_state.get("login_submitted", False):
740
- st.session_state.pop("login_submitted")
741
- st.rerun()
742
- else:
 
 
 
 
743
  # Display chat history
744
  for message in st.session_state.chat_history:
745
  with st.chat_message(message["role"]):
746
  st.write(message["content"])
747
 
748
- # Chat input with custom placeholder text
749
- user_query = st.chat_input("Type your question here (or 'exit' to end)...") # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
750
  if user_query:
751
  if user_query.lower() == "exit":
752
  st.session_state.chat_history.append({"role": "user", "content": "exit"})
@@ -756,7 +755,9 @@ def nutrition_disorder_streamlit():
756
  st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
757
  with st.chat_message("assistant"):
758
  st.write(goodbye_msg)
759
- st.session_state.user_id = None
 
 
760
  st.rerun()
761
  return
762
 
@@ -765,44 +766,35 @@ def nutrition_disorder_streamlit():
765
  st.write(user_query)
766
 
767
  # Filter input using Llama Guard
768
- filtered_result = filter_input_with_llama_guard(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
769
- #filtered_result = filtered_result.replace("\n", " ") # Normalize the result
770
- #print(f"filtered_result (1): {filtered_result}")
771
- st.write(f"filtered_result (1): {filtered_result}")
772
- if filtered_result is None:
773
- #print("Agent: Sorry, I encountered an error while filtering your input. Please try again.")
774
- st.write("Agent: Sorry, I encountered an error while filtering your input. Please try again.")
775
- #continue
776
- else:
777
- filtered_result = filtered_result.replace("\n", " ") # Normalize the result
778
- #print(f"filtered_result (2): {filtered_result}")
779
- st.write(f"filtered_result (2): {filtered_result}")
780
-
781
- # Check if input is safe based on allowed statuses
782
- if filtered_result in ["safe", "safe S7", "safe S6"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
783
- st.write("Input is safe.")
784
- try:
785
-
786
- if 'chatbot' not in st.session_state:
787
- st.session_state.chatbot = NutritionBot() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
788
-
789
- st.write("chatbot is calling handle_customer_query...")
790
- st.write("user_id: ", st.session_state.user_id)
791
- st.write("user_query: ", user_query)
792
-
793
- response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
794
- st.write("response is returned.")
795
- # Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
796
-
797
- st.write(response)
798
- st.session_state.chat_history.append({"role": "assistant", "content": response})
799
- except Exception as e:
800
- st.write(f"Sorry, I encountered an error while processing your query. Please try again. Error: {e}", str(e))
801
- st.session_state.chat_history.append({"role": "assistant", "content": f"Sorry, I encountered an error while processing your query. Please try again. Error: {e}"})
802
  else:
803
- inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
804
- st.write(inappropriate_msg)
805
  st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
 
 
 
 
 
 
 
806
 
 
807
  if __name__ == "__main__":
808
- nutrition_disorder_streamlit()
 
1
 
2
+ # --- 0. Library Imports ---
3
+ import os
4
+ import json
5
+ import shutil
6
+ import time
7
+ import numpy as np
8
+ from datetime import datetime
9
+ from typing import Dict, List, Any, TypedDict, Tuple
10
+
11
+ # LangChain Core & Community
12
+ from langchain_core.documents import Document
13
+ from langchain_core.runnables import RunnablePassthrough
14
+ from langchain_core.output_parsers import StrOutputParser
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  from langchain_core.prompts import ChatPromptTemplate
16
+ from langchain_core.tools import tool
17
+ from langchain_community.vectorstores import Chroma
18
+ from langchain_community.document_loaders import PyPDFDirectoryLoader
19
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
20
+ from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
21
 
22
+ # LangChain OpenAI
 
 
 
 
23
  from langchain_openai import OpenAIEmbeddings, ChatOpenAI
 
24
 
25
+ # LangChain Experimental
26
+ from langchain_experimental.text_splitter import SemanticChunker
 
27
 
28
+ # LangChain Agents & Graph
29
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
30
+ from langgraph.graph import StateGraph, END, START
31
 
32
+ # External Libraries
33
+ import chromadb
34
+ from llama_parse import LlamaParse # For PDF parsing
35
+ from groq import Groq # For Llama Guard
36
+ from mem0 import MemoryClient # For memory
37
+ import streamlit as st # For Web UI
38
 
39
+ # Fix for numpy depreciation warning if necessary
40
+ np.float_ = np.float64
41
 
42
+ #import nest_asyncio
 
 
 
 
 
43
 
44
+ # --- 1. Configuration and Setup Utilities ---
45
+
46
+ from typing import Dict, List, Any, TypedDict, Tuple
47
+
48
+ def load_config_from_env() -> Dict:
49
+ """Loads API keys and endpoints from environment variables."""
50
+ # Prioritize environment variables for deployment
51
+ config = {
52
+ "AZURE_OPENAI_API_KEY": os.getenv("AZURE_OPENAI_API_KEY"),
53
+ "AZURE_OPENAI_API_BASE": os.getenv("AZURE_OPENAI_API_BASE"),
54
+ "LLAMA_KEY": os.getenv("LLAMA_KEY"), # LlamaParse API Key
55
+ "MEM0_API_KEY": os.getenv("MEM0_API_KEY"),
56
+ "GROQ_API_KEY": os.getenv("GROQ_API_KEY"), # For Llama Guard via Groq
57
+ }
58
+ # Basic validation
59
+ for key, value in config.items():
60
+ if not value:
61
+ st.warning(f"Warning: Environment variable '{key}' is not set.")
62
+ return config
63
+
64
+ def initialize_llms_and_embeddings(config: Dict) -> Tuple[OpenAIEmbeddings, ChatOpenAI, chromadb.utils.embedding_functions.OpenAIEmbeddingFunction, Groq]:
65
+ """Initializes LLM, Embedding models, and API clients."""
66
+ api_key = config["AZURE_OPENAI_API_KEY"]
67
+ endpoint = config["AZURE_OPENAI_API_BASE"]
68
+ groq_api_key = config["GROQ_API_KEY"]
69
+
70
+ # Initialize ChromaDB embedding function (used for collection creation)
71
+ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
72
+ api_base=endpoint,
73
+ api_key=api_key,
74
+ model_name='text-embedding-ada-002' # Specify model explicitly
75
+ )
76
+
77
+ # Initialize LangChain OpenAI Embeddings (used for `SemanticChunker` and `Chroma` vectorstore)
78
+ embedding_model = OpenAIEmbeddings(
79
+ openai_api_base=endpoint,
80
+ openai_api_key=api_key,
81
+ model='text-embedding-ada-002' # Specify model explicitly
82
+ )
83
+
84
+ # Initialize LangChain Chat OpenAI model
85
+ llm = ChatOpenAI(
86
+ openai_api_base=endpoint,
87
+ openai_api_key=api_key,
88
+ model="gpt-4o-mini",
89
+ streaming=False,
90
+ temperature=0.0 # Ensure deterministic behavior for evaluations
91
+ )
92
+
93
+ # Initialize Groq client for Llama Guard
94
+ llama_guard_client = Groq(api_key=groq_api_key)
95
+
96
+ return embedding_model, llm, embedding_function, llama_guard_client
97
+
98
+ def filter_input_with_llama_guard(user_input: str, llama_guard_client: Groq, model: str = "meta-llama/llama-guard-4-12b") -> str:
99
+ """
100
+ Filters user input using Llama Guard to ensure it is safe.
101
+ Returns "safe", "UNSAFE" (with category), or None on error.
102
+ """
103
+ try:
104
+ response = llama_guard_client.chat.completions.create(
105
+ messages=[{"role": "user", "content": user_input}],
106
+ model=model,
107
+ )
108
+ return response.choices[0].message.content.strip()
109
+ except Exception as e:
110
+ st.error(f"Error with Llama Guard: {e}")
111
+ return None
112
+
113
+ # --- 2. Data Preparation (Parsing & Chunking) ---
114
+
115
+ # Note: In a deployed app, PDF parsing and vector DB creation would typically be
116
+ # a separate, offline process, and the pre-built vector DB would be loaded.
117
+ # For this template, we'll assume the nutritional_db is pre-existing and loaded.
118
+
119
+ def load_and_split_documents(folder_path: str, embedding_model: OpenAIEmbeddings) -> List[Document]:
120
+ """Loads PDF documents from a folder and semantically chunks them."""
121
+ semantic_text_splitter = SemanticChunker(
122
+ embedding_model,
123
+ breakpoint_threshold_type='percentile',
124
+ breakpoint_threshold_amount=80
125
+ )
126
+ pdf_loader = PyPDFDirectoryLoader(folder_path)
127
+ chunks = pdf_loader.load_and_split(semantic_text_splitter)
128
+ return chunks
129
+
130
+ def parse_pdf_tables_with_llamaparse(pdf_path: str, llamaparse_api_key: str) -> Tuple[Dict, Dict]:
131
+ """Parses a PDF file using LlamaParse and extracts page texts and tables."""
132
+ # This requires `nest_asyncio.apply()` to be called once at the start of the app if running async.
133
+ # In a Streamlit app, ensure it's at the very top level if needed.
134
+ # import nest_asyncio; nest_asyncio.apply() # Uncomment if needed for async parser
135
+
136
+ parser = LlamaParse(
137
+ result_type="markdown",
138
+ skip_diagonal_text=True,
139
+ fast_mode=False,
140
+ num_workers=1, # Adjust as per environment capabilities
141
+ check_interval=10,
142
+ api_key=llamaparse_api_key
143
+ )
144
+ json_objs = parser.get_json_result(pdf_path)
145
+ page_texts, tables = {}, {}
146
+ for obj in json_objs:
147
+ json_list = obj['pages']
148
+ name = obj["file_path"].split("/")[-1]
149
+ page_texts[name] = {}
150
+ tables[name] = {}
151
+ for json_item in json_list:
152
+ for component in json_item['items']:
153
+ if component['type'] == 'table':
154
+ tables[name][json_item['page']] = component['rows']
155
+ return page_texts, tables
156
+
157
+ def generate_hypothetical_questions(llm: ChatOpenAI, docs: List[Document], is_table: bool = False) -> List[Document]:
158
+ """Generates hypothetical questions for text chunks or tables."""
159
+ prompt_template = """
160
+ Generate a list of exactly three (3) hypothetical questions that the below nutritional disorder {content_type} could be used to answer:
161
+ {content}
162
+ Ensure that the questions are specific in the context of nutrition, dietary deficiencies, metabolic disorders, vitamin and mineral imbalances, obesity, and related health conditions.
163
+ Generate only a list of questions.
164
+ Do not mention anything before or after the list.
165
+ If the content cannot answer any questions, return an empty list.
166
+ """
167
+ hyp_docs = []
168
+ content_type = "table" if is_table else "document"
169
+
170
+ for i, doc in enumerate(docs):
171
+ content_to_use = str(doc) if is_table else doc.page_content # Tables are often raw data, stringify
172
+ try:
173
+ response = llm.invoke(prompt_template.format(content_type=content_type, content=content_to_use))
174
+ questions = response.content
175
+ except Exception as e:
176
+ st.error(f"Error generating hypothetical questions for {'table' if is_table else 'text'} chunk ID {doc.id}: {e}")
177
+ questions = "[]" # Return empty list string on error
178
+
179
+ questions_metadata = {
180
+ 'original_content': content_to_use,
181
+ 'source': doc.metadata.get('source', 'unknown'),
182
+ 'page': doc.metadata.get('page', -1),
183
+ 'type': content_type
184
+ }
185
+ hyp_docs.append(
186
+ Document(
187
+ id=f"{'table_' if is_table else 'text_chunk_'}{doc.id or i}", # Ensure unique IDs
188
+ page_content=questions,
189
+ metadata=questions_metadata
190
+ )
191
+ )
192
+ time.sleep(0.1) # Small delay to avoid rate limits
193
+ return hyp_docs
194
+
195
+
196
+ def create_and_persist_vector_db(
197
+ documents: List[Document],
198
+ embedding_model: OpenAIEmbeddings,
199
+ collection_name: str,
200
+ persist_directory: str
201
+ ):
202
+ """Creates/updates a Chroma vector store and persists it."""
203
+ # Ensure IDs are strings as required by ChromaDB
204
+ doc_ids = [str(d.id) for d in documents] if documents else []
205
+ if not doc_ids:
206
+ st.warning(f"No documents to add to collection '{collection_name}'.")
207
+ return
208
+
209
+ # Initialize or connect to Chroma vectorstore
210
+ vector_store = Chroma.from_documents(
211
+ documents,
212
+ embedding_model,
213
+ collection_name=collection_name,
214
+ persist_directory=persist_directory
215
+ )
216
+ st.info(f"Initialized ChromaDB with collection '{collection_name}' at '{persist_directory}'. "
217
+ f"Documents count: {len(documents)}")
218
+ return vector_store
219
+
220
+ def load_vector_db(
221
+ embedding_model: OpenAIEmbeddings,
222
+ collection_name: str,
223
+ persist_directory: str
224
+ ) -> Chroma:
225
+ """Loads an existing Chroma vector store."""
226
+ try:
227
+ # Check if the directory exists and contains ChromaDB files
228
+ if not os.path.exists(persist_directory) or not os.listdir(persist_directory):
229
+ st.error(f"Vector DB directory '{persist_directory}' is empty or does not exist.") or print(f"Vector DB directory '{persist_directory}' is empty or does not exist.")
230
+ st.warning("Please ensure the 'nutritional_db' folder is correctly placed/mounted.") or print("Please ensure the 'nutritional_db' folder is correctly placed/mounted.")
231
+ return None
232
+
233
+ vector_store = Chroma(
234
+ collection_name=collection_name,
235
+ persist_directory=persist_directory,
236
+ embedding_function=embedding_model
237
+ )
238
+ st.success(f"Successfully loaded ChromaDB collection '{collection_name}' from '{persist_directory}'.") or print(f"Successfully loaded ChromaDB collection '{collection_name}' from '{persist_directory}'.")
239
+ # You can add a check for the number of documents loaded for verification
240
+ # Example: print(vector_store._collection.count())
241
+ return vector_store
242
+ except Exception as e:
243
+ st.error(f"Error loading ChromaDB from '{persist_directory}': {e}") or print(f"Error loading ChromaDB from '{persist_directory}': {e}")
244
+ return None
245
+
246
+ # --- 3. Agent Workflow Definition (LangGraph) ---
247
 
248
  class AgentState(TypedDict):
249
+ """Represents the state of the AI agent at different stages of the workflow."""
250
+ query: str
251
+ expanded_query: str
252
+ context: List[Dict[str, Any]]
253
+ response: str
254
+ precision_score: float
255
+ groundedness_score: float
256
+ groundedness_loop_count: int
257
+ precision_loop_count: int
258
  feedback: str
259
  query_feedback: str
260
+ groundedness_check: bool # This field isn't used in should_continue_groundedness, can be removed
261
  loop_max_iter: int
262
 
 
 
 
 
 
 
263
 
264
+ # Node functions for LangGraph
265
+ def expand_query(state: AgentState) -> AgentState:
266
+ st.write("---Expanding Query---")
 
267
  system_message = '''
268
  You are a domain expert assisting in answering questions related to research papers.
269
  Convert the user query into something that a nutritionist would understand. Use domain related words.
270
+ Return three (3) related search queries based on the user's request separated by newline.
271
+ Return only three (3) versions of the question as a list.
272
  Perform query expansion on the question received. If there are multiple common ways of phrasing a user question \
273
  or common synonyms for key words in the question, make sure to return multiple versions \
274
  of the query with the different phrasings.
275
+ If the query has multiple parts, split them into separate simpler queries. This is the only case where you can generate more than three (3) queries.
276
  If there are acronyms or words you are not familiar with, do not try to rephrase them.
277
  Generate only a list of questions. Do not mention anything before or after the list.
278
+ Use the query feedback if provided to craft the search queries.
279
  '''
 
280
  expand_prompt = ChatPromptTemplate.from_messages([
281
  ("system", system_message),
282
  ("user", "Expand this query: {query} using the feedback: {query_feedback}")
 
283
  ])
284
+ chain = expand_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state
 
285
  expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
286
+ st.write(f"Expanded query:\n{expanded_query}")
287
  state["expanded_query"] = expanded_query
288
  return state
289
 
290
+ def retrieve_context(state: AgentState) -> AgentState:
291
+ st.write("---Retrieving Context---")
292
+ query = f"{state['query']}; {state['expanded_query']}"
293
+ st.write(f"Query used for retrieval:\n{query}")
294
 
295
+ # Ensure vector_store is loaded and available in session_state
296
+ if 'vector_store' not in st.session_state or st.session_state.vector_store is None:
297
+ st.error("Vector store not initialized. Cannot retrieve context.")
298
+ state['context'] = []
299
+ return state
 
 
 
 
 
 
 
 
 
 
 
300
 
301
+ retriever = st.session_state.vector_store.as_retriever(
302
+ search_type='similarity',
303
+ search_kwargs={'k': 3}
304
+ )
 
 
 
 
 
 
 
 
305
  docs = retriever.invoke(query)
306
+ st.write(f"Retrieved documents (first 100 chars each):\n{[doc.page_content[:100] for doc in docs]}")
307
 
308
+ context = [
309
+ {"content": doc.page_content, "metadata": doc.metadata}
 
 
 
 
310
  for doc in docs
311
  ]
312
+ state['context'] = context
313
+ #st.write(f"Extracted context with metadata:\n{context}") # Too verbose for production UI
 
314
  return state
315
 
316
+ def craft_response(state: AgentState) -> AgentState:
317
+ st.write("---Crafting Response---")
 
 
 
 
 
 
 
 
 
 
318
  system_message = '''
319
  Generates a response to a user query and context provided.
320
 
 
332
 
333
  The answer you provide must come from the user queries with context provided.
334
  If feedback is provided, use it to craft the response.
335
+ If information provided is not enough to answer the query respond with 'I don't know the answer. Not in my records.'
336
  '''
 
337
  response_prompt = ChatPromptTemplate.from_messages([
338
  ("system", system_message),
339
+ ("user", "Query:\n{query}\nContext:\n{context}\n\nfeedback:\n{feedback}")
340
  ])
341
+ chain = response_prompt | st.session_state.llm # Use llm from session state
 
342
  response = chain.invoke({
343
  "query": state['query'],
344
  "context": "\n".join([doc["content"] for doc in state['context']]),
345
+ "feedback": state["feedback"]
346
  })
347
+ state['response'] = response.content # Access content from AIMessage
348
+ st.write(f"Intermediate response:\n{state['response']}")
 
349
  return state
350
 
351
+ def score_groundedness(state: AgentState) -> AgentState:
352
+ st.write("---Checking Groundedness---")
 
 
 
 
 
 
 
 
 
 
 
353
  system_message = '''
354
  You are tasked with rating AI generated answers to questions posed by users.
355
  Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context.
356
+
357
  In the input, the context is {context}, while the AI generated response is {response}.
358
 
359
  Evaluation criteria:
 
369
  Do not show any instructions for deriving your answer.
370
 
371
  Output your result as a float number between 0 and 1 using the evaluation criteria.
372
+ The better the criteria, the closer it is to 1 and the worse the criteria, the closer it is to 0.
373
  '''
 
374
  groundedness_prompt = ChatPromptTemplate.from_messages([
375
  ("system", system_message),
376
  ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
377
  ])
378
+ chain = groundedness_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state
 
379
  groundedness_score = float(chain.invoke({
380
  "context": "\n".join([doc["content"] for doc in state['context']]),
381
+ "response": state['response']
382
  }))
 
 
 
383
  state['groundedness_score'] = groundedness_score
384
+ state['groundedness_loop_count'] += 1
385
+ st.write(f"Groundedness score: {groundedness_score}")
386
  return state
387
 
388
+ def check_precision(state: AgentState) -> AgentState:
389
+ st.write("---Checking Precision---")
 
 
 
 
 
 
 
 
 
390
  system_message = '''
391
  Given question, answer and context verify if the context was useful in arriving at the given answer.
392
  Give verdict as "1" if useful and "0" if not useful.
 
395
  0 or near 0 if it is least useful, 0.5 or near 0.5 if retry is warranted, and 1 or close to 1 is most useful.
396
  Do not show any instructions for deriving your answer.
397
  '''
 
398
  precision_prompt = ChatPromptTemplate.from_messages([
399
  ("system", system_message),
400
  ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
401
  ])
402
+ chain = precision_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state
 
403
  precision_score = float(chain.invoke({
404
  "query": state['query'],
405
+ "response": state['response']
406
  }))
407
  state['precision_score'] = precision_score
 
408
  state['precision_loop_count'] +=1
409
+ st.write(f"Precision score: {precision_score}")
410
  return state
411
 
412
+ def refine_response(state: AgentState) -> AgentState:
413
+ st.write("---Refining Response---")
 
 
 
 
 
 
 
 
 
 
414
  system_message = '''
415
+ Since the last response failed the groundedness test, and is deemed not satisfactory,
416
  use the feedback in terms of the query, context and the last response
417
  to identify potential gaps, ambiguities, or missing details, and
418
  to suggest improvements to enhance accuracy and completeness of the response.
419
  '''
 
420
  refine_response_prompt = ChatPromptTemplate.from_messages([
421
  ("system", system_message),
422
  ("user", "Query: {query}\nResponse: {response}\n\n"
423
  "What improvements can be made to enhance accuracy and completeness?")
424
  ])
425
+ chain = refine_response_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state
 
 
 
426
  feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
 
 
427
  state['feedback'] = feedback
428
+ st.write(f"Refinement feedback:\n{feedback}")
429
  return state
430
 
431
+ def refine_query(state: AgentState) -> AgentState:
432
+ st.write("---Refining Query---")
 
 
 
 
 
 
 
 
 
 
 
433
  system_message = '''
434
+ Since the last response failed the precision test, and is deemed not satisfactory,
435
  use the feedback in terms of the query, context and re-generate extended queries
436
  to identify specific keywords, scope refinements, or missing details, and
437
  to provides structured suggestions for improvement to enhance accuracy and completeness of the response.
438
  '''
 
439
  refine_query_prompt = ChatPromptTemplate.from_messages([
440
  ("system", system_message),
441
  ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
442
  "What improvements can be made for a better search?")
443
  ])
444
+ chain = refine_query_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state
 
 
 
445
  query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
 
 
446
  state['query_feedback'] = query_feedback
447
+ st.write(f"Query refinement feedback:\n{query_feedback}")
448
  return state
449
 
450
+ def should_continue_groundedness(state: AgentState) -> str:
451
+ st.write("---Deciding Groundedness Continuation---")
452
+ st.write(f"Groundedness loop count: {state['groundedness_loop_count']}")
453
+ if state['groundedness_score'] >= 0.8:
454
+ st.write("Moving to precision check.")
455
+ return "check_precision"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  else:
457
+ if state["groundedness_loop_count"] >= state['loop_max_iter']:
458
+ st.write("Max iterations reached for groundedness.")
459
  return "max_iterations_reached"
460
  else:
461
+ st.write("Groundedness score not met. Refining response.")
462
+ return "refine_response"
463
+
464
+ def should_continue_precision(state: AgentState) -> str:
465
+ st.write("---Deciding Precision Continuation---")
466
+ st.write(f"Precision loop count: {state['precision_loop_count']}")
467
+ if state['precision_score'] > 0.8:
468
+ st.write("Precision sufficient. Ending workflow.")
469
+ return "pass"
470
+ else:
471
+ if state["precision_loop_count"] >= state['loop_max_iter']:
472
+ st.write("Max iterations reached for precision.")
473
+ return "max_iterations_reached"
474
+ else:
475
+ st.write("Precision score not met. Refining query.")
476
+ return "refine_query"
477
 
478
+ def max_iterations_reached(state: AgentState) -> AgentState:
479
+ st.write("---Max Iterations Reached---")
 
 
480
  response = "I'm unable to refine the response further. Please provide more context or clarify your question."
481
  state['response'] = response
482
  return state
483
 
484
+ @tool
485
+ def agentic_rag_tool(query: str) -> Dict[str, Any]:
486
+ """
487
+ Runs the RAG-based agent workflow to generate context-aware responses.
488
+ This function is exposed as a tool for the overall chatbot.
489
+ """
490
+ # Initialize state for the LangGraph workflow
491
+ inputs = {
492
+ "query": query,
493
+ "expanded_query": "",
494
+ "context": [],
495
+ "response": "",
496
+ "precision_score": 0.0,
497
+ "groundedness_score": 0.0,
498
+ "groundedness_loop_count": 0,
499
+ "precision_loop_count": 0,
500
+ "feedback": "",
501
+ "query_feedback": "",
502
+ "loop_max_iter": 3
503
+ }
504
+ # Compile the workflow once and store it in session state if not already done
505
+ if 'workflow_app' not in st.session_state:
506
+ st.session_state.workflow_app = create_rag_workflow().compile()
507
 
508
+ # Invoke the compiled LangGraph workflow
509
+ output = st.session_state.workflow_app.invoke(inputs)
510
+ return output
511
 
512
+ def create_rag_workflow() -> StateGraph:
513
+ """Creates the LangGraph workflow for the RAG agent."""
514
+ workflow = StateGraph(AgentState)
 
 
515
 
516
+ workflow.add_node("expand_query", expand_query)
517
+ workflow.add_node("retrieve_context", retrieve_context)
518
+ workflow.add_node("craft_response", craft_response)
519
+ workflow.add_node("score_groundedness", score_groundedness)
520
+ workflow.add_node("refine_response", refine_response)
521
+ workflow.add_node("check_precision", check_precision)
522
+ workflow.add_node("refine_query", refine_query)
523
+ workflow.add_node("max_iterations_reached", max_iterations_reached)
 
524
 
 
525
  workflow.add_edge(START, "expand_query")
526
  workflow.add_edge("expand_query", "retrieve_context")
527
  workflow.add_edge("retrieve_context", "craft_response")
528
  workflow.add_edge("craft_response", "score_groundedness")
529
 
 
530
  workflow.add_conditional_edges(
531
  "score_groundedness",
532
+ should_continue_groundedness,
533
  {
534
+ "check_precision": "check_precision",
535
+ "refine_response": "refine_response",
536
+ "max_iterations_reached": "max_iterations_reached"
537
  }
538
  )
539
+ workflow.add_edge("refine_response", "craft_response")
540
 
 
 
 
541
  workflow.add_conditional_edges(
542
  "check_precision",
543
+ should_continue_precision,
544
  {
545
+ "pass": END,
546
+ "refine_query": "refine_query",
547
+ "max_iterations_reached": "max_iterations_reached"
548
  }
549
  )
550
+ workflow.add_edge("refine_query", "expand_query")
 
 
551
  workflow.add_edge("max_iterations_reached", END)
552
 
553
  return workflow
554
 
555
+ # --- 4. Main Chatbot Class (Integrating Memory & Agent) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
 
557
  class NutritionBot:
558
+ def __init__(self, config: Dict):
559
  """
560
  Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
561
  """
562
+ mem0_api_key = config["MEM0_API_KEY"]
563
+ openai_api_key = config["AZURE_OPENAI_API_KEY"]
564
+ openai_api_base = config["AZURE_OPENAI_API_BASE"]
565
 
566
  # Initialize a memory client to store and retrieve customer interactions
567
+ self.memory = MemoryClient(api_key=mem0_api_key)
568
 
569
+ # Initialize the OpenAI client (LangChain ChatOpenAI)
570
  self.client = ChatOpenAI(
 
571
  model="gpt-4o-mini",
572
+ openai_api_key=openai_api_key,
573
+ openai_api_base=openai_api_base,
574
+ temperature=0
 
 
575
  )
576
+ # Store LLM in session state for use in graph nodes
577
+ st.session_state.llm = self.client
578
 
579
+ # Define tools available to the chatbot
580
+ tools = [agentic_rag_tool]
581
 
582
+ # Define the system prompt for the agent
583
  system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
584
  Guidelines for Interaction:
585
  Maintain a polite, professional, and reassuring tone.
 
588
  Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
589
  Ensure consistent and accurate information across conversations.
590
  If any detail is unclear or missing, proactively ask for clarification.
591
+ Always use the agentic_rag_tool to retrieve up-to-date and evidence-based nutrition insights.
592
  Keep track of ongoing issues and follow-ups to ensure continuity in support.
593
  Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
594
  """
595
 
596
  # Build the prompt template for the agent
597
  prompt = ChatPromptTemplate.from_messages([
598
+ ("system", system_prompt),
599
+ ("human", "{input}"),
600
+ ("placeholder", "{agent_scratchpad}")
601
  ])
602
 
603
  # Create an agent capable of interacting with tools and executing tasks
 
607
  self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
608
 
609
  def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
610
+ """Store customer interaction in memory for future reference."""
 
 
 
 
 
 
 
 
611
  if metadata is None:
612
  metadata = {}
 
 
613
  metadata["timestamp"] = datetime.now().isoformat()
 
 
614
  conversation = [
615
  {"role": "user", "content": message},
616
  {"role": "assistant", "content": response}
617
  ]
618
+ self.memory.add(conversation, user_id=user_id, output_format="v1.1", metadata=metadata)
 
 
 
 
 
 
 
 
619
 
620
  def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
621
+ """Retrieve past interactions relevant to the current query."""
622
+ return self.memory.search(query=query, user_id=user_id, limit=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
 
624
  def handle_customer_query(self, user_id: str, query: str) -> str:
625
+ """Process a customer's query and provide a response, taking into account past interactions."""
 
 
 
 
 
 
 
 
 
 
 
626
  relevant_history = self.get_relevant_history(user_id, query)
627
 
 
628
  context = "Previous relevant interactions:\n"
629
+ for memory_item in relevant_history:
630
+ # Mem0 'memory' field is typically a list of dicts or a string.
631
+ # Assuming 'v1.1' output format from `memory.add` means `memory_item['memory']` is structured.
632
+ if isinstance(memory_item.get('memory'), list):
633
+ for part in memory_item['memory']:
634
+ context += f"{part['role'].capitalize()}: {part['content']}\n"
635
+ else: # Fallback for older formats or if it's a simple string
636
+ context += f"History: {memory_item.get('memory', 'N/A')}\n"
637
  context += "---\n"
638
 
 
 
 
 
639
  prompt = f"""
640
  Context:
641
  {context}
 
644
 
645
  Provide a helpful response that takes into account any relevant past interactions.
646
  """
647
+ st.write("Prompt sent to agent executor:", prompt) # Debugging
 
 
 
 
 
 
 
 
 
 
 
 
648
 
649
+ try:
650
+ response_dict = self.agent_executor.invoke({"input": prompt})
651
+ response_content = response_dict.get('output', "I'm sorry, I couldn't generate a response.")
652
+ except Exception as e:
653
+ st.error(f"Error during agent execution: {e}")
654
+ response_content = f"I'm sorry, I encountered an internal error: {e}"
655
 
656
+ self.store_customer_interaction(user_id=user_id, message=query, response=response_content, metadata={"type": "support_query"})
657
+ return response_content
658
 
659
+ # --- 5. Streamlit UI ---
660
+
661
+ def nutrition_disorder_streamlit_app():
662
+ """Streamlit-based UI for the Nutrition Disorder Specialist Agent."""
663
+ st.set_page_config(page_title="Nutrition Disorder Specialist", layout="centered")
664
+
665
+ st.title("👨‍⚕️ Nutrition Disorder Specialist")
666
+ st.markdown("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
667
+ st.markdown("---")
668
 
669
+ # Initialize session state variables
670
  if 'chat_history' not in st.session_state:
671
  st.session_state.chat_history = []
672
  if 'user_id' not in st.session_state:
673
  st.session_state.user_id = None
674
+ if 'chatbot' not in st.session_state:
675
+ st.session_state.chatbot = None
676
+ if 'config_loaded' not in st.session_state:
677
+ st.session_state.config_loaded = False
678
+ if 'vector_store_loaded' not in st.session_state:
679
+ st.session_state.vector_store_loaded = False
680
+
681
+ # --- Configuration Loading and Model Initialization ---
682
+ if not st.session_state.config_loaded:
683
+ with st.spinner("Loading configurations and initializing models..."):
684
+ config = load_config_from_env()
685
+ if not all(config.values()):
686
+ st.error("Some environment variables are missing. Please set them up for the app to function.")
687
+ st.stop() # Stop execution if critical configs are missing
688
+
689
+ # Step 1.
690
+ embedding_model, llm_instance, embedding_function, llama_guard_client_instance = initialize_llms_and_embeddings(config)
691
+
692
+ # Step 2. Store initialized objects in session state
693
+ st.session_state.config = config
694
+ st.session_state.embedding_model = embedding_model
695
+ st.session_state.llm = llm_instance
696
+ st.session_state.embedding_function = embedding_function # Used during vector_store creation/loading
697
+ st.session_state.llama_guard_client = llama_guard_client_instance
698
+ st.session_state.config_loaded = True
699
+ st.rerun() # Rerun to update UI after loading
700
+
701
+ # --- Vector Store Loading ---
702
+ if st.session_state.config_loaded and not st.session_state.vector_store_loaded:
703
+ with st.spinner("Loading nutrition knowledge base (vector database)..."):
704
+ # Ensure the nutritional_db directory exists relative to the app.py
705
+ # In Docker, this means the folder should be copied into /app
706
+ persist_dir = "./nutritional_db"
707
+ if not os.path.exists(persist_dir):
708
+ st.error(f"Required data directory '{persist_dir}' not found. Please ensure it's copied into the Docker image.")
709
+ st.stop()
710
+
711
+ st.session_state.vector_store = load_vector_db(
712
+ embedding_model=st.session_state.embedding_model,
713
+ collection_name="nutritional_hypotheticals",
714
+ persist_directory=persist_dir
715
+ )
716
+ if st.session_state.vector_store is None:
717
+ st.error("Failed to load vector database. Chat functionality will be limited.")
718
+ st.session_state.vector_store_loaded = True
719
+ st.rerun() # Rerun to update UI after loading
720
+
721
+
722
+ # --- Login Form ---
723
  if st.session_state.user_id is None:
724
  with st.form("login_form", clear_on_submit=True):
725
+ user_id_input = st.text_input("Please enter your name to begin:", key="user_id_input")
726
  submit_button = st.form_submit_button("Login")
727
+ if submit_button and user_id_input:
728
+ st.session_state.user_id = user_id_input.strip()
729
  st.session_state.chat_history.append({
730
  "role": "assistant",
731
+ "content": f"Welcome, {st.session_state.user_id}! How can I help you with nutrition disorders today?"
732
  })
733
+ # Initialize chatbot only after config and vector store are ready
734
+ if st.session_state.config_loaded and st.session_state.vector_store_loaded:
735
+ st.session_state.chatbot = NutritionBot(st.session_state.config)
736
+ else:
737
+ st.warning("Chatbot initialization delayed as configurations or vector store are still loading.")
738
+ st.rerun()
739
+
740
+ # --- Chat Interface ---
741
+ elif st.session_state.chatbot: # Only show chat if chatbot is initialized
742
  # Display chat history
743
  for message in st.session_state.chat_history:
744
  with st.chat_message(message["role"]):
745
  st.write(message["content"])
746
 
747
+ user_query = st.chat_input("Type your question here (e.g., 'What are dietary deficiencies?')")
748
+
749
  if user_query:
750
  if user_query.lower() == "exit":
751
  st.session_state.chat_history.append({"role": "user", "content": "exit"})
 
755
  st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
756
  with st.chat_message("assistant"):
757
  st.write(goodbye_msg)
758
+ st.session_state.user_id = None # Log out
759
+ st.session_state.chatbot = None # Clear chatbot instance
760
+ st.session_state.chat_history = [] # Clear history on exit
761
  st.rerun()
762
  return
763
 
 
766
  st.write(user_query)
767
 
768
  # Filter input using Llama Guard
769
+ with st.spinner("Filtering input for safety..."):
770
+ filtered_result = filter_input_with_llama_guard(user_query, st.session_state.llama_guard_client)
771
+ if filtered_result:
772
+ filtered_result = filtered_result.replace("\n", " ").strip()
773
+ st.info(f"Llama Guard says: {filtered_result}") # Show Llama Guard's verdict
774
+
775
+ # Process the user query if safe
776
+ if filtered_result and ("safe" in filtered_result.lower() or "s7" in filtered_result.lower()): # Allow "safe S7" etc.
777
+ with st.spinner("Thinking..."):
778
+ try:
779
+ response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
780
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
781
+ with st.chat_message("assistant"):
782
+ st.write(response)
783
+ except Exception as e:
784
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {e}"
785
+ st.error(error_msg)
786
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787
  else:
788
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate or unsafe. Please try again."
 
789
  st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
790
+ with st.chat_message("assistant"):
791
+ st.write(inappropriate_msg)
792
+ st.rerun() # Rerun to update chat history instantly
793
+
794
+ elif st.session_state.user_id: # User is logged in but chatbot not ready
795
+ st.info("Initializing chatbot. Please wait...")
796
+
797
 
798
+ # --- Main entry point for Streamlit App ---
799
  if __name__ == "__main__":
800
+ nutrition_disorder_streamlit_app()