jenngang commited on
Commit
4f82aa4
·
verified ·
1 Parent(s): 0fdd6fd

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +675 -501
app.py CHANGED
@@ -1,611 +1,785 @@
1
- import streamlit as st
2
- import json
3
- import os
4
- import uuid
5
 
6
- import pandas as pd
7
- from datetime import datetime
8
- import sqlite3
9
- import weave
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- from langchain.memory import ConversationSummaryBufferMemory
12
- from openai import OpenAI
13
- from langchain_openai import ChatOpenAI
14
- #from langchain.embeddings import OpenAIEmbeddings
15
- from langchain_community.embeddings import OpenAIEmbeddings
16
- from langchain_openai import OpenAIEmbeddings
17
- from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
18
- from langchain_community.vectorstores import Chroma
19
- from langchain_core.caches import BaseCache
20
 
 
 
 
21
 
 
 
22
 
23
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
24
- from langchain_community.utilities.sql_database import SQLDatabase
25
- from langchain_community.agent_toolkits import create_sql_agent
26
- from langchain.agents import create_tool_calling_agent, AgentExecutor
27
- from langchain_core.tools import tool
28
 
 
 
29
 
30
- from huggingface_hub import CommitScheduler
31
- from pathlib import Path
 
 
 
 
32
 
33
- #BaseCache.register_cache_type("memory", lambda: None)
34
- #ChatOpenAI.model_rebuild()
35
  #====================================SETUP=====================================#
36
  # Fetch secrets from Hugging Face Spaces
37
-
38
- model_name = "gpt-4o"
39
-
40
- # Extract the OpenAI key and endpoint from the configuration
41
-
42
- api_key = os.environ["API_KEY"]
43
- endpoint = os.environ["OPENAI_API_BASE"]
44
-
45
-
46
- # Define the location of the SQLite database
47
- db_loc = 'ecomm.db'
48
-
49
- # Create a SQLDatabase instance from the SQLite database URI
50
- db = SQLDatabase.from_uri(f"sqlite:///{db_loc}")
51
-
52
- # Retrieve the schema information of the database tables
53
- database_schema = db.get_table_info()
54
-
55
-
56
- # Let's initiate w&b weave with a project name - this will automatically save all the llm calls made using openai or gemini
57
- # Make sure to save your w&b api key in secrets as WANDB_API_KEY
58
- weave.init('ecomm_support')
59
- # <--------------------------------------------------------- Uncomment to log to WANDB
60
-
61
- #=================================Setup Logging=====================================#
62
-
63
- log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
64
- log_folder = log_file.parent
65
-
66
- log_scheduler = CommitScheduler(
67
- repo_id="chatbot-logs", #Dataset name where we want to save the logs.
68
- repo_type="dataset",
69
- folder_path=log_folder,
70
- path_in_repo="data",
71
- every=5 # Saves data every x minute
72
  )
73
 
 
74
 
75
- history_file = Path("history/")/f"data_{uuid.uuid4()}.json"
76
- history_folder = history_file.parent
77
-
78
- history_scheduler = CommitScheduler(
79
- repo_id="chatbot-history", #Dataset name where we want to save the logs.
80
- repo_type="dataset",
81
- folder_path=history_folder,
82
- path_in_repo="data",
83
- every=5 # Saves data every x minute
84
- )
85
-
86
- #=================================SQL_AGENT=====================================#
87
-
88
- # Define the system message for the agent, including instructions and available tables
89
- system_message = f"""You are a SQLite expert agent designed to interact with a SQLite database.
90
- Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
91
- Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 100 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database..
92
- You can order the results by a relevant column to return the most interesting examples in the database.
93
- You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
94
- You have access to tools for interacting with the database.
95
- Only use the given tools. Only use the information returned by the tools to construct your final answer.
96
- You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
97
- DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
98
- You are not allowed to make dummy data.
99
- If the question does not seem related to the database, just return "I don't know" as the answer.
100
- Before you execute the query, tell us why you are executing it and what you expect to find briefly.
101
- Only use the following tables:
102
- {database_schema}
103
- """
104
-
105
- # Create a full prompt template for the agent using the system message and placeholders
106
- full_prompt = ChatPromptTemplate.from_messages(
107
- [
108
- ("system", system_message),
109
- ("human", '{input}'),
110
- MessagesPlaceholder("agent_scratchpad")
111
- ]
112
  )
113
 
114
- # Initialize the ChatOpenAI model with the extracted configuration
115
  llm = ChatOpenAI(
116
  openai_api_base=endpoint,
117
  openai_api_key=api_key,
118
- model="gpt-4o",
119
- streaming=False # Explicitly disabling streaming
120
- )
121
-
122
- # Create the SQL agent using the ChatOpenAI model, database, and prompt template
123
- sqlite_agent = create_sql_agent(
124
- llm=llm,
125
- db=db,
126
- prompt=full_prompt,
127
- agent_type="openai-tools",
128
- agent_executor_kwargs={'handle_parsing_errors': True},
129
- max_iterations=5,
130
- verbose=True
131
  )
