import inspect
import logging
import uuid
from datetime import date, datetime, time
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Type, Union, get_type_hints

from browser_use.controller.registry.views import ActionModel
from langchain.tools import BaseTool
from langchain_mcp_adapters.client import MultiServerMCPClient
from pydantic import BaseModel, Field, create_model
from pydantic.v1 import BaseModel, Field

logger = logging.getLogger(__name__)


async def setup_mcp_client_and_tools(mcp_server_config: Dict[str, Any]) -> Optional[MultiServerMCPClient]:
    """
    Initializes the MultiServerMCPClient, connects to servers, fetches tools,
    filters them, and returns a flat list of usable tools and the client instance.

    Returns:
        A tuple containing:
        - list[BaseTool]: The filtered list of usable LangChain tools.
        - MultiServerMCPClient | None: The initialized and started client instance, or None on failure.
    """

    logger.info("Initializing MultiServerMCPClient...")

    if not mcp_server_config:
        logger.error("No MCP server configuration provided.")
        return None

    try:
        if "mcpServers" in mcp_server_config:
            mcp_server_config = mcp_server_config["mcpServers"]
        client = MultiServerMCPClient(mcp_server_config)
        await client.__aenter__()
        return client

    except Exception as e:
        logger.error(f"Failed to setup MCP client or fetch tools: {e}", exc_info=True)
        return None


def create_tool_param_model(tool: BaseTool) -> Type[BaseModel]:
    """Creates a Pydantic model from a LangChain tool's schema"""

    # Get tool schema information
    json_schema = tool.args_schema
    tool_name = tool.name

    # If the tool already has a schema defined, convert it to a new param_model
    if json_schema is not None:

        # Create new parameter model
        params = {}

        # Process properties if they exist
        if 'properties' in json_schema:
            # Find required fields
            required_fields: Set[str] = set(json_schema.get('required', []))

            for prop_name, prop_details in json_schema['properties'].items():
                field_type = resolve_type(prop_details, f"{tool_name}_{prop_name}")

                # Check if parameter is required
                is_required = prop_name in required_fields

                # Get default value and description
                default_value = prop_details.get('default', ... if is_required else None)
                description = prop_details.get('description', '')

                # Add field constraints
                field_kwargs = {'default': default_value}
                if description:
                    field_kwargs['description'] = description

                # Add additional constraints if present
                if 'minimum' in prop_details:
                    field_kwargs['ge'] = prop_details['minimum']
                if 'maximum' in prop_details:
                    field_kwargs['le'] = prop_details['maximum']
                if 'minLength' in prop_details:
                    field_kwargs['min_length'] = prop_details['minLength']
                if 'maxLength' in prop_details:
                    field_kwargs['max_length'] = prop_details['maxLength']
                if 'pattern' in prop_details:
                    field_kwargs['pattern'] = prop_details['pattern']

                # Add to parameters dictionary
                params[prop_name] = (field_type, Field(**field_kwargs))

        return create_model(
            f'{tool_name}_parameters',
            __base__=ActionModel,
            **params,  # type: ignore
        )

    # If no schema is defined, extract parameters from the _run method
    run_method = tool._run
    sig = inspect.signature(run_method)

    # Get type hints for better type information
    try:
        type_hints = get_type_hints(run_method)
    except Exception:
        type_hints = {}

    params = {}
    for name, param in sig.parameters.items():
        # Skip 'self' parameter and any other parameters you want to exclude
        if name == 'self':
            continue

        # Get annotation from type hints if available, otherwise from signature
        annotation = type_hints.get(name, param.annotation)
        if annotation == inspect.Parameter.empty:
            annotation = Any

        # Use default value if available, otherwise make it required
        if param.default != param.empty:
            params[name] = (annotation, param.default)
        else:
            params[name] = (annotation, ...)

    return create_model(
        f'{tool_name}_parameters',
        __base__=ActionModel,
        **params,  # type: ignore
    )


def resolve_type(prop_details: Dict[str, Any], prefix: str = "") -> Any:
    """Recursively resolves JSON schema type to Python/Pydantic type"""

    # Handle reference types
    if '$ref' in prop_details:
        # In a real application, reference resolution would be needed
        return Any

    # Basic type mapping
    type_mapping = {
        'string': str,
        'integer': int,
        'number': float,
        'boolean': bool,
        'array': List,
        'object': Dict,
        'null': type(None),
    }

    # Handle formatted strings
    if prop_details.get('type') == 'string' and 'format' in prop_details:
        format_mapping = {
            'date-time': datetime,
            'date': date,
            'time': time,
            'email': str,
            'uri': str,
            'url': str,
            'uuid': uuid.UUID,
            'binary': bytes,
        }
        return format_mapping.get(prop_details['format'], str)

    # Handle enum types
    if 'enum' in prop_details:
        enum_values = prop_details['enum']
        # Create dynamic enum class with safe names
        enum_dict = {}
        for i, v in enumerate(enum_values):
            # Ensure enum names are valid Python identifiers
            if isinstance(v, str):
                key = v.upper().replace(' ', '_').replace('-', '_')
                if not key.isidentifier():
                    key = f"VALUE_{i}"
            else:
                key = f"VALUE_{i}"
            enum_dict[key] = v

        # Only create enum if we have values
        if enum_dict:
            return Enum(f"{prefix}_Enum", enum_dict)
        return str  # Fallback

    # Handle array types
    if prop_details.get('type') == 'array' and 'items' in prop_details:
        item_type = resolve_type(prop_details['items'], f"{prefix}_item")
        return List[item_type]  # type: ignore

    # Handle object types with properties
    if prop_details.get('type') == 'object' and 'properties' in prop_details:
        nested_params = {}
        for nested_name, nested_details in prop_details['properties'].items():
            nested_type = resolve_type(nested_details, f"{prefix}_{nested_name}")
            # Get required field info
            required_fields = prop_details.get('required', [])
            is_required = nested_name in required_fields
            default_value = nested_details.get('default', ... if is_required else None)
            description = nested_details.get('description', '')

            field_kwargs = {'default': default_value}
            if description:
                field_kwargs['description'] = description

            nested_params[nested_name] = (nested_type, Field(**field_kwargs))

        # Create nested model
        nested_model = create_model(f"{prefix}_Model", **nested_params)
        return nested_model

    # Handle union types (oneOf, anyOf)
    if 'oneOf' in prop_details or 'anyOf' in prop_details:
        union_schema = prop_details.get('oneOf') or prop_details.get('anyOf')
        union_types = []
        for i, t in enumerate(union_schema):
            union_types.append(resolve_type(t, f"{prefix}_{i}"))

        if union_types:
            return Union.__getitem__(tuple(union_types))  # type: ignore
        return Any

    # Handle allOf (intersection types)
    if 'allOf' in prop_details:
        nested_params = {}
        for i, schema_part in enumerate(prop_details['allOf']):
            if 'properties' in schema_part:
                for nested_name, nested_details in schema_part['properties'].items():
                    nested_type = resolve_type(nested_details, f"{prefix}_allOf_{i}_{nested_name}")
                    # Check if required
                    required_fields = schema_part.get('required', [])
                    is_required = nested_name in required_fields
                    nested_params[nested_name] = (nested_type, ... if is_required else None)

        # Create composite model
        if nested_params:
            composite_model = create_model(f"{prefix}_CompositeModel", **nested_params)
            return composite_model
        return Dict

    # Default to basic types
    schema_type = prop_details.get('type', 'string')
    if isinstance(schema_type, list):
        # Handle multiple types (e.g., ["string", "null"])
        non_null_types = [t for t in schema_type if t != 'null']
        if non_null_types:
            primary_type = type_mapping.get(non_null_types[0], Any)
            if 'null' in schema_type:
                return Optional[primary_type]  # type: ignore
            return primary_type
        return Any

    return type_mapping.get(schema_type, Any)