File size: 22,466 Bytes
b93c73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fdd28b
 
b93c73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fdd28b
 
 
b93c73a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
# Import necessary libraries
import faiss
from sentence_transformers import SentenceTransformer
import numpy as np
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re # For parsing tool calls
import gradio as gr # For the Gradio interface

# --- Configuration and Initialization ---

# File path for the knowledge base
KNOWLEDGE_FILE = 'knowledge_base.txt'

# Create a dummy knowledge file if it doesn't exist for demonstration purposes
# This ensures the app runs even without manually creating the file first in the Space
if not os.path.exists(KNOWLEDGE_FILE):
    print(f"Creating dummy knowledge file: {KNOWLEDGE_FILE}")
    with open(KNOWLEDGE_FILE, 'w', encoding='utf-8') as f:
        f.write("FAQ: How to reset my router?\n\n")
        f.write("To reset your router, locate the small reset button, usually on the back. Press and hold it for 10-15 seconds until the lights blink.\n\n")
        f.write("FAQ: What is your return policy?\n\n")
        f.write("You can return most items within 30 days of purchase if they are in original condition. Some electronics have a 15-day policy.\n\n")
        f.write("Simulated Data: Order ID 12345 is in status 'Shipped'. Items: Laptop, Mouse.\n\n")
        f.write("Simulated Data: Order ID 67890 is in status 'Processing'. Items: Keyboard.\n\n")
        f.write("Simulated Data: Order ID 11223 is in status 'Delivered'. Items: Monitor, Webcam.\n\n")
        f.write("Troubleshooting: If your device won't connect to Wi-Fi, try restarting both the device and the router.")


# --- RAG Component ---

# 1. Load knowledge and chunk (simple chunking)
def load_knowledge(filepath):
    """Loads knowledge from a text file and splits it into chunks."""
    print(f"Loading knowledge from {filepath}...")
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            content = f.read()
        # Simple chunking by splitting on double newline
        chunks = [chunk.strip() for chunk in content.split('\n\n') if chunk.strip()]
        print(f"Loaded {len(chunks)} chunks.")
        return chunks
    except FileNotFoundError:
        print(f"Error: Knowledge file not found at {filepath}")
        return []
    except Exception as e:
        print(f"Error loading knowledge file: {e}")
        return []

# 2. Load Embedding Model
# Using a small, fast model for the pilot
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("Embedding model loaded: all-MiniLM-L6-v2")

# 3. Create Embeddings and FAISS Index
def create_faiss_index(chunks, model):
    """Creates FAISS index from text chunks."""
    if not chunks:
        print("No chunks to index.")
        return None, []
    print("Creating embeddings...")
    # Encode chunks to get embeddings
    embeddings = model.encode(chunks)
    dimension = embeddings.shape[1] # Dimension of the embeddings
    # Create a FAISS index (FlatL2 is simple and good for pilot)
    index = faiss.IndexFlatL2(dimension)
    # Add embeddings to the index
    index.add(np.array(embeddings).astype('float32'))
    print(f"FAISS index created with {index.ntotal} vectors.")
    return index, chunks # Return index and original chunks

# 4. Search Function in FAISS
def search_faiss(query, index, chunks, model, k=3):
    """Searches the FAISS index for the most similar chunks."""
    if index is None or not chunks:
        return ["Knowledge base not available."]
    print(f"Searching FAISS for query: '{query}' (k={k})")
    # Encode the query
    query_embedding = model.encode([query])[0]
    # Search the index
    distances, indices = index.search(np.array([query_embedding]).astype('float32'), k)
    # Retrieve the actual text chunks based on indices
    results = [chunks[i] for i in indices[0]]
    print(f"Found {len(results)} relevant chunks.")
    return results

# Prepare the knowledge base and FAISS index on startup
knowledge_chunks = load_knowledge(KNOWLEDGE_FILE)
faiss_index, indexed_chunks = create_faiss_index(knowledge_chunks, embedding_model)


