kywind
update
f96995c
"""
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file found here:
# https://github.com/graphdeco-inria/gaussian-splatting/blob/main/LICENSE.md
#
# For inquiries contact george.drettakis@inria.fr
#######################################################################################################################
##### NOTE: CODE IN THIS FILE IS NOT INCLUDED IN THE OVERALL PROJECT'S MIT LICENSE #####
##### USE OF THIS CODE FOLLOWS THE COPYRIGHT NOTICE ABOVE #####
#######################################################################################################################
"""
import torch
import torch.nn.functional as func
from torch.autograd import Variable
from math import exp
def build_rotation(q):
norm = torch.sqrt(q[:, 0] * q[:, 0] + q[:, 1] * q[:, 1] + q[:, 2] * q[:, 2] + q[:, 3] * q[:, 3])
q = q / norm[:, None]
rot = torch.zeros((q.size(0), 3, 3), device='cuda')
r = q[:, 0]
x = q[:, 1]
y = q[:, 2]
z = q[:, 3]
rot[:, 0, 0] = 1 - 2 * (y * y + z * z)
rot[:, 0, 1] = 2 * (x * y - r * z)
rot[:, 0, 2] = 2 * (x * z + r * y)
rot[:, 1, 0] = 2 * (x * y + r * z)
rot[:, 1, 1] = 1 - 2 * (x * x + z * z)
rot[:, 1, 2] = 2 * (y * z - r * x)
rot[:, 2, 0] = 2 * (x * z - r * y)
rot[:, 2, 1] = 2 * (y * z + r * x)
rot[:, 2, 2] = 1 - 2 * (x * x + y * y)
return rot
def calc_mse(img1, img2):
return ((img1 - img2) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
def calc_psnr(img1, img2):
mse = ((img1 - img2) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
return 20 * torch.log10(1.0 / torch.sqrt(mse))
def gaussian(window_size, sigma):
"""
Generate a 1D Gaussian kernel.
Parameters:
- window_size: The size (length) of the output Gaussian kernel.
- sigma: The standard deviation of the Gaussian distribution.
Returns:
- A 1D tensor representing the Gaussian kernel normalized to have a sum of 1.
"""
# For each position in the desired window size, calculate the Gaussian value.
# The middle of the window corresponds to the peak of the Gaussian.
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
# Normalize the Gaussian kernel to have a sum of 1 and return it.
return gauss / gauss.sum()
def create_window(window_size, channel):
"""
Generate a 2D Gaussian kernel window.
Parameters:
- window_size: The size (width and height) of the output 2D Gaussian kernel.
- channel: Number of channels for which the window will be replicated.
Returns:
- A 4D tensor representing the Gaussian window for the specified number of channels.
"""
# Create a 1D Gaussian kernel of size 'window_size' with standard deviation 1.5.
# The unsqueeze operation adds an extra dimension, making it a 2D tensor.
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
# Compute the outer product of the 1D Gaussian kernel with itself to get a 2D Gaussian kernel.
# This results in a symmetric 2D Gaussian kernel.
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
# Expand the 2D window to have the desired number of channels.
# The expand operation replicates the 2D window for each channel.
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def calc_ssim(img1, img2, window_size=11, size_average=True):
channel = img1.size(-3)
window = create_window(window_size, channel)
# print('img1', img1.device)
# window = window.to(device=img1.device)
# assert torch.isfinite(img1).all(), "img1 contains NaN or Inf"
# assert torch.isfinite(window).all(), "window contains NaN or Inf"
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = func.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = func.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = func.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = func.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = func.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
c1 = 0.01 ** 2
c2 = 0.03 ** 2
ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
def accumulate_mean2d_gradient(variables):
variables['means2D_gradient_accum'][variables['seen']] += torch.norm(
variables['means2D'].grad[variables['seen'], :2], dim=-1)
variables['denom'][variables['seen']] += 1
return variables
def update_params_and_optimizer(new_params, params, optimizer):
for k, v in new_params.items():
group = [x for x in optimizer.param_groups if x["name"] == k][0]
stored_state = optimizer.state.get(group['params'][0], None)
stored_state["exp_avg"] = torch.zeros_like(v)
stored_state["exp_avg_sq"] = torch.zeros_like(v)
del optimizer.state[group['params'][0]]
group["params"][0] = torch.nn.Parameter(v.requires_grad_(True))
optimizer.state[group['params'][0]] = stored_state
params[k] = group["params"][0]
return params
def cat_params_to_optimizer(new_params, params, optimizer):
for k, v in new_params.items():
group = [g for g in optimizer.param_groups if g['name'] == k][0]
stored_state = optimizer.state.get(group['params'][0], None)
if stored_state is not None:
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(v)), dim=0)
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(v)), dim=0)
del optimizer.state[group['params'][0]]
group["params"][0] = torch.nn.Parameter(torch.cat((group["params"][0], v), dim=0).requires_grad_(True))
optimizer.state[group['params'][0]] = stored_state
params[k] = group["params"][0]
else:
group["params"][0] = torch.nn.Parameter(torch.cat((group["params"][0], v), dim=0).requires_grad_(True))
params[k] = group["params"][0]
return params
def remove_points(to_remove, params, variables, optimizer):
"""
Parameters:
- to_remove: A boolean tensor where 'True' indicates the points to remove.
- params: A dictionary containing parameters.
- variables: A dictionary containing various variables.
- optimizer: An optimizer object containing optimization information.
Returns:
- Updated params and variables dictionaries after removal.
"""
# Find the points that we want to keep (the opposite of `to_remove`).
to_keep = ~to_remove
# Extract the keys from `params` except for 'cam_m' and 'cam_c'.
keys = [k for k in params.keys() if k not in ['cam_m', 'cam_c']]
for k in keys:
# Find the parameter group associated with the current key in the optimizer.
group = [g for g in optimizer.param_groups if g['name'] == k][0]
# Try to get the state of this group from the optimizer (this contains momentum information, etc. for optimizers like Adam).
stored_state = optimizer.state.get(group['params'][0], None)
if stored_state is not None:
# Update the stored state by keeping only the desired entries.
stored_state["exp_avg"] = stored_state["exp_avg"][to_keep]
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][to_keep]
# Delete the old state and set a new parameter tensor, keeping only the desired entries and ensuring gradients can be computed.
del optimizer.state[group['params'][0]]
group["params"][0] = torch.nn.Parameter((group["params"][0][to_keep].requires_grad_(True)))
optimizer.state[group['params'][0]] = stored_state
params[k] = group["params"][0]
else:
# If there's no stored state, just update the parameter tensor.
group["params"][0] = torch.nn.Parameter(group["params"][0][to_keep].requires_grad_(True))
params[k] = group["params"][0]
variables['means2D_gradient_accum'] = variables['means2D_gradient_accum'][to_keep]
variables['denom'] = variables['denom'][to_keep]
variables['max_2D_radius'] = variables['max_2D_radius'][to_keep]
return params, variables
def inverse_sigmoid(x):
return torch.log(x / (1 - x))
def densify(params, variables, optimizer, i, grad_thresh, remove_thresh, remove_thresh_5k, scale_scene_radius):
"""
Adjusts the density of points based on various conditions and thresholds.
Parameters:
- params: A dictionary containing parameters.
- variables: A dictionary containing various variables.
- optimizer: An optimizer object containing optimization information.
- i: An iteration or step count.
- remove_thresh: A threshold for removing points.
Returns:
- Updated params and variables dictionaries after adjustment.
"""
if i <= 5000:
variables = accumulate_mean2d_gradient(variables)
if (i >= 500) and (i % 100 == 0):
# Calculate the gradient of the means2D values and handle NaNs.
grads = variables['means2D_gradient_accum'] / variables['denom']
grads[grads.isnan()] = 0.0
# Define points that should be cloned based on gradient thresholds and scales of the points.
to_clone = torch.logical_and(grads >= grad_thresh, (
torch.max(torch.exp(params['log_scales']), dim=1).values <= scale_scene_radius * variables['scene_radius']))
# Extract parameters for points that need cloning.
new_params = {k: v[to_clone] for k, v in params.items() if k not in ['cam_m', 'cam_c']}
params = cat_params_to_optimizer(new_params, params, optimizer)
num_pts = params['means3D'].shape[0]
padded_grad = torch.zeros(num_pts, device="cuda")
padded_grad[:grads.shape[0]] = grads
to_split = torch.logical_and(padded_grad >= grad_thresh,
torch.max(torch.exp(params['log_scales']), dim=1).values > scale_scene_radius * variables[
'scene_radius'])
n = 2 # number to split into
new_params = {k: v[to_split].repeat(n, 1) for k, v in params.items() if k not in ['cam_m', 'cam_c']}
stds = torch.exp(params['log_scales'])[to_split].repeat(n, 1)
means = torch.zeros((stds.size(0), 3), device="cuda")
samples = torch.normal(mean=means, std=stds)
rots = build_rotation(params['unnorm_rotations'][to_split]).repeat(n, 1, 1)
new_params['means3D'] += torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1)
new_params['log_scales'] = torch.log(torch.exp(new_params['log_scales']) / (0.8 * n))
params = cat_params_to_optimizer(new_params, params, optimizer)
num_pts = params['means3D'].shape[0]
variables['means2D_gradient_accum'] = torch.zeros(num_pts, device="cuda")
variables['denom'] = torch.zeros(num_pts, device="cuda")
variables['max_2D_radius'] = torch.zeros(num_pts, device="cuda")
to_remove = torch.cat((to_split, torch.zeros(n * to_split.sum(), dtype=torch.bool, device="cuda")))
params, variables = remove_points(to_remove, params, variables, optimizer)
remove_threshold = remove_thresh_5k if i == 5000 else remove_thresh
to_remove = (torch.sigmoid(params['logit_opacities']) < remove_threshold).squeeze()
# print("num of to remove: ", to_remove.sum())
if i >= 3000:
big_points_ws = torch.exp(params['log_scales']).max(dim=1).values > 0.1 * variables['scene_radius']
# print("num of big points: ", big_points_ws.sum())
to_remove = torch.logical_or(to_remove, big_points_ws)
params, variables = remove_points(to_remove, params, variables, optimizer)
torch.cuda.empty_cache()
if i > 0 and i % 3000 == 0:
new_params = {'logit_opacities': inverse_sigmoid(torch.ones_like(params['logit_opacities']) * 0.01)}
params = update_params_and_optimizer(new_params, params, optimizer)
num_pts = params['means3D'].shape[0]
return params, variables, num_pts