import os
import sys
import cv2
import re
from shared import path_manager
import modules.async_worker as worker
from tqdm import tqdm

from modules.util import generate_temp_filename

from PIL import Image
import imageio.v3 as iio
import numpy as np
import torch
import insightface

from importlib.abc import MetaPathFinder, Loader
from importlib.util import spec_from_loader, module_from_spec

class ImportRedirector(MetaPathFinder):
    def __init__(self, redirect_map):
        self.redirect_map = redirect_map

    def find_spec(self, fullname, path, target=None):
        if fullname in self.redirect_map:
            return spec_from_loader(fullname, ImportLoader(self.redirect_map[fullname]))
        return None

class ImportLoader(Loader):
    def __init__(self, redirect):
        self.redirect = redirect

    def create_module(self, spec):
        return None

    def exec_module(self, module):
        import importlib
        redirected = importlib.import_module(self.redirect)
        module.__dict__.update(redirected.__dict__)

# Set up the redirection
redirect_map = {
    'torchvision.transforms.functional_tensor': 'torchvision.transforms.functional'
}

sys.meta_path.insert(0, ImportRedirector(redirect_map))

import gfpgan
from facexlib.utils.face_restoration_helper import FaceRestoreHelper

# Requirements:
# insightface==0.7.3
# onnxruntime-gpu==1.16.1
# gfpgan==1.3.8
#
# add to settings/powerup.json
#
#  "Faceswap": {
#    "type": "faceswap"
#  }
#
# Models in models/faceswap/
# https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth
# and inswapper_128.onnx from where you can find it

