Nikhil Mudhalwadkar
cloned
0fb096b
from typing import Union, List
import gradio as gr
import matplotlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
matplotlib.use('Agg')
import numpy as np
from PIL import Image
import albumentations as A
import albumentations.pytorch as al_pytorch
import torchvision
from pl_bolts.models.gans import Pix2Pix
""" Class """
class OverpoweredPix2Pix(Pix2Pix):
def validation_step(self, batch, batch_idx):
""" Validation step """
real, condition = batch
with torch.no_grad():
loss = self._disc_step(real, condition)
self.log("val_PatchGAN_loss", loss)
loss = self._gen_step(real, condition)
self.log("val_generator_loss", loss)
return {
'sketch': real,
'colour': condition
}
def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
sketch = outputs[0]['sketch']
colour = outputs[0]['colour']
with torch.no_grad():
gen_coloured = self.gen(sketch)
grid_image = torchvision.utils.make_grid(
[
sketch[0], colour[0], gen_coloured[0],
],
normalize=True
)
self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch)
""" Load the model """
model_checkpoint_path = "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
# model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
# model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
model = OverpoweredPix2Pix.load_from_checkpoint(
model_checkpoint_path
)
model_chk = torch.load(
model_checkpoint_path, map_location=torch.device('cpu')
)
# model = gen().load_state_dict(model_chk)
model.eval()
def greet(name):
return "Hello " + name + "!!"
def predict(img: Image):
# transform img
image = np.asarray(img)
# image = image[:, image.shape[1] // 2:, :]
# use on inference
inference_transform = A.Compose([
A.Resize(width=256, height=256),
A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
al_pytorch.ToTensorV2(),
])
# inverse_transform = A.Compose([
# A.Normalize(
# mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225]
# ),
# ])
inference_img = inference_transform(
image=image
)['image'].unsqueeze(0)
with torch.no_grad():
result = model.gen(inference_img)
# torchvision.utils.save_image(inference_img, "inference_image.png", normalize=True)
torchvision.utils.save_image(result, "inference_image.png", normalize=True)
"""
result_grid = torchvision.utils.make_grid(
[result[0]],
normalize=True
)
# plt.imsave("coloured_grid.png", (result_grid.permute(1,2,0).detach().numpy()*255).astype(int))
torchvision.utils.save_image(
result_grid, "coloured_image.png", normalize=True
)
"""
return "inference_image.png" # 'coloured_image.png',
iface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(type="pil"),
#inputs="sketchpad",
examples=[
"examples/thesis_test.png",
"examples/thesis_test2.png",
"examples/thesis1.png",
"examples/thesis4.png",
"examples/thesis5.png",
"examples/thesis6.png",
# "examples/1000000.png"
],
outputs=gr.outputs.Image(type="pil",),
#outputs=[
# "image",
# # "image"
#],
title="Colour your sketches!",
description=" Upload a sketch and the conditional gan will colour it for you!",
article="WIP repo lives here - https://github.com/nmud19/thesisGAN "
)
iface.launch()