# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn.functional as F from easydict import EasyDict as edict from sklearn.decomposition import PCA import matplotlib.pyplot as plt EPS = 1e-6 def nearest_sample2d(im, x, y, return_inbounds=False): # x and y are each B, N # output is B, C, N if len(im.shape) == 5: B, N, C, H, W = list(im.shape) else: B, C, H, W = list(im.shape) N = list(x.shape)[1] x = x.float() y = y.float() H_f = torch.tensor(H, dtype=torch.float32) W_f = torch.tensor(W, dtype=torch.float32) # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte() y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() inbounds = (x_valid & y_valid).float() inbounds = inbounds.reshape( B, N ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) return output, inbounds return output # B, C, N def smart_cat(tensor1, tensor2, dim): if tensor1 is None: return tensor2 return torch.cat([tensor1, tensor2], dim=dim) def normalize_single(d): # d is a whatever shape torch tensor dmin = torch.min(d) dmax = torch.max(d) d = (d - dmin) / (EPS + (dmax - dmin)) return d def normalize(d): # d is B x whatever. normalize within each element of the batch out = torch.zeros(d.size()) if d.is_cuda: out = out.cuda() B = list(d.size())[0] for b in list(range(B)): out[b] = normalize_single(d[b]) return out def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"): # returns a meshgrid sized B x Y x X grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device)) grid_y = torch.reshape(grid_y, [1, Y, 1]) grid_y = grid_y.repeat(B, 1, X) grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device)) grid_x = torch.reshape(grid_x, [1, 1, X]) grid_x = grid_x.repeat(B, Y, 1) if stack: # note we stack in xy order # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) grid = torch.stack([grid_x, grid_y], dim=-1) return grid else: return grid_y, grid_x def reduce_masked_mean(x, mask, dim=None, keepdim=False): # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting # returns shape-1 # axis can be a list of axes for (a, b) in zip(x.size(), mask.size()): assert a == b # some shape mismatch! prod = x * mask if dim is None: numer = torch.sum(prod) denom = EPS + torch.sum(mask) else: numer = torch.sum(prod, dim=dim, keepdim=keepdim) denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim) mean = numer / denom return mean def bilinear_sample2d(im, x, y, return_inbounds=False): # x and y are each B, N # output is B, C, N if len(im.shape) == 5: B, N, C, H, W = list(im.shape) else: B, C, H, W = list(im.shape) N = list(x.shape)[1] x = x.float() y = y.float() H_f = torch.tensor(H, dtype=torch.float32) W_f = torch.tensor(W, dtype=torch.float32) # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte() y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() inbounds = (x_valid & y_valid).float() inbounds = inbounds.reshape( B, N ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) return output, inbounds return output # B, C, N def procrustes_analysis(X0,X1,Weight): # [B,N,3] # translation t0 = X0.mean(dim=1,keepdim=True) t1 = X1.mean(dim=1,keepdim=True) X0c = X0-t0 X1c = X1-t1 # scale # s0 = (X0c**2).sum(dim=-1).mean().sqrt() # s1 = (X1c**2).sum(dim=-1).mean().sqrt() # X0cs = X0c/s0 # X1cs = X1c/s1 # rotation (use double for SVD, float loses precision) U,_,V = (X0c.t()@X1c).double().svd(some=True) R = (U@V.t()).float() if R.det()<0: R[2] *= -1 # align X1 to X0: X1to0 = (X1-t1)/@R.t()+t0 se3 = edict(t0=t0[0],t1=t1[0],R=R) return se3 def bilinear_sampler(input, coords, align_corners=True, padding_mode="border", interp_mode="bilinear"): r"""Sample a tensor using bilinear interpolation `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at coordinates :attr:`coords` using bilinear interpolation. It is the same as `torch.nn.functional.grid_sample()` but with a different coordinate convention. The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where :math:`B` is the batch size, :math:`C` is the number of channels, :math:`H` is the height of the image, and :math:`W` is the width of the image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note that in this case the order of the components is slightly different from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. If `align_corners` is `True`, the coordinate :math:`x` is assumed to be in the range :math:`[0,W-1]`, with 0 corresponding to the center of the left-most image pixel :math:`W-1` to the center of the right-most pixel. If `align_corners` is `False`, the coordinate :math:`x` is assumed to be in the range :math:`[0,W]`, with 0 corresponding to the left edge of the left-most pixel :math:`W` to the right edge of the right-most pixel. Similar conventions apply to the :math:`y` for the range :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range :math:`[0,T-1]` and :math:`[0,T]`. Args: input (Tensor): batch of input images. coords (Tensor): batch of coordinates. align_corners (bool, optional): Coordinate convention. Defaults to `True`. padding_mode (str, optional): Padding mode. Defaults to `"border"`. Returns: Tensor: sampled points. """ sizes = input.shape[2:] assert len(sizes) in [2, 3] if len(sizes) == 3: # t x y -> x y t to match dimensions T H W in grid_sample coords = coords[..., [1, 2, 0]] if align_corners: coords = coords * torch.tensor( [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device ) else: coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) coords -= 1 return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode, mode=interp_mode) def sample_features4d(input, coords, interp_mode="bilinear"): r"""Sample spatial features `sample_features4d(input, coords)` samples the spatial features :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. The field is sampled at coordinates :attr:`coords` using bilinear interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the same convention as :func:`bilinear_sampler` with `align_corners=True`. The output tensor has one feature per point, and has shape :math:`(B, R, C)`. Args: input (Tensor): spatial features. coords (Tensor): points. Returns: Tensor: sampled features. """ B, _, _, _ = input.shape # B R 2 -> B R 1 2 coords = coords.unsqueeze(2) # B C R 1 feats = bilinear_sampler(input, coords, interp_mode=interp_mode) return feats.permute(0, 2, 1, 3).view( B, -1, feats.shape[1] * feats.shape[3] ) # B C R 1 -> B R C def sample_features5d(input, coords, interp_mode="bilinear"): r"""Sample spatio-temporal features `sample_features5d(input, coords)` works in the same way as :func:`sample_features4d` but for spatio-temporal features and points: :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i, x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`. Args: input (Tensor): spatio-temporal features. coords (Tensor): spatio-temporal points. Returns: Tensor: sampled features. """ B, T, _, _, _ = input.shape # B T C H W -> B C T H W input = input.permute(0, 2, 1, 3, 4) # B R1 R2 3 -> B R1 R2 1 3 coords = coords.unsqueeze(3) # B C R1 R2 1 feats = bilinear_sampler(input, coords, interp_mode=interp_mode) return feats.permute(0, 2, 3, 1, 4).view( B, feats.shape[2], feats.shape[3], feats.shape[1] ) # B C R1 R2 1 -> B R1 R2 C def vis_PCA(fmaps, save_dir): """ visualize the PCA of the feature maps args: fmaps: feature maps 1 C H W save_dir: the directory to save the PCA visualization """ pca = PCA(n_components=3) fmap_vis = fmaps[0,...] fmap_vnorm = ( (fmap_vis-fmap_vis.min())/ (fmap_vis.max()-fmap_vis.min())) H_vis, W_vis = fmap_vis.shape[1:] fmap_vnorm = fmap_vnorm.reshape(fmap_vnorm.shape[0], -1).permute(1,0) fmap_pca = pca.fit_transform(fmap_vnorm.detach().cpu().numpy()) pca = fmap_pca.reshape(H_vis,W_vis,3) plt.imsave(save_dir, ( (pca-pca.min())/ (pca.max()-pca.min()) ))