|
import functools |
|
from typing import List, Union |
|
|
|
import torch |
|
import torch.fx |
|
from torch import nn, Tensor |
|
from torch._dynamo.utils import is_compile_supported |
|
from torch.jit.annotations import BroadcastingList2 |
|
from torch.nn.modules.utils import _pair |
|
from torchvision.extension import _assert_has_ops, _has_ops |
|
|
|
from ..utils import _log_api_usage_once |
|
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format |
|
|
|
|
|
def lazy_compile(**compile_kwargs): |
|
"""Lazily wrap a function with torch.compile on the first call |
|
|
|
This avoids eagerly importing dynamo. |
|
""" |
|
|
|
def decorate_fn(fn): |
|
@functools.wraps(fn) |
|
def compile_hook(*args, **kwargs): |
|
compiled_fn = torch.compile(fn, **compile_kwargs) |
|
globals()[fn.__name__] = functools.wraps(fn)(compiled_fn) |
|
return compiled_fn(*args, **kwargs) |
|
|
|
return compile_hook |
|
|
|
return decorate_fn |
|
|
|
|
|
|
|
def _bilinear_interpolate( |
|
input, |
|
roi_batch_ind, |
|
y, |
|
x, |
|
ymask, |
|
xmask, |
|
): |
|
_, channels, height, width = input.size() |
|
|
|
|
|
y = y.clamp(min=0) |
|
x = x.clamp(min=0) |
|
y_low = y.int() |
|
x_low = x.int() |
|
y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1) |
|
y_low = torch.where(y_low >= height - 1, height - 1, y_low) |
|
y = torch.where(y_low >= height - 1, y.to(input.dtype), y) |
|
|
|
x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1) |
|
x_low = torch.where(x_low >= width - 1, width - 1, x_low) |
|
x = torch.where(x_low >= width - 1, x.to(input.dtype), x) |
|
|
|
ly = y - y_low |
|
lx = x - x_low |
|
hy = 1.0 - ly |
|
hx = 1.0 - lx |
|
|
|
|
|
|
|
|
|
def masked_index( |
|
y, |
|
x, |
|
): |
|
if ymask is not None: |
|
assert xmask is not None |
|
y = torch.where(ymask[:, None, :], y, 0) |
|
x = torch.where(xmask[:, None, :], x, 0) |
|
return input[ |
|
roi_batch_ind[:, None, None, None, None, None], |
|
torch.arange(channels, device=input.device)[None, :, None, None, None, None], |
|
y[:, None, :, None, :, None], |
|
x[:, None, None, :, None, :], |
|
] |
|
|
|
v1 = masked_index(y_low, x_low) |
|
v2 = masked_index(y_low, x_high) |
|
v3 = masked_index(y_high, x_low) |
|
v4 = masked_index(y_high, x_high) |
|
|
|
|
|
def outer_prod(y, x): |
|
return y[:, None, :, None, :, None] * x[:, None, None, :, None, :] |
|
|
|
w1 = outer_prod(hy, hx) |
|
w2 = outer_prod(hy, lx) |
|
w3 = outer_prod(ly, hx) |
|
w4 = outer_prod(ly, lx) |
|
|
|
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 |
|
return val |
|
|
|
|
|
|
|
|
|
def maybe_cast(tensor): |
|
if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double: |
|
return tensor.float() |
|
else: |
|
return tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lazy_compile(dynamic=True) |
|
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): |
|
orig_dtype = input.dtype |
|
|
|
input = maybe_cast(input) |
|
rois = maybe_cast(rois) |
|
|
|
_, _, height, width = input.size() |
|
|
|
ph = torch.arange(pooled_height, device=input.device) |
|
pw = torch.arange(pooled_width, device=input.device) |
|
|
|
|
|
|
|
|
|
roi_batch_ind = rois[:, 0].int() |
|
offset = 0.5 if aligned else 0.0 |
|
roi_start_w = rois[:, 1] * spatial_scale - offset |
|
roi_start_h = rois[:, 2] * spatial_scale - offset |
|
roi_end_w = rois[:, 3] * spatial_scale - offset |
|
roi_end_h = rois[:, 4] * spatial_scale - offset |
|
|
|
roi_width = roi_end_w - roi_start_w |
|
roi_height = roi_end_h - roi_start_h |
|
if not aligned: |
|
roi_width = torch.clamp(roi_width, min=1.0) |
|
roi_height = torch.clamp(roi_height, min=1.0) |
|
|
|
bin_size_h = roi_height / pooled_height |
|
bin_size_w = roi_width / pooled_width |
|
|
|
exact_sampling = sampling_ratio > 0 |
|
|
|
roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) |
|
roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) |
|
|
|
""" |
|
iy, ix = dims(2) |
|
""" |
|
|
|
if exact_sampling: |
|
count = max(roi_bin_grid_h * roi_bin_grid_w, 1) |
|
iy = torch.arange(roi_bin_grid_h, device=input.device) |
|
ix = torch.arange(roi_bin_grid_w, device=input.device) |
|
ymask = None |
|
xmask = None |
|
else: |
|
count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) |
|
|
|
|
|
|
|
|
|
|
|
iy = torch.arange(height, device=input.device) |
|
ix = torch.arange(width, device=input.device) |
|
ymask = iy[None, :] < roi_bin_grid_h[:, None] |
|
xmask = ix[None, :] < roi_bin_grid_w[:, None] |
|
|
|
def from_K(t): |
|
return t[:, None, None] |
|
|
|
y = ( |
|
from_K(roi_start_h) |
|
+ ph[None, :, None] * from_K(bin_size_h) |
|
+ (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h) |
|
) |
|
x = ( |
|
from_K(roi_start_w) |
|
+ pw[None, :, None] * from_K(bin_size_w) |
|
+ (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w) |
|
) |
|
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) |
|
|
|
|
|
if not exact_sampling: |
|
val = torch.where(ymask[:, None, None, None, :, None], val, 0) |
|
val = torch.where(xmask[:, None, None, None, None, :], val, 0) |
|
|
|
output = val.sum((-1, -2)) |
|
if isinstance(count, torch.Tensor): |
|
output /= count[:, None, None, None] |
|
else: |
|
output /= count |
|
|
|
output = output.to(orig_dtype) |
|
|
|
return output |
|
|
|
|
|
@torch.fx.wrap |
|
def roi_align( |
|
input: Tensor, |
|
boxes: Union[Tensor, List[Tensor]], |
|
output_size: BroadcastingList2[int], |
|
spatial_scale: float = 1.0, |
|
sampling_ratio: int = -1, |
|
aligned: bool = False, |
|
) -> Tensor: |
|
""" |
|
Performs Region of Interest (RoI) Align operator with average pooling, as described in Mask R-CNN. |
|
|
|
Args: |
|
input (Tensor[N, C, H, W]): The input tensor, i.e. a batch with ``N`` elements. Each element |
|
contains ``C`` feature maps of dimensions ``H x W``. |
|
If the tensor is quantized, we expect a batch size of ``N == 1``. |
|
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) |
|
format where the regions will be taken from. |
|
The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``. |
|
If a single Tensor is passed, then the first column should |
|
contain the index of the corresponding element in the batch, i.e. a number in ``[0, N - 1]``. |
|
If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i |
|
in the batch. |
|
output_size (int or Tuple[int, int]): the size of the output (in bins or pixels) after the pooling |
|
is performed, as (height, width). |
|
spatial_scale (float): a scaling factor that maps the box coordinates to |
|
the input coordinates. For example, if your boxes are defined on the scale |
|
of a 224x224 image and your input is a 112x112 feature map (resulting from a 0.5x scaling of |
|
the original image), you'll want to set this to 0.5. Default: 1.0 |
|
sampling_ratio (int): number of sampling points in the interpolation grid |
|
used to compute the output value of each pooled output bin. If > 0, |
|
then exactly ``sampling_ratio x sampling_ratio`` sampling points per bin are used. If |
|
<= 0, then an adaptive number of grid points are used (computed as |
|
``ceil(roi_width / output_width)``, and likewise for height). Default: -1 |
|
aligned (bool): If False, use the legacy implementation. |
|
If True, pixel shift the box coordinates it by -0.5 for a better alignment with the two |
|
neighboring pixel indices. This version is used in Detectron2 |
|
|
|
Returns: |
|
Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs. |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(roi_align) |
|
check_roi_boxes_shape(boxes) |
|
rois = boxes |
|
output_size = _pair(output_size) |
|
if not isinstance(rois, torch.Tensor): |
|
rois = convert_boxes_to_roi_format(rois) |
|
if not torch.jit.is_scripting(): |
|
if ( |
|
not _has_ops() |
|
or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps or input.is_xpu)) |
|
) and is_compile_supported(input.device.type): |
|
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) |
|
_assert_has_ops() |
|
return torch.ops.torchvision.roi_align( |
|
input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned |
|
) |
|
|
|
|
|
class RoIAlign(nn.Module): |
|
""" |
|
See :func:`roi_align`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
output_size: BroadcastingList2[int], |
|
spatial_scale: float, |
|
sampling_ratio: int, |
|
aligned: bool = False, |
|
): |
|
super().__init__() |
|
_log_api_usage_once(self) |
|
self.output_size = output_size |
|
self.spatial_scale = spatial_scale |
|
self.sampling_ratio = sampling_ratio |
|
self.aligned = aligned |
|
|
|
def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor: |
|
return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned) |
|
|
|
def __repr__(self) -> str: |
|
s = ( |
|
f"{self.__class__.__name__}(" |
|
f"output_size={self.output_size}" |
|
f", spatial_scale={self.spatial_scale}" |
|
f", sampling_ratio={self.sampling_ratio}" |
|
f", aligned={self.aligned}" |
|
f")" |
|
) |
|
return s |
|
|