File size: 3,697 Bytes
ba68fc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python3
"""
Tool registry for managing and discovering GAIA tools.
"""

from typing import Dict, List, Optional, Type, Any
from dataclasses import dataclass, field

from .base import GAIATool, ToolCategory, ToolMetadata
from ..utils.exceptions import ToolNotFoundError


class ToolRegistry:
    """Registry for managing GAIA tools."""
    
    def __init__(self):
        self._tools: Dict[str, Type[GAIATool]] = {}
        self._metadata: Dict[str, ToolMetadata] = {}
        self._instances: Dict[str, GAIATool] = {}
    
    def register(self, tool_class: Type[GAIATool], metadata: ToolMetadata) -> None:
        """Register a tool with metadata."""
        self._tools[metadata.name] = tool_class
        self._metadata[metadata.name] = metadata
    
    def get_tool(self, name: str, **init_kwargs) -> GAIATool:
        """Get tool instance by name."""
        if name not in self._tools:
            raise ToolNotFoundError(f"Tool '{name}' not found in registry")
        
        # Return cached instance or create new one
        cache_key = f"{name}_{hash(frozenset(init_kwargs.items()))}"
        if cache_key not in self._instances:
            tool_class = self._tools[name]
            self._instances[cache_key] = tool_class(**init_kwargs)
        
        return self._instances[cache_key]
    
    def get_tools_by_category(self, category: ToolCategory) -> List[str]:
        """Get tool names by category."""
        return [
            name for name, metadata in self._metadata.items()
            if metadata.category == category
        ]
    
    def get_all_tools(self) -> List[str]:
        """Get all registered tool names."""
        return list(self._tools.keys())
    
    def get_metadata(self, name: str) -> ToolMetadata:
        """Get tool metadata by name."""
        if name not in self._metadata:
            raise ToolNotFoundError(f"Tool '{name}' not found in registry")
        return self._metadata[name]
    
    def search_tools(self, query: str) -> List[str]:
        """Search tools by name or description."""
        query_lower = query.lower()
        matches = []
        
        for name, metadata in self._metadata.items():
            if (query_lower in name.lower() or 
                query_lower in metadata.description.lower()):
                matches.append(name)
        
        return matches
    
    def validate_dependencies(self, name: str) -> bool:
        """Check if tool dependencies are available."""
        metadata = self.get_metadata(name)
        
        # Check if dependency tools are registered
        for dep in metadata.dependencies:
            if dep not in self._tools:
                return False
        
        return True
    
    def get_tool_info(self, name: str) -> Dict[str, Any]:
        """Get comprehensive tool information."""
        metadata = self.get_metadata(name)
        
        return {
            "name": metadata.name,
            "description": metadata.description,
            "category": metadata.category.value,
            "version": metadata.version,
            "author": metadata.author,
            "input_schema": metadata.input_schema,
            "output_schema": metadata.output_schema,
            "examples": metadata.examples,
            "dependencies": metadata.dependencies,
            "dependencies_satisfied": self.validate_dependencies(name)
        }


# Global tool registry
tool_registry = ToolRegistry()


def register_tool(metadata: ToolMetadata):
    """Decorator to register a tool."""
    
    def decorator(tool_class: Type[GAIATool]):
        tool_registry.register(tool_class, metadata)
        return tool_class
    
    return decorator