File size: 11,920 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 |
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.fx
import torchvision
from torch import nn, Tensor
from torchvision.ops.boxes import box_area
from ..utils import _log_api_usage_once
from .roi_align import roi_align
# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
@torch.jit.unused
def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
first_result = unmerged_results[0]
dtype, device = first_result.dtype, first_result.device
res = torch.zeros(
(levels.size(0), first_result.size(1), first_result.size(2), first_result.size(3)), dtype=dtype, device=device
)
for level in range(len(unmerged_results)):
index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
index = index.expand(
index.size(0),
unmerged_results[level].size(1),
unmerged_results[level].size(2),
unmerged_results[level].size(3),
)
res = res.scatter(0, index, unmerged_results[level])
return res
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
def initLevelMapper(
k_min: int,
k_max: int,
canonical_scale: int = 224,
canonical_level: int = 4,
eps: float = 1e-6,
):
return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)
class LevelMapper:
"""Determine which FPN level each RoI in a set of RoIs should map to based
on the heuristic in the FPN paper.
Args:
k_min (int)
k_max (int)
canonical_scale (int)
canonical_level (int)
eps (float)
"""
def __init__(
self,
k_min: int,
k_max: int,
canonical_scale: int = 224,
canonical_level: int = 4,
eps: float = 1e-6,
):
self.k_min = k_min
self.k_max = k_max
self.s0 = canonical_scale
self.lvl0 = canonical_level
self.eps = eps
def __call__(self, boxlists: List[Tensor]) -> Tensor:
"""
Args:
boxlists (list[BoxList])
"""
# Compute level ids
s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))
# Eqn.(1) in FPN paper
target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
def _convert_to_roi_format(boxes: List[Tensor]) -> Tensor:
concat_boxes = torch.cat(boxes, dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.cat(
[torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device) for i, b in enumerate(boxes)],
dim=0,
)
rois = torch.cat([ids, concat_boxes], dim=1)
return rois
def _infer_scale(feature: Tensor, original_size: List[int]) -> float:
# assumption: the scale is of the form 2 ** (-k), with k integer
size = feature.shape[-2:]
possible_scales: List[float] = []
for s1, s2 in zip(size, original_size):
approx_scale = float(s1) / float(s2)
scale = 2 ** float(torch.tensor(approx_scale).log2().round())
possible_scales.append(scale)
return possible_scales[0]
@torch.fx.wrap
def _setup_scales(
features: List[Tensor], image_shapes: List[Tuple[int, int]], canonical_scale: int, canonical_level: int
) -> Tuple[List[float], LevelMapper]:
if not image_shapes:
raise ValueError("images list should not be empty")
max_x = 0
max_y = 0
for shape in image_shapes:
max_x = max(shape[0], max_x)
max_y = max(shape[1], max_y)
original_input_shape = (max_x, max_y)
scales = [_infer_scale(feat, original_input_shape) for feat in features]
# get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level.
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
map_levels = initLevelMapper(
int(lvl_min),
int(lvl_max),
canonical_scale=canonical_scale,
canonical_level=canonical_level,
)
return scales, map_levels
@torch.fx.wrap
def _filter_input(x: Dict[str, Tensor], featmap_names: List[str]) -> List[Tensor]:
x_filtered = []
for k, v in x.items():
if k in featmap_names:
x_filtered.append(v)
return x_filtered
@torch.fx.wrap
def _multiscale_roi_align(
x_filtered: List[Tensor],
boxes: List[Tensor],
output_size: List[int],
sampling_ratio: int,
scales: Optional[List[float]],
mapper: Optional[LevelMapper],
) -> Tensor:
"""
Args:
x_filtered (List[Tensor]): List of input tensors.
boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
(x1, y1, x2, y2) format and in the image reference size, not the feature map
reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
output_size (Union[List[Tuple[int, int]], List[int]]): size of the output
sampling_ratio (int): sampling ratio for ROIAlign
scales (Optional[List[float]]): If None, scales will be automatically inferred. Default value is None.
mapper (Optional[LevelMapper]): If none, mapper will be automatically inferred. Default value is None.
Returns:
result (Tensor)
"""
if scales is None or mapper is None:
raise ValueError("scales and mapper should not be None")
num_levels = len(x_filtered)
rois = _convert_to_roi_format(boxes)
if num_levels == 1:
return roi_align(
x_filtered[0],
rois,
output_size=output_size,
spatial_scale=scales[0],
sampling_ratio=sampling_ratio,
)
levels = mapper(boxes)
num_rois = len(rois)
num_channels = x_filtered[0].shape[1]
dtype, device = x_filtered[0].dtype, x_filtered[0].device
result = torch.zeros(
(
num_rois,
num_channels,
)
+ output_size,
dtype=dtype,
device=device,
)
tracing_results = []
for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
idx_in_level = torch.where(levels == level)[0]
rois_per_level = rois[idx_in_level]
result_idx_in_level = roi_align(
per_level_feature,
rois_per_level,
output_size=output_size,
spatial_scale=scale,
sampling_ratio=sampling_ratio,
)
if torchvision._is_tracing():
tracing_results.append(result_idx_in_level.to(dtype))
else:
# result and result_idx_in_level's dtypes are based on dtypes of different
# elements in x_filtered. x_filtered contains tensors output by different
# layers. When autocast is active, it may choose different dtypes for
# different layers' outputs. Therefore, we defensively match result's dtype
# before copying elements from result_idx_in_level in the following op.
# We need to cast manually (can't rely on autocast to cast for us) because
# the op acts on result in-place, and autocast only affects out-of-place ops.
result[idx_in_level] = result_idx_in_level.to(result.dtype)
if torchvision._is_tracing():
result = _onnx_merge_levels(levels, tracing_results)
return result
class MultiScaleRoIAlign(nn.Module):
"""
Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
It infers the scale of the pooling via the heuristics specified in eq. 1
of the `Feature Pyramid Network paper <https://arxiv.org/abs/1612.03144>`_.
They keyword-only parameters ``canonical_scale`` and ``canonical_level``
correspond respectively to ``224`` and ``k0=4`` in eq. 1, and
have the following meaning: ``canonical_level`` is the target level of the pyramid from
which to pool a region of interest with ``w x h = canonical_scale x canonical_scale``.
Args:
featmap_names (List[str]): the names of the feature maps that will be used
for the pooling.
output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
sampling_ratio (int): sampling ratio for ROIAlign
canonical_scale (int, optional): canonical_scale for LevelMapper
canonical_level (int, optional): canonical_level for LevelMapper
Examples::
>>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
>>> i = OrderedDict()
>>> i['feat1'] = torch.rand(1, 5, 64, 64)
>>> i['feat2'] = torch.rand(1, 5, 32, 32) # this feature won't be used in the pooling
>>> i['feat3'] = torch.rand(1, 5, 16, 16)
>>> # create some random bounding boxes
>>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
>>> # original image size, before computing the feature maps
>>> image_sizes = [(512, 512)]
>>> output = m(i, [boxes], image_sizes)
>>> print(output.shape)
>>> torch.Size([6, 5, 3, 3])
"""
__annotations__ = {"scales": Optional[List[float]], "map_levels": Optional[LevelMapper]}
def __init__(
self,
featmap_names: List[str],
output_size: Union[int, Tuple[int], List[int]],
sampling_ratio: int,
*,
canonical_scale: int = 224,
canonical_level: int = 4,
):
super().__init__()
_log_api_usage_once(self)
if isinstance(output_size, int):
output_size = (output_size, output_size)
self.featmap_names = featmap_names
self.sampling_ratio = sampling_ratio
self.output_size = tuple(output_size)
self.scales = None
self.map_levels = None
self.canonical_scale = canonical_scale
self.canonical_level = canonical_level
def forward(
self,
x: Dict[str, Tensor],
boxes: List[Tensor],
image_shapes: List[Tuple[int, int]],
) -> Tensor:
"""
Args:
x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
all the same number of channels, but they can have different sizes.
boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
(x1, y1, x2, y2) format and in the image reference size, not the feature map
reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
image_shapes (List[Tuple[height, width]]): the sizes of each image before they
have been fed to a CNN to obtain feature maps. This allows us to infer the
scale factor for each one of the levels to be pooled.
Returns:
result (Tensor)
"""
x_filtered = _filter_input(x, self.featmap_names)
if self.scales is None or self.map_levels is None:
self.scales, self.map_levels = _setup_scales(
x_filtered, image_shapes, self.canonical_scale, self.canonical_level
)
return _multiscale_roi_align(
x_filtered,
boxes,
self.output_size,
self.sampling_ratio,
self.scales,
self.map_levels,
)
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})"
)
|