132
- #### Let's convert the sql agent into a tool that our fin agent can use.
133
-
134
- @tool
135
- def sql_tool(user_input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  """
137
- Gathers information regarding purchases, transactions, returns, refunds, etc.
138
- Executes a SQL query using the sqlite_agent and returns the result.
139
  Args:
140
- user_input (str): a natural language query string explaining what information is required while also providing the necessary details to get the information.
 
141
  Returns:
142
- str: The result of the SQL query execution. If an error occurs, the exception is returned as a string.
143
  """
144
- try:
145
- # Invoke the sqlite_agent with the user input (SQL query)
146
- response = sqlite_agent.invoke(user_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- # Extract the output from the response
149
- prediction = response['output']
150
 
151
- except Exception as e:
152
- # If an exception occurs, capture the exception message
153
- prediction = e
154
-
155
- # Return the result or the exception message
156
- return prediction
157
-
158
- #=================================== RAG TOOL======================================#
159
- qna_system_message = """
160
- You are an assistant to a support agent. Your task is to provide relevant information about the Python package Streamlit.
161
- User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context.
162
- The context contains references to specific portions of documents relevant to the user's query, along with source links.
163
- The source for a context will begin with the token ###Source
164
- When crafting your response:
165
- 1. Select only context relevant to answer the question.
166
- 2. User questions will begin with the token: ###Question.
167
- 3. If the context provided doesn't answer the question respond with - "I do not have sufficient information to answer that"
168
- 4. If user asks for product - list all the products that are relevant to his query. If you don't have that product try to cross sell with one of the products we have that is related to what they are interested in.
169
- You should get information about similar products in the context.
170
- Please adhere to the following guidelines:
171
- - Your response should only be about the question asked and nothing else.
172
- - Answer only using the context provided.
173
- - Do not mention anything about the context in your final answer.
174
- - If the answer is not found in the context, it is very very important for you to respond with "I don't know."
175
- - Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
176
- - Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources.
177
- Here is an example of how to structure your response:
178
- Answer:
179
- [Answer]
180
- Source:
181
- [Source]
182
- """
183
-
184
- qna_user_message_template = """
185
- ###Context
186
- Here are some documents and their source that may be relevant to the question mentioned below.
187
- {context}
188
- ###Question
189
- {question}
190
- """
191
- # Load the persisted DB
192
- persisted_vectordb_location = 'policy_docs/policy_docs'
193
- #Create a Colelction Name
194
- collection_name = 'policy_docs'
195
-
196
- embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
197
- # Load the persisted DB
198
  vector_store = Chroma(
199
- collection_name=collection_name,
200
- persist_directory=persisted_vectordb_location,
201
  embedding_function=embedding_model
202
-
203
  )
204
 
 
205
  retriever = vector_store.as_retriever(
206
  search_type='similarity',
207
- search_kwargs={'k': 5}
208
  )
209
 
 
 
 
210
 
211
- client = OpenAI(
212
- api_key=api_key,
213
- base_url=endpoint
214
- )
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- @tool
218
- def rag(user_input: str) -> str:
219
 
 
220
  """
221
- Answers questions regarding products, and policies using product descriptions, product policies, and general policies of business using RAG.
222
- Args:
223
- user_input (str): The input question or query from the user.
224
- Returns:
225
- response (str): Return the generated response or an error message if an exception occurs.
 
 
226
  """
 
 
 
 
227
 
228
- relevant_document_chunks = retriever.invoke(user_input)
229
- context_list = [d.page_content + "\n ###Source: " + d.metadata['source'] + "\n\n " for d in relevant_document_chunks]
 
230
 
231
- context_for_query = ". ".join(context_list)
 
232
 
233
- prompt = [
234
- {'role':'system', 'content': qna_system_message},
235
- {'role': 'user', 'content': qna_user_message_template.format(
236
- context=context_for_query,
237
- question=user_input
238
- )
239
- }
240
- ]
241
- try:
242
- response = client.chat.completions.create(
243
- model="gpt-4o",
244
- messages=prompt
245
- )
246
 
247
- prediction = response.choices[0].message.content
248
- except Exception as e:
249
- prediction = f'Sorry, I encountered the following error: \n {e}'
 
250
 
 
 
 
 
251
 
252
- return prediction
 
 
 
 
 
 
 
253
 
 
254
 
255
- #=================================== Other TOOLS======================================#
256
 
257
- # Function to log actions
258
- def log_history(email: str,chat_history: list) -> None:
259
- # Save the log to the file
260
- with history_scheduler.lock:
261
- # Open the log file in append mode
262
- with history_file.open("a") as f:
263
- f.write(json.dumps({
264
- "email": email,
265
- "chat_history": chat_history,
266
- "timestamp": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
267
- }))
268
 
269
- #st.write("chat_recorded")
 
 
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- def log_action(customer_id: str,task: str, details: str) -> None:
273
- # Save the log to the file
274
- with log_scheduler.lock:
275
- # Open the log file in append mode
276
- with log_file.open("a") as f:
277
- f.write(json.dumps({
278
- "customer_id": customer_id,
279
- "task": task,
280
- "details": details
281
- }))
282
 
 
283
 
284
- @tool
285
- def register_feedback(intent, customer_id, feedback, rating):
286
  """
287
- Logs customer feedback into the feedback log.
 
288
  Args:
289
- intent (str): The category of the support query (e.g., "cancel_order", "get_refund").
290
- customer_id (int): The unique ID of the customer.
291
- feedback (str): The feedback provided by the customer.
292
- rating(int): The rating provided by the customer out of 5
293
  Returns:
294
- str: Success message.
295
  """
296
- details = {
297
- "intent": intent,
298
- "customer_id": customer_id,
299
- "feedback": feedback,
300
- "rating": rating
301
- }
302
- log_action(customer_id,"register_feedback", details)
303
- #print("register_feedback success")
304
- #return "Feedback registered successfully!"
 
 
 
 
 
 
305
 
306
- @tool
307
- def defer_to_human(customer_id, query, intent, reason):
 
 
 
 
 
 
 
 
 
 
308
  """
309
- Logs customer details and the reason for deferring to a human agent.
 
310
  Args:
311
- customer_id (int): The unique ID of the customer whose query is being deferred.
312
- query (str): The customer's query or issue that needs human intervention.
313
- reason (str): The reason why the query cannot be resolved by the chatbot.
314
  Returns:
315
- str: Success message indicating the deferral was logged.
316
  """
 
317
 
318
- details = {
319
- "customer_id": customer_id,
320
- "query": query,
321
- "reason": reason,
322
- "intent": intent
323
- }
 
324
 
325
- log_action(customer_id,"defer_to_human", details)
326
- #return "Case deferred to human agent and logged successfully!"
 
 
 
327
 
 
328
 
329
- @tool
330
- def days_since(delivered_date: str) ->str:
 
 
 
 
 
 
 
 
331
  """
332
- Calculates the number of days since the product was delivered. This helps in determining whether the product is within return period or not.
 
333
  Args:
334
- delivered_date (str): The date when the product was delivered in the format 'YYYY-MM-DD'.
335
- """
336
- try:
337
- # Convert the delivered_date string to a datetime object
338
- delivered_date = datetime.strptime(delivered_date, '%Y-%m-%d')
339
- today = datetime.today()
340
-
341
- # Calculate the difference in days
342
- days_difference = (today - delivered_date).days
343
-
344
- return str(days_difference)
345
- except ValueError as e:
346
- return f"Error: {e}"
347
-
348
- def build_prompt(df):
349
-
350
- system_message = f"""
351
-
352
- You are an intelligent e-commerce chatbot designed to assist users with pre-order and post-order queries. Your job is to
353
-
354
- Gather necessary information from the user to help them with their query.
355
- If at any point you cannot determine the next steps - defer to human. you do not have clearance to go beyond the scope the following flow.
356
- Do not provide sql inputs to the sql tool - you only need to ask in natural language what information you need.
357
- You are only allowed to provide information relevant to the particular customer and the customer information is provided below. you can provide information of this customer only. Following is the information about the customer from the last 2 weeks:
358
-
359
- {df}
360
-
361
- If this information is not enough to answer question, identify the customer from data above and fetch necessary information usign the sql_tool or rag tool - do not fetch information of other customers.
362
- use the details provided in the above file to fetch information from sql tool - like customer id, email and phone. Refrain from asking customers details unless necessary.
363
- If customer asks about a product, you should act as a sales representative and help them understand the product as much as possible and provide all the necessary information for them. You should also provide them the link to the product which you can get from the source of the information.
364
- If a customer asks a query about a policy, be grounded to the context provided to you. if at any point you don't the right thing to say, politely tell the customer that you are not the right person to answer this and defer it to a human.
365
- Any time you defer it to a human, you should tell the customer why you did it in a polite manner.
366
- MANDATORY STEP:
367
- After helping the customer with their concern,
368
- - Ask if the customer needs help with anything else. If they ask for anything from the above list help them and along with that,
369
- 1. Ask for their feedback and rating out of 5.
370
- 2. then, Use the `register_feedback` tool to log it. - you MUST ask customer feedback along with asking customer what else they need help with.
371
- 3. After receving customer feedback exit the chat by responding with 'Bye'.
372
-
373
- ---
374
- ### **Handling Out-of-Scope Queries:**
375
- If the user's query, at any point is not covered by the workflows above:
376
- - Respond:
377
- > "This is beyond my skill. Let me connect you to a customer service agent" and get necessary details from the customer and use the defer_to_human tool.
378
- - Get customer feedback and rating out of 5.
379
- - After getting feedback, end the conversation by saying 'Bye'.
380
- ---
381
- ### **IMPORTANT Notes for the Model:**
382
- - Always fetch additional required details from the database and do not blindly believe details provided by the customer like customer id, email and phone number. You should get the customer id from the system prompt. Cross check with the database and stay loyal to the database.
383
- - Be empathetic to the customer but loyal to the instructions provided to you. Try to deescalate a situation before deferring it to human and defer to human only once.
384
- - Always aim to minimize the number of questions asked by retrieving as much information as possible from `sql_tool` and `rag` tool.
385
- - Follow the exact workflows for each query category.
386
- - You will always confirm the order id even if the customer has only one order before you fetch any details.
387
- """
388
 
389
- #st.write(system_message)
390
- prompt = ChatPromptTemplate.from_messages([
 
 
 
 
 
 
 
 
 
 
 
391
  ("system", system_message),
392
- ("human", "{input}"),
393
- ("placeholder", "{agent_scratchpad}"),
394
  ])
395
 
396
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
 
399
- #===============================================Streamlit=========================================#
400
 
401
 
402
- def login_page():
403
- st.title("Login Page")
 
 
 
 
 
404
 
405
- email = st.text_input("Email")
406
- password = st.text_input("Password", type="password")
407
 
408
- login_button = st.button("Login")
409
 
410
- if login_button:
411
- if authenticate_user(email, password):
412
- st.session_state.logged_in = True
413
- st.session_state.email = email
414
- st.success("Login successful! Redirecting to Chatbot...")
415
- st.rerun()
416
- else:
417
- st.error("Invalid email or password.")
418
 
419
- def authenticate_user(email, phone):
420
- connection = sqlite3.connect("ecomm.db") # Replace with your .db file path
421
- cursor = connection.cursor()
422
 
423
- query = "SELECT first_name FROM customers WHERE email = ? AND phone = ?"
424
- cursor.execute(query, (email, phone))
425
- user = cursor.fetchone()
 
 
 
 
 
 
426
 
427
- if user:
428
- return True # Login successful
429
- return False # Login failed
 
 
430
 
431
- ### Prefetch details
 
 
 
 
 
 
 
 
 
432
 
433
- def fetch_details(email):
434
- try:
 
 
 
 
 
 
 
 
 
 
435
 
436
- # Connect to the SQLite database
437
- connection = sqlite3.connect("ecomm.db") # Replace with your .db file path
438
- cursor = connection.cursor()
439
-
440
- query = f"""
441
- SELECT
442
- c.customer_id,
443
- c.first_name || ' ' || c.last_name AS customer_name,
444
- c.email,
445
- c.phone,
446
- c.address AS customer_address,
447
- o.order_id,
448
- o.order_date,
449
- o.status AS order_status,
450
- o.price AS order_price,
451
- p.name AS product_name,
452
- p.price AS product_price,
453
- i.invoice_date,
454
- i.amount AS invoice_amount,
455
- i.invoice_url,
456
- s.delivery_date,
457
- s.shipping_status,
458
- s.shipping_address,
459
- r.refund_amount,
460
- r.refund_status
461
- FROM Customers c
462
- LEFT JOIN Orders o ON c.customer_id = o.customer_id
463
- LEFT JOIN Products p ON o.product_id = p.product_id
464
- LEFT JOIN Invoices i ON o.order_id = i.order_id
465
- LEFT JOIN Shipping s ON o.order_id = s.order_id
466
- LEFT JOIN Refund r ON o.order_id = r.order_id
467
- WHERE o.order_date >= datetime('now', '-30 days')
468
- AND c.email = ?
469
- ORDER BY o.order_date DESC;
470
- """
471
 
472
- cursor.execute(query, (email,))
473
- columns = [description[0] for description in cursor.description] # Extract column names
474
- results = cursor.fetchall() # Fetch all rows
475
- #st.write(results)
476
- # Convert results into a list of dictionaries
477
- details = [dict(zip(columns, row)) for row in results]
478
- #st.write(details)
479
- return str(details).replace("{","/").replace("}","/")
480
 
481
- except Exception as e:
482
- st.write(f"Error: {e}")
483
- finally:
484
- # Close the connection
485
- if connection:
486
- cursor.close()
487
- connection.close()
488
 
489
- # Function to process user input and generate a chatbot response
490
 
491
- def chatbot_interface():
492
- st.title("E-Commerce Chatbot")
493
 
494
- if 'conversation_history' not in st.session_state:
495
- st.session_state.conversation_history = [{"role": "assistant", "content": "welcome! I am Raha, how can I help you on this beautiful day?"}]
496
 
 
 
 
 
 
 
497
 
498
- details = fetch_details(st.session_state.email)
499
- # st.write(details)
500
- prompt = build_prompt(details)
501
- tools = []
502
- #[sql_tool,defer_to_human, rag, register_feedback, days_since]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
 
504
 
505
- chatbot = ChatOpenAI(
506
- openai_api_base=endpoint,
507
- openai_api_key=api_key,
508
- model="gpt-4o",
509
- streaming=False, # Explicitly disabling streaming
510
- temperature=0
511
- )
512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  try:
514
- st.write("Attempting direct LLM test...")
515
- test_response = chatbot.invoke("Hello, can you hear me?")
516
- st.success(f"Direct LLM Test OK: Received response.")
517
- # Optionally display part of the response if needed for confirmation
518
- # st.write(test_response.content[:100])
 
 
519
  except Exception as e:
520
- st.error(f"Direct LLM Test FAILED: {e}")
521
- st.error("The basic connection to the LLM endpoint might be failing. Check API Key, Endpoint URL, and Network.")
522
- # You might want to stop execution here if the basic test fails
523
- st.stop()
 
524
 
525
- agent = create_tool_calling_agent(chatbot, tools, prompt)
526
- agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
 
 
 
527
 
 
 
528
 
 
 
 
 
 
 
 
529
 
530
- # Display chat messages from history on app rerun
531
- for message in st.session_state.conversation_history:
532
- with st.chat_message(message["role"]):
533
- st.markdown(message["content"])
 
 
 
 
 
 
 
 
 
 
 
534
 
535
- # React to user input
536
- if user_input := st.chat_input("You: ", key="chat_input"):
537
- # Display user message in chat message container
538
- st.chat_message("user").markdown(user_input)
539
- with st.spinner("Processing..."):
540
 
541
- # Add user message to conversation history
542
- st.session_state.conversation_history.append({"role": "user", "content": user_input})
 
 
 
 
543
 
544
- conversation_input = "\n".join(
545
- [f"{turn['role'].capitalize()}: {turn['content']}" for turn in st.session_state.conversation_history]
546
- )
547
 
548
- try:
549
- # Pass the history to the agent
550
- response = agent_executor.invoke({"input": conversation_input})
551
 
552
- # Add the chatbot's response to the history
553
- chatbot_response = response['output']
554
- st.session_state.conversation_history.append({"role": "assistant", "content": chatbot_response})
555
- # Check if the assistant's response contains "exit"
556
- if "bye" in chatbot_response.lower():
557
- log_history(st.session_state.email,st.session_state.conversation_history)
558
 
559
- # Display the chatbot's response
560
- with st.chat_message("assistant"):
561
- st.markdown(chatbot_response)
562
-
563
- except ValueError as ve:
564
- if "No generation chunks were returned" in str(ve):
565
- st.error(f"Agent Error: Failed to get a response from the LLM during agent execution (ValueError: {ve}).")
566
- st.error("This often indicates an issue with the custom API endpoint (compatibility, content filtering, or error) or network problems.")
567
- st.error("Check the Hugging Face Space logs for more details, especially network errors or messages related to content policy.")
568
- # Log the error for debugging
569
- print(f"Caught ValueError: {ve}") # This will print to HF Space logs
570
- # Optionally add the error to chat history for visibility
571
- st.session_state.conversation_history.append({"role": "assistant", "content": f"Sorry, I encountered an internal error (ValueError: No generation chunks). Please check logs or try again later."})
572
- with st.chat_message("assistant"):
573
- st.markdown("Sorry, I encountered an internal error (ValueError: No generation chunks). Please check logs or try again later.")
574
-
575
- else:
576
- # Handle other ValueErrors if necessary
577
- st.error(f"Agent Error: An unexpected ValueError occurred: {ve}")
578
- print(f"Caught other ValueError: {ve}")
579
- st.session_state.conversation_history.append({"role": "assistant", "content": f"Sorry, I encountered an unexpected internal error (ValueError)."})
580
- with st.chat_message("assistant"):
581
- st.markdown("Sorry, I encountered an unexpected internal error (ValueError).")
582
-
583
- except Exception as e:
584
- st.error(f"Agent Error: An unexpected error occurred: {e}")
585
- st.error("Check Hugging Face Space logs for the full traceback.")
586
- # Log the full error and traceback for debugging
587
- import traceback
588
- print(f"Caught Exception: {e}")
589
- print(traceback.format_exc()) # Print full traceback to logs
590
- # Check if it looks like a content policy error based on the message
591
- if "policy" in str(e).lower() or "safety" in str(e).lower() or "blocked" in str(e).lower():
592
- st.warning("The error message suggests the request might have been blocked by a content policy.")
593
- st.session_state.conversation_history.append({"role": "assistant", "content": f"Sorry, my response might have been blocked by a content policy. Error: {e}"})
594
- with st.chat_message("assistant"):
595
- st.markdown(f"Sorry, my response might have been blocked by a content policy. Error: {e}")
596
- else:
597
- st.session_state.conversation_history.append({"role": "assistant", "content": f"Sorry, I encountered an unexpected error: {e}"})
598
- with st.chat_message("assistant"):
599
- st.markdown(f"Sorry, I encountered an unexpected error: {e}")
600
-
601
- def main():
602
- # Check if the user is logged in
603
- if "logged_in" in st.session_state and st.session_state["logged_in"]:
604
- # Show chatbot page if logged in
605
- chatbot_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
  else:
607
- # Show login page if not logged in
608
- login_page()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
 
610
  if __name__ == "__main__":
611
- main()
 
 
 
 
 
1
 
2
+ # Import necessary libraries
3
+ import os # Interacting with the operating system (reading/writing files)
4
+ import chromadb # High-performance vector database for storing/querying dense vectors
5
+ from dotenv import load_dotenv # Loading environment variables from a .env file
6
+ import json # Parsing and handling JSON data
7
+
8
+ # LangChain imports
9
+ from langchain_core.documents import Document # Document data structures
10
+ from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines
11
+ from langchain_core.output_parsers import StrOutputParser # String output parser
12
+ from langchain.prompts import ChatPromptTemplate # Template for chat prompts
13
+ from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction
14
+ from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers
15
+ from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
16
+ from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
17
+
18
+ # LangChain community & experimental imports
19
+ from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
20
+ from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
21
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
22
+ from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
23
+ from langchain.text_splitter import (
24
+ CharacterTextSplitter, # Splitting text by characters
25
+ RecursiveCharacterTextSplitter # Recursive splitting of text by characters
26
+ )
27
+ from langchain_core.tools import tool
28
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
29
+ from langchain_core.prompts import ChatPromptTemplate
30
 
31
+ # LangChain OpenAI imports
32
+ from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
33
+ from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
 
 
 
 
 
 
34
 
35
+ # LlamaParse & LlamaIndex imports
36
+ from llama_parse import LlamaParse # Document parsing library
37
+ from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
38
 
39
+ # LangGraph import
40
+ from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain
41
 
42
+ # Pydantic import
43
+ from pydantic import BaseModel # Pydantic for data validation
 
 
 
44
 
45
+ # Typing imports
46
+ from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations
47
 
48
+ # Other utilities
49
+ import numpy as np # Numpy for numerical operations
50
+ from groq import Groq
51
+ from mem0 import MemoryClient
52
+ import streamlit as st
53
+ from datetime import datetime
54
 
 
 
55
  #====================================SETUP=====================================#
56
  # Fetch secrets from Hugging Face Spaces
57
+ api_key = config.get("API_KEY")
58
+ endpoint = config.get("OPENAI_API_BASE")
59
+ groq_api_key = config.get('LLAMA_API_KEY') # llama_api_key = os.environ['GROQ_API_KEY']
60
+ MEM0_api_key = config.get('mem0') # MEM0_api_key = os.environ['mem0']
61
+
62
+ # Initialize the OpenAI embedding function for Chroma
63
+ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
64
+ api_base=endpoint, # Complete the code to define the API base endpoint
65
+ api_key=api_key, # Complete the code to define the API key
66
+ model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  )
68
 
69
+ # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
70
 
71
+ # Initialize the OpenAI Embeddings
72
+ embedding_model = OpenAIEmbeddings(
73
+ openai_api_base=endpoint,
74
+ openai_api_key=api_key,
75
+ model='text-embedding-ada-002'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  )
77
 
78
+ # Initialize the Chat OpenAI model
79
  llm = ChatOpenAI(
80
  openai_api_base=endpoint,
81
  openai_api_key=api_key,
82
+ model="gpt-4o", # used gpt4o instead of gpt-4o-mini to get improved results
83
+ streaming=False
 
 
 
 
 
 
 
 
 
 
 
84
  )
85
+ # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
86
+
87
+ # set the LLM and embedding model in the LlamaIndex settings.
88
+ Settings.llm = llm # _____ # Complete the code to define the LLM model
89
+ Settings.embedding = embedding_model # _____ # Complete the code to define the embedding model
90
+
91
+ #================================Creating Langgraph agent======================#
92
+
93
+ class AgentState(TypedDict):
94
+ query: str # The current user query
95
+ expanded_query: str # The expanded version of the user query
96
+ context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
97
+ response: str # The generated response to the user query
98
+ precision_score: float # The precision score of the response
99
+ groundedness_score: float # The groundedness score of the response
100
+ groundedness_loop_count: int # Counter for groundedness refinement loops
101
+ precision_loop_count: int # Counter for precision refinement loops
102
+ feedback: str
103
+ query_feedback: str
104
+ groundedness_check: bool
105
+ loop_max_iter: int
106
+
107
+ def expand_query(state):
108
  """
109
+ Expands the user query to improve retrieval of nutrition disorder-related information.
110
+
111
  Args:
112
+ state (Dict): The current state of the workflow, containing the user query.
113
+
114
  Returns:
115
+ Dict: The updated state with the expanded query.
116
  """
117
+ print("---------Expanding Query---------")
118
+ system_message = # ________________________
119
+ '''
120
+ You are a domain expert assisting in answering questions related to research papers.
121
+ Convert the user query into something that a nutritionist would understand. Use domain related words.
122
+ Return 3 related search queries based on the user's request seperated by newline.
123
+ Return only 3 versions of the question as a list.
124
+ Perform query expansion on the question received. If there are multiple common ways of phrasing a user question \
125
+ or common synonyms for key words in the question, make sure to return multiple versions \
126
+ of the query with the different phrasings.
127
+ If the query has multiple parts, split them into separate simpler queries. This is the only case where you can generate more than 3 queries.
128
+ If there are acronyms or words you are not familiar with, do not try to rephrase them.
129
+ Generate only a list of questions. Do not mention anything before or after the list.
130
+ Use the query feeback if provided to craft the search queries.
131
+ '''
132
+
133
+ expand_prompt = ChatPromptTemplate.from_messages([
134
+ ("system", system_message),
135
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
136
 
137
+ ])
 
138
 
139
+ chain = expand_prompt | llm | StrOutputParser()
140
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
141
+ print("expanded_query", expanded_query)
142
+ state["expanded_query"] = expanded_query
143
+ return state
144
+
145
+
146
+ # Initialize the Chroma vector store for retrieving documents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  vector_store = Chroma(
148
+ collection_name="nutritional_hypotheticals",
149
+ persist_directory="./nutritional_db2",
150
  embedding_function=embedding_model
 
151
  )
