from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from wandb_gql import gql import wandb from wandb.apis.public.registries._freezable_list import AddOnlyArtifactTypesList from wandb.apis.public.registries.registries_search import Collections, Versions from wandb.apis.public.registries.utils import ( _fetch_org_entity_from_organization, _format_gql_artifact_types_input, _gql_to_registry_visibility, _registry_visibility_to_gql, ) from wandb.proto.wandb_internal_pb2 import ServerFeature from wandb.sdk.artifacts._validators import REGISTRY_PREFIX, validate_project_name from wandb.sdk.internal.internal_api import Api as InternalApi from wandb.sdk.projects._generated.delete_project import DeleteProject from wandb.sdk.projects._generated.operations import ( DELETE_PROJECT_GQL, FETCH_REGISTRY_GQL, RENAME_PROJECT_GQL, UPSERT_REGISTRY_PROJECT_GQL, ) from wandb.sdk.projects._generated.rename_project import RenameProject from wandb.sdk.projects._generated.upsert_registry_project import UpsertRegistryProject if TYPE_CHECKING: from wandb_gql import Client class Registry: """A single registry in the Registry.""" def __init__( self, client: "Client", organization: str, entity: str, name: str, attrs: Optional[Dict[str, Any]] = None, ): self.client = client self._name = name self._saved_name = name self._entity = entity self._organization = organization if attrs is not None: self._update_attributes(attrs) def _update_attributes(self, attrs: Dict[str, Any]) -> None: """Helper method to update instance attributes from a dictionary.""" self._id = attrs.get("id", "") if self._id is None: raise ValueError(f"Registry {self.name}'s id is not found") self._description = attrs.get("description", "") self._allow_all_artifact_types = attrs.get( "allowAllArtifactTypesInRegistry", False ) self._artifact_types = AddOnlyArtifactTypesList( t["node"]["name"] for t in attrs.get("artifactTypes", {}).get("edges", []) ) self._created_at = attrs.get("createdAt", "") self._updated_at = attrs.get("updatedAt", "") self._visibility = _gql_to_registry_visibility(attrs.get("access", "")) @property def full_name(self) -> str: """Full name of the registry including the `wandb-registry-` prefix.""" return f"wandb-registry-{self.name}" @property def name(self) -> str: """Name of the registry without the `wandb-registry-` prefix.""" return self._name @name.setter def name(self, value: str): self._name = value @property def entity(self) -> str: """Organization entity of the registry.""" return self._entity @property def organization(self) -> str: """Organization name of the registry.""" return self._organization @property def description(self) -> str: """Description of the registry.""" return self._description @description.setter def description(self, value: str): """Set the description of the registry.""" self._description = value @property def allow_all_artifact_types(self): """Returns whether all artifact types are allowed in the registry. If `True` then artifacts of any type can be added to this registry. If `False` then artifacts are restricted to the types in `artifact_types` for this registry. """ return self._allow_all_artifact_types @allow_all_artifact_types.setter def allow_all_artifact_types(self, value: bool): """Set whether all artifact types are allowed in the registry.""" self._allow_all_artifact_types = value @property def artifact_types(self) -> AddOnlyArtifactTypesList: """Returns the artifact types allowed in the registry. If `allow_all_artifact_types` is `True` then `artifact_types` reflects the types previously saved or currently used in the registry. If `allow_all_artifact_types` is `False` then artifacts are restricted to the types in `artifact_types`. Note: Previously saved artifact types cannot be removed. Example: ```python registry.artifact_types.append("model") registry.save() # once saved, the artifact type `model` cannot be removed registry.artifact_types.append("accidentally_added") registry.artifact_types.remove( "accidentally_added" ) # Types can only be removed if it has not been saved yet ``` """ return self._artifact_types @property def created_at(self) -> str: """Timestamp of when the registry was created.""" return self._created_at @property def updated_at(self) -> str: """Timestamp of when the registry was last updated.""" return self._updated_at @property def path(self): return [self.entity, self.full_name] @property def visibility(self) -> Literal["organization", "restricted"]: """Visibility of the registry. Returns: Literal["organization", "restricted"]: The visibility level. - "organization": Anyone in the organization can view this registry. You can edit their roles later from the settings in the UI. - "restricted": Only invited members via the UI can access this registry. Public sharing is disabled. """ return self._visibility @visibility.setter def visibility(self, value: Literal["organization", "restricted"]): """Set the visibility of the registry. Args: value: The visibility level. Options are: - "organization": Anyone in the organization can view this registry. You can edit their roles later from the settings in the UI. - "restricted": Only invited members via the UI can access this registry. Public sharing is disabled. """ self._visibility = value def collections(self, filter: Optional[Dict[str, Any]] = None) -> Collections: """Returns the collections belonging to the registry.""" registry_filter = { "name": self.full_name, } return Collections(self.client, self.organization, registry_filter, filter) def versions(self, filter: Optional[Dict[str, Any]] = None) -> Versions: """Returns the versions belonging to the registry.""" registry_filter = { "name": self.full_name, } return Versions(self.client, self.organization, registry_filter, None, filter) @classmethod def create( cls, client: "Client", organization: str, name: str, visibility: Literal["organization", "restricted"], description: Optional[str] = None, artifact_types: Optional[List[str]] = None, ): """Create a new registry. The registry name must be unique within the organization. This function should be called using `api.create_registry()` Args: client: The GraphQL client. organization: The name of the organization. name: The name of the registry (without the `wandb-registry-` prefix). visibility: The visibility level ('organization' or 'restricted'). description: An optional description for the registry. artifact_types: An optional list of allowed artifact types. Returns: Registry: The newly created Registry object. Raises: ValueError: If a registry with the same name already exists in the organization or if the creation fails. """ org_entity = _fetch_org_entity_from_organization(client, organization) full_name = REGISTRY_PREFIX + name validate_project_name(full_name) accepted_artifact_types = [] if artifact_types: accepted_artifact_types = _format_gql_artifact_types_input(artifact_types) visibility_value = _registry_visibility_to_gql(visibility) registry_creation_error = ( f"Failed to create registry {name!r} in organization {organization!r}." ) try: response = client.execute( gql(UPSERT_REGISTRY_PROJECT_GQL), variable_values={ "description": description, "entityName": org_entity, "name": full_name, "access": visibility_value, "allowAllArtifactTypesInRegistry": not accepted_artifact_types, "artifactTypes": accepted_artifact_types, }, ) except Exception: raise ValueError(registry_creation_error) if not response["upsertModel"]["inserted"]: raise ValueError(registry_creation_error) return Registry( client, organization, org_entity, name, response["upsertModel"]["project"], ) def delete(self) -> None: """Delete the registry. This is irreversible.""" try: response = self.client.execute( gql(DELETE_PROJECT_GQL), variable_values={"id": self._id} ) result = DeleteProject.model_validate(response) except Exception: raise ValueError( f"Failed to delete registry: {self.name!r} in organization: {self.organization!r}" ) if not result.delete_model.success: raise ValueError( f"Failed to delete registry: {self.name!r} in organization: {self.organization!r}" ) def load(self) -> None: """Load the registry attributes from the backend to reflect the latest saved state.""" load_failure_message = ( f"Failed to load registry {self.name!r} " f"in organization {self.organization!r}." ) try: response = self.client.execute( gql(FETCH_REGISTRY_GQL), variable_values={ "name": self.full_name, "entityName": self.entity, }, ) except Exception: raise ValueError(load_failure_message) if response["entity"] is None: raise ValueError(load_failure_message) self.attrs = response["entity"]["project"] if self.attrs is None: raise ValueError(load_failure_message) self._update_attributes(self.attrs) def save(self) -> None: """Save registry attributes to the backend.""" if not InternalApi()._server_supports( ServerFeature.INCLUDE_ARTIFACT_TYPES_IN_REGISTRY_CREATION ): raise RuntimeError( "saving the registry is not enabled on this wandb server version. " "Please upgrade your server version or contact support at support@wandb.com." ) if self._no_updating_registry_types(): raise ValueError( f"Cannot update artifact types when `allows_all_artifact_types` is {True!r}. Set it to {False!r} first." ) validate_project_name(self.full_name) visibility_value = _registry_visibility_to_gql(self.visibility) newly_added_types = _format_gql_artifact_types_input(self.artifact_types.draft) registry_save_error = f"Failed to save and update registry: {self.name} in organization: {self.organization}" full_saved_name = f"{REGISTRY_PREFIX}{self._saved_name}" try: response = self.client.execute( gql(UPSERT_REGISTRY_PROJECT_GQL), variable_values={ "description": self.description, "entityName": self.entity, "name": full_saved_name, # this makes it so we are updating the original registry in case the name has changed "access": visibility_value, "allowAllArtifactTypesInRegistry": self.allow_all_artifact_types, "artifactTypes": newly_added_types, }, ) result = UpsertRegistryProject.model_validate(response) except Exception: raise ValueError(registry_save_error) if result.upsert_model.inserted: # This is not suppose trigger unless the user has messed with the `_saved_name` variable wandb.termlog( f"Created registry {self.name!r} in organization {self.organization!r} on save" ) self._update_attributes(response["upsertModel"]["project"]) # Update the name of the registry if it has changed if self._saved_name != self.name: response = self.client.execute( gql(RENAME_PROJECT_GQL), variable_values={ "entityName": self.entity, "oldProjectName": full_saved_name, "newProjectName": self.full_name, }, ) result = RenameProject.model_validate(response) self._saved_name = self.name if result.rename_project.inserted: # This is not suppose trigger unless the user has messed with the `_saved_name` variable wandb.termlog(f"Created new registry {self.name!r} on save") def _no_updating_registry_types(self) -> bool: # artifact types draft means user assigned types to add that are not yet saved return len(self.artifact_types.draft) > 0 and self.allow_all_artifact_types