lino / src /utils /loss.py
algohunt
initial_commit
c295391
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
class MultiScaleDerivativeLoss(nn.Module):
def __init__(self, operator='scharr', p=1, reduction='mean', normalize_input=False, num_scales=4):
"""
operator: 'scharr' (一阶) or 'laplace' (二阶)
p: 1 for L1, 2 for L2
reduction: 'mean' or 'sum'
normalize_input: whether to normalize input vectors (for normals)
num_scales: number of scales in the pyramid (e.g., 4 = 原图, 1/2, 1/4, 1/8)
"""
super().__init__()
assert operator in ['scharr', 'laplace']
assert p in [1, 2]
assert reduction in ['mean', 'sum']
assert num_scales >= 1
self.operator = operator
self.p = p
self.reduction = reduction
self.normalize_input = normalize_input
self.num_scales = num_scales
def forward(self, pred, gt):
"""
pred, gt: [B, C, H, W] tensors
"""
pred_pyramid = self._build_pyramid(pred)
gt_pyramid = self._build_pyramid(gt)
total_loss = 0.0
for pred_i, gt_i in zip(pred_pyramid, gt_pyramid):
if self.normalize_input:
pred_i = F.normalize(pred_i, dim=1)
gt_i = F.normalize(gt_i, dim=1)
grad_pred = self._compute_gradient(pred_i)
grad_gt = self._compute_gradient(gt_i)
diff = grad_pred - grad_gt
if self.p == 1:
diff = torch.abs(diff)
else:
diff = diff ** 2
if self.reduction == 'mean':
total_loss += diff.mean()
else:
total_loss += diff.sum()
return total_loss / self.num_scales
def _build_pyramid(self, img):
"""Construct a multi-scale pyramid from input image"""
pyramid = [img]
for i in range(1, self.num_scales):
scale = 0.5 ** i
img = F.interpolate(img, scale_factor=scale, mode='bicubic', align_corners=False, recompute_scale_factor=True,antialias=True)
pyramid.append(img)
return pyramid
def _compute_gradient(self, img):
B, C, H, W = img.shape
device = img.device
if self.operator == 'scharr':
kernel_x = torch.tensor([[[-3., 0., 3.],
[-10., 0., 10.],
[-3., 0., 3.]]], device=device) / 16.0
kernel_y = torch.tensor([[[-3., -10., -3.],
[0., 0., 0.],
[3., 10., 3.]]], device=device) / 16.0
kernel_x = kernel_x.unsqueeze(0).expand(C, 1, 3, 3)
kernel_y = kernel_y.unsqueeze(0).expand(C, 1, 3, 3)
grad_x = F.conv2d(img, kernel_x, padding=1, groups=C)
grad_y = F.conv2d(img, kernel_y, padding=1, groups=C)
return torch.cat([grad_x, grad_y], dim=1) # [B, 2C, H, W]
elif self.operator == 'laplace':
kernel = torch.tensor([[[0., 1., 0.],
[1., -4., 1.],
[0., 1., 0.]]], device=device)
kernel = kernel.unsqueeze(0).expand(C, 1, 3, 3)
return F.conv2d(img, kernel, padding=1, groups=C) # [B, C, H, W]
class CosineLoss(torch.nn.Module):
def __init__(self):
super(CosineLoss, self).__init__()
def forward(self, N, N_hat):
"""
N: 真实法向量, 形状 (B, C, H, W)
N_hat: 预测法向量, 形状应与 N 相同
"""
# 创建非零 mask(按像素维度求L2范数)
_,_,H,W = N.shape
mask = (N.norm(p=2, dim=1, keepdim=True) > 0) # shape: (B, 1, H, W),True表示N非零
mse = F.mse_loss(N, N_hat, reduction='mean') * H * W /2048
dot_product = torch.sum(N * N_hat, dim=1, keepdim=True) # shape: (B, 1, H, W)
# 仅在非零区域计算 loss
loss = 1 - dot_product
loss = loss[mask] # 只取非零像素位置
return loss.mean(), mse
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average = True, stride=None):
mu1 = F.conv2d(img1, window, padding = (window_size-1)//2, groups = channel, stride=stride)
mu2 = F.conv2d(img2, window, padding = (window_size-1)//2, groups = channel, stride=stride)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = (window_size-1)//2, groups = channel, stride=stride) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = (window_size-1)//2, groups = channel, stride=stride) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = (window_size-1)//2, groups = channel, stride=stride) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size = 3, size_average = True, stride=3):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.stride = stride
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
"""
img1, img2: torch.Tensor([b,c,h,w])
"""
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average, stride=self.stride)
def ssim(img1, img2, window_size = 11, size_average = True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
class S3IM(torch.nn.Module):
def __init__(self, kernel_size=4, stride=4, repeat_time=10, patch_height=64, patch_width=32):
super(S3IM, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.repeat_time = repeat_time
self.patch_height = patch_height
self.patch_width = patch_width
self.ssim_loss = SSIM(window_size=self.kernel_size, stride=self.stride)
def forward(self, src_vec, tar_vec):
"""
Args:
src_vec: [B, N, C] e.g., [batch, pixels, channels]
tar_vec: [B, N, C]
Returns:
loss: scalar tensor
"""
B, N, C = src_vec.shape
device = src_vec.device
patch_list_src, patch_list_tar = [], []
for b in range(B):
index_list = []
for i in range(self.repeat_time):
if i == 0:
tmp_index = torch.arange(N, device=device)
else:
tmp_index = torch.randperm(N, device=device)
index_list.append(tmp_index)
res_index = torch.cat(index_list) # [M * N]
tar_all = tar_vec[b][res_index] # [M*N, C]
src_all = src_vec[b][res_index] # [M*N, C]
# reshape into [1, C, H, W]
tar_patch = tar_all.permute(1, 0).reshape(1, C, self.patch_height, self.patch_width * self.repeat_time)
src_patch = src_all.permute(1, 0).reshape(1, C, self.patch_height, self.patch_width * self.repeat_time)
patch_list_tar.append(tar_patch)
patch_list_src.append(src_patch)
# Stack all batches: [B, C, H, W]
tar_tensor = torch.cat(patch_list_tar, dim=0)
src_tensor = torch.cat(patch_list_src, dim=0)
# 计算 batch-wise SSIM,输出为 [B]
ssim_scores = self.ssim_loss(src_tensor, tar_tensor)
# 损失为 1 - mean SSIM
loss = 1.0 - ssim_scores
return loss
torch.manual_seed(0)
# 假设每张图片提取出 64 x 64 个像素,每个像素 3 通道
# H, W, C = 64, 32, 3
# N = H * W
# B = 4
# # 随机生成两个图像特征向量:[N, C]
# src_vec = torch.rand(B, N, C) # 模拟重建图像
# tar_vec = torch.rand(B, N, C) # 模拟 ground truth 图像
# # 初始化 S3IM 模块
# s3im_loss_fn = S3IM(kernel_size=4, stride=4, repeat_time=10, patch_height=64, patch_width=32)
# # 计算损失
# loss = s3im_loss_fn(src_vec, tar_vec)
def weighted_huber_loss(
input: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor, # 新增的置信度权重张量
reduction: str = 'mean',
delta: float = 1.0,
) -> torch.Tensor:
# 广播对齐所有张量
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
expanded_weight, _ = torch.broadcast_tensors(weight, input) # 确保权重可广播
# 计算逐元素误差
diff = expanded_input - expanded_target
abs_diff = torch.abs(diff)
# Huber损失分段计算
loss = torch.where(
abs_diff <= delta,
0.5 * (diff ** 2),
delta * (abs_diff - 0.5 * delta)
)
# 应用权重
weighted_loss = expanded_weight * loss
# 汇总方式
if reduction == 'mean':
return torch.mean(weighted_loss)
elif reduction == 'sum':
return torch.sum(weighted_loss)
elif reduction == 'none':
return weighted_loss
else:
raise ValueError(f"Unsupported reduction: {reduction}")