import torch
from torch import nn


def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx, pairwise_distance


def local_operator(x, k):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    idx, _ = knn(x, k=k)
    device = torch.device('cpu')
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points

    idx = idx + idx_base

    idx = idx.view(-1)

    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()

    neighbor = x.view(batch_size * num_points, -1)[idx, :]

    neighbor = neighbor.view(batch_size, num_points, k, num_dims)

    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)

    feature = torch.cat((neighbor-x, neighbor), dim=3).permute(0, 3, 1, 2)  # local and global all in

    return feature


def local_operator_withnorm(x, norm_plt, k):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    norm_plt = norm_plt.view(batch_size, -1, num_points)
    idx, _ = knn(x, k=k)  # (batch_size, num_points, k)
    device = torch.device('cpu')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points

    idx = idx + idx_base

    idx = idx.view(-1)

    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()
    norm_plt = norm_plt.transpose(2, 1).contiguous()

    neighbor = x.view(batch_size * num_points, -1)[idx, :]
    neighbor_norm = norm_plt.view(batch_size * num_points, -1)[idx, :]

    neighbor = neighbor.view(batch_size, num_points, k, num_dims)
    neighbor_norm = neighbor_norm.view(batch_size, num_points, k, num_dims)

    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)

    feature = torch.cat((neighbor-x, neighbor, neighbor_norm), dim=3).permute(0, 3, 1, 2)  # 3c

    return feature


def GDM(x, M):
    """
    Geometry-Disentangle Module
    M: number of disentangled points in both sharp and gentle variation components
    """
    k = 64  # number of neighbors to decide the range of j in Eq.(5)
    tau = 0.2  # threshold in Eq.(2)
    sigma = 2  # parameters of f (Gaussian function in Eq.(2))
    ###############
    """Graph Construction:"""
    device = torch.device('cpu')
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)

    idx, p = knn(x, k=k)  # p: -[(x1-x2)^2+...]

    # here we add a tau
    p1 = torch.abs(p)
    p1 = torch.sqrt(p1)
    mask = p1 < tau

    # here we add a sigma
    p = p / (sigma * sigma)
    w = torch.exp(p)  # b,n,n
    w = torch.mul(mask.float(), w)

    b = 1/torch.sum(w, dim=1)
    b = b.reshape(batch_size, num_points, 1).repeat(1, 1, num_points)
    c = torch.eye(num_points, num_points, device=device)
    c = c.expand(batch_size, num_points, num_points)
    D = b * c  # b,n,n

    A = torch.matmul(D, w)  # normalized adjacency matrix A_hat

    # Get Aij in a local area:
    idx2 = idx.view(batch_size * num_points, -1)
    idx_base2 = torch.arange(0, batch_size * num_points, device=device).view(-1, 1) * num_points
    idx2 = idx2 + idx_base2

    idx2 = idx2.reshape(batch_size * num_points, k)[:, 1:k]
    idx2 = idx2.reshape(batch_size * num_points * (k - 1))
    idx2 = idx2.view(-1)

    A = A.view(-1)
    A = A[idx2].reshape(batch_size, num_points, k - 1)  # Aij: b,n,k
    ###############
    """Disentangling Point Clouds into Sharp(xs) and Gentle(xg) Variation Components:"""
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
    idx = idx + idx_base
    idx = idx.reshape(batch_size * num_points, k)[:, 1:k]
    idx = idx.reshape(batch_size * num_points * (k - 1))

    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()  # b,n,c
    neighbor = x.view(batch_size * num_points, -1)[idx, :]
    neighbor = neighbor.view(batch_size, num_points, k - 1, num_dims)  # b,n,k,c
    A = A.reshape(batch_size, num_points, k - 1, 1)  # b,n,k,1
    n = A.mul(neighbor)  # b,n,k,c
    n = torch.sum(n, dim=2)  # b,n,c

    pai = torch.norm(x - n, dim=-1).pow(2)  # Eq.(5)
    pais = pai.topk(k=M, dim=-1)[1]  # first M points as the sharp variation component
    paig = (-pai).topk(k=M, dim=-1)[1]  # last M points as the gentle variation component

    pai_base = torch.arange(0, batch_size, device=device).view(-1, 1) * num_points
    indices = (pais + pai_base).view(-1)
    indiceg = (paig + pai_base).view(-1)

    xs = x.view(batch_size * num_points, -1)[indices, :]
    xg = x.view(batch_size * num_points, -1)[indiceg, :]

    xs = xs.view(batch_size, M, -1)  # b,M,c
    xg = xg.view(batch_size, M, -1)  # b,M,c

    return xs, xg


class SGCAM(nn.Module):
    """Sharp-Gentle Complementary Attention Module:"""
    def __init__(self, in_channels, inter_channels=None, bn_layer=True):
        super(SGCAM, self).__init__()

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        conv_nd = nn.Conv1d
        bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                bn(self.in_channels)
            )
            nn.init.constant(self.W[1].weight, 0)
            nn.init.constant(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant(self.W.weight, 0)
            nn.init.constant(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)

        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0)

    def forward(self, x, x_2):
        batch_size = x.size(0)

        g_x = self.g(x_2).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x_2).view(batch_size, self.inter_channels, -1)
        W = torch.matmul(theta_x, phi_x)  # Attention Matrix
        N = W.size(-1)
        W_div_C = W / N

        y = torch.matmul(W_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        y = W_y + x

        return y