jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import glob
import json
import os
import shutil
import time
from pathlib import Path
from typing import Optional, Union
import pyarrow as pa
import datasets
import datasets.config
import datasets.data_files
from datasets.naming import camelcase_to_snakecase, filenames_for_dataset_split
logger = datasets.utils.logging.get_logger(__name__)
def _get_modification_time(cached_directory_path):
return (Path(cached_directory_path)).stat().st_mtime
def _find_hash_in_cache(
dataset_name: str,
config_name: Optional[str],
cache_dir: Optional[str],
config_kwargs: dict,
custom_features: Optional[datasets.Features],
) -> tuple[str, str, str]:
if config_name or config_kwargs or custom_features:
config_id = datasets.BuilderConfig(config_name or "default").create_config_id(
config_kwargs=config_kwargs, custom_features=custom_features
)
else:
config_id = None
cache_dir = os.path.expanduser(str(cache_dir or datasets.config.HF_DATASETS_CACHE))
namespace_and_dataset_name = dataset_name.split("/")
namespace_and_dataset_name[-1] = camelcase_to_snakecase(namespace_and_dataset_name[-1])
cached_relative_path = "___".join(namespace_and_dataset_name)
cached_datasets_directory_path_root = os.path.join(cache_dir, cached_relative_path)
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(
os.path.join(cached_datasets_directory_path_root, config_id or "*", "*", "*")
)
if os.path.isdir(cached_directory_path)
and (
config_kwargs
or custom_features
or json.loads(Path(cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"]
== Path(cached_directory_path).parts[-3] # no extra params => config_id == config_name
)
]
if not cached_directory_paths:
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*"))
if os.path.isdir(cached_directory_path)
]
available_configs = sorted(
{Path(cached_directory_path).parts[-3] for cached_directory_path in cached_directory_paths}
)
raise ValueError(
f"Couldn't find cache for {dataset_name}"
+ (f" for config '{config_id}'" if config_id else "")
+ (f"\nAvailable configs in the cache: {available_configs}" if available_configs else "")
)
# get most recent
cached_directory_path = Path(sorted(cached_directory_paths, key=_get_modification_time)[-1])
version, hash = cached_directory_path.parts[-2:]
other_configs = [
Path(_cached_directory_path).parts[-3]
for _cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", version, hash))
if os.path.isdir(_cached_directory_path)
and (
config_kwargs
or custom_features
or json.loads(Path(_cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"]
== Path(_cached_directory_path).parts[-3] # no extra params => config_id == config_name
)
]
if not config_id and len(other_configs) > 1:
raise ValueError(
f"There are multiple '{dataset_name}' configurations in the cache: {', '.join(other_configs)}"
f"\nPlease specify which configuration to reload from the cache, e.g."
f"\n\tload_dataset('{dataset_name}', '{other_configs[0]}')"
)
config_name = cached_directory_path.parts[-3]
warning_msg = (
f"Found the latest cached dataset configuration '{config_name}' at {cached_directory_path} "
f"(last modified on {time.ctime(_get_modification_time(cached_directory_path))})."
)
logger.warning(warning_msg)
return config_name, version, hash
class Cache(datasets.ArrowBasedBuilder):
def __init__(
self,
cache_dir: Optional[str] = None,
dataset_name: Optional[str] = None,
config_name: Optional[str] = None,
version: Optional[str] = "0.0.0",
hash: Optional[str] = None,
base_path: Optional[str] = None,
info: Optional[datasets.DatasetInfo] = None,
features: Optional[datasets.Features] = None,
token: Optional[Union[bool, str]] = None,
repo_id: Optional[str] = None,
data_files: Optional[Union[str, list, dict, datasets.data_files.DataFilesDict]] = None,
data_dir: Optional[str] = None,
storage_options: Optional[dict] = None,
writer_batch_size: Optional[int] = None,
**config_kwargs,
):
if repo_id is None and dataset_name is None:
raise ValueError("repo_id or dataset_name is required for the Cache dataset builder")
if data_files is not None:
config_kwargs["data_files"] = data_files
if data_dir is not None:
config_kwargs["data_dir"] = data_dir
if hash == "auto" and version == "auto":
config_name, version, hash = _find_hash_in_cache(
dataset_name=repo_id or dataset_name,
config_name=config_name,
cache_dir=cache_dir,
config_kwargs=config_kwargs,
custom_features=features,
)
elif hash == "auto" or version == "auto":
raise NotImplementedError("Pass both hash='auto' and version='auto' instead")
super().__init__(
cache_dir=cache_dir,
dataset_name=dataset_name,
config_name=config_name,
version=version,
hash=hash,
base_path=base_path,
info=info,
token=token,
repo_id=repo_id,
storage_options=storage_options,
writer_batch_size=writer_batch_size,
)
def _info(self) -> datasets.DatasetInfo:
return datasets.DatasetInfo()
def download_and_prepare(self, output_dir: Optional[str] = None, *args, **kwargs):
if not os.path.exists(self.cache_dir):
raise ValueError(f"Cache directory for {self.dataset_name} doesn't exist at {self.cache_dir}")
if output_dir is not None and output_dir != self.cache_dir:
shutil.copytree(self.cache_dir, output_dir)
def _split_generators(self, dl_manager):
# used to stream from cache
if isinstance(self.info.splits, datasets.SplitDict):
split_infos: list[datasets.SplitInfo] = list(self.info.splits.values())
else:
raise ValueError(f"Missing splits info for {self.dataset_name} in cache directory {self.cache_dir}")
return [
datasets.SplitGenerator(
name=split_info.name,
gen_kwargs={
"files": filenames_for_dataset_split(
self.cache_dir,
dataset_name=self.dataset_name,
split=split_info.name,
filetype_suffix="arrow",
shard_lengths=split_info.shard_lengths,
)
},
)
for split_info in split_infos
]
def _generate_tables(self, files):
# used to stream from cache
for file_idx, file in enumerate(files):
with open(file, "rb") as f:
try:
for batch_idx, record_batch in enumerate(pa.ipc.open_stream(f)):
pa_table = pa.Table.from_batches([record_batch])
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield f"{file_idx}_{batch_idx}", pa_table
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise