from PIL import Image
import numpy as np
import matplotlib.cm as cm
import msgspec
import torch
from torchvision.transforms import transforms
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF
import timm
from timm.models import VisionTransformer
import safetensors.torch
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download

class Fit(torch.nn.Module):
    def __init__(
        self,
        bounds: tuple[int, int] | int,
        interpolation = InterpolationMode.LANCZOS,
        grow: bool = True,
        pad: float | None = None
    ):
        super().__init__()

        self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds
        self.interpolation = interpolation
        self.grow = grow
        self.pad = pad

    def forward(self, img: Image) -> Image:
        wimg, himg = img.size
        hbound, wbound = self.bounds

        hscale = hbound / himg
        wscale = wbound / wimg

        if not self.grow:
            hscale = min(hscale, 1.0)
            wscale = min(wscale, 1.0)

        scale = min(hscale, wscale)
        if scale == 1.0:
            return img

        hnew = min(round(himg * scale), hbound)
        wnew = min(round(wimg * scale), wbound)

        img = TF.resize(img, (hnew, wnew), self.interpolation)

        if self.pad is None:
            return img

        hpad = hbound - hnew
        wpad = wbound - wnew

        tpad = hpad // 2
        bpad = hpad - tpad

        lpad = wpad // 2
        rpad = wpad - lpad

        return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad)

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}(" +
            f"bounds={self.bounds}, " +
            f"interpolation={self.interpolation.value}, " +
            f"grow={self.grow}, " +
            f"pad={self.pad})"
        )

class CompositeAlpha(torch.nn.Module):
    def __init__(
        self,
        background: tuple[float, float, float] | float,
    ):
        super().__init__()

        self.background = (background, background, background) if isinstance(background, float) else background
        self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2)

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        if img.shape[-3] == 3:
            return img

        alpha = img[..., 3, None, :, :]

        img[..., :3, :, :] *= alpha

        background = self.background.expand(-1, img.shape[-2], img.shape[-1])
        if background.ndim == 1:
            background = background[:, None, None]
        elif background.ndim == 2:
            background = background[None, :, :]

        img[..., :3, :, :] += (1.0 - alpha) * background
        return img[..., :3, :, :]

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}(" +
            f"background={self.background})"
        )

transform = transforms.Compose([
    Fit((384, 384)),
    transforms.ToTensor(),
    CompositeAlpha(0.5),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
    transforms.CenterCrop((384, 384)),
])

model = timm.create_model(
    "vit_so400m_patch14_siglip_384.webli",
    pretrained=False,
    num_classes=9083,
) # type: VisionTransformer

cached_model = hf_hub_download(
    repo_id="RedRocket/JointTaggerProject",
    subfolder="JTP_PILOT",
    filename="JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors"
)

safetensors.torch.load_model(model, cached_model)
model.eval()

with open("tagger_tags.json", "rb") as file:
    tags = msgspec.json.decode(file.read(), type=dict[str, int])

for tag in list(tags.keys()):
    tags[tag.replace("_", " ")] = tags.pop(tag)

allowed_tags = list(tags.keys())

@spaces.GPU(duration=5)
def run_classifier(image: Image.Image, threshold):
    img = image.convert('RGBA')
    tensor = transform(img).unsqueeze(0)

    with torch.no_grad():
        logits = model(tensor)
        probits = torch.nn.functional.sigmoid(logits[0])
        values, indices = probits.cpu().topk(250)

    tag_score = {allowed_tags[idx.item()]: val.item() for idx, val in zip(indices, values)}

    sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))

    return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score

def create_tags(threshold, sorted_tag_score: dict):
    filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
    text_no_impl = ", ".join(filtered_tag_score.keys())
    return text_no_impl, filtered_tag_score

def clear_image():
    return "", {}, None, {}, None

