#!/usr/bin/env python3 """ Base classes and interfaces for GAIA tools. """ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Dict, Optional, Union, List from enum import Enum import time import functools from ..utils.exceptions import ToolError, ToolValidationError, ToolExecutionError, ToolTimeoutError class ToolStatus(Enum): """Tool execution status.""" SUCCESS = "success" ERROR = "error" TIMEOUT = "timeout" VALIDATION_FAILED = "validation_failed" @dataclass class ToolResult: """Standardized tool result format.""" status: ToolStatus output: Any error_message: Optional[str] = None execution_time: Optional[float] = None metadata: Dict[str, Any] = field(default_factory=dict) @property def is_success(self) -> bool: """Check if tool execution was successful.""" return self.status == ToolStatus.SUCCESS @property def is_error(self) -> bool: """Check if tool execution failed.""" return self.status in [ToolStatus.ERROR, ToolStatus.TIMEOUT, ToolStatus.VALIDATION_FAILED] def get_output_or_error(self) -> str: """Get output if successful, otherwise error message.""" if self.is_success: return str(self.output) return self.error_message or "Unknown error" class GAIATool(ABC): """Abstract base class for all GAIA tools.""" def __init__(self, name: str, description: str, timeout: int = 60): self.name = name self.description = description self.timeout = timeout self._execution_count = 0 self._total_execution_time = 0.0 @abstractmethod def _execute(self, **kwargs) -> Any: """Execute the tool logic. Must be implemented by subclasses.""" pass @abstractmethod def _validate_input(self, **kwargs) -> None: """Validate input parameters. Must be implemented by subclasses.""" pass def execute(self, **kwargs) -> ToolResult: """Execute tool with standardized error handling and timing.""" start_time = time.time() try: # Input validation self._validate_input(**kwargs) # Execute with timeout result = self._execute_with_timeout(**kwargs) # Record execution execution_time = time.time() - start_time self._record_execution(execution_time) return ToolResult( status=ToolStatus.SUCCESS, output=result, execution_time=execution_time, metadata=self._get_execution_metadata() ) except ToolValidationError as e: execution_time = time.time() - start_time return ToolResult( status=ToolStatus.VALIDATION_FAILED, output=None, error_message=str(e), execution_time=execution_time ) except ToolTimeoutError as e: execution_time = time.time() - start_time return ToolResult( status=ToolStatus.TIMEOUT, output=None, error_message=str(e), execution_time=execution_time ) except Exception as e: execution_time = time.time() - start_time return ToolResult( status=ToolStatus.ERROR, output=None, error_message=f"{self.name} execution failed: {str(e)}", execution_time=execution_time ) def _execute_with_timeout(self, **kwargs) -> Any: """Execute with timeout handling.""" import signal def timeout_handler(signum, frame): raise ToolTimeoutError(f"Tool {self.name} timed out after {self.timeout} seconds") # Set timeout old_handler = signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(self.timeout) try: result = self._execute(**kwargs) signal.alarm(0) # Cancel timeout return result finally: signal.signal(signal.SIGALRM, old_handler) def _record_execution(self, execution_time: float) -> None: """Record execution statistics.""" self._execution_count += 1 self._total_execution_time += execution_time def _get_execution_metadata(self) -> Dict[str, Any]: """Get execution metadata.""" return { "tool_name": self.name, "execution_count": self._execution_count, "average_execution_time": self._total_execution_time / max(1, self._execution_count) } def __call__(self, **kwargs) -> ToolResult: """Make tool callable.""" return self.execute(**kwargs) def __str__(self) -> str: return f"{self.name}: {self.description}" class AsyncGAIATool(GAIATool): """Base class for async tools.""" @abstractmethod async def _execute_async(self, **kwargs) -> Any: """Async execute method. Must be implemented by subclasses.""" pass def _execute(self, **kwargs) -> Any: """Sync wrapper for async execution.""" import asyncio return asyncio.run(self._execute_async(**kwargs)) def tool_with_retry(max_retries: int = 3, backoff_factor: float = 2.0): """Decorator to add retry logic to tool execution.""" def decorator(tool_class): original_execute = tool_class._execute @functools.wraps(original_execute) def execute_with_retry(self, **kwargs): last_exception = None for attempt in range(max_retries + 1): try: return original_execute(self, **kwargs) except Exception as e: last_exception = e if attempt < max_retries: wait_time = backoff_factor ** attempt time.sleep(wait_time) continue else: raise e if last_exception: raise last_exception tool_class._execute = execute_with_retry return tool_class return decorator def validate_required_params(*required_params): """Decorator to validate required parameters.""" def decorator(validate_method): @functools.wraps(validate_method) def wrapper(self, **kwargs): # Check required parameters missing_params = [param for param in required_params if param not in kwargs] if missing_params: raise ToolValidationError( f"Missing required parameters for {self.name}: {missing_params}" ) # Check for None values none_params = [param for param in required_params if kwargs.get(param) is None] if none_params: raise ToolValidationError( f"Required parameters cannot be None for {self.name}: {none_params}" ) # Call original validation return validate_method(self, **kwargs) return wrapper return decorator class ToolCategory(Enum): """Tool categories for organization.""" MULTIMEDIA = "multimedia" RESEARCH = "research" FILE_PROCESSING = "file_processing" CHESS = "chess" MATH = "math" UTILITY = "utility" @dataclass class ToolMetadata: """Metadata for tool registration and discovery.""" name: str description: str category: ToolCategory input_schema: Dict[str, Any] output_schema: Dict[str, Any] examples: List[Dict[str, Any]] = field(default_factory=list) version: str = "1.0.0" author: Optional[str] = None dependencies: List[str] = field(default_factory=list)