# --- LLM Component (Hugging Face) ---

# Choose an LLM (gpt2 is small and easy to run on CPU)
model_name = "gpt2"

# Load Tokenizer and Model
print(f"Loading LLM: {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Some models like gpt2 don't have a pad token, use eos token instead
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"LLM loaded: {model_name} on {device}")

# --- Tool Definitions (Simulated) ---

# Simulate a small order database
simulated_orders = {
    "12345": {"status": "Shipped", "items": ["Laptop", "Mouse"], "address": "123 Main St"},
    "67890": {"status": "Processing", "items": ["Keyboard"], "address": "456 Oak Ave"},
    "11223": {"status": "Delivered", "items": ["Monitor", "Webcam"], "address": "789 Pine Ln"},
}

def search_knowledge_base_tool(query: str) -> str:
    """
    Searches the knowledge base for relevant information based on the query.
    Args:
        query (str): The search query.
    Returns:
        str: A string containing the retrieved information chunks, separated by '---'.
             Returns a message if no information is found or if RAG is not available.
    """
    print(f"\n--- Calling Tool: search_knowledge_base with query='{query}' ---")
    if faiss_index is None or not indexed_chunks:
        return "Error: Knowledge base is not available or empty."

    results = search_faiss(query, faiss_index, indexed_chunks, embedding_model)

    if not results:
        return "No relevant information found in the knowledge base."

    formatted_results = "Retrieved Information:\n" + "\n---\n".join(results)
    print(f"--- Tool Output: {formatted_results} ---")
    return formatted_results

def get_order_status_tool(order_id: str) -> str:
    """
    Gets the current status and items of an order given its ID.
    Args:
        order_id (str): The ID of the order to look up.
    Returns:
        str: A string describing the order status and items, or an error message if not found.
    """
    print(f"\n--- Calling Tool: get_order_status with order_id='{order_id}' ---")
    order_info = simulated_orders.get(order_id)
    if order_info:
        items_list = ", ".join(order_info['items'])
        status_message = f"Order {order_id} status: {order_info['status']}. Items: {items_list}."
        print(f"--- Tool Output: {status_message} ---")
        return status_message
    else:
        error_message = f"Order {order_id} not found in our system."
        print(f"--- Tool Output: {error_message} ---")
        return error_message

def initiate_return_tool(order_id: str, item_name: str) -> str:
    """
    Initiates a return process for a specific item in an order.
    Args:
        order_id (str): The ID of the order.
        item_name (str): The name of the item to return.
    Returns:
        str: A confirmation message or an error message if the order/item is not found.
    """
    print(f"\n--- Calling Tool: initiate_return for order_id='{order_id}', item='{item_name}' ---")
    order_info = simulated_orders.get(order_id)
    if order_info:
        # Simple check if item is in the order
        if any(item.lower() == item_name.lower() for item in order_info['items']):
            # In a real system, this would interact with a return API
            confirmation_message = f"Return initiated successfully for '{item_name}' in order {order_id}. Please check your email for instructions."
            print(f"--- Tool Output: {confirmation_message} ---")
            return confirmation_message
        else:
            error_message = f"Item '{item_name}' not found in order {order_id}."
            print(f"--- Tool Output: {error_message} ---")
            return error_message
    else:
        error_message = f"Order {order_id} not found, cannot initiate return."
        print(f"--- Tool Output: {error_message} ---")
        return error_message

# Map tool names to their functions
available_tools = {
    "search_knowledge_base": search_knowledge_base_tool,
    "get_order_status": get_order_status_tool,
    "initiate_return": initiate_return_tool,
}

# --- Agentic Logic (Inspired by smolagents) ---

