|
from typing import Tuple |
|
|
|
import torch |
|
|
|
from ..utils import _log_api_usage_once |
|
from ._utils import _loss_inter_union, _upcast_non_float |
|
|
|
|
|
def distance_box_iou_loss( |
|
boxes1: torch.Tensor, |
|
boxes2: torch.Tensor, |
|
reduction: str = "none", |
|
eps: float = 1e-7, |
|
) -> torch.Tensor: |
|
|
|
""" |
|
Gradient-friendly IoU loss with an additional penalty that is non-zero when the |
|
distance between boxes' centers isn't zero. Indeed, for two exactly overlapping |
|
boxes, the distance IoU is the same as the IoU loss. |
|
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. |
|
|
|
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with |
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the |
|
same dimensions. |
|
|
|
Args: |
|
boxes1 (Tensor[N, 4]): first set of boxes |
|
boxes2 (Tensor[N, 4]): second set of boxes |
|
reduction (string, optional): Specifies the reduction to apply to the output: |
|
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be |
|
applied to the output. ``'mean'``: The output will be averaged. |
|
``'sum'``: The output will be summed. Default: ``'none'`` |
|
eps (float, optional): small number to prevent division by zero. Default: 1e-7 |
|
|
|
Returns: |
|
Tensor: Loss tensor with the reduction option applied. |
|
|
|
Reference: |
|
Zhaohui Zheng et al.: Distance Intersection over Union Loss: |
|
https://arxiv.org/abs/1911.08287 |
|
""" |
|
|
|
|
|
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(distance_box_iou_loss) |
|
|
|
boxes1 = _upcast_non_float(boxes1) |
|
boxes2 = _upcast_non_float(boxes2) |
|
|
|
loss, _ = _diou_iou_loss(boxes1, boxes2, eps) |
|
|
|
|
|
if reduction == "none": |
|
pass |
|
elif reduction == "mean": |
|
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() |
|
elif reduction == "sum": |
|
loss = loss.sum() |
|
else: |
|
raise ValueError( |
|
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" |
|
) |
|
return loss |
|
|
|
|
|
def _diou_iou_loss( |
|
boxes1: torch.Tensor, |
|
boxes2: torch.Tensor, |
|
eps: float = 1e-7, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
intsct, union = _loss_inter_union(boxes1, boxes2) |
|
iou = intsct / (union + eps) |
|
|
|
x1, y1, x2, y2 = boxes1.unbind(dim=-1) |
|
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) |
|
xc1 = torch.min(x1, x1g) |
|
yc1 = torch.min(y1, y1g) |
|
xc2 = torch.max(x2, x2g) |
|
yc2 = torch.max(y2, y2g) |
|
|
|
diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps |
|
|
|
x_p = (x2 + x1) / 2 |
|
y_p = (y2 + y1) / 2 |
|
x_g = (x1g + x2g) / 2 |
|
y_g = (y1g + y2g) / 2 |
|
|
|
centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) |
|
|
|
|
|
loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared) |
|
return loss, iou |
|
|