|
import glob |
|
import json |
|
import os |
|
import shutil |
|
import time |
|
from pathlib import Path |
|
from typing import Optional, Union |
|
|
|
import pyarrow as pa |
|
|
|
import datasets |
|
import datasets.config |
|
import datasets.data_files |
|
from datasets.naming import camelcase_to_snakecase, filenames_for_dataset_split |
|
|
|
|
|
logger = datasets.utils.logging.get_logger(__name__) |
|
|
|
|
|
def _get_modification_time(cached_directory_path): |
|
return (Path(cached_directory_path)).stat().st_mtime |
|
|
|
|
|
def _find_hash_in_cache( |
|
dataset_name: str, |
|
config_name: Optional[str], |
|
cache_dir: Optional[str], |
|
config_kwargs: dict, |
|
custom_features: Optional[datasets.Features], |
|
) -> tuple[str, str, str]: |
|
if config_name or config_kwargs or custom_features: |
|
config_id = datasets.BuilderConfig(config_name or "default").create_config_id( |
|
config_kwargs=config_kwargs, custom_features=custom_features |
|
) |
|
else: |
|
config_id = None |
|
cache_dir = os.path.expanduser(str(cache_dir or datasets.config.HF_DATASETS_CACHE)) |
|
namespace_and_dataset_name = dataset_name.split("/") |
|
namespace_and_dataset_name[-1] = camelcase_to_snakecase(namespace_and_dataset_name[-1]) |
|
cached_relative_path = "___".join(namespace_and_dataset_name) |
|
cached_datasets_directory_path_root = os.path.join(cache_dir, cached_relative_path) |
|
cached_directory_paths = [ |
|
cached_directory_path |
|
for cached_directory_path in glob.glob( |
|
os.path.join(cached_datasets_directory_path_root, config_id or "*", "*", "*") |
|
) |
|
if os.path.isdir(cached_directory_path) |
|
and ( |
|
config_kwargs |
|
or custom_features |
|
or json.loads(Path(cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"] |
|
== Path(cached_directory_path).parts[-3] |
|
) |
|
] |
|
if not cached_directory_paths: |
|
cached_directory_paths = [ |
|
cached_directory_path |
|
for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*")) |
|
if os.path.isdir(cached_directory_path) |
|
] |
|
available_configs = sorted( |
|
{Path(cached_directory_path).parts[-3] for cached_directory_path in cached_directory_paths} |
|
) |
|
raise ValueError( |
|
f"Couldn't find cache for {dataset_name}" |
|
+ (f" for config '{config_id}'" if config_id else "") |
|
+ (f"\nAvailable configs in the cache: {available_configs}" if available_configs else "") |
|
) |
|
|
|
cached_directory_path = Path(sorted(cached_directory_paths, key=_get_modification_time)[-1]) |
|
version, hash = cached_directory_path.parts[-2:] |
|
other_configs = [ |
|
Path(_cached_directory_path).parts[-3] |
|
for _cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", version, hash)) |
|
if os.path.isdir(_cached_directory_path) |
|
and ( |
|
config_kwargs |
|
or custom_features |
|
or json.loads(Path(_cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"] |
|
== Path(_cached_directory_path).parts[-3] |
|
) |
|
] |
|
if not config_id and len(other_configs) > 1: |
|
raise ValueError( |
|
f"There are multiple '{dataset_name}' configurations in the cache: {', '.join(other_configs)}" |
|
f"\nPlease specify which configuration to reload from the cache, e.g." |
|
f"\n\tload_dataset('{dataset_name}', '{other_configs[0]}')" |
|
) |
|
config_name = cached_directory_path.parts[-3] |
|
warning_msg = ( |
|
f"Found the latest cached dataset configuration '{config_name}' at {cached_directory_path} " |
|
f"(last modified on {time.ctime(_get_modification_time(cached_directory_path))})." |
|
) |
|
logger.warning(warning_msg) |
|
return config_name, version, hash |
|
|
|
|
|
class Cache(datasets.ArrowBasedBuilder): |
|
def __init__( |
|
self, |
|
cache_dir: Optional[str] = None, |
|
dataset_name: Optional[str] = None, |
|
config_name: Optional[str] = None, |
|
version: Optional[str] = "0.0.0", |
|
hash: Optional[str] = None, |
|
base_path: Optional[str] = None, |
|
info: Optional[datasets.DatasetInfo] = None, |
|
features: Optional[datasets.Features] = None, |
|
token: Optional[Union[bool, str]] = None, |
|
repo_id: Optional[str] = None, |
|
data_files: Optional[Union[str, list, dict, datasets.data_files.DataFilesDict]] = None, |
|
data_dir: Optional[str] = None, |
|
storage_options: Optional[dict] = None, |
|
writer_batch_size: Optional[int] = None, |
|
**config_kwargs, |
|
): |
|
if repo_id is None and dataset_name is None: |
|
raise ValueError("repo_id or dataset_name is required for the Cache dataset builder") |
|
if data_files is not None: |
|
config_kwargs["data_files"] = data_files |
|
if data_dir is not None: |
|
config_kwargs["data_dir"] = data_dir |
|
if hash == "auto" and version == "auto": |
|
config_name, version, hash = _find_hash_in_cache( |
|
dataset_name=repo_id or dataset_name, |
|
config_name=config_name, |
|
cache_dir=cache_dir, |
|
config_kwargs=config_kwargs, |
|
custom_features=features, |
|
) |
|
elif hash == "auto" or version == "auto": |
|
raise NotImplementedError("Pass both hash='auto' and version='auto' instead") |
|
super().__init__( |
|
cache_dir=cache_dir, |
|
dataset_name=dataset_name, |
|
config_name=config_name, |
|
version=version, |
|
hash=hash, |
|
base_path=base_path, |
|
info=info, |
|
token=token, |
|
repo_id=repo_id, |
|
storage_options=storage_options, |
|
writer_batch_size=writer_batch_size, |
|
) |
|
|
|
def _info(self) -> datasets.DatasetInfo: |
|
return datasets.DatasetInfo() |
|
|
|
def download_and_prepare(self, output_dir: Optional[str] = None, *args, **kwargs): |
|
if not os.path.exists(self.cache_dir): |
|
raise ValueError(f"Cache directory for {self.dataset_name} doesn't exist at {self.cache_dir}") |
|
if output_dir is not None and output_dir != self.cache_dir: |
|
shutil.copytree(self.cache_dir, output_dir) |
|
|
|
def _split_generators(self, dl_manager): |
|
|
|
if isinstance(self.info.splits, datasets.SplitDict): |
|
split_infos: list[datasets.SplitInfo] = list(self.info.splits.values()) |
|
else: |
|
raise ValueError(f"Missing splits info for {self.dataset_name} in cache directory {self.cache_dir}") |
|
return [ |
|
datasets.SplitGenerator( |
|
name=split_info.name, |
|
gen_kwargs={ |
|
"files": filenames_for_dataset_split( |
|
self.cache_dir, |
|
dataset_name=self.dataset_name, |
|
split=split_info.name, |
|
filetype_suffix="arrow", |
|
shard_lengths=split_info.shard_lengths, |
|
) |
|
}, |
|
) |
|
for split_info in split_infos |
|
] |
|
|
|
def _generate_tables(self, files): |
|
|
|
for file_idx, file in enumerate(files): |
|
with open(file, "rb") as f: |
|
try: |
|
for batch_idx, record_batch in enumerate(pa.ipc.open_stream(f)): |
|
pa_table = pa.Table.from_batches([record_batch]) |
|
|
|
|
|
|
|
yield f"{file_idx}_{batch_idx}", pa_table |
|
except ValueError as e: |
|
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") |
|
raise |
|
|