import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

from .taming_blocks import Encoder
from .nnutils import SPADEResnetBlock, get_edges, initWave

from libs.nnutils import poolfeat, upfeat
from libs.utils import label2one_hot_torch

from swapae.models.networks.stylegan2_layers import ConvLayer
from torch_geometric.nn import GCNConv
from torch_geometric.utils import softmax
from .loss import styleLossMaskv3

class GCN(nn.Module):
    def __init__(self, n_cluster, temperature = 1, add_self_loops = True, hidden_dim = 256):
        super().__init__()
        self.gcnconv1 = GCNConv(hidden_dim, hidden_dim, add_self_loops = add_self_loops)
        self.gcnconv2 = GCNConv(hidden_dim, hidden_dim, add_self_loops = add_self_loops)
        self.pool1 = nn.Sequential(nn.Conv2d(hidden_dim, n_cluster, 3, 1, 1))
        self.temperature = temperature

    def compute_edge_score_softmax(self, raw_edge_score, edge_index, num_nodes):
        return softmax(raw_edge_score, edge_index[1], num_nodes=num_nodes)

    def compute_edge_weight(self, node_feature, edge_index):
        src_feat = torch.gather(node_feature, 0, edge_index[0].unsqueeze(1).repeat(1, node_feature.shape[1]))
        tgt_feat = torch.gather(node_feature, 0, edge_index[1].unsqueeze(1).repeat(1, node_feature.shape[1]))
        raw_edge_weight = nn.CosineSimilarity(dim=1, eps=1e-6)(src_feat, tgt_feat)
        edge_weight = self.compute_edge_score_softmax(raw_edge_weight, edge_index, node_feature.shape[0])
        return raw_edge_weight.squeeze(), edge_weight.squeeze()

    def forward(self, sp_code, slic, clustering = False):
        edges, aff = get_edges(torch.argmax(slic, dim = 1).unsqueeze(1), sp_code.shape[1])
        prop_code = []
        sp_assign = []
        edge_weights = []
        conv_feats = []
        for i in range(sp_code.shape[0]):
            # compute edge weight
            edge_index = edges[i]
            raw_edge_weight, edge_weight = self.compute_edge_weight(sp_code[i], edge_index)
            feat = self.gcnconv1(sp_code[i], edge_index, edge_weight = edge_weight)
            raw_edge_weight, edge_weight = self.compute_edge_weight(feat, edge_index)
            edge_weights.append(raw_edge_weight)
            feat = F.leaky_relu(feat, 0.2)
            feat = self.gcnconv2(feat, edge_index, edge_weight = edge_weight)

            # maybe clustering
            conv_feat = upfeat(feat, slic[i:i+1])
            conv_feats.append(conv_feat)
            if not clustering:
                feat = conv_feat
                pred_mask = slic[i:i+1]
            else:
                pred_mask = self.pool1(conv_feat)
                # enforce pixels belong to the same superpixel to have same grouping label
                pred_mask = upfeat(poolfeat(pred_mask, slic[i:i+1]), slic[i:i+1])
                s_ = F.softmax(pred_mask * self.temperature, dim = 1)

                # compute texture code w.r.t grouping
                pool_feat = poolfeat(conv_feat, s_, avg = True)
                # hard upsampling
                #hard_s_ = label2one_hot_torch(torch.argmax(s_, dim = 1).unsqueeze(1), C = s_.shape[1])
                feat = upfeat(pool_feat, s_)
                #feat = upfeat(pool_feat, hard_s_)

            prop_code.append(feat)
            sp_assign.append(pred_mask)
        prop_code = torch.cat(prop_code)
        conv_feats = torch.cat(conv_feats)
        return prop_code, torch.cat(sp_assign), conv_feats

class SPADEGenerator(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        nf = hidden_dim // 16

        self.head_0 = SPADEResnetBlock(in_dim, 16 * nf)

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf)
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf)

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf)
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf)
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf)
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf)

        final_nc = nf

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)

    def forward(self, sine_wave, texon):

        x = self.head_0(sine_wave, texon)

        x = self.up(x)
        x = self.G_middle_0(x, texon)
        x = self.G_middle_1(x, texon)

        x = self.up(x)
        x = self.up_0(x, texon)
        x = self.up(x)
        x = self.up_1(x, texon)
        #x = self.up(x)
        x = self.up_2(x, texon)
        #x = self.up(x)
        x = self.up_3(x, texon)

        x = self.conv_img(F.leaky_relu(x, 2e-1))
        return x

