import argparse
import logging
import os
import re
from typing import Callable

import cv2
import gradio as gr
import nh3
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor

from lisa_on_cuda import app_logger
from lisa_on_cuda.LISA import LISAForCausalLM
from lisa_on_cuda.llava import conversation as conversation_lib
from lisa_on_cuda.llava.mm_utils import tokenizer_image_token
from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
from . import constants, utils

placeholders = utils.create_placeholder_variables()


def get_device_map_kwargs(device_map="auto", device="cuda"):
    kwargs = {"device_map": device_map}
    if device != "cuda":
        kwargs['device_map'] = {"": device}
    return kwargs


def parse_args(args_to_parse, internal_logger=None):
    if internal_logger is None:
        internal_logger = app_logger
    internal_logger.info(f"ROOT_PROJECT:{utils.PROJECT_ROOT_FOLDER}, default vis_output:{utils.VIS_OUTPUT}.")
    parser = argparse.ArgumentParser(description="LISA chat")
    parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1-explanatory")
    parser.add_argument("--vis_save_path", default=str(utils.VIS_OUTPUT), type=str)
    parser.add_argument(
        "--precision",
        default="fp16",
        type=str,
        choices=["fp32", "bf16", "fp16"],
        help="precision for inference",
    )
    parser.add_argument("--image_size", default=1024, type=int, help="image size")
    parser.add_argument("--model_max_length", default=512, type=int)
    parser.add_argument("--lora_r", default=8, type=int)
    parser.add_argument(
        "--vision-tower", default="openai/clip-vit-large-patch14", type=str
    )
    parser.add_argument("--local-rank", default=0, type=int, help="node rank")
    parser.add_argument("--load_in_8bit", action="store_true", default=False)
    parser.add_argument("--load_in_4bit", action="store_true", default=True)
    parser.add_argument("--use_mm_start_end", action="store_true", default=True)
    parser.add_argument(
        "--conv_type",
        default="llava_v1",
        type=str,
        choices=["llava_v1", "llava_llama_2"],
    )
    return parser.parse_args(args_to_parse)


def get_cleaned_input(input_str, internal_logger=None):
    if internal_logger is None:
        internal_logger = app_logger
    internal_logger.info(f"start cleaning of input_str: {input_str}.")
    input_str = nh3.clean(
        input_str,
        tags={
            "a",
            "abbr",
            "acronym",
            "b",
            "blockquote",
            "code",
            "em",
            "i",
            "li",
            "ol",
            "strong",
            "ul",
        },
        attributes={
            "a": {"href", "title"},
            "abbr": {"title"},
            "acronym": {"title"},
        },
        url_schemes={"http", "https", "mailto"},
        link_rel=None,
    )
    internal_logger.info(f"cleaned input_str: {input_str}.")
    return input_str


def set_image_precision_by_args(input_image, precision):
    if precision == "bf16":
        input_image = input_image.bfloat16()
    elif precision == "fp16":
        input_image = input_image.half()
    else:
        input_image = input_image.float()
    return input_image


def preprocess(
        x,
        pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
        pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
        img_size=1024,
) -> torch.Tensor:
    """Normalize pixel values and pad to a square input."""
    logging.info("preprocess started")
    # Normalize colors
    x = (x - pixel_mean) / pixel_std
    # Pad
    h, w = x.shape[-2:]
    padh = img_size - h
    padw = img_size - w
    x = F.pad(x, (0, padw, 0, padh))
    logging.info("preprocess ended")
    return x


def load_model_for_causal_llm_pretrained(
        version, torch_dtype, load_in_8bit, load_in_4bit, seg_token_idx, vision_tower,
        internal_logger: logging = None, device_map="auto", device="cuda"
):
    if internal_logger is None:
        internal_logger = app_logger
    internal_logger.debug(f"prepare kwargs, 4bit:{load_in_4bit}, 8bit:{load_in_8bit}.")
    kwargs_device_map = get_device_map_kwargs(device_map=device_map, device=device)
    kwargs = {"torch_dtype": torch_dtype, **kwargs_device_map}
    if load_in_4bit:
        kwargs.update(
            {
                "torch_dtype": torch.half,
                # "load_in_4bit": True,
                "quantization_config": BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.float16,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4",
                    llm_int8_skip_modules=["visual_model"],
                ),
            }
        )
    elif load_in_8bit:
        kwargs.update(
            {
                "torch_dtype": torch.half,
                "quantization_config": BitsAndBytesConfig(
                    llm_int8_skip_modules=["visual_model"],
                    load_in_8bit=True,
                ),
            }
        )
    internal_logger.debug(f"start loading model:{version}.")
    _model = LISAForCausalLM.from_pretrained(
        version,
        low_cpu_mem_usage=True,
        vision_tower=vision_tower,
        seg_token_idx=seg_token_idx,
        # try to avoid CUDA init RuntimeError on ZeroGPU huggingface hardware (injected into kwargs)
        **kwargs
    )
    internal_logger.debug("model loaded!")
    return _model


