ConradLinus's picture
Upload folder using huggingface_hub
d631808 verified
from __future__ import annotations
from typing import Any
from openai import NOT_GIVEN
from typing_extensions import TypeGuard
from .exceptions import UserError
_EMPTY_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {},
"required": [],
}
def ensure_strict_json_schema(
schema: dict[str, Any],
) -> dict[str, Any]:
"""Mutates the given JSON schema to ensure it conforms to the `strict` standard
that the OpenAI API expects.
"""
if schema == {}:
return _EMPTY_SCHEMA
return _ensure_strict_json_schema(schema, path=(), root=schema)
# Adapted from https://github.com/openai/openai-python/blob/main/src/openai/lib/_pydantic.py
def _ensure_strict_json_schema(
json_schema: object,
*,
path: tuple[str, ...],
root: dict[str, object],
) -> dict[str, Any]:
if not is_dict(json_schema):
raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
defs = json_schema.get("$defs")
if is_dict(defs):
for def_name, def_schema in defs.items():
_ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root)
definitions = json_schema.get("definitions")
if is_dict(definitions):
for definition_name, definition_schema in definitions.items():
_ensure_strict_json_schema(
definition_schema, path=(*path, "definitions", definition_name), root=root
)
typ = json_schema.get("type")
if typ == "object" and "additionalProperties" not in json_schema:
json_schema["additionalProperties"] = False
elif (
typ == "object"
and "additionalProperties" in json_schema
and json_schema["additionalProperties"]
):
raise UserError(
"additionalProperties should not be set for object types. This could be because "
"you're using an older version of Pydantic, or because you configured additional "
"properties to be allowed. If you really need this, update the function or output tool "
"to not use a strict schema."
)
# object types
# { 'type': 'object', 'properties': { 'a': {...} } }
properties = json_schema.get("properties")
if is_dict(properties):
json_schema["required"] = list(properties.keys())
json_schema["properties"] = {
key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root)
for key, prop_schema in properties.items()
}
# arrays
# { 'type': 'array', 'items': {...} }
items = json_schema.get("items")
if is_dict(items):
json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root)
# unions
any_of = json_schema.get("anyOf")
if is_list(any_of):
json_schema["anyOf"] = [
_ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
for i, variant in enumerate(any_of)
]
# intersections
all_of = json_schema.get("allOf")
if is_list(all_of):
if len(all_of) == 1:
json_schema.update(
_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root)
)
json_schema.pop("allOf")
else:
json_schema["allOf"] = [
_ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root)
for i, entry in enumerate(all_of)
]
# strip `None` defaults as there's no meaningful distinction here
# the schema will still be `nullable` and the model will default
# to using `None` anyway
if json_schema.get("default", NOT_GIVEN) is None:
json_schema.pop("default")
# we can't use `$ref`s if there are also other properties defined, e.g.
# `{"$ref": "...", "description": "my description"}`
#
# so we unravel the ref
# `{"type": "string", "description": "my description"}`
ref = json_schema.get("$ref")
if ref and has_more_than_n_keys(json_schema, 1):
assert isinstance(ref, str), f"Received non-string $ref - {ref}"
resolved = resolve_ref(root=root, ref=ref)
if not is_dict(resolved):
raise ValueError(
f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}"
)
# properties from the json schema take priority over the ones on the `$ref`
json_schema.update({**resolved, **json_schema})
json_schema.pop("$ref")
# Since the schema expanded from `$ref` might not have `additionalProperties: false` applied
# we call `_ensure_strict_json_schema` again to fix the inlined schema and ensure it's valid
return _ensure_strict_json_schema(json_schema, path=path, root=root)
return json_schema
def resolve_ref(*, root: dict[str, object], ref: str) -> object:
if not ref.startswith("#/"):
raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
path = ref[2:].split("/")
resolved = root
for key in path:
value = resolved[key]
assert is_dict(value), (
f"encountered non-dictionary entry while resolving {ref} - {resolved}"
)
resolved = value
return resolved
def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
# just pretend that we know there are only `str` keys
# as that check is not worth the performance cost
return isinstance(obj, dict)
def is_list(obj: object) -> TypeGuard[list[object]]:
return isinstance(obj, list)
def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
i = 0
for _ in obj.keys():
i += 1
if i > n:
return True
return False