jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
from enum import Enum
from functools import lru_cache
from typing import TYPE_CHECKING, Any, List, Literal, Mapping, Optional, Sequence
from wandb.sdk.artifacts._validators import (
REGISTRY_PREFIX,
validate_artifact_types_list,
)
if TYPE_CHECKING:
from wandb_gql import Client
from wandb_gql import gql
class _Visibility(str, Enum):
# names are what users see/pass into Python methods
# values are what's expected by backend API
organization = "PRIVATE"
restricted = "RESTRICTED"
@classmethod
def _missing_(cls, value: object) -> Any:
return next(
(e for e in cls if e.name == value),
None,
)
def _format_gql_artifact_types_input(
artifact_types: Optional[List[str]] = None,
):
"""Format the artifact types for the GQL input.
Args:
artifact_types: The artifact types to add to the registry.
Returns:
The artifact types for the GQL input.
"""
if artifact_types is None:
return []
new_types = validate_artifact_types_list(artifact_types)
return [{"name": type} for type in new_types]
def _gql_to_registry_visibility(
visibility: str,
) -> Literal["organization", "restricted"]:
"""Convert the GQL visibility to the registry visibility.
Args:
visibility: The GQL visibility.
Returns:
The registry visibility.
"""
try:
return _Visibility(visibility).name
except ValueError:
raise ValueError(f"Invalid visibility: {visibility!r} from backend")
def _registry_visibility_to_gql(
visibility: Literal["organization", "restricted"],
) -> str:
"""Convert the registry visibility to the GQL visibility."""
try:
return _Visibility[visibility].value
except KeyError:
raise ValueError(
f"Invalid visibility: {visibility!r}. "
f"Must be one of: {', '.join(map(repr, (e.name for e in _Visibility)))}"
)
def _ensure_registry_prefix_on_names(query, in_name=False):
"""Traverse the filter to prepend the `name` key value with the registry prefix unless the value is a regex.
- in_name: True if we are under a "name" key (or propagating from one).
EX: {"name": "model"} -> {"name": "wandb-registry-model"}
"""
if isinstance((txt := query), str):
if in_name:
return txt if txt.startswith(REGISTRY_PREFIX) else f"{REGISTRY_PREFIX}{txt}"
return txt
if isinstance((dct := query), Mapping):
new_dict = {}
for key, obj in dct.items():
if key == "name":
new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=True)
elif key == "$regex":
# For regex operator, we skip transformation of its value.
new_dict[key] = obj
else:
# For any other key, propagate the in_name and skip_transform flags as-is.
new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=in_name)
return new_dict
if isinstance((objs := query), Sequence):
return list(
map(lambda x: _ensure_registry_prefix_on_names(x, in_name=in_name), objs)
)
return query
@lru_cache(maxsize=10)
def _fetch_org_entity_from_organization(client: "Client", organization: str) -> str:
"""Fetch the org entity from the organization.
Args:
client (Client): Graphql client.
organization (str): The organization to fetch the org entity for.
"""
query = gql(
"""
query FetchOrgEntityFromOrganization($organization: String!) {
organization(name: $organization) {
orgEntity {
name
}
}
}
"""
)
try:
response = client.execute(query, variable_values={"organization": organization})
except Exception as e:
raise ValueError(
f"Error fetching org entity for organization: {organization}"
) from e
if (
not (org := response["organization"])
or not (org_entity := org["orgEntity"])
or not (org_name := org_entity["name"])
):
raise ValueError(f"Organization entity for {organization} not found.")
return org_name