#!/usr/bin/env python # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import base64 import importlib.metadata import importlib.util import inspect import json import os import re import textwrap import types from functools import lru_cache from io import BytesIO from typing import TYPE_CHECKING, Any, Dict, Tuple, Union if TYPE_CHECKING: from smolagents.memory import AgentLogger __all__ = ["AgentError"] @lru_cache def _is_package_available(package_name: str) -> bool: try: importlib.metadata.version(package_name) return True except importlib.metadata.PackageNotFoundError: return False @lru_cache def _is_pillow_available(): return importlib.util.find_spec("PIL") is not None BASE_BUILTIN_MODULES = [ "collections", "datetime", "itertools", "math", "queue", "random", "re", "stat", "statistics", "time", "unicodedata", ] def escape_code_brackets(text: str) -> str: """Escapes square brackets in code segments while preserving Rich styling tags.""" def replace_bracketed_content(match): content = match.group(1) cleaned = re.sub( r"bold|red|green|blue|yellow|magenta|cyan|white|black|italic|dim|\s|#[0-9a-fA-F]{6}", "", content ) return f"\\[{content}\\]" if cleaned.strip() else f"[{content}]" return re.sub(r"\[([^\]]*)\]", replace_bracketed_content, text) class AgentError(Exception): """Base class for other agent-related exceptions""" def __init__(self, message, logger: "AgentLogger"): super().__init__(message) self.message = message logger.log_error(message) def dict(self) -> Dict[str, str]: return {"type": self.__class__.__name__, "message": str(self.message)} class AgentParsingError(AgentError): """Exception raised for errors in parsing in the agent""" pass class AgentExecutionError(AgentError): """Exception raised for errors in execution in the agent""" pass class AgentMaxStepsError(AgentError): """Exception raised for errors in execution in the agent""" pass class AgentGenerationError(AgentError): """Exception raised for errors in generation in the agent""" pass def make_json_serializable(obj: Any) -> Any: """Recursive function to make objects JSON serializable""" if obj is None: return None elif isinstance(obj, (str, int, float, bool)): # Try to parse string as JSON if it looks like a JSON object/array if isinstance(obj, str): try: if (obj.startswith("{") and obj.endswith("}")) or (obj.startswith("[") and obj.endswith("]")): parsed = json.loads(obj) return make_json_serializable(parsed) except json.JSONDecodeError: pass return obj elif isinstance(obj, (list, tuple)): return [make_json_serializable(item) for item in obj] elif isinstance(obj, dict): return {str(k): make_json_serializable(v) for k, v in obj.items()} elif hasattr(obj, "__dict__"): # For custom objects, convert their __dict__ to a serializable format return {"_type": obj.__class__.__name__, **{k: make_json_serializable(v) for k, v in obj.__dict__.items()}} else: # For any other type, convert to string return str(obj) def parse_json_blob(json_blob: str) -> Dict[str, str]: try: first_accolade_index = json_blob.find("{") last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1] json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'") json_data = json.loads(json_blob, strict=False) return json_data except json.JSONDecodeError as e: place = e.pos if json_blob[place - 1 : place + 2] == "},\n": raise ValueError( "JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL." ) raise ValueError( f"The JSON blob you used is invalid due to the following error: {e}.\n" f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n" f"'{json_blob[place - 4 : place + 5]}'." ) except Exception as e: raise ValueError(f"Error in parsing the JSON blob: {e}") def parse_code_blobs(code_blob: str) -> str: """Parses the LLM's output to get any code blob inside. Will return the code directly if it's code.""" pattern = r"```(?:py|python)?\n(.*?)\n```" matches = re.findall(pattern, code_blob, re.DOTALL) if len(matches) == 0: try: # Maybe the LLM outputted a code blob directly ast.parse(code_blob) return code_blob except SyntaxError: pass if "final" in code_blob and "answer" in code_blob: raise ValueError( f""" Your code snippet is invalid, because the regex pattern {pattern} was not found in it. Here is your code snippet: {code_blob} It seems like you're trying to return the final answer, you can do it as follows: Code: ```py final_answer("YOUR FINAL ANSWER HERE") ```""".strip() ) raise ValueError( f""" Your code snippet is invalid, because the regex pattern {pattern} was not found in it. Here is your code snippet: {code_blob} Make sure to include code with the correct pattern, for instance: Thoughts: Your thoughts Code: ```py # Your python code here ```""".strip() ) return "\n\n".join(match.strip() for match in matches) def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]: json_blob = json_blob.replace("```json", "").replace("```", "") tool_call = parse_json_blob(json_blob) tool_name_key, tool_arguments_key = None, None for possible_tool_name_key in ["action", "tool_name", "tool", "name", "function"]: if possible_tool_name_key in tool_call: tool_name_key = possible_tool_name_key for possible_tool_arguments_key in [ "action_input", "tool_arguments", "tool_args", "parameters", ]: if possible_tool_arguments_key in tool_call: tool_arguments_key = possible_tool_arguments_key if tool_name_key is not None: if tool_arguments_key is not None: return tool_call[tool_name_key], tool_call[tool_arguments_key] else: return tool_call[tool_name_key], None error_msg = "No tool name key found in tool call!" + f" Tool call: {json_blob}" raise AgentParsingError(error_msg) MAX_LENGTH_TRUNCATE_CONTENT = 20000 def truncate_content(content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT) -> str: if len(content) <= max_length: return content else: return ( content[: max_length // 2] + f"\n..._This content has been truncated to stay below {max_length} characters_...\n" + content[-max_length // 2 :] ) class ImportFinder(ast.NodeVisitor): def __init__(self): self.packages = set() def visit_Import(self, node): for alias in node.names: # Get the base package name (before any dots) base_package = alias.name.split(".")[0] self.packages.add(base_package) def visit_ImportFrom(self, node): if node.module: # for "from x import y" statements # Get the base package name (before any dots) base_package = node.module.split(".")[0] self.packages.add(base_package) def get_method_source(method): """Get source code for a method, including bound methods.""" if isinstance(method, types.MethodType): method = method.__func__ return get_source(method) def is_same_method(method1, method2): """Compare two methods by their source code.""" try: source1 = get_method_source(method1) source2 = get_method_source(method2) # Remove method decorators if any source1 = "\n".join(line for line in source1.split("\n") if not line.strip().startswith("@")) source2 = "\n".join(line for line in source2.split("\n") if not line.strip().startswith("@")) return source1 == source2 except (TypeError, OSError): return False def is_same_item(item1, item2): """Compare two class items (methods or attributes) for equality.""" if callable(item1) and callable(item2): return is_same_method(item1, item2) else: return item1 == item2 def instance_to_source(instance, base_cls=None): """Convert an instance to its class source code representation.""" cls = instance.__class__ class_name = cls.__name__ # Start building class lines class_lines = [] if base_cls: class_lines.append(f"class {class_name}({base_cls.__name__}):") else: class_lines.append(f"class {class_name}:") # Add docstring if it exists and differs from base if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__): class_lines.append(f' """{cls.__doc__}"""') # Add class-level attributes class_attrs = { name: value for name, value in cls.__dict__.items() if not name.startswith("__") and not callable(value) and not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value) } for name, value in class_attrs.items(): if isinstance(value, str): # multiline value if "\n" in value: escaped_value = value.replace('"""', r"\"\"\"") # Escape triple quotes class_lines.append(f' {name} = """{escaped_value}"""') else: class_lines.append(f" {name} = {json.dumps(value)}") else: class_lines.append(f" {name} = {repr(value)}") if class_attrs: class_lines.append("") # Add methods methods = { name: func for name, func in cls.__dict__.items() if callable(func) and not ( base_cls and hasattr(base_cls, name) and getattr(base_cls, name).__code__.co_code == func.__code__.co_code ) } for name, method in methods.items(): method_source = get_source(method) # Clean up the indentation method_lines = method_source.split("\n") first_line = method_lines[0] indent = len(first_line) - len(first_line.lstrip()) method_lines = [line[indent:] for line in method_lines] method_source = "\n".join([" " + line if line.strip() else line for line in method_lines]) class_lines.append(method_source) class_lines.append("") # Find required imports using ImportFinder import_finder = ImportFinder() import_finder.visit(ast.parse("\n".join(class_lines))) required_imports = import_finder.packages # Build final code with imports final_lines = [] # Add base class import if needed if base_cls: final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}") # Add discovered imports for package in required_imports: final_lines.append(f"import {package}") if final_lines: # Add empty line after imports final_lines.append("") # Add the class code final_lines.extend(class_lines) return "\n".join(final_lines) def get_source(obj) -> str: """Get the source code of a class or callable object (e.g.: function, method). First attempts to get the source code using `inspect.getsource`. In a dynamic environment (e.g.: Jupyter, IPython), if this fails, falls back to retrieving the source code from the current interactive shell session. Args: obj: A class or callable object (e.g.: function, method) Returns: str: The source code of the object, dedented and stripped Raises: TypeError: If object is not a class or callable OSError: If source code cannot be retrieved from any source ValueError: If source cannot be found in IPython history Note: TODO: handle Python standard REPL """ if not (isinstance(obj, type) or callable(obj)): raise TypeError(f"Expected class or callable, got {type(obj)}") inspect_error = None try: return textwrap.dedent(inspect.getsource(obj)).strip() except OSError as e: # let's keep track of the exception to raise it if all further methods fail inspect_error = e try: import IPython shell = IPython.get_ipython() if not shell: raise ImportError("No active IPython shell found") all_cells = "\n".join(shell.user_ns.get("In", [])).strip() if not all_cells: raise ValueError("No code cells found in IPython session") tree = ast.parse(all_cells) for node in ast.walk(tree): if isinstance(node, (ast.ClassDef, ast.FunctionDef)) and node.name == obj.__name__: return textwrap.dedent("\n".join(all_cells.split("\n")[node.lineno - 1 : node.end_lineno])).strip() raise ValueError(f"Could not find source code for {obj.__name__} in IPython history") except ImportError: # IPython is not available, let's just raise the original inspect error raise inspect_error except ValueError as e: # IPython is available but we couldn't find the source code, let's raise the error raise e from inspect_error def encode_image_base64(image): buffered = BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def make_image_url(base64_image): return f"data:image/png;base64,{base64_image}" def make_init_file(folder: str): os.makedirs(folder, exist_ok=True) # Create __init__ with open(os.path.join(folder, "__init__.py"), "w"): pass