File size: 4,289 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from enum import Enum
from functools import lru_cache
from typing import TYPE_CHECKING, Any, List, Literal, Mapping, Optional, Sequence
from wandb.sdk.artifacts._validators import (
REGISTRY_PREFIX,
validate_artifact_types_list,
)
if TYPE_CHECKING:
from wandb_gql import Client
from wandb_gql import gql
class _Visibility(str, Enum):
# names are what users see/pass into Python methods
# values are what's expected by backend API
organization = "PRIVATE"
restricted = "RESTRICTED"
@classmethod
def _missing_(cls, value: object) -> Any:
return next(
(e for e in cls if e.name == value),
None,
)
def _format_gql_artifact_types_input(
artifact_types: Optional[List[str]] = None,
):
"""Format the artifact types for the GQL input.
Args:
artifact_types: The artifact types to add to the registry.
Returns:
The artifact types for the GQL input.
"""
if artifact_types is None:
return []
new_types = validate_artifact_types_list(artifact_types)
return [{"name": type} for type in new_types]
def _gql_to_registry_visibility(
visibility: str,
) -> Literal["organization", "restricted"]:
"""Convert the GQL visibility to the registry visibility.
Args:
visibility: The GQL visibility.
Returns:
The registry visibility.
"""
try:
return _Visibility(visibility).name
except ValueError:
raise ValueError(f"Invalid visibility: {visibility!r} from backend")
def _registry_visibility_to_gql(
visibility: Literal["organization", "restricted"],
) -> str:
"""Convert the registry visibility to the GQL visibility."""
try:
return _Visibility[visibility].value
except KeyError:
raise ValueError(
f"Invalid visibility: {visibility!r}. "
f"Must be one of: {', '.join(map(repr, (e.name for e in _Visibility)))}"
)
def _ensure_registry_prefix_on_names(query, in_name=False):
"""Traverse the filter to prepend the `name` key value with the registry prefix unless the value is a regex.
- in_name: True if we are under a "name" key (or propagating from one).
EX: {"name": "model"} -> {"name": "wandb-registry-model"}
"""
if isinstance((txt := query), str):
if in_name:
return txt if txt.startswith(REGISTRY_PREFIX) else f"{REGISTRY_PREFIX}{txt}"
return txt
if isinstance((dct := query), Mapping):
new_dict = {}
for key, obj in dct.items():
if key == "name":
new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=True)
elif key == "$regex":
# For regex operator, we skip transformation of its value.
new_dict[key] = obj
else:
# For any other key, propagate the in_name and skip_transform flags as-is.
new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=in_name)
return new_dict
if isinstance((objs := query), Sequence):
return list(
map(lambda x: _ensure_registry_prefix_on_names(x, in_name=in_name), objs)
)
return query
@lru_cache(maxsize=10)
def _fetch_org_entity_from_organization(client: "Client", organization: str) -> str:
"""Fetch the org entity from the organization.
Args:
client (Client): Graphql client.
organization (str): The organization to fetch the org entity for.
"""
query = gql(
"""
query FetchOrgEntityFromOrganization($organization: String!) {
organization(name: $organization) {
orgEntity {
name
}
}
}
"""
)
try:
response = client.execute(query, variable_values={"organization": organization})
except Exception as e:
raise ValueError(
f"Error fetching org entity for organization: {organization}"
) from e
if (
not (org := response["organization"])
or not (org_entity := org["orgEntity"])
or not (org_name := org_entity["name"])
):
raise ValueError(f"Organization entity for {organization} not found.")
return org_name
|