File size: 6,189 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import glob
import os
from collections import defaultdict
from html.parser import HTMLParser
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from .folder import default_loader
from .vision import VisionDataset
class Flickr8kParser(HTMLParser):
"""Parser for extracting captions from the Flickr8k dataset web page."""
def __init__(self, root: Union[str, Path]) -> None:
super().__init__()
self.root = root
# Data structure to store captions
self.annotations: Dict[str, List[str]] = {}
# State variables
self.in_table = False
self.current_tag: Optional[str] = None
self.current_img: Optional[str] = None
def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
self.current_tag = tag
if tag == "table":
self.in_table = True
def handle_endtag(self, tag: str) -> None:
self.current_tag = None
if tag == "table":
self.in_table = False
def handle_data(self, data: str) -> None:
if self.in_table:
if data == "Image Not Found":
self.current_img = None
elif self.current_tag == "a":
img_id = data.split("/")[-2]
img_id = os.path.join(self.root, img_id + "_*.jpg")
img_id = glob.glob(img_id)[0]
self.current_img = img_id
self.annotations[img_id] = []
elif self.current_tag == "li" and self.current_img:
img_id = self.current_img
self.annotations[img_id].append(data.strip())
class Flickr8k(VisionDataset):
"""`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
Args:
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""
def __init__(
self,
root: Union[str, Path],
ann_file: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict
parser = Flickr8kParser(self.root)
with open(self.ann_file) as fh:
parser.feed(fh.read())
self.annotations = parser.annotations
self.ids = list(sorted(self.annotations.keys()))
self.loader = loader
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is a list of captions for the image.
"""
img_id = self.ids[index]
# Image
img = self.loader(img_id)
if self.transform is not None:
img = self.transform(img)
# Captions
target = self.annotations[img_id]
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.ids)
class Flickr30k(VisionDataset):
"""`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
Args:
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""
def __init__(
self,
root: str,
ann_file: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict
self.annotations = defaultdict(list)
with open(self.ann_file) as fh:
for line in fh:
img_id, caption = line.strip().split("\t")
self.annotations[img_id[:-2]].append(caption)
self.ids = list(sorted(self.annotations.keys()))
self.loader = loader
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is a list of captions for the image.
"""
img_id = self.ids[index]
# Image
filename = os.path.join(self.root, img_id)
img = self.loader(filename)
if self.transform is not None:
img = self.transform(img)
# Captions
target = self.annotations[img_id]
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.ids)
|