152
 
153
+ # Create a retriever from the vector store
154
  retriever = vector_store.as_retriever(
155
  search_type='similarity',
156
+ search_kwargs={'k': 3}
157
  )
158
 
159
+ def retrieve_context(state):
160
+ """
161
+ Retrieves context from the vector store using the expanded or original query.
162
 
163
+ Args:
164
+ state (Dict): The current state of the workflow, containing the query and expanded query.
 
 
165
 
166
+ Returns:
167
+ Dict: The updated state with the retrieved context.
168
+ """
169
+ print("---------retrieve_context---------")
170
+ query = state['query'] # state['_____'] # Complete the code to define the key for the expanded query
171
+ #print("Query used for retrieval:", query) # Debugging: Print the query
172
+
173
+ # Retrieve documents from the vector store
174
+ docs = retriever.invoke(query)
175
+ print("Retrieved documents:", docs) # Debugging: Print the raw docs object
176
+
177
+ # Extract both page_content and metadata from each document
178
+ context= [
179
+ {
180
+ "content": doc.page_content, # The actual content of the document
181
+ "metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
182
+ }
183
+ for doc in docs
184
+ ]
185
+ state['context'] = context # state['_____'] = context # Complete the code to define the key for storing the context
186
+ print("Extracted context with metadata:", context) # Debugging: Print the extracted context
187
+ #print(f"Groundedness loop count: {state['groundedness_loop_count']}")
188
+ return state
189
 
 
 
