|
import os |
|
import re |
|
from functools import partial |
|
from glob import has_magic |
|
from pathlib import Path, PurePath |
|
from typing import Callable, Optional, Union |
|
|
|
import huggingface_hub |
|
from fsspec.core import url_to_fs |
|
from huggingface_hub import HfFileSystem |
|
from packaging import version |
|
from tqdm.contrib.concurrent import thread_map |
|
|
|
from . import config |
|
from .download import DownloadConfig |
|
from .naming import _split_re |
|
from .splits import Split |
|
from .utils import logging |
|
from .utils import tqdm as hf_tqdm |
|
from .utils.file_utils import _prepare_path_and_storage_options, is_local_path, is_relative_path, xbasename, xjoin |
|
from .utils.py_utils import glob_pattern_to_regex, string_to_dict |
|
|
|
|
|
SingleOriginMetadata = Union[tuple[str, str], tuple[str], tuple[()]] |
|
|
|
|
|
SANITIZED_DEFAULT_SPLIT = str(Split.TRAIN) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Url(str): |
|
pass |
|
|
|
|
|
class EmptyDatasetError(FileNotFoundError): |
|
pass |
|
|
|
|
|
SPLIT_PATTERN_SHARDED = "data/{split}-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*" |
|
|
|
SPLIT_KEYWORDS = { |
|
Split.TRAIN: ["train", "training"], |
|
Split.VALIDATION: ["validation", "valid", "dev", "val"], |
|
Split.TEST: ["test", "testing", "eval", "evaluation"], |
|
} |
|
NON_WORDS_CHARS = "-._ 0-9" |
|
if config.FSSPEC_VERSION < version.parse("2023.9.0"): |
|
KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**[{sep}/]{keyword}[{sep}]*", "{keyword}[{sep}]*"] |
|
KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [ |
|
"{keyword}/**", |
|
"{keyword}[{sep}]*/**", |
|
"**[{sep}/]{keyword}/**", |
|
"**[{sep}/]{keyword}[{sep}]*/**", |
|
] |
|
elif config.FSSPEC_VERSION < version.parse("2023.12.0"): |
|
KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**/*[{sep}/]{keyword}[{sep}]*", "{keyword}[{sep}]*"] |
|
KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [ |
|
"{keyword}/**/*", |
|
"{keyword}[{sep}]*/**/*", |
|
"**/*[{sep}/]{keyword}/**/*", |
|
"**/*[{sep}/]{keyword}[{sep}]*/**/*", |
|
] |
|
else: |
|
KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**/{keyword}[{sep}]*", "**/*[{sep}]{keyword}[{sep}]*"] |
|
KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [ |
|
"**/{keyword}/**", |
|
"**/{keyword}[{sep}]*/**", |
|
"**/*[{sep}]{keyword}/**", |
|
"**/*[{sep}]{keyword}[{sep}]*/**", |
|
] |
|
|
|
DEFAULT_SPLITS = [Split.TRAIN, Split.VALIDATION, Split.TEST] |
|
DEFAULT_PATTERNS_SPLIT_IN_FILENAME = { |
|
split: [ |
|
pattern.format(keyword=keyword, sep=NON_WORDS_CHARS) |
|
for keyword in SPLIT_KEYWORDS[split] |
|
for pattern in KEYWORDS_IN_FILENAME_BASE_PATTERNS |
|
] |
|
for split in DEFAULT_SPLITS |
|
} |
|
DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME = { |
|
split: [ |
|
pattern.format(keyword=keyword, sep=NON_WORDS_CHARS) |
|
for keyword in SPLIT_KEYWORDS[split] |
|
for pattern in KEYWORDS_IN_DIR_NAME_BASE_PATTERNS |
|
] |
|
for split in DEFAULT_SPLITS |
|
} |
|
|
|
|
|
DEFAULT_PATTERNS_ALL = { |
|
Split.TRAIN: ["**"], |
|
} |
|
|
|
ALL_SPLIT_PATTERNS = [SPLIT_PATTERN_SHARDED] |
|
ALL_DEFAULT_PATTERNS = [ |
|
DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME, |
|
DEFAULT_PATTERNS_SPLIT_IN_FILENAME, |
|
DEFAULT_PATTERNS_ALL, |
|
] |
|
WILDCARD_CHARACTERS = "*[]" |
|
FILES_TO_IGNORE = [ |
|
"README.md", |
|
"config.json", |
|
"dataset_info.json", |
|
"dataset_infos.json", |
|
"dummy_data.zip", |
|
"dataset_dict.json", |
|
] |
|
|
|
|
|
def contains_wildcards(pattern: str) -> bool: |
|
return any(wildcard_character in pattern for wildcard_character in WILDCARD_CHARACTERS) |
|
|
|
|
|
def sanitize_patterns(patterns: Union[dict, list, str]) -> dict[str, Union[list[str], "DataFilesList"]]: |
|
""" |
|
Take the data_files patterns from the user, and format them into a dictionary. |
|
Each key is the name of the split, and each value is a list of data files patterns (paths or urls). |
|
The default split is "train". |
|
|
|
Returns: |
|
patterns: dictionary of split_name -> list of patterns |
|
""" |
|
if isinstance(patterns, dict): |
|
return {str(key): value if isinstance(value, list) else [value] for key, value in patterns.items()} |
|
elif isinstance(patterns, str): |
|
return {SANITIZED_DEFAULT_SPLIT: [patterns]} |
|
elif isinstance(patterns, list): |
|
if any(isinstance(pattern, dict) for pattern in patterns): |
|
for pattern in patterns: |
|
if not ( |
|
isinstance(pattern, dict) |
|
and len(pattern) == 2 |
|
and "split" in pattern |
|
and isinstance(pattern.get("path"), (str, list)) |
|
): |
|
raise ValueError( |
|
f"Expected each split to have a 'path' key which can be a string or a list of strings, but got {pattern}" |
|
) |
|
splits = [pattern["split"] for pattern in patterns] |
|
if len(set(splits)) != len(splits): |
|
raise ValueError(f"Some splits are duplicated in data_files: {splits}") |
|
return { |
|
str(pattern["split"]): pattern["path"] if isinstance(pattern["path"], list) else [pattern["path"]] |
|
for pattern in patterns |
|
} |
|
else: |
|
return {SANITIZED_DEFAULT_SPLIT: patterns} |
|
else: |
|
return sanitize_patterns(list(patterns)) |
|
|
|
|
|
def _is_inside_unrequested_special_dir(matched_rel_path: str, pattern: str) -> bool: |
|
""" |
|
When a path matches a pattern, we additionally check if it's inside a special directory |
|
we ignore by default (if it starts with a double underscore). |
|
|
|
Users can still explicitly request a filepath inside such a directory if "__pycache__" is |
|
mentioned explicitly in the requested pattern. |
|
|
|
Some examples: |
|
|
|
base directory: |
|
|
|
./ |
|
βββ __pycache__ |
|
βββ b.txt |
|
|
|
>>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "**") |
|
True |
|
>>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "*/b.txt") |
|
True |
|
>>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "__pycache__/*") |
|
False |
|
>>> _is_inside_unrequested_special_dir("__pycache__/b.txt", "__*/*") |
|
False |
|
""" |
|
|
|
|
|
|
|
data_dirs_to_ignore_in_path = [part for part in PurePath(matched_rel_path).parent.parts if part.startswith("__")] |
|
data_dirs_to_ignore_in_pattern = [part for part in PurePath(pattern).parent.parts if part.startswith("__")] |
|
return len(data_dirs_to_ignore_in_path) != len(data_dirs_to_ignore_in_pattern) |
|
|
|
|
|
def _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(matched_rel_path: str, pattern: str) -> bool: |
|
""" |
|
When a path matches a pattern, we additionally check if it's a hidden file or if it's inside |
|
a hidden directory we ignore by default, i.e. if the file name or a parent directory name starts with a dot. |
|
|
|
Users can still explicitly request a filepath that is hidden or is inside a hidden directory |
|
if the hidden part is mentioned explicitly in the requested pattern. |
|
|
|
Some examples: |
|
|
|
base directory: |
|
|
|
./ |
|
βββ .hidden_file.txt |
|
|
|
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_file.txt", "**") |
|
True |
|
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_file.txt", ".*") |
|
False |
|
|
|
base directory: |
|
|
|
./ |
|
βββ .hidden_dir |
|
βββ a.txt |
|
|
|
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/a.txt", "**") |
|
True |
|
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/a.txt", ".*/*") |
|
False |
|
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/a.txt", ".hidden_dir/*") |
|
False |
|
|
|
base directory: |
|
|
|
./ |
|
βββ .hidden_dir |
|
βββ .hidden_file.txt |
|
|
|
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", "**") |
|
True |
|
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".*/*") |
|
True |
|
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".*/.*") |
|
False |
|
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".hidden_dir/*") |
|
True |
|
>>> _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(".hidden_dir/.hidden_file.txt", ".hidden_dir/.*") |
|
False |
|
""" |
|
|
|
|
|
|
|
hidden_directories_in_path = [ |
|
part for part in PurePath(matched_rel_path).parts if part.startswith(".") and not set(part) == {"."} |
|
] |
|
hidden_directories_in_pattern = [ |
|
part for part in PurePath(pattern).parts if part.startswith(".") and not set(part) == {"."} |
|
] |
|
return len(hidden_directories_in_path) != len(hidden_directories_in_pattern) |
|
|
|
|
|
def _get_data_files_patterns(pattern_resolver: Callable[[str], list[str]]) -> dict[str, list[str]]: |
|
""" |
|
Get the default pattern from a directory or repository by testing all the supported patterns. |
|
The first patterns to return a non-empty list of data files is returned. |
|
|
|
In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS. |
|
""" |
|
|
|
for split_pattern in ALL_SPLIT_PATTERNS: |
|
pattern = split_pattern.replace("{split}", "*") |
|
try: |
|
data_files = pattern_resolver(pattern) |
|
except FileNotFoundError: |
|
continue |
|
if len(data_files) > 0: |
|
splits: set[str] = set() |
|
for p in data_files: |
|
p_parts = string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern))) |
|
assert p_parts is not None |
|
splits.add(p_parts["split"]) |
|
|
|
if any(not re.match(_split_re, split) for split in splits): |
|
raise ValueError(f"Split name should match '{_split_re}'' but got '{splits}'.") |
|
sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted( |
|
splits - {str(split) for split in DEFAULT_SPLITS} |
|
) |
|
return {split: [split_pattern.format(split=split)] for split in sorted_splits} |
|
|
|
for patterns_dict in ALL_DEFAULT_PATTERNS: |
|
non_empty_splits = [] |
|
for split, patterns in patterns_dict.items(): |
|
for pattern in patterns: |
|
try: |
|
data_files = pattern_resolver(pattern) |
|
except FileNotFoundError: |
|
continue |
|
if len(data_files) > 0: |
|
non_empty_splits.append(split) |
|
break |
|
if non_empty_splits: |
|
return {split: patterns_dict[split] for split in non_empty_splits} |
|
raise FileNotFoundError(f"Couldn't resolve pattern {pattern} with resolver {pattern_resolver}") |
|
|
|
|
|
def resolve_pattern( |
|
pattern: str, |
|
base_path: str, |
|
allowed_extensions: Optional[list[str]] = None, |
|
download_config: Optional[DownloadConfig] = None, |
|
) -> list[str]: |
|
""" |
|
Resolve the paths and URLs of the data files from the pattern passed by the user. |
|
|
|
You can use patterns to resolve multiple local files. Here are a few examples: |
|
- *.csv to match all the CSV files at the first level |
|
- **.csv to match all the CSV files at any level |
|
- data/* to match all the files inside "data" |
|
- data/** to match all the files inside "data" and its subdirectories |
|
|
|
The patterns are resolved using the fsspec glob. In fsspec>=2023.12.0 this is equivalent to |
|
Python's glob.glob, Path.glob, Path.match and fnmatch where ** is unsupported with a prefix/suffix |
|
other than a forward slash /. |
|
|
|
More generally: |
|
- '*' matches any character except a forward-slash (to match just the file or directory name) |
|
- '**' matches any character including a forward-slash / |
|
|
|
Hidden files and directories (i.e. whose names start with a dot) are ignored, unless they are explicitly requested. |
|
The same applies to special directories that start with a double underscore like "__pycache__". |
|
You can still include one if the pattern explicitly mentions it: |
|
- to include a hidden file: "*/.hidden.txt" or "*/.*" |
|
- to include a hidden directory: ".hidden/*" or ".*/*" |
|
- to include a special directory: "__special__/*" or "__*/*" |
|
|
|
Example:: |
|
|
|
>>> from datasets.data_files import resolve_pattern |
|
>>> base_path = "." |
|
>>> resolve_pattern("docs/**/*.py", base_path) |
|
[/Users/mariosasko/Desktop/projects/datasets/docs/source/_config.py'] |
|
|
|
Args: |
|
pattern (str): Unix pattern or paths or URLs of the data files to resolve. |
|
The paths can be absolute or relative to base_path. |
|
Remote filesystems using fsspec are supported, e.g. with the hf:// protocol. |
|
base_path (str): Base path to use when resolving relative paths. |
|
allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions). |
|
For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"] |
|
download_config ([`DownloadConfig`], *optional*): Specific download configuration parameters. |
|
Returns: |
|
List[str]: List of paths or URLs to the local or remote files that match the patterns. |
|
""" |
|
if is_relative_path(pattern): |
|
pattern = xjoin(base_path, pattern) |
|
elif is_local_path(pattern): |
|
base_path = os.path.splitdrive(pattern)[0] + os.sep |
|
else: |
|
base_path = "" |
|
pattern, storage_options = _prepare_path_and_storage_options(pattern, download_config=download_config) |
|
fs, fs_pattern = url_to_fs(pattern, **storage_options) |
|
files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)} |
|
protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0] |
|
protocol_prefix = protocol + "://" if protocol != "file" else "" |
|
glob_kwargs = {} |
|
if protocol == "hf" and config.HF_HUB_VERSION >= version.parse("0.20.0"): |
|
|
|
glob_kwargs["expand_info"] = False |
|
matched_paths = [ |
|
filepath if filepath.startswith(protocol_prefix) else protocol_prefix + filepath |
|
for filepath, info in fs.glob(pattern, detail=True, **glob_kwargs).items() |
|
if (info["type"] == "file" or (info.get("islink") and os.path.isfile(os.path.realpath(filepath)))) |
|
and (xbasename(filepath) not in files_to_ignore) |
|
and not _is_inside_unrequested_special_dir(filepath, fs_pattern) |
|
and not _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(filepath, fs_pattern) |
|
] |
|
if allowed_extensions is not None: |
|
out = [ |
|
filepath |
|
for filepath in matched_paths |
|
if any("." + suffix in allowed_extensions for suffix in xbasename(filepath).split(".")[1:]) |
|
] |
|
if len(out) < len(matched_paths): |
|
invalid_matched_files = list(set(matched_paths) - set(out)) |
|
logger.info( |
|
f"Some files matched the pattern '{pattern}' but don't have valid data file extensions: {invalid_matched_files}" |
|
) |
|
else: |
|
out = matched_paths |
|
if not out: |
|
error_msg = f"Unable to find '{pattern}'" |
|
if allowed_extensions is not None: |
|
error_msg += f" with any supported extension {list(allowed_extensions)}" |
|
raise FileNotFoundError(error_msg) |
|
return out |
|
|
|
|
|
def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig] = None) -> dict[str, list[str]]: |
|
""" |
|
Get the default pattern from a directory testing all the supported patterns. |
|
The first patterns to return a non-empty list of data files is returned. |
|
|
|
Some examples of supported patterns: |
|
|
|
Input: |
|
|
|
my_dataset_repository/ |
|
βββ README.md |
|
βββ dataset.csv |
|
|
|
Output: |
|
|
|
{'train': ['**']} |
|
|
|
Input: |
|
|
|
my_dataset_repository/ |
|
βββ README.md |
|
βββ train.csv |
|
βββ test.csv |
|
|
|
my_dataset_repository/ |
|
βββ README.md |
|
βββ data/ |
|
βββ train.csv |
|
βββ test.csv |
|
|
|
my_dataset_repository/ |
|
βββ README.md |
|
βββ train_0.csv |
|
βββ train_1.csv |
|
βββ train_2.csv |
|
βββ train_3.csv |
|
βββ test_0.csv |
|
βββ test_1.csv |
|
|
|
Output: |
|
|
|
{'train': ['**/train[-._ 0-9]*', '**/*[-._ 0-9]train[-._ 0-9]*', '**/training[-._ 0-9]*', '**/*[-._ 0-9]training[-._ 0-9]*'], |
|
'test': ['**/test[-._ 0-9]*', '**/*[-._ 0-9]test[-._ 0-9]*', '**/testing[-._ 0-9]*', '**/*[-._ 0-9]testing[-._ 0-9]*', ...]} |
|
|
|
Input: |
|
|
|
my_dataset_repository/ |
|
βββ README.md |
|
βββ data/ |
|
βββ train/ |
|
β βββ shard_0.csv |
|
β βββ shard_1.csv |
|
β βββ shard_2.csv |
|
β βββ shard_3.csv |
|
βββ test/ |
|
βββ shard_0.csv |
|
βββ shard_1.csv |
|
|
|
Output: |
|
|
|
{'train': ['**/train/**', '**/train[-._ 0-9]*/**', '**/*[-._ 0-9]train/**', '**/*[-._ 0-9]train[-._ 0-9]*/**', ...], |
|
'test': ['**/test/**', '**/test[-._ 0-9]*/**', '**/*[-._ 0-9]test/**', '**/*[-._ 0-9]test[-._ 0-9]*/**', ...]} |
|
|
|
Input: |
|
|
|
my_dataset_repository/ |
|
βββ README.md |
|
βββ data/ |
|
βββ train-00000-of-00003.csv |
|
βββ train-00001-of-00003.csv |
|
βββ train-00002-of-00003.csv |
|
βββ test-00000-of-00001.csv |
|
βββ random-00000-of-00003.csv |
|
βββ random-00001-of-00003.csv |
|
βββ random-00002-of-00003.csv |
|
|
|
Output: |
|
|
|
{'train': ['data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'], |
|
'test': ['data/test-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'], |
|
'random': ['data/random-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*']} |
|
|
|
In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS. |
|
""" |
|
resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config) |
|
try: |
|
return _get_data_files_patterns(resolver) |
|
except FileNotFoundError: |
|
raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None |
|
|
|
|
|
def _get_single_origin_metadata( |
|
data_file: str, |
|
download_config: Optional[DownloadConfig] = None, |
|
) -> SingleOriginMetadata: |
|
data_file, storage_options = _prepare_path_and_storage_options(data_file, download_config=download_config) |
|
fs, *_ = url_to_fs(data_file, **storage_options) |
|
if isinstance(fs, HfFileSystem): |
|
resolved_path = fs.resolve_path(data_file) |
|
return resolved_path.repo_id, resolved_path.revision |
|
elif data_file.startswith(config.HF_ENDPOINT): |
|
hffs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=download_config.token) |
|
data_file = "hf://" + data_file[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1) |
|
resolved_path = hffs.resolve_path(data_file) |
|
return resolved_path.repo_id, resolved_path.revision |
|
info = fs.info(data_file) |
|
|
|
for key in ["ETag", "etag", "mtime"]: |
|
if key in info: |
|
return (str(info[key]),) |
|
return () |
|
|
|
|
|
def _get_origin_metadata( |
|
data_files: list[str], |
|
download_config: Optional[DownloadConfig] = None, |
|
max_workers: Optional[int] = None, |
|
) -> list[SingleOriginMetadata]: |
|
max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS |
|
return thread_map( |
|
partial(_get_single_origin_metadata, download_config=download_config), |
|
data_files, |
|
max_workers=max_workers, |
|
tqdm_class=hf_tqdm, |
|
desc="Resolving data files", |
|
|
|
disable=len(data_files) <= 16 or None, |
|
) |
|
|
|
|
|
class DataFilesList(list[str]): |
|
""" |
|
List of data files (absolute local paths or URLs). |
|
It has two construction methods given the user's data files patterns: |
|
- ``from_hf_repo``: resolve patterns inside a dataset repository |
|
- ``from_local_or_remote``: resolve patterns from a local path |
|
|
|
Moreover, DataFilesList has an additional attribute ``origin_metadata``. |
|
It can store: |
|
- the last modified time of local files |
|
- ETag of remote files |
|
- commit sha of a dataset repository |
|
|
|
Thanks to this additional attribute, it is possible to hash the list |
|
and get a different hash if and only if at least one file changed. |
|
This is useful for caching Dataset objects that are obtained from a list of data files. |
|
""" |
|
|
|
def __init__(self, data_files: list[str], origin_metadata: list[SingleOriginMetadata]) -> None: |
|
super().__init__(data_files) |
|
self.origin_metadata = origin_metadata |
|
|
|
def __add__(self, other: "DataFilesList") -> "DataFilesList": |
|
return DataFilesList([*self, *other], self.origin_metadata + other.origin_metadata) |
|
|
|
@classmethod |
|
def from_hf_repo( |
|
cls, |
|
patterns: list[str], |
|
dataset_info: huggingface_hub.hf_api.DatasetInfo, |
|
base_path: Optional[str] = None, |
|
allowed_extensions: Optional[list[str]] = None, |
|
download_config: Optional[DownloadConfig] = None, |
|
) -> "DataFilesList": |
|
base_path = f"hf://datasets/{dataset_info.id}@{dataset_info.sha}/{base_path or ''}".rstrip("/") |
|
return cls.from_patterns( |
|
patterns, base_path=base_path, allowed_extensions=allowed_extensions, download_config=download_config |
|
) |
|
|
|
@classmethod |
|
def from_local_or_remote( |
|
cls, |
|
patterns: list[str], |
|
base_path: Optional[str] = None, |
|
allowed_extensions: Optional[list[str]] = None, |
|
download_config: Optional[DownloadConfig] = None, |
|
) -> "DataFilesList": |
|
base_path = base_path if base_path is not None else Path().resolve().as_posix() |
|
return cls.from_patterns( |
|
patterns, base_path=base_path, allowed_extensions=allowed_extensions, download_config=download_config |
|
) |
|
|
|
@classmethod |
|
def from_patterns( |
|
cls, |
|
patterns: list[str], |
|
base_path: Optional[str] = None, |
|
allowed_extensions: Optional[list[str]] = None, |
|
download_config: Optional[DownloadConfig] = None, |
|
) -> "DataFilesList": |
|
base_path = base_path if base_path is not None else Path().resolve().as_posix() |
|
data_files = [] |
|
for pattern in patterns: |
|
try: |
|
data_files.extend( |
|
resolve_pattern( |
|
pattern, |
|
base_path=base_path, |
|
allowed_extensions=allowed_extensions, |
|
download_config=download_config, |
|
) |
|
) |
|
except FileNotFoundError: |
|
if not has_magic(pattern): |
|
raise |
|
origin_metadata = _get_origin_metadata(data_files, download_config=download_config) |
|
return cls(data_files, origin_metadata) |
|
|
|
def filter( |
|
self, *, extensions: Optional[list[str]] = None, file_names: Optional[list[str]] = None |
|
) -> "DataFilesList": |
|
patterns = [] |
|
if extensions: |
|
ext_pattern = "|".join(re.escape(ext) for ext in extensions) |
|
patterns.append(re.compile(f".*({ext_pattern})(\\..+)?$")) |
|
if file_names: |
|
fn_pattern = "|".join(re.escape(fn) for fn in file_names) |
|
patterns.append(re.compile(rf".*[\/]?({fn_pattern})$")) |
|
if patterns: |
|
return DataFilesList( |
|
[data_file for data_file in self if any(pattern.match(data_file) for pattern in patterns)], |
|
origin_metadata=self.origin_metadata, |
|
) |
|
else: |
|
return DataFilesList(list(self), origin_metadata=self.origin_metadata) |
|
|
|
|
|
class DataFilesDict(dict[str, DataFilesList]): |
|
""" |
|
Dict of split_name -> list of data files (absolute local paths or URLs). |
|
It has two construction methods given the user's data files patterns : |
|
- ``from_hf_repo``: resolve patterns inside a dataset repository |
|
- ``from_local_or_remote``: resolve patterns from a local path |
|
|
|
Moreover, each list is a DataFilesList. It is possible to hash the dictionary |
|
and get a different hash if and only if at least one file changed. |
|
For more info, see [`DataFilesList`]. |
|
|
|
This is useful for caching Dataset objects that are obtained from a list of data files. |
|
|
|
Changing the order of the keys of this dictionary also doesn't change its hash. |
|
""" |
|
|
|
@classmethod |
|
def from_local_or_remote( |
|
cls, |
|
patterns: dict[str, Union[list[str], DataFilesList]], |
|
base_path: Optional[str] = None, |
|
allowed_extensions: Optional[list[str]] = None, |
|
download_config: Optional[DownloadConfig] = None, |
|
) -> "DataFilesDict": |
|
out = cls() |
|
for key, patterns_for_key in patterns.items(): |
|
out[key] = ( |
|
patterns_for_key |
|
if isinstance(patterns_for_key, DataFilesList) |
|
else DataFilesList.from_local_or_remote( |
|
patterns_for_key, |
|
base_path=base_path, |
|
allowed_extensions=allowed_extensions, |
|
download_config=download_config, |
|
) |
|
) |
|
return out |
|
|
|
@classmethod |
|
def from_hf_repo( |
|
cls, |
|
patterns: dict[str, Union[list[str], DataFilesList]], |
|
dataset_info: huggingface_hub.hf_api.DatasetInfo, |
|
base_path: Optional[str] = None, |
|
allowed_extensions: Optional[list[str]] = None, |
|
download_config: Optional[DownloadConfig] = None, |
|
) -> "DataFilesDict": |
|
out = cls() |
|
for key, patterns_for_key in patterns.items(): |
|
out[key] = ( |
|
patterns_for_key |
|
if isinstance(patterns_for_key, DataFilesList) |
|
else DataFilesList.from_hf_repo( |
|
patterns_for_key, |
|
dataset_info=dataset_info, |
|
base_path=base_path, |
|
allowed_extensions=allowed_extensions, |
|
download_config=download_config, |
|
) |
|
) |
|
return out |
|
|
|
@classmethod |
|
def from_patterns( |
|
cls, |
|
patterns: dict[str, Union[list[str], DataFilesList]], |
|
base_path: Optional[str] = None, |
|
allowed_extensions: Optional[list[str]] = None, |
|
download_config: Optional[DownloadConfig] = None, |
|
) -> "DataFilesDict": |
|
out = cls() |
|
for key, patterns_for_key in patterns.items(): |
|
out[key] = ( |
|
patterns_for_key |
|
if isinstance(patterns_for_key, DataFilesList) |
|
else DataFilesList.from_patterns( |
|
patterns_for_key, |
|
base_path=base_path, |
|
allowed_extensions=allowed_extensions, |
|
download_config=download_config, |
|
) |
|
) |
|
return out |
|
|
|
def filter( |
|
self, *, extensions: Optional[list[str]] = None, file_names: Optional[list[str]] = None |
|
) -> "DataFilesDict": |
|
out = type(self)() |
|
for key, data_files_list in self.items(): |
|
out[key] = data_files_list.filter(extensions=extensions, file_names=file_names) |
|
return out |
|
|
|
|
|
class DataFilesPatternsList(list[str]): |
|
""" |
|
List of data files patterns (absolute local paths or URLs). |
|
For each pattern there should also be a list of allowed extensions |
|
to keep, or a None ot keep all the files for the pattern. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
patterns: list[str], |
|
allowed_extensions: list[Optional[list[str]]], |
|
): |
|
super().__init__(patterns) |
|
self.allowed_extensions = allowed_extensions |
|
|
|
def __add__(self, other): |
|
return DataFilesList([*self, *other], self.allowed_extensions + other.allowed_extensions) |
|
|
|
@classmethod |
|
def from_patterns( |
|
cls, patterns: list[str], allowed_extensions: Optional[list[str]] = None |
|
) -> "DataFilesPatternsList": |
|
return cls(patterns, [allowed_extensions] * len(patterns)) |
|
|
|
def resolve( |
|
self, |
|
base_path: str, |
|
download_config: Optional[DownloadConfig] = None, |
|
) -> "DataFilesList": |
|
base_path = base_path if base_path is not None else Path().resolve().as_posix() |
|
data_files = [] |
|
for pattern, allowed_extensions in zip(self, self.allowed_extensions): |
|
try: |
|
data_files.extend( |
|
resolve_pattern( |
|
pattern, |
|
base_path=base_path, |
|
allowed_extensions=allowed_extensions, |
|
download_config=download_config, |
|
) |
|
) |
|
except FileNotFoundError: |
|
if not has_magic(pattern): |
|
raise |
|
origin_metadata = _get_origin_metadata(data_files, download_config=download_config) |
|
return DataFilesList(data_files, origin_metadata) |
|
|
|
def filter_extensions(self, extensions: list[str]) -> "DataFilesPatternsList": |
|
return DataFilesPatternsList( |
|
self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions] |
|
) |
|
|
|
|
|
class DataFilesPatternsDict(dict[str, DataFilesPatternsList]): |
|
""" |
|
Dict of split_name -> list of data files patterns (absolute local paths or URLs). |
|
""" |
|
|
|
@classmethod |
|
def from_patterns( |
|
cls, patterns: dict[str, list[str]], allowed_extensions: Optional[list[str]] = None |
|
) -> "DataFilesPatternsDict": |
|
out = cls() |
|
for key, patterns_for_key in patterns.items(): |
|
out[key] = ( |
|
patterns_for_key |
|
if isinstance(patterns_for_key, DataFilesPatternsList) |
|
else DataFilesPatternsList.from_patterns( |
|
patterns_for_key, |
|
allowed_extensions=allowed_extensions, |
|
) |
|
) |
|
return out |
|
|
|
def resolve( |
|
self, |
|
base_path: str, |
|
download_config: Optional[DownloadConfig] = None, |
|
) -> "DataFilesDict": |
|
out = DataFilesDict() |
|
for key, data_files_patterns_list in self.items(): |
|
out[key] = data_files_patterns_list.resolve(base_path, download_config) |
|
return out |
|
|
|
def filter_extensions(self, extensions: list[str]) -> "DataFilesPatternsDict": |
|
out = type(self)() |
|
for key, data_files_patterns_list in self.items(): |
|
out[key] = data_files_patterns_list.filter_extensions(extensions) |
|
return out |
|
|