|
import torch |
|
import torch.nn.functional as F |
|
|
|
from ..utils import _log_api_usage_once |
|
|
|
|
|
def sigmoid_focal_loss( |
|
inputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
alpha: float = 0.25, |
|
gamma: float = 2, |
|
reduction: str = "none", |
|
) -> torch.Tensor: |
|
""" |
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. |
|
|
|
Args: |
|
inputs (Tensor): A float tensor of arbitrary shape. |
|
The predictions for each example. |
|
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary |
|
classification label for each element in inputs |
|
(0 for the negative class and 1 for the positive class). |
|
alpha (float): Weighting factor in range [0, 1] to balance |
|
positive vs negative examples or -1 for ignore. Default: ``0.25``. |
|
gamma (float): Exponent of the modulating factor (1 - p_t) to |
|
balance easy vs hard examples. Default: ``2``. |
|
reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` |
|
``'none'``: No reduction will be applied to the output. |
|
``'mean'``: The output will be averaged. |
|
``'sum'``: The output will be summed. Default: ``'none'``. |
|
Returns: |
|
Loss tensor with the reduction option applied. |
|
""" |
|
|
|
|
|
if not (0 <= alpha <= 1) and alpha != -1: |
|
raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.") |
|
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(sigmoid_focal_loss) |
|
p = torch.sigmoid(inputs) |
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") |
|
p_t = p * targets + (1 - p) * (1 - targets) |
|
loss = ce_loss * ((1 - p_t) ** gamma) |
|
|
|
if alpha >= 0: |
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) |
|
loss = alpha_t * loss |
|
|
|
|
|
if reduction == "none": |
|
pass |
|
elif reduction == "mean": |
|
loss = loss.mean() |
|
elif reduction == "sum": |
|
loss = loss.sum() |
|
else: |
|
raise ValueError( |
|
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" |
|
) |
|
return loss |
|
|