File size: 8,196 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import glob
import json
import os
import shutil
import time
from pathlib import Path
from typing import Optional, Union

import pyarrow as pa

import datasets
import datasets.config
import datasets.data_files
from datasets.naming import camelcase_to_snakecase, filenames_for_dataset_split


logger = datasets.utils.logging.get_logger(__name__)


def _get_modification_time(cached_directory_path):
    return (Path(cached_directory_path)).stat().st_mtime


def _find_hash_in_cache(
    dataset_name: str,
    config_name: Optional[str],
    cache_dir: Optional[str],
    config_kwargs: dict,
    custom_features: Optional[datasets.Features],
) -> tuple[str, str, str]:
    if config_name or config_kwargs or custom_features:
        config_id = datasets.BuilderConfig(config_name or "default").create_config_id(
            config_kwargs=config_kwargs, custom_features=custom_features
        )
    else:
        config_id = None
    cache_dir = os.path.expanduser(str(cache_dir or datasets.config.HF_DATASETS_CACHE))
    namespace_and_dataset_name = dataset_name.split("/")
    namespace_and_dataset_name[-1] = camelcase_to_snakecase(namespace_and_dataset_name[-1])
    cached_relative_path = "___".join(namespace_and_dataset_name)
    cached_datasets_directory_path_root = os.path.join(cache_dir, cached_relative_path)
    cached_directory_paths = [
        cached_directory_path
        for cached_directory_path in glob.glob(
            os.path.join(cached_datasets_directory_path_root, config_id or "*", "*", "*")
        )
        if os.path.isdir(cached_directory_path)
        and (
            config_kwargs
            or custom_features
            or json.loads(Path(cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"]
            == Path(cached_directory_path).parts[-3]  # no extra params => config_id == config_name
        )
    ]
    if not cached_directory_paths:
        cached_directory_paths = [
            cached_directory_path
            for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*"))
            if os.path.isdir(cached_directory_path)
        ]
        available_configs = sorted(
            {Path(cached_directory_path).parts[-3] for cached_directory_path in cached_directory_paths}
        )
        raise ValueError(
            f"Couldn't find cache for {dataset_name}"
            + (f" for config '{config_id}'" if config_id else "")
            + (f"\nAvailable configs in the cache: {available_configs}" if available_configs else "")
        )
    # get most recent
    cached_directory_path = Path(sorted(cached_directory_paths, key=_get_modification_time)[-1])
    version, hash = cached_directory_path.parts[-2:]
    other_configs = [
        Path(_cached_directory_path).parts[-3]
        for _cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", version, hash))
        if os.path.isdir(_cached_directory_path)
        and (
            config_kwargs
            or custom_features
            or json.loads(Path(_cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"]
            == Path(_cached_directory_path).parts[-3]  # no extra params => config_id == config_name
        )
    ]
    if not config_id and len(other_configs) > 1:
        raise ValueError(
            f"There are multiple '{dataset_name}' configurations in the cache: {', '.join(other_configs)}"
            f"\nPlease specify which configuration to reload from the cache, e.g."
            f"\n\tload_dataset('{dataset_name}', '{other_configs[0]}')"
        )
    config_name = cached_directory_path.parts[-3]
    warning_msg = (
        f"Found the latest cached dataset configuration '{config_name}' at {cached_directory_path} "
        f"(last modified on {time.ctime(_get_modification_time(cached_directory_path))})."
    )
    logger.warning(warning_msg)
    return config_name, version, hash


class Cache(datasets.ArrowBasedBuilder):
    def __init__(
        self,
        cache_dir: Optional[str] = None,
        dataset_name: Optional[str] = None,
        config_name: Optional[str] = None,
        version: Optional[str] = "0.0.0",
        hash: Optional[str] = None,
        base_path: Optional[str] = None,
        info: Optional[datasets.DatasetInfo] = None,
        features: Optional[datasets.Features] = None,
        token: Optional[Union[bool, str]] = None,
        repo_id: Optional[str] = None,
        data_files: Optional[Union[str, list, dict, datasets.data_files.DataFilesDict]] = None,
        data_dir: Optional[str] = None,
        storage_options: Optional[dict] = None,
        writer_batch_size: Optional[int] = None,
        **config_kwargs,
    ):
        if repo_id is None and dataset_name is None:
            raise ValueError("repo_id or dataset_name is required for the Cache dataset builder")
        if data_files is not None:
            config_kwargs["data_files"] = data_files
        if data_dir is not None:
            config_kwargs["data_dir"] = data_dir
        if hash == "auto" and version == "auto":
            config_name, version, hash = _find_hash_in_cache(
                dataset_name=repo_id or dataset_name,
                config_name=config_name,
                cache_dir=cache_dir,
                config_kwargs=config_kwargs,
                custom_features=features,
            )
        elif hash == "auto" or version == "auto":
            raise NotImplementedError("Pass both hash='auto' and version='auto' instead")
        super().__init__(
            cache_dir=cache_dir,
            dataset_name=dataset_name,
            config_name=config_name,
            version=version,
            hash=hash,
            base_path=base_path,
            info=info,
            token=token,
            repo_id=repo_id,
            storage_options=storage_options,
            writer_batch_size=writer_batch_size,
        )

    def _info(self) -> datasets.DatasetInfo:
        return datasets.DatasetInfo()

    def download_and_prepare(self, output_dir: Optional[str] = None, *args, **kwargs):
        if not os.path.exists(self.cache_dir):
            raise ValueError(f"Cache directory for {self.dataset_name} doesn't exist at {self.cache_dir}")
        if output_dir is not None and output_dir != self.cache_dir:
            shutil.copytree(self.cache_dir, output_dir)

    def _split_generators(self, dl_manager):
        # used to stream from cache
        if isinstance(self.info.splits, datasets.SplitDict):
            split_infos: list[datasets.SplitInfo] = list(self.info.splits.values())
        else:
            raise ValueError(f"Missing splits info for {self.dataset_name} in cache directory {self.cache_dir}")
        return [
            datasets.SplitGenerator(
                name=split_info.name,
                gen_kwargs={
                    "files": filenames_for_dataset_split(
                        self.cache_dir,
                        dataset_name=self.dataset_name,
                        split=split_info.name,
                        filetype_suffix="arrow",
                        shard_lengths=split_info.shard_lengths,
                    )
                },
            )
            for split_info in split_infos
        ]

    def _generate_tables(self, files):
        # used to stream from cache
        for file_idx, file in enumerate(files):
            with open(file, "rb") as f:
                try:
                    for batch_idx, record_batch in enumerate(pa.ipc.open_stream(f)):
                        pa_table = pa.Table.from_batches([record_batch])
                        # Uncomment for debugging (will print the Arrow table size and elements)
                        # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
                        # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
                        yield f"{file_idx}_{batch_idx}", pa_table
                except ValueError as e:
                    logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
                    raise