Spaces:
Running
Running
import ast | |
import builtins | |
from itertools import zip_longest | |
from typing import Set | |
from .utils import BASE_BUILTIN_MODULES, get_source | |
_BUILTIN_NAMES = set(vars(builtins)) | |
class MethodChecker(ast.NodeVisitor): | |
""" | |
Checks that a method | |
- only uses defined names | |
- contains no local imports (e.g. numpy is ok but local_script is not) | |
""" | |
def __init__(self, class_attributes: Set[str], check_imports: bool = True): | |
self.undefined_names = set() | |
self.imports = {} | |
self.from_imports = {} | |
self.assigned_names = set() | |
self.arg_names = set() | |
self.class_attributes = class_attributes | |
self.errors = [] | |
self.check_imports = check_imports | |
self.typing_names = {"Any"} | |
def visit_arguments(self, node): | |
"""Collect function arguments""" | |
self.arg_names = {arg.arg for arg in node.args} | |
if node.kwarg: | |
self.arg_names.add(node.kwarg.arg) | |
if node.vararg: | |
self.arg_names.add(node.vararg.arg) | |
def visit_Import(self, node): | |
for name in node.names: | |
actual_name = name.asname or name.name | |
self.imports[actual_name] = name.name | |
def visit_ImportFrom(self, node): | |
module = node.module or "" | |
for name in node.names: | |
actual_name = name.asname or name.name | |
self.from_imports[actual_name] = (module, name.name) | |
def visit_Assign(self, node): | |
for target in node.targets: | |
if isinstance(target, ast.Name): | |
self.assigned_names.add(target.id) | |
self.visit(node.value) | |
def visit_With(self, node): | |
"""Track aliases in 'with' statements (the 'y' in 'with X as y')""" | |
for item in node.items: | |
if item.optional_vars: # This is the 'y' in 'with X as y' | |
if isinstance(item.optional_vars, ast.Name): | |
self.assigned_names.add(item.optional_vars.id) | |
self.generic_visit(node) | |
def visit_ExceptHandler(self, node): | |
"""Track exception aliases (the 'e' in 'except Exception as e')""" | |
if node.name: # This is the 'e' in 'except Exception as e' | |
self.assigned_names.add(node.name) | |
self.generic_visit(node) | |
def visit_AnnAssign(self, node): | |
"""Track annotated assignments.""" | |
if isinstance(node.target, ast.Name): | |
self.assigned_names.add(node.target.id) | |
if node.value: | |
self.visit(node.value) | |
def visit_For(self, node): | |
target = node.target | |
if isinstance(target, ast.Name): | |
self.assigned_names.add(target.id) | |
elif isinstance(target, ast.Tuple): | |
for elt in target.elts: | |
if isinstance(elt, ast.Name): | |
self.assigned_names.add(elt.id) | |
self.generic_visit(node) | |
def _handle_comprehension_generators(self, generators): | |
"""Helper method to handle generators in all types of comprehensions""" | |
for generator in generators: | |
if isinstance(generator.target, ast.Name): | |
self.assigned_names.add(generator.target.id) | |
elif isinstance(generator.target, ast.Tuple): | |
for elt in generator.target.elts: | |
if isinstance(elt, ast.Name): | |
self.assigned_names.add(elt.id) | |
def visit_ListComp(self, node): | |
"""Track variables in list comprehensions""" | |
self._handle_comprehension_generators(node.generators) | |
self.generic_visit(node) | |
def visit_DictComp(self, node): | |
"""Track variables in dictionary comprehensions""" | |
self._handle_comprehension_generators(node.generators) | |
self.generic_visit(node) | |
def visit_SetComp(self, node): | |
"""Track variables in set comprehensions""" | |
self._handle_comprehension_generators(node.generators) | |
self.generic_visit(node) | |
def visit_Attribute(self, node): | |
if not (isinstance(node.value, ast.Name) and node.value.id == "self"): | |
self.generic_visit(node) | |
def visit_Name(self, node): | |
if isinstance(node.ctx, ast.Load): | |
if not ( | |
node.id in _BUILTIN_NAMES | |
or node.id in BASE_BUILTIN_MODULES | |
or node.id in self.arg_names | |
or node.id == "self" | |
or node.id in self.class_attributes | |
or node.id in self.imports | |
or node.id in self.from_imports | |
or node.id in self.assigned_names | |
or node.id in self.typing_names | |
): | |
self.errors.append(f"Name '{node.id}' is undefined.") | |
def visit_Call(self, node): | |
if isinstance(node.func, ast.Name): | |
if not ( | |
node.func.id in _BUILTIN_NAMES | |
or node.func.id in BASE_BUILTIN_MODULES | |
or node.func.id in self.arg_names | |
or node.func.id == "self" | |
or node.func.id in self.class_attributes | |
or node.func.id in self.imports | |
or node.func.id in self.from_imports | |
or node.func.id in self.assigned_names | |
): | |
self.errors.append(f"Name '{node.func.id}' is undefined.") | |
self.generic_visit(node) | |
def validate_tool_attributes(cls, check_imports: bool = True) -> None: | |
""" | |
Validates that a Tool class follows the proper patterns: | |
0. Any argument of __init__ should have a default. | |
Args chosen at init are not traceable, so we cannot rebuild the source code for them, thus any important arg should be defined as a class attribute. | |
1. About the class: | |
- Class attributes should only be strings or dicts | |
- Class attributes cannot be complex attributes | |
2. About all class methods: | |
- Imports must be from packages, not local files | |
- All methods must be self-contained | |
Raises all errors encountered, if no error returns None. | |
""" | |
class ClassLevelChecker(ast.NodeVisitor): | |
def __init__(self): | |
self.imported_names = set() | |
self.complex_attributes = set() | |
self.class_attributes = set() | |
self.non_defaults = set() | |
self.non_literal_defaults = set() | |
self.in_method = False | |
def visit_FunctionDef(self, node): | |
if node.name == "__init__": | |
self._check_init_function_parameters(node) | |
old_context = self.in_method | |
self.in_method = True | |
self.generic_visit(node) | |
self.in_method = old_context | |
def visit_Assign(self, node): | |
if self.in_method: | |
return | |
# Track class attributes | |
for target in node.targets: | |
if isinstance(target, ast.Name): | |
self.class_attributes.add(target.id) | |
# Check if the assignment is more complex than simple literals | |
if not all( | |
isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)) | |
for val in ast.walk(node.value) | |
): | |
for target in node.targets: | |
if isinstance(target, ast.Name): | |
self.complex_attributes.add(target.id) | |
def _check_init_function_parameters(self, node): | |
# Check defaults in parameters | |
for arg, default in reversed(list(zip_longest(reversed(node.args.args), reversed(node.args.defaults)))): | |
if default is None: | |
if arg.arg != "self": | |
self.non_defaults.add(arg.arg) | |
elif not isinstance(default, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)): | |
self.non_literal_defaults.add(arg.arg) | |
class_level_checker = ClassLevelChecker() | |
source = get_source(cls) | |
tree = ast.parse(source) | |
class_node = tree.body[0] | |
if not isinstance(class_node, ast.ClassDef): | |
raise ValueError("Source code must define a class") | |
class_level_checker.visit(class_node) | |
errors = [] | |
if class_level_checker.complex_attributes: | |
errors.append( | |
f"Complex attributes should be defined in __init__, not as class attributes: " | |
f"{', '.join(class_level_checker.complex_attributes)}" | |
) | |
if class_level_checker.non_defaults: | |
errors.append( | |
f"Parameters in __init__ must have default values, found required parameters: " | |
f"{', '.join(class_level_checker.non_defaults)}" | |
) | |
if class_level_checker.non_literal_defaults: | |
errors.append( | |
f"Parameters in __init__ must have literal default values, found non-literal defaults: " | |
f"{', '.join(class_level_checker.non_literal_defaults)}" | |
) | |
# Run checks on all methods | |
for node in class_node.body: | |
if isinstance(node, ast.FunctionDef): | |
method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports) | |
method_checker.visit(node) | |
errors += [f"- {node.name}: {error}" for error in method_checker.errors] | |
if errors: | |
raise ValueError(f"Tool validation failed for {cls.__name__}:\n" + "\n".join(errors)) | |
return | |