Spaces:
Running
Running
import pdb | |
import pyperclip | |
from typing import Optional, Type, Callable, Dict, Any, Union, Awaitable, TypeVar | |
from pydantic import BaseModel | |
from browser_use.agent.views import ActionResult | |
from browser_use.browser.context import BrowserContext | |
from browser_use.controller.service import Controller, DoneAction | |
from browser_use.controller.registry.service import Registry, RegisteredAction | |
from main_content_extractor import MainContentExtractor | |
from browser_use.controller.views import ( | |
ClickElementAction, | |
DoneAction, | |
ExtractPageContentAction, | |
GoToUrlAction, | |
InputTextAction, | |
OpenTabAction, | |
ScrollAction, | |
SearchGoogleAction, | |
SendKeysAction, | |
SwitchTabAction, | |
) | |
import logging | |
import inspect | |
import asyncio | |
import os | |
from langchain_core.language_models.chat_models import BaseChatModel | |
from browser_use.agent.views import ActionModel, ActionResult | |
from src.utils.mcp_client import create_tool_param_model, setup_mcp_client_and_tools | |
from browser_use.utils import time_execution_sync | |
logger = logging.getLogger(__name__) | |
Context = TypeVar('Context') | |
class CustomController(Controller): | |
def __init__(self, exclude_actions: list[str] = [], | |
output_model: Optional[Type[BaseModel]] = None, | |
ask_assistant_callback: Optional[Union[Callable[[str, BrowserContext], Dict[str, Any]], Callable[ | |
[str, BrowserContext], Awaitable[Dict[str, Any]]]]] = None, | |
): | |
super().__init__(exclude_actions=exclude_actions, output_model=output_model) | |
self._register_custom_actions() | |
self.ask_assistant_callback = ask_assistant_callback | |
self.mcp_client = None | |
self.mcp_server_config = None | |
def _register_custom_actions(self): | |
"""Register all custom browser actions""" | |
async def ask_for_assistant(query: str, browser: BrowserContext): | |
if self.ask_assistant_callback: | |
if inspect.iscoroutinefunction(self.ask_assistant_callback): | |
user_response = await self.ask_assistant_callback(query, browser) | |
else: | |
user_response = self.ask_assistant_callback(query, browser) | |
msg = f"AI ask: {query}. User response: {user_response['response']}" | |
logger.info(msg) | |
return ActionResult(extracted_content=msg, include_in_memory=True) | |
else: | |
return ActionResult(extracted_content="Human cannot help you. Please try another way.", | |
include_in_memory=True) | |
async def upload_file(index: int, path: str, browser: BrowserContext, available_file_paths: list[str]): | |
if path not in available_file_paths: | |
return ActionResult(error=f'File path {path} is not available') | |
if not os.path.exists(path): | |
return ActionResult(error=f'File {path} does not exist') | |
dom_el = await browser.get_dom_element_by_index(index) | |
file_upload_dom_el = dom_el.get_file_upload_element() | |
if file_upload_dom_el is None: | |
msg = f'No file upload element found at index {index}' | |
logger.info(msg) | |
return ActionResult(error=msg) | |
file_upload_el = await browser.get_locate_element(file_upload_dom_el) | |
if file_upload_el is None: | |
msg = f'No file upload element found at index {index}' | |
logger.info(msg) | |
return ActionResult(error=msg) | |
try: | |
await file_upload_el.set_input_files(path) | |
msg = f'Successfully uploaded file to index {index}' | |
logger.info(msg) | |
return ActionResult(extracted_content=msg, include_in_memory=True) | |
except Exception as e: | |
msg = f'Failed to upload file to index {index}: {str(e)}' | |
logger.info(msg) | |
return ActionResult(error=msg) | |
async def act( | |
self, | |
action: ActionModel, | |
browser_context: Optional[BrowserContext] = None, | |
# | |
page_extraction_llm: Optional[BaseChatModel] = None, | |
sensitive_data: Optional[Dict[str, str]] = None, | |
available_file_paths: Optional[list[str]] = None, | |
# | |
context: Context | None = None, | |
) -> ActionResult: | |
"""Execute an action""" | |
try: | |
for action_name, params in action.model_dump(exclude_unset=True).items(): | |
if params is not None: | |
if action_name.startswith("mcp"): | |
# this is a mcp tool | |
logger.debug(f"Invoke MCP tool: {action_name}") | |
mcp_tool = self.registry.registry.actions.get(action_name).function | |
result = await mcp_tool.ainvoke(params) | |
else: | |
result = await self.registry.execute_action( | |
action_name, | |
params, | |
browser=browser_context, | |
page_extraction_llm=page_extraction_llm, | |
sensitive_data=sensitive_data, | |
available_file_paths=available_file_paths, | |
context=context, | |
) | |
if isinstance(result, str): | |
return ActionResult(extracted_content=result) | |
elif isinstance(result, ActionResult): | |
return result | |
elif result is None: | |
return ActionResult() | |
else: | |
raise ValueError(f'Invalid action result type: {type(result)} of {result}') | |
return ActionResult() | |
except Exception as e: | |
raise e | |
async def setup_mcp_client(self, mcp_server_config: Optional[Dict[str, Any]] = None): | |
self.mcp_server_config = mcp_server_config | |
if self.mcp_server_config: | |
self.mcp_client = await setup_mcp_client_and_tools(self.mcp_server_config) | |
self.register_mcp_tools() | |
def register_mcp_tools(self): | |
""" | |
Register the MCP tools used by this controller. | |
""" | |
if self.mcp_client: | |
for server_name in self.mcp_client.server_name_to_tools: | |
for tool in self.mcp_client.server_name_to_tools[server_name]: | |
tool_name = f"mcp.{server_name}.{tool.name}" | |
self.registry.registry.actions[tool_name] = RegisteredAction( | |
name=tool_name, | |
description=tool.description, | |
function=tool, | |
param_model=create_tool_param_model(tool), | |
) | |
logger.info(f"Add mcp tool: {tool_name}") | |
logger.debug( | |
f"Registered {len(self.mcp_client.server_name_to_tools[server_name])} mcp tools for {server_name}") | |
else: | |
logger.warning(f"MCP client not started.") | |
async def close_mcp_client(self): | |
if self.mcp_client: | |
await self.mcp_client.__aexit__(None, None, None) | |