def gpu_init_zero(internal_logger: logging = None):
    if internal_logger is None:
        internal_logger = app_logger
    internal_logger.info("GPU init...")


def get_model(args_to_parse, internal_logger: logging = None, inference_decorator: Callable = None, device_map="auto", device="cpu", device2="cuda"):
    """Load model and inference function with arguments. Compatible with ZeroGPU (spaces 0.30.2)

    Args:
        args_to_parse: default input arguments
        internal_logger: logger
        inference_decorator: inference decorator (now it's supported and tested ZeroGPU spaces.GPU decorator)
        device_map: device type needed for ZeroGPU cuda hw
        device: device type needed for ZeroGPU cuda hw
        device2: device type needed for ZeroGPU cuda hw, default to cpu to avoid bug on loading model

    Returns:
        inference function with LISA model
    """
    if internal_logger is None:
        internal_logger = app_logger
    internal_logger.info(f"starting model preparation, folder creation for path: {args_to_parse.vis_save_path}.")
    if inference_decorator:
        internal_logger.info(f"try explicit gpu init with decorator {inference_decorator.__name__}...")
        inference_decorator(gpu_init_zero(internal_logger=internal_logger))
        internal_logger.info(f"gpu explicitly initialized!")
    try:
        vis_save_path_exists = os.path.isdir(args_to_parse.vis_save_path)
        logging.info(f"vis_save_path_exists:{vis_save_path_exists}.")
        os.makedirs(args_to_parse.vis_save_path, exist_ok=True)
    except PermissionError as pex:
        internal_logger.info(f"PermissionError: {pex}, folder:{args_to_parse.vis_save_path}.")

    # global tokenizer, tokenizer
    # Create model
    internal_logger.info(f"creating tokenizer: {args_to_parse.version}, max_length:{args_to_parse.model_max_length}.")
    _tokenizer = AutoTokenizer.from_pretrained(
        args_to_parse.version,
        cache_dir=None,
        model_max_length=args_to_parse.model_max_length,
        padding_side="right",
        use_fast=False,
    )
    _tokenizer.pad_token = _tokenizer.unk_token
    internal_logger.info("tokenizer ok")
    args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
    torch_dtype = torch.float32
    if args_to_parse.precision == "bf16":
        torch_dtype = torch.bfloat16
    elif args_to_parse.precision == "fp16":
        torch_dtype = torch.half

    internal_logger.debug(f"start loading causal llm:{args_to_parse.version}...")
    _model = inference_decorator(
        load_model_for_causal_llm_pretrained(
            args_to_parse.version,
            torch_dtype=torch_dtype,
            load_in_8bit=args_to_parse.load_in_8bit,
            load_in_4bit=args_to_parse.load_in_4bit,
            seg_token_idx=args_to_parse.seg_token_idx,
            vision_tower=args_to_parse.vision_tower,
            device_map=device_map,  # try to avoid CUDA init RuntimeError on ZeroGPU huggingface hardware
            device=device
        )) if inference_decorator else load_model_for_causal_llm_pretrained(
        args_to_parse.version,
        torch_dtype=torch_dtype,
        load_in_8bit=args_to_parse.load_in_8bit,
        load_in_4bit=args_to_parse.load_in_4bit,
        seg_token_idx=args_to_parse.seg_token_idx,
        vision_tower=args_to_parse.vision_tower,
        device_map=device_map
    )
    internal_logger.debug("causal llm loaded!")

    _model.config.eos_token_id = _tokenizer.eos_token_id
    _model.config.bos_token_id = _tokenizer.bos_token_id
    _model.config.pad_token_id = _tokenizer.pad_token_id
    _model.get_model().initialize_vision_modules(_model.get_model().config)

    internal_logger.debug(f"start vision tower:{args_to_parse.vision_tower}...")
    _model, vision_tower = inference_decorator(
        prepare_model_vision_tower(_model, args_to_parse, torch_dtype)
    ) if inference_decorator else prepare_model_vision_tower(
        _model, args_to_parse, torch_dtype
    )
    internal_logger.debug(f"_model type:{type(_model)}, vision_tower type:{type(vision_tower)}.")
    # set device to "cuda" try to avoid CUDA init RuntimeError on ZeroGPU huggingface hardware
    vision_tower.to(device=device2)
    internal_logger.debug("vision tower loaded, prepare clip image processor...")
    _clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower)
    internal_logger.debug("clip image processor done.")
    _transform = ResizeLongestSide(args_to_parse.image_size)
    internal_logger.debug("start model evaluation...")
    inference_decorator(_model.eval()) if inference_decorator else _model.eval()
    internal_logger.info("model preparation ok!")
    return _model, _clip_image_processor, _tokenizer, _transform


