Spaces:
Running
Running
#!/usr/bin/env python | |
# coding=utf-8 | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import ast | |
import base64 | |
import importlib.metadata | |
import importlib.util | |
import inspect | |
import json | |
import os | |
import re | |
import textwrap | |
import types | |
from functools import lru_cache | |
from io import BytesIO | |
from typing import TYPE_CHECKING, Any, Dict, Tuple, Union | |
if TYPE_CHECKING: | |
from smolagents.memory import AgentLogger | |
__all__ = ["AgentError"] | |
def _is_package_available(package_name: str) -> bool: | |
try: | |
importlib.metadata.version(package_name) | |
return True | |
except importlib.metadata.PackageNotFoundError: | |
return False | |
def _is_pillow_available(): | |
return importlib.util.find_spec("PIL") is not None | |
BASE_BUILTIN_MODULES = [ | |
"collections", | |
"datetime", | |
"itertools", | |
"math", | |
"queue", | |
"random", | |
"re", | |
"stat", | |
"statistics", | |
"time", | |
"unicodedata", | |
] | |
def escape_code_brackets(text: str) -> str: | |
"""Escapes square brackets in code segments while preserving Rich styling tags.""" | |
def replace_bracketed_content(match): | |
content = match.group(1) | |
cleaned = re.sub( | |
r"bold|red|green|blue|yellow|magenta|cyan|white|black|italic|dim|\s|#[0-9a-fA-F]{6}", "", content | |
) | |
return f"\\[{content}\\]" if cleaned.strip() else f"[{content}]" | |
return re.sub(r"\[([^\]]*)\]", replace_bracketed_content, text) | |
class AgentError(Exception): | |
"""Base class for other agent-related exceptions""" | |
def __init__(self, message, logger: "AgentLogger"): | |
super().__init__(message) | |
self.message = message | |
logger.log_error(message) | |
def dict(self) -> Dict[str, str]: | |
return {"type": self.__class__.__name__, "message": str(self.message)} | |
class AgentParsingError(AgentError): | |
"""Exception raised for errors in parsing in the agent""" | |
pass | |
class AgentExecutionError(AgentError): | |
"""Exception raised for errors in execution in the agent""" | |
pass | |
class AgentMaxStepsError(AgentError): | |
"""Exception raised for errors in execution in the agent""" | |
pass | |
class AgentGenerationError(AgentError): | |
"""Exception raised for errors in generation in the agent""" | |
pass | |
def make_json_serializable(obj: Any) -> Any: | |
"""Recursive function to make objects JSON serializable""" | |
if obj is None: | |
return None | |
elif isinstance(obj, (str, int, float, bool)): | |
# Try to parse string as JSON if it looks like a JSON object/array | |
if isinstance(obj, str): | |
try: | |
if (obj.startswith("{") and obj.endswith("}")) or (obj.startswith("[") and obj.endswith("]")): | |
parsed = json.loads(obj) | |
return make_json_serializable(parsed) | |
except json.JSONDecodeError: | |
pass | |
return obj | |
elif isinstance(obj, (list, tuple)): | |
return [make_json_serializable(item) for item in obj] | |
elif isinstance(obj, dict): | |
return {str(k): make_json_serializable(v) for k, v in obj.items()} | |
elif hasattr(obj, "__dict__"): | |
# For custom objects, convert their __dict__ to a serializable format | |
return {"_type": obj.__class__.__name__, **{k: make_json_serializable(v) for k, v in obj.__dict__.items()}} | |
else: | |
# For any other type, convert to string | |
return str(obj) | |
def parse_json_blob(json_blob: str) -> Dict[str, str]: | |
try: | |
first_accolade_index = json_blob.find("{") | |
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1] | |
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'") | |
json_data = json.loads(json_blob, strict=False) | |
return json_data | |
except json.JSONDecodeError as e: | |
place = e.pos | |
if json_blob[place - 1 : place + 2] == "},\n": | |
raise ValueError( | |
"JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL." | |
) | |
raise ValueError( | |
f"The JSON blob you used is invalid due to the following error: {e}.\n" | |
f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n" | |
f"'{json_blob[place - 4 : place + 5]}'." | |
) | |
except Exception as e: | |
raise ValueError(f"Error in parsing the JSON blob: {e}") | |
def parse_code_blobs(code_blob: str) -> str: | |
"""Parses the LLM's output to get any code blob inside. Will return the code directly if it's code.""" | |
pattern = r"```(?:py|python)?\n(.*?)\n```" | |
matches = re.findall(pattern, code_blob, re.DOTALL) | |
if len(matches) == 0: | |
try: # Maybe the LLM outputted a code blob directly | |
ast.parse(code_blob) | |
return code_blob | |
except SyntaxError: | |
pass | |
if "final" in code_blob and "answer" in code_blob: | |
raise ValueError( | |
f""" | |
Your code snippet is invalid, because the regex pattern {pattern} was not found in it. | |
Here is your code snippet: | |
{code_blob} | |
It seems like you're trying to return the final answer, you can do it as follows: | |
Code: | |
```py | |
final_answer("YOUR FINAL ANSWER HERE") | |
```<end_code>""".strip() | |
) | |
raise ValueError( | |
f""" | |
Your code snippet is invalid, because the regex pattern {pattern} was not found in it. | |
Here is your code snippet: | |
{code_blob} | |
Make sure to include code with the correct pattern, for instance: | |
Thoughts: Your thoughts | |
Code: | |
```py | |
# Your python code here | |
```<end_code>""".strip() | |
) | |
return "\n\n".join(match.strip() for match in matches) | |
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]: | |
json_blob = json_blob.replace("```json", "").replace("```", "") | |
tool_call = parse_json_blob(json_blob) | |
tool_name_key, tool_arguments_key = None, None | |
for possible_tool_name_key in ["action", "tool_name", "tool", "name", "function"]: | |
if possible_tool_name_key in tool_call: | |
tool_name_key = possible_tool_name_key | |
for possible_tool_arguments_key in [ | |
"action_input", | |
"tool_arguments", | |
"tool_args", | |
"parameters", | |
]: | |
if possible_tool_arguments_key in tool_call: | |
tool_arguments_key = possible_tool_arguments_key | |
if tool_name_key is not None: | |
if tool_arguments_key is not None: | |
return tool_call[tool_name_key], tool_call[tool_arguments_key] | |
else: | |
return tool_call[tool_name_key], None | |
error_msg = "No tool name key found in tool call!" + f" Tool call: {json_blob}" | |
raise AgentParsingError(error_msg) | |
MAX_LENGTH_TRUNCATE_CONTENT = 20000 | |
def truncate_content(content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT) -> str: | |
if len(content) <= max_length: | |
return content | |
else: | |
return ( | |
content[: max_length // 2] | |
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n" | |
+ content[-max_length // 2 :] | |
) | |
class ImportFinder(ast.NodeVisitor): | |
def __init__(self): | |
self.packages = set() | |
def visit_Import(self, node): | |
for alias in node.names: | |
# Get the base package name (before any dots) | |
base_package = alias.name.split(".")[0] | |
self.packages.add(base_package) | |
def visit_ImportFrom(self, node): | |
if node.module: # for "from x import y" statements | |
# Get the base package name (before any dots) | |
base_package = node.module.split(".")[0] | |
self.packages.add(base_package) | |
def get_method_source(method): | |
"""Get source code for a method, including bound methods.""" | |
if isinstance(method, types.MethodType): | |
method = method.__func__ | |
return get_source(method) | |
def is_same_method(method1, method2): | |
"""Compare two methods by their source code.""" | |
try: | |
source1 = get_method_source(method1) | |
source2 = get_method_source(method2) | |
# Remove method decorators if any | |
source1 = "\n".join(line for line in source1.split("\n") if not line.strip().startswith("@")) | |
source2 = "\n".join(line for line in source2.split("\n") if not line.strip().startswith("@")) | |
return source1 == source2 | |
except (TypeError, OSError): | |
return False | |
def is_same_item(item1, item2): | |
"""Compare two class items (methods or attributes) for equality.""" | |
if callable(item1) and callable(item2): | |
return is_same_method(item1, item2) | |
else: | |
return item1 == item2 | |
def instance_to_source(instance, base_cls=None): | |
"""Convert an instance to its class source code representation.""" | |
cls = instance.__class__ | |
class_name = cls.__name__ | |
# Start building class lines | |
class_lines = [] | |
if base_cls: | |
class_lines.append(f"class {class_name}({base_cls.__name__}):") | |
else: | |
class_lines.append(f"class {class_name}:") | |
# Add docstring if it exists and differs from base | |
if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__): | |
class_lines.append(f' """{cls.__doc__}"""') | |
# Add class-level attributes | |
class_attrs = { | |
name: value | |
for name, value in cls.__dict__.items() | |
if not name.startswith("__") | |
and not callable(value) | |
and not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value) | |
} | |
for name, value in class_attrs.items(): | |
if isinstance(value, str): | |
# multiline value | |
if "\n" in value: | |
escaped_value = value.replace('"""', r"\"\"\"") # Escape triple quotes | |
class_lines.append(f' {name} = """{escaped_value}"""') | |
else: | |
class_lines.append(f" {name} = {json.dumps(value)}") | |
else: | |
class_lines.append(f" {name} = {repr(value)}") | |
if class_attrs: | |
class_lines.append("") | |
# Add methods | |
methods = { | |
name: func | |
for name, func in cls.__dict__.items() | |
if callable(func) | |
and not ( | |
base_cls and hasattr(base_cls, name) and getattr(base_cls, name).__code__.co_code == func.__code__.co_code | |
) | |
} | |
for name, method in methods.items(): | |
method_source = get_source(method) | |
# Clean up the indentation | |
method_lines = method_source.split("\n") | |
first_line = method_lines[0] | |
indent = len(first_line) - len(first_line.lstrip()) | |
method_lines = [line[indent:] for line in method_lines] | |
method_source = "\n".join([" " + line if line.strip() else line for line in method_lines]) | |
class_lines.append(method_source) | |
class_lines.append("") | |
# Find required imports using ImportFinder | |
import_finder = ImportFinder() | |
import_finder.visit(ast.parse("\n".join(class_lines))) | |
required_imports = import_finder.packages | |
# Build final code with imports | |
final_lines = [] | |
# Add base class import if needed | |
if base_cls: | |
final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}") | |
# Add discovered imports | |
for package in required_imports: | |
final_lines.append(f"import {package}") | |
if final_lines: # Add empty line after imports | |
final_lines.append("") | |
# Add the class code | |
final_lines.extend(class_lines) | |
return "\n".join(final_lines) | |
def get_source(obj) -> str: | |
"""Get the source code of a class or callable object (e.g.: function, method). | |
First attempts to get the source code using `inspect.getsource`. | |
In a dynamic environment (e.g.: Jupyter, IPython), if this fails, | |
falls back to retrieving the source code from the current interactive shell session. | |
Args: | |
obj: A class or callable object (e.g.: function, method) | |
Returns: | |
str: The source code of the object, dedented and stripped | |
Raises: | |
TypeError: If object is not a class or callable | |
OSError: If source code cannot be retrieved from any source | |
ValueError: If source cannot be found in IPython history | |
Note: | |
TODO: handle Python standard REPL | |
""" | |
if not (isinstance(obj, type) or callable(obj)): | |
raise TypeError(f"Expected class or callable, got {type(obj)}") | |
inspect_error = None | |
try: | |
return textwrap.dedent(inspect.getsource(obj)).strip() | |
except OSError as e: | |
# let's keep track of the exception to raise it if all further methods fail | |
inspect_error = e | |
try: | |
import IPython | |
shell = IPython.get_ipython() | |
if not shell: | |
raise ImportError("No active IPython shell found") | |
all_cells = "\n".join(shell.user_ns.get("In", [])).strip() | |
if not all_cells: | |
raise ValueError("No code cells found in IPython session") | |
tree = ast.parse(all_cells) | |
for node in ast.walk(tree): | |
if isinstance(node, (ast.ClassDef, ast.FunctionDef)) and node.name == obj.__name__: | |
return textwrap.dedent("\n".join(all_cells.split("\n")[node.lineno - 1 : node.end_lineno])).strip() | |
raise ValueError(f"Could not find source code for {obj.__name__} in IPython history") | |
except ImportError: | |
# IPython is not available, let's just raise the original inspect error | |
raise inspect_error | |
except ValueError as e: | |
# IPython is available but we couldn't find the source code, let's raise the error | |
raise e from inspect_error | |
def encode_image_base64(image): | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
def make_image_url(base64_image): | |
return f"data:image/png;base64,{base64_image}" | |
def make_init_file(folder: str): | |
os.makedirs(folder, exist_ok=True) | |
# Create __init__ | |
with open(os.path.join(folder, "__init__.py"), "w"): | |
pass | |