|
import json |
|
import os |
|
from collections import namedtuple |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
from PIL import Image |
|
|
|
from .utils import extract_archive, iterable_to_str, verify_str_arg |
|
from .vision import VisionDataset |
|
|
|
|
|
class Cityscapes(VisionDataset): |
|
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset. |
|
|
|
Args: |
|
root (str or ``pathlib.Path``): Root directory of dataset where directory ``leftImg8bit`` |
|
and ``gtFine`` or ``gtCoarse`` are located. |
|
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine" |
|
otherwise ``train``, ``train_extra`` or ``val`` |
|
mode (string, optional): The quality mode to use, ``fine`` or ``coarse`` |
|
target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon`` |
|
or ``color``. Can also be a list to output a tuple with all specified target types. |
|
transform (callable, optional): A function/transform that takes in a PIL image |
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
transforms (callable, optional): A function/transform that takes input sample and its target as entry |
|
and returns a transformed version. |
|
|
|
Examples: |
|
|
|
Get semantic segmentation target |
|
|
|
.. code-block:: python |
|
|
|
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', |
|
target_type='semantic') |
|
|
|
img, smnt = dataset[0] |
|
|
|
Get multiple targets |
|
|
|
.. code-block:: python |
|
|
|
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', |
|
target_type=['instance', 'color', 'polygon']) |
|
|
|
img, (inst, col, poly) = dataset[0] |
|
|
|
Validate on the "coarse" set |
|
|
|
.. code-block:: python |
|
|
|
dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse', |
|
target_type='semantic') |
|
|
|
img, smnt = dataset[0] |
|
""" |
|
|
|
|
|
CityscapesClass = namedtuple( |
|
"CityscapesClass", |
|
["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"], |
|
) |
|
|
|
classes = [ |
|
CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)), |
|
CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)), |
|
CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)), |
|
CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)), |
|
CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)), |
|
CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)), |
|
CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)), |
|
CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)), |
|
CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)), |
|
CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)), |
|
CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)), |
|
CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)), |
|
CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)), |
|
CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)), |
|
CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)), |
|
CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)), |
|
CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)), |
|
CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)), |
|
CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)), |
|
CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)), |
|
CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)), |
|
CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)), |
|
CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)), |
|
CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)), |
|
CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)), |
|
CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)), |
|
CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)), |
|
CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)), |
|
CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)), |
|
CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)), |
|
CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)), |
|
CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)), |
|
CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)), |
|
CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)), |
|
CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)), |
|
] |
|
|
|
def __init__( |
|
self, |
|
root: Union[str, Path], |
|
split: str = "train", |
|
mode: str = "fine", |
|
target_type: Union[List[str], str] = "instance", |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
transforms: Optional[Callable] = None, |
|
) -> None: |
|
super().__init__(root, transforms, transform, target_transform) |
|
self.mode = "gtFine" if mode == "fine" else "gtCoarse" |
|
self.images_dir = os.path.join(self.root, "leftImg8bit", split) |
|
self.targets_dir = os.path.join(self.root, self.mode, split) |
|
self.target_type = target_type |
|
self.split = split |
|
self.images = [] |
|
self.targets = [] |
|
|
|
verify_str_arg(mode, "mode", ("fine", "coarse")) |
|
if mode == "fine": |
|
valid_modes = ("train", "test", "val") |
|
else: |
|
valid_modes = ("train", "train_extra", "val") |
|
msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}." |
|
msg = msg.format(split, mode, iterable_to_str(valid_modes)) |
|
verify_str_arg(split, "split", valid_modes, msg) |
|
|
|
if not isinstance(target_type, list): |
|
self.target_type = [target_type] |
|
[ |
|
verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color")) |
|
for value in self.target_type |
|
] |
|
|
|
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): |
|
|
|
if split == "train_extra": |
|
image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip") |
|
else: |
|
image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip") |
|
|
|
if self.mode == "gtFine": |
|
target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip") |
|
elif self.mode == "gtCoarse": |
|
target_dir_zip = os.path.join(self.root, f"{self.mode}.zip") |
|
|
|
if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip): |
|
extract_archive(from_path=image_dir_zip, to_path=self.root) |
|
extract_archive(from_path=target_dir_zip, to_path=self.root) |
|
else: |
|
raise RuntimeError( |
|
"Dataset not found or incomplete. Please make sure all required folders for the" |
|
' specified "split" and "mode" are inside the "root" directory' |
|
) |
|
|
|
for city in os.listdir(self.images_dir): |
|
img_dir = os.path.join(self.images_dir, city) |
|
target_dir = os.path.join(self.targets_dir, city) |
|
for file_name in os.listdir(img_dir): |
|
target_types = [] |
|
for t in self.target_type: |
|
target_name = "{}_{}".format( |
|
file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t) |
|
) |
|
target_types.append(os.path.join(target_dir, target_name)) |
|
|
|
self.images.append(os.path.join(img_dir, file_name)) |
|
self.targets.append(target_types) |
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
""" |
|
Args: |
|
index (int): Index |
|
Returns: |
|
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more |
|
than one item. Otherwise, target is a json object if target_type="polygon", else the image segmentation. |
|
""" |
|
|
|
image = Image.open(self.images[index]).convert("RGB") |
|
|
|
targets: Any = [] |
|
for i, t in enumerate(self.target_type): |
|
if t == "polygon": |
|
target = self._load_json(self.targets[index][i]) |
|
else: |
|
target = Image.open(self.targets[index][i]) |
|
|
|
targets.append(target) |
|
|
|
target = tuple(targets) if len(targets) > 1 else targets[0] |
|
|
|
if self.transforms is not None: |
|
image, target = self.transforms(image, target) |
|
|
|
return image, target |
|
|
|
def __len__(self) -> int: |
|
return len(self.images) |
|
|
|
def extra_repr(self) -> str: |
|
lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"] |
|
return "\n".join(lines).format(**self.__dict__) |
|
|
|
def _load_json(self, path: str) -> Dict[str, Any]: |
|
with open(path) as file: |
|
data = json.load(file) |
|
return data |
|
|
|
def _get_target_suffix(self, mode: str, target_type: str) -> str: |
|
if target_type == "instance": |
|
return f"{mode}_instanceIds.png" |
|
elif target_type == "semantic": |
|
return f"{mode}_labelIds.png" |
|
elif target_type == "color": |
|
return f"{mode}_color.png" |
|
else: |
|
return f"{mode}_polygons.json" |
|
|