|
import csv |
|
import os |
|
from collections import namedtuple |
|
from pathlib import Path |
|
from typing import Any, Callable, List, Optional, Tuple, Union |
|
|
|
import PIL |
|
import torch |
|
|
|
from .utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg |
|
from .vision import VisionDataset |
|
|
|
CSV = namedtuple("CSV", ["header", "index", "data"]) |
|
|
|
|
|
class CelebA(VisionDataset): |
|
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset. |
|
|
|
Args: |
|
root (str or ``pathlib.Path``): Root directory where images are downloaded to. |
|
split (string): One of {'train', 'valid', 'test', 'all'}. |
|
Accordingly dataset is selected. |
|
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, |
|
or ``landmarks``. Can also be a list to output a tuple with all specified target types. |
|
The targets represent: |
|
|
|
- ``attr`` (Tensor shape=(40,) dtype=int): binary (0, 1) labels for attributes |
|
- ``identity`` (int): label for each person (data points with the same identity are the same person) |
|
- ``bbox`` (Tensor shape=(4,) dtype=int): bounding box (x, y, width, height) |
|
- ``landmarks`` (Tensor shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, |
|
righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) |
|
|
|
Defaults to ``attr``. If empty, ``None`` will be returned as target. |
|
|
|
transform (callable, optional): A function/transform that takes in a PIL image |
|
and returns a transformed version. E.g, ``transforms.PILToTensor`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
download (bool, optional): If true, downloads the dataset from the internet and |
|
puts it in root directory. If dataset is already downloaded, it is not |
|
downloaded again. |
|
|
|
.. warning:: |
|
|
|
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required. |
|
""" |
|
|
|
base_folder = "celeba" |
|
|
|
|
|
|
|
file_list = [ |
|
|
|
("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), |
|
|
|
|
|
("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), |
|
("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), |
|
("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), |
|
("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), |
|
|
|
("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), |
|
] |
|
|
|
def __init__( |
|
self, |
|
root: Union[str, Path], |
|
split: str = "train", |
|
target_type: Union[List[str], str] = "attr", |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
download: bool = False, |
|
) -> None: |
|
super().__init__(root, transform=transform, target_transform=target_transform) |
|
self.split = split |
|
if isinstance(target_type, list): |
|
self.target_type = target_type |
|
else: |
|
self.target_type = [target_type] |
|
|
|
if not self.target_type and self.target_transform is not None: |
|
raise RuntimeError("target_transform is specified but target_type is empty") |
|
|
|
if download: |
|
self.download() |
|
|
|
if not self._check_integrity(): |
|
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") |
|
|
|
split_map = { |
|
"train": 0, |
|
"valid": 1, |
|
"test": 2, |
|
"all": None, |
|
} |
|
split_ = split_map[ |
|
verify_str_arg( |
|
split.lower() if isinstance(split, str) else split, |
|
"split", |
|
("train", "valid", "test", "all"), |
|
) |
|
] |
|
splits = self._load_csv("list_eval_partition.txt") |
|
identity = self._load_csv("identity_CelebA.txt") |
|
bbox = self._load_csv("list_bbox_celeba.txt", header=1) |
|
landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1) |
|
attr = self._load_csv("list_attr_celeba.txt", header=1) |
|
|
|
mask = slice(None) if split_ is None else (splits.data == split_).squeeze() |
|
|
|
if mask == slice(None): |
|
self.filename = splits.index |
|
else: |
|
self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))] |
|
self.identity = identity.data[mask] |
|
self.bbox = bbox.data[mask] |
|
self.landmarks_align = landmarks_align.data[mask] |
|
self.attr = attr.data[mask] |
|
|
|
self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor") |
|
self.attr_names = attr.header |
|
|
|
def _load_csv( |
|
self, |
|
filename: str, |
|
header: Optional[int] = None, |
|
) -> CSV: |
|
with open(os.path.join(self.root, self.base_folder, filename)) as csv_file: |
|
data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True)) |
|
|
|
if header is not None: |
|
headers = data[header] |
|
data = data[header + 1 :] |
|
else: |
|
headers = [] |
|
|
|
indices = [row[0] for row in data] |
|
data = [row[1:] for row in data] |
|
data_int = [list(map(int, i)) for i in data] |
|
|
|
return CSV(headers, indices, torch.tensor(data_int)) |
|
|
|
def _check_integrity(self) -> bool: |
|
for (_, md5, filename) in self.file_list: |
|
fpath = os.path.join(self.root, self.base_folder, filename) |
|
_, ext = os.path.splitext(filename) |
|
|
|
|
|
if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): |
|
return False |
|
|
|
|
|
return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) |
|
|
|
def download(self) -> None: |
|
if self._check_integrity(): |
|
return |
|
|
|
for (file_id, md5, filename) in self.file_list: |
|
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) |
|
|
|
extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip")) |
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) |
|
|
|
target: Any = [] |
|
for t in self.target_type: |
|
if t == "attr": |
|
target.append(self.attr[index, :]) |
|
elif t == "identity": |
|
target.append(self.identity[index, 0]) |
|
elif t == "bbox": |
|
target.append(self.bbox[index, :]) |
|
elif t == "landmarks": |
|
target.append(self.landmarks_align[index, :]) |
|
else: |
|
|
|
raise ValueError(f'Target type "{t}" is not recognized.') |
|
|
|
if self.transform is not None: |
|
X = self.transform(X) |
|
|
|
if target: |
|
target = tuple(target) if len(target) > 1 else target[0] |
|
|
|
if self.target_transform is not None: |
|
target = self.target_transform(target) |
|
else: |
|
target = None |
|
|
|
return X, target |
|
|
|
def __len__(self) -> int: |
|
return len(self.attr) |
|
|
|
def extra_repr(self) -> str: |
|
lines = ["Target type: {target_type}", "Split: {split}"] |
|
return "\n".join(lines).format(**self.__dict__) |
|
|