tonthatthienvu's picture
feat: major refactoring - transform monolithic architecture into modular system
ba68fc1
#!/usr/bin/env python3
"""
Base classes and interfaces for GAIA tools.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union, List
from enum import Enum
import time
import functools
from ..utils.exceptions import ToolError, ToolValidationError, ToolExecutionError, ToolTimeoutError
class ToolStatus(Enum):
"""Tool execution status."""
SUCCESS = "success"
ERROR = "error"
TIMEOUT = "timeout"
VALIDATION_FAILED = "validation_failed"
@dataclass
class ToolResult:
"""Standardized tool result format."""
status: ToolStatus
output: Any
error_message: Optional[str] = None
execution_time: Optional[float] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def is_success(self) -> bool:
"""Check if tool execution was successful."""
return self.status == ToolStatus.SUCCESS
@property
def is_error(self) -> bool:
"""Check if tool execution failed."""
return self.status in [ToolStatus.ERROR, ToolStatus.TIMEOUT, ToolStatus.VALIDATION_FAILED]
def get_output_or_error(self) -> str:
"""Get output if successful, otherwise error message."""
if self.is_success:
return str(self.output)
return self.error_message or "Unknown error"
class GAIATool(ABC):
"""Abstract base class for all GAIA tools."""
def __init__(self, name: str, description: str, timeout: int = 60):
self.name = name
self.description = description
self.timeout = timeout
self._execution_count = 0
self._total_execution_time = 0.0
@abstractmethod
def _execute(self, **kwargs) -> Any:
"""Execute the tool logic. Must be implemented by subclasses."""
pass
@abstractmethod
def _validate_input(self, **kwargs) -> None:
"""Validate input parameters. Must be implemented by subclasses."""
pass
def execute(self, **kwargs) -> ToolResult:
"""Execute tool with standardized error handling and timing."""
start_time = time.time()
try:
# Input validation
self._validate_input(**kwargs)
# Execute with timeout
result = self._execute_with_timeout(**kwargs)
# Record execution
execution_time = time.time() - start_time
self._record_execution(execution_time)
return ToolResult(
status=ToolStatus.SUCCESS,
output=result,
execution_time=execution_time,
metadata=self._get_execution_metadata()
)
except ToolValidationError as e:
execution_time = time.time() - start_time
return ToolResult(
status=ToolStatus.VALIDATION_FAILED,
output=None,
error_message=str(e),
execution_time=execution_time
)
except ToolTimeoutError as e:
execution_time = time.time() - start_time
return ToolResult(
status=ToolStatus.TIMEOUT,
output=None,
error_message=str(e),
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult(
status=ToolStatus.ERROR,
output=None,
error_message=f"{self.name} execution failed: {str(e)}",
execution_time=execution_time
)
def _execute_with_timeout(self, **kwargs) -> Any:
"""Execute with timeout handling."""
import signal
def timeout_handler(signum, frame):
raise ToolTimeoutError(f"Tool {self.name} timed out after {self.timeout} seconds")
# Set timeout
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(self.timeout)
try:
result = self._execute(**kwargs)
signal.alarm(0) # Cancel timeout
return result
finally:
signal.signal(signal.SIGALRM, old_handler)
def _record_execution(self, execution_time: float) -> None:
"""Record execution statistics."""
self._execution_count += 1
self._total_execution_time += execution_time
def _get_execution_metadata(self) -> Dict[str, Any]:
"""Get execution metadata."""
return {
"tool_name": self.name,
"execution_count": self._execution_count,
"average_execution_time": self._total_execution_time / max(1, self._execution_count)
}
def __call__(self, **kwargs) -> ToolResult:
"""Make tool callable."""
return self.execute(**kwargs)
def __str__(self) -> str:
return f"{self.name}: {self.description}"
class AsyncGAIATool(GAIATool):
"""Base class for async tools."""
@abstractmethod
async def _execute_async(self, **kwargs) -> Any:
"""Async execute method. Must be implemented by subclasses."""
pass
def _execute(self, **kwargs) -> Any:
"""Sync wrapper for async execution."""
import asyncio
return asyncio.run(self._execute_async(**kwargs))
def tool_with_retry(max_retries: int = 3, backoff_factor: float = 2.0):
"""Decorator to add retry logic to tool execution."""
def decorator(tool_class):
original_execute = tool_class._execute
@functools.wraps(original_execute)
def execute_with_retry(self, **kwargs):
last_exception = None
for attempt in range(max_retries + 1):
try:
return original_execute(self, **kwargs)
except Exception as e:
last_exception = e
if attempt < max_retries:
wait_time = backoff_factor ** attempt
time.sleep(wait_time)
continue
else:
raise e
if last_exception:
raise last_exception
tool_class._execute = execute_with_retry
return tool_class
return decorator
def validate_required_params(*required_params):
"""Decorator to validate required parameters."""
def decorator(validate_method):
@functools.wraps(validate_method)
def wrapper(self, **kwargs):
# Check required parameters
missing_params = [param for param in required_params if param not in kwargs]
if missing_params:
raise ToolValidationError(
f"Missing required parameters for {self.name}: {missing_params}"
)
# Check for None values
none_params = [param for param in required_params if kwargs.get(param) is None]
if none_params:
raise ToolValidationError(
f"Required parameters cannot be None for {self.name}: {none_params}"
)
# Call original validation
return validate_method(self, **kwargs)
return wrapper
return decorator
class ToolCategory(Enum):
"""Tool categories for organization."""
MULTIMEDIA = "multimedia"
RESEARCH = "research"
FILE_PROCESSING = "file_processing"
CHESS = "chess"
MATH = "math"
UTILITY = "utility"
@dataclass
class ToolMetadata:
"""Metadata for tool registration and discovery."""
name: str
description: str
category: ToolCategory
input_schema: Dict[str, Any]
output_schema: Dict[str, Any]
examples: List[Dict[str, Any]] = field(default_factory=list)
version: str = "1.0.0"
author: Optional[str] = None
dependencies: List[str] = field(default_factory=list)