Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Tool registry for managing and discovering GAIA tools. | |
""" | |
from typing import Dict, List, Optional, Type, Any | |
from dataclasses import dataclass, field | |
from .base import GAIATool, ToolCategory, ToolMetadata | |
from ..utils.exceptions import ToolNotFoundError | |
class ToolRegistry: | |
"""Registry for managing GAIA tools.""" | |
def __init__(self): | |
self._tools: Dict[str, Type[GAIATool]] = {} | |
self._metadata: Dict[str, ToolMetadata] = {} | |
self._instances: Dict[str, GAIATool] = {} | |
def register(self, tool_class: Type[GAIATool], metadata: ToolMetadata) -> None: | |
"""Register a tool with metadata.""" | |
self._tools[metadata.name] = tool_class | |
self._metadata[metadata.name] = metadata | |
def get_tool(self, name: str, **init_kwargs) -> GAIATool: | |
"""Get tool instance by name.""" | |
if name not in self._tools: | |
raise ToolNotFoundError(f"Tool '{name}' not found in registry") | |
# Return cached instance or create new one | |
cache_key = f"{name}_{hash(frozenset(init_kwargs.items()))}" | |
if cache_key not in self._instances: | |
tool_class = self._tools[name] | |
self._instances[cache_key] = tool_class(**init_kwargs) | |
return self._instances[cache_key] | |
def get_tools_by_category(self, category: ToolCategory) -> List[str]: | |
"""Get tool names by category.""" | |
return [ | |
name for name, metadata in self._metadata.items() | |
if metadata.category == category | |
] | |
def get_all_tools(self) -> List[str]: | |
"""Get all registered tool names.""" | |
return list(self._tools.keys()) | |
def get_metadata(self, name: str) -> ToolMetadata: | |
"""Get tool metadata by name.""" | |
if name not in self._metadata: | |
raise ToolNotFoundError(f"Tool '{name}' not found in registry") | |
return self._metadata[name] | |
def search_tools(self, query: str) -> List[str]: | |
"""Search tools by name or description.""" | |
query_lower = query.lower() | |
matches = [] | |
for name, metadata in self._metadata.items(): | |
if (query_lower in name.lower() or | |
query_lower in metadata.description.lower()): | |
matches.append(name) | |
return matches | |
def validate_dependencies(self, name: str) -> bool: | |
"""Check if tool dependencies are available.""" | |
metadata = self.get_metadata(name) | |
# Check if dependency tools are registered | |
for dep in metadata.dependencies: | |
if dep not in self._tools: | |
return False | |
return True | |
def get_tool_info(self, name: str) -> Dict[str, Any]: | |
"""Get comprehensive tool information.""" | |
metadata = self.get_metadata(name) | |
return { | |
"name": metadata.name, | |
"description": metadata.description, | |
"category": metadata.category.value, | |
"version": metadata.version, | |
"author": metadata.author, | |
"input_schema": metadata.input_schema, | |
"output_schema": metadata.output_schema, | |
"examples": metadata.examples, | |
"dependencies": metadata.dependencies, | |
"dependencies_satisfied": self.validate_dependencies(name) | |
} | |
# Global tool registry | |
tool_registry = ToolRegistry() | |
def register_tool(metadata: ToolMetadata): | |
"""Decorator to register a tool.""" | |
def decorator(tool_class: Type[GAIATool]): | |
tool_registry.register(tool_class, metadata) | |
return tool_class | |
return decorator |