Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import numpy as np | |
from math import exp | |
class MultiScaleDerivativeLoss(nn.Module): | |
def __init__(self, operator='scharr', p=1, reduction='mean', normalize_input=False, num_scales=4): | |
""" | |
operator: 'scharr' (一阶) or 'laplace' (二阶) | |
p: 1 for L1, 2 for L2 | |
reduction: 'mean' or 'sum' | |
normalize_input: whether to normalize input vectors (for normals) | |
num_scales: number of scales in the pyramid (e.g., 4 = 原图, 1/2, 1/4, 1/8) | |
""" | |
super().__init__() | |
assert operator in ['scharr', 'laplace'] | |
assert p in [1, 2] | |
assert reduction in ['mean', 'sum'] | |
assert num_scales >= 1 | |
self.operator = operator | |
self.p = p | |
self.reduction = reduction | |
self.normalize_input = normalize_input | |
self.num_scales = num_scales | |
def forward(self, pred, gt): | |
""" | |
pred, gt: [B, C, H, W] tensors | |
""" | |
pred_pyramid = self._build_pyramid(pred) | |
gt_pyramid = self._build_pyramid(gt) | |
total_loss = 0.0 | |
for pred_i, gt_i in zip(pred_pyramid, gt_pyramid): | |
if self.normalize_input: | |
pred_i = F.normalize(pred_i, dim=1) | |
gt_i = F.normalize(gt_i, dim=1) | |
grad_pred = self._compute_gradient(pred_i) | |
grad_gt = self._compute_gradient(gt_i) | |
diff = grad_pred - grad_gt | |
if self.p == 1: | |
diff = torch.abs(diff) | |
else: | |
diff = diff ** 2 | |
if self.reduction == 'mean': | |
total_loss += diff.mean() | |
else: | |
total_loss += diff.sum() | |
return total_loss / self.num_scales | |
def _build_pyramid(self, img): | |
"""Construct a multi-scale pyramid from input image""" | |
pyramid = [img] | |
for i in range(1, self.num_scales): | |
scale = 0.5 ** i | |
img = F.interpolate(img, scale_factor=scale, mode='bicubic', align_corners=False, recompute_scale_factor=True,antialias=True) | |
pyramid.append(img) | |
return pyramid | |
def _compute_gradient(self, img): | |
B, C, H, W = img.shape | |
device = img.device | |
if self.operator == 'scharr': | |
kernel_x = torch.tensor([[[-3., 0., 3.], | |
[-10., 0., 10.], | |
[-3., 0., 3.]]], device=device) / 16.0 | |
kernel_y = torch.tensor([[[-3., -10., -3.], | |
[0., 0., 0.], | |
[3., 10., 3.]]], device=device) / 16.0 | |
kernel_x = kernel_x.unsqueeze(0).expand(C, 1, 3, 3) | |
kernel_y = kernel_y.unsqueeze(0).expand(C, 1, 3, 3) | |
grad_x = F.conv2d(img, kernel_x, padding=1, groups=C) | |
grad_y = F.conv2d(img, kernel_y, padding=1, groups=C) | |
return torch.cat([grad_x, grad_y], dim=1) # [B, 2C, H, W] | |
elif self.operator == 'laplace': | |
kernel = torch.tensor([[[0., 1., 0.], | |
[1., -4., 1.], | |
[0., 1., 0.]]], device=device) | |
kernel = kernel.unsqueeze(0).expand(C, 1, 3, 3) | |
return F.conv2d(img, kernel, padding=1, groups=C) # [B, C, H, W] | |
class CosineLoss(torch.nn.Module): | |
def __init__(self): | |
super(CosineLoss, self).__init__() | |
def forward(self, N, N_hat): | |
""" | |
N: 真实法向量, 形状 (B, C, H, W) | |
N_hat: 预测法向量, 形状应与 N 相同 | |
""" | |
# 创建非零 mask(按像素维度求L2范数) | |
_,_,H,W = N.shape | |
mask = (N.norm(p=2, dim=1, keepdim=True) > 0) # shape: (B, 1, H, W),True表示N非零 | |
mse = F.mse_loss(N, N_hat, reduction='mean') * H * W /2048 | |
dot_product = torch.sum(N * N_hat, dim=1, keepdim=True) # shape: (B, 1, H, W) | |
# 仅在非零区域计算 loss | |
loss = 1 - dot_product | |
loss = loss[mask] # 只取非零像素位置 | |
return loss.mean(), mse | |
def gaussian(window_size, sigma): | |
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) | |
return gauss/gauss.sum() | |
def create_window(window_size, channel): | |
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) | |
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | |
return window | |
def _ssim(img1, img2, window, window_size, channel, size_average = True, stride=None): | |
mu1 = F.conv2d(img1, window, padding = (window_size-1)//2, groups = channel, stride=stride) | |
mu2 = F.conv2d(img2, window, padding = (window_size-1)//2, groups = channel, stride=stride) | |
mu1_sq = mu1.pow(2) | |
mu2_sq = mu2.pow(2) | |
mu1_mu2 = mu1*mu2 | |
sigma1_sq = F.conv2d(img1*img1, window, padding = (window_size-1)//2, groups = channel, stride=stride) - mu1_sq | |
sigma2_sq = F.conv2d(img2*img2, window, padding = (window_size-1)//2, groups = channel, stride=stride) - mu2_sq | |
sigma12 = F.conv2d(img1*img2, window, padding = (window_size-1)//2, groups = channel, stride=stride) - mu1_mu2 | |
C1 = 0.01**2 | |
C2 = 0.03**2 | |
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) | |
if size_average: | |
return ssim_map.mean() | |
else: | |
return ssim_map.mean(1).mean(1).mean(1) | |
class SSIM(torch.nn.Module): | |
def __init__(self, window_size = 3, size_average = True, stride=3): | |
super(SSIM, self).__init__() | |
self.window_size = window_size | |
self.size_average = size_average | |
self.channel = 1 | |
self.stride = stride | |
self.window = create_window(window_size, self.channel) | |
def forward(self, img1, img2): | |
""" | |
img1, img2: torch.Tensor([b,c,h,w]) | |
""" | |
(_, channel, _, _) = img1.size() | |
if channel == self.channel and self.window.data.type() == img1.data.type(): | |
window = self.window | |
else: | |
window = create_window(self.window_size, channel) | |
if img1.is_cuda: | |
window = window.cuda(img1.get_device()) | |
window = window.type_as(img1) | |
self.window = window | |
self.channel = channel | |
return _ssim(img1, img2, window, self.window_size, channel, self.size_average, stride=self.stride) | |
def ssim(img1, img2, window_size = 11, size_average = True): | |
(_, channel, _, _) = img1.size() | |
window = create_window(window_size, channel) | |
if img1.is_cuda: | |
window = window.cuda(img1.get_device()) | |
window = window.type_as(img1) | |
return _ssim(img1, img2, window, window_size, channel, size_average) | |
class S3IM(torch.nn.Module): | |
def __init__(self, kernel_size=4, stride=4, repeat_time=10, patch_height=64, patch_width=32): | |
super(S3IM, self).__init__() | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.repeat_time = repeat_time | |
self.patch_height = patch_height | |
self.patch_width = patch_width | |
self.ssim_loss = SSIM(window_size=self.kernel_size, stride=self.stride) | |
def forward(self, src_vec, tar_vec): | |
""" | |
Args: | |
src_vec: [B, N, C] e.g., [batch, pixels, channels] | |
tar_vec: [B, N, C] | |
Returns: | |
loss: scalar tensor | |
""" | |
B, N, C = src_vec.shape | |
device = src_vec.device | |
patch_list_src, patch_list_tar = [], [] | |
for b in range(B): | |
index_list = [] | |
for i in range(self.repeat_time): | |
if i == 0: | |
tmp_index = torch.arange(N, device=device) | |
else: | |
tmp_index = torch.randperm(N, device=device) | |
index_list.append(tmp_index) | |
res_index = torch.cat(index_list) # [M * N] | |
tar_all = tar_vec[b][res_index] # [M*N, C] | |
src_all = src_vec[b][res_index] # [M*N, C] | |
# reshape into [1, C, H, W] | |
tar_patch = tar_all.permute(1, 0).reshape(1, C, self.patch_height, self.patch_width * self.repeat_time) | |
src_patch = src_all.permute(1, 0).reshape(1, C, self.patch_height, self.patch_width * self.repeat_time) | |
patch_list_tar.append(tar_patch) | |
patch_list_src.append(src_patch) | |
# Stack all batches: [B, C, H, W] | |
tar_tensor = torch.cat(patch_list_tar, dim=0) | |
src_tensor = torch.cat(patch_list_src, dim=0) | |
# 计算 batch-wise SSIM,输出为 [B] | |
ssim_scores = self.ssim_loss(src_tensor, tar_tensor) | |
# 损失为 1 - mean SSIM | |
loss = 1.0 - ssim_scores | |
return loss | |
torch.manual_seed(0) | |
# 假设每张图片提取出 64 x 64 个像素,每个像素 3 通道 | |
# H, W, C = 64, 32, 3 | |
# N = H * W | |
# B = 4 | |
# # 随机生成两个图像特征向量:[N, C] | |
# src_vec = torch.rand(B, N, C) # 模拟重建图像 | |
# tar_vec = torch.rand(B, N, C) # 模拟 ground truth 图像 | |
# # 初始化 S3IM 模块 | |
# s3im_loss_fn = S3IM(kernel_size=4, stride=4, repeat_time=10, patch_height=64, patch_width=32) | |
# # 计算损失 | |
# loss = s3im_loss_fn(src_vec, tar_vec) | |
def weighted_huber_loss( | |
input: torch.Tensor, | |
target: torch.Tensor, | |
weight: torch.Tensor, # 新增的置信度权重张量 | |
reduction: str = 'mean', | |
delta: float = 1.0, | |
) -> torch.Tensor: | |
# 广播对齐所有张量 | |
expanded_input, expanded_target = torch.broadcast_tensors(input, target) | |
expanded_weight, _ = torch.broadcast_tensors(weight, input) # 确保权重可广播 | |
# 计算逐元素误差 | |
diff = expanded_input - expanded_target | |
abs_diff = torch.abs(diff) | |
# Huber损失分段计算 | |
loss = torch.where( | |
abs_diff <= delta, | |
0.5 * (diff ** 2), | |
delta * (abs_diff - 0.5 * delta) | |
) | |
# 应用权重 | |
weighted_loss = expanded_weight * loss | |
# 汇总方式 | |
if reduction == 'mean': | |
return torch.mean(weighted_loss) | |
elif reduction == 'sum': | |
return torch.sum(weighted_loss) | |
elif reduction == 'none': | |
return weighted_loss | |
else: | |
raise ValueError(f"Unsupported reduction: {reduction}") |