190
 
191
+ def craft_response(state: Dict) -> Dict:
192
  """
193
+ Generates a response using the retrieved context, focusing on nutrition disorders.
194
+
195
+ Args:
196
+ state (Dict): The current state of the workflow, containing the query and retrieved context.
197
+
198
+ Returns:
199
+ Dict: The updated state with the generated response.
200
  """
201
+ print("---------craft_response---------")
202
+ system_message = # ________________________
203
+ '''
204
+ Generates a response to a user query and context provided.
205
 
206
+ Parameters:
207
+ query (str): The user's query and expanded queries based on user's query.
208
+ context (str): The documents retrieved relevant to the queries.
209
 
210
+ Returns:
211
+ response (str): The response generated by the model.
212
 
213
+ The function performs the following steps:
214
+ 1. Constructs a prompt containing system and user prompts.
215
+ 2. Sends the prompt containing user queries with context provided to the GPT model to generate a response.
216
+ 3. Displays the response.
 
 
 
 
 
 
 
 
 
217
 
218
+ The answer you provide must come from the user queries with context provided.
219
+ If feedback is provided, use it to craft the response.
220
+ If information provided is not enough to answer the query respons with 'I don't know the answer. Not in my records.'
221
+ '''
222
 
223
+ response_prompt = ChatPromptTemplate.from_messages([
224
+ ("system", system_message),
225
+ ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
226
+ ])
227
 
