File size: 1,909 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from typing import Callable, Optional
from .. import Features, NamedSplit, Split
from ..packaged_modules.generator.generator import Generator
from .abc import AbstractDatasetInputStream
class GeneratorDatasetInputStream(AbstractDatasetInputStream):
def __init__(
self,
generator: Callable,
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
split: NamedSplit = Split.TRAIN,
**kwargs,
):
super().__init__(
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
num_proc=num_proc,
**kwargs,
)
self.builder = Generator(
cache_dir=cache_dir,
features=features,
generator=generator,
gen_kwargs=gen_kwargs,
split=split,
**kwargs,
)
def read(self):
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split=self.builder.config.split)
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
verification_mode = None
base_path = None
self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
verification_mode=verification_mode,
base_path=base_path,
num_proc=self.num_proc,
)
dataset = self.builder.as_dataset(
split=self.builder.config.split, verification_mode=verification_mode, in_memory=self.keep_in_memory
)
return dataset
|