File size: 2,289 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""progress."""

import os
from typing import IO, TYPE_CHECKING, Optional

from wandb.errors import CommError

if TYPE_CHECKING:
    from typing import Protocol

    class ProgressFn(Protocol):
        def __call__(self, new_bytes: int, total_bytes: int) -> None:
            pass


class Progress:
    """A helper class for displaying progress."""

    ITER_BYTES = 1024 * 1024

    def __init__(
        self, file: IO[bytes], callback: Optional["ProgressFn"] = None
    ) -> None:
        self.file = file
        if callback is None:

            def callback_(new_bytes: int, total_bytes: int) -> None:
                pass

            callback = callback_

        self.callback: ProgressFn = callback
        self.bytes_read = 0
        self.len = os.fstat(file.fileno()).st_size

    def read(self, size=-1):
        """Read bytes and call the callback."""
        bites = self.file.read(size)
        self.bytes_read += len(bites)
        if not bites and self.bytes_read < self.len:
            # Files shrinking during uploads causes request timeouts. Maybe
            # we could avoid those by updating the self.len in real-time, but
            # files getting truncated while uploading seems like something
            # that shouldn't really be happening anyway.
            raise CommError(
                f"File {self.file.name} size shrank from {self.len} to {self.bytes_read} while it was being uploaded."
            )
        # Growing files are also likely to be bad, but our code didn't break
        # on those in the past, so it's riskier to make that an error now.
        self.callback(len(bites), self.bytes_read)
        return bites

    def rewind(self) -> None:
        self.callback(-self.bytes_read, 0)
        self.bytes_read = 0
        self.file.seek(0)

    def __getattr__(self, name):
        """Fallback to the file object for attrs not defined here."""
        if hasattr(self.file, name):
            return getattr(self.file, name)
        else:
            raise AttributeError

    def __iter__(self):
        return self

    def __next__(self):
        bites = self.read(self.ITER_BYTES)
        if len(bites) == 0:
            raise StopIteration
        return bites

    def __len__(self):
        return self.len

    next = __next__