gabrielsemiceki9 commited on
Commit
b72e09b
·
verified ·
1 Parent(s): 5f7ebb7

Upload 125 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. activation.py +18 -0
  3. app_gradio.py +137 -0
  4. assets/biome_image.png +0 -0
  5. assets/sample_traj.gif +3 -0
  6. assets/teaser.gif +3 -0
  7. configs/img2lmdb.yaml +177 -0
  8. configs/landscape1m.yaml +175 -0
  9. configs/scenedreamer_inference.yaml +93 -0
  10. configs/scenedreamer_train.yaml +223 -0
  11. encoding.py +67 -0
  12. environment.yaml +44 -0
  13. gridencoder/__init__.py +1 -0
  14. gridencoder/backend.py +40 -0
  15. gridencoder/grid.py +224 -0
  16. gridencoder/setup.py +50 -0
  17. gridencoder/src/bindings.cpp +8 -0
  18. gridencoder/src/gridencoder.cu +478 -0
  19. gridencoder/src/gridencoder.h +15 -0
  20. imaginaire/__init__.py +4 -0
  21. imaginaire/config.py +238 -0
  22. imaginaire/discriminators/__init__.py +0 -0
  23. imaginaire/discriminators/gancraft.py +278 -0
  24. imaginaire/generators/__init__.py +4 -0
  25. imaginaire/generators/gancraft_base.py +603 -0
  26. imaginaire/generators/scenedreamer.py +851 -0
  27. imaginaire/generators/spade.py +571 -0
  28. imaginaire/layers/__init__.py +27 -0
  29. imaginaire/layers/activation_norm.py +629 -0
  30. imaginaire/layers/conv.py +1377 -0
  31. imaginaire/layers/misc.py +61 -0
  32. imaginaire/layers/non_local.py +88 -0
  33. imaginaire/layers/nonlinearity.py +65 -0
  34. imaginaire/layers/residual.py +1411 -0
  35. imaginaire/layers/residual_deep.py +346 -0
  36. imaginaire/layers/vit.py +204 -0
  37. imaginaire/layers/weight_norm.py +267 -0
  38. imaginaire/losses/__init__.py +18 -0
  39. imaginaire/losses/feature_matching.py +38 -0
  40. imaginaire/losses/gan.py +173 -0
  41. imaginaire/losses/info_nce.py +87 -0
  42. imaginaire/losses/kl.py +23 -0
  43. imaginaire/losses/perceptual.py +395 -0
  44. imaginaire/losses/weighted_mse.py +28 -0
  45. imaginaire/model_utils/__init__.py +4 -0
  46. imaginaire/model_utils/gancraft/camctl.py +679 -0
  47. imaginaire/model_utils/gancraft/gaugan_lbl2col.csv +182 -0
  48. imaginaire/model_utils/gancraft/gaugan_reduction.csv +182 -0
  49. imaginaire/model_utils/gancraft/id2name_gg.csv +680 -0
  50. imaginaire/model_utils/gancraft/loss.py +96 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/sample_traj.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/teaser.gif filter=lfs diff=lfs merge=lfs -text
activation.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Function
3
+ from torch.cuda.amp import custom_bwd, custom_fwd
4
+
5
+ class _trunc_exp(Function):
6
+ @staticmethod
7
+ @custom_fwd(cast_inputs=torch.float32) # cast to float32
8
+ def forward(ctx, x):
9
+ ctx.save_for_backward(x)
10
+ return torch.exp(x)
11
+
12
+ @staticmethod
13
+ @custom_bwd
14
+ def backward(ctx, g):
15
+ x = ctx.saved_tensors[0]
16
+ return g * torch.exp(x.clamp(-15, 15))
17
+
18
+ trunc_exp = _trunc_exp.apply
app_gradio.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import importlib
5
+ import argparse
6
+ from imaginaire.config import Config
7
+ from imaginaire.utils.cudnn import init_cudnn
8
+ import gradio as gr
9
+ from PIL import Image
10
+
11
+
12
+ class WrappedModel(nn.Module):
13
+ r"""Dummy wrapping the module.
14
+ """
15
+
16
+ def __init__(self, module):
17
+ super(WrappedModel, self).__init__()
18
+ self.module = module
19
+
20
+ def forward(self, *args, **kwargs):
21
+ r"""PyTorch module forward function overload."""
22
+ return self.module(*args, **kwargs)
23
+
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser(description='Training')
26
+ parser.add_argument('--config', type=str, default='./configs/scenedreamer_inference.yaml', help='Path to the training config file.')
27
+ parser.add_argument('--checkpoint', default='./scenedreamer_released.pt',
28
+ help='Checkpoint path.')
29
+ parser.add_argument('--output_dir', type=str, default='./test/',
30
+ help='Location to save the image outputs')
31
+ parser.add_argument('--seed', type=int, default=8888,
32
+ help='Random seed.')
33
+ args = parser.parse_args()
34
+ return args
35
+
36
+
37
+ args = parse_args()
38
+ cfg = Config(args.config)
39
+
40
+ # Initialize cudnn.
41
+ init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)
42
+
43
+ # Initialize data loaders and models.
44
+
45
+ lib_G = importlib.import_module(cfg.gen.type)
46
+ net_G = lib_G.Generator(cfg.gen, cfg.data)
47
+ net_G = net_G.to('cuda')
48
+ net_G = WrappedModel(net_G)
49
+
50
+ if args.checkpoint == '':
51
+ raise NotImplementedError("No checkpoint is provided for inference!")
52
+
53
+ # Load checkpoint.
54
+ # trainer.load_checkpoint(cfg, args.checkpoint)
55
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
56
+ net_G.load_state_dict(checkpoint['net_G'])
57
+
58
+ # Do inference.
59
+ net_G = net_G.module
60
+ net_G.eval()
61
+ for name, param in net_G.named_parameters():
62
+ param.requires_grad = False
63
+ torch.cuda.empty_cache()
64
+ world_dir = os.path.join(args.output_dir)
65
+ os.makedirs(world_dir, exist_ok=True)
66
+
67
+
68
+
69
+ def get_bev(seed):
70
+ print('[PCGGenerator] Generating BEV scene representation...')
71
+ os.system('python terrain_generator.py --size {} --seed {} --outdir {}'.format(net_G.voxel.sample_size, seed, world_dir))
72
+ heightmap_path = os.path.join(world_dir, 'heightmap.png')
73
+ semantic_path = os.path.join(world_dir, 'colormap.png')
74
+ heightmap = Image.open(heightmap_path)
75
+ semantic = Image.open(semantic_path)
76
+ return semantic, heightmap
77
+
78
+ def get_video(seed, num_frames, reso_h, reso_w):
79
+ device = torch.device('cuda')
80
+ rng_cuda = torch.Generator(device=device)
81
+ rng_cuda = rng_cuda.manual_seed(seed)
82
+ torch.manual_seed(seed)
83
+ torch.cuda.manual_seed(seed)
84
+ net_G.voxel.next_world(device, world_dir, checkpoint)
85
+ cam_mode = cfg.inference_args.camera_mode
86
+ cfg.inference_args.cam_maxstep = num_frames
87
+ cfg.inference_args.resolution_hw = [reso_h, reso_w]
88
+ current_outdir = os.path.join(world_dir, 'camera_{:02d}'.format(cam_mode))
89
+ os.makedirs(current_outdir, exist_ok=True)
90
+ z = torch.empty(1, net_G.style_dims, dtype=torch.float32, device=device)
91
+ z.normal_(generator=rng_cuda)
92
+ net_G.inference_givenstyle(z, current_outdir, **vars(cfg.inference_args))
93
+ return os.path.join(current_outdir, 'rgb_render.mp4')
94
+
95
+ markdown=f'''
96
+ # SceneDreamer: Unbounded 3D Scene Generation from 2D Image Collections
97
+
98
+ Authored by Zhaoxi Chen, Guangcong Wang, Ziwei Liu
99
+ ### Useful links:
100
+ - [Official Github Repo](https://github.com/FrozenBurning/SceneDreamer)
101
+ - [Project Page](https://scene-dreamer.github.io/)
102
+ - [arXiv Link](https://arxiv.org/abs/2302.01330)
103
+ Licensed under the S-Lab License.
104
+ We offer a sampled scene whose BEVs are shown on the right. You can also use the button "Generate BEV" to randomly sample a new 3D world represented by a height map and a semantic map. But it requires a long time.
105
+
106
+ To render video, push the button "Render" to generate a camera trajectory flying through the world. You can specify rendering options as shown below!
107
+ '''
108
+
109
+ with gr.Blocks() as demo:
110
+ with gr.Row():
111
+ with gr.Column():
112
+ gr.Markdown(markdown)
113
+ with gr.Column():
114
+ with gr.Row():
115
+ with gr.Column():
116
+ semantic = gr.Image(value='./test/colormap.png',type="pil", shape=(512, 512))
117
+ with gr.Column():
118
+ height = gr.Image(value='./test/heightmap.png', type="pil", shape=(512, 512))
119
+ with gr.Row():
120
+ # with gr.Column():
121
+ # image = gr.Image(type='pil', shape(540, 960))
122
+ with gr.Column():
123
+ video = gr.Video()
124
+ with gr.Row():
125
+ num_frames = gr.Slider(minimum=10, maximum=200, value=20, step=1, label='Number of rendered frames')
126
+ user_seed = gr.Slider(minimum=0, maximum=999999, value=8888, step=1, label='Random seed')
127
+ resolution_h = gr.Slider(minimum=256, maximum=2160, value=270, step=1, label='Height of rendered image')
128
+ resolution_w = gr.Slider(minimum=256, maximum=3840, value=480, step=1, label='Width of rendered image')
129
+
130
+ with gr.Row():
131
+ btn = gr.Button(value="Generate BEV")
132
+ btn_2=gr.Button(value="Render")
133
+
134
+ btn.click(get_bev,[user_seed],[semantic, height])
135
+ btn_2.click(get_video,[user_seed, num_frames, resolution_h, resolution_w], [video])
136
+
137
+ demo.launch(debug=True)
assets/biome_image.png ADDED
assets/sample_traj.gif ADDED

Git LFS Details

  • SHA256: 9bff3115871f1a78fbe237b44ba51e1bf3eb0578c2850815b0b71dd012c3be6a
  • Pointer size: 132 Bytes
  • Size of remote file: 6.58 MB
assets/teaser.gif ADDED

Git LFS Details

  • SHA256: 12c4e75a99a9a5dc17b89fc1c272fa994068bdd83e0636202f001e875f203a05
  • Pointer size: 133 Bytes
  • Size of remote file: 24.5 MB
configs/img2lmdb.yaml ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ inference_args:
2
+ random_style: True
3
+ use_fixed_random_style: False
4
+ keep_original_size: True
5
+
6
+ image_save_iter: 5000
7
+ snapshot_save_epoch: 5
8
+ max_epoch: 400
9
+ logging_iter: 100
10
+ trainer:
11
+ type: imaginaire.trainers.spade
12
+ model_average_config:
13
+ enabled: True
14
+ beta: 0.9999
15
+ start_iteration: 1000
16
+ num_batch_norm_estimation_iterations: 30
17
+ amp_config:
18
+ enabled: True
19
+ gan_mode: hinge
20
+ gan_relativistic: False
21
+ perceptual_loss:
22
+ mode: 'vgg19'
23
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
24
+ weights: [0.03125, 0.0625, 0.125, 0.25, 1.0]
25
+ fp16: True
26
+ loss_weight:
27
+ gan: 1.0
28
+ perceptual: 10.0
29
+ feature_matching: 10.0
30
+ kl: 0.05
31
+ init:
32
+ type: xavier
33
+ gain: 0.02
34
+ gen_opt:
35
+ type: adam
36
+ lr: 0.0001
37
+ adam_beta1: 0.
38
+ adam_beta2: 0.999
39
+ lr_policy:
40
+ iteration_mode: False
41
+ type: step
42
+ step_size: 400
43
+ gamma: 0.1
44
+ dis_opt:
45
+ type: adam
46
+ lr: 0.0004
47
+ adam_beta1: 0.
48
+ adam_beta2: 0.999
49
+ lr_policy:
50
+ iteration_mode: False
51
+ type: step
52
+ step_size: 400
53
+ gamma: 0.1
54
+ gen:
55
+ type: imaginaire.generators.spade
56
+ version: v20
57
+ style_dims: 256
58
+ num_filters: 128
59
+ kernel_size: 3
60
+ weight_norm_type: 'spectral'
61
+ use_posenc_in_input_layer: False
62
+ global_adaptive_norm_type: 'sync_batch'
63
+ activation_norm_params:
64
+ num_filters: 128
65
+ kernel_size: 5
66
+ separate_projection: True
67
+ activation_norm_type: 'sync_batch'
68
+ style_enc:
69
+ num_filters: 64
70
+ kernel_size: 3
71
+ dis:
72
+ type: imaginaire.discriminators.spade
73
+ kernel_size: 4
74
+ num_filters: 128
75
+ max_num_filters: 512
76
+ num_discriminators: 2
77
+ num_layers: 5
78
+ activation_norm_type: 'none'
79
+ weight_norm_type: 'spectral'
80
+
81
+ # Data options.
82
+ data:
83
+ type: imaginaire.datasets.paired_images
84
+ # How many data loading workers per GPU?
85
+ num_workers: 8
86
+ input_types:
87
+ - images:
88
+ ext: jpg
89
+ num_channels: 3
90
+ normalize: True
91
+ use_dont_care: False
92
+ - seg_maps:
93
+ ext: jpg
94
+ num_channels: 1
95
+ is_mask: True
96
+ normalize: False
97
+ # - edge_maps:
98
+ # ext: png
99
+ # num_channels: 1
100
+ # normalize: False
101
+
102
+ full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels
103
+ use_dont_care: True
104
+ one_hot_num_classes:
105
+ seg_maps: 183
106
+ input_labels:
107
+ - seg_maps
108
+ # - edge_maps
109
+
110
+ # Which lmdb contains the ground truth image.
111
+ input_image:
112
+ - images
113
+
114
+ # Train dataset details.
115
+ train:
116
+ # Input LMDBs.
117
+ roots:
118
+ - ./data/lhq/train
119
+ # Batch size per GPU.
120
+ batch_size: 4
121
+ # Data augmentations to be performed in given order.
122
+ augmentations:
123
+ resize_smallest_side: 256
124
+ # Rotate in (-rotate, rotate) in degrees.
125
+ rotate: 0
126
+ # Scale image by factor \in [1, 1+random_scale_limit].
127
+ random_scale_limit: 0.2
128
+ # Horizontal flip?
129
+ horizontal_flip: True
130
+ # Crop size.
131
+ random_crop_h_w: 256, 256
132
+ # Train dataset details.
133
+ val:
134
+ # Input LMDBs.
135
+ roots:
136
+ - ./data/lhq/val
137
+ # Batch size per GPU.
138
+ batch_size: 4
139
+ # Data augmentations to be performed in given order.
140
+ augmentations:
141
+ # Crop size.
142
+ resize_h_w: 256, 256
143
+
144
+ test_data:
145
+ type: imaginaire.datasets.paired_images
146
+ num_workers: 8
147
+ input_types:
148
+ - seg_maps:
149
+ ext: jpg
150
+ num_channels: 1
151
+ is_mask: True
152
+ normalize: False
153
+ # - edge_maps:
154
+ # ext: png
155
+ # num_channels: 1
156
+ # normalize: False
157
+
158
+ full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels
159
+ use_dont_care: True
160
+ one_hot_num_classes:
161
+ seg_maps: 183
162
+ input_labels:
163
+ - seg_maps
164
+ # - edge_maps
165
+
166
+ paired: True
167
+ # Validation dataset details.
168
+ test:
169
+ is_lmdb: False
170
+ roots:
171
+ - ./data/lhq/train
172
+ # Batch size per GPU.
173
+ batch_size: 1
174
+ # If resize_h_w is not given, then it is assumed to be same as crop_h_w.
175
+ augmentations:
176
+ resize_h_w: 256, 256
177
+ horizontal_flip: False
configs/landscape1m.yaml ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_weight: ./landscape1m-segformer.pt
2
+
3
+ inference_args:
4
+ random_style: True
5
+ use_fixed_random_style: False
6
+ keep_original_size: True
7
+
8
+ image_save_iter: 5000
9
+ snapshot_save_epoch: 5
10
+ snapshot_save_iter: 30000
11
+ max_epoch: 400
12
+ logging_iter: 100
13
+ trainer:
14
+ type: imaginaire.trainers.spade
15
+ model_average_config:
16
+ enabled: True
17
+ beta: 0.9999
18
+ start_iteration: 1000
19
+ num_batch_norm_estimation_iterations: 30
20
+ amp_config:
21
+ enabled: True
22
+ gan_mode: hinge
23
+ gan_relativistic: False
24
+ perceptual_loss:
25
+ mode: 'vgg19'
26
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
27
+ weights: [0.03125, 0.0625, 0.125, 0.25, 1.0]
28
+ fp16: True
29
+ loss_weight:
30
+ gan: 1.0
31
+ perceptual: 10.0
32
+ feature_matching: 10.0
33
+ kl: 0.05
34
+ init:
35
+ type: xavier
36
+ gain: 0.02
37
+ gen_opt:
38
+ type: adam
39
+ lr: 0.0001
40
+ adam_beta1: 0.
41
+ adam_beta2: 0.999
42
+ lr_policy:
43
+ iteration_mode: False
44
+ type: step
45
+ step_size: 400
46
+ gamma: 0.1
47
+ dis_opt:
48
+ type: adam
49
+ lr: 0.0004
50
+ adam_beta1: 0.
51
+ adam_beta2: 0.999
52
+ lr_policy:
53
+ iteration_mode: False
54
+ type: step
55
+ step_size: 400
56
+ gamma: 0.1
57
+ gen:
58
+ type: imaginaire.generators.spade
59
+ version: v20
60
+ output_multiplier: 0.5
61
+ image_channels: 3
62
+ num_labels: 184
63
+ style_dims: 256
64
+ num_filters: 128
65
+ kernel_size: 3
66
+ weight_norm_type: 'spectral'
67
+ use_posenc_in_input_layer: False
68
+ global_adaptive_norm_type: 'sync_batch'
69
+ activation_norm_params:
70
+ num_filters: 128
71
+ kernel_size: 5
72
+ separate_projection: True
73
+ activation_norm_type: 'sync_batch'
74
+ style_enc:
75
+ num_filters: 64
76
+ kernel_size: 3
77
+ dis:
78
+ type: imaginaire.discriminators.spade
79
+ kernel_size: 4
80
+ num_filters: 128
81
+ max_num_filters: 512
82
+ num_discriminators: 2
83
+ num_layers: 5
84
+ activation_norm_type: 'none'
85
+ weight_norm_type: 'spectral'
86
+
87
+ # Data options.
88
+ data:
89
+ type: imaginaire.datasets.paired_images
90
+ # How many data loading workers per GPU?
91
+ num_workers: 8
92
+ input_types:
93
+ - images:
94
+ ext: jpg
95
+ num_channels: 3
96
+ normalize: True
97
+ use_dont_care: False
98
+ - seg_maps:
99
+ ext: jpg
100
+ num_channels: 1
101
+ is_mask: True
102
+ normalize: False
103
+
104
+ full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels
105
+ use_dont_care: True
106
+ one_hot_num_classes:
107
+ seg_maps: 183
108
+ input_labels:
109
+ - seg_maps
110
+
111
+ # Which lmdb contains the ground truth image.
112
+ input_image:
113
+ - images
114
+
115
+ # Train dataset details.
116
+ train:
117
+ # Input LMDBs.
118
+ dataset_type: lmdb
119
+ roots:
120
+ - ./data/lhq_lmdb/train
121
+ # Batch size per GPU.
122
+ batch_size: 4
123
+ # Data augmentations to be performed in given order.
124
+ augmentations:
125
+ resize_smallest_side: 512
126
+ # Rotate in (-rotate, rotate) in degrees.
127
+ rotate: 0
128
+ # Scale image by factor \in [1, 1+random_scale_limit].
129
+ random_scale_limit: 0.2
130
+ # Horizontal flip?
131
+ horizontal_flip: True
132
+ # Crop size.
133
+ random_crop_h_w: 512, 512
134
+ # Train dataset details.
135
+ val:
136
+ dataset_type: lmdb
137
+ # Input LMDBs.
138
+ roots:
139
+ - ./data/lhq_lmdb/val
140
+ # Batch size per GPU.
141
+ batch_size: 4
142
+ # Data augmentations to be performed in given order.
143
+ augmentations:
144
+ # Crop size.
145
+ resize_h_w: 512, 512
146
+
147
+ test_data:
148
+ type: imaginaire.datasets.paired_images
149
+ num_workers: 8
150
+ input_types:
151
+ - seg_maps:
152
+ ext: jpg
153
+ num_channels: 1
154
+ is_mask: True
155
+ normalize: False
156
+
157
+ full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels
158
+ use_dont_care: True
159
+ one_hot_num_classes:
160
+ seg_maps: 183
161
+ input_labels:
162
+ - seg_maps
163
+
164
+ paired: True
165
+ # Validation dataset details.
166
+ test:
167
+ is_lmdb: True
168
+ roots:
169
+ - ./data/lhq_lmdb/val
170
+ # Batch size per GPU.
171
+ batch_size: 1
172
+ # If resize_h_w is not given, then it is assumed to be same as crop_h_w.
173
+ augmentations:
174
+ resize_h_w: 256, 256
175
+ horizontal_flip: False
configs/scenedreamer_inference.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ inference_args:
2
+ # 0: Camera orbiting the scene & looking at the center
3
+ # 1: Camera orbiting the scene & zooming in
4
+ # 2: Camera orbiting the scene & coming closer and closer to the center
5
+ # 3: Similar to 2, camera orbiting at the opposite direction
6
+ # 4: Simliar to 2, camera stays further away from the center
7
+ # 5: Camera sits at the center and look outwards
8
+ # 6: Camera rises while looking down
9
+ # 7: Camera really far away looking down at a 45deg angle
10
+ # 8: Camera for perpetual view generation, non-sliding window
11
+ # 9: Camera for infinite world generation, sliding window
12
+ camera_mode: 4
13
+
14
+ cam_maxstep: 40
15
+ resolution_hw: [540, 960]
16
+ num_samples: 40
17
+ cam_ang: 72
18
+
19
+ gen:
20
+ type: imaginaire.generators.scenedreamer
21
+ pcg_dataset_path: None
22
+ pcg_cache: False
23
+ scene_size: 2048
24
+
25
+ blk_feat_dim: 64
26
+
27
+ pe_lvl_feat: 4
28
+ pe_incl_orig_feat: False
29
+ pe_no_pe_feat_dim: 40
30
+ pe_lvl_raydir: 0
31
+ pe_incl_orig_raydir: False
32
+ style_dims: 128 # Set to 0 to disable style.
33
+ interm_style_dims: 256
34
+ final_feat_dim: 64
35
+
36
+ # Number of pixels removed from each edge to reduce boundary artifact of CNN
37
+ # both sides combined (8 -> 4 on left and 4 on right).
38
+ pad: 6
39
+
40
+ # ======== Sky network ========
41
+ pe_lvl_raydir_sky: 5
42
+ pe_incl_orig_raydir_sky: True
43
+
44
+ # ======== Style Encoder =========
45
+ # Comment out to disable style encoder.
46
+ style_enc:
47
+ num_filters: 64
48
+ kernel_size: 3
49
+ weight_norm_type: 'none'
50
+
51
+ stylenet_model: StyleMLP
52
+ stylenet_model_kwargs:
53
+ normalize_input: True
54
+ num_layers: 5
55
+
56
+ mlp_model: RenderMLP
57
+ mlp_model_kwargs:
58
+ use_seg: True
59
+
60
+ # ======== Ray Casting Params ========
61
+ num_blocks_early_stop: 6
62
+ num_samples: 24 # Original model uses 24. Reduced to 4 to allow training on 12GB GPUs (with significant performance penalty)
63
+ sample_depth: 3 # Stop the ray after certain depth
64
+ coarse_deterministic_sampling: False
65
+ sample_use_box_boundaries: False # Including voxel boundaries into the sample
66
+
67
+ # ======== Blender ========
68
+ raw_noise_std: 0.0
69
+ dists_scale: 0.25
70
+ clip_feat_map: True
71
+ # Prevent sky from leaking to the foreground.
72
+ keep_sky_out: True
73
+ keep_sky_out_avgpool: True
74
+ sky_global_avgpool: True
75
+
76
+ # ======== Label translator ========
77
+ reduced_label_set: True
78
+ use_label_smooth: True
79
+ use_label_smooth_real: True
80
+ use_label_smooth_pgt: True
81
+ label_smooth_dia: 11
82
+
83
+ # ======== Camera sampler ========
84
+ camera_sampler_type: 'traditional'
85
+ cam_res: [360, 640] # Camera resolution before cropping.
86
+ crop_size: [256, 256] # Actual crop size is crop_size+pad. It should generally match random_crop_h_w in dataloader.
87
+
88
+ # Threshold for rejecting camera poses that will result in a seg mask with low entropy.
89
+ # Generally, 0.5 min, 0.8 max.
90
+ camera_min_entropy: 0.75
91
+
92
+ # Threshold for rejecting camera poses that are too close to the objects.
93
+ camera_rej_avg_depth: 2.0
configs/scenedreamer_train.yaml ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_save_iter: 5000
2
+ snapshot_save_epoch: 5
3
+ snapshot_save_iter: 10000
4
+ max_epoch: 400
5
+ logging_iter: 10
6
+
7
+ trainer:
8
+ type: imaginaire.trainers.gancraft
9
+ model_average_config:
10
+ enabled: False
11
+ amp_config:
12
+ enabled: False
13
+ perceptual_loss:
14
+ mode: 'vgg19'
15
+ layers: ['relu_3_1', 'relu_4_1', 'relu_5_1']
16
+ weights: [0.125, 0.25, 1.0]
17
+ loss_weight:
18
+ l2: 10.0
19
+ gan: 0.5
20
+ pseudo_gan: 0.5
21
+ perceptual: 10.0
22
+ kl: 0.05
23
+ init:
24
+ type: xavier
25
+ gain: 0.02
26
+
27
+ # SPADE/GauGAN model for pseudo-GT generation.
28
+ gaugan_loader:
29
+ config: configs/landscape1m.yaml
30
+
31
+ image_to_tensorboard: True
32
+ distributed_data_parallel_params:
33
+ find_unused_parameters: False
34
+ broadcast_buffers: False
35
+
36
+ gen_opt:
37
+ type: adam
38
+ lr: 0.0001
39
+ eps: 1.e-7
40
+ adam_beta1: 0.
41
+ adam_beta2: 0.999
42
+ lr_policy:
43
+ iteration_mode: False
44
+ type: step
45
+ step_size: 400
46
+ gamma: 0.1
47
+ param_groups:
48
+ world_encoder:
49
+ lr: 0.0005
50
+ hash_encoder:
51
+ lr: 0.0001
52
+ render_net:
53
+ lr: 0.0001
54
+ sky_net:
55
+ lr: 0.0001
56
+ style_net:
57
+ lr: 0.0001
58
+ style_encoder:
59
+ lr: 0.0001
60
+ denoiser:
61
+ lr: 0.0001
62
+
63
+ dis_opt:
64
+ type: adam
65
+ lr: 0.0004
66
+ eps: 1.e-7
67
+ adam_beta1: 0.
68
+ adam_beta2: 0.999
69
+ lr_policy:
70
+ iteration_mode: False
71
+ type: step
72
+ step_size: 400
73
+ gamma: 0.1
74
+
75
+ gen:
76
+ type: imaginaire.generators.scenedreamer
77
+ pcg_dataset_path: ./data/terrain_cache
78
+ pcg_cache: True
79
+ scene_size: 2048
80
+
81
+ blk_feat_dim: 64
82
+
83
+ pe_lvl_feat: 4
84
+ pe_incl_orig_feat: False
85
+ pe_no_pe_feat_dim: 40
86
+ pe_lvl_raydir: 0
87
+ pe_incl_orig_raydir: False
88
+ style_dims: 128 # Set to 0 to disable style.
89
+ interm_style_dims: 256
90
+ final_feat_dim: 64
91
+
92
+ # Number of pixels removed from each edge to reduce boundary artifact of CNN
93
+ # both sides combined (8 -> 4 on left and 4 on right).
94
+ pad: 6
95
+
96
+ # ======== Sky network ========
97
+ pe_lvl_raydir_sky: 5
98
+ pe_incl_orig_raydir_sky: True
99
+
100
+ # ======== Style Encoder =========
101
+ # Comment out to disable style encoder.
102
+ style_enc:
103
+ num_filters: 64
104
+ kernel_size: 3
105
+ weight_norm_type: 'none'
106
+
107
+ stylenet_model: StyleMLP
108
+ stylenet_model_kwargs:
109
+ normalize_input: True
110
+ num_layers: 5
111
+
112
+ mlp_model: RenderMLP
113
+ mlp_model_kwargs:
114
+ use_seg: True
115
+
116
+ # ======== Ray Casting Params ========
117
+ num_blocks_early_stop: 6
118
+ num_samples: 24 # Decrease it if you got OOM on lowend GPU
119
+ sample_depth: 3 # Stop the ray after certain depth
120
+ coarse_deterministic_sampling: False
121
+ sample_use_box_boundaries: False # Including voxel boundaries into the sample
122
+
123
+ # ======== Blender ========
124
+ raw_noise_std: 0.0
125
+ dists_scale: 0.25
126
+ clip_feat_map: True
127
+ # Prevent sky from leaking to the foreground.
128
+ keep_sky_out: True
129
+ keep_sky_out_avgpool: True
130
+ sky_global_avgpool: True
131
+
132
+ # ======== Label translator ========
133
+ reduced_label_set: True
134
+ use_label_smooth: True
135
+ use_label_smooth_real: True
136
+ use_label_smooth_pgt: True
137
+ label_smooth_dia: 11
138
+
139
+ # ======== Camera sampler ========
140
+ camera_sampler_type: 'traditional'
141
+ cam_res: [360, 640] # Camera resolution before cropping.
142
+ crop_size: [256, 256] # Actual crop size is crop_size+pad. It should generally match random_crop_h_w in dataloader.
143
+
144
+ # Threshold for rejecting camera poses that will result in a seg mask with low entropy.
145
+ # Generally, 0.5 min, 0.8 max.
146
+ camera_min_entropy: 0.75
147
+
148
+ # Threshold for rejecting camera poses that are too close to the objects.
149
+ camera_rej_avg_depth: 2.0
150
+
151
+ dis:
152
+ type: imaginaire.discriminators.gancraft
153
+ image_channels: 3
154
+ num_labels: 12 # Same as num_reduced_lbls.
155
+ use_label: True
156
+ num_filters: 128
157
+ fpse_kernel_size: 3
158
+ activation_norm_type: 'none'
159
+ weight_norm_type: spectral
160
+ smooth_resample: True
161
+
162
+ # Data options.
163
+ data:
164
+ type: imaginaire.datasets.paired_images
165
+ num_workers: 8
166
+ input_types:
167
+ - images:
168
+ ext: jpg
169
+ num_channels: 3
170
+ normalize: True
171
+ use_dont_care: False
172
+ - seg_maps:
173
+ ext: png
174
+ num_channels: 1
175
+ is_mask: True
176
+ normalize: False
177
+
178
+ full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels
179
+ use_dont_care: False
180
+ one_hot_num_classes:
181
+ seg_maps: 184
182
+ input_labels:
183
+ - seg_maps
184
+
185
+ # Which lmdb contains the ground truth image.
186
+ input_image:
187
+ - images
188
+
189
+ # Train dataset details.
190
+ train:
191
+ dataset_type: lmdb
192
+ # Input LMDBs.
193
+ roots:
194
+ - ./data/lhq_lmdb/train
195
+ # Batch size per GPU.
196
+ batch_size: 1
197
+ # Data augmentations to be performed in given order.
198
+ augmentations:
199
+ resize_smallest_side: 256
200
+ # Rotate in (-rotate, rotate) in degrees.
201
+ rotate: 0
202
+ # Scale image by factor \in [1, 1+random_scale_limit].
203
+ random_scale_limit: 0.2
204
+ # Horizontal flip?
205
+ horizontal_flip: True
206
+ # Crop size.
207
+ random_crop_h_w: 256, 256
208
+ # Train dataset details.
209
+ val:
210
+ dataset_type: lmdb
211
+ # Input LMDBs.
212
+ roots:
213
+ - ./data/lhq_lmdb/val
214
+ # Batch size per GPU.
215
+ batch_size: 1
216
+ # Data augmentations to be performed in given order.
217
+ augmentations:
218
+ # Crop size.
219
+ resize_h_w: 256, 256
220
+
221
+ test_data:
222
+ type: imaginaire.datasets.dummy
223
+ num_workers: 0
encoding.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class FreqEncoder(nn.Module):
6
+ def __init__(self, input_dim, max_freq_log2, N_freqs,
7
+ log_sampling=True, include_input=True,
8
+ periodic_fns=(torch.sin, torch.cos)):
9
+
10
+ super().__init__()
11
+
12
+ self.input_dim = input_dim
13
+ self.include_input = include_input
14
+ self.periodic_fns = periodic_fns
15
+
16
+ self.output_dim = 0
17
+ if self.include_input:
18
+ self.output_dim += self.input_dim
19
+
20
+ self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
21
+
22
+ if log_sampling:
23
+ self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
24
+ else:
25
+ self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)
26
+
27
+ self.freq_bands = self.freq_bands.numpy().tolist()
28
+
29
+ def forward(self, input, **kwargs):
30
+
31
+ out = []
32
+ if self.include_input:
33
+ out.append(input)
34
+
35
+ for i in range(len(self.freq_bands)):
36
+ freq = self.freq_bands[i]
37
+ for p_fn in self.periodic_fns:
38
+ out.append(p_fn(input * freq))
39
+
40
+ out = torch.cat(out, dim=-1)
41
+
42
+
43
+ return out
44
+
45
+ def get_encoder(encoding, input_dim=3,
46
+ multires=6,
47
+ degree=4,
48
+ num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
49
+ **kwargs):
50
+
51
+ if encoding == 'None':
52
+ return lambda x, **kwargs: x, input_dim
53
+
54
+ elif encoding == 'hashgrid':
55
+ from gridencoder import GridEncoder
56
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
57
+
58
+ elif encoding == 'tiledgrid':
59
+ from gridencoder import GridEncoder
60
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
61
+ elif encoding == 'varhashgrid':
62
+ from gridencoder.grid import VarGridEncoder
63
+ encoder = VarGridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, hash_entries = kwargs['hash_feat_dim'])
64
+ else:
65
+ raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
66
+
67
+ return encoder, encoder.output_dim
environment.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: scenedreamer
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.9
7
+ - pytorch=1.12.0
8
+ - cudatoolkit=11.3
9
+ - torchvision
10
+ - pip
11
+ - numpy
12
+ - scipy
13
+ - scikit-image
14
+ - pip:
15
+ - einops
16
+ - noise
17
+ - opencv-python
18
+ - cmake
19
+ - pynvml
20
+ - Pillow>=8.3.2
21
+ - tqdm==4.35.0
22
+ - wget
23
+ - cython
24
+ - lmdb
25
+ - av
26
+ - opencv-python
27
+ - opencv-contrib-python
28
+ - imutils
29
+ - imageio-ffmpeg
30
+ - qimage2ndarray
31
+ - albumentations
32
+ - requests==2.25.1
33
+ - nvidia-ml-py3==7.352.0
34
+ - pyglet
35
+ - timm
36
+ - diskcache
37
+ - boto3
38
+ - awscli_plugin_endpoint
39
+ - awscli
40
+ - rsa
41
+ - wandb
42
+ - tensorboard
43
+ - lpips
44
+ - matplotlib
gridencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .grid import GridEncoder
gridencoder/backend.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.cpp_extension import load
3
+
4
+ _src_path = os.path.dirname(os.path.abspath(__file__))
5
+
6
+ nvcc_flags = [
7
+ '-O3', '-std=c++17',
8
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9
+ ]
10
+
11
+ if os.name == "posix":
12
+ c_flags = ['-O3', '-std=c++17']
13
+ elif os.name == "nt":
14
+ c_flags = ['/O2', '/std:c++17']
15
+
16
+ # find cl.exe
17
+ def find_cl_path():
18
+ import glob
19
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21
+ if paths:
22
+ return paths[0]
23
+
24
+ # If cl.exe is not on path, try to find it.
25
+ if os.system("where cl.exe >nul 2>nul") != 0:
26
+ cl_path = find_cl_path()
27
+ if cl_path is None:
28
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29
+ os.environ["PATH"] += ";" + cl_path
30
+
31
+ _backend = load(name='_grid_encoder',
32
+ extra_cflags=c_flags,
33
+ extra_cuda_cflags=nvcc_flags,
34
+ sources=[os.path.join(_src_path, 'src', f) for f in [
35
+ 'gridencoder.cu',
36
+ 'bindings.cpp',
37
+ ]],
38
+ )
39
+
40
+ __all__ = ['_backend']
gridencoder/grid.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.autograd import Function
6
+ from torch.autograd.function import once_differentiable
7
+ from torch.cuda.amp import custom_bwd, custom_fwd
8
+
9
+ try:
10
+ import _gridencoder as _backend
11
+ except ImportError:
12
+ from .backend import _backend
13
+
14
+ _gridtype_to_id = {
15
+ 'hash': 0,
16
+ 'tiled': 1,
17
+ }
18
+
19
+ class _grid_encode(Function):
20
+ @staticmethod
21
+ @custom_fwd
22
+ def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
23
+ # inputs: [B, D], float in [0, 1]
24
+ # embeddings: [sO, C], float
25
+ # offsets: [L + 1], int
26
+ # RETURN: [B, F], float
27
+
28
+ inputs = inputs.contiguous()
29
+
30
+ B, D = inputs.shape # batch size, coord dim
31
+ L = offsets.shape[0] - 1 # level
32
+ C = embeddings.shape[1] # embedding dim for each level
33
+ S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
34
+ H = base_resolution # base resolution
35
+
36
+ # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
37
+ # if C % 2 != 0, force float, since half for atomicAdd is very slow.
38
+ if torch.is_autocast_enabled() and C % 2 == 0:
39
+ embeddings = embeddings.to(torch.half)
40
+
41
+ # L first, optimize cache for cuda kernel, but needs an extra permute later
42
+ outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
43
+
44
+ if calc_grad_inputs:
45
+ dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
46
+ else:
47
+ dy_dx = torch.empty(1, device=inputs.device, dtype=embeddings.dtype) # placeholder... TODO: a better way?
48
+
49
+ _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners)
50
+
51
+ # permute back to [B, L * C]
52
+ outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
53
+
54
+ ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
55
+ ctx.dims = [B, D, C, L, S, H, gridtype]
56
+ ctx.calc_grad_inputs = calc_grad_inputs
57
+ ctx.align_corners = align_corners
58
+
59
+ return outputs
60
+
61
+ @staticmethod
62
+ #@once_differentiable
63
+ @custom_bwd
64
+ def backward(ctx, grad):
65
+
66
+ inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
67
+ B, D, C, L, S, H, gridtype = ctx.dims
68
+ calc_grad_inputs = ctx.calc_grad_inputs
69
+ align_corners = ctx.align_corners
70
+
71
+ # grad: [B, L * C] --> [L, B, C]
72
+ grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
73
+
74
+ grad_embeddings = torch.zeros_like(embeddings)
75
+
76
+ if calc_grad_inputs:
77
+ grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
78
+ else:
79
+ grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype)
80
+
81
+ _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners)
82
+
83
+ if calc_grad_inputs:
84
+ grad_inputs = grad_inputs.to(inputs.dtype)
85
+ return grad_inputs, grad_embeddings, None, None, None, None, None, None
86
+ else:
87
+ return None, grad_embeddings, None, None, None, None, None, None
88
+
89
+
90
+ grid_encode = _grid_encode.apply
91
+
92
+
93
+ class GridEncoder(nn.Module):
94
+ def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
95
+ super().__init__()
96
+
97
+ # the finest resolution desired at the last level, if provided, overridee per_level_scale
98
+ if desired_resolution is not None:
99
+ per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
100
+
101
+ self.input_dim = input_dim # coord dims, 2 or 3
102
+ self.num_levels = num_levels # num levels, each level multiply resolution by 2
103
+ self.level_dim = level_dim # encode channels per level
104
+ self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
105
+ self.log2_hashmap_size = log2_hashmap_size
106
+ self.base_resolution = base_resolution
107
+ self.output_dim = num_levels * level_dim
108
+ self.gridtype = gridtype
109
+ self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
110
+ self.align_corners = align_corners
111
+
112
+ # allocate parameters
113
+ offsets = []
114
+ offset = 0
115
+ self.max_params = 2 ** log2_hashmap_size
116
+ for i in range(num_levels):
117
+ resolution = int(np.ceil(base_resolution * per_level_scale ** i))
118
+ params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
119
+ params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
120
+ offsets.append(offset)
121
+ offset += params_in_level
122
+ offsets.append(offset)
123
+ offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
124
+ self.register_buffer('offsets', offsets)
125
+
126
+ self.n_params = offsets[-1] * level_dim
127
+
128
+ # parameters
129
+ self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
130
+
131
+ self.reset_parameters()
132
+
133
+ def reset_parameters(self):
134
+ std = 1e-4
135
+ self.embeddings.data.uniform_(-std, std)
136
+
137
+ def __repr__(self):
138
+ return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
139
+
140
+ def forward(self, inputs, bound=1):
141
+ # inputs: [..., input_dim], normalized real world positions in [-bound, bound]
142
+ # return: [..., num_levels * level_dim]
143
+
144
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
145
+
146
+ #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
147
+
148
+ prefix_shape = list(inputs.shape[:-1])
149
+ inputs = inputs.view(-1, self.input_dim)
150
+
151
+ outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
152
+ outputs = outputs.view(prefix_shape + [self.output_dim])
153
+
154
+ #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
155
+
156
+ return outputs
157
+
158
+ class VarGridEncoder(nn.Module):
159
+ def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, hash_entries=None):
160
+ super().__init__()
161
+
162
+ # the finest resolution desired at the last level, if provided, overridee per_level_scale
163
+ if desired_resolution is not None:
164
+ per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
165
+
166
+ self.input_dim = input_dim # coord dims, 2 or 3
167
+ self.num_levels = num_levels # num levels, each level multiply resolution by 2
168
+ self.level_dim = level_dim # encode channels per level
169
+ self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
170
+ self.log2_hashmap_size = log2_hashmap_size
171
+ self.base_resolution = base_resolution
172
+ self.output_dim = num_levels * level_dim
173
+ self.gridtype = gridtype
174
+ self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
175
+ self.align_corners = align_corners
176
+
177
+ # allocate parameters
178
+ offsets = []
179
+ offset = 0
180
+ self.max_params = 2 ** log2_hashmap_size
181
+ for i in range(num_levels):
182
+ resolution = int(np.ceil(base_resolution * per_level_scale ** i))
183
+ params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
184
+ params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
185
+ offsets.append(offset)
186
+ offset += params_in_level
187
+ offsets.append(offset)
188
+ offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
189
+ self.register_buffer('offsets', offsets)
190
+
191
+ self.n_params = offsets[-1] * level_dim
192
+ self.level_dim = level_dim
193
+ self.offset = offset
194
+
195
+ # parameters
196
+ self.embeddings = nn.Parameter(torch.empty(offset - hash_entries, level_dim))
197
+
198
+ self.reset_parameters()
199
+
200
+ def reset_parameters(self):
201
+ std = 1e-4
202
+ self.embeddings.data.uniform_(-std, std)
203
+
204
+ def __repr__(self):
205
+ return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
206
+
207
+ def forward(self, inputs, embeddings, bound=1):
208
+ # inputs: [..., input_dim], normalized real world positions in [-bound, bound]
209
+ # return: [..., num_levels * level_dim]
210
+ input_embeddings = torch.cat([embeddings, self.embeddings], dim=0)
211
+
212
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
213
+
214
+ #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
215
+
216
+ prefix_shape = list(inputs.shape[:-1])
217
+ inputs = inputs.view(-1, self.input_dim)
218
+
219
+ outputs = grid_encode(inputs, input_embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
220
+ outputs = outputs.view(prefix_shape + [self.output_dim])
221
+
222
+ #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
223
+
224
+ return outputs
gridencoder/setup.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
+
5
+ _src_path = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ nvcc_flags = [
8
+ '-O3', '-std=c++17',
9
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10
+ ]
11
+
12
+ if os.name == "posix":
13
+ c_flags = ['-O3', '-std=c++17']
14
+ elif os.name == "nt":
15
+ c_flags = ['/O2', '/std:c++17']
16
+
17
+ # find cl.exe
18
+ def find_cl_path():
19
+ import glob
20
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22
+ if paths:
23
+ return paths[0]
24
+
25
+ # If cl.exe is not on path, try to find it.
26
+ if os.system("where cl.exe >nul 2>nul") != 0:
27
+ cl_path = find_cl_path()
28
+ if cl_path is None:
29
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30
+ os.environ["PATH"] += ";" + cl_path
31
+
32
+ setup(
33
+ name='gridencoder', # package name, import this to use python API
34
+ ext_modules=[
35
+ CUDAExtension(
36
+ name='_gridencoder', # extension name, import this to use CUDA API
37
+ sources=[os.path.join(_src_path, 'src', f) for f in [
38
+ 'gridencoder.cu',
39
+ 'bindings.cpp',
40
+ ]],
41
+ extra_compile_args={
42
+ 'cxx': c_flags,
43
+ 'nvcc': nvcc_flags,
44
+ }
45
+ ),
46
+ ],
47
+ cmdclass={
48
+ 'build_ext': BuildExtension,
49
+ }
50
+ )
gridencoder/src/bindings.cpp ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include "gridencoder.h"
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
7
+ m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
8
+ }
gridencoder/src/gridencoder.cu ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda.h>
2
+ #include <cuda_fp16.h>
3
+ #include <cuda_runtime.h>
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <torch/torch.h>
7
+
8
+ #include <algorithm>
9
+ #include <stdexcept>
10
+
11
+ #include <stdint.h>
12
+ #include <cstdio>
13
+
14
+
15
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
16
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
17
+ #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
18
+ #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
19
+
20
+
21
+ // just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...
22
+ static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
23
+ // requires CUDA >= 10 and ARCH >= 70
24
+ // this is very slow compared to float or __half2, and never used.
25
+ //return atomicAdd(reinterpret_cast<__half*>(address), val);
26
+ }
27
+
28
+
29
+ template <typename T>
30
+ static inline __host__ __device__ T div_round_up(T val, T divisor) {
31
+ return (val + divisor - 1) / divisor;
32
+ }
33
+
34
+
35
+ template <uint32_t D>
36
+ __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
37
+ static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
38
+
39
+ // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
40
+ // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
41
+ // coordinates.
42
+ constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };
43
+
44
+ uint32_t result = 0;
45
+ #pragma unroll
46
+ for (uint32_t i = 0; i < D; ++i) {
47
+ result ^= pos_grid[i] * primes[i];
48
+ }
49
+
50
+ return result;
51
+ }
52
+
53
+
54
+ template <uint32_t D, uint32_t C>
55
+ __device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
56
+ uint32_t stride = 1;
57
+ uint32_t index = 0;
58
+
59
+ #pragma unroll
60
+ for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
61
+ index += pos_grid[d] * stride;
62
+ stride *= align_corners ? resolution: (resolution + 1);
63
+ }
64
+
65
+ // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
66
+ // gridtype: 0 == hash, 1 == tiled
67
+ if (gridtype == 0 && stride > hashmap_size) {
68
+ index = fast_hash<D>(pos_grid);
69
+ }
70
+
71
+ return (index % hashmap_size) * C + ch;
72
+ }
73
+
74
+
75
+ template <typename scalar_t, uint32_t D, uint32_t C>
76
+ __global__ void kernel_grid(
77
+ const float * __restrict__ inputs,
78
+ const scalar_t * __restrict__ grid,
79
+ const int * __restrict__ offsets,
80
+ scalar_t * __restrict__ outputs,
81
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
82
+ const bool calc_grad_inputs,
83
+ scalar_t * __restrict__ dy_dx,
84
+ const uint32_t gridtype,
85
+ const bool align_corners
86
+ ) {
87
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
88
+
89
+ if (b >= B) return;
90
+
91
+ const uint32_t level = blockIdx.y;
92
+
93
+ // locate
94
+ grid += (uint32_t)offsets[level] * C;
95
+ inputs += b * D;
96
+ outputs += level * B * C + b * C;
97
+
98
+ // check input range (should be in [0, 1])
99
+ bool flag_oob = false;
100
+ #pragma unroll
101
+ for (uint32_t d = 0; d < D; d++) {
102
+ if (inputs[d] < 0 || inputs[d] > 1) {
103
+ flag_oob = true;
104
+ }
105
+ }
106
+ // if input out of bound, just set output to 0
107
+ if (flag_oob) {
108
+ #pragma unroll
109
+ for (uint32_t ch = 0; ch < C; ch++) {
110
+ outputs[ch] = 0;
111
+ }
112
+ if (calc_grad_inputs) {
113
+ dy_dx += b * D * L * C + level * D * C; // B L D C
114
+ #pragma unroll
115
+ for (uint32_t d = 0; d < D; d++) {
116
+ #pragma unroll
117
+ for (uint32_t ch = 0; ch < C; ch++) {
118
+ dy_dx[d * C + ch] = 0;
119
+ }
120
+ }
121
+ }
122
+ return;
123
+ }
124
+
125
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
126
+ const float scale = exp2f(level * S) * H - 1.0f;
127
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
128
+
129
+ // calculate coordinate
130
+ float pos[D];
131
+ uint32_t pos_grid[D];
132
+
133
+ #pragma unroll
134
+ for (uint32_t d = 0; d < D; d++) {
135
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
136
+ pos_grid[d] = floorf(pos[d]);
137
+ pos[d] -= (float)pos_grid[d];
138
+ }
139
+
140
+ //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
141
+
142
+ // interpolate
143
+ scalar_t results[C] = {0}; // temp results in register
144
+
145
+ #pragma unroll
146
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
147
+ float w = 1;
148
+ uint32_t pos_grid_local[D];
149
+
150
+ #pragma unroll
151
+ for (uint32_t d = 0; d < D; d++) {
152
+ if ((idx & (1 << d)) == 0) {
153
+ w *= 1 - pos[d];
154
+ pos_grid_local[d] = pos_grid[d];
155
+ } else {
156
+ w *= pos[d];
157
+ pos_grid_local[d] = pos_grid[d] + 1;
158
+ }
159
+ }
160
+
161
+ uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
162
+
163
+ // writing to register (fast)
164
+ #pragma unroll
165
+ for (uint32_t ch = 0; ch < C; ch++) {
166
+ results[ch] += w * grid[index + ch];
167
+ }
168
+
169
+ //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
170
+ }
171
+
172
+ // writing to global memory (slow)
173
+ #pragma unroll
174
+ for (uint32_t ch = 0; ch < C; ch++) {
175
+ outputs[ch] = results[ch];
176
+ }
177
+
178
+ // prepare dy_dx for calc_grad_inputs
179
+ // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
180
+ if (calc_grad_inputs) {
181
+
182
+ dy_dx += b * D * L * C + level * D * C; // B L D C
183
+
184
+ #pragma unroll
185
+ for (uint32_t gd = 0; gd < D; gd++) {
186
+
187
+ scalar_t results_grad[C] = {0};
188
+
189
+ #pragma unroll
190
+ for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
191
+ float w = scale;
192
+ uint32_t pos_grid_local[D];
193
+
194
+ #pragma unroll
195
+ for (uint32_t nd = 0; nd < D - 1; nd++) {
196
+ const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
197
+
198
+ if ((idx & (1 << nd)) == 0) {
199
+ w *= 1 - pos[d];
200
+ pos_grid_local[d] = pos_grid[d];
201
+ } else {
202
+ w *= pos[d];
203
+ pos_grid_local[d] = pos_grid[d] + 1;
204
+ }
205
+ }
206
+
207
+ pos_grid_local[gd] = pos_grid[gd];
208
+ uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
209
+ pos_grid_local[gd] = pos_grid[gd] + 1;
210
+ uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
211
+
212
+ #pragma unroll
213
+ for (uint32_t ch = 0; ch < C; ch++) {
214
+ results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);
215
+ }
216
+ }
217
+
218
+ #pragma unroll
219
+ for (uint32_t ch = 0; ch < C; ch++) {
220
+ dy_dx[gd * C + ch] = results_grad[ch];
221
+ }
222
+ }
223
+ }
224
+ }
225
+
226
+
227
+ template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
228
+ __global__ void kernel_grid_backward(
229
+ const scalar_t * __restrict__ grad,
230
+ const float * __restrict__ inputs,
231
+ const scalar_t * __restrict__ grid,
232
+ const int * __restrict__ offsets,
233
+ scalar_t * __restrict__ grad_grid,
234
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
235
+ const uint32_t gridtype,
236
+ const bool align_corners
237
+ ) {
238
+ const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
239
+ if (b >= B) return;
240
+
241
+ const uint32_t level = blockIdx.y;
242
+ const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
243
+
244
+ // locate
245
+ grad_grid += offsets[level] * C;
246
+ inputs += b * D;
247
+ grad += level * B * C + b * C + ch; // L, B, C
248
+
249
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
250
+ const float scale = exp2f(level * S) * H - 1.0f;
251
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
252
+
253
+ // check input range (should be in [0, 1])
254
+ #pragma unroll
255
+ for (uint32_t d = 0; d < D; d++) {
256
+ if (inputs[d] < 0 || inputs[d] > 1) {
257
+ return; // grad is init as 0, so we simply return.
258
+ }
259
+ }
260
+
261
+ // calculate coordinate
262
+ float pos[D];
263
+ uint32_t pos_grid[D];
264
+
265
+ #pragma unroll
266
+ for (uint32_t d = 0; d < D; d++) {
267
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
268
+ pos_grid[d] = floorf(pos[d]);
269
+ pos[d] -= (float)pos_grid[d];
270
+ }
271
+
272
+ scalar_t grad_cur[N_C] = {0}; // fetch to register
273
+ #pragma unroll
274
+ for (uint32_t c = 0; c < N_C; c++) {
275
+ grad_cur[c] = grad[c];
276
+ }
277
+
278
+ // interpolate
279
+ #pragma unroll
280
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
281
+ float w = 1;
282
+ uint32_t pos_grid_local[D];
283
+
284
+ #pragma unroll
285
+ for (uint32_t d = 0; d < D; d++) {
286
+ if ((idx & (1 << d)) == 0) {
287
+ w *= 1 - pos[d];
288
+ pos_grid_local[d] = pos_grid[d];
289
+ } else {
290
+ w *= pos[d];
291
+ pos_grid_local[d] = pos_grid[d] + 1;
292
+ }
293
+ }
294
+
295
+ uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
296
+
297
+ // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
298
+ // TODO: use float which is better than __half, if N_C % 2 != 0
299
+ if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
300
+ #pragma unroll
301
+ for (uint32_t c = 0; c < N_C; c += 2) {
302
+ // process two __half at once (by interpreting as a __half2)
303
+ __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
304
+ atomicAdd((__half2*)&grad_grid[index + c], v);
305
+ }
306
+ // float, or __half when N_C % 2 != 0 (which means C == 1)
307
+ } else {
308
+ #pragma unroll
309
+ for (uint32_t c = 0; c < N_C; c++) {
310
+ atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
311
+ }
312
+ }
313
+ }
314
+ }
315
+
316
+
317
+ template <typename scalar_t, uint32_t D, uint32_t C>
318
+ __global__ void kernel_input_backward(
319
+ const scalar_t * __restrict__ grad,
320
+ const scalar_t * __restrict__ dy_dx,
321
+ scalar_t * __restrict__ grad_inputs,
322
+ uint32_t B, uint32_t L
323
+ ) {
324
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
325
+ if (t >= B * D) return;
326
+
327
+ const uint32_t b = t / D;
328
+ const uint32_t d = t - b * D;
329
+
330
+ dy_dx += b * L * D * C;
331
+
332
+ scalar_t result = 0;
333
+
334
+ # pragma unroll
335
+ for (int l = 0; l < L; l++) {
336
+ # pragma unroll
337
+ for (int ch = 0; ch < C; ch++) {
338
+ result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
339
+ }
340
+ }
341
+
342
+ grad_inputs[t] = result;
343
+ }
344
+
345
+
346
+ template <typename scalar_t, uint32_t D>
347
+ void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
348
+ static constexpr uint32_t N_THREAD = 512;
349
+ const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
350
+ switch (C) {
351
+ case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
352
+ case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
353
+ case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
354
+ case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
355
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
356
+ }
357
+ }
358
+
359
+ // inputs: [B, D], float, in [0, 1]
360
+ // embeddings: [sO, C], float
361
+ // offsets: [L + 1], uint32_t
362
+ // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
363
+ // H: base resolution
364
+ // dy_dx: [B, L * D * C]
365
+ template <typename scalar_t>
366
+ void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
367
+ switch (D) {
368
+ case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
369
+ case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
370
+ case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
371
+ case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
372
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
373
+ }
374
+
375
+ }
376
+
377
+ template <typename scalar_t, uint32_t D>
378
+ void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
379
+ static constexpr uint32_t N_THREAD = 256;
380
+ const uint32_t N_C = std::min(2u, C); // n_features_per_thread
381
+ const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
382
+ switch (C) {
383
+ case 1:
384
+ kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
385
+ if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
386
+ break;
387
+ case 2:
388
+ kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
389
+ if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
390
+ break;
391
+ case 4:
392
+ kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
393
+ if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
394
+ break;
395
+ case 8:
396
+ kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
397
+ if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
398
+ break;
399
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
400
+ }
401
+ }
402
+
403
+
404
+ // grad: [L, B, C], float
405
+ // inputs: [B, D], float, in [0, 1]
406
+ // embeddings: [sO, C], float
407
+ // offsets: [L + 1], uint32_t
408
+ // grad_embeddings: [sO, C]
409
+ // H: base resolution
410
+ template <typename scalar_t>
411
+ void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
412
+ switch (D) {
413
+ case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
414
+ case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
415
+ case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
416
+ case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
417
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
418
+ }
419
+ }
420
+
421
+
422
+
423
+ void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners) {
424
+ CHECK_CUDA(inputs);
425
+ CHECK_CUDA(embeddings);
426
+ CHECK_CUDA(offsets);
427
+ CHECK_CUDA(outputs);
428
+ CHECK_CUDA(dy_dx);
429
+
430
+ CHECK_CONTIGUOUS(inputs);
431
+ CHECK_CONTIGUOUS(embeddings);
432
+ CHECK_CONTIGUOUS(offsets);
433
+ CHECK_CONTIGUOUS(outputs);
434
+ CHECK_CONTIGUOUS(dy_dx);
435
+
436
+ CHECK_IS_FLOATING(inputs);
437
+ CHECK_IS_FLOATING(embeddings);
438
+ CHECK_IS_INT(offsets);
439
+ CHECK_IS_FLOATING(outputs);
440
+ CHECK_IS_FLOATING(dy_dx);
441
+
442
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
443
+ embeddings.scalar_type(), "grid_encode_forward", ([&] {
444
+ grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), gridtype, align_corners);
445
+ }));
446
+ }
447
+
448
+ void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners) {
449
+ CHECK_CUDA(grad);
450
+ CHECK_CUDA(inputs);
451
+ CHECK_CUDA(embeddings);
452
+ CHECK_CUDA(offsets);
453
+ CHECK_CUDA(grad_embeddings);
454
+ CHECK_CUDA(dy_dx);
455
+ CHECK_CUDA(grad_inputs);
456
+
457
+ CHECK_CONTIGUOUS(grad);
458
+ CHECK_CONTIGUOUS(inputs);
459
+ CHECK_CONTIGUOUS(embeddings);
460
+ CHECK_CONTIGUOUS(offsets);
461
+ CHECK_CONTIGUOUS(grad_embeddings);
462
+ CHECK_CONTIGUOUS(dy_dx);
463
+ CHECK_CONTIGUOUS(grad_inputs);
464
+
465
+ CHECK_IS_FLOATING(grad);
466
+ CHECK_IS_FLOATING(inputs);
467
+ CHECK_IS_FLOATING(embeddings);
468
+ CHECK_IS_INT(offsets);
469
+ CHECK_IS_FLOATING(grad_embeddings);
470
+ CHECK_IS_FLOATING(dy_dx);
471
+ CHECK_IS_FLOATING(grad_inputs);
472
+
473
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
474
+ grad.scalar_type(), "grid_encode_backward", ([&] {
475
+ grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>(), gridtype, align_corners);
476
+ }));
477
+
478
+ }
gridencoder/src/gridencoder.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _HASH_ENCODE_H
2
+ #define _HASH_ENCODE_H
3
+
4
+ #include <stdint.h>
5
+ #include <torch/torch.h>
6
+
7
+ // inputs: [B, D], float, in [0, 1]
8
+ // embeddings: [sO, C], float
9
+ // offsets: [L + 1], uint32_t
10
+ // outputs: [B, L * C], float
11
+ // H: base resolution
12
+ void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners);
13
+ void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners);
14
+
15
+ #endif
imaginaire/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
imaginaire/config.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ """Config utilities for yml file."""
6
+
7
+ import collections
8
+ import functools
9
+ import os
10
+ import re
11
+
12
+ import yaml
13
+ from imaginaire.utils.distributed import master_only_print as print
14
+
15
+ DEBUG = False
16
+ USE_JIT = False
17
+
18
+
19
+ class AttrDict(dict):
20
+ """Dict as attribute trick."""
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super(AttrDict, self).__init__(*args, **kwargs)
24
+ self.__dict__ = self
25
+ for key, value in self.__dict__.items():
26
+ if isinstance(value, dict):
27
+ self.__dict__[key] = AttrDict(value)
28
+ elif isinstance(value, (list, tuple)):
29
+ if isinstance(value[0], dict):
30
+ self.__dict__[key] = [AttrDict(item) for item in value]
31
+ else:
32
+ self.__dict__[key] = value
33
+
34
+ def yaml(self):
35
+ """Convert object to yaml dict and return."""
36
+ yaml_dict = {}
37
+ for key, value in self.__dict__.items():
38
+ if isinstance(value, AttrDict):
39
+ yaml_dict[key] = value.yaml()
40
+ elif isinstance(value, list):
41
+ if isinstance(value[0], AttrDict):
42
+ new_l = []
43
+ for item in value:
44
+ new_l.append(item.yaml())
45
+ yaml_dict[key] = new_l
46
+ else:
47
+ yaml_dict[key] = value
48
+ else:
49
+ yaml_dict[key] = value
50
+ return yaml_dict
51
+
52
+ def __repr__(self):
53
+ """Print all variables."""
54
+ ret_str = []
55
+ for key, value in self.__dict__.items():
56
+ if isinstance(value, AttrDict):
57
+ ret_str.append('{}:'.format(key))
58
+ child_ret_str = value.__repr__().split('\n')
59
+ for item in child_ret_str:
60
+ ret_str.append(' ' + item)
61
+ elif isinstance(value, list):
62
+ if isinstance(value[0], AttrDict):
63
+ ret_str.append('{}:'.format(key))
64
+ for item in value:
65
+ # Treat as AttrDict above.
66
+ child_ret_str = item.__repr__().split('\n')
67
+ for item in child_ret_str:
68
+ ret_str.append(' ' + item)
69
+ else:
70
+ ret_str.append('{}: {}'.format(key, value))
71
+ else:
72
+ ret_str.append('{}: {}'.format(key, value))
73
+ return '\n'.join(ret_str)
74
+
75
+
76
+ class Config(AttrDict):
77
+ r"""Configuration class. This should include every human specifiable
78
+ hyperparameter values for your training."""
79
+
80
+ def __init__(self, filename=None, verbose=False):
81
+ super(Config, self).__init__()
82
+ self.source_filename = filename
83
+ # Set default parameters.
84
+ # Logging.
85
+ large_number = 1000000000
86
+ self.snapshot_save_iter = large_number
87
+ self.snapshot_save_epoch = large_number
88
+ self.metrics_iter = None
89
+ self.metrics_epoch = None
90
+ self.snapshot_save_start_iter = 0
91
+ self.snapshot_save_start_epoch = 0
92
+ self.image_save_iter = large_number
93
+ self.image_display_iter = large_number
94
+ self.max_epoch = large_number
95
+ self.max_iter = large_number
96
+ self.logging_iter = 100
97
+ self.speed_benchmark = False
98
+
99
+ # Trainer.
100
+ self.trainer = AttrDict(
101
+ model_average_config=AttrDict(enabled=False,
102
+ beta=0.9999,
103
+ start_iteration=1000,
104
+ num_batch_norm_estimation_iterations=30,
105
+ remove_sn=True),
106
+ # model_average=False,
107
+ # model_average_beta=0.9999,
108
+ # model_average_start_iteration=1000,
109
+ # model_average_batch_norm_estimation_iteration=30,
110
+ # model_average_remove_sn=True,
111
+ image_to_tensorboard=False,
112
+ hparam_to_tensorboard=False,
113
+ distributed_data_parallel='pytorch',
114
+ distributed_data_parallel_params=AttrDict(
115
+ find_unused_parameters=False),
116
+ delay_allreduce=True,
117
+ gan_relativistic=False,
118
+ gen_step=1,
119
+ dis_step=1,
120
+ gan_decay_k=1.,
121
+ gan_min_k=1.,
122
+ gan_separate_topk=False,
123
+ aug_policy='',
124
+ channels_last=False,
125
+ strict_resume=True,
126
+ amp_gp=False,
127
+ amp_config=AttrDict(init_scale=65536.0,
128
+ growth_factor=2.0,
129
+ backoff_factor=0.5,
130
+ growth_interval=2000,
131
+ enabled=False))
132
+
133
+ # Networks.
134
+ self.gen = AttrDict(type='imaginaire.generators.dummy')
135
+ self.dis = AttrDict(type='imaginaire.discriminators.dummy')
136
+
137
+ # Optimizers.
138
+ self.gen_opt = AttrDict(type='adam',
139
+ fused_opt=False,
140
+ lr=0.0001,
141
+ adam_beta1=0.0,
142
+ adam_beta2=0.999,
143
+ eps=1e-8,
144
+ lr_policy=AttrDict(iteration_mode=False,
145
+ type='step',
146
+ step_size=large_number,
147
+ gamma=1))
148
+ self.dis_opt = AttrDict(type='adam',
149
+ fused_opt=False,
150
+ lr=0.0001,
151
+ adam_beta1=0.0,
152
+ adam_beta2=0.999,
153
+ eps=1e-8,
154
+ lr_policy=AttrDict(iteration_mode=False,
155
+ type='step',
156
+ step_size=large_number,
157
+ gamma=1))
158
+ # Data.
159
+ self.data = AttrDict(name='dummy',
160
+ type='imaginaire.datasets.images',
161
+ num_workers=0)
162
+ self.test_data = AttrDict(name='dummy',
163
+ type='imaginaire.datasets.images',
164
+ num_workers=0,
165
+ test=AttrDict(is_lmdb=False,
166
+ roots='',
167
+ batch_size=1))
168
+
169
+
170
+ # Cudnn.
171
+ self.cudnn = AttrDict(deterministic=False,
172
+ benchmark=True)
173
+
174
+ # Others.
175
+ self.pretrained_weight = ''
176
+ self.inference_args = AttrDict()
177
+
178
+ # Update with given configurations.
179
+ assert os.path.exists(filename), 'File {} not exist.'.format(filename)
180
+ loader = yaml.SafeLoader
181
+ loader.add_implicit_resolver(
182
+ u'tag:yaml.org,2002:float',
183
+ re.compile(u'''^(?:
184
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
185
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
186
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
187
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
188
+ |[-+]?\\.(?:inf|Inf|INF)
189
+ |\\.(?:nan|NaN|NAN))$''', re.X),
190
+ list(u'-+0123456789.'))
191
+ try:
192
+ with open(filename, 'r') as f:
193
+ cfg_dict = yaml.load(f, Loader=loader)
194
+ except EnvironmentError:
195
+ print('Please check the file with name of "%s"', filename)
196
+ recursive_update(self, cfg_dict)
197
+
198
+ # Put common opts in both gen and dis.
199
+ if 'common' in cfg_dict:
200
+ self.common = AttrDict(**cfg_dict['common'])
201
+ self.gen.common = self.common
202
+ self.dis.common = self.common
203
+
204
+ if verbose:
205
+ print(' imaginaire config '.center(80, '-'))
206
+ print(self.__repr__())
207
+ print(''.center(80, '-'))
208
+
209
+
210
+ def rsetattr(obj, attr, val):
211
+ """Recursively find object and set value"""
212
+ pre, _, post = attr.rpartition('.')
213
+ return setattr(rgetattr(obj, pre) if pre else obj, post, val)
214
+
215
+
216
+ def rgetattr(obj, attr, *args):
217
+ """Recursively find object and return value"""
218
+
219
+ def _getattr(obj, attr):
220
+ r"""Get attribute."""
221
+ return getattr(obj, attr, *args)
222
+
223
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
224
+
225
+
226
+ def recursive_update(d, u):
227
+ """Recursively update AttrDict d with AttrDict u"""
228
+ for key, value in u.items():
229
+ if isinstance(value, collections.abc.Mapping):
230
+ d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
231
+ elif isinstance(value, (list, tuple)):
232
+ if isinstance(value[0], dict):
233
+ d.__dict__[key] = [AttrDict(item) for item in value]
234
+ else:
235
+ d.__dict__[key] = value
236
+ else:
237
+ d.__dict__[key] = value
238
+ return d
imaginaire/discriminators/__init__.py ADDED
File without changes
imaginaire/discriminators/gancraft.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import functools
10
+ from imaginaire.layers import Conv2dBlock
11
+
12
+ from imaginaire.utils.data import get_paired_input_label_channel_number, get_paired_input_image_channel_number
13
+ from imaginaire.utils.distributed import master_only_print as print
14
+
15
+
16
+ class Discriminator(nn.Module):
17
+ r"""Multi-resolution patch discriminator. Based on FPSE discriminator but with N+1 labels.
18
+
19
+ Args:
20
+ dis_cfg (obj): Discriminator definition part of the yaml config file.
21
+ data_cfg (obj): Data definition part of the yaml config file.
22
+ """
23
+
24
+ def __init__(self, dis_cfg, data_cfg):
25
+ super(Discriminator, self).__init__()
26
+ # We assume the first datum is the ground truth image.
27
+ image_channels = get_paired_input_image_channel_number(data_cfg)
28
+ # Calculate number of channels in the input label.
29
+ num_labels = get_paired_input_label_channel_number(data_cfg)
30
+
31
+ self.use_label = getattr(dis_cfg, 'use_label', True)
32
+ # Override number of input channels
33
+ if hasattr(dis_cfg, 'image_channels'):
34
+ image_channels = dis_cfg.image_channels
35
+ if hasattr(dis_cfg, 'num_labels'):
36
+ num_labels = dis_cfg.num_labels
37
+ else:
38
+ # We assume the first datum is the ground truth image.
39
+ image_channels = get_paired_input_image_channel_number(data_cfg)
40
+ # Calculate number of channels in the input label.
41
+ num_labels = get_paired_input_label_channel_number(data_cfg)
42
+
43
+ if not self.use_label:
44
+ num_labels = 2 # ignore + true
45
+
46
+ # Build the discriminator.
47
+ num_filters = getattr(dis_cfg, 'num_filters', 128)
48
+ weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
49
+
50
+ fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3)
51
+ fpse_activation_norm_type = getattr(dis_cfg,
52
+ 'fpse_activation_norm_type',
53
+ 'none')
54
+ do_multiscale = getattr(dis_cfg, 'do_multiscale', False)
55
+ smooth_resample = getattr(dis_cfg, 'smooth_resample', False)
56
+ no_label_except_largest_scale = getattr(dis_cfg, 'no_label_except_largest_scale', False)
57
+
58
+ self.fpse_discriminator = FPSEDiscriminator(
59
+ image_channels,
60
+ num_labels,
61
+ num_filters,
62
+ fpse_kernel_size,
63
+ weight_norm_type,
64
+ fpse_activation_norm_type,
65
+ do_multiscale,
66
+ smooth_resample,
67
+ no_label_except_largest_scale)
68
+
69
+ def _single_forward(self, input_label, input_image, weights):
70
+ output_list, features_list = self.fpse_discriminator(input_image, input_label, weights)
71
+ return output_list, [features_list]
72
+
73
+ def forward(self, data, net_G_output, weights=None, incl_real=False, incl_pseudo_real=False):
74
+ r"""GANcraft discriminator forward.
75
+
76
+ Args:
77
+ data (dict):
78
+ - data (N x C1 x H x W tensor) : Ground truth images.
79
+ - label (N x C2 x H x W tensor) : Semantic representations.
80
+ - z (N x style_dims tensor): Gaussian random noise.
81
+ net_G_output (dict):
82
+ - fake_images (N x C1 x H x W tensor) : Fake images.
83
+ Returns:
84
+ output_x (dict):
85
+ - real_outputs (list): list of output tensors produced by
86
+ individual patch discriminators for real images.
87
+ - real_features (list): list of lists of features produced by
88
+ individual patch discriminators for real images.
89
+ - fake_outputs (list): list of output tensors produced by
90
+ individual patch discriminators for fake images.
91
+ - fake_features (list): list of lists of features produced by
92
+ individual patch discriminators for fake images.
93
+ """
94
+ output_x = dict()
95
+
96
+ # Fake.
97
+ fake_images = net_G_output['fake_images']
98
+ if self.use_label:
99
+ fake_labels = data['fake_masks']
100
+ else:
101
+ fake_labels = torch.zeros([fake_images.size(0), 2, fake_images.size(
102
+ 2), fake_images.size(3)], device=fake_images.device, dtype=fake_images.dtype)
103
+ fake_labels[:, 1, :, :] = 1
104
+ output_x['fake_outputs'], output_x['fake_features'] = \
105
+ self._single_forward(fake_labels, fake_images, None)
106
+
107
+ # Real.
108
+ if incl_real:
109
+ real_images = data['images']
110
+ if self.use_label:
111
+ real_labels = data['real_masks']
112
+ else:
113
+ real_labels = torch.zeros([real_images.size(0), 2, real_images.size(
114
+ 2), real_images.size(3)], device=real_images.device, dtype=real_images.dtype)
115
+ real_labels[:, 1, :, :] = 1
116
+ output_x['real_outputs'], output_x['real_features'] = \
117
+ self._single_forward(real_labels, real_images, None)
118
+
119
+ # pseudo-Real.
120
+ if incl_pseudo_real:
121
+ preal_images = data['pseudo_real_img']
122
+ preal_labels = data['fake_masks']
123
+ if not self.use_label:
124
+ preal_labels = torch.zeros([preal_images.size(0), 2, preal_images.size(
125
+ 2), preal_images.size(3)], device=preal_images.device, dtype=preal_images.dtype)
126
+ preal_labels[:, 1, :, :] = 1
127
+ output_x['pseudo_real_outputs'], output_x['pseudo_real_features'] = \
128
+ self._single_forward(preal_labels, preal_images, None)
129
+
130
+ return output_x
131
+
132
+
133
+ class FPSEDiscriminator(nn.Module):
134
+ def __init__(self,
135
+ num_input_channels,
136
+ num_labels,
137
+ num_filters,
138
+ kernel_size,
139
+ weight_norm_type,
140
+ activation_norm_type,
141
+ do_multiscale,
142
+ smooth_resample,
143
+ no_label_except_largest_scale):
144
+ super().__init__()
145
+
146
+ self.do_multiscale = do_multiscale
147
+ self.no_label_except_largest_scale = no_label_except_largest_scale
148
+
149
+ padding = int(np.ceil((kernel_size - 1.0) / 2))
150
+ nonlinearity = 'leakyrelu'
151
+ stride1_conv2d_block = \
152
+ functools.partial(Conv2dBlock,
153
+ kernel_size=kernel_size,
154
+ stride=1,
155
+ padding=padding,
156
+ weight_norm_type=weight_norm_type,
157
+ activation_norm_type=activation_norm_type,
158
+ nonlinearity=nonlinearity,
159
+ # inplace_nonlinearity=True,
160
+ order='CNA')
161
+ down_conv2d_block = \
162
+ functools.partial(Conv2dBlock,
163
+ kernel_size=kernel_size,
164
+ stride=2,
165
+ padding=padding,
166
+ weight_norm_type=weight_norm_type,
167
+ activation_norm_type=activation_norm_type,
168
+ nonlinearity=nonlinearity,
169
+ # inplace_nonlinearity=True,
170
+ order='CNA')
171
+ latent_conv2d_block = \
172
+ functools.partial(Conv2dBlock,
173
+ kernel_size=1,
174
+ stride=1,
175
+ weight_norm_type=weight_norm_type,
176
+ activation_norm_type=activation_norm_type,
177
+ nonlinearity=nonlinearity,
178
+ # inplace_nonlinearity=True,
179
+ order='CNA')
180
+ # bottom-up pathway
181
+ self.enc1 = down_conv2d_block(num_input_channels, num_filters) # 3
182
+ self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters) # 7
183
+ self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters) # 15
184
+ self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters) # 31
185
+ self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters) # 63
186
+
187
+ # top-down pathway
188
+ # self.lat1 = latent_conv2d_block(num_filters, 2 * num_filters) # Zekun
189
+ self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters)
190
+ self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters)
191
+ self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
192
+ self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
193
+
194
+ # upsampling
195
+ self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear',
196
+ align_corners=False)
197
+
198
+ # final layers
199
+ self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
200
+ self.output = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1)
201
+
202
+ if self.do_multiscale:
203
+ self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
204
+ self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
205
+ if self.no_label_except_largest_scale:
206
+ self.output3 = Conv2dBlock(num_filters * 2, 2, kernel_size=1)
207
+ self.output4 = Conv2dBlock(num_filters * 2, 2, kernel_size=1)
208
+ else:
209
+ self.output3 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1)
210
+ self.output4 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1)
211
+
212
+ self.interpolator = functools.partial(F.interpolate, mode='nearest')
213
+ if smooth_resample:
214
+ self.interpolator = self.smooth_interp
215
+
216
+ @staticmethod
217
+ def smooth_interp(x, size):
218
+ r"""Smooth interpolation of segmentation maps.
219
+
220
+ Args:
221
+ x (4D tensor): Segmentation maps.
222
+ size(2D list): Target size (H, W).
223
+ """
224
+ x = F.interpolate(x, size=size, mode='area')
225
+ onehot_idx = torch.argmax(x, dim=-3, keepdims=True)
226
+ x.fill_(0.0)
227
+ x.scatter_(1, onehot_idx, 1.0)
228
+ return x
229
+
230
+ # Weights: [N C]
231
+ def forward(self, images, segmaps, weights=None):
232
+ # Assume images 256x256
233
+ # bottom-up pathway
234
+ feat11 = self.enc1(images) # 128
235
+ feat12 = self.enc2(feat11) # 64
236
+ feat13 = self.enc3(feat12) # 32
237
+ feat14 = self.enc4(feat13) # 16
238
+ feat15 = self.enc5(feat14) # 8
239
+ # top-down pathway and lateral connections
240
+ feat25 = self.lat5(feat15) # 8
241
+ feat24 = self.upsample2x(feat25) + self.lat4(feat14) # 16
242
+ feat23 = self.upsample2x(feat24) + self.lat3(feat13) # 32
243
+ feat22 = self.upsample2x(feat23) + self.lat2(feat12) # 64
244
+
245
+ # final prediction layers
246
+ feat32 = self.final2(feat22)
247
+
248
+ results = []
249
+ label_map = self.interpolator(segmaps, size=feat32.size()[2:])
250
+ pred2 = self.output(feat32) # N, num_labels+1, H//4, W//4
251
+
252
+ features = [feat11, feat12, feat13, feat14, feat15, feat25, feat24, feat23, feat22]
253
+ if weights is not None:
254
+ label_map = label_map * weights[..., None, None]
255
+ results.append({'pred': pred2, 'label': label_map})
256
+
257
+ if self.do_multiscale:
258
+ feat33 = self.final3(feat23)
259
+ pred3 = self.output3(feat33)
260
+
261
+ feat34 = self.final4(feat24)
262
+ pred4 = self.output4(feat34)
263
+
264
+ if self.no_label_except_largest_scale:
265
+ label_map3 = torch.ones([pred3.size(0), 1, pred3.size(2), pred3.size(3)], device=pred3.device)
266
+ label_map4 = torch.ones([pred4.size(0), 1, pred4.size(2), pred4.size(3)], device=pred4.device)
267
+ else:
268
+ label_map3 = self.interpolator(segmaps, size=pred3.size()[2:])
269
+ label_map4 = self.interpolator(segmaps, size=pred4.size()[2:])
270
+
271
+ if weights is not None:
272
+ label_map3 = label_map3 * weights[..., None, None]
273
+ label_map4 = label_map4 * weights[..., None, None]
274
+
275
+ results.append({'pred': pred3, 'label': label_map3})
276
+ results.append({'pred': pred4, 'label': label_map4})
277
+
278
+ return results, features
imaginaire/generators/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
imaginaire/generators/gancraft_base.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import functools
6
+ import re
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from imaginaire.layers import Conv2dBlock, LinearBlock
14
+ from imaginaire.model_utils.layers import AffineMod, ModLinear
15
+ import imaginaire.model_utils.gancraft.mc_utils as mc_utils
16
+ import imaginaire.model_utils.gancraft.voxlib as voxlib
17
+ from imaginaire.utils.distributed import master_only_print as print
18
+
19
+
20
+ class RenderMLP(nn.Module):
21
+ r""" MLP with affine modulation."""
22
+
23
+ def __init__(self, in_channels, style_dim, viewdir_dim, mask_dim=680,
24
+ out_channels_s=1, out_channels_c=3, hidden_channels=256,
25
+ use_seg=True):
26
+ super(RenderMLP, self).__init__()
27
+
28
+ self.use_seg = use_seg
29
+ if self.use_seg:
30
+ self.fc_m_a = nn.Linear(mask_dim, hidden_channels, bias=False)
31
+
32
+ self.fc_viewdir = None
33
+ if viewdir_dim > 0:
34
+ self.fc_viewdir = nn.Linear(viewdir_dim, hidden_channels, bias=False)
35
+
36
+ self.fc_1 = nn.Linear(in_channels, hidden_channels)
37
+
38
+ self.fc_2 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
39
+ self.fc_3 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
40
+ self.fc_4 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
41
+
42
+ self.fc_sigma = nn.Linear(hidden_channels, out_channels_s)
43
+
44
+ if viewdir_dim > 0:
45
+ self.fc_5 = nn.Linear(hidden_channels, hidden_channels, bias=False)
46
+ self.mod_5 = AffineMod(hidden_channels, style_dim, mod_bias=True)
47
+ else:
48
+ self.fc_5 = ModLinear(hidden_channels, hidden_channels, style_dim,
49
+ bias=False, mod_bias=True, output_mode=True)
50
+ self.fc_6 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
51
+ self.fc_out_c = nn.Linear(hidden_channels, out_channels_c)
52
+
53
+ self.act = nn.LeakyReLU(negative_slope=0.2)
54
+
55
+ def forward(self, x, raydir, z, m):
56
+ r""" Forward network
57
+
58
+ Args:
59
+ x (N x H x W x M x in_channels tensor): Projected features.
60
+ raydir (N x H x W x 1 x viewdir_dim tensor): Ray directions.
61
+ z (N x style_dim tensor): Style codes.
62
+ m (N x H x W x M x mask_dim tensor): One-hot segmentation maps.
63
+ """
64
+ b, h, w, n, _ = x.size()
65
+ z = z[:, None, None, None, :]
66
+
67
+ f = self.fc_1(x)
68
+ if self.use_seg:
69
+ f = f + self.fc_m_a(m)
70
+ # Common MLP
71
+ f = self.act(f)
72
+ f = self.act(self.fc_2(f, z))
73
+ f = self.act(self.fc_3(f, z))
74
+ f = self.act(self.fc_4(f, z))
75
+
76
+ # Sigma MLP
77
+ sigma = self.fc_sigma(f)
78
+
79
+ # Color MLP
80
+ if self.fc_viewdir is not None:
81
+ f = self.fc_5(f)
82
+ f = f + self.fc_viewdir(raydir)
83
+ f = self.act(self.mod_5(f, z))
84
+ else:
85
+ f = self.act(self.fc_5(f, z))
86
+ f = self.act(self.fc_6(f, z))
87
+ c = self.fc_out_c(f)
88
+ return sigma, c
89
+
90
+
91
+ class StyleMLP(nn.Module):
92
+ r"""MLP converting style code to intermediate style representation."""
93
+
94
+ def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True,
95
+ output_act=True):
96
+ super(StyleMLP, self).__init__()
97
+
98
+ self.normalize_input = normalize_input
99
+ self.output_act = output_act
100
+ fc_layers = []
101
+ fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True))
102
+ for i in range(num_layers-1):
103
+ fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True))
104
+ self.fc_layers = nn.ModuleList(fc_layers)
105
+
106
+ self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True)
107
+
108
+ if leaky_relu:
109
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
110
+ else:
111
+ self.act = functools.partial(F.relu, inplace=True)
112
+
113
+ def forward(self, z):
114
+ r""" Forward network
115
+
116
+ Args:
117
+ z (N x style_dim tensor): Style codes.
118
+ """
119
+ if self.normalize_input:
120
+ z = F.normalize(z, p=2, dim=-1)
121
+ for fc_layer in self.fc_layers:
122
+ z = self.act(fc_layer(z))
123
+ z = self.fc_out(z)
124
+ if self.output_act:
125
+ z = self.act(z)
126
+ return z
127
+
128
+
129
+ class SKYMLP(nn.Module):
130
+ r"""MLP converting ray directions to sky features."""
131
+
132
+ def __init__(self, in_channels, style_dim, out_channels_c=3,
133
+ hidden_channels=256, leaky_relu=True):
134
+ super(SKYMLP, self).__init__()
135
+ self.fc_z_a = nn.Linear(style_dim, hidden_channels, bias=False)
136
+
137
+ self.fc1 = nn.Linear(in_channels, hidden_channels)
138
+ self.fc2 = nn.Linear(hidden_channels, hidden_channels)
139
+ self.fc3 = nn.Linear(hidden_channels, hidden_channels)
140
+ self.fc4 = nn.Linear(hidden_channels, hidden_channels)
141
+ self.fc5 = nn.Linear(hidden_channels, hidden_channels)
142
+
143
+ self.fc_out_c = nn.Linear(hidden_channels, out_channels_c)
144
+
145
+ if leaky_relu:
146
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
147
+ else:
148
+ self.act = functools.partial(F.relu, inplace=True)
149
+
150
+ def forward(self, x, z):
151
+ r"""Forward network
152
+
153
+ Args:
154
+ x (... x in_channels tensor): Ray direction embeddings.
155
+ z (... x style_dim tensor): Style codes.
156
+ """
157
+
158
+ z = self.fc_z_a(z)
159
+ while z.dim() < x.dim():
160
+ z = z.unsqueeze(1)
161
+
162
+ y = self.act(self.fc1(x) + z)
163
+ y = self.act(self.fc2(y))
164
+ y = self.act(self.fc3(y))
165
+ y = self.act(self.fc4(y))
166
+ y = self.act(self.fc5(y))
167
+ c = self.fc_out_c(y)
168
+
169
+ return c
170
+
171
+
172
+ class RenderCNN(nn.Module):
173
+ r"""CNN converting intermediate feature map to final image."""
174
+
175
+ def __init__(self, in_channels, style_dim, hidden_channels=256,
176
+ leaky_relu=True):
177
+ super(RenderCNN, self).__init__()
178
+ self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels)
179
+
180
+ self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0)
181
+ self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1)
182
+ self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False)
183
+
184
+ self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1)
185
+ self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False)
186
+
187
+ self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0)
188
+ self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0)
189
+
190
+ self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0)
191
+
192
+ if leaky_relu:
193
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
194
+ else:
195
+ self.act = functools.partial(F.relu, inplace=True)
196
+
197
+ def modulate(self, x, w, b):
198
+ w = w[..., None, None]
199
+ b = b[..., None, None]
200
+ return x * (w+1) + b
201
+
202
+ def forward(self, x, z):
203
+ r"""Forward network.
204
+
205
+ Args:
206
+ x (N x in_channels x H x W tensor): Intermediate feature map
207
+ z (N x style_dim tensor): Style codes.
208
+ """
209
+ z = self.fc_z_cond(z)
210
+ adapt = torch.chunk(z, 2 * 2, dim=-1)
211
+
212
+ y = self.act(self.conv1(x))
213
+
214
+ y = y + self.conv2b(self.act(self.conv2a(y)))
215
+ y = self.act(self.modulate(y, adapt[0], adapt[1]))
216
+
217
+ y = y + self.conv3b(self.act(self.conv3a(y)))
218
+ y = self.act(self.modulate(y, adapt[2], adapt[3]))
219
+
220
+ y = y + self.conv4b(self.act(self.conv4a(y)))
221
+ y = self.act(y)
222
+
223
+ y = self.conv4(y)
224
+
225
+ return y
226
+
227
+
228
+ class StyleEncoder(nn.Module):
229
+ r"""Style Encoder constructor.
230
+
231
+ Args:
232
+ style_enc_cfg (obj): Style encoder definition file.
233
+ """
234
+
235
+ def __init__(self, style_enc_cfg):
236
+ super(StyleEncoder, self).__init__()
237
+ input_image_channels = style_enc_cfg.input_image_channels
238
+ num_filters = style_enc_cfg.num_filters
239
+ kernel_size = style_enc_cfg.kernel_size
240
+ padding = int(np.ceil((kernel_size - 1.0) / 2))
241
+ style_dims = style_enc_cfg.style_dims
242
+ weight_norm_type = style_enc_cfg.weight_norm_type
243
+ self.no_vae = getattr(style_enc_cfg, 'no_vae', False)
244
+ activation_norm_type = 'none'
245
+ nonlinearity = 'leakyrelu'
246
+ base_conv2d_block = \
247
+ functools.partial(Conv2dBlock,
248
+ kernel_size=kernel_size,
249
+ stride=2,
250
+ padding=padding,
251
+ weight_norm_type=weight_norm_type,
252
+ activation_norm_type=activation_norm_type,
253
+ # inplace_nonlinearity=True,
254
+ nonlinearity=nonlinearity)
255
+ self.layer1 = base_conv2d_block(input_image_channels, num_filters)
256
+ self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2)
257
+ self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4)
258
+ self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8)
259
+ self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8)
260
+ self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8)
261
+ self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
262
+ if not self.no_vae:
263
+ self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
264
+
265
+ def forward(self, input_x):
266
+ r"""SPADE Style Encoder forward.
267
+
268
+ Args:
269
+ input_x (N x 3 x H x W tensor): input images.
270
+ Returns:
271
+ mu (N x C tensor): Mean vectors.
272
+ logvar (N x C tensor): Log-variance vectors.
273
+ z (N x C tensor): Style code vectors.
274
+ """
275
+ if input_x.size(2) != 256 or input_x.size(3) != 256:
276
+ input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear')
277
+ x = self.layer1(input_x)
278
+ x = self.layer2(x)
279
+ x = self.layer3(x)
280
+ x = self.layer4(x)
281
+ x = self.layer5(x)
282
+ x = self.layer6(x)
283
+ x = x.view(x.size(0), -1)
284
+ mu = self.fc_mu(x)
285
+ if not self.no_vae:
286
+ logvar = self.fc_var(x)
287
+ std = torch.exp(0.5 * logvar)
288
+ eps = torch.randn_like(std)
289
+ z = eps.mul(std) + mu
290
+ else:
291
+ z = mu
292
+ logvar = torch.zeros_like(mu)
293
+ return mu, logvar, z
294
+
295
+
296
+ class Base3DGenerator(nn.Module):
297
+ r"""Minecraft 3D generator constructor.
298
+
299
+ Args:
300
+ gen_cfg (obj): Generator definition part of the yaml config file.
301
+ data_cfg (obj): Data definition part of the yaml config file.
302
+ """
303
+
304
+ def __init__(self, gen_cfg, data_cfg):
305
+ super(Base3DGenerator, self).__init__()
306
+ print('Base3DGenerator initialization.')
307
+
308
+ # ---------------------- Main Network ------------------------
309
+ # Exclude some of the features from positional encoding
310
+ self.pe_no_pe_feat_dim = getattr(gen_cfg, 'pe_no_pe_feat_dim', 0)
311
+
312
+ # blk_feat passes through PE
313
+ input_dim = (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim)*(gen_cfg.pe_lvl_feat*2) + self.pe_no_pe_feat_dim
314
+ if (gen_cfg.pe_incl_orig_feat):
315
+ input_dim += (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim)
316
+ print('[Base3DGenerator] Expected input dimensions: ', input_dim)
317
+ self.input_dim = input_dim
318
+
319
+ self.mlp_model_kwargs = gen_cfg.mlp_model_kwargs
320
+ self.pe_lvl_localcoords = getattr(gen_cfg, 'pe_lvl_localcoords', 0)
321
+ if self.pe_lvl_localcoords > 0:
322
+ self.mlp_model_kwargs['poscode_dim'] = self.pe_lvl_localcoords * 2 * 3
323
+
324
+ # Set pe_lvl_raydir=0 and pe_incl_orig_raydir=False to disable view direction input
325
+ input_dim_viewdir = 3*(gen_cfg.pe_lvl_raydir*2)
326
+ if (gen_cfg.pe_incl_orig_raydir):
327
+ input_dim_viewdir += 3
328
+ print('[Base3DGenerator] Expected viewdir input dimensions: ', input_dim_viewdir)
329
+ self.input_dim_viewdir = input_dim_viewdir
330
+
331
+ self.pe_params = [gen_cfg.pe_lvl_feat, gen_cfg.pe_incl_orig_feat,
332
+ gen_cfg.pe_lvl_raydir, gen_cfg.pe_incl_orig_raydir]
333
+
334
+ # Style input dimension
335
+ style_dims = gen_cfg.style_dims
336
+ self.style_dims = style_dims
337
+ interm_style_dims = getattr(gen_cfg, 'interm_style_dims', style_dims)
338
+ self.interm_style_dims = interm_style_dims
339
+ # ---------------------- Style MLP --------------------------
340
+ self.style_net = globals()[gen_cfg.stylenet_model](
341
+ style_dims, interm_style_dims, **gen_cfg.stylenet_model_kwargs)
342
+
343
+ # number of output channels for MLP (before blending)
344
+ final_feat_dim = getattr(gen_cfg, 'final_feat_dim', 16)
345
+ self.final_feat_dim = final_feat_dim
346
+
347
+ # ----------------------- Sky Network -------------------------
348
+ sky_input_dim_base = 3
349
+ # Dedicated sky network input dimensions
350
+ sky_input_dim = sky_input_dim_base*(gen_cfg.pe_lvl_raydir_sky*2)
351
+ if (gen_cfg.pe_incl_orig_raydir_sky):
352
+ sky_input_dim += sky_input_dim_base
353
+ print('[Base3DGenerator] Expected sky input dimensions: ', sky_input_dim)
354
+ self.pe_params_sky = [gen_cfg.pe_lvl_raydir_sky, gen_cfg.pe_incl_orig_raydir_sky]
355
+ self.sky_net = SKYMLP(sky_input_dim, style_dim=interm_style_dims, out_channels_c=final_feat_dim)
356
+
357
+ # ----------------------- Style Encoder -------------------------
358
+ style_enc_cfg = getattr(gen_cfg, 'style_enc', None)
359
+ setattr(style_enc_cfg, 'input_image_channels', 3)
360
+ setattr(style_enc_cfg, 'style_dims', gen_cfg.style_dims)
361
+ self.style_encoder = StyleEncoder(style_enc_cfg)
362
+
363
+ # ---------------------- Ray Caster -------------------------
364
+ self.num_blocks_early_stop = gen_cfg.num_blocks_early_stop
365
+ self.num_samples = gen_cfg.num_samples
366
+ self.sample_depth = gen_cfg.sample_depth
367
+ self.coarse_deterministic_sampling = getattr(gen_cfg, 'coarse_deterministic_sampling', True)
368
+ self.sample_use_box_boundaries = getattr(gen_cfg, 'sample_use_box_boundaries', True)
369
+
370
+ # ---------------------- Blender -------------------------
371
+ self.raw_noise_std = getattr(gen_cfg, 'raw_noise_std', 0.0)
372
+ self.dists_scale = getattr(gen_cfg, 'dists_scale', 0.25)
373
+ self.clip_feat_map = getattr(gen_cfg, 'clip_feat_map', True)
374
+ self.keep_sky_out = getattr(gen_cfg, 'keep_sky_out', False)
375
+ self.keep_sky_out_avgpool = getattr(gen_cfg, 'keep_sky_out_avgpool', False)
376
+ keep_sky_out_learnbg = getattr(gen_cfg, 'keep_sky_out_learnbg', False)
377
+ self.sky_global_avgpool = getattr(gen_cfg, 'sky_global_avgpool', False)
378
+ if self.keep_sky_out:
379
+ self.sky_replace_color = None
380
+ if keep_sky_out_learnbg:
381
+ sky_replace_color = torch.zeros([final_feat_dim])
382
+ sky_replace_color.requires_grad = True
383
+ self.sky_replace_color = torch.nn.Parameter(sky_replace_color)
384
+ # ---------------------- render_cnn -------------------------
385
+ self.denoiser = RenderCNN(final_feat_dim, style_dim=interm_style_dims)
386
+ self.pad = gen_cfg.pad
387
+
388
+ def get_param_groups(self, cfg_opt):
389
+ print('[Generator] get_param_groups')
390
+
391
+ if hasattr(cfg_opt, 'ignore_parameters'):
392
+ print('[Generator::get_param_groups] [x]: ignored.')
393
+ optimize_parameters = []
394
+ for k, x in self.named_parameters():
395
+ match = False
396
+ for m in cfg_opt.ignore_parameters:
397
+ if re.match(m, k) is not None:
398
+ match = True
399
+ print(' [x]', k)
400
+ break
401
+ if match is False:
402
+ print(' [v]', k)
403
+ optimize_parameters.append(x)
404
+ else:
405
+ optimize_parameters = self.parameters()
406
+
407
+ param_groups = []
408
+ param_groups.append({'params': optimize_parameters})
409
+
410
+ if hasattr(cfg_opt, 'param_groups'):
411
+ optimized_param_names = []
412
+ all_param_names = [k for k, v in self.named_parameters()]
413
+ param_groups = []
414
+ for k, v in cfg_opt.param_groups.items():
415
+ print('[Generator::get_param_groups] Adding param group from config:', k, v)
416
+ params = getattr(self, k)
417
+ named_parameters = [k]
418
+ if issubclass(type(params), nn.Module):
419
+ named_parameters = [k+'.'+pname for pname, _ in params.named_parameters()]
420
+ params = params.parameters()
421
+ param_groups.append({'params': params, **v})
422
+ optimized_param_names.extend(named_parameters)
423
+
424
+ print('[Generator::get_param_groups] UNOPTIMIZED PARAMETERS:\n ',
425
+ set(all_param_names) - set(optimized_param_names))
426
+
427
+ return param_groups
428
+
429
+ def _forward_perpix_sub(self, blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot=None):
430
+ r"""Forwarding the MLP.
431
+
432
+ Args:
433
+ blk_feats (K x C1 tensor): Sparse block features.
434
+ worldcoord2 (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. L is number of samples; N is batch size, always 1.
435
+ raydirs_in (N x H x W x 1 x C2 tensor or None): ray direction embeddings.
436
+ z (N x C3 tensor): Intermediate style vectors.
437
+ mc_masks_onehot (N x H x W x L x C4): One-hot segmentation maps.
438
+ Returns:
439
+ net_out_s (N x H x W x L x 1 tensor): Opacities.
440
+ net_out_c (N x H x W x L x C5 tensor): Color embeddings.
441
+ """
442
+ proj_feature = voxlib.sparse_trilinear_interp_worldcoord(
443
+ blk_feats, self.voxel.corner_t, worldcoord2, ign_zero=True)
444
+
445
+ render_net_extra_kwargs = {}
446
+ if self.pe_lvl_localcoords > 0:
447
+ local_coords = torch.remainder(worldcoord2, 1.0) * 2.0
448
+ # Scale to [0, 2], as the positional encoding function doesn't have internal x2
449
+ local_coords[torch.isnan(local_coords)] = 0.0
450
+ local_coords = local_coords.contiguous()
451
+ poscode = voxlib.positional_encoding(local_coords, self.pe_lvl_localcoords, -1, False)
452
+ render_net_extra_kwargs['poscode'] = poscode
453
+
454
+ if self.pe_params[0] == 0 and self.pe_params[1] is True: # no PE shortcut, saves ~400MB
455
+ feature_in = proj_feature
456
+ else:
457
+ if self.pe_no_pe_feat_dim > 0:
458
+ feature_in = voxlib.positional_encoding(
459
+ proj_feature[..., :-self.pe_no_pe_feat_dim].contiguous(), self.pe_params[0], -1, self.pe_params[1])
460
+ feature_in = torch.cat([feature_in, proj_feature[..., -self.pe_no_pe_feat_dim:]], dim=-1)
461
+ else:
462
+ feature_in = voxlib.positional_encoding(
463
+ proj_feature.contiguous(), self.pe_params[0], -1, self.pe_params[1])
464
+
465
+ net_out_s, net_out_c = self.render_net(feature_in, raydirs_in, z, mc_masks_onehot, **render_net_extra_kwargs)
466
+
467
+ if self.raw_noise_std > 0.:
468
+ noise = torch.randn_like(net_out_s) * self.raw_noise_std
469
+ net_out_s = net_out_s + noise
470
+
471
+ return net_out_s, net_out_c
472
+
473
+ def _forward_perpix(self, blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z):
474
+ r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features
475
+
476
+ Args:
477
+ blk_feats (K x C1 tensor): Sparse block features.
478
+ voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels, why always 6?
479
+ depth2 (N x 2 x H x W x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection.
480
+ raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
481
+ cam_ori_t (N x 3 tensor): Camera origins.
482
+ z (N x C3 tensor): Intermediate style vectors.
483
+ """
484
+ # Generate sky_mask; PE transform on ray direction.
485
+ with torch.no_grad():
486
+ raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
487
+ if self.pe_params[2] == 0 and self.pe_params[3] is True:
488
+ raydirs_in = raydirs_in
489
+ elif self.pe_params[2] == 0 and self.pe_params[3] is False: # Not using raydir at all
490
+ raydirs_in = None
491
+ else:
492
+ raydirs_in = voxlib.positional_encoding(raydirs_in, self.pe_params[2], -1, self.pe_params[3])
493
+
494
+ # sky_mask: when True, ray finally hits sky
495
+ sky_mask = voxel_id[:, :, :, [-1], :] == 0
496
+ # sky_only_mask: when True, ray hits nothing but sky
497
+ sky_only_mask = voxel_id[:, :, :, [0], :] == 0
498
+
499
+ with torch.no_grad():
500
+ # Random sample points along the ray
501
+ num_samples = self.num_samples + 1
502
+ if self.sample_use_box_boundaries:
503
+ num_samples = self.num_samples - self.num_blocks_early_stop
504
+
505
+ # 10 samples per ray + 4 intersections - 2
506
+ rand_depth, new_dists, new_idx = mc_utils.sample_depth_batched(
507
+ depth2, num_samples, deterministic=self.coarse_deterministic_sampling,
508
+ use_box_boundaries=self.sample_use_box_boundaries, sample_depth=self.sample_depth)
509
+
510
+ worldcoord2 = raydirs * rand_depth + cam_ori_t[:, None, None, None, :]
511
+
512
+ # Generate per-sample segmentation label
513
+ voxel_id_reduced = self.label_trans.mc2reduced(voxel_id, ign2dirt=True)
514
+ mc_masks = torch.gather(voxel_id_reduced, -2, new_idx) # B 256 256 N 1
515
+ mc_masks = mc_masks.long()
516
+ mc_masks_onehot = torch.zeros([mc_masks.size(0), mc_masks.size(1), mc_masks.size(
517
+ 2), mc_masks.size(3), self.num_reduced_labels], dtype=torch.float, device=voxel_id.device)
518
+ # mc_masks_onehot: [B H W Nlayer 680]
519
+ mc_masks_onehot.scatter_(-1, mc_masks, 1.0)
520
+
521
+ net_out_s, net_out_c = self._forward_perpix_sub(blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot)
522
+
523
+ # Handle sky
524
+ sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
525
+ sky_raydirs_in = voxlib.positional_encoding(sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1])
526
+ skynet_out_c = self.sky_net(sky_raydirs_in, z)
527
+
528
+ # Blending
529
+ weights = mc_utils.volum_rendering_relu(net_out_s, new_dists * self.dists_scale, dim=-2)
530
+
531
+ # If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero.
532
+ weights = weights * torch.logical_not(sky_only_mask).float()
533
+ total_weights_raw = torch.sum(weights, dim=-2, keepdim=True) # 256 256 1 1
534
+ total_weights = total_weights_raw
535
+
536
+ is_gnd = worldcoord2[..., [0]] <= 1.0 # Y X Z, [256, 256, 4, 3], nan < 1.0 == False
537
+ is_gnd = is_gnd.any(dim=-2, keepdim=True)
538
+ nosky_mask = torch.logical_or(torch.logical_not(sky_mask), is_gnd)
539
+ nosky_mask = nosky_mask.float()
540
+
541
+ # Avoid sky leakage
542
+ sky_weight = 1.0-total_weights
543
+ if self.keep_sky_out:
544
+ # keep_sky_out_avgpool overrides sky_replace_color
545
+ if self.sky_replace_color is None or self.keep_sky_out_avgpool:
546
+ if self.keep_sky_out_avgpool:
547
+ if hasattr(self, 'sky_avg'):
548
+ sky_avg = self.sky_avg
549
+ else:
550
+ if self.sky_global_avgpool:
551
+ sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True)
552
+ else:
553
+ skynet_out_c_nchw = skynet_out_c.permute(0, 4, 1, 2, 3).squeeze(-1).contiguous()
554
+ sky_avg = F.avg_pool2d(skynet_out_c_nchw, 31, stride=1, padding=15, count_include_pad=False)
555
+ sky_avg = sky_avg.permute(0, 2, 3, 1).unsqueeze(-2).contiguous()
556
+ # print(sky_avg.shape)
557
+ skynet_out_c = skynet_out_c * (1.0-nosky_mask) + sky_avg*(nosky_mask)
558
+ else:
559
+ sky_weight = sky_weight * (1.0-nosky_mask)
560
+ else:
561
+ skynet_out_c = skynet_out_c * (1.0-nosky_mask) + self.sky_replace_color*(nosky_mask)
562
+
563
+ if self.clip_feat_map is True: # intermediate feature before blending & CNN
564
+ rgbs = torch.clamp(net_out_c, -1, 1) + 1
565
+ rgbs_sky = torch.clamp(skynet_out_c, -1, 1) + 1
566
+ net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
567
+ rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
568
+ net_out = net_out.squeeze(-2)
569
+ net_out = net_out - 1
570
+ elif self.clip_feat_map is False:
571
+ rgbs = net_out_c
572
+ rgbs_sky = skynet_out_c
573
+ net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
574
+ rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
575
+ net_out = net_out.squeeze(-2)
576
+ elif self.clip_feat_map == 'tanh':
577
+ rgbs = torch.tanh(net_out_c)
578
+ rgbs_sky = torch.tanh(skynet_out_c)
579
+ net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
580
+ rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
581
+ net_out = net_out.squeeze(-2)
582
+ else:
583
+ raise NotImplementedError
584
+
585
+ return net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \
586
+ nosky_mask, sky_mask, sky_only_mask, new_idx
587
+
588
+ def _forward_global(self, net_out, z):
589
+ r"""Forward the CNN
590
+
591
+ Args:
592
+ net_out (N x C5 x H x W tensor): Intermediate feature maps.
593
+ z (N x C3 tensor): Intermediate style vectors.
594
+
595
+ Returns:
596
+ fake_images (N x 3 x H x W tensor): Output image.
597
+ fake_images_raw (N x 3 x H x W tensor): Output image before TanH.
598
+ """
599
+ fake_images = net_out.permute(0, 3, 1, 2).contiguous()
600
+ fake_images_raw = self.denoiser(fake_images, z)
601
+ fake_images = torch.tanh(fake_images_raw)
602
+
603
+ return fake_images, fake_images_raw
imaginaire/generators/scenedreamer.py ADDED
@@ -0,0 +1,851 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Using Hashgrid as backbone representation
2
+
3
+ import os
4
+ import cv2
5
+ import imageio
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ import imaginaire.model_utils.gancraft.camctl as camctl
12
+ import imaginaire.model_utils.gancraft.mc_utils as mc_utils
13
+ import imaginaire.model_utils.gancraft.voxlib as voxlib
14
+ from imaginaire.model_utils.pcg_gen import PCGVoxelGenerator, PCGCache
15
+ from imaginaire.utils.distributed import master_only_print as print
16
+ from imaginaire.generators.gancraft_base import Base3DGenerator
17
+ from encoding import get_encoder
18
+
19
+ from imaginaire.model_utils.layers import LightningMLP, ConditionalHashGrid
20
+
21
+ class Generator(Base3DGenerator):
22
+ r"""SceneDreamer generator constructor.
23
+
24
+ Args:
25
+ gen_cfg (obj): Generator definition part of the yaml config file.
26
+ data_cfg (obj): Data definition part of the yaml config file.
27
+ """
28
+
29
+ def __init__(self, gen_cfg, data_cfg):
30
+ super(Generator, self).__init__(gen_cfg, data_cfg)
31
+ print('SceneDreamer[Hash] on ALL Scenes generator initialization.')
32
+
33
+ # here should be a list of height maps and semantic maps
34
+ if gen_cfg.pcg_cache:
35
+ print('[Generator] Loading PCG dataset: ', gen_cfg.pcg_dataset_path)
36
+ self.voxel = PCGCache(gen_cfg.pcg_dataset_path)
37
+ print('[Generator] Loaded PCG dataset.')
38
+ else:
39
+ self.voxel = PCGVoxelGenerator(gen_cfg.scene_size)
40
+ self.blk_feats = None
41
+ # Minecraft -> SPADE label translator.
42
+ self.label_trans = mc_utils.MCLabelTranslator()
43
+ self.num_reduced_labels = self.label_trans.get_num_reduced_lbls()
44
+ self.reduced_label_set = getattr(gen_cfg, 'reduced_label_set', False)
45
+ self.use_label_smooth = getattr(gen_cfg, 'use_label_smooth', False)
46
+ self.use_label_smooth_real = getattr(gen_cfg, 'use_label_smooth_real', self.use_label_smooth)
47
+ self.use_label_smooth_pgt = getattr(gen_cfg, 'use_label_smooth_pgt', False)
48
+ self.label_smooth_dia = getattr(gen_cfg, 'label_smooth_dia', 11)
49
+
50
+ # Load MLP model.
51
+ self.hash_encoder, self.hash_in_dim = get_encoder(encoding='hashgrid', input_dim=5, desired_resolution=2048 * 1, level_dim=8)
52
+ self.render_net = LightningMLP(self.hash_in_dim, viewdir_dim=self.input_dim_viewdir, style_dim=self.interm_style_dims, mask_dim=self.num_reduced_labels, out_channels_s=1, out_channels_c=self.final_feat_dim, **self.mlp_model_kwargs)
53
+ print(self.hash_encoder)
54
+ self.world_encoder = ConditionalHashGrid()
55
+
56
+ # Camera sampler.
57
+ self.camera_sampler_type = getattr(gen_cfg, 'camera_sampler_type', "random")
58
+ assert self.camera_sampler_type in ['random', 'traditional']
59
+ self.camera_min_entropy = getattr(gen_cfg, 'camera_min_entropy', -1)
60
+ self.camera_rej_avg_depth = getattr(gen_cfg, 'camera_rej_avg_depth', -1)
61
+ self.cam_res = gen_cfg.cam_res
62
+ self.crop_size = gen_cfg.crop_size
63
+
64
+ print('Done with the SceneDreamer initialization.')
65
+
66
+ def custom_init(self):
67
+ r"""Weight initialization."""
68
+
69
+ def init_func(m):
70
+ if hasattr(m, 'weight'):
71
+ try:
72
+ nn.init.kaiming_normal_(m.weight.data, a=0.2, nonlinearity='leaky_relu')
73
+ except:
74
+ print(m.name)
75
+ m.weight.data *= 0.5
76
+ if hasattr(m, 'bias') and m.bias is not None:
77
+ m.bias.data.fill_(0.0)
78
+ self.apply(init_func)
79
+
80
+ def _get_batch(self, batch_size, device):
81
+ r"""Sample camera poses and perform ray-voxel intersection.
82
+
83
+ Args:
84
+ batch_size (int): Expected batch size of the current batch
85
+ device (torch.device): Device on which the tensors should be stored
86
+ """
87
+ with torch.no_grad():
88
+ self.voxel.sample_world(device)
89
+ voxel_id_batch = []
90
+ depth2_batch = []
91
+ raydirs_batch = []
92
+ cam_ori_t_batch = []
93
+ for b in range(batch_size):
94
+ while True: # Rejection sampling.
95
+ # Sample camera pose.
96
+ if self.camera_sampler_type == 'random':
97
+ cam_res = self.cam_res
98
+ cam_ori_t, cam_dir_t, cam_up_t = camctl.rand_camera_pose_thridperson2(self.voxel)
99
+ # ~24mm fov horizontal.
100
+ cam_f = 0.5/np.tan(np.deg2rad(73/2) * (np.random.rand(1)*0.5+0.5)) * (cam_res[1]-1)
101
+ cam_c = [(cam_res[0]-1)/2, (cam_res[1]-1)/2]
102
+ cam_res_crop = [self.crop_size[0] + self.pad, self.crop_size[1] + self.pad]
103
+ cam_c = mc_utils.rand_crop(cam_c, cam_res, cam_res_crop)
104
+ elif self.camera_sampler_type == 'traditional':
105
+ cam_res = self.cam_res
106
+ cam_c = [(cam_res[0]-1)/2, (cam_res[1]-1)/2]
107
+ dice = torch.rand(1).item()
108
+ if dice > 0.5:
109
+ cam_ori_t, cam_dir_t, cam_up_t, cam_f = \
110
+ camctl.rand_camera_pose_tour(self.voxel)
111
+ cam_f = cam_f * (cam_res[1]-1)
112
+ else:
113
+ cam_ori_t, cam_dir_t, cam_up_t = \
114
+ camctl.rand_camera_pose_thridperson2(self.voxel)
115
+ # ~24mm fov horizontal.
116
+ cam_f = 0.5 / np.tan(np.deg2rad(73/2) * (np.random.rand(1)*0.5+0.5)) * (cam_res[1]-1)
117
+
118
+ cam_res_crop = [self.crop_size[0] + self.pad, self.crop_size[1] + self.pad]
119
+ cam_c = mc_utils.rand_crop(cam_c, cam_res, cam_res_crop)
120
+ else:
121
+ raise NotImplementedError(
122
+ 'Unknown self.camera_sampler_type: {}'.format(self.camera_sampler_type))
123
+
124
+ # Run ray-voxel intersection test
125
+ voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective(
126
+ self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, cam_res_crop,
127
+ self.num_blocks_early_stop)
128
+
129
+ if self.camera_rej_avg_depth > 0:
130
+ depth_map = depth2[0, :, :, 0, :]
131
+ avg_depth = torch.mean(depth_map[~torch.isnan(depth_map)])
132
+ if avg_depth < self.camera_rej_avg_depth:
133
+ continue
134
+
135
+ # Reject low entropy.
136
+ if self.camera_min_entropy > 0:
137
+ # Check entropy.
138
+ maskcnt = torch.bincount(
139
+ torch.flatten(voxel_id[:, :, 0, 0]), weights=None, minlength=680).float() / \
140
+ (voxel_id.size(0)*voxel_id.size(1))
141
+ maskentropy = -torch.sum(maskcnt * torch.log(maskcnt+1e-10))
142
+ if maskentropy < self.camera_min_entropy:
143
+ continue
144
+ break
145
+
146
+ voxel_id_batch.append(voxel_id)
147
+ depth2_batch.append(depth2)
148
+ raydirs_batch.append(raydirs)
149
+ cam_ori_t_batch.append(cam_ori_t)
150
+ voxel_id = torch.stack(voxel_id_batch, dim=0)
151
+ depth2 = torch.stack(depth2_batch, dim=0)
152
+ raydirs = torch.stack(raydirs_batch, dim=0)
153
+ cam_ori_t = torch.stack(cam_ori_t_batch, dim=0).to(device)
154
+ cam_poses = None
155
+ return voxel_id, depth2, raydirs, cam_ori_t, cam_poses
156
+
157
+
158
+ def get_pseudo_gt(self, pseudo_gen, voxel_id, z=None, style_img=None, resize_512=True, deterministic=False):
159
+ r"""Evaluating img2img network to obtain pseudo-ground truth images.
160
+
161
+ Args:
162
+ pseudo_gen (callable): Function converting mask to image using img2img network.
163
+ voxel_id (N x img_dims[0] x img_dims[1] x max_samples x 1 tensor): IDs of intersected tensors along
164
+ each ray.
165
+ z (N x C tensor): Optional style code passed to pseudo_gen.
166
+ style_img (N x 3 x H x W tensor): Optional style image passed to pseudo_gen.
167
+ resize_512 (bool): If True, evaluate pseudo_gen at 512x512 regardless of input resolution.
168
+ deterministic (bool): If True, disable stochastic label mapping.
169
+ """
170
+ with torch.no_grad():
171
+ mc_mask = voxel_id[:, :, :, 0, :].permute(0, 3, 1, 2).long().contiguous()
172
+ coco_mask = self.label_trans.mc2coco(mc_mask) - 1
173
+ coco_mask[coco_mask < 0] = 183
174
+
175
+ if not deterministic:
176
+ # Stochastic mapping
177
+ dice = torch.rand(1).item()
178
+ if dice > 0.5 and dice < 0.9:
179
+ coco_mask[coco_mask == self.label_trans.gglbl2ggid('sky')] = self.label_trans.gglbl2ggid('clouds')
180
+ elif dice >= 0.9:
181
+ coco_mask[coco_mask == self.label_trans.gglbl2ggid('sky')] = self.label_trans.gglbl2ggid('fog')
182
+ dice = torch.rand(1).item()
183
+ if dice > 0.33 and dice < 0.66:
184
+ coco_mask[coco_mask == self.label_trans.gglbl2ggid('water')] = self.label_trans.gglbl2ggid('sea')
185
+ elif dice >= 0.66:
186
+ coco_mask[coco_mask == self.label_trans.gglbl2ggid('water')] = self.label_trans.gglbl2ggid('river')
187
+
188
+ fake_masks = torch.zeros([coco_mask.size(0), 185, coco_mask.size(2), coco_mask.size(3)],
189
+ dtype=torch.half, device=voxel_id.device)
190
+ fake_masks.scatter_(1, coco_mask, 1.0)
191
+
192
+ if self.use_label_smooth_pgt:
193
+ fake_masks = mc_utils.segmask_smooth(fake_masks, kernel_size=self.label_smooth_dia)
194
+ if self.pad > 0:
195
+ fake_masks = fake_masks[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
196
+
197
+ # Generate pseudo GT using GauGAN.
198
+ if resize_512:
199
+ fake_masks_512 = F.interpolate(fake_masks, size=[512, 512], mode='nearest')
200
+ else:
201
+ fake_masks_512 = fake_masks
202
+ pseudo_real_img = pseudo_gen(fake_masks_512, z=z, style_img=style_img)
203
+
204
+ # NaN Inf Guard. NaN can occure on Volta GPUs.
205
+ nan_mask = torch.isnan(pseudo_real_img)
206
+ inf_mask = torch.isinf(pseudo_real_img)
207
+ pseudo_real_img[nan_mask | inf_mask] = 0.0
208
+ if resize_512:
209
+ pseudo_real_img = F.interpolate(
210
+ pseudo_real_img, size=[fake_masks.size(2), fake_masks.size(3)], mode='area')
211
+ pseudo_real_img = torch.clamp(pseudo_real_img, -1, 1)
212
+
213
+ return pseudo_real_img, fake_masks
214
+
215
+
216
+ def sample_camera(self, data, pseudo_gen):
217
+ r"""Sample camera randomly and precompute everything used by both Gen and Dis.
218
+
219
+ Args:
220
+ data (dict):
221
+ images (N x 3 x H x W tensor) : Real images
222
+ label (N x C2 x H x W tensor) : Segmentation map
223
+ pseudo_gen (callable): Function converting mask to image using img2img network.
224
+ Returns:
225
+ ret (dict):
226
+ voxel_id (N x H x W x max_samples x 1 tensor): IDs of intersected tensors along each ray.
227
+ depth2 (N x 2 x H x W x max_samples x 1 tensor): Depths of entrance and exit points for each ray-voxel
228
+ intersection.
229
+ raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
230
+ cam_ori_t (N x 3 tensor): Camera origins.
231
+ pseudo_real_img (N x 3 x H x W tensor): Pseudo-ground truth image.
232
+ real_masks (N x C3 x H x W tensor): One-hot segmentation map for real images, with translated labels.
233
+ fake_masks (N x C3 x H x W tensor): One-hot segmentation map for sampled camera views.
234
+ """
235
+ device = torch.device('cuda')
236
+ batch_size = data['images'].size(0)
237
+ # ================ Assemble a batch ==================
238
+ # Requires: voxel_id, depth2, raydirs, cam_ori_t.
239
+ voxel_id, depth2, raydirs, cam_ori_t, _ = self._get_batch(batch_size, device)
240
+ ret = {'voxel_id': voxel_id, 'depth2': depth2, 'raydirs': raydirs, 'cam_ori_t': cam_ori_t}
241
+
242
+ if pseudo_gen is not None:
243
+ pseudo_real_img, _ = self.get_pseudo_gt(pseudo_gen, voxel_id)
244
+ ret['pseudo_real_img'] = pseudo_real_img.float()
245
+
246
+ # =============== Mask translation ================
247
+ real_masks = data['label']
248
+ if self.reduced_label_set:
249
+ # Translate fake mask (directly from mcid).
250
+ # convert unrecognized labels to 'dirt'.
251
+ # N C H W [1 1 80 80]
252
+ reduce_fake_mask = self.label_trans.mc2reduced(
253
+ voxel_id[:, :, :, 0, :].permute(0, 3, 1, 2).long().contiguous()
254
+ , ign2dirt=True)
255
+ reduce_fake_mask_onehot = torch.zeros([
256
+ reduce_fake_mask.size(0), self.num_reduced_labels, reduce_fake_mask.size(2), reduce_fake_mask.size(3)],
257
+ dtype=torch.float, device=device)
258
+ reduce_fake_mask_onehot.scatter_(1, reduce_fake_mask, 1.0)
259
+ fake_masks = reduce_fake_mask_onehot
260
+ if self.pad != 0:
261
+ fake_masks = fake_masks[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
262
+
263
+ # Translate real mask (data['label']), which is onehot.
264
+ real_masks_idx = torch.argmax(real_masks, dim=1, keepdim=True)
265
+ real_masks_idx[real_masks_idx > 182] = 182
266
+
267
+ reduced_real_mask = self.label_trans.coco2reduced(real_masks_idx)
268
+ reduced_real_mask_onehot = torch.zeros([
269
+ reduced_real_mask.size(0), self.num_reduced_labels, reduced_real_mask.size(2),
270
+ reduced_real_mask.size(3)], dtype=torch.float, device=device)
271
+ reduced_real_mask_onehot.scatter_(1, reduced_real_mask, 1.0)
272
+ real_masks = reduced_real_mask_onehot
273
+
274
+ # Mask smoothing.
275
+ if self.use_label_smooth:
276
+ fake_masks = mc_utils.segmask_smooth(fake_masks, kernel_size=self.label_smooth_dia)
277
+ if self.use_label_smooth_real:
278
+ real_masks = mc_utils.segmask_smooth(real_masks, kernel_size=self.label_smooth_dia)
279
+
280
+ ret['real_masks'] = real_masks
281
+ ret['fake_masks'] = fake_masks
282
+
283
+ return ret
284
+
285
+ def _forward_perpix_sub(self, blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot=None, global_enc=None):
286
+ r"""Per-pixel rendering forwarding
287
+
288
+ Args:
289
+ blk_feats: Deprecated
290
+ worldcoord2 (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. L is number of samples; N is batch size, always 1.
291
+ raydirs_in (N x H x W x 1 x C2 tensor or None): ray direction embeddings.
292
+ z (N x C3 tensor): Intermediate style vectors.
293
+ mc_masks_onehot (N x H x W x L x C4): One-hot segmentation maps.
294
+ Returns:
295
+ net_out_s (N x H x W x L x 1 tensor): Opacities.
296
+ net_out_c (N x H x W x L x C5 tensor): Color embeddings.
297
+ """
298
+ _x, _y, _z = self.voxel.voxel_t.shape
299
+ delimeter = torch.Tensor([_x, _y, _z]).to(worldcoord2)
300
+ normalized_coord = worldcoord2 / delimeter * 2 - 1
301
+ global_enc = global_enc[:, None, None, None, :].repeat(1, normalized_coord.shape[1], normalized_coord.shape[2], normalized_coord.shape[3], 1)
302
+ normalized_coord = torch.cat([normalized_coord, global_enc], dim=-1)
303
+ feature_in = self.hash_encoder(normalized_coord)
304
+
305
+ net_out_s, net_out_c = self.render_net(feature_in, raydirs_in, z, mc_masks_onehot)
306
+
307
+ if self.raw_noise_std > 0.:
308
+ noise = torch.randn_like(net_out_s) * self.raw_noise_std
309
+ net_out_s = net_out_s + noise
310
+
311
+ return net_out_s, net_out_c
312
+
313
+ def _forward_perpix(self, blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z, global_enc):
314
+ r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features
315
+
316
+ Args:
317
+ blk_feats (K x C1 tensor): Deprecated
318
+ voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels, why always 6?
319
+ depth2 (N x 2 x H x W x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection.
320
+ raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
321
+ cam_ori_t (N x 3 tensor): Camera origins.
322
+ z (N x C3 tensor): Intermediate style vectors.
323
+ """
324
+ # Generate sky_mask; PE transform on ray direction.
325
+ with torch.no_grad():
326
+ raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
327
+ if self.pe_params[2] == 0 and self.pe_params[3] is True:
328
+ raydirs_in = raydirs_in
329
+ elif self.pe_params[2] == 0 and self.pe_params[3] is False: # Not using raydir at all
330
+ raydirs_in = None
331
+ else:
332
+ raydirs_in = voxlib.positional_encoding(raydirs_in, self.pe_params[2], -1, self.pe_params[3])
333
+
334
+ # sky_mask: when True, ray finally hits sky
335
+ sky_mask = voxel_id[:, :, :, [-1], :] == 0
336
+ # sky_only_mask: when True, ray hits nothing but sky
337
+ sky_only_mask = voxel_id[:, :, :, [0], :] == 0
338
+
339
+ with torch.no_grad():
340
+ # Random sample points along the ray
341
+ num_samples = self.num_samples + 1
342
+ if self.sample_use_box_boundaries:
343
+ num_samples = self.num_samples - self.num_blocks_early_stop
344
+
345
+ # 10 samples per ray + 4 intersections - 2
346
+ rand_depth, new_dists, new_idx = mc_utils.sample_depth_batched(
347
+ depth2, num_samples, deterministic=self.coarse_deterministic_sampling,
348
+ use_box_boundaries=self.sample_use_box_boundaries, sample_depth=self.sample_depth)
349
+
350
+ nan_mask = torch.isnan(rand_depth)
351
+ inf_mask = torch.isinf(rand_depth)
352
+ rand_depth[nan_mask | inf_mask] = 0.0
353
+
354
+ worldcoord2 = raydirs * rand_depth + cam_ori_t[:, None, None, None, :]
355
+
356
+ # Generate per-sample segmentation label
357
+ voxel_id_reduced = self.label_trans.mc2reduced(voxel_id, ign2dirt=True)
358
+ mc_masks = torch.gather(voxel_id_reduced, -2, new_idx) # B 256 256 N 1
359
+ mc_masks = mc_masks.long()
360
+ mc_masks_onehot = torch.zeros([mc_masks.size(0), mc_masks.size(1), mc_masks.size(
361
+ 2), mc_masks.size(3), self.num_reduced_labels], dtype=torch.float, device=voxel_id.device)
362
+ # mc_masks_onehot: [B H W Nlayer 680]
363
+ mc_masks_onehot.scatter_(-1, mc_masks, 1.0)
364
+
365
+ net_out_s, net_out_c = self._forward_perpix_sub(blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot, global_enc)
366
+
367
+ # Handle sky
368
+ sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
369
+ sky_raydirs_in = voxlib.positional_encoding(sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1])
370
+ skynet_out_c = self.sky_net(sky_raydirs_in, z)
371
+
372
+ # Blending
373
+ weights = mc_utils.volum_rendering_relu(net_out_s, new_dists * self.dists_scale, dim=-2)
374
+
375
+ # If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero.
376
+ weights = weights * torch.logical_not(sky_only_mask).float()
377
+ total_weights_raw = torch.sum(weights, dim=-2, keepdim=True) # 256 256 1 1
378
+ total_weights = total_weights_raw
379
+
380
+ is_gnd = worldcoord2[..., [0]] <= 1.0 # Y X Z, [256, 256, 4, 3], nan < 1.0 == False
381
+ is_gnd = is_gnd.any(dim=-2, keepdim=True)
382
+ nosky_mask = torch.logical_or(torch.logical_not(sky_mask), is_gnd)
383
+ nosky_mask = nosky_mask.float()
384
+
385
+ # Avoid sky leakage
386
+ sky_weight = 1.0-total_weights
387
+ if self.keep_sky_out:
388
+ # keep_sky_out_avgpool overrides sky_replace_color
389
+ if self.sky_replace_color is None or self.keep_sky_out_avgpool:
390
+ if self.keep_sky_out_avgpool:
391
+ if hasattr(self, 'sky_avg'):
392
+ sky_avg = self.sky_avg
393
+ else:
394
+ if self.sky_global_avgpool:
395
+ sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True)
396
+ else:
397
+ skynet_out_c_nchw = skynet_out_c.permute(0, 4, 1, 2, 3).squeeze(-1).contiguous()
398
+ sky_avg = F.avg_pool2d(skynet_out_c_nchw, 31, stride=1, padding=15, count_include_pad=False)
399
+ sky_avg = sky_avg.permute(0, 2, 3, 1).unsqueeze(-2).contiguous()
400
+ # print(sky_avg.shape)
401
+ skynet_out_c = skynet_out_c * (1.0-nosky_mask) + sky_avg*(nosky_mask)
402
+ else:
403
+ sky_weight = sky_weight * (1.0-nosky_mask)
404
+ else:
405
+ skynet_out_c = skynet_out_c * (1.0-nosky_mask) + self.sky_replace_color*(nosky_mask)
406
+
407
+ if self.clip_feat_map is True: # intermediate feature before blending & CNN
408
+ rgbs = torch.clamp(net_out_c, -1, 1) + 1
409
+ rgbs_sky = torch.clamp(skynet_out_c, -1, 1) + 1
410
+ net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
411
+ rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
412
+ net_out = net_out.squeeze(-2)
413
+ net_out = net_out - 1
414
+ elif self.clip_feat_map is False:
415
+ rgbs = net_out_c
416
+ rgbs_sky = skynet_out_c
417
+ net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
418
+ rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
419
+ net_out = net_out.squeeze(-2)
420
+ elif self.clip_feat_map == 'tanh':
421
+ rgbs = torch.tanh(net_out_c)
422
+ rgbs_sky = torch.tanh(skynet_out_c)
423
+ net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
424
+ rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
425
+ net_out = net_out.squeeze(-2)
426
+ else:
427
+ raise NotImplementedError
428
+
429
+ return net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \
430
+ nosky_mask, sky_mask, sky_only_mask, new_idx
431
+
432
+ def forward(self, data, random_style=False):
433
+ r"""SceneDreamer forward.
434
+ """
435
+ device = torch.device('cuda')
436
+ batch_size = data['images'].size(0)
437
+ # Requires: voxel_id, depth2, raydirs, cam_ori_t.
438
+ voxel_id, depth2, raydirs, cam_ori_t = data['voxel_id'], data['depth2'], data['raydirs'], data['cam_ori_t']
439
+ if 'pseudo_real_img' in data:
440
+ pseudo_real_img = data['pseudo_real_img']
441
+
442
+ global_enc = self.world_encoder(self.voxel.current_height_map, self.voxel.current_semantic_map)
443
+
444
+ z, mu, logvar = None, None, None
445
+ if random_style:
446
+ if self.style_dims > 0:
447
+ z = torch.randn(batch_size, self.style_dims, dtype=torch.float32, device=device)
448
+ else:
449
+ if self.style_encoder is None:
450
+ # ================ Get Style Code =================
451
+ if self.style_dims > 0:
452
+ z = torch.randn(batch_size, self.style_dims, dtype=torch.float32, device=device)
453
+ else:
454
+ mu, logvar, z = self.style_encoder(pseudo_real_img)
455
+
456
+ # ================ Network Forward ================
457
+ # Forward StyleNet
458
+ if self.style_net is not None:
459
+ z = self.style_net(z)
460
+
461
+ # Forward per-pixel net.
462
+ net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, nosky_mask, \
463
+ sky_mask, sky_only_mask, new_idx = self._forward_perpix(
464
+ self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z, global_enc)
465
+
466
+ # Forward global net.
467
+ fake_images, fake_images_raw = self._forward_global(net_out, z)
468
+ if self.pad != 0:
469
+ fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
470
+
471
+ # =============== Arrange Return Values ================
472
+ output = {}
473
+ output['fake_images'] = fake_images
474
+ output['mu'] = mu
475
+ output['logvar'] = logvar
476
+ return output
477
+
478
+
479
+ def inference_givenstyle(self, style,
480
+ output_dir,
481
+ camera_mode,
482
+ style_img_path=None,
483
+ seed=1,
484
+ pad=30,
485
+ num_samples=40,
486
+ num_blocks_early_stop=6,
487
+ sample_depth=3,
488
+ tile_size=128,
489
+ resolution_hw=[540, 960],
490
+ cam_ang=72,
491
+ cam_maxstep=10):
492
+ r"""Compute result images according to the provided camera trajectory and save the results in the specified
493
+ folder. The full image is evaluated in multiple tiles to save memory.
494
+
495
+ Args:
496
+ output_dir (str): Where should the results be stored.
497
+ camera_mode (int): Which camera trajectory to use.
498
+ style_img_path (str): Path to the style-conditioning image.
499
+ seed (int): Random seed (controls style when style_image_path is not specified).
500
+ pad (int): Pixels to remove from the image tiles before stitching. Should be equal or larger than the
501
+ receptive field of the CNN to avoid border artifact.
502
+ num_samples (int): Number of samples per ray (different from training).
503
+ num_blocks_early_stop (int): Max number of intersected boxes per ray before stopping
504
+ (different from training).
505
+ sample_depth (float): Max distance traveled through boxes before stopping (different from training).
506
+ tile_size (int): Max size of a tile in pixels.
507
+ resolution_hw (list [H, W]): Resolution of the output image.
508
+ cam_ang (float): Horizontal FOV of the camera (may be adjusted by the camera controller).
509
+ cam_maxstep (int): Number of frames sampled from the camera trajectory.
510
+ """
511
+
512
+ def write_img(path, img, rgb_input=False):
513
+ img = ((img*0.5+0.5)*255).detach().cpu().numpy().astype(np.uint8)
514
+ img = img[0].transpose(1, 2, 0)
515
+ if rgb_input:
516
+ img = img[..., [2, 1, 0]]
517
+ cv2.imwrite(path, img, [cv2.IMWRITE_PNG_COMPRESSION, 4])
518
+ return img[..., ::-1]
519
+
520
+ def read_img(path):
521
+ img = cv2.imread(path).astype(np.float32)[..., [2, 1, 0]].transpose(2, 0, 1) / 255
522
+ img = img * 2 - 1
523
+ img = torch.from_numpy(img)
524
+
525
+ print('Saving to', output_dir)
526
+
527
+ # Use provided random seed.
528
+ device = torch.device('cuda')
529
+
530
+ global_enc = self.world_encoder(self.voxel.current_height_map, self.voxel.current_semantic_map)
531
+
532
+ biome_colors = torch.Tensor([
533
+ [255, 255, 178],
534
+ [184, 200, 98],
535
+ [188, 161, 53],
536
+ [190, 255, 242],
537
+ [106, 144, 38],
538
+ [33, 77, 41],
539
+ [86, 179, 106],
540
+ [34, 61, 53],
541
+ [35, 114, 94],
542
+ [0, 0, 255],
543
+ [0, 255, 0],
544
+ ]).to(device) / 255 * 2 - 1
545
+ semantic_map = torch.argmax(self.voxel.current_semantic_map, dim=1)
546
+
547
+ self.pad = pad
548
+ self.num_samples = num_samples
549
+ self.num_blocks_early_stop = num_blocks_early_stop
550
+ self.sample_depth = sample_depth
551
+
552
+ self.coarse_deterministic_sampling = True
553
+ self.crop_size = resolution_hw
554
+ self.cam_res = [self.crop_size[0]+self.pad, self.crop_size[1]+self.pad]
555
+ self.use_label_smooth_pgt = False
556
+
557
+ # Make output dirs.
558
+ output_dir = os.path.join(output_dir, 'rgb_render')
559
+ os.makedirs(output_dir, exist_ok=True)
560
+ fout = imageio.get_writer(output_dir + '.mp4', fps=10)
561
+
562
+ write_img(os.path.join(output_dir, 'semantic_map.png'), biome_colors[semantic_map].permute(0, 3, 1, 2), rgb_input=True)
563
+ write_img(os.path.join(output_dir, 'height_map.png'), self.voxel.current_height_map)
564
+ np.save(os.path.join(output_dir, 'style.npy'), style.detach().cpu().numpy())
565
+ evalcamctl = camctl.EvalCameraController(
566
+ self.voxel, maxstep=cam_maxstep, pattern=camera_mode, cam_ang=cam_ang,
567
+ smooth_decay_multiplier=150/cam_maxstep)
568
+
569
+ # Get output style.
570
+ z = self.style_net(style)
571
+
572
+ # Generate required output images.
573
+ for id, (cam_ori_t, cam_dir_t, cam_up_t, cam_f) in enumerate(evalcamctl):
574
+ print('Rendering frame', id)
575
+ cam_f = cam_f * (self.crop_size[1]-1) # So that the view is not depending on the padding
576
+ cam_c = [(self.cam_res[0]-1)/2, (self.cam_res[1]-1)/2]
577
+
578
+ voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective(
579
+ self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, self.cam_res,
580
+ self.num_blocks_early_stop)
581
+
582
+ voxel_id = voxel_id.unsqueeze(0)
583
+ depth2 = depth2.unsqueeze(0)
584
+ raydirs = raydirs.unsqueeze(0)
585
+ cam_ori_t = cam_ori_t.unsqueeze(0).to(device)
586
+
587
+ voxel_id_all = voxel_id
588
+ depth2_all = depth2
589
+ raydirs_all = raydirs
590
+
591
+ # Evaluate sky in advance to get a consistent sky in the semi-transparent region.
592
+ if self.sky_global_avgpool:
593
+ sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
594
+ sky_raydirs_in = voxlib.positional_encoding(
595
+ sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1])
596
+ skynet_out_c = self.sky_net(sky_raydirs_in, z)
597
+ sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True)
598
+ self.sky_avg = sky_avg
599
+
600
+ num_strips_h = (self.cam_res[0]-self.pad+tile_size-1)//tile_size
601
+ num_strips_w = (self.cam_res[1]-self.pad+tile_size-1)//tile_size
602
+
603
+ fake_images_chunks_v = []
604
+ # For each horizontal strip.
605
+ for strip_id_h in range(num_strips_h):
606
+ strip_begin_h = strip_id_h * tile_size
607
+ strip_end_h = np.minimum(strip_id_h * tile_size + tile_size + self.pad, self.cam_res[0])
608
+ # For each vertical strip.
609
+ fake_images_chunks_h = []
610
+ for strip_id_w in range(num_strips_w):
611
+ strip_begin_w = strip_id_w * tile_size
612
+ strip_end_w = np.minimum(strip_id_w * tile_size + tile_size + self.pad, self.cam_res[1])
613
+
614
+ voxel_id = voxel_id_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
615
+ depth2 = depth2_all[:, :, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
616
+ raydirs = raydirs_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
617
+
618
+ net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \
619
+ nosky_mask, sky_mask, sky_only_mask, new_idx = self._forward_perpix(
620
+ self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z, global_enc)
621
+ fake_images, _ = self._forward_global(net_out, z)
622
+
623
+ if self.pad != 0:
624
+ fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
625
+ fake_images_chunks_h.append(fake_images)
626
+ fake_images_h = torch.cat(fake_images_chunks_h, dim=-1)
627
+ fake_images_chunks_v.append(fake_images_h)
628
+ fake_images = torch.cat(fake_images_chunks_v, dim=-2)
629
+ rgb = write_img(os.path.join(output_dir,
630
+ '{:05d}.png'.format(id)), fake_images, rgb_input=True)
631
+ fout.append_data(rgb)
632
+ fout.close()
633
+
634
+
635
+
636
+ def inference_givenstyle_depth(self, style,
637
+ output_dir,
638
+ camera_mode,
639
+ style_img_path=None,
640
+ seed=1,
641
+ pad=30,
642
+ num_samples=40,
643
+ num_blocks_early_stop=6,
644
+ sample_depth=3,
645
+ tile_size=128,
646
+ resolution_hw=[540, 960],
647
+ cam_ang=72,
648
+ cam_maxstep=10):
649
+ r"""Compute result images according to the provided camera trajectory and save the results in the specified
650
+ folder. The full image is evaluated in multiple tiles to save memory.
651
+
652
+ Args:
653
+ output_dir (str): Where should the results be stored.
654
+ camera_mode (int): Which camera trajectory to use.
655
+ style_img_path (str): Path to the style-conditioning image.
656
+ seed (int): Random seed (controls style when style_image_path is not specified).
657
+ pad (int): Pixels to remove from the image tiles before stitching. Should be equal or larger than the
658
+ receptive field of the CNN to avoid border artifact.
659
+ num_samples (int): Number of samples per ray (different from training).
660
+ num_blocks_early_stop (int): Max number of intersected boxes per ray before stopping
661
+ (different from training).
662
+ sample_depth (float): Max distance traveled through boxes before stopping (different from training).
663
+ tile_size (int): Max size of a tile in pixels.
664
+ resolution_hw (list [H, W]): Resolution of the output image.
665
+ cam_ang (float): Horizontal FOV of the camera (may be adjusted by the camera controller).
666
+ cam_maxstep (int): Number of frames sampled from the camera trajectory.
667
+ """
668
+
669
+ def write_img(path, img, rgb_input=False):
670
+ img = ((img*0.5+0.5)*255).detach().cpu().numpy().astype(np.uint8)
671
+ img = img[0].transpose(1, 2, 0)
672
+ if rgb_input:
673
+ img = img[..., [2, 1, 0]]
674
+ cv2.imwrite(path, img, [cv2.IMWRITE_PNG_COMPRESSION, 4])
675
+ return img[..., ::-1]
676
+
677
+ def read_img(path):
678
+ img = cv2.imread(path).astype(np.float32)[..., [2, 1, 0]].transpose(2, 0, 1) / 255
679
+ img = img * 2 - 1
680
+ img = torch.from_numpy(img)
681
+
682
+ print('Saving to', output_dir)
683
+
684
+ # Use provided random seed.
685
+ device = torch.device('cuda')
686
+
687
+ global_enc = self.world_encoder(self.voxel.current_height_map, self.voxel.current_semantic_map)
688
+
689
+ biome_colors = torch.Tensor([
690
+ [255, 255, 178],
691
+ [184, 200, 98],
692
+ [188, 161, 53],
693
+ [190, 255, 242],
694
+ [106, 144, 38],
695
+ [33, 77, 41],
696
+ [86, 179, 106],
697
+ [34, 61, 53],
698
+ [35, 114, 94],
699
+ [0, 0, 255],
700
+ [0, 255, 0],
701
+ ]) / 255 * 2 - 1
702
+ print(self.voxel.current_height_map[0].shape)
703
+ semantic_map = torch.argmax(self.voxel.current_semantic_map, dim=1)
704
+ print(torch.unique(semantic_map, return_counts=True))
705
+ print(semantic_map.min())
706
+
707
+ self.pad = pad
708
+ self.num_samples = num_samples
709
+ self.num_blocks_early_stop = num_blocks_early_stop
710
+ self.sample_depth = sample_depth
711
+
712
+ self.coarse_deterministic_sampling = True
713
+ self.crop_size = resolution_hw
714
+ self.cam_res = [self.crop_size[0]+self.pad, self.crop_size[1]+self.pad]
715
+ self.use_label_smooth_pgt = False
716
+
717
+ # Make output dirs.
718
+ gancraft_outputs_dir = os.path.join(output_dir, 'gancraft_outputs')
719
+ os.makedirs(gancraft_outputs_dir, exist_ok=True)
720
+ gancraft_depth_outputs_dir = os.path.join(output_dir, 'depth')
721
+ os.makedirs(gancraft_depth_outputs_dir, exist_ok=True)
722
+ vis_masks_dir = os.path.join(output_dir, 'vis_masks')
723
+ os.makedirs(vis_masks_dir, exist_ok=True)
724
+ fout = imageio.get_writer(gancraft_outputs_dir + '.mp4', fps=10)
725
+ fout_cat = imageio.get_writer(gancraft_outputs_dir + '-vis_masks.mp4', fps=10)
726
+
727
+ write_img(os.path.join(output_dir, 'semantic_map.png'), biome_colors[semantic_map].permute(0, 3, 1, 2), rgb_input=True)
728
+ write_img(os.path.join(output_dir, 'heightmap.png'), self.voxel.current_height_map)
729
+
730
+ evalcamctl = camctl.EvalCameraController(
731
+ self.voxel, maxstep=cam_maxstep, pattern=camera_mode, cam_ang=cam_ang,
732
+ smooth_decay_multiplier=150/cam_maxstep)
733
+
734
+ # import pickle
735
+ # with open(os.path.join(output_dir,'camera.pkl'), 'wb') as f:
736
+ # pickle.dump(evalcamctl, f)
737
+
738
+ # Get output style.
739
+ z = self.style_net(style)
740
+
741
+ # Generate required output images.
742
+ for id, (cam_ori_t, cam_dir_t, cam_up_t, cam_f) in enumerate(evalcamctl):
743
+ # print('Rendering frame', id)
744
+ cam_f = cam_f * (self.crop_size[1]-1) # So that the view is not depending on the padding
745
+ cam_c = [(self.cam_res[0]-1)/2, (self.cam_res[1]-1)/2]
746
+
747
+ voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective(
748
+ self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, self.cam_res,
749
+ self.num_blocks_early_stop)
750
+
751
+ voxel_id = voxel_id.unsqueeze(0)
752
+ depth2 = depth2.unsqueeze(0)
753
+ raydirs = raydirs.unsqueeze(0)
754
+ cam_ori_t = cam_ori_t.unsqueeze(0).to(device)
755
+
756
+ # Save 3D voxel rendering.
757
+ mc_rgb = self.label_trans.mc_color(voxel_id[0, :, :, 0, 0].cpu().numpy())
758
+ # Diffused shading, co-located light.
759
+ first_intersection_depth = depth2[:, 0, :, :, 0, None, :] # [1, 542, 542, 1, 1].
760
+ first_intersection_point = raydirs * first_intersection_depth + cam_ori_t[:, None, None, None, :]
761
+ fip_local_coords = torch.remainder(first_intersection_point, 1.0)
762
+ fip_wall_proximity = torch.minimum(fip_local_coords, 1.0-fip_local_coords)
763
+ fip_wall_orientation = torch.argmin(fip_wall_proximity, dim=-1, keepdim=False)
764
+ # 0: [1,0,0]; 1: [0,1,0]; 2: [0,0,1]
765
+ lut = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32,
766
+ device=fip_wall_orientation.device)
767
+ fip_normal = lut[fip_wall_orientation] # [1, 542, 542, 1, 3]
768
+ diffuse_shade = torch.abs(torch.sum(fip_normal * raydirs, dim=-1))
769
+
770
+ mc_rgb = (mc_rgb.astype(np.float) / 255) ** 2.2
771
+ mc_rgb = mc_rgb * diffuse_shade[0, :, :, :].cpu().numpy()
772
+ mc_rgb = (mc_rgb ** (1/2.2)) * 255
773
+ mc_rgb = mc_rgb.astype(np.uint8)
774
+ if self.pad > 0:
775
+ mc_rgb = mc_rgb[self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
776
+ cv2.imwrite(os.path.join(vis_masks_dir, '{:05d}.png'.format(id)), mc_rgb, [cv2.IMWRITE_PNG_COMPRESSION, 4])
777
+
778
+ # Tiled eval of GANcraft.
779
+ voxel_id_all = voxel_id
780
+ depth2_all = depth2
781
+ raydirs_all = raydirs
782
+
783
+ # Evaluate sky in advance to get a consistent sky in the semi-transparent region.
784
+ if self.sky_global_avgpool:
785
+ sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
786
+ sky_raydirs_in = voxlib.positional_encoding(
787
+ sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1])
788
+ skynet_out_c = self.sky_net(sky_raydirs_in, z)
789
+ sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True)
790
+ self.sky_avg = sky_avg
791
+
792
+ num_strips_h = (self.cam_res[0]-self.pad+tile_size-1)//tile_size
793
+ num_strips_w = (self.cam_res[1]-self.pad+tile_size-1)//tile_size
794
+
795
+ fake_images_chunks_v = []
796
+ fake_depth_chunks_v = []
797
+ # For each horizontal strip.
798
+ for strip_id_h in range(num_strips_h):
799
+ strip_begin_h = strip_id_h * tile_size
800
+ strip_end_h = np.minimum(strip_id_h * tile_size + tile_size + self.pad, self.cam_res[0])
801
+ # For each vertical strip.
802
+ fake_images_chunks_h = []
803
+ fake_depth_chunks_h = []
804
+ for strip_id_w in range(num_strips_w):
805
+ strip_begin_w = strip_id_w * tile_size
806
+ strip_end_w = np.minimum(strip_id_w * tile_size + tile_size + self.pad, self.cam_res[1])
807
+
808
+ voxel_id = voxel_id_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
809
+ depth2 = depth2_all[:, :, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
810
+ raydirs = raydirs_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
811
+
812
+ net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \
813
+ nosky_mask, sky_mask, sky_only_mask, new_idx = self._forward_perpix(
814
+ self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z, global_enc)
815
+ fake_images, _ = self._forward_global(net_out, z)
816
+ depth_map = torch.sum(weights * rand_depth, -2)
817
+ # disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map).to(depth_map), depth_map / torch.sum(weights, -2))
818
+ # depth_map = torch.clip(depth_map, 0, 100.)
819
+ # disp_map = 1. / (depth_map.permute(0, 3, 1, 2))
820
+ disp_map = depth_map.permute(0, 3, 1, 2)
821
+ if self.pad != 0:
822
+ fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
823
+ disp_map = disp_map[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
824
+ fake_images_chunks_h.append(fake_images)
825
+ fake_depth_chunks_h.append(disp_map)
826
+ fake_images_h = torch.cat(fake_images_chunks_h, dim=-1)
827
+ fake_depth_h = torch.cat(fake_depth_chunks_h, dim=-1)
828
+ fake_images_chunks_v.append(fake_images_h)
829
+ fake_depth_chunks_v.append(fake_depth_h)
830
+ fake_images = torch.cat(fake_images_chunks_v, dim=-2)
831
+ fake_depth = torch.cat(fake_depth_chunks_v, dim=-2)
832
+ # fake_depth = ((fake_depth - fake_depth.mean()) / fake_depth.std() + 1) / 2
833
+ # fake_depth = torch.clip(1./ (fake_depth + 1e-4), 0., 1.)
834
+ # fake_depth = ((fake_depth - fake_depth.mean()) / fake_depth.std() + 1) / 2
835
+ mmask = fake_depth > 0
836
+ tmp = fake_depth[mmask]
837
+ # tmp = 1. / (tmp + 1e-4)
838
+ tmp = (tmp - tmp.min()) / (tmp.max() - tmp.min())
839
+ # tmp = ((tmp - tmp.mean()) / tmp.std() + 1) / 2.
840
+ fake_depth[~mmask] = 1
841
+ fake_depth[mmask] = tmp
842
+ # fake_depth = (fake_depth - fake_depth.min()) / (fake_depth.max() - fake_depth.min())
843
+
844
+ cv2.imwrite(os.path.join(gancraft_depth_outputs_dir, '{:05d}.png'.format(id)), fake_depth[0].permute(1, 2, 0).detach().cpu().numpy() * 255)
845
+ rgb = write_img(os.path.join(gancraft_outputs_dir,
846
+ '{:05d}.png'.format(id)), fake_images, rgb_input=True)
847
+ fout.append_data(rgb)
848
+ fout_cat.append_data(np.concatenate((mc_rgb[..., ::-1], rgb), axis=1))
849
+ fout.close()
850
+ fout_cat.close()
851
+
imaginaire/generators/spade.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import functools
6
+ import math
7
+ import types
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn import Upsample as NearestUpsample
14
+
15
+ from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock
16
+ from imaginaire.utils.data import (get_crop_h_w,
17
+ get_paired_input_image_channel_number,
18
+ get_paired_input_label_channel_number)
19
+ from imaginaire.utils.distributed import master_only_print as print
20
+
21
+
22
+ class Generator(nn.Module):
23
+ r"""SPADE generator constructor.
24
+
25
+ Args:
26
+ gen_cfg (obj): Generator definition part of the yaml config file.
27
+ data_cfg (obj): Data definition part of the yaml config file.
28
+ """
29
+
30
+ def __init__(self, gen_cfg, data_cfg):
31
+ super(Generator, self).__init__()
32
+ print('SPADE generator initialization.')
33
+ # We assume the first datum is the ground truth image.
34
+ image_channels = getattr(gen_cfg, 'image_channels', None)
35
+ if image_channels is None:
36
+ image_channels = get_paired_input_image_channel_number(data_cfg)
37
+ num_labels = getattr(gen_cfg, 'num_labels', None)
38
+ if num_labels is None:
39
+ # Calculate number of channels in the input label when not specified.
40
+ num_labels = get_paired_input_label_channel_number(data_cfg)
41
+ crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations)
42
+ # Build the generator
43
+ out_image_small_side_size = crop_w if crop_w < crop_h else crop_h
44
+ num_filters = getattr(gen_cfg, 'num_filters', 128)
45
+ kernel_size = getattr(gen_cfg, 'kernel_size', 3)
46
+ weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral')
47
+
48
+ cond_dims = 0
49
+ # Check whether we use the style code.
50
+ style_dims = getattr(gen_cfg, 'style_dims', None)
51
+ self.style_dims = style_dims
52
+ if style_dims is not None:
53
+ print('\tStyle code dimensions: %d' % style_dims)
54
+ cond_dims += style_dims
55
+ self.use_style = True
56
+ else:
57
+ self.use_style = False
58
+ # Check whether we use the attribute code.
59
+ if hasattr(gen_cfg, 'attribute_dims'):
60
+ self.use_attribute = True
61
+ self.attribute_dims = gen_cfg.attribute_dims
62
+ cond_dims += gen_cfg.attribute_dims
63
+ else:
64
+ self.use_attribute = False
65
+
66
+ if not self.use_style and not self.use_attribute:
67
+ self.use_style_encoder = False
68
+ else:
69
+ self.use_style_encoder = True
70
+ print('\tBase filter number: %d' % num_filters)
71
+ print('\tConvolution kernel size: %d' % kernel_size)
72
+ print('\tWeight norm type: %s' % weight_norm_type)
73
+ skip_activation_norm = \
74
+ getattr(gen_cfg, 'skip_activation_norm', True)
75
+ activation_norm_params = getattr(gen_cfg, 'activation_norm_params', None)
76
+ if activation_norm_params is None:
77
+ activation_norm_params = types.SimpleNamespace()
78
+ if not hasattr(activation_norm_params, 'num_filters'):
79
+ setattr(activation_norm_params, 'num_filters', 128)
80
+ if not hasattr(activation_norm_params, 'kernel_size'):
81
+ setattr(activation_norm_params, 'kernel_size', 3)
82
+ if not hasattr(activation_norm_params, 'activation_norm_type'):
83
+ setattr(activation_norm_params, 'activation_norm_type', 'sync_batch')
84
+ if not hasattr(activation_norm_params, 'separate_projection'):
85
+ setattr(activation_norm_params, 'separate_projection', False)
86
+ if not hasattr(activation_norm_params, 'activation_norm_params'):
87
+ activation_norm_params.activation_norm_params = types.SimpleNamespace()
88
+ activation_norm_params.activation_norm_params.affine = True
89
+ setattr(activation_norm_params, 'cond_dims', num_labels)
90
+ if not hasattr(activation_norm_params, 'weight_norm_type'):
91
+ setattr(activation_norm_params, 'weight_norm_type', weight_norm_type)
92
+ global_adaptive_norm_type = getattr(gen_cfg, 'global_adaptive_norm_type', 'sync_batch')
93
+ use_posenc_in_input_layer = getattr(gen_cfg, 'use_posenc_in_input_layer', True)
94
+ output_multiplier = getattr(gen_cfg, 'output_multiplier', 1.0)
95
+ print(activation_norm_params)
96
+ self.spade_generator = SPADEGenerator(num_labels,
97
+ out_image_small_side_size,
98
+ image_channels,
99
+ num_filters,
100
+ kernel_size,
101
+ cond_dims,
102
+ activation_norm_params,
103
+ weight_norm_type,
104
+ global_adaptive_norm_type,
105
+ skip_activation_norm,
106
+ use_posenc_in_input_layer,
107
+ self.use_style_encoder,
108
+ output_multiplier)
109
+ if self.use_style:
110
+ # Build the encoder.
111
+ style_enc_cfg = getattr(gen_cfg, 'style_enc', None)
112
+ if style_enc_cfg is None:
113
+ style_enc_cfg = types.SimpleNamespace()
114
+ if not hasattr(style_enc_cfg, 'num_filters'):
115
+ setattr(style_enc_cfg, 'num_filters', 128)
116
+ if not hasattr(style_enc_cfg, 'kernel_size'):
117
+ setattr(style_enc_cfg, 'kernel_size', 3)
118
+ if not hasattr(style_enc_cfg, 'weight_norm_type'):
119
+ setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type)
120
+ setattr(style_enc_cfg, 'input_image_channels', image_channels)
121
+ setattr(style_enc_cfg, 'style_dims', style_dims)
122
+ self.style_encoder = StyleEncoder(style_enc_cfg)
123
+
124
+ self.z = None
125
+ print('Done with the SPADE generator initialization.')
126
+
127
+ def forward(self, data, random_style=False):
128
+ r"""SPADE Generator forward.
129
+
130
+ Args:
131
+ data (dict):
132
+ - images (N x C1 x H x W tensor) : Ground truth images
133
+ - label (N x C2 x H x W tensor) : Semantic representations
134
+ - z (N x style_dims tensor): Gaussian random noise
135
+ - random_style (bool): Whether to sample a random style vector.
136
+ Returns:
137
+ (dict):
138
+ - fake_images (N x 3 x H x W tensor): fake images
139
+ - mu (N x C1 tensor): mean vectors
140
+ - logvar (N x C1 tensor): log-variance vectors
141
+ """
142
+ if self.use_style_encoder:
143
+ if random_style:
144
+ bs = data['label'].size(0)
145
+ z = torch.randn(
146
+ bs, self.style_dims, dtype=torch.float32).cuda()
147
+ if (data['label'].dtype ==
148
+ data['label'].dtype == torch.float16):
149
+ z = z.half()
150
+ mu = None
151
+ logvar = None
152
+ else:
153
+ mu, logvar, z = self.style_encoder(data['images'])
154
+ if self.use_attribute:
155
+ data['z'] = torch.cat((z, data['attributes'].squeeze(1)), dim=1)
156
+ else:
157
+ data['z'] = z
158
+ output = self.spade_generator(data)
159
+ if self.use_style_encoder:
160
+ output['mu'] = mu
161
+ output['logvar'] = logvar
162
+ return output
163
+
164
+ def inference(self,
165
+ data,
166
+ random_style=False,
167
+ use_fixed_random_style=False,
168
+ keep_original_size=False):
169
+ r"""Compute results images for a batch of input data and save the
170
+ results in the specified folder.
171
+
172
+ Args:
173
+ data (dict):
174
+ - images (N x C1 x H x W tensor) : Ground truth images
175
+ - label (N x C2 x H x W tensor) : Semantic representations
176
+ - z (N x style_dims tensor): Gaussian random noise
177
+ random_style (bool): Whether to sample a random style vector.
178
+ use_fixed_random_style (bool): Sample random style once and use it
179
+ for all the remaining inference.
180
+ keep_original_size (bool): Keep original size of the input.
181
+ Returns:
182
+ (dict):
183
+ - fake_images (N x 3 x H x W tensor): fake images
184
+ - mu (N x C1 tensor): mean vectors
185
+ - logvar (N x C1 tensor): log-variance vectors
186
+ """
187
+ self.eval()
188
+ self.spade_generator.eval()
189
+
190
+ if self.use_style_encoder:
191
+ if random_style and self.use_style_encoder:
192
+ if self.z is None or not use_fixed_random_style:
193
+ bs = data['label'].size(0)
194
+ z = torch.randn(
195
+ bs, self.style_dims, dtype=torch.float32).to('cuda')
196
+ if (data['label'].dtype ==
197
+ data['label'].dtype ==
198
+ torch.float16):
199
+ z = z.half()
200
+ self.z = z
201
+ else:
202
+ z = self.z
203
+ else:
204
+ mu, logvar, z = self.style_encoder(data['images'])
205
+ data['z'] = z
206
+
207
+ output = self.spade_generator(data)
208
+ output_images = output['fake_images']
209
+
210
+ if keep_original_size:
211
+ height = data['original_h_w'][0][0]
212
+ width = data['original_h_w'][0][1]
213
+ output_images = torch.nn.functional.interpolate(
214
+ output_images, size=[height, width])
215
+
216
+ for key in data['key'].keys():
217
+ if 'segmaps' in key or 'seg_maps' in key:
218
+ file_names = data['key'][key][0]
219
+ break
220
+ for key in data['key'].keys():
221
+ if 'edgemaps' in key or 'edge_maps' in key:
222
+ file_names = data['key'][key][0]
223
+ break
224
+
225
+ return output_images, file_names
226
+
227
+
228
+ class SPADEGenerator(nn.Module):
229
+ r"""SPADE Image Generator constructor.
230
+
231
+ Args:
232
+ num_labels (int): Number of different labels.
233
+ out_image_small_side_size (int): min(width, height)
234
+ image_channels (int): Num. of channels of the output image.
235
+ num_filters (int): Base filter numbers.
236
+ kernel_size (int): Convolution kernel size.
237
+ style_dims (int): Dimensions of the style code.
238
+ activation_norm_params (obj): Spatially adaptive normalization param.
239
+ weight_norm_type (str): Type of weight normalization.
240
+ ``'none'``, ``'spectral'``, or ``'weight'``.
241
+ global_adaptive_norm_type (str): Type of normalization in SPADE.
242
+ skip_activation_norm (bool): If ``True``, applies activation norm to the
243
+ shortcut connection in residual blocks.
244
+ use_style_encoder (bool): Whether to use global adaptive norm
245
+ like conditional batch norm or adaptive instance norm.
246
+ output_multiplier (float): A positive number multiplied to the output
247
+ """
248
+
249
+ def __init__(self,
250
+ num_labels,
251
+ out_image_small_side_size,
252
+ image_channels,
253
+ num_filters,
254
+ kernel_size,
255
+ style_dims,
256
+ activation_norm_params,
257
+ weight_norm_type,
258
+ global_adaptive_norm_type,
259
+ skip_activation_norm,
260
+ use_posenc_in_input_layer,
261
+ use_style_encoder,
262
+ output_multiplier):
263
+ super(SPADEGenerator, self).__init__()
264
+ self.output_multiplier = output_multiplier
265
+ self.use_style_encoder = use_style_encoder
266
+ self.use_posenc_in_input_layer = use_posenc_in_input_layer
267
+ self.out_image_small_side_size = out_image_small_side_size
268
+ self.num_filters = num_filters
269
+ padding = int(np.ceil((kernel_size - 1.0) / 2))
270
+ nonlinearity = 'leakyrelu'
271
+ activation_norm_type = 'spatially_adaptive'
272
+ base_res2d_block = \
273
+ functools.partial(Res2dBlock,
274
+ kernel_size=kernel_size,
275
+ padding=padding,
276
+ bias=[True, True, False],
277
+ weight_norm_type=weight_norm_type,
278
+ activation_norm_type=activation_norm_type,
279
+ activation_norm_params=activation_norm_params,
280
+ skip_activation_norm=skip_activation_norm,
281
+ nonlinearity=nonlinearity,
282
+ order='NACNAC')
283
+ if self.use_style_encoder:
284
+ self.fc_0 = LinearBlock(style_dims, 2 * style_dims,
285
+ weight_norm_type=weight_norm_type,
286
+ nonlinearity='relu',
287
+ order='CAN')
288
+ self.fc_1 = LinearBlock(2 * style_dims, 2 * style_dims,
289
+ weight_norm_type=weight_norm_type,
290
+ nonlinearity='relu',
291
+ order='CAN')
292
+
293
+ adaptive_norm_params = types.SimpleNamespace()
294
+ if not hasattr(adaptive_norm_params, 'cond_dims'):
295
+ setattr(adaptive_norm_params, 'cond_dims', 2 * style_dims)
296
+ if not hasattr(adaptive_norm_params, 'activation_norm_type'):
297
+ setattr(adaptive_norm_params, 'activation_norm_type', global_adaptive_norm_type)
298
+ if not hasattr(adaptive_norm_params, 'weight_norm_type'):
299
+ setattr(adaptive_norm_params, 'weight_norm_type', activation_norm_params.weight_norm_type)
300
+ if not hasattr(adaptive_norm_params, 'separate_projection'):
301
+ setattr(adaptive_norm_params, 'separate_projection', activation_norm_params.separate_projection)
302
+ adaptive_norm_params.activation_norm_params = types.SimpleNamespace()
303
+ setattr(adaptive_norm_params.activation_norm_params, 'affine',
304
+ activation_norm_params.activation_norm_params.affine)
305
+ base_cbn2d_block = \
306
+ functools.partial(Conv2dBlock,
307
+ kernel_size=kernel_size,
308
+ stride=1,
309
+ padding=padding,
310
+ bias=True,
311
+ weight_norm_type=weight_norm_type,
312
+ activation_norm_type='adaptive',
313
+ activation_norm_params=adaptive_norm_params,
314
+ nonlinearity=nonlinearity,
315
+ order='NAC')
316
+ else:
317
+ base_conv2d_block = \
318
+ functools.partial(Conv2dBlock,
319
+ kernel_size=kernel_size,
320
+ stride=1,
321
+ padding=padding,
322
+ bias=True,
323
+ weight_norm_type=weight_norm_type,
324
+ nonlinearity=nonlinearity,
325
+ order='NAC')
326
+ in_num_labels = num_labels
327
+ in_num_labels += 2 if self.use_posenc_in_input_layer else 0
328
+ self.head_0 = Conv2dBlock(in_num_labels, 8 * num_filters,
329
+ kernel_size=kernel_size, stride=1,
330
+ padding=padding,
331
+ weight_norm_type=weight_norm_type,
332
+ activation_norm_type='none',
333
+ nonlinearity=nonlinearity)
334
+ if self.use_style_encoder:
335
+ self.cbn_head_0 = base_cbn2d_block(
336
+ 8 * num_filters, 16 * num_filters)
337
+ else:
338
+ self.conv_head_0 = base_conv2d_block(
339
+ 8 * num_filters, 16 * num_filters)
340
+ self.head_1 = base_res2d_block(16 * num_filters, 16 * num_filters)
341
+ self.head_2 = base_res2d_block(16 * num_filters, 16 * num_filters)
342
+
343
+ self.up_0a = base_res2d_block(16 * num_filters, 8 * num_filters)
344
+ if self.use_style_encoder:
345
+ self.cbn_up_0a = base_cbn2d_block(
346
+ 8 * num_filters, 8 * num_filters)
347
+ else:
348
+ self.conv_up_0a = base_conv2d_block(
349
+ 8 * num_filters, 8 * num_filters)
350
+ self.up_0b = base_res2d_block(8 * num_filters, 8 * num_filters)
351
+
352
+ self.up_1a = base_res2d_block(8 * num_filters, 4 * num_filters)
353
+ if self.use_style_encoder:
354
+ self.cbn_up_1a = base_cbn2d_block(
355
+ 4 * num_filters, 4 * num_filters)
356
+ else:
357
+ self.conv_up_1a = base_conv2d_block(
358
+ 4 * num_filters, 4 * num_filters)
359
+ self.up_1b = base_res2d_block(4 * num_filters, 4 * num_filters)
360
+ self.up_2a = base_res2d_block(4 * num_filters, 4 * num_filters)
361
+ if self.use_style_encoder:
362
+ self.cbn_up_2a = base_cbn2d_block(
363
+ 4 * num_filters, 4 * num_filters)
364
+ else:
365
+ self.conv_up_2a = base_conv2d_block(
366
+ 4 * num_filters, 4 * num_filters)
367
+ self.up_2b = base_res2d_block(4 * num_filters, 2 * num_filters)
368
+ self.conv_img256 = Conv2dBlock(2 * num_filters, image_channels,
369
+ 5, stride=1, padding=2,
370
+ weight_norm_type=weight_norm_type,
371
+ activation_norm_type='none',
372
+ nonlinearity=nonlinearity,
373
+ order='ANC')
374
+ self.base = 16
375
+ if self.out_image_small_side_size == 512:
376
+ self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters)
377
+ self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters)
378
+ self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels,
379
+ 5, stride=1, padding=2,
380
+ weight_norm_type=weight_norm_type,
381
+ activation_norm_type='none',
382
+ nonlinearity=nonlinearity,
383
+ order='ANC')
384
+ self.base = 32
385
+ if self.out_image_small_side_size == 1024:
386
+ self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters)
387
+ self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters)
388
+ self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels,
389
+ 5, stride=1, padding=2,
390
+ weight_norm_type=weight_norm_type,
391
+ activation_norm_type='none',
392
+ nonlinearity=nonlinearity,
393
+ order='ANC')
394
+ self.up_4a = base_res2d_block(num_filters, num_filters // 2)
395
+ self.up_4b = base_res2d_block(num_filters // 2, num_filters // 2)
396
+ self.conv_img1024 = Conv2dBlock(num_filters // 2, image_channels,
397
+ 5, stride=1, padding=2,
398
+ weight_norm_type=weight_norm_type,
399
+ activation_norm_type='none',
400
+ nonlinearity=nonlinearity,
401
+ order='ANC')
402
+ self.nearest_upsample4x = NearestUpsample(scale_factor=4, mode='nearest')
403
+ self.base = 64
404
+ if self.out_image_small_side_size != 256 and self.out_image_small_side_size != 512 \
405
+ and self.out_image_small_side_size != 1024:
406
+ raise ValueError('Generation image size (%d, %d) not supported' %
407
+ (self.out_image_small_side_size,
408
+ self.out_image_small_side_size))
409
+ self.nearest_upsample2x = NearestUpsample(scale_factor=2, mode='nearest')
410
+
411
+ xv, yv = torch.meshgrid(
412
+ [torch.arange(-1, 1.1, 2. / 15), torch.arange(-1, 1.1, 2. / 15)])
413
+ self.xy = torch.cat((xv.unsqueeze(0), yv.unsqueeze(0)), 0).unsqueeze(0)
414
+ self.xy = self.xy.cuda()
415
+
416
+ def forward(self, data):
417
+ r"""SPADE Generator forward.
418
+
419
+ Args:
420
+ data (dict):
421
+ - data (N x C1 x H x W tensor) : Ground truth images.
422
+ - label (N x C2 x H x W tensor) : Semantic representations.
423
+ - z (N x style_dims tensor): Gaussian random noise.
424
+ Returns:
425
+ output (dict):
426
+ - fake_images (N x 3 x H x W tensor): Fake images.
427
+ """
428
+ seg = data['label']
429
+
430
+ if self.use_style_encoder:
431
+ z = data['z']
432
+ z = self.fc_0(z)
433
+ z = self.fc_1(z)
434
+
435
+ # The code piece below makes sure that the input size is always 16x16
436
+ sy = math.floor(seg.size()[2] * 1.0 / self.base)
437
+ sx = math.floor(seg.size()[3] * 1.0 / self.base)
438
+
439
+ in_seg = F.interpolate(seg, size=[sy, sx], mode='nearest')
440
+ if self.use_posenc_in_input_layer:
441
+ in_xy = F.interpolate(self.xy, size=[sy, sx], mode='bicubic')
442
+ in_seg_xy = torch.cat(
443
+ (in_seg, in_xy.expand(in_seg.size()[0], 2, sy, sx)), 1)
444
+ else:
445
+ in_seg_xy = in_seg
446
+ # 16x16
447
+ x = self.head_0(in_seg_xy)
448
+ if self.use_style_encoder:
449
+ x = self.cbn_head_0(x, z)
450
+ else:
451
+ x = self.conv_head_0(x)
452
+ x = self.head_1(x, seg)
453
+ x = self.head_2(x, seg)
454
+ x = self.nearest_upsample2x(x)
455
+ # 32x32
456
+ x = self.up_0a(x, seg)
457
+ if self.use_style_encoder:
458
+ x = self.cbn_up_0a(x, z)
459
+ else:
460
+ x = self.conv_up_0a(x)
461
+ x = self.up_0b(x, seg)
462
+ x = self.nearest_upsample2x(x)
463
+ # 64x64
464
+ x = self.up_1a(x, seg)
465
+ if self.use_style_encoder:
466
+ x = self.cbn_up_1a(x, z)
467
+ else:
468
+ x = self.conv_up_1a(x)
469
+ x = self.up_1b(x, seg)
470
+ x = self.nearest_upsample2x(x)
471
+ # 128x128
472
+ x = self.up_2a(x, seg)
473
+ if self.use_style_encoder:
474
+ x = self.cbn_up_2a(x, z)
475
+ else:
476
+ x = self.conv_up_2a(x)
477
+ x = self.up_2b(x, seg)
478
+ x = self.nearest_upsample2x(x)
479
+ # 256x256
480
+ if self.out_image_small_side_size == 256:
481
+ x256 = self.conv_img256(x)
482
+ x = torch.tanh(self.output_multiplier * x256)
483
+ # 512x512
484
+ elif self.out_image_small_side_size == 512:
485
+ x256 = self.conv_img256(x)
486
+ x256 = self.nearest_upsample2x(x256)
487
+ x = self.up_3a(x, seg)
488
+ x = self.up_3b(x, seg)
489
+ x = self.nearest_upsample2x(x)
490
+ x512 = self.conv_img512(x)
491
+ x = torch.tanh(self.output_multiplier * (x256 + x512))
492
+ # 1024x1024
493
+ elif self.out_image_small_side_size == 1024:
494
+ x256 = self.conv_img256(x)
495
+ x256 = self.nearest_upsample4x(x256)
496
+ x = self.up_3a(x, seg)
497
+ x = self.up_3b(x, seg)
498
+ x = self.nearest_upsample2x(x)
499
+ x512 = self.conv_img512(x)
500
+ x512 = self.nearest_upsample2x(x512)
501
+ x = self.up_4a(x, seg)
502
+ x = self.up_4b(x, seg)
503
+ x = self.nearest_upsample2x(x)
504
+ x1024 = self.conv_img1024(x)
505
+ x = torch.tanh(self.output_multiplier * (x256 + x512 + x1024))
506
+ output = dict()
507
+ output['fake_images'] = x
508
+ return output
509
+
510
+
511
+ class StyleEncoder(nn.Module):
512
+ r"""Style Encode constructor.
513
+
514
+ Args:
515
+ style_enc_cfg (obj): Style encoder definition file.
516
+ """
517
+
518
+ def __init__(self, style_enc_cfg):
519
+ super(StyleEncoder, self).__init__()
520
+ input_image_channels = style_enc_cfg.input_image_channels
521
+ num_filters = style_enc_cfg.num_filters
522
+ kernel_size = style_enc_cfg.kernel_size
523
+ padding = int(np.ceil((kernel_size - 1.0) / 2))
524
+ style_dims = style_enc_cfg.style_dims
525
+ weight_norm_type = style_enc_cfg.weight_norm_type
526
+ activation_norm_type = 'none'
527
+ nonlinearity = 'leakyrelu'
528
+ base_conv2d_block = \
529
+ functools.partial(Conv2dBlock,
530
+ kernel_size=kernel_size,
531
+ stride=2,
532
+ padding=padding,
533
+ weight_norm_type=weight_norm_type,
534
+ activation_norm_type=activation_norm_type,
535
+ # inplace_nonlinearity=True,
536
+ nonlinearity=nonlinearity)
537
+ self.layer1 = base_conv2d_block(input_image_channels, num_filters)
538
+ self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2)
539
+ self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4)
540
+ self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8)
541
+ self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8)
542
+ self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8)
543
+ self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
544
+ self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
545
+
546
+ def forward(self, input_x):
547
+ r"""SPADE Style Encoder forward.
548
+
549
+ Args:
550
+ input_x (N x 3 x H x W tensor): input images.
551
+ Returns:
552
+ (tuple):
553
+ - mu (N x C tensor): Mean vectors.
554
+ - logvar (N x C tensor): Log-variance vectors.
555
+ - z (N x C tensor): Style code vectors.
556
+ """
557
+ if input_x.size(2) != 256 or input_x.size(3) != 256:
558
+ input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear')
559
+ x = self.layer1(input_x)
560
+ x = self.layer2(x)
561
+ x = self.layer3(x)
562
+ x = self.layer4(x)
563
+ x = self.layer5(x)
564
+ x = self.layer6(x)
565
+ x = x.view(x.size(0), -1)
566
+ mu = self.fc_mu(x)
567
+ logvar = self.fc_var(x)
568
+ std = torch.exp(0.5 * logvar)
569
+ eps = torch.randn_like(std)
570
+ z = eps.mul(std) + mu
571
+ return mu, logvar, z
imaginaire/layers/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ from .conv import LinearBlock, Conv1dBlock, Conv2dBlock, Conv3dBlock, \
6
+ HyperConv2dBlock, MultiOutConv2dBlock, \
7
+ PartialConv2dBlock, PartialConv3dBlock
8
+ from .residual import ResLinearBlock, Res1dBlock, Res2dBlock, Res3dBlock, \
9
+ HyperRes2dBlock, MultiOutRes2dBlock, UpRes2dBlock, DownRes2dBlock, \
10
+ PartialRes2dBlock, PartialRes3dBlock
11
+ from .non_local import NonLocal2dBlock
12
+
13
+ __all__ = ['Conv1dBlock', 'Conv2dBlock', 'Conv3dBlock', 'LinearBlock',
14
+ 'HyperConv2dBlock', 'MultiOutConv2dBlock',
15
+ 'PartialConv2dBlock', 'PartialConv3dBlock',
16
+ 'Res1dBlock', 'Res2dBlock', 'Res3dBlock',
17
+ 'UpRes2dBlock', 'DownRes2dBlock',
18
+ 'ResLinearBlock', 'HyperRes2dBlock', 'MultiOutRes2dBlock',
19
+ 'PartialRes2dBlock', 'PartialRes3dBlock',
20
+ 'NonLocal2dBlock']
21
+
22
+ try:
23
+ from .repvgg import RepVGG1dBlock, RepVGG2dBlock, RepVGG3dBlock
24
+ from .attn import MultiheadAttention
25
+ __all__.extend(['RepVGG1dBlock', 'RepVGG2dBlock', 'RepVGG3dBlock'])
26
+ except: # noqa
27
+ pass
imaginaire/layers/activation_norm.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ # flake8: noqa E722
6
+ from types import SimpleNamespace
7
+
8
+ import torch
9
+
10
+ try:
11
+ from torch.nn import SyncBatchNorm
12
+ except ImportError:
13
+ from torch.nn import BatchNorm2d as SyncBatchNorm
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from .conv import LinearBlock, Conv2dBlock, HyperConv2d, PartialConv2dBlock
17
+ from .misc import PartialSequential, ApplyNoise
18
+
19
+
20
+ class AdaptiveNorm(nn.Module):
21
+ r"""Adaptive normalization layer. The layer first normalizes the input, then
22
+ performs an affine transformation using parameters computed from the
23
+ conditional inputs.
24
+
25
+ Args:
26
+ num_features (int): Number of channels in the input tensor.
27
+ cond_dims (int): Number of channels in the conditional inputs.
28
+ weight_norm_type (str): Type of weight normalization.
29
+ ``'none'``, ``'spectral'``, ``'weight'``, or ``'weight_demod'``.
30
+ projection (bool): If ``True``, project the conditional input to gamma
31
+ and beta using a fully connected layer, otherwise directly use
32
+ the conditional input as gamma and beta.
33
+ projection_bias (bool) If ``True``, use bias in the fully connected
34
+ projection layer.
35
+ separate_projection (bool): If ``True``, we will use two different
36
+ layers for gamma and beta. Otherwise, we will use one layer. It
37
+ matters only if you apply any weight norms to this layer.
38
+ input_dim (int): Number of dimensions of the input tensor.
39
+ activation_norm_type (str):
40
+ Type of activation normalization.
41
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
42
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
43
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
44
+ activation_norm_params (obj, optional, default=None):
45
+ Parameters of activation normalization.
46
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
47
+ keyword arguments when initializing activation normalization.
48
+ """
49
+
50
+ def __init__(self, num_features, cond_dims, weight_norm_type='',
51
+ projection=True,
52
+ projection_bias=True,
53
+ separate_projection=False,
54
+ input_dim=2,
55
+ activation_norm_type='instance',
56
+ activation_norm_params=None,
57
+ apply_noise=False,
58
+ add_bias=True,
59
+ input_scale=1.0,
60
+ init_gain=1.0):
61
+ super().__init__()
62
+ if activation_norm_params is None:
63
+ activation_norm_params = SimpleNamespace(affine=False)
64
+ self.norm = get_activation_norm_layer(num_features,
65
+ activation_norm_type,
66
+ input_dim,
67
+ **vars(activation_norm_params))
68
+ if apply_noise:
69
+ self.noise_layer = ApplyNoise()
70
+ else:
71
+ self.noise_layer = None
72
+
73
+ if projection:
74
+ if separate_projection:
75
+ self.fc_gamma = \
76
+ LinearBlock(cond_dims, num_features,
77
+ weight_norm_type=weight_norm_type,
78
+ bias=projection_bias)
79
+ self.fc_beta = \
80
+ LinearBlock(cond_dims, num_features,
81
+ weight_norm_type=weight_norm_type,
82
+ bias=projection_bias)
83
+ else:
84
+ self.fc = LinearBlock(cond_dims, num_features * 2,
85
+ weight_norm_type=weight_norm_type,
86
+ bias=projection_bias)
87
+
88
+ self.projection = projection
89
+ self.separate_projection = separate_projection
90
+ self.input_scale = input_scale
91
+ self.add_bias = add_bias
92
+ self.conditional = True
93
+ self.init_gain = init_gain
94
+
95
+ def forward(self, x, y, noise=None, **_kwargs):
96
+ r"""Adaptive Normalization forward.
97
+
98
+ Args:
99
+ x (N x C1 x * tensor): Input tensor.
100
+ y (N x C2 tensor): Conditional information.
101
+ Returns:
102
+ out (N x C1 x * tensor): Output tensor.
103
+ """
104
+ y = y * self.input_scale
105
+ if self.projection:
106
+ if self.separate_projection:
107
+ gamma = self.fc_gamma(y)
108
+ beta = self.fc_beta(y)
109
+ for _ in range(x.dim() - gamma.dim()):
110
+ gamma = gamma.unsqueeze(-1)
111
+ beta = beta.unsqueeze(-1)
112
+ else:
113
+ y = self.fc(y)
114
+ for _ in range(x.dim() - y.dim()):
115
+ y = y.unsqueeze(-1)
116
+ gamma, beta = y.chunk(2, 1)
117
+ else:
118
+ for _ in range(x.dim() - y.dim()):
119
+ y = y.unsqueeze(-1)
120
+ gamma, beta = y.chunk(2, 1)
121
+ if self.norm is not None:
122
+ x = self.norm(x)
123
+ if self.noise_layer is not None:
124
+ x = self.noise_layer(x, noise=noise)
125
+ if self.add_bias:
126
+ x = torch.addcmul(beta, x, 1 + gamma)
127
+ return x
128
+ else:
129
+ return x * (1 + gamma), beta.squeeze(3).squeeze(2)
130
+
131
+
132
+ class SpatiallyAdaptiveNorm(nn.Module):
133
+ r"""Spatially Adaptive Normalization (SPADE) initialization.
134
+
135
+ Args:
136
+ num_features (int) : Number of channels in the input tensor.
137
+ cond_dims (int or list of int) : List of numbers of channels
138
+ in the input.
139
+ num_filters (int): Number of filters in SPADE.
140
+ kernel_size (int): Kernel size of the convolutional filters in
141
+ the SPADE layer.
142
+ weight_norm_type (str): Type of weight normalization.
143
+ ``'none'``, ``'spectral'``, or ``'weight'``.
144
+ separate_projection (bool): If ``True``, we will use two different
145
+ layers for gamma and beta. Otherwise, we will use one layer. It
146
+ matters only if you apply any weight norms to this layer.
147
+ activation_norm_type (str):
148
+ Type of activation normalization.
149
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
150
+ ``'layer'``, ``'layer_2d'``, ``'group'``.
151
+ activation_norm_params (obj, optional, default=None):
152
+ Parameters of activation normalization.
153
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
154
+ keyword arguments when initializing activation normalization.
155
+ """
156
+
157
+ def __init__(self,
158
+ num_features,
159
+ cond_dims,
160
+ num_filters=128,
161
+ kernel_size=3,
162
+ weight_norm_type='',
163
+ separate_projection=False,
164
+ activation_norm_type='sync_batch',
165
+ activation_norm_params=None,
166
+ bias_only=False,
167
+ partial=False,
168
+ interpolation='nearest'):
169
+ super().__init__()
170
+ if activation_norm_params is None:
171
+ activation_norm_params = SimpleNamespace(affine=False)
172
+ padding = kernel_size // 2
173
+ self.separate_projection = separate_projection
174
+ self.mlps = nn.ModuleList()
175
+ self.gammas = nn.ModuleList()
176
+ self.betas = nn.ModuleList()
177
+ self.bias_only = bias_only
178
+ self.interpolation = interpolation
179
+
180
+ # Make cond_dims a list.
181
+ if type(cond_dims) != list:
182
+ cond_dims = [cond_dims]
183
+
184
+ # Make num_filters a list.
185
+ if not isinstance(num_filters, list):
186
+ num_filters = [num_filters] * len(cond_dims)
187
+ else:
188
+ assert len(num_filters) >= len(cond_dims)
189
+
190
+ # Make partial a list.
191
+ if not isinstance(partial, list):
192
+ partial = [partial] * len(cond_dims)
193
+ else:
194
+ assert len(partial) >= len(cond_dims)
195
+
196
+ for i, cond_dim in enumerate(cond_dims):
197
+ mlp = []
198
+ conv_block = PartialConv2dBlock if partial[i] else Conv2dBlock
199
+ sequential = PartialSequential if partial[i] else nn.Sequential
200
+
201
+ if num_filters[i] > 0:
202
+ mlp += [conv_block(cond_dim,
203
+ num_filters[i],
204
+ kernel_size,
205
+ padding=padding,
206
+ weight_norm_type=weight_norm_type,
207
+ nonlinearity='relu')]
208
+ mlp_ch = cond_dim if num_filters[i] == 0 else num_filters[i]
209
+
210
+ if self.separate_projection:
211
+ if partial[i]:
212
+ raise NotImplementedError(
213
+ 'Separate projection not yet implemented for ' +
214
+ 'partial conv')
215
+ self.mlps.append(nn.Sequential(*mlp))
216
+ self.gammas.append(
217
+ conv_block(mlp_ch, num_features,
218
+ kernel_size,
219
+ padding=padding,
220
+ weight_norm_type=weight_norm_type))
221
+ self.betas.append(
222
+ conv_block(mlp_ch, num_features,
223
+ kernel_size,
224
+ padding=padding,
225
+ weight_norm_type=weight_norm_type))
226
+ else:
227
+ mlp += [conv_block(mlp_ch, num_features * 2, kernel_size,
228
+ padding=padding,
229
+ weight_norm_type=weight_norm_type)]
230
+ self.mlps.append(sequential(*mlp))
231
+
232
+ self.norm = get_activation_norm_layer(num_features,
233
+ activation_norm_type,
234
+ 2,
235
+ **vars(activation_norm_params))
236
+ self.conditional = True
237
+
238
+ def forward(self, x, *cond_inputs, **_kwargs):
239
+ r"""Spatially Adaptive Normalization (SPADE) forward.
240
+
241
+ Args:
242
+ x (N x C1 x H x W tensor) : Input tensor.
243
+ cond_inputs (list of tensors) : Conditional maps for SPADE.
244
+ Returns:
245
+ output (4D tensor) : Output tensor.
246
+ """
247
+ output = self.norm(x) if self.norm is not None else x
248
+ for i in range(len(cond_inputs)):
249
+ if cond_inputs[i] is None:
250
+ continue
251
+ label_map = F.interpolate(cond_inputs[i], size=x.size()[2:], mode=self.interpolation)
252
+ if self.separate_projection:
253
+ hidden = self.mlps[i](label_map)
254
+ gamma = self.gammas[i](hidden)
255
+ beta = self.betas[i](hidden)
256
+ else:
257
+ affine_params = self.mlps[i](label_map)
258
+ gamma, beta = affine_params.chunk(2, dim=1)
259
+ if self.bias_only:
260
+ output = output + beta
261
+ else:
262
+ output = output * (1 + gamma) + beta
263
+ return output
264
+
265
+
266
+ class DualAdaptiveNorm(nn.Module):
267
+ def __init__(self,
268
+ num_features,
269
+ cond_dims,
270
+ projection_bias=True,
271
+ weight_norm_type='',
272
+ activation_norm_type='instance',
273
+ activation_norm_params=None,
274
+ apply_noise=False,
275
+ bias_only=False,
276
+ init_gain=1.0,
277
+ fc_scale=None,
278
+ is_spatial=None):
279
+ super().__init__()
280
+ if activation_norm_params is None:
281
+ activation_norm_params = SimpleNamespace(affine=False)
282
+ self.mlps = nn.ModuleList()
283
+ self.gammas = nn.ModuleList()
284
+ self.betas = nn.ModuleList()
285
+ self.bias_only = bias_only
286
+
287
+ # Make cond_dims a list.
288
+ if type(cond_dims) != list:
289
+ cond_dims = [cond_dims]
290
+
291
+ if is_spatial is None:
292
+ is_spatial = [False for _ in range(len(cond_dims))]
293
+ self.is_spatial = is_spatial
294
+
295
+ for cond_dim, this_is_spatial in zip(cond_dims, is_spatial):
296
+ kwargs = dict(weight_norm_type=weight_norm_type,
297
+ bias=projection_bias,
298
+ init_gain=init_gain,
299
+ output_scale=fc_scale)
300
+ if this_is_spatial:
301
+ self.gammas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs))
302
+ self.betas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs))
303
+ else:
304
+ self.gammas.append(LinearBlock(cond_dim, num_features, **kwargs))
305
+ self.betas.append(LinearBlock(cond_dim, num_features, **kwargs))
306
+
307
+ self.norm = get_activation_norm_layer(num_features,
308
+ activation_norm_type,
309
+ 2,
310
+ **vars(activation_norm_params))
311
+ self.conditional = True
312
+
313
+ def forward(self, x, *cond_inputs, **_kwargs):
314
+ assert len(cond_inputs) == len(self.gammas)
315
+ output = self.norm(x) if self.norm is not None else x
316
+ for cond, gamma_layer, beta_layer in zip(cond_inputs, self.gammas, self.betas):
317
+ if cond is None:
318
+ continue
319
+ gamma = gamma_layer(cond)
320
+ beta = beta_layer(cond)
321
+ if cond.dim() == 4 and gamma.shape != x.shape:
322
+ gamma = F.interpolate(gamma, size=x.size()[2:], mode='bilinear')
323
+ beta = F.interpolate(beta, size=x.size()[2:], mode='bilinear')
324
+ elif cond.dim() == 2:
325
+ gamma = gamma[:, :, None, None]
326
+ beta = beta[:, :, None, None]
327
+ if self.bias_only:
328
+ output = output + beta
329
+ else:
330
+ output = output * (1 + gamma) + beta
331
+ return output
332
+
333
+
334
+ class HyperSpatiallyAdaptiveNorm(nn.Module):
335
+ r"""Spatially Adaptive Normalization (SPADE) initialization.
336
+
337
+ Args:
338
+ num_features (int) : Number of channels in the input tensor.
339
+ cond_dims (int or list of int) : List of numbers of channels
340
+ in the conditional input.
341
+ num_filters (int): Number of filters in SPADE.
342
+ kernel_size (int): Kernel size of the convolutional filters in
343
+ the SPADE layer.
344
+ weight_norm_type (str): Type of weight normalization.
345
+ ``'none'``, ``'spectral'``, or ``'weight'``.
346
+ activation_norm_type (str):
347
+ Type of activation normalization.
348
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
349
+ ``'layer'``, ``'layer_2d'``, ``'group'``.
350
+ is_hyper (bool): Whether to use hyper SPADE.
351
+ """
352
+
353
+ def __init__(self, num_features, cond_dims,
354
+ num_filters=0, kernel_size=3,
355
+ weight_norm_type='',
356
+ activation_norm_type='sync_batch', is_hyper=True):
357
+ super().__init__()
358
+ padding = kernel_size // 2
359
+ self.mlps = nn.ModuleList()
360
+ if type(cond_dims) != list:
361
+ cond_dims = [cond_dims]
362
+
363
+ for i, cond_dim in enumerate(cond_dims):
364
+ mlp = []
365
+ if not is_hyper or (i != 0):
366
+ if num_filters > 0:
367
+ mlp += [Conv2dBlock(cond_dim, num_filters, kernel_size,
368
+ padding=padding,
369
+ weight_norm_type=weight_norm_type,
370
+ nonlinearity='relu')]
371
+ mlp_ch = cond_dim if num_filters == 0 else num_filters
372
+ mlp += [Conv2dBlock(mlp_ch, num_features * 2, kernel_size,
373
+ padding=padding,
374
+ weight_norm_type=weight_norm_type)]
375
+ mlp = nn.Sequential(*mlp)
376
+ else:
377
+ if num_filters > 0:
378
+ raise ValueError('Multi hyper layer not supported yet.')
379
+ mlp = HyperConv2d(padding=padding)
380
+ self.mlps.append(mlp)
381
+
382
+ self.norm = get_activation_norm_layer(num_features,
383
+ activation_norm_type,
384
+ 2,
385
+ affine=False)
386
+
387
+ self.conditional = True
388
+
389
+ def forward(self, x, *cond_inputs,
390
+ norm_weights=(None, None), **_kwargs):
391
+ r"""Spatially Adaptive Normalization (SPADE) forward.
392
+
393
+ Args:
394
+ x (4D tensor) : Input tensor.
395
+ cond_inputs (list of tensors) : Conditional maps for SPADE.
396
+ norm_weights (5D tensor or list of tensors): conv weights or
397
+ [weights, biases].
398
+ Returns:
399
+ output (4D tensor) : Output tensor.
400
+ """
401
+ output = self.norm(x)
402
+ for i in range(len(cond_inputs)):
403
+ if cond_inputs[i] is None:
404
+ continue
405
+ if type(cond_inputs[i]) == list:
406
+ cond_input, mask = cond_inputs[i]
407
+ mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=False)
408
+ else:
409
+ cond_input = cond_inputs[i]
410
+ mask = None
411
+ label_map = F.interpolate(cond_input, size=x.size()[2:])
412
+ if norm_weights is None or norm_weights[0] is None or i != 0:
413
+ affine_params = self.mlps[i](label_map)
414
+ else:
415
+ affine_params = self.mlps[i](label_map,
416
+ conv_weights=norm_weights)
417
+ gamma, beta = affine_params.chunk(2, dim=1)
418
+ if mask is not None:
419
+ gamma = gamma * (1 - mask)
420
+ beta = beta * (1 - mask)
421
+ output = output * (1 + gamma) + beta
422
+ return output
423
+
424
+
425
+ class LayerNorm2d(nn.Module):
426
+ r"""Layer Normalization as introduced in
427
+ https://arxiv.org/abs/1607.06450.
428
+ This is the usual way to apply layer normalization in CNNs.
429
+ Note that unlike the pytorch implementation which applies per-element
430
+ scale and bias, here it applies per-channel scale and bias, similar to
431
+ batch/instance normalization.
432
+
433
+ Args:
434
+ num_features (int): Number of channels in the input tensor.
435
+ eps (float, optional, default=1e-5): a value added to the
436
+ denominator for numerical stability.
437
+ affine (bool, optional, default=False): If ``True``, performs
438
+ affine transformation after normalization.
439
+ """
440
+
441
+ def __init__(self, num_features, eps=1e-5, channel_only=False, affine=True):
442
+ super(LayerNorm2d, self).__init__()
443
+ self.num_features = num_features
444
+ self.affine = affine
445
+ self.eps = eps
446
+ self.channel_only = channel_only
447
+
448
+ if self.affine:
449
+ self.gamma = nn.Parameter(torch.Tensor(num_features).fill_(1.0))
450
+ self.beta = nn.Parameter(torch.zeros(num_features))
451
+
452
+ def forward(self, x):
453
+ r"""
454
+
455
+ Args:
456
+ x (tensor): Input tensor.
457
+ """
458
+ shape = [-1] + [1] * (x.dim() - 1)
459
+ if self.channel_only:
460
+ mean = x.mean(1, keepdim=True)
461
+ std = x.std(1, keepdim=True)
462
+ else:
463
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
464
+ std = x.view(x.size(0), -1).std(1).view(*shape)
465
+
466
+ x = (x - mean) / (std + self.eps)
467
+
468
+ if self.affine:
469
+ shape = [1, -1] + [1] * (x.dim() - 2)
470
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
471
+ return x
472
+
473
+
474
+ class ScaleNorm(nn.Module):
475
+ r"""Scale normalization:
476
+ "Transformers without Tears: Improving the Normalization of Self-Attention"
477
+ Modified from:
478
+ https://github.com/tnq177/transformers_without_tears
479
+ """
480
+
481
+ def __init__(self, dim=-1, learned_scale=True, eps=1e-5):
482
+ super().__init__()
483
+ # scale = num_features ** 0.5
484
+ if learned_scale:
485
+ self.scale = nn.Parameter(torch.tensor(1.))
486
+ else:
487
+ self.scale = 1.
488
+ # self.num_features = num_features
489
+ self.dim = dim
490
+ self.eps = eps
491
+ self.learned_scale = learned_scale
492
+
493
+ def forward(self, x):
494
+ # noinspection PyArgumentList
495
+ scale = self.scale * torch.rsqrt(torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
496
+ return x * scale
497
+
498
+ def extra_repr(self):
499
+ s = 'learned_scale={learned_scale}'
500
+ return s.format(**self.__dict__)
501
+
502
+
503
+ class PixelNorm(ScaleNorm):
504
+ def __init__(self, learned_scale=False, eps=1e-5, **_kwargs):
505
+ super().__init__(1, learned_scale, eps)
506
+
507
+
508
+ class SplitMeanStd(nn.Module):
509
+ def __init__(self, num_features, eps=1e-5, **kwargs):
510
+ super().__init__()
511
+ self.num_features = num_features
512
+ self.eps = eps
513
+ self.multiple_outputs = True
514
+
515
+ def forward(self, x):
516
+ b, c, h, w = x.size()
517
+ mean = x.view(b, c, -1).mean(-1)[:, :, None, None]
518
+ var = x.view(b, c, -1).var(-1)[:, :, None, None]
519
+ std = torch.sqrt(var + self.eps)
520
+
521
+ # x = (x - mean) / std
522
+ return x, torch.cat((mean, std), dim=1)
523
+
524
+
525
+ class ScaleNorm(nn.Module):
526
+ r"""Scale normalization:
527
+ "Transformers without Tears: Improving the Normalization of Self-Attention"
528
+ Modified from:
529
+ https://github.com/tnq177/transformers_without_tears
530
+ """
531
+
532
+ def __init__(self, dim=-1, learned_scale=True, eps=1e-5):
533
+ super().__init__()
534
+ # scale = num_features ** 0.5
535
+ if learned_scale:
536
+ self.scale = nn.Parameter(torch.tensor(1.))
537
+ else:
538
+ self.scale = 1.
539
+ # self.num_features = num_features
540
+ self.dim = dim
541
+ self.eps = eps
542
+ self.learned_scale = learned_scale
543
+
544
+ def forward(self, x):
545
+ # noinspection PyArgumentList
546
+ scale = self.scale * torch.rsqrt(
547
+ torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
548
+ return x * scale
549
+
550
+ def extra_repr(self):
551
+ s = 'learned_scale={learned_scale}'
552
+ return s.format(**self.__dict__)
553
+
554
+
555
+ class PixelLayerNorm(nn.Module):
556
+ def __init__(self, *args, **kwargs):
557
+ super().__init__()
558
+ self.norm = nn.LayerNorm(*args, **kwargs)
559
+
560
+ def forward(self, x):
561
+ if x.dim() == 4:
562
+ b, c, h, w = x.shape
563
+ return self.norm(x.permute(0, 2, 3, 1).view(-1, c).contiguous()).view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
564
+ else:
565
+ return self.norm(x)
566
+
567
+
568
+ def get_activation_norm_layer(num_features, norm_type, input_dim, **norm_params):
569
+ r"""Return an activation normalization layer.
570
+
571
+ Args:
572
+ num_features (int): Number of feature channels.
573
+ norm_type (str):
574
+ Type of activation normalization.
575
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
576
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
577
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
578
+ input_dim (int): Number of input dimensions.
579
+ norm_params: Arbitrary keyword arguments that will be used to
580
+ initialize the activation normalization.
581
+ """
582
+ input_dim = max(input_dim, 1) # Norm1d works with both 0d and 1d inputs
583
+
584
+ if norm_type == 'none' or norm_type == '':
585
+ norm_layer = None
586
+ elif norm_type == 'batch':
587
+ norm = getattr(nn, 'BatchNorm%dd' % input_dim)
588
+ norm_layer = norm(num_features, **norm_params)
589
+ elif norm_type == 'instance':
590
+ affine = norm_params.pop('affine', True) # Use affine=True by default
591
+ norm = getattr(nn, 'InstanceNorm%dd' % input_dim)
592
+ norm_layer = norm(num_features, affine=affine, **norm_params)
593
+ elif norm_type == 'sync_batch':
594
+ norm_layer = SyncBatchNorm(num_features, **norm_params)
595
+ elif norm_type == 'layer':
596
+ norm_layer = nn.LayerNorm(num_features, **norm_params)
597
+ elif norm_type == 'layer_2d':
598
+ norm_layer = LayerNorm2d(num_features, **norm_params)
599
+ elif norm_type == 'pixel_layer':
600
+ elementwise_affine = norm_params.pop('affine', True) # Use affine=True by default
601
+ norm_layer = PixelLayerNorm(num_features, elementwise_affine=elementwise_affine, **norm_params)
602
+ elif norm_type == 'scale':
603
+ norm_layer = ScaleNorm(**norm_params)
604
+ elif norm_type == 'pixel':
605
+ norm_layer = PixelNorm(**norm_params)
606
+ import imaginaire.config
607
+ if imaginaire.config.USE_JIT:
608
+ norm_layer = torch.jit.script(norm_layer)
609
+ elif norm_type == 'group':
610
+ num_groups = norm_params.pop('num_groups', 4)
611
+ norm_layer = nn.GroupNorm(num_channels=num_features, num_groups=num_groups, **norm_params)
612
+ elif norm_type == 'adaptive':
613
+ norm_layer = AdaptiveNorm(num_features, **norm_params)
614
+ elif norm_type == 'dual_adaptive':
615
+ norm_layer = DualAdaptiveNorm(num_features, **norm_params)
616
+ elif norm_type == 'spatially_adaptive':
617
+ if input_dim != 2:
618
+ raise ValueError('Spatially adaptive normalization layers '
619
+ 'only supports 2D input')
620
+ norm_layer = SpatiallyAdaptiveNorm(num_features, **norm_params)
621
+ elif norm_type == 'hyper_spatially_adaptive':
622
+ if input_dim != 2:
623
+ raise ValueError('Spatially adaptive normalization layers '
624
+ 'only supports 2D input')
625
+ norm_layer = HyperSpatiallyAdaptiveNorm(num_features, **norm_params)
626
+ else:
627
+ raise ValueError('Activation norm layer %s '
628
+ 'is not recognized' % norm_type)
629
+ return norm_layer
imaginaire/layers/conv.py ADDED
@@ -0,0 +1,1377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import warnings
6
+ from types import SimpleNamespace
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ from .misc import ApplyNoise
13
+ from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur
14
+
15
+
16
+ class _BaseConvBlock(nn.Module):
17
+ r"""An abstract wrapper class that wraps a torch convolution or linear layer
18
+ with normalization and nonlinearity.
19
+ """
20
+
21
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode,
22
+ weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, nonlinearity,
23
+ inplace_nonlinearity, apply_noise, blur, order, input_dim, clamp, blur_kernel, output_scale,
24
+ init_gain):
25
+ super().__init__()
26
+ from .nonlinearity import get_nonlinearity_layer
27
+ from .weight_norm import get_weight_norm_layer
28
+ from .activation_norm import get_activation_norm_layer
29
+ self.weight_norm_type = weight_norm_type
30
+ self.stride = stride
31
+ self.clamp = clamp
32
+ self.init_gain = init_gain
33
+
34
+ # Nonlinearity layer.
35
+ if 'fused' in nonlinearity:
36
+ # Fusing nonlinearity with bias.
37
+ lr_mul = getattr(weight_norm_params, 'lr_mul', 1)
38
+ conv_before_nonlinearity = order.find('C') < order.find('A')
39
+ if conv_before_nonlinearity:
40
+ assert bias is True
41
+ bias = False
42
+ channel = out_channels if conv_before_nonlinearity else in_channels
43
+ nonlinearity_layer = get_nonlinearity_layer(
44
+ nonlinearity, inplace=inplace_nonlinearity,
45
+ num_channels=channel, lr_mul=lr_mul)
46
+ else:
47
+ nonlinearity_layer = get_nonlinearity_layer(
48
+ nonlinearity, inplace=inplace_nonlinearity)
49
+
50
+ # Noise injection layer.
51
+ if apply_noise:
52
+ order = order.replace('C', 'CG')
53
+ noise_layer = ApplyNoise()
54
+ else:
55
+ noise_layer = None
56
+
57
+ # Convolutional layer.
58
+ if blur:
59
+ assert blur_kernel is not None
60
+ if stride == 2:
61
+ # Blur - Conv - Noise - Activate
62
+ p = (len(blur_kernel) - 2) + (kernel_size - 1)
63
+ pad0, pad1 = (p + 1) // 2, p // 2
64
+ padding = 0
65
+ blur_layer = Blur(
66
+ blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
67
+ )
68
+ order = order.replace('C', 'BC')
69
+ elif stride == 0.5:
70
+ # Conv - Blur - Noise - Activate
71
+ padding = 0
72
+ p = (len(blur_kernel) - 2) - (kernel_size - 1)
73
+ pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1
74
+ blur_layer = Blur(
75
+ blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
76
+ )
77
+ order = order.replace('C', 'CB')
78
+ elif stride == 1:
79
+ # No blur for now
80
+ blur_layer = nn.Identity()
81
+ else:
82
+ raise NotImplementedError
83
+ else:
84
+ blur_layer = nn.Identity()
85
+
86
+ if weight_norm_params is None:
87
+ weight_norm_params = SimpleNamespace()
88
+ weight_norm = get_weight_norm_layer(
89
+ weight_norm_type, **vars(weight_norm_params))
90
+ conv_layer = weight_norm(self._get_conv_layer(
91
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
92
+ groups, bias, padding_mode, input_dim))
93
+
94
+ # Normalization layer.
95
+ conv_before_norm = order.find('C') < order.find('N')
96
+ norm_channels = out_channels if conv_before_norm else in_channels
97
+ if activation_norm_params is None:
98
+ activation_norm_params = SimpleNamespace()
99
+ activation_norm_layer = get_activation_norm_layer(
100
+ norm_channels,
101
+ activation_norm_type,
102
+ input_dim,
103
+ **vars(activation_norm_params))
104
+
105
+ # Mapping from operation names to layers.
106
+ mappings = {'C': {'conv': conv_layer},
107
+ 'N': {'norm': activation_norm_layer},
108
+ 'A': {'nonlinearity': nonlinearity_layer}}
109
+ mappings.update({'B': {'blur': blur_layer}})
110
+ mappings.update({'G': {'noise': noise_layer}})
111
+
112
+ # All layers in order.
113
+ self.layers = nn.ModuleDict()
114
+ for op in order:
115
+ if list(mappings[op].values())[0] is not None:
116
+ self.layers.update(mappings[op])
117
+
118
+ # Whether this block expects conditional inputs.
119
+ self.conditional = \
120
+ getattr(conv_layer, 'conditional', False) or \
121
+ getattr(activation_norm_layer, 'conditional', False)
122
+
123
+ # Scale the output by a learnable scaler parameter.
124
+ if output_scale is not None:
125
+ self.output_scale = nn.Parameter(torch.tensor(output_scale))
126
+ else:
127
+ self.register_parameter("output_scale", None)
128
+
129
+ def forward(self, x, *cond_inputs, **kw_cond_inputs):
130
+ r"""
131
+
132
+ Args:
133
+ x (tensor): Input tensor.
134
+ cond_inputs (list of tensors) : Conditional input tensors.
135
+ kw_cond_inputs (dict) : Keyword conditional inputs.
136
+ """
137
+ for key, layer in self.layers.items():
138
+ if getattr(layer, 'conditional', False):
139
+ # Layers that require conditional inputs.
140
+ x = layer(x, *cond_inputs, **kw_cond_inputs)
141
+ else:
142
+ x = layer(x)
143
+ if self.clamp is not None and isinstance(layer, nn.Conv2d):
144
+ x.clamp_(max=self.clamp)
145
+ if key == 'conv':
146
+ if self.output_scale is not None:
147
+ x = x * self.output_scale
148
+ return x
149
+
150
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
151
+ padding, dilation, groups, bias, padding_mode,
152
+ input_dim):
153
+ # Returns the convolutional layer.
154
+ if input_dim == 0:
155
+ layer = nn.Linear(in_channels, out_channels, bias)
156
+ else:
157
+ if stride < 1: # Fractionally-strided convolution.
158
+ padding_mode = 'zeros'
159
+ assert padding == 0
160
+ layer_type = getattr(nn, f'ConvTranspose{input_dim}d')
161
+ stride = round(1 / stride)
162
+ else:
163
+ layer_type = getattr(nn, f'Conv{input_dim}d')
164
+ layer = layer_type(
165
+ in_channels, out_channels, kernel_size, stride, padding,
166
+ dilation=dilation, groups=groups, bias=bias,
167
+ padding_mode=padding_mode
168
+ )
169
+
170
+ return layer
171
+
172
+ def __repr__(self):
173
+ main_str = self._get_name() + '('
174
+ child_lines = []
175
+ for name, layer in self.layers.items():
176
+ mod_str = repr(layer)
177
+ if name == 'conv' and self.weight_norm_type != 'none' and \
178
+ self.weight_norm_type != '':
179
+ mod_str = mod_str[:-1] + \
180
+ ', weight_norm={}'.format(self.weight_norm_type) + ')'
181
+ if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1:
182
+ mod_str = mod_str[:-1] + \
183
+ ', lr_mul={}'.format(layer.base_lr_mul) + ')'
184
+ mod_str = self._addindent(mod_str, 2)
185
+ child_lines.append(mod_str)
186
+ if len(child_lines) == 1:
187
+ main_str += child_lines[0]
188
+ else:
189
+ main_str += '\n ' + '\n '.join(child_lines) + '\n'
190
+
191
+ main_str += ')'
192
+ return main_str
193
+
194
+ @staticmethod
195
+ def _addindent(s_, numSpaces):
196
+ s = s_.split('\n')
197
+ # don't do anything for single-line stuff
198
+ if len(s) == 1:
199
+ return s_
200
+ first = s.pop(0)
201
+ s = [(numSpaces * ' ') + line for line in s]
202
+ s = '\n'.join(s)
203
+ s = first + '\n' + s
204
+ return s
205
+
206
+
207
+ class ModulatedConv2dBlock(_BaseConvBlock):
208
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
209
+ padding=0, dilation=1, groups=1, bias=True,
210
+ padding_mode='zeros',
211
+ weight_norm_type='none', weight_norm_params=None,
212
+ activation_norm_type='none', activation_norm_params=None,
213
+ nonlinearity='none', inplace_nonlinearity=False,
214
+ apply_noise=True, blur=True, order='CNA', demodulate=True,
215
+ eps=True, style_dim=None, clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, init_gain=1.0):
216
+ self.eps = eps
217
+ self.demodulate = demodulate
218
+ assert style_dim is not None
219
+
220
+ super().__init__(in_channels, out_channels, kernel_size, stride,
221
+ padding, dilation, groups, bias, padding_mode,
222
+ weight_norm_type, weight_norm_params,
223
+ activation_norm_type, activation_norm_params,
224
+ nonlinearity, inplace_nonlinearity, apply_noise, blur,
225
+ order, 2, clamp, blur_kernel, output_scale, init_gain)
226
+ self.modulation = LinearBlock(style_dim, in_channels,
227
+ weight_norm_type=weight_norm_type,
228
+ weight_norm_params=weight_norm_params)
229
+
230
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
231
+ padding, dilation, groups, bias, padding_mode,
232
+ input_dim):
233
+ assert input_dim == 2
234
+ layer = ModulatedConv2d(
235
+ in_channels, out_channels, kernel_size, stride, padding,
236
+ dilation, groups, bias, padding_mode, self.demodulate, self.eps)
237
+ return layer
238
+
239
+ def forward(self, x, *cond_inputs, **kw_cond_inputs):
240
+ for layer in self.layers.values():
241
+ if getattr(layer, 'conditional', False):
242
+ # Layers that require conditional inputs.
243
+ assert len(cond_inputs) == 1
244
+ style = cond_inputs[0]
245
+ x = layer(
246
+ x, self.modulation(style), **kw_cond_inputs
247
+ )
248
+ else:
249
+ x = layer(x)
250
+ if self.clamp is not None and isinstance(layer, ModulatedConv2d):
251
+ x.clamp_(max=self.clamp)
252
+ return x
253
+
254
+ def __repr__(self):
255
+ main_str = self._get_name() + '('
256
+ child_lines = []
257
+ for name, layer in self.layers.items():
258
+ mod_str = repr(layer)
259
+ if name == 'conv' and self.weight_norm_type != 'none' and \
260
+ self.weight_norm_type != '':
261
+ mod_str = mod_str[:-1] + \
262
+ ', weight_norm={}'.format(self.weight_norm_type) + \
263
+ ', demodulate={}'.format(self.demodulate) + ')'
264
+ mod_str = self._addindent(mod_str, 2)
265
+ child_lines.append(mod_str)
266
+ child_lines.append(
267
+ self._addindent('Modulation(' + repr(self.modulation) + ')', 2)
268
+ )
269
+ if len(child_lines) == 1:
270
+ main_str += child_lines[0]
271
+ else:
272
+ main_str += '\n ' + '\n '.join(child_lines) + '\n'
273
+
274
+ main_str += ')'
275
+ return main_str
276
+
277
+
278
+ class ModulatedConv2d(nn.Module):
279
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
280
+ dilation, groups, bias, padding_mode, demodulate=True,
281
+ eps=1e-8):
282
+ # in_channels, out_channels, kernel_size, stride, padding,
283
+ # dilation, groups, bias, padding_mode
284
+ assert dilation == 1 and groups == 1
285
+
286
+ super().__init__()
287
+
288
+ self.eps = eps
289
+ self.kernel_size = kernel_size
290
+ self.in_channels = in_channels
291
+ self.out_channels = out_channels
292
+ self.padding = padding
293
+ self.stride = stride
294
+ self.padding_mode = padding_mode
295
+ # kernel_size // 2
296
+ # assert self.padding == padding
297
+
298
+ self.weight = nn.Parameter(
299
+ torch.randn(out_channels, in_channels, kernel_size, kernel_size)
300
+ )
301
+
302
+ if bias:
303
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
304
+ else:
305
+ # noinspection PyTypeChecker
306
+ self.register_parameter('bias', None)
307
+
308
+ # self.modulation = LinearBlock(style_dim, in_channels,
309
+ # weight_norm_type=weight_norm_type)
310
+ self.demodulate = demodulate
311
+ self.conditional = True
312
+
313
+ def forward(self, x, style, **_kwargs):
314
+ batch, in_channel, height, width = x.shape
315
+
316
+ # style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
317
+ # We assume the modulation layer is outside this module.
318
+ style = style.view(batch, 1, in_channel, 1, 1)
319
+ weight = self.weight.unsqueeze(0) * style
320
+
321
+ if self.demodulate:
322
+ demod = torch.rsqrt(
323
+ weight.pow(2).sum([2, 3, 4]) + self.eps)
324
+ weight = weight * demod.view(batch, self.out_channels, 1, 1, 1)
325
+
326
+ weight = weight.view(
327
+ batch * self.out_channels,
328
+ in_channel, self.kernel_size, self.kernel_size
329
+ )
330
+ if self.bias is not None:
331
+ bias = self.bias.repeat(batch)
332
+ else:
333
+ bias = self.bias
334
+
335
+ x = x.view(1, batch * in_channel, height, width)
336
+
337
+ if self.padding_mode != 'zeros':
338
+ x = F.pad(x, self._reversed_padding_repeated_twice,
339
+ mode=self.padding_mode)
340
+ padding = (0, 0)
341
+ else:
342
+ padding = self.padding
343
+
344
+ if self.stride == 0.5:
345
+ weight = weight.view(
346
+ batch, self.out_channels, in_channel,
347
+ self.kernel_size, self.kernel_size
348
+ )
349
+ weight = weight.transpose(1, 2).reshape(
350
+ batch * in_channel, self.out_channels,
351
+ self.kernel_size, self.kernel_size
352
+ )
353
+ out = F.conv_transpose2d(
354
+ x, weight, bias, padding=padding, stride=2, groups=batch
355
+ )
356
+
357
+ elif self.stride == 2:
358
+ out = F.conv2d(
359
+ x, weight, bias, padding=padding, stride=2, groups=batch
360
+ )
361
+
362
+ else:
363
+ out = F.conv2d(x, weight, bias, padding=padding, groups=batch)
364
+
365
+ _, _, height, width = out.shape
366
+ out = out.view(batch, self.out_channels, height, width)
367
+
368
+ return out
369
+
370
+ def extra_repr(self):
371
+ s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
372
+ ', stride={stride}')
373
+ if self.bias is None:
374
+ s += ', bias=False'
375
+ if self.padding_mode != 'zeros':
376
+ s += ', padding_mode={padding_mode}'
377
+ return s.format(**self.__dict__)
378
+
379
+
380
+ class LinearBlock(_BaseConvBlock):
381
+ r"""A Wrapper class that wraps ``torch.nn.Linear`` with normalization and
382
+ nonlinearity.
383
+
384
+ Args:
385
+ in_features (int): Number of channels in the input tensor.
386
+ out_features (int): Number of channels in the output tensor.
387
+ bias (bool, optional, default=True):
388
+ If ``True``, adds a learnable bias to the output.
389
+ weight_norm_type (str, optional, default='none'):
390
+ Type of weight normalization.
391
+ ``'none'``, ``'spectral'``, ``'weight'``
392
+ or ``'weight_demod'``.
393
+ weight_norm_params (obj, optional, default=None):
394
+ Parameters of weight normalization.
395
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
396
+ keyword arguments when initializing weight normalization.
397
+ activation_norm_type (str, optional, default='none'):
398
+ Type of activation normalization.
399
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
400
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
401
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
402
+ activation_norm_params (obj, optional, default=None):
403
+ Parameters of activation normalization.
404
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
405
+ keyword arguments when initializing activation normalization.
406
+ nonlinearity (str, optional, default='none'):
407
+ Type of nonlinear activation function.
408
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
409
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
410
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
411
+ set ``inplace=True`` when initializing the nonlinearity layer.
412
+ apply_noise (bool, optional, default=False): If ``True``, add
413
+ Gaussian noise with learnable magnitude after the
414
+ fully-connected layer.
415
+ order (str, optional, default='CNA'): Order of operations.
416
+ ``'C'``: fully-connected,
417
+ ``'N'``: normalization,
418
+ ``'A'``: nonlinear activation.
419
+ For example, a block initialized with ``order='CNA'`` will
420
+ do convolution first, then normalization, then nonlinearity.
421
+ """
422
+
423
+ def __init__(self, in_features, out_features, bias=True,
424
+ weight_norm_type='none', weight_norm_params=None,
425
+ activation_norm_type='none', activation_norm_params=None,
426
+ nonlinearity='none', inplace_nonlinearity=False,
427
+ apply_noise=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None,
428
+ init_gain=1.0, **_kwargs):
429
+ if bool(_kwargs):
430
+ warnings.warn(f"Unused keyword arguments {_kwargs}")
431
+ super().__init__(in_features, out_features, None, None,
432
+ None, None, None, bias,
433
+ None, weight_norm_type, weight_norm_params,
434
+ activation_norm_type, activation_norm_params,
435
+ nonlinearity, inplace_nonlinearity, apply_noise,
436
+ False, order, 0, clamp, blur_kernel, output_scale,
437
+ init_gain)
438
+
439
+
440
+ class EmbeddingBlock(_BaseConvBlock):
441
+ def __init__(self, in_features, out_features, bias=True,
442
+ weight_norm_type='none', weight_norm_params=None,
443
+ activation_norm_type='none', activation_norm_params=None,
444
+ nonlinearity='none', inplace_nonlinearity=False,
445
+ apply_noise=False, order='CNA', clamp=None, output_scale=None,
446
+ init_gain=1.0, **_kwargs):
447
+ if bool(_kwargs):
448
+ warnings.warn(f"Unused keyword arguments {_kwargs}")
449
+ super().__init__(in_features, out_features, None, None,
450
+ None, None, None, bias,
451
+ None, weight_norm_type, weight_norm_params,
452
+ activation_norm_type, activation_norm_params,
453
+ nonlinearity, inplace_nonlinearity, apply_noise,
454
+ False, order, 0, clamp, None, output_scale,
455
+ init_gain)
456
+
457
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
458
+ padding, dilation, groups, bias, padding_mode,
459
+ input_dim):
460
+ assert input_dim == 0
461
+ return nn.Embedding(in_channels, out_channels)
462
+
463
+
464
+ class Embedding2dBlock(_BaseConvBlock):
465
+ def __init__(self, in_features, out_features, bias=True,
466
+ weight_norm_type='none', weight_norm_params=None,
467
+ activation_norm_type='none', activation_norm_params=None,
468
+ nonlinearity='none', inplace_nonlinearity=False,
469
+ apply_noise=False, order='CNA', clamp=None, output_scale=None,
470
+ init_gain=1.0, **_kwargs):
471
+ if bool(_kwargs):
472
+ warnings.warn(f"Unused keyword arguments {_kwargs}")
473
+ super().__init__(in_features, out_features, None, None,
474
+ None, None, None, bias,
475
+ None, weight_norm_type, weight_norm_params,
476
+ activation_norm_type, activation_norm_params,
477
+ nonlinearity, inplace_nonlinearity, apply_noise,
478
+ False, order, 0, clamp, None, output_scale,
479
+ init_gain)
480
+
481
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
482
+ padding, dilation, groups, bias, padding_mode,
483
+ input_dim):
484
+ assert input_dim == 0
485
+ return Embedding2d(in_channels, out_channels)
486
+
487
+
488
+ class Conv1dBlock(_BaseConvBlock):
489
+ r"""A Wrapper class that wraps ``torch.nn.Conv1d`` with normalization and
490
+ nonlinearity.
491
+
492
+ Args:
493
+ in_channels (int): Number of channels in the input tensor.
494
+ out_channels (int): Number of channels in the output tensor.
495
+ kernel_size (int or tuple): Size of the convolving kernel.
496
+ stride (int or float or tuple, optional, default=1):
497
+ Stride of the convolution.
498
+ padding (int or tuple, optional, default=0):
499
+ Zero-padding added to both sides of the input.
500
+ dilation (int or tuple, optional, default=1):
501
+ Spacing between kernel elements.
502
+ groups (int, optional, default=1): Number of blocked connections
503
+ from input channels to output channels.
504
+ bias (bool, optional, default=True):
505
+ If ``True``, adds a learnable bias to the output.
506
+ padding_mode (string, optional, default='zeros'): Type of padding:
507
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
508
+ weight_norm_type (str, optional, default='none'):
509
+ Type of weight normalization.
510
+ ``'none'``, ``'spectral'``, ``'weight'``
511
+ or ``'weight_demod'``.
512
+ weight_norm_params (obj, optional, default=None):
513
+ Parameters of weight normalization.
514
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
515
+ keyword arguments when initializing weight normalization.
516
+ activation_norm_type (str, optional, default='none'):
517
+ Type of activation normalization.
518
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
519
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
520
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
521
+ activation_norm_params (obj, optional, default=None):
522
+ Parameters of activation normalization.
523
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
524
+ keyword arguments when initializing activation normalization.
525
+ nonlinearity (str, optional, default='none'):
526
+ Type of nonlinear activation function.
527
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
528
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
529
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
530
+ set ``inplace=True`` when initializing the nonlinearity layer.
531
+ apply_noise (bool, optional, default=False): If ``True``, adds
532
+ Gaussian noise with learnable magnitude to the convolution output.
533
+ order (str, optional, default='CNA'): Order of operations.
534
+ ``'C'``: convolution,
535
+ ``'N'``: normalization,
536
+ ``'A'``: nonlinear activation.
537
+ For example, a block initialized with ``order='CNA'`` will
538
+ do convolution first, then normalization, then nonlinearity.
539
+ """
540
+
541
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
542
+ padding=0, dilation=1, groups=1, bias=True,
543
+ padding_mode='zeros',
544
+ weight_norm_type='none', weight_norm_params=None,
545
+ activation_norm_type='none', activation_norm_params=None,
546
+ nonlinearity='none', inplace_nonlinearity=False,
547
+ apply_noise=False, blur=False, order='CNA', clamp=None, output_scale=None, init_gain=1.0, **_kwargs):
548
+ super().__init__(in_channels, out_channels, kernel_size, stride,
549
+ padding, dilation, groups, bias, padding_mode,
550
+ weight_norm_type, weight_norm_params,
551
+ activation_norm_type, activation_norm_params,
552
+ nonlinearity, inplace_nonlinearity, apply_noise,
553
+ blur, order, 1, clamp, None, output_scale, init_gain)
554
+
555
+
556
+ class Conv2dBlock(_BaseConvBlock):
557
+ r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and
558
+ nonlinearity.
559
+
560
+ Args:
561
+ in_channels (int): Number of channels in the input tensor.
562
+ out_channels (int): Number of channels in the output tensor.
563
+ kernel_size (int or tuple): Size of the convolving kernel.
564
+ stride (int or float or tuple, optional, default=1):
565
+ Stride of the convolution.
566
+ padding (int or tuple, optional, default=0):
567
+ Zero-padding added to both sides of the input.
568
+ dilation (int or tuple, optional, default=1):
569
+ Spacing between kernel elements.
570
+ groups (int, optional, default=1): Number of blocked connections
571
+ from input channels to output channels.
572
+ bias (bool, optional, default=True):
573
+ If ``True``, adds a learnable bias to the output.
574
+ padding_mode (string, optional, default='zeros'): Type of padding:
575
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
576
+ weight_norm_type (str, optional, default='none'):
577
+ Type of weight normalization.
578
+ ``'none'``, ``'spectral'``, ``'weight'``
579
+ or ``'weight_demod'``.
580
+ weight_norm_params (obj, optional, default=None):
581
+ Parameters of weight normalization.
582
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
583
+ keyword arguments when initializing weight normalization.
584
+ activation_norm_type (str, optional, default='none'):
585
+ Type of activation normalization.
586
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
587
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
588
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
589
+ activation_norm_params (obj, optional, default=None):
590
+ Parameters of activation normalization.
591
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
592
+ keyword arguments when initializing activation normalization.
593
+ nonlinearity (str, optional, default='none'):
594
+ Type of nonlinear activation function.
595
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
596
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
597
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
598
+ set ``inplace=True`` when initializing the nonlinearity layer.
599
+ apply_noise (bool, optional, default=False): If ``True``, adds
600
+ Gaussian noise with learnable magnitude to the convolution output.
601
+ order (str, optional, default='CNA'): Order of operations.
602
+ ``'C'``: convolution,
603
+ ``'N'``: normalization,
604
+ ``'A'``: nonlinear activation.
605
+ For example, a block initialized with ``order='CNA'`` will
606
+ do convolution first, then normalization, then nonlinearity.
607
+ """
608
+
609
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
610
+ padding=0, dilation=1, groups=1, bias=True,
611
+ padding_mode='zeros',
612
+ weight_norm_type='none', weight_norm_params=None,
613
+ activation_norm_type='none', activation_norm_params=None,
614
+ nonlinearity='none', inplace_nonlinearity=False,
615
+ apply_noise=False, blur=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1),
616
+ output_scale=None, init_gain=1.0):
617
+ super().__init__(in_channels, out_channels, kernel_size, stride,
618
+ padding, dilation, groups, bias, padding_mode,
619
+ weight_norm_type, weight_norm_params,
620
+ activation_norm_type, activation_norm_params,
621
+ nonlinearity, inplace_nonlinearity,
622
+ apply_noise, blur, order, 2, clamp, blur_kernel, output_scale, init_gain)
623
+
624
+
625
+ class Conv3dBlock(_BaseConvBlock):
626
+ r"""A Wrapper class that wraps ``torch.nn.Conv3d`` with normalization and
627
+ nonlinearity.
628
+
629
+ Args:
630
+ in_channels (int): Number of channels in the input tensor.
631
+ out_channels (int): Number of channels in the output tensor.
632
+ kernel_size (int or tuple): Size of the convolving kernel.
633
+ stride (int or float or tuple, optional, default=1):
634
+ Stride of the convolution.
635
+ padding (int or tuple, optional, default=0):
636
+ Zero-padding added to both sides of the input.
637
+ dilation (int or tuple, optional, default=1):
638
+ Spacing between kernel elements.
639
+ groups (int, optional, default=1): Number of blocked connections
640
+ from input channels to output channels.
641
+ bias (bool, optional, default=True):
642
+ If ``True``, adds a learnable bias to the output.
643
+ padding_mode (string, optional, default='zeros'): Type of padding:
644
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
645
+ weight_norm_type (str, optional, default='none'):
646
+ Type of weight normalization.
647
+ ``'none'``, ``'spectral'``, ``'weight'``
648
+ or ``'weight_demod'``.
649
+ weight_norm_params (obj, optional, default=None):
650
+ Parameters of weight normalization.
651
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
652
+ keyword arguments when initializing weight normalization.
653
+ activation_norm_type (str, optional, default='none'):
654
+ Type of activation normalization.
655
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
656
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
657
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
658
+ activation_norm_params (obj, optional, default=None):
659
+ Parameters of activation normalization.
660
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
661
+ keyword arguments when initializing activation normalization.
662
+ nonlinearity (str, optional, default='none'):
663
+ Type of nonlinear activation function.
664
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
665
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
666
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
667
+ set ``inplace=True`` when initializing the nonlinearity layer.
668
+ apply_noise (bool, optional, default=False): If ``True``, adds
669
+ Gaussian noise with learnable magnitude to the convolution output.
670
+ order (str, optional, default='CNA'): Order of operations.
671
+ ``'C'``: convolution,
672
+ ``'N'``: normalization,
673
+ ``'A'``: nonlinear activation.
674
+ For example, a block initialized with ``order='CNA'`` will
675
+ do convolution first, then normalization, then nonlinearity.
676
+ """
677
+
678
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
679
+ padding=0, dilation=1, groups=1, bias=True,
680
+ padding_mode='zeros',
681
+ weight_norm_type='none', weight_norm_params=None,
682
+ activation_norm_type='none', activation_norm_params=None,
683
+ nonlinearity='none', inplace_nonlinearity=False,
684
+ apply_noise=False, blur=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None,
685
+ init_gain=1.0):
686
+ super().__init__(in_channels, out_channels, kernel_size, stride,
687
+ padding, dilation, groups, bias, padding_mode,
688
+ weight_norm_type, weight_norm_params,
689
+ activation_norm_type, activation_norm_params,
690
+ nonlinearity, inplace_nonlinearity,
691
+ apply_noise, blur, order, 3, clamp, blur_kernel, output_scale, init_gain)
692
+
693
+
694
+ class _BaseHyperConvBlock(_BaseConvBlock):
695
+ r"""An abstract wrapper class that wraps a hyper convolutional layer
696
+ with normalization and nonlinearity.
697
+ """
698
+
699
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
700
+ padding, dilation, groups, bias,
701
+ padding_mode,
702
+ weight_norm_type, weight_norm_params,
703
+ activation_norm_type, activation_norm_params,
704
+ nonlinearity, inplace_nonlinearity, apply_noise, blur,
705
+ is_hyper_conv, is_hyper_norm, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1),
706
+ output_scale=None, init_gain=1.0):
707
+ self.is_hyper_conv = is_hyper_conv
708
+ if is_hyper_conv:
709
+ weight_norm_type = 'none'
710
+ if is_hyper_norm:
711
+ activation_norm_type = 'hyper_' + activation_norm_type
712
+ super().__init__(in_channels, out_channels, kernel_size, stride,
713
+ padding, dilation, groups, bias, padding_mode,
714
+ weight_norm_type, weight_norm_params,
715
+ activation_norm_type, activation_norm_params,
716
+ nonlinearity, inplace_nonlinearity, apply_noise, blur,
717
+ order, input_dim, clamp, blur_kernel, output_scale, init_gain)
718
+
719
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
720
+ padding, dilation, groups, bias, padding_mode,
721
+ input_dim):
722
+ if input_dim == 0:
723
+ raise ValueError('HyperLinearBlock is not supported.')
724
+ else:
725
+ name = 'HyperConv' if self.is_hyper_conv else 'nn.Conv'
726
+ layer_type = eval(name + '%dd' % input_dim)
727
+ layer = layer_type(
728
+ in_channels, out_channels, kernel_size, stride, padding,
729
+ dilation, groups, bias, padding_mode)
730
+ return layer
731
+
732
+
733
+ class HyperConv2dBlock(_BaseHyperConvBlock):
734
+ r"""A Wrapper class that wraps ``HyperConv2d`` with normalization and
735
+ nonlinearity.
736
+
737
+ Args:
738
+ in_channels (int): Number of channels in the input tensor.
739
+ out_channels (int): Number of channels in the output tensor.
740
+ kernel_size (int or tuple): Size of the convolving kernel.
741
+ stride (int or float or tuple, optional, default=1):
742
+ Stride of the convolution.
743
+ padding (int or tuple, optional, default=0):
744
+ Zero-padding added to both sides of the input.
745
+ dilation (int or tuple, optional, default=1):
746
+ Spacing between kernel elements.
747
+ groups (int, optional, default=1): Number of blocked connections
748
+ from input channels to output channels.
749
+ bias (bool, optional, default=True):
750
+ If ``True``, adds a learnable bias to the output.
751
+ padding_mode (string, optional, default='zeros'): Type of padding:
752
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
753
+ weight_norm_type (str, optional, default='none'):
754
+ Type of weight normalization.
755
+ ``'none'``, ``'spectral'``, ``'weight'``
756
+ or ``'weight_demod'``.
757
+ weight_norm_params (obj, optional, default=None):
758
+ Parameters of weight normalization.
759
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
760
+ keyword arguments when initializing weight normalization.
761
+ activation_norm_type (str, optional, default='none'):
762
+ Type of activation normalization.
763
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
764
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
765
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
766
+ activation_norm_params (obj, optional, default=None):
767
+ Parameters of activation normalization.
768
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
769
+ keyword arguments when initializing activation normalization.
770
+ is_hyper_conv (bool, optional, default=False): If ``True``, use
771
+ ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``.
772
+ is_hyper_norm (bool, optional, default=False): If ``True``, use
773
+ hyper normalizations.
774
+ nonlinearity (str, optional, default='none'):
775
+ Type of nonlinear activation function.
776
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
777
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
778
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
779
+ set ``inplace=True`` when initializing the nonlinearity layer.
780
+ apply_noise (bool, optional, default=False): If ``True``, adds
781
+ Gaussian noise with learnable magnitude to the convolution output.
782
+ order (str, optional, default='CNA'): Order of operations.
783
+ ``'C'``: convolution,
784
+ ``'N'``: normalization,
785
+ ``'A'``: nonlinear activation.
786
+ For example, a block initialized with ``order='CNA'`` will
787
+ do convolution first, then normalization, then nonlinearity.
788
+ """
789
+
790
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
791
+ padding=0, dilation=1, groups=1, bias=True,
792
+ padding_mode='zeros',
793
+ weight_norm_type='none', weight_norm_params=None,
794
+ activation_norm_type='none', activation_norm_params=None,
795
+ is_hyper_conv=False, is_hyper_norm=False,
796
+ nonlinearity='none', inplace_nonlinearity=False,
797
+ apply_noise=False, blur=False, order='CNA', clamp=None):
798
+ super().__init__(in_channels, out_channels, kernel_size, stride,
799
+ padding, dilation, groups, bias, padding_mode,
800
+ weight_norm_type, weight_norm_params,
801
+ activation_norm_type, activation_norm_params,
802
+ nonlinearity, inplace_nonlinearity, apply_noise, blur,
803
+ is_hyper_conv, is_hyper_norm, order, 2, clamp)
804
+
805
+
806
+ class HyperConv2d(nn.Module):
807
+ r"""Hyper Conv2d initialization.
808
+
809
+ Args:
810
+ in_channels (int): Dummy parameter.
811
+ out_channels (int): Dummy parameter.
812
+ kernel_size (int or tuple): Dummy parameter.
813
+ stride (int or float or tuple, optional, default=1):
814
+ Stride of the convolution. Default: 1
815
+ padding (int or tuple, optional, default=0):
816
+ Zero-padding added to both sides of the input.
817
+ padding_mode (string, optional, default='zeros'):
818
+ ``'zeros'``, ``'reflect'``, ``'replicate'``
819
+ or ``'circular'``.
820
+ dilation (int or tuple, optional, default=1):
821
+ Spacing between kernel elements.
822
+ groups (int, optional, default=1): Number of blocked connections
823
+ from input channels to output channels.
824
+ bias (bool, optional, default=True): If ``True``,
825
+ adds a learnable bias to the output.
826
+ """
827
+
828
+ def __init__(self, in_channels=0, out_channels=0, kernel_size=3,
829
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
830
+ padding_mode='zeros'):
831
+ super().__init__()
832
+ self.stride = stride
833
+ self.padding = padding
834
+ self.dilation = dilation
835
+ self.groups = groups
836
+ self.use_bias = bias
837
+ self.padding_mode = padding_mode
838
+ self.conditional = True
839
+
840
+ def forward(self, x, *args, conv_weights=(None, None), **kwargs):
841
+ r"""Hyper Conv2d forward. Convolve x using the provided weight and bias.
842
+
843
+ Args:
844
+ x (N x C x H x W tensor): Input tensor.
845
+ conv_weights (N x C2 x C1 x k x k tensor or list of tensors):
846
+ Convolution weights or [weight, bias].
847
+ Returns:
848
+ y (N x C2 x H x W tensor): Output tensor.
849
+ """
850
+ if conv_weights is None:
851
+ conv_weight, conv_bias = None, None
852
+ elif isinstance(conv_weights, torch.Tensor):
853
+ conv_weight, conv_bias = conv_weights, None
854
+ else:
855
+ conv_weight, conv_bias = conv_weights
856
+
857
+ if conv_weight is None:
858
+ return x
859
+ if conv_bias is None:
860
+ if self.use_bias:
861
+ raise ValueError('bias not provided but set to true during '
862
+ 'initialization')
863
+ conv_bias = [None] * x.size(0)
864
+ if self.padding_mode != 'zeros':
865
+ x = F.pad(x, [self.padding] * 4, mode=self.padding_mode)
866
+ padding = 0
867
+ else:
868
+ padding = self.padding
869
+
870
+ y = None
871
+ # noinspection PyArgumentList
872
+ for i in range(x.size(0)):
873
+ if self.stride >= 1:
874
+ yi = F.conv2d(x[i: i + 1],
875
+ weight=conv_weight[i], bias=conv_bias[i],
876
+ stride=self.stride, padding=padding,
877
+ dilation=self.dilation, groups=self.groups)
878
+ else:
879
+ yi = F.conv_transpose2d(x[i: i + 1], weight=conv_weight[i],
880
+ bias=conv_bias[i], padding=self.padding,
881
+ stride=int(1 / self.stride),
882
+ dilation=self.dilation,
883
+ output_padding=self.padding,
884
+ groups=self.groups)
885
+ y = torch.cat([y, yi]) if y is not None else yi
886
+ return y
887
+
888
+
889
+ class _BasePartialConvBlock(_BaseConvBlock):
890
+ r"""An abstract wrapper class that wraps a partial convolutional layer
891
+ with normalization and nonlinearity.
892
+ """
893
+
894
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
895
+ padding, dilation, groups, bias, padding_mode,
896
+ weight_norm_type, weight_norm_params,
897
+ activation_norm_type, activation_norm_params,
898
+ nonlinearity, inplace_nonlinearity,
899
+ multi_channel, return_mask,
900
+ apply_noise, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, init_gain=1.0):
901
+ self.multi_channel = multi_channel
902
+ self.return_mask = return_mask
903
+ self.partial_conv = True
904
+ super().__init__(in_channels, out_channels, kernel_size, stride,
905
+ padding, dilation, groups, bias, padding_mode,
906
+ weight_norm_type, weight_norm_params,
907
+ activation_norm_type, activation_norm_params,
908
+ nonlinearity, inplace_nonlinearity, apply_noise,
909
+ False, order, input_dim, clamp, blur_kernel, output_scale, init_gain)
910
+
911
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
912
+ padding, dilation, groups, bias, padding_mode,
913
+ input_dim):
914
+ if input_dim == 2:
915
+ layer_type = PartialConv2d
916
+ elif input_dim == 3:
917
+ layer_type = PartialConv3d
918
+ else:
919
+ raise ValueError('Partial conv only supports 2D and 3D conv now.')
920
+ layer = layer_type(
921
+ in_channels, out_channels, kernel_size, stride, padding,
922
+ dilation, groups, bias, padding_mode,
923
+ multi_channel=self.multi_channel, return_mask=self.return_mask)
924
+ return layer
925
+
926
+ def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs):
927
+ r"""
928
+
929
+ Args:
930
+ x (tensor): Input tensor.
931
+ cond_inputs (list of tensors) : Conditional input tensors.
932
+ mask_in (tensor, optional, default=``None``) If not ``None``,
933
+ it masks the valid input region.
934
+ kw_cond_inputs (dict) : Keyword conditional inputs.
935
+ Returns:
936
+ (tuple):
937
+ - x (tensor): Output tensor.
938
+ - mask_out (tensor, optional): Masks the valid output region.
939
+ """
940
+ mask_out = None
941
+ for layer in self.layers.values():
942
+ if getattr(layer, 'conditional', False):
943
+ x = layer(x, *cond_inputs, **kw_cond_inputs)
944
+ elif getattr(layer, 'partial_conv', False):
945
+ x = layer(x, mask_in=mask_in, **kw_cond_inputs)
946
+ if type(x) == tuple:
947
+ x, mask_out = x
948
+ else:
949
+ x = layer(x)
950
+
951
+ if mask_out is not None:
952
+ return x, mask_out
953
+ return x
954
+
955
+
956
+ class PartialConv2dBlock(_BasePartialConvBlock):
957
+ r"""A Wrapper class that wraps ``PartialConv2d`` with normalization and
958
+ nonlinearity.
959
+
960
+ Args:
961
+ in_channels (int): Number of channels in the input tensor.
962
+ out_channels (int): Number of channels in the output tensor.
963
+ kernel_size (int or tuple): Size of the convolving kernel.
964
+ stride (int or float or tuple, optional, default=1):
965
+ Stride of the convolution.
966
+ padding (int or tuple, optional, default=0):
967
+ Zero-padding added to both sides of the input.
968
+ dilation (int or tuple, optional, default=1):
969
+ Spacing between kernel elements.
970
+ groups (int, optional, default=1): Number of blocked connections
971
+ from input channels to output channels.
972
+ bias (bool, optional, default=True):
973
+ If ``True``, adds a learnable bias to the output.
974
+ padding_mode (string, optional, default='zeros'): Type of padding:
975
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
976
+ weight_norm_type (str, optional, default='none'):
977
+ Type of weight normalization.
978
+ ``'none'``, ``'spectral'``, ``'weight'``
979
+ or ``'weight_demod'``.
980
+ weight_norm_params (obj, optional, default=None):
981
+ Parameters of weight normalization.
982
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
983
+ keyword arguments when initializing weight normalization.
984
+ activation_norm_type (str, optional, default='none'):
985
+ Type of activation normalization.
986
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
987
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
988
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
989
+ activation_norm_params (obj, optional, default=None):
990
+ Parameters of activation normalization.
991
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
992
+ keyword arguments when initializing activation normalization.
993
+ nonlinearity (str, optional, default='none'):
994
+ Type of nonlinear activation function.
995
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
996
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
997
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
998
+ set ``inplace=True`` when initializing the nonlinearity layer.
999
+ apply_noise (bool, optional, default=False): If ``True``, adds
1000
+ Gaussian noise with learnable magnitude to the convolution output.
1001
+ order (str, optional, default='CNA'): Order of operations.
1002
+ ``'C'``: convolution,
1003
+ ``'N'``: normalization,
1004
+ ``'A'``: nonlinear activation.
1005
+ For example, a block initialized with ``order='CNA'`` will
1006
+ do convolution first, then normalization, then nonlinearity.
1007
+ multi_channel (bool, optional, default=False): If ``True``, use
1008
+ different masks for different channels.
1009
+ return_mask (bool, optional, default=True): If ``True``, the
1010
+ forward call also returns a new mask.
1011
+ """
1012
+
1013
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
1014
+ padding=0, dilation=1, groups=1, bias=True,
1015
+ padding_mode='zeros',
1016
+ weight_norm_type='none', weight_norm_params=None,
1017
+ activation_norm_type='none', activation_norm_params=None,
1018
+ nonlinearity='none', inplace_nonlinearity=False,
1019
+ multi_channel=False, return_mask=True,
1020
+ apply_noise=False, order='CNA', clamp=None):
1021
+ super().__init__(in_channels, out_channels, kernel_size, stride,
1022
+ padding, dilation, groups, bias, padding_mode,
1023
+ weight_norm_type, weight_norm_params,
1024
+ activation_norm_type, activation_norm_params,
1025
+ nonlinearity, inplace_nonlinearity,
1026
+ multi_channel, return_mask, apply_noise, order, 2,
1027
+ clamp)
1028
+
1029
+
1030
+ class PartialConv3dBlock(_BasePartialConvBlock):
1031
+ r"""A Wrapper class that wraps ``PartialConv3d`` with normalization and
1032
+ nonlinearity.
1033
+
1034
+ Args:
1035
+ in_channels (int): Number of channels in the input tensor.
1036
+ out_channels (int): Number of channels in the output tensor.
1037
+ kernel_size (int or tuple): Size of the convolving kernel.
1038
+ stride (int or float or tuple, optional, default=1):
1039
+ Stride of the convolution.
1040
+ padding (int or tuple, optional, default=0):
1041
+ Zero-padding added to both sides of the input.
1042
+ dilation (int or tuple, optional, default=1):
1043
+ Spacing between kernel elements.
1044
+ groups (int, optional, default=1): Number of blocked connections
1045
+ from input channels to output channels.
1046
+ bias (bool, optional, default=True):
1047
+ If ``True``, adds a learnable bias to the output.
1048
+ padding_mode (string, optional, default='zeros'): Type of padding:
1049
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
1050
+ weight_norm_type (str, optional, default='none'):
1051
+ Type of weight normalization.
1052
+ ``'none'``, ``'spectral'``, ``'weight'``
1053
+ or ``'weight_demod'``.
1054
+ weight_norm_params (obj, optional, default=None):
1055
+ Parameters of weight normalization.
1056
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
1057
+ keyword arguments when initializing weight normalization.
1058
+ activation_norm_type (str, optional, default='none'):
1059
+ Type of activation normalization.
1060
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
1061
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
1062
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
1063
+ activation_norm_params (obj, optional, default=None):
1064
+ Parameters of activation normalization.
1065
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
1066
+ keyword arguments when initializing activation normalization.
1067
+ nonlinearity (str, optional, default='none'):
1068
+ Type of nonlinear activation function.
1069
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
1070
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
1071
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
1072
+ set ``inplace=True`` when initializing the nonlinearity layer.
1073
+ apply_noise (bool, optional, default=False): If ``True``, adds
1074
+ Gaussian noise with learnable magnitude to the convolution output.
1075
+ order (str, optional, default='CNA'): Order of operations.
1076
+ ``'C'``: convolution,
1077
+ ``'N'``: normalization,
1078
+ ``'A'``: nonlinear activation.
1079
+ For example, a block initialized with ``order='CNA'`` will
1080
+ do convolution first, then normalization, then nonlinearity.
1081
+ multi_channel (bool, optional, default=False): If ``True``, use
1082
+ different masks for different channels.
1083
+ return_mask (bool, optional, default=True): If ``True``, the
1084
+ forward call also returns a new mask.
1085
+ """
1086
+
1087
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
1088
+ padding=0, dilation=1, groups=1, bias=True,
1089
+ padding_mode='zeros',
1090
+ weight_norm_type='none', weight_norm_params=None,
1091
+ activation_norm_type='none', activation_norm_params=None,
1092
+ nonlinearity='none', inplace_nonlinearity=False,
1093
+ multi_channel=False, return_mask=True,
1094
+ apply_noise=False, order='CNA', clamp=None):
1095
+ super().__init__(in_channels, out_channels, kernel_size, stride,
1096
+ padding, dilation, groups, bias, padding_mode,
1097
+ weight_norm_type, weight_norm_params,
1098
+ activation_norm_type, activation_norm_params,
1099
+ nonlinearity, inplace_nonlinearity,
1100
+ multi_channel, return_mask, apply_noise, order, 3,
1101
+ clamp)
1102
+
1103
+
1104
+ class _MultiOutBaseConvBlock(_BaseConvBlock):
1105
+ r"""An abstract wrapper class that wraps a hyper convolutional layer with
1106
+ normalization and nonlinearity. It can return multiple outputs, if some
1107
+ layers in the block return more than one output.
1108
+ """
1109
+
1110
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode,
1111
+ weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, nonlinearity,
1112
+ inplace_nonlinearity, apply_noise, blur, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1),
1113
+ output_scale=None, init_gain=1.0):
1114
+ super().__init__(in_channels, out_channels, kernel_size, stride,
1115
+ padding, dilation, groups, bias, padding_mode,
1116
+ weight_norm_type, weight_norm_params,
1117
+ activation_norm_type, activation_norm_params,
1118
+ nonlinearity, inplace_nonlinearity,
1119
+ apply_noise, blur, order, input_dim, clamp, blur_kernel, output_scale, init_gain)
1120
+ self.multiple_outputs = True
1121
+
1122
+ def forward(self, x, *cond_inputs, **kw_cond_inputs):
1123
+ r"""
1124
+
1125
+ Args:
1126
+ x (tensor): Input tensor.
1127
+ cond_inputs (list of tensors) : Conditional input tensors.
1128
+ kw_cond_inputs (dict) : Keyword conditional inputs.
1129
+ Returns:
1130
+ (tuple):
1131
+ - x (tensor): Main output tensor.
1132
+ - other_outputs (list of tensors): Other output tensors.
1133
+ """
1134
+ other_outputs = []
1135
+ for layer in self.layers.values():
1136
+ if getattr(layer, 'conditional', False):
1137
+ x = layer(x, *cond_inputs, **kw_cond_inputs)
1138
+ if getattr(layer, 'multiple_outputs', False):
1139
+ x, other_output = layer(x)
1140
+ other_outputs.append(other_output)
1141
+ else:
1142
+ x = layer(x)
1143
+ return (x, *other_outputs)
1144
+
1145
+
1146
+ class MultiOutConv2dBlock(_MultiOutBaseConvBlock):
1147
+ r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and
1148
+ nonlinearity. It can return multiple outputs, if some layers in the block
1149
+ return more than one output.
1150
+
1151
+ Args:
1152
+ in_channels (int): Number of channels in the input tensor.
1153
+ out_channels (int): Number of channels in the output tensor.
1154
+ kernel_size (int or tuple): Size of the convolving kernel.
1155
+ stride (int or float or tuple, optional, default=1):
1156
+ Stride of the convolution.
1157
+ padding (int or tuple, optional, default=0):
1158
+ Zero-padding added to both sides of the input.
1159
+ dilation (int or tuple, optional, default=1):
1160
+ Spacing between kernel elements.
1161
+ groups (int, optional, default=1): Number of blocked connections
1162
+ from input channels to output channels.
1163
+ bias (bool, optional, default=True):
1164
+ If ``True``, adds a learnable bias to the output.
1165
+ padding_mode (string, optional, default='zeros'): Type of padding:
1166
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
1167
+ weight_norm_type (str, optional, default='none'):
1168
+ Type of weight normalization.
1169
+ ``'none'``, ``'spectral'``, ``'weight'``
1170
+ or ``'weight_demod'``.
1171
+ weight_norm_params (obj, optional, default=None):
1172
+ Parameters of weight normalization.
1173
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
1174
+ keyword arguments when initializing weight normalization.
1175
+ activation_norm_type (str, optional, default='none'):
1176
+ Type of activation normalization.
1177
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
1178
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
1179
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
1180
+ activation_norm_params (obj, optional, default=None):
1181
+ Parameters of activation normalization.
1182
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
1183
+ keyword arguments when initializing activation normalization.
1184
+ nonlinearity (str, optional, default='none'):
1185
+ Type of nonlinear activation function.
1186
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
1187
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
1188
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
1189
+ set ``inplace=True`` when initializing the nonlinearity layer.
1190
+ apply_noise (bool, optional, default=False): If ``True``, adds
1191
+ Gaussian noise with learnable magnitude to the convolution output.
1192
+ order (str, optional, default='CNA'): Order of operations.
1193
+ ``'C'``: convolution,
1194
+ ``'N'``: normalization,
1195
+ ``'A'``: nonlinear activation.
1196
+ For example, a block initialized with ``order='CNA'`` will
1197
+ do convolution first, then normalization, then nonlinearity.
1198
+ """
1199
+
1200
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
1201
+ padding=0, dilation=1, groups=1, bias=True,
1202
+ padding_mode='zeros',
1203
+ weight_norm_type='none', weight_norm_params=None,
1204
+ activation_norm_type='none', activation_norm_params=None,
1205
+ nonlinearity='none', inplace_nonlinearity=False,
1206
+ apply_noise=False, blur=False, order='CNA', clamp=None):
1207
+ super().__init__(in_channels, out_channels, kernel_size, stride,
1208
+ padding, dilation, groups, bias, padding_mode,
1209
+ weight_norm_type, weight_norm_params,
1210
+ activation_norm_type, activation_norm_params,
1211
+ nonlinearity, inplace_nonlinearity,
1212
+ apply_noise, blur, order, 2, clamp)
1213
+
1214
+
1215
+ ###############################################################################
1216
+ # BSD 3-Clause License
1217
+ #
1218
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
1219
+ #
1220
+ # Author & Contact: Guilin Liu (guilinl@nvidia.com)
1221
+ ###############################################################################
1222
+ class PartialConv2d(nn.Conv2d):
1223
+ r"""Partial 2D convolution in
1224
+ "Image inpainting for irregular holes using partial convolutions."
1225
+ Liu et al., ECCV 2018
1226
+ """
1227
+
1228
+ def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs):
1229
+ # whether the mask is multi-channel or not
1230
+ self.multi_channel = multi_channel
1231
+ self.return_mask = return_mask
1232
+ super(PartialConv2d, self).__init__(*args, **kwargs)
1233
+
1234
+ if self.multi_channel:
1235
+ self.weight_maskUpdater = torch.ones(self.out_channels,
1236
+ self.in_channels,
1237
+ self.kernel_size[0],
1238
+ self.kernel_size[1])
1239
+ else:
1240
+ self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0],
1241
+ self.kernel_size[1])
1242
+
1243
+ shape = self.weight_maskUpdater.shape
1244
+ self.slide_winsize = shape[1] * shape[2] * shape[3]
1245
+
1246
+ self.last_size = (None, None, None, None)
1247
+ self.update_mask = None
1248
+ self.mask_ratio = None
1249
+ self.partial_conv = True
1250
+
1251
+ def forward(self, x, mask_in=None):
1252
+ r"""
1253
+
1254
+ Args:
1255
+ x (tensor): Input tensor.
1256
+ mask_in (tensor, optional, default=``None``) If not ``None``,
1257
+ it masks the valid input region.
1258
+ """
1259
+ assert len(x.shape) == 4
1260
+ if mask_in is not None or self.last_size != tuple(x.shape):
1261
+ self.last_size = tuple(x.shape)
1262
+
1263
+ with torch.no_grad():
1264
+ if self.weight_maskUpdater.type() != x.type():
1265
+ self.weight_maskUpdater = self.weight_maskUpdater.to(x)
1266
+
1267
+ if mask_in is None:
1268
+ # If mask is not provided, create a mask.
1269
+ if self.multi_channel:
1270
+ mask = torch.ones(x.data.shape[0],
1271
+ x.data.shape[1],
1272
+ x.data.shape[2],
1273
+ x.data.shape[3]).to(x)
1274
+ else:
1275
+ mask = torch.ones(1, 1, x.data.shape[2],
1276
+ x.data.shape[3]).to(x)
1277
+ else:
1278
+ mask = mask_in
1279
+
1280
+ self.update_mask = F.conv2d(mask, self.weight_maskUpdater,
1281
+ bias=None, stride=self.stride,
1282
+ padding=self.padding,
1283
+ dilation=self.dilation, groups=1)
1284
+
1285
+ # For mixed precision training, eps from 1e-8 to 1e-6.
1286
+ eps = 1e-6
1287
+ self.mask_ratio = self.slide_winsize / (self.update_mask + eps)
1288
+ self.update_mask = torch.clamp(self.update_mask, 0, 1)
1289
+ self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
1290
+
1291
+ raw_out = super(PartialConv2d, self).forward(
1292
+ torch.mul(x, mask) if mask_in is not None else x)
1293
+
1294
+ if self.bias is not None:
1295
+ bias_view = self.bias.view(1, self.out_channels, 1, 1)
1296
+ output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
1297
+ output = torch.mul(output, self.update_mask)
1298
+ else:
1299
+ output = torch.mul(raw_out, self.mask_ratio)
1300
+
1301
+ if self.return_mask:
1302
+ return output, self.update_mask
1303
+ else:
1304
+ return output
1305
+
1306
+
1307
+ class PartialConv3d(nn.Conv3d):
1308
+ r"""Partial 3D convolution in
1309
+ "Image inpainting for irregular holes using partial convolutions."
1310
+ Liu et al., ECCV 2018
1311
+ """
1312
+
1313
+ def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs):
1314
+ # whether the mask is multi-channel or not
1315
+ self.multi_channel = multi_channel
1316
+ self.return_mask = return_mask
1317
+ super(PartialConv3d, self).__init__(*args, **kwargs)
1318
+
1319
+ if self.multi_channel:
1320
+ self.weight_maskUpdater = \
1321
+ torch.ones(self.out_channels, self.in_channels,
1322
+ self.kernel_size[0], self.kernel_size[1],
1323
+ self.kernel_size[2])
1324
+ else:
1325
+ self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0],
1326
+ self.kernel_size[1],
1327
+ self.kernel_size[2])
1328
+ self.weight_maskUpdater = self.weight_maskUpdater.to('cuda')
1329
+
1330
+ shape = self.weight_maskUpdater.shape
1331
+ self.slide_winsize = shape[1] * shape[2] * shape[3] * shape[4]
1332
+ self.partial_conv = True
1333
+
1334
+ def forward(self, x, mask_in=None):
1335
+ r"""
1336
+
1337
+ Args:
1338
+ x (tensor): Input tensor.
1339
+ mask_in (tensor, optional, default=``None``) If not ``None``, it
1340
+ masks the valid input region.
1341
+ """
1342
+ assert len(x.shape) == 5
1343
+
1344
+ with torch.no_grad():
1345
+ mask = mask_in
1346
+ update_mask = F.conv3d(mask, self.weight_maskUpdater, bias=None,
1347
+ stride=self.stride, padding=self.padding,
1348
+ dilation=self.dilation, groups=1)
1349
+
1350
+ mask_ratio = self.slide_winsize / (update_mask + 1e-8)
1351
+ update_mask = torch.clamp(update_mask, 0, 1)
1352
+ mask_ratio = torch.mul(mask_ratio, update_mask)
1353
+
1354
+ raw_out = super(PartialConv3d, self).forward(torch.mul(x, mask_in))
1355
+
1356
+ if self.bias is not None:
1357
+ bias_view = self.bias.view(1, self.out_channels, 1, 1, 1)
1358
+ output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view
1359
+ if mask_in is not None:
1360
+ output = torch.mul(output, update_mask)
1361
+ else:
1362
+ output = torch.mul(raw_out, mask_ratio)
1363
+
1364
+ if self.return_mask:
1365
+ return output, update_mask
1366
+ else:
1367
+ return output
1368
+
1369
+
1370
+ class Embedding2d(nn.Embedding):
1371
+ def __init__(self, in_channels, out_channels):
1372
+ super().__init__(in_channels, out_channels)
1373
+
1374
+ def forward(self, x):
1375
+ return F.embedding(
1376
+ x.squeeze(1).long(), self.weight, self.padding_idx, self.max_norm,
1377
+ self.norm_type, self.scale_grad_by_freq, self.sparse).permute(0, 3, 1, 2).contiguous()
imaginaire/layers/misc.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import torch
6
+ from torch import nn
7
+
8
+
9
+ class ApplyNoise(nn.Module):
10
+ r"""Add Gaussian noise to the input tensor."""
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+ # scale of the noise
15
+ self.scale = nn.Parameter(torch.zeros(1))
16
+ self.conditional = True
17
+
18
+ def forward(self, x, *_args, noise=None, **_kwargs):
19
+ r"""
20
+
21
+ Args:
22
+ x (tensor): Input tensor.
23
+ noise (tensor, optional, default=``None``) : Noise tensor to be
24
+ added to the input.
25
+ """
26
+ if noise is None:
27
+ sz = x.size()
28
+ noise = x.new_empty(sz[0], 1, *sz[2:]).normal_()
29
+
30
+ return x + self.scale * noise
31
+
32
+
33
+ class PartialSequential(nn.Sequential):
34
+ r"""Sequential block for partial convolutions."""
35
+ def __init__(self, *modules):
36
+ super(PartialSequential, self).__init__(*modules)
37
+
38
+ def forward(self, x):
39
+ r"""
40
+
41
+ Args:
42
+ x (tensor): Input tensor.
43
+ """
44
+ act = x[:, :-1]
45
+ mask = x[:, -1].unsqueeze(1)
46
+ for module in self:
47
+ act, mask = module(act, mask_in=mask)
48
+ return act
49
+
50
+
51
+ class ConstantInput(nn.Module):
52
+ def __init__(self, channel, size=4):
53
+ super().__init__()
54
+ if isinstance(size, int):
55
+ h, w = size, size
56
+ else:
57
+ h, w = size
58
+ self.input = nn.Parameter(torch.randn(1, channel, h, w))
59
+
60
+ def forward(self):
61
+ return self.input
imaginaire/layers/non_local.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ from functools import partial
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from imaginaire.layers import Conv2dBlock
11
+
12
+
13
+ class NonLocal2dBlock(nn.Module):
14
+ r"""Self attention Layer
15
+
16
+ Args:
17
+ in_channels (int): Number of channels in the input tensor.
18
+ scale (bool, optional, default=True): If ``True``, scale the
19
+ output by a learnable parameter.
20
+ clamp (bool, optional, default=``False``): If ``True``, clamp the
21
+ scaling parameter to (-1, 1).
22
+ weight_norm_type (str, optional, default='none'):
23
+ Type of weight normalization.
24
+ ``'none'``, ``'spectral'``, ``'weight'``.
25
+ weight_norm_params (obj, optional, default=None):
26
+ Parameters of weight normalization.
27
+ If not ``None``, weight_norm_params.__dict__ will be used as
28
+ keyword arguments when initializing weight normalization.
29
+ bias (bool, optional, default=True): If ``True``, adds bias in the
30
+ convolutional blocks.
31
+ """
32
+
33
+ def __init__(self,
34
+ in_channels,
35
+ scale=True,
36
+ clamp=False,
37
+ weight_norm_type='none',
38
+ weight_norm_params=None,
39
+ bias=True):
40
+ super(NonLocal2dBlock, self).__init__()
41
+ self.clamp = clamp
42
+ self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0
43
+ self.in_channels = in_channels
44
+ base_conv2d_block = partial(Conv2dBlock,
45
+ kernel_size=1,
46
+ stride=1,
47
+ padding=0,
48
+ weight_norm_type=weight_norm_type,
49
+ weight_norm_params=weight_norm_params,
50
+ bias=bias)
51
+ self.theta = base_conv2d_block(in_channels, in_channels // 8)
52
+ self.phi = base_conv2d_block(in_channels, in_channels // 8)
53
+ self.g = base_conv2d_block(in_channels, in_channels // 2)
54
+ self.out_conv = base_conv2d_block(in_channels // 2, in_channels)
55
+ self.softmax = nn.Softmax(dim=-1)
56
+ self.max_pool = nn.MaxPool2d(2)
57
+
58
+ def forward(self, x):
59
+ r"""
60
+
61
+ Args:
62
+ x (tensor) : input feature maps (B X C X W X H)
63
+ Returns:
64
+ (tuple):
65
+ - out (tensor) : self attention value + input feature
66
+ - attention (tensor): B x N x N (N is Width*Height)
67
+ """
68
+ n, c, h, w = x.size()
69
+ theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1).contiguous()
70
+
71
+ phi = self.phi(x)
72
+ phi = self.max_pool(phi).view(n, -1, h * w // 4)
73
+
74
+ energy = torch.bmm(theta, phi)
75
+ attention = self.softmax(energy)
76
+
77
+ g = self.g(x)
78
+ g = self.max_pool(g).view(n, -1, h * w // 4)
79
+
80
+ out = torch.bmm(g, attention.permute(0, 2, 1).contiguous())
81
+ out = out.view(n, c // 2, h, w)
82
+ out = self.out_conv(out)
83
+
84
+ if self.clamp:
85
+ out = self.gamma.clamp(-1, 1) * out + x
86
+ else:
87
+ out = self.gamma * out + x
88
+ return out
imaginaire/layers/nonlinearity.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ from imaginaire.third_party.bias_act.bias_act import FusedNonlinearity
10
+
11
+
12
+ class ScaledLeakyReLU(nn.Module):
13
+ def __init__(self, negative_slope=0.2, scale=2 ** 0.5, inplace=False):
14
+ super().__init__()
15
+
16
+ self.negative_slope = negative_slope
17
+ self.scale = scale
18
+ self.inplace = inplace
19
+
20
+ def forward(self, x):
21
+ return F.leaky_relu(x, self.negative_slope, inplace=self.inplace) * self.scale
22
+ # return _fused_scaled_leakyrelu(x, self.negative_slope, self.inplace, self.scale)
23
+
24
+
25
+ # @torch.jit.script
26
+ # def _fused_scaled_leakyrelu(x: torch.Tensor, negative_slope: float, inplace: bool, scale: float):
27
+ # return F.leaky_relu(x, negative_slope, inplace=inplace) * scale
28
+
29
+
30
+ def get_nonlinearity_layer(nonlinearity_type, inplace, **kwargs):
31
+ r"""Return a nonlinearity layer.
32
+
33
+ Args:
34
+ nonlinearity_type (str):
35
+ Type of nonlinear activation function.
36
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
37
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
38
+ inplace (bool): If ``True``, set ``inplace=True`` when initializing
39
+ the nonlinearity layer.
40
+ """
41
+ if nonlinearity_type.startswith('fused'):
42
+ nonlinearity = FusedNonlinearity(nonlinearity=nonlinearity_type[6:], **kwargs)
43
+ elif nonlinearity_type == 'relu':
44
+ nonlinearity = nn.ReLU(inplace=inplace)
45
+ elif nonlinearity_type == 'leakyrelu':
46
+ nonlinearity = nn.LeakyReLU(0.2, inplace=inplace)
47
+ elif nonlinearity_type == 'scaled_leakyrelu':
48
+ nonlinearity = ScaledLeakyReLU(0.2, inplace=inplace)
49
+ import imaginaire.config
50
+ if imaginaire.config.USE_JIT:
51
+ nonlinearity = torch.jit.script(nonlinearity)
52
+ elif nonlinearity_type == 'prelu':
53
+ nonlinearity = nn.PReLU()
54
+ elif nonlinearity_type == 'tanh':
55
+ nonlinearity = nn.Tanh()
56
+ elif nonlinearity_type == 'sigmoid':
57
+ nonlinearity = nn.Sigmoid()
58
+ elif nonlinearity_type.startswith('softmax'):
59
+ dim = nonlinearity_type.split(',')[1] if ',' in nonlinearity_type else 1
60
+ nonlinearity = nn.Softmax(dim=int(dim))
61
+ elif nonlinearity_type == 'none' or nonlinearity_type == '':
62
+ nonlinearity = None
63
+ else:
64
+ raise ValueError('Nonlinearity %s is not recognized' % nonlinearity_type)
65
+ return nonlinearity
imaginaire/layers/residual.py ADDED
@@ -0,0 +1,1411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import functools
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import Upsample as NearestUpsample
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+ from .conv import (Conv1dBlock, Conv2dBlock, Conv3dBlock, HyperConv2dBlock,
13
+ LinearBlock, MultiOutConv2dBlock, PartialConv2dBlock,
14
+ PartialConv3dBlock, ModulatedConv2dBlock)
15
+ from imaginaire.third_party.upfirdn2d.upfirdn2d import BlurUpsample
16
+
17
+
18
+ class _BaseResBlock(nn.Module):
19
+ r"""An abstract class for residual blocks.
20
+ """
21
+
22
+ def __init__(self, in_channels, out_channels, kernel_size,
23
+ stride, padding, dilation, groups, bias, padding_mode,
24
+ weight_norm_type, weight_norm_params,
25
+ activation_norm_type, activation_norm_params,
26
+ skip_activation_norm, skip_nonlinearity,
27
+ nonlinearity, inplace_nonlinearity, apply_noise,
28
+ hidden_channels_equal_out_channels,
29
+ order, block, learn_shortcut, clamp, output_scale,
30
+ skip_block=None, blur=False, upsample_first=True, skip_weight_norm=True):
31
+ super().__init__()
32
+ self.in_channels = in_channels
33
+ self.out_channels = out_channels
34
+ self.output_scale = output_scale
35
+ self.upsample_first = upsample_first
36
+ self.stride = stride
37
+ self.blur = blur
38
+ if skip_block is None:
39
+ skip_block = block
40
+
41
+ if order == 'pre_act':
42
+ order = 'NACNAC'
43
+ if isinstance(bias, bool):
44
+ # The bias for conv_block_0, conv_block_1, and conv_block_s.
45
+ biases = [bias, bias, bias]
46
+ elif isinstance(bias, list):
47
+ if len(bias) == 3:
48
+ biases = bias
49
+ else:
50
+ raise ValueError('Bias list must be 3.')
51
+ else:
52
+ raise ValueError('Bias must be either an integer or s list.')
53
+ if learn_shortcut is None:
54
+ self.learn_shortcut = (in_channels != out_channels)
55
+ else:
56
+ self.learn_shortcut = learn_shortcut
57
+ if len(order) > 6 or len(order) < 5:
58
+ raise ValueError('order must be either 5 or 6 characters')
59
+ if hidden_channels_equal_out_channels:
60
+ hidden_channels = out_channels
61
+ else:
62
+ hidden_channels = min(in_channels, out_channels)
63
+
64
+ # Parameters.
65
+ residual_params = {}
66
+ shortcut_params = {}
67
+ base_params = dict(dilation=dilation,
68
+ groups=groups,
69
+ padding_mode=padding_mode,
70
+ clamp=clamp)
71
+ residual_params.update(base_params)
72
+ residual_params.update(
73
+ dict(activation_norm_type=activation_norm_type,
74
+ activation_norm_params=activation_norm_params,
75
+ weight_norm_type=weight_norm_type,
76
+ weight_norm_params=weight_norm_params,
77
+ padding=padding,
78
+ apply_noise=apply_noise))
79
+ shortcut_params.update(base_params)
80
+ shortcut_params.update(dict(kernel_size=1))
81
+ if skip_activation_norm:
82
+ shortcut_params.update(
83
+ dict(activation_norm_type=activation_norm_type,
84
+ activation_norm_params=activation_norm_params,
85
+ apply_noise=False))
86
+ if skip_weight_norm:
87
+ shortcut_params.update(
88
+ dict(weight_norm_type=weight_norm_type,
89
+ weight_norm_params=weight_norm_params))
90
+
91
+ # Residual branch.
92
+ if order.find('A') < order.find('C') and \
93
+ (activation_norm_type == '' or activation_norm_type == 'none'):
94
+ # Nonlinearity is the first operation in the residual path.
95
+ # In-place nonlinearity will modify the input variable and cause
96
+ # backward error.
97
+ first_inplace = False
98
+ else:
99
+ first_inplace = inplace_nonlinearity
100
+
101
+ (first_stride, second_stride, shortcut_stride,
102
+ first_blur, second_blur, shortcut_blur) = self._get_stride_blur()
103
+ self.conv_block_0 = block(
104
+ in_channels, hidden_channels,
105
+ kernel_size=kernel_size,
106
+ bias=biases[0],
107
+ nonlinearity=nonlinearity,
108
+ order=order[0:3],
109
+ inplace_nonlinearity=first_inplace,
110
+ stride=first_stride,
111
+ blur=first_blur,
112
+ **residual_params
113
+ )
114
+ self.conv_block_1 = block(
115
+ hidden_channels, out_channels,
116
+ kernel_size=kernel_size,
117
+ bias=biases[1],
118
+ nonlinearity=nonlinearity,
119
+ order=order[3:],
120
+ inplace_nonlinearity=inplace_nonlinearity,
121
+ stride=second_stride,
122
+ blur=second_blur,
123
+ **residual_params
124
+ )
125
+
126
+ # Shortcut branch.
127
+ if self.learn_shortcut:
128
+ if skip_nonlinearity:
129
+ skip_nonlinearity_type = nonlinearity
130
+ else:
131
+ skip_nonlinearity_type = ''
132
+ self.conv_block_s = skip_block(in_channels, out_channels,
133
+ bias=biases[2],
134
+ nonlinearity=skip_nonlinearity_type,
135
+ order=order[0:3],
136
+ stride=shortcut_stride,
137
+ blur=shortcut_blur,
138
+ **shortcut_params)
139
+ elif in_channels < out_channels:
140
+ if skip_nonlinearity:
141
+ skip_nonlinearity_type = nonlinearity
142
+ else:
143
+ skip_nonlinearity_type = ''
144
+ self.conv_block_s = skip_block(in_channels,
145
+ out_channels - in_channels,
146
+ bias=biases[2],
147
+ nonlinearity=skip_nonlinearity_type,
148
+ order=order[0:3],
149
+ stride=shortcut_stride,
150
+ blur=shortcut_blur,
151
+ **shortcut_params)
152
+
153
+ # Whether this block expects conditional inputs.
154
+ self.conditional = \
155
+ getattr(self.conv_block_0, 'conditional', False) or \
156
+ getattr(self.conv_block_1, 'conditional', False)
157
+
158
+ def _get_stride_blur(self):
159
+ if self.stride > 1:
160
+ # Downsampling.
161
+ first_stride, second_stride = 1, self.stride
162
+ first_blur, second_blur = False, self.blur
163
+ shortcut_stride = self.stride
164
+ shortcut_blur = self.blur
165
+ self.upsample = None
166
+ elif self.stride < 1:
167
+ # Upsampling.
168
+ first_stride, second_stride = self.stride, 1
169
+ first_blur, second_blur = self.blur, False
170
+ shortcut_blur = False
171
+ shortcut_stride = 1
172
+ if self.blur:
173
+ # The shortcut branch uses blur_upsample + stride-1 conv
174
+ self.upsample = BlurUpsample()
175
+ else:
176
+ shortcut_stride = self.stride
177
+ self.upsample = nn.Upsample(scale_factor=2)
178
+ else:
179
+ first_stride = second_stride = 1
180
+ first_blur = second_blur = False
181
+ shortcut_stride = 1
182
+ shortcut_blur = False
183
+ self.upsample = None
184
+ return (first_stride, second_stride, shortcut_stride,
185
+ first_blur, second_blur, shortcut_blur)
186
+
187
+ def conv_blocks(
188
+ self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs
189
+ ):
190
+ r"""Returns the output of the residual branch.
191
+
192
+ Args:
193
+ x (tensor): Input tensor.
194
+ cond_inputs (list of tensors) : Conditional input tensors.
195
+ kw_cond_inputs (dict) : Keyword conditional inputs.
196
+ Returns:
197
+ dx (tensor): Output tensor.
198
+ """
199
+ if separate_cond:
200
+ dx = self.conv_block_0(x, cond_inputs[0],
201
+ **kw_cond_inputs.get('kwargs_0', {}))
202
+ dx = self.conv_block_1(dx, cond_inputs[1],
203
+ **kw_cond_inputs.get('kwargs_1', {}))
204
+ else:
205
+ dx = self.conv_block_0(x, *cond_inputs, **kw_cond_inputs)
206
+ dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
207
+ return dx
208
+
209
+ def forward(self, x, *cond_inputs, do_checkpoint=False, separate_cond=False,
210
+ **kw_cond_inputs):
211
+ r"""
212
+
213
+ Args:
214
+ x (tensor): Input tensor.
215
+ cond_inputs (list of tensors) : Conditional input tensors.
216
+ do_checkpoint (bool, optional, default=``False``) If ``True``,
217
+ trade compute for memory by checkpointing the model.
218
+ kw_cond_inputs (dict) : Keyword conditional inputs.
219
+ Returns:
220
+ output (tensor): Output tensor.
221
+ """
222
+ if do_checkpoint:
223
+ dx = checkpoint(self.conv_blocks, x, *cond_inputs,
224
+ separate_cond=separate_cond, **kw_cond_inputs)
225
+ else:
226
+ dx = self.conv_blocks(x, *cond_inputs,
227
+ separate_cond=separate_cond, **kw_cond_inputs)
228
+
229
+ if self.upsample_first and self.upsample is not None:
230
+ x = self.upsample(x)
231
+ if self.learn_shortcut:
232
+ if separate_cond:
233
+ x_shortcut = self.conv_block_s(
234
+ x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {})
235
+ )
236
+ else:
237
+ x_shortcut = self.conv_block_s(
238
+ x, *cond_inputs, **kw_cond_inputs
239
+ )
240
+ elif self.in_channels < self.out_channels:
241
+ if separate_cond:
242
+ x_shortcut_pad = self.conv_block_s(
243
+ x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {})
244
+ )
245
+ else:
246
+ x_shortcut_pad = self.conv_block_s(
247
+ x, *cond_inputs, **kw_cond_inputs
248
+ )
249
+ x_shortcut = torch.cat((x, x_shortcut_pad), dim=1)
250
+ elif self.in_channels > self.out_channels:
251
+ x_shortcut = x[:, :self.out_channels, :, :]
252
+ else:
253
+ x_shortcut = x
254
+ if not self.upsample_first and self.upsample is not None:
255
+ x_shortcut = self.upsample(x_shortcut)
256
+
257
+ output = x_shortcut + dx
258
+ return self.output_scale * output
259
+
260
+ def extra_repr(self):
261
+ s = 'output_scale={output_scale}'
262
+ return s.format(**self.__dict__)
263
+
264
+
265
+ class ModulatedRes2dBlock(_BaseResBlock):
266
+ def __init__(self, in_channels, out_channels, style_dim, kernel_size=3,
267
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
268
+ padding_mode='zeros',
269
+ weight_norm_type='none', weight_norm_params=None,
270
+ activation_norm_type='none', activation_norm_params=None,
271
+ skip_activation_norm=True, skip_nonlinearity=False,
272
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
273
+ apply_noise=True, hidden_channels_equal_out_channels=False,
274
+ order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1,
275
+ demodulate=True, eps=1e-8):
276
+ block = functools.partial(ModulatedConv2dBlock,
277
+ style_dim=style_dim,
278
+ demodulate=demodulate, eps=eps)
279
+ skip_block = Conv2dBlock
280
+ super().__init__(in_channels, out_channels, kernel_size, stride,
281
+ padding, dilation, groups, bias, padding_mode,
282
+ weight_norm_type, weight_norm_params,
283
+ activation_norm_type, activation_norm_params,
284
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
285
+ inplace_nonlinearity, apply_noise,
286
+ hidden_channels_equal_out_channels, order, block,
287
+ learn_shortcut, clamp, output_scale, skip_block=skip_block)
288
+
289
+ def conv_blocks(self, x, *cond_inputs, **kw_cond_inputs):
290
+ assert len(list(cond_inputs)) == 2
291
+ dx = self.conv_block_0(x, cond_inputs[0], **kw_cond_inputs)
292
+ dx = self.conv_block_1(dx, cond_inputs[1], **kw_cond_inputs)
293
+ return dx
294
+
295
+
296
+ class ResLinearBlock(_BaseResBlock):
297
+ r"""Residual block with full-connected layers.
298
+
299
+ Args:
300
+ in_channels (int) : Number of channels in the input tensor.
301
+ out_channels (int) : Number of channels in the output tensor.
302
+ weight_norm_type (str, optional, default='none'):
303
+ Type of weight normalization.
304
+ ``'none'``, ``'spectral'``, ``'weight'``
305
+ or ``'weight_demod'``.
306
+ weight_norm_params (obj, optional, default=None):
307
+ Parameters of weight normalization.
308
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
309
+ keyword arguments when initializing weight normalization.
310
+ activation_norm_type (str, optional, default='none'):
311
+ Type of activation normalization.
312
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
313
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
314
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
315
+ activation_norm_params (obj, optional, default=None):
316
+ Parameters of activation normalization.
317
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
318
+ keyword arguments when initializing activation normalization.
319
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
320
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
321
+ learned shortcut connection.
322
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
323
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
324
+ learned shortcut connection.
325
+ nonlinearity (str, optional, default='none'):
326
+ Type of nonlinear activation function in the residual link.
327
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
328
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
329
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
330
+ set ``inplace=True`` when initializing the nonlinearity layers.
331
+ apply_noise (bool, optional, default=False): If ``True``, add
332
+ Gaussian noise with learnable magnitude after the
333
+ fully-connected layer.
334
+ hidden_channels_equal_out_channels (bool, optional, default=False):
335
+ If ``True``, set the hidden channel number to be equal to the
336
+ output channel number. If ``False``, the hidden channel number
337
+ equals to the smaller of the input channel number and the
338
+ output channel number.
339
+ order (str, optional, default='CNACNA'): Order of operations
340
+ in the residual link.
341
+ ``'C'``: fully-connected,
342
+ ``'N'``: normalization,
343
+ ``'A'``: nonlinear activation.
344
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
345
+ a convolutional shortcut instead of an identity one, otherwise only
346
+ use a convolutional one if input and output have different number of
347
+ channels.
348
+ """
349
+
350
+ def __init__(self, in_channels, out_channels, bias=True,
351
+ weight_norm_type='none', weight_norm_params=None,
352
+ activation_norm_type='none', activation_norm_params=None,
353
+ skip_activation_norm=True, skip_nonlinearity=False,
354
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
355
+ apply_noise=False, hidden_channels_equal_out_channels=False,
356
+ order='CNACNA', learn_shortcut=None, clamp=None,
357
+ output_scale=1):
358
+ super().__init__(in_channels, out_channels, None, 1, None, None,
359
+ None, bias, None, weight_norm_type, weight_norm_params,
360
+ activation_norm_type, activation_norm_params,
361
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
362
+ inplace_nonlinearity, apply_noise,
363
+ hidden_channels_equal_out_channels, order, LinearBlock,
364
+ learn_shortcut, clamp, output_scale)
365
+
366
+
367
+ class Res1dBlock(_BaseResBlock):
368
+ r"""Residual block for 1D input.
369
+
370
+ Args:
371
+ in_channels (int) : Number of channels in the input tensor.
372
+ out_channels (int) : Number of channels in the output tensor.
373
+ kernel_size (int, optional, default=3): Kernel size for the
374
+ convolutional filters in the residual link.
375
+ padding (int, optional, default=1): Padding size.
376
+ dilation (int, optional, default=1): Dilation factor.
377
+ groups (int, optional, default=1): Number of convolutional/linear
378
+ groups.
379
+ padding_mode (string, optional, default='zeros'): Type of padding:
380
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
381
+ weight_norm_type (str, optional, default='none'):
382
+ Type of weight normalization.
383
+ ``'none'``, ``'spectral'``, ``'weight'``
384
+ or ``'weight_demod'``.
385
+ weight_norm_params (obj, optional, default=None):
386
+ Parameters of weight normalization.
387
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
388
+ keyword arguments when initializing weight normalization.
389
+ activation_norm_type (str, optional, default='none'):
390
+ Type of activation normalization.
391
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
392
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
393
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
394
+ activation_norm_params (obj, optional, default=None):
395
+ Parameters of activation normalization.
396
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
397
+ keyword arguments when initializing activation normalization.
398
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
399
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
400
+ learned shortcut connection.
401
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
402
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
403
+ learned shortcut connection.
404
+ nonlinearity (str, optional, default='none'):
405
+ Type of nonlinear activation function in the residual link.
406
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
407
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
408
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
409
+ set ``inplace=True`` when initializing the nonlinearity layers.
410
+ apply_noise (bool, optional, default=False): If ``True``, adds
411
+ Gaussian noise with learnable magnitude to the convolution output.
412
+ hidden_channels_equal_out_channels (bool, optional, default=False):
413
+ If ``True``, set the hidden channel number to be equal to the
414
+ output channel number. If ``False``, the hidden channel number
415
+ equals to the smaller of the input channel number and the
416
+ output channel number.
417
+ order (str, optional, default='CNACNA'): Order of operations
418
+ in the residual link.
419
+ ``'C'``: convolution,
420
+ ``'N'``: normalization,
421
+ ``'A'``: nonlinear activation.
422
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
423
+ a convolutional shortcut instead of an identity one, otherwise only
424
+ use a convolutional one if input and output have different number of
425
+ channels.
426
+ """
427
+
428
+ def __init__(self, in_channels, out_channels, kernel_size=3,
429
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
430
+ padding_mode='zeros',
431
+ weight_norm_type='none', weight_norm_params=None,
432
+ activation_norm_type='none', activation_norm_params=None,
433
+ skip_activation_norm=True, skip_nonlinearity=False,
434
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
435
+ apply_noise=False, hidden_channels_equal_out_channels=False,
436
+ order='CNACNA', learn_shortcut=None, clamp=None,
437
+ output_scale=1):
438
+ super().__init__(in_channels, out_channels, kernel_size, stride,
439
+ padding, dilation, groups, bias, padding_mode,
440
+ weight_norm_type, weight_norm_params,
441
+ activation_norm_type, activation_norm_params,
442
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
443
+ inplace_nonlinearity, apply_noise,
444
+ hidden_channels_equal_out_channels, order, Conv1dBlock,
445
+ learn_shortcut, clamp, output_scale)
446
+
447
+
448
+ class Res2dBlock(_BaseResBlock):
449
+ r"""Residual block for 2D input.
450
+
451
+ Args:
452
+ in_channels (int) : Number of channels in the input tensor.
453
+ out_channels (int) : Number of channels in the output tensor.
454
+ kernel_size (int, optional, default=3): Kernel size for the
455
+ convolutional filters in the residual link.
456
+ padding (int, optional, default=1): Padding size.
457
+ dilation (int, optional, default=1): Dilation factor.
458
+ groups (int, optional, default=1): Number of convolutional/linear
459
+ groups.
460
+ padding_mode (string, optional, default='zeros'): Type of padding:
461
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
462
+ weight_norm_type (str, optional, default='none'):
463
+ Type of weight normalization.
464
+ ``'none'``, ``'spectral'``, ``'weight'``
465
+ or ``'weight_demod'``.
466
+ weight_norm_params (obj, optional, default=None):
467
+ Parameters of weight normalization.
468
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
469
+ keyword arguments when initializing weight normalization.
470
+ activation_norm_type (str, optional, default='none'):
471
+ Type of activation normalization.
472
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
473
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
474
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
475
+ activation_norm_params (obj, optional, default=None):
476
+ Parameters of activation normalization.
477
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
478
+ keyword arguments when initializing activation normalization.
479
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
480
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
481
+ learned shortcut connection.
482
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
483
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
484
+ learned shortcut connection.
485
+ nonlinearity (str, optional, default='none'):
486
+ Type of nonlinear activation function in the residual link.
487
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
488
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
489
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
490
+ set ``inplace=True`` when initializing the nonlinearity layers.
491
+ apply_noise (bool, optional, default=False): If ``True``, adds
492
+ Gaussian noise with learnable magnitude to the convolution output.
493
+ hidden_channels_equal_out_channels (bool, optional, default=False):
494
+ If ``True``, set the hidden channel number to be equal to the
495
+ output channel number. If ``False``, the hidden channel number
496
+ equals to the smaller of the input channel number and the
497
+ output channel number.
498
+ order (str, optional, default='CNACNA'): Order of operations
499
+ in the residual link.
500
+ ``'C'``: convolution,
501
+ ``'N'``: normalization,
502
+ ``'A'``: nonlinear activation.
503
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
504
+ a convolutional shortcut instead of an identity one, otherwise only
505
+ use a convolutional one if input and output have different number of
506
+ channels.
507
+ """
508
+
509
+ def __init__(self, in_channels, out_channels, kernel_size=3,
510
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
511
+ padding_mode='zeros',
512
+ weight_norm_type='none', weight_norm_params=None,
513
+ activation_norm_type='none', activation_norm_params=None,
514
+ skip_activation_norm=True, skip_nonlinearity=False,
515
+ skip_weight_norm=True,
516
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
517
+ apply_noise=False, hidden_channels_equal_out_channels=False,
518
+ order='CNACNA', learn_shortcut=None, clamp=None,
519
+ output_scale=1, blur=False, upsample_first=True):
520
+ super().__init__(in_channels, out_channels, kernel_size, stride,
521
+ padding, dilation, groups, bias, padding_mode,
522
+ weight_norm_type, weight_norm_params,
523
+ activation_norm_type, activation_norm_params,
524
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
525
+ inplace_nonlinearity, apply_noise,
526
+ hidden_channels_equal_out_channels, order, Conv2dBlock,
527
+ learn_shortcut, clamp, output_scale, blur=blur,
528
+ upsample_first=upsample_first,
529
+ skip_weight_norm=skip_weight_norm)
530
+
531
+
532
+ class Res3dBlock(_BaseResBlock):
533
+ r"""Residual block for 3D input.
534
+
535
+ Args:
536
+ in_channels (int) : Number of channels in the input tensor.
537
+ out_channels (int) : Number of channels in the output tensor.
538
+ kernel_size (int, optional, default=3): Kernel size for the
539
+ convolutional filters in the residual link.
540
+ padding (int, optional, default=1): Padding size.
541
+ dilation (int, optional, default=1): Dilation factor.
542
+ groups (int, optional, default=1): Number of convolutional/linear
543
+ groups.
544
+ padding_mode (string, optional, default='zeros'): Type of padding:
545
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
546
+ weight_norm_type (str, optional, default='none'):
547
+ Type of weight normalization.
548
+ ``'none'``, ``'spectral'``, ``'weight'``
549
+ or ``'weight_demod'``.
550
+ weight_norm_params (obj, optional, default=None):
551
+ Parameters of weight normalization.
552
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
553
+ keyword arguments when initializing weight normalization.
554
+ activation_norm_type (str, optional, default='none'):
555
+ Type of activation normalization.
556
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
557
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
558
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
559
+ activation_norm_params (obj, optional, default=None):
560
+ Parameters of activation normalization.
561
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
562
+ keyword arguments when initializing activation normalization.
563
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
564
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
565
+ learned shortcut connection.
566
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
567
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
568
+ learned shortcut connection.
569
+ nonlinearity (str, optional, default='none'):
570
+ Type of nonlinear activation function in the residual link.
571
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
572
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
573
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
574
+ set ``inplace=True`` when initializing the nonlinearity layers.
575
+ apply_noise (bool, optional, default=False): If ``True``, adds
576
+ Gaussian noise with learnable magnitude to the convolution output.
577
+ hidden_channels_equal_out_channels (bool, optional, default=False):
578
+ If ``True``, set the hidden channel number to be equal to the
579
+ output channel number. If ``False``, the hidden channel number
580
+ equals to the smaller of the input channel number and the
581
+ output channel number.
582
+ order (str, optional, default='CNACNA'): Order of operations
583
+ in the residual link.
584
+ ``'C'``: convolution,
585
+ ``'N'``: normalization,
586
+ ``'A'``: nonlinear activation.
587
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
588
+ a convolutional shortcut instead of an identity one, otherwise only
589
+ use a convolutional one if input and output have different number of
590
+ channels.
591
+ """
592
+
593
+ def __init__(self, in_channels, out_channels, kernel_size=3,
594
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
595
+ padding_mode='zeros',
596
+ weight_norm_type='none', weight_norm_params=None,
597
+ activation_norm_type='none', activation_norm_params=None,
598
+ skip_activation_norm=True, skip_nonlinearity=False,
599
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
600
+ apply_noise=False, hidden_channels_equal_out_channels=False,
601
+ order='CNACNA', learn_shortcut=None, clamp=None,
602
+ output_scale=1):
603
+ super().__init__(in_channels, out_channels, kernel_size, stride,
604
+ padding, dilation, groups, bias, padding_mode,
605
+ weight_norm_type, weight_norm_params,
606
+ activation_norm_type, activation_norm_params,
607
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
608
+ inplace_nonlinearity, apply_noise,
609
+ hidden_channels_equal_out_channels, order, Conv3dBlock,
610
+ learn_shortcut, clamp, output_scale)
611
+
612
+
613
+ class _BaseHyperResBlock(_BaseResBlock):
614
+ r"""An abstract class for hyper residual blocks.
615
+ """
616
+
617
+ def __init__(self, in_channels, out_channels, kernel_size,
618
+ stride, padding, dilation, groups, bias, padding_mode,
619
+ weight_norm_type, weight_norm_params,
620
+ activation_norm_type, activation_norm_params,
621
+ skip_activation_norm, skip_nonlinearity,
622
+ nonlinearity, inplace_nonlinearity, apply_noise,
623
+ hidden_channels_equal_out_channels,
624
+ order, is_hyper_conv, is_hyper_norm, block, learn_shortcut,
625
+ clamp=None, output_scale=1):
626
+ block = functools.partial(block,
627
+ is_hyper_conv=is_hyper_conv,
628
+ is_hyper_norm=is_hyper_norm)
629
+ super().__init__(in_channels, out_channels, kernel_size, stride,
630
+ padding, dilation, groups, bias, padding_mode,
631
+ weight_norm_type, weight_norm_params,
632
+ activation_norm_type, activation_norm_params,
633
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
634
+ inplace_nonlinearity, apply_noise,
635
+ hidden_channels_equal_out_channels, order, block,
636
+ learn_shortcut, clamp, output_scale)
637
+
638
+ def forward(self, x, *cond_inputs, conv_weights=(None,) * 3,
639
+ norm_weights=(None,) * 3, **kw_cond_inputs):
640
+ r"""
641
+
642
+ Args:
643
+ x (tensor): Input tensor.
644
+ cond_inputs (list of tensors) : Conditional input tensors.
645
+ conv_weights (list of tensors): Convolution weights for
646
+ three convolutional layers respectively.
647
+ norm_weights (list of tensors): Normalization weights for
648
+ three convolutional layers respectively.
649
+ kw_cond_inputs (dict) : Keyword conditional inputs.
650
+ Returns:
651
+ output (tensor): Output tensor.
652
+ """
653
+ dx = self.conv_block_0(x, *cond_inputs, conv_weights=conv_weights[0],
654
+ norm_weights=norm_weights[0])
655
+ dx = self.conv_block_1(dx, *cond_inputs, conv_weights=conv_weights[1],
656
+ norm_weights=norm_weights[1])
657
+ if self.learn_shortcut:
658
+ x_shortcut = self.conv_block_s(x, *cond_inputs,
659
+ conv_weights=conv_weights[2],
660
+ norm_weights=norm_weights[2])
661
+ else:
662
+ x_shortcut = x
663
+ output = x_shortcut + dx
664
+ return self.output_scale * output
665
+
666
+
667
+ class HyperRes2dBlock(_BaseHyperResBlock):
668
+ r"""Hyper residual block for 2D input.
669
+
670
+ Args:
671
+ in_channels (int) : Number of channels in the input tensor.
672
+ out_channels (int) : Number of channels in the output tensor.
673
+ kernel_size (int, optional, default=3): Kernel size for the
674
+ convolutional filters in the residual link.
675
+ padding (int, optional, default=1): Padding size.
676
+ dilation (int, optional, default=1): Dilation factor.
677
+ groups (int, optional, default=1): Number of convolutional/linear
678
+ groups.
679
+ padding_mode (string, optional, default='zeros'): Type of padding:
680
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
681
+ weight_norm_type (str, optional, default='none'):
682
+ Type of weight normalization.
683
+ ``'none'``, ``'spectral'``, ``'weight'``
684
+ or ``'weight_demod'``.
685
+ weight_norm_params (obj, optional, default=None):
686
+ Parameters of weight normalization.
687
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
688
+ keyword arguments when initializing weight normalization.
689
+ activation_norm_type (str, optional, default='none'):
690
+ Type of activation normalization.
691
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
692
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
693
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
694
+ activation_norm_params (obj, optional, default=None):
695
+ Parameters of activation normalization.
696
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
697
+ keyword arguments when initializing activation normalization.
698
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
699
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
700
+ learned shortcut connection.
701
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
702
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
703
+ learned shortcut connection.
704
+ nonlinearity (str, optional, default='none'):
705
+ Type of nonlinear activation function in the residual link.
706
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
707
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
708
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
709
+ set ``inplace=True`` when initializing the nonlinearity layers.
710
+ apply_noise (bool, optional, default=False): If ``True``, adds
711
+ Gaussian noise with learnable magnitude to the convolution output.
712
+ hidden_channels_equal_out_channels (bool, optional, default=False):
713
+ If ``True``, set the hidden channel number to be equal to the
714
+ output channel number. If ``False``, the hidden channel number
715
+ equals to the smaller of the input channel number and the
716
+ output channel number.
717
+ order (str, optional, default='CNACNA'): Order of operations
718
+ in the residual link.
719
+ ``'C'``: convolution,
720
+ ``'N'``: normalization,
721
+ ``'A'``: nonlinear activation.
722
+ is_hyper_conv (bool, optional, default=False): If ``True``, use
723
+ ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``.
724
+ is_hyper_norm (bool, optional, default=False): If ``True``, use
725
+ hyper normalizations.
726
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
727
+ a convolutional shortcut instead of an identity one, otherwise only
728
+ use a convolutional one if input and output have different number of
729
+ channels.
730
+ """
731
+
732
+ def __init__(self, in_channels, out_channels, kernel_size=3,
733
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
734
+ padding_mode='zeros',
735
+ weight_norm_type='', weight_norm_params=None,
736
+ activation_norm_type='', activation_norm_params=None,
737
+ skip_activation_norm=True, skip_nonlinearity=False,
738
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
739
+ apply_noise=False, hidden_channels_equal_out_channels=False,
740
+ order='CNACNA', is_hyper_conv=False, is_hyper_norm=False,
741
+ learn_shortcut=None, clamp=None, output_scale=1):
742
+ super().__init__(in_channels, out_channels, kernel_size,
743
+ stride, padding, dilation, groups, bias, padding_mode,
744
+ weight_norm_type, weight_norm_params,
745
+ activation_norm_type, activation_norm_params,
746
+ skip_activation_norm, skip_nonlinearity,
747
+ nonlinearity, inplace_nonlinearity, apply_noise,
748
+ hidden_channels_equal_out_channels,
749
+ order, is_hyper_conv, is_hyper_norm,
750
+ HyperConv2dBlock, learn_shortcut, clamp, output_scale)
751
+
752
+
753
+ class _BaseDownResBlock(_BaseResBlock):
754
+ r"""An abstract class for residual blocks with downsampling.
755
+ """
756
+
757
+ def __init__(self, in_channels, out_channels, kernel_size,
758
+ stride, padding, dilation, groups, bias, padding_mode,
759
+ weight_norm_type, weight_norm_params,
760
+ activation_norm_type, activation_norm_params,
761
+ skip_activation_norm, skip_nonlinearity,
762
+ nonlinearity, inplace_nonlinearity,
763
+ apply_noise, hidden_channels_equal_out_channels,
764
+ order, block, pooling, down_factor, learn_shortcut,
765
+ clamp=None, output_scale=1):
766
+ super().__init__(in_channels, out_channels, kernel_size,
767
+ stride, padding, dilation, groups, bias, padding_mode,
768
+ weight_norm_type, weight_norm_params,
769
+ activation_norm_type, activation_norm_params,
770
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
771
+ inplace_nonlinearity, apply_noise,
772
+ hidden_channels_equal_out_channels, order, block,
773
+ learn_shortcut, clamp, output_scale)
774
+ self.pooling = pooling(down_factor)
775
+
776
+ def forward(self, x, *cond_inputs):
777
+ r"""
778
+
779
+ Args:
780
+ x (tensor) : Input tensor.
781
+ cond_inputs (list of tensors) : conditional input.
782
+ Returns:
783
+ output (tensor) : Output tensor.
784
+ """
785
+ dx = self.conv_block_0(x, *cond_inputs)
786
+ dx = self.conv_block_1(dx, *cond_inputs)
787
+ dx = self.pooling(dx)
788
+ if self.learn_shortcut:
789
+ x_shortcut = self.conv_block_s(x, *cond_inputs)
790
+ else:
791
+ x_shortcut = x
792
+ x_shortcut = self.pooling(x_shortcut)
793
+ output = x_shortcut + dx
794
+ return self.output_scale * output
795
+
796
+
797
+ class DownRes2dBlock(_BaseDownResBlock):
798
+ r"""Residual block for 2D input with downsampling.
799
+
800
+ Args:
801
+ in_channels (int) : Number of channels in the input tensor.
802
+ out_channels (int) : Number of channels in the output tensor.
803
+ kernel_size (int, optional, default=3): Kernel size for the
804
+ convolutional filters in the residual link.
805
+ padding (int, optional, default=1): Padding size.
806
+ dilation (int, optional, default=1): Dilation factor.
807
+ groups (int, optional, default=1): Number of convolutional/linear
808
+ groups.
809
+ padding_mode (string, optional, default='zeros'): Type of padding:
810
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
811
+ weight_norm_type (str, optional, default='none'):
812
+ Type of weight normalization.
813
+ ``'none'``, ``'spectral'``, ``'weight'``
814
+ or ``'weight_demod'``.
815
+ weight_norm_params (obj, optional, default=None):
816
+ Parameters of weight normalization.
817
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
818
+ keyword arguments when initializing weight normalization.
819
+ activation_norm_type (str, optional, default='none'):
820
+ Type of activation normalization.
821
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
822
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
823
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
824
+ activation_norm_params (obj, optional, default=None):
825
+ Parameters of activation normalization.
826
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
827
+ keyword arguments when initializing activation normalization.
828
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
829
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
830
+ learned shortcut connection.
831
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
832
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
833
+ learned shortcut connection.
834
+ nonlinearity (str, optional, default='none'):
835
+ Type of nonlinear activation function in the residual link.
836
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
837
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
838
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
839
+ set ``inplace=True`` when initializing the nonlinearity layers.
840
+ apply_noise (bool, optional, default=False): If ``True``, adds
841
+ Gaussian noise with learnable magnitude to the convolution output.
842
+ hidden_channels_equal_out_channels (bool, optional, default=False):
843
+ If ``True``, set the hidden channel number to be equal to the
844
+ output channel number. If ``False``, the hidden channel number
845
+ equals to the smaller of the input channel number and the
846
+ output channel number.
847
+ order (str, optional, default='CNACNA'): Order of operations
848
+ in the residual link.
849
+ ``'C'``: convolution,
850
+ ``'N'``: normalization,
851
+ ``'A'``: nonlinear activation.
852
+ pooling (class, optional, default=nn.AvgPool2d): Pytorch pooling
853
+ layer to be used.
854
+ down_factor (int, optional, default=2): Downsampling factor.
855
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
856
+ a convolutional shortcut instead of an identity one, otherwise only
857
+ use a convolutional one if input and output have different number of
858
+ channels.
859
+ """
860
+
861
+ def __init__(self, in_channels, out_channels, kernel_size=3,
862
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
863
+ padding_mode='zeros',
864
+ weight_norm_type='none', weight_norm_params=None,
865
+ activation_norm_type='none', activation_norm_params=None,
866
+ skip_activation_norm=True, skip_nonlinearity=False,
867
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
868
+ apply_noise=False, hidden_channels_equal_out_channels=False,
869
+ order='CNACNA', pooling=nn.AvgPool2d, down_factor=2,
870
+ learn_shortcut=None, clamp=None, output_scale=1):
871
+ super().__init__(in_channels, out_channels, kernel_size,
872
+ stride, padding, dilation, groups, bias, padding_mode,
873
+ weight_norm_type, weight_norm_params,
874
+ activation_norm_type, activation_norm_params,
875
+ skip_activation_norm, skip_nonlinearity,
876
+ nonlinearity, inplace_nonlinearity, apply_noise,
877
+ hidden_channels_equal_out_channels,
878
+ order, Conv2dBlock, pooling,
879
+ down_factor, learn_shortcut, clamp, output_scale)
880
+
881
+
882
+ class _BaseUpResBlock(_BaseResBlock):
883
+ r"""An abstract class for residual blocks with upsampling.
884
+ """
885
+
886
+ def __init__(self, in_channels, out_channels, kernel_size,
887
+ stride, padding, dilation, groups, bias, padding_mode,
888
+ weight_norm_type, weight_norm_params,
889
+ activation_norm_type, activation_norm_params,
890
+ skip_activation_norm, skip_nonlinearity,
891
+ nonlinearity, inplace_nonlinearity,
892
+ apply_noise, hidden_channels_equal_out_channels,
893
+ order, block, upsample, up_factor, learn_shortcut, clamp=None,
894
+ output_scale=1):
895
+ super().__init__(in_channels, out_channels, kernel_size,
896
+ stride, padding, dilation, groups, bias, padding_mode,
897
+ weight_norm_type, weight_norm_params,
898
+ activation_norm_type, activation_norm_params,
899
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
900
+ inplace_nonlinearity, apply_noise,
901
+ hidden_channels_equal_out_channels, order, block,
902
+ learn_shortcut, clamp, output_scale)
903
+ self.order = order
904
+ self.upsample = upsample(scale_factor=up_factor)
905
+
906
+ def _get_stride_blur(self):
907
+ # Upsampling.
908
+ first_stride, second_stride = self.stride, 1
909
+ first_blur, second_blur = self.blur, False
910
+ shortcut_blur = False
911
+ shortcut_stride = 1
912
+ # if self.upsample == 'blur_deconv':
913
+
914
+ if self.blur:
915
+ # The shortcut branch uses blur_upsample + stride-1 conv
916
+ self.upsample = BlurUpsample()
917
+ else:
918
+ shortcut_stride = self.stride
919
+ self.upsample = nn.Upsample(scale_factor=2)
920
+
921
+ return (first_stride, second_stride, shortcut_stride,
922
+ first_blur, second_blur, shortcut_blur)
923
+
924
+ def forward(self, x, *cond_inputs):
925
+ r"""Implementation of the up residual block forward function.
926
+ If the order is 'NAC' for the first residual block, we will first
927
+ do the activation norm and nonlinearity, in the original resolution.
928
+ We will then upsample the activation map to a higher resolution. We
929
+ then do the convolution.
930
+ It is is other orders, then we first do the whole processing and
931
+ then upsample.
932
+
933
+ Args:
934
+ x (tensor) : Input tensor.
935
+ cond_inputs (list of tensors) : Conditional input.
936
+ Returns:
937
+ output (tensor) : Output tensor.
938
+ """
939
+ # In this particular upsample residual block operation, we first
940
+ # upsample the skip connection.
941
+ if self.learn_shortcut:
942
+ x_shortcut = self.upsample(x)
943
+ x_shortcut = self.conv_block_s(x_shortcut, *cond_inputs)
944
+ else:
945
+ x_shortcut = self.upsample(x)
946
+
947
+ if self.order[0:3] == 'NAC':
948
+ for ix, layer in enumerate(self.conv_block_0.layers.values()):
949
+ if getattr(layer, 'conditional', False):
950
+ x = layer(x, *cond_inputs)
951
+ else:
952
+ x = layer(x)
953
+ if ix == 1:
954
+ x = self.upsample(x)
955
+ else:
956
+ x = self.conv_block_0(x, *cond_inputs)
957
+ x = self.upsample(x)
958
+ x = self.conv_block_1(x, *cond_inputs)
959
+
960
+ output = x_shortcut + x
961
+ return self.output_scale * output
962
+
963
+
964
+ class UpRes2dBlock(_BaseUpResBlock):
965
+ r"""Residual block for 2D input with downsampling.
966
+
967
+ Args:
968
+ in_channels (int) : Number of channels in the input tensor.
969
+ out_channels (int) : Number of channels in the output tensor.
970
+ kernel_size (int, optional, default=3): Kernel size for the
971
+ convolutional filters in the residual link.
972
+ padding (int, optional, default=1): Padding size.
973
+ dilation (int, optional, default=1): Dilation factor.
974
+ groups (int, optional, default=1): Number of convolutional/linear
975
+ groups.
976
+ padding_mode (string, optional, default='zeros'): Type of padding:
977
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
978
+ weight_norm_type (str, optional, default='none'):
979
+ Type of weight normalization.
980
+ ``'none'``, ``'spectral'``, ``'weight'``
981
+ or ``'weight_demod'``.
982
+ weight_norm_params (obj, optional, default=None):
983
+ Parameters of weight normalization.
984
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
985
+ keyword arguments when initializing weight normalization.
986
+ activation_norm_type (str, optional, default='none'):
987
+ Type of activation normalization.
988
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
989
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
990
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
991
+ activation_norm_params (obj, optional, default=None):
992
+ Parameters of activation normalization.
993
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
994
+ keyword arguments when initializing activation normalization.
995
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
996
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
997
+ learned shortcut connection.
998
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
999
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
1000
+ learned shortcut connection.
1001
+ nonlinearity (str, optional, default='none'):
1002
+ Type of nonlinear activation function in the residual link.
1003
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
1004
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
1005
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
1006
+ set ``inplace=True`` when initializing the nonlinearity layers.
1007
+ apply_noise (bool, optional, default=False): If ``True``, adds
1008
+ Gaussian noise with learnable magnitude to the convolution output.
1009
+ hidden_channels_equal_out_channels (bool, optional, default=False):
1010
+ If ``True``, set the hidden channel number to be equal to the
1011
+ output channel number. If ``False``, the hidden channel number
1012
+ equals to the smaller of the input channel number and the
1013
+ output channel number.
1014
+ order (str, optional, default='CNACNA'): Order of operations
1015
+ in the residual link.
1016
+ ``'C'``: convolution,
1017
+ ``'N'``: normalization,
1018
+ ``'A'``: nonlinear activation.
1019
+ upsample (class, optional, default=NearestUpsample): PPytorch
1020
+ upsampling layer to be used.
1021
+ up_factor (int, optional, default=2): Upsampling factor.
1022
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
1023
+ a convolutional shortcut instead of an identity one, otherwise only
1024
+ use a convolutional one if input and output have different number of
1025
+ channels.
1026
+ """
1027
+
1028
+ def __init__(self, in_channels, out_channels, kernel_size=3,
1029
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
1030
+ padding_mode='zeros',
1031
+ weight_norm_type='none', weight_norm_params=None,
1032
+ activation_norm_type='none', activation_norm_params=None,
1033
+ skip_activation_norm=True, skip_nonlinearity=False,
1034
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
1035
+ apply_noise=False, hidden_channels_equal_out_channels=False,
1036
+ order='CNACNA', upsample=NearestUpsample, up_factor=2,
1037
+ learn_shortcut=None, clamp=None, output_scale=1):
1038
+ super().__init__(in_channels, out_channels, kernel_size,
1039
+ stride, padding, dilation, groups, bias, padding_mode,
1040
+ weight_norm_type, weight_norm_params,
1041
+ activation_norm_type, activation_norm_params,
1042
+ skip_activation_norm, skip_nonlinearity,
1043
+ nonlinearity, inplace_nonlinearity,
1044
+ apply_noise, hidden_channels_equal_out_channels,
1045
+ order, Conv2dBlock,
1046
+ upsample, up_factor, learn_shortcut, clamp,
1047
+ output_scale)
1048
+
1049
+
1050
+ class _BasePartialResBlock(_BaseResBlock):
1051
+ r"""An abstract class for residual blocks with partial convolution.
1052
+ """
1053
+
1054
+ def __init__(self, in_channels, out_channels, kernel_size,
1055
+ stride, padding, dilation, groups, bias, padding_mode,
1056
+ weight_norm_type, weight_norm_params,
1057
+ activation_norm_type, activation_norm_params,
1058
+ skip_activation_norm, skip_nonlinearity,
1059
+ nonlinearity, inplace_nonlinearity,
1060
+ multi_channel, return_mask,
1061
+ apply_noise, hidden_channels_equal_out_channels,
1062
+ order, block, learn_shortcut, clamp=None, output_scale=1):
1063
+ block = functools.partial(block,
1064
+ multi_channel=multi_channel,
1065
+ return_mask=return_mask)
1066
+ self.partial_conv = True
1067
+ super().__init__(in_channels, out_channels, kernel_size, stride,
1068
+ padding, dilation, groups, bias, padding_mode,
1069
+ weight_norm_type, weight_norm_params,
1070
+ activation_norm_type, activation_norm_params,
1071
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
1072
+ inplace_nonlinearity, apply_noise,
1073
+ hidden_channels_equal_out_channels, order, block,
1074
+ learn_shortcut, clamp, output_scale)
1075
+
1076
+ def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs):
1077
+ r"""
1078
+
1079
+ Args:
1080
+ x (tensor): Input tensor.
1081
+ cond_inputs (list of tensors) : Conditional input tensors.
1082
+ mask_in (tensor, optional, default=``None``) If not ``None``,
1083
+ it masks the valid input region.
1084
+ kw_cond_inputs (dict) : Keyword conditional inputs.
1085
+ Returns:
1086
+ (tuple):
1087
+ - output (tensor): Output tensor.
1088
+ - mask_out (tensor, optional): Masks the valid output region.
1089
+ """
1090
+ if self.conv_block_0.layers.conv.return_mask:
1091
+ dx, mask_out = self.conv_block_0(x, *cond_inputs,
1092
+ mask_in=mask_in, **kw_cond_inputs)
1093
+ dx, mask_out = self.conv_block_1(dx, *cond_inputs,
1094
+ mask_in=mask_out, **kw_cond_inputs)
1095
+ else:
1096
+ dx = self.conv_block_0(x, *cond_inputs,
1097
+ mask_in=mask_in, **kw_cond_inputs)
1098
+ dx = self.conv_block_1(dx, *cond_inputs,
1099
+ mask_in=mask_in, **kw_cond_inputs)
1100
+ mask_out = None
1101
+
1102
+ if self.learn_shortcut:
1103
+ x_shortcut = self.conv_block_s(x, mask_in=mask_in, *cond_inputs,
1104
+ **kw_cond_inputs)
1105
+ if type(x_shortcut) == tuple:
1106
+ x_shortcut, _ = x_shortcut
1107
+ else:
1108
+ x_shortcut = x
1109
+ output = x_shortcut + dx
1110
+
1111
+ if mask_out is not None:
1112
+ return output, mask_out
1113
+ return self.output_scale * output
1114
+
1115
+
1116
+ class PartialRes2dBlock(_BasePartialResBlock):
1117
+ r"""Residual block for 2D input with partial convolution.
1118
+
1119
+ Args:
1120
+ in_channels (int) : Number of channels in the input tensor.
1121
+ out_channels (int) : Number of channels in the output tensor.
1122
+ kernel_size (int, optional, default=3): Kernel size for the
1123
+ convolutional filters in the residual link.
1124
+ padding (int, optional, default=1): Padding size.
1125
+ dilation (int, optional, default=1): Dilation factor.
1126
+ groups (int, optional, default=1): Number of convolutional/linear
1127
+ groups.
1128
+ padding_mode (string, optional, default='zeros'): Type of padding:
1129
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
1130
+ weight_norm_type (str, optional, default='none'):
1131
+ Type of weight normalization.
1132
+ ``'none'``, ``'spectral'``, ``'weight'``
1133
+ or ``'weight_demod'``.
1134
+ weight_norm_params (obj, optional, default=None):
1135
+ Parameters of weight normalization.
1136
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
1137
+ keyword arguments when initializing weight normalization.
1138
+ activation_norm_type (str, optional, default='none'):
1139
+ Type of activation normalization.
1140
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
1141
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
1142
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
1143
+ activation_norm_params (obj, optional, default=None):
1144
+ Parameters of activation normalization.
1145
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
1146
+ keyword arguments when initializing activation normalization.
1147
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
1148
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
1149
+ learned shortcut connection.
1150
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
1151
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
1152
+ learned shortcut connection.
1153
+ nonlinearity (str, optional, default='none'):
1154
+ Type of nonlinear activation function in the residual link.
1155
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
1156
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
1157
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
1158
+ set ``inplace=True`` when initializing the nonlinearity layers.
1159
+ apply_noise (bool, optional, default=False): If ``True``, adds
1160
+ Gaussian noise with learnable magnitude to the convolution output.
1161
+ hidden_channels_equal_out_channels (bool, optional, default=False):
1162
+ If ``True``, set the hidden channel number to be equal to the
1163
+ output channel number. If ``False``, the hidden channel number
1164
+ equals to the smaller of the input channel number and the
1165
+ output channel number.
1166
+ order (str, optional, default='CNACNA'): Order of operations
1167
+ in the residual link.
1168
+ ``'C'``: convolution,
1169
+ ``'N'``: normalization,
1170
+ ``'A'``: nonlinear activation.
1171
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
1172
+ a convolutional shortcut instead of an identity one, otherwise only
1173
+ use a convolutional one if input and output have different number of
1174
+ channels.
1175
+ """
1176
+
1177
+ def __init__(self, in_channels, out_channels, kernel_size=3,
1178
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
1179
+ padding_mode='zeros',
1180
+ weight_norm_type='none', weight_norm_params=None,
1181
+ activation_norm_type='none', activation_norm_params=None,
1182
+ skip_activation_norm=True, skip_nonlinearity=False,
1183
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
1184
+ multi_channel=False, return_mask=True,
1185
+ apply_noise=False,
1186
+ hidden_channels_equal_out_channels=False,
1187
+ order='CNACNA', learn_shortcut=None, clamp=None,
1188
+ output_scale=1):
1189
+ super().__init__(in_channels, out_channels, kernel_size,
1190
+ stride, padding, dilation, groups, bias,
1191
+ padding_mode, weight_norm_type, weight_norm_params,
1192
+ activation_norm_type, activation_norm_params,
1193
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
1194
+ inplace_nonlinearity, multi_channel, return_mask,
1195
+ apply_noise, hidden_channels_equal_out_channels,
1196
+ order, PartialConv2dBlock, learn_shortcut, clamp,
1197
+ output_scale)
1198
+
1199
+
1200
+ class PartialRes3dBlock(_BasePartialResBlock):
1201
+ r"""Residual block for 3D input with partial convolution.
1202
+
1203
+ Args:
1204
+ in_channels (int) : Number of channels in the input tensor.
1205
+ out_channels (int) : Number of channels in the output tensor.
1206
+ kernel_size (int, optional, default=3): Kernel size for the
1207
+ convolutional filters in the residual link.
1208
+ padding (int, optional, default=1): Padding size.
1209
+ dilation (int, optional, default=1): Dilation factor.
1210
+ groups (int, optional, default=1): Number of convolutional/linear
1211
+ groups.
1212
+ padding_mode (string, optional, default='zeros'): Type of padding:
1213
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
1214
+ weight_norm_type (str, optional, default='none'):
1215
+ Type of weight normalization.
1216
+ ``'none'``, ``'spectral'``, ``'weight'``
1217
+ or ``'weight_demod'``.
1218
+ weight_norm_params (obj, optional, default=None):
1219
+ Parameters of weight normalization.
1220
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
1221
+ keyword arguments when initializing weight normalization.
1222
+ activation_norm_type (str, optional, default='none'):
1223
+ Type of activation normalization.
1224
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
1225
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
1226
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
1227
+ activation_norm_params (obj, optional, default=None):
1228
+ Parameters of activation normalization.
1229
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
1230
+ keyword arguments when initializing activation normalization.
1231
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
1232
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
1233
+ learned shortcut connection.
1234
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
1235
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
1236
+ learned shortcut connection.
1237
+ nonlinearity (str, optional, default='none'):
1238
+ Type of nonlinear activation function in the residual link.
1239
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
1240
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
1241
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
1242
+ set ``inplace=True`` when initializing the nonlinearity layers.
1243
+ apply_noise (bool, optional, default=False): If ``True``, adds
1244
+ Gaussian noise with learnable magnitude to the convolution output.
1245
+ hidden_channels_equal_out_channels (bool, optional, default=False):
1246
+ If ``True``, set the hidden channel number to be equal to the
1247
+ output channel number. If ``False``, the hidden channel number
1248
+ equals to the smaller of the input channel number and the
1249
+ output channel number.
1250
+ order (str, optional, default='CNACNA'): Order of operations
1251
+ in the residual link.
1252
+ ``'C'``: convolution,
1253
+ ``'N'``: normalization,
1254
+ ``'A'``: nonlinear activation.
1255
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
1256
+ a convolutional shortcut instead of an identity one, otherwise only
1257
+ use a convolutional one if input and output have different number of
1258
+ channels.
1259
+ """
1260
+
1261
+ def __init__(self, in_channels, out_channels, kernel_size=3,
1262
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
1263
+ padding_mode='zeros',
1264
+ weight_norm_type='none', weight_norm_params=None,
1265
+ activation_norm_type='none', activation_norm_params=None,
1266
+ skip_activation_norm=True, skip_nonlinearity=False,
1267
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
1268
+ multi_channel=False, return_mask=True,
1269
+ apply_noise=False, hidden_channels_equal_out_channels=False,
1270
+ order='CNACNA', learn_shortcut=None, clamp=None,
1271
+ output_scale=1):
1272
+ super().__init__(in_channels, out_channels, kernel_size,
1273
+ stride, padding, dilation, groups, bias,
1274
+ padding_mode, weight_norm_type, weight_norm_params,
1275
+ activation_norm_type, activation_norm_params,
1276
+ skip_activation_norm, skip_nonlinearity,
1277
+ nonlinearity, inplace_nonlinearity, multi_channel,
1278
+ return_mask, apply_noise,
1279
+ hidden_channels_equal_out_channels,
1280
+ order, PartialConv3dBlock, learn_shortcut, clamp,
1281
+ output_scale)
1282
+
1283
+
1284
+ class _BaseMultiOutResBlock(_BaseResBlock):
1285
+ r"""An abstract class for residual blocks that can returns multiple outputs.
1286
+ """
1287
+
1288
+ def __init__(self, in_channels, out_channels, kernel_size,
1289
+ stride, padding, dilation, groups, bias, padding_mode,
1290
+ weight_norm_type, weight_norm_params,
1291
+ activation_norm_type, activation_norm_params,
1292
+ skip_activation_norm, skip_nonlinearity,
1293
+ nonlinearity, inplace_nonlinearity,
1294
+ apply_noise, hidden_channels_equal_out_channels,
1295
+ order, block, learn_shortcut, clamp=None, output_scale=1,
1296
+ blur=False, upsample_first=True):
1297
+ self.multiple_outputs = True
1298
+ super().__init__(in_channels, out_channels, kernel_size, stride,
1299
+ padding, dilation, groups, bias, padding_mode,
1300
+ weight_norm_type, weight_norm_params,
1301
+ activation_norm_type, activation_norm_params,
1302
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
1303
+ inplace_nonlinearity, apply_noise,
1304
+ hidden_channels_equal_out_channels, order, block,
1305
+ learn_shortcut, clamp, output_scale, blur=blur,
1306
+ upsample_first=upsample_first)
1307
+
1308
+ def forward(self, x, *cond_inputs):
1309
+ r"""
1310
+
1311
+ Args:
1312
+ x (tensor): Input tensor.
1313
+ cond_inputs (list of tensors) : Conditional input tensors.
1314
+ Returns:
1315
+ (tuple):
1316
+ - output (tensor): Output tensor.
1317
+ - aux_outputs_0 (tensor): Auxiliary output of the first block.
1318
+ - aux_outputs_1 (tensor): Auxiliary output of the second block.
1319
+ """
1320
+ dx, aux_outputs_0 = self.conv_block_0(x, *cond_inputs)
1321
+ dx, aux_outputs_1 = self.conv_block_1(dx, *cond_inputs)
1322
+ if self.learn_shortcut:
1323
+ # We are not using the auxiliary outputs of self.conv_block_s.
1324
+ x_shortcut, _ = self.conv_block_s(x, *cond_inputs)
1325
+ else:
1326
+ x_shortcut = x
1327
+ output = x_shortcut + dx
1328
+ return self.output_scale * output, aux_outputs_0, aux_outputs_1
1329
+
1330
+
1331
+ class MultiOutRes2dBlock(_BaseMultiOutResBlock):
1332
+ r"""Residual block for 2D input. It can return multiple outputs, if some
1333
+ layers in the block return more than one output.
1334
+
1335
+ Args:
1336
+ in_channels (int) : Number of channels in the input tensor.
1337
+ out_channels (int) : Number of channels in the output tensor.
1338
+ kernel_size (int, optional, default=3): Kernel size for the
1339
+ convolutional filters in the residual link.
1340
+ padding (int, optional, default=1): Padding size.
1341
+ dilation (int, optional, default=1): Dilation factor.
1342
+ groups (int, optional, default=1): Number of convolutional/linear
1343
+ groups.
1344
+ padding_mode (string, optional, default='zeros'): Type of padding:
1345
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
1346
+ weight_norm_type (str, optional, default='none'):
1347
+ Type of weight normalization.
1348
+ ``'none'``, ``'spectral'``, ``'weight'``
1349
+ or ``'weight_demod'``.
1350
+ weight_norm_params (obj, optional, default=None):
1351
+ Parameters of weight normalization.
1352
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
1353
+ keyword arguments when initializing weight normalization.
1354
+ activation_norm_type (str, optional, default='none'):
1355
+ Type of activation normalization.
1356
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
1357
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
1358
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
1359
+ activation_norm_params (obj, optional, default=None):
1360
+ Parameters of activation normalization.
1361
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
1362
+ keyword arguments when initializing activation normalization.
1363
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
1364
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
1365
+ learned shortcut connection.
1366
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
1367
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
1368
+ learned shortcut connection.
1369
+ nonlinearity (str, optional, default='none'):
1370
+ Type of nonlinear activation function in the residual link.
1371
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
1372
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
1373
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
1374
+ set ``inplace=True`` when initializing the nonlinearity layers.
1375
+ apply_noise (bool, optional, default=False): If ``True``, adds
1376
+ Gaussian noise with learnable magnitude to the convolution output.
1377
+ hidden_channels_equal_out_channels (bool, optional, default=False):
1378
+ If ``True``, set the hidden channel number to be equal to the
1379
+ output channel number. If ``False``, the hidden channel number
1380
+ equals to the smaller of the input channel number and the
1381
+ output channel number.
1382
+ order (str, optional, default='CNACNA'): Order of operations
1383
+ in the residual link.
1384
+ ``'C'``: convolution,
1385
+ ``'N'``: normalization,
1386
+ ``'A'``: nonlinear activation.
1387
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
1388
+ a convolutional shortcut instead of an identity one, otherwise only
1389
+ use a convolutional one if input and output have different number of
1390
+ channels.
1391
+ """
1392
+
1393
+ def __init__(self, in_channels, out_channels, kernel_size=3,
1394
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
1395
+ padding_mode='zeros',
1396
+ weight_norm_type='none', weight_norm_params=None,
1397
+ activation_norm_type='none', activation_norm_params=None,
1398
+ skip_activation_norm=True, skip_nonlinearity=False,
1399
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
1400
+ apply_noise=False, hidden_channels_equal_out_channels=False,
1401
+ order='CNACNA', learn_shortcut=None, clamp=None,
1402
+ output_scale=1, blur=False, upsample_first=True):
1403
+ super().__init__(in_channels, out_channels, kernel_size, stride,
1404
+ padding, dilation, groups, bias, padding_mode,
1405
+ weight_norm_type, weight_norm_params,
1406
+ activation_norm_type, activation_norm_params,
1407
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
1408
+ inplace_nonlinearity, apply_noise,
1409
+ hidden_channels_equal_out_channels, order,
1410
+ MultiOutConv2dBlock, learn_shortcut, clamp,
1411
+ output_scale, blur=blur, upsample_first=upsample_first)
imaginaire/layers/residual_deep.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import torch
6
+ from torch import nn
7
+ from torch.utils.checkpoint import checkpoint
8
+
9
+ from imaginaire.third_party.upfirdn2d import BlurDownsample, BlurUpsample
10
+ from .conv import Conv2dBlock
11
+
12
+
13
+ class _BaseDeepResBlock(nn.Module):
14
+ def __init__(self, in_channels, out_channels, kernel_size,
15
+ stride, padding, dilation, groups, bias, padding_mode,
16
+ weight_norm_type, weight_norm_params,
17
+ activation_norm_type, activation_norm_params,
18
+ skip_activation_norm, skip_nonlinearity,
19
+ nonlinearity, inplace_nonlinearity, apply_noise,
20
+ hidden_channels_equal_out_channels,
21
+ order, block, learn_shortcut, output_scale, skip_block=None,
22
+ blur=True, border_free=True, resample_first=True,
23
+ skip_weight_norm=True, hidden_channel_ratio=4):
24
+ super().__init__()
25
+ self.in_channels = in_channels
26
+ self.out_channels = out_channels
27
+ self.output_scale = output_scale
28
+ self.resample_first = resample_first
29
+ self.stride = stride
30
+ self.blur = blur
31
+ self.border_free = border_free
32
+ assert not border_free
33
+ if skip_block is None:
34
+ skip_block = block
35
+
36
+ if order == 'pre_act':
37
+ order = 'NACNAC'
38
+ if isinstance(bias, bool):
39
+ # The bias for conv_block_0, conv_block_1, and conv_block_s.
40
+ biases = [bias, bias, bias]
41
+ elif isinstance(bias, list):
42
+ if len(bias) == 3:
43
+ biases = bias
44
+ else:
45
+ raise ValueError('Bias list must be 3.')
46
+ else:
47
+ raise ValueError('Bias must be either an integer or s list.')
48
+ self.learn_shortcut = learn_shortcut
49
+ if len(order) > 6 or len(order) < 5:
50
+ raise ValueError('order must be either 5 or 6 characters')
51
+ hidden_channels = in_channels // hidden_channel_ratio
52
+
53
+ # Parameters.
54
+ residual_params = {}
55
+ shortcut_params = {}
56
+ base_params = dict(dilation=dilation,
57
+ groups=groups,
58
+ padding_mode=padding_mode)
59
+ residual_params.update(base_params)
60
+ residual_params.update(
61
+ dict(activation_norm_type=activation_norm_type,
62
+ activation_norm_params=activation_norm_params,
63
+ weight_norm_type=weight_norm_type,
64
+ weight_norm_params=weight_norm_params,
65
+ apply_noise=apply_noise)
66
+ )
67
+ shortcut_params.update(base_params)
68
+ shortcut_params.update(dict(kernel_size=1))
69
+ if skip_activation_norm:
70
+ shortcut_params.update(
71
+ dict(activation_norm_type=activation_norm_type,
72
+ activation_norm_params=activation_norm_params,
73
+ apply_noise=False))
74
+ if skip_weight_norm:
75
+ shortcut_params.update(
76
+ dict(weight_norm_type=weight_norm_type,
77
+ weight_norm_params=weight_norm_params))
78
+
79
+ # Residual branch.
80
+ if order.find('A') < order.find('C') and \
81
+ (activation_norm_type == '' or activation_norm_type == 'none'):
82
+ # Nonlinearity is the first operation in the residual path.
83
+ # In-place nonlinearity will modify the input variable and cause
84
+ # backward error.
85
+ first_inplace = False
86
+ else:
87
+ first_inplace = inplace_nonlinearity
88
+
89
+ (first_stride, second_stride, shortcut_stride,
90
+ first_blur, second_blur, shortcut_blur) = self._get_stride_blur()
91
+
92
+ self.conv_block_1x1_in = block(
93
+ in_channels, hidden_channels,
94
+ 1, 1, 0,
95
+ bias=biases[0],
96
+ nonlinearity=nonlinearity,
97
+ order=order[0:3],
98
+ inplace_nonlinearity=first_inplace,
99
+ **residual_params
100
+ )
101
+
102
+ self.conv_block_0 = block(
103
+ hidden_channels, hidden_channels,
104
+ kernel_size=2 if self.border_free and first_stride < 1 else
105
+ kernel_size,
106
+ padding=padding,
107
+ bias=biases[0],
108
+ nonlinearity=nonlinearity,
109
+ order=order[0:3],
110
+ inplace_nonlinearity=inplace_nonlinearity,
111
+ stride=first_stride,
112
+ blur=first_blur,
113
+ **residual_params
114
+ )
115
+ self.conv_block_1 = block(
116
+ hidden_channels, hidden_channels,
117
+ kernel_size=kernel_size,
118
+ padding=padding,
119
+ bias=biases[1],
120
+ nonlinearity=nonlinearity,
121
+ order=order[3:],
122
+ inplace_nonlinearity=inplace_nonlinearity,
123
+ stride=second_stride,
124
+ blur=second_blur,
125
+ **residual_params
126
+ )
127
+
128
+ self.conv_block_1x1_out = block(
129
+ hidden_channels, out_channels,
130
+ 1, 1, 0,
131
+ bias=biases[1],
132
+ nonlinearity=nonlinearity,
133
+ order=order[0:3],
134
+ inplace_nonlinearity=inplace_nonlinearity,
135
+ **residual_params
136
+ )
137
+
138
+ # Shortcut branch.
139
+ if self.learn_shortcut:
140
+ if skip_nonlinearity:
141
+ skip_nonlinearity_type = nonlinearity
142
+ else:
143
+ skip_nonlinearity_type = ''
144
+ self.conv_block_s = skip_block(in_channels, out_channels,
145
+ bias=biases[2],
146
+ nonlinearity=skip_nonlinearity_type,
147
+ order=order[0:3],
148
+ stride=shortcut_stride,
149
+ blur=shortcut_blur,
150
+ **shortcut_params)
151
+ elif in_channels < out_channels:
152
+ if skip_nonlinearity:
153
+ skip_nonlinearity_type = nonlinearity
154
+ else:
155
+ skip_nonlinearity_type = ''
156
+ self.conv_block_s = skip_block(in_channels,
157
+ out_channels - in_channels,
158
+ bias=biases[2],
159
+ nonlinearity=skip_nonlinearity_type,
160
+ order=order[0:3],
161
+ stride=shortcut_stride,
162
+ blur=shortcut_blur,
163
+ **shortcut_params)
164
+
165
+ # Whether this block expects conditional inputs.
166
+ self.conditional = \
167
+ getattr(self.conv_block_0, 'conditional', False) or \
168
+ getattr(self.conv_block_1, 'conditional', False) or \
169
+ getattr(self.conv_block_1x1_in, 'conditional', False) or \
170
+ getattr(self.conv_block_1x1_out, 'conditional', False)
171
+
172
+ def _get_stride_blur(self):
173
+ if self.stride > 1:
174
+ # Downsampling.
175
+ first_stride, second_stride = 1, self.stride
176
+ first_blur, second_blur = False, self.blur
177
+ shortcut_blur = False
178
+ shortcut_stride = 1
179
+ if self.blur:
180
+ # The shortcut branch uses blur_downsample + stride-1 conv
181
+ if self.border_free:
182
+ self.resample = nn.AvgPool2d(2)
183
+ else:
184
+ self.resample = BlurDownsample()
185
+ else:
186
+ shortcut_stride = self.stride
187
+ self.resample = nn.AvgPool2d(2)
188
+ elif self.stride < 1:
189
+ # Upsampling.
190
+ first_stride, second_stride = self.stride, 1
191
+ first_blur, second_blur = self.blur, False
192
+ shortcut_blur = False
193
+ shortcut_stride = 1
194
+ if self.blur:
195
+ # The shortcut branch uses blur_upsample + stride-1 conv
196
+ if self.border_free:
197
+ self.resample = nn.Upsample(scale_factor=2,
198
+ mode='bilinear')
199
+ else:
200
+ self.resample = BlurUpsample()
201
+ else:
202
+ shortcut_stride = self.stride
203
+ self.resample = nn.Upsample(scale_factor=2)
204
+ else:
205
+ first_stride = second_stride = 1
206
+ first_blur = second_blur = False
207
+ shortcut_stride = 1
208
+ shortcut_blur = False
209
+ self.resample = None
210
+ return (first_stride, second_stride, shortcut_stride,
211
+ first_blur, second_blur, shortcut_blur)
212
+
213
+ def conv_blocks(
214
+ self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs
215
+ ):
216
+ if separate_cond:
217
+ assert len(list(cond_inputs)) == 4
218
+ dx = self.conv_block_1x1_in(x, cond_inputs[0],
219
+ **kw_cond_inputs.get('kwargs_0', {}))
220
+ dx = self.conv_block_0(dx, cond_inputs[1],
221
+ **kw_cond_inputs.get('kwargs_1', {}))
222
+ dx = self.conv_block_1(dx, cond_inputs[2],
223
+ **kw_cond_inputs.get('kwargs_2', {}))
224
+ dx = self.conv_block_1x1_out(dx, cond_inputs[3],
225
+ **kw_cond_inputs.get('kwargs_3', {}))
226
+ else:
227
+ dx = self.conv_block_1x1_in(x, *cond_inputs, **kw_cond_inputs)
228
+ dx = self.conv_block_0(dx, *cond_inputs, **kw_cond_inputs)
229
+ dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
230
+ dx = self.conv_block_1x1_out(dx, *cond_inputs, **kw_cond_inputs)
231
+ return dx
232
+
233
+ def forward(self, x, *cond_inputs, do_checkpoint=False, **kw_cond_inputs):
234
+ if do_checkpoint:
235
+ dx = checkpoint(self.conv_blocks, x, *cond_inputs, **kw_cond_inputs)
236
+ else:
237
+ dx = self.conv_blocks(x, *cond_inputs, **kw_cond_inputs)
238
+
239
+ if self.resample_first and self.resample is not None:
240
+ x = self.resample(x)
241
+ if self.learn_shortcut:
242
+ x_shortcut = self.conv_block_s(
243
+ x, *cond_inputs, **kw_cond_inputs
244
+ )
245
+ elif self.in_channels < self.out_channels:
246
+ x_shortcut_pad = self.conv_block_s(
247
+ x, *cond_inputs, **kw_cond_inputs
248
+ )
249
+ x_shortcut = torch.cat((x, x_shortcut_pad), dim=1)
250
+ elif self.in_channels > self.out_channels:
251
+ x_shortcut = x[:, :self.out_channels, :, :]
252
+ else:
253
+ x_shortcut = x
254
+ if not self.resample_first and self.resample is not None:
255
+ x_shortcut = self.resample(x_shortcut)
256
+
257
+ output = x_shortcut + dx
258
+ return self.output_scale * output
259
+
260
+ def extra_repr(self):
261
+ s = 'output_scale={output_scale}'
262
+ return s.format(**self.__dict__)
263
+
264
+
265
+ class DeepRes2dBlock(_BaseDeepResBlock):
266
+ r"""Residual block for 2D input.
267
+
268
+ Args:
269
+ in_channels (int) : Number of channels in the input tensor.
270
+ out_channels (int) : Number of channels in the output tensor.
271
+ kernel_size (int, optional, default=3): Kernel size for the
272
+ convolutional filters in the residual link.
273
+ padding (int, optional, default=1): Padding size.
274
+ dilation (int, optional, default=1): Dilation factor.
275
+ groups (int, optional, default=1): Number of convolutional/linear
276
+ groups.
277
+ padding_mode (string, optional, default='zeros'): Type of padding:
278
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
279
+ weight_norm_type (str, optional, default='none'):
280
+ Type of weight normalization.
281
+ ``'none'``, ``'spectral'``, ``'weight'``
282
+ or ``'weight_demod'``.
283
+ weight_norm_params (obj, optional, default=None):
284
+ Parameters of weight normalization.
285
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
286
+ keyword arguments when initializing weight normalization.
287
+ activation_norm_type (str, optional, default='none'):
288
+ Type of activation normalization.
289
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
290
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
291
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
292
+ activation_norm_params (obj, optional, default=None):
293
+ Parameters of activation normalization.
294
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
295
+ keyword arguments when initializing activation normalization.
296
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
297
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
298
+ learned shortcut connection.
299
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
300
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
301
+ learned shortcut connection.
302
+ nonlinearity (str, optional, default='none'):
303
+ Type of nonlinear activation function in the residual link.
304
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
305
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
306
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
307
+ set ``inplace=True`` when initializing the nonlinearity layers.
308
+ apply_noise (bool, optional, default=False): If ``True``, adds
309
+ Gaussian noise with learnable magnitude to the convolution output.
310
+ hidden_channels_equal_out_channels (bool, optional, default=False):
311
+ If ``True``, set the hidden channel number to be equal to the
312
+ output channel number. If ``False``, the hidden channel number
313
+ equals to the smaller of the input channel number and the
314
+ output channel number.
315
+ order (str, optional, default='CNACNA'): Order of operations
316
+ in the residual link.
317
+ ``'C'``: convolution,
318
+ ``'N'``: normalization,
319
+ ``'A'``: nonlinear activation.
320
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
321
+ a convolutional shortcut instead of an identity one, otherwise only
322
+ use a convolutional one if input and output have different number of
323
+ channels.
324
+ """
325
+
326
+ def __init__(self, in_channels, out_channels, kernel_size=3,
327
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
328
+ padding_mode='zeros',
329
+ weight_norm_type='none', weight_norm_params=None,
330
+ activation_norm_type='none', activation_norm_params=None,
331
+ skip_activation_norm=True, skip_nonlinearity=False,
332
+ skip_weight_norm=True,
333
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
334
+ apply_noise=False, hidden_channels_equal_out_channels=False,
335
+ order='CNACNA', learn_shortcut=False, output_scale=1,
336
+ blur=True, resample_first=True, border_free=False):
337
+ super().__init__(in_channels, out_channels, kernel_size, stride,
338
+ padding, dilation, groups, bias, padding_mode,
339
+ weight_norm_type, weight_norm_params,
340
+ activation_norm_type, activation_norm_params,
341
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
342
+ inplace_nonlinearity, apply_noise,
343
+ hidden_channels_equal_out_channels, order, Conv2dBlock,
344
+ learn_shortcut, output_scale, blur=blur,
345
+ resample_first=resample_first, border_free=border_free,
346
+ skip_weight_norm=skip_weight_norm)
imaginaire/layers/vit.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ from types import SimpleNamespace
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from .misc import ApplyNoise
11
+ from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur
12
+
13
+
14
+ class ViT2dBlock(nn.Module):
15
+ r"""An abstract wrapper class that wraps a torch convolution or linear layer
16
+ with normalization and nonlinearity.
17
+ """
18
+
19
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
20
+ padding, dilation, groups, bias, padding_mode,
21
+ weight_norm_type, weight_norm_params,
22
+ activation_norm_type, activation_norm_params,
23
+ nonlinearity, inplace_nonlinearity,
24
+ apply_noise, blur, order, input_dim, clamp,
25
+ blur_kernel=(1, 3, 3, 1), output_scale=None,
26
+ init_gain=1.0):
27
+ super().__init__()
28
+ from .nonlinearity import get_nonlinearity_layer
29
+ from .weight_norm import get_weight_norm_layer
30
+ from .activation_norm import get_activation_norm_layer
31
+ self.weight_norm_type = weight_norm_type
32
+ self.stride = stride
33
+ self.clamp = clamp
34
+ self.init_gain = init_gain
35
+
36
+ # Nonlinearity layer.
37
+ if 'fused' in nonlinearity:
38
+ # Fusing nonlinearity with bias.
39
+ lr_mul = getattr(weight_norm_params, 'lr_mul', 1)
40
+ conv_before_nonlinearity = order.find('C') < order.find('A')
41
+ if conv_before_nonlinearity:
42
+ assert bias
43
+ bias = False
44
+ channel = out_channels if conv_before_nonlinearity else in_channels
45
+ nonlinearity_layer = get_nonlinearity_layer(
46
+ nonlinearity, inplace=inplace_nonlinearity,
47
+ num_channels=channel, lr_mul=lr_mul)
48
+ else:
49
+ nonlinearity_layer = get_nonlinearity_layer(
50
+ nonlinearity, inplace=inplace_nonlinearity)
51
+
52
+ # Noise injection layer.
53
+ if apply_noise:
54
+ order = order.replace('C', 'CG')
55
+ noise_layer = ApplyNoise()
56
+ else:
57
+ noise_layer = None
58
+
59
+ # Convolutional layer.
60
+ if blur:
61
+ if stride == 2:
62
+ # Blur - Conv - Noise - Activate
63
+ p = (len(blur_kernel) - 2) + (kernel_size - 1)
64
+ pad0, pad1 = (p + 1) // 2, p // 2
65
+ padding = 0
66
+ blur_layer = Blur(
67
+ blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
68
+ )
69
+ order = order.replace('C', 'BC')
70
+ elif stride == 0.5:
71
+ # Conv - Blur - Noise - Activate
72
+ padding = 0
73
+ p = (len(blur_kernel) - 2) - (kernel_size - 1)
74
+ pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1
75
+ blur_layer = Blur(
76
+ blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
77
+ )
78
+ order = order.replace('C', 'CB')
79
+ elif stride == 1:
80
+ # No blur for now
81
+ blur_layer = nn.Identity()
82
+ else:
83
+ raise NotImplementedError
84
+ else:
85
+ blur_layer = nn.Identity()
86
+
87
+ if weight_norm_params is None:
88
+ weight_norm_params = SimpleNamespace()
89
+ weight_norm = get_weight_norm_layer(
90
+ weight_norm_type, **vars(weight_norm_params))
91
+ conv_layer = weight_norm(self._get_conv_layer(
92
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
93
+ groups, bias, padding_mode, input_dim))
94
+
95
+ # Normalization layer.
96
+ conv_before_norm = order.find('C') < order.find('N')
97
+ norm_channels = out_channels if conv_before_norm else in_channels
98
+ if activation_norm_params is None:
99
+ activation_norm_params = SimpleNamespace()
100
+ activation_norm_layer = get_activation_norm_layer(
101
+ norm_channels,
102
+ activation_norm_type,
103
+ input_dim,
104
+ **vars(activation_norm_params))
105
+
106
+ # Mapping from operation names to layers.
107
+ mappings = {'C': {'conv': conv_layer},
108
+ 'N': {'norm': activation_norm_layer},
109
+ 'A': {'nonlinearity': nonlinearity_layer}}
110
+ mappings.update({'B': {'blur': blur_layer}})
111
+ mappings.update({'G': {'noise': noise_layer}})
112
+
113
+ # All layers in order.
114
+ self.layers = nn.ModuleDict()
115
+ for op in order:
116
+ if list(mappings[op].values())[0] is not None:
117
+ self.layers.update(mappings[op])
118
+
119
+ # Whether this block expects conditional inputs.
120
+ self.conditional = \
121
+ getattr(conv_layer, 'conditional', False) or \
122
+ getattr(activation_norm_layer, 'conditional', False)
123
+
124
+ if output_scale is not None:
125
+ self.output_scale = nn.Parameter(torch.tensor(output_scale))
126
+ else:
127
+ self.register_parameter("output_scale", None)
128
+
129
+ def forward(self, x, *cond_inputs, **kw_cond_inputs):
130
+ r"""
131
+
132
+ Args:
133
+ x (tensor): Input tensor.
134
+ cond_inputs (list of tensors) : Conditional input tensors.
135
+ kw_cond_inputs (dict) : Keyword conditional inputs.
136
+ """
137
+ for key, layer in self.layers.items():
138
+ if getattr(layer, 'conditional', False):
139
+ # Layers that require conditional inputs.
140
+ x = layer(x, *cond_inputs, **kw_cond_inputs)
141
+ else:
142
+ x = layer(x)
143
+ if self.clamp is not None and isinstance(layer, nn.Conv2d):
144
+ x.clamp_(max=self.clamp)
145
+ if key == 'conv':
146
+ if self.output_scale is not None:
147
+ x = x * self.output_scale
148
+ return x
149
+
150
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
151
+ padding, dilation, groups, bias, padding_mode,
152
+ input_dim):
153
+ # Returns the convolutional layer.
154
+ if input_dim == 0:
155
+ layer = nn.Linear(in_channels, out_channels, bias)
156
+ else:
157
+ if stride < 1: # Fractionally-strided convolution.
158
+ padding_mode = 'zeros'
159
+ assert padding == 0
160
+ layer_type = getattr(nn, f'ConvTranspose{input_dim}d')
161
+ stride = round(1 / stride)
162
+ else:
163
+ layer_type = getattr(nn, f'Conv{input_dim}d')
164
+ layer = layer_type(
165
+ in_channels, out_channels, kernel_size, stride, padding,
166
+ dilation=dilation, groups=groups, bias=bias,
167
+ padding_mode=padding_mode
168
+ )
169
+
170
+ return layer
171
+
172
+ def __repr__(self):
173
+ main_str = self._get_name() + '('
174
+ child_lines = []
175
+ for name, layer in self.layers.items():
176
+ mod_str = repr(layer)
177
+ if name == 'conv' and self.weight_norm_type != 'none' and \
178
+ self.weight_norm_type != '':
179
+ mod_str = mod_str[:-1] + \
180
+ ', weight_norm={}'.format(self.weight_norm_type) + ')'
181
+ if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1:
182
+ mod_str = mod_str[:-1] + \
183
+ ', lr_mul={}'.format(layer.base_lr_mul) + ')'
184
+ mod_str = self._addindent(mod_str, 2)
185
+ child_lines.append(mod_str)
186
+ if len(child_lines) == 1:
187
+ main_str += child_lines[0]
188
+ else:
189
+ main_str += '\n ' + '\n '.join(child_lines) + '\n'
190
+
191
+ main_str += ')'
192
+ return main_str
193
+
194
+ @staticmethod
195
+ def _addindent(s_, numSpaces):
196
+ s = s_.split('\n')
197
+ # don't do anything for single-line stuff
198
+ if len(s) == 1:
199
+ return s_
200
+ first = s.pop(0)
201
+ s = [(numSpaces * ' ') + line for line in s]
202
+ s = '\n'.join(s)
203
+ s = first + '\n' + s
204
+ return s
imaginaire/layers/weight_norm.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import collections
6
+ import functools
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn.utils import spectral_norm, weight_norm
11
+ from torch.nn.utils.spectral_norm import SpectralNorm, \
12
+ SpectralNormStateDictHook, SpectralNormLoadStateDictPreHook
13
+
14
+ from .conv import LinearBlock
15
+
16
+
17
+ class WeightDemodulation(nn.Module):
18
+ r"""Weight demodulation in
19
+ "Analyzing and Improving the Image Quality of StyleGAN", Karras et al.
20
+
21
+ Args:
22
+ conv (torch.nn.Modules): Convolutional layer.
23
+ cond_dims (int): The number of channels in the conditional input.
24
+ eps (float, optional, default=1e-8): a value added to the
25
+ denominator for numerical stability.
26
+ adaptive_bias (bool, optional, default=False): If ``True``, adaptively
27
+ predicts bias from the conditional input.
28
+ demod (bool, optional, default=False): If ``True``, performs
29
+ weight demodulation.
30
+ """
31
+
32
+ def __init__(self, conv, cond_dims, eps=1e-8,
33
+ adaptive_bias=False, demod=True):
34
+ super().__init__()
35
+ self.conv = conv
36
+ self.adaptive_bias = adaptive_bias
37
+ if adaptive_bias:
38
+ self.conv.register_parameter('bias', None)
39
+ self.fc_beta = LinearBlock(cond_dims, self.conv.out_channels)
40
+ self.fc_gamma = LinearBlock(cond_dims, self.conv.in_channels)
41
+ self.eps = eps
42
+ self.demod = demod
43
+ self.conditional = True
44
+
45
+ def forward(self, x, y, **_kwargs):
46
+ r"""Weight demodulation forward"""
47
+ b, c, h, w = x.size()
48
+ self.conv.groups = b
49
+ gamma = self.fc_gamma(y)
50
+ gamma = gamma[:, None, :, None, None]
51
+ weight = self.conv.weight[None, :, :, :, :] * gamma
52
+
53
+ if self.demod:
54
+ d = torch.rsqrt(
55
+ (weight ** 2).sum(
56
+ dim=(2, 3, 4), keepdim=True) + self.eps)
57
+ weight = weight * d
58
+
59
+ x = x.reshape(1, -1, h, w)
60
+ _, _, *ws = weight.shape
61
+ weight = weight.reshape(b * self.conv.out_channels, *ws)
62
+ x = self.conv._conv_forward(x, weight)
63
+
64
+ x = x.reshape(-1, self.conv.out_channels, h, w)
65
+ if self.adaptive_bias:
66
+ x += self.fc_beta(y)[:, :, None, None]
67
+ return x
68
+
69
+
70
+ def weight_demod(
71
+ conv, cond_dims=256, eps=1e-8, adaptive_bias=False, demod=True):
72
+ r"""Weight demodulation."""
73
+ return WeightDemodulation(conv, cond_dims, eps, adaptive_bias, demod)
74
+
75
+
76
+ class ScaledLR(object):
77
+ def __init__(self, weight_name, bias_name):
78
+ self.weight_name = weight_name
79
+ self.bias_name = bias_name
80
+
81
+ def compute_weight(self, module):
82
+ weight = getattr(module, self.weight_name + '_ori')
83
+ return weight * module.weight_scale
84
+
85
+ def compute_bias(self, module):
86
+ bias = getattr(module, self.bias_name + '_ori')
87
+ if bias is not None:
88
+ return bias * module.bias_scale
89
+ else:
90
+ return None
91
+
92
+ @staticmethod
93
+ def apply(module, weight_name, bias_name, lr_mul, equalized):
94
+ assert weight_name == 'weight'
95
+ assert bias_name == 'bias'
96
+ fn = ScaledLR(weight_name, bias_name)
97
+ module.register_forward_pre_hook(fn)
98
+
99
+ if hasattr(module, bias_name):
100
+ # module.bias is a parameter (can be None).
101
+ bias = getattr(module, bias_name)
102
+ delattr(module, bias_name)
103
+ module.register_parameter(bias_name + '_ori', bias)
104
+ else:
105
+ # module.bias does not exist.
106
+ bias = None
107
+ setattr(module, bias_name + '_ori', bias)
108
+ if bias is not None:
109
+ setattr(module, bias_name, bias.data)
110
+ else:
111
+ setattr(module, bias_name, None)
112
+ module.register_buffer('bias_scale', torch.tensor(lr_mul))
113
+
114
+ if hasattr(module, weight_name + '_orig'):
115
+ # The module has been wrapped with spectral normalization.
116
+ # We only want to keep a single weight parameter.
117
+ weight = getattr(module, weight_name + '_orig')
118
+ delattr(module, weight_name + '_orig')
119
+ module.register_parameter(weight_name + '_ori', weight)
120
+ setattr(module, weight_name + '_orig', weight.data)
121
+ # Put this hook before the spectral norm hook.
122
+ module._forward_pre_hooks = collections.OrderedDict(
123
+ reversed(list(module._forward_pre_hooks.items()))
124
+ )
125
+ module.use_sn = True
126
+ else:
127
+ weight = getattr(module, weight_name)
128
+ delattr(module, weight_name)
129
+ module.register_parameter(weight_name + '_ori', weight)
130
+ setattr(module, weight_name, weight.data)
131
+ module.use_sn = False
132
+
133
+ # assert weight.dim() == 4 or weight.dim() == 2
134
+ if equalized:
135
+ fan_in = weight.data.size(1) * weight.data[0][0].numel()
136
+ # Theoretically, the gain should be sqrt(2) instead of 1.
137
+ # The official StyleGAN2 uses 1 for some reason.
138
+ module.register_buffer(
139
+ 'weight_scale', torch.tensor(lr_mul * ((1 / fan_in) ** 0.5))
140
+ )
141
+ else:
142
+ module.register_buffer('weight_scale', torch.tensor(lr_mul))
143
+
144
+ module.lr_mul = module.weight_scale
145
+ module.base_lr_mul = lr_mul
146
+
147
+ return fn
148
+
149
+ def remove(self, module):
150
+ with torch.no_grad():
151
+ weight = self.compute_weight(module)
152
+ delattr(module, self.weight_name + '_ori')
153
+
154
+ if module.use_sn:
155
+ setattr(module, self.weight_name + '_orig', weight.detach())
156
+ else:
157
+ delattr(module, self.weight_name)
158
+ module.register_parameter(self.weight_name,
159
+ torch.nn.Parameter(weight.detach()))
160
+
161
+ with torch.no_grad():
162
+ bias = self.compute_bias(module)
163
+ delattr(module, self.bias_name)
164
+ delattr(module, self.bias_name + '_ori')
165
+ if bias is not None:
166
+ module.register_parameter(self.bias_name,
167
+ torch.nn.Parameter(bias.detach()))
168
+ else:
169
+ module.register_parameter(self.bias_name, None)
170
+
171
+ module.lr_mul = 1.0
172
+ module.base_lr_mul = 1.0
173
+
174
+ def __call__(self, module, input):
175
+ weight = self.compute_weight(module)
176
+ if module.use_sn:
177
+ # The following spectral norm hook will compute the SN of
178
+ # "module.weight_orig" and store the normalized weight in
179
+ # "module.weight".
180
+ setattr(module, self.weight_name + '_orig', weight)
181
+ else:
182
+ setattr(module, self.weight_name, weight)
183
+ bias = self.compute_bias(module)
184
+ setattr(module, self.bias_name, bias)
185
+
186
+
187
+ def remove_weight_norms(module, weight_name='weight', bias_name='bias'):
188
+ if hasattr(module, 'weight_ori') or hasattr(module, 'weight_orig'):
189
+ for k in list(module._forward_pre_hooks.keys()):
190
+ hook = module._forward_pre_hooks[k]
191
+ if (isinstance(hook, ScaledLR) or isinstance(hook, SpectralNorm)):
192
+ hook.remove(module)
193
+ del module._forward_pre_hooks[k]
194
+
195
+ for k, hook in module._state_dict_hooks.items():
196
+ if isinstance(hook, SpectralNormStateDictHook) and \
197
+ hook.fn.name == weight_name:
198
+ del module._state_dict_hooks[k]
199
+ break
200
+
201
+ for k, hook in module._load_state_dict_pre_hooks.items():
202
+ if isinstance(hook, SpectralNormLoadStateDictPreHook) and \
203
+ hook.fn.name == weight_name:
204
+ del module._load_state_dict_pre_hooks[k]
205
+ break
206
+
207
+ return module
208
+
209
+
210
+ def remove_equalized_lr(module, weight_name='weight', bias_name='bias'):
211
+ for k, hook in module._forward_pre_hooks.items():
212
+ if isinstance(hook, ScaledLR) and hook.weight_name == weight_name:
213
+ hook.remove(module)
214
+ del module._forward_pre_hooks[k]
215
+ break
216
+ else:
217
+ raise ValueError("Equalized learning rate not found")
218
+
219
+ return module
220
+
221
+
222
+ def scaled_lr(
223
+ module, weight_name='weight', bias_name='bias', lr_mul=1.,
224
+ equalized=False,
225
+ ):
226
+ ScaledLR.apply(module, weight_name, bias_name, lr_mul, equalized)
227
+ return module
228
+
229
+
230
+ def get_weight_norm_layer(norm_type, **norm_params):
231
+ r"""Return weight normalization.
232
+
233
+ Args:
234
+ norm_type (str):
235
+ Type of weight normalization.
236
+ ``'none'``, ``'spectral'``, ``'weight'``
237
+ or ``'weight_demod'``.
238
+ norm_params: Arbitrary keyword arguments that will be used to
239
+ initialize the weight normalization.
240
+ """
241
+ if norm_type == 'none' or norm_type == '': # no normalization
242
+ return lambda x: x
243
+ elif norm_type == 'spectral': # spectral normalization
244
+ return functools.partial(spectral_norm, **norm_params)
245
+ elif norm_type == 'weight': # weight normalization
246
+ return functools.partial(weight_norm, **norm_params)
247
+ elif norm_type == 'weight_demod': # weight demodulation
248
+ return functools.partial(weight_demod, **norm_params)
249
+ elif norm_type == 'equalized_lr': # equalized learning rate
250
+ return functools.partial(scaled_lr, equalized=True, **norm_params)
251
+ elif norm_type == 'scaled_lr': # equalized learning rate
252
+ return functools.partial(scaled_lr, **norm_params)
253
+ elif norm_type == 'equalized_lr_spectral':
254
+ lr_mul = norm_params.pop('lr_mul', 1.0)
255
+ return lambda x: functools.partial(
256
+ scaled_lr, equalized=True, lr_mul=lr_mul)(
257
+ functools.partial(spectral_norm, **norm_params)(x)
258
+ )
259
+ elif norm_type == 'scaled_lr_spectral':
260
+ lr_mul = norm_params.pop('lr_mul', 1.0)
261
+ return lambda x: functools.partial(
262
+ scaled_lr, lr_mul=lr_mul)(
263
+ functools.partial(spectral_norm, **norm_params)(x)
264
+ )
265
+ else:
266
+ raise ValueError(
267
+ 'Weight norm layer %s is not recognized' % norm_type)
imaginaire/losses/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ from .gan import GANLoss
6
+ from .perceptual import PerceptualLoss
7
+ from .feature_matching import FeatureMatchingLoss
8
+ from .kl import GaussianKLLoss
9
+
10
+ __all__ = ['GANLoss', 'PerceptualLoss', 'FeatureMatchingLoss', 'GaussianKLLoss',
11
+ 'MaskedL1Loss', 'FlowLoss', 'DictLoss',
12
+ 'WeightedMSELoss']
13
+
14
+ try:
15
+ from .gradient_penalty import GradientPenaltyLoss
16
+ __all__.extend(['GradientPenaltyLoss'])
17
+ except: # noqa
18
+ pass
imaginaire/losses/feature_matching.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import torch.nn as nn
6
+
7
+
8
+ class FeatureMatchingLoss(nn.Module):
9
+ r"""Compute feature matching loss"""
10
+ def __init__(self, criterion='l1'):
11
+ super(FeatureMatchingLoss, self).__init__()
12
+ if criterion == 'l1':
13
+ self.criterion = nn.L1Loss()
14
+ elif criterion == 'l2' or criterion == 'mse':
15
+ self.criterion = nn.MSELoss()
16
+ else:
17
+ raise ValueError('Criterion %s is not recognized' % criterion)
18
+
19
+ def forward(self, fake_features, real_features):
20
+ r"""Return the target vector for the binary cross entropy loss
21
+ computation.
22
+
23
+ Args:
24
+ fake_features (list of lists): Discriminator features of fake images.
25
+ real_features (list of lists): Discriminator features of real images.
26
+
27
+ Returns:
28
+ (tensor): Loss value.
29
+ """
30
+ num_d = len(fake_features)
31
+ dis_weight = 1.0 / num_d
32
+ loss = fake_features[0][0].new_tensor(0)
33
+ for i in range(num_d):
34
+ for j in range(len(fake_features[i])):
35
+ tmp_loss = self.criterion(fake_features[i][j],
36
+ real_features[i][j].detach())
37
+ loss += dis_weight * tmp_loss
38
+ return loss
imaginaire/losses/gan.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from imaginaire.utils.distributed import master_only_print as print
11
+
12
+
13
+ @torch.jit.script
14
+ def fuse_math_min_mean_pos(x):
15
+ r"""Fuse operation min mean for hinge loss computation of positive
16
+ samples"""
17
+ minval = torch.min(x - 1, x * 0)
18
+ loss = -torch.mean(minval)
19
+ return loss
20
+
21
+
22
+ @torch.jit.script
23
+ def fuse_math_min_mean_neg(x):
24
+ r"""Fuse operation min mean for hinge loss computation of negative
25
+ samples"""
26
+ minval = torch.min(-x - 1, x * 0)
27
+ loss = -torch.mean(minval)
28
+ return loss
29
+
30
+
31
+ class GANLoss(nn.Module):
32
+ r"""GAN loss constructor.
33
+
34
+ Args:
35
+ gan_mode (str): Type of GAN loss. ``'hinge'``, ``'least_square'``,
36
+ ``'non_saturated'``, ``'wasserstein'``.
37
+ target_real_label (float): The desired output label for real images.
38
+ target_fake_label (float): The desired output label for fake images.
39
+ decay_k (float): The decay factor per epoch for top-k training.
40
+ min_k (float): The minimum percentage of samples to select.
41
+ separate_topk (bool): If ``True``, selects top-k for each sample
42
+ separately, otherwise selects top-k among all samples.
43
+ """
44
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
45
+ decay_k=1., min_k=1., separate_topk=False):
46
+ super(GANLoss, self).__init__()
47
+ self.real_label = target_real_label
48
+ self.fake_label = target_fake_label
49
+ self.real_label_tensor = None
50
+ self.fake_label_tensor = None
51
+ self.gan_mode = gan_mode
52
+ self.decay_k = decay_k
53
+ self.min_k = min_k
54
+ self.separate_topk = separate_topk
55
+ self.register_buffer('k', torch.tensor(1.0))
56
+ print('GAN mode: %s' % gan_mode)
57
+
58
+ def forward(self, dis_output, t_real, dis_update=True, reduce=True):
59
+ r"""GAN loss computation.
60
+
61
+ Args:
62
+ dis_output (tensor or list of tensors): Discriminator outputs.
63
+ t_real (bool): If ``True``, uses the real label as target, otherwise uses the fake label as target.
64
+ dis_update (bool): If ``True``, the loss will be used to update the discriminator, otherwise the generator.
65
+ reduce (bool): If ``True``, when a list of discriminator outputs are provided, it will return the average
66
+ of all losses, otherwise it will return a list of losses.
67
+ Returns:
68
+ loss (tensor): Loss value.
69
+ """
70
+ if isinstance(dis_output, list):
71
+ # For multi-scale discriminators.
72
+ # In this implementation, the loss is first averaged for each scale
73
+ # (batch size and number of locations) then averaged across scales,
74
+ # so that the gradient is not dominated by the discriminator that
75
+ # has the most output values (highest resolution).
76
+ losses = []
77
+ for dis_output_i in dis_output:
78
+ assert isinstance(dis_output_i, torch.Tensor)
79
+ losses.append(self.loss(dis_output_i, t_real, dis_update))
80
+ if reduce:
81
+ return torch.mean(torch.stack(losses))
82
+ else:
83
+ return losses
84
+ else:
85
+ return self.loss(dis_output, t_real, dis_update)
86
+
87
+ def loss(self, dis_output, t_real, dis_update=True):
88
+ r"""GAN loss computation.
89
+
90
+ Args:
91
+ dis_output (tensor): Discriminator outputs.
92
+ t_real (bool): If ``True``, uses the real label as target, otherwise
93
+ uses the fake label as target.
94
+ dis_update (bool): Updating the discriminator or the generator.
95
+ Returns:
96
+ loss (tensor): Loss value.
97
+ """
98
+ if not dis_update:
99
+ assert t_real, \
100
+ "The target should be real when updating the generator."
101
+
102
+ if not dis_update and self.k < 1:
103
+ r"""
104
+ Use top-k training:
105
+ "Top-k Training of GANs: Improving GAN Performance by Throwing
106
+ Away Bad Samples"
107
+ Here, each sample may have multiple discriminator output values
108
+ (patch discriminator). We could either select top-k for each sample
109
+ separately (when ``self.separate_topk=True``), or collect values
110
+ from all samples and then select top-k (default, when
111
+ ``self.separate_topk=False``).
112
+ """
113
+ if self.separate_topk:
114
+ dis_output = dis_output.view(dis_output.size(0), -1)
115
+ else:
116
+ dis_output = dis_output.view(-1)
117
+ k = math.ceil(self.k * dis_output.size(-1))
118
+ dis_output, _ = torch.topk(dis_output, k)
119
+
120
+ if self.gan_mode == 'non_saturated':
121
+ target_tensor = self.get_target_tensor(dis_output, t_real)
122
+ loss = F.binary_cross_entropy_with_logits(dis_output,
123
+ target_tensor)
124
+ elif self.gan_mode == 'least_square':
125
+ target_tensor = self.get_target_tensor(dis_output, t_real)
126
+ loss = 0.5 * F.mse_loss(dis_output, target_tensor)
127
+ elif self.gan_mode == 'hinge':
128
+ if dis_update:
129
+ if t_real:
130
+ loss = fuse_math_min_mean_pos(dis_output)
131
+ else:
132
+ loss = fuse_math_min_mean_neg(dis_output)
133
+ else:
134
+ loss = -torch.mean(dis_output)
135
+ elif self.gan_mode == 'wasserstein':
136
+ if t_real:
137
+ loss = -torch.mean(dis_output)
138
+ else:
139
+ loss = torch.mean(dis_output)
140
+ elif self.gan_mode == 'softplus':
141
+ target_tensor = self.get_target_tensor(dis_output, t_real)
142
+ loss = F.binary_cross_entropy_with_logits(dis_output,
143
+ target_tensor)
144
+ else:
145
+ raise ValueError('Unexpected gan_mode {}'.format(self.gan_mode))
146
+ return loss
147
+
148
+ def get_target_tensor(self, dis_output, t_real):
149
+ r"""Return the target vector for the binary cross entropy loss
150
+ computation.
151
+
152
+ Args:
153
+ dis_output (tensor): Discriminator outputs.
154
+ t_real (bool): If ``True``, uses the real label as target, otherwise
155
+ uses the fake label as target.
156
+ Returns:
157
+ target (tensor): Target tensor vector.
158
+ """
159
+ if t_real:
160
+ if self.real_label_tensor is None:
161
+ self.real_label_tensor = dis_output.new_tensor(self.real_label)
162
+ return self.real_label_tensor.expand_as(dis_output)
163
+ else:
164
+ if self.fake_label_tensor is None:
165
+ self.fake_label_tensor = dis_output.new_tensor(self.fake_label)
166
+ return self.fake_label_tensor.expand_as(dis_output)
167
+
168
+ def topk_anneal(self):
169
+ r"""Anneal k after each epoch."""
170
+ if self.decay_k < 1:
171
+ # noinspection PyAttributeOutsideInit
172
+ self.k.fill_(max(self.decay_k * self.k, self.min_k))
173
+ print("Top-k training: update k to {}.".format(self.k))
imaginaire/losses/info_nce.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.distributed as dist
10
+
11
+ from imaginaire.utils.distributed import get_world_size, get_rank, \
12
+ dist_all_reduce_tensor
13
+
14
+
15
+ class GatherLayer(torch.autograd.Function):
16
+ @staticmethod
17
+ def forward(ctx, input):
18
+ ctx.save_for_backward(input)
19
+ output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
20
+ dist.all_gather(output, input)
21
+ return tuple(output)
22
+
23
+ @staticmethod
24
+ def backward(ctx, *grads):
25
+ input, = ctx.saved_tensors
26
+ grad_out = torch.zeros_like(input)
27
+ all_grads = torch.stack(grads)
28
+ all_grads = dist_all_reduce_tensor(all_grads, reduce='sum')
29
+ grad_out[:] = all_grads[get_rank()]
30
+ return grad_out
31
+
32
+
33
+ class InfoNCELoss(nn.Module):
34
+ def __init__(self,
35
+ temperature=0.07,
36
+ gather_distributed=True,
37
+ learn_temperature=True,
38
+ single_direction=False,
39
+ flatten=True):
40
+ super(InfoNCELoss, self).__init__()
41
+ self.logit_scale = nn.Parameter(torch.tensor([math.log(1/temperature)]))
42
+ self.logit_scale.requires_grad = learn_temperature
43
+ self.gather_distributed = gather_distributed
44
+ self.single_direction = single_direction
45
+ self.flatten = flatten
46
+
47
+ def forward(self, features_a, features_b, gather_distributed=None, eps=1e-8):
48
+ if gather_distributed is None:
49
+ gather_distributed = self.gather_distributed
50
+
51
+ if features_a is None or features_b is None:
52
+ return torch.tensor(0, device='cuda'), torch.tensor(0, device='cuda')
53
+
54
+ bs_a, bs_b = features_a.size(0), features_b.size(0)
55
+ if self.flatten:
56
+ features_a, features_b = features_a.reshape(bs_a, -1), features_b.reshape(bs_b, -1)
57
+ else:
58
+ features_a = features_a.reshape(bs_a, features_a.size(1), -1).mean(-1)
59
+ features_b = features_b.reshape(bs_b, features_b.size(1), -1).mean(-1)
60
+
61
+ # Temperature clipping.
62
+ self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 4.6052)
63
+
64
+ # normalized features
65
+ features_a = features_a / (features_a.norm(dim=1, keepdim=True) + eps)
66
+ features_b = features_b / (features_b.norm(dim=1, keepdim=True) + eps)
67
+
68
+ loss_a = self._forward_single_direction(features_a, features_b, gather_distributed)
69
+ if self.single_direction:
70
+ return loss_a
71
+ else:
72
+ loss_b = self._forward_single_direction(features_b, features_a, gather_distributed)
73
+ return loss_a + loss_b
74
+
75
+ def _forward_single_direction(
76
+ self, features_a, features_b, gather_distributed):
77
+ bs_a = features_a.shape[0]
78
+ logit_scale = self.logit_scale.exp()
79
+ if get_world_size() > 1 and gather_distributed:
80
+ gather_features_b = torch.cat(GatherLayer.apply(features_b))
81
+ gather_labels_a = torch.arange(bs_a, device='cuda') + get_rank() * bs_a
82
+ logits_a = logit_scale * features_a @ gather_features_b.t()
83
+ else:
84
+ gather_labels_a = torch.arange(bs_a, device='cuda')
85
+ logits_a = logit_scale * features_a @ features_b.t()
86
+ loss_a = F.cross_entropy(logits_a, gather_labels_a)
87
+ return loss_a
imaginaire/losses/kl.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class GaussianKLLoss(nn.Module):
10
+ r"""Compute KL loss in VAE for Gaussian distributions"""
11
+ def __init__(self):
12
+ super(GaussianKLLoss, self).__init__()
13
+
14
+ def forward(self, mu, logvar=None):
15
+ r"""Compute loss
16
+
17
+ Args:
18
+ mu (tensor): mean
19
+ logvar (tensor): logarithm of variance
20
+ """
21
+ if logvar is None:
22
+ logvar = torch.zeros_like(mu)
23
+ return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
imaginaire/losses/perceptual.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchvision
8
+ from torch import nn, distributed as dist
9
+
10
+ from imaginaire.losses.info_nce import InfoNCELoss
11
+ from imaginaire.utils.distributed import master_only_print as print, \
12
+ is_local_master
13
+ from imaginaire.utils.misc import apply_imagenet_normalization, to_float
14
+
15
+
16
+ class PerceptualLoss(nn.Module):
17
+ r"""Perceptual loss initialization.
18
+
19
+ Args:
20
+ network (str) : The name of the loss network: 'vgg16' | 'vgg19'.
21
+ layers (str or list of str) : The layers used to compute the loss.
22
+ weights (float or list of float : The loss weights of each layer.
23
+ criterion (str): The type of distance function: 'l1' | 'l2'.
24
+ resize (bool) : If ``True``, resize the input images to 224x224.
25
+ resize_mode (str): Algorithm used for resizing.
26
+ num_scales (int): The loss will be evaluated at original size and
27
+ this many times downsampled sizes.
28
+ per_sample_weight (bool): Output loss for individual samples in the
29
+ batch instead of mean loss.
30
+ """
31
+
32
+ def __init__(self, network='vgg19', layers='relu_4_1', weights=None,
33
+ criterion='l1', resize=False, resize_mode='bilinear',
34
+ num_scales=1, per_sample_weight=False,
35
+ info_nce_temperature=0.07,
36
+ info_nce_gather_distributed=True,
37
+ info_nce_learn_temperature=True,
38
+ info_nce_flatten=True):
39
+ super().__init__()
40
+ if isinstance(layers, str):
41
+ layers = [layers]
42
+ if weights is None:
43
+ weights = [1.] * len(layers)
44
+ elif isinstance(layers, float) or isinstance(layers, int):
45
+ weights = [weights]
46
+
47
+ if dist.is_initialized() and not is_local_master():
48
+ # Make sure only the first process in distributed training downloads
49
+ # the model, and the others will use the cache
50
+ # noinspection PyUnresolvedReferences
51
+ torch.distributed.barrier()
52
+
53
+ assert len(layers) == len(weights), \
54
+ 'The number of layers (%s) must be equal to ' \
55
+ 'the number of weights (%s).' % (len(layers), len(weights))
56
+ if network == 'vgg19':
57
+ self.model = _vgg19(layers)
58
+ elif network == 'vgg16':
59
+ self.model = _vgg16(layers)
60
+ elif network == 'alexnet':
61
+ self.model = _alexnet(layers)
62
+ elif network == 'inception_v3':
63
+ self.model = _inception_v3(layers)
64
+ elif network == 'resnet50':
65
+ self.model = _resnet50(layers)
66
+ elif network == 'robust_resnet50':
67
+ self.model = _robust_resnet50(layers)
68
+ elif network == 'vgg_face_dag':
69
+ self.model = _vgg_face_dag(layers)
70
+ else:
71
+ raise ValueError('Network %s is not recognized' % network)
72
+
73
+ if dist.is_initialized() and is_local_master():
74
+ # Make sure only the first process in distributed training downloads
75
+ # the model, and the others will use the cache
76
+ # noinspection PyUnresolvedReferences
77
+ torch.distributed.barrier()
78
+
79
+ self.num_scales = num_scales
80
+ self.layers = layers
81
+ self.weights = weights
82
+ reduction = 'mean' if not per_sample_weight else 'none'
83
+ if criterion == 'l1':
84
+ self.criterion = nn.L1Loss(reduction=reduction)
85
+ elif criterion == 'l2' or criterion == 'mse':
86
+ self.criterion = nn.MSELoss(reduction=reduction)
87
+ elif criterion == 'info_nce':
88
+ self.criterion = InfoNCELoss(
89
+ temperature=info_nce_temperature,
90
+ gather_distributed=info_nce_gather_distributed,
91
+ learn_temperature=info_nce_learn_temperature,
92
+ flatten=info_nce_flatten,
93
+ single_direction=True
94
+ )
95
+ else:
96
+ raise ValueError('Criterion %s is not recognized' % criterion)
97
+ self.resize = resize
98
+ self.resize_mode = resize_mode
99
+ print('Perceptual loss:')
100
+ print('\tMode: {}'.format(network))
101
+
102
+ def forward(self, inp, target, per_sample_weights=None):
103
+ r"""Perceptual loss forward.
104
+
105
+ Args:
106
+ inp (4D tensor) : Input tensor.
107
+ target (4D tensor) : Ground truth tensor, same shape as the input.
108
+ per_sample_weight (bool): Output loss for individual samples in the
109
+ batch instead of mean loss.
110
+ Returns:
111
+ (scalar tensor) : The perceptual loss.
112
+ """
113
+ if not torch.is_autocast_enabled():
114
+ inp, target = to_float([inp, target])
115
+
116
+ # Perceptual loss should operate in eval mode by default.
117
+ self.model.eval()
118
+ inp, target = apply_imagenet_normalization(inp), apply_imagenet_normalization(target)
119
+ if self.resize:
120
+ inp = F.interpolate(inp, mode=self.resize_mode, size=(224, 224), align_corners=False)
121
+ target = F.interpolate(target, mode=self.resize_mode, size=(224, 224), align_corners=False)
122
+
123
+ # Evaluate perceptual loss at each scale.
124
+ loss = 0
125
+ for scale in range(self.num_scales):
126
+ input_features, target_features = self.model(inp), self.model(target)
127
+
128
+ for layer, weight in zip(self.layers, self.weights):
129
+ # Example per-layer VGG19 loss values after applying
130
+ # [0.03125, 0.0625, 0.125, 0.25, 1.0] weighting.
131
+ # relu_1_1, 0.014698
132
+ # relu_2_1, 0.085817
133
+ # relu_3_1, 0.349977
134
+ # relu_4_1, 0.544188
135
+ # relu_5_1, 0.906261
136
+ # print('%s, %f' % (
137
+ # layer,
138
+ # weight * self.criterion(
139
+ # input_features[layer],
140
+ # target_features[
141
+ # layer].detach()).item()))
142
+ l_tmp = self.criterion(input_features[layer], target_features[layer].detach())
143
+ if per_sample_weights is not None:
144
+ l_tmp = l_tmp.mean(1).mean(1).mean(1)
145
+ loss += weight * l_tmp
146
+ # Downsample the input and target.
147
+ if scale != self.num_scales - 1:
148
+ inp = F.interpolate(
149
+ inp, mode=self.resize_mode, scale_factor=0.5,
150
+ align_corners=False, recompute_scale_factor=True)
151
+ target = F.interpolate(
152
+ target, mode=self.resize_mode, scale_factor=0.5,
153
+ align_corners=False, recompute_scale_factor=True)
154
+
155
+ return loss.float()
156
+
157
+
158
+ class _PerceptualNetwork(nn.Module):
159
+ r"""The network that extracts features to compute the perceptual loss.
160
+
161
+ Args:
162
+ network (nn.Sequential) : The network that extracts features.
163
+ layer_name_mapping (dict) : The dictionary that
164
+ maps a layer's index to its name.
165
+ layers (list of str): The list of layer names that we are using.
166
+ """
167
+
168
+ def __init__(self, network, layer_name_mapping, layers):
169
+ super().__init__()
170
+ assert isinstance(network, nn.Sequential), \
171
+ 'The network needs to be of type "nn.Sequential".'
172
+ self.network = network
173
+ self.layer_name_mapping = layer_name_mapping
174
+ self.layers = layers
175
+ for param in self.parameters():
176
+ param.requires_grad = False
177
+
178
+ def forward(self, x):
179
+ r"""Extract perceptual features."""
180
+ output = {}
181
+ for i, layer in enumerate(self.network):
182
+ x = layer(x)
183
+ layer_name = self.layer_name_mapping.get(i, None)
184
+ if layer_name in self.layers:
185
+ # If the current layer is used by the perceptual loss.
186
+ output[layer_name] = x
187
+ return output
188
+
189
+
190
+ def _vgg19(layers):
191
+ r"""Get vgg19 layers"""
192
+ vgg = torchvision.models.vgg19(pretrained=True)
193
+ # network = vgg.features
194
+ network = torch.nn.Sequential(*(list(vgg.features) + [vgg.avgpool] + [nn.Flatten()] + list(vgg.classifier)))
195
+ layer_name_mapping = {1: 'relu_1_1',
196
+ 3: 'relu_1_2',
197
+ 6: 'relu_2_1',
198
+ 8: 'relu_2_2',
199
+ 11: 'relu_3_1',
200
+ 13: 'relu_3_2',
201
+ 15: 'relu_3_3',
202
+ 17: 'relu_3_4',
203
+ 20: 'relu_4_1',
204
+ 22: 'relu_4_2',
205
+ 24: 'relu_4_3',
206
+ 26: 'relu_4_4',
207
+ 29: 'relu_5_1',
208
+ 31: 'relu_5_2',
209
+ 33: 'relu_5_3',
210
+ 35: 'relu_5_4',
211
+ 36: 'pool_5',
212
+ 42: 'fc_2'}
213
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
214
+
215
+
216
+ def _vgg16(layers):
217
+ r"""Get vgg16 layers"""
218
+ network = torchvision.models.vgg16(pretrained=True).features
219
+ layer_name_mapping = {1: 'relu_1_1',
220
+ 3: 'relu_1_2',
221
+ 6: 'relu_2_1',
222
+ 8: 'relu_2_2',
223
+ 11: 'relu_3_1',
224
+ 13: 'relu_3_2',
225
+ 15: 'relu_3_3',
226
+ 18: 'relu_4_1',
227
+ 20: 'relu_4_2',
228
+ 22: 'relu_4_3',
229
+ 25: 'relu_5_1'}
230
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
231
+
232
+
233
+ def _alexnet(layers):
234
+ r"""Get alexnet layers"""
235
+ network = torchvision.models.alexnet(pretrained=True).features
236
+ layer_name_mapping = {0: 'conv_1',
237
+ 1: 'relu_1',
238
+ 3: 'conv_2',
239
+ 4: 'relu_2',
240
+ 6: 'conv_3',
241
+ 7: 'relu_3',
242
+ 8: 'conv_4',
243
+ 9: 'relu_4',
244
+ 10: 'conv_5',
245
+ 11: 'relu_5'}
246
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
247
+
248
+
249
+ def _inception_v3(layers):
250
+ r"""Get inception v3 layers"""
251
+ inception = torchvision.models.inception_v3(pretrained=True)
252
+ network = nn.Sequential(inception.Conv2d_1a_3x3,
253
+ inception.Conv2d_2a_3x3,
254
+ inception.Conv2d_2b_3x3,
255
+ nn.MaxPool2d(kernel_size=3, stride=2),
256
+ inception.Conv2d_3b_1x1,
257
+ inception.Conv2d_4a_3x3,
258
+ nn.MaxPool2d(kernel_size=3, stride=2),
259
+ inception.Mixed_5b,
260
+ inception.Mixed_5c,
261
+ inception.Mixed_5d,
262
+ inception.Mixed_6a,
263
+ inception.Mixed_6b,
264
+ inception.Mixed_6c,
265
+ inception.Mixed_6d,
266
+ inception.Mixed_6e,
267
+ inception.Mixed_7a,
268
+ inception.Mixed_7b,
269
+ inception.Mixed_7c,
270
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)))
271
+ layer_name_mapping = {3: 'pool_1',
272
+ 6: 'pool_2',
273
+ 14: 'mixed_6e',
274
+ 18: 'pool_3'}
275
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
276
+
277
+
278
+ def _resnet50(layers):
279
+ r"""Get resnet50 layers"""
280
+ resnet50 = torchvision.models.resnet50(pretrained=True)
281
+ network = nn.Sequential(resnet50.conv1,
282
+ resnet50.bn1,
283
+ resnet50.relu,
284
+ resnet50.maxpool,
285
+ resnet50.layer1,
286
+ resnet50.layer2,
287
+ resnet50.layer3,
288
+ resnet50.layer4,
289
+ resnet50.avgpool)
290
+ layer_name_mapping = {4: 'layer_1',
291
+ 5: 'layer_2',
292
+ 6: 'layer_3',
293
+ 7: 'layer_4'}
294
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
295
+
296
+
297
+ def _robust_resnet50(layers):
298
+ r"""Get robust resnet50 layers"""
299
+ resnet50 = torchvision.models.resnet50(pretrained=False)
300
+ state_dict = torch.utils.model_zoo.load_url(
301
+ 'http://andrewilyas.com/ImageNet.pt')
302
+ new_state_dict = {}
303
+ for k, v in state_dict['model'].items():
304
+ if k.startswith('module.model.'):
305
+ new_state_dict[k[13:]] = v
306
+ resnet50.load_state_dict(new_state_dict)
307
+ network = nn.Sequential(resnet50.conv1,
308
+ resnet50.bn1,
309
+ resnet50.relu,
310
+ resnet50.maxpool,
311
+ resnet50.layer1,
312
+ resnet50.layer2,
313
+ resnet50.layer3,
314
+ resnet50.layer4,
315
+ resnet50.avgpool)
316
+ layer_name_mapping = {4: 'layer_1',
317
+ 5: 'layer_2',
318
+ 6: 'layer_3',
319
+ 7: 'layer_4'}
320
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
321
+
322
+
323
+ def _vgg_face_dag(layers):
324
+ network = torchvision.models.vgg16(num_classes=2622)
325
+ state_dict = torch.utils.model_zoo.load_url(
326
+ 'http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/'
327
+ 'vgg_face_dag.pth')
328
+ feature_layer_name_mapping = {
329
+ 0: 'conv1_1',
330
+ 2: 'conv1_2',
331
+ 5: 'conv2_1',
332
+ 7: 'conv2_2',
333
+ 10: 'conv3_1',
334
+ 12: 'conv3_2',
335
+ 14: 'conv3_3',
336
+ 17: 'conv4_1',
337
+ 19: 'conv4_2',
338
+ 21: 'conv4_3',
339
+ 24: 'conv5_1',
340
+ 26: 'conv5_2',
341
+ 28: 'conv5_3'}
342
+ new_state_dict = {}
343
+ for k, v in feature_layer_name_mapping.items():
344
+ new_state_dict['features.' + str(k) + '.weight'] = \
345
+ state_dict[v + '.weight']
346
+ new_state_dict['features.' + str(k) + '.bias'] = \
347
+ state_dict[v + '.bias']
348
+
349
+ classifier_layer_name_mapping = {
350
+ 0: 'fc6',
351
+ 3: 'fc7',
352
+ 6: 'fc8'}
353
+ for k, v in classifier_layer_name_mapping.items():
354
+ new_state_dict['classifier.' + str(k) + '.weight'] = \
355
+ state_dict[v + '.weight']
356
+ new_state_dict['classifier.' + str(k) + '.bias'] = \
357
+ state_dict[v + '.bias']
358
+
359
+ network.load_state_dict(new_state_dict)
360
+
361
+ class Flatten(nn.Module):
362
+ def forward(self, x):
363
+ return x.view(x.shape[0], -1)
364
+
365
+ layer_name_mapping = {
366
+ 0: 'conv_1_1',
367
+ 1: 'relu_1_1',
368
+ 2: 'conv_1_2',
369
+ 5: 'conv_2_1', # 1/2
370
+ 6: 'relu_2_1',
371
+ 7: 'conv_2_2',
372
+ 10: 'conv_3_1', # 1/4
373
+ 11: 'relu_3_1',
374
+ 12: 'conv_3_2',
375
+ 14: 'conv_3_3',
376
+ 17: 'conv_4_1', # 1/8
377
+ 18: 'relu_4_1',
378
+ 19: 'conv_4_2',
379
+ 21: 'conv_4_3',
380
+ 24: 'conv_5_1', # 1/16
381
+ 25: 'relu_5_1',
382
+ 26: 'conv_5_2',
383
+ 28: 'conv_5_3',
384
+ 33: 'fc6',
385
+ 36: 'fc7',
386
+ 39: 'fc8'
387
+ }
388
+ seq_layers = []
389
+ for feature in network.features:
390
+ seq_layers += [feature]
391
+ seq_layers += [network.avgpool, Flatten()]
392
+ for classifier in network.classifier:
393
+ seq_layers += [classifier]
394
+ network = nn.Sequential(*seq_layers)
395
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
imaginaire/losses/weighted_mse.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class WeightedMSELoss(nn.Module):
10
+ r"""Compute Weighted MSE loss"""
11
+ def __init__(self, reduction='mean'):
12
+ super(WeightedMSELoss, self).__init__()
13
+ self.reduction = reduction
14
+
15
+ def forward(self, input, target, weight):
16
+ r"""Return weighted MSE Loss.
17
+ Args:
18
+ input (tensor):
19
+ target (tensor):
20
+ weight (tensor):
21
+ Returns:
22
+ (tensor): Loss value.
23
+ """
24
+ if self.reduction == 'mean':
25
+ loss = torch.mean(weight * (input - target) ** 2)
26
+ else:
27
+ loss = torch.sum(weight * (input - target) ** 2)
28
+ return loss
imaginaire/model_utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
imaginaire/model_utils/gancraft/camctl.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ class EvalCameraController:
10
+ def __init__(self, voxel, maxstep=128, pattern=0, cam_ang=73, smooth_decay_multiplier=1.0):
11
+ self.voxel = voxel
12
+ self.maxstep = maxstep
13
+ self.camera_poses = [] # ori, dir, up, f
14
+ circle = torch.linspace(0, 2*np.pi, steps=maxstep)
15
+ size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
16
+ # Shrink the circle a bit.
17
+ shift = size * 0.2
18
+ size = size * 0.8
19
+
20
+ if pattern == 0:
21
+ height_history = []
22
+ # Calculate smooth height.
23
+ for i in range(maxstep):
24
+ farpoint = torch.tensor([
25
+ 70,
26
+ torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
27
+ torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
28
+ height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
29
+
30
+ # Filtfilt
31
+ height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
32
+
33
+ for i in range(maxstep):
34
+ farpoint = torch.tensor([
35
+ 70,
36
+ torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
37
+ torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
38
+
39
+ farpoint[0] = height_history[i]
40
+
41
+ nearpoint = torch.tensor([
42
+ 60,
43
+ torch.sin(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(1)/2 + shift,
44
+ torch.cos(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(2)/2 + shift])
45
+ cam_ori = self.voxel.world2local(farpoint)
46
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
47
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
48
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
49
+
50
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
51
+
52
+ elif pattern == 1:
53
+ zoom = torch.linspace(1.0, 0.25, steps=maxstep)
54
+ height_history = []
55
+ for i in range(maxstep):
56
+ farpoint = torch.tensor([
57
+ 90,
58
+ torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
59
+ torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
60
+
61
+ height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
62
+
63
+ height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
64
+
65
+ for i in range(maxstep):
66
+ farpoint = torch.tensor([
67
+ 90,
68
+ torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
69
+ torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
70
+
71
+ farpoint[0] = height_history[i]
72
+
73
+ nearpoint = torch.tensor([
74
+ 60,
75
+ torch.sin(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(1)/2 + shift,
76
+ torch.cos(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(2)/2 + shift])
77
+ cam_ori = self.voxel.world2local(farpoint)
78
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
79
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
80
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)*zoom[i]) # about 24mm fov
81
+
82
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
83
+
84
+ elif pattern == 2:
85
+ move = torch.linspace(1.0, 0.2, steps=maxstep)
86
+ height_history = []
87
+ for i in range(maxstep):
88
+ farpoint = torch.tensor([
89
+ 90,
90
+ torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
91
+ torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
92
+
93
+ height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
94
+
95
+ height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
96
+
97
+ for i in range(maxstep):
98
+ farpoint = torch.tensor([
99
+ 90,
100
+ torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
101
+ torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
102
+
103
+ farpoint[0] = height_history[i]
104
+
105
+ nearpoint = torch.tensor([
106
+ 60,
107
+ torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
108
+ torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
109
+ cam_ori = self.voxel.world2local(farpoint)
110
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
111
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
112
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
113
+
114
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
115
+
116
+ elif pattern == 3:
117
+ move = torch.linspace(0.75, 0.2, steps=maxstep)
118
+ height_history = []
119
+ for i in range(maxstep):
120
+ farpoint = torch.tensor([
121
+ 70,
122
+ torch.sin(-circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
123
+ torch.cos(-circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
124
+
125
+ height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
126
+
127
+ height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
128
+
129
+ for i in range(maxstep):
130
+ farpoint = torch.tensor([
131
+ 70,
132
+ torch.sin(-circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
133
+ torch.cos(-circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
134
+
135
+ farpoint[0] = height_history[i]
136
+
137
+ nearpoint = torch.tensor([
138
+ 60,
139
+ torch.sin(-circle[i]-0.4*np.pi)*size*0.9*move[i] + voxel.voxel_t.size(1)/2 + shift,
140
+ torch.cos(-circle[i]-0.4*np.pi)*size*0.9*move[i] + voxel.voxel_t.size(2)/2 + shift])
141
+ cam_ori = self.voxel.world2local(farpoint)
142
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
143
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
144
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
145
+
146
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
147
+
148
+ elif pattern == 4:
149
+ move = torch.linspace(1.0, 0.5, steps=maxstep)
150
+ height_history = []
151
+ for i in range(maxstep):
152
+ farpoint = torch.tensor([
153
+ 90,
154
+ torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
155
+ torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
156
+
157
+ height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
158
+
159
+ height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
160
+
161
+ for i in range(maxstep):
162
+ farpoint = torch.tensor([
163
+ 90,
164
+ torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
165
+ torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
166
+
167
+ farpoint[0] = height_history[i]
168
+
169
+ nearpoint = torch.tensor([
170
+ 60,
171
+ torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
172
+ torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
173
+ cam_ori = self.voxel.world2local(farpoint)
174
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
175
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
176
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
177
+
178
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
179
+
180
+ # look outward
181
+ elif pattern == 5:
182
+ move = torch.linspace(1.0, 0.5, steps=maxstep)
183
+ height_history = []
184
+ for i in range(maxstep):
185
+ nearpoint = torch.tensor([
186
+ 60,
187
+ torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
188
+ torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
189
+
190
+ height_history.append(self._get_height(nearpoint[1], nearpoint[2], nearpoint[0]))
191
+
192
+ height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
193
+
194
+ for i in range(maxstep):
195
+ nearpoint = torch.tensor([
196
+ 60,
197
+ torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
198
+ torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
199
+
200
+ nearpoint[0] = height_history[i]
201
+
202
+ farpoint = torch.tensor([
203
+ 60,
204
+ torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
205
+ torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
206
+
207
+ cam_ori = self.voxel.world2local(nearpoint)
208
+ cam_dir = self.voxel.world2local(farpoint - nearpoint, is_vec=True)
209
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
210
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
211
+
212
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
213
+ # Rise
214
+ elif pattern == 6:
215
+ shift = 0
216
+ lift = torch.linspace(0.0, 200.0, steps=maxstep)
217
+ zoom = torch.linspace(0.8, 1.6, steps=maxstep)
218
+ for i in range(maxstep):
219
+ farpoint = torch.tensor([
220
+ 80+lift[i],
221
+ torch.sin(circle[i]/4)*size*0.2 + voxel.voxel_t.size(1)/2 + shift,
222
+ torch.cos(circle[i]/4)*size*0.2 + voxel.voxel_t.size(2)/2 + shift])
223
+
224
+ farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
225
+
226
+ nearpoint = torch.tensor([
227
+ 65,
228
+ torch.sin(circle[i]/4+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(1)/2 + shift,
229
+ torch.cos(circle[i]/4+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(2)/2 + shift])
230
+ cam_ori = self.voxel.world2local(farpoint)
231
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
232
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
233
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov
234
+
235
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
236
+ # 45deg
237
+ elif pattern == 7:
238
+ rad = torch.tensor([np.deg2rad(45).astype(np.float32)])
239
+ size = 1536
240
+ for i in range(maxstep):
241
+ farpoint = torch.tensor([
242
+ 61+size,
243
+ torch.sin(rad)*size + voxel.voxel_t.size(1)/2,
244
+ torch.cos(rad)*size + voxel.voxel_t.size(2)/2])
245
+
246
+ nearpoint = torch.tensor([
247
+ 61,
248
+ voxel.voxel_t.size(1)/2,
249
+ voxel.voxel_t.size(2)/2])
250
+ cam_ori = self.voxel.world2local(farpoint)
251
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
252
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
253
+ cam_f = 0.5/np.tan(np.deg2rad(19.5/2)) # about 50mm fov
254
+
255
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
256
+
257
+ elif pattern == 8:
258
+ size = self.voxel.voxel_t.size(1) // 2
259
+ for i in range(maxstep):
260
+ farpoint = torch.tensor([
261
+ 300,
262
+ 0*size + voxel.voxel_t.size(1)//2,
263
+ -1*size + voxel.voxel_t.size(2)/2 + size // maxstep * (i - maxstep // 4)])
264
+ nearpoint = torch.tensor([
265
+ 120,
266
+ 0*size*0.5 + voxel.voxel_t.size(1)//2,
267
+ -1*size*0.5 + voxel.voxel_t.size(2)/2 + size // maxstep * (i - maxstep // 4)])
268
+ cam_ori = self.voxel.world2local(farpoint)
269
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
270
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
271
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
272
+
273
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
274
+
275
+ elif pattern == 9:
276
+ size = self.voxel.voxel_t.size(2) // 2
277
+ for i in range(maxstep):
278
+ farpoint = torch.tensor([
279
+ 140,
280
+ voxel.voxel_t.size(1)//2,
281
+ -size // 4 + size * 8 // maxstep * i]
282
+ , dtype=torch.float32)
283
+ nearpoint = torch.tensor([
284
+ 100,
285
+ voxel.voxel_t.size(1)//2,
286
+ size * 8 // maxstep * i]
287
+ , dtype=torch.float32)
288
+ cam_ori = self.voxel.world2local(farpoint)
289
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
290
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
291
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
292
+
293
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
294
+
295
+
296
+ def _get_height(self, loc0, loc1, minheight):
297
+ loc0 = int(loc0)
298
+ loc1 = int(loc1)
299
+ height = minheight
300
+ for dx in range(-3, 4):
301
+ for dy in range(-3, 4):
302
+ if (loc0+dx) < 0 or (loc0+dx) >= self.voxel.heightmap.shape[0] or (loc1+dy) < 0 or \
303
+ (loc1+dy) >= self.voxel.heightmap.shape[1]:
304
+ height = max(height, minheight)
305
+ else:
306
+ height = max(height, self.voxel.heightmap[loc0+dx, loc1+dy] + 2)
307
+ return height
308
+
309
+ def filtfilt(self, height_history, decay=0.2):
310
+ # Filtfilt
311
+ height_history2 = []
312
+ maxstep = len(height_history)
313
+ prev_height = height_history[0]
314
+ for i in range(maxstep):
315
+ prev_height = prev_height - decay
316
+ if prev_height < height_history[i]:
317
+ prev_height = height_history[i]
318
+ height_history2.append(prev_height)
319
+ prev_height = height_history[-1]
320
+ for i in range(maxstep-1, -1, -1):
321
+ prev_height = prev_height - decay
322
+ if prev_height < height_history[i]:
323
+ prev_height = height_history[i]
324
+ height_history2[i] = max(prev_height, height_history2[i])
325
+ return height_history2
326
+
327
+ def __len__(self):
328
+ return len(self.camera_poses)
329
+
330
+ def __getitem__(self, idx):
331
+ return self.camera_poses[idx]
332
+
333
+
334
+ class TourCameraController:
335
+ def __init__(self, voxel, maxstep=128):
336
+ self.voxel = voxel
337
+ self.maxstep = maxstep
338
+ self.camera_poses = [] # ori, dir, up, f
339
+ circle = torch.linspace(0, 2*np.pi, steps=maxstep//4)
340
+ size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
341
+ # Shrink the circle a bit
342
+ shift = size * 0.2
343
+ size = size * 0.8
344
+
345
+ for i in range(maxstep//4):
346
+ farpoint = torch.tensor([
347
+ 70,
348
+ torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
349
+ torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
350
+
351
+ farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
352
+
353
+ nearpoint = torch.tensor([
354
+ 60,
355
+ torch.sin(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(1)/2 + shift,
356
+ torch.cos(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(2)/2 + shift])
357
+ cam_ori = self.voxel.world2local(farpoint)
358
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
359
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
360
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)) # about 24mm fov
361
+
362
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
363
+
364
+ zoom = torch.linspace(1.0, 0.25, steps=maxstep//4)
365
+ for i in range(maxstep//4):
366
+ farpoint = torch.tensor([
367
+ 90,
368
+ torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
369
+ torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
370
+
371
+ farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
372
+
373
+ nearpoint = torch.tensor([
374
+ 60,
375
+ torch.sin(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(1)/2 + shift,
376
+ torch.cos(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(2)/2 + shift])
377
+ cam_ori = self.voxel.world2local(farpoint)
378
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
379
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
380
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov
381
+
382
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
383
+
384
+ move = torch.linspace(1.0, 0.2, steps=maxstep//4)
385
+ for i in range(maxstep//4):
386
+ farpoint = torch.tensor([
387
+ 90,
388
+ torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
389
+ torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
390
+
391
+ farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
392
+
393
+ nearpoint = torch.tensor([
394
+ 60,
395
+ torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
396
+ torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
397
+ cam_ori = self.voxel.world2local(farpoint)
398
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
399
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
400
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)) # about 24mm fov
401
+
402
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
403
+
404
+ lift = torch.linspace(0.0, 200.0, steps=maxstep//4)
405
+ zoom = torch.linspace(0.6, 1.2, steps=maxstep//4)
406
+ for i in range(maxstep//4):
407
+ farpoint = torch.tensor([
408
+ 80+lift[i],
409
+ torch.sin(circle[i])*size*0.2 + voxel.voxel_t.size(1)/2 + shift,
410
+ torch.cos(circle[i])*size*0.2 + voxel.voxel_t.size(2)/2 + shift])
411
+
412
+ farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
413
+
414
+ nearpoint = torch.tensor([
415
+ 60,
416
+ torch.sin(circle[i]+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(1)/2 + shift,
417
+ torch.cos(circle[i]+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(2)/2 + shift])
418
+ cam_ori = self.voxel.world2local(farpoint)
419
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
420
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
421
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov
422
+
423
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
424
+
425
+ def _get_height(self, loc0, loc1, minheight):
426
+ loc0 = int(loc0)
427
+ loc1 = int(loc1)
428
+ height = minheight
429
+ for dx in range(-3, 4):
430
+ for dy in range(-3, 4):
431
+ if (loc0+dx) < 0 or (loc0+dx) >= self.voxel.heightmap.shape[0] or (loc1+dy) < 0 or \
432
+ (loc1+dy) >= self.voxel.heightmap.shape[1]:
433
+ height = max(height, minheight)
434
+ else:
435
+ height = max(height, self.voxel.heightmap[loc0+dx, loc1+dy] + 2)
436
+ return height
437
+
438
+ def __len__(self):
439
+ return len(self.camera_poses)
440
+
441
+ def __getitem__(self, idx):
442
+ return self.camera_poses[idx]
443
+
444
+
445
+ def rand_camera_pose_birdseye(voxel, border=128):
446
+ r"""Generating random camera pose in the upper hemisphere, in the format of origin-direction-up
447
+ Assuming [Y X Z] coordinate. Y is negative gravity direction.
448
+ The camera pose is converted into the voxel coordinate system so that it can be used directly for rendering
449
+ 1. Uniformly sample a point on the upper hemisphere of a unit sphere, as cam_ori.
450
+ 2. Set cam_dir to be from cam_ori to the origin
451
+ 3. cam_up is always pointing towards sky
452
+ 4. move cam_ori to random place according to voxel size
453
+ """
454
+ cam_dir = torch.randn(3, dtype=torch.float32)
455
+ cam_dir = cam_dir / torch.sqrt(torch.sum(cam_dir*cam_dir))
456
+ cam_dir[0] = -torch.abs(cam_dir[0])
457
+ cam_up = torch.tensor([1, 0, 0], dtype=torch.float32)
458
+
459
+ # generate camera lookat target
460
+ r = np.random.rand(2)
461
+ r[0] *= voxel.voxel_t.size(1)-border-border
462
+ r[1] *= voxel.voxel_t.size(2)-border-border
463
+ r = r + border
464
+ y = voxel.heightmap[int(r[0]+0.5), int(r[1]+0.5)] + (np.random.rand(1)-0.5) * 5
465
+ cam_target = torch.tensor([y, r[0], r[1]], dtype=torch.float32)
466
+ cam_ori = cam_target - cam_dir * (np.random.rand(1).item() * 100)
467
+ cam_ori[0] = max(voxel.heightmap[int(cam_ori[1]+0.5), int(cam_ori[2]+0.5)]+2, cam_ori[0])
468
+ # Translate to voxel coordinate
469
+ cam_ori = voxel.world2local(cam_ori)
470
+ cam_dir = voxel.world2local(cam_dir, is_vec=True)
471
+ cam_up = voxel.world2local(cam_up, is_vec=True)
472
+
473
+ return cam_ori, cam_dir, cam_up
474
+
475
+
476
+ def get_neighbor_height(heightmap, loc0, loc1, minheight, neighbor_size=7):
477
+ loc0 = int(loc0)
478
+ loc1 = int(loc1)
479
+ height = 0
480
+ for dx in range(-neighbor_size//2, neighbor_size//2+1):
481
+ for dy in range(-neighbor_size//2, neighbor_size//2+1):
482
+ if (loc0+dx) < 0 or (loc0+dx) >= heightmap.shape[0] or (loc1+dy) < 0 or (loc1+dy) >= heightmap.shape[1]:
483
+ height = max(height, minheight)
484
+ else:
485
+ height = max(minheight, heightmap[loc0+dx, loc1+dy] + 2)
486
+ return height
487
+
488
+
489
+ def rand_camera_pose_firstperson(voxel, border=128):
490
+ r"""Generating random camera pose in the upper hemisphere, in the format of origin-direction-up
491
+ """
492
+ r = np.random.rand(5)
493
+ r[0] *= voxel.voxel_t.size(1)-border-border
494
+ r[1] *= voxel.voxel_t.size(2)-border-border
495
+ r[0] = r[0] + border
496
+ r[1] = r[1] + border
497
+
498
+ y = get_neighbor_height(voxel.heightmap, r[0], r[1], 0) + np.random.rand(1) * 15
499
+
500
+ cam_ori = torch.tensor([y, r[0], r[1]], dtype=torch.float32)
501
+
502
+ rand_ang_h = r[2] * 2 * np.pi
503
+ cam_target = torch.tensor([0, cam_ori[1]+np.sin(rand_ang_h)*border*r[4], cam_ori[2] +
504
+ np.cos(rand_ang_h)*border*r[4]], dtype=torch.float32)
505
+ cam_target[0] = get_neighbor_height(voxel.heightmap, cam_target[1],
506
+ cam_target[2], 0, neighbor_size=1) - 2 + r[3] * 10
507
+
508
+ cam_dir = cam_target - cam_ori
509
+
510
+ cam_up = torch.tensor([1, 0, 0], dtype=torch.float32)
511
+
512
+ cam_ori = voxel.world2local(cam_ori)
513
+ cam_dir = voxel.world2local(cam_dir, is_vec=True)
514
+ cam_up = voxel.world2local(cam_up, is_vec=True)
515
+
516
+ return cam_ori, cam_dir, cam_up
517
+
518
+
519
+ def rand_camera_pose_thridperson(voxel, border=96):
520
+ r = torch.rand(2)
521
+ r[0] *= voxel.voxel_t.size(1)
522
+ r[1] *= voxel.voxel_t.size(2)
523
+ rand_height = 60 + torch.rand(1) * 40
524
+ rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=5)
525
+ farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
526
+
527
+ r = torch.rand(2)
528
+ r[0] *= voxel.voxel_t.size(1) - border - border
529
+ r[1] *= voxel.voxel_t.size(2) - border - border
530
+ r[0] = r[0] + border
531
+ r[1] = r[1] + border
532
+ rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=1) - 5
533
+ nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
534
+
535
+ cam_ori = voxel.world2local(farpoint)
536
+ cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
537
+ cam_up = voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
538
+
539
+ return cam_ori, cam_dir, cam_up
540
+
541
+
542
+ def rand_camera_pose_thridperson2(voxel, border=48):
543
+ r = torch.rand(2)
544
+ r[0] *= voxel.voxel_t.size(1) - border - border
545
+ r[1] *= voxel.voxel_t.size(2) - border - border
546
+ r[0] = r[0] + border
547
+ r[1] = r[1] + border
548
+ rand_height = 60 + torch.rand(1) * 40
549
+ rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=5)
550
+ farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
551
+
552
+ r = torch.rand(2)
553
+ r[0] *= voxel.voxel_t.size(1) - border - border
554
+ r[1] *= voxel.voxel_t.size(2) - border - border
555
+ r[0] = r[0] + border
556
+ r[1] = r[1] + border
557
+ rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=1) - 5
558
+ nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
559
+
560
+ # Random Up vector (tilt a little bit)
561
+ # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
562
+ up = torch.randn(3) * 0.02
563
+ up[0] = 1.0
564
+ up = up / up.norm(p=2)
565
+ cam_ori = voxel.world2local(farpoint)
566
+ cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
567
+ cam_up = voxel.world2local(up, is_vec=True)
568
+
569
+ return cam_ori, cam_dir, cam_up
570
+
571
+
572
+ def rand_camera_pose_thridperson3(voxel, border=64):
573
+ r"""Attempting to solve the camera too close to wall problem and the lack of aerial poses."""
574
+ r = torch.rand(2)
575
+ r[0] *= voxel.voxel_t.size(1) - border - border
576
+ r[1] *= voxel.voxel_t.size(2) - border - border
577
+ r[0] = r[0] + border
578
+ r[1] = r[1] + border
579
+ rand_height = 60 + torch.rand(1) * 40
580
+ if torch.rand(1) > 0.8:
581
+ rand_height = 60 + torch.rand(1) * 60
582
+ rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=7)
583
+ farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
584
+
585
+ r = torch.rand(2)
586
+ r[0] *= voxel.voxel_t.size(1) - border - border
587
+ r[1] *= voxel.voxel_t.size(2) - border - border
588
+ r[0] = r[0] + border
589
+ r[1] = r[1] + border
590
+ rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=3) - 5
591
+ nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
592
+
593
+ # Random Up vector (tilt a little bit)
594
+ # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
595
+ up = torch.randn(3) * 0.02
596
+ up[0] = 1.0
597
+ up = up / up.norm(p=2)
598
+ # print(up)
599
+ cam_ori = voxel.world2local(farpoint)
600
+ cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
601
+ cam_up = voxel.world2local(up, is_vec=True)
602
+
603
+ return cam_ori, cam_dir, cam_up
604
+
605
+
606
+ def rand_camera_pose_tour(voxel):
607
+ size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
608
+ center = [voxel.voxel_t.size(1)/2, voxel.voxel_t.size(2)/2]
609
+
610
+ rnd = torch.rand(8)
611
+
612
+ rnd_deg = torch.rand(1) * 2 * np.pi
613
+ far_radius = rnd[0]*0.8+0.2
614
+ far_height = rnd[1]*30 + 60
615
+ farpoint = torch.tensor([
616
+ far_height,
617
+ torch.sin(rnd_deg)*size*far_radius + center[0],
618
+ torch.cos(rnd_deg)*size*far_radius + center[1]])
619
+
620
+ farpoint[0] = get_neighbor_height(voxel.heightmap, farpoint[1], farpoint[2], farpoint[0], neighbor_size=7)
621
+
622
+ near_radius = far_radius * rnd[2]
623
+ near_shift_rad = np.pi*(rnd[3]-0.5)
624
+ near_height = 60 + rnd[4] * 10
625
+ nearpoint = torch.tensor([
626
+ near_height,
627
+ torch.sin(rnd_deg+near_shift_rad)*size*near_radius + center[0],
628
+ torch.cos(rnd_deg+near_shift_rad)*size*near_radius + center[1]])
629
+
630
+ # Random Up vector (tilt a little bit)
631
+ # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
632
+ up = torch.randn(3) * 0.02
633
+ up[0] = 1.0
634
+ up = up / up.norm(p=2)
635
+ cam_ori = voxel.world2local(farpoint)
636
+ cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
637
+ cam_up = voxel.world2local(up, is_vec=True)
638
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)*(rnd[5]*0.75+0.25)) # about 24mm fov
639
+
640
+ return cam_ori, cam_dir, cam_up, cam_f
641
+
642
+ # Look from center to outward
643
+
644
+
645
+ def rand_camera_pose_insideout(voxel):
646
+ size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
647
+ center = [voxel.voxel_t.size(1)/2, voxel.voxel_t.size(2)/2]
648
+
649
+ rnd = torch.rand(8)
650
+
651
+ rnd_deg = torch.rand(1) * 2 * np.pi
652
+ far_radius = rnd[0]*0.8+0.2
653
+ far_height = rnd[1]*10 + 60
654
+ farpoint = torch.tensor([
655
+ far_height,
656
+ torch.sin(rnd_deg)*size*far_radius + center[0],
657
+ torch.cos(rnd_deg)*size*far_radius + center[1]])
658
+
659
+ near_radius = far_radius * rnd[2]
660
+ near_shift_rad = np.pi*(rnd[3]-0.5)
661
+ near_height = 60 + rnd[4] * 30
662
+ nearpoint = torch.tensor([
663
+ near_height,
664
+ torch.sin(rnd_deg+near_shift_rad)*size*near_radius + center[0],
665
+ torch.cos(rnd_deg+near_shift_rad)*size*near_radius + center[1]])
666
+
667
+ nearpoint[0] = get_neighbor_height(voxel.heightmap, nearpoint[1], nearpoint[2], nearpoint[0], neighbor_size=7)
668
+
669
+ # Random Up vector (tilt a little bit)
670
+ # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
671
+ up = torch.randn(3) * 0.02
672
+ up[0] = 1.0
673
+ up = up / up.norm(p=2)
674
+ cam_ori = voxel.world2local(nearpoint)
675
+ cam_dir = voxel.world2local(farpoint-nearpoint, is_vec=True)
676
+ cam_up = voxel.world2local(up, is_vec=True)
677
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)*(rnd[5]*0.75+0.25)) # about 24mm fov
678
+
679
+ return cam_ori, cam_dir, cam_up, cam_f
imaginaire/model_utils/gancraft/gaugan_lbl2col.csv ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ person,#00AC0D
2
+ bicycle,#012F47
3
+ car,#0275B8
4
+ motorcycle,#03C098
5
+ airplane,#04434F
6
+ bus,#05FB29
7
+ train,#06C312
8
+ truck,#076728
9
+ boat,#0809B6
10
+ traffic-light,#09D3CF
11
+ fire-hydrant,#0A150B
12
+ street-sign,#0BF2A6
13
+ stop-sign,#0C246F
14
+ parking-meter,#0D575D
15
+ bench,#0E46F9
16
+ bird,#0FD881
17
+ cat,#1058DF
18
+ dog,#118C76
19
+ horse,#123A2C
20
+ sheep,#13C1D8
21
+ cow,#14E67D
22
+ elephant,#152718
23
+ bear,#165743
24
+ zebra,#17AED2
25
+ giraffe,#1858EF
26
+ hat,#195103
27
+ backpack,#1AA5EA
28
+ umbrella,#1B19CC
29
+ shoe,#1C4DE6
30
+ eye-glasses,#1D4823
31
+ handbag,#1E09D6
32
+ tie,#1F94FE
33
+ suitcase,#2073BD
34
+ frisbee,#21D0C5
35
+ skis,#22F3D7
36
+ snowboard,#23C52B
37
+ sports-ball,#24FE20
38
+ kite,#254F0B
39
+ baseball-bat,#26AF68
40
+ baseball-glove,#27C0D4
41
+ skateboard,#28528A
42
+ surfboard,#2963B6
43
+ tennis-racket,#2AD8EB
44
+ bottle,#2BB1A5
45
+ plate,#2CF37D
46
+ wine-glass,#2D1D9C
47
+ cup,#2E936F
48
+ fork,#2F93E8
49
+ knife,#308E02
50
+ spoon,#31A71B
51
+ bowl,#3220D3
52
+ banana,#33C1D9
53
+ apple,#340997
54
+ sandwich,#35B935
55
+ orange,#367F33
56
+ broccoli,#3720AE
57
+ carrot,#381F94
58
+ hot-dog,#39CAB5
59
+ pizza,#3AF41D
60
+ donut,#3B9743
61
+ cake,#3CA323
62
+ chair,#3DFE27
63
+ couch,#3ECB89
64
+ potted-plant,#3F7249
65
+ bed,#40B729
66
+ mirror,#411C97
67
+ dining-table,#422283
68
+ window,#43802E
69
+ desk,#4480DA
70
+ toilet,#45A4B2
71
+ door,#46356C
72
+ tv,#478503
73
+ laptop,#48261F
74
+ mouse,#49E809
75
+ remote,#4AF48A
76
+ keyboard,#4B111B
77
+ cell-phone,#4C4FAD
78
+ microwave,#4D84C7
79
+ oven,#4E69A7
80
+ toaster,#4F2A3D
81
+ sink,#50BA55
82
+ refrigerator,#511F61
83
+ blender,#52782C
84
+ book,#530122
85
+ clock,#5441A2
86
+ vase,#55E758
87
+ scissors,#56A921
88
+ teddy-bear,#573985
89
+ hair-drier,#5823E8
90
+ toothbrush,#5966FF
91
+ hair-brush,#5A7724
92
+ banner,#5B0B00
93
+ blanket,#5CAECB
94
+ branch,#5D5222
95
+ bridge,#5E5BC5
96
+ building-other,#5F807E
97
+ bush,#606E32
98
+ cabinet,#6163FE
99
+ cage,#623550
100
+ cardboard,#638CBE
101
+ carpet,#647988
102
+ ceiling-other,#65AABD
103
+ ceiling-tile,#665481
104
+ cloth,#67CBD1
105
+ clothes,#684470
106
+ clouds,#696969
107
+ counter,#6AC478
108
+ cupboard,#6B2F5B
109
+ curtain,#6C7FA8
110
+ desk-stuff,#6DF474
111
+ dirt,#6E6E28
112
+ door-stuff,#6FCCB0
113
+ fence,#706419
114
+ floor-marble,#71B443
115
+ floor-other,#72E867
116
+ floor-stone,#734EFC
117
+ floor-tile,#748F23
118
+ floor-wood,#759472
119
+ flower,#760000
120
+ fog,#77BA1D
121
+ food-other,#7817F1
122
+ fruit,#79CF21
123
+ furniture-other,#7A8D92
124
+ grass,#7BC800
125
+ gravel,#7C32C8
126
+ ground-other,#7D3054
127
+ hill,#7EC864
128
+ house,#7F4502
129
+ leaves,#80A945
130
+ light,#81A365
131
+ mat,#82C08C
132
+ metal,#835F2C
133
+ mirror-stuff,#84C575
134
+ moss,#855EFD
135
+ mountain,#869664
136
+ mud,#87716F
137
+ napkin,#88B25B
138
+ net,#892455
139
+ paper,#8AA2A7
140
+ pavement,#8B3027
141
+ pillow,#8C5DCB
142
+ plant,#8DE61E
143
+ plastic,#8E629E
144
+ platform,#8F2A91
145
+ playingfield,#90CDC6
146
+ railing,#9170C7
147
+ railroad,#92E712
148
+ river,#9364C8
149
+ road,#946E28
150
+ rock,#956432
151
+ roof,#9600B1
152
+ rug,#978A29
153
+ salad,#98725D
154
+ sand,#999900
155
+ sea,#9AC6DA
156
+ shelf,#9B7FC9
157
+ sky,#9CEEDD
158
+ skyscraper,#9DBBF2
159
+ snow,#9E9EAA
160
+ solid-other,#9F79DB
161
+ stairs,#A06249
162
+ stone,#A1A164
163
+ straw,#A2A3EB
164
+ structural,#A3DED1
165
+ table,#A47B69
166
+ tent,#A5C3BA
167
+ textile-other,#A65280
168
+ towel,#A7AED6
169
+ tree,#A8C832
170
+ vegetable,#A99410
171
+ wall-brick,#AAD16A
172
+ wall-concrete,#AB32A4
173
+ wall-other,#AC9B5E
174
+ wall-panel,#AD0E18
175
+ wall-stone,#AE2974
176
+ wall-tile,#AF3ABF
177
+ wall-wood,#B0C1C3
178
+ water,#B1C8FF
179
+ waterdrops,#B20A88
180
+ window-blind,#B356B8
181
+ window-other,#B42B5B
182
+ wood,#B57B00
imaginaire/model_utils/gancraft/gaugan_reduction.csv ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ person,ignore
2
+ bicycle,ignore
3
+ car,ignore
4
+ motorcycle,ignore
5
+ airplane,ignore
6
+ bus,ignore
7
+ train,ignore
8
+ truck,ignore
9
+ boat,ignore
10
+ traffic-light,ignore
11
+ fire-hydrant,ignore
12
+ street-sign,ignore
13
+ stop-sign,ignore
14
+ parking-meter,ignore
15
+ bench,ignore
16
+ bird,ignore
17
+ cat,ignore
18
+ dog,ignore
19
+ horse,ignore
20
+ sheep,ignore
21
+ cow,ignore
22
+ elephant,ignore
23
+ bear,ignore
24
+ zebra,ignore
25
+ giraffe,ignore
26
+ hat,ignore
27
+ backpack,ignore
28
+ umbrella,ignore
29
+ shoe,ignore
30
+ eye-glasses,ignore
31
+ handbag,ignore
32
+ tie,ignore
33
+ suitcase,ignore
34
+ frisbee,ignore
35
+ skis,ignore
36
+ snowboard,ignore
37
+ sports-ball,ignore
38
+ kite,ignore
39
+ baseball-bat,ignore
40
+ baseball-glove,ignore
41
+ skateboard,ignore
42
+ surfboard,ignore
43
+ tennis-racket,ignore
44
+ bottle,ignore
45
+ plate,ignore
46
+ wine-glass,ignore
47
+ cup,ignore
48
+ fork,ignore
49
+ knife,ignore
50
+ spoon,ignore
51
+ bowl,ignore
52
+ banana,ignore
53
+ apple,ignore
54
+ sandwich,ignore
55
+ orange,ignore
56
+ broccoli,ignore
57
+ carrot,ignore
58
+ hot-dog,ignore
59
+ pizza,ignore
60
+ donut,ignore
61
+ cake,ignore
62
+ chair,ignore
63
+ couch,ignore
64
+ potted-plant,ignore
65
+ bed,ignore
66
+ mirror,ignore
67
+ dining-table,ignore
68
+ window,ignore
69
+ desk,ignore
70
+ toilet,ignore
71
+ door,ignore
72
+ tv,ignore
73
+ laptop,ignore
74
+ mouse,ignore
75
+ remote,ignore
76
+ keyboard,ignore
77
+ cell-phone,ignore
78
+ microwave,ignore
79
+ oven,ignore
80
+ toaster,ignore
81
+ sink,ignore
82
+ refrigerator,ignore
83
+ blender,ignore
84
+ book,ignore
85
+ clock,ignore
86
+ vase,ignore
87
+ scissors,ignore
88
+ teddy-bear,ignore
89
+ hair-drier,ignore
90
+ toothbrush,ignore
91
+ hair-brush,ignore
92
+ banner,ignore
93
+ blanket,ignore
94
+ branch,tree
95
+ bridge,ignore
96
+ building-other,ignore
97
+ bush,tree
98
+ cabinet,ignore
99
+ cage,ignore
100
+ cardboard,ignore
101
+ carpet,ignore
102
+ ceiling-other,ignore
103
+ ceiling-tile,ignore
104
+ cloth,ignore
105
+ clothes,ignore
106
+ clouds,sky
107
+ counter,ignore
108
+ cupboard,ignore
109
+ curtain,ignore
110
+ desk-stuff,ignore
111
+ dirt,dirt
112
+ door-stuff,ignore
113
+ fence,ignore
114
+ floor-marble,ignore
115
+ floor-other,ignore
116
+ floor-stone,ignore
117
+ floor-tile,ignore
118
+ floor-wood,ignore
119
+ flower,flower
120
+ fog,sky
121
+ food-other,ignore
122
+ fruit,ignore
123
+ furniture-other,ignore
124
+ grass,grass
125
+ gravel,gravel
126
+ ground-other,ignore
127
+ hill,grass
128
+ house,ignore
129
+ leaves,tree
130
+ light,ignore
131
+ mat,ignore
132
+ metal,ignore
133
+ mirror-stuff,ignore
134
+ moss,grass
135
+ mountain,grass
136
+ mud,dirt
137
+ napkin,ignore
138
+ net,ignore
139
+ paper,ignore
140
+ pavement,ignore
141
+ pillow,ignore
142
+ plant,flower
143
+ plastic,ignore
144
+ platform,ignore
145
+ playingfield,ignore
146
+ railing,ignore
147
+ railroad,ignore
148
+ river,water
149
+ road,ignore
150
+ rock,rock
151
+ roof,ignore
152
+ rug,ignore
153
+ salad,ignore
154
+ sand,sand
155
+ sea,water
156
+ shelf,ignore
157
+ sky,sky
158
+ skyscraper,ignore
159
+ snow,snow
160
+ solid-other,ignore
161
+ stairs,ignore
162
+ stone,stone
163
+ straw,grass
164
+ structural,ignore
165
+ table,ignore
166
+ tent,ignore
167
+ textile-other,ignore
168
+ towel,ignore
169
+ tree,tree
170
+ vegetable,ignore
171
+ wall-brick,ignore
172
+ wall-concrete,ignore
173
+ wall-other,ignore
174
+ wall-panel,ignore
175
+ wall-stone,ignore
176
+ wall-tile,ignore
177
+ wall-wood,ignore
178
+ water,water
179
+ waterdrops,ignore
180
+ window-blind,ignore
181
+ window-other,ignore
182
+ wood,ignore
imaginaire/model_utils/gancraft/id2name_gg.csv ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0,air,0,sky
2
+ 1,stone,7368816,stone
3
+ 2,granite,7368816,rock
4
+ 3,polished_granite,7368816,rock
5
+ 4,diorite,7368816,rock
6
+ 5,polished_diorite,7368816,rock
7
+ 6,andesite,7368816,rock
8
+ 7,polished_andesite,7368816,rock
9
+ 8,grass_block,8368696,grass
10
+ 9,dirt,9923917,dirt
11
+ 10,coarse_dirt,9923917,dirt
12
+ 11,podzol,9923917,dirt
13
+ 12,cobblestone,7368816,stone
14
+ 13,oak_planks,9402184,wood
15
+ 14,spruce_planks,9402184,wood
16
+ 15,birch_planks,9402184,wood
17
+ 16,jungle_planks,9402184,wood
18
+ 17,acacia_planks,9402184,wood
19
+ 18,dark_oak_planks,9402184,wood
20
+ 19,oak_sapling,31744,plant
21
+ 20,spruce_sapling,31744,plant
22
+ 21,birch_sapling,31744,plant
23
+ 22,jungle_sapling,31744,plant
24
+ 23,acacia_sapling,31744,plant
25
+ 24,dark_oak_sapling,31744,plant
26
+ 25,bedrock,7368816,rock
27
+ 26,water,4210943,water
28
+ 27,lava,16711680,
29
+ 28,sand,16247203,sand
30
+ 29,red_sand,16247203,sand
31
+ 30,gravel,16247203,gravel
32
+ 31,gold_ore,7368816,rock
33
+ 32,iron_ore,7368816,rock
34
+ 33,coal_ore,7368816,rock
35
+ 34,oak_log,9402184,tree
36
+ 35,spruce_log,9402184,tree
37
+ 36,birch_log,9402184,tree
38
+ 37,jungle_log,9402184,tree
39
+ 38,acacia_log,9402184,tree
40
+ 39,dark_oak_log,9402184,tree
41
+ 40,stripped_spruce_log,9402184,wood
42
+ 41,stripped_birch_log,9402184,wood
43
+ 42,stripped_jungle_log,9402184,wood
44
+ 43,stripped_acacia_log,9402184,wood
45
+ 44,stripped_dark_oak_log,9402184,wood
46
+ 45,stripped_oak_log,9402184,wood
47
+ 46,oak_wood,9402184,wood
48
+ 47,spruce_wood,9402184,wood
49
+ 48,birch_wood,9402184,wood
50
+ 49,jungle_wood,9402184,wood
51
+ 50,acacia_wood,9402184,wood
52
+ 51,dark_oak_wood,9402184,wood
53
+ 52,stripped_oak_wood,9402184,wood
54
+ 53,stripped_spruce_wood,9402184,wood
55
+ 54,stripped_birch_wood,9402184,wood
56
+ 55,stripped_jungle_wood,9402184,wood
57
+ 56,stripped_acacia_wood,9402184,wood
58
+ 57,stripped_dark_oak_wood,9402184,wood
59
+ 58,oak_leaves,31744,tree
60
+ 59,spruce_leaves,31744,tree
61
+ 60,birch_leaves,31744,tree
62
+ 61,jungle_leaves,31744,tree
63
+ 62,acacia_leaves,31744,tree
64
+ 63,dark_oak_leaves,31744,tree
65
+ 64,sponge,15066419,
66
+ 65,wet_sponge,15066419,
67
+ 66,glass,0,
68
+ 67,lapis_ore,7368816,
69
+ 68,lapis_block,10987431,
70
+ 69,dispenser,7368816,
71
+ 70,sandstone,7368816,sand
72
+ 71,chiseled_sandstone,7368816,sand
73
+ 72,cut_sandstone,7368816,sand
74
+ 73,note_block,9402184,
75
+ 74,white_bed,13092807,
76
+ 75,orange_bed,13092807,
77
+ 76,magenta_bed,13092807,
78
+ 77,light_blue_bed,13092807,
79
+ 78,yellow_bed,13092807,
80
+ 79,lime_bed,13092807,
81
+ 80,pink_bed,13092807,
82
+ 81,gray_bed,13092807,
83
+ 82,light_gray_bed,13092807,
84
+ 83,cyan_bed,13092807,
85
+ 84,purple_bed,13092807,
86
+ 85,blue_bed,13092807,
87
+ 86,brown_bed,13092807,
88
+ 87,green_bed,13092807,
89
+ 88,red_bed,13092807,
90
+ 89,black_bed,13092807,
91
+ 90,powered_rail,0,
92
+ 91,detector_rail,0,
93
+ 92,sticky_piston,7368816,
94
+ 93,cobweb,13092807,
95
+ 94,grass,31744,grass
96
+ 95,fern,31744,grass
97
+ 96,dead_bush,31744,grass
98
+ 97,seagrass,4210943,water
99
+ 98,tall_seagrass,4210943,water
100
+ 99,piston,7368816,
101
+ 100,piston_head,7368816,
102
+ 101,white_wool,13092807,
103
+ 102,orange_wool,13092807,
104
+ 103,magenta_wool,13092807,
105
+ 104,light_blue_wool,13092807,
106
+ 105,yellow_wool,13092807,
107
+ 106,lime_wool,13092807,
108
+ 107,pink_wool,13092807,
109
+ 108,gray_wool,13092807,
110
+ 109,light_gray_wool,13092807,
111
+ 110,cyan_wool,13092807,
112
+ 111,purple_wool,13092807,
113
+ 112,blue_wool,13092807,
114
+ 113,brown_wool,13092807,
115
+ 114,green_wool,13092807,
116
+ 115,red_wool,13092807,
117
+ 116,black_wool,13092807,
118
+ 117,moving_piston,7368816,
119
+ 118,dandelion,31744,flower
120
+ 119,poppy,31744,flower
121
+ 120,blue_orchid,31744,flower
122
+ 121,allium,31744,flower
123
+ 122,azure_bluet,31744,flower
124
+ 123,red_tulip,31744,flower
125
+ 124,orange_tulip,31744,flower
126
+ 125,white_tulip,31744,flower
127
+ 126,pink_tulip,31744,flower
128
+ 127,oxeye_daisy,31744,flower
129
+ 128,cornflower,31744,flower
130
+ 129,wither_rose,31744,flower
131
+ 130,lily_of_the_valley,31744,flower
132
+ 131,brown_mushroom,31744,flower
133
+ 132,red_mushroom,31744,flower
134
+ 133,gold_block,10987431,
135
+ 134,iron_block,10987431,
136
+ 135,bricks,7368816,
137
+ 136,tnt,16711680,
138
+ 137,bookshelf,9402184,
139
+ 138,mossy_cobblestone,7368816,
140
+ 139,obsidian,7368816,
141
+ 140,torch,0,
142
+ 141,wall_torch,0,
143
+ 142,fire,0,
144
+ 143,spawner,7368816,
145
+ 144,oak_stairs,9402184,
146
+ 145,chest,9402184,
147
+ 146,redstone_wire,0,
148
+ 147,diamond_ore,7368816,
149
+ 148,diamond_block,10987431,
150
+ 149,crafting_table,9402184,
151
+ 150,wheat,31744,
152
+ 151,farmland,9923917,
153
+ 152,furnace,7368816,
154
+ 153,oak_sign,9402184,
155
+ 154,spruce_sign,9402184,
156
+ 155,birch_sign,9402184,
157
+ 156,acacia_sign,9402184,
158
+ 157,jungle_sign,9402184,
159
+ 158,dark_oak_sign,9402184,
160
+ 159,oak_door,9402184,
161
+ 160,ladder,0,
162
+ 161,rail,0,
163
+ 162,cobblestone_stairs,7368816,
164
+ 163,oak_wall_sign,9402184,
165
+ 164,spruce_wall_sign,9402184,
166
+ 165,birch_wall_sign,9402184,
167
+ 166,acacia_wall_sign,9402184,
168
+ 167,jungle_wall_sign,9402184,
169
+ 168,dark_oak_wall_sign,9402184,
170
+ 169,lever,0,
171
+ 170,stone_pressure_plate,7368816,
172
+ 171,iron_door,10987431,
173
+ 172,oak_pressure_plate,9402184,
174
+ 173,spruce_pressure_plate,9402184,
175
+ 174,birch_pressure_plate,9402184,
176
+ 175,jungle_pressure_plate,9402184,
177
+ 176,acacia_pressure_plate,9402184,
178
+ 177,dark_oak_pressure_plate,9402184,
179
+ 178,redstone_ore,7368816,
180
+ 179,redstone_torch,0,
181
+ 180,redstone_wall_torch,0,
182
+ 181,stone_button,0,
183
+ 182,snow,16777215,snow
184
+ 183,ice,10526975,snow
185
+ 184,snow_block,16777215,snow
186
+ 185,cactus,31744,plant
187
+ 186,clay,10791096,
188
+ 187,sugar_cane,31744,plant
189
+ 188,jukebox,9402184,
190
+ 189,oak_fence,9402184,
191
+ 190,pumpkin,31744,
192
+ 191,netherrack,7368816,
193
+ 192,soul_sand,16247203,
194
+ 193,glowstone,0,
195
+ 194,nether_portal,0,
196
+ 195,carved_pumpkin,31744,
197
+ 196,jack_o_lantern,31744,
198
+ 197,cake,0,
199
+ 198,repeater,0,
200
+ 199,white_stained_glass,0,
201
+ 200,orange_stained_glass,0,
202
+ 201,magenta_stained_glass,0,
203
+ 202,light_blue_stained_glass,0,
204
+ 203,yellow_stained_glass,0,
205
+ 204,lime_stained_glass,0,
206
+ 205,pink_stained_glass,0,
207
+ 206,gray_stained_glass,0,
208
+ 207,light_gray_stained_glass,0,
209
+ 208,cyan_stained_glass,0,
210
+ 209,purple_stained_glass,0,
211
+ 210,blue_stained_glass,0,
212
+ 211,brown_stained_glass,0,
213
+ 212,green_stained_glass,0,
214
+ 213,red_stained_glass,0,
215
+ 214,black_stained_glass,0,
216
+ 215,oak_trapdoor,9402184,
217
+ 216,spruce_trapdoor,9402184,
218
+ 217,birch_trapdoor,9402184,
219
+ 218,jungle_trapdoor,9402184,
220
+ 219,acacia_trapdoor,9402184,
221
+ 220,dark_oak_trapdoor,9402184,
222
+ 221,stone_bricks,7368816,
223
+ 222,mossy_stone_bricks,7368816,
224
+ 223,cracked_stone_bricks,7368816,
225
+ 224,chiseled_stone_bricks,7368816,
226
+ 225,infested_stone,10791096,
227
+ 226,infested_cobblestone,10791096,
228
+ 227,infested_stone_bricks,10791096,
229
+ 228,infested_mossy_stone_bricks,10791096,
230
+ 229,infested_cracked_stone_bricks,10791096,
231
+ 230,infested_chiseled_stone_bricks,10791096,
232
+ 231,brown_mushroom_block,9402184,tree
233
+ 232,red_mushroom_block,9402184,tree
234
+ 233,mushroom_stem,9402184,tree
235
+ 234,iron_bars,10987431,
236
+ 235,glass_pane,0,
237
+ 236,melon,31744,
238
+ 237,attached_pumpkin_stem,31744,
239
+ 238,attached_melon_stem,31744,
240
+ 239,pumpkin_stem,31744,
241
+ 240,melon_stem,31744,
242
+ 241,vine,31744,plant
243
+ 242,oak_fence_gate,9402184,
244
+ 243,brick_stairs,7368816,
245
+ 244,stone_brick_stairs,7368816,
246
+ 245,mycelium,8368696,
247
+ 246,lily_pad,31744,grass
248
+ 247,nether_bricks,7368816,
249
+ 248,nether_brick_fence,7368816,
250
+ 249,nether_brick_stairs,7368816,
251
+ 250,nether_wart,31744,
252
+ 251,enchanting_table,7368816,
253
+ 252,brewing_stand,10987431,
254
+ 253,cauldron,10987431,
255
+ 254,end_portal,0,
256
+ 255,end_portal_frame,7368816,
257
+ 256,end_stone,7368816,
258
+ 257,dragon_egg,31744,
259
+ 258,redstone_lamp,0,
260
+ 259,cocoa,31744,
261
+ 260,sandstone_stairs,7368816,
262
+ 261,emerald_ore,7368816,
263
+ 262,ender_chest,7368816,
264
+ 263,tripwire_hook,0,
265
+ 264,tripwire,0,
266
+ 265,emerald_block,10987431,
267
+ 266,spruce_stairs,9402184,
268
+ 267,birch_stairs,9402184,
269
+ 268,jungle_stairs,9402184,
270
+ 269,command_block,10987431,
271
+ 270,beacon,0,
272
+ 271,cobblestone_wall,7368816,
273
+ 272,mossy_cobblestone_wall,7368816,
274
+ 273,flower_pot,0,
275
+ 274,potted_oak_sapling,0,
276
+ 275,potted_spruce_sapling,0,
277
+ 276,potted_birch_sapling,0,
278
+ 277,potted_jungle_sapling,0,
279
+ 278,potted_acacia_sapling,0,
280
+ 279,potted_dark_oak_sapling,0,
281
+ 280,potted_fern,0,
282
+ 281,potted_dandelion,0,
283
+ 282,potted_poppy,0,
284
+ 283,potted_blue_orchid,0,
285
+ 284,potted_allium,0,
286
+ 285,potted_azure_bluet,0,
287
+ 286,potted_red_tulip,0,
288
+ 287,potted_orange_tulip,0,
289
+ 288,potted_white_tulip,0,
290
+ 289,potted_pink_tulip,0,
291
+ 290,potted_oxeye_daisy,0,
292
+ 291,potted_cornflower,0,
293
+ 292,potted_lily_of_the_valley,0,
294
+ 293,potted_wither_rose,0,
295
+ 294,potted_red_mushroom,0,
296
+ 295,potted_brown_mushroom,0,
297
+ 296,potted_dead_bush,0,
298
+ 297,potted_cactus,0,
299
+ 298,carrots,31744,
300
+ 299,potatoes,31744,
301
+ 300,oak_button,0,
302
+ 301,spruce_button,0,
303
+ 302,birch_button,0,
304
+ 303,jungle_button,0,
305
+ 304,acacia_button,0,
306
+ 305,dark_oak_button,0,
307
+ 306,skeleton_skull,0,
308
+ 307,skeleton_wall_skull,0,
309
+ 308,wither_skeleton_skull,0,
310
+ 309,wither_skeleton_wall_skull,0,
311
+ 310,zombie_head,0,
312
+ 311,zombie_wall_head,0,
313
+ 312,player_head,0,
314
+ 313,player_wall_head,0,
315
+ 314,creeper_head,0,
316
+ 315,creeper_wall_head,0,
317
+ 316,dragon_head,0,
318
+ 317,dragon_wall_head,0,
319
+ 318,anvil,10987431,
320
+ 319,chipped_anvil,10987431,
321
+ 320,damaged_anvil,10987431,
322
+ 321,trapped_chest,9402184,
323
+ 322,light_weighted_pressure_plate,10987431,
324
+ 323,heavy_weighted_pressure_plate,10987431,
325
+ 324,comparator,0,
326
+ 325,daylight_detector,9402184,
327
+ 326,redstone_block,10987431,
328
+ 327,nether_quartz_ore,7368816,
329
+ 328,hopper,10987431,
330
+ 329,quartz_block,7368816,
331
+ 330,chiseled_quartz_block,7368816,
332
+ 331,quartz_pillar,7368816,
333
+ 332,quartz_stairs,7368816,
334
+ 333,activator_rail,0,
335
+ 334,dropper,7368816,
336
+ 335,white_terracotta,7368816,
337
+ 336,orange_terracotta,7368816,
338
+ 337,magenta_terracotta,7368816,
339
+ 338,light_blue_terracotta,7368816,
340
+ 339,yellow_terracotta,7368816,
341
+ 340,lime_terracotta,7368816,
342
+ 341,pink_terracotta,7368816,
343
+ 342,gray_terracotta,7368816,
344
+ 343,light_gray_terracotta,7368816,
345
+ 344,cyan_terracotta,7368816,
346
+ 345,purple_terracotta,7368816,
347
+ 346,blue_terracotta,7368816,
348
+ 347,brown_terracotta,7368816,
349
+ 348,green_terracotta,7368816,
350
+ 349,red_terracotta,7368816,
351
+ 350,black_terracotta,7368816,
352
+ 351,white_stained_glass_pane,0,
353
+ 352,orange_stained_glass_pane,0,
354
+ 353,magenta_stained_glass_pane,0,
355
+ 354,light_blue_stained_glass_pane,0,
356
+ 355,yellow_stained_glass_pane,0,
357
+ 356,lime_stained_glass_pane,0,
358
+ 357,pink_stained_glass_pane,0,
359
+ 358,gray_stained_glass_pane,0,
360
+ 359,light_gray_stained_glass_pane,0,
361
+ 360,cyan_stained_glass_pane,0,
362
+ 361,purple_stained_glass_pane,0,
363
+ 362,blue_stained_glass_pane,0,
364
+ 363,brown_stained_glass_pane,0,
365
+ 364,green_stained_glass_pane,0,
366
+ 365,red_stained_glass_pane,0,
367
+ 366,black_stained_glass_pane,0,
368
+ 367,acacia_stairs,9402184,
369
+ 368,dark_oak_stairs,9402184,
370
+ 369,slime_block,10791096,
371
+ 370,barrier,0,
372
+ 371,iron_trapdoor,10987431,
373
+ 372,prismarine,7368816,
374
+ 373,prismarine_bricks,7368816,
375
+ 374,dark_prismarine,7368816,
376
+ 375,prismarine_stairs,7368816,
377
+ 376,prismarine_brick_stairs,7368816,
378
+ 377,dark_prismarine_stairs,7368816,
379
+ 378,prismarine_slab,7368816,
380
+ 379,prismarine_brick_slab,7368816,
381
+ 380,dark_prismarine_slab,7368816,
382
+ 381,sea_lantern,0,
383
+ 382,hay_block,8368696,
384
+ 383,white_carpet,13092807,
385
+ 384,orange_carpet,13092807,
386
+ 385,magenta_carpet,13092807,
387
+ 386,light_blue_carpet,13092807,
388
+ 387,yellow_carpet,13092807,
389
+ 388,lime_carpet,13092807,
390
+ 389,pink_carpet,13092807,
391
+ 390,gray_carpet,13092807,
392
+ 391,light_gray_carpet,13092807,
393
+ 392,cyan_carpet,13092807,
394
+ 393,purple_carpet,13092807,
395
+ 394,blue_carpet,13092807,
396
+ 395,brown_carpet,13092807,
397
+ 396,green_carpet,13092807,
398
+ 397,red_carpet,13092807,
399
+ 398,black_carpet,13092807,
400
+ 399,terracotta,7368816,
401
+ 400,coal_block,7368816,
402
+ 401,packed_ice,10526975,
403
+ 402,sunflower,31744,flower
404
+ 403,lilac,31744,flower
405
+ 404,rose_bush,31744,flower
406
+ 405,peony,31744,flower
407
+ 406,tall_grass,31744,plant
408
+ 407,large_fern,31744,plant
409
+ 408,white_banner,9402184,
410
+ 409,orange_banner,9402184,
411
+ 410,magenta_banner,9402184,
412
+ 411,light_blue_banner,9402184,
413
+ 412,yellow_banner,9402184,
414
+ 413,lime_banner,9402184,
415
+ 414,pink_banner,9402184,
416
+ 415,gray_banner,9402184,
417
+ 416,light_gray_banner,9402184,
418
+ 417,cyan_banner,9402184,
419
+ 418,purple_banner,9402184,
420
+ 419,blue_banner,9402184,
421
+ 420,brown_banner,9402184,
422
+ 421,green_banner,9402184,
423
+ 422,red_banner,9402184,
424
+ 423,black_banner,9402184,
425
+ 424,white_wall_banner,9402184,
426
+ 425,orange_wall_banner,9402184,
427
+ 426,magenta_wall_banner,9402184,
428
+ 427,light_blue_wall_banner,9402184,
429
+ 428,yellow_wall_banner,9402184,
430
+ 429,lime_wall_banner,9402184,
431
+ 430,pink_wall_banner,9402184,
432
+ 431,gray_wall_banner,9402184,
433
+ 432,light_gray_wall_banner,9402184,
434
+ 433,cyan_wall_banner,9402184,
435
+ 434,purple_wall_banner,9402184,
436
+ 435,blue_wall_banner,9402184,
437
+ 436,brown_wall_banner,9402184,
438
+ 437,green_wall_banner,9402184,
439
+ 438,red_wall_banner,9402184,
440
+ 439,black_wall_banner,9402184,
441
+ 440,red_sandstone,7368816,
442
+ 441,chiseled_red_sandstone,7368816,
443
+ 442,cut_red_sandstone,7368816,
444
+ 443,red_sandstone_stairs,7368816,
445
+ 444,oak_slab,9402184,
446
+ 445,spruce_slab,9402184,
447
+ 446,birch_slab,9402184,
448
+ 447,jungle_slab,9402184,
449
+ 448,acacia_slab,9402184,
450
+ 449,dark_oak_slab,9402184,
451
+ 450,stone_slab,7368816,
452
+ 451,smooth_stone_slab,7368816,
453
+ 452,sandstone_slab,7368816,
454
+ 453,cut_sandstone_slab,7368816,
455
+ 454,petrified_oak_slab,7368816,
456
+ 455,cobblestone_slab,7368816,
457
+ 456,brick_slab,7368816,
458
+ 457,stone_brick_slab,7368816,
459
+ 458,nether_brick_slab,7368816,
460
+ 459,quartz_slab,7368816,
461
+ 460,red_sandstone_slab,7368816,
462
+ 461,cut_red_sandstone_slab,7368816,
463
+ 462,purpur_slab,7368816,
464
+ 463,smooth_stone,7368816,
465
+ 464,smooth_sandstone,7368816,
466
+ 465,smooth_quartz,7368816,
467
+ 466,smooth_red_sandstone,7368816,
468
+ 467,spruce_fence_gate,9402184,
469
+ 468,birch_fence_gate,9402184,
470
+ 469,jungle_fence_gate,9402184,
471
+ 470,acacia_fence_gate,9402184,
472
+ 471,dark_oak_fence_gate,9402184,
473
+ 472,spruce_fence,9402184,
474
+ 473,birch_fence,9402184,
475
+ 474,jungle_fence,9402184,
476
+ 475,acacia_fence,9402184,
477
+ 476,dark_oak_fence,9402184,
478
+ 477,spruce_door,9402184,
479
+ 478,birch_door,9402184,
480
+ 479,jungle_door,9402184,
481
+ 480,acacia_door,9402184,
482
+ 481,dark_oak_door,9402184,
483
+ 482,end_rod,0,
484
+ 483,chorus_plant,31744,
485
+ 484,chorus_flower,31744,
486
+ 485,purpur_block,7368816,
487
+ 486,purpur_pillar,7368816,
488
+ 487,purpur_stairs,7368816,
489
+ 488,end_stone_bricks,7368816,
490
+ 489,beetroots,31744,
491
+ 490,grass_path,9923917,
492
+ 491,end_gateway,0,
493
+ 492,repeating_command_block,10987431,
494
+ 493,chain_command_block,10987431,
495
+ 494,frosted_ice,10526975,
496
+ 495,magma_block,7368816,
497
+ 496,nether_wart_block,8368696,
498
+ 497,red_nether_bricks,7368816,
499
+ 498,bone_block,7368816,
500
+ 499,structure_void,0,
501
+ 500,observer,7368816,
502
+ 501,shulker_box,8339378,
503
+ 502,white_shulker_box,8339378,
504
+ 503,orange_shulker_box,8339378,
505
+ 504,magenta_shulker_box,8339378,
506
+ 505,light_blue_shulker_box,8339378,
507
+ 506,yellow_shulker_box,8339378,
508
+ 507,lime_shulker_box,8339378,
509
+ 508,pink_shulker_box,8339378,
510
+ 509,gray_shulker_box,8339378,
511
+ 510,light_gray_shulker_box,8339378,
512
+ 511,cyan_shulker_box,8339378,
513
+ 512,purple_shulker_box,8339378,
514
+ 513,blue_shulker_box,8339378,
515
+ 514,brown_shulker_box,8339378,
516
+ 515,green_shulker_box,8339378,
517
+ 516,red_shulker_box,8339378,
518
+ 517,black_shulker_box,8339378,
519
+ 518,white_glazed_terracotta,7368816,
520
+ 519,orange_glazed_terracotta,7368816,
521
+ 520,magenta_glazed_terracotta,7368816,
522
+ 521,light_blue_glazed_terracotta,7368816,
523
+ 522,yellow_glazed_terracotta,7368816,
524
+ 523,lime_glazed_terracotta,7368816,
525
+ 524,pink_glazed_terracotta,7368816,
526
+ 525,gray_glazed_terracotta,7368816,
527
+ 526,light_gray_glazed_terracotta,7368816,
528
+ 527,cyan_glazed_terracotta,7368816,
529
+ 528,purple_glazed_terracotta,7368816,
530
+ 529,blue_glazed_terracotta,7368816,
531
+ 530,brown_glazed_terracotta,7368816,
532
+ 531,green_glazed_terracotta,7368816,
533
+ 532,red_glazed_terracotta,7368816,
534
+ 533,black_glazed_terracotta,7368816,
535
+ 534,white_concrete,7368816,
536
+ 535,orange_concrete,7368816,
537
+ 536,magenta_concrete,7368816,
538
+ 537,light_blue_concrete,7368816,
539
+ 538,yellow_concrete,7368816,
540
+ 539,lime_concrete,7368816,
541
+ 540,pink_concrete,7368816,
542
+ 541,gray_concrete,7368816,
543
+ 542,light_gray_concrete,7368816,
544
+ 543,cyan_concrete,7368816,
545
+ 544,purple_concrete,7368816,
546
+ 545,blue_concrete,7368816,
547
+ 546,brown_concrete,7368816,
548
+ 547,green_concrete,7368816,
549
+ 548,red_concrete,7368816,
550
+ 549,black_concrete,7368816,
551
+ 550,white_concrete_powder,16247203,
552
+ 551,orange_concrete_powder,16247203,
553
+ 552,magenta_concrete_powder,16247203,
554
+ 553,light_blue_concrete_powder,16247203,
555
+ 554,yellow_concrete_powder,16247203,
556
+ 555,lime_concrete_powder,16247203,
557
+ 556,pink_concrete_powder,16247203,
558
+ 557,gray_concrete_powder,16247203,
559
+ 558,light_gray_concrete_powder,16247203,
560
+ 559,cyan_concrete_powder,16247203,
561
+ 560,purple_concrete_powder,16247203,
562
+ 561,blue_concrete_powder,16247203,
563
+ 562,brown_concrete_powder,16247203,
564
+ 563,green_concrete_powder,16247203,
565
+ 564,red_concrete_powder,16247203,
566
+ 565,black_concrete_powder,16247203,
567
+ 566,kelp,4210943,
568
+ 567,kelp_plant,4210943,
569
+ 568,dried_kelp_block,8368696,
570
+ 569,turtle_egg,31744,
571
+ 570,dead_tube_coral_block,7368816,
572
+ 571,dead_brain_coral_block,7368816,
573
+ 572,dead_bubble_coral_block,7368816,
574
+ 573,dead_fire_coral_block,7368816,
575
+ 574,dead_horn_coral_block,7368816,
576
+ 575,tube_coral_block,7368816,
577
+ 576,brain_coral_block,7368816,
578
+ 577,bubble_coral_block,7368816,
579
+ 578,fire_coral_block,7368816,
580
+ 579,horn_coral_block,7368816,
581
+ 580,dead_tube_coral,7368816,
582
+ 581,dead_brain_coral,7368816,
583
+ 582,dead_bubble_coral,7368816,
584
+ 583,dead_fire_coral,7368816,
585
+ 584,dead_horn_coral,7368816,
586
+ 585,tube_coral,4210943,
587
+ 586,brain_coral,4210943,
588
+ 587,bubble_coral,4210943,
589
+ 588,fire_coral,4210943,
590
+ 589,horn_coral,4210943,
591
+ 590,dead_tube_coral_fan,7368816,
592
+ 591,dead_brain_coral_fan,7368816,
593
+ 592,dead_bubble_coral_fan,7368816,
594
+ 593,dead_fire_coral_fan,7368816,
595
+ 594,dead_horn_coral_fan,7368816,
596
+ 595,tube_coral_fan,4210943,
597
+ 596,brain_coral_fan,4210943,
598
+ 597,bubble_coral_fan,4210943,
599
+ 598,fire_coral_fan,4210943,
600
+ 599,horn_coral_fan,4210943,
601
+ 600,dead_tube_coral_wall_fan,7368816,
602
+ 601,dead_brain_coral_wall_fan,7368816,
603
+ 602,dead_bubble_coral_wall_fan,7368816,
604
+ 603,dead_fire_coral_wall_fan,7368816,
605
+ 604,dead_horn_coral_wall_fan,7368816,
606
+ 605,tube_coral_wall_fan,4210943,
607
+ 606,brain_coral_wall_fan,4210943,
608
+ 607,bubble_coral_wall_fan,4210943,
609
+ 608,fire_coral_wall_fan,4210943,
610
+ 609,horn_coral_wall_fan,4210943,
611
+ 610,sea_pickle,4210943,
612
+ 611,blue_ice,10526975,
613
+ 612,conduit,0,
614
+ 613,bamboo_sapling,9402184,plant
615
+ 614,bamboo,9402184,plant
616
+ 615,potted_bamboo,0,
617
+ 616,void_air,0,dirt
618
+ 617,cave_air,0,dirt
619
+ 618,bubble_column,4210943,
620
+ 619,polished_granite_stairs,7368816,
621
+ 620,smooth_red_sandstone_stairs,7368816,
622
+ 621,mossy_stone_brick_stairs,7368816,
623
+ 622,polished_diorite_stairs,7368816,
624
+ 623,mossy_cobblestone_stairs,7368816,
625
+ 624,end_stone_brick_stairs,7368816,
626
+ 625,stone_stairs,7368816,
627
+ 626,smooth_sandstone_stairs,7368816,
628
+ 627,smooth_quartz_stairs,7368816,
629
+ 628,granite_stairs,7368816,
630
+ 629,andesite_stairs,7368816,
631
+ 630,red_nether_brick_stairs,7368816,
632
+ 631,polished_andesite_stairs,7368816,
633
+ 632,diorite_stairs,7368816,
634
+ 633,polished_granite_slab,7368816,
635
+ 634,smooth_red_sandstone_slab,7368816,
636
+ 635,mossy_stone_brick_slab,7368816,
637
+ 636,polished_diorite_slab,7368816,
638
+ 637,mossy_cobblestone_slab,7368816,
639
+ 638,end_stone_brick_slab,7368816,
640
+ 639,smooth_sandstone_slab,7368816,
641
+ 640,smooth_quartz_slab,7368816,
642
+ 641,granite_slab,7368816,
643
+ 642,andesite_slab,7368816,
644
+ 643,red_nether_brick_slab,7368816,
645
+ 644,polished_andesite_slab,7368816,
646
+ 645,diorite_slab,7368816,
647
+ 646,brick_wall,7368816,
648
+ 647,prismarine_wall,7368816,
649
+ 648,red_sandstone_wall,7368816,
650
+ 649,mossy_stone_brick_wall,7368816,
651
+ 650,granite_wall,7368816,
652
+ 651,stone_brick_wall,7368816,
653
+ 652,nether_brick_wall,7368816,
654
+ 653,andesite_wall,7368816,
655
+ 654,red_nether_brick_wall,7368816,
656
+ 655,sandstone_wall,7368816,
657
+ 656,end_stone_brick_wall,7368816,
658
+ 657,diorite_wall,7368816,
659
+ 658,scaffolding,0,
660
+ 659,loom,9402184,
661
+ 660,barrel,9402184,
662
+ 661,smoker,7368816,
663
+ 662,blast_furnace,7368816,
664
+ 663,cartography_table,9402184,
665
+ 664,fletching_table,9402184,
666
+ 665,grindstone,10987431,
667
+ 666,lectern,9402184,
668
+ 667,smithing_table,9402184,
669
+ 668,stonecutter,7368816,
670
+ 669,bell,10987431,
671
+ 670,lantern,10987431,
672
+ 671,campfire,9402184,
673
+ 672,sweet_berry_bush,31744,
674
+ 673,structure_block,10987431,
675
+ 674,jigsaw,10987431,
676
+ 675,composter,9402184,
677
+ 676,bee_nest,9402184,
678
+ 677,beehive,9402184,
679
+ 678,honey_block,10791096,
680
+ 679,honeycomb_block,10791096,
imaginaire/model_utils/gancraft/loss.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class GANLoss(nn.Module):
11
+ def __init__(self, target_real_label=1.0, target_fake_label=0.0):
12
+ r"""GAN loss constructor.
13
+
14
+ Args:
15
+ target_real_label (float): Desired output label for the real images.
16
+ target_fake_label (float): Desired output label for the fake images.
17
+ """
18
+ super(GANLoss, self).__init__()
19
+ self.real_label = target_real_label
20
+ self.fake_label = target_fake_label
21
+ self.real_label_tensor = None
22
+ self.fake_label_tensor = None
23
+
24
+ def forward(self, input_x, t_real, weight=None,
25
+ reduce_dim=True, dis_update=True):
26
+ r"""GAN loss computation.
27
+
28
+ Args:
29
+ input_x (tensor or list of tensors): Output values.
30
+ t_real (boolean): Is this output value for real images.
31
+ reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use
32
+ multi-resolution discriminators.
33
+ weight (float): Weight to scale the loss value.
34
+ dis_update (boolean): Updating the discriminator or the generator.
35
+ Returns:
36
+ loss (tensor): Loss value.
37
+ """
38
+ if isinstance(input_x, list):
39
+ loss = 0
40
+ for pred_i in input_x:
41
+ if isinstance(pred_i, list):
42
+ pred_i = pred_i[-1]
43
+ loss_tensor = self.loss(pred_i, t_real, weight,
44
+ reduce_dim, dis_update)
45
+ bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
46
+ new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
47
+ loss += new_loss
48
+ return loss / len(input_x)
49
+ else:
50
+ return self.loss(input_x, t_real, weight, reduce_dim, dis_update)
51
+
52
+ def loss(self, input_x, t_real, weight=None,
53
+ reduce_dim=True, dis_update=True):
54
+ r"""N+1 label GAN loss computation.
55
+
56
+ Args:
57
+ input_x (tensor): Output values.
58
+ t_real (boolean): Is this output value for real images.
59
+ reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use
60
+ multi-resolution discriminators.
61
+ weight (float): Weight to scale the loss value.
62
+ dis_update (boolean): Updating the discriminator or the generator.
63
+ Returns:
64
+ loss (tensor): Loss value.
65
+ """
66
+ assert reduce_dim is True
67
+ pred = input_x['pred'].clone()
68
+ label = input_x['label'].clone()
69
+ batch_size = pred.size(0)
70
+
71
+ # ignore label 0
72
+ label[:, 0, ...] = 0
73
+ pred[:, 0, ...] = 0
74
+ pred = F.log_softmax(pred, dim=1)
75
+ assert pred.size(1) == (label.size(1) + 1)
76
+ if dis_update:
77
+ if t_real:
78
+ pred_real = pred[:, :-1, :, :]
79
+ loss = - label * pred_real
80
+ loss = torch.sum(loss, dim=1, keepdim=True)
81
+ else:
82
+ pred_fake = pred[:, -1, None, :, :] # N plus 1
83
+ loss = - pred_fake
84
+ else:
85
+ assert t_real, "GAN loss must be aiming for real."
86
+ pred_real = pred[:, :-1, :, :]
87
+ loss = - label * pred_real
88
+ loss = torch.sum(loss, dim=1, keepdim=True)
89
+
90
+ if weight is not None:
91
+ loss = loss * weight
92
+ if reduce_dim:
93
+ loss = torch.mean(loss)
94
+ else:
95
+ loss = loss.view(batch_size, -1).mean(dim=1)
96
+ return loss