File size: 5,285 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import pathlib
from typing import Any, Callable, Optional, Tuple, Union

from PIL import Image

from .utils import _decompress, download_file_from_google_drive, verify_str_arg
from .vision import VisionDataset


class PCAM(VisionDataset):
    """`PCAM Dataset   <https://github.com/basveeling/pcam>`_.

    The PatchCamelyon dataset is a binary classification dataset with 327,680
    color images (96px x 96px), extracted from histopathologic scans of lymph node
    sections. Each image is annotated with a binary label indicating presence of
    metastatic tissue.

    This dataset requires the ``h5py`` package which you can install with ``pip install h5py``.

    Args:
         root (str or ``pathlib.Path``): Root directory of the dataset.
         split (string, optional): The dataset split, supports ``"train"`` (default), ``"test"`` or ``"val"``.
         transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
             version. E.g, ``transforms.RandomCrop``.
         target_transform (callable, optional): A function/transform that takes in the target and transforms it.
         download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
             dataset is already downloaded, it is not downloaded again.

             .. warning::

                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
    """

    _FILES = {
        "train": {
            "images": (
                "camelyonpatch_level_2_split_train_x.h5",  # Data file name
                "1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2",  # Google Drive ID
                "1571f514728f59376b705fc836ff4b63",  # md5 hash
            ),
            "targets": (
                "camelyonpatch_level_2_split_train_y.h5",
                "1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
                "35c2d7259d906cfc8143347bb8e05be7",
            ),
        },
        "test": {
            "images": (
                "camelyonpatch_level_2_split_test_x.h5",
                "1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
                "d8c2d60d490dbd479f8199bdfa0cf6ec",
            ),
            "targets": (
                "camelyonpatch_level_2_split_test_y.h5",
                "17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
                "60a7035772fbdb7f34eb86d4420cf66a",
            ),
        },
        "val": {
            "images": (
                "camelyonpatch_level_2_split_valid_x.h5",
                "1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
                "d5b63470df7cfa627aeec8b9dc0c066e",
            ),
            "targets": (
                "camelyonpatch_level_2_split_valid_y.h5",
                "1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
                "2b85f58b927af9964a4c15b8f7e8f179",
            ),
        },
    }

    def __init__(
        self,
        root: Union[str, pathlib.Path],
        split: str = "train",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ):
        try:
            import h5py

            self.h5py = h5py
        except ImportError:
            raise RuntimeError(
                "h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
            )

        self._split = verify_str_arg(split, "split", ("train", "test", "val"))

        super().__init__(root, transform=transform, target_transform=target_transform)
        self._base_folder = pathlib.Path(self.root) / "pcam"

        if download:
            self._download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found. You can use download=True to download it")

    def __len__(self) -> int:
        images_file = self._FILES[self._split]["images"][0]
        with self.h5py.File(self._base_folder / images_file) as images_data:
            return images_data["x"].shape[0]

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        images_file = self._FILES[self._split]["images"][0]
        with self.h5py.File(self._base_folder / images_file) as images_data:
            image = Image.fromarray(images_data["x"][idx]).convert("RGB")

        targets_file = self._FILES[self._split]["targets"][0]
        with self.h5py.File(self._base_folder / targets_file) as targets_data:
            target = int(targets_data["y"][idx, 0, 0, 0])  # shape is [num_images, 1, 1, 1]

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            target = self.target_transform(target)

        return image, target

    def _check_exists(self) -> bool:
        images_file = self._FILES[self._split]["images"][0]
        targets_file = self._FILES[self._split]["targets"][0]
        return all(self._base_folder.joinpath(h5_file).exists() for h5_file in (images_file, targets_file))

    def _download(self) -> None:
        if self._check_exists():
            return

        for file_name, file_id, md5 in self._FILES[self._split].values():
            archive_name = file_name + ".gz"
            download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5)
            _decompress(str(self._base_folder / archive_name))