DeF0017 commited on
Commit
325a7f8
·
verified ·
1 Parent(s): c360f38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py CHANGED
@@ -1,3 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  transforms2 = A.Compose(
2
  [
3
  A.Resize(width=256, height=256),
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import albumentations as A
5
+ from albumentations.pytorch import ToTensorV2
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import numpy as np
9
+ import os
10
+ from tqdm import tqdm
11
+ from torchvision.utils import save_image
12
+ import gradio as gr
13
+
14
+ class cnnBlock(nn.Module):
15
+ def __init__(self, in_channels, out_channels, up_sample=False, use_act=True, **kwargs):
16
+ super().__init__()
17
+ self.cnn_block = nn.Sequential(
18
+ nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, **kwargs)
19
+ if up_sample else
20
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, padding_mode="reflect", **kwargs),
21
+ nn.InstanceNorm2d(out_channels),
22
+ nn.ReLU(inplace=True) if use_act else nn.Identity()
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.cnn_block(x)
27
+
28
+ class residualBlock(nn.Module):
29
+ def __init__(self, channels):
30
+ super().__init__()
31
+ self.resBlock = nn.Sequential(
32
+ cnnBlock(channels, channels, kernel_size=3, padding=1),
33
+ cnnBlock(channels, channels, use_act=False, kernel_size=3, padding=1)
34
+ )
35
+
36
+ def forward(self, x):
37
+ return x + self.resBlock(x)
38
+
39
+ class Generator(nn.Module):
40
+ def __init__(self, img_channels=3, features=64, num_residual=9):
41
+ super().__init__()
42
+ self.initial = nn.Sequential(
43
+ nn.Conv2d(img_channels, 64, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
44
+ nn.ReLU()
45
+ )
46
+ self.downBlock = nn.ModuleList([
47
+ cnnBlock(features, features*2, kernel_size=3, stride=2, padding=1),
48
+ cnnBlock(features*2, features*4, kernel_size=3, stride=2, padding=1)
49
+ ])
50
+ self.resBlock = nn.Sequential(*[residualBlock(features*4) for _ in range(num_residual)])
51
+ self.upBlock = nn.ModuleList([
52
+ cnnBlock(features*4, features*2, up_sample=True, kernel_size=3, stride=2, padding=1, output_padding=1),
53
+ cnnBlock(features*2, features, up_sample=True, kernel_size=3, stride=2, padding=1, output_padding=1),
54
+ ])
55
+ self.final = nn.Conv2d(features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
56
+
57
+ def forward(self, x):
58
+ x = self.initial(x)
59
+ for layer in self.downBlock:
60
+ x = layer(x)
61
+ x = self.resBlock(x)
62
+ for layer in self.upBlock:
63
+ x = layer(x)
64
+ x = self.final(x)
65
+ return torch.tanh(x)
66
+
67
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
68
+ TRAIN_DIR = "/kaggle/input/vangogh2photo/vangogh2photo/train"
69
+ VAL_DIR = "/kaggle/input/vangogh2photo/vangogh2photo/val"
70
+ BATCH_SIZE = 1
71
+ LEARNING_RATE = 2e-4
72
+ LAMBDA_IDENTITY = 0.0
73
+ LAMBDA_CYCLE = 10
74
+ NUM_WORKERS = 4
75
+ NUM_EPOCHS = 0
76
+ LOAD_MODEL = True
77
+ SAVE_MODEL = False
78
+ CHECKPOINT_GEN_B = "/kaggle/input/checkpoints/genB.pth.tar"
79
+ CHECKPOINT_GEN_A = "/kaggle/input/checkpoints/genA.pth.tar"
80
+ CHECKPOINT_DISC_A = "/kaggle/input/checkpoints/discA.pth.tar"
81
+ CHECKPOINT_DISC_B = "/kaggle/input/checkpoints/discB.pth.tar"
82
+
83
+ transforms = A.Compose(
84
+ [
85
+ A.Resize(width=256, height=256),
86
+ A.HorizontalFlip(p=0.5),
87
+ A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
88
+ ToTensorV2(),
89
+ ],
90
+ additional_targets={"image0": "image"},
91
+ is_check_shapes=False
92
+ )
93
+
94
+ def load_checkpoint(checkpoint_file, model, optimizer, lr):
95
+ print("=> Loading checkpoint")
96
+ checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
97
+ model.load_state_dict(checkpoint["state_dict"])
98
+ optimizer.load_state_dict(checkpoint["optimizer"])
99
+
100
+ # If we don't do this then it will just have learning rate of old checkpoint
101
+ # and it will lead to many hours of debugging \:
102
+ for param_group in optimizer.param_groups:
103
+ param_group["lr"] = lr
104
+
105
+ genA = Generator().to(DEVICE)
106
+
107
+ load_checkpoint(CHECKPOINT_GEN_A, genA, optim_gen, LEARNING_RATE)
108
+
109
+ def postprocess_and_show(output):
110
+ # Detach from GPU, move to CPU, and remove the batch dimension
111
+ output = output.squeeze(0).detach().cpu()
112
+
113
+ # Convert from [-1, 1] to [0, 1]
114
+ output = (output + 1) / 2.0
115
+
116
+ # Convert from tensor to NumPy array and transpose (C, H, W) to (H, W, C)
117
+ output_image = output.permute(1, 2, 0).numpy()
118
+
119
+ # Convert to a [0, 255] image (optional if you're using a visualization library)
120
+ output_image = (output_image * 255).astype(np.uint8)
121
+
122
+ # Option 2: Convert to a PIL image if you want to save or manipulate it
123
+ output_pil = Image.fromarray(output_image)
124
+
125
+ return output_pil
126
+ #plt.imshow(output_pil)
127
+
128
  transforms2 = A.Compose(
129
  [
130
  A.Resize(width=256, height=256),