Spaces:
Sleeping
Sleeping
| import os | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| class SingleClassSegmentationDataset(Dataset): | |
| def __init__(self, dataset, class_labels, image_size=352, transform=None): | |
| self.items = dataset | |
| self.class_labels = class_labels | |
| self.image_size = image_size | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.items) | |
| def __getitem__(self, idx): | |
| item = self.items[idx] | |
| image = Image.open(item["img_path"]).convert("RGB") | |
| mask = Image.open(item["mask_path"]).convert("L") | |
| class_name = item["label"] | |
| class_index = self.class_labels.index(class_name) | |
| background_index = 0 | |
| mask_np = np.array(mask) > 0 | |
| final_mask = np.full(mask_np.shape, background_index, dtype=np.uint8) | |
| final_mask[mask_np] = class_index | |
| image = image.resize((self.image_size, self.image_size), Image.BILINEAR) | |
| final_mask = Image.fromarray(final_mask).resize((self.image_size, self.image_size), Image.NEAREST) | |
| if self.transform: | |
| image, final_mask = self.transform(image, final_mask) | |
| return { | |
| "image": image, | |
| "labels": torch.from_numpy(np.array(final_mask)).long() | |
| } | |
| class SegmentationCollator: | |
| def __init__(self, processor, class_labels): | |
| self.processor = processor | |
| self.class_labels = class_labels | |
| def __call__(self, batch): | |
| images = [item["image"] for item in batch] | |
| labels = [item["labels"] for item in batch] | |
| prompts = self.class_labels * len(images) | |
| expanded_images = [img for img in images for _ in self.class_labels] | |
| inputs = self.processor( | |
| images=expanded_images, | |
| text=prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True | |
| ) | |
| return { | |
| "pixel_values": inputs["pixel_values"], | |
| "input_ids": inputs["input_ids"], | |
| "labels": torch.stack(labels) | |
| } | |