DeF0017's picture
Update app.py
4d6ed90 verified
import torch
import torch.nn as nn
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from tqdm import tqdm
from torchvision.utils import save_image
import gradio as gr
class cnnBlock(nn.Module):
def __init__(self, in_channels, out_channels, up_sample=False, use_act=True, **kwargs):
super().__init__()
self.cnn_block = nn.Sequential(
nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, **kwargs)
if up_sample else
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, padding_mode="reflect", **kwargs),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True) if use_act else nn.Identity()
)
def forward(self, x):
return self.cnn_block(x)
class residualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.resBlock = nn.Sequential(
cnnBlock(channels, channels, kernel_size=3, padding=1),
cnnBlock(channels, channels, use_act=False, kernel_size=3, padding=1)
)
def forward(self, x):
return x + self.resBlock(x)
class Generator(nn.Module):
def __init__(self, img_channels=3, features=64, num_residual=9):
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(img_channels, 64, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
nn.ReLU()
)
self.downBlock = nn.ModuleList([
cnnBlock(features, features*2, kernel_size=3, stride=2, padding=1),
cnnBlock(features*2, features*4, kernel_size=3, stride=2, padding=1)
])
self.resBlock = nn.Sequential(*[residualBlock(features*4) for _ in range(num_residual)])
self.upBlock = nn.ModuleList([
cnnBlock(features*4, features*2, up_sample=True, kernel_size=3, stride=2, padding=1, output_padding=1),
cnnBlock(features*2, features, up_sample=True, kernel_size=3, stride=2, padding=1, output_padding=1),
])
self.final = nn.Conv2d(features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
def forward(self, x):
x = self.initial(x)
for layer in self.downBlock:
x = layer(x)
x = self.resBlock(x)
for layer in self.upBlock:
x = layer(x)
x = self.final(x)
return torch.tanh(x)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "/kaggle/input/vangogh2photo/vangogh2photo/train"
VAL_DIR = "/kaggle/input/vangogh2photo/vangogh2photo/val"
BATCH_SIZE = 1
LEARNING_RATE = 2e-4
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 4
NUM_EPOCHS = 0
LOAD_MODEL = True
SAVE_MODEL = False
CHECKPOINT_GEN_A = f"{os.getcwd()}/genA.pth.tar"
CHECKPOINT_GEN_B = f"{os.getcwd()}/genB.pth.tar"
transforms = A.Compose(
[
A.Resize(width=256, height=256),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
ToTensorV2(),
],
additional_targets={"image0": "image"},
is_check_shapes=False
)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print("=> Loading checkpoint")
checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
# If we don't do this then it will just have learning rate of old checkpoint
# and it will lead to many hours of debugging \:
for param_group in optimizer.param_groups:
param_group["lr"] = lr
genB = Generator().to(DEVICE)
genA = Generator().to(DEVICE)
optim_gen = optim.Adam(list(genB.parameters()) + list(genA.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999))
load_checkpoint(CHECKPOINT_GEN_A, genA, optim_gen, LEARNING_RATE)
load_checkpoint(CHECKPOINT_GEN_B, genB, optim_gen, LEARNING_RATE)
def postprocess_and_show(output):
# Detach from GPU, move to CPU, and remove the batch dimension
output = output.squeeze(0).detach().cpu()
# Convert from [-1, 1] to [0, 1]
output = (output + 1) / 2.0
# Convert from tensor to NumPy array and transpose (C, H, W) to (H, W, C)
output_image = output.permute(1, 2, 0).numpy()
# Convert to a [0, 255] image (optional if you're using a visualization library)
output_image = (output_image * 255).astype(np.uint8)
# Option 2: Convert to a PIL image if you want to save or manipulate it
output_pil = Image.fromarray(output_image)
return output_pil
#plt.imshow(output_pil)
transforms2 = A.Compose(
[
A.Resize(width=256, height=256),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
ToTensorV2(),
],
is_check_shapes=False
)
def style_transfer(img_file):
img = np.array(Image.open(img_file))
transform_img = transforms2(image=img)
input_img = transform_img["image"]
input_img = input_img.to(DEVICE)
output_img = genA(input_img)
return postprocess_and_show(output_img)
#save_image(output_img*0.5 + 0.5, f"/kaggle/working/output.png")
image_input = gr.Image(type="filepath")
image_output = gr.Image()
demo = gr.Interface(fn=style_transfer, inputs=image_input, outputs=image_output, title="Style Transfer with CycleGAN")
if __name__ == "__main__":
demo.launch()