|
import enum |
|
import os |
|
from typing import Optional |
|
|
|
from huggingface_hub.utils import insecure_hashlib |
|
|
|
from .. import config |
|
from ..exceptions import ( |
|
ExpectedMoreDownloadedFilesError, |
|
ExpectedMoreSplitsError, |
|
NonMatchingChecksumError, |
|
NonMatchingSplitsSizesError, |
|
UnexpectedDownloadedFileError, |
|
UnexpectedSplitsError, |
|
) |
|
from .logging import get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class VerificationMode(enum.Enum): |
|
"""`Enum` that specifies which verification checks to run. |
|
|
|
The default mode is `BASIC_CHECKS`, which will perform only rudimentary checks to avoid slowdowns |
|
when generating/downloading a dataset for the first time. |
|
|
|
The verification modes: |
|
|
|
| | Verification checks | |
|
|---------------------------|------------------------------------------------------------------------------ | |
|
| `ALL_CHECKS` | Split checks, uniqueness of the keys yielded in case of the GeneratorBuilder | |
|
| | and the validity (number of files, checksums, etc.) of downloaded files | |
|
| `BASIC_CHECKS` (default) | Same as `ALL_CHECKS` but without checking downloaded files | |
|
| `NO_CHECKS` | None | |
|
|
|
""" |
|
|
|
ALL_CHECKS = "all_checks" |
|
BASIC_CHECKS = "basic_checks" |
|
NO_CHECKS = "no_checks" |
|
|
|
|
|
def verify_checksums(expected_checksums: Optional[dict], recorded_checksums: dict, verification_name=None): |
|
if expected_checksums is None: |
|
logger.info("Unable to verify checksums.") |
|
return |
|
if len(set(expected_checksums) - set(recorded_checksums)) > 0: |
|
raise ExpectedMoreDownloadedFilesError(str(set(expected_checksums) - set(recorded_checksums))) |
|
if len(set(recorded_checksums) - set(expected_checksums)) > 0: |
|
raise UnexpectedDownloadedFileError(str(set(recorded_checksums) - set(expected_checksums))) |
|
bad_urls = [url for url in expected_checksums if expected_checksums[url] != recorded_checksums[url]] |
|
for_verification_name = " for " + verification_name if verification_name is not None else "" |
|
if len(bad_urls) > 0: |
|
raise NonMatchingChecksumError( |
|
f"Checksums didn't match{for_verification_name}:\n" |
|
f"{bad_urls}\n" |
|
"Set `verification_mode='no_checks'` to skip checksums verification and ignore this error" |
|
) |
|
logger.info("All the checksums matched successfully" + for_verification_name) |
|
|
|
|
|
def verify_splits(expected_splits: Optional[dict], recorded_splits: dict): |
|
if expected_splits is None: |
|
logger.info("Unable to verify splits sizes.") |
|
return |
|
if len(set(expected_splits) - set(recorded_splits)) > 0: |
|
raise ExpectedMoreSplitsError(str(set(expected_splits) - set(recorded_splits))) |
|
if len(set(recorded_splits) - set(expected_splits)) > 0: |
|
raise UnexpectedSplitsError(str(set(recorded_splits) - set(expected_splits))) |
|
bad_splits = [ |
|
{"expected": expected_splits[name], "recorded": recorded_splits[name]} |
|
for name in expected_splits |
|
if expected_splits[name].num_examples != recorded_splits[name].num_examples |
|
] |
|
if len(bad_splits) > 0: |
|
raise NonMatchingSplitsSizesError(str(bad_splits)) |
|
logger.info("All the splits matched successfully.") |
|
|
|
|
|
def get_size_checksum_dict(path: str, record_checksum: bool = True) -> dict: |
|
"""Compute the file size and the sha256 checksum of a file""" |
|
if record_checksum: |
|
m = insecure_hashlib.sha256() |
|
with open(path, "rb") as f: |
|
for chunk in iter(lambda: f.read(1 << 20), b""): |
|
m.update(chunk) |
|
checksum = m.hexdigest() |
|
else: |
|
checksum = None |
|
return {"num_bytes": os.path.getsize(path), "checksum": checksum} |
|
|
|
|
|
def is_small_dataset(dataset_size): |
|
"""Check if `dataset_size` is smaller than `config.IN_MEMORY_MAX_SIZE`. |
|
|
|
Args: |
|
dataset_size (int): Dataset size in bytes. |
|
|
|
Returns: |
|
bool: Whether `dataset_size` is smaller than `config.IN_MEMORY_MAX_SIZE`. |
|
""" |
|
if dataset_size and config.IN_MEMORY_MAX_SIZE: |
|
return dataset_size < config.IN_MEMORY_MAX_SIZE |
|
else: |
|
return False |
|
|