def prepare_model_vision_tower(_model, args_to_parse, torch_dtype, internal_logger: logging = None):
    if internal_logger is None:
        internal_logger = app_logger
    internal_logger.debug(f"start vision tower preparation, torch dtype:{torch_dtype}, args_to_parse:{args_to_parse}.")
    vision_tower = _model.get_model().get_vision_tower()
    vision_tower.to(dtype=torch_dtype)
    if args_to_parse.precision == "bf16":
        internal_logger.debug(f"vision tower precision bf16? {args_to_parse.precision}, 1.")
        _model = _model.bfloat16().cuda()
    elif (
            args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit)
    ):
        internal_logger.debug(f"vision tower precision fp16? {args_to_parse.precision}, 2.")
        vision_tower = _model.get_model().get_vision_tower()
        _model.model.vision_tower = None
        import deepspeed

        model_engine = deepspeed.init_inference(
            model=_model,
            dtype=torch.half,
            replace_with_kernel_inject=True,
            replace_method="auto",
        )
        _model = model_engine.module
        _model.model.vision_tower = vision_tower.half().cuda()
    elif args_to_parse.precision == "fp32":
        internal_logger.debug(f"vision tower precision fp32? {args_to_parse.precision}, 3.")
        _model = _model.float().cuda()
    vision_tower = _model.get_model().get_vision_tower()
    internal_logger.debug("vision tower ok!")
    return _model, vision_tower


