jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import enum
import os
from typing import Optional
from huggingface_hub.utils import insecure_hashlib
from .. import config
from ..exceptions import (
ExpectedMoreDownloadedFilesError,
ExpectedMoreSplitsError,
NonMatchingChecksumError,
NonMatchingSplitsSizesError,
UnexpectedDownloadedFileError,
UnexpectedSplitsError,
)
from .logging import get_logger
logger = get_logger(__name__)
class VerificationMode(enum.Enum):
"""`Enum` that specifies which verification checks to run.
The default mode is `BASIC_CHECKS`, which will perform only rudimentary checks to avoid slowdowns
when generating/downloading a dataset for the first time.
The verification modes:
| | Verification checks |
|---------------------------|------------------------------------------------------------------------------ |
| `ALL_CHECKS` | Split checks, uniqueness of the keys yielded in case of the GeneratorBuilder |
| | and the validity (number of files, checksums, etc.) of downloaded files |
| `BASIC_CHECKS` (default) | Same as `ALL_CHECKS` but without checking downloaded files |
| `NO_CHECKS` | None |
"""
ALL_CHECKS = "all_checks"
BASIC_CHECKS = "basic_checks"
NO_CHECKS = "no_checks"
def verify_checksums(expected_checksums: Optional[dict], recorded_checksums: dict, verification_name=None):
if expected_checksums is None:
logger.info("Unable to verify checksums.")
return
if len(set(expected_checksums) - set(recorded_checksums)) > 0:
raise ExpectedMoreDownloadedFilesError(str(set(expected_checksums) - set(recorded_checksums)))
if len(set(recorded_checksums) - set(expected_checksums)) > 0:
raise UnexpectedDownloadedFileError(str(set(recorded_checksums) - set(expected_checksums)))
bad_urls = [url for url in expected_checksums if expected_checksums[url] != recorded_checksums[url]]
for_verification_name = " for " + verification_name if verification_name is not None else ""
if len(bad_urls) > 0:
raise NonMatchingChecksumError(
f"Checksums didn't match{for_verification_name}:\n"
f"{bad_urls}\n"
"Set `verification_mode='no_checks'` to skip checksums verification and ignore this error"
)
logger.info("All the checksums matched successfully" + for_verification_name)
def verify_splits(expected_splits: Optional[dict], recorded_splits: dict):
if expected_splits is None:
logger.info("Unable to verify splits sizes.")
return
if len(set(expected_splits) - set(recorded_splits)) > 0:
raise ExpectedMoreSplitsError(str(set(expected_splits) - set(recorded_splits)))
if len(set(recorded_splits) - set(expected_splits)) > 0:
raise UnexpectedSplitsError(str(set(recorded_splits) - set(expected_splits)))
bad_splits = [
{"expected": expected_splits[name], "recorded": recorded_splits[name]}
for name in expected_splits
if expected_splits[name].num_examples != recorded_splits[name].num_examples
]
if len(bad_splits) > 0:
raise NonMatchingSplitsSizesError(str(bad_splits))
logger.info("All the splits matched successfully.")
def get_size_checksum_dict(path: str, record_checksum: bool = True) -> dict:
"""Compute the file size and the sha256 checksum of a file"""
if record_checksum:
m = insecure_hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
m.update(chunk)
checksum = m.hexdigest()
else:
checksum = None
return {"num_bytes": os.path.getsize(path), "checksum": checksum}
def is_small_dataset(dataset_size):
"""Check if `dataset_size` is smaller than `config.IN_MEMORY_MAX_SIZE`.
Args:
dataset_size (int): Dataset size in bytes.
Returns:
bool: Whether `dataset_size` is smaller than `config.IN_MEMORY_MAX_SIZE`.
"""
if dataset_size and config.IN_MEMORY_MAX_SIZE:
return dataset_size < config.IN_MEMORY_MAX_SIZE
else:
return False