|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn, Tensor |
|
|
|
|
|
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor: |
|
""" |
|
Efficient version of torch.cat that avoids a copy if there is only a single element in a list |
|
""" |
|
|
|
|
|
if len(tensors) == 1: |
|
return tensors[0] |
|
return torch.cat(tensors, dim) |
|
|
|
|
|
def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor: |
|
concat_boxes = _cat([b for b in boxes], dim=0) |
|
temp = [] |
|
for i, b in enumerate(boxes): |
|
temp.append(torch.full_like(b[:, :1], i)) |
|
ids = _cat(temp, dim=0) |
|
rois = torch.cat([ids, concat_boxes], dim=1) |
|
return rois |
|
|
|
|
|
def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]): |
|
if isinstance(boxes, (list, tuple)): |
|
for _tensor in boxes: |
|
torch._assert( |
|
_tensor.size(1) == 4, "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]" |
|
) |
|
elif isinstance(boxes, torch.Tensor): |
|
torch._assert(boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]") |
|
else: |
|
torch._assert(False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]") |
|
return |
|
|
|
|
|
def split_normalization_params( |
|
model: nn.Module, norm_classes: Optional[List[type]] = None |
|
) -> Tuple[List[Tensor], List[Tensor]]: |
|
|
|
if not norm_classes: |
|
norm_classes = [ |
|
nn.modules.batchnorm._BatchNorm, |
|
nn.LayerNorm, |
|
nn.GroupNorm, |
|
nn.modules.instancenorm._InstanceNorm, |
|
nn.LocalResponseNorm, |
|
] |
|
|
|
for t in norm_classes: |
|
if not issubclass(t, nn.Module): |
|
raise ValueError(f"Class {t} is not a subclass of nn.Module.") |
|
|
|
classes = tuple(norm_classes) |
|
|
|
norm_params = [] |
|
other_params = [] |
|
for module in model.modules(): |
|
if next(module.children(), None): |
|
other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad) |
|
elif isinstance(module, classes): |
|
norm_params.extend(p for p in module.parameters() if p.requires_grad) |
|
else: |
|
other_params.extend(p for p in module.parameters() if p.requires_grad) |
|
return norm_params, other_params |
|
|
|
|
|
def _upcast(t: Tensor) -> Tensor: |
|
|
|
if t.is_floating_point(): |
|
return t if t.dtype in (torch.float32, torch.float64) else t.float() |
|
else: |
|
return t if t.dtype in (torch.int32, torch.int64) else t.int() |
|
|
|
|
|
def _upcast_non_float(t: Tensor) -> Tensor: |
|
|
|
if t.dtype not in (torch.float32, torch.float64): |
|
return t.float() |
|
return t |
|
|
|
|
|
def _loss_inter_union( |
|
boxes1: torch.Tensor, |
|
boxes2: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
x1, y1, x2, y2 = boxes1.unbind(dim=-1) |
|
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) |
|
|
|
|
|
xkis1 = torch.max(x1, x1g) |
|
ykis1 = torch.max(y1, y1g) |
|
xkis2 = torch.min(x2, x2g) |
|
ykis2 = torch.min(y2, y2g) |
|
|
|
intsctk = torch.zeros_like(x1) |
|
mask = (ykis2 > ykis1) & (xkis2 > xkis1) |
|
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) |
|
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk |
|
|
|
return intsctk, unionk |
|
|