"""SLIC dataset
 - Returns an image together with its SLIC segmentation map.
"""
import torch
import torch.utils.data as data
import torchvision.transforms as transforms

import numpy as np
from glob import glob
from PIL import Image
from skimage.segmentation import slic
from skimage.color import rgb2lab

from .utils import label2one_hot_torch

class RandomResizedCrop(object):
    def __init__(self, N, res, scale=(0.5, 1.0)):
        self.res    = res
        self.scale  = scale 
        self.rscale = [np.random.uniform(*scale) for _ in range(N)]
        self.rcrop  = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)]

    def random_crop(self, idx, img):
        ws, hs = self.rcrop[idx]
        res1 = int(img.size(-1))
        res2 = int(self.rscale[idx]*res1)
        i1 = int(round((res1-res2)*ws))
        j1 = int(round((res1-res2)*hs))

        return img[:, :, i1:i1+res2, j1:j1+res2]


    def __call__(self, indice, image):
        new_image = []
        res_tar   = self.res // 4 if image.size(1) > 5 else self.res # View 1 or View 2? 
        
        for i, idx in enumerate(indice):
            img = image[[i]]
            img = self.random_crop(idx, img)
            img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False)

            new_image.append(img)

        new_image = torch.cat(new_image)
        
        return new_image

class RandomVerticalFlip(object):
    def __init__(self, N, p=0.5):
        self.p_ref = p
        self.plist = np.random.random_sample(N)
        
    def __call__(self, indice, image):
        I = np.nonzero(self.plist[indice] < self.p_ref)[0]

        if len(image.size()) == 3:
            image_t = image[I].flip([1]) 
        else:
            image_t = image[I].flip([2])
        
        return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])

class RandomHorizontalTensorFlip(object):
    def __init__(self, N, p=0.5):
        self.p_ref = p
        self.plist = np.random.random_sample(N)

    def __call__(self, indice, image, is_label=False):
        I = np.nonzero(self.plist[indice] < self.p_ref)[0]
        
        if len(image.size()) == 3:
            image_t = image[I].flip([2])
        else:
            image_t = image[I].flip([3])
        
        return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])

class Dataset(data.Dataset):
    def __init__(self, data_dir, img_size=256, crop_size=128, test=False, 
                 sp_num=256, slic = True, lab = False): 
        super(Dataset, self).__init__()
        #self.data_list = glob(os.path.join(data_dir, "*.jpg"))
        ext = ["*.jpg"]
        dl = []
        [dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext]
        self.data_list = dl
        self.sp_num = sp_num
        self.slic = slic
        self.lab = lab
        if test:
            self.transform = transforms.Compose([
                             transforms.Resize(img_size),
                             transforms.CenterCrop(crop_size)])
        else:
            self.transform = transforms.Compose([
                             transforms.RandomChoice([
                                transforms.ColorJitter(brightness=0.05),
                                transforms.ColorJitter(contrast=0.05),
                                transforms.ColorJitter(saturation=0.01),
                                transforms.ColorJitter(hue=0.01)]),
                             transforms.RandomHorizontalFlip(),
                             transforms.RandomVerticalFlip(),
                             transforms.Resize(int(img_size)),
                             transforms.RandomCrop(crop_size)])

        N = len(self.data_list)
        self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
        self.random_vertical_flip   = RandomVerticalFlip(N=N)
        self.random_resized_crop    = RandomResizedCrop(N=N, res=img_size)
        self.eqv_list = ['random_crop', 'h_flip']

    def transform_eqv(self, indice, image):
        if 'random_crop' in self.eqv_list:
            image = self.random_resized_crop(indice, image)
        if 'h_flip' in self.eqv_list:
            image = self.random_horizontal_flip(indice, image)
        if 'v_flip' in self.eqv_list:
            image = self.random_vertical_flip(indice, image)

        return image
    
    def __getitem__(self, index):
        data_path = self.data_list[index]
        ori_img = Image.open(data_path)
        ori_img = self.transform(ori_img)
        ori_img = np.array(ori_img)

        # compute slic
        if self.slic:
            slic_i = slic(ori_img, n_segments=self.sp_num, compactness=10, start_label=0, min_size_factor=0.3)
            slic_i = torch.from_numpy(slic_i)
            slic_i[slic_i >= self.sp_num] = self.sp_num - 1
            oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = self.sp_num).squeeze()

        if ori_img.ndim < 3:
            ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2)
        ori_img = ori_img[:, :, :3]

        rets = []
        if self.lab:
            lab_img = rgb2lab(ori_img)
            rets.append(torch.from_numpy(lab_img).float().permute(2, 0, 1))

        ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1)
        rets.append(ori_img/255.0)

        if self.slic:
            rets.append(oh)
        
        rets.append(index)
        
        return rets
    
    def __len__(self):
        return len(self.data_list)

if __name__ == '__main__':
    import torchvision.utils as vutils
    dataset = Dataset('/home/xtli/DATA/texture_data/',
                      sampled_num=3000)
    loader_ = torch.utils.data.DataLoader(dataset     = dataset,
                                         batch_size  = 1,
                                         shuffle     = True,
                                         num_workers = 1,
                                         drop_last   = True)
    loader = iter(loader_)
    img, points, pixs = loader.next()

    crop_size = 128
    canvas = torch.zeros((1, 3, crop_size, crop_size))
    for i in range(points.shape[-2]):
        p = (points[0, i] + 1) / 2.0 * (crop_size - 1)
        canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i]
    vutils.save_image(canvas, 'canvas.png')
    vutils.save_image(img, 'img.png')