Spaces:
Running
Running
File size: 9,429 Bytes
b1f90a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import inspect
import logging
import uuid
from datetime import date, datetime, time
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Type, Union, get_type_hints
from browser_use.controller.registry.views import ActionModel
from langchain.tools import BaseTool
from langchain_mcp_adapters.client import MultiServerMCPClient
from pydantic import BaseModel, Field, create_model
from pydantic.v1 import BaseModel, Field
logger = logging.getLogger(__name__)
async def setup_mcp_client_and_tools(mcp_server_config: Dict[str, Any]) -> Optional[MultiServerMCPClient]:
"""
Initializes the MultiServerMCPClient, connects to servers, fetches tools,
filters them, and returns a flat list of usable tools and the client instance.
Returns:
A tuple containing:
- list[BaseTool]: The filtered list of usable LangChain tools.
- MultiServerMCPClient | None: The initialized and started client instance, or None on failure.
"""
logger.info("Initializing MultiServerMCPClient...")
if not mcp_server_config:
logger.error("No MCP server configuration provided.")
return None
try:
if "mcpServers" in mcp_server_config:
mcp_server_config = mcp_server_config["mcpServers"]
client = MultiServerMCPClient(mcp_server_config)
await client.__aenter__()
return client
except Exception as e:
logger.error(f"Failed to setup MCP client or fetch tools: {e}", exc_info=True)
return None
def create_tool_param_model(tool: BaseTool) -> Type[BaseModel]:
"""Creates a Pydantic model from a LangChain tool's schema"""
# Get tool schema information
json_schema = tool.args_schema
tool_name = tool.name
# If the tool already has a schema defined, convert it to a new param_model
if json_schema is not None:
# Create new parameter model
params = {}
# Process properties if they exist
if 'properties' in json_schema:
# Find required fields
required_fields: Set[str] = set(json_schema.get('required', []))
for prop_name, prop_details in json_schema['properties'].items():
field_type = resolve_type(prop_details, f"{tool_name}_{prop_name}")
# Check if parameter is required
is_required = prop_name in required_fields
# Get default value and description
default_value = prop_details.get('default', ... if is_required else None)
description = prop_details.get('description', '')
# Add field constraints
field_kwargs = {'default': default_value}
if description:
field_kwargs['description'] = description
# Add additional constraints if present
if 'minimum' in prop_details:
field_kwargs['ge'] = prop_details['minimum']
if 'maximum' in prop_details:
field_kwargs['le'] = prop_details['maximum']
if 'minLength' in prop_details:
field_kwargs['min_length'] = prop_details['minLength']
if 'maxLength' in prop_details:
field_kwargs['max_length'] = prop_details['maxLength']
if 'pattern' in prop_details:
field_kwargs['pattern'] = prop_details['pattern']
# Add to parameters dictionary
params[prop_name] = (field_type, Field(**field_kwargs))
return create_model(
f'{tool_name}_parameters',
__base__=ActionModel,
**params, # type: ignore
)
# If no schema is defined, extract parameters from the _run method
run_method = tool._run
sig = inspect.signature(run_method)
# Get type hints for better type information
try:
type_hints = get_type_hints(run_method)
except Exception:
type_hints = {}
params = {}
for name, param in sig.parameters.items():
# Skip 'self' parameter and any other parameters you want to exclude
if name == 'self':
continue
# Get annotation from type hints if available, otherwise from signature
annotation = type_hints.get(name, param.annotation)
if annotation == inspect.Parameter.empty:
annotation = Any
# Use default value if available, otherwise make it required
if param.default != param.empty:
params[name] = (annotation, param.default)
else:
params[name] = (annotation, ...)
return create_model(
f'{tool_name}_parameters',
__base__=ActionModel,
**params, # type: ignore
)
def resolve_type(prop_details: Dict[str, Any], prefix: str = "") -> Any:
"""Recursively resolves JSON schema type to Python/Pydantic type"""
# Handle reference types
if '$ref' in prop_details:
# In a real application, reference resolution would be needed
return Any
# Basic type mapping
type_mapping = {
'string': str,
'integer': int,
'number': float,
'boolean': bool,
'array': List,
'object': Dict,
'null': type(None),
}
# Handle formatted strings
if prop_details.get('type') == 'string' and 'format' in prop_details:
format_mapping = {
'date-time': datetime,
'date': date,
'time': time,
'email': str,
'uri': str,
'url': str,
'uuid': uuid.UUID,
'binary': bytes,
}
return format_mapping.get(prop_details['format'], str)
# Handle enum types
if 'enum' in prop_details:
enum_values = prop_details['enum']
# Create dynamic enum class with safe names
enum_dict = {}
for i, v in enumerate(enum_values):
# Ensure enum names are valid Python identifiers
if isinstance(v, str):
key = v.upper().replace(' ', '_').replace('-', '_')
if not key.isidentifier():
key = f"VALUE_{i}"
else:
key = f"VALUE_{i}"
enum_dict[key] = v
# Only create enum if we have values
if enum_dict:
return Enum(f"{prefix}_Enum", enum_dict)
return str # Fallback
# Handle array types
if prop_details.get('type') == 'array' and 'items' in prop_details:
item_type = resolve_type(prop_details['items'], f"{prefix}_item")
return List[item_type] # type: ignore
# Handle object types with properties
if prop_details.get('type') == 'object' and 'properties' in prop_details:
nested_params = {}
for nested_name, nested_details in prop_details['properties'].items():
nested_type = resolve_type(nested_details, f"{prefix}_{nested_name}")
# Get required field info
required_fields = prop_details.get('required', [])
is_required = nested_name in required_fields
default_value = nested_details.get('default', ... if is_required else None)
description = nested_details.get('description', '')
field_kwargs = {'default': default_value}
if description:
field_kwargs['description'] = description
nested_params[nested_name] = (nested_type, Field(**field_kwargs))
# Create nested model
nested_model = create_model(f"{prefix}_Model", **nested_params)
return nested_model
# Handle union types (oneOf, anyOf)
if 'oneOf' in prop_details or 'anyOf' in prop_details:
union_schema = prop_details.get('oneOf') or prop_details.get('anyOf')
union_types = []
for i, t in enumerate(union_schema):
union_types.append(resolve_type(t, f"{prefix}_{i}"))
if union_types:
return Union.__getitem__(tuple(union_types)) # type: ignore
return Any
# Handle allOf (intersection types)
if 'allOf' in prop_details:
nested_params = {}
for i, schema_part in enumerate(prop_details['allOf']):
if 'properties' in schema_part:
for nested_name, nested_details in schema_part['properties'].items():
nested_type = resolve_type(nested_details, f"{prefix}_allOf_{i}_{nested_name}")
# Check if required
required_fields = schema_part.get('required', [])
is_required = nested_name in required_fields
nested_params[nested_name] = (nested_type, ... if is_required else None)
# Create composite model
if nested_params:
composite_model = create_model(f"{prefix}_CompositeModel", **nested_params)
return composite_model
return Dict
# Default to basic types
schema_type = prop_details.get('type', 'string')
if isinstance(schema_type, list):
# Handle multiple types (e.g., ["string", "null"])
non_null_types = [t for t in schema_type if t != 'null']
if non_null_types:
primary_type = type_mapping.get(non_null_types[0], Any)
if 'null' in schema_type:
return Optional[primary_type] # type: ignore
return primary_type
return Any
return type_mapping.get(schema_type, Any)
|