from dataclasses import dataclass, field

import numpy as np
import json
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from skimage import measure
from einops import repeat
from tqdm import tqdm
from PIL import Image

from diffusers import (
    DDPMScheduler,
    DDIMScheduler,
    UniPCMultistepScheduler,
    KarrasVeScheduler,
    DPMSolverMultistepScheduler,
)
from diffusers.training_utils import (
    compute_density_for_timestep_sampling,
    compute_loss_weighting_for_sd3,
    free_memory,
)
import step1x3d_geometry
from step1x3d_geometry.systems.base import BaseSystem
from step1x3d_geometry.utils.misc import get_rank
from step1x3d_geometry.utils.typing import *
from step1x3d_geometry.systems.utils import read_image, preprocess_image, flow_sample


def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32):
    sigmas = noise_scheduler.sigmas.to(device=timesteps.device, dtype=dtype)
    schedule_timesteps = noise_scheduler.timesteps.to(timesteps.device)
    step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

    sigma = sigmas[step_indices].flatten()
    while len(sigma.shape) < n_dim:
        sigma = sigma.unsqueeze(-1)
    return sigma


@step1x3d_geometry.register("rectified-flow-system")
class RectifiedFlowSystem(BaseSystem):
    @dataclass
    class Config(BaseSystem.Config):
        skip_validation: bool = True
        val_samples_json: str = ""
        bounds: float = 1.05
        mc_level: float = 0.0
        octree_resolution: int = 256

        # diffusion config
        guidance_scale: float = 7.5
        num_inference_steps: int = 30
        eta: float = 0.0
        snr_gamma: float = 5.0

        # flow
        weighting_scheme: str = "logit_normal"
        logit_mean: float = 0
        logit_std: float = 1.0
        mode_scale: float = 1.29
        precondition_outputs: bool = True
        precondition_t: int = 1000

        # shape vae model
        shape_model_type: str = None
        shape_model: dict = field(default_factory=dict)

        # condition model
        visual_condition_type: Optional[str] = None
        visual_condition: dict = field(default_factory=dict)
        caption_condition_type: Optional[str] = None
        caption_condition: dict = field(default_factory=dict)
        label_condition_type: Optional[str] = None
        label_condition: dict = field(default_factory=dict)

        # diffusion model
        denoiser_model_type: str = None
        denoiser_model: dict = field(default_factory=dict)

        # noise scheduler
        noise_scheduler_type: str = None
        noise_scheduler: dict = field(default_factory=dict)

        # denoise scheduler
        denoise_scheduler_type: str = None
        denoise_scheduler: dict = field(default_factory=dict)

        # lora
        use_lora: bool = False
        lora_layers: Optional[str] = None
        rank: int = 128  # The dimension of the LoRA update matrices.
        alpha: int = 128

    cfg: Config

    def configure(self):
        super().configure()

        self.shape_model = step1x3d_geometry.find(self.cfg.shape_model_type)(
            self.cfg.shape_model
        )
        self.shape_model.eval()
        self.shape_model.requires_grad_(False)

        if self.cfg.visual_condition_type is not None:
            self.visual_condition = step1x3d_geometry.find(
                self.cfg.visual_condition_type
            )(self.cfg.visual_condition)
            self.visual_condition.requires_grad_(False)

        if self.cfg.caption_condition_type is not None:
            self.caption_condition = step1x3d_geometry.find(
                self.cfg.caption_condition_type
            )(self.cfg.caption_condition)
            self.caption_condition.requires_grad_(False)

        if self.cfg.label_condition_type is not None:
            self.label_condition = step1x3d_geometry.find(
                self.cfg.label_condition_type
            )(self.cfg.label_condition)

        self.denoiser_model = step1x3d_geometry.find(self.cfg.denoiser_model_type)(
            self.cfg.denoiser_model
        )
        if self.cfg.use_lora:  # We only train the additional adapter LoRA layers
            self.denoiser_model.requires_grad_(False)

        self.noise_scheduler = step1x3d_geometry.find(self.cfg.noise_scheduler_type)(
            **self.cfg.noise_scheduler
        )
        self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)

        self.denoise_scheduler = step1x3d_geometry.find(
            self.cfg.denoise_scheduler_type
        )(**self.cfg.denoise_scheduler)

        if self.cfg.use_lora:
            from peft import LoraConfig, set_peft_model_state_dict

            if self.cfg.lora_layers is not None:
                self.target_modules = [
                    layer.strip() for layer in self.cfg.lora_layers.split(",")
                ]
            else:
                self.target_modules = [
                    "attn.to_k",
                    "attn.to_q",
                    "attn.to_v",
                    "attn.to_out.0",
                    "attn.add_k_proj",
                    "attn.add_q_proj",
                    "attn.add_v_proj",
                    "attn.to_add_out",
                    "ff.net.0.proj",
                    "ff.net.2",
                    "ff_context.net.0.proj",
                    "ff_context.net.2",
                ]
                self.transformer_lora_config = LoraConfig(
                    r=self.cfg.rank,
                    lora_alpha=self.cfg.alpha,
                    init_lora_weights="gaussian",
                    target_modules=self.target_modules,
                )
                self.denoiser_model.dit_model.add_adapter(self.transformer_lora_config)

    def forward(self, batch: Dict[str, Any], skip_noise=False) -> Dict[str, Any]:
        # 1. encode shape latents
        if "sharp_surface" in batch.keys():
            sharp_surface = batch["sharp_surface"][
                ..., : 3 + self.cfg.shape_model.point_feats
            ]
        else:
            sharp_surface = None
        shape_embeds, latents, _ = self.shape_model.encode(
            batch["surface"][..., : 3 + self.cfg.shape_model.point_feats],
            sample_posterior=True,
            sharp_surface=sharp_surface,
        )

        # 2. gain visual condition
        visual_cond = None
        if self.cfg.visual_condition_type is not None:
            assert "image" in batch.keys(), "image is required for label encoder"
            if "image" in batch and batch["image"].dim() == 5:
                if self.training:
                    bs, n_images = batch["image"].shape[:2]
                    batch["image"] = batch["image"].view(
                        bs * n_images, *batch["image"].shape[-3:]
                    )
                else:
                    batch["image"] = batch["image"][:, 0, ...]
                    n_images = 1
                    bs = batch["image"].shape[0]
                visual_cond = self.visual_condition(batch).to(latents)
                latents = latents.unsqueeze(1).repeat(1, n_images, 1, 1)
                latents = latents.view(bs * n_images, *latents.shape[-2:])
            else:
                visual_cond = self.visual_condition(batch).to(latents)
                bs = visual_cond.shape[0]
                n_images = 1

        ## 2.1 text condition if provided
        caption_cond = None
        if self.cfg.caption_condition_type is not None:
            assert "caption" in batch.keys(), "caption is required for caption encoder"
            assert bs == len(
                batch["caption"]
            ), "Batch size must be the same as the caption length."
            caption_cond = (
                self.caption_condition(batch)
                .repeat_interleave(n_images, dim=0)
                .to(latents)
            )

        ## 2.2 label condition if provided
        label_cond = None
        if self.cfg.label_condition_type is not None:
            assert "label" in batch.keys(), "label is required for label encoder"
            assert bs == len(
                batch["label"]
            ), "Batch size must be the same as the label length."
            label_cond = (
                self.label_condition(batch)
                .repeat_interleave(n_images, dim=0)
                .to(latents)
            )

        # 3. sample noise that we"ll add to the latents
        noise = torch.randn_like(latents).to(
            latents
        )  # [batch_size, n_token, latent_dim]

        # 4. Sample a random timestep
        u = compute_density_for_timestep_sampling(
            weighting_scheme=self.cfg.weighting_scheme,
            batch_size=bs * n_images,
            logit_mean=self.cfg.logit_mean,
            logit_std=self.cfg.logit_std,
            mode_scale=self.cfg.mode_scale,
        )
        indices = (u * self.cfg.noise_scheduler.num_train_timesteps).long()
        timesteps = self.noise_scheduler_copy.timesteps[indices].to(
            device=latents.device
        )

        # 5. add noise
        sigmas = get_sigmas(
            self.noise_scheduler_copy, timesteps, n_dim=3, dtype=latents.dtype
        )
        noisy_z = (1.0 - sigmas) * latents + sigmas * noise

        # 6. diffusion model forward
        output = self.denoiser_model(
            noisy_z, timesteps.long(), visual_cond, caption_cond, label_cond
        ).sample

        # 7. compute loss
        if self.cfg.precondition_outputs:
            output = output * (-sigmas) + noisy_z
        # these weighting schemes use a uniform timestep sampling
        # and instead post-weight the loss
        weighting = compute_loss_weighting_for_sd3(
            weighting_scheme=self.cfg.weighting_scheme, sigmas=sigmas
        )
        # flow matching loss
        if self.cfg.precondition_outputs:
            target = latents
        else:
            target = noise - latents

        # Compute regular loss.
        loss = torch.mean(
            (weighting.float() * (output.float() - target.float()) ** 2).reshape(
                target.shape[0], -1
            ),
            1,
        )
        loss = loss.mean()

        return {
            "loss_diffusion": loss,
            "latents": latents,
            "x_t": noisy_z,
            "noise": noise,
            "noise_pred": output,
            "timesteps": timesteps,
        }

    def training_step(self, batch, batch_idx):
        out = self(batch)

        loss = 0.0
        for name, value in out.items():
            if name.startswith("loss_"):
                self.log(f"train/{name}", value)
                loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])
            if name.startswith("log_"):
                self.log(f"log/{name.replace('log_', '')}", value.mean())

        for name, value in self.cfg.loss.items():
            if name.startswith("lambda_"):
                self.log(f"train_params/{name}", self.C(value))

        return {"loss": loss}

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        if self.cfg.skip_validation:
            return {}
        self.eval()

        if get_rank() == 0:
            sample_inputs = json.loads(
                open(self.cfg.val_samples_json).read()
            )  # condition
            sample_inputs_ = copy.deepcopy(sample_inputs)
            sample_outputs = self.sample(sample_inputs)  # list
            for i, latents in enumerate(sample_outputs["latents"]):
                meshes = self.shape_model.extract_geometry(
                    latents,
                    bounds=self.cfg.bounds,
                    mc_level=self.cfg.mc_level,
                    octree_resolution=self.cfg.octree_resolution,
                    enable_pbar=False,
                )

                for j in range(len(meshes)):
                    name = ""
                    if "image" in sample_inputs_:
                        name += (
                            sample_inputs_["image"][j]
                            .split("/")[-1]
                            .replace(".png", "")
                        )

                    elif "mvimages" in sample_inputs_:
                        name += (
                            sample_inputs_["mvimages"][j][0]
                            .split("/")[-2]
                            .replace(".png", "")
                        )

                    if "caption" in sample_inputs_:
                        name += "_" + sample_inputs_["caption"][j].replace(
                            " ", "_"
                        ).replace(".", "")

                    if "label" in sample_inputs_:
                        name += (
                            "_"
                            + sample_inputs_["label"][j]["symmetry"]
                            + sample_inputs_["label"][j]["edge_type"]
                        )

                    if (
                        meshes[j].verts is not None
                        and meshes[j].verts.shape[0] > 0
                        and meshes[j].faces is not None
                        and meshes[j].faces.shape[0] > 0
                    ):
                        self.save_mesh(
                            f"it{self.true_global_step}/{name}_{i}.obj",
                            meshes[j].verts,
                            meshes[j].faces,
                        )
                        torch.cuda.empty_cache()

        out = self(batch)
        if self.global_step == 0:
            latents = self.shape_model.decode(out["latents"])
            meshes = self.shape_model.extract_geometry(
                latents,
                bounds=self.cfg.bounds,
                mc_level=self.cfg.mc_level,
                octree_resolution=self.cfg.octree_resolution,
                enable_pbar=False,
            )

            for i, mesh in enumerate(meshes):
                self.save_mesh(
                    f"it{self.true_global_step}/{batch['uid'][i]}.obj",
                    mesh.verts,
                    mesh.faces,
                )

        return {"val/loss": out["loss_diffusion"]}

    @torch.no_grad()
    def sample(
        self,
        sample_inputs: Dict[str, Union[torch.FloatTensor, List[str]]],
        sample_times: int = 1,
        steps: Optional[int] = None,
        guidance_scale: Optional[float] = None,
        eta: float = 0.0,
        seed: Optional[int] = None,
        **kwargs,
    ):

        if steps is None:
            steps = self.cfg.num_inference_steps
        if guidance_scale is None:
            guidance_scale = self.cfg.guidance_scale
        do_classifier_free_guidance = guidance_scale != 1.0

        # conditional encode
        visal_cond = None
        if "image" in sample_inputs:
            sample_inputs["image"] = [
                Image.open(img) if type(img) == str else img
                for img in sample_inputs["image"]
            ]
            sample_inputs["image"] = preprocess_image(sample_inputs["image"], **kwargs)
            cond = self.visual_condition.encode_image(sample_inputs["image"])
            if do_classifier_free_guidance:
                un_cond = self.visual_condition.empty_image_embeds.repeat(
                    len(sample_inputs["image"]), 1, 1
                ).to(cond)
                visal_cond = torch.cat([un_cond, cond], dim=0)
        caption_cond = None
        if "caption" in sample_inputs:
            cond = self.label_condition.encode_label(sample_inputs["caption"])
            if do_classifier_free_guidance:
                un_cond = self.caption_condition.empty_caption_embeds.repeat(
                    len(sample_inputs["caption"]), 1, 1
                ).to(cond)
                caption_cond = torch.cat([un_cond, cond], dim=0)
        label_cond = None
        if "label" in sample_inputs:
            cond = self.label_condition.encode_label(sample_inputs["label"])
            if do_classifier_free_guidance:
                un_cond = self.label_condition.empty_label_embeds.repeat(
                    len(sample_inputs["label"]), 1, 1
                ).to(cond)
                label_cond = torch.cat([un_cond, cond], dim=0)

        latents_list = []
        if seed != None:
            generator = torch.Generator(device="cuda").manual_seed(seed)
        else:
            generator = None

        for _ in range(sample_times):
            sample_loop = flow_sample(
                self.denoise_scheduler,
                self.denoiser_model.eval(),
                shape=self.shape_model.latent_shape,
                visual_cond=visal_cond,
                caption_cond=caption_cond,
                label_cond=label_cond,
                steps=steps,
                guidance_scale=guidance_scale,
                do_classifier_free_guidance=do_classifier_free_guidance,
                device=self.device,
                eta=eta,
                disable_prog=False,
                generator=generator,
            )
            for sample, t in sample_loop:
                latents = sample
            latents_list.append(self.shape_model.decode(latents))

        return {"latents": latents_list, "inputs": sample_inputs}

    def on_validation_epoch_end(self):
        pass

    def test_step(self, batch, batch_idx):
        return