Spaces:
Runtime error
Runtime error
File size: 3,847 Bytes
0fb096b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from typing import Union, List
import gradio as gr
import matplotlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
matplotlib.use('Agg')
import numpy as np
from PIL import Image
import albumentations as A
import albumentations.pytorch as al_pytorch
import torchvision
from pl_bolts.models.gans import Pix2Pix
""" Class """
class OverpoweredPix2Pix(Pix2Pix):
def validation_step(self, batch, batch_idx):
""" Validation step """
real, condition = batch
with torch.no_grad():
loss = self._disc_step(real, condition)
self.log("val_PatchGAN_loss", loss)
loss = self._gen_step(real, condition)
self.log("val_generator_loss", loss)
return {
'sketch': real,
'colour': condition
}
def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
sketch = outputs[0]['sketch']
colour = outputs[0]['colour']
with torch.no_grad():
gen_coloured = self.gen(sketch)
grid_image = torchvision.utils.make_grid(
[
sketch[0], colour[0], gen_coloured[0],
],
normalize=True
)
self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch)
""" Load the model """
model_checkpoint_path = "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
# model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
# model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
model = OverpoweredPix2Pix.load_from_checkpoint(
model_checkpoint_path
)
model_chk = torch.load(
model_checkpoint_path, map_location=torch.device('cpu')
)
# model = gen().load_state_dict(model_chk)
model.eval()
def greet(name):
return "Hello " + name + "!!"
def predict(img: Image):
# transform img
image = np.asarray(img)
# image = image[:, image.shape[1] // 2:, :]
# use on inference
inference_transform = A.Compose([
A.Resize(width=256, height=256),
A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
al_pytorch.ToTensorV2(),
])
# inverse_transform = A.Compose([
# A.Normalize(
# mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225]
# ),
# ])
inference_img = inference_transform(
image=image
)['image'].unsqueeze(0)
with torch.no_grad():
result = model.gen(inference_img)
# torchvision.utils.save_image(inference_img, "inference_image.png", normalize=True)
torchvision.utils.save_image(result, "inference_image.png", normalize=True)
"""
result_grid = torchvision.utils.make_grid(
[result[0]],
normalize=True
)
# plt.imsave("coloured_grid.png", (result_grid.permute(1,2,0).detach().numpy()*255).astype(int))
torchvision.utils.save_image(
result_grid, "coloured_image.png", normalize=True
)
"""
return "inference_image.png" # 'coloured_image.png',
iface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(type="pil"),
#inputs="sketchpad",
examples=[
"examples/thesis_test.png",
"examples/thesis_test2.png",
"examples/thesis1.png",
"examples/thesis4.png",
"examples/thesis5.png",
"examples/thesis6.png",
# "examples/1000000.png"
],
outputs=gr.outputs.Image(type="pil",),
#outputs=[
# "image",
# # "image"
#],
title="Colour your sketches!",
description=" Upload a sketch and the conditional gan will colour it for you!",
article="WIP repo lives here - https://github.com/nmud19/thesisGAN "
)
iface.launch()
|