Spaces:
Running
Running
File size: 8,103 Bytes
ba68fc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
#!/usr/bin/env python3
"""
Base classes and interfaces for GAIA tools.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union, List
from enum import Enum
import time
import functools
from ..utils.exceptions import ToolError, ToolValidationError, ToolExecutionError, ToolTimeoutError
class ToolStatus(Enum):
"""Tool execution status."""
SUCCESS = "success"
ERROR = "error"
TIMEOUT = "timeout"
VALIDATION_FAILED = "validation_failed"
@dataclass
class ToolResult:
"""Standardized tool result format."""
status: ToolStatus
output: Any
error_message: Optional[str] = None
execution_time: Optional[float] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def is_success(self) -> bool:
"""Check if tool execution was successful."""
return self.status == ToolStatus.SUCCESS
@property
def is_error(self) -> bool:
"""Check if tool execution failed."""
return self.status in [ToolStatus.ERROR, ToolStatus.TIMEOUT, ToolStatus.VALIDATION_FAILED]
def get_output_or_error(self) -> str:
"""Get output if successful, otherwise error message."""
if self.is_success:
return str(self.output)
return self.error_message or "Unknown error"
class GAIATool(ABC):
"""Abstract base class for all GAIA tools."""
def __init__(self, name: str, description: str, timeout: int = 60):
self.name = name
self.description = description
self.timeout = timeout
self._execution_count = 0
self._total_execution_time = 0.0
@abstractmethod
def _execute(self, **kwargs) -> Any:
"""Execute the tool logic. Must be implemented by subclasses."""
pass
@abstractmethod
def _validate_input(self, **kwargs) -> None:
"""Validate input parameters. Must be implemented by subclasses."""
pass
def execute(self, **kwargs) -> ToolResult:
"""Execute tool with standardized error handling and timing."""
start_time = time.time()
try:
# Input validation
self._validate_input(**kwargs)
# Execute with timeout
result = self._execute_with_timeout(**kwargs)
# Record execution
execution_time = time.time() - start_time
self._record_execution(execution_time)
return ToolResult(
status=ToolStatus.SUCCESS,
output=result,
execution_time=execution_time,
metadata=self._get_execution_metadata()
)
except ToolValidationError as e:
execution_time = time.time() - start_time
return ToolResult(
status=ToolStatus.VALIDATION_FAILED,
output=None,
error_message=str(e),
execution_time=execution_time
)
except ToolTimeoutError as e:
execution_time = time.time() - start_time
return ToolResult(
status=ToolStatus.TIMEOUT,
output=None,
error_message=str(e),
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult(
status=ToolStatus.ERROR,
output=None,
error_message=f"{self.name} execution failed: {str(e)}",
execution_time=execution_time
)
def _execute_with_timeout(self, **kwargs) -> Any:
"""Execute with timeout handling."""
import signal
def timeout_handler(signum, frame):
raise ToolTimeoutError(f"Tool {self.name} timed out after {self.timeout} seconds")
# Set timeout
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(self.timeout)
try:
result = self._execute(**kwargs)
signal.alarm(0) # Cancel timeout
return result
finally:
signal.signal(signal.SIGALRM, old_handler)
def _record_execution(self, execution_time: float) -> None:
"""Record execution statistics."""
self._execution_count += 1
self._total_execution_time += execution_time
def _get_execution_metadata(self) -> Dict[str, Any]:
"""Get execution metadata."""
return {
"tool_name": self.name,
"execution_count": self._execution_count,
"average_execution_time": self._total_execution_time / max(1, self._execution_count)
}
def __call__(self, **kwargs) -> ToolResult:
"""Make tool callable."""
return self.execute(**kwargs)
def __str__(self) -> str:
return f"{self.name}: {self.description}"
class AsyncGAIATool(GAIATool):
"""Base class for async tools."""
@abstractmethod
async def _execute_async(self, **kwargs) -> Any:
"""Async execute method. Must be implemented by subclasses."""
pass
def _execute(self, **kwargs) -> Any:
"""Sync wrapper for async execution."""
import asyncio
return asyncio.run(self._execute_async(**kwargs))
def tool_with_retry(max_retries: int = 3, backoff_factor: float = 2.0):
"""Decorator to add retry logic to tool execution."""
def decorator(tool_class):
original_execute = tool_class._execute
@functools.wraps(original_execute)
def execute_with_retry(self, **kwargs):
last_exception = None
for attempt in range(max_retries + 1):
try:
return original_execute(self, **kwargs)
except Exception as e:
last_exception = e
if attempt < max_retries:
wait_time = backoff_factor ** attempt
time.sleep(wait_time)
continue
else:
raise e
if last_exception:
raise last_exception
tool_class._execute = execute_with_retry
return tool_class
return decorator
def validate_required_params(*required_params):
"""Decorator to validate required parameters."""
def decorator(validate_method):
@functools.wraps(validate_method)
def wrapper(self, **kwargs):
# Check required parameters
missing_params = [param for param in required_params if param not in kwargs]
if missing_params:
raise ToolValidationError(
f"Missing required parameters for {self.name}: {missing_params}"
)
# Check for None values
none_params = [param for param in required_params if kwargs.get(param) is None]
if none_params:
raise ToolValidationError(
f"Required parameters cannot be None for {self.name}: {none_params}"
)
# Call original validation
return validate_method(self, **kwargs)
return wrapper
return decorator
class ToolCategory(Enum):
"""Tool categories for organization."""
MULTIMEDIA = "multimedia"
RESEARCH = "research"
FILE_PROCESSING = "file_processing"
CHESS = "chess"
MATH = "math"
UTILITY = "utility"
@dataclass
class ToolMetadata:
"""Metadata for tool registration and discovery."""
name: str
description: str
category: ToolCategory
input_schema: Dict[str, Any]
output_schema: Dict[str, Any]
examples: List[Dict[str, Any]] = field(default_factory=list)
version: str = "1.0.0"
author: Optional[str] = None
dependencies: List[str] = field(default_factory=list) |