class pipeline:
    pipeline_type = ["faceswap"]

    analyser_model = None
    analyser_hash = ""
    swapper_model = None
    swapper_hash = ""
    gfpgan_model = None

    def parse_gen_data(self, gen_data):
        gen_data["original_image_number"] = gen_data["image_number"]
        gen_data["image_number"] = 1
        gen_data["show_preview"] = False
        return gen_data

    def load_base_model(self, name):
        model_name = "inswapper_128.onnx"
        if not self.swapper_hash == model_name:
            print(f"Loading swapper model: {model_name}")
            model_path = os.path.join(path_manager.model_paths["faceswap_path"], model_name)
            try:
                with open(os.devnull, "w") as sys.stdout:
                    self.swapper_model = insightface.model_zoo.get_model(
                        model_path,
                        download=False,
                        download_zip=False,
                    )
                    self.swapper_hash = model_name
                sys.stdout = sys.__stdout__
            except:
                print(f"Failed loading model! {model_path}")

        model_name = "buffalo_l"
        det_thresh = 0.5
        if not self.analyser_hash == model_name:
            print(f"Loading analyser model: {model_name}")
            try:
                with open(os.devnull, "w") as sys.stdout:
                    self.analyser_model = insightface.app.FaceAnalysis(name=model_name)
                    self.analyser_model.prepare(
                        ctx_id=0, det_thresh=det_thresh, det_size=(640, 640)
                    )
                    self.analyser_hash = model_name
                sys.stdout = sys.__stdout__
            except:
                print(f"Failed loading model! {model_name}")

    def load_gfpgan_model(self):
        if self.gfpgan_model is None:
            channel_multiplier = 2

            model_name = "GFPGANv1.4.pth"
            model_path = os.path.join(path_manager.model_paths["faceswap_path"], model_name)

            # https://github.com/TencentARC/GFPGAN/blob/master/inference_gfpgan.py
            self.gfpgan_model = gfpgan.GFPGANer
            self.gfpgan_model.bg_upsampler = None
            # initialize model
            self.gfpgan_model.device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu"
            )

            upscale = 2
            self.gfpgan_model.face_helper = FaceRestoreHelper(
                upscale,
                det_model="retinaface_resnet50",
                model_rootpath=path_manager.model_paths["faceswap_path"],
            )
            # face_size=512,
            # crop_ratio=(1, 1),
            # save_ext='png',
            # use_parse=True,
            # device=self.device,

            self.gfpgan_model.gfpgan = gfpgan.GFPGANv1Clean(
                out_size=512,
                num_style_feat=512,
                channel_multiplier=channel_multiplier,
                decoder_load_path=None,
                fix_decoder=False,
                num_mlp=8,
                input_is_latent=True,
                different_w=True,
                narrow=1,
                sft_half=True,
            )

            loadnet = torch.load(model_path)
            if "params_ema" in loadnet:
                keyname = "params_ema"
            else:
                keyname = "params"
            self.gfpgan_model.gfpgan.load_state_dict(loadnet[keyname], strict=True)
            self.gfpgan_model.gfpgan.eval()
            self.gfpgan_model.gfpgan = self.gfpgan_model.gfpgan.to(
                self.gfpgan_model.device
            )

    def load_keywords(self, lora):
        return ""

    def load_loras(self, loras):
        return

    def refresh_controlnet(self, name=None):
        return

    def clean_prompt_cond_caches(self):
        return

    def swap_faces(self, original_image, input_faces, out_faces):
        idx = 0
        for out_face in out_faces:
            original_image = self.swapper_model.get(
                original_image,
                out_face,
                input_faces[idx % len(input_faces)],
                paste_back=True,
            )
            idx += 1
        return original_image

    def restore_faces(self, image):
        self.load_gfpgan_model()

        image_bgr = image[:, :, ::-1]
        _cropped_faces, _restored_faces, gfpgan_output_bgr = self.gfpgan_model.enhance(
            self.gfpgan_model,
            image_bgr,
            has_aligned=False,
            only_center_face=False,
            paste_back=True,
            weight=0.5,
        )
        image = gfpgan_output_bgr[:, :, ::-1]

        return image

    def process(
        self,
        gen_data=None,
        callback=None,
    ):
        worker.add_result(
            gen_data["task_id"],
            "preview",
            (-1, f"Generating ...", None)
        )

        input_image = gen_data["input_image"]
        input_image = cv2.cvtColor(np.asarray(input_image), cv2.COLOR_RGB2BGR)
        input_faces = sorted(
            self.analyser_model.get(input_image), key=lambda x: x.bbox[0]
        )

        prompt = gen_data["prompt"].strip()
        if re.fullmatch("https?://.*\.gif", prompt, re.IGNORECASE) is not None:
            x = iio.immeta(prompt)
            duration = x["duration"]
            loop = x["loop"]
            gif = cv2.VideoCapture(prompt)

            # Swap
            in_imgs = []
            out_imgs = []
            while True:
                ret, frame = gif.read()
                if not ret:
                    break
                in_imgs.append(frame)

            with tqdm(total=len(in_imgs), desc="Groop", unit="frames") as progress:
                i=0
                steps=len(in_imgs)
                for frame in in_imgs:
                    out_faces = sorted(
                        self.analyser_model.get(frame), key=lambda x: x.bbox[0]
                    )
                    frame = self.swap_faces(frame, input_faces, out_faces)
                    out_imgs.append(
                        Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                    )
                    i+=1
                    callback(i, 0, 0, steps, out_imgs[-1])
                    progress.update(1)
            images = generate_temp_filename(
                folder=path_manager.model_paths["temp_outputs_path"], extension="gif"
            )
            os.makedirs(os.path.dirname(images), exist_ok=True)
            out_imgs[0].save(
                images,
                save_all=True,
                append_images=out_imgs[1:],
                optimize=True,
                duration=duration,
                loop=loop,
            )
        else:
            output_image = cv2.imread(gen_data["main_view"])
            if output_image is None:
                images = "html/error.png"
            else:
                output_faces = sorted(
                    self.analyser_model.get(output_image), key=lambda x: x.bbox[0]
                )
                result_image = self.swap_faces(output_image, input_faces, output_faces)
                result_image = self.restore_faces(result_image)
                images = Image.fromarray(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))

        return [images]