jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import os
import re
from functools import partial
from glob import has_magic
from pathlib import Path, PurePath
from typing import Callable, Optional, Union
import huggingface_hub
from fsspec.core import url_to_fs
from huggingface_hub import HfFileSystem
from packaging import version
from tqdm.contrib.concurrent import thread_map
from . import config
from .download import DownloadConfig
from .naming import _split_re
from .splits import Split
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils.file_utils import _prepare_path_and_storage_options, is_local_path, is_relative_path, xbasename, xjoin
from .utils.py_utils import glob_pattern_to_regex, string_to_dict
SingleOriginMetadata = Union[tuple[str, str], tuple[str], tuple[()]]
SANITIZED_DEFAULT_SPLIT = str(Split.TRAIN)
logger = logging.get_logger(__name__)
class Url(str):
pass
class EmptyDatasetError(FileNotFoundError):
pass
SPLIT_PATTERN_SHARDED = "data/{split}-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"
SPLIT_KEYWORDS = {
Split.TRAIN: ["train", "training"],
Split.VALIDATION: ["validation", "valid", "dev", "val"],
Split.TEST: ["test", "testing", "eval", "evaluation"],
}
NON_WORDS_CHARS = "-._ 0-9"
if config.FSSPEC_VERSION < version.parse("2023.9.0"):
KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**[{sep}/]{keyword}[{sep}]*", "{keyword}[{sep}]*"]
KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [
"{keyword}/**",
"{keyword}[{sep}]*/**",
"**[{sep}/]{keyword}/**",
"**[{sep}/]{keyword}[{sep}]*/**",
]
elif config.FSSPEC_VERSION < version.parse("2023.12.0"):
KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**/*[{sep}/]{keyword}[{sep}]*", "{keyword}[{sep}]*"]
KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [
"{keyword}/**/*",
"{keyword}[{sep}]*/**/*",
"**/*[{sep}/]{keyword}/**/*",
"**/*[{sep}/]{keyword}[{sep}]*/**/*",
]
else:
KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**/{keyword}[{sep}]*", "**/*[{sep}]{keyword}[{sep}]*"]
KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [
"**/{keyword}/**",
"**/{keyword}[{sep}]*/**",
"**/*[{sep}]{keyword}/**",
"**/*[{sep}]{keyword}[{sep}]*/**",
]
DEFAULT_SPLITS = [Split.TRAIN, Split.VALIDATION, Split.TEST]
DEFAULT_PATTERNS_SPLIT_IN_FILENAME = {
split: [
pattern.format(keyword=keyword, sep=NON_WORDS_CHARS)
for keyword in SPLIT_KEYWORDS[split]
for pattern in KEYWORDS_IN_FILENAME_BASE_PATTERNS
]
for split in DEFAULT_SPLITS
}
DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME = {
split: [
pattern.format(keyword=keyword, sep=NON_WORDS_CHARS)
for keyword in SPLIT_KEYWORDS[split]
for pattern in KEYWORDS_IN_DIR_NAME_BASE_PATTERNS
]
for split in DEFAULT_SPLITS
}
DEFAULT_PATTERNS_ALL = {
Split.TRAIN: ["**"],
}
ALL_SPLIT_PATTERNS = [SPLIT_PATTERN_SHARDED]
ALL_DEFAULT_PATTERNS = [
DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME,
DEFAULT_PATTERNS_SPLIT_IN_FILENAME,
DEFAULT_PATTERNS_ALL,
]
WILDCARD_CHARACTERS = "*[]"
FILES_TO_IGNORE = [
"README.md",
"config.json",
"dataset_info.json",
"dataset_infos.json",
"dummy_data.zip",
"dataset_dict.json",
]
def contains_wildcards(pattern: str) -> bool:
return any(wildcard_character in pattern for wildcard_character in WILDCARD_CHARACTERS)
def sanitize_patterns(patterns: Union[dict, list, str]) -> dict[str, Union[list[str], "DataFilesList"]]:
"""
Take the data_files patterns from the user, and format them into a dictionary.
Each key is the name of the split, and each value is a list of data files patterns (paths or urls).
The default split is "train".
Returns:
patterns: dictionary of split_name -> list of patterns
"""
if isinstance(patterns, dict):
return {str(key): value if isinstance(value, list) else [value] for key, value in patterns.items()}
elif isinstance(patterns, str):
return {SANITIZED_DEFAULT_SPLIT: [patterns]}
elif isinstance(patterns, list):
if any(isinstance(pattern, dict) for pattern in patterns):
for pattern in patterns:
if not (
isinstance(pattern, dict)
and len(pattern) == 2
and "split" in pattern
and isinstance(pattern.get("path"), (str, list))
):
raise ValueError(
f"Expected each split to have a 'path' key which can be a string or a list of strings, but got {pattern}"
)
splits = [pattern["split"] for pattern in patterns]
if len(set(splits)) != len(splits):
raise ValueError(f"Some splits are duplicated in data_files: {splits}")
return {
str(pattern["split"]): pattern["path"] if isinstance(pattern["path"], list) else [pattern["path"]]
for pattern in patterns
}
else:
return {SANITIZED_DEFAULT_SPLIT: patterns}
else:
return sanitize_patterns(list(patterns))
def _is_inside_unrequested_special_dir(matched_rel_path: str, pattern: str) -> bool:
"""
When a path matches a pattern, we additionally check if it's inside a special directory
we ignore by default (if it starts with a double underscore).
Users can still explicitly request a filepath inside such a directory if "__pycache__" is
mentioned explicitly in the requested pattern.
Some examples:
base directory:
./
└── __pycache__
└── b.txt
>>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "**")
True
>>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "*/b.txt")
True
>>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "__pycache__/*")
False
>>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "__*/*")
False
"""
# We just need to check if every special directories from the path is present explicitly in the pattern.
# Since we assume that the path matches the pattern, it's equivalent to counting that both
# the parent path and the parent pattern have the same number of special directories.
data_dirs_to_ignore_in_path = [part for part in PurePath(matched_rel_path).parent.parts if part.startswith("__")]
data_dirs_to_ignore_in_pattern = [part for part in PurePath(pattern).parent.parts if part.startswith("__")]
return len(data_dirs_to_ignore_in_path) != len(data_dirs_to_ignore_in_pattern)
def _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(matched_rel_path: str, pattern: str) -> bool:
"""
When a path matches a pattern, we additionally check if it's a hidden file or if it's inside
a hidden directory we ignore by default, i.e. if the file name or a parent directory name starts with a dot.
Users can still explicitly request a filepath that is hidden or is inside a hidden directory
if the hidden part is mentioned explicitly in the requested pattern.
Some examples:
base directory:
./
└── .hidden_file.txt
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_file.txt", "**")
True
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_file.txt", ".*")
False
base directory:
./
└── .hidden_dir
└── a.txt
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/a.txt", "**")
True
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/a.txt", ".*/*")
False
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/a.txt", ".hidden_dir/*")
False
base directory:
./
└── .hidden_dir
└── .hidden_file.txt
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", "**")
True
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".*/*")
True
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".*/.*")
False
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".hidden_dir/*")
True
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".hidden_dir/.*")
False
"""
# We just need to check if every hidden part from the path is present explicitly in the pattern.
# Since we assume that the path matches the pattern, it's equivalent to counting that both
# the path and the pattern have the same number of hidden parts.
hidden_directories_in_path = [
part for part in PurePath(matched_rel_path).parts if part.startswith(".") and not set(part) == {"."}
]
hidden_directories_in_pattern = [
part for part in PurePath(pattern).parts if part.startswith(".") and not set(part) == {"."}
]
return len(hidden_directories_in_path) != len(hidden_directories_in_pattern)
def _get_data_files_patterns(pattern_resolver: Callable[[str], list[str]]) -> dict[str, list[str]]:
"""
Get the default pattern from a directory or repository by testing all the supported patterns.
The first patterns to return a non-empty list of data files is returned.
In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS.
"""
# first check the split patterns like data/{split}-00000-of-00001.parquet
for split_pattern in ALL_SPLIT_PATTERNS:
pattern = split_pattern.replace("{split}", "*")
try:
data_files = pattern_resolver(pattern)
except FileNotFoundError:
continue
if len(data_files) > 0:
splits: set[str] = set()
for p in data_files:
p_parts = string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern)))
assert p_parts is not None
splits.add(p_parts["split"])
if any(not re.match(_split_re, split) for split in splits):
raise ValueError(f"Split name should match '{_split_re}'' but got '{splits}'.")
sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted(
splits - {str(split) for split in DEFAULT_SPLITS}
)
return {split: [split_pattern.format(split=split)] for split in sorted_splits}
# then check the default patterns based on train/valid/test splits
for patterns_dict in ALL_DEFAULT_PATTERNS:
non_empty_splits = []
for split, patterns in patterns_dict.items():
for pattern in patterns:
try:
data_files = pattern_resolver(pattern)
except FileNotFoundError:
continue
if len(data_files) > 0:
non_empty_splits.append(split)
break
if non_empty_splits:
return {split: patterns_dict[split] for split in non_empty_splits}
raise FileNotFoundError(f"Couldn't resolve pattern {pattern} with resolver {pattern_resolver}")
def resolve_pattern(
pattern: str,
base_path: str,
allowed_extensions: Optional[list[str]] = None,
download_config: Optional[DownloadConfig] = None,
) -> list[str]:
"""
Resolve the paths and URLs of the data files from the pattern passed by the user.
You can use patterns to resolve multiple local files. Here are a few examples:
- *.csv to match all the CSV files at the first level
- **.csv to match all the CSV files at any level
- data/* to match all the files inside "data"
- data/** to match all the files inside "data" and its subdirectories
The patterns are resolved using the fsspec glob. In fsspec>=2023.12.0 this is equivalent to
Python's glob.glob, Path.glob, Path.match and fnmatch where ** is unsupported with a prefix/suffix
other than a forward slash /.
More generally:
- '*' matches any character except a forward-slash (to match just the file or directory name)
- '**' matches any character including a forward-slash /
Hidden files and directories (i.e. whose names start with a dot) are ignored, unless they are explicitly requested.
The same applies to special directories that start with a double underscore like "__pycache__".
You can still include one if the pattern explicitly mentions it:
- to include a hidden file: "*/.hidden.txt" or "*/.*"
- to include a hidden directory: ".hidden/*" or ".*/*"
- to include a special directory: "__special__/*" or "__*/*"
Example::
>>> from datasets.data_files import resolve_pattern
>>> base_path = "."
>>> resolve_pattern("docs/**/*.py", base_path)
[/Users/mariosasko/Desktop/projects/datasets/docs/source/_config.py']
Args:
pattern (str): Unix pattern or paths or URLs of the data files to resolve.
The paths can be absolute or relative to base_path.
Remote filesystems using fsspec are supported, e.g. with the hf:// protocol.
base_path (str): Base path to use when resolving relative paths.
allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions).
For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"]
download_config ([`DownloadConfig`], *optional*): Specific download configuration parameters.
Returns:
List[str]: List of paths or URLs to the local or remote files that match the patterns.
"""
if is_relative_path(pattern):
pattern = xjoin(base_path, pattern)
elif is_local_path(pattern):
base_path = os.path.splitdrive(pattern)[0] + os.sep
else:
base_path = ""
pattern, storage_options = _prepare_path_and_storage_options(pattern, download_config=download_config)
fs, fs_pattern = url_to_fs(pattern, **storage_options)
files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)}
protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0]
protocol_prefix = protocol + "://" if protocol != "file" else ""
glob_kwargs = {}
if protocol == "hf" and config.HF_HUB_VERSION >= version.parse("0.20.0"):
# 10 times faster glob with detail=True (ignores costly info like lastCommit)
glob_kwargs["expand_info"] = False
matched_paths = [
filepath if filepath.startswith(protocol_prefix) else protocol_prefix + filepath
for filepath, info in fs.glob(pattern, detail=True, **glob_kwargs).items()
if (info["type"] == "file" or (info.get("islink") and os.path.isfile(os.path.realpath(filepath))))
and (xbasename(filepath) not in files_to_ignore)
and not _is_inside_unrequested_special_dir(filepath, fs_pattern)
and not _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(filepath, fs_pattern)
] # ignore .ipynb and __pycache__, but keep /../
if allowed_extensions is not None:
out = [
filepath
for filepath in matched_paths
if any("." + suffix in allowed_extensions for suffix in xbasename(filepath).split(".")[1:])
]
if len(out) < len(matched_paths):
invalid_matched_files = list(set(matched_paths) - set(out))
logger.info(
f"Some files matched the pattern '{pattern}' but don't have valid data file extensions: {invalid_matched_files}"
)
else:
out = matched_paths
if not out:
error_msg = f"Unable to find '{pattern}'"
if allowed_extensions is not None:
error_msg += f" with any supported extension {list(allowed_extensions)}"
raise FileNotFoundError(error_msg)
return out
def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig] = None) -> dict[str, list[str]]:
"""
Get the default pattern from a directory testing all the supported patterns.
The first patterns to return a non-empty list of data files is returned.
Some examples of supported patterns:
Input:
my_dataset_repository/
β”œβ”€β”€ README.md
└── dataset.csv
Output:
{'train': ['**']}
Input:
my_dataset_repository/
β”œβ”€β”€ README.md
β”œβ”€β”€ train.csv
└── test.csv
my_dataset_repository/
β”œβ”€β”€ README.md
└── data/
β”œβ”€β”€ train.csv
└── test.csv
my_dataset_repository/
β”œβ”€β”€ README.md
β”œβ”€β”€ train_0.csv
β”œβ”€β”€ train_1.csv
β”œβ”€β”€ train_2.csv
β”œβ”€β”€ train_3.csv
β”œβ”€β”€ test_0.csv
└── test_1.csv
Output:
{'train': ['**/train[-._ 0-9]*', '**/*[-._ 0-9]train[-._ 0-9]*', '**/training[-._ 0-9]*', '**/*[-._ 0-9]training[-._ 0-9]*'],
'test': ['**/test[-._ 0-9]*', '**/*[-._ 0-9]test[-._ 0-9]*', '**/testing[-._ 0-9]*', '**/*[-._ 0-9]testing[-._ 0-9]*', ...]}
Input:
my_dataset_repository/
β”œβ”€β”€ README.md
└── data/
β”œβ”€β”€ train/
β”‚ β”œβ”€β”€ shard_0.csv
β”‚ β”œβ”€β”€ shard_1.csv
β”‚ β”œβ”€β”€ shard_2.csv
β”‚ └── shard_3.csv
└── test/
β”œβ”€β”€ shard_0.csv
└── shard_1.csv
Output:
{'train': ['**/train/**', '**/train[-._ 0-9]*/**', '**/*[-._ 0-9]train/**', '**/*[-._ 0-9]train[-._ 0-9]*/**', ...],
'test': ['**/test/**', '**/test[-._ 0-9]*/**', '**/*[-._ 0-9]test/**', '**/*[-._ 0-9]test[-._ 0-9]*/**', ...]}
Input:
my_dataset_repository/
β”œβ”€β”€ README.md
└── data/
β”œβ”€β”€ train-00000-of-00003.csv
β”œβ”€β”€ train-00001-of-00003.csv
β”œβ”€β”€ train-00002-of-00003.csv
β”œβ”€β”€ test-00000-of-00001.csv
β”œβ”€β”€ random-00000-of-00003.csv
β”œβ”€β”€ random-00001-of-00003.csv
└── random-00002-of-00003.csv
Output:
{'train': ['data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'],
'test': ['data/test-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'],
'random': ['data/random-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*']}
In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS.
"""
resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config)
try:
return _get_data_files_patterns(resolver)
except FileNotFoundError:
raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None
def _get_single_origin_metadata(
data_file: str,
download_config: Optional[DownloadConfig] = None,
) -> SingleOriginMetadata:
data_file, storage_options = _prepare_path_and_storage_options(data_file, download_config=download_config)
fs, *_ = url_to_fs(data_file, **storage_options)
if isinstance(fs, HfFileSystem):
resolved_path = fs.resolve_path(data_file)
return resolved_path.repo_id, resolved_path.revision
elif data_file.startswith(config.HF_ENDPOINT):
hffs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=download_config.token)
data_file = "hf://" + data_file[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1)
resolved_path = hffs.resolve_path(data_file)
return resolved_path.repo_id, resolved_path.revision
info = fs.info(data_file)
# s3fs uses "ETag", gcsfs uses "etag", and for local we simply check mtime
for key in ["ETag", "etag", "mtime"]:
if key in info:
return (str(info[key]),)
return ()
def _get_origin_metadata(
data_files: list[str],
download_config: Optional[DownloadConfig] = None,
max_workers: Optional[int] = None,
) -> list[SingleOriginMetadata]:
max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS
return thread_map(
partial(_get_single_origin_metadata, download_config=download_config),
data_files,
max_workers=max_workers,
tqdm_class=hf_tqdm,
desc="Resolving data files",
# set `disable=None` rather than `disable=False` by default to disable progress bar when no TTY attached
disable=len(data_files) <= 16 or None,
)
class DataFilesList(list[str]):
"""
List of data files (absolute local paths or URLs).
It has two construction methods given the user's data files patterns:
- ``from_hf_repo``: resolve patterns inside a dataset repository
- ``from_local_or_remote``: resolve patterns from a local path
Moreover, DataFilesList has an additional attribute ``origin_metadata``.
It can store:
- the last modified time of local files
- ETag of remote files
- commit sha of a dataset repository
Thanks to this additional attribute, it is possible to hash the list
and get a different hash if and only if at least one file changed.
This is useful for caching Dataset objects that are obtained from a list of data files.
"""
def __init__(self, data_files: list[str], origin_metadata: list[SingleOriginMetadata]) -> None:
super().__init__(data_files)
self.origin_metadata = origin_metadata
def __add__(self, other: "DataFilesList") -> "DataFilesList":
return DataFilesList([*self, *other], self.origin_metadata + other.origin_metadata)
@classmethod
def from_hf_repo(
cls,
patterns: list[str],
dataset_info: huggingface_hub.hf_api.DatasetInfo,
base_path: Optional[str] = None,
allowed_extensions: Optional[list[str]] = None,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesList":
base_path = f"hf://datasets/{dataset_info.id}@{dataset_info.sha}/{base_path or ''}".rstrip("/")
return cls.from_patterns(
patterns, base_path=base_path, allowed_extensions=allowed_extensions, download_config=download_config
)
@classmethod
def from_local_or_remote(
cls,
patterns: list[str],
base_path: Optional[str] = None,
allowed_extensions: Optional[list[str]] = None,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesList":
base_path = base_path if base_path is not None else Path().resolve().as_posix()
return cls.from_patterns(
patterns, base_path=base_path, allowed_extensions=allowed_extensions, download_config=download_config
)
@classmethod
def from_patterns(
cls,
patterns: list[str],
base_path: Optional[str] = None,
allowed_extensions: Optional[list[str]] = None,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesList":
base_path = base_path if base_path is not None else Path().resolve().as_posix()
data_files = []
for pattern in patterns:
try:
data_files.extend(
resolve_pattern(
pattern,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
)
except FileNotFoundError:
if not has_magic(pattern):
raise
origin_metadata = _get_origin_metadata(data_files, download_config=download_config)
return cls(data_files, origin_metadata)
def filter(
self, *, extensions: Optional[list[str]] = None, file_names: Optional[list[str]] = None
) -> "DataFilesList":
patterns = []
if extensions:
ext_pattern = "|".join(re.escape(ext) for ext in extensions)
patterns.append(re.compile(f".*({ext_pattern})(\\..+)?$"))
if file_names:
fn_pattern = "|".join(re.escape(fn) for fn in file_names)
patterns.append(re.compile(rf".*[\/]?({fn_pattern})$"))
if patterns:
return DataFilesList(
[data_file for data_file in self if any(pattern.match(data_file) for pattern in patterns)],
origin_metadata=self.origin_metadata,
)
else:
return DataFilesList(list(self), origin_metadata=self.origin_metadata)
class DataFilesDict(dict[str, DataFilesList]):
"""
Dict of split_name -> list of data files (absolute local paths or URLs).
It has two construction methods given the user's data files patterns :
- ``from_hf_repo``: resolve patterns inside a dataset repository
- ``from_local_or_remote``: resolve patterns from a local path
Moreover, each list is a DataFilesList. It is possible to hash the dictionary
and get a different hash if and only if at least one file changed.
For more info, see [`DataFilesList`].
This is useful for caching Dataset objects that are obtained from a list of data files.
Changing the order of the keys of this dictionary also doesn't change its hash.
"""
@classmethod
def from_local_or_remote(
cls,
patterns: dict[str, Union[list[str], DataFilesList]],
base_path: Optional[str] = None,
allowed_extensions: Optional[list[str]] = None,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesDict":
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
patterns_for_key
if isinstance(patterns_for_key, DataFilesList)
else DataFilesList.from_local_or_remote(
patterns_for_key,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
)
return out
@classmethod
def from_hf_repo(
cls,
patterns: dict[str, Union[list[str], DataFilesList]],
dataset_info: huggingface_hub.hf_api.DatasetInfo,
base_path: Optional[str] = None,
allowed_extensions: Optional[list[str]] = None,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesDict":
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
patterns_for_key
if isinstance(patterns_for_key, DataFilesList)
else DataFilesList.from_hf_repo(
patterns_for_key,
dataset_info=dataset_info,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
)
return out
@classmethod
def from_patterns(
cls,
patterns: dict[str, Union[list[str], DataFilesList]],
base_path: Optional[str] = None,
allowed_extensions: Optional[list[str]] = None,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesDict":
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
patterns_for_key
if isinstance(patterns_for_key, DataFilesList)
else DataFilesList.from_patterns(
patterns_for_key,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
)
return out
def filter(
self, *, extensions: Optional[list[str]] = None, file_names: Optional[list[str]] = None
) -> "DataFilesDict":
out = type(self)()
for key, data_files_list in self.items():
out[key] = data_files_list.filter(extensions=extensions, file_names=file_names)
return out
class DataFilesPatternsList(list[str]):
"""
List of data files patterns (absolute local paths or URLs).
For each pattern there should also be a list of allowed extensions
to keep, or a None ot keep all the files for the pattern.
"""
def __init__(
self,
patterns: list[str],
allowed_extensions: list[Optional[list[str]]],
):
super().__init__(patterns)
self.allowed_extensions = allowed_extensions
def __add__(self, other):
return DataFilesList([*self, *other], self.allowed_extensions + other.allowed_extensions)
@classmethod
def from_patterns(
cls, patterns: list[str], allowed_extensions: Optional[list[str]] = None
) -> "DataFilesPatternsList":
return cls(patterns, [allowed_extensions] * len(patterns))
def resolve(
self,
base_path: str,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesList":
base_path = base_path if base_path is not None else Path().resolve().as_posix()
data_files = []
for pattern, allowed_extensions in zip(self, self.allowed_extensions):
try:
data_files.extend(
resolve_pattern(
pattern,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
)
except FileNotFoundError:
if not has_magic(pattern):
raise
origin_metadata = _get_origin_metadata(data_files, download_config=download_config)
return DataFilesList(data_files, origin_metadata)
def filter_extensions(self, extensions: list[str]) -> "DataFilesPatternsList":
return DataFilesPatternsList(
self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions]
)
class DataFilesPatternsDict(dict[str, DataFilesPatternsList]):
"""
Dict of split_name -> list of data files patterns (absolute local paths or URLs).
"""
@classmethod
def from_patterns(
cls, patterns: dict[str, list[str]], allowed_extensions: Optional[list[str]] = None
) -> "DataFilesPatternsDict":
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
patterns_for_key
if isinstance(patterns_for_key, DataFilesPatternsList)
else DataFilesPatternsList.from_patterns(
patterns_for_key,
allowed_extensions=allowed_extensions,
)
)
return out
def resolve(
self,
base_path: str,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesDict":
out = DataFilesDict()
for key, data_files_patterns_list in self.items():
out[key] = data_files_patterns_list.resolve(base_path, download_config)
return out
def filter_extensions(self, extensions: list[str]) -> "DataFilesPatternsDict":
out = type(self)()
for key, data_files_patterns_list in self.items():
out[key] = data_files_patterns_list.filter_extensions(extensions)
return out