def get_inference_model_by_args(
        args_to_parse, internal_logger0: logging = None, inference_decorator: Callable = None, device_map="auto", device="cuda"
):
    """Load model and inference function with arguments. Compatible with ZeroGPU (spaces 0.30.2)

    Args:
        args_to_parse: default input arguments
        internal_logger0: logger
        inference_decorator: inference decorator (now it's supported and tested ZeroGPU spaces.GPU decorator)
        device_map: device type needed for ZeroGPU cuda hw
        device: device type needed for ZeroGPU cuda hw

    Returns:
        inference function with LISA model
    """
    if internal_logger0 is None:
        internal_logger0 = app_logger
    internal_logger0.info(f"args_to_parse:{args_to_parse}, creating model...")
    model, clip_image_processor, tokenizer, transform = get_model(args_to_parse, device_map=device_map, device=device, inference_decorator=inference_decorator)
    internal_logger0.info("created model, preparing inference function")
    no_seg_out = placeholders["no_seg_out"]

    def inference(
            input_str: str,
            input_image: str | np.ndarray,
            internal_logger: logging = None,
            embedding_key: str = None
    ):
        if internal_logger is None:
            internal_logger = app_logger

        # filter out special chars
        input_str = get_cleaned_input(input_str)
        internal_logger.info(f" input_str type: {type(input_str)}, input_image type: {type(input_image)}.")
        internal_logger.info(f"input_str: {input_str}, input_image: {type(input_image)}.")

        # input valid check
        if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
            output_str = f"[Error] Unprocessable Entity input: {input_str}."
            internal_logger.error(output_str)

            from fastapi import status
            from fastapi.responses import JSONResponse

            return JSONResponse(
                status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
                content={"msg": "Error - Unprocessable Entity"}
            )

        # Model Inference
        conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
        conv.messages = []

        prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + input_str
        if args_to_parse.use_mm_start_end:
            replace_token = (
                    utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN
            )
            prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token)

        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], "")
        prompt = conv.get_prompt()

        internal_logger.info("read and preprocess image.")
        image_np = input_image
        if isinstance(input_image, str):
            image_np = cv2.imread(input_image)
            image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
        original_size_list = [image_np.shape[:2]]
        internal_logger.debug("start clip_image_processor.preprocess")
        image_clip = (
            clip_image_processor.preprocess(image_np, return_tensors="pt")[
                "pixel_values"
            ][0]
            .unsqueeze(0)
            .cuda()
        )
        internal_logger.debug("done clip_image_processor.preprocess")
        internal_logger.info(f"image_clip type: {type(image_clip)}.")
        image_clip = set_image_precision_by_args(image_clip, args_to_parse.precision)

        image = transform.apply_image(image_np)
        resize_list = [image.shape[:2]]

        internal_logger.debug(f"starting preprocess image: {type(image_clip)}.")
        image = (
            preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
            .unsqueeze(0)
            .cuda()
        )
        internal_logger.info(f"done preprocess image:{type(image)}, image_clip type: {type(image_clip)}.")
        image = set_image_precision_by_args(image, args_to_parse.precision)

        input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
        input_ids = input_ids.unsqueeze(0).cuda()

        embedding_key = get_hash_array(embedding_key, image, internal_logger)
        internal_logger.info(f"start model evaluation with embedding_key {embedding_key}.")
        output_ids, pred_masks = model.evaluate(
            image_clip,
            image,
            input_ids,
            resize_list,
            original_size_list,
            max_new_tokens=512,
            tokenizer=tokenizer,
            model_logger=internal_logger,
            embedding_key=embedding_key
        )
        internal_logger.info("model evaluation done, start token decoding...")
        output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX]

        text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
        text_output = text_output.replace("\n", "").replace("  ", " ")
        text_output = text_output.split("ASSISTANT: ")[-1]

        internal_logger.info(
            f"token decoding ended,found n {len(pred_masks)} prediction masks, "
            f"text_output type: {type(text_output)}, text_output: {text_output}."
        )
        output_image = no_seg_out
        output_mask = no_seg_out
        for i, pred_mask in enumerate(pred_masks):
            if pred_mask.shape[0] == 0 or pred_mask.shape[1] == 0:
                continue
            pred_mask = pred_mask.detach().cpu().numpy()[0]
            pred_mask_bool = pred_mask > 0
            output_mask = pred_mask_bool.astype(np.uint8) * 255

            output_image = image_np.copy()
            output_image[pred_mask_bool] = (
                    image_np * 0.5
                    + pred_mask_bool[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
            )[pred_mask_bool]

        output_str = f"ASSISTANT: {text_output} ..."
        internal_logger.info(f"output_image type: {type(output_mask)}.")
        return output_image, output_mask, output_str

    internal_logger0.info("prepared inference function.")
    internal_logger0.info(f"inference decorator none? {type(inference_decorator)}.")
    if inference_decorator:
        return inference_decorator(inference)

    return inference


def get_gradio_interface(
        fn_inference: Callable,
        args: str = None
):
    article_and_demo_parameters = constants.article
    if args is not None:
        article_and_demo_parameters = constants.demo_parameters
        args_dict = {arg: getattr(args, arg) for arg in vars(args)}
        for arg_k, arg_v in args_dict.items():
            print(f"arg_k:{arg_v}, arg_v:{arg_v}.")
            article_and_demo_parameters += " * " + "".join(f"{arg_k}: {arg_v};\n")

        print(f"args_dict:{args_dict}.")
        print(f"description_and_demo_parameters:{article_and_demo_parameters}.")
        article_and_demo_parameters += "\n\n" + constants.article

    return gr.Interface(
        fn_inference,
        inputs=[
            gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
            gr.Image(type="filepath", label="Input Image")
        ],
        outputs=[
            gr.Image(type="pil", label="segmentation Output"),
            gr.Image(type="pil", label="mask Output"),
            gr.Textbox(lines=1, placeholder=None, label="Text Output")
        ],
        title=constants.title,
        description=constants.description,
        article=article_and_demo_parameters,
        examples=constants.examples,
        allow_flagging="auto"
    )


def get_hash_array(embedding_key: str, arr: np.ndarray | torch.Tensor, model_logger: logging):
    from base64 import b64encode
    from hashlib import sha256

    model_logger.debug(f"embedding_key {embedding_key} is None? {embedding_key is None}.")
    if embedding_key is None:
        img2hash = arr
        if isinstance(arr, torch.Tensor):
            model_logger.debug("images variable is a Tensor, start converting back to numpy")
            img2hash = arr.numpy(force=True)
            model_logger.debug("done Tensor converted back to numpy")
        model_logger.debug("start image hashing")
        img2hash_fn = sha256(img2hash)
        embedding_key = b64encode(img2hash_fn.digest())
        embedding_key = embedding_key.decode("utf-8")
        model_logger.debug(f"done image hashing, now embedding_key is {embedding_key}.")
    return embedding_key


if __name__ == '__main__':
    parsed_args = parse_args([])
    print("arrrrg:", parsed_args)