File size: 2,289 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
"""progress."""
import os
from typing import IO, TYPE_CHECKING, Optional
from wandb.errors import CommError
if TYPE_CHECKING:
from typing import Protocol
class ProgressFn(Protocol):
def __call__(self, new_bytes: int, total_bytes: int) -> None:
pass
class Progress:
"""A helper class for displaying progress."""
ITER_BYTES = 1024 * 1024
def __init__(
self, file: IO[bytes], callback: Optional["ProgressFn"] = None
) -> None:
self.file = file
if callback is None:
def callback_(new_bytes: int, total_bytes: int) -> None:
pass
callback = callback_
self.callback: ProgressFn = callback
self.bytes_read = 0
self.len = os.fstat(file.fileno()).st_size
def read(self, size=-1):
"""Read bytes and call the callback."""
bites = self.file.read(size)
self.bytes_read += len(bites)
if not bites and self.bytes_read < self.len:
# Files shrinking during uploads causes request timeouts. Maybe
# we could avoid those by updating the self.len in real-time, but
# files getting truncated while uploading seems like something
# that shouldn't really be happening anyway.
raise CommError(
f"File {self.file.name} size shrank from {self.len} to {self.bytes_read} while it was being uploaded."
)
# Growing files are also likely to be bad, but our code didn't break
# on those in the past, so it's riskier to make that an error now.
self.callback(len(bites), self.bytes_read)
return bites
def rewind(self) -> None:
self.callback(-self.bytes_read, 0)
self.bytes_read = 0
self.file.seek(0)
def __getattr__(self, name):
"""Fallback to the file object for attrs not defined here."""
if hasattr(self.file, name):
return getattr(self.file, name)
else:
raise AttributeError
def __iter__(self):
return self
def __next__(self):
bites = self.read(self.ITER_BYTES)
if len(bites) == 0:
raise StopIteration
return bites
def __len__(self):
return self.len
next = __next__
|