@spaces.GPU(duration=5)
def cam_inference(img, threshold, alpha, evt: gr.SelectData):
    target_tag_index = tags[evt.value]
    tensor = transform(img).unsqueeze(0)

    gradients = {}
    activations = {}

    def hook_forward(module, input, output):
        activations['value'] = output

    def hook_backward(module, grad_in, grad_out):
        gradients['value'] = grad_out[0]

    handle_forward = model.norm.register_forward_hook(hook_forward)
    handle_backward = model.norm.register_full_backward_hook(hook_backward)

    logits = model(tensor)
    probits = torch.nn.functional.sigmoid(logits[0])
 
    model.zero_grad()
    probits[target_tag_index].backward(retain_graph=True)

    with torch.no_grad():
        patch_grads = gradients.get('value')
        patch_acts = activations.get('value')
    
        weights = torch.mean(patch_grads, dim=1).squeeze(0)
    
        cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
        cam_1d = torch.relu(cam_1d)
    
        cam = cam_1d.reshape(27, 27).detach().cpu().numpy()

    handle_forward.remove()
    handle_backward.remove()

    return create_cam_visualization_pil(img, cam, alpha=alpha, vis_threshold=threshold), cam

def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
    """
    Overlays CAM on image and returns a PIL image.
    Args:
        image_pil: PIL Image (RGB)
        cam: 2D numpy array (activation map)
        alpha: float, blending factor
        vis_threshold: float, minimum normalized CAM value to show color
    Returns:
        PIL.Image.Image with overlay
    """
    if cam is None:
        return image_pil
    w, h = image_pil.size
    size = max(w, h)

    # Normalize CAM to [0, 1]
    cam -= cam.min()
    cam /= cam.max()

    # Create heatmap using matplotlib colormap
    colormap = cm.get_cmap('inferno')
    cam_rgb = colormap(cam)[:, :, :3]  # RGB

    # Create alpha channel
    cam_alpha = (cam >= vis_threshold).astype(np.float32) * alpha  # Alpha mask
    cam_rgba = np.dstack((cam_rgb, cam_alpha))  # Shape: (H, W, 4)
    
    # Coarse upscale for CAM output -- keeps "blocky" effect that is truer to what is measured
    cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8), mode="RGBA")
    cam_pil = cam_pil.resize((216,216), resample=Image.Resampling.NEAREST)

    # Model uses padded image as input, this matches attention map to input image aspect ratio
    cam_pil = cam_pil.resize((size, size), resample=Image.Resampling.BICUBIC)
    cam_pil = transforms.CenterCrop((h, w))(cam_pil)

    # Composite over original
    composite = Image.alpha_composite(image_pil, cam_pil)

    return composite

custom_css = """
.output-class { display: none; }
.inferno-slider input[type=range] {
    background: linear-gradient(to right,
        #000004, #1b0c41, #4a0c6b, #781c6d,
        #a52c60, #cf4446, #ed6925, #fb9b06,
        #f7d13d, #fcffa4
    ) !important;
    background-size: 100% 100% !important;
}
#image_container-image {
    width: 100%;
    aspect-ratio: 1 / 1;
    max-height: 100%;
}
#image_container img {
    object-fit: contain !important;
}
"""

with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("## Joint Tagger Project: JTP-PILOT Demo")
    original_image_state = gr.State() # stash a copy of the input image
    sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
    cam_state = gr.State()
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
            cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
            alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
        with gr.Column():
            threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
            tag_string = gr.Textbox(label="Tag String")
            label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)

    gr.Markdown("""
    This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results).  A threshold of 0.2 is recommended.  Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
    This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
    Thanks to metal63 for providing initial code for attention visualization (click a tag in the tag list to try it out!)
    Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
    """)

    image.upload(
        fn=run_classifier,
        inputs=[image, threshold_slider],
        outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state],
        show_progress='minimal'
    )

    image.clear(
        fn=clear_image,
        inputs=[],
        outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
    )

    threshold_slider.input(
        fn=create_tags,
        inputs=[threshold_slider, sorted_tag_score_state],
        outputs=[tag_string, label_box],
        show_progress='hidden'
    )

    label_box.select(
        fn=cam_inference,
        inputs=[original_image_state, cam_slider, alpha_slider],
        outputs=[image, cam_state],
        show_progress='minimal'
    )

    cam_slider.input(
        fn=create_cam_visualization_pil,
        inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
        outputs=[image],
        show_progress='hidden'
    )

    alpha_slider.input(
        fn=create_cam_visualization_pil,
        inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
        outputs=[image],
        show_progress='hidden'
    )

if __name__ == "__main__":
    demo.launch()