|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional |
|
|
|
from wandb_gql import gql |
|
|
|
import wandb |
|
from wandb.apis.public.registries._freezable_list import AddOnlyArtifactTypesList |
|
from wandb.apis.public.registries.registries_search import Collections, Versions |
|
from wandb.apis.public.registries.utils import ( |
|
_fetch_org_entity_from_organization, |
|
_format_gql_artifact_types_input, |
|
_gql_to_registry_visibility, |
|
_registry_visibility_to_gql, |
|
) |
|
from wandb.proto.wandb_internal_pb2 import ServerFeature |
|
from wandb.sdk.artifacts._validators import REGISTRY_PREFIX, validate_project_name |
|
from wandb.sdk.internal.internal_api import Api as InternalApi |
|
from wandb.sdk.projects._generated.delete_project import DeleteProject |
|
from wandb.sdk.projects._generated.operations import ( |
|
DELETE_PROJECT_GQL, |
|
FETCH_REGISTRY_GQL, |
|
RENAME_PROJECT_GQL, |
|
UPSERT_REGISTRY_PROJECT_GQL, |
|
) |
|
from wandb.sdk.projects._generated.rename_project import RenameProject |
|
from wandb.sdk.projects._generated.upsert_registry_project import UpsertRegistryProject |
|
|
|
if TYPE_CHECKING: |
|
from wandb_gql import Client |
|
|
|
|
|
class Registry: |
|
"""A single registry in the Registry.""" |
|
|
|
def __init__( |
|
self, |
|
client: "Client", |
|
organization: str, |
|
entity: str, |
|
name: str, |
|
attrs: Optional[Dict[str, Any]] = None, |
|
): |
|
self.client = client |
|
self._name = name |
|
self._saved_name = name |
|
self._entity = entity |
|
self._organization = organization |
|
if attrs is not None: |
|
self._update_attributes(attrs) |
|
|
|
def _update_attributes(self, attrs: Dict[str, Any]) -> None: |
|
"""Helper method to update instance attributes from a dictionary.""" |
|
self._id = attrs.get("id", "") |
|
if self._id is None: |
|
raise ValueError(f"Registry {self.name}'s id is not found") |
|
|
|
self._description = attrs.get("description", "") |
|
self._allow_all_artifact_types = attrs.get( |
|
"allowAllArtifactTypesInRegistry", False |
|
) |
|
self._artifact_types = AddOnlyArtifactTypesList( |
|
t["node"]["name"] for t in attrs.get("artifactTypes", {}).get("edges", []) |
|
) |
|
self._created_at = attrs.get("createdAt", "") |
|
self._updated_at = attrs.get("updatedAt", "") |
|
self._visibility = _gql_to_registry_visibility(attrs.get("access", "")) |
|
|
|
@property |
|
def full_name(self) -> str: |
|
"""Full name of the registry including the `wandb-registry-` prefix.""" |
|
return f"wandb-registry-{self.name}" |
|
|
|
@property |
|
def name(self) -> str: |
|
"""Name of the registry without the `wandb-registry-` prefix.""" |
|
return self._name |
|
|
|
@name.setter |
|
def name(self, value: str): |
|
self._name = value |
|
|
|
@property |
|
def entity(self) -> str: |
|
"""Organization entity of the registry.""" |
|
return self._entity |
|
|
|
@property |
|
def organization(self) -> str: |
|
"""Organization name of the registry.""" |
|
return self._organization |
|
|
|
@property |
|
def description(self) -> str: |
|
"""Description of the registry.""" |
|
return self._description |
|
|
|
@description.setter |
|
def description(self, value: str): |
|
"""Set the description of the registry.""" |
|
self._description = value |
|
|
|
@property |
|
def allow_all_artifact_types(self): |
|
"""Returns whether all artifact types are allowed in the registry. |
|
|
|
If `True` then artifacts of any type can be added to this registry. |
|
If `False` then artifacts are restricted to the types in `artifact_types` for this registry. |
|
""" |
|
return self._allow_all_artifact_types |
|
|
|
@allow_all_artifact_types.setter |
|
def allow_all_artifact_types(self, value: bool): |
|
"""Set whether all artifact types are allowed in the registry.""" |
|
self._allow_all_artifact_types = value |
|
|
|
@property |
|
def artifact_types(self) -> AddOnlyArtifactTypesList: |
|
"""Returns the artifact types allowed in the registry. |
|
|
|
If `allow_all_artifact_types` is `True` then `artifact_types` reflects the |
|
types previously saved or currently used in the registry. |
|
If `allow_all_artifact_types` is `False` then artifacts are restricted to the |
|
types in `artifact_types`. |
|
|
|
Note: |
|
Previously saved artifact types cannot be removed. |
|
|
|
Example: |
|
```python |
|
registry.artifact_types.append("model") |
|
registry.save() # once saved, the artifact type `model` cannot be removed |
|
registry.artifact_types.append("accidentally_added") |
|
registry.artifact_types.remove( |
|
"accidentally_added" |
|
) # Types can only be removed if it has not been saved yet |
|
``` |
|
""" |
|
return self._artifact_types |
|
|
|
@property |
|
def created_at(self) -> str: |
|
"""Timestamp of when the registry was created.""" |
|
return self._created_at |
|
|
|
@property |
|
def updated_at(self) -> str: |
|
"""Timestamp of when the registry was last updated.""" |
|
return self._updated_at |
|
|
|
@property |
|
def path(self): |
|
return [self.entity, self.full_name] |
|
|
|
@property |
|
def visibility(self) -> Literal["organization", "restricted"]: |
|
"""Visibility of the registry. |
|
|
|
Returns: |
|
Literal["organization", "restricted"]: The visibility level. |
|
- "organization": Anyone in the organization can view this registry. |
|
You can edit their roles later from the settings in the UI. |
|
- "restricted": Only invited members via the UI can access this registry. |
|
Public sharing is disabled. |
|
""" |
|
return self._visibility |
|
|
|
@visibility.setter |
|
def visibility(self, value: Literal["organization", "restricted"]): |
|
"""Set the visibility of the registry. |
|
|
|
Args: |
|
value: The visibility level. Options are: |
|
- "organization": Anyone in the organization can view this registry. |
|
You can edit their roles later from the settings in the UI. |
|
- "restricted": Only invited members via the UI can access this registry. |
|
Public sharing is disabled. |
|
""" |
|
self._visibility = value |
|
|
|
def collections(self, filter: Optional[Dict[str, Any]] = None) -> Collections: |
|
"""Returns the collections belonging to the registry.""" |
|
registry_filter = { |
|
"name": self.full_name, |
|
} |
|
return Collections(self.client, self.organization, registry_filter, filter) |
|
|
|
def versions(self, filter: Optional[Dict[str, Any]] = None) -> Versions: |
|
"""Returns the versions belonging to the registry.""" |
|
registry_filter = { |
|
"name": self.full_name, |
|
} |
|
return Versions(self.client, self.organization, registry_filter, None, filter) |
|
|
|
@classmethod |
|
def create( |
|
cls, |
|
client: "Client", |
|
organization: str, |
|
name: str, |
|
visibility: Literal["organization", "restricted"], |
|
description: Optional[str] = None, |
|
artifact_types: Optional[List[str]] = None, |
|
): |
|
"""Create a new registry. |
|
|
|
The registry name must be unique within the organization. |
|
This function should be called using `api.create_registry()` |
|
|
|
Args: |
|
client: The GraphQL client. |
|
organization: The name of the organization. |
|
name: The name of the registry (without the `wandb-registry-` prefix). |
|
visibility: The visibility level ('organization' or 'restricted'). |
|
description: An optional description for the registry. |
|
artifact_types: An optional list of allowed artifact types. |
|
|
|
Returns: |
|
Registry: The newly created Registry object. |
|
|
|
Raises: |
|
ValueError: If a registry with the same name already exists in the |
|
organization or if the creation fails. |
|
""" |
|
org_entity = _fetch_org_entity_from_organization(client, organization) |
|
full_name = REGISTRY_PREFIX + name |
|
validate_project_name(full_name) |
|
accepted_artifact_types = [] |
|
if artifact_types: |
|
accepted_artifact_types = _format_gql_artifact_types_input(artifact_types) |
|
visibility_value = _registry_visibility_to_gql(visibility) |
|
registry_creation_error = ( |
|
f"Failed to create registry {name!r} in organization {organization!r}." |
|
) |
|
try: |
|
response = client.execute( |
|
gql(UPSERT_REGISTRY_PROJECT_GQL), |
|
variable_values={ |
|
"description": description, |
|
"entityName": org_entity, |
|
"name": full_name, |
|
"access": visibility_value, |
|
"allowAllArtifactTypesInRegistry": not accepted_artifact_types, |
|
"artifactTypes": accepted_artifact_types, |
|
}, |
|
) |
|
except Exception: |
|
raise ValueError(registry_creation_error) |
|
if not response["upsertModel"]["inserted"]: |
|
raise ValueError(registry_creation_error) |
|
|
|
return Registry( |
|
client, |
|
organization, |
|
org_entity, |
|
name, |
|
response["upsertModel"]["project"], |
|
) |
|
|
|
def delete(self) -> None: |
|
"""Delete the registry. This is irreversible.""" |
|
try: |
|
response = self.client.execute( |
|
gql(DELETE_PROJECT_GQL), variable_values={"id": self._id} |
|
) |
|
result = DeleteProject.model_validate(response) |
|
except Exception: |
|
raise ValueError( |
|
f"Failed to delete registry: {self.name!r} in organization: {self.organization!r}" |
|
) |
|
if not result.delete_model.success: |
|
raise ValueError( |
|
f"Failed to delete registry: {self.name!r} in organization: {self.organization!r}" |
|
) |
|
|
|
def load(self) -> None: |
|
"""Load the registry attributes from the backend to reflect the latest saved state.""" |
|
load_failure_message = ( |
|
f"Failed to load registry {self.name!r} " |
|
f"in organization {self.organization!r}." |
|
) |
|
try: |
|
response = self.client.execute( |
|
gql(FETCH_REGISTRY_GQL), |
|
variable_values={ |
|
"name": self.full_name, |
|
"entityName": self.entity, |
|
}, |
|
) |
|
except Exception: |
|
raise ValueError(load_failure_message) |
|
if response["entity"] is None: |
|
raise ValueError(load_failure_message) |
|
self.attrs = response["entity"]["project"] |
|
if self.attrs is None: |
|
raise ValueError(load_failure_message) |
|
self._update_attributes(self.attrs) |
|
|
|
def save(self) -> None: |
|
"""Save registry attributes to the backend.""" |
|
if not InternalApi()._server_supports( |
|
ServerFeature.INCLUDE_ARTIFACT_TYPES_IN_REGISTRY_CREATION |
|
): |
|
raise RuntimeError( |
|
"saving the registry is not enabled on this wandb server version. " |
|
"Please upgrade your server version or contact support at support@wandb.com." |
|
) |
|
|
|
if self._no_updating_registry_types(): |
|
raise ValueError( |
|
f"Cannot update artifact types when `allows_all_artifact_types` is {True!r}. Set it to {False!r} first." |
|
) |
|
|
|
validate_project_name(self.full_name) |
|
visibility_value = _registry_visibility_to_gql(self.visibility) |
|
newly_added_types = _format_gql_artifact_types_input(self.artifact_types.draft) |
|
registry_save_error = f"Failed to save and update registry: {self.name} in organization: {self.organization}" |
|
full_saved_name = f"{REGISTRY_PREFIX}{self._saved_name}" |
|
try: |
|
response = self.client.execute( |
|
gql(UPSERT_REGISTRY_PROJECT_GQL), |
|
variable_values={ |
|
"description": self.description, |
|
"entityName": self.entity, |
|
"name": full_saved_name, |
|
"access": visibility_value, |
|
"allowAllArtifactTypesInRegistry": self.allow_all_artifact_types, |
|
"artifactTypes": newly_added_types, |
|
}, |
|
) |
|
result = UpsertRegistryProject.model_validate(response) |
|
except Exception: |
|
raise ValueError(registry_save_error) |
|
if result.upsert_model.inserted: |
|
|
|
wandb.termlog( |
|
f"Created registry {self.name!r} in organization {self.organization!r} on save" |
|
) |
|
self._update_attributes(response["upsertModel"]["project"]) |
|
|
|
|
|
if self._saved_name != self.name: |
|
response = self.client.execute( |
|
gql(RENAME_PROJECT_GQL), |
|
variable_values={ |
|
"entityName": self.entity, |
|
"oldProjectName": full_saved_name, |
|
"newProjectName": self.full_name, |
|
}, |
|
) |
|
result = RenameProject.model_validate(response) |
|
self._saved_name = self.name |
|
if result.rename_project.inserted: |
|
|
|
wandb.termlog(f"Created new registry {self.name!r} on save") |
|
|
|
def _no_updating_registry_types(self) -> bool: |
|
|
|
return len(self.artifact_types.draft) > 0 and self.allow_all_artifact_types |
|
|