import json
from collections import defaultdict
import safetensors
import timm
from transformers import AutoProcessor
import gradio as gr
import torch
import time
from florence2_implementation.modeling_florence2 import Florence2ForConditionalGeneration
from torchvision.transforms import InterpolationMode
from PIL import Image
import torchvision.transforms.functional as TF
from torchvision.transforms import transforms
import random
import csv
import os

torch.set_grad_enabled(False)

# HF now (Feb 20, 2025) imposes a storage limit of 1GB. Will have to pull JTP from other places.
os.system("wget -nv https://huggingface.co/RedRocket/JointTaggerProject/resolve/main/JTP_PILOT2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors")


category_id_to_str = {
        "0": "general",
        # 3 copyright
        "4": "character",
        "5": "species",
        "7": "meta",
        "8": "lore",
        "1": "artist",
    }
class Pruner:
    def __init__(self, path_to_tag_list_csv):
        species_tags = set()
        allowed_tags = set()
        with open(path_to_tag_list_csv, "r") as f:
            reader = csv.reader(f)
            header = next(reader)
            name_index = header.index("name")
            category_index = header.index("category")
            post_count_index = header.index("post_count")
            for row in reader:
                if int(row[post_count_index]) > 20:
                    category = row[category_index]
                    name = row[name_index]
                    if category == "5":
                        species_tags.add(name)
                        allowed_tags.add(name)
                    elif category == "0":
                        allowed_tags.add(name)
                    elif category == "7":
                        allowed_tags.add(name)

        self.species_tags = species_tags
        self.allowed_tags = allowed_tags

    def _prune_not_allowed_tags(self, raw_tags):
        this_allowed_tags = set()
        for tag in raw_tags:
            if tag in self.allowed_tags:
                this_allowed_tags.add(tag)
        return this_allowed_tags

    def _find_and_format_species_tags(self, tag_set):
        this_specie_tags = []
        for tag in tag_set:
            if tag in self.species_tags:
                this_specie_tags.append(tag)

        formatted_tags = f"species: {' '.join([t for t in this_specie_tags])}\n"
        return formatted_tags, this_specie_tags

    def prompt_construction_pipeline_florence2(self, tags, length):
        if type(tags) is str:
            tags = tags.split(" ")
        random.shuffle(tags)
        tags = self._prune_not_allowed_tags(tags, )
        formatted_species_tags, this_specie_tags = self._find_and_format_species_tags(tags)
        non_species_tags = [t for t in tags if t not in this_specie_tags]
        prompt = f"{' '.join(non_species_tags)}\n{formatted_species_tags}\nlength: {length}\n\nSTYLE1 FURRY CAPTION:"
        return prompt



class Fit(torch.nn.Module):
    def __init__(
            self,
            bounds: tuple[int, int] | int,
            interpolation=InterpolationMode.LANCZOS,
            grow: bool = True,
            pad: float | None = None
    ):
        super().__init__()

        self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds
        self.interpolation = interpolation
        self.grow = grow
        self.pad = pad

    def forward(self, img: Image) -> Image:
        wimg, himg = img.size
        hbound, wbound = self.bounds

        hscale = hbound / himg
        wscale = wbound / wimg

        if not self.grow:
            hscale = min(hscale, 1.0)
            wscale = min(wscale, 1.0)

        scale = min(hscale, wscale)
        if scale == 1.0:
            return img

        hnew = min(round(himg * scale), hbound)
        wnew = min(round(wimg * scale), wbound)

        img = TF.resize(img, (hnew, wnew), self.interpolation)

        if self.pad is None:
            return img

        hpad = hbound - hnew
        wpad = wbound - wnew

        tpad = hpad // 2
        bpad = hpad - tpad

        lpad = wpad // 2
        rpad = wpad - lpad

        return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad)

    def __repr__(self) -> str:
        return (
                f"{self.__class__.__name__}(" +
                f"bounds={self.bounds}, " +
                f"interpolation={self.interpolation.value}, " +
                f"grow={self.grow}, " +
                f"pad={self.pad})"
        )


class CompositeAlpha(torch.nn.Module):
    def __init__(
            self,
            background: tuple[float, float, float] | float,
    ):
        super().__init__()

        self.background = (background, background, background) if isinstance(background, float) else background
        self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2)

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        if img.shape[-3] == 3:
            return img

        alpha = img[..., 3, None, :, :]

        img[..., :3, :, :] *= alpha

        background = self.background.expand(-1, img.shape[-2], img.shape[-1])
        if background.ndim == 1:
            background = background[:, None, None]
        elif background.ndim == 2:
            background = background[None, :, :]

        img[..., :3, :, :] += (1.0 - alpha) * background
        return img[..., :3, :, :]

    def __repr__(self) -> str:
        return (
                f"{self.__class__.__name__}(" +
                f"background={self.background})"
        )


