File size: 1,407 Bytes
f96995c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math


def chamfer(x, y):  # x: (B, N, D), y: (B, M, D)
    x = x[:, None].repeat(1, y.shape[1], 1, 1)  # (B, M, N, D)
    y = y[:, :, None].repeat(1, 1, x.shape[2], 1)  # (B, M, N, D)
    dis = torch.norm(x - y, 2, dim=-1)  # (B, M, N)
    dis_xy = torch.mean(dis.min(dim=2).values, dim=1)  # dis_xy: mean over N
    dis_yx = torch.mean(dis.min(dim=1).values, dim=1)  # dis_yx: mean over M
    return dis_xy + dis_yx


def batch_chamfer_dist(xyz, xyz_gt):
    # xyz: (B, N, 3)
    # xyz_gt: (M, 3)

    # mean aligning
    # chamfer = (xyz.mean(dim=1) - xyz_gt.mean(dim=0)).norm(dim=1)  # (B,)

    # chamfer distance
    xyz_gt = xyz_gt[None]  # (1, M, 3)
    dist1 = torch.sqrt(torch.sum((xyz[:, :, None] - xyz_gt[:, None]) ** 2, dim=3))  # (B, N, M)
    dist2 = torch.sqrt(torch.sum((xyz_gt[:, None] - xyz[:, :, None]) ** 2, dim=3))  # (B, M, N)
    chamfer = torch.mean(torch.min(dist1, dim=1).values, dim=1) + torch.mean(torch.min(dist2, dim=1).values, dim=1)  # (B,)
    return chamfer


def angle_normalize(x):
    return (((x + math.pi) % (2 * math.pi)) - math.pi)


def clip_actions(action, action_lower_lim, action_upper_lim):
    action_new = action.clone()
    # action_new[..., 2] = angle_normalize(action[..., 2])
    action_new.data.clamp_(action_lower_lim, action_upper_lim)
    return action_new