File size: 5,341 Bytes
325a7f8 4d6ed90 325a7f8 d86204b 325a7f8 3daf5ab 325a7f8 d86204b 325a7f8 c360f38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import torch
import torch.nn as nn
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from tqdm import tqdm
from torchvision.utils import save_image
import gradio as gr
class cnnBlock(nn.Module):
def __init__(self, in_channels, out_channels, up_sample=False, use_act=True, **kwargs):
super().__init__()
self.cnn_block = nn.Sequential(
nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, **kwargs)
if up_sample else
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, padding_mode="reflect", **kwargs),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True) if use_act else nn.Identity()
)
def forward(self, x):
return self.cnn_block(x)
class residualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.resBlock = nn.Sequential(
cnnBlock(channels, channels, kernel_size=3, padding=1),
cnnBlock(channels, channels, use_act=False, kernel_size=3, padding=1)
)
def forward(self, x):
return x + self.resBlock(x)
class Generator(nn.Module):
def __init__(self, img_channels=3, features=64, num_residual=9):
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(img_channels, 64, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
nn.ReLU()
)
self.downBlock = nn.ModuleList([
cnnBlock(features, features*2, kernel_size=3, stride=2, padding=1),
cnnBlock(features*2, features*4, kernel_size=3, stride=2, padding=1)
])
self.resBlock = nn.Sequential(*[residualBlock(features*4) for _ in range(num_residual)])
self.upBlock = nn.ModuleList([
cnnBlock(features*4, features*2, up_sample=True, kernel_size=3, stride=2, padding=1, output_padding=1),
cnnBlock(features*2, features, up_sample=True, kernel_size=3, stride=2, padding=1, output_padding=1),
])
self.final = nn.Conv2d(features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
def forward(self, x):
x = self.initial(x)
for layer in self.downBlock:
x = layer(x)
x = self.resBlock(x)
for layer in self.upBlock:
x = layer(x)
x = self.final(x)
return torch.tanh(x)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "/kaggle/input/vangogh2photo/vangogh2photo/train"
VAL_DIR = "/kaggle/input/vangogh2photo/vangogh2photo/val"
BATCH_SIZE = 1
LEARNING_RATE = 2e-4
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 4
NUM_EPOCHS = 0
LOAD_MODEL = True
SAVE_MODEL = False
CHECKPOINT_GEN_A = f"{os.getcwd()}/genA.pth.tar"
CHECKPOINT_GEN_B = f"{os.getcwd()}/genB.pth.tar"
transforms = A.Compose(
[
A.Resize(width=256, height=256),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
ToTensorV2(),
],
additional_targets={"image0": "image"},
is_check_shapes=False
)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print("=> Loading checkpoint")
checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
# If we don't do this then it will just have learning rate of old checkpoint
# and it will lead to many hours of debugging \:
for param_group in optimizer.param_groups:
param_group["lr"] = lr
genB = Generator().to(DEVICE)
genA = Generator().to(DEVICE)
optim_gen = optim.Adam(list(genB.parameters()) + list(genA.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999))
load_checkpoint(CHECKPOINT_GEN_A, genA, optim_gen, LEARNING_RATE)
load_checkpoint(CHECKPOINT_GEN_B, genB, optim_gen, LEARNING_RATE)
def postprocess_and_show(output):
# Detach from GPU, move to CPU, and remove the batch dimension
output = output.squeeze(0).detach().cpu()
# Convert from [-1, 1] to [0, 1]
output = (output + 1) / 2.0
# Convert from tensor to NumPy array and transpose (C, H, W) to (H, W, C)
output_image = output.permute(1, 2, 0).numpy()
# Convert to a [0, 255] image (optional if you're using a visualization library)
output_image = (output_image * 255).astype(np.uint8)
# Option 2: Convert to a PIL image if you want to save or manipulate it
output_pil = Image.fromarray(output_image)
return output_pil
#plt.imshow(output_pil)
transforms2 = A.Compose(
[
A.Resize(width=256, height=256),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
ToTensorV2(),
],
is_check_shapes=False
)
def style_transfer(img_file):
img = np.array(Image.open(img_file))
transform_img = transforms2(image=img)
input_img = transform_img["image"]
input_img = input_img.to(DEVICE)
output_img = genA(input_img)
return postprocess_and_show(output_img)
#save_image(output_img*0.5 + 0.5, f"/kaggle/working/output.png")
image_input = gr.Image(type="filepath")
image_output = gr.Image()
demo = gr.Interface(fn=style_transfer, inputs=image_input, outputs=image_output, title="Style Transfer with CycleGAN")
if __name__ == "__main__":
demo.launch() |