|
import torch
|
|
import math
|
|
import numbers
|
|
import torch.nn.functional as F
|
|
import torchvision.transforms.functional as f
|
|
import torchvision.transforms as T
|
|
import torchvision.transforms.v2 as v2
|
|
|
|
from torchvision import transforms as _transforms
|
|
from typing import List, Optional, Tuple, Union
|
|
from scipy import ndimage
|
|
from torch import Tensor
|
|
|
|
from sn_calibration.src.evaluate_extremities import mirror_labels
|
|
|
|
class ToTensor(torch.nn.Module):
|
|
def __call__(self, sample):
|
|
image, target, mask = sample['image'], sample['target'], sample['mask']
|
|
|
|
return {'image': f.to_tensor(image).float(),
|
|
'target': torch.from_numpy(target).float(),
|
|
'mask': torch.from_numpy(mask).float()}
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}()"
|
|
|
|
class RandomHorizontalFlip(torch.nn.Module):
|
|
def __init__(self, p=0.5):
|
|
super().__init__()
|
|
self.p = p
|
|
self.swap_dict = {1: 4, 2: 5, 3: 6, 4: 1, 5: 2, 6: 3, 7: 10, 8: 12, 9: 11, 10: 7, 11: 9, 12: 8, 13: 13,
|
|
14: 14, 15: 16, 16: 15, 17: 17, 18: 21, 19: 22, 20: 23, 21: 18, 22: 19, 23: 20, 24:24}
|
|
|
|
def forward(self, sample):
|
|
if torch.rand(1) < self.p:
|
|
image, target, mask = sample['image'], sample['target'], sample['mask']
|
|
image = f.hflip(image)
|
|
target = f.hflip(target)
|
|
|
|
target_swap, mask_swap = self.swap_layers(target, mask)
|
|
|
|
return {'image': image,
|
|
'target': target_swap,
|
|
'mask': mask_swap}
|
|
else:
|
|
return {'image': sample['image'],
|
|
'target': sample['target'],
|
|
'mask': sample['mask']}
|
|
|
|
|
|
def swap_layers(self, target, mask):
|
|
target_swap = torch.zeros_like(target)
|
|
mask_swap = torch.zeros_like(mask)
|
|
for kp in self.swap_dict.keys():
|
|
kp_swap = self.swap_dict[kp]
|
|
target_swap[kp_swap-1, :, :] = target[kp-1, :, :].clone()
|
|
mask_swap[kp_swap-1] = mask[kp-1].clone()
|
|
|
|
return target_swap, mask_swap
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(p={self.p})"
|
|
|
|
|
|
class AddGaussianNoise(torch.nn.Module):
|
|
def __init__(self, mean=0., std=2.):
|
|
self.std = std
|
|
self.mean = mean
|
|
|
|
def __call__(self, sample):
|
|
image = sample['image']
|
|
image += torch.randn(image.size()) * self.std + self.mean
|
|
image = torch.clip(image, 0, 1)
|
|
|
|
return {'image': image,
|
|
'target': sample['target'],
|
|
'mask': sample['mask']}
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
|
|
|
|
|
class ColorJitter(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
brightness: Union[float, Tuple[float, float]] = 0,
|
|
contrast: Union[float, Tuple[float, float]] = 0,
|
|
saturation: Union[float, Tuple[float, float]] = 0,
|
|
hue: Union[float, Tuple[float, float]] = 0,
|
|
) -> None:
|
|
super().__init__()
|
|
self.brightness = self._check_input(brightness, "brightness")
|
|
self.contrast = self._check_input(contrast, "contrast")
|
|
self.saturation = self._check_input(saturation, "saturation")
|
|
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
|
|
|
|
@torch.jit.unused
|
|
def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
|
|
if isinstance(value, numbers.Number):
|
|
if value < 0:
|
|
raise ValueError(f"If {name} is a single number, it must be non negative.")
|
|
value = [center - float(value), center + float(value)]
|
|
if clip_first_on_zero:
|
|
value[0] = max(value[0], 0.0)
|
|
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
|
value = [float(value[0]), float(value[1])]
|
|
else:
|
|
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
|
|
|
|
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
|
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
|
|
|
|
|
|
|
|
if value[0] == value[1] == center:
|
|
return None
|
|
else:
|
|
return tuple(value)
|
|
|
|
@staticmethod
|
|
def get_params(
|
|
brightness: Optional[List[float]],
|
|
contrast: Optional[List[float]],
|
|
saturation: Optional[List[float]],
|
|
hue: Optional[List[float]],
|
|
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
|
|
"""Get the parameters for the randomized transform to be applied on image.
|
|
|
|
Args:
|
|
brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
|
|
uniformly. Pass None to turn off the transformation.
|
|
contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
|
|
uniformly. Pass None to turn off the transformation.
|
|
saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
|
|
uniformly. Pass None to turn off the transformation.
|
|
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
|
|
Pass None to turn off the transformation.
|
|
|
|
Returns:
|
|
tuple: The parameters used to apply the randomized transform
|
|
along with their random order.
|
|
"""
|
|
fn_idx = torch.randperm(4)
|
|
|
|
b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
|
|
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
|
|
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
|
|
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
|
|
|
|
return fn_idx, b, c, s, h
|
|
|
|
|
|
def forward(self, sample):
|
|
"""
|
|
Args:
|
|
img (PIL Image or Tensor): Input image.
|
|
|
|
Returns:
|
|
PIL Image or Tensor: Color jittered image.
|
|
"""
|
|
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
|
|
self.brightness, self.contrast, self.saturation, self.hue
|
|
)
|
|
|
|
image = sample['image']
|
|
|
|
for fn_id in fn_idx:
|
|
if fn_id == 0 and brightness_factor is not None:
|
|
image = f.adjust_brightness(image, brightness_factor)
|
|
elif fn_id == 1 and contrast_factor is not None:
|
|
image = f.adjust_contrast(image, contrast_factor)
|
|
elif fn_id == 2 and saturation_factor is not None:
|
|
image = f.adjust_saturation(image, saturation_factor)
|
|
elif fn_id == 3 and hue_factor is not None:
|
|
image = f.adjust_hue(image, hue_factor)
|
|
|
|
return {'image': image,
|
|
'target': sample['target'],
|
|
'mask': sample['mask']}
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
s = (
|
|
f"{self.__class__.__name__}("
|
|
f"brightness={self.brightness}"
|
|
f", contrast={self.contrast}"
|
|
f", saturation={self.saturation}"
|
|
f", hue={self.hue})"
|
|
)
|
|
return s
|
|
|
|
|
|
|
|
transforms = v2.Compose([
|
|
ToTensor(),
|
|
RandomHorizontalFlip(p=.5),
|
|
ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
|
|
AddGaussianNoise(0, .1)
|
|
])
|
|
|
|
|
|
no_transforms = v2.Compose([
|
|
ToTensor(),
|
|
])
|
|
|
|
|
|
|