# System Prompt to guide the LLM
SYSTEM_PROMPT = """You are a helpful and friendly customer support agent.
Your goal is to assist the user with their queries regarding orders, products, or technical issues.
You have access to several tools to help you gather information or perform actions. Use them when necessary.

Available tools:
1.  `search_knowledge_base(query: str)`: Use this tool to find information in the knowledge base about products, FAQs, or troubleshooting steps. Use it whenever you need information you don't have or to confirm details.
2.  `get_order_status(order_id: str)`: Use this tool to check the current status and items of a customer's order. Requires an order ID.
3.  `initiate_return(order_id: str, item_name: str)`: Use this tool to start a return process for a specific item in an order. Requires the order ID and the exact item name from the order.

To use a tool, respond *only* with the following format on a new line, followed by no other text:
TOOL_USE: tool_name(arg1=value1, arg2=value2, ...)
Ensure arguments are correctly formatted (e.g., strings in quotes if necessary, though the simple parser below expects simple key=value).

If you have enough information, have completed the user's request, or cannot use a tool, respond directly to the user in natural language.
Always be polite, empathetic, and try to resolve the user's issue efficiently.
If the user asks for something unrelated to your tools or knowledge base, politely state that you can only help with support queries.
Manage the conversation flow. Ask clarifying questions if needed (e.g., ask for an order ID if the user asks about an order).
After a tool is used and its output is provided, you will receive the output. Interpret the tool output and decide the next step: either use another tool, ask a clarifying question, or provide the final answer to the user.

Conversation History:
"""

def parse_tool_call(text: str) -> tuple[str | None, dict | None]:
    """
    Parses a string to extract a tool name and arguments.
    Expected format: TOOL_USE: tool_name(arg1=value1, arg2=value2, ...)
    Returns: (tool_name, args_dict) or (None, None) if parsing fails.
    """
    tool_call_prefix = "TOOL_USE:"
    if text.strip().startswith(tool_call_prefix):
        call_string = text.strip()[len(tool_call_prefix):].strip()
        # Use regex for more robust parsing (still simplified)
        match = re.match(r"(\w+)\((.*)\)", call_string)
        if match:
            tool_name = match.group(1)
            args_string = match.group(2)
            args = {}
            # Simple parsing for key=value pairs
            if args_string:
                # Split by comma, but handle potential commas within values (not handled here for simplicity)
                arg_pairs = args_string.split(',')
                for pair in arg_pairs:
                    if '=' in pair:
                        key, value = pair.split('=', 1)
                        # Basic cleaning of value (remove leading/trailing whitespace and quotes)
                        key = key.strip()
                        value = value.strip().strip("'\"")
                        args[key] = value
            return tool_name, args
        else:
            print(f"DEBUG: Failed regex parse on tool call string: '{call_string}'")
            return None, None
    return None, None


