|
import json |
|
from pathlib import Path |
|
from typing import Any, Callable, Optional, Tuple, Union |
|
|
|
from .folder import default_loader |
|
|
|
from .utils import download_and_extract_archive, verify_str_arg |
|
from .vision import VisionDataset |
|
|
|
|
|
class Food101(VisionDataset): |
|
"""`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_. |
|
|
|
The Food-101 is a challenging data set of 101 food categories with 101,000 images. |
|
For each class, 250 manually reviewed test images are provided as well as 750 training images. |
|
On purpose, the training images were not cleaned, and thus still contain some amount of noise. |
|
This comes mostly in the form of intense colors and sometimes wrong labels. All images were |
|
rescaled to have a maximum side length of 512 pixels. |
|
|
|
|
|
Args: |
|
root (str or ``pathlib.Path``): Root directory of the dataset. |
|
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. |
|
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, |
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
target_transform (callable, optional): A function/transform that takes in the target and transforms it. |
|
download (bool, optional): If True, downloads the dataset from the internet and |
|
puts it in root directory. If dataset is already downloaded, it is not |
|
downloaded again. Default is False. |
|
loader (callable, optional): A function to load an image given its path. |
|
By default, it uses PIL as its image loader, but users could also pass in |
|
``torchvision.io.decode_image`` for decoding image data into tensors directly. |
|
""" |
|
|
|
_URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz" |
|
_MD5 = "85eeb15f3717b99a5da872d97d918f87" |
|
|
|
def __init__( |
|
self, |
|
root: Union[str, Path], |
|
split: str = "train", |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
download: bool = False, |
|
loader: Callable[[Union[str, Path]], Any] = default_loader, |
|
) -> None: |
|
super().__init__(root, transform=transform, target_transform=target_transform) |
|
self._split = verify_str_arg(split, "split", ("train", "test")) |
|
self._base_folder = Path(self.root) / "food-101" |
|
self._meta_folder = self._base_folder / "meta" |
|
self._images_folder = self._base_folder / "images" |
|
|
|
if download: |
|
self._download() |
|
|
|
if not self._check_exists(): |
|
raise RuntimeError("Dataset not found. You can use download=True to download it") |
|
|
|
self._labels = [] |
|
self._image_files = [] |
|
with open(self._meta_folder / f"{split}.json") as f: |
|
metadata = json.loads(f.read()) |
|
|
|
self.classes = sorted(metadata.keys()) |
|
self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) |
|
|
|
for class_label, im_rel_paths in metadata.items(): |
|
self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths) |
|
self._image_files += [ |
|
self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths |
|
] |
|
self.loader = loader |
|
|
|
def __len__(self) -> int: |
|
return len(self._image_files) |
|
|
|
def __getitem__(self, idx: int) -> Tuple[Any, Any]: |
|
image_file, label = self._image_files[idx], self._labels[idx] |
|
image = self.loader(image_file) |
|
|
|
if self.transform: |
|
image = self.transform(image) |
|
|
|
if self.target_transform: |
|
label = self.target_transform(label) |
|
|
|
return image, label |
|
|
|
def extra_repr(self) -> str: |
|
return f"split={self._split}" |
|
|
|
def _check_exists(self) -> bool: |
|
return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder)) |
|
|
|
def _download(self) -> None: |
|
if self._check_exists(): |
|
return |
|
download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) |
|
|