Spaces:
Running
Running
import abc | |
from dataclasses import dataclass | |
from typing import Any | |
from pydantic import BaseModel, TypeAdapter | |
from typing_extensions import TypedDict, get_args, get_origin | |
from .exceptions import ModelBehaviorError, UserError | |
from .strict_schema import ensure_strict_json_schema | |
from .tracing import SpanError | |
from .util import _error_tracing, _json | |
_WRAPPER_DICT_KEY = "response" | |
class AgentOutputSchemaBase(abc.ABC): | |
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON | |
produced by the LLM into the output type. | |
""" | |
def is_plain_text(self) -> bool: | |
"""Whether the output type is plain text (versus a JSON object).""" | |
pass | |
def name(self) -> str: | |
"""The name of the output type.""" | |
pass | |
def json_schema(self) -> dict[str, Any]: | |
"""Returns the JSON schema of the output. Will only be called if the output type is not | |
plain text. | |
""" | |
pass | |
def is_strict_json_schema(self) -> bool: | |
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema | |
features, but guarantees valid JSON. See here for details: | |
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas | |
""" | |
pass | |
def validate_json(self, json_str: str) -> Any: | |
"""Validate a JSON string against the output type. You must return the validated object, | |
or raise a `ModelBehaviorError` if the JSON is invalid. | |
""" | |
pass | |
class AgentOutputSchema(AgentOutputSchemaBase): | |
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON | |
produced by the LLM into the output type. | |
""" | |
output_type: type[Any] | |
"""The type of the output.""" | |
_type_adapter: TypeAdapter[Any] | |
"""A type adapter that wraps the output type, so that we can validate JSON.""" | |
_is_wrapped: bool | |
"""Whether the output type is wrapped in a dictionary. This is generally done if the base | |
output type cannot be represented as a JSON Schema object. | |
""" | |
_output_schema: dict[str, Any] | |
"""The JSON schema of the output.""" | |
_strict_json_schema: bool | |
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, | |
as it increases the likelihood of correct JSON input. | |
""" | |
def __init__(self, output_type: type[Any], strict_json_schema: bool = True): | |
""" | |
Args: | |
output_type: The type of the output. | |
strict_json_schema: Whether the JSON schema is in strict mode. We **strongly** recommend | |
setting this to True, as it increases the likelihood of correct JSON input. | |
""" | |
self.output_type = output_type | |
self._strict_json_schema = strict_json_schema | |
if output_type is None or output_type is str: | |
self._is_wrapped = False | |
self._type_adapter = TypeAdapter(output_type) | |
self._output_schema = self._type_adapter.json_schema() | |
return | |
# We should wrap for things that are not plain text, and for things that would definitely | |
# not be a JSON Schema object. | |
self._is_wrapped = not _is_subclass_of_base_model_or_dict(output_type) | |
if self._is_wrapped: | |
OutputType = TypedDict( | |
"OutputType", | |
{ | |
_WRAPPER_DICT_KEY: output_type, # type: ignore | |
}, | |
) | |
self._type_adapter = TypeAdapter(OutputType) | |
self._output_schema = self._type_adapter.json_schema() | |
else: | |
self._type_adapter = TypeAdapter(output_type) | |
self._output_schema = self._type_adapter.json_schema() | |
if self._strict_json_schema: | |
try: | |
self._output_schema = ensure_strict_json_schema(self._output_schema) | |
except UserError as e: | |
raise UserError( | |
"Strict JSON schema is enabled, but the output type is not valid. " | |
"Either make the output type strict, or pass output_schema_strict=False to " | |
"your Agent()" | |
) from e | |
def is_plain_text(self) -> bool: | |
"""Whether the output type is plain text (versus a JSON object).""" | |
return self.output_type is None or self.output_type is str | |
def is_strict_json_schema(self) -> bool: | |
"""Whether the JSON schema is in strict mode.""" | |
return self._strict_json_schema | |
def json_schema(self) -> dict[str, Any]: | |
"""The JSON schema of the output type.""" | |
if self.is_plain_text(): | |
raise UserError("Output type is plain text, so no JSON schema is available") | |
return self._output_schema | |
def validate_json(self, json_str: str) -> Any: | |
"""Validate a JSON string against the output type. Returns the validated object, or raises | |
a `ModelBehaviorError` if the JSON is invalid. | |
""" | |
validated = _json.validate_json(json_str, self._type_adapter, partial=False) | |
if self._is_wrapped: | |
if not isinstance(validated, dict): | |
_error_tracing.attach_error_to_current_span( | |
SpanError( | |
message="Invalid JSON", | |
data={"details": f"Expected a dict, got {type(validated)}"}, | |
) | |
) | |
raise ModelBehaviorError( | |
f"Expected a dict, got {type(validated)} for JSON: {json_str}" | |
) | |
if _WRAPPER_DICT_KEY not in validated: | |
_error_tracing.attach_error_to_current_span( | |
SpanError( | |
message="Invalid JSON", | |
data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"}, | |
) | |
) | |
raise ModelBehaviorError( | |
f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}" | |
) | |
return validated[_WRAPPER_DICT_KEY] | |
return validated | |
def name(self) -> str: | |
"""The name of the output type.""" | |
return _type_to_str(self.output_type) | |
def _is_subclass_of_base_model_or_dict(t: Any) -> bool: | |
if not isinstance(t, type): | |
return False | |
# If it's a generic alias, 'origin' will be the actual type, e.g. 'list' | |
origin = get_origin(t) | |
allowed_types = (BaseModel, dict) | |
# If it's a generic alias e.g. list[str], then we should check the origin type i.e. list | |
return issubclass(origin or t, allowed_types) | |
def _type_to_str(t: type[Any]) -> str: | |
origin = get_origin(t) | |
args = get_args(t) | |
if origin is None: | |
# It's a simple type like `str`, `int`, etc. | |
return t.__name__ | |
elif args: | |
args_str = ", ".join(_type_to_str(arg) for arg in args) | |
return f"{origin.__name__}[{args_str}]" | |
else: | |
return str(t) | |