def run_agent_step(user_input: str, conversation_history: list[str], llm_model, llm_tokenizer, tools: dict) -> tuple[str, list[str]]:
    """
    Runs one step of the agent's interaction loop.
    Determines if a tool needs to be called or generates a direct response.
    Handles multi-turn interaction including tool calls.
    """
    # Add user input to history for context
    current_history = conversation_history + [f"User: {user_input}"]

    # Prepare the full prompt for the LLM
    # Join history with newlines, ensuring system prompt is first
    prompt = SYSTEM_PROMPT + "\n" + "\n".join(current_history) + "\nAgent:"

    print(f"\n--- Sending Prompt to LLM ---\n{prompt}\n--- End Prompt ---")

    # Generate response from LLM
    inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
    max_tokens = 250 # Adjust based on desired verbosity

    output_tokens = llm_model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        num_return_sequences=1,
        pad_token_id=llm_tokenizer.eos_token_id,
        # Add other generation parameters if needed (e.g., temperature, top_p)
    )

    # Decode the generated text, excluding the input prompt part
    generated_text = llm_tokenizer.decode(
        output_tokens[0][inputs.input_ids.shape[-1]:],
        skip_special_tokens=True
    ).strip()

    print(f"\n--- LLM Raw Output ---\n{generated_text}\n--- End LLM Output ---")

    # Try to parse the output as a tool call
    tool_name, tool_args = parse_tool_call(generated_text)

    if tool_name and tool_name in tools:
        # It's a tool call!
        print(f"DEBUG: Detected tool call: {tool_name} with args {tool_args}")
        try:
            # Execute the tool
            tool_output = tools[tool_name](**tool_args)

            # Add the tool call and its output to the history
            # This is the observation step for the agent
            updated_history = current_history + [f"Agent: {generated_text}", f"Tool Output: {tool_output}"]

            # Now, call the LLM *again* with the updated history including tool output.
            # This allows the agent to interpret the tool's result and formulate the final response.
            print("\n--- Re-prompting LLM with Tool Output ---")
            second_prompt = SYSTEM_PROMPT + "\n" + "\n".join(updated_history) + "\nAgent (interpreting tool output):"
            print(f"--- Second Prompt ---\n{second_prompt}\n--- End Second Prompt ---")

            second_inputs = llm_tokenizer(second_prompt, return_tensors="pt").to(device)
            second_output_tokens = llm_model.generate(
                **second_inputs,
                max_new_tokens=max_tokens, # Use same max tokens for response
                num_return_sequences=1,
                pad_token_id=llm_tokenizer.eos_token_id,
            )
            final_response = llm_tokenizer.decode(
                second_output_tokens[0][second_inputs.input_ids.shape[-1]:],
                skip_special_tokens=True
            ).strip()

            # Add the final response to the history
            final_history = updated_history + [f"Agent: {final_response}"]
            return final_response, final_history

        except TypeError as e:
             # Handle cases where args don't match tool function signature
            error_msg = f"Error executing tool '{tool_name}': Invalid arguments. {e}"
            print(f"DEBUG: {error_msg}")
            # Inform the agent about the error
            updated_history = current_history + [f"Agent: {generated_text}", f"Tool Output: Error - {error_msg}"]
            second_prompt = SYSTEM_PROMPT + "\n" + "\n".join(updated_history) + "\nAgent (interpreting tool error):"
            second_inputs = llm_tokenizer(second_prompt, return_tensors="pt").to(device)
            second_output_tokens = llm_model.generate(
                **second_inputs,
                max_new_tokens=max_tokens,
                num_return_sequences=1,
                pad_token_id=llm_tokenizer.eos_token_id,
            )
            final_response = llm_tokenizer.decode(
                second_output_tokens[0][second_inputs.input_ids.shape[-1]:],
                skip_special_tokens=True
            ).strip()
            final_history = updated_history + [f"Agent: {final_response}"]
            return final_response, final_history

        except Exception as e:
            # Handle other potential errors during tool execution
            error_msg = f"An unexpected error occurred during tool execution: {e}"
            print(f"DEBUG: {error_msg}")
            # Inform the agent about the error
            updated_history = current_history + [f"Agent: {generated_text}", f"Tool Output: Error - {error_msg}"]
            second_prompt = SYSTEM_PROMPT + "\n" + "\n".join(updated_history) + "\nAgent (interpreting tool error):"
            second_inputs = llm_tokenizer(second_prompt, return_tensors="pt").to(device)
            second_output_tokens = llm_model.generate(
                **second_inputs,
                max_new_tokens=max_tokens,
                num_return_sequences=1,
                pad_token_id=llm_tokenizer.eos_token_id,
            )
            final_response = llm_tokenizer.decode(
                second_output_tokens[0][second_inputs.input_ids.shape[-1]:],
                skip_special_tokens=True
            ).strip()
            final_history = updated_history + [f"Agent: {final_response}"]
            return final_response, final_history


    else:
        # Not a tool call, or tool not recognized. Treat as a direct response.
        print("DEBUG: No valid tool call detected. Treating as direct response.")
        final_history = current_history + [f"Agent: {generated_text}"]
        return generated_text, final_history

