from typing import Union, List import gradio as gr import matplotlib import torch import torch.nn as nn import torch.nn.functional as F from pytorch_lightning.utilities.types import EPOCH_OUTPUT matplotlib.use('Agg') import numpy as np from PIL import Image import albumentations as A import albumentations.pytorch as al_pytorch import torchvision from pl_bolts.models.gans import Pix2Pix """ Class """ class OverpoweredPix2Pix(Pix2Pix): def validation_step(self, batch, batch_idx): """ Validation step """ real, condition = batch with torch.no_grad(): loss = self._disc_step(real, condition) self.log("val_PatchGAN_loss", loss) loss = self._gen_step(real, condition) self.log("val_generator_loss", loss) return { 'sketch': real, 'colour': condition } def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: sketch = outputs[0]['sketch'] colour = outputs[0]['colour'] with torch.no_grad(): gen_coloured = self.gen(sketch) grid_image = torchvision.utils.make_grid( [ sketch[0], colour[0], gen_coloured[0], ], normalize=True ) self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch) """ Load the model """ model_checkpoint_path = "model/lightning_bolts_model/epoch=99-step=89000.ckpt" # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt" # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth" model = OverpoweredPix2Pix.load_from_checkpoint( model_checkpoint_path ) model_chk = torch.load( model_checkpoint_path, map_location=torch.device('cpu') ) # model = gen().load_state_dict(model_chk) model.eval() def greet(name): return "Hello " + name + "!!" def predict(img: Image): # transform img image = np.asarray(img) # image = image[:, image.shape[1] // 2:, :] # use on inference inference_transform = A.Compose([ A.Resize(width=256, height=256), A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0), al_pytorch.ToTensorV2(), ]) # inverse_transform = A.Compose([ # A.Normalize( # mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225] # ), # ]) inference_img = inference_transform( image=image )['image'].unsqueeze(0) with torch.no_grad(): result = model.gen(inference_img) # torchvision.utils.save_image(inference_img, "inference_image.png", normalize=True) torchvision.utils.save_image(result, "inference_image.png", normalize=True) """ result_grid = torchvision.utils.make_grid( [result[0]], normalize=True ) # plt.imsave("coloured_grid.png", (result_grid.permute(1,2,0).detach().numpy()*255).astype(int)) torchvision.utils.save_image( result_grid, "coloured_image.png", normalize=True ) """ return "inference_image.png" # 'coloured_image.png', iface = gr.Interface( fn=predict, inputs=gr.inputs.Image(type="pil"), #inputs="sketchpad", examples=[ "examples/thesis_test.png", "examples/thesis_test2.png", "examples/thesis1.png", "examples/thesis4.png", "examples/thesis5.png", "examples/thesis6.png", # "examples/1000000.png" ], outputs=gr.outputs.Image(type="pil",), #outputs=[ # "image", # # "image" #], title="Colour your sketches!", description=" Upload a sketch and the conditional gan will colour it for you!", article="WIP repo lives here - https://github.com/nmud19/thesisGAN " ) iface.launch()