228
+ chain = response_prompt | llm
229
+ response = chain.invoke({
230
+ "query": state['query'],
231
+ "context": "\n".join([doc["content"] for doc in state['context']]),
232
+ "feedback": state["feedback"] # ________________ # add feedback to the prompt
233
+ })
234
+ state['response'] = response
235
+ print("intermediate response: ", response)
236
 
237
+ return state
238
 
 
239
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
+ def score_groundedness(state: Dict) -> Dict:
242
+ """
243
+ Checks whether the response is grounded in the retrieved context.
244
 
245
+ Args:
246
+ state (Dict): The current state of the workflow, containing the response and context.
247
+
248
+ Returns:
249
+ Dict: The updated state with the groundedness score.
250
+ """
251
+ print("---------check_groundedness---------")
252
+ system_message = # ________________________
253
+ '''
254
+ You are tasked with rating AI generated answers to questions posed by users.
255
+ Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context.
256
+ In the input, the context is {context}, while the AI generated response is {response}.
257
+
258
+ Evaluation criteria:
259
+ The task is to judge the extent to which the metric is followed by the answer.
260
+ 1 - The metric is not followed at all
261
+ 2 - The metric is followed only to a limited extent
262
+ 3 - The metric is followed to a good extent
263
+ 4 - The metric is followed mostly
264
+ 5 - The metric is followed completely
265
+
266
+ The answer should be derived only from the information presented in the context
267
+
268
+ Do not show any instructions for deriving your answer.
269
+
270
+ Output your result as a float number between 0 and 1 using the evaluation criteria.
271
+ The better the criteria, the cloase it is to 1 and the worse the criteria, the closer it is to 0.
272
+ '''
273
+
274
+ groundedness_prompt = ChatPromptTemplate.from_messages([
275
+ ("system", system_message),
276
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
277
+ ])
278
 
