File size: 658 Bytes
c08ab4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from torch.utils.data import Dataset

class HumanActionDataset(Dataset):
    def __init__(self, hf_dataset_split, transform=None):
        """
        hf_dataset_split: Hugging Face dataset split, e.g. ds['train']
        transform: torchvision transforms
        """
        self.dataset = hf_dataset_split
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"]  # PIL.Image.Image
        label = item["labels"]

        if self.transform:
            image = self.transform(image)

        return image, label