class GatedHead(torch.nn.Module):
    def __init__(self,
        num_features: int,
        num_classes: int
    ):
        super().__init__()
        self.num_classes = num_classes
        self.linear = torch.nn.Linear(num_features, num_classes * 2)

        self.act = torch.nn.Sigmoid()
        self.gate = torch.nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear(x)
        x = self.act(x[:, :self.num_classes]) * self.gate(x[:, self.num_classes:])
        return x

model_id = "lodestone-horizon/furrence2-large"
model = Florence2ForConditionalGeneration.from_pretrained(model_id,).eval()
processor = AutoProcessor.from_pretrained("./florence2_implementation/", trust_remote_code=True)


tree = defaultdict(list)
with open('tag_implications-2024-05-05.csv', 'rt') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        if row["status"] == "active":
            tree[row["consequent_name"]].append(row["antecedent_name"])


title = """<h1 align="center">Furrence2 Captioner Demo</h1>"""
description=(
    """<br> The captioner is being prompted by JTP Pilot2 tagger. You may use hand-curated tags to get better results. </a>
    <br> This demo is running on CPU. For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.</a>"""
)
tagger_transform = transforms.Compose([
            Fit((384, 384)),
            transforms.ToTensor(),
            CompositeAlpha(0.5),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
            transforms.CenterCrop((384, 384)),
        ])

THRESHOLD = 0.2
tagger_model = timm.create_model(
    "vit_so400m_patch14_siglip_384.webli",
    pretrained=False,
    num_classes=9083,
)  # type: VisionTransformer
tagger_model.head = GatedHead(min(tagger_model.head.weight.shape), 9083)
safetensors.torch.load_model(tagger_model, "JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors")

tagger_model.eval()

with open("JTP_PILOT2_tags.json", "r") as file:
    tags = json.load(file)  # type: dict
allowed_tags = list(tags.keys())

for idx, tag in enumerate(allowed_tags):
    allowed_tags[idx] = tag

pruner = Pruner("tags-2024-05-05.csv")

def generate_prompt(image, expected_caption_length):
    global THRESHOLD, tree, tokenizer, model, tagger_model, tagger_transform
    tagger_input = tagger_transform(image.convert('RGBA')).unsqueeze(0)
    probabilities = tagger_model(tagger_input)
    for prob in probabilities:
        indices = torch.where(prob > THRESHOLD)[0]
        sorted_indices = torch.argsort(prob[indices], descending=True)
        final_tags = []
        for i in sorted_indices:
            final_tags.append(allowed_tags[indices[i]])

    final_tags = " ".join(final_tags)
    task_prompt = pruner.prompt_construction_pipeline_florence2(final_tags, expected_caption_length)
    return task_prompt


def inference_caption(image, expected_caption_length, seq_len=512,):
    start_time = time.time()
    prompt_input = generate_prompt(image, expected_caption_length)
    end_time = time.time()
    execution_time = end_time - start_time
    print(f"Finished tagging in {execution_time:.3f} seconds")
    try:
        pixel_values = processor.image_processor(image, return_tensors="pt", )["pixel_values"]
        encoder_inputs = processor.tokenizer(
            text=prompt_input,
            return_tensors="pt",
            # padding = "max_length",
            # truncation = True,
            # max_length = 256,
            # don't add these; these will cause problems when doing inference
        )
        start_time = time.time()
        generated_ids = model.generate(
            input_ids=encoder_inputs["input_ids"],
            attention_mask=encoder_inputs["attention_mask"],
            pixel_values=pixel_values,
            max_new_tokens=seq_len,
            early_stopping=False,
            do_sample=False,
            num_beams=3,
        )
        end_time = time.time()
        execution_time = end_time - start_time
        print(f"Finished captioning in {execution_time:.3f} seconds")
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

        return generated_text
    except Exception as e:
        print("error message:", e)
        return "An error occurred."


def main():

    with gr.Blocks() as iface:

        gr.Markdown(title)
        gr.Markdown(description)

        with gr.Row():
            with gr.Column(scale=1):
                image_input = gr.Image(type="pil")

                seq_len = gr.Number(
                    value=512, label="Output Cutoff Length", precision=0,
                    interactive=True
                )

                expected_length = gr.Number(minimum=50, maximum=200,
                    value=100, label="Expected Caption Length", precision=0,
                    interactive=True
                )

            with gr.Column(scale=1):
                with gr.Column():
                    caption_button = gr.Button(
                        value="Caption it!", interactive=True, variant="primary",
                    )

                    caption_output = gr.Textbox(lines=1, label="Caption Output")
                    caption_button.click(
                        inference_caption,
                        [
                            image_input,
                            expected_length,
                            seq_len,
                        ],
                        [caption_output,],
                    )

    iface.launch(share=False)

if __name__ == "__main__":
    main()