|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import BCEWithLogitsLoss, MSELoss |
|
|
|
from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss |
|
from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss |
|
from .loss_rt_detr import RTDetrForObjectDetectionLoss |
|
|
|
|
|
def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs): |
|
reduction = "sum" if num_items_in_batch is not None else "mean" |
|
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) |
|
if reduction == "sum": |
|
loss = loss / num_items_in_batch |
|
return loss |
|
|
|
|
|
def ForCausalLMLoss( |
|
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs |
|
): |
|
|
|
logits = logits.float() |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
shift_logits = shift_logits.view(-1, vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs) |
|
return loss |
|
|
|
|
|
def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs): |
|
num_labels = config.num_labels |
|
if config.problem_type is None: |
|
if num_labels == 1: |
|
config.problem_type = "regression" |
|
elif num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
config.problem_type = "single_label_classification" |
|
else: |
|
config.problem_type = "multi_label_classification" |
|
|
|
if config.problem_type == "regression": |
|
loss_fct = MSELoss() |
|
if num_labels == 1: |
|
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(pooled_logits, labels) |
|
elif config.problem_type == "single_label_classification": |
|
loss = fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs) |
|
elif config.problem_type == "multi_label_classification": |
|
loss_fct = BCEWithLogitsLoss() |
|
loss = loss_fct(pooled_logits, labels) |
|
return loss |
|
|
|
|
|
def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs): |
|
total_loss = None |
|
if start_positions is not None and end_positions is not None: |
|
|
|
if len(start_positions.size()) > 1: |
|
start_positions = start_positions.squeeze(-1).to(start_logits.device) |
|
if len(end_positions.size()) > 1: |
|
end_positions = end_positions.squeeze(-1).to(end_logits.device) |
|
|
|
ignored_index = start_logits.size(1) |
|
start_positions = start_positions.clamp(0, ignored_index) |
|
end_positions = end_positions.clamp(0, ignored_index) |
|
|
|
start_loss = fixed_cross_entropy(start_logits, start_positions, ignore_index=ignored_index, **kwargs) |
|
end_loss = fixed_cross_entropy(end_logits, end_positions, ignore_index=ignored_index, **kwargs) |
|
total_loss = (start_loss + end_loss) / 2 |
|
return total_loss |
|
|
|
|
|
def ForTokenClassification(logits, labels, config, **kwargs): |
|
|
|
logits = logits.view(-1, config.num_labels) |
|
labels = labels.view(-1) |
|
logits = logits.float() |
|
|
|
return fixed_cross_entropy(logits, labels, **kwargs) |
|
|
|
|
|
LOSS_MAPPING = { |
|
"ForCausalLM": ForCausalLMLoss, |
|
"ForQuestionAnswering": ForQuestionAnsweringLoss, |
|
"ForSequenceClassification": ForSequenceClassificationLoss, |
|
"ForTokenClassification": ForTokenClassification, |
|
"ForSegmentation": ForSegmentationLoss, |
|
"ForObjectDetection": ForObjectDetectionLoss, |
|
"DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss, |
|
"ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss, |
|
"GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss, |
|
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss, |
|
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss, |
|
} |
|
|