File size: 8,196 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import glob
import json
import os
import shutil
import time
from pathlib import Path
from typing import Optional, Union
import pyarrow as pa
import datasets
import datasets.config
import datasets.data_files
from datasets.naming import camelcase_to_snakecase, filenames_for_dataset_split
logger = datasets.utils.logging.get_logger(__name__)
def _get_modification_time(cached_directory_path):
return (Path(cached_directory_path)).stat().st_mtime
def _find_hash_in_cache(
dataset_name: str,
config_name: Optional[str],
cache_dir: Optional[str],
config_kwargs: dict,
custom_features: Optional[datasets.Features],
) -> tuple[str, str, str]:
if config_name or config_kwargs or custom_features:
config_id = datasets.BuilderConfig(config_name or "default").create_config_id(
config_kwargs=config_kwargs, custom_features=custom_features
)
else:
config_id = None
cache_dir = os.path.expanduser(str(cache_dir or datasets.config.HF_DATASETS_CACHE))
namespace_and_dataset_name = dataset_name.split("/")
namespace_and_dataset_name[-1] = camelcase_to_snakecase(namespace_and_dataset_name[-1])
cached_relative_path = "___".join(namespace_and_dataset_name)
cached_datasets_directory_path_root = os.path.join(cache_dir, cached_relative_path)
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(
os.path.join(cached_datasets_directory_path_root, config_id or "*", "*", "*")
)
if os.path.isdir(cached_directory_path)
and (
config_kwargs
or custom_features
or json.loads(Path(cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"]
== Path(cached_directory_path).parts[-3] # no extra params => config_id == config_name
)
]
if not cached_directory_paths:
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*"))
if os.path.isdir(cached_directory_path)
]
available_configs = sorted(
{Path(cached_directory_path).parts[-3] for cached_directory_path in cached_directory_paths}
)
raise ValueError(
f"Couldn't find cache for {dataset_name}"
+ (f" for config '{config_id}'" if config_id else "")
+ (f"\nAvailable configs in the cache: {available_configs}" if available_configs else "")
)
# get most recent
cached_directory_path = Path(sorted(cached_directory_paths, key=_get_modification_time)[-1])
version, hash = cached_directory_path.parts[-2:]
other_configs = [
Path(_cached_directory_path).parts[-3]
for _cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", version, hash))
if os.path.isdir(_cached_directory_path)
and (
config_kwargs
or custom_features
or json.loads(Path(_cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"]
== Path(_cached_directory_path).parts[-3] # no extra params => config_id == config_name
)
]
if not config_id and len(other_configs) > 1:
raise ValueError(
f"There are multiple '{dataset_name}' configurations in the cache: {', '.join(other_configs)}"
f"\nPlease specify which configuration to reload from the cache, e.g."
f"\n\tload_dataset('{dataset_name}', '{other_configs[0]}')"
)
config_name = cached_directory_path.parts[-3]
warning_msg = (
f"Found the latest cached dataset configuration '{config_name}' at {cached_directory_path} "
f"(last modified on {time.ctime(_get_modification_time(cached_directory_path))})."
)
logger.warning(warning_msg)
return config_name, version, hash
class Cache(datasets.ArrowBasedBuilder):
def __init__(
self,
cache_dir: Optional[str] = None,
dataset_name: Optional[str] = None,
config_name: Optional[str] = None,
version: Optional[str] = "0.0.0",
hash: Optional[str] = None,
base_path: Optional[str] = None,
info: Optional[datasets.DatasetInfo] = None,
features: Optional[datasets.Features] = None,
token: Optional[Union[bool, str]] = None,
repo_id: Optional[str] = None,
data_files: Optional[Union[str, list, dict, datasets.data_files.DataFilesDict]] = None,
data_dir: Optional[str] = None,
storage_options: Optional[dict] = None,
writer_batch_size: Optional[int] = None,
**config_kwargs,
):
if repo_id is None and dataset_name is None:
raise ValueError("repo_id or dataset_name is required for the Cache dataset builder")
if data_files is not None:
config_kwargs["data_files"] = data_files
if data_dir is not None:
config_kwargs["data_dir"] = data_dir
if hash == "auto" and version == "auto":
config_name, version, hash = _find_hash_in_cache(
dataset_name=repo_id or dataset_name,
config_name=config_name,
cache_dir=cache_dir,
config_kwargs=config_kwargs,
custom_features=features,
)
elif hash == "auto" or version == "auto":
raise NotImplementedError("Pass both hash='auto' and version='auto' instead")
super().__init__(
cache_dir=cache_dir,
dataset_name=dataset_name,
config_name=config_name,
version=version,
hash=hash,
base_path=base_path,
info=info,
token=token,
repo_id=repo_id,
storage_options=storage_options,
writer_batch_size=writer_batch_size,
)
def _info(self) -> datasets.DatasetInfo:
return datasets.DatasetInfo()
def download_and_prepare(self, output_dir: Optional[str] = None, *args, **kwargs):
if not os.path.exists(self.cache_dir):
raise ValueError(f"Cache directory for {self.dataset_name} doesn't exist at {self.cache_dir}")
if output_dir is not None and output_dir != self.cache_dir:
shutil.copytree(self.cache_dir, output_dir)
def _split_generators(self, dl_manager):
# used to stream from cache
if isinstance(self.info.splits, datasets.SplitDict):
split_infos: list[datasets.SplitInfo] = list(self.info.splits.values())
else:
raise ValueError(f"Missing splits info for {self.dataset_name} in cache directory {self.cache_dir}")
return [
datasets.SplitGenerator(
name=split_info.name,
gen_kwargs={
"files": filenames_for_dataset_split(
self.cache_dir,
dataset_name=self.dataset_name,
split=split_info.name,
filetype_suffix="arrow",
shard_lengths=split_info.shard_lengths,
)
},
)
for split_info in split_infos
]
def _generate_tables(self, files):
# used to stream from cache
for file_idx, file in enumerate(files):
with open(file, "rb") as f:
try:
for batch_idx, record_batch in enumerate(pa.ipc.open_stream(f)):
pa_table = pa.Table.from_batches([record_batch])
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield f"{file_idx}_{batch_idx}", pa_table
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise
|