Spaces:
Running
Running
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -86,8 +86,8 @@ llm = ChatOpenAI(
|
|
86 |
# This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
|
87 |
|
88 |
# set the LLM and embedding model in the LlamaIndex settings.
|
89 |
-
Settings.llm = llm #
|
90 |
-
Settings.embedding = embedding_model #
|
91 |
|
92 |
#================================Creating Langgraph agent======================#
|
93 |
|
@@ -167,7 +167,7 @@ def retrieve_context(state):
|
|
167 |
Dict: The updated state with the retrieved context.
|
168 |
"""
|
169 |
print("---------retrieve_context---------")
|
170 |
-
query = state['query'] #
|
171 |
#print("Query used for retrieval:", query) # Debugging: Print the query
|
172 |
|
173 |
# Retrieve documents from the vector store
|
@@ -182,7 +182,7 @@ def retrieve_context(state):
|
|
182 |
}
|
183 |
for doc in docs
|
184 |
]
|
185 |
-
state['context'] = context #
|
186 |
print("Extracted context with metadata:", context) # Debugging: Print the extracted context
|
187 |
#print(f"Groundedness loop count: {state['groundedness_loop_count']}")
|
188 |
return state
|
@@ -228,7 +228,7 @@ def craft_response(state: Dict) -> Dict:
|
|
228 |
response = chain.invoke({
|
229 |
"query": state['query'],
|
230 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
231 |
-
"feedback": state["feedback"] #
|
232 |
})
|
233 |
state['response'] = response
|
234 |
print("intermediate response: ", response)
|
@@ -277,7 +277,7 @@ def score_groundedness(state: Dict) -> Dict:
|
|
277 |
chain = groundedness_prompt | llm | StrOutputParser()
|
278 |
groundedness_score = float(chain.invoke({
|
279 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
280 |
-
"response": state['response'] #
|
281 |
}))
|
282 |
print("groundedness_score: ", groundedness_score)
|
283 |
state['groundedness_loop_count'] += 1
|
@@ -311,10 +311,10 @@ def check_precision(state: Dict) -> Dict:
|
|
311 |
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
|
312 |
])
|
313 |
|
314 |
-
chain = precision_prompt | llm | StrOutputParser() #
|
315 |
precision_score = float(chain.invoke({
|
316 |
"query": state['query'],
|
317 |
-
"response": state['response'] #
|
318 |
}))
|
319 |
state['precision_score'] = precision_score
|
320 |
print("precision_score:", precision_score)
|
@@ -397,7 +397,7 @@ def should_continue_groundedness(state):
|
|
397 |
"""Decides if groundedness is sufficient or needs improvement."""
|
398 |
print("---------should_continue_groundedness---------")
|
399 |
print("groundedness loop count: ", state['groundedness_loop_count'])
|
400 |
-
if state['groundedness_score'] >= 0.8 #
|
401 |
print("Moving to precision")
|
402 |
return "check_precision"
|
403 |
else:
|
@@ -412,14 +412,14 @@ def should_continue_precision(state: Dict) -> str:
|
|
412 |
"""Decides if precision is sufficient or needs improvement."""
|
413 |
print("---------should_continue_precision---------")
|
414 |
print("precision loop count: ", state["precision_loop_count"])
|
415 |
-
if state['precision_score'] > 0.8: #
|
416 |
return "pass" # Complete the workflow
|
417 |
else:
|
418 |
-
if state["precision_loop_count"] >= state['loop_max_iter']: #
|
419 |
return "max_iterations_reached"
|
420 |
else:
|
421 |
print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
|
422 |
-
return "refine_query" #
|
423 |
|
424 |
|
425 |
|
@@ -438,17 +438,17 @@ from langgraph.graph import END, StateGraph, START
|
|
438 |
|
439 |
def create_workflow() -> StateGraph:
|
440 |
"""Creates the updated workflow for the AI nutrition agent."""
|
441 |
-
workflow = StateGraph(
|
442 |
|
443 |
# Add processing nodes
|
444 |
-
workflow.add_node("expand_query", expand_query) #
|
445 |
-
workflow.add_node("retrieve_context", retrieve_context) #
|
446 |
-
workflow.add_node("craft_response", craft_response) #
|
447 |
-
workflow.add_node("score_groundedness", score_groundedness) #
|
448 |
-
workflow.add_node("refine_response", refine_response) #
|
449 |
-
workflow.add_node("check_precision", check_precision) #
|
450 |
-
workflow.add_node("refine_query", refine_query) #
|
451 |
-
workflow.add_node("max_iterations_reached", max_iterations_reached) #
|
452 |
|
453 |
# Main flow edges
|
454 |
workflow.add_edge(START, "expand_query")
|
@@ -459,28 +459,28 @@ def create_workflow() -> StateGraph:
|
|
459 |
# Conditional edges based on groundedness check
|
460 |
workflow.add_conditional_edges(
|
461 |
"score_groundedness",
|
462 |
-
should_continue_groundedness, #
|
463 |
{
|
464 |
-
"check_precision": "check_precision", #
|
465 |
-
"refine_response": "refine_response", #
|
466 |
-
"max_iterations_reached": max_iterations_reached #
|
467 |
}
|
468 |
)
|
469 |
|
470 |
-
workflow.add_edge("refine_response", "craft_response") #
|
471 |
|
472 |
# Conditional edges based on precision check
|
473 |
workflow.add_conditional_edges(
|
474 |
"check_precision",
|
475 |
-
should_continue_precision, #
|
476 |
{
|
477 |
-
"pass": END, #
|
478 |
-
"refine_query": "refine_query", #
|
479 |
-
"max_iterations_reached": "max_iterations_reached" #
|
480 |
}
|
481 |
)
|
482 |
|
483 |
-
workflow.add_edge("refine_query", "expand_query") #
|
484 |
|
485 |
workflow.add_edge("max_iterations_reached", END)
|
486 |
|
@@ -505,16 +505,16 @@ def agentic_rag(query: str):
|
|
505 |
# Initialize state with necessary parameters
|
506 |
inputs = {
|
507 |
"query": query, # Current user query
|
508 |
-
"expanded_query": "", #
|
509 |
"context": [], # Retrieved documents (initially empty)
|
510 |
-
"response": "", #
|
511 |
-
"precision_score": 0.0, #
|
512 |
-
"groundedness_score": 0.0, #
|
513 |
-
"groundedness_loop_count": 0, #
|
514 |
-
"precision_loop_count": 0, #
|
515 |
-
"feedback": "", #
|
516 |
-
"query_feedback": "", #
|
517 |
-
"loop_max_iter": 3 #
|
518 |
}
|
519 |
|
520 |
output = WORKFLOW_APP.invoke(inputs)
|
@@ -645,7 +645,7 @@ class NutritionBot:
|
|
645 |
return self.memory.search(
|
646 |
query=query, # Search for interactions related to the query
|
647 |
user_id=user_id, # Restrict search to the specific user
|
648 |
-
limit=
|
649 |
)
|
650 |
|
651 |
|
@@ -736,7 +736,7 @@ def nutrition_disorder_streamlit():
|
|
736 |
st.write(message["content"])
|
737 |
|
738 |
# Chat input with custom placeholder text
|
739 |
-
user_query = st.chat_input("Type your question here (or 'exit' to end)...") #
|
740 |
if user_query:
|
741 |
if user_query.lower() == "exit":
|
742 |
st.session_state.chat_history.append({"role": "user", "content": "exit"})
|
@@ -755,15 +755,15 @@ def nutrition_disorder_streamlit():
|
|
755 |
st.write(user_query)
|
756 |
|
757 |
# Filter input using Llama Guard
|
758 |
-
filtered_result = filter_input_with_llama_guard(user_query) #
|
759 |
filtered_result = filtered_result.replace("\n", " ") # Normalize the result
|
760 |
|
761 |
# Check if input is safe based on allowed statuses
|
762 |
-
if filtered_result in ["safe", "safe S7", "safe S6"]: #
|
763 |
try:
|
764 |
if 'chatbot' not in st.session_state:
|
765 |
-
st.session_state.chatbot = NutritionBot() #
|
766 |
-
response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
|
767 |
# Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
|
768 |
st.write(response)
|
769 |
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
|
|
86 |
# This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
|
87 |
|
88 |
# set the LLM and embedding model in the LlamaIndex settings.
|
89 |
+
Settings.llm = llm # Complete the code to define the LLM model
|
90 |
+
Settings.embedding = embedding_model # Complete the code to define the embedding model
|
91 |
|
92 |
#================================Creating Langgraph agent======================#
|
93 |
|
|
|
167 |
Dict: The updated state with the retrieved context.
|
168 |
"""
|
169 |
print("---------retrieve_context---------")
|
170 |
+
query = state['query'] # Complete the code to define the key for the expanded query
|
171 |
#print("Query used for retrieval:", query) # Debugging: Print the query
|
172 |
|
173 |
# Retrieve documents from the vector store
|
|
|
182 |
}
|
183 |
for doc in docs
|
184 |
]
|
185 |
+
state['context'] = context # Complete the code to define the key for storing the context
|
186 |
print("Extracted context with metadata:", context) # Debugging: Print the extracted context
|
187 |
#print(f"Groundedness loop count: {state['groundedness_loop_count']}")
|
188 |
return state
|
|
|
228 |
response = chain.invoke({
|
229 |
"query": state['query'],
|
230 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
231 |
+
"feedback": state["feedback"] # add feedback to the prompt
|
232 |
})
|
233 |
state['response'] = response
|
234 |
print("intermediate response: ", response)
|
|
|
277 |
chain = groundedness_prompt | llm | StrOutputParser()
|
278 |
groundedness_score = float(chain.invoke({
|
279 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
280 |
+
"response": state['response'] # Complete the code to define the response
|
281 |
}))
|
282 |
print("groundedness_score: ", groundedness_score)
|
283 |
state['groundedness_loop_count'] += 1
|
|
|
311 |
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
|
312 |
])
|
313 |
|
314 |
+
chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
|
315 |
precision_score = float(chain.invoke({
|
316 |
"query": state['query'],
|
317 |
+
"response": state['response'] # Complete the code to access the response from the state
|
318 |
}))
|
319 |
state['precision_score'] = precision_score
|
320 |
print("precision_score:", precision_score)
|
|
|
397 |
"""Decides if groundedness is sufficient or needs improvement."""
|
398 |
print("---------should_continue_groundedness---------")
|
399 |
print("groundedness loop count: ", state['groundedness_loop_count'])
|
400 |
+
if state['groundedness_score'] >= 0.8 # Complete the code to define the threshold for groundedness
|
401 |
print("Moving to precision")
|
402 |
return "check_precision"
|
403 |
else:
|
|
|
412 |
"""Decides if precision is sufficient or needs improvement."""
|
413 |
print("---------should_continue_precision---------")
|
414 |
print("precision loop count: ", state["precision_loop_count"])
|
415 |
+
if state['precision_score'] > 0.8: # Threshold for precision
|
416 |
return "pass" # Complete the workflow
|
417 |
else:
|
418 |
+
if state["precision_loop_count"] >= state['loop_max_iter']: # Maximum allowed loops
|
419 |
return "max_iterations_reached"
|
420 |
else:
|
421 |
print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
|
422 |
+
return "refine_query" # Refine the query
|
423 |
|
424 |
|
425 |
|
|
|
438 |
|
439 |
def create_workflow() -> StateGraph:
|
440 |
"""Creates the updated workflow for the AI nutrition agent."""
|
441 |
+
workflow = StateGraph(START) # Complete the code to define the initial state of the agent
|
442 |
|
443 |
# Add processing nodes
|
444 |
+
workflow.add_node("expand_query", expand_query) # Step 1: Expand user query. Complete with the function to expand the query
|
445 |
+
workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
|
446 |
+
workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
|
447 |
+
workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
|
448 |
+
workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
|
449 |
+
workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision. Complete with the function to check precision
|
450 |
+
workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
|
451 |
+
workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations. Complete with the function to handle max iterations
|
452 |
|
453 |
# Main flow edges
|
454 |
workflow.add_edge(START, "expand_query")
|
|
|
459 |
# Conditional edges based on groundedness check
|
460 |
workflow.add_conditional_edges(
|
461 |
"score_groundedness",
|
462 |
+
should_continue_groundedness, # Use the conditional function
|
463 |
{
|
464 |
+
"check_precision": "check_precision", # If well-grounded, proceed to precision check.
|
465 |
+
"refine_response": "refine_response", # If not, refine the response.
|
466 |
+
"max_iterations_reached": max_iterations_reached # If max loops reached, exit.
|
467 |
}
|
468 |
)
|
469 |
|
470 |
+
workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
|
471 |
|
472 |
# Conditional edges based on precision check
|
473 |
workflow.add_conditional_edges(
|
474 |
"check_precision",
|
475 |
+
should_continue_precision, # Use the conditional function
|
476 |
{
|
477 |
+
"pass": END, # If precise, complete the workflow.
|
478 |
+
"refine_query": "refine_query", # If imprecise, refine the query.
|
479 |
+
"max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
|
480 |
}
|
481 |
)
|
482 |
|
483 |
+
workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
|
484 |
|
485 |
workflow.add_edge("max_iterations_reached", END)
|
486 |
|
|
|
505 |
# Initialize state with necessary parameters
|
506 |
inputs = {
|
507 |
"query": query, # Current user query
|
508 |
+
"expanded_query": "", # Complete the code to define the expanded version of the query
|
509 |
"context": [], # Retrieved documents (initially empty)
|
510 |
+
"response": "", # Complete the code to define the AI-generated response
|
511 |
+
"precision_score": 0.0, # Complete the code to define the precision score of the response
|
512 |
+
"groundedness_score": 0.0, # Complete the code to define the groundedness score of the response
|
513 |
+
"groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
|
514 |
+
"precision_loop_count": 0, # Complete the code to define the counter for precision loops
|
515 |
+
"feedback": "", # Complete the code to define the feedback
|
516 |
+
"query_feedback": "", # Complete the code to define the query feedback
|
517 |
+
"loop_max_iter": 3 # Complete the code to define the maximum number of iterations for loops
|
518 |
}
|
519 |
|
520 |
output = WORKFLOW_APP.invoke(inputs)
|
|
|
645 |
return self.memory.search(
|
646 |
query=query, # Search for interactions related to the query
|
647 |
user_id=user_id, # Restrict search to the specific user
|
648 |
+
limit=5 # Complete the code to define the limit for retrieved interactions
|
649 |
)
|
650 |
|
651 |
|
|
|
736 |
st.write(message["content"])
|
737 |
|
738 |
# Chat input with custom placeholder text
|
739 |
+
user_query = st.chat_input("Type your question here (or 'exit' to end)...") # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
|
740 |
if user_query:
|
741 |
if user_query.lower() == "exit":
|
742 |
st.session_state.chat_history.append({"role": "user", "content": "exit"})
|
|
|
755 |
st.write(user_query)
|
756 |
|
757 |
# Filter input using Llama Guard
|
758 |
+
filtered_result = filter_input_with_llama_guard(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
|
759 |
filtered_result = filtered_result.replace("\n", " ") # Normalize the result
|
760 |
|
761 |
# Check if input is safe based on allowed statuses
|
762 |
+
if filtered_result in ["safe", "safe S7", "safe S6"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
|
763 |
try:
|
764 |
if 'chatbot' not in st.session_state:
|
765 |
+
st.session_state.chatbot = NutritionBot() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
|
766 |
+
response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
|
767 |
# Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
|
768 |
st.write(response)
|
769 |
st.session_state.chat_history.append({"role": "assistant", "content": response})
|