|
import logging |
|
import os |
|
from argparse import ArgumentParser |
|
from collections.abc import Generator |
|
from pathlib import Path |
|
from shutil import copyfile, rmtree |
|
from typing import Optional |
|
|
|
import datasets.config |
|
from datasets.builder import DatasetBuilder |
|
from datasets.commands import BaseDatasetsCLICommand |
|
from datasets.download.download_manager import DownloadMode |
|
from datasets.load import dataset_module_factory, import_main_class |
|
from datasets.utils.info_utils import VerificationMode |
|
from datasets.utils.logging import ERROR, get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def _test_command_factory(args): |
|
return TestCommand( |
|
args.dataset, |
|
args.name, |
|
args.cache_dir, |
|
args.data_dir, |
|
args.all_configs, |
|
args.save_info or args.save_infos, |
|
args.ignore_verifications, |
|
args.force_redownload, |
|
args.clear_cache, |
|
args.num_proc, |
|
args.trust_remote_code, |
|
) |
|
|
|
|
|
class TestCommand(BaseDatasetsCLICommand): |
|
__test__ = False |
|
|
|
@staticmethod |
|
def register_subcommand(parser: ArgumentParser): |
|
test_parser = parser.add_parser("test", help="Test dataset implementation.") |
|
test_parser.add_argument("--name", type=str, default=None, help="Dataset processing name") |
|
test_parser.add_argument( |
|
"--cache_dir", |
|
type=str, |
|
default=None, |
|
help="Cache directory where the datasets are stored.", |
|
) |
|
test_parser.add_argument( |
|
"--data_dir", |
|
type=str, |
|
default=None, |
|
help="Can be used to specify a manual directory to get the files from.", |
|
) |
|
test_parser.add_argument("--all_configs", action="store_true", help="Test all dataset configurations") |
|
test_parser.add_argument( |
|
"--save_info", action="store_true", help="Save the dataset infos in the dataset card (README.md)" |
|
) |
|
test_parser.add_argument( |
|
"--ignore_verifications", |
|
action="store_true", |
|
help="Run the test without checksums and splits checks.", |
|
) |
|
test_parser.add_argument("--force_redownload", action="store_true", help="Force dataset redownload") |
|
test_parser.add_argument( |
|
"--clear_cache", |
|
action="store_true", |
|
help="Remove downloaded files and cached datasets after each config test", |
|
) |
|
test_parser.add_argument("--num_proc", type=int, default=None, help="Number of processes") |
|
test_parser.add_argument( |
|
"--trust_remote_code", action="store_true", help="whether to trust the code execution of the load script" |
|
) |
|
|
|
test_parser.add_argument("--save_infos", action="store_true", help="alias to save_info") |
|
test_parser.add_argument("dataset", type=str, help="Name of the dataset to download") |
|
test_parser.set_defaults(func=_test_command_factory) |
|
|
|
def __init__( |
|
self, |
|
dataset: str, |
|
name: str, |
|
cache_dir: str, |
|
data_dir: str, |
|
all_configs: bool, |
|
save_infos: bool, |
|
ignore_verifications: bool, |
|
force_redownload: bool, |
|
clear_cache: bool, |
|
num_proc: int, |
|
trust_remote_code: Optional[bool], |
|
): |
|
self._dataset = dataset |
|
self._name = name |
|
self._cache_dir = cache_dir |
|
self._data_dir = data_dir |
|
self._all_configs = all_configs |
|
self._save_infos = save_infos |
|
self._ignore_verifications = ignore_verifications |
|
self._force_redownload = force_redownload |
|
self._clear_cache = clear_cache |
|
self._num_proc = num_proc |
|
self._trust_remote_code = trust_remote_code |
|
if clear_cache and not cache_dir: |
|
print( |
|
"When --clear_cache is used, specifying a cache directory is mandatory.\n" |
|
"The 'download' folder of the cache directory and the dataset builder cache will be deleted after each configuration test.\n" |
|
"Please provide a --cache_dir that will be used to test the dataset script." |
|
) |
|
exit(1) |
|
if save_infos: |
|
self._ignore_verifications = True |
|
|
|
def run(self): |
|
logging.getLogger("filelock").setLevel(ERROR) |
|
if self._name is not None and self._all_configs: |
|
print("Both parameters `config` and `all_configs` can't be used at once.") |
|
exit(1) |
|
path, config_name = self._dataset, self._name |
|
module = dataset_module_factory(path, trust_remote_code=self._trust_remote_code) |
|
builder_cls = import_main_class(module.module_path) |
|
n_builders = len(builder_cls.BUILDER_CONFIGS) if self._all_configs and builder_cls.BUILDER_CONFIGS else 1 |
|
|
|
def get_builders() -> Generator[DatasetBuilder, None, None]: |
|
if self._all_configs and builder_cls.BUILDER_CONFIGS: |
|
for i, config in enumerate(builder_cls.BUILDER_CONFIGS): |
|
if "config_name" in module.builder_kwargs: |
|
yield builder_cls( |
|
cache_dir=self._cache_dir, |
|
data_dir=self._data_dir, |
|
**module.builder_kwargs, |
|
) |
|
else: |
|
yield builder_cls( |
|
config_name=config.name, |
|
cache_dir=self._cache_dir, |
|
data_dir=self._data_dir, |
|
**module.builder_kwargs, |
|
) |
|
else: |
|
if "config_name" in module.builder_kwargs: |
|
yield builder_cls(cache_dir=self._cache_dir, data_dir=self._data_dir, **module.builder_kwargs) |
|
else: |
|
yield builder_cls( |
|
config_name=config_name, |
|
cache_dir=self._cache_dir, |
|
data_dir=self._data_dir, |
|
**module.builder_kwargs, |
|
) |
|
|
|
for j, builder in enumerate(get_builders()): |
|
print(f"Testing builder '{builder.config.name}' ({j + 1}/{n_builders})") |
|
builder._record_infos = os.path.exists( |
|
os.path.join(builder.get_imported_module_dir(), datasets.config.DATASETDICT_INFOS_FILENAME) |
|
) |
|
builder.download_and_prepare( |
|
download_mode=DownloadMode.REUSE_CACHE_IF_EXISTS |
|
if not self._force_redownload |
|
else DownloadMode.FORCE_REDOWNLOAD, |
|
verification_mode=VerificationMode.NO_CHECKS |
|
if self._ignore_verifications |
|
else VerificationMode.ALL_CHECKS, |
|
num_proc=self._num_proc, |
|
) |
|
builder.as_dataset() |
|
if self._save_infos: |
|
builder._save_infos() |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self._save_infos: |
|
dataset_readme_path = os.path.join( |
|
builder_cls.get_imported_module_dir(), datasets.config.REPOCARD_FILENAME |
|
) |
|
name = Path(path).name + ".py" |
|
combined_path = os.path.join(path, name) |
|
if os.path.isfile(path): |
|
dataset_dir = os.path.dirname(path) |
|
elif os.path.isfile(combined_path): |
|
dataset_dir = path |
|
elif os.path.isdir(path): |
|
dataset_dir = path |
|
else: |
|
dataset_dir = None |
|
print(f"Dataset card saved at {dataset_readme_path}") |
|
|
|
|
|
if dataset_dir is not None: |
|
user_dataset_readme_path = os.path.join(dataset_dir, datasets.config.REPOCARD_FILENAME) |
|
copyfile(dataset_readme_path, user_dataset_readme_path) |
|
print(f"Dataset card saved at {user_dataset_readme_path}") |
|
|
|
|
|
if self._clear_cache: |
|
if os.path.isdir(builder._cache_dir): |
|
logger.warning(f"Clearing cache at {builder._cache_dir}") |
|
rmtree(builder._cache_dir) |
|
download_dir = os.path.join(self._cache_dir, datasets.config.DOWNLOADED_DATASETS_DIR) |
|
if os.path.isdir(download_dir): |
|
logger.warning(f"Clearing cache at {download_dir}") |
|
rmtree(download_dir) |
|
|
|
print("Test successful.") |
|
|