Spaces:
Sleeping
Sleeping
| # Multi-HMR | |
| # Copyright (c) 2024-present NAVER Corp. | |
| # CC BY-NC-SA 4.0 license | |
| import torch | |
| from torch import nn | |
| import numpy as np | |
| class FourierPositionEncoding(nn.Module): | |
| def __init__(self, n, num_bands, max_resolution): | |
| """ | |
| Module that generate Fourier encoding - no learning involved | |
| """ | |
| super().__init__() | |
| self.num_bands = num_bands | |
| self.max_resolution = [max_resolution] * n | |
| def channels(self): | |
| """ | |
| Return the output dimension | |
| """ | |
| num_dims = len(self.max_resolution) | |
| encoding_size = self.num_bands * num_dims | |
| encoding_size *= 2 # sin-cos | |
| encoding_size += num_dims # concat | |
| return encoding_size | |
| def forward(self, pos): | |
| """ | |
| Forward pass that take rays as input and generate Fourier positional encodings | |
| """ | |
| fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution) | |
| return fourier_pos_enc | |
| def _generate_fourier_features(pos, num_bands, max_resolution): | |
| """Generate fourier features from a given set of positions and frequencies""" | |
| b, n = pos.shape[:2] | |
| device = pos.device | |
| # Linear frequency sampling | |
| min_freq = 1.0 | |
| freq_bands = torch.stack([torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=device) for res in max_resolution], dim=0) | |
| # Stacking | |
| per_pos_features = torch.stack([pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0) | |
| per_pos_features = per_pos_features.reshape(b, n, -1) | |
| # Sin-Cos | |
| per_pos_features = torch.cat([torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1) | |
| # Concat with initial pos | |
| per_pos_features = torch.cat([pos, per_pos_features], dim=-1) | |
| return per_pos_features |