from model.video_diffusion.models.controlnet3d import ControlNet3DModel
from model.video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel
from model.video_diffusion.pipelines.pipeline_stable_diffusion_controlnet3d import Controlnet3DStableDiffusionPipeline
from transformers import DPTForDepthEstimation
from model.annotator.hed import HEDNetwork
import torch
from einops import rearrange,repeat
import imageio
import numpy as np
import cv2
import torch.nn.functional as F
from PIL import Image
import argparse
import tempfile
import os
import gradio as gr


control_mode = 'depth'
control_net_path = f"wf-genius/controlavideo-{control_mode}"
unet = UNetPseudo3DConditionModel.from_pretrained(control_net_path,
                        torch_dtype = torch.float16,
                        subfolder='unet',
                        ).to("cuda") 
controlnet = ControlNet3DModel.from_pretrained(control_net_path,
                        torch_dtype = torch.float16,
                        subfolder='controlnet',
                        ).to("cuda")

if control_mode == 'depth':
    annotator_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
elif control_mode == 'canny':
    annotator_model = None
elif control_mode == 'hed':
    # firstly download from https://huggingface.co/wf-genius/controlavideo-hed/resolve/main/hed-network.pth 
    annotator_model = HEDNetwork('hed-network.pth').to("cuda")

video_controlnet_pipe = Controlnet3DStableDiffusionPipeline.from_pretrained(control_net_path, unet=unet, 
                        controlnet=controlnet, annotator_model=annotator_model,
                        torch_dtype = torch.float16,
                        ).to("cuda")


def to_video(frames, fps: int) -> str:
    out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
    writer = imageio.get_writer(out_file.name, format='FFMPEG', fps=fps)
    for frame in frames:
        writer.append_data(np.array(frame))
    writer.close()
    return out_file.name

def inference(input_video, 
                prompt,
                seed,
                num_inference_steps,
                guidance_scale,
                sampling_rate,
                video_scale,
                init_noise_thres,
                each_sample_frame,
                iter_times,
                h,
                w,
                ):
    num_sample_frames = iter_times * each_sample_frame
    testing_prompt = [prompt]
    np_frames, fps_vid = Controlnet3DStableDiffusionPipeline.get_frames_preprocess(input_video, num_frames=num_sample_frames, sampling_rate=sampling_rate, return_np=True)
    if control_mode == 'depth':
        frames = torch.from_numpy(np_frames).div(255) * 2 - 1
        frames = rearrange(frames, "f h w c -> c f h w").unsqueeze(0)
        frames = rearrange(frames, 'b c f h w -> (b f) c h w')
        control_maps = video_controlnet_pipe.get_depth_map(frames, h, w, return_standard_norm=False)  # (b f) 1 h w
    elif control_mode == 'canny':
        control_maps = np.stack([cv2.Canny(inp, 100, 200) for inp in np_frames])
        control_maps = repeat(control_maps, 'f h w -> f c h w',c=1)
        control_maps = torch.from_numpy(control_maps).div(255)  # 0~1
    elif control_mode == 'hed':
        control_maps = np.stack([video_controlnet_pipe.get_hed_map(inp) for inp in np_frames])
        control_maps = repeat(control_maps, 'f h w -> f c h w',c=1)
        control_maps = torch.from_numpy(control_maps).div(255)  # 0~1
    control_maps = control_maps.to(dtype=controlnet.dtype, device=controlnet.device)
    control_maps = F.interpolate(control_maps, size=(h,w), mode='bilinear', align_corners=False)
    control_maps = rearrange(control_maps, "(b f) c h w -> b c f h w", f=num_sample_frames)
    if control_maps.shape[1] == 1:
        control_maps = repeat(control_maps, 'b c f h w -> b (n c) f h w',  n=3)

    frames = torch.from_numpy(np_frames).div(255)
    frames = rearrange(frames, 'f h w c -> f c h w')
    v2v_input_frames =  torch.nn.functional.interpolate(
                frames,
                size=(h, w),
                mode="bicubic",
                antialias=True,
            ) 
    v2v_input_frames = rearrange(v2v_input_frames, '(b f) c h w -> b c f h w ', f=num_sample_frames)

    out = []
    for i in range(num_sample_frames//each_sample_frame):
        out1 = video_controlnet_pipe(
                # controlnet_hint= control_maps[:,:,:each_sample_frame,:,:],
                # images= v2v_input_frames[:,:,:each_sample_frame,:,:],
                controlnet_hint=control_maps[:,:,i*each_sample_frame-1:(i+1)*each_sample_frame-1,:,:] if i>0 else control_maps[:,:,:each_sample_frame,:,:],
                images=v2v_input_frames[:,:,i*each_sample_frame-1:(i+1)*each_sample_frame-1,:,:] if i>0 else v2v_input_frames[:,:,:each_sample_frame,:,:],
                first_frame_output=out[-1] if i>0 else None,
                prompt=testing_prompt,
                num_inference_steps=num_inference_steps,
                width=w,
                height=h,
                guidance_scale=guidance_scale,
                generator=[torch.Generator(device="cuda").manual_seed(seed)],
                video_scale = video_scale, 
                init_noise_by_residual_thres = init_noise_thres,    # residual-based init. larger thres ==> more smooth.
                controlnet_conditioning_scale=1.0,
                fix_first_frame=True, 
                in_domain=True,
        )
        out1 = out1.images[0]    
        if len(out1) > 1:
            out1 = out1[1:] # drop the first frame
        out.extend(out1)

    return to_video(out, 8)


examples = [
        ["bear.mp4",
        "a bear walking through stars, artstation"],
        ["car-shadow.mp4",
        "a car, sunset, cartoon style, artstation."],
        ["libby.mp4",
        "a dog running, chinese ink painting."],
]

def preview_inference(
        input_video, 
        prompt, seed, 
        num_inference_steps, guidance_scale, 
        sampling_rate, video_scale, init_noise_thres,
        each_sample_frame,iter_times, h, w,
    ):
    return inference(input_video, 
        prompt, seed, 
        num_inference_steps, guidance_scale, 
        sampling_rate, 0.0, 0.0, 1, 1, h, w,)

if __name__ == '__main__':
    with gr.Blocks() as demo:
        with gr.Row():
            # with gr.Column(scale=1):
            input_video = gr.Video(
                label="Input Video", source='upload', format="mp4", visible=True)
            with gr.Column():
                init_noise_thres = gr.Slider(0, 1, value=0.1, step=0.1, label="init_noise_thress")
                each_sample_frame = gr.Slider(6, 16, value=8, step=1, label="each_sample_frame")
                iter_times = gr.Slider(1, 4, value=1, step=1, label="iter_times")
                sampling_rate = gr.Slider(1, 8, value=3, step=1, label="sampling_rate")
                h = gr.Slider(256, 768, value=512, step=64, label="height")
                w = gr.Slider(256, 768, value=512, step=64, label="width")
            with gr.Column():
                seed =  gr.Slider(0, 6666, value=1, step=1, label="seed")
                num_inference_steps =  gr.Slider(5, 50, value=20, step=1, label="num_inference_steps")
                guidance_scale = gr.Slider(1, 20, value=7.5, step=0.5, label="guidance_scale")
                video_scale = gr.Slider(0, 2.5, value=1.5, step=0.1, label="video_scale")
                prompt = gr.Textbox(label='Prompt')
                # preview_button = gr.Button('Preview')
                run_button = gr.Button('Generate Video')
            # with gr.Column(scale=1):
            result = gr.Video(label="Generated Video")
        inputs = [
                input_video,
                prompt,
                seed,
                num_inference_steps,
                guidance_scale,
                sampling_rate,
                video_scale,
                init_noise_thres,
                each_sample_frame,
                iter_times,
                h,
                w,
            ]
    
        gr.Examples(examples=examples,
                    inputs=inputs,
                    outputs=result,
                    fn=inference,
                    cache_examples=False,
                    run_on_click=False,
                    )

        run_button.click(fn=inference,
                            inputs=inputs,
                            outputs=result,)
        # preview_button.click(fn=preview_inference,
        #                     inputs=inputs,
        #                     outputs=result,)
    
    demo.launch(server_name="0.0.0.0", server_port=7860)