jenngang commited on
Commit
5eb7f83
·
verified ·
1 Parent(s): 9cab537

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +46 -46
app.py CHANGED
@@ -86,8 +86,8 @@ llm = ChatOpenAI(
86
  # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
87
 
88
  # set the LLM and embedding model in the LlamaIndex settings.
89
- Settings.llm = llm # _____ # Complete the code to define the LLM model
90
- Settings.embedding = embedding_model # _____ # Complete the code to define the embedding model
91
 
92
  #================================Creating Langgraph agent======================#
93
 
@@ -167,7 +167,7 @@ def retrieve_context(state):
167
  Dict: The updated state with the retrieved context.
168
  """
169
  print("---------retrieve_context---------")
170
- query = state['query'] # state['_____'] # Complete the code to define the key for the expanded query
171
  #print("Query used for retrieval:", query) # Debugging: Print the query
172
 
173
  # Retrieve documents from the vector store
@@ -182,7 +182,7 @@ def retrieve_context(state):
182
  }
183
  for doc in docs
184
  ]
185
- state['context'] = context # state['_____'] = context # Complete the code to define the key for storing the context
186
  print("Extracted context with metadata:", context) # Debugging: Print the extracted context
187
  #print(f"Groundedness loop count: {state['groundedness_loop_count']}")
188
  return state
@@ -228,7 +228,7 @@ def craft_response(state: Dict) -> Dict:
228
  response = chain.invoke({
229
  "query": state['query'],
230
  "context": "\n".join([doc["content"] for doc in state['context']]),
231
- "feedback": state["feedback"] # ________________ # add feedback to the prompt
232
  })
233
  state['response'] = response
234
  print("intermediate response: ", response)
@@ -277,7 +277,7 @@ def score_groundedness(state: Dict) -> Dict:
277
  chain = groundedness_prompt | llm | StrOutputParser()
278
  groundedness_score = float(chain.invoke({
279
  "context": "\n".join([doc["content"] for doc in state['context']]),
280
- "response": state['response'] # __________ # Complete the code to define the response
281
  }))
282
  print("groundedness_score: ", groundedness_score)
283
  state['groundedness_loop_count'] += 1
@@ -311,10 +311,10 @@ def check_precision(state: Dict) -> Dict:
311
  ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
312
  ])
313
 
314
- chain = precision_prompt | llm | StrOutputParser() # _____________ | llm | StrOutputParser() # Complete the code to define the chain of processing
315
  precision_score = float(chain.invoke({
316
  "query": state['query'],
317
- "response": state['response'] # ______________ # Complete the code to access the response from the state
318
  }))
319
  state['precision_score'] = precision_score
320
  print("precision_score:", precision_score)
@@ -397,7 +397,7 @@ def should_continue_groundedness(state):
397
  """Decides if groundedness is sufficient or needs improvement."""
398
  print("---------should_continue_groundedness---------")
399
  print("groundedness loop count: ", state['groundedness_loop_count'])
400
- if state['groundedness_score'] >= 0.8 # _____: # Complete the code to define the threshold for groundedness
401
  print("Moving to precision")
402
  return "check_precision"
403
  else:
@@ -412,14 +412,14 @@ def should_continue_precision(state: Dict) -> str:
412
  """Decides if precision is sufficient or needs improvement."""
413
  print("---------should_continue_precision---------")
414
  print("precision loop count: ", state["precision_loop_count"])
415
- if state['precision_score'] > 0.8: # ___________: # Threshold for precision
416
  return "pass" # Complete the workflow
417
  else:
418
- if state["precision_loop_count"] >= state['loop_max_iter']: # ___________: # Maximum allowed loops
419
  return "max_iterations_reached"
420
  else:
421
  print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
422
- return "refine_query" # ____________ # Refine the query
423
 
424
 
425
 
@@ -438,17 +438,17 @@ from langgraph.graph import END, StateGraph, START
438
 
439
  def create_workflow() -> StateGraph:
440
  """Creates the updated workflow for the AI nutrition agent."""
441
- workflow = StateGraph(State) # StateGraph(_____ ) # Complete the code to define the initial state of the agent
442
 
443
  # Add processing nodes
444
- workflow.add_node("expand_query", expand_query) # _____ ) # Step 1: Expand user query. Complete with the function to expand the query
445
- workflow.add_node("retrieve_context", retrieve_context) #_____ ) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
446
- workflow.add_node("craft_response", craft_response) # _____ ) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
447
- workflow.add_node("score_groundedness", score_groundedness) # _____ ) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
448
- workflow.add_node("refine_response", refine_response) # _____ ) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
449
- workflow.add_node("check_precision", check_precision) # _____ ) # Step 6: Evaluate response precision. Complete with the function to check precision
450
- workflow.add_node("refine_query", refine_query) # _____ ) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
451
- workflow.add_node("max_iterations_reached", max_iterations_reached) # _____ ) # Step 8: Handle max iterations. Complete with the function to handle max iterations
452
 
453
  # Main flow edges
454
  workflow.add_edge(START, "expand_query")
@@ -459,28 +459,28 @@ def create_workflow() -> StateGraph:
459
  # Conditional edges based on groundedness check
460
  workflow.add_conditional_edges(
461
  "score_groundedness",
462
- should_continue_groundedness, # ___________, # Use the conditional function
463
  {
464
- "check_precision": "check_precision", # ___________, # If well-grounded, proceed to precision check.
465
- "refine_response": "refine_response", # ___________, # If not, refine the response.
466
- "max_iterations_reached": max_iterations_reached # ___________ # If max loops reached, exit.
467
  }
468
  )
469
 
470
- workflow.add_edge("refine_response", "craft_response") # __________, ___________) # Refined responses are reprocessed.
471
 
472
  # Conditional edges based on precision check
473
  workflow.add_conditional_edges(
474
  "check_precision",
475
- should_continue_precision, # ___________, # Use the conditional function
476
  {
477
- "pass": END, # ___________, # If precise, complete the workflow.
478
- "refine_query": "refine_query", # ___________, # If imprecise, refine the query.
479
- "max_iterations_reached": "max_iterations_reached" # ___________ # If max loops reached, exit.
480
  }
481
  )
482
 
483
- workflow.add_edge("refine_query", "expand_query") # __________, ___________) # Refined queries go through expansion again.
484
 
485
  workflow.add_edge("max_iterations_reached", END)
486
 
@@ -505,16 +505,16 @@ def agentic_rag(query: str):
505
  # Initialize state with necessary parameters
506
  inputs = {
507
  "query": query, # Current user query
508
- "expanded_query": "", # "_____", # Complete the code to define the expanded version of the query
509
  "context": [], # Retrieved documents (initially empty)
510
- "response": "", # "_____", # Complete the code to define the AI-generated response
511
- "precision_score": 0.0, # _____, # Complete the code to define the precision score of the response
512
- "groundedness_score": 0.0, # _____, # Complete the code to define the groundedness score of the response
513
- "groundedness_loop_count": 0, # _____, # Complete the code to define the counter for groundedness loops
514
- "precision_loop_count": 0, # _____, # Complete the code to define the counter for precision loops
515
- "feedback": "", # "_____", # Complete the code to define the feedback
516
- "query_feedback": "", # "_____", # Complete the code to define the query feedback
517
- "loop_max_iter": 3 # _____ # Complete the code to define the maximum number of iterations for loops
518
  }
519
 
520
  output = WORKFLOW_APP.invoke(inputs)
@@ -645,7 +645,7 @@ class NutritionBot:
645
  return self.memory.search(
646
  query=query, # Search for interactions related to the query
647
  user_id=user_id, # Restrict search to the specific user
648
- limit=_____ # Complete the code to define the limit for retrieved interactions
649
  )
650
 
651
 
@@ -736,7 +736,7 @@ def nutrition_disorder_streamlit():
736
  st.write(message["content"])
737
 
738
  # Chat input with custom placeholder text
739
- user_query = st.chat_input("Type your question here (or 'exit' to end)...") # __________) # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
740
  if user_query:
741
  if user_query.lower() == "exit":
742
  st.session_state.chat_history.append({"role": "user", "content": "exit"})
@@ -755,15 +755,15 @@ def nutrition_disorder_streamlit():
755
  st.write(user_query)
756
 
757
  # Filter input using Llama Guard
758
- filtered_result = filter_input_with_llama_guard(user_query) # __________(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
759
  filtered_result = filtered_result.replace("\n", " ") # Normalize the result
760
 
761
  # Check if input is safe based on allowed statuses
762
- if filtered_result in ["safe", "safe S7", "safe S6"]: # __________, __________, __________]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
763
  try:
764
  if 'chatbot' not in st.session_state:
765
- st.session_state.chatbot = NutritionBot() # __________() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
766
- response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query) # __________(st.session_state.user_id, user_query)
767
  # Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
768
  st.write(response)
769
  st.session_state.chat_history.append({"role": "assistant", "content": response})
 
86
  # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
87
 
88
  # set the LLM and embedding model in the LlamaIndex settings.
89
+ Settings.llm = llm # Complete the code to define the LLM model
90
+ Settings.embedding = embedding_model # Complete the code to define the embedding model
91
 
92
  #================================Creating Langgraph agent======================#
93
 
 
167
  Dict: The updated state with the retrieved context.
168
  """
169
  print("---------retrieve_context---------")
170
+ query = state['query'] # Complete the code to define the key for the expanded query
171
  #print("Query used for retrieval:", query) # Debugging: Print the query
172
 
173
  # Retrieve documents from the vector store
 
182
  }
