|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import ast |
|
import builtins |
|
import difflib |
|
from collections.abc import Mapping |
|
from importlib import import_module |
|
from typing import Any, Callable, Dict, List, Optional |
|
|
|
import numpy as np |
|
|
|
from ..utils import is_pandas_available |
|
|
|
|
|
if is_pandas_available(): |
|
import pandas as pd |
|
|
|
|
|
class InterpreterError(ValueError): |
|
""" |
|
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported |
|
operations. |
|
""" |
|
|
|
pass |
|
|
|
|
|
ERRORS = { |
|
name: getattr(builtins, name) |
|
for name in dir(builtins) |
|
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException) |
|
} |
|
|
|
|
|
LIST_SAFE_MODULES = [ |
|
"random", |
|
"collections", |
|
"math", |
|
"time", |
|
"queue", |
|
"itertools", |
|
"re", |
|
"stat", |
|
"statistics", |
|
"unicodedata", |
|
] |
|
|
|
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000 |
|
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000 |
|
|
|
|
|
class BreakException(Exception): |
|
pass |
|
|
|
|
|
class ContinueException(Exception): |
|
pass |
|
|
|
|
|
class ReturnException(Exception): |
|
def __init__(self, value): |
|
self.value = value |
|
|
|
|
|
def get_iterable(obj): |
|
if isinstance(obj, list): |
|
return obj |
|
elif hasattr(obj, "__iter__"): |
|
return list(obj) |
|
else: |
|
raise InterpreterError("Object is not iterable") |
|
|
|
|
|
def evaluate_unaryop(expression, state, static_tools, custom_tools): |
|
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools) |
|
if isinstance(expression.op, ast.USub): |
|
return -operand |
|
elif isinstance(expression.op, ast.UAdd): |
|
return operand |
|
elif isinstance(expression.op, ast.Not): |
|
return not operand |
|
elif isinstance(expression.op, ast.Invert): |
|
return ~operand |
|
else: |
|
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.") |
|
|
|
|
|
def evaluate_lambda(lambda_expression, state, static_tools, custom_tools): |
|
args = [arg.arg for arg in lambda_expression.args.args] |
|
|
|
def lambda_func(*values): |
|
new_state = state.copy() |
|
for arg, value in zip(args, values): |
|
new_state[arg] = value |
|
return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools) |
|
|
|
return lambda_func |
|
|
|
|
|
def evaluate_while(while_loop, state, static_tools, custom_tools): |
|
max_iterations = 1000 |
|
iterations = 0 |
|
while evaluate_ast(while_loop.test, state, static_tools, custom_tools): |
|
for node in while_loop.body: |
|
try: |
|
evaluate_ast(node, state, static_tools, custom_tools) |
|
except BreakException: |
|
return None |
|
except ContinueException: |
|
break |
|
iterations += 1 |
|
if iterations > max_iterations: |
|
raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded") |
|
return None |
|
|
|
|
|
def create_function(func_def, state, static_tools, custom_tools): |
|
def new_func(*args, **kwargs): |
|
func_state = state.copy() |
|
arg_names = [arg.arg for arg in func_def.args.args] |
|
default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults] |
|
|
|
|
|
defaults = dict(zip(arg_names[-len(default_values) :], default_values)) |
|
|
|
|
|
for name, value in zip(arg_names, args): |
|
func_state[name] = value |
|
|
|
|
|
for name, value in kwargs.items(): |
|
func_state[name] = value |
|
|
|
|
|
if func_def.args.vararg: |
|
vararg_name = func_def.args.vararg.arg |
|
func_state[vararg_name] = args |
|
|
|
if func_def.args.kwarg: |
|
kwarg_name = func_def.args.kwarg.arg |
|
func_state[kwarg_name] = kwargs |
|
|
|
|
|
for name, value in defaults.items(): |
|
if name not in func_state: |
|
func_state[name] = value |
|
|
|
|
|
if func_def.args.args and func_def.args.args[0].arg == "self": |
|
if args: |
|
func_state["self"] = args[0] |
|
func_state["__class__"] = args[0].__class__ |
|
|
|
result = None |
|
try: |
|
for stmt in func_def.body: |
|
result = evaluate_ast(stmt, func_state, static_tools, custom_tools) |
|
except ReturnException as e: |
|
result = e.value |
|
return result |
|
|
|
return new_func |
|
|
|
|
|
def create_class(class_name, class_bases, class_body): |
|
class_dict = {} |
|
for key, value in class_body.items(): |
|
class_dict[key] = value |
|
return type(class_name, tuple(class_bases), class_dict) |
|
|
|
|
|
def evaluate_function_def(func_def, state, static_tools, custom_tools): |
|
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools) |
|
return custom_tools[func_def.name] |
|
|
|
|
|
def evaluate_class_def(class_def, state, static_tools, custom_tools): |
|
class_name = class_def.name |
|
bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases] |
|
class_dict = {} |
|
|
|
for stmt in class_def.body: |
|
if isinstance(stmt, ast.FunctionDef): |
|
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools) |
|
elif isinstance(stmt, ast.Assign): |
|
for target in stmt.targets: |
|
if isinstance(target, ast.Name): |
|
class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools) |
|
elif isinstance(target, ast.Attribute): |
|
class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools) |
|
else: |
|
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}") |
|
|
|
new_class = type(class_name, tuple(bases), class_dict) |
|
state[class_name] = new_class |
|
return new_class |
|
|
|
|
|
def evaluate_augassign(expression, state, static_tools, custom_tools): |
|
|
|
def get_current_value(target): |
|
if isinstance(target, ast.Name): |
|
return state.get(target.id, 0) |
|
elif isinstance(target, ast.Subscript): |
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools) |
|
key = evaluate_ast(target.slice, state, static_tools, custom_tools) |
|
return obj[key] |
|
elif isinstance(target, ast.Attribute): |
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools) |
|
return getattr(obj, target.attr) |
|
elif isinstance(target, ast.Tuple): |
|
return tuple(get_current_value(elt) for elt in target.elts) |
|
elif isinstance(target, ast.List): |
|
return [get_current_value(elt) for elt in target.elts] |
|
else: |
|
raise InterpreterError("AugAssign not supported for {type(target)} targets.") |
|
|
|
current_value = get_current_value(expression.target) |
|
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools) |
|
|
|
|
|
if isinstance(expression.op, ast.Add): |
|
if isinstance(current_value, list): |
|
if not isinstance(value_to_add, list): |
|
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.") |
|
updated_value = current_value + value_to_add |
|
else: |
|
updated_value = current_value + value_to_add |
|
elif isinstance(expression.op, ast.Sub): |
|
updated_value = current_value - value_to_add |
|
elif isinstance(expression.op, ast.Mult): |
|
updated_value = current_value * value_to_add |
|
elif isinstance(expression.op, ast.Div): |
|
updated_value = current_value / value_to_add |
|
elif isinstance(expression.op, ast.Mod): |
|
updated_value = current_value % value_to_add |
|
elif isinstance(expression.op, ast.Pow): |
|
updated_value = current_value**value_to_add |
|
elif isinstance(expression.op, ast.FloorDiv): |
|
updated_value = current_value // value_to_add |
|
elif isinstance(expression.op, ast.BitAnd): |
|
updated_value = current_value & value_to_add |
|
elif isinstance(expression.op, ast.BitOr): |
|
updated_value = current_value | value_to_add |
|
elif isinstance(expression.op, ast.BitXor): |
|
updated_value = current_value ^ value_to_add |
|
elif isinstance(expression.op, ast.LShift): |
|
updated_value = current_value << value_to_add |
|
elif isinstance(expression.op, ast.RShift): |
|
updated_value = current_value >> value_to_add |
|
else: |
|
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.") |
|
|
|
|
|
set_value(expression.target, updated_value, state, static_tools, custom_tools) |
|
|
|
return updated_value |
|
|
|
|
|
def evaluate_boolop(node, state, static_tools, custom_tools): |
|
if isinstance(node.op, ast.And): |
|
for value in node.values: |
|
if not evaluate_ast(value, state, static_tools, custom_tools): |
|
return False |
|
return True |
|
elif isinstance(node.op, ast.Or): |
|
for value in node.values: |
|
if evaluate_ast(value, state, static_tools, custom_tools): |
|
return True |
|
return False |
|
|
|
|
|
def evaluate_binop(binop, state, static_tools, custom_tools): |
|
|
|
left_val = evaluate_ast(binop.left, state, static_tools, custom_tools) |
|
right_val = evaluate_ast(binop.right, state, static_tools, custom_tools) |
|
|
|
|
|
if isinstance(binop.op, ast.Add): |
|
return left_val + right_val |
|
elif isinstance(binop.op, ast.Sub): |
|
return left_val - right_val |
|
elif isinstance(binop.op, ast.Mult): |
|
return left_val * right_val |
|
elif isinstance(binop.op, ast.Div): |
|
return left_val / right_val |
|
elif isinstance(binop.op, ast.Mod): |
|
return left_val % right_val |
|
elif isinstance(binop.op, ast.Pow): |
|
return left_val**right_val |
|
elif isinstance(binop.op, ast.FloorDiv): |
|
return left_val // right_val |
|
elif isinstance(binop.op, ast.BitAnd): |
|
return left_val & right_val |
|
elif isinstance(binop.op, ast.BitOr): |
|
return left_val | right_val |
|
elif isinstance(binop.op, ast.BitXor): |
|
return left_val ^ right_val |
|
elif isinstance(binop.op, ast.LShift): |
|
return left_val << right_val |
|
elif isinstance(binop.op, ast.RShift): |
|
return left_val >> right_val |
|
else: |
|
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.") |
|
|
|
|
|
def evaluate_assign(assign, state, static_tools, custom_tools): |
|
result = evaluate_ast(assign.value, state, static_tools, custom_tools) |
|
if len(assign.targets) == 1: |
|
target = assign.targets[0] |
|
set_value(target, result, state, static_tools, custom_tools) |
|
else: |
|
if len(assign.targets) != len(result): |
|
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.") |
|
expanded_values = [] |
|
for tgt in assign.targets: |
|
if isinstance(tgt, ast.Starred): |
|
expanded_values.extend(result) |
|
else: |
|
expanded_values.append(result) |
|
for tgt, val in zip(assign.targets, expanded_values): |
|
set_value(tgt, val, state, static_tools, custom_tools) |
|
return result |
|
|
|
|
|
def set_value(target, value, state, static_tools, custom_tools): |
|
if isinstance(target, ast.Name): |
|
if target.id in static_tools: |
|
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!") |
|
state[target.id] = value |
|
elif isinstance(target, ast.Tuple): |
|
if not isinstance(value, tuple): |
|
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): |
|
value = tuple(value) |
|
else: |
|
raise InterpreterError("Cannot unpack non-tuple value") |
|
if len(target.elts) != len(value): |
|
raise InterpreterError("Cannot unpack tuple of wrong size") |
|
for i, elem in enumerate(target.elts): |
|
set_value(elem, value[i], state, static_tools, custom_tools) |
|
elif isinstance(target, ast.Subscript): |
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools) |
|
key = evaluate_ast(target.slice, state, static_tools, custom_tools) |
|
obj[key] = value |
|
elif isinstance(target, ast.Attribute): |
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools) |
|
setattr(obj, target.attr, value) |
|
|
|
|
|
def evaluate_call(call, state, static_tools, custom_tools): |
|
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)): |
|
raise InterpreterError(f"This is not a correct function: {call.func}).") |
|
if isinstance(call.func, ast.Attribute): |
|
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools) |
|
func_name = call.func.attr |
|
if not hasattr(obj, func_name): |
|
raise InterpreterError(f"Object {obj} has no attribute {func_name}") |
|
func = getattr(obj, func_name) |
|
|
|
elif isinstance(call.func, ast.Name): |
|
func_name = call.func.id |
|
if func_name in state: |
|
func = state[func_name] |
|
elif func_name in static_tools: |
|
func = static_tools[func_name] |
|
elif func_name in custom_tools: |
|
func = custom_tools[func_name] |
|
elif func_name in ERRORS: |
|
func = ERRORS[func_name] |
|
else: |
|
raise InterpreterError( |
|
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})." |
|
) |
|
|
|
args = [] |
|
for arg in call.args: |
|
if isinstance(arg, ast.Starred): |
|
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools)) |
|
else: |
|
args.append(evaluate_ast(arg, state, static_tools, custom_tools)) |
|
|
|
args = [] |
|
for arg in call.args: |
|
if isinstance(arg, ast.Starred): |
|
unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools) |
|
if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)): |
|
raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}") |
|
args.extend(unpacked) |
|
else: |
|
args.append(evaluate_ast(arg, state, static_tools, custom_tools)) |
|
|
|
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords} |
|
|
|
if isinstance(func, type) and len(func.__module__.split(".")) > 1: |
|
|
|
obj = func.__new__(func) |
|
if hasattr(obj, "__init__"): |
|
obj.__init__(*args, **kwargs) |
|
return obj |
|
else: |
|
if func_name == "super": |
|
if not args: |
|
if "__class__" in state and "self" in state: |
|
return super(state["__class__"], state["self"]) |
|
else: |
|
raise InterpreterError("super() needs at least one argument") |
|
cls = args[0] |
|
if not isinstance(cls, type): |
|
raise InterpreterError("super() argument 1 must be type") |
|
if len(args) == 1: |
|
return super(cls) |
|
elif len(args) == 2: |
|
instance = args[1] |
|
return super(cls, instance) |
|
else: |
|
raise InterpreterError("super() takes at most 2 arguments") |
|
else: |
|
if func_name == "print": |
|
output = " ".join(map(str, args)) |
|
global PRINT_OUTPUTS |
|
PRINT_OUTPUTS += output + "\n" |
|
|
|
return None |
|
else: |
|
output = func(*args, **kwargs) |
|
return output |
|
|
|
|
|
def evaluate_subscript(subscript, state, static_tools, custom_tools): |
|
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools) |
|
value = evaluate_ast(subscript.value, state, static_tools, custom_tools) |
|
|
|
if isinstance(value, str) and isinstance(index, str): |
|
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible") |
|
if isinstance(value, pd.core.indexing._LocIndexer): |
|
parent_object = value.obj |
|
return parent_object.loc[index] |
|
if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)): |
|
return value[index] |
|
elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy): |
|
return value[index] |
|
elif isinstance(index, slice): |
|
return value[index] |
|
elif isinstance(value, (list, tuple)): |
|
if not (-len(value) <= index < len(value)): |
|
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}") |
|
return value[int(index)] |
|
elif isinstance(value, str): |
|
if not (-len(value) <= index < len(value)): |
|
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}") |
|
return value[index] |
|
elif index in value: |
|
return value[index] |
|
elif isinstance(index, str) and isinstance(value, Mapping): |
|
close_matches = difflib.get_close_matches(index, list(value.keys())) |
|
if len(close_matches) > 0: |
|
return value[close_matches[0]] |
|
raise InterpreterError(f"Could not index {value} with '{index}'.") |
|
|
|
|
|
def evaluate_name(name, state, static_tools, custom_tools): |
|
if name.id in state: |
|
return state[name.id] |
|
elif name.id in static_tools: |
|
return static_tools[name.id] |
|
elif name.id in ERRORS: |
|
return ERRORS[name.id] |
|
close_matches = difflib.get_close_matches(name.id, list(state.keys())) |
|
if len(close_matches) > 0: |
|
return state[close_matches[0]] |
|
raise InterpreterError(f"The variable `{name.id}` is not defined.") |
|
|
|
|
|
def evaluate_condition(condition, state, static_tools, custom_tools): |
|
left = evaluate_ast(condition.left, state, static_tools, custom_tools) |
|
comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators] |
|
ops = [type(op) for op in condition.ops] |
|
|
|
result = True |
|
current_left = left |
|
|
|
for op, comparator in zip(ops, comparators): |
|
if op == ast.Eq: |
|
current_result = current_left == comparator |
|
elif op == ast.NotEq: |
|
current_result = current_left != comparator |
|
elif op == ast.Lt: |
|
current_result = current_left < comparator |
|
elif op == ast.LtE: |
|
current_result = current_left <= comparator |
|
elif op == ast.Gt: |
|
current_result = current_left > comparator |
|
elif op == ast.GtE: |
|
current_result = current_left >= comparator |
|
elif op == ast.Is: |
|
current_result = current_left is comparator |
|
elif op == ast.IsNot: |
|
current_result = current_left is not comparator |
|
elif op == ast.In: |
|
current_result = current_left in comparator |
|
elif op == ast.NotIn: |
|
current_result = current_left not in comparator |
|
else: |
|
raise InterpreterError(f"Operator not supported: {op}") |
|
|
|
result = result & current_result |
|
current_left = comparator |
|
|
|
if isinstance(result, bool) and not result: |
|
break |
|
|
|
return result if isinstance(result, (bool, pd.Series)) else result.all() |
|
|
|
|
|
def evaluate_if(if_statement, state, static_tools, custom_tools): |
|
result = None |
|
test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools) |
|
if test_result: |
|
for line in if_statement.body: |
|
line_result = evaluate_ast(line, state, static_tools, custom_tools) |
|
if line_result is not None: |
|
result = line_result |
|
else: |
|
for line in if_statement.orelse: |
|
line_result = evaluate_ast(line, state, static_tools, custom_tools) |
|
if line_result is not None: |
|
result = line_result |
|
return result |
|
|
|
|
|
def evaluate_for(for_loop, state, static_tools, custom_tools): |
|
result = None |
|
iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools) |
|
for counter in iterator: |
|
set_value(for_loop.target, counter, state, static_tools, custom_tools) |
|
for node in for_loop.body: |
|
try: |
|
line_result = evaluate_ast(node, state, static_tools, custom_tools) |
|
if line_result is not None: |
|
result = line_result |
|
except BreakException: |
|
break |
|
except ContinueException: |
|
continue |
|
else: |
|
continue |
|
break |
|
return result |
|
|
|
|
|
def evaluate_listcomp(listcomp, state, static_tools, custom_tools): |
|
def inner_evaluate(generators, index, current_state): |
|
if index >= len(generators): |
|
return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)] |
|
generator = generators[index] |
|
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools) |
|
result = [] |
|
for value in iter_value: |
|
new_state = current_state.copy() |
|
if isinstance(generator.target, ast.Tuple): |
|
for idx, elem in enumerate(generator.target.elts): |
|
new_state[elem.id] = value[idx] |
|
else: |
|
new_state[generator.target.id] = value |
|
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs): |
|
result.extend(inner_evaluate(generators, index + 1, new_state)) |
|
return result |
|
|
|
return inner_evaluate(listcomp.generators, 0, state) |
|
|
|
|
|
def evaluate_try(try_node, state, static_tools, custom_tools): |
|
try: |
|
for stmt in try_node.body: |
|
evaluate_ast(stmt, state, static_tools, custom_tools) |
|
except Exception as e: |
|
matched = False |
|
for handler in try_node.handlers: |
|
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)): |
|
matched = True |
|
if handler.name: |
|
state[handler.name] = e |
|
for stmt in handler.body: |
|
evaluate_ast(stmt, state, static_tools, custom_tools) |
|
break |
|
if not matched: |
|
raise e |
|
else: |
|
if try_node.orelse: |
|
for stmt in try_node.orelse: |
|
evaluate_ast(stmt, state, static_tools, custom_tools) |
|
finally: |
|
if try_node.finalbody: |
|
for stmt in try_node.finalbody: |
|
evaluate_ast(stmt, state, static_tools, custom_tools) |
|
|
|
|
|
def evaluate_raise(raise_node, state, static_tools, custom_tools): |
|
if raise_node.exc is not None: |
|
exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools) |
|
else: |
|
exc = None |
|
if raise_node.cause is not None: |
|
cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools) |
|
else: |
|
cause = None |
|
if exc is not None: |
|
if cause is not None: |
|
raise exc from cause |
|
else: |
|
raise exc |
|
else: |
|
raise InterpreterError("Re-raise is not supported without an active exception") |
|
|
|
|
|
def evaluate_assert(assert_node, state, static_tools, custom_tools): |
|
test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools) |
|
if not test_result: |
|
if assert_node.msg: |
|
msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools) |
|
raise AssertionError(msg) |
|
else: |
|
|
|
test_code = ast.unparse(assert_node.test) |
|
raise AssertionError(f"Assertion failed: {test_code}") |
|
|
|
|
|
def evaluate_with(with_node, state, static_tools, custom_tools): |
|
contexts = [] |
|
for item in with_node.items: |
|
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools) |
|
if item.optional_vars: |
|
state[item.optional_vars.id] = context_expr.__enter__() |
|
contexts.append(state[item.optional_vars.id]) |
|
else: |
|
context_var = context_expr.__enter__() |
|
contexts.append(context_var) |
|
|
|
try: |
|
for stmt in with_node.body: |
|
evaluate_ast(stmt, state, static_tools, custom_tools) |
|
except Exception as e: |
|
for context in reversed(contexts): |
|
context.__exit__(type(e), e, e.__traceback__) |
|
raise |
|
else: |
|
for context in reversed(contexts): |
|
context.__exit__(None, None, None) |
|
|
|
|
|
def import_modules(expression, state, authorized_imports): |
|
def check_module_authorized(module_name): |
|
module_path = module_name.split(".") |
|
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] |
|
return any(subpath in authorized_imports for subpath in module_subpaths) |
|
|
|
if isinstance(expression, ast.Import): |
|
for alias in expression.names: |
|
if check_module_authorized(alias.name): |
|
module = import_module(alias.name) |
|
state[alias.asname or alias.name] = module |
|
else: |
|
raise InterpreterError( |
|
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" |
|
) |
|
return None |
|
elif isinstance(expression, ast.ImportFrom): |
|
if check_module_authorized(expression.module): |
|
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) |
|
for alias in expression.names: |
|
state[alias.asname or alias.name] = getattr(module, alias.name) |
|
else: |
|
raise InterpreterError(f"Import from {expression.module} is not allowed.") |
|
return None |
|
|
|
|
|
def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools): |
|
result = {} |
|
for gen in dictcomp.generators: |
|
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools) |
|
for value in iter_value: |
|
new_state = state.copy() |
|
set_value(gen.target, value, new_state, static_tools, custom_tools) |
|
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs): |
|
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools) |
|
val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools) |
|
result[key] = val |
|
return result |
|
|
|
|
|
def evaluate_ast( |
|
expression: ast.AST, |
|
state: Dict[str, Any], |
|
static_tools: Dict[str, Callable], |
|
custom_tools: Dict[str, Callable], |
|
authorized_imports: List[str] = LIST_SAFE_MODULES, |
|
): |
|
""" |
|
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given |
|
set of functions. |
|
|
|
This function will recurse trough the nodes of the tree provided. |
|
|
|
Args: |
|
expression (`ast.AST`): |
|
The code to evaluate, as an abstract syntax tree. |
|
state (`Dict[str, Any]`): |
|
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation |
|
encounters assignements. |
|
static_tools (`Dict[str, Callable]`): |
|
Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error. |
|
custom_tools (`Dict[str, Callable]`): |
|
Functions that may be called during the evaluation. These static_tools can be overwritten. |
|
authorized_imports (`List[str]`): |
|
The list of modules that can be imported by the code. By default, only a few safe modules are allowed. |
|
Add more at your own risk! |
|
""" |
|
global OPERATIONS_COUNT |
|
if OPERATIONS_COUNT >= MAX_OPERATIONS: |
|
raise InterpreterError( |
|
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations." |
|
) |
|
OPERATIONS_COUNT += 1 |
|
if isinstance(expression, ast.Assign): |
|
|
|
|
|
return evaluate_assign(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.AugAssign): |
|
return evaluate_augassign(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Call): |
|
|
|
return evaluate_call(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Constant): |
|
|
|
return expression.value |
|
elif isinstance(expression, ast.Tuple): |
|
return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts) |
|
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)): |
|
return evaluate_listcomp(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.UnaryOp): |
|
return evaluate_unaryop(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Starred): |
|
return evaluate_ast(expression.value, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.BoolOp): |
|
|
|
return evaluate_boolop(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Break): |
|
raise BreakException() |
|
elif isinstance(expression, ast.Continue): |
|
raise ContinueException() |
|
elif isinstance(expression, ast.BinOp): |
|
|
|
return evaluate_binop(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Compare): |
|
|
|
return evaluate_condition(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Lambda): |
|
return evaluate_lambda(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.FunctionDef): |
|
return evaluate_function_def(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Dict): |
|
|
|
keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys] |
|
values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values] |
|
return dict(zip(keys, values)) |
|
elif isinstance(expression, ast.Expr): |
|
|
|
return evaluate_ast(expression.value, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.For): |
|
|
|
return evaluate_for(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.FormattedValue): |
|
|
|
return evaluate_ast(expression.value, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.If): |
|
|
|
return evaluate_if(expression, state, static_tools, custom_tools) |
|
elif hasattr(ast, "Index") and isinstance(expression, ast.Index): |
|
return evaluate_ast(expression.value, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.JoinedStr): |
|
return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values]) |
|
elif isinstance(expression, ast.List): |
|
|
|
return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts] |
|
elif isinstance(expression, ast.Name): |
|
|
|
return evaluate_name(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Subscript): |
|
|
|
return evaluate_subscript(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.IfExp): |
|
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools) |
|
if test_val: |
|
return evaluate_ast(expression.body, state, static_tools, custom_tools) |
|
else: |
|
return evaluate_ast(expression.orelse, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Attribute): |
|
value = evaluate_ast(expression.value, state, static_tools, custom_tools) |
|
return getattr(value, expression.attr) |
|
elif isinstance(expression, ast.Slice): |
|
return slice( |
|
evaluate_ast(expression.lower, state, static_tools, custom_tools) |
|
if expression.lower is not None |
|
else None, |
|
evaluate_ast(expression.upper, state, static_tools, custom_tools) |
|
if expression.upper is not None |
|
else None, |
|
evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None, |
|
) |
|
elif isinstance(expression, ast.DictComp): |
|
return evaluate_dictcomp(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.While): |
|
return evaluate_while(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, (ast.Import, ast.ImportFrom)): |
|
return import_modules(expression, state, authorized_imports) |
|
elif isinstance(expression, ast.ClassDef): |
|
return evaluate_class_def(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Try): |
|
return evaluate_try(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Raise): |
|
return evaluate_raise(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Assert): |
|
return evaluate_assert(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.With): |
|
return evaluate_with(expression, state, static_tools, custom_tools) |
|
elif isinstance(expression, ast.Set): |
|
return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts} |
|
elif isinstance(expression, ast.Return): |
|
raise ReturnException( |
|
evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None |
|
) |
|
else: |
|
|
|
raise InterpreterError(f"{expression.__class__.__name__} is not supported.") |
|
|
|
|
|
def evaluate_python_code( |
|
code: str, |
|
static_tools: Optional[Dict[str, Callable]] = None, |
|
custom_tools: Optional[Dict[str, Callable]] = None, |
|
state: Optional[Dict[str, Any]] = None, |
|
authorized_imports: List[str] = LIST_SAFE_MODULES, |
|
): |
|
""" |
|
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set |
|
of functions. |
|
|
|
This function will recurse through the nodes of the tree provided. |
|
|
|
Args: |
|
code (`str`): |
|
The code to evaluate. |
|
static_tools (`Dict[str, Callable]`): |
|
The functions that may be called during the evaluation. |
|
These tools cannot be overwritten in the code: any assignment to their name will raise an error. |
|
custom_tools (`Dict[str, Callable]`): |
|
The functions that may be called during the evaluation. |
|
These tools can be overwritten in the code: any assignment to their name will overwrite them. |
|
state (`Dict[str, Any]`): |
|
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be |
|
updated by this function to contain all variables as they are evaluated. |
|
The print outputs will be stored in the state under the key 'print_outputs'. |
|
""" |
|
try: |
|
expression = ast.parse(code) |
|
except SyntaxError as e: |
|
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}") |
|
if state is None: |
|
state = {} |
|
if static_tools is None: |
|
static_tools = {} |
|
if custom_tools is None: |
|
custom_tools = {} |
|
result = None |
|
global PRINT_OUTPUTS |
|
PRINT_OUTPUTS = "" |
|
global OPERATIONS_COUNT |
|
OPERATIONS_COUNT = 0 |
|
for node in expression.body: |
|
try: |
|
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) |
|
except InterpreterError as e: |
|
msg = "" |
|
if len(PRINT_OUTPUTS) > 0: |
|
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT: |
|
msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n" |
|
else: |
|
msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n" |
|
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" |
|
raise InterpreterError(msg) |
|
finally: |
|
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT: |
|
state["print_outputs"] = PRINT_OUTPUTS |
|
else: |
|
state["print_outputs"] = ( |
|
PRINT_OUTPUTS[:MAX_LEN_OUTPUT] |
|
+ f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._" |
|
) |
|
|
|
return result |
|
|