import os
import json
from contextlib import contextmanager

import torch
import numpy as np
from einops import rearrange

import torch.nn.functional as F
import torch.distributed as dist
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only

from taming.modules.vqvae.quantize import VectorQuantizer as VectorQuantizer

from core.modules.networks.ae_modules import Encoder, Decoder
from core.distributions import DiagonalGaussianDistribution
from utils.utils import instantiate_from_config
from utils.save_video import tensor2videogrids
from core.common import shape_to_str, gather_data


class AutoencoderKL(pl.LightningModule):
    def __init__(
        self,
        ddconfig,
        lossconfig,
        embed_dim,
        ckpt_path=None,
        ignore_keys=[],
        image_key="image",
        colorize_nlabels=None,
        monitor=None,
        test=False,
        logdir=None,
        input_dim=4,
        test_args=None,
    ):
        super().__init__()
        self.image_key = image_key
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        self.loss = instantiate_from_config(lossconfig)
        assert ddconfig["double_z"]
        self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
        self.embed_dim = embed_dim
        self.input_dim = input_dim
        self.test = test
        self.test_args = test_args
        self.logdir = logdir
        if colorize_nlabels is not None:
            assert type(colorize_nlabels) == int
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
        if monitor is not None:
            self.monitor = monitor
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
        if self.test:
            self.init_test()

    def init_test(
        self,
    ):
        self.test = True
        save_dir = os.path.join(self.logdir, "test")
        if "ckpt" in self.test_args:
            ckpt_name = (
                os.path.basename(self.test_args.ckpt).split(".ckpt")[0]
                + f"_epoch{self._cur_epoch}"
            )
            self.root = os.path.join(save_dir, ckpt_name)
        else:
            self.root = save_dir
        if "test_subdir" in self.test_args:
            self.root = os.path.join(save_dir, self.test_args.test_subdir)

        self.root_zs = os.path.join(self.root, "zs")
        self.root_dec = os.path.join(self.root, "reconstructions")
        self.root_inputs = os.path.join(self.root, "inputs")
        os.makedirs(self.root, exist_ok=True)

        if self.test_args.save_z:
            os.makedirs(self.root_zs, exist_ok=True)
        if self.test_args.save_reconstruction:
            os.makedirs(self.root_dec, exist_ok=True)
        if self.test_args.save_input:
            os.makedirs(self.root_inputs, exist_ok=True)
        assert self.test_args is not None
        self.test_maximum = getattr(
            self.test_args, "test_maximum", None
        )  # 1500 # 12000/8
        self.count = 0
        self.eval_metrics = {}
        self.decodes = []
        self.save_decode_samples = 2048
        if getattr(self.test_args, "cal_metrics", False):
            self.EvalLpips = EvalLpips()

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")
        try:
            self._cur_epoch = sd["epoch"]
            sd = sd["state_dict"]
        except:
            self._cur_epoch = "null"
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        self.load_state_dict(sd, strict=False)
        # self.load_state_dict(sd, strict=True)
        print(f"Restored from {path}")

    def encode(self, x, **kwargs):

        h = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, z, **kwargs):
        z = self.post_quant_conv(z)
        dec = self.decoder(z)
        return dec

    def forward(self, input, sample_posterior=True):
        posterior = self.encode(input)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        dec = self.decode(z)
        return dec, posterior

    def get_input(self, batch, k):
        x = batch[k]
        # if len(x.shape) == 3:
        #     x = x[..., None]
        # if x.dim() == 4:
        #     x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
        if x.dim() == 5 and self.input_dim == 4:
            b, c, t, h, w = x.shape
            self.b = b
            self.t = t
            x = rearrange(x, "b c t h w -> (b t) c h w")

        return x

    def training_step(self, batch, batch_idx, optimizer_idx):
        inputs = self.get_input(batch, self.image_key)
        reconstructions, posterior = self(inputs)

        if optimizer_idx == 0:
            # train encoder+decoder+logvar
            aeloss, log_dict_ae = self.loss(
                inputs,
                reconstructions,
                posterior,
                optimizer_idx,
                self.global_step,
                last_layer=self.get_last_layer(),
                split="train",
            )
            self.log(
                "aeloss",
                aeloss,
                prog_bar=True,
                logger=True,
                on_step=True,
                on_epoch=True,
            )
            self.log_dict(
                log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
            )
            return aeloss

        if optimizer_idx == 1:
            # train the discriminator
            discloss, log_dict_disc = self.loss(
                inputs,
                reconstructions,
                posterior,
                optimizer_idx,
                self.global_step,
                last_layer=self.get_last_layer(),
                split="train",
            )

            self.log(
                "discloss",
                discloss,
                prog_bar=True,
                logger=True,
                on_step=True,
                on_epoch=True,
            )
            self.log_dict(
                log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
            )
            return discloss

    def validation_step(self, batch, batch_idx):
        inputs = self.get_input(batch, self.image_key)
        reconstructions, posterior = self(inputs)
        aeloss, log_dict_ae = self.loss(
            inputs,
            reconstructions,
            posterior,
            0,
            self.global_step,
            last_layer=self.get_last_layer(),
            split="val",
        )

        discloss, log_dict_disc = self.loss(
            inputs,
            reconstructions,
            posterior,
            1,
            self.global_step,
            last_layer=self.get_last_layer(),
            split="val",
        )

        self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
        self.log_dict(log_dict_ae)
        self.log_dict(log_dict_disc)
        return self.log_dict

    def test_step(self, batch, batch_idx):
        # save z, dec
        inputs = self.get_input(batch, self.image_key)
        # forward
        sample_posterior = True
        posterior = self.encode(inputs)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        dec = self.decode(z)

        # logs
        if self.test_args.save_z:
            torch.save(
                z,
                os.path.join(
                    self.root_zs,
                    f"zs_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.pt",
                ),
            )
        if self.test_args.save_reconstruction:
            tensor2videogrids(
                dec,
                self.root_dec,
                f"reconstructions_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.mp4",
                fps=10,
            )
        if self.test_args.save_input:
            tensor2videogrids(
                inputs,
                self.root_inputs,
                f"inputs_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.mp4",
                fps=10,
            )

        if "save_z" in self.test_args and self.test_args.save_z:
            dec_np = (dec.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) + 1) / 2 * 255
            dec_np = dec_np.astype(np.uint8)
            self.root_dec_np = os.path.join(self.root, "reconstructions_np")
            os.makedirs(self.root_dec_np, exist_ok=True)
            np.savez(
                os.path.join(
                    self.root_dec_np,
                    f"reconstructions_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(dec_np)}.npz",
                ),
                dec_np,
            )

        self.count += z.shape[0]

        # misc
        self.log("batch_idx", batch_idx, prog_bar=True)
        self.log_dict(self.eval_metrics, prog_bar=True, logger=True)
        torch.cuda.empty_cache()
        if self.test_maximum is not None:
            if self.count > self.test_maximum:
                import sys

                sys.exit()
            else:
                prog = self.count / self.test_maximum * 100
                print(f"Test progress: {prog:.2f}% [{self.count}/{self.test_maximum}]")

    @rank_zero_only
    def on_test_end(self):
        if self.test_args.cal_metrics:
            psnrs, ssims, ms_ssims, lpipses = [], [], [], []
            n_batches = 0
            n_samples = 0
            overall = {}
            for k, v in self.eval_metrics.items():
                psnrs.append(v["psnr"])
                ssims.append(v["ssim"])
                lpipses.append(v["lpips"])
                n_batches += 1
                n_samples += v["n_samples"]

            mean_psnr = sum(psnrs) / len(psnrs)
            mean_ssim = sum(ssims) / len(ssims)
            # overall['ms_ssim'] = min(ms_ssims)
            mean_lpips = sum(lpipses) / len(lpipses)

            overall = {
                "psnr": mean_psnr,
                "ssim": mean_ssim,
                "lpips": mean_lpips,
                "n_batches": n_batches,
                "n_samples": n_samples,
            }
            overall_t = torch.tensor([mean_psnr, mean_ssim, mean_lpips])
            # dump
            for k, v in overall.items():
                if isinstance(v, torch.Tensor):
                    overall[k] = float(v)
            with open(
                os.path.join(self.root, f"reconstruction_metrics.json"), "w"
            ) as f:
                json.dump(overall, f)
            f.close()

    def configure_optimizers(self):
        lr = self.learning_rate
        opt_ae = torch.optim.Adam(
            list(self.encoder.parameters())
            + list(self.decoder.parameters())
            + list(self.quant_conv.parameters())
            + list(self.post_quant_conv.parameters()),
            lr=lr,
            betas=(0.5, 0.9),
        )
        opt_disc = torch.optim.Adam(
            self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
        )
        return [opt_ae, opt_disc], []

    def get_last_layer(self):
        return self.decoder.conv_out.weight

    @torch.no_grad()
    def log_images(self, batch, only_inputs=False, **kwargs):
        log = dict()
        x = self.get_input(batch, self.image_key)
        x = x.to(self.device)
        if not only_inputs:
            xrec, posterior = self(x)
            if x.shape[1] > 3:
                # colorize with random projection
                assert xrec.shape[1] > 3
                x = self.to_rgb(x)
                xrec = self.to_rgb(xrec)
            log["samples"] = self.decode(torch.randn_like(posterior.sample()))
            log["reconstructions"] = xrec
        log["inputs"] = x
        return log

    def to_rgb(self, x):
        assert self.image_key == "segmentation"
        if not hasattr(self, "colorize"):
            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
        x = F.conv2d(x, weight=self.colorize)
        x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
        return x


class IdentityFirstStage(torch.nn.Module):
    def __init__(self, *args, vq_interface=False, **kwargs):
        self.vq_interface = vq_interface
        super().__init__()

    def encode(self, x, *args, **kwargs):
        return x

    def decode(self, x, *args, **kwargs):
        return x

    def quantize(self, x, *args, **kwargs):
        if self.vq_interface:
            return x, None, [None, None, None]
        return x

    def forward(self, x, *args, **kwargs):
        return x