from typing import Union, List

import gradio as gr
import matplotlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.utilities.types import EPOCH_OUTPUT

matplotlib.use('Agg')
import numpy as np
from PIL import Image
import albumentations as A
import albumentations.pytorch as al_pytorch
import torchvision
from pl_bolts.models.gans import Pix2Pix

""" Class """


class OverpoweredPix2Pix(Pix2Pix):

    def validation_step(self, batch, batch_idx):
        """ Validation step """
        real, condition = batch
        with torch.no_grad():
            loss = self._disc_step(real, condition)
            self.log("val_PatchGAN_loss", loss)

            loss = self._gen_step(real, condition)
            self.log("val_generator_loss", loss)

        return {
            'sketch': real,
            'colour': condition
        }

    def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
        sketch = outputs[0]['sketch']
        colour = outputs[0]['colour']
        with torch.no_grad():
            gen_coloured = self.gen(sketch)
        grid_image = torchvision.utils.make_grid(
            [
                sketch[0], colour[0], gen_coloured[0],
            ],
            normalize=True
        )
        self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch)


""" Load the model """
model_checkpoint_path = "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
# model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
# model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"

model = OverpoweredPix2Pix.load_from_checkpoint(
    model_checkpoint_path
)

model_chk = torch.load(
    model_checkpoint_path, map_location=torch.device('cpu')
)
# model = gen().load_state_dict(model_chk)

model.eval()


def greet(name):
    return "Hello " + name + "!!"


def predict(img: Image):
    # transform img
    image = np.asarray(img)
    # image = image[:, image.shape[1] // 2:, :]
    # use on inference
    inference_transform = A.Compose([
        A.Resize(width=256, height=256),
        A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
        al_pytorch.ToTensorV2(),
    ])
    # inverse_transform = A.Compose([
    #     A.Normalize(
    #         mean=[0.485, 0.456, 0.406],
    #         std=[0.229, 0.224, 0.225]
    #     ),
    # ])
    inference_img = inference_transform(
        image=image
    )['image'].unsqueeze(0)
    with torch.no_grad():
        result = model.gen(inference_img)
        # torchvision.utils.save_image(inference_img, "inference_image.png", normalize=True)
        torchvision.utils.save_image(result, "inference_image.png", normalize=True)

    """
    result_grid = torchvision.utils.make_grid(
        [result[0]],
        normalize=True
    )
    # plt.imsave("coloured_grid.png", (result_grid.permute(1,2,0).detach().numpy()*255).astype(int))
    torchvision.utils.save_image(
        result_grid, "coloured_image.png", normalize=True
    )
    """
    return "inference_image.png"  # 'coloured_image.png',


iface = gr.Interface(
    fn=predict,
    inputs=gr.inputs.Image(type="pil"),
    #inputs="sketchpad",
    examples=[
        "examples/thesis_test.png",
        "examples/thesis_test2.png",
        "examples/thesis1.png",
        "examples/thesis4.png",
        "examples/thesis5.png",
        "examples/thesis6.png",
        # "examples/1000000.png"
    ],
    outputs=gr.outputs.Image(type="pil",),
    #outputs=[
    #    "image",
    #    # "image"
    #],
    title="Colour your sketches!",
    description=" Upload a sketch and the conditional gan will colour it for you!",
    article="WIP repo lives here - https://github.com/nmud19/thesisGAN "
)
iface.launch()