|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import albumentations as A |
|
from albumentations.pytorch import ToTensorV2 |
|
from PIL import Image |
|
from torch.utils.data import Dataset, DataLoader |
|
import numpy as np |
|
import os |
|
from tqdm import tqdm |
|
from torchvision.utils import save_image |
|
import gradio as gr |
|
|
|
class cnnBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, up_sample=False, use_act=True, **kwargs): |
|
super().__init__() |
|
self.cnn_block = nn.Sequential( |
|
nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, **kwargs) |
|
if up_sample else |
|
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, padding_mode="reflect", **kwargs), |
|
nn.InstanceNorm2d(out_channels), |
|
nn.ReLU(inplace=True) if use_act else nn.Identity() |
|
) |
|
|
|
def forward(self, x): |
|
return self.cnn_block(x) |
|
|
|
class residualBlock(nn.Module): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.resBlock = nn.Sequential( |
|
cnnBlock(channels, channels, kernel_size=3, padding=1), |
|
cnnBlock(channels, channels, use_act=False, kernel_size=3, padding=1) |
|
) |
|
|
|
def forward(self, x): |
|
return x + self.resBlock(x) |
|
|
|
class Generator(nn.Module): |
|
def __init__(self, img_channels=3, features=64, num_residual=9): |
|
super().__init__() |
|
self.initial = nn.Sequential( |
|
nn.Conv2d(img_channels, 64, kernel_size=7, stride=1, padding=3, padding_mode="reflect"), |
|
nn.ReLU() |
|
) |
|
self.downBlock = nn.ModuleList([ |
|
cnnBlock(features, features*2, kernel_size=3, stride=2, padding=1), |
|
cnnBlock(features*2, features*4, kernel_size=3, stride=2, padding=1) |
|
]) |
|
self.resBlock = nn.Sequential(*[residualBlock(features*4) for _ in range(num_residual)]) |
|
self.upBlock = nn.ModuleList([ |
|
cnnBlock(features*4, features*2, up_sample=True, kernel_size=3, stride=2, padding=1, output_padding=1), |
|
cnnBlock(features*2, features, up_sample=True, kernel_size=3, stride=2, padding=1, output_padding=1), |
|
]) |
|
self.final = nn.Conv2d(features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect") |
|
|
|
def forward(self, x): |
|
x = self.initial(x) |
|
for layer in self.downBlock: |
|
x = layer(x) |
|
x = self.resBlock(x) |
|
for layer in self.upBlock: |
|
x = layer(x) |
|
x = self.final(x) |
|
return torch.tanh(x) |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
TRAIN_DIR = "/kaggle/input/vangogh2photo/vangogh2photo/train" |
|
VAL_DIR = "/kaggle/input/vangogh2photo/vangogh2photo/val" |
|
BATCH_SIZE = 1 |
|
LEARNING_RATE = 2e-4 |
|
LAMBDA_IDENTITY = 0.0 |
|
LAMBDA_CYCLE = 10 |
|
NUM_WORKERS = 4 |
|
NUM_EPOCHS = 0 |
|
LOAD_MODEL = True |
|
SAVE_MODEL = False |
|
CHECKPOINT_GEN_A = f"{os.getcwd()}/genA.pth.tar" |
|
CHECKPOINT_GEN_B = f"{os.getcwd()}/genB.pth.tar" |
|
|
|
transforms = A.Compose( |
|
[ |
|
A.Resize(width=256, height=256), |
|
A.HorizontalFlip(p=0.5), |
|
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255), |
|
ToTensorV2(), |
|
], |
|
additional_targets={"image0": "image"}, |
|
is_check_shapes=False |
|
) |
|
|
|
def load_checkpoint(checkpoint_file, model, optimizer, lr): |
|
print("=> Loading checkpoint") |
|
checkpoint = torch.load(checkpoint_file, map_location=DEVICE) |
|
model.load_state_dict(checkpoint["state_dict"]) |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
|
|
|
|
|
|
for param_group in optimizer.param_groups: |
|
param_group["lr"] = lr |
|
|
|
genB = Generator().to(DEVICE) |
|
genA = Generator().to(DEVICE) |
|
optim_gen = optim.Adam(list(genB.parameters()) + list(genA.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999)) |
|
|
|
load_checkpoint(CHECKPOINT_GEN_A, genA, optim_gen, LEARNING_RATE) |
|
load_checkpoint(CHECKPOINT_GEN_B, genB, optim_gen, LEARNING_RATE) |
|
|
|
def postprocess_and_show(output): |
|
|
|
output = output.squeeze(0).detach().cpu() |
|
|
|
|
|
output = (output + 1) / 2.0 |
|
|
|
|
|
output_image = output.permute(1, 2, 0).numpy() |
|
|
|
|
|
output_image = (output_image * 255).astype(np.uint8) |
|
|
|
|
|
output_pil = Image.fromarray(output_image) |
|
|
|
return output_pil |
|
|
|
|
|
transforms2 = A.Compose( |
|
[ |
|
A.Resize(width=256, height=256), |
|
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255), |
|
ToTensorV2(), |
|
], |
|
is_check_shapes=False |
|
) |
|
def style_transfer(img_file): |
|
img = np.array(Image.open(img_file)) |
|
transform_img = transforms2(image=img) |
|
input_img = transform_img["image"] |
|
input_img = input_img.to(DEVICE) |
|
output_img = genA(input_img) |
|
return postprocess_and_show(output_img) |
|
|
|
|
|
|
|
image_input = gr.Image(type="filepath") |
|
image_output = gr.Image() |
|
|
|
demo = gr.Interface(fn=style_transfer, inputs=image_input, outputs=image_output, title="Style Transfer with CycleGAN") |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |