File size: 6,013 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from collections.abc import Sequence
from typing import Union

import torch
from torch import Tensor
from torch.nn import functional as F  # noqa: N812


def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: Union[torch.device, str]) -> Tensor:
    """Compute 1D gaussian kernel.

    Args:
        kernel_size: size of the gaussian kernel
        sigma: Standard deviation of the gaussian kernel
        dtype: data type of the output tensor
        device: device of the output tensor

    Example:
        >>> _gaussian(3, 1, torch.float, 'cpu')
        tensor([[0.2741, 0.4519, 0.2741]])

    """
    dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device)
    gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)
    return (gauss / gauss.sum()).unsqueeze(dim=0)  # (1, kernel_size)


def _gaussian_kernel_2d(
    channel: int,
    kernel_size: Sequence[int],
    sigma: Sequence[float],
    dtype: torch.dtype,
    device: Union[torch.device, str],
) -> Tensor:
    """Compute 2D gaussian kernel.

    Args:
        channel: number of channels in the image
        kernel_size: size of the gaussian kernel as a tuple (h, w)
        sigma: Standard deviation of the gaussian kernel
        dtype: data type of the output tensor
        device: device of the output tensor

    Example:
        >>> _gaussian_kernel_2d(1, (5,5), (1,1), torch.float, "cpu")
        tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030],
                  [0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
                  [0.0219, 0.0983, 0.1621, 0.0983, 0.0219],
                  [0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
                  [0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]])

    """
    gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
    gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
    kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y)  # (kernel_size, 1) * (1, kernel_size)

    return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])


def _uniform_weight_bias_conv2d(inputs: Tensor, window_size: int) -> tuple[Tensor, Tensor]:
    """Construct uniform weight and bias for a 2d convolution.

    Args:
        inputs: Input image
        window_size: size of convolutional kernel

    Return:
        The weight and bias for 2d convolution

    """
    kernel_weight = torch.ones(1, 1, window_size, window_size, dtype=inputs.dtype, device=inputs.device)
    kernel_weight /= window_size**2
    kernel_bias = torch.zeros(1, dtype=inputs.dtype, device=inputs.device)
    return kernel_weight, kernel_bias


def _single_dimension_pad(inputs: Tensor, dim: int, pad: int, outer_pad: int = 0) -> Tensor:
    """Apply single-dimension reflection padding to match scipy implementation.

    Args:
        inputs: Input image
        dim: A dimension the image should be padded over
        pad: Number of pads
        outer_pad: Number of outer pads

    Return:
        Image padded over a single dimension

    """
    _max = inputs.shape[dim]
    x = torch.index_select(inputs, dim, torch.arange(pad - 1, -1, -1).to(inputs.device))
    y = torch.index_select(inputs, dim, torch.arange(_max - 1, _max - pad - outer_pad, -1).to(inputs.device))
    return torch.cat((x, inputs, y), dim)


def _reflection_pad_2d(inputs: Tensor, pad: int, outer_pad: int = 0) -> Tensor:
    """Apply reflection padding to the input image.

    Args:
        inputs: Input image
        pad: Number of pads
        outer_pad: Number of outer pads

    Return:
        Padded image

    """
    for dim in [2, 3]:
        inputs = _single_dimension_pad(inputs, dim, pad, outer_pad)
    return inputs


def _uniform_filter(inputs: Tensor, window_size: int) -> Tensor:
    """Apply uniform filter with a window of a given size over the input image.

    Args:
        inputs: Input image
        window_size: Sliding window used for rmse calculation

    Return:
        Image transformed with the uniform input

    """
    inputs = _reflection_pad_2d(inputs, window_size // 2, window_size % 2)
    kernel_weight, kernel_bias = _uniform_weight_bias_conv2d(inputs, window_size)
    # Iterate over channels
    return torch.cat(
        [
            F.conv2d(inputs[:, channel].unsqueeze(1), kernel_weight, kernel_bias, padding=0)
            for channel in range(inputs.shape[1])
        ],
        dim=1,
    )


def _gaussian_kernel_3d(
    channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
    """Compute 3D gaussian kernel.

    Args:
        channel: number of channels in the image
        kernel_size: size of the gaussian kernel as a tuple (h, w, d)
        sigma: Standard deviation of the gaussian kernel
        dtype: data type of the output tensor
        device: device of the output tensor

    """
    gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
    gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
    gaussian_kernel_z = _gaussian(kernel_size[2], sigma[2], dtype, device)
    kernel_xy = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y)  # (kernel_size, 1) * (1, kernel_size)
    kernel = torch.mul(
        kernel_xy.unsqueeze(-1).repeat(1, 1, kernel_size[2]),
        gaussian_kernel_z.expand(kernel_size[0], kernel_size[1], kernel_size[2]),
    )
    return kernel.expand(channel, 1, kernel_size[0], kernel_size[1], kernel_size[2])


def _reflection_pad_3d(inputs: Tensor, pad_h: int, pad_w: int, pad_d: int) -> Tensor:
    """Reflective padding of 3d input.

    Args:
        inputs: tensor to pad, should be a 3D tensor of shape ``[N, C, H, W, D]``
        pad_w: amount of padding in the height dimension
        pad_h: amount of padding in the width dimension
        pad_d: amount of padding in the depth dimension

    Returns:
        padded input tensor

    """
    return F.pad(inputs, (pad_h, pad_h, pad_w, pad_w, pad_d, pad_d), mode="reflect")