agenticRAGpilot / app.py
lukehilasak's picture
fix gradio
1fdd28b
# Import necessary libraries
import faiss
from sentence_transformers import SentenceTransformer
import numpy as np
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re # For parsing tool calls
import gradio as gr # For the Gradio interface
# --- Configuration and Initialization ---
# File path for the knowledge base
KNOWLEDGE_FILE = 'knowledge_base.txt'
# Create a dummy knowledge file if it doesn't exist for demonstration purposes
# This ensures the app runs even without manually creating the file first in the Space
if not os.path.exists(KNOWLEDGE_FILE):
print(f"Creating dummy knowledge file: {KNOWLEDGE_FILE}")
with open(KNOWLEDGE_FILE, 'w', encoding='utf-8') as f:
f.write("FAQ: How to reset my router?\n\n")
f.write("To reset your router, locate the small reset button, usually on the back. Press and hold it for 10-15 seconds until the lights blink.\n\n")
f.write("FAQ: What is your return policy?\n\n")
f.write("You can return most items within 30 days of purchase if they are in original condition. Some electronics have a 15-day policy.\n\n")
f.write("Simulated Data: Order ID 12345 is in status 'Shipped'. Items: Laptop, Mouse.\n\n")
f.write("Simulated Data: Order ID 67890 is in status 'Processing'. Items: Keyboard.\n\n")
f.write("Simulated Data: Order ID 11223 is in status 'Delivered'. Items: Monitor, Webcam.\n\n")
f.write("Troubleshooting: If your device won't connect to Wi-Fi, try restarting both the device and the router.")
# --- RAG Component ---
# 1. Load knowledge and chunk (simple chunking)
def load_knowledge(filepath):
"""Loads knowledge from a text file and splits it into chunks."""
print(f"Loading knowledge from {filepath}...")
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
# Simple chunking by splitting on double newline
chunks = [chunk.strip() for chunk in content.split('\n\n') if chunk.strip()]
print(f"Loaded {len(chunks)} chunks.")
return chunks
except FileNotFoundError:
print(f"Error: Knowledge file not found at {filepath}")
return []
except Exception as e:
print(f"Error loading knowledge file: {e}")
return []
# 2. Load Embedding Model
# Using a small, fast model for the pilot
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("Embedding model loaded: all-MiniLM-L6-v2")
# 3. Create Embeddings and FAISS Index
def create_faiss_index(chunks, model):
"""Creates FAISS index from text chunks."""
if not chunks:
print("No chunks to index.")
return None, []
print("Creating embeddings...")
# Encode chunks to get embeddings
embeddings = model.encode(chunks)
dimension = embeddings.shape[1] # Dimension of the embeddings
# Create a FAISS index (FlatL2 is simple and good for pilot)
index = faiss.IndexFlatL2(dimension)
# Add embeddings to the index
index.add(np.array(embeddings).astype('float32'))
print(f"FAISS index created with {index.ntotal} vectors.")
return index, chunks # Return index and original chunks
# 4. Search Function in FAISS
def search_faiss(query, index, chunks, model, k=3):
"""Searches the FAISS index for the most similar chunks."""
if index is None or not chunks:
return ["Knowledge base not available."]
print(f"Searching FAISS for query: '{query}' (k={k})")
# Encode the query
query_embedding = model.encode([query])[0]
# Search the index
distances, indices = index.search(np.array([query_embedding]).astype('float32'), k)
# Retrieve the actual text chunks based on indices
results = [chunks[i] for i in indices[0]]
print(f"Found {len(results)} relevant chunks.")
return results
# Prepare the knowledge base and FAISS index on startup
knowledge_chunks = load_knowledge(KNOWLEDGE_FILE)
faiss_index, indexed_chunks = create_faiss_index(knowledge_chunks, embedding_model)
# --- LLM Component (Hugging Face) ---
# Choose an LLM (gpt2 is small and easy to run on CPU)
model_name = "gpt2"
# Load Tokenizer and Model
print(f"Loading LLM: {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Some models like gpt2 don't have a pad token, use eos token instead
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"LLM loaded: {model_name} on {device}")
# --- Tool Definitions (Simulated) ---
# Simulate a small order database
simulated_orders = {
"12345": {"status": "Shipped", "items": ["Laptop", "Mouse"], "address": "123 Main St"},
"67890": {"status": "Processing", "items": ["Keyboard"], "address": "456 Oak Ave"},
"11223": {"status": "Delivered", "items": ["Monitor", "Webcam"], "address": "789 Pine Ln"},
}
def search_knowledge_base_tool(query: str) -> str:
"""
Searches the knowledge base for relevant information based on the query.
Args:
query (str): The search query.
Returns:
str: A string containing the retrieved information chunks, separated by '---'.
Returns a message if no information is found or if RAG is not available.
"""
print(f"\n--- Calling Tool: search_knowledge_base with query='{query}' ---")
if faiss_index is None or not indexed_chunks:
return "Error: Knowledge base is not available or empty."
results = search_faiss(query, faiss_index, indexed_chunks, embedding_model)
if not results:
return "No relevant information found in the knowledge base."
formatted_results = "Retrieved Information:\n" + "\n---\n".join(results)
print(f"--- Tool Output: {formatted_results} ---")
return formatted_results
def get_order_status_tool(order_id: str) -> str:
"""
Gets the current status and items of an order given its ID.
Args:
order_id (str): The ID of the order to look up.
Returns:
str: A string describing the order status and items, or an error message if not found.
"""
print(f"\n--- Calling Tool: get_order_status with order_id='{order_id}' ---")
order_info = simulated_orders.get(order_id)
if order_info:
items_list = ", ".join(order_info['items'])
status_message = f"Order {order_id} status: {order_info['status']}. Items: {items_list}."
print(f"--- Tool Output: {status_message} ---")
return status_message
else:
error_message = f"Order {order_id} not found in our system."
print(f"--- Tool Output: {error_message} ---")
return error_message
def initiate_return_tool(order_id: str, item_name: str) -> str:
"""
Initiates a return process for a specific item in an order.
Args:
order_id (str): The ID of the order.
item_name (str): The name of the item to return.
Returns:
str: A confirmation message or an error message if the order/item is not found.
"""
print(f"\n--- Calling Tool: initiate_return for order_id='{order_id}', item='{item_name}' ---")
order_info = simulated_orders.get(order_id)
if order_info:
# Simple check if item is in the order
if any(item.lower() == item_name.lower() for item in order_info['items']):
# In a real system, this would interact with a return API
confirmation_message = f"Return initiated successfully for '{item_name}' in order {order_id}. Please check your email for instructions."
print(f"--- Tool Output: {confirmation_message} ---")
return confirmation_message
else:
error_message = f"Item '{item_name}' not found in order {order_id}."
print(f"--- Tool Output: {error_message} ---")
return error_message
else:
error_message = f"Order {order_id} not found, cannot initiate return."
print(f"--- Tool Output: {error_message} ---")
return error_message
# Map tool names to their functions
available_tools = {
"search_knowledge_base": search_knowledge_base_tool,
"get_order_status": get_order_status_tool,
"initiate_return": initiate_return_tool,
}
# --- Agentic Logic (Inspired by smolagents) ---
# System Prompt to guide the LLM
SYSTEM_PROMPT = """You are a helpful and friendly customer support agent.
Your goal is to assist the user with their queries regarding orders, products, or technical issues.
You have access to several tools to help you gather information or perform actions. Use them when necessary.
Available tools:
1. `search_knowledge_base(query: str)`: Use this tool to find information in the knowledge base about products, FAQs, or troubleshooting steps. Use it whenever you need information you don't have or to confirm details.
2. `get_order_status(order_id: str)`: Use this tool to check the current status and items of a customer's order. Requires an order ID.
3. `initiate_return(order_id: str, item_name: str)`: Use this tool to start a return process for a specific item in an order. Requires the order ID and the exact item name from the order.
To use a tool, respond *only* with the following format on a new line, followed by no other text:
TOOL_USE: tool_name(arg1=value1, arg2=value2, ...)
Ensure arguments are correctly formatted (e.g., strings in quotes if necessary, though the simple parser below expects simple key=value).
If you have enough information, have completed the user's request, or cannot use a tool, respond directly to the user in natural language.
Always be polite, empathetic, and try to resolve the user's issue efficiently.
If the user asks for something unrelated to your tools or knowledge base, politely state that you can only help with support queries.
Manage the conversation flow. Ask clarifying questions if needed (e.g., ask for an order ID if the user asks about an order).
After a tool is used and its output is provided, you will receive the output. Interpret the tool output and decide the next step: either use another tool, ask a clarifying question, or provide the final answer to the user.
Conversation History:
"""
def parse_tool_call(text: str) -> tuple[str | None, dict | None]:
"""
Parses a string to extract a tool name and arguments.
Expected format: TOOL_USE: tool_name(arg1=value1, arg2=value2, ...)
Returns: (tool_name, args_dict) or (None, None) if parsing fails.
"""
tool_call_prefix = "TOOL_USE:"
if text.strip().startswith(tool_call_prefix):
call_string = text.strip()[len(tool_call_prefix):].strip()
# Use regex for more robust parsing (still simplified)
match = re.match(r"(\w+)\((.*)\)", call_string)
if match:
tool_name = match.group(1)
args_string = match.group(2)
args = {}
# Simple parsing for key=value pairs
if args_string:
# Split by comma, but handle potential commas within values (not handled here for simplicity)
arg_pairs = args_string.split(',')
for pair in arg_pairs:
if '=' in pair:
key, value = pair.split('=', 1)
# Basic cleaning of value (remove leading/trailing whitespace and quotes)
key = key.strip()
value = value.strip().strip("'\"")
args[key] = value
return tool_name, args
else:
print(f"DEBUG: Failed regex parse on tool call string: '{call_string}'")
return None, None
return None, None
def run_agent_step(user_input: str, conversation_history: list[str], llm_model, llm_tokenizer, tools: dict) -> tuple[str, list[str]]:
"""
Runs one step of the agent's interaction loop.
Determines if a tool needs to be called or generates a direct response.
Handles multi-turn interaction including tool calls.
"""
# Add user input to history for context
current_history = conversation_history + [f"User: {user_input}"]
# Prepare the full prompt for the LLM
# Join history with newlines, ensuring system prompt is first
prompt = SYSTEM_PROMPT + "\n" + "\n".join(current_history) + "\nAgent:"
print(f"\n--- Sending Prompt to LLM ---\n{prompt}\n--- End Prompt ---")
# Generate response from LLM
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
max_tokens = 250 # Adjust based on desired verbosity
output_tokens = llm_model.generate(
**inputs,
max_new_tokens=max_tokens,
num_return_sequences=1,
pad_token_id=llm_tokenizer.eos_token_id,
# Add other generation parameters if needed (e.g., temperature, top_p)
)
# Decode the generated text, excluding the input prompt part
generated_text = llm_tokenizer.decode(
output_tokens[0][inputs.input_ids.shape[-1]:],
skip_special_tokens=True
).strip()
print(f"\n--- LLM Raw Output ---\n{generated_text}\n--- End LLM Output ---")
# Try to parse the output as a tool call
tool_name, tool_args = parse_tool_call(generated_text)
if tool_name and tool_name in tools:
# It's a tool call!
print(f"DEBUG: Detected tool call: {tool_name} with args {tool_args}")
try:
# Execute the tool
tool_output = tools[tool_name](**tool_args)
# Add the tool call and its output to the history
# This is the observation step for the agent
updated_history = current_history + [f"Agent: {generated_text}", f"Tool Output: {tool_output}"]
# Now, call the LLM *again* with the updated history including tool output.
# This allows the agent to interpret the tool's result and formulate the final response.
print("\n--- Re-prompting LLM with Tool Output ---")
second_prompt = SYSTEM_PROMPT + "\n" + "\n".join(updated_history) + "\nAgent (interpreting tool output):"
print(f"--- Second Prompt ---\n{second_prompt}\n--- End Second Prompt ---")
second_inputs = llm_tokenizer(second_prompt, return_tensors="pt").to(device)
second_output_tokens = llm_model.generate(
**second_inputs,
max_new_tokens=max_tokens, # Use same max tokens for response
num_return_sequences=1,
pad_token_id=llm_tokenizer.eos_token_id,
)
final_response = llm_tokenizer.decode(
second_output_tokens[0][second_inputs.input_ids.shape[-1]:],
skip_special_tokens=True
).strip()
# Add the final response to the history
final_history = updated_history + [f"Agent: {final_response}"]
return final_response, final_history
except TypeError as e:
# Handle cases where args don't match tool function signature
error_msg = f"Error executing tool '{tool_name}': Invalid arguments. {e}"
print(f"DEBUG: {error_msg}")
# Inform the agent about the error
updated_history = current_history + [f"Agent: {generated_text}", f"Tool Output: Error - {error_msg}"]
second_prompt = SYSTEM_PROMPT + "\n" + "\n".join(updated_history) + "\nAgent (interpreting tool error):"
second_inputs = llm_tokenizer(second_prompt, return_tensors="pt").to(device)
second_output_tokens = llm_model.generate(
**second_inputs,
max_new_tokens=max_tokens,
num_return_sequences=1,
pad_token_id=llm_tokenizer.eos_token_id,
)
final_response = llm_tokenizer.decode(
second_output_tokens[0][second_inputs.input_ids.shape[-1]:],
skip_special_tokens=True
).strip()
final_history = updated_history + [f"Agent: {final_response}"]
return final_response, final_history
except Exception as e:
# Handle other potential errors during tool execution
error_msg = f"An unexpected error occurred during tool execution: {e}"
print(f"DEBUG: {error_msg}")
# Inform the agent about the error
updated_history = current_history + [f"Agent: {generated_text}", f"Tool Output: Error - {error_msg}"]
second_prompt = SYSTEM_PROMPT + "\n" + "\n".join(updated_history) + "\nAgent (interpreting tool error):"
second_inputs = llm_tokenizer(second_prompt, return_tensors="pt").to(device)
second_output_tokens = llm_model.generate(
**second_inputs,
max_new_tokens=max_tokens,
num_return_sequences=1,
pad_token_id=llm_tokenizer.eos_token_id,
)
final_response = llm_tokenizer.decode(
second_output_tokens[0][second_inputs.input_ids.shape[-1]:],
skip_special_tokens=True
).strip()
final_history = updated_history + [f"Agent: {final_response}"]
return final_response, final_history
else:
# Not a tool call, or tool not recognized. Treat as a direct response.
print("DEBUG: No valid tool call detected. Treating as direct response.")
final_history = current_history + [f"Agent: {generated_text}"]
return generated_text, final_history
# --- Gradio Interface ---
# This function will be called by Gradio when the user sends a message
# It takes the user input and the current chat history (provided by Gradio)
# It returns the updated chat history
def respond(message, chat_history):
# Gradio chat_history is a list of lists: [[user_msg, agent_msg], ...]
# We need to convert it to our internal list of strings format for the agent
conversation_history_for_agent = []
for user_msg, agent_msg in chat_history:
conversation_history_for_agent.append(f"User: {user_msg}")
if agent_msg: # Add agent's previous response if it exists
# Need to handle potential tool calls/outputs in previous agent messages
# For simplicity in this pilot, we'll just add the agent's final displayed message
# A more robust agent would track its own internal history including tool steps
conversation_history_for_agent.append(f"Agent: {agent_msg}")
# Run the agent step with the current message and converted history
# Note: run_agent_step returns the agent's response and its OWN internal history list
# We will use the response for Gradio, but the history management needs care
# For this simple Gradio integration, we'll let run_agent_step manage history internally
# and just use the final response. A more complex app might manage history state separately.
# Let's modify run_agent_step slightly or manage history outside for Gradio
# We'll pass the Gradio history and let run_agent_step return the full updated history list
# Then convert back for Gradio.
# --- Revised run_agent_step to handle Gradio history ---
# The previous run_agent_step already returns the updated list history, which is good.
# We just need to pass the Gradio history converted to list format.
# Convert Gradio history to list of strings format for run_agent_step
agent_internal_history = []
for human, agent in chat_history:
agent_internal_history.append(f"User: {human}")
if agent is not None: # Check if agent has responded yet in a turn
# Simple approach: just add the final agent message from the pair
agent_internal_history.append(f"Agent: {agent}")
# Note: This simple conversion loses the intermediate Tool Use / Tool Output steps
# from the *previous* turns if they occurred.
# A more advanced implementation would need to store the agent's full step history
# separately and pass that to run_agent_step.
# For this pilot, the current turn's tool use/output is added *within* run_agent_step
# before the final response is generated. This is sufficient to demonstrate the logic
# for the current turn.
# Run the agent logic
agent_response, updated_agent_internal_history = run_agent_step(message, agent_internal_history, model, tokenizer, available_tools)
# Append the agent's final response to the Gradio chat history
chat_history.append((message, agent_response))
# Note: The updated_agent_internal_history contains the full trace for the *last* turn.
# If you wanted to persist the full detailed history across turns for the agent's context,
# you would need to store updated_agent_internal_history in the Gradio state.
# For this pilot, the simple conversion above provides enough context for the LLM
# to remember previous user/agent messages, but not the intermediate tool steps of past turns.
return "", chat_history # Return empty string for message box, and updated history
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# Agentic RAG Pilot Demo
This is a small-scale pilot demonstrating an agentic system with RAG and tool use
for customer support queries, inspired by the smolagents framework.
Ask me about order status (try Order ID 12345 or 67890) or how to reset your router.
"""
)
# Update chatbot with type parameter
chatbot = gr.Chatbot(height=400, type='messages')
# Textbox for user input
msg = gr.Textbox(label="Your Message")
# Button to send message
send_button = gr.Button("Send")
# Button to clear chat
clear = gr.ClearButton([msg, chatbot])
# Link the send button and textbox submission to the respond function
# When the send button is clicked OR the user presses Enter in the textbox,
# call the respond function with the message and current chat history.
# The output updates the message box (clearing it) and the chatbot.
send_button.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
# Launch the Gradio app with updated parameters
if __name__ == "__main__":
demo.queue().launch() # Use queue() method instead of enable_queue parameter