Spaces:
Sleeping
Sleeping
# Import necessary libraries | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
import os | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import re # For parsing tool calls | |
import gradio as gr # For the Gradio interface | |
# --- Configuration and Initialization --- | |
# File path for the knowledge base | |
KNOWLEDGE_FILE = 'knowledge_base.txt' | |
# Create a dummy knowledge file if it doesn't exist for demonstration purposes | |
# This ensures the app runs even without manually creating the file first in the Space | |
if not os.path.exists(KNOWLEDGE_FILE): | |
print(f"Creating dummy knowledge file: {KNOWLEDGE_FILE}") | |
with open(KNOWLEDGE_FILE, 'w', encoding='utf-8') as f: | |
f.write("FAQ: How to reset my router?\n\n") | |
f.write("To reset your router, locate the small reset button, usually on the back. Press and hold it for 10-15 seconds until the lights blink.\n\n") | |
f.write("FAQ: What is your return policy?\n\n") | |
f.write("You can return most items within 30 days of purchase if they are in original condition. Some electronics have a 15-day policy.\n\n") | |
f.write("Simulated Data: Order ID 12345 is in status 'Shipped'. Items: Laptop, Mouse.\n\n") | |
f.write("Simulated Data: Order ID 67890 is in status 'Processing'. Items: Keyboard.\n\n") | |
f.write("Simulated Data: Order ID 11223 is in status 'Delivered'. Items: Monitor, Webcam.\n\n") | |
f.write("Troubleshooting: If your device won't connect to Wi-Fi, try restarting both the device and the router.") | |
# --- RAG Component --- | |
# 1. Load knowledge and chunk (simple chunking) | |
def load_knowledge(filepath): | |
"""Loads knowledge from a text file and splits it into chunks.""" | |
print(f"Loading knowledge from {filepath}...") | |
try: | |
with open(filepath, 'r', encoding='utf-8') as f: | |
content = f.read() | |
# Simple chunking by splitting on double newline | |
chunks = [chunk.strip() for chunk in content.split('\n\n') if chunk.strip()] | |
print(f"Loaded {len(chunks)} chunks.") | |
return chunks | |
except FileNotFoundError: | |
print(f"Error: Knowledge file not found at {filepath}") | |
return [] | |
except Exception as e: | |
print(f"Error loading knowledge file: {e}") | |
return [] | |
# 2. Load Embedding Model | |
# Using a small, fast model for the pilot | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
print("Embedding model loaded: all-MiniLM-L6-v2") | |
# 3. Create Embeddings and FAISS Index | |
def create_faiss_index(chunks, model): | |
"""Creates FAISS index from text chunks.""" | |
if not chunks: | |
print("No chunks to index.") | |
return None, [] | |
print("Creating embeddings...") | |
# Encode chunks to get embeddings | |
embeddings = model.encode(chunks) | |
dimension = embeddings.shape[1] # Dimension of the embeddings | |
# Create a FAISS index (FlatL2 is simple and good for pilot) | |
index = faiss.IndexFlatL2(dimension) | |
# Add embeddings to the index | |
index.add(np.array(embeddings).astype('float32')) | |
print(f"FAISS index created with {index.ntotal} vectors.") | |
return index, chunks # Return index and original chunks | |
# 4. Search Function in FAISS | |
def search_faiss(query, index, chunks, model, k=3): | |
"""Searches the FAISS index for the most similar chunks.""" | |
if index is None or not chunks: | |
return ["Knowledge base not available."] | |
print(f"Searching FAISS for query: '{query}' (k={k})") | |
# Encode the query | |
query_embedding = model.encode([query])[0] | |
# Search the index | |
distances, indices = index.search(np.array([query_embedding]).astype('float32'), k) | |
# Retrieve the actual text chunks based on indices | |
results = [chunks[i] for i in indices[0]] | |
print(f"Found {len(results)} relevant chunks.") | |
return results | |
# Prepare the knowledge base and FAISS index on startup | |
knowledge_chunks = load_knowledge(KNOWLEDGE_FILE) | |
faiss_index, indexed_chunks = create_faiss_index(knowledge_chunks, embedding_model) | |
# --- LLM Component (Hugging Face) --- | |
# Choose an LLM (gpt2 is small and easy to run on CPU) | |
model_name = "gpt2" | |
# Load Tokenizer and Model | |
print(f"Loading LLM: {model_name}...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Some models like gpt2 don't have a pad token, use eos token instead | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Move model to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
print(f"LLM loaded: {model_name} on {device}") | |
# --- Tool Definitions (Simulated) --- | |
# Simulate a small order database | |
simulated_orders = { | |
"12345": {"status": "Shipped", "items": ["Laptop", "Mouse"], "address": "123 Main St"}, | |
"67890": {"status": "Processing", "items": ["Keyboard"], "address": "456 Oak Ave"}, | |
"11223": {"status": "Delivered", "items": ["Monitor", "Webcam"], "address": "789 Pine Ln"}, | |
} | |
def search_knowledge_base_tool(query: str) -> str: | |
""" | |
Searches the knowledge base for relevant information based on the query. | |
Args: | |
query (str): The search query. | |
Returns: | |
str: A string containing the retrieved information chunks, separated by '---'. | |
Returns a message if no information is found or if RAG is not available. | |
""" | |
print(f"\n--- Calling Tool: search_knowledge_base with query='{query}' ---") | |
if faiss_index is None or not indexed_chunks: | |
return "Error: Knowledge base is not available or empty." | |
results = search_faiss(query, faiss_index, indexed_chunks, embedding_model) | |
if not results: | |
return "No relevant information found in the knowledge base." | |
formatted_results = "Retrieved Information:\n" + "\n---\n".join(results) | |
print(f"--- Tool Output: {formatted_results} ---") | |
return formatted_results | |
def get_order_status_tool(order_id: str) -> str: | |
""" | |
Gets the current status and items of an order given its ID. | |
Args: | |
order_id (str): The ID of the order to look up. | |
Returns: | |
str: A string describing the order status and items, or an error message if not found. | |
""" | |
print(f"\n--- Calling Tool: get_order_status with order_id='{order_id}' ---") | |
order_info = simulated_orders.get(order_id) | |
if order_info: | |
items_list = ", ".join(order_info['items']) | |
status_message = f"Order {order_id} status: {order_info['status']}. Items: {items_list}." | |
print(f"--- Tool Output: {status_message} ---") | |
return status_message | |
else: | |
error_message = f"Order {order_id} not found in our system." | |
print(f"--- Tool Output: {error_message} ---") | |
return error_message | |
def initiate_return_tool(order_id: str, item_name: str) -> str: | |
""" | |
Initiates a return process for a specific item in an order. | |
Args: | |
order_id (str): The ID of the order. | |
item_name (str): The name of the item to return. | |
Returns: | |
str: A confirmation message or an error message if the order/item is not found. | |
""" | |
print(f"\n--- Calling Tool: initiate_return for order_id='{order_id}', item='{item_name}' ---") | |
order_info = simulated_orders.get(order_id) | |
if order_info: | |
# Simple check if item is in the order | |
if any(item.lower() == item_name.lower() for item in order_info['items']): | |
# In a real system, this would interact with a return API | |
confirmation_message = f"Return initiated successfully for '{item_name}' in order {order_id}. Please check your email for instructions." | |
print(f"--- Tool Output: {confirmation_message} ---") | |
return confirmation_message | |
else: | |
error_message = f"Item '{item_name}' not found in order {order_id}." | |
print(f"--- Tool Output: {error_message} ---") | |
return error_message | |
else: | |
error_message = f"Order {order_id} not found, cannot initiate return." | |
print(f"--- Tool Output: {error_message} ---") | |
return error_message | |
# Map tool names to their functions | |
available_tools = { | |
"search_knowledge_base": search_knowledge_base_tool, | |
"get_order_status": get_order_status_tool, | |
"initiate_return": initiate_return_tool, | |
} | |
# --- Agentic Logic (Inspired by smolagents) --- | |
# System Prompt to guide the LLM | |
SYSTEM_PROMPT = """You are a helpful and friendly customer support agent. | |
Your goal is to assist the user with their queries regarding orders, products, or technical issues. | |
You have access to several tools to help you gather information or perform actions. Use them when necessary. | |
Available tools: | |
1. `search_knowledge_base(query: str)`: Use this tool to find information in the knowledge base about products, FAQs, or troubleshooting steps. Use it whenever you need information you don't have or to confirm details. | |
2. `get_order_status(order_id: str)`: Use this tool to check the current status and items of a customer's order. Requires an order ID. | |
3. `initiate_return(order_id: str, item_name: str)`: Use this tool to start a return process for a specific item in an order. Requires the order ID and the exact item name from the order. | |
To use a tool, respond *only* with the following format on a new line, followed by no other text: | |
TOOL_USE: tool_name(arg1=value1, arg2=value2, ...) | |
Ensure arguments are correctly formatted (e.g., strings in quotes if necessary, though the simple parser below expects simple key=value). | |
If you have enough information, have completed the user's request, or cannot use a tool, respond directly to the user in natural language. | |
Always be polite, empathetic, and try to resolve the user's issue efficiently. | |
If the user asks for something unrelated to your tools or knowledge base, politely state that you can only help with support queries. | |
Manage the conversation flow. Ask clarifying questions if needed (e.g., ask for an order ID if the user asks about an order). | |
After a tool is used and its output is provided, you will receive the output. Interpret the tool output and decide the next step: either use another tool, ask a clarifying question, or provide the final answer to the user. | |
Conversation History: | |
""" | |
def parse_tool_call(text: str) -> tuple[str | None, dict | None]: | |
""" | |
Parses a string to extract a tool name and arguments. | |
Expected format: TOOL_USE: tool_name(arg1=value1, arg2=value2, ...) | |
Returns: (tool_name, args_dict) or (None, None) if parsing fails. | |
""" | |
tool_call_prefix = "TOOL_USE:" | |
if text.strip().startswith(tool_call_prefix): | |
call_string = text.strip()[len(tool_call_prefix):].strip() | |
# Use regex for more robust parsing (still simplified) | |
match = re.match(r"(\w+)\((.*)\)", call_string) | |
if match: | |
tool_name = match.group(1) | |
args_string = match.group(2) | |
args = {} | |
# Simple parsing for key=value pairs | |
if args_string: | |
# Split by comma, but handle potential commas within values (not handled here for simplicity) | |
arg_pairs = args_string.split(',') | |
for pair in arg_pairs: | |
if '=' in pair: | |
key, value = pair.split('=', 1) | |
# Basic cleaning of value (remove leading/trailing whitespace and quotes) | |
key = key.strip() | |
value = value.strip().strip("'\"") | |
args[key] = value | |
return tool_name, args | |
else: | |
print(f"DEBUG: Failed regex parse on tool call string: '{call_string}'") | |
return None, None | |
return None, None | |
def run_agent_step(user_input: str, conversation_history: list[str], llm_model, llm_tokenizer, tools: dict) -> tuple[str, list[str]]: | |
""" | |
Runs one step of the agent's interaction loop. | |
Determines if a tool needs to be called or generates a direct response. | |
Handles multi-turn interaction including tool calls. | |
""" | |
# Add user input to history for context | |
current_history = conversation_history + [f"User: {user_input}"] | |
# Prepare the full prompt for the LLM | |
# Join history with newlines, ensuring system prompt is first | |
prompt = SYSTEM_PROMPT + "\n" + "\n".join(current_history) + "\nAgent:" | |
print(f"\n--- Sending Prompt to LLM ---\n{prompt}\n--- End Prompt ---") | |
# Generate response from LLM | |
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) | |
max_tokens = 250 # Adjust based on desired verbosity | |
output_tokens = llm_model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
num_return_sequences=1, | |
pad_token_id=llm_tokenizer.eos_token_id, | |
# Add other generation parameters if needed (e.g., temperature, top_p) | |
) | |
# Decode the generated text, excluding the input prompt part | |
generated_text = llm_tokenizer.decode( | |
output_tokens[0][inputs.input_ids.shape[-1]:], | |
skip_special_tokens=True | |
).strip() | |
print(f"\n--- LLM Raw Output ---\n{generated_text}\n--- End LLM Output ---") | |
# Try to parse the output as a tool call | |
tool_name, tool_args = parse_tool_call(generated_text) | |
if tool_name and tool_name in tools: | |
# It's a tool call! | |
print(f"DEBUG: Detected tool call: {tool_name} with args {tool_args}") | |
try: | |
# Execute the tool | |
tool_output = tools[tool_name](**tool_args) | |
# Add the tool call and its output to the history | |
# This is the observation step for the agent | |
updated_history = current_history + [f"Agent: {generated_text}", f"Tool Output: {tool_output}"] | |
# Now, call the LLM *again* with the updated history including tool output. | |
# This allows the agent to interpret the tool's result and formulate the final response. | |
print("\n--- Re-prompting LLM with Tool Output ---") | |
second_prompt = SYSTEM_PROMPT + "\n" + "\n".join(updated_history) + "\nAgent (interpreting tool output):" | |
print(f"--- Second Prompt ---\n{second_prompt}\n--- End Second Prompt ---") | |
second_inputs = llm_tokenizer(second_prompt, return_tensors="pt").to(device) | |
second_output_tokens = llm_model.generate( | |
**second_inputs, | |
max_new_tokens=max_tokens, # Use same max tokens for response | |
num_return_sequences=1, | |
pad_token_id=llm_tokenizer.eos_token_id, | |
) | |
final_response = llm_tokenizer.decode( | |
second_output_tokens[0][second_inputs.input_ids.shape[-1]:], | |
skip_special_tokens=True | |
).strip() | |
# Add the final response to the history | |
final_history = updated_history + [f"Agent: {final_response}"] | |
return final_response, final_history | |
except TypeError as e: | |
# Handle cases where args don't match tool function signature | |
error_msg = f"Error executing tool '{tool_name}': Invalid arguments. {e}" | |
print(f"DEBUG: {error_msg}") | |
# Inform the agent about the error | |
updated_history = current_history + [f"Agent: {generated_text}", f"Tool Output: Error - {error_msg}"] | |
second_prompt = SYSTEM_PROMPT + "\n" + "\n".join(updated_history) + "\nAgent (interpreting tool error):" | |
second_inputs = llm_tokenizer(second_prompt, return_tensors="pt").to(device) | |
second_output_tokens = llm_model.generate( | |
**second_inputs, | |
max_new_tokens=max_tokens, | |
num_return_sequences=1, | |
pad_token_id=llm_tokenizer.eos_token_id, | |
) | |
final_response = llm_tokenizer.decode( | |
second_output_tokens[0][second_inputs.input_ids.shape[-1]:], | |
skip_special_tokens=True | |
).strip() | |
final_history = updated_history + [f"Agent: {final_response}"] | |
return final_response, final_history | |
except Exception as e: | |
# Handle other potential errors during tool execution | |
error_msg = f"An unexpected error occurred during tool execution: {e}" | |
print(f"DEBUG: {error_msg}") | |
# Inform the agent about the error | |
updated_history = current_history + [f"Agent: {generated_text}", f"Tool Output: Error - {error_msg}"] | |
second_prompt = SYSTEM_PROMPT + "\n" + "\n".join(updated_history) + "\nAgent (interpreting tool error):" | |
second_inputs = llm_tokenizer(second_prompt, return_tensors="pt").to(device) | |
second_output_tokens = llm_model.generate( | |
**second_inputs, | |
max_new_tokens=max_tokens, | |
num_return_sequences=1, | |
pad_token_id=llm_tokenizer.eos_token_id, | |
) | |
final_response = llm_tokenizer.decode( | |
second_output_tokens[0][second_inputs.input_ids.shape[-1]:], | |
skip_special_tokens=True | |
).strip() | |
final_history = updated_history + [f"Agent: {final_response}"] | |
return final_response, final_history | |
else: | |
# Not a tool call, or tool not recognized. Treat as a direct response. | |
print("DEBUG: No valid tool call detected. Treating as direct response.") | |
final_history = current_history + [f"Agent: {generated_text}"] | |
return generated_text, final_history | |
# --- Gradio Interface --- | |
# This function will be called by Gradio when the user sends a message | |
# It takes the user input and the current chat history (provided by Gradio) | |
# It returns the updated chat history | |
def respond(message, chat_history): | |
# Gradio chat_history is a list of lists: [[user_msg, agent_msg], ...] | |
# We need to convert it to our internal list of strings format for the agent | |
conversation_history_for_agent = [] | |
for user_msg, agent_msg in chat_history: | |
conversation_history_for_agent.append(f"User: {user_msg}") | |
if agent_msg: # Add agent's previous response if it exists | |
# Need to handle potential tool calls/outputs in previous agent messages | |
# For simplicity in this pilot, we'll just add the agent's final displayed message | |
# A more robust agent would track its own internal history including tool steps | |
conversation_history_for_agent.append(f"Agent: {agent_msg}") | |
# Run the agent step with the current message and converted history | |
# Note: run_agent_step returns the agent's response and its OWN internal history list | |
# We will use the response for Gradio, but the history management needs care | |
# For this simple Gradio integration, we'll let run_agent_step manage history internally | |
# and just use the final response. A more complex app might manage history state separately. | |
# Let's modify run_agent_step slightly or manage history outside for Gradio | |
# We'll pass the Gradio history and let run_agent_step return the full updated history list | |
# Then convert back for Gradio. | |
# --- Revised run_agent_step to handle Gradio history --- | |
# The previous run_agent_step already returns the updated list history, which is good. | |
# We just need to pass the Gradio history converted to list format. | |
# Convert Gradio history to list of strings format for run_agent_step | |
agent_internal_history = [] | |
for human, agent in chat_history: | |
agent_internal_history.append(f"User: {human}") | |
if agent is not None: # Check if agent has responded yet in a turn | |
# Simple approach: just add the final agent message from the pair | |
agent_internal_history.append(f"Agent: {agent}") | |
# Note: This simple conversion loses the intermediate Tool Use / Tool Output steps | |
# from the *previous* turns if they occurred. | |
# A more advanced implementation would need to store the agent's full step history | |
# separately and pass that to run_agent_step. | |
# For this pilot, the current turn's tool use/output is added *within* run_agent_step | |
# before the final response is generated. This is sufficient to demonstrate the logic | |
# for the current turn. | |
# Run the agent logic | |
agent_response, updated_agent_internal_history = run_agent_step(message, agent_internal_history, model, tokenizer, available_tools) | |
# Append the agent's final response to the Gradio chat history | |
chat_history.append((message, agent_response)) | |
# Note: The updated_agent_internal_history contains the full trace for the *last* turn. | |
# If you wanted to persist the full detailed history across turns for the agent's context, | |
# you would need to store updated_agent_internal_history in the Gradio state. | |
# For this pilot, the simple conversion above provides enough context for the LLM | |
# to remember previous user/agent messages, but not the intermediate tool steps of past turns. | |
return "", chat_history # Return empty string for message box, and updated history | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Agentic RAG Pilot Demo | |
This is a small-scale pilot demonstrating an agentic system with RAG and tool use | |
for customer support queries, inspired by the smolagents framework. | |
Ask me about order status (try Order ID 12345 or 67890) or how to reset your router. | |
""" | |
) | |
# Update chatbot with type parameter | |
chatbot = gr.Chatbot(height=400, type='messages') | |
# Textbox for user input | |
msg = gr.Textbox(label="Your Message") | |
# Button to send message | |
send_button = gr.Button("Send") | |
# Button to clear chat | |
clear = gr.ClearButton([msg, chatbot]) | |
# Link the send button and textbox submission to the respond function | |
# When the send button is clicked OR the user presses Enter in the textbox, | |
# call the respond function with the message and current chat history. | |
# The output updates the message box (clearing it) and the chatbot. | |
send_button.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot]) | |
msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot]) | |
# Launch the Gradio app with updated parameters | |
if __name__ == "__main__": | |
demo.queue().launch() # Use queue() method instead of enable_queue parameter | |