File size: 8,497 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
"""Public API: files."""
import io
import os
from typing import Optional
import requests
from wandb_gql import gql
from wandb_gql.client import RetryError
import wandb
from wandb import util
from wandb.apis.attrs import Attrs
from wandb.apis.normalize import normalize_exceptions
from wandb.apis.paginator import SizedPaginator
from wandb.apis.public import utils
from wandb.apis.public.api import Api
from wandb.apis.public.const import RETRY_TIMEDELTA
from wandb.sdk.lib import retry
FILE_FRAGMENT = """fragment RunFilesFragment on Run {
files(names: $fileNames, after: $fileCursor, first: $fileLimit) {
edges {
node {
id
name
url(upload: $upload)
directUrl
sizeBytes
mimetype
updatedAt
md5
}
cursor
}
pageInfo {
endCursor
hasNextPage
}
}
}"""
class Files(SizedPaginator["File"]):
"""An iterable collection of `File` objects."""
QUERY = gql(
"""
query RunFiles($project: String!, $entity: String!, $name: String!, $fileCursor: String,
$fileLimit: Int = 50, $fileNames: [String] = [], $upload: Boolean = false) {{
project(name: $project, entityName: $entity) {{
internalId
run(name: $name) {{
fileCount
...RunFilesFragment
}}
}}
}}
{}
""".format(FILE_FRAGMENT)
)
def __init__(self, client, run, names=None, per_page=50, upload=False):
self.run = run
variables = {
"project": run.project,
"entity": run.entity,
"name": run.id,
"fileNames": names or [],
"upload": upload,
}
super().__init__(client, variables, per_page)
@property
def length(self):
if self.last_response:
return self.last_response["project"]["run"]["fileCount"]
else:
return None
@property
def more(self):
if self.last_response:
return self.last_response["project"]["run"]["files"]["pageInfo"][
"hasNextPage"
]
else:
return True
@property
def cursor(self):
if self.last_response:
return self.last_response["project"]["run"]["files"]["edges"][-1]["cursor"]
else:
return None
def update_variables(self):
self.variables.update({"fileLimit": self.per_page, "fileCursor": self.cursor})
def convert_objects(self):
return [
File(self.client, r["node"], self.run)
for r in self.last_response["project"]["run"]["files"]["edges"]
]
def __repr__(self):
return "<Files {} ({})>".format("/".join(self.run.path), len(self))
class File(Attrs):
"""File is a class associated with a file saved by wandb.
Attributes:
name (string): filename
url (string): path to file
direct_url (string): path to file in the bucket
md5 (string): md5 of file
mimetype (string): mimetype of file
updated_at (string): timestamp of last update
size (int): size of file in bytes
path_uri (str): path to file in the bucket, currently only available for files stored in S3
"""
def __init__(self, client, attrs, run=None):
self.client = client
self._attrs = attrs
self.run = run
self.server_supports_delete_file_with_project_id: Optional[bool] = None
super().__init__(dict(attrs))
@property
def size(self):
size_bytes = self._attrs["sizeBytes"]
if size_bytes is not None:
return int(size_bytes)
return 0
@property
def path_uri(self) -> str:
"""
Returns the uri path to the file in the storage bucket.
"""
path_uri = ""
try:
path_uri = utils.parse_s3_url_to_s3_uri(self._attrs["directUrl"])
except ValueError:
wandb.termwarn("path_uri is only available for files stored in S3")
except LookupError:
wandb.termwarn("Unable to find direct_url of file")
return path_uri
@normalize_exceptions
@retry.retriable(
retry_timedelta=RETRY_TIMEDELTA,
check_retry_fn=util.no_retry_auth,
retryable_exceptions=(RetryError, requests.RequestException),
)
def download(
self,
root: str = ".",
replace: bool = False,
exist_ok: bool = False,
api: Optional[Api] = None,
) -> io.TextIOWrapper:
"""Downloads a file previously saved by a run from the wandb server.
Args:
replace (boolean): If `True`, download will overwrite a local file
if it exists. Defaults to `False`.
root (str): Local directory to save the file. Defaults to ".".
exist_ok (boolean): If `True`, will not raise ValueError if file already
exists and will not re-download unless replace=True. Defaults to `False`.
api (Api, optional): If given, the `Api` instance used to download the file.
Raises:
`ValueError` if file already exists, replace=False and exist_ok=False.
"""
if api is None:
api = wandb.Api()
path = os.path.join(root, self.name)
if os.path.exists(path) and not replace:
if exist_ok:
return open(path)
else:
raise ValueError(
"File already exists, pass replace=True to overwrite or exist_ok=True to leave it as is and don't error."
)
util.download_file_from_url(path, self.url, api.api_key)
return open(path)
@normalize_exceptions
def delete(self):
project_id_mutation_fragment = ""
project_id_variable_fragment = ""
variable_values = {
"files": [self.id],
}
# Add projectId to mutation and variables if the server supports it.
# Otherwise, do not include projectId in mutation for older server versions which do not support it.
if self._server_accepts_project_id_for_delete_file():
variable_values["projectId"] = self.run._project_internal_id
project_id_variable_fragment = ", $projectId: Int"
project_id_mutation_fragment = "projectId: $projectId"
mutation_string = """
mutation deleteFiles($files: [ID!]!{}) {{
deleteFiles(input: {{
files: $files
{}
}}) {{
success
}}
}}
""".format(project_id_variable_fragment, project_id_mutation_fragment)
mutation = gql(mutation_string)
self.client.execute(
mutation,
variable_values=variable_values,
)
def __repr__(self):
return "<File {} ({}) {}>".format(
self.name,
self.mimetype,
util.to_human_size(self.size, units=util.POW_2_BYTES),
)
@normalize_exceptions
def _server_accepts_project_id_for_delete_file(self) -> bool:
"""Returns True if the server supports deleting files with a projectId.
This check is done by utilizing GraphQL introspection in the available fields on the DeleteFiles API.
"""
query_string = """
query ProbeDeleteFilesProjectIdInput {
DeleteFilesProjectIdInputType: __type(name:"DeleteFilesInput") {
inputFields{
name
}
}
}
"""
# Only perform the query once to avoid extra network calls
if self.server_supports_delete_file_with_project_id is None:
query = gql(query_string)
res = self.client.execute(query)
# If projectId is in the inputFields, the server supports deleting files with a projectId
self.server_supports_delete_file_with_project_id = "projectId" in [
x["name"]
for x in (
res.get("DeleteFilesProjectIdInputType", {}).get(
"inputFields", [{}]
)
)
]
return self.server_supports_delete_file_with_project_id
|