File size: 3,622 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from types import SimpleNamespace
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    List,
    Optional,
    TypeVar,
    Union,
)

from . import cache, filters, shardlists
from .filters import reraise_exception
from .pipeline import DataPipeline

T = TypeVar('T')
Sample = Dict[str, Any]
Handler = Callable[[Exception], bool]

class FluidInterface:
    def batched(self, batchsize: int, collation_fn: Optional[Callable] = filters.default_collation_fn, partial: bool = True) -> 'FluidInterface': ...
    def unbatched(self) -> 'FluidInterface': ...
    def listed(self, batchsize: int, partial: bool = True) -> 'FluidInterface': ...
    def unlisted(self) -> 'FluidInterface': ...
    def log_keys(self, logfile: Optional[str] = None) -> 'FluidInterface': ...
    def shuffle(self, size: int, **kw: Any) -> 'FluidInterface': ...
    def map(self, f: Callable, handler: Handler = reraise_exception) -> 'FluidInterface': ...
    def decode(
        self,
        *args: Union[str, Callable],
        pre: Optional[List[Callable]] = None,
        post: Optional[List[Callable]] = None,
        only: Optional[List[str]] = None,
        partial: bool = False,
        handler: Handler = reraise_exception,
    ) -> 'FluidInterface': ...
    def map_dict(self, handler: Handler = reraise_exception, **kw: Callable) -> 'FluidInterface': ...
    def select(self, predicate: Callable[[Any], bool], **kw: Any) -> 'FluidInterface': ...
    def to_tuple(self, *args: str, **kw: Any) -> 'FluidInterface': ...
    def map_tuple(self, *args: Callable, handler: Handler = reraise_exception) -> 'FluidInterface': ...
    def slice(self, *args: int) -> 'FluidInterface': ...
    def rename(self, **kw: str) -> 'FluidInterface': ...
    def rsample(self, p: float = 0.5) -> 'FluidInterface': ...
    def rename_keys(self, *args: Any, **kw: Any) -> 'FluidInterface': ...
    def extract_keys(self, *args: str, **kw: Any) -> 'FluidInterface': ...
    def xdecode(self, *args: Any, **kw: Any) -> 'FluidInterface': ...
    def mcached(self) -> 'FluidInterface': ...
    def lmdb_cached(self, *args: Any, **kw: Any) -> 'FluidInterface': ...
    def compose(self, other: Any) -> Any: ...

def check_empty(source: Iterable[Sample]) -> Iterator[Sample]: ...

class WebDataset(DataPipeline, FluidInterface):
    seed: int

    def __init__(
        self,
        urls: Union[str, Dict[str, Any], Iterable[str]],
        handler: Callable[[Exception], bool] = reraise_exception,
        mode: Optional[str] = None,
        resampled: bool = False,
        repeat: bool = False,
        shardshuffle: Optional[Union[bool, int]] = None,
        cache_size: int = -1,
        cache_dir: Optional[str] = None,
        url_to_name: Any = cache.pipe_cleaner,
        detshuffle: bool = False,
        nodesplitter: Optional[Callable] = shardlists.single_node_only,
        workersplitter: Optional[Callable] = shardlists.split_by_worker,
        select_files: Optional[Callable[[str], bool]] = None,
        rename_files: Optional[Callable[[str], str]] = None,
        empty_check: bool = True,
        verbose: bool = False,
        seed: Optional[int] = None,
    ) -> None: ...

    def update_cache_info(self, args: SimpleNamespace) -> None: ...
    def create_url_iterator(self, args: SimpleNamespace) -> None: ...
    def __enter__(self) -> 'WebDataset': ...
    def __exit__(self, *args: Any) -> None: ...

class FluidWrapper(DataPipeline, FluidInterface):
    def __init__(self, initial: Any) -> None: ...

class WebLoader(DataPipeline, FluidInterface):
    def __init__(self, *args: Any, **kw: Any) -> None: ...