183
  for doc in docs
184
  ]
185
+ state['context'] = context # Complete the code to define the key for storing the context
186
  print("Extracted context with metadata:", context) # Debugging: Print the extracted context
187
  #print(f"Groundedness loop count: {state['groundedness_loop_count']}")
188
  return state
 
228
  response = chain.invoke({
229
  "query": state['query'],
230
  "context": "\n".join([doc["content"] for doc in state['context']]),
231
+ "feedback": state["feedback"] # add feedback to the prompt
232
  })
233
  state['response'] = response
234
  print("intermediate response: ", response)
 
277
  chain = groundedness_prompt | llm | StrOutputParser()
278
  groundedness_score = float(chain.invoke({
279
  "context": "\n".join([doc["content"] for doc in state['context']]),
280
+ "response": state['response'] # Complete the code to define the response
281
  }))
282
  print("groundedness_score: ", groundedness_score)
283
  state['groundedness_loop_count'] += 1
 
311
  ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
312
  ])
313
 
314
+ chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
315
  precision_score = float(chain.invoke({
316
  "query": state['query'],
317
+ "response": state['response'] # Complete the code to access the response from the state
318
  }))
319
  state['precision_score'] = precision_score
320
  print("precision_score:", precision_score)
 
397
  """Decides if groundedness is sufficient or needs improvement."""
