Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Base classes and interfaces for GAIA tools. | |
""" | |
from abc import ABC, abstractmethod | |
from dataclasses import dataclass, field | |
from typing import Any, Dict, Optional, Union, List | |
from enum import Enum | |
import time | |
import functools | |
from ..utils.exceptions import ToolError, ToolValidationError, ToolExecutionError, ToolTimeoutError | |
class ToolStatus(Enum): | |
"""Tool execution status.""" | |
SUCCESS = "success" | |
ERROR = "error" | |
TIMEOUT = "timeout" | |
VALIDATION_FAILED = "validation_failed" | |
class ToolResult: | |
"""Standardized tool result format.""" | |
status: ToolStatus | |
output: Any | |
error_message: Optional[str] = None | |
execution_time: Optional[float] = None | |
metadata: Dict[str, Any] = field(default_factory=dict) | |
def is_success(self) -> bool: | |
"""Check if tool execution was successful.""" | |
return self.status == ToolStatus.SUCCESS | |
def is_error(self) -> bool: | |
"""Check if tool execution failed.""" | |
return self.status in [ToolStatus.ERROR, ToolStatus.TIMEOUT, ToolStatus.VALIDATION_FAILED] | |
def get_output_or_error(self) -> str: | |
"""Get output if successful, otherwise error message.""" | |
if self.is_success: | |
return str(self.output) | |
return self.error_message or "Unknown error" | |
class GAIATool(ABC): | |
"""Abstract base class for all GAIA tools.""" | |
def __init__(self, name: str, description: str, timeout: int = 60): | |
self.name = name | |
self.description = description | |
self.timeout = timeout | |
self._execution_count = 0 | |
self._total_execution_time = 0.0 | |
def _execute(self, **kwargs) -> Any: | |
"""Execute the tool logic. Must be implemented by subclasses.""" | |
pass | |
def _validate_input(self, **kwargs) -> None: | |
"""Validate input parameters. Must be implemented by subclasses.""" | |
pass | |
def execute(self, **kwargs) -> ToolResult: | |
"""Execute tool with standardized error handling and timing.""" | |
start_time = time.time() | |
try: | |
# Input validation | |
self._validate_input(**kwargs) | |
# Execute with timeout | |
result = self._execute_with_timeout(**kwargs) | |
# Record execution | |
execution_time = time.time() - start_time | |
self._record_execution(execution_time) | |
return ToolResult( | |
status=ToolStatus.SUCCESS, | |
output=result, | |
execution_time=execution_time, | |
metadata=self._get_execution_metadata() | |
) | |
except ToolValidationError as e: | |
execution_time = time.time() - start_time | |
return ToolResult( | |
status=ToolStatus.VALIDATION_FAILED, | |
output=None, | |
error_message=str(e), | |
execution_time=execution_time | |
) | |
except ToolTimeoutError as e: | |
execution_time = time.time() - start_time | |
return ToolResult( | |
status=ToolStatus.TIMEOUT, | |
output=None, | |
error_message=str(e), | |
execution_time=execution_time | |
) | |
except Exception as e: | |
execution_time = time.time() - start_time | |
return ToolResult( | |
status=ToolStatus.ERROR, | |
output=None, | |
error_message=f"{self.name} execution failed: {str(e)}", | |
execution_time=execution_time | |
) | |
def _execute_with_timeout(self, **kwargs) -> Any: | |
"""Execute with timeout handling.""" | |
import signal | |
def timeout_handler(signum, frame): | |
raise ToolTimeoutError(f"Tool {self.name} timed out after {self.timeout} seconds") | |
# Set timeout | |
old_handler = signal.signal(signal.SIGALRM, timeout_handler) | |
signal.alarm(self.timeout) | |
try: | |
result = self._execute(**kwargs) | |
signal.alarm(0) # Cancel timeout | |
return result | |
finally: | |
signal.signal(signal.SIGALRM, old_handler) | |
def _record_execution(self, execution_time: float) -> None: | |
"""Record execution statistics.""" | |
self._execution_count += 1 | |
self._total_execution_time += execution_time | |
def _get_execution_metadata(self) -> Dict[str, Any]: | |
"""Get execution metadata.""" | |
return { | |
"tool_name": self.name, | |
"execution_count": self._execution_count, | |
"average_execution_time": self._total_execution_time / max(1, self._execution_count) | |
} | |
def __call__(self, **kwargs) -> ToolResult: | |
"""Make tool callable.""" | |
return self.execute(**kwargs) | |
def __str__(self) -> str: | |
return f"{self.name}: {self.description}" | |
class AsyncGAIATool(GAIATool): | |
"""Base class for async tools.""" | |
async def _execute_async(self, **kwargs) -> Any: | |
"""Async execute method. Must be implemented by subclasses.""" | |
pass | |
def _execute(self, **kwargs) -> Any: | |
"""Sync wrapper for async execution.""" | |
import asyncio | |
return asyncio.run(self._execute_async(**kwargs)) | |
def tool_with_retry(max_retries: int = 3, backoff_factor: float = 2.0): | |
"""Decorator to add retry logic to tool execution.""" | |
def decorator(tool_class): | |
original_execute = tool_class._execute | |
def execute_with_retry(self, **kwargs): | |
last_exception = None | |
for attempt in range(max_retries + 1): | |
try: | |
return original_execute(self, **kwargs) | |
except Exception as e: | |
last_exception = e | |
if attempt < max_retries: | |
wait_time = backoff_factor ** attempt | |
time.sleep(wait_time) | |
continue | |
else: | |
raise e | |
if last_exception: | |
raise last_exception | |
tool_class._execute = execute_with_retry | |
return tool_class | |
return decorator | |
def validate_required_params(*required_params): | |
"""Decorator to validate required parameters.""" | |
def decorator(validate_method): | |
def wrapper(self, **kwargs): | |
# Check required parameters | |
missing_params = [param for param in required_params if param not in kwargs] | |
if missing_params: | |
raise ToolValidationError( | |
f"Missing required parameters for {self.name}: {missing_params}" | |
) | |
# Check for None values | |
none_params = [param for param in required_params if kwargs.get(param) is None] | |
if none_params: | |
raise ToolValidationError( | |
f"Required parameters cannot be None for {self.name}: {none_params}" | |
) | |
# Call original validation | |
return validate_method(self, **kwargs) | |
return wrapper | |
return decorator | |
class ToolCategory(Enum): | |
"""Tool categories for organization.""" | |
MULTIMEDIA = "multimedia" | |
RESEARCH = "research" | |
FILE_PROCESSING = "file_processing" | |
CHESS = "chess" | |
MATH = "math" | |
UTILITY = "utility" | |
class ToolMetadata: | |
"""Metadata for tool registration and discovery.""" | |
name: str | |
description: str | |
category: ToolCategory | |
input_schema: Dict[str, Any] | |
output_schema: Dict[str, Any] | |
examples: List[Dict[str, Any]] = field(default_factory=list) | |
version: str = "1.0.0" | |
author: Optional[str] = None | |
dependencies: List[str] = field(default_factory=list) |