|
"""Public API: artifacts.""" |
|
|
|
from __future__ import annotations |
|
|
|
import json |
|
import re |
|
from copy import copy |
|
from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping, Sequence |
|
|
|
from typing_extensions import override |
|
from wandb_gql import Client, gql |
|
|
|
import wandb |
|
from wandb.apis import public |
|
from wandb.apis.normalize import normalize_exceptions |
|
from wandb.apis.paginator import Paginator, SizedPaginator |
|
from wandb.errors.term import termlog |
|
from wandb.proto.wandb_deprecated import Deprecated |
|
from wandb.proto.wandb_internal_pb2 import ServerFeature |
|
from wandb.sdk.artifacts._generated import ( |
|
ARTIFACT_COLLECTION_MEMBERSHIP_FILES_GQL, |
|
ARTIFACT_VERSION_FILES_GQL, |
|
CREATE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL, |
|
DELETE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL, |
|
DELETE_ARTIFACT_PORTFOLIO_GQL, |
|
DELETE_ARTIFACT_SEQUENCE_GQL, |
|
MOVE_ARTIFACT_COLLECTION_GQL, |
|
PROJECT_ARTIFACT_COLLECTION_GQL, |
|
PROJECT_ARTIFACT_COLLECTIONS_GQL, |
|
PROJECT_ARTIFACT_TYPE_GQL, |
|
PROJECT_ARTIFACT_TYPES_GQL, |
|
PROJECT_ARTIFACTS_GQL, |
|
RUN_INPUT_ARTIFACTS_GQL, |
|
RUN_OUTPUT_ARTIFACTS_GQL, |
|
UPDATE_ARTIFACT_PORTFOLIO_GQL, |
|
UPDATE_ARTIFACT_SEQUENCE_GQL, |
|
ArtifactCollectionMembershipFiles, |
|
ArtifactCollectionsFragment, |
|
ArtifactsFragment, |
|
ArtifactTypeFragment, |
|
ArtifactTypesFragment, |
|
ArtifactVersionFiles, |
|
FilesFragment, |
|
ProjectArtifactCollection, |
|
ProjectArtifactCollections, |
|
ProjectArtifacts, |
|
ProjectArtifactType, |
|
ProjectArtifactTypes, |
|
RunInputArtifactsProjectRunInputArtifacts, |
|
RunOutputArtifactsProjectRunOutputArtifacts, |
|
) |
|
from wandb.sdk.artifacts._graphql_fragments import omit_artifact_fields |
|
from wandb.sdk.artifacts._validators import ( |
|
SOURCE_ARTIFACT_COLLECTION_TYPE, |
|
validate_artifact_name, |
|
validate_artifact_type, |
|
) |
|
from wandb.sdk.internal.internal_api import Api as InternalApi |
|
from wandb.sdk.lib import deprecate |
|
|
|
from .utils import gql_compat |
|
|
|
if TYPE_CHECKING: |
|
from wandb.sdk.artifacts.artifact import Artifact |
|
|
|
from . import RetryingClient, Run |
|
|
|
|
|
class ArtifactTypes(Paginator["ArtifactType"]): |
|
QUERY = gql(PROJECT_ARTIFACT_TYPES_GQL) |
|
|
|
last_response: ArtifactTypesFragment | None |
|
|
|
def __init__( |
|
self, |
|
client: Client, |
|
entity: str, |
|
project: str, |
|
per_page: int = 50, |
|
): |
|
self.entity = entity |
|
self.project = project |
|
|
|
variable_values = { |
|
"entityName": entity, |
|
"projectName": project, |
|
} |
|
super().__init__(client, variable_values, per_page) |
|
|
|
@override |
|
def _update_response(self) -> None: |
|
"""Fetch and validate the response data for the current page.""" |
|
data = self.client.execute(self.QUERY, variable_values=self.variables) |
|
result = ProjectArtifactTypes.model_validate(data) |
|
|
|
|
|
if not ((proj := result.project) and (conn := proj.artifact_types)): |
|
raise ValueError(f"Unable to parse {type(self).__name__!r} response data") |
|
|
|
self.last_response = ArtifactTypesFragment.model_validate(conn) |
|
|
|
@property |
|
def length(self) -> None: |
|
|
|
return None |
|
|
|
@property |
|
def more(self) -> bool: |
|
if self.last_response is None: |
|
return True |
|
return self.last_response.page_info.has_next_page |
|
|
|
@property |
|
def cursor(self) -> str | None: |
|
if self.last_response is None: |
|
return None |
|
return self.last_response.edges[-1].cursor |
|
|
|
def update_variables(self) -> None: |
|
self.variables.update({"cursor": self.cursor}) |
|
|
|
def convert_objects(self) -> list[ArtifactType]: |
|
if self.last_response is None: |
|
return [] |
|
|
|
return [ |
|
ArtifactType( |
|
client=self.client, |
|
entity=self.entity, |
|
project=self.project, |
|
type_name=node.name, |
|
attrs=node.model_dump(exclude_unset=True), |
|
) |
|
for edge in self.last_response.edges |
|
if edge.node and (node := ArtifactTypeFragment.model_validate(edge.node)) |
|
] |
|
|
|
|
|
class ArtifactType: |
|
def __init__( |
|
self, |
|
client: Client, |
|
entity: str, |
|
project: str, |
|
type_name: str, |
|
attrs: Mapping[str, Any] | None = None, |
|
): |
|
self.client = client |
|
self.entity = entity |
|
self.project = project |
|
self.type = type_name |
|
self._attrs = attrs |
|
if self._attrs is None: |
|
self.load() |
|
|
|
def load(self) -> Mapping[str, Any]: |
|
data: Mapping[str, Any] | None = self.client.execute( |
|
gql(PROJECT_ARTIFACT_TYPE_GQL), |
|
variable_values={ |
|
"entityName": self.entity, |
|
"projectName": self.project, |
|
"artifactTypeName": self.type, |
|
}, |
|
) |
|
result = ProjectArtifactType.model_validate(data) |
|
if not ((proj := result.project) and (artifact_type := proj.artifact_type)): |
|
raise ValueError(f"Could not find artifact type {self.type}") |
|
|
|
self._attrs = artifact_type.model_dump(exclude_unset=True) |
|
return self._attrs |
|
|
|
@property |
|
def id(self) -> str: |
|
return self._attrs["id"] |
|
|
|
@property |
|
def name(self) -> str: |
|
return self._attrs["name"] |
|
|
|
@normalize_exceptions |
|
def collections(self, per_page: int = 50) -> ArtifactCollections: |
|
"""Artifact collections.""" |
|
return ArtifactCollections(self.client, self.entity, self.project, self.type) |
|
|
|
def collection(self, name: str) -> ArtifactCollection: |
|
return ArtifactCollection( |
|
self.client, self.entity, self.project, name, self.type |
|
) |
|
|
|
def __repr__(self) -> str: |
|
return f"<ArtifactType {self.type}>" |
|
|
|
|
|
class ArtifactCollections(SizedPaginator["ArtifactCollection"]): |
|
last_response: ArtifactCollectionsFragment | None |
|
|
|
def __init__( |
|
self, |
|
client: Client, |
|
entity: str, |
|
project: str, |
|
type_name: str, |
|
per_page: int = 50, |
|
): |
|
self.entity = entity |
|
self.project = project |
|
self.type_name = type_name |
|
|
|
variable_values = { |
|
"entityName": entity, |
|
"projectName": project, |
|
"artifactTypeName": type_name, |
|
} |
|
|
|
if server_supports_artifact_collections_gql_edges(client): |
|
rename_fields = None |
|
else: |
|
rename_fields = {"artifactCollections": "artifactSequences"} |
|
|
|
self.QUERY = gql_compat( |
|
PROJECT_ARTIFACT_COLLECTIONS_GQL, rename_fields=rename_fields |
|
) |
|
|
|
super().__init__(client, variable_values, per_page) |
|
|
|
@override |
|
def _update_response(self) -> None: |
|
"""Fetch and validate the response data for the current page.""" |
|
data = self.client.execute(self.QUERY, variable_values=self.variables) |
|
result = ProjectArtifactCollections.model_validate(data) |
|
|
|
|
|
if not ( |
|
(proj := result.project) |
|
and (type_ := proj.artifact_type) |
|
and (conn := type_.artifact_collections) |
|
): |
|
raise ValueError(f"Unable to parse {type(self).__name__!r} response data") |
|
|
|
self.last_response = ArtifactCollectionsFragment.model_validate(conn) |
|
|
|
@property |
|
def length(self): |
|
if self.last_response is None: |
|
return None |
|
return self.last_response.total_count |
|
|
|
@property |
|
def more(self): |
|
if self.last_response is None: |
|
return True |
|
return self.last_response.page_info.has_next_page |
|
|
|
@property |
|
def cursor(self): |
|
if self.last_response is None: |
|
return None |
|
return self.last_response.edges[-1].cursor |
|
|
|
def update_variables(self) -> None: |
|
self.variables.update({"cursor": self.cursor}) |
|
|
|
def convert_objects(self) -> list[ArtifactCollection]: |
|
if self.last_response is None: |
|
return [] |
|
|
|
return [ |
|
ArtifactCollection( |
|
client=self.client, |
|
entity=self.entity, |
|
project=self.project, |
|
name=node.name, |
|
type=self.type_name, |
|
) |
|
for edge in self.last_response.edges |
|
if (node := edge.node) |
|
] |
|
|
|
|
|
class ArtifactCollection: |
|
def __init__( |
|
self, |
|
client: Client, |
|
entity: str, |
|
project: str, |
|
name: str, |
|
type: str, |
|
organization: str | None = None, |
|
attrs: Mapping[str, Any] | None = None, |
|
is_sequence: bool | None = None, |
|
): |
|
self.client = client |
|
self.entity = entity |
|
self.project = project |
|
self._name = validate_artifact_name(name) |
|
self._saved_name = name |
|
self._type = type |
|
self._saved_type = type |
|
self._attrs = attrs |
|
if is_sequence is not None: |
|
self._is_sequence = is_sequence |
|
if (attrs is None) or (is_sequence is None): |
|
self.load() |
|
self._aliases = [a["node"]["alias"] for a in self._attrs["aliases"]["edges"]] |
|
self._description = self._attrs["description"] |
|
self._created_at = self._attrs["createdAt"] |
|
self._tags = [a["node"]["name"] for a in self._attrs["tags"]["edges"]] |
|
self._saved_tags = copy(self._tags) |
|
self.organization = organization |
|
|
|
@property |
|
def id(self) -> str: |
|
return self._attrs["id"] |
|
|
|
@normalize_exceptions |
|
def artifacts(self, per_page: int = 50) -> Artifacts: |
|
"""Artifacts.""" |
|
return Artifacts( |
|
client=self.client, |
|
entity=self.entity, |
|
project=self.project, |
|
collection_name=self._saved_name, |
|
type=self._saved_type, |
|
per_page=per_page, |
|
) |
|
|
|
@property |
|
def aliases(self) -> list[str]: |
|
"""Artifact Collection Aliases.""" |
|
return self._aliases |
|
|
|
@property |
|
def created_at(self) -> str: |
|
return self._created_at |
|
|
|
def load(self): |
|
if server_supports_artifact_collections_gql_edges(self.client): |
|
rename_fields = None |
|
else: |
|
rename_fields = {"artifactCollection": "artifactSequence"} |
|
|
|
response = self.client.execute( |
|
gql_compat(PROJECT_ARTIFACT_COLLECTION_GQL, rename_fields=rename_fields), |
|
variable_values={ |
|
"entityName": self.entity, |
|
"projectName": self.project, |
|
"artifactTypeName": self._saved_type, |
|
"artifactCollectionName": self._saved_name, |
|
}, |
|
) |
|
|
|
result = ProjectArtifactCollection.model_validate(response) |
|
|
|
if not ( |
|
result.project |
|
and (proj := result.project) |
|
and (type_ := proj.artifact_type) |
|
and (collection := type_.artifact_collection) |
|
): |
|
raise ValueError(f"Could not find artifact type {self._saved_type}") |
|
|
|
sequence = type_.artifact_sequence |
|
self._is_sequence = ( |
|
sequence is not None |
|
) and sequence.typename__ == SOURCE_ARTIFACT_COLLECTION_TYPE |
|
|
|
if self._attrs is None: |
|
self._attrs = collection.model_dump(exclude_unset=True) |
|
return self._attrs |
|
|
|
@normalize_exceptions |
|
def change_type(self, new_type: str) -> None: |
|
"""Deprecated, change type directly with `save` instead.""" |
|
deprecate.deprecate( |
|
field_name=Deprecated.artifact_collection__change_type, |
|
warning_message="ArtifactCollection.change_type(type) is deprecated, use ArtifactCollection.save() instead.", |
|
) |
|
|
|
if self._saved_type != new_type: |
|
try: |
|
validate_artifact_type(self._saved_type, self.name) |
|
except ValueError as e: |
|
raise ValueError( |
|
f"The current type '{self._saved_type!r}' is an internal type and cannot be changed." |
|
) from e |
|
|
|
|
|
validate_artifact_type(new_type, self.name) |
|
|
|
if not self.is_sequence(): |
|
raise ValueError("Artifact collection needs to be a sequence") |
|
termlog( |
|
f"Changing artifact collection type of {self._saved_type} to {new_type}" |
|
) |
|
self.client.execute( |
|
gql(MOVE_ARTIFACT_COLLECTION_GQL), |
|
variable_values={ |
|
"artifactSequenceID": self.id, |
|
"destinationArtifactTypeName": new_type, |
|
}, |
|
) |
|
self._saved_type = new_type |
|
self._type = new_type |
|
|
|
def is_sequence(self) -> bool: |
|
"""Return whether the artifact collection is a sequence.""" |
|
return self._is_sequence |
|
|
|
@normalize_exceptions |
|
def delete(self) -> None: |
|
"""Delete the entire artifact collection.""" |
|
self.client.execute( |
|
gql( |
|
DELETE_ARTIFACT_SEQUENCE_GQL |
|
if self.is_sequence() |
|
else DELETE_ARTIFACT_PORTFOLIO_GQL |
|
), |
|
variable_values={"id": self.id}, |
|
) |
|
|
|
@property |
|
def description(self) -> str: |
|
"""A description of the artifact collection.""" |
|
return self._description |
|
|
|
@description.setter |
|
def description(self, description: str | None) -> None: |
|
self._description = description |
|
|
|
@property |
|
def tags(self) -> list[str]: |
|
"""The tags associated with the artifact collection.""" |
|
return self._tags |
|
|
|
@tags.setter |
|
def tags(self, tags: list[str]) -> None: |
|
if any(not re.match(r"^[-\w]+([ ]+[-\w]+)*$", tag) for tag in tags): |
|
raise ValueError( |
|
"Tags must only contain alphanumeric characters or underscores separated by spaces or hyphens" |
|
) |
|
self._tags = tags |
|
|
|
@property |
|
def name(self) -> str: |
|
"""The name of the artifact collection.""" |
|
return self._name |
|
|
|
@name.setter |
|
def name(self, name: str) -> None: |
|
self._name = validate_artifact_name(name) |
|
|
|
@property |
|
def type(self): |
|
"""The type of the artifact collection.""" |
|
return self._type |
|
|
|
@type.setter |
|
def type(self, type: list[str]) -> None: |
|
if not self.is_sequence(): |
|
raise ValueError( |
|
"Type can only be changed if the artifact collection is a sequence." |
|
) |
|
self._type = type |
|
|
|
def _update_collection(self) -> None: |
|
self.client.execute( |
|
gql( |
|
UPDATE_ARTIFACT_SEQUENCE_GQL |
|
if self.is_sequence() |
|
else UPDATE_ARTIFACT_PORTFOLIO_GQL |
|
), |
|
variable_values={ |
|
"id": self.id, |
|
"name": self.name, |
|
"description": self.description, |
|
}, |
|
) |
|
self._saved_name = self._name |
|
|
|
def _update_collection_type(self) -> None: |
|
self.client.execute( |
|
gql(MOVE_ARTIFACT_COLLECTION_GQL), |
|
variable_values={ |
|
"artifactSequenceID": self.id, |
|
"destinationArtifactTypeName": self.type, |
|
}, |
|
) |
|
self._saved_type = self._type |
|
|
|
def _add_tags(self, tags_to_add: Iterable[str]) -> None: |
|
self.client.execute( |
|
gql(CREATE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL), |
|
variable_values={ |
|
"entityName": self.entity, |
|
"projectName": self.project, |
|
"artifactCollectionName": self._saved_name, |
|
"tags": [{"tagName": tag} for tag in tags_to_add], |
|
}, |
|
) |
|
|
|
def _delete_tags(self, tags_to_delete: Iterable[str]) -> None: |
|
self.client.execute( |
|
gql(DELETE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL), |
|
variable_values={ |
|
"entityName": self.entity, |
|
"projectName": self.project, |
|
"artifactCollectionName": self._saved_name, |
|
"tags": [{"tagName": tag} for tag in tags_to_delete], |
|
}, |
|
) |
|
|
|
@normalize_exceptions |
|
def save(self) -> None: |
|
"""Persist any changes made to the artifact collection.""" |
|
if self._saved_type != self.type: |
|
try: |
|
validate_artifact_type(self.type, self._name) |
|
except ValueError as e: |
|
raise ValueError(f"Failed to save artifact collection: {e}") from e |
|
try: |
|
validate_artifact_type(self._saved_type, self._name) |
|
except ValueError as e: |
|
raise ValueError( |
|
f"Failed to save artifact collection '{self._name}': " |
|
f"The current type '{self._saved_type!r}' is an internal type and cannot be changed." |
|
) from e |
|
|
|
self._update_collection() |
|
|
|
if self.is_sequence() and (self._saved_type != self._type): |
|
self._update_collection_type() |
|
|
|
current_tags = set(self._tags) |
|
saved_tags = set(self._saved_tags) |
|
if tags_to_add := (current_tags - saved_tags): |
|
self._add_tags(tags_to_add) |
|
if tags_to_delete := (saved_tags - current_tags): |
|
self._delete_tags(tags_to_delete) |
|
self._saved_tags = copy(self._tags) |
|
|
|
def __repr__(self) -> str: |
|
return f"<ArtifactCollection {self._name} ({self._type})>" |
|
|
|
|
|
class Artifacts(SizedPaginator["Artifact"]): |
|
"""An iterable collection of artifact versions associated with a project and optional filter. |
|
|
|
This is generally used indirectly via the `Api`.artifact_versions method. |
|
""" |
|
|
|
last_response: ArtifactsFragment | None |
|
|
|
def __init__( |
|
self, |
|
client: Client, |
|
entity: str, |
|
project: str, |
|
collection_name: str, |
|
type: str, |
|
filters: Mapping[str, Any] | None = None, |
|
order: str | None = None, |
|
per_page: int = 50, |
|
tags: str | list[str] | None = None, |
|
): |
|
self.entity = entity |
|
self.collection_name = collection_name |
|
self.type = type |
|
self.project = project |
|
self.filters = {"state": "COMMITTED"} if filters is None else filters |
|
self.tags = [tags] if isinstance(tags, str) else tags |
|
self.order = order |
|
variables = { |
|
"project": self.project, |
|
"entity": self.entity, |
|
"order": self.order, |
|
"type": self.type, |
|
"collection": self.collection_name, |
|
"filters": json.dumps(self.filters), |
|
} |
|
|
|
if server_supports_artifact_collections_gql_edges(client): |
|
rename_fields = None |
|
else: |
|
rename_fields = {"artifactCollection": "artifactSequence"} |
|
|
|
self.QUERY = gql_compat( |
|
PROJECT_ARTIFACTS_GQL, |
|
omit_fields=omit_artifact_fields(api=InternalApi()), |
|
rename_fields=rename_fields, |
|
) |
|
|
|
super().__init__(client, variables, per_page) |
|
|
|
@override |
|
def _update_response(self) -> None: |
|
data = self.client.execute(self.QUERY, variable_values=self.variables) |
|
result = ProjectArtifacts.model_validate(data) |
|
|
|
|
|
if not ( |
|
(proj := result.project) |
|
and (type_ := proj.artifact_type) |
|
and (collection := type_.artifact_collection) |
|
and (conn := collection.artifacts) |
|
): |
|
raise ValueError(f"Unable to parse {type(self).__name__!r} response data") |
|
|
|
self.last_response = ArtifactsFragment.model_validate(conn) |
|
|
|
@property |
|
def length(self) -> int | None: |
|
if self.last_response is None: |
|
return None |
|
return self.last_response.total_count |
|
|
|
@property |
|
def more(self) -> bool: |
|
if self.last_response is None: |
|
return True |
|
return self.last_response.page_info.has_next_page |
|
|
|
@property |
|
def cursor(self) -> str | None: |
|
if self.last_response is None: |
|
return None |
|
return self.last_response.edges[-1].cursor |
|
|
|
def convert_objects(self) -> list[Artifact]: |
|
if self.last_response is None: |
|
return [] |
|
|
|
artifact_edges = (edge for edge in self.last_response.edges if edge.node) |
|
artifacts = ( |
|
wandb.Artifact._from_attrs( |
|
entity=self.entity, |
|
project=self.project, |
|
name=f"{self.collection_name}:{edge.version}", |
|
attrs=edge.node.model_dump(exclude_unset=True), |
|
client=self.client, |
|
) |
|
for edge in artifact_edges |
|
) |
|
required_tags = set(self.tags or []) |
|
return [art for art in artifacts if required_tags.issubset(art.tags)] |
|
|
|
|
|
class RunArtifacts(SizedPaginator["Artifact"]): |
|
last_response: ( |
|
RunOutputArtifactsProjectRunOutputArtifacts |
|
| RunInputArtifactsProjectRunInputArtifacts |
|
) |
|
|
|
|
|
_response_cls: type[ |
|
RunOutputArtifactsProjectRunOutputArtifacts |
|
| RunInputArtifactsProjectRunInputArtifacts |
|
] |
|
|
|
def __init__( |
|
self, |
|
client: Client, |
|
run: Run, |
|
mode: Literal["logged", "used"] = "logged", |
|
per_page: int = 50, |
|
): |
|
self.run = run |
|
|
|
if mode == "logged": |
|
self.run_key = "outputArtifacts" |
|
self.QUERY = gql_compat( |
|
RUN_OUTPUT_ARTIFACTS_GQL, |
|
omit_fields=omit_artifact_fields(api=InternalApi()), |
|
) |
|
self._response_cls = RunOutputArtifactsProjectRunOutputArtifacts |
|
elif mode == "used": |
|
self.run_key = "inputArtifacts" |
|
self.QUERY = gql_compat( |
|
RUN_INPUT_ARTIFACTS_GQL, |
|
omit_fields=omit_artifact_fields(api=InternalApi()), |
|
) |
|
self._response_cls = RunInputArtifactsProjectRunInputArtifacts |
|
else: |
|
raise ValueError("mode must be logged or used") |
|
|
|
variable_values = { |
|
"entity": run.entity, |
|
"project": run.project, |
|
"runName": run.id, |
|
} |
|
super().__init__(client, variable_values, per_page) |
|
|
|
@override |
|
def _update_response(self) -> None: |
|
data = self.client.execute(self.QUERY, variable_values=self.variables) |
|
|
|
|
|
inner_data = data["project"]["run"][self.run_key] |
|
self.last_response = self._response_cls.model_validate(inner_data) |
|
|
|
@property |
|
def length(self) -> int | None: |
|
if self.last_response is None: |
|
return None |
|
return self.last_response.total_count |
|
|
|
@property |
|
def more(self) -> bool: |
|
if self.last_response is None: |
|
return True |
|
return self.last_response.page_info.has_next_page |
|
|
|
@property |
|
def cursor(self) -> str | None: |
|
if self.last_response is None: |
|
return None |
|
return self.last_response.edges[-1].cursor |
|
|
|
def convert_objects(self) -> list[Artifact]: |
|
if self.last_response is None: |
|
return [] |
|
|
|
return [ |
|
wandb.Artifact._from_attrs( |
|
entity=proj.entity_name, |
|
project=proj.name, |
|
name=f"{artifact_seq.name}:v{node.version_index}", |
|
attrs=node.model_dump(exclude_unset=True), |
|
client=self.client, |
|
) |
|
for edge in self.last_response.edges |
|
if (node := edge.node) |
|
and (artifact_seq := node.artifact_sequence) |
|
and (proj := artifact_seq.project) |
|
] |
|
|
|
|
|
class ArtifactFiles(SizedPaginator["public.File"]): |
|
last_response: FilesFragment | None |
|
|
|
def __init__( |
|
self, |
|
client: Client, |
|
artifact: Artifact, |
|
names: Sequence[str] | None = None, |
|
per_page: int = 50, |
|
): |
|
self.query_via_membership = InternalApi()._server_supports( |
|
ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILES |
|
) |
|
self.artifact = artifact |
|
|
|
if self.query_via_membership: |
|
query_str = ARTIFACT_COLLECTION_MEMBERSHIP_FILES_GQL |
|
variables = { |
|
"entityName": artifact.entity, |
|
"projectName": artifact.project, |
|
"artifactName": artifact.name.split(":")[0], |
|
"artifactVersionIndex": artifact.version, |
|
"fileNames": names, |
|
} |
|
else: |
|
query_str = ARTIFACT_VERSION_FILES_GQL |
|
variables = { |
|
"entityName": artifact.source_entity, |
|
"projectName": artifact.source_project, |
|
"artifactName": artifact.source_name, |
|
"artifactTypeName": artifact.type, |
|
"fileNames": names, |
|
} |
|
|
|
|
|
|
|
if not client.version_supported("0.12.21"): |
|
self.QUERY = gql_compat(query_str, omit_fields={"storagePath"}) |
|
else: |
|
self.QUERY = gql(query_str) |
|
|
|
super().__init__(client, variables, per_page) |
|
|
|
@override |
|
def _update_response(self) -> None: |
|
data = self.client.execute(self.QUERY, variable_values=self.variables) |
|
|
|
|
|
if self.query_via_membership: |
|
result = ArtifactCollectionMembershipFiles.model_validate(data) |
|
conn = result.project.artifact_collection.artifact_membership.files |
|
else: |
|
result = ArtifactVersionFiles.model_validate(data) |
|
conn = result.project.artifact_type.artifact.files |
|
|
|
if conn is None: |
|
raise ValueError(f"Unable to parse {type(self).__name__!r} response data") |
|
|
|
self.last_response = FilesFragment.model_validate(conn) |
|
|
|
@property |
|
def path(self) -> list[str]: |
|
return [self.artifact.entity, self.artifact.project, self.artifact.name] |
|
|
|
@property |
|
def length(self) -> int: |
|
return self.artifact.file_count |
|
|
|
@property |
|
def more(self) -> bool: |
|
if self.last_response is None: |
|
return True |
|
return self.last_response.page_info.has_next_page |
|
|
|
@property |
|
def cursor(self) -> str | None: |
|
if self.last_response is None: |
|
return None |
|
return self.last_response.edges[-1].cursor |
|
|
|
def update_variables(self) -> None: |
|
self.variables.update({"fileLimit": self.per_page, "fileCursor": self.cursor}) |
|
|
|
def convert_objects(self) -> list[public.File]: |
|
if self.last_response is None: |
|
return [] |
|
|
|
return [ |
|
public.File( |
|
client=self.client, |
|
attrs=node.model_dump(exclude_unset=True), |
|
) |
|
for edge in self.last_response.edges |
|
if (node := edge.node) |
|
] |
|
|
|
def __repr__(self) -> str: |
|
path_str = "/".join(self.path) |
|
return f"<ArtifactFiles {path_str} ({len(self)})>" |
|
|
|
|
|
def server_supports_artifact_collections_gql_edges( |
|
client: RetryingClient, warn: bool = False |
|
) -> bool: |
|
|
|
|
|
|
|
supported = client.version_supported("0.12.11") |
|
if not supported and warn: |
|
|
|
wandb.termwarn( |
|
"W&B Local Server version does not support ArtifactCollection gql edges; falling back to using legacy ArtifactSequence. Please update server to at least version 0.9.50." |
|
) |
|
return supported |
|
|