jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import re
import textwrap
from collections import Counter
from itertools import groupby
from operator import itemgetter
from typing import Any, ClassVar, Optional
import yaml
from huggingface_hub import DatasetCardData
from ..config import METADATA_CONFIGS_FIELD
from ..features import Features
from ..info import DatasetInfo, DatasetInfosDict
from ..naming import _split_re
from ..utils.logging import get_logger
logger = get_logger(__name__)
class _NoDuplicateSafeLoader(yaml.SafeLoader):
def _check_no_duplicates_on_constructed_node(self, node):
keys = [self.constructed_objects[key_node] for key_node, _ in node.value]
keys = [tuple(key) if isinstance(key, list) else key for key in keys]
counter = Counter(keys)
duplicate_keys = [key for key in counter if counter[key] > 1]
if duplicate_keys:
raise TypeError(f"Got duplicate yaml keys: {duplicate_keys}")
def construct_mapping(self, node, deep=False):
mapping = super().construct_mapping(node, deep=deep)
self._check_no_duplicates_on_constructed_node(node)
return mapping
def _split_yaml_from_readme(readme_content: str) -> tuple[Optional[str], str]:
full_content = list(readme_content.splitlines())
if full_content and full_content[0] == "---" and "---" in full_content[1:]:
sep_idx = full_content[1:].index("---") + 1
yamlblock = "\n".join(full_content[1:sep_idx])
return yamlblock, "\n".join(full_content[sep_idx + 1 :])
return None, "\n".join(full_content)
class MetadataConfigs(dict[str, dict[str, Any]]):
"""Should be in format {config_name: {**config_params}}."""
FIELD_NAME: ClassVar[str] = METADATA_CONFIGS_FIELD
@staticmethod
def _raise_if_data_files_field_not_valid(metadata_config: dict):
yaml_data_files = metadata_config.get("data_files")
if yaml_data_files is not None:
yaml_error_message = textwrap.dedent(
f"""
Expected data_files in YAML to be either a string or a list of strings
or a list of dicts with two keys: 'split' and 'path', but got {yaml_data_files}
Examples of data_files in YAML:
data_files: data.csv
data_files: data/*.png
data_files:
- part0/*
- part1/*
data_files:
- split: train
path: train/*
- split: test
path: test/*
data_files:
- split: train
path:
- train/part1/*
- train/part2/*
- split: test
path: test/*
PS: some symbols like dashes '-' are not allowed in split names
"""
)
if not isinstance(yaml_data_files, (list, str)):
raise ValueError(yaml_error_message)
if isinstance(yaml_data_files, list):
for yaml_data_files_item in yaml_data_files:
if (
not isinstance(yaml_data_files_item, (str, dict))
or isinstance(yaml_data_files_item, dict)
and not (
len(yaml_data_files_item) == 2
and "split" in yaml_data_files_item
and re.match(_split_re, yaml_data_files_item["split"])
and isinstance(yaml_data_files_item.get("path"), (str, list))
)
):
raise ValueError(yaml_error_message)
@classmethod
def _from_exported_parquet_files_and_dataset_infos(
cls,
parquet_commit_hash: str,
exported_parquet_files: list[dict[str, Any]],
dataset_infos: DatasetInfosDict,
) -> "MetadataConfigs":
metadata_configs = {
config_name: {
"data_files": [
{
"split": split_name,
"path": [
parquet_file["url"].replace("refs%2Fconvert%2Fparquet", parquet_commit_hash)
for parquet_file in parquet_files_for_split
],
}
for split_name, parquet_files_for_split in groupby(parquet_files_for_config, itemgetter("split"))
],
"version": str(dataset_infos.get(config_name, DatasetInfo()).version or "0.0.0"),
}
for config_name, parquet_files_for_config in groupby(exported_parquet_files, itemgetter("config"))
}
if dataset_infos:
# Preserve order of configs and splits
metadata_configs = {
config_name: {
"data_files": [
data_file
for split_name in dataset_info.splits
for data_file in metadata_configs[config_name]["data_files"]
if data_file["split"] == split_name
],
"version": metadata_configs[config_name]["version"],
}
for config_name, dataset_info in dataset_infos.items()
}
return cls(metadata_configs)
@classmethod
def from_dataset_card_data(cls, dataset_card_data: DatasetCardData) -> "MetadataConfigs":
if dataset_card_data.get(cls.FIELD_NAME):
metadata_configs = dataset_card_data[cls.FIELD_NAME]
if not isinstance(metadata_configs, list):
raise ValueError(f"Expected {cls.FIELD_NAME} to be a list, but got '{metadata_configs}'")
for metadata_config in metadata_configs:
if "config_name" not in metadata_config:
raise ValueError(
f"Each config must include `config_name` field with a string name of a config, "
f"but got {metadata_config}. "
)
cls._raise_if_data_files_field_not_valid(metadata_config)
return cls(
{
config.pop("config_name"): {
param: value if param != "features" else Features._from_yaml_list(value)
for param, value in config.items()
}
for metadata_config in metadata_configs
if (config := metadata_config.copy())
}
)
return cls()
def to_dataset_card_data(self, dataset_card_data: DatasetCardData) -> None:
if self:
for metadata_config in self.values():
self._raise_if_data_files_field_not_valid(metadata_config)
current_metadata_configs = self.from_dataset_card_data(dataset_card_data)
total_metadata_configs = dict(sorted({**current_metadata_configs, **self}.items()))
for config_name, config_metadata in total_metadata_configs.items():
config_metadata.pop("config_name", None)
dataset_card_data[self.FIELD_NAME] = [
{"config_name": config_name, **config_metadata}
for config_name, config_metadata in total_metadata_configs.items()
]
def get_default_config_name(self) -> Optional[str]:
default_config_name = None
for config_name, metadata_config in self.items():
if len(self) == 1 or config_name == "default" or metadata_config.get("default"):
if default_config_name is None:
default_config_name = config_name
else:
raise ValueError(
f"Dataset has several default configs: '{default_config_name}' and '{config_name}'."
)
return default_config_name
# DEPRECATED - just here to support old versions of evaluate like 0.2.2
# To support new tasks on the Hugging Face Hub, please open a PR for this file:
# https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/pipelines.ts
known_task_ids = {
"image-classification": [],
"translation": [],
"image-segmentation": [],
"fill-mask": [],
"automatic-speech-recognition": [],
"token-classification": [],
"sentence-similarity": [],
"audio-classification": [],
"question-answering": [],
"summarization": [],
"zero-shot-classification": [],
"table-to-text": [],
"feature-extraction": [],
"other": [],
"multiple-choice": [],
"text-classification": [],
"text-to-image": [],
"text2text-generation": [],
"zero-shot-image-classification": [],
"tabular-classification": [],
"tabular-regression": [],
"image-to-image": [],
"tabular-to-text": [],
"unconditional-image-generation": [],
"text-retrieval": [],
"text-to-speech": [],
"object-detection": [],
"audio-to-audio": [],
"text-generation": [],
"conversational": [],
"table-question-answering": [],
"visual-question-answering": [],
"image-to-text": [],
"reinforcement-learning": [],
"voice-activity-detection": [],
"time-series-forecasting": [],
"document-question-answering": [],
}