279
+ chain = groundedness_prompt | llm | StrOutputParser()
280
+ groundedness_score = float(chain.invoke({
281
+ "context": "\n".join([doc["content"] for doc in state['context']]),
282
+ "response": state['response'] # __________ # Complete the code to define the response
283
+ }))
284
+ print("groundedness_score: ", groundedness_score)
285
+ state['groundedness_loop_count'] += 1
286
+ print("#########Groundedness Incremented###########")
287
+ state['groundedness_score'] = groundedness_score
 
288
 
289
+ return state
290
 
291
+ def check_precision(state: Dict) -> Dict:
 
292
  """
293
+ Checks whether the response precisely addresses the user’s query.
294
+
295
  Args:
296
+ state (Dict): The current state of the workflow, containing the query and response.
297
+
 
 
298
  Returns:
299
+ Dict: The updated state with the precision score.
300
  """
301
+ print("---------check_precision---------")
302
+ system_message = # ________________________
303
+ '''
304
+ Given question, answer and context verify if the context was useful in arriving at the given answer.
305
+ Give verdict as "1" if useful and "0" if not useful.
306
+ Output your result as a float number between 0 and 1
307
+ Give verdict as a scaled numeric value of type float between 0 and 1, such that
308
+ 0 or near 0 if it is least useful, 0.5 or near 0.5 if retry is warranted, and 1 or close to 1 is most useful.
309
+ Do not show any instructions for deriving your answer.
310
+ '''
311
+
312
+ precision_prompt = ChatPromptTemplate.from_messages([
313
+ ("system", system_message),
314
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
315
+ ])
316
 
317
+ chain = precision_prompt | llm | StrOutputParser() # _____________ | llm | StrOutputParser() # Complete the code to define the chain of processing
318
+ precision_score = float(chain.invoke({
319
+ "query": state['query'],
320
+ "response": state['response'] # ______________ # Complete the code to access the response from the state
321
+ }))
322
+ state['precision_score'] = precision_score
323
+ print("precision_score:", precision_score)
324
+ state['precision_loop_count'] +=1
325
+ print("#########Precision Incremented###########")
326
+ return state
327
+
328
+ def refine_response(state: Dict) -> Dict:
329
  """
330
+ Suggests improvements for the generated response.
331
+
332
  Args:
333
+ state (Dict): The current state of the workflow, containing the query and response.
334
+
 
335
  Returns:
336
+ Dict: The updated state with response refinement suggestions.
337
  """
338
+ print("---------refine_response---------")
339
 
340
+ system_message = # ________________________
341
+ '''
342
+ Since the last response failded the groundedness test, and is deemed not satisfactory,
343
+ use the feedback in terms of the query, context and the last response
344
+ to identify potential gaps, ambiguities, or missing details, and
345
+ to suggest improvements to enhance accuracy and completeness of the response.
346
+ '''
347
 
348
+ refine_response_prompt = ChatPromptTemplate.from_messages([
349
+ ("system", system_message),
350
+ ("user", "Query: {query}\nResponse: {response}\n\n"
351
+ "What improvements can be made to enhance accuracy and completeness?")
352
+ ])
353
 
354
+ chain = refine_response_prompt | llm| StrOutputParser()
355
 
356
+ # Store response suggestions in a structured format
357
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
358
+ print("feedback: ", feedback)
359
+ print(f"State: {state}")
360
+ state['feedback'] = feedback
361
+ return state
362
+
363
+
364
+
365
+ def refine_query(state: Dict) -> Dict:
366
  """
367
+ Suggests improvements for the expanded query.
368
+
369
  Args:
370
+ state (Dict): The current state of the workflow, containing the query and expanded query.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
+ Returns:
373
+ Dict: The updated state with query refinement suggestions.
374
+ """
375
+ print("---------refine_query---------")
376
+ system_message = # ________________________
377
+ '''
378
+ Since the last response failded the precision test, and is deemed not satisfactory,
379
+ use the feedback in terms of the query, context and re-generate extended queries
380
+ to identify specific keywords, scope refinements, or missing details, and
381
+ to provides structured suggestions for improvement to enhance accuracy and completeness of the response.
382
+ '''
383
+
384
+ refine_query_prompt = ChatPromptTemplate.from_messages([
385
  ("system", system_message),
386
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
387
+ "What improvements can be made for a better search?")
388
  ])
389
 
390
+ chain = refine_query_prompt | llm | StrOutputParser()
391
+
392
+ # Store refinement suggestions without modifying the original expanded query
393
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
394
+ print("query_feedback: ", query_feedback)
395
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
396
+ state['query_feedback'] = query_feedback
397
+ return state
398
+
399
+
400
+
401
+ def should_continue_groundedness(state):
402
+ """Decides if groundedness is sufficient or needs improvement."""
403
+ print("---------should_continue_groundedness---------")
404
+ print("groundedness loop count: ", state['groundedness_loop_count'])
405
+ if state['groundedness_score'] >= 0.8 # _____: # Complete the code to define the threshold for groundedness
406
+ print("Moving to precision")
407
+ return "check_precision"
408
+ else:
409
+ if state["groundedness_loop_count"] > state['loop_max_iter']:
410
+ return "max_iterations_reached"
411
+ else:
412
+ print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
413
+ return "refine_response"
414
+
415
+
416
+ def should_continue_precision(state: Dict) -> str:
417
+ """Decides if precision is sufficient or needs improvement."""
418
+ print("---------should_continue_precision---------")
419
+ print("precision loop count: ", state["precision_loop_count"])
420
+ if state['precision_score'] > 0.8: # ___________: # Threshold for precision
421
+ return "pass" # Complete the workflow
422
+ else:
423
+ if state["precision_loop_count"] >= state['loop_max_iter']: # ___________: # Maximum allowed loops
424
+ return "max_iterations_reached"
425
+ else:
426
+ print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
427
+ return "refine_query" # ____________ # Refine the query
428
 
429
 
 
430
 
431
 
432
+ def max_iterations_reached(state: Dict) -> Dict:
433
+ """Handles the case when the maximum number of iterations is reached."""
434
+ print("---------max_iterations_reached---------")
435
+ """Handles the case when the maximum number of iterations is reached."""
436
+ response = "I'm unable to refine the response further. Please provide more context or clarify your question."
437
+ state['response'] = response
438
+ return state
439
 
 
 
440
 
 
441
 
