import asyncio
from langchain_mcp_adapters.tools import load_mcp_tools
from langchain_mcp_adapters.sessions import SSEConnection
from langgraph.prebuilt import create_react_agent
from langchain_ollama.chat_models import ChatOllama
from langchain_anthropic import ChatAnthropic
import gradio as gr
import re
from dotenv import load_dotenv
import os
import json
from datetime import datetime
from typing import List, Any
import re

load_dotenv()

# Global variable to store execution history
execution_history = []

def format_message_for_display(message):
    """Format a message for display in the chat interface"""
    if hasattr(message, 'content'):
        content = message.content
    else:
        content = str(message)
    
    if hasattr(message, 'tool_calls') and message.tool_calls:
        tool_info = []
        for tool_call in message.tool_calls:
            tool_info.append(f"🔧 **Tool Call**: {tool_call['name']}")
            if 'args' in tool_call:
                tool_info.append(f"   **Args**: {json.dumps(tool_call['args'], indent=2)}")
        content += "\n\n" + "\n".join(tool_info)
    
    return content

def add_to_execution_history(step_type: str, data: Any, tab_id: str = None):
    """Add a step to the execution history"""
    timestamp = datetime.now().strftime("%H:%M:%S")
    execution_history.append({
        "timestamp": timestamp,
        "type": step_type,
        "data": data,
        "tab_id": tab_id
    })

def format_execution_history():
    """Format the execution history for display"""
    if not execution_history:
        return "No execution history yet."
    
    formatted_history = []
    for entry in execution_history:
        timestamp = entry["timestamp"]
        step_type = entry["type"]
        tab_id = entry.get("tab_id", "N/A")
        
        if step_type == "user_input":
            formatted_history.append(f"**[{timestamp}] 👤 User (Tab: {tab_id})**\n\n{entry['data']}\n\n")
        elif step_type == "agent_response":
            formatted_history.append(f"**[{timestamp}] 🤖 Agent**\n\n{entry['data']}\n\n")
        elif step_type == "tool_call":
            tool_data = entry['data']
            formatted_history.append(f"**[{timestamp}] 🔧 Tool Call**\n\n**Tool**: {tool_data['name']}\n\n**Arguments**: \n\n```json\n{json.dumps(tool_data.get('args', {}), indent=2)}\n```\n\n")
        elif step_type == "tool_result":
            formatted_history.append(f"**[{timestamp}] ✅ Tool Result**\n\n```\n{entry['data']}\n```\n\n")
        elif step_type == "error":
            formatted_history.append(f"**[{timestamp}] ❌ Error**\n\n{entry['data']}\n\n")
        
        formatted_history.append("---\n\n")
    
    return "".join(formatted_history)

async def initialize_tools():
    """
    Initializes the SSE connection and loads the MCP tools.
    We can reuse this because the tools don't depend on the Anthropic API key.
    """
    connection = SSEConnection(url=os.getenv("MCP_SERVER_URL"), transport="sse")
    tools = await load_mcp_tools(session=None, connection=connection)
    return tools

async def create_agent_with_llm(llm_provider: str, anthropic_key: str | None, ollama_model: str | None, tools):
    """
    Creates a langgraph-react agent dynamically, injecting the Anthropic API key if requested.
    """
    if llm_provider == "anthropic":
        # If a key is provided, we use it; if not, we throw an exception or return an error.
        if not anthropic_key:
            anthropic_key = os.getenv("ANTHROPIC_API_KEY", anthropic_key)
        if not anthropic_key:
            raise ValueError("Anthropic API key is required for the 'anthropic' provider.")
        llm = ChatAnthropic(
            model=os.getenv("ANTHROPIC_MODEL", "claude-3-sonnet-20240229"),
            anthropic_api_key=anthropic_key
        )
    else:
        # In the case of Ollama, we don't depend on a key.
        llm = ChatOllama(model=ollama_model or os.getenv("OLLAMA_MODEL", "qwen3:8b"))

    with open("prompt.txt", "r") as f:
        prompt = f.read()
        
    agent = create_react_agent(llm, tools, prompt=prompt)
    return agent

# We can initialize the tools only once, as they don't depend on the key.
tools = asyncio.get_event_loop().run_until_complete(initialize_tools())

