|
from typing import Tuple |
|
|
|
import torch |
|
import torchvision |
|
from torch import Tensor |
|
from torchvision.extension import _assert_has_ops |
|
|
|
from ..utils import _log_api_usage_once |
|
from ._box_convert import _box_cxcywh_to_xyxy, _box_xywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xyxy_to_xywh |
|
from ._utils import _upcast |
|
|
|
|
|
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: |
|
""" |
|
Performs non-maximum suppression (NMS) on the boxes according |
|
to their intersection-over-union (IoU). |
|
|
|
NMS iteratively removes lower scoring boxes which have an |
|
IoU greater than ``iou_threshold`` with another (higher scoring) |
|
box. |
|
|
|
If multiple boxes have the exact same score and satisfy the IoU |
|
criterion with respect to a reference box, the selected box is |
|
not guaranteed to be the same between CPU and GPU. This is similar |
|
to the behavior of argsort in PyTorch when repeated values are present. |
|
|
|
Args: |
|
boxes (Tensor[N, 4])): boxes to perform NMS on. They |
|
are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and |
|
``0 <= y1 < y2``. |
|
scores (Tensor[N]): scores for each one of the boxes |
|
iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold |
|
|
|
Returns: |
|
Tensor: int64 tensor with the indices of the elements that have been kept |
|
by NMS, sorted in decreasing order of scores |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(nms) |
|
_assert_has_ops() |
|
return torch.ops.torchvision.nms(boxes, scores, iou_threshold) |
|
|
|
|
|
def batched_nms( |
|
boxes: Tensor, |
|
scores: Tensor, |
|
idxs: Tensor, |
|
iou_threshold: float, |
|
) -> Tensor: |
|
""" |
|
Performs non-maximum suppression in a batched fashion. |
|
|
|
Each index value correspond to a category, and NMS |
|
will not be applied between elements of different categories. |
|
|
|
Args: |
|
boxes (Tensor[N, 4]): boxes where NMS will be performed. They |
|
are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and |
|
``0 <= y1 < y2``. |
|
scores (Tensor[N]): scores for each one of the boxes |
|
idxs (Tensor[N]): indices of the categories for each one of the boxes. |
|
iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold |
|
|
|
Returns: |
|
Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted |
|
in decreasing order of scores |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(batched_nms) |
|
|
|
|
|
|
|
if boxes.numel() > (4000 if boxes.device.type == "cpu" else 100_000) and not torchvision._is_tracing(): |
|
return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold) |
|
else: |
|
return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) |
|
|
|
|
|
@torch.jit._script_if_tracing |
|
def _batched_nms_coordinate_trick( |
|
boxes: Tensor, |
|
scores: Tensor, |
|
idxs: Tensor, |
|
iou_threshold: float, |
|
) -> Tensor: |
|
|
|
|
|
|
|
|
|
if boxes.numel() == 0: |
|
return torch.empty((0,), dtype=torch.int64, device=boxes.device) |
|
max_coordinate = boxes.max() |
|
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) |
|
boxes_for_nms = boxes + offsets[:, None] |
|
keep = nms(boxes_for_nms, scores, iou_threshold) |
|
return keep |
|
|
|
|
|
@torch.jit._script_if_tracing |
|
def _batched_nms_vanilla( |
|
boxes: Tensor, |
|
scores: Tensor, |
|
idxs: Tensor, |
|
iou_threshold: float, |
|
) -> Tensor: |
|
|
|
keep_mask = torch.zeros_like(scores, dtype=torch.bool) |
|
for class_id in torch.unique(idxs): |
|
curr_indices = torch.where(idxs == class_id)[0] |
|
curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold) |
|
keep_mask[curr_indices[curr_keep_indices]] = True |
|
keep_indices = torch.where(keep_mask)[0] |
|
return keep_indices[scores[keep_indices].sort(descending=True)[1]] |
|
|
|
|
|
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor: |
|
""" |
|
Remove every box from ``boxes`` which contains at least one side length |
|
that is smaller than ``min_size``. |
|
|
|
.. note:: |
|
For sanitizing a :class:`~torchvision.tv_tensors.BoundingBoxes` object, consider using |
|
the transform :func:`~torchvision.transforms.v2.SanitizeBoundingBoxes` instead. |
|
|
|
Args: |
|
boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format |
|
with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. |
|
min_size (float): minimum size |
|
|
|
Returns: |
|
Tensor[K]: indices of the boxes that have both sides |
|
larger than ``min_size`` |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(remove_small_boxes) |
|
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] |
|
keep = (ws >= min_size) & (hs >= min_size) |
|
keep = torch.where(keep)[0] |
|
return keep |
|
|
|
|
|
def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor: |
|
""" |
|
Clip boxes so that they lie inside an image of size ``size``. |
|
|
|
.. note:: |
|
For clipping a :class:`~torchvision.tv_tensors.BoundingBoxes` object, consider using |
|
the transform :func:`~torchvision.transforms.v2.ClampBoundingBoxes` instead. |
|
|
|
Args: |
|
boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format |
|
with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. |
|
size (Tuple[height, width]): size of the image |
|
|
|
Returns: |
|
Tensor[N, 4]: clipped boxes |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(clip_boxes_to_image) |
|
dim = boxes.dim() |
|
boxes_x = boxes[..., 0::2] |
|
boxes_y = boxes[..., 1::2] |
|
height, width = size |
|
|
|
if torchvision._is_tracing(): |
|
boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device)) |
|
boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device)) |
|
boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device)) |
|
boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device)) |
|
else: |
|
boxes_x = boxes_x.clamp(min=0, max=width) |
|
boxes_y = boxes_y.clamp(min=0, max=height) |
|
|
|
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim) |
|
return clipped_boxes.reshape(boxes.shape) |
|
|
|
|
|
def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: |
|
""" |
|
Converts :class:`torch.Tensor` boxes from a given ``in_fmt`` to ``out_fmt``. |
|
|
|
.. note:: |
|
For converting a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.BoundingBoxes` object |
|
between different formats, |
|
consider using :func:`~torchvision.transforms.v2.functional.convert_bounding_box_format` instead. |
|
Or see the corresponding transform :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat`. |
|
|
|
Supported ``in_fmt`` and ``out_fmt`` strings are: |
|
|
|
``'xyxy'``: boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right. |
|
This is the format that torchvision utilities expect. |
|
|
|
``'xywh'``: boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height. |
|
|
|
``'cxcywh'``: boxes are represented via centre, width and height, cx, cy being center of box, w, h |
|
being width and height. |
|
|
|
Args: |
|
boxes (Tensor[N, 4]): boxes which will be converted. |
|
in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh']. |
|
out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh'] |
|
|
|
Returns: |
|
Tensor[N, 4]: Boxes into converted format. |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(box_convert) |
|
allowed_fmts = ("xyxy", "xywh", "cxcywh") |
|
if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts: |
|
raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt") |
|
|
|
if in_fmt == out_fmt: |
|
return boxes.clone() |
|
|
|
if in_fmt != "xyxy" and out_fmt != "xyxy": |
|
|
|
if in_fmt == "xywh": |
|
boxes = _box_xywh_to_xyxy(boxes) |
|
elif in_fmt == "cxcywh": |
|
boxes = _box_cxcywh_to_xyxy(boxes) |
|
in_fmt = "xyxy" |
|
|
|
if in_fmt == "xyxy": |
|
if out_fmt == "xywh": |
|
boxes = _box_xyxy_to_xywh(boxes) |
|
elif out_fmt == "cxcywh": |
|
boxes = _box_xyxy_to_cxcywh(boxes) |
|
elif out_fmt == "xyxy": |
|
if in_fmt == "xywh": |
|
boxes = _box_xywh_to_xyxy(boxes) |
|
elif in_fmt == "cxcywh": |
|
boxes = _box_cxcywh_to_xyxy(boxes) |
|
return boxes |
|
|
|
|
|
def box_area(boxes: Tensor) -> Tensor: |
|
""" |
|
Computes the area of a set of bounding boxes, which are specified by their |
|
(x1, y1, x2, y2) coordinates. |
|
|
|
Args: |
|
boxes (Tensor[N, 4]): boxes for which the area will be computed. They |
|
are expected to be in (x1, y1, x2, y2) format with |
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``. |
|
|
|
Returns: |
|
Tensor[N]: the area for each box |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(box_area) |
|
boxes = _upcast(boxes) |
|
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
|
|
|
|
|
|
|
|
|
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]: |
|
area1 = box_area(boxes1) |
|
area2 = box_area(boxes2) |
|
|
|
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) |
|
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) |
|
|
|
wh = _upcast(rb - lt).clamp(min=0) |
|
inter = wh[:, :, 0] * wh[:, :, 1] |
|
|
|
union = area1[:, None] + area2 - inter |
|
|
|
return inter, union |
|
|
|
|
|
def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: |
|
""" |
|
Return intersection-over-union (Jaccard index) between two sets of boxes. |
|
|
|
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with |
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``. |
|
|
|
Args: |
|
boxes1 (Tensor[N, 4]): first set of boxes |
|
boxes2 (Tensor[M, 4]): second set of boxes |
|
|
|
Returns: |
|
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(box_iou) |
|
inter, union = _box_inter_union(boxes1, boxes2) |
|
iou = inter / union |
|
return iou |
|
|
|
|
|
|
|
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: |
|
""" |
|
Return generalized intersection-over-union (Jaccard index) between two sets of boxes. |
|
|
|
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with |
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``. |
|
|
|
Args: |
|
boxes1 (Tensor[N, 4]): first set of boxes |
|
boxes2 (Tensor[M, 4]): second set of boxes |
|
|
|
Returns: |
|
Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values |
|
for every element in boxes1 and boxes2 |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(generalized_box_iou) |
|
|
|
inter, union = _box_inter_union(boxes1, boxes2) |
|
iou = inter / union |
|
|
|
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) |
|
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) |
|
|
|
whi = _upcast(rbi - lti).clamp(min=0) |
|
areai = whi[:, :, 0] * whi[:, :, 1] |
|
|
|
return iou - (areai - union) / areai |
|
|
|
|
|
def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor: |
|
""" |
|
Return complete intersection-over-union (Jaccard index) between two sets of boxes. |
|
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with |
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``. |
|
Args: |
|
boxes1 (Tensor[N, 4]): first set of boxes |
|
boxes2 (Tensor[M, 4]): second set of boxes |
|
eps (float, optional): small number to prevent division by zero. Default: 1e-7 |
|
Returns: |
|
Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values |
|
for every element in boxes1 and boxes2 |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(complete_box_iou) |
|
|
|
boxes1 = _upcast(boxes1) |
|
boxes2 = _upcast(boxes2) |
|
|
|
diou, iou = _box_diou_iou(boxes1, boxes2, eps) |
|
|
|
w_pred = boxes1[:, None, 2] - boxes1[:, None, 0] |
|
h_pred = boxes1[:, None, 3] - boxes1[:, None, 1] |
|
|
|
w_gt = boxes2[:, 2] - boxes2[:, 0] |
|
h_gt = boxes2[:, 3] - boxes2[:, 1] |
|
|
|
v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2) |
|
with torch.no_grad(): |
|
alpha = v / (1 - iou + v + eps) |
|
return diou - alpha * v |
|
|
|
|
|
def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor: |
|
""" |
|
Return distance intersection-over-union (Jaccard index) between two sets of boxes. |
|
|
|
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with |
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``. |
|
|
|
Args: |
|
boxes1 (Tensor[N, 4]): first set of boxes |
|
boxes2 (Tensor[M, 4]): second set of boxes |
|
eps (float, optional): small number to prevent division by zero. Default: 1e-7 |
|
|
|
Returns: |
|
Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values |
|
for every element in boxes1 and boxes2 |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(distance_box_iou) |
|
|
|
boxes1 = _upcast(boxes1) |
|
boxes2 = _upcast(boxes2) |
|
diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps) |
|
return diou |
|
|
|
|
|
def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Tensor, Tensor]: |
|
|
|
iou = box_iou(boxes1, boxes2) |
|
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) |
|
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) |
|
whi = _upcast(rbi - lti).clamp(min=0) |
|
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps |
|
|
|
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2 |
|
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2 |
|
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2 |
|
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2 |
|
|
|
centers_distance_squared = (_upcast((x_p[:, None] - x_g[None, :])) ** 2) + ( |
|
_upcast((y_p[:, None] - y_g[None, :])) ** 2 |
|
) |
|
|
|
|
|
return iou - (centers_distance_squared / diagonal_distance_squared), iou |
|
|
|
|
|
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Compute the bounding boxes around the provided masks. |
|
|
|
Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with |
|
``0 <= x1 <= x2`` and ``0 <= y1 <= y2``. |
|
|
|
.. warning:: |
|
|
|
In most cases the output will guarantee ``x1 < x2`` and ``y1 < y2``. But |
|
if the input is degenerate, e.g. if a mask is a single row or a single |
|
column, then the output may have x1 = x2 or y1 = y2. |
|
|
|
Args: |
|
masks (Tensor[N, H, W]): masks to transform where N is the number of masks |
|
and (H, W) are the spatial dimensions. |
|
|
|
Returns: |
|
Tensor[N, 4]: bounding boxes |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(masks_to_boxes) |
|
if masks.numel() == 0: |
|
return torch.zeros((0, 4), device=masks.device, dtype=torch.float) |
|
|
|
n = masks.shape[0] |
|
|
|
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float) |
|
|
|
for index, mask in enumerate(masks): |
|
y, x = torch.where(mask != 0) |
|
|
|
bounding_boxes[index, 0] = torch.min(x) |
|
bounding_boxes[index, 1] = torch.min(y) |
|
bounding_boxes[index, 2] = torch.max(x) |
|
bounding_boxes[index, 3] = torch.max(y) |
|
|
|
return bounding_boxes |
|
|