File size: 7,372 Bytes
f96995c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from pathlib import Path
import torch
import numpy as np
import random
from PIL import Image

from pgnd.utils import get_root
import sys
root: Path = get_root(__file__)

from diff_gaussian_rasterization import GaussianRasterizer as Renderer
from gs.helpers import setup_camera, l1_loss_v1, o3d_knn, params2rendervar
from gs.external import calc_ssim, calc_psnr


def get_custom_dataset(img_list, seg_list, metadata):
    """
    Generates a dataset from given metadata and sequence.
    """
    dataset = []
    
    # Loop over filenames corresponding to 't' in the metadata
    for c in range(len(img_list)):
        
        # Extract parameters from the metadata
        w, h = metadata['w'], metadata['h']
        k = metadata['k'][c]
        w2c = metadata['w2c'][c]
        
        # Set up a camera using extracted parameters and some default values
        cam = setup_camera(w, h, k, w2c, near=0.01, far=100)
        
        # Get the filename of the current image and open it
        if isinstance(img_list[c], str):
            im = np.array(Image.open(img_list[c]))
            im = torch.tensor(im).float().cuda().permute(2, 0, 1) / 255
        else:
            im = torch.tensor(img_list[c]).permute(2, 0, 1).float().cuda()
            if im.max() > 2.0: 
                im = im / 255
        
        # Open the corresponding segmentation image and convert it to a tensor
        if isinstance(seg_list[c], str):
            seg = np.array(Image.open(seg_list[c])).astype(np.float32)
        else:
            seg = seg_list[c].astype(np.float32)
        seg = torch.tensor(seg).float().cuda()

        # Create a color segmentation tensor. It seems to treat the segmentation as binary (object/background)
        seg_col = torch.stack((seg, torch.zeros_like(seg), 1 - seg))
        
        # Add the data to the dataset
        dataset.append({'cam': cam, 'im': im, 'seg': seg_col, 'id': c})

    return dataset


def initialize_params(init_pt_cld, metadata):
    """
    Initializes parameters and variables required for a 3D point cloud based on provided data.
    """

    # Extract the segmentation data.
    seg = init_pt_cld[:, 6]

    # Define a constant for the maximum number of cameras.
    max_cams = 4

    # Compute the squared distance for the K-Nearest Neighbors of each point in the 3D point cloud.
    sq_dist, indices = o3d_knn(init_pt_cld[:, :3], 3)

    # Calculate the mean squared distance for the 3 closest points and clip its minimum value.
    mean3_sq_dist = sq_dist.mean(-1).clip(min=0.0000001)

    # Initialize various parameters related to the 3D point cloud.
    params = {
        'means3D': init_pt_cld[:, :3],                          # 3D coordinates of the points.
        'rgb_colors': init_pt_cld[:, 3:6],                      # RGB color values for the points.
        'seg_colors': np.stack((seg, np.zeros_like(seg), 1 - seg), -1),  # Segmentation colors.
        'unnorm_rotations': np.tile([1, 0, 0, 0], (seg.shape[0], 1)),  # Default rotations for each point.
        'logit_opacities': np.zeros((seg.shape[0], 1)),         # Initial opacity values for the points.
        'log_scales': np.tile(np.log(np.sqrt(mean3_sq_dist))[..., None], (1, 3)),  # Scale factors for the points.
        'cam_m': np.zeros((max_cams, 3)),                       # Placeholder for camera motion.
        'cam_c': np.zeros((max_cams, 3)),                       # Placeholder for camera center.
    }

    # Convert the params to PyTorch tensors and move them to the GPU.
    params = {k: torch.nn.Parameter(torch.tensor(v).cuda().float().contiguous().requires_grad_(True)) for k, v in params.items()}
    # params['rgb_colors'].requires_grad = False

    # Calculate the scene radius based on the camera centers.
    cam_centers = np.linalg.inv(metadata['w2c'])[:, :3, 3]
    scene_radius = 1.1 * np.max(np.linalg.norm(cam_centers - np.mean(cam_centers, 0)[None], axis=-1))

    # Initialize other associated variables.
    variables = {
        'max_2D_radius': torch.zeros(params['means3D'].shape[0]).cuda().float(), # Maximum 2D radius.
        'scene_radius': scene_radius,                                          # Scene radius.
        'means2D_gradient_accum': torch.zeros(params['means3D'].shape[0]).cuda().float(), # Means2D gradient accumulator.
        'denom': torch.zeros(params['means3D'].shape[0]).cuda().float()        # Denominator.
    }

    return params, variables


def initialize_optimizer(params, variables):
    lrs = {
        'means3D': 0.00016 * variables['scene_radius'],
        'rgb_colors': 0.0,
        'seg_colors': 0.0,
        'unnorm_rotations': 0.001,
        'logit_opacities': 0.05,
        'log_scales': 0.001,
        'cam_m': 1e-4,
        'cam_c': 1e-4,
    }
    param_groups = [{'params': [v], 'name': k, 'lr': lrs[k]} for k, v in params.items()]
    return torch.optim.Adam(param_groups, lr=0.0, eps=1e-15)


def get_loss(params, curr_data, variables, loss_weights):

    # Initialize dictionary to store various loss components
    losses = {}

    # Convert parameters to rendering variables and retain gradient for 'means2D'
    rendervar = params2rendervar(params)
    rendervar['means2D'].retain_grad()

    # Perform rendering to obtain image, radius, and other outputs
    im, radius, _ = Renderer(raster_settings=curr_data['cam'])(**rendervar)

    # Apply camera parameters to modify the rendered image
    curr_id = curr_data['id']
    im = torch.exp(params['cam_m'][curr_id])[:, None, None] * im + params['cam_c'][curr_id][:, None, None]

    # Calculate image loss using L1 loss and ssim
    losses['im'] = 0.8 * l1_loss_v1(im, curr_data['im']) + 0.2 * (1.0 - calc_ssim(im, curr_data['im']))
    variables['means2D'] = rendervar['means2D']  # Gradient only accum from colour render for densification

    segrendervar = params2rendervar(params)
    segrendervar['colors_precomp'] = params['seg_colors']
    seg, _, _, = Renderer(raster_settings=curr_data['cam'])(**segrendervar)

    # Calculate segmentation loss
    losses['seg'] = 0.8 * l1_loss_v1(seg, curr_data['seg']) + 0.2 * (1.0 - calc_ssim(seg, curr_data['seg']))

    # Calculate total loss as weighted sum of individual losses
    loss = sum([loss_weights[k] * v for k, v in losses.items()])

    # Update variables related to rendering radius and seen areas
    seen = radius > 0
    variables['max_2D_radius'][seen] = torch.max(radius[seen], variables['max_2D_radius'][seen])
    variables['seen'] = seen
    return loss, variables


def report_progress(params, data, i, progress_bar, num_pts, every_i=100, vis_dir=None):
    if i % every_i == 0:
        im, _, _, = Renderer(raster_settings=data['cam'])(**params2rendervar(params))
        curr_id = data['id']
        im = torch.exp(params['cam_m'][curr_id])[:, None, None] * im + params['cam_c'][curr_id][:, None, None]
        if vis_dir:
            Image.fromarray((im.cpu().numpy().clip(0, 1) * 255).astype(np.uint8).transpose(1, 2, 0)).save(f"{vis_dir}/{i:06d}.png")
        psnr = calc_psnr(im, data['im']).mean()
        progress_bar.set_postfix({"img 0 PSNR": f"{psnr:.{7}f}, number of points: {num_pts}"})
        progress_bar.update(every_i)


def get_batch(todo_dataset, dataset):
    if not todo_dataset:
        todo_dataset = dataset.copy()
    curr_data = todo_dataset.pop(random.randint(0, len(todo_dataset) - 1))
    return curr_data