File size: 5,637 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import csv
import os
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
from PIL import Image
from .utils import download_and_extract_archive
from .vision import VisionDataset
class Kitti(VisionDataset):
"""`KITTI <http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark>`_ Dataset.
It corresponds to the "left color images of object" dataset, for object detection.
Args:
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
Expects the following folder structure if download=False:
.. code::
<root>
βββ Kitti
ββ raw
βββ training
| βββ image_2
| βββ label_2
βββ testing
βββ image_2
train (bool, optional): Use ``train`` split if true, else ``test`` split.
Defaults to ``train``.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample
and its target as entry and returns a transformed version.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/"
resources = [
"data_object_image_2.zip",
"data_object_label_2.zip",
]
image_dir_name = "image_2"
labels_dir_name = "label_2"
def __init__(
self,
root: Union[str, Path],
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
download: bool = False,
):
super().__init__(
root,
transform=transform,
target_transform=target_transform,
transforms=transforms,
)
self.images = []
self.targets = []
self.train = train
self._location = "training" if self.train else "testing"
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You may use download=True to download it.")
image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
if self.train:
labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name)
for img_file in os.listdir(image_dir):
self.images.append(os.path.join(image_dir, img_file))
if self.train:
self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt"))
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Get item at a given index.
Args:
index (int): Index
Returns:
tuple: (image, target), where
target is a list of dictionaries with the following keys:
- type: str
- truncated: float
- occluded: int
- alpha: float
- bbox: float[4]
- dimensions: float[3]
- locations: float[3]
- rotation_y: float
"""
image = Image.open(self.images[index])
target = self._parse_target(index) if self.train else None
if self.transforms:
image, target = self.transforms(image, target)
return image, target
def _parse_target(self, index: int) -> List:
target = []
with open(self.targets[index]) as inp:
content = csv.reader(inp, delimiter=" ")
for line in content:
target.append(
{
"type": line[0],
"truncated": float(line[1]),
"occluded": int(line[2]),
"alpha": float(line[3]),
"bbox": [float(x) for x in line[4:8]],
"dimensions": [float(x) for x in line[8:11]],
"location": [float(x) for x in line[11:14]],
"rotation_y": float(line[14]),
}
)
return target
def __len__(self) -> int:
return len(self.images)
@property
def _raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, "raw")
def _check_exists(self) -> bool:
"""Check if the data directory exists."""
folders = [self.image_dir_name]
if self.train:
folders.append(self.labels_dir_name)
return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders)
def download(self) -> None:
"""Download the KITTI data if it doesn't exist already."""
if self._check_exists():
return
os.makedirs(self._raw_folder, exist_ok=True)
# download files
for fname in self.resources:
download_and_extract_archive(
url=f"{self.data_url}{fname}",
download_root=self._raw_folder,
filename=fname,
)
|