Spaces:
Running
Running
File size: 7,920 Bytes
b1f90a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import pdb
import pyperclip
from typing import Optional, Type, Callable, Dict, Any, Union, Awaitable, TypeVar
from pydantic import BaseModel
from browser_use.agent.views import ActionResult
from browser_use.browser.context import BrowserContext
from browser_use.controller.service import Controller, DoneAction
from browser_use.controller.registry.service import Registry, RegisteredAction
from main_content_extractor import MainContentExtractor
from browser_use.controller.views import (
ClickElementAction,
DoneAction,
ExtractPageContentAction,
GoToUrlAction,
InputTextAction,
OpenTabAction,
ScrollAction,
SearchGoogleAction,
SendKeysAction,
SwitchTabAction,
)
import logging
import inspect
import asyncio
import os
from langchain_core.language_models.chat_models import BaseChatModel
from browser_use.agent.views import ActionModel, ActionResult
from src.utils.mcp_client import create_tool_param_model, setup_mcp_client_and_tools
from browser_use.utils import time_execution_sync
logger = logging.getLogger(__name__)
Context = TypeVar('Context')
class CustomController(Controller):
def __init__(self, exclude_actions: list[str] = [],
output_model: Optional[Type[BaseModel]] = None,
ask_assistant_callback: Optional[Union[Callable[[str, BrowserContext], Dict[str, Any]], Callable[
[str, BrowserContext], Awaitable[Dict[str, Any]]]]] = None,
):
super().__init__(exclude_actions=exclude_actions, output_model=output_model)
self._register_custom_actions()
self.ask_assistant_callback = ask_assistant_callback
self.mcp_client = None
self.mcp_server_config = None
def _register_custom_actions(self):
"""Register all custom browser actions"""
@self.registry.action(
"When executing tasks, prioritize autonomous completion. However, if you encounter a definitive blocker "
"that prevents you from proceeding independently – such as needing credentials you don't possess, "
"requiring subjective human judgment, needing a physical action performed, encountering complex CAPTCHAs, "
"or facing limitations in your capabilities – you must request human assistance."
)
async def ask_for_assistant(query: str, browser: BrowserContext):
if self.ask_assistant_callback:
if inspect.iscoroutinefunction(self.ask_assistant_callback):
user_response = await self.ask_assistant_callback(query, browser)
else:
user_response = self.ask_assistant_callback(query, browser)
msg = f"AI ask: {query}. User response: {user_response['response']}"
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
else:
return ActionResult(extracted_content="Human cannot help you. Please try another way.",
include_in_memory=True)
@self.registry.action(
'Upload file to interactive element with file path ',
)
async def upload_file(index: int, path: str, browser: BrowserContext, available_file_paths: list[str]):
if path not in available_file_paths:
return ActionResult(error=f'File path {path} is not available')
if not os.path.exists(path):
return ActionResult(error=f'File {path} does not exist')
dom_el = await browser.get_dom_element_by_index(index)
file_upload_dom_el = dom_el.get_file_upload_element()
if file_upload_dom_el is None:
msg = f'No file upload element found at index {index}'
logger.info(msg)
return ActionResult(error=msg)
file_upload_el = await browser.get_locate_element(file_upload_dom_el)
if file_upload_el is None:
msg = f'No file upload element found at index {index}'
logger.info(msg)
return ActionResult(error=msg)
try:
await file_upload_el.set_input_files(path)
msg = f'Successfully uploaded file to index {index}'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
except Exception as e:
msg = f'Failed to upload file to index {index}: {str(e)}'
logger.info(msg)
return ActionResult(error=msg)
@time_execution_sync('--act')
async def act(
self,
action: ActionModel,
browser_context: Optional[BrowserContext] = None,
#
page_extraction_llm: Optional[BaseChatModel] = None,
sensitive_data: Optional[Dict[str, str]] = None,
available_file_paths: Optional[list[str]] = None,
#
context: Context | None = None,
) -> ActionResult:
"""Execute an action"""
try:
for action_name, params in action.model_dump(exclude_unset=True).items():
if params is not None:
if action_name.startswith("mcp"):
# this is a mcp tool
logger.debug(f"Invoke MCP tool: {action_name}")
mcp_tool = self.registry.registry.actions.get(action_name).function
result = await mcp_tool.ainvoke(params)
else:
result = await self.registry.execute_action(
action_name,
params,
browser=browser_context,
page_extraction_llm=page_extraction_llm,
sensitive_data=sensitive_data,
available_file_paths=available_file_paths,
context=context,
)
if isinstance(result, str):
return ActionResult(extracted_content=result)
elif isinstance(result, ActionResult):
return result
elif result is None:
return ActionResult()
else:
raise ValueError(f'Invalid action result type: {type(result)} of {result}')
return ActionResult()
except Exception as e:
raise e
async def setup_mcp_client(self, mcp_server_config: Optional[Dict[str, Any]] = None):
self.mcp_server_config = mcp_server_config
if self.mcp_server_config:
self.mcp_client = await setup_mcp_client_and_tools(self.mcp_server_config)
self.register_mcp_tools()
def register_mcp_tools(self):
"""
Register the MCP tools used by this controller.
"""
if self.mcp_client:
for server_name in self.mcp_client.server_name_to_tools:
for tool in self.mcp_client.server_name_to_tools[server_name]:
tool_name = f"mcp.{server_name}.{tool.name}"
self.registry.registry.actions[tool_name] = RegisteredAction(
name=tool_name,
description=tool.description,
function=tool,
param_model=create_tool_param_model(tool),
)
logger.info(f"Add mcp tool: {tool_name}")
logger.debug(
f"Registered {len(self.mcp_client.server_name_to_tools[server_name])} mcp tools for {server_name}")
else:
logger.warning(f"MCP client not started.")
async def close_mcp_client(self):
if self.mcp_client:
await self.mcp_client.__aexit__(None, None, None)
|