File size: 90,575 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 |
# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""DatasetBuilder base class."""
import abc
import contextlib
import copy
import inspect
import os
import posixpath
import shutil
import textwrap
import time
import urllib
from collections.abc import Iterable, Mapping
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from unittest.mock import patch
import fsspec
from fsspec.core import url_to_fs
from multiprocess import Pool
from tqdm.contrib.concurrent import thread_map
from . import config, utils
from .arrow_dataset import Dataset
from .arrow_reader import (
ArrowReader,
ReadInstruction,
)
from .arrow_writer import ArrowWriter, ParquetWriter, SchemaInferenceError
from .data_files import DataFilesDict, DataFilesPatternsDict, sanitize_patterns
from .dataset_dict import DatasetDict, IterableDatasetDict
from .download.download_config import DownloadConfig
from .download.download_manager import DownloadManager, DownloadMode
from .download.streaming_download_manager import StreamingDownloadManager, xjoin
from .exceptions import DatasetGenerationCastError, DatasetGenerationError, FileFormatError, ManualDownloadError
from .features import Features
from .filesystems import (
is_remote_filesystem,
rename,
)
from .fingerprint import Hasher
from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo
from .iterable_dataset import ArrowExamplesIterable, ExamplesIterable, IterableDataset
from .keyhash import DuplicatedKeysError
from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase
from .splits import Split, SplitDict, SplitGenerator, SplitInfo
from .streaming import extend_dataset_builder_for_streaming
from .table import CastError
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils._filelock import FileLock
from .utils.file_utils import is_remote_url
from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits
from .utils.py_utils import (
classproperty,
convert_file_size_to_int,
has_sufficient_disk_space,
iflatmap_unordered,
map_nested,
memoize,
size_str,
temporary_assignment,
)
from .utils.sharding import _number_of_shards_in_gen_kwargs, _split_gen_kwargs
from .utils.track import tracked_list
if TYPE_CHECKING:
from .load import DatasetModule
logger = logging.get_logger(__name__)
class InvalidConfigName(ValueError):
pass
@dataclass
class BuilderConfig:
"""Base class for `DatasetBuilder` data configuration.
`DatasetBuilder` subclasses with data configuration options should subclass
`BuilderConfig` and add their own properties.
Attributes:
name (`str`, defaults to `default`):
The name of the configuration.
version (`Version` or `str`, defaults to `0.0.0`):
The version of the configuration.
data_dir (`str`, *optional*):
Path to the directory containing the source data.
data_files (`str` or `Sequence` or `Mapping`, *optional*):
Path(s) to source data file(s).
description (`str`, *optional*):
A human description of the configuration.
"""
name: str = "default"
version: Optional[Union[utils.Version, str]] = utils.Version("0.0.0")
data_dir: Optional[str] = None
data_files: Optional[Union[DataFilesDict, DataFilesPatternsDict]] = None
description: Optional[str] = None
def __post_init__(self):
# The config name is used to name the cache directory.
for invalid_char in INVALID_WINDOWS_CHARACTERS_IN_PATH:
if invalid_char in self.name:
raise InvalidConfigName(
f"Bad characters from black list '{INVALID_WINDOWS_CHARACTERS_IN_PATH}' found in '{self.name}'. "
f"They could create issues when creating a directory for this config on Windows filesystem."
)
if self.data_files is not None and not isinstance(self.data_files, (DataFilesDict, DataFilesPatternsDict)):
raise ValueError(f"Expected a DataFilesDict in data_files but got {self.data_files}")
def __eq__(self, o):
# we need to override the default dataclass __eq__ since it doesn't check for
# other attributes that the ones of the signature.
if set(self.__dict__.keys()) != set(o.__dict__.keys()):
return False
return all((k, getattr(self, k)) == (k, getattr(o, k)) for k in self.__dict__.keys())
def create_config_id(
self,
config_kwargs: dict,
custom_features: Optional[Features] = None,
) -> str:
"""
The config id is used to build the cache directory.
By default it is equal to the config name.
However the name of a config is not sufficient to have a unique identifier for the dataset being generated
since it doesn't take into account:
- the config kwargs that can be used to overwrite attributes
- the custom features used to write the dataset
- the data_files for json/text/csv/pandas datasets
Therefore the config id is just the config name with an optional suffix based on these.
"""
# Possibly add a suffix to the name to handle custom features/data_files/config_kwargs
suffix: Optional[str] = None
config_kwargs_to_add_to_suffix = config_kwargs.copy()
# name and version are already used to build the cache directory
config_kwargs_to_add_to_suffix.pop("name", None)
config_kwargs_to_add_to_suffix.pop("version", None)
# data dir handling (when specified it points to the manually downloaded data):
# it was previously ignored before the introduction of config id because we didn't want
# to change the config name. Now it's fine to take it into account for the config id.
# config_kwargs_to_add_to_suffix.pop("data_dir", None)
if "data_dir" in config_kwargs_to_add_to_suffix:
if config_kwargs_to_add_to_suffix["data_dir"] is None:
config_kwargs_to_add_to_suffix.pop("data_dir", None)
else:
# canonicalize the data dir to avoid two paths to the same location having different
# hashes
data_dir = config_kwargs_to_add_to_suffix["data_dir"]
data_dir = os.path.normpath(data_dir)
config_kwargs_to_add_to_suffix["data_dir"] = data_dir
if config_kwargs_to_add_to_suffix:
# we don't care about the order of the kwargs
config_kwargs_to_add_to_suffix = {
k: config_kwargs_to_add_to_suffix[k] for k in sorted(config_kwargs_to_add_to_suffix)
}
if all(isinstance(v, (str, bool, int, float)) for v in config_kwargs_to_add_to_suffix.values()):
suffix = ",".join(
str(k) + "=" + urllib.parse.quote_plus(str(v)) for k, v in config_kwargs_to_add_to_suffix.items()
)
if len(suffix) > 32: # hash if too long
suffix = Hasher.hash(config_kwargs_to_add_to_suffix)
else:
suffix = Hasher.hash(config_kwargs_to_add_to_suffix)
if custom_features is not None:
m = Hasher()
if suffix:
m.update(suffix)
m.update(custom_features)
suffix = m.hexdigest()
if suffix:
config_id = self.name + "-" + suffix
if len(config_id) > config.MAX_DATASET_CONFIG_ID_READABLE_LENGTH:
config_id = self.name + "-" + Hasher.hash(suffix)
return config_id
else:
return self.name
def _resolve_data_files(self, base_path: str, download_config: DownloadConfig) -> None:
if isinstance(self.data_files, DataFilesPatternsDict):
base_path = xjoin(base_path, self.data_dir) if self.data_dir else base_path
self.data_files = self.data_files.resolve(base_path, download_config)
class DatasetBuilder:
"""Abstract base class for all datasets.
`DatasetBuilder` has 3 key methods:
- [`DatasetBuilder.info`]: Documents the dataset, including feature
names, types, shapes, version, splits, citation, etc.
- [`DatasetBuilder.download_and_prepare`]: Downloads the source data
and writes it to disk.
- [`DatasetBuilder.as_dataset`]: Generates a [`Dataset`].
Some `DatasetBuilder`s expose multiple variants of the
dataset by defining a [`BuilderConfig`] subclass and accepting a
config object (or name) on construction. Configurable datasets expose a
pre-defined set of configurations in [`DatasetBuilder.builder_configs`].
Args:
cache_dir (`str`, *optional*):
Directory to cache data. Defaults to `"~/.cache/huggingface/datasets"`.
dataset_name (`str`, *optional*):
Name of the dataset, if different from the builder name. Useful for packaged builders
like csv, imagefolder, audiofolder, etc. to reflect the difference between datasets
that use the same packaged builder.
config_name (`str`, *optional*):
Name of the dataset configuration.
It affects the data generated on disk. Different configurations will have their own subdirectories and
versions.
If not provided, the default configuration is used (if it exists).
<Added version="2.3.0">
Parameter `name` was renamed to `config_name`.
</Added>
hash (`str`, *optional*):
Hash specific to the dataset code. Used to update the caching directory when the
dataset loading script code is updated (to avoid reusing old data).
The typical caching directory (defined in `self._relative_data_dir`) is `name/version/hash/`.
base_path (`str`, *optional*):
Base path for relative paths that are used to download files.
This can be a remote URL.
features ([`Features`], *optional*):
Features types to use with this dataset.
It can be used to change the [`Features`] types of a dataset, for example.
token (`str` or `bool`, *optional*):
String or boolean to use as Bearer token for remote files on the
Datasets Hub. If `True`, will get token from `"~/.huggingface"`.
repo_id (`str`, *optional*):
ID of the dataset repository.
Used to distinguish builders with the same name but not coming from the same namespace, for example "rajpurkar/squad"
and "lhoestq/squad" repo IDs. In the latter, the builder name would be "lhoestq___squad".
data_files (`str` or `Sequence` or `Mapping`, *optional*):
Path(s) to source data file(s).
For builders like "csv" or "json" that need the user to specify data files. They can be either
local or remote files. For convenience, you can use a `DataFilesDict`.
data_dir (`str`, *optional*):
Path to directory containing source data file(s).
Use only if `data_files` is not passed, in which case it is equivalent to passing
`os.path.join(data_dir, "**")` as `data_files`.
For builders that require manual download, it must be the path to the local directory containing the
manually downloaded data.
storage_options (`dict`, *optional*):
Key/value pairs to be passed on to the dataset file-system backend, if any.
writer_batch_size (`int`, *optional*):
Batch size used by the ArrowWriter.
It defines the number of samples that are kept in memory before writing them
and also the length of the arrow chunks.
None means that the ArrowWriter will use its default value.
**config_kwargs (additional keyword arguments): Keyword arguments to be passed to the corresponding builder
configuration class, set on the class attribute [`DatasetBuilder.BUILDER_CONFIG_CLASS`]. The builder
configuration class is [`BuilderConfig`] or a subclass of it.
"""
# Default version
VERSION = None # Default version set in BuilderConfig
# Class for the builder config.
BUILDER_CONFIG_CLASS = BuilderConfig
# Named configurations that modify the data generated by download_and_prepare.
BUILDER_CONFIGS = []
# Optional default config name to be used when name is None
DEFAULT_CONFIG_NAME = None
# Default batch size used by the ArrowWriter
# It defines the number of samples that are kept in memory before writing them
# and also the length of the arrow chunks
# None means that the ArrowWriter will use its default value
DEFAULT_WRITER_BATCH_SIZE = None
def __init__(
self,
cache_dir: Optional[str] = None,
dataset_name: Optional[str] = None,
config_name: Optional[str] = None,
hash: Optional[str] = None,
base_path: Optional[str] = None,
info: Optional[DatasetInfo] = None,
features: Optional[Features] = None,
token: Optional[Union[bool, str]] = None,
repo_id: Optional[str] = None,
data_files: Optional[Union[str, list, dict, DataFilesDict]] = None,
data_dir: Optional[str] = None,
storage_options: Optional[dict] = None,
writer_batch_size: Optional[int] = None,
**config_kwargs,
):
# DatasetBuilder name
self.name: str = camelcase_to_snakecase(self.__module__.split(".")[-1])
self.hash: Optional[str] = hash
self.base_path = base_path
self.token = token
self.repo_id = repo_id
self.storage_options = storage_options or {}
self.dataset_name = camelcase_to_snakecase(dataset_name) if dataset_name else self.name
self._writer_batch_size = writer_batch_size or self.DEFAULT_WRITER_BATCH_SIZE
if data_files is not None and not isinstance(data_files, DataFilesDict):
data_files = DataFilesDict.from_patterns(
sanitize_patterns(data_files),
base_path=base_path,
download_config=DownloadConfig(token=token, storage_options=self.storage_options),
)
# Prepare config: DatasetConfig contains name, version and description but can be extended by each dataset
if "features" in inspect.signature(self.BUILDER_CONFIG_CLASS.__init__).parameters and features is not None:
config_kwargs["features"] = features
if data_files is not None:
config_kwargs["data_files"] = data_files
if data_dir is not None:
config_kwargs["data_dir"] = data_dir
self.config_kwargs = config_kwargs
self.config, self.config_id = self._create_builder_config(
config_name=config_name,
custom_features=features,
**config_kwargs,
)
# prepare info: DatasetInfo are a standardized dataclass across all datasets
# Prefill datasetinfo
if info is None:
# TODO FOR PACKAGED MODULES IT IMPORTS DATA FROM src/packaged_modules which doesn't make sense
info = self.get_exported_dataset_info()
info.update(self._info())
info.builder_name = self.name
info.dataset_name = self.dataset_name
info.config_name = self.config.name
info.version = self.config.version
self.info = info
# update info with user specified infos
if features is not None:
self.info.features = features
# Prepare data dirs:
# cache_dir can be a remote bucket on GCS or S3
self._cache_dir_root = str(cache_dir or config.HF_DATASETS_CACHE)
self._cache_dir_root = (
self._cache_dir_root if is_remote_url(self._cache_dir_root) else os.path.expanduser(self._cache_dir_root)
)
self._cache_downloaded_dir = (
posixpath.join(self._cache_dir_root, config.DOWNLOADED_DATASETS_DIR)
if cache_dir
else str(config.DOWNLOADED_DATASETS_PATH)
)
self._cache_downloaded_dir = (
self._cache_downloaded_dir
if is_remote_url(self._cache_downloaded_dir)
else os.path.expanduser(self._cache_downloaded_dir)
)
# In case there exists a legacy cache directory
self._legacy_relative_data_dir = None
self._cache_dir = self._build_cache_dir()
if not is_remote_url(self._cache_dir_root):
os.makedirs(self._cache_dir_root, exist_ok=True)
lock_path = os.path.join(
self._cache_dir_root, Path(self._cache_dir).as_posix().replace("/", "_") + ".lock"
)
with FileLock(lock_path):
if os.path.exists(self._cache_dir): # check if data exist
if len(os.listdir(self._cache_dir)) > 0:
if os.path.exists(os.path.join(self._cache_dir, config.DATASET_INFO_FILENAME)):
logger.info("Overwrite dataset info from restored data version if exists.")
self.info = DatasetInfo.from_directory(self._cache_dir)
else: # dir exists but no data, remove the empty dir as data aren't available anymore
logger.warning(
f"Old caching folder {self._cache_dir} for dataset {self.dataset_name} exists but no data were found. Removing it. "
)
os.rmdir(self._cache_dir)
# Store in the cache by default unless the user specifies a custom output_dir to download_and_prepare
self._output_dir = self._cache_dir
self._fs: fsspec.AbstractFileSystem = fsspec.filesystem("file")
# Set download manager
self.dl_manager = None
# Set to True by "datasets-cli test" to generate file checksums for (deprecated) dataset_infos.json independently of verification_mode value.
self._record_infos = False
# Set in `.download_and_prepare` once the format of the generated dataset is known
self._file_format = None
# Enable streaming (e.g. it patches "open" to work with remote files)
extend_dataset_builder_for_streaming(self)
def __getstate__(self):
return self.__dict__
def __setstate__(self, d):
self.__dict__ = d
# Re-enable streaming, since patched functions are not kept when pickling
extend_dataset_builder_for_streaming(self)
# Must be set for datasets that use 'data_dir' functionality - the ones
# that require users to do additional steps to download the data
# (this is usually due to some external regulations / rules).
# This field should contain a string with user instructions, including
# the list of files that should be present. It will be
# displayed in the dataset documentation.
@property
def manual_download_instructions(self) -> Optional[str]:
return None
def _check_legacy_cache(self) -> Optional[str]:
"""Check for the old cache directory template {cache_dir}/{namespace}___{builder_name} from 2.13"""
if (
self.__module__.startswith("datasets.")
and not is_remote_url(self._cache_dir_root)
and self.config.name == "default"
):
from .packaged_modules import _PACKAGED_DATASETS_MODULES
namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name
config_id = config_name + self.config_id[len(self.config.name) :]
hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1]
legacy_relative_data_dir = posixpath.join(
self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}",
config_id,
"0.0.0",
hash,
)
legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir)
if os.path.isdir(legacy_cache_dir):
return legacy_relative_data_dir
def _check_legacy_cache2(self, dataset_module: "DatasetModule") -> Optional[str]:
"""Check for the old cache directory template {cache_dir}/{namespace}___{dataset_name}/{config_name}-xxx from 2.14 and 2.15"""
if (
self.__module__.startswith("datasets.")
and not is_remote_url(self._cache_dir_root)
and not (set(self.config_kwargs) - {"data_files", "data_dir"})
):
from .packaged_modules import _PACKAGED_DATASETS_MODULES_2_15_HASHES
from .utils._dill import Pickler
def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str:
"""
Used to update hash of packaged modules which is used for creating unique cache directories to reflect
different config parameters which are passed in metadata from readme.
"""
params_to_exclude = {"config_name", "version", "description"}
params_to_add_to_hash = {
param: value
for param, value in sorted(config_parameters.items())
if param not in params_to_exclude
}
m = Hasher()
m.update(hash)
m.update(params_to_add_to_hash)
return m.hexdigest()
namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
with patch.object(Pickler, "_legacy_no_dict_keys_sorting", True):
config_id = self.config.name + "-" + Hasher.hash({"data_files": self.config.data_files})
hash = _PACKAGED_DATASETS_MODULES_2_15_HASHES.get(self.name, "missing")
if (
dataset_module.builder_configs_parameters.metadata_configs
and self.config.name in dataset_module.builder_configs_parameters.metadata_configs
):
hash = update_hash_with_config_parameters(
hash, dataset_module.builder_configs_parameters.metadata_configs[self.config.name]
)
legacy_relative_data_dir = posixpath.join(
self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}",
config_id,
"0.0.0",
hash,
)
legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir)
if os.path.isdir(legacy_cache_dir):
return legacy_relative_data_dir
@classmethod
def get_all_exported_dataset_infos(cls) -> DatasetInfosDict:
"""Empty dict if doesn't exist
Example:
```py
>>> from datasets import load_dataset_builder
>>> ds_builder = load_dataset_builder('vivos')
>>> ds_builder.get_all_exported_dataset_infos()
{'default': DatasetInfo(description='', citation='', homepage='', license='', features={'speaker_id': Value(dtype='string', id=None), 'path': Value(dtype='string', id=None), 'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None), 'sentence': Value(dtype='string', id=None)}, post_processed=None, supervised_keys=None, builder_name=None, dataset_name=None, config_name='default', version=None, splits={'train': SplitInfo(name='train', num_bytes=1722002133, num_examples=11660, shard_lengths=None, dataset_name=None), 'test': SplitInfo(name='test', num_bytes=86120227, num_examples=760, shard_lengths=None, dataset_name=None)}, download_checksums=None, download_size=1475540500, post_processing_size=None, dataset_size=1808122360, size_in_bytes=None)}
```
"""
return DatasetInfosDict.from_directory(cls.get_imported_module_dir())
def get_exported_dataset_info(self) -> DatasetInfo:
"""Empty `DatasetInfo` if doesn't exist
Example:
```py
>>> from datasets import load_dataset_builder
>>> ds_builder = load_dataset_builder('cornell-movie-review-data/rotten_tomatoes')
>>> ds_builder.get_exported_dataset_info()
DatasetInfo(description='', citation='', homepage='', license='', features={'speaker_id': Value(dtype='string', id=None), 'path': Value(dtype='string', id=None), 'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None), 'sentence': Value(dtype='string', id=None)}, post_processed=None, supervised_keys=None, builder_name=None, dataset_name=None, config_name='default', version=None, splits={'train': SplitInfo(name='train', num_bytes=1722002133, num_examples=11660, shard_lengths=None, dataset_name=None), 'test': SplitInfo(name='test', num_bytes=86120227, num_examples=760, shard_lengths=None, dataset_name=None)}, download_checksums=None, download_size=1475540500, post_processing_size=None, dataset_size=1808122360, size_in_bytes=None)
```
"""
return self.get_all_exported_dataset_infos().get(self.config.name, DatasetInfo())
def _create_builder_config(
self, config_name=None, custom_features=None, **config_kwargs
) -> tuple[BuilderConfig, str]:
"""Create and validate BuilderConfig object as well as a unique config id for this config.
Raises ValueError if there are multiple builder configs and config_name and DEFAULT_CONFIG_NAME are None.
config_kwargs override the defaults kwargs in config
"""
builder_config = None
# try default config
if config_name is None and self.BUILDER_CONFIGS:
if self.DEFAULT_CONFIG_NAME is not None:
builder_config = self.builder_configs.get(self.DEFAULT_CONFIG_NAME)
logger.info(f"No config specified, defaulting to: {self.dataset_name}/{builder_config.name}")
else:
if len(self.BUILDER_CONFIGS) > 1:
if not config_kwargs:
example_of_usage = (
f"load_dataset('{self.repo_id or self.dataset_name}', '{self.BUILDER_CONFIGS[0].name}')"
)
raise ValueError(
"Config name is missing."
f"\nPlease pick one among the available configs: {list(self.builder_configs.keys())}"
+ f"\nExample of usage:\n\t`{example_of_usage}`"
)
else:
builder_config = self.BUILDER_CONFIGS[0]
logger.info(
f"No config specified, defaulting to the single config: {self.dataset_name}/{builder_config.name}"
)
# try to get config by name
if isinstance(config_name, str):
builder_config = self.builder_configs.get(config_name)
if builder_config is None and self.BUILDER_CONFIGS:
raise ValueError(
f"BuilderConfig '{config_name}' not found. Available: {list(self.builder_configs.keys())}"
)
# if not using an existing config, then create a new config on the fly
if not builder_config:
if config_name is not None:
config_kwargs["name"] = config_name
elif self.DEFAULT_CONFIG_NAME and not config_kwargs:
# Use DEFAULT_CONFIG_NAME only if no config_kwargs are passed
config_kwargs["name"] = self.DEFAULT_CONFIG_NAME
if "version" not in config_kwargs and hasattr(self, "VERSION") and self.VERSION:
config_kwargs["version"] = self.VERSION
builder_config = self.BUILDER_CONFIG_CLASS(**config_kwargs)
# otherwise use the config_kwargs to overwrite the attributes
else:
builder_config = copy.deepcopy(builder_config) if config_kwargs else builder_config
for key, value in config_kwargs.items():
if value is not None:
if not hasattr(builder_config, key):
raise ValueError(f"BuilderConfig {builder_config} doesn't have a '{key}' key.")
setattr(builder_config, key, value)
if not builder_config.name:
raise ValueError(f"BuilderConfig must have a name, got {builder_config.name}")
# resolve data files if needed
builder_config._resolve_data_files(
base_path=self.base_path,
download_config=DownloadConfig(token=self.token, storage_options=self.storage_options),
)
# compute the config id that is going to be used for caching
config_id = builder_config.create_config_id(
config_kwargs,
custom_features=custom_features,
)
is_custom = (config_id not in self.builder_configs) and config_id != "default"
if is_custom:
logger.info(f"Using custom data configuration {config_id}")
else:
if (
builder_config.name in self.builder_configs
and builder_config != self.builder_configs[builder_config.name]
):
raise ValueError(
"Cannot name a custom BuilderConfig the same as an available "
f"BuilderConfig. Change the name. Available BuilderConfigs: {list(self.builder_configs.keys())}"
)
if not builder_config.version:
raise ValueError(f"BuilderConfig {builder_config.name} must have a version")
return builder_config, config_id
@classproperty
@classmethod
@memoize()
def builder_configs(cls) -> dict[str, BuilderConfig]:
"""Dictionary of pre-defined configurations for this builder class."""
configs = {config.name: config for config in cls.BUILDER_CONFIGS}
if len(configs) != len(cls.BUILDER_CONFIGS):
names = [config.name for config in cls.BUILDER_CONFIGS]
raise ValueError(f"Names in BUILDER_CONFIGS must not be duplicated. Got {names}")
return configs
@property
def cache_dir(self):
return self._cache_dir
def _use_legacy_cache_dir_if_possible(self, dataset_module: "DatasetModule"):
# Check for the legacy cache directory template (datasets<3.0.0)
self._legacy_relative_data_dir = (
self._check_legacy_cache2(dataset_module) or self._check_legacy_cache() or None
)
self._cache_dir = self._build_cache_dir()
self._output_dir = self._cache_dir
def _relative_data_dir(self, with_version=True, with_hash=True) -> str:
"""Relative path of this dataset in cache_dir:
Will be:
self.dataset_name/self.config.version/self.hash/
or if a repo_id with a namespace has been specified:
self.namespace___self.dataset_name/self.config.version/self.hash/
If any of these element is missing or if ``with_version=False`` the corresponding subfolders are dropped.
"""
if self._legacy_relative_data_dir is not None and with_version and with_hash:
return self._legacy_relative_data_dir
namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
builder_data_dir = self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}"
builder_data_dir = posixpath.join(builder_data_dir, self.config_id)
if with_version:
builder_data_dir = posixpath.join(builder_data_dir, str(self.config.version))
if with_hash and self.hash and isinstance(self.hash, str):
builder_data_dir = posixpath.join(builder_data_dir, self.hash)
return builder_data_dir
def _build_cache_dir(self):
"""Return the data directory for the current version."""
builder_data_dir = posixpath.join(self._cache_dir_root, self._relative_data_dir(with_version=False))
version_data_dir = posixpath.join(self._cache_dir_root, self._relative_data_dir(with_version=True))
def _other_versions_on_disk():
"""Returns previous versions on disk."""
if not os.path.exists(builder_data_dir):
return []
version_dirnames = []
for dir_name in os.listdir(builder_data_dir):
try:
version_dirnames.append((utils.Version(dir_name), dir_name))
except ValueError: # Invalid version (ex: incomplete data dir)
pass
version_dirnames.sort(reverse=True)
return version_dirnames
# Check and warn if other versions exist
if not is_remote_url(builder_data_dir):
version_dirs = _other_versions_on_disk()
if version_dirs:
other_version = version_dirs[0][0]
if other_version != self.config.version:
warn_msg = (
f"Found a different version {str(other_version)} of dataset {self.dataset_name} in "
f"cache_dir {self._cache_dir_root}. Using currently defined version "
f"{str(self.config.version)}."
)
logger.warning(warn_msg)
return version_data_dir
@abc.abstractmethod
def _info(self) -> DatasetInfo:
"""Construct the DatasetInfo object. See `DatasetInfo` for details.
Warning: This function is only called once and the result is cached for all
following .info() calls.
Returns:
info: (DatasetInfo) The dataset information
"""
raise NotImplementedError
@classmethod
def get_imported_module_dir(cls):
"""Return the path of the module of this class or subclass."""
return os.path.dirname(inspect.getfile(inspect.getmodule(cls)))
def _rename(self, src: str, dst: str):
rename(self._fs, src, dst)
def download_and_prepare(
self,
output_dir: Optional[str] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[Union[DownloadMode, str]] = None,
verification_mode: Optional[Union[VerificationMode, str]] = None,
dl_manager: Optional[DownloadManager] = None,
base_path: Optional[str] = None,
file_format: str = "arrow",
max_shard_size: Optional[Union[int, str]] = None,
num_proc: Optional[int] = None,
storage_options: Optional[dict] = None,
**download_and_prepare_kwargs,
):
"""Downloads and prepares dataset for reading.
Args:
output_dir (`str`, *optional*):
Output directory for the dataset.
Default to this builder's `cache_dir`, which is inside `~/.cache/huggingface/datasets` by default.
<Added version="2.5.0"/>
download_config (`DownloadConfig`, *optional*):
Specific download configuration parameters.
download_mode ([`DownloadMode`] or `str`, *optional*):
Select the download/generate mode, default to `REUSE_DATASET_IF_EXISTS`.
verification_mode ([`VerificationMode`] or `str`, defaults to `BASIC_CHECKS`):
Verification mode determining the checks to run on the downloaded/processed dataset information (checksums/size/splits/...).
<Added version="2.9.1"/>
dl_manager (`DownloadManager`, *optional*):
Specific `DownloadManger` to use.
base_path (`str`, *optional*):
Base path for relative paths that are used to download files. This can be a remote url.
If not specified, the value of the `base_path` attribute (`self.base_path`) will be used instead.
file_format (`str`, *optional*):
Format of the data files in which the dataset will be written.
Supported formats: "arrow", "parquet". Default to "arrow" format.
If the format is "parquet", then image and audio data are embedded into the Parquet files instead of pointing to local files.
<Added version="2.5.0"/>
max_shard_size (`Union[str, int]`, *optional*):
Maximum number of bytes written per shard, default is "500MB".
The size is based on uncompressed data size, so in practice your shard files may be smaller than
`max_shard_size` thanks to Parquet compression for example.
<Added version="2.5.0"/>
num_proc (`int`, *optional*, defaults to `None`):
Number of processes when downloading and generating the dataset locally.
Multiprocessing is disabled by default.
<Added version="2.7.0"/>
storage_options (`dict`, *optional*):
Key/value pairs to be passed on to the caching file-system backend, if any.
<Added version="2.5.0"/>
**download_and_prepare_kwargs (additional keyword arguments): Keyword arguments.
Example:
Download and prepare the dataset as Arrow files that can be loaded as a Dataset using `builder.as_dataset()`:
```py
>>> from datasets import load_dataset_builder
>>> builder = load_dataset_builder("cornell-movie-review-data/rotten_tomatoes")
>>> builder.download_and_prepare()
```
Download and prepare the dataset as sharded Parquet files locally:
```py
>>> from datasets import load_dataset_builder
>>> builder = load_dataset_builder("cornell-movie-review-data/rotten_tomatoes")
>>> builder.download_and_prepare("./output_dir", file_format="parquet")
```
Download and prepare the dataset as sharded Parquet files in a cloud storage:
```py
>>> from datasets import load_dataset_builder
>>> storage_options = {"key": aws_access_key_id, "secret": aws_secret_access_key}
>>> builder = load_dataset_builder("cornell-movie-review-data/rotten_tomatoes")
>>> builder.download_and_prepare("s3://my-bucket/my_rotten_tomatoes", storage_options=storage_options, file_format="parquet")
```
"""
output_dir = output_dir if output_dir is not None else self._cache_dir
# output_dir can be a remote bucket on GCS or S3
fs, output_dir = url_to_fs(output_dir, **(storage_options or {}))
self._fs = fs
self._output_dir = output_dir if not is_remote_filesystem(self._fs) else self._fs.unstrip_protocol(output_dir)
download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS)
verification_mode = VerificationMode(verification_mode or VerificationMode.BASIC_CHECKS)
base_path = base_path if base_path is not None else self.base_path
if file_format is not None and file_format not in ["arrow", "parquet"]:
raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'")
self._file_format = file_format
if self._fs._strip_protocol(self._output_dir) == "":
# We don't support the root directory, because it has no dirname,
# and we need a dirname to use a <dirname>.incomplete directory
# when the dataset is being written
raise RuntimeError(
f"Unable to download and prepare the dataset at the root {self._output_dir}. "
f"Please specify a subdirectory, e.g. '{self._output_dir + self.dataset_name}'"
)
if dl_manager is None:
if download_config is None:
download_config = DownloadConfig(
cache_dir=self._cache_downloaded_dir,
force_download=download_mode == DownloadMode.FORCE_REDOWNLOAD,
force_extract=download_mode == DownloadMode.FORCE_REDOWNLOAD,
use_etag=False,
num_proc=num_proc,
token=self.token,
storage_options=self.storage_options,
) # We don't use etag for data files to speed up the process
dl_manager = DownloadManager(
dataset_name=self.dataset_name,
download_config=download_config,
data_dir=self.config.data_dir,
base_path=base_path,
record_checksums=(self._record_infos or verification_mode == VerificationMode.ALL_CHECKS),
)
is_local = not is_remote_filesystem(self._fs)
self.dl_manager = dl_manager
# Prevent parallel local disk operations
if is_local:
# Create parent directory of the output_dir to put the lock file in there
Path(self._output_dir).parent.mkdir(parents=True, exist_ok=True)
lock_path = self._output_dir + "_builder.lock"
# File locking only with local paths; no file locking on GCS or S3
with FileLock(lock_path) if is_local else contextlib.nullcontext():
# Check if the data already exists
data_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME))
if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS:
logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})")
# We need to update the info in case some splits were added in the meantime
# for example when calling load_dataset from multiple workers.
self.info = self._load_info()
self.download_post_processing_resources(dl_manager)
return
logger.info(f"Generating dataset {self.dataset_name} ({self._output_dir})")
if is_local: # if cache dir is local, check for available space
if not has_sufficient_disk_space(
self.info.size_in_bytes or 0, directory=Path(self._output_dir).parent
):
raise OSError(
f"Not enough disk space. Needed: {size_str(self.info.size_in_bytes or 0)} (download: {size_str(self.info.download_size or 0)}, generated: {size_str(self.info.dataset_size or 0)}, post-processed: {size_str(self.info.post_processing_size or 0)})"
)
@contextlib.contextmanager
def incomplete_dir(dirname):
"""Create temporary dir for dirname and rename on exit."""
if not is_local:
self._fs.makedirs(dirname, exist_ok=True)
yield dirname
else:
tmp_dir = dirname + ".incomplete"
os.makedirs(tmp_dir, exist_ok=True)
try:
yield tmp_dir
if os.path.isdir(dirname):
shutil.rmtree(dirname)
# LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory
shutil.move(tmp_dir, dirname)
finally:
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
# Print is intentional: we want this to always go to stdout so user has
# information needed to cancel download/preparation if needed.
# This comes right before the progress bar.
if self.info.size_in_bytes:
logger.info(
f"Downloading and preparing dataset {self.dataset_name}/{self.config.name} "
f"(download: {size_str(self.info.download_size)}, generated: {size_str(self.info.dataset_size)}, "
f"post-processed: {size_str(self.info.post_processing_size)}, "
f"total: {size_str(self.info.size_in_bytes)}) to {self._output_dir}..."
)
else:
_dest = self._fs._strip_protocol(self._output_dir) if is_local else self._output_dir
logger.info(f"Downloading and preparing dataset {self.dataset_name}/{self.config.name} to {_dest}...")
self._check_manual_download(dl_manager)
# Create a tmp dir and rename to self._output_dir on successful exit.
with incomplete_dir(self._output_dir) as tmp_output_dir:
# Temporarily assign _output_dir to tmp_data_dir to avoid having to forward
# it to every sub function.
with temporary_assignment(self, "_output_dir", tmp_output_dir):
prepare_split_kwargs = {"file_format": file_format}
if max_shard_size is not None:
prepare_split_kwargs["max_shard_size"] = max_shard_size
if num_proc is not None:
prepare_split_kwargs["num_proc"] = num_proc
self._download_and_prepare(
dl_manager=dl_manager,
verification_mode=verification_mode,
**prepare_split_kwargs,
**download_and_prepare_kwargs,
)
# Sync info
self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values())
self.info.download_checksums = dl_manager.get_recorded_sizes_checksums()
if self.info.download_size is not None:
self.info.size_in_bytes = self.info.dataset_size + self.info.download_size
# Save info
self._save_info()
# Download post processing resources
self.download_post_processing_resources(dl_manager)
logger.info(
f"Dataset {self.dataset_name} downloaded and prepared to {self._output_dir}. "
f"Subsequent calls will reuse this data."
)
def _check_manual_download(self, dl_manager):
if self.manual_download_instructions is not None and dl_manager.manual_dir is None:
raise ManualDownloadError(
textwrap.dedent(
f"""\
The dataset {self.dataset_name} with config {self.config.name} requires manual data.
Please follow the manual download instructions:
{self.manual_download_instructions}
Manual data can be loaded with:
datasets.load_dataset("{self.repo_id or self.dataset_name}", data_dir="<path/to/manual/data>")"""
)
)
def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_kwargs):
"""Downloads and prepares dataset for reading.
This is the internal implementation to overwrite called when user calls
`download_and_prepare`. It should download all required data and generate
the pre-processed datasets files.
Args:
dl_manager ([`DownloadManager`]):
`DownloadManager` used to download and cache data.
verification_mode ([`VerificationMode`]):
if `ALL_CHECKS`, perform all the verifications including checksums.
if `BASIC_CHECKS`, do not perform checksums, only perform split tests.
if `NO_CHECKS`, do not perform any verification.
prepare_split_kwargs: Additional options, such as `file_format`, `max_shard_size`
"""
# Generating data for all splits
split_dict = SplitDict(dataset_name=self.dataset_name)
split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs)
split_generators = self._split_generators(dl_manager, **split_generators_kwargs)
# Checksums verification
if verification_mode == VerificationMode.ALL_CHECKS and dl_manager.record_checksums:
verify_checksums(
self.info.download_checksums, dl_manager.get_recorded_sizes_checksums(), "dataset source files"
)
# Build splits
for split_generator in split_generators:
if str(split_generator.split_info.name).lower() == "all":
raise ValueError(
"`all` is a special split keyword corresponding to the "
"union of all splits, so cannot be used as key in "
"._split_generator()."
)
logger.info(f"Generating {split_generator.split_info.name} split")
split_dict.add(split_generator.split_info)
try:
# Prepare split will record examples associated to the split
self._prepare_split(split_generator, **prepare_split_kwargs)
except OSError as e:
raise OSError(
"Cannot find data file. "
+ (self.manual_download_instructions or "")
+ "\nOriginal error:\n"
+ str(e)
) from None
# If check_duplicates is set to True , then except DuplicatedKeysError
except DuplicatedKeysError as e:
raise DuplicatedKeysError(
e.key,
e.duplicate_key_indices,
fix_msg=f"To avoid duplicate keys, please fix the dataset script {self.name}.py",
) from None
dl_manager.manage_extracted_files()
if verification_mode == VerificationMode.BASIC_CHECKS or verification_mode == VerificationMode.ALL_CHECKS:
verify_splits(self.info.splits, split_dict)
# Update the info object with the splits.
self.info.splits = split_dict
self.info.download_size = dl_manager.downloaded_size
def download_post_processing_resources(self, dl_manager):
for split in self.info.splits or []:
for resource_name, resource_file_name in self._post_processing_resources(split).items():
if not not is_remote_filesystem(self._fs):
raise NotImplementedError(f"Post processing is not supported on filesystem {self._fs}")
if os.sep in resource_file_name:
raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}")
resource_path = os.path.join(self._output_dir, resource_file_name)
if not os.path.exists(resource_path):
downloaded_resource_path = self._download_post_processing_resources(
split, resource_name, dl_manager
)
if downloaded_resource_path:
logger.info(f"Downloaded post-processing resource {resource_name} as {resource_file_name}")
shutil.move(downloaded_resource_path, resource_path)
def _load_info(self) -> DatasetInfo:
return DatasetInfo.from_directory(self._output_dir, storage_options=self._fs.storage_options)
def _save_info(self):
file_lock = (
FileLock(self._output_dir + "_info.lock")
if not is_remote_filesystem(self._fs)
else contextlib.nullcontext()
)
with file_lock:
self.info.write_to_directory(self._output_dir, storage_options=self._fs.storage_options)
def _save_infos(self):
file_lock = (
FileLock(self._output_dir + "_infos.lock")
if not is_remote_filesystem(self._fs)
else contextlib.nullcontext()
)
with file_lock:
DatasetInfosDict(**{self.config.name: self.info}).write_to_directory(self.get_imported_module_dir())
def _make_split_generators_kwargs(self, prepare_split_kwargs):
"""Get kwargs for `self._split_generators()` from `prepare_split_kwargs`."""
del prepare_split_kwargs
return {}
def as_dataset(
self,
split: Optional[Split] = None,
run_post_process=True,
verification_mode: Optional[Union[VerificationMode, str]] = None,
in_memory=False,
) -> Union[Dataset, DatasetDict]:
"""Return a Dataset for the specified split.
Args:
split (`datasets.Split`):
Which subset of the data to return.
run_post_process (`bool`, defaults to `True`):
Whether to run post-processing dataset transforms and/or add
indexes.
verification_mode ([`VerificationMode`] or `str`, defaults to `BASIC_CHECKS`):
Verification mode determining the checks to run on the
downloaded/processed dataset information (checksums/size/splits/...).
<Added version="2.9.1"/>
in_memory (`bool`, defaults to `False`):
Whether to copy the data in-memory.
Returns:
datasets.Dataset
Example:
```py
>>> from datasets import load_dataset_builder
>>> builder = load_dataset_builder('cornell-movie-review-data/rotten_tomatoes')
>>> builder.download_and_prepare()
>>> ds = builder.as_dataset(split='train')
>>> ds
Dataset({
features: ['text', 'label'],
num_rows: 8530
})
```
"""
if self._file_format is not None and self._file_format != "arrow":
raise FileFormatError('Loading a dataset not written in the "arrow" format is not supported.')
if is_remote_filesystem(self._fs):
raise NotImplementedError(f"Loading a dataset cached in a {type(self._fs).__name__} is not supported.")
if not os.path.exists(self._output_dir):
raise FileNotFoundError(
f"Dataset {self.dataset_name}: could not find data in {self._output_dir}. Please make sure to call "
"builder.download_and_prepare(), or use "
"datasets.load_dataset() before trying to access the Dataset object."
)
logger.debug(f"Constructing Dataset for split {split or ', '.join(self.info.splits)}, from {self._output_dir}")
# By default, return all splits
if split is None:
split = {s: s for s in self.info.splits}
verification_mode = VerificationMode(verification_mode or VerificationMode.BASIC_CHECKS)
# Create a dataset for each of the given splits
datasets = map_nested(
partial(
self._build_single_dataset,
run_post_process=run_post_process,
verification_mode=verification_mode,
in_memory=in_memory,
),
split,
map_tuple=True,
disable_tqdm=True,
)
if isinstance(datasets, dict):
datasets = DatasetDict(datasets)
return datasets
def _build_single_dataset(
self,
split: Union[str, ReadInstruction, Split],
run_post_process: bool,
verification_mode: VerificationMode,
in_memory: bool = False,
):
"""as_dataset for a single split."""
if not isinstance(split, ReadInstruction):
split = str(split)
if split == "all":
split = "+".join(self.info.splits.keys())
split = Split(split)
# Build base dataset
ds = self._as_dataset(
split=split,
in_memory=in_memory,
)
if run_post_process:
for resource_file_name in self._post_processing_resources(split).values():
if os.sep in resource_file_name:
raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}")
resources_paths = {
resource_name: os.path.join(self._output_dir, resource_file_name)
for resource_name, resource_file_name in self._post_processing_resources(split).items()
}
post_processed = self._post_process(ds, resources_paths)
if post_processed is not None:
ds = post_processed
recorded_checksums = {}
record_checksums = False
for resource_name, resource_path in resources_paths.items():
size_checksum = get_size_checksum_dict(resource_path)
recorded_checksums[resource_name] = size_checksum
if verification_mode == VerificationMode.ALL_CHECKS and record_checksums:
if self.info.post_processed is None or self.info.post_processed.resources_checksums is None:
expected_checksums = None
else:
expected_checksums = self.info.post_processed.resources_checksums.get(split)
verify_checksums(expected_checksums, recorded_checksums, "post processing resources")
if self.info.post_processed is None:
self.info.post_processed = PostProcessedInfo()
if self.info.post_processed.resources_checksums is None:
self.info.post_processed.resources_checksums = {}
self.info.post_processed.resources_checksums[str(split)] = recorded_checksums
self.info.post_processing_size = sum(
checksums_dict["num_bytes"]
for split_checksums_dicts in self.info.post_processed.resources_checksums.values()
for checksums_dict in split_checksums_dicts.values()
)
if self.info.dataset_size is not None and self.info.download_size is not None:
self.info.size_in_bytes = (
self.info.dataset_size + self.info.download_size + self.info.post_processing_size
)
self._save_info()
ds._info.post_processed = self.info.post_processed
ds._info.post_processing_size = self.info.post_processing_size
ds._info.size_in_bytes = self.info.size_in_bytes
if self.info.post_processed.features is not None:
if self.info.post_processed.features.type != ds.features.type:
raise ValueError(
f"Post-processed features info don't match the dataset:\nGot\n{self.info.post_processed.features}\nbut expected something like\n{ds.features}"
)
else:
ds.info.features = self.info.post_processed.features
return ds
def _as_dataset(self, split: Union[ReadInstruction, Split] = Split.TRAIN, in_memory: bool = False) -> Dataset:
"""Constructs a `Dataset`.
This is the internal implementation to overwrite called when user calls
`as_dataset`. It should read the pre-processed datasets files and generate
the `Dataset` object.
Args:
split (`datasets.Split`):
which subset of the data to read.
in_memory (`bool`, defaults to `False`):
Whether to copy the data in-memory.
Returns:
`Dataset`
"""
cache_dir = self._fs._strip_protocol(self._output_dir)
dataset_name = self.dataset_name
if self._check_legacy_cache():
dataset_name = self.name
dataset_kwargs = ArrowReader(cache_dir, self.info).read(
name=dataset_name,
instructions=split,
split_infos=self.info.splits.values(),
in_memory=in_memory,
)
fingerprint = self._get_dataset_fingerprint(split)
return Dataset(fingerprint=fingerprint, **dataset_kwargs)
def _get_dataset_fingerprint(self, split: Union[ReadInstruction, Split]) -> str:
"""The dataset fingerprint is the hash of the relative directory dataset_name/config_name/version/hash, as well as the split specs."""
hasher = Hasher()
hasher.update(Path(self._relative_data_dir()).as_posix())
hasher.update(str(split)) # for example: train, train+test, train[:10%], test[:33%](pct1_dropremainder)
fingerprint = hasher.hexdigest()
return fingerprint
def as_streaming_dataset(
self,
split: Optional[str] = None,
base_path: Optional[str] = None,
) -> Union[dict[str, IterableDataset], IterableDataset]:
if is_remote_filesystem(self._fs):
raise NotImplementedError(
f"Loading a streaming dataset cached in a {type(self._fs).__name__} is not supported yet."
)
dl_manager = StreamingDownloadManager(
base_path=base_path or self.base_path,
download_config=DownloadConfig(token=self.token, storage_options=self.storage_options),
dataset_name=self.dataset_name,
data_dir=self.config.data_dir,
)
self._check_manual_download(dl_manager)
splits_generators = {sg.name: sg for sg in self._split_generators(dl_manager)}
# By default, return all splits
if split is None:
splits_generator = splits_generators
elif split in splits_generators:
splits_generator = splits_generators[split]
else:
raise ValueError(f"Bad split: {split}. Available splits: {list(splits_generators)}")
# Create a dataset for each of the given splits
datasets = map_nested(
self._as_streaming_dataset_single,
splits_generator,
map_tuple=True,
)
if isinstance(datasets, dict):
datasets = IterableDatasetDict(datasets)
return datasets
def _as_streaming_dataset_single(
self,
splits_generator,
) -> IterableDataset:
ex_iterable = self._get_examples_iterable_for_split(splits_generator)
# add auth to be able to access and decode audio/image files from private repositories.
token_per_repo_id = {self.repo_id: self.token} if self.repo_id else {}
return IterableDataset(
ex_iterable, info=self.info, split=splits_generator.name, token_per_repo_id=token_per_repo_id
)
def _post_process(self, dataset: Dataset, resources_paths: Mapping[str, str]) -> Optional[Dataset]:
"""Run dataset transforms or add indexes"""
return None
def _post_processing_resources(self, split: str) -> dict[str, str]:
"""Mapping resource_name -> resource_file_name"""
return {}
def _download_post_processing_resources(
self, split: str, resource_name: str, dl_manager: DownloadManager
) -> Optional[str]:
"""Download the resource using the download manager and return the downloaded path."""
return None
@abc.abstractmethod
def _split_generators(self, dl_manager: Union[DownloadManager, StreamingDownloadManager]):
"""Specify feature dictionary generators and dataset splits.
This function returns a list of `SplitGenerator`s defining how to generate
data and what splits to use.
Example:
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={'file': 'train_data.zip'},
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={'file': 'test_data.zip'},
),
]
The above code will first call `_generate_examples(file='train_data.zip')`
to write the train data, then `_generate_examples(file='test_data.zip')` to
write the test data.
Datasets are typically split into different subsets to be used at various
stages of training and evaluation.
Note that for datasets without a `VALIDATION` split, you can use a
fraction of the `TRAIN` data for evaluation as you iterate on your model
so as not to overfit to the `TEST` data.
For downloads and extractions, use the given `download_manager`.
Note that the `DownloadManager` caches downloads, so it is fine to have each
generator attempt to download the source data.
A good practice is to download all data in this function, and then
distribute the relevant parts to each split with the `gen_kwargs` argument
Args:
dl_manager (`Union[DownloadManager, StreamingDownloadManager]`):
Download manager to download the data
Returns:
`list<SplitGenerator>`.
"""
raise NotImplementedError()
@abc.abstractmethod
def _prepare_split(
self,
split_generator: SplitGenerator,
file_format: str = "arrow",
max_shard_size: Optional[Union[str, int]] = None,
num_proc: Optional[int] = None,
**kwargs,
):
"""Generate the examples and record them on disk.
Args:
split_generator (`SplitGenerator`):
Split generator to process
file_format (`str`, *optional*):
format of the data files in which the dataset will be written.
Supported formats: "arrow", "parquet". Default to "arrow" format.
max_shard_size (`Union[str, int]`, *optional*):
Maximum number of bytes written per shard, default is "500MB".
The size is based on uncompressed data size, so in practice your shard files may be smaller than
`max_shard_size` thanks to Parquet compression for example.
num_proc (`int`, *optional*, defaults to `None`):
Number of processes when downloading and generating the dataset locally.
Multiprocessing is disabled by default.
<Added version="2.7.0"/>
**kwargs: Additional kwargs forwarded from _download_and_prepare
"""
raise NotImplementedError()
def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
"""Generate the examples on the fly.
Args:
split_generator (`SplitGenerator`):
Split generator to process
"""
raise NotImplementedError()
class GeneratorBasedBuilder(DatasetBuilder):
"""Base class for datasets with data generation based on dict generators.
`GeneratorBasedBuilder` is a convenience class that abstracts away much
of the data writing and reading of `DatasetBuilder`. It expects subclasses to
implement generators of feature dictionaries across the dataset splits
(`_split_generators`). See the method docstrings for details.
"""
@abc.abstractmethod
def _generate_examples(self, **kwargs):
"""Default function generating examples for each `SplitGenerator`.
This function preprocess the examples from the raw data to the preprocessed
dataset files.
This function is called once for each `SplitGenerator` defined in
`_split_generators`. The examples yielded here will be written on
disk.
Args:
**kwargs (additional keyword arguments):
Arguments forwarded from the SplitGenerator.gen_kwargs
Yields:
key: `str` or `int`, a unique deterministic example identification key.
* Unique: An error will be raised if two examples are yield with the
same key.
* Deterministic: When generating the dataset twice, the same example
should have the same key.
Good keys can be the image id, or line number if examples are extracted
from a text file.
The key will be hashed and sorted to shuffle examples deterministically,
such as generating the dataset multiple times keep examples in the
same order.
example: `dict<str feature_name, feature_value>`, a feature dictionary
ready to be encoded and written to disk. The example will be
encoded with `self.info.features.encode_example({...})`.
"""
raise NotImplementedError()
def _prepare_split(
self,
split_generator: SplitGenerator,
check_duplicate_keys: bool,
file_format="arrow",
num_proc: Optional[int] = None,
max_shard_size: Optional[Union[int, str]] = None,
):
max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE)
if self.info.splits is not None:
split_info = self.info.splits[split_generator.name]
else:
split_info = split_generator.split_info
SUFFIX = "-JJJJJ-SSSSS-of-NNNNN"
fname = f"{self.dataset_name}-{split_generator.name}{SUFFIX}.{file_format}"
fpath = posixpath.join(self._output_dir, fname)
if num_proc and num_proc > 1:
num_input_shards = _number_of_shards_in_gen_kwargs(split_generator.gen_kwargs)
if num_input_shards <= 1:
logger.warning(
f"Setting num_proc from {num_proc} back to 1 for the {split_info.name} split to disable multiprocessing as it only contains one shard."
)
num_proc = 1
elif num_input_shards < num_proc:
logger.warning(
f"Setting num_proc from {num_proc} to {num_input_shards} for the {split_info.name} split as it only contains {num_input_shards} shards."
)
num_proc = num_input_shards
pbar = hf_tqdm(
unit=" examples",
total=split_info.num_examples,
desc=f"Generating {split_info.name} split",
)
_prepare_split_args = {
"fpath": fpath,
"file_format": file_format,
"max_shard_size": max_shard_size,
"split_info": split_info,
"check_duplicate_keys": check_duplicate_keys,
}
if num_proc is None or num_proc == 1:
result = None
gen_kwargs = split_generator.gen_kwargs
job_id = 0
with pbar:
for job_id, done, content in self._prepare_split_single(
gen_kwargs=gen_kwargs, job_id=job_id, **_prepare_split_args
):
if done:
result = content
else:
pbar.update(content)
# wrapping everything into lists for consistency with the multiprocessed code path
assert result is not None, "Failed to retrieve results from prepare_split"
examples_per_job, bytes_per_job, features_per_job, shards_per_job, shard_lengths_per_job = (
[item] for item in result
)
else:
kwargs_per_job = [
{"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args}
for job_id, gen_kwargs in enumerate(
_split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc)
)
]
num_jobs = len(kwargs_per_job)
examples_per_job = [None] * num_jobs
bytes_per_job = [None] * num_jobs
features_per_job = [None] * num_jobs
shards_per_job = [None] * num_jobs
shard_lengths_per_job = [None] * num_jobs
with Pool(num_proc) as pool:
with pbar:
for job_id, done, content in iflatmap_unordered(
pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job
):
if done:
# the content is the result of the job
(
examples_per_job[job_id],
bytes_per_job[job_id],
features_per_job[job_id],
shards_per_job[job_id],
shard_lengths_per_job[job_id],
) = content
else:
# the content is the number of examples progress update
pbar.update(content)
assert None not in examples_per_job, (
f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results"
)
total_shards = sum(shards_per_job)
total_num_examples = sum(examples_per_job)
total_num_bytes = sum(bytes_per_job)
features = features_per_job[0]
split_generator.split_info.num_examples = total_num_examples
split_generator.split_info.num_bytes = total_num_bytes
# should rename everything at the end
logger.debug(f"Renaming {total_shards} shards.")
if total_shards > 1:
# use the -SSSSS-of-NNNNN pattern
def _rename_shard(shard_and_job: tuple[int]):
shard_id, job_id = shard_and_job
global_shard_id = sum(shards_per_job[:job_id]) + shard_id
self._rename(
fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
fpath.replace("JJJJJ-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"),
)
shards_and_jobs = [
(shard_id, job_id)
for job_id, num_shards in enumerate(shards_per_job)
for shard_id in range(num_shards)
]
thread_map(_rename_shard, shards_and_jobs, disable=True, max_workers=64)
split_generator.split_info.shard_lengths = [
shard_length for shard_lengths in shard_lengths_per_job for shard_length in shard_lengths
]
else:
# don't use any pattern
shard_id, job_id = 0, 0
self._rename(
fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
fpath.replace(SUFFIX, ""),
)
if self.info.features is None:
self.info.features = features
def _prepare_split_single(
self,
gen_kwargs: dict,
fpath: str,
file_format: str,
max_shard_size: int,
split_info: SplitInfo,
check_duplicate_keys: bool,
job_id: int,
) -> Iterable[tuple[int, bool, Union[int, tuple]]]:
generator = self._generate_examples(**gen_kwargs)
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
embed_local_files = file_format == "parquet"
shard_lengths = []
total_num_examples, total_num_bytes = 0, 0
shard_id = 0
num_examples_progress_update = 0
try:
writer = writer_class(
features=self.info.features,
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
writer_batch_size=self._writer_batch_size,
hash_salt=split_info.name,
check_duplicates=check_duplicate_keys,
storage_options=self._fs.storage_options,
embed_local_files=embed_local_files,
)
try:
_time = time.time()
for key, record in generator:
if max_shard_size is not None and writer._num_bytes > max_shard_size:
num_examples, num_bytes = writer.finalize()
writer.close()
shard_lengths.append(num_examples)
total_num_examples += num_examples
total_num_bytes += num_bytes
shard_id += 1
writer = writer_class(
features=writer._features,
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
writer_batch_size=self._writer_batch_size,
hash_salt=split_info.name,
check_duplicates=check_duplicate_keys,
storage_options=self._fs.storage_options,
embed_local_files=embed_local_files,
)
example = self.info.features.encode_example(record) if self.info.features is not None else record
writer.write(example, key)
num_examples_progress_update += 1
if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL:
_time = time.time()
yield job_id, False, num_examples_progress_update
num_examples_progress_update = 0
finally:
yield job_id, False, num_examples_progress_update
num_shards = shard_id + 1
num_examples, num_bytes = writer.finalize()
writer.close()
shard_lengths.append(num_examples)
total_num_examples += num_examples
total_num_bytes += num_bytes
except Exception as e:
# Ignore the writer's error for no examples written to the file if this error was caused by the error in _generate_examples before the first example was yielded
if isinstance(e, SchemaInferenceError) and e.__context__ is not None:
e = e.__context__
raise DatasetGenerationError("An error occurred while generating the dataset") from e
yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths)
def _download_and_prepare(self, dl_manager, verification_mode, **prepare_splits_kwargs):
super()._download_and_prepare(
dl_manager,
verification_mode,
check_duplicate_keys=verification_mode == VerificationMode.BASIC_CHECKS
or verification_mode == VerificationMode.ALL_CHECKS,
**prepare_splits_kwargs,
)
def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
return ExamplesIterable(self._generate_examples, split_generator.gen_kwargs)
class ArrowBasedBuilder(DatasetBuilder):
"""Base class for datasets with data generation based on Arrow loading functions (CSV/JSON/Parquet)."""
@abc.abstractmethod
def _generate_tables(self, **kwargs):
"""Default function generating examples for each `SplitGenerator`.
This function preprocess the examples from the raw data to the preprocessed
dataset files.
This function is called once for each `SplitGenerator` defined in
`_split_generators`. The examples yielded here will be written on
disk.
Args:
**kwargs (additional keyword arguments):
Arguments forwarded from the SplitGenerator.gen_kwargs
Yields:
key: `str` or `int`, a unique deterministic example identification key.
* Unique: An error will be raised if two examples are yield with the
same key.
* Deterministic: When generating the dataset twice, the same example
should have the same key.
Good keys can be the image id, or line number if examples are extracted
from a text file.
The key will be hashed and sorted to shuffle examples deterministically,
such as generating the dataset multiple times keep examples in the
same order.
example: `pyarrow.Table`, a feature table
ready to be encoded and written to disk.
"""
raise NotImplementedError()
def _prepare_split(
self,
split_generator: SplitGenerator,
file_format: str = "arrow",
num_proc: Optional[int] = None,
max_shard_size: Optional[Union[str, int]] = None,
):
max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE)
try:
split_info = self.info.splits[split_generator.name]
except Exception:
split_info = split_generator.split_info
SUFFIX = "-JJJJJ-SSSSS-of-NNNNN"
fname = f"{self.dataset_name}-{split_generator.name}{SUFFIX}.{file_format}"
fpath = posixpath.join(self._output_dir, fname)
if num_proc and num_proc > 1:
num_input_shards = _number_of_shards_in_gen_kwargs(split_generator.gen_kwargs)
if num_input_shards <= 1:
logger.warning(
f"Setting num_proc from {num_proc} back to 1 for the {split_info.name} split to disable multiprocessing as it only contains one shard."
)
num_proc = 1
elif num_input_shards < num_proc:
logger.warning(
f"Setting num_proc from {num_proc} to {num_input_shards} for the {split_info.name} split as it only contains {num_input_shards} shards."
)
num_proc = num_input_shards
pbar = hf_tqdm(
unit=" examples",
total=split_info.num_examples,
desc=f"Generating {split_info.name} split",
)
_prepare_split_args = {
"fpath": fpath,
"file_format": file_format,
"max_shard_size": max_shard_size,
}
if num_proc is None or num_proc == 1:
result = None
gen_kwargs = split_generator.gen_kwargs
job_id = 0
with pbar:
for job_id, done, content in self._prepare_split_single(
gen_kwargs=gen_kwargs, job_id=job_id, **_prepare_split_args
):
if done:
result = content
else:
pbar.update(content)
# wrapping everything into lists for consistency with the multiprocessed code path
assert result is not None, "Failed to retrieve results from prepare_split"
examples_per_job, bytes_per_job, features_per_job, shards_per_job, shard_lengths_per_job = (
[item] for item in result
)
else:
kwargs_per_job = [
{"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args}
for job_id, gen_kwargs in enumerate(
_split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc)
)
]
num_jobs = len(kwargs_per_job)
examples_per_job = [None] * num_jobs
bytes_per_job = [None] * num_jobs
features_per_job = [None] * num_jobs
shards_per_job = [None] * num_jobs
shard_lengths_per_job = [None] * num_jobs
with Pool(num_proc) as pool:
with pbar:
for job_id, done, content in iflatmap_unordered(
pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job
):
if done:
# the content is the result of the job
(
examples_per_job[job_id],
bytes_per_job[job_id],
features_per_job[job_id],
shards_per_job[job_id],
shard_lengths_per_job[job_id],
) = content
else:
# the content is the number of examples progress update
pbar.update(content)
assert None not in examples_per_job, (
f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results"
)
total_shards = sum(shards_per_job)
total_num_examples = sum(examples_per_job)
total_num_bytes = sum(bytes_per_job)
features = features_per_job[0]
split_generator.split_info.num_examples = total_num_examples
split_generator.split_info.num_bytes = total_num_bytes
# should rename everything at the end
logger.debug(f"Renaming {total_shards} shards.")
if total_shards > 1:
# use the -SSSSS-of-NNNNN pattern
def _rename_shard(shard_id_and_job: tuple[int]):
shard_id, job_id = shard_id_and_job
global_shard_id = sum(shards_per_job[:job_id]) + shard_id
self._rename(
fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
fpath.replace("JJJJJ-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"),
)
shard_ids_and_jobs = [
(shard_id, job_id)
for job_id, num_shards in enumerate(shards_per_job)
for shard_id in range(num_shards)
]
thread_map(_rename_shard, shard_ids_and_jobs, disable=True, max_workers=64)
split_generator.split_info.shard_lengths = [
shard_length for shard_lengths in shard_lengths_per_job for shard_length in shard_lengths
]
else:
# don't use any pattern
shard_id, job_id = 0, 0
self._rename(
fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
fpath.replace(SUFFIX, ""),
)
if self.info.features is None:
self.info.features = features
def _prepare_split_single(
self, gen_kwargs: dict, fpath: str, file_format: str, max_shard_size: int, job_id: int
) -> Iterable[tuple[int, bool, Union[int, tuple]]]:
gen_kwargs = {k: tracked_list(v) if isinstance(v, list) else v for k, v in gen_kwargs.items()}
generator = self._generate_tables(**gen_kwargs)
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
embed_local_files = file_format == "parquet"
shard_lengths = []
total_num_examples, total_num_bytes = 0, 0
shard_id = 0
num_examples_progress_update = 0
try:
writer = writer_class(
features=self.info.features,
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
writer_batch_size=self._writer_batch_size,
storage_options=self._fs.storage_options,
embed_local_files=embed_local_files,
)
try:
_time = time.time()
for _, table in generator:
if max_shard_size is not None and writer._num_bytes > max_shard_size:
num_examples, num_bytes = writer.finalize()
writer.close()
shard_lengths.append(num_examples)
total_num_examples += num_examples
total_num_bytes += num_bytes
shard_id += 1
writer = writer_class(
features=writer._features,
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
writer_batch_size=self._writer_batch_size,
storage_options=self._fs.storage_options,
embed_local_files=embed_local_files,
)
try:
writer.write_table(table)
except CastError as cast_error:
raise DatasetGenerationCastError.from_cast_error(
cast_error=cast_error,
builder_name=self.info.builder_name,
gen_kwargs=gen_kwargs,
token=self.token,
)
num_examples_progress_update += len(table)
if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL:
_time = time.time()
yield job_id, False, num_examples_progress_update
num_examples_progress_update = 0
finally:
yield job_id, False, num_examples_progress_update
num_shards = shard_id + 1
num_examples, num_bytes = writer.finalize()
writer.close()
shard_lengths.append(num_examples)
total_num_examples += num_examples
total_num_bytes += num_bytes
except Exception as e:
# Ignore the writer's error for no examples written to the file if this error was caused by the error in _generate_examples before the first example was yielded
if isinstance(e, SchemaInferenceError) and e.__context__ is not None:
e = e.__context__
if isinstance(e, DatasetGenerationError):
raise
raise DatasetGenerationError("An error occurred while generating the dataset") from e
yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths)
def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
return ArrowExamplesIterable(self._generate_tables, kwargs=split_generator.gen_kwargs)
|