import pdb

import torch
import numpy as np
from .utils import _broadcast_tensor, _extract_into_tensor



class _WrappedModel_DiT:
    def __init__(self, model, diffusion, device=None, class_emb_null=None):
        self.model = model
        self.diffusion = diffusion
        self._predict_xstart_from_eps = diffusion._predict_xstart_from_eps
        
        self.diffusion_t_map = list(diffusion.use_timesteps)
        self.diffusion_t_map.sort()
        
        self.diffusion_t = [self.diffusion_t_map[i] for i in range(diffusion.num_timesteps)] # list(range(diffusion.num_timesteps))
        self.diffusion_t = np.array(self.diffusion_t)
        
        self.diffusion_sqrt_alpha_cumprod = np.array([diffusion.sqrt_alphas_cumprod[i] for i in range(diffusion.num_timesteps)])
        self.fm_steps = [(1 - self.diffusion_sqrt_alpha_cumprod[i]**2)**0.5/(self.diffusion_sqrt_alpha_cumprod[i] + (1 - self.diffusion_sqrt_alpha_cumprod[i]**2)**0.5) for i in range(len(self.diffusion_t))]
        
        self.fm_steps = torch.tensor([0] + self.fm_steps, device=device)
        self.y_null = class_emb_null




        
    def __call__(self, x, t, y, kwargs):

        N = len(self.diffusion_t)
        B,C,H,W = x.shape
        diffusion_x = torch.zeros_like(x)
        diffusion_t = _extract_into_tensor(self.diffusion_t, t-1, t.shape).long()


        t_fm = self.fm_steps[t]
        diffusion_x_tmp = _extract_into_tensor(self.diffusion.sqrt_alphas_cumprod, t-1, x.shape) * x / ( 1 + 1e-4 - _broadcast_tensor(t_fm,x.shape)) 
        diffusion_x_tmp = diffusion_x_tmp.to(torch.float)
        diffusion_x = torch.where(_broadcast_tensor(t,x.shape) == N, x, diffusion_x_tmp)


        y_null_batch = torch.cat([self.y_null[0].unsqueeze(0)]*B, dim=0)
        y_new = torch.cat([y, y_null_batch], 0)
        

        model_output = self.model(torch.cat([diffusion_x,diffusion_x],dim=0), torch.cat([diffusion_t,diffusion_t],dim=0), y_new, **kwargs)
        model_output = model_output[:B]
        model_output, _ = torch.split(model_output, C, dim=1)
        x0_diffusion = self._predict_xstart_from_eps(x_t=diffusion_x, t=t-1, eps=model_output)
        vt = (x - x0_diffusion) / (_broadcast_tensor(t_fm,x.shape))
        vt = vt.to(diffusion_x.dtype)
        return vt



class _WrappedModel_Sora:
    def __init__(self, model, guidance_scale, y_null, timesteps, num_timesteps, mask_t):
        self.model = model
        self.guidance_scale = guidance_scale    
        self.y_null = y_null
        
        self.timesteps = [torch.tensor([0], device=model.device)] + timesteps[::-1]
        self.timesteps = torch.cat(self.timesteps, dim=0)
        self.fm_steps = [x/num_timesteps for x in self.timesteps]
        self.mask_t = mask_t
        
    def __call__(self, x, t, y, kwargs):
        y = torch.cat([y, self.y_null], dim=0)
        
        t_in = self.timesteps[t]        
        
        x_in = torch.cat([x,x], dim=0)
        # breakpoint()
        mask_t_upper = self.mask_t >= t_in.unsqueeze(1)
        kwargs["x_mask"] = mask_t_upper.repeat(2, 1)

        t_in = torch.cat([t_in,t_in], dim=0)
        with torch.no_grad():
            pred = self.model(x_in, t_in, y,  **kwargs).chunk(2, dim=1)[0]
        # breakpoint()
        pred_cond, pred_uncond = pred.chunk(2, dim=0)
        v_pred = pred_uncond + self.guidance_scale * (pred_cond - pred_uncond)
        
        return -v_pred



class _WrappedModel_Wan:
    def __init__(self, model, timesteps, num_timesteps, context_null, guide_scale):
        self.model = model
        self.context_null = context_null
        self.guide_scale = guide_scale
        fm_steps = torch.cat([timesteps,torch.zeros_like(timesteps[0]).view(1)])
        self.time_steps = torch.flip(fm_steps, dims=[0])
        self.fm_steps = self.time_steps/num_timesteps
        
        
    def __call__(self, x, t, y, kwargs):
        self.time_steps = self.time_steps.to(t.device)
        t = self.time_steps[t]
        noise_pred_cond = self.model(x, t=t, context=y, **kwargs)[0]
        noise_pred_uncond = self.model(x, t=t, context=self.context_null, **kwargs)[0]
        noise_pred = noise_pred_uncond + self.guide_scale * (noise_pred_cond - noise_pred_uncond)
        return noise_pred




class _WrappedModel_FLUX:
    def __init__(self, model, timesteps, num_timesteps):
        self.model = model
        fm_steps = torch.cat([timesteps,torch.zeros_like(timesteps[0]).view(1)])
        self.time_steps = torch.flip(fm_steps, dims=[0])
        self.fm_steps = self.time_steps/num_timesteps
            
        
    def __call__(self, x, t, y, kwargs):
        t = self.time_steps[t]
        t = t.expand(x.shape[0]).to(x.dtype) / 1000
        pred = self.model(hidden_states=x, timestep=t, **kwargs)[0]
        return pred