MMaDA / training /imagenet_dataset.py
jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# coding=utf-8
# Copyright 2024 NUS Show Lab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from typing import Any, Callable, Optional
import torch
from torchvision.datasets.folder import DatasetFolder, default_loader
from training.utils import image_transform
class ImageNetDataset(DatasetFolder):
def __init__(
self,
root: str,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
image_size=256,
):
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
self.transform = image_transform
self.image_size = image_size
super().__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=self.transform,
target_transform=None,
is_valid_file=is_valid_file,
)
with open('./training/imagenet_label_mapping', 'r') as f:
self.labels = {}
for l in f:
num, description = l.split(":")
self.labels[int(num)] = description.strip()
print("ImageNet dataset loaded.")
def __getitem__(self, idx):
try:
path, target = self.samples[idx]
image = self.loader(path)
image = self.transform(image, resolution=self.image_size)
input_ids = "{}".format(self.labels[target])
class_ids = torch.tensor(target)
return {'images': image, 'input_ids': input_ids, 'class_ids': class_ids}
except Exception as e:
print(e)
return self.__getitem__(idx+1)
def collate_fn(self, batch):
batched = collections.defaultdict(list)
for data in batch:
for k, v in data.items():
batched[k].append(v)
for k, v in batched.items():
if k not in ('input_ids'):
batched[k] = torch.stack(v, dim=0)
return batched
if __name__ == '__main__':
pass