442
+ from langgraph.graph import END, StateGraph, START
 
 
 
 
 
 
 
443
 
444
+ def create_workflow() -> StateGraph:
445
+ """Creates the updated workflow for the AI nutrition agent."""
446
+ workflow = StateGraph(State) # StateGraph(_____ ) # Complete the code to define the initial state of the agent
447
 
448
+ # Add processing nodes
449
+ workflow.add_node("expand_query", expand_query) # _____ ) # Step 1: Expand user query. Complete with the function to expand the query
450
+ workflow.add_node("retrieve_context", retrieve_context) #_____ ) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
451
+ workflow.add_node("craft_response", craft_response) # _____ ) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
452
+ workflow.add_node("score_groundedness", score_groundedness) # _____ ) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
453
+ workflow.add_node("refine_response", refine_response) # _____ ) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
454
+ workflow.add_node("check_precision", check_precision) # _____ ) # Step 6: Evaluate response precision. Complete with the function to check precision
455
+ workflow.add_node("refine_query", refine_query) # _____ ) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
456
+ workflow.add_node("max_iterations_reached", max_iterations_reached) # _____ ) # Step 8: Handle max iterations. Complete with the function to handle max iterations
457
 
458
+ # Main flow edges
459
+ workflow.add_edge(START, "expand_query")
460
+ workflow.add_edge("expand_query", "retrieve_context")
461
+ workflow.add_edge("retrieve_context", "craft_response")
462
+ workflow.add_edge("craft_response", "score_groundedness")
463
 
464
+ # Conditional edges based on groundedness check
465
+ workflow.add_conditional_edges(
466
+ "score_groundedness",
467
+ should_continue_groundedness, # ___________, # Use the conditional function
468
+ {
469
+ "check_precision": "check_precision", # ___________, # If well-grounded, proceed to precision check.
470
+ "refine_response": "refine_response", # ___________, # If not, refine the response.
471
+ "max_iterations_reached": max_iterations_reached # ___________ # If max loops reached, exit.
472
+ }
473
+ )
474
 
475
+ workflow.add_edge("refine_response", "craft_response") # __________, ___________) # Refined responses are reprocessed.
476
+
477
+ # Conditional edges based on precision check
478
+ workflow.add_conditional_edges(
479
+ "check_precision",
480
+ should_continue_precision, # ___________, # Use the conditional function
481
+ {
482
+ "pass": END, # ___________, # If precise, complete the workflow.
483
+ "refine_query": "refine_query" # ___________, # If imprecise, refine the query.
484
+ "max_iterations_reached": "max_iterations_reached" # ___________ # If max loops reached, exit.
485
+ }
486
+ )
487
 
488
+ workflow.add_edge("refine_query", "expand_query") # __________, ___________) # Refined queries go through expansion again.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ workflow.add_edge("max_iterations_reached", END)
 
 
 
 
 
 
 
491
 
492
+ return workflow
 
 
 
 
 
 
493
 
 
494
 
 
 
495
 
 
 
496
 
497
+ #=========================== Defining the agentic rag tool ====================#
498
+ WORKFLOW_APP = create_workflow().compile()
499
+ @tool
500
+ def agentic_rag(query: str):
501
+ """
502
+ Runs the RAG-based agent with conversation history for context-aware responses.
503
 
504
+ Args:
505
+ query (str): The current user query.
506
+
507
+ Returns:
508
+ Dict[str, Any]: The updated state with the generated response and conversation history.
509
+ """
510
+ # Initialize state with necessary parameters
511
+ inputs = {
512
+ "query": query, # Current user query
513
+ "expanded_query": "", # "_____", # Complete the code to define the expanded version of the query
514
+ "context": [], # Retrieved documents (initially empty)
515
+ "response": "", # "_____", # Complete the code to define the AI-generated response
516
+ "precision_score": 0.0, # _____, # Complete the code to define the precision score of the response
517
+ "groundedness_score": 0.0 # _____, # Complete the code to define the groundedness score of the response
518
+ "groundedness_loop_count": 0 # _____, # Complete the code to define the counter for groundedness loops
519
+ "precision_loop_count": 0, # _____, # Complete the code to define the counter for precision loops
520
+ "feedback": "", # "_____", # Complete the code to define the feedback
521
+ "query_feedback": "", # "_____", # Complete the code to define the query feedback
522
+ "loop_max_iter": 3 # _____ # Complete the code to define the maximum number of iterations for loops
523
+ }
524
 
525
+ output = WORKFLOW_APP.invoke(inputs)
526
 
527
+ return output
 
 
 
 
 
 
528
 
529
+
530
+ #================================ Guardrails ===========================#
531
+ llama_guard_client = Groq(api_key=groq_api_key) # Groq(api_key=llama_api_key)
532
+ # Function to filter user input with Llama Guard
533
+ def filter_input_with_llama_guard(user_input, model="llama-guard-3-8b"):
534
+ """
535
+ Filters user input using Llama Guard to ensure it is safe.
536
+
537
+ Parameters:
538
+ - user_input: The input provided by the user.
539
+ - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
540
+
541
+ Returns:
542
+ - The filtered and safe input.
543
+ """
544
  try:
545
+ # Create a request to Llama Guard to filter the user input
546
+ response = llama_guard_client.chat.completions.create(
547
+ messages=[{"role": "user", "content": user_input}],
548
+ model=model,
549
+ )
550
+ # Return the filtered input
551
+ return response.choices[0].message.content.strip()
552
  except Exception as e:
553
+ print(f"Error with Llama Guard: {e}")
554
+ return None
555
+
556
+
557
+ #============================= Adding Memory to the agent using mem0 ===============================#
558
 
559
+ class NutritionBot:
560
+ def __init__(self):
561
+ """
562
+ Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
563
+ """
564
 
565
+ # Initialize a memory client to store and retrieve customer interactions
566
+ self.memory = MemoryClient(api_key=MEM0_api_key) # userdata.get("mem0")) # Complete the code to define the memory client API key
567
 
568
+ # Initialize the OpenAI client using the provided credentials
569
+ self.client = ChatOpenAI(
570
+ model_name="gpt-4o", # Used gpt-4o to get improved results; Specify the model to use (e.g., GPT-4 optimized version)
571
+ api_key=config.get("API_KEY"), # API key for authentication
572
+ endpoint = config.get("OPENAI_API_BASE"),
573
+ temperature=0 # Controls randomness in responses; 0 ensures deterministic results
574
+ )
575
 
576
+ # Define tools available to the chatbot, such as web search
577
+ tools = [agentic_rag]
578
+
579
+ # Define the system prompt to set the behavior of the chatbot
580
+ system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
581
+ Guidelines for Interaction:
582
+ Maintain a polite, professional, and reassuring tone.
583
+ Show genuine empathy for customer concerns and health challenges.
584
+ Reference past interactions to provide personalized and consistent advice.
585
+ Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
586
+ Ensure consistent and accurate information across conversations.
587
+ If any detail is unclear or missing, proactively ask for clarification.
588
+ Always use the agentic_rag tool to retrieve up-to-date and evidence-based nutrition insights.
589
+ Keep track of ongoing issues and follow-ups to ensure continuity in support.
590
+ Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
591
 
