Spaces:
Running
Running
File size: 7,127 Bytes
d631808 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import abc
from dataclasses import dataclass
from typing import Any
from pydantic import BaseModel, TypeAdapter
from typing_extensions import TypedDict, get_args, get_origin
from .exceptions import ModelBehaviorError, UserError
from .strict_schema import ensure_strict_json_schema
from .tracing import SpanError
from .util import _error_tracing, _json
_WRAPPER_DICT_KEY = "response"
class AgentOutputSchemaBase(abc.ABC):
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
produced by the LLM into the output type.
"""
@abc.abstractmethod
def is_plain_text(self) -> bool:
"""Whether the output type is plain text (versus a JSON object)."""
pass
@abc.abstractmethod
def name(self) -> str:
"""The name of the output type."""
pass
@abc.abstractmethod
def json_schema(self) -> dict[str, Any]:
"""Returns the JSON schema of the output. Will only be called if the output type is not
plain text.
"""
pass
@abc.abstractmethod
def is_strict_json_schema(self) -> bool:
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
features, but guarantees valid JSON. See here for details:
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
"""
pass
@abc.abstractmethod
def validate_json(self, json_str: str) -> Any:
"""Validate a JSON string against the output type. You must return the validated object,
or raise a `ModelBehaviorError` if the JSON is invalid.
"""
pass
@dataclass(init=False)
class AgentOutputSchema(AgentOutputSchemaBase):
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
produced by the LLM into the output type.
"""
output_type: type[Any]
"""The type of the output."""
_type_adapter: TypeAdapter[Any]
"""A type adapter that wraps the output type, so that we can validate JSON."""
_is_wrapped: bool
"""Whether the output type is wrapped in a dictionary. This is generally done if the base
output type cannot be represented as a JSON Schema object.
"""
_output_schema: dict[str, Any]
"""The JSON schema of the output."""
_strict_json_schema: bool
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input.
"""
def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
"""
Args:
output_type: The type of the output.
strict_json_schema: Whether the JSON schema is in strict mode. We **strongly** recommend
setting this to True, as it increases the likelihood of correct JSON input.
"""
self.output_type = output_type
self._strict_json_schema = strict_json_schema
if output_type is None or output_type is str:
self._is_wrapped = False
self._type_adapter = TypeAdapter(output_type)
self._output_schema = self._type_adapter.json_schema()
return
# We should wrap for things that are not plain text, and for things that would definitely
# not be a JSON Schema object.
self._is_wrapped = not _is_subclass_of_base_model_or_dict(output_type)
if self._is_wrapped:
OutputType = TypedDict(
"OutputType",
{
_WRAPPER_DICT_KEY: output_type, # type: ignore
},
)
self._type_adapter = TypeAdapter(OutputType)
self._output_schema = self._type_adapter.json_schema()
else:
self._type_adapter = TypeAdapter(output_type)
self._output_schema = self._type_adapter.json_schema()
if self._strict_json_schema:
try:
self._output_schema = ensure_strict_json_schema(self._output_schema)
except UserError as e:
raise UserError(
"Strict JSON schema is enabled, but the output type is not valid. "
"Either make the output type strict, or pass output_schema_strict=False to "
"your Agent()"
) from e
def is_plain_text(self) -> bool:
"""Whether the output type is plain text (versus a JSON object)."""
return self.output_type is None or self.output_type is str
def is_strict_json_schema(self) -> bool:
"""Whether the JSON schema is in strict mode."""
return self._strict_json_schema
def json_schema(self) -> dict[str, Any]:
"""The JSON schema of the output type."""
if self.is_plain_text():
raise UserError("Output type is plain text, so no JSON schema is available")
return self._output_schema
def validate_json(self, json_str: str) -> Any:
"""Validate a JSON string against the output type. Returns the validated object, or raises
a `ModelBehaviorError` if the JSON is invalid.
"""
validated = _json.validate_json(json_str, self._type_adapter, partial=False)
if self._is_wrapped:
if not isinstance(validated, dict):
_error_tracing.attach_error_to_current_span(
SpanError(
message="Invalid JSON",
data={"details": f"Expected a dict, got {type(validated)}"},
)
)
raise ModelBehaviorError(
f"Expected a dict, got {type(validated)} for JSON: {json_str}"
)
if _WRAPPER_DICT_KEY not in validated:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Invalid JSON",
data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"},
)
)
raise ModelBehaviorError(
f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}"
)
return validated[_WRAPPER_DICT_KEY]
return validated
def name(self) -> str:
"""The name of the output type."""
return _type_to_str(self.output_type)
def _is_subclass_of_base_model_or_dict(t: Any) -> bool:
if not isinstance(t, type):
return False
# If it's a generic alias, 'origin' will be the actual type, e.g. 'list'
origin = get_origin(t)
allowed_types = (BaseModel, dict)
# If it's a generic alias e.g. list[str], then we should check the origin type i.e. list
return issubclass(origin or t, allowed_types)
def _type_to_str(t: type[Any]) -> str:
origin = get_origin(t)
args = get_args(t)
if origin is None:
# It's a simple type like `str`, `int`, etc.
return t.__name__
elif args:
args_str = ", ".join(_type_to_str(arg) for arg in args)
return f"{origin.__name__}[{args_str}]"
else:
return str(t)
|