398
  print("---------should_continue_groundedness---------")
399
  print("groundedness loop count: ", state['groundedness_loop_count'])
400
+ if state['groundedness_score'] >= 0.8 # Complete the code to define the threshold for groundedness
401
  print("Moving to precision")
402
  return "check_precision"
403
  else:
 
412
  """Decides if precision is sufficient or needs improvement."""
413
  print("---------should_continue_precision---------")
414
  print("precision loop count: ", state["precision_loop_count"])
415
+ if state['precision_score'] > 0.8: # Threshold for precision
416
  return "pass" # Complete the workflow
417
  else:
418
+ if state["precision_loop_count"] >= state['loop_max_iter']: # Maximum allowed loops
419
  return "max_iterations_reached"
420
  else:
421
  print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
422
+ return "refine_query" # Refine the query
423
 
424
 
425
 
 
438
 
439
  def create_workflow() -> StateGraph:
440
  """Creates the updated workflow for the AI nutrition agent."""
441
+ workflow = StateGraph(START) # Complete the code to define the initial state of the agent
442
 
443
  # Add processing nodes
444
+ workflow.add_node("expand_query", expand_query) # Step 1: Expand user query. Complete with the function to expand the query
445
+ workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
446
+ workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
447
+ workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
448
+ workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
449
+ workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision. Complete with the function to check precision
450
+ workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
451
+ workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations. Complete with the function to handle max iterations
452
 
