|
from enum import Enum |
|
from functools import lru_cache |
|
from typing import TYPE_CHECKING, Any, List, Literal, Mapping, Optional, Sequence |
|
|
|
from wandb.sdk.artifacts._validators import ( |
|
REGISTRY_PREFIX, |
|
validate_artifact_types_list, |
|
) |
|
|
|
if TYPE_CHECKING: |
|
from wandb_gql import Client |
|
|
|
from wandb_gql import gql |
|
|
|
|
|
class _Visibility(str, Enum): |
|
|
|
|
|
organization = "PRIVATE" |
|
restricted = "RESTRICTED" |
|
|
|
@classmethod |
|
def _missing_(cls, value: object) -> Any: |
|
return next( |
|
(e for e in cls if e.name == value), |
|
None, |
|
) |
|
|
|
|
|
def _format_gql_artifact_types_input( |
|
artifact_types: Optional[List[str]] = None, |
|
): |
|
"""Format the artifact types for the GQL input. |
|
|
|
Args: |
|
artifact_types: The artifact types to add to the registry. |
|
|
|
Returns: |
|
The artifact types for the GQL input. |
|
""" |
|
if artifact_types is None: |
|
return [] |
|
new_types = validate_artifact_types_list(artifact_types) |
|
return [{"name": type} for type in new_types] |
|
|
|
|
|
def _gql_to_registry_visibility( |
|
visibility: str, |
|
) -> Literal["organization", "restricted"]: |
|
"""Convert the GQL visibility to the registry visibility. |
|
|
|
Args: |
|
visibility: The GQL visibility. |
|
|
|
Returns: |
|
The registry visibility. |
|
""" |
|
try: |
|
return _Visibility(visibility).name |
|
except ValueError: |
|
raise ValueError(f"Invalid visibility: {visibility!r} from backend") |
|
|
|
|
|
def _registry_visibility_to_gql( |
|
visibility: Literal["organization", "restricted"], |
|
) -> str: |
|
"""Convert the registry visibility to the GQL visibility.""" |
|
try: |
|
return _Visibility[visibility].value |
|
except KeyError: |
|
raise ValueError( |
|
f"Invalid visibility: {visibility!r}. " |
|
f"Must be one of: {', '.join(map(repr, (e.name for e in _Visibility)))}" |
|
) |
|
|
|
|
|
def _ensure_registry_prefix_on_names(query, in_name=False): |
|
"""Traverse the filter to prepend the `name` key value with the registry prefix unless the value is a regex. |
|
|
|
- in_name: True if we are under a "name" key (or propagating from one). |
|
|
|
EX: {"name": "model"} -> {"name": "wandb-registry-model"} |
|
""" |
|
if isinstance((txt := query), str): |
|
if in_name: |
|
return txt if txt.startswith(REGISTRY_PREFIX) else f"{REGISTRY_PREFIX}{txt}" |
|
return txt |
|
if isinstance((dct := query), Mapping): |
|
new_dict = {} |
|
for key, obj in dct.items(): |
|
if key == "name": |
|
new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=True) |
|
elif key == "$regex": |
|
|
|
new_dict[key] = obj |
|
else: |
|
|
|
new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=in_name) |
|
return new_dict |
|
if isinstance((objs := query), Sequence): |
|
return list( |
|
map(lambda x: _ensure_registry_prefix_on_names(x, in_name=in_name), objs) |
|
) |
|
return query |
|
|
|
|
|
@lru_cache(maxsize=10) |
|
def _fetch_org_entity_from_organization(client: "Client", organization: str) -> str: |
|
"""Fetch the org entity from the organization. |
|
|
|
Args: |
|
client (Client): Graphql client. |
|
organization (str): The organization to fetch the org entity for. |
|
""" |
|
query = gql( |
|
""" |
|
query FetchOrgEntityFromOrganization($organization: String!) { |
|
organization(name: $organization) { |
|
orgEntity { |
|
name |
|
} |
|
} |
|
} |
|
""" |
|
) |
|
try: |
|
response = client.execute(query, variable_values={"organization": organization}) |
|
except Exception as e: |
|
raise ValueError( |
|
f"Error fetching org entity for organization: {organization}" |
|
) from e |
|
|
|
if ( |
|
not (org := response["organization"]) |
|
or not (org_entity := org["orgEntity"]) |
|
or not (org_name := org_entity["name"]) |
|
): |
|
raise ValueError(f"Organization entity for {organization} not found.") |
|
|
|
return org_name |
|
|