Spaces:
Running
Running
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -1,611 +1,785 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
import uuid
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
import
|
9 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
|
12 |
-
from
|
13 |
-
from
|
14 |
-
#from langchain.embeddings import OpenAIEmbeddings
|
15 |
-
from langchain_community.embeddings import OpenAIEmbeddings
|
16 |
-
from langchain_openai import OpenAIEmbeddings
|
17 |
-
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
18 |
-
from langchain_community.vectorstores import Chroma
|
19 |
-
from langchain_core.caches import BaseCache
|
20 |
|
|
|
|
|
|
|
21 |
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
from
|
25 |
-
from langchain_community.agent_toolkits import create_sql_agent
|
26 |
-
from langchain.agents import create_tool_calling_agent, AgentExecutor
|
27 |
-
from langchain_core.tools import tool
|
28 |
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
#BaseCache.register_cache_type("memory", lambda: None)
|
34 |
-
#ChatOpenAI.model_rebuild()
|
35 |
#====================================SETUP=====================================#
|
36 |
# Fetch secrets from Hugging Face Spaces
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
#
|
47 |
-
db_loc = 'ecomm.db'
|
48 |
-
|
49 |
-
# Create a SQLDatabase instance from the SQLite database URI
|
50 |
-
db = SQLDatabase.from_uri(f"sqlite:///{db_loc}")
|
51 |
-
|
52 |
-
# Retrieve the schema information of the database tables
|
53 |
-
database_schema = db.get_table_info()
|
54 |
-
|
55 |
-
|
56 |
-
# Let's initiate w&b weave with a project name - this will automatically save all the llm calls made using openai or gemini
|
57 |
-
# Make sure to save your w&b api key in secrets as WANDB_API_KEY
|
58 |
-
weave.init('ecomm_support')
|
59 |
-
# <--------------------------------------------------------- Uncomment to log to WANDB
|
60 |
-
|
61 |
-
#=================================Setup Logging=====================================#
|
62 |
-
|
63 |
-
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
|
64 |
-
log_folder = log_file.parent
|
65 |
-
|
66 |
-
log_scheduler = CommitScheduler(
|
67 |
-
repo_id="chatbot-logs", #Dataset name where we want to save the logs.
|
68 |
-
repo_type="dataset",
|
69 |
-
folder_path=log_folder,
|
70 |
-
path_in_repo="data",
|
71 |
-
every=5 # Saves data every x minute
|
72 |
)
|
73 |
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
repo_type="dataset",
|
81 |
-
folder_path=history_folder,
|
82 |
-
path_in_repo="data",
|
83 |
-
every=5 # Saves data every x minute
|
84 |
-
)
|
85 |
-
|
86 |
-
#=================================SQL_AGENT=====================================#
|
87 |
-
|
88 |
-
# Define the system message for the agent, including instructions and available tables
|
89 |
-
system_message = f"""You are a SQLite expert agent designed to interact with a SQLite database.
|
90 |
-
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
|
91 |
-
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 100 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database..
|
92 |
-
You can order the results by a relevant column to return the most interesting examples in the database.
|
93 |
-
You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
|
94 |
-
You have access to tools for interacting with the database.
|
95 |
-
Only use the given tools. Only use the information returned by the tools to construct your final answer.
|
96 |
-
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
|
97 |
-
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
98 |
-
You are not allowed to make dummy data.
|
99 |
-
If the question does not seem related to the database, just return "I don't know" as the answer.
|
100 |
-
Before you execute the query, tell us why you are executing it and what you expect to find briefly.
|
101 |
-
Only use the following tables:
|
102 |
-
{database_schema}
|
103 |
-
"""
|
104 |
-
|
105 |
-
# Create a full prompt template for the agent using the system message and placeholders
|
106 |
-
full_prompt = ChatPromptTemplate.from_messages(
|
107 |
-
[
|
108 |
-
("system", system_message),
|
109 |
-
("human", '{input}'),
|
110 |
-
MessagesPlaceholder("agent_scratchpad")
|
111 |
-
]
|
112 |
)
|
113 |
|
114 |
-
# Initialize the
|
115 |
llm = ChatOpenAI(
|
116 |
openai_api_base=endpoint,
|
117 |
openai_api_key=api_key,
|
118 |
-
model="gpt-4o",
|
119 |
-
streaming=False
|
120 |
-
)
|
121 |
-
|
122 |
-
# Create the SQL agent using the ChatOpenAI model, database, and prompt template
|
123 |
-
sqlite_agent = create_sql_agent(
|
124 |
-
llm=llm,
|
125 |
-
db=db,
|
126 |
-
prompt=full_prompt,
|
127 |
-
agent_type="openai-tools",
|
128 |
-
agent_executor_kwargs={'handle_parsing_errors': True},
|
129 |
-
max_iterations=5,
|
130 |
-
verbose=True
|
131 |
)
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
"""
|
137 |
-
|
138 |
-
|
139 |
Args:
|
140 |
-
|
|
|
141 |
Returns:
|
142 |
-
|
143 |
"""
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
-
|
149 |
-
prediction = response['output']
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
qna_system_message = """
|
160 |
-
You are an assistant to a support agent. Your task is to provide relevant information about the Python package Streamlit.
|
161 |
-
User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context.
|
162 |
-
The context contains references to specific portions of documents relevant to the user's query, along with source links.
|
163 |
-
The source for a context will begin with the token ###Source
|
164 |
-
When crafting your response:
|
165 |
-
1. Select only context relevant to answer the question.
|
166 |
-
2. User questions will begin with the token: ###Question.
|
167 |
-
3. If the context provided doesn't answer the question respond with - "I do not have sufficient information to answer that"
|
168 |
-
4. If user asks for product - list all the products that are relevant to his query. If you don't have that product try to cross sell with one of the products we have that is related to what they are interested in.
|
169 |
-
You should get information about similar products in the context.
|
170 |
-
Please adhere to the following guidelines:
|
171 |
-
- Your response should only be about the question asked and nothing else.
|
172 |
-
- Answer only using the context provided.
|
173 |
-
- Do not mention anything about the context in your final answer.
|
174 |
-
- If the answer is not found in the context, it is very very important for you to respond with "I don't know."
|
175 |
-
- Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
|
176 |
-
- Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources.
|
177 |
-
Here is an example of how to structure your response:
|
178 |
-
Answer:
|
179 |
-
[Answer]
|
180 |
-
Source:
|
181 |
-
[Source]
|
182 |
-
"""
|
183 |
-
|
184 |
-
qna_user_message_template = """
|
185 |
-
###Context
|
186 |
-
Here are some documents and their source that may be relevant to the question mentioned below.
|
187 |
-
{context}
|
188 |
-
###Question
|
189 |
-
{question}
|
190 |
-
"""
|
191 |
-
# Load the persisted DB
|
192 |
-
persisted_vectordb_location = 'policy_docs/policy_docs'
|
193 |
-
#Create a Colelction Name
|
194 |
-
collection_name = 'policy_docs'
|
195 |
-
|
196 |
-
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
|
197 |
-
# Load the persisted DB
|
198 |
vector_store = Chroma(
|
199 |
-
collection_name=
|
200 |
-
persist_directory=
|
201 |
embedding_function=embedding_model
|
202 |
-
|
203 |
)
|
204 |
|
|
|
205 |
retriever = vector_store.as_retriever(
|
206 |
search_type='similarity',
|
207 |
-
search_kwargs={'k':
|
208 |
)
|
209 |
|
|
|
|
|
|
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
base_url=endpoint
|
214 |
-
)
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
-
@tool
|
218 |
-
def rag(user_input: str) -> str:
|
219 |
|
|
|
220 |
"""
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
|
|
|
|
226 |
"""
|
|
|
|
|
|
|
|
|
227 |
|
228 |
-
|
229 |
-
|
|
|
230 |
|
231 |
-
|
|
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
question=user_input
|
238 |
-
)
|
239 |
-
}
|
240 |
-
]
|
241 |
-
try:
|
242 |
-
response = client.chat.completions.create(
|
243 |
-
model="gpt-4o",
|
244 |
-
messages=prompt
|
245 |
-
)
|
246 |
|
247 |
-
|
248 |
-
|
249 |
-
|
|
|
250 |
|
|
|
|
|
|
|
|
|
251 |
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
|
|
254 |
|
255 |
-
#=================================== Other TOOLS======================================#
|
256 |
|
257 |
-
# Function to log actions
|
258 |
-
def log_history(email: str,chat_history: list) -> None:
|
259 |
-
# Save the log to the file
|
260 |
-
with history_scheduler.lock:
|
261 |
-
# Open the log file in append mode
|
262 |
-
with history_file.open("a") as f:
|
263 |
-
f.write(json.dumps({
|
264 |
-
"email": email,
|
265 |
-
"chat_history": chat_history,
|
266 |
-
"timestamp": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
267 |
-
}))
|
268 |
|
269 |
-
|
|
|
|
|
270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
#
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
}))
|
282 |
|
|
|
283 |
|
284 |
-
|
285 |
-
def register_feedback(intent, customer_id, feedback, rating):
|
286 |
"""
|
287 |
-
|
|
|
288 |
Args:
|
289 |
-
|
290 |
-
|
291 |
-
feedback (str): The feedback provided by the customer.
|
292 |
-
rating(int): The rating provided by the customer out of 5
|
293 |
Returns:
|
294 |
-
|
295 |
"""
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
|
306 |
-
|
307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
"""
|
309 |
-
|
|
|
310 |
Args:
|
311 |
-
|
312 |
-
|
313 |
-
reason (str): The reason why the query cannot be resolved by the chatbot.
|
314 |
Returns:
|
315 |
-
|
316 |
"""
|
|
|
317 |
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
|
|
324 |
|
325 |
-
|
326 |
-
|
|
|
|
|
|
|
327 |
|
|
|
328 |
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
"""
|
332 |
-
|
|
|
333 |
Args:
|
334 |
-
|
335 |
-
"""
|
336 |
-
try:
|
337 |
-
# Convert the delivered_date string to a datetime object
|
338 |
-
delivered_date = datetime.strptime(delivered_date, '%Y-%m-%d')
|
339 |
-
today = datetime.today()
|
340 |
-
|
341 |
-
# Calculate the difference in days
|
342 |
-
days_difference = (today - delivered_date).days
|
343 |
-
|
344 |
-
return str(days_difference)
|
345 |
-
except ValueError as e:
|
346 |
-
return f"Error: {e}"
|
347 |
-
|
348 |
-
def build_prompt(df):
|
349 |
-
|
350 |
-
system_message = f"""
|
351 |
-
|
352 |
-
You are an intelligent e-commerce chatbot designed to assist users with pre-order and post-order queries. Your job is to
|
353 |
-
|
354 |
-
Gather necessary information from the user to help them with their query.
|
355 |
-
If at any point you cannot determine the next steps - defer to human. you do not have clearance to go beyond the scope the following flow.
|
356 |
-
Do not provide sql inputs to the sql tool - you only need to ask in natural language what information you need.
|
357 |
-
You are only allowed to provide information relevant to the particular customer and the customer information is provided below. you can provide information of this customer only. Following is the information about the customer from the last 2 weeks:
|
358 |
-
|
359 |
-
{df}
|
360 |
-
|
361 |
-
If this information is not enough to answer question, identify the customer from data above and fetch necessary information usign the sql_tool or rag tool - do not fetch information of other customers.
|
362 |
-
use the details provided in the above file to fetch information from sql tool - like customer id, email and phone. Refrain from asking customers details unless necessary.
|
363 |
-
If customer asks about a product, you should act as a sales representative and help them understand the product as much as possible and provide all the necessary information for them. You should also provide them the link to the product which you can get from the source of the information.
|
364 |
-
If a customer asks a query about a policy, be grounded to the context provided to you. if at any point you don't the right thing to say, politely tell the customer that you are not the right person to answer this and defer it to a human.
|
365 |
-
Any time you defer it to a human, you should tell the customer why you did it in a polite manner.
|
366 |
-
MANDATORY STEP:
|
367 |
-
After helping the customer with their concern,
|
368 |
-
- Ask if the customer needs help with anything else. If they ask for anything from the above list help them and along with that,
|
369 |
-
1. Ask for their feedback and rating out of 5.
|
370 |
-
2. then, Use the `register_feedback` tool to log it. - you MUST ask customer feedback along with asking customer what else they need help with.
|
371 |
-
3. After receving customer feedback exit the chat by responding with 'Bye'.
|
372 |
-
|
373 |
-
---
|
374 |
-
### **Handling Out-of-Scope Queries:**
|
375 |
-
If the user's query, at any point is not covered by the workflows above:
|
376 |
-
- Respond:
|
377 |
-
> "This is beyond my skill. Let me connect you to a customer service agent" and get necessary details from the customer and use the defer_to_human tool.
|
378 |
-
- Get customer feedback and rating out of 5.
|
379 |
-
- After getting feedback, end the conversation by saying 'Bye'.
|
380 |
-
---
|
381 |
-
### **IMPORTANT Notes for the Model:**
|
382 |
-
- Always fetch additional required details from the database and do not blindly believe details provided by the customer like customer id, email and phone number. You should get the customer id from the system prompt. Cross check with the database and stay loyal to the database.
|
383 |
-
- Be empathetic to the customer but loyal to the instructions provided to you. Try to deescalate a situation before deferring it to human and defer to human only once.
|
384 |
-
- Always aim to minimize the number of questions asked by retrieving as much information as possible from `sql_tool` and `rag` tool.
|
385 |
-
- Follow the exact workflows for each query category.
|
386 |
-
- You will always confirm the order id even if the customer has only one order before you fetch any details.
|
387 |
-
"""
|
388 |
|
389 |
-
|
390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
("system", system_message),
|
392 |
-
("
|
393 |
-
|
394 |
])
|
395 |
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
|
398 |
|
399 |
-
#===============================================Streamlit=========================================#
|
400 |
|
401 |
|
402 |
-
def
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
-
email = st.text_input("Email")
|
406 |
-
password = st.text_input("Password", type="password")
|
407 |
|
408 |
-
login_button = st.button("Login")
|
409 |
|
410 |
-
|
411 |
-
if authenticate_user(email, password):
|
412 |
-
st.session_state.logged_in = True
|
413 |
-
st.session_state.email = email
|
414 |
-
st.success("Login successful! Redirecting to Chatbot...")
|
415 |
-
st.rerun()
|
416 |
-
else:
|
417 |
-
st.error("Invalid email or password.")
|
418 |
|
419 |
-
def
|
420 |
-
|
421 |
-
|
422 |
|
423 |
-
|
424 |
-
|
425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
-
|
428 |
-
|
429 |
-
|
|
|
|
|
430 |
|
431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
|
433 |
-
|
434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
|
436 |
-
|
437 |
-
connection = sqlite3.connect("ecomm.db") # Replace with your .db file path
|
438 |
-
cursor = connection.cursor()
|
439 |
-
|
440 |
-
query = f"""
|
441 |
-
SELECT
|
442 |
-
c.customer_id,
|
443 |
-
c.first_name || ' ' || c.last_name AS customer_name,
|
444 |
-
c.email,
|
445 |
-
c.phone,
|
446 |
-
c.address AS customer_address,
|
447 |
-
o.order_id,
|
448 |
-
o.order_date,
|
449 |
-
o.status AS order_status,
|
450 |
-
o.price AS order_price,
|
451 |
-
p.name AS product_name,
|
452 |
-
p.price AS product_price,
|
453 |
-
i.invoice_date,
|
454 |
-
i.amount AS invoice_amount,
|
455 |
-
i.invoice_url,
|
456 |
-
s.delivery_date,
|
457 |
-
s.shipping_status,
|
458 |
-
s.shipping_address,
|
459 |
-
r.refund_amount,
|
460 |
-
r.refund_status
|
461 |
-
FROM Customers c
|
462 |
-
LEFT JOIN Orders o ON c.customer_id = o.customer_id
|
463 |
-
LEFT JOIN Products p ON o.product_id = p.product_id
|
464 |
-
LEFT JOIN Invoices i ON o.order_id = i.order_id
|
465 |
-
LEFT JOIN Shipping s ON o.order_id = s.order_id
|
466 |
-
LEFT JOIN Refund r ON o.order_id = r.order_id
|
467 |
-
WHERE o.order_date >= datetime('now', '-30 days')
|
468 |
-
AND c.email = ?
|
469 |
-
ORDER BY o.order_date DESC;
|
470 |
-
"""
|
471 |
|
472 |
-
|
473 |
-
columns = [description[0] for description in cursor.description] # Extract column names
|
474 |
-
results = cursor.fetchall() # Fetch all rows
|
475 |
-
#st.write(results)
|
476 |
-
# Convert results into a list of dictionaries
|
477 |
-
details = [dict(zip(columns, row)) for row in results]
|
478 |
-
#st.write(details)
|
479 |
-
return str(details).replace("{","/").replace("}","/")
|
480 |
|
481 |
-
|
482 |
-
st.write(f"Error: {e}")
|
483 |
-
finally:
|
484 |
-
# Close the connection
|
485 |
-
if connection:
|
486 |
-
cursor.close()
|
487 |
-
connection.close()
|
488 |
|
489 |
-
# Function to process user input and generate a chatbot response
|
490 |
|
491 |
-
def chatbot_interface():
|
492 |
-
st.title("E-Commerce Chatbot")
|
493 |
|
494 |
-
if 'conversation_history' not in st.session_state:
|
495 |
-
st.session_state.conversation_history = [{"role": "assistant", "content": "welcome! I am Raha, how can I help you on this beautiful day?"}]
|
496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
|
|
|
504 |
|
505 |
-
|
506 |
-
openai_api_base=endpoint,
|
507 |
-
openai_api_key=api_key,
|
508 |
-
model="gpt-4o",
|
509 |
-
streaming=False, # Explicitly disabling streaming
|
510 |
-
temperature=0
|
511 |
-
)
|
512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
try:
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
|
|
|
|
519 |
except Exception as e:
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
|
|
524 |
|
525 |
-
|
526 |
-
|
|
|
|
|
|
|
527 |
|
|
|
|
|
528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
534 |
|
535 |
-
|
536 |
-
if user_input := st.chat_input("You: ", key="chat_input"):
|
537 |
-
# Display user message in chat message container
|
538 |
-
st.chat_message("user").markdown(user_input)
|
539 |
-
with st.spinner("Processing..."):
|
540 |
|
541 |
-
|
542 |
-
|
|
|
|
|
|
|
|
|
543 |
|
544 |
-
|
545 |
-
|
546 |
-
)
|
547 |
|
548 |
-
|
549 |
-
|
550 |
-
response = agent_executor.invoke({"input": conversation_input})
|
551 |
|
552 |
-
# Add the chatbot's response to the history
|
553 |
-
chatbot_response = response['output']
|
554 |
-
st.session_state.conversation_history.append({"role": "assistant", "content": chatbot_response})
|
555 |
-
# Check if the assistant's response contains "exit"
|
556 |
-
if "bye" in chatbot_response.lower():
|
557 |
-
log_history(st.session_state.email,st.session_state.conversation_history)
|
558 |
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
606 |
else:
|
607 |
-
#
|
608 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
609 |
|
610 |
if __name__ == "__main__":
|
611 |
-
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
+
# Import necessary libraries
|
3 |
+
import os # Interacting with the operating system (reading/writing files)
|
4 |
+
import chromadb # High-performance vector database for storing/querying dense vectors
|
5 |
+
from dotenv import load_dotenv # Loading environment variables from a .env file
|
6 |
+
import json # Parsing and handling JSON data
|
7 |
+
|
8 |
+
# LangChain imports
|
9 |
+
from langchain_core.documents import Document # Document data structures
|
10 |
+
from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines
|
11 |
+
from langchain_core.output_parsers import StrOutputParser # String output parser
|
12 |
+
from langchain.prompts import ChatPromptTemplate # Template for chat prompts
|
13 |
+
from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction
|
14 |
+
from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers
|
15 |
+
from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
|
16 |
+
from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
|
17 |
+
|
18 |
+
# LangChain community & experimental imports
|
19 |
+
from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
|
20 |
+
from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
|
21 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
|
22 |
+
from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
|
23 |
+
from langchain.text_splitter import (
|
24 |
+
CharacterTextSplitter, # Splitting text by characters
|
25 |
+
RecursiveCharacterTextSplitter # Recursive splitting of text by characters
|
26 |
+
)
|
27 |
+
from langchain_core.tools import tool
|
28 |
+
from langchain.agents import create_tool_calling_agent, AgentExecutor
|
29 |
+
from langchain_core.prompts import ChatPromptTemplate
|
30 |
|
31 |
+
# LangChain OpenAI imports
|
32 |
+
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
|
33 |
+
from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
# LlamaParse & LlamaIndex imports
|
36 |
+
from llama_parse import LlamaParse # Document parsing library
|
37 |
+
from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
|
38 |
|
39 |
+
# LangGraph import
|
40 |
+
from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain
|
41 |
|
42 |
+
# Pydantic import
|
43 |
+
from pydantic import BaseModel # Pydantic for data validation
|
|
|
|
|
|
|
44 |
|
45 |
+
# Typing imports
|
46 |
+
from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations
|
47 |
|
48 |
+
# Other utilities
|
49 |
+
import numpy as np # Numpy for numerical operations
|
50 |
+
from groq import Groq
|
51 |
+
from mem0 import MemoryClient
|
52 |
+
import streamlit as st
|
53 |
+
from datetime import datetime
|
54 |
|
|
|
|
|
55 |
#====================================SETUP=====================================#
|
56 |
# Fetch secrets from Hugging Face Spaces
|
57 |
+
api_key = config.get("API_KEY")
|
58 |
+
endpoint = config.get("OPENAI_API_BASE")
|
59 |
+
groq_api_key = config.get('LLAMA_API_KEY') # llama_api_key = os.environ['GROQ_API_KEY']
|
60 |
+
MEM0_api_key = config.get('mem0') # MEM0_api_key = os.environ['mem0']
|
61 |
+
|
62 |
+
# Initialize the OpenAI embedding function for Chroma
|
63 |
+
embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
|
64 |
+
api_base=endpoint, # Complete the code to define the API base endpoint
|
65 |
+
api_key=api_key, # Complete the code to define the API key
|
66 |
+
model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
)
|
68 |
|
69 |
+
# This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
|
70 |
|
71 |
+
# Initialize the OpenAI Embeddings
|
72 |
+
embedding_model = OpenAIEmbeddings(
|
73 |
+
openai_api_base=endpoint,
|
74 |
+
openai_api_key=api_key,
|
75 |
+
model='text-embedding-ada-002'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
)
|
77 |
|
78 |
+
# Initialize the Chat OpenAI model
|
79 |
llm = ChatOpenAI(
|
80 |
openai_api_base=endpoint,
|
81 |
openai_api_key=api_key,
|
82 |
+
model="gpt-4o", # used gpt4o instead of gpt-4o-mini to get improved results
|
83 |
+
streaming=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
)
|
85 |
+
# This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
|
86 |
+
|
87 |
+
# set the LLM and embedding model in the LlamaIndex settings.
|
88 |
+
Settings.llm = llm # _____ # Complete the code to define the LLM model
|
89 |
+
Settings.embedding = embedding_model # _____ # Complete the code to define the embedding model
|
90 |
+
|
91 |
+
#================================Creating Langgraph agent======================#
|
92 |
+
|
93 |
+
class AgentState(TypedDict):
|
94 |
+
query: str # The current user query
|
95 |
+
expanded_query: str # The expanded version of the user query
|
96 |
+
context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
|
97 |
+
response: str # The generated response to the user query
|
98 |
+
precision_score: float # The precision score of the response
|
99 |
+
groundedness_score: float # The groundedness score of the response
|
100 |
+
groundedness_loop_count: int # Counter for groundedness refinement loops
|
101 |
+
precision_loop_count: int # Counter for precision refinement loops
|
102 |
+
feedback: str
|
103 |
+
query_feedback: str
|
104 |
+
groundedness_check: bool
|
105 |
+
loop_max_iter: int
|
106 |
+
|
107 |
+
def expand_query(state):
|
108 |
"""
|
109 |
+
Expands the user query to improve retrieval of nutrition disorder-related information.
|
110 |
+
|
111 |
Args:
|
112 |
+
state (Dict): The current state of the workflow, containing the user query.
|
113 |
+
|
114 |
Returns:
|
115 |
+
Dict: The updated state with the expanded query.
|
116 |
"""
|
117 |
+
print("---------Expanding Query---------")
|
118 |
+
system_message = # ________________________
|
119 |
+
'''
|
120 |
+
You are a domain expert assisting in answering questions related to research papers.
|
121 |
+
Convert the user query into something that a nutritionist would understand. Use domain related words.
|
122 |
+
Return 3 related search queries based on the user's request seperated by newline.
|
123 |
+
Return only 3 versions of the question as a list.
|
124 |
+
Perform query expansion on the question received. If there are multiple common ways of phrasing a user question \
|
125 |
+
or common synonyms for key words in the question, make sure to return multiple versions \
|
126 |
+
of the query with the different phrasings.
|
127 |
+
If the query has multiple parts, split them into separate simpler queries. This is the only case where you can generate more than 3 queries.
|
128 |
+
If there are acronyms or words you are not familiar with, do not try to rephrase them.
|
129 |
+
Generate only a list of questions. Do not mention anything before or after the list.
|
130 |
+
Use the query feeback if provided to craft the search queries.
|
131 |
+
'''
|
132 |
+
|
133 |
+
expand_prompt = ChatPromptTemplate.from_messages([
|
134 |
+
("system", system_message),
|
135 |
+
("user", "Expand this query: {query} using the feedback: {query_feedback}")
|
136 |
|
137 |
+
])
|
|
|
138 |
|
139 |
+
chain = expand_prompt | llm | StrOutputParser()
|
140 |
+
expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
|
141 |
+
print("expanded_query", expanded_query)
|
142 |
+
state["expanded_query"] = expanded_query
|
143 |
+
return state
|
144 |
+
|
145 |
+
|
146 |
+
# Initialize the Chroma vector store for retrieving documents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
vector_store = Chroma(
|
148 |
+
collection_name="nutritional_hypotheticals",
|
149 |
+
persist_directory="./nutritional_db2",
|
150 |
embedding_function=embedding_model
|
|
|
151 |
)
|
152 |
|
153 |
+
# Create a retriever from the vector store
|
154 |
retriever = vector_store.as_retriever(
|
155 |
search_type='similarity',
|
156 |
+
search_kwargs={'k': 3}
|
157 |
)
|
158 |
|
159 |
+
def retrieve_context(state):
|
160 |
+
"""
|
161 |
+
Retrieves context from the vector store using the expanded or original query.
|
162 |
|
163 |
+
Args:
|
164 |
+
state (Dict): The current state of the workflow, containing the query and expanded query.
|
|
|
|
|
165 |
|
166 |
+
Returns:
|
167 |
+
Dict: The updated state with the retrieved context.
|
168 |
+
"""
|
169 |
+
print("---------retrieve_context---------")
|
170 |
+
query = state['query'] # state['_____'] # Complete the code to define the key for the expanded query
|
171 |
+
#print("Query used for retrieval:", query) # Debugging: Print the query
|
172 |
+
|
173 |
+
# Retrieve documents from the vector store
|
174 |
+
docs = retriever.invoke(query)
|
175 |
+
print("Retrieved documents:", docs) # Debugging: Print the raw docs object
|
176 |
+
|
177 |
+
# Extract both page_content and metadata from each document
|
178 |
+
context= [
|
179 |
+
{
|
180 |
+
"content": doc.page_content, # The actual content of the document
|
181 |
+
"metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
|
182 |
+
}
|
183 |
+
for doc in docs
|
184 |
+
]
|
185 |
+
state['context'] = context # state['_____'] = context # Complete the code to define the key for storing the context
|
186 |
+
print("Extracted context with metadata:", context) # Debugging: Print the extracted context
|
187 |
+
#print(f"Groundedness loop count: {state['groundedness_loop_count']}")
|
188 |
+
return state
|
189 |
|
|
|
|
|
190 |
|
191 |
+
def craft_response(state: Dict) -> Dict:
|
192 |
"""
|
193 |
+
Generates a response using the retrieved context, focusing on nutrition disorders.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
state (Dict): The current state of the workflow, containing the query and retrieved context.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
Dict: The updated state with the generated response.
|
200 |
"""
|
201 |
+
print("---------craft_response---------")
|
202 |
+
system_message = # ________________________
|
203 |
+
'''
|
204 |
+
Generates a response to a user query and context provided.
|
205 |
|
206 |
+
Parameters:
|
207 |
+
query (str): The user's query and expanded queries based on user's query.
|
208 |
+
context (str): The documents retrieved relevant to the queries.
|
209 |
|
210 |
+
Returns:
|
211 |
+
response (str): The response generated by the model.
|
212 |
|
213 |
+
The function performs the following steps:
|
214 |
+
1. Constructs a prompt containing system and user prompts.
|
215 |
+
2. Sends the prompt containing user queries with context provided to the GPT model to generate a response.
|
216 |
+
3. Displays the response.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
+
The answer you provide must come from the user queries with context provided.
|
219 |
+
If feedback is provided, use it to craft the response.
|
220 |
+
If information provided is not enough to answer the query respons with 'I don't know the answer. Not in my records.'
|
221 |
+
'''
|
222 |
|
223 |
+
response_prompt = ChatPromptTemplate.from_messages([
|
224 |
+
("system", system_message),
|
225 |
+
("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
|
226 |
+
])
|
227 |
|
228 |
+
chain = response_prompt | llm
|
229 |
+
response = chain.invoke({
|
230 |
+
"query": state['query'],
|
231 |
+
"context": "\n".join([doc["content"] for doc in state['context']]),
|
232 |
+
"feedback": state["feedback"] # ________________ # add feedback to the prompt
|
233 |
+
})
|
234 |
+
state['response'] = response
|
235 |
+
print("intermediate response: ", response)
|
236 |
|
237 |
+
return state
|
238 |
|
|
|
239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
+
def score_groundedness(state: Dict) -> Dict:
|
242 |
+
"""
|
243 |
+
Checks whether the response is grounded in the retrieved context.
|
244 |
|
245 |
+
Args:
|
246 |
+
state (Dict): The current state of the workflow, containing the response and context.
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
Dict: The updated state with the groundedness score.
|
250 |
+
"""
|
251 |
+
print("---------check_groundedness---------")
|
252 |
+
system_message = # ________________________
|
253 |
+
'''
|
254 |
+
You are tasked with rating AI generated answers to questions posed by users.
|
255 |
+
Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context.
|
256 |
+
In the input, the context is {context}, while the AI generated response is {response}.
|
257 |
+
|
258 |
+
Evaluation criteria:
|
259 |
+
The task is to judge the extent to which the metric is followed by the answer.
|
260 |
+
1 - The metric is not followed at all
|
261 |
+
2 - The metric is followed only to a limited extent
|
262 |
+
3 - The metric is followed to a good extent
|
263 |
+
4 - The metric is followed mostly
|
264 |
+
5 - The metric is followed completely
|
265 |
+
|
266 |
+
The answer should be derived only from the information presented in the context
|
267 |
+
|
268 |
+
Do not show any instructions for deriving your answer.
|
269 |
+
|
270 |
+
Output your result as a float number between 0 and 1 using the evaluation criteria.
|
271 |
+
The better the criteria, the cloase it is to 1 and the worse the criteria, the closer it is to 0.
|
272 |
+
'''
|
273 |
+
|
274 |
+
groundedness_prompt = ChatPromptTemplate.from_messages([
|
275 |
+
("system", system_message),
|
276 |
+
("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
|
277 |
+
])
|
278 |
|
279 |
+
chain = groundedness_prompt | llm | StrOutputParser()
|
280 |
+
groundedness_score = float(chain.invoke({
|
281 |
+
"context": "\n".join([doc["content"] for doc in state['context']]),
|
282 |
+
"response": state['response'] # __________ # Complete the code to define the response
|
283 |
+
}))
|
284 |
+
print("groundedness_score: ", groundedness_score)
|
285 |
+
state['groundedness_loop_count'] += 1
|
286 |
+
print("#########Groundedness Incremented###########")
|
287 |
+
state['groundedness_score'] = groundedness_score
|
|
|
288 |
|
289 |
+
return state
|
290 |
|
291 |
+
def check_precision(state: Dict) -> Dict:
|
|
|
292 |
"""
|
293 |
+
Checks whether the response precisely addresses the user’s query.
|
294 |
+
|
295 |
Args:
|
296 |
+
state (Dict): The current state of the workflow, containing the query and response.
|
297 |
+
|
|
|
|
|
298 |
Returns:
|
299 |
+
Dict: The updated state with the precision score.
|
300 |
"""
|
301 |
+
print("---------check_precision---------")
|
302 |
+
system_message = # ________________________
|
303 |
+
'''
|
304 |
+
Given question, answer and context verify if the context was useful in arriving at the given answer.
|
305 |
+
Give verdict as "1" if useful and "0" if not useful.
|
306 |
+
Output your result as a float number between 0 and 1
|
307 |
+
Give verdict as a scaled numeric value of type float between 0 and 1, such that
|
308 |
+
0 or near 0 if it is least useful, 0.5 or near 0.5 if retry is warranted, and 1 or close to 1 is most useful.
|
309 |
+
Do not show any instructions for deriving your answer.
|
310 |
+
'''
|
311 |
+
|
312 |
+
precision_prompt = ChatPromptTemplate.from_messages([
|
313 |
+
("system", system_message),
|
314 |
+
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
|
315 |
+
])
|
316 |
|
317 |
+
chain = precision_prompt | llm | StrOutputParser() # _____________ | llm | StrOutputParser() # Complete the code to define the chain of processing
|
318 |
+
precision_score = float(chain.invoke({
|
319 |
+
"query": state['query'],
|
320 |
+
"response": state['response'] # ______________ # Complete the code to access the response from the state
|
321 |
+
}))
|
322 |
+
state['precision_score'] = precision_score
|
323 |
+
print("precision_score:", precision_score)
|
324 |
+
state['precision_loop_count'] +=1
|
325 |
+
print("#########Precision Incremented###########")
|
326 |
+
return state
|
327 |
+
|
328 |
+
def refine_response(state: Dict) -> Dict:
|
329 |
"""
|
330 |
+
Suggests improvements for the generated response.
|
331 |
+
|
332 |
Args:
|
333 |
+
state (Dict): The current state of the workflow, containing the query and response.
|
334 |
+
|
|
|
335 |
Returns:
|
336 |
+
Dict: The updated state with response refinement suggestions.
|
337 |
"""
|
338 |
+
print("---------refine_response---------")
|
339 |
|
340 |
+
system_message = # ________________________
|
341 |
+
'''
|
342 |
+
Since the last response failded the groundedness test, and is deemed not satisfactory,
|
343 |
+
use the feedback in terms of the query, context and the last response
|
344 |
+
to identify potential gaps, ambiguities, or missing details, and
|
345 |
+
to suggest improvements to enhance accuracy and completeness of the response.
|
346 |
+
'''
|
347 |
|
348 |
+
refine_response_prompt = ChatPromptTemplate.from_messages([
|
349 |
+
("system", system_message),
|
350 |
+
("user", "Query: {query}\nResponse: {response}\n\n"
|
351 |
+
"What improvements can be made to enhance accuracy and completeness?")
|
352 |
+
])
|
353 |
|
354 |
+
chain = refine_response_prompt | llm| StrOutputParser()
|
355 |
|
356 |
+
# Store response suggestions in a structured format
|
357 |
+
feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
|
358 |
+
print("feedback: ", feedback)
|
359 |
+
print(f"State: {state}")
|
360 |
+
state['feedback'] = feedback
|
361 |
+
return state
|
362 |
+
|
363 |
+
|
364 |
+
|
365 |
+
def refine_query(state: Dict) -> Dict:
|
366 |
"""
|
367 |
+
Suggests improvements for the expanded query.
|
368 |
+
|
369 |
Args:
|
370 |
+
state (Dict): The current state of the workflow, containing the query and expanded query.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
+
Returns:
|
373 |
+
Dict: The updated state with query refinement suggestions.
|
374 |
+
"""
|
375 |
+
print("---------refine_query---------")
|
376 |
+
system_message = # ________________________
|
377 |
+
'''
|
378 |
+
Since the last response failded the precision test, and is deemed not satisfactory,
|
379 |
+
use the feedback in terms of the query, context and re-generate extended queries
|
380 |
+
to identify specific keywords, scope refinements, or missing details, and
|
381 |
+
to provides structured suggestions for improvement to enhance accuracy and completeness of the response.
|
382 |
+
'''
|
383 |
+
|
384 |
+
refine_query_prompt = ChatPromptTemplate.from_messages([
|
385 |
("system", system_message),
|
386 |
+
("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
|
387 |
+
"What improvements can be made for a better search?")
|
388 |
])
|
389 |
|
390 |
+
chain = refine_query_prompt | llm | StrOutputParser()
|
391 |
+
|
392 |
+
# Store refinement suggestions without modifying the original expanded query
|
393 |
+
query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
|
394 |
+
print("query_feedback: ", query_feedback)
|
395 |
+
print(f"Groundedness loop count: {state['groundedness_loop_count']}")
|
396 |
+
state['query_feedback'] = query_feedback
|
397 |
+
return state
|
398 |
+
|
399 |
+
|
400 |
+
|
401 |
+
def should_continue_groundedness(state):
|
402 |
+
"""Decides if groundedness is sufficient or needs improvement."""
|
403 |
+
print("---------should_continue_groundedness---------")
|
404 |
+
print("groundedness loop count: ", state['groundedness_loop_count'])
|
405 |
+
if state['groundedness_score'] >= 0.8 # _____: # Complete the code to define the threshold for groundedness
|
406 |
+
print("Moving to precision")
|
407 |
+
return "check_precision"
|
408 |
+
else:
|
409 |
+
if state["groundedness_loop_count"] > state['loop_max_iter']:
|
410 |
+
return "max_iterations_reached"
|
411 |
+
else:
|
412 |
+
print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
|
413 |
+
return "refine_response"
|
414 |
+
|
415 |
+
|
416 |
+
def should_continue_precision(state: Dict) -> str:
|
417 |
+
"""Decides if precision is sufficient or needs improvement."""
|
418 |
+
print("---------should_continue_precision---------")
|
419 |
+
print("precision loop count: ", state["precision_loop_count"])
|
420 |
+
if state['precision_score'] > 0.8: # ___________: # Threshold for precision
|
421 |
+
return "pass" # Complete the workflow
|
422 |
+
else:
|
423 |
+
if state["precision_loop_count"] >= state['loop_max_iter']: # ___________: # Maximum allowed loops
|
424 |
+
return "max_iterations_reached"
|
425 |
+
else:
|
426 |
+
print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
|
427 |
+
return "refine_query" # ____________ # Refine the query
|
428 |
|
429 |
|
|
|
430 |
|
431 |
|
432 |
+
def max_iterations_reached(state: Dict) -> Dict:
|
433 |
+
"""Handles the case when the maximum number of iterations is reached."""
|
434 |
+
print("---------max_iterations_reached---------")
|
435 |
+
"""Handles the case when the maximum number of iterations is reached."""
|
436 |
+
response = "I'm unable to refine the response further. Please provide more context or clarify your question."
|
437 |
+
state['response'] = response
|
438 |
+
return state
|
439 |
|
|
|
|
|
440 |
|
|
|
441 |
|
442 |
+
from langgraph.graph import END, StateGraph, START
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
|
444 |
+
def create_workflow() -> StateGraph:
|
445 |
+
"""Creates the updated workflow for the AI nutrition agent."""
|
446 |
+
workflow = StateGraph(State) # StateGraph(_____ ) # Complete the code to define the initial state of the agent
|
447 |
|
448 |
+
# Add processing nodes
|
449 |
+
workflow.add_node("expand_query", expand_query) # _____ ) # Step 1: Expand user query. Complete with the function to expand the query
|
450 |
+
workflow.add_node("retrieve_context", retrieve_context) #_____ ) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
|
451 |
+
workflow.add_node("craft_response", craft_response) # _____ ) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
|
452 |
+
workflow.add_node("score_groundedness", score_groundedness) # _____ ) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
|
453 |
+
workflow.add_node("refine_response", refine_response) # _____ ) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
|
454 |
+
workflow.add_node("check_precision", check_precision) # _____ ) # Step 6: Evaluate response precision. Complete with the function to check precision
|
455 |
+
workflow.add_node("refine_query", refine_query) # _____ ) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
|
456 |
+
workflow.add_node("max_iterations_reached", max_iterations_reached) # _____ ) # Step 8: Handle max iterations. Complete with the function to handle max iterations
|
457 |
|
458 |
+
# Main flow edges
|
459 |
+
workflow.add_edge(START, "expand_query")
|
460 |
+
workflow.add_edge("expand_query", "retrieve_context")
|
461 |
+
workflow.add_edge("retrieve_context", "craft_response")
|
462 |
+
workflow.add_edge("craft_response", "score_groundedness")
|
463 |
|
464 |
+
# Conditional edges based on groundedness check
|
465 |
+
workflow.add_conditional_edges(
|
466 |
+
"score_groundedness",
|
467 |
+
should_continue_groundedness, # ___________, # Use the conditional function
|
468 |
+
{
|
469 |
+
"check_precision": "check_precision", # ___________, # If well-grounded, proceed to precision check.
|
470 |
+
"refine_response": "refine_response", # ___________, # If not, refine the response.
|
471 |
+
"max_iterations_reached": max_iterations_reached # ___________ # If max loops reached, exit.
|
472 |
+
}
|
473 |
+
)
|
474 |
|
475 |
+
workflow.add_edge("refine_response", "craft_response") # __________, ___________) # Refined responses are reprocessed.
|
476 |
+
|
477 |
+
# Conditional edges based on precision check
|
478 |
+
workflow.add_conditional_edges(
|
479 |
+
"check_precision",
|
480 |
+
should_continue_precision, # ___________, # Use the conditional function
|
481 |
+
{
|
482 |
+
"pass": END, # ___________, # If precise, complete the workflow.
|
483 |
+
"refine_query": "refine_query" # ___________, # If imprecise, refine the query.
|
484 |
+
"max_iterations_reached": "max_iterations_reached" # ___________ # If max loops reached, exit.
|
485 |
+
}
|
486 |
+
)
|
487 |
|
488 |
+
workflow.add_edge("refine_query", "expand_query") # __________, ___________) # Refined queries go through expansion again.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
|
490 |
+
workflow.add_edge("max_iterations_reached", END)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
|
492 |
+
return workflow
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
|
|
|
494 |
|
|
|
|
|
495 |
|
|
|
|
|
496 |
|
497 |
+
#=========================== Defining the agentic rag tool ====================#
|
498 |
+
WORKFLOW_APP = create_workflow().compile()
|
499 |
+
@tool
|
500 |
+
def agentic_rag(query: str):
|
501 |
+
"""
|
502 |
+
Runs the RAG-based agent with conversation history for context-aware responses.
|
503 |
|
504 |
+
Args:
|
505 |
+
query (str): The current user query.
|
506 |
+
|
507 |
+
Returns:
|
508 |
+
Dict[str, Any]: The updated state with the generated response and conversation history.
|
509 |
+
"""
|
510 |
+
# Initialize state with necessary parameters
|
511 |
+
inputs = {
|
512 |
+
"query": query, # Current user query
|
513 |
+
"expanded_query": "", # "_____", # Complete the code to define the expanded version of the query
|
514 |
+
"context": [], # Retrieved documents (initially empty)
|
515 |
+
"response": "", # "_____", # Complete the code to define the AI-generated response
|
516 |
+
"precision_score": 0.0, # _____, # Complete the code to define the precision score of the response
|
517 |
+
"groundedness_score": 0.0 # _____, # Complete the code to define the groundedness score of the response
|
518 |
+
"groundedness_loop_count": 0 # _____, # Complete the code to define the counter for groundedness loops
|
519 |
+
"precision_loop_count": 0, # _____, # Complete the code to define the counter for precision loops
|
520 |
+
"feedback": "", # "_____", # Complete the code to define the feedback
|
521 |
+
"query_feedback": "", # "_____", # Complete the code to define the query feedback
|
522 |
+
"loop_max_iter": 3 # _____ # Complete the code to define the maximum number of iterations for loops
|
523 |
+
}
|
524 |
|
525 |
+
output = WORKFLOW_APP.invoke(inputs)
|
526 |
|
527 |
+
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
|
529 |
+
|
530 |
+
#================================ Guardrails ===========================#
|
531 |
+
llama_guard_client = Groq(api_key=groq_api_key) # Groq(api_key=llama_api_key)
|
532 |
+
# Function to filter user input with Llama Guard
|
533 |
+
def filter_input_with_llama_guard(user_input, model="llama-guard-3-8b"):
|
534 |
+
"""
|
535 |
+
Filters user input using Llama Guard to ensure it is safe.
|
536 |
+
|
537 |
+
Parameters:
|
538 |
+
- user_input: The input provided by the user.
|
539 |
+
- model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
|
540 |
+
|
541 |
+
Returns:
|
542 |
+
- The filtered and safe input.
|
543 |
+
"""
|
544 |
try:
|
545 |
+
# Create a request to Llama Guard to filter the user input
|
546 |
+
response = llama_guard_client.chat.completions.create(
|
547 |
+
messages=[{"role": "user", "content": user_input}],
|
548 |
+
model=model,
|
549 |
+
)
|
550 |
+
# Return the filtered input
|
551 |
+
return response.choices[0].message.content.strip()
|
552 |
except Exception as e:
|
553 |
+
print(f"Error with Llama Guard: {e}")
|
554 |
+
return None
|
555 |
+
|
556 |
+
|
557 |
+
#============================= Adding Memory to the agent using mem0 ===============================#
|
558 |
|
559 |
+
class NutritionBot:
|
560 |
+
def __init__(self):
|
561 |
+
"""
|
562 |
+
Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
|
563 |
+
"""
|
564 |
|
565 |
+
# Initialize a memory client to store and retrieve customer interactions
|
566 |
+
self.memory = MemoryClient(api_key=MEM0_api_key) # userdata.get("mem0")) # Complete the code to define the memory client API key
|
567 |
|
568 |
+
# Initialize the OpenAI client using the provided credentials
|
569 |
+
self.client = ChatOpenAI(
|
570 |
+
model_name="gpt-4o", # Used gpt-4o to get improved results; Specify the model to use (e.g., GPT-4 optimized version)
|
571 |
+
api_key=config.get("API_KEY"), # API key for authentication
|
572 |
+
endpoint = config.get("OPENAI_API_BASE"),
|
573 |
+
temperature=0 # Controls randomness in responses; 0 ensures deterministic results
|
574 |
+
)
|
575 |
|
576 |
+
# Define tools available to the chatbot, such as web search
|
577 |
+
tools = [agentic_rag]
|
578 |
+
|
579 |
+
# Define the system prompt to set the behavior of the chatbot
|
580 |
+
system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
|
581 |
+
Guidelines for Interaction:
|
582 |
+
Maintain a polite, professional, and reassuring tone.
|
583 |
+
Show genuine empathy for customer concerns and health challenges.
|
584 |
+
Reference past interactions to provide personalized and consistent advice.
|
585 |
+
Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
|
586 |
+
Ensure consistent and accurate information across conversations.
|
587 |
+
If any detail is unclear or missing, proactively ask for clarification.
|
588 |
+
Always use the agentic_rag tool to retrieve up-to-date and evidence-based nutrition insights.
|
589 |
+
Keep track of ongoing issues and follow-ups to ensure continuity in support.
|
590 |
+
Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
|
591 |
|
592 |
+
"""
|
|
|
|
|
|
|
|
|
593 |
|
594 |
+
# Build the prompt template for the agent
|
595 |
+
prompt = ChatPromptTemplate.from_messages([
|
596 |
+
("system", system_prompt), # System instructions
|
597 |
+
("human", "{input}"), # Placeholder for human input
|
598 |
+
("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
|
599 |
+
])
|
600 |
|
601 |
+
# Create an agent capable of interacting with tools and executing tasks
|
602 |
+
agent = create_tool_calling_agent(self.client, tools, prompt)
|
|
|
603 |
|
604 |
+
# Wrap the agent in an executor to manage tool interactions and execution flow
|
605 |
+
self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
|
|
606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
607 |
|
608 |
+
def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
|
609 |
+
"""
|
610 |
+
Store customer interaction in memory for future reference.
|
611 |
+
|
612 |
+
Args:
|
613 |
+
user_id (str): Unique identifier for the customer.
|
614 |
+
message (str): Customer's query or message.
|
615 |
+
response (str): Chatbot's response.
|
616 |
+
metadata (Dict, optional): Additional metadata for the interaction.
|
617 |
+
"""
|
618 |
+
if metadata is None:
|
619 |
+
metadata = {}
|
620 |
+
|
621 |
+
# Add a timestamp to the metadata for tracking purposes
|
622 |
+
metadata["timestamp"] = datetime.now().isoformat()
|
623 |
+
|
624 |
+
# Format the conversation for storage
|
625 |
+
conversation = [
|
626 |
+
{"role": "user", "content": message},
|
627 |
+
{"role": "assistant", "content": response}
|
628 |
+
]
|
629 |
+
|
630 |
+
# Store the interaction in the memory client
|
631 |
+
self.memory.add(
|
632 |
+
conversation,
|
633 |
+
user_id=user_id,
|
634 |
+
output_format="v1.1",
|
635 |
+
metadata=metadata
|
636 |
+
)
|
637 |
+
|
638 |
+
|
639 |
+
def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
|
640 |
+
"""
|
641 |
+
Retrieve past interactions relevant to the current query.
|
642 |
+
|
643 |
+
Args:
|
644 |
+
user_id (str): Unique identifier for the customer.
|
645 |
+
query (str): The customer's current query.
|
646 |
+
|
647 |
+
Returns:
|
648 |
+
List[Dict]: A list of relevant past interactions.
|
649 |
+
"""
|
650 |
+
return self.memory.search(
|
651 |
+
query=query, # Search for interactions related to the query
|
652 |
+
user_id=user_id, # Restrict search to the specific user
|
653 |
+
limit=_____ # Complete the code to define the limit for retrieved interactions
|
654 |
+
)
|
655 |
+
|
656 |
+
|
657 |
+
def handle_customer_query(self, user_id: str, query: str) -> str:
|
658 |
+
"""
|
659 |
+
Process a customer's query and provide a response, taking into account past interactions.
|
660 |
+
|
661 |
+
Args:
|
662 |
+
user_id (str): Unique identifier for the customer.
|
663 |
+
query (str): Customer's query.
|
664 |
+
|
665 |
+
Returns:
|
666 |
+
str: Chatbot's response.
|
667 |
+
"""
|
668 |
+
|
669 |
+
# Retrieve relevant past interactions for context
|
670 |
+
relevant_history = self.get_relevant_history(user_id, query)
|
671 |
+
|
672 |
+
# Build a context string from the relevant history
|
673 |
+
context = "Previous relevant interactions:\n"
|
674 |
+
for memory in relevant_history:
|
675 |
+
context += f"Customer: {memory['memory']}\n" # Customer's past messages
|
676 |
+
context += f"Support: {memory['memory']}\n" # Chatbot's past responses
|
677 |
+
context += "---\n"
|
678 |
+
|
679 |
+
# Print context for debugging purposes
|
680 |
+
print("Context: ", context)
|
681 |
+
|
682 |
+
# Prepare a prompt combining past context and the current query
|
683 |
+
prompt = f"""
|
684 |
+
Context:
|
685 |
+
{context}
|
686 |
+
|
687 |
+
Current customer query: {query}
|
688 |
+
|
689 |
+
Provide a helpful response that takes into account any relevant past interactions.
|
690 |
+
"""
|
691 |
+
|
692 |
+
# Generate a response using the agent
|
693 |
+
response = self.agent_executor.invoke({"input": prompt})
|
694 |
+
|
695 |
+
# Store the current interaction for future reference
|
696 |
+
self.store_customer_interaction(
|
697 |
+
user_id=user_id,
|
698 |
+
message=query,
|
699 |
+
response=response["output"],
|
700 |
+
metadata={"type": "support_query"}
|
701 |
+
)
|
702 |
+
|
703 |
+
# Return the chatbot's response
|
704 |
+
return response['output']
|
705 |
+
|
706 |
+
|
707 |
+
#=====================User Interface using streamlit ===========================#
|
708 |
+
def nutrition_disorder_streamlit():
|
709 |
+
"""
|
710 |
+
A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
|
711 |
+
"""
|
712 |
+
st.title("Nutrition Disorder Specialist")
|
713 |
+
st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
|
714 |
+
st.write("Type 'exit' to end the conversation.")
|
715 |
+
|
716 |
+
# Initialize session state for chat history and user_id if they don't exist
|
717 |
+
if 'chat_history' not in st.session_state:
|
718 |
+
st.session_state.chat_history = []
|
719 |
+
if 'user_id' not in st.session_state:
|
720 |
+
st.session_state.user_id = None
|
721 |
+
|
722 |
+
# Login form: Only if user is not logged in
|
723 |
+
if st.session_state.user_id is None:
|
724 |
+
with st.form("login_form", clear_on_submit=True):
|
725 |
+
user_id = st.text_input("Please enter your name to begin:")
|
726 |
+
submit_button = st.form_submit_button("Login")
|
727 |
+
if submit_button and user_id:
|
728 |
+
st.session_state.user_id = user_id
|
729 |
+
st.session_state.chat_history.append({
|
730 |
+
"role": "assistant",
|
731 |
+
"content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
|
732 |
+
})
|
733 |
+
st.session_state.login_submitted = True # Set flag to trigger rerun
|
734 |
+
if st.session_state.get("login_submitted", False):
|
735 |
+
st.session_state.pop("login_submitted")
|
736 |
+
st.rerun()
|
737 |
else:
|
738 |
+
# Display chat history
|
739 |
+
for message in st.session_state.chat_history:
|
740 |
+
with st.chat_message(message["role"]):
|
741 |
+
st.write(message["content"])
|
742 |
+
|
743 |
+
# Chat input with custom placeholder text
|
744 |
+
user_query = st.chat_input("Type your question here (or 'exit' to end)...") # __________) # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
|
745 |
+
if user_query:
|
746 |
+
if user_query.lower() == "exit":
|
747 |
+
st.session_state.chat_history.append({"role": "user", "content": "exit"})
|
748 |
+
with st.chat_message("user"):
|
749 |
+
st.write("exit")
|
750 |
+
goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders."
|
751 |
+
st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
|
752 |
+
with st.chat_message("assistant"):
|
753 |
+
st.write(goodbye_msg)
|
754 |
+
st.session_state.user_id = None
|
755 |
+
st.rerun()
|
756 |
+
return
|
757 |
+
|
758 |
+
st.session_state.chat_history.append({"role": "user", "content": user_query})
|
759 |
+
with st.chat_message("user"):
|
760 |
+
st.write(user_query)
|
761 |
+
|
762 |
+
# Filter input using Llama Guard
|
763 |
+
filtered_result = filter_input_with_llama_guard(user_query) # __________(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
|
764 |
+
filtered_result = filtered_result.replace("\n", " ") # Normalize the result
|
765 |
+
|
766 |
+
# Check if input is safe based on allowed statuses
|
767 |
+
if filtered_result in ["safe", "safe S7", "safe S6"]: # __________, __________, __________]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
|
768 |
+
try:
|
769 |
+
if 'chatbot' not in st.session_state:
|
770 |
+
st.session_state.chatbot = NutritionBot # () __________() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
|
771 |
+
response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query) # __________(st.session_state.user_id, user_query)
|
772 |
+
# Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
|
773 |
+
st.write(response)
|
774 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
775 |
+
except Exception as e:
|
776 |
+
error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
|
777 |
+
st.write(error_msg)
|
778 |
+
st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
|
779 |
+
else:
|
780 |
+
inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
|
781 |
+
st.write(inappropriate_msg)
|
782 |
+
st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
|
783 |
|
784 |
if __name__ == "__main__":
|
785 |
+
nutrition_disorder_streamlit()
|