453
  # Main flow edges
454
  workflow.add_edge(START, "expand_query")
 
459
  # Conditional edges based on groundedness check
460
  workflow.add_conditional_edges(
461
  "score_groundedness",
462
+ should_continue_groundedness, # Use the conditional function
463
  {
464
+ "check_precision": "check_precision", # If well-grounded, proceed to precision check.
465
+ "refine_response": "refine_response", # If not, refine the response.
466
+ "max_iterations_reached": max_iterations_reached # If max loops reached, exit.
467
  }
468
  )
469
 
470
+ workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
471
 
472
  # Conditional edges based on precision check
473
  workflow.add_conditional_edges(
474
  "check_precision",
475
+ should_continue_precision, # Use the conditional function
476
  {
477
+ "pass": END, # If precise, complete the workflow.
478
+ "refine_query": "refine_query", # If imprecise, refine the query.
479
+ "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
480
  }
481
  )
482
 
483
+ workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
484
 
485
  workflow.add_edge("max_iterations_reached", END)
486
 
 
505
  # Initialize state with necessary parameters
506
  inputs = {
507
  "query": query, # Current user query
508
+ "expanded_query": "", # Complete the code to define the expanded version of the query
509
  "context": [], # Retrieved documents (initially empty)
510
+ "response": "", # Complete the code to define the AI-generated response
511
+ "precision_score": 0.0, # Complete the code to define the precision score of the response
512
+ "groundedness_score": 0.0, # Complete the code to define the groundedness score of the response
513
+ "groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
514
+ "precision_loop_count": 0, # Complete the code to define the counter for precision loops
515
+ "feedback": "", # Complete the code to define the feedback
516
+ "query_feedback": "", # Complete the code to define the query feedback
517
+ "loop_max_iter": 3 # Complete the code to define the maximum number of iterations for loops
518
  }
519
 
520
  output = WORKFLOW_APP.invoke(inputs)
 
645
  return self.memory.search(
646
  query=query, # Search for interactions related to the query
647
  user_id=user_id, # Restrict search to the specific user
648
+ limit=5 # Complete the code to define the limit for retrieved interactions
649
  )
650
 
651
 
 
736
  st.write(message["content"])
737
 
738
  # Chat input with custom placeholder text
739
+ user_query = st.chat_input("Type your question here (or 'exit' to end)...") # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
740
  if user_query:
741
  if user_query.lower() == "exit":
742
  st.session_state.chat_history.append({"role": "user", "content": "exit"})
 
755
  st.write(user_query)
756
 
757
  # Filter input using Llama Guard
758
+ filtered_result = filter_input_with_llama_guard(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
759
  filtered_result = filtered_result.replace("\n", " ") # Normalize the result
760
 
761
  # Check if input is safe based on allowed statuses
762
+ if filtered_result in ["safe", "safe S7", "safe S6"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
763
  try:
764
  if 'chatbot' not in st.session_state:
765
+ st.session_state.chatbot = NutritionBot() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
766
+ response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
767
  # Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
768
  st.write(response)
769
  st.session_state.chat_history.append({"role": "assistant", "content": response})