Spaces:
Running
Running
from __future__ import annotations | |
import abc | |
import asyncio | |
from collections.abc import AsyncIterator | |
from dataclasses import dataclass, field | |
from typing import TYPE_CHECKING, Any, cast | |
from typing_extensions import TypeVar | |
from ._run_impl import QueueCompleteSentinel | |
from .agent import Agent | |
from .agent_output import AgentOutputSchemaBase | |
from .exceptions import ( | |
AgentsException, | |
InputGuardrailTripwireTriggered, | |
MaxTurnsExceeded, | |
RunErrorDetails, | |
) | |
from .guardrail import InputGuardrailResult, OutputGuardrailResult | |
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem | |
from .logger import logger | |
from .run_context import RunContextWrapper | |
from .stream_events import StreamEvent | |
from .tracing import Trace | |
from .util._pretty_print import ( | |
pretty_print_result, | |
pretty_print_run_result_streaming, | |
) | |
if TYPE_CHECKING: | |
from ._run_impl import QueueCompleteSentinel | |
from .agent import Agent | |
T = TypeVar("T") | |
class RunResultBase(abc.ABC): | |
input: str | list[TResponseInputItem] | |
"""The original input items i.e. the items before run() was called. This may be a mutated | |
version of the input, if there are handoff input filters that mutate the input. | |
""" | |
new_items: list[RunItem] | |
"""The new items generated during the agent run. These include things like new messages, tool | |
calls and their outputs, etc. | |
""" | |
raw_responses: list[ModelResponse] | |
"""The raw LLM responses generated by the model during the agent run.""" | |
final_output: Any | |
"""The output of the last agent.""" | |
input_guardrail_results: list[InputGuardrailResult] | |
"""Guardrail results for the input messages.""" | |
output_guardrail_results: list[OutputGuardrailResult] | |
"""Guardrail results for the final output of the agent.""" | |
context_wrapper: RunContextWrapper[Any] | |
"""The context wrapper for the agent run.""" | |
def last_agent(self) -> Agent[Any]: | |
"""The last agent that was run.""" | |
def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) -> T: | |
"""A convenience method to cast the final output to a specific type. By default, the cast | |
is only for the typechecker. If you set `raise_if_incorrect_type` to True, we'll raise a | |
TypeError if the final output is not of the given type. | |
Args: | |
cls: The type to cast the final output to. | |
raise_if_incorrect_type: If True, we'll raise a TypeError if the final output is not of | |
the given type. | |
Returns: | |
The final output casted to the given type. | |
""" | |
if raise_if_incorrect_type and not isinstance(self.final_output, cls): | |
raise TypeError(f"Final output is not of type {cls.__name__}") | |
return cast(T, self.final_output) | |
def to_input_list(self) -> list[TResponseInputItem]: | |
"""Creates a new input list, merging the original input with all the new items generated.""" | |
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input) | |
new_items = [item.to_input_item() for item in self.new_items] | |
return original_items + new_items | |
def last_response_id(self) -> str | None: | |
"""Convenience method to get the response ID of the last model response.""" | |
if not self.raw_responses: | |
return None | |
return self.raw_responses[-1].response_id | |
class RunResult(RunResultBase): | |
_last_agent: Agent[Any] | |
def last_agent(self) -> Agent[Any]: | |
"""The last agent that was run.""" | |
return self._last_agent | |
def __str__(self) -> str: | |
return pretty_print_result(self) | |
class RunResultStreaming(RunResultBase): | |
"""The result of an agent run in streaming mode. You can use the `stream_events` method to | |
receive semantic events as they are generated. | |
The streaming method will raise: | |
- A MaxTurnsExceeded exception if the agent exceeds the max_turns limit. | |
- A GuardrailTripwireTriggered exception if a guardrail is tripped. | |
""" | |
current_agent: Agent[Any] | |
"""The current agent that is running.""" | |
current_turn: int | |
"""The current turn number.""" | |
max_turns: int | |
"""The maximum number of turns the agent can run for.""" | |
final_output: Any | |
"""The final output of the agent. This is None until the agent has finished running.""" | |
_current_agent_output_schema: AgentOutputSchemaBase | None = field(repr=False) | |
trace: Trace | None = field(repr=False) | |
is_complete: bool = False | |
"""Whether the agent has finished running.""" | |
# Queues that the background run_loop writes to | |
_event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field( | |
default_factory=asyncio.Queue, repr=False | |
) | |
_input_guardrail_queue: asyncio.Queue[InputGuardrailResult] = field( | |
default_factory=asyncio.Queue, repr=False | |
) | |
# Store the asyncio tasks that we're waiting on | |
_run_impl_task: asyncio.Task[Any] | None = field(default=None, repr=False) | |
_input_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) | |
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) | |
_stored_exception: Exception | None = field(default=None, repr=False) | |
def last_agent(self) -> Agent[Any]: | |
"""The last agent that was run. Updates as the agent run progresses, so the true last agent | |
is only available after the agent run is complete. | |
""" | |
return self.current_agent | |
def cancel(self) -> None: | |
"""Cancels the streaming run, stopping all background tasks and marking the run as | |
complete.""" | |
self._cleanup_tasks() # Cancel all running tasks | |
self.is_complete = True # Mark the run as complete to stop event streaming | |
# Optionally, clear the event queue to prevent processing stale events | |
while not self._event_queue.empty(): | |
self._event_queue.get_nowait() | |
while not self._input_guardrail_queue.empty(): | |
self._input_guardrail_queue.get_nowait() | |
async def stream_events(self) -> AsyncIterator[StreamEvent]: | |
"""Stream deltas for new items as they are generated. We're using the types from the | |
OpenAI Responses API, so these are semantic events: each event has a `type` field that | |
describes the type of the event, along with the data for that event. | |
This will raise: | |
- A MaxTurnsExceeded exception if the agent exceeds the max_turns limit. | |
- A GuardrailTripwireTriggered exception if a guardrail is tripped. | |
""" | |
while True: | |
self._check_errors() | |
if self._stored_exception: | |
logger.debug("Breaking due to stored exception") | |
self.is_complete = True | |
break | |
if self.is_complete and self._event_queue.empty(): | |
break | |
try: | |
item = await self._event_queue.get() | |
except asyncio.CancelledError: | |
break | |
if isinstance(item, QueueCompleteSentinel): | |
self._event_queue.task_done() | |
# Check for errors, in case the queue was completed due to an exception | |
self._check_errors() | |
break | |
yield item | |
self._event_queue.task_done() | |
self._cleanup_tasks() | |
if self._stored_exception: | |
raise self._stored_exception | |
def _create_error_details(self) -> RunErrorDetails: | |
"""Return a `RunErrorDetails` object considering the current attributes of the class.""" | |
return RunErrorDetails( | |
input=self.input, | |
new_items=self.new_items, | |
raw_responses=self.raw_responses, | |
last_agent=self.current_agent, | |
context_wrapper=self.context_wrapper, | |
input_guardrail_results=self.input_guardrail_results, | |
output_guardrail_results=self.output_guardrail_results, | |
) | |
def _check_errors(self): | |
if self.current_turn > self.max_turns: | |
max_turns_exc = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") | |
max_turns_exc.run_data = self._create_error_details() | |
self._stored_exception = max_turns_exc | |
# Fetch all the completed guardrail results from the queue and raise if needed | |
while not self._input_guardrail_queue.empty(): | |
guardrail_result = self._input_guardrail_queue.get_nowait() | |
if guardrail_result.output.tripwire_triggered: | |
tripwire_exc = InputGuardrailTripwireTriggered(guardrail_result) | |
tripwire_exc.run_data = self._create_error_details() | |
self._stored_exception = tripwire_exc | |
# Check the tasks for any exceptions | |
if self._run_impl_task and self._run_impl_task.done(): | |
run_impl_exc = self._run_impl_task.exception() | |
if run_impl_exc and isinstance(run_impl_exc, Exception): | |
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None: | |
run_impl_exc.run_data = self._create_error_details() | |
self._stored_exception = run_impl_exc | |
if self._input_guardrails_task and self._input_guardrails_task.done(): | |
in_guard_exc = self._input_guardrails_task.exception() | |
if in_guard_exc and isinstance(in_guard_exc, Exception): | |
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None: | |
in_guard_exc.run_data = self._create_error_details() | |
self._stored_exception = in_guard_exc | |
if self._output_guardrails_task and self._output_guardrails_task.done(): | |
out_guard_exc = self._output_guardrails_task.exception() | |
if out_guard_exc and isinstance(out_guard_exc, Exception): | |
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None: | |
out_guard_exc.run_data = self._create_error_details() | |
self._stored_exception = out_guard_exc | |
def _cleanup_tasks(self): | |
if self._run_impl_task and not self._run_impl_task.done(): | |
self._run_impl_task.cancel() | |
if self._input_guardrails_task and not self._input_guardrails_task.done(): | |
self._input_guardrails_task.cancel() | |
if self._output_guardrails_task and not self._output_guardrails_task.done(): | |
self._output_guardrails_task.cancel() | |
def __str__(self) -> str: | |
return pretty_print_run_result_streaming(self) | |