nutribot / app.py
arajshiva's picture
Upload app.py with huggingface_hub
41421cb verified
######################## WRITE YOUR CODE HERE #########################
# Import necessary libraries
import os # Interacting with the operating system (reading/writing files)
import chromadb # High-performance vector database for storing/querying dense vectors
from dotenv import load_dotenv # Loading environment variables from a .env file
import json # Parsing and handling JSON data
# LangChain imports
from langchain_core.documents import Document # Document data structures
from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines
from langchain_core.output_parsers import StrOutputParser # String output parser
from langchain.prompts import ChatPromptTemplate # Template for chat prompts
from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction
from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers
from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
# LangChain community & experimental imports
from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
from langchain.text_splitter import (
CharacterTextSplitter, # Splitting text by characters
RecursiveCharacterTextSplitter # Recursive splitting of text by characters
)
from langchain_core.tools import tool
from langchain.agents import create_tool_calling_agent, AgentExecutor
from langchain_core.prompts import ChatPromptTemplate
# LangChain OpenAI imports
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
# LlamaParse & LlamaIndex imports
from llama_parse import LlamaParse # Document parsing library
from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
# LangGraph import
from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain
# Pydantic import
from pydantic import BaseModel # Pydantic for data validation
# Typing imports
from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations
# Other utilities
import numpy as np # Numpy for numerical operations
from groq import Groq
from mem0 import MemoryClient
import streamlit as st
from datetime import datetime
#====================================SETUP=====================================#
# Fetch secrets from Hugging Face Spaces
api_key = os.environ['AZURE_OPENAI_API_KEY']
endpoint = os.environ['AZURE_OPENAI_ENDPOINT']
api_version = os.environ['AZURE_OPENAI_APIVERSION']
model_name = os.environ['CHATGPT_MODEL']
emb_key = os.environ['EMB_MODEL_KEY']
emb_endpoint = os.environ['EMB_DEPLOYMENT']
#llama_api_key = os.environ['GROQ_API_KEY']
llama_api_key = os.environ['LLAMA_API_KEY']
# Initialize the OpenAI embedding function for Chroma
embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
# api_base=_____, # Complete the code to define the API base endpoint
# api_key=_____, # Complete the code to define the API key
api_base= emb_endpoint, # Complete the code to define the API base endpoint
api_key= emb_key, # Complete the code to define the API key
api_type='azure', # This is a fixed value and does not need modification
api_version='2023-05-15', # This is a fixed value and does not need modification
model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
)
# This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided Azure endpoint and API key.
# Initialize the Azure OpenAI Embeddings
embedding_model = AzureOpenAIEmbeddings(
# azure_endpoint=_____, # Complete the code to define the Azure endpoint
# api_key=_____, # Complete the code to define the API key
azure_endpoint= emb_endpoint, # Complete the code to define the Azure endpoint
api_key= emb_key, # Complete the code to define the API key
api_version='2023-05-15', # This is a fixed value and does not need modification
model='text-embedding-ada-002' # This is a fixed value and does not need modification
)
# This initializes the Azure OpenAI embeddings model using the specified endpoint, API key, and model name.
# Initialize the Azure Chat OpenAI model
llm = AzureChatOpenAI(
azure_endpoint=endpoint,
api_key=api_key,
api_version='2024-05-01-preview',
azure_deployment='gpt-4o',
temperature=0
)
# This initializes the Azure Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
# set the LLM and embedding model in the LlamaIndex settings.
# Settings.llm = _____ # Complete the code to define the LLM model
# Settings.embedding = _____ # Complete the code to define the embedding model
Settings.llm = llm # Complete the code to define the LLM model
Settings.embedding = embedding_model # Complete the code to define the embedding model
#================================Creating Langgraph agent======================#
class AgentState(TypedDict):
query: str # The current user query
expanded_query: str # The expanded version of the user query
context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
response: str # The generated response to the user query
precision_score: float # The precision score of the response
groundedness_score: float # The groundedness score of the response
groundedness_loop_count: int # Counter for groundedness refinement loops
precision_loop_count: int # Counter for precision refinement loops
feedback: str
query_feedback: str
groundedness_check: bool
loop_max_iter: int
def expand_query(state):
"""
Expands the user query to improve retrieval of nutrition disorder-related information.
Args:
state (Dict): The current state of the workflow, containing the user query.
Returns:
Dict: The updated state with the expanded query.
"""
print("---------Expanding Query---------")
#system_message = '''________________________'''
system_message = """
You are a domain expert assisting in answering questions related to nutrition disorder-related information.
Convert the user query into something that a nutritionist would understand. Use domain related words.
Perform query expansion on the question received. If there are multiple common ways of phrasing a user question \
or common synonyms for key words in the question, make sure to return multiple versions \
of the query with the different phrasings.
If the query has multiple parts, split them into separate simpler queries. This is the only case where you can generate more than 3 queries.
If there are acronyms or words you are not familiar with, do not try to rephrase them.
Return only 3 versions of the question as a list.
Generate only a list of questions. Do not mention anything before or after the list.
"""
expand_prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("user", "Expand this query: {query} using the feedback: {query_feedback}")
])
chain = expand_prompt | llm | StrOutputParser()
expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
print("expanded_query", expanded_query)
state["expanded_query"] = expanded_query
return state
print("Current Working Directory:", os.getcwd())
# Initialize the Chroma vector store for retrieving documents
vector_store = Chroma(
collection_name="nutritional-medical-reference",
persist_directory="./research_db",
embedding_function=embedding_model
)
# Create a retriever from the vector store
retriever = vector_store.as_retriever(
search_type='similarity',
search_kwargs={'k': 3}
)
def retrieve_context(state):
"""
Retrieves context from the vector store using the expanded or original query.
Args:
state (Dict): The current state of the workflow, containing the query and expanded query.
Returns:
Dict: The updated state with the retrieved context.
"""
print("---------retrieve_context---------")
#query = state['_____'] # Complete the code to define the key for the expanded query
query = state['expanded_query'] # Complete the code to define the key for the expanded query
#print("Query used for retrieval:", query) # Debugging: Print the query
# Retrieve documents from the vector store
docs = retriever.invoke(query)
print("Retrieved documents:", docs) # Debugging: Print the raw docs object
# Extract both page_content and metadata from each document
context= [
{
"content": doc.page_content, # The actual content of the document
"metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
}
for doc in docs
]
#state['_____'] = context # Complete the code to define the key for storing the context
state['context'] = context # Complete the code to define the key for storing the context
print("Extracted context with metadata:", context) # Debugging: Print the extracted context
#print(f"Groundedness loop count: {state['groundedness_loop_count']}")
return state
def craft_response(state: Dict) -> Dict:
"""
Generates a response using the retrieved context, focusing on nutrition disorders.
Args:
state (Dict): The current state of the workflow, containing the query and retrieved context.
Returns:
Dict: The updated state with the generated response.
"""
print("---------craft_response---------")
#system_message = '''________________________'''
system_message = """
You are a knowledgeable nutritionist specialized in nutrition and health.
Use the provided context to generate a helpful, accurate, and empathetic response to the user's query.
Focus on identifying, explaining, or addressing nutrition disorders where relevant. Be clear and concise.
"""
response_prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
])
chain = response_prompt | llm
response = chain.invoke({
"query": state['query'],
"context": "\n".join([doc["content"] for doc in state['context']]),
#"feedback": ________________ # add feedback to the prompt
"feedback": state['feedback'] # add feedback to the prompt
})
state['response'] = response
print("intermediate response: ", response)
return state
def score_groundedness(state: Dict) -> Dict:
"""
Checks whether the response is grounded in the retrieved context.
Args:
state (Dict): The current state of the workflow, containing the response and context.
Returns:
Dict: The updated state with the groundedness score.
"""
print("---------check_groundedness---------")
#system_message = '''________________________'''
system_message = '''You are an objective evaluator tasked with scoring the groundedness of a response
based on the retrieved context provided.
Definition of "groundedness":
- A response is considered grounded if it strictly uses information present in the provided context.
- It should avoid hallucinating, fabricating, or introducing any claims that are not explicitly supported by the context.
Scoring Guidelines:
- Return a numeric score between 0 and 1.
- 1.0: The response is entirely grounded in the context.
- 0.5: The response is partially grounded (some parts supported, others not).
- 0.0: The response is not grounded at all (hallucinated or irrelevant).
Important:
- Do NOT explain your score.
- Do NOT provide justification.
- ONLY return the score as a number (e.g., 1.0, 0.5, or 0.0).
'''
groundedness_prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
])
chain = groundedness_prompt | llm | StrOutputParser()
groundedness_score = float(chain.invoke({
"context": "\n".join([doc["content"] for doc in state['context']]),
#"response": __________ # Complete the code to define the response
"response": state['response'] # Complete the code to define the response
}))
print("groundedness_score: ", groundedness_score)
state['groundedness_loop_count'] += 1
print("#########Groundedness Incremented###########")
state['groundedness_score'] = groundedness_score
return state
def check_precision(state: Dict) -> Dict:
"""
Checks whether the response precisely addresses the user’s query.
Args:
state (Dict): The current state of the workflow, containing the query and response.
Returns:
Dict: The updated state with the precision score.
"""
print("---------check_precision---------")
system_message = '''________________________'''
system_message = '''Given question, answer and context verify if the context was useful in arriving at the given answer.
Give verdict as "1" if useful and "0" if not '''
precision_prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
])
#chain = _____________ | llm | StrOutputParser() # Complete the code to define the chain of processing
chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
precision_score = float(chain.invoke({
"query": state['query'],
#"response":______________ # Complete the code to access the response from the state
"response":state['response'] # Complete the code to access the response from the state
}))
state['precision_score'] = precision_score
print("precision_score:", precision_score)
state['precision_loop_count'] +=1
print("#########Precision Incremented###########")
return state
def refine_response(state: Dict) -> Dict:
"""
Suggests improvements for the generated response.
Args:
state (Dict): The current state of the workflow, containing the query and response.
Returns:
Dict: The updated state with response refinement suggestions.
"""
print("---------refine_response---------")
#system_message = '''________________________'''
system_message = '''You are a response refinement expert tasked with reviewing and improving AI-generated answers.
Your role is to:
- Carefully analyze the given response in light of the original user query.
- Identify any factual inaccuracies, gaps, or lack of clarity.
- Suggest improvements that make the response more complete, precise, and aligned with the query intent.
Guidelines:
- Be constructive and focused.
- Suggest rewordings, additions, or clarifications where needed.
- Highlight if any information is missing or should be cited.
- Avoid introducing new facts unless they are universally accepted and directly relevant.
Output Format:
- ONLY return specific suggestions for improving the response.
- Do NOT rewrite the full response.
- Do NOT return general praise. Focus on actionable refinements.'''
refine_response_prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("user", "Query: {query}\nResponse: {response}\n\n"
"What improvements can be made to enhance accuracy and completeness?")
])
chain = refine_response_prompt | llm| StrOutputParser()
# Store response suggestions in a structured format
feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
print("feedback: ", feedback)
print(f"State: {state}")
state['feedback'] = feedback
return state
def refine_query(state: Dict) -> Dict:
"""
Suggests improvements for the expanded query.
Args:
state (Dict): The current state of the workflow, containing the query and expanded query.
Returns:
Dict: The updated state with query refinement suggestions.
"""
print("---------refine_query---------")
#system_message = '''________________________'''
system_message = '''
You are an expert in information retrieval and query optimization.
Your job is to analyze an expanded search query that was generated from a user's original question, and suggest specific improvements that will help a search or retrieval system return more relevant, high-quality results.
Guidelines:
- Ensure the expanded query is clear, concise, and aligned with the user's original intent.
- Eliminate any ambiguity or redundancy.
- Suggest adding important synonyms, rephrasings, or domain-specific terminology if helpful.
- Avoid suggesting overly broad or overly narrow queries.
- Do NOT rewrite the query. Just offer targeted suggestions for improvement.
Output Format:
- Provide bullet-point suggestions for improving the expanded query.
- Focus on changes that will improve retrieval quality without losing the user's intent.
'''
refine_query_prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
"What improvements can be made for a better search?")
])
chain = refine_query_prompt | llm | StrOutputParser()
# Store refinement suggestions without modifying the original expanded query
query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
print("query_feedback: ", query_feedback)
print(f"Groundedness loop count: {state['groundedness_loop_count']}")
state['query_feedback'] = query_feedback
return state
def should_continue_groundedness(state):
"""Decides if groundedness is sufficient or needs improvement."""
print("---------should_continue_groundedness---------")
print("groundedness loop count: ", state['groundedness_loop_count'])
#if state['groundedness_score'] >= _____: # Complete the code to define the threshold for groundedness
if state['groundedness_score'] >= 0.5: # Complete the code to define the threshold for groundedness
print("Moving to precision")
return "check_precision"
else:
if state["groundedness_loop_count"] > state['loop_max_iter']:
return "max_iterations_reached"
else:
print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
return "refine_response"
def should_continue_precision(state: Dict) -> str:
"""Decides if precision is sufficient or needs improvement."""
print("---------should_continue_precision---------")
#print("precision loop count: ", ___________)
print("precision loop count: ",state['precision_loop_count'])
#if ___________: # Threshold for precision
if state['precision_score']==1.0: # Threshold for precision
return "pass" # Complete the workflow
else:
#if ___________: # Maximum allowed loops
if state['precision_loop_count'] >= 3: # Maximum allowed loops
return "max_iterations_reached"
else:
print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
#return ____________ # Refine the query
return "refine_query"
def max_iterations_reached(state: Dict) -> Dict:
"""Handles the case when the maximum number of iterations is reached."""
print("---------max_iterations_reached---------")
"""Handles the case when the maximum number of iterations is reached."""
response = "I'm unable to refine the response further. Please provide more context or clarify your question."
state['response'] = response
return state
from langgraph.graph import END, StateGraph, START
def create_workflow() -> StateGraph:
"""Creates the updated workflow for the AI nutrition agent."""
#workflow = StateGraph(__________)
workflow = StateGraph(AgentState)
# Add processing nodes
#workflow.add_node("expand_query", ___________) # Step 1: Expand user query.
workflow.add_node("expand_query", expand_query) # Step 1: Expand user query.
#workflow.add_node("retrieve_context", ___________) # Step 2: Retrieve relevant documents.
workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents.
#workflow.add_node("craft_response", ___________) # Step 3: Generate a response based on retrieved data.
workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data.
#workflow.add_node("score_groundedness", ___________) # Step 4: Evaluate response grounding.
workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding.
#workflow.add_node("refine_response", ___________) # Step 5: Improve response if it's weakly grounded.
workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded.
#workflow.add_node("check_precision", ___________) # Step 6: Evaluate response precision.
workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision.
#workflow.add_node("refine_query", ___________) # Step 7: Improve query if response lacks precision.
workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision.
#workflow.add_node("max_iterations_reached", ___________) # Step 8: Handle max iterations.
workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations.
# Main flow edges
#workflow.add_edge(__________, ___________)
# workflow.add_edge(__________, ___________)
# workflow.add_edge(__________, ___________)
# workflow.add_edge(__________, ___________)
workflow.add_edge(START, "expand_query")
workflow.add_edge("expand_query", "retrieve_context")
workflow.add_edge("retrieve_context", "craft_response")
workflow.add_edge("craft_response", "score_groundedness")
# Conditional edges based on groundedness check
workflow.add_conditional_edges(
"score_groundedness",
should_continue_groundedness, # Use the conditional function
{
"check_precision": "check_precision", # If well-grounded, proceed to precision check.
"refine_response": "refine_response", # If not, refine the response.
"max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
}
)
#workflow.add_edge(__________, ___________) # Refined responses are reprocessed.
workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
# Conditional edges based on precision check
workflow.add_conditional_edges(
"check_precision",
should_continue_precision, # Use the conditional function
{
"pass": END, # If precise, complete the workflow.
"refine_query": "refine_query", # If imprecise, refine the query.
"max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
}
)
# workflow.add_edge(__________, ___________) # Refined queries go through expansion again.
# workflow.add_edge(__________, ___________)
workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
workflow.add_edge("max_iterations_reached", END)
return workflow
#=========================== Defining the agentic rag tool ====================#
WORKFLOW_APP = create_workflow().compile()
@tool
def agentic_rag(query: str):
"""
Runs the RAG-based agent with conversation history for context-aware responses.
Args:
query (str): The current user query.
Returns:
Dict[str, Any]: The updated state with the generated response and conversation history.
"""
# Initialize state with necessary parameters
inputs = {
"query": query, # Current user query
"expanded_query": "", # Complete the code to define the expanded version of the query
"context": [], # Retrieved documents (initially empty)
"response": "", # Complete the code to define the AI-generated response
"precision_score": 0.0, # Complete the code to define the precision score of the response
"groundedness_score": 0.0, # Complete the code to define the groundedness score of the response
"groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
"precision_loop_count": 0, # Complete the code to define the counter for precision loops
"feedback": "", # Complete the code to define the feedback
"query_feedback": "", # Complete the code to define the query feedback
"loop_max_iter": 3 # Complete the code to define the maximum number of iterations for loops
}
output = WORKFLOW_APP.invoke(inputs)
return output
#================================ Guardrails ===========================#
llama_guard_client = Groq(api_key=llama_api_key)
# Function to filter user input with Llama Guard
#def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12b"):
def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12b"):
"""
Filters user input using Llama Guard to ensure it is safe.
Parameters:
- user_input: The input provided by the user.
- model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
Returns:
- The filtered and safe input.
"""
try:
# Create a request to Llama Guard to filter the user input
response = llama_guard_client.chat.completions.create(
messages=[{"role": "user", "content": user_input}],
model=model,
)
# Return the filtered input
return response.choices[0].message.content.strip()
except Exception as e:
print(f"Error with Llama Guard: {e}")
return None
#============================= Adding Memory to the agent using mem0 ===============================#
class NutritionBot:
def __init__(self):
# Initialize a memory client to store and retrieve customer interactions
#self.memory = MemoryClient(os.environ["mem0"]) # Complete the code to define the memory client API key
try:
self.memory = MemoryClient(os.environ["mem0"])
except Exception as e:
st.error(f"Failed to initialize MemoryClient: {e}")
#self.memory = MemoryClient(api_key=userdata.get("mem0")) # Complete the code to define the memory client API key
# Initialize the Azure OpenAI client using the provided credentials
self.client = AzureChatOpenAI(
# model_name="_____", # Specify the model to use (e.g., GPT-4 optimized version)
# api_key=config['_____'], # API key for authentication
# azure_endpoint=config['_____'], # Endpoint URL for Azure OpenAI
# api_version=config['_____'], # API version being used
# temperature=_____ # Controls randomness in responses; 0 ensures deterministic results
model_name= model_name, # Specify the model to use (e.g., GPT-4 optimized version)
api_key= api_key, # API key for authentication
azure_endpoint= endpoint, # Endpoint URL for Azure OpenAI
api_version= api_version, # API version being used
temperature=0 # Controls randomness in responses; 0 ensures deterministic results
)
"""
Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
"""
# Define tools available to the chatbot, such as web search
tools = [agentic_rag]
# Define the system prompt to set the behavior of the chatbot
system_prompt = """You are a helpful nutrition assistant.
Answer user questions about nutrition disorders accurately, clearly, and respectfully using available information."""
# Build the prompt template for the agent
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt), # System instructions
("human", "{input}"), # Placeholder for human input
("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
])
# Create an agent capable of interacting with tools and executing tasks
agent = create_tool_calling_agent(self.client, tools, prompt)
# Wrap the agent in an executor to manage tool interactions and execution flow
self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
"""
Store customer interaction in memory for future reference.
Args:
user_id (str): Unique identifier for the customer.
message (str): Customer's query or message.
response (str): Chatbot's response.
metadata (Dict, optional): Additional metadata for the interaction.
"""
if metadata is None:
metadata = {}
# Add a timestamp to the metadata for tracking purposes
metadata["timestamp"] = datetime.now().isoformat()
# Format the conversation for storage
conversation = [
{"role": "user", "content": message},
{"role": "assistant", "content": response}
]
# Store the interaction in the memory client
self.memory.add(
conversation,
user_id=user_id,
output_format="v1.1",
metadata=metadata
)
def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
"""
Retrieve past interactions relevant to the current query.
Args:
user_id (str): Unique identifier for the customer.
query (str): The customer's current query.
Returns:
List[Dict]: A list of relevant past interactions.
"""
return self.memory.search(
query=query, # Search for interactions related to the query
user_id=user_id, # Restrict search to the specific user
limit= 3 # Complete the code to define the limit for retrieved interactions
)
def handle_customer_query(self, user_id: str, query: str) -> str:
"""
Process a customer's query and provide a response, taking into account past interactions.
Args:
user_id (str): Unique identifier for the customer.
query (str): Customer's query.
Returns:
str: Chatbot's response.
"""
# Retrieve relevant past interactions for context
relevant_history = self.get_relevant_history(user_id, query)
# Build a context string from the relevant history
context = "Previous relevant interactions:\n"
for memory in relevant_history:
context += f"Customer: {memory['memory']}\n" # Customer's past messages
context += f"Support: {memory['memory']}\n" # Chatbot's past responses
context += "---\n"
# Print context for debugging purposes
print("Context: ", context)
# Prepare a prompt combining past context and the current query
# prompt = f"""
# Context:
# {context}
# Current customer query: {query}
# Provide a helpful response that takes into account any relevant past interactions.
# """
prompt = f"{context}\n\nUser: {query}"
# Generate a response using the agent
response = self.agent_executor.invoke({"input": prompt})
# Store the current interaction for future reference
self.store_customer_interaction(
user_id=user_id,
message=query,
response=response["output"],
metadata={"type": "support_query"}
)
# Return the chatbot's response
return response['output']
#=====================User Interface using streamlit ===========================#
def nutrition_disorder_streamlit():
"""
A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
"""
st.title("Nutrition Disorder Specialist")
st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
st.write("Type 'exit' to end the conversation.")
# Initialize session state for chat history and user_id if they don't exist
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'user_id' not in st.session_state:
st.session_state.user_id = None
# Login form: Only if user is not logged in
if st.session_state.user_id is None:
with st.form("login_form", clear_on_submit=True):
user_id = st.text_input("Please enter your name to begin:")
submit_button = st.form_submit_button("Login")
if submit_button and user_id:
st.session_state.user_id = user_id
st.session_state.chat_history.append({
"role": "assistant",
"content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
})
st.session_state.login_submitted = True # Set flag to trigger rerun
if st.session_state.get("login_submitted", False):
st.session_state.pop("login_submitted")
st.rerun()
else:
# Display chat history
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.write(message["content"])
# Chat input with custom placeholder text
#user_query = st.chat_input(__________) # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
user_query = st.chat_input("Type your question here (or 'exit' to end)...")
if user_query:
if user_query.lower() == "exit":
st.session_state.chat_history.append({"role": "user", "content": "exit"})
with st.chat_message("user"):
st.write("exit")
goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders."
st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
with st.chat_message("assistant"):
st.write(goodbye_msg)
st.session_state.user_id = None
st.rerun()
return
st.session_state.chat_history.append({"role": "user", "content": user_query})
with st.chat_message("user"):
st.write(user_query)
# Filter input using Llama Guard
#filtered_result = __________(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
filtered_result = filter_input_with_llama_guard(user_query)
filtered_result = filtered_result.replace("\n", " ") # Normalize the result
# Check if input is safe based on allowed statuses
#if filtered_result in [__________, __________, __________]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
if filtered_result in ["safe", "unsafe S6", "unsafe S7"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
try:
if 'chatbot' not in st.session_state:
#st.session_state.chatbot = __________() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
st.session_state.chatbot = NutritionBot()
#response = st.session_state.chatbot.__________(st.session_state.user_id, user_query)
response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
# Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
st.write(response)
st.session_state.chat_history.append({"role": "assistant", "content": response})
except Exception as e:
error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
st.write(error_msg)
st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
else:
inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
st.write(inappropriate_msg)
st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
if __name__ == "__main__":
nutrition_disorder_streamlit()