|
import os.path |
|
from pathlib import Path |
|
from typing import Any, Callable, List, Optional, Tuple, Union |
|
|
|
from PIL import Image |
|
|
|
from .vision import VisionDataset |
|
|
|
|
|
class CocoDetection(VisionDataset): |
|
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset. |
|
|
|
It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed, |
|
which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``. |
|
|
|
Args: |
|
root (str or ``pathlib.Path``): Root directory where images are downloaded to. |
|
annFile (string): Path to json annotation file. |
|
transform (callable, optional): A function/transform that takes in a PIL image |
|
and returns a transformed version. E.g, ``transforms.PILToTensor`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
transforms (callable, optional): A function/transform that takes input sample and its target as entry |
|
and returns a transformed version. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
root: Union[str, Path], |
|
annFile: str, |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
transforms: Optional[Callable] = None, |
|
) -> None: |
|
super().__init__(root, transforms, transform, target_transform) |
|
from pycocotools.coco import COCO |
|
|
|
self.coco = COCO(annFile) |
|
self.ids = list(sorted(self.coco.imgs.keys())) |
|
|
|
def _load_image(self, id: int) -> Image.Image: |
|
path = self.coco.loadImgs(id)[0]["file_name"] |
|
return Image.open(os.path.join(self.root, path)).convert("RGB") |
|
|
|
def _load_target(self, id: int) -> List[Any]: |
|
return self.coco.loadAnns(self.coco.getAnnIds(id)) |
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
|
|
if not isinstance(index, int): |
|
raise ValueError(f"Index must be of type integer, got {type(index)} instead.") |
|
|
|
id = self.ids[index] |
|
image = self._load_image(id) |
|
target = self._load_target(id) |
|
|
|
if self.transforms is not None: |
|
image, target = self.transforms(image, target) |
|
|
|
return image, target |
|
|
|
def __len__(self) -> int: |
|
return len(self.ids) |
|
|
|
|
|
class CocoCaptions(CocoDetection): |
|
"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset. |
|
|
|
It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed, |
|
which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``. |
|
|
|
Args: |
|
root (str or ``pathlib.Path``): Root directory where images are downloaded to. |
|
annFile (string): Path to json annotation file. |
|
transform (callable, optional): A function/transform that takes in a PIL image |
|
and returns a transformed version. E.g, ``transforms.PILToTensor`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
transforms (callable, optional): A function/transform that takes input sample and its target as entry |
|
and returns a transformed version. |
|
|
|
Example: |
|
|
|
.. code:: python |
|
|
|
import torchvision.datasets as dset |
|
import torchvision.transforms as transforms |
|
cap = dset.CocoCaptions(root = 'dir where images are', |
|
annFile = 'json annotation file', |
|
transform=transforms.PILToTensor()) |
|
|
|
print('Number of samples: ', len(cap)) |
|
img, target = cap[3] # load 4th sample |
|
|
|
print("Image Size: ", img.size()) |
|
print(target) |
|
|
|
Output: :: |
|
|
|
Number of samples: 82783 |
|
Image Size: (3L, 427L, 640L) |
|
[u'A plane emitting smoke stream flying over a mountain.', |
|
u'A plane darts across a bright blue sky behind a mountain covered in snow', |
|
u'A plane leaves a contrail above the snowy mountain top.', |
|
u'A mountain that has a plane flying overheard in the distance.', |
|
u'A mountain view with a plume of smoke in the background'] |
|
|
|
""" |
|
|
|
def _load_target(self, id: int) -> List[str]: |
|
return [ann["caption"] for ann in super()._load_target(id)] |
|
|