Spaces:
Sleeping
Sleeping
import gradio as gr | |
import asyncio | |
import logging | |
import tempfile | |
import json | |
import re | |
import requests | |
from typing import Optional, Dict, Any, List | |
from services.audio_service import AudioService | |
from services.llm_service import LLMService | |
from services.screen_service import ScreenService | |
from config.settings import Settings | |
from config.prompts import get_generic_prompt, get_vision_prompt | |
# Configure root logger | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class MCPRestClient: | |
def __init__(self, base_url: str = "http://localhost:8000"): | |
self.base_url = base_url.rstrip('/') | |
async def initialize(self) -> bool: | |
"""Test connection to MCP server""" | |
mcp_available = False | |
try: | |
response = requests.get(f"{self.base_url}/", timeout=5) | |
if response.status_code == 200: | |
logger.info("Successfully connected to MCP server") | |
mcp_available = True | |
return mcp_available | |
else: | |
raise ConnectionError(f"MCP server returned status {response.status_code}") | |
except Exception as e: | |
logger.error(f"Failed to connect to MCP server at {self.base_url}: {e}") | |
logger.info("IRIS did not detect any MCP server. If you're running this in a HuggingFace space, please referr to the readme.md documentation.") | |
async def get_available_tools(self) -> Dict[str, Dict]: | |
"""Get list of available tools from MCP server""" | |
try: | |
response = requests.get(f"{self.base_url}/tools", timeout=5) | |
if response.status_code == 200: | |
data = response.json() | |
tools = {} | |
for tool in data.get("tools", []): | |
tools[tool["name"]] = { | |
"description": tool.get("description", ""), | |
"inputSchema": tool.get("inputSchema", {}) | |
} | |
return tools | |
else: | |
logger.error(f"Failed to get tools: HTTP {response.status_code}") | |
return {} | |
except Exception as e: | |
logger.error(f"Failed to get tools: {e}") | |
return {} | |
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: | |
"""Call a tool on the MCP server""" | |
try: | |
payload = { | |
"name": tool_name, | |
"arguments": arguments | |
} | |
response = requests.post( | |
f"{self.base_url}/tools/call", | |
json=payload, | |
timeout=30 | |
) | |
if response.status_code == 200: | |
data = response.json() | |
if data.get("success"): | |
return data.get("result") | |
else: | |
return {"error": data.get("error", "Unknown error")} | |
else: | |
return {"error": f"HTTP {response.status_code}: {response.text}"} | |
except Exception as e: | |
return {"error": str(e)} | |
async def close(self): | |
"""Nothing to close with requests""" | |
pass | |
class AgenticChatbot: | |
def __init__(self): | |
self.settings = Settings() | |
# AudioService | |
audio_api_key = ( | |
self.settings.hf_token | |
if self.settings.effective_audio_provider == "huggingface" | |
else self.settings.openai_api_key | |
) | |
self.audio_service = AudioService( | |
api_key=audio_api_key, | |
stt_provider="fal-ai", | |
stt_model=self.settings.stt_model, | |
tts_model=self.settings.tts_model, | |
) | |
# LLMService | |
self.llm_service = LLMService( | |
api_key=self.settings.llm_api_key, | |
model_name=self.settings.effective_model_name, | |
) | |
# MCPService - Now using REST client | |
mcp_server_url = getattr(self.settings, 'mcp_server_url', 'http://localhost:8000') | |
self.mcp_service = MCPRestClient(mcp_server_url) | |
# ScreenService | |
self.screen_service = ScreenService( | |
prompt=get_vision_prompt(), | |
model=self.settings.NEBIUS_MODEL, | |
fps=0.05, | |
queue_size=2, | |
monitor=1, | |
compression_quality=self.settings.screen_compression_quality, | |
max_width=self.settings.max_width, | |
max_height=self.settings.max_height, | |
) | |
self.latest_screen_context: str = "" | |
self.conversation_history: List[Dict[str, Any]] = [] | |
async def initialize(self): | |
try: | |
mcp_available = await self.mcp_service.initialize() | |
if not mcp_available: | |
return False | |
tools = await self.mcp_service.get_available_tools() | |
logger.info(f"Initialized with {len(tools)} MCP tools") | |
except Exception as e: | |
logger.error(f"MCP init failed: {e}") | |
# Screen callbacks | |
def _on_screen_result(self, resp: dict, latency: float, frame_b64: str): | |
try: | |
content = resp.choices[0].message.content | |
except Exception: | |
content = str(resp) | |
self.latest_screen_context = content | |
logger.info(f"[Screen] {latency*1000:.0f}ms β {content}") | |
def _get_conversation_history(self) -> List[Dict[str, str]]: | |
"""Return the current conversation history for the screen service""" | |
return self.conversation_history.copy() | |
def start_screen_sharing(self) -> str: | |
self.latest_screen_context = "" | |
# Pass the history getter method to screen service | |
self.screen_service.start( | |
self._on_screen_result, | |
history_getter=self._get_conversation_history # Use the method reference | |
) | |
return "β Screen sharing started." | |
async def stop_screen_sharing( | |
self, | |
history: Optional[List[Dict[str, str]]] | |
) -> (List[Dict[str, str]], str, Optional[str]): | |
"""Stop screen sharing and append an LLM-generated summary to the chat.""" | |
# Stop capture | |
self.screen_service.stop() | |
# Get the latest vision context | |
vision_ctx = self.latest_screen_context | |
if vision_ctx and history is not None: | |
# Call process_message with the vision context as user input | |
updated_history, audio_path = await self.process_message( | |
text_input=f"VISION MODEL OUTPUT: {vision_ctx}", | |
audio_input=None, | |
history=history | |
) | |
return updated_history, "π Screen sharing stopped.", audio_path | |
# If no vision context or history, just return | |
return history or [], "π Screen sharing stopped.", None | |
async def execute_tool_calls(self, response_text: str) -> str: | |
"""Parse and execute function calls from LLM response using robust regex parsing""" | |
# Clean the response text - remove code blocks and extra formatting | |
cleaned_text = re.sub(r'```[a-zA-Z]*\n?', '', response_text) # Remove code block markers | |
cleaned_text = re.sub(r'\n```', '', cleaned_text) # Remove closing code blocks | |
# Pattern for function calls: function_name(arg1="value1", arg2=value2, arg3=true) | |
function_pattern = r'(\w+)\s*\(\s*([^)]*)\s*\)' | |
results = [] | |
# Find all function calls in the cleaned response | |
for match in re.finditer(function_pattern, cleaned_text): | |
tool_name = match.group(1) | |
args_str = match.group(2).strip() | |
# Skip if this isn't actually a tool (check against available tools) | |
available_tools = await self.mcp_service.get_available_tools() | |
if tool_name not in available_tools: | |
continue | |
try: | |
# Parse arguments using regex for key=value pairs | |
args = {} | |
if args_str: | |
# Pattern for key=value pairs, handling quoted strings, numbers, booleans | |
arg_pattern = r'(\w+)\s*=\s*(?:"([^"]*)"|\'([^\']*)\'|(\w+))' | |
for arg_match in re.finditer(arg_pattern, args_str): | |
key = arg_match.group(1) | |
# Get the value from whichever group matched (quoted or unquoted) | |
value = (arg_match.group(2) or | |
arg_match.group(3) or | |
arg_match.group(4)) | |
# Type conversion for common types | |
if value.lower() == 'true': | |
args[key] = True | |
elif value.lower() == 'false': | |
args[key] = False | |
elif value.isdigit(): | |
args[key] = int(value) | |
elif value.replace('.', '').isdigit(): | |
args[key] = float(value) | |
else: | |
args[key] = value | |
# Execute the tool | |
logger.info(f"Executing tool: {tool_name} with args: {args}") | |
result = await self.mcp_service.call_tool(tool_name, args) | |
results.append({ | |
'tool': tool_name, | |
'args': args, | |
'result': result | |
}) | |
except Exception as e: | |
results.append({ | |
'tool': tool_name, | |
'args': args if 'args' in locals() else {}, | |
'error': str(e) | |
}) | |
# Format results for LLM | |
if not results: | |
return "" | |
formatted_results = [] | |
for result in results: | |
if 'error' in result: | |
formatted_results.append( | |
f"Tool {result['tool']} failed: {result['error']}" | |
) | |
else: | |
formatted_results.append( | |
f"Tool {result['tool']} executed successfully:\n{json.dumps(result['result'], indent=2)}" | |
) | |
return "\n\n".join(formatted_results) | |
# Chat / tool integration | |
async def generate_response( | |
self, | |
user_input: str, | |
screen_context: str = "", | |
tool_result: str = "" | |
) -> str: | |
# Retrieve available tools metadata | |
tools = await self.mcp_service.get_available_tools() | |
# Format tool list for prompt | |
tool_desc = "\n".join(f"- {name}: {info.get('description','')}" for name, info in tools.items()) | |
# Build messages | |
messages: List[Dict[str, str]] = [ | |
{"role": "system", "content": get_generic_prompt()}, | |
] | |
# Inform LLM about tools | |
if tool_desc: | |
messages.append({"role": "system", "content": f"Available tools:\n{tool_desc}"}) | |
messages.append({"role": "user", "content": user_input}) | |
if tool_result: | |
messages.append({"role": "assistant", "content": tool_result}) | |
return await self.llm_service.get_chat_completion(messages) | |
async def process_message( | |
self, | |
text_input: str, | |
audio_input: Optional[str], | |
history: List[Dict[str, str]] | |
) -> (List[Dict[str, str]], Optional[str]): | |
# Debug: Log the incoming state | |
logger.info(f"=== PROCESS_MESSAGE START ===") | |
for i, msg in enumerate(history[-3:]): | |
logger.info(f" {len(history) - 3 + i}: {msg.get('role')} - {msg.get('content', '')[:100]}...") | |
# Update the internal conversation history to match the UI history | |
self.conversation_history = history.copy() | |
# STT | |
transcript = "" | |
if audio_input: | |
transcript = await self.audio_service.speech_to_text(audio_input) | |
user_input = (text_input + " " + transcript).strip() | |
# If no input, return unchanged | |
if not user_input: | |
return history, None | |
# Check if this is a vision model output being processed | |
is_vision_output = user_input.startswith("VISION MODEL OUTPUT:") | |
# Add user message to both histories (ALWAYS add the user input) | |
user_message = {"role": "user", "content": user_input} | |
history.append(user_message) | |
self.conversation_history.append(user_message) | |
# Handle screen context - only for regular user inputs, not vision outputs | |
screen_ctx = "" | |
if not is_vision_output and self.latest_screen_context: | |
screen_ctx = self.latest_screen_context | |
# Clear the screen context after using it to prevent reuse | |
self.latest_screen_context = "" | |
# Get initial LLM response (may include tool calls) | |
assistant_reply = await self.generate_response(user_input, screen_ctx) | |
# Check if response contains function calls and execute them | |
tool_results = await self.execute_tool_calls(assistant_reply) | |
if tool_results: | |
tool_message = {"role": "assistant", "content": tool_results} | |
history.append(tool_message) | |
self.conversation_history.append(tool_message) | |
# Get final response after tool execution | |
assistant_reply = await self.generate_response(user_input, screen_ctx, tool_results) | |
# ALWAYS add the final assistant response to both histories | |
assistant_message = {"role": "assistant", "content": assistant_reply} | |
history.append(assistant_message) | |
self.conversation_history.append(assistant_message) | |
# TTS - only speak the assistant reply for regular inputs | |
audio_path = None | |
audio_bytes = await self.audio_service.text_to_speech(assistant_reply) | |
if audio_bytes: | |
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
tmp.write(audio_bytes) | |
tmp.close() | |
audio_path = tmp.name | |
logger.info(f"=== PROCESS_MESSAGE END ===") | |
return history, audio_path | |
async def cleanup(self): | |
"""Cleanup resources""" | |
await self.mcp_service.close() | |
# ββββββββββββββββββββββββββββββββββ | |
# Gradio interface setup | |
# ββββββββββββββββββββββββββββββββββ | |
chatbot = AgenticChatbot() | |
async def setup_gradio_interface() -> gr.Blocks: | |
# Try to initialize MCP; if it fails, weβll show a static message in Gradio | |
mcp_available = await chatbot.initialize() | |
with gr.Blocks(title="Agentic Chatbot", theme=gr.themes.Soft()) as demo: | |
# If no MCP server, show a banner | |
if not mcp_available: | |
gr.Markdown( | |
""" | |
<div style="padding:10px; border:2px solid #f00; border-radius:5px; background-color:#fee; color: #000;"> | |
**β οΈ No MCP detected. Please refer to the README documentation. IRIS requires a Virtual Environment to run.** | |
</div> | |
""" | |
) | |
chat = gr.Chatbot(type="messages", label="Conversation") | |
text_input = gr.Textbox(lines=2, placeholder="Type your messageβ¦", label="Text") | |
audio_input = gr.Audio(sources=["microphone"], type="filepath", label="Voice") | |
# Screen-sharing controls | |
screen_status = gr.Textbox(label="Screen Sharing Status", interactive=False) | |
start_btn = gr.Button("Start sharing screen") | |
stop_btn = gr.Button("Stop sharing screen") | |
# AI response audio player | |
audio_output = gr.Audio(label="AI Response", autoplay=True) | |
# Message send | |
send_btn = gr.Button("Send", variant="primary") | |
# Wire up buttons | |
start_btn.click(fn=chatbot.start_screen_sharing, inputs=None, outputs=screen_status) | |
stop_btn.click(fn=chatbot.stop_screen_sharing, | |
inputs=[chat], | |
outputs=[chat, screen_status, audio_output]) | |
send_btn.click( | |
chatbot.process_message, | |
inputs=[text_input, audio_input, chat], | |
outputs=[chat, audio_output] | |
) | |
text_input.submit( | |
chatbot.process_message, | |
inputs=[text_input, audio_input, chat], | |
outputs=[chat, audio_output] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = asyncio.run(setup_gradio_interface()) | |
demo.launch(server_name="0.0.0.0", server_port=7860) |