intelligent_systems_course / tools_registry.py
dkolarova's picture
Update tools_registry.py
82af58d verified
from typing import Callable, Any, Dict, get_type_hints, Optional, _GenericAlias
from dataclasses import dataclass
import inspect
# DataClasses are like normal classes in Python, but they have some basic functions like instantiation, comparing, and printing the classes already implemented.
# Syntax: @dataclasses.dataclass(*, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False)
# Parameters:
# init: If true __init__() method will be generated
# repr: If true __repr__() method will be generated
# eq: If true __eq__() method will be generated
# order: If true __lt__(), __le__(), __gt__(), and __ge__() methods will be generated.
# unsafe_hash: If False __hash__() method is generated according to how eq and frozen are set
# frozen: If true assigning to fields will generate an exception.
@dataclass
class Tool:
name: str
description: str
func: Callable[..., str]
parameters: Dict[str, Dict[str, str]]
def __call__(self, *args, **kwargs) -> str:
return self.func(*args, **kwargs)
def parse_docstring_params(docstring: str) -> Dict[str, str]:
"""Extract parameter descriptions from docstring."""
if not docstring:
return {}
params = {}
lines = docstring.split('\n')
in_params = False
current_param = None
for line in lines:
line = line.strip()
if line.startswith('Parameters:'):
in_params = True
elif in_params:
if line.startswith('-') or line.startswith('*'):
current_param = line.lstrip('- *').split(':')[0].strip()
params[current_param] = line.lstrip('- *').split(':')[1].strip()
elif current_param and line:
params[current_param] += ' ' + line.strip()
elif not line:
in_params = False
return params
def get_type_description(type_hint: Any) -> str:
"""Get a human-readable description of a type hint."""
return type_hint.__name__
def tool(name: str = None):
def decorator(func: Callable[..., str]) -> Tool:
tool_name = name or func.__name__
description = inspect.getdoc(func) or "No description available"
print(description)
type_hints = get_type_hints(func)
param_docs = parse_docstring_params(description)
sig = inspect.signature(func)
print(sig)
params = {}
for param_name, param in sig.parameters.items():
params[param_name] = {
"type": get_type_description(type_hints.get(param_name, Any)),
"description": param_docs.get(param_name, "No description available")
}
print(description)
return Tool(
name=tool_name,
description=description.split('\n\n')[0],
func=func,
parameters=params
)
return decorator