class Waver(nn.Module):
    def __init__(self, tex_code_dim, zPeriodic):
        super(Waver, self).__init__()
        K = tex_code_dim
        layers =  [nn.Conv2d(tex_code_dim, K, 1)]
        layers += [nn.ReLU(True)]
        layers += [nn.Conv2d(K, 2 * zPeriodic, 1)]
        self.learnedWN =  nn.Sequential(*layers)
        self.waveNumbers = initWave(zPeriodic)

    def forward(self, GLZ=None):
        return (self.waveNumbers.to(GLZ.device) + self.learnedWN(GLZ))

class AE(nn.Module):
    def __init__(self, args, **ignore_kwargs):
        super(AE, self).__init__()

        # encoder & decoder
        self.enc = Encoder(ch=64, out_ch=3, ch_mult=[1,2,4,8], num_res_blocks=1, attn_resolutions=[],
                           in_channels=3, resolution=args.crop_size, z_channels=args.hidden_dim, double_z=False)
        if args.dec_input_mode == 'sine_wave_noise':
            self.G = SPADEGenerator(args.spatial_code_dim * 2, args.hidden_dim)
        else:
            self.G = SPADEGenerator(args.spatial_code_dim, args.hidden_dim)

        self.add_module(
            "ToTexCode",
            nn.Sequential(
                ConvLayer(args.hidden_dim, args.hidden_dim, kernel_size=3, activate=True, bias=True),
                ConvLayer(args.hidden_dim, args.tex_code_dim, kernel_size=3, activate=True, bias=True),
                ConvLayer(args.tex_code_dim, args.hidden_dim, kernel_size=1, activate=False, bias=False)
            )
        )
        self.gcn = GCN(n_cluster = args.n_cluster, temperature = args.temperature, add_self_loops = (args.add_self_loops == 1), hidden_dim = args.hidden_dim)

        self.add_gcn_epoch = args.add_gcn_epoch
        self.add_clustering_epoch = args.add_clustering_epoch
        self.add_texture_epoch = args.add_texture_epoch

        self.patch_size = args.patch_size
        self.sine_wave_dim = args.spatial_code_dim

        # inpainting network
        self.learnedWN = Waver(args.hidden_dim, zPeriodic = args.spatial_code_dim)
        self.dec_input_mode = args.dec_input_mode
        self.style_loss = styleLossMaskv3(device = args.device)

        if args.sine_weight:
            if args.dec_input_mode == 'sine_wave_noise':
                self.add_module(
                    "ChannelWeight",
                    nn.Sequential(
                        ConvLayer(args.hidden_dim, args.hidden_dim//2, kernel_size=3, activate=True, bias=True, downsample=True),
                        ConvLayer(args.hidden_dim//2, args.hidden_dim//4, kernel_size=3, activate=True, bias=True, downsample=True),
                        ConvLayer(args.hidden_dim//4, args.spatial_code_dim*2, kernel_size=1, activate=False, bias=False, downsample=True)))
            else:
                self.add_module(
                    "ChannelWeight",
                    nn.Sequential(
                        ConvLayer(args.hidden_dim, args.hidden_dim//2, kernel_size=3, activate=True, bias=True, downsample=True),
                        ConvLayer(args.hidden_dim//2, args.hidden_dim//4, kernel_size=3, activate=True, bias=True, downsample=True),
                        ConvLayer(args.hidden_dim//4, args.spatial_code_dim, kernel_size=1, activate=False, bias=False, downsample=True)))

    def get_sine_wave(self, GL, offset_mode = 'random'):
        img_size = GL.shape[-1] // 8
        GL = F.interpolate(GL, size = (img_size, img_size), mode = 'nearest')
        xv, yv = np.meshgrid(np.arange(img_size), np.arange(img_size),indexing='ij')
        c = torch.FloatTensor(np.concatenate([xv[np.newaxis], yv[np.newaxis]], 0)[np.newaxis])
        c = c.to(GL.device)
        # c: 1, 2, 28, 28
        c = c.repeat(GL.shape[0], self.sine_wave_dim, 1, 1)
        # c: 1, 64, 28, 28
        period = self.learnedWN(GL)
        # period: 1, 64, 28, 28
        raw = period * c
        if offset_mode == 'random':
            offset = torch.zeros((GL.shape[0], self.sine_wave_dim, 1, 1)).to(GL.device).uniform_(-1, 1) * 6.28
            offset = offset.repeat(1, 1, img_size, img_size)
            wave = torch.sin(raw[:, ::2] + raw[:, 1::2] + offset)
        elif offset_mode == 'rec':
            wave = torch.sin(raw[:, ::2] + raw[:, 1::2])
        return wave

    def forward(self, rgb_img, slic, epoch = 0, test_time = False, test = False, tex_idx = None):
        return