async def chat(history: list, tab_id: str=None, anthropic_api_key: str=None):
    """
    Original API function for compatibility - now with history tracking
    history: list of messages [{"role": "user"/"assistant", "content": "..."}]
    tab_id: a string that the client wants to correlate
    anthropic_api_key: the key sent by the client in each request
    """
    # Extract the last message to add to execution history
    if history:
        last_message = history[-1]["content"]
        add_to_execution_history("user_input", last_message, tab_id)
    
    if tab_id:
        history[-1]["content"] += f"\nThis is your tab_id: {tab_id}"

    llm_provider = os.getenv("LLM_PROVIDER", "ollama").lower()
    ollama_model = os.getenv("OLLAMA_MODEL", "qwen3:8b")

    try:
        agent = await create_agent_with_llm(llm_provider, anthropic_api_key, ollama_model, tools)
    except ValueError as e:
        error_msg = str(e)
        add_to_execution_history("error", error_msg, tab_id)
        return error_msg

    try:
        result = await agent.ainvoke({"messages": history})
        
        # Process all messages in the result to track tool calls
        all_messages = result["messages"]
        
        # Track tool calls and responses
        for msg in all_messages:
            if hasattr(msg, 'tool_calls') and msg.tool_calls:
                for tool_call in msg.tool_calls:
                    add_to_execution_history("tool_call", {
                        "name": tool_call.get("name", "unknown"),
                        "args": tool_call.get("args", {})
                    }, tab_id)
              # Check if it's a tool message (result of tool execution)
            if hasattr(msg, 'name') and msg.name:
                add_to_execution_history("tool_result", msg.content, tab_id)
        
        output = all_messages[-1].content
        cleaned = re.sub(r'<think>.*?</think>', '', output, flags=re.DOTALL).strip()
        
        add_to_execution_history("agent_response", cleaned, tab_id)
        return cleaned
        
    except Exception as e:
        error_msg = f"Error during execution: {str(e)}"
        add_to_execution_history("error", error_msg, tab_id)
        return error_msg

async def chat_with_history_tracking(message: str, history: List, tab_id: str = None, anthropic_api_key: str = None):
    """
    Enhanced chat function that tracks all execution steps
    """
    # Add user input to execution history
    add_to_execution_history("user_input", message, tab_id)
    
    # Convert history format for LangGraph (keeping compatibility)
    messages = []
    for h in history:
        if isinstance(h, dict):
            messages.append(h)
        else:
            # Convert tuple format to dict format
            role = "user" if h[0] == "user" else "assistant"
            messages.append({"role": role, "content": h[1]})
    
    # Add current message
    messages.append({"role": "user", "content": message})
    
    if tab_id:
        messages[-1]["content"] += f"\nThis is your tab_id: {tab_id}"

    llm_provider = os.getenv("LLM_PROVIDER", "ollama").lower()
    ollama_model = os.getenv("OLLAMA_MODEL", "qwen3:8b")

    try:
        agent = await create_agent_with_llm(llm_provider, anthropic_api_key, ollama_model, tools)
    except ValueError as e:
        error_msg = str(e)
        add_to_execution_history("error", error_msg, tab_id)
        history.append([message, error_msg])
        return history, format_execution_history()

    try:
        # Stream the agent execution to capture intermediate steps
        result = await agent.ainvoke({"messages": messages})
        
        # Process all messages in the result
        all_messages = result["messages"]
        
        # Track tool calls and responses
        for msg in all_messages:
            if hasattr(msg, 'tool_calls') and msg.tool_calls:
                for tool_call in msg.tool_calls:
                    add_to_execution_history("tool_call", {
                        "name": tool_call.get("name", "unknown"),
                        "args": tool_call.get("args", {})
                    }, tab_id)
            
            # Check if it's a tool message (result of tool execution)
            if hasattr(msg, 'name') and msg.name:
                add_to_execution_history("tool_result", msg.content, tab_id)
          # Get the final output
        output = all_messages[-1].content
        cleaned = re.sub(r'<think>.*?</think>', '', output, flags=re.DOTALL).strip()
        
        add_to_execution_history("agent_response", cleaned, tab_id)
        history.append([message, cleaned])
        
        return history, format_execution_history()
        
    except Exception as e:
        error_msg = f"Error during execution: {str(e)}"
        add_to_execution_history("error", error_msg, tab_id)
        history.append([message, error_msg])
        return history, format_execution_history()

def clear_history():
    """Clear the execution history"""
    global execution_history
    execution_history = []
    return [], "Execution history cleared."

