Moonshot_DeepResearch / src /smolagents /tool_validation.py
MaoShen's picture
Upload folder using huggingface_hub
2eb41d7 verified
import ast
import builtins
from itertools import zip_longest
from typing import Set
from .utils import BASE_BUILTIN_MODULES, get_source
_BUILTIN_NAMES = set(vars(builtins))
class MethodChecker(ast.NodeVisitor):
"""
Checks that a method
- only uses defined names
- contains no local imports (e.g. numpy is ok but local_script is not)
"""
def __init__(self, class_attributes: Set[str], check_imports: bool = True):
self.undefined_names = set()
self.imports = {}
self.from_imports = {}
self.assigned_names = set()
self.arg_names = set()
self.class_attributes = class_attributes
self.errors = []
self.check_imports = check_imports
self.typing_names = {"Any"}
def visit_arguments(self, node):
"""Collect function arguments"""
self.arg_names = {arg.arg for arg in node.args}
if node.kwarg:
self.arg_names.add(node.kwarg.arg)
if node.vararg:
self.arg_names.add(node.vararg.arg)
def visit_Import(self, node):
for name in node.names:
actual_name = name.asname or name.name
self.imports[actual_name] = name.name
def visit_ImportFrom(self, node):
module = node.module or ""
for name in node.names:
actual_name = name.asname or name.name
self.from_imports[actual_name] = (module, name.name)
def visit_Assign(self, node):
for target in node.targets:
if isinstance(target, ast.Name):
self.assigned_names.add(target.id)
self.visit(node.value)
def visit_With(self, node):
"""Track aliases in 'with' statements (the 'y' in 'with X as y')"""
for item in node.items:
if item.optional_vars: # This is the 'y' in 'with X as y'
if isinstance(item.optional_vars, ast.Name):
self.assigned_names.add(item.optional_vars.id)
self.generic_visit(node)
def visit_ExceptHandler(self, node):
"""Track exception aliases (the 'e' in 'except Exception as e')"""
if node.name: # This is the 'e' in 'except Exception as e'
self.assigned_names.add(node.name)
self.generic_visit(node)
def visit_AnnAssign(self, node):
"""Track annotated assignments."""
if isinstance(node.target, ast.Name):
self.assigned_names.add(node.target.id)
if node.value:
self.visit(node.value)
def visit_For(self, node):
target = node.target
if isinstance(target, ast.Name):
self.assigned_names.add(target.id)
elif isinstance(target, ast.Tuple):
for elt in target.elts:
if isinstance(elt, ast.Name):
self.assigned_names.add(elt.id)
self.generic_visit(node)
def _handle_comprehension_generators(self, generators):
"""Helper method to handle generators in all types of comprehensions"""
for generator in generators:
if isinstance(generator.target, ast.Name):
self.assigned_names.add(generator.target.id)
elif isinstance(generator.target, ast.Tuple):
for elt in generator.target.elts:
if isinstance(elt, ast.Name):
self.assigned_names.add(elt.id)
def visit_ListComp(self, node):
"""Track variables in list comprehensions"""
self._handle_comprehension_generators(node.generators)
self.generic_visit(node)
def visit_DictComp(self, node):
"""Track variables in dictionary comprehensions"""
self._handle_comprehension_generators(node.generators)
self.generic_visit(node)
def visit_SetComp(self, node):
"""Track variables in set comprehensions"""
self._handle_comprehension_generators(node.generators)
self.generic_visit(node)
def visit_Attribute(self, node):
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
if not (
node.id in _BUILTIN_NAMES
or node.id in BASE_BUILTIN_MODULES
or node.id in self.arg_names
or node.id == "self"
or node.id in self.class_attributes
or node.id in self.imports
or node.id in self.from_imports
or node.id in self.assigned_names
or node.id in self.typing_names
):
self.errors.append(f"Name '{node.id}' is undefined.")
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
if not (
node.func.id in _BUILTIN_NAMES
or node.func.id in BASE_BUILTIN_MODULES
or node.func.id in self.arg_names
or node.func.id == "self"
or node.func.id in self.class_attributes
or node.func.id in self.imports
or node.func.id in self.from_imports
or node.func.id in self.assigned_names
):
self.errors.append(f"Name '{node.func.id}' is undefined.")
self.generic_visit(node)
def validate_tool_attributes(cls, check_imports: bool = True) -> None:
"""
Validates that a Tool class follows the proper patterns:
0. Any argument of __init__ should have a default.
Args chosen at init are not traceable, so we cannot rebuild the source code for them, thus any important arg should be defined as a class attribute.
1. About the class:
- Class attributes should only be strings or dicts
- Class attributes cannot be complex attributes
2. About all class methods:
- Imports must be from packages, not local files
- All methods must be self-contained
Raises all errors encountered, if no error returns None.
"""
class ClassLevelChecker(ast.NodeVisitor):
def __init__(self):
self.imported_names = set()
self.complex_attributes = set()
self.class_attributes = set()
self.non_defaults = set()
self.non_literal_defaults = set()
self.in_method = False
def visit_FunctionDef(self, node):
if node.name == "__init__":
self._check_init_function_parameters(node)
old_context = self.in_method
self.in_method = True
self.generic_visit(node)
self.in_method = old_context
def visit_Assign(self, node):
if self.in_method:
return
# Track class attributes
for target in node.targets:
if isinstance(target, ast.Name):
self.class_attributes.add(target.id)
# Check if the assignment is more complex than simple literals
if not all(
isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set))
for val in ast.walk(node.value)
):
for target in node.targets:
if isinstance(target, ast.Name):
self.complex_attributes.add(target.id)
def _check_init_function_parameters(self, node):
# Check defaults in parameters
for arg, default in reversed(list(zip_longest(reversed(node.args.args), reversed(node.args.defaults)))):
if default is None:
if arg.arg != "self":
self.non_defaults.add(arg.arg)
elif not isinstance(default, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)):
self.non_literal_defaults.add(arg.arg)
class_level_checker = ClassLevelChecker()
source = get_source(cls)
tree = ast.parse(source)
class_node = tree.body[0]
if not isinstance(class_node, ast.ClassDef):
raise ValueError("Source code must define a class")
class_level_checker.visit(class_node)
errors = []
if class_level_checker.complex_attributes:
errors.append(
f"Complex attributes should be defined in __init__, not as class attributes: "
f"{', '.join(class_level_checker.complex_attributes)}"
)
if class_level_checker.non_defaults:
errors.append(
f"Parameters in __init__ must have default values, found required parameters: "
f"{', '.join(class_level_checker.non_defaults)}"
)
if class_level_checker.non_literal_defaults:
errors.append(
f"Parameters in __init__ must have literal default values, found non-literal defaults: "
f"{', '.join(class_level_checker.non_literal_defaults)}"
)
# Run checks on all methods
for node in class_node.body:
if isinstance(node, ast.FunctionDef):
method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports)
method_checker.visit(node)
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
if errors:
raise ValueError(f"Tool validation failed for {cls.__name__}:\n" + "\n".join(errors))
return