"""Public API: artifacts.""" from __future__ import annotations import json import re from copy import copy from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping, Sequence from typing_extensions import override from wandb_gql import Client, gql import wandb from wandb.apis import public from wandb.apis.normalize import normalize_exceptions from wandb.apis.paginator import Paginator, SizedPaginator from wandb.errors.term import termlog from wandb.proto.wandb_deprecated import Deprecated from wandb.proto.wandb_internal_pb2 import ServerFeature from wandb.sdk.artifacts._generated import ( ARTIFACT_COLLECTION_MEMBERSHIP_FILES_GQL, ARTIFACT_VERSION_FILES_GQL, CREATE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL, DELETE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL, DELETE_ARTIFACT_PORTFOLIO_GQL, DELETE_ARTIFACT_SEQUENCE_GQL, MOVE_ARTIFACT_COLLECTION_GQL, PROJECT_ARTIFACT_COLLECTION_GQL, PROJECT_ARTIFACT_COLLECTIONS_GQL, PROJECT_ARTIFACT_TYPE_GQL, PROJECT_ARTIFACT_TYPES_GQL, PROJECT_ARTIFACTS_GQL, RUN_INPUT_ARTIFACTS_GQL, RUN_OUTPUT_ARTIFACTS_GQL, UPDATE_ARTIFACT_PORTFOLIO_GQL, UPDATE_ARTIFACT_SEQUENCE_GQL, ArtifactCollectionMembershipFiles, ArtifactCollectionsFragment, ArtifactsFragment, ArtifactTypeFragment, ArtifactTypesFragment, ArtifactVersionFiles, FilesFragment, ProjectArtifactCollection, ProjectArtifactCollections, ProjectArtifacts, ProjectArtifactType, ProjectArtifactTypes, RunInputArtifactsProjectRunInputArtifacts, RunOutputArtifactsProjectRunOutputArtifacts, ) from wandb.sdk.artifacts._graphql_fragments import omit_artifact_fields from wandb.sdk.artifacts._validators import ( SOURCE_ARTIFACT_COLLECTION_TYPE, validate_artifact_name, validate_artifact_type, ) from wandb.sdk.internal.internal_api import Api as InternalApi from wandb.sdk.lib import deprecate from .utils import gql_compat if TYPE_CHECKING: from wandb.sdk.artifacts.artifact import Artifact from . import RetryingClient, Run class ArtifactTypes(Paginator["ArtifactType"]): QUERY = gql(PROJECT_ARTIFACT_TYPES_GQL) last_response: ArtifactTypesFragment | None def __init__( self, client: Client, entity: str, project: str, per_page: int = 50, ): self.entity = entity self.project = project variable_values = { "entityName": entity, "projectName": project, } super().__init__(client, variable_values, per_page) @override def _update_response(self) -> None: """Fetch and validate the response data for the current page.""" data = self.client.execute(self.QUERY, variable_values=self.variables) result = ProjectArtifactTypes.model_validate(data) # Extract the inner `*Connection` result for faster/easier access. if not ((proj := result.project) and (conn := proj.artifact_types)): raise ValueError(f"Unable to parse {type(self).__name__!r} response data") self.last_response = ArtifactTypesFragment.model_validate(conn) @property def length(self) -> None: # TODO return None @property def more(self) -> bool: if self.last_response is None: return True return self.last_response.page_info.has_next_page @property def cursor(self) -> str | None: if self.last_response is None: return None return self.last_response.edges[-1].cursor def update_variables(self) -> None: self.variables.update({"cursor": self.cursor}) def convert_objects(self) -> list[ArtifactType]: if self.last_response is None: return [] return [ ArtifactType( client=self.client, entity=self.entity, project=self.project, type_name=node.name, attrs=node.model_dump(exclude_unset=True), ) for edge in self.last_response.edges if edge.node and (node := ArtifactTypeFragment.model_validate(edge.node)) ] class ArtifactType: def __init__( self, client: Client, entity: str, project: str, type_name: str, attrs: Mapping[str, Any] | None = None, ): self.client = client self.entity = entity self.project = project self.type = type_name self._attrs = attrs if self._attrs is None: self.load() def load(self) -> Mapping[str, Any]: data: Mapping[str, Any] | None = self.client.execute( gql(PROJECT_ARTIFACT_TYPE_GQL), variable_values={ "entityName": self.entity, "projectName": self.project, "artifactTypeName": self.type, }, ) result = ProjectArtifactType.model_validate(data) if not ((proj := result.project) and (artifact_type := proj.artifact_type)): raise ValueError(f"Could not find artifact type {self.type}") self._attrs = artifact_type.model_dump(exclude_unset=True) return self._attrs @property def id(self) -> str: return self._attrs["id"] @property def name(self) -> str: return self._attrs["name"] @normalize_exceptions def collections(self, per_page: int = 50) -> ArtifactCollections: """Artifact collections.""" return ArtifactCollections(self.client, self.entity, self.project, self.type) def collection(self, name: str) -> ArtifactCollection: return ArtifactCollection( self.client, self.entity, self.project, name, self.type ) def __repr__(self) -> str: return f"" class ArtifactCollections(SizedPaginator["ArtifactCollection"]): last_response: ArtifactCollectionsFragment | None def __init__( self, client: Client, entity: str, project: str, type_name: str, per_page: int = 50, ): self.entity = entity self.project = project self.type_name = type_name variable_values = { "entityName": entity, "projectName": project, "artifactTypeName": type_name, } if server_supports_artifact_collections_gql_edges(client): rename_fields = None else: rename_fields = {"artifactCollections": "artifactSequences"} self.QUERY = gql_compat( PROJECT_ARTIFACT_COLLECTIONS_GQL, rename_fields=rename_fields ) super().__init__(client, variable_values, per_page) @override def _update_response(self) -> None: """Fetch and validate the response data for the current page.""" data = self.client.execute(self.QUERY, variable_values=self.variables) result = ProjectArtifactCollections.model_validate(data) # Extract the inner `*Connection` result for faster/easier access. if not ( (proj := result.project) and (type_ := proj.artifact_type) and (conn := type_.artifact_collections) ): raise ValueError(f"Unable to parse {type(self).__name__!r} response data") self.last_response = ArtifactCollectionsFragment.model_validate(conn) @property def length(self): if self.last_response is None: return None return self.last_response.total_count @property def more(self): if self.last_response is None: return True return self.last_response.page_info.has_next_page @property def cursor(self): if self.last_response is None: return None return self.last_response.edges[-1].cursor def update_variables(self) -> None: self.variables.update({"cursor": self.cursor}) def convert_objects(self) -> list[ArtifactCollection]: if self.last_response is None: return [] return [ ArtifactCollection( client=self.client, entity=self.entity, project=self.project, name=node.name, type=self.type_name, ) for edge in self.last_response.edges if (node := edge.node) ] class ArtifactCollection: def __init__( self, client: Client, entity: str, project: str, name: str, type: str, organization: str | None = None, attrs: Mapping[str, Any] | None = None, is_sequence: bool | None = None, ): self.client = client self.entity = entity self.project = project self._name = validate_artifact_name(name) self._saved_name = name self._type = type self._saved_type = type self._attrs = attrs if is_sequence is not None: self._is_sequence = is_sequence if (attrs is None) or (is_sequence is None): self.load() self._aliases = [a["node"]["alias"] for a in self._attrs["aliases"]["edges"]] self._description = self._attrs["description"] self._created_at = self._attrs["createdAt"] self._tags = [a["node"]["name"] for a in self._attrs["tags"]["edges"]] self._saved_tags = copy(self._tags) self.organization = organization @property def id(self) -> str: return self._attrs["id"] @normalize_exceptions def artifacts(self, per_page: int = 50) -> Artifacts: """Artifacts.""" return Artifacts( client=self.client, entity=self.entity, project=self.project, collection_name=self._saved_name, type=self._saved_type, per_page=per_page, ) @property def aliases(self) -> list[str]: """Artifact Collection Aliases.""" return self._aliases @property def created_at(self) -> str: return self._created_at def load(self): if server_supports_artifact_collections_gql_edges(self.client): rename_fields = None else: rename_fields = {"artifactCollection": "artifactSequence"} response = self.client.execute( gql_compat(PROJECT_ARTIFACT_COLLECTION_GQL, rename_fields=rename_fields), variable_values={ "entityName": self.entity, "projectName": self.project, "artifactTypeName": self._saved_type, "artifactCollectionName": self._saved_name, }, ) result = ProjectArtifactCollection.model_validate(response) if not ( result.project and (proj := result.project) and (type_ := proj.artifact_type) and (collection := type_.artifact_collection) ): raise ValueError(f"Could not find artifact type {self._saved_type}") sequence = type_.artifact_sequence self._is_sequence = ( sequence is not None ) and sequence.typename__ == SOURCE_ARTIFACT_COLLECTION_TYPE if self._attrs is None: self._attrs = collection.model_dump(exclude_unset=True) return self._attrs @normalize_exceptions def change_type(self, new_type: str) -> None: """Deprecated, change type directly with `save` instead.""" deprecate.deprecate( field_name=Deprecated.artifact_collection__change_type, warning_message="ArtifactCollection.change_type(type) is deprecated, use ArtifactCollection.save() instead.", ) if self._saved_type != new_type: try: validate_artifact_type(self._saved_type, self.name) except ValueError as e: raise ValueError( f"The current type '{self._saved_type!r}' is an internal type and cannot be changed." ) from e # Check that the new type is not going to conflict with internal types validate_artifact_type(new_type, self.name) if not self.is_sequence(): raise ValueError("Artifact collection needs to be a sequence") termlog( f"Changing artifact collection type of {self._saved_type} to {new_type}" ) self.client.execute( gql(MOVE_ARTIFACT_COLLECTION_GQL), variable_values={ "artifactSequenceID": self.id, "destinationArtifactTypeName": new_type, }, ) self._saved_type = new_type self._type = new_type def is_sequence(self) -> bool: """Return whether the artifact collection is a sequence.""" return self._is_sequence @normalize_exceptions def delete(self) -> None: """Delete the entire artifact collection.""" self.client.execute( gql( DELETE_ARTIFACT_SEQUENCE_GQL if self.is_sequence() else DELETE_ARTIFACT_PORTFOLIO_GQL ), variable_values={"id": self.id}, ) @property def description(self) -> str: """A description of the artifact collection.""" return self._description @description.setter def description(self, description: str | None) -> None: self._description = description @property def tags(self) -> list[str]: """The tags associated with the artifact collection.""" return self._tags @tags.setter def tags(self, tags: list[str]) -> None: if any(not re.match(r"^[-\w]+([ ]+[-\w]+)*$", tag) for tag in tags): raise ValueError( "Tags must only contain alphanumeric characters or underscores separated by spaces or hyphens" ) self._tags = tags @property def name(self) -> str: """The name of the artifact collection.""" return self._name @name.setter def name(self, name: str) -> None: self._name = validate_artifact_name(name) @property def type(self): """The type of the artifact collection.""" return self._type @type.setter def type(self, type: list[str]) -> None: if not self.is_sequence(): raise ValueError( "Type can only be changed if the artifact collection is a sequence." ) self._type = type def _update_collection(self) -> None: self.client.execute( gql( UPDATE_ARTIFACT_SEQUENCE_GQL if self.is_sequence() else UPDATE_ARTIFACT_PORTFOLIO_GQL ), variable_values={ "id": self.id, "name": self.name, "description": self.description, }, ) self._saved_name = self._name def _update_collection_type(self) -> None: self.client.execute( gql(MOVE_ARTIFACT_COLLECTION_GQL), variable_values={ "artifactSequenceID": self.id, "destinationArtifactTypeName": self.type, }, ) self._saved_type = self._type def _add_tags(self, tags_to_add: Iterable[str]) -> None: self.client.execute( gql(CREATE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL), variable_values={ "entityName": self.entity, "projectName": self.project, "artifactCollectionName": self._saved_name, "tags": [{"tagName": tag} for tag in tags_to_add], }, ) def _delete_tags(self, tags_to_delete: Iterable[str]) -> None: self.client.execute( gql(DELETE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL), variable_values={ "entityName": self.entity, "projectName": self.project, "artifactCollectionName": self._saved_name, "tags": [{"tagName": tag} for tag in tags_to_delete], }, ) @normalize_exceptions def save(self) -> None: """Persist any changes made to the artifact collection.""" if self._saved_type != self.type: try: validate_artifact_type(self.type, self._name) except ValueError as e: raise ValueError(f"Failed to save artifact collection: {e}") from e try: validate_artifact_type(self._saved_type, self._name) except ValueError as e: raise ValueError( f"Failed to save artifact collection '{self._name}': " f"The current type '{self._saved_type!r}' is an internal type and cannot be changed." ) from e self._update_collection() if self.is_sequence() and (self._saved_type != self._type): self._update_collection_type() current_tags = set(self._tags) saved_tags = set(self._saved_tags) if tags_to_add := (current_tags - saved_tags): self._add_tags(tags_to_add) if tags_to_delete := (saved_tags - current_tags): self._delete_tags(tags_to_delete) self._saved_tags = copy(self._tags) def __repr__(self) -> str: return f"" class Artifacts(SizedPaginator["Artifact"]): """An iterable collection of artifact versions associated with a project and optional filter. This is generally used indirectly via the `Api`.artifact_versions method. """ last_response: ArtifactsFragment | None def __init__( self, client: Client, entity: str, project: str, collection_name: str, type: str, filters: Mapping[str, Any] | None = None, order: str | None = None, per_page: int = 50, tags: str | list[str] | None = None, ): self.entity = entity self.collection_name = collection_name self.type = type self.project = project self.filters = {"state": "COMMITTED"} if filters is None else filters self.tags = [tags] if isinstance(tags, str) else tags self.order = order variables = { "project": self.project, "entity": self.entity, "order": self.order, "type": self.type, "collection": self.collection_name, "filters": json.dumps(self.filters), } if server_supports_artifact_collections_gql_edges(client): rename_fields = None else: rename_fields = {"artifactCollection": "artifactSequence"} self.QUERY = gql_compat( PROJECT_ARTIFACTS_GQL, omit_fields=omit_artifact_fields(api=InternalApi()), rename_fields=rename_fields, ) super().__init__(client, variables, per_page) @override def _update_response(self) -> None: data = self.client.execute(self.QUERY, variable_values=self.variables) result = ProjectArtifacts.model_validate(data) # Extract the inner `*Connection` result for faster/easier access. if not ( (proj := result.project) and (type_ := proj.artifact_type) and (collection := type_.artifact_collection) and (conn := collection.artifacts) ): raise ValueError(f"Unable to parse {type(self).__name__!r} response data") self.last_response = ArtifactsFragment.model_validate(conn) @property def length(self) -> int | None: if self.last_response is None: return None return self.last_response.total_count @property def more(self) -> bool: if self.last_response is None: return True return self.last_response.page_info.has_next_page @property def cursor(self) -> str | None: if self.last_response is None: return None return self.last_response.edges[-1].cursor def convert_objects(self) -> list[Artifact]: if self.last_response is None: return [] artifact_edges = (edge for edge in self.last_response.edges if edge.node) artifacts = ( wandb.Artifact._from_attrs( entity=self.entity, project=self.project, name=f"{self.collection_name}:{edge.version}", attrs=edge.node.model_dump(exclude_unset=True), client=self.client, ) for edge in artifact_edges ) required_tags = set(self.tags or []) return [art for art in artifacts if required_tags.issubset(art.tags)] class RunArtifacts(SizedPaginator["Artifact"]): last_response: ( RunOutputArtifactsProjectRunOutputArtifacts | RunInputArtifactsProjectRunInputArtifacts ) #: The pydantic model used to parse the (inner part of the) raw response. _response_cls: type[ RunOutputArtifactsProjectRunOutputArtifacts | RunInputArtifactsProjectRunInputArtifacts ] def __init__( self, client: Client, run: Run, mode: Literal["logged", "used"] = "logged", per_page: int = 50, ): self.run = run if mode == "logged": self.run_key = "outputArtifacts" self.QUERY = gql_compat( RUN_OUTPUT_ARTIFACTS_GQL, omit_fields=omit_artifact_fields(api=InternalApi()), ) self._response_cls = RunOutputArtifactsProjectRunOutputArtifacts elif mode == "used": self.run_key = "inputArtifacts" self.QUERY = gql_compat( RUN_INPUT_ARTIFACTS_GQL, omit_fields=omit_artifact_fields(api=InternalApi()), ) self._response_cls = RunInputArtifactsProjectRunInputArtifacts else: raise ValueError("mode must be logged or used") variable_values = { "entity": run.entity, "project": run.project, "runName": run.id, } super().__init__(client, variable_values, per_page) @override def _update_response(self) -> None: data = self.client.execute(self.QUERY, variable_values=self.variables) # Extract the inner `*Connection` result for faster/easier access. inner_data = data["project"]["run"][self.run_key] self.last_response = self._response_cls.model_validate(inner_data) @property def length(self) -> int | None: if self.last_response is None: return None return self.last_response.total_count @property def more(self) -> bool: if self.last_response is None: return True return self.last_response.page_info.has_next_page @property def cursor(self) -> str | None: if self.last_response is None: return None return self.last_response.edges[-1].cursor def convert_objects(self) -> list[Artifact]: if self.last_response is None: return [] return [ wandb.Artifact._from_attrs( entity=proj.entity_name, project=proj.name, name=f"{artifact_seq.name}:v{node.version_index}", attrs=node.model_dump(exclude_unset=True), client=self.client, ) for edge in self.last_response.edges if (node := edge.node) and (artifact_seq := node.artifact_sequence) and (proj := artifact_seq.project) ] class ArtifactFiles(SizedPaginator["public.File"]): last_response: FilesFragment | None def __init__( self, client: Client, artifact: Artifact, names: Sequence[str] | None = None, per_page: int = 50, ): self.query_via_membership = InternalApi()._server_supports( ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILES ) self.artifact = artifact if self.query_via_membership: query_str = ARTIFACT_COLLECTION_MEMBERSHIP_FILES_GQL variables = { "entityName": artifact.entity, "projectName": artifact.project, "artifactName": artifact.name.split(":")[0], "artifactVersionIndex": artifact.version, "fileNames": names, } else: query_str = ARTIFACT_VERSION_FILES_GQL variables = { "entityName": artifact.source_entity, "projectName": artifact.source_project, "artifactName": artifact.source_name, "artifactTypeName": artifact.type, "fileNames": names, } # The server must advertise at least SDK 0.12.21 # to get storagePath if not client.version_supported("0.12.21"): self.QUERY = gql_compat(query_str, omit_fields={"storagePath"}) else: self.QUERY = gql(query_str) super().__init__(client, variables, per_page) @override def _update_response(self) -> None: data = self.client.execute(self.QUERY, variable_values=self.variables) # Extract the inner `*Connection` result for faster/easier access. if self.query_via_membership: result = ArtifactCollectionMembershipFiles.model_validate(data) conn = result.project.artifact_collection.artifact_membership.files else: result = ArtifactVersionFiles.model_validate(data) conn = result.project.artifact_type.artifact.files if conn is None: raise ValueError(f"Unable to parse {type(self).__name__!r} response data") self.last_response = FilesFragment.model_validate(conn) @property def path(self) -> list[str]: return [self.artifact.entity, self.artifact.project, self.artifact.name] @property def length(self) -> int: return self.artifact.file_count @property def more(self) -> bool: if self.last_response is None: return True return self.last_response.page_info.has_next_page @property def cursor(self) -> str | None: if self.last_response is None: return None return self.last_response.edges[-1].cursor def update_variables(self) -> None: self.variables.update({"fileLimit": self.per_page, "fileCursor": self.cursor}) def convert_objects(self) -> list[public.File]: if self.last_response is None: return [] return [ public.File( client=self.client, attrs=node.model_dump(exclude_unset=True), ) for edge in self.last_response.edges if (node := edge.node) ] def __repr__(self) -> str: path_str = "/".join(self.path) return f"" def server_supports_artifact_collections_gql_edges( client: RetryingClient, warn: bool = False ) -> bool: # TODO: Validate this version # Edges were merged into core on Mar 2, 2022: https://github.com/wandb/core/commit/81c90b29eaacfe0a96dc1ebd83c53560ca763e8b # CLI version was bumped to "0.12.11" on Mar 3, 2022: https://github.com/wandb/core/commit/328396fa7c89a2178d510a1be9c0d4451f350d7b supported = client.version_supported("0.12.11") # edges were merged on if not supported and warn: # First local release to include the above is 0.9.50: https://github.com/wandb/local/releases/tag/0.9.50 wandb.termwarn( "W&B Local Server version does not support ArtifactCollection gql edges; falling back to using legacy ArtifactSequence. Please update server to at least version 0.9.50." ) return supported