Spaces:
Running
Running
from dataclasses import asdict, dataclass | |
from logging import getLogger | |
from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, Union | |
from smolagents.models import ChatMessage, MessageRole | |
from smolagents.monitoring import AgentLogger, LogLevel | |
from smolagents.utils import AgentError, make_json_serializable | |
if TYPE_CHECKING: | |
from smolagents.models import ChatMessage | |
from smolagents.monitoring import AgentLogger | |
logger = getLogger(__name__) | |
class Message(TypedDict): | |
role: MessageRole | |
content: str | list[dict] | |
class ToolCall: | |
name: str | |
arguments: Any | |
id: str | |
def dict(self): | |
return { | |
"id": self.id, | |
"type": "function", | |
"function": { | |
"name": self.name, | |
"arguments": make_json_serializable(self.arguments), | |
}, | |
} | |
class MemoryStep: | |
def dict(self): | |
return asdict(self) | |
def to_messages(self, **kwargs) -> List[Dict[str, Any]]: | |
raise NotImplementedError | |
class ActionStep(MemoryStep): | |
model_input_messages: List[Message] | None = None | |
tool_calls: List[ToolCall] | None = None | |
start_time: float | None = None | |
end_time: float | None = None | |
step_number: int | None = None | |
error: AgentError | None = None | |
duration: float | None = None | |
model_output_message: ChatMessage = None | |
model_output: str | None = None | |
observations: str | None = None | |
observations_images: List[str] | None = None | |
action_output: Any = None | |
def dict(self): | |
# We overwrite the method to parse the tool_calls and action_output manually | |
return { | |
"model_input_messages": self.model_input_messages, | |
"tool_calls": [tc.dict() for tc in self.tool_calls] if self.tool_calls else [], | |
"start_time": self.start_time, | |
"end_time": self.end_time, | |
"step": self.step_number, | |
"error": self.error.dict() if self.error else None, | |
"duration": self.duration, | |
"model_output_message": self.model_output_message, | |
"model_output": self.model_output, | |
"observations": self.observations, | |
"action_output": make_json_serializable(self.action_output), | |
} | |
def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Message]: | |
messages = [] | |
if self.model_input_messages is not None and show_model_input_messages: | |
messages.append(Message(role=MessageRole.SYSTEM, content=self.model_input_messages)) | |
if self.model_output is not None and not summary_mode: | |
messages.append( | |
Message(role=MessageRole.ASSISTANT, content=[{"type": "text", "text": self.model_output.strip()}]) | |
) | |
if self.tool_calls is not None: | |
messages.append( | |
Message( | |
role=MessageRole.ASSISTANT, | |
content=[ | |
{ | |
"type": "text", | |
"text": "Calling tools:\n" + str([tc.dict() for tc in self.tool_calls]), | |
} | |
], | |
) | |
) | |
if self.observations is not None: | |
messages.append( | |
Message( | |
role=MessageRole.TOOL_RESPONSE, | |
content=[ | |
{ | |
"type": "text", | |
"text": f"Call id: {self.tool_calls[0].id}\nObservation:\n{self.observations}", | |
} | |
], | |
) | |
) | |
if self.error is not None: | |
error_message = ( | |
"Error:\n" | |
+ str(self.error) | |
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" | |
) | |
message_content = f"Call id: {self.tool_calls[0].id}\n" if self.tool_calls else "" | |
message_content += error_message | |
messages.append( | |
Message(role=MessageRole.TOOL_RESPONSE, content=[{"type": "text", "text": message_content}]) | |
) | |
if self.observations_images: | |
messages.append( | |
Message( | |
role=MessageRole.USER, | |
content=[{"type": "text", "text": "Here are the observed images:"}] | |
+ [ | |
{ | |
"type": "image", | |
"image": image, | |
} | |
for image in self.observations_images | |
], | |
) | |
) | |
return messages | |
class PlanningStep(MemoryStep): | |
model_input_messages: List[Message] | |
model_output_message_facts: ChatMessage | |
facts: str | |
model_output_message_plan: ChatMessage | |
plan: str | |
def to_messages(self, summary_mode: bool, **kwargs) -> List[Message]: | |
messages = [] | |
messages.append( | |
Message( | |
role=MessageRole.ASSISTANT, content=[{"type": "text", "text": f"[FACTS LIST]:\n{self.facts.strip()}"}] | |
) | |
) | |
if not summary_mode: # This step is not shown to a model writing a plan to avoid influencing the new plan | |
messages.append( | |
Message( | |
role=MessageRole.ASSISTANT, content=[{"type": "text", "text": f"[PLAN]:\n{self.plan.strip()}"}] | |
) | |
) | |
return messages | |
class TaskStep(MemoryStep): | |
task: str | |
task_images: List[str] | None = None | |
def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]: | |
content = [{"type": "text", "text": f"New task:\n{self.task}"}] | |
if self.task_images: | |
for image in self.task_images: | |
content.append({"type": "image", "image": image}) | |
return [Message(role=MessageRole.USER, content=content)] | |
class SystemPromptStep(MemoryStep): | |
system_prompt: str | |
def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]: | |
if summary_mode: | |
return [] | |
return [Message(role=MessageRole.SYSTEM, content=[{"type": "text", "text": self.system_prompt}])] | |
class AgentMemory: | |
def __init__(self, system_prompt: str): | |
self.system_prompt = SystemPromptStep(system_prompt=system_prompt) | |
self.steps: List[Union[TaskStep, ActionStep, PlanningStep]] = [] | |
def reset(self): | |
self.steps = [] | |
def get_succinct_steps(self) -> list[dict]: | |
return [ | |
{key: value for key, value in step.dict().items() if key != "model_input_messages"} for step in self.steps | |
] | |
def get_full_steps(self) -> list[dict]: | |
return [step.dict() for step in self.steps] | |
def replay(self, logger: AgentLogger, detailed: bool = False): | |
"""Prints a pretty replay of the agent's steps. | |
Args: | |
logger (AgentLogger): The logger to print replay logs to. | |
detailed (bool, optional): If True, also displays the memory at each step. Defaults to False. | |
Careful: will increase log length exponentially. Use only for debugging. | |
""" | |
logger.console.log("Replaying the agent's steps:") | |
for step in self.steps: | |
if isinstance(step, SystemPromptStep) and detailed: | |
logger.log_markdown(title="System prompt", content=step.system_prompt, level=LogLevel.ERROR) | |
elif isinstance(step, TaskStep): | |
logger.log_task(step.task, "", level=LogLevel.ERROR) | |
elif isinstance(step, ActionStep): | |
logger.log_rule(f"Step {step.step_number}", level=LogLevel.ERROR) | |
if detailed: | |
logger.log_messages(step.model_input_messages) | |
logger.log_markdown(title="Agent output:", content=step.model_output, level=LogLevel.ERROR) | |
elif isinstance(step, PlanningStep): | |
logger.log_rule("Planning step", level=LogLevel.ERROR) | |
if detailed: | |
logger.log_messages(step.model_input_messages, level=LogLevel.ERROR) | |
logger.log_markdown(title="Agent output:", content=step.facts + "\n" + step.plan, level=LogLevel.ERROR) | |
__all__ = ["AgentMemory"] | |