|
import os |
|
import shutil |
|
import tempfile |
|
from contextlib import contextmanager |
|
from pathlib import Path |
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union |
|
|
|
import torch |
|
|
|
from .folder import ImageFolder |
|
from .utils import check_integrity, extract_archive, verify_str_arg |
|
|
|
ARCHIVE_META = { |
|
"train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), |
|
"val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), |
|
"devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), |
|
} |
|
|
|
META_FILE = "meta.bin" |
|
|
|
|
|
class ImageNet(ImageFolder): |
|
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset. |
|
|
|
.. note:: |
|
Before using this class, it is required to download ImageNet 2012 dataset from |
|
`here <https://image-net.org/challenges/LSVRC/2012/2012-downloads.php>`_ and |
|
place the files ``ILSVRC2012_devkit_t12.tar.gz`` and ``ILSVRC2012_img_train.tar`` |
|
or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory. |
|
|
|
Args: |
|
root (str or ``pathlib.Path``): Root directory of the ImageNet Dataset. |
|
split (string, optional): The dataset split, supports ``train``, or ``val``. |
|
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, |
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
loader (callable, optional): A function to load an image given its path. |
|
By default, it uses PIL as its image loader, but users could also pass in |
|
``torchvision.io.decode_image`` for decoding image data into tensors directly. |
|
|
|
Attributes: |
|
classes (list): List of the class name tuples. |
|
class_to_idx (dict): Dict with items (class_name, class_index). |
|
wnids (list): List of the WordNet IDs. |
|
wnid_to_idx (dict): Dict with items (wordnet_id, class_index). |
|
imgs (list): List of (image path, class_index) tuples |
|
targets (list): The class_index value for each image in the dataset |
|
""" |
|
|
|
def __init__(self, root: Union[str, Path], split: str = "train", **kwargs: Any) -> None: |
|
root = self.root = os.path.expanduser(root) |
|
self.split = verify_str_arg(split, "split", ("train", "val")) |
|
|
|
self.parse_archives() |
|
wnid_to_classes = load_meta_file(self.root)[0] |
|
|
|
super().__init__(self.split_folder, **kwargs) |
|
self.root = root |
|
|
|
self.wnids = self.classes |
|
self.wnid_to_idx = self.class_to_idx |
|
self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] |
|
self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss} |
|
|
|
def parse_archives(self) -> None: |
|
if not check_integrity(os.path.join(self.root, META_FILE)): |
|
parse_devkit_archive(self.root) |
|
|
|
if not os.path.isdir(self.split_folder): |
|
if self.split == "train": |
|
parse_train_archive(self.root) |
|
elif self.split == "val": |
|
parse_val_archive(self.root) |
|
|
|
@property |
|
def split_folder(self) -> str: |
|
return os.path.join(self.root, self.split) |
|
|
|
def extra_repr(self) -> str: |
|
return "Split: {split}".format(**self.__dict__) |
|
|
|
|
|
def load_meta_file(root: Union[str, Path], file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]: |
|
if file is None: |
|
file = META_FILE |
|
file = os.path.join(root, file) |
|
|
|
if check_integrity(file): |
|
return torch.load(file, weights_only=True) |
|
else: |
|
msg = ( |
|
"The meta file {} is not present in the root directory or is corrupted. " |
|
"This file is automatically created by the ImageNet dataset." |
|
) |
|
raise RuntimeError(msg.format(file, root)) |
|
|
|
|
|
def _verify_archive(root: Union[str, Path], file: str, md5: str) -> None: |
|
if not check_integrity(os.path.join(root, file), md5): |
|
msg = ( |
|
"The archive {} is not present in the root directory or is corrupted. " |
|
"You need to download it externally and place it in {}." |
|
) |
|
raise RuntimeError(msg.format(file, root)) |
|
|
|
|
|
def parse_devkit_archive(root: Union[str, Path], file: Optional[str] = None) -> None: |
|
"""Parse the devkit archive of the ImageNet2012 classification dataset and save |
|
the meta information in a binary file. |
|
|
|
Args: |
|
root (str or ``pathlib.Path``): Root directory containing the devkit archive |
|
file (str, optional): Name of devkit archive. Defaults to |
|
'ILSVRC2012_devkit_t12.tar.gz' |
|
""" |
|
import scipy.io as sio |
|
|
|
def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, Tuple[str, ...]]]: |
|
metafile = os.path.join(devkit_root, "data", "meta.mat") |
|
meta = sio.loadmat(metafile, squeeze_me=True)["synsets"] |
|
nums_children = list(zip(*meta))[4] |
|
meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0] |
|
idcs, wnids, classes = list(zip(*meta))[:3] |
|
classes = [tuple(clss.split(", ")) for clss in classes] |
|
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} |
|
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} |
|
return idx_to_wnid, wnid_to_classes |
|
|
|
def parse_val_groundtruth_txt(devkit_root: str) -> List[int]: |
|
file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt") |
|
with open(file) as txtfh: |
|
val_idcs = txtfh.readlines() |
|
return [int(val_idx) for val_idx in val_idcs] |
|
|
|
@contextmanager |
|
def get_tmp_dir() -> Iterator[str]: |
|
tmp_dir = tempfile.mkdtemp() |
|
try: |
|
yield tmp_dir |
|
finally: |
|
shutil.rmtree(tmp_dir) |
|
|
|
archive_meta = ARCHIVE_META["devkit"] |
|
if file is None: |
|
file = archive_meta[0] |
|
md5 = archive_meta[1] |
|
|
|
_verify_archive(root, file, md5) |
|
|
|
with get_tmp_dir() as tmp_dir: |
|
extract_archive(os.path.join(root, file), tmp_dir) |
|
|
|
devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12") |
|
idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root) |
|
val_idcs = parse_val_groundtruth_txt(devkit_root) |
|
val_wnids = [idx_to_wnid[idx] for idx in val_idcs] |
|
|
|
torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE)) |
|
|
|
|
|
def parse_train_archive(root: Union[str, Path], file: Optional[str] = None, folder: str = "train") -> None: |
|
"""Parse the train images archive of the ImageNet2012 classification dataset and |
|
prepare it for usage with the ImageNet dataset. |
|
|
|
Args: |
|
root (str or ``pathlib.Path``): Root directory containing the train images archive |
|
file (str, optional): Name of train images archive. Defaults to |
|
'ILSVRC2012_img_train.tar' |
|
folder (str, optional): Optional name for train images folder. Defaults to |
|
'train' |
|
""" |
|
archive_meta = ARCHIVE_META["train"] |
|
if file is None: |
|
file = archive_meta[0] |
|
md5 = archive_meta[1] |
|
|
|
_verify_archive(root, file, md5) |
|
|
|
train_root = os.path.join(root, folder) |
|
extract_archive(os.path.join(root, file), train_root) |
|
|
|
archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)] |
|
for archive in archives: |
|
extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True) |
|
|
|
|
|
def parse_val_archive( |
|
root: Union[str, Path], file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val" |
|
) -> None: |
|
"""Parse the validation images archive of the ImageNet2012 classification dataset |
|
and prepare it for usage with the ImageNet dataset. |
|
|
|
Args: |
|
root (str or ``pathlib.Path``): Root directory containing the validation images archive |
|
file (str, optional): Name of validation images archive. Defaults to |
|
'ILSVRC2012_img_val.tar' |
|
wnids (list, optional): List of WordNet IDs of the validation images. If None |
|
is given, the IDs are loaded from the meta file in the root directory |
|
folder (str, optional): Optional name for validation images folder. Defaults to |
|
'val' |
|
""" |
|
archive_meta = ARCHIVE_META["val"] |
|
if file is None: |
|
file = archive_meta[0] |
|
md5 = archive_meta[1] |
|
if wnids is None: |
|
wnids = load_meta_file(root)[1] |
|
|
|
_verify_archive(root, file, md5) |
|
|
|
val_root = os.path.join(root, folder) |
|
extract_archive(os.path.join(root, file), val_root) |
|
|
|
images = sorted(os.path.join(val_root, image) for image in os.listdir(val_root)) |
|
|
|
for wnid in set(wnids): |
|
os.mkdir(os.path.join(val_root, wnid)) |
|
|
|
for wnid, img_file in zip(wnids, images): |
|
shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file))) |
|
|