592
+ """
 
 
 
 
593
 
594
+ # Build the prompt template for the agent
595
+ prompt = ChatPromptTemplate.from_messages([
596
+ ("system", system_prompt), # System instructions
597
+ ("human", "{input}"), # Placeholder for human input
598
+ ("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
599
+ ])
600
 
601
+ # Create an agent capable of interacting with tools and executing tasks
602
+ agent = create_tool_calling_agent(self.client, tools, prompt)
 
603
 
604
+ # Wrap the agent in an executor to manage tool interactions and execution flow
605
+ self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
 
606
 
 
 
 
 
 
 
607
 
608
+ def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
609
+ """
610
+ Store customer interaction in memory for future reference.
611
+
612
+ Args:
613
+ user_id (str): Unique identifier for the customer.
614
+ message (str): Customer's query or message.
615
+ response (str): Chatbot's response.
616
+ metadata (Dict, optional): Additional metadata for the interaction.
617
+ """
618
+ if metadata is None:
619
+ metadata = {}
620
+
621
+ # Add a timestamp to the metadata for tracking purposes
622
+ metadata["timestamp"] = datetime.now().isoformat()
623
+
624
+ # Format the conversation for storage
625
+ conversation = [
626
+ {"role": "user", "content": message},
627
+ {"role": "assistant", "content": response}
628
+ ]
629
+
630
+ # Store the interaction in the memory client
631
+ self.memory.add(
632
+ conversation,
633
+ user_id=user_id,
634
+ output_format="v1.1",
635
+ metadata=metadata
636
+ )
637
+
638
+
639
+ def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
640
+ """
641
+ Retrieve past interactions relevant to the current query.
642
+
643
+ Args:
644
+ user_id (str): Unique identifier for the customer.
645
+ query (str): The customer's current query.
646
+
647
+ Returns:
648
+ List[Dict]: A list of relevant past interactions.
649
+ """
650
+ return self.memory.search(
651
+ query=query, # Search for interactions related to the query
652
+ user_id=user_id, # Restrict search to the specific user
653
+ limit=_____ # Complete the code to define the limit for retrieved interactions
654
+ )
655
+
656
+
657
+ def handle_customer_query(self, user_id: str, query: str) -> str:
658
+ """
659
+ Process a customer's query and provide a response, taking into account past interactions.
660
+
661
+ Args:
662
+ user_id (str): Unique identifier for the customer.
663
+ query (str): Customer's query.
664
+
665
+ Returns:
666
+ str: Chatbot's response.
667
+ """
668
+
669
+ # Retrieve relevant past interactions for context
670
+ relevant_history = self.get_relevant_history(user_id, query)
671
+
672
+ # Build a context string from the relevant history
673
+ context = "Previous relevant interactions:\n"
674
+ for memory in relevant_history:
675
+ context += f"Customer: {memory['memory']}\n" # Customer's past messages
676
+ context += f"Support: {memory['memory']}\n" # Chatbot's past responses
677
+ context += "---\n"
678
+
679
+ # Print context for debugging purposes
680
+ print("Context: ", context)
681
+
682
+ # Prepare a prompt combining past context and the current query
683
+ prompt = f"""
684
+ Context:
685
+ {context}
686
+
687
+ Current customer query: {query}
688
+
689
+ Provide a helpful response that takes into account any relevant past interactions.
690
+ """
691
+
692
+ # Generate a response using the agent
693
+ response = self.agent_executor.invoke({"input": prompt})
694
+
695
+ # Store the current interaction for future reference
696
+ self.store_customer_interaction(
697
+ user_id=user_id,
698
+ message=query,
699
+ response=response["output"],
700
+ metadata={"type": "support_query"}
701
+ )
702
+
703
+ # Return the chatbot's response
704
+ return response['output']
705
+
706
+
707
+ #=====================User Interface using streamlit ===========================#
708
+ def nutrition_disorder_streamlit():
709
+ """
710
+ A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
711
+ """
712
+ st.title("Nutrition Disorder Specialist")
713
+ st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
714
+ st.write("Type 'exit' to end the conversation.")
715
+
716
+ # Initialize session state for chat history and user_id if they don't exist
717
+ if 'chat_history' not in st.session_state:
718
+ st.session_state.chat_history = []
719
+ if 'user_id' not in st.session_state:
720
+ st.session_state.user_id = None
721
+
722
+ # Login form: Only if user is not logged in
723
+ if st.session_state.user_id is None:
724
+ with st.form("login_form", clear_on_submit=True):
725
+ user_id = st.text_input("Please enter your name to begin:")
726
+ submit_button = st.form_submit_button("Login")
727
+ if submit_button and user_id:
728
+ st.session_state.user_id = user_id
729
+ st.session_state.chat_history.append({
730
+ "role": "assistant",
731
+ "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
732
+ })
733
+ st.session_state.login_submitted = True # Set flag to trigger rerun
734
+ if st.session_state.get("login_submitted", False):
735
+ st.session_state.pop("login_submitted")
736
+ st.rerun()
737
  else:
738
+ # Display chat history
739
+ for message in st.session_state.chat_history:
740
+ with st.chat_message(message["role"]):
741
+ st.write(message["content"])
742
+
743
+ # Chat input with custom placeholder text
744
+ user_query = st.chat_input("Type your question here (or 'exit' to end)...") # __________) # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
745
+ if user_query:
746
+ if user_query.lower() == "exit":
747
+ st.session_state.chat_history.append({"role": "user", "content": "exit"})
748
+ with st.chat_message("user"):
749
+ st.write("exit")
750
+ goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders."
751
+ st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
752
+ with st.chat_message("assistant"):
753
+ st.write(goodbye_msg)
754
+ st.session_state.user_id = None
755
+ st.rerun()
756
+ return
757
+
758
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
759
+ with st.chat_message("user"):
760
+ st.write(user_query)
761
+
762
+ # Filter input using Llama Guard
763
+ filtered_result = filter_input_with_llama_guard(user_query) # __________(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
764
+ filtered_result = filtered_result.replace("\n", " ") # Normalize the result
765
+
766
+ # Check if input is safe based on allowed statuses
767
+ if filtered_result in ["safe", "safe S7", "safe S6"]: # __________, __________, __________]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
768
+ try:
769
+ if 'chatbot' not in st.session_state:
770
+ st.session_state.chatbot = NutritionBot # () __________() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
771
+ response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query) # __________(st.session_state.user_id, user_query)
772
+ # Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
773
+ st.write(response)
774
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
775
+ except Exception as e:
776
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
777
+ st.write(error_msg)
778
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
779
+ else:
780
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
781
+ st.write(inappropriate_msg)
782
+ st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
783
 
784
  if __name__ == "__main__":
785
+ nutrition_disorder_streamlit()