|
import glob |
|
import os |
|
from collections import defaultdict |
|
from html.parser import HTMLParser |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
from .folder import default_loader |
|
from .vision import VisionDataset |
|
|
|
|
|
class Flickr8kParser(HTMLParser): |
|
"""Parser for extracting captions from the Flickr8k dataset web page.""" |
|
|
|
def __init__(self, root: Union[str, Path]) -> None: |
|
super().__init__() |
|
|
|
self.root = root |
|
|
|
|
|
self.annotations: Dict[str, List[str]] = {} |
|
|
|
|
|
self.in_table = False |
|
self.current_tag: Optional[str] = None |
|
self.current_img: Optional[str] = None |
|
|
|
def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None: |
|
self.current_tag = tag |
|
|
|
if tag == "table": |
|
self.in_table = True |
|
|
|
def handle_endtag(self, tag: str) -> None: |
|
self.current_tag = None |
|
|
|
if tag == "table": |
|
self.in_table = False |
|
|
|
def handle_data(self, data: str) -> None: |
|
if self.in_table: |
|
if data == "Image Not Found": |
|
self.current_img = None |
|
elif self.current_tag == "a": |
|
img_id = data.split("/")[-2] |
|
img_id = os.path.join(self.root, img_id + "_*.jpg") |
|
img_id = glob.glob(img_id)[0] |
|
self.current_img = img_id |
|
self.annotations[img_id] = [] |
|
elif self.current_tag == "li" and self.current_img: |
|
img_id = self.current_img |
|
self.annotations[img_id].append(data.strip()) |
|
|
|
|
|
class Flickr8k(VisionDataset): |
|
"""`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset. |
|
|
|
Args: |
|
root (str or ``pathlib.Path``): Root directory where images are downloaded to. |
|
ann_file (string): Path to annotation file. |
|
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, |
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
loader (callable, optional): A function to load an image given its path. |
|
By default, it uses PIL as its image loader, but users could also pass in |
|
``torchvision.io.decode_image`` for decoding image data into tensors directly. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
root: Union[str, Path], |
|
ann_file: str, |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
loader: Callable[[str], Any] = default_loader, |
|
) -> None: |
|
super().__init__(root, transform=transform, target_transform=target_transform) |
|
self.ann_file = os.path.expanduser(ann_file) |
|
|
|
|
|
parser = Flickr8kParser(self.root) |
|
with open(self.ann_file) as fh: |
|
parser.feed(fh.read()) |
|
self.annotations = parser.annotations |
|
|
|
self.ids = list(sorted(self.annotations.keys())) |
|
self.loader = loader |
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
""" |
|
Args: |
|
index (int): Index |
|
|
|
Returns: |
|
tuple: Tuple (image, target). target is a list of captions for the image. |
|
""" |
|
img_id = self.ids[index] |
|
|
|
|
|
img = self.loader(img_id) |
|
if self.transform is not None: |
|
img = self.transform(img) |
|
|
|
|
|
target = self.annotations[img_id] |
|
if self.target_transform is not None: |
|
target = self.target_transform(target) |
|
|
|
return img, target |
|
|
|
def __len__(self) -> int: |
|
return len(self.ids) |
|
|
|
|
|
class Flickr30k(VisionDataset): |
|
"""`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset. |
|
|
|
Args: |
|
root (str or ``pathlib.Path``): Root directory where images are downloaded to. |
|
ann_file (string): Path to annotation file. |
|
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, |
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
loader (callable, optional): A function to load an image given its path. |
|
By default, it uses PIL as its image loader, but users could also pass in |
|
``torchvision.io.decode_image`` for decoding image data into tensors directly. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
root: str, |
|
ann_file: str, |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
loader: Callable[[str], Any] = default_loader, |
|
) -> None: |
|
super().__init__(root, transform=transform, target_transform=target_transform) |
|
self.ann_file = os.path.expanduser(ann_file) |
|
|
|
|
|
self.annotations = defaultdict(list) |
|
with open(self.ann_file) as fh: |
|
for line in fh: |
|
img_id, caption = line.strip().split("\t") |
|
self.annotations[img_id[:-2]].append(caption) |
|
|
|
self.ids = list(sorted(self.annotations.keys())) |
|
self.loader = loader |
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
""" |
|
Args: |
|
index (int): Index |
|
|
|
Returns: |
|
tuple: Tuple (image, target). target is a list of captions for the image. |
|
""" |
|
img_id = self.ids[index] |
|
|
|
|
|
filename = os.path.join(self.root, img_id) |
|
img = self.loader(filename) |
|
if self.transform is not None: |
|
img = self.transform(img) |
|
|
|
|
|
target = self.annotations[img_id] |
|
if self.target_transform is not None: |
|
target = self.target_transform(target) |
|
|
|
return img, target |
|
|
|
def __len__(self) -> int: |
|
return len(self.ids) |
|
|