# Create the enhanced Gradio interface
with gr.Blocks(title="OwlBear Agent - Complete History", theme=gr.themes.Default()) as demo:
    gr.Markdown("# 🦉 OwlBear Agent - Complete Execution View")
    gr.Markdown("This interface shows the complete agent execution process, including tool calls and intermediate steps.")
    gr.Markdown("**Note:** All messages sent to the original API also appear here automatically.")
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("## 💬 Chat")
            chatbot = gr.Chatbot(
                label="Conversation",
                height=400,
                show_label=True,
                container=True,
            )
            
            with gr.Row():
                msg = gr.Textbox(
                    label="Message",
                    placeholder="Type your message here...",
                    lines=2,
                    scale=4
                )
                send_btn = gr.Button("Send", variant="primary", scale=1)
            
            with gr.Row():
                tab_id = gr.Textbox(
                    label="Tab ID",
                    placeholder="Tab ID (optional)",
                    value="main",
                    scale=1
                )
                anthropic_key = gr.Textbox(
                    label="Anthropic API Key",
                    placeholder="Anthropic API Key (optional)",
                    type="password",
                    scale=2
                )
            
            clear_btn = gr.Button("Clear Chat", variant="secondary")
        
        with gr.Column(scale=1):
            gr.Markdown("## 📊 Detailed Execution History")
            gr.Markdown("*Updates automatically every 2 seconds*")
            execution_display = gr.Markdown(
                value="No execution history yet.",
                label="Complete History",
                height=600,
                container=True,
            )
            
            refresh_btn = gr.Button("Refresh History", variant="secondary")
            clear_history_btn = gr.Button("Clear History", variant="secondary")
    
    # Auto-refresh timer for execution history
    timer = gr.Timer(value=2)  # Refresh every 2 seconds
    timer.tick(lambda: format_execution_history(), outputs=[execution_display], show_api=False)
    
    # Event handlers
    def send_message(message, history, tab_id, anthropic_key):
        if not message.strip():
            return history, "", format_execution_history()
        
        # Run the async function
        import asyncio
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        try:
            new_history, execution_history_display = loop.run_until_complete(
                chat_with_history_tracking(message, history, tab_id, anthropic_key)
            )
            return new_history, "", execution_history_display
        finally:
            loop.close()
    
    send_btn.click(
        send_message,
        inputs=[msg, chatbot, tab_id, anthropic_key],
        outputs=[chatbot, msg, execution_display],
        show_api=False
    )
    
    msg.submit(
        send_message,
        inputs=[msg, chatbot, tab_id, anthropic_key],
        outputs=[chatbot, msg, execution_display],
        show_api=False
    )
    
    clear_btn.click(
        lambda: ([], ""),
        outputs=[chatbot, msg],
        show_api=False
    )
    
    refresh_btn.click(
        lambda: format_execution_history(),
        outputs=[execution_display],
        show_api=False
    )
    
    clear_history_btn.click(
        clear_history,
        outputs=[chatbot, execution_display],
        show_api=False
    )

api_demo = gr.Interface(
    fn=chat,
    inputs=[
        gr.JSON(label="history"),
        gr.Textbox(label="tab_id"),
        gr.Textbox(label="anthropic_api_key"),
    ],
    outputs="text",    title="OwlBear Agent - Original API"
)

with open("README.md", "r", encoding="utf-8") as f:
    readme = f.read()
    if readme.startswith("---"):
        parts = readme.split("---", 2)
        if len(parts) >= 3:
            readme = parts[2]


html_blocks = re.findall(r'```html\n(.*?)\n```', readme, re.DOTALL)
for i, html_block in enumerate(html_blocks):
    readme = readme.replace(f"```html\n{html_block}\n```", f"{{HTML_BLOCK_{i}}}")

with gr.Blocks() as intro_demo:
    parts = re.split(r'({HTML_BLOCK_\d+})', readme)
    
    for part in parts:
        if part.startswith("{HTML_BLOCK_"):
            block_idx = int(part.replace("{HTML_BLOCK_", "").replace("}", ""))
            gr.HTML(html_blocks[block_idx])
        else:
            if part.strip():
                gr.Markdown(part)

# Combined interface with tabs
combined_demo = gr.TabbedInterface(
    [intro_demo, demo, api_demo],
    ["README", "Complete View with History", "Original API"],
    title="🧙🏼‍♂️ LLM Game Master - Agent"
)

if __name__ == "__main__":
    combined_demo.launch(server_port=int(os.getenv("GRADIO_PORT", 7860)))