File size: 11,320 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
import functools
from typing import List, Union
import torch
import torch.fx
from torch import nn, Tensor
from torch._dynamo.utils import is_compile_supported
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops, _has_ops
from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
def lazy_compile(**compile_kwargs):
"""Lazily wrap a function with torch.compile on the first call
This avoids eagerly importing dynamo.
"""
def decorate_fn(fn):
@functools.wraps(fn)
def compile_hook(*args, **kwargs):
compiled_fn = torch.compile(fn, **compile_kwargs)
globals()[fn.__name__] = functools.wraps(fn)(compiled_fn)
return compiled_fn(*args, **kwargs)
return compile_hook
return decorate_fn
# NB: all inputs are tensors
def _bilinear_interpolate(
input, # [N, C, H, W]
roi_batch_ind, # [K]
y, # [K, PH, IY]
x, # [K, PW, IX]
ymask, # [K, IY]
xmask, # [K, IX]
):
_, channels, height, width = input.size()
# deal with inverse element out of feature map boundary
y = y.clamp(min=0)
x = x.clamp(min=0)
y_low = y.int()
x_low = x.int()
y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
y_low = torch.where(y_low >= height - 1, height - 1, y_low)
y = torch.where(y_low >= height - 1, y.to(input.dtype), y)
x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
x_low = torch.where(x_low >= width - 1, width - 1, x_low)
x = torch.where(x_low >= width - 1, x.to(input.dtype), x)
ly = y - y_low
lx = x - x_low
hy = 1.0 - ly
hx = 1.0 - lx
# do bilinear interpolation, but respect the masking!
# TODO: It's possible the masking here is unnecessary if y and
# x were clamped appropriately; hard to tell
def masked_index(
y, # [K, PH, IY]
x, # [K, PW, IX]
):
if ymask is not None:
assert xmask is not None
y = torch.where(ymask[:, None, :], y, 0)
x = torch.where(xmask[:, None, :], x, 0)
return input[
roi_batch_ind[:, None, None, None, None, None],
torch.arange(channels, device=input.device)[None, :, None, None, None, None],
y[:, None, :, None, :, None], # prev [K, PH, IY]
x[:, None, None, :, None, :], # prev [K, PW, IX]
] # [K, C, PH, PW, IY, IX]
v1 = masked_index(y_low, x_low)
v2 = masked_index(y_low, x_high)
v3 = masked_index(y_high, x_low)
v4 = masked_index(y_high, x_high)
# all ws preemptively [K, C, PH, PW, IY, IX]
def outer_prod(y, x):
return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]
w1 = outer_prod(hy, hx)
w2 = outer_prod(hy, lx)
w3 = outer_prod(ly, hx)
w4 = outer_prod(ly, lx)
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val
# TODO: this doesn't actually cache
# TODO: main library should make this easier to do
def maybe_cast(tensor):
if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double:
return tensor.float()
else:
return tensor
# This is a pure Python and differentiable implementation of roi_align. When
# run in eager mode, it uses a lot of memory, but when compiled it has
# acceptable memory usage. The main point of this implementation is that
# its backwards is deterministic.
# It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@lazy_compile(dynamic=True)
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
orig_dtype = input.dtype
input = maybe_cast(input)
rois = maybe_cast(rois)
_, _, height, width = input.size()
ph = torch.arange(pooled_height, device=input.device) # [PH]
pw = torch.arange(pooled_width, device=input.device) # [PW]
# input: [N, C, H, W]
# rois: [K, 5]
roi_batch_ind = rois[:, 0].int() # [K]
offset = 0.5 if aligned else 0.0
roi_start_w = rois[:, 1] * spatial_scale - offset # [K]
roi_start_h = rois[:, 2] * spatial_scale - offset # [K]
roi_end_w = rois[:, 3] * spatial_scale - offset # [K]
roi_end_h = rois[:, 4] * spatial_scale - offset # [K]
roi_width = roi_end_w - roi_start_w # [K]
roi_height = roi_end_h - roi_start_h # [K]
if not aligned:
roi_width = torch.clamp(roi_width, min=1.0) # [K]
roi_height = torch.clamp(roi_height, min=1.0) # [K]
bin_size_h = roi_height / pooled_height # [K]
bin_size_w = roi_width / pooled_width # [K]
exact_sampling = sampling_ratio > 0
roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) # scalar or [K]
roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) # scalar or [K]
"""
iy, ix = dims(2)
"""
if exact_sampling:
count = max(roi_bin_grid_h * roi_bin_grid_w, 1) # scalar
iy = torch.arange(roi_bin_grid_h, device=input.device) # [IY]
ix = torch.arange(roi_bin_grid_w, device=input.device) # [IX]
ymask = None
xmask = None
else:
count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) # [K]
# When doing adaptive sampling, the number of samples we need to do
# is data-dependent based on how big the ROIs are. This is a bit
# awkward because first-class dims can't actually handle this.
# So instead, we inefficiently suppose that we needed to sample ALL
# the points and mask out things that turned out to be unnecessary
iy = torch.arange(height, device=input.device) # [IY]
ix = torch.arange(width, device=input.device) # [IX]
ymask = iy[None, :] < roi_bin_grid_h[:, None] # [K, IY]
xmask = ix[None, :] < roi_bin_grid_w[:, None] # [K, IX]
def from_K(t):
return t[:, None, None]
y = (
from_K(roi_start_h)
+ ph[None, :, None] * from_K(bin_size_h)
+ (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h)
) # [K, PH, IY]
x = (
from_K(roi_start_w)
+ pw[None, :, None] * from_K(bin_size_w)
+ (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w)
) # [K, PW, IX]
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX]
# Mask out samples that weren't actually adaptively needed
if not exact_sampling:
val = torch.where(ymask[:, None, None, None, :, None], val, 0)
val = torch.where(xmask[:, None, None, None, None, :], val, 0)
output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW]
if isinstance(count, torch.Tensor):
output /= count[:, None, None, None]
else:
output /= count
output = output.to(orig_dtype)
return output
@torch.fx.wrap
def roi_align(
input: Tensor,
boxes: Union[Tensor, List[Tensor]],
output_size: BroadcastingList2[int],
spatial_scale: float = 1.0,
sampling_ratio: int = -1,
aligned: bool = False,
) -> Tensor:
"""
Performs Region of Interest (RoI) Align operator with average pooling, as described in Mask R-CNN.
Args:
input (Tensor[N, C, H, W]): The input tensor, i.e. a batch with ``N`` elements. Each element
contains ``C`` feature maps of dimensions ``H x W``.
If the tensor is quantized, we expect a batch size of ``N == 1``.
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
format where the regions will be taken from.
The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
If a single Tensor is passed, then the first column should
contain the index of the corresponding element in the batch, i.e. a number in ``[0, N - 1]``.
If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i
in the batch.
output_size (int or Tuple[int, int]): the size of the output (in bins or pixels) after the pooling
is performed, as (height, width).
spatial_scale (float): a scaling factor that maps the box coordinates to
the input coordinates. For example, if your boxes are defined on the scale
of a 224x224 image and your input is a 112x112 feature map (resulting from a 0.5x scaling of
the original image), you'll want to set this to 0.5. Default: 1.0
sampling_ratio (int): number of sampling points in the interpolation grid
used to compute the output value of each pooled output bin. If > 0,
then exactly ``sampling_ratio x sampling_ratio`` sampling points per bin are used. If
<= 0, then an adaptive number of grid points are used (computed as
``ceil(roi_width / output_width)``, and likewise for height). Default: -1
aligned (bool): If False, use the legacy implementation.
If True, pixel shift the box coordinates it by -0.5 for a better alignment with the two
neighboring pixel indices. This version is used in Detectron2
Returns:
Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(roi_align)
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting():
if (
not _has_ops()
or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps or input.is_xpu))
) and is_compile_supported(input.device.type):
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
_assert_has_ops()
return torch.ops.torchvision.roi_align(
input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
)
class RoIAlign(nn.Module):
"""
See :func:`roi_align`.
"""
def __init__(
self,
output_size: BroadcastingList2[int],
spatial_scale: float,
sampling_ratio: int,
aligned: bool = False,
):
super().__init__()
_log_api_usage_once(self)
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
self.aligned = aligned
def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor:
return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"output_size={self.output_size}"
f", spatial_scale={self.spatial_scale}"
f", sampling_ratio={self.sampling_ratio}"
f", aligned={self.aligned}"
f")"
)
return s
|