from enum import Enum
import gc
import numpy as np
import torch




import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel


import utils
import gradio_utils
import os

from einops import rearrange

import matplotlib.pyplot as plt

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

class Model:
    def __init__(self, **kwargs):
        self.base_controlnet, self.base_controlnet_params = FlaxControlNetModel.from_pretrained(
       #"JFoz/dog-cat-pose", dtype=jnp.bfloat16
        "lllyasviel/control_v11p_sd15_openpose", dtype=jnp.bfloat16, from_pt=True
        )
        self.pipe, self.params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", controlnet=self.base_controlnet, revision="flax", dtype=jnp.bfloat16,# from_pt=True,
        )

    def infer_frame(self, frame_id, prompt, negative_prompt, rng, **kwargs):

        print(prompt, frame_id)

        num_samples = 1
        prompt_ids = self.pipe.prepare_text_inputs([prompt[frame_id]]*num_samples)
        negative_prompt_ids = self.pipe.prepare_text_inputs([negative_prompt[frame_id]] * num_samples)
        processed_image = self.pipe.prepare_image_inputs([kwargs['image'][frame_id]]*num_samples)
    
        self.params["controlnet"] = self.base_controlnet_params


        p_params = replicate(self.params)
        prompt_ids = shard(prompt_ids)
        negative_prompt_ids = shard(negative_prompt_ids)
        processed_image = shard(processed_image)
    
        output = self.pipe(
            prompt_ids=prompt_ids,
            image=processed_image,
            params=p_params,
            prng_seed=rng,
            num_inference_steps=50,
            neg_prompt_ids=negative_prompt_ids,
            jit=True,
        ).images

        output_images = np.asarray(output.reshape((num_samples,) + output.shape[-3:]))
        return output_images

    def inference(self, **kwargs):
        
        seed = kwargs.pop('seed', 0)
       
        rng = create_key(0)
        rng = jax.random.split(rng, jax.device_count())   

        f = len(kwargs['image'])
        print('frames', f)


        assert 'prompt' in kwargs
        prompt = [kwargs.pop('prompt')] * f
        negative_prompt = [kwargs.pop('negative_prompt', '')] * f

        frames_counter = 0
        
        result = []
        for i in range(0, f):
            print(f'Processing frame {i + 1} / {f}')
            result.append(self.infer_frame(frame_id=i,
                                                   prompt=prompt,
                                                   negative_prompt=negative_prompt,
                                                   rng = rng,
                                                   **kwargs))
            frames_counter += 1
        result = np.stack(result, axis=0)
        return result

    def process_controlnet_pose(self,
                                video_path,
                                prompt,
                                num_inference_steps=20,
                                controlnet_conditioning_scale=1.0,
                                guidance_scale=9.0,
                                seed=42,
                                eta=0.0,
                                resolution=512,
                                save_path=None):
        print("Module Pose")
        video_path = gradio_utils.motion_to_video_path(video_path)


        added_prompt = 'best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth'
        negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic'

        video, fps = utils.prepare_video(
            video_path, resolution, False, output_fps=4)
        control = utils.pre_process_pose(
            video, apply_pose_detect=False)
        
        print('N frames', len(control))
        f, _, h, w = video.shape

        result = self.inference(image=control,
                                prompt=prompt + ', ' + added_prompt,
                                height=h,
                                width=w,
                                negative_prompt=negative_prompts,
                                num_inference_steps=num_inference_steps,
                                guidance_scale=guidance_scale,
                                controlnet_conditioning_scale=controlnet_conditioning_scale,
                                eta=eta,
                                seed=seed,
                                output_type='numpy',
                                )
        return utils.create_gif(result.astype(jnp.float16), fps, path=save_path)