# --- Gradio Interface ---

# This function will be called by Gradio when the user sends a message
# It takes the user input and the current chat history (provided by Gradio)
# It returns the updated chat history
def respond(message, chat_history):
    # Gradio chat_history is a list of lists: [[user_msg, agent_msg], ...]
    # We need to convert it to our internal list of strings format for the agent
    conversation_history_for_agent = []
    for user_msg, agent_msg in chat_history:
        conversation_history_for_agent.append(f"User: {user_msg}")
        if agent_msg: # Add agent's previous response if it exists
             # Need to handle potential tool calls/outputs in previous agent messages
             # For simplicity in this pilot, we'll just add the agent's final displayed message
             # A more robust agent would track its own internal history including tool steps
             conversation_history_for_agent.append(f"Agent: {agent_msg}")


    # Run the agent step with the current message and converted history
    # Note: run_agent_step returns the agent's response and its OWN internal history list
    # We will use the response for Gradio, but the history management needs care
    # For this simple Gradio integration, we'll let run_agent_step manage history internally
    # and just use the final response. A more complex app might manage history state separately.

    # Let's modify run_agent_step slightly or manage history outside for Gradio
    # We'll pass the Gradio history and let run_agent_step return the full updated history list
    # Then convert back for Gradio.

    # --- Revised run_agent_step to handle Gradio history ---
    # The previous run_agent_step already returns the updated list history, which is good.
    # We just need to pass the Gradio history converted to list format.

    # Convert Gradio history to list of strings format for run_agent_step
    agent_internal_history = []
    for human, agent in chat_history:
        agent_internal_history.append(f"User: {human}")
        if agent is not None: # Check if agent has responded yet in a turn
             # Simple approach: just add the final agent message from the pair
            agent_internal_history.append(f"Agent: {agent}")
            # Note: This simple conversion loses the intermediate Tool Use / Tool Output steps
            # from the *previous* turns if they occurred.
            # A more advanced implementation would need to store the agent's full step history
            # separately and pass that to run_agent_step.
            # For this pilot, the current turn's tool use/output is added *within* run_agent_step
            # before the final response is generated. This is sufficient to demonstrate the logic
            # for the current turn.

    # Run the agent logic
    agent_response, updated_agent_internal_history = run_agent_step(message, agent_internal_history, model, tokenizer, available_tools)

    # Append the agent's final response to the Gradio chat history
    chat_history.append((message, agent_response))

    # Note: The updated_agent_internal_history contains the full trace for the *last* turn.
    # If you wanted to persist the full detailed history across turns for the agent's context,
    # you would need to store updated_agent_internal_history in the Gradio state.
    # For this pilot, the simple conversion above provides enough context for the LLM
    # to remember previous user/agent messages, but not the intermediate tool steps of past turns.

    return "", chat_history # Return empty string for message box, and updated history


# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Agentic RAG Pilot Demo
        This is a small-scale pilot demonstrating an agentic system with RAG and tool use
        for customer support queries, inspired by the smolagents framework.
        Ask me about order status (try Order ID 12345 or 67890) or how to reset your router.
        """
    )
    # Update chatbot with type parameter
    chatbot = gr.Chatbot(height=400, type='messages')
    # Textbox for user input
    msg = gr.Textbox(label="Your Message")
    # Button to send message
    send_button = gr.Button("Send")
    # Button to clear chat
    clear = gr.ClearButton([msg, chatbot])

    # Link the send button and textbox submission to the respond function
    # When the send button is clicked OR the user presses Enter in the textbox,
    # call the respond function with the message and current chat history.
    # The output updates the message box (clearing it) and the chatbot.
    send_button.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
    msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])


# Launch the Gradio app with updated parameters
if __name__ == "__main__":
    demo.queue().launch()  # Use queue() method instead of enable_queue parameter