File size: 9,115 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import logging
import os
from argparse import ArgumentParser
from collections.abc import Generator
from pathlib import Path
from shutil import copyfile, rmtree
from typing import Optional
import datasets.config
from datasets.builder import DatasetBuilder
from datasets.commands import BaseDatasetsCLICommand
from datasets.download.download_manager import DownloadMode
from datasets.load import dataset_module_factory, import_main_class
from datasets.utils.info_utils import VerificationMode
from datasets.utils.logging import ERROR, get_logger
logger = get_logger(__name__)
def _test_command_factory(args):
return TestCommand(
args.dataset,
args.name,
args.cache_dir,
args.data_dir,
args.all_configs,
args.save_info or args.save_infos,
args.ignore_verifications,
args.force_redownload,
args.clear_cache,
args.num_proc,
args.trust_remote_code,
)
class TestCommand(BaseDatasetsCLICommand):
__test__ = False # to tell pytest it's not a test class
@staticmethod
def register_subcommand(parser: ArgumentParser):
test_parser = parser.add_parser("test", help="Test dataset implementation.")
test_parser.add_argument("--name", type=str, default=None, help="Dataset processing name")
test_parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="Cache directory where the datasets are stored.",
)
test_parser.add_argument(
"--data_dir",
type=str,
default=None,
help="Can be used to specify a manual directory to get the files from.",
)
test_parser.add_argument("--all_configs", action="store_true", help="Test all dataset configurations")
test_parser.add_argument(
"--save_info", action="store_true", help="Save the dataset infos in the dataset card (README.md)"
)
test_parser.add_argument(
"--ignore_verifications",
action="store_true",
help="Run the test without checksums and splits checks.",
)
test_parser.add_argument("--force_redownload", action="store_true", help="Force dataset redownload")
test_parser.add_argument(
"--clear_cache",
action="store_true",
help="Remove downloaded files and cached datasets after each config test",
)
test_parser.add_argument("--num_proc", type=int, default=None, help="Number of processes")
test_parser.add_argument(
"--trust_remote_code", action="store_true", help="whether to trust the code execution of the load script"
)
# aliases
test_parser.add_argument("--save_infos", action="store_true", help="alias to save_info")
test_parser.add_argument("dataset", type=str, help="Name of the dataset to download")
test_parser.set_defaults(func=_test_command_factory)
def __init__(
self,
dataset: str,
name: str,
cache_dir: str,
data_dir: str,
all_configs: bool,
save_infos: bool,
ignore_verifications: bool,
force_redownload: bool,
clear_cache: bool,
num_proc: int,
trust_remote_code: Optional[bool],
):
self._dataset = dataset
self._name = name
self._cache_dir = cache_dir
self._data_dir = data_dir
self._all_configs = all_configs
self._save_infos = save_infos
self._ignore_verifications = ignore_verifications
self._force_redownload = force_redownload
self._clear_cache = clear_cache
self._num_proc = num_proc
self._trust_remote_code = trust_remote_code
if clear_cache and not cache_dir:
print(
"When --clear_cache is used, specifying a cache directory is mandatory.\n"
"The 'download' folder of the cache directory and the dataset builder cache will be deleted after each configuration test.\n"
"Please provide a --cache_dir that will be used to test the dataset script."
)
exit(1)
if save_infos:
self._ignore_verifications = True
def run(self):
logging.getLogger("filelock").setLevel(ERROR)
if self._name is not None and self._all_configs:
print("Both parameters `config` and `all_configs` can't be used at once.")
exit(1)
path, config_name = self._dataset, self._name
module = dataset_module_factory(path, trust_remote_code=self._trust_remote_code)
builder_cls = import_main_class(module.module_path)
n_builders = len(builder_cls.BUILDER_CONFIGS) if self._all_configs and builder_cls.BUILDER_CONFIGS else 1
def get_builders() -> Generator[DatasetBuilder, None, None]:
if self._all_configs and builder_cls.BUILDER_CONFIGS:
for i, config in enumerate(builder_cls.BUILDER_CONFIGS):
if "config_name" in module.builder_kwargs:
yield builder_cls(
cache_dir=self._cache_dir,
data_dir=self._data_dir,
**module.builder_kwargs,
)
else:
yield builder_cls(
config_name=config.name,
cache_dir=self._cache_dir,
data_dir=self._data_dir,
**module.builder_kwargs,
)
else:
if "config_name" in module.builder_kwargs:
yield builder_cls(cache_dir=self._cache_dir, data_dir=self._data_dir, **module.builder_kwargs)
else:
yield builder_cls(
config_name=config_name,
cache_dir=self._cache_dir,
data_dir=self._data_dir,
**module.builder_kwargs,
)
for j, builder in enumerate(get_builders()):
print(f"Testing builder '{builder.config.name}' ({j + 1}/{n_builders})")
builder._record_infos = os.path.exists(
os.path.join(builder.get_imported_module_dir(), datasets.config.DATASETDICT_INFOS_FILENAME)
) # record checksums only if we need to update a (deprecated) dataset_infos.json
builder.download_and_prepare(
download_mode=DownloadMode.REUSE_CACHE_IF_EXISTS
if not self._force_redownload
else DownloadMode.FORCE_REDOWNLOAD,
verification_mode=VerificationMode.NO_CHECKS
if self._ignore_verifications
else VerificationMode.ALL_CHECKS,
num_proc=self._num_proc,
)
builder.as_dataset()
if self._save_infos:
builder._save_infos()
# If save_infos=True, the dataset card (README.md) is created next to the loaded module file.
# The dataset_infos are saved in the YAML part of the README.md
# Let's move it to the original directory of the dataset script, to allow the user to
# upload them on S3 at the same time afterwards.
if self._save_infos:
dataset_readme_path = os.path.join(
builder_cls.get_imported_module_dir(), datasets.config.REPOCARD_FILENAME
)
name = Path(path).name + ".py"
combined_path = os.path.join(path, name)
if os.path.isfile(path):
dataset_dir = os.path.dirname(path)
elif os.path.isfile(combined_path):
dataset_dir = path
elif os.path.isdir(path): # for local directories containing only data files
dataset_dir = path
else: # in case of a remote dataset
dataset_dir = None
print(f"Dataset card saved at {dataset_readme_path}")
# Move dataset_info back to the user
if dataset_dir is not None:
user_dataset_readme_path = os.path.join(dataset_dir, datasets.config.REPOCARD_FILENAME)
copyfile(dataset_readme_path, user_dataset_readme_path)
print(f"Dataset card saved at {user_dataset_readme_path}")
# If clear_cache=True, the download folder and the dataset builder cache directory are deleted
if self._clear_cache:
if os.path.isdir(builder._cache_dir):
logger.warning(f"Clearing cache at {builder._cache_dir}")
rmtree(builder._cache_dir)
download_dir = os.path.join(self._cache_dir, datasets.config.DOWNLOADED_DATASETS_DIR)
if os.path.isdir(download_dir):
logger.warning(f"Clearing cache at {download_dir}")
rmtree(download_dir)
print("Test successful.")
|