import glob import os from collections import defaultdict from html.parser import HTMLParser from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union from .folder import default_loader from .vision import VisionDataset class Flickr8kParser(HTMLParser): """Parser for extracting captions from the Flickr8k dataset web page.""" def __init__(self, root: Union[str, Path]) -> None: super().__init__() self.root = root # Data structure to store captions self.annotations: Dict[str, List[str]] = {} # State variables self.in_table = False self.current_tag: Optional[str] = None self.current_img: Optional[str] = None def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None: self.current_tag = tag if tag == "table": self.in_table = True def handle_endtag(self, tag: str) -> None: self.current_tag = None if tag == "table": self.in_table = False def handle_data(self, data: str) -> None: if self.in_table: if data == "Image Not Found": self.current_img = None elif self.current_tag == "a": img_id = data.split("/")[-2] img_id = os.path.join(self.root, img_id + "_*.jpg") img_id = glob.glob(img_id)[0] self.current_img = img_id self.annotations[img_id] = [] elif self.current_tag == "li" and self.current_img: img_id = self.current_img self.annotations[img_id].append(data.strip()) class Flickr8k(VisionDataset): """`Flickr8k Entities `_ Dataset. Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. ann_file (string): Path to annotation file. transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. By default, it uses PIL as its image loader, but users could also pass in ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ def __init__( self, root: Union[str, Path], ann_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self.ann_file = os.path.expanduser(ann_file) # Read annotations and store in a dict parser = Flickr8kParser(self.root) with open(self.ann_file) as fh: parser.feed(fh.read()) self.annotations = parser.annotations self.ids = list(sorted(self.annotations.keys())) self.loader = loader def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: Tuple (image, target). target is a list of captions for the image. """ img_id = self.ids[index] # Image img = self.loader(img_id) if self.transform is not None: img = self.transform(img) # Captions target = self.annotations[img_id] if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self) -> int: return len(self.ids) class Flickr30k(VisionDataset): """`Flickr30k Entities `_ Dataset. Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. ann_file (string): Path to annotation file. transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. By default, it uses PIL as its image loader, but users could also pass in ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ def __init__( self, root: str, ann_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = default_loader, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self.ann_file = os.path.expanduser(ann_file) # Read annotations and store in a dict self.annotations = defaultdict(list) with open(self.ann_file) as fh: for line in fh: img_id, caption = line.strip().split("\t") self.annotations[img_id[:-2]].append(caption) self.ids = list(sorted(self.annotations.keys())) self.loader = loader def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: Tuple (image, target). target is a list of captions for the image. """ img_id = self.ids[index] # Image filename = os.path.join(self.root, img_id) img = self.loader(filename) if self.transform is not None: img = self.transform(img) # Captions target = self.annotations[img_id] if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self) -> int: return len(self.ids)