Spaces:
Runtime error
Runtime error
Upload 125 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- activation.py +18 -0
- app_gradio.py +137 -0
- assets/biome_image.png +0 -0
- assets/sample_traj.gif +3 -0
- assets/teaser.gif +3 -0
- configs/img2lmdb.yaml +177 -0
- configs/landscape1m.yaml +175 -0
- configs/scenedreamer_inference.yaml +93 -0
- configs/scenedreamer_train.yaml +223 -0
- encoding.py +67 -0
- environment.yaml +44 -0
- gridencoder/__init__.py +1 -0
- gridencoder/backend.py +40 -0
- gridencoder/grid.py +224 -0
- gridencoder/setup.py +50 -0
- gridencoder/src/bindings.cpp +8 -0
- gridencoder/src/gridencoder.cu +478 -0
- gridencoder/src/gridencoder.h +15 -0
- imaginaire/__init__.py +4 -0
- imaginaire/config.py +238 -0
- imaginaire/discriminators/__init__.py +0 -0
- imaginaire/discriminators/gancraft.py +278 -0
- imaginaire/generators/__init__.py +4 -0
- imaginaire/generators/gancraft_base.py +603 -0
- imaginaire/generators/scenedreamer.py +851 -0
- imaginaire/generators/spade.py +571 -0
- imaginaire/layers/__init__.py +27 -0
- imaginaire/layers/activation_norm.py +629 -0
- imaginaire/layers/conv.py +1377 -0
- imaginaire/layers/misc.py +61 -0
- imaginaire/layers/non_local.py +88 -0
- imaginaire/layers/nonlinearity.py +65 -0
- imaginaire/layers/residual.py +1411 -0
- imaginaire/layers/residual_deep.py +346 -0
- imaginaire/layers/vit.py +204 -0
- imaginaire/layers/weight_norm.py +267 -0
- imaginaire/losses/__init__.py +18 -0
- imaginaire/losses/feature_matching.py +38 -0
- imaginaire/losses/gan.py +173 -0
- imaginaire/losses/info_nce.py +87 -0
- imaginaire/losses/kl.py +23 -0
- imaginaire/losses/perceptual.py +395 -0
- imaginaire/losses/weighted_mse.py +28 -0
- imaginaire/model_utils/__init__.py +4 -0
- imaginaire/model_utils/gancraft/camctl.py +679 -0
- imaginaire/model_utils/gancraft/gaugan_lbl2col.csv +182 -0
- imaginaire/model_utils/gancraft/gaugan_reduction.csv +182 -0
- imaginaire/model_utils/gancraft/id2name_gg.csv +680 -0
- imaginaire/model_utils/gancraft/loss.py +96 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/sample_traj.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/teaser.gif filter=lfs diff=lfs merge=lfs -text
|
activation.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Function
|
3 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
4 |
+
|
5 |
+
class _trunc_exp(Function):
|
6 |
+
@staticmethod
|
7 |
+
@custom_fwd(cast_inputs=torch.float32) # cast to float32
|
8 |
+
def forward(ctx, x):
|
9 |
+
ctx.save_for_backward(x)
|
10 |
+
return torch.exp(x)
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
@custom_bwd
|
14 |
+
def backward(ctx, g):
|
15 |
+
x = ctx.saved_tensors[0]
|
16 |
+
return g * torch.exp(x.clamp(-15, 15))
|
17 |
+
|
18 |
+
trunc_exp = _trunc_exp.apply
|
app_gradio.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import importlib
|
5 |
+
import argparse
|
6 |
+
from imaginaire.config import Config
|
7 |
+
from imaginaire.utils.cudnn import init_cudnn
|
8 |
+
import gradio as gr
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
|
12 |
+
class WrappedModel(nn.Module):
|
13 |
+
r"""Dummy wrapping the module.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, module):
|
17 |
+
super(WrappedModel, self).__init__()
|
18 |
+
self.module = module
|
19 |
+
|
20 |
+
def forward(self, *args, **kwargs):
|
21 |
+
r"""PyTorch module forward function overload."""
|
22 |
+
return self.module(*args, **kwargs)
|
23 |
+
|
24 |
+
def parse_args():
|
25 |
+
parser = argparse.ArgumentParser(description='Training')
|
26 |
+
parser.add_argument('--config', type=str, default='./configs/scenedreamer_inference.yaml', help='Path to the training config file.')
|
27 |
+
parser.add_argument('--checkpoint', default='./scenedreamer_released.pt',
|
28 |
+
help='Checkpoint path.')
|
29 |
+
parser.add_argument('--output_dir', type=str, default='./test/',
|
30 |
+
help='Location to save the image outputs')
|
31 |
+
parser.add_argument('--seed', type=int, default=8888,
|
32 |
+
help='Random seed.')
|
33 |
+
args = parser.parse_args()
|
34 |
+
return args
|
35 |
+
|
36 |
+
|
37 |
+
args = parse_args()
|
38 |
+
cfg = Config(args.config)
|
39 |
+
|
40 |
+
# Initialize cudnn.
|
41 |
+
init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)
|
42 |
+
|
43 |
+
# Initialize data loaders and models.
|
44 |
+
|
45 |
+
lib_G = importlib.import_module(cfg.gen.type)
|
46 |
+
net_G = lib_G.Generator(cfg.gen, cfg.data)
|
47 |
+
net_G = net_G.to('cuda')
|
48 |
+
net_G = WrappedModel(net_G)
|
49 |
+
|
50 |
+
if args.checkpoint == '':
|
51 |
+
raise NotImplementedError("No checkpoint is provided for inference!")
|
52 |
+
|
53 |
+
# Load checkpoint.
|
54 |
+
# trainer.load_checkpoint(cfg, args.checkpoint)
|
55 |
+
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
56 |
+
net_G.load_state_dict(checkpoint['net_G'])
|
57 |
+
|
58 |
+
# Do inference.
|
59 |
+
net_G = net_G.module
|
60 |
+
net_G.eval()
|
61 |
+
for name, param in net_G.named_parameters():
|
62 |
+
param.requires_grad = False
|
63 |
+
torch.cuda.empty_cache()
|
64 |
+
world_dir = os.path.join(args.output_dir)
|
65 |
+
os.makedirs(world_dir, exist_ok=True)
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
def get_bev(seed):
|
70 |
+
print('[PCGGenerator] Generating BEV scene representation...')
|
71 |
+
os.system('python terrain_generator.py --size {} --seed {} --outdir {}'.format(net_G.voxel.sample_size, seed, world_dir))
|
72 |
+
heightmap_path = os.path.join(world_dir, 'heightmap.png')
|
73 |
+
semantic_path = os.path.join(world_dir, 'colormap.png')
|
74 |
+
heightmap = Image.open(heightmap_path)
|
75 |
+
semantic = Image.open(semantic_path)
|
76 |
+
return semantic, heightmap
|
77 |
+
|
78 |
+
def get_video(seed, num_frames, reso_h, reso_w):
|
79 |
+
device = torch.device('cuda')
|
80 |
+
rng_cuda = torch.Generator(device=device)
|
81 |
+
rng_cuda = rng_cuda.manual_seed(seed)
|
82 |
+
torch.manual_seed(seed)
|
83 |
+
torch.cuda.manual_seed(seed)
|
84 |
+
net_G.voxel.next_world(device, world_dir, checkpoint)
|
85 |
+
cam_mode = cfg.inference_args.camera_mode
|
86 |
+
cfg.inference_args.cam_maxstep = num_frames
|
87 |
+
cfg.inference_args.resolution_hw = [reso_h, reso_w]
|
88 |
+
current_outdir = os.path.join(world_dir, 'camera_{:02d}'.format(cam_mode))
|
89 |
+
os.makedirs(current_outdir, exist_ok=True)
|
90 |
+
z = torch.empty(1, net_G.style_dims, dtype=torch.float32, device=device)
|
91 |
+
z.normal_(generator=rng_cuda)
|
92 |
+
net_G.inference_givenstyle(z, current_outdir, **vars(cfg.inference_args))
|
93 |
+
return os.path.join(current_outdir, 'rgb_render.mp4')
|
94 |
+
|
95 |
+
markdown=f'''
|
96 |
+
# SceneDreamer: Unbounded 3D Scene Generation from 2D Image Collections
|
97 |
+
|
98 |
+
Authored by Zhaoxi Chen, Guangcong Wang, Ziwei Liu
|
99 |
+
### Useful links:
|
100 |
+
- [Official Github Repo](https://github.com/FrozenBurning/SceneDreamer)
|
101 |
+
- [Project Page](https://scene-dreamer.github.io/)
|
102 |
+
- [arXiv Link](https://arxiv.org/abs/2302.01330)
|
103 |
+
Licensed under the S-Lab License.
|
104 |
+
We offer a sampled scene whose BEVs are shown on the right. You can also use the button "Generate BEV" to randomly sample a new 3D world represented by a height map and a semantic map. But it requires a long time.
|
105 |
+
|
106 |
+
To render video, push the button "Render" to generate a camera trajectory flying through the world. You can specify rendering options as shown below!
|
107 |
+
'''
|
108 |
+
|
109 |
+
with gr.Blocks() as demo:
|
110 |
+
with gr.Row():
|
111 |
+
with gr.Column():
|
112 |
+
gr.Markdown(markdown)
|
113 |
+
with gr.Column():
|
114 |
+
with gr.Row():
|
115 |
+
with gr.Column():
|
116 |
+
semantic = gr.Image(value='./test/colormap.png',type="pil", shape=(512, 512))
|
117 |
+
with gr.Column():
|
118 |
+
height = gr.Image(value='./test/heightmap.png', type="pil", shape=(512, 512))
|
119 |
+
with gr.Row():
|
120 |
+
# with gr.Column():
|
121 |
+
# image = gr.Image(type='pil', shape(540, 960))
|
122 |
+
with gr.Column():
|
123 |
+
video = gr.Video()
|
124 |
+
with gr.Row():
|
125 |
+
num_frames = gr.Slider(minimum=10, maximum=200, value=20, step=1, label='Number of rendered frames')
|
126 |
+
user_seed = gr.Slider(minimum=0, maximum=999999, value=8888, step=1, label='Random seed')
|
127 |
+
resolution_h = gr.Slider(minimum=256, maximum=2160, value=270, step=1, label='Height of rendered image')
|
128 |
+
resolution_w = gr.Slider(minimum=256, maximum=3840, value=480, step=1, label='Width of rendered image')
|
129 |
+
|
130 |
+
with gr.Row():
|
131 |
+
btn = gr.Button(value="Generate BEV")
|
132 |
+
btn_2=gr.Button(value="Render")
|
133 |
+
|
134 |
+
btn.click(get_bev,[user_seed],[semantic, height])
|
135 |
+
btn_2.click(get_video,[user_seed, num_frames, resolution_h, resolution_w], [video])
|
136 |
+
|
137 |
+
demo.launch(debug=True)
|
assets/biome_image.png
ADDED
![]() |
assets/sample_traj.gif
ADDED
![]() |
Git LFS Details
|
assets/teaser.gif
ADDED
![]() |
Git LFS Details
|
configs/img2lmdb.yaml
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
inference_args:
|
2 |
+
random_style: True
|
3 |
+
use_fixed_random_style: False
|
4 |
+
keep_original_size: True
|
5 |
+
|
6 |
+
image_save_iter: 5000
|
7 |
+
snapshot_save_epoch: 5
|
8 |
+
max_epoch: 400
|
9 |
+
logging_iter: 100
|
10 |
+
trainer:
|
11 |
+
type: imaginaire.trainers.spade
|
12 |
+
model_average_config:
|
13 |
+
enabled: True
|
14 |
+
beta: 0.9999
|
15 |
+
start_iteration: 1000
|
16 |
+
num_batch_norm_estimation_iterations: 30
|
17 |
+
amp_config:
|
18 |
+
enabled: True
|
19 |
+
gan_mode: hinge
|
20 |
+
gan_relativistic: False
|
21 |
+
perceptual_loss:
|
22 |
+
mode: 'vgg19'
|
23 |
+
layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
|
24 |
+
weights: [0.03125, 0.0625, 0.125, 0.25, 1.0]
|
25 |
+
fp16: True
|
26 |
+
loss_weight:
|
27 |
+
gan: 1.0
|
28 |
+
perceptual: 10.0
|
29 |
+
feature_matching: 10.0
|
30 |
+
kl: 0.05
|
31 |
+
init:
|
32 |
+
type: xavier
|
33 |
+
gain: 0.02
|
34 |
+
gen_opt:
|
35 |
+
type: adam
|
36 |
+
lr: 0.0001
|
37 |
+
adam_beta1: 0.
|
38 |
+
adam_beta2: 0.999
|
39 |
+
lr_policy:
|
40 |
+
iteration_mode: False
|
41 |
+
type: step
|
42 |
+
step_size: 400
|
43 |
+
gamma: 0.1
|
44 |
+
dis_opt:
|
45 |
+
type: adam
|
46 |
+
lr: 0.0004
|
47 |
+
adam_beta1: 0.
|
48 |
+
adam_beta2: 0.999
|
49 |
+
lr_policy:
|
50 |
+
iteration_mode: False
|
51 |
+
type: step
|
52 |
+
step_size: 400
|
53 |
+
gamma: 0.1
|
54 |
+
gen:
|
55 |
+
type: imaginaire.generators.spade
|
56 |
+
version: v20
|
57 |
+
style_dims: 256
|
58 |
+
num_filters: 128
|
59 |
+
kernel_size: 3
|
60 |
+
weight_norm_type: 'spectral'
|
61 |
+
use_posenc_in_input_layer: False
|
62 |
+
global_adaptive_norm_type: 'sync_batch'
|
63 |
+
activation_norm_params:
|
64 |
+
num_filters: 128
|
65 |
+
kernel_size: 5
|
66 |
+
separate_projection: True
|
67 |
+
activation_norm_type: 'sync_batch'
|
68 |
+
style_enc:
|
69 |
+
num_filters: 64
|
70 |
+
kernel_size: 3
|
71 |
+
dis:
|
72 |
+
type: imaginaire.discriminators.spade
|
73 |
+
kernel_size: 4
|
74 |
+
num_filters: 128
|
75 |
+
max_num_filters: 512
|
76 |
+
num_discriminators: 2
|
77 |
+
num_layers: 5
|
78 |
+
activation_norm_type: 'none'
|
79 |
+
weight_norm_type: 'spectral'
|
80 |
+
|
81 |
+
# Data options.
|
82 |
+
data:
|
83 |
+
type: imaginaire.datasets.paired_images
|
84 |
+
# How many data loading workers per GPU?
|
85 |
+
num_workers: 8
|
86 |
+
input_types:
|
87 |
+
- images:
|
88 |
+
ext: jpg
|
89 |
+
num_channels: 3
|
90 |
+
normalize: True
|
91 |
+
use_dont_care: False
|
92 |
+
- seg_maps:
|
93 |
+
ext: jpg
|
94 |
+
num_channels: 1
|
95 |
+
is_mask: True
|
96 |
+
normalize: False
|
97 |
+
# - edge_maps:
|
98 |
+
# ext: png
|
99 |
+
# num_channels: 1
|
100 |
+
# normalize: False
|
101 |
+
|
102 |
+
full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels
|
103 |
+
use_dont_care: True
|
104 |
+
one_hot_num_classes:
|
105 |
+
seg_maps: 183
|
106 |
+
input_labels:
|
107 |
+
- seg_maps
|
108 |
+
# - edge_maps
|
109 |
+
|
110 |
+
# Which lmdb contains the ground truth image.
|
111 |
+
input_image:
|
112 |
+
- images
|
113 |
+
|
114 |
+
# Train dataset details.
|
115 |
+
train:
|
116 |
+
# Input LMDBs.
|
117 |
+
roots:
|
118 |
+
- ./data/lhq/train
|
119 |
+
# Batch size per GPU.
|
120 |
+
batch_size: 4
|
121 |
+
# Data augmentations to be performed in given order.
|
122 |
+
augmentations:
|
123 |
+
resize_smallest_side: 256
|
124 |
+
# Rotate in (-rotate, rotate) in degrees.
|
125 |
+
rotate: 0
|
126 |
+
# Scale image by factor \in [1, 1+random_scale_limit].
|
127 |
+
random_scale_limit: 0.2
|
128 |
+
# Horizontal flip?
|
129 |
+
horizontal_flip: True
|
130 |
+
# Crop size.
|
131 |
+
random_crop_h_w: 256, 256
|
132 |
+
# Train dataset details.
|
133 |
+
val:
|
134 |
+
# Input LMDBs.
|
135 |
+
roots:
|
136 |
+
- ./data/lhq/val
|
137 |
+
# Batch size per GPU.
|
138 |
+
batch_size: 4
|
139 |
+
# Data augmentations to be performed in given order.
|
140 |
+
augmentations:
|
141 |
+
# Crop size.
|
142 |
+
resize_h_w: 256, 256
|
143 |
+
|
144 |
+
test_data:
|
145 |
+
type: imaginaire.datasets.paired_images
|
146 |
+
num_workers: 8
|
147 |
+
input_types:
|
148 |
+
- seg_maps:
|
149 |
+
ext: jpg
|
150 |
+
num_channels: 1
|
151 |
+
is_mask: True
|
152 |
+
normalize: False
|
153 |
+
# - edge_maps:
|
154 |
+
# ext: png
|
155 |
+
# num_channels: 1
|
156 |
+
# normalize: False
|
157 |
+
|
158 |
+
full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels
|
159 |
+
use_dont_care: True
|
160 |
+
one_hot_num_classes:
|
161 |
+
seg_maps: 183
|
162 |
+
input_labels:
|
163 |
+
- seg_maps
|
164 |
+
# - edge_maps
|
165 |
+
|
166 |
+
paired: True
|
167 |
+
# Validation dataset details.
|
168 |
+
test:
|
169 |
+
is_lmdb: False
|
170 |
+
roots:
|
171 |
+
- ./data/lhq/train
|
172 |
+
# Batch size per GPU.
|
173 |
+
batch_size: 1
|
174 |
+
# If resize_h_w is not given, then it is assumed to be same as crop_h_w.
|
175 |
+
augmentations:
|
176 |
+
resize_h_w: 256, 256
|
177 |
+
horizontal_flip: False
|
configs/landscape1m.yaml
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained_weight: ./landscape1m-segformer.pt
|
2 |
+
|
3 |
+
inference_args:
|
4 |
+
random_style: True
|
5 |
+
use_fixed_random_style: False
|
6 |
+
keep_original_size: True
|
7 |
+
|
8 |
+
image_save_iter: 5000
|
9 |
+
snapshot_save_epoch: 5
|
10 |
+
snapshot_save_iter: 30000
|
11 |
+
max_epoch: 400
|
12 |
+
logging_iter: 100
|
13 |
+
trainer:
|
14 |
+
type: imaginaire.trainers.spade
|
15 |
+
model_average_config:
|
16 |
+
enabled: True
|
17 |
+
beta: 0.9999
|
18 |
+
start_iteration: 1000
|
19 |
+
num_batch_norm_estimation_iterations: 30
|
20 |
+
amp_config:
|
21 |
+
enabled: True
|
22 |
+
gan_mode: hinge
|
23 |
+
gan_relativistic: False
|
24 |
+
perceptual_loss:
|
25 |
+
mode: 'vgg19'
|
26 |
+
layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
|
27 |
+
weights: [0.03125, 0.0625, 0.125, 0.25, 1.0]
|
28 |
+
fp16: True
|
29 |
+
loss_weight:
|
30 |
+
gan: 1.0
|
31 |
+
perceptual: 10.0
|
32 |
+
feature_matching: 10.0
|
33 |
+
kl: 0.05
|
34 |
+
init:
|
35 |
+
type: xavier
|
36 |
+
gain: 0.02
|
37 |
+
gen_opt:
|
38 |
+
type: adam
|
39 |
+
lr: 0.0001
|
40 |
+
adam_beta1: 0.
|
41 |
+
adam_beta2: 0.999
|
42 |
+
lr_policy:
|
43 |
+
iteration_mode: False
|
44 |
+
type: step
|
45 |
+
step_size: 400
|
46 |
+
gamma: 0.1
|
47 |
+
dis_opt:
|
48 |
+
type: adam
|
49 |
+
lr: 0.0004
|
50 |
+
adam_beta1: 0.
|
51 |
+
adam_beta2: 0.999
|
52 |
+
lr_policy:
|
53 |
+
iteration_mode: False
|
54 |
+
type: step
|
55 |
+
step_size: 400
|
56 |
+
gamma: 0.1
|
57 |
+
gen:
|
58 |
+
type: imaginaire.generators.spade
|
59 |
+
version: v20
|
60 |
+
output_multiplier: 0.5
|
61 |
+
image_channels: 3
|
62 |
+
num_labels: 184
|
63 |
+
style_dims: 256
|
64 |
+
num_filters: 128
|
65 |
+
kernel_size: 3
|
66 |
+
weight_norm_type: 'spectral'
|
67 |
+
use_posenc_in_input_layer: False
|
68 |
+
global_adaptive_norm_type: 'sync_batch'
|
69 |
+
activation_norm_params:
|
70 |
+
num_filters: 128
|
71 |
+
kernel_size: 5
|
72 |
+
separate_projection: True
|
73 |
+
activation_norm_type: 'sync_batch'
|
74 |
+
style_enc:
|
75 |
+
num_filters: 64
|
76 |
+
kernel_size: 3
|
77 |
+
dis:
|
78 |
+
type: imaginaire.discriminators.spade
|
79 |
+
kernel_size: 4
|
80 |
+
num_filters: 128
|
81 |
+
max_num_filters: 512
|
82 |
+
num_discriminators: 2
|
83 |
+
num_layers: 5
|
84 |
+
activation_norm_type: 'none'
|
85 |
+
weight_norm_type: 'spectral'
|
86 |
+
|
87 |
+
# Data options.
|
88 |
+
data:
|
89 |
+
type: imaginaire.datasets.paired_images
|
90 |
+
# How many data loading workers per GPU?
|
91 |
+
num_workers: 8
|
92 |
+
input_types:
|
93 |
+
- images:
|
94 |
+
ext: jpg
|
95 |
+
num_channels: 3
|
96 |
+
normalize: True
|
97 |
+
use_dont_care: False
|
98 |
+
- seg_maps:
|
99 |
+
ext: jpg
|
100 |
+
num_channels: 1
|
101 |
+
is_mask: True
|
102 |
+
normalize: False
|
103 |
+
|
104 |
+
full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels
|
105 |
+
use_dont_care: True
|
106 |
+
one_hot_num_classes:
|
107 |
+
seg_maps: 183
|
108 |
+
input_labels:
|
109 |
+
- seg_maps
|
110 |
+
|
111 |
+
# Which lmdb contains the ground truth image.
|
112 |
+
input_image:
|
113 |
+
- images
|
114 |
+
|
115 |
+
# Train dataset details.
|
116 |
+
train:
|
117 |
+
# Input LMDBs.
|
118 |
+
dataset_type: lmdb
|
119 |
+
roots:
|
120 |
+
- ./data/lhq_lmdb/train
|
121 |
+
# Batch size per GPU.
|
122 |
+
batch_size: 4
|
123 |
+
# Data augmentations to be performed in given order.
|
124 |
+
augmentations:
|
125 |
+
resize_smallest_side: 512
|
126 |
+
# Rotate in (-rotate, rotate) in degrees.
|
127 |
+
rotate: 0
|
128 |
+
# Scale image by factor \in [1, 1+random_scale_limit].
|
129 |
+
random_scale_limit: 0.2
|
130 |
+
# Horizontal flip?
|
131 |
+
horizontal_flip: True
|
132 |
+
# Crop size.
|
133 |
+
random_crop_h_w: 512, 512
|
134 |
+
# Train dataset details.
|
135 |
+
val:
|
136 |
+
dataset_type: lmdb
|
137 |
+
# Input LMDBs.
|
138 |
+
roots:
|
139 |
+
- ./data/lhq_lmdb/val
|
140 |
+
# Batch size per GPU.
|
141 |
+
batch_size: 4
|
142 |
+
# Data augmentations to be performed in given order.
|
143 |
+
augmentations:
|
144 |
+
# Crop size.
|
145 |
+
resize_h_w: 512, 512
|
146 |
+
|
147 |
+
test_data:
|
148 |
+
type: imaginaire.datasets.paired_images
|
149 |
+
num_workers: 8
|
150 |
+
input_types:
|
151 |
+
- seg_maps:
|
152 |
+
ext: jpg
|
153 |
+
num_channels: 1
|
154 |
+
is_mask: True
|
155 |
+
normalize: False
|
156 |
+
|
157 |
+
full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels
|
158 |
+
use_dont_care: True
|
159 |
+
one_hot_num_classes:
|
160 |
+
seg_maps: 183
|
161 |
+
input_labels:
|
162 |
+
- seg_maps
|
163 |
+
|
164 |
+
paired: True
|
165 |
+
# Validation dataset details.
|
166 |
+
test:
|
167 |
+
is_lmdb: True
|
168 |
+
roots:
|
169 |
+
- ./data/lhq_lmdb/val
|
170 |
+
# Batch size per GPU.
|
171 |
+
batch_size: 1
|
172 |
+
# If resize_h_w is not given, then it is assumed to be same as crop_h_w.
|
173 |
+
augmentations:
|
174 |
+
resize_h_w: 256, 256
|
175 |
+
horizontal_flip: False
|
configs/scenedreamer_inference.yaml
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
inference_args:
|
2 |
+
# 0: Camera orbiting the scene & looking at the center
|
3 |
+
# 1: Camera orbiting the scene & zooming in
|
4 |
+
# 2: Camera orbiting the scene & coming closer and closer to the center
|
5 |
+
# 3: Similar to 2, camera orbiting at the opposite direction
|
6 |
+
# 4: Simliar to 2, camera stays further away from the center
|
7 |
+
# 5: Camera sits at the center and look outwards
|
8 |
+
# 6: Camera rises while looking down
|
9 |
+
# 7: Camera really far away looking down at a 45deg angle
|
10 |
+
# 8: Camera for perpetual view generation, non-sliding window
|
11 |
+
# 9: Camera for infinite world generation, sliding window
|
12 |
+
camera_mode: 4
|
13 |
+
|
14 |
+
cam_maxstep: 40
|
15 |
+
resolution_hw: [540, 960]
|
16 |
+
num_samples: 40
|
17 |
+
cam_ang: 72
|
18 |
+
|
19 |
+
gen:
|
20 |
+
type: imaginaire.generators.scenedreamer
|
21 |
+
pcg_dataset_path: None
|
22 |
+
pcg_cache: False
|
23 |
+
scene_size: 2048
|
24 |
+
|
25 |
+
blk_feat_dim: 64
|
26 |
+
|
27 |
+
pe_lvl_feat: 4
|
28 |
+
pe_incl_orig_feat: False
|
29 |
+
pe_no_pe_feat_dim: 40
|
30 |
+
pe_lvl_raydir: 0
|
31 |
+
pe_incl_orig_raydir: False
|
32 |
+
style_dims: 128 # Set to 0 to disable style.
|
33 |
+
interm_style_dims: 256
|
34 |
+
final_feat_dim: 64
|
35 |
+
|
36 |
+
# Number of pixels removed from each edge to reduce boundary artifact of CNN
|
37 |
+
# both sides combined (8 -> 4 on left and 4 on right).
|
38 |
+
pad: 6
|
39 |
+
|
40 |
+
# ======== Sky network ========
|
41 |
+
pe_lvl_raydir_sky: 5
|
42 |
+
pe_incl_orig_raydir_sky: True
|
43 |
+
|
44 |
+
# ======== Style Encoder =========
|
45 |
+
# Comment out to disable style encoder.
|
46 |
+
style_enc:
|
47 |
+
num_filters: 64
|
48 |
+
kernel_size: 3
|
49 |
+
weight_norm_type: 'none'
|
50 |
+
|
51 |
+
stylenet_model: StyleMLP
|
52 |
+
stylenet_model_kwargs:
|
53 |
+
normalize_input: True
|
54 |
+
num_layers: 5
|
55 |
+
|
56 |
+
mlp_model: RenderMLP
|
57 |
+
mlp_model_kwargs:
|
58 |
+
use_seg: True
|
59 |
+
|
60 |
+
# ======== Ray Casting Params ========
|
61 |
+
num_blocks_early_stop: 6
|
62 |
+
num_samples: 24 # Original model uses 24. Reduced to 4 to allow training on 12GB GPUs (with significant performance penalty)
|
63 |
+
sample_depth: 3 # Stop the ray after certain depth
|
64 |
+
coarse_deterministic_sampling: False
|
65 |
+
sample_use_box_boundaries: False # Including voxel boundaries into the sample
|
66 |
+
|
67 |
+
# ======== Blender ========
|
68 |
+
raw_noise_std: 0.0
|
69 |
+
dists_scale: 0.25
|
70 |
+
clip_feat_map: True
|
71 |
+
# Prevent sky from leaking to the foreground.
|
72 |
+
keep_sky_out: True
|
73 |
+
keep_sky_out_avgpool: True
|
74 |
+
sky_global_avgpool: True
|
75 |
+
|
76 |
+
# ======== Label translator ========
|
77 |
+
reduced_label_set: True
|
78 |
+
use_label_smooth: True
|
79 |
+
use_label_smooth_real: True
|
80 |
+
use_label_smooth_pgt: True
|
81 |
+
label_smooth_dia: 11
|
82 |
+
|
83 |
+
# ======== Camera sampler ========
|
84 |
+
camera_sampler_type: 'traditional'
|
85 |
+
cam_res: [360, 640] # Camera resolution before cropping.
|
86 |
+
crop_size: [256, 256] # Actual crop size is crop_size+pad. It should generally match random_crop_h_w in dataloader.
|
87 |
+
|
88 |
+
# Threshold for rejecting camera poses that will result in a seg mask with low entropy.
|
89 |
+
# Generally, 0.5 min, 0.8 max.
|
90 |
+
camera_min_entropy: 0.75
|
91 |
+
|
92 |
+
# Threshold for rejecting camera poses that are too close to the objects.
|
93 |
+
camera_rej_avg_depth: 2.0
|
configs/scenedreamer_train.yaml
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_save_iter: 5000
|
2 |
+
snapshot_save_epoch: 5
|
3 |
+
snapshot_save_iter: 10000
|
4 |
+
max_epoch: 400
|
5 |
+
logging_iter: 10
|
6 |
+
|
7 |
+
trainer:
|
8 |
+
type: imaginaire.trainers.gancraft
|
9 |
+
model_average_config:
|
10 |
+
enabled: False
|
11 |
+
amp_config:
|
12 |
+
enabled: False
|
13 |
+
perceptual_loss:
|
14 |
+
mode: 'vgg19'
|
15 |
+
layers: ['relu_3_1', 'relu_4_1', 'relu_5_1']
|
16 |
+
weights: [0.125, 0.25, 1.0]
|
17 |
+
loss_weight:
|
18 |
+
l2: 10.0
|
19 |
+
gan: 0.5
|
20 |
+
pseudo_gan: 0.5
|
21 |
+
perceptual: 10.0
|
22 |
+
kl: 0.05
|
23 |
+
init:
|
24 |
+
type: xavier
|
25 |
+
gain: 0.02
|
26 |
+
|
27 |
+
# SPADE/GauGAN model for pseudo-GT generation.
|
28 |
+
gaugan_loader:
|
29 |
+
config: configs/landscape1m.yaml
|
30 |
+
|
31 |
+
image_to_tensorboard: True
|
32 |
+
distributed_data_parallel_params:
|
33 |
+
find_unused_parameters: False
|
34 |
+
broadcast_buffers: False
|
35 |
+
|
36 |
+
gen_opt:
|
37 |
+
type: adam
|
38 |
+
lr: 0.0001
|
39 |
+
eps: 1.e-7
|
40 |
+
adam_beta1: 0.
|
41 |
+
adam_beta2: 0.999
|
42 |
+
lr_policy:
|
43 |
+
iteration_mode: False
|
44 |
+
type: step
|
45 |
+
step_size: 400
|
46 |
+
gamma: 0.1
|
47 |
+
param_groups:
|
48 |
+
world_encoder:
|
49 |
+
lr: 0.0005
|
50 |
+
hash_encoder:
|
51 |
+
lr: 0.0001
|
52 |
+
render_net:
|
53 |
+
lr: 0.0001
|
54 |
+
sky_net:
|
55 |
+
lr: 0.0001
|
56 |
+
style_net:
|
57 |
+
lr: 0.0001
|
58 |
+
style_encoder:
|
59 |
+
lr: 0.0001
|
60 |
+
denoiser:
|
61 |
+
lr: 0.0001
|
62 |
+
|
63 |
+
dis_opt:
|
64 |
+
type: adam
|
65 |
+
lr: 0.0004
|
66 |
+
eps: 1.e-7
|
67 |
+
adam_beta1: 0.
|
68 |
+
adam_beta2: 0.999
|
69 |
+
lr_policy:
|
70 |
+
iteration_mode: False
|
71 |
+
type: step
|
72 |
+
step_size: 400
|
73 |
+
gamma: 0.1
|
74 |
+
|
75 |
+
gen:
|
76 |
+
type: imaginaire.generators.scenedreamer
|
77 |
+
pcg_dataset_path: ./data/terrain_cache
|
78 |
+
pcg_cache: True
|
79 |
+
scene_size: 2048
|
80 |
+
|
81 |
+
blk_feat_dim: 64
|
82 |
+
|
83 |
+
pe_lvl_feat: 4
|
84 |
+
pe_incl_orig_feat: False
|
85 |
+
pe_no_pe_feat_dim: 40
|
86 |
+
pe_lvl_raydir: 0
|
87 |
+
pe_incl_orig_raydir: False
|
88 |
+
style_dims: 128 # Set to 0 to disable style.
|
89 |
+
interm_style_dims: 256
|
90 |
+
final_feat_dim: 64
|
91 |
+
|
92 |
+
# Number of pixels removed from each edge to reduce boundary artifact of CNN
|
93 |
+
# both sides combined (8 -> 4 on left and 4 on right).
|
94 |
+
pad: 6
|
95 |
+
|
96 |
+
# ======== Sky network ========
|
97 |
+
pe_lvl_raydir_sky: 5
|
98 |
+
pe_incl_orig_raydir_sky: True
|
99 |
+
|
100 |
+
# ======== Style Encoder =========
|
101 |
+
# Comment out to disable style encoder.
|
102 |
+
style_enc:
|
103 |
+
num_filters: 64
|
104 |
+
kernel_size: 3
|
105 |
+
weight_norm_type: 'none'
|
106 |
+
|
107 |
+
stylenet_model: StyleMLP
|
108 |
+
stylenet_model_kwargs:
|
109 |
+
normalize_input: True
|
110 |
+
num_layers: 5
|
111 |
+
|
112 |
+
mlp_model: RenderMLP
|
113 |
+
mlp_model_kwargs:
|
114 |
+
use_seg: True
|
115 |
+
|
116 |
+
# ======== Ray Casting Params ========
|
117 |
+
num_blocks_early_stop: 6
|
118 |
+
num_samples: 24 # Decrease it if you got OOM on lowend GPU
|
119 |
+
sample_depth: 3 # Stop the ray after certain depth
|
120 |
+
coarse_deterministic_sampling: False
|
121 |
+
sample_use_box_boundaries: False # Including voxel boundaries into the sample
|
122 |
+
|
123 |
+
# ======== Blender ========
|
124 |
+
raw_noise_std: 0.0
|
125 |
+
dists_scale: 0.25
|
126 |
+
clip_feat_map: True
|
127 |
+
# Prevent sky from leaking to the foreground.
|
128 |
+
keep_sky_out: True
|
129 |
+
keep_sky_out_avgpool: True
|
130 |
+
sky_global_avgpool: True
|
131 |
+
|
132 |
+
# ======== Label translator ========
|
133 |
+
reduced_label_set: True
|
134 |
+
use_label_smooth: True
|
135 |
+
use_label_smooth_real: True
|
136 |
+
use_label_smooth_pgt: True
|
137 |
+
label_smooth_dia: 11
|
138 |
+
|
139 |
+
# ======== Camera sampler ========
|
140 |
+
camera_sampler_type: 'traditional'
|
141 |
+
cam_res: [360, 640] # Camera resolution before cropping.
|
142 |
+
crop_size: [256, 256] # Actual crop size is crop_size+pad. It should generally match random_crop_h_w in dataloader.
|
143 |
+
|
144 |
+
# Threshold for rejecting camera poses that will result in a seg mask with low entropy.
|
145 |
+
# Generally, 0.5 min, 0.8 max.
|
146 |
+
camera_min_entropy: 0.75
|
147 |
+
|
148 |
+
# Threshold for rejecting camera poses that are too close to the objects.
|
149 |
+
camera_rej_avg_depth: 2.0
|
150 |
+
|
151 |
+
dis:
|
152 |
+
type: imaginaire.discriminators.gancraft
|
153 |
+
image_channels: 3
|
154 |
+
num_labels: 12 # Same as num_reduced_lbls.
|
155 |
+
use_label: True
|
156 |
+
num_filters: 128
|
157 |
+
fpse_kernel_size: 3
|
158 |
+
activation_norm_type: 'none'
|
159 |
+
weight_norm_type: spectral
|
160 |
+
smooth_resample: True
|
161 |
+
|
162 |
+
# Data options.
|
163 |
+
data:
|
164 |
+
type: imaginaire.datasets.paired_images
|
165 |
+
num_workers: 8
|
166 |
+
input_types:
|
167 |
+
- images:
|
168 |
+
ext: jpg
|
169 |
+
num_channels: 3
|
170 |
+
normalize: True
|
171 |
+
use_dont_care: False
|
172 |
+
- seg_maps:
|
173 |
+
ext: png
|
174 |
+
num_channels: 1
|
175 |
+
is_mask: True
|
176 |
+
normalize: False
|
177 |
+
|
178 |
+
full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels
|
179 |
+
use_dont_care: False
|
180 |
+
one_hot_num_classes:
|
181 |
+
seg_maps: 184
|
182 |
+
input_labels:
|
183 |
+
- seg_maps
|
184 |
+
|
185 |
+
# Which lmdb contains the ground truth image.
|
186 |
+
input_image:
|
187 |
+
- images
|
188 |
+
|
189 |
+
# Train dataset details.
|
190 |
+
train:
|
191 |
+
dataset_type: lmdb
|
192 |
+
# Input LMDBs.
|
193 |
+
roots:
|
194 |
+
- ./data/lhq_lmdb/train
|
195 |
+
# Batch size per GPU.
|
196 |
+
batch_size: 1
|
197 |
+
# Data augmentations to be performed in given order.
|
198 |
+
augmentations:
|
199 |
+
resize_smallest_side: 256
|
200 |
+
# Rotate in (-rotate, rotate) in degrees.
|
201 |
+
rotate: 0
|
202 |
+
# Scale image by factor \in [1, 1+random_scale_limit].
|
203 |
+
random_scale_limit: 0.2
|
204 |
+
# Horizontal flip?
|
205 |
+
horizontal_flip: True
|
206 |
+
# Crop size.
|
207 |
+
random_crop_h_w: 256, 256
|
208 |
+
# Train dataset details.
|
209 |
+
val:
|
210 |
+
dataset_type: lmdb
|
211 |
+
# Input LMDBs.
|
212 |
+
roots:
|
213 |
+
- ./data/lhq_lmdb/val
|
214 |
+
# Batch size per GPU.
|
215 |
+
batch_size: 1
|
216 |
+
# Data augmentations to be performed in given order.
|
217 |
+
augmentations:
|
218 |
+
# Crop size.
|
219 |
+
resize_h_w: 256, 256
|
220 |
+
|
221 |
+
test_data:
|
222 |
+
type: imaginaire.datasets.dummy
|
223 |
+
num_workers: 0
|
encoding.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class FreqEncoder(nn.Module):
|
6 |
+
def __init__(self, input_dim, max_freq_log2, N_freqs,
|
7 |
+
log_sampling=True, include_input=True,
|
8 |
+
periodic_fns=(torch.sin, torch.cos)):
|
9 |
+
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.input_dim = input_dim
|
13 |
+
self.include_input = include_input
|
14 |
+
self.periodic_fns = periodic_fns
|
15 |
+
|
16 |
+
self.output_dim = 0
|
17 |
+
if self.include_input:
|
18 |
+
self.output_dim += self.input_dim
|
19 |
+
|
20 |
+
self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
|
21 |
+
|
22 |
+
if log_sampling:
|
23 |
+
self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
|
24 |
+
else:
|
25 |
+
self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)
|
26 |
+
|
27 |
+
self.freq_bands = self.freq_bands.numpy().tolist()
|
28 |
+
|
29 |
+
def forward(self, input, **kwargs):
|
30 |
+
|
31 |
+
out = []
|
32 |
+
if self.include_input:
|
33 |
+
out.append(input)
|
34 |
+
|
35 |
+
for i in range(len(self.freq_bands)):
|
36 |
+
freq = self.freq_bands[i]
|
37 |
+
for p_fn in self.periodic_fns:
|
38 |
+
out.append(p_fn(input * freq))
|
39 |
+
|
40 |
+
out = torch.cat(out, dim=-1)
|
41 |
+
|
42 |
+
|
43 |
+
return out
|
44 |
+
|
45 |
+
def get_encoder(encoding, input_dim=3,
|
46 |
+
multires=6,
|
47 |
+
degree=4,
|
48 |
+
num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
|
49 |
+
**kwargs):
|
50 |
+
|
51 |
+
if encoding == 'None':
|
52 |
+
return lambda x, **kwargs: x, input_dim
|
53 |
+
|
54 |
+
elif encoding == 'hashgrid':
|
55 |
+
from gridencoder import GridEncoder
|
56 |
+
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
|
57 |
+
|
58 |
+
elif encoding == 'tiledgrid':
|
59 |
+
from gridencoder import GridEncoder
|
60 |
+
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
|
61 |
+
elif encoding == 'varhashgrid':
|
62 |
+
from gridencoder.grid import VarGridEncoder
|
63 |
+
encoder = VarGridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, hash_entries = kwargs['hash_feat_dim'])
|
64 |
+
else:
|
65 |
+
raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
|
66 |
+
|
67 |
+
return encoder, encoder.output_dim
|
environment.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: scenedreamer
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
dependencies:
|
6 |
+
- python=3.9
|
7 |
+
- pytorch=1.12.0
|
8 |
+
- cudatoolkit=11.3
|
9 |
+
- torchvision
|
10 |
+
- pip
|
11 |
+
- numpy
|
12 |
+
- scipy
|
13 |
+
- scikit-image
|
14 |
+
- pip:
|
15 |
+
- einops
|
16 |
+
- noise
|
17 |
+
- opencv-python
|
18 |
+
- cmake
|
19 |
+
- pynvml
|
20 |
+
- Pillow>=8.3.2
|
21 |
+
- tqdm==4.35.0
|
22 |
+
- wget
|
23 |
+
- cython
|
24 |
+
- lmdb
|
25 |
+
- av
|
26 |
+
- opencv-python
|
27 |
+
- opencv-contrib-python
|
28 |
+
- imutils
|
29 |
+
- imageio-ffmpeg
|
30 |
+
- qimage2ndarray
|
31 |
+
- albumentations
|
32 |
+
- requests==2.25.1
|
33 |
+
- nvidia-ml-py3==7.352.0
|
34 |
+
- pyglet
|
35 |
+
- timm
|
36 |
+
- diskcache
|
37 |
+
- boto3
|
38 |
+
- awscli_plugin_endpoint
|
39 |
+
- awscli
|
40 |
+
- rsa
|
41 |
+
- wandb
|
42 |
+
- tensorboard
|
43 |
+
- lpips
|
44 |
+
- matplotlib
|
gridencoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .grid import GridEncoder
|
gridencoder/backend.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.cpp_extension import load
|
3 |
+
|
4 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
|
6 |
+
nvcc_flags = [
|
7 |
+
'-O3', '-std=c++17',
|
8 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
9 |
+
]
|
10 |
+
|
11 |
+
if os.name == "posix":
|
12 |
+
c_flags = ['-O3', '-std=c++17']
|
13 |
+
elif os.name == "nt":
|
14 |
+
c_flags = ['/O2', '/std:c++17']
|
15 |
+
|
16 |
+
# find cl.exe
|
17 |
+
def find_cl_path():
|
18 |
+
import glob
|
19 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
20 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
21 |
+
if paths:
|
22 |
+
return paths[0]
|
23 |
+
|
24 |
+
# If cl.exe is not on path, try to find it.
|
25 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
26 |
+
cl_path = find_cl_path()
|
27 |
+
if cl_path is None:
|
28 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
29 |
+
os.environ["PATH"] += ";" + cl_path
|
30 |
+
|
31 |
+
_backend = load(name='_grid_encoder',
|
32 |
+
extra_cflags=c_flags,
|
33 |
+
extra_cuda_cflags=nvcc_flags,
|
34 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
35 |
+
'gridencoder.cu',
|
36 |
+
'bindings.cpp',
|
37 |
+
]],
|
38 |
+
)
|
39 |
+
|
40 |
+
__all__ = ['_backend']
|
gridencoder/grid.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.autograd.function import once_differentiable
|
7 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
8 |
+
|
9 |
+
try:
|
10 |
+
import _gridencoder as _backend
|
11 |
+
except ImportError:
|
12 |
+
from .backend import _backend
|
13 |
+
|
14 |
+
_gridtype_to_id = {
|
15 |
+
'hash': 0,
|
16 |
+
'tiled': 1,
|
17 |
+
}
|
18 |
+
|
19 |
+
class _grid_encode(Function):
|
20 |
+
@staticmethod
|
21 |
+
@custom_fwd
|
22 |
+
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
|
23 |
+
# inputs: [B, D], float in [0, 1]
|
24 |
+
# embeddings: [sO, C], float
|
25 |
+
# offsets: [L + 1], int
|
26 |
+
# RETURN: [B, F], float
|
27 |
+
|
28 |
+
inputs = inputs.contiguous()
|
29 |
+
|
30 |
+
B, D = inputs.shape # batch size, coord dim
|
31 |
+
L = offsets.shape[0] - 1 # level
|
32 |
+
C = embeddings.shape[1] # embedding dim for each level
|
33 |
+
S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
|
34 |
+
H = base_resolution # base resolution
|
35 |
+
|
36 |
+
# manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
|
37 |
+
# if C % 2 != 0, force float, since half for atomicAdd is very slow.
|
38 |
+
if torch.is_autocast_enabled() and C % 2 == 0:
|
39 |
+
embeddings = embeddings.to(torch.half)
|
40 |
+
|
41 |
+
# L first, optimize cache for cuda kernel, but needs an extra permute later
|
42 |
+
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
|
43 |
+
|
44 |
+
if calc_grad_inputs:
|
45 |
+
dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
|
46 |
+
else:
|
47 |
+
dy_dx = torch.empty(1, device=inputs.device, dtype=embeddings.dtype) # placeholder... TODO: a better way?
|
48 |
+
|
49 |
+
_backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners)
|
50 |
+
|
51 |
+
# permute back to [B, L * C]
|
52 |
+
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
|
53 |
+
|
54 |
+
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
|
55 |
+
ctx.dims = [B, D, C, L, S, H, gridtype]
|
56 |
+
ctx.calc_grad_inputs = calc_grad_inputs
|
57 |
+
ctx.align_corners = align_corners
|
58 |
+
|
59 |
+
return outputs
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
#@once_differentiable
|
63 |
+
@custom_bwd
|
64 |
+
def backward(ctx, grad):
|
65 |
+
|
66 |
+
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
|
67 |
+
B, D, C, L, S, H, gridtype = ctx.dims
|
68 |
+
calc_grad_inputs = ctx.calc_grad_inputs
|
69 |
+
align_corners = ctx.align_corners
|
70 |
+
|
71 |
+
# grad: [B, L * C] --> [L, B, C]
|
72 |
+
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
|
73 |
+
|
74 |
+
grad_embeddings = torch.zeros_like(embeddings)
|
75 |
+
|
76 |
+
if calc_grad_inputs:
|
77 |
+
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
|
78 |
+
else:
|
79 |
+
grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype)
|
80 |
+
|
81 |
+
_backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners)
|
82 |
+
|
83 |
+
if calc_grad_inputs:
|
84 |
+
grad_inputs = grad_inputs.to(inputs.dtype)
|
85 |
+
return grad_inputs, grad_embeddings, None, None, None, None, None, None
|
86 |
+
else:
|
87 |
+
return None, grad_embeddings, None, None, None, None, None, None
|
88 |
+
|
89 |
+
|
90 |
+
grid_encode = _grid_encode.apply
|
91 |
+
|
92 |
+
|
93 |
+
class GridEncoder(nn.Module):
|
94 |
+
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
# the finest resolution desired at the last level, if provided, overridee per_level_scale
|
98 |
+
if desired_resolution is not None:
|
99 |
+
per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
|
100 |
+
|
101 |
+
self.input_dim = input_dim # coord dims, 2 or 3
|
102 |
+
self.num_levels = num_levels # num levels, each level multiply resolution by 2
|
103 |
+
self.level_dim = level_dim # encode channels per level
|
104 |
+
self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
|
105 |
+
self.log2_hashmap_size = log2_hashmap_size
|
106 |
+
self.base_resolution = base_resolution
|
107 |
+
self.output_dim = num_levels * level_dim
|
108 |
+
self.gridtype = gridtype
|
109 |
+
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
|
110 |
+
self.align_corners = align_corners
|
111 |
+
|
112 |
+
# allocate parameters
|
113 |
+
offsets = []
|
114 |
+
offset = 0
|
115 |
+
self.max_params = 2 ** log2_hashmap_size
|
116 |
+
for i in range(num_levels):
|
117 |
+
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
|
118 |
+
params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
|
119 |
+
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
|
120 |
+
offsets.append(offset)
|
121 |
+
offset += params_in_level
|
122 |
+
offsets.append(offset)
|
123 |
+
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
|
124 |
+
self.register_buffer('offsets', offsets)
|
125 |
+
|
126 |
+
self.n_params = offsets[-1] * level_dim
|
127 |
+
|
128 |
+
# parameters
|
129 |
+
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
|
130 |
+
|
131 |
+
self.reset_parameters()
|
132 |
+
|
133 |
+
def reset_parameters(self):
|
134 |
+
std = 1e-4
|
135 |
+
self.embeddings.data.uniform_(-std, std)
|
136 |
+
|
137 |
+
def __repr__(self):
|
138 |
+
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
|
139 |
+
|
140 |
+
def forward(self, inputs, bound=1):
|
141 |
+
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
|
142 |
+
# return: [..., num_levels * level_dim]
|
143 |
+
|
144 |
+
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
145 |
+
|
146 |
+
#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
|
147 |
+
|
148 |
+
prefix_shape = list(inputs.shape[:-1])
|
149 |
+
inputs = inputs.view(-1, self.input_dim)
|
150 |
+
|
151 |
+
outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
|
152 |
+
outputs = outputs.view(prefix_shape + [self.output_dim])
|
153 |
+
|
154 |
+
#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
|
155 |
+
|
156 |
+
return outputs
|
157 |
+
|
158 |
+
class VarGridEncoder(nn.Module):
|
159 |
+
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, hash_entries=None):
|
160 |
+
super().__init__()
|
161 |
+
|
162 |
+
# the finest resolution desired at the last level, if provided, overridee per_level_scale
|
163 |
+
if desired_resolution is not None:
|
164 |
+
per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
|
165 |
+
|
166 |
+
self.input_dim = input_dim # coord dims, 2 or 3
|
167 |
+
self.num_levels = num_levels # num levels, each level multiply resolution by 2
|
168 |
+
self.level_dim = level_dim # encode channels per level
|
169 |
+
self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
|
170 |
+
self.log2_hashmap_size = log2_hashmap_size
|
171 |
+
self.base_resolution = base_resolution
|
172 |
+
self.output_dim = num_levels * level_dim
|
173 |
+
self.gridtype = gridtype
|
174 |
+
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
|
175 |
+
self.align_corners = align_corners
|
176 |
+
|
177 |
+
# allocate parameters
|
178 |
+
offsets = []
|
179 |
+
offset = 0
|
180 |
+
self.max_params = 2 ** log2_hashmap_size
|
181 |
+
for i in range(num_levels):
|
182 |
+
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
|
183 |
+
params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
|
184 |
+
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
|
185 |
+
offsets.append(offset)
|
186 |
+
offset += params_in_level
|
187 |
+
offsets.append(offset)
|
188 |
+
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
|
189 |
+
self.register_buffer('offsets', offsets)
|
190 |
+
|
191 |
+
self.n_params = offsets[-1] * level_dim
|
192 |
+
self.level_dim = level_dim
|
193 |
+
self.offset = offset
|
194 |
+
|
195 |
+
# parameters
|
196 |
+
self.embeddings = nn.Parameter(torch.empty(offset - hash_entries, level_dim))
|
197 |
+
|
198 |
+
self.reset_parameters()
|
199 |
+
|
200 |
+
def reset_parameters(self):
|
201 |
+
std = 1e-4
|
202 |
+
self.embeddings.data.uniform_(-std, std)
|
203 |
+
|
204 |
+
def __repr__(self):
|
205 |
+
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
|
206 |
+
|
207 |
+
def forward(self, inputs, embeddings, bound=1):
|
208 |
+
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
|
209 |
+
# return: [..., num_levels * level_dim]
|
210 |
+
input_embeddings = torch.cat([embeddings, self.embeddings], dim=0)
|
211 |
+
|
212 |
+
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
213 |
+
|
214 |
+
#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
|
215 |
+
|
216 |
+
prefix_shape = list(inputs.shape[:-1])
|
217 |
+
inputs = inputs.view(-1, self.input_dim)
|
218 |
+
|
219 |
+
outputs = grid_encode(inputs, input_embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
|
220 |
+
outputs = outputs.view(prefix_shape + [self.output_dim])
|
221 |
+
|
222 |
+
#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
|
223 |
+
|
224 |
+
return outputs
|
gridencoder/setup.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup
|
3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
4 |
+
|
5 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
|
7 |
+
nvcc_flags = [
|
8 |
+
'-O3', '-std=c++17',
|
9 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
10 |
+
]
|
11 |
+
|
12 |
+
if os.name == "posix":
|
13 |
+
c_flags = ['-O3', '-std=c++17']
|
14 |
+
elif os.name == "nt":
|
15 |
+
c_flags = ['/O2', '/std:c++17']
|
16 |
+
|
17 |
+
# find cl.exe
|
18 |
+
def find_cl_path():
|
19 |
+
import glob
|
20 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
21 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
22 |
+
if paths:
|
23 |
+
return paths[0]
|
24 |
+
|
25 |
+
# If cl.exe is not on path, try to find it.
|
26 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
27 |
+
cl_path = find_cl_path()
|
28 |
+
if cl_path is None:
|
29 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
30 |
+
os.environ["PATH"] += ";" + cl_path
|
31 |
+
|
32 |
+
setup(
|
33 |
+
name='gridencoder', # package name, import this to use python API
|
34 |
+
ext_modules=[
|
35 |
+
CUDAExtension(
|
36 |
+
name='_gridencoder', # extension name, import this to use CUDA API
|
37 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
38 |
+
'gridencoder.cu',
|
39 |
+
'bindings.cpp',
|
40 |
+
]],
|
41 |
+
extra_compile_args={
|
42 |
+
'cxx': c_flags,
|
43 |
+
'nvcc': nvcc_flags,
|
44 |
+
}
|
45 |
+
),
|
46 |
+
],
|
47 |
+
cmdclass={
|
48 |
+
'build_ext': BuildExtension,
|
49 |
+
}
|
50 |
+
)
|
gridencoder/src/bindings.cpp
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include "gridencoder.h"
|
4 |
+
|
5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
+
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
|
7 |
+
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
|
8 |
+
}
|
gridencoder/src/gridencoder.cu
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda.h>
|
2 |
+
#include <cuda_fp16.h>
|
3 |
+
#include <cuda_runtime.h>
|
4 |
+
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <torch/torch.h>
|
7 |
+
|
8 |
+
#include <algorithm>
|
9 |
+
#include <stdexcept>
|
10 |
+
|
11 |
+
#include <stdint.h>
|
12 |
+
#include <cstdio>
|
13 |
+
|
14 |
+
|
15 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
16 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
17 |
+
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
18 |
+
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
19 |
+
|
20 |
+
|
21 |
+
// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...
|
22 |
+
static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
|
23 |
+
// requires CUDA >= 10 and ARCH >= 70
|
24 |
+
// this is very slow compared to float or __half2, and never used.
|
25 |
+
//return atomicAdd(reinterpret_cast<__half*>(address), val);
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
template <typename T>
|
30 |
+
static inline __host__ __device__ T div_round_up(T val, T divisor) {
|
31 |
+
return (val + divisor - 1) / divisor;
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
template <uint32_t D>
|
36 |
+
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
|
37 |
+
static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
|
38 |
+
|
39 |
+
// While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
|
40 |
+
// and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
|
41 |
+
// coordinates.
|
42 |
+
constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };
|
43 |
+
|
44 |
+
uint32_t result = 0;
|
45 |
+
#pragma unroll
|
46 |
+
for (uint32_t i = 0; i < D; ++i) {
|
47 |
+
result ^= pos_grid[i] * primes[i];
|
48 |
+
}
|
49 |
+
|
50 |
+
return result;
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
template <uint32_t D, uint32_t C>
|
55 |
+
__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
|
56 |
+
uint32_t stride = 1;
|
57 |
+
uint32_t index = 0;
|
58 |
+
|
59 |
+
#pragma unroll
|
60 |
+
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
|
61 |
+
index += pos_grid[d] * stride;
|
62 |
+
stride *= align_corners ? resolution: (resolution + 1);
|
63 |
+
}
|
64 |
+
|
65 |
+
// NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
|
66 |
+
// gridtype: 0 == hash, 1 == tiled
|
67 |
+
if (gridtype == 0 && stride > hashmap_size) {
|
68 |
+
index = fast_hash<D>(pos_grid);
|
69 |
+
}
|
70 |
+
|
71 |
+
return (index % hashmap_size) * C + ch;
|
72 |
+
}
|
73 |
+
|
74 |
+
|
75 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
76 |
+
__global__ void kernel_grid(
|
77 |
+
const float * __restrict__ inputs,
|
78 |
+
const scalar_t * __restrict__ grid,
|
79 |
+
const int * __restrict__ offsets,
|
80 |
+
scalar_t * __restrict__ outputs,
|
81 |
+
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
82 |
+
const bool calc_grad_inputs,
|
83 |
+
scalar_t * __restrict__ dy_dx,
|
84 |
+
const uint32_t gridtype,
|
85 |
+
const bool align_corners
|
86 |
+
) {
|
87 |
+
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
88 |
+
|
89 |
+
if (b >= B) return;
|
90 |
+
|
91 |
+
const uint32_t level = blockIdx.y;
|
92 |
+
|
93 |
+
// locate
|
94 |
+
grid += (uint32_t)offsets[level] * C;
|
95 |
+
inputs += b * D;
|
96 |
+
outputs += level * B * C + b * C;
|
97 |
+
|
98 |
+
// check input range (should be in [0, 1])
|
99 |
+
bool flag_oob = false;
|
100 |
+
#pragma unroll
|
101 |
+
for (uint32_t d = 0; d < D; d++) {
|
102 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
103 |
+
flag_oob = true;
|
104 |
+
}
|
105 |
+
}
|
106 |
+
// if input out of bound, just set output to 0
|
107 |
+
if (flag_oob) {
|
108 |
+
#pragma unroll
|
109 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
110 |
+
outputs[ch] = 0;
|
111 |
+
}
|
112 |
+
if (calc_grad_inputs) {
|
113 |
+
dy_dx += b * D * L * C + level * D * C; // B L D C
|
114 |
+
#pragma unroll
|
115 |
+
for (uint32_t d = 0; d < D; d++) {
|
116 |
+
#pragma unroll
|
117 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
118 |
+
dy_dx[d * C + ch] = 0;
|
119 |
+
}
|
120 |
+
}
|
121 |
+
}
|
122 |
+
return;
|
123 |
+
}
|
124 |
+
|
125 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
126 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
127 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
128 |
+
|
129 |
+
// calculate coordinate
|
130 |
+
float pos[D];
|
131 |
+
uint32_t pos_grid[D];
|
132 |
+
|
133 |
+
#pragma unroll
|
134 |
+
for (uint32_t d = 0; d < D; d++) {
|
135 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
136 |
+
pos_grid[d] = floorf(pos[d]);
|
137 |
+
pos[d] -= (float)pos_grid[d];
|
138 |
+
}
|
139 |
+
|
140 |
+
//printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
|
141 |
+
|
142 |
+
// interpolate
|
143 |
+
scalar_t results[C] = {0}; // temp results in register
|
144 |
+
|
145 |
+
#pragma unroll
|
146 |
+
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
147 |
+
float w = 1;
|
148 |
+
uint32_t pos_grid_local[D];
|
149 |
+
|
150 |
+
#pragma unroll
|
151 |
+
for (uint32_t d = 0; d < D; d++) {
|
152 |
+
if ((idx & (1 << d)) == 0) {
|
153 |
+
w *= 1 - pos[d];
|
154 |
+
pos_grid_local[d] = pos_grid[d];
|
155 |
+
} else {
|
156 |
+
w *= pos[d];
|
157 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
158 |
+
}
|
159 |
+
}
|
160 |
+
|
161 |
+
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
162 |
+
|
163 |
+
// writing to register (fast)
|
164 |
+
#pragma unroll
|
165 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
166 |
+
results[ch] += w * grid[index + ch];
|
167 |
+
}
|
168 |
+
|
169 |
+
//printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
|
170 |
+
}
|
171 |
+
|
172 |
+
// writing to global memory (slow)
|
173 |
+
#pragma unroll
|
174 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
175 |
+
outputs[ch] = results[ch];
|
176 |
+
}
|
177 |
+
|
178 |
+
// prepare dy_dx for calc_grad_inputs
|
179 |
+
// differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
|
180 |
+
if (calc_grad_inputs) {
|
181 |
+
|
182 |
+
dy_dx += b * D * L * C + level * D * C; // B L D C
|
183 |
+
|
184 |
+
#pragma unroll
|
185 |
+
for (uint32_t gd = 0; gd < D; gd++) {
|
186 |
+
|
187 |
+
scalar_t results_grad[C] = {0};
|
188 |
+
|
189 |
+
#pragma unroll
|
190 |
+
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
|
191 |
+
float w = scale;
|
192 |
+
uint32_t pos_grid_local[D];
|
193 |
+
|
194 |
+
#pragma unroll
|
195 |
+
for (uint32_t nd = 0; nd < D - 1; nd++) {
|
196 |
+
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
|
197 |
+
|
198 |
+
if ((idx & (1 << nd)) == 0) {
|
199 |
+
w *= 1 - pos[d];
|
200 |
+
pos_grid_local[d] = pos_grid[d];
|
201 |
+
} else {
|
202 |
+
w *= pos[d];
|
203 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
204 |
+
}
|
205 |
+
}
|
206 |
+
|
207 |
+
pos_grid_local[gd] = pos_grid[gd];
|
208 |
+
uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
209 |
+
pos_grid_local[gd] = pos_grid[gd] + 1;
|
210 |
+
uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
211 |
+
|
212 |
+
#pragma unroll
|
213 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
214 |
+
results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);
|
215 |
+
}
|
216 |
+
}
|
217 |
+
|
218 |
+
#pragma unroll
|
219 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
220 |
+
dy_dx[gd * C + ch] = results_grad[ch];
|
221 |
+
}
|
222 |
+
}
|
223 |
+
}
|
224 |
+
}
|
225 |
+
|
226 |
+
|
227 |
+
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
|
228 |
+
__global__ void kernel_grid_backward(
|
229 |
+
const scalar_t * __restrict__ grad,
|
230 |
+
const float * __restrict__ inputs,
|
231 |
+
const scalar_t * __restrict__ grid,
|
232 |
+
const int * __restrict__ offsets,
|
233 |
+
scalar_t * __restrict__ grad_grid,
|
234 |
+
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
235 |
+
const uint32_t gridtype,
|
236 |
+
const bool align_corners
|
237 |
+
) {
|
238 |
+
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
|
239 |
+
if (b >= B) return;
|
240 |
+
|
241 |
+
const uint32_t level = blockIdx.y;
|
242 |
+
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
|
243 |
+
|
244 |
+
// locate
|
245 |
+
grad_grid += offsets[level] * C;
|
246 |
+
inputs += b * D;
|
247 |
+
grad += level * B * C + b * C + ch; // L, B, C
|
248 |
+
|
249 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
250 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
251 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
252 |
+
|
253 |
+
// check input range (should be in [0, 1])
|
254 |
+
#pragma unroll
|
255 |
+
for (uint32_t d = 0; d < D; d++) {
|
256 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
257 |
+
return; // grad is init as 0, so we simply return.
|
258 |
+
}
|
259 |
+
}
|
260 |
+
|
261 |
+
// calculate coordinate
|
262 |
+
float pos[D];
|
263 |
+
uint32_t pos_grid[D];
|
264 |
+
|
265 |
+
#pragma unroll
|
266 |
+
for (uint32_t d = 0; d < D; d++) {
|
267 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
268 |
+
pos_grid[d] = floorf(pos[d]);
|
269 |
+
pos[d] -= (float)pos_grid[d];
|
270 |
+
}
|
271 |
+
|
272 |
+
scalar_t grad_cur[N_C] = {0}; // fetch to register
|
273 |
+
#pragma unroll
|
274 |
+
for (uint32_t c = 0; c < N_C; c++) {
|
275 |
+
grad_cur[c] = grad[c];
|
276 |
+
}
|
277 |
+
|
278 |
+
// interpolate
|
279 |
+
#pragma unroll
|
280 |
+
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
281 |
+
float w = 1;
|
282 |
+
uint32_t pos_grid_local[D];
|
283 |
+
|
284 |
+
#pragma unroll
|
285 |
+
for (uint32_t d = 0; d < D; d++) {
|
286 |
+
if ((idx & (1 << d)) == 0) {
|
287 |
+
w *= 1 - pos[d];
|
288 |
+
pos_grid_local[d] = pos_grid[d];
|
289 |
+
} else {
|
290 |
+
w *= pos[d];
|
291 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
292 |
+
}
|
293 |
+
}
|
294 |
+
|
295 |
+
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
|
296 |
+
|
297 |
+
// atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
|
298 |
+
// TODO: use float which is better than __half, if N_C % 2 != 0
|
299 |
+
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
|
300 |
+
#pragma unroll
|
301 |
+
for (uint32_t c = 0; c < N_C; c += 2) {
|
302 |
+
// process two __half at once (by interpreting as a __half2)
|
303 |
+
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
|
304 |
+
atomicAdd((__half2*)&grad_grid[index + c], v);
|
305 |
+
}
|
306 |
+
// float, or __half when N_C % 2 != 0 (which means C == 1)
|
307 |
+
} else {
|
308 |
+
#pragma unroll
|
309 |
+
for (uint32_t c = 0; c < N_C; c++) {
|
310 |
+
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
|
311 |
+
}
|
312 |
+
}
|
313 |
+
}
|
314 |
+
}
|
315 |
+
|
316 |
+
|
317 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
318 |
+
__global__ void kernel_input_backward(
|
319 |
+
const scalar_t * __restrict__ grad,
|
320 |
+
const scalar_t * __restrict__ dy_dx,
|
321 |
+
scalar_t * __restrict__ grad_inputs,
|
322 |
+
uint32_t B, uint32_t L
|
323 |
+
) {
|
324 |
+
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
325 |
+
if (t >= B * D) return;
|
326 |
+
|
327 |
+
const uint32_t b = t / D;
|
328 |
+
const uint32_t d = t - b * D;
|
329 |
+
|
330 |
+
dy_dx += b * L * D * C;
|
331 |
+
|
332 |
+
scalar_t result = 0;
|
333 |
+
|
334 |
+
# pragma unroll
|
335 |
+
for (int l = 0; l < L; l++) {
|
336 |
+
# pragma unroll
|
337 |
+
for (int ch = 0; ch < C; ch++) {
|
338 |
+
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
|
339 |
+
}
|
340 |
+
}
|
341 |
+
|
342 |
+
grad_inputs[t] = result;
|
343 |
+
}
|
344 |
+
|
345 |
+
|
346 |
+
template <typename scalar_t, uint32_t D>
|
347 |
+
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
|
348 |
+
static constexpr uint32_t N_THREAD = 512;
|
349 |
+
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
|
350 |
+
switch (C) {
|
351 |
+
case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
|
352 |
+
case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
|
353 |
+
case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
|
354 |
+
case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
|
355 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
356 |
+
}
|
357 |
+
}
|
358 |
+
|
359 |
+
// inputs: [B, D], float, in [0, 1]
|
360 |
+
// embeddings: [sO, C], float
|
361 |
+
// offsets: [L + 1], uint32_t
|
362 |
+
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
|
363 |
+
// H: base resolution
|
364 |
+
// dy_dx: [B, L * D * C]
|
365 |
+
template <typename scalar_t>
|
366 |
+
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
|
367 |
+
switch (D) {
|
368 |
+
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
|
369 |
+
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
|
370 |
+
case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
|
371 |
+
case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
|
372 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
373 |
+
}
|
374 |
+
|
375 |
+
}
|
376 |
+
|
377 |
+
template <typename scalar_t, uint32_t D>
|
378 |
+
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
379 |
+
static constexpr uint32_t N_THREAD = 256;
|
380 |
+
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
|
381 |
+
const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
|
382 |
+
switch (C) {
|
383 |
+
case 1:
|
384 |
+
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
385 |
+
if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
386 |
+
break;
|
387 |
+
case 2:
|
388 |
+
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
389 |
+
if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
390 |
+
break;
|
391 |
+
case 4:
|
392 |
+
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
393 |
+
if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
394 |
+
break;
|
395 |
+
case 8:
|
396 |
+
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
397 |
+
if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
398 |
+
break;
|
399 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
400 |
+
}
|
401 |
+
}
|
402 |
+
|
403 |
+
|
404 |
+
// grad: [L, B, C], float
|
405 |
+
// inputs: [B, D], float, in [0, 1]
|
406 |
+
// embeddings: [sO, C], float
|
407 |
+
// offsets: [L + 1], uint32_t
|
408 |
+
// grad_embeddings: [sO, C]
|
409 |
+
// H: base resolution
|
410 |
+
template <typename scalar_t>
|
411 |
+
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
412 |
+
switch (D) {
|
413 |
+
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
|
414 |
+
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
|
415 |
+
case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
|
416 |
+
case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
|
417 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
418 |
+
}
|
419 |
+
}
|
420 |
+
|
421 |
+
|
422 |
+
|
423 |
+
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners) {
|
424 |
+
CHECK_CUDA(inputs);
|
425 |
+
CHECK_CUDA(embeddings);
|
426 |
+
CHECK_CUDA(offsets);
|
427 |
+
CHECK_CUDA(outputs);
|
428 |
+
CHECK_CUDA(dy_dx);
|
429 |
+
|
430 |
+
CHECK_CONTIGUOUS(inputs);
|
431 |
+
CHECK_CONTIGUOUS(embeddings);
|
432 |
+
CHECK_CONTIGUOUS(offsets);
|
433 |
+
CHECK_CONTIGUOUS(outputs);
|
434 |
+
CHECK_CONTIGUOUS(dy_dx);
|
435 |
+
|
436 |
+
CHECK_IS_FLOATING(inputs);
|
437 |
+
CHECK_IS_FLOATING(embeddings);
|
438 |
+
CHECK_IS_INT(offsets);
|
439 |
+
CHECK_IS_FLOATING(outputs);
|
440 |
+
CHECK_IS_FLOATING(dy_dx);
|
441 |
+
|
442 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
443 |
+
embeddings.scalar_type(), "grid_encode_forward", ([&] {
|
444 |
+
grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), gridtype, align_corners);
|
445 |
+
}));
|
446 |
+
}
|
447 |
+
|
448 |
+
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
449 |
+
CHECK_CUDA(grad);
|
450 |
+
CHECK_CUDA(inputs);
|
451 |
+
CHECK_CUDA(embeddings);
|
452 |
+
CHECK_CUDA(offsets);
|
453 |
+
CHECK_CUDA(grad_embeddings);
|
454 |
+
CHECK_CUDA(dy_dx);
|
455 |
+
CHECK_CUDA(grad_inputs);
|
456 |
+
|
457 |
+
CHECK_CONTIGUOUS(grad);
|
458 |
+
CHECK_CONTIGUOUS(inputs);
|
459 |
+
CHECK_CONTIGUOUS(embeddings);
|
460 |
+
CHECK_CONTIGUOUS(offsets);
|
461 |
+
CHECK_CONTIGUOUS(grad_embeddings);
|
462 |
+
CHECK_CONTIGUOUS(dy_dx);
|
463 |
+
CHECK_CONTIGUOUS(grad_inputs);
|
464 |
+
|
465 |
+
CHECK_IS_FLOATING(grad);
|
466 |
+
CHECK_IS_FLOATING(inputs);
|
467 |
+
CHECK_IS_FLOATING(embeddings);
|
468 |
+
CHECK_IS_INT(offsets);
|
469 |
+
CHECK_IS_FLOATING(grad_embeddings);
|
470 |
+
CHECK_IS_FLOATING(dy_dx);
|
471 |
+
CHECK_IS_FLOATING(grad_inputs);
|
472 |
+
|
473 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
474 |
+
grad.scalar_type(), "grid_encode_backward", ([&] {
|
475 |
+
grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>(), gridtype, align_corners);
|
476 |
+
}));
|
477 |
+
|
478 |
+
}
|
gridencoder/src/gridencoder.h
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _HASH_ENCODE_H
|
2 |
+
#define _HASH_ENCODE_H
|
3 |
+
|
4 |
+
#include <stdint.h>
|
5 |
+
#include <torch/torch.h>
|
6 |
+
|
7 |
+
// inputs: [B, D], float, in [0, 1]
|
8 |
+
// embeddings: [sO, C], float
|
9 |
+
// offsets: [L + 1], uint32_t
|
10 |
+
// outputs: [B, L * C], float
|
11 |
+
// H: base resolution
|
12 |
+
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners);
|
13 |
+
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners);
|
14 |
+
|
15 |
+
#endif
|
imaginaire/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
imaginaire/config.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
"""Config utilities for yml file."""
|
6 |
+
|
7 |
+
import collections
|
8 |
+
import functools
|
9 |
+
import os
|
10 |
+
import re
|
11 |
+
|
12 |
+
import yaml
|
13 |
+
from imaginaire.utils.distributed import master_only_print as print
|
14 |
+
|
15 |
+
DEBUG = False
|
16 |
+
USE_JIT = False
|
17 |
+
|
18 |
+
|
19 |
+
class AttrDict(dict):
|
20 |
+
"""Dict as attribute trick."""
|
21 |
+
|
22 |
+
def __init__(self, *args, **kwargs):
|
23 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
24 |
+
self.__dict__ = self
|
25 |
+
for key, value in self.__dict__.items():
|
26 |
+
if isinstance(value, dict):
|
27 |
+
self.__dict__[key] = AttrDict(value)
|
28 |
+
elif isinstance(value, (list, tuple)):
|
29 |
+
if isinstance(value[0], dict):
|
30 |
+
self.__dict__[key] = [AttrDict(item) for item in value]
|
31 |
+
else:
|
32 |
+
self.__dict__[key] = value
|
33 |
+
|
34 |
+
def yaml(self):
|
35 |
+
"""Convert object to yaml dict and return."""
|
36 |
+
yaml_dict = {}
|
37 |
+
for key, value in self.__dict__.items():
|
38 |
+
if isinstance(value, AttrDict):
|
39 |
+
yaml_dict[key] = value.yaml()
|
40 |
+
elif isinstance(value, list):
|
41 |
+
if isinstance(value[0], AttrDict):
|
42 |
+
new_l = []
|
43 |
+
for item in value:
|
44 |
+
new_l.append(item.yaml())
|
45 |
+
yaml_dict[key] = new_l
|
46 |
+
else:
|
47 |
+
yaml_dict[key] = value
|
48 |
+
else:
|
49 |
+
yaml_dict[key] = value
|
50 |
+
return yaml_dict
|
51 |
+
|
52 |
+
def __repr__(self):
|
53 |
+
"""Print all variables."""
|
54 |
+
ret_str = []
|
55 |
+
for key, value in self.__dict__.items():
|
56 |
+
if isinstance(value, AttrDict):
|
57 |
+
ret_str.append('{}:'.format(key))
|
58 |
+
child_ret_str = value.__repr__().split('\n')
|
59 |
+
for item in child_ret_str:
|
60 |
+
ret_str.append(' ' + item)
|
61 |
+
elif isinstance(value, list):
|
62 |
+
if isinstance(value[0], AttrDict):
|
63 |
+
ret_str.append('{}:'.format(key))
|
64 |
+
for item in value:
|
65 |
+
# Treat as AttrDict above.
|
66 |
+
child_ret_str = item.__repr__().split('\n')
|
67 |
+
for item in child_ret_str:
|
68 |
+
ret_str.append(' ' + item)
|
69 |
+
else:
|
70 |
+
ret_str.append('{}: {}'.format(key, value))
|
71 |
+
else:
|
72 |
+
ret_str.append('{}: {}'.format(key, value))
|
73 |
+
return '\n'.join(ret_str)
|
74 |
+
|
75 |
+
|
76 |
+
class Config(AttrDict):
|
77 |
+
r"""Configuration class. This should include every human specifiable
|
78 |
+
hyperparameter values for your training."""
|
79 |
+
|
80 |
+
def __init__(self, filename=None, verbose=False):
|
81 |
+
super(Config, self).__init__()
|
82 |
+
self.source_filename = filename
|
83 |
+
# Set default parameters.
|
84 |
+
# Logging.
|
85 |
+
large_number = 1000000000
|
86 |
+
self.snapshot_save_iter = large_number
|
87 |
+
self.snapshot_save_epoch = large_number
|
88 |
+
self.metrics_iter = None
|
89 |
+
self.metrics_epoch = None
|
90 |
+
self.snapshot_save_start_iter = 0
|
91 |
+
self.snapshot_save_start_epoch = 0
|
92 |
+
self.image_save_iter = large_number
|
93 |
+
self.image_display_iter = large_number
|
94 |
+
self.max_epoch = large_number
|
95 |
+
self.max_iter = large_number
|
96 |
+
self.logging_iter = 100
|
97 |
+
self.speed_benchmark = False
|
98 |
+
|
99 |
+
# Trainer.
|
100 |
+
self.trainer = AttrDict(
|
101 |
+
model_average_config=AttrDict(enabled=False,
|
102 |
+
beta=0.9999,
|
103 |
+
start_iteration=1000,
|
104 |
+
num_batch_norm_estimation_iterations=30,
|
105 |
+
remove_sn=True),
|
106 |
+
# model_average=False,
|
107 |
+
# model_average_beta=0.9999,
|
108 |
+
# model_average_start_iteration=1000,
|
109 |
+
# model_average_batch_norm_estimation_iteration=30,
|
110 |
+
# model_average_remove_sn=True,
|
111 |
+
image_to_tensorboard=False,
|
112 |
+
hparam_to_tensorboard=False,
|
113 |
+
distributed_data_parallel='pytorch',
|
114 |
+
distributed_data_parallel_params=AttrDict(
|
115 |
+
find_unused_parameters=False),
|
116 |
+
delay_allreduce=True,
|
117 |
+
gan_relativistic=False,
|
118 |
+
gen_step=1,
|
119 |
+
dis_step=1,
|
120 |
+
gan_decay_k=1.,
|
121 |
+
gan_min_k=1.,
|
122 |
+
gan_separate_topk=False,
|
123 |
+
aug_policy='',
|
124 |
+
channels_last=False,
|
125 |
+
strict_resume=True,
|
126 |
+
amp_gp=False,
|
127 |
+
amp_config=AttrDict(init_scale=65536.0,
|
128 |
+
growth_factor=2.0,
|
129 |
+
backoff_factor=0.5,
|
130 |
+
growth_interval=2000,
|
131 |
+
enabled=False))
|
132 |
+
|
133 |
+
# Networks.
|
134 |
+
self.gen = AttrDict(type='imaginaire.generators.dummy')
|
135 |
+
self.dis = AttrDict(type='imaginaire.discriminators.dummy')
|
136 |
+
|
137 |
+
# Optimizers.
|
138 |
+
self.gen_opt = AttrDict(type='adam',
|
139 |
+
fused_opt=False,
|
140 |
+
lr=0.0001,
|
141 |
+
adam_beta1=0.0,
|
142 |
+
adam_beta2=0.999,
|
143 |
+
eps=1e-8,
|
144 |
+
lr_policy=AttrDict(iteration_mode=False,
|
145 |
+
type='step',
|
146 |
+
step_size=large_number,
|
147 |
+
gamma=1))
|
148 |
+
self.dis_opt = AttrDict(type='adam',
|
149 |
+
fused_opt=False,
|
150 |
+
lr=0.0001,
|
151 |
+
adam_beta1=0.0,
|
152 |
+
adam_beta2=0.999,
|
153 |
+
eps=1e-8,
|
154 |
+
lr_policy=AttrDict(iteration_mode=False,
|
155 |
+
type='step',
|
156 |
+
step_size=large_number,
|
157 |
+
gamma=1))
|
158 |
+
# Data.
|
159 |
+
self.data = AttrDict(name='dummy',
|
160 |
+
type='imaginaire.datasets.images',
|
161 |
+
num_workers=0)
|
162 |
+
self.test_data = AttrDict(name='dummy',
|
163 |
+
type='imaginaire.datasets.images',
|
164 |
+
num_workers=0,
|
165 |
+
test=AttrDict(is_lmdb=False,
|
166 |
+
roots='',
|
167 |
+
batch_size=1))
|
168 |
+
|
169 |
+
|
170 |
+
# Cudnn.
|
171 |
+
self.cudnn = AttrDict(deterministic=False,
|
172 |
+
benchmark=True)
|
173 |
+
|
174 |
+
# Others.
|
175 |
+
self.pretrained_weight = ''
|
176 |
+
self.inference_args = AttrDict()
|
177 |
+
|
178 |
+
# Update with given configurations.
|
179 |
+
assert os.path.exists(filename), 'File {} not exist.'.format(filename)
|
180 |
+
loader = yaml.SafeLoader
|
181 |
+
loader.add_implicit_resolver(
|
182 |
+
u'tag:yaml.org,2002:float',
|
183 |
+
re.compile(u'''^(?:
|
184 |
+
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
185 |
+
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
186 |
+
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
187 |
+
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
188 |
+
|[-+]?\\.(?:inf|Inf|INF)
|
189 |
+
|\\.(?:nan|NaN|NAN))$''', re.X),
|
190 |
+
list(u'-+0123456789.'))
|
191 |
+
try:
|
192 |
+
with open(filename, 'r') as f:
|
193 |
+
cfg_dict = yaml.load(f, Loader=loader)
|
194 |
+
except EnvironmentError:
|
195 |
+
print('Please check the file with name of "%s"', filename)
|
196 |
+
recursive_update(self, cfg_dict)
|
197 |
+
|
198 |
+
# Put common opts in both gen and dis.
|
199 |
+
if 'common' in cfg_dict:
|
200 |
+
self.common = AttrDict(**cfg_dict['common'])
|
201 |
+
self.gen.common = self.common
|
202 |
+
self.dis.common = self.common
|
203 |
+
|
204 |
+
if verbose:
|
205 |
+
print(' imaginaire config '.center(80, '-'))
|
206 |
+
print(self.__repr__())
|
207 |
+
print(''.center(80, '-'))
|
208 |
+
|
209 |
+
|
210 |
+
def rsetattr(obj, attr, val):
|
211 |
+
"""Recursively find object and set value"""
|
212 |
+
pre, _, post = attr.rpartition('.')
|
213 |
+
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
|
214 |
+
|
215 |
+
|
216 |
+
def rgetattr(obj, attr, *args):
|
217 |
+
"""Recursively find object and return value"""
|
218 |
+
|
219 |
+
def _getattr(obj, attr):
|
220 |
+
r"""Get attribute."""
|
221 |
+
return getattr(obj, attr, *args)
|
222 |
+
|
223 |
+
return functools.reduce(_getattr, [obj] + attr.split('.'))
|
224 |
+
|
225 |
+
|
226 |
+
def recursive_update(d, u):
|
227 |
+
"""Recursively update AttrDict d with AttrDict u"""
|
228 |
+
for key, value in u.items():
|
229 |
+
if isinstance(value, collections.abc.Mapping):
|
230 |
+
d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
|
231 |
+
elif isinstance(value, (list, tuple)):
|
232 |
+
if isinstance(value[0], dict):
|
233 |
+
d.__dict__[key] = [AttrDict(item) for item in value]
|
234 |
+
else:
|
235 |
+
d.__dict__[key] = value
|
236 |
+
else:
|
237 |
+
d.__dict__[key] = value
|
238 |
+
return d
|
imaginaire/discriminators/__init__.py
ADDED
File without changes
|
imaginaire/discriminators/gancraft.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import functools
|
10 |
+
from imaginaire.layers import Conv2dBlock
|
11 |
+
|
12 |
+
from imaginaire.utils.data import get_paired_input_label_channel_number, get_paired_input_image_channel_number
|
13 |
+
from imaginaire.utils.distributed import master_only_print as print
|
14 |
+
|
15 |
+
|
16 |
+
class Discriminator(nn.Module):
|
17 |
+
r"""Multi-resolution patch discriminator. Based on FPSE discriminator but with N+1 labels.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
dis_cfg (obj): Discriminator definition part of the yaml config file.
|
21 |
+
data_cfg (obj): Data definition part of the yaml config file.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, dis_cfg, data_cfg):
|
25 |
+
super(Discriminator, self).__init__()
|
26 |
+
# We assume the first datum is the ground truth image.
|
27 |
+
image_channels = get_paired_input_image_channel_number(data_cfg)
|
28 |
+
# Calculate number of channels in the input label.
|
29 |
+
num_labels = get_paired_input_label_channel_number(data_cfg)
|
30 |
+
|
31 |
+
self.use_label = getattr(dis_cfg, 'use_label', True)
|
32 |
+
# Override number of input channels
|
33 |
+
if hasattr(dis_cfg, 'image_channels'):
|
34 |
+
image_channels = dis_cfg.image_channels
|
35 |
+
if hasattr(dis_cfg, 'num_labels'):
|
36 |
+
num_labels = dis_cfg.num_labels
|
37 |
+
else:
|
38 |
+
# We assume the first datum is the ground truth image.
|
39 |
+
image_channels = get_paired_input_image_channel_number(data_cfg)
|
40 |
+
# Calculate number of channels in the input label.
|
41 |
+
num_labels = get_paired_input_label_channel_number(data_cfg)
|
42 |
+
|
43 |
+
if not self.use_label:
|
44 |
+
num_labels = 2 # ignore + true
|
45 |
+
|
46 |
+
# Build the discriminator.
|
47 |
+
num_filters = getattr(dis_cfg, 'num_filters', 128)
|
48 |
+
weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
|
49 |
+
|
50 |
+
fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3)
|
51 |
+
fpse_activation_norm_type = getattr(dis_cfg,
|
52 |
+
'fpse_activation_norm_type',
|
53 |
+
'none')
|
54 |
+
do_multiscale = getattr(dis_cfg, 'do_multiscale', False)
|
55 |
+
smooth_resample = getattr(dis_cfg, 'smooth_resample', False)
|
56 |
+
no_label_except_largest_scale = getattr(dis_cfg, 'no_label_except_largest_scale', False)
|
57 |
+
|
58 |
+
self.fpse_discriminator = FPSEDiscriminator(
|
59 |
+
image_channels,
|
60 |
+
num_labels,
|
61 |
+
num_filters,
|
62 |
+
fpse_kernel_size,
|
63 |
+
weight_norm_type,
|
64 |
+
fpse_activation_norm_type,
|
65 |
+
do_multiscale,
|
66 |
+
smooth_resample,
|
67 |
+
no_label_except_largest_scale)
|
68 |
+
|
69 |
+
def _single_forward(self, input_label, input_image, weights):
|
70 |
+
output_list, features_list = self.fpse_discriminator(input_image, input_label, weights)
|
71 |
+
return output_list, [features_list]
|
72 |
+
|
73 |
+
def forward(self, data, net_G_output, weights=None, incl_real=False, incl_pseudo_real=False):
|
74 |
+
r"""GANcraft discriminator forward.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
data (dict):
|
78 |
+
- data (N x C1 x H x W tensor) : Ground truth images.
|
79 |
+
- label (N x C2 x H x W tensor) : Semantic representations.
|
80 |
+
- z (N x style_dims tensor): Gaussian random noise.
|
81 |
+
net_G_output (dict):
|
82 |
+
- fake_images (N x C1 x H x W tensor) : Fake images.
|
83 |
+
Returns:
|
84 |
+
output_x (dict):
|
85 |
+
- real_outputs (list): list of output tensors produced by
|
86 |
+
individual patch discriminators for real images.
|
87 |
+
- real_features (list): list of lists of features produced by
|
88 |
+
individual patch discriminators for real images.
|
89 |
+
- fake_outputs (list): list of output tensors produced by
|
90 |
+
individual patch discriminators for fake images.
|
91 |
+
- fake_features (list): list of lists of features produced by
|
92 |
+
individual patch discriminators for fake images.
|
93 |
+
"""
|
94 |
+
output_x = dict()
|
95 |
+
|
96 |
+
# Fake.
|
97 |
+
fake_images = net_G_output['fake_images']
|
98 |
+
if self.use_label:
|
99 |
+
fake_labels = data['fake_masks']
|
100 |
+
else:
|
101 |
+
fake_labels = torch.zeros([fake_images.size(0), 2, fake_images.size(
|
102 |
+
2), fake_images.size(3)], device=fake_images.device, dtype=fake_images.dtype)
|
103 |
+
fake_labels[:, 1, :, :] = 1
|
104 |
+
output_x['fake_outputs'], output_x['fake_features'] = \
|
105 |
+
self._single_forward(fake_labels, fake_images, None)
|
106 |
+
|
107 |
+
# Real.
|
108 |
+
if incl_real:
|
109 |
+
real_images = data['images']
|
110 |
+
if self.use_label:
|
111 |
+
real_labels = data['real_masks']
|
112 |
+
else:
|
113 |
+
real_labels = torch.zeros([real_images.size(0), 2, real_images.size(
|
114 |
+
2), real_images.size(3)], device=real_images.device, dtype=real_images.dtype)
|
115 |
+
real_labels[:, 1, :, :] = 1
|
116 |
+
output_x['real_outputs'], output_x['real_features'] = \
|
117 |
+
self._single_forward(real_labels, real_images, None)
|
118 |
+
|
119 |
+
# pseudo-Real.
|
120 |
+
if incl_pseudo_real:
|
121 |
+
preal_images = data['pseudo_real_img']
|
122 |
+
preal_labels = data['fake_masks']
|
123 |
+
if not self.use_label:
|
124 |
+
preal_labels = torch.zeros([preal_images.size(0), 2, preal_images.size(
|
125 |
+
2), preal_images.size(3)], device=preal_images.device, dtype=preal_images.dtype)
|
126 |
+
preal_labels[:, 1, :, :] = 1
|
127 |
+
output_x['pseudo_real_outputs'], output_x['pseudo_real_features'] = \
|
128 |
+
self._single_forward(preal_labels, preal_images, None)
|
129 |
+
|
130 |
+
return output_x
|
131 |
+
|
132 |
+
|
133 |
+
class FPSEDiscriminator(nn.Module):
|
134 |
+
def __init__(self,
|
135 |
+
num_input_channels,
|
136 |
+
num_labels,
|
137 |
+
num_filters,
|
138 |
+
kernel_size,
|
139 |
+
weight_norm_type,
|
140 |
+
activation_norm_type,
|
141 |
+
do_multiscale,
|
142 |
+
smooth_resample,
|
143 |
+
no_label_except_largest_scale):
|
144 |
+
super().__init__()
|
145 |
+
|
146 |
+
self.do_multiscale = do_multiscale
|
147 |
+
self.no_label_except_largest_scale = no_label_except_largest_scale
|
148 |
+
|
149 |
+
padding = int(np.ceil((kernel_size - 1.0) / 2))
|
150 |
+
nonlinearity = 'leakyrelu'
|
151 |
+
stride1_conv2d_block = \
|
152 |
+
functools.partial(Conv2dBlock,
|
153 |
+
kernel_size=kernel_size,
|
154 |
+
stride=1,
|
155 |
+
padding=padding,
|
156 |
+
weight_norm_type=weight_norm_type,
|
157 |
+
activation_norm_type=activation_norm_type,
|
158 |
+
nonlinearity=nonlinearity,
|
159 |
+
# inplace_nonlinearity=True,
|
160 |
+
order='CNA')
|
161 |
+
down_conv2d_block = \
|
162 |
+
functools.partial(Conv2dBlock,
|
163 |
+
kernel_size=kernel_size,
|
164 |
+
stride=2,
|
165 |
+
padding=padding,
|
166 |
+
weight_norm_type=weight_norm_type,
|
167 |
+
activation_norm_type=activation_norm_type,
|
168 |
+
nonlinearity=nonlinearity,
|
169 |
+
# inplace_nonlinearity=True,
|
170 |
+
order='CNA')
|
171 |
+
latent_conv2d_block = \
|
172 |
+
functools.partial(Conv2dBlock,
|
173 |
+
kernel_size=1,
|
174 |
+
stride=1,
|
175 |
+
weight_norm_type=weight_norm_type,
|
176 |
+
activation_norm_type=activation_norm_type,
|
177 |
+
nonlinearity=nonlinearity,
|
178 |
+
# inplace_nonlinearity=True,
|
179 |
+
order='CNA')
|
180 |
+
# bottom-up pathway
|
181 |
+
self.enc1 = down_conv2d_block(num_input_channels, num_filters) # 3
|
182 |
+
self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters) # 7
|
183 |
+
self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters) # 15
|
184 |
+
self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters) # 31
|
185 |
+
self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters) # 63
|
186 |
+
|
187 |
+
# top-down pathway
|
188 |
+
# self.lat1 = latent_conv2d_block(num_filters, 2 * num_filters) # Zekun
|
189 |
+
self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters)
|
190 |
+
self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters)
|
191 |
+
self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
|
192 |
+
self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
|
193 |
+
|
194 |
+
# upsampling
|
195 |
+
self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear',
|
196 |
+
align_corners=False)
|
197 |
+
|
198 |
+
# final layers
|
199 |
+
self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
|
200 |
+
self.output = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1)
|
201 |
+
|
202 |
+
if self.do_multiscale:
|
203 |
+
self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
|
204 |
+
self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
|
205 |
+
if self.no_label_except_largest_scale:
|
206 |
+
self.output3 = Conv2dBlock(num_filters * 2, 2, kernel_size=1)
|
207 |
+
self.output4 = Conv2dBlock(num_filters * 2, 2, kernel_size=1)
|
208 |
+
else:
|
209 |
+
self.output3 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1)
|
210 |
+
self.output4 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1)
|
211 |
+
|
212 |
+
self.interpolator = functools.partial(F.interpolate, mode='nearest')
|
213 |
+
if smooth_resample:
|
214 |
+
self.interpolator = self.smooth_interp
|
215 |
+
|
216 |
+
@staticmethod
|
217 |
+
def smooth_interp(x, size):
|
218 |
+
r"""Smooth interpolation of segmentation maps.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
x (4D tensor): Segmentation maps.
|
222 |
+
size(2D list): Target size (H, W).
|
223 |
+
"""
|
224 |
+
x = F.interpolate(x, size=size, mode='area')
|
225 |
+
onehot_idx = torch.argmax(x, dim=-3, keepdims=True)
|
226 |
+
x.fill_(0.0)
|
227 |
+
x.scatter_(1, onehot_idx, 1.0)
|
228 |
+
return x
|
229 |
+
|
230 |
+
# Weights: [N C]
|
231 |
+
def forward(self, images, segmaps, weights=None):
|
232 |
+
# Assume images 256x256
|
233 |
+
# bottom-up pathway
|
234 |
+
feat11 = self.enc1(images) # 128
|
235 |
+
feat12 = self.enc2(feat11) # 64
|
236 |
+
feat13 = self.enc3(feat12) # 32
|
237 |
+
feat14 = self.enc4(feat13) # 16
|
238 |
+
feat15 = self.enc5(feat14) # 8
|
239 |
+
# top-down pathway and lateral connections
|
240 |
+
feat25 = self.lat5(feat15) # 8
|
241 |
+
feat24 = self.upsample2x(feat25) + self.lat4(feat14) # 16
|
242 |
+
feat23 = self.upsample2x(feat24) + self.lat3(feat13) # 32
|
243 |
+
feat22 = self.upsample2x(feat23) + self.lat2(feat12) # 64
|
244 |
+
|
245 |
+
# final prediction layers
|
246 |
+
feat32 = self.final2(feat22)
|
247 |
+
|
248 |
+
results = []
|
249 |
+
label_map = self.interpolator(segmaps, size=feat32.size()[2:])
|
250 |
+
pred2 = self.output(feat32) # N, num_labels+1, H//4, W//4
|
251 |
+
|
252 |
+
features = [feat11, feat12, feat13, feat14, feat15, feat25, feat24, feat23, feat22]
|
253 |
+
if weights is not None:
|
254 |
+
label_map = label_map * weights[..., None, None]
|
255 |
+
results.append({'pred': pred2, 'label': label_map})
|
256 |
+
|
257 |
+
if self.do_multiscale:
|
258 |
+
feat33 = self.final3(feat23)
|
259 |
+
pred3 = self.output3(feat33)
|
260 |
+
|
261 |
+
feat34 = self.final4(feat24)
|
262 |
+
pred4 = self.output4(feat34)
|
263 |
+
|
264 |
+
if self.no_label_except_largest_scale:
|
265 |
+
label_map3 = torch.ones([pred3.size(0), 1, pred3.size(2), pred3.size(3)], device=pred3.device)
|
266 |
+
label_map4 = torch.ones([pred4.size(0), 1, pred4.size(2), pred4.size(3)], device=pred4.device)
|
267 |
+
else:
|
268 |
+
label_map3 = self.interpolator(segmaps, size=pred3.size()[2:])
|
269 |
+
label_map4 = self.interpolator(segmaps, size=pred4.size()[2:])
|
270 |
+
|
271 |
+
if weights is not None:
|
272 |
+
label_map3 = label_map3 * weights[..., None, None]
|
273 |
+
label_map4 = label_map4 * weights[..., None, None]
|
274 |
+
|
275 |
+
results.append({'pred': pred3, 'label': label_map3})
|
276 |
+
results.append({'pred': pred4, 'label': label_map4})
|
277 |
+
|
278 |
+
return results, features
|
imaginaire/generators/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
imaginaire/generators/gancraft_base.py
ADDED
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import functools
|
6 |
+
import re
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from imaginaire.layers import Conv2dBlock, LinearBlock
|
14 |
+
from imaginaire.model_utils.layers import AffineMod, ModLinear
|
15 |
+
import imaginaire.model_utils.gancraft.mc_utils as mc_utils
|
16 |
+
import imaginaire.model_utils.gancraft.voxlib as voxlib
|
17 |
+
from imaginaire.utils.distributed import master_only_print as print
|
18 |
+
|
19 |
+
|
20 |
+
class RenderMLP(nn.Module):
|
21 |
+
r""" MLP with affine modulation."""
|
22 |
+
|
23 |
+
def __init__(self, in_channels, style_dim, viewdir_dim, mask_dim=680,
|
24 |
+
out_channels_s=1, out_channels_c=3, hidden_channels=256,
|
25 |
+
use_seg=True):
|
26 |
+
super(RenderMLP, self).__init__()
|
27 |
+
|
28 |
+
self.use_seg = use_seg
|
29 |
+
if self.use_seg:
|
30 |
+
self.fc_m_a = nn.Linear(mask_dim, hidden_channels, bias=False)
|
31 |
+
|
32 |
+
self.fc_viewdir = None
|
33 |
+
if viewdir_dim > 0:
|
34 |
+
self.fc_viewdir = nn.Linear(viewdir_dim, hidden_channels, bias=False)
|
35 |
+
|
36 |
+
self.fc_1 = nn.Linear(in_channels, hidden_channels)
|
37 |
+
|
38 |
+
self.fc_2 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
|
39 |
+
self.fc_3 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
|
40 |
+
self.fc_4 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
|
41 |
+
|
42 |
+
self.fc_sigma = nn.Linear(hidden_channels, out_channels_s)
|
43 |
+
|
44 |
+
if viewdir_dim > 0:
|
45 |
+
self.fc_5 = nn.Linear(hidden_channels, hidden_channels, bias=False)
|
46 |
+
self.mod_5 = AffineMod(hidden_channels, style_dim, mod_bias=True)
|
47 |
+
else:
|
48 |
+
self.fc_5 = ModLinear(hidden_channels, hidden_channels, style_dim,
|
49 |
+
bias=False, mod_bias=True, output_mode=True)
|
50 |
+
self.fc_6 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
|
51 |
+
self.fc_out_c = nn.Linear(hidden_channels, out_channels_c)
|
52 |
+
|
53 |
+
self.act = nn.LeakyReLU(negative_slope=0.2)
|
54 |
+
|
55 |
+
def forward(self, x, raydir, z, m):
|
56 |
+
r""" Forward network
|
57 |
+
|
58 |
+
Args:
|
59 |
+
x (N x H x W x M x in_channels tensor): Projected features.
|
60 |
+
raydir (N x H x W x 1 x viewdir_dim tensor): Ray directions.
|
61 |
+
z (N x style_dim tensor): Style codes.
|
62 |
+
m (N x H x W x M x mask_dim tensor): One-hot segmentation maps.
|
63 |
+
"""
|
64 |
+
b, h, w, n, _ = x.size()
|
65 |
+
z = z[:, None, None, None, :]
|
66 |
+
|
67 |
+
f = self.fc_1(x)
|
68 |
+
if self.use_seg:
|
69 |
+
f = f + self.fc_m_a(m)
|
70 |
+
# Common MLP
|
71 |
+
f = self.act(f)
|
72 |
+
f = self.act(self.fc_2(f, z))
|
73 |
+
f = self.act(self.fc_3(f, z))
|
74 |
+
f = self.act(self.fc_4(f, z))
|
75 |
+
|
76 |
+
# Sigma MLP
|
77 |
+
sigma = self.fc_sigma(f)
|
78 |
+
|
79 |
+
# Color MLP
|
80 |
+
if self.fc_viewdir is not None:
|
81 |
+
f = self.fc_5(f)
|
82 |
+
f = f + self.fc_viewdir(raydir)
|
83 |
+
f = self.act(self.mod_5(f, z))
|
84 |
+
else:
|
85 |
+
f = self.act(self.fc_5(f, z))
|
86 |
+
f = self.act(self.fc_6(f, z))
|
87 |
+
c = self.fc_out_c(f)
|
88 |
+
return sigma, c
|
89 |
+
|
90 |
+
|
91 |
+
class StyleMLP(nn.Module):
|
92 |
+
r"""MLP converting style code to intermediate style representation."""
|
93 |
+
|
94 |
+
def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True,
|
95 |
+
output_act=True):
|
96 |
+
super(StyleMLP, self).__init__()
|
97 |
+
|
98 |
+
self.normalize_input = normalize_input
|
99 |
+
self.output_act = output_act
|
100 |
+
fc_layers = []
|
101 |
+
fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True))
|
102 |
+
for i in range(num_layers-1):
|
103 |
+
fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True))
|
104 |
+
self.fc_layers = nn.ModuleList(fc_layers)
|
105 |
+
|
106 |
+
self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True)
|
107 |
+
|
108 |
+
if leaky_relu:
|
109 |
+
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
110 |
+
else:
|
111 |
+
self.act = functools.partial(F.relu, inplace=True)
|
112 |
+
|
113 |
+
def forward(self, z):
|
114 |
+
r""" Forward network
|
115 |
+
|
116 |
+
Args:
|
117 |
+
z (N x style_dim tensor): Style codes.
|
118 |
+
"""
|
119 |
+
if self.normalize_input:
|
120 |
+
z = F.normalize(z, p=2, dim=-1)
|
121 |
+
for fc_layer in self.fc_layers:
|
122 |
+
z = self.act(fc_layer(z))
|
123 |
+
z = self.fc_out(z)
|
124 |
+
if self.output_act:
|
125 |
+
z = self.act(z)
|
126 |
+
return z
|
127 |
+
|
128 |
+
|
129 |
+
class SKYMLP(nn.Module):
|
130 |
+
r"""MLP converting ray directions to sky features."""
|
131 |
+
|
132 |
+
def __init__(self, in_channels, style_dim, out_channels_c=3,
|
133 |
+
hidden_channels=256, leaky_relu=True):
|
134 |
+
super(SKYMLP, self).__init__()
|
135 |
+
self.fc_z_a = nn.Linear(style_dim, hidden_channels, bias=False)
|
136 |
+
|
137 |
+
self.fc1 = nn.Linear(in_channels, hidden_channels)
|
138 |
+
self.fc2 = nn.Linear(hidden_channels, hidden_channels)
|
139 |
+
self.fc3 = nn.Linear(hidden_channels, hidden_channels)
|
140 |
+
self.fc4 = nn.Linear(hidden_channels, hidden_channels)
|
141 |
+
self.fc5 = nn.Linear(hidden_channels, hidden_channels)
|
142 |
+
|
143 |
+
self.fc_out_c = nn.Linear(hidden_channels, out_channels_c)
|
144 |
+
|
145 |
+
if leaky_relu:
|
146 |
+
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
147 |
+
else:
|
148 |
+
self.act = functools.partial(F.relu, inplace=True)
|
149 |
+
|
150 |
+
def forward(self, x, z):
|
151 |
+
r"""Forward network
|
152 |
+
|
153 |
+
Args:
|
154 |
+
x (... x in_channels tensor): Ray direction embeddings.
|
155 |
+
z (... x style_dim tensor): Style codes.
|
156 |
+
"""
|
157 |
+
|
158 |
+
z = self.fc_z_a(z)
|
159 |
+
while z.dim() < x.dim():
|
160 |
+
z = z.unsqueeze(1)
|
161 |
+
|
162 |
+
y = self.act(self.fc1(x) + z)
|
163 |
+
y = self.act(self.fc2(y))
|
164 |
+
y = self.act(self.fc3(y))
|
165 |
+
y = self.act(self.fc4(y))
|
166 |
+
y = self.act(self.fc5(y))
|
167 |
+
c = self.fc_out_c(y)
|
168 |
+
|
169 |
+
return c
|
170 |
+
|
171 |
+
|
172 |
+
class RenderCNN(nn.Module):
|
173 |
+
r"""CNN converting intermediate feature map to final image."""
|
174 |
+
|
175 |
+
def __init__(self, in_channels, style_dim, hidden_channels=256,
|
176 |
+
leaky_relu=True):
|
177 |
+
super(RenderCNN, self).__init__()
|
178 |
+
self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels)
|
179 |
+
|
180 |
+
self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0)
|
181 |
+
self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1)
|
182 |
+
self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False)
|
183 |
+
|
184 |
+
self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1)
|
185 |
+
self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False)
|
186 |
+
|
187 |
+
self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0)
|
188 |
+
self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0)
|
189 |
+
|
190 |
+
self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0)
|
191 |
+
|
192 |
+
if leaky_relu:
|
193 |
+
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
194 |
+
else:
|
195 |
+
self.act = functools.partial(F.relu, inplace=True)
|
196 |
+
|
197 |
+
def modulate(self, x, w, b):
|
198 |
+
w = w[..., None, None]
|
199 |
+
b = b[..., None, None]
|
200 |
+
return x * (w+1) + b
|
201 |
+
|
202 |
+
def forward(self, x, z):
|
203 |
+
r"""Forward network.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
x (N x in_channels x H x W tensor): Intermediate feature map
|
207 |
+
z (N x style_dim tensor): Style codes.
|
208 |
+
"""
|
209 |
+
z = self.fc_z_cond(z)
|
210 |
+
adapt = torch.chunk(z, 2 * 2, dim=-1)
|
211 |
+
|
212 |
+
y = self.act(self.conv1(x))
|
213 |
+
|
214 |
+
y = y + self.conv2b(self.act(self.conv2a(y)))
|
215 |
+
y = self.act(self.modulate(y, adapt[0], adapt[1]))
|
216 |
+
|
217 |
+
y = y + self.conv3b(self.act(self.conv3a(y)))
|
218 |
+
y = self.act(self.modulate(y, adapt[2], adapt[3]))
|
219 |
+
|
220 |
+
y = y + self.conv4b(self.act(self.conv4a(y)))
|
221 |
+
y = self.act(y)
|
222 |
+
|
223 |
+
y = self.conv4(y)
|
224 |
+
|
225 |
+
return y
|
226 |
+
|
227 |
+
|
228 |
+
class StyleEncoder(nn.Module):
|
229 |
+
r"""Style Encoder constructor.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
style_enc_cfg (obj): Style encoder definition file.
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, style_enc_cfg):
|
236 |
+
super(StyleEncoder, self).__init__()
|
237 |
+
input_image_channels = style_enc_cfg.input_image_channels
|
238 |
+
num_filters = style_enc_cfg.num_filters
|
239 |
+
kernel_size = style_enc_cfg.kernel_size
|
240 |
+
padding = int(np.ceil((kernel_size - 1.0) / 2))
|
241 |
+
style_dims = style_enc_cfg.style_dims
|
242 |
+
weight_norm_type = style_enc_cfg.weight_norm_type
|
243 |
+
self.no_vae = getattr(style_enc_cfg, 'no_vae', False)
|
244 |
+
activation_norm_type = 'none'
|
245 |
+
nonlinearity = 'leakyrelu'
|
246 |
+
base_conv2d_block = \
|
247 |
+
functools.partial(Conv2dBlock,
|
248 |
+
kernel_size=kernel_size,
|
249 |
+
stride=2,
|
250 |
+
padding=padding,
|
251 |
+
weight_norm_type=weight_norm_type,
|
252 |
+
activation_norm_type=activation_norm_type,
|
253 |
+
# inplace_nonlinearity=True,
|
254 |
+
nonlinearity=nonlinearity)
|
255 |
+
self.layer1 = base_conv2d_block(input_image_channels, num_filters)
|
256 |
+
self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2)
|
257 |
+
self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4)
|
258 |
+
self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8)
|
259 |
+
self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8)
|
260 |
+
self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8)
|
261 |
+
self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
|
262 |
+
if not self.no_vae:
|
263 |
+
self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
|
264 |
+
|
265 |
+
def forward(self, input_x):
|
266 |
+
r"""SPADE Style Encoder forward.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
input_x (N x 3 x H x W tensor): input images.
|
270 |
+
Returns:
|
271 |
+
mu (N x C tensor): Mean vectors.
|
272 |
+
logvar (N x C tensor): Log-variance vectors.
|
273 |
+
z (N x C tensor): Style code vectors.
|
274 |
+
"""
|
275 |
+
if input_x.size(2) != 256 or input_x.size(3) != 256:
|
276 |
+
input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear')
|
277 |
+
x = self.layer1(input_x)
|
278 |
+
x = self.layer2(x)
|
279 |
+
x = self.layer3(x)
|
280 |
+
x = self.layer4(x)
|
281 |
+
x = self.layer5(x)
|
282 |
+
x = self.layer6(x)
|
283 |
+
x = x.view(x.size(0), -1)
|
284 |
+
mu = self.fc_mu(x)
|
285 |
+
if not self.no_vae:
|
286 |
+
logvar = self.fc_var(x)
|
287 |
+
std = torch.exp(0.5 * logvar)
|
288 |
+
eps = torch.randn_like(std)
|
289 |
+
z = eps.mul(std) + mu
|
290 |
+
else:
|
291 |
+
z = mu
|
292 |
+
logvar = torch.zeros_like(mu)
|
293 |
+
return mu, logvar, z
|
294 |
+
|
295 |
+
|
296 |
+
class Base3DGenerator(nn.Module):
|
297 |
+
r"""Minecraft 3D generator constructor.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
gen_cfg (obj): Generator definition part of the yaml config file.
|
301 |
+
data_cfg (obj): Data definition part of the yaml config file.
|
302 |
+
"""
|
303 |
+
|
304 |
+
def __init__(self, gen_cfg, data_cfg):
|
305 |
+
super(Base3DGenerator, self).__init__()
|
306 |
+
print('Base3DGenerator initialization.')
|
307 |
+
|
308 |
+
# ---------------------- Main Network ------------------------
|
309 |
+
# Exclude some of the features from positional encoding
|
310 |
+
self.pe_no_pe_feat_dim = getattr(gen_cfg, 'pe_no_pe_feat_dim', 0)
|
311 |
+
|
312 |
+
# blk_feat passes through PE
|
313 |
+
input_dim = (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim)*(gen_cfg.pe_lvl_feat*2) + self.pe_no_pe_feat_dim
|
314 |
+
if (gen_cfg.pe_incl_orig_feat):
|
315 |
+
input_dim += (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim)
|
316 |
+
print('[Base3DGenerator] Expected input dimensions: ', input_dim)
|
317 |
+
self.input_dim = input_dim
|
318 |
+
|
319 |
+
self.mlp_model_kwargs = gen_cfg.mlp_model_kwargs
|
320 |
+
self.pe_lvl_localcoords = getattr(gen_cfg, 'pe_lvl_localcoords', 0)
|
321 |
+
if self.pe_lvl_localcoords > 0:
|
322 |
+
self.mlp_model_kwargs['poscode_dim'] = self.pe_lvl_localcoords * 2 * 3
|
323 |
+
|
324 |
+
# Set pe_lvl_raydir=0 and pe_incl_orig_raydir=False to disable view direction input
|
325 |
+
input_dim_viewdir = 3*(gen_cfg.pe_lvl_raydir*2)
|
326 |
+
if (gen_cfg.pe_incl_orig_raydir):
|
327 |
+
input_dim_viewdir += 3
|
328 |
+
print('[Base3DGenerator] Expected viewdir input dimensions: ', input_dim_viewdir)
|
329 |
+
self.input_dim_viewdir = input_dim_viewdir
|
330 |
+
|
331 |
+
self.pe_params = [gen_cfg.pe_lvl_feat, gen_cfg.pe_incl_orig_feat,
|
332 |
+
gen_cfg.pe_lvl_raydir, gen_cfg.pe_incl_orig_raydir]
|
333 |
+
|
334 |
+
# Style input dimension
|
335 |
+
style_dims = gen_cfg.style_dims
|
336 |
+
self.style_dims = style_dims
|
337 |
+
interm_style_dims = getattr(gen_cfg, 'interm_style_dims', style_dims)
|
338 |
+
self.interm_style_dims = interm_style_dims
|
339 |
+
# ---------------------- Style MLP --------------------------
|
340 |
+
self.style_net = globals()[gen_cfg.stylenet_model](
|
341 |
+
style_dims, interm_style_dims, **gen_cfg.stylenet_model_kwargs)
|
342 |
+
|
343 |
+
# number of output channels for MLP (before blending)
|
344 |
+
final_feat_dim = getattr(gen_cfg, 'final_feat_dim', 16)
|
345 |
+
self.final_feat_dim = final_feat_dim
|
346 |
+
|
347 |
+
# ----------------------- Sky Network -------------------------
|
348 |
+
sky_input_dim_base = 3
|
349 |
+
# Dedicated sky network input dimensions
|
350 |
+
sky_input_dim = sky_input_dim_base*(gen_cfg.pe_lvl_raydir_sky*2)
|
351 |
+
if (gen_cfg.pe_incl_orig_raydir_sky):
|
352 |
+
sky_input_dim += sky_input_dim_base
|
353 |
+
print('[Base3DGenerator] Expected sky input dimensions: ', sky_input_dim)
|
354 |
+
self.pe_params_sky = [gen_cfg.pe_lvl_raydir_sky, gen_cfg.pe_incl_orig_raydir_sky]
|
355 |
+
self.sky_net = SKYMLP(sky_input_dim, style_dim=interm_style_dims, out_channels_c=final_feat_dim)
|
356 |
+
|
357 |
+
# ----------------------- Style Encoder -------------------------
|
358 |
+
style_enc_cfg = getattr(gen_cfg, 'style_enc', None)
|
359 |
+
setattr(style_enc_cfg, 'input_image_channels', 3)
|
360 |
+
setattr(style_enc_cfg, 'style_dims', gen_cfg.style_dims)
|
361 |
+
self.style_encoder = StyleEncoder(style_enc_cfg)
|
362 |
+
|
363 |
+
# ---------------------- Ray Caster -------------------------
|
364 |
+
self.num_blocks_early_stop = gen_cfg.num_blocks_early_stop
|
365 |
+
self.num_samples = gen_cfg.num_samples
|
366 |
+
self.sample_depth = gen_cfg.sample_depth
|
367 |
+
self.coarse_deterministic_sampling = getattr(gen_cfg, 'coarse_deterministic_sampling', True)
|
368 |
+
self.sample_use_box_boundaries = getattr(gen_cfg, 'sample_use_box_boundaries', True)
|
369 |
+
|
370 |
+
# ---------------------- Blender -------------------------
|
371 |
+
self.raw_noise_std = getattr(gen_cfg, 'raw_noise_std', 0.0)
|
372 |
+
self.dists_scale = getattr(gen_cfg, 'dists_scale', 0.25)
|
373 |
+
self.clip_feat_map = getattr(gen_cfg, 'clip_feat_map', True)
|
374 |
+
self.keep_sky_out = getattr(gen_cfg, 'keep_sky_out', False)
|
375 |
+
self.keep_sky_out_avgpool = getattr(gen_cfg, 'keep_sky_out_avgpool', False)
|
376 |
+
keep_sky_out_learnbg = getattr(gen_cfg, 'keep_sky_out_learnbg', False)
|
377 |
+
self.sky_global_avgpool = getattr(gen_cfg, 'sky_global_avgpool', False)
|
378 |
+
if self.keep_sky_out:
|
379 |
+
self.sky_replace_color = None
|
380 |
+
if keep_sky_out_learnbg:
|
381 |
+
sky_replace_color = torch.zeros([final_feat_dim])
|
382 |
+
sky_replace_color.requires_grad = True
|
383 |
+
self.sky_replace_color = torch.nn.Parameter(sky_replace_color)
|
384 |
+
# ---------------------- render_cnn -------------------------
|
385 |
+
self.denoiser = RenderCNN(final_feat_dim, style_dim=interm_style_dims)
|
386 |
+
self.pad = gen_cfg.pad
|
387 |
+
|
388 |
+
def get_param_groups(self, cfg_opt):
|
389 |
+
print('[Generator] get_param_groups')
|
390 |
+
|
391 |
+
if hasattr(cfg_opt, 'ignore_parameters'):
|
392 |
+
print('[Generator::get_param_groups] [x]: ignored.')
|
393 |
+
optimize_parameters = []
|
394 |
+
for k, x in self.named_parameters():
|
395 |
+
match = False
|
396 |
+
for m in cfg_opt.ignore_parameters:
|
397 |
+
if re.match(m, k) is not None:
|
398 |
+
match = True
|
399 |
+
print(' [x]', k)
|
400 |
+
break
|
401 |
+
if match is False:
|
402 |
+
print(' [v]', k)
|
403 |
+
optimize_parameters.append(x)
|
404 |
+
else:
|
405 |
+
optimize_parameters = self.parameters()
|
406 |
+
|
407 |
+
param_groups = []
|
408 |
+
param_groups.append({'params': optimize_parameters})
|
409 |
+
|
410 |
+
if hasattr(cfg_opt, 'param_groups'):
|
411 |
+
optimized_param_names = []
|
412 |
+
all_param_names = [k for k, v in self.named_parameters()]
|
413 |
+
param_groups = []
|
414 |
+
for k, v in cfg_opt.param_groups.items():
|
415 |
+
print('[Generator::get_param_groups] Adding param group from config:', k, v)
|
416 |
+
params = getattr(self, k)
|
417 |
+
named_parameters = [k]
|
418 |
+
if issubclass(type(params), nn.Module):
|
419 |
+
named_parameters = [k+'.'+pname for pname, _ in params.named_parameters()]
|
420 |
+
params = params.parameters()
|
421 |
+
param_groups.append({'params': params, **v})
|
422 |
+
optimized_param_names.extend(named_parameters)
|
423 |
+
|
424 |
+
print('[Generator::get_param_groups] UNOPTIMIZED PARAMETERS:\n ',
|
425 |
+
set(all_param_names) - set(optimized_param_names))
|
426 |
+
|
427 |
+
return param_groups
|
428 |
+
|
429 |
+
def _forward_perpix_sub(self, blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot=None):
|
430 |
+
r"""Forwarding the MLP.
|
431 |
+
|
432 |
+
Args:
|
433 |
+
blk_feats (K x C1 tensor): Sparse block features.
|
434 |
+
worldcoord2 (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. L is number of samples; N is batch size, always 1.
|
435 |
+
raydirs_in (N x H x W x 1 x C2 tensor or None): ray direction embeddings.
|
436 |
+
z (N x C3 tensor): Intermediate style vectors.
|
437 |
+
mc_masks_onehot (N x H x W x L x C4): One-hot segmentation maps.
|
438 |
+
Returns:
|
439 |
+
net_out_s (N x H x W x L x 1 tensor): Opacities.
|
440 |
+
net_out_c (N x H x W x L x C5 tensor): Color embeddings.
|
441 |
+
"""
|
442 |
+
proj_feature = voxlib.sparse_trilinear_interp_worldcoord(
|
443 |
+
blk_feats, self.voxel.corner_t, worldcoord2, ign_zero=True)
|
444 |
+
|
445 |
+
render_net_extra_kwargs = {}
|
446 |
+
if self.pe_lvl_localcoords > 0:
|
447 |
+
local_coords = torch.remainder(worldcoord2, 1.0) * 2.0
|
448 |
+
# Scale to [0, 2], as the positional encoding function doesn't have internal x2
|
449 |
+
local_coords[torch.isnan(local_coords)] = 0.0
|
450 |
+
local_coords = local_coords.contiguous()
|
451 |
+
poscode = voxlib.positional_encoding(local_coords, self.pe_lvl_localcoords, -1, False)
|
452 |
+
render_net_extra_kwargs['poscode'] = poscode
|
453 |
+
|
454 |
+
if self.pe_params[0] == 0 and self.pe_params[1] is True: # no PE shortcut, saves ~400MB
|
455 |
+
feature_in = proj_feature
|
456 |
+
else:
|
457 |
+
if self.pe_no_pe_feat_dim > 0:
|
458 |
+
feature_in = voxlib.positional_encoding(
|
459 |
+
proj_feature[..., :-self.pe_no_pe_feat_dim].contiguous(), self.pe_params[0], -1, self.pe_params[1])
|
460 |
+
feature_in = torch.cat([feature_in, proj_feature[..., -self.pe_no_pe_feat_dim:]], dim=-1)
|
461 |
+
else:
|
462 |
+
feature_in = voxlib.positional_encoding(
|
463 |
+
proj_feature.contiguous(), self.pe_params[0], -1, self.pe_params[1])
|
464 |
+
|
465 |
+
net_out_s, net_out_c = self.render_net(feature_in, raydirs_in, z, mc_masks_onehot, **render_net_extra_kwargs)
|
466 |
+
|
467 |
+
if self.raw_noise_std > 0.:
|
468 |
+
noise = torch.randn_like(net_out_s) * self.raw_noise_std
|
469 |
+
net_out_s = net_out_s + noise
|
470 |
+
|
471 |
+
return net_out_s, net_out_c
|
472 |
+
|
473 |
+
def _forward_perpix(self, blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z):
|
474 |
+
r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features
|
475 |
+
|
476 |
+
Args:
|
477 |
+
blk_feats (K x C1 tensor): Sparse block features.
|
478 |
+
voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels, why always 6?
|
479 |
+
depth2 (N x 2 x H x W x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection.
|
480 |
+
raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
|
481 |
+
cam_ori_t (N x 3 tensor): Camera origins.
|
482 |
+
z (N x C3 tensor): Intermediate style vectors.
|
483 |
+
"""
|
484 |
+
# Generate sky_mask; PE transform on ray direction.
|
485 |
+
with torch.no_grad():
|
486 |
+
raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
|
487 |
+
if self.pe_params[2] == 0 and self.pe_params[3] is True:
|
488 |
+
raydirs_in = raydirs_in
|
489 |
+
elif self.pe_params[2] == 0 and self.pe_params[3] is False: # Not using raydir at all
|
490 |
+
raydirs_in = None
|
491 |
+
else:
|
492 |
+
raydirs_in = voxlib.positional_encoding(raydirs_in, self.pe_params[2], -1, self.pe_params[3])
|
493 |
+
|
494 |
+
# sky_mask: when True, ray finally hits sky
|
495 |
+
sky_mask = voxel_id[:, :, :, [-1], :] == 0
|
496 |
+
# sky_only_mask: when True, ray hits nothing but sky
|
497 |
+
sky_only_mask = voxel_id[:, :, :, [0], :] == 0
|
498 |
+
|
499 |
+
with torch.no_grad():
|
500 |
+
# Random sample points along the ray
|
501 |
+
num_samples = self.num_samples + 1
|
502 |
+
if self.sample_use_box_boundaries:
|
503 |
+
num_samples = self.num_samples - self.num_blocks_early_stop
|
504 |
+
|
505 |
+
# 10 samples per ray + 4 intersections - 2
|
506 |
+
rand_depth, new_dists, new_idx = mc_utils.sample_depth_batched(
|
507 |
+
depth2, num_samples, deterministic=self.coarse_deterministic_sampling,
|
508 |
+
use_box_boundaries=self.sample_use_box_boundaries, sample_depth=self.sample_depth)
|
509 |
+
|
510 |
+
worldcoord2 = raydirs * rand_depth + cam_ori_t[:, None, None, None, :]
|
511 |
+
|
512 |
+
# Generate per-sample segmentation label
|
513 |
+
voxel_id_reduced = self.label_trans.mc2reduced(voxel_id, ign2dirt=True)
|
514 |
+
mc_masks = torch.gather(voxel_id_reduced, -2, new_idx) # B 256 256 N 1
|
515 |
+
mc_masks = mc_masks.long()
|
516 |
+
mc_masks_onehot = torch.zeros([mc_masks.size(0), mc_masks.size(1), mc_masks.size(
|
517 |
+
2), mc_masks.size(3), self.num_reduced_labels], dtype=torch.float, device=voxel_id.device)
|
518 |
+
# mc_masks_onehot: [B H W Nlayer 680]
|
519 |
+
mc_masks_onehot.scatter_(-1, mc_masks, 1.0)
|
520 |
+
|
521 |
+
net_out_s, net_out_c = self._forward_perpix_sub(blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot)
|
522 |
+
|
523 |
+
# Handle sky
|
524 |
+
sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
|
525 |
+
sky_raydirs_in = voxlib.positional_encoding(sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1])
|
526 |
+
skynet_out_c = self.sky_net(sky_raydirs_in, z)
|
527 |
+
|
528 |
+
# Blending
|
529 |
+
weights = mc_utils.volum_rendering_relu(net_out_s, new_dists * self.dists_scale, dim=-2)
|
530 |
+
|
531 |
+
# If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero.
|
532 |
+
weights = weights * torch.logical_not(sky_only_mask).float()
|
533 |
+
total_weights_raw = torch.sum(weights, dim=-2, keepdim=True) # 256 256 1 1
|
534 |
+
total_weights = total_weights_raw
|
535 |
+
|
536 |
+
is_gnd = worldcoord2[..., [0]] <= 1.0 # Y X Z, [256, 256, 4, 3], nan < 1.0 == False
|
537 |
+
is_gnd = is_gnd.any(dim=-2, keepdim=True)
|
538 |
+
nosky_mask = torch.logical_or(torch.logical_not(sky_mask), is_gnd)
|
539 |
+
nosky_mask = nosky_mask.float()
|
540 |
+
|
541 |
+
# Avoid sky leakage
|
542 |
+
sky_weight = 1.0-total_weights
|
543 |
+
if self.keep_sky_out:
|
544 |
+
# keep_sky_out_avgpool overrides sky_replace_color
|
545 |
+
if self.sky_replace_color is None or self.keep_sky_out_avgpool:
|
546 |
+
if self.keep_sky_out_avgpool:
|
547 |
+
if hasattr(self, 'sky_avg'):
|
548 |
+
sky_avg = self.sky_avg
|
549 |
+
else:
|
550 |
+
if self.sky_global_avgpool:
|
551 |
+
sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True)
|
552 |
+
else:
|
553 |
+
skynet_out_c_nchw = skynet_out_c.permute(0, 4, 1, 2, 3).squeeze(-1).contiguous()
|
554 |
+
sky_avg = F.avg_pool2d(skynet_out_c_nchw, 31, stride=1, padding=15, count_include_pad=False)
|
555 |
+
sky_avg = sky_avg.permute(0, 2, 3, 1).unsqueeze(-2).contiguous()
|
556 |
+
# print(sky_avg.shape)
|
557 |
+
skynet_out_c = skynet_out_c * (1.0-nosky_mask) + sky_avg*(nosky_mask)
|
558 |
+
else:
|
559 |
+
sky_weight = sky_weight * (1.0-nosky_mask)
|
560 |
+
else:
|
561 |
+
skynet_out_c = skynet_out_c * (1.0-nosky_mask) + self.sky_replace_color*(nosky_mask)
|
562 |
+
|
563 |
+
if self.clip_feat_map is True: # intermediate feature before blending & CNN
|
564 |
+
rgbs = torch.clamp(net_out_c, -1, 1) + 1
|
565 |
+
rgbs_sky = torch.clamp(skynet_out_c, -1, 1) + 1
|
566 |
+
net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
|
567 |
+
rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
|
568 |
+
net_out = net_out.squeeze(-2)
|
569 |
+
net_out = net_out - 1
|
570 |
+
elif self.clip_feat_map is False:
|
571 |
+
rgbs = net_out_c
|
572 |
+
rgbs_sky = skynet_out_c
|
573 |
+
net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
|
574 |
+
rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
|
575 |
+
net_out = net_out.squeeze(-2)
|
576 |
+
elif self.clip_feat_map == 'tanh':
|
577 |
+
rgbs = torch.tanh(net_out_c)
|
578 |
+
rgbs_sky = torch.tanh(skynet_out_c)
|
579 |
+
net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
|
580 |
+
rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
|
581 |
+
net_out = net_out.squeeze(-2)
|
582 |
+
else:
|
583 |
+
raise NotImplementedError
|
584 |
+
|
585 |
+
return net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \
|
586 |
+
nosky_mask, sky_mask, sky_only_mask, new_idx
|
587 |
+
|
588 |
+
def _forward_global(self, net_out, z):
|
589 |
+
r"""Forward the CNN
|
590 |
+
|
591 |
+
Args:
|
592 |
+
net_out (N x C5 x H x W tensor): Intermediate feature maps.
|
593 |
+
z (N x C3 tensor): Intermediate style vectors.
|
594 |
+
|
595 |
+
Returns:
|
596 |
+
fake_images (N x 3 x H x W tensor): Output image.
|
597 |
+
fake_images_raw (N x 3 x H x W tensor): Output image before TanH.
|
598 |
+
"""
|
599 |
+
fake_images = net_out.permute(0, 3, 1, 2).contiguous()
|
600 |
+
fake_images_raw = self.denoiser(fake_images, z)
|
601 |
+
fake_images = torch.tanh(fake_images_raw)
|
602 |
+
|
603 |
+
return fake_images, fake_images_raw
|
imaginaire/generators/scenedreamer.py
ADDED
@@ -0,0 +1,851 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Using Hashgrid as backbone representation
|
2 |
+
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
import imageio
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
import imaginaire.model_utils.gancraft.camctl as camctl
|
12 |
+
import imaginaire.model_utils.gancraft.mc_utils as mc_utils
|
13 |
+
import imaginaire.model_utils.gancraft.voxlib as voxlib
|
14 |
+
from imaginaire.model_utils.pcg_gen import PCGVoxelGenerator, PCGCache
|
15 |
+
from imaginaire.utils.distributed import master_only_print as print
|
16 |
+
from imaginaire.generators.gancraft_base import Base3DGenerator
|
17 |
+
from encoding import get_encoder
|
18 |
+
|
19 |
+
from imaginaire.model_utils.layers import LightningMLP, ConditionalHashGrid
|
20 |
+
|
21 |
+
class Generator(Base3DGenerator):
|
22 |
+
r"""SceneDreamer generator constructor.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
gen_cfg (obj): Generator definition part of the yaml config file.
|
26 |
+
data_cfg (obj): Data definition part of the yaml config file.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, gen_cfg, data_cfg):
|
30 |
+
super(Generator, self).__init__(gen_cfg, data_cfg)
|
31 |
+
print('SceneDreamer[Hash] on ALL Scenes generator initialization.')
|
32 |
+
|
33 |
+
# here should be a list of height maps and semantic maps
|
34 |
+
if gen_cfg.pcg_cache:
|
35 |
+
print('[Generator] Loading PCG dataset: ', gen_cfg.pcg_dataset_path)
|
36 |
+
self.voxel = PCGCache(gen_cfg.pcg_dataset_path)
|
37 |
+
print('[Generator] Loaded PCG dataset.')
|
38 |
+
else:
|
39 |
+
self.voxel = PCGVoxelGenerator(gen_cfg.scene_size)
|
40 |
+
self.blk_feats = None
|
41 |
+
# Minecraft -> SPADE label translator.
|
42 |
+
self.label_trans = mc_utils.MCLabelTranslator()
|
43 |
+
self.num_reduced_labels = self.label_trans.get_num_reduced_lbls()
|
44 |
+
self.reduced_label_set = getattr(gen_cfg, 'reduced_label_set', False)
|
45 |
+
self.use_label_smooth = getattr(gen_cfg, 'use_label_smooth', False)
|
46 |
+
self.use_label_smooth_real = getattr(gen_cfg, 'use_label_smooth_real', self.use_label_smooth)
|
47 |
+
self.use_label_smooth_pgt = getattr(gen_cfg, 'use_label_smooth_pgt', False)
|
48 |
+
self.label_smooth_dia = getattr(gen_cfg, 'label_smooth_dia', 11)
|
49 |
+
|
50 |
+
# Load MLP model.
|
51 |
+
self.hash_encoder, self.hash_in_dim = get_encoder(encoding='hashgrid', input_dim=5, desired_resolution=2048 * 1, level_dim=8)
|
52 |
+
self.render_net = LightningMLP(self.hash_in_dim, viewdir_dim=self.input_dim_viewdir, style_dim=self.interm_style_dims, mask_dim=self.num_reduced_labels, out_channels_s=1, out_channels_c=self.final_feat_dim, **self.mlp_model_kwargs)
|
53 |
+
print(self.hash_encoder)
|
54 |
+
self.world_encoder = ConditionalHashGrid()
|
55 |
+
|
56 |
+
# Camera sampler.
|
57 |
+
self.camera_sampler_type = getattr(gen_cfg, 'camera_sampler_type', "random")
|
58 |
+
assert self.camera_sampler_type in ['random', 'traditional']
|
59 |
+
self.camera_min_entropy = getattr(gen_cfg, 'camera_min_entropy', -1)
|
60 |
+
self.camera_rej_avg_depth = getattr(gen_cfg, 'camera_rej_avg_depth', -1)
|
61 |
+
self.cam_res = gen_cfg.cam_res
|
62 |
+
self.crop_size = gen_cfg.crop_size
|
63 |
+
|
64 |
+
print('Done with the SceneDreamer initialization.')
|
65 |
+
|
66 |
+
def custom_init(self):
|
67 |
+
r"""Weight initialization."""
|
68 |
+
|
69 |
+
def init_func(m):
|
70 |
+
if hasattr(m, 'weight'):
|
71 |
+
try:
|
72 |
+
nn.init.kaiming_normal_(m.weight.data, a=0.2, nonlinearity='leaky_relu')
|
73 |
+
except:
|
74 |
+
print(m.name)
|
75 |
+
m.weight.data *= 0.5
|
76 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
77 |
+
m.bias.data.fill_(0.0)
|
78 |
+
self.apply(init_func)
|
79 |
+
|
80 |
+
def _get_batch(self, batch_size, device):
|
81 |
+
r"""Sample camera poses and perform ray-voxel intersection.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
batch_size (int): Expected batch size of the current batch
|
85 |
+
device (torch.device): Device on which the tensors should be stored
|
86 |
+
"""
|
87 |
+
with torch.no_grad():
|
88 |
+
self.voxel.sample_world(device)
|
89 |
+
voxel_id_batch = []
|
90 |
+
depth2_batch = []
|
91 |
+
raydirs_batch = []
|
92 |
+
cam_ori_t_batch = []
|
93 |
+
for b in range(batch_size):
|
94 |
+
while True: # Rejection sampling.
|
95 |
+
# Sample camera pose.
|
96 |
+
if self.camera_sampler_type == 'random':
|
97 |
+
cam_res = self.cam_res
|
98 |
+
cam_ori_t, cam_dir_t, cam_up_t = camctl.rand_camera_pose_thridperson2(self.voxel)
|
99 |
+
# ~24mm fov horizontal.
|
100 |
+
cam_f = 0.5/np.tan(np.deg2rad(73/2) * (np.random.rand(1)*0.5+0.5)) * (cam_res[1]-1)
|
101 |
+
cam_c = [(cam_res[0]-1)/2, (cam_res[1]-1)/2]
|
102 |
+
cam_res_crop = [self.crop_size[0] + self.pad, self.crop_size[1] + self.pad]
|
103 |
+
cam_c = mc_utils.rand_crop(cam_c, cam_res, cam_res_crop)
|
104 |
+
elif self.camera_sampler_type == 'traditional':
|
105 |
+
cam_res = self.cam_res
|
106 |
+
cam_c = [(cam_res[0]-1)/2, (cam_res[1]-1)/2]
|
107 |
+
dice = torch.rand(1).item()
|
108 |
+
if dice > 0.5:
|
109 |
+
cam_ori_t, cam_dir_t, cam_up_t, cam_f = \
|
110 |
+
camctl.rand_camera_pose_tour(self.voxel)
|
111 |
+
cam_f = cam_f * (cam_res[1]-1)
|
112 |
+
else:
|
113 |
+
cam_ori_t, cam_dir_t, cam_up_t = \
|
114 |
+
camctl.rand_camera_pose_thridperson2(self.voxel)
|
115 |
+
# ~24mm fov horizontal.
|
116 |
+
cam_f = 0.5 / np.tan(np.deg2rad(73/2) * (np.random.rand(1)*0.5+0.5)) * (cam_res[1]-1)
|
117 |
+
|
118 |
+
cam_res_crop = [self.crop_size[0] + self.pad, self.crop_size[1] + self.pad]
|
119 |
+
cam_c = mc_utils.rand_crop(cam_c, cam_res, cam_res_crop)
|
120 |
+
else:
|
121 |
+
raise NotImplementedError(
|
122 |
+
'Unknown self.camera_sampler_type: {}'.format(self.camera_sampler_type))
|
123 |
+
|
124 |
+
# Run ray-voxel intersection test
|
125 |
+
voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective(
|
126 |
+
self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, cam_res_crop,
|
127 |
+
self.num_blocks_early_stop)
|
128 |
+
|
129 |
+
if self.camera_rej_avg_depth > 0:
|
130 |
+
depth_map = depth2[0, :, :, 0, :]
|
131 |
+
avg_depth = torch.mean(depth_map[~torch.isnan(depth_map)])
|
132 |
+
if avg_depth < self.camera_rej_avg_depth:
|
133 |
+
continue
|
134 |
+
|
135 |
+
# Reject low entropy.
|
136 |
+
if self.camera_min_entropy > 0:
|
137 |
+
# Check entropy.
|
138 |
+
maskcnt = torch.bincount(
|
139 |
+
torch.flatten(voxel_id[:, :, 0, 0]), weights=None, minlength=680).float() / \
|
140 |
+
(voxel_id.size(0)*voxel_id.size(1))
|
141 |
+
maskentropy = -torch.sum(maskcnt * torch.log(maskcnt+1e-10))
|
142 |
+
if maskentropy < self.camera_min_entropy:
|
143 |
+
continue
|
144 |
+
break
|
145 |
+
|
146 |
+
voxel_id_batch.append(voxel_id)
|
147 |
+
depth2_batch.append(depth2)
|
148 |
+
raydirs_batch.append(raydirs)
|
149 |
+
cam_ori_t_batch.append(cam_ori_t)
|
150 |
+
voxel_id = torch.stack(voxel_id_batch, dim=0)
|
151 |
+
depth2 = torch.stack(depth2_batch, dim=0)
|
152 |
+
raydirs = torch.stack(raydirs_batch, dim=0)
|
153 |
+
cam_ori_t = torch.stack(cam_ori_t_batch, dim=0).to(device)
|
154 |
+
cam_poses = None
|
155 |
+
return voxel_id, depth2, raydirs, cam_ori_t, cam_poses
|
156 |
+
|
157 |
+
|
158 |
+
def get_pseudo_gt(self, pseudo_gen, voxel_id, z=None, style_img=None, resize_512=True, deterministic=False):
|
159 |
+
r"""Evaluating img2img network to obtain pseudo-ground truth images.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
pseudo_gen (callable): Function converting mask to image using img2img network.
|
163 |
+
voxel_id (N x img_dims[0] x img_dims[1] x max_samples x 1 tensor): IDs of intersected tensors along
|
164 |
+
each ray.
|
165 |
+
z (N x C tensor): Optional style code passed to pseudo_gen.
|
166 |
+
style_img (N x 3 x H x W tensor): Optional style image passed to pseudo_gen.
|
167 |
+
resize_512 (bool): If True, evaluate pseudo_gen at 512x512 regardless of input resolution.
|
168 |
+
deterministic (bool): If True, disable stochastic label mapping.
|
169 |
+
"""
|
170 |
+
with torch.no_grad():
|
171 |
+
mc_mask = voxel_id[:, :, :, 0, :].permute(0, 3, 1, 2).long().contiguous()
|
172 |
+
coco_mask = self.label_trans.mc2coco(mc_mask) - 1
|
173 |
+
coco_mask[coco_mask < 0] = 183
|
174 |
+
|
175 |
+
if not deterministic:
|
176 |
+
# Stochastic mapping
|
177 |
+
dice = torch.rand(1).item()
|
178 |
+
if dice > 0.5 and dice < 0.9:
|
179 |
+
coco_mask[coco_mask == self.label_trans.gglbl2ggid('sky')] = self.label_trans.gglbl2ggid('clouds')
|
180 |
+
elif dice >= 0.9:
|
181 |
+
coco_mask[coco_mask == self.label_trans.gglbl2ggid('sky')] = self.label_trans.gglbl2ggid('fog')
|
182 |
+
dice = torch.rand(1).item()
|
183 |
+
if dice > 0.33 and dice < 0.66:
|
184 |
+
coco_mask[coco_mask == self.label_trans.gglbl2ggid('water')] = self.label_trans.gglbl2ggid('sea')
|
185 |
+
elif dice >= 0.66:
|
186 |
+
coco_mask[coco_mask == self.label_trans.gglbl2ggid('water')] = self.label_trans.gglbl2ggid('river')
|
187 |
+
|
188 |
+
fake_masks = torch.zeros([coco_mask.size(0), 185, coco_mask.size(2), coco_mask.size(3)],
|
189 |
+
dtype=torch.half, device=voxel_id.device)
|
190 |
+
fake_masks.scatter_(1, coco_mask, 1.0)
|
191 |
+
|
192 |
+
if self.use_label_smooth_pgt:
|
193 |
+
fake_masks = mc_utils.segmask_smooth(fake_masks, kernel_size=self.label_smooth_dia)
|
194 |
+
if self.pad > 0:
|
195 |
+
fake_masks = fake_masks[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
|
196 |
+
|
197 |
+
# Generate pseudo GT using GauGAN.
|
198 |
+
if resize_512:
|
199 |
+
fake_masks_512 = F.interpolate(fake_masks, size=[512, 512], mode='nearest')
|
200 |
+
else:
|
201 |
+
fake_masks_512 = fake_masks
|
202 |
+
pseudo_real_img = pseudo_gen(fake_masks_512, z=z, style_img=style_img)
|
203 |
+
|
204 |
+
# NaN Inf Guard. NaN can occure on Volta GPUs.
|
205 |
+
nan_mask = torch.isnan(pseudo_real_img)
|
206 |
+
inf_mask = torch.isinf(pseudo_real_img)
|
207 |
+
pseudo_real_img[nan_mask | inf_mask] = 0.0
|
208 |
+
if resize_512:
|
209 |
+
pseudo_real_img = F.interpolate(
|
210 |
+
pseudo_real_img, size=[fake_masks.size(2), fake_masks.size(3)], mode='area')
|
211 |
+
pseudo_real_img = torch.clamp(pseudo_real_img, -1, 1)
|
212 |
+
|
213 |
+
return pseudo_real_img, fake_masks
|
214 |
+
|
215 |
+
|
216 |
+
def sample_camera(self, data, pseudo_gen):
|
217 |
+
r"""Sample camera randomly and precompute everything used by both Gen and Dis.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
data (dict):
|
221 |
+
images (N x 3 x H x W tensor) : Real images
|
222 |
+
label (N x C2 x H x W tensor) : Segmentation map
|
223 |
+
pseudo_gen (callable): Function converting mask to image using img2img network.
|
224 |
+
Returns:
|
225 |
+
ret (dict):
|
226 |
+
voxel_id (N x H x W x max_samples x 1 tensor): IDs of intersected tensors along each ray.
|
227 |
+
depth2 (N x 2 x H x W x max_samples x 1 tensor): Depths of entrance and exit points for each ray-voxel
|
228 |
+
intersection.
|
229 |
+
raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
|
230 |
+
cam_ori_t (N x 3 tensor): Camera origins.
|
231 |
+
pseudo_real_img (N x 3 x H x W tensor): Pseudo-ground truth image.
|
232 |
+
real_masks (N x C3 x H x W tensor): One-hot segmentation map for real images, with translated labels.
|
233 |
+
fake_masks (N x C3 x H x W tensor): One-hot segmentation map for sampled camera views.
|
234 |
+
"""
|
235 |
+
device = torch.device('cuda')
|
236 |
+
batch_size = data['images'].size(0)
|
237 |
+
# ================ Assemble a batch ==================
|
238 |
+
# Requires: voxel_id, depth2, raydirs, cam_ori_t.
|
239 |
+
voxel_id, depth2, raydirs, cam_ori_t, _ = self._get_batch(batch_size, device)
|
240 |
+
ret = {'voxel_id': voxel_id, 'depth2': depth2, 'raydirs': raydirs, 'cam_ori_t': cam_ori_t}
|
241 |
+
|
242 |
+
if pseudo_gen is not None:
|
243 |
+
pseudo_real_img, _ = self.get_pseudo_gt(pseudo_gen, voxel_id)
|
244 |
+
ret['pseudo_real_img'] = pseudo_real_img.float()
|
245 |
+
|
246 |
+
# =============== Mask translation ================
|
247 |
+
real_masks = data['label']
|
248 |
+
if self.reduced_label_set:
|
249 |
+
# Translate fake mask (directly from mcid).
|
250 |
+
# convert unrecognized labels to 'dirt'.
|
251 |
+
# N C H W [1 1 80 80]
|
252 |
+
reduce_fake_mask = self.label_trans.mc2reduced(
|
253 |
+
voxel_id[:, :, :, 0, :].permute(0, 3, 1, 2).long().contiguous()
|
254 |
+
, ign2dirt=True)
|
255 |
+
reduce_fake_mask_onehot = torch.zeros([
|
256 |
+
reduce_fake_mask.size(0), self.num_reduced_labels, reduce_fake_mask.size(2), reduce_fake_mask.size(3)],
|
257 |
+
dtype=torch.float, device=device)
|
258 |
+
reduce_fake_mask_onehot.scatter_(1, reduce_fake_mask, 1.0)
|
259 |
+
fake_masks = reduce_fake_mask_onehot
|
260 |
+
if self.pad != 0:
|
261 |
+
fake_masks = fake_masks[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
|
262 |
+
|
263 |
+
# Translate real mask (data['label']), which is onehot.
|
264 |
+
real_masks_idx = torch.argmax(real_masks, dim=1, keepdim=True)
|
265 |
+
real_masks_idx[real_masks_idx > 182] = 182
|
266 |
+
|
267 |
+
reduced_real_mask = self.label_trans.coco2reduced(real_masks_idx)
|
268 |
+
reduced_real_mask_onehot = torch.zeros([
|
269 |
+
reduced_real_mask.size(0), self.num_reduced_labels, reduced_real_mask.size(2),
|
270 |
+
reduced_real_mask.size(3)], dtype=torch.float, device=device)
|
271 |
+
reduced_real_mask_onehot.scatter_(1, reduced_real_mask, 1.0)
|
272 |
+
real_masks = reduced_real_mask_onehot
|
273 |
+
|
274 |
+
# Mask smoothing.
|
275 |
+
if self.use_label_smooth:
|
276 |
+
fake_masks = mc_utils.segmask_smooth(fake_masks, kernel_size=self.label_smooth_dia)
|
277 |
+
if self.use_label_smooth_real:
|
278 |
+
real_masks = mc_utils.segmask_smooth(real_masks, kernel_size=self.label_smooth_dia)
|
279 |
+
|
280 |
+
ret['real_masks'] = real_masks
|
281 |
+
ret['fake_masks'] = fake_masks
|
282 |
+
|
283 |
+
return ret
|
284 |
+
|
285 |
+
def _forward_perpix_sub(self, blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot=None, global_enc=None):
|
286 |
+
r"""Per-pixel rendering forwarding
|
287 |
+
|
288 |
+
Args:
|
289 |
+
blk_feats: Deprecated
|
290 |
+
worldcoord2 (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. L is number of samples; N is batch size, always 1.
|
291 |
+
raydirs_in (N x H x W x 1 x C2 tensor or None): ray direction embeddings.
|
292 |
+
z (N x C3 tensor): Intermediate style vectors.
|
293 |
+
mc_masks_onehot (N x H x W x L x C4): One-hot segmentation maps.
|
294 |
+
Returns:
|
295 |
+
net_out_s (N x H x W x L x 1 tensor): Opacities.
|
296 |
+
net_out_c (N x H x W x L x C5 tensor): Color embeddings.
|
297 |
+
"""
|
298 |
+
_x, _y, _z = self.voxel.voxel_t.shape
|
299 |
+
delimeter = torch.Tensor([_x, _y, _z]).to(worldcoord2)
|
300 |
+
normalized_coord = worldcoord2 / delimeter * 2 - 1
|
301 |
+
global_enc = global_enc[:, None, None, None, :].repeat(1, normalized_coord.shape[1], normalized_coord.shape[2], normalized_coord.shape[3], 1)
|
302 |
+
normalized_coord = torch.cat([normalized_coord, global_enc], dim=-1)
|
303 |
+
feature_in = self.hash_encoder(normalized_coord)
|
304 |
+
|
305 |
+
net_out_s, net_out_c = self.render_net(feature_in, raydirs_in, z, mc_masks_onehot)
|
306 |
+
|
307 |
+
if self.raw_noise_std > 0.:
|
308 |
+
noise = torch.randn_like(net_out_s) * self.raw_noise_std
|
309 |
+
net_out_s = net_out_s + noise
|
310 |
+
|
311 |
+
return net_out_s, net_out_c
|
312 |
+
|
313 |
+
def _forward_perpix(self, blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z, global_enc):
|
314 |
+
r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features
|
315 |
+
|
316 |
+
Args:
|
317 |
+
blk_feats (K x C1 tensor): Deprecated
|
318 |
+
voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels, why always 6?
|
319 |
+
depth2 (N x 2 x H x W x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection.
|
320 |
+
raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
|
321 |
+
cam_ori_t (N x 3 tensor): Camera origins.
|
322 |
+
z (N x C3 tensor): Intermediate style vectors.
|
323 |
+
"""
|
324 |
+
# Generate sky_mask; PE transform on ray direction.
|
325 |
+
with torch.no_grad():
|
326 |
+
raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
|
327 |
+
if self.pe_params[2] == 0 and self.pe_params[3] is True:
|
328 |
+
raydirs_in = raydirs_in
|
329 |
+
elif self.pe_params[2] == 0 and self.pe_params[3] is False: # Not using raydir at all
|
330 |
+
raydirs_in = None
|
331 |
+
else:
|
332 |
+
raydirs_in = voxlib.positional_encoding(raydirs_in, self.pe_params[2], -1, self.pe_params[3])
|
333 |
+
|
334 |
+
# sky_mask: when True, ray finally hits sky
|
335 |
+
sky_mask = voxel_id[:, :, :, [-1], :] == 0
|
336 |
+
# sky_only_mask: when True, ray hits nothing but sky
|
337 |
+
sky_only_mask = voxel_id[:, :, :, [0], :] == 0
|
338 |
+
|
339 |
+
with torch.no_grad():
|
340 |
+
# Random sample points along the ray
|
341 |
+
num_samples = self.num_samples + 1
|
342 |
+
if self.sample_use_box_boundaries:
|
343 |
+
num_samples = self.num_samples - self.num_blocks_early_stop
|
344 |
+
|
345 |
+
# 10 samples per ray + 4 intersections - 2
|
346 |
+
rand_depth, new_dists, new_idx = mc_utils.sample_depth_batched(
|
347 |
+
depth2, num_samples, deterministic=self.coarse_deterministic_sampling,
|
348 |
+
use_box_boundaries=self.sample_use_box_boundaries, sample_depth=self.sample_depth)
|
349 |
+
|
350 |
+
nan_mask = torch.isnan(rand_depth)
|
351 |
+
inf_mask = torch.isinf(rand_depth)
|
352 |
+
rand_depth[nan_mask | inf_mask] = 0.0
|
353 |
+
|
354 |
+
worldcoord2 = raydirs * rand_depth + cam_ori_t[:, None, None, None, :]
|
355 |
+
|
356 |
+
# Generate per-sample segmentation label
|
357 |
+
voxel_id_reduced = self.label_trans.mc2reduced(voxel_id, ign2dirt=True)
|
358 |
+
mc_masks = torch.gather(voxel_id_reduced, -2, new_idx) # B 256 256 N 1
|
359 |
+
mc_masks = mc_masks.long()
|
360 |
+
mc_masks_onehot = torch.zeros([mc_masks.size(0), mc_masks.size(1), mc_masks.size(
|
361 |
+
2), mc_masks.size(3), self.num_reduced_labels], dtype=torch.float, device=voxel_id.device)
|
362 |
+
# mc_masks_onehot: [B H W Nlayer 680]
|
363 |
+
mc_masks_onehot.scatter_(-1, mc_masks, 1.0)
|
364 |
+
|
365 |
+
net_out_s, net_out_c = self._forward_perpix_sub(blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot, global_enc)
|
366 |
+
|
367 |
+
# Handle sky
|
368 |
+
sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
|
369 |
+
sky_raydirs_in = voxlib.positional_encoding(sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1])
|
370 |
+
skynet_out_c = self.sky_net(sky_raydirs_in, z)
|
371 |
+
|
372 |
+
# Blending
|
373 |
+
weights = mc_utils.volum_rendering_relu(net_out_s, new_dists * self.dists_scale, dim=-2)
|
374 |
+
|
375 |
+
# If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero.
|
376 |
+
weights = weights * torch.logical_not(sky_only_mask).float()
|
377 |
+
total_weights_raw = torch.sum(weights, dim=-2, keepdim=True) # 256 256 1 1
|
378 |
+
total_weights = total_weights_raw
|
379 |
+
|
380 |
+
is_gnd = worldcoord2[..., [0]] <= 1.0 # Y X Z, [256, 256, 4, 3], nan < 1.0 == False
|
381 |
+
is_gnd = is_gnd.any(dim=-2, keepdim=True)
|
382 |
+
nosky_mask = torch.logical_or(torch.logical_not(sky_mask), is_gnd)
|
383 |
+
nosky_mask = nosky_mask.float()
|
384 |
+
|
385 |
+
# Avoid sky leakage
|
386 |
+
sky_weight = 1.0-total_weights
|
387 |
+
if self.keep_sky_out:
|
388 |
+
# keep_sky_out_avgpool overrides sky_replace_color
|
389 |
+
if self.sky_replace_color is None or self.keep_sky_out_avgpool:
|
390 |
+
if self.keep_sky_out_avgpool:
|
391 |
+
if hasattr(self, 'sky_avg'):
|
392 |
+
sky_avg = self.sky_avg
|
393 |
+
else:
|
394 |
+
if self.sky_global_avgpool:
|
395 |
+
sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True)
|
396 |
+
else:
|
397 |
+
skynet_out_c_nchw = skynet_out_c.permute(0, 4, 1, 2, 3).squeeze(-1).contiguous()
|
398 |
+
sky_avg = F.avg_pool2d(skynet_out_c_nchw, 31, stride=1, padding=15, count_include_pad=False)
|
399 |
+
sky_avg = sky_avg.permute(0, 2, 3, 1).unsqueeze(-2).contiguous()
|
400 |
+
# print(sky_avg.shape)
|
401 |
+
skynet_out_c = skynet_out_c * (1.0-nosky_mask) + sky_avg*(nosky_mask)
|
402 |
+
else:
|
403 |
+
sky_weight = sky_weight * (1.0-nosky_mask)
|
404 |
+
else:
|
405 |
+
skynet_out_c = skynet_out_c * (1.0-nosky_mask) + self.sky_replace_color*(nosky_mask)
|
406 |
+
|
407 |
+
if self.clip_feat_map is True: # intermediate feature before blending & CNN
|
408 |
+
rgbs = torch.clamp(net_out_c, -1, 1) + 1
|
409 |
+
rgbs_sky = torch.clamp(skynet_out_c, -1, 1) + 1
|
410 |
+
net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
|
411 |
+
rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
|
412 |
+
net_out = net_out.squeeze(-2)
|
413 |
+
net_out = net_out - 1
|
414 |
+
elif self.clip_feat_map is False:
|
415 |
+
rgbs = net_out_c
|
416 |
+
rgbs_sky = skynet_out_c
|
417 |
+
net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
|
418 |
+
rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
|
419 |
+
net_out = net_out.squeeze(-2)
|
420 |
+
elif self.clip_feat_map == 'tanh':
|
421 |
+
rgbs = torch.tanh(net_out_c)
|
422 |
+
rgbs_sky = torch.tanh(skynet_out_c)
|
423 |
+
net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
|
424 |
+
rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
|
425 |
+
net_out = net_out.squeeze(-2)
|
426 |
+
else:
|
427 |
+
raise NotImplementedError
|
428 |
+
|
429 |
+
return net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \
|
430 |
+
nosky_mask, sky_mask, sky_only_mask, new_idx
|
431 |
+
|
432 |
+
def forward(self, data, random_style=False):
|
433 |
+
r"""SceneDreamer forward.
|
434 |
+
"""
|
435 |
+
device = torch.device('cuda')
|
436 |
+
batch_size = data['images'].size(0)
|
437 |
+
# Requires: voxel_id, depth2, raydirs, cam_ori_t.
|
438 |
+
voxel_id, depth2, raydirs, cam_ori_t = data['voxel_id'], data['depth2'], data['raydirs'], data['cam_ori_t']
|
439 |
+
if 'pseudo_real_img' in data:
|
440 |
+
pseudo_real_img = data['pseudo_real_img']
|
441 |
+
|
442 |
+
global_enc = self.world_encoder(self.voxel.current_height_map, self.voxel.current_semantic_map)
|
443 |
+
|
444 |
+
z, mu, logvar = None, None, None
|
445 |
+
if random_style:
|
446 |
+
if self.style_dims > 0:
|
447 |
+
z = torch.randn(batch_size, self.style_dims, dtype=torch.float32, device=device)
|
448 |
+
else:
|
449 |
+
if self.style_encoder is None:
|
450 |
+
# ================ Get Style Code =================
|
451 |
+
if self.style_dims > 0:
|
452 |
+
z = torch.randn(batch_size, self.style_dims, dtype=torch.float32, device=device)
|
453 |
+
else:
|
454 |
+
mu, logvar, z = self.style_encoder(pseudo_real_img)
|
455 |
+
|
456 |
+
# ================ Network Forward ================
|
457 |
+
# Forward StyleNet
|
458 |
+
if self.style_net is not None:
|
459 |
+
z = self.style_net(z)
|
460 |
+
|
461 |
+
# Forward per-pixel net.
|
462 |
+
net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, nosky_mask, \
|
463 |
+
sky_mask, sky_only_mask, new_idx = self._forward_perpix(
|
464 |
+
self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z, global_enc)
|
465 |
+
|
466 |
+
# Forward global net.
|
467 |
+
fake_images, fake_images_raw = self._forward_global(net_out, z)
|
468 |
+
if self.pad != 0:
|
469 |
+
fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
|
470 |
+
|
471 |
+
# =============== Arrange Return Values ================
|
472 |
+
output = {}
|
473 |
+
output['fake_images'] = fake_images
|
474 |
+
output['mu'] = mu
|
475 |
+
output['logvar'] = logvar
|
476 |
+
return output
|
477 |
+
|
478 |
+
|
479 |
+
def inference_givenstyle(self, style,
|
480 |
+
output_dir,
|
481 |
+
camera_mode,
|
482 |
+
style_img_path=None,
|
483 |
+
seed=1,
|
484 |
+
pad=30,
|
485 |
+
num_samples=40,
|
486 |
+
num_blocks_early_stop=6,
|
487 |
+
sample_depth=3,
|
488 |
+
tile_size=128,
|
489 |
+
resolution_hw=[540, 960],
|
490 |
+
cam_ang=72,
|
491 |
+
cam_maxstep=10):
|
492 |
+
r"""Compute result images according to the provided camera trajectory and save the results in the specified
|
493 |
+
folder. The full image is evaluated in multiple tiles to save memory.
|
494 |
+
|
495 |
+
Args:
|
496 |
+
output_dir (str): Where should the results be stored.
|
497 |
+
camera_mode (int): Which camera trajectory to use.
|
498 |
+
style_img_path (str): Path to the style-conditioning image.
|
499 |
+
seed (int): Random seed (controls style when style_image_path is not specified).
|
500 |
+
pad (int): Pixels to remove from the image tiles before stitching. Should be equal or larger than the
|
501 |
+
receptive field of the CNN to avoid border artifact.
|
502 |
+
num_samples (int): Number of samples per ray (different from training).
|
503 |
+
num_blocks_early_stop (int): Max number of intersected boxes per ray before stopping
|
504 |
+
(different from training).
|
505 |
+
sample_depth (float): Max distance traveled through boxes before stopping (different from training).
|
506 |
+
tile_size (int): Max size of a tile in pixels.
|
507 |
+
resolution_hw (list [H, W]): Resolution of the output image.
|
508 |
+
cam_ang (float): Horizontal FOV of the camera (may be adjusted by the camera controller).
|
509 |
+
cam_maxstep (int): Number of frames sampled from the camera trajectory.
|
510 |
+
"""
|
511 |
+
|
512 |
+
def write_img(path, img, rgb_input=False):
|
513 |
+
img = ((img*0.5+0.5)*255).detach().cpu().numpy().astype(np.uint8)
|
514 |
+
img = img[0].transpose(1, 2, 0)
|
515 |
+
if rgb_input:
|
516 |
+
img = img[..., [2, 1, 0]]
|
517 |
+
cv2.imwrite(path, img, [cv2.IMWRITE_PNG_COMPRESSION, 4])
|
518 |
+
return img[..., ::-1]
|
519 |
+
|
520 |
+
def read_img(path):
|
521 |
+
img = cv2.imread(path).astype(np.float32)[..., [2, 1, 0]].transpose(2, 0, 1) / 255
|
522 |
+
img = img * 2 - 1
|
523 |
+
img = torch.from_numpy(img)
|
524 |
+
|
525 |
+
print('Saving to', output_dir)
|
526 |
+
|
527 |
+
# Use provided random seed.
|
528 |
+
device = torch.device('cuda')
|
529 |
+
|
530 |
+
global_enc = self.world_encoder(self.voxel.current_height_map, self.voxel.current_semantic_map)
|
531 |
+
|
532 |
+
biome_colors = torch.Tensor([
|
533 |
+
[255, 255, 178],
|
534 |
+
[184, 200, 98],
|
535 |
+
[188, 161, 53],
|
536 |
+
[190, 255, 242],
|
537 |
+
[106, 144, 38],
|
538 |
+
[33, 77, 41],
|
539 |
+
[86, 179, 106],
|
540 |
+
[34, 61, 53],
|
541 |
+
[35, 114, 94],
|
542 |
+
[0, 0, 255],
|
543 |
+
[0, 255, 0],
|
544 |
+
]).to(device) / 255 * 2 - 1
|
545 |
+
semantic_map = torch.argmax(self.voxel.current_semantic_map, dim=1)
|
546 |
+
|
547 |
+
self.pad = pad
|
548 |
+
self.num_samples = num_samples
|
549 |
+
self.num_blocks_early_stop = num_blocks_early_stop
|
550 |
+
self.sample_depth = sample_depth
|
551 |
+
|
552 |
+
self.coarse_deterministic_sampling = True
|
553 |
+
self.crop_size = resolution_hw
|
554 |
+
self.cam_res = [self.crop_size[0]+self.pad, self.crop_size[1]+self.pad]
|
555 |
+
self.use_label_smooth_pgt = False
|
556 |
+
|
557 |
+
# Make output dirs.
|
558 |
+
output_dir = os.path.join(output_dir, 'rgb_render')
|
559 |
+
os.makedirs(output_dir, exist_ok=True)
|
560 |
+
fout = imageio.get_writer(output_dir + '.mp4', fps=10)
|
561 |
+
|
562 |
+
write_img(os.path.join(output_dir, 'semantic_map.png'), biome_colors[semantic_map].permute(0, 3, 1, 2), rgb_input=True)
|
563 |
+
write_img(os.path.join(output_dir, 'height_map.png'), self.voxel.current_height_map)
|
564 |
+
np.save(os.path.join(output_dir, 'style.npy'), style.detach().cpu().numpy())
|
565 |
+
evalcamctl = camctl.EvalCameraController(
|
566 |
+
self.voxel, maxstep=cam_maxstep, pattern=camera_mode, cam_ang=cam_ang,
|
567 |
+
smooth_decay_multiplier=150/cam_maxstep)
|
568 |
+
|
569 |
+
# Get output style.
|
570 |
+
z = self.style_net(style)
|
571 |
+
|
572 |
+
# Generate required output images.
|
573 |
+
for id, (cam_ori_t, cam_dir_t, cam_up_t, cam_f) in enumerate(evalcamctl):
|
574 |
+
print('Rendering frame', id)
|
575 |
+
cam_f = cam_f * (self.crop_size[1]-1) # So that the view is not depending on the padding
|
576 |
+
cam_c = [(self.cam_res[0]-1)/2, (self.cam_res[1]-1)/2]
|
577 |
+
|
578 |
+
voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective(
|
579 |
+
self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, self.cam_res,
|
580 |
+
self.num_blocks_early_stop)
|
581 |
+
|
582 |
+
voxel_id = voxel_id.unsqueeze(0)
|
583 |
+
depth2 = depth2.unsqueeze(0)
|
584 |
+
raydirs = raydirs.unsqueeze(0)
|
585 |
+
cam_ori_t = cam_ori_t.unsqueeze(0).to(device)
|
586 |
+
|
587 |
+
voxel_id_all = voxel_id
|
588 |
+
depth2_all = depth2
|
589 |
+
raydirs_all = raydirs
|
590 |
+
|
591 |
+
# Evaluate sky in advance to get a consistent sky in the semi-transparent region.
|
592 |
+
if self.sky_global_avgpool:
|
593 |
+
sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
|
594 |
+
sky_raydirs_in = voxlib.positional_encoding(
|
595 |
+
sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1])
|
596 |
+
skynet_out_c = self.sky_net(sky_raydirs_in, z)
|
597 |
+
sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True)
|
598 |
+
self.sky_avg = sky_avg
|
599 |
+
|
600 |
+
num_strips_h = (self.cam_res[0]-self.pad+tile_size-1)//tile_size
|
601 |
+
num_strips_w = (self.cam_res[1]-self.pad+tile_size-1)//tile_size
|
602 |
+
|
603 |
+
fake_images_chunks_v = []
|
604 |
+
# For each horizontal strip.
|
605 |
+
for strip_id_h in range(num_strips_h):
|
606 |
+
strip_begin_h = strip_id_h * tile_size
|
607 |
+
strip_end_h = np.minimum(strip_id_h * tile_size + tile_size + self.pad, self.cam_res[0])
|
608 |
+
# For each vertical strip.
|
609 |
+
fake_images_chunks_h = []
|
610 |
+
for strip_id_w in range(num_strips_w):
|
611 |
+
strip_begin_w = strip_id_w * tile_size
|
612 |
+
strip_end_w = np.minimum(strip_id_w * tile_size + tile_size + self.pad, self.cam_res[1])
|
613 |
+
|
614 |
+
voxel_id = voxel_id_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
|
615 |
+
depth2 = depth2_all[:, :, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
|
616 |
+
raydirs = raydirs_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
|
617 |
+
|
618 |
+
net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \
|
619 |
+
nosky_mask, sky_mask, sky_only_mask, new_idx = self._forward_perpix(
|
620 |
+
self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z, global_enc)
|
621 |
+
fake_images, _ = self._forward_global(net_out, z)
|
622 |
+
|
623 |
+
if self.pad != 0:
|
624 |
+
fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
|
625 |
+
fake_images_chunks_h.append(fake_images)
|
626 |
+
fake_images_h = torch.cat(fake_images_chunks_h, dim=-1)
|
627 |
+
fake_images_chunks_v.append(fake_images_h)
|
628 |
+
fake_images = torch.cat(fake_images_chunks_v, dim=-2)
|
629 |
+
rgb = write_img(os.path.join(output_dir,
|
630 |
+
'{:05d}.png'.format(id)), fake_images, rgb_input=True)
|
631 |
+
fout.append_data(rgb)
|
632 |
+
fout.close()
|
633 |
+
|
634 |
+
|
635 |
+
|
636 |
+
def inference_givenstyle_depth(self, style,
|
637 |
+
output_dir,
|
638 |
+
camera_mode,
|
639 |
+
style_img_path=None,
|
640 |
+
seed=1,
|
641 |
+
pad=30,
|
642 |
+
num_samples=40,
|
643 |
+
num_blocks_early_stop=6,
|
644 |
+
sample_depth=3,
|
645 |
+
tile_size=128,
|
646 |
+
resolution_hw=[540, 960],
|
647 |
+
cam_ang=72,
|
648 |
+
cam_maxstep=10):
|
649 |
+
r"""Compute result images according to the provided camera trajectory and save the results in the specified
|
650 |
+
folder. The full image is evaluated in multiple tiles to save memory.
|
651 |
+
|
652 |
+
Args:
|
653 |
+
output_dir (str): Where should the results be stored.
|
654 |
+
camera_mode (int): Which camera trajectory to use.
|
655 |
+
style_img_path (str): Path to the style-conditioning image.
|
656 |
+
seed (int): Random seed (controls style when style_image_path is not specified).
|
657 |
+
pad (int): Pixels to remove from the image tiles before stitching. Should be equal or larger than the
|
658 |
+
receptive field of the CNN to avoid border artifact.
|
659 |
+
num_samples (int): Number of samples per ray (different from training).
|
660 |
+
num_blocks_early_stop (int): Max number of intersected boxes per ray before stopping
|
661 |
+
(different from training).
|
662 |
+
sample_depth (float): Max distance traveled through boxes before stopping (different from training).
|
663 |
+
tile_size (int): Max size of a tile in pixels.
|
664 |
+
resolution_hw (list [H, W]): Resolution of the output image.
|
665 |
+
cam_ang (float): Horizontal FOV of the camera (may be adjusted by the camera controller).
|
666 |
+
cam_maxstep (int): Number of frames sampled from the camera trajectory.
|
667 |
+
"""
|
668 |
+
|
669 |
+
def write_img(path, img, rgb_input=False):
|
670 |
+
img = ((img*0.5+0.5)*255).detach().cpu().numpy().astype(np.uint8)
|
671 |
+
img = img[0].transpose(1, 2, 0)
|
672 |
+
if rgb_input:
|
673 |
+
img = img[..., [2, 1, 0]]
|
674 |
+
cv2.imwrite(path, img, [cv2.IMWRITE_PNG_COMPRESSION, 4])
|
675 |
+
return img[..., ::-1]
|
676 |
+
|
677 |
+
def read_img(path):
|
678 |
+
img = cv2.imread(path).astype(np.float32)[..., [2, 1, 0]].transpose(2, 0, 1) / 255
|
679 |
+
img = img * 2 - 1
|
680 |
+
img = torch.from_numpy(img)
|
681 |
+
|
682 |
+
print('Saving to', output_dir)
|
683 |
+
|
684 |
+
# Use provided random seed.
|
685 |
+
device = torch.device('cuda')
|
686 |
+
|
687 |
+
global_enc = self.world_encoder(self.voxel.current_height_map, self.voxel.current_semantic_map)
|
688 |
+
|
689 |
+
biome_colors = torch.Tensor([
|
690 |
+
[255, 255, 178],
|
691 |
+
[184, 200, 98],
|
692 |
+
[188, 161, 53],
|
693 |
+
[190, 255, 242],
|
694 |
+
[106, 144, 38],
|
695 |
+
[33, 77, 41],
|
696 |
+
[86, 179, 106],
|
697 |
+
[34, 61, 53],
|
698 |
+
[35, 114, 94],
|
699 |
+
[0, 0, 255],
|
700 |
+
[0, 255, 0],
|
701 |
+
]) / 255 * 2 - 1
|
702 |
+
print(self.voxel.current_height_map[0].shape)
|
703 |
+
semantic_map = torch.argmax(self.voxel.current_semantic_map, dim=1)
|
704 |
+
print(torch.unique(semantic_map, return_counts=True))
|
705 |
+
print(semantic_map.min())
|
706 |
+
|
707 |
+
self.pad = pad
|
708 |
+
self.num_samples = num_samples
|
709 |
+
self.num_blocks_early_stop = num_blocks_early_stop
|
710 |
+
self.sample_depth = sample_depth
|
711 |
+
|
712 |
+
self.coarse_deterministic_sampling = True
|
713 |
+
self.crop_size = resolution_hw
|
714 |
+
self.cam_res = [self.crop_size[0]+self.pad, self.crop_size[1]+self.pad]
|
715 |
+
self.use_label_smooth_pgt = False
|
716 |
+
|
717 |
+
# Make output dirs.
|
718 |
+
gancraft_outputs_dir = os.path.join(output_dir, 'gancraft_outputs')
|
719 |
+
os.makedirs(gancraft_outputs_dir, exist_ok=True)
|
720 |
+
gancraft_depth_outputs_dir = os.path.join(output_dir, 'depth')
|
721 |
+
os.makedirs(gancraft_depth_outputs_dir, exist_ok=True)
|
722 |
+
vis_masks_dir = os.path.join(output_dir, 'vis_masks')
|
723 |
+
os.makedirs(vis_masks_dir, exist_ok=True)
|
724 |
+
fout = imageio.get_writer(gancraft_outputs_dir + '.mp4', fps=10)
|
725 |
+
fout_cat = imageio.get_writer(gancraft_outputs_dir + '-vis_masks.mp4', fps=10)
|
726 |
+
|
727 |
+
write_img(os.path.join(output_dir, 'semantic_map.png'), biome_colors[semantic_map].permute(0, 3, 1, 2), rgb_input=True)
|
728 |
+
write_img(os.path.join(output_dir, 'heightmap.png'), self.voxel.current_height_map)
|
729 |
+
|
730 |
+
evalcamctl = camctl.EvalCameraController(
|
731 |
+
self.voxel, maxstep=cam_maxstep, pattern=camera_mode, cam_ang=cam_ang,
|
732 |
+
smooth_decay_multiplier=150/cam_maxstep)
|
733 |
+
|
734 |
+
# import pickle
|
735 |
+
# with open(os.path.join(output_dir,'camera.pkl'), 'wb') as f:
|
736 |
+
# pickle.dump(evalcamctl, f)
|
737 |
+
|
738 |
+
# Get output style.
|
739 |
+
z = self.style_net(style)
|
740 |
+
|
741 |
+
# Generate required output images.
|
742 |
+
for id, (cam_ori_t, cam_dir_t, cam_up_t, cam_f) in enumerate(evalcamctl):
|
743 |
+
# print('Rendering frame', id)
|
744 |
+
cam_f = cam_f * (self.crop_size[1]-1) # So that the view is not depending on the padding
|
745 |
+
cam_c = [(self.cam_res[0]-1)/2, (self.cam_res[1]-1)/2]
|
746 |
+
|
747 |
+
voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective(
|
748 |
+
self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, self.cam_res,
|
749 |
+
self.num_blocks_early_stop)
|
750 |
+
|
751 |
+
voxel_id = voxel_id.unsqueeze(0)
|
752 |
+
depth2 = depth2.unsqueeze(0)
|
753 |
+
raydirs = raydirs.unsqueeze(0)
|
754 |
+
cam_ori_t = cam_ori_t.unsqueeze(0).to(device)
|
755 |
+
|
756 |
+
# Save 3D voxel rendering.
|
757 |
+
mc_rgb = self.label_trans.mc_color(voxel_id[0, :, :, 0, 0].cpu().numpy())
|
758 |
+
# Diffused shading, co-located light.
|
759 |
+
first_intersection_depth = depth2[:, 0, :, :, 0, None, :] # [1, 542, 542, 1, 1].
|
760 |
+
first_intersection_point = raydirs * first_intersection_depth + cam_ori_t[:, None, None, None, :]
|
761 |
+
fip_local_coords = torch.remainder(first_intersection_point, 1.0)
|
762 |
+
fip_wall_proximity = torch.minimum(fip_local_coords, 1.0-fip_local_coords)
|
763 |
+
fip_wall_orientation = torch.argmin(fip_wall_proximity, dim=-1, keepdim=False)
|
764 |
+
# 0: [1,0,0]; 1: [0,1,0]; 2: [0,0,1]
|
765 |
+
lut = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32,
|
766 |
+
device=fip_wall_orientation.device)
|
767 |
+
fip_normal = lut[fip_wall_orientation] # [1, 542, 542, 1, 3]
|
768 |
+
diffuse_shade = torch.abs(torch.sum(fip_normal * raydirs, dim=-1))
|
769 |
+
|
770 |
+
mc_rgb = (mc_rgb.astype(np.float) / 255) ** 2.2
|
771 |
+
mc_rgb = mc_rgb * diffuse_shade[0, :, :, :].cpu().numpy()
|
772 |
+
mc_rgb = (mc_rgb ** (1/2.2)) * 255
|
773 |
+
mc_rgb = mc_rgb.astype(np.uint8)
|
774 |
+
if self.pad > 0:
|
775 |
+
mc_rgb = mc_rgb[self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
|
776 |
+
cv2.imwrite(os.path.join(vis_masks_dir, '{:05d}.png'.format(id)), mc_rgb, [cv2.IMWRITE_PNG_COMPRESSION, 4])
|
777 |
+
|
778 |
+
# Tiled eval of GANcraft.
|
779 |
+
voxel_id_all = voxel_id
|
780 |
+
depth2_all = depth2
|
781 |
+
raydirs_all = raydirs
|
782 |
+
|
783 |
+
# Evaluate sky in advance to get a consistent sky in the semi-transparent region.
|
784 |
+
if self.sky_global_avgpool:
|
785 |
+
sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
|
786 |
+
sky_raydirs_in = voxlib.positional_encoding(
|
787 |
+
sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1])
|
788 |
+
skynet_out_c = self.sky_net(sky_raydirs_in, z)
|
789 |
+
sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True)
|
790 |
+
self.sky_avg = sky_avg
|
791 |
+
|
792 |
+
num_strips_h = (self.cam_res[0]-self.pad+tile_size-1)//tile_size
|
793 |
+
num_strips_w = (self.cam_res[1]-self.pad+tile_size-1)//tile_size
|
794 |
+
|
795 |
+
fake_images_chunks_v = []
|
796 |
+
fake_depth_chunks_v = []
|
797 |
+
# For each horizontal strip.
|
798 |
+
for strip_id_h in range(num_strips_h):
|
799 |
+
strip_begin_h = strip_id_h * tile_size
|
800 |
+
strip_end_h = np.minimum(strip_id_h * tile_size + tile_size + self.pad, self.cam_res[0])
|
801 |
+
# For each vertical strip.
|
802 |
+
fake_images_chunks_h = []
|
803 |
+
fake_depth_chunks_h = []
|
804 |
+
for strip_id_w in range(num_strips_w):
|
805 |
+
strip_begin_w = strip_id_w * tile_size
|
806 |
+
strip_end_w = np.minimum(strip_id_w * tile_size + tile_size + self.pad, self.cam_res[1])
|
807 |
+
|
808 |
+
voxel_id = voxel_id_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
|
809 |
+
depth2 = depth2_all[:, :, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
|
810 |
+
raydirs = raydirs_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
|
811 |
+
|
812 |
+
net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \
|
813 |
+
nosky_mask, sky_mask, sky_only_mask, new_idx = self._forward_perpix(
|
814 |
+
self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z, global_enc)
|
815 |
+
fake_images, _ = self._forward_global(net_out, z)
|
816 |
+
depth_map = torch.sum(weights * rand_depth, -2)
|
817 |
+
# disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map).to(depth_map), depth_map / torch.sum(weights, -2))
|
818 |
+
# depth_map = torch.clip(depth_map, 0, 100.)
|
819 |
+
# disp_map = 1. / (depth_map.permute(0, 3, 1, 2))
|
820 |
+
disp_map = depth_map.permute(0, 3, 1, 2)
|
821 |
+
if self.pad != 0:
|
822 |
+
fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
|
823 |
+
disp_map = disp_map[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
|
824 |
+
fake_images_chunks_h.append(fake_images)
|
825 |
+
fake_depth_chunks_h.append(disp_map)
|
826 |
+
fake_images_h = torch.cat(fake_images_chunks_h, dim=-1)
|
827 |
+
fake_depth_h = torch.cat(fake_depth_chunks_h, dim=-1)
|
828 |
+
fake_images_chunks_v.append(fake_images_h)
|
829 |
+
fake_depth_chunks_v.append(fake_depth_h)
|
830 |
+
fake_images = torch.cat(fake_images_chunks_v, dim=-2)
|
831 |
+
fake_depth = torch.cat(fake_depth_chunks_v, dim=-2)
|
832 |
+
# fake_depth = ((fake_depth - fake_depth.mean()) / fake_depth.std() + 1) / 2
|
833 |
+
# fake_depth = torch.clip(1./ (fake_depth + 1e-4), 0., 1.)
|
834 |
+
# fake_depth = ((fake_depth - fake_depth.mean()) / fake_depth.std() + 1) / 2
|
835 |
+
mmask = fake_depth > 0
|
836 |
+
tmp = fake_depth[mmask]
|
837 |
+
# tmp = 1. / (tmp + 1e-4)
|
838 |
+
tmp = (tmp - tmp.min()) / (tmp.max() - tmp.min())
|
839 |
+
# tmp = ((tmp - tmp.mean()) / tmp.std() + 1) / 2.
|
840 |
+
fake_depth[~mmask] = 1
|
841 |
+
fake_depth[mmask] = tmp
|
842 |
+
# fake_depth = (fake_depth - fake_depth.min()) / (fake_depth.max() - fake_depth.min())
|
843 |
+
|
844 |
+
cv2.imwrite(os.path.join(gancraft_depth_outputs_dir, '{:05d}.png'.format(id)), fake_depth[0].permute(1, 2, 0).detach().cpu().numpy() * 255)
|
845 |
+
rgb = write_img(os.path.join(gancraft_outputs_dir,
|
846 |
+
'{:05d}.png'.format(id)), fake_images, rgb_input=True)
|
847 |
+
fout.append_data(rgb)
|
848 |
+
fout_cat.append_data(np.concatenate((mc_rgb[..., ::-1], rgb), axis=1))
|
849 |
+
fout.close()
|
850 |
+
fout_cat.close()
|
851 |
+
|
imaginaire/generators/spade.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import functools
|
6 |
+
import math
|
7 |
+
import types
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.nn import Upsample as NearestUpsample
|
14 |
+
|
15 |
+
from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock
|
16 |
+
from imaginaire.utils.data import (get_crop_h_w,
|
17 |
+
get_paired_input_image_channel_number,
|
18 |
+
get_paired_input_label_channel_number)
|
19 |
+
from imaginaire.utils.distributed import master_only_print as print
|
20 |
+
|
21 |
+
|
22 |
+
class Generator(nn.Module):
|
23 |
+
r"""SPADE generator constructor.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
gen_cfg (obj): Generator definition part of the yaml config file.
|
27 |
+
data_cfg (obj): Data definition part of the yaml config file.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, gen_cfg, data_cfg):
|
31 |
+
super(Generator, self).__init__()
|
32 |
+
print('SPADE generator initialization.')
|
33 |
+
# We assume the first datum is the ground truth image.
|
34 |
+
image_channels = getattr(gen_cfg, 'image_channels', None)
|
35 |
+
if image_channels is None:
|
36 |
+
image_channels = get_paired_input_image_channel_number(data_cfg)
|
37 |
+
num_labels = getattr(gen_cfg, 'num_labels', None)
|
38 |
+
if num_labels is None:
|
39 |
+
# Calculate number of channels in the input label when not specified.
|
40 |
+
num_labels = get_paired_input_label_channel_number(data_cfg)
|
41 |
+
crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations)
|
42 |
+
# Build the generator
|
43 |
+
out_image_small_side_size = crop_w if crop_w < crop_h else crop_h
|
44 |
+
num_filters = getattr(gen_cfg, 'num_filters', 128)
|
45 |
+
kernel_size = getattr(gen_cfg, 'kernel_size', 3)
|
46 |
+
weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral')
|
47 |
+
|
48 |
+
cond_dims = 0
|
49 |
+
# Check whether we use the style code.
|
50 |
+
style_dims = getattr(gen_cfg, 'style_dims', None)
|
51 |
+
self.style_dims = style_dims
|
52 |
+
if style_dims is not None:
|
53 |
+
print('\tStyle code dimensions: %d' % style_dims)
|
54 |
+
cond_dims += style_dims
|
55 |
+
self.use_style = True
|
56 |
+
else:
|
57 |
+
self.use_style = False
|
58 |
+
# Check whether we use the attribute code.
|
59 |
+
if hasattr(gen_cfg, 'attribute_dims'):
|
60 |
+
self.use_attribute = True
|
61 |
+
self.attribute_dims = gen_cfg.attribute_dims
|
62 |
+
cond_dims += gen_cfg.attribute_dims
|
63 |
+
else:
|
64 |
+
self.use_attribute = False
|
65 |
+
|
66 |
+
if not self.use_style and not self.use_attribute:
|
67 |
+
self.use_style_encoder = False
|
68 |
+
else:
|
69 |
+
self.use_style_encoder = True
|
70 |
+
print('\tBase filter number: %d' % num_filters)
|
71 |
+
print('\tConvolution kernel size: %d' % kernel_size)
|
72 |
+
print('\tWeight norm type: %s' % weight_norm_type)
|
73 |
+
skip_activation_norm = \
|
74 |
+
getattr(gen_cfg, 'skip_activation_norm', True)
|
75 |
+
activation_norm_params = getattr(gen_cfg, 'activation_norm_params', None)
|
76 |
+
if activation_norm_params is None:
|
77 |
+
activation_norm_params = types.SimpleNamespace()
|
78 |
+
if not hasattr(activation_norm_params, 'num_filters'):
|
79 |
+
setattr(activation_norm_params, 'num_filters', 128)
|
80 |
+
if not hasattr(activation_norm_params, 'kernel_size'):
|
81 |
+
setattr(activation_norm_params, 'kernel_size', 3)
|
82 |
+
if not hasattr(activation_norm_params, 'activation_norm_type'):
|
83 |
+
setattr(activation_norm_params, 'activation_norm_type', 'sync_batch')
|
84 |
+
if not hasattr(activation_norm_params, 'separate_projection'):
|
85 |
+
setattr(activation_norm_params, 'separate_projection', False)
|
86 |
+
if not hasattr(activation_norm_params, 'activation_norm_params'):
|
87 |
+
activation_norm_params.activation_norm_params = types.SimpleNamespace()
|
88 |
+
activation_norm_params.activation_norm_params.affine = True
|
89 |
+
setattr(activation_norm_params, 'cond_dims', num_labels)
|
90 |
+
if not hasattr(activation_norm_params, 'weight_norm_type'):
|
91 |
+
setattr(activation_norm_params, 'weight_norm_type', weight_norm_type)
|
92 |
+
global_adaptive_norm_type = getattr(gen_cfg, 'global_adaptive_norm_type', 'sync_batch')
|
93 |
+
use_posenc_in_input_layer = getattr(gen_cfg, 'use_posenc_in_input_layer', True)
|
94 |
+
output_multiplier = getattr(gen_cfg, 'output_multiplier', 1.0)
|
95 |
+
print(activation_norm_params)
|
96 |
+
self.spade_generator = SPADEGenerator(num_labels,
|
97 |
+
out_image_small_side_size,
|
98 |
+
image_channels,
|
99 |
+
num_filters,
|
100 |
+
kernel_size,
|
101 |
+
cond_dims,
|
102 |
+
activation_norm_params,
|
103 |
+
weight_norm_type,
|
104 |
+
global_adaptive_norm_type,
|
105 |
+
skip_activation_norm,
|
106 |
+
use_posenc_in_input_layer,
|
107 |
+
self.use_style_encoder,
|
108 |
+
output_multiplier)
|
109 |
+
if self.use_style:
|
110 |
+
# Build the encoder.
|
111 |
+
style_enc_cfg = getattr(gen_cfg, 'style_enc', None)
|
112 |
+
if style_enc_cfg is None:
|
113 |
+
style_enc_cfg = types.SimpleNamespace()
|
114 |
+
if not hasattr(style_enc_cfg, 'num_filters'):
|
115 |
+
setattr(style_enc_cfg, 'num_filters', 128)
|
116 |
+
if not hasattr(style_enc_cfg, 'kernel_size'):
|
117 |
+
setattr(style_enc_cfg, 'kernel_size', 3)
|
118 |
+
if not hasattr(style_enc_cfg, 'weight_norm_type'):
|
119 |
+
setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type)
|
120 |
+
setattr(style_enc_cfg, 'input_image_channels', image_channels)
|
121 |
+
setattr(style_enc_cfg, 'style_dims', style_dims)
|
122 |
+
self.style_encoder = StyleEncoder(style_enc_cfg)
|
123 |
+
|
124 |
+
self.z = None
|
125 |
+
print('Done with the SPADE generator initialization.')
|
126 |
+
|
127 |
+
def forward(self, data, random_style=False):
|
128 |
+
r"""SPADE Generator forward.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
data (dict):
|
132 |
+
- images (N x C1 x H x W tensor) : Ground truth images
|
133 |
+
- label (N x C2 x H x W tensor) : Semantic representations
|
134 |
+
- z (N x style_dims tensor): Gaussian random noise
|
135 |
+
- random_style (bool): Whether to sample a random style vector.
|
136 |
+
Returns:
|
137 |
+
(dict):
|
138 |
+
- fake_images (N x 3 x H x W tensor): fake images
|
139 |
+
- mu (N x C1 tensor): mean vectors
|
140 |
+
- logvar (N x C1 tensor): log-variance vectors
|
141 |
+
"""
|
142 |
+
if self.use_style_encoder:
|
143 |
+
if random_style:
|
144 |
+
bs = data['label'].size(0)
|
145 |
+
z = torch.randn(
|
146 |
+
bs, self.style_dims, dtype=torch.float32).cuda()
|
147 |
+
if (data['label'].dtype ==
|
148 |
+
data['label'].dtype == torch.float16):
|
149 |
+
z = z.half()
|
150 |
+
mu = None
|
151 |
+
logvar = None
|
152 |
+
else:
|
153 |
+
mu, logvar, z = self.style_encoder(data['images'])
|
154 |
+
if self.use_attribute:
|
155 |
+
data['z'] = torch.cat((z, data['attributes'].squeeze(1)), dim=1)
|
156 |
+
else:
|
157 |
+
data['z'] = z
|
158 |
+
output = self.spade_generator(data)
|
159 |
+
if self.use_style_encoder:
|
160 |
+
output['mu'] = mu
|
161 |
+
output['logvar'] = logvar
|
162 |
+
return output
|
163 |
+
|
164 |
+
def inference(self,
|
165 |
+
data,
|
166 |
+
random_style=False,
|
167 |
+
use_fixed_random_style=False,
|
168 |
+
keep_original_size=False):
|
169 |
+
r"""Compute results images for a batch of input data and save the
|
170 |
+
results in the specified folder.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
data (dict):
|
174 |
+
- images (N x C1 x H x W tensor) : Ground truth images
|
175 |
+
- label (N x C2 x H x W tensor) : Semantic representations
|
176 |
+
- z (N x style_dims tensor): Gaussian random noise
|
177 |
+
random_style (bool): Whether to sample a random style vector.
|
178 |
+
use_fixed_random_style (bool): Sample random style once and use it
|
179 |
+
for all the remaining inference.
|
180 |
+
keep_original_size (bool): Keep original size of the input.
|
181 |
+
Returns:
|
182 |
+
(dict):
|
183 |
+
- fake_images (N x 3 x H x W tensor): fake images
|
184 |
+
- mu (N x C1 tensor): mean vectors
|
185 |
+
- logvar (N x C1 tensor): log-variance vectors
|
186 |
+
"""
|
187 |
+
self.eval()
|
188 |
+
self.spade_generator.eval()
|
189 |
+
|
190 |
+
if self.use_style_encoder:
|
191 |
+
if random_style and self.use_style_encoder:
|
192 |
+
if self.z is None or not use_fixed_random_style:
|
193 |
+
bs = data['label'].size(0)
|
194 |
+
z = torch.randn(
|
195 |
+
bs, self.style_dims, dtype=torch.float32).to('cuda')
|
196 |
+
if (data['label'].dtype ==
|
197 |
+
data['label'].dtype ==
|
198 |
+
torch.float16):
|
199 |
+
z = z.half()
|
200 |
+
self.z = z
|
201 |
+
else:
|
202 |
+
z = self.z
|
203 |
+
else:
|
204 |
+
mu, logvar, z = self.style_encoder(data['images'])
|
205 |
+
data['z'] = z
|
206 |
+
|
207 |
+
output = self.spade_generator(data)
|
208 |
+
output_images = output['fake_images']
|
209 |
+
|
210 |
+
if keep_original_size:
|
211 |
+
height = data['original_h_w'][0][0]
|
212 |
+
width = data['original_h_w'][0][1]
|
213 |
+
output_images = torch.nn.functional.interpolate(
|
214 |
+
output_images, size=[height, width])
|
215 |
+
|
216 |
+
for key in data['key'].keys():
|
217 |
+
if 'segmaps' in key or 'seg_maps' in key:
|
218 |
+
file_names = data['key'][key][0]
|
219 |
+
break
|
220 |
+
for key in data['key'].keys():
|
221 |
+
if 'edgemaps' in key or 'edge_maps' in key:
|
222 |
+
file_names = data['key'][key][0]
|
223 |
+
break
|
224 |
+
|
225 |
+
return output_images, file_names
|
226 |
+
|
227 |
+
|
228 |
+
class SPADEGenerator(nn.Module):
|
229 |
+
r"""SPADE Image Generator constructor.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
num_labels (int): Number of different labels.
|
233 |
+
out_image_small_side_size (int): min(width, height)
|
234 |
+
image_channels (int): Num. of channels of the output image.
|
235 |
+
num_filters (int): Base filter numbers.
|
236 |
+
kernel_size (int): Convolution kernel size.
|
237 |
+
style_dims (int): Dimensions of the style code.
|
238 |
+
activation_norm_params (obj): Spatially adaptive normalization param.
|
239 |
+
weight_norm_type (str): Type of weight normalization.
|
240 |
+
``'none'``, ``'spectral'``, or ``'weight'``.
|
241 |
+
global_adaptive_norm_type (str): Type of normalization in SPADE.
|
242 |
+
skip_activation_norm (bool): If ``True``, applies activation norm to the
|
243 |
+
shortcut connection in residual blocks.
|
244 |
+
use_style_encoder (bool): Whether to use global adaptive norm
|
245 |
+
like conditional batch norm or adaptive instance norm.
|
246 |
+
output_multiplier (float): A positive number multiplied to the output
|
247 |
+
"""
|
248 |
+
|
249 |
+
def __init__(self,
|
250 |
+
num_labels,
|
251 |
+
out_image_small_side_size,
|
252 |
+
image_channels,
|
253 |
+
num_filters,
|
254 |
+
kernel_size,
|
255 |
+
style_dims,
|
256 |
+
activation_norm_params,
|
257 |
+
weight_norm_type,
|
258 |
+
global_adaptive_norm_type,
|
259 |
+
skip_activation_norm,
|
260 |
+
use_posenc_in_input_layer,
|
261 |
+
use_style_encoder,
|
262 |
+
output_multiplier):
|
263 |
+
super(SPADEGenerator, self).__init__()
|
264 |
+
self.output_multiplier = output_multiplier
|
265 |
+
self.use_style_encoder = use_style_encoder
|
266 |
+
self.use_posenc_in_input_layer = use_posenc_in_input_layer
|
267 |
+
self.out_image_small_side_size = out_image_small_side_size
|
268 |
+
self.num_filters = num_filters
|
269 |
+
padding = int(np.ceil((kernel_size - 1.0) / 2))
|
270 |
+
nonlinearity = 'leakyrelu'
|
271 |
+
activation_norm_type = 'spatially_adaptive'
|
272 |
+
base_res2d_block = \
|
273 |
+
functools.partial(Res2dBlock,
|
274 |
+
kernel_size=kernel_size,
|
275 |
+
padding=padding,
|
276 |
+
bias=[True, True, False],
|
277 |
+
weight_norm_type=weight_norm_type,
|
278 |
+
activation_norm_type=activation_norm_type,
|
279 |
+
activation_norm_params=activation_norm_params,
|
280 |
+
skip_activation_norm=skip_activation_norm,
|
281 |
+
nonlinearity=nonlinearity,
|
282 |
+
order='NACNAC')
|
283 |
+
if self.use_style_encoder:
|
284 |
+
self.fc_0 = LinearBlock(style_dims, 2 * style_dims,
|
285 |
+
weight_norm_type=weight_norm_type,
|
286 |
+
nonlinearity='relu',
|
287 |
+
order='CAN')
|
288 |
+
self.fc_1 = LinearBlock(2 * style_dims, 2 * style_dims,
|
289 |
+
weight_norm_type=weight_norm_type,
|
290 |
+
nonlinearity='relu',
|
291 |
+
order='CAN')
|
292 |
+
|
293 |
+
adaptive_norm_params = types.SimpleNamespace()
|
294 |
+
if not hasattr(adaptive_norm_params, 'cond_dims'):
|
295 |
+
setattr(adaptive_norm_params, 'cond_dims', 2 * style_dims)
|
296 |
+
if not hasattr(adaptive_norm_params, 'activation_norm_type'):
|
297 |
+
setattr(adaptive_norm_params, 'activation_norm_type', global_adaptive_norm_type)
|
298 |
+
if not hasattr(adaptive_norm_params, 'weight_norm_type'):
|
299 |
+
setattr(adaptive_norm_params, 'weight_norm_type', activation_norm_params.weight_norm_type)
|
300 |
+
if not hasattr(adaptive_norm_params, 'separate_projection'):
|
301 |
+
setattr(adaptive_norm_params, 'separate_projection', activation_norm_params.separate_projection)
|
302 |
+
adaptive_norm_params.activation_norm_params = types.SimpleNamespace()
|
303 |
+
setattr(adaptive_norm_params.activation_norm_params, 'affine',
|
304 |
+
activation_norm_params.activation_norm_params.affine)
|
305 |
+
base_cbn2d_block = \
|
306 |
+
functools.partial(Conv2dBlock,
|
307 |
+
kernel_size=kernel_size,
|
308 |
+
stride=1,
|
309 |
+
padding=padding,
|
310 |
+
bias=True,
|
311 |
+
weight_norm_type=weight_norm_type,
|
312 |
+
activation_norm_type='adaptive',
|
313 |
+
activation_norm_params=adaptive_norm_params,
|
314 |
+
nonlinearity=nonlinearity,
|
315 |
+
order='NAC')
|
316 |
+
else:
|
317 |
+
base_conv2d_block = \
|
318 |
+
functools.partial(Conv2dBlock,
|
319 |
+
kernel_size=kernel_size,
|
320 |
+
stride=1,
|
321 |
+
padding=padding,
|
322 |
+
bias=True,
|
323 |
+
weight_norm_type=weight_norm_type,
|
324 |
+
nonlinearity=nonlinearity,
|
325 |
+
order='NAC')
|
326 |
+
in_num_labels = num_labels
|
327 |
+
in_num_labels += 2 if self.use_posenc_in_input_layer else 0
|
328 |
+
self.head_0 = Conv2dBlock(in_num_labels, 8 * num_filters,
|
329 |
+
kernel_size=kernel_size, stride=1,
|
330 |
+
padding=padding,
|
331 |
+
weight_norm_type=weight_norm_type,
|
332 |
+
activation_norm_type='none',
|
333 |
+
nonlinearity=nonlinearity)
|
334 |
+
if self.use_style_encoder:
|
335 |
+
self.cbn_head_0 = base_cbn2d_block(
|
336 |
+
8 * num_filters, 16 * num_filters)
|
337 |
+
else:
|
338 |
+
self.conv_head_0 = base_conv2d_block(
|
339 |
+
8 * num_filters, 16 * num_filters)
|
340 |
+
self.head_1 = base_res2d_block(16 * num_filters, 16 * num_filters)
|
341 |
+
self.head_2 = base_res2d_block(16 * num_filters, 16 * num_filters)
|
342 |
+
|
343 |
+
self.up_0a = base_res2d_block(16 * num_filters, 8 * num_filters)
|
344 |
+
if self.use_style_encoder:
|
345 |
+
self.cbn_up_0a = base_cbn2d_block(
|
346 |
+
8 * num_filters, 8 * num_filters)
|
347 |
+
else:
|
348 |
+
self.conv_up_0a = base_conv2d_block(
|
349 |
+
8 * num_filters, 8 * num_filters)
|
350 |
+
self.up_0b = base_res2d_block(8 * num_filters, 8 * num_filters)
|
351 |
+
|
352 |
+
self.up_1a = base_res2d_block(8 * num_filters, 4 * num_filters)
|
353 |
+
if self.use_style_encoder:
|
354 |
+
self.cbn_up_1a = base_cbn2d_block(
|
355 |
+
4 * num_filters, 4 * num_filters)
|
356 |
+
else:
|
357 |
+
self.conv_up_1a = base_conv2d_block(
|
358 |
+
4 * num_filters, 4 * num_filters)
|
359 |
+
self.up_1b = base_res2d_block(4 * num_filters, 4 * num_filters)
|
360 |
+
self.up_2a = base_res2d_block(4 * num_filters, 4 * num_filters)
|
361 |
+
if self.use_style_encoder:
|
362 |
+
self.cbn_up_2a = base_cbn2d_block(
|
363 |
+
4 * num_filters, 4 * num_filters)
|
364 |
+
else:
|
365 |
+
self.conv_up_2a = base_conv2d_block(
|
366 |
+
4 * num_filters, 4 * num_filters)
|
367 |
+
self.up_2b = base_res2d_block(4 * num_filters, 2 * num_filters)
|
368 |
+
self.conv_img256 = Conv2dBlock(2 * num_filters, image_channels,
|
369 |
+
5, stride=1, padding=2,
|
370 |
+
weight_norm_type=weight_norm_type,
|
371 |
+
activation_norm_type='none',
|
372 |
+
nonlinearity=nonlinearity,
|
373 |
+
order='ANC')
|
374 |
+
self.base = 16
|
375 |
+
if self.out_image_small_side_size == 512:
|
376 |
+
self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters)
|
377 |
+
self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters)
|
378 |
+
self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels,
|
379 |
+
5, stride=1, padding=2,
|
380 |
+
weight_norm_type=weight_norm_type,
|
381 |
+
activation_norm_type='none',
|
382 |
+
nonlinearity=nonlinearity,
|
383 |
+
order='ANC')
|
384 |
+
self.base = 32
|
385 |
+
if self.out_image_small_side_size == 1024:
|
386 |
+
self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters)
|
387 |
+
self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters)
|
388 |
+
self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels,
|
389 |
+
5, stride=1, padding=2,
|
390 |
+
weight_norm_type=weight_norm_type,
|
391 |
+
activation_norm_type='none',
|
392 |
+
nonlinearity=nonlinearity,
|
393 |
+
order='ANC')
|
394 |
+
self.up_4a = base_res2d_block(num_filters, num_filters // 2)
|
395 |
+
self.up_4b = base_res2d_block(num_filters // 2, num_filters // 2)
|
396 |
+
self.conv_img1024 = Conv2dBlock(num_filters // 2, image_channels,
|
397 |
+
5, stride=1, padding=2,
|
398 |
+
weight_norm_type=weight_norm_type,
|
399 |
+
activation_norm_type='none',
|
400 |
+
nonlinearity=nonlinearity,
|
401 |
+
order='ANC')
|
402 |
+
self.nearest_upsample4x = NearestUpsample(scale_factor=4, mode='nearest')
|
403 |
+
self.base = 64
|
404 |
+
if self.out_image_small_side_size != 256 and self.out_image_small_side_size != 512 \
|
405 |
+
and self.out_image_small_side_size != 1024:
|
406 |
+
raise ValueError('Generation image size (%d, %d) not supported' %
|
407 |
+
(self.out_image_small_side_size,
|
408 |
+
self.out_image_small_side_size))
|
409 |
+
self.nearest_upsample2x = NearestUpsample(scale_factor=2, mode='nearest')
|
410 |
+
|
411 |
+
xv, yv = torch.meshgrid(
|
412 |
+
[torch.arange(-1, 1.1, 2. / 15), torch.arange(-1, 1.1, 2. / 15)])
|
413 |
+
self.xy = torch.cat((xv.unsqueeze(0), yv.unsqueeze(0)), 0).unsqueeze(0)
|
414 |
+
self.xy = self.xy.cuda()
|
415 |
+
|
416 |
+
def forward(self, data):
|
417 |
+
r"""SPADE Generator forward.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
data (dict):
|
421 |
+
- data (N x C1 x H x W tensor) : Ground truth images.
|
422 |
+
- label (N x C2 x H x W tensor) : Semantic representations.
|
423 |
+
- z (N x style_dims tensor): Gaussian random noise.
|
424 |
+
Returns:
|
425 |
+
output (dict):
|
426 |
+
- fake_images (N x 3 x H x W tensor): Fake images.
|
427 |
+
"""
|
428 |
+
seg = data['label']
|
429 |
+
|
430 |
+
if self.use_style_encoder:
|
431 |
+
z = data['z']
|
432 |
+
z = self.fc_0(z)
|
433 |
+
z = self.fc_1(z)
|
434 |
+
|
435 |
+
# The code piece below makes sure that the input size is always 16x16
|
436 |
+
sy = math.floor(seg.size()[2] * 1.0 / self.base)
|
437 |
+
sx = math.floor(seg.size()[3] * 1.0 / self.base)
|
438 |
+
|
439 |
+
in_seg = F.interpolate(seg, size=[sy, sx], mode='nearest')
|
440 |
+
if self.use_posenc_in_input_layer:
|
441 |
+
in_xy = F.interpolate(self.xy, size=[sy, sx], mode='bicubic')
|
442 |
+
in_seg_xy = torch.cat(
|
443 |
+
(in_seg, in_xy.expand(in_seg.size()[0], 2, sy, sx)), 1)
|
444 |
+
else:
|
445 |
+
in_seg_xy = in_seg
|
446 |
+
# 16x16
|
447 |
+
x = self.head_0(in_seg_xy)
|
448 |
+
if self.use_style_encoder:
|
449 |
+
x = self.cbn_head_0(x, z)
|
450 |
+
else:
|
451 |
+
x = self.conv_head_0(x)
|
452 |
+
x = self.head_1(x, seg)
|
453 |
+
x = self.head_2(x, seg)
|
454 |
+
x = self.nearest_upsample2x(x)
|
455 |
+
# 32x32
|
456 |
+
x = self.up_0a(x, seg)
|
457 |
+
if self.use_style_encoder:
|
458 |
+
x = self.cbn_up_0a(x, z)
|
459 |
+
else:
|
460 |
+
x = self.conv_up_0a(x)
|
461 |
+
x = self.up_0b(x, seg)
|
462 |
+
x = self.nearest_upsample2x(x)
|
463 |
+
# 64x64
|
464 |
+
x = self.up_1a(x, seg)
|
465 |
+
if self.use_style_encoder:
|
466 |
+
x = self.cbn_up_1a(x, z)
|
467 |
+
else:
|
468 |
+
x = self.conv_up_1a(x)
|
469 |
+
x = self.up_1b(x, seg)
|
470 |
+
x = self.nearest_upsample2x(x)
|
471 |
+
# 128x128
|
472 |
+
x = self.up_2a(x, seg)
|
473 |
+
if self.use_style_encoder:
|
474 |
+
x = self.cbn_up_2a(x, z)
|
475 |
+
else:
|
476 |
+
x = self.conv_up_2a(x)
|
477 |
+
x = self.up_2b(x, seg)
|
478 |
+
x = self.nearest_upsample2x(x)
|
479 |
+
# 256x256
|
480 |
+
if self.out_image_small_side_size == 256:
|
481 |
+
x256 = self.conv_img256(x)
|
482 |
+
x = torch.tanh(self.output_multiplier * x256)
|
483 |
+
# 512x512
|
484 |
+
elif self.out_image_small_side_size == 512:
|
485 |
+
x256 = self.conv_img256(x)
|
486 |
+
x256 = self.nearest_upsample2x(x256)
|
487 |
+
x = self.up_3a(x, seg)
|
488 |
+
x = self.up_3b(x, seg)
|
489 |
+
x = self.nearest_upsample2x(x)
|
490 |
+
x512 = self.conv_img512(x)
|
491 |
+
x = torch.tanh(self.output_multiplier * (x256 + x512))
|
492 |
+
# 1024x1024
|
493 |
+
elif self.out_image_small_side_size == 1024:
|
494 |
+
x256 = self.conv_img256(x)
|
495 |
+
x256 = self.nearest_upsample4x(x256)
|
496 |
+
x = self.up_3a(x, seg)
|
497 |
+
x = self.up_3b(x, seg)
|
498 |
+
x = self.nearest_upsample2x(x)
|
499 |
+
x512 = self.conv_img512(x)
|
500 |
+
x512 = self.nearest_upsample2x(x512)
|
501 |
+
x = self.up_4a(x, seg)
|
502 |
+
x = self.up_4b(x, seg)
|
503 |
+
x = self.nearest_upsample2x(x)
|
504 |
+
x1024 = self.conv_img1024(x)
|
505 |
+
x = torch.tanh(self.output_multiplier * (x256 + x512 + x1024))
|
506 |
+
output = dict()
|
507 |
+
output['fake_images'] = x
|
508 |
+
return output
|
509 |
+
|
510 |
+
|
511 |
+
class StyleEncoder(nn.Module):
|
512 |
+
r"""Style Encode constructor.
|
513 |
+
|
514 |
+
Args:
|
515 |
+
style_enc_cfg (obj): Style encoder definition file.
|
516 |
+
"""
|
517 |
+
|
518 |
+
def __init__(self, style_enc_cfg):
|
519 |
+
super(StyleEncoder, self).__init__()
|
520 |
+
input_image_channels = style_enc_cfg.input_image_channels
|
521 |
+
num_filters = style_enc_cfg.num_filters
|
522 |
+
kernel_size = style_enc_cfg.kernel_size
|
523 |
+
padding = int(np.ceil((kernel_size - 1.0) / 2))
|
524 |
+
style_dims = style_enc_cfg.style_dims
|
525 |
+
weight_norm_type = style_enc_cfg.weight_norm_type
|
526 |
+
activation_norm_type = 'none'
|
527 |
+
nonlinearity = 'leakyrelu'
|
528 |
+
base_conv2d_block = \
|
529 |
+
functools.partial(Conv2dBlock,
|
530 |
+
kernel_size=kernel_size,
|
531 |
+
stride=2,
|
532 |
+
padding=padding,
|
533 |
+
weight_norm_type=weight_norm_type,
|
534 |
+
activation_norm_type=activation_norm_type,
|
535 |
+
# inplace_nonlinearity=True,
|
536 |
+
nonlinearity=nonlinearity)
|
537 |
+
self.layer1 = base_conv2d_block(input_image_channels, num_filters)
|
538 |
+
self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2)
|
539 |
+
self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4)
|
540 |
+
self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8)
|
541 |
+
self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8)
|
542 |
+
self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8)
|
543 |
+
self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
|
544 |
+
self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
|
545 |
+
|
546 |
+
def forward(self, input_x):
|
547 |
+
r"""SPADE Style Encoder forward.
|
548 |
+
|
549 |
+
Args:
|
550 |
+
input_x (N x 3 x H x W tensor): input images.
|
551 |
+
Returns:
|
552 |
+
(tuple):
|
553 |
+
- mu (N x C tensor): Mean vectors.
|
554 |
+
- logvar (N x C tensor): Log-variance vectors.
|
555 |
+
- z (N x C tensor): Style code vectors.
|
556 |
+
"""
|
557 |
+
if input_x.size(2) != 256 or input_x.size(3) != 256:
|
558 |
+
input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear')
|
559 |
+
x = self.layer1(input_x)
|
560 |
+
x = self.layer2(x)
|
561 |
+
x = self.layer3(x)
|
562 |
+
x = self.layer4(x)
|
563 |
+
x = self.layer5(x)
|
564 |
+
x = self.layer6(x)
|
565 |
+
x = x.view(x.size(0), -1)
|
566 |
+
mu = self.fc_mu(x)
|
567 |
+
logvar = self.fc_var(x)
|
568 |
+
std = torch.exp(0.5 * logvar)
|
569 |
+
eps = torch.randn_like(std)
|
570 |
+
z = eps.mul(std) + mu
|
571 |
+
return mu, logvar, z
|
imaginaire/layers/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
from .conv import LinearBlock, Conv1dBlock, Conv2dBlock, Conv3dBlock, \
|
6 |
+
HyperConv2dBlock, MultiOutConv2dBlock, \
|
7 |
+
PartialConv2dBlock, PartialConv3dBlock
|
8 |
+
from .residual import ResLinearBlock, Res1dBlock, Res2dBlock, Res3dBlock, \
|
9 |
+
HyperRes2dBlock, MultiOutRes2dBlock, UpRes2dBlock, DownRes2dBlock, \
|
10 |
+
PartialRes2dBlock, PartialRes3dBlock
|
11 |
+
from .non_local import NonLocal2dBlock
|
12 |
+
|
13 |
+
__all__ = ['Conv1dBlock', 'Conv2dBlock', 'Conv3dBlock', 'LinearBlock',
|
14 |
+
'HyperConv2dBlock', 'MultiOutConv2dBlock',
|
15 |
+
'PartialConv2dBlock', 'PartialConv3dBlock',
|
16 |
+
'Res1dBlock', 'Res2dBlock', 'Res3dBlock',
|
17 |
+
'UpRes2dBlock', 'DownRes2dBlock',
|
18 |
+
'ResLinearBlock', 'HyperRes2dBlock', 'MultiOutRes2dBlock',
|
19 |
+
'PartialRes2dBlock', 'PartialRes3dBlock',
|
20 |
+
'NonLocal2dBlock']
|
21 |
+
|
22 |
+
try:
|
23 |
+
from .repvgg import RepVGG1dBlock, RepVGG2dBlock, RepVGG3dBlock
|
24 |
+
from .attn import MultiheadAttention
|
25 |
+
__all__.extend(['RepVGG1dBlock', 'RepVGG2dBlock', 'RepVGG3dBlock'])
|
26 |
+
except: # noqa
|
27 |
+
pass
|
imaginaire/layers/activation_norm.py
ADDED
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
# flake8: noqa E722
|
6 |
+
from types import SimpleNamespace
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
try:
|
11 |
+
from torch.nn import SyncBatchNorm
|
12 |
+
except ImportError:
|
13 |
+
from torch.nn import BatchNorm2d as SyncBatchNorm
|
14 |
+
from torch import nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from .conv import LinearBlock, Conv2dBlock, HyperConv2d, PartialConv2dBlock
|
17 |
+
from .misc import PartialSequential, ApplyNoise
|
18 |
+
|
19 |
+
|
20 |
+
class AdaptiveNorm(nn.Module):
|
21 |
+
r"""Adaptive normalization layer. The layer first normalizes the input, then
|
22 |
+
performs an affine transformation using parameters computed from the
|
23 |
+
conditional inputs.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
num_features (int): Number of channels in the input tensor.
|
27 |
+
cond_dims (int): Number of channels in the conditional inputs.
|
28 |
+
weight_norm_type (str): Type of weight normalization.
|
29 |
+
``'none'``, ``'spectral'``, ``'weight'``, or ``'weight_demod'``.
|
30 |
+
projection (bool): If ``True``, project the conditional input to gamma
|
31 |
+
and beta using a fully connected layer, otherwise directly use
|
32 |
+
the conditional input as gamma and beta.
|
33 |
+
projection_bias (bool) If ``True``, use bias in the fully connected
|
34 |
+
projection layer.
|
35 |
+
separate_projection (bool): If ``True``, we will use two different
|
36 |
+
layers for gamma and beta. Otherwise, we will use one layer. It
|
37 |
+
matters only if you apply any weight norms to this layer.
|
38 |
+
input_dim (int): Number of dimensions of the input tensor.
|
39 |
+
activation_norm_type (str):
|
40 |
+
Type of activation normalization.
|
41 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
42 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
43 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
44 |
+
activation_norm_params (obj, optional, default=None):
|
45 |
+
Parameters of activation normalization.
|
46 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
47 |
+
keyword arguments when initializing activation normalization.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self, num_features, cond_dims, weight_norm_type='',
|
51 |
+
projection=True,
|
52 |
+
projection_bias=True,
|
53 |
+
separate_projection=False,
|
54 |
+
input_dim=2,
|
55 |
+
activation_norm_type='instance',
|
56 |
+
activation_norm_params=None,
|
57 |
+
apply_noise=False,
|
58 |
+
add_bias=True,
|
59 |
+
input_scale=1.0,
|
60 |
+
init_gain=1.0):
|
61 |
+
super().__init__()
|
62 |
+
if activation_norm_params is None:
|
63 |
+
activation_norm_params = SimpleNamespace(affine=False)
|
64 |
+
self.norm = get_activation_norm_layer(num_features,
|
65 |
+
activation_norm_type,
|
66 |
+
input_dim,
|
67 |
+
**vars(activation_norm_params))
|
68 |
+
if apply_noise:
|
69 |
+
self.noise_layer = ApplyNoise()
|
70 |
+
else:
|
71 |
+
self.noise_layer = None
|
72 |
+
|
73 |
+
if projection:
|
74 |
+
if separate_projection:
|
75 |
+
self.fc_gamma = \
|
76 |
+
LinearBlock(cond_dims, num_features,
|
77 |
+
weight_norm_type=weight_norm_type,
|
78 |
+
bias=projection_bias)
|
79 |
+
self.fc_beta = \
|
80 |
+
LinearBlock(cond_dims, num_features,
|
81 |
+
weight_norm_type=weight_norm_type,
|
82 |
+
bias=projection_bias)
|
83 |
+
else:
|
84 |
+
self.fc = LinearBlock(cond_dims, num_features * 2,
|
85 |
+
weight_norm_type=weight_norm_type,
|
86 |
+
bias=projection_bias)
|
87 |
+
|
88 |
+
self.projection = projection
|
89 |
+
self.separate_projection = separate_projection
|
90 |
+
self.input_scale = input_scale
|
91 |
+
self.add_bias = add_bias
|
92 |
+
self.conditional = True
|
93 |
+
self.init_gain = init_gain
|
94 |
+
|
95 |
+
def forward(self, x, y, noise=None, **_kwargs):
|
96 |
+
r"""Adaptive Normalization forward.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
x (N x C1 x * tensor): Input tensor.
|
100 |
+
y (N x C2 tensor): Conditional information.
|
101 |
+
Returns:
|
102 |
+
out (N x C1 x * tensor): Output tensor.
|
103 |
+
"""
|
104 |
+
y = y * self.input_scale
|
105 |
+
if self.projection:
|
106 |
+
if self.separate_projection:
|
107 |
+
gamma = self.fc_gamma(y)
|
108 |
+
beta = self.fc_beta(y)
|
109 |
+
for _ in range(x.dim() - gamma.dim()):
|
110 |
+
gamma = gamma.unsqueeze(-1)
|
111 |
+
beta = beta.unsqueeze(-1)
|
112 |
+
else:
|
113 |
+
y = self.fc(y)
|
114 |
+
for _ in range(x.dim() - y.dim()):
|
115 |
+
y = y.unsqueeze(-1)
|
116 |
+
gamma, beta = y.chunk(2, 1)
|
117 |
+
else:
|
118 |
+
for _ in range(x.dim() - y.dim()):
|
119 |
+
y = y.unsqueeze(-1)
|
120 |
+
gamma, beta = y.chunk(2, 1)
|
121 |
+
if self.norm is not None:
|
122 |
+
x = self.norm(x)
|
123 |
+
if self.noise_layer is not None:
|
124 |
+
x = self.noise_layer(x, noise=noise)
|
125 |
+
if self.add_bias:
|
126 |
+
x = torch.addcmul(beta, x, 1 + gamma)
|
127 |
+
return x
|
128 |
+
else:
|
129 |
+
return x * (1 + gamma), beta.squeeze(3).squeeze(2)
|
130 |
+
|
131 |
+
|
132 |
+
class SpatiallyAdaptiveNorm(nn.Module):
|
133 |
+
r"""Spatially Adaptive Normalization (SPADE) initialization.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
num_features (int) : Number of channels in the input tensor.
|
137 |
+
cond_dims (int or list of int) : List of numbers of channels
|
138 |
+
in the input.
|
139 |
+
num_filters (int): Number of filters in SPADE.
|
140 |
+
kernel_size (int): Kernel size of the convolutional filters in
|
141 |
+
the SPADE layer.
|
142 |
+
weight_norm_type (str): Type of weight normalization.
|
143 |
+
``'none'``, ``'spectral'``, or ``'weight'``.
|
144 |
+
separate_projection (bool): If ``True``, we will use two different
|
145 |
+
layers for gamma and beta. Otherwise, we will use one layer. It
|
146 |
+
matters only if you apply any weight norms to this layer.
|
147 |
+
activation_norm_type (str):
|
148 |
+
Type of activation normalization.
|
149 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
150 |
+
``'layer'``, ``'layer_2d'``, ``'group'``.
|
151 |
+
activation_norm_params (obj, optional, default=None):
|
152 |
+
Parameters of activation normalization.
|
153 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
154 |
+
keyword arguments when initializing activation normalization.
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(self,
|
158 |
+
num_features,
|
159 |
+
cond_dims,
|
160 |
+
num_filters=128,
|
161 |
+
kernel_size=3,
|
162 |
+
weight_norm_type='',
|
163 |
+
separate_projection=False,
|
164 |
+
activation_norm_type='sync_batch',
|
165 |
+
activation_norm_params=None,
|
166 |
+
bias_only=False,
|
167 |
+
partial=False,
|
168 |
+
interpolation='nearest'):
|
169 |
+
super().__init__()
|
170 |
+
if activation_norm_params is None:
|
171 |
+
activation_norm_params = SimpleNamespace(affine=False)
|
172 |
+
padding = kernel_size // 2
|
173 |
+
self.separate_projection = separate_projection
|
174 |
+
self.mlps = nn.ModuleList()
|
175 |
+
self.gammas = nn.ModuleList()
|
176 |
+
self.betas = nn.ModuleList()
|
177 |
+
self.bias_only = bias_only
|
178 |
+
self.interpolation = interpolation
|
179 |
+
|
180 |
+
# Make cond_dims a list.
|
181 |
+
if type(cond_dims) != list:
|
182 |
+
cond_dims = [cond_dims]
|
183 |
+
|
184 |
+
# Make num_filters a list.
|
185 |
+
if not isinstance(num_filters, list):
|
186 |
+
num_filters = [num_filters] * len(cond_dims)
|
187 |
+
else:
|
188 |
+
assert len(num_filters) >= len(cond_dims)
|
189 |
+
|
190 |
+
# Make partial a list.
|
191 |
+
if not isinstance(partial, list):
|
192 |
+
partial = [partial] * len(cond_dims)
|
193 |
+
else:
|
194 |
+
assert len(partial) >= len(cond_dims)
|
195 |
+
|
196 |
+
for i, cond_dim in enumerate(cond_dims):
|
197 |
+
mlp = []
|
198 |
+
conv_block = PartialConv2dBlock if partial[i] else Conv2dBlock
|
199 |
+
sequential = PartialSequential if partial[i] else nn.Sequential
|
200 |
+
|
201 |
+
if num_filters[i] > 0:
|
202 |
+
mlp += [conv_block(cond_dim,
|
203 |
+
num_filters[i],
|
204 |
+
kernel_size,
|
205 |
+
padding=padding,
|
206 |
+
weight_norm_type=weight_norm_type,
|
207 |
+
nonlinearity='relu')]
|
208 |
+
mlp_ch = cond_dim if num_filters[i] == 0 else num_filters[i]
|
209 |
+
|
210 |
+
if self.separate_projection:
|
211 |
+
if partial[i]:
|
212 |
+
raise NotImplementedError(
|
213 |
+
'Separate projection not yet implemented for ' +
|
214 |
+
'partial conv')
|
215 |
+
self.mlps.append(nn.Sequential(*mlp))
|
216 |
+
self.gammas.append(
|
217 |
+
conv_block(mlp_ch, num_features,
|
218 |
+
kernel_size,
|
219 |
+
padding=padding,
|
220 |
+
weight_norm_type=weight_norm_type))
|
221 |
+
self.betas.append(
|
222 |
+
conv_block(mlp_ch, num_features,
|
223 |
+
kernel_size,
|
224 |
+
padding=padding,
|
225 |
+
weight_norm_type=weight_norm_type))
|
226 |
+
else:
|
227 |
+
mlp += [conv_block(mlp_ch, num_features * 2, kernel_size,
|
228 |
+
padding=padding,
|
229 |
+
weight_norm_type=weight_norm_type)]
|
230 |
+
self.mlps.append(sequential(*mlp))
|
231 |
+
|
232 |
+
self.norm = get_activation_norm_layer(num_features,
|
233 |
+
activation_norm_type,
|
234 |
+
2,
|
235 |
+
**vars(activation_norm_params))
|
236 |
+
self.conditional = True
|
237 |
+
|
238 |
+
def forward(self, x, *cond_inputs, **_kwargs):
|
239 |
+
r"""Spatially Adaptive Normalization (SPADE) forward.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
x (N x C1 x H x W tensor) : Input tensor.
|
243 |
+
cond_inputs (list of tensors) : Conditional maps for SPADE.
|
244 |
+
Returns:
|
245 |
+
output (4D tensor) : Output tensor.
|
246 |
+
"""
|
247 |
+
output = self.norm(x) if self.norm is not None else x
|
248 |
+
for i in range(len(cond_inputs)):
|
249 |
+
if cond_inputs[i] is None:
|
250 |
+
continue
|
251 |
+
label_map = F.interpolate(cond_inputs[i], size=x.size()[2:], mode=self.interpolation)
|
252 |
+
if self.separate_projection:
|
253 |
+
hidden = self.mlps[i](label_map)
|
254 |
+
gamma = self.gammas[i](hidden)
|
255 |
+
beta = self.betas[i](hidden)
|
256 |
+
else:
|
257 |
+
affine_params = self.mlps[i](label_map)
|
258 |
+
gamma, beta = affine_params.chunk(2, dim=1)
|
259 |
+
if self.bias_only:
|
260 |
+
output = output + beta
|
261 |
+
else:
|
262 |
+
output = output * (1 + gamma) + beta
|
263 |
+
return output
|
264 |
+
|
265 |
+
|
266 |
+
class DualAdaptiveNorm(nn.Module):
|
267 |
+
def __init__(self,
|
268 |
+
num_features,
|
269 |
+
cond_dims,
|
270 |
+
projection_bias=True,
|
271 |
+
weight_norm_type='',
|
272 |
+
activation_norm_type='instance',
|
273 |
+
activation_norm_params=None,
|
274 |
+
apply_noise=False,
|
275 |
+
bias_only=False,
|
276 |
+
init_gain=1.0,
|
277 |
+
fc_scale=None,
|
278 |
+
is_spatial=None):
|
279 |
+
super().__init__()
|
280 |
+
if activation_norm_params is None:
|
281 |
+
activation_norm_params = SimpleNamespace(affine=False)
|
282 |
+
self.mlps = nn.ModuleList()
|
283 |
+
self.gammas = nn.ModuleList()
|
284 |
+
self.betas = nn.ModuleList()
|
285 |
+
self.bias_only = bias_only
|
286 |
+
|
287 |
+
# Make cond_dims a list.
|
288 |
+
if type(cond_dims) != list:
|
289 |
+
cond_dims = [cond_dims]
|
290 |
+
|
291 |
+
if is_spatial is None:
|
292 |
+
is_spatial = [False for _ in range(len(cond_dims))]
|
293 |
+
self.is_spatial = is_spatial
|
294 |
+
|
295 |
+
for cond_dim, this_is_spatial in zip(cond_dims, is_spatial):
|
296 |
+
kwargs = dict(weight_norm_type=weight_norm_type,
|
297 |
+
bias=projection_bias,
|
298 |
+
init_gain=init_gain,
|
299 |
+
output_scale=fc_scale)
|
300 |
+
if this_is_spatial:
|
301 |
+
self.gammas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs))
|
302 |
+
self.betas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs))
|
303 |
+
else:
|
304 |
+
self.gammas.append(LinearBlock(cond_dim, num_features, **kwargs))
|
305 |
+
self.betas.append(LinearBlock(cond_dim, num_features, **kwargs))
|
306 |
+
|
307 |
+
self.norm = get_activation_norm_layer(num_features,
|
308 |
+
activation_norm_type,
|
309 |
+
2,
|
310 |
+
**vars(activation_norm_params))
|
311 |
+
self.conditional = True
|
312 |
+
|
313 |
+
def forward(self, x, *cond_inputs, **_kwargs):
|
314 |
+
assert len(cond_inputs) == len(self.gammas)
|
315 |
+
output = self.norm(x) if self.norm is not None else x
|
316 |
+
for cond, gamma_layer, beta_layer in zip(cond_inputs, self.gammas, self.betas):
|
317 |
+
if cond is None:
|
318 |
+
continue
|
319 |
+
gamma = gamma_layer(cond)
|
320 |
+
beta = beta_layer(cond)
|
321 |
+
if cond.dim() == 4 and gamma.shape != x.shape:
|
322 |
+
gamma = F.interpolate(gamma, size=x.size()[2:], mode='bilinear')
|
323 |
+
beta = F.interpolate(beta, size=x.size()[2:], mode='bilinear')
|
324 |
+
elif cond.dim() == 2:
|
325 |
+
gamma = gamma[:, :, None, None]
|
326 |
+
beta = beta[:, :, None, None]
|
327 |
+
if self.bias_only:
|
328 |
+
output = output + beta
|
329 |
+
else:
|
330 |
+
output = output * (1 + gamma) + beta
|
331 |
+
return output
|
332 |
+
|
333 |
+
|
334 |
+
class HyperSpatiallyAdaptiveNorm(nn.Module):
|
335 |
+
r"""Spatially Adaptive Normalization (SPADE) initialization.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
num_features (int) : Number of channels in the input tensor.
|
339 |
+
cond_dims (int or list of int) : List of numbers of channels
|
340 |
+
in the conditional input.
|
341 |
+
num_filters (int): Number of filters in SPADE.
|
342 |
+
kernel_size (int): Kernel size of the convolutional filters in
|
343 |
+
the SPADE layer.
|
344 |
+
weight_norm_type (str): Type of weight normalization.
|
345 |
+
``'none'``, ``'spectral'``, or ``'weight'``.
|
346 |
+
activation_norm_type (str):
|
347 |
+
Type of activation normalization.
|
348 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
349 |
+
``'layer'``, ``'layer_2d'``, ``'group'``.
|
350 |
+
is_hyper (bool): Whether to use hyper SPADE.
|
351 |
+
"""
|
352 |
+
|
353 |
+
def __init__(self, num_features, cond_dims,
|
354 |
+
num_filters=0, kernel_size=3,
|
355 |
+
weight_norm_type='',
|
356 |
+
activation_norm_type='sync_batch', is_hyper=True):
|
357 |
+
super().__init__()
|
358 |
+
padding = kernel_size // 2
|
359 |
+
self.mlps = nn.ModuleList()
|
360 |
+
if type(cond_dims) != list:
|
361 |
+
cond_dims = [cond_dims]
|
362 |
+
|
363 |
+
for i, cond_dim in enumerate(cond_dims):
|
364 |
+
mlp = []
|
365 |
+
if not is_hyper or (i != 0):
|
366 |
+
if num_filters > 0:
|
367 |
+
mlp += [Conv2dBlock(cond_dim, num_filters, kernel_size,
|
368 |
+
padding=padding,
|
369 |
+
weight_norm_type=weight_norm_type,
|
370 |
+
nonlinearity='relu')]
|
371 |
+
mlp_ch = cond_dim if num_filters == 0 else num_filters
|
372 |
+
mlp += [Conv2dBlock(mlp_ch, num_features * 2, kernel_size,
|
373 |
+
padding=padding,
|
374 |
+
weight_norm_type=weight_norm_type)]
|
375 |
+
mlp = nn.Sequential(*mlp)
|
376 |
+
else:
|
377 |
+
if num_filters > 0:
|
378 |
+
raise ValueError('Multi hyper layer not supported yet.')
|
379 |
+
mlp = HyperConv2d(padding=padding)
|
380 |
+
self.mlps.append(mlp)
|
381 |
+
|
382 |
+
self.norm = get_activation_norm_layer(num_features,
|
383 |
+
activation_norm_type,
|
384 |
+
2,
|
385 |
+
affine=False)
|
386 |
+
|
387 |
+
self.conditional = True
|
388 |
+
|
389 |
+
def forward(self, x, *cond_inputs,
|
390 |
+
norm_weights=(None, None), **_kwargs):
|
391 |
+
r"""Spatially Adaptive Normalization (SPADE) forward.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
x (4D tensor) : Input tensor.
|
395 |
+
cond_inputs (list of tensors) : Conditional maps for SPADE.
|
396 |
+
norm_weights (5D tensor or list of tensors): conv weights or
|
397 |
+
[weights, biases].
|
398 |
+
Returns:
|
399 |
+
output (4D tensor) : Output tensor.
|
400 |
+
"""
|
401 |
+
output = self.norm(x)
|
402 |
+
for i in range(len(cond_inputs)):
|
403 |
+
if cond_inputs[i] is None:
|
404 |
+
continue
|
405 |
+
if type(cond_inputs[i]) == list:
|
406 |
+
cond_input, mask = cond_inputs[i]
|
407 |
+
mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=False)
|
408 |
+
else:
|
409 |
+
cond_input = cond_inputs[i]
|
410 |
+
mask = None
|
411 |
+
label_map = F.interpolate(cond_input, size=x.size()[2:])
|
412 |
+
if norm_weights is None or norm_weights[0] is None or i != 0:
|
413 |
+
affine_params = self.mlps[i](label_map)
|
414 |
+
else:
|
415 |
+
affine_params = self.mlps[i](label_map,
|
416 |
+
conv_weights=norm_weights)
|
417 |
+
gamma, beta = affine_params.chunk(2, dim=1)
|
418 |
+
if mask is not None:
|
419 |
+
gamma = gamma * (1 - mask)
|
420 |
+
beta = beta * (1 - mask)
|
421 |
+
output = output * (1 + gamma) + beta
|
422 |
+
return output
|
423 |
+
|
424 |
+
|
425 |
+
class LayerNorm2d(nn.Module):
|
426 |
+
r"""Layer Normalization as introduced in
|
427 |
+
https://arxiv.org/abs/1607.06450.
|
428 |
+
This is the usual way to apply layer normalization in CNNs.
|
429 |
+
Note that unlike the pytorch implementation which applies per-element
|
430 |
+
scale and bias, here it applies per-channel scale and bias, similar to
|
431 |
+
batch/instance normalization.
|
432 |
+
|
433 |
+
Args:
|
434 |
+
num_features (int): Number of channels in the input tensor.
|
435 |
+
eps (float, optional, default=1e-5): a value added to the
|
436 |
+
denominator for numerical stability.
|
437 |
+
affine (bool, optional, default=False): If ``True``, performs
|
438 |
+
affine transformation after normalization.
|
439 |
+
"""
|
440 |
+
|
441 |
+
def __init__(self, num_features, eps=1e-5, channel_only=False, affine=True):
|
442 |
+
super(LayerNorm2d, self).__init__()
|
443 |
+
self.num_features = num_features
|
444 |
+
self.affine = affine
|
445 |
+
self.eps = eps
|
446 |
+
self.channel_only = channel_only
|
447 |
+
|
448 |
+
if self.affine:
|
449 |
+
self.gamma = nn.Parameter(torch.Tensor(num_features).fill_(1.0))
|
450 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
451 |
+
|
452 |
+
def forward(self, x):
|
453 |
+
r"""
|
454 |
+
|
455 |
+
Args:
|
456 |
+
x (tensor): Input tensor.
|
457 |
+
"""
|
458 |
+
shape = [-1] + [1] * (x.dim() - 1)
|
459 |
+
if self.channel_only:
|
460 |
+
mean = x.mean(1, keepdim=True)
|
461 |
+
std = x.std(1, keepdim=True)
|
462 |
+
else:
|
463 |
+
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
464 |
+
std = x.view(x.size(0), -1).std(1).view(*shape)
|
465 |
+
|
466 |
+
x = (x - mean) / (std + self.eps)
|
467 |
+
|
468 |
+
if self.affine:
|
469 |
+
shape = [1, -1] + [1] * (x.dim() - 2)
|
470 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
471 |
+
return x
|
472 |
+
|
473 |
+
|
474 |
+
class ScaleNorm(nn.Module):
|
475 |
+
r"""Scale normalization:
|
476 |
+
"Transformers without Tears: Improving the Normalization of Self-Attention"
|
477 |
+
Modified from:
|
478 |
+
https://github.com/tnq177/transformers_without_tears
|
479 |
+
"""
|
480 |
+
|
481 |
+
def __init__(self, dim=-1, learned_scale=True, eps=1e-5):
|
482 |
+
super().__init__()
|
483 |
+
# scale = num_features ** 0.5
|
484 |
+
if learned_scale:
|
485 |
+
self.scale = nn.Parameter(torch.tensor(1.))
|
486 |
+
else:
|
487 |
+
self.scale = 1.
|
488 |
+
# self.num_features = num_features
|
489 |
+
self.dim = dim
|
490 |
+
self.eps = eps
|
491 |
+
self.learned_scale = learned_scale
|
492 |
+
|
493 |
+
def forward(self, x):
|
494 |
+
# noinspection PyArgumentList
|
495 |
+
scale = self.scale * torch.rsqrt(torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
|
496 |
+
return x * scale
|
497 |
+
|
498 |
+
def extra_repr(self):
|
499 |
+
s = 'learned_scale={learned_scale}'
|
500 |
+
return s.format(**self.__dict__)
|
501 |
+
|
502 |
+
|
503 |
+
class PixelNorm(ScaleNorm):
|
504 |
+
def __init__(self, learned_scale=False, eps=1e-5, **_kwargs):
|
505 |
+
super().__init__(1, learned_scale, eps)
|
506 |
+
|
507 |
+
|
508 |
+
class SplitMeanStd(nn.Module):
|
509 |
+
def __init__(self, num_features, eps=1e-5, **kwargs):
|
510 |
+
super().__init__()
|
511 |
+
self.num_features = num_features
|
512 |
+
self.eps = eps
|
513 |
+
self.multiple_outputs = True
|
514 |
+
|
515 |
+
def forward(self, x):
|
516 |
+
b, c, h, w = x.size()
|
517 |
+
mean = x.view(b, c, -1).mean(-1)[:, :, None, None]
|
518 |
+
var = x.view(b, c, -1).var(-1)[:, :, None, None]
|
519 |
+
std = torch.sqrt(var + self.eps)
|
520 |
+
|
521 |
+
# x = (x - mean) / std
|
522 |
+
return x, torch.cat((mean, std), dim=1)
|
523 |
+
|
524 |
+
|
525 |
+
class ScaleNorm(nn.Module):
|
526 |
+
r"""Scale normalization:
|
527 |
+
"Transformers without Tears: Improving the Normalization of Self-Attention"
|
528 |
+
Modified from:
|
529 |
+
https://github.com/tnq177/transformers_without_tears
|
530 |
+
"""
|
531 |
+
|
532 |
+
def __init__(self, dim=-1, learned_scale=True, eps=1e-5):
|
533 |
+
super().__init__()
|
534 |
+
# scale = num_features ** 0.5
|
535 |
+
if learned_scale:
|
536 |
+
self.scale = nn.Parameter(torch.tensor(1.))
|
537 |
+
else:
|
538 |
+
self.scale = 1.
|
539 |
+
# self.num_features = num_features
|
540 |
+
self.dim = dim
|
541 |
+
self.eps = eps
|
542 |
+
self.learned_scale = learned_scale
|
543 |
+
|
544 |
+
def forward(self, x):
|
545 |
+
# noinspection PyArgumentList
|
546 |
+
scale = self.scale * torch.rsqrt(
|
547 |
+
torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
|
548 |
+
return x * scale
|
549 |
+
|
550 |
+
def extra_repr(self):
|
551 |
+
s = 'learned_scale={learned_scale}'
|
552 |
+
return s.format(**self.__dict__)
|
553 |
+
|
554 |
+
|
555 |
+
class PixelLayerNorm(nn.Module):
|
556 |
+
def __init__(self, *args, **kwargs):
|
557 |
+
super().__init__()
|
558 |
+
self.norm = nn.LayerNorm(*args, **kwargs)
|
559 |
+
|
560 |
+
def forward(self, x):
|
561 |
+
if x.dim() == 4:
|
562 |
+
b, c, h, w = x.shape
|
563 |
+
return self.norm(x.permute(0, 2, 3, 1).view(-1, c).contiguous()).view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
|
564 |
+
else:
|
565 |
+
return self.norm(x)
|
566 |
+
|
567 |
+
|
568 |
+
def get_activation_norm_layer(num_features, norm_type, input_dim, **norm_params):
|
569 |
+
r"""Return an activation normalization layer.
|
570 |
+
|
571 |
+
Args:
|
572 |
+
num_features (int): Number of feature channels.
|
573 |
+
norm_type (str):
|
574 |
+
Type of activation normalization.
|
575 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
576 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
577 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
578 |
+
input_dim (int): Number of input dimensions.
|
579 |
+
norm_params: Arbitrary keyword arguments that will be used to
|
580 |
+
initialize the activation normalization.
|
581 |
+
"""
|
582 |
+
input_dim = max(input_dim, 1) # Norm1d works with both 0d and 1d inputs
|
583 |
+
|
584 |
+
if norm_type == 'none' or norm_type == '':
|
585 |
+
norm_layer = None
|
586 |
+
elif norm_type == 'batch':
|
587 |
+
norm = getattr(nn, 'BatchNorm%dd' % input_dim)
|
588 |
+
norm_layer = norm(num_features, **norm_params)
|
589 |
+
elif norm_type == 'instance':
|
590 |
+
affine = norm_params.pop('affine', True) # Use affine=True by default
|
591 |
+
norm = getattr(nn, 'InstanceNorm%dd' % input_dim)
|
592 |
+
norm_layer = norm(num_features, affine=affine, **norm_params)
|
593 |
+
elif norm_type == 'sync_batch':
|
594 |
+
norm_layer = SyncBatchNorm(num_features, **norm_params)
|
595 |
+
elif norm_type == 'layer':
|
596 |
+
norm_layer = nn.LayerNorm(num_features, **norm_params)
|
597 |
+
elif norm_type == 'layer_2d':
|
598 |
+
norm_layer = LayerNorm2d(num_features, **norm_params)
|
599 |
+
elif norm_type == 'pixel_layer':
|
600 |
+
elementwise_affine = norm_params.pop('affine', True) # Use affine=True by default
|
601 |
+
norm_layer = PixelLayerNorm(num_features, elementwise_affine=elementwise_affine, **norm_params)
|
602 |
+
elif norm_type == 'scale':
|
603 |
+
norm_layer = ScaleNorm(**norm_params)
|
604 |
+
elif norm_type == 'pixel':
|
605 |
+
norm_layer = PixelNorm(**norm_params)
|
606 |
+
import imaginaire.config
|
607 |
+
if imaginaire.config.USE_JIT:
|
608 |
+
norm_layer = torch.jit.script(norm_layer)
|
609 |
+
elif norm_type == 'group':
|
610 |
+
num_groups = norm_params.pop('num_groups', 4)
|
611 |
+
norm_layer = nn.GroupNorm(num_channels=num_features, num_groups=num_groups, **norm_params)
|
612 |
+
elif norm_type == 'adaptive':
|
613 |
+
norm_layer = AdaptiveNorm(num_features, **norm_params)
|
614 |
+
elif norm_type == 'dual_adaptive':
|
615 |
+
norm_layer = DualAdaptiveNorm(num_features, **norm_params)
|
616 |
+
elif norm_type == 'spatially_adaptive':
|
617 |
+
if input_dim != 2:
|
618 |
+
raise ValueError('Spatially adaptive normalization layers '
|
619 |
+
'only supports 2D input')
|
620 |
+
norm_layer = SpatiallyAdaptiveNorm(num_features, **norm_params)
|
621 |
+
elif norm_type == 'hyper_spatially_adaptive':
|
622 |
+
if input_dim != 2:
|
623 |
+
raise ValueError('Spatially adaptive normalization layers '
|
624 |
+
'only supports 2D input')
|
625 |
+
norm_layer = HyperSpatiallyAdaptiveNorm(num_features, **norm_params)
|
626 |
+
else:
|
627 |
+
raise ValueError('Activation norm layer %s '
|
628 |
+
'is not recognized' % norm_type)
|
629 |
+
return norm_layer
|
imaginaire/layers/conv.py
ADDED
@@ -0,0 +1,1377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import warnings
|
6 |
+
from types import SimpleNamespace
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
from .misc import ApplyNoise
|
13 |
+
from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur
|
14 |
+
|
15 |
+
|
16 |
+
class _BaseConvBlock(nn.Module):
|
17 |
+
r"""An abstract wrapper class that wraps a torch convolution or linear layer
|
18 |
+
with normalization and nonlinearity.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode,
|
22 |
+
weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, nonlinearity,
|
23 |
+
inplace_nonlinearity, apply_noise, blur, order, input_dim, clamp, blur_kernel, output_scale,
|
24 |
+
init_gain):
|
25 |
+
super().__init__()
|
26 |
+
from .nonlinearity import get_nonlinearity_layer
|
27 |
+
from .weight_norm import get_weight_norm_layer
|
28 |
+
from .activation_norm import get_activation_norm_layer
|
29 |
+
self.weight_norm_type = weight_norm_type
|
30 |
+
self.stride = stride
|
31 |
+
self.clamp = clamp
|
32 |
+
self.init_gain = init_gain
|
33 |
+
|
34 |
+
# Nonlinearity layer.
|
35 |
+
if 'fused' in nonlinearity:
|
36 |
+
# Fusing nonlinearity with bias.
|
37 |
+
lr_mul = getattr(weight_norm_params, 'lr_mul', 1)
|
38 |
+
conv_before_nonlinearity = order.find('C') < order.find('A')
|
39 |
+
if conv_before_nonlinearity:
|
40 |
+
assert bias is True
|
41 |
+
bias = False
|
42 |
+
channel = out_channels if conv_before_nonlinearity else in_channels
|
43 |
+
nonlinearity_layer = get_nonlinearity_layer(
|
44 |
+
nonlinearity, inplace=inplace_nonlinearity,
|
45 |
+
num_channels=channel, lr_mul=lr_mul)
|
46 |
+
else:
|
47 |
+
nonlinearity_layer = get_nonlinearity_layer(
|
48 |
+
nonlinearity, inplace=inplace_nonlinearity)
|
49 |
+
|
50 |
+
# Noise injection layer.
|
51 |
+
if apply_noise:
|
52 |
+
order = order.replace('C', 'CG')
|
53 |
+
noise_layer = ApplyNoise()
|
54 |
+
else:
|
55 |
+
noise_layer = None
|
56 |
+
|
57 |
+
# Convolutional layer.
|
58 |
+
if blur:
|
59 |
+
assert blur_kernel is not None
|
60 |
+
if stride == 2:
|
61 |
+
# Blur - Conv - Noise - Activate
|
62 |
+
p = (len(blur_kernel) - 2) + (kernel_size - 1)
|
63 |
+
pad0, pad1 = (p + 1) // 2, p // 2
|
64 |
+
padding = 0
|
65 |
+
blur_layer = Blur(
|
66 |
+
blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
|
67 |
+
)
|
68 |
+
order = order.replace('C', 'BC')
|
69 |
+
elif stride == 0.5:
|
70 |
+
# Conv - Blur - Noise - Activate
|
71 |
+
padding = 0
|
72 |
+
p = (len(blur_kernel) - 2) - (kernel_size - 1)
|
73 |
+
pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1
|
74 |
+
blur_layer = Blur(
|
75 |
+
blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
|
76 |
+
)
|
77 |
+
order = order.replace('C', 'CB')
|
78 |
+
elif stride == 1:
|
79 |
+
# No blur for now
|
80 |
+
blur_layer = nn.Identity()
|
81 |
+
else:
|
82 |
+
raise NotImplementedError
|
83 |
+
else:
|
84 |
+
blur_layer = nn.Identity()
|
85 |
+
|
86 |
+
if weight_norm_params is None:
|
87 |
+
weight_norm_params = SimpleNamespace()
|
88 |
+
weight_norm = get_weight_norm_layer(
|
89 |
+
weight_norm_type, **vars(weight_norm_params))
|
90 |
+
conv_layer = weight_norm(self._get_conv_layer(
|
91 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
92 |
+
groups, bias, padding_mode, input_dim))
|
93 |
+
|
94 |
+
# Normalization layer.
|
95 |
+
conv_before_norm = order.find('C') < order.find('N')
|
96 |
+
norm_channels = out_channels if conv_before_norm else in_channels
|
97 |
+
if activation_norm_params is None:
|
98 |
+
activation_norm_params = SimpleNamespace()
|
99 |
+
activation_norm_layer = get_activation_norm_layer(
|
100 |
+
norm_channels,
|
101 |
+
activation_norm_type,
|
102 |
+
input_dim,
|
103 |
+
**vars(activation_norm_params))
|
104 |
+
|
105 |
+
# Mapping from operation names to layers.
|
106 |
+
mappings = {'C': {'conv': conv_layer},
|
107 |
+
'N': {'norm': activation_norm_layer},
|
108 |
+
'A': {'nonlinearity': nonlinearity_layer}}
|
109 |
+
mappings.update({'B': {'blur': blur_layer}})
|
110 |
+
mappings.update({'G': {'noise': noise_layer}})
|
111 |
+
|
112 |
+
# All layers in order.
|
113 |
+
self.layers = nn.ModuleDict()
|
114 |
+
for op in order:
|
115 |
+
if list(mappings[op].values())[0] is not None:
|
116 |
+
self.layers.update(mappings[op])
|
117 |
+
|
118 |
+
# Whether this block expects conditional inputs.
|
119 |
+
self.conditional = \
|
120 |
+
getattr(conv_layer, 'conditional', False) or \
|
121 |
+
getattr(activation_norm_layer, 'conditional', False)
|
122 |
+
|
123 |
+
# Scale the output by a learnable scaler parameter.
|
124 |
+
if output_scale is not None:
|
125 |
+
self.output_scale = nn.Parameter(torch.tensor(output_scale))
|
126 |
+
else:
|
127 |
+
self.register_parameter("output_scale", None)
|
128 |
+
|
129 |
+
def forward(self, x, *cond_inputs, **kw_cond_inputs):
|
130 |
+
r"""
|
131 |
+
|
132 |
+
Args:
|
133 |
+
x (tensor): Input tensor.
|
134 |
+
cond_inputs (list of tensors) : Conditional input tensors.
|
135 |
+
kw_cond_inputs (dict) : Keyword conditional inputs.
|
136 |
+
"""
|
137 |
+
for key, layer in self.layers.items():
|
138 |
+
if getattr(layer, 'conditional', False):
|
139 |
+
# Layers that require conditional inputs.
|
140 |
+
x = layer(x, *cond_inputs, **kw_cond_inputs)
|
141 |
+
else:
|
142 |
+
x = layer(x)
|
143 |
+
if self.clamp is not None and isinstance(layer, nn.Conv2d):
|
144 |
+
x.clamp_(max=self.clamp)
|
145 |
+
if key == 'conv':
|
146 |
+
if self.output_scale is not None:
|
147 |
+
x = x * self.output_scale
|
148 |
+
return x
|
149 |
+
|
150 |
+
def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
|
151 |
+
padding, dilation, groups, bias, padding_mode,
|
152 |
+
input_dim):
|
153 |
+
# Returns the convolutional layer.
|
154 |
+
if input_dim == 0:
|
155 |
+
layer = nn.Linear(in_channels, out_channels, bias)
|
156 |
+
else:
|
157 |
+
if stride < 1: # Fractionally-strided convolution.
|
158 |
+
padding_mode = 'zeros'
|
159 |
+
assert padding == 0
|
160 |
+
layer_type = getattr(nn, f'ConvTranspose{input_dim}d')
|
161 |
+
stride = round(1 / stride)
|
162 |
+
else:
|
163 |
+
layer_type = getattr(nn, f'Conv{input_dim}d')
|
164 |
+
layer = layer_type(
|
165 |
+
in_channels, out_channels, kernel_size, stride, padding,
|
166 |
+
dilation=dilation, groups=groups, bias=bias,
|
167 |
+
padding_mode=padding_mode
|
168 |
+
)
|
169 |
+
|
170 |
+
return layer
|
171 |
+
|
172 |
+
def __repr__(self):
|
173 |
+
main_str = self._get_name() + '('
|
174 |
+
child_lines = []
|
175 |
+
for name, layer in self.layers.items():
|
176 |
+
mod_str = repr(layer)
|
177 |
+
if name == 'conv' and self.weight_norm_type != 'none' and \
|
178 |
+
self.weight_norm_type != '':
|
179 |
+
mod_str = mod_str[:-1] + \
|
180 |
+
', weight_norm={}'.format(self.weight_norm_type) + ')'
|
181 |
+
if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1:
|
182 |
+
mod_str = mod_str[:-1] + \
|
183 |
+
', lr_mul={}'.format(layer.base_lr_mul) + ')'
|
184 |
+
mod_str = self._addindent(mod_str, 2)
|
185 |
+
child_lines.append(mod_str)
|
186 |
+
if len(child_lines) == 1:
|
187 |
+
main_str += child_lines[0]
|
188 |
+
else:
|
189 |
+
main_str += '\n ' + '\n '.join(child_lines) + '\n'
|
190 |
+
|
191 |
+
main_str += ')'
|
192 |
+
return main_str
|
193 |
+
|
194 |
+
@staticmethod
|
195 |
+
def _addindent(s_, numSpaces):
|
196 |
+
s = s_.split('\n')
|
197 |
+
# don't do anything for single-line stuff
|
198 |
+
if len(s) == 1:
|
199 |
+
return s_
|
200 |
+
first = s.pop(0)
|
201 |
+
s = [(numSpaces * ' ') + line for line in s]
|
202 |
+
s = '\n'.join(s)
|
203 |
+
s = first + '\n' + s
|
204 |
+
return s
|
205 |
+
|
206 |
+
|
207 |
+
class ModulatedConv2dBlock(_BaseConvBlock):
|
208 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
209 |
+
padding=0, dilation=1, groups=1, bias=True,
|
210 |
+
padding_mode='zeros',
|
211 |
+
weight_norm_type='none', weight_norm_params=None,
|
212 |
+
activation_norm_type='none', activation_norm_params=None,
|
213 |
+
nonlinearity='none', inplace_nonlinearity=False,
|
214 |
+
apply_noise=True, blur=True, order='CNA', demodulate=True,
|
215 |
+
eps=True, style_dim=None, clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, init_gain=1.0):
|
216 |
+
self.eps = eps
|
217 |
+
self.demodulate = demodulate
|
218 |
+
assert style_dim is not None
|
219 |
+
|
220 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
221 |
+
padding, dilation, groups, bias, padding_mode,
|
222 |
+
weight_norm_type, weight_norm_params,
|
223 |
+
activation_norm_type, activation_norm_params,
|
224 |
+
nonlinearity, inplace_nonlinearity, apply_noise, blur,
|
225 |
+
order, 2, clamp, blur_kernel, output_scale, init_gain)
|
226 |
+
self.modulation = LinearBlock(style_dim, in_channels,
|
227 |
+
weight_norm_type=weight_norm_type,
|
228 |
+
weight_norm_params=weight_norm_params)
|
229 |
+
|
230 |
+
def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
|
231 |
+
padding, dilation, groups, bias, padding_mode,
|
232 |
+
input_dim):
|
233 |
+
assert input_dim == 2
|
234 |
+
layer = ModulatedConv2d(
|
235 |
+
in_channels, out_channels, kernel_size, stride, padding,
|
236 |
+
dilation, groups, bias, padding_mode, self.demodulate, self.eps)
|
237 |
+
return layer
|
238 |
+
|
239 |
+
def forward(self, x, *cond_inputs, **kw_cond_inputs):
|
240 |
+
for layer in self.layers.values():
|
241 |
+
if getattr(layer, 'conditional', False):
|
242 |
+
# Layers that require conditional inputs.
|
243 |
+
assert len(cond_inputs) == 1
|
244 |
+
style = cond_inputs[0]
|
245 |
+
x = layer(
|
246 |
+
x, self.modulation(style), **kw_cond_inputs
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
x = layer(x)
|
250 |
+
if self.clamp is not None and isinstance(layer, ModulatedConv2d):
|
251 |
+
x.clamp_(max=self.clamp)
|
252 |
+
return x
|
253 |
+
|
254 |
+
def __repr__(self):
|
255 |
+
main_str = self._get_name() + '('
|
256 |
+
child_lines = []
|
257 |
+
for name, layer in self.layers.items():
|
258 |
+
mod_str = repr(layer)
|
259 |
+
if name == 'conv' and self.weight_norm_type != 'none' and \
|
260 |
+
self.weight_norm_type != '':
|
261 |
+
mod_str = mod_str[:-1] + \
|
262 |
+
', weight_norm={}'.format(self.weight_norm_type) + \
|
263 |
+
', demodulate={}'.format(self.demodulate) + ')'
|
264 |
+
mod_str = self._addindent(mod_str, 2)
|
265 |
+
child_lines.append(mod_str)
|
266 |
+
child_lines.append(
|
267 |
+
self._addindent('Modulation(' + repr(self.modulation) + ')', 2)
|
268 |
+
)
|
269 |
+
if len(child_lines) == 1:
|
270 |
+
main_str += child_lines[0]
|
271 |
+
else:
|
272 |
+
main_str += '\n ' + '\n '.join(child_lines) + '\n'
|
273 |
+
|
274 |
+
main_str += ')'
|
275 |
+
return main_str
|
276 |
+
|
277 |
+
|
278 |
+
class ModulatedConv2d(nn.Module):
|
279 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
|
280 |
+
dilation, groups, bias, padding_mode, demodulate=True,
|
281 |
+
eps=1e-8):
|
282 |
+
# in_channels, out_channels, kernel_size, stride, padding,
|
283 |
+
# dilation, groups, bias, padding_mode
|
284 |
+
assert dilation == 1 and groups == 1
|
285 |
+
|
286 |
+
super().__init__()
|
287 |
+
|
288 |
+
self.eps = eps
|
289 |
+
self.kernel_size = kernel_size
|
290 |
+
self.in_channels = in_channels
|
291 |
+
self.out_channels = out_channels
|
292 |
+
self.padding = padding
|
293 |
+
self.stride = stride
|
294 |
+
self.padding_mode = padding_mode
|
295 |
+
# kernel_size // 2
|
296 |
+
# assert self.padding == padding
|
297 |
+
|
298 |
+
self.weight = nn.Parameter(
|
299 |
+
torch.randn(out_channels, in_channels, kernel_size, kernel_size)
|
300 |
+
)
|
301 |
+
|
302 |
+
if bias:
|
303 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
304 |
+
else:
|
305 |
+
# noinspection PyTypeChecker
|
306 |
+
self.register_parameter('bias', None)
|
307 |
+
|
308 |
+
# self.modulation = LinearBlock(style_dim, in_channels,
|
309 |
+
# weight_norm_type=weight_norm_type)
|
310 |
+
self.demodulate = demodulate
|
311 |
+
self.conditional = True
|
312 |
+
|
313 |
+
def forward(self, x, style, **_kwargs):
|
314 |
+
batch, in_channel, height, width = x.shape
|
315 |
+
|
316 |
+
# style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
317 |
+
# We assume the modulation layer is outside this module.
|
318 |
+
style = style.view(batch, 1, in_channel, 1, 1)
|
319 |
+
weight = self.weight.unsqueeze(0) * style
|
320 |
+
|
321 |
+
if self.demodulate:
|
322 |
+
demod = torch.rsqrt(
|
323 |
+
weight.pow(2).sum([2, 3, 4]) + self.eps)
|
324 |
+
weight = weight * demod.view(batch, self.out_channels, 1, 1, 1)
|
325 |
+
|
326 |
+
weight = weight.view(
|
327 |
+
batch * self.out_channels,
|
328 |
+
in_channel, self.kernel_size, self.kernel_size
|
329 |
+
)
|
330 |
+
if self.bias is not None:
|
331 |
+
bias = self.bias.repeat(batch)
|
332 |
+
else:
|
333 |
+
bias = self.bias
|
334 |
+
|
335 |
+
x = x.view(1, batch * in_channel, height, width)
|
336 |
+
|
337 |
+
if self.padding_mode != 'zeros':
|
338 |
+
x = F.pad(x, self._reversed_padding_repeated_twice,
|
339 |
+
mode=self.padding_mode)
|
340 |
+
padding = (0, 0)
|
341 |
+
else:
|
342 |
+
padding = self.padding
|
343 |
+
|
344 |
+
if self.stride == 0.5:
|
345 |
+
weight = weight.view(
|
346 |
+
batch, self.out_channels, in_channel,
|
347 |
+
self.kernel_size, self.kernel_size
|
348 |
+
)
|
349 |
+
weight = weight.transpose(1, 2).reshape(
|
350 |
+
batch * in_channel, self.out_channels,
|
351 |
+
self.kernel_size, self.kernel_size
|
352 |
+
)
|
353 |
+
out = F.conv_transpose2d(
|
354 |
+
x, weight, bias, padding=padding, stride=2, groups=batch
|
355 |
+
)
|
356 |
+
|
357 |
+
elif self.stride == 2:
|
358 |
+
out = F.conv2d(
|
359 |
+
x, weight, bias, padding=padding, stride=2, groups=batch
|
360 |
+
)
|
361 |
+
|
362 |
+
else:
|
363 |
+
out = F.conv2d(x, weight, bias, padding=padding, groups=batch)
|
364 |
+
|
365 |
+
_, _, height, width = out.shape
|
366 |
+
out = out.view(batch, self.out_channels, height, width)
|
367 |
+
|
368 |
+
return out
|
369 |
+
|
370 |
+
def extra_repr(self):
|
371 |
+
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
|
372 |
+
', stride={stride}')
|
373 |
+
if self.bias is None:
|
374 |
+
s += ', bias=False'
|
375 |
+
if self.padding_mode != 'zeros':
|
376 |
+
s += ', padding_mode={padding_mode}'
|
377 |
+
return s.format(**self.__dict__)
|
378 |
+
|
379 |
+
|
380 |
+
class LinearBlock(_BaseConvBlock):
|
381 |
+
r"""A Wrapper class that wraps ``torch.nn.Linear`` with normalization and
|
382 |
+
nonlinearity.
|
383 |
+
|
384 |
+
Args:
|
385 |
+
in_features (int): Number of channels in the input tensor.
|
386 |
+
out_features (int): Number of channels in the output tensor.
|
387 |
+
bias (bool, optional, default=True):
|
388 |
+
If ``True``, adds a learnable bias to the output.
|
389 |
+
weight_norm_type (str, optional, default='none'):
|
390 |
+
Type of weight normalization.
|
391 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
392 |
+
or ``'weight_demod'``.
|
393 |
+
weight_norm_params (obj, optional, default=None):
|
394 |
+
Parameters of weight normalization.
|
395 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
396 |
+
keyword arguments when initializing weight normalization.
|
397 |
+
activation_norm_type (str, optional, default='none'):
|
398 |
+
Type of activation normalization.
|
399 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
400 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
401 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
402 |
+
activation_norm_params (obj, optional, default=None):
|
403 |
+
Parameters of activation normalization.
|
404 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
405 |
+
keyword arguments when initializing activation normalization.
|
406 |
+
nonlinearity (str, optional, default='none'):
|
407 |
+
Type of nonlinear activation function.
|
408 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
409 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
410 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
411 |
+
set ``inplace=True`` when initializing the nonlinearity layer.
|
412 |
+
apply_noise (bool, optional, default=False): If ``True``, add
|
413 |
+
Gaussian noise with learnable magnitude after the
|
414 |
+
fully-connected layer.
|
415 |
+
order (str, optional, default='CNA'): Order of operations.
|
416 |
+
``'C'``: fully-connected,
|
417 |
+
``'N'``: normalization,
|
418 |
+
``'A'``: nonlinear activation.
|
419 |
+
For example, a block initialized with ``order='CNA'`` will
|
420 |
+
do convolution first, then normalization, then nonlinearity.
|
421 |
+
"""
|
422 |
+
|
423 |
+
def __init__(self, in_features, out_features, bias=True,
|
424 |
+
weight_norm_type='none', weight_norm_params=None,
|
425 |
+
activation_norm_type='none', activation_norm_params=None,
|
426 |
+
nonlinearity='none', inplace_nonlinearity=False,
|
427 |
+
apply_noise=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None,
|
428 |
+
init_gain=1.0, **_kwargs):
|
429 |
+
if bool(_kwargs):
|
430 |
+
warnings.warn(f"Unused keyword arguments {_kwargs}")
|
431 |
+
super().__init__(in_features, out_features, None, None,
|
432 |
+
None, None, None, bias,
|
433 |
+
None, weight_norm_type, weight_norm_params,
|
434 |
+
activation_norm_type, activation_norm_params,
|
435 |
+
nonlinearity, inplace_nonlinearity, apply_noise,
|
436 |
+
False, order, 0, clamp, blur_kernel, output_scale,
|
437 |
+
init_gain)
|
438 |
+
|
439 |
+
|
440 |
+
class EmbeddingBlock(_BaseConvBlock):
|
441 |
+
def __init__(self, in_features, out_features, bias=True,
|
442 |
+
weight_norm_type='none', weight_norm_params=None,
|
443 |
+
activation_norm_type='none', activation_norm_params=None,
|
444 |
+
nonlinearity='none', inplace_nonlinearity=False,
|
445 |
+
apply_noise=False, order='CNA', clamp=None, output_scale=None,
|
446 |
+
init_gain=1.0, **_kwargs):
|
447 |
+
if bool(_kwargs):
|
448 |
+
warnings.warn(f"Unused keyword arguments {_kwargs}")
|
449 |
+
super().__init__(in_features, out_features, None, None,
|
450 |
+
None, None, None, bias,
|
451 |
+
None, weight_norm_type, weight_norm_params,
|
452 |
+
activation_norm_type, activation_norm_params,
|
453 |
+
nonlinearity, inplace_nonlinearity, apply_noise,
|
454 |
+
False, order, 0, clamp, None, output_scale,
|
455 |
+
init_gain)
|
456 |
+
|
457 |
+
def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
|
458 |
+
padding, dilation, groups, bias, padding_mode,
|
459 |
+
input_dim):
|
460 |
+
assert input_dim == 0
|
461 |
+
return nn.Embedding(in_channels, out_channels)
|
462 |
+
|
463 |
+
|
464 |
+
class Embedding2dBlock(_BaseConvBlock):
|
465 |
+
def __init__(self, in_features, out_features, bias=True,
|
466 |
+
weight_norm_type='none', weight_norm_params=None,
|
467 |
+
activation_norm_type='none', activation_norm_params=None,
|
468 |
+
nonlinearity='none', inplace_nonlinearity=False,
|
469 |
+
apply_noise=False, order='CNA', clamp=None, output_scale=None,
|
470 |
+
init_gain=1.0, **_kwargs):
|
471 |
+
if bool(_kwargs):
|
472 |
+
warnings.warn(f"Unused keyword arguments {_kwargs}")
|
473 |
+
super().__init__(in_features, out_features, None, None,
|
474 |
+
None, None, None, bias,
|
475 |
+
None, weight_norm_type, weight_norm_params,
|
476 |
+
activation_norm_type, activation_norm_params,
|
477 |
+
nonlinearity, inplace_nonlinearity, apply_noise,
|
478 |
+
False, order, 0, clamp, None, output_scale,
|
479 |
+
init_gain)
|
480 |
+
|
481 |
+
def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
|
482 |
+
padding, dilation, groups, bias, padding_mode,
|
483 |
+
input_dim):
|
484 |
+
assert input_dim == 0
|
485 |
+
return Embedding2d(in_channels, out_channels)
|
486 |
+
|
487 |
+
|
488 |
+
class Conv1dBlock(_BaseConvBlock):
|
489 |
+
r"""A Wrapper class that wraps ``torch.nn.Conv1d`` with normalization and
|
490 |
+
nonlinearity.
|
491 |
+
|
492 |
+
Args:
|
493 |
+
in_channels (int): Number of channels in the input tensor.
|
494 |
+
out_channels (int): Number of channels in the output tensor.
|
495 |
+
kernel_size (int or tuple): Size of the convolving kernel.
|
496 |
+
stride (int or float or tuple, optional, default=1):
|
497 |
+
Stride of the convolution.
|
498 |
+
padding (int or tuple, optional, default=0):
|
499 |
+
Zero-padding added to both sides of the input.
|
500 |
+
dilation (int or tuple, optional, default=1):
|
501 |
+
Spacing between kernel elements.
|
502 |
+
groups (int, optional, default=1): Number of blocked connections
|
503 |
+
from input channels to output channels.
|
504 |
+
bias (bool, optional, default=True):
|
505 |
+
If ``True``, adds a learnable bias to the output.
|
506 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
507 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
508 |
+
weight_norm_type (str, optional, default='none'):
|
509 |
+
Type of weight normalization.
|
510 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
511 |
+
or ``'weight_demod'``.
|
512 |
+
weight_norm_params (obj, optional, default=None):
|
513 |
+
Parameters of weight normalization.
|
514 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
515 |
+
keyword arguments when initializing weight normalization.
|
516 |
+
activation_norm_type (str, optional, default='none'):
|
517 |
+
Type of activation normalization.
|
518 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
519 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
520 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
521 |
+
activation_norm_params (obj, optional, default=None):
|
522 |
+
Parameters of activation normalization.
|
523 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
524 |
+
keyword arguments when initializing activation normalization.
|
525 |
+
nonlinearity (str, optional, default='none'):
|
526 |
+
Type of nonlinear activation function.
|
527 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
528 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
529 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
530 |
+
set ``inplace=True`` when initializing the nonlinearity layer.
|
531 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
532 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
533 |
+
order (str, optional, default='CNA'): Order of operations.
|
534 |
+
``'C'``: convolution,
|
535 |
+
``'N'``: normalization,
|
536 |
+
``'A'``: nonlinear activation.
|
537 |
+
For example, a block initialized with ``order='CNA'`` will
|
538 |
+
do convolution first, then normalization, then nonlinearity.
|
539 |
+
"""
|
540 |
+
|
541 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
542 |
+
padding=0, dilation=1, groups=1, bias=True,
|
543 |
+
padding_mode='zeros',
|
544 |
+
weight_norm_type='none', weight_norm_params=None,
|
545 |
+
activation_norm_type='none', activation_norm_params=None,
|
546 |
+
nonlinearity='none', inplace_nonlinearity=False,
|
547 |
+
apply_noise=False, blur=False, order='CNA', clamp=None, output_scale=None, init_gain=1.0, **_kwargs):
|
548 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
549 |
+
padding, dilation, groups, bias, padding_mode,
|
550 |
+
weight_norm_type, weight_norm_params,
|
551 |
+
activation_norm_type, activation_norm_params,
|
552 |
+
nonlinearity, inplace_nonlinearity, apply_noise,
|
553 |
+
blur, order, 1, clamp, None, output_scale, init_gain)
|
554 |
+
|
555 |
+
|
556 |
+
class Conv2dBlock(_BaseConvBlock):
|
557 |
+
r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and
|
558 |
+
nonlinearity.
|
559 |
+
|
560 |
+
Args:
|
561 |
+
in_channels (int): Number of channels in the input tensor.
|
562 |
+
out_channels (int): Number of channels in the output tensor.
|
563 |
+
kernel_size (int or tuple): Size of the convolving kernel.
|
564 |
+
stride (int or float or tuple, optional, default=1):
|
565 |
+
Stride of the convolution.
|
566 |
+
padding (int or tuple, optional, default=0):
|
567 |
+
Zero-padding added to both sides of the input.
|
568 |
+
dilation (int or tuple, optional, default=1):
|
569 |
+
Spacing between kernel elements.
|
570 |
+
groups (int, optional, default=1): Number of blocked connections
|
571 |
+
from input channels to output channels.
|
572 |
+
bias (bool, optional, default=True):
|
573 |
+
If ``True``, adds a learnable bias to the output.
|
574 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
575 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
576 |
+
weight_norm_type (str, optional, default='none'):
|
577 |
+
Type of weight normalization.
|
578 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
579 |
+
or ``'weight_demod'``.
|
580 |
+
weight_norm_params (obj, optional, default=None):
|
581 |
+
Parameters of weight normalization.
|
582 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
583 |
+
keyword arguments when initializing weight normalization.
|
584 |
+
activation_norm_type (str, optional, default='none'):
|
585 |
+
Type of activation normalization.
|
586 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
587 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
588 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
589 |
+
activation_norm_params (obj, optional, default=None):
|
590 |
+
Parameters of activation normalization.
|
591 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
592 |
+
keyword arguments when initializing activation normalization.
|
593 |
+
nonlinearity (str, optional, default='none'):
|
594 |
+
Type of nonlinear activation function.
|
595 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
596 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
597 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
598 |
+
set ``inplace=True`` when initializing the nonlinearity layer.
|
599 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
600 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
601 |
+
order (str, optional, default='CNA'): Order of operations.
|
602 |
+
``'C'``: convolution,
|
603 |
+
``'N'``: normalization,
|
604 |
+
``'A'``: nonlinear activation.
|
605 |
+
For example, a block initialized with ``order='CNA'`` will
|
606 |
+
do convolution first, then normalization, then nonlinearity.
|
607 |
+
"""
|
608 |
+
|
609 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
610 |
+
padding=0, dilation=1, groups=1, bias=True,
|
611 |
+
padding_mode='zeros',
|
612 |
+
weight_norm_type='none', weight_norm_params=None,
|
613 |
+
activation_norm_type='none', activation_norm_params=None,
|
614 |
+
nonlinearity='none', inplace_nonlinearity=False,
|
615 |
+
apply_noise=False, blur=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1),
|
616 |
+
output_scale=None, init_gain=1.0):
|
617 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
618 |
+
padding, dilation, groups, bias, padding_mode,
|
619 |
+
weight_norm_type, weight_norm_params,
|
620 |
+
activation_norm_type, activation_norm_params,
|
621 |
+
nonlinearity, inplace_nonlinearity,
|
622 |
+
apply_noise, blur, order, 2, clamp, blur_kernel, output_scale, init_gain)
|
623 |
+
|
624 |
+
|
625 |
+
class Conv3dBlock(_BaseConvBlock):
|
626 |
+
r"""A Wrapper class that wraps ``torch.nn.Conv3d`` with normalization and
|
627 |
+
nonlinearity.
|
628 |
+
|
629 |
+
Args:
|
630 |
+
in_channels (int): Number of channels in the input tensor.
|
631 |
+
out_channels (int): Number of channels in the output tensor.
|
632 |
+
kernel_size (int or tuple): Size of the convolving kernel.
|
633 |
+
stride (int or float or tuple, optional, default=1):
|
634 |
+
Stride of the convolution.
|
635 |
+
padding (int or tuple, optional, default=0):
|
636 |
+
Zero-padding added to both sides of the input.
|
637 |
+
dilation (int or tuple, optional, default=1):
|
638 |
+
Spacing between kernel elements.
|
639 |
+
groups (int, optional, default=1): Number of blocked connections
|
640 |
+
from input channels to output channels.
|
641 |
+
bias (bool, optional, default=True):
|
642 |
+
If ``True``, adds a learnable bias to the output.
|
643 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
644 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
645 |
+
weight_norm_type (str, optional, default='none'):
|
646 |
+
Type of weight normalization.
|
647 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
648 |
+
or ``'weight_demod'``.
|
649 |
+
weight_norm_params (obj, optional, default=None):
|
650 |
+
Parameters of weight normalization.
|
651 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
652 |
+
keyword arguments when initializing weight normalization.
|
653 |
+
activation_norm_type (str, optional, default='none'):
|
654 |
+
Type of activation normalization.
|
655 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
656 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
657 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
658 |
+
activation_norm_params (obj, optional, default=None):
|
659 |
+
Parameters of activation normalization.
|
660 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
661 |
+
keyword arguments when initializing activation normalization.
|
662 |
+
nonlinearity (str, optional, default='none'):
|
663 |
+
Type of nonlinear activation function.
|
664 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
665 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
666 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
667 |
+
set ``inplace=True`` when initializing the nonlinearity layer.
|
668 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
669 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
670 |
+
order (str, optional, default='CNA'): Order of operations.
|
671 |
+
``'C'``: convolution,
|
672 |
+
``'N'``: normalization,
|
673 |
+
``'A'``: nonlinear activation.
|
674 |
+
For example, a block initialized with ``order='CNA'`` will
|
675 |
+
do convolution first, then normalization, then nonlinearity.
|
676 |
+
"""
|
677 |
+
|
678 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
679 |
+
padding=0, dilation=1, groups=1, bias=True,
|
680 |
+
padding_mode='zeros',
|
681 |
+
weight_norm_type='none', weight_norm_params=None,
|
682 |
+
activation_norm_type='none', activation_norm_params=None,
|
683 |
+
nonlinearity='none', inplace_nonlinearity=False,
|
684 |
+
apply_noise=False, blur=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None,
|
685 |
+
init_gain=1.0):
|
686 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
687 |
+
padding, dilation, groups, bias, padding_mode,
|
688 |
+
weight_norm_type, weight_norm_params,
|
689 |
+
activation_norm_type, activation_norm_params,
|
690 |
+
nonlinearity, inplace_nonlinearity,
|
691 |
+
apply_noise, blur, order, 3, clamp, blur_kernel, output_scale, init_gain)
|
692 |
+
|
693 |
+
|
694 |
+
class _BaseHyperConvBlock(_BaseConvBlock):
|
695 |
+
r"""An abstract wrapper class that wraps a hyper convolutional layer
|
696 |
+
with normalization and nonlinearity.
|
697 |
+
"""
|
698 |
+
|
699 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
700 |
+
padding, dilation, groups, bias,
|
701 |
+
padding_mode,
|
702 |
+
weight_norm_type, weight_norm_params,
|
703 |
+
activation_norm_type, activation_norm_params,
|
704 |
+
nonlinearity, inplace_nonlinearity, apply_noise, blur,
|
705 |
+
is_hyper_conv, is_hyper_norm, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1),
|
706 |
+
output_scale=None, init_gain=1.0):
|
707 |
+
self.is_hyper_conv = is_hyper_conv
|
708 |
+
if is_hyper_conv:
|
709 |
+
weight_norm_type = 'none'
|
710 |
+
if is_hyper_norm:
|
711 |
+
activation_norm_type = 'hyper_' + activation_norm_type
|
712 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
713 |
+
padding, dilation, groups, bias, padding_mode,
|
714 |
+
weight_norm_type, weight_norm_params,
|
715 |
+
activation_norm_type, activation_norm_params,
|
716 |
+
nonlinearity, inplace_nonlinearity, apply_noise, blur,
|
717 |
+
order, input_dim, clamp, blur_kernel, output_scale, init_gain)
|
718 |
+
|
719 |
+
def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
|
720 |
+
padding, dilation, groups, bias, padding_mode,
|
721 |
+
input_dim):
|
722 |
+
if input_dim == 0:
|
723 |
+
raise ValueError('HyperLinearBlock is not supported.')
|
724 |
+
else:
|
725 |
+
name = 'HyperConv' if self.is_hyper_conv else 'nn.Conv'
|
726 |
+
layer_type = eval(name + '%dd' % input_dim)
|
727 |
+
layer = layer_type(
|
728 |
+
in_channels, out_channels, kernel_size, stride, padding,
|
729 |
+
dilation, groups, bias, padding_mode)
|
730 |
+
return layer
|
731 |
+
|
732 |
+
|
733 |
+
class HyperConv2dBlock(_BaseHyperConvBlock):
|
734 |
+
r"""A Wrapper class that wraps ``HyperConv2d`` with normalization and
|
735 |
+
nonlinearity.
|
736 |
+
|
737 |
+
Args:
|
738 |
+
in_channels (int): Number of channels in the input tensor.
|
739 |
+
out_channels (int): Number of channels in the output tensor.
|
740 |
+
kernel_size (int or tuple): Size of the convolving kernel.
|
741 |
+
stride (int or float or tuple, optional, default=1):
|
742 |
+
Stride of the convolution.
|
743 |
+
padding (int or tuple, optional, default=0):
|
744 |
+
Zero-padding added to both sides of the input.
|
745 |
+
dilation (int or tuple, optional, default=1):
|
746 |
+
Spacing between kernel elements.
|
747 |
+
groups (int, optional, default=1): Number of blocked connections
|
748 |
+
from input channels to output channels.
|
749 |
+
bias (bool, optional, default=True):
|
750 |
+
If ``True``, adds a learnable bias to the output.
|
751 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
752 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
753 |
+
weight_norm_type (str, optional, default='none'):
|
754 |
+
Type of weight normalization.
|
755 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
756 |
+
or ``'weight_demod'``.
|
757 |
+
weight_norm_params (obj, optional, default=None):
|
758 |
+
Parameters of weight normalization.
|
759 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
760 |
+
keyword arguments when initializing weight normalization.
|
761 |
+
activation_norm_type (str, optional, default='none'):
|
762 |
+
Type of activation normalization.
|
763 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
764 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
765 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
766 |
+
activation_norm_params (obj, optional, default=None):
|
767 |
+
Parameters of activation normalization.
|
768 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
769 |
+
keyword arguments when initializing activation normalization.
|
770 |
+
is_hyper_conv (bool, optional, default=False): If ``True``, use
|
771 |
+
``HyperConv2d``, otherwise use ``torch.nn.Conv2d``.
|
772 |
+
is_hyper_norm (bool, optional, default=False): If ``True``, use
|
773 |
+
hyper normalizations.
|
774 |
+
nonlinearity (str, optional, default='none'):
|
775 |
+
Type of nonlinear activation function.
|
776 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
777 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
778 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
779 |
+
set ``inplace=True`` when initializing the nonlinearity layer.
|
780 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
781 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
782 |
+
order (str, optional, default='CNA'): Order of operations.
|
783 |
+
``'C'``: convolution,
|
784 |
+
``'N'``: normalization,
|
785 |
+
``'A'``: nonlinear activation.
|
786 |
+
For example, a block initialized with ``order='CNA'`` will
|
787 |
+
do convolution first, then normalization, then nonlinearity.
|
788 |
+
"""
|
789 |
+
|
790 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
791 |
+
padding=0, dilation=1, groups=1, bias=True,
|
792 |
+
padding_mode='zeros',
|
793 |
+
weight_norm_type='none', weight_norm_params=None,
|
794 |
+
activation_norm_type='none', activation_norm_params=None,
|
795 |
+
is_hyper_conv=False, is_hyper_norm=False,
|
796 |
+
nonlinearity='none', inplace_nonlinearity=False,
|
797 |
+
apply_noise=False, blur=False, order='CNA', clamp=None):
|
798 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
799 |
+
padding, dilation, groups, bias, padding_mode,
|
800 |
+
weight_norm_type, weight_norm_params,
|
801 |
+
activation_norm_type, activation_norm_params,
|
802 |
+
nonlinearity, inplace_nonlinearity, apply_noise, blur,
|
803 |
+
is_hyper_conv, is_hyper_norm, order, 2, clamp)
|
804 |
+
|
805 |
+
|
806 |
+
class HyperConv2d(nn.Module):
|
807 |
+
r"""Hyper Conv2d initialization.
|
808 |
+
|
809 |
+
Args:
|
810 |
+
in_channels (int): Dummy parameter.
|
811 |
+
out_channels (int): Dummy parameter.
|
812 |
+
kernel_size (int or tuple): Dummy parameter.
|
813 |
+
stride (int or float or tuple, optional, default=1):
|
814 |
+
Stride of the convolution. Default: 1
|
815 |
+
padding (int or tuple, optional, default=0):
|
816 |
+
Zero-padding added to both sides of the input.
|
817 |
+
padding_mode (string, optional, default='zeros'):
|
818 |
+
``'zeros'``, ``'reflect'``, ``'replicate'``
|
819 |
+
or ``'circular'``.
|
820 |
+
dilation (int or tuple, optional, default=1):
|
821 |
+
Spacing between kernel elements.
|
822 |
+
groups (int, optional, default=1): Number of blocked connections
|
823 |
+
from input channels to output channels.
|
824 |
+
bias (bool, optional, default=True): If ``True``,
|
825 |
+
adds a learnable bias to the output.
|
826 |
+
"""
|
827 |
+
|
828 |
+
def __init__(self, in_channels=0, out_channels=0, kernel_size=3,
|
829 |
+
stride=1, padding=1, dilation=1, groups=1, bias=True,
|
830 |
+
padding_mode='zeros'):
|
831 |
+
super().__init__()
|
832 |
+
self.stride = stride
|
833 |
+
self.padding = padding
|
834 |
+
self.dilation = dilation
|
835 |
+
self.groups = groups
|
836 |
+
self.use_bias = bias
|
837 |
+
self.padding_mode = padding_mode
|
838 |
+
self.conditional = True
|
839 |
+
|
840 |
+
def forward(self, x, *args, conv_weights=(None, None), **kwargs):
|
841 |
+
r"""Hyper Conv2d forward. Convolve x using the provided weight and bias.
|
842 |
+
|
843 |
+
Args:
|
844 |
+
x (N x C x H x W tensor): Input tensor.
|
845 |
+
conv_weights (N x C2 x C1 x k x k tensor or list of tensors):
|
846 |
+
Convolution weights or [weight, bias].
|
847 |
+
Returns:
|
848 |
+
y (N x C2 x H x W tensor): Output tensor.
|
849 |
+
"""
|
850 |
+
if conv_weights is None:
|
851 |
+
conv_weight, conv_bias = None, None
|
852 |
+
elif isinstance(conv_weights, torch.Tensor):
|
853 |
+
conv_weight, conv_bias = conv_weights, None
|
854 |
+
else:
|
855 |
+
conv_weight, conv_bias = conv_weights
|
856 |
+
|
857 |
+
if conv_weight is None:
|
858 |
+
return x
|
859 |
+
if conv_bias is None:
|
860 |
+
if self.use_bias:
|
861 |
+
raise ValueError('bias not provided but set to true during '
|
862 |
+
'initialization')
|
863 |
+
conv_bias = [None] * x.size(0)
|
864 |
+
if self.padding_mode != 'zeros':
|
865 |
+
x = F.pad(x, [self.padding] * 4, mode=self.padding_mode)
|
866 |
+
padding = 0
|
867 |
+
else:
|
868 |
+
padding = self.padding
|
869 |
+
|
870 |
+
y = None
|
871 |
+
# noinspection PyArgumentList
|
872 |
+
for i in range(x.size(0)):
|
873 |
+
if self.stride >= 1:
|
874 |
+
yi = F.conv2d(x[i: i + 1],
|
875 |
+
weight=conv_weight[i], bias=conv_bias[i],
|
876 |
+
stride=self.stride, padding=padding,
|
877 |
+
dilation=self.dilation, groups=self.groups)
|
878 |
+
else:
|
879 |
+
yi = F.conv_transpose2d(x[i: i + 1], weight=conv_weight[i],
|
880 |
+
bias=conv_bias[i], padding=self.padding,
|
881 |
+
stride=int(1 / self.stride),
|
882 |
+
dilation=self.dilation,
|
883 |
+
output_padding=self.padding,
|
884 |
+
groups=self.groups)
|
885 |
+
y = torch.cat([y, yi]) if y is not None else yi
|
886 |
+
return y
|
887 |
+
|
888 |
+
|
889 |
+
class _BasePartialConvBlock(_BaseConvBlock):
|
890 |
+
r"""An abstract wrapper class that wraps a partial convolutional layer
|
891 |
+
with normalization and nonlinearity.
|
892 |
+
"""
|
893 |
+
|
894 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
895 |
+
padding, dilation, groups, bias, padding_mode,
|
896 |
+
weight_norm_type, weight_norm_params,
|
897 |
+
activation_norm_type, activation_norm_params,
|
898 |
+
nonlinearity, inplace_nonlinearity,
|
899 |
+
multi_channel, return_mask,
|
900 |
+
apply_noise, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, init_gain=1.0):
|
901 |
+
self.multi_channel = multi_channel
|
902 |
+
self.return_mask = return_mask
|
903 |
+
self.partial_conv = True
|
904 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
905 |
+
padding, dilation, groups, bias, padding_mode,
|
906 |
+
weight_norm_type, weight_norm_params,
|
907 |
+
activation_norm_type, activation_norm_params,
|
908 |
+
nonlinearity, inplace_nonlinearity, apply_noise,
|
909 |
+
False, order, input_dim, clamp, blur_kernel, output_scale, init_gain)
|
910 |
+
|
911 |
+
def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
|
912 |
+
padding, dilation, groups, bias, padding_mode,
|
913 |
+
input_dim):
|
914 |
+
if input_dim == 2:
|
915 |
+
layer_type = PartialConv2d
|
916 |
+
elif input_dim == 3:
|
917 |
+
layer_type = PartialConv3d
|
918 |
+
else:
|
919 |
+
raise ValueError('Partial conv only supports 2D and 3D conv now.')
|
920 |
+
layer = layer_type(
|
921 |
+
in_channels, out_channels, kernel_size, stride, padding,
|
922 |
+
dilation, groups, bias, padding_mode,
|
923 |
+
multi_channel=self.multi_channel, return_mask=self.return_mask)
|
924 |
+
return layer
|
925 |
+
|
926 |
+
def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs):
|
927 |
+
r"""
|
928 |
+
|
929 |
+
Args:
|
930 |
+
x (tensor): Input tensor.
|
931 |
+
cond_inputs (list of tensors) : Conditional input tensors.
|
932 |
+
mask_in (tensor, optional, default=``None``) If not ``None``,
|
933 |
+
it masks the valid input region.
|
934 |
+
kw_cond_inputs (dict) : Keyword conditional inputs.
|
935 |
+
Returns:
|
936 |
+
(tuple):
|
937 |
+
- x (tensor): Output tensor.
|
938 |
+
- mask_out (tensor, optional): Masks the valid output region.
|
939 |
+
"""
|
940 |
+
mask_out = None
|
941 |
+
for layer in self.layers.values():
|
942 |
+
if getattr(layer, 'conditional', False):
|
943 |
+
x = layer(x, *cond_inputs, **kw_cond_inputs)
|
944 |
+
elif getattr(layer, 'partial_conv', False):
|
945 |
+
x = layer(x, mask_in=mask_in, **kw_cond_inputs)
|
946 |
+
if type(x) == tuple:
|
947 |
+
x, mask_out = x
|
948 |
+
else:
|
949 |
+
x = layer(x)
|
950 |
+
|
951 |
+
if mask_out is not None:
|
952 |
+
return x, mask_out
|
953 |
+
return x
|
954 |
+
|
955 |
+
|
956 |
+
class PartialConv2dBlock(_BasePartialConvBlock):
|
957 |
+
r"""A Wrapper class that wraps ``PartialConv2d`` with normalization and
|
958 |
+
nonlinearity.
|
959 |
+
|
960 |
+
Args:
|
961 |
+
in_channels (int): Number of channels in the input tensor.
|
962 |
+
out_channels (int): Number of channels in the output tensor.
|
963 |
+
kernel_size (int or tuple): Size of the convolving kernel.
|
964 |
+
stride (int or float or tuple, optional, default=1):
|
965 |
+
Stride of the convolution.
|
966 |
+
padding (int or tuple, optional, default=0):
|
967 |
+
Zero-padding added to both sides of the input.
|
968 |
+
dilation (int or tuple, optional, default=1):
|
969 |
+
Spacing between kernel elements.
|
970 |
+
groups (int, optional, default=1): Number of blocked connections
|
971 |
+
from input channels to output channels.
|
972 |
+
bias (bool, optional, default=True):
|
973 |
+
If ``True``, adds a learnable bias to the output.
|
974 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
975 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
976 |
+
weight_norm_type (str, optional, default='none'):
|
977 |
+
Type of weight normalization.
|
978 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
979 |
+
or ``'weight_demod'``.
|
980 |
+
weight_norm_params (obj, optional, default=None):
|
981 |
+
Parameters of weight normalization.
|
982 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
983 |
+
keyword arguments when initializing weight normalization.
|
984 |
+
activation_norm_type (str, optional, default='none'):
|
985 |
+
Type of activation normalization.
|
986 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
987 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
988 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
989 |
+
activation_norm_params (obj, optional, default=None):
|
990 |
+
Parameters of activation normalization.
|
991 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
992 |
+
keyword arguments when initializing activation normalization.
|
993 |
+
nonlinearity (str, optional, default='none'):
|
994 |
+
Type of nonlinear activation function.
|
995 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
996 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
997 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
998 |
+
set ``inplace=True`` when initializing the nonlinearity layer.
|
999 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
1000 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
1001 |
+
order (str, optional, default='CNA'): Order of operations.
|
1002 |
+
``'C'``: convolution,
|
1003 |
+
``'N'``: normalization,
|
1004 |
+
``'A'``: nonlinear activation.
|
1005 |
+
For example, a block initialized with ``order='CNA'`` will
|
1006 |
+
do convolution first, then normalization, then nonlinearity.
|
1007 |
+
multi_channel (bool, optional, default=False): If ``True``, use
|
1008 |
+
different masks for different channels.
|
1009 |
+
return_mask (bool, optional, default=True): If ``True``, the
|
1010 |
+
forward call also returns a new mask.
|
1011 |
+
"""
|
1012 |
+
|
1013 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
1014 |
+
padding=0, dilation=1, groups=1, bias=True,
|
1015 |
+
padding_mode='zeros',
|
1016 |
+
weight_norm_type='none', weight_norm_params=None,
|
1017 |
+
activation_norm_type='none', activation_norm_params=None,
|
1018 |
+
nonlinearity='none', inplace_nonlinearity=False,
|
1019 |
+
multi_channel=False, return_mask=True,
|
1020 |
+
apply_noise=False, order='CNA', clamp=None):
|
1021 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
1022 |
+
padding, dilation, groups, bias, padding_mode,
|
1023 |
+
weight_norm_type, weight_norm_params,
|
1024 |
+
activation_norm_type, activation_norm_params,
|
1025 |
+
nonlinearity, inplace_nonlinearity,
|
1026 |
+
multi_channel, return_mask, apply_noise, order, 2,
|
1027 |
+
clamp)
|
1028 |
+
|
1029 |
+
|
1030 |
+
class PartialConv3dBlock(_BasePartialConvBlock):
|
1031 |
+
r"""A Wrapper class that wraps ``PartialConv3d`` with normalization and
|
1032 |
+
nonlinearity.
|
1033 |
+
|
1034 |
+
Args:
|
1035 |
+
in_channels (int): Number of channels in the input tensor.
|
1036 |
+
out_channels (int): Number of channels in the output tensor.
|
1037 |
+
kernel_size (int or tuple): Size of the convolving kernel.
|
1038 |
+
stride (int or float or tuple, optional, default=1):
|
1039 |
+
Stride of the convolution.
|
1040 |
+
padding (int or tuple, optional, default=0):
|
1041 |
+
Zero-padding added to both sides of the input.
|
1042 |
+
dilation (int or tuple, optional, default=1):
|
1043 |
+
Spacing between kernel elements.
|
1044 |
+
groups (int, optional, default=1): Number of blocked connections
|
1045 |
+
from input channels to output channels.
|
1046 |
+
bias (bool, optional, default=True):
|
1047 |
+
If ``True``, adds a learnable bias to the output.
|
1048 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
1049 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
1050 |
+
weight_norm_type (str, optional, default='none'):
|
1051 |
+
Type of weight normalization.
|
1052 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
1053 |
+
or ``'weight_demod'``.
|
1054 |
+
weight_norm_params (obj, optional, default=None):
|
1055 |
+
Parameters of weight normalization.
|
1056 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
1057 |
+
keyword arguments when initializing weight normalization.
|
1058 |
+
activation_norm_type (str, optional, default='none'):
|
1059 |
+
Type of activation normalization.
|
1060 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
1061 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
1062 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
1063 |
+
activation_norm_params (obj, optional, default=None):
|
1064 |
+
Parameters of activation normalization.
|
1065 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
1066 |
+
keyword arguments when initializing activation normalization.
|
1067 |
+
nonlinearity (str, optional, default='none'):
|
1068 |
+
Type of nonlinear activation function.
|
1069 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
1070 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
1071 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
1072 |
+
set ``inplace=True`` when initializing the nonlinearity layer.
|
1073 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
1074 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
1075 |
+
order (str, optional, default='CNA'): Order of operations.
|
1076 |
+
``'C'``: convolution,
|
1077 |
+
``'N'``: normalization,
|
1078 |
+
``'A'``: nonlinear activation.
|
1079 |
+
For example, a block initialized with ``order='CNA'`` will
|
1080 |
+
do convolution first, then normalization, then nonlinearity.
|
1081 |
+
multi_channel (bool, optional, default=False): If ``True``, use
|
1082 |
+
different masks for different channels.
|
1083 |
+
return_mask (bool, optional, default=True): If ``True``, the
|
1084 |
+
forward call also returns a new mask.
|
1085 |
+
"""
|
1086 |
+
|
1087 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
1088 |
+
padding=0, dilation=1, groups=1, bias=True,
|
1089 |
+
padding_mode='zeros',
|
1090 |
+
weight_norm_type='none', weight_norm_params=None,
|
1091 |
+
activation_norm_type='none', activation_norm_params=None,
|
1092 |
+
nonlinearity='none', inplace_nonlinearity=False,
|
1093 |
+
multi_channel=False, return_mask=True,
|
1094 |
+
apply_noise=False, order='CNA', clamp=None):
|
1095 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
1096 |
+
padding, dilation, groups, bias, padding_mode,
|
1097 |
+
weight_norm_type, weight_norm_params,
|
1098 |
+
activation_norm_type, activation_norm_params,
|
1099 |
+
nonlinearity, inplace_nonlinearity,
|
1100 |
+
multi_channel, return_mask, apply_noise, order, 3,
|
1101 |
+
clamp)
|
1102 |
+
|
1103 |
+
|
1104 |
+
class _MultiOutBaseConvBlock(_BaseConvBlock):
|
1105 |
+
r"""An abstract wrapper class that wraps a hyper convolutional layer with
|
1106 |
+
normalization and nonlinearity. It can return multiple outputs, if some
|
1107 |
+
layers in the block return more than one output.
|
1108 |
+
"""
|
1109 |
+
|
1110 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode,
|
1111 |
+
weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, nonlinearity,
|
1112 |
+
inplace_nonlinearity, apply_noise, blur, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1),
|
1113 |
+
output_scale=None, init_gain=1.0):
|
1114 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
1115 |
+
padding, dilation, groups, bias, padding_mode,
|
1116 |
+
weight_norm_type, weight_norm_params,
|
1117 |
+
activation_norm_type, activation_norm_params,
|
1118 |
+
nonlinearity, inplace_nonlinearity,
|
1119 |
+
apply_noise, blur, order, input_dim, clamp, blur_kernel, output_scale, init_gain)
|
1120 |
+
self.multiple_outputs = True
|
1121 |
+
|
1122 |
+
def forward(self, x, *cond_inputs, **kw_cond_inputs):
|
1123 |
+
r"""
|
1124 |
+
|
1125 |
+
Args:
|
1126 |
+
x (tensor): Input tensor.
|
1127 |
+
cond_inputs (list of tensors) : Conditional input tensors.
|
1128 |
+
kw_cond_inputs (dict) : Keyword conditional inputs.
|
1129 |
+
Returns:
|
1130 |
+
(tuple):
|
1131 |
+
- x (tensor): Main output tensor.
|
1132 |
+
- other_outputs (list of tensors): Other output tensors.
|
1133 |
+
"""
|
1134 |
+
other_outputs = []
|
1135 |
+
for layer in self.layers.values():
|
1136 |
+
if getattr(layer, 'conditional', False):
|
1137 |
+
x = layer(x, *cond_inputs, **kw_cond_inputs)
|
1138 |
+
if getattr(layer, 'multiple_outputs', False):
|
1139 |
+
x, other_output = layer(x)
|
1140 |
+
other_outputs.append(other_output)
|
1141 |
+
else:
|
1142 |
+
x = layer(x)
|
1143 |
+
return (x, *other_outputs)
|
1144 |
+
|
1145 |
+
|
1146 |
+
class MultiOutConv2dBlock(_MultiOutBaseConvBlock):
|
1147 |
+
r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and
|
1148 |
+
nonlinearity. It can return multiple outputs, if some layers in the block
|
1149 |
+
return more than one output.
|
1150 |
+
|
1151 |
+
Args:
|
1152 |
+
in_channels (int): Number of channels in the input tensor.
|
1153 |
+
out_channels (int): Number of channels in the output tensor.
|
1154 |
+
kernel_size (int or tuple): Size of the convolving kernel.
|
1155 |
+
stride (int or float or tuple, optional, default=1):
|
1156 |
+
Stride of the convolution.
|
1157 |
+
padding (int or tuple, optional, default=0):
|
1158 |
+
Zero-padding added to both sides of the input.
|
1159 |
+
dilation (int or tuple, optional, default=1):
|
1160 |
+
Spacing between kernel elements.
|
1161 |
+
groups (int, optional, default=1): Number of blocked connections
|
1162 |
+
from input channels to output channels.
|
1163 |
+
bias (bool, optional, default=True):
|
1164 |
+
If ``True``, adds a learnable bias to the output.
|
1165 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
1166 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
1167 |
+
weight_norm_type (str, optional, default='none'):
|
1168 |
+
Type of weight normalization.
|
1169 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
1170 |
+
or ``'weight_demod'``.
|
1171 |
+
weight_norm_params (obj, optional, default=None):
|
1172 |
+
Parameters of weight normalization.
|
1173 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
1174 |
+
keyword arguments when initializing weight normalization.
|
1175 |
+
activation_norm_type (str, optional, default='none'):
|
1176 |
+
Type of activation normalization.
|
1177 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
1178 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
1179 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
1180 |
+
activation_norm_params (obj, optional, default=None):
|
1181 |
+
Parameters of activation normalization.
|
1182 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
1183 |
+
keyword arguments when initializing activation normalization.
|
1184 |
+
nonlinearity (str, optional, default='none'):
|
1185 |
+
Type of nonlinear activation function.
|
1186 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
1187 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
1188 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
1189 |
+
set ``inplace=True`` when initializing the nonlinearity layer.
|
1190 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
1191 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
1192 |
+
order (str, optional, default='CNA'): Order of operations.
|
1193 |
+
``'C'``: convolution,
|
1194 |
+
``'N'``: normalization,
|
1195 |
+
``'A'``: nonlinear activation.
|
1196 |
+
For example, a block initialized with ``order='CNA'`` will
|
1197 |
+
do convolution first, then normalization, then nonlinearity.
|
1198 |
+
"""
|
1199 |
+
|
1200 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
1201 |
+
padding=0, dilation=1, groups=1, bias=True,
|
1202 |
+
padding_mode='zeros',
|
1203 |
+
weight_norm_type='none', weight_norm_params=None,
|
1204 |
+
activation_norm_type='none', activation_norm_params=None,
|
1205 |
+
nonlinearity='none', inplace_nonlinearity=False,
|
1206 |
+
apply_noise=False, blur=False, order='CNA', clamp=None):
|
1207 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
1208 |
+
padding, dilation, groups, bias, padding_mode,
|
1209 |
+
weight_norm_type, weight_norm_params,
|
1210 |
+
activation_norm_type, activation_norm_params,
|
1211 |
+
nonlinearity, inplace_nonlinearity,
|
1212 |
+
apply_noise, blur, order, 2, clamp)
|
1213 |
+
|
1214 |
+
|
1215 |
+
###############################################################################
|
1216 |
+
# BSD 3-Clause License
|
1217 |
+
#
|
1218 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
1219 |
+
#
|
1220 |
+
# Author & Contact: Guilin Liu (guilinl@nvidia.com)
|
1221 |
+
###############################################################################
|
1222 |
+
class PartialConv2d(nn.Conv2d):
|
1223 |
+
r"""Partial 2D convolution in
|
1224 |
+
"Image inpainting for irregular holes using partial convolutions."
|
1225 |
+
Liu et al., ECCV 2018
|
1226 |
+
"""
|
1227 |
+
|
1228 |
+
def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs):
|
1229 |
+
# whether the mask is multi-channel or not
|
1230 |
+
self.multi_channel = multi_channel
|
1231 |
+
self.return_mask = return_mask
|
1232 |
+
super(PartialConv2d, self).__init__(*args, **kwargs)
|
1233 |
+
|
1234 |
+
if self.multi_channel:
|
1235 |
+
self.weight_maskUpdater = torch.ones(self.out_channels,
|
1236 |
+
self.in_channels,
|
1237 |
+
self.kernel_size[0],
|
1238 |
+
self.kernel_size[1])
|
1239 |
+
else:
|
1240 |
+
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0],
|
1241 |
+
self.kernel_size[1])
|
1242 |
+
|
1243 |
+
shape = self.weight_maskUpdater.shape
|
1244 |
+
self.slide_winsize = shape[1] * shape[2] * shape[3]
|
1245 |
+
|
1246 |
+
self.last_size = (None, None, None, None)
|
1247 |
+
self.update_mask = None
|
1248 |
+
self.mask_ratio = None
|
1249 |
+
self.partial_conv = True
|
1250 |
+
|
1251 |
+
def forward(self, x, mask_in=None):
|
1252 |
+
r"""
|
1253 |
+
|
1254 |
+
Args:
|
1255 |
+
x (tensor): Input tensor.
|
1256 |
+
mask_in (tensor, optional, default=``None``) If not ``None``,
|
1257 |
+
it masks the valid input region.
|
1258 |
+
"""
|
1259 |
+
assert len(x.shape) == 4
|
1260 |
+
if mask_in is not None or self.last_size != tuple(x.shape):
|
1261 |
+
self.last_size = tuple(x.shape)
|
1262 |
+
|
1263 |
+
with torch.no_grad():
|
1264 |
+
if self.weight_maskUpdater.type() != x.type():
|
1265 |
+
self.weight_maskUpdater = self.weight_maskUpdater.to(x)
|
1266 |
+
|
1267 |
+
if mask_in is None:
|
1268 |
+
# If mask is not provided, create a mask.
|
1269 |
+
if self.multi_channel:
|
1270 |
+
mask = torch.ones(x.data.shape[0],
|
1271 |
+
x.data.shape[1],
|
1272 |
+
x.data.shape[2],
|
1273 |
+
x.data.shape[3]).to(x)
|
1274 |
+
else:
|
1275 |
+
mask = torch.ones(1, 1, x.data.shape[2],
|
1276 |
+
x.data.shape[3]).to(x)
|
1277 |
+
else:
|
1278 |
+
mask = mask_in
|
1279 |
+
|
1280 |
+
self.update_mask = F.conv2d(mask, self.weight_maskUpdater,
|
1281 |
+
bias=None, stride=self.stride,
|
1282 |
+
padding=self.padding,
|
1283 |
+
dilation=self.dilation, groups=1)
|
1284 |
+
|
1285 |
+
# For mixed precision training, eps from 1e-8 to 1e-6.
|
1286 |
+
eps = 1e-6
|
1287 |
+
self.mask_ratio = self.slide_winsize / (self.update_mask + eps)
|
1288 |
+
self.update_mask = torch.clamp(self.update_mask, 0, 1)
|
1289 |
+
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
|
1290 |
+
|
1291 |
+
raw_out = super(PartialConv2d, self).forward(
|
1292 |
+
torch.mul(x, mask) if mask_in is not None else x)
|
1293 |
+
|
1294 |
+
if self.bias is not None:
|
1295 |
+
bias_view = self.bias.view(1, self.out_channels, 1, 1)
|
1296 |
+
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
|
1297 |
+
output = torch.mul(output, self.update_mask)
|
1298 |
+
else:
|
1299 |
+
output = torch.mul(raw_out, self.mask_ratio)
|
1300 |
+
|
1301 |
+
if self.return_mask:
|
1302 |
+
return output, self.update_mask
|
1303 |
+
else:
|
1304 |
+
return output
|
1305 |
+
|
1306 |
+
|
1307 |
+
class PartialConv3d(nn.Conv3d):
|
1308 |
+
r"""Partial 3D convolution in
|
1309 |
+
"Image inpainting for irregular holes using partial convolutions."
|
1310 |
+
Liu et al., ECCV 2018
|
1311 |
+
"""
|
1312 |
+
|
1313 |
+
def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs):
|
1314 |
+
# whether the mask is multi-channel or not
|
1315 |
+
self.multi_channel = multi_channel
|
1316 |
+
self.return_mask = return_mask
|
1317 |
+
super(PartialConv3d, self).__init__(*args, **kwargs)
|
1318 |
+
|
1319 |
+
if self.multi_channel:
|
1320 |
+
self.weight_maskUpdater = \
|
1321 |
+
torch.ones(self.out_channels, self.in_channels,
|
1322 |
+
self.kernel_size[0], self.kernel_size[1],
|
1323 |
+
self.kernel_size[2])
|
1324 |
+
else:
|
1325 |
+
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0],
|
1326 |
+
self.kernel_size[1],
|
1327 |
+
self.kernel_size[2])
|
1328 |
+
self.weight_maskUpdater = self.weight_maskUpdater.to('cuda')
|
1329 |
+
|
1330 |
+
shape = self.weight_maskUpdater.shape
|
1331 |
+
self.slide_winsize = shape[1] * shape[2] * shape[3] * shape[4]
|
1332 |
+
self.partial_conv = True
|
1333 |
+
|
1334 |
+
def forward(self, x, mask_in=None):
|
1335 |
+
r"""
|
1336 |
+
|
1337 |
+
Args:
|
1338 |
+
x (tensor): Input tensor.
|
1339 |
+
mask_in (tensor, optional, default=``None``) If not ``None``, it
|
1340 |
+
masks the valid input region.
|
1341 |
+
"""
|
1342 |
+
assert len(x.shape) == 5
|
1343 |
+
|
1344 |
+
with torch.no_grad():
|
1345 |
+
mask = mask_in
|
1346 |
+
update_mask = F.conv3d(mask, self.weight_maskUpdater, bias=None,
|
1347 |
+
stride=self.stride, padding=self.padding,
|
1348 |
+
dilation=self.dilation, groups=1)
|
1349 |
+
|
1350 |
+
mask_ratio = self.slide_winsize / (update_mask + 1e-8)
|
1351 |
+
update_mask = torch.clamp(update_mask, 0, 1)
|
1352 |
+
mask_ratio = torch.mul(mask_ratio, update_mask)
|
1353 |
+
|
1354 |
+
raw_out = super(PartialConv3d, self).forward(torch.mul(x, mask_in))
|
1355 |
+
|
1356 |
+
if self.bias is not None:
|
1357 |
+
bias_view = self.bias.view(1, self.out_channels, 1, 1, 1)
|
1358 |
+
output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view
|
1359 |
+
if mask_in is not None:
|
1360 |
+
output = torch.mul(output, update_mask)
|
1361 |
+
else:
|
1362 |
+
output = torch.mul(raw_out, mask_ratio)
|
1363 |
+
|
1364 |
+
if self.return_mask:
|
1365 |
+
return output, update_mask
|
1366 |
+
else:
|
1367 |
+
return output
|
1368 |
+
|
1369 |
+
|
1370 |
+
class Embedding2d(nn.Embedding):
|
1371 |
+
def __init__(self, in_channels, out_channels):
|
1372 |
+
super().__init__(in_channels, out_channels)
|
1373 |
+
|
1374 |
+
def forward(self, x):
|
1375 |
+
return F.embedding(
|
1376 |
+
x.squeeze(1).long(), self.weight, self.padding_idx, self.max_norm,
|
1377 |
+
self.norm_type, self.scale_grad_by_freq, self.sparse).permute(0, 3, 1, 2).contiguous()
|
imaginaire/layers/misc.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
class ApplyNoise(nn.Module):
|
10 |
+
r"""Add Gaussian noise to the input tensor."""
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
# scale of the noise
|
15 |
+
self.scale = nn.Parameter(torch.zeros(1))
|
16 |
+
self.conditional = True
|
17 |
+
|
18 |
+
def forward(self, x, *_args, noise=None, **_kwargs):
|
19 |
+
r"""
|
20 |
+
|
21 |
+
Args:
|
22 |
+
x (tensor): Input tensor.
|
23 |
+
noise (tensor, optional, default=``None``) : Noise tensor to be
|
24 |
+
added to the input.
|
25 |
+
"""
|
26 |
+
if noise is None:
|
27 |
+
sz = x.size()
|
28 |
+
noise = x.new_empty(sz[0], 1, *sz[2:]).normal_()
|
29 |
+
|
30 |
+
return x + self.scale * noise
|
31 |
+
|
32 |
+
|
33 |
+
class PartialSequential(nn.Sequential):
|
34 |
+
r"""Sequential block for partial convolutions."""
|
35 |
+
def __init__(self, *modules):
|
36 |
+
super(PartialSequential, self).__init__(*modules)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
r"""
|
40 |
+
|
41 |
+
Args:
|
42 |
+
x (tensor): Input tensor.
|
43 |
+
"""
|
44 |
+
act = x[:, :-1]
|
45 |
+
mask = x[:, -1].unsqueeze(1)
|
46 |
+
for module in self:
|
47 |
+
act, mask = module(act, mask_in=mask)
|
48 |
+
return act
|
49 |
+
|
50 |
+
|
51 |
+
class ConstantInput(nn.Module):
|
52 |
+
def __init__(self, channel, size=4):
|
53 |
+
super().__init__()
|
54 |
+
if isinstance(size, int):
|
55 |
+
h, w = size, size
|
56 |
+
else:
|
57 |
+
h, w = size
|
58 |
+
self.input = nn.Parameter(torch.randn(1, channel, h, w))
|
59 |
+
|
60 |
+
def forward(self):
|
61 |
+
return self.input
|
imaginaire/layers/non_local.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from imaginaire.layers import Conv2dBlock
|
11 |
+
|
12 |
+
|
13 |
+
class NonLocal2dBlock(nn.Module):
|
14 |
+
r"""Self attention Layer
|
15 |
+
|
16 |
+
Args:
|
17 |
+
in_channels (int): Number of channels in the input tensor.
|
18 |
+
scale (bool, optional, default=True): If ``True``, scale the
|
19 |
+
output by a learnable parameter.
|
20 |
+
clamp (bool, optional, default=``False``): If ``True``, clamp the
|
21 |
+
scaling parameter to (-1, 1).
|
22 |
+
weight_norm_type (str, optional, default='none'):
|
23 |
+
Type of weight normalization.
|
24 |
+
``'none'``, ``'spectral'``, ``'weight'``.
|
25 |
+
weight_norm_params (obj, optional, default=None):
|
26 |
+
Parameters of weight normalization.
|
27 |
+
If not ``None``, weight_norm_params.__dict__ will be used as
|
28 |
+
keyword arguments when initializing weight normalization.
|
29 |
+
bias (bool, optional, default=True): If ``True``, adds bias in the
|
30 |
+
convolutional blocks.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self,
|
34 |
+
in_channels,
|
35 |
+
scale=True,
|
36 |
+
clamp=False,
|
37 |
+
weight_norm_type='none',
|
38 |
+
weight_norm_params=None,
|
39 |
+
bias=True):
|
40 |
+
super(NonLocal2dBlock, self).__init__()
|
41 |
+
self.clamp = clamp
|
42 |
+
self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0
|
43 |
+
self.in_channels = in_channels
|
44 |
+
base_conv2d_block = partial(Conv2dBlock,
|
45 |
+
kernel_size=1,
|
46 |
+
stride=1,
|
47 |
+
padding=0,
|
48 |
+
weight_norm_type=weight_norm_type,
|
49 |
+
weight_norm_params=weight_norm_params,
|
50 |
+
bias=bias)
|
51 |
+
self.theta = base_conv2d_block(in_channels, in_channels // 8)
|
52 |
+
self.phi = base_conv2d_block(in_channels, in_channels // 8)
|
53 |
+
self.g = base_conv2d_block(in_channels, in_channels // 2)
|
54 |
+
self.out_conv = base_conv2d_block(in_channels // 2, in_channels)
|
55 |
+
self.softmax = nn.Softmax(dim=-1)
|
56 |
+
self.max_pool = nn.MaxPool2d(2)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
r"""
|
60 |
+
|
61 |
+
Args:
|
62 |
+
x (tensor) : input feature maps (B X C X W X H)
|
63 |
+
Returns:
|
64 |
+
(tuple):
|
65 |
+
- out (tensor) : self attention value + input feature
|
66 |
+
- attention (tensor): B x N x N (N is Width*Height)
|
67 |
+
"""
|
68 |
+
n, c, h, w = x.size()
|
69 |
+
theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1).contiguous()
|
70 |
+
|
71 |
+
phi = self.phi(x)
|
72 |
+
phi = self.max_pool(phi).view(n, -1, h * w // 4)
|
73 |
+
|
74 |
+
energy = torch.bmm(theta, phi)
|
75 |
+
attention = self.softmax(energy)
|
76 |
+
|
77 |
+
g = self.g(x)
|
78 |
+
g = self.max_pool(g).view(n, -1, h * w // 4)
|
79 |
+
|
80 |
+
out = torch.bmm(g, attention.permute(0, 2, 1).contiguous())
|
81 |
+
out = out.view(n, c // 2, h, w)
|
82 |
+
out = self.out_conv(out)
|
83 |
+
|
84 |
+
if self.clamp:
|
85 |
+
out = self.gamma.clamp(-1, 1) * out + x
|
86 |
+
else:
|
87 |
+
out = self.gamma * out + x
|
88 |
+
return out
|
imaginaire/layers/nonlinearity.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from imaginaire.third_party.bias_act.bias_act import FusedNonlinearity
|
10 |
+
|
11 |
+
|
12 |
+
class ScaledLeakyReLU(nn.Module):
|
13 |
+
def __init__(self, negative_slope=0.2, scale=2 ** 0.5, inplace=False):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.negative_slope = negative_slope
|
17 |
+
self.scale = scale
|
18 |
+
self.inplace = inplace
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return F.leaky_relu(x, self.negative_slope, inplace=self.inplace) * self.scale
|
22 |
+
# return _fused_scaled_leakyrelu(x, self.negative_slope, self.inplace, self.scale)
|
23 |
+
|
24 |
+
|
25 |
+
# @torch.jit.script
|
26 |
+
# def _fused_scaled_leakyrelu(x: torch.Tensor, negative_slope: float, inplace: bool, scale: float):
|
27 |
+
# return F.leaky_relu(x, negative_slope, inplace=inplace) * scale
|
28 |
+
|
29 |
+
|
30 |
+
def get_nonlinearity_layer(nonlinearity_type, inplace, **kwargs):
|
31 |
+
r"""Return a nonlinearity layer.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
nonlinearity_type (str):
|
35 |
+
Type of nonlinear activation function.
|
36 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
37 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
38 |
+
inplace (bool): If ``True``, set ``inplace=True`` when initializing
|
39 |
+
the nonlinearity layer.
|
40 |
+
"""
|
41 |
+
if nonlinearity_type.startswith('fused'):
|
42 |
+
nonlinearity = FusedNonlinearity(nonlinearity=nonlinearity_type[6:], **kwargs)
|
43 |
+
elif nonlinearity_type == 'relu':
|
44 |
+
nonlinearity = nn.ReLU(inplace=inplace)
|
45 |
+
elif nonlinearity_type == 'leakyrelu':
|
46 |
+
nonlinearity = nn.LeakyReLU(0.2, inplace=inplace)
|
47 |
+
elif nonlinearity_type == 'scaled_leakyrelu':
|
48 |
+
nonlinearity = ScaledLeakyReLU(0.2, inplace=inplace)
|
49 |
+
import imaginaire.config
|
50 |
+
if imaginaire.config.USE_JIT:
|
51 |
+
nonlinearity = torch.jit.script(nonlinearity)
|
52 |
+
elif nonlinearity_type == 'prelu':
|
53 |
+
nonlinearity = nn.PReLU()
|
54 |
+
elif nonlinearity_type == 'tanh':
|
55 |
+
nonlinearity = nn.Tanh()
|
56 |
+
elif nonlinearity_type == 'sigmoid':
|
57 |
+
nonlinearity = nn.Sigmoid()
|
58 |
+
elif nonlinearity_type.startswith('softmax'):
|
59 |
+
dim = nonlinearity_type.split(',')[1] if ',' in nonlinearity_type else 1
|
60 |
+
nonlinearity = nn.Softmax(dim=int(dim))
|
61 |
+
elif nonlinearity_type == 'none' or nonlinearity_type == '':
|
62 |
+
nonlinearity = None
|
63 |
+
else:
|
64 |
+
raise ValueError('Nonlinearity %s is not recognized' % nonlinearity_type)
|
65 |
+
return nonlinearity
|
imaginaire/layers/residual.py
ADDED
@@ -0,0 +1,1411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import functools
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import Upsample as NearestUpsample
|
10 |
+
from torch.utils.checkpoint import checkpoint
|
11 |
+
|
12 |
+
from .conv import (Conv1dBlock, Conv2dBlock, Conv3dBlock, HyperConv2dBlock,
|
13 |
+
LinearBlock, MultiOutConv2dBlock, PartialConv2dBlock,
|
14 |
+
PartialConv3dBlock, ModulatedConv2dBlock)
|
15 |
+
from imaginaire.third_party.upfirdn2d.upfirdn2d import BlurUpsample
|
16 |
+
|
17 |
+
|
18 |
+
class _BaseResBlock(nn.Module):
|
19 |
+
r"""An abstract class for residual blocks.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
23 |
+
stride, padding, dilation, groups, bias, padding_mode,
|
24 |
+
weight_norm_type, weight_norm_params,
|
25 |
+
activation_norm_type, activation_norm_params,
|
26 |
+
skip_activation_norm, skip_nonlinearity,
|
27 |
+
nonlinearity, inplace_nonlinearity, apply_noise,
|
28 |
+
hidden_channels_equal_out_channels,
|
29 |
+
order, block, learn_shortcut, clamp, output_scale,
|
30 |
+
skip_block=None, blur=False, upsample_first=True, skip_weight_norm=True):
|
31 |
+
super().__init__()
|
32 |
+
self.in_channels = in_channels
|
33 |
+
self.out_channels = out_channels
|
34 |
+
self.output_scale = output_scale
|
35 |
+
self.upsample_first = upsample_first
|
36 |
+
self.stride = stride
|
37 |
+
self.blur = blur
|
38 |
+
if skip_block is None:
|
39 |
+
skip_block = block
|
40 |
+
|
41 |
+
if order == 'pre_act':
|
42 |
+
order = 'NACNAC'
|
43 |
+
if isinstance(bias, bool):
|
44 |
+
# The bias for conv_block_0, conv_block_1, and conv_block_s.
|
45 |
+
biases = [bias, bias, bias]
|
46 |
+
elif isinstance(bias, list):
|
47 |
+
if len(bias) == 3:
|
48 |
+
biases = bias
|
49 |
+
else:
|
50 |
+
raise ValueError('Bias list must be 3.')
|
51 |
+
else:
|
52 |
+
raise ValueError('Bias must be either an integer or s list.')
|
53 |
+
if learn_shortcut is None:
|
54 |
+
self.learn_shortcut = (in_channels != out_channels)
|
55 |
+
else:
|
56 |
+
self.learn_shortcut = learn_shortcut
|
57 |
+
if len(order) > 6 or len(order) < 5:
|
58 |
+
raise ValueError('order must be either 5 or 6 characters')
|
59 |
+
if hidden_channels_equal_out_channels:
|
60 |
+
hidden_channels = out_channels
|
61 |
+
else:
|
62 |
+
hidden_channels = min(in_channels, out_channels)
|
63 |
+
|
64 |
+
# Parameters.
|
65 |
+
residual_params = {}
|
66 |
+
shortcut_params = {}
|
67 |
+
base_params = dict(dilation=dilation,
|
68 |
+
groups=groups,
|
69 |
+
padding_mode=padding_mode,
|
70 |
+
clamp=clamp)
|
71 |
+
residual_params.update(base_params)
|
72 |
+
residual_params.update(
|
73 |
+
dict(activation_norm_type=activation_norm_type,
|
74 |
+
activation_norm_params=activation_norm_params,
|
75 |
+
weight_norm_type=weight_norm_type,
|
76 |
+
weight_norm_params=weight_norm_params,
|
77 |
+
padding=padding,
|
78 |
+
apply_noise=apply_noise))
|
79 |
+
shortcut_params.update(base_params)
|
80 |
+
shortcut_params.update(dict(kernel_size=1))
|
81 |
+
if skip_activation_norm:
|
82 |
+
shortcut_params.update(
|
83 |
+
dict(activation_norm_type=activation_norm_type,
|
84 |
+
activation_norm_params=activation_norm_params,
|
85 |
+
apply_noise=False))
|
86 |
+
if skip_weight_norm:
|
87 |
+
shortcut_params.update(
|
88 |
+
dict(weight_norm_type=weight_norm_type,
|
89 |
+
weight_norm_params=weight_norm_params))
|
90 |
+
|
91 |
+
# Residual branch.
|
92 |
+
if order.find('A') < order.find('C') and \
|
93 |
+
(activation_norm_type == '' or activation_norm_type == 'none'):
|
94 |
+
# Nonlinearity is the first operation in the residual path.
|
95 |
+
# In-place nonlinearity will modify the input variable and cause
|
96 |
+
# backward error.
|
97 |
+
first_inplace = False
|
98 |
+
else:
|
99 |
+
first_inplace = inplace_nonlinearity
|
100 |
+
|
101 |
+
(first_stride, second_stride, shortcut_stride,
|
102 |
+
first_blur, second_blur, shortcut_blur) = self._get_stride_blur()
|
103 |
+
self.conv_block_0 = block(
|
104 |
+
in_channels, hidden_channels,
|
105 |
+
kernel_size=kernel_size,
|
106 |
+
bias=biases[0],
|
107 |
+
nonlinearity=nonlinearity,
|
108 |
+
order=order[0:3],
|
109 |
+
inplace_nonlinearity=first_inplace,
|
110 |
+
stride=first_stride,
|
111 |
+
blur=first_blur,
|
112 |
+
**residual_params
|
113 |
+
)
|
114 |
+
self.conv_block_1 = block(
|
115 |
+
hidden_channels, out_channels,
|
116 |
+
kernel_size=kernel_size,
|
117 |
+
bias=biases[1],
|
118 |
+
nonlinearity=nonlinearity,
|
119 |
+
order=order[3:],
|
120 |
+
inplace_nonlinearity=inplace_nonlinearity,
|
121 |
+
stride=second_stride,
|
122 |
+
blur=second_blur,
|
123 |
+
**residual_params
|
124 |
+
)
|
125 |
+
|
126 |
+
# Shortcut branch.
|
127 |
+
if self.learn_shortcut:
|
128 |
+
if skip_nonlinearity:
|
129 |
+
skip_nonlinearity_type = nonlinearity
|
130 |
+
else:
|
131 |
+
skip_nonlinearity_type = ''
|
132 |
+
self.conv_block_s = skip_block(in_channels, out_channels,
|
133 |
+
bias=biases[2],
|
134 |
+
nonlinearity=skip_nonlinearity_type,
|
135 |
+
order=order[0:3],
|
136 |
+
stride=shortcut_stride,
|
137 |
+
blur=shortcut_blur,
|
138 |
+
**shortcut_params)
|
139 |
+
elif in_channels < out_channels:
|
140 |
+
if skip_nonlinearity:
|
141 |
+
skip_nonlinearity_type = nonlinearity
|
142 |
+
else:
|
143 |
+
skip_nonlinearity_type = ''
|
144 |
+
self.conv_block_s = skip_block(in_channels,
|
145 |
+
out_channels - in_channels,
|
146 |
+
bias=biases[2],
|
147 |
+
nonlinearity=skip_nonlinearity_type,
|
148 |
+
order=order[0:3],
|
149 |
+
stride=shortcut_stride,
|
150 |
+
blur=shortcut_blur,
|
151 |
+
**shortcut_params)
|
152 |
+
|
153 |
+
# Whether this block expects conditional inputs.
|
154 |
+
self.conditional = \
|
155 |
+
getattr(self.conv_block_0, 'conditional', False) or \
|
156 |
+
getattr(self.conv_block_1, 'conditional', False)
|
157 |
+
|
158 |
+
def _get_stride_blur(self):
|
159 |
+
if self.stride > 1:
|
160 |
+
# Downsampling.
|
161 |
+
first_stride, second_stride = 1, self.stride
|
162 |
+
first_blur, second_blur = False, self.blur
|
163 |
+
shortcut_stride = self.stride
|
164 |
+
shortcut_blur = self.blur
|
165 |
+
self.upsample = None
|
166 |
+
elif self.stride < 1:
|
167 |
+
# Upsampling.
|
168 |
+
first_stride, second_stride = self.stride, 1
|
169 |
+
first_blur, second_blur = self.blur, False
|
170 |
+
shortcut_blur = False
|
171 |
+
shortcut_stride = 1
|
172 |
+
if self.blur:
|
173 |
+
# The shortcut branch uses blur_upsample + stride-1 conv
|
174 |
+
self.upsample = BlurUpsample()
|
175 |
+
else:
|
176 |
+
shortcut_stride = self.stride
|
177 |
+
self.upsample = nn.Upsample(scale_factor=2)
|
178 |
+
else:
|
179 |
+
first_stride = second_stride = 1
|
180 |
+
first_blur = second_blur = False
|
181 |
+
shortcut_stride = 1
|
182 |
+
shortcut_blur = False
|
183 |
+
self.upsample = None
|
184 |
+
return (first_stride, second_stride, shortcut_stride,
|
185 |
+
first_blur, second_blur, shortcut_blur)
|
186 |
+
|
187 |
+
def conv_blocks(
|
188 |
+
self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs
|
189 |
+
):
|
190 |
+
r"""Returns the output of the residual branch.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
x (tensor): Input tensor.
|
194 |
+
cond_inputs (list of tensors) : Conditional input tensors.
|
195 |
+
kw_cond_inputs (dict) : Keyword conditional inputs.
|
196 |
+
Returns:
|
197 |
+
dx (tensor): Output tensor.
|
198 |
+
"""
|
199 |
+
if separate_cond:
|
200 |
+
dx = self.conv_block_0(x, cond_inputs[0],
|
201 |
+
**kw_cond_inputs.get('kwargs_0', {}))
|
202 |
+
dx = self.conv_block_1(dx, cond_inputs[1],
|
203 |
+
**kw_cond_inputs.get('kwargs_1', {}))
|
204 |
+
else:
|
205 |
+
dx = self.conv_block_0(x, *cond_inputs, **kw_cond_inputs)
|
206 |
+
dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
|
207 |
+
return dx
|
208 |
+
|
209 |
+
def forward(self, x, *cond_inputs, do_checkpoint=False, separate_cond=False,
|
210 |
+
**kw_cond_inputs):
|
211 |
+
r"""
|
212 |
+
|
213 |
+
Args:
|
214 |
+
x (tensor): Input tensor.
|
215 |
+
cond_inputs (list of tensors) : Conditional input tensors.
|
216 |
+
do_checkpoint (bool, optional, default=``False``) If ``True``,
|
217 |
+
trade compute for memory by checkpointing the model.
|
218 |
+
kw_cond_inputs (dict) : Keyword conditional inputs.
|
219 |
+
Returns:
|
220 |
+
output (tensor): Output tensor.
|
221 |
+
"""
|
222 |
+
if do_checkpoint:
|
223 |
+
dx = checkpoint(self.conv_blocks, x, *cond_inputs,
|
224 |
+
separate_cond=separate_cond, **kw_cond_inputs)
|
225 |
+
else:
|
226 |
+
dx = self.conv_blocks(x, *cond_inputs,
|
227 |
+
separate_cond=separate_cond, **kw_cond_inputs)
|
228 |
+
|
229 |
+
if self.upsample_first and self.upsample is not None:
|
230 |
+
x = self.upsample(x)
|
231 |
+
if self.learn_shortcut:
|
232 |
+
if separate_cond:
|
233 |
+
x_shortcut = self.conv_block_s(
|
234 |
+
x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {})
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
x_shortcut = self.conv_block_s(
|
238 |
+
x, *cond_inputs, **kw_cond_inputs
|
239 |
+
)
|
240 |
+
elif self.in_channels < self.out_channels:
|
241 |
+
if separate_cond:
|
242 |
+
x_shortcut_pad = self.conv_block_s(
|
243 |
+
x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {})
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
x_shortcut_pad = self.conv_block_s(
|
247 |
+
x, *cond_inputs, **kw_cond_inputs
|
248 |
+
)
|
249 |
+
x_shortcut = torch.cat((x, x_shortcut_pad), dim=1)
|
250 |
+
elif self.in_channels > self.out_channels:
|
251 |
+
x_shortcut = x[:, :self.out_channels, :, :]
|
252 |
+
else:
|
253 |
+
x_shortcut = x
|
254 |
+
if not self.upsample_first and self.upsample is not None:
|
255 |
+
x_shortcut = self.upsample(x_shortcut)
|
256 |
+
|
257 |
+
output = x_shortcut + dx
|
258 |
+
return self.output_scale * output
|
259 |
+
|
260 |
+
def extra_repr(self):
|
261 |
+
s = 'output_scale={output_scale}'
|
262 |
+
return s.format(**self.__dict__)
|
263 |
+
|
264 |
+
|
265 |
+
class ModulatedRes2dBlock(_BaseResBlock):
|
266 |
+
def __init__(self, in_channels, out_channels, style_dim, kernel_size=3,
|
267 |
+
stride=1, padding=1, dilation=1, groups=1, bias=True,
|
268 |
+
padding_mode='zeros',
|
269 |
+
weight_norm_type='none', weight_norm_params=None,
|
270 |
+
activation_norm_type='none', activation_norm_params=None,
|
271 |
+
skip_activation_norm=True, skip_nonlinearity=False,
|
272 |
+
nonlinearity='leakyrelu', inplace_nonlinearity=False,
|
273 |
+
apply_noise=True, hidden_channels_equal_out_channels=False,
|
274 |
+
order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1,
|
275 |
+
demodulate=True, eps=1e-8):
|
276 |
+
block = functools.partial(ModulatedConv2dBlock,
|
277 |
+
style_dim=style_dim,
|
278 |
+
demodulate=demodulate, eps=eps)
|
279 |
+
skip_block = Conv2dBlock
|
280 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
281 |
+
padding, dilation, groups, bias, padding_mode,
|
282 |
+
weight_norm_type, weight_norm_params,
|
283 |
+
activation_norm_type, activation_norm_params,
|
284 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
285 |
+
inplace_nonlinearity, apply_noise,
|
286 |
+
hidden_channels_equal_out_channels, order, block,
|
287 |
+
learn_shortcut, clamp, output_scale, skip_block=skip_block)
|
288 |
+
|
289 |
+
def conv_blocks(self, x, *cond_inputs, **kw_cond_inputs):
|
290 |
+
assert len(list(cond_inputs)) == 2
|
291 |
+
dx = self.conv_block_0(x, cond_inputs[0], **kw_cond_inputs)
|
292 |
+
dx = self.conv_block_1(dx, cond_inputs[1], **kw_cond_inputs)
|
293 |
+
return dx
|
294 |
+
|
295 |
+
|
296 |
+
class ResLinearBlock(_BaseResBlock):
|
297 |
+
r"""Residual block with full-connected layers.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
in_channels (int) : Number of channels in the input tensor.
|
301 |
+
out_channels (int) : Number of channels in the output tensor.
|
302 |
+
weight_norm_type (str, optional, default='none'):
|
303 |
+
Type of weight normalization.
|
304 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
305 |
+
or ``'weight_demod'``.
|
306 |
+
weight_norm_params (obj, optional, default=None):
|
307 |
+
Parameters of weight normalization.
|
308 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
309 |
+
keyword arguments when initializing weight normalization.
|
310 |
+
activation_norm_type (str, optional, default='none'):
|
311 |
+
Type of activation normalization.
|
312 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
313 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
314 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
315 |
+
activation_norm_params (obj, optional, default=None):
|
316 |
+
Parameters of activation normalization.
|
317 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
318 |
+
keyword arguments when initializing activation normalization.
|
319 |
+
skip_activation_norm (bool, optional, default=True): If ``True`` and
|
320 |
+
``learn_shortcut`` is also ``True``, applies activation norm to the
|
321 |
+
learned shortcut connection.
|
322 |
+
skip_nonlinearity (bool, optional, default=True): If ``True`` and
|
323 |
+
``learn_shortcut`` is also ``True``, applies nonlinearity to the
|
324 |
+
learned shortcut connection.
|
325 |
+
nonlinearity (str, optional, default='none'):
|
326 |
+
Type of nonlinear activation function in the residual link.
|
327 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
328 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
329 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
330 |
+
set ``inplace=True`` when initializing the nonlinearity layers.
|
331 |
+
apply_noise (bool, optional, default=False): If ``True``, add
|
332 |
+
Gaussian noise with learnable magnitude after the
|
333 |
+
fully-connected layer.
|
334 |
+
hidden_channels_equal_out_channels (bool, optional, default=False):
|
335 |
+
If ``True``, set the hidden channel number to be equal to the
|
336 |
+
output channel number. If ``False``, the hidden channel number
|
337 |
+
equals to the smaller of the input channel number and the
|
338 |
+
output channel number.
|
339 |
+
order (str, optional, default='CNACNA'): Order of operations
|
340 |
+
in the residual link.
|
341 |
+
``'C'``: fully-connected,
|
342 |
+
``'N'``: normalization,
|
343 |
+
``'A'``: nonlinear activation.
|
344 |
+
learn_shortcut (bool, optional, default=False): If ``True``, always use
|
345 |
+
a convolutional shortcut instead of an identity one, otherwise only
|
346 |
+
use a convolutional one if input and output have different number of
|
347 |
+
channels.
|
348 |
+
"""
|
349 |
+
|
350 |
+
def __init__(self, in_channels, out_channels, bias=True,
|
351 |
+
weight_norm_type='none', weight_norm_params=None,
|
352 |
+
activation_norm_type='none', activation_norm_params=None,
|
353 |
+
skip_activation_norm=True, skip_nonlinearity=False,
|
354 |
+
nonlinearity='leakyrelu', inplace_nonlinearity=False,
|
355 |
+
apply_noise=False, hidden_channels_equal_out_channels=False,
|
356 |
+
order='CNACNA', learn_shortcut=None, clamp=None,
|
357 |
+
output_scale=1):
|
358 |
+
super().__init__(in_channels, out_channels, None, 1, None, None,
|
359 |
+
None, bias, None, weight_norm_type, weight_norm_params,
|
360 |
+
activation_norm_type, activation_norm_params,
|
361 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
362 |
+
inplace_nonlinearity, apply_noise,
|
363 |
+
hidden_channels_equal_out_channels, order, LinearBlock,
|
364 |
+
learn_shortcut, clamp, output_scale)
|
365 |
+
|
366 |
+
|
367 |
+
class Res1dBlock(_BaseResBlock):
|
368 |
+
r"""Residual block for 1D input.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
in_channels (int) : Number of channels in the input tensor.
|
372 |
+
out_channels (int) : Number of channels in the output tensor.
|
373 |
+
kernel_size (int, optional, default=3): Kernel size for the
|
374 |
+
convolutional filters in the residual link.
|
375 |
+
padding (int, optional, default=1): Padding size.
|
376 |
+
dilation (int, optional, default=1): Dilation factor.
|
377 |
+
groups (int, optional, default=1): Number of convolutional/linear
|
378 |
+
groups.
|
379 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
380 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
381 |
+
weight_norm_type (str, optional, default='none'):
|
382 |
+
Type of weight normalization.
|
383 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
384 |
+
or ``'weight_demod'``.
|
385 |
+
weight_norm_params (obj, optional, default=None):
|
386 |
+
Parameters of weight normalization.
|
387 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
388 |
+
keyword arguments when initializing weight normalization.
|
389 |
+
activation_norm_type (str, optional, default='none'):
|
390 |
+
Type of activation normalization.
|
391 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
392 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
393 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
394 |
+
activation_norm_params (obj, optional, default=None):
|
395 |
+
Parameters of activation normalization.
|
396 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
397 |
+
keyword arguments when initializing activation normalization.
|
398 |
+
skip_activation_norm (bool, optional, default=True): If ``True`` and
|
399 |
+
``learn_shortcut`` is also ``True``, applies activation norm to the
|
400 |
+
learned shortcut connection.
|
401 |
+
skip_nonlinearity (bool, optional, default=True): If ``True`` and
|
402 |
+
``learn_shortcut`` is also ``True``, applies nonlinearity to the
|
403 |
+
learned shortcut connection.
|
404 |
+
nonlinearity (str, optional, default='none'):
|
405 |
+
Type of nonlinear activation function in the residual link.
|
406 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
407 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
408 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
409 |
+
set ``inplace=True`` when initializing the nonlinearity layers.
|
410 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
411 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
412 |
+
hidden_channels_equal_out_channels (bool, optional, default=False):
|
413 |
+
If ``True``, set the hidden channel number to be equal to the
|
414 |
+
output channel number. If ``False``, the hidden channel number
|
415 |
+
equals to the smaller of the input channel number and the
|
416 |
+
output channel number.
|
417 |
+
order (str, optional, default='CNACNA'): Order of operations
|
418 |
+
in the residual link.
|
419 |
+
``'C'``: convolution,
|
420 |
+
``'N'``: normalization,
|
421 |
+
``'A'``: nonlinear activation.
|
422 |
+
learn_shortcut (bool, optional, default=False): If ``True``, always use
|
423 |
+
a convolutional shortcut instead of an identity one, otherwise only
|
424 |
+
use a convolutional one if input and output have different number of
|
425 |
+
channels.
|
426 |
+
"""
|
427 |
+
|
428 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
429 |
+
stride=1, padding=1, dilation=1, groups=1, bias=True,
|
430 |
+
padding_mode='zeros',
|
431 |
+
weight_norm_type='none', weight_norm_params=None,
|
432 |
+
activation_norm_type='none', activation_norm_params=None,
|
433 |
+
skip_activation_norm=True, skip_nonlinearity=False,
|
434 |
+
nonlinearity='leakyrelu', inplace_nonlinearity=False,
|
435 |
+
apply_noise=False, hidden_channels_equal_out_channels=False,
|
436 |
+
order='CNACNA', learn_shortcut=None, clamp=None,
|
437 |
+
output_scale=1):
|
438 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
439 |
+
padding, dilation, groups, bias, padding_mode,
|
440 |
+
weight_norm_type, weight_norm_params,
|
441 |
+
activation_norm_type, activation_norm_params,
|
442 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
443 |
+
inplace_nonlinearity, apply_noise,
|
444 |
+
hidden_channels_equal_out_channels, order, Conv1dBlock,
|
445 |
+
learn_shortcut, clamp, output_scale)
|
446 |
+
|
447 |
+
|
448 |
+
class Res2dBlock(_BaseResBlock):
|
449 |
+
r"""Residual block for 2D input.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
in_channels (int) : Number of channels in the input tensor.
|
453 |
+
out_channels (int) : Number of channels in the output tensor.
|
454 |
+
kernel_size (int, optional, default=3): Kernel size for the
|
455 |
+
convolutional filters in the residual link.
|
456 |
+
padding (int, optional, default=1): Padding size.
|
457 |
+
dilation (int, optional, default=1): Dilation factor.
|
458 |
+
groups (int, optional, default=1): Number of convolutional/linear
|
459 |
+
groups.
|
460 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
461 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
462 |
+
weight_norm_type (str, optional, default='none'):
|
463 |
+
Type of weight normalization.
|
464 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
465 |
+
or ``'weight_demod'``.
|
466 |
+
weight_norm_params (obj, optional, default=None):
|
467 |
+
Parameters of weight normalization.
|
468 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
469 |
+
keyword arguments when initializing weight normalization.
|
470 |
+
activation_norm_type (str, optional, default='none'):
|
471 |
+
Type of activation normalization.
|
472 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
473 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
474 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
475 |
+
activation_norm_params (obj, optional, default=None):
|
476 |
+
Parameters of activation normalization.
|
477 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
478 |
+
keyword arguments when initializing activation normalization.
|
479 |
+
skip_activation_norm (bool, optional, default=True): If ``True`` and
|
480 |
+
``learn_shortcut`` is also ``True``, applies activation norm to the
|
481 |
+
learned shortcut connection.
|
482 |
+
skip_nonlinearity (bool, optional, default=True): If ``True`` and
|
483 |
+
``learn_shortcut`` is also ``True``, applies nonlinearity to the
|
484 |
+
learned shortcut connection.
|
485 |
+
nonlinearity (str, optional, default='none'):
|
486 |
+
Type of nonlinear activation function in the residual link.
|
487 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
488 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
489 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
490 |
+
set ``inplace=True`` when initializing the nonlinearity layers.
|
491 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
492 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
493 |
+
hidden_channels_equal_out_channels (bool, optional, default=False):
|
494 |
+
If ``True``, set the hidden channel number to be equal to the
|
495 |
+
output channel number. If ``False``, the hidden channel number
|
496 |
+
equals to the smaller of the input channel number and the
|
497 |
+
output channel number.
|
498 |
+
order (str, optional, default='CNACNA'): Order of operations
|
499 |
+
in the residual link.
|
500 |
+
``'C'``: convolution,
|
501 |
+
``'N'``: normalization,
|
502 |
+
``'A'``: nonlinear activation.
|
503 |
+
learn_shortcut (bool, optional, default=False): If ``True``, always use
|
504 |
+
a convolutional shortcut instead of an identity one, otherwise only
|
505 |
+
use a convolutional one if input and output have different number of
|
506 |
+
channels.
|
507 |
+
"""
|
508 |
+
|
509 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
510 |
+
stride=1, padding=1, dilation=1, groups=1, bias=True,
|
511 |
+
padding_mode='zeros',
|
512 |
+
weight_norm_type='none', weight_norm_params=None,
|
513 |
+
activation_norm_type='none', activation_norm_params=None,
|
514 |
+
skip_activation_norm=True, skip_nonlinearity=False,
|
515 |
+
skip_weight_norm=True,
|
516 |
+
nonlinearity='leakyrelu', inplace_nonlinearity=False,
|
517 |
+
apply_noise=False, hidden_channels_equal_out_channels=False,
|
518 |
+
order='CNACNA', learn_shortcut=None, clamp=None,
|
519 |
+
output_scale=1, blur=False, upsample_first=True):
|
520 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
521 |
+
padding, dilation, groups, bias, padding_mode,
|
522 |
+
weight_norm_type, weight_norm_params,
|
523 |
+
activation_norm_type, activation_norm_params,
|
524 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
525 |
+
inplace_nonlinearity, apply_noise,
|
526 |
+
hidden_channels_equal_out_channels, order, Conv2dBlock,
|
527 |
+
learn_shortcut, clamp, output_scale, blur=blur,
|
528 |
+
upsample_first=upsample_first,
|
529 |
+
skip_weight_norm=skip_weight_norm)
|
530 |
+
|
531 |
+
|
532 |
+
class Res3dBlock(_BaseResBlock):
|
533 |
+
r"""Residual block for 3D input.
|
534 |
+
|
535 |
+
Args:
|
536 |
+
in_channels (int) : Number of channels in the input tensor.
|
537 |
+
out_channels (int) : Number of channels in the output tensor.
|
538 |
+
kernel_size (int, optional, default=3): Kernel size for the
|
539 |
+
convolutional filters in the residual link.
|
540 |
+
padding (int, optional, default=1): Padding size.
|
541 |
+
dilation (int, optional, default=1): Dilation factor.
|
542 |
+
groups (int, optional, default=1): Number of convolutional/linear
|
543 |
+
groups.
|
544 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
545 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
546 |
+
weight_norm_type (str, optional, default='none'):
|
547 |
+
Type of weight normalization.
|
548 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
549 |
+
or ``'weight_demod'``.
|
550 |
+
weight_norm_params (obj, optional, default=None):
|
551 |
+
Parameters of weight normalization.
|
552 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
553 |
+
keyword arguments when initializing weight normalization.
|
554 |
+
activation_norm_type (str, optional, default='none'):
|
555 |
+
Type of activation normalization.
|
556 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
557 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
558 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
559 |
+
activation_norm_params (obj, optional, default=None):
|
560 |
+
Parameters of activation normalization.
|
561 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
562 |
+
keyword arguments when initializing activation normalization.
|
563 |
+
skip_activation_norm (bool, optional, default=True): If ``True`` and
|
564 |
+
``learn_shortcut`` is also ``True``, applies activation norm to the
|
565 |
+
learned shortcut connection.
|
566 |
+
skip_nonlinearity (bool, optional, default=True): If ``True`` and
|
567 |
+
``learn_shortcut`` is also ``True``, applies nonlinearity to the
|
568 |
+
learned shortcut connection.
|
569 |
+
nonlinearity (str, optional, default='none'):
|
570 |
+
Type of nonlinear activation function in the residual link.
|
571 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
572 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
573 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
574 |
+
set ``inplace=True`` when initializing the nonlinearity layers.
|
575 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
576 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
577 |
+
hidden_channels_equal_out_channels (bool, optional, default=False):
|
578 |
+
If ``True``, set the hidden channel number to be equal to the
|
579 |
+
output channel number. If ``False``, the hidden channel number
|
580 |
+
equals to the smaller of the input channel number and the
|
581 |
+
output channel number.
|
582 |
+
order (str, optional, default='CNACNA'): Order of operations
|
583 |
+
in the residual link.
|
584 |
+
``'C'``: convolution,
|
585 |
+
``'N'``: normalization,
|
586 |
+
``'A'``: nonlinear activation.
|
587 |
+
learn_shortcut (bool, optional, default=False): If ``True``, always use
|
588 |
+
a convolutional shortcut instead of an identity one, otherwise only
|
589 |
+
use a convolutional one if input and output have different number of
|
590 |
+
channels.
|
591 |
+
"""
|
592 |
+
|
593 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
594 |
+
stride=1, padding=1, dilation=1, groups=1, bias=True,
|
595 |
+
padding_mode='zeros',
|
596 |
+
weight_norm_type='none', weight_norm_params=None,
|
597 |
+
activation_norm_type='none', activation_norm_params=None,
|
598 |
+
skip_activation_norm=True, skip_nonlinearity=False,
|
599 |
+
nonlinearity='leakyrelu', inplace_nonlinearity=False,
|
600 |
+
apply_noise=False, hidden_channels_equal_out_channels=False,
|
601 |
+
order='CNACNA', learn_shortcut=None, clamp=None,
|
602 |
+
output_scale=1):
|
603 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
604 |
+
padding, dilation, groups, bias, padding_mode,
|
605 |
+
weight_norm_type, weight_norm_params,
|
606 |
+
activation_norm_type, activation_norm_params,
|
607 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
608 |
+
inplace_nonlinearity, apply_noise,
|
609 |
+
hidden_channels_equal_out_channels, order, Conv3dBlock,
|
610 |
+
learn_shortcut, clamp, output_scale)
|
611 |
+
|
612 |
+
|
613 |
+
class _BaseHyperResBlock(_BaseResBlock):
|
614 |
+
r"""An abstract class for hyper residual blocks.
|
615 |
+
"""
|
616 |
+
|
617 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
618 |
+
stride, padding, dilation, groups, bias, padding_mode,
|
619 |
+
weight_norm_type, weight_norm_params,
|
620 |
+
activation_norm_type, activation_norm_params,
|
621 |
+
skip_activation_norm, skip_nonlinearity,
|
622 |
+
nonlinearity, inplace_nonlinearity, apply_noise,
|
623 |
+
hidden_channels_equal_out_channels,
|
624 |
+
order, is_hyper_conv, is_hyper_norm, block, learn_shortcut,
|
625 |
+
clamp=None, output_scale=1):
|
626 |
+
block = functools.partial(block,
|
627 |
+
is_hyper_conv=is_hyper_conv,
|
628 |
+
is_hyper_norm=is_hyper_norm)
|
629 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
630 |
+
padding, dilation, groups, bias, padding_mode,
|
631 |
+
weight_norm_type, weight_norm_params,
|
632 |
+
activation_norm_type, activation_norm_params,
|
633 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
634 |
+
inplace_nonlinearity, apply_noise,
|
635 |
+
hidden_channels_equal_out_channels, order, block,
|
636 |
+
learn_shortcut, clamp, output_scale)
|
637 |
+
|
638 |
+
def forward(self, x, *cond_inputs, conv_weights=(None,) * 3,
|
639 |
+
norm_weights=(None,) * 3, **kw_cond_inputs):
|
640 |
+
r"""
|
641 |
+
|
642 |
+
Args:
|
643 |
+
x (tensor): Input tensor.
|
644 |
+
cond_inputs (list of tensors) : Conditional input tensors.
|
645 |
+
conv_weights (list of tensors): Convolution weights for
|
646 |
+
three convolutional layers respectively.
|
647 |
+
norm_weights (list of tensors): Normalization weights for
|
648 |
+
three convolutional layers respectively.
|
649 |
+
kw_cond_inputs (dict) : Keyword conditional inputs.
|
650 |
+
Returns:
|
651 |
+
output (tensor): Output tensor.
|
652 |
+
"""
|
653 |
+
dx = self.conv_block_0(x, *cond_inputs, conv_weights=conv_weights[0],
|
654 |
+
norm_weights=norm_weights[0])
|
655 |
+
dx = self.conv_block_1(dx, *cond_inputs, conv_weights=conv_weights[1],
|
656 |
+
norm_weights=norm_weights[1])
|
657 |
+
if self.learn_shortcut:
|
658 |
+
x_shortcut = self.conv_block_s(x, *cond_inputs,
|
659 |
+
conv_weights=conv_weights[2],
|
660 |
+
norm_weights=norm_weights[2])
|
661 |
+
else:
|
662 |
+
x_shortcut = x
|
663 |
+
output = x_shortcut + dx
|
664 |
+
return self.output_scale * output
|
665 |
+
|
666 |
+
|
667 |
+
class HyperRes2dBlock(_BaseHyperResBlock):
|
668 |
+
r"""Hyper residual block for 2D input.
|
669 |
+
|
670 |
+
Args:
|
671 |
+
in_channels (int) : Number of channels in the input tensor.
|
672 |
+
out_channels (int) : Number of channels in the output tensor.
|
673 |
+
kernel_size (int, optional, default=3): Kernel size for the
|
674 |
+
convolutional filters in the residual link.
|
675 |
+
padding (int, optional, default=1): Padding size.
|
676 |
+
dilation (int, optional, default=1): Dilation factor.
|
677 |
+
groups (int, optional, default=1): Number of convolutional/linear
|
678 |
+
groups.
|
679 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
680 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
681 |
+
weight_norm_type (str, optional, default='none'):
|
682 |
+
Type of weight normalization.
|
683 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
684 |
+
or ``'weight_demod'``.
|
685 |
+
weight_norm_params (obj, optional, default=None):
|
686 |
+
Parameters of weight normalization.
|
687 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
688 |
+
keyword arguments when initializing weight normalization.
|
689 |
+
activation_norm_type (str, optional, default='none'):
|
690 |
+
Type of activation normalization.
|
691 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
692 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
693 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
694 |
+
activation_norm_params (obj, optional, default=None):
|
695 |
+
Parameters of activation normalization.
|
696 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
697 |
+
keyword arguments when initializing activation normalization.
|
698 |
+
skip_activation_norm (bool, optional, default=True): If ``True`` and
|
699 |
+
``learn_shortcut`` is also ``True``, applies activation norm to the
|
700 |
+
learned shortcut connection.
|
701 |
+
skip_nonlinearity (bool, optional, default=True): If ``True`` and
|
702 |
+
``learn_shortcut`` is also ``True``, applies nonlinearity to the
|
703 |
+
learned shortcut connection.
|
704 |
+
nonlinearity (str, optional, default='none'):
|
705 |
+
Type of nonlinear activation function in the residual link.
|
706 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
707 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
708 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
709 |
+
set ``inplace=True`` when initializing the nonlinearity layers.
|
710 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
711 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
712 |
+
hidden_channels_equal_out_channels (bool, optional, default=False):
|
713 |
+
If ``True``, set the hidden channel number to be equal to the
|
714 |
+
output channel number. If ``False``, the hidden channel number
|
715 |
+
equals to the smaller of the input channel number and the
|
716 |
+
output channel number.
|
717 |
+
order (str, optional, default='CNACNA'): Order of operations
|
718 |
+
in the residual link.
|
719 |
+
``'C'``: convolution,
|
720 |
+
``'N'``: normalization,
|
721 |
+
``'A'``: nonlinear activation.
|
722 |
+
is_hyper_conv (bool, optional, default=False): If ``True``, use
|
723 |
+
``HyperConv2d``, otherwise use ``torch.nn.Conv2d``.
|
724 |
+
is_hyper_norm (bool, optional, default=False): If ``True``, use
|
725 |
+
hyper normalizations.
|
726 |
+
learn_shortcut (bool, optional, default=False): If ``True``, always use
|
727 |
+
a convolutional shortcut instead of an identity one, otherwise only
|
728 |
+
use a convolutional one if input and output have different number of
|
729 |
+
channels.
|
730 |
+
"""
|
731 |
+
|
732 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
733 |
+
stride=1, padding=1, dilation=1, groups=1, bias=True,
|
734 |
+
padding_mode='zeros',
|
735 |
+
weight_norm_type='', weight_norm_params=None,
|
736 |
+
activation_norm_type='', activation_norm_params=None,
|
737 |
+
skip_activation_norm=True, skip_nonlinearity=False,
|
738 |
+
nonlinearity='leakyrelu', inplace_nonlinearity=False,
|
739 |
+
apply_noise=False, hidden_channels_equal_out_channels=False,
|
740 |
+
order='CNACNA', is_hyper_conv=False, is_hyper_norm=False,
|
741 |
+
learn_shortcut=None, clamp=None, output_scale=1):
|
742 |
+
super().__init__(in_channels, out_channels, kernel_size,
|
743 |
+
stride, padding, dilation, groups, bias, padding_mode,
|
744 |
+
weight_norm_type, weight_norm_params,
|
745 |
+
activation_norm_type, activation_norm_params,
|
746 |
+
skip_activation_norm, skip_nonlinearity,
|
747 |
+
nonlinearity, inplace_nonlinearity, apply_noise,
|
748 |
+
hidden_channels_equal_out_channels,
|
749 |
+
order, is_hyper_conv, is_hyper_norm,
|
750 |
+
HyperConv2dBlock, learn_shortcut, clamp, output_scale)
|
751 |
+
|
752 |
+
|
753 |
+
class _BaseDownResBlock(_BaseResBlock):
|
754 |
+
r"""An abstract class for residual blocks with downsampling.
|
755 |
+
"""
|
756 |
+
|
757 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
758 |
+
stride, padding, dilation, groups, bias, padding_mode,
|
759 |
+
weight_norm_type, weight_norm_params,
|
760 |
+
activation_norm_type, activation_norm_params,
|
761 |
+
skip_activation_norm, skip_nonlinearity,
|
762 |
+
nonlinearity, inplace_nonlinearity,
|
763 |
+
apply_noise, hidden_channels_equal_out_channels,
|
764 |
+
order, block, pooling, down_factor, learn_shortcut,
|
765 |
+
clamp=None, output_scale=1):
|
766 |
+
super().__init__(in_channels, out_channels, kernel_size,
|
767 |
+
stride, padding, dilation, groups, bias, padding_mode,
|
768 |
+
weight_norm_type, weight_norm_params,
|
769 |
+
activation_norm_type, activation_norm_params,
|
770 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
771 |
+
inplace_nonlinearity, apply_noise,
|
772 |
+
hidden_channels_equal_out_channels, order, block,
|
773 |
+
learn_shortcut, clamp, output_scale)
|
774 |
+
self.pooling = pooling(down_factor)
|
775 |
+
|
776 |
+
def forward(self, x, *cond_inputs):
|
777 |
+
r"""
|
778 |
+
|
779 |
+
Args:
|
780 |
+
x (tensor) : Input tensor.
|
781 |
+
cond_inputs (list of tensors) : conditional input.
|
782 |
+
Returns:
|
783 |
+
output (tensor) : Output tensor.
|
784 |
+
"""
|
785 |
+
dx = self.conv_block_0(x, *cond_inputs)
|
786 |
+
dx = self.conv_block_1(dx, *cond_inputs)
|
787 |
+
dx = self.pooling(dx)
|
788 |
+
if self.learn_shortcut:
|
789 |
+
x_shortcut = self.conv_block_s(x, *cond_inputs)
|
790 |
+
else:
|
791 |
+
x_shortcut = x
|
792 |
+
x_shortcut = self.pooling(x_shortcut)
|
793 |
+
output = x_shortcut + dx
|
794 |
+
return self.output_scale * output
|
795 |
+
|
796 |
+
|
797 |
+
class DownRes2dBlock(_BaseDownResBlock):
|
798 |
+
r"""Residual block for 2D input with downsampling.
|
799 |
+
|
800 |
+
Args:
|
801 |
+
in_channels (int) : Number of channels in the input tensor.
|
802 |
+
out_channels (int) : Number of channels in the output tensor.
|
803 |
+
kernel_size (int, optional, default=3): Kernel size for the
|
804 |
+
convolutional filters in the residual link.
|
805 |
+
padding (int, optional, default=1): Padding size.
|
806 |
+
dilation (int, optional, default=1): Dilation factor.
|
807 |
+
groups (int, optional, default=1): Number of convolutional/linear
|
808 |
+
groups.
|
809 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
810 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
811 |
+
weight_norm_type (str, optional, default='none'):
|
812 |
+
Type of weight normalization.
|
813 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
814 |
+
or ``'weight_demod'``.
|
815 |
+
weight_norm_params (obj, optional, default=None):
|
816 |
+
Parameters of weight normalization.
|
817 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
818 |
+
keyword arguments when initializing weight normalization.
|
819 |
+
activation_norm_type (str, optional, default='none'):
|
820 |
+
Type of activation normalization.
|
821 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
822 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
823 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
824 |
+
activation_norm_params (obj, optional, default=None):
|
825 |
+
Parameters of activation normalization.
|
826 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
827 |
+
keyword arguments when initializing activation normalization.
|
828 |
+
skip_activation_norm (bool, optional, default=True): If ``True`` and
|
829 |
+
``learn_shortcut`` is also ``True``, applies activation norm to the
|
830 |
+
learned shortcut connection.
|
831 |
+
skip_nonlinearity (bool, optional, default=True): If ``True`` and
|
832 |
+
``learn_shortcut`` is also ``True``, applies nonlinearity to the
|
833 |
+
learned shortcut connection.
|
834 |
+
nonlinearity (str, optional, default='none'):
|
835 |
+
Type of nonlinear activation function in the residual link.
|
836 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
837 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
838 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
839 |
+
set ``inplace=True`` when initializing the nonlinearity layers.
|
840 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
841 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
842 |
+
hidden_channels_equal_out_channels (bool, optional, default=False):
|
843 |
+
If ``True``, set the hidden channel number to be equal to the
|
844 |
+
output channel number. If ``False``, the hidden channel number
|
845 |
+
equals to the smaller of the input channel number and the
|
846 |
+
output channel number.
|
847 |
+
order (str, optional, default='CNACNA'): Order of operations
|
848 |
+
in the residual link.
|
849 |
+
``'C'``: convolution,
|
850 |
+
``'N'``: normalization,
|
851 |
+
``'A'``: nonlinear activation.
|
852 |
+
pooling (class, optional, default=nn.AvgPool2d): Pytorch pooling
|
853 |
+
layer to be used.
|
854 |
+
down_factor (int, optional, default=2): Downsampling factor.
|
855 |
+
learn_shortcut (bool, optional, default=False): If ``True``, always use
|
856 |
+
a convolutional shortcut instead of an identity one, otherwise only
|
857 |
+
use a convolutional one if input and output have different number of
|
858 |
+
channels.
|
859 |
+
"""
|
860 |
+
|
861 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
862 |
+
stride=1, padding=1, dilation=1, groups=1, bias=True,
|
863 |
+
padding_mode='zeros',
|
864 |
+
weight_norm_type='none', weight_norm_params=None,
|
865 |
+
activation_norm_type='none', activation_norm_params=None,
|
866 |
+
skip_activation_norm=True, skip_nonlinearity=False,
|
867 |
+
nonlinearity='leakyrelu', inplace_nonlinearity=False,
|
868 |
+
apply_noise=False, hidden_channels_equal_out_channels=False,
|
869 |
+
order='CNACNA', pooling=nn.AvgPool2d, down_factor=2,
|
870 |
+
learn_shortcut=None, clamp=None, output_scale=1):
|
871 |
+
super().__init__(in_channels, out_channels, kernel_size,
|
872 |
+
stride, padding, dilation, groups, bias, padding_mode,
|
873 |
+
weight_norm_type, weight_norm_params,
|
874 |
+
activation_norm_type, activation_norm_params,
|
875 |
+
skip_activation_norm, skip_nonlinearity,
|
876 |
+
nonlinearity, inplace_nonlinearity, apply_noise,
|
877 |
+
hidden_channels_equal_out_channels,
|
878 |
+
order, Conv2dBlock, pooling,
|
879 |
+
down_factor, learn_shortcut, clamp, output_scale)
|
880 |
+
|
881 |
+
|
882 |
+
class _BaseUpResBlock(_BaseResBlock):
|
883 |
+
r"""An abstract class for residual blocks with upsampling.
|
884 |
+
"""
|
885 |
+
|
886 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
887 |
+
stride, padding, dilation, groups, bias, padding_mode,
|
888 |
+
weight_norm_type, weight_norm_params,
|
889 |
+
activation_norm_type, activation_norm_params,
|
890 |
+
skip_activation_norm, skip_nonlinearity,
|
891 |
+
nonlinearity, inplace_nonlinearity,
|
892 |
+
apply_noise, hidden_channels_equal_out_channels,
|
893 |
+
order, block, upsample, up_factor, learn_shortcut, clamp=None,
|
894 |
+
output_scale=1):
|
895 |
+
super().__init__(in_channels, out_channels, kernel_size,
|
896 |
+
stride, padding, dilation, groups, bias, padding_mode,
|
897 |
+
weight_norm_type, weight_norm_params,
|
898 |
+
activation_norm_type, activation_norm_params,
|
899 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
900 |
+
inplace_nonlinearity, apply_noise,
|
901 |
+
hidden_channels_equal_out_channels, order, block,
|
902 |
+
learn_shortcut, clamp, output_scale)
|
903 |
+
self.order = order
|
904 |
+
self.upsample = upsample(scale_factor=up_factor)
|
905 |
+
|
906 |
+
def _get_stride_blur(self):
|
907 |
+
# Upsampling.
|
908 |
+
first_stride, second_stride = self.stride, 1
|
909 |
+
first_blur, second_blur = self.blur, False
|
910 |
+
shortcut_blur = False
|
911 |
+
shortcut_stride = 1
|
912 |
+
# if self.upsample == 'blur_deconv':
|
913 |
+
|
914 |
+
if self.blur:
|
915 |
+
# The shortcut branch uses blur_upsample + stride-1 conv
|
916 |
+
self.upsample = BlurUpsample()
|
917 |
+
else:
|
918 |
+
shortcut_stride = self.stride
|
919 |
+
self.upsample = nn.Upsample(scale_factor=2)
|
920 |
+
|
921 |
+
return (first_stride, second_stride, shortcut_stride,
|
922 |
+
first_blur, second_blur, shortcut_blur)
|
923 |
+
|
924 |
+
def forward(self, x, *cond_inputs):
|
925 |
+
r"""Implementation of the up residual block forward function.
|
926 |
+
If the order is 'NAC' for the first residual block, we will first
|
927 |
+
do the activation norm and nonlinearity, in the original resolution.
|
928 |
+
We will then upsample the activation map to a higher resolution. We
|
929 |
+
then do the convolution.
|
930 |
+
It is is other orders, then we first do the whole processing and
|
931 |
+
then upsample.
|
932 |
+
|
933 |
+
Args:
|
934 |
+
x (tensor) : Input tensor.
|
935 |
+
cond_inputs (list of tensors) : Conditional input.
|
936 |
+
Returns:
|
937 |
+
output (tensor) : Output tensor.
|
938 |
+
"""
|
939 |
+
# In this particular upsample residual block operation, we first
|
940 |
+
# upsample the skip connection.
|
941 |
+
if self.learn_shortcut:
|
942 |
+
x_shortcut = self.upsample(x)
|
943 |
+
x_shortcut = self.conv_block_s(x_shortcut, *cond_inputs)
|
944 |
+
else:
|
945 |
+
x_shortcut = self.upsample(x)
|
946 |
+
|
947 |
+
if self.order[0:3] == 'NAC':
|
948 |
+
for ix, layer in enumerate(self.conv_block_0.layers.values()):
|
949 |
+
if getattr(layer, 'conditional', False):
|
950 |
+
x = layer(x, *cond_inputs)
|
951 |
+
else:
|
952 |
+
x = layer(x)
|
953 |
+
if ix == 1:
|
954 |
+
x = self.upsample(x)
|
955 |
+
else:
|
956 |
+
x = self.conv_block_0(x, *cond_inputs)
|
957 |
+
x = self.upsample(x)
|
958 |
+
x = self.conv_block_1(x, *cond_inputs)
|
959 |
+
|
960 |
+
output = x_shortcut + x
|
961 |
+
return self.output_scale * output
|
962 |
+
|
963 |
+
|
964 |
+
class UpRes2dBlock(_BaseUpResBlock):
|
965 |
+
r"""Residual block for 2D input with downsampling.
|
966 |
+
|
967 |
+
Args:
|
968 |
+
in_channels (int) : Number of channels in the input tensor.
|
969 |
+
out_channels (int) : Number of channels in the output tensor.
|
970 |
+
kernel_size (int, optional, default=3): Kernel size for the
|
971 |
+
convolutional filters in the residual link.
|
972 |
+
padding (int, optional, default=1): Padding size.
|
973 |
+
dilation (int, optional, default=1): Dilation factor.
|
974 |
+
groups (int, optional, default=1): Number of convolutional/linear
|
975 |
+
groups.
|
976 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
977 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
978 |
+
weight_norm_type (str, optional, default='none'):
|
979 |
+
Type of weight normalization.
|
980 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
981 |
+
or ``'weight_demod'``.
|
982 |
+
weight_norm_params (obj, optional, default=None):
|
983 |
+
Parameters of weight normalization.
|
984 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
985 |
+
keyword arguments when initializing weight normalization.
|
986 |
+
activation_norm_type (str, optional, default='none'):
|
987 |
+
Type of activation normalization.
|
988 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
989 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
990 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
991 |
+
activation_norm_params (obj, optional, default=None):
|
992 |
+
Parameters of activation normalization.
|
993 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
994 |
+
keyword arguments when initializing activation normalization.
|
995 |
+
skip_activation_norm (bool, optional, default=True): If ``True`` and
|
996 |
+
``learn_shortcut`` is also ``True``, applies activation norm to the
|
997 |
+
learned shortcut connection.
|
998 |
+
skip_nonlinearity (bool, optional, default=True): If ``True`` and
|
999 |
+
``learn_shortcut`` is also ``True``, applies nonlinearity to the
|
1000 |
+
learned shortcut connection.
|
1001 |
+
nonlinearity (str, optional, default='none'):
|
1002 |
+
Type of nonlinear activation function in the residual link.
|
1003 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
1004 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
1005 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
1006 |
+
set ``inplace=True`` when initializing the nonlinearity layers.
|
1007 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
1008 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
1009 |
+
hidden_channels_equal_out_channels (bool, optional, default=False):
|
1010 |
+
If ``True``, set the hidden channel number to be equal to the
|
1011 |
+
output channel number. If ``False``, the hidden channel number
|
1012 |
+
equals to the smaller of the input channel number and the
|
1013 |
+
output channel number.
|
1014 |
+
order (str, optional, default='CNACNA'): Order of operations
|
1015 |
+
in the residual link.
|
1016 |
+
``'C'``: convolution,
|
1017 |
+
``'N'``: normalization,
|
1018 |
+
``'A'``: nonlinear activation.
|
1019 |
+
upsample (class, optional, default=NearestUpsample): PPytorch
|
1020 |
+
upsampling layer to be used.
|
1021 |
+
up_factor (int, optional, default=2): Upsampling factor.
|
1022 |
+
learn_shortcut (bool, optional, default=False): If ``True``, always use
|
1023 |
+
a convolutional shortcut instead of an identity one, otherwise only
|
1024 |
+
use a convolutional one if input and output have different number of
|
1025 |
+
channels.
|
1026 |
+
"""
|
1027 |
+
|
1028 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
1029 |
+
stride=1, padding=1, dilation=1, groups=1, bias=True,
|
1030 |
+
padding_mode='zeros',
|
1031 |
+
weight_norm_type='none', weight_norm_params=None,
|
1032 |
+
activation_norm_type='none', activation_norm_params=None,
|
1033 |
+
skip_activation_norm=True, skip_nonlinearity=False,
|
1034 |
+
nonlinearity='leakyrelu', inplace_nonlinearity=False,
|
1035 |
+
apply_noise=False, hidden_channels_equal_out_channels=False,
|
1036 |
+
order='CNACNA', upsample=NearestUpsample, up_factor=2,
|
1037 |
+
learn_shortcut=None, clamp=None, output_scale=1):
|
1038 |
+
super().__init__(in_channels, out_channels, kernel_size,
|
1039 |
+
stride, padding, dilation, groups, bias, padding_mode,
|
1040 |
+
weight_norm_type, weight_norm_params,
|
1041 |
+
activation_norm_type, activation_norm_params,
|
1042 |
+
skip_activation_norm, skip_nonlinearity,
|
1043 |
+
nonlinearity, inplace_nonlinearity,
|
1044 |
+
apply_noise, hidden_channels_equal_out_channels,
|
1045 |
+
order, Conv2dBlock,
|
1046 |
+
upsample, up_factor, learn_shortcut, clamp,
|
1047 |
+
output_scale)
|
1048 |
+
|
1049 |
+
|
1050 |
+
class _BasePartialResBlock(_BaseResBlock):
|
1051 |
+
r"""An abstract class for residual blocks with partial convolution.
|
1052 |
+
"""
|
1053 |
+
|
1054 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
1055 |
+
stride, padding, dilation, groups, bias, padding_mode,
|
1056 |
+
weight_norm_type, weight_norm_params,
|
1057 |
+
activation_norm_type, activation_norm_params,
|
1058 |
+
skip_activation_norm, skip_nonlinearity,
|
1059 |
+
nonlinearity, inplace_nonlinearity,
|
1060 |
+
multi_channel, return_mask,
|
1061 |
+
apply_noise, hidden_channels_equal_out_channels,
|
1062 |
+
order, block, learn_shortcut, clamp=None, output_scale=1):
|
1063 |
+
block = functools.partial(block,
|
1064 |
+
multi_channel=multi_channel,
|
1065 |
+
return_mask=return_mask)
|
1066 |
+
self.partial_conv = True
|
1067 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
1068 |
+
padding, dilation, groups, bias, padding_mode,
|
1069 |
+
weight_norm_type, weight_norm_params,
|
1070 |
+
activation_norm_type, activation_norm_params,
|
1071 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
1072 |
+
inplace_nonlinearity, apply_noise,
|
1073 |
+
hidden_channels_equal_out_channels, order, block,
|
1074 |
+
learn_shortcut, clamp, output_scale)
|
1075 |
+
|
1076 |
+
def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs):
|
1077 |
+
r"""
|
1078 |
+
|
1079 |
+
Args:
|
1080 |
+
x (tensor): Input tensor.
|
1081 |
+
cond_inputs (list of tensors) : Conditional input tensors.
|
1082 |
+
mask_in (tensor, optional, default=``None``) If not ``None``,
|
1083 |
+
it masks the valid input region.
|
1084 |
+
kw_cond_inputs (dict) : Keyword conditional inputs.
|
1085 |
+
Returns:
|
1086 |
+
(tuple):
|
1087 |
+
- output (tensor): Output tensor.
|
1088 |
+
- mask_out (tensor, optional): Masks the valid output region.
|
1089 |
+
"""
|
1090 |
+
if self.conv_block_0.layers.conv.return_mask:
|
1091 |
+
dx, mask_out = self.conv_block_0(x, *cond_inputs,
|
1092 |
+
mask_in=mask_in, **kw_cond_inputs)
|
1093 |
+
dx, mask_out = self.conv_block_1(dx, *cond_inputs,
|
1094 |
+
mask_in=mask_out, **kw_cond_inputs)
|
1095 |
+
else:
|
1096 |
+
dx = self.conv_block_0(x, *cond_inputs,
|
1097 |
+
mask_in=mask_in, **kw_cond_inputs)
|
1098 |
+
dx = self.conv_block_1(dx, *cond_inputs,
|
1099 |
+
mask_in=mask_in, **kw_cond_inputs)
|
1100 |
+
mask_out = None
|
1101 |
+
|
1102 |
+
if self.learn_shortcut:
|
1103 |
+
x_shortcut = self.conv_block_s(x, mask_in=mask_in, *cond_inputs,
|
1104 |
+
**kw_cond_inputs)
|
1105 |
+
if type(x_shortcut) == tuple:
|
1106 |
+
x_shortcut, _ = x_shortcut
|
1107 |
+
else:
|
1108 |
+
x_shortcut = x
|
1109 |
+
output = x_shortcut + dx
|
1110 |
+
|
1111 |
+
if mask_out is not None:
|
1112 |
+
return output, mask_out
|
1113 |
+
return self.output_scale * output
|
1114 |
+
|
1115 |
+
|
1116 |
+
class PartialRes2dBlock(_BasePartialResBlock):
|
1117 |
+
r"""Residual block for 2D input with partial convolution.
|
1118 |
+
|
1119 |
+
Args:
|
1120 |
+
in_channels (int) : Number of channels in the input tensor.
|
1121 |
+
out_channels (int) : Number of channels in the output tensor.
|
1122 |
+
kernel_size (int, optional, default=3): Kernel size for the
|
1123 |
+
convolutional filters in the residual link.
|
1124 |
+
padding (int, optional, default=1): Padding size.
|
1125 |
+
dilation (int, optional, default=1): Dilation factor.
|
1126 |
+
groups (int, optional, default=1): Number of convolutional/linear
|
1127 |
+
groups.
|
1128 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
1129 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
1130 |
+
weight_norm_type (str, optional, default='none'):
|
1131 |
+
Type of weight normalization.
|
1132 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
1133 |
+
or ``'weight_demod'``.
|
1134 |
+
weight_norm_params (obj, optional, default=None):
|
1135 |
+
Parameters of weight normalization.
|
1136 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
1137 |
+
keyword arguments when initializing weight normalization.
|
1138 |
+
activation_norm_type (str, optional, default='none'):
|
1139 |
+
Type of activation normalization.
|
1140 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
1141 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
1142 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
1143 |
+
activation_norm_params (obj, optional, default=None):
|
1144 |
+
Parameters of activation normalization.
|
1145 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
1146 |
+
keyword arguments when initializing activation normalization.
|
1147 |
+
skip_activation_norm (bool, optional, default=True): If ``True`` and
|
1148 |
+
``learn_shortcut`` is also ``True``, applies activation norm to the
|
1149 |
+
learned shortcut connection.
|
1150 |
+
skip_nonlinearity (bool, optional, default=True): If ``True`` and
|
1151 |
+
``learn_shortcut`` is also ``True``, applies nonlinearity to the
|
1152 |
+
learned shortcut connection.
|
1153 |
+
nonlinearity (str, optional, default='none'):
|
1154 |
+
Type of nonlinear activation function in the residual link.
|
1155 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
1156 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
1157 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
1158 |
+
set ``inplace=True`` when initializing the nonlinearity layers.
|
1159 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
1160 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
1161 |
+
hidden_channels_equal_out_channels (bool, optional, default=False):
|
1162 |
+
If ``True``, set the hidden channel number to be equal to the
|
1163 |
+
output channel number. If ``False``, the hidden channel number
|
1164 |
+
equals to the smaller of the input channel number and the
|
1165 |
+
output channel number.
|
1166 |
+
order (str, optional, default='CNACNA'): Order of operations
|
1167 |
+
in the residual link.
|
1168 |
+
``'C'``: convolution,
|
1169 |
+
``'N'``: normalization,
|
1170 |
+
``'A'``: nonlinear activation.
|
1171 |
+
learn_shortcut (bool, optional, default=False): If ``True``, always use
|
1172 |
+
a convolutional shortcut instead of an identity one, otherwise only
|
1173 |
+
use a convolutional one if input and output have different number of
|
1174 |
+
channels.
|
1175 |
+
"""
|
1176 |
+
|
1177 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
1178 |
+
stride=1, padding=1, dilation=1, groups=1, bias=True,
|
1179 |
+
padding_mode='zeros',
|
1180 |
+
weight_norm_type='none', weight_norm_params=None,
|
1181 |
+
activation_norm_type='none', activation_norm_params=None,
|
1182 |
+
skip_activation_norm=True, skip_nonlinearity=False,
|
1183 |
+
nonlinearity='leakyrelu', inplace_nonlinearity=False,
|
1184 |
+
multi_channel=False, return_mask=True,
|
1185 |
+
apply_noise=False,
|
1186 |
+
hidden_channels_equal_out_channels=False,
|
1187 |
+
order='CNACNA', learn_shortcut=None, clamp=None,
|
1188 |
+
output_scale=1):
|
1189 |
+
super().__init__(in_channels, out_channels, kernel_size,
|
1190 |
+
stride, padding, dilation, groups, bias,
|
1191 |
+
padding_mode, weight_norm_type, weight_norm_params,
|
1192 |
+
activation_norm_type, activation_norm_params,
|
1193 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
1194 |
+
inplace_nonlinearity, multi_channel, return_mask,
|
1195 |
+
apply_noise, hidden_channels_equal_out_channels,
|
1196 |
+
order, PartialConv2dBlock, learn_shortcut, clamp,
|
1197 |
+
output_scale)
|
1198 |
+
|
1199 |
+
|
1200 |
+
class PartialRes3dBlock(_BasePartialResBlock):
|
1201 |
+
r"""Residual block for 3D input with partial convolution.
|
1202 |
+
|
1203 |
+
Args:
|
1204 |
+
in_channels (int) : Number of channels in the input tensor.
|
1205 |
+
out_channels (int) : Number of channels in the output tensor.
|
1206 |
+
kernel_size (int, optional, default=3): Kernel size for the
|
1207 |
+
convolutional filters in the residual link.
|
1208 |
+
padding (int, optional, default=1): Padding size.
|
1209 |
+
dilation (int, optional, default=1): Dilation factor.
|
1210 |
+
groups (int, optional, default=1): Number of convolutional/linear
|
1211 |
+
groups.
|
1212 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
1213 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
1214 |
+
weight_norm_type (str, optional, default='none'):
|
1215 |
+
Type of weight normalization.
|
1216 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
1217 |
+
or ``'weight_demod'``.
|
1218 |
+
weight_norm_params (obj, optional, default=None):
|
1219 |
+
Parameters of weight normalization.
|
1220 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
1221 |
+
keyword arguments when initializing weight normalization.
|
1222 |
+
activation_norm_type (str, optional, default='none'):
|
1223 |
+
Type of activation normalization.
|
1224 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
1225 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
1226 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
1227 |
+
activation_norm_params (obj, optional, default=None):
|
1228 |
+
Parameters of activation normalization.
|
1229 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
1230 |
+
keyword arguments when initializing activation normalization.
|
1231 |
+
skip_activation_norm (bool, optional, default=True): If ``True`` and
|
1232 |
+
``learn_shortcut`` is also ``True``, applies activation norm to the
|
1233 |
+
learned shortcut connection.
|
1234 |
+
skip_nonlinearity (bool, optional, default=True): If ``True`` and
|
1235 |
+
``learn_shortcut`` is also ``True``, applies nonlinearity to the
|
1236 |
+
learned shortcut connection.
|
1237 |
+
nonlinearity (str, optional, default='none'):
|
1238 |
+
Type of nonlinear activation function in the residual link.
|
1239 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
1240 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
1241 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
1242 |
+
set ``inplace=True`` when initializing the nonlinearity layers.
|
1243 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
1244 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
1245 |
+
hidden_channels_equal_out_channels (bool, optional, default=False):
|
1246 |
+
If ``True``, set the hidden channel number to be equal to the
|
1247 |
+
output channel number. If ``False``, the hidden channel number
|
1248 |
+
equals to the smaller of the input channel number and the
|
1249 |
+
output channel number.
|
1250 |
+
order (str, optional, default='CNACNA'): Order of operations
|
1251 |
+
in the residual link.
|
1252 |
+
``'C'``: convolution,
|
1253 |
+
``'N'``: normalization,
|
1254 |
+
``'A'``: nonlinear activation.
|
1255 |
+
learn_shortcut (bool, optional, default=False): If ``True``, always use
|
1256 |
+
a convolutional shortcut instead of an identity one, otherwise only
|
1257 |
+
use a convolutional one if input and output have different number of
|
1258 |
+
channels.
|
1259 |
+
"""
|
1260 |
+
|
1261 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
1262 |
+
stride=1, padding=1, dilation=1, groups=1, bias=True,
|
1263 |
+
padding_mode='zeros',
|
1264 |
+
weight_norm_type='none', weight_norm_params=None,
|
1265 |
+
activation_norm_type='none', activation_norm_params=None,
|
1266 |
+
skip_activation_norm=True, skip_nonlinearity=False,
|
1267 |
+
nonlinearity='leakyrelu', inplace_nonlinearity=False,
|
1268 |
+
multi_channel=False, return_mask=True,
|
1269 |
+
apply_noise=False, hidden_channels_equal_out_channels=False,
|
1270 |
+
order='CNACNA', learn_shortcut=None, clamp=None,
|
1271 |
+
output_scale=1):
|
1272 |
+
super().__init__(in_channels, out_channels, kernel_size,
|
1273 |
+
stride, padding, dilation, groups, bias,
|
1274 |
+
padding_mode, weight_norm_type, weight_norm_params,
|
1275 |
+
activation_norm_type, activation_norm_params,
|
1276 |
+
skip_activation_norm, skip_nonlinearity,
|
1277 |
+
nonlinearity, inplace_nonlinearity, multi_channel,
|
1278 |
+
return_mask, apply_noise,
|
1279 |
+
hidden_channels_equal_out_channels,
|
1280 |
+
order, PartialConv3dBlock, learn_shortcut, clamp,
|
1281 |
+
output_scale)
|
1282 |
+
|
1283 |
+
|
1284 |
+
class _BaseMultiOutResBlock(_BaseResBlock):
|
1285 |
+
r"""An abstract class for residual blocks that can returns multiple outputs.
|
1286 |
+
"""
|
1287 |
+
|
1288 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
1289 |
+
stride, padding, dilation, groups, bias, padding_mode,
|
1290 |
+
weight_norm_type, weight_norm_params,
|
1291 |
+
activation_norm_type, activation_norm_params,
|
1292 |
+
skip_activation_norm, skip_nonlinearity,
|
1293 |
+
nonlinearity, inplace_nonlinearity,
|
1294 |
+
apply_noise, hidden_channels_equal_out_channels,
|
1295 |
+
order, block, learn_shortcut, clamp=None, output_scale=1,
|
1296 |
+
blur=False, upsample_first=True):
|
1297 |
+
self.multiple_outputs = True
|
1298 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
1299 |
+
padding, dilation, groups, bias, padding_mode,
|
1300 |
+
weight_norm_type, weight_norm_params,
|
1301 |
+
activation_norm_type, activation_norm_params,
|
1302 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
1303 |
+
inplace_nonlinearity, apply_noise,
|
1304 |
+
hidden_channels_equal_out_channels, order, block,
|
1305 |
+
learn_shortcut, clamp, output_scale, blur=blur,
|
1306 |
+
upsample_first=upsample_first)
|
1307 |
+
|
1308 |
+
def forward(self, x, *cond_inputs):
|
1309 |
+
r"""
|
1310 |
+
|
1311 |
+
Args:
|
1312 |
+
x (tensor): Input tensor.
|
1313 |
+
cond_inputs (list of tensors) : Conditional input tensors.
|
1314 |
+
Returns:
|
1315 |
+
(tuple):
|
1316 |
+
- output (tensor): Output tensor.
|
1317 |
+
- aux_outputs_0 (tensor): Auxiliary output of the first block.
|
1318 |
+
- aux_outputs_1 (tensor): Auxiliary output of the second block.
|
1319 |
+
"""
|
1320 |
+
dx, aux_outputs_0 = self.conv_block_0(x, *cond_inputs)
|
1321 |
+
dx, aux_outputs_1 = self.conv_block_1(dx, *cond_inputs)
|
1322 |
+
if self.learn_shortcut:
|
1323 |
+
# We are not using the auxiliary outputs of self.conv_block_s.
|
1324 |
+
x_shortcut, _ = self.conv_block_s(x, *cond_inputs)
|
1325 |
+
else:
|
1326 |
+
x_shortcut = x
|
1327 |
+
output = x_shortcut + dx
|
1328 |
+
return self.output_scale * output, aux_outputs_0, aux_outputs_1
|
1329 |
+
|
1330 |
+
|
1331 |
+
class MultiOutRes2dBlock(_BaseMultiOutResBlock):
|
1332 |
+
r"""Residual block for 2D input. It can return multiple outputs, if some
|
1333 |
+
layers in the block return more than one output.
|
1334 |
+
|
1335 |
+
Args:
|
1336 |
+
in_channels (int) : Number of channels in the input tensor.
|
1337 |
+
out_channels (int) : Number of channels in the output tensor.
|
1338 |
+
kernel_size (int, optional, default=3): Kernel size for the
|
1339 |
+
convolutional filters in the residual link.
|
1340 |
+
padding (int, optional, default=1): Padding size.
|
1341 |
+
dilation (int, optional, default=1): Dilation factor.
|
1342 |
+
groups (int, optional, default=1): Number of convolutional/linear
|
1343 |
+
groups.
|
1344 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
1345 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
1346 |
+
weight_norm_type (str, optional, default='none'):
|
1347 |
+
Type of weight normalization.
|
1348 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
1349 |
+
or ``'weight_demod'``.
|
1350 |
+
weight_norm_params (obj, optional, default=None):
|
1351 |
+
Parameters of weight normalization.
|
1352 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
1353 |
+
keyword arguments when initializing weight normalization.
|
1354 |
+
activation_norm_type (str, optional, default='none'):
|
1355 |
+
Type of activation normalization.
|
1356 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
1357 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
1358 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
1359 |
+
activation_norm_params (obj, optional, default=None):
|
1360 |
+
Parameters of activation normalization.
|
1361 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
1362 |
+
keyword arguments when initializing activation normalization.
|
1363 |
+
skip_activation_norm (bool, optional, default=True): If ``True`` and
|
1364 |
+
``learn_shortcut`` is also ``True``, applies activation norm to the
|
1365 |
+
learned shortcut connection.
|
1366 |
+
skip_nonlinearity (bool, optional, default=True): If ``True`` and
|
1367 |
+
``learn_shortcut`` is also ``True``, applies nonlinearity to the
|
1368 |
+
learned shortcut connection.
|
1369 |
+
nonlinearity (str, optional, default='none'):
|
1370 |
+
Type of nonlinear activation function in the residual link.
|
1371 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
1372 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
1373 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
1374 |
+
set ``inplace=True`` when initializing the nonlinearity layers.
|
1375 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
1376 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
1377 |
+
hidden_channels_equal_out_channels (bool, optional, default=False):
|
1378 |
+
If ``True``, set the hidden channel number to be equal to the
|
1379 |
+
output channel number. If ``False``, the hidden channel number
|
1380 |
+
equals to the smaller of the input channel number and the
|
1381 |
+
output channel number.
|
1382 |
+
order (str, optional, default='CNACNA'): Order of operations
|
1383 |
+
in the residual link.
|
1384 |
+
``'C'``: convolution,
|
1385 |
+
``'N'``: normalization,
|
1386 |
+
``'A'``: nonlinear activation.
|
1387 |
+
learn_shortcut (bool, optional, default=False): If ``True``, always use
|
1388 |
+
a convolutional shortcut instead of an identity one, otherwise only
|
1389 |
+
use a convolutional one if input and output have different number of
|
1390 |
+
channels.
|
1391 |
+
"""
|
1392 |
+
|
1393 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
1394 |
+
stride=1, padding=1, dilation=1, groups=1, bias=True,
|
1395 |
+
padding_mode='zeros',
|
1396 |
+
weight_norm_type='none', weight_norm_params=None,
|
1397 |
+
activation_norm_type='none', activation_norm_params=None,
|
1398 |
+
skip_activation_norm=True, skip_nonlinearity=False,
|
1399 |
+
nonlinearity='leakyrelu', inplace_nonlinearity=False,
|
1400 |
+
apply_noise=False, hidden_channels_equal_out_channels=False,
|
1401 |
+
order='CNACNA', learn_shortcut=None, clamp=None,
|
1402 |
+
output_scale=1, blur=False, upsample_first=True):
|
1403 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
1404 |
+
padding, dilation, groups, bias, padding_mode,
|
1405 |
+
weight_norm_type, weight_norm_params,
|
1406 |
+
activation_norm_type, activation_norm_params,
|
1407 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
1408 |
+
inplace_nonlinearity, apply_noise,
|
1409 |
+
hidden_channels_equal_out_channels, order,
|
1410 |
+
MultiOutConv2dBlock, learn_shortcut, clamp,
|
1411 |
+
output_scale, blur=blur, upsample_first=upsample_first)
|
imaginaire/layers/residual_deep.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.utils.checkpoint import checkpoint
|
8 |
+
|
9 |
+
from imaginaire.third_party.upfirdn2d import BlurDownsample, BlurUpsample
|
10 |
+
from .conv import Conv2dBlock
|
11 |
+
|
12 |
+
|
13 |
+
class _BaseDeepResBlock(nn.Module):
|
14 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
15 |
+
stride, padding, dilation, groups, bias, padding_mode,
|
16 |
+
weight_norm_type, weight_norm_params,
|
17 |
+
activation_norm_type, activation_norm_params,
|
18 |
+
skip_activation_norm, skip_nonlinearity,
|
19 |
+
nonlinearity, inplace_nonlinearity, apply_noise,
|
20 |
+
hidden_channels_equal_out_channels,
|
21 |
+
order, block, learn_shortcut, output_scale, skip_block=None,
|
22 |
+
blur=True, border_free=True, resample_first=True,
|
23 |
+
skip_weight_norm=True, hidden_channel_ratio=4):
|
24 |
+
super().__init__()
|
25 |
+
self.in_channels = in_channels
|
26 |
+
self.out_channels = out_channels
|
27 |
+
self.output_scale = output_scale
|
28 |
+
self.resample_first = resample_first
|
29 |
+
self.stride = stride
|
30 |
+
self.blur = blur
|
31 |
+
self.border_free = border_free
|
32 |
+
assert not border_free
|
33 |
+
if skip_block is None:
|
34 |
+
skip_block = block
|
35 |
+
|
36 |
+
if order == 'pre_act':
|
37 |
+
order = 'NACNAC'
|
38 |
+
if isinstance(bias, bool):
|
39 |
+
# The bias for conv_block_0, conv_block_1, and conv_block_s.
|
40 |
+
biases = [bias, bias, bias]
|
41 |
+
elif isinstance(bias, list):
|
42 |
+
if len(bias) == 3:
|
43 |
+
biases = bias
|
44 |
+
else:
|
45 |
+
raise ValueError('Bias list must be 3.')
|
46 |
+
else:
|
47 |
+
raise ValueError('Bias must be either an integer or s list.')
|
48 |
+
self.learn_shortcut = learn_shortcut
|
49 |
+
if len(order) > 6 or len(order) < 5:
|
50 |
+
raise ValueError('order must be either 5 or 6 characters')
|
51 |
+
hidden_channels = in_channels // hidden_channel_ratio
|
52 |
+
|
53 |
+
# Parameters.
|
54 |
+
residual_params = {}
|
55 |
+
shortcut_params = {}
|
56 |
+
base_params = dict(dilation=dilation,
|
57 |
+
groups=groups,
|
58 |
+
padding_mode=padding_mode)
|
59 |
+
residual_params.update(base_params)
|
60 |
+
residual_params.update(
|
61 |
+
dict(activation_norm_type=activation_norm_type,
|
62 |
+
activation_norm_params=activation_norm_params,
|
63 |
+
weight_norm_type=weight_norm_type,
|
64 |
+
weight_norm_params=weight_norm_params,
|
65 |
+
apply_noise=apply_noise)
|
66 |
+
)
|
67 |
+
shortcut_params.update(base_params)
|
68 |
+
shortcut_params.update(dict(kernel_size=1))
|
69 |
+
if skip_activation_norm:
|
70 |
+
shortcut_params.update(
|
71 |
+
dict(activation_norm_type=activation_norm_type,
|
72 |
+
activation_norm_params=activation_norm_params,
|
73 |
+
apply_noise=False))
|
74 |
+
if skip_weight_norm:
|
75 |
+
shortcut_params.update(
|
76 |
+
dict(weight_norm_type=weight_norm_type,
|
77 |
+
weight_norm_params=weight_norm_params))
|
78 |
+
|
79 |
+
# Residual branch.
|
80 |
+
if order.find('A') < order.find('C') and \
|
81 |
+
(activation_norm_type == '' or activation_norm_type == 'none'):
|
82 |
+
# Nonlinearity is the first operation in the residual path.
|
83 |
+
# In-place nonlinearity will modify the input variable and cause
|
84 |
+
# backward error.
|
85 |
+
first_inplace = False
|
86 |
+
else:
|
87 |
+
first_inplace = inplace_nonlinearity
|
88 |
+
|
89 |
+
(first_stride, second_stride, shortcut_stride,
|
90 |
+
first_blur, second_blur, shortcut_blur) = self._get_stride_blur()
|
91 |
+
|
92 |
+
self.conv_block_1x1_in = block(
|
93 |
+
in_channels, hidden_channels,
|
94 |
+
1, 1, 0,
|
95 |
+
bias=biases[0],
|
96 |
+
nonlinearity=nonlinearity,
|
97 |
+
order=order[0:3],
|
98 |
+
inplace_nonlinearity=first_inplace,
|
99 |
+
**residual_params
|
100 |
+
)
|
101 |
+
|
102 |
+
self.conv_block_0 = block(
|
103 |
+
hidden_channels, hidden_channels,
|
104 |
+
kernel_size=2 if self.border_free and first_stride < 1 else
|
105 |
+
kernel_size,
|
106 |
+
padding=padding,
|
107 |
+
bias=biases[0],
|
108 |
+
nonlinearity=nonlinearity,
|
109 |
+
order=order[0:3],
|
110 |
+
inplace_nonlinearity=inplace_nonlinearity,
|
111 |
+
stride=first_stride,
|
112 |
+
blur=first_blur,
|
113 |
+
**residual_params
|
114 |
+
)
|
115 |
+
self.conv_block_1 = block(
|
116 |
+
hidden_channels, hidden_channels,
|
117 |
+
kernel_size=kernel_size,
|
118 |
+
padding=padding,
|
119 |
+
bias=biases[1],
|
120 |
+
nonlinearity=nonlinearity,
|
121 |
+
order=order[3:],
|
122 |
+
inplace_nonlinearity=inplace_nonlinearity,
|
123 |
+
stride=second_stride,
|
124 |
+
blur=second_blur,
|
125 |
+
**residual_params
|
126 |
+
)
|
127 |
+
|
128 |
+
self.conv_block_1x1_out = block(
|
129 |
+
hidden_channels, out_channels,
|
130 |
+
1, 1, 0,
|
131 |
+
bias=biases[1],
|
132 |
+
nonlinearity=nonlinearity,
|
133 |
+
order=order[0:3],
|
134 |
+
inplace_nonlinearity=inplace_nonlinearity,
|
135 |
+
**residual_params
|
136 |
+
)
|
137 |
+
|
138 |
+
# Shortcut branch.
|
139 |
+
if self.learn_shortcut:
|
140 |
+
if skip_nonlinearity:
|
141 |
+
skip_nonlinearity_type = nonlinearity
|
142 |
+
else:
|
143 |
+
skip_nonlinearity_type = ''
|
144 |
+
self.conv_block_s = skip_block(in_channels, out_channels,
|
145 |
+
bias=biases[2],
|
146 |
+
nonlinearity=skip_nonlinearity_type,
|
147 |
+
order=order[0:3],
|
148 |
+
stride=shortcut_stride,
|
149 |
+
blur=shortcut_blur,
|
150 |
+
**shortcut_params)
|
151 |
+
elif in_channels < out_channels:
|
152 |
+
if skip_nonlinearity:
|
153 |
+
skip_nonlinearity_type = nonlinearity
|
154 |
+
else:
|
155 |
+
skip_nonlinearity_type = ''
|
156 |
+
self.conv_block_s = skip_block(in_channels,
|
157 |
+
out_channels - in_channels,
|
158 |
+
bias=biases[2],
|
159 |
+
nonlinearity=skip_nonlinearity_type,
|
160 |
+
order=order[0:3],
|
161 |
+
stride=shortcut_stride,
|
162 |
+
blur=shortcut_blur,
|
163 |
+
**shortcut_params)
|
164 |
+
|
165 |
+
# Whether this block expects conditional inputs.
|
166 |
+
self.conditional = \
|
167 |
+
getattr(self.conv_block_0, 'conditional', False) or \
|
168 |
+
getattr(self.conv_block_1, 'conditional', False) or \
|
169 |
+
getattr(self.conv_block_1x1_in, 'conditional', False) or \
|
170 |
+
getattr(self.conv_block_1x1_out, 'conditional', False)
|
171 |
+
|
172 |
+
def _get_stride_blur(self):
|
173 |
+
if self.stride > 1:
|
174 |
+
# Downsampling.
|
175 |
+
first_stride, second_stride = 1, self.stride
|
176 |
+
first_blur, second_blur = False, self.blur
|
177 |
+
shortcut_blur = False
|
178 |
+
shortcut_stride = 1
|
179 |
+
if self.blur:
|
180 |
+
# The shortcut branch uses blur_downsample + stride-1 conv
|
181 |
+
if self.border_free:
|
182 |
+
self.resample = nn.AvgPool2d(2)
|
183 |
+
else:
|
184 |
+
self.resample = BlurDownsample()
|
185 |
+
else:
|
186 |
+
shortcut_stride = self.stride
|
187 |
+
self.resample = nn.AvgPool2d(2)
|
188 |
+
elif self.stride < 1:
|
189 |
+
# Upsampling.
|
190 |
+
first_stride, second_stride = self.stride, 1
|
191 |
+
first_blur, second_blur = self.blur, False
|
192 |
+
shortcut_blur = False
|
193 |
+
shortcut_stride = 1
|
194 |
+
if self.blur:
|
195 |
+
# The shortcut branch uses blur_upsample + stride-1 conv
|
196 |
+
if self.border_free:
|
197 |
+
self.resample = nn.Upsample(scale_factor=2,
|
198 |
+
mode='bilinear')
|
199 |
+
else:
|
200 |
+
self.resample = BlurUpsample()
|
201 |
+
else:
|
202 |
+
shortcut_stride = self.stride
|
203 |
+
self.resample = nn.Upsample(scale_factor=2)
|
204 |
+
else:
|
205 |
+
first_stride = second_stride = 1
|
206 |
+
first_blur = second_blur = False
|
207 |
+
shortcut_stride = 1
|
208 |
+
shortcut_blur = False
|
209 |
+
self.resample = None
|
210 |
+
return (first_stride, second_stride, shortcut_stride,
|
211 |
+
first_blur, second_blur, shortcut_blur)
|
212 |
+
|
213 |
+
def conv_blocks(
|
214 |
+
self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs
|
215 |
+
):
|
216 |
+
if separate_cond:
|
217 |
+
assert len(list(cond_inputs)) == 4
|
218 |
+
dx = self.conv_block_1x1_in(x, cond_inputs[0],
|
219 |
+
**kw_cond_inputs.get('kwargs_0', {}))
|
220 |
+
dx = self.conv_block_0(dx, cond_inputs[1],
|
221 |
+
**kw_cond_inputs.get('kwargs_1', {}))
|
222 |
+
dx = self.conv_block_1(dx, cond_inputs[2],
|
223 |
+
**kw_cond_inputs.get('kwargs_2', {}))
|
224 |
+
dx = self.conv_block_1x1_out(dx, cond_inputs[3],
|
225 |
+
**kw_cond_inputs.get('kwargs_3', {}))
|
226 |
+
else:
|
227 |
+
dx = self.conv_block_1x1_in(x, *cond_inputs, **kw_cond_inputs)
|
228 |
+
dx = self.conv_block_0(dx, *cond_inputs, **kw_cond_inputs)
|
229 |
+
dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
|
230 |
+
dx = self.conv_block_1x1_out(dx, *cond_inputs, **kw_cond_inputs)
|
231 |
+
return dx
|
232 |
+
|
233 |
+
def forward(self, x, *cond_inputs, do_checkpoint=False, **kw_cond_inputs):
|
234 |
+
if do_checkpoint:
|
235 |
+
dx = checkpoint(self.conv_blocks, x, *cond_inputs, **kw_cond_inputs)
|
236 |
+
else:
|
237 |
+
dx = self.conv_blocks(x, *cond_inputs, **kw_cond_inputs)
|
238 |
+
|
239 |
+
if self.resample_first and self.resample is not None:
|
240 |
+
x = self.resample(x)
|
241 |
+
if self.learn_shortcut:
|
242 |
+
x_shortcut = self.conv_block_s(
|
243 |
+
x, *cond_inputs, **kw_cond_inputs
|
244 |
+
)
|
245 |
+
elif self.in_channels < self.out_channels:
|
246 |
+
x_shortcut_pad = self.conv_block_s(
|
247 |
+
x, *cond_inputs, **kw_cond_inputs
|
248 |
+
)
|
249 |
+
x_shortcut = torch.cat((x, x_shortcut_pad), dim=1)
|
250 |
+
elif self.in_channels > self.out_channels:
|
251 |
+
x_shortcut = x[:, :self.out_channels, :, :]
|
252 |
+
else:
|
253 |
+
x_shortcut = x
|
254 |
+
if not self.resample_first and self.resample is not None:
|
255 |
+
x_shortcut = self.resample(x_shortcut)
|
256 |
+
|
257 |
+
output = x_shortcut + dx
|
258 |
+
return self.output_scale * output
|
259 |
+
|
260 |
+
def extra_repr(self):
|
261 |
+
s = 'output_scale={output_scale}'
|
262 |
+
return s.format(**self.__dict__)
|
263 |
+
|
264 |
+
|
265 |
+
class DeepRes2dBlock(_BaseDeepResBlock):
|
266 |
+
r"""Residual block for 2D input.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
in_channels (int) : Number of channels in the input tensor.
|
270 |
+
out_channels (int) : Number of channels in the output tensor.
|
271 |
+
kernel_size (int, optional, default=3): Kernel size for the
|
272 |
+
convolutional filters in the residual link.
|
273 |
+
padding (int, optional, default=1): Padding size.
|
274 |
+
dilation (int, optional, default=1): Dilation factor.
|
275 |
+
groups (int, optional, default=1): Number of convolutional/linear
|
276 |
+
groups.
|
277 |
+
padding_mode (string, optional, default='zeros'): Type of padding:
|
278 |
+
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
|
279 |
+
weight_norm_type (str, optional, default='none'):
|
280 |
+
Type of weight normalization.
|
281 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
282 |
+
or ``'weight_demod'``.
|
283 |
+
weight_norm_params (obj, optional, default=None):
|
284 |
+
Parameters of weight normalization.
|
285 |
+
If not ``None``, ``weight_norm_params.__dict__`` will be used as
|
286 |
+
keyword arguments when initializing weight normalization.
|
287 |
+
activation_norm_type (str, optional, default='none'):
|
288 |
+
Type of activation normalization.
|
289 |
+
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
|
290 |
+
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
|
291 |
+
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
|
292 |
+
activation_norm_params (obj, optional, default=None):
|
293 |
+
Parameters of activation normalization.
|
294 |
+
If not ``None``, ``activation_norm_params.__dict__`` will be used as
|
295 |
+
keyword arguments when initializing activation normalization.
|
296 |
+
skip_activation_norm (bool, optional, default=True): If ``True`` and
|
297 |
+
``learn_shortcut`` is also ``True``, applies activation norm to the
|
298 |
+
learned shortcut connection.
|
299 |
+
skip_nonlinearity (bool, optional, default=True): If ``True`` and
|
300 |
+
``learn_shortcut`` is also ``True``, applies nonlinearity to the
|
301 |
+
learned shortcut connection.
|
302 |
+
nonlinearity (str, optional, default='none'):
|
303 |
+
Type of nonlinear activation function in the residual link.
|
304 |
+
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
|
305 |
+
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
|
306 |
+
inplace_nonlinearity (bool, optional, default=False): If ``True``,
|
307 |
+
set ``inplace=True`` when initializing the nonlinearity layers.
|
308 |
+
apply_noise (bool, optional, default=False): If ``True``, adds
|
309 |
+
Gaussian noise with learnable magnitude to the convolution output.
|
310 |
+
hidden_channels_equal_out_channels (bool, optional, default=False):
|
311 |
+
If ``True``, set the hidden channel number to be equal to the
|
312 |
+
output channel number. If ``False``, the hidden channel number
|
313 |
+
equals to the smaller of the input channel number and the
|
314 |
+
output channel number.
|
315 |
+
order (str, optional, default='CNACNA'): Order of operations
|
316 |
+
in the residual link.
|
317 |
+
``'C'``: convolution,
|
318 |
+
``'N'``: normalization,
|
319 |
+
``'A'``: nonlinear activation.
|
320 |
+
learn_shortcut (bool, optional, default=False): If ``True``, always use
|
321 |
+
a convolutional shortcut instead of an identity one, otherwise only
|
322 |
+
use a convolutional one if input and output have different number of
|
323 |
+
channels.
|
324 |
+
"""
|
325 |
+
|
326 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
327 |
+
stride=1, padding=1, dilation=1, groups=1, bias=True,
|
328 |
+
padding_mode='zeros',
|
329 |
+
weight_norm_type='none', weight_norm_params=None,
|
330 |
+
activation_norm_type='none', activation_norm_params=None,
|
331 |
+
skip_activation_norm=True, skip_nonlinearity=False,
|
332 |
+
skip_weight_norm=True,
|
333 |
+
nonlinearity='leakyrelu', inplace_nonlinearity=False,
|
334 |
+
apply_noise=False, hidden_channels_equal_out_channels=False,
|
335 |
+
order='CNACNA', learn_shortcut=False, output_scale=1,
|
336 |
+
blur=True, resample_first=True, border_free=False):
|
337 |
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
338 |
+
padding, dilation, groups, bias, padding_mode,
|
339 |
+
weight_norm_type, weight_norm_params,
|
340 |
+
activation_norm_type, activation_norm_params,
|
341 |
+
skip_activation_norm, skip_nonlinearity, nonlinearity,
|
342 |
+
inplace_nonlinearity, apply_noise,
|
343 |
+
hidden_channels_equal_out_channels, order, Conv2dBlock,
|
344 |
+
learn_shortcut, output_scale, blur=blur,
|
345 |
+
resample_first=resample_first, border_free=border_free,
|
346 |
+
skip_weight_norm=skip_weight_norm)
|
imaginaire/layers/vit.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
from types import SimpleNamespace
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from .misc import ApplyNoise
|
11 |
+
from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur
|
12 |
+
|
13 |
+
|
14 |
+
class ViT2dBlock(nn.Module):
|
15 |
+
r"""An abstract wrapper class that wraps a torch convolution or linear layer
|
16 |
+
with normalization and nonlinearity.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
20 |
+
padding, dilation, groups, bias, padding_mode,
|
21 |
+
weight_norm_type, weight_norm_params,
|
22 |
+
activation_norm_type, activation_norm_params,
|
23 |
+
nonlinearity, inplace_nonlinearity,
|
24 |
+
apply_noise, blur, order, input_dim, clamp,
|
25 |
+
blur_kernel=(1, 3, 3, 1), output_scale=None,
|
26 |
+
init_gain=1.0):
|
27 |
+
super().__init__()
|
28 |
+
from .nonlinearity import get_nonlinearity_layer
|
29 |
+
from .weight_norm import get_weight_norm_layer
|
30 |
+
from .activation_norm import get_activation_norm_layer
|
31 |
+
self.weight_norm_type = weight_norm_type
|
32 |
+
self.stride = stride
|
33 |
+
self.clamp = clamp
|
34 |
+
self.init_gain = init_gain
|
35 |
+
|
36 |
+
# Nonlinearity layer.
|
37 |
+
if 'fused' in nonlinearity:
|
38 |
+
# Fusing nonlinearity with bias.
|
39 |
+
lr_mul = getattr(weight_norm_params, 'lr_mul', 1)
|
40 |
+
conv_before_nonlinearity = order.find('C') < order.find('A')
|
41 |
+
if conv_before_nonlinearity:
|
42 |
+
assert bias
|
43 |
+
bias = False
|
44 |
+
channel = out_channels if conv_before_nonlinearity else in_channels
|
45 |
+
nonlinearity_layer = get_nonlinearity_layer(
|
46 |
+
nonlinearity, inplace=inplace_nonlinearity,
|
47 |
+
num_channels=channel, lr_mul=lr_mul)
|
48 |
+
else:
|
49 |
+
nonlinearity_layer = get_nonlinearity_layer(
|
50 |
+
nonlinearity, inplace=inplace_nonlinearity)
|
51 |
+
|
52 |
+
# Noise injection layer.
|
53 |
+
if apply_noise:
|
54 |
+
order = order.replace('C', 'CG')
|
55 |
+
noise_layer = ApplyNoise()
|
56 |
+
else:
|
57 |
+
noise_layer = None
|
58 |
+
|
59 |
+
# Convolutional layer.
|
60 |
+
if blur:
|
61 |
+
if stride == 2:
|
62 |
+
# Blur - Conv - Noise - Activate
|
63 |
+
p = (len(blur_kernel) - 2) + (kernel_size - 1)
|
64 |
+
pad0, pad1 = (p + 1) // 2, p // 2
|
65 |
+
padding = 0
|
66 |
+
blur_layer = Blur(
|
67 |
+
blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
|
68 |
+
)
|
69 |
+
order = order.replace('C', 'BC')
|
70 |
+
elif stride == 0.5:
|
71 |
+
# Conv - Blur - Noise - Activate
|
72 |
+
padding = 0
|
73 |
+
p = (len(blur_kernel) - 2) - (kernel_size - 1)
|
74 |
+
pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1
|
75 |
+
blur_layer = Blur(
|
76 |
+
blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
|
77 |
+
)
|
78 |
+
order = order.replace('C', 'CB')
|
79 |
+
elif stride == 1:
|
80 |
+
# No blur for now
|
81 |
+
blur_layer = nn.Identity()
|
82 |
+
else:
|
83 |
+
raise NotImplementedError
|
84 |
+
else:
|
85 |
+
blur_layer = nn.Identity()
|
86 |
+
|
87 |
+
if weight_norm_params is None:
|
88 |
+
weight_norm_params = SimpleNamespace()
|
89 |
+
weight_norm = get_weight_norm_layer(
|
90 |
+
weight_norm_type, **vars(weight_norm_params))
|
91 |
+
conv_layer = weight_norm(self._get_conv_layer(
|
92 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
93 |
+
groups, bias, padding_mode, input_dim))
|
94 |
+
|
95 |
+
# Normalization layer.
|
96 |
+
conv_before_norm = order.find('C') < order.find('N')
|
97 |
+
norm_channels = out_channels if conv_before_norm else in_channels
|
98 |
+
if activation_norm_params is None:
|
99 |
+
activation_norm_params = SimpleNamespace()
|
100 |
+
activation_norm_layer = get_activation_norm_layer(
|
101 |
+
norm_channels,
|
102 |
+
activation_norm_type,
|
103 |
+
input_dim,
|
104 |
+
**vars(activation_norm_params))
|
105 |
+
|
106 |
+
# Mapping from operation names to layers.
|
107 |
+
mappings = {'C': {'conv': conv_layer},
|
108 |
+
'N': {'norm': activation_norm_layer},
|
109 |
+
'A': {'nonlinearity': nonlinearity_layer}}
|
110 |
+
mappings.update({'B': {'blur': blur_layer}})
|
111 |
+
mappings.update({'G': {'noise': noise_layer}})
|
112 |
+
|
113 |
+
# All layers in order.
|
114 |
+
self.layers = nn.ModuleDict()
|
115 |
+
for op in order:
|
116 |
+
if list(mappings[op].values())[0] is not None:
|
117 |
+
self.layers.update(mappings[op])
|
118 |
+
|
119 |
+
# Whether this block expects conditional inputs.
|
120 |
+
self.conditional = \
|
121 |
+
getattr(conv_layer, 'conditional', False) or \
|
122 |
+
getattr(activation_norm_layer, 'conditional', False)
|
123 |
+
|
124 |
+
if output_scale is not None:
|
125 |
+
self.output_scale = nn.Parameter(torch.tensor(output_scale))
|
126 |
+
else:
|
127 |
+
self.register_parameter("output_scale", None)
|
128 |
+
|
129 |
+
def forward(self, x, *cond_inputs, **kw_cond_inputs):
|
130 |
+
r"""
|
131 |
+
|
132 |
+
Args:
|
133 |
+
x (tensor): Input tensor.
|
134 |
+
cond_inputs (list of tensors) : Conditional input tensors.
|
135 |
+
kw_cond_inputs (dict) : Keyword conditional inputs.
|
136 |
+
"""
|
137 |
+
for key, layer in self.layers.items():
|
138 |
+
if getattr(layer, 'conditional', False):
|
139 |
+
# Layers that require conditional inputs.
|
140 |
+
x = layer(x, *cond_inputs, **kw_cond_inputs)
|
141 |
+
else:
|
142 |
+
x = layer(x)
|
143 |
+
if self.clamp is not None and isinstance(layer, nn.Conv2d):
|
144 |
+
x.clamp_(max=self.clamp)
|
145 |
+
if key == 'conv':
|
146 |
+
if self.output_scale is not None:
|
147 |
+
x = x * self.output_scale
|
148 |
+
return x
|
149 |
+
|
150 |
+
def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
|
151 |
+
padding, dilation, groups, bias, padding_mode,
|
152 |
+
input_dim):
|
153 |
+
# Returns the convolutional layer.
|
154 |
+
if input_dim == 0:
|
155 |
+
layer = nn.Linear(in_channels, out_channels, bias)
|
156 |
+
else:
|
157 |
+
if stride < 1: # Fractionally-strided convolution.
|
158 |
+
padding_mode = 'zeros'
|
159 |
+
assert padding == 0
|
160 |
+
layer_type = getattr(nn, f'ConvTranspose{input_dim}d')
|
161 |
+
stride = round(1 / stride)
|
162 |
+
else:
|
163 |
+
layer_type = getattr(nn, f'Conv{input_dim}d')
|
164 |
+
layer = layer_type(
|
165 |
+
in_channels, out_channels, kernel_size, stride, padding,
|
166 |
+
dilation=dilation, groups=groups, bias=bias,
|
167 |
+
padding_mode=padding_mode
|
168 |
+
)
|
169 |
+
|
170 |
+
return layer
|
171 |
+
|
172 |
+
def __repr__(self):
|
173 |
+
main_str = self._get_name() + '('
|
174 |
+
child_lines = []
|
175 |
+
for name, layer in self.layers.items():
|
176 |
+
mod_str = repr(layer)
|
177 |
+
if name == 'conv' and self.weight_norm_type != 'none' and \
|
178 |
+
self.weight_norm_type != '':
|
179 |
+
mod_str = mod_str[:-1] + \
|
180 |
+
', weight_norm={}'.format(self.weight_norm_type) + ')'
|
181 |
+
if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1:
|
182 |
+
mod_str = mod_str[:-1] + \
|
183 |
+
', lr_mul={}'.format(layer.base_lr_mul) + ')'
|
184 |
+
mod_str = self._addindent(mod_str, 2)
|
185 |
+
child_lines.append(mod_str)
|
186 |
+
if len(child_lines) == 1:
|
187 |
+
main_str += child_lines[0]
|
188 |
+
else:
|
189 |
+
main_str += '\n ' + '\n '.join(child_lines) + '\n'
|
190 |
+
|
191 |
+
main_str += ')'
|
192 |
+
return main_str
|
193 |
+
|
194 |
+
@staticmethod
|
195 |
+
def _addindent(s_, numSpaces):
|
196 |
+
s = s_.split('\n')
|
197 |
+
# don't do anything for single-line stuff
|
198 |
+
if len(s) == 1:
|
199 |
+
return s_
|
200 |
+
first = s.pop(0)
|
201 |
+
s = [(numSpaces * ' ') + line for line in s]
|
202 |
+
s = '\n'.join(s)
|
203 |
+
s = first + '\n' + s
|
204 |
+
return s
|
imaginaire/layers/weight_norm.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import collections
|
6 |
+
import functools
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
11 |
+
from torch.nn.utils.spectral_norm import SpectralNorm, \
|
12 |
+
SpectralNormStateDictHook, SpectralNormLoadStateDictPreHook
|
13 |
+
|
14 |
+
from .conv import LinearBlock
|
15 |
+
|
16 |
+
|
17 |
+
class WeightDemodulation(nn.Module):
|
18 |
+
r"""Weight demodulation in
|
19 |
+
"Analyzing and Improving the Image Quality of StyleGAN", Karras et al.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
conv (torch.nn.Modules): Convolutional layer.
|
23 |
+
cond_dims (int): The number of channels in the conditional input.
|
24 |
+
eps (float, optional, default=1e-8): a value added to the
|
25 |
+
denominator for numerical stability.
|
26 |
+
adaptive_bias (bool, optional, default=False): If ``True``, adaptively
|
27 |
+
predicts bias from the conditional input.
|
28 |
+
demod (bool, optional, default=False): If ``True``, performs
|
29 |
+
weight demodulation.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, conv, cond_dims, eps=1e-8,
|
33 |
+
adaptive_bias=False, demod=True):
|
34 |
+
super().__init__()
|
35 |
+
self.conv = conv
|
36 |
+
self.adaptive_bias = adaptive_bias
|
37 |
+
if adaptive_bias:
|
38 |
+
self.conv.register_parameter('bias', None)
|
39 |
+
self.fc_beta = LinearBlock(cond_dims, self.conv.out_channels)
|
40 |
+
self.fc_gamma = LinearBlock(cond_dims, self.conv.in_channels)
|
41 |
+
self.eps = eps
|
42 |
+
self.demod = demod
|
43 |
+
self.conditional = True
|
44 |
+
|
45 |
+
def forward(self, x, y, **_kwargs):
|
46 |
+
r"""Weight demodulation forward"""
|
47 |
+
b, c, h, w = x.size()
|
48 |
+
self.conv.groups = b
|
49 |
+
gamma = self.fc_gamma(y)
|
50 |
+
gamma = gamma[:, None, :, None, None]
|
51 |
+
weight = self.conv.weight[None, :, :, :, :] * gamma
|
52 |
+
|
53 |
+
if self.demod:
|
54 |
+
d = torch.rsqrt(
|
55 |
+
(weight ** 2).sum(
|
56 |
+
dim=(2, 3, 4), keepdim=True) + self.eps)
|
57 |
+
weight = weight * d
|
58 |
+
|
59 |
+
x = x.reshape(1, -1, h, w)
|
60 |
+
_, _, *ws = weight.shape
|
61 |
+
weight = weight.reshape(b * self.conv.out_channels, *ws)
|
62 |
+
x = self.conv._conv_forward(x, weight)
|
63 |
+
|
64 |
+
x = x.reshape(-1, self.conv.out_channels, h, w)
|
65 |
+
if self.adaptive_bias:
|
66 |
+
x += self.fc_beta(y)[:, :, None, None]
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
def weight_demod(
|
71 |
+
conv, cond_dims=256, eps=1e-8, adaptive_bias=False, demod=True):
|
72 |
+
r"""Weight demodulation."""
|
73 |
+
return WeightDemodulation(conv, cond_dims, eps, adaptive_bias, demod)
|
74 |
+
|
75 |
+
|
76 |
+
class ScaledLR(object):
|
77 |
+
def __init__(self, weight_name, bias_name):
|
78 |
+
self.weight_name = weight_name
|
79 |
+
self.bias_name = bias_name
|
80 |
+
|
81 |
+
def compute_weight(self, module):
|
82 |
+
weight = getattr(module, self.weight_name + '_ori')
|
83 |
+
return weight * module.weight_scale
|
84 |
+
|
85 |
+
def compute_bias(self, module):
|
86 |
+
bias = getattr(module, self.bias_name + '_ori')
|
87 |
+
if bias is not None:
|
88 |
+
return bias * module.bias_scale
|
89 |
+
else:
|
90 |
+
return None
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def apply(module, weight_name, bias_name, lr_mul, equalized):
|
94 |
+
assert weight_name == 'weight'
|
95 |
+
assert bias_name == 'bias'
|
96 |
+
fn = ScaledLR(weight_name, bias_name)
|
97 |
+
module.register_forward_pre_hook(fn)
|
98 |
+
|
99 |
+
if hasattr(module, bias_name):
|
100 |
+
# module.bias is a parameter (can be None).
|
101 |
+
bias = getattr(module, bias_name)
|
102 |
+
delattr(module, bias_name)
|
103 |
+
module.register_parameter(bias_name + '_ori', bias)
|
104 |
+
else:
|
105 |
+
# module.bias does not exist.
|
106 |
+
bias = None
|
107 |
+
setattr(module, bias_name + '_ori', bias)
|
108 |
+
if bias is not None:
|
109 |
+
setattr(module, bias_name, bias.data)
|
110 |
+
else:
|
111 |
+
setattr(module, bias_name, None)
|
112 |
+
module.register_buffer('bias_scale', torch.tensor(lr_mul))
|
113 |
+
|
114 |
+
if hasattr(module, weight_name + '_orig'):
|
115 |
+
# The module has been wrapped with spectral normalization.
|
116 |
+
# We only want to keep a single weight parameter.
|
117 |
+
weight = getattr(module, weight_name + '_orig')
|
118 |
+
delattr(module, weight_name + '_orig')
|
119 |
+
module.register_parameter(weight_name + '_ori', weight)
|
120 |
+
setattr(module, weight_name + '_orig', weight.data)
|
121 |
+
# Put this hook before the spectral norm hook.
|
122 |
+
module._forward_pre_hooks = collections.OrderedDict(
|
123 |
+
reversed(list(module._forward_pre_hooks.items()))
|
124 |
+
)
|
125 |
+
module.use_sn = True
|
126 |
+
else:
|
127 |
+
weight = getattr(module, weight_name)
|
128 |
+
delattr(module, weight_name)
|
129 |
+
module.register_parameter(weight_name + '_ori', weight)
|
130 |
+
setattr(module, weight_name, weight.data)
|
131 |
+
module.use_sn = False
|
132 |
+
|
133 |
+
# assert weight.dim() == 4 or weight.dim() == 2
|
134 |
+
if equalized:
|
135 |
+
fan_in = weight.data.size(1) * weight.data[0][0].numel()
|
136 |
+
# Theoretically, the gain should be sqrt(2) instead of 1.
|
137 |
+
# The official StyleGAN2 uses 1 for some reason.
|
138 |
+
module.register_buffer(
|
139 |
+
'weight_scale', torch.tensor(lr_mul * ((1 / fan_in) ** 0.5))
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
module.register_buffer('weight_scale', torch.tensor(lr_mul))
|
143 |
+
|
144 |
+
module.lr_mul = module.weight_scale
|
145 |
+
module.base_lr_mul = lr_mul
|
146 |
+
|
147 |
+
return fn
|
148 |
+
|
149 |
+
def remove(self, module):
|
150 |
+
with torch.no_grad():
|
151 |
+
weight = self.compute_weight(module)
|
152 |
+
delattr(module, self.weight_name + '_ori')
|
153 |
+
|
154 |
+
if module.use_sn:
|
155 |
+
setattr(module, self.weight_name + '_orig', weight.detach())
|
156 |
+
else:
|
157 |
+
delattr(module, self.weight_name)
|
158 |
+
module.register_parameter(self.weight_name,
|
159 |
+
torch.nn.Parameter(weight.detach()))
|
160 |
+
|
161 |
+
with torch.no_grad():
|
162 |
+
bias = self.compute_bias(module)
|
163 |
+
delattr(module, self.bias_name)
|
164 |
+
delattr(module, self.bias_name + '_ori')
|
165 |
+
if bias is not None:
|
166 |
+
module.register_parameter(self.bias_name,
|
167 |
+
torch.nn.Parameter(bias.detach()))
|
168 |
+
else:
|
169 |
+
module.register_parameter(self.bias_name, None)
|
170 |
+
|
171 |
+
module.lr_mul = 1.0
|
172 |
+
module.base_lr_mul = 1.0
|
173 |
+
|
174 |
+
def __call__(self, module, input):
|
175 |
+
weight = self.compute_weight(module)
|
176 |
+
if module.use_sn:
|
177 |
+
# The following spectral norm hook will compute the SN of
|
178 |
+
# "module.weight_orig" and store the normalized weight in
|
179 |
+
# "module.weight".
|
180 |
+
setattr(module, self.weight_name + '_orig', weight)
|
181 |
+
else:
|
182 |
+
setattr(module, self.weight_name, weight)
|
183 |
+
bias = self.compute_bias(module)
|
184 |
+
setattr(module, self.bias_name, bias)
|
185 |
+
|
186 |
+
|
187 |
+
def remove_weight_norms(module, weight_name='weight', bias_name='bias'):
|
188 |
+
if hasattr(module, 'weight_ori') or hasattr(module, 'weight_orig'):
|
189 |
+
for k in list(module._forward_pre_hooks.keys()):
|
190 |
+
hook = module._forward_pre_hooks[k]
|
191 |
+
if (isinstance(hook, ScaledLR) or isinstance(hook, SpectralNorm)):
|
192 |
+
hook.remove(module)
|
193 |
+
del module._forward_pre_hooks[k]
|
194 |
+
|
195 |
+
for k, hook in module._state_dict_hooks.items():
|
196 |
+
if isinstance(hook, SpectralNormStateDictHook) and \
|
197 |
+
hook.fn.name == weight_name:
|
198 |
+
del module._state_dict_hooks[k]
|
199 |
+
break
|
200 |
+
|
201 |
+
for k, hook in module._load_state_dict_pre_hooks.items():
|
202 |
+
if isinstance(hook, SpectralNormLoadStateDictPreHook) and \
|
203 |
+
hook.fn.name == weight_name:
|
204 |
+
del module._load_state_dict_pre_hooks[k]
|
205 |
+
break
|
206 |
+
|
207 |
+
return module
|
208 |
+
|
209 |
+
|
210 |
+
def remove_equalized_lr(module, weight_name='weight', bias_name='bias'):
|
211 |
+
for k, hook in module._forward_pre_hooks.items():
|
212 |
+
if isinstance(hook, ScaledLR) and hook.weight_name == weight_name:
|
213 |
+
hook.remove(module)
|
214 |
+
del module._forward_pre_hooks[k]
|
215 |
+
break
|
216 |
+
else:
|
217 |
+
raise ValueError("Equalized learning rate not found")
|
218 |
+
|
219 |
+
return module
|
220 |
+
|
221 |
+
|
222 |
+
def scaled_lr(
|
223 |
+
module, weight_name='weight', bias_name='bias', lr_mul=1.,
|
224 |
+
equalized=False,
|
225 |
+
):
|
226 |
+
ScaledLR.apply(module, weight_name, bias_name, lr_mul, equalized)
|
227 |
+
return module
|
228 |
+
|
229 |
+
|
230 |
+
def get_weight_norm_layer(norm_type, **norm_params):
|
231 |
+
r"""Return weight normalization.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
norm_type (str):
|
235 |
+
Type of weight normalization.
|
236 |
+
``'none'``, ``'spectral'``, ``'weight'``
|
237 |
+
or ``'weight_demod'``.
|
238 |
+
norm_params: Arbitrary keyword arguments that will be used to
|
239 |
+
initialize the weight normalization.
|
240 |
+
"""
|
241 |
+
if norm_type == 'none' or norm_type == '': # no normalization
|
242 |
+
return lambda x: x
|
243 |
+
elif norm_type == 'spectral': # spectral normalization
|
244 |
+
return functools.partial(spectral_norm, **norm_params)
|
245 |
+
elif norm_type == 'weight': # weight normalization
|
246 |
+
return functools.partial(weight_norm, **norm_params)
|
247 |
+
elif norm_type == 'weight_demod': # weight demodulation
|
248 |
+
return functools.partial(weight_demod, **norm_params)
|
249 |
+
elif norm_type == 'equalized_lr': # equalized learning rate
|
250 |
+
return functools.partial(scaled_lr, equalized=True, **norm_params)
|
251 |
+
elif norm_type == 'scaled_lr': # equalized learning rate
|
252 |
+
return functools.partial(scaled_lr, **norm_params)
|
253 |
+
elif norm_type == 'equalized_lr_spectral':
|
254 |
+
lr_mul = norm_params.pop('lr_mul', 1.0)
|
255 |
+
return lambda x: functools.partial(
|
256 |
+
scaled_lr, equalized=True, lr_mul=lr_mul)(
|
257 |
+
functools.partial(spectral_norm, **norm_params)(x)
|
258 |
+
)
|
259 |
+
elif norm_type == 'scaled_lr_spectral':
|
260 |
+
lr_mul = norm_params.pop('lr_mul', 1.0)
|
261 |
+
return lambda x: functools.partial(
|
262 |
+
scaled_lr, lr_mul=lr_mul)(
|
263 |
+
functools.partial(spectral_norm, **norm_params)(x)
|
264 |
+
)
|
265 |
+
else:
|
266 |
+
raise ValueError(
|
267 |
+
'Weight norm layer %s is not recognized' % norm_type)
|
imaginaire/losses/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
from .gan import GANLoss
|
6 |
+
from .perceptual import PerceptualLoss
|
7 |
+
from .feature_matching import FeatureMatchingLoss
|
8 |
+
from .kl import GaussianKLLoss
|
9 |
+
|
10 |
+
__all__ = ['GANLoss', 'PerceptualLoss', 'FeatureMatchingLoss', 'GaussianKLLoss',
|
11 |
+
'MaskedL1Loss', 'FlowLoss', 'DictLoss',
|
12 |
+
'WeightedMSELoss']
|
13 |
+
|
14 |
+
try:
|
15 |
+
from .gradient_penalty import GradientPenaltyLoss
|
16 |
+
__all__.extend(['GradientPenaltyLoss'])
|
17 |
+
except: # noqa
|
18 |
+
pass
|
imaginaire/losses/feature_matching.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
class FeatureMatchingLoss(nn.Module):
|
9 |
+
r"""Compute feature matching loss"""
|
10 |
+
def __init__(self, criterion='l1'):
|
11 |
+
super(FeatureMatchingLoss, self).__init__()
|
12 |
+
if criterion == 'l1':
|
13 |
+
self.criterion = nn.L1Loss()
|
14 |
+
elif criterion == 'l2' or criterion == 'mse':
|
15 |
+
self.criterion = nn.MSELoss()
|
16 |
+
else:
|
17 |
+
raise ValueError('Criterion %s is not recognized' % criterion)
|
18 |
+
|
19 |
+
def forward(self, fake_features, real_features):
|
20 |
+
r"""Return the target vector for the binary cross entropy loss
|
21 |
+
computation.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
fake_features (list of lists): Discriminator features of fake images.
|
25 |
+
real_features (list of lists): Discriminator features of real images.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
(tensor): Loss value.
|
29 |
+
"""
|
30 |
+
num_d = len(fake_features)
|
31 |
+
dis_weight = 1.0 / num_d
|
32 |
+
loss = fake_features[0][0].new_tensor(0)
|
33 |
+
for i in range(num_d):
|
34 |
+
for j in range(len(fake_features[i])):
|
35 |
+
tmp_loss = self.criterion(fake_features[i][j],
|
36 |
+
real_features[i][j].detach())
|
37 |
+
loss += dis_weight * tmp_loss
|
38 |
+
return loss
|
imaginaire/losses/gan.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from imaginaire.utils.distributed import master_only_print as print
|
11 |
+
|
12 |
+
|
13 |
+
@torch.jit.script
|
14 |
+
def fuse_math_min_mean_pos(x):
|
15 |
+
r"""Fuse operation min mean for hinge loss computation of positive
|
16 |
+
samples"""
|
17 |
+
minval = torch.min(x - 1, x * 0)
|
18 |
+
loss = -torch.mean(minval)
|
19 |
+
return loss
|
20 |
+
|
21 |
+
|
22 |
+
@torch.jit.script
|
23 |
+
def fuse_math_min_mean_neg(x):
|
24 |
+
r"""Fuse operation min mean for hinge loss computation of negative
|
25 |
+
samples"""
|
26 |
+
minval = torch.min(-x - 1, x * 0)
|
27 |
+
loss = -torch.mean(minval)
|
28 |
+
return loss
|
29 |
+
|
30 |
+
|
31 |
+
class GANLoss(nn.Module):
|
32 |
+
r"""GAN loss constructor.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
gan_mode (str): Type of GAN loss. ``'hinge'``, ``'least_square'``,
|
36 |
+
``'non_saturated'``, ``'wasserstein'``.
|
37 |
+
target_real_label (float): The desired output label for real images.
|
38 |
+
target_fake_label (float): The desired output label for fake images.
|
39 |
+
decay_k (float): The decay factor per epoch for top-k training.
|
40 |
+
min_k (float): The minimum percentage of samples to select.
|
41 |
+
separate_topk (bool): If ``True``, selects top-k for each sample
|
42 |
+
separately, otherwise selects top-k among all samples.
|
43 |
+
"""
|
44 |
+
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
|
45 |
+
decay_k=1., min_k=1., separate_topk=False):
|
46 |
+
super(GANLoss, self).__init__()
|
47 |
+
self.real_label = target_real_label
|
48 |
+
self.fake_label = target_fake_label
|
49 |
+
self.real_label_tensor = None
|
50 |
+
self.fake_label_tensor = None
|
51 |
+
self.gan_mode = gan_mode
|
52 |
+
self.decay_k = decay_k
|
53 |
+
self.min_k = min_k
|
54 |
+
self.separate_topk = separate_topk
|
55 |
+
self.register_buffer('k', torch.tensor(1.0))
|
56 |
+
print('GAN mode: %s' % gan_mode)
|
57 |
+
|
58 |
+
def forward(self, dis_output, t_real, dis_update=True, reduce=True):
|
59 |
+
r"""GAN loss computation.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
dis_output (tensor or list of tensors): Discriminator outputs.
|
63 |
+
t_real (bool): If ``True``, uses the real label as target, otherwise uses the fake label as target.
|
64 |
+
dis_update (bool): If ``True``, the loss will be used to update the discriminator, otherwise the generator.
|
65 |
+
reduce (bool): If ``True``, when a list of discriminator outputs are provided, it will return the average
|
66 |
+
of all losses, otherwise it will return a list of losses.
|
67 |
+
Returns:
|
68 |
+
loss (tensor): Loss value.
|
69 |
+
"""
|
70 |
+
if isinstance(dis_output, list):
|
71 |
+
# For multi-scale discriminators.
|
72 |
+
# In this implementation, the loss is first averaged for each scale
|
73 |
+
# (batch size and number of locations) then averaged across scales,
|
74 |
+
# so that the gradient is not dominated by the discriminator that
|
75 |
+
# has the most output values (highest resolution).
|
76 |
+
losses = []
|
77 |
+
for dis_output_i in dis_output:
|
78 |
+
assert isinstance(dis_output_i, torch.Tensor)
|
79 |
+
losses.append(self.loss(dis_output_i, t_real, dis_update))
|
80 |
+
if reduce:
|
81 |
+
return torch.mean(torch.stack(losses))
|
82 |
+
else:
|
83 |
+
return losses
|
84 |
+
else:
|
85 |
+
return self.loss(dis_output, t_real, dis_update)
|
86 |
+
|
87 |
+
def loss(self, dis_output, t_real, dis_update=True):
|
88 |
+
r"""GAN loss computation.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
dis_output (tensor): Discriminator outputs.
|
92 |
+
t_real (bool): If ``True``, uses the real label as target, otherwise
|
93 |
+
uses the fake label as target.
|
94 |
+
dis_update (bool): Updating the discriminator or the generator.
|
95 |
+
Returns:
|
96 |
+
loss (tensor): Loss value.
|
97 |
+
"""
|
98 |
+
if not dis_update:
|
99 |
+
assert t_real, \
|
100 |
+
"The target should be real when updating the generator."
|
101 |
+
|
102 |
+
if not dis_update and self.k < 1:
|
103 |
+
r"""
|
104 |
+
Use top-k training:
|
105 |
+
"Top-k Training of GANs: Improving GAN Performance by Throwing
|
106 |
+
Away Bad Samples"
|
107 |
+
Here, each sample may have multiple discriminator output values
|
108 |
+
(patch discriminator). We could either select top-k for each sample
|
109 |
+
separately (when ``self.separate_topk=True``), or collect values
|
110 |
+
from all samples and then select top-k (default, when
|
111 |
+
``self.separate_topk=False``).
|
112 |
+
"""
|
113 |
+
if self.separate_topk:
|
114 |
+
dis_output = dis_output.view(dis_output.size(0), -1)
|
115 |
+
else:
|
116 |
+
dis_output = dis_output.view(-1)
|
117 |
+
k = math.ceil(self.k * dis_output.size(-1))
|
118 |
+
dis_output, _ = torch.topk(dis_output, k)
|
119 |
+
|
120 |
+
if self.gan_mode == 'non_saturated':
|
121 |
+
target_tensor = self.get_target_tensor(dis_output, t_real)
|
122 |
+
loss = F.binary_cross_entropy_with_logits(dis_output,
|
123 |
+
target_tensor)
|
124 |
+
elif self.gan_mode == 'least_square':
|
125 |
+
target_tensor = self.get_target_tensor(dis_output, t_real)
|
126 |
+
loss = 0.5 * F.mse_loss(dis_output, target_tensor)
|
127 |
+
elif self.gan_mode == 'hinge':
|
128 |
+
if dis_update:
|
129 |
+
if t_real:
|
130 |
+
loss = fuse_math_min_mean_pos(dis_output)
|
131 |
+
else:
|
132 |
+
loss = fuse_math_min_mean_neg(dis_output)
|
133 |
+
else:
|
134 |
+
loss = -torch.mean(dis_output)
|
135 |
+
elif self.gan_mode == 'wasserstein':
|
136 |
+
if t_real:
|
137 |
+
loss = -torch.mean(dis_output)
|
138 |
+
else:
|
139 |
+
loss = torch.mean(dis_output)
|
140 |
+
elif self.gan_mode == 'softplus':
|
141 |
+
target_tensor = self.get_target_tensor(dis_output, t_real)
|
142 |
+
loss = F.binary_cross_entropy_with_logits(dis_output,
|
143 |
+
target_tensor)
|
144 |
+
else:
|
145 |
+
raise ValueError('Unexpected gan_mode {}'.format(self.gan_mode))
|
146 |
+
return loss
|
147 |
+
|
148 |
+
def get_target_tensor(self, dis_output, t_real):
|
149 |
+
r"""Return the target vector for the binary cross entropy loss
|
150 |
+
computation.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
dis_output (tensor): Discriminator outputs.
|
154 |
+
t_real (bool): If ``True``, uses the real label as target, otherwise
|
155 |
+
uses the fake label as target.
|
156 |
+
Returns:
|
157 |
+
target (tensor): Target tensor vector.
|
158 |
+
"""
|
159 |
+
if t_real:
|
160 |
+
if self.real_label_tensor is None:
|
161 |
+
self.real_label_tensor = dis_output.new_tensor(self.real_label)
|
162 |
+
return self.real_label_tensor.expand_as(dis_output)
|
163 |
+
else:
|
164 |
+
if self.fake_label_tensor is None:
|
165 |
+
self.fake_label_tensor = dis_output.new_tensor(self.fake_label)
|
166 |
+
return self.fake_label_tensor.expand_as(dis_output)
|
167 |
+
|
168 |
+
def topk_anneal(self):
|
169 |
+
r"""Anneal k after each epoch."""
|
170 |
+
if self.decay_k < 1:
|
171 |
+
# noinspection PyAttributeOutsideInit
|
172 |
+
self.k.fill_(max(self.decay_k * self.k, self.min_k))
|
173 |
+
print("Top-k training: update k to {}.".format(self.k))
|
imaginaire/losses/info_nce.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.distributed as dist
|
10 |
+
|
11 |
+
from imaginaire.utils.distributed import get_world_size, get_rank, \
|
12 |
+
dist_all_reduce_tensor
|
13 |
+
|
14 |
+
|
15 |
+
class GatherLayer(torch.autograd.Function):
|
16 |
+
@staticmethod
|
17 |
+
def forward(ctx, input):
|
18 |
+
ctx.save_for_backward(input)
|
19 |
+
output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
|
20 |
+
dist.all_gather(output, input)
|
21 |
+
return tuple(output)
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def backward(ctx, *grads):
|
25 |
+
input, = ctx.saved_tensors
|
26 |
+
grad_out = torch.zeros_like(input)
|
27 |
+
all_grads = torch.stack(grads)
|
28 |
+
all_grads = dist_all_reduce_tensor(all_grads, reduce='sum')
|
29 |
+
grad_out[:] = all_grads[get_rank()]
|
30 |
+
return grad_out
|
31 |
+
|
32 |
+
|
33 |
+
class InfoNCELoss(nn.Module):
|
34 |
+
def __init__(self,
|
35 |
+
temperature=0.07,
|
36 |
+
gather_distributed=True,
|
37 |
+
learn_temperature=True,
|
38 |
+
single_direction=False,
|
39 |
+
flatten=True):
|
40 |
+
super(InfoNCELoss, self).__init__()
|
41 |
+
self.logit_scale = nn.Parameter(torch.tensor([math.log(1/temperature)]))
|
42 |
+
self.logit_scale.requires_grad = learn_temperature
|
43 |
+
self.gather_distributed = gather_distributed
|
44 |
+
self.single_direction = single_direction
|
45 |
+
self.flatten = flatten
|
46 |
+
|
47 |
+
def forward(self, features_a, features_b, gather_distributed=None, eps=1e-8):
|
48 |
+
if gather_distributed is None:
|
49 |
+
gather_distributed = self.gather_distributed
|
50 |
+
|
51 |
+
if features_a is None or features_b is None:
|
52 |
+
return torch.tensor(0, device='cuda'), torch.tensor(0, device='cuda')
|
53 |
+
|
54 |
+
bs_a, bs_b = features_a.size(0), features_b.size(0)
|
55 |
+
if self.flatten:
|
56 |
+
features_a, features_b = features_a.reshape(bs_a, -1), features_b.reshape(bs_b, -1)
|
57 |
+
else:
|
58 |
+
features_a = features_a.reshape(bs_a, features_a.size(1), -1).mean(-1)
|
59 |
+
features_b = features_b.reshape(bs_b, features_b.size(1), -1).mean(-1)
|
60 |
+
|
61 |
+
# Temperature clipping.
|
62 |
+
self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 4.6052)
|
63 |
+
|
64 |
+
# normalized features
|
65 |
+
features_a = features_a / (features_a.norm(dim=1, keepdim=True) + eps)
|
66 |
+
features_b = features_b / (features_b.norm(dim=1, keepdim=True) + eps)
|
67 |
+
|
68 |
+
loss_a = self._forward_single_direction(features_a, features_b, gather_distributed)
|
69 |
+
if self.single_direction:
|
70 |
+
return loss_a
|
71 |
+
else:
|
72 |
+
loss_b = self._forward_single_direction(features_b, features_a, gather_distributed)
|
73 |
+
return loss_a + loss_b
|
74 |
+
|
75 |
+
def _forward_single_direction(
|
76 |
+
self, features_a, features_b, gather_distributed):
|
77 |
+
bs_a = features_a.shape[0]
|
78 |
+
logit_scale = self.logit_scale.exp()
|
79 |
+
if get_world_size() > 1 and gather_distributed:
|
80 |
+
gather_features_b = torch.cat(GatherLayer.apply(features_b))
|
81 |
+
gather_labels_a = torch.arange(bs_a, device='cuda') + get_rank() * bs_a
|
82 |
+
logits_a = logit_scale * features_a @ gather_features_b.t()
|
83 |
+
else:
|
84 |
+
gather_labels_a = torch.arange(bs_a, device='cuda')
|
85 |
+
logits_a = logit_scale * features_a @ features_b.t()
|
86 |
+
loss_a = F.cross_entropy(logits_a, gather_labels_a)
|
87 |
+
return loss_a
|
imaginaire/losses/kl.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class GaussianKLLoss(nn.Module):
|
10 |
+
r"""Compute KL loss in VAE for Gaussian distributions"""
|
11 |
+
def __init__(self):
|
12 |
+
super(GaussianKLLoss, self).__init__()
|
13 |
+
|
14 |
+
def forward(self, mu, logvar=None):
|
15 |
+
r"""Compute loss
|
16 |
+
|
17 |
+
Args:
|
18 |
+
mu (tensor): mean
|
19 |
+
logvar (tensor): logarithm of variance
|
20 |
+
"""
|
21 |
+
if logvar is None:
|
22 |
+
logvar = torch.zeros_like(mu)
|
23 |
+
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
imaginaire/losses/perceptual.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchvision
|
8 |
+
from torch import nn, distributed as dist
|
9 |
+
|
10 |
+
from imaginaire.losses.info_nce import InfoNCELoss
|
11 |
+
from imaginaire.utils.distributed import master_only_print as print, \
|
12 |
+
is_local_master
|
13 |
+
from imaginaire.utils.misc import apply_imagenet_normalization, to_float
|
14 |
+
|
15 |
+
|
16 |
+
class PerceptualLoss(nn.Module):
|
17 |
+
r"""Perceptual loss initialization.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
network (str) : The name of the loss network: 'vgg16' | 'vgg19'.
|
21 |
+
layers (str or list of str) : The layers used to compute the loss.
|
22 |
+
weights (float or list of float : The loss weights of each layer.
|
23 |
+
criterion (str): The type of distance function: 'l1' | 'l2'.
|
24 |
+
resize (bool) : If ``True``, resize the input images to 224x224.
|
25 |
+
resize_mode (str): Algorithm used for resizing.
|
26 |
+
num_scales (int): The loss will be evaluated at original size and
|
27 |
+
this many times downsampled sizes.
|
28 |
+
per_sample_weight (bool): Output loss for individual samples in the
|
29 |
+
batch instead of mean loss.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, network='vgg19', layers='relu_4_1', weights=None,
|
33 |
+
criterion='l1', resize=False, resize_mode='bilinear',
|
34 |
+
num_scales=1, per_sample_weight=False,
|
35 |
+
info_nce_temperature=0.07,
|
36 |
+
info_nce_gather_distributed=True,
|
37 |
+
info_nce_learn_temperature=True,
|
38 |
+
info_nce_flatten=True):
|
39 |
+
super().__init__()
|
40 |
+
if isinstance(layers, str):
|
41 |
+
layers = [layers]
|
42 |
+
if weights is None:
|
43 |
+
weights = [1.] * len(layers)
|
44 |
+
elif isinstance(layers, float) or isinstance(layers, int):
|
45 |
+
weights = [weights]
|
46 |
+
|
47 |
+
if dist.is_initialized() and not is_local_master():
|
48 |
+
# Make sure only the first process in distributed training downloads
|
49 |
+
# the model, and the others will use the cache
|
50 |
+
# noinspection PyUnresolvedReferences
|
51 |
+
torch.distributed.barrier()
|
52 |
+
|
53 |
+
assert len(layers) == len(weights), \
|
54 |
+
'The number of layers (%s) must be equal to ' \
|
55 |
+
'the number of weights (%s).' % (len(layers), len(weights))
|
56 |
+
if network == 'vgg19':
|
57 |
+
self.model = _vgg19(layers)
|
58 |
+
elif network == 'vgg16':
|
59 |
+
self.model = _vgg16(layers)
|
60 |
+
elif network == 'alexnet':
|
61 |
+
self.model = _alexnet(layers)
|
62 |
+
elif network == 'inception_v3':
|
63 |
+
self.model = _inception_v3(layers)
|
64 |
+
elif network == 'resnet50':
|
65 |
+
self.model = _resnet50(layers)
|
66 |
+
elif network == 'robust_resnet50':
|
67 |
+
self.model = _robust_resnet50(layers)
|
68 |
+
elif network == 'vgg_face_dag':
|
69 |
+
self.model = _vgg_face_dag(layers)
|
70 |
+
else:
|
71 |
+
raise ValueError('Network %s is not recognized' % network)
|
72 |
+
|
73 |
+
if dist.is_initialized() and is_local_master():
|
74 |
+
# Make sure only the first process in distributed training downloads
|
75 |
+
# the model, and the others will use the cache
|
76 |
+
# noinspection PyUnresolvedReferences
|
77 |
+
torch.distributed.barrier()
|
78 |
+
|
79 |
+
self.num_scales = num_scales
|
80 |
+
self.layers = layers
|
81 |
+
self.weights = weights
|
82 |
+
reduction = 'mean' if not per_sample_weight else 'none'
|
83 |
+
if criterion == 'l1':
|
84 |
+
self.criterion = nn.L1Loss(reduction=reduction)
|
85 |
+
elif criterion == 'l2' or criterion == 'mse':
|
86 |
+
self.criterion = nn.MSELoss(reduction=reduction)
|
87 |
+
elif criterion == 'info_nce':
|
88 |
+
self.criterion = InfoNCELoss(
|
89 |
+
temperature=info_nce_temperature,
|
90 |
+
gather_distributed=info_nce_gather_distributed,
|
91 |
+
learn_temperature=info_nce_learn_temperature,
|
92 |
+
flatten=info_nce_flatten,
|
93 |
+
single_direction=True
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
raise ValueError('Criterion %s is not recognized' % criterion)
|
97 |
+
self.resize = resize
|
98 |
+
self.resize_mode = resize_mode
|
99 |
+
print('Perceptual loss:')
|
100 |
+
print('\tMode: {}'.format(network))
|
101 |
+
|
102 |
+
def forward(self, inp, target, per_sample_weights=None):
|
103 |
+
r"""Perceptual loss forward.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
inp (4D tensor) : Input tensor.
|
107 |
+
target (4D tensor) : Ground truth tensor, same shape as the input.
|
108 |
+
per_sample_weight (bool): Output loss for individual samples in the
|
109 |
+
batch instead of mean loss.
|
110 |
+
Returns:
|
111 |
+
(scalar tensor) : The perceptual loss.
|
112 |
+
"""
|
113 |
+
if not torch.is_autocast_enabled():
|
114 |
+
inp, target = to_float([inp, target])
|
115 |
+
|
116 |
+
# Perceptual loss should operate in eval mode by default.
|
117 |
+
self.model.eval()
|
118 |
+
inp, target = apply_imagenet_normalization(inp), apply_imagenet_normalization(target)
|
119 |
+
if self.resize:
|
120 |
+
inp = F.interpolate(inp, mode=self.resize_mode, size=(224, 224), align_corners=False)
|
121 |
+
target = F.interpolate(target, mode=self.resize_mode, size=(224, 224), align_corners=False)
|
122 |
+
|
123 |
+
# Evaluate perceptual loss at each scale.
|
124 |
+
loss = 0
|
125 |
+
for scale in range(self.num_scales):
|
126 |
+
input_features, target_features = self.model(inp), self.model(target)
|
127 |
+
|
128 |
+
for layer, weight in zip(self.layers, self.weights):
|
129 |
+
# Example per-layer VGG19 loss values after applying
|
130 |
+
# [0.03125, 0.0625, 0.125, 0.25, 1.0] weighting.
|
131 |
+
# relu_1_1, 0.014698
|
132 |
+
# relu_2_1, 0.085817
|
133 |
+
# relu_3_1, 0.349977
|
134 |
+
# relu_4_1, 0.544188
|
135 |
+
# relu_5_1, 0.906261
|
136 |
+
# print('%s, %f' % (
|
137 |
+
# layer,
|
138 |
+
# weight * self.criterion(
|
139 |
+
# input_features[layer],
|
140 |
+
# target_features[
|
141 |
+
# layer].detach()).item()))
|
142 |
+
l_tmp = self.criterion(input_features[layer], target_features[layer].detach())
|
143 |
+
if per_sample_weights is not None:
|
144 |
+
l_tmp = l_tmp.mean(1).mean(1).mean(1)
|
145 |
+
loss += weight * l_tmp
|
146 |
+
# Downsample the input and target.
|
147 |
+
if scale != self.num_scales - 1:
|
148 |
+
inp = F.interpolate(
|
149 |
+
inp, mode=self.resize_mode, scale_factor=0.5,
|
150 |
+
align_corners=False, recompute_scale_factor=True)
|
151 |
+
target = F.interpolate(
|
152 |
+
target, mode=self.resize_mode, scale_factor=0.5,
|
153 |
+
align_corners=False, recompute_scale_factor=True)
|
154 |
+
|
155 |
+
return loss.float()
|
156 |
+
|
157 |
+
|
158 |
+
class _PerceptualNetwork(nn.Module):
|
159 |
+
r"""The network that extracts features to compute the perceptual loss.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
network (nn.Sequential) : The network that extracts features.
|
163 |
+
layer_name_mapping (dict) : The dictionary that
|
164 |
+
maps a layer's index to its name.
|
165 |
+
layers (list of str): The list of layer names that we are using.
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, network, layer_name_mapping, layers):
|
169 |
+
super().__init__()
|
170 |
+
assert isinstance(network, nn.Sequential), \
|
171 |
+
'The network needs to be of type "nn.Sequential".'
|
172 |
+
self.network = network
|
173 |
+
self.layer_name_mapping = layer_name_mapping
|
174 |
+
self.layers = layers
|
175 |
+
for param in self.parameters():
|
176 |
+
param.requires_grad = False
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
r"""Extract perceptual features."""
|
180 |
+
output = {}
|
181 |
+
for i, layer in enumerate(self.network):
|
182 |
+
x = layer(x)
|
183 |
+
layer_name = self.layer_name_mapping.get(i, None)
|
184 |
+
if layer_name in self.layers:
|
185 |
+
# If the current layer is used by the perceptual loss.
|
186 |
+
output[layer_name] = x
|
187 |
+
return output
|
188 |
+
|
189 |
+
|
190 |
+
def _vgg19(layers):
|
191 |
+
r"""Get vgg19 layers"""
|
192 |
+
vgg = torchvision.models.vgg19(pretrained=True)
|
193 |
+
# network = vgg.features
|
194 |
+
network = torch.nn.Sequential(*(list(vgg.features) + [vgg.avgpool] + [nn.Flatten()] + list(vgg.classifier)))
|
195 |
+
layer_name_mapping = {1: 'relu_1_1',
|
196 |
+
3: 'relu_1_2',
|
197 |
+
6: 'relu_2_1',
|
198 |
+
8: 'relu_2_2',
|
199 |
+
11: 'relu_3_1',
|
200 |
+
13: 'relu_3_2',
|
201 |
+
15: 'relu_3_3',
|
202 |
+
17: 'relu_3_4',
|
203 |
+
20: 'relu_4_1',
|
204 |
+
22: 'relu_4_2',
|
205 |
+
24: 'relu_4_3',
|
206 |
+
26: 'relu_4_4',
|
207 |
+
29: 'relu_5_1',
|
208 |
+
31: 'relu_5_2',
|
209 |
+
33: 'relu_5_3',
|
210 |
+
35: 'relu_5_4',
|
211 |
+
36: 'pool_5',
|
212 |
+
42: 'fc_2'}
|
213 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
214 |
+
|
215 |
+
|
216 |
+
def _vgg16(layers):
|
217 |
+
r"""Get vgg16 layers"""
|
218 |
+
network = torchvision.models.vgg16(pretrained=True).features
|
219 |
+
layer_name_mapping = {1: 'relu_1_1',
|
220 |
+
3: 'relu_1_2',
|
221 |
+
6: 'relu_2_1',
|
222 |
+
8: 'relu_2_2',
|
223 |
+
11: 'relu_3_1',
|
224 |
+
13: 'relu_3_2',
|
225 |
+
15: 'relu_3_3',
|
226 |
+
18: 'relu_4_1',
|
227 |
+
20: 'relu_4_2',
|
228 |
+
22: 'relu_4_3',
|
229 |
+
25: 'relu_5_1'}
|
230 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
231 |
+
|
232 |
+
|
233 |
+
def _alexnet(layers):
|
234 |
+
r"""Get alexnet layers"""
|
235 |
+
network = torchvision.models.alexnet(pretrained=True).features
|
236 |
+
layer_name_mapping = {0: 'conv_1',
|
237 |
+
1: 'relu_1',
|
238 |
+
3: 'conv_2',
|
239 |
+
4: 'relu_2',
|
240 |
+
6: 'conv_3',
|
241 |
+
7: 'relu_3',
|
242 |
+
8: 'conv_4',
|
243 |
+
9: 'relu_4',
|
244 |
+
10: 'conv_5',
|
245 |
+
11: 'relu_5'}
|
246 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
247 |
+
|
248 |
+
|
249 |
+
def _inception_v3(layers):
|
250 |
+
r"""Get inception v3 layers"""
|
251 |
+
inception = torchvision.models.inception_v3(pretrained=True)
|
252 |
+
network = nn.Sequential(inception.Conv2d_1a_3x3,
|
253 |
+
inception.Conv2d_2a_3x3,
|
254 |
+
inception.Conv2d_2b_3x3,
|
255 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
256 |
+
inception.Conv2d_3b_1x1,
|
257 |
+
inception.Conv2d_4a_3x3,
|
258 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
259 |
+
inception.Mixed_5b,
|
260 |
+
inception.Mixed_5c,
|
261 |
+
inception.Mixed_5d,
|
262 |
+
inception.Mixed_6a,
|
263 |
+
inception.Mixed_6b,
|
264 |
+
inception.Mixed_6c,
|
265 |
+
inception.Mixed_6d,
|
266 |
+
inception.Mixed_6e,
|
267 |
+
inception.Mixed_7a,
|
268 |
+
inception.Mixed_7b,
|
269 |
+
inception.Mixed_7c,
|
270 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1)))
|
271 |
+
layer_name_mapping = {3: 'pool_1',
|
272 |
+
6: 'pool_2',
|
273 |
+
14: 'mixed_6e',
|
274 |
+
18: 'pool_3'}
|
275 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
276 |
+
|
277 |
+
|
278 |
+
def _resnet50(layers):
|
279 |
+
r"""Get resnet50 layers"""
|
280 |
+
resnet50 = torchvision.models.resnet50(pretrained=True)
|
281 |
+
network = nn.Sequential(resnet50.conv1,
|
282 |
+
resnet50.bn1,
|
283 |
+
resnet50.relu,
|
284 |
+
resnet50.maxpool,
|
285 |
+
resnet50.layer1,
|
286 |
+
resnet50.layer2,
|
287 |
+
resnet50.layer3,
|
288 |
+
resnet50.layer4,
|
289 |
+
resnet50.avgpool)
|
290 |
+
layer_name_mapping = {4: 'layer_1',
|
291 |
+
5: 'layer_2',
|
292 |
+
6: 'layer_3',
|
293 |
+
7: 'layer_4'}
|
294 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
295 |
+
|
296 |
+
|
297 |
+
def _robust_resnet50(layers):
|
298 |
+
r"""Get robust resnet50 layers"""
|
299 |
+
resnet50 = torchvision.models.resnet50(pretrained=False)
|
300 |
+
state_dict = torch.utils.model_zoo.load_url(
|
301 |
+
'http://andrewilyas.com/ImageNet.pt')
|
302 |
+
new_state_dict = {}
|
303 |
+
for k, v in state_dict['model'].items():
|
304 |
+
if k.startswith('module.model.'):
|
305 |
+
new_state_dict[k[13:]] = v
|
306 |
+
resnet50.load_state_dict(new_state_dict)
|
307 |
+
network = nn.Sequential(resnet50.conv1,
|
308 |
+
resnet50.bn1,
|
309 |
+
resnet50.relu,
|
310 |
+
resnet50.maxpool,
|
311 |
+
resnet50.layer1,
|
312 |
+
resnet50.layer2,
|
313 |
+
resnet50.layer3,
|
314 |
+
resnet50.layer4,
|
315 |
+
resnet50.avgpool)
|
316 |
+
layer_name_mapping = {4: 'layer_1',
|
317 |
+
5: 'layer_2',
|
318 |
+
6: 'layer_3',
|
319 |
+
7: 'layer_4'}
|
320 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
321 |
+
|
322 |
+
|
323 |
+
def _vgg_face_dag(layers):
|
324 |
+
network = torchvision.models.vgg16(num_classes=2622)
|
325 |
+
state_dict = torch.utils.model_zoo.load_url(
|
326 |
+
'http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/'
|
327 |
+
'vgg_face_dag.pth')
|
328 |
+
feature_layer_name_mapping = {
|
329 |
+
0: 'conv1_1',
|
330 |
+
2: 'conv1_2',
|
331 |
+
5: 'conv2_1',
|
332 |
+
7: 'conv2_2',
|
333 |
+
10: 'conv3_1',
|
334 |
+
12: 'conv3_2',
|
335 |
+
14: 'conv3_3',
|
336 |
+
17: 'conv4_1',
|
337 |
+
19: 'conv4_2',
|
338 |
+
21: 'conv4_3',
|
339 |
+
24: 'conv5_1',
|
340 |
+
26: 'conv5_2',
|
341 |
+
28: 'conv5_3'}
|
342 |
+
new_state_dict = {}
|
343 |
+
for k, v in feature_layer_name_mapping.items():
|
344 |
+
new_state_dict['features.' + str(k) + '.weight'] = \
|
345 |
+
state_dict[v + '.weight']
|
346 |
+
new_state_dict['features.' + str(k) + '.bias'] = \
|
347 |
+
state_dict[v + '.bias']
|
348 |
+
|
349 |
+
classifier_layer_name_mapping = {
|
350 |
+
0: 'fc6',
|
351 |
+
3: 'fc7',
|
352 |
+
6: 'fc8'}
|
353 |
+
for k, v in classifier_layer_name_mapping.items():
|
354 |
+
new_state_dict['classifier.' + str(k) + '.weight'] = \
|
355 |
+
state_dict[v + '.weight']
|
356 |
+
new_state_dict['classifier.' + str(k) + '.bias'] = \
|
357 |
+
state_dict[v + '.bias']
|
358 |
+
|
359 |
+
network.load_state_dict(new_state_dict)
|
360 |
+
|
361 |
+
class Flatten(nn.Module):
|
362 |
+
def forward(self, x):
|
363 |
+
return x.view(x.shape[0], -1)
|
364 |
+
|
365 |
+
layer_name_mapping = {
|
366 |
+
0: 'conv_1_1',
|
367 |
+
1: 'relu_1_1',
|
368 |
+
2: 'conv_1_2',
|
369 |
+
5: 'conv_2_1', # 1/2
|
370 |
+
6: 'relu_2_1',
|
371 |
+
7: 'conv_2_2',
|
372 |
+
10: 'conv_3_1', # 1/4
|
373 |
+
11: 'relu_3_1',
|
374 |
+
12: 'conv_3_2',
|
375 |
+
14: 'conv_3_3',
|
376 |
+
17: 'conv_4_1', # 1/8
|
377 |
+
18: 'relu_4_1',
|
378 |
+
19: 'conv_4_2',
|
379 |
+
21: 'conv_4_3',
|
380 |
+
24: 'conv_5_1', # 1/16
|
381 |
+
25: 'relu_5_1',
|
382 |
+
26: 'conv_5_2',
|
383 |
+
28: 'conv_5_3',
|
384 |
+
33: 'fc6',
|
385 |
+
36: 'fc7',
|
386 |
+
39: 'fc8'
|
387 |
+
}
|
388 |
+
seq_layers = []
|
389 |
+
for feature in network.features:
|
390 |
+
seq_layers += [feature]
|
391 |
+
seq_layers += [network.avgpool, Flatten()]
|
392 |
+
for classifier in network.classifier:
|
393 |
+
seq_layers += [classifier]
|
394 |
+
network = nn.Sequential(*seq_layers)
|
395 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
imaginaire/losses/weighted_mse.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class WeightedMSELoss(nn.Module):
|
10 |
+
r"""Compute Weighted MSE loss"""
|
11 |
+
def __init__(self, reduction='mean'):
|
12 |
+
super(WeightedMSELoss, self).__init__()
|
13 |
+
self.reduction = reduction
|
14 |
+
|
15 |
+
def forward(self, input, target, weight):
|
16 |
+
r"""Return weighted MSE Loss.
|
17 |
+
Args:
|
18 |
+
input (tensor):
|
19 |
+
target (tensor):
|
20 |
+
weight (tensor):
|
21 |
+
Returns:
|
22 |
+
(tensor): Loss value.
|
23 |
+
"""
|
24 |
+
if self.reduction == 'mean':
|
25 |
+
loss = torch.mean(weight * (input - target) ** 2)
|
26 |
+
else:
|
27 |
+
loss = torch.sum(weight * (input - target) ** 2)
|
28 |
+
return loss
|
imaginaire/model_utils/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
imaginaire/model_utils/gancraft/camctl.py
ADDED
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class EvalCameraController:
|
10 |
+
def __init__(self, voxel, maxstep=128, pattern=0, cam_ang=73, smooth_decay_multiplier=1.0):
|
11 |
+
self.voxel = voxel
|
12 |
+
self.maxstep = maxstep
|
13 |
+
self.camera_poses = [] # ori, dir, up, f
|
14 |
+
circle = torch.linspace(0, 2*np.pi, steps=maxstep)
|
15 |
+
size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
|
16 |
+
# Shrink the circle a bit.
|
17 |
+
shift = size * 0.2
|
18 |
+
size = size * 0.8
|
19 |
+
|
20 |
+
if pattern == 0:
|
21 |
+
height_history = []
|
22 |
+
# Calculate smooth height.
|
23 |
+
for i in range(maxstep):
|
24 |
+
farpoint = torch.tensor([
|
25 |
+
70,
|
26 |
+
torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
|
27 |
+
torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
|
28 |
+
height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
|
29 |
+
|
30 |
+
# Filtfilt
|
31 |
+
height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
|
32 |
+
|
33 |
+
for i in range(maxstep):
|
34 |
+
farpoint = torch.tensor([
|
35 |
+
70,
|
36 |
+
torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
|
37 |
+
torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
|
38 |
+
|
39 |
+
farpoint[0] = height_history[i]
|
40 |
+
|
41 |
+
nearpoint = torch.tensor([
|
42 |
+
60,
|
43 |
+
torch.sin(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(1)/2 + shift,
|
44 |
+
torch.cos(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(2)/2 + shift])
|
45 |
+
cam_ori = self.voxel.world2local(farpoint)
|
46 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
47 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
48 |
+
cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
|
49 |
+
|
50 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
51 |
+
|
52 |
+
elif pattern == 1:
|
53 |
+
zoom = torch.linspace(1.0, 0.25, steps=maxstep)
|
54 |
+
height_history = []
|
55 |
+
for i in range(maxstep):
|
56 |
+
farpoint = torch.tensor([
|
57 |
+
90,
|
58 |
+
torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
|
59 |
+
torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
|
60 |
+
|
61 |
+
height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
|
62 |
+
|
63 |
+
height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
|
64 |
+
|
65 |
+
for i in range(maxstep):
|
66 |
+
farpoint = torch.tensor([
|
67 |
+
90,
|
68 |
+
torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
|
69 |
+
torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
|
70 |
+
|
71 |
+
farpoint[0] = height_history[i]
|
72 |
+
|
73 |
+
nearpoint = torch.tensor([
|
74 |
+
60,
|
75 |
+
torch.sin(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(1)/2 + shift,
|
76 |
+
torch.cos(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(2)/2 + shift])
|
77 |
+
cam_ori = self.voxel.world2local(farpoint)
|
78 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
79 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
80 |
+
cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)*zoom[i]) # about 24mm fov
|
81 |
+
|
82 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
83 |
+
|
84 |
+
elif pattern == 2:
|
85 |
+
move = torch.linspace(1.0, 0.2, steps=maxstep)
|
86 |
+
height_history = []
|
87 |
+
for i in range(maxstep):
|
88 |
+
farpoint = torch.tensor([
|
89 |
+
90,
|
90 |
+
torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
91 |
+
torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
92 |
+
|
93 |
+
height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
|
94 |
+
|
95 |
+
height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
|
96 |
+
|
97 |
+
for i in range(maxstep):
|
98 |
+
farpoint = torch.tensor([
|
99 |
+
90,
|
100 |
+
torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
101 |
+
torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
102 |
+
|
103 |
+
farpoint[0] = height_history[i]
|
104 |
+
|
105 |
+
nearpoint = torch.tensor([
|
106 |
+
60,
|
107 |
+
torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
108 |
+
torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
109 |
+
cam_ori = self.voxel.world2local(farpoint)
|
110 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
111 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
112 |
+
cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
|
113 |
+
|
114 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
115 |
+
|
116 |
+
elif pattern == 3:
|
117 |
+
move = torch.linspace(0.75, 0.2, steps=maxstep)
|
118 |
+
height_history = []
|
119 |
+
for i in range(maxstep):
|
120 |
+
farpoint = torch.tensor([
|
121 |
+
70,
|
122 |
+
torch.sin(-circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
123 |
+
torch.cos(-circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
124 |
+
|
125 |
+
height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
|
126 |
+
|
127 |
+
height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
|
128 |
+
|
129 |
+
for i in range(maxstep):
|
130 |
+
farpoint = torch.tensor([
|
131 |
+
70,
|
132 |
+
torch.sin(-circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
133 |
+
torch.cos(-circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
134 |
+
|
135 |
+
farpoint[0] = height_history[i]
|
136 |
+
|
137 |
+
nearpoint = torch.tensor([
|
138 |
+
60,
|
139 |
+
torch.sin(-circle[i]-0.4*np.pi)*size*0.9*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
140 |
+
torch.cos(-circle[i]-0.4*np.pi)*size*0.9*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
141 |
+
cam_ori = self.voxel.world2local(farpoint)
|
142 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
143 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
144 |
+
cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
|
145 |
+
|
146 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
147 |
+
|
148 |
+
elif pattern == 4:
|
149 |
+
move = torch.linspace(1.0, 0.5, steps=maxstep)
|
150 |
+
height_history = []
|
151 |
+
for i in range(maxstep):
|
152 |
+
farpoint = torch.tensor([
|
153 |
+
90,
|
154 |
+
torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
155 |
+
torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
156 |
+
|
157 |
+
height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
|
158 |
+
|
159 |
+
height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
|
160 |
+
|
161 |
+
for i in range(maxstep):
|
162 |
+
farpoint = torch.tensor([
|
163 |
+
90,
|
164 |
+
torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
165 |
+
torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
166 |
+
|
167 |
+
farpoint[0] = height_history[i]
|
168 |
+
|
169 |
+
nearpoint = torch.tensor([
|
170 |
+
60,
|
171 |
+
torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
172 |
+
torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
173 |
+
cam_ori = self.voxel.world2local(farpoint)
|
174 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
175 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
176 |
+
cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
|
177 |
+
|
178 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
179 |
+
|
180 |
+
# look outward
|
181 |
+
elif pattern == 5:
|
182 |
+
move = torch.linspace(1.0, 0.5, steps=maxstep)
|
183 |
+
height_history = []
|
184 |
+
for i in range(maxstep):
|
185 |
+
nearpoint = torch.tensor([
|
186 |
+
60,
|
187 |
+
torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
188 |
+
torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
189 |
+
|
190 |
+
height_history.append(self._get_height(nearpoint[1], nearpoint[2], nearpoint[0]))
|
191 |
+
|
192 |
+
height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
|
193 |
+
|
194 |
+
for i in range(maxstep):
|
195 |
+
nearpoint = torch.tensor([
|
196 |
+
60,
|
197 |
+
torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
198 |
+
torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
199 |
+
|
200 |
+
nearpoint[0] = height_history[i]
|
201 |
+
|
202 |
+
farpoint = torch.tensor([
|
203 |
+
60,
|
204 |
+
torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
205 |
+
torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
206 |
+
|
207 |
+
cam_ori = self.voxel.world2local(nearpoint)
|
208 |
+
cam_dir = self.voxel.world2local(farpoint - nearpoint, is_vec=True)
|
209 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
210 |
+
cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
|
211 |
+
|
212 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
213 |
+
# Rise
|
214 |
+
elif pattern == 6:
|
215 |
+
shift = 0
|
216 |
+
lift = torch.linspace(0.0, 200.0, steps=maxstep)
|
217 |
+
zoom = torch.linspace(0.8, 1.6, steps=maxstep)
|
218 |
+
for i in range(maxstep):
|
219 |
+
farpoint = torch.tensor([
|
220 |
+
80+lift[i],
|
221 |
+
torch.sin(circle[i]/4)*size*0.2 + voxel.voxel_t.size(1)/2 + shift,
|
222 |
+
torch.cos(circle[i]/4)*size*0.2 + voxel.voxel_t.size(2)/2 + shift])
|
223 |
+
|
224 |
+
farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
|
225 |
+
|
226 |
+
nearpoint = torch.tensor([
|
227 |
+
65,
|
228 |
+
torch.sin(circle[i]/4+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(1)/2 + shift,
|
229 |
+
torch.cos(circle[i]/4+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(2)/2 + shift])
|
230 |
+
cam_ori = self.voxel.world2local(farpoint)
|
231 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
232 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
233 |
+
cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov
|
234 |
+
|
235 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
236 |
+
# 45deg
|
237 |
+
elif pattern == 7:
|
238 |
+
rad = torch.tensor([np.deg2rad(45).astype(np.float32)])
|
239 |
+
size = 1536
|
240 |
+
for i in range(maxstep):
|
241 |
+
farpoint = torch.tensor([
|
242 |
+
61+size,
|
243 |
+
torch.sin(rad)*size + voxel.voxel_t.size(1)/2,
|
244 |
+
torch.cos(rad)*size + voxel.voxel_t.size(2)/2])
|
245 |
+
|
246 |
+
nearpoint = torch.tensor([
|
247 |
+
61,
|
248 |
+
voxel.voxel_t.size(1)/2,
|
249 |
+
voxel.voxel_t.size(2)/2])
|
250 |
+
cam_ori = self.voxel.world2local(farpoint)
|
251 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
252 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
253 |
+
cam_f = 0.5/np.tan(np.deg2rad(19.5/2)) # about 50mm fov
|
254 |
+
|
255 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
256 |
+
|
257 |
+
elif pattern == 8:
|
258 |
+
size = self.voxel.voxel_t.size(1) // 2
|
259 |
+
for i in range(maxstep):
|
260 |
+
farpoint = torch.tensor([
|
261 |
+
300,
|
262 |
+
0*size + voxel.voxel_t.size(1)//2,
|
263 |
+
-1*size + voxel.voxel_t.size(2)/2 + size // maxstep * (i - maxstep // 4)])
|
264 |
+
nearpoint = torch.tensor([
|
265 |
+
120,
|
266 |
+
0*size*0.5 + voxel.voxel_t.size(1)//2,
|
267 |
+
-1*size*0.5 + voxel.voxel_t.size(2)/2 + size // maxstep * (i - maxstep // 4)])
|
268 |
+
cam_ori = self.voxel.world2local(farpoint)
|
269 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
270 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
271 |
+
cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
|
272 |
+
|
273 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
274 |
+
|
275 |
+
elif pattern == 9:
|
276 |
+
size = self.voxel.voxel_t.size(2) // 2
|
277 |
+
for i in range(maxstep):
|
278 |
+
farpoint = torch.tensor([
|
279 |
+
140,
|
280 |
+
voxel.voxel_t.size(1)//2,
|
281 |
+
-size // 4 + size * 8 // maxstep * i]
|
282 |
+
, dtype=torch.float32)
|
283 |
+
nearpoint = torch.tensor([
|
284 |
+
100,
|
285 |
+
voxel.voxel_t.size(1)//2,
|
286 |
+
size * 8 // maxstep * i]
|
287 |
+
, dtype=torch.float32)
|
288 |
+
cam_ori = self.voxel.world2local(farpoint)
|
289 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
290 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
291 |
+
cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
|
292 |
+
|
293 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
294 |
+
|
295 |
+
|
296 |
+
def _get_height(self, loc0, loc1, minheight):
|
297 |
+
loc0 = int(loc0)
|
298 |
+
loc1 = int(loc1)
|
299 |
+
height = minheight
|
300 |
+
for dx in range(-3, 4):
|
301 |
+
for dy in range(-3, 4):
|
302 |
+
if (loc0+dx) < 0 or (loc0+dx) >= self.voxel.heightmap.shape[0] or (loc1+dy) < 0 or \
|
303 |
+
(loc1+dy) >= self.voxel.heightmap.shape[1]:
|
304 |
+
height = max(height, minheight)
|
305 |
+
else:
|
306 |
+
height = max(height, self.voxel.heightmap[loc0+dx, loc1+dy] + 2)
|
307 |
+
return height
|
308 |
+
|
309 |
+
def filtfilt(self, height_history, decay=0.2):
|
310 |
+
# Filtfilt
|
311 |
+
height_history2 = []
|
312 |
+
maxstep = len(height_history)
|
313 |
+
prev_height = height_history[0]
|
314 |
+
for i in range(maxstep):
|
315 |
+
prev_height = prev_height - decay
|
316 |
+
if prev_height < height_history[i]:
|
317 |
+
prev_height = height_history[i]
|
318 |
+
height_history2.append(prev_height)
|
319 |
+
prev_height = height_history[-1]
|
320 |
+
for i in range(maxstep-1, -1, -1):
|
321 |
+
prev_height = prev_height - decay
|
322 |
+
if prev_height < height_history[i]:
|
323 |
+
prev_height = height_history[i]
|
324 |
+
height_history2[i] = max(prev_height, height_history2[i])
|
325 |
+
return height_history2
|
326 |
+
|
327 |
+
def __len__(self):
|
328 |
+
return len(self.camera_poses)
|
329 |
+
|
330 |
+
def __getitem__(self, idx):
|
331 |
+
return self.camera_poses[idx]
|
332 |
+
|
333 |
+
|
334 |
+
class TourCameraController:
|
335 |
+
def __init__(self, voxel, maxstep=128):
|
336 |
+
self.voxel = voxel
|
337 |
+
self.maxstep = maxstep
|
338 |
+
self.camera_poses = [] # ori, dir, up, f
|
339 |
+
circle = torch.linspace(0, 2*np.pi, steps=maxstep//4)
|
340 |
+
size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
|
341 |
+
# Shrink the circle a bit
|
342 |
+
shift = size * 0.2
|
343 |
+
size = size * 0.8
|
344 |
+
|
345 |
+
for i in range(maxstep//4):
|
346 |
+
farpoint = torch.tensor([
|
347 |
+
70,
|
348 |
+
torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
|
349 |
+
torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
|
350 |
+
|
351 |
+
farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
|
352 |
+
|
353 |
+
nearpoint = torch.tensor([
|
354 |
+
60,
|
355 |
+
torch.sin(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(1)/2 + shift,
|
356 |
+
torch.cos(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(2)/2 + shift])
|
357 |
+
cam_ori = self.voxel.world2local(farpoint)
|
358 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
359 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
360 |
+
cam_f = 0.5/np.tan(np.deg2rad(73/2)) # about 24mm fov
|
361 |
+
|
362 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
363 |
+
|
364 |
+
zoom = torch.linspace(1.0, 0.25, steps=maxstep//4)
|
365 |
+
for i in range(maxstep//4):
|
366 |
+
farpoint = torch.tensor([
|
367 |
+
90,
|
368 |
+
torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
|
369 |
+
torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
|
370 |
+
|
371 |
+
farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
|
372 |
+
|
373 |
+
nearpoint = torch.tensor([
|
374 |
+
60,
|
375 |
+
torch.sin(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(1)/2 + shift,
|
376 |
+
torch.cos(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(2)/2 + shift])
|
377 |
+
cam_ori = self.voxel.world2local(farpoint)
|
378 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
379 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
380 |
+
cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov
|
381 |
+
|
382 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
383 |
+
|
384 |
+
move = torch.linspace(1.0, 0.2, steps=maxstep//4)
|
385 |
+
for i in range(maxstep//4):
|
386 |
+
farpoint = torch.tensor([
|
387 |
+
90,
|
388 |
+
torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
389 |
+
torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
390 |
+
|
391 |
+
farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
|
392 |
+
|
393 |
+
nearpoint = torch.tensor([
|
394 |
+
60,
|
395 |
+
torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
|
396 |
+
torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
|
397 |
+
cam_ori = self.voxel.world2local(farpoint)
|
398 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
399 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
400 |
+
cam_f = 0.5/np.tan(np.deg2rad(73/2)) # about 24mm fov
|
401 |
+
|
402 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
403 |
+
|
404 |
+
lift = torch.linspace(0.0, 200.0, steps=maxstep//4)
|
405 |
+
zoom = torch.linspace(0.6, 1.2, steps=maxstep//4)
|
406 |
+
for i in range(maxstep//4):
|
407 |
+
farpoint = torch.tensor([
|
408 |
+
80+lift[i],
|
409 |
+
torch.sin(circle[i])*size*0.2 + voxel.voxel_t.size(1)/2 + shift,
|
410 |
+
torch.cos(circle[i])*size*0.2 + voxel.voxel_t.size(2)/2 + shift])
|
411 |
+
|
412 |
+
farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
|
413 |
+
|
414 |
+
nearpoint = torch.tensor([
|
415 |
+
60,
|
416 |
+
torch.sin(circle[i]+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(1)/2 + shift,
|
417 |
+
torch.cos(circle[i]+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(2)/2 + shift])
|
418 |
+
cam_ori = self.voxel.world2local(farpoint)
|
419 |
+
cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
|
420 |
+
cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
421 |
+
cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov
|
422 |
+
|
423 |
+
self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
|
424 |
+
|
425 |
+
def _get_height(self, loc0, loc1, minheight):
|
426 |
+
loc0 = int(loc0)
|
427 |
+
loc1 = int(loc1)
|
428 |
+
height = minheight
|
429 |
+
for dx in range(-3, 4):
|
430 |
+
for dy in range(-3, 4):
|
431 |
+
if (loc0+dx) < 0 or (loc0+dx) >= self.voxel.heightmap.shape[0] or (loc1+dy) < 0 or \
|
432 |
+
(loc1+dy) >= self.voxel.heightmap.shape[1]:
|
433 |
+
height = max(height, minheight)
|
434 |
+
else:
|
435 |
+
height = max(height, self.voxel.heightmap[loc0+dx, loc1+dy] + 2)
|
436 |
+
return height
|
437 |
+
|
438 |
+
def __len__(self):
|
439 |
+
return len(self.camera_poses)
|
440 |
+
|
441 |
+
def __getitem__(self, idx):
|
442 |
+
return self.camera_poses[idx]
|
443 |
+
|
444 |
+
|
445 |
+
def rand_camera_pose_birdseye(voxel, border=128):
|
446 |
+
r"""Generating random camera pose in the upper hemisphere, in the format of origin-direction-up
|
447 |
+
Assuming [Y X Z] coordinate. Y is negative gravity direction.
|
448 |
+
The camera pose is converted into the voxel coordinate system so that it can be used directly for rendering
|
449 |
+
1. Uniformly sample a point on the upper hemisphere of a unit sphere, as cam_ori.
|
450 |
+
2. Set cam_dir to be from cam_ori to the origin
|
451 |
+
3. cam_up is always pointing towards sky
|
452 |
+
4. move cam_ori to random place according to voxel size
|
453 |
+
"""
|
454 |
+
cam_dir = torch.randn(3, dtype=torch.float32)
|
455 |
+
cam_dir = cam_dir / torch.sqrt(torch.sum(cam_dir*cam_dir))
|
456 |
+
cam_dir[0] = -torch.abs(cam_dir[0])
|
457 |
+
cam_up = torch.tensor([1, 0, 0], dtype=torch.float32)
|
458 |
+
|
459 |
+
# generate camera lookat target
|
460 |
+
r = np.random.rand(2)
|
461 |
+
r[0] *= voxel.voxel_t.size(1)-border-border
|
462 |
+
r[1] *= voxel.voxel_t.size(2)-border-border
|
463 |
+
r = r + border
|
464 |
+
y = voxel.heightmap[int(r[0]+0.5), int(r[1]+0.5)] + (np.random.rand(1)-0.5) * 5
|
465 |
+
cam_target = torch.tensor([y, r[0], r[1]], dtype=torch.float32)
|
466 |
+
cam_ori = cam_target - cam_dir * (np.random.rand(1).item() * 100)
|
467 |
+
cam_ori[0] = max(voxel.heightmap[int(cam_ori[1]+0.5), int(cam_ori[2]+0.5)]+2, cam_ori[0])
|
468 |
+
# Translate to voxel coordinate
|
469 |
+
cam_ori = voxel.world2local(cam_ori)
|
470 |
+
cam_dir = voxel.world2local(cam_dir, is_vec=True)
|
471 |
+
cam_up = voxel.world2local(cam_up, is_vec=True)
|
472 |
+
|
473 |
+
return cam_ori, cam_dir, cam_up
|
474 |
+
|
475 |
+
|
476 |
+
def get_neighbor_height(heightmap, loc0, loc1, minheight, neighbor_size=7):
|
477 |
+
loc0 = int(loc0)
|
478 |
+
loc1 = int(loc1)
|
479 |
+
height = 0
|
480 |
+
for dx in range(-neighbor_size//2, neighbor_size//2+1):
|
481 |
+
for dy in range(-neighbor_size//2, neighbor_size//2+1):
|
482 |
+
if (loc0+dx) < 0 or (loc0+dx) >= heightmap.shape[0] or (loc1+dy) < 0 or (loc1+dy) >= heightmap.shape[1]:
|
483 |
+
height = max(height, minheight)
|
484 |
+
else:
|
485 |
+
height = max(minheight, heightmap[loc0+dx, loc1+dy] + 2)
|
486 |
+
return height
|
487 |
+
|
488 |
+
|
489 |
+
def rand_camera_pose_firstperson(voxel, border=128):
|
490 |
+
r"""Generating random camera pose in the upper hemisphere, in the format of origin-direction-up
|
491 |
+
"""
|
492 |
+
r = np.random.rand(5)
|
493 |
+
r[0] *= voxel.voxel_t.size(1)-border-border
|
494 |
+
r[1] *= voxel.voxel_t.size(2)-border-border
|
495 |
+
r[0] = r[0] + border
|
496 |
+
r[1] = r[1] + border
|
497 |
+
|
498 |
+
y = get_neighbor_height(voxel.heightmap, r[0], r[1], 0) + np.random.rand(1) * 15
|
499 |
+
|
500 |
+
cam_ori = torch.tensor([y, r[0], r[1]], dtype=torch.float32)
|
501 |
+
|
502 |
+
rand_ang_h = r[2] * 2 * np.pi
|
503 |
+
cam_target = torch.tensor([0, cam_ori[1]+np.sin(rand_ang_h)*border*r[4], cam_ori[2] +
|
504 |
+
np.cos(rand_ang_h)*border*r[4]], dtype=torch.float32)
|
505 |
+
cam_target[0] = get_neighbor_height(voxel.heightmap, cam_target[1],
|
506 |
+
cam_target[2], 0, neighbor_size=1) - 2 + r[3] * 10
|
507 |
+
|
508 |
+
cam_dir = cam_target - cam_ori
|
509 |
+
|
510 |
+
cam_up = torch.tensor([1, 0, 0], dtype=torch.float32)
|
511 |
+
|
512 |
+
cam_ori = voxel.world2local(cam_ori)
|
513 |
+
cam_dir = voxel.world2local(cam_dir, is_vec=True)
|
514 |
+
cam_up = voxel.world2local(cam_up, is_vec=True)
|
515 |
+
|
516 |
+
return cam_ori, cam_dir, cam_up
|
517 |
+
|
518 |
+
|
519 |
+
def rand_camera_pose_thridperson(voxel, border=96):
|
520 |
+
r = torch.rand(2)
|
521 |
+
r[0] *= voxel.voxel_t.size(1)
|
522 |
+
r[1] *= voxel.voxel_t.size(2)
|
523 |
+
rand_height = 60 + torch.rand(1) * 40
|
524 |
+
rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=5)
|
525 |
+
farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
|
526 |
+
|
527 |
+
r = torch.rand(2)
|
528 |
+
r[0] *= voxel.voxel_t.size(1) - border - border
|
529 |
+
r[1] *= voxel.voxel_t.size(2) - border - border
|
530 |
+
r[0] = r[0] + border
|
531 |
+
r[1] = r[1] + border
|
532 |
+
rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=1) - 5
|
533 |
+
nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
|
534 |
+
|
535 |
+
cam_ori = voxel.world2local(farpoint)
|
536 |
+
cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
|
537 |
+
cam_up = voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
|
538 |
+
|
539 |
+
return cam_ori, cam_dir, cam_up
|
540 |
+
|
541 |
+
|
542 |
+
def rand_camera_pose_thridperson2(voxel, border=48):
|
543 |
+
r = torch.rand(2)
|
544 |
+
r[0] *= voxel.voxel_t.size(1) - border - border
|
545 |
+
r[1] *= voxel.voxel_t.size(2) - border - border
|
546 |
+
r[0] = r[0] + border
|
547 |
+
r[1] = r[1] + border
|
548 |
+
rand_height = 60 + torch.rand(1) * 40
|
549 |
+
rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=5)
|
550 |
+
farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
|
551 |
+
|
552 |
+
r = torch.rand(2)
|
553 |
+
r[0] *= voxel.voxel_t.size(1) - border - border
|
554 |
+
r[1] *= voxel.voxel_t.size(2) - border - border
|
555 |
+
r[0] = r[0] + border
|
556 |
+
r[1] = r[1] + border
|
557 |
+
rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=1) - 5
|
558 |
+
nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
|
559 |
+
|
560 |
+
# Random Up vector (tilt a little bit)
|
561 |
+
# up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
|
562 |
+
up = torch.randn(3) * 0.02
|
563 |
+
up[0] = 1.0
|
564 |
+
up = up / up.norm(p=2)
|
565 |
+
cam_ori = voxel.world2local(farpoint)
|
566 |
+
cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
|
567 |
+
cam_up = voxel.world2local(up, is_vec=True)
|
568 |
+
|
569 |
+
return cam_ori, cam_dir, cam_up
|
570 |
+
|
571 |
+
|
572 |
+
def rand_camera_pose_thridperson3(voxel, border=64):
|
573 |
+
r"""Attempting to solve the camera too close to wall problem and the lack of aerial poses."""
|
574 |
+
r = torch.rand(2)
|
575 |
+
r[0] *= voxel.voxel_t.size(1) - border - border
|
576 |
+
r[1] *= voxel.voxel_t.size(2) - border - border
|
577 |
+
r[0] = r[0] + border
|
578 |
+
r[1] = r[1] + border
|
579 |
+
rand_height = 60 + torch.rand(1) * 40
|
580 |
+
if torch.rand(1) > 0.8:
|
581 |
+
rand_height = 60 + torch.rand(1) * 60
|
582 |
+
rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=7)
|
583 |
+
farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
|
584 |
+
|
585 |
+
r = torch.rand(2)
|
586 |
+
r[0] *= voxel.voxel_t.size(1) - border - border
|
587 |
+
r[1] *= voxel.voxel_t.size(2) - border - border
|
588 |
+
r[0] = r[0] + border
|
589 |
+
r[1] = r[1] + border
|
590 |
+
rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=3) - 5
|
591 |
+
nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
|
592 |
+
|
593 |
+
# Random Up vector (tilt a little bit)
|
594 |
+
# up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
|
595 |
+
up = torch.randn(3) * 0.02
|
596 |
+
up[0] = 1.0
|
597 |
+
up = up / up.norm(p=2)
|
598 |
+
# print(up)
|
599 |
+
cam_ori = voxel.world2local(farpoint)
|
600 |
+
cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
|
601 |
+
cam_up = voxel.world2local(up, is_vec=True)
|
602 |
+
|
603 |
+
return cam_ori, cam_dir, cam_up
|
604 |
+
|
605 |
+
|
606 |
+
def rand_camera_pose_tour(voxel):
|
607 |
+
size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
|
608 |
+
center = [voxel.voxel_t.size(1)/2, voxel.voxel_t.size(2)/2]
|
609 |
+
|
610 |
+
rnd = torch.rand(8)
|
611 |
+
|
612 |
+
rnd_deg = torch.rand(1) * 2 * np.pi
|
613 |
+
far_radius = rnd[0]*0.8+0.2
|
614 |
+
far_height = rnd[1]*30 + 60
|
615 |
+
farpoint = torch.tensor([
|
616 |
+
far_height,
|
617 |
+
torch.sin(rnd_deg)*size*far_radius + center[0],
|
618 |
+
torch.cos(rnd_deg)*size*far_radius + center[1]])
|
619 |
+
|
620 |
+
farpoint[0] = get_neighbor_height(voxel.heightmap, farpoint[1], farpoint[2], farpoint[0], neighbor_size=7)
|
621 |
+
|
622 |
+
near_radius = far_radius * rnd[2]
|
623 |
+
near_shift_rad = np.pi*(rnd[3]-0.5)
|
624 |
+
near_height = 60 + rnd[4] * 10
|
625 |
+
nearpoint = torch.tensor([
|
626 |
+
near_height,
|
627 |
+
torch.sin(rnd_deg+near_shift_rad)*size*near_radius + center[0],
|
628 |
+
torch.cos(rnd_deg+near_shift_rad)*size*near_radius + center[1]])
|
629 |
+
|
630 |
+
# Random Up vector (tilt a little bit)
|
631 |
+
# up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
|
632 |
+
up = torch.randn(3) * 0.02
|
633 |
+
up[0] = 1.0
|
634 |
+
up = up / up.norm(p=2)
|
635 |
+
cam_ori = voxel.world2local(farpoint)
|
636 |
+
cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
|
637 |
+
cam_up = voxel.world2local(up, is_vec=True)
|
638 |
+
cam_f = 0.5/np.tan(np.deg2rad(73/2)*(rnd[5]*0.75+0.25)) # about 24mm fov
|
639 |
+
|
640 |
+
return cam_ori, cam_dir, cam_up, cam_f
|
641 |
+
|
642 |
+
# Look from center to outward
|
643 |
+
|
644 |
+
|
645 |
+
def rand_camera_pose_insideout(voxel):
|
646 |
+
size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
|
647 |
+
center = [voxel.voxel_t.size(1)/2, voxel.voxel_t.size(2)/2]
|
648 |
+
|
649 |
+
rnd = torch.rand(8)
|
650 |
+
|
651 |
+
rnd_deg = torch.rand(1) * 2 * np.pi
|
652 |
+
far_radius = rnd[0]*0.8+0.2
|
653 |
+
far_height = rnd[1]*10 + 60
|
654 |
+
farpoint = torch.tensor([
|
655 |
+
far_height,
|
656 |
+
torch.sin(rnd_deg)*size*far_radius + center[0],
|
657 |
+
torch.cos(rnd_deg)*size*far_radius + center[1]])
|
658 |
+
|
659 |
+
near_radius = far_radius * rnd[2]
|
660 |
+
near_shift_rad = np.pi*(rnd[3]-0.5)
|
661 |
+
near_height = 60 + rnd[4] * 30
|
662 |
+
nearpoint = torch.tensor([
|
663 |
+
near_height,
|
664 |
+
torch.sin(rnd_deg+near_shift_rad)*size*near_radius + center[0],
|
665 |
+
torch.cos(rnd_deg+near_shift_rad)*size*near_radius + center[1]])
|
666 |
+
|
667 |
+
nearpoint[0] = get_neighbor_height(voxel.heightmap, nearpoint[1], nearpoint[2], nearpoint[0], neighbor_size=7)
|
668 |
+
|
669 |
+
# Random Up vector (tilt a little bit)
|
670 |
+
# up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
|
671 |
+
up = torch.randn(3) * 0.02
|
672 |
+
up[0] = 1.0
|
673 |
+
up = up / up.norm(p=2)
|
674 |
+
cam_ori = voxel.world2local(nearpoint)
|
675 |
+
cam_dir = voxel.world2local(farpoint-nearpoint, is_vec=True)
|
676 |
+
cam_up = voxel.world2local(up, is_vec=True)
|
677 |
+
cam_f = 0.5/np.tan(np.deg2rad(73/2)*(rnd[5]*0.75+0.25)) # about 24mm fov
|
678 |
+
|
679 |
+
return cam_ori, cam_dir, cam_up, cam_f
|
imaginaire/model_utils/gancraft/gaugan_lbl2col.csv
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
person,#00AC0D
|
2 |
+
bicycle,#012F47
|
3 |
+
car,#0275B8
|
4 |
+
motorcycle,#03C098
|
5 |
+
airplane,#04434F
|
6 |
+
bus,#05FB29
|
7 |
+
train,#06C312
|
8 |
+
truck,#076728
|
9 |
+
boat,#0809B6
|
10 |
+
traffic-light,#09D3CF
|
11 |
+
fire-hydrant,#0A150B
|
12 |
+
street-sign,#0BF2A6
|
13 |
+
stop-sign,#0C246F
|
14 |
+
parking-meter,#0D575D
|
15 |
+
bench,#0E46F9
|
16 |
+
bird,#0FD881
|
17 |
+
cat,#1058DF
|
18 |
+
dog,#118C76
|
19 |
+
horse,#123A2C
|
20 |
+
sheep,#13C1D8
|
21 |
+
cow,#14E67D
|
22 |
+
elephant,#152718
|
23 |
+
bear,#165743
|
24 |
+
zebra,#17AED2
|
25 |
+
giraffe,#1858EF
|
26 |
+
hat,#195103
|
27 |
+
backpack,#1AA5EA
|
28 |
+
umbrella,#1B19CC
|
29 |
+
shoe,#1C4DE6
|
30 |
+
eye-glasses,#1D4823
|
31 |
+
handbag,#1E09D6
|
32 |
+
tie,#1F94FE
|
33 |
+
suitcase,#2073BD
|
34 |
+
frisbee,#21D0C5
|
35 |
+
skis,#22F3D7
|
36 |
+
snowboard,#23C52B
|
37 |
+
sports-ball,#24FE20
|
38 |
+
kite,#254F0B
|
39 |
+
baseball-bat,#26AF68
|
40 |
+
baseball-glove,#27C0D4
|
41 |
+
skateboard,#28528A
|
42 |
+
surfboard,#2963B6
|
43 |
+
tennis-racket,#2AD8EB
|
44 |
+
bottle,#2BB1A5
|
45 |
+
plate,#2CF37D
|
46 |
+
wine-glass,#2D1D9C
|
47 |
+
cup,#2E936F
|
48 |
+
fork,#2F93E8
|
49 |
+
knife,#308E02
|
50 |
+
spoon,#31A71B
|
51 |
+
bowl,#3220D3
|
52 |
+
banana,#33C1D9
|
53 |
+
apple,#340997
|
54 |
+
sandwich,#35B935
|
55 |
+
orange,#367F33
|
56 |
+
broccoli,#3720AE
|
57 |
+
carrot,#381F94
|
58 |
+
hot-dog,#39CAB5
|
59 |
+
pizza,#3AF41D
|
60 |
+
donut,#3B9743
|
61 |
+
cake,#3CA323
|
62 |
+
chair,#3DFE27
|
63 |
+
couch,#3ECB89
|
64 |
+
potted-plant,#3F7249
|
65 |
+
bed,#40B729
|
66 |
+
mirror,#411C97
|
67 |
+
dining-table,#422283
|
68 |
+
window,#43802E
|
69 |
+
desk,#4480DA
|
70 |
+
toilet,#45A4B2
|
71 |
+
door,#46356C
|
72 |
+
tv,#478503
|
73 |
+
laptop,#48261F
|
74 |
+
mouse,#49E809
|
75 |
+
remote,#4AF48A
|
76 |
+
keyboard,#4B111B
|
77 |
+
cell-phone,#4C4FAD
|
78 |
+
microwave,#4D84C7
|
79 |
+
oven,#4E69A7
|
80 |
+
toaster,#4F2A3D
|
81 |
+
sink,#50BA55
|
82 |
+
refrigerator,#511F61
|
83 |
+
blender,#52782C
|
84 |
+
book,#530122
|
85 |
+
clock,#5441A2
|
86 |
+
vase,#55E758
|
87 |
+
scissors,#56A921
|
88 |
+
teddy-bear,#573985
|
89 |
+
hair-drier,#5823E8
|
90 |
+
toothbrush,#5966FF
|
91 |
+
hair-brush,#5A7724
|
92 |
+
banner,#5B0B00
|
93 |
+
blanket,#5CAECB
|
94 |
+
branch,#5D5222
|
95 |
+
bridge,#5E5BC5
|
96 |
+
building-other,#5F807E
|
97 |
+
bush,#606E32
|
98 |
+
cabinet,#6163FE
|
99 |
+
cage,#623550
|
100 |
+
cardboard,#638CBE
|
101 |
+
carpet,#647988
|
102 |
+
ceiling-other,#65AABD
|
103 |
+
ceiling-tile,#665481
|
104 |
+
cloth,#67CBD1
|
105 |
+
clothes,#684470
|
106 |
+
clouds,#696969
|
107 |
+
counter,#6AC478
|
108 |
+
cupboard,#6B2F5B
|
109 |
+
curtain,#6C7FA8
|
110 |
+
desk-stuff,#6DF474
|
111 |
+
dirt,#6E6E28
|
112 |
+
door-stuff,#6FCCB0
|
113 |
+
fence,#706419
|
114 |
+
floor-marble,#71B443
|
115 |
+
floor-other,#72E867
|
116 |
+
floor-stone,#734EFC
|
117 |
+
floor-tile,#748F23
|
118 |
+
floor-wood,#759472
|
119 |
+
flower,#760000
|
120 |
+
fog,#77BA1D
|
121 |
+
food-other,#7817F1
|
122 |
+
fruit,#79CF21
|
123 |
+
furniture-other,#7A8D92
|
124 |
+
grass,#7BC800
|
125 |
+
gravel,#7C32C8
|
126 |
+
ground-other,#7D3054
|
127 |
+
hill,#7EC864
|
128 |
+
house,#7F4502
|
129 |
+
leaves,#80A945
|
130 |
+
light,#81A365
|
131 |
+
mat,#82C08C
|
132 |
+
metal,#835F2C
|
133 |
+
mirror-stuff,#84C575
|
134 |
+
moss,#855EFD
|
135 |
+
mountain,#869664
|
136 |
+
mud,#87716F
|
137 |
+
napkin,#88B25B
|
138 |
+
net,#892455
|
139 |
+
paper,#8AA2A7
|
140 |
+
pavement,#8B3027
|
141 |
+
pillow,#8C5DCB
|
142 |
+
plant,#8DE61E
|
143 |
+
plastic,#8E629E
|
144 |
+
platform,#8F2A91
|
145 |
+
playingfield,#90CDC6
|
146 |
+
railing,#9170C7
|
147 |
+
railroad,#92E712
|
148 |
+
river,#9364C8
|
149 |
+
road,#946E28
|
150 |
+
rock,#956432
|
151 |
+
roof,#9600B1
|
152 |
+
rug,#978A29
|
153 |
+
salad,#98725D
|
154 |
+
sand,#999900
|
155 |
+
sea,#9AC6DA
|
156 |
+
shelf,#9B7FC9
|
157 |
+
sky,#9CEEDD
|
158 |
+
skyscraper,#9DBBF2
|
159 |
+
snow,#9E9EAA
|
160 |
+
solid-other,#9F79DB
|
161 |
+
stairs,#A06249
|
162 |
+
stone,#A1A164
|
163 |
+
straw,#A2A3EB
|
164 |
+
structural,#A3DED1
|
165 |
+
table,#A47B69
|
166 |
+
tent,#A5C3BA
|
167 |
+
textile-other,#A65280
|
168 |
+
towel,#A7AED6
|
169 |
+
tree,#A8C832
|
170 |
+
vegetable,#A99410
|
171 |
+
wall-brick,#AAD16A
|
172 |
+
wall-concrete,#AB32A4
|
173 |
+
wall-other,#AC9B5E
|
174 |
+
wall-panel,#AD0E18
|
175 |
+
wall-stone,#AE2974
|
176 |
+
wall-tile,#AF3ABF
|
177 |
+
wall-wood,#B0C1C3
|
178 |
+
water,#B1C8FF
|
179 |
+
waterdrops,#B20A88
|
180 |
+
window-blind,#B356B8
|
181 |
+
window-other,#B42B5B
|
182 |
+
wood,#B57B00
|
imaginaire/model_utils/gancraft/gaugan_reduction.csv
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
person,ignore
|
2 |
+
bicycle,ignore
|
3 |
+
car,ignore
|
4 |
+
motorcycle,ignore
|
5 |
+
airplane,ignore
|
6 |
+
bus,ignore
|
7 |
+
train,ignore
|
8 |
+
truck,ignore
|
9 |
+
boat,ignore
|
10 |
+
traffic-light,ignore
|
11 |
+
fire-hydrant,ignore
|
12 |
+
street-sign,ignore
|
13 |
+
stop-sign,ignore
|
14 |
+
parking-meter,ignore
|
15 |
+
bench,ignore
|
16 |
+
bird,ignore
|
17 |
+
cat,ignore
|
18 |
+
dog,ignore
|
19 |
+
horse,ignore
|
20 |
+
sheep,ignore
|
21 |
+
cow,ignore
|
22 |
+
elephant,ignore
|
23 |
+
bear,ignore
|
24 |
+
zebra,ignore
|
25 |
+
giraffe,ignore
|
26 |
+
hat,ignore
|
27 |
+
backpack,ignore
|
28 |
+
umbrella,ignore
|
29 |
+
shoe,ignore
|
30 |
+
eye-glasses,ignore
|
31 |
+
handbag,ignore
|
32 |
+
tie,ignore
|
33 |
+
suitcase,ignore
|
34 |
+
frisbee,ignore
|
35 |
+
skis,ignore
|
36 |
+
snowboard,ignore
|
37 |
+
sports-ball,ignore
|
38 |
+
kite,ignore
|
39 |
+
baseball-bat,ignore
|
40 |
+
baseball-glove,ignore
|
41 |
+
skateboard,ignore
|
42 |
+
surfboard,ignore
|
43 |
+
tennis-racket,ignore
|
44 |
+
bottle,ignore
|
45 |
+
plate,ignore
|
46 |
+
wine-glass,ignore
|
47 |
+
cup,ignore
|
48 |
+
fork,ignore
|
49 |
+
knife,ignore
|
50 |
+
spoon,ignore
|
51 |
+
bowl,ignore
|
52 |
+
banana,ignore
|
53 |
+
apple,ignore
|
54 |
+
sandwich,ignore
|
55 |
+
orange,ignore
|
56 |
+
broccoli,ignore
|
57 |
+
carrot,ignore
|
58 |
+
hot-dog,ignore
|
59 |
+
pizza,ignore
|
60 |
+
donut,ignore
|
61 |
+
cake,ignore
|
62 |
+
chair,ignore
|
63 |
+
couch,ignore
|
64 |
+
potted-plant,ignore
|
65 |
+
bed,ignore
|
66 |
+
mirror,ignore
|
67 |
+
dining-table,ignore
|
68 |
+
window,ignore
|
69 |
+
desk,ignore
|
70 |
+
toilet,ignore
|
71 |
+
door,ignore
|
72 |
+
tv,ignore
|
73 |
+
laptop,ignore
|
74 |
+
mouse,ignore
|
75 |
+
remote,ignore
|
76 |
+
keyboard,ignore
|
77 |
+
cell-phone,ignore
|
78 |
+
microwave,ignore
|
79 |
+
oven,ignore
|
80 |
+
toaster,ignore
|
81 |
+
sink,ignore
|
82 |
+
refrigerator,ignore
|
83 |
+
blender,ignore
|
84 |
+
book,ignore
|
85 |
+
clock,ignore
|
86 |
+
vase,ignore
|
87 |
+
scissors,ignore
|
88 |
+
teddy-bear,ignore
|
89 |
+
hair-drier,ignore
|
90 |
+
toothbrush,ignore
|
91 |
+
hair-brush,ignore
|
92 |
+
banner,ignore
|
93 |
+
blanket,ignore
|
94 |
+
branch,tree
|
95 |
+
bridge,ignore
|
96 |
+
building-other,ignore
|
97 |
+
bush,tree
|
98 |
+
cabinet,ignore
|
99 |
+
cage,ignore
|
100 |
+
cardboard,ignore
|
101 |
+
carpet,ignore
|
102 |
+
ceiling-other,ignore
|
103 |
+
ceiling-tile,ignore
|
104 |
+
cloth,ignore
|
105 |
+
clothes,ignore
|
106 |
+
clouds,sky
|
107 |
+
counter,ignore
|
108 |
+
cupboard,ignore
|
109 |
+
curtain,ignore
|
110 |
+
desk-stuff,ignore
|
111 |
+
dirt,dirt
|
112 |
+
door-stuff,ignore
|
113 |
+
fence,ignore
|
114 |
+
floor-marble,ignore
|
115 |
+
floor-other,ignore
|
116 |
+
floor-stone,ignore
|
117 |
+
floor-tile,ignore
|
118 |
+
floor-wood,ignore
|
119 |
+
flower,flower
|
120 |
+
fog,sky
|
121 |
+
food-other,ignore
|
122 |
+
fruit,ignore
|
123 |
+
furniture-other,ignore
|
124 |
+
grass,grass
|
125 |
+
gravel,gravel
|
126 |
+
ground-other,ignore
|
127 |
+
hill,grass
|
128 |
+
house,ignore
|
129 |
+
leaves,tree
|
130 |
+
light,ignore
|
131 |
+
mat,ignore
|
132 |
+
metal,ignore
|
133 |
+
mirror-stuff,ignore
|
134 |
+
moss,grass
|
135 |
+
mountain,grass
|
136 |
+
mud,dirt
|
137 |
+
napkin,ignore
|
138 |
+
net,ignore
|
139 |
+
paper,ignore
|
140 |
+
pavement,ignore
|
141 |
+
pillow,ignore
|
142 |
+
plant,flower
|
143 |
+
plastic,ignore
|
144 |
+
platform,ignore
|
145 |
+
playingfield,ignore
|
146 |
+
railing,ignore
|
147 |
+
railroad,ignore
|
148 |
+
river,water
|
149 |
+
road,ignore
|
150 |
+
rock,rock
|
151 |
+
roof,ignore
|
152 |
+
rug,ignore
|
153 |
+
salad,ignore
|
154 |
+
sand,sand
|
155 |
+
sea,water
|
156 |
+
shelf,ignore
|
157 |
+
sky,sky
|
158 |
+
skyscraper,ignore
|
159 |
+
snow,snow
|
160 |
+
solid-other,ignore
|
161 |
+
stairs,ignore
|
162 |
+
stone,stone
|
163 |
+
straw,grass
|
164 |
+
structural,ignore
|
165 |
+
table,ignore
|
166 |
+
tent,ignore
|
167 |
+
textile-other,ignore
|
168 |
+
towel,ignore
|
169 |
+
tree,tree
|
170 |
+
vegetable,ignore
|
171 |
+
wall-brick,ignore
|
172 |
+
wall-concrete,ignore
|
173 |
+
wall-other,ignore
|
174 |
+
wall-panel,ignore
|
175 |
+
wall-stone,ignore
|
176 |
+
wall-tile,ignore
|
177 |
+
wall-wood,ignore
|
178 |
+
water,water
|
179 |
+
waterdrops,ignore
|
180 |
+
window-blind,ignore
|
181 |
+
window-other,ignore
|
182 |
+
wood,ignore
|
imaginaire/model_utils/gancraft/id2name_gg.csv
ADDED
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
0,air,0,sky
|
2 |
+
1,stone,7368816,stone
|
3 |
+
2,granite,7368816,rock
|
4 |
+
3,polished_granite,7368816,rock
|
5 |
+
4,diorite,7368816,rock
|
6 |
+
5,polished_diorite,7368816,rock
|
7 |
+
6,andesite,7368816,rock
|
8 |
+
7,polished_andesite,7368816,rock
|
9 |
+
8,grass_block,8368696,grass
|
10 |
+
9,dirt,9923917,dirt
|
11 |
+
10,coarse_dirt,9923917,dirt
|
12 |
+
11,podzol,9923917,dirt
|
13 |
+
12,cobblestone,7368816,stone
|
14 |
+
13,oak_planks,9402184,wood
|
15 |
+
14,spruce_planks,9402184,wood
|
16 |
+
15,birch_planks,9402184,wood
|
17 |
+
16,jungle_planks,9402184,wood
|
18 |
+
17,acacia_planks,9402184,wood
|
19 |
+
18,dark_oak_planks,9402184,wood
|
20 |
+
19,oak_sapling,31744,plant
|
21 |
+
20,spruce_sapling,31744,plant
|
22 |
+
21,birch_sapling,31744,plant
|
23 |
+
22,jungle_sapling,31744,plant
|
24 |
+
23,acacia_sapling,31744,plant
|
25 |
+
24,dark_oak_sapling,31744,plant
|
26 |
+
25,bedrock,7368816,rock
|
27 |
+
26,water,4210943,water
|
28 |
+
27,lava,16711680,
|
29 |
+
28,sand,16247203,sand
|
30 |
+
29,red_sand,16247203,sand
|
31 |
+
30,gravel,16247203,gravel
|
32 |
+
31,gold_ore,7368816,rock
|
33 |
+
32,iron_ore,7368816,rock
|
34 |
+
33,coal_ore,7368816,rock
|
35 |
+
34,oak_log,9402184,tree
|
36 |
+
35,spruce_log,9402184,tree
|
37 |
+
36,birch_log,9402184,tree
|
38 |
+
37,jungle_log,9402184,tree
|
39 |
+
38,acacia_log,9402184,tree
|
40 |
+
39,dark_oak_log,9402184,tree
|
41 |
+
40,stripped_spruce_log,9402184,wood
|
42 |
+
41,stripped_birch_log,9402184,wood
|
43 |
+
42,stripped_jungle_log,9402184,wood
|
44 |
+
43,stripped_acacia_log,9402184,wood
|
45 |
+
44,stripped_dark_oak_log,9402184,wood
|
46 |
+
45,stripped_oak_log,9402184,wood
|
47 |
+
46,oak_wood,9402184,wood
|
48 |
+
47,spruce_wood,9402184,wood
|
49 |
+
48,birch_wood,9402184,wood
|
50 |
+
49,jungle_wood,9402184,wood
|
51 |
+
50,acacia_wood,9402184,wood
|
52 |
+
51,dark_oak_wood,9402184,wood
|
53 |
+
52,stripped_oak_wood,9402184,wood
|
54 |
+
53,stripped_spruce_wood,9402184,wood
|
55 |
+
54,stripped_birch_wood,9402184,wood
|
56 |
+
55,stripped_jungle_wood,9402184,wood
|
57 |
+
56,stripped_acacia_wood,9402184,wood
|
58 |
+
57,stripped_dark_oak_wood,9402184,wood
|
59 |
+
58,oak_leaves,31744,tree
|
60 |
+
59,spruce_leaves,31744,tree
|
61 |
+
60,birch_leaves,31744,tree
|
62 |
+
61,jungle_leaves,31744,tree
|
63 |
+
62,acacia_leaves,31744,tree
|
64 |
+
63,dark_oak_leaves,31744,tree
|
65 |
+
64,sponge,15066419,
|
66 |
+
65,wet_sponge,15066419,
|
67 |
+
66,glass,0,
|
68 |
+
67,lapis_ore,7368816,
|
69 |
+
68,lapis_block,10987431,
|
70 |
+
69,dispenser,7368816,
|
71 |
+
70,sandstone,7368816,sand
|
72 |
+
71,chiseled_sandstone,7368816,sand
|
73 |
+
72,cut_sandstone,7368816,sand
|
74 |
+
73,note_block,9402184,
|
75 |
+
74,white_bed,13092807,
|
76 |
+
75,orange_bed,13092807,
|
77 |
+
76,magenta_bed,13092807,
|
78 |
+
77,light_blue_bed,13092807,
|
79 |
+
78,yellow_bed,13092807,
|
80 |
+
79,lime_bed,13092807,
|
81 |
+
80,pink_bed,13092807,
|
82 |
+
81,gray_bed,13092807,
|
83 |
+
82,light_gray_bed,13092807,
|
84 |
+
83,cyan_bed,13092807,
|
85 |
+
84,purple_bed,13092807,
|
86 |
+
85,blue_bed,13092807,
|
87 |
+
86,brown_bed,13092807,
|
88 |
+
87,green_bed,13092807,
|
89 |
+
88,red_bed,13092807,
|
90 |
+
89,black_bed,13092807,
|
91 |
+
90,powered_rail,0,
|
92 |
+
91,detector_rail,0,
|
93 |
+
92,sticky_piston,7368816,
|
94 |
+
93,cobweb,13092807,
|
95 |
+
94,grass,31744,grass
|
96 |
+
95,fern,31744,grass
|
97 |
+
96,dead_bush,31744,grass
|
98 |
+
97,seagrass,4210943,water
|
99 |
+
98,tall_seagrass,4210943,water
|
100 |
+
99,piston,7368816,
|
101 |
+
100,piston_head,7368816,
|
102 |
+
101,white_wool,13092807,
|
103 |
+
102,orange_wool,13092807,
|
104 |
+
103,magenta_wool,13092807,
|
105 |
+
104,light_blue_wool,13092807,
|
106 |
+
105,yellow_wool,13092807,
|
107 |
+
106,lime_wool,13092807,
|
108 |
+
107,pink_wool,13092807,
|
109 |
+
108,gray_wool,13092807,
|
110 |
+
109,light_gray_wool,13092807,
|
111 |
+
110,cyan_wool,13092807,
|
112 |
+
111,purple_wool,13092807,
|
113 |
+
112,blue_wool,13092807,
|
114 |
+
113,brown_wool,13092807,
|
115 |
+
114,green_wool,13092807,
|
116 |
+
115,red_wool,13092807,
|
117 |
+
116,black_wool,13092807,
|
118 |
+
117,moving_piston,7368816,
|
119 |
+
118,dandelion,31744,flower
|
120 |
+
119,poppy,31744,flower
|
121 |
+
120,blue_orchid,31744,flower
|
122 |
+
121,allium,31744,flower
|
123 |
+
122,azure_bluet,31744,flower
|
124 |
+
123,red_tulip,31744,flower
|
125 |
+
124,orange_tulip,31744,flower
|
126 |
+
125,white_tulip,31744,flower
|
127 |
+
126,pink_tulip,31744,flower
|
128 |
+
127,oxeye_daisy,31744,flower
|
129 |
+
128,cornflower,31744,flower
|
130 |
+
129,wither_rose,31744,flower
|
131 |
+
130,lily_of_the_valley,31744,flower
|
132 |
+
131,brown_mushroom,31744,flower
|
133 |
+
132,red_mushroom,31744,flower
|
134 |
+
133,gold_block,10987431,
|
135 |
+
134,iron_block,10987431,
|
136 |
+
135,bricks,7368816,
|
137 |
+
136,tnt,16711680,
|
138 |
+
137,bookshelf,9402184,
|
139 |
+
138,mossy_cobblestone,7368816,
|
140 |
+
139,obsidian,7368816,
|
141 |
+
140,torch,0,
|
142 |
+
141,wall_torch,0,
|
143 |
+
142,fire,0,
|
144 |
+
143,spawner,7368816,
|
145 |
+
144,oak_stairs,9402184,
|
146 |
+
145,chest,9402184,
|
147 |
+
146,redstone_wire,0,
|
148 |
+
147,diamond_ore,7368816,
|
149 |
+
148,diamond_block,10987431,
|
150 |
+
149,crafting_table,9402184,
|
151 |
+
150,wheat,31744,
|
152 |
+
151,farmland,9923917,
|
153 |
+
152,furnace,7368816,
|
154 |
+
153,oak_sign,9402184,
|
155 |
+
154,spruce_sign,9402184,
|
156 |
+
155,birch_sign,9402184,
|
157 |
+
156,acacia_sign,9402184,
|
158 |
+
157,jungle_sign,9402184,
|
159 |
+
158,dark_oak_sign,9402184,
|
160 |
+
159,oak_door,9402184,
|
161 |
+
160,ladder,0,
|
162 |
+
161,rail,0,
|
163 |
+
162,cobblestone_stairs,7368816,
|
164 |
+
163,oak_wall_sign,9402184,
|
165 |
+
164,spruce_wall_sign,9402184,
|
166 |
+
165,birch_wall_sign,9402184,
|
167 |
+
166,acacia_wall_sign,9402184,
|
168 |
+
167,jungle_wall_sign,9402184,
|
169 |
+
168,dark_oak_wall_sign,9402184,
|
170 |
+
169,lever,0,
|
171 |
+
170,stone_pressure_plate,7368816,
|
172 |
+
171,iron_door,10987431,
|
173 |
+
172,oak_pressure_plate,9402184,
|
174 |
+
173,spruce_pressure_plate,9402184,
|
175 |
+
174,birch_pressure_plate,9402184,
|
176 |
+
175,jungle_pressure_plate,9402184,
|
177 |
+
176,acacia_pressure_plate,9402184,
|
178 |
+
177,dark_oak_pressure_plate,9402184,
|
179 |
+
178,redstone_ore,7368816,
|
180 |
+
179,redstone_torch,0,
|
181 |
+
180,redstone_wall_torch,0,
|
182 |
+
181,stone_button,0,
|
183 |
+
182,snow,16777215,snow
|
184 |
+
183,ice,10526975,snow
|
185 |
+
184,snow_block,16777215,snow
|
186 |
+
185,cactus,31744,plant
|
187 |
+
186,clay,10791096,
|
188 |
+
187,sugar_cane,31744,plant
|
189 |
+
188,jukebox,9402184,
|
190 |
+
189,oak_fence,9402184,
|
191 |
+
190,pumpkin,31744,
|
192 |
+
191,netherrack,7368816,
|
193 |
+
192,soul_sand,16247203,
|
194 |
+
193,glowstone,0,
|
195 |
+
194,nether_portal,0,
|
196 |
+
195,carved_pumpkin,31744,
|
197 |
+
196,jack_o_lantern,31744,
|
198 |
+
197,cake,0,
|
199 |
+
198,repeater,0,
|
200 |
+
199,white_stained_glass,0,
|
201 |
+
200,orange_stained_glass,0,
|
202 |
+
201,magenta_stained_glass,0,
|
203 |
+
202,light_blue_stained_glass,0,
|
204 |
+
203,yellow_stained_glass,0,
|
205 |
+
204,lime_stained_glass,0,
|
206 |
+
205,pink_stained_glass,0,
|
207 |
+
206,gray_stained_glass,0,
|
208 |
+
207,light_gray_stained_glass,0,
|
209 |
+
208,cyan_stained_glass,0,
|
210 |
+
209,purple_stained_glass,0,
|
211 |
+
210,blue_stained_glass,0,
|
212 |
+
211,brown_stained_glass,0,
|
213 |
+
212,green_stained_glass,0,
|
214 |
+
213,red_stained_glass,0,
|
215 |
+
214,black_stained_glass,0,
|
216 |
+
215,oak_trapdoor,9402184,
|
217 |
+
216,spruce_trapdoor,9402184,
|
218 |
+
217,birch_trapdoor,9402184,
|
219 |
+
218,jungle_trapdoor,9402184,
|
220 |
+
219,acacia_trapdoor,9402184,
|
221 |
+
220,dark_oak_trapdoor,9402184,
|
222 |
+
221,stone_bricks,7368816,
|
223 |
+
222,mossy_stone_bricks,7368816,
|
224 |
+
223,cracked_stone_bricks,7368816,
|
225 |
+
224,chiseled_stone_bricks,7368816,
|
226 |
+
225,infested_stone,10791096,
|
227 |
+
226,infested_cobblestone,10791096,
|
228 |
+
227,infested_stone_bricks,10791096,
|
229 |
+
228,infested_mossy_stone_bricks,10791096,
|
230 |
+
229,infested_cracked_stone_bricks,10791096,
|
231 |
+
230,infested_chiseled_stone_bricks,10791096,
|
232 |
+
231,brown_mushroom_block,9402184,tree
|
233 |
+
232,red_mushroom_block,9402184,tree
|
234 |
+
233,mushroom_stem,9402184,tree
|
235 |
+
234,iron_bars,10987431,
|
236 |
+
235,glass_pane,0,
|
237 |
+
236,melon,31744,
|
238 |
+
237,attached_pumpkin_stem,31744,
|
239 |
+
238,attached_melon_stem,31744,
|
240 |
+
239,pumpkin_stem,31744,
|
241 |
+
240,melon_stem,31744,
|
242 |
+
241,vine,31744,plant
|
243 |
+
242,oak_fence_gate,9402184,
|
244 |
+
243,brick_stairs,7368816,
|
245 |
+
244,stone_brick_stairs,7368816,
|
246 |
+
245,mycelium,8368696,
|
247 |
+
246,lily_pad,31744,grass
|
248 |
+
247,nether_bricks,7368816,
|
249 |
+
248,nether_brick_fence,7368816,
|
250 |
+
249,nether_brick_stairs,7368816,
|
251 |
+
250,nether_wart,31744,
|
252 |
+
251,enchanting_table,7368816,
|
253 |
+
252,brewing_stand,10987431,
|
254 |
+
253,cauldron,10987431,
|
255 |
+
254,end_portal,0,
|
256 |
+
255,end_portal_frame,7368816,
|
257 |
+
256,end_stone,7368816,
|
258 |
+
257,dragon_egg,31744,
|
259 |
+
258,redstone_lamp,0,
|
260 |
+
259,cocoa,31744,
|
261 |
+
260,sandstone_stairs,7368816,
|
262 |
+
261,emerald_ore,7368816,
|
263 |
+
262,ender_chest,7368816,
|
264 |
+
263,tripwire_hook,0,
|
265 |
+
264,tripwire,0,
|
266 |
+
265,emerald_block,10987431,
|
267 |
+
266,spruce_stairs,9402184,
|
268 |
+
267,birch_stairs,9402184,
|
269 |
+
268,jungle_stairs,9402184,
|
270 |
+
269,command_block,10987431,
|
271 |
+
270,beacon,0,
|
272 |
+
271,cobblestone_wall,7368816,
|
273 |
+
272,mossy_cobblestone_wall,7368816,
|
274 |
+
273,flower_pot,0,
|
275 |
+
274,potted_oak_sapling,0,
|
276 |
+
275,potted_spruce_sapling,0,
|
277 |
+
276,potted_birch_sapling,0,
|
278 |
+
277,potted_jungle_sapling,0,
|
279 |
+
278,potted_acacia_sapling,0,
|
280 |
+
279,potted_dark_oak_sapling,0,
|
281 |
+
280,potted_fern,0,
|
282 |
+
281,potted_dandelion,0,
|
283 |
+
282,potted_poppy,0,
|
284 |
+
283,potted_blue_orchid,0,
|
285 |
+
284,potted_allium,0,
|
286 |
+
285,potted_azure_bluet,0,
|
287 |
+
286,potted_red_tulip,0,
|
288 |
+
287,potted_orange_tulip,0,
|
289 |
+
288,potted_white_tulip,0,
|
290 |
+
289,potted_pink_tulip,0,
|
291 |
+
290,potted_oxeye_daisy,0,
|
292 |
+
291,potted_cornflower,0,
|
293 |
+
292,potted_lily_of_the_valley,0,
|
294 |
+
293,potted_wither_rose,0,
|
295 |
+
294,potted_red_mushroom,0,
|
296 |
+
295,potted_brown_mushroom,0,
|
297 |
+
296,potted_dead_bush,0,
|
298 |
+
297,potted_cactus,0,
|
299 |
+
298,carrots,31744,
|
300 |
+
299,potatoes,31744,
|
301 |
+
300,oak_button,0,
|
302 |
+
301,spruce_button,0,
|
303 |
+
302,birch_button,0,
|
304 |
+
303,jungle_button,0,
|
305 |
+
304,acacia_button,0,
|
306 |
+
305,dark_oak_button,0,
|
307 |
+
306,skeleton_skull,0,
|
308 |
+
307,skeleton_wall_skull,0,
|
309 |
+
308,wither_skeleton_skull,0,
|
310 |
+
309,wither_skeleton_wall_skull,0,
|
311 |
+
310,zombie_head,0,
|
312 |
+
311,zombie_wall_head,0,
|
313 |
+
312,player_head,0,
|
314 |
+
313,player_wall_head,0,
|
315 |
+
314,creeper_head,0,
|
316 |
+
315,creeper_wall_head,0,
|
317 |
+
316,dragon_head,0,
|
318 |
+
317,dragon_wall_head,0,
|
319 |
+
318,anvil,10987431,
|
320 |
+
319,chipped_anvil,10987431,
|
321 |
+
320,damaged_anvil,10987431,
|
322 |
+
321,trapped_chest,9402184,
|
323 |
+
322,light_weighted_pressure_plate,10987431,
|
324 |
+
323,heavy_weighted_pressure_plate,10987431,
|
325 |
+
324,comparator,0,
|
326 |
+
325,daylight_detector,9402184,
|
327 |
+
326,redstone_block,10987431,
|
328 |
+
327,nether_quartz_ore,7368816,
|
329 |
+
328,hopper,10987431,
|
330 |
+
329,quartz_block,7368816,
|
331 |
+
330,chiseled_quartz_block,7368816,
|
332 |
+
331,quartz_pillar,7368816,
|
333 |
+
332,quartz_stairs,7368816,
|
334 |
+
333,activator_rail,0,
|
335 |
+
334,dropper,7368816,
|
336 |
+
335,white_terracotta,7368816,
|
337 |
+
336,orange_terracotta,7368816,
|
338 |
+
337,magenta_terracotta,7368816,
|
339 |
+
338,light_blue_terracotta,7368816,
|
340 |
+
339,yellow_terracotta,7368816,
|
341 |
+
340,lime_terracotta,7368816,
|
342 |
+
341,pink_terracotta,7368816,
|
343 |
+
342,gray_terracotta,7368816,
|
344 |
+
343,light_gray_terracotta,7368816,
|
345 |
+
344,cyan_terracotta,7368816,
|
346 |
+
345,purple_terracotta,7368816,
|
347 |
+
346,blue_terracotta,7368816,
|
348 |
+
347,brown_terracotta,7368816,
|
349 |
+
348,green_terracotta,7368816,
|
350 |
+
349,red_terracotta,7368816,
|
351 |
+
350,black_terracotta,7368816,
|
352 |
+
351,white_stained_glass_pane,0,
|
353 |
+
352,orange_stained_glass_pane,0,
|
354 |
+
353,magenta_stained_glass_pane,0,
|
355 |
+
354,light_blue_stained_glass_pane,0,
|
356 |
+
355,yellow_stained_glass_pane,0,
|
357 |
+
356,lime_stained_glass_pane,0,
|
358 |
+
357,pink_stained_glass_pane,0,
|
359 |
+
358,gray_stained_glass_pane,0,
|
360 |
+
359,light_gray_stained_glass_pane,0,
|
361 |
+
360,cyan_stained_glass_pane,0,
|
362 |
+
361,purple_stained_glass_pane,0,
|
363 |
+
362,blue_stained_glass_pane,0,
|
364 |
+
363,brown_stained_glass_pane,0,
|
365 |
+
364,green_stained_glass_pane,0,
|
366 |
+
365,red_stained_glass_pane,0,
|
367 |
+
366,black_stained_glass_pane,0,
|
368 |
+
367,acacia_stairs,9402184,
|
369 |
+
368,dark_oak_stairs,9402184,
|
370 |
+
369,slime_block,10791096,
|
371 |
+
370,barrier,0,
|
372 |
+
371,iron_trapdoor,10987431,
|
373 |
+
372,prismarine,7368816,
|
374 |
+
373,prismarine_bricks,7368816,
|
375 |
+
374,dark_prismarine,7368816,
|
376 |
+
375,prismarine_stairs,7368816,
|
377 |
+
376,prismarine_brick_stairs,7368816,
|
378 |
+
377,dark_prismarine_stairs,7368816,
|
379 |
+
378,prismarine_slab,7368816,
|
380 |
+
379,prismarine_brick_slab,7368816,
|
381 |
+
380,dark_prismarine_slab,7368816,
|
382 |
+
381,sea_lantern,0,
|
383 |
+
382,hay_block,8368696,
|
384 |
+
383,white_carpet,13092807,
|
385 |
+
384,orange_carpet,13092807,
|
386 |
+
385,magenta_carpet,13092807,
|
387 |
+
386,light_blue_carpet,13092807,
|
388 |
+
387,yellow_carpet,13092807,
|
389 |
+
388,lime_carpet,13092807,
|
390 |
+
389,pink_carpet,13092807,
|
391 |
+
390,gray_carpet,13092807,
|
392 |
+
391,light_gray_carpet,13092807,
|
393 |
+
392,cyan_carpet,13092807,
|
394 |
+
393,purple_carpet,13092807,
|
395 |
+
394,blue_carpet,13092807,
|
396 |
+
395,brown_carpet,13092807,
|
397 |
+
396,green_carpet,13092807,
|
398 |
+
397,red_carpet,13092807,
|
399 |
+
398,black_carpet,13092807,
|
400 |
+
399,terracotta,7368816,
|
401 |
+
400,coal_block,7368816,
|
402 |
+
401,packed_ice,10526975,
|
403 |
+
402,sunflower,31744,flower
|
404 |
+
403,lilac,31744,flower
|
405 |
+
404,rose_bush,31744,flower
|
406 |
+
405,peony,31744,flower
|
407 |
+
406,tall_grass,31744,plant
|
408 |
+
407,large_fern,31744,plant
|
409 |
+
408,white_banner,9402184,
|
410 |
+
409,orange_banner,9402184,
|
411 |
+
410,magenta_banner,9402184,
|
412 |
+
411,light_blue_banner,9402184,
|
413 |
+
412,yellow_banner,9402184,
|
414 |
+
413,lime_banner,9402184,
|
415 |
+
414,pink_banner,9402184,
|
416 |
+
415,gray_banner,9402184,
|
417 |
+
416,light_gray_banner,9402184,
|
418 |
+
417,cyan_banner,9402184,
|
419 |
+
418,purple_banner,9402184,
|
420 |
+
419,blue_banner,9402184,
|
421 |
+
420,brown_banner,9402184,
|
422 |
+
421,green_banner,9402184,
|
423 |
+
422,red_banner,9402184,
|
424 |
+
423,black_banner,9402184,
|
425 |
+
424,white_wall_banner,9402184,
|
426 |
+
425,orange_wall_banner,9402184,
|
427 |
+
426,magenta_wall_banner,9402184,
|
428 |
+
427,light_blue_wall_banner,9402184,
|
429 |
+
428,yellow_wall_banner,9402184,
|
430 |
+
429,lime_wall_banner,9402184,
|
431 |
+
430,pink_wall_banner,9402184,
|
432 |
+
431,gray_wall_banner,9402184,
|
433 |
+
432,light_gray_wall_banner,9402184,
|
434 |
+
433,cyan_wall_banner,9402184,
|
435 |
+
434,purple_wall_banner,9402184,
|
436 |
+
435,blue_wall_banner,9402184,
|
437 |
+
436,brown_wall_banner,9402184,
|
438 |
+
437,green_wall_banner,9402184,
|
439 |
+
438,red_wall_banner,9402184,
|
440 |
+
439,black_wall_banner,9402184,
|
441 |
+
440,red_sandstone,7368816,
|
442 |
+
441,chiseled_red_sandstone,7368816,
|
443 |
+
442,cut_red_sandstone,7368816,
|
444 |
+
443,red_sandstone_stairs,7368816,
|
445 |
+
444,oak_slab,9402184,
|
446 |
+
445,spruce_slab,9402184,
|
447 |
+
446,birch_slab,9402184,
|
448 |
+
447,jungle_slab,9402184,
|
449 |
+
448,acacia_slab,9402184,
|
450 |
+
449,dark_oak_slab,9402184,
|
451 |
+
450,stone_slab,7368816,
|
452 |
+
451,smooth_stone_slab,7368816,
|
453 |
+
452,sandstone_slab,7368816,
|
454 |
+
453,cut_sandstone_slab,7368816,
|
455 |
+
454,petrified_oak_slab,7368816,
|
456 |
+
455,cobblestone_slab,7368816,
|
457 |
+
456,brick_slab,7368816,
|
458 |
+
457,stone_brick_slab,7368816,
|
459 |
+
458,nether_brick_slab,7368816,
|
460 |
+
459,quartz_slab,7368816,
|
461 |
+
460,red_sandstone_slab,7368816,
|
462 |
+
461,cut_red_sandstone_slab,7368816,
|
463 |
+
462,purpur_slab,7368816,
|
464 |
+
463,smooth_stone,7368816,
|
465 |
+
464,smooth_sandstone,7368816,
|
466 |
+
465,smooth_quartz,7368816,
|
467 |
+
466,smooth_red_sandstone,7368816,
|
468 |
+
467,spruce_fence_gate,9402184,
|
469 |
+
468,birch_fence_gate,9402184,
|
470 |
+
469,jungle_fence_gate,9402184,
|
471 |
+
470,acacia_fence_gate,9402184,
|
472 |
+
471,dark_oak_fence_gate,9402184,
|
473 |
+
472,spruce_fence,9402184,
|
474 |
+
473,birch_fence,9402184,
|
475 |
+
474,jungle_fence,9402184,
|
476 |
+
475,acacia_fence,9402184,
|
477 |
+
476,dark_oak_fence,9402184,
|
478 |
+
477,spruce_door,9402184,
|
479 |
+
478,birch_door,9402184,
|
480 |
+
479,jungle_door,9402184,
|
481 |
+
480,acacia_door,9402184,
|
482 |
+
481,dark_oak_door,9402184,
|
483 |
+
482,end_rod,0,
|
484 |
+
483,chorus_plant,31744,
|
485 |
+
484,chorus_flower,31744,
|
486 |
+
485,purpur_block,7368816,
|
487 |
+
486,purpur_pillar,7368816,
|
488 |
+
487,purpur_stairs,7368816,
|
489 |
+
488,end_stone_bricks,7368816,
|
490 |
+
489,beetroots,31744,
|
491 |
+
490,grass_path,9923917,
|
492 |
+
491,end_gateway,0,
|
493 |
+
492,repeating_command_block,10987431,
|
494 |
+
493,chain_command_block,10987431,
|
495 |
+
494,frosted_ice,10526975,
|
496 |
+
495,magma_block,7368816,
|
497 |
+
496,nether_wart_block,8368696,
|
498 |
+
497,red_nether_bricks,7368816,
|
499 |
+
498,bone_block,7368816,
|
500 |
+
499,structure_void,0,
|
501 |
+
500,observer,7368816,
|
502 |
+
501,shulker_box,8339378,
|
503 |
+
502,white_shulker_box,8339378,
|
504 |
+
503,orange_shulker_box,8339378,
|
505 |
+
504,magenta_shulker_box,8339378,
|
506 |
+
505,light_blue_shulker_box,8339378,
|
507 |
+
506,yellow_shulker_box,8339378,
|
508 |
+
507,lime_shulker_box,8339378,
|
509 |
+
508,pink_shulker_box,8339378,
|
510 |
+
509,gray_shulker_box,8339378,
|
511 |
+
510,light_gray_shulker_box,8339378,
|
512 |
+
511,cyan_shulker_box,8339378,
|
513 |
+
512,purple_shulker_box,8339378,
|
514 |
+
513,blue_shulker_box,8339378,
|
515 |
+
514,brown_shulker_box,8339378,
|
516 |
+
515,green_shulker_box,8339378,
|
517 |
+
516,red_shulker_box,8339378,
|
518 |
+
517,black_shulker_box,8339378,
|
519 |
+
518,white_glazed_terracotta,7368816,
|
520 |
+
519,orange_glazed_terracotta,7368816,
|
521 |
+
520,magenta_glazed_terracotta,7368816,
|
522 |
+
521,light_blue_glazed_terracotta,7368816,
|
523 |
+
522,yellow_glazed_terracotta,7368816,
|
524 |
+
523,lime_glazed_terracotta,7368816,
|
525 |
+
524,pink_glazed_terracotta,7368816,
|
526 |
+
525,gray_glazed_terracotta,7368816,
|
527 |
+
526,light_gray_glazed_terracotta,7368816,
|
528 |
+
527,cyan_glazed_terracotta,7368816,
|
529 |
+
528,purple_glazed_terracotta,7368816,
|
530 |
+
529,blue_glazed_terracotta,7368816,
|
531 |
+
530,brown_glazed_terracotta,7368816,
|
532 |
+
531,green_glazed_terracotta,7368816,
|
533 |
+
532,red_glazed_terracotta,7368816,
|
534 |
+
533,black_glazed_terracotta,7368816,
|
535 |
+
534,white_concrete,7368816,
|
536 |
+
535,orange_concrete,7368816,
|
537 |
+
536,magenta_concrete,7368816,
|
538 |
+
537,light_blue_concrete,7368816,
|
539 |
+
538,yellow_concrete,7368816,
|
540 |
+
539,lime_concrete,7368816,
|
541 |
+
540,pink_concrete,7368816,
|
542 |
+
541,gray_concrete,7368816,
|
543 |
+
542,light_gray_concrete,7368816,
|
544 |
+
543,cyan_concrete,7368816,
|
545 |
+
544,purple_concrete,7368816,
|
546 |
+
545,blue_concrete,7368816,
|
547 |
+
546,brown_concrete,7368816,
|
548 |
+
547,green_concrete,7368816,
|
549 |
+
548,red_concrete,7368816,
|
550 |
+
549,black_concrete,7368816,
|
551 |
+
550,white_concrete_powder,16247203,
|
552 |
+
551,orange_concrete_powder,16247203,
|
553 |
+
552,magenta_concrete_powder,16247203,
|
554 |
+
553,light_blue_concrete_powder,16247203,
|
555 |
+
554,yellow_concrete_powder,16247203,
|
556 |
+
555,lime_concrete_powder,16247203,
|
557 |
+
556,pink_concrete_powder,16247203,
|
558 |
+
557,gray_concrete_powder,16247203,
|
559 |
+
558,light_gray_concrete_powder,16247203,
|
560 |
+
559,cyan_concrete_powder,16247203,
|
561 |
+
560,purple_concrete_powder,16247203,
|
562 |
+
561,blue_concrete_powder,16247203,
|
563 |
+
562,brown_concrete_powder,16247203,
|
564 |
+
563,green_concrete_powder,16247203,
|
565 |
+
564,red_concrete_powder,16247203,
|
566 |
+
565,black_concrete_powder,16247203,
|
567 |
+
566,kelp,4210943,
|
568 |
+
567,kelp_plant,4210943,
|
569 |
+
568,dried_kelp_block,8368696,
|
570 |
+
569,turtle_egg,31744,
|
571 |
+
570,dead_tube_coral_block,7368816,
|
572 |
+
571,dead_brain_coral_block,7368816,
|
573 |
+
572,dead_bubble_coral_block,7368816,
|
574 |
+
573,dead_fire_coral_block,7368816,
|
575 |
+
574,dead_horn_coral_block,7368816,
|
576 |
+
575,tube_coral_block,7368816,
|
577 |
+
576,brain_coral_block,7368816,
|
578 |
+
577,bubble_coral_block,7368816,
|
579 |
+
578,fire_coral_block,7368816,
|
580 |
+
579,horn_coral_block,7368816,
|
581 |
+
580,dead_tube_coral,7368816,
|
582 |
+
581,dead_brain_coral,7368816,
|
583 |
+
582,dead_bubble_coral,7368816,
|
584 |
+
583,dead_fire_coral,7368816,
|
585 |
+
584,dead_horn_coral,7368816,
|
586 |
+
585,tube_coral,4210943,
|
587 |
+
586,brain_coral,4210943,
|
588 |
+
587,bubble_coral,4210943,
|
589 |
+
588,fire_coral,4210943,
|
590 |
+
589,horn_coral,4210943,
|
591 |
+
590,dead_tube_coral_fan,7368816,
|
592 |
+
591,dead_brain_coral_fan,7368816,
|
593 |
+
592,dead_bubble_coral_fan,7368816,
|
594 |
+
593,dead_fire_coral_fan,7368816,
|
595 |
+
594,dead_horn_coral_fan,7368816,
|
596 |
+
595,tube_coral_fan,4210943,
|
597 |
+
596,brain_coral_fan,4210943,
|
598 |
+
597,bubble_coral_fan,4210943,
|
599 |
+
598,fire_coral_fan,4210943,
|
600 |
+
599,horn_coral_fan,4210943,
|
601 |
+
600,dead_tube_coral_wall_fan,7368816,
|
602 |
+
601,dead_brain_coral_wall_fan,7368816,
|
603 |
+
602,dead_bubble_coral_wall_fan,7368816,
|
604 |
+
603,dead_fire_coral_wall_fan,7368816,
|
605 |
+
604,dead_horn_coral_wall_fan,7368816,
|
606 |
+
605,tube_coral_wall_fan,4210943,
|
607 |
+
606,brain_coral_wall_fan,4210943,
|
608 |
+
607,bubble_coral_wall_fan,4210943,
|
609 |
+
608,fire_coral_wall_fan,4210943,
|
610 |
+
609,horn_coral_wall_fan,4210943,
|
611 |
+
610,sea_pickle,4210943,
|
612 |
+
611,blue_ice,10526975,
|
613 |
+
612,conduit,0,
|
614 |
+
613,bamboo_sapling,9402184,plant
|
615 |
+
614,bamboo,9402184,plant
|
616 |
+
615,potted_bamboo,0,
|
617 |
+
616,void_air,0,dirt
|
618 |
+
617,cave_air,0,dirt
|
619 |
+
618,bubble_column,4210943,
|
620 |
+
619,polished_granite_stairs,7368816,
|
621 |
+
620,smooth_red_sandstone_stairs,7368816,
|
622 |
+
621,mossy_stone_brick_stairs,7368816,
|
623 |
+
622,polished_diorite_stairs,7368816,
|
624 |
+
623,mossy_cobblestone_stairs,7368816,
|
625 |
+
624,end_stone_brick_stairs,7368816,
|
626 |
+
625,stone_stairs,7368816,
|
627 |
+
626,smooth_sandstone_stairs,7368816,
|
628 |
+
627,smooth_quartz_stairs,7368816,
|
629 |
+
628,granite_stairs,7368816,
|
630 |
+
629,andesite_stairs,7368816,
|
631 |
+
630,red_nether_brick_stairs,7368816,
|
632 |
+
631,polished_andesite_stairs,7368816,
|
633 |
+
632,diorite_stairs,7368816,
|
634 |
+
633,polished_granite_slab,7368816,
|
635 |
+
634,smooth_red_sandstone_slab,7368816,
|
636 |
+
635,mossy_stone_brick_slab,7368816,
|
637 |
+
636,polished_diorite_slab,7368816,
|
638 |
+
637,mossy_cobblestone_slab,7368816,
|
639 |
+
638,end_stone_brick_slab,7368816,
|
640 |
+
639,smooth_sandstone_slab,7368816,
|
641 |
+
640,smooth_quartz_slab,7368816,
|
642 |
+
641,granite_slab,7368816,
|
643 |
+
642,andesite_slab,7368816,
|
644 |
+
643,red_nether_brick_slab,7368816,
|
645 |
+
644,polished_andesite_slab,7368816,
|
646 |
+
645,diorite_slab,7368816,
|
647 |
+
646,brick_wall,7368816,
|
648 |
+
647,prismarine_wall,7368816,
|
649 |
+
648,red_sandstone_wall,7368816,
|
650 |
+
649,mossy_stone_brick_wall,7368816,
|
651 |
+
650,granite_wall,7368816,
|
652 |
+
651,stone_brick_wall,7368816,
|
653 |
+
652,nether_brick_wall,7368816,
|
654 |
+
653,andesite_wall,7368816,
|
655 |
+
654,red_nether_brick_wall,7368816,
|
656 |
+
655,sandstone_wall,7368816,
|
657 |
+
656,end_stone_brick_wall,7368816,
|
658 |
+
657,diorite_wall,7368816,
|
659 |
+
658,scaffolding,0,
|
660 |
+
659,loom,9402184,
|
661 |
+
660,barrel,9402184,
|
662 |
+
661,smoker,7368816,
|
663 |
+
662,blast_furnace,7368816,
|
664 |
+
663,cartography_table,9402184,
|
665 |
+
664,fletching_table,9402184,
|
666 |
+
665,grindstone,10987431,
|
667 |
+
666,lectern,9402184,
|
668 |
+
667,smithing_table,9402184,
|
669 |
+
668,stonecutter,7368816,
|
670 |
+
669,bell,10987431,
|
671 |
+
670,lantern,10987431,
|
672 |
+
671,campfire,9402184,
|
673 |
+
672,sweet_berry_bush,31744,
|
674 |
+
673,structure_block,10987431,
|
675 |
+
674,jigsaw,10987431,
|
676 |
+
675,composter,9402184,
|
677 |
+
676,bee_nest,9402184,
|
678 |
+
677,beehive,9402184,
|
679 |
+
678,honey_block,10791096,
|
680 |
+
679,honeycomb_block,10791096,
|
imaginaire/model_utils/gancraft/loss.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class GANLoss(nn.Module):
|
11 |
+
def __init__(self, target_real_label=1.0, target_fake_label=0.0):
|
12 |
+
r"""GAN loss constructor.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
target_real_label (float): Desired output label for the real images.
|
16 |
+
target_fake_label (float): Desired output label for the fake images.
|
17 |
+
"""
|
18 |
+
super(GANLoss, self).__init__()
|
19 |
+
self.real_label = target_real_label
|
20 |
+
self.fake_label = target_fake_label
|
21 |
+
self.real_label_tensor = None
|
22 |
+
self.fake_label_tensor = None
|
23 |
+
|
24 |
+
def forward(self, input_x, t_real, weight=None,
|
25 |
+
reduce_dim=True, dis_update=True):
|
26 |
+
r"""GAN loss computation.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
input_x (tensor or list of tensors): Output values.
|
30 |
+
t_real (boolean): Is this output value for real images.
|
31 |
+
reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use
|
32 |
+
multi-resolution discriminators.
|
33 |
+
weight (float): Weight to scale the loss value.
|
34 |
+
dis_update (boolean): Updating the discriminator or the generator.
|
35 |
+
Returns:
|
36 |
+
loss (tensor): Loss value.
|
37 |
+
"""
|
38 |
+
if isinstance(input_x, list):
|
39 |
+
loss = 0
|
40 |
+
for pred_i in input_x:
|
41 |
+
if isinstance(pred_i, list):
|
42 |
+
pred_i = pred_i[-1]
|
43 |
+
loss_tensor = self.loss(pred_i, t_real, weight,
|
44 |
+
reduce_dim, dis_update)
|
45 |
+
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
|
46 |
+
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
|
47 |
+
loss += new_loss
|
48 |
+
return loss / len(input_x)
|
49 |
+
else:
|
50 |
+
return self.loss(input_x, t_real, weight, reduce_dim, dis_update)
|
51 |
+
|
52 |
+
def loss(self, input_x, t_real, weight=None,
|
53 |
+
reduce_dim=True, dis_update=True):
|
54 |
+
r"""N+1 label GAN loss computation.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
input_x (tensor): Output values.
|
58 |
+
t_real (boolean): Is this output value for real images.
|
59 |
+
reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use
|
60 |
+
multi-resolution discriminators.
|
61 |
+
weight (float): Weight to scale the loss value.
|
62 |
+
dis_update (boolean): Updating the discriminator or the generator.
|
63 |
+
Returns:
|
64 |
+
loss (tensor): Loss value.
|
65 |
+
"""
|
66 |
+
assert reduce_dim is True
|
67 |
+
pred = input_x['pred'].clone()
|
68 |
+
label = input_x['label'].clone()
|
69 |
+
batch_size = pred.size(0)
|
70 |
+
|
71 |
+
# ignore label 0
|
72 |
+
label[:, 0, ...] = 0
|
73 |
+
pred[:, 0, ...] = 0
|
74 |
+
pred = F.log_softmax(pred, dim=1)
|
75 |
+
assert pred.size(1) == (label.size(1) + 1)
|
76 |
+
if dis_update:
|
77 |
+
if t_real:
|
78 |
+
pred_real = pred[:, :-1, :, :]
|
79 |
+
loss = - label * pred_real
|
80 |
+
loss = torch.sum(loss, dim=1, keepdim=True)
|
81 |
+
else:
|
82 |
+
pred_fake = pred[:, -1, None, :, :] # N plus 1
|
83 |
+
loss = - pred_fake
|
84 |
+
else:
|
85 |
+
assert t_real, "GAN loss must be aiming for real."
|
86 |
+
pred_real = pred[:, :-1, :, :]
|
87 |
+
loss = - label * pred_real
|
88 |
+
loss = torch.sum(loss, dim=1, keepdim=True)
|
89 |
+
|
90 |
+
if weight is not None:
|
91 |
+
loss = loss * weight
|
92 |
+
if reduce_dim:
|
93 |
+
loss = torch.mean(loss)
|
94 |
+
else:
|
95 |
+
loss = loss.view(batch_size, -1).mean(dim=1)
|
96 |
+
return loss
|