gabrielsemiceki9's picture
Upload 125 files
b72e09b verified
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
import torch.nn as nn
import torch.nn.functional as F
class GANLoss(nn.Module):
def __init__(self, target_real_label=1.0, target_fake_label=0.0):
r"""GAN loss constructor.
Args:
target_real_label (float): Desired output label for the real images.
target_fake_label (float): Desired output label for the fake images.
"""
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_tensor = None
self.fake_label_tensor = None
def forward(self, input_x, t_real, weight=None,
reduce_dim=True, dis_update=True):
r"""GAN loss computation.
Args:
input_x (tensor or list of tensors): Output values.
t_real (boolean): Is this output value for real images.
reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use
multi-resolution discriminators.
weight (float): Weight to scale the loss value.
dis_update (boolean): Updating the discriminator or the generator.
Returns:
loss (tensor): Loss value.
"""
if isinstance(input_x, list):
loss = 0
for pred_i in input_x:
if isinstance(pred_i, list):
pred_i = pred_i[-1]
loss_tensor = self.loss(pred_i, t_real, weight,
reduce_dim, dis_update)
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
loss += new_loss
return loss / len(input_x)
else:
return self.loss(input_x, t_real, weight, reduce_dim, dis_update)
def loss(self, input_x, t_real, weight=None,
reduce_dim=True, dis_update=True):
r"""N+1 label GAN loss computation.
Args:
input_x (tensor): Output values.
t_real (boolean): Is this output value for real images.
reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use
multi-resolution discriminators.
weight (float): Weight to scale the loss value.
dis_update (boolean): Updating the discriminator or the generator.
Returns:
loss (tensor): Loss value.
"""
assert reduce_dim is True
pred = input_x['pred'].clone()
label = input_x['label'].clone()
batch_size = pred.size(0)
# ignore label 0
label[:, 0, ...] = 0
pred[:, 0, ...] = 0
pred = F.log_softmax(pred, dim=1)
assert pred.size(1) == (label.size(1) + 1)
if dis_update:
if t_real:
pred_real = pred[:, :-1, :, :]
loss = - label * pred_real
loss = torch.sum(loss, dim=1, keepdim=True)
else:
pred_fake = pred[:, -1, None, :, :] # N plus 1
loss = - pred_fake
else:
assert t_real, "GAN loss must be aiming for real."
pred_real = pred[:, :-1, :, :]
loss = - label * pred_real
loss = torch.sum(loss, dim=1, keepdim=True)
if weight is not None:
loss = loss * weight
if reduce_dim:
loss = torch.mean(loss)
else:
loss = loss.view(batch_size, -1).mean(dim=1)
return loss