import os
import sys
import cv2
import numpy as np
import torch
from PIL import Image
from .utils import gen_new_name, prompts
import torch
from omegaconf import OmegaConf
import numpy as np
import wget
from .inpainting_src.ldm_inpainting.ldm.models.diffusion.ddim import DDIMSampler
from .inpainting_src.ldm_inpainting.ldm.util import instantiate_from_config
from .utils import cal_dilate_factor, dilate_mask


def make_batch(image, mask, device):
    image = image.astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)

    mask = mask.astype(np.float32) / 255.0
    mask = mask[None, None]
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    mask = torch.from_numpy(mask)
        
    masked_image = (1 - mask) * image

    batch = {"image": image, "mask": mask, "masked_image": masked_image}
    for k in batch:
        batch[k] = batch[k].to(device=device)
        batch[k] = batch[k] * 2.0 - 1.0
    return batch


class LDMInpainting:
    def __init__(self, device):
        self.model_checkpoint_path = 'model_zoo/ldm_inpainting_big.ckpt'
        config = './iGPT/models/inpainting_src/ldm_inpainting/config.yaml'
        self.ddim_steps = 50
        self.device = device
        config = OmegaConf.load(config)
        model = instantiate_from_config(config.model)
        self.download_parameters()
        model.load_state_dict(torch.load(self.model_checkpoint_path)["state_dict"], strict=False)
        self.model = model.to(device=device)
        self.sampler = DDIMSampler(model)
    
    def download_parameters(self):
        url = 'https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1'
        if not os.path.exists(self.model_checkpoint_path):
            wget.download(url, out=self.model_checkpoint_path)

    @prompts(name="Remove the Masked Object",
             description="useful when you want to remove an object by masking the region in the image. "
                         "like: remove masked object or inpaint the masked region.. "
                         "The input to this tool should be a comma separated string of two, "
                         "representing the image_path and mask_path")
    @torch.no_grad()
    def inference(self, inputs):
        print(f'inputs: {inputs}')
        # image, mask, device
        img_path, mask_path = inputs.split(',')[0], inputs.split(',')[1]
        img_path = img_path.strip()
        mask_path = mask_path.strip()
        image = Image.open(img_path)
        mask = Image.open(mask_path).convert('L')
        w, h = image.size
        image = image.resize((512, 512))
        mask = mask.resize((512, 512))
        image = np.array(image)
        mask = np.array(mask)
        dilate_factor = cal_dilate_factor(mask.astype(np.uint8))
        mask = dilate_mask(mask, dilate_factor)
        
        with self.model.ema_scope():
            batch = make_batch(image, mask, device=self.device)
            # encode masked image and concat downsampled mask
            c = self.model.cond_stage_model.encode(batch["masked_image"])
            cc = torch.nn.functional.interpolate(batch["mask"],
                                                 size=c.shape[-2:])
            c = torch.cat((c, cc), dim=1)

            shape = (c.shape[1] - 1,) + c.shape[2:]
            samples_ddim, _ = self.sampler.sample(S=self.ddim_steps,
                                                    conditioning=c,
                                                    batch_size=c.shape[0],
                                                    shape=shape,
                                                    verbose=False)
            x_samples_ddim = self.model.decode_first_stage(samples_ddim)

            image = torch.clamp((batch["image"] + 1.0) / 2.0,
                                min=0.0, max=1.0)
            mask = torch.clamp((batch["mask"] + 1.0) / 2.0,
                               min=0.0, max=1.0)
            predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0,
                                          min=0.0, max=1.0)

            inpainted = (1 - mask) * image + mask * predicted_image
            inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
        
        # print(type(inpainted))
        inpainted = inpainted.astype(np.uint8)
        new_img_name = gen_new_name(img_path, 'LDMInpainter')
        new_img = Image.fromarray(inpainted)
        new_img = new_img.resize((w, h))
        new_img.save(new_img_name)
        print(
            f"\nProcessed LDMInpainting, Inputs: {inputs}, "
            f"Output Image: {new_img_name}")
        return new_img_name
        # return inpainted

'''
if __name__ == '__main__':
    painting = LDMInpainting('cuda:0')
    res = painting.inference(f'image/82e612_fe54ca_raw.png,image/04a785_fe54ca_mask.png.')
    print(res)
'''