import spaces
import rembg
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, AutoPipelineForImage2Image
import cv2
from transformers import pipeline
import numpy as np
from PIL import Image
import gradio as gr

# pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
# pipe.to("cuda")

def check_prompt(prompt):
    if prompt is None:
        raise gr.Error("Please enter a prompt!")

imagepipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float32, use_safetensors=True
)

controlNet_normal = ControlNetModel.from_pretrained(
        "fusing/stable-diffusion-v1-5-controlnet-normal", 
        torch_dtype=torch.float16
    )

controlNet_depth = ControlNetModel.from_pretrained(
        "lllyasviel/sd-controlnet-depth", 
        torch_dtype=torch.float16
    )
controlNet_MAP = {"Normal": controlNet_normal, "Depth": controlNet_depth}

# vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)

# Function to generate an image from text using diffusion
@spaces.GPU
def generate_txttoimg(prompt, control_image, controlnet):
    prompt += "no background, side view, minimalist shot, single shoe, no legs, product photo"
    
    textpipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    controlnet=controlNet_MAP[controlnet],
    torch_dtype=torch.float16,
    safety_checker = None
    )
    
    textpipe.to("cuda")

    if controlnet == "Normal":
        control_image = get_normal(control_image)
    elif controlnet == "Depth":
        control_image = get_depth(control_image)
    
    image = textpipe(prompt, image=control_image).images[0]

    image2 = rembg.remove(image)

    return image2

@spaces.GPU
def generate_imgtoimg(prompt, init_image, strength=0.5):
    prompt += ", no background, side view, minimalist shot, single shoe, no legs, product photo"
    
    imagepipe.to("cuda")

    image = imagepipe(prompt, image=init_image, strength=strength).images[0]
    
    image2 = rembg.remove(image)

    return image2



def get_normal(image):
    depth_estimator = pipeline("depth-estimation", model ="Intel/dpt-hybrid-midas" )

    image = depth_estimator(image)['predicted_depth'][0]

    image = image.numpy()

    image_depth = image.copy()
    image_depth -= np.min(image_depth)
    image_depth /= np.max(image_depth)

    bg_threhold = 0.4

    x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3)
    x[image_depth < bg_threhold] = 0

    y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3)
    y[image_depth < bg_threhold] = 0

    z = np.ones_like(x) * np.pi * 2.0

    image = np.stack([x, y, z], axis=2)
    image /= np.sum(image ** 2.0, axis=2, keepdims=True) ** 0.5
    image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
    normalimage = Image.fromarray(image)

    return normalimage

def get_depth(image):
    depth_estimator = pipeline('depth-estimation')

    image = depth_estimator(image)['depth']
    image = np.array(image)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    depthimage = Image.fromarray(image)
    return depthimage

# def get_canny(image):
#     image = np.array(image)

#     low_threshold = 100
#     high_threshold = 200

#     image = cv2.Canny(image,low_threshold,high_threshold)
#     image = image[:,:,None]
#     image = np.concatenate([image, image, image], axis=2)
#     canny_image = Image.fromarray(image)
#     return canny_image

def update_image(image):
    return image