File size: 2,903 Bytes
8047d47
e1a8ab5
 
 
b87d3a9
 
 
 
 
 
 
 
 
 
 
 
e1a8ab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878fd2c
e1a8ab5
 
 
 
878fd2c
e1a8ab5
 
 
 
 
 
 
8047d47
e1a8ab5
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import Callable, Any, Dict, get_type_hints, Optional, _GenericAlias
from dataclasses import dataclass
import inspect


# DataClasses are like normal classes in Python, but they have some basic functions like instantiation, comparing, and printing the classes already implemented.
# Syntax: @dataclasses.dataclass(*, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False)

# Parameters:

# init: If true  __init__() method will be generated
# repr: If true  __repr__() method will be generated
# eq: If true  __eq__() method will be generated
# order: If true  __lt__(), __le__(), __gt__(), and __ge__() methods will be generated.
# unsafe_hash: If False __hash__() method is generated according to how eq and frozen are set
# frozen: If true assigning to fields will generate an exception.
@dataclass
class Tool:
    name: str
    description: str
    func: Callable[..., str]
    parameters: Dict[str, Dict[str, str]]
    
    def __call__(self, *args, **kwargs) -> str:
        return self.func(*args, **kwargs)

def parse_docstring_params(docstring: str) -> Dict[str, str]:
    """Extract parameter descriptions from docstring."""
    if not docstring:
        return {}
    
    params = {}
    lines = docstring.split('\n')
    in_params = False
    current_param = None
    
    for line in lines:
        line = line.strip()
        if line.startswith('Parameters:'):
            in_params = True
        elif in_params:
            if line.startswith('-') or line.startswith('*'):
                current_param = line.lstrip('- *').split(':')[0].strip()
                params[current_param] = line.lstrip('- *').split(':')[1].strip()
            elif current_param and line:
                params[current_param] += ' ' + line.strip()
            elif not line:
                in_params = False
    
    return params

def get_type_description(type_hint: Any) -> str:
    """Get a human-readable description of a type hint."""
    return type_hint.__name__

def tool(name: str = None):
    def decorator(func: Callable[..., str]) -> Tool:
        tool_name = name or func.__name__
        description = inspect.getdoc(func) or "No description available"
        print(description)
        
        type_hints = get_type_hints(func)
        param_docs = parse_docstring_params(description)
        sig = inspect.signature(func)
        print(sig)
        
        params = {}
        for param_name, param in sig.parameters.items():
            params[param_name] = {
                "type": get_type_description(type_hints.get(param_name, Any)),
                "description": param_docs.get(param_name, "No description available")
            }
        print(description)
        return Tool(
            name=tool_name,
            description=description.split('\n\n')[0],
            func=func,
            parameters=params
        )
    return decorator