Spaces:
Running
Running
# --- 0. Library Imports --- | |
import os | |
import json | |
import shutil | |
import time | |
import numpy as np | |
from datetime import datetime | |
from typing import Dict, List, Any, TypedDict, Tuple | |
# LangChain Core & Community | |
from langchain_core.documents import Document | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.tools import tool | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.document_loaders import PyPDFDirectoryLoader | |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter | |
# LangChain OpenAI | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
# LangChain Experimental | |
from langchain_experimental.text_splitter import SemanticChunker | |
# LangChain Agents & Graph | |
from langchain.agents import create_tool_calling_agent, AgentExecutor | |
from langgraph.graph import StateGraph, END, START | |
# External Libraries | |
import chromadb | |
from llama_parse import LlamaParse # For PDF parsing | |
from groq import Groq # For Llama Guard | |
from mem0 import MemoryClient # For memory | |
import streamlit as st # For Web UI | |
# Fix for numpy depreciation warning if necessary | |
np.float_ = np.float64 | |
#import nest_asyncio | |
# --- 1. Configuration and Setup Utilities --- | |
from typing import Dict, List, Any, TypedDict, Tuple | |
def load_config_from_env() -> Dict: | |
"""Loads API keys and endpoints from environment variables.""" | |
# Prioritize environment variables for deployment | |
config = { | |
"AZURE_OPENAI_API_KEY": os.getenv("AZURE_OPENAI_API_KEY"), | |
"AZURE_OPENAI_API_BASE": os.getenv("AZURE_OPENAI_API_BASE"), | |
"LLAMA_KEY": os.getenv("LLAMA_KEY"), # LlamaParse API Key | |
"MEM0_API_KEY": os.getenv("MEM0_API_KEY"), | |
"GROQ_API_KEY": os.getenv("GROQ_API_KEY"), # For Llama Guard via Groq | |
} | |
# Basic validation | |
for key, value in config.items(): | |
if not value: | |
st.warning(f"Warning: Environment variable '{key}' is not set.") | |
return config | |
def initialize_llms_and_embeddings(config: Dict) -> Tuple[OpenAIEmbeddings, ChatOpenAI, chromadb.utils.embedding_functions.OpenAIEmbeddingFunction, Groq]: | |
"""Initializes LLM, Embedding models, and API clients.""" | |
api_key = config["AZURE_OPENAI_API_KEY"] | |
endpoint = config["AZURE_OPENAI_API_BASE"] | |
groq_api_key = config["GROQ_API_KEY"] | |
# Initialize ChromaDB embedding function (used for collection creation) | |
embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction( | |
api_base=endpoint, | |
api_key=api_key, | |
model_name='text-embedding-ada-002' # Specify model explicitly | |
) | |
# Initialize LangChain OpenAI Embeddings (used for `SemanticChunker` and `Chroma` vectorstore) | |
embedding_model = OpenAIEmbeddings( | |
openai_api_base=endpoint, | |
openai_api_key=api_key, | |
model='text-embedding-ada-002' # Specify model explicitly | |
) | |
# Initialize LangChain Chat OpenAI model | |
llm = ChatOpenAI( | |
openai_api_base=endpoint, | |
openai_api_key=api_key, | |
model="gpt-4o-mini", | |
streaming=False, | |
temperature=0.0 # Ensure deterministic behavior for evaluations | |
) | |
# Initialize Groq client for Llama Guard | |
llama_guard_client = Groq(api_key=groq_api_key) | |
return embedding_model, llm, embedding_function, llama_guard_client | |
def filter_input_with_llama_guard(user_input: str, llama_guard_client: Groq, model: str = "meta-llama/llama-guard-4-12b") -> str: | |
""" | |
Filters user input using Llama Guard to ensure it is safe. | |
Returns "safe", "UNSAFE" (with category), or None on error. | |
""" | |
try: | |
response = llama_guard_client.chat.completions.create( | |
messages=[{"role": "user", "content": user_input}], | |
model=model, | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
st.error(f"Error with Llama Guard: {e}") | |
return None | |
# --- 2. Data Preparation (Parsing & Chunking) --- | |
# Note: In a deployed app, PDF parsing and vector DB creation would typically be | |
# a separate, offline process, and the pre-built vector DB would be loaded. | |
# For this template, we'll assume the nutritional_db is pre-existing and loaded. | |
def load_and_split_documents(folder_path: str, embedding_model: OpenAIEmbeddings) -> List[Document]: | |
"""Loads PDF documents from a folder and semantically chunks them.""" | |
semantic_text_splitter = SemanticChunker( | |
embedding_model, | |
breakpoint_threshold_type='percentile', | |
breakpoint_threshold_amount=80 | |
) | |
pdf_loader = PyPDFDirectoryLoader(folder_path) | |
chunks = pdf_loader.load_and_split(semantic_text_splitter) | |
return chunks | |
def parse_pdf_tables_with_llamaparse(pdf_path: str, llamaparse_api_key: str) -> Tuple[Dict, Dict]: | |
"""Parses a PDF file using LlamaParse and extracts page texts and tables.""" | |
# This requires `nest_asyncio.apply()` to be called once at the start of the app if running async. | |
# In a Streamlit app, ensure it's at the very top level if needed. | |
# import nest_asyncio; nest_asyncio.apply() # Uncomment if needed for async parser | |
parser = LlamaParse( | |
result_type="markdown", | |
skip_diagonal_text=True, | |
fast_mode=False, | |
num_workers=1, # Adjust as per environment capabilities | |
check_interval=10, | |
api_key=llamaparse_api_key | |
) | |
json_objs = parser.get_json_result(pdf_path) | |
page_texts, tables = {}, {} | |
for obj in json_objs: | |
json_list = obj['pages'] | |
name = obj["file_path"].split("/")[-1] | |
page_texts[name] = {} | |
tables[name] = {} | |
for json_item in json_list: | |
for component in json_item['items']: | |
if component['type'] == 'table': | |
tables[name][json_item['page']] = component['rows'] | |
return page_texts, tables | |
def generate_hypothetical_questions(llm: ChatOpenAI, docs: List[Document], is_table: bool = False) -> List[Document]: | |
"""Generates hypothetical questions for text chunks or tables.""" | |
prompt_template = """ | |
Generate a list of exactly three (3) hypothetical questions that the below nutritional disorder {content_type} could be used to answer: | |
{content} | |
Ensure that the questions are specific in the context of nutrition, dietary deficiencies, metabolic disorders, vitamin and mineral imbalances, obesity, and related health conditions. | |
Generate only a list of questions. | |
Do not mention anything before or after the list. | |
If the content cannot answer any questions, return an empty list. | |
""" | |
hyp_docs = [] | |
content_type = "table" if is_table else "document" | |
for i, doc in enumerate(docs): | |
content_to_use = str(doc) if is_table else doc.page_content # Tables are often raw data, stringify | |
try: | |
response = llm.invoke(prompt_template.format(content_type=content_type, content=content_to_use)) | |
questions = response.content | |
except Exception as e: | |
st.error(f"Error generating hypothetical questions for {'table' if is_table else 'text'} chunk ID {doc.id}: {e}") | |
questions = "[]" # Return empty list string on error | |
questions_metadata = { | |
'original_content': content_to_use, | |
'source': doc.metadata.get('source', 'unknown'), | |
'page': doc.metadata.get('page', -1), | |
'type': content_type | |
} | |
hyp_docs.append( | |
Document( | |
id=f"{'table_' if is_table else 'text_chunk_'}{doc.id or i}", # Ensure unique IDs | |
page_content=questions, | |
metadata=questions_metadata | |
) | |
) | |
time.sleep(0.1) # Small delay to avoid rate limits | |
return hyp_docs | |
def create_and_persist_vector_db( | |
documents: List[Document], | |
embedding_model: OpenAIEmbeddings, | |
collection_name: str, | |
persist_directory: str | |
): | |
"""Creates/updates a Chroma vector store and persists it.""" | |
# Ensure IDs are strings as required by ChromaDB | |
doc_ids = [str(d.id) for d in documents] if documents else [] | |
if not doc_ids: | |
st.warning(f"No documents to add to collection '{collection_name}'.") | |
return | |
# Initialize or connect to Chroma vectorstore | |
vector_store = Chroma.from_documents( | |
documents, | |
embedding_model, | |
collection_name=collection_name, | |
persist_directory=persist_directory | |
) | |
st.info(f"Initialized ChromaDB with collection '{collection_name}' at '{persist_directory}'. " | |
f"Documents count: {len(documents)}") | |
return vector_store | |
def load_vector_db( | |
embedding_model: OpenAIEmbeddings, | |
collection_name: str, | |
persist_directory: str | |
) -> Chroma: | |
"""Loads an existing Chroma vector store.""" | |
try: | |
# Check if the directory exists and contains ChromaDB files | |
if not os.path.exists(persist_directory) or not os.listdir(persist_directory): | |
st.error(f"Vector DB directory '{persist_directory}' is empty or does not exist.") or print(f"Vector DB directory '{persist_directory}' is empty or does not exist.") | |
st.warning("Please ensure the 'nutritional_db' folder is correctly placed/mounted.") or print("Please ensure the 'nutritional_db' folder is correctly placed/mounted.") | |
return None | |
vector_store = Chroma( | |
collection_name=collection_name, | |
persist_directory=persist_directory, | |
embedding_function=embedding_model | |
) | |
st.success(f"Successfully loaded ChromaDB collection '{collection_name}' from '{persist_directory}'.") or print(f"Successfully loaded ChromaDB collection '{collection_name}' from '{persist_directory}'.") | |
# You can add a check for the number of documents loaded for verification | |
# Example: print(vector_store._collection.count()) | |
return vector_store | |
except Exception as e: | |
st.error(f"Error loading ChromaDB from '{persist_directory}': {e}") or print(f"Error loading ChromaDB from '{persist_directory}': {e}") | |
return None | |
# --- 3. Agent Workflow Definition (LangGraph) --- | |
class AgentState(TypedDict): | |
"""Represents the state of the AI agent at different stages of the workflow.""" | |
query: str | |
expanded_query: str | |
context: List[Dict[str, Any]] | |
response: str | |
precision_score: float | |
groundedness_score: float | |
groundedness_loop_count: int | |
precision_loop_count: int | |
feedback: str | |
query_feedback: str | |
groundedness_check: bool # This field isn't used in should_continue_groundedness, can be removed | |
loop_max_iter: int | |
# Node functions for LangGraph | |
def expand_query(state: AgentState) -> AgentState: | |
st.write("---Expanding Query---") | |
system_message = ''' | |
You are a domain expert assisting in answering questions related to research papers. | |
Convert the user query into something that a nutritionist would understand. Use domain related words. | |
Return three (3) related search queries based on the user's request separated by newline. | |
Return only three (3) versions of the question as a list. | |
Perform query expansion on the question received. If there are multiple common ways of phrasing a user question \ | |
or common synonyms for key words in the question, make sure to return multiple versions \ | |
of the query with the different phrasings. | |
If the query has multiple parts, split them into separate simpler queries. This is the only case where you can generate more than three (3) queries. | |
If there are acronyms or words you are not familiar with, do not try to rephrase them. | |
Generate only a list of questions. Do not mention anything before or after the list. | |
Use the query feedback if provided to craft the search queries. | |
''' | |
expand_prompt = ChatPromptTemplate.from_messages([ | |
("system", system_message), | |
("user", "Expand this query: {query} using the feedback: {query_feedback}") | |
]) | |
chain = expand_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state | |
expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]}) | |
st.write(f"Expanded query:\n{expanded_query}") | |
state["expanded_query"] = expanded_query | |
return state | |
def retrieve_context(state: AgentState) -> AgentState: | |
st.write("---Retrieving Context---") | |
query = f"{state['query']}; {state['expanded_query']}" | |
st.write(f"Query used for retrieval:\n{query}") | |
# Ensure vector_store is loaded and available in session_state | |
if 'vector_store' not in st.session_state or st.session_state.vector_store is None: | |
st.error("Vector store not initialized. Cannot retrieve context.") | |
state['context'] = [] | |
return state | |
retriever = st.session_state.vector_store.as_retriever( | |
search_type='similarity', | |
search_kwargs={'k': 3} | |
) | |
docs = retriever.invoke(query) | |
st.write(f"Retrieved documents (first 100 chars each):\n{[doc.page_content[:100] for doc in docs]}") | |
context = [ | |
{"content": doc.page_content, "metadata": doc.metadata} | |
for doc in docs | |
] | |
state['context'] = context | |
#st.write(f"Extracted context with metadata:\n{context}") # Too verbose for production UI | |
return state | |
def craft_response(state: AgentState) -> AgentState: | |
st.write("---Crafting Response---") | |
system_message = ''' | |
Generates a response to a user query and context provided. | |
Parameters: | |
query (str): The user's query and expanded queries based on user's query. | |
context (str): The documents retrieved relevant to the queries. | |
Returns: | |
response (str): The response generated by the model. | |
The function performs the following steps: | |
1. Constructs a prompt containing system and user prompts. | |
2. Sends the prompt containing user queries with context provided to the GPT model to generate a response. | |
3. Displays the response. | |
The answer you provide must come from the user queries with context provided. | |
If feedback is provided, use it to craft the response. | |
If information provided is not enough to answer the query respond with 'I don't know the answer. Not in my records.' | |
''' | |
response_prompt = ChatPromptTemplate.from_messages([ | |
("system", system_message), | |
("user", "Query:\n{query}\nContext:\n{context}\n\nfeedback:\n{feedback}") | |
]) | |
chain = response_prompt | st.session_state.llm # Use llm from session state | |
response = chain.invoke({ | |
"query": state['query'], | |
"context": "\n".join([doc["content"] for doc in state['context']]), | |
"feedback": state["feedback"] | |
}) | |
state['response'] = response.content # Access content from AIMessage | |
st.write(f"Intermediate response:\n{state['response']}") | |
return state | |
def score_groundedness(state: AgentState) -> AgentState: | |
st.write("---Checking Groundedness---") | |
system_message = ''' | |
You are tasked with rating AI generated answers to questions posed by users. | |
Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context. | |
In the input, the context is {context}, while the AI generated response is {response}. | |
Evaluation criteria: | |
The task is to judge the extent to which the metric is followed by the answer. | |
1 - The metric is not followed at all | |
2 - The metric is followed only to a limited extent | |
3 - The metric is followed to a good extent | |
4 - The metric is followed mostly | |
5 - The metric is followed completely | |
The answer should be derived only from the information presented in the context | |
Do not show any instructions for deriving your answer. | |
Output your result as a float number between 0 and 1 using the evaluation criteria. | |
The better the criteria, the closer it is to 1 and the worse the criteria, the closer it is to 0. | |
''' | |
groundedness_prompt = ChatPromptTemplate.from_messages([ | |
("system", system_message), | |
("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:") | |
]) | |
chain = groundedness_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state | |
groundedness_score = float(chain.invoke({ | |
"context": "\n".join([doc["content"] for doc in state['context']]), | |
"response": state['response'] | |
})) | |
state['groundedness_score'] = groundedness_score | |
state['groundedness_loop_count'] += 1 | |
st.write(f"Groundedness score: {groundedness_score}") | |
return state | |
def check_precision(state: AgentState) -> AgentState: | |
st.write("---Checking Precision---") | |
system_message = ''' | |
Given question, answer and context verify if the context was useful in arriving at the given answer. | |
Give verdict as "1" if useful and "0" if not useful. | |
Output your result as a float number between 0 and 1 | |
Give verdict as a scaled numeric value of type float between 0 and 1, such that | |
0 or near 0 if it is least useful, 0.5 or near 0.5 if retry is warranted, and 1 or close to 1 is most useful. | |
Do not show any instructions for deriving your answer. | |
''' | |
precision_prompt = ChatPromptTemplate.from_messages([ | |
("system", system_message), | |
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:") | |
]) | |
chain = precision_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state | |
precision_score = float(chain.invoke({ | |
"query": state['query'], | |
"response": state['response'] | |
})) | |
state['precision_score'] = precision_score | |
state['precision_loop_count'] +=1 | |
st.write(f"Precision score: {precision_score}") | |
return state | |
def refine_response(state: AgentState) -> AgentState: | |
st.write("---Refining Response---") | |
system_message = ''' | |
Since the last response failed the groundedness test, and is deemed not satisfactory, | |
use the feedback in terms of the query, context and the last response | |
to identify potential gaps, ambiguities, or missing details, and | |
to suggest improvements to enhance accuracy and completeness of the response. | |
''' | |
refine_response_prompt = ChatPromptTemplate.from_messages([ | |
("system", system_message), | |
("user", "Query: {query}\nResponse: {response}\n\n" | |
"What improvements can be made to enhance accuracy and completeness?") | |
]) | |
chain = refine_response_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state | |
feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}" | |
state['feedback'] = feedback | |
st.write(f"Refinement feedback:\n{feedback}") | |
return state | |
def refine_query(state: AgentState) -> AgentState: | |
st.write("---Refining Query---") | |
system_message = ''' | |
Since the last response failed the precision test, and is deemed not satisfactory, | |
use the feedback in terms of the query, context and re-generate extended queries | |
to identify specific keywords, scope refinements, or missing details, and | |
to provides structured suggestions for improvement to enhance accuracy and completeness of the response. | |
''' | |
refine_query_prompt = ChatPromptTemplate.from_messages([ | |
("system", system_message), | |
("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n" | |
"What improvements can be made for a better search?") | |
]) | |
chain = refine_query_prompt | st.session_state.llm | StrOutputParser() # Use llm from session state | |
query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}" | |
state['query_feedback'] = query_feedback | |
st.write(f"Query refinement feedback:\n{query_feedback}") | |
return state | |
def should_continue_groundedness(state: AgentState) -> str: | |
st.write("---Deciding Groundedness Continuation---") | |
st.write(f"Groundedness loop count: {state['groundedness_loop_count']}") | |
if state['groundedness_score'] >= 0.8: | |
st.write("Moving to precision check.") | |
return "check_precision" | |
else: | |
if state["groundedness_loop_count"] >= state['loop_max_iter']: | |
st.write("Max iterations reached for groundedness.") | |
return "max_iterations_reached" | |
else: | |
st.write("Groundedness score not met. Refining response.") | |
return "refine_response" | |
def should_continue_precision(state: AgentState) -> str: | |
st.write("---Deciding Precision Continuation---") | |
st.write(f"Precision loop count: {state['precision_loop_count']}") | |
if state['precision_score'] > 0.8: | |
st.write("Precision sufficient. Ending workflow.") | |
return "pass" | |
else: | |
if state["precision_loop_count"] >= state['loop_max_iter']: | |
st.write("Max iterations reached for precision.") | |
return "max_iterations_reached" | |
else: | |
st.write("Precision score not met. Refining query.") | |
return "refine_query" | |
def max_iterations_reached(state: AgentState) -> AgentState: | |
st.write("---Max Iterations Reached---") | |
response = "I'm unable to refine the response further. Please provide more context or clarify your question." | |
state['response'] = response | |
return state | |
def agentic_rag_tool(query: str) -> Dict[str, Any]: | |
""" | |
Runs the RAG-based agent workflow to generate context-aware responses. | |
This function is exposed as a tool for the overall chatbot. | |
""" | |
# Initialize state for the LangGraph workflow | |
inputs = { | |
"query": query, | |
"expanded_query": "", | |
"context": [], | |
"response": "", | |
"precision_score": 0.0, | |
"groundedness_score": 0.0, | |
"groundedness_loop_count": 0, | |
"precision_loop_count": 0, | |
"feedback": "", | |
"query_feedback": "", | |
"loop_max_iter": 3 | |
} | |
# Compile the workflow once and store it in session state if not already done | |
if 'workflow_app' not in st.session_state: | |
st.session_state.workflow_app = create_rag_workflow().compile() | |
# Invoke the compiled LangGraph workflow | |
output = st.session_state.workflow_app.invoke(inputs) | |
return output | |
def create_rag_workflow() -> StateGraph: | |
"""Creates the LangGraph workflow for the RAG agent.""" | |
workflow = StateGraph(AgentState) | |
workflow.add_node("expand_query", expand_query) | |
workflow.add_node("retrieve_context", retrieve_context) | |
workflow.add_node("craft_response", craft_response) | |
workflow.add_node("score_groundedness", score_groundedness) | |
workflow.add_node("refine_response", refine_response) | |
workflow.add_node("check_precision", check_precision) | |
workflow.add_node("refine_query", refine_query) | |
workflow.add_node("max_iterations_reached", max_iterations_reached) | |
workflow.add_edge(START, "expand_query") | |
workflow.add_edge("expand_query", "retrieve_context") | |
workflow.add_edge("retrieve_context", "craft_response") | |
workflow.add_edge("craft_response", "score_groundedness") | |
workflow.add_conditional_edges( | |
"score_groundedness", | |
should_continue_groundedness, | |
{ | |
"check_precision": "check_precision", | |
"refine_response": "refine_response", | |
"max_iterations_reached": "max_iterations_reached" | |
} | |
) | |
workflow.add_edge("refine_response", "craft_response") | |
workflow.add_conditional_edges( | |
"check_precision", | |
should_continue_precision, | |
{ | |
"pass": END, | |
"refine_query": "refine_query", | |
"max_iterations_reached": "max_iterations_reached" | |
} | |
) | |
workflow.add_edge("refine_query", "expand_query") | |
workflow.add_edge("max_iterations_reached", END) | |
return workflow | |
# --- 4. Main Chatbot Class (Integrating Memory & Agent) --- | |
class NutritionBot: | |
def __init__(self, config: Dict): | |
""" | |
Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor. | |
""" | |
mem0_api_key = config["MEM0_API_KEY"] | |
openai_api_key = config["AZURE_OPENAI_API_KEY"] | |
openai_api_base = config["AZURE_OPENAI_API_BASE"] | |
# Initialize a memory client to store and retrieve customer interactions | |
self.memory = MemoryClient(api_key=mem0_api_key) | |
# Initialize the OpenAI client (LangChain ChatOpenAI) | |
self.client = ChatOpenAI( | |
model="gpt-4o-mini", | |
openai_api_key=openai_api_key, | |
openai_api_base=openai_api_base, | |
temperature=0 | |
) | |
# Store LLM in session state for use in graph nodes | |
st.session_state.llm = self.client | |
# Define tools available to the chatbot | |
tools = [agentic_rag_tool] | |
# Define the system prompt for the agent | |
system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience. | |
Guidelines for Interaction: | |
Maintain a polite, professional, and reassuring tone. | |
Show genuine empathy for customer concerns and health challenges. | |
Reference past interactions to provide personalized and consistent advice. | |
Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations. | |
Ensure consistent and accurate information across conversations. | |
If any detail is unclear or missing, proactively ask for clarification. | |
Always use the agentic_rag_tool to retrieve up-to-date and evidence-based nutrition insights. | |
Keep track of ongoing issues and follow-ups to ensure continuity in support. | |
Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences. | |
""" | |
# Build the prompt template for the agent | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", system_prompt), | |
("human", "{input}"), | |
("placeholder", "{agent_scratchpad}") | |
]) | |
# Create an agent capable of interacting with tools and executing tasks | |
agent = create_tool_calling_agent(self.client, tools, prompt) | |
# Wrap the agent in an executor to manage tool interactions and execution flow | |
self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) | |
def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None): | |
"""Store customer interaction in memory for future reference.""" | |
if metadata is None: | |
metadata = {} | |
metadata["timestamp"] = datetime.now().isoformat() | |
conversation = [ | |
{"role": "user", "content": message}, | |
{"role": "assistant", "content": response} | |
] | |
self.memory.add(conversation, user_id=user_id, output_format="v1.1", metadata=metadata) | |
def get_relevant_history(self, user_id: str, query: str) -> List[Dict]: | |
"""Retrieve past interactions relevant to the current query.""" | |
return self.memory.search(query=query, user_id=user_id, limit=5) | |
def handle_customer_query(self, user_id: str, query: str) -> str: | |
"""Process a customer's query and provide a response, taking into account past interactions.""" | |
relevant_history = self.get_relevant_history(user_id, query) | |
context = "Previous relevant interactions:\n" | |
for memory_item in relevant_history: | |
# Mem0 'memory' field is typically a list of dicts or a string. | |
# Assuming 'v1.1' output format from `memory.add` means `memory_item['memory']` is structured. | |
if isinstance(memory_item.get('memory'), list): | |
for part in memory_item['memory']: | |
context += f"{part['role'].capitalize()}: {part['content']}\n" | |
else: # Fallback for older formats or if it's a simple string | |
context += f"History: {memory_item.get('memory', 'N/A')}\n" | |
context += "---\n" | |
prompt = f""" | |
Context: | |
{context} | |
Current customer query: {query} | |
Provide a helpful response that takes into account any relevant past interactions. | |
""" | |
st.write("Prompt sent to agent executor:", prompt) # Debugging | |
try: | |
response_dict = self.agent_executor.invoke({"input": prompt}) | |
response_content = response_dict.get('output', "I'm sorry, I couldn't generate a response.") | |
except Exception as e: | |
st.error(f"Error during agent execution: {e}") | |
response_content = f"I'm sorry, I encountered an internal error: {e}" | |
self.store_customer_interaction(user_id=user_id, message=query, response=response_content, metadata={"type": "support_query"}) | |
return response_content | |
# --- 5. Streamlit UI --- | |
def nutrition_disorder_streamlit_app(): | |
"""Streamlit-based UI for the Nutrition Disorder Specialist Agent.""" | |
st.set_page_config(page_title="Nutrition Disorder Specialist", layout="centered") | |
st.title("👨⚕️ Nutrition Disorder Specialist") | |
st.markdown("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.") | |
st.markdown("---") | |
# Initialize session state variables | |
if 'chat_history' not in st.session_state: | |
st.session_state.chat_history = [] | |
if 'user_id' not in st.session_state: | |
st.session_state.user_id = None | |
if 'chatbot' not in st.session_state: | |
st.session_state.chatbot = None | |
if 'config_loaded' not in st.session_state: | |
st.session_state.config_loaded = False | |
if 'vector_store_loaded' not in st.session_state: | |
st.session_state.vector_store_loaded = False | |
# --- Configuration Loading and Model Initialization --- | |
if not st.session_state.config_loaded: | |
with st.spinner("Loading configurations and initializing models..."): | |
config = load_config_from_env() | |
if not all(config.values()): | |
st.error("Some environment variables are missing. Please set them up for the app to function.") | |
st.stop() # Stop execution if critical configs are missing | |
# Step 1. | |
embedding_model, llm_instance, embedding_function, llama_guard_client_instance = initialize_llms_and_embeddings(config) | |
# Step 2. Store initialized objects in session state | |
st.session_state.config = config | |
st.session_state.embedding_model = embedding_model | |
st.session_state.llm = llm_instance | |
st.session_state.embedding_function = embedding_function # Used during vector_store creation/loading | |
st.session_state.llama_guard_client = llama_guard_client_instance | |
st.session_state.config_loaded = True | |
st.rerun() # Rerun to update UI after loading | |
# --- Vector Store Loading --- | |
if st.session_state.config_loaded and not st.session_state.vector_store_loaded: | |
with st.spinner("Loading nutrition knowledge base (vector database)..."): | |
# Ensure the nutritional_db directory exists relative to the app.py | |
# In Docker, this means the folder should be copied into /app | |
persist_dir = "./nutritional_db" | |
if not os.path.exists(persist_dir): | |
st.error(f"Required data directory '{persist_dir}' not found. Please ensure it's copied into the Docker image.") | |
st.stop() | |
st.session_state.vector_store = load_vector_db( | |
embedding_model=st.session_state.embedding_model, | |
collection_name="nutritional_hypotheticals", | |
persist_directory=persist_dir | |
) | |
if st.session_state.vector_store is None: | |
st.error("Failed to load vector database. Chat functionality will be limited.") | |
st.session_state.vector_store_loaded = True | |
st.rerun() # Rerun to update UI after loading | |
# --- Login Form --- | |
if st.session_state.user_id is None: | |
with st.form("login_form", clear_on_submit=True): | |
user_id_input = st.text_input("Please enter your name to begin:", key="user_id_input") | |
submit_button = st.form_submit_button("Login") | |
if submit_button and user_id_input: | |
st.session_state.user_id = user_id_input.strip() | |
st.session_state.chat_history.append({ | |
"role": "assistant", | |
"content": f"Welcome, {st.session_state.user_id}! How can I help you with nutrition disorders today?" | |
}) | |
# Initialize chatbot only after config and vector store are ready | |
if st.session_state.config_loaded and st.session_state.vector_store_loaded: | |
st.session_state.chatbot = NutritionBot(st.session_state.config) | |
else: | |
st.warning("Chatbot initialization delayed as configurations or vector store are still loading.") | |
st.rerun() | |
# --- Chat Interface --- | |
elif st.session_state.chatbot: # Only show chat if chatbot is initialized | |
# Display chat history | |
for message in st.session_state.chat_history: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
user_query = st.chat_input("Type your question here (e.g., 'What are dietary deficiencies?')") | |
if user_query: | |
if user_query.lower() == "exit": | |
st.session_state.chat_history.append({"role": "user", "content": "exit"}) | |
with st.chat_message("user"): | |
st.write("exit") | |
goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders." | |
st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg}) | |
with st.chat_message("assistant"): | |
st.write(goodbye_msg) | |
st.session_state.user_id = None # Log out | |
st.session_state.chatbot = None # Clear chatbot instance | |
st.session_state.chat_history = [] # Clear history on exit | |
st.rerun() | |
return | |
st.session_state.chat_history.append({"role": "user", "content": user_query}) | |
with st.chat_message("user"): | |
st.write(user_query) | |
# Filter input using Llama Guard | |
with st.spinner("Filtering input for safety..."): | |
filtered_result = filter_input_with_llama_guard(user_query, st.session_state.llama_guard_client) | |
if filtered_result: | |
filtered_result = filtered_result.replace("\n", " ").strip() | |
st.info(f"Llama Guard says: {filtered_result}") # Show Llama Guard's verdict | |
# Process the user query if safe | |
if filtered_result and ("safe" in filtered_result.lower() or "s7" in filtered_result.lower()): # Allow "safe S7" etc. | |
with st.spinner("Thinking..."): | |
try: | |
response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query) | |
st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
with st.chat_message("assistant"): | |
st.write(response) | |
except Exception as e: | |
error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {e}" | |
st.error(error_msg) | |
st.session_state.chat_history.append({"role": "assistant", "content": error_msg}) | |
else: | |
inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate or unsafe. Please try again." | |
st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg}) | |
with st.chat_message("assistant"): | |
st.write(inappropriate_msg) | |
st.rerun() # Rerun to update chat history instantly | |
elif st.session_state.user_id: # User is logged in but chatbot not ready | |
st.info("Initializing chatbot. Please wait...") | |
# --- Main entry point for Streamlit App --- | |
if __name__ == "__main__": | |
nutrition_disorder_streamlit_app() | |