diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..b3d9d65228df70b310aff5795d1c1dd1602ad30b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/sample_traj.gif filter=lfs diff=lfs merge=lfs -text +assets/teaser.gif filter=lfs diff=lfs merge=lfs -text diff --git a/activation.py b/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..17e8187e55d74e9c73fb7f7698d111f8b204fd35 --- /dev/null +++ b/activation.py @@ -0,0 +1,18 @@ +import torch +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +class _trunc_exp(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # cast to float32 + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): + x = ctx.saved_tensors[0] + return g * torch.exp(x.clamp(-15, 15)) + +trunc_exp = _trunc_exp.apply \ No newline at end of file diff --git a/app_gradio.py b/app_gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..a915d8002bd615372cae7b580fdcce8ded2d37de --- /dev/null +++ b/app_gradio.py @@ -0,0 +1,137 @@ +import os +import torch +import torch.nn as nn +import importlib +import argparse +from imaginaire.config import Config +from imaginaire.utils.cudnn import init_cudnn +import gradio as gr +from PIL import Image + + +class WrappedModel(nn.Module): + r"""Dummy wrapping the module. + """ + + def __init__(self, module): + super(WrappedModel, self).__init__() + self.module = module + + def forward(self, *args, **kwargs): + r"""PyTorch module forward function overload.""" + return self.module(*args, **kwargs) + +def parse_args(): + parser = argparse.ArgumentParser(description='Training') + parser.add_argument('--config', type=str, default='./configs/scenedreamer_inference.yaml', help='Path to the training config file.') + parser.add_argument('--checkpoint', default='./scenedreamer_released.pt', + help='Checkpoint path.') + parser.add_argument('--output_dir', type=str, default='./test/', + help='Location to save the image outputs') + parser.add_argument('--seed', type=int, default=8888, + help='Random seed.') + args = parser.parse_args() + return args + + +args = parse_args() +cfg = Config(args.config) + +# Initialize cudnn. +init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark) + +# Initialize data loaders and models. + +lib_G = importlib.import_module(cfg.gen.type) +net_G = lib_G.Generator(cfg.gen, cfg.data) +net_G = net_G.to('cuda') +net_G = WrappedModel(net_G) + +if args.checkpoint == '': + raise NotImplementedError("No checkpoint is provided for inference!") + +# Load checkpoint. +# trainer.load_checkpoint(cfg, args.checkpoint) +checkpoint = torch.load(args.checkpoint, map_location='cpu') +net_G.load_state_dict(checkpoint['net_G']) + +# Do inference. +net_G = net_G.module +net_G.eval() +for name, param in net_G.named_parameters(): + param.requires_grad = False +torch.cuda.empty_cache() +world_dir = os.path.join(args.output_dir) +os.makedirs(world_dir, exist_ok=True) + + + +def get_bev(seed): + print('[PCGGenerator] Generating BEV scene representation...') + os.system('python terrain_generator.py --size {} --seed {} --outdir {}'.format(net_G.voxel.sample_size, seed, world_dir)) + heightmap_path = os.path.join(world_dir, 'heightmap.png') + semantic_path = os.path.join(world_dir, 'colormap.png') + heightmap = Image.open(heightmap_path) + semantic = Image.open(semantic_path) + return semantic, heightmap + +def get_video(seed, num_frames, reso_h, reso_w): + device = torch.device('cuda') + rng_cuda = torch.Generator(device=device) + rng_cuda = rng_cuda.manual_seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + net_G.voxel.next_world(device, world_dir, checkpoint) + cam_mode = cfg.inference_args.camera_mode + cfg.inference_args.cam_maxstep = num_frames + cfg.inference_args.resolution_hw = [reso_h, reso_w] + current_outdir = os.path.join(world_dir, 'camera_{:02d}'.format(cam_mode)) + os.makedirs(current_outdir, exist_ok=True) + z = torch.empty(1, net_G.style_dims, dtype=torch.float32, device=device) + z.normal_(generator=rng_cuda) + net_G.inference_givenstyle(z, current_outdir, **vars(cfg.inference_args)) + return os.path.join(current_outdir, 'rgb_render.mp4') + +markdown=f''' + # SceneDreamer: Unbounded 3D Scene Generation from 2D Image Collections + + Authored by Zhaoxi Chen, Guangcong Wang, Ziwei Liu + ### Useful links: + - [Official Github Repo](https://github.com/FrozenBurning/SceneDreamer) + - [Project Page](https://scene-dreamer.github.io/) + - [arXiv Link](https://arxiv.org/abs/2302.01330) + Licensed under the S-Lab License. + We offer a sampled scene whose BEVs are shown on the right. You can also use the button "Generate BEV" to randomly sample a new 3D world represented by a height map and a semantic map. But it requires a long time. + + To render video, push the button "Render" to generate a camera trajectory flying through the world. You can specify rendering options as shown below! +''' + +with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(): + gr.Markdown(markdown) + with gr.Column(): + with gr.Row(): + with gr.Column(): + semantic = gr.Image(value='./test/colormap.png',type="pil", shape=(512, 512)) + with gr.Column(): + height = gr.Image(value='./test/heightmap.png', type="pil", shape=(512, 512)) + with gr.Row(): + # with gr.Column(): + # image = gr.Image(type='pil', shape(540, 960)) + with gr.Column(): + video = gr.Video() + with gr.Row(): + num_frames = gr.Slider(minimum=10, maximum=200, value=20, step=1, label='Number of rendered frames') + user_seed = gr.Slider(minimum=0, maximum=999999, value=8888, step=1, label='Random seed') + resolution_h = gr.Slider(minimum=256, maximum=2160, value=270, step=1, label='Height of rendered image') + resolution_w = gr.Slider(minimum=256, maximum=3840, value=480, step=1, label='Width of rendered image') + + with gr.Row(): + btn = gr.Button(value="Generate BEV") + btn_2=gr.Button(value="Render") + + btn.click(get_bev,[user_seed],[semantic, height]) + btn_2.click(get_video,[user_seed, num_frames, resolution_h, resolution_w], [video]) + +demo.launch(debug=True) \ No newline at end of file diff --git a/assets/biome_image.png b/assets/biome_image.png new file mode 100644 index 0000000000000000000000000000000000000000..5af56fb4c9c18e19b97b686f3f4ab22fbaf097ce Binary files /dev/null and b/assets/biome_image.png differ diff --git a/assets/sample_traj.gif b/assets/sample_traj.gif new file mode 100644 index 0000000000000000000000000000000000000000..a09c8f5ec00615b447d09a377a446f6f5b054b85 --- /dev/null +++ b/assets/sample_traj.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9bff3115871f1a78fbe237b44ba51e1bf3eb0578c2850815b0b71dd012c3be6a +size 6577470 diff --git a/assets/teaser.gif b/assets/teaser.gif new file mode 100644 index 0000000000000000000000000000000000000000..d9a72154989f44f85d4987b116f5d984a12065a6 --- /dev/null +++ b/assets/teaser.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12c4e75a99a9a5dc17b89fc1c272fa994068bdd83e0636202f001e875f203a05 +size 24540474 diff --git a/configs/img2lmdb.yaml b/configs/img2lmdb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7dd57b078cd42de7c3ec5637f5a52b2bc61f1602 --- /dev/null +++ b/configs/img2lmdb.yaml @@ -0,0 +1,177 @@ +inference_args: + random_style: True + use_fixed_random_style: False + keep_original_size: True + +image_save_iter: 5000 +snapshot_save_epoch: 5 +max_epoch: 400 +logging_iter: 100 +trainer: + type: imaginaire.trainers.spade + model_average_config: + enabled: True + beta: 0.9999 + start_iteration: 1000 + num_batch_norm_estimation_iterations: 30 + amp_config: + enabled: True + gan_mode: hinge + gan_relativistic: False + perceptual_loss: + mode: 'vgg19' + layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1'] + weights: [0.03125, 0.0625, 0.125, 0.25, 1.0] + fp16: True + loss_weight: + gan: 1.0 + perceptual: 10.0 + feature_matching: 10.0 + kl: 0.05 + init: + type: xavier + gain: 0.02 +gen_opt: + type: adam + lr: 0.0001 + adam_beta1: 0. + adam_beta2: 0.999 + lr_policy: + iteration_mode: False + type: step + step_size: 400 + gamma: 0.1 +dis_opt: + type: adam + lr: 0.0004 + adam_beta1: 0. + adam_beta2: 0.999 + lr_policy: + iteration_mode: False + type: step + step_size: 400 + gamma: 0.1 +gen: + type: imaginaire.generators.spade + version: v20 + style_dims: 256 + num_filters: 128 + kernel_size: 3 + weight_norm_type: 'spectral' + use_posenc_in_input_layer: False + global_adaptive_norm_type: 'sync_batch' + activation_norm_params: + num_filters: 128 + kernel_size: 5 + separate_projection: True + activation_norm_type: 'sync_batch' + style_enc: + num_filters: 64 + kernel_size: 3 +dis: + type: imaginaire.discriminators.spade + kernel_size: 4 + num_filters: 128 + max_num_filters: 512 + num_discriminators: 2 + num_layers: 5 + activation_norm_type: 'none' + weight_norm_type: 'spectral' + +# Data options. +data: + type: imaginaire.datasets.paired_images + # How many data loading workers per GPU? + num_workers: 8 + input_types: + - images: + ext: jpg + num_channels: 3 + normalize: True + use_dont_care: False + - seg_maps: + ext: jpg + num_channels: 1 + is_mask: True + normalize: False + # - edge_maps: + # ext: png + # num_channels: 1 + # normalize: False + + full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels + use_dont_care: True + one_hot_num_classes: + seg_maps: 183 + input_labels: + - seg_maps + # - edge_maps + + # Which lmdb contains the ground truth image. + input_image: + - images + + # Train dataset details. + train: + # Input LMDBs. + roots: + - ./data/lhq/train + # Batch size per GPU. + batch_size: 4 + # Data augmentations to be performed in given order. + augmentations: + resize_smallest_side: 256 + # Rotate in (-rotate, rotate) in degrees. + rotate: 0 + # Scale image by factor \in [1, 1+random_scale_limit]. + random_scale_limit: 0.2 + # Horizontal flip? + horizontal_flip: True + # Crop size. + random_crop_h_w: 256, 256 + # Train dataset details. + val: + # Input LMDBs. + roots: + - ./data/lhq/val + # Batch size per GPU. + batch_size: 4 + # Data augmentations to be performed in given order. + augmentations: + # Crop size. + resize_h_w: 256, 256 + +test_data: + type: imaginaire.datasets.paired_images + num_workers: 8 + input_types: + - seg_maps: + ext: jpg + num_channels: 1 + is_mask: True + normalize: False + # - edge_maps: + # ext: png + # num_channels: 1 + # normalize: False + + full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels + use_dont_care: True + one_hot_num_classes: + seg_maps: 183 + input_labels: + - seg_maps + # - edge_maps + + paired: True + # Validation dataset details. + test: + is_lmdb: False + roots: + - ./data/lhq/train + # Batch size per GPU. + batch_size: 1 + # If resize_h_w is not given, then it is assumed to be same as crop_h_w. + augmentations: + resize_h_w: 256, 256 + horizontal_flip: False \ No newline at end of file diff --git a/configs/landscape1m.yaml b/configs/landscape1m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f232e3f9356da1bf10747fd72ea7ec2e4897fa3 --- /dev/null +++ b/configs/landscape1m.yaml @@ -0,0 +1,175 @@ +pretrained_weight: ./landscape1m-segformer.pt + +inference_args: + random_style: True + use_fixed_random_style: False + keep_original_size: True + +image_save_iter: 5000 +snapshot_save_epoch: 5 +snapshot_save_iter: 30000 +max_epoch: 400 +logging_iter: 100 +trainer: + type: imaginaire.trainers.spade + model_average_config: + enabled: True + beta: 0.9999 + start_iteration: 1000 + num_batch_norm_estimation_iterations: 30 + amp_config: + enabled: True + gan_mode: hinge + gan_relativistic: False + perceptual_loss: + mode: 'vgg19' + layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1'] + weights: [0.03125, 0.0625, 0.125, 0.25, 1.0] + fp16: True + loss_weight: + gan: 1.0 + perceptual: 10.0 + feature_matching: 10.0 + kl: 0.05 + init: + type: xavier + gain: 0.02 +gen_opt: + type: adam + lr: 0.0001 + adam_beta1: 0. + adam_beta2: 0.999 + lr_policy: + iteration_mode: False + type: step + step_size: 400 + gamma: 0.1 +dis_opt: + type: adam + lr: 0.0004 + adam_beta1: 0. + adam_beta2: 0.999 + lr_policy: + iteration_mode: False + type: step + step_size: 400 + gamma: 0.1 +gen: + type: imaginaire.generators.spade + version: v20 + output_multiplier: 0.5 + image_channels: 3 + num_labels: 184 + style_dims: 256 + num_filters: 128 + kernel_size: 3 + weight_norm_type: 'spectral' + use_posenc_in_input_layer: False + global_adaptive_norm_type: 'sync_batch' + activation_norm_params: + num_filters: 128 + kernel_size: 5 + separate_projection: True + activation_norm_type: 'sync_batch' + style_enc: + num_filters: 64 + kernel_size: 3 +dis: + type: imaginaire.discriminators.spade + kernel_size: 4 + num_filters: 128 + max_num_filters: 512 + num_discriminators: 2 + num_layers: 5 + activation_norm_type: 'none' + weight_norm_type: 'spectral' + +# Data options. +data: + type: imaginaire.datasets.paired_images + # How many data loading workers per GPU? + num_workers: 8 + input_types: + - images: + ext: jpg + num_channels: 3 + normalize: True + use_dont_care: False + - seg_maps: + ext: jpg + num_channels: 1 + is_mask: True + normalize: False + + full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels + use_dont_care: True + one_hot_num_classes: + seg_maps: 183 + input_labels: + - seg_maps + + # Which lmdb contains the ground truth image. + input_image: + - images + + # Train dataset details. + train: + # Input LMDBs. + dataset_type: lmdb + roots: + - ./data/lhq_lmdb/train + # Batch size per GPU. + batch_size: 4 + # Data augmentations to be performed in given order. + augmentations: + resize_smallest_side: 512 + # Rotate in (-rotate, rotate) in degrees. + rotate: 0 + # Scale image by factor \in [1, 1+random_scale_limit]. + random_scale_limit: 0.2 + # Horizontal flip? + horizontal_flip: True + # Crop size. + random_crop_h_w: 512, 512 + # Train dataset details. + val: + dataset_type: lmdb + # Input LMDBs. + roots: + - ./data/lhq_lmdb/val + # Batch size per GPU. + batch_size: 4 + # Data augmentations to be performed in given order. + augmentations: + # Crop size. + resize_h_w: 512, 512 + +test_data: + type: imaginaire.datasets.paired_images + num_workers: 8 + input_types: + - seg_maps: + ext: jpg + num_channels: 1 + is_mask: True + normalize: False + + full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels + use_dont_care: True + one_hot_num_classes: + seg_maps: 183 + input_labels: + - seg_maps + + paired: True + # Validation dataset details. + test: + is_lmdb: True + roots: + - ./data/lhq_lmdb/val + # Batch size per GPU. + batch_size: 1 + # If resize_h_w is not given, then it is assumed to be same as crop_h_w. + augmentations: + resize_h_w: 256, 256 + horizontal_flip: False \ No newline at end of file diff --git a/configs/scenedreamer_inference.yaml b/configs/scenedreamer_inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3f20b77f0155820564b84f81d2f3660703e1713 --- /dev/null +++ b/configs/scenedreamer_inference.yaml @@ -0,0 +1,93 @@ +inference_args: + # 0: Camera orbiting the scene & looking at the center + # 1: Camera orbiting the scene & zooming in + # 2: Camera orbiting the scene & coming closer and closer to the center + # 3: Similar to 2, camera orbiting at the opposite direction + # 4: Simliar to 2, camera stays further away from the center + # 5: Camera sits at the center and look outwards + # 6: Camera rises while looking down + # 7: Camera really far away looking down at a 45deg angle + # 8: Camera for perpetual view generation, non-sliding window + # 9: Camera for infinite world generation, sliding window + camera_mode: 4 + + cam_maxstep: 40 + resolution_hw: [540, 960] + num_samples: 40 + cam_ang: 72 + +gen: + type: imaginaire.generators.scenedreamer + pcg_dataset_path: None + pcg_cache: False + scene_size: 2048 + + blk_feat_dim: 64 + + pe_lvl_feat: 4 + pe_incl_orig_feat: False + pe_no_pe_feat_dim: 40 + pe_lvl_raydir: 0 + pe_incl_orig_raydir: False + style_dims: 128 # Set to 0 to disable style. + interm_style_dims: 256 + final_feat_dim: 64 + + # Number of pixels removed from each edge to reduce boundary artifact of CNN + # both sides combined (8 -> 4 on left and 4 on right). + pad: 6 + + # ======== Sky network ======== + pe_lvl_raydir_sky: 5 + pe_incl_orig_raydir_sky: True + + # ======== Style Encoder ========= + # Comment out to disable style encoder. + style_enc: + num_filters: 64 + kernel_size: 3 + weight_norm_type: 'none' + + stylenet_model: StyleMLP + stylenet_model_kwargs: + normalize_input: True + num_layers: 5 + + mlp_model: RenderMLP + mlp_model_kwargs: + use_seg: True + + # ======== Ray Casting Params ======== + num_blocks_early_stop: 6 + num_samples: 24 # Original model uses 24. Reduced to 4 to allow training on 12GB GPUs (with significant performance penalty) + sample_depth: 3 # Stop the ray after certain depth + coarse_deterministic_sampling: False + sample_use_box_boundaries: False # Including voxel boundaries into the sample + + # ======== Blender ======== + raw_noise_std: 0.0 + dists_scale: 0.25 + clip_feat_map: True + # Prevent sky from leaking to the foreground. + keep_sky_out: True + keep_sky_out_avgpool: True + sky_global_avgpool: True + + # ======== Label translator ======== + reduced_label_set: True + use_label_smooth: True + use_label_smooth_real: True + use_label_smooth_pgt: True + label_smooth_dia: 11 + + # ======== Camera sampler ======== + camera_sampler_type: 'traditional' + cam_res: [360, 640] # Camera resolution before cropping. + crop_size: [256, 256] # Actual crop size is crop_size+pad. It should generally match random_crop_h_w in dataloader. + + # Threshold for rejecting camera poses that will result in a seg mask with low entropy. + # Generally, 0.5 min, 0.8 max. + camera_min_entropy: 0.75 + + # Threshold for rejecting camera poses that are too close to the objects. + camera_rej_avg_depth: 2.0 diff --git a/configs/scenedreamer_train.yaml b/configs/scenedreamer_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9458ebf0a207e3db08ed0a625b831fb5f6c41f0a --- /dev/null +++ b/configs/scenedreamer_train.yaml @@ -0,0 +1,223 @@ +image_save_iter: 5000 +snapshot_save_epoch: 5 +snapshot_save_iter: 10000 +max_epoch: 400 +logging_iter: 10 + +trainer: + type: imaginaire.trainers.gancraft + model_average_config: + enabled: False + amp_config: + enabled: False + perceptual_loss: + mode: 'vgg19' + layers: ['relu_3_1', 'relu_4_1', 'relu_5_1'] + weights: [0.125, 0.25, 1.0] + loss_weight: + l2: 10.0 + gan: 0.5 + pseudo_gan: 0.5 + perceptual: 10.0 + kl: 0.05 + init: + type: xavier + gain: 0.02 + + # SPADE/GauGAN model for pseudo-GT generation. + gaugan_loader: + config: configs/landscape1m.yaml + + image_to_tensorboard: True + distributed_data_parallel_params: + find_unused_parameters: False + broadcast_buffers: False + +gen_opt: + type: adam + lr: 0.0001 + eps: 1.e-7 + adam_beta1: 0. + adam_beta2: 0.999 + lr_policy: + iteration_mode: False + type: step + step_size: 400 + gamma: 0.1 + param_groups: + world_encoder: + lr: 0.0005 + hash_encoder: + lr: 0.0001 + render_net: + lr: 0.0001 + sky_net: + lr: 0.0001 + style_net: + lr: 0.0001 + style_encoder: + lr: 0.0001 + denoiser: + lr: 0.0001 + +dis_opt: + type: adam + lr: 0.0004 + eps: 1.e-7 + adam_beta1: 0. + adam_beta2: 0.999 + lr_policy: + iteration_mode: False + type: step + step_size: 400 + gamma: 0.1 + +gen: + type: imaginaire.generators.scenedreamer + pcg_dataset_path: ./data/terrain_cache + pcg_cache: True + scene_size: 2048 + + blk_feat_dim: 64 + + pe_lvl_feat: 4 + pe_incl_orig_feat: False + pe_no_pe_feat_dim: 40 + pe_lvl_raydir: 0 + pe_incl_orig_raydir: False + style_dims: 128 # Set to 0 to disable style. + interm_style_dims: 256 + final_feat_dim: 64 + + # Number of pixels removed from each edge to reduce boundary artifact of CNN + # both sides combined (8 -> 4 on left and 4 on right). + pad: 6 + + # ======== Sky network ======== + pe_lvl_raydir_sky: 5 + pe_incl_orig_raydir_sky: True + + # ======== Style Encoder ========= + # Comment out to disable style encoder. + style_enc: + num_filters: 64 + kernel_size: 3 + weight_norm_type: 'none' + + stylenet_model: StyleMLP + stylenet_model_kwargs: + normalize_input: True + num_layers: 5 + + mlp_model: RenderMLP + mlp_model_kwargs: + use_seg: True + + # ======== Ray Casting Params ======== + num_blocks_early_stop: 6 + num_samples: 24 # Decrease it if you got OOM on lowend GPU + sample_depth: 3 # Stop the ray after certain depth + coarse_deterministic_sampling: False + sample_use_box_boundaries: False # Including voxel boundaries into the sample + + # ======== Blender ======== + raw_noise_std: 0.0 + dists_scale: 0.25 + clip_feat_map: True + # Prevent sky from leaking to the foreground. + keep_sky_out: True + keep_sky_out_avgpool: True + sky_global_avgpool: True + + # ======== Label translator ======== + reduced_label_set: True + use_label_smooth: True + use_label_smooth_real: True + use_label_smooth_pgt: True + label_smooth_dia: 11 + + # ======== Camera sampler ======== + camera_sampler_type: 'traditional' + cam_res: [360, 640] # Camera resolution before cropping. + crop_size: [256, 256] # Actual crop size is crop_size+pad. It should generally match random_crop_h_w in dataloader. + + # Threshold for rejecting camera poses that will result in a seg mask with low entropy. + # Generally, 0.5 min, 0.8 max. + camera_min_entropy: 0.75 + + # Threshold for rejecting camera poses that are too close to the objects. + camera_rej_avg_depth: 2.0 + +dis: + type: imaginaire.discriminators.gancraft + image_channels: 3 + num_labels: 12 # Same as num_reduced_lbls. + use_label: True + num_filters: 128 + fpse_kernel_size: 3 + activation_norm_type: 'none' + weight_norm_type: spectral + smooth_resample: True + +# Data options. +data: + type: imaginaire.datasets.paired_images + num_workers: 8 + input_types: + - images: + ext: jpg + num_channels: 3 + normalize: True + use_dont_care: False + - seg_maps: + ext: png + num_channels: 1 + is_mask: True + normalize: False + + full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels + use_dont_care: False + one_hot_num_classes: + seg_maps: 184 + input_labels: + - seg_maps + + # Which lmdb contains the ground truth image. + input_image: + - images + + # Train dataset details. + train: + dataset_type: lmdb + # Input LMDBs. + roots: + - ./data/lhq_lmdb/train + # Batch size per GPU. + batch_size: 1 + # Data augmentations to be performed in given order. + augmentations: + resize_smallest_side: 256 + # Rotate in (-rotate, rotate) in degrees. + rotate: 0 + # Scale image by factor \in [1, 1+random_scale_limit]. + random_scale_limit: 0.2 + # Horizontal flip? + horizontal_flip: True + # Crop size. + random_crop_h_w: 256, 256 + # Train dataset details. + val: + dataset_type: lmdb + # Input LMDBs. + roots: + - ./data/lhq_lmdb/val + # Batch size per GPU. + batch_size: 1 + # Data augmentations to be performed in given order. + augmentations: + # Crop size. + resize_h_w: 256, 256 + +test_data: + type: imaginaire.datasets.dummy + num_workers: 0 diff --git a/encoding.py b/encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..a5484eed8ae6bc7407852783129c76405239ee72 --- /dev/null +++ b/encoding.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FreqEncoder(nn.Module): + def __init__(self, input_dim, max_freq_log2, N_freqs, + log_sampling=True, include_input=True, + periodic_fns=(torch.sin, torch.cos)): + + super().__init__() + + self.input_dim = input_dim + self.include_input = include_input + self.periodic_fns = periodic_fns + + self.output_dim = 0 + if self.include_input: + self.output_dim += self.input_dim + + self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) + + if log_sampling: + self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) + else: + self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs) + + self.freq_bands = self.freq_bands.numpy().tolist() + + def forward(self, input, **kwargs): + + out = [] + if self.include_input: + out.append(input) + + for i in range(len(self.freq_bands)): + freq = self.freq_bands[i] + for p_fn in self.periodic_fns: + out.append(p_fn(input * freq)) + + out = torch.cat(out, dim=-1) + + + return out + +def get_encoder(encoding, input_dim=3, + multires=6, + degree=4, + num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, + **kwargs): + + if encoding == 'None': + return lambda x, **kwargs: x, input_dim + + elif encoding == 'hashgrid': + from gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) + + elif encoding == 'tiledgrid': + from gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) + elif encoding == 'varhashgrid': + from gridencoder.grid import VarGridEncoder + encoder = VarGridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, hash_entries = kwargs['hash_feat_dim']) + else: + raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]') + + return encoder, encoder.output_dim \ No newline at end of file diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..19bda7a4b9f71f77599f98873c3689681fd30752 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,44 @@ +name: scenedreamer +channels: + - pytorch + - nvidia +dependencies: + - python=3.9 + - pytorch=1.12.0 + - cudatoolkit=11.3 + - torchvision + - pip + - numpy + - scipy + - scikit-image + - pip: + - einops + - noise + - opencv-python + - cmake + - pynvml + - Pillow>=8.3.2 + - tqdm==4.35.0 + - wget + - cython + - lmdb + - av + - opencv-python + - opencv-contrib-python + - imutils + - imageio-ffmpeg + - qimage2ndarray + - albumentations + - requests==2.25.1 + - nvidia-ml-py3==7.352.0 + - pyglet + - timm + - diskcache + - boto3 + - awscli_plugin_endpoint + - awscli + - rsa + - wandb + - tensorboard + - lpips + - matplotlib \ No newline at end of file diff --git a/gridencoder/__init__.py b/gridencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1476cef5314e0918b963d1ac64ee0613a7743d5 --- /dev/null +++ b/gridencoder/__init__.py @@ -0,0 +1 @@ +from .grid import GridEncoder \ No newline at end of file diff --git a/gridencoder/backend.py b/gridencoder/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..fa7b1f5a8f89831fa98db547ed665c49479178bc --- /dev/null +++ b/gridencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++17', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++17'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_grid_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/gridencoder/grid.py b/gridencoder/grid.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c5a6a281b49d6c2707fd00de4149ce3b183294 --- /dev/null +++ b/gridencoder/grid.py @@ -0,0 +1,224 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _gridencoder as _backend +except ImportError: + from .backend import _backend + +_gridtype_to_id = { + 'hash': 0, + 'tiled': 1, +} + +class _grid_encode(Function): + @staticmethod + @custom_fwd + def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False): + # inputs: [B, D], float in [0, 1] + # embeddings: [sO, C], float + # offsets: [L + 1], int + # RETURN: [B, F], float + + inputs = inputs.contiguous() + + B, D = inputs.shape # batch size, coord dim + L = offsets.shape[0] - 1 # level + C = embeddings.shape[1] # embedding dim for each level + S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = base_resolution # base resolution + + # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) + # if C % 2 != 0, force float, since half for atomicAdd is very slow. + if torch.is_autocast_enabled() and C % 2 == 0: + embeddings = embeddings.to(torch.half) + + # L first, optimize cache for cuda kernel, but needs an extra permute later + outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) + + if calc_grad_inputs: + dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) + else: + dy_dx = torch.empty(1, device=inputs.device, dtype=embeddings.dtype) # placeholder... TODO: a better way? + + _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners) + + # permute back to [B, L * C] + outputs = outputs.permute(1, 0, 2).reshape(B, L * C) + + ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) + ctx.dims = [B, D, C, L, S, H, gridtype] + ctx.calc_grad_inputs = calc_grad_inputs + ctx.align_corners = align_corners + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + + inputs, embeddings, offsets, dy_dx = ctx.saved_tensors + B, D, C, L, S, H, gridtype = ctx.dims + calc_grad_inputs = ctx.calc_grad_inputs + align_corners = ctx.align_corners + + # grad: [B, L * C] --> [L, B, C] + grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() + + grad_embeddings = torch.zeros_like(embeddings) + + if calc_grad_inputs: + grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) + else: + grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype) + + _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners) + + if calc_grad_inputs: + grad_inputs = grad_inputs.to(inputs.dtype) + return grad_inputs, grad_embeddings, None, None, None, None, None, None + else: + return None, grad_embeddings, None, None, None, None, None, None + + +grid_encode = _grid_encode.apply + + +class GridEncoder(nn.Module): + def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False): + super().__init__() + + # the finest resolution desired at the last level, if provided, overridee per_level_scale + if desired_resolution is not None: + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) + + self.input_dim = input_dim # coord dims, 2 or 3 + self.num_levels = num_levels # num levels, each level multiply resolution by 2 + self.level_dim = level_dim # encode channels per level + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. + self.log2_hashmap_size = log2_hashmap_size + self.base_resolution = base_resolution + self.output_dim = num_levels * level_dim + self.gridtype = gridtype + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" + self.align_corners = align_corners + + # allocate parameters + offsets = [] + offset = 0 + self.max_params = 2 ** log2_hashmap_size + for i in range(num_levels): + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) + params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible + offsets.append(offset) + offset += params_in_level + offsets.append(offset) + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) + self.register_buffer('offsets', offsets) + + self.n_params = offsets[-1] * level_dim + + # parameters + self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) + + self.reset_parameters() + + def reset_parameters(self): + std = 1e-4 + self.embeddings.data.uniform_(-std, std) + + def __repr__(self): + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}" + + def forward(self, inputs, bound=1): + # inputs: [..., input_dim], normalized real world positions in [-bound, bound] + # return: [..., num_levels * level_dim] + + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.view(-1, self.input_dim) + + outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners) + outputs = outputs.view(prefix_shape + [self.output_dim]) + + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + + return outputs + +class VarGridEncoder(nn.Module): + def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, hash_entries=None): + super().__init__() + + # the finest resolution desired at the last level, if provided, overridee per_level_scale + if desired_resolution is not None: + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) + + self.input_dim = input_dim # coord dims, 2 or 3 + self.num_levels = num_levels # num levels, each level multiply resolution by 2 + self.level_dim = level_dim # encode channels per level + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. + self.log2_hashmap_size = log2_hashmap_size + self.base_resolution = base_resolution + self.output_dim = num_levels * level_dim + self.gridtype = gridtype + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" + self.align_corners = align_corners + + # allocate parameters + offsets = [] + offset = 0 + self.max_params = 2 ** log2_hashmap_size + for i in range(num_levels): + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) + params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible + offsets.append(offset) + offset += params_in_level + offsets.append(offset) + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) + self.register_buffer('offsets', offsets) + + self.n_params = offsets[-1] * level_dim + self.level_dim = level_dim + self.offset = offset + + # parameters + self.embeddings = nn.Parameter(torch.empty(offset - hash_entries, level_dim)) + + self.reset_parameters() + + def reset_parameters(self): + std = 1e-4 + self.embeddings.data.uniform_(-std, std) + + def __repr__(self): + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}" + + def forward(self, inputs, embeddings, bound=1): + # inputs: [..., input_dim], normalized real world positions in [-bound, bound] + # return: [..., num_levels * level_dim] + input_embeddings = torch.cat([embeddings, self.embeddings], dim=0) + + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.view(-1, self.input_dim) + + outputs = grid_encode(inputs, input_embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners) + outputs = outputs.view(prefix_shape + [self.output_dim]) + + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + + return outputs \ No newline at end of file diff --git a/gridencoder/setup.py b/gridencoder/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..69028a7396c8ae5c6a2eeaceac586d01f1a78f6b --- /dev/null +++ b/gridencoder/setup.py @@ -0,0 +1,50 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++17', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++17'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='gridencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_gridencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/gridencoder/src/bindings.cpp b/gridencoder/src/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..afa6f64fd7d7d1efc0380e279aad3ae3ca66c36f --- /dev/null +++ b/gridencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "gridencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); + m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); +} \ No newline at end of file diff --git a/gridencoder/src/gridencoder.cu b/gridencoder/src/gridencoder.cu new file mode 100644 index 0000000000000000000000000000000000000000..2f4ff4c7cc7eb8e6bb473a1fb710698a8c51b369 --- /dev/null +++ b/gridencoder/src/gridencoder.cu @@ -0,0 +1,478 @@ +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... +static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) { + // requires CUDA >= 10 and ARCH >= 70 + // this is very slow compared to float or __half2, and never used. + //return atomicAdd(reinterpret_cast<__half*>(address), val); +} + + +template +static inline __host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + + +template +__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { + static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions."); + + // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence + // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional + // coordinates. + constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 }; + + uint32_t result = 0; + #pragma unroll + for (uint32_t i = 0; i < D; ++i) { + result ^= pos_grid[i] * primes[i]; + } + + return result; +} + + +template +__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { + uint32_t stride = 1; + uint32_t index = 0; + + #pragma unroll + for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { + index += pos_grid[d] * stride; + stride *= align_corners ? resolution: (resolution + 1); + } + + // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. + // gridtype: 0 == hash, 1 == tiled + if (gridtype == 0 && stride > hashmap_size) { + index = fast_hash(pos_grid); + } + + return (index % hashmap_size) * C + ch; +} + + +template +__global__ void kernel_grid( + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ outputs, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const bool calc_grad_inputs, + scalar_t * __restrict__ dy_dx, + const uint32_t gridtype, + const bool align_corners +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + grid += (uint32_t)offsets[level] * C; + inputs += b * D; + outputs += level * B * C + b * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + // if input out of bound, just set output to 0 + if (flag_oob) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = 0; + } + if (calc_grad_inputs) { + dy_dx += b * D * L * C + level * D * C; // B L D C + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[d * C + ch] = 0; + } + } + } + return; + } + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // interpolate + scalar_t results[C] = {0}; // temp results in register + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + // writing to register (fast) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results[ch] += w * grid[index + ch]; + } + + //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = results[ch]; + } + + // prepare dy_dx for calc_grad_inputs + // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 + if (calc_grad_inputs) { + + dy_dx += b * D * L * C + level * D * C; // B L D C + + #pragma unroll + for (uint32_t gd = 0; gd < D; gd++) { + + scalar_t results_grad[C] = {0}; + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { + float w = scale; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t nd = 0; nd < D - 1; nd++) { + const uint32_t d = (nd >= gd) ? (nd + 1) : nd; + + if ((idx & (1 << nd)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + pos_grid_local[gd] = pos_grid[gd]; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + pos_grid_local[gd] = pos_grid[gd] + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]); + } + } + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[gd * C + ch] = results_grad[ch]; + } + } + } +} + + +template +__global__ void kernel_grid_backward( + const scalar_t * __restrict__ grad, + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ grad_grid, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners +) { + const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; + if (b >= B) return; + + const uint32_t level = blockIdx.y; + const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; + + // locate + grad_grid += offsets[level] * C; + inputs += b * D; + grad += level * B * C + b * C + ch; // L, B, C + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // check input range (should be in [0, 1]) + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + return; // grad is init as 0, so we simply return. + } + } + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + } + + scalar_t grad_cur[N_C] = {0}; // fetch to register + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + grad_cur[c] = grad[c]; + } + + // interpolate + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); + + // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 + // TODO: use float which is better than __half, if N_C % 2 != 0 + if (std::is_same::value && N_C % 2 == 0) { + #pragma unroll + for (uint32_t c = 0; c < N_C; c += 2) { + // process two __half at once (by interpreting as a __half2) + __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; + atomicAdd((__half2*)&grad_grid[index + c], v); + } + // float, or __half when N_C % 2 != 0 (which means C == 1) + } else { + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + atomicAdd(&grad_grid[index + c], w * grad_cur[c]); + } + } + } +} + + +template +__global__ void kernel_input_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ dy_dx, + scalar_t * __restrict__ grad_inputs, + uint32_t B, uint32_t L +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; + + dy_dx += b * L * D * C; + + scalar_t result = 0; + + # pragma unroll + for (int l = 0; l < L; l++) { + # pragma unroll + for (int ch = 0; ch < C; ch++) { + result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; + } + } + + grad_inputs[t] = result; +} + + +template +void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; + case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; + case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; + case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) +// H: base resolution +// dy_dx: [B, L * D * C] +template +void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) { + switch (D) { + case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; + case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; + case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; + case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } + +} + +template +void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) { + static constexpr uint32_t N_THREAD = 256; + const uint32_t N_C = std::min(2u, C); // n_features_per_thread + const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; + switch (C) { + case 1: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); + if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 2: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); + if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 4: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); + if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 8: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); + if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +// grad: [L, B, C], float +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// grad_embeddings: [sO, C] +// H: base resolution +template +void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) { + switch (D) { + case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break; + case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break; + case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break; + case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + + +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners) { + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(outputs); + CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(outputs); + CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(outputs); + CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grid_encode_forward", ([&] { + grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr(), gridtype, align_corners); + })); +} + +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(grad_embeddings); + CHECK_CUDA(dy_dx); + CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(grad_embeddings); + CHECK_CONTIGUOUS(dy_dx); + CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(grad_embeddings); + CHECK_IS_FLOATING(dy_dx); + CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "grid_encode_backward", ([&] { + grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr(), grad_inputs.data_ptr(), gridtype, align_corners); + })); + +} diff --git a/gridencoder/src/gridencoder.h b/gridencoder/src/gridencoder.h new file mode 100644 index 0000000000000000000000000000000000000000..b093e78272ee55a309fadfa08db6a2943c247700 --- /dev/null +++ b/gridencoder/src/gridencoder.h @@ -0,0 +1,15 @@ +#ifndef _HASH_ENCODE_H +#define _HASH_ENCODE_H + +#include +#include + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [B, L * C], float +// H: base resolution +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners); +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners); + +#endif \ No newline at end of file diff --git a/imaginaire/__init__.py b/imaginaire/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780 --- /dev/null +++ b/imaginaire/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md diff --git a/imaginaire/config.py b/imaginaire/config.py new file mode 100644 index 0000000000000000000000000000000000000000..3a728a5aaee8d040288ff9ffd17a4fa83a7e2ca7 --- /dev/null +++ b/imaginaire/config.py @@ -0,0 +1,238 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +"""Config utilities for yml file.""" + +import collections +import functools +import os +import re + +import yaml +from imaginaire.utils.distributed import master_only_print as print + +DEBUG = False +USE_JIT = False + + +class AttrDict(dict): + """Dict as attribute trick.""" + + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + for key, value in self.__dict__.items(): + if isinstance(value, dict): + self.__dict__[key] = AttrDict(value) + elif isinstance(value, (list, tuple)): + if isinstance(value[0], dict): + self.__dict__[key] = [AttrDict(item) for item in value] + else: + self.__dict__[key] = value + + def yaml(self): + """Convert object to yaml dict and return.""" + yaml_dict = {} + for key, value in self.__dict__.items(): + if isinstance(value, AttrDict): + yaml_dict[key] = value.yaml() + elif isinstance(value, list): + if isinstance(value[0], AttrDict): + new_l = [] + for item in value: + new_l.append(item.yaml()) + yaml_dict[key] = new_l + else: + yaml_dict[key] = value + else: + yaml_dict[key] = value + return yaml_dict + + def __repr__(self): + """Print all variables.""" + ret_str = [] + for key, value in self.__dict__.items(): + if isinstance(value, AttrDict): + ret_str.append('{}:'.format(key)) + child_ret_str = value.__repr__().split('\n') + for item in child_ret_str: + ret_str.append(' ' + item) + elif isinstance(value, list): + if isinstance(value[0], AttrDict): + ret_str.append('{}:'.format(key)) + for item in value: + # Treat as AttrDict above. + child_ret_str = item.__repr__().split('\n') + for item in child_ret_str: + ret_str.append(' ' + item) + else: + ret_str.append('{}: {}'.format(key, value)) + else: + ret_str.append('{}: {}'.format(key, value)) + return '\n'.join(ret_str) + + +class Config(AttrDict): + r"""Configuration class. This should include every human specifiable + hyperparameter values for your training.""" + + def __init__(self, filename=None, verbose=False): + super(Config, self).__init__() + self.source_filename = filename + # Set default parameters. + # Logging. + large_number = 1000000000 + self.snapshot_save_iter = large_number + self.snapshot_save_epoch = large_number + self.metrics_iter = None + self.metrics_epoch = None + self.snapshot_save_start_iter = 0 + self.snapshot_save_start_epoch = 0 + self.image_save_iter = large_number + self.image_display_iter = large_number + self.max_epoch = large_number + self.max_iter = large_number + self.logging_iter = 100 + self.speed_benchmark = False + + # Trainer. + self.trainer = AttrDict( + model_average_config=AttrDict(enabled=False, + beta=0.9999, + start_iteration=1000, + num_batch_norm_estimation_iterations=30, + remove_sn=True), + # model_average=False, + # model_average_beta=0.9999, + # model_average_start_iteration=1000, + # model_average_batch_norm_estimation_iteration=30, + # model_average_remove_sn=True, + image_to_tensorboard=False, + hparam_to_tensorboard=False, + distributed_data_parallel='pytorch', + distributed_data_parallel_params=AttrDict( + find_unused_parameters=False), + delay_allreduce=True, + gan_relativistic=False, + gen_step=1, + dis_step=1, + gan_decay_k=1., + gan_min_k=1., + gan_separate_topk=False, + aug_policy='', + channels_last=False, + strict_resume=True, + amp_gp=False, + amp_config=AttrDict(init_scale=65536.0, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=False)) + + # Networks. + self.gen = AttrDict(type='imaginaire.generators.dummy') + self.dis = AttrDict(type='imaginaire.discriminators.dummy') + + # Optimizers. + self.gen_opt = AttrDict(type='adam', + fused_opt=False, + lr=0.0001, + adam_beta1=0.0, + adam_beta2=0.999, + eps=1e-8, + lr_policy=AttrDict(iteration_mode=False, + type='step', + step_size=large_number, + gamma=1)) + self.dis_opt = AttrDict(type='adam', + fused_opt=False, + lr=0.0001, + adam_beta1=0.0, + adam_beta2=0.999, + eps=1e-8, + lr_policy=AttrDict(iteration_mode=False, + type='step', + step_size=large_number, + gamma=1)) + # Data. + self.data = AttrDict(name='dummy', + type='imaginaire.datasets.images', + num_workers=0) + self.test_data = AttrDict(name='dummy', + type='imaginaire.datasets.images', + num_workers=0, + test=AttrDict(is_lmdb=False, + roots='', + batch_size=1)) + + +# Cudnn. + self.cudnn = AttrDict(deterministic=False, + benchmark=True) + + # Others. + self.pretrained_weight = '' + self.inference_args = AttrDict() + + # Update with given configurations. + assert os.path.exists(filename), 'File {} not exist.'.format(filename) + loader = yaml.SafeLoader + loader.add_implicit_resolver( + u'tag:yaml.org,2002:float', + re.compile(u'''^(?: + [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$''', re.X), + list(u'-+0123456789.')) + try: + with open(filename, 'r') as f: + cfg_dict = yaml.load(f, Loader=loader) + except EnvironmentError: + print('Please check the file with name of "%s"', filename) + recursive_update(self, cfg_dict) + + # Put common opts in both gen and dis. + if 'common' in cfg_dict: + self.common = AttrDict(**cfg_dict['common']) + self.gen.common = self.common + self.dis.common = self.common + + if verbose: + print(' imaginaire config '.center(80, '-')) + print(self.__repr__()) + print(''.center(80, '-')) + + +def rsetattr(obj, attr, val): + """Recursively find object and set value""" + pre, _, post = attr.rpartition('.') + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + """Recursively find object and return value""" + + def _getattr(obj, attr): + r"""Get attribute.""" + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split('.')) + + +def recursive_update(d, u): + """Recursively update AttrDict d with AttrDict u""" + for key, value in u.items(): + if isinstance(value, collections.abc.Mapping): + d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value) + elif isinstance(value, (list, tuple)): + if isinstance(value[0], dict): + d.__dict__[key] = [AttrDict(item) for item in value] + else: + d.__dict__[key] = value + else: + d.__dict__[key] = value + return d diff --git a/imaginaire/discriminators/__init__.py b/imaginaire/discriminators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imaginaire/discriminators/gancraft.py b/imaginaire/discriminators/gancraft.py new file mode 100644 index 0000000000000000000000000000000000000000..0bc070cb46ac5c6ae287231ddd0144bedd6d55a2 --- /dev/null +++ b/imaginaire/discriminators/gancraft.py @@ -0,0 +1,278 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import functools +from imaginaire.layers import Conv2dBlock + +from imaginaire.utils.data import get_paired_input_label_channel_number, get_paired_input_image_channel_number +from imaginaire.utils.distributed import master_only_print as print + + +class Discriminator(nn.Module): + r"""Multi-resolution patch discriminator. Based on FPSE discriminator but with N+1 labels. + + Args: + dis_cfg (obj): Discriminator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, dis_cfg, data_cfg): + super(Discriminator, self).__init__() + # We assume the first datum is the ground truth image. + image_channels = get_paired_input_image_channel_number(data_cfg) + # Calculate number of channels in the input label. + num_labels = get_paired_input_label_channel_number(data_cfg) + + self.use_label = getattr(dis_cfg, 'use_label', True) + # Override number of input channels + if hasattr(dis_cfg, 'image_channels'): + image_channels = dis_cfg.image_channels + if hasattr(dis_cfg, 'num_labels'): + num_labels = dis_cfg.num_labels + else: + # We assume the first datum is the ground truth image. + image_channels = get_paired_input_image_channel_number(data_cfg) + # Calculate number of channels in the input label. + num_labels = get_paired_input_label_channel_number(data_cfg) + + if not self.use_label: + num_labels = 2 # ignore + true + + # Build the discriminator. + num_filters = getattr(dis_cfg, 'num_filters', 128) + weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') + + fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3) + fpse_activation_norm_type = getattr(dis_cfg, + 'fpse_activation_norm_type', + 'none') + do_multiscale = getattr(dis_cfg, 'do_multiscale', False) + smooth_resample = getattr(dis_cfg, 'smooth_resample', False) + no_label_except_largest_scale = getattr(dis_cfg, 'no_label_except_largest_scale', False) + + self.fpse_discriminator = FPSEDiscriminator( + image_channels, + num_labels, + num_filters, + fpse_kernel_size, + weight_norm_type, + fpse_activation_norm_type, + do_multiscale, + smooth_resample, + no_label_except_largest_scale) + + def _single_forward(self, input_label, input_image, weights): + output_list, features_list = self.fpse_discriminator(input_image, input_label, weights) + return output_list, [features_list] + + def forward(self, data, net_G_output, weights=None, incl_real=False, incl_pseudo_real=False): + r"""GANcraft discriminator forward. + + Args: + data (dict): + - data (N x C1 x H x W tensor) : Ground truth images. + - label (N x C2 x H x W tensor) : Semantic representations. + - z (N x style_dims tensor): Gaussian random noise. + net_G_output (dict): + - fake_images (N x C1 x H x W tensor) : Fake images. + Returns: + output_x (dict): + - real_outputs (list): list of output tensors produced by + individual patch discriminators for real images. + - real_features (list): list of lists of features produced by + individual patch discriminators for real images. + - fake_outputs (list): list of output tensors produced by + individual patch discriminators for fake images. + - fake_features (list): list of lists of features produced by + individual patch discriminators for fake images. + """ + output_x = dict() + + # Fake. + fake_images = net_G_output['fake_images'] + if self.use_label: + fake_labels = data['fake_masks'] + else: + fake_labels = torch.zeros([fake_images.size(0), 2, fake_images.size( + 2), fake_images.size(3)], device=fake_images.device, dtype=fake_images.dtype) + fake_labels[:, 1, :, :] = 1 + output_x['fake_outputs'], output_x['fake_features'] = \ + self._single_forward(fake_labels, fake_images, None) + + # Real. + if incl_real: + real_images = data['images'] + if self.use_label: + real_labels = data['real_masks'] + else: + real_labels = torch.zeros([real_images.size(0), 2, real_images.size( + 2), real_images.size(3)], device=real_images.device, dtype=real_images.dtype) + real_labels[:, 1, :, :] = 1 + output_x['real_outputs'], output_x['real_features'] = \ + self._single_forward(real_labels, real_images, None) + + # pseudo-Real. + if incl_pseudo_real: + preal_images = data['pseudo_real_img'] + preal_labels = data['fake_masks'] + if not self.use_label: + preal_labels = torch.zeros([preal_images.size(0), 2, preal_images.size( + 2), preal_images.size(3)], device=preal_images.device, dtype=preal_images.dtype) + preal_labels[:, 1, :, :] = 1 + output_x['pseudo_real_outputs'], output_x['pseudo_real_features'] = \ + self._single_forward(preal_labels, preal_images, None) + + return output_x + + +class FPSEDiscriminator(nn.Module): + def __init__(self, + num_input_channels, + num_labels, + num_filters, + kernel_size, + weight_norm_type, + activation_norm_type, + do_multiscale, + smooth_resample, + no_label_except_largest_scale): + super().__init__() + + self.do_multiscale = do_multiscale + self.no_label_except_largest_scale = no_label_except_largest_scale + + padding = int(np.ceil((kernel_size - 1.0) / 2)) + nonlinearity = 'leakyrelu' + stride1_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=1, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity, + # inplace_nonlinearity=True, + order='CNA') + down_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=2, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity, + # inplace_nonlinearity=True, + order='CNA') + latent_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=1, + stride=1, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity, + # inplace_nonlinearity=True, + order='CNA') + # bottom-up pathway + self.enc1 = down_conv2d_block(num_input_channels, num_filters) # 3 + self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters) # 7 + self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters) # 15 + self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters) # 31 + self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters) # 63 + + # top-down pathway + # self.lat1 = latent_conv2d_block(num_filters, 2 * num_filters) # Zekun + self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters) + self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters) + self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters) + self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters) + + # upsampling + self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear', + align_corners=False) + + # final layers + self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) + self.output = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1) + + if self.do_multiscale: + self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) + self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) + if self.no_label_except_largest_scale: + self.output3 = Conv2dBlock(num_filters * 2, 2, kernel_size=1) + self.output4 = Conv2dBlock(num_filters * 2, 2, kernel_size=1) + else: + self.output3 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1) + self.output4 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1) + + self.interpolator = functools.partial(F.interpolate, mode='nearest') + if smooth_resample: + self.interpolator = self.smooth_interp + + @staticmethod + def smooth_interp(x, size): + r"""Smooth interpolation of segmentation maps. + + Args: + x (4D tensor): Segmentation maps. + size(2D list): Target size (H, W). + """ + x = F.interpolate(x, size=size, mode='area') + onehot_idx = torch.argmax(x, dim=-3, keepdims=True) + x.fill_(0.0) + x.scatter_(1, onehot_idx, 1.0) + return x + + # Weights: [N C] + def forward(self, images, segmaps, weights=None): + # Assume images 256x256 + # bottom-up pathway + feat11 = self.enc1(images) # 128 + feat12 = self.enc2(feat11) # 64 + feat13 = self.enc3(feat12) # 32 + feat14 = self.enc4(feat13) # 16 + feat15 = self.enc5(feat14) # 8 + # top-down pathway and lateral connections + feat25 = self.lat5(feat15) # 8 + feat24 = self.upsample2x(feat25) + self.lat4(feat14) # 16 + feat23 = self.upsample2x(feat24) + self.lat3(feat13) # 32 + feat22 = self.upsample2x(feat23) + self.lat2(feat12) # 64 + + # final prediction layers + feat32 = self.final2(feat22) + + results = [] + label_map = self.interpolator(segmaps, size=feat32.size()[2:]) + pred2 = self.output(feat32) # N, num_labels+1, H//4, W//4 + + features = [feat11, feat12, feat13, feat14, feat15, feat25, feat24, feat23, feat22] + if weights is not None: + label_map = label_map * weights[..., None, None] + results.append({'pred': pred2, 'label': label_map}) + + if self.do_multiscale: + feat33 = self.final3(feat23) + pred3 = self.output3(feat33) + + feat34 = self.final4(feat24) + pred4 = self.output4(feat34) + + if self.no_label_except_largest_scale: + label_map3 = torch.ones([pred3.size(0), 1, pred3.size(2), pred3.size(3)], device=pred3.device) + label_map4 = torch.ones([pred4.size(0), 1, pred4.size(2), pred4.size(3)], device=pred4.device) + else: + label_map3 = self.interpolator(segmaps, size=pred3.size()[2:]) + label_map4 = self.interpolator(segmaps, size=pred4.size()[2:]) + + if weights is not None: + label_map3 = label_map3 * weights[..., None, None] + label_map4 = label_map4 * weights[..., None, None] + + results.append({'pred': pred3, 'label': label_map3}) + results.append({'pred': pred4, 'label': label_map4}) + + return results, features diff --git a/imaginaire/generators/__init__.py b/imaginaire/generators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780 --- /dev/null +++ b/imaginaire/generators/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md diff --git a/imaginaire/generators/gancraft_base.py b/imaginaire/generators/gancraft_base.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6fc84b104e1695bc51d555c0c0b3307daa0b19 --- /dev/null +++ b/imaginaire/generators/gancraft_base.py @@ -0,0 +1,603 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import functools +import re + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from imaginaire.layers import Conv2dBlock, LinearBlock +from imaginaire.model_utils.layers import AffineMod, ModLinear +import imaginaire.model_utils.gancraft.mc_utils as mc_utils +import imaginaire.model_utils.gancraft.voxlib as voxlib +from imaginaire.utils.distributed import master_only_print as print + + +class RenderMLP(nn.Module): + r""" MLP with affine modulation.""" + + def __init__(self, in_channels, style_dim, viewdir_dim, mask_dim=680, + out_channels_s=1, out_channels_c=3, hidden_channels=256, + use_seg=True): + super(RenderMLP, self).__init__() + + self.use_seg = use_seg + if self.use_seg: + self.fc_m_a = nn.Linear(mask_dim, hidden_channels, bias=False) + + self.fc_viewdir = None + if viewdir_dim > 0: + self.fc_viewdir = nn.Linear(viewdir_dim, hidden_channels, bias=False) + + self.fc_1 = nn.Linear(in_channels, hidden_channels) + + self.fc_2 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) + self.fc_3 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) + self.fc_4 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) + + self.fc_sigma = nn.Linear(hidden_channels, out_channels_s) + + if viewdir_dim > 0: + self.fc_5 = nn.Linear(hidden_channels, hidden_channels, bias=False) + self.mod_5 = AffineMod(hidden_channels, style_dim, mod_bias=True) + else: + self.fc_5 = ModLinear(hidden_channels, hidden_channels, style_dim, + bias=False, mod_bias=True, output_mode=True) + self.fc_6 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) + self.fc_out_c = nn.Linear(hidden_channels, out_channels_c) + + self.act = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, x, raydir, z, m): + r""" Forward network + + Args: + x (N x H x W x M x in_channels tensor): Projected features. + raydir (N x H x W x 1 x viewdir_dim tensor): Ray directions. + z (N x style_dim tensor): Style codes. + m (N x H x W x M x mask_dim tensor): One-hot segmentation maps. + """ + b, h, w, n, _ = x.size() + z = z[:, None, None, None, :] + + f = self.fc_1(x) + if self.use_seg: + f = f + self.fc_m_a(m) + # Common MLP + f = self.act(f) + f = self.act(self.fc_2(f, z)) + f = self.act(self.fc_3(f, z)) + f = self.act(self.fc_4(f, z)) + + # Sigma MLP + sigma = self.fc_sigma(f) + + # Color MLP + if self.fc_viewdir is not None: + f = self.fc_5(f) + f = f + self.fc_viewdir(raydir) + f = self.act(self.mod_5(f, z)) + else: + f = self.act(self.fc_5(f, z)) + f = self.act(self.fc_6(f, z)) + c = self.fc_out_c(f) + return sigma, c + + +class StyleMLP(nn.Module): + r"""MLP converting style code to intermediate style representation.""" + + def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True, + output_act=True): + super(StyleMLP, self).__init__() + + self.normalize_input = normalize_input + self.output_act = output_act + fc_layers = [] + fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True)) + for i in range(num_layers-1): + fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True)) + self.fc_layers = nn.ModuleList(fc_layers) + + self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True) + + if leaky_relu: + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + self.act = functools.partial(F.relu, inplace=True) + + def forward(self, z): + r""" Forward network + + Args: + z (N x style_dim tensor): Style codes. + """ + if self.normalize_input: + z = F.normalize(z, p=2, dim=-1) + for fc_layer in self.fc_layers: + z = self.act(fc_layer(z)) + z = self.fc_out(z) + if self.output_act: + z = self.act(z) + return z + + +class SKYMLP(nn.Module): + r"""MLP converting ray directions to sky features.""" + + def __init__(self, in_channels, style_dim, out_channels_c=3, + hidden_channels=256, leaky_relu=True): + super(SKYMLP, self).__init__() + self.fc_z_a = nn.Linear(style_dim, hidden_channels, bias=False) + + self.fc1 = nn.Linear(in_channels, hidden_channels) + self.fc2 = nn.Linear(hidden_channels, hidden_channels) + self.fc3 = nn.Linear(hidden_channels, hidden_channels) + self.fc4 = nn.Linear(hidden_channels, hidden_channels) + self.fc5 = nn.Linear(hidden_channels, hidden_channels) + + self.fc_out_c = nn.Linear(hidden_channels, out_channels_c) + + if leaky_relu: + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + self.act = functools.partial(F.relu, inplace=True) + + def forward(self, x, z): + r"""Forward network + + Args: + x (... x in_channels tensor): Ray direction embeddings. + z (... x style_dim tensor): Style codes. + """ + + z = self.fc_z_a(z) + while z.dim() < x.dim(): + z = z.unsqueeze(1) + + y = self.act(self.fc1(x) + z) + y = self.act(self.fc2(y)) + y = self.act(self.fc3(y)) + y = self.act(self.fc4(y)) + y = self.act(self.fc5(y)) + c = self.fc_out_c(y) + + return c + + +class RenderCNN(nn.Module): + r"""CNN converting intermediate feature map to final image.""" + + def __init__(self, in_channels, style_dim, hidden_channels=256, + leaky_relu=True): + super(RenderCNN, self).__init__() + self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels) + + self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0) + self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) + self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) + + self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) + self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) + + self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) + self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) + + self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0) + + if leaky_relu: + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + self.act = functools.partial(F.relu, inplace=True) + + def modulate(self, x, w, b): + w = w[..., None, None] + b = b[..., None, None] + return x * (w+1) + b + + def forward(self, x, z): + r"""Forward network. + + Args: + x (N x in_channels x H x W tensor): Intermediate feature map + z (N x style_dim tensor): Style codes. + """ + z = self.fc_z_cond(z) + adapt = torch.chunk(z, 2 * 2, dim=-1) + + y = self.act(self.conv1(x)) + + y = y + self.conv2b(self.act(self.conv2a(y))) + y = self.act(self.modulate(y, adapt[0], adapt[1])) + + y = y + self.conv3b(self.act(self.conv3a(y))) + y = self.act(self.modulate(y, adapt[2], adapt[3])) + + y = y + self.conv4b(self.act(self.conv4a(y))) + y = self.act(y) + + y = self.conv4(y) + + return y + + +class StyleEncoder(nn.Module): + r"""Style Encoder constructor. + + Args: + style_enc_cfg (obj): Style encoder definition file. + """ + + def __init__(self, style_enc_cfg): + super(StyleEncoder, self).__init__() + input_image_channels = style_enc_cfg.input_image_channels + num_filters = style_enc_cfg.num_filters + kernel_size = style_enc_cfg.kernel_size + padding = int(np.ceil((kernel_size - 1.0) / 2)) + style_dims = style_enc_cfg.style_dims + weight_norm_type = style_enc_cfg.weight_norm_type + self.no_vae = getattr(style_enc_cfg, 'no_vae', False) + activation_norm_type = 'none' + nonlinearity = 'leakyrelu' + base_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=2, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + # inplace_nonlinearity=True, + nonlinearity=nonlinearity) + self.layer1 = base_conv2d_block(input_image_channels, num_filters) + self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2) + self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4) + self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8) + self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8) + self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8) + self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims) + if not self.no_vae: + self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims) + + def forward(self, input_x): + r"""SPADE Style Encoder forward. + + Args: + input_x (N x 3 x H x W tensor): input images. + Returns: + mu (N x C tensor): Mean vectors. + logvar (N x C tensor): Log-variance vectors. + z (N x C tensor): Style code vectors. + """ + if input_x.size(2) != 256 or input_x.size(3) != 256: + input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear') + x = self.layer1(input_x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + x = self.layer6(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + if not self.no_vae: + logvar = self.fc_var(x) + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + z = eps.mul(std) + mu + else: + z = mu + logvar = torch.zeros_like(mu) + return mu, logvar, z + + +class Base3DGenerator(nn.Module): + r"""Minecraft 3D generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, gen_cfg, data_cfg): + super(Base3DGenerator, self).__init__() + print('Base3DGenerator initialization.') + + # ---------------------- Main Network ------------------------ + # Exclude some of the features from positional encoding + self.pe_no_pe_feat_dim = getattr(gen_cfg, 'pe_no_pe_feat_dim', 0) + + # blk_feat passes through PE + input_dim = (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim)*(gen_cfg.pe_lvl_feat*2) + self.pe_no_pe_feat_dim + if (gen_cfg.pe_incl_orig_feat): + input_dim += (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim) + print('[Base3DGenerator] Expected input dimensions: ', input_dim) + self.input_dim = input_dim + + self.mlp_model_kwargs = gen_cfg.mlp_model_kwargs + self.pe_lvl_localcoords = getattr(gen_cfg, 'pe_lvl_localcoords', 0) + if self.pe_lvl_localcoords > 0: + self.mlp_model_kwargs['poscode_dim'] = self.pe_lvl_localcoords * 2 * 3 + + # Set pe_lvl_raydir=0 and pe_incl_orig_raydir=False to disable view direction input + input_dim_viewdir = 3*(gen_cfg.pe_lvl_raydir*2) + if (gen_cfg.pe_incl_orig_raydir): + input_dim_viewdir += 3 + print('[Base3DGenerator] Expected viewdir input dimensions: ', input_dim_viewdir) + self.input_dim_viewdir = input_dim_viewdir + + self.pe_params = [gen_cfg.pe_lvl_feat, gen_cfg.pe_incl_orig_feat, + gen_cfg.pe_lvl_raydir, gen_cfg.pe_incl_orig_raydir] + + # Style input dimension + style_dims = gen_cfg.style_dims + self.style_dims = style_dims + interm_style_dims = getattr(gen_cfg, 'interm_style_dims', style_dims) + self.interm_style_dims = interm_style_dims + # ---------------------- Style MLP -------------------------- + self.style_net = globals()[gen_cfg.stylenet_model]( + style_dims, interm_style_dims, **gen_cfg.stylenet_model_kwargs) + + # number of output channels for MLP (before blending) + final_feat_dim = getattr(gen_cfg, 'final_feat_dim', 16) + self.final_feat_dim = final_feat_dim + + # ----------------------- Sky Network ------------------------- + sky_input_dim_base = 3 + # Dedicated sky network input dimensions + sky_input_dim = sky_input_dim_base*(gen_cfg.pe_lvl_raydir_sky*2) + if (gen_cfg.pe_incl_orig_raydir_sky): + sky_input_dim += sky_input_dim_base + print('[Base3DGenerator] Expected sky input dimensions: ', sky_input_dim) + self.pe_params_sky = [gen_cfg.pe_lvl_raydir_sky, gen_cfg.pe_incl_orig_raydir_sky] + self.sky_net = SKYMLP(sky_input_dim, style_dim=interm_style_dims, out_channels_c=final_feat_dim) + + # ----------------------- Style Encoder ------------------------- + style_enc_cfg = getattr(gen_cfg, 'style_enc', None) + setattr(style_enc_cfg, 'input_image_channels', 3) + setattr(style_enc_cfg, 'style_dims', gen_cfg.style_dims) + self.style_encoder = StyleEncoder(style_enc_cfg) + + # ---------------------- Ray Caster ------------------------- + self.num_blocks_early_stop = gen_cfg.num_blocks_early_stop + self.num_samples = gen_cfg.num_samples + self.sample_depth = gen_cfg.sample_depth + self.coarse_deterministic_sampling = getattr(gen_cfg, 'coarse_deterministic_sampling', True) + self.sample_use_box_boundaries = getattr(gen_cfg, 'sample_use_box_boundaries', True) + + # ---------------------- Blender ------------------------- + self.raw_noise_std = getattr(gen_cfg, 'raw_noise_std', 0.0) + self.dists_scale = getattr(gen_cfg, 'dists_scale', 0.25) + self.clip_feat_map = getattr(gen_cfg, 'clip_feat_map', True) + self.keep_sky_out = getattr(gen_cfg, 'keep_sky_out', False) + self.keep_sky_out_avgpool = getattr(gen_cfg, 'keep_sky_out_avgpool', False) + keep_sky_out_learnbg = getattr(gen_cfg, 'keep_sky_out_learnbg', False) + self.sky_global_avgpool = getattr(gen_cfg, 'sky_global_avgpool', False) + if self.keep_sky_out: + self.sky_replace_color = None + if keep_sky_out_learnbg: + sky_replace_color = torch.zeros([final_feat_dim]) + sky_replace_color.requires_grad = True + self.sky_replace_color = torch.nn.Parameter(sky_replace_color) + # ---------------------- render_cnn ------------------------- + self.denoiser = RenderCNN(final_feat_dim, style_dim=interm_style_dims) + self.pad = gen_cfg.pad + + def get_param_groups(self, cfg_opt): + print('[Generator] get_param_groups') + + if hasattr(cfg_opt, 'ignore_parameters'): + print('[Generator::get_param_groups] [x]: ignored.') + optimize_parameters = [] + for k, x in self.named_parameters(): + match = False + for m in cfg_opt.ignore_parameters: + if re.match(m, k) is not None: + match = True + print(' [x]', k) + break + if match is False: + print(' [v]', k) + optimize_parameters.append(x) + else: + optimize_parameters = self.parameters() + + param_groups = [] + param_groups.append({'params': optimize_parameters}) + + if hasattr(cfg_opt, 'param_groups'): + optimized_param_names = [] + all_param_names = [k for k, v in self.named_parameters()] + param_groups = [] + for k, v in cfg_opt.param_groups.items(): + print('[Generator::get_param_groups] Adding param group from config:', k, v) + params = getattr(self, k) + named_parameters = [k] + if issubclass(type(params), nn.Module): + named_parameters = [k+'.'+pname for pname, _ in params.named_parameters()] + params = params.parameters() + param_groups.append({'params': params, **v}) + optimized_param_names.extend(named_parameters) + + print('[Generator::get_param_groups] UNOPTIMIZED PARAMETERS:\n ', + set(all_param_names) - set(optimized_param_names)) + + return param_groups + + def _forward_perpix_sub(self, blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot=None): + r"""Forwarding the MLP. + + Args: + blk_feats (K x C1 tensor): Sparse block features. + worldcoord2 (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. L is number of samples; N is batch size, always 1. + raydirs_in (N x H x W x 1 x C2 tensor or None): ray direction embeddings. + z (N x C3 tensor): Intermediate style vectors. + mc_masks_onehot (N x H x W x L x C4): One-hot segmentation maps. + Returns: + net_out_s (N x H x W x L x 1 tensor): Opacities. + net_out_c (N x H x W x L x C5 tensor): Color embeddings. + """ + proj_feature = voxlib.sparse_trilinear_interp_worldcoord( + blk_feats, self.voxel.corner_t, worldcoord2, ign_zero=True) + + render_net_extra_kwargs = {} + if self.pe_lvl_localcoords > 0: + local_coords = torch.remainder(worldcoord2, 1.0) * 2.0 + # Scale to [0, 2], as the positional encoding function doesn't have internal x2 + local_coords[torch.isnan(local_coords)] = 0.0 + local_coords = local_coords.contiguous() + poscode = voxlib.positional_encoding(local_coords, self.pe_lvl_localcoords, -1, False) + render_net_extra_kwargs['poscode'] = poscode + + if self.pe_params[0] == 0 and self.pe_params[1] is True: # no PE shortcut, saves ~400MB + feature_in = proj_feature + else: + if self.pe_no_pe_feat_dim > 0: + feature_in = voxlib.positional_encoding( + proj_feature[..., :-self.pe_no_pe_feat_dim].contiguous(), self.pe_params[0], -1, self.pe_params[1]) + feature_in = torch.cat([feature_in, proj_feature[..., -self.pe_no_pe_feat_dim:]], dim=-1) + else: + feature_in = voxlib.positional_encoding( + proj_feature.contiguous(), self.pe_params[0], -1, self.pe_params[1]) + + net_out_s, net_out_c = self.render_net(feature_in, raydirs_in, z, mc_masks_onehot, **render_net_extra_kwargs) + + if self.raw_noise_std > 0.: + noise = torch.randn_like(net_out_s) * self.raw_noise_std + net_out_s = net_out_s + noise + + return net_out_s, net_out_c + + def _forward_perpix(self, blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z): + r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features + + Args: + blk_feats (K x C1 tensor): Sparse block features. + voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels, why always 6? + depth2 (N x 2 x H x W x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection. + raydirs (N x H x W x 1 x 3 tensor): The direction of each ray. + cam_ori_t (N x 3 tensor): Camera origins. + z (N x C3 tensor): Intermediate style vectors. + """ + # Generate sky_mask; PE transform on ray direction. + with torch.no_grad(): + raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() + if self.pe_params[2] == 0 and self.pe_params[3] is True: + raydirs_in = raydirs_in + elif self.pe_params[2] == 0 and self.pe_params[3] is False: # Not using raydir at all + raydirs_in = None + else: + raydirs_in = voxlib.positional_encoding(raydirs_in, self.pe_params[2], -1, self.pe_params[3]) + + # sky_mask: when True, ray finally hits sky + sky_mask = voxel_id[:, :, :, [-1], :] == 0 + # sky_only_mask: when True, ray hits nothing but sky + sky_only_mask = voxel_id[:, :, :, [0], :] == 0 + + with torch.no_grad(): + # Random sample points along the ray + num_samples = self.num_samples + 1 + if self.sample_use_box_boundaries: + num_samples = self.num_samples - self.num_blocks_early_stop + + # 10 samples per ray + 4 intersections - 2 + rand_depth, new_dists, new_idx = mc_utils.sample_depth_batched( + depth2, num_samples, deterministic=self.coarse_deterministic_sampling, + use_box_boundaries=self.sample_use_box_boundaries, sample_depth=self.sample_depth) + + worldcoord2 = raydirs * rand_depth + cam_ori_t[:, None, None, None, :] + + # Generate per-sample segmentation label + voxel_id_reduced = self.label_trans.mc2reduced(voxel_id, ign2dirt=True) + mc_masks = torch.gather(voxel_id_reduced, -2, new_idx) # B 256 256 N 1 + mc_masks = mc_masks.long() + mc_masks_onehot = torch.zeros([mc_masks.size(0), mc_masks.size(1), mc_masks.size( + 2), mc_masks.size(3), self.num_reduced_labels], dtype=torch.float, device=voxel_id.device) + # mc_masks_onehot: [B H W Nlayer 680] + mc_masks_onehot.scatter_(-1, mc_masks, 1.0) + + net_out_s, net_out_c = self._forward_perpix_sub(blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot) + + # Handle sky + sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() + sky_raydirs_in = voxlib.positional_encoding(sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1]) + skynet_out_c = self.sky_net(sky_raydirs_in, z) + + # Blending + weights = mc_utils.volum_rendering_relu(net_out_s, new_dists * self.dists_scale, dim=-2) + + # If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero. + weights = weights * torch.logical_not(sky_only_mask).float() + total_weights_raw = torch.sum(weights, dim=-2, keepdim=True) # 256 256 1 1 + total_weights = total_weights_raw + + is_gnd = worldcoord2[..., [0]] <= 1.0 # Y X Z, [256, 256, 4, 3], nan < 1.0 == False + is_gnd = is_gnd.any(dim=-2, keepdim=True) + nosky_mask = torch.logical_or(torch.logical_not(sky_mask), is_gnd) + nosky_mask = nosky_mask.float() + + # Avoid sky leakage + sky_weight = 1.0-total_weights + if self.keep_sky_out: + # keep_sky_out_avgpool overrides sky_replace_color + if self.sky_replace_color is None or self.keep_sky_out_avgpool: + if self.keep_sky_out_avgpool: + if hasattr(self, 'sky_avg'): + sky_avg = self.sky_avg + else: + if self.sky_global_avgpool: + sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True) + else: + skynet_out_c_nchw = skynet_out_c.permute(0, 4, 1, 2, 3).squeeze(-1).contiguous() + sky_avg = F.avg_pool2d(skynet_out_c_nchw, 31, stride=1, padding=15, count_include_pad=False) + sky_avg = sky_avg.permute(0, 2, 3, 1).unsqueeze(-2).contiguous() + # print(sky_avg.shape) + skynet_out_c = skynet_out_c * (1.0-nosky_mask) + sky_avg*(nosky_mask) + else: + sky_weight = sky_weight * (1.0-nosky_mask) + else: + skynet_out_c = skynet_out_c * (1.0-nosky_mask) + self.sky_replace_color*(nosky_mask) + + if self.clip_feat_map is True: # intermediate feature before blending & CNN + rgbs = torch.clamp(net_out_c, -1, 1) + 1 + rgbs_sky = torch.clamp(skynet_out_c, -1, 1) + 1 + net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ + rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 + net_out = net_out.squeeze(-2) + net_out = net_out - 1 + elif self.clip_feat_map is False: + rgbs = net_out_c + rgbs_sky = skynet_out_c + net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ + rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 + net_out = net_out.squeeze(-2) + elif self.clip_feat_map == 'tanh': + rgbs = torch.tanh(net_out_c) + rgbs_sky = torch.tanh(skynet_out_c) + net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ + rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 + net_out = net_out.squeeze(-2) + else: + raise NotImplementedError + + return net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \ + nosky_mask, sky_mask, sky_only_mask, new_idx + + def _forward_global(self, net_out, z): + r"""Forward the CNN + + Args: + net_out (N x C5 x H x W tensor): Intermediate feature maps. + z (N x C3 tensor): Intermediate style vectors. + + Returns: + fake_images (N x 3 x H x W tensor): Output image. + fake_images_raw (N x 3 x H x W tensor): Output image before TanH. + """ + fake_images = net_out.permute(0, 3, 1, 2).contiguous() + fake_images_raw = self.denoiser(fake_images, z) + fake_images = torch.tanh(fake_images_raw) + + return fake_images, fake_images_raw diff --git a/imaginaire/generators/scenedreamer.py b/imaginaire/generators/scenedreamer.py new file mode 100644 index 0000000000000000000000000000000000000000..ed32b6b7906b659161b2c57e6292ed94761b309b --- /dev/null +++ b/imaginaire/generators/scenedreamer.py @@ -0,0 +1,851 @@ +# Using Hashgrid as backbone representation + +import os +import cv2 +import imageio +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import imaginaire.model_utils.gancraft.camctl as camctl +import imaginaire.model_utils.gancraft.mc_utils as mc_utils +import imaginaire.model_utils.gancraft.voxlib as voxlib +from imaginaire.model_utils.pcg_gen import PCGVoxelGenerator, PCGCache +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.generators.gancraft_base import Base3DGenerator +from encoding import get_encoder + +from imaginaire.model_utils.layers import LightningMLP, ConditionalHashGrid + +class Generator(Base3DGenerator): + r"""SceneDreamer generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, gen_cfg, data_cfg): + super(Generator, self).__init__(gen_cfg, data_cfg) + print('SceneDreamer[Hash] on ALL Scenes generator initialization.') + + # here should be a list of height maps and semantic maps + if gen_cfg.pcg_cache: + print('[Generator] Loading PCG dataset: ', gen_cfg.pcg_dataset_path) + self.voxel = PCGCache(gen_cfg.pcg_dataset_path) + print('[Generator] Loaded PCG dataset.') + else: + self.voxel = PCGVoxelGenerator(gen_cfg.scene_size) + self.blk_feats = None + # Minecraft -> SPADE label translator. + self.label_trans = mc_utils.MCLabelTranslator() + self.num_reduced_labels = self.label_trans.get_num_reduced_lbls() + self.reduced_label_set = getattr(gen_cfg, 'reduced_label_set', False) + self.use_label_smooth = getattr(gen_cfg, 'use_label_smooth', False) + self.use_label_smooth_real = getattr(gen_cfg, 'use_label_smooth_real', self.use_label_smooth) + self.use_label_smooth_pgt = getattr(gen_cfg, 'use_label_smooth_pgt', False) + self.label_smooth_dia = getattr(gen_cfg, 'label_smooth_dia', 11) + + # Load MLP model. + self.hash_encoder, self.hash_in_dim = get_encoder(encoding='hashgrid', input_dim=5, desired_resolution=2048 * 1, level_dim=8) + self.render_net = LightningMLP(self.hash_in_dim, viewdir_dim=self.input_dim_viewdir, style_dim=self.interm_style_dims, mask_dim=self.num_reduced_labels, out_channels_s=1, out_channels_c=self.final_feat_dim, **self.mlp_model_kwargs) + print(self.hash_encoder) + self.world_encoder = ConditionalHashGrid() + + # Camera sampler. + self.camera_sampler_type = getattr(gen_cfg, 'camera_sampler_type', "random") + assert self.camera_sampler_type in ['random', 'traditional'] + self.camera_min_entropy = getattr(gen_cfg, 'camera_min_entropy', -1) + self.camera_rej_avg_depth = getattr(gen_cfg, 'camera_rej_avg_depth', -1) + self.cam_res = gen_cfg.cam_res + self.crop_size = gen_cfg.crop_size + + print('Done with the SceneDreamer initialization.') + + def custom_init(self): + r"""Weight initialization.""" + + def init_func(m): + if hasattr(m, 'weight'): + try: + nn.init.kaiming_normal_(m.weight.data, a=0.2, nonlinearity='leaky_relu') + except: + print(m.name) + m.weight.data *= 0.5 + if hasattr(m, 'bias') and m.bias is not None: + m.bias.data.fill_(0.0) + self.apply(init_func) + + def _get_batch(self, batch_size, device): + r"""Sample camera poses and perform ray-voxel intersection. + + Args: + batch_size (int): Expected batch size of the current batch + device (torch.device): Device on which the tensors should be stored + """ + with torch.no_grad(): + self.voxel.sample_world(device) + voxel_id_batch = [] + depth2_batch = [] + raydirs_batch = [] + cam_ori_t_batch = [] + for b in range(batch_size): + while True: # Rejection sampling. + # Sample camera pose. + if self.camera_sampler_type == 'random': + cam_res = self.cam_res + cam_ori_t, cam_dir_t, cam_up_t = camctl.rand_camera_pose_thridperson2(self.voxel) + # ~24mm fov horizontal. + cam_f = 0.5/np.tan(np.deg2rad(73/2) * (np.random.rand(1)*0.5+0.5)) * (cam_res[1]-1) + cam_c = [(cam_res[0]-1)/2, (cam_res[1]-1)/2] + cam_res_crop = [self.crop_size[0] + self.pad, self.crop_size[1] + self.pad] + cam_c = mc_utils.rand_crop(cam_c, cam_res, cam_res_crop) + elif self.camera_sampler_type == 'traditional': + cam_res = self.cam_res + cam_c = [(cam_res[0]-1)/2, (cam_res[1]-1)/2] + dice = torch.rand(1).item() + if dice > 0.5: + cam_ori_t, cam_dir_t, cam_up_t, cam_f = \ + camctl.rand_camera_pose_tour(self.voxel) + cam_f = cam_f * (cam_res[1]-1) + else: + cam_ori_t, cam_dir_t, cam_up_t = \ + camctl.rand_camera_pose_thridperson2(self.voxel) + # ~24mm fov horizontal. + cam_f = 0.5 / np.tan(np.deg2rad(73/2) * (np.random.rand(1)*0.5+0.5)) * (cam_res[1]-1) + + cam_res_crop = [self.crop_size[0] + self.pad, self.crop_size[1] + self.pad] + cam_c = mc_utils.rand_crop(cam_c, cam_res, cam_res_crop) + else: + raise NotImplementedError( + 'Unknown self.camera_sampler_type: {}'.format(self.camera_sampler_type)) + + # Run ray-voxel intersection test + voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective( + self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, cam_res_crop, + self.num_blocks_early_stop) + + if self.camera_rej_avg_depth > 0: + depth_map = depth2[0, :, :, 0, :] + avg_depth = torch.mean(depth_map[~torch.isnan(depth_map)]) + if avg_depth < self.camera_rej_avg_depth: + continue + + # Reject low entropy. + if self.camera_min_entropy > 0: + # Check entropy. + maskcnt = torch.bincount( + torch.flatten(voxel_id[:, :, 0, 0]), weights=None, minlength=680).float() / \ + (voxel_id.size(0)*voxel_id.size(1)) + maskentropy = -torch.sum(maskcnt * torch.log(maskcnt+1e-10)) + if maskentropy < self.camera_min_entropy: + continue + break + + voxel_id_batch.append(voxel_id) + depth2_batch.append(depth2) + raydirs_batch.append(raydirs) + cam_ori_t_batch.append(cam_ori_t) + voxel_id = torch.stack(voxel_id_batch, dim=0) + depth2 = torch.stack(depth2_batch, dim=0) + raydirs = torch.stack(raydirs_batch, dim=0) + cam_ori_t = torch.stack(cam_ori_t_batch, dim=0).to(device) + cam_poses = None + return voxel_id, depth2, raydirs, cam_ori_t, cam_poses + + + def get_pseudo_gt(self, pseudo_gen, voxel_id, z=None, style_img=None, resize_512=True, deterministic=False): + r"""Evaluating img2img network to obtain pseudo-ground truth images. + + Args: + pseudo_gen (callable): Function converting mask to image using img2img network. + voxel_id (N x img_dims[0] x img_dims[1] x max_samples x 1 tensor): IDs of intersected tensors along + each ray. + z (N x C tensor): Optional style code passed to pseudo_gen. + style_img (N x 3 x H x W tensor): Optional style image passed to pseudo_gen. + resize_512 (bool): If True, evaluate pseudo_gen at 512x512 regardless of input resolution. + deterministic (bool): If True, disable stochastic label mapping. + """ + with torch.no_grad(): + mc_mask = voxel_id[:, :, :, 0, :].permute(0, 3, 1, 2).long().contiguous() + coco_mask = self.label_trans.mc2coco(mc_mask) - 1 + coco_mask[coco_mask < 0] = 183 + + if not deterministic: + # Stochastic mapping + dice = torch.rand(1).item() + if dice > 0.5 and dice < 0.9: + coco_mask[coco_mask == self.label_trans.gglbl2ggid('sky')] = self.label_trans.gglbl2ggid('clouds') + elif dice >= 0.9: + coco_mask[coco_mask == self.label_trans.gglbl2ggid('sky')] = self.label_trans.gglbl2ggid('fog') + dice = torch.rand(1).item() + if dice > 0.33 and dice < 0.66: + coco_mask[coco_mask == self.label_trans.gglbl2ggid('water')] = self.label_trans.gglbl2ggid('sea') + elif dice >= 0.66: + coco_mask[coco_mask == self.label_trans.gglbl2ggid('water')] = self.label_trans.gglbl2ggid('river') + + fake_masks = torch.zeros([coco_mask.size(0), 185, coco_mask.size(2), coco_mask.size(3)], + dtype=torch.half, device=voxel_id.device) + fake_masks.scatter_(1, coco_mask, 1.0) + + if self.use_label_smooth_pgt: + fake_masks = mc_utils.segmask_smooth(fake_masks, kernel_size=self.label_smooth_dia) + if self.pad > 0: + fake_masks = fake_masks[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2] + + # Generate pseudo GT using GauGAN. + if resize_512: + fake_masks_512 = F.interpolate(fake_masks, size=[512, 512], mode='nearest') + else: + fake_masks_512 = fake_masks + pseudo_real_img = pseudo_gen(fake_masks_512, z=z, style_img=style_img) + + # NaN Inf Guard. NaN can occure on Volta GPUs. + nan_mask = torch.isnan(pseudo_real_img) + inf_mask = torch.isinf(pseudo_real_img) + pseudo_real_img[nan_mask | inf_mask] = 0.0 + if resize_512: + pseudo_real_img = F.interpolate( + pseudo_real_img, size=[fake_masks.size(2), fake_masks.size(3)], mode='area') + pseudo_real_img = torch.clamp(pseudo_real_img, -1, 1) + + return pseudo_real_img, fake_masks + + + def sample_camera(self, data, pseudo_gen): + r"""Sample camera randomly and precompute everything used by both Gen and Dis. + + Args: + data (dict): + images (N x 3 x H x W tensor) : Real images + label (N x C2 x H x W tensor) : Segmentation map + pseudo_gen (callable): Function converting mask to image using img2img network. + Returns: + ret (dict): + voxel_id (N x H x W x max_samples x 1 tensor): IDs of intersected tensors along each ray. + depth2 (N x 2 x H x W x max_samples x 1 tensor): Depths of entrance and exit points for each ray-voxel + intersection. + raydirs (N x H x W x 1 x 3 tensor): The direction of each ray. + cam_ori_t (N x 3 tensor): Camera origins. + pseudo_real_img (N x 3 x H x W tensor): Pseudo-ground truth image. + real_masks (N x C3 x H x W tensor): One-hot segmentation map for real images, with translated labels. + fake_masks (N x C3 x H x W tensor): One-hot segmentation map for sampled camera views. + """ + device = torch.device('cuda') + batch_size = data['images'].size(0) + # ================ Assemble a batch ================== + # Requires: voxel_id, depth2, raydirs, cam_ori_t. + voxel_id, depth2, raydirs, cam_ori_t, _ = self._get_batch(batch_size, device) + ret = {'voxel_id': voxel_id, 'depth2': depth2, 'raydirs': raydirs, 'cam_ori_t': cam_ori_t} + + if pseudo_gen is not None: + pseudo_real_img, _ = self.get_pseudo_gt(pseudo_gen, voxel_id) + ret['pseudo_real_img'] = pseudo_real_img.float() + + # =============== Mask translation ================ + real_masks = data['label'] + if self.reduced_label_set: + # Translate fake mask (directly from mcid). + # convert unrecognized labels to 'dirt'. + # N C H W [1 1 80 80] + reduce_fake_mask = self.label_trans.mc2reduced( + voxel_id[:, :, :, 0, :].permute(0, 3, 1, 2).long().contiguous() + , ign2dirt=True) + reduce_fake_mask_onehot = torch.zeros([ + reduce_fake_mask.size(0), self.num_reduced_labels, reduce_fake_mask.size(2), reduce_fake_mask.size(3)], + dtype=torch.float, device=device) + reduce_fake_mask_onehot.scatter_(1, reduce_fake_mask, 1.0) + fake_masks = reduce_fake_mask_onehot + if self.pad != 0: + fake_masks = fake_masks[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2] + + # Translate real mask (data['label']), which is onehot. + real_masks_idx = torch.argmax(real_masks, dim=1, keepdim=True) + real_masks_idx[real_masks_idx > 182] = 182 + + reduced_real_mask = self.label_trans.coco2reduced(real_masks_idx) + reduced_real_mask_onehot = torch.zeros([ + reduced_real_mask.size(0), self.num_reduced_labels, reduced_real_mask.size(2), + reduced_real_mask.size(3)], dtype=torch.float, device=device) + reduced_real_mask_onehot.scatter_(1, reduced_real_mask, 1.0) + real_masks = reduced_real_mask_onehot + + # Mask smoothing. + if self.use_label_smooth: + fake_masks = mc_utils.segmask_smooth(fake_masks, kernel_size=self.label_smooth_dia) + if self.use_label_smooth_real: + real_masks = mc_utils.segmask_smooth(real_masks, kernel_size=self.label_smooth_dia) + + ret['real_masks'] = real_masks + ret['fake_masks'] = fake_masks + + return ret + + def _forward_perpix_sub(self, blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot=None, global_enc=None): + r"""Per-pixel rendering forwarding + + Args: + blk_feats: Deprecated + worldcoord2 (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. L is number of samples; N is batch size, always 1. + raydirs_in (N x H x W x 1 x C2 tensor or None): ray direction embeddings. + z (N x C3 tensor): Intermediate style vectors. + mc_masks_onehot (N x H x W x L x C4): One-hot segmentation maps. + Returns: + net_out_s (N x H x W x L x 1 tensor): Opacities. + net_out_c (N x H x W x L x C5 tensor): Color embeddings. + """ + _x, _y, _z = self.voxel.voxel_t.shape + delimeter = torch.Tensor([_x, _y, _z]).to(worldcoord2) + normalized_coord = worldcoord2 / delimeter * 2 - 1 + global_enc = global_enc[:, None, None, None, :].repeat(1, normalized_coord.shape[1], normalized_coord.shape[2], normalized_coord.shape[3], 1) + normalized_coord = torch.cat([normalized_coord, global_enc], dim=-1) + feature_in = self.hash_encoder(normalized_coord) + + net_out_s, net_out_c = self.render_net(feature_in, raydirs_in, z, mc_masks_onehot) + + if self.raw_noise_std > 0.: + noise = torch.randn_like(net_out_s) * self.raw_noise_std + net_out_s = net_out_s + noise + + return net_out_s, net_out_c + + def _forward_perpix(self, blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z, global_enc): + r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features + + Args: + blk_feats (K x C1 tensor): Deprecated + voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels, why always 6? + depth2 (N x 2 x H x W x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection. + raydirs (N x H x W x 1 x 3 tensor): The direction of each ray. + cam_ori_t (N x 3 tensor): Camera origins. + z (N x C3 tensor): Intermediate style vectors. + """ + # Generate sky_mask; PE transform on ray direction. + with torch.no_grad(): + raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() + if self.pe_params[2] == 0 and self.pe_params[3] is True: + raydirs_in = raydirs_in + elif self.pe_params[2] == 0 and self.pe_params[3] is False: # Not using raydir at all + raydirs_in = None + else: + raydirs_in = voxlib.positional_encoding(raydirs_in, self.pe_params[2], -1, self.pe_params[3]) + + # sky_mask: when True, ray finally hits sky + sky_mask = voxel_id[:, :, :, [-1], :] == 0 + # sky_only_mask: when True, ray hits nothing but sky + sky_only_mask = voxel_id[:, :, :, [0], :] == 0 + + with torch.no_grad(): + # Random sample points along the ray + num_samples = self.num_samples + 1 + if self.sample_use_box_boundaries: + num_samples = self.num_samples - self.num_blocks_early_stop + + # 10 samples per ray + 4 intersections - 2 + rand_depth, new_dists, new_idx = mc_utils.sample_depth_batched( + depth2, num_samples, deterministic=self.coarse_deterministic_sampling, + use_box_boundaries=self.sample_use_box_boundaries, sample_depth=self.sample_depth) + + nan_mask = torch.isnan(rand_depth) + inf_mask = torch.isinf(rand_depth) + rand_depth[nan_mask | inf_mask] = 0.0 + + worldcoord2 = raydirs * rand_depth + cam_ori_t[:, None, None, None, :] + + # Generate per-sample segmentation label + voxel_id_reduced = self.label_trans.mc2reduced(voxel_id, ign2dirt=True) + mc_masks = torch.gather(voxel_id_reduced, -2, new_idx) # B 256 256 N 1 + mc_masks = mc_masks.long() + mc_masks_onehot = torch.zeros([mc_masks.size(0), mc_masks.size(1), mc_masks.size( + 2), mc_masks.size(3), self.num_reduced_labels], dtype=torch.float, device=voxel_id.device) + # mc_masks_onehot: [B H W Nlayer 680] + mc_masks_onehot.scatter_(-1, mc_masks, 1.0) + + net_out_s, net_out_c = self._forward_perpix_sub(blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot, global_enc) + + # Handle sky + sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() + sky_raydirs_in = voxlib.positional_encoding(sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1]) + skynet_out_c = self.sky_net(sky_raydirs_in, z) + + # Blending + weights = mc_utils.volum_rendering_relu(net_out_s, new_dists * self.dists_scale, dim=-2) + + # If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero. + weights = weights * torch.logical_not(sky_only_mask).float() + total_weights_raw = torch.sum(weights, dim=-2, keepdim=True) # 256 256 1 1 + total_weights = total_weights_raw + + is_gnd = worldcoord2[..., [0]] <= 1.0 # Y X Z, [256, 256, 4, 3], nan < 1.0 == False + is_gnd = is_gnd.any(dim=-2, keepdim=True) + nosky_mask = torch.logical_or(torch.logical_not(sky_mask), is_gnd) + nosky_mask = nosky_mask.float() + + # Avoid sky leakage + sky_weight = 1.0-total_weights + if self.keep_sky_out: + # keep_sky_out_avgpool overrides sky_replace_color + if self.sky_replace_color is None or self.keep_sky_out_avgpool: + if self.keep_sky_out_avgpool: + if hasattr(self, 'sky_avg'): + sky_avg = self.sky_avg + else: + if self.sky_global_avgpool: + sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True) + else: + skynet_out_c_nchw = skynet_out_c.permute(0, 4, 1, 2, 3).squeeze(-1).contiguous() + sky_avg = F.avg_pool2d(skynet_out_c_nchw, 31, stride=1, padding=15, count_include_pad=False) + sky_avg = sky_avg.permute(0, 2, 3, 1).unsqueeze(-2).contiguous() + # print(sky_avg.shape) + skynet_out_c = skynet_out_c * (1.0-nosky_mask) + sky_avg*(nosky_mask) + else: + sky_weight = sky_weight * (1.0-nosky_mask) + else: + skynet_out_c = skynet_out_c * (1.0-nosky_mask) + self.sky_replace_color*(nosky_mask) + + if self.clip_feat_map is True: # intermediate feature before blending & CNN + rgbs = torch.clamp(net_out_c, -1, 1) + 1 + rgbs_sky = torch.clamp(skynet_out_c, -1, 1) + 1 + net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ + rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 + net_out = net_out.squeeze(-2) + net_out = net_out - 1 + elif self.clip_feat_map is False: + rgbs = net_out_c + rgbs_sky = skynet_out_c + net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ + rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 + net_out = net_out.squeeze(-2) + elif self.clip_feat_map == 'tanh': + rgbs = torch.tanh(net_out_c) + rgbs_sky = torch.tanh(skynet_out_c) + net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ + rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 + net_out = net_out.squeeze(-2) + else: + raise NotImplementedError + + return net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \ + nosky_mask, sky_mask, sky_only_mask, new_idx + + def forward(self, data, random_style=False): + r"""SceneDreamer forward. + """ + device = torch.device('cuda') + batch_size = data['images'].size(0) + # Requires: voxel_id, depth2, raydirs, cam_ori_t. + voxel_id, depth2, raydirs, cam_ori_t = data['voxel_id'], data['depth2'], data['raydirs'], data['cam_ori_t'] + if 'pseudo_real_img' in data: + pseudo_real_img = data['pseudo_real_img'] + + global_enc = self.world_encoder(self.voxel.current_height_map, self.voxel.current_semantic_map) + + z, mu, logvar = None, None, None + if random_style: + if self.style_dims > 0: + z = torch.randn(batch_size, self.style_dims, dtype=torch.float32, device=device) + else: + if self.style_encoder is None: + # ================ Get Style Code ================= + if self.style_dims > 0: + z = torch.randn(batch_size, self.style_dims, dtype=torch.float32, device=device) + else: + mu, logvar, z = self.style_encoder(pseudo_real_img) + + # ================ Network Forward ================ + # Forward StyleNet + if self.style_net is not None: + z = self.style_net(z) + + # Forward per-pixel net. + net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, nosky_mask, \ + sky_mask, sky_only_mask, new_idx = self._forward_perpix( + self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z, global_enc) + + # Forward global net. + fake_images, fake_images_raw = self._forward_global(net_out, z) + if self.pad != 0: + fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2] + + # =============== Arrange Return Values ================ + output = {} + output['fake_images'] = fake_images + output['mu'] = mu + output['logvar'] = logvar + return output + + + def inference_givenstyle(self, style, + output_dir, + camera_mode, + style_img_path=None, + seed=1, + pad=30, + num_samples=40, + num_blocks_early_stop=6, + sample_depth=3, + tile_size=128, + resolution_hw=[540, 960], + cam_ang=72, + cam_maxstep=10): + r"""Compute result images according to the provided camera trajectory and save the results in the specified + folder. The full image is evaluated in multiple tiles to save memory. + + Args: + output_dir (str): Where should the results be stored. + camera_mode (int): Which camera trajectory to use. + style_img_path (str): Path to the style-conditioning image. + seed (int): Random seed (controls style when style_image_path is not specified). + pad (int): Pixels to remove from the image tiles before stitching. Should be equal or larger than the + receptive field of the CNN to avoid border artifact. + num_samples (int): Number of samples per ray (different from training). + num_blocks_early_stop (int): Max number of intersected boxes per ray before stopping + (different from training). + sample_depth (float): Max distance traveled through boxes before stopping (different from training). + tile_size (int): Max size of a tile in pixels. + resolution_hw (list [H, W]): Resolution of the output image. + cam_ang (float): Horizontal FOV of the camera (may be adjusted by the camera controller). + cam_maxstep (int): Number of frames sampled from the camera trajectory. + """ + + def write_img(path, img, rgb_input=False): + img = ((img*0.5+0.5)*255).detach().cpu().numpy().astype(np.uint8) + img = img[0].transpose(1, 2, 0) + if rgb_input: + img = img[..., [2, 1, 0]] + cv2.imwrite(path, img, [cv2.IMWRITE_PNG_COMPRESSION, 4]) + return img[..., ::-1] + + def read_img(path): + img = cv2.imread(path).astype(np.float32)[..., [2, 1, 0]].transpose(2, 0, 1) / 255 + img = img * 2 - 1 + img = torch.from_numpy(img) + + print('Saving to', output_dir) + + # Use provided random seed. + device = torch.device('cuda') + + global_enc = self.world_encoder(self.voxel.current_height_map, self.voxel.current_semantic_map) + + biome_colors = torch.Tensor([ + [255, 255, 178], + [184, 200, 98], + [188, 161, 53], + [190, 255, 242], + [106, 144, 38], + [33, 77, 41], + [86, 179, 106], + [34, 61, 53], + [35, 114, 94], + [0, 0, 255], + [0, 255, 0], + ]).to(device) / 255 * 2 - 1 + semantic_map = torch.argmax(self.voxel.current_semantic_map, dim=1) + + self.pad = pad + self.num_samples = num_samples + self.num_blocks_early_stop = num_blocks_early_stop + self.sample_depth = sample_depth + + self.coarse_deterministic_sampling = True + self.crop_size = resolution_hw + self.cam_res = [self.crop_size[0]+self.pad, self.crop_size[1]+self.pad] + self.use_label_smooth_pgt = False + + # Make output dirs. + output_dir = os.path.join(output_dir, 'rgb_render') + os.makedirs(output_dir, exist_ok=True) + fout = imageio.get_writer(output_dir + '.mp4', fps=10) + + write_img(os.path.join(output_dir, 'semantic_map.png'), biome_colors[semantic_map].permute(0, 3, 1, 2), rgb_input=True) + write_img(os.path.join(output_dir, 'height_map.png'), self.voxel.current_height_map) + np.save(os.path.join(output_dir, 'style.npy'), style.detach().cpu().numpy()) + evalcamctl = camctl.EvalCameraController( + self.voxel, maxstep=cam_maxstep, pattern=camera_mode, cam_ang=cam_ang, + smooth_decay_multiplier=150/cam_maxstep) + + # Get output style. + z = self.style_net(style) + + # Generate required output images. + for id, (cam_ori_t, cam_dir_t, cam_up_t, cam_f) in enumerate(evalcamctl): + print('Rendering frame', id) + cam_f = cam_f * (self.crop_size[1]-1) # So that the view is not depending on the padding + cam_c = [(self.cam_res[0]-1)/2, (self.cam_res[1]-1)/2] + + voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective( + self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, self.cam_res, + self.num_blocks_early_stop) + + voxel_id = voxel_id.unsqueeze(0) + depth2 = depth2.unsqueeze(0) + raydirs = raydirs.unsqueeze(0) + cam_ori_t = cam_ori_t.unsqueeze(0).to(device) + + voxel_id_all = voxel_id + depth2_all = depth2 + raydirs_all = raydirs + + # Evaluate sky in advance to get a consistent sky in the semi-transparent region. + if self.sky_global_avgpool: + sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() + sky_raydirs_in = voxlib.positional_encoding( + sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1]) + skynet_out_c = self.sky_net(sky_raydirs_in, z) + sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True) + self.sky_avg = sky_avg + + num_strips_h = (self.cam_res[0]-self.pad+tile_size-1)//tile_size + num_strips_w = (self.cam_res[1]-self.pad+tile_size-1)//tile_size + + fake_images_chunks_v = [] + # For each horizontal strip. + for strip_id_h in range(num_strips_h): + strip_begin_h = strip_id_h * tile_size + strip_end_h = np.minimum(strip_id_h * tile_size + tile_size + self.pad, self.cam_res[0]) + # For each vertical strip. + fake_images_chunks_h = [] + for strip_id_w in range(num_strips_w): + strip_begin_w = strip_id_w * tile_size + strip_end_w = np.minimum(strip_id_w * tile_size + tile_size + self.pad, self.cam_res[1]) + + voxel_id = voxel_id_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :] + depth2 = depth2_all[:, :, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :] + raydirs = raydirs_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :] + + net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \ + nosky_mask, sky_mask, sky_only_mask, new_idx = self._forward_perpix( + self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z, global_enc) + fake_images, _ = self._forward_global(net_out, z) + + if self.pad != 0: + fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2] + fake_images_chunks_h.append(fake_images) + fake_images_h = torch.cat(fake_images_chunks_h, dim=-1) + fake_images_chunks_v.append(fake_images_h) + fake_images = torch.cat(fake_images_chunks_v, dim=-2) + rgb = write_img(os.path.join(output_dir, + '{:05d}.png'.format(id)), fake_images, rgb_input=True) + fout.append_data(rgb) + fout.close() + + + + def inference_givenstyle_depth(self, style, + output_dir, + camera_mode, + style_img_path=None, + seed=1, + pad=30, + num_samples=40, + num_blocks_early_stop=6, + sample_depth=3, + tile_size=128, + resolution_hw=[540, 960], + cam_ang=72, + cam_maxstep=10): + r"""Compute result images according to the provided camera trajectory and save the results in the specified + folder. The full image is evaluated in multiple tiles to save memory. + + Args: + output_dir (str): Where should the results be stored. + camera_mode (int): Which camera trajectory to use. + style_img_path (str): Path to the style-conditioning image. + seed (int): Random seed (controls style when style_image_path is not specified). + pad (int): Pixels to remove from the image tiles before stitching. Should be equal or larger than the + receptive field of the CNN to avoid border artifact. + num_samples (int): Number of samples per ray (different from training). + num_blocks_early_stop (int): Max number of intersected boxes per ray before stopping + (different from training). + sample_depth (float): Max distance traveled through boxes before stopping (different from training). + tile_size (int): Max size of a tile in pixels. + resolution_hw (list [H, W]): Resolution of the output image. + cam_ang (float): Horizontal FOV of the camera (may be adjusted by the camera controller). + cam_maxstep (int): Number of frames sampled from the camera trajectory. + """ + + def write_img(path, img, rgb_input=False): + img = ((img*0.5+0.5)*255).detach().cpu().numpy().astype(np.uint8) + img = img[0].transpose(1, 2, 0) + if rgb_input: + img = img[..., [2, 1, 0]] + cv2.imwrite(path, img, [cv2.IMWRITE_PNG_COMPRESSION, 4]) + return img[..., ::-1] + + def read_img(path): + img = cv2.imread(path).astype(np.float32)[..., [2, 1, 0]].transpose(2, 0, 1) / 255 + img = img * 2 - 1 + img = torch.from_numpy(img) + + print('Saving to', output_dir) + + # Use provided random seed. + device = torch.device('cuda') + + global_enc = self.world_encoder(self.voxel.current_height_map, self.voxel.current_semantic_map) + + biome_colors = torch.Tensor([ + [255, 255, 178], + [184, 200, 98], + [188, 161, 53], + [190, 255, 242], + [106, 144, 38], + [33, 77, 41], + [86, 179, 106], + [34, 61, 53], + [35, 114, 94], + [0, 0, 255], + [0, 255, 0], + ]) / 255 * 2 - 1 + print(self.voxel.current_height_map[0].shape) + semantic_map = torch.argmax(self.voxel.current_semantic_map, dim=1) + print(torch.unique(semantic_map, return_counts=True)) + print(semantic_map.min()) + + self.pad = pad + self.num_samples = num_samples + self.num_blocks_early_stop = num_blocks_early_stop + self.sample_depth = sample_depth + + self.coarse_deterministic_sampling = True + self.crop_size = resolution_hw + self.cam_res = [self.crop_size[0]+self.pad, self.crop_size[1]+self.pad] + self.use_label_smooth_pgt = False + + # Make output dirs. + gancraft_outputs_dir = os.path.join(output_dir, 'gancraft_outputs') + os.makedirs(gancraft_outputs_dir, exist_ok=True) + gancraft_depth_outputs_dir = os.path.join(output_dir, 'depth') + os.makedirs(gancraft_depth_outputs_dir, exist_ok=True) + vis_masks_dir = os.path.join(output_dir, 'vis_masks') + os.makedirs(vis_masks_dir, exist_ok=True) + fout = imageio.get_writer(gancraft_outputs_dir + '.mp4', fps=10) + fout_cat = imageio.get_writer(gancraft_outputs_dir + '-vis_masks.mp4', fps=10) + + write_img(os.path.join(output_dir, 'semantic_map.png'), biome_colors[semantic_map].permute(0, 3, 1, 2), rgb_input=True) + write_img(os.path.join(output_dir, 'heightmap.png'), self.voxel.current_height_map) + + evalcamctl = camctl.EvalCameraController( + self.voxel, maxstep=cam_maxstep, pattern=camera_mode, cam_ang=cam_ang, + smooth_decay_multiplier=150/cam_maxstep) + + # import pickle + # with open(os.path.join(output_dir,'camera.pkl'), 'wb') as f: + # pickle.dump(evalcamctl, f) + + # Get output style. + z = self.style_net(style) + + # Generate required output images. + for id, (cam_ori_t, cam_dir_t, cam_up_t, cam_f) in enumerate(evalcamctl): + # print('Rendering frame', id) + cam_f = cam_f * (self.crop_size[1]-1) # So that the view is not depending on the padding + cam_c = [(self.cam_res[0]-1)/2, (self.cam_res[1]-1)/2] + + voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective( + self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, self.cam_res, + self.num_blocks_early_stop) + + voxel_id = voxel_id.unsqueeze(0) + depth2 = depth2.unsqueeze(0) + raydirs = raydirs.unsqueeze(0) + cam_ori_t = cam_ori_t.unsqueeze(0).to(device) + + # Save 3D voxel rendering. + mc_rgb = self.label_trans.mc_color(voxel_id[0, :, :, 0, 0].cpu().numpy()) + # Diffused shading, co-located light. + first_intersection_depth = depth2[:, 0, :, :, 0, None, :] # [1, 542, 542, 1, 1]. + first_intersection_point = raydirs * first_intersection_depth + cam_ori_t[:, None, None, None, :] + fip_local_coords = torch.remainder(first_intersection_point, 1.0) + fip_wall_proximity = torch.minimum(fip_local_coords, 1.0-fip_local_coords) + fip_wall_orientation = torch.argmin(fip_wall_proximity, dim=-1, keepdim=False) + # 0: [1,0,0]; 1: [0,1,0]; 2: [0,0,1] + lut = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32, + device=fip_wall_orientation.device) + fip_normal = lut[fip_wall_orientation] # [1, 542, 542, 1, 3] + diffuse_shade = torch.abs(torch.sum(fip_normal * raydirs, dim=-1)) + + mc_rgb = (mc_rgb.astype(np.float) / 255) ** 2.2 + mc_rgb = mc_rgb * diffuse_shade[0, :, :, :].cpu().numpy() + mc_rgb = (mc_rgb ** (1/2.2)) * 255 + mc_rgb = mc_rgb.astype(np.uint8) + if self.pad > 0: + mc_rgb = mc_rgb[self.pad//2:-self.pad//2, self.pad//2:-self.pad//2] + cv2.imwrite(os.path.join(vis_masks_dir, '{:05d}.png'.format(id)), mc_rgb, [cv2.IMWRITE_PNG_COMPRESSION, 4]) + + # Tiled eval of GANcraft. + voxel_id_all = voxel_id + depth2_all = depth2 + raydirs_all = raydirs + + # Evaluate sky in advance to get a consistent sky in the semi-transparent region. + if self.sky_global_avgpool: + sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() + sky_raydirs_in = voxlib.positional_encoding( + sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1]) + skynet_out_c = self.sky_net(sky_raydirs_in, z) + sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True) + self.sky_avg = sky_avg + + num_strips_h = (self.cam_res[0]-self.pad+tile_size-1)//tile_size + num_strips_w = (self.cam_res[1]-self.pad+tile_size-1)//tile_size + + fake_images_chunks_v = [] + fake_depth_chunks_v = [] + # For each horizontal strip. + for strip_id_h in range(num_strips_h): + strip_begin_h = strip_id_h * tile_size + strip_end_h = np.minimum(strip_id_h * tile_size + tile_size + self.pad, self.cam_res[0]) + # For each vertical strip. + fake_images_chunks_h = [] + fake_depth_chunks_h = [] + for strip_id_w in range(num_strips_w): + strip_begin_w = strip_id_w * tile_size + strip_end_w = np.minimum(strip_id_w * tile_size + tile_size + self.pad, self.cam_res[1]) + + voxel_id = voxel_id_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :] + depth2 = depth2_all[:, :, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :] + raydirs = raydirs_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :] + + net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \ + nosky_mask, sky_mask, sky_only_mask, new_idx = self._forward_perpix( + self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z, global_enc) + fake_images, _ = self._forward_global(net_out, z) + depth_map = torch.sum(weights * rand_depth, -2) + # disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map).to(depth_map), depth_map / torch.sum(weights, -2)) + # depth_map = torch.clip(depth_map, 0, 100.) + # disp_map = 1. / (depth_map.permute(0, 3, 1, 2)) + disp_map = depth_map.permute(0, 3, 1, 2) + if self.pad != 0: + fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2] + disp_map = disp_map[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2] + fake_images_chunks_h.append(fake_images) + fake_depth_chunks_h.append(disp_map) + fake_images_h = torch.cat(fake_images_chunks_h, dim=-1) + fake_depth_h = torch.cat(fake_depth_chunks_h, dim=-1) + fake_images_chunks_v.append(fake_images_h) + fake_depth_chunks_v.append(fake_depth_h) + fake_images = torch.cat(fake_images_chunks_v, dim=-2) + fake_depth = torch.cat(fake_depth_chunks_v, dim=-2) + # fake_depth = ((fake_depth - fake_depth.mean()) / fake_depth.std() + 1) / 2 + # fake_depth = torch.clip(1./ (fake_depth + 1e-4), 0., 1.) + # fake_depth = ((fake_depth - fake_depth.mean()) / fake_depth.std() + 1) / 2 + mmask = fake_depth > 0 + tmp = fake_depth[mmask] + # tmp = 1. / (tmp + 1e-4) + tmp = (tmp - tmp.min()) / (tmp.max() - tmp.min()) + # tmp = ((tmp - tmp.mean()) / tmp.std() + 1) / 2. + fake_depth[~mmask] = 1 + fake_depth[mmask] = tmp + # fake_depth = (fake_depth - fake_depth.min()) / (fake_depth.max() - fake_depth.min()) + + cv2.imwrite(os.path.join(gancraft_depth_outputs_dir, '{:05d}.png'.format(id)), fake_depth[0].permute(1, 2, 0).detach().cpu().numpy() * 255) + rgb = write_img(os.path.join(gancraft_outputs_dir, + '{:05d}.png'.format(id)), fake_images, rgb_input=True) + fout.append_data(rgb) + fout_cat.append_data(np.concatenate((mc_rgb[..., ::-1], rgb), axis=1)) + fout.close() + fout_cat.close() + diff --git a/imaginaire/generators/spade.py b/imaginaire/generators/spade.py new file mode 100644 index 0000000000000000000000000000000000000000..dc69630304ccb2ce3fab707ca2e7de5f7aeec55a --- /dev/null +++ b/imaginaire/generators/spade.py @@ -0,0 +1,571 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import functools +import math +import types + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Upsample as NearestUpsample + +from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock +from imaginaire.utils.data import (get_crop_h_w, + get_paired_input_image_channel_number, + get_paired_input_label_channel_number) +from imaginaire.utils.distributed import master_only_print as print + + +class Generator(nn.Module): + r"""SPADE generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, gen_cfg, data_cfg): + super(Generator, self).__init__() + print('SPADE generator initialization.') + # We assume the first datum is the ground truth image. + image_channels = getattr(gen_cfg, 'image_channels', None) + if image_channels is None: + image_channels = get_paired_input_image_channel_number(data_cfg) + num_labels = getattr(gen_cfg, 'num_labels', None) + if num_labels is None: + # Calculate number of channels in the input label when not specified. + num_labels = get_paired_input_label_channel_number(data_cfg) + crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations) + # Build the generator + out_image_small_side_size = crop_w if crop_w < crop_h else crop_h + num_filters = getattr(gen_cfg, 'num_filters', 128) + kernel_size = getattr(gen_cfg, 'kernel_size', 3) + weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral') + + cond_dims = 0 + # Check whether we use the style code. + style_dims = getattr(gen_cfg, 'style_dims', None) + self.style_dims = style_dims + if style_dims is not None: + print('\tStyle code dimensions: %d' % style_dims) + cond_dims += style_dims + self.use_style = True + else: + self.use_style = False + # Check whether we use the attribute code. + if hasattr(gen_cfg, 'attribute_dims'): + self.use_attribute = True + self.attribute_dims = gen_cfg.attribute_dims + cond_dims += gen_cfg.attribute_dims + else: + self.use_attribute = False + + if not self.use_style and not self.use_attribute: + self.use_style_encoder = False + else: + self.use_style_encoder = True + print('\tBase filter number: %d' % num_filters) + print('\tConvolution kernel size: %d' % kernel_size) + print('\tWeight norm type: %s' % weight_norm_type) + skip_activation_norm = \ + getattr(gen_cfg, 'skip_activation_norm', True) + activation_norm_params = getattr(gen_cfg, 'activation_norm_params', None) + if activation_norm_params is None: + activation_norm_params = types.SimpleNamespace() + if not hasattr(activation_norm_params, 'num_filters'): + setattr(activation_norm_params, 'num_filters', 128) + if not hasattr(activation_norm_params, 'kernel_size'): + setattr(activation_norm_params, 'kernel_size', 3) + if not hasattr(activation_norm_params, 'activation_norm_type'): + setattr(activation_norm_params, 'activation_norm_type', 'sync_batch') + if not hasattr(activation_norm_params, 'separate_projection'): + setattr(activation_norm_params, 'separate_projection', False) + if not hasattr(activation_norm_params, 'activation_norm_params'): + activation_norm_params.activation_norm_params = types.SimpleNamespace() + activation_norm_params.activation_norm_params.affine = True + setattr(activation_norm_params, 'cond_dims', num_labels) + if not hasattr(activation_norm_params, 'weight_norm_type'): + setattr(activation_norm_params, 'weight_norm_type', weight_norm_type) + global_adaptive_norm_type = getattr(gen_cfg, 'global_adaptive_norm_type', 'sync_batch') + use_posenc_in_input_layer = getattr(gen_cfg, 'use_posenc_in_input_layer', True) + output_multiplier = getattr(gen_cfg, 'output_multiplier', 1.0) + print(activation_norm_params) + self.spade_generator = SPADEGenerator(num_labels, + out_image_small_side_size, + image_channels, + num_filters, + kernel_size, + cond_dims, + activation_norm_params, + weight_norm_type, + global_adaptive_norm_type, + skip_activation_norm, + use_posenc_in_input_layer, + self.use_style_encoder, + output_multiplier) + if self.use_style: + # Build the encoder. + style_enc_cfg = getattr(gen_cfg, 'style_enc', None) + if style_enc_cfg is None: + style_enc_cfg = types.SimpleNamespace() + if not hasattr(style_enc_cfg, 'num_filters'): + setattr(style_enc_cfg, 'num_filters', 128) + if not hasattr(style_enc_cfg, 'kernel_size'): + setattr(style_enc_cfg, 'kernel_size', 3) + if not hasattr(style_enc_cfg, 'weight_norm_type'): + setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type) + setattr(style_enc_cfg, 'input_image_channels', image_channels) + setattr(style_enc_cfg, 'style_dims', style_dims) + self.style_encoder = StyleEncoder(style_enc_cfg) + + self.z = None + print('Done with the SPADE generator initialization.') + + def forward(self, data, random_style=False): + r"""SPADE Generator forward. + + Args: + data (dict): + - images (N x C1 x H x W tensor) : Ground truth images + - label (N x C2 x H x W tensor) : Semantic representations + - z (N x style_dims tensor): Gaussian random noise + - random_style (bool): Whether to sample a random style vector. + Returns: + (dict): + - fake_images (N x 3 x H x W tensor): fake images + - mu (N x C1 tensor): mean vectors + - logvar (N x C1 tensor): log-variance vectors + """ + if self.use_style_encoder: + if random_style: + bs = data['label'].size(0) + z = torch.randn( + bs, self.style_dims, dtype=torch.float32).cuda() + if (data['label'].dtype == + data['label'].dtype == torch.float16): + z = z.half() + mu = None + logvar = None + else: + mu, logvar, z = self.style_encoder(data['images']) + if self.use_attribute: + data['z'] = torch.cat((z, data['attributes'].squeeze(1)), dim=1) + else: + data['z'] = z + output = self.spade_generator(data) + if self.use_style_encoder: + output['mu'] = mu + output['logvar'] = logvar + return output + + def inference(self, + data, + random_style=False, + use_fixed_random_style=False, + keep_original_size=False): + r"""Compute results images for a batch of input data and save the + results in the specified folder. + + Args: + data (dict): + - images (N x C1 x H x W tensor) : Ground truth images + - label (N x C2 x H x W tensor) : Semantic representations + - z (N x style_dims tensor): Gaussian random noise + random_style (bool): Whether to sample a random style vector. + use_fixed_random_style (bool): Sample random style once and use it + for all the remaining inference. + keep_original_size (bool): Keep original size of the input. + Returns: + (dict): + - fake_images (N x 3 x H x W tensor): fake images + - mu (N x C1 tensor): mean vectors + - logvar (N x C1 tensor): log-variance vectors + """ + self.eval() + self.spade_generator.eval() + + if self.use_style_encoder: + if random_style and self.use_style_encoder: + if self.z is None or not use_fixed_random_style: + bs = data['label'].size(0) + z = torch.randn( + bs, self.style_dims, dtype=torch.float32).to('cuda') + if (data['label'].dtype == + data['label'].dtype == + torch.float16): + z = z.half() + self.z = z + else: + z = self.z + else: + mu, logvar, z = self.style_encoder(data['images']) + data['z'] = z + + output = self.spade_generator(data) + output_images = output['fake_images'] + + if keep_original_size: + height = data['original_h_w'][0][0] + width = data['original_h_w'][0][1] + output_images = torch.nn.functional.interpolate( + output_images, size=[height, width]) + + for key in data['key'].keys(): + if 'segmaps' in key or 'seg_maps' in key: + file_names = data['key'][key][0] + break + for key in data['key'].keys(): + if 'edgemaps' in key or 'edge_maps' in key: + file_names = data['key'][key][0] + break + + return output_images, file_names + + +class SPADEGenerator(nn.Module): + r"""SPADE Image Generator constructor. + + Args: + num_labels (int): Number of different labels. + out_image_small_side_size (int): min(width, height) + image_channels (int): Num. of channels of the output image. + num_filters (int): Base filter numbers. + kernel_size (int): Convolution kernel size. + style_dims (int): Dimensions of the style code. + activation_norm_params (obj): Spatially adaptive normalization param. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, or ``'weight'``. + global_adaptive_norm_type (str): Type of normalization in SPADE. + skip_activation_norm (bool): If ``True``, applies activation norm to the + shortcut connection in residual blocks. + use_style_encoder (bool): Whether to use global adaptive norm + like conditional batch norm or adaptive instance norm. + output_multiplier (float): A positive number multiplied to the output + """ + + def __init__(self, + num_labels, + out_image_small_side_size, + image_channels, + num_filters, + kernel_size, + style_dims, + activation_norm_params, + weight_norm_type, + global_adaptive_norm_type, + skip_activation_norm, + use_posenc_in_input_layer, + use_style_encoder, + output_multiplier): + super(SPADEGenerator, self).__init__() + self.output_multiplier = output_multiplier + self.use_style_encoder = use_style_encoder + self.use_posenc_in_input_layer = use_posenc_in_input_layer + self.out_image_small_side_size = out_image_small_side_size + self.num_filters = num_filters + padding = int(np.ceil((kernel_size - 1.0) / 2)) + nonlinearity = 'leakyrelu' + activation_norm_type = 'spatially_adaptive' + base_res2d_block = \ + functools.partial(Res2dBlock, + kernel_size=kernel_size, + padding=padding, + bias=[True, True, False], + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + skip_activation_norm=skip_activation_norm, + nonlinearity=nonlinearity, + order='NACNAC') + if self.use_style_encoder: + self.fc_0 = LinearBlock(style_dims, 2 * style_dims, + weight_norm_type=weight_norm_type, + nonlinearity='relu', + order='CAN') + self.fc_1 = LinearBlock(2 * style_dims, 2 * style_dims, + weight_norm_type=weight_norm_type, + nonlinearity='relu', + order='CAN') + + adaptive_norm_params = types.SimpleNamespace() + if not hasattr(adaptive_norm_params, 'cond_dims'): + setattr(adaptive_norm_params, 'cond_dims', 2 * style_dims) + if not hasattr(adaptive_norm_params, 'activation_norm_type'): + setattr(adaptive_norm_params, 'activation_norm_type', global_adaptive_norm_type) + if not hasattr(adaptive_norm_params, 'weight_norm_type'): + setattr(adaptive_norm_params, 'weight_norm_type', activation_norm_params.weight_norm_type) + if not hasattr(adaptive_norm_params, 'separate_projection'): + setattr(adaptive_norm_params, 'separate_projection', activation_norm_params.separate_projection) + adaptive_norm_params.activation_norm_params = types.SimpleNamespace() + setattr(adaptive_norm_params.activation_norm_params, 'affine', + activation_norm_params.activation_norm_params.affine) + base_cbn2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=1, + padding=padding, + bias=True, + weight_norm_type=weight_norm_type, + activation_norm_type='adaptive', + activation_norm_params=adaptive_norm_params, + nonlinearity=nonlinearity, + order='NAC') + else: + base_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=1, + padding=padding, + bias=True, + weight_norm_type=weight_norm_type, + nonlinearity=nonlinearity, + order='NAC') + in_num_labels = num_labels + in_num_labels += 2 if self.use_posenc_in_input_layer else 0 + self.head_0 = Conv2dBlock(in_num_labels, 8 * num_filters, + kernel_size=kernel_size, stride=1, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type='none', + nonlinearity=nonlinearity) + if self.use_style_encoder: + self.cbn_head_0 = base_cbn2d_block( + 8 * num_filters, 16 * num_filters) + else: + self.conv_head_0 = base_conv2d_block( + 8 * num_filters, 16 * num_filters) + self.head_1 = base_res2d_block(16 * num_filters, 16 * num_filters) + self.head_2 = base_res2d_block(16 * num_filters, 16 * num_filters) + + self.up_0a = base_res2d_block(16 * num_filters, 8 * num_filters) + if self.use_style_encoder: + self.cbn_up_0a = base_cbn2d_block( + 8 * num_filters, 8 * num_filters) + else: + self.conv_up_0a = base_conv2d_block( + 8 * num_filters, 8 * num_filters) + self.up_0b = base_res2d_block(8 * num_filters, 8 * num_filters) + + self.up_1a = base_res2d_block(8 * num_filters, 4 * num_filters) + if self.use_style_encoder: + self.cbn_up_1a = base_cbn2d_block( + 4 * num_filters, 4 * num_filters) + else: + self.conv_up_1a = base_conv2d_block( + 4 * num_filters, 4 * num_filters) + self.up_1b = base_res2d_block(4 * num_filters, 4 * num_filters) + self.up_2a = base_res2d_block(4 * num_filters, 4 * num_filters) + if self.use_style_encoder: + self.cbn_up_2a = base_cbn2d_block( + 4 * num_filters, 4 * num_filters) + else: + self.conv_up_2a = base_conv2d_block( + 4 * num_filters, 4 * num_filters) + self.up_2b = base_res2d_block(4 * num_filters, 2 * num_filters) + self.conv_img256 = Conv2dBlock(2 * num_filters, image_channels, + 5, stride=1, padding=2, + weight_norm_type=weight_norm_type, + activation_norm_type='none', + nonlinearity=nonlinearity, + order='ANC') + self.base = 16 + if self.out_image_small_side_size == 512: + self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters) + self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters) + self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels, + 5, stride=1, padding=2, + weight_norm_type=weight_norm_type, + activation_norm_type='none', + nonlinearity=nonlinearity, + order='ANC') + self.base = 32 + if self.out_image_small_side_size == 1024: + self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters) + self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters) + self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels, + 5, stride=1, padding=2, + weight_norm_type=weight_norm_type, + activation_norm_type='none', + nonlinearity=nonlinearity, + order='ANC') + self.up_4a = base_res2d_block(num_filters, num_filters // 2) + self.up_4b = base_res2d_block(num_filters // 2, num_filters // 2) + self.conv_img1024 = Conv2dBlock(num_filters // 2, image_channels, + 5, stride=1, padding=2, + weight_norm_type=weight_norm_type, + activation_norm_type='none', + nonlinearity=nonlinearity, + order='ANC') + self.nearest_upsample4x = NearestUpsample(scale_factor=4, mode='nearest') + self.base = 64 + if self.out_image_small_side_size != 256 and self.out_image_small_side_size != 512 \ + and self.out_image_small_side_size != 1024: + raise ValueError('Generation image size (%d, %d) not supported' % + (self.out_image_small_side_size, + self.out_image_small_side_size)) + self.nearest_upsample2x = NearestUpsample(scale_factor=2, mode='nearest') + + xv, yv = torch.meshgrid( + [torch.arange(-1, 1.1, 2. / 15), torch.arange(-1, 1.1, 2. / 15)]) + self.xy = torch.cat((xv.unsqueeze(0), yv.unsqueeze(0)), 0).unsqueeze(0) + self.xy = self.xy.cuda() + + def forward(self, data): + r"""SPADE Generator forward. + + Args: + data (dict): + - data (N x C1 x H x W tensor) : Ground truth images. + - label (N x C2 x H x W tensor) : Semantic representations. + - z (N x style_dims tensor): Gaussian random noise. + Returns: + output (dict): + - fake_images (N x 3 x H x W tensor): Fake images. + """ + seg = data['label'] + + if self.use_style_encoder: + z = data['z'] + z = self.fc_0(z) + z = self.fc_1(z) + + # The code piece below makes sure that the input size is always 16x16 + sy = math.floor(seg.size()[2] * 1.0 / self.base) + sx = math.floor(seg.size()[3] * 1.0 / self.base) + + in_seg = F.interpolate(seg, size=[sy, sx], mode='nearest') + if self.use_posenc_in_input_layer: + in_xy = F.interpolate(self.xy, size=[sy, sx], mode='bicubic') + in_seg_xy = torch.cat( + (in_seg, in_xy.expand(in_seg.size()[0], 2, sy, sx)), 1) + else: + in_seg_xy = in_seg + # 16x16 + x = self.head_0(in_seg_xy) + if self.use_style_encoder: + x = self.cbn_head_0(x, z) + else: + x = self.conv_head_0(x) + x = self.head_1(x, seg) + x = self.head_2(x, seg) + x = self.nearest_upsample2x(x) + # 32x32 + x = self.up_0a(x, seg) + if self.use_style_encoder: + x = self.cbn_up_0a(x, z) + else: + x = self.conv_up_0a(x) + x = self.up_0b(x, seg) + x = self.nearest_upsample2x(x) + # 64x64 + x = self.up_1a(x, seg) + if self.use_style_encoder: + x = self.cbn_up_1a(x, z) + else: + x = self.conv_up_1a(x) + x = self.up_1b(x, seg) + x = self.nearest_upsample2x(x) + # 128x128 + x = self.up_2a(x, seg) + if self.use_style_encoder: + x = self.cbn_up_2a(x, z) + else: + x = self.conv_up_2a(x) + x = self.up_2b(x, seg) + x = self.nearest_upsample2x(x) + # 256x256 + if self.out_image_small_side_size == 256: + x256 = self.conv_img256(x) + x = torch.tanh(self.output_multiplier * x256) + # 512x512 + elif self.out_image_small_side_size == 512: + x256 = self.conv_img256(x) + x256 = self.nearest_upsample2x(x256) + x = self.up_3a(x, seg) + x = self.up_3b(x, seg) + x = self.nearest_upsample2x(x) + x512 = self.conv_img512(x) + x = torch.tanh(self.output_multiplier * (x256 + x512)) + # 1024x1024 + elif self.out_image_small_side_size == 1024: + x256 = self.conv_img256(x) + x256 = self.nearest_upsample4x(x256) + x = self.up_3a(x, seg) + x = self.up_3b(x, seg) + x = self.nearest_upsample2x(x) + x512 = self.conv_img512(x) + x512 = self.nearest_upsample2x(x512) + x = self.up_4a(x, seg) + x = self.up_4b(x, seg) + x = self.nearest_upsample2x(x) + x1024 = self.conv_img1024(x) + x = torch.tanh(self.output_multiplier * (x256 + x512 + x1024)) + output = dict() + output['fake_images'] = x + return output + + +class StyleEncoder(nn.Module): + r"""Style Encode constructor. + + Args: + style_enc_cfg (obj): Style encoder definition file. + """ + + def __init__(self, style_enc_cfg): + super(StyleEncoder, self).__init__() + input_image_channels = style_enc_cfg.input_image_channels + num_filters = style_enc_cfg.num_filters + kernel_size = style_enc_cfg.kernel_size + padding = int(np.ceil((kernel_size - 1.0) / 2)) + style_dims = style_enc_cfg.style_dims + weight_norm_type = style_enc_cfg.weight_norm_type + activation_norm_type = 'none' + nonlinearity = 'leakyrelu' + base_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=2, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + # inplace_nonlinearity=True, + nonlinearity=nonlinearity) + self.layer1 = base_conv2d_block(input_image_channels, num_filters) + self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2) + self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4) + self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8) + self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8) + self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8) + self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims) + self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims) + + def forward(self, input_x): + r"""SPADE Style Encoder forward. + + Args: + input_x (N x 3 x H x W tensor): input images. + Returns: + (tuple): + - mu (N x C tensor): Mean vectors. + - logvar (N x C tensor): Log-variance vectors. + - z (N x C tensor): Style code vectors. + """ + if input_x.size(2) != 256 or input_x.size(3) != 256: + input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear') + x = self.layer1(input_x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + x = self.layer6(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + z = eps.mul(std) + mu + return mu, logvar, z diff --git a/imaginaire/layers/__init__.py b/imaginaire/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3f93c154678b630de93ab3d6b199204c4fd8fb --- /dev/null +++ b/imaginaire/layers/__init__.py @@ -0,0 +1,27 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .conv import LinearBlock, Conv1dBlock, Conv2dBlock, Conv3dBlock, \ + HyperConv2dBlock, MultiOutConv2dBlock, \ + PartialConv2dBlock, PartialConv3dBlock +from .residual import ResLinearBlock, Res1dBlock, Res2dBlock, Res3dBlock, \ + HyperRes2dBlock, MultiOutRes2dBlock, UpRes2dBlock, DownRes2dBlock, \ + PartialRes2dBlock, PartialRes3dBlock +from .non_local import NonLocal2dBlock + +__all__ = ['Conv1dBlock', 'Conv2dBlock', 'Conv3dBlock', 'LinearBlock', + 'HyperConv2dBlock', 'MultiOutConv2dBlock', + 'PartialConv2dBlock', 'PartialConv3dBlock', + 'Res1dBlock', 'Res2dBlock', 'Res3dBlock', + 'UpRes2dBlock', 'DownRes2dBlock', + 'ResLinearBlock', 'HyperRes2dBlock', 'MultiOutRes2dBlock', + 'PartialRes2dBlock', 'PartialRes3dBlock', + 'NonLocal2dBlock'] + +try: + from .repvgg import RepVGG1dBlock, RepVGG2dBlock, RepVGG3dBlock + from .attn import MultiheadAttention + __all__.extend(['RepVGG1dBlock', 'RepVGG2dBlock', 'RepVGG3dBlock']) +except: # noqa + pass diff --git a/imaginaire/layers/activation_norm.py b/imaginaire/layers/activation_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf44616e947f61678ab5bdc4daa4508fb3a857a --- /dev/null +++ b/imaginaire/layers/activation_norm.py @@ -0,0 +1,629 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# flake8: noqa E722 +from types import SimpleNamespace + +import torch + +try: + from torch.nn import SyncBatchNorm +except ImportError: + from torch.nn import BatchNorm2d as SyncBatchNorm +from torch import nn +from torch.nn import functional as F +from .conv import LinearBlock, Conv2dBlock, HyperConv2d, PartialConv2dBlock +from .misc import PartialSequential, ApplyNoise + + +class AdaptiveNorm(nn.Module): + r"""Adaptive normalization layer. The layer first normalizes the input, then + performs an affine transformation using parameters computed from the + conditional inputs. + + Args: + num_features (int): Number of channels in the input tensor. + cond_dims (int): Number of channels in the conditional inputs. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'``, or ``'weight_demod'``. + projection (bool): If ``True``, project the conditional input to gamma + and beta using a fully connected layer, otherwise directly use + the conditional input as gamma and beta. + projection_bias (bool) If ``True``, use bias in the fully connected + projection layer. + separate_projection (bool): If ``True``, we will use two different + layers for gamma and beta. Otherwise, we will use one layer. It + matters only if you apply any weight norms to this layer. + input_dim (int): Number of dimensions of the input tensor. + activation_norm_type (str): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + """ + + def __init__(self, num_features, cond_dims, weight_norm_type='', + projection=True, + projection_bias=True, + separate_projection=False, + input_dim=2, + activation_norm_type='instance', + activation_norm_params=None, + apply_noise=False, + add_bias=True, + input_scale=1.0, + init_gain=1.0): + super().__init__() + if activation_norm_params is None: + activation_norm_params = SimpleNamespace(affine=False) + self.norm = get_activation_norm_layer(num_features, + activation_norm_type, + input_dim, + **vars(activation_norm_params)) + if apply_noise: + self.noise_layer = ApplyNoise() + else: + self.noise_layer = None + + if projection: + if separate_projection: + self.fc_gamma = \ + LinearBlock(cond_dims, num_features, + weight_norm_type=weight_norm_type, + bias=projection_bias) + self.fc_beta = \ + LinearBlock(cond_dims, num_features, + weight_norm_type=weight_norm_type, + bias=projection_bias) + else: + self.fc = LinearBlock(cond_dims, num_features * 2, + weight_norm_type=weight_norm_type, + bias=projection_bias) + + self.projection = projection + self.separate_projection = separate_projection + self.input_scale = input_scale + self.add_bias = add_bias + self.conditional = True + self.init_gain = init_gain + + def forward(self, x, y, noise=None, **_kwargs): + r"""Adaptive Normalization forward. + + Args: + x (N x C1 x * tensor): Input tensor. + y (N x C2 tensor): Conditional information. + Returns: + out (N x C1 x * tensor): Output tensor. + """ + y = y * self.input_scale + if self.projection: + if self.separate_projection: + gamma = self.fc_gamma(y) + beta = self.fc_beta(y) + for _ in range(x.dim() - gamma.dim()): + gamma = gamma.unsqueeze(-1) + beta = beta.unsqueeze(-1) + else: + y = self.fc(y) + for _ in range(x.dim() - y.dim()): + y = y.unsqueeze(-1) + gamma, beta = y.chunk(2, 1) + else: + for _ in range(x.dim() - y.dim()): + y = y.unsqueeze(-1) + gamma, beta = y.chunk(2, 1) + if self.norm is not None: + x = self.norm(x) + if self.noise_layer is not None: + x = self.noise_layer(x, noise=noise) + if self.add_bias: + x = torch.addcmul(beta, x, 1 + gamma) + return x + else: + return x * (1 + gamma), beta.squeeze(3).squeeze(2) + + +class SpatiallyAdaptiveNorm(nn.Module): + r"""Spatially Adaptive Normalization (SPADE) initialization. + + Args: + num_features (int) : Number of channels in the input tensor. + cond_dims (int or list of int) : List of numbers of channels + in the input. + num_filters (int): Number of filters in SPADE. + kernel_size (int): Kernel size of the convolutional filters in + the SPADE layer. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, or ``'weight'``. + separate_projection (bool): If ``True``, we will use two different + layers for gamma and beta. Otherwise, we will use one layer. It + matters only if you apply any weight norms to this layer. + activation_norm_type (str): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + """ + + def __init__(self, + num_features, + cond_dims, + num_filters=128, + kernel_size=3, + weight_norm_type='', + separate_projection=False, + activation_norm_type='sync_batch', + activation_norm_params=None, + bias_only=False, + partial=False, + interpolation='nearest'): + super().__init__() + if activation_norm_params is None: + activation_norm_params = SimpleNamespace(affine=False) + padding = kernel_size // 2 + self.separate_projection = separate_projection + self.mlps = nn.ModuleList() + self.gammas = nn.ModuleList() + self.betas = nn.ModuleList() + self.bias_only = bias_only + self.interpolation = interpolation + + # Make cond_dims a list. + if type(cond_dims) != list: + cond_dims = [cond_dims] + + # Make num_filters a list. + if not isinstance(num_filters, list): + num_filters = [num_filters] * len(cond_dims) + else: + assert len(num_filters) >= len(cond_dims) + + # Make partial a list. + if not isinstance(partial, list): + partial = [partial] * len(cond_dims) + else: + assert len(partial) >= len(cond_dims) + + for i, cond_dim in enumerate(cond_dims): + mlp = [] + conv_block = PartialConv2dBlock if partial[i] else Conv2dBlock + sequential = PartialSequential if partial[i] else nn.Sequential + + if num_filters[i] > 0: + mlp += [conv_block(cond_dim, + num_filters[i], + kernel_size, + padding=padding, + weight_norm_type=weight_norm_type, + nonlinearity='relu')] + mlp_ch = cond_dim if num_filters[i] == 0 else num_filters[i] + + if self.separate_projection: + if partial[i]: + raise NotImplementedError( + 'Separate projection not yet implemented for ' + + 'partial conv') + self.mlps.append(nn.Sequential(*mlp)) + self.gammas.append( + conv_block(mlp_ch, num_features, + kernel_size, + padding=padding, + weight_norm_type=weight_norm_type)) + self.betas.append( + conv_block(mlp_ch, num_features, + kernel_size, + padding=padding, + weight_norm_type=weight_norm_type)) + else: + mlp += [conv_block(mlp_ch, num_features * 2, kernel_size, + padding=padding, + weight_norm_type=weight_norm_type)] + self.mlps.append(sequential(*mlp)) + + self.norm = get_activation_norm_layer(num_features, + activation_norm_type, + 2, + **vars(activation_norm_params)) + self.conditional = True + + def forward(self, x, *cond_inputs, **_kwargs): + r"""Spatially Adaptive Normalization (SPADE) forward. + + Args: + x (N x C1 x H x W tensor) : Input tensor. + cond_inputs (list of tensors) : Conditional maps for SPADE. + Returns: + output (4D tensor) : Output tensor. + """ + output = self.norm(x) if self.norm is not None else x + for i in range(len(cond_inputs)): + if cond_inputs[i] is None: + continue + label_map = F.interpolate(cond_inputs[i], size=x.size()[2:], mode=self.interpolation) + if self.separate_projection: + hidden = self.mlps[i](label_map) + gamma = self.gammas[i](hidden) + beta = self.betas[i](hidden) + else: + affine_params = self.mlps[i](label_map) + gamma, beta = affine_params.chunk(2, dim=1) + if self.bias_only: + output = output + beta + else: + output = output * (1 + gamma) + beta + return output + + +class DualAdaptiveNorm(nn.Module): + def __init__(self, + num_features, + cond_dims, + projection_bias=True, + weight_norm_type='', + activation_norm_type='instance', + activation_norm_params=None, + apply_noise=False, + bias_only=False, + init_gain=1.0, + fc_scale=None, + is_spatial=None): + super().__init__() + if activation_norm_params is None: + activation_norm_params = SimpleNamespace(affine=False) + self.mlps = nn.ModuleList() + self.gammas = nn.ModuleList() + self.betas = nn.ModuleList() + self.bias_only = bias_only + + # Make cond_dims a list. + if type(cond_dims) != list: + cond_dims = [cond_dims] + + if is_spatial is None: + is_spatial = [False for _ in range(len(cond_dims))] + self.is_spatial = is_spatial + + for cond_dim, this_is_spatial in zip(cond_dims, is_spatial): + kwargs = dict(weight_norm_type=weight_norm_type, + bias=projection_bias, + init_gain=init_gain, + output_scale=fc_scale) + if this_is_spatial: + self.gammas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs)) + self.betas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs)) + else: + self.gammas.append(LinearBlock(cond_dim, num_features, **kwargs)) + self.betas.append(LinearBlock(cond_dim, num_features, **kwargs)) + + self.norm = get_activation_norm_layer(num_features, + activation_norm_type, + 2, + **vars(activation_norm_params)) + self.conditional = True + + def forward(self, x, *cond_inputs, **_kwargs): + assert len(cond_inputs) == len(self.gammas) + output = self.norm(x) if self.norm is not None else x + for cond, gamma_layer, beta_layer in zip(cond_inputs, self.gammas, self.betas): + if cond is None: + continue + gamma = gamma_layer(cond) + beta = beta_layer(cond) + if cond.dim() == 4 and gamma.shape != x.shape: + gamma = F.interpolate(gamma, size=x.size()[2:], mode='bilinear') + beta = F.interpolate(beta, size=x.size()[2:], mode='bilinear') + elif cond.dim() == 2: + gamma = gamma[:, :, None, None] + beta = beta[:, :, None, None] + if self.bias_only: + output = output + beta + else: + output = output * (1 + gamma) + beta + return output + + +class HyperSpatiallyAdaptiveNorm(nn.Module): + r"""Spatially Adaptive Normalization (SPADE) initialization. + + Args: + num_features (int) : Number of channels in the input tensor. + cond_dims (int or list of int) : List of numbers of channels + in the conditional input. + num_filters (int): Number of filters in SPADE. + kernel_size (int): Kernel size of the convolutional filters in + the SPADE layer. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, or ``'weight'``. + activation_norm_type (str): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``. + is_hyper (bool): Whether to use hyper SPADE. + """ + + def __init__(self, num_features, cond_dims, + num_filters=0, kernel_size=3, + weight_norm_type='', + activation_norm_type='sync_batch', is_hyper=True): + super().__init__() + padding = kernel_size // 2 + self.mlps = nn.ModuleList() + if type(cond_dims) != list: + cond_dims = [cond_dims] + + for i, cond_dim in enumerate(cond_dims): + mlp = [] + if not is_hyper or (i != 0): + if num_filters > 0: + mlp += [Conv2dBlock(cond_dim, num_filters, kernel_size, + padding=padding, + weight_norm_type=weight_norm_type, + nonlinearity='relu')] + mlp_ch = cond_dim if num_filters == 0 else num_filters + mlp += [Conv2dBlock(mlp_ch, num_features * 2, kernel_size, + padding=padding, + weight_norm_type=weight_norm_type)] + mlp = nn.Sequential(*mlp) + else: + if num_filters > 0: + raise ValueError('Multi hyper layer not supported yet.') + mlp = HyperConv2d(padding=padding) + self.mlps.append(mlp) + + self.norm = get_activation_norm_layer(num_features, + activation_norm_type, + 2, + affine=False) + + self.conditional = True + + def forward(self, x, *cond_inputs, + norm_weights=(None, None), **_kwargs): + r"""Spatially Adaptive Normalization (SPADE) forward. + + Args: + x (4D tensor) : Input tensor. + cond_inputs (list of tensors) : Conditional maps for SPADE. + norm_weights (5D tensor or list of tensors): conv weights or + [weights, biases]. + Returns: + output (4D tensor) : Output tensor. + """ + output = self.norm(x) + for i in range(len(cond_inputs)): + if cond_inputs[i] is None: + continue + if type(cond_inputs[i]) == list: + cond_input, mask = cond_inputs[i] + mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=False) + else: + cond_input = cond_inputs[i] + mask = None + label_map = F.interpolate(cond_input, size=x.size()[2:]) + if norm_weights is None or norm_weights[0] is None or i != 0: + affine_params = self.mlps[i](label_map) + else: + affine_params = self.mlps[i](label_map, + conv_weights=norm_weights) + gamma, beta = affine_params.chunk(2, dim=1) + if mask is not None: + gamma = gamma * (1 - mask) + beta = beta * (1 - mask) + output = output * (1 + gamma) + beta + return output + + +class LayerNorm2d(nn.Module): + r"""Layer Normalization as introduced in + https://arxiv.org/abs/1607.06450. + This is the usual way to apply layer normalization in CNNs. + Note that unlike the pytorch implementation which applies per-element + scale and bias, here it applies per-channel scale and bias, similar to + batch/instance normalization. + + Args: + num_features (int): Number of channels in the input tensor. + eps (float, optional, default=1e-5): a value added to the + denominator for numerical stability. + affine (bool, optional, default=False): If ``True``, performs + affine transformation after normalization. + """ + + def __init__(self, num_features, eps=1e-5, channel_only=False, affine=True): + super(LayerNorm2d, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + self.channel_only = channel_only + + if self.affine: + self.gamma = nn.Parameter(torch.Tensor(num_features).fill_(1.0)) + self.beta = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + r""" + + Args: + x (tensor): Input tensor. + """ + shape = [-1] + [1] * (x.dim() - 1) + if self.channel_only: + mean = x.mean(1, keepdim=True) + std = x.std(1, keepdim=True) + else: + mean = x.view(x.size(0), -1).mean(1).view(*shape) + std = x.view(x.size(0), -1).std(1).view(*shape) + + x = (x - mean) / (std + self.eps) + + if self.affine: + shape = [1, -1] + [1] * (x.dim() - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ScaleNorm(nn.Module): + r"""Scale normalization: + "Transformers without Tears: Improving the Normalization of Self-Attention" + Modified from: + https://github.com/tnq177/transformers_without_tears + """ + + def __init__(self, dim=-1, learned_scale=True, eps=1e-5): + super().__init__() + # scale = num_features ** 0.5 + if learned_scale: + self.scale = nn.Parameter(torch.tensor(1.)) + else: + self.scale = 1. + # self.num_features = num_features + self.dim = dim + self.eps = eps + self.learned_scale = learned_scale + + def forward(self, x): + # noinspection PyArgumentList + scale = self.scale * torch.rsqrt(torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps) + return x * scale + + def extra_repr(self): + s = 'learned_scale={learned_scale}' + return s.format(**self.__dict__) + + +class PixelNorm(ScaleNorm): + def __init__(self, learned_scale=False, eps=1e-5, **_kwargs): + super().__init__(1, learned_scale, eps) + + +class SplitMeanStd(nn.Module): + def __init__(self, num_features, eps=1e-5, **kwargs): + super().__init__() + self.num_features = num_features + self.eps = eps + self.multiple_outputs = True + + def forward(self, x): + b, c, h, w = x.size() + mean = x.view(b, c, -1).mean(-1)[:, :, None, None] + var = x.view(b, c, -1).var(-1)[:, :, None, None] + std = torch.sqrt(var + self.eps) + + # x = (x - mean) / std + return x, torch.cat((mean, std), dim=1) + + +class ScaleNorm(nn.Module): + r"""Scale normalization: + "Transformers without Tears: Improving the Normalization of Self-Attention" + Modified from: + https://github.com/tnq177/transformers_without_tears + """ + + def __init__(self, dim=-1, learned_scale=True, eps=1e-5): + super().__init__() + # scale = num_features ** 0.5 + if learned_scale: + self.scale = nn.Parameter(torch.tensor(1.)) + else: + self.scale = 1. + # self.num_features = num_features + self.dim = dim + self.eps = eps + self.learned_scale = learned_scale + + def forward(self, x): + # noinspection PyArgumentList + scale = self.scale * torch.rsqrt( + torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps) + return x * scale + + def extra_repr(self): + s = 'learned_scale={learned_scale}' + return s.format(**self.__dict__) + + +class PixelLayerNorm(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.norm = nn.LayerNorm(*args, **kwargs) + + def forward(self, x): + if x.dim() == 4: + b, c, h, w = x.shape + return self.norm(x.permute(0, 2, 3, 1).view(-1, c).contiguous()).view(b, h, w, c).permute(0, 3, 1, 2).contiguous() + else: + return self.norm(x) + + +def get_activation_norm_layer(num_features, norm_type, input_dim, **norm_params): + r"""Return an activation normalization layer. + + Args: + num_features (int): Number of feature channels. + norm_type (str): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + input_dim (int): Number of input dimensions. + norm_params: Arbitrary keyword arguments that will be used to + initialize the activation normalization. + """ + input_dim = max(input_dim, 1) # Norm1d works with both 0d and 1d inputs + + if norm_type == 'none' or norm_type == '': + norm_layer = None + elif norm_type == 'batch': + norm = getattr(nn, 'BatchNorm%dd' % input_dim) + norm_layer = norm(num_features, **norm_params) + elif norm_type == 'instance': + affine = norm_params.pop('affine', True) # Use affine=True by default + norm = getattr(nn, 'InstanceNorm%dd' % input_dim) + norm_layer = norm(num_features, affine=affine, **norm_params) + elif norm_type == 'sync_batch': + norm_layer = SyncBatchNorm(num_features, **norm_params) + elif norm_type == 'layer': + norm_layer = nn.LayerNorm(num_features, **norm_params) + elif norm_type == 'layer_2d': + norm_layer = LayerNorm2d(num_features, **norm_params) + elif norm_type == 'pixel_layer': + elementwise_affine = norm_params.pop('affine', True) # Use affine=True by default + norm_layer = PixelLayerNorm(num_features, elementwise_affine=elementwise_affine, **norm_params) + elif norm_type == 'scale': + norm_layer = ScaleNorm(**norm_params) + elif norm_type == 'pixel': + norm_layer = PixelNorm(**norm_params) + import imaginaire.config + if imaginaire.config.USE_JIT: + norm_layer = torch.jit.script(norm_layer) + elif norm_type == 'group': + num_groups = norm_params.pop('num_groups', 4) + norm_layer = nn.GroupNorm(num_channels=num_features, num_groups=num_groups, **norm_params) + elif norm_type == 'adaptive': + norm_layer = AdaptiveNorm(num_features, **norm_params) + elif norm_type == 'dual_adaptive': + norm_layer = DualAdaptiveNorm(num_features, **norm_params) + elif norm_type == 'spatially_adaptive': + if input_dim != 2: + raise ValueError('Spatially adaptive normalization layers ' + 'only supports 2D input') + norm_layer = SpatiallyAdaptiveNorm(num_features, **norm_params) + elif norm_type == 'hyper_spatially_adaptive': + if input_dim != 2: + raise ValueError('Spatially adaptive normalization layers ' + 'only supports 2D input') + norm_layer = HyperSpatiallyAdaptiveNorm(num_features, **norm_params) + else: + raise ValueError('Activation norm layer %s ' + 'is not recognized' % norm_type) + return norm_layer diff --git a/imaginaire/layers/conv.py b/imaginaire/layers/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..499fc0442b77e3183225c3529a4e3590dab0bc57 --- /dev/null +++ b/imaginaire/layers/conv.py @@ -0,0 +1,1377 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import warnings +from types import SimpleNamespace + +import torch +from torch import nn +from torch.nn import functional as F + +from .misc import ApplyNoise +from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur + + +class _BaseConvBlock(nn.Module): + r"""An abstract wrapper class that wraps a torch convolution or linear layer + with normalization and nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, nonlinearity, + inplace_nonlinearity, apply_noise, blur, order, input_dim, clamp, blur_kernel, output_scale, + init_gain): + super().__init__() + from .nonlinearity import get_nonlinearity_layer + from .weight_norm import get_weight_norm_layer + from .activation_norm import get_activation_norm_layer + self.weight_norm_type = weight_norm_type + self.stride = stride + self.clamp = clamp + self.init_gain = init_gain + + # Nonlinearity layer. + if 'fused' in nonlinearity: + # Fusing nonlinearity with bias. + lr_mul = getattr(weight_norm_params, 'lr_mul', 1) + conv_before_nonlinearity = order.find('C') < order.find('A') + if conv_before_nonlinearity: + assert bias is True + bias = False + channel = out_channels if conv_before_nonlinearity else in_channels + nonlinearity_layer = get_nonlinearity_layer( + nonlinearity, inplace=inplace_nonlinearity, + num_channels=channel, lr_mul=lr_mul) + else: + nonlinearity_layer = get_nonlinearity_layer( + nonlinearity, inplace=inplace_nonlinearity) + + # Noise injection layer. + if apply_noise: + order = order.replace('C', 'CG') + noise_layer = ApplyNoise() + else: + noise_layer = None + + # Convolutional layer. + if blur: + assert blur_kernel is not None + if stride == 2: + # Blur - Conv - Noise - Activate + p = (len(blur_kernel) - 2) + (kernel_size - 1) + pad0, pad1 = (p + 1) // 2, p // 2 + padding = 0 + blur_layer = Blur( + blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode + ) + order = order.replace('C', 'BC') + elif stride == 0.5: + # Conv - Blur - Noise - Activate + padding = 0 + p = (len(blur_kernel) - 2) - (kernel_size - 1) + pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1 + blur_layer = Blur( + blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode + ) + order = order.replace('C', 'CB') + elif stride == 1: + # No blur for now + blur_layer = nn.Identity() + else: + raise NotImplementedError + else: + blur_layer = nn.Identity() + + if weight_norm_params is None: + weight_norm_params = SimpleNamespace() + weight_norm = get_weight_norm_layer( + weight_norm_type, **vars(weight_norm_params)) + conv_layer = weight_norm(self._get_conv_layer( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, input_dim)) + + # Normalization layer. + conv_before_norm = order.find('C') < order.find('N') + norm_channels = out_channels if conv_before_norm else in_channels + if activation_norm_params is None: + activation_norm_params = SimpleNamespace() + activation_norm_layer = get_activation_norm_layer( + norm_channels, + activation_norm_type, + input_dim, + **vars(activation_norm_params)) + + # Mapping from operation names to layers. + mappings = {'C': {'conv': conv_layer}, + 'N': {'norm': activation_norm_layer}, + 'A': {'nonlinearity': nonlinearity_layer}} + mappings.update({'B': {'blur': blur_layer}}) + mappings.update({'G': {'noise': noise_layer}}) + + # All layers in order. + self.layers = nn.ModuleDict() + for op in order: + if list(mappings[op].values())[0] is not None: + self.layers.update(mappings[op]) + + # Whether this block expects conditional inputs. + self.conditional = \ + getattr(conv_layer, 'conditional', False) or \ + getattr(activation_norm_layer, 'conditional', False) + + # Scale the output by a learnable scaler parameter. + if output_scale is not None: + self.output_scale = nn.Parameter(torch.tensor(output_scale)) + else: + self.register_parameter("output_scale", None) + + def forward(self, x, *cond_inputs, **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + kw_cond_inputs (dict) : Keyword conditional inputs. + """ + for key, layer in self.layers.items(): + if getattr(layer, 'conditional', False): + # Layers that require conditional inputs. + x = layer(x, *cond_inputs, **kw_cond_inputs) + else: + x = layer(x) + if self.clamp is not None and isinstance(layer, nn.Conv2d): + x.clamp_(max=self.clamp) + if key == 'conv': + if self.output_scale is not None: + x = x * self.output_scale + return x + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + # Returns the convolutional layer. + if input_dim == 0: + layer = nn.Linear(in_channels, out_channels, bias) + else: + if stride < 1: # Fractionally-strided convolution. + padding_mode = 'zeros' + assert padding == 0 + layer_type = getattr(nn, f'ConvTranspose{input_dim}d') + stride = round(1 / stride) + else: + layer_type = getattr(nn, f'Conv{input_dim}d') + layer = layer_type( + in_channels, out_channels, kernel_size, stride, padding, + dilation=dilation, groups=groups, bias=bias, + padding_mode=padding_mode + ) + + return layer + + def __repr__(self): + main_str = self._get_name() + '(' + child_lines = [] + for name, layer in self.layers.items(): + mod_str = repr(layer) + if name == 'conv' and self.weight_norm_type != 'none' and \ + self.weight_norm_type != '': + mod_str = mod_str[:-1] + \ + ', weight_norm={}'.format(self.weight_norm_type) + ')' + if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1: + mod_str = mod_str[:-1] + \ + ', lr_mul={}'.format(layer.base_lr_mul) + ')' + mod_str = self._addindent(mod_str, 2) + child_lines.append(mod_str) + if len(child_lines) == 1: + main_str += child_lines[0] + else: + main_str += '\n ' + '\n '.join(child_lines) + '\n' + + main_str += ')' + return main_str + + @staticmethod + def _addindent(s_, numSpaces): + s = s_.split('\n') + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + +class ModulatedConv2dBlock(_BaseConvBlock): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=True, blur=True, order='CNA', demodulate=True, + eps=True, style_dim=None, clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, init_gain=1.0): + self.eps = eps + self.demodulate = demodulate + assert style_dim is not None + + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, blur, + order, 2, clamp, blur_kernel, output_scale, init_gain) + self.modulation = LinearBlock(style_dim, in_channels, + weight_norm_type=weight_norm_type, + weight_norm_params=weight_norm_params) + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + assert input_dim == 2 + layer = ModulatedConv2d( + in_channels, out_channels, kernel_size, stride, padding, + dilation, groups, bias, padding_mode, self.demodulate, self.eps) + return layer + + def forward(self, x, *cond_inputs, **kw_cond_inputs): + for layer in self.layers.values(): + if getattr(layer, 'conditional', False): + # Layers that require conditional inputs. + assert len(cond_inputs) == 1 + style = cond_inputs[0] + x = layer( + x, self.modulation(style), **kw_cond_inputs + ) + else: + x = layer(x) + if self.clamp is not None and isinstance(layer, ModulatedConv2d): + x.clamp_(max=self.clamp) + return x + + def __repr__(self): + main_str = self._get_name() + '(' + child_lines = [] + for name, layer in self.layers.items(): + mod_str = repr(layer) + if name == 'conv' and self.weight_norm_type != 'none' and \ + self.weight_norm_type != '': + mod_str = mod_str[:-1] + \ + ', weight_norm={}'.format(self.weight_norm_type) + \ + ', demodulate={}'.format(self.demodulate) + ')' + mod_str = self._addindent(mod_str, 2) + child_lines.append(mod_str) + child_lines.append( + self._addindent('Modulation(' + repr(self.modulation) + ')', 2) + ) + if len(child_lines) == 1: + main_str += child_lines[0] + else: + main_str += '\n ' + '\n '.join(child_lines) + '\n' + + main_str += ')' + return main_str + + +class ModulatedConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, + dilation, groups, bias, padding_mode, demodulate=True, + eps=1e-8): + # in_channels, out_channels, kernel_size, stride, padding, + # dilation, groups, bias, padding_mode + assert dilation == 1 and groups == 1 + + super().__init__() + + self.eps = eps + self.kernel_size = kernel_size + self.in_channels = in_channels + self.out_channels = out_channels + self.padding = padding + self.stride = stride + self.padding_mode = padding_mode + # kernel_size // 2 + # assert self.padding == padding + + self.weight = nn.Parameter( + torch.randn(out_channels, in_channels, kernel_size, kernel_size) + ) + + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + # noinspection PyTypeChecker + self.register_parameter('bias', None) + + # self.modulation = LinearBlock(style_dim, in_channels, + # weight_norm_type=weight_norm_type) + self.demodulate = demodulate + self.conditional = True + + def forward(self, x, style, **_kwargs): + batch, in_channel, height, width = x.shape + + # style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + # We assume the modulation layer is outside this module. + style = style.view(batch, 1, in_channel, 1, 1) + weight = self.weight.unsqueeze(0) * style + + if self.demodulate: + demod = torch.rsqrt( + weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(batch, self.out_channels, 1, 1, 1) + + weight = weight.view( + batch * self.out_channels, + in_channel, self.kernel_size, self.kernel_size + ) + if self.bias is not None: + bias = self.bias.repeat(batch) + else: + bias = self.bias + + x = x.view(1, batch * in_channel, height, width) + + if self.padding_mode != 'zeros': + x = F.pad(x, self._reversed_padding_repeated_twice, + mode=self.padding_mode) + padding = (0, 0) + else: + padding = self.padding + + if self.stride == 0.5: + weight = weight.view( + batch, self.out_channels, in_channel, + self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channels, + self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d( + x, weight, bias, padding=padding, stride=2, groups=batch + ) + + elif self.stride == 2: + out = F.conv2d( + x, weight, bias, padding=padding, stride=2, groups=batch + ) + + else: + out = F.conv2d(x, weight, bias, padding=padding, groups=batch) + + _, _, height, width = out.shape + out = out.view(batch, self.out_channels, height, width) + + return out + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.bias is None: + s += ', bias=False' + if self.padding_mode != 'zeros': + s += ', padding_mode={padding_mode}' + return s.format(**self.__dict__) + + +class LinearBlock(_BaseConvBlock): + r"""A Wrapper class that wraps ``torch.nn.Linear`` with normalization and + nonlinearity. + + Args: + in_features (int): Number of channels in the input tensor. + out_features (int): Number of channels in the output tensor. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, add + Gaussian noise with learnable magnitude after the + fully-connected layer. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: fully-connected, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + """ + + def __init__(self, in_features, out_features, bias=True, + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, + init_gain=1.0, **_kwargs): + if bool(_kwargs): + warnings.warn(f"Unused keyword arguments {_kwargs}") + super().__init__(in_features, out_features, None, None, + None, None, None, bias, + None, weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, + False, order, 0, clamp, blur_kernel, output_scale, + init_gain) + + +class EmbeddingBlock(_BaseConvBlock): + def __init__(self, in_features, out_features, bias=True, + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, order='CNA', clamp=None, output_scale=None, + init_gain=1.0, **_kwargs): + if bool(_kwargs): + warnings.warn(f"Unused keyword arguments {_kwargs}") + super().__init__(in_features, out_features, None, None, + None, None, None, bias, + None, weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, + False, order, 0, clamp, None, output_scale, + init_gain) + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + assert input_dim == 0 + return nn.Embedding(in_channels, out_channels) + + +class Embedding2dBlock(_BaseConvBlock): + def __init__(self, in_features, out_features, bias=True, + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, order='CNA', clamp=None, output_scale=None, + init_gain=1.0, **_kwargs): + if bool(_kwargs): + warnings.warn(f"Unused keyword arguments {_kwargs}") + super().__init__(in_features, out_features, None, None, + None, None, None, bias, + None, weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, + False, order, 0, clamp, None, output_scale, + init_gain) + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + assert input_dim == 0 + return Embedding2d(in_channels, out_channels) + + +class Conv1dBlock(_BaseConvBlock): + r"""A Wrapper class that wraps ``torch.nn.Conv1d`` with normalization and + nonlinearity. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, blur=False, order='CNA', clamp=None, output_scale=None, init_gain=1.0, **_kwargs): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, + blur, order, 1, clamp, None, output_scale, init_gain) + + +class Conv2dBlock(_BaseConvBlock): + r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and + nonlinearity. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, blur=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), + output_scale=None, init_gain=1.0): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + apply_noise, blur, order, 2, clamp, blur_kernel, output_scale, init_gain) + + +class Conv3dBlock(_BaseConvBlock): + r"""A Wrapper class that wraps ``torch.nn.Conv3d`` with normalization and + nonlinearity. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, blur=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, + init_gain=1.0): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + apply_noise, blur, order, 3, clamp, blur_kernel, output_scale, init_gain) + + +class _BaseHyperConvBlock(_BaseConvBlock): + r"""An abstract wrapper class that wraps a hyper convolutional layer + with normalization and nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, + padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, blur, + is_hyper_conv, is_hyper_norm, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1), + output_scale=None, init_gain=1.0): + self.is_hyper_conv = is_hyper_conv + if is_hyper_conv: + weight_norm_type = 'none' + if is_hyper_norm: + activation_norm_type = 'hyper_' + activation_norm_type + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, blur, + order, input_dim, clamp, blur_kernel, output_scale, init_gain) + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + if input_dim == 0: + raise ValueError('HyperLinearBlock is not supported.') + else: + name = 'HyperConv' if self.is_hyper_conv else 'nn.Conv' + layer_type = eval(name + '%dd' % input_dim) + layer = layer_type( + in_channels, out_channels, kernel_size, stride, padding, + dilation, groups, bias, padding_mode) + return layer + + +class HyperConv2dBlock(_BaseHyperConvBlock): + r"""A Wrapper class that wraps ``HyperConv2d`` with normalization and + nonlinearity. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + is_hyper_conv (bool, optional, default=False): If ``True``, use + ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``. + is_hyper_norm (bool, optional, default=False): If ``True``, use + hyper normalizations. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + is_hyper_conv=False, is_hyper_norm=False, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, blur=False, order='CNA', clamp=None): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, blur, + is_hyper_conv, is_hyper_norm, order, 2, clamp) + + +class HyperConv2d(nn.Module): + r"""Hyper Conv2d initialization. + + Args: + in_channels (int): Dummy parameter. + out_channels (int): Dummy parameter. + kernel_size (int or tuple): Dummy parameter. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. Default: 1 + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + padding_mode (string, optional, default='zeros'): + ``'zeros'``, ``'reflect'``, ``'replicate'`` + or ``'circular'``. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): If ``True``, + adds a learnable bias to the output. + """ + + def __init__(self, in_channels=0, out_channels=0, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros'): + super().__init__() + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.use_bias = bias + self.padding_mode = padding_mode + self.conditional = True + + def forward(self, x, *args, conv_weights=(None, None), **kwargs): + r"""Hyper Conv2d forward. Convolve x using the provided weight and bias. + + Args: + x (N x C x H x W tensor): Input tensor. + conv_weights (N x C2 x C1 x k x k tensor or list of tensors): + Convolution weights or [weight, bias]. + Returns: + y (N x C2 x H x W tensor): Output tensor. + """ + if conv_weights is None: + conv_weight, conv_bias = None, None + elif isinstance(conv_weights, torch.Tensor): + conv_weight, conv_bias = conv_weights, None + else: + conv_weight, conv_bias = conv_weights + + if conv_weight is None: + return x + if conv_bias is None: + if self.use_bias: + raise ValueError('bias not provided but set to true during ' + 'initialization') + conv_bias = [None] * x.size(0) + if self.padding_mode != 'zeros': + x = F.pad(x, [self.padding] * 4, mode=self.padding_mode) + padding = 0 + else: + padding = self.padding + + y = None + # noinspection PyArgumentList + for i in range(x.size(0)): + if self.stride >= 1: + yi = F.conv2d(x[i: i + 1], + weight=conv_weight[i], bias=conv_bias[i], + stride=self.stride, padding=padding, + dilation=self.dilation, groups=self.groups) + else: + yi = F.conv_transpose2d(x[i: i + 1], weight=conv_weight[i], + bias=conv_bias[i], padding=self.padding, + stride=int(1 / self.stride), + dilation=self.dilation, + output_padding=self.padding, + groups=self.groups) + y = torch.cat([y, yi]) if y is not None else yi + return y + + +class _BasePartialConvBlock(_BaseConvBlock): + r"""An abstract wrapper class that wraps a partial convolutional layer + with normalization and nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + multi_channel, return_mask, + apply_noise, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, init_gain=1.0): + self.multi_channel = multi_channel + self.return_mask = return_mask + self.partial_conv = True + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, + False, order, input_dim, clamp, blur_kernel, output_scale, init_gain) + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + if input_dim == 2: + layer_type = PartialConv2d + elif input_dim == 3: + layer_type = PartialConv3d + else: + raise ValueError('Partial conv only supports 2D and 3D conv now.') + layer = layer_type( + in_channels, out_channels, kernel_size, stride, padding, + dilation, groups, bias, padding_mode, + multi_channel=self.multi_channel, return_mask=self.return_mask) + return layer + + def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + mask_in (tensor, optional, default=``None``) If not ``None``, + it masks the valid input region. + kw_cond_inputs (dict) : Keyword conditional inputs. + Returns: + (tuple): + - x (tensor): Output tensor. + - mask_out (tensor, optional): Masks the valid output region. + """ + mask_out = None + for layer in self.layers.values(): + if getattr(layer, 'conditional', False): + x = layer(x, *cond_inputs, **kw_cond_inputs) + elif getattr(layer, 'partial_conv', False): + x = layer(x, mask_in=mask_in, **kw_cond_inputs) + if type(x) == tuple: + x, mask_out = x + else: + x = layer(x) + + if mask_out is not None: + return x, mask_out + return x + + +class PartialConv2dBlock(_BasePartialConvBlock): + r"""A Wrapper class that wraps ``PartialConv2d`` with normalization and + nonlinearity. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + multi_channel (bool, optional, default=False): If ``True``, use + different masks for different channels. + return_mask (bool, optional, default=True): If ``True``, the + forward call also returns a new mask. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + multi_channel=False, return_mask=True, + apply_noise=False, order='CNA', clamp=None): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + multi_channel, return_mask, apply_noise, order, 2, + clamp) + + +class PartialConv3dBlock(_BasePartialConvBlock): + r"""A Wrapper class that wraps ``PartialConv3d`` with normalization and + nonlinearity. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + multi_channel (bool, optional, default=False): If ``True``, use + different masks for different channels. + return_mask (bool, optional, default=True): If ``True``, the + forward call also returns a new mask. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + multi_channel=False, return_mask=True, + apply_noise=False, order='CNA', clamp=None): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + multi_channel, return_mask, apply_noise, order, 3, + clamp) + + +class _MultiOutBaseConvBlock(_BaseConvBlock): + r"""An abstract wrapper class that wraps a hyper convolutional layer with + normalization and nonlinearity. It can return multiple outputs, if some + layers in the block return more than one output. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, nonlinearity, + inplace_nonlinearity, apply_noise, blur, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1), + output_scale=None, init_gain=1.0): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + apply_noise, blur, order, input_dim, clamp, blur_kernel, output_scale, init_gain) + self.multiple_outputs = True + + def forward(self, x, *cond_inputs, **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + kw_cond_inputs (dict) : Keyword conditional inputs. + Returns: + (tuple): + - x (tensor): Main output tensor. + - other_outputs (list of tensors): Other output tensors. + """ + other_outputs = [] + for layer in self.layers.values(): + if getattr(layer, 'conditional', False): + x = layer(x, *cond_inputs, **kw_cond_inputs) + if getattr(layer, 'multiple_outputs', False): + x, other_output = layer(x) + other_outputs.append(other_output) + else: + x = layer(x) + return (x, *other_outputs) + + +class MultiOutConv2dBlock(_MultiOutBaseConvBlock): + r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and + nonlinearity. It can return multiple outputs, if some layers in the block + return more than one output. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, blur=False, order='CNA', clamp=None): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + apply_noise, blur, order, 2, clamp) + + +############################################################################### +# BSD 3-Clause License +# +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Author & Contact: Guilin Liu (guilinl@nvidia.com) +############################################################################### +class PartialConv2d(nn.Conv2d): + r"""Partial 2D convolution in + "Image inpainting for irregular holes using partial convolutions." + Liu et al., ECCV 2018 + """ + + def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs): + # whether the mask is multi-channel or not + self.multi_channel = multi_channel + self.return_mask = return_mask + super(PartialConv2d, self).__init__(*args, **kwargs) + + if self.multi_channel: + self.weight_maskUpdater = torch.ones(self.out_channels, + self.in_channels, + self.kernel_size[0], + self.kernel_size[1]) + else: + self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], + self.kernel_size[1]) + + shape = self.weight_maskUpdater.shape + self.slide_winsize = shape[1] * shape[2] * shape[3] + + self.last_size = (None, None, None, None) + self.update_mask = None + self.mask_ratio = None + self.partial_conv = True + + def forward(self, x, mask_in=None): + r""" + + Args: + x (tensor): Input tensor. + mask_in (tensor, optional, default=``None``) If not ``None``, + it masks the valid input region. + """ + assert len(x.shape) == 4 + if mask_in is not None or self.last_size != tuple(x.shape): + self.last_size = tuple(x.shape) + + with torch.no_grad(): + if self.weight_maskUpdater.type() != x.type(): + self.weight_maskUpdater = self.weight_maskUpdater.to(x) + + if mask_in is None: + # If mask is not provided, create a mask. + if self.multi_channel: + mask = torch.ones(x.data.shape[0], + x.data.shape[1], + x.data.shape[2], + x.data.shape[3]).to(x) + else: + mask = torch.ones(1, 1, x.data.shape[2], + x.data.shape[3]).to(x) + else: + mask = mask_in + + self.update_mask = F.conv2d(mask, self.weight_maskUpdater, + bias=None, stride=self.stride, + padding=self.padding, + dilation=self.dilation, groups=1) + + # For mixed precision training, eps from 1e-8 to 1e-6. + eps = 1e-6 + self.mask_ratio = self.slide_winsize / (self.update_mask + eps) + self.update_mask = torch.clamp(self.update_mask, 0, 1) + self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) + + raw_out = super(PartialConv2d, self).forward( + torch.mul(x, mask) if mask_in is not None else x) + + if self.bias is not None: + bias_view = self.bias.view(1, self.out_channels, 1, 1) + output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view + output = torch.mul(output, self.update_mask) + else: + output = torch.mul(raw_out, self.mask_ratio) + + if self.return_mask: + return output, self.update_mask + else: + return output + + +class PartialConv3d(nn.Conv3d): + r"""Partial 3D convolution in + "Image inpainting for irregular holes using partial convolutions." + Liu et al., ECCV 2018 + """ + + def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs): + # whether the mask is multi-channel or not + self.multi_channel = multi_channel + self.return_mask = return_mask + super(PartialConv3d, self).__init__(*args, **kwargs) + + if self.multi_channel: + self.weight_maskUpdater = \ + torch.ones(self.out_channels, self.in_channels, + self.kernel_size[0], self.kernel_size[1], + self.kernel_size[2]) + else: + self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], + self.kernel_size[1], + self.kernel_size[2]) + self.weight_maskUpdater = self.weight_maskUpdater.to('cuda') + + shape = self.weight_maskUpdater.shape + self.slide_winsize = shape[1] * shape[2] * shape[3] * shape[4] + self.partial_conv = True + + def forward(self, x, mask_in=None): + r""" + + Args: + x (tensor): Input tensor. + mask_in (tensor, optional, default=``None``) If not ``None``, it + masks the valid input region. + """ + assert len(x.shape) == 5 + + with torch.no_grad(): + mask = mask_in + update_mask = F.conv3d(mask, self.weight_maskUpdater, bias=None, + stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=1) + + mask_ratio = self.slide_winsize / (update_mask + 1e-8) + update_mask = torch.clamp(update_mask, 0, 1) + mask_ratio = torch.mul(mask_ratio, update_mask) + + raw_out = super(PartialConv3d, self).forward(torch.mul(x, mask_in)) + + if self.bias is not None: + bias_view = self.bias.view(1, self.out_channels, 1, 1, 1) + output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view + if mask_in is not None: + output = torch.mul(output, update_mask) + else: + output = torch.mul(raw_out, mask_ratio) + + if self.return_mask: + return output, update_mask + else: + return output + + +class Embedding2d(nn.Embedding): + def __init__(self, in_channels, out_channels): + super().__init__(in_channels, out_channels) + + def forward(self, x): + return F.embedding( + x.squeeze(1).long(), self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse).permute(0, 3, 1, 2).contiguous() diff --git a/imaginaire/layers/misc.py b/imaginaire/layers/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7731bd2fa855939a2866b211c33eb3ccce00c480 --- /dev/null +++ b/imaginaire/layers/misc.py @@ -0,0 +1,61 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch import nn + + +class ApplyNoise(nn.Module): + r"""Add Gaussian noise to the input tensor.""" + + def __init__(self): + super().__init__() + # scale of the noise + self.scale = nn.Parameter(torch.zeros(1)) + self.conditional = True + + def forward(self, x, *_args, noise=None, **_kwargs): + r""" + + Args: + x (tensor): Input tensor. + noise (tensor, optional, default=``None``) : Noise tensor to be + added to the input. + """ + if noise is None: + sz = x.size() + noise = x.new_empty(sz[0], 1, *sz[2:]).normal_() + + return x + self.scale * noise + + +class PartialSequential(nn.Sequential): + r"""Sequential block for partial convolutions.""" + def __init__(self, *modules): + super(PartialSequential, self).__init__(*modules) + + def forward(self, x): + r""" + + Args: + x (tensor): Input tensor. + """ + act = x[:, :-1] + mask = x[:, -1].unsqueeze(1) + for module in self: + act, mask = module(act, mask_in=mask) + return act + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + if isinstance(size, int): + h, w = size, size + else: + h, w = size + self.input = nn.Parameter(torch.randn(1, channel, h, w)) + + def forward(self): + return self.input diff --git a/imaginaire/layers/non_local.py b/imaginaire/layers/non_local.py new file mode 100644 index 0000000000000000000000000000000000000000..adb9676d8e55ed1dcc2f80cef3e2ca195d91ddb9 --- /dev/null +++ b/imaginaire/layers/non_local.py @@ -0,0 +1,88 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from functools import partial + +import torch +import torch.nn as nn + +from imaginaire.layers import Conv2dBlock + + +class NonLocal2dBlock(nn.Module): + r"""Self attention Layer + + Args: + in_channels (int): Number of channels in the input tensor. + scale (bool, optional, default=True): If ``True``, scale the + output by a learnable parameter. + clamp (bool, optional, default=``False``): If ``True``, clamp the + scaling parameter to (-1, 1). + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, weight_norm_params.__dict__ will be used as + keyword arguments when initializing weight normalization. + bias (bool, optional, default=True): If ``True``, adds bias in the + convolutional blocks. + """ + + def __init__(self, + in_channels, + scale=True, + clamp=False, + weight_norm_type='none', + weight_norm_params=None, + bias=True): + super(NonLocal2dBlock, self).__init__() + self.clamp = clamp + self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0 + self.in_channels = in_channels + base_conv2d_block = partial(Conv2dBlock, + kernel_size=1, + stride=1, + padding=0, + weight_norm_type=weight_norm_type, + weight_norm_params=weight_norm_params, + bias=bias) + self.theta = base_conv2d_block(in_channels, in_channels // 8) + self.phi = base_conv2d_block(in_channels, in_channels // 8) + self.g = base_conv2d_block(in_channels, in_channels // 2) + self.out_conv = base_conv2d_block(in_channels // 2, in_channels) + self.softmax = nn.Softmax(dim=-1) + self.max_pool = nn.MaxPool2d(2) + + def forward(self, x): + r""" + + Args: + x (tensor) : input feature maps (B X C X W X H) + Returns: + (tuple): + - out (tensor) : self attention value + input feature + - attention (tensor): B x N x N (N is Width*Height) + """ + n, c, h, w = x.size() + theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1).contiguous() + + phi = self.phi(x) + phi = self.max_pool(phi).view(n, -1, h * w // 4) + + energy = torch.bmm(theta, phi) + attention = self.softmax(energy) + + g = self.g(x) + g = self.max_pool(g).view(n, -1, h * w // 4) + + out = torch.bmm(g, attention.permute(0, 2, 1).contiguous()) + out = out.view(n, c // 2, h, w) + out = self.out_conv(out) + + if self.clamp: + out = self.gamma.clamp(-1, 1) * out + x + else: + out = self.gamma * out + x + return out diff --git a/imaginaire/layers/nonlinearity.py b/imaginaire/layers/nonlinearity.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc172c74323e707e5a19f94e466c1bf0dae4418 --- /dev/null +++ b/imaginaire/layers/nonlinearity.py @@ -0,0 +1,65 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch import nn +import torch.nn.functional as F + +from imaginaire.third_party.bias_act.bias_act import FusedNonlinearity + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2, scale=2 ** 0.5, inplace=False): + super().__init__() + + self.negative_slope = negative_slope + self.scale = scale + self.inplace = inplace + + def forward(self, x): + return F.leaky_relu(x, self.negative_slope, inplace=self.inplace) * self.scale + # return _fused_scaled_leakyrelu(x, self.negative_slope, self.inplace, self.scale) + + +# @torch.jit.script +# def _fused_scaled_leakyrelu(x: torch.Tensor, negative_slope: float, inplace: bool, scale: float): +# return F.leaky_relu(x, negative_slope, inplace=inplace) * scale + + +def get_nonlinearity_layer(nonlinearity_type, inplace, **kwargs): + r"""Return a nonlinearity layer. + + Args: + nonlinearity_type (str): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace (bool): If ``True``, set ``inplace=True`` when initializing + the nonlinearity layer. + """ + if nonlinearity_type.startswith('fused'): + nonlinearity = FusedNonlinearity(nonlinearity=nonlinearity_type[6:], **kwargs) + elif nonlinearity_type == 'relu': + nonlinearity = nn.ReLU(inplace=inplace) + elif nonlinearity_type == 'leakyrelu': + nonlinearity = nn.LeakyReLU(0.2, inplace=inplace) + elif nonlinearity_type == 'scaled_leakyrelu': + nonlinearity = ScaledLeakyReLU(0.2, inplace=inplace) + import imaginaire.config + if imaginaire.config.USE_JIT: + nonlinearity = torch.jit.script(nonlinearity) + elif nonlinearity_type == 'prelu': + nonlinearity = nn.PReLU() + elif nonlinearity_type == 'tanh': + nonlinearity = nn.Tanh() + elif nonlinearity_type == 'sigmoid': + nonlinearity = nn.Sigmoid() + elif nonlinearity_type.startswith('softmax'): + dim = nonlinearity_type.split(',')[1] if ',' in nonlinearity_type else 1 + nonlinearity = nn.Softmax(dim=int(dim)) + elif nonlinearity_type == 'none' or nonlinearity_type == '': + nonlinearity = None + else: + raise ValueError('Nonlinearity %s is not recognized' % nonlinearity_type) + return nonlinearity diff --git a/imaginaire/layers/residual.py b/imaginaire/layers/residual.py new file mode 100644 index 0000000000000000000000000000000000000000..5e1bda4dd30922f694302803b7d606af7f3c0c21 --- /dev/null +++ b/imaginaire/layers/residual.py @@ -0,0 +1,1411 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import functools + +import torch +from torch import nn +from torch.nn import Upsample as NearestUpsample +from torch.utils.checkpoint import checkpoint + +from .conv import (Conv1dBlock, Conv2dBlock, Conv3dBlock, HyperConv2dBlock, + LinearBlock, MultiOutConv2dBlock, PartialConv2dBlock, + PartialConv3dBlock, ModulatedConv2dBlock) +from imaginaire.third_party.upfirdn2d.upfirdn2d import BlurUpsample + + +class _BaseResBlock(nn.Module): + r"""An abstract class for residual blocks. + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, + order, block, learn_shortcut, clamp, output_scale, + skip_block=None, blur=False, upsample_first=True, skip_weight_norm=True): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.output_scale = output_scale + self.upsample_first = upsample_first + self.stride = stride + self.blur = blur + if skip_block is None: + skip_block = block + + if order == 'pre_act': + order = 'NACNAC' + if isinstance(bias, bool): + # The bias for conv_block_0, conv_block_1, and conv_block_s. + biases = [bias, bias, bias] + elif isinstance(bias, list): + if len(bias) == 3: + biases = bias + else: + raise ValueError('Bias list must be 3.') + else: + raise ValueError('Bias must be either an integer or s list.') + if learn_shortcut is None: + self.learn_shortcut = (in_channels != out_channels) + else: + self.learn_shortcut = learn_shortcut + if len(order) > 6 or len(order) < 5: + raise ValueError('order must be either 5 or 6 characters') + if hidden_channels_equal_out_channels: + hidden_channels = out_channels + else: + hidden_channels = min(in_channels, out_channels) + + # Parameters. + residual_params = {} + shortcut_params = {} + base_params = dict(dilation=dilation, + groups=groups, + padding_mode=padding_mode, + clamp=clamp) + residual_params.update(base_params) + residual_params.update( + dict(activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + weight_norm_type=weight_norm_type, + weight_norm_params=weight_norm_params, + padding=padding, + apply_noise=apply_noise)) + shortcut_params.update(base_params) + shortcut_params.update(dict(kernel_size=1)) + if skip_activation_norm: + shortcut_params.update( + dict(activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + apply_noise=False)) + if skip_weight_norm: + shortcut_params.update( + dict(weight_norm_type=weight_norm_type, + weight_norm_params=weight_norm_params)) + + # Residual branch. + if order.find('A') < order.find('C') and \ + (activation_norm_type == '' or activation_norm_type == 'none'): + # Nonlinearity is the first operation in the residual path. + # In-place nonlinearity will modify the input variable and cause + # backward error. + first_inplace = False + else: + first_inplace = inplace_nonlinearity + + (first_stride, second_stride, shortcut_stride, + first_blur, second_blur, shortcut_blur) = self._get_stride_blur() + self.conv_block_0 = block( + in_channels, hidden_channels, + kernel_size=kernel_size, + bias=biases[0], + nonlinearity=nonlinearity, + order=order[0:3], + inplace_nonlinearity=first_inplace, + stride=first_stride, + blur=first_blur, + **residual_params + ) + self.conv_block_1 = block( + hidden_channels, out_channels, + kernel_size=kernel_size, + bias=biases[1], + nonlinearity=nonlinearity, + order=order[3:], + inplace_nonlinearity=inplace_nonlinearity, + stride=second_stride, + blur=second_blur, + **residual_params + ) + + # Shortcut branch. + if self.learn_shortcut: + if skip_nonlinearity: + skip_nonlinearity_type = nonlinearity + else: + skip_nonlinearity_type = '' + self.conv_block_s = skip_block(in_channels, out_channels, + bias=biases[2], + nonlinearity=skip_nonlinearity_type, + order=order[0:3], + stride=shortcut_stride, + blur=shortcut_blur, + **shortcut_params) + elif in_channels < out_channels: + if skip_nonlinearity: + skip_nonlinearity_type = nonlinearity + else: + skip_nonlinearity_type = '' + self.conv_block_s = skip_block(in_channels, + out_channels - in_channels, + bias=biases[2], + nonlinearity=skip_nonlinearity_type, + order=order[0:3], + stride=shortcut_stride, + blur=shortcut_blur, + **shortcut_params) + + # Whether this block expects conditional inputs. + self.conditional = \ + getattr(self.conv_block_0, 'conditional', False) or \ + getattr(self.conv_block_1, 'conditional', False) + + def _get_stride_blur(self): + if self.stride > 1: + # Downsampling. + first_stride, second_stride = 1, self.stride + first_blur, second_blur = False, self.blur + shortcut_stride = self.stride + shortcut_blur = self.blur + self.upsample = None + elif self.stride < 1: + # Upsampling. + first_stride, second_stride = self.stride, 1 + first_blur, second_blur = self.blur, False + shortcut_blur = False + shortcut_stride = 1 + if self.blur: + # The shortcut branch uses blur_upsample + stride-1 conv + self.upsample = BlurUpsample() + else: + shortcut_stride = self.stride + self.upsample = nn.Upsample(scale_factor=2) + else: + first_stride = second_stride = 1 + first_blur = second_blur = False + shortcut_stride = 1 + shortcut_blur = False + self.upsample = None + return (first_stride, second_stride, shortcut_stride, + first_blur, second_blur, shortcut_blur) + + def conv_blocks( + self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs + ): + r"""Returns the output of the residual branch. + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + kw_cond_inputs (dict) : Keyword conditional inputs. + Returns: + dx (tensor): Output tensor. + """ + if separate_cond: + dx = self.conv_block_0(x, cond_inputs[0], + **kw_cond_inputs.get('kwargs_0', {})) + dx = self.conv_block_1(dx, cond_inputs[1], + **kw_cond_inputs.get('kwargs_1', {})) + else: + dx = self.conv_block_0(x, *cond_inputs, **kw_cond_inputs) + dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs) + return dx + + def forward(self, x, *cond_inputs, do_checkpoint=False, separate_cond=False, + **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + do_checkpoint (bool, optional, default=``False``) If ``True``, + trade compute for memory by checkpointing the model. + kw_cond_inputs (dict) : Keyword conditional inputs. + Returns: + output (tensor): Output tensor. + """ + if do_checkpoint: + dx = checkpoint(self.conv_blocks, x, *cond_inputs, + separate_cond=separate_cond, **kw_cond_inputs) + else: + dx = self.conv_blocks(x, *cond_inputs, + separate_cond=separate_cond, **kw_cond_inputs) + + if self.upsample_first and self.upsample is not None: + x = self.upsample(x) + if self.learn_shortcut: + if separate_cond: + x_shortcut = self.conv_block_s( + x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {}) + ) + else: + x_shortcut = self.conv_block_s( + x, *cond_inputs, **kw_cond_inputs + ) + elif self.in_channels < self.out_channels: + if separate_cond: + x_shortcut_pad = self.conv_block_s( + x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {}) + ) + else: + x_shortcut_pad = self.conv_block_s( + x, *cond_inputs, **kw_cond_inputs + ) + x_shortcut = torch.cat((x, x_shortcut_pad), dim=1) + elif self.in_channels > self.out_channels: + x_shortcut = x[:, :self.out_channels, :, :] + else: + x_shortcut = x + if not self.upsample_first and self.upsample is not None: + x_shortcut = self.upsample(x_shortcut) + + output = x_shortcut + dx + return self.output_scale * output + + def extra_repr(self): + s = 'output_scale={output_scale}' + return s.format(**self.__dict__) + + +class ModulatedRes2dBlock(_BaseResBlock): + def __init__(self, in_channels, out_channels, style_dim, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=True, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1, + demodulate=True, eps=1e-8): + block = functools.partial(ModulatedConv2dBlock, + style_dim=style_dim, + demodulate=demodulate, eps=eps) + skip_block = Conv2dBlock + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, block, + learn_shortcut, clamp, output_scale, skip_block=skip_block) + + def conv_blocks(self, x, *cond_inputs, **kw_cond_inputs): + assert len(list(cond_inputs)) == 2 + dx = self.conv_block_0(x, cond_inputs[0], **kw_cond_inputs) + dx = self.conv_block_1(dx, cond_inputs[1], **kw_cond_inputs) + return dx + + +class ResLinearBlock(_BaseResBlock): + r"""Residual block with full-connected layers. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, add + Gaussian noise with learnable magnitude after the + fully-connected layer. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: fully-connected, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, bias=True, + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1): + super().__init__(in_channels, out_channels, None, 1, None, None, + None, bias, None, weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, LinearBlock, + learn_shortcut, clamp, output_scale) + + +class Res1dBlock(_BaseResBlock): + r"""Residual block for 1D input. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, Conv1dBlock, + learn_shortcut, clamp, output_scale) + + +class Res2dBlock(_BaseResBlock): + r"""Residual block for 2D input. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + skip_weight_norm=True, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1, blur=False, upsample_first=True): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, Conv2dBlock, + learn_shortcut, clamp, output_scale, blur=blur, + upsample_first=upsample_first, + skip_weight_norm=skip_weight_norm) + + +class Res3dBlock(_BaseResBlock): + r"""Residual block for 3D input. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, Conv3dBlock, + learn_shortcut, clamp, output_scale) + + +class _BaseHyperResBlock(_BaseResBlock): + r"""An abstract class for hyper residual blocks. + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, + order, is_hyper_conv, is_hyper_norm, block, learn_shortcut, + clamp=None, output_scale=1): + block = functools.partial(block, + is_hyper_conv=is_hyper_conv, + is_hyper_norm=is_hyper_norm) + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, block, + learn_shortcut, clamp, output_scale) + + def forward(self, x, *cond_inputs, conv_weights=(None,) * 3, + norm_weights=(None,) * 3, **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + conv_weights (list of tensors): Convolution weights for + three convolutional layers respectively. + norm_weights (list of tensors): Normalization weights for + three convolutional layers respectively. + kw_cond_inputs (dict) : Keyword conditional inputs. + Returns: + output (tensor): Output tensor. + """ + dx = self.conv_block_0(x, *cond_inputs, conv_weights=conv_weights[0], + norm_weights=norm_weights[0]) + dx = self.conv_block_1(dx, *cond_inputs, conv_weights=conv_weights[1], + norm_weights=norm_weights[1]) + if self.learn_shortcut: + x_shortcut = self.conv_block_s(x, *cond_inputs, + conv_weights=conv_weights[2], + norm_weights=norm_weights[2]) + else: + x_shortcut = x + output = x_shortcut + dx + return self.output_scale * output + + +class HyperRes2dBlock(_BaseHyperResBlock): + r"""Hyper residual block for 2D input. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + is_hyper_conv (bool, optional, default=False): If ``True``, use + ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``. + is_hyper_norm (bool, optional, default=False): If ``True``, use + hyper normalizations. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='', weight_norm_params=None, + activation_norm_type='', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', is_hyper_conv=False, is_hyper_norm=False, + learn_shortcut=None, clamp=None, output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, + order, is_hyper_conv, is_hyper_norm, + HyperConv2dBlock, learn_shortcut, clamp, output_scale) + + +class _BaseDownResBlock(_BaseResBlock): + r"""An abstract class for residual blocks with downsampling. + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, + apply_noise, hidden_channels_equal_out_channels, + order, block, pooling, down_factor, learn_shortcut, + clamp=None, output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, block, + learn_shortcut, clamp, output_scale) + self.pooling = pooling(down_factor) + + def forward(self, x, *cond_inputs): + r""" + + Args: + x (tensor) : Input tensor. + cond_inputs (list of tensors) : conditional input. + Returns: + output (tensor) : Output tensor. + """ + dx = self.conv_block_0(x, *cond_inputs) + dx = self.conv_block_1(dx, *cond_inputs) + dx = self.pooling(dx) + if self.learn_shortcut: + x_shortcut = self.conv_block_s(x, *cond_inputs) + else: + x_shortcut = x + x_shortcut = self.pooling(x_shortcut) + output = x_shortcut + dx + return self.output_scale * output + + +class DownRes2dBlock(_BaseDownResBlock): + r"""Residual block for 2D input with downsampling. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + pooling (class, optional, default=nn.AvgPool2d): Pytorch pooling + layer to be used. + down_factor (int, optional, default=2): Downsampling factor. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', pooling=nn.AvgPool2d, down_factor=2, + learn_shortcut=None, clamp=None, output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, + order, Conv2dBlock, pooling, + down_factor, learn_shortcut, clamp, output_scale) + + +class _BaseUpResBlock(_BaseResBlock): + r"""An abstract class for residual blocks with upsampling. + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, + apply_noise, hidden_channels_equal_out_channels, + order, block, upsample, up_factor, learn_shortcut, clamp=None, + output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, block, + learn_shortcut, clamp, output_scale) + self.order = order + self.upsample = upsample(scale_factor=up_factor) + + def _get_stride_blur(self): + # Upsampling. + first_stride, second_stride = self.stride, 1 + first_blur, second_blur = self.blur, False + shortcut_blur = False + shortcut_stride = 1 + # if self.upsample == 'blur_deconv': + + if self.blur: + # The shortcut branch uses blur_upsample + stride-1 conv + self.upsample = BlurUpsample() + else: + shortcut_stride = self.stride + self.upsample = nn.Upsample(scale_factor=2) + + return (first_stride, second_stride, shortcut_stride, + first_blur, second_blur, shortcut_blur) + + def forward(self, x, *cond_inputs): + r"""Implementation of the up residual block forward function. + If the order is 'NAC' for the first residual block, we will first + do the activation norm and nonlinearity, in the original resolution. + We will then upsample the activation map to a higher resolution. We + then do the convolution. + It is is other orders, then we first do the whole processing and + then upsample. + + Args: + x (tensor) : Input tensor. + cond_inputs (list of tensors) : Conditional input. + Returns: + output (tensor) : Output tensor. + """ + # In this particular upsample residual block operation, we first + # upsample the skip connection. + if self.learn_shortcut: + x_shortcut = self.upsample(x) + x_shortcut = self.conv_block_s(x_shortcut, *cond_inputs) + else: + x_shortcut = self.upsample(x) + + if self.order[0:3] == 'NAC': + for ix, layer in enumerate(self.conv_block_0.layers.values()): + if getattr(layer, 'conditional', False): + x = layer(x, *cond_inputs) + else: + x = layer(x) + if ix == 1: + x = self.upsample(x) + else: + x = self.conv_block_0(x, *cond_inputs) + x = self.upsample(x) + x = self.conv_block_1(x, *cond_inputs) + + output = x_shortcut + x + return self.output_scale * output + + +class UpRes2dBlock(_BaseUpResBlock): + r"""Residual block for 2D input with downsampling. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + upsample (class, optional, default=NearestUpsample): PPytorch + upsampling layer to be used. + up_factor (int, optional, default=2): Upsampling factor. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', upsample=NearestUpsample, up_factor=2, + learn_shortcut=None, clamp=None, output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, + apply_noise, hidden_channels_equal_out_channels, + order, Conv2dBlock, + upsample, up_factor, learn_shortcut, clamp, + output_scale) + + +class _BasePartialResBlock(_BaseResBlock): + r"""An abstract class for residual blocks with partial convolution. + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, + multi_channel, return_mask, + apply_noise, hidden_channels_equal_out_channels, + order, block, learn_shortcut, clamp=None, output_scale=1): + block = functools.partial(block, + multi_channel=multi_channel, + return_mask=return_mask) + self.partial_conv = True + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, block, + learn_shortcut, clamp, output_scale) + + def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + mask_in (tensor, optional, default=``None``) If not ``None``, + it masks the valid input region. + kw_cond_inputs (dict) : Keyword conditional inputs. + Returns: + (tuple): + - output (tensor): Output tensor. + - mask_out (tensor, optional): Masks the valid output region. + """ + if self.conv_block_0.layers.conv.return_mask: + dx, mask_out = self.conv_block_0(x, *cond_inputs, + mask_in=mask_in, **kw_cond_inputs) + dx, mask_out = self.conv_block_1(dx, *cond_inputs, + mask_in=mask_out, **kw_cond_inputs) + else: + dx = self.conv_block_0(x, *cond_inputs, + mask_in=mask_in, **kw_cond_inputs) + dx = self.conv_block_1(dx, *cond_inputs, + mask_in=mask_in, **kw_cond_inputs) + mask_out = None + + if self.learn_shortcut: + x_shortcut = self.conv_block_s(x, mask_in=mask_in, *cond_inputs, + **kw_cond_inputs) + if type(x_shortcut) == tuple: + x_shortcut, _ = x_shortcut + else: + x_shortcut = x + output = x_shortcut + dx + + if mask_out is not None: + return output, mask_out + return self.output_scale * output + + +class PartialRes2dBlock(_BasePartialResBlock): + r"""Residual block for 2D input with partial convolution. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + multi_channel=False, return_mask=True, + apply_noise=False, + hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, + padding_mode, weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, multi_channel, return_mask, + apply_noise, hidden_channels_equal_out_channels, + order, PartialConv2dBlock, learn_shortcut, clamp, + output_scale) + + +class PartialRes3dBlock(_BasePartialResBlock): + r"""Residual block for 3D input with partial convolution. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + multi_channel=False, return_mask=True, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, + padding_mode, weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, multi_channel, + return_mask, apply_noise, + hidden_channels_equal_out_channels, + order, PartialConv3dBlock, learn_shortcut, clamp, + output_scale) + + +class _BaseMultiOutResBlock(_BaseResBlock): + r"""An abstract class for residual blocks that can returns multiple outputs. + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, + apply_noise, hidden_channels_equal_out_channels, + order, block, learn_shortcut, clamp=None, output_scale=1, + blur=False, upsample_first=True): + self.multiple_outputs = True + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, block, + learn_shortcut, clamp, output_scale, blur=blur, + upsample_first=upsample_first) + + def forward(self, x, *cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + Returns: + (tuple): + - output (tensor): Output tensor. + - aux_outputs_0 (tensor): Auxiliary output of the first block. + - aux_outputs_1 (tensor): Auxiliary output of the second block. + """ + dx, aux_outputs_0 = self.conv_block_0(x, *cond_inputs) + dx, aux_outputs_1 = self.conv_block_1(dx, *cond_inputs) + if self.learn_shortcut: + # We are not using the auxiliary outputs of self.conv_block_s. + x_shortcut, _ = self.conv_block_s(x, *cond_inputs) + else: + x_shortcut = x + output = x_shortcut + dx + return self.output_scale * output, aux_outputs_0, aux_outputs_1 + + +class MultiOutRes2dBlock(_BaseMultiOutResBlock): + r"""Residual block for 2D input. It can return multiple outputs, if some + layers in the block return more than one output. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1, blur=False, upsample_first=True): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, + MultiOutConv2dBlock, learn_shortcut, clamp, + output_scale, blur=blur, upsample_first=upsample_first) diff --git a/imaginaire/layers/residual_deep.py b/imaginaire/layers/residual_deep.py new file mode 100644 index 0000000000000000000000000000000000000000..b0bbcd497f4689bed4faf20e8e47c0fc4e282812 --- /dev/null +++ b/imaginaire/layers/residual_deep.py @@ -0,0 +1,346 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +from imaginaire.third_party.upfirdn2d import BlurDownsample, BlurUpsample +from .conv import Conv2dBlock + + +class _BaseDeepResBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, + order, block, learn_shortcut, output_scale, skip_block=None, + blur=True, border_free=True, resample_first=True, + skip_weight_norm=True, hidden_channel_ratio=4): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.output_scale = output_scale + self.resample_first = resample_first + self.stride = stride + self.blur = blur + self.border_free = border_free + assert not border_free + if skip_block is None: + skip_block = block + + if order == 'pre_act': + order = 'NACNAC' + if isinstance(bias, bool): + # The bias for conv_block_0, conv_block_1, and conv_block_s. + biases = [bias, bias, bias] + elif isinstance(bias, list): + if len(bias) == 3: + biases = bias + else: + raise ValueError('Bias list must be 3.') + else: + raise ValueError('Bias must be either an integer or s list.') + self.learn_shortcut = learn_shortcut + if len(order) > 6 or len(order) < 5: + raise ValueError('order must be either 5 or 6 characters') + hidden_channels = in_channels // hidden_channel_ratio + + # Parameters. + residual_params = {} + shortcut_params = {} + base_params = dict(dilation=dilation, + groups=groups, + padding_mode=padding_mode) + residual_params.update(base_params) + residual_params.update( + dict(activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + weight_norm_type=weight_norm_type, + weight_norm_params=weight_norm_params, + apply_noise=apply_noise) + ) + shortcut_params.update(base_params) + shortcut_params.update(dict(kernel_size=1)) + if skip_activation_norm: + shortcut_params.update( + dict(activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + apply_noise=False)) + if skip_weight_norm: + shortcut_params.update( + dict(weight_norm_type=weight_norm_type, + weight_norm_params=weight_norm_params)) + + # Residual branch. + if order.find('A') < order.find('C') and \ + (activation_norm_type == '' or activation_norm_type == 'none'): + # Nonlinearity is the first operation in the residual path. + # In-place nonlinearity will modify the input variable and cause + # backward error. + first_inplace = False + else: + first_inplace = inplace_nonlinearity + + (first_stride, second_stride, shortcut_stride, + first_blur, second_blur, shortcut_blur) = self._get_stride_blur() + + self.conv_block_1x1_in = block( + in_channels, hidden_channels, + 1, 1, 0, + bias=biases[0], + nonlinearity=nonlinearity, + order=order[0:3], + inplace_nonlinearity=first_inplace, + **residual_params + ) + + self.conv_block_0 = block( + hidden_channels, hidden_channels, + kernel_size=2 if self.border_free and first_stride < 1 else + kernel_size, + padding=padding, + bias=biases[0], + nonlinearity=nonlinearity, + order=order[0:3], + inplace_nonlinearity=inplace_nonlinearity, + stride=first_stride, + blur=first_blur, + **residual_params + ) + self.conv_block_1 = block( + hidden_channels, hidden_channels, + kernel_size=kernel_size, + padding=padding, + bias=biases[1], + nonlinearity=nonlinearity, + order=order[3:], + inplace_nonlinearity=inplace_nonlinearity, + stride=second_stride, + blur=second_blur, + **residual_params + ) + + self.conv_block_1x1_out = block( + hidden_channels, out_channels, + 1, 1, 0, + bias=biases[1], + nonlinearity=nonlinearity, + order=order[0:3], + inplace_nonlinearity=inplace_nonlinearity, + **residual_params + ) + + # Shortcut branch. + if self.learn_shortcut: + if skip_nonlinearity: + skip_nonlinearity_type = nonlinearity + else: + skip_nonlinearity_type = '' + self.conv_block_s = skip_block(in_channels, out_channels, + bias=biases[2], + nonlinearity=skip_nonlinearity_type, + order=order[0:3], + stride=shortcut_stride, + blur=shortcut_blur, + **shortcut_params) + elif in_channels < out_channels: + if skip_nonlinearity: + skip_nonlinearity_type = nonlinearity + else: + skip_nonlinearity_type = '' + self.conv_block_s = skip_block(in_channels, + out_channels - in_channels, + bias=biases[2], + nonlinearity=skip_nonlinearity_type, + order=order[0:3], + stride=shortcut_stride, + blur=shortcut_blur, + **shortcut_params) + + # Whether this block expects conditional inputs. + self.conditional = \ + getattr(self.conv_block_0, 'conditional', False) or \ + getattr(self.conv_block_1, 'conditional', False) or \ + getattr(self.conv_block_1x1_in, 'conditional', False) or \ + getattr(self.conv_block_1x1_out, 'conditional', False) + + def _get_stride_blur(self): + if self.stride > 1: + # Downsampling. + first_stride, second_stride = 1, self.stride + first_blur, second_blur = False, self.blur + shortcut_blur = False + shortcut_stride = 1 + if self.blur: + # The shortcut branch uses blur_downsample + stride-1 conv + if self.border_free: + self.resample = nn.AvgPool2d(2) + else: + self.resample = BlurDownsample() + else: + shortcut_stride = self.stride + self.resample = nn.AvgPool2d(2) + elif self.stride < 1: + # Upsampling. + first_stride, second_stride = self.stride, 1 + first_blur, second_blur = self.blur, False + shortcut_blur = False + shortcut_stride = 1 + if self.blur: + # The shortcut branch uses blur_upsample + stride-1 conv + if self.border_free: + self.resample = nn.Upsample(scale_factor=2, + mode='bilinear') + else: + self.resample = BlurUpsample() + else: + shortcut_stride = self.stride + self.resample = nn.Upsample(scale_factor=2) + else: + first_stride = second_stride = 1 + first_blur = second_blur = False + shortcut_stride = 1 + shortcut_blur = False + self.resample = None + return (first_stride, second_stride, shortcut_stride, + first_blur, second_blur, shortcut_blur) + + def conv_blocks( + self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs + ): + if separate_cond: + assert len(list(cond_inputs)) == 4 + dx = self.conv_block_1x1_in(x, cond_inputs[0], + **kw_cond_inputs.get('kwargs_0', {})) + dx = self.conv_block_0(dx, cond_inputs[1], + **kw_cond_inputs.get('kwargs_1', {})) + dx = self.conv_block_1(dx, cond_inputs[2], + **kw_cond_inputs.get('kwargs_2', {})) + dx = self.conv_block_1x1_out(dx, cond_inputs[3], + **kw_cond_inputs.get('kwargs_3', {})) + else: + dx = self.conv_block_1x1_in(x, *cond_inputs, **kw_cond_inputs) + dx = self.conv_block_0(dx, *cond_inputs, **kw_cond_inputs) + dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs) + dx = self.conv_block_1x1_out(dx, *cond_inputs, **kw_cond_inputs) + return dx + + def forward(self, x, *cond_inputs, do_checkpoint=False, **kw_cond_inputs): + if do_checkpoint: + dx = checkpoint(self.conv_blocks, x, *cond_inputs, **kw_cond_inputs) + else: + dx = self.conv_blocks(x, *cond_inputs, **kw_cond_inputs) + + if self.resample_first and self.resample is not None: + x = self.resample(x) + if self.learn_shortcut: + x_shortcut = self.conv_block_s( + x, *cond_inputs, **kw_cond_inputs + ) + elif self.in_channels < self.out_channels: + x_shortcut_pad = self.conv_block_s( + x, *cond_inputs, **kw_cond_inputs + ) + x_shortcut = torch.cat((x, x_shortcut_pad), dim=1) + elif self.in_channels > self.out_channels: + x_shortcut = x[:, :self.out_channels, :, :] + else: + x_shortcut = x + if not self.resample_first and self.resample is not None: + x_shortcut = self.resample(x_shortcut) + + output = x_shortcut + dx + return self.output_scale * output + + def extra_repr(self): + s = 'output_scale={output_scale}' + return s.format(**self.__dict__) + + +class DeepRes2dBlock(_BaseDeepResBlock): + r"""Residual block for 2D input. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + skip_weight_norm=True, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=False, output_scale=1, + blur=True, resample_first=True, border_free=False): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, Conv2dBlock, + learn_shortcut, output_scale, blur=blur, + resample_first=resample_first, border_free=border_free, + skip_weight_norm=skip_weight_norm) diff --git a/imaginaire/layers/vit.py b/imaginaire/layers/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..abd0d039d715444efc3c5a1e9889330bfa5b4c4f --- /dev/null +++ b/imaginaire/layers/vit.py @@ -0,0 +1,204 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from types import SimpleNamespace + +import torch +from torch import nn + +from .misc import ApplyNoise +from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur + + +class ViT2dBlock(nn.Module): + r"""An abstract wrapper class that wraps a torch convolution or linear layer + with normalization and nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + apply_noise, blur, order, input_dim, clamp, + blur_kernel=(1, 3, 3, 1), output_scale=None, + init_gain=1.0): + super().__init__() + from .nonlinearity import get_nonlinearity_layer + from .weight_norm import get_weight_norm_layer + from .activation_norm import get_activation_norm_layer + self.weight_norm_type = weight_norm_type + self.stride = stride + self.clamp = clamp + self.init_gain = init_gain + + # Nonlinearity layer. + if 'fused' in nonlinearity: + # Fusing nonlinearity with bias. + lr_mul = getattr(weight_norm_params, 'lr_mul', 1) + conv_before_nonlinearity = order.find('C') < order.find('A') + if conv_before_nonlinearity: + assert bias + bias = False + channel = out_channels if conv_before_nonlinearity else in_channels + nonlinearity_layer = get_nonlinearity_layer( + nonlinearity, inplace=inplace_nonlinearity, + num_channels=channel, lr_mul=lr_mul) + else: + nonlinearity_layer = get_nonlinearity_layer( + nonlinearity, inplace=inplace_nonlinearity) + + # Noise injection layer. + if apply_noise: + order = order.replace('C', 'CG') + noise_layer = ApplyNoise() + else: + noise_layer = None + + # Convolutional layer. + if blur: + if stride == 2: + # Blur - Conv - Noise - Activate + p = (len(blur_kernel) - 2) + (kernel_size - 1) + pad0, pad1 = (p + 1) // 2, p // 2 + padding = 0 + blur_layer = Blur( + blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode + ) + order = order.replace('C', 'BC') + elif stride == 0.5: + # Conv - Blur - Noise - Activate + padding = 0 + p = (len(blur_kernel) - 2) - (kernel_size - 1) + pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1 + blur_layer = Blur( + blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode + ) + order = order.replace('C', 'CB') + elif stride == 1: + # No blur for now + blur_layer = nn.Identity() + else: + raise NotImplementedError + else: + blur_layer = nn.Identity() + + if weight_norm_params is None: + weight_norm_params = SimpleNamespace() + weight_norm = get_weight_norm_layer( + weight_norm_type, **vars(weight_norm_params)) + conv_layer = weight_norm(self._get_conv_layer( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, input_dim)) + + # Normalization layer. + conv_before_norm = order.find('C') < order.find('N') + norm_channels = out_channels if conv_before_norm else in_channels + if activation_norm_params is None: + activation_norm_params = SimpleNamespace() + activation_norm_layer = get_activation_norm_layer( + norm_channels, + activation_norm_type, + input_dim, + **vars(activation_norm_params)) + + # Mapping from operation names to layers. + mappings = {'C': {'conv': conv_layer}, + 'N': {'norm': activation_norm_layer}, + 'A': {'nonlinearity': nonlinearity_layer}} + mappings.update({'B': {'blur': blur_layer}}) + mappings.update({'G': {'noise': noise_layer}}) + + # All layers in order. + self.layers = nn.ModuleDict() + for op in order: + if list(mappings[op].values())[0] is not None: + self.layers.update(mappings[op]) + + # Whether this block expects conditional inputs. + self.conditional = \ + getattr(conv_layer, 'conditional', False) or \ + getattr(activation_norm_layer, 'conditional', False) + + if output_scale is not None: + self.output_scale = nn.Parameter(torch.tensor(output_scale)) + else: + self.register_parameter("output_scale", None) + + def forward(self, x, *cond_inputs, **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + kw_cond_inputs (dict) : Keyword conditional inputs. + """ + for key, layer in self.layers.items(): + if getattr(layer, 'conditional', False): + # Layers that require conditional inputs. + x = layer(x, *cond_inputs, **kw_cond_inputs) + else: + x = layer(x) + if self.clamp is not None and isinstance(layer, nn.Conv2d): + x.clamp_(max=self.clamp) + if key == 'conv': + if self.output_scale is not None: + x = x * self.output_scale + return x + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + # Returns the convolutional layer. + if input_dim == 0: + layer = nn.Linear(in_channels, out_channels, bias) + else: + if stride < 1: # Fractionally-strided convolution. + padding_mode = 'zeros' + assert padding == 0 + layer_type = getattr(nn, f'ConvTranspose{input_dim}d') + stride = round(1 / stride) + else: + layer_type = getattr(nn, f'Conv{input_dim}d') + layer = layer_type( + in_channels, out_channels, kernel_size, stride, padding, + dilation=dilation, groups=groups, bias=bias, + padding_mode=padding_mode + ) + + return layer + + def __repr__(self): + main_str = self._get_name() + '(' + child_lines = [] + for name, layer in self.layers.items(): + mod_str = repr(layer) + if name == 'conv' and self.weight_norm_type != 'none' and \ + self.weight_norm_type != '': + mod_str = mod_str[:-1] + \ + ', weight_norm={}'.format(self.weight_norm_type) + ')' + if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1: + mod_str = mod_str[:-1] + \ + ', lr_mul={}'.format(layer.base_lr_mul) + ')' + mod_str = self._addindent(mod_str, 2) + child_lines.append(mod_str) + if len(child_lines) == 1: + main_str += child_lines[0] + else: + main_str += '\n ' + '\n '.join(child_lines) + '\n' + + main_str += ')' + return main_str + + @staticmethod + def _addindent(s_, numSpaces): + s = s_.split('\n') + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s diff --git a/imaginaire/layers/weight_norm.py b/imaginaire/layers/weight_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..e15ca2d21cea70062fa24ffdcd5adab51c8dcb25 --- /dev/null +++ b/imaginaire/layers/weight_norm.py @@ -0,0 +1,267 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import collections +import functools + +import torch +from torch import nn +from torch.nn.utils import spectral_norm, weight_norm +from torch.nn.utils.spectral_norm import SpectralNorm, \ + SpectralNormStateDictHook, SpectralNormLoadStateDictPreHook + +from .conv import LinearBlock + + +class WeightDemodulation(nn.Module): + r"""Weight demodulation in + "Analyzing and Improving the Image Quality of StyleGAN", Karras et al. + + Args: + conv (torch.nn.Modules): Convolutional layer. + cond_dims (int): The number of channels in the conditional input. + eps (float, optional, default=1e-8): a value added to the + denominator for numerical stability. + adaptive_bias (bool, optional, default=False): If ``True``, adaptively + predicts bias from the conditional input. + demod (bool, optional, default=False): If ``True``, performs + weight demodulation. + """ + + def __init__(self, conv, cond_dims, eps=1e-8, + adaptive_bias=False, demod=True): + super().__init__() + self.conv = conv + self.adaptive_bias = adaptive_bias + if adaptive_bias: + self.conv.register_parameter('bias', None) + self.fc_beta = LinearBlock(cond_dims, self.conv.out_channels) + self.fc_gamma = LinearBlock(cond_dims, self.conv.in_channels) + self.eps = eps + self.demod = demod + self.conditional = True + + def forward(self, x, y, **_kwargs): + r"""Weight demodulation forward""" + b, c, h, w = x.size() + self.conv.groups = b + gamma = self.fc_gamma(y) + gamma = gamma[:, None, :, None, None] + weight = self.conv.weight[None, :, :, :, :] * gamma + + if self.demod: + d = torch.rsqrt( + (weight ** 2).sum( + dim=(2, 3, 4), keepdim=True) + self.eps) + weight = weight * d + + x = x.reshape(1, -1, h, w) + _, _, *ws = weight.shape + weight = weight.reshape(b * self.conv.out_channels, *ws) + x = self.conv._conv_forward(x, weight) + + x = x.reshape(-1, self.conv.out_channels, h, w) + if self.adaptive_bias: + x += self.fc_beta(y)[:, :, None, None] + return x + + +def weight_demod( + conv, cond_dims=256, eps=1e-8, adaptive_bias=False, demod=True): + r"""Weight demodulation.""" + return WeightDemodulation(conv, cond_dims, eps, adaptive_bias, demod) + + +class ScaledLR(object): + def __init__(self, weight_name, bias_name): + self.weight_name = weight_name + self.bias_name = bias_name + + def compute_weight(self, module): + weight = getattr(module, self.weight_name + '_ori') + return weight * module.weight_scale + + def compute_bias(self, module): + bias = getattr(module, self.bias_name + '_ori') + if bias is not None: + return bias * module.bias_scale + else: + return None + + @staticmethod + def apply(module, weight_name, bias_name, lr_mul, equalized): + assert weight_name == 'weight' + assert bias_name == 'bias' + fn = ScaledLR(weight_name, bias_name) + module.register_forward_pre_hook(fn) + + if hasattr(module, bias_name): + # module.bias is a parameter (can be None). + bias = getattr(module, bias_name) + delattr(module, bias_name) + module.register_parameter(bias_name + '_ori', bias) + else: + # module.bias does not exist. + bias = None + setattr(module, bias_name + '_ori', bias) + if bias is not None: + setattr(module, bias_name, bias.data) + else: + setattr(module, bias_name, None) + module.register_buffer('bias_scale', torch.tensor(lr_mul)) + + if hasattr(module, weight_name + '_orig'): + # The module has been wrapped with spectral normalization. + # We only want to keep a single weight parameter. + weight = getattr(module, weight_name + '_orig') + delattr(module, weight_name + '_orig') + module.register_parameter(weight_name + '_ori', weight) + setattr(module, weight_name + '_orig', weight.data) + # Put this hook before the spectral norm hook. + module._forward_pre_hooks = collections.OrderedDict( + reversed(list(module._forward_pre_hooks.items())) + ) + module.use_sn = True + else: + weight = getattr(module, weight_name) + delattr(module, weight_name) + module.register_parameter(weight_name + '_ori', weight) + setattr(module, weight_name, weight.data) + module.use_sn = False + + # assert weight.dim() == 4 or weight.dim() == 2 + if equalized: + fan_in = weight.data.size(1) * weight.data[0][0].numel() + # Theoretically, the gain should be sqrt(2) instead of 1. + # The official StyleGAN2 uses 1 for some reason. + module.register_buffer( + 'weight_scale', torch.tensor(lr_mul * ((1 / fan_in) ** 0.5)) + ) + else: + module.register_buffer('weight_scale', torch.tensor(lr_mul)) + + module.lr_mul = module.weight_scale + module.base_lr_mul = lr_mul + + return fn + + def remove(self, module): + with torch.no_grad(): + weight = self.compute_weight(module) + delattr(module, self.weight_name + '_ori') + + if module.use_sn: + setattr(module, self.weight_name + '_orig', weight.detach()) + else: + delattr(module, self.weight_name) + module.register_parameter(self.weight_name, + torch.nn.Parameter(weight.detach())) + + with torch.no_grad(): + bias = self.compute_bias(module) + delattr(module, self.bias_name) + delattr(module, self.bias_name + '_ori') + if bias is not None: + module.register_parameter(self.bias_name, + torch.nn.Parameter(bias.detach())) + else: + module.register_parameter(self.bias_name, None) + + module.lr_mul = 1.0 + module.base_lr_mul = 1.0 + + def __call__(self, module, input): + weight = self.compute_weight(module) + if module.use_sn: + # The following spectral norm hook will compute the SN of + # "module.weight_orig" and store the normalized weight in + # "module.weight". + setattr(module, self.weight_name + '_orig', weight) + else: + setattr(module, self.weight_name, weight) + bias = self.compute_bias(module) + setattr(module, self.bias_name, bias) + + +def remove_weight_norms(module, weight_name='weight', bias_name='bias'): + if hasattr(module, 'weight_ori') or hasattr(module, 'weight_orig'): + for k in list(module._forward_pre_hooks.keys()): + hook = module._forward_pre_hooks[k] + if (isinstance(hook, ScaledLR) or isinstance(hook, SpectralNorm)): + hook.remove(module) + del module._forward_pre_hooks[k] + + for k, hook in module._state_dict_hooks.items(): + if isinstance(hook, SpectralNormStateDictHook) and \ + hook.fn.name == weight_name: + del module._state_dict_hooks[k] + break + + for k, hook in module._load_state_dict_pre_hooks.items(): + if isinstance(hook, SpectralNormLoadStateDictPreHook) and \ + hook.fn.name == weight_name: + del module._load_state_dict_pre_hooks[k] + break + + return module + + +def remove_equalized_lr(module, weight_name='weight', bias_name='bias'): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, ScaledLR) and hook.weight_name == weight_name: + hook.remove(module) + del module._forward_pre_hooks[k] + break + else: + raise ValueError("Equalized learning rate not found") + + return module + + +def scaled_lr( + module, weight_name='weight', bias_name='bias', lr_mul=1., + equalized=False, +): + ScaledLR.apply(module, weight_name, bias_name, lr_mul, equalized) + return module + + +def get_weight_norm_layer(norm_type, **norm_params): + r"""Return weight normalization. + + Args: + norm_type (str): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + norm_params: Arbitrary keyword arguments that will be used to + initialize the weight normalization. + """ + if norm_type == 'none' or norm_type == '': # no normalization + return lambda x: x + elif norm_type == 'spectral': # spectral normalization + return functools.partial(spectral_norm, **norm_params) + elif norm_type == 'weight': # weight normalization + return functools.partial(weight_norm, **norm_params) + elif norm_type == 'weight_demod': # weight demodulation + return functools.partial(weight_demod, **norm_params) + elif norm_type == 'equalized_lr': # equalized learning rate + return functools.partial(scaled_lr, equalized=True, **norm_params) + elif norm_type == 'scaled_lr': # equalized learning rate + return functools.partial(scaled_lr, **norm_params) + elif norm_type == 'equalized_lr_spectral': + lr_mul = norm_params.pop('lr_mul', 1.0) + return lambda x: functools.partial( + scaled_lr, equalized=True, lr_mul=lr_mul)( + functools.partial(spectral_norm, **norm_params)(x) + ) + elif norm_type == 'scaled_lr_spectral': + lr_mul = norm_params.pop('lr_mul', 1.0) + return lambda x: functools.partial( + scaled_lr, lr_mul=lr_mul)( + functools.partial(spectral_norm, **norm_params)(x) + ) + else: + raise ValueError( + 'Weight norm layer %s is not recognized' % norm_type) diff --git a/imaginaire/losses/__init__.py b/imaginaire/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c217e44c1e311ca4886a4a7de968c11d7ee64652 --- /dev/null +++ b/imaginaire/losses/__init__.py @@ -0,0 +1,18 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .gan import GANLoss +from .perceptual import PerceptualLoss +from .feature_matching import FeatureMatchingLoss +from .kl import GaussianKLLoss + +__all__ = ['GANLoss', 'PerceptualLoss', 'FeatureMatchingLoss', 'GaussianKLLoss', + 'MaskedL1Loss', 'FlowLoss', 'DictLoss', + 'WeightedMSELoss'] + +try: + from .gradient_penalty import GradientPenaltyLoss + __all__.extend(['GradientPenaltyLoss']) +except: # noqa + pass diff --git a/imaginaire/losses/feature_matching.py b/imaginaire/losses/feature_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..f70034b3c3afba5a914261b55cf0abeab832391c --- /dev/null +++ b/imaginaire/losses/feature_matching.py @@ -0,0 +1,38 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch.nn as nn + + +class FeatureMatchingLoss(nn.Module): + r"""Compute feature matching loss""" + def __init__(self, criterion='l1'): + super(FeatureMatchingLoss, self).__init__() + if criterion == 'l1': + self.criterion = nn.L1Loss() + elif criterion == 'l2' or criterion == 'mse': + self.criterion = nn.MSELoss() + else: + raise ValueError('Criterion %s is not recognized' % criterion) + + def forward(self, fake_features, real_features): + r"""Return the target vector for the binary cross entropy loss + computation. + + Args: + fake_features (list of lists): Discriminator features of fake images. + real_features (list of lists): Discriminator features of real images. + + Returns: + (tensor): Loss value. + """ + num_d = len(fake_features) + dis_weight = 1.0 / num_d + loss = fake_features[0][0].new_tensor(0) + for i in range(num_d): + for j in range(len(fake_features[i])): + tmp_loss = self.criterion(fake_features[i][j], + real_features[i][j].detach()) + loss += dis_weight * tmp_loss + return loss diff --git a/imaginaire/losses/gan.py b/imaginaire/losses/gan.py new file mode 100644 index 0000000000000000000000000000000000000000..aaa9c30dd51887b25b439b90fa728e94fe2b03a9 --- /dev/null +++ b/imaginaire/losses/gan.py @@ -0,0 +1,173 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from imaginaire.utils.distributed import master_only_print as print + + +@torch.jit.script +def fuse_math_min_mean_pos(x): + r"""Fuse operation min mean for hinge loss computation of positive + samples""" + minval = torch.min(x - 1, x * 0) + loss = -torch.mean(minval) + return loss + + +@torch.jit.script +def fuse_math_min_mean_neg(x): + r"""Fuse operation min mean for hinge loss computation of negative + samples""" + minval = torch.min(-x - 1, x * 0) + loss = -torch.mean(minval) + return loss + + +class GANLoss(nn.Module): + r"""GAN loss constructor. + + Args: + gan_mode (str): Type of GAN loss. ``'hinge'``, ``'least_square'``, + ``'non_saturated'``, ``'wasserstein'``. + target_real_label (float): The desired output label for real images. + target_fake_label (float): The desired output label for fake images. + decay_k (float): The decay factor per epoch for top-k training. + min_k (float): The minimum percentage of samples to select. + separate_topk (bool): If ``True``, selects top-k for each sample + separately, otherwise selects top-k among all samples. + """ + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, + decay_k=1., min_k=1., separate_topk=False): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_tensor = None + self.fake_label_tensor = None + self.gan_mode = gan_mode + self.decay_k = decay_k + self.min_k = min_k + self.separate_topk = separate_topk + self.register_buffer('k', torch.tensor(1.0)) + print('GAN mode: %s' % gan_mode) + + def forward(self, dis_output, t_real, dis_update=True, reduce=True): + r"""GAN loss computation. + + Args: + dis_output (tensor or list of tensors): Discriminator outputs. + t_real (bool): If ``True``, uses the real label as target, otherwise uses the fake label as target. + dis_update (bool): If ``True``, the loss will be used to update the discriminator, otherwise the generator. + reduce (bool): If ``True``, when a list of discriminator outputs are provided, it will return the average + of all losses, otherwise it will return a list of losses. + Returns: + loss (tensor): Loss value. + """ + if isinstance(dis_output, list): + # For multi-scale discriminators. + # In this implementation, the loss is first averaged for each scale + # (batch size and number of locations) then averaged across scales, + # so that the gradient is not dominated by the discriminator that + # has the most output values (highest resolution). + losses = [] + for dis_output_i in dis_output: + assert isinstance(dis_output_i, torch.Tensor) + losses.append(self.loss(dis_output_i, t_real, dis_update)) + if reduce: + return torch.mean(torch.stack(losses)) + else: + return losses + else: + return self.loss(dis_output, t_real, dis_update) + + def loss(self, dis_output, t_real, dis_update=True): + r"""GAN loss computation. + + Args: + dis_output (tensor): Discriminator outputs. + t_real (bool): If ``True``, uses the real label as target, otherwise + uses the fake label as target. + dis_update (bool): Updating the discriminator or the generator. + Returns: + loss (tensor): Loss value. + """ + if not dis_update: + assert t_real, \ + "The target should be real when updating the generator." + + if not dis_update and self.k < 1: + r""" + Use top-k training: + "Top-k Training of GANs: Improving GAN Performance by Throwing + Away Bad Samples" + Here, each sample may have multiple discriminator output values + (patch discriminator). We could either select top-k for each sample + separately (when ``self.separate_topk=True``), or collect values + from all samples and then select top-k (default, when + ``self.separate_topk=False``). + """ + if self.separate_topk: + dis_output = dis_output.view(dis_output.size(0), -1) + else: + dis_output = dis_output.view(-1) + k = math.ceil(self.k * dis_output.size(-1)) + dis_output, _ = torch.topk(dis_output, k) + + if self.gan_mode == 'non_saturated': + target_tensor = self.get_target_tensor(dis_output, t_real) + loss = F.binary_cross_entropy_with_logits(dis_output, + target_tensor) + elif self.gan_mode == 'least_square': + target_tensor = self.get_target_tensor(dis_output, t_real) + loss = 0.5 * F.mse_loss(dis_output, target_tensor) + elif self.gan_mode == 'hinge': + if dis_update: + if t_real: + loss = fuse_math_min_mean_pos(dis_output) + else: + loss = fuse_math_min_mean_neg(dis_output) + else: + loss = -torch.mean(dis_output) + elif self.gan_mode == 'wasserstein': + if t_real: + loss = -torch.mean(dis_output) + else: + loss = torch.mean(dis_output) + elif self.gan_mode == 'softplus': + target_tensor = self.get_target_tensor(dis_output, t_real) + loss = F.binary_cross_entropy_with_logits(dis_output, + target_tensor) + else: + raise ValueError('Unexpected gan_mode {}'.format(self.gan_mode)) + return loss + + def get_target_tensor(self, dis_output, t_real): + r"""Return the target vector for the binary cross entropy loss + computation. + + Args: + dis_output (tensor): Discriminator outputs. + t_real (bool): If ``True``, uses the real label as target, otherwise + uses the fake label as target. + Returns: + target (tensor): Target tensor vector. + """ + if t_real: + if self.real_label_tensor is None: + self.real_label_tensor = dis_output.new_tensor(self.real_label) + return self.real_label_tensor.expand_as(dis_output) + else: + if self.fake_label_tensor is None: + self.fake_label_tensor = dis_output.new_tensor(self.fake_label) + return self.fake_label_tensor.expand_as(dis_output) + + def topk_anneal(self): + r"""Anneal k after each epoch.""" + if self.decay_k < 1: + # noinspection PyAttributeOutsideInit + self.k.fill_(max(self.decay_k * self.k, self.min_k)) + print("Top-k training: update k to {}.".format(self.k)) diff --git a/imaginaire/losses/info_nce.py b/imaginaire/losses/info_nce.py new file mode 100644 index 0000000000000000000000000000000000000000..8033e828f0b99d12d6e8f8b71811982d0ab568f6 --- /dev/null +++ b/imaginaire/losses/info_nce.py @@ -0,0 +1,87 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from imaginaire.utils.distributed import get_world_size, get_rank, \ + dist_all_reduce_tensor + + +class GatherLayer(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] + dist.all_gather(output, input) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + input, = ctx.saved_tensors + grad_out = torch.zeros_like(input) + all_grads = torch.stack(grads) + all_grads = dist_all_reduce_tensor(all_grads, reduce='sum') + grad_out[:] = all_grads[get_rank()] + return grad_out + + +class InfoNCELoss(nn.Module): + def __init__(self, + temperature=0.07, + gather_distributed=True, + learn_temperature=True, + single_direction=False, + flatten=True): + super(InfoNCELoss, self).__init__() + self.logit_scale = nn.Parameter(torch.tensor([math.log(1/temperature)])) + self.logit_scale.requires_grad = learn_temperature + self.gather_distributed = gather_distributed + self.single_direction = single_direction + self.flatten = flatten + + def forward(self, features_a, features_b, gather_distributed=None, eps=1e-8): + if gather_distributed is None: + gather_distributed = self.gather_distributed + + if features_a is None or features_b is None: + return torch.tensor(0, device='cuda'), torch.tensor(0, device='cuda') + + bs_a, bs_b = features_a.size(0), features_b.size(0) + if self.flatten: + features_a, features_b = features_a.reshape(bs_a, -1), features_b.reshape(bs_b, -1) + else: + features_a = features_a.reshape(bs_a, features_a.size(1), -1).mean(-1) + features_b = features_b.reshape(bs_b, features_b.size(1), -1).mean(-1) + + # Temperature clipping. + self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 4.6052) + + # normalized features + features_a = features_a / (features_a.norm(dim=1, keepdim=True) + eps) + features_b = features_b / (features_b.norm(dim=1, keepdim=True) + eps) + + loss_a = self._forward_single_direction(features_a, features_b, gather_distributed) + if self.single_direction: + return loss_a + else: + loss_b = self._forward_single_direction(features_b, features_a, gather_distributed) + return loss_a + loss_b + + def _forward_single_direction( + self, features_a, features_b, gather_distributed): + bs_a = features_a.shape[0] + logit_scale = self.logit_scale.exp() + if get_world_size() > 1 and gather_distributed: + gather_features_b = torch.cat(GatherLayer.apply(features_b)) + gather_labels_a = torch.arange(bs_a, device='cuda') + get_rank() * bs_a + logits_a = logit_scale * features_a @ gather_features_b.t() + else: + gather_labels_a = torch.arange(bs_a, device='cuda') + logits_a = logit_scale * features_a @ features_b.t() + loss_a = F.cross_entropy(logits_a, gather_labels_a) + return loss_a diff --git a/imaginaire/losses/kl.py b/imaginaire/losses/kl.py new file mode 100644 index 0000000000000000000000000000000000000000..d73db46d16c752a1041bf37ce830af9f76f740e7 --- /dev/null +++ b/imaginaire/losses/kl.py @@ -0,0 +1,23 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +import torch.nn as nn + + +class GaussianKLLoss(nn.Module): + r"""Compute KL loss in VAE for Gaussian distributions""" + def __init__(self): + super(GaussianKLLoss, self).__init__() + + def forward(self, mu, logvar=None): + r"""Compute loss + + Args: + mu (tensor): mean + logvar (tensor): logarithm of variance + """ + if logvar is None: + logvar = torch.zeros_like(mu) + return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) diff --git a/imaginaire/losses/perceptual.py b/imaginaire/losses/perceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..424656fa09b65333e4fa28cae2de7114de69ebfa --- /dev/null +++ b/imaginaire/losses/perceptual.py @@ -0,0 +1,395 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +import torch.nn.functional as F +import torchvision +from torch import nn, distributed as dist + +from imaginaire.losses.info_nce import InfoNCELoss +from imaginaire.utils.distributed import master_only_print as print, \ + is_local_master +from imaginaire.utils.misc import apply_imagenet_normalization, to_float + + +class PerceptualLoss(nn.Module): + r"""Perceptual loss initialization. + + Args: + network (str) : The name of the loss network: 'vgg16' | 'vgg19'. + layers (str or list of str) : The layers used to compute the loss. + weights (float or list of float : The loss weights of each layer. + criterion (str): The type of distance function: 'l1' | 'l2'. + resize (bool) : If ``True``, resize the input images to 224x224. + resize_mode (str): Algorithm used for resizing. + num_scales (int): The loss will be evaluated at original size and + this many times downsampled sizes. + per_sample_weight (bool): Output loss for individual samples in the + batch instead of mean loss. + """ + + def __init__(self, network='vgg19', layers='relu_4_1', weights=None, + criterion='l1', resize=False, resize_mode='bilinear', + num_scales=1, per_sample_weight=False, + info_nce_temperature=0.07, + info_nce_gather_distributed=True, + info_nce_learn_temperature=True, + info_nce_flatten=True): + super().__init__() + if isinstance(layers, str): + layers = [layers] + if weights is None: + weights = [1.] * len(layers) + elif isinstance(layers, float) or isinstance(layers, int): + weights = [weights] + + if dist.is_initialized() and not is_local_master(): + # Make sure only the first process in distributed training downloads + # the model, and the others will use the cache + # noinspection PyUnresolvedReferences + torch.distributed.barrier() + + assert len(layers) == len(weights), \ + 'The number of layers (%s) must be equal to ' \ + 'the number of weights (%s).' % (len(layers), len(weights)) + if network == 'vgg19': + self.model = _vgg19(layers) + elif network == 'vgg16': + self.model = _vgg16(layers) + elif network == 'alexnet': + self.model = _alexnet(layers) + elif network == 'inception_v3': + self.model = _inception_v3(layers) + elif network == 'resnet50': + self.model = _resnet50(layers) + elif network == 'robust_resnet50': + self.model = _robust_resnet50(layers) + elif network == 'vgg_face_dag': + self.model = _vgg_face_dag(layers) + else: + raise ValueError('Network %s is not recognized' % network) + + if dist.is_initialized() and is_local_master(): + # Make sure only the first process in distributed training downloads + # the model, and the others will use the cache + # noinspection PyUnresolvedReferences + torch.distributed.barrier() + + self.num_scales = num_scales + self.layers = layers + self.weights = weights + reduction = 'mean' if not per_sample_weight else 'none' + if criterion == 'l1': + self.criterion = nn.L1Loss(reduction=reduction) + elif criterion == 'l2' or criterion == 'mse': + self.criterion = nn.MSELoss(reduction=reduction) + elif criterion == 'info_nce': + self.criterion = InfoNCELoss( + temperature=info_nce_temperature, + gather_distributed=info_nce_gather_distributed, + learn_temperature=info_nce_learn_temperature, + flatten=info_nce_flatten, + single_direction=True + ) + else: + raise ValueError('Criterion %s is not recognized' % criterion) + self.resize = resize + self.resize_mode = resize_mode + print('Perceptual loss:') + print('\tMode: {}'.format(network)) + + def forward(self, inp, target, per_sample_weights=None): + r"""Perceptual loss forward. + + Args: + inp (4D tensor) : Input tensor. + target (4D tensor) : Ground truth tensor, same shape as the input. + per_sample_weight (bool): Output loss for individual samples in the + batch instead of mean loss. + Returns: + (scalar tensor) : The perceptual loss. + """ + if not torch.is_autocast_enabled(): + inp, target = to_float([inp, target]) + + # Perceptual loss should operate in eval mode by default. + self.model.eval() + inp, target = apply_imagenet_normalization(inp), apply_imagenet_normalization(target) + if self.resize: + inp = F.interpolate(inp, mode=self.resize_mode, size=(224, 224), align_corners=False) + target = F.interpolate(target, mode=self.resize_mode, size=(224, 224), align_corners=False) + + # Evaluate perceptual loss at each scale. + loss = 0 + for scale in range(self.num_scales): + input_features, target_features = self.model(inp), self.model(target) + + for layer, weight in zip(self.layers, self.weights): + # Example per-layer VGG19 loss values after applying + # [0.03125, 0.0625, 0.125, 0.25, 1.0] weighting. + # relu_1_1, 0.014698 + # relu_2_1, 0.085817 + # relu_3_1, 0.349977 + # relu_4_1, 0.544188 + # relu_5_1, 0.906261 + # print('%s, %f' % ( + # layer, + # weight * self.criterion( + # input_features[layer], + # target_features[ + # layer].detach()).item())) + l_tmp = self.criterion(input_features[layer], target_features[layer].detach()) + if per_sample_weights is not None: + l_tmp = l_tmp.mean(1).mean(1).mean(1) + loss += weight * l_tmp + # Downsample the input and target. + if scale != self.num_scales - 1: + inp = F.interpolate( + inp, mode=self.resize_mode, scale_factor=0.5, + align_corners=False, recompute_scale_factor=True) + target = F.interpolate( + target, mode=self.resize_mode, scale_factor=0.5, + align_corners=False, recompute_scale_factor=True) + + return loss.float() + + +class _PerceptualNetwork(nn.Module): + r"""The network that extracts features to compute the perceptual loss. + + Args: + network (nn.Sequential) : The network that extracts features. + layer_name_mapping (dict) : The dictionary that + maps a layer's index to its name. + layers (list of str): The list of layer names that we are using. + """ + + def __init__(self, network, layer_name_mapping, layers): + super().__init__() + assert isinstance(network, nn.Sequential), \ + 'The network needs to be of type "nn.Sequential".' + self.network = network + self.layer_name_mapping = layer_name_mapping + self.layers = layers + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + r"""Extract perceptual features.""" + output = {} + for i, layer in enumerate(self.network): + x = layer(x) + layer_name = self.layer_name_mapping.get(i, None) + if layer_name in self.layers: + # If the current layer is used by the perceptual loss. + output[layer_name] = x + return output + + +def _vgg19(layers): + r"""Get vgg19 layers""" + vgg = torchvision.models.vgg19(pretrained=True) + # network = vgg.features + network = torch.nn.Sequential(*(list(vgg.features) + [vgg.avgpool] + [nn.Flatten()] + list(vgg.classifier))) + layer_name_mapping = {1: 'relu_1_1', + 3: 'relu_1_2', + 6: 'relu_2_1', + 8: 'relu_2_2', + 11: 'relu_3_1', + 13: 'relu_3_2', + 15: 'relu_3_3', + 17: 'relu_3_4', + 20: 'relu_4_1', + 22: 'relu_4_2', + 24: 'relu_4_3', + 26: 'relu_4_4', + 29: 'relu_5_1', + 31: 'relu_5_2', + 33: 'relu_5_3', + 35: 'relu_5_4', + 36: 'pool_5', + 42: 'fc_2'} + return _PerceptualNetwork(network, layer_name_mapping, layers) + + +def _vgg16(layers): + r"""Get vgg16 layers""" + network = torchvision.models.vgg16(pretrained=True).features + layer_name_mapping = {1: 'relu_1_1', + 3: 'relu_1_2', + 6: 'relu_2_1', + 8: 'relu_2_2', + 11: 'relu_3_1', + 13: 'relu_3_2', + 15: 'relu_3_3', + 18: 'relu_4_1', + 20: 'relu_4_2', + 22: 'relu_4_3', + 25: 'relu_5_1'} + return _PerceptualNetwork(network, layer_name_mapping, layers) + + +def _alexnet(layers): + r"""Get alexnet layers""" + network = torchvision.models.alexnet(pretrained=True).features + layer_name_mapping = {0: 'conv_1', + 1: 'relu_1', + 3: 'conv_2', + 4: 'relu_2', + 6: 'conv_3', + 7: 'relu_3', + 8: 'conv_4', + 9: 'relu_4', + 10: 'conv_5', + 11: 'relu_5'} + return _PerceptualNetwork(network, layer_name_mapping, layers) + + +def _inception_v3(layers): + r"""Get inception v3 layers""" + inception = torchvision.models.inception_v3(pretrained=True) + network = nn.Sequential(inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2), + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2), + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1))) + layer_name_mapping = {3: 'pool_1', + 6: 'pool_2', + 14: 'mixed_6e', + 18: 'pool_3'} + return _PerceptualNetwork(network, layer_name_mapping, layers) + + +def _resnet50(layers): + r"""Get resnet50 layers""" + resnet50 = torchvision.models.resnet50(pretrained=True) + network = nn.Sequential(resnet50.conv1, + resnet50.bn1, + resnet50.relu, + resnet50.maxpool, + resnet50.layer1, + resnet50.layer2, + resnet50.layer3, + resnet50.layer4, + resnet50.avgpool) + layer_name_mapping = {4: 'layer_1', + 5: 'layer_2', + 6: 'layer_3', + 7: 'layer_4'} + return _PerceptualNetwork(network, layer_name_mapping, layers) + + +def _robust_resnet50(layers): + r"""Get robust resnet50 layers""" + resnet50 = torchvision.models.resnet50(pretrained=False) + state_dict = torch.utils.model_zoo.load_url( + 'http://andrewilyas.com/ImageNet.pt') + new_state_dict = {} + for k, v in state_dict['model'].items(): + if k.startswith('module.model.'): + new_state_dict[k[13:]] = v + resnet50.load_state_dict(new_state_dict) + network = nn.Sequential(resnet50.conv1, + resnet50.bn1, + resnet50.relu, + resnet50.maxpool, + resnet50.layer1, + resnet50.layer2, + resnet50.layer3, + resnet50.layer4, + resnet50.avgpool) + layer_name_mapping = {4: 'layer_1', + 5: 'layer_2', + 6: 'layer_3', + 7: 'layer_4'} + return _PerceptualNetwork(network, layer_name_mapping, layers) + + +def _vgg_face_dag(layers): + network = torchvision.models.vgg16(num_classes=2622) + state_dict = torch.utils.model_zoo.load_url( + 'http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/' + 'vgg_face_dag.pth') + feature_layer_name_mapping = { + 0: 'conv1_1', + 2: 'conv1_2', + 5: 'conv2_1', + 7: 'conv2_2', + 10: 'conv3_1', + 12: 'conv3_2', + 14: 'conv3_3', + 17: 'conv4_1', + 19: 'conv4_2', + 21: 'conv4_3', + 24: 'conv5_1', + 26: 'conv5_2', + 28: 'conv5_3'} + new_state_dict = {} + for k, v in feature_layer_name_mapping.items(): + new_state_dict['features.' + str(k) + '.weight'] = \ + state_dict[v + '.weight'] + new_state_dict['features.' + str(k) + '.bias'] = \ + state_dict[v + '.bias'] + + classifier_layer_name_mapping = { + 0: 'fc6', + 3: 'fc7', + 6: 'fc8'} + for k, v in classifier_layer_name_mapping.items(): + new_state_dict['classifier.' + str(k) + '.weight'] = \ + state_dict[v + '.weight'] + new_state_dict['classifier.' + str(k) + '.bias'] = \ + state_dict[v + '.bias'] + + network.load_state_dict(new_state_dict) + + class Flatten(nn.Module): + def forward(self, x): + return x.view(x.shape[0], -1) + + layer_name_mapping = { + 0: 'conv_1_1', + 1: 'relu_1_1', + 2: 'conv_1_2', + 5: 'conv_2_1', # 1/2 + 6: 'relu_2_1', + 7: 'conv_2_2', + 10: 'conv_3_1', # 1/4 + 11: 'relu_3_1', + 12: 'conv_3_2', + 14: 'conv_3_3', + 17: 'conv_4_1', # 1/8 + 18: 'relu_4_1', + 19: 'conv_4_2', + 21: 'conv_4_3', + 24: 'conv_5_1', # 1/16 + 25: 'relu_5_1', + 26: 'conv_5_2', + 28: 'conv_5_3', + 33: 'fc6', + 36: 'fc7', + 39: 'fc8' + } + seq_layers = [] + for feature in network.features: + seq_layers += [feature] + seq_layers += [network.avgpool, Flatten()] + for classifier in network.classifier: + seq_layers += [classifier] + network = nn.Sequential(*seq_layers) + return _PerceptualNetwork(network, layer_name_mapping, layers) diff --git a/imaginaire/losses/weighted_mse.py b/imaginaire/losses/weighted_mse.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e49989a5c3ee8576dcf4dea8a98a16c1911cc9 --- /dev/null +++ b/imaginaire/losses/weighted_mse.py @@ -0,0 +1,28 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +import torch.nn as nn + + +class WeightedMSELoss(nn.Module): + r"""Compute Weighted MSE loss""" + def __init__(self, reduction='mean'): + super(WeightedMSELoss, self).__init__() + self.reduction = reduction + + def forward(self, input, target, weight): + r"""Return weighted MSE Loss. + Args: + input (tensor): + target (tensor): + weight (tensor): + Returns: + (tensor): Loss value. + """ + if self.reduction == 'mean': + loss = torch.mean(weight * (input - target) ** 2) + else: + loss = torch.sum(weight * (input - target) ** 2) + return loss diff --git a/imaginaire/model_utils/__init__.py b/imaginaire/model_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780 --- /dev/null +++ b/imaginaire/model_utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md diff --git a/imaginaire/model_utils/gancraft/camctl.py b/imaginaire/model_utils/gancraft/camctl.py new file mode 100644 index 0000000000000000000000000000000000000000..120ab38672509e5c0c42a3308418831717ed1937 --- /dev/null +++ b/imaginaire/model_utils/gancraft/camctl.py @@ -0,0 +1,679 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import numpy as np +import torch + + +class EvalCameraController: + def __init__(self, voxel, maxstep=128, pattern=0, cam_ang=73, smooth_decay_multiplier=1.0): + self.voxel = voxel + self.maxstep = maxstep + self.camera_poses = [] # ori, dir, up, f + circle = torch.linspace(0, 2*np.pi, steps=maxstep) + size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2 + # Shrink the circle a bit. + shift = size * 0.2 + size = size * 0.8 + + if pattern == 0: + height_history = [] + # Calculate smooth height. + for i in range(maxstep): + farpoint = torch.tensor([ + 70, + torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift]) + height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0])) + + # Filtfilt + height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier) + + for i in range(maxstep): + farpoint = torch.tensor([ + 70, + torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = height_history[i] + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + elif pattern == 1: + zoom = torch.linspace(1.0, 0.25, steps=maxstep) + height_history = [] + for i in range(maxstep): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift]) + + height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0])) + + height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier) + + for i in range(maxstep): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = height_history[i] + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)*zoom[i]) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + elif pattern == 2: + move = torch.linspace(1.0, 0.2, steps=maxstep) + height_history = [] + for i in range(maxstep): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0])) + + height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier) + + for i in range(maxstep): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = height_history[i] + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + elif pattern == 3: + move = torch.linspace(0.75, 0.2, steps=maxstep) + height_history = [] + for i in range(maxstep): + farpoint = torch.tensor([ + 70, + torch.sin(-circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(-circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0])) + + height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier) + + for i in range(maxstep): + farpoint = torch.tensor([ + 70, + torch.sin(-circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(-circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = height_history[i] + + nearpoint = torch.tensor([ + 60, + torch.sin(-circle[i]-0.4*np.pi)*size*0.9*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(-circle[i]-0.4*np.pi)*size*0.9*move[i] + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + elif pattern == 4: + move = torch.linspace(1.0, 0.5, steps=maxstep) + height_history = [] + for i in range(maxstep): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0])) + + height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier) + + for i in range(maxstep): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = height_history[i] + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + # look outward + elif pattern == 5: + move = torch.linspace(1.0, 0.5, steps=maxstep) + height_history = [] + for i in range(maxstep): + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + height_history.append(self._get_height(nearpoint[1], nearpoint[2], nearpoint[0])) + + height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier) + + for i in range(maxstep): + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + nearpoint[0] = height_history[i] + + farpoint = torch.tensor([ + 60, + torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + cam_ori = self.voxel.world2local(nearpoint) + cam_dir = self.voxel.world2local(farpoint - nearpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + # Rise + elif pattern == 6: + shift = 0 + lift = torch.linspace(0.0, 200.0, steps=maxstep) + zoom = torch.linspace(0.8, 1.6, steps=maxstep) + for i in range(maxstep): + farpoint = torch.tensor([ + 80+lift[i], + torch.sin(circle[i]/4)*size*0.2 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]/4)*size*0.2 + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0]) + + nearpoint = torch.tensor([ + 65, + torch.sin(circle[i]/4+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]/4+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + # 45deg + elif pattern == 7: + rad = torch.tensor([np.deg2rad(45).astype(np.float32)]) + size = 1536 + for i in range(maxstep): + farpoint = torch.tensor([ + 61+size, + torch.sin(rad)*size + voxel.voxel_t.size(1)/2, + torch.cos(rad)*size + voxel.voxel_t.size(2)/2]) + + nearpoint = torch.tensor([ + 61, + voxel.voxel_t.size(1)/2, + voxel.voxel_t.size(2)/2]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(19.5/2)) # about 50mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + elif pattern == 8: + size = self.voxel.voxel_t.size(1) // 2 + for i in range(maxstep): + farpoint = torch.tensor([ + 300, + 0*size + voxel.voxel_t.size(1)//2, + -1*size + voxel.voxel_t.size(2)/2 + size // maxstep * (i - maxstep // 4)]) + nearpoint = torch.tensor([ + 120, + 0*size*0.5 + voxel.voxel_t.size(1)//2, + -1*size*0.5 + voxel.voxel_t.size(2)/2 + size // maxstep * (i - maxstep // 4)]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + elif pattern == 9: + size = self.voxel.voxel_t.size(2) // 2 + for i in range(maxstep): + farpoint = torch.tensor([ + 140, + voxel.voxel_t.size(1)//2, + -size // 4 + size * 8 // maxstep * i] + , dtype=torch.float32) + nearpoint = torch.tensor([ + 100, + voxel.voxel_t.size(1)//2, + size * 8 // maxstep * i] + , dtype=torch.float32) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + + def _get_height(self, loc0, loc1, minheight): + loc0 = int(loc0) + loc1 = int(loc1) + height = minheight + for dx in range(-3, 4): + for dy in range(-3, 4): + if (loc0+dx) < 0 or (loc0+dx) >= self.voxel.heightmap.shape[0] or (loc1+dy) < 0 or \ + (loc1+dy) >= self.voxel.heightmap.shape[1]: + height = max(height, minheight) + else: + height = max(height, self.voxel.heightmap[loc0+dx, loc1+dy] + 2) + return height + + def filtfilt(self, height_history, decay=0.2): + # Filtfilt + height_history2 = [] + maxstep = len(height_history) + prev_height = height_history[0] + for i in range(maxstep): + prev_height = prev_height - decay + if prev_height < height_history[i]: + prev_height = height_history[i] + height_history2.append(prev_height) + prev_height = height_history[-1] + for i in range(maxstep-1, -1, -1): + prev_height = prev_height - decay + if prev_height < height_history[i]: + prev_height = height_history[i] + height_history2[i] = max(prev_height, height_history2[i]) + return height_history2 + + def __len__(self): + return len(self.camera_poses) + + def __getitem__(self, idx): + return self.camera_poses[idx] + + +class TourCameraController: + def __init__(self, voxel, maxstep=128): + self.voxel = voxel + self.maxstep = maxstep + self.camera_poses = [] # ori, dir, up, f + circle = torch.linspace(0, 2*np.pi, steps=maxstep//4) + size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2 + # Shrink the circle a bit + shift = size * 0.2 + size = size * 0.8 + + for i in range(maxstep//4): + farpoint = torch.tensor([ + 70, + torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0]) + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + zoom = torch.linspace(1.0, 0.25, steps=maxstep//4) + for i in range(maxstep//4): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0]) + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + move = torch.linspace(1.0, 0.2, steps=maxstep//4) + for i in range(maxstep//4): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0]) + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + lift = torch.linspace(0.0, 200.0, steps=maxstep//4) + zoom = torch.linspace(0.6, 1.2, steps=maxstep//4) + for i in range(maxstep//4): + farpoint = torch.tensor([ + 80+lift[i], + torch.sin(circle[i])*size*0.2 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*0.2 + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0]) + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + def _get_height(self, loc0, loc1, minheight): + loc0 = int(loc0) + loc1 = int(loc1) + height = minheight + for dx in range(-3, 4): + for dy in range(-3, 4): + if (loc0+dx) < 0 or (loc0+dx) >= self.voxel.heightmap.shape[0] or (loc1+dy) < 0 or \ + (loc1+dy) >= self.voxel.heightmap.shape[1]: + height = max(height, minheight) + else: + height = max(height, self.voxel.heightmap[loc0+dx, loc1+dy] + 2) + return height + + def __len__(self): + return len(self.camera_poses) + + def __getitem__(self, idx): + return self.camera_poses[idx] + + +def rand_camera_pose_birdseye(voxel, border=128): + r"""Generating random camera pose in the upper hemisphere, in the format of origin-direction-up + Assuming [Y X Z] coordinate. Y is negative gravity direction. + The camera pose is converted into the voxel coordinate system so that it can be used directly for rendering + 1. Uniformly sample a point on the upper hemisphere of a unit sphere, as cam_ori. + 2. Set cam_dir to be from cam_ori to the origin + 3. cam_up is always pointing towards sky + 4. move cam_ori to random place according to voxel size + """ + cam_dir = torch.randn(3, dtype=torch.float32) + cam_dir = cam_dir / torch.sqrt(torch.sum(cam_dir*cam_dir)) + cam_dir[0] = -torch.abs(cam_dir[0]) + cam_up = torch.tensor([1, 0, 0], dtype=torch.float32) + + # generate camera lookat target + r = np.random.rand(2) + r[0] *= voxel.voxel_t.size(1)-border-border + r[1] *= voxel.voxel_t.size(2)-border-border + r = r + border + y = voxel.heightmap[int(r[0]+0.5), int(r[1]+0.5)] + (np.random.rand(1)-0.5) * 5 + cam_target = torch.tensor([y, r[0], r[1]], dtype=torch.float32) + cam_ori = cam_target - cam_dir * (np.random.rand(1).item() * 100) + cam_ori[0] = max(voxel.heightmap[int(cam_ori[1]+0.5), int(cam_ori[2]+0.5)]+2, cam_ori[0]) + # Translate to voxel coordinate + cam_ori = voxel.world2local(cam_ori) + cam_dir = voxel.world2local(cam_dir, is_vec=True) + cam_up = voxel.world2local(cam_up, is_vec=True) + + return cam_ori, cam_dir, cam_up + + +def get_neighbor_height(heightmap, loc0, loc1, minheight, neighbor_size=7): + loc0 = int(loc0) + loc1 = int(loc1) + height = 0 + for dx in range(-neighbor_size//2, neighbor_size//2+1): + for dy in range(-neighbor_size//2, neighbor_size//2+1): + if (loc0+dx) < 0 or (loc0+dx) >= heightmap.shape[0] or (loc1+dy) < 0 or (loc1+dy) >= heightmap.shape[1]: + height = max(height, minheight) + else: + height = max(minheight, heightmap[loc0+dx, loc1+dy] + 2) + return height + + +def rand_camera_pose_firstperson(voxel, border=128): + r"""Generating random camera pose in the upper hemisphere, in the format of origin-direction-up + """ + r = np.random.rand(5) + r[0] *= voxel.voxel_t.size(1)-border-border + r[1] *= voxel.voxel_t.size(2)-border-border + r[0] = r[0] + border + r[1] = r[1] + border + + y = get_neighbor_height(voxel.heightmap, r[0], r[1], 0) + np.random.rand(1) * 15 + + cam_ori = torch.tensor([y, r[0], r[1]], dtype=torch.float32) + + rand_ang_h = r[2] * 2 * np.pi + cam_target = torch.tensor([0, cam_ori[1]+np.sin(rand_ang_h)*border*r[4], cam_ori[2] + + np.cos(rand_ang_h)*border*r[4]], dtype=torch.float32) + cam_target[0] = get_neighbor_height(voxel.heightmap, cam_target[1], + cam_target[2], 0, neighbor_size=1) - 2 + r[3] * 10 + + cam_dir = cam_target - cam_ori + + cam_up = torch.tensor([1, 0, 0], dtype=torch.float32) + + cam_ori = voxel.world2local(cam_ori) + cam_dir = voxel.world2local(cam_dir, is_vec=True) + cam_up = voxel.world2local(cam_up, is_vec=True) + + return cam_ori, cam_dir, cam_up + + +def rand_camera_pose_thridperson(voxel, border=96): + r = torch.rand(2) + r[0] *= voxel.voxel_t.size(1) + r[1] *= voxel.voxel_t.size(2) + rand_height = 60 + torch.rand(1) * 40 + rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=5) + farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32) + + r = torch.rand(2) + r[0] *= voxel.voxel_t.size(1) - border - border + r[1] *= voxel.voxel_t.size(2) - border - border + r[0] = r[0] + border + r[1] = r[1] + border + rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=1) - 5 + nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32) + + cam_ori = voxel.world2local(farpoint) + cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + + return cam_ori, cam_dir, cam_up + + +def rand_camera_pose_thridperson2(voxel, border=48): + r = torch.rand(2) + r[0] *= voxel.voxel_t.size(1) - border - border + r[1] *= voxel.voxel_t.size(2) - border - border + r[0] = r[0] + border + r[1] = r[1] + border + rand_height = 60 + torch.rand(1) * 40 + rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=5) + farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32) + + r = torch.rand(2) + r[0] *= voxel.voxel_t.size(1) - border - border + r[1] *= voxel.voxel_t.size(2) - border - border + r[0] = r[0] + border + r[1] = r[1] + border + rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=1) - 5 + nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32) + + # Random Up vector (tilt a little bit) + # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176 + up = torch.randn(3) * 0.02 + up[0] = 1.0 + up = up / up.norm(p=2) + cam_ori = voxel.world2local(farpoint) + cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = voxel.world2local(up, is_vec=True) + + return cam_ori, cam_dir, cam_up + + +def rand_camera_pose_thridperson3(voxel, border=64): + r"""Attempting to solve the camera too close to wall problem and the lack of aerial poses.""" + r = torch.rand(2) + r[0] *= voxel.voxel_t.size(1) - border - border + r[1] *= voxel.voxel_t.size(2) - border - border + r[0] = r[0] + border + r[1] = r[1] + border + rand_height = 60 + torch.rand(1) * 40 + if torch.rand(1) > 0.8: + rand_height = 60 + torch.rand(1) * 60 + rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=7) + farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32) + + r = torch.rand(2) + r[0] *= voxel.voxel_t.size(1) - border - border + r[1] *= voxel.voxel_t.size(2) - border - border + r[0] = r[0] + border + r[1] = r[1] + border + rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=3) - 5 + nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32) + + # Random Up vector (tilt a little bit) + # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176 + up = torch.randn(3) * 0.02 + up[0] = 1.0 + up = up / up.norm(p=2) + # print(up) + cam_ori = voxel.world2local(farpoint) + cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = voxel.world2local(up, is_vec=True) + + return cam_ori, cam_dir, cam_up + + +def rand_camera_pose_tour(voxel): + size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2 + center = [voxel.voxel_t.size(1)/2, voxel.voxel_t.size(2)/2] + + rnd = torch.rand(8) + + rnd_deg = torch.rand(1) * 2 * np.pi + far_radius = rnd[0]*0.8+0.2 + far_height = rnd[1]*30 + 60 + farpoint = torch.tensor([ + far_height, + torch.sin(rnd_deg)*size*far_radius + center[0], + torch.cos(rnd_deg)*size*far_radius + center[1]]) + + farpoint[0] = get_neighbor_height(voxel.heightmap, farpoint[1], farpoint[2], farpoint[0], neighbor_size=7) + + near_radius = far_radius * rnd[2] + near_shift_rad = np.pi*(rnd[3]-0.5) + near_height = 60 + rnd[4] * 10 + nearpoint = torch.tensor([ + near_height, + torch.sin(rnd_deg+near_shift_rad)*size*near_radius + center[0], + torch.cos(rnd_deg+near_shift_rad)*size*near_radius + center[1]]) + + # Random Up vector (tilt a little bit) + # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176 + up = torch.randn(3) * 0.02 + up[0] = 1.0 + up = up / up.norm(p=2) + cam_ori = voxel.world2local(farpoint) + cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = voxel.world2local(up, is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)*(rnd[5]*0.75+0.25)) # about 24mm fov + + return cam_ori, cam_dir, cam_up, cam_f + +# Look from center to outward + + +def rand_camera_pose_insideout(voxel): + size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2 + center = [voxel.voxel_t.size(1)/2, voxel.voxel_t.size(2)/2] + + rnd = torch.rand(8) + + rnd_deg = torch.rand(1) * 2 * np.pi + far_radius = rnd[0]*0.8+0.2 + far_height = rnd[1]*10 + 60 + farpoint = torch.tensor([ + far_height, + torch.sin(rnd_deg)*size*far_radius + center[0], + torch.cos(rnd_deg)*size*far_radius + center[1]]) + + near_radius = far_radius * rnd[2] + near_shift_rad = np.pi*(rnd[3]-0.5) + near_height = 60 + rnd[4] * 30 + nearpoint = torch.tensor([ + near_height, + torch.sin(rnd_deg+near_shift_rad)*size*near_radius + center[0], + torch.cos(rnd_deg+near_shift_rad)*size*near_radius + center[1]]) + + nearpoint[0] = get_neighbor_height(voxel.heightmap, nearpoint[1], nearpoint[2], nearpoint[0], neighbor_size=7) + + # Random Up vector (tilt a little bit) + # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176 + up = torch.randn(3) * 0.02 + up[0] = 1.0 + up = up / up.norm(p=2) + cam_ori = voxel.world2local(nearpoint) + cam_dir = voxel.world2local(farpoint-nearpoint, is_vec=True) + cam_up = voxel.world2local(up, is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)*(rnd[5]*0.75+0.25)) # about 24mm fov + + return cam_ori, cam_dir, cam_up, cam_f diff --git a/imaginaire/model_utils/gancraft/gaugan_lbl2col.csv b/imaginaire/model_utils/gancraft/gaugan_lbl2col.csv new file mode 100644 index 0000000000000000000000000000000000000000..ba061b7a5bda98899c8b1f653d5f204fccfa38e4 --- /dev/null +++ b/imaginaire/model_utils/gancraft/gaugan_lbl2col.csv @@ -0,0 +1,182 @@ +person,#00AC0D +bicycle,#012F47 +car,#0275B8 +motorcycle,#03C098 +airplane,#04434F +bus,#05FB29 +train,#06C312 +truck,#076728 +boat,#0809B6 +traffic-light,#09D3CF +fire-hydrant,#0A150B +street-sign,#0BF2A6 +stop-sign,#0C246F +parking-meter,#0D575D +bench,#0E46F9 +bird,#0FD881 +cat,#1058DF +dog,#118C76 +horse,#123A2C +sheep,#13C1D8 +cow,#14E67D +elephant,#152718 +bear,#165743 +zebra,#17AED2 +giraffe,#1858EF +hat,#195103 +backpack,#1AA5EA +umbrella,#1B19CC +shoe,#1C4DE6 +eye-glasses,#1D4823 +handbag,#1E09D6 +tie,#1F94FE +suitcase,#2073BD +frisbee,#21D0C5 +skis,#22F3D7 +snowboard,#23C52B +sports-ball,#24FE20 +kite,#254F0B +baseball-bat,#26AF68 +baseball-glove,#27C0D4 +skateboard,#28528A +surfboard,#2963B6 +tennis-racket,#2AD8EB +bottle,#2BB1A5 +plate,#2CF37D +wine-glass,#2D1D9C +cup,#2E936F +fork,#2F93E8 +knife,#308E02 +spoon,#31A71B +bowl,#3220D3 +banana,#33C1D9 +apple,#340997 +sandwich,#35B935 +orange,#367F33 +broccoli,#3720AE +carrot,#381F94 +hot-dog,#39CAB5 +pizza,#3AF41D +donut,#3B9743 +cake,#3CA323 +chair,#3DFE27 +couch,#3ECB89 +potted-plant,#3F7249 +bed,#40B729 +mirror,#411C97 +dining-table,#422283 +window,#43802E +desk,#4480DA +toilet,#45A4B2 +door,#46356C +tv,#478503 +laptop,#48261F +mouse,#49E809 +remote,#4AF48A +keyboard,#4B111B +cell-phone,#4C4FAD +microwave,#4D84C7 +oven,#4E69A7 +toaster,#4F2A3D +sink,#50BA55 +refrigerator,#511F61 +blender,#52782C +book,#530122 +clock,#5441A2 +vase,#55E758 +scissors,#56A921 +teddy-bear,#573985 +hair-drier,#5823E8 +toothbrush,#5966FF +hair-brush,#5A7724 +banner,#5B0B00 +blanket,#5CAECB +branch,#5D5222 +bridge,#5E5BC5 +building-other,#5F807E +bush,#606E32 +cabinet,#6163FE +cage,#623550 +cardboard,#638CBE +carpet,#647988 +ceiling-other,#65AABD +ceiling-tile,#665481 +cloth,#67CBD1 +clothes,#684470 +clouds,#696969 +counter,#6AC478 +cupboard,#6B2F5B +curtain,#6C7FA8 +desk-stuff,#6DF474 +dirt,#6E6E28 +door-stuff,#6FCCB0 +fence,#706419 +floor-marble,#71B443 +floor-other,#72E867 +floor-stone,#734EFC +floor-tile,#748F23 +floor-wood,#759472 +flower,#760000 +fog,#77BA1D +food-other,#7817F1 +fruit,#79CF21 +furniture-other,#7A8D92 +grass,#7BC800 +gravel,#7C32C8 +ground-other,#7D3054 +hill,#7EC864 +house,#7F4502 +leaves,#80A945 +light,#81A365 +mat,#82C08C +metal,#835F2C +mirror-stuff,#84C575 +moss,#855EFD +mountain,#869664 +mud,#87716F +napkin,#88B25B +net,#892455 +paper,#8AA2A7 +pavement,#8B3027 +pillow,#8C5DCB +plant,#8DE61E +plastic,#8E629E +platform,#8F2A91 +playingfield,#90CDC6 +railing,#9170C7 +railroad,#92E712 +river,#9364C8 +road,#946E28 +rock,#956432 +roof,#9600B1 +rug,#978A29 +salad,#98725D +sand,#999900 +sea,#9AC6DA +shelf,#9B7FC9 +sky,#9CEEDD +skyscraper,#9DBBF2 +snow,#9E9EAA +solid-other,#9F79DB +stairs,#A06249 +stone,#A1A164 +straw,#A2A3EB +structural,#A3DED1 +table,#A47B69 +tent,#A5C3BA +textile-other,#A65280 +towel,#A7AED6 +tree,#A8C832 +vegetable,#A99410 +wall-brick,#AAD16A +wall-concrete,#AB32A4 +wall-other,#AC9B5E +wall-panel,#AD0E18 +wall-stone,#AE2974 +wall-tile,#AF3ABF +wall-wood,#B0C1C3 +water,#B1C8FF +waterdrops,#B20A88 +window-blind,#B356B8 +window-other,#B42B5B +wood,#B57B00 diff --git a/imaginaire/model_utils/gancraft/gaugan_reduction.csv b/imaginaire/model_utils/gancraft/gaugan_reduction.csv new file mode 100644 index 0000000000000000000000000000000000000000..a49ad38c2f8bddd04a12a08a09b1bbdab5944fc2 --- /dev/null +++ b/imaginaire/model_utils/gancraft/gaugan_reduction.csv @@ -0,0 +1,182 @@ +person,ignore +bicycle,ignore +car,ignore +motorcycle,ignore +airplane,ignore +bus,ignore +train,ignore +truck,ignore +boat,ignore +traffic-light,ignore +fire-hydrant,ignore +street-sign,ignore +stop-sign,ignore +parking-meter,ignore +bench,ignore +bird,ignore +cat,ignore +dog,ignore +horse,ignore +sheep,ignore +cow,ignore +elephant,ignore +bear,ignore +zebra,ignore +giraffe,ignore +hat,ignore +backpack,ignore +umbrella,ignore +shoe,ignore +eye-glasses,ignore +handbag,ignore +tie,ignore +suitcase,ignore +frisbee,ignore +skis,ignore +snowboard,ignore +sports-ball,ignore +kite,ignore +baseball-bat,ignore +baseball-glove,ignore +skateboard,ignore +surfboard,ignore +tennis-racket,ignore +bottle,ignore +plate,ignore +wine-glass,ignore +cup,ignore +fork,ignore +knife,ignore +spoon,ignore +bowl,ignore +banana,ignore +apple,ignore +sandwich,ignore +orange,ignore +broccoli,ignore +carrot,ignore +hot-dog,ignore +pizza,ignore +donut,ignore +cake,ignore +chair,ignore +couch,ignore +potted-plant,ignore +bed,ignore +mirror,ignore +dining-table,ignore +window,ignore +desk,ignore +toilet,ignore +door,ignore +tv,ignore +laptop,ignore +mouse,ignore +remote,ignore +keyboard,ignore +cell-phone,ignore +microwave,ignore +oven,ignore +toaster,ignore +sink,ignore +refrigerator,ignore +blender,ignore +book,ignore +clock,ignore +vase,ignore +scissors,ignore +teddy-bear,ignore +hair-drier,ignore +toothbrush,ignore +hair-brush,ignore +banner,ignore +blanket,ignore +branch,tree +bridge,ignore +building-other,ignore +bush,tree +cabinet,ignore +cage,ignore +cardboard,ignore +carpet,ignore +ceiling-other,ignore +ceiling-tile,ignore +cloth,ignore +clothes,ignore +clouds,sky +counter,ignore +cupboard,ignore +curtain,ignore +desk-stuff,ignore +dirt,dirt +door-stuff,ignore +fence,ignore +floor-marble,ignore +floor-other,ignore +floor-stone,ignore +floor-tile,ignore +floor-wood,ignore +flower,flower +fog,sky +food-other,ignore +fruit,ignore +furniture-other,ignore +grass,grass +gravel,gravel +ground-other,ignore +hill,grass +house,ignore +leaves,tree +light,ignore +mat,ignore +metal,ignore +mirror-stuff,ignore +moss,grass +mountain,grass +mud,dirt +napkin,ignore +net,ignore +paper,ignore +pavement,ignore +pillow,ignore +plant,flower +plastic,ignore +platform,ignore +playingfield,ignore +railing,ignore +railroad,ignore +river,water +road,ignore +rock,rock +roof,ignore +rug,ignore +salad,ignore +sand,sand +sea,water +shelf,ignore +sky,sky +skyscraper,ignore +snow,snow +solid-other,ignore +stairs,ignore +stone,stone +straw,grass +structural,ignore +table,ignore +tent,ignore +textile-other,ignore +towel,ignore +tree,tree +vegetable,ignore +wall-brick,ignore +wall-concrete,ignore +wall-other,ignore +wall-panel,ignore +wall-stone,ignore +wall-tile,ignore +wall-wood,ignore +water,water +waterdrops,ignore +window-blind,ignore +window-other,ignore +wood,ignore diff --git a/imaginaire/model_utils/gancraft/id2name_gg.csv b/imaginaire/model_utils/gancraft/id2name_gg.csv new file mode 100644 index 0000000000000000000000000000000000000000..bb52afe4132cdae36494c08dab6ac4982f572386 --- /dev/null +++ b/imaginaire/model_utils/gancraft/id2name_gg.csv @@ -0,0 +1,680 @@ +0,air,0,sky +1,stone,7368816,stone +2,granite,7368816,rock +3,polished_granite,7368816,rock +4,diorite,7368816,rock +5,polished_diorite,7368816,rock +6,andesite,7368816,rock +7,polished_andesite,7368816,rock +8,grass_block,8368696,grass +9,dirt,9923917,dirt +10,coarse_dirt,9923917,dirt +11,podzol,9923917,dirt +12,cobblestone,7368816,stone +13,oak_planks,9402184,wood +14,spruce_planks,9402184,wood +15,birch_planks,9402184,wood +16,jungle_planks,9402184,wood +17,acacia_planks,9402184,wood +18,dark_oak_planks,9402184,wood +19,oak_sapling,31744,plant +20,spruce_sapling,31744,plant +21,birch_sapling,31744,plant +22,jungle_sapling,31744,plant +23,acacia_sapling,31744,plant +24,dark_oak_sapling,31744,plant +25,bedrock,7368816,rock +26,water,4210943,water +27,lava,16711680, +28,sand,16247203,sand +29,red_sand,16247203,sand +30,gravel,16247203,gravel +31,gold_ore,7368816,rock +32,iron_ore,7368816,rock +33,coal_ore,7368816,rock +34,oak_log,9402184,tree +35,spruce_log,9402184,tree +36,birch_log,9402184,tree +37,jungle_log,9402184,tree +38,acacia_log,9402184,tree +39,dark_oak_log,9402184,tree +40,stripped_spruce_log,9402184,wood +41,stripped_birch_log,9402184,wood +42,stripped_jungle_log,9402184,wood +43,stripped_acacia_log,9402184,wood +44,stripped_dark_oak_log,9402184,wood +45,stripped_oak_log,9402184,wood +46,oak_wood,9402184,wood +47,spruce_wood,9402184,wood +48,birch_wood,9402184,wood +49,jungle_wood,9402184,wood +50,acacia_wood,9402184,wood +51,dark_oak_wood,9402184,wood +52,stripped_oak_wood,9402184,wood +53,stripped_spruce_wood,9402184,wood +54,stripped_birch_wood,9402184,wood +55,stripped_jungle_wood,9402184,wood +56,stripped_acacia_wood,9402184,wood +57,stripped_dark_oak_wood,9402184,wood +58,oak_leaves,31744,tree +59,spruce_leaves,31744,tree +60,birch_leaves,31744,tree +61,jungle_leaves,31744,tree +62,acacia_leaves,31744,tree +63,dark_oak_leaves,31744,tree +64,sponge,15066419, +65,wet_sponge,15066419, +66,glass,0, +67,lapis_ore,7368816, +68,lapis_block,10987431, +69,dispenser,7368816, +70,sandstone,7368816,sand +71,chiseled_sandstone,7368816,sand +72,cut_sandstone,7368816,sand +73,note_block,9402184, +74,white_bed,13092807, +75,orange_bed,13092807, +76,magenta_bed,13092807, +77,light_blue_bed,13092807, +78,yellow_bed,13092807, +79,lime_bed,13092807, +80,pink_bed,13092807, +81,gray_bed,13092807, +82,light_gray_bed,13092807, +83,cyan_bed,13092807, +84,purple_bed,13092807, +85,blue_bed,13092807, +86,brown_bed,13092807, +87,green_bed,13092807, +88,red_bed,13092807, +89,black_bed,13092807, +90,powered_rail,0, +91,detector_rail,0, +92,sticky_piston,7368816, +93,cobweb,13092807, +94,grass,31744,grass +95,fern,31744,grass +96,dead_bush,31744,grass +97,seagrass,4210943,water +98,tall_seagrass,4210943,water +99,piston,7368816, +100,piston_head,7368816, +101,white_wool,13092807, +102,orange_wool,13092807, +103,magenta_wool,13092807, +104,light_blue_wool,13092807, +105,yellow_wool,13092807, +106,lime_wool,13092807, +107,pink_wool,13092807, +108,gray_wool,13092807, +109,light_gray_wool,13092807, +110,cyan_wool,13092807, +111,purple_wool,13092807, +112,blue_wool,13092807, +113,brown_wool,13092807, +114,green_wool,13092807, +115,red_wool,13092807, +116,black_wool,13092807, +117,moving_piston,7368816, +118,dandelion,31744,flower +119,poppy,31744,flower +120,blue_orchid,31744,flower +121,allium,31744,flower +122,azure_bluet,31744,flower +123,red_tulip,31744,flower +124,orange_tulip,31744,flower +125,white_tulip,31744,flower +126,pink_tulip,31744,flower +127,oxeye_daisy,31744,flower +128,cornflower,31744,flower +129,wither_rose,31744,flower +130,lily_of_the_valley,31744,flower +131,brown_mushroom,31744,flower +132,red_mushroom,31744,flower +133,gold_block,10987431, +134,iron_block,10987431, +135,bricks,7368816, +136,tnt,16711680, +137,bookshelf,9402184, +138,mossy_cobblestone,7368816, +139,obsidian,7368816, +140,torch,0, +141,wall_torch,0, +142,fire,0, +143,spawner,7368816, +144,oak_stairs,9402184, +145,chest,9402184, +146,redstone_wire,0, +147,diamond_ore,7368816, +148,diamond_block,10987431, +149,crafting_table,9402184, +150,wheat,31744, +151,farmland,9923917, +152,furnace,7368816, +153,oak_sign,9402184, +154,spruce_sign,9402184, +155,birch_sign,9402184, +156,acacia_sign,9402184, +157,jungle_sign,9402184, +158,dark_oak_sign,9402184, +159,oak_door,9402184, +160,ladder,0, +161,rail,0, +162,cobblestone_stairs,7368816, +163,oak_wall_sign,9402184, +164,spruce_wall_sign,9402184, +165,birch_wall_sign,9402184, +166,acacia_wall_sign,9402184, +167,jungle_wall_sign,9402184, +168,dark_oak_wall_sign,9402184, +169,lever,0, +170,stone_pressure_plate,7368816, +171,iron_door,10987431, +172,oak_pressure_plate,9402184, +173,spruce_pressure_plate,9402184, +174,birch_pressure_plate,9402184, +175,jungle_pressure_plate,9402184, +176,acacia_pressure_plate,9402184, +177,dark_oak_pressure_plate,9402184, +178,redstone_ore,7368816, +179,redstone_torch,0, +180,redstone_wall_torch,0, +181,stone_button,0, +182,snow,16777215,snow +183,ice,10526975,snow +184,snow_block,16777215,snow +185,cactus,31744,plant +186,clay,10791096, +187,sugar_cane,31744,plant +188,jukebox,9402184, +189,oak_fence,9402184, +190,pumpkin,31744, +191,netherrack,7368816, +192,soul_sand,16247203, +193,glowstone,0, +194,nether_portal,0, +195,carved_pumpkin,31744, +196,jack_o_lantern,31744, +197,cake,0, +198,repeater,0, +199,white_stained_glass,0, +200,orange_stained_glass,0, +201,magenta_stained_glass,0, +202,light_blue_stained_glass,0, +203,yellow_stained_glass,0, +204,lime_stained_glass,0, +205,pink_stained_glass,0, +206,gray_stained_glass,0, +207,light_gray_stained_glass,0, +208,cyan_stained_glass,0, +209,purple_stained_glass,0, +210,blue_stained_glass,0, +211,brown_stained_glass,0, +212,green_stained_glass,0, +213,red_stained_glass,0, +214,black_stained_glass,0, +215,oak_trapdoor,9402184, +216,spruce_trapdoor,9402184, +217,birch_trapdoor,9402184, +218,jungle_trapdoor,9402184, +219,acacia_trapdoor,9402184, +220,dark_oak_trapdoor,9402184, +221,stone_bricks,7368816, +222,mossy_stone_bricks,7368816, +223,cracked_stone_bricks,7368816, +224,chiseled_stone_bricks,7368816, +225,infested_stone,10791096, +226,infested_cobblestone,10791096, +227,infested_stone_bricks,10791096, +228,infested_mossy_stone_bricks,10791096, +229,infested_cracked_stone_bricks,10791096, +230,infested_chiseled_stone_bricks,10791096, +231,brown_mushroom_block,9402184,tree +232,red_mushroom_block,9402184,tree +233,mushroom_stem,9402184,tree +234,iron_bars,10987431, +235,glass_pane,0, +236,melon,31744, +237,attached_pumpkin_stem,31744, +238,attached_melon_stem,31744, +239,pumpkin_stem,31744, +240,melon_stem,31744, +241,vine,31744,plant +242,oak_fence_gate,9402184, +243,brick_stairs,7368816, +244,stone_brick_stairs,7368816, +245,mycelium,8368696, +246,lily_pad,31744,grass +247,nether_bricks,7368816, +248,nether_brick_fence,7368816, +249,nether_brick_stairs,7368816, +250,nether_wart,31744, +251,enchanting_table,7368816, +252,brewing_stand,10987431, +253,cauldron,10987431, +254,end_portal,0, +255,end_portal_frame,7368816, +256,end_stone,7368816, +257,dragon_egg,31744, +258,redstone_lamp,0, +259,cocoa,31744, +260,sandstone_stairs,7368816, +261,emerald_ore,7368816, +262,ender_chest,7368816, +263,tripwire_hook,0, +264,tripwire,0, +265,emerald_block,10987431, +266,spruce_stairs,9402184, +267,birch_stairs,9402184, +268,jungle_stairs,9402184, +269,command_block,10987431, +270,beacon,0, +271,cobblestone_wall,7368816, +272,mossy_cobblestone_wall,7368816, +273,flower_pot,0, +274,potted_oak_sapling,0, +275,potted_spruce_sapling,0, +276,potted_birch_sapling,0, +277,potted_jungle_sapling,0, +278,potted_acacia_sapling,0, +279,potted_dark_oak_sapling,0, +280,potted_fern,0, +281,potted_dandelion,0, +282,potted_poppy,0, +283,potted_blue_orchid,0, +284,potted_allium,0, +285,potted_azure_bluet,0, +286,potted_red_tulip,0, +287,potted_orange_tulip,0, +288,potted_white_tulip,0, +289,potted_pink_tulip,0, +290,potted_oxeye_daisy,0, +291,potted_cornflower,0, +292,potted_lily_of_the_valley,0, +293,potted_wither_rose,0, +294,potted_red_mushroom,0, +295,potted_brown_mushroom,0, +296,potted_dead_bush,0, +297,potted_cactus,0, +298,carrots,31744, +299,potatoes,31744, +300,oak_button,0, +301,spruce_button,0, +302,birch_button,0, +303,jungle_button,0, +304,acacia_button,0, +305,dark_oak_button,0, +306,skeleton_skull,0, +307,skeleton_wall_skull,0, +308,wither_skeleton_skull,0, +309,wither_skeleton_wall_skull,0, +310,zombie_head,0, +311,zombie_wall_head,0, +312,player_head,0, +313,player_wall_head,0, +314,creeper_head,0, +315,creeper_wall_head,0, +316,dragon_head,0, +317,dragon_wall_head,0, +318,anvil,10987431, +319,chipped_anvil,10987431, +320,damaged_anvil,10987431, +321,trapped_chest,9402184, +322,light_weighted_pressure_plate,10987431, +323,heavy_weighted_pressure_plate,10987431, +324,comparator,0, +325,daylight_detector,9402184, +326,redstone_block,10987431, +327,nether_quartz_ore,7368816, +328,hopper,10987431, +329,quartz_block,7368816, +330,chiseled_quartz_block,7368816, +331,quartz_pillar,7368816, +332,quartz_stairs,7368816, +333,activator_rail,0, +334,dropper,7368816, +335,white_terracotta,7368816, +336,orange_terracotta,7368816, +337,magenta_terracotta,7368816, +338,light_blue_terracotta,7368816, +339,yellow_terracotta,7368816, +340,lime_terracotta,7368816, +341,pink_terracotta,7368816, +342,gray_terracotta,7368816, +343,light_gray_terracotta,7368816, +344,cyan_terracotta,7368816, +345,purple_terracotta,7368816, +346,blue_terracotta,7368816, +347,brown_terracotta,7368816, +348,green_terracotta,7368816, +349,red_terracotta,7368816, +350,black_terracotta,7368816, +351,white_stained_glass_pane,0, +352,orange_stained_glass_pane,0, +353,magenta_stained_glass_pane,0, +354,light_blue_stained_glass_pane,0, +355,yellow_stained_glass_pane,0, +356,lime_stained_glass_pane,0, +357,pink_stained_glass_pane,0, +358,gray_stained_glass_pane,0, +359,light_gray_stained_glass_pane,0, +360,cyan_stained_glass_pane,0, +361,purple_stained_glass_pane,0, +362,blue_stained_glass_pane,0, +363,brown_stained_glass_pane,0, +364,green_stained_glass_pane,0, +365,red_stained_glass_pane,0, +366,black_stained_glass_pane,0, +367,acacia_stairs,9402184, +368,dark_oak_stairs,9402184, +369,slime_block,10791096, +370,barrier,0, +371,iron_trapdoor,10987431, +372,prismarine,7368816, +373,prismarine_bricks,7368816, +374,dark_prismarine,7368816, +375,prismarine_stairs,7368816, +376,prismarine_brick_stairs,7368816, +377,dark_prismarine_stairs,7368816, +378,prismarine_slab,7368816, +379,prismarine_brick_slab,7368816, +380,dark_prismarine_slab,7368816, +381,sea_lantern,0, +382,hay_block,8368696, +383,white_carpet,13092807, +384,orange_carpet,13092807, +385,magenta_carpet,13092807, +386,light_blue_carpet,13092807, +387,yellow_carpet,13092807, +388,lime_carpet,13092807, +389,pink_carpet,13092807, +390,gray_carpet,13092807, +391,light_gray_carpet,13092807, +392,cyan_carpet,13092807, +393,purple_carpet,13092807, +394,blue_carpet,13092807, +395,brown_carpet,13092807, +396,green_carpet,13092807, +397,red_carpet,13092807, +398,black_carpet,13092807, +399,terracotta,7368816, +400,coal_block,7368816, +401,packed_ice,10526975, +402,sunflower,31744,flower +403,lilac,31744,flower +404,rose_bush,31744,flower +405,peony,31744,flower +406,tall_grass,31744,plant +407,large_fern,31744,plant +408,white_banner,9402184, +409,orange_banner,9402184, +410,magenta_banner,9402184, +411,light_blue_banner,9402184, +412,yellow_banner,9402184, +413,lime_banner,9402184, +414,pink_banner,9402184, +415,gray_banner,9402184, +416,light_gray_banner,9402184, +417,cyan_banner,9402184, +418,purple_banner,9402184, +419,blue_banner,9402184, +420,brown_banner,9402184, +421,green_banner,9402184, +422,red_banner,9402184, +423,black_banner,9402184, +424,white_wall_banner,9402184, +425,orange_wall_banner,9402184, +426,magenta_wall_banner,9402184, +427,light_blue_wall_banner,9402184, +428,yellow_wall_banner,9402184, +429,lime_wall_banner,9402184, +430,pink_wall_banner,9402184, +431,gray_wall_banner,9402184, +432,light_gray_wall_banner,9402184, +433,cyan_wall_banner,9402184, +434,purple_wall_banner,9402184, +435,blue_wall_banner,9402184, +436,brown_wall_banner,9402184, +437,green_wall_banner,9402184, +438,red_wall_banner,9402184, +439,black_wall_banner,9402184, +440,red_sandstone,7368816, +441,chiseled_red_sandstone,7368816, +442,cut_red_sandstone,7368816, +443,red_sandstone_stairs,7368816, +444,oak_slab,9402184, +445,spruce_slab,9402184, +446,birch_slab,9402184, +447,jungle_slab,9402184, +448,acacia_slab,9402184, +449,dark_oak_slab,9402184, +450,stone_slab,7368816, +451,smooth_stone_slab,7368816, +452,sandstone_slab,7368816, +453,cut_sandstone_slab,7368816, +454,petrified_oak_slab,7368816, +455,cobblestone_slab,7368816, +456,brick_slab,7368816, +457,stone_brick_slab,7368816, +458,nether_brick_slab,7368816, +459,quartz_slab,7368816, +460,red_sandstone_slab,7368816, +461,cut_red_sandstone_slab,7368816, +462,purpur_slab,7368816, +463,smooth_stone,7368816, +464,smooth_sandstone,7368816, +465,smooth_quartz,7368816, +466,smooth_red_sandstone,7368816, +467,spruce_fence_gate,9402184, +468,birch_fence_gate,9402184, +469,jungle_fence_gate,9402184, +470,acacia_fence_gate,9402184, +471,dark_oak_fence_gate,9402184, +472,spruce_fence,9402184, +473,birch_fence,9402184, +474,jungle_fence,9402184, +475,acacia_fence,9402184, +476,dark_oak_fence,9402184, +477,spruce_door,9402184, +478,birch_door,9402184, +479,jungle_door,9402184, +480,acacia_door,9402184, +481,dark_oak_door,9402184, +482,end_rod,0, +483,chorus_plant,31744, +484,chorus_flower,31744, +485,purpur_block,7368816, +486,purpur_pillar,7368816, +487,purpur_stairs,7368816, +488,end_stone_bricks,7368816, +489,beetroots,31744, +490,grass_path,9923917, +491,end_gateway,0, +492,repeating_command_block,10987431, +493,chain_command_block,10987431, +494,frosted_ice,10526975, +495,magma_block,7368816, +496,nether_wart_block,8368696, +497,red_nether_bricks,7368816, +498,bone_block,7368816, +499,structure_void,0, +500,observer,7368816, +501,shulker_box,8339378, +502,white_shulker_box,8339378, +503,orange_shulker_box,8339378, +504,magenta_shulker_box,8339378, +505,light_blue_shulker_box,8339378, +506,yellow_shulker_box,8339378, +507,lime_shulker_box,8339378, +508,pink_shulker_box,8339378, +509,gray_shulker_box,8339378, +510,light_gray_shulker_box,8339378, +511,cyan_shulker_box,8339378, +512,purple_shulker_box,8339378, +513,blue_shulker_box,8339378, +514,brown_shulker_box,8339378, +515,green_shulker_box,8339378, +516,red_shulker_box,8339378, +517,black_shulker_box,8339378, +518,white_glazed_terracotta,7368816, +519,orange_glazed_terracotta,7368816, +520,magenta_glazed_terracotta,7368816, +521,light_blue_glazed_terracotta,7368816, +522,yellow_glazed_terracotta,7368816, +523,lime_glazed_terracotta,7368816, +524,pink_glazed_terracotta,7368816, +525,gray_glazed_terracotta,7368816, +526,light_gray_glazed_terracotta,7368816, +527,cyan_glazed_terracotta,7368816, +528,purple_glazed_terracotta,7368816, +529,blue_glazed_terracotta,7368816, +530,brown_glazed_terracotta,7368816, +531,green_glazed_terracotta,7368816, +532,red_glazed_terracotta,7368816, +533,black_glazed_terracotta,7368816, +534,white_concrete,7368816, +535,orange_concrete,7368816, +536,magenta_concrete,7368816, +537,light_blue_concrete,7368816, +538,yellow_concrete,7368816, +539,lime_concrete,7368816, +540,pink_concrete,7368816, +541,gray_concrete,7368816, +542,light_gray_concrete,7368816, +543,cyan_concrete,7368816, +544,purple_concrete,7368816, +545,blue_concrete,7368816, +546,brown_concrete,7368816, +547,green_concrete,7368816, +548,red_concrete,7368816, +549,black_concrete,7368816, +550,white_concrete_powder,16247203, +551,orange_concrete_powder,16247203, +552,magenta_concrete_powder,16247203, +553,light_blue_concrete_powder,16247203, +554,yellow_concrete_powder,16247203, +555,lime_concrete_powder,16247203, +556,pink_concrete_powder,16247203, +557,gray_concrete_powder,16247203, +558,light_gray_concrete_powder,16247203, +559,cyan_concrete_powder,16247203, +560,purple_concrete_powder,16247203, +561,blue_concrete_powder,16247203, +562,brown_concrete_powder,16247203, +563,green_concrete_powder,16247203, +564,red_concrete_powder,16247203, +565,black_concrete_powder,16247203, +566,kelp,4210943, +567,kelp_plant,4210943, +568,dried_kelp_block,8368696, +569,turtle_egg,31744, +570,dead_tube_coral_block,7368816, +571,dead_brain_coral_block,7368816, +572,dead_bubble_coral_block,7368816, +573,dead_fire_coral_block,7368816, +574,dead_horn_coral_block,7368816, +575,tube_coral_block,7368816, +576,brain_coral_block,7368816, +577,bubble_coral_block,7368816, +578,fire_coral_block,7368816, +579,horn_coral_block,7368816, +580,dead_tube_coral,7368816, +581,dead_brain_coral,7368816, +582,dead_bubble_coral,7368816, +583,dead_fire_coral,7368816, +584,dead_horn_coral,7368816, +585,tube_coral,4210943, +586,brain_coral,4210943, +587,bubble_coral,4210943, +588,fire_coral,4210943, +589,horn_coral,4210943, +590,dead_tube_coral_fan,7368816, +591,dead_brain_coral_fan,7368816, +592,dead_bubble_coral_fan,7368816, +593,dead_fire_coral_fan,7368816, +594,dead_horn_coral_fan,7368816, +595,tube_coral_fan,4210943, +596,brain_coral_fan,4210943, +597,bubble_coral_fan,4210943, +598,fire_coral_fan,4210943, +599,horn_coral_fan,4210943, +600,dead_tube_coral_wall_fan,7368816, +601,dead_brain_coral_wall_fan,7368816, +602,dead_bubble_coral_wall_fan,7368816, +603,dead_fire_coral_wall_fan,7368816, +604,dead_horn_coral_wall_fan,7368816, +605,tube_coral_wall_fan,4210943, +606,brain_coral_wall_fan,4210943, +607,bubble_coral_wall_fan,4210943, +608,fire_coral_wall_fan,4210943, +609,horn_coral_wall_fan,4210943, +610,sea_pickle,4210943, +611,blue_ice,10526975, +612,conduit,0, +613,bamboo_sapling,9402184,plant +614,bamboo,9402184,plant +615,potted_bamboo,0, +616,void_air,0,dirt +617,cave_air,0,dirt +618,bubble_column,4210943, +619,polished_granite_stairs,7368816, +620,smooth_red_sandstone_stairs,7368816, +621,mossy_stone_brick_stairs,7368816, +622,polished_diorite_stairs,7368816, +623,mossy_cobblestone_stairs,7368816, +624,end_stone_brick_stairs,7368816, +625,stone_stairs,7368816, +626,smooth_sandstone_stairs,7368816, +627,smooth_quartz_stairs,7368816, +628,granite_stairs,7368816, +629,andesite_stairs,7368816, +630,red_nether_brick_stairs,7368816, +631,polished_andesite_stairs,7368816, +632,diorite_stairs,7368816, +633,polished_granite_slab,7368816, +634,smooth_red_sandstone_slab,7368816, +635,mossy_stone_brick_slab,7368816, +636,polished_diorite_slab,7368816, +637,mossy_cobblestone_slab,7368816, +638,end_stone_brick_slab,7368816, +639,smooth_sandstone_slab,7368816, +640,smooth_quartz_slab,7368816, +641,granite_slab,7368816, +642,andesite_slab,7368816, +643,red_nether_brick_slab,7368816, +644,polished_andesite_slab,7368816, +645,diorite_slab,7368816, +646,brick_wall,7368816, +647,prismarine_wall,7368816, +648,red_sandstone_wall,7368816, +649,mossy_stone_brick_wall,7368816, +650,granite_wall,7368816, +651,stone_brick_wall,7368816, +652,nether_brick_wall,7368816, +653,andesite_wall,7368816, +654,red_nether_brick_wall,7368816, +655,sandstone_wall,7368816, +656,end_stone_brick_wall,7368816, +657,diorite_wall,7368816, +658,scaffolding,0, +659,loom,9402184, +660,barrel,9402184, +661,smoker,7368816, +662,blast_furnace,7368816, +663,cartography_table,9402184, +664,fletching_table,9402184, +665,grindstone,10987431, +666,lectern,9402184, +667,smithing_table,9402184, +668,stonecutter,7368816, +669,bell,10987431, +670,lantern,10987431, +671,campfire,9402184, +672,sweet_berry_bush,31744, +673,structure_block,10987431, +674,jigsaw,10987431, +675,composter,9402184, +676,bee_nest,9402184, +677,beehive,9402184, +678,honey_block,10791096, +679,honeycomb_block,10791096, diff --git a/imaginaire/model_utils/gancraft/loss.py b/imaginaire/model_utils/gancraft/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b1811de5307535167f645b4ea8a889a468b41780 --- /dev/null +++ b/imaginaire/model_utils/gancraft/loss.py @@ -0,0 +1,96 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GANLoss(nn.Module): + def __init__(self, target_real_label=1.0, target_fake_label=0.0): + r"""GAN loss constructor. + + Args: + target_real_label (float): Desired output label for the real images. + target_fake_label (float): Desired output label for the fake images. + """ + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_tensor = None + self.fake_label_tensor = None + + def forward(self, input_x, t_real, weight=None, + reduce_dim=True, dis_update=True): + r"""GAN loss computation. + + Args: + input_x (tensor or list of tensors): Output values. + t_real (boolean): Is this output value for real images. + reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use + multi-resolution discriminators. + weight (float): Weight to scale the loss value. + dis_update (boolean): Updating the discriminator or the generator. + Returns: + loss (tensor): Loss value. + """ + if isinstance(input_x, list): + loss = 0 + for pred_i in input_x: + if isinstance(pred_i, list): + pred_i = pred_i[-1] + loss_tensor = self.loss(pred_i, t_real, weight, + reduce_dim, dis_update) + bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) + new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) + loss += new_loss + return loss / len(input_x) + else: + return self.loss(input_x, t_real, weight, reduce_dim, dis_update) + + def loss(self, input_x, t_real, weight=None, + reduce_dim=True, dis_update=True): + r"""N+1 label GAN loss computation. + + Args: + input_x (tensor): Output values. + t_real (boolean): Is this output value for real images. + reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use + multi-resolution discriminators. + weight (float): Weight to scale the loss value. + dis_update (boolean): Updating the discriminator or the generator. + Returns: + loss (tensor): Loss value. + """ + assert reduce_dim is True + pred = input_x['pred'].clone() + label = input_x['label'].clone() + batch_size = pred.size(0) + + # ignore label 0 + label[:, 0, ...] = 0 + pred[:, 0, ...] = 0 + pred = F.log_softmax(pred, dim=1) + assert pred.size(1) == (label.size(1) + 1) + if dis_update: + if t_real: + pred_real = pred[:, :-1, :, :] + loss = - label * pred_real + loss = torch.sum(loss, dim=1, keepdim=True) + else: + pred_fake = pred[:, -1, None, :, :] # N plus 1 + loss = - pred_fake + else: + assert t_real, "GAN loss must be aiming for real." + pred_real = pred[:, :-1, :, :] + loss = - label * pred_real + loss = torch.sum(loss, dim=1, keepdim=True) + + if weight is not None: + loss = loss * weight + if reduce_dim: + loss = torch.mean(loss) + else: + loss = loss.view(batch_size, -1).mean(dim=1) + return loss diff --git a/imaginaire/model_utils/gancraft/mc_lbl_reduction.py b/imaginaire/model_utils/gancraft/mc_lbl_reduction.py new file mode 100644 index 0000000000000000000000000000000000000000..03fec1d3b600cfd31358cf480924da5232e0104a --- /dev/null +++ b/imaginaire/model_utils/gancraft/mc_lbl_reduction.py @@ -0,0 +1,83 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os +import csv + + +class ReducedLabelMapper: + def __init__(self): + this_path = os.path.dirname(os.path.abspath(__file__)) + print('[ReducedLabelMapper] Loading from {}'.format(this_path)) + + # Load Minecraft LUT + mcid2rdlbl_lut = {} + mcid2mclbl_lut = {} + with open(os.path.join(this_path, 'mc_reduction.csv'), newline='') as csvfile: + csvreader = csv.reader(csvfile, delimiter=',') + for row in csvreader: + mcid = int(row[0]) + mcid2rdlbl_lut[mcid] = row[3] + mcid2mclbl_lut[mcid] = row[1] + + # Load reduced label set + reduced_lbls = [] + rdlbl2rdid = {} + with open(os.path.join(this_path, 'reduced_coco_lbls.csv'), newline='') as csvfile: + csvreader = csv.reader(csvfile, delimiter=',') + for idx, row in enumerate(csvreader): + rdlbl2rdid[row[0]] = idx + reduced_lbls.append(row[0]) + print(['{}: {}'.format(rdid, rdlbl) for rdid, rdlbl in enumerate(reduced_lbls)]) + # The first label should always be 'ignore' + assert reduced_lbls[0] == 'ignore' + + # Generate Minecraft ID to Reduced ID LUT + mcid2rdid_lut = [] + for mcid in range(len(mcid2rdlbl_lut)): + rdlbl = mcid2rdlbl_lut[mcid] + if rdlbl == '': + rdlbl = 'ignore' + rdid = rdlbl2rdid[rdlbl] + mcid2rdid_lut.append(rdid) + + # ================= coco part ================== + gg_label_list = [] + gglbl2ggid = {} + with open(os.path.join(this_path, 'gaugan_lbl2col.csv'), newline='') as csvfile: + csvreader = csv.reader(csvfile, delimiter=',') + for idx, row in enumerate(csvreader): + gg_label_list.append(row[0]) + gglbl2ggid[row[0]] = idx + + # Load coco -> reduced mapping table + gglbl2rdid = {} + with open(os.path.join(this_path, 'gaugan_reduction.csv'), newline='') as csvfile: + csvreader = csv.reader(csvfile, delimiter=',') + for idx, row in enumerate(csvreader): + gglbl = row[0] + target_rdlbl = row[1] + ggid = gglbl2ggid[gglbl] + target_rdid = rdlbl2rdid[target_rdlbl] + gglbl2rdid[ggid] = target_rdid + ggid2rdid = [gglbl2rdid[i] for i in range(len(gglbl2rdid))] + + print('[ReducedLabelMapper] #Reduced Labels: {}'.format(len(reduced_lbls))) + + self.mcid2rdid_lut = mcid2rdid_lut + self.ggid2rdid = ggid2rdid + self.reduced_lbls = reduced_lbls + + self.ignore_id = rdlbl2rdid['ignore'] + self.dirt_id = rdlbl2rdid['dirt'] + self.water_id = rdlbl2rdid['water'] + + self.gglbl2ggid = gglbl2ggid + + def gglbl2ggid(self, gglbl): + return self.gglbl2ggid[gglbl] + + +if __name__ == '__main__': + mapper = ReducedLabelMapper() diff --git a/imaginaire/model_utils/gancraft/mc_reduction.csv b/imaginaire/model_utils/gancraft/mc_reduction.csv new file mode 100644 index 0000000000000000000000000000000000000000..254af7255d67be76b5d41cdbd162173e55ec0b9c --- /dev/null +++ b/imaginaire/model_utils/gancraft/mc_reduction.csv @@ -0,0 +1,680 @@ +0,air,0,sky +1,stone,7368816,stone +2,granite,7368816,rock +3,polished_granite,7368816,rock +4,diorite,7368816,rock +5,polished_diorite,7368816,rock +6,andesite,7368816,rock +7,polished_andesite,7368816,rock +8,grass_block,8368696,grass +9,dirt,9923917,dirt +10,coarse_dirt,9923917,dirt +11,podzol,9923917,dirt +12,cobblestone,7368816,stone +13,oak_planks,9402184, +14,spruce_planks,9402184, +15,birch_planks,9402184, +16,jungle_planks,9402184, +17,acacia_planks,9402184, +18,dark_oak_planks,9402184, +19,oak_sapling,31744,grass +20,spruce_sapling,31744,grass +21,birch_sapling,31744,grass +22,jungle_sapling,31744,grass +23,acacia_sapling,31744,grass +24,dark_oak_sapling,31744,grass +25,bedrock,7368816,rock +26,water,4210943,water +27,lava,16711680, +28,sand,16247203,sand +29,red_sand,16247203,sand +30,gravel,16247203,gravel +31,gold_ore,7368816,rock +32,iron_ore,7368816,rock +33,coal_ore,7368816,rock +34,oak_log,9402184,tree +35,spruce_log,9402184,tree +36,birch_log,9402184,tree +37,jungle_log,9402184,tree +38,acacia_log,9402184,tree +39,dark_oak_log,9402184,tree +40,stripped_spruce_log,9402184, +41,stripped_birch_log,9402184, +42,stripped_jungle_log,9402184, +43,stripped_acacia_log,9402184, +44,stripped_dark_oak_log,9402184, +45,stripped_oak_log,9402184, +46,oak_wood,9402184, +47,spruce_wood,9402184, +48,birch_wood,9402184, +49,jungle_wood,9402184, +50,acacia_wood,9402184, +51,dark_oak_wood,9402184, +52,stripped_oak_wood,9402184, +53,stripped_spruce_wood,9402184, +54,stripped_birch_wood,9402184, +55,stripped_jungle_wood,9402184, +56,stripped_acacia_wood,9402184, +57,stripped_dark_oak_wood,9402184, +58,oak_leaves,31744,tree +59,spruce_leaves,31744,tree +60,birch_leaves,31744,tree +61,jungle_leaves,31744,tree +62,acacia_leaves,31744,tree +63,dark_oak_leaves,31744,tree +64,sponge,15066419, +65,wet_sponge,15066419, +66,glass,0, +67,lapis_ore,7368816, +68,lapis_block,10987431, +69,dispenser,7368816, +70,sandstone,7368816,sand +71,chiseled_sandstone,7368816,sand +72,cut_sandstone,7368816,sand +73,note_block,9402184, +74,white_bed,13092807, +75,orange_bed,13092807, +76,magenta_bed,13092807, +77,light_blue_bed,13092807, +78,yellow_bed,13092807, +79,lime_bed,13092807, +80,pink_bed,13092807, +81,gray_bed,13092807, +82,light_gray_bed,13092807, +83,cyan_bed,13092807, +84,purple_bed,13092807, +85,blue_bed,13092807, +86,brown_bed,13092807, +87,green_bed,13092807, +88,red_bed,13092807, +89,black_bed,13092807, +90,powered_rail,0, +91,detector_rail,0, +92,sticky_piston,7368816, +93,cobweb,13092807, +94,grass,31744,grass +95,fern,31744,grass +96,dead_bush,31744,grass +97,seagrass,4210943,water +98,tall_seagrass,4210943,water +99,piston,7368816, +100,piston_head,7368816, +101,white_wool,13092807, +102,orange_wool,13092807, +103,magenta_wool,13092807, +104,light_blue_wool,13092807, +105,yellow_wool,13092807, +106,lime_wool,13092807, +107,pink_wool,13092807, +108,gray_wool,13092807, +109,light_gray_wool,13092807, +110,cyan_wool,13092807, +111,purple_wool,13092807, +112,blue_wool,13092807, +113,brown_wool,13092807, +114,green_wool,13092807, +115,red_wool,13092807, +116,black_wool,13092807, +117,moving_piston,7368816, +118,dandelion,31744,flower +119,poppy,31744,flower +120,blue_orchid,31744,flower +121,allium,31744,flower +122,azure_bluet,31744,flower +123,red_tulip,31744,flower +124,orange_tulip,31744,flower +125,white_tulip,31744,flower +126,pink_tulip,31744,flower +127,oxeye_daisy,31744,flower +128,cornflower,31744,flower +129,wither_rose,31744,flower +130,lily_of_the_valley,31744,flower +131,brown_mushroom,31744,flower +132,red_mushroom,31744,flower +133,gold_block,10987431, +134,iron_block,10987431, +135,bricks,7368816, +136,tnt,16711680, +137,bookshelf,9402184, +138,mossy_cobblestone,7368816, +139,obsidian,7368816, +140,torch,0, +141,wall_torch,0, +142,fire,0, +143,spawner,7368816, +144,oak_stairs,9402184, +145,chest,9402184, +146,redstone_wire,0, +147,diamond_ore,7368816, +148,diamond_block,10987431, +149,crafting_table,9402184, +150,wheat,31744, +151,farmland,9923917, +152,furnace,7368816, +153,oak_sign,9402184, +154,spruce_sign,9402184, +155,birch_sign,9402184, +156,acacia_sign,9402184, +157,jungle_sign,9402184, +158,dark_oak_sign,9402184, +159,oak_door,9402184, +160,ladder,0, +161,rail,0, +162,cobblestone_stairs,7368816, +163,oak_wall_sign,9402184, +164,spruce_wall_sign,9402184, +165,birch_wall_sign,9402184, +166,acacia_wall_sign,9402184, +167,jungle_wall_sign,9402184, +168,dark_oak_wall_sign,9402184, +169,lever,0, +170,stone_pressure_plate,7368816, +171,iron_door,10987431, +172,oak_pressure_plate,9402184, +173,spruce_pressure_plate,9402184, +174,birch_pressure_plate,9402184, +175,jungle_pressure_plate,9402184, +176,acacia_pressure_plate,9402184, +177,dark_oak_pressure_plate,9402184, +178,redstone_ore,7368816, +179,redstone_torch,0, +180,redstone_wall_torch,0, +181,stone_button,0, +182,snow,16777215,snow +183,ice,10526975,snow +184,snow_block,16777215,snow +185,cactus,31744,flower +186,clay,10791096,dirt +187,sugar_cane,31744,flower +188,jukebox,9402184, +189,oak_fence,9402184, +190,pumpkin,31744, +191,netherrack,7368816, +192,soul_sand,16247203, +193,glowstone,0, +194,nether_portal,0, +195,carved_pumpkin,31744, +196,jack_o_lantern,31744, +197,cake,0, +198,repeater,0, +199,white_stained_glass,0, +200,orange_stained_glass,0, +201,magenta_stained_glass,0, +202,light_blue_stained_glass,0, +203,yellow_stained_glass,0, +204,lime_stained_glass,0, +205,pink_stained_glass,0, +206,gray_stained_glass,0, +207,light_gray_stained_glass,0, +208,cyan_stained_glass,0, +209,purple_stained_glass,0, +210,blue_stained_glass,0, +211,brown_stained_glass,0, +212,green_stained_glass,0, +213,red_stained_glass,0, +214,black_stained_glass,0, +215,oak_trapdoor,9402184, +216,spruce_trapdoor,9402184, +217,birch_trapdoor,9402184, +218,jungle_trapdoor,9402184, +219,acacia_trapdoor,9402184, +220,dark_oak_trapdoor,9402184, +221,stone_bricks,7368816, +222,mossy_stone_bricks,7368816, +223,cracked_stone_bricks,7368816, +224,chiseled_stone_bricks,7368816, +225,infested_stone,10791096, +226,infested_cobblestone,10791096, +227,infested_stone_bricks,10791096, +228,infested_mossy_stone_bricks,10791096, +229,infested_cracked_stone_bricks,10791096, +230,infested_chiseled_stone_bricks,10791096, +231,brown_mushroom_block,9402184,tree +232,red_mushroom_block,9402184,tree +233,mushroom_stem,9402184,tree +234,iron_bars,10987431, +235,glass_pane,0, +236,melon,31744, +237,attached_pumpkin_stem,31744, +238,attached_melon_stem,31744, +239,pumpkin_stem,31744, +240,melon_stem,31744, +241,vine,31744,tree +242,oak_fence_gate,9402184, +243,brick_stairs,7368816, +244,stone_brick_stairs,7368816, +245,mycelium,8368696, +246,lily_pad,31744,grass +247,nether_bricks,7368816, +248,nether_brick_fence,7368816, +249,nether_brick_stairs,7368816, +250,nether_wart,31744, +251,enchanting_table,7368816, +252,brewing_stand,10987431, +253,cauldron,10987431, +254,end_portal,0, +255,end_portal_frame,7368816, +256,end_stone,7368816, +257,dragon_egg,31744, +258,redstone_lamp,0, +259,cocoa,31744, +260,sandstone_stairs,7368816, +261,emerald_ore,7368816, +262,ender_chest,7368816, +263,tripwire_hook,0, +264,tripwire,0, +265,emerald_block,10987431, +266,spruce_stairs,9402184, +267,birch_stairs,9402184, +268,jungle_stairs,9402184, +269,command_block,10987431, +270,beacon,0, +271,cobblestone_wall,7368816, +272,mossy_cobblestone_wall,7368816, +273,flower_pot,0, +274,potted_oak_sapling,0, +275,potted_spruce_sapling,0, +276,potted_birch_sapling,0, +277,potted_jungle_sapling,0, +278,potted_acacia_sapling,0, +279,potted_dark_oak_sapling,0, +280,potted_fern,0, +281,potted_dandelion,0, +282,potted_poppy,0, +283,potted_blue_orchid,0, +284,potted_allium,0, +285,potted_azure_bluet,0, +286,potted_red_tulip,0, +287,potted_orange_tulip,0, +288,potted_white_tulip,0, +289,potted_pink_tulip,0, +290,potted_oxeye_daisy,0, +291,potted_cornflower,0, +292,potted_lily_of_the_valley,0, +293,potted_wither_rose,0, +294,potted_red_mushroom,0, +295,potted_brown_mushroom,0, +296,potted_dead_bush,0, +297,potted_cactus,0, +298,carrots,31744, +299,potatoes,31744, +300,oak_button,0, +301,spruce_button,0, +302,birch_button,0, +303,jungle_button,0, +304,acacia_button,0, +305,dark_oak_button,0, +306,skeleton_skull,0, +307,skeleton_wall_skull,0, +308,wither_skeleton_skull,0, +309,wither_skeleton_wall_skull,0, +310,zombie_head,0, +311,zombie_wall_head,0, +312,player_head,0, +313,player_wall_head,0, +314,creeper_head,0, +315,creeper_wall_head,0, +316,dragon_head,0, +317,dragon_wall_head,0, +318,anvil,10987431, +319,chipped_anvil,10987431, +320,damaged_anvil,10987431, +321,trapped_chest,9402184, +322,light_weighted_pressure_plate,10987431, +323,heavy_weighted_pressure_plate,10987431, +324,comparator,0, +325,daylight_detector,9402184, +326,redstone_block,10987431, +327,nether_quartz_ore,7368816, +328,hopper,10987431, +329,quartz_block,7368816, +330,chiseled_quartz_block,7368816, +331,quartz_pillar,7368816, +332,quartz_stairs,7368816, +333,activator_rail,0, +334,dropper,7368816, +335,white_terracotta,7368816, +336,orange_terracotta,7368816, +337,magenta_terracotta,7368816, +338,light_blue_terracotta,7368816, +339,yellow_terracotta,7368816, +340,lime_terracotta,7368816, +341,pink_terracotta,7368816, +342,gray_terracotta,7368816, +343,light_gray_terracotta,7368816, +344,cyan_terracotta,7368816, +345,purple_terracotta,7368816, +346,blue_terracotta,7368816, +347,brown_terracotta,7368816, +348,green_terracotta,7368816, +349,red_terracotta,7368816, +350,black_terracotta,7368816, +351,white_stained_glass_pane,0, +352,orange_stained_glass_pane,0, +353,magenta_stained_glass_pane,0, +354,light_blue_stained_glass_pane,0, +355,yellow_stained_glass_pane,0, +356,lime_stained_glass_pane,0, +357,pink_stained_glass_pane,0, +358,gray_stained_glass_pane,0, +359,light_gray_stained_glass_pane,0, +360,cyan_stained_glass_pane,0, +361,purple_stained_glass_pane,0, +362,blue_stained_glass_pane,0, +363,brown_stained_glass_pane,0, +364,green_stained_glass_pane,0, +365,red_stained_glass_pane,0, +366,black_stained_glass_pane,0, +367,acacia_stairs,9402184, +368,dark_oak_stairs,9402184, +369,slime_block,10791096, +370,barrier,0, +371,iron_trapdoor,10987431, +372,prismarine,7368816, +373,prismarine_bricks,7368816, +374,dark_prismarine,7368816, +375,prismarine_stairs,7368816, +376,prismarine_brick_stairs,7368816, +377,dark_prismarine_stairs,7368816, +378,prismarine_slab,7368816, +379,prismarine_brick_slab,7368816, +380,dark_prismarine_slab,7368816, +381,sea_lantern,0, +382,hay_block,8368696, +383,white_carpet,13092807, +384,orange_carpet,13092807, +385,magenta_carpet,13092807, +386,light_blue_carpet,13092807, +387,yellow_carpet,13092807, +388,lime_carpet,13092807, +389,pink_carpet,13092807, +390,gray_carpet,13092807, +391,light_gray_carpet,13092807, +392,cyan_carpet,13092807, +393,purple_carpet,13092807, +394,blue_carpet,13092807, +395,brown_carpet,13092807, +396,green_carpet,13092807, +397,red_carpet,13092807, +398,black_carpet,13092807, +399,terracotta,7368816, +400,coal_block,7368816, +401,packed_ice,10526975,snow +402,sunflower,31744,flower +403,lilac,31744,flower +404,rose_bush,31744,flower +405,peony,31744,flower +406,tall_grass,31744,flower +407,large_fern,31744,flower +408,white_banner,9402184, +409,orange_banner,9402184, +410,magenta_banner,9402184, +411,light_blue_banner,9402184, +412,yellow_banner,9402184, +413,lime_banner,9402184, +414,pink_banner,9402184, +415,gray_banner,9402184, +416,light_gray_banner,9402184, +417,cyan_banner,9402184, +418,purple_banner,9402184, +419,blue_banner,9402184, +420,brown_banner,9402184, +421,green_banner,9402184, +422,red_banner,9402184, +423,black_banner,9402184, +424,white_wall_banner,9402184, +425,orange_wall_banner,9402184, +426,magenta_wall_banner,9402184, +427,light_blue_wall_banner,9402184, +428,yellow_wall_banner,9402184, +429,lime_wall_banner,9402184, +430,pink_wall_banner,9402184, +431,gray_wall_banner,9402184, +432,light_gray_wall_banner,9402184, +433,cyan_wall_banner,9402184, +434,purple_wall_banner,9402184, +435,blue_wall_banner,9402184, +436,brown_wall_banner,9402184, +437,green_wall_banner,9402184, +438,red_wall_banner,9402184, +439,black_wall_banner,9402184, +440,red_sandstone,7368816, +441,chiseled_red_sandstone,7368816, +442,cut_red_sandstone,7368816, +443,red_sandstone_stairs,7368816, +444,oak_slab,9402184, +445,spruce_slab,9402184, +446,birch_slab,9402184, +447,jungle_slab,9402184, +448,acacia_slab,9402184, +449,dark_oak_slab,9402184, +450,stone_slab,7368816, +451,smooth_stone_slab,7368816, +452,sandstone_slab,7368816, +453,cut_sandstone_slab,7368816, +454,petrified_oak_slab,7368816, +455,cobblestone_slab,7368816, +456,brick_slab,7368816, +457,stone_brick_slab,7368816, +458,nether_brick_slab,7368816, +459,quartz_slab,7368816, +460,red_sandstone_slab,7368816, +461,cut_red_sandstone_slab,7368816, +462,purpur_slab,7368816, +463,smooth_stone,7368816, +464,smooth_sandstone,7368816, +465,smooth_quartz,7368816, +466,smooth_red_sandstone,7368816, +467,spruce_fence_gate,9402184, +468,birch_fence_gate,9402184, +469,jungle_fence_gate,9402184, +470,acacia_fence_gate,9402184, +471,dark_oak_fence_gate,9402184, +472,spruce_fence,9402184, +473,birch_fence,9402184, +474,jungle_fence,9402184, +475,acacia_fence,9402184, +476,dark_oak_fence,9402184, +477,spruce_door,9402184, +478,birch_door,9402184, +479,jungle_door,9402184, +480,acacia_door,9402184, +481,dark_oak_door,9402184, +482,end_rod,0, +483,chorus_plant,31744, +484,chorus_flower,31744, +485,purpur_block,7368816, +486,purpur_pillar,7368816, +487,purpur_stairs,7368816, +488,end_stone_bricks,7368816, +489,beetroots,31744, +490,grass_path,9923917, +491,end_gateway,0, +492,repeating_command_block,10987431, +493,chain_command_block,10987431, +494,frosted_ice,10526975,snow +495,magma_block,7368816, +496,nether_wart_block,8368696, +497,red_nether_bricks,7368816, +498,bone_block,7368816, +499,structure_void,0, +500,observer,7368816, +501,shulker_box,8339378, +502,white_shulker_box,8339378, +503,orange_shulker_box,8339378, +504,magenta_shulker_box,8339378, +505,light_blue_shulker_box,8339378, +506,yellow_shulker_box,8339378, +507,lime_shulker_box,8339378, +508,pink_shulker_box,8339378, +509,gray_shulker_box,8339378, +510,light_gray_shulker_box,8339378, +511,cyan_shulker_box,8339378, +512,purple_shulker_box,8339378, +513,blue_shulker_box,8339378, +514,brown_shulker_box,8339378, +515,green_shulker_box,8339378, +516,red_shulker_box,8339378, +517,black_shulker_box,8339378, +518,white_glazed_terracotta,7368816, +519,orange_glazed_terracotta,7368816, +520,magenta_glazed_terracotta,7368816, +521,light_blue_glazed_terracotta,7368816, +522,yellow_glazed_terracotta,7368816, +523,lime_glazed_terracotta,7368816, +524,pink_glazed_terracotta,7368816, +525,gray_glazed_terracotta,7368816, +526,light_gray_glazed_terracotta,7368816, +527,cyan_glazed_terracotta,7368816, +528,purple_glazed_terracotta,7368816, +529,blue_glazed_terracotta,7368816, +530,brown_glazed_terracotta,7368816, +531,green_glazed_terracotta,7368816, +532,red_glazed_terracotta,7368816, +533,black_glazed_terracotta,7368816, +534,white_concrete,7368816, +535,orange_concrete,7368816, +536,magenta_concrete,7368816, +537,light_blue_concrete,7368816, +538,yellow_concrete,7368816, +539,lime_concrete,7368816, +540,pink_concrete,7368816, +541,gray_concrete,7368816, +542,light_gray_concrete,7368816, +543,cyan_concrete,7368816, +544,purple_concrete,7368816, +545,blue_concrete,7368816, +546,brown_concrete,7368816, +547,green_concrete,7368816, +548,red_concrete,7368816, +549,black_concrete,7368816, +550,white_concrete_powder,16247203, +551,orange_concrete_powder,16247203, +552,magenta_concrete_powder,16247203, +553,light_blue_concrete_powder,16247203, +554,yellow_concrete_powder,16247203, +555,lime_concrete_powder,16247203, +556,pink_concrete_powder,16247203, +557,gray_concrete_powder,16247203, +558,light_gray_concrete_powder,16247203, +559,cyan_concrete_powder,16247203, +560,purple_concrete_powder,16247203, +561,blue_concrete_powder,16247203, +562,brown_concrete_powder,16247203, +563,green_concrete_powder,16247203, +564,red_concrete_powder,16247203, +565,black_concrete_powder,16247203, +566,kelp,4210943, +567,kelp_plant,4210943, +568,dried_kelp_block,8368696, +569,turtle_egg,31744, +570,dead_tube_coral_block,7368816, +571,dead_brain_coral_block,7368816, +572,dead_bubble_coral_block,7368816, +573,dead_fire_coral_block,7368816, +574,dead_horn_coral_block,7368816, +575,tube_coral_block,7368816, +576,brain_coral_block,7368816, +577,bubble_coral_block,7368816, +578,fire_coral_block,7368816, +579,horn_coral_block,7368816, +580,dead_tube_coral,7368816, +581,dead_brain_coral,7368816, +582,dead_bubble_coral,7368816, +583,dead_fire_coral,7368816, +584,dead_horn_coral,7368816, +585,tube_coral,4210943, +586,brain_coral,4210943, +587,bubble_coral,4210943, +588,fire_coral,4210943, +589,horn_coral,4210943, +590,dead_tube_coral_fan,7368816, +591,dead_brain_coral_fan,7368816, +592,dead_bubble_coral_fan,7368816, +593,dead_fire_coral_fan,7368816, +594,dead_horn_coral_fan,7368816, +595,tube_coral_fan,4210943, +596,brain_coral_fan,4210943, +597,bubble_coral_fan,4210943, +598,fire_coral_fan,4210943, +599,horn_coral_fan,4210943, +600,dead_tube_coral_wall_fan,7368816, +601,dead_brain_coral_wall_fan,7368816, +602,dead_bubble_coral_wall_fan,7368816, +603,dead_fire_coral_wall_fan,7368816, +604,dead_horn_coral_wall_fan,7368816, +605,tube_coral_wall_fan,4210943, +606,brain_coral_wall_fan,4210943, +607,bubble_coral_wall_fan,4210943, +608,fire_coral_wall_fan,4210943, +609,horn_coral_wall_fan,4210943, +610,sea_pickle,4210943, +611,blue_ice,10526975,snow +612,conduit,0, +613,bamboo_sapling,9402184,flower +614,bamboo,9402184,tree +615,potted_bamboo,0, +616,void_air,0,dirt +617,cave_air,0,dirt +618,bubble_column,4210943, +619,polished_granite_stairs,7368816, +620,smooth_red_sandstone_stairs,7368816, +621,mossy_stone_brick_stairs,7368816, +622,polished_diorite_stairs,7368816, +623,mossy_cobblestone_stairs,7368816, +624,end_stone_brick_stairs,7368816, +625,stone_stairs,7368816, +626,smooth_sandstone_stairs,7368816, +627,smooth_quartz_stairs,7368816, +628,granite_stairs,7368816, +629,andesite_stairs,7368816, +630,red_nether_brick_stairs,7368816, +631,polished_andesite_stairs,7368816, +632,diorite_stairs,7368816, +633,polished_granite_slab,7368816, +634,smooth_red_sandstone_slab,7368816, +635,mossy_stone_brick_slab,7368816, +636,polished_diorite_slab,7368816, +637,mossy_cobblestone_slab,7368816, +638,end_stone_brick_slab,7368816, +639,smooth_sandstone_slab,7368816, +640,smooth_quartz_slab,7368816, +641,granite_slab,7368816, +642,andesite_slab,7368816, +643,red_nether_brick_slab,7368816, +644,polished_andesite_slab,7368816, +645,diorite_slab,7368816, +646,brick_wall,7368816, +647,prismarine_wall,7368816, +648,red_sandstone_wall,7368816, +649,mossy_stone_brick_wall,7368816, +650,granite_wall,7368816, +651,stone_brick_wall,7368816, +652,nether_brick_wall,7368816, +653,andesite_wall,7368816, +654,red_nether_brick_wall,7368816, +655,sandstone_wall,7368816, +656,end_stone_brick_wall,7368816, +657,diorite_wall,7368816, +658,scaffolding,0, +659,loom,9402184, +660,barrel,9402184, +661,smoker,7368816, +662,blast_furnace,7368816, +663,cartography_table,9402184, +664,fletching_table,9402184, +665,grindstone,10987431, +666,lectern,9402184, +667,smithing_table,9402184, +668,stonecutter,7368816, +669,bell,10987431, +670,lantern,10987431, +671,campfire,9402184, +672,sweet_berry_bush,31744, +673,structure_block,10987431, +674,jigsaw,10987431, +675,composter,9402184, +676,bee_nest,9402184, +677,beehive,9402184, +678,honey_block,10791096, +679,honeycomb_block,10791096, diff --git a/imaginaire/model_utils/gancraft/mc_utils.py b/imaginaire/model_utils/gancraft/mc_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ade0868fd5b48291980561d5bdfd8c0f21a46fab --- /dev/null +++ b/imaginaire/model_utils/gancraft/mc_utils.py @@ -0,0 +1,300 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +import csv +import time +from scipy import ndimage +import os + +from imaginaire.model_utils.gancraft.mc_lbl_reduction import ReducedLabelMapper + + +def load_voxel_new(voxel_path, shape=[256, 512, 512]): + voxel_world = np.fromfile(voxel_path, dtype='int32') + voxel_world = voxel_world.reshape( + shape[1]//16, shape[2]//16, 16, 16, shape[0]) + voxel_world = voxel_world.transpose(4, 0, 2, 1, 3) + voxel_world = voxel_world.reshape(shape[0], shape[1], shape[2]) + voxel_world = np.ascontiguousarray(voxel_world) + voxel_world = torch.from_numpy(voxel_world.astype(np.int32)) + return voxel_world + + +def gen_corner_voxel(voxel): + r"""Converting voxel center array to voxel corner array. The size of the + produced array grows by 1 on every dimension. + + Args: + voxel (torch.IntTensor, CPU): Input voxel of three dimensions + """ + structure = np.zeros([3, 3, 3], dtype=np.bool) + structure[1:, 1:, 1:] = True + voxel_p = F.pad(voxel, (0, 1, 0, 1, 0, 1)) + corners = ndimage.binary_dilation(voxel_p.numpy(), structure) + corners = torch.tensor(corners, dtype=torch.int32) + return corners + + +def calc_height_map(voxel_t): + r"""Calculate height map given a voxel grid [Y, X, Z] as input. + The height is defined as the Y index of the surface (non-air) block + + Args: + voxel (Y x X x Z torch.IntTensor, CPU): Input voxel of three dimensions + Output: + heightmap (X x Z torch.IntTensor) + """ + m, h = torch.max((torch.flip(voxel_t, [0]) != 0).int(), dim=0, keepdim=False) + heightmap = voxel_t.shape[0] - 1 - h + heightmap[m == 0] = 0 # Special case when the whole vertical column is empty + return heightmap + + +def trans_vec_homo(m, v, is_vec=False): + r"""3-dimensional Homogeneous matrix and regular vector multiplication + Convert v to homogeneous vector, perform M-V multiplication, and convert back + Note that this function does not support autograd. + + Args: + m (4 x 4 tensor): a homogeneous matrix + v (3 tensor): a 3-d vector + vec (bool): if true, v is direction. Otherwise v is point + """ + if is_vec: + v = torch.tensor([v[0], v[1], v[2], 0], dtype=v.dtype) + else: + v = torch.tensor([v[0], v[1], v[2], 1], dtype=v.dtype) + v = torch.mv(m, v) + if not is_vec: + v = v / v[3] + v = v[:3] + return v + + +def cumsum_exclusive(tensor, dim): + cumsum = torch.cumsum(tensor, dim) + cumsum = torch.roll(cumsum, 1, dim) + cumsum.index_fill_(dim, torch.tensor([0], dtype=torch.long, device=tensor.device), 0) + return cumsum + + +def sample_depth_batched(depth2, nsamples, deterministic=False, use_box_boundaries=True, sample_depth=4): + r""" Make best effort to sample points within the same distance for every ray. + Exception: When there is not enough voxel. + + Args: + depth2 (N x 2 x 256 x 256 x 4 x 1 tensor): + - N: Batch. + - 2: Entrance / exit depth for each intersected box. + - 256, 256: Height, Width. + - 4: Number of intersected boxes along the ray. + - 1: One extra dim for consistent tensor dims. + depth2 can include NaNs. + deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling. + use_box_boundaries (bool): Whether to add the entrance / exit points into the sample. + sample_depth (float): Truncate the ray when it travels further than sample_depth inside voxels. + """ + + bs = depth2.size(0) + dim0 = depth2.size(2) + dim1 = depth2.size(3) + dists = depth2[:, 1] - depth2[:, 0] + dists[torch.isnan(dists)] = 0 # N, 256, 256, 4, 1 + accu_depth = torch.cumsum(dists, dim=-2) # N, 256, 256, 4, 1 + total_depth = accu_depth[..., [-1], :] # N, 256, 256, 1, 1 + + total_depth = torch.clamp(total_depth, None, sample_depth) + + # Ignore out of range box boundaries. Fill with random samples. + if use_box_boundaries: + boundary_samples = accu_depth.clone().detach() + boundary_samples_filler = torch.rand_like(boundary_samples) * total_depth + bad_mask = (accu_depth > sample_depth) | (dists == 0) + boundary_samples[bad_mask] = boundary_samples_filler[bad_mask] + + rand_shape = [bs, dim0, dim1, nsamples, 1] + # 256, 256, N, 1 + if deterministic: + rand_samples = torch.empty(rand_shape, dtype=total_depth.dtype, device=total_depth.device) + rand_samples[..., :, 0] = torch.linspace(0, 1, nsamples+2)[1:-1] + else: + rand_samples = torch.rand(rand_shape, dtype=total_depth.dtype, device=total_depth.device) # 256, 256, N, 1 + # Stratified sampling as in NeRF + rand_samples = rand_samples / nsamples + rand_samples[..., :, 0] += torch.linspace(0, 1, nsamples+1, device=rand_samples.device)[:-1] + rand_samples = rand_samples * total_depth # 256, 256, N, 1 + + # Can also include boundaries + if use_box_boundaries: + rand_samples = torch.cat([rand_samples, boundary_samples, torch.zeros( + [bs, dim0, dim1, 1, 1], dtype=total_depth.dtype, device=total_depth.device)], dim=-2) + rand_samples, _ = torch.sort(rand_samples, dim=-2, descending=False) + + midpoints = (rand_samples[..., 1:, :] + rand_samples[..., :-1, :]) / 2 + new_dists = rand_samples[..., 1:, :] - rand_samples[..., :-1, :] + + # Scatter the random samples back + # 256, 256, 1, M, 1 > 256, 256, N, 1, 1 + idx = torch.sum(midpoints.unsqueeze(-3) > accu_depth.unsqueeze(-2), dim=-3) # 256, 256, M, 1 + # print(idx.shape, idx.max(), idx.min()) # max 3, min 0 + + depth_deltas = depth2[:, 0, :, :, 1:, :] - depth2[:, 1, :, :, :-1, :] # There might be NaNs! + depth_deltas = torch.cumsum(depth_deltas, dim=-2) + depth_deltas = torch.cat([depth2[:, 0, :, :, [0], :], depth_deltas+depth2[:, 0, :, :, [0], :]], dim=-2) + heads = torch.gather(depth_deltas, -2, idx) # 256 256 M 1 + # heads = torch.gather(depth2[0], -2, idx) # 256 256 M 1 + + # print(torch.any(torch.isnan(heads))) + rand_depth = heads + midpoints # 256 256 N 1 + + return rand_depth, new_dists, idx + + +def volum_rendering_relu(sigma, dists, dim=2): + free_energy = F.relu(sigma) * dists + + a = 1 - torch.exp(-free_energy.float()) # probability of it is not empty here + b = torch.exp(-cumsum_exclusive(free_energy, dim=dim)) # probability of everything is empty up to now + probs = a * b # probability of the ray hits something here + + return probs + +class MCLabelTranslator: + r"""Resolving mapping across Minecraft voxel, coco-stuff label and reduced label set.""" + + def __init__(self): + this_path = os.path.dirname(os.path.abspath(__file__)) + # Load voxel name lut + id2name_lut = {} + id2color_lut = {} + id2glbl_lut = {} + with open(os.path.join(this_path, 'id2name_gg.csv'), newline='') as csvfile: + csvreader = csv.reader(csvfile, delimiter=',') + for row in csvreader: + id2name_lut[int(row[0])] = row[1] + id2color_lut[int(row[0])] = int(row[2]) + id2glbl_lut[int(row[0])] = row[3] + + # Load GauGAN color lut + glbl2color_lut = {} + glbl2cocoidx_lut = {} + with open(os.path.join(this_path, 'gaugan_lbl2col.csv'), newline='') as csvfile: + csvreader = csv.reader(csvfile, delimiter=',') + cocoidx = 1 # 0 is "Others" + for row in csvreader: + color = int(row[1].lstrip('#'), 16) + glbl2color_lut[row[0]] = color + glbl2cocoidx_lut[row[0]] = cocoidx + cocoidx += 1 + + # Generate id2ggcolor lut + id2ggcolor_lut = {} + for k, v in id2glbl_lut.items(): + if v: + id2ggcolor_lut[k] = glbl2color_lut[v] + else: + id2ggcolor_lut[k] = 0 + + # Generate id2cocoidx + id2cocoidx_lut = {} + for k, v in id2glbl_lut.items(): + if v: + id2cocoidx_lut[k] = glbl2cocoidx_lut[v] + else: + id2cocoidx_lut[k] = 0 + + self.id2color_lut = id2color_lut + self.id2name_lut = id2name_lut + self.id2glbl_lut = id2glbl_lut + self.id2ggcolor_lut = id2ggcolor_lut + self.id2cocoidx_lut = id2cocoidx_lut + + if True: + mapper = ReducedLabelMapper() + mcid2rdid_lut = mapper.mcid2rdid_lut + mcid2rdid_lut = torch.tensor(mcid2rdid_lut, dtype=torch.long) + self.mcid2rdid_lut = mcid2rdid_lut + self.num_reduced_lbls = len(mapper.reduced_lbls) + self.ignore_id = mapper.ignore_id + self.dirt_id = mapper.dirt_id + self.water_id = mapper.water_id + + self.mapper = mapper + + ggid2rdid_lut = mapper.ggid2rdid + [0] # Last index is ignore + ggid2rdid_lut = torch.tensor(ggid2rdid_lut, dtype=torch.long) + self.ggid2rdid_lut = ggid2rdid_lut + if True: + mc2coco_lut = list(zip(*sorted([(k, v) for k, v in self.id2cocoidx_lut.items()])))[1] + mc2coco_lut = torch.tensor(mc2coco_lut, dtype=torch.long) + self.mc2coco_lut = mc2coco_lut + + def gglbl2ggid(self, gglbl): + return self.mapper.gglbl2ggid[gglbl] + + def mc2coco(self, mc): + self.mc2coco_lut = self.mc2coco_lut.to(mc.device) + coco = self.mc2coco_lut[mc.long()] + return coco + + def mc2reduced(self, mc, ign2dirt=False): + self.mcid2rdid_lut = self.mcid2rdid_lut.to(mc.device) + reduced = self.mcid2rdid_lut[mc.long()] + if ign2dirt: + reduced[reduced == self.ignore_id] = self.dirt_id + return reduced + + def coco2reduced(self, coco): + self.ggid2rdid_lut = self.ggid2rdid_lut.to(coco.device) + reduced = self.ggid2rdid_lut[coco.long()] + return reduced + + def get_num_reduced_lbls(self): + return self.num_reduced_lbls + + @staticmethod + def uint32_to_4uint8(x): + dt1 = np.dtype(('i4', [('bytes', 'u1', 4)])) + color = x.view(dtype=dt1)['bytes'] + return color + + def mc_color(self, img): + r"""Obtaining Minecraft default color. + + Args: + img (H x W x 1 int32 numpy tensor): Segmentation map. + """ + lut = self.id2color_lut + lut = list(zip(*sorted([(k, v) for k, v in lut.items()])))[1] + lut = np.array(lut, dtype=np.uint32) + rgb = lut[img] + rgb = self.uint32_to_4uint8(rgb)[..., :3] + + return rgb + + +def rand_crop(cam_c, cam_res, target_res): + r"""Produces a new cam_c so that the effect of rendering with the new cam_c and target_res is the same as rendering + with the old parameters and then crop out target_res. + """ + d0 = np.random.randint(cam_res[0] - target_res[0] + 1) + d1 = np.random.randint(cam_res[1] - target_res[1] + 1) + cam_c = [cam_c[0]-d0, cam_c[1]-d1] + return cam_c + + +def segmask_smooth(seg_mask, kernel_size=7): + labels = F.avg_pool2d(seg_mask, kernel_size, 1, kernel_size//2) + onehot_idx = torch.argmax(labels, dim=1, keepdims=True) + labels.fill_(0.0) + labels.scatter_(1, onehot_idx, 1.0) + return labels + + +def colormap(x, cmap='viridis'): + x = np.nan_to_num(x, np.nan, np.nan, np.nan) + x = x - np.nanmin(x) + x = x / np.nanmax(x) + rgb = plt.get_cmap(cmap)(x)[..., :3] + return rgb diff --git a/imaginaire/model_utils/gancraft/reduced_coco_lbls.csv b/imaginaire/model_utils/gancraft/reduced_coco_lbls.csv new file mode 100644 index 0000000000000000000000000000000000000000..c82cc05572bbace78643911e9789f4f2cfd15f0e --- /dev/null +++ b/imaginaire/model_utils/gancraft/reduced_coco_lbls.csv @@ -0,0 +1,12 @@ +ignore +sky +tree +dirt +flower +grass +gravel +water +rock +stone +sand +snow \ No newline at end of file diff --git a/imaginaire/model_utils/gancraft/voxlib/Makefile b/imaginaire/model_utils/gancraft/voxlib/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..de903af09d2feda89118edc743048d70ce963d24 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/Makefile @@ -0,0 +1,11 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md + +all: + python setup.py build_ext --inplace + python setup.py install + +clean: + rm -rf *.o *.a *.so test build diff --git a/imaginaire/model_utils/gancraft/voxlib/__init__.py b/imaginaire/model_utils/gancraft/voxlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fce15c92b99ef0ad9feaeb31d75664a0971b385 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/__init__.py @@ -0,0 +1,7 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .positional_encoding import positional_encoding +from .sp_trilinear import sparse_trilinear_interp_worldcoord +from voxlib import ray_voxel_intersection_perspective diff --git a/imaginaire/model_utils/gancraft/voxlib/positional_encoding.py b/imaginaire/model_utils/gancraft/voxlib/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..ef95d0bd47103233e1c4f32c70dd5b463c9eac1d --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/positional_encoding.py @@ -0,0 +1,63 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch.autograd import Function +import voxlib + +# Cheatsheet: +# mark_dirty() must be used to mark any input that is modified inplace by the forward function. +# mark_non_differentiable() + + +class PositionalEncodingFunction(Function): + @staticmethod + def forward(ctx, in_feature, pe_degrees, dim, incl_orig): + out_feature = voxlib.positional_encoding(in_feature, pe_degrees, dim, incl_orig) + + ctx.save_for_backward(out_feature) + ctx.pe_degrees = pe_degrees + ctx.dim = dim + ctx.incl_orig = incl_orig + + return out_feature + + @staticmethod + def backward(ctx, out_feature_grad): + out_feature, = ctx.saved_tensors + + # torch::Tensor positional_encoding_backward(const torch::Tensor& out_feature_grad, + # const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig) { + in_feature_grad = voxlib.positional_encoding_backward( + out_feature_grad, out_feature, ctx.pe_degrees, ctx.dim, ctx.incl_orig) + + return in_feature_grad, None, None, None + + +def positional_encoding(in_feature, pe_degrees, dim=-1, incl_orig=False): + return PositionalEncodingFunction.apply(in_feature, pe_degrees, dim, incl_orig) + +# input: N, C +# output: N, pe_degrees*C + + +def positional_encoding_pt(pts, pe_degrees, dim=-1, incl_orig=False): + import numpy as np + pe_stor = [] + for i in range(pe_degrees): + pe_stor.append(torch.sin(pts * np.pi * 2 ** i)) + pe_stor.append(torch.cos(pts * np.pi * 2 ** i)) + if incl_orig: + pe_stor.append(pts) + pe = torch.cat(pe_stor, dim=dim) + return pe + + +if __name__ == '__main__': + x = torch.rand(384, 512, 5, 48).cuda() * 1024 + y = positional_encoding_pt(x, 4, incl_orig=True) + y2 = positional_encoding(x, 4, incl_orig=True) + + print(torch.abs(y - y2)) + print(torch.allclose(y, y2, rtol=1e-05, atol=1e-05)) diff --git a/imaginaire/model_utils/gancraft/voxlib/positional_encoding_kernel.cu b/imaginaire/model_utils/gancraft/voxlib/positional_encoding_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..af97ef1ce53aec55d248443f0ad17f7d545787a4 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/positional_encoding_kernel.cu @@ -0,0 +1,285 @@ +// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, check out LICENSE.md + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + + +#include +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + + +struct PE_Params { + int ndegrees; + int pre_size; + int post_size; + bool incl_orig; +}; + +// const int TILE_DIM_X = 16; // channel dim +// const int TILE_DIM_Y = 64; // entry dim +// dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, (p.pre_size+TILE_DIM_Y-1)/TILE_DIM_Y, 1); +// dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1); +template +__global__ void positional_encoding_kernel( + float* __restrict__ out_feature, + const float* __restrict__ in_feature, const PE_Params p) { + + const int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x; + const int idx_entry_base = blockIdx.y * TILE_DIM_Y * DUP_Y + threadIdx.y * DUP_Y; + if (idx_feat >= p.post_size) { + return; + } + + int stride = p.ndegrees*2; + if (p.incl_orig) { + stride += 1; + } + + for (int j=0; j= p.pre_size) { + return; + } + float data = in_feature[idx_entry*p.post_size + idx_feat]; + + for (int i=0; i +__global__ void positional_encoding_backward_kernel( + float* __restrict__ in_feature_grad, + const float* __restrict__ out_feature_grad, const float* __restrict__ out_feature, const PE_Params p) { + + int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x; + const int idx_entry_base = blockIdx.y * TILE_DIM_Y * DUP_Y + threadIdx.y * DUP_Y; + + if (idx_feat >= p.post_size) { + return; + } + + int stride = p.ndegrees*2; + if (p.incl_orig) { + stride += 1; + } + + for (int j=0; j= p.pre_size) { + return; + } + + float grad = 0.0f; + for (int i=0; i +torch::Tensor positional_encoding_cuda(const torch::Tensor& in_feature, int ndegrees, int dim, bool incl_orig) { + CHECK_CUDA(in_feature); + + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + torch::Device device = in_feature.device(); + + assert(in_feature.dtype() == torch::kFloat32); + + // Handle negative index + if (dim < 0) { + dim = in_feature.dim() + dim; + } + assert(dim >= 0 && dim < in_feature.dim()); + + // No need to be contiguous. Input and output has the same memory layout. + CHECK_CONTIGUOUS(in_feature); + + PE_Params p; + p.ndegrees = ndegrees; + p.incl_orig = incl_orig; + + // This only works for contiguous tensors... + int pre_size = 1; + int post_size = 1; + for (int i=0; i out_feature_shape; + for (int i=0; i Each thread handle a single post_size + // Case 2: Concat at the middle (post_size > pre_size) --> Each thread handle + const int TILE_DIM_X = 16; // channel dim + const int TILE_DIM_Y = 64; // entry dim + //const int DUP_Y = 4; // Each thread handle multiple entries to save threads + const int DUP_Y = 8; // DGXA 64 samples per ray @ 256x256 + dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, (p.pre_size+(TILE_DIM_Y*DUP_Y)-1)/(TILE_DIM_Y*DUP_Y), 1); + dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1); + positional_encoding_kernel<<>>( + out_feature.data_ptr(), + in_feature.data_ptr(), p + ); + + C10_CUDA_CHECK(cudaGetLastError()); + return out_feature; +} + +//in_feature_grad = voxrender_op.positional_encoding_backward(out_feature_grad, out_feature, ctx.pe_degrees, ctx.dim, ctx.incl_orig); +// Input: +// out_feature_grad: float32 [..., N*ndegree*2+incl_orig, ...] +// out_feature: float32 [..., N*ndegree*2+incl_orig, ...] +// ndegrees: int32 Degrees of PE encoding +// dim: int32 Dimension to concatenate +// incl_orig: bool Whether to include original feature vector or not +// Output: +// in_feature_grad: float32 [..., N, ...] +// std::vector +torch::Tensor positional_encoding_backward_cuda(const torch::Tensor& out_feature_grad_, const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig) { + CHECK_CUDA(out_feature_grad_); + CHECK_CUDA(out_feature); + + const torch::Tensor out_feature_grad = out_feature_grad_.contiguous(); + + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + torch::Device device = out_feature_grad.device(); + + assert(out_feature_grad.dtype() == torch::kFloat32); + assert(out_feature.dtype() == torch::kFloat32); + assert(out_feature_grad.sizes() == out_feature.sizes()); + + // Handle negative index + if (dim < 0) { + dim = out_feature.dim() + dim; + } + assert(dim >= 0 && dim < out_feature.dim()); + + CHECK_CONTIGUOUS(out_feature_grad); + CHECK_CONTIGUOUS(out_feature); + + PE_Params p; + p.ndegrees = ndegrees; + p.incl_orig = incl_orig; + + int expansion_factor = ndegrees*2; + if (incl_orig) { + expansion_factor += 1; + } + // This only works for contiguous tensors... + int pre_size = 1; + int post_size = 1; + for (int i=0; i out_feature_shape; + for (int i=0; i Each thread handle a single post_size + // Case 2: Concat at the middle (post_size > pre_size) --> Each thread handle + const int TILE_DIM_X = 16; // channel dim + const int TILE_DIM_Y = 64; // entry dim + //const int DUP_Y = 4; // Nothing to amortize + const int DUP_Y = 8; // DGXA + dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, (p.pre_size+(TILE_DIM_Y*DUP_Y)-1)/(TILE_DIM_Y*DUP_Y), 1); + dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1); + positional_encoding_backward_kernel<<>>( + in_feature_grad.data_ptr(), + out_feature_grad.data_ptr(), out_feature.data_ptr(), p + ); + + C10_CUDA_CHECK(cudaGetLastError()); + + return in_feature_grad; +} diff --git a/imaginaire/model_utils/gancraft/voxlib/ray_voxel_intersection.cu b/imaginaire/model_utils/gancraft/voxlib/ray_voxel_intersection.cu new file mode 100644 index 0000000000000000000000000000000000000000..7ef22dc309e2eb6d944c50d917235f0c62219cb6 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/ray_voxel_intersection.cu @@ -0,0 +1,325 @@ +// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, check out LICENSE.md +// +// The ray marching algorithm used in this file is a variety of modified Bresenham method: +// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.42.3443&rep=rep1&type=pdf +// Search for "voxel traversal algorithm" for related information + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +//#include +#include +#include +#include + +#include "voxlib_common.h" + +struct RVIP_Params { + int voxel_dims[3]; + int voxel_strides[3]; + int max_samples; + int img_dims[2]; + // Camera parameters + float cam_ori[3]; + float cam_fwd[3]; + float cam_side[3]; + float cam_up[3]; + float cam_c[2]; + float cam_f; + //unsigned long seed; +}; + +/* + out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples, 1] + out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples, 1] + out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1, 3] + Image coordinates refer to the center of the pixel + [0, 0, 0] at voxel coordinate is at the corner of the corner block (instead of at the center) +*/ +template +static __global__ void ray_voxel_intersection_perspective_kernel(int32_t* __restrict__ out_voxel_id, float* __restrict__ out_depth, float* __restrict__ out_raydirs, +const int32_t* __restrict__ in_voxel, const RVIP_Params p) { + + int img_coords[2]; + img_coords[1] = blockIdx.x*TILE_DIM+threadIdx.x; + img_coords[0] = blockIdx.y*TILE_DIM+threadIdx.y; + if (img_coords[0] >= p.img_dims[0] || img_coords[1] >= p.img_dims[1]) { + return; + } + int pix_index = img_coords[0] * p.img_dims[1] + img_coords[1]; + + // Calculate ray origin and direction + float rayori[3], raydir[3]; + rayori[0] = p.cam_ori[0]; + rayori[1] = p.cam_ori[1]; + rayori[2] = p.cam_ori[2]; + + // Camera intrinsics + float ndc_imcoords[2]; + ndc_imcoords[0] = p.cam_c[0] - (float)img_coords[0]; // Flip height + ndc_imcoords[1] = (float)img_coords[1] - p.cam_c[1]; + + raydir[0] = p.cam_up[0] * ndc_imcoords[0] + p.cam_side[0] * ndc_imcoords[1] + p.cam_fwd[0] * p.cam_f; + raydir[1] = p.cam_up[1] * ndc_imcoords[0] + p.cam_side[1] * ndc_imcoords[1] + p.cam_fwd[1] * p.cam_f; + raydir[2] = p.cam_up[2] * ndc_imcoords[0] + p.cam_side[2] * ndc_imcoords[1] + p.cam_fwd[2] * p.cam_f; + normalize(raydir); + + // Save out_raydirs + out_raydirs[pix_index*3] = raydir[0]; + out_raydirs[pix_index*3+1] = raydir[1]; + out_raydirs[pix_index*3+2] = raydir[2]; + + float axis_t[3]; + int axis_int[3]; + //int axis_intbound[3]; + + // Current voxel + axis_int[0] = floorf(rayori[0]); + axis_int[1] = floorf(rayori[1]); + axis_int[2] = floorf(rayori[2]); + + #pragma unroll + for (int i=0; i<3; i++) { + if (raydir[i] > 0) { + // Initial t value + // Handle boundary case where rayori[i] is a whole number. Always round Up for the next block + //axis_t[i] = (ceilf(nextafterf(rayori[i], HUGE_VALF)) - rayori[i]) / raydir[i]; + axis_t[i] = ((float)(axis_int[i]+1) - rayori[i]) / raydir[i]; + } else if (raydir[i] < 0) { + axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i]; + } else { + axis_t[i] = HUGE_VALF; + } + } + + // Fused raymarching and sampling + bool quit = false; + for (int cur_plane=0; cur_plane < p.max_samples; cur_plane++) { // Last cycle is for calculating p2 + float t = nanf("0"); + float t2 = nanf("0"); + int32_t blk_id = 0; + // Find the next intersection + while (!quit) { + // Find the next smallest t + float tnow; + /* + #pragma unroll + for (int i=0; i<3; i++) { + if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) { + // Update current t + tnow = axis_t[i]; + // Update t candidates + if (raydir[i] > 0) { + axis_int[i] += 1; + if (axis_int[i] >= p.voxel_dims[i]) { + quit = true; + } + axis_t[i] = ((float)(axis_int[i]+1) - rayori[i]) / raydir[i]; + } else { + axis_int[i] -= 1; + if (axis_int[i] < 0) { + quit = true; + } + axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i]; + } + break; // Avoid advancing multiple steps as axis_t is updated + } + } + */ + // Hand unroll + if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) { + // Update current t + tnow = axis_t[0]; + // Update t candidates + if (raydir[0] > 0) { + axis_int[0] += 1; + if (axis_int[0] >= p.voxel_dims[0]) { + quit = true; + } + axis_t[0] = ((float)(axis_int[0]+1) - rayori[0]) / raydir[0]; + } else { + axis_int[0] -= 1; + if (axis_int[0] < 0) { + quit = true; + } + axis_t[0] = ((float)axis_int[0] - rayori[0]) / raydir[0]; + } + } else if (axis_t[1] <= axis_t[2]) { + tnow = axis_t[1]; + if (raydir[1] > 0) { + axis_int[1] += 1; + if (axis_int[1] >= p.voxel_dims[1]) { + quit = true; + } + axis_t[1] = ((float)(axis_int[1]+1) - rayori[1]) / raydir[1]; + } else { + axis_int[1] -= 1; + if (axis_int[1] < 0) { + quit = true; + } + axis_t[1] = ((float)axis_int[1] - rayori[1]) / raydir[1]; + } + } else { + tnow = axis_t[2]; + if (raydir[2] > 0) { + axis_int[2] += 1; + if (axis_int[2] >= p.voxel_dims[2]) { + quit = true; + } + axis_t[2] = ((float)(axis_int[2]+1) - rayori[2]) / raydir[2]; + } else { + axis_int[2] -= 1; + if (axis_int[2] < 0) { + quit = true; + } + axis_t[2] = ((float)axis_int[2] - rayori[2]) / raydir[2]; + } + } + + if (quit) { + break; + } + + // Skip empty space + // Could there be deadlock if the ray direction is away from the world? + if (axis_int[0] < 0 || axis_int[0] >= p.voxel_dims[0] || axis_int[1] < 0 || axis_int[1] >= p.voxel_dims[1] || axis_int[2] < 0 || axis_int[2] >= p.voxel_dims[2]) { + continue; + } + + // Test intersection using voxel grid + blk_id = in_voxel[axis_int[0]*p.voxel_strides[0] + axis_int[1]*p.voxel_strides[1] + axis_int[2]*p.voxel_strides[2]]; + if (blk_id == 0) { + continue; + } + + // Now that there is an intersection + t = tnow; + // Calculate t2 + /* + #pragma unroll + for (int i=0; i<3; i++) { + if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) { + t2 = axis_t[i]; + break; + } + } + */ + // Hand unroll + if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) { + t2 = axis_t[0]; + } else if (axis_t[1] <= axis_t[2]) { + t2 = axis_t[1]; + } else { + t2 = axis_t[2]; + } + break; + } // while !quit (ray marching loop) + + out_depth[pix_index*p.max_samples+cur_plane] = t; + out_depth[p.img_dims[0]*p.img_dims[1]*p.max_samples + pix_index*p.max_samples+cur_plane] = t2; + out_voxel_id[pix_index*p.max_samples+cur_plane] = blk_id; + } // cur_plane +} + + +/* + out: + out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples, 1] + out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples, 1] + out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1, 3] + in: + in_voxel: torch CUDA int32 [X, Y, Z] [40, 512, 512] + cam_ori: torch float [3] + cam_dir: torch float [3] + cam_up: torch float [3] + cam_f: float + cam_c: int [2] + img_dims: int [2] + max_samples: int +*/ +std::vector ray_voxel_intersection_perspective_cuda(const torch::Tensor& in_voxel, const torch::Tensor& cam_ori, const torch::Tensor& cam_dir, const torch::Tensor& cam_up, float cam_f, const std::vector& cam_c, const std::vector& img_dims, int max_samples) { + CHECK_CUDA(in_voxel); + + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + torch::Device device = in_voxel.device(); + + //assert(in_voxel.dtype() == torch::kU8); + assert(in_voxel.dtype() == torch::kInt32); // Minecraft compatibility + assert(in_voxel.dim() == 3); + assert(cam_ori.dtype() == torch::kFloat32); + assert(cam_ori.numel() == 3); + assert(cam_dir.dtype() == torch::kFloat32); + assert(cam_dir.numel() == 3); + assert(cam_up.dtype() == torch::kFloat32); + assert(cam_up.numel() == 3); + assert(img_dims.size() == 2); + + RVIP_Params p; + + // Calculate camera rays + const torch::Tensor cam_ori_c = cam_ori.cpu(); + const torch::Tensor cam_dir_c = cam_dir.cpu(); + const torch::Tensor cam_up_c = cam_up.cpu(); + + // Get the coordinate frame of camera space in world space + normalize(p.cam_fwd, cam_dir_c.data_ptr()); + cross(p.cam_side, p.cam_fwd, cam_up_c.data_ptr()); + normalize(p.cam_side); + cross(p.cam_up, p.cam_side, p.cam_fwd); + normalize(p.cam_up); // Not absolutely necessary as both vectors are normalized. But just in case... + + copyarr(p.cam_ori, cam_ori_c.data_ptr()); + + p.cam_f = cam_f; + p.cam_c[0] = cam_c[0]; + p.cam_c[1] = cam_c[1]; + p.max_samples = max_samples; + //printf("[Renderer] max_dist: %ld\n", max_dist); + + p.voxel_dims[0] = in_voxel.size(0); + p.voxel_dims[1] = in_voxel.size(1); + p.voxel_dims[2] = in_voxel.size(2); + p.voxel_strides[0] = in_voxel.stride(0); + p.voxel_strides[1] = in_voxel.stride(1); + p.voxel_strides[2] = in_voxel.stride(2); + + //printf("[Renderer] Voxel resolution: %ld, %ld, %ld\n", p.voxel_dims[0], p.voxel_dims[1], p.voxel_dims[2]); + + p.img_dims[0] = img_dims[0]; + p.img_dims[1] = img_dims[1]; + + // Create output tensors + // For Minecraft Seg Mask + torch::Tensor out_voxel_id = torch::empty({p.img_dims[0], p.img_dims[1], p.max_samples, 1}, torch::TensorOptions().dtype(torch::kInt32).device(device)); + + torch::Tensor out_depth; + // Produce two sets of localcoords, one for entry point, the other one for exit point. They share the same corner_ids. + out_depth = torch::empty({2, p.img_dims[0], p.img_dims[1], p.max_samples, 1}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); + + torch::Tensor out_raydirs = torch::empty({p.img_dims[0], p.img_dims[1], 1, 3}, torch::TensorOptions().dtype(torch::kFloat32).device(device).requires_grad(false)); + + const int TILE_DIM = 8; + dim3 dimGrid((p.img_dims[1]+TILE_DIM-1)/TILE_DIM, (p.img_dims[0]+TILE_DIM-1)/TILE_DIM, 1); + dim3 dimBlock(TILE_DIM, TILE_DIM, 1); + + ray_voxel_intersection_perspective_kernel<<>>( + out_voxel_id.data_ptr(), out_depth.data_ptr(), out_raydirs.data_ptr(), in_voxel.data_ptr(), p + ); + + return {out_voxel_id, out_depth, out_raydirs}; +} diff --git a/imaginaire/model_utils/gancraft/voxlib/setup.py b/imaginaire/model_utils/gancraft/voxlib/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..1eca848211370c8ddf9dda55d5c67804a73061e9 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/setup.py @@ -0,0 +1,25 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +cxx_args = ['-fopenmp'] +nvcc_args = [] + +setup( + name='voxrender', + ext_modules=[ + CUDAExtension('voxlib', [ + 'voxlib.cpp', + 'ray_voxel_intersection.cu', + 'sp_trilinear_worldcoord_kernel.cu', + 'positional_encoding_kernel.cu' + ], + extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args} + ) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/imaginaire/model_utils/gancraft/voxlib/sp_trilinear.py b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear.py new file mode 100644 index 0000000000000000000000000000000000000000..1bad56fb23f6b8e2a8e41573f8b6b85f9a5693f1 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear.py @@ -0,0 +1,35 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from torch.autograd import Function +import voxlib + +""" +It takes world coordinate as input instead of block-local coordinate. Corner IDs are looked up on-the-fly to +save memory. +""" + + +class SparseTrilinearWorldCoordFunction(Function): + @staticmethod + def forward(ctx, in_feature, corner_lut_t, in_worldcoord, ign_zero): + + out_feature = voxlib.sp_trilinear_worldcoord(in_feature, corner_lut_t, in_worldcoord, ign_zero, -1) + ctx.ign_zero = ign_zero + ctx.save_for_backward(in_feature, corner_lut_t, in_worldcoord) + + return out_feature + + @staticmethod + def backward(ctx, out_feature_grad): + in_feature, corner_lut_t, in_worldcoord = ctx.saved_tensors + + assert ctx.needs_input_grad[2] is False + in_feature_grad, = voxlib.sp_trilinear_worldcoord_backward( + out_feature_grad, in_feature, corner_lut_t, in_worldcoord, ctx.ign_zero, False) + return in_feature_grad, None, None, None, None + + +def sparse_trilinear_interp_worldcoord(in_feature, corner_lut_t, in_worldcoord, ign_zero=False): + return SparseTrilinearWorldCoordFunction.apply(in_feature, corner_lut_t, in_worldcoord, ign_zero) diff --git a/imaginaire/model_utils/gancraft/voxlib/sp_trilinear_worldcoord_kernel.cu b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear_worldcoord_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..92f0378a746b74f7527c16c97d4c3d14feaf5881 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear_worldcoord_kernel.cu @@ -0,0 +1,527 @@ +// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, check out LICENSE.md +// +// Fast routine for sparse tri-linear interpolation of high dimensional features. +// Ignore label is supported. + + +#include + +#include +#include +#include +#include + +#include +#include +#include + + +#include +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + + +struct SpTrilinear_wc_Params { + int in_feature_dim; + int in_feature_numentries; + int corner_lut_dims[3]; + int corner_lut_strides[3]; + int in_worldcoord_dims[8]; + int in_worldcoord_strides[8]; + int in_worldcoord_ndim; + int out_feature_dims[8]; + int out_feature_strides[8]; + bool ign_zero; +}; + + +// out_feature.data_ptr(), +// in_feature.data_ptr(), corner_lut_t.data_ptr(), in_worldcoord.data_ptr(), p +template +__global__ void sp_trilinear_worldcoord_kernel( + float* __restrict__ out_feature, + const float* __restrict__ in_feature, const int32_t* __restrict__ corner_lut_t, const float* __restrict__ in_worldcoord, SpTrilinear_wc_Params p) { + + const int GRID_X = gridDim.y; + int idx_entry = blockIdx.x * TILE_DIM_Y + threadIdx.y; + + // Index processing + //int index[7]; + int t = idx_entry; + int idx_in_worldcoord = 0; + int idx_out_feature = 0; + for (int i=p.in_worldcoord_ndim-2; i>=0; i--) { + int idx_t = t % p.in_worldcoord_dims[i]; + t = t / p.in_worldcoord_dims[i]; + idx_in_worldcoord += p.in_worldcoord_strides[i] * idx_t; + idx_out_feature += p.out_feature_strides[i] * idx_t; + } + if (t > 0) { + return; + } + int stride_in_worldcoord = p.in_worldcoord_strides[p.in_worldcoord_ndim-1]; + int stride_out_feature = p.out_feature_strides[p.in_worldcoord_ndim-1]; + + + float world_coords[3]; + world_coords[0] = in_worldcoord[idx_in_worldcoord]; + world_coords[1] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord]; + world_coords[2] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord*2]; + + float local_coords[3]; + int vox_coords[3]; + local_coords[0] = world_coords[0] - floorf(world_coords[0]); + vox_coords[0] = (int)floorf(world_coords[0]); + local_coords[1] = world_coords[1] - floorf(world_coords[1]); + vox_coords[1] = (int)floorf(world_coords[1]); + local_coords[2] = world_coords[2] - floorf(world_coords[2]); + vox_coords[2] = (int)floorf(world_coords[2]); + + float interp_weight[8]; + // 0,0,0 + interp_weight[0] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]); + // 0,0,1 + interp_weight[1] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]); + // 0,1,0 + interp_weight[2] = (1.0f-local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]); + // 0,1,1 + interp_weight[3] = (1.0f-local_coords[0])*(local_coords[1])*(local_coords[2]); + // 1,0,0 + interp_weight[4] = (local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]); + // 1,0,1 + interp_weight[5] = (local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]); + // 1,1,0 + interp_weight[6] = (local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]); + // 1,1,1 + interp_weight[7] = (local_coords[0])*(local_coords[1])*(local_coords[2]); + + int indices[8]; + // Hard boundary check (zero padding) + if (isnan(world_coords[0]) || isnan(world_coords[1]) || isnan(world_coords[2])) { + indices[0] = -1;indices[1] = -1;indices[2] = -1;indices[3] = -1; + indices[4] = -1;indices[5] = -1;indices[6] = -1;indices[7] = -1; + } else { + // Clamp to boundaries + int vox_coords_1[3]; + vox_coords_1[0] = min(max(vox_coords[0]+1, 0), p.corner_lut_dims[0]-1); + vox_coords_1[1] = min(max(vox_coords[1]+1, 0), p.corner_lut_dims[1]-1); + vox_coords_1[2] = min(max(vox_coords[2]+1, 0), p.corner_lut_dims[2]-1); + vox_coords[0] = min(max(vox_coords[0], 0), p.corner_lut_dims[0]-1); + vox_coords[1] = min(max(vox_coords[1], 0), p.corner_lut_dims[1]-1); + vox_coords[2] = min(max(vox_coords[2], 0), p.corner_lut_dims[2]-1); + int idx_corner_lut; + // 000 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[0] = corner_lut_t[idx_corner_lut]; + // 001 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[1] = corner_lut_t[idx_corner_lut]; + // 010 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[2] = corner_lut_t[idx_corner_lut]; + // 011 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[3] = corner_lut_t[idx_corner_lut]; + // 100 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[4] = corner_lut_t[idx_corner_lut]; + // 101 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[5] = corner_lut_t[idx_corner_lut]; + // 110 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[6] = corner_lut_t[idx_corner_lut]; + // 111 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[7] = corner_lut_t[idx_corner_lut]; + } + + if (p.ign_zero) { + // Zero indices are to be ignored +#pragma unroll + for (int i=0; i<8; i++) { + indices[i] -= 1; + } + } + + //int idx_feat = blockIdx.x * TILE_DIM_X * DUP_X + threadIdx.x; + int idx_feat = blockIdx.y * TILE_DIM_X + threadIdx.x; + for (int i=0; i= p.in_feature_dim) { + return; + } + float interp_feat = 0.0f; +#pragma unroll + for (int j=0; j<8; j++) { + if (indices[j] >= 0) { + interp_feat = fmaf(in_feature[indices[j]*p.in_feature_dim+idx_feat], interp_weight[j], interp_feat); + } + } + //out_feature[idx_entry*p.in_feature_dim+idx_feat] = interp_feat; + out_feature[idx_out_feature+stride_out_feature*idx_feat] = interp_feat; + //idx_feat += TILE_DIM_X; + idx_feat += TILE_DIM_X * GRID_X; + } +} + + +//sp_trilinear_worldcoord_backward2feature_kernel<<>>( +// in_feature_grad.data_ptr(), +// out_feature_grad.data_ptr(), in_feature.data_ptr(), in_corner_lut.data_ptr(), in_worldcoord.data_ptr(), p +// Backward to feature +template +__global__ void sp_trilinear_worldcoord_backward2feature_kernel( + float* __restrict__ in_feature_grad, + const float* __restrict__ out_feature_grad, const int32_t* __restrict__ corner_lut_t, const float* __restrict__ in_worldcoord, SpTrilinear_wc_Params p) { + + const int GRID_X = gridDim.x; + int idx_entry = blockIdx.y * TILE_DIM_Y + threadIdx.y; + + // Index processing + //int index[7]; + int t = idx_entry; + int idx_in_worldcoord = 0; + int idx_out_feature = 0; + for (int i=p.in_worldcoord_ndim-2; i>=0; i--) { + int idx_t = t % p.in_worldcoord_dims[i]; + t = t / p.in_worldcoord_dims[i]; + //index[i] = idx_t; + idx_in_worldcoord += p.in_worldcoord_strides[i] * idx_t; + idx_out_feature += p.out_feature_strides[i] * idx_t; + } + if (t > 0) { + return; + } + int stride_in_worldcoord = p.in_worldcoord_strides[p.in_worldcoord_ndim-1]; + int stride_out_feature = p.out_feature_strides[p.in_worldcoord_ndim-1]; + + float world_coords[3]; + world_coords[0] = in_worldcoord[idx_in_worldcoord]; + world_coords[1] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord]; + world_coords[2] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord*2]; + + float local_coords[3]; + int vox_coords[3]; + local_coords[0] = world_coords[0] - floorf(world_coords[0]); + vox_coords[0] = (int)floorf(world_coords[0]); + local_coords[1] = world_coords[1] - floorf(world_coords[1]); + vox_coords[1] = (int)floorf(world_coords[1]); + local_coords[2] = world_coords[2] - floorf(world_coords[2]); + vox_coords[2] = (int)floorf(world_coords[2]); + + float interp_weight[8]; + // 0,0,0 + interp_weight[0] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]); + // 0,0,1 + interp_weight[1] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]); + // 0,1,0 + interp_weight[2] = (1.0f-local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]); + // 0,1,1 + interp_weight[3] = (1.0f-local_coords[0])*(local_coords[1])*(local_coords[2]); + // 1,0,0 + interp_weight[4] = (local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]); + // 1,0,1 + interp_weight[5] = (local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]); + // 1,1,0 + interp_weight[6] = (local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]); + // 1,1,1 + interp_weight[7] = (local_coords[0])*(local_coords[1])*(local_coords[2]); + + int indices[8]; + // Hard boundary check (zero padding) + if (isnan(world_coords[0]) || isnan(world_coords[1]) || isnan(world_coords[2])) {// || + //vox_coords[0] < 0 || vox_coords[0] >= (p.corner_lut_dims[0]-1) || + //vox_coords[1] < 0 || vox_coords[1] >= (p.corner_lut_dims[1]-1) || + //vox_coords[2] < 0 || vox_coords[2] >= (p.corner_lut_dims[2]-1)) { + indices[0] = -1;indices[1] = -1;indices[2] = -1;indices[3] = -1; + indices[4] = -1;indices[5] = -1;indices[6] = -1;indices[7] = -1; + } else { + // Clamp to boundaries + int vox_coords_1[3]; + vox_coords_1[0] = min(max(vox_coords[0]+1, 0), p.corner_lut_dims[0]-1); + vox_coords_1[1] = min(max(vox_coords[1]+1, 0), p.corner_lut_dims[1]-1); + vox_coords_1[2] = min(max(vox_coords[2]+1, 0), p.corner_lut_dims[2]-1); + vox_coords[0] = min(max(vox_coords[0], 0), p.corner_lut_dims[0]-1); + vox_coords[1] = min(max(vox_coords[1], 0), p.corner_lut_dims[1]-1); + vox_coords[2] = min(max(vox_coords[2], 0), p.corner_lut_dims[2]-1); + int idx_corner_lut; + // 000 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[0] = corner_lut_t[idx_corner_lut]; + // 001 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[1] = corner_lut_t[idx_corner_lut]; + // 010 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[2] = corner_lut_t[idx_corner_lut]; + // 011 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[3] = corner_lut_t[idx_corner_lut]; + // 100 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[4] = corner_lut_t[idx_corner_lut]; + // 101 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[5] = corner_lut_t[idx_corner_lut]; + // 110 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[6] = corner_lut_t[idx_corner_lut]; + // 111 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[7] = corner_lut_t[idx_corner_lut]; + } + + if (p.ign_zero) { +#pragma unroll + for (int i=0; i<8; i++) { + indices[i] -= 1; + } + } + + //int idx_feat = blockIdx.x * TILE_DIM_X * DUP_X + threadIdx.x; + int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x; + for (int i=0; i= p.in_feature_dim) { + return; + } + float grad = out_feature_grad[idx_out_feature+stride_out_feature*idx_feat]; +#pragma unroll + for (int j=0; j<8; j++) { + if (indices[j] >= 0) { + //indices[j]*p.in_feature_dim+idx_feat + atomicAdd(&in_feature_grad[indices[j]*p.in_feature_dim+idx_feat], grad * interp_weight[j]); + } + } + //idx_feat += TILE_DIM_X; + idx_feat += TILE_DIM_X * GRID_X; + } +} + +// in_feature, corner_lut_t, in_world_coord, ign_zero=False +// Input: +// in_feature: float32 [M C] +// in_corner_lut: int32 [X Y Z] +// in_worldcoord: float32 [..., 3] +// ---Index: int32 [..., 8], containing [0, M]. 0 is ignore label. +// ---Coord: float32 [..., 3] +// Output: +// Interp. Feat: float32 [..., C] +// std::vector +torch::Tensor sp_trilinear_worldcoord_cuda(const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, int channel_pos) { + CHECK_CUDA(in_feature); + CHECK_CUDA(in_corner_lut); + CHECK_CUDA(in_worldcoord); + + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + torch::Device device = in_feature.device(); + + // assert(tensor.sizes() == std::vector{3, 4, 5}); + assert(in_feature.dtype() == torch::kFloat32); + assert(in_feature.dim() == 2); + assert(in_corner_lut.dtype() == torch::kInt32); + assert(in_corner_lut.dim() == 3); + assert(in_worldcoord.dtype() == torch::kFloat32); + assert(in_worldcoord.size(-1) == 3); + assert(in_worldcoord.dim() <= 8); + + CHECK_CONTIGUOUS(in_feature); + //CHECK_CONTIGUOUS(in_corner_lut); // Will still run correctly, but performance will suffer. + //CHECK_CONTIGUOUS(in_worldcoord); + + //int channel_pos = -1; // -1 for HWC, -3 for CHW + if (channel_pos < 0) { + channel_pos += in_worldcoord.dim(); + } + assert(channel_pos >= 0 && channel_pos < in_worldcoord.dim()); + + SpTrilinear_wc_Params p; + p.in_feature_dim = in_feature.size(1); + p.in_feature_numentries = in_feature.size(0); + p.in_worldcoord_ndim = in_worldcoord.dim(); + for (int i=0; i out_feature_shape; + //if (channel_first) { // Channel First format, suitable for 2D convolution + // //assert(false); + for (int i=0; i<<>>( + out_feature.data_ptr(), + in_feature.data_ptr(), in_corner_lut.data_ptr(), in_worldcoord.data_ptr(), p + ); + C10_CUDA_CHECK(cudaGetLastError()); + return out_feature; +} + + +// Backward function for sparse trilinear interpolation +// Input: +// out_feature_grad: float32 [..., C] +// in_feature: float32 [M, C] +// in_corner_lut: int32 [X Y Z] +// ---in_index: int32 [..., 8], containing [0, M]. 0 is ignore label. +// in_worldcoord: float32 [..., 3] +// ign_zero: bool +// need_coord_grad: bool +// Output: +// in_feature_grad: float32 [M, C] +// in_coord_grad: float32 [..., 3] +std::vector sp_trilinear_worldcoord_backward_cuda(const torch::Tensor& out_feature_grad , const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, bool need_coord_grad) { + assert(need_coord_grad == false); + CHECK_CUDA(out_feature_grad); + CHECK_CUDA(in_feature); + CHECK_CUDA(in_corner_lut); + CHECK_CUDA(in_worldcoord); + + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + torch::Device device = out_feature_grad.device(); + + //for (int i=0; i{3, 4, 5}); + assert(out_feature_grad.dtype() == torch::kFloat32); + for (int i=0; i<<>>( + in_feature_grad.data_ptr(), + out_feature_grad.data_ptr(), in_corner_lut.data_ptr(), in_worldcoord.data_ptr(), p + ); + } + + C10_CUDA_CHECK(cudaGetLastError()); + return {in_feature_grad}; +} diff --git a/imaginaire/model_utils/gancraft/voxlib/voxlib.cpp b/imaginaire/model_utils/gancraft/voxlib/voxlib.cpp new file mode 100644 index 0000000000000000000000000000000000000000..70095052d71f53e5e519a5f57b3c0848998a1b22 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/voxlib.cpp @@ -0,0 +1,31 @@ +// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, check out LICENSE.md +#include +#include +#include +#include + +// Fast voxel traversal along rays +std::vector ray_voxel_intersection_perspective_cuda(const torch::Tensor& in_voxel, const torch::Tensor& cam_ori, const torch::Tensor& cam_dir, const torch::Tensor& cam_up, float cam_f, const std::vector& cam_c, const std::vector& img_dims, int max_samples); + + +// World Coordinate Sparse Trilinear Interpolation +torch::Tensor sp_trilinear_worldcoord_cuda(const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, int channel_pos); + +std::vector sp_trilinear_worldcoord_backward_cuda(const torch::Tensor& out_feature_grad , const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, bool need_coord_grad); + +// Fast & Memory Efficient Positional Encoding +torch::Tensor positional_encoding_cuda(const torch::Tensor& in_feature, int ndegrees, int dim, bool incl_orig); + +torch::Tensor positional_encoding_backward_cuda(const torch::Tensor& out_feature_grad, const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig); + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ray_voxel_intersection_perspective", &ray_voxel_intersection_perspective_cuda, "Ray-voxel intersections given perspective camera parameters (CUDA)"); + m.def("sp_trilinear_worldcoord", &sp_trilinear_worldcoord_cuda, "Sparse Trilinear interpolation, world coordinate [forward] (CUDA)"); + m.def("sp_trilinear_worldcoord_backward", &sp_trilinear_worldcoord_backward_cuda, "Sparse Trilinear interpolation, world coordinate [backward] (CUDA)"); + m.def("positional_encoding", &positional_encoding_cuda, "Fused Positional Encoding [forward] (CUDA)"); + m.def("positional_encoding_backward", &positional_encoding_backward_cuda, "Fused Positional Encoding [backward] (CUDA)"); +} \ No newline at end of file diff --git a/imaginaire/model_utils/gancraft/voxlib/voxlib_common.h b/imaginaire/model_utils/gancraft/voxlib/voxlib_common.h new file mode 100644 index 0000000000000000000000000000000000000000..46b47fc80ecf802347607395ff04565732a4ee87 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/voxlib_common.h @@ -0,0 +1,76 @@ +// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, check out LICENSE.md +#ifndef VOXLIB_COMMON_H +#define VOXLIB_COMMON_H + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_CPU(x) TORCH_CHECK(x.device().is_cpu(), #x " must be a CPU tensor") + +#include +#include +// CUDA vector math functions +__host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +template +__host__ __forceinline__ void cross(scalar_t* r, const scalar_t* a, const scalar_t* b) { + r[0] = a[1]*b[2] - a[2]*b[1]; + r[1] = a[2]*b[0] - a[0]*b[2]; + r[2] = a[0]*b[1] - a[1]*b[0]; +} + +__device__ __host__ __forceinline__ float dot(const float* a, const float* b) { + return a[0] * b[0] + a[1] * b[1] + a[2] * b[2]; +} + +template +__device__ __host__ __forceinline__ void copyarr(scalar_t* r, const scalar_t* a) { + #pragma unroll + for (int i=0; i +__device__ __host__ __forceinline__ void normalize(scalar_t* a) { + scalar_t vec_len=0.0f; + #pragma unroll + for (int i=0; i +__device__ __host__ __forceinline__ void normalize(scalar_t* r, const scalar_t* a) { + scalar_t vec_len=0.0f; + #pragma unroll + for (int i=0; i= data_type_num_classes] = data_type_num_classes + data[data_type] = label_map / 255.0 + return data + + +def _encode_onehot(label_map, num_classes, use_dont_care): + r"""Make input one-hot. + + Args: + label_map (torch.Tensor): (C, H, W) tensor containing indices. + num_classes (int): Number of labels to expand tensor to. + use_dont_care (bool): Use the dont care label or not? + Returns: + output (torch.Tensor): (num_classes, H, W) one-hot tensor. + """ + # All labels lie in [0. num_classes - 1]. + # Encode dont care as num_classes. + label_map[label_map < 0] = num_classes + label_map[label_map >= num_classes] = num_classes + + size = label_map.size() + output_size = (num_classes + 1, size[-2], size[-1]) + output = torch.zeros(*output_size) + if label_map.dim() == 4: + output = output.unsqueeze(0).repeat(label_map.size(0), 1, 1, 1) + output = output.scatter_(1, label_map.data.long(), 1.0) + if not use_dont_care: + output = output[:, :num_classes, ...] + else: + output = output.scatter_(0, label_map.data.long(), 1.0) + if not use_dont_care: + output = output[:num_classes, ...] + return output diff --git a/imaginaire/model_utils/layers.py b/imaginaire/model_utils/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..c8f0ccaddf514de7303a41115c45c5482af7c04d --- /dev/null +++ b/imaginaire/model_utils/layers.py @@ -0,0 +1,271 @@ +import torch +import torch.nn as nn +import numpy as np + + +class SRTConvBlock(nn.Module): + def __init__(self, idim, hdim=None, odim=None): + super().__init__() + if hdim is None: + hdim = idim + + if odim is None: + odim = 2 * hdim + + conv_kwargs = {'bias': False, 'kernel_size': 3, 'padding': 1} + self.layers = nn.Sequential( + nn.Conv2d(idim, hdim, stride=1, **conv_kwargs), + nn.ReLU(), + nn.Conv2d(hdim, odim, stride=2, **conv_kwargs), + nn.ReLU()) + + def forward(self, x): + return self.layers(x) + +class ConditionalHashGrid(nn.Module): + def __init__(self, num_conv_blocks = 6): + super(ConditionalHashGrid, self).__init__() + self.sconv_head = nn.Conv2d(11, 8, kernel_size=3, stride=2, padding=1) + self.hconv_head = nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1) + conv_blocks = [] + cur_hdim = 16 + for i in range(1, num_conv_blocks): + conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None)) + cur_hdim *= 2 + self.conv_blocks = nn.Sequential(*conv_blocks) + self.fc1 = nn.Linear(cur_hdim, 16) + self.fc2 = nn.Linear(16, 2) + self.act = nn.LeakyReLU(0.2) + + def forward(self, height_map, semantic_map): + h = self.act(self.hconv_head(height_map)) + s = self.act(self.sconv_head(semantic_map)) + joint = torch.cat([h, s], dim=1) + # interm = [] + # interm.append(joint.permute(0, 2, 3, 1).reshape(-1, 8)) + for layer in self.conv_blocks: + out = self.act(layer(joint)) + # interm.append(out.permute(0, 2, 3, 1).reshape(-1, 8)) + joint = out + + out = out.permute(0, 2, 3, 1) + out = torch.mean(out.reshape(out.shape[0], -1, out.shape[-1]), dim=1) + cond = self.act(self.fc1(out)) + cond = torch.tanh(self.fc2(cond)) + return cond + +class LightningMLP(nn.Module): + r""" MLP with affine modulation.""" + + def __init__(self, in_channels, style_dim, viewdir_dim, mask_dim=680, + out_channels_s=1, out_channels_c=3, hidden_channels=256, + use_seg=True): + super(LightningMLP, self).__init__() + + self.use_seg = use_seg + if self.use_seg: + self.fc_m_a = nn.Linear(mask_dim, hidden_channels, bias=False) + + self.fc_viewdir = None + if viewdir_dim > 0: + self.fc_viewdir = nn.Linear(viewdir_dim, hidden_channels, bias=False) + + self.fc_1 = nn.Linear(in_channels, hidden_channels) + + self.fc_2 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) + self.fc_3 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) + self.fc_4 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) + + self.fc_sigma = nn.Linear(hidden_channels, out_channels_s) + + if viewdir_dim > 0: + self.fc_5 = nn.Linear(hidden_channels, hidden_channels, bias=False) + self.mod_5 = AffineMod(hidden_channels, style_dim, mod_bias=True) + else: + self.fc_5 = ModLinear(hidden_channels, hidden_channels, style_dim, + bias=False, mod_bias=True, output_mode=True) + self.fc_6 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) + self.fc_out_c = nn.Linear(hidden_channels, out_channels_c) + + self.act = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, x, raydir, z, m): + r""" Forward network + + Args: + x (N x H x W x M x in_channels tensor): Projected features. + raydir (N x H x W x 1 x viewdir_dim tensor): Ray directions. + z (N x style_dim tensor): Style codes. + m (N x H x W x M x mask_dim tensor): One-hot segmentation maps. + """ + b, h, w, n, _ = x.size() + z = z[:, None, None, None, :] + # print('style z', z.shape) + # print('global enc:', global_enc.shape) + f = self.fc_1(x) + if self.use_seg: + f = f + self.fc_m_a(m) + # Common MLP + f = self.act(f) + f = self.act(self.fc_2(f, z)) + f = self.act(self.fc_3(f, z)) + f = self.act(self.fc_4(f, z)) + + # Sigma MLP + sigma = self.fc_sigma(f) + + # Color MLP + if self.fc_viewdir is not None: + f = self.fc_5(f) + f = f + self.fc_viewdir(raydir) + f = self.act(self.mod_5(f, z)) + else: + f = self.act(self.fc_5(f, z)) + f = self.act(self.fc_6(f, z)) + c = self.fc_out_c(f) + return sigma, c + +class AffineMod(nn.Module): + r"""Learning affine modulation of activation. + + Args: + in_features (int): Number of input features. + style_features (int): Number of style features. + mod_bias (bool): Whether to modulate bias. + """ + + def __init__(self, + in_features, + style_features, + mod_bias=True + ): + super().__init__() + self.weight_alpha = nn.Parameter(torch.randn([in_features, style_features]) / np.sqrt(style_features)) + self.bias_alpha = nn.Parameter(torch.full([in_features], 1, dtype=torch.float)) # init to 1 + self.weight_beta = None + self.bias_beta = None + self.mod_bias = mod_bias + if mod_bias: + self.weight_beta = nn.Parameter(torch.randn([in_features, style_features]) / np.sqrt(style_features)) + self.bias_beta = nn.Parameter(torch.full([in_features], 0, dtype=torch.float)) + + @staticmethod + def _linear_f(x, w, b): + w = w.to(x.dtype) + x_shape = x.shape + x = x.reshape(-1, x_shape[-1]) + if b is not None: + b = b.to(x.dtype) + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = x.reshape(*x_shape[:-1], -1) + return x + + # x: B, ... , Cin + # z: B, 1, 1, , Cz + def forward(self, x, z): + x_shape = x.shape + z_shape = z.shape + x = x.reshape(x_shape[0], -1, x_shape[-1]) + z = z.reshape(z_shape[0], 1, z_shape[-1]) + + alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha) # [B, ..., I] + x = x * alpha + + if self.mod_bias: + beta = self._linear_f(z, self.weight_beta, self.bias_beta) # [B, ..., I] + x = x + beta + + x = x.reshape(*x_shape[:-1], x.shape[-1]) + return x + + +class ModLinear(nn.Module): + r"""Linear layer with affine modulation (Based on StyleGAN2 mod demod). + Equivalent to affine modulation following linear, but faster when the same modulation parameters are shared across + multiple inputs. + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + style_features (int): Number of style features. + bias (bool): Apply additive bias before the activation function? + mod_bias (bool): Whether to modulate bias. + output_mode (bool): If True, modulate output instead of input. + weight_gain (float): Initialization gain + """ + + def __init__(self, + in_features, + out_features, + style_features, + bias=True, + mod_bias=True, + output_mode=False, + weight_gain=1, + bias_init=0 + ): + super().__init__() + weight_gain = weight_gain / np.sqrt(in_features) + self.weight = nn.Parameter(torch.randn([out_features, in_features]) * weight_gain) + self.bias = nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_alpha = nn.Parameter(torch.randn([in_features, style_features]) / np.sqrt(style_features)) + self.bias_alpha = nn.Parameter(torch.full([in_features], 1, dtype=torch.float)) # init to 1 + self.weight_beta = None + self.bias_beta = None + self.mod_bias = mod_bias + self.output_mode = output_mode + if mod_bias: + if output_mode: + mod_bias_dims = out_features + else: + mod_bias_dims = in_features + self.weight_beta = nn.Parameter(torch.randn([mod_bias_dims, style_features]) / np.sqrt(style_features)) + self.bias_beta = nn.Parameter(torch.full([mod_bias_dims], 0, dtype=torch.float)) + + @staticmethod + def _linear_f(x, w, b): + w = w.to(x.dtype) + x_shape = x.shape + x = x.reshape(-1, x_shape[-1]) + if b is not None: + b = b.to(x.dtype) + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = x.reshape(*x_shape[:-1], -1) + return x + + # x: B, ... , Cin + # z: B, 1, 1, , Cz + def forward(self, x, z): + x_shape = x.shape + z_shape = z.shape + x = x.reshape(x_shape[0], -1, x_shape[-1]) + z = z.reshape(z_shape[0], 1, z_shape[-1]) + + alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha) # [B, ..., I] + w = self.weight.to(x.dtype) # [O I] + w = w.unsqueeze(0) * alpha # [1 O I] * [B 1 I] = [B O I] + + if self.mod_bias: + beta = self._linear_f(z, self.weight_beta, self.bias_beta) # [B, ..., I] + if not self.output_mode: + x = x + beta + + b = self.bias + if b is not None: + b = b.to(x.dtype)[None, None, :] + if self.mod_bias and self.output_mode: + if b is None: + b = beta + else: + b = b + beta + + # [B ? I] @ [B I O] = [B ? O] + if b is not None: + x = torch.baddbmm(b, x, w.transpose(1, 2)) + else: + x = x.bmm(w.transpose(1, 2)) + x = x.reshape(*x_shape[:-1], x.shape[-1]) + return x diff --git a/imaginaire/model_utils/pcg_gen.py b/imaginaire/model_utils/pcg_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..0312c24700be3bcd93057887bb32810d148d16c3 --- /dev/null +++ b/imaginaire/model_utils/pcg_gen.py @@ -0,0 +1,214 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import random +import cv2 +import os + + +class PCGCache(nn.Module): + r"""PCG Datasets""" + def __init__(self, pcg_dataset_path): + super(PCGCache, self).__init__() + ''' + height_map: [size, size] array, in [-1, 1] range where < 0 indicates water + semantic_map: [size, size] array, in {0, 1, ..., 9} range, where 9 indicates water + ''' + self.sample_size = 1024 + self.sample_height = 256 + pcg_world_list = sorted(os.listdir(pcg_dataset_path)) + self.pcg_world_path = [] + for p in pcg_world_list: + self.pcg_world_path.append(os.path.join(pcg_dataset_path, p)) + self.n = len(self.pcg_world_path) + + def sample_world(self, device): + idx = random.randint(0, self.n - 1) + world_path = self.pcg_world_path[idx] + voxel_sparse = np.load(os.path.join(world_path, 'voxel_sparse.npy')) + current_height_map = np.load(os.path.join(world_path, 'height_map.npy')) + current_semantic_map = np.load(os.path.join(world_path, 'semantic_map.npy')) + heightmap = np.load(os.path.join(world_path, 'hmap_mc.npy')) + voxel_sparse = torch.from_numpy(voxel_sparse).to(device) + voxel_1 = voxel_sparse[0, :].to(torch.int64) + voxel_2 = voxel_sparse[1, :].to(torch.int64) + voxel_3 = voxel_sparse[2, :].to(torch.int64) + self.voxel_t = torch.zeros(self.sample_height, self.sample_size, self.sample_size, device=device, dtype=torch.int32) + self.voxel_t[voxel_1, voxel_2, voxel_3] = voxel_sparse[3, :].to(torch.int32) + self.current_height_map = torch.from_numpy(current_height_map).to(device) + self.current_semantic_map = torch.from_numpy(current_semantic_map).to(device) + self.heightmap = torch.from_numpy(heightmap) + self.trans_mat = torch.eye(4) + gnd_level = heightmap.min() + sky_level = heightmap.max() + 1 + self.voxel_t = self.voxel_t[gnd_level:sky_level, :, :] + self.trans_mat[0, 3] += gnd_level + + def world2local(self, v, is_vec=False): + mat_world2local = torch.inverse(self.trans_mat) + return trans_vec_homo(mat_world2local, v, is_vec) + + def _truncate_voxel(self): + gnd_level = self.heightmap.min() + sky_level = self.heightmap.max() + 1 + self.voxel_t = self.voxel_t[gnd_level:sky_level, :, :] + self.trans_mat[0, 3] += gnd_level + print('[GANcraft-utils] Voxel truncated. Gnd: {}; Sky: {}.'.format(gnd_level.item(), sky_level.item())) + + def is_sea(self, loc): + r"""loc: [2]: x, z.""" + x = int(loc[1]) + z = int(loc[2]) + if x < 0 or x > self.heightmap.size(0) or z < 0 or z > self.heightmap.size(1): + print('[McVoxel] is_sea(): Index out of bound.') + return True + y = self.heightmap[x, z] - self.trans_mat[0, 3] + y = int(y) + if self.voxel_t[y, x, z] == 26: + print('[McVoxel] is_sea(): Get a sea.') + print(self.voxel_t[y, x, z], self.voxel_t[y+1, x, z]) + return True + else: + return False + + +class PCGVoxelGenerator(nn.Module): + def __init__(self, sample_size = 2048): + super(PCGVoxelGenerator, self).__init__() + self.sample_height = 256 + self.sample_size = sample_size + self.voxel_t = None + + def next_world(self, device, world_dir, pcg_asset): + # Generate BEV representation + print('[PCGGenerator] Loading BEV scene representation...') + heightmap_path = os.path.join(world_dir, 'heightmap.npy') + semanticmap_path = os.path.join(world_dir, 'semanticmap.png') + treemap_path = os.path.join(world_dir, 'treemap.png') + height_map = np.load(heightmap_path) + semantic_map = cv2.imread(semanticmap_path, 0) + tree_map = cv2.imread(treemap_path, 0) + + print('[PCGGenerator] Creating scene windows...') + height_map[height_map < 0] = 0 + height_map = ((height_map - height_map.min()) / (1 - height_map.min()) * (self.sample_height - 1)).astype(np.int16) + + self.total_size = height_map.shape + + + org_semantic_map = torch.from_numpy(semantic_map.copy()) + org_semantic_map[tree_map != 255] = 10 + chunk_trees_map = tree_map + + biome_trees_dict = { + 'desert': [], + 'savanna': [5], + 'twoodland': [1, 7], + 'tundra': [], + 'seasonal forest': [1, 2], + 'rainforest': [1, 2, 3], + 'temp forest': [4], + 'temp rainforest': [0, 3], + 'boreal': [5,6,7], + 'water': [], + } + biome2mclabels = torch.tensor([28, 9, 8, 1, 9, 8, 9, 8, 30, 26], dtype=torch.int32) + biome_names = list(biome_trees_dict.keys()) + chunk_grid_x, chunk_grid_y = torch.meshgrid(torch.arange(self.total_size[0]), torch.arange(self.total_size[1])) + world_voxel_t = torch.zeros(self.sample_height, self.total_size[0], self.total_size[1]).to(torch.int32) + + chunk_height_map = torch.from_numpy(height_map.astype(int))[None, ...] + chunk_semantic_map = torch.from_numpy(semantic_map) + chunk_semantic_map = biome2mclabels[chunk_semantic_map[None, ...].long().contiguous()] + world_voxel_t = world_voxel_t.scatter_(0, chunk_height_map, chunk_semantic_map) + pad_num = 16 + for preproc_step in range(pad_num): + world_voxel_t = world_voxel_t.scatter(0, torch.clip(chunk_height_map + preproc_step + 1, 0, self.sample_height - 1), chunk_semantic_map) + chunk_height_map = chunk_height_map + pad_num + chunk_height_map = chunk_height_map[0] + boundary_detect = 50 + + trees_models = pcg_asset['assets'] + + for biome_id in range(biome2mclabels.shape[0]): + tree_pos_mask = (chunk_trees_map == biome_id) + tree_pos_x = chunk_grid_x[tree_pos_mask] + tree_pos_y = chunk_grid_y[tree_pos_mask] + tree_pos_h = chunk_height_map[tree_pos_mask] + assert len(tree_pos_x) == len(tree_pos_y) + selected_trees = biome_trees_dict[biome_names[biome_id]] + if len(selected_trees) == 0: + continue + for idx in range(len(tree_pos_x)): + if tree_pos_x[idx] < boundary_detect or tree_pos_x[idx] > self.total_size[0] - boundary_detect or tree_pos_y[idx] < boundary_detect or tree_pos_y[idx] > self.total_size[1] - boundary_detect or tree_pos_h[idx] > self.sample_height - boundary_detect: + # hack, to avoid out of index near the boundary + continue + tree_id = random.choice(selected_trees) + tmp = world_voxel_t[tree_pos_h[idx]: tree_pos_h[idx] + trees_models[tree_id].shape[0], tree_pos_x[idx]: tree_pos_x[idx] + trees_models[tree_id].shape[1], tree_pos_y[idx]: tree_pos_y[idx] + trees_models[tree_id].shape[2]] + tmp_mask = (tmp == 0) + try: + world_voxel_t[tree_pos_h[idx]: tree_pos_h[idx] + trees_models[tree_id].shape[0], tree_pos_x[idx]: tree_pos_x[idx] + trees_models[tree_id].shape[1], tree_pos_y[idx]: tree_pos_y[idx] + trees_models[tree_id].shape[2]][tmp_mask] = trees_models[tree_id][tmp_mask] + except: + print('height?', tree_pos_h[idx]) + print(tmp_mask.shape) + print(tmp.shape) + print(trees_models[tree_id].shape) + print(world_voxel_t.shape) + print(tree_id) + raise NotImplementedError + self.trans_mat = torch.eye(4) # Transform voxel to world + # Generate heightmap for camera trajectory generation + m, h = torch.max((torch.flip(world_voxel_t, [0]) != 0).int(), dim=0, keepdim=False) + heightmap = world_voxel_t.shape[0] - 1 - h + heightmap[m == 0] = 0 # Special case when the whole vertical column is empty + gnd_level = heightmap.min() + sky_level = heightmap.max() + 1 + current_height_map = (chunk_height_map / (self.sample_height - 1))[None, None, ...] + current_semantic_map = F.one_hot(org_semantic_map.to(torch.int64)).to(torch.float).permute(2, 0, 1)[None, ...] + + self.current_height_map = current_height_map.to(device) + self.current_semantic_map = current_semantic_map.to(device) + self.heightmap = heightmap + self.voxel_t = world_voxel_t[gnd_level:sky_level, :, :].to(device) + self.trans_mat[0, 3] += gnd_level + + def world2local(self, v, is_vec=False): + mat_world2local = torch.inverse(self.trans_mat) + return trans_vec_homo(mat_world2local, v, is_vec) + + def is_sea(self, loc): + r"""loc: [2]: x, z.""" + x = int(loc[1]) + z = int(loc[2]) + if x < 0 or x > self.heightmap.size(0) or z < 0 or z > self.heightmap.size(1): + print('[McVoxel] is_sea(): Index out of bound.') + return True + y = self.heightmap[x, z] - self.trans_mat[0, 3] + y = int(y) + if self.voxel_t[y, x, z] == 26: + print('[McVoxel] is_sea(): Get a sea.') + print(self.voxel_t[y, x, z], self.voxel_t[y+1, x, z]) + return True + else: + return False + +def trans_vec_homo(m, v, is_vec=False): + r"""3-dimensional Homogeneous matrix and regular vector multiplication + Convert v to homogeneous vector, perform M-V multiplication, and convert back + Note that this function does not support autograd. + + Args: + m (4 x 4 tensor): a homogeneous matrix + v (3 tensor): a 3-d vector + vec (bool): if true, v is direction. Otherwise v is point + """ + if is_vec: + v = torch.tensor([v[0], v[1], v[2], 0], dtype=v.dtype) + else: + v = torch.tensor([v[0], v[1], v[2], 1], dtype=v.dtype) + v = torch.mv(m, v) + if not is_vec: + v = v / v[3] + v = v[:3] + return v \ No newline at end of file diff --git a/imaginaire/optimizers/__init__.py b/imaginaire/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69bedc71589e178b96730174c9860b8cc8430e55 --- /dev/null +++ b/imaginaire/optimizers/__init__.py @@ -0,0 +1,8 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .fromage import Fromage +from .madam import Madam + +__all__ = ['Fromage', 'Madam'] diff --git a/imaginaire/optimizers/fromage.py b/imaginaire/optimizers/fromage.py new file mode 100644 index 0000000000000000000000000000000000000000..d00203de89f55fd122f71b7de8718ed7ef681ec8 --- /dev/null +++ b/imaginaire/optimizers/fromage.py @@ -0,0 +1,44 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# import torch +import math + +from torch.optim.optimizer import Optimizer, required + + +class Fromage(Optimizer): + r"""Fromage optimizer implementation (https://arxiv.org/abs/2002.03432)""" + + def __init__(self, params, lr=required, momentum=0): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + defaults = dict(lr=lr, momentum=momentum) + super(Fromage, self).__init__(params, defaults) + + def step(self, closure=None): + r"""Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + d_p = p.grad.data + d_p_norm = p.grad.norm() + p_norm = p.norm() + if p_norm > 0.0 and d_p_norm > 0.0: + p.data.add_(-group['lr'], d_p * (p_norm / d_p_norm)) + else: + p.data.add_(-group['lr'], d_p) + p.data /= math.sqrt(1 + group['lr'] ** 2) + + return loss diff --git a/imaginaire/optimizers/madam.py b/imaginaire/optimizers/madam.py new file mode 100644 index 0000000000000000000000000000000000000000..11bf71d049d9323e9ba646713413578ae5eb4503 --- /dev/null +++ b/imaginaire/optimizers/madam.py @@ -0,0 +1,54 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch.optim.optimizer import Optimizer, required + + +class Madam(Optimizer): + r"""MADAM optimizer implementation (https://arxiv.org/abs/2006.14560)""" + def __init__(self, params, lr=required, scale=3.0, + g_bound=None, momentum=0): + self.scale = scale + self.g_bound = g_bound + defaults = dict(lr=lr, momentum=momentum) + super(Madam, self).__init__(params, defaults) + + def step(self, closure=None): + r"""Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + state = self.state[p] + if len(state) == 0: + state['max'] = self.scale * (p * p).mean().sqrt().item() + state['step'] = 0 + state['exp_avg_sq'] = torch.zeros_like(p) + + state['step'] += 1 + bias_correction = 1 - 0.999 ** state['step'] + state['exp_avg_sq'] = 0.999 * state[ + 'exp_avg_sq'] + 0.001 * p.grad.data ** 2 + g_normed = \ + p.grad.data / (state['exp_avg_sq'] / bias_correction).sqrt() + g_normed[torch.isnan(g_normed)] = 0 + if self.g_bound is not None: + g_normed.clamp_(-self.g_bound, self.g_bound) + + p.data *= torch.exp( + -group['lr'] * g_normed * torch.sign(p.data)) + p.data.clamp_(-state['max'], state['max']) + + return loss diff --git a/imaginaire/third_party/__init__.py b/imaginaire/third_party/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imaginaire/third_party/bias_act/__init__.py b/imaginaire/third_party/bias_act/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9dfe0aec6e5fd4a1538ed959abff7c5106784c9b --- /dev/null +++ b/imaginaire/third_party/bias_act/__init__.py @@ -0,0 +1,3 @@ +from .bias_act import FusedNonlinearity + +__all__ = ['FusedNonlinearity'] diff --git a/imaginaire/third_party/bias_act/bias_act.py b/imaginaire/third_party/bias_act/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..29b01dc97884036aec1c42feb184c510c5ad0870 --- /dev/null +++ b/imaginaire/third_party/bias_act/bias_act.py @@ -0,0 +1,219 @@ +# flake8: noqa +import numpy as np +from types import SimpleNamespace + +import torch +from torch import nn + +import bias_act_cuda + +# ---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': SimpleNamespace(func=lambda x, **_: x, def_alpha=0, def_gain=1, + cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.relu(x), + def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, + ref='y', has_2nd_grad=False), + 'leakyrelu': SimpleNamespace( + func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), + def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', + has_2nd_grad=False), + 'tanh': SimpleNamespace(func=lambda x, **_: torch.tanh(x), def_alpha=0, + def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': SimpleNamespace(func=lambda x, **_: torch.sigmoid(x), + def_alpha=0, def_gain=1, cuda_idx=5, ref='y', + has_2nd_grad=True), + 'elu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.elu(x), + def_alpha=0, def_gain=1, cuda_idx=6, ref='y', + has_2nd_grad=True), + 'selu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.selu(x), + def_alpha=0, def_gain=1, cuda_idx=7, ref='y', + has_2nd_grad=True), + 'softplus': SimpleNamespace( + func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, + def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': SimpleNamespace(func=lambda x, **_: torch.sigmoid(x) * x, + def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, + ref='x', has_2nd_grad=True), +} + +# ---------------------------------------------------------------------------- + +_null_tensor = torch.empty([0]) + + +def _bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, + impl='cuda'): + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda': + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, + clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, + clamp=clamp) + + +# ---------------------------------------------------------------------------- + +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + + +# ---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + if x.ndim > 2 and x.stride()[1] == 1: + ctx.memory_format = torch.channels_last + else: + ctx.memory_format = torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not \ + _null_tensor: + y = bias_act_cuda.bias_act_cuda(x, b, _null_tensor, _null_tensor, + _null_tensor, 0, dim, spec.cuda_idx, alpha, + gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + if x.ndim > 2 and x.stride()[1] == 1: + ctx.memory_format = torch.channels_last + else: + ctx.memory_format = torch.contiguous_format + dx = bias_act_cuda.bias_act_cuda(dy, b, x, y, _null_tensor, 1, dim, + spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and ( + ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = bias_act_cuda.bias_act_cuda(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, + alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + + +class FusedNonlinearity(nn.Module): + def __init__(self, nonlinearity, num_channels=None, lr_mul=1.0, alpha=None, impl='cuda', gain=None): + super().__init__() + if num_channels is not None: + self.bias = nn.Parameter(torch.zeros(num_channels)) + else: + self.register_parameter('bias', None) + self.nonlinearity = nonlinearity + self.gain = gain + self.alpha = alpha + self.lr_mul = lr_mul + self.impl = impl + + def forward(self, x): + bias = self.bias.type_as(x) * self.lr_mul if self.bias is not None else None + return _bias_act( + x, b=bias, dim=1, act=self.nonlinearity, + alpha=self.alpha, gain=self.gain, clamp=None, impl=self.impl + ) + + def __repr__(self): + mod_str = f'{self.__class__.__name__}(type={self.nonlinearity}' + if self.gain is not None: + mod_str += f', gain={self.gain}' + if self.alpha is not None: + mod_str += f', alpha={self.alpha}' + if self.lr_mul != 1: + mod_str += f', lr_mul={self.lr_mul}' + mod_str += ')' + return mod_str diff --git a/imaginaire/third_party/bias_act/setup.py b/imaginaire/third_party/bias_act/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..175331446b628a1c912f42d6b01d51c8fd5fe5b7 --- /dev/null +++ b/imaginaire/third_party/bias_act/setup.py @@ -0,0 +1,43 @@ +# flake8: noqa +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os + + +cuda_version = os.getenv('CUDA_VERSION') +print('CUDA_VERSION: {}'.format(cuda_version)) + +nvcc_args = list() +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_50,code=sm_50') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_52,code=sm_52') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_60,code=sm_60') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_61,code=sm_61') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_70,code=sm_70') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_75,code=sm_75') +if cuda_version is not None: + if cuda_version >= '11.0': + nvcc_args.append('-gencode') + nvcc_args.append('arch=compute_80,code=sm_80') +nvcc_args.append('-Xcompiler') +nvcc_args.append('-Wall') +nvcc_args.append('-std=c++17') + +setup( + name='bias_act_cuda', + py_modules=['bias_act'], + ext_modules=[ + CUDAExtension('bias_act_cuda', [ + './src/bias_act_cuda.cc', + './src/bias_act_cuda_kernel.cu' + ], extra_compile_args={'cxx': ['-Wall', '-std=c++17'], + 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/imaginaire/third_party/bias_act/src/bias_act_cuda.cc b/imaginaire/third_party/bias_act/src/bias_act_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf975dbe6784e89cfa056574da8780d1e5f5b97d --- /dev/null +++ b/imaginaire/third_party/bias_act/src/bias_act_cuda.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include +#include +#include + +#include "bias_act_cuda.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act_cuda(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda_kernel", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act_cuda", &bias_act_cuda); +} + +//------------------------------------------------------------------------ diff --git a/imaginaire/third_party/bias_act/src/bias_act_cuda.h b/imaginaire/third_party/bias_act/src/bias_act_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4 --- /dev/null +++ b/imaginaire/third_party/bias_act/src/bias_act_cuda.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/imaginaire/third_party/bias_act/src/bias_act_cuda_kernel.cu b/imaginaire/third_party/bias_act/src/bias_act_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..9adbb942b5ce5740a5527449995e1887cda12816 --- /dev/null +++ b/imaginaire/third_party/bias_act/src/bias_act_cuda_kernel.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "bias_act_cuda.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/imaginaire/third_party/channelnorm/channelnorm.py b/imaginaire/third_party/channelnorm/channelnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd46711ca0bf2b6bb650112fa364100f6d4c927 --- /dev/null +++ b/imaginaire/third_party/channelnorm/channelnorm.py @@ -0,0 +1,39 @@ +# flake8: noqa +from torch.autograd import Function, Variable +from torch.nn.modules.module import Module +import channelnorm_cuda + + +class ChannelNormFunction(Function): + @staticmethod + def forward(ctx, input1, norm_deg=2): + assert input1.is_contiguous() + b, _, h, w = input1.size() + output = input1.new(b, 1, h, w).zero_() + + channelnorm_cuda.forward(input1, output, norm_deg) + ctx.save_for_backward(input1, output) + ctx.norm_deg = norm_deg + + return output + + @staticmethod + def backward(ctx, grad_output): + input1, output = ctx.saved_tensors + + grad_input1 = Variable(input1.new(input1.size()).zero_()) + + channelnorm_cuda.backward(input1, output, grad_output.data, + grad_input1.data, ctx.norm_deg) + + return grad_input1, None + + +class ChannelNorm(Module): + + def __init__(self, norm_deg=2): + super(ChannelNorm, self).__init__() + self.norm_deg = norm_deg + + def forward(self, input1): + return ChannelNormFunction.apply(input1, self.norm_deg) diff --git a/imaginaire/third_party/channelnorm/setup.py b/imaginaire/third_party/channelnorm/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..bae552bdba5ede3e706893bba93e3b21c6b7d3ea --- /dev/null +++ b/imaginaire/third_party/channelnorm/setup.py @@ -0,0 +1,43 @@ +# flake8: noqa +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os + + +cuda_version = os.getenv('CUDA_VERSION') +print('CUDA_VERSION: {}'.format(cuda_version)) + +nvcc_args = list() +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_50,code=sm_50') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_52,code=sm_52') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_60,code=sm_60') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_61,code=sm_61') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_70,code=sm_70') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_75,code=sm_75') +if cuda_version is not None: + if cuda_version >= '11.0': + nvcc_args.append('-gencode') + nvcc_args.append('arch=compute_80,code=sm_80') +nvcc_args.append('-Xcompiler') +nvcc_args.append('-Wall') +nvcc_args.append('-std=c++17') + +setup( + name='channelnorm_cuda', + py_modules=['channelnorm'], + ext_modules=[ + CUDAExtension('channelnorm_cuda', [ + './src/channelnorm_cuda.cc', + './src/channelnorm_kernel.cu' + ], extra_compile_args={'cxx': ['-Wall', '-std=c++17'], + 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/imaginaire/third_party/channelnorm/src/channelnorm_cuda.cc b/imaginaire/third_party/channelnorm/src/channelnorm_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..69d82eb184e97b2eefa9810ad156d1104cf84745 --- /dev/null +++ b/imaginaire/third_party/channelnorm/src/channelnorm_cuda.cc @@ -0,0 +1,31 @@ +#include +#include + +#include "channelnorm_kernel.cuh" + +int channelnorm_cuda_forward( + at::Tensor& input1, + at::Tensor& output, + int norm_deg) { + + channelnorm_kernel_forward(input1, output, norm_deg); + return 1; +} + + +int channelnorm_cuda_backward( + at::Tensor& input1, + at::Tensor& output, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + int norm_deg) { + + channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg); + return 1; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)"); + m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)"); +} + diff --git a/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cu b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..99ace6855a61373443a6ddff7c7858eb474e9e48 --- /dev/null +++ b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cu @@ -0,0 +1,177 @@ +#include +#include +#include + +#include "channelnorm_kernel.cuh" + +#define CUDA_NUM_THREADS 512 + +#define DIM0(TENSOR) ((TENSOR).x) +#define DIM1(TENSOR) ((TENSOR).y) +#define DIM2(TENSOR) ((TENSOR).z) +#define DIM3(TENSOR) ((TENSOR).w) + +#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) + +using at::Half; + +template +__global__ void kernel_channelnorm_update_output( + const int n, + const scalar_t* __restrict__ input1, + const long4 input1_size, + const long4 input1_stride, + scalar_t* __restrict__ output, + const long4 output_size, + const long4 output_stride, + int norm_deg) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + int dim_b = DIM0(output_size); + int dim_c = DIM1(output_size); + int dim_h = DIM2(output_size); + int dim_w = DIM3(output_size); + int dim_chw = dim_c * dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + int i1dim_c = DIM1(input1_size); + int i1dim_h = DIM2(input1_size); + int i1dim_w = DIM3(input1_size); + int i1dim_chw = i1dim_c * i1dim_h * i1dim_w; + int i1dim_hw = i1dim_h * i1dim_w; + + float result = 0.0; + + for (int c = 0; c < i1dim_c; ++c) { + int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x; + scalar_t val = input1[i1Index]; + result += static_cast(val * val); + } + result = sqrt(result); + output[index] = static_cast(result); +} + + +template +__global__ void kernel_channelnorm_backward_input1( + const int n, + const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, + const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, + scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, + int norm_deg) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + float val = 0.0; + + int dim_b = DIM0(gradInput_size); + int dim_c = DIM1(gradInput_size); + int dim_h = DIM2(gradInput_size); + int dim_w = DIM3(gradInput_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + + int outIndex = b * dim_hw + y * dim_w + x; + val = static_cast(gradOutput[outIndex]) * static_cast(input1[index]) / (static_cast(output[outIndex])+1e-9); + gradInput[index] = static_cast(val); + +} + +void channelnorm_kernel_forward( + at::Tensor& input1, + at::Tensor& output, + int norm_deg) { + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); + const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); + + int n = output.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] { + + kernel_channelnorm_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + output.data(), + output_size, + output_stride, + norm_deg); + + })); + + // TODO: ATen-equivalent check + + // THCudaCheck(cudaGetLastError()); +} + +void channelnorm_kernel_backward( + at::Tensor& input1, + at::Tensor& output, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + int norm_deg) { + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); + const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); + + const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); + const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); + + const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); + const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); + + int n = gradInput1.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] { + + kernel_channelnorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + output.data(), + output_size, + output_stride, + gradOutput.data(), + gradOutput_size, + gradOutput_stride, + gradInput1.data(), + gradInput1_size, + gradInput1_stride, + norm_deg + ); + + })); + + // TODO: Add ATen-equivalent check + +// THCudaCheck(cudaGetLastError()); +} diff --git a/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cuh b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3e6223f7fe60feb4bf9e4f66c3d849b84c89dcda --- /dev/null +++ b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cuh @@ -0,0 +1,16 @@ +#pragma once + +#include + +void channelnorm_kernel_forward( + at::Tensor& input1, + at::Tensor& output, + int norm_deg); + + +void channelnorm_kernel_backward( + at::Tensor& input1, + at::Tensor& output, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + int norm_deg); diff --git a/imaginaire/third_party/correlation/correlation.py b/imaginaire/third_party/correlation/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..e47739dff7475c0f29bff32bc2dc9f097161d144 --- /dev/null +++ b/imaginaire/third_party/correlation/correlation.py @@ -0,0 +1,105 @@ +# flake8: noqa +import torch +from torch.nn.modules.module import Module +from torch.autograd import Function +import correlation_cuda + + +class CorrelationFunction(Function): + + @staticmethod + def forward(ctx, + pad_size, + kernel_size, + max_displacement, + stride1, + stride2, + corr_multiply, + input1, + input2): + ctx.save_for_backward(input1, input2) + ctx.pad_size = pad_size + ctx.kernel_size = kernel_size + ctx.max_displacement = max_displacement + ctx.stride1 = stride1 + ctx.stride2 = stride2 + ctx.corr_multiply = corr_multiply + + with torch.cuda.device_of(input1): + rbot1 = input1.new() + rbot2 = input2.new() + output = input1.new() + + correlation_cuda.forward( + input1, + input2, + rbot1, + rbot2, + output, + ctx.pad_size, + ctx.kernel_size, + ctx.max_displacement, + ctx.stride1, + ctx.stride2, + ctx.corr_multiply) + + return output + + @staticmethod + def backward(ctx, grad_output): + input1, input2 = ctx.saved_tensors + + with torch.cuda.device_of(input1): + rbot1 = input1.new() + rbot2 = input2.new() + + grad_input1 = input1.new() + grad_input2 = input2.new() + + correlation_cuda.backward( + input1, + input2, + rbot1, + rbot2, + grad_output, + grad_input1, + grad_input2, + ctx.pad_size, + ctx.kernel_size, + ctx.max_displacement, + ctx.stride1, + ctx.stride2, + ctx.corr_multiply) + + return grad_input1, grad_input2 + +class Correlation(Module): + def __init__( + self, + pad_size=0, + kernel_size=0, + max_displacement=0, + stride1=1, + stride2=2, + corr_multiply=1): + super(Correlation, self).__init__() + self.pad_size = pad_size + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride1 = stride1 + self.stride2 = stride2 + self.corr_multiply = corr_multiply + + def forward(self, input1, input2): + + result = CorrelationFunction.apply( + self.pad_size, + self.kernel_size, + self.max_displacement, + self.stride1, + self.stride2, + self.corr_multiply, + input1, + input2) + + return result diff --git a/imaginaire/third_party/correlation/setup.py b/imaginaire/third_party/correlation/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..770dff31f054e05b6542642d81734c908b198a91 --- /dev/null +++ b/imaginaire/third_party/correlation/setup.py @@ -0,0 +1,43 @@ +# flake8: noqa +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os + + +cuda_version = os.getenv('CUDA_VERSION') +print('CUDA_VERSION: {}'.format(cuda_version)) + +nvcc_args = list() +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_50,code=sm_50') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_52,code=sm_52') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_60,code=sm_60') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_61,code=sm_61') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_70,code=sm_70') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_75,code=sm_75') +if cuda_version is not None: + if cuda_version >= '11.0': + nvcc_args.append('-gencode') + nvcc_args.append('arch=compute_80,code=sm_80') +nvcc_args.append('-Xcompiler') +nvcc_args.append('-Wall') +nvcc_args.append('-std=c++17') + +setup( + name='correlation_cuda', + py_modules=['correlation'], + ext_modules=[ + CUDAExtension('correlation_cuda', [ + './src/correlation_cuda.cc', + './src/correlation_cuda_kernel.cu' + ], extra_compile_args={'cxx': ['-Wall', '-std=c++17'], + 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/imaginaire/third_party/correlation/src/correlation_cuda.cc b/imaginaire/third_party/correlation/src/correlation_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..feccd65295fa90a22564b08fc80464a76361a1aa --- /dev/null +++ b/imaginaire/third_party/correlation/src/correlation_cuda.cc @@ -0,0 +1,173 @@ +#include +#include +#include +#include +#include +#include + +#include "correlation_cuda_kernel.cuh" + +int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply) +{ + + int batchSize = input1.size(0); + + int nInputChannels = input1.size(1); + int inputHeight = input1.size(2); + int inputWidth = input1.size(3); + + int kernel_radius = (kernel_size - 1) / 2; + int border_radius = kernel_radius + max_displacement; + + int paddedInputHeight = inputHeight + 2 * pad_size; + int paddedInputWidth = inputWidth + 2 * pad_size; + + int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); + + int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); + int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); + + rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); + + rInput1.fill_(0); + rInput2.fill_(0); + output.fill_(0); + + int success = correlation_forward_cuda_kernel( + output, + output.size(0), + output.size(1), + output.size(2), + output.size(3), + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + input1, + input1.size(1), + input1.size(2), + input1.size(3), + input1.stride(0), + input1.stride(1), + input1.stride(2), + input1.stride(3), + input2, + input2.size(1), + input2.stride(0), + input2.stride(1), + input2.stride(2), + input2.stride(3), + rInput1, + rInput2, + pad_size, + kernel_size, + max_displacement, + stride1, + stride2, + corr_type_multiply, + at::cuda::getCurrentCUDAStream() + //at::globalContext().getCurrentCUDAStream() + ); + + //check for errors + if (!success) { + AT_ERROR("CUDA call failed"); + } + + return 1; + +} + +int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, + at::Tensor& gradInput1, at::Tensor& gradInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply) +{ + + int batchSize = input1.size(0); + int nInputChannels = input1.size(1); + int paddedInputHeight = input1.size(2)+ 2 * pad_size; + int paddedInputWidth = input1.size(3)+ 2 * pad_size; + + int height = input1.size(2); + int width = input1.size(3); + + rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + gradInput1.resize_({batchSize, nInputChannels, height, width}); + gradInput2.resize_({batchSize, nInputChannels, height, width}); + + rInput1.fill_(0); + rInput2.fill_(0); + gradInput1.fill_(0); + gradInput2.fill_(0); + + int success = correlation_backward_cuda_kernel(gradOutput, + gradOutput.size(0), + gradOutput.size(1), + gradOutput.size(2), + gradOutput.size(3), + gradOutput.stride(0), + gradOutput.stride(1), + gradOutput.stride(2), + gradOutput.stride(3), + input1, + input1.size(1), + input1.size(2), + input1.size(3), + input1.stride(0), + input1.stride(1), + input1.stride(2), + input1.stride(3), + input2, + input2.stride(0), + input2.stride(1), + input2.stride(2), + input2.stride(3), + gradInput1, + gradInput1.stride(0), + gradInput1.stride(1), + gradInput1.stride(2), + gradInput1.stride(3), + gradInput2, + gradInput2.size(1), + gradInput2.stride(0), + gradInput2.stride(1), + gradInput2.stride(2), + gradInput2.stride(3), + rInput1, + rInput2, + pad_size, + kernel_size, + max_displacement, + stride1, + stride2, + corr_type_multiply, + at::cuda::getCurrentCUDAStream() + //at::globalContext().getCurrentCUDAStream() + ); + + if (!success) { + AT_ERROR("CUDA call failed"); + } + + return 1; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); + m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); +} + diff --git a/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cu b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..eaf86fc129137d055de7400916567c6669b45c19 --- /dev/null +++ b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cu @@ -0,0 +1,564 @@ +#include + +#include "correlation_cuda_kernel.cuh" + +#define CUDA_NUM_THREADS 1024 +#define THREADS_PER_BLOCK 32 +#define FULL_MASK 0xffffffff + +#include +#include +#include +#include + +using at::Half; + +template +__forceinline__ __device__ scalar_t warpReduceSum(scalar_t val) { + for (int offset = 16; offset > 0; offset /= 2) + val += __shfl_down_sync(FULL_MASK, val, offset); + return val; +} + +template +__forceinline__ __device__ scalar_t blockReduceSum(scalar_t val) { + + static __shared__ scalar_t shared[32]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0; + + if (wid == 0) + val = warpReduceSum(val); + + return val; +} + + +template +__global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size) +{ + + // n (batch size), c (num of channels), y (height), x (width) + int n = blockIdx.x; + int y = blockIdx.y; + int x = blockIdx.z; + + int ch_off = threadIdx.x; + scalar_t value; + + int dimcyx = channels * height * width; + int dimyx = height * width; + + int p_dimx = (width + 2 * pad_size); + int p_dimy = (height + 2 * pad_size); + int p_dimyxc = channels * p_dimy * p_dimx; + int p_dimxc = p_dimx * channels; + + for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) { + value = input[n * dimcyx + c * dimyx + y * width + x]; + rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value; + } +} + + +template +__global__ void correlation_forward(scalar_t* __restrict__ output, const int nOutputChannels, + const int outputHeight, const int outputWidth, const scalar_t* __restrict__ rInput1, + const int nInputChannels, const int inputHeight, const int inputWidth, + const scalar_t* __restrict__ rInput2, const int pad_size, const int kernel_size, + const int max_displacement, const int stride1, const int stride2) { + + int32_t pInputWidth = inputWidth + 2 * pad_size; + int32_t pInputHeight = inputHeight + 2 * pad_size; + + int32_t kernel_rad = (kernel_size - 1) / 2; + + int32_t displacement_rad = max_displacement / stride2; + + int32_t displacement_size = 2 * displacement_rad + 1; + + int32_t n = blockIdx.x; + int32_t y1 = blockIdx.y * stride1 + max_displacement; + int32_t x1 = blockIdx.z * stride1 + max_displacement; + int32_t c = threadIdx.x; + + int32_t pdimyxc = pInputHeight * pInputWidth * nInputChannels; + + int32_t pdimxc = pInputWidth * nInputChannels; + + int32_t pdimc = nInputChannels; + + int32_t tdimcyx = nOutputChannels * outputHeight * outputWidth; + int32_t tdimyx = outputHeight * outputWidth; + int32_t tdimx = outputWidth; + + int32_t nelems = kernel_size * kernel_size * pdimc; + + // element-wise product along channel axis + for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { + for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { + int x2 = x1 + ti * stride2; + int y2 = y1 + tj * stride2; + + float acc0 = 0.0f; + + for (int j = -kernel_rad; j <= kernel_rad; ++j) { + for (int i = -kernel_rad; i <= kernel_rad; ++i) { + // THREADS_PER_BLOCK + #pragma unroll + for (int ch = c; ch < pdimc; ch += blockDim.x) { + + int indx1 = n * pdimyxc + (y1 + j) * pdimxc + + (x1 + i) * pdimc + ch; + int indx2 = n * pdimyxc + (y2 + j) * pdimxc + + (x2 + i) * pdimc + ch; + acc0 += static_cast(rInput1[indx1] * rInput2[indx2]); + } + } + } + + if (blockDim.x == warpSize) { + __syncwarp(); + acc0 = warpReduceSum(acc0); + } else { + __syncthreads(); + acc0 = blockReduceSum(acc0); + } + + if (threadIdx.x == 0) { + + int tc = (tj + displacement_rad) * displacement_size + + (ti + displacement_rad); + const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + + blockIdx.z; + output[tindx] = static_cast(acc0 / nelems); + } + } + } +} + + +template +__global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth, + const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, + const scalar_t* __restrict__ rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2) + { + // n (batch size), c (num of channels), y (height), x (width) + + int n = item; + int y = blockIdx.x * stride1 + pad_size; + int x = blockIdx.y * stride1 + pad_size; + int c = blockIdx.z; + int tch_off = threadIdx.x; + + int kernel_rad = (kernel_size - 1) / 2; + int displacement_rad = max_displacement / stride2; + int displacement_size = 2 * displacement_rad + 1; + + int xmin = (x - kernel_rad - max_displacement) / stride1; + int ymin = (y - kernel_rad - max_displacement) / stride1; + + int xmax = (x + kernel_rad - max_displacement) / stride1; + int ymax = (y + kernel_rad - max_displacement) / stride1; + + if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { + // assumes gradInput1 is pre-allocated and zero filled + return; + } + + if (xmin > xmax || ymin > ymax) { + // assumes gradInput1 is pre-allocated and zero filled + return; + } + + xmin = max(0,xmin); + xmax = min(outputWidth-1,xmax); + + ymin = max(0,ymin); + ymax = min(outputHeight-1,ymax); + + int pInputWidth = inputWidth + 2 * pad_size; + int pInputHeight = inputHeight + 2 * pad_size; + + int pdimyxc = pInputHeight * pInputWidth * nInputChannels; + int pdimxc = pInputWidth * nInputChannels; + int pdimc = nInputChannels; + + int tdimcyx = nOutputChannels * outputHeight * outputWidth; + int tdimyx = outputHeight * outputWidth; + int tdimx = outputWidth; + + int odimcyx = nInputChannels * inputHeight* inputWidth; + int odimyx = inputHeight * inputWidth; + int odimx = inputWidth; + + scalar_t nelems = kernel_size * kernel_size * nInputChannels; + + __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; + prod_sum[tch_off] = 0; + + for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { + + int i2 = (tc % displacement_size - displacement_rad) * stride2; + int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c; + + scalar_t val2 = rInput2[indx2]; + + for (int j = ymin; j <= ymax; ++j) { + for (int i = xmin; i <= xmax; ++i) { + int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; + prod_sum[tch_off] += gradOutput[tindx] * val2; + } + } + } + __syncthreads(); + + if(tch_off == 0) { + scalar_t reduce_sum = 0; + for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { + reduce_sum += prod_sum[idx]; + } + const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); + gradInput1[indx1] = reduce_sum / nelems; + } + +} + +template +__global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth, + const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, + const scalar_t* __restrict__ rInput1, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2) +{ + // n (batch size), c (num of channels), y (height), x (width) + + int n = item; + int y = blockIdx.x * stride1 + pad_size; + int x = blockIdx.y * stride1 + pad_size; + int c = blockIdx.z; + + int tch_off = threadIdx.x; + + int kernel_rad = (kernel_size - 1) / 2; + int displacement_rad = max_displacement / stride2; + int displacement_size = 2 * displacement_rad + 1; + + int pInputWidth = inputWidth + 2 * pad_size; + int pInputHeight = inputHeight + 2 * pad_size; + + int pdimyxc = pInputHeight * pInputWidth * nInputChannels; + int pdimxc = pInputWidth * nInputChannels; + int pdimc = nInputChannels; + + int tdimcyx = nOutputChannels * outputHeight * outputWidth; + int tdimyx = outputHeight * outputWidth; + int tdimx = outputWidth; + + int odimcyx = nInputChannels * inputHeight* inputWidth; + int odimyx = inputHeight * inputWidth; + int odimx = inputWidth; + + scalar_t nelems = kernel_size * kernel_size * nInputChannels; + + __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; + prod_sum[tch_off] = 0; + + for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { + int i2 = (tc % displacement_size - displacement_rad) * stride2; + int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int xmin = (x - kernel_rad - max_displacement - i2) / stride1; + int ymin = (y - kernel_rad - max_displacement - j2) / stride1; + + int xmax = (x + kernel_rad - max_displacement - i2) / stride1; + int ymax = (y + kernel_rad - max_displacement - j2) / stride1; + + if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { + // assumes gradInput2 is pre-allocated and zero filled + continue; + } + + if (xmin > xmax || ymin > ymax) { + // assumes gradInput2 is pre-allocated and zero filled + continue; + } + + xmin = max(0,xmin); + xmax = min(outputWidth-1,xmax); + + ymin = max(0,ymin); + ymax = min(outputHeight-1,ymax); + + int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c; + scalar_t val1 = rInput1[indx1]; + + for (int j = ymin; j <= ymax; ++j) { + for (int i = xmin; i <= xmax; ++i) { + int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; + prod_sum[tch_off] += gradOutput[tindx] * val1; + } + } + } + + __syncthreads(); + + if(tch_off == 0) { + scalar_t reduce_sum = 0; + for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { + reduce_sum += prod_sum[idx]; + } + const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); + gradInput2[indx2] = reduce_sum / nelems; + } + +} + +int correlation_forward_cuda_kernel(at::Tensor& output, + int ob, + int oc, + int oh, + int ow, + int osb, + int osc, + int osh, + int osw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gc, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream) +{ + + int batchSize = ob; + + int nInputChannels = ic; + int inputWidth = iw; + int inputHeight = ih; + + int nOutputChannels = oc; + int outputWidth = ow; + int outputHeight = oh; + + dim3 blocks_grid(batchSize, inputHeight, inputWidth); + dim3 threads_block(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] { + + channels_first<<>>( + input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size); + + })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] { + + channels_first<<>> ( + input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size); + + })); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] { + + correlation_forward<<>> + (output.data(), nOutputChannels, outputHeight, outputWidth, + rInput1.data(), nInputChannels, inputHeight, inputWidth, + rInput2.data(), + pad_size, + kernel_size, + max_displacement, + stride1, + stride2); + + })); + + cudaError_t err = cudaGetLastError(); + + + // check for errors + if (err != cudaSuccess) { + printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); + return 0; + } + + return 1; +} + + +int correlation_backward_cuda_kernel( + at::Tensor& gradOutput, + int gob, + int goc, + int goh, + int gow, + int gosb, + int gosc, + int gosh, + int gosw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& gradInput1, + int gisb, + int gisc, + int gish, + int gisw, + + at::Tensor& gradInput2, + int ggc, + int ggsb, + int ggsc, + int ggsh, + int ggsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream) +{ + + int batchSize = gob; + int num = batchSize; + + int nInputChannels = ic; + int inputWidth = iw; + int inputHeight = ih; + + int nOutputChannels = goc; + int outputWidth = gow; + int outputHeight = goh; + + dim3 blocks_grid(batchSize, inputHeight, inputWidth); + dim3 threads_block(THREADS_PER_BLOCK); + + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] { + + channels_first<<>>( + input1.data(), + rInput1.data(), + nInputChannels, + inputHeight, + inputWidth, + pad_size + ); + })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { + + channels_first<<>>( + input2.data(), + rInput2.data(), + nInputChannels, + inputHeight, + inputWidth, + pad_size + ); + })); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels); + + for (int n = 0; n < num; ++n) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { + + + correlation_backward_input1<<>> ( + n, gradInput1.data(), nInputChannels, inputHeight, inputWidth, + gradOutput.data(), nOutputChannels, outputHeight, outputWidth, + rInput2.data(), + pad_size, + kernel_size, + max_displacement, + stride1, + stride2); + })); + } + + for(int n = 0; n < batchSize; n++) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] { + + correlation_backward_input2<<>>( + n, gradInput2.data(), nInputChannels, inputHeight, inputWidth, + gradOutput.data(), nOutputChannels, outputHeight, outputWidth, + rInput1.data(), + pad_size, + kernel_size, + max_displacement, + stride1, + stride2); + + })); + } + + // check for errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); + return 0; + } + + return 1; +} diff --git a/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cuh b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1586d3af6bc184bfea8482a991a6625f865f02b3 --- /dev/null +++ b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cuh @@ -0,0 +1,91 @@ +#pragma once + +#include +#include +#include + +int correlation_forward_cuda_kernel(at::Tensor& output, + int ob, + int oc, + int oh, + int ow, + int osb, + int osc, + int osh, + int osw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gc, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream); + + +int correlation_backward_cuda_kernel( + at::Tensor& gradOutput, + int gob, + int goc, + int goh, + int gow, + int gosb, + int gosc, + int gosh, + int gosw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& gradInput1, + int gisb, + int gisc, + int gish, + int gisw, + + at::Tensor& gradInput2, + int ggc, + int ggsb, + int ggsc, + int ggsh, + int ggsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream); diff --git a/imaginaire/third_party/resample2d/resample2d.py b/imaginaire/third_party/resample2d/resample2d.py new file mode 100644 index 0000000000000000000000000000000000000000..cbdea3fa9941894090aa124adda1b62e1ea5e012 --- /dev/null +++ b/imaginaire/third_party/resample2d/resample2d.py @@ -0,0 +1,62 @@ +# flake8: noqa +from torch.nn.modules.module import Module +from torch.autograd import Function, Variable +from torch.cuda.amp import autocast +import resample2d_cuda + + +class Resample2dFunction(Function): + + @staticmethod + # def forward(ctx, input1, input2, kernel_size=1, bilinear=True): + def forward(ctx, input1, input2, kernel_size=1): + assert input1.is_contiguous() + assert input2.is_contiguous() + + ctx.save_for_backward(input1, input2) + ctx.kernel_size = kernel_size + ctx.bilinear = True + + _, d, _, _ = input1.size() + b, _, h, w = input2.size() + output = input1.new(b, d, h, w).zero_() + + resample2d_cuda.forward(input1, input2, output, kernel_size) + + return output + + @staticmethod + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + assert grad_output.is_contiguous() + + input1, input2 = ctx.saved_tensors + + grad_input1 = Variable(input1.new(input1.size()).zero_()) + grad_input2 = Variable(input1.new(input2.size()).zero_()) + + # resample2d_cuda.backward(input1, input2, grad_output.data, + # grad_input1.data, grad_input2.data, + # ctx.kernel_size, ctx.bilinear) + resample2d_cuda.backward(input1, input2, grad_output.data, + grad_input1.data, grad_input2.data, + ctx.kernel_size) + + return grad_input1, grad_input2, None, None + + +class Resample2d(Module): + + def __init__(self, kernel_size=1, bilinear=True): + super(Resample2d, self).__init__() + self.kernel_size = kernel_size + self.bilinear = bilinear + + @autocast(False) + def forward(self, input1, input2): + input1, input2 = input1.float(), input2.float() + input1_c = input1.contiguous() + # return Resample2dFunction.apply( + # input1_c, input2, self.kernel_size, self.bilinear) + return Resample2dFunction.apply( + input1_c, input2, self.kernel_size) \ No newline at end of file diff --git a/imaginaire/third_party/resample2d/setup.py b/imaginaire/third_party/resample2d/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..71b5219322ac6c76cf36ac772c48f6c88d623c5e --- /dev/null +++ b/imaginaire/third_party/resample2d/setup.py @@ -0,0 +1,43 @@ +# flake8: noqa +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os + + +cuda_version = os.getenv('CUDA_VERSION') +print('CUDA_VERSION: {}'.format(cuda_version)) + +nvcc_args = list() +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_50,code=sm_50') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_52,code=sm_52') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_60,code=sm_60') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_61,code=sm_61') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_70,code=sm_70') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_75,code=sm_75') +if cuda_version is not None: + if cuda_version >= '11.0': + nvcc_args.append('-gencode') + nvcc_args.append('arch=compute_80,code=sm_80') +nvcc_args.append('-Xcompiler') +nvcc_args.append('-Wall') +nvcc_args.append('-std=c++17') + +setup( + name='resample2d_cuda', + py_modules=['resample2d'], + ext_modules=[ + CUDAExtension('resample2d_cuda', [ + './src/resample2d_cuda.cc', + './src/resample2d_kernel.cu' + ], extra_compile_args={'cxx': ['-Wall', '-std=c++17'], + 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/imaginaire/third_party/resample2d/src/resample2d_cuda.cc b/imaginaire/third_party/resample2d/src/resample2d_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..b330a06bc0f20fe82c275e9a784f7ed91faf7717 --- /dev/null +++ b/imaginaire/third_party/resample2d/src/resample2d_cuda.cc @@ -0,0 +1,34 @@ +#include +#include + +#include "resample2d_kernel.cuh" + +int resample2d_cuda_forward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& output, + int kernel_size/*, bool bilinear*/) { + resample2d_kernel_forward(input1, input2, output, kernel_size/*, + bilinear*/); + return 1; +} + +int resample2d_cuda_backward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + at::Tensor& gradInput2, + int kernel_size/*, bool bilinear*/) { + resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, + gradInput2, kernel_size/*, bilinear*/); + return 1; +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)"); + m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)"); +} + diff --git a/imaginaire/third_party/resample2d/src/resample2d_kernel.cu b/imaginaire/third_party/resample2d/src/resample2d_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..654ca8e417b6d22ff623d72d19e24997a2db284c --- /dev/null +++ b/imaginaire/third_party/resample2d/src/resample2d_kernel.cu @@ -0,0 +1,328 @@ +#include +#include +#include + +#define CUDA_NUM_THREADS 512 +#define THREADS_PER_BLOCK 64 + +#define DIM0(TENSOR) ((TENSOR).x) +#define DIM1(TENSOR) ((TENSOR).y) +#define DIM2(TENSOR) ((TENSOR).z) +#define DIM3(TENSOR) ((TENSOR).w) + +#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) + +template +__global__ void kernel_resample2d_update_output(const int n, + const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, + scalar_t* __restrict__ output, + const long4 output_size, const + long4 output_stride, int + kernel_size/*, bool bilinear*/) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + bool bilinear = true; + if (index >= n) { + return; + } + + scalar_t val = 0.0f; + + int dim_b = DIM0(output_size); + int dim_c = DIM1(output_size); + int dim_h = DIM2(output_size); + int dim_w = DIM3(output_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int c = ( index / dim_hw ) % dim_c; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); + scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); + + scalar_t xf = static_cast(x) + dx; + scalar_t yf = static_cast(y) + dy; + scalar_t alpha = xf - floor(xf); // alpha + scalar_t beta = yf - floor(yf); // beta + + if (bilinear) { + int xL = max(min( int (floor(xf)), dim_w-1), 0); + int xR = max(min( int (floor(xf)+1), dim_w -1), 0); + int yT = max(min( int (floor(yf)), dim_h-1), 0); + int yB = max(min( int (floor(yf)+1), dim_h-1), 0); + + for (int fy = 0; fy < kernel_size; fy += 1) { + for (int fx = 0; fx < kernel_size; fx += 1) { + val += static_cast((1. - alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xL + fx)); + val += static_cast((alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xR + fx)); + val += static_cast((1. - alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xL + fx)); + val += static_cast((alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xR + fx)); + } + } + + output[index] = val; + } + else { + int xN = max(min( int (floor(xf + 0.5)), dim_w - 1), 0); + int yN = max(min( int (floor(yf + 0.5)), dim_h - 1), 0); + + output[index] = static_cast ( DIM3_INDEX(input1, b, c, yN, xN) ); + } + +} + + +template +__global__ void kernel_resample2d_backward_input1( + const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, + const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, + scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 + gradInput_stride, int kernel_size/*, bool bilinear*/) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + bool bilinear = true; + if (index >= n) { + return; + } + + int dim_b = DIM0(gradOutput_size); + int dim_c = DIM1(gradOutput_size); + int dim_h = DIM2(gradOutput_size); + int dim_w = DIM3(gradOutput_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int c = ( index / dim_hw ) % dim_c; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); + scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); + + scalar_t xf = static_cast(x) + dx; + scalar_t yf = static_cast(y) + dy; + scalar_t alpha = xf - int(xf); // alpha + scalar_t beta = yf - int(yf); // beta + + int idim_h = DIM2(input1_size); + int idim_w = DIM3(input1_size); + + int xL = max(min( int (floor(xf)), idim_w-1), 0); + int xR = max(min( int (floor(xf)+1), idim_w -1), 0); + int yT = max(min( int (floor(yf)), idim_h-1), 0); + int yB = max(min( int (floor(yf)+1), idim_h-1), 0); + + for (int fy = 0; fy < kernel_size; fy += 1) { + for (int fx = 0; fx < kernel_size; fx += 1) { + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xL + fx)), (1-alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xR + fx)), (alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xL + fx)), (1-alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xR + fx)), (alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + } + } + +} + +template +__global__ void kernel_resample2d_backward_input2( + const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, + const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, + scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 + gradInput_stride, int kernel_size/*, bool bilinear*/) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + bool bilinear = true; + if (index >= n) { + return; + } + + scalar_t output = 0.0; + int kernel_rad = (kernel_size - 1)/2; + + int dim_b = DIM0(gradInput_size); + int dim_c = DIM1(gradInput_size); + int dim_h = DIM2(gradInput_size); + int dim_w = DIM3(gradInput_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int c = ( index / dim_hw ) % dim_c; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + int odim_c = DIM1(gradOutput_size); + + scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); + scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); + + scalar_t xf = static_cast(x) + dx; + scalar_t yf = static_cast(y) + dy; + + int xL = max(min( int (floor(xf)), dim_w-1), 0); + int xR = max(min( int (floor(xf)+1), dim_w -1), 0); + int yT = max(min( int (floor(yf)), dim_h-1), 0); + int yB = max(min( int (floor(yf)+1), dim_h-1), 0); + + if (c % 2) { + float gamma = 1 - (xf - floor(xf)); // alpha + for (int i = 0; i <= 2*kernel_rad; ++i) { + for (int j = 0; j <= 2*kernel_rad; ++j) { + for (int ch = 0; ch < odim_c; ++ch) { + output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); + output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); + output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); + output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); + } + } + } + } + else { + float gamma = 1 - (yf - floor(yf)); // alpha + for (int i = 0; i <= 2*kernel_rad; ++i) { + for (int j = 0; j <= 2*kernel_rad; ++j) { + for (int ch = 0; ch < odim_c; ++ch) { + output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); + output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); + output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); + output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); + } + } + } + + } + + gradInput[index] = output; + +} + +void resample2d_kernel_forward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& output, + int kernel_size/*, + bool bilinear*/) { + + int n = output.numel(); + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); + const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); + + const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); + const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); + + // TODO: when atomicAdd gets resolved, change to AT_DISPATCH_FLOATING_TYPES_AND_HALF +// AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_forward_kernel", ([&] { + + kernel_resample2d_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + input2.data(), + input2_size, + input2_stride, + output.data(), + output_size, + output_stride, + kernel_size/*, + bilinear*/); + +// })); + + // TODO: ATen-equivalent check + + // THCudaCheck(cudaGetLastError()); + +} + +void resample2d_kernel_backward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + at::Tensor& gradInput2, + int kernel_size/*, + bool bilinear*/) { + + int n = gradOutput.numel(); + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); + const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); + + const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); + const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); + + const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); + const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); + +// AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_backward_input1", ([&] { + + kernel_resample2d_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + input2.data(), + input2_size, + input2_stride, + gradOutput.data(), + gradOutput_size, + gradOutput_stride, + gradInput1.data(), + gradInput1_size, + gradInput1_stride, + kernel_size/*, + bilinear*/ + ); + +// })); + + const long4 gradInput2_size = make_long4(gradInput2.size(0), gradInput2.size(1), gradInput2.size(2), gradInput2.size(3)); + const long4 gradInput2_stride = make_long4(gradInput2.stride(0), gradInput2.stride(1), gradInput2.stride(2), gradInput2.stride(3)); + + n = gradInput2.numel(); + +// AT_DISPATCH_FLOATING_TYPES(gradInput2.type(), "resample_backward_input2", ([&] { + + + kernel_resample2d_backward_input2<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + input2.data(), + input2_size, + input2_stride, + gradOutput.data(), + gradOutput_size, + gradOutput_stride, + gradInput2.data(), + gradInput2_size, + gradInput2_stride, + kernel_size/*, + bilinear*/ + ); + +// })); + + // TODO: Use the ATen equivalent to get last error + + // THCudaCheck(cudaGetLastError()); + +} diff --git a/imaginaire/third_party/resample2d/src/resample2d_kernel.cuh b/imaginaire/third_party/resample2d/src/resample2d_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3a815269a562e762cd7bd0c73af21d468d4eb2fd --- /dev/null +++ b/imaginaire/third_party/resample2d/src/resample2d_kernel.cuh @@ -0,0 +1,19 @@ +#pragma once + +#include + +void resample2d_kernel_forward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& output, + int kernel_size/*, + bool bilinear*/); + +void resample2d_kernel_backward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + at::Tensor& gradInput2, + int kernel_size/*, + bool bilinear*/); \ No newline at end of file diff --git a/imaginaire/third_party/upfirdn2d/__init__.py b/imaginaire/third_party/upfirdn2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c92bf9c45932a8578f64ef83dc0b067ebd27ca0 --- /dev/null +++ b/imaginaire/third_party/upfirdn2d/__init__.py @@ -0,0 +1,3 @@ +from .upfirdn2d import BlurUpsample, BlurDownsample, Blur + +__all__ = ['BlurUpsample', 'BlurDownsample', 'Blur'] diff --git a/imaginaire/third_party/upfirdn2d/setup.py b/imaginaire/third_party/upfirdn2d/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..c927de0928f842d91a5ddb88a8824dac69afea00 --- /dev/null +++ b/imaginaire/third_party/upfirdn2d/setup.py @@ -0,0 +1,43 @@ +# flake8: noqa +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os + + +cuda_version = os.getenv('CUDA_VERSION') +print('CUDA_VERSION: {}'.format(cuda_version)) + +nvcc_args = list() +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_50,code=sm_50') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_52,code=sm_52') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_60,code=sm_60') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_61,code=sm_61') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_70,code=sm_70') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_75,code=sm_75') +if cuda_version is not None: + if cuda_version >= '11.0': + nvcc_args.append('-gencode') + nvcc_args.append('arch=compute_80,code=sm_80') +nvcc_args.append('-Xcompiler') +nvcc_args.append('-Wall') +nvcc_args.append('-std=c++17') + +setup( + name='upfirdn2d_cuda', + py_modules=['upfirdn2d'], + ext_modules=[ + CUDAExtension('upfirdn2d_cuda', [ + './src/upfirdn2d_cuda.cc', + './src/upfirdn2d_cuda_kernel.cu' + ], extra_compile_args={'cxx': ['-Wall', '-std=c++17'], + 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.cc b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..65df7a9ad78e4f6f7560feed79048983f60e8add --- /dev/null +++ b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "upfirdn2d_cuda.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d_cuda(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda_kernel", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d_cuda", &upfirdn2d_cuda); +} + +//------------------------------------------------------------------------ diff --git a/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.h b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd --- /dev/null +++ b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda_kernel.cu b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..d7f8938f7ac1220d934fe6a357de543a452445e4 --- /dev/null +++ b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda_kernel.cu @@ -0,0 +1,350 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "upfirdn2d_cuda.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/imaginaire/third_party/upfirdn2d/upfirdn2d.py b/imaginaire/third_party/upfirdn2d/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..8548efe56653d6d8083f68d6e6617ba84b398d1e --- /dev/null +++ b/imaginaire/third_party/upfirdn2d/upfirdn2d.py @@ -0,0 +1,471 @@ +# flake8: noqa +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient resampling of 2D images.""" + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +import upfirdn2d_cuda + + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + assert fw >= 1 and fh >= 1 + return fw, fh + + +class BlurUpsample(nn.Module): + def __init__(self, + kernel=(1, 3, 3, 1), + factor=2, + padding_mode='zeros'): + super().__init__() + p = len(kernel) + px0 = (p + factor - 1) // 2 + px1 = (p - factor) // 2 + py0 = (p + factor - 1) // 2 + py1 = (p - factor) // 2 + + self.pad = [px0, px1, py0, py1] + self.factor = factor + self.register_buffer('kernel', setup_filter(kernel)) + self.kernel_1d = kernel + self.padding_mode = padding_mode + + def forward(self, x): + if self.padding_mode != 'zeros': + x = F.pad(x, list(self.pad) * 2, mode=self.padding_mode) + out = upfirdn2d( + x, self.kernel, up=self.factor, gain=self.factor ** 2) + else: + out = upfirdn2d( + x, self.kernel, up=self.factor, padding=self.pad, + gain=self.factor ** 2) + return out + + def extra_repr(self): + s = 'kernel={kernel_1d}, ' \ + 'padding_mode={padding_mode}, pad={pad}' + return s.format(**self.__dict__) + + +class BlurDownsample(nn.Module): + def __init__(self, kernel=(1, 3, 3, 1), factor=2, padding_mode='zeros'): + super().__init__() + p = len(kernel) + px0 = (p - factor + 1) // 2 + px1 = (p - factor) // 2 + py0 = (p - factor + 1) // 2 + py1 = (p - factor) // 2 + + self.pad = [px0, px1, py0, py1] + self.factor = factor + self.register_buffer('kernel', setup_filter(kernel)) + self.kernel_1d = kernel + self.padding_mode = padding_mode + + def forward(self, x): + if self.padding_mode != 'zeros': + x = F.pad(x, list(self.pad) * 2, mode=self.padding_mode) + out = upfirdn2d(x, self.kernel, down=self.factor) + else: + out = upfirdn2d(x, self.kernel, down=self.factor, padding=self.pad) + return out + + def extra_repr(self): + s = 'kernel={kernel_1d}, ' \ + 'padding_mode={padding_mode}, pad={pad}' + return s.format(**self.__dict__) + + +class Blur(nn.Module): + def __init__(self, + kernel=(1, 3, 3, 1), + pad=0, + padding_mode='zeros'): + super().__init__() + self.register_buffer('kernel', setup_filter(kernel)) + self.kernel_1d = kernel + self.padding_mode = padding_mode + self.pad = pad + + def forward(self, x): + if self.padding_mode != 'zeros': + x = F.pad(x, list(self.pad) * 2, mode=self.padding_mode) + out = upfirdn2d(x, self.kernel) + else: + out = upfirdn2d(x, self.kernel, padding=self.pad) + return out + + def extra_repr(self): + s = 'kernel={kernel_1d}, ' \ + 'padding_mode={padding_mode}, pad={pad}' + return s.format(**self.__dict__) + + +# ---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, + flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + + +# ---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda': + return _upfirdn2d_cuda(up=up, down=down, padding=padding, + flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + + +# ---------------------------------------------------------------------------- + +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), + max(pady1, 0)]) + x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), + max(-padx0, 0): x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = F.conv2d(input=x, weight=f, groups=num_channels) + else: + x = F.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = F.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + + +# ---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = ( + upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = upfirdn2d_cuda.upfirdn2d_cuda(y, f, upx, upy, downx, downy, padx0, + padx1, pady0, pady1, flip_filter, gain) + else: + y = upfirdn2d_cuda.upfirdn2d_cuda(y, f.unsqueeze(0), upx, 1, downx, 1, + padx0, padx1, 0, 0, flip_filter, + np.sqrt(gain)) + y = upfirdn2d_cuda.upfirdn2d_cuda(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, + pady0, pady1, flip_filter, np.sqrt(gain)) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + + +# ---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + + +# ---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl) + + +# ---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, + impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +# ---------------------------------------------------------------------------- diff --git a/imaginaire/trainers/__init__.py b/imaginaire/trainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780 --- /dev/null +++ b/imaginaire/trainers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md diff --git a/imaginaire/trainers/base.py b/imaginaire/trainers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4053dd6199f03f1944bfafffbef1c9e56bed8991 --- /dev/null +++ b/imaginaire/trainers/base.py @@ -0,0 +1,982 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import json +import os +import time + +import torch +import torchvision +import wandb +from torch.cuda.amp import GradScaler, autocast +from tqdm import tqdm + +from imaginaire.utils.distributed import is_master, master_only +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.io import save_pilimage_in_jpeg +from imaginaire.utils.meters import Meter +from imaginaire.utils.misc import to_cuda, to_device, requires_grad, to_channels_last +from imaginaire.utils.model_average import (calibrate_batch_norm_momentum, + reset_batch_norm) +from imaginaire.utils.visualization import tensor2pilimage + + +class BaseTrainer(object): + r"""Base trainer. We expect that all trainers inherit this class. + + Args: + cfg (obj): Global configuration. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + """ + + def __init__(self, + cfg, + net_G, + net_D, + opt_G, + opt_D, + sch_G, + sch_D, + train_data_loader, + val_data_loader): + super(BaseTrainer, self).__init__() + print('Setup trainer.') + + # Initialize models and data loaders. + self.cfg = cfg + self.net_G = net_G + if cfg.trainer.model_average_config.enabled: + # Two wrappers (DDP + model average). + self.net_G_module = self.net_G.module.module + else: + # One wrapper (DDP) + self.net_G_module = self.net_G.module + self.val_data_loader = val_data_loader + self.is_inference = train_data_loader is None + self.net_D = net_D + self.opt_G = opt_G + self.opt_D = opt_D + self.sch_G = sch_G + self.sch_D = sch_D + self.train_data_loader = train_data_loader + if self.cfg.trainer.channels_last: + self.net_G = self.net_G.to(memory_format=torch.channels_last) + self.net_D = self.net_D.to(memory_format=torch.channels_last) + + # Initialize amp. + if self.cfg.trainer.amp_config.enabled: + print("Using automatic mixed precision training.") + self.scaler_G = GradScaler(**vars(self.cfg.trainer.amp_config)) + self.scaler_D = GradScaler(**vars(self.cfg.trainer.amp_config)) + # In order to check whether the discriminator/generator has + # skipped the last parameter update due to gradient overflow. + self.last_step_count_G = 0 + self.last_step_count_D = 0 + self.skipped_G = False + self.skipped_D = False + + # Initialize data augmentation policy. + self.aug_policy = cfg.trainer.aug_policy + print("Augmentation policy: {}".format(self.aug_policy)) + + # Initialize loss functions. + # All loss names have weights. Some have criterion modules. + # Mapping from loss names to criterion modules. + self.criteria = torch.nn.ModuleDict() + # Mapping from loss names to loss weights. + self.weights = dict() + self.losses = dict(gen_update=dict(), dis_update=dict()) + self.gen_losses = self.losses['gen_update'] + self.dis_losses = self.losses['dis_update'] + self._init_loss(cfg) + for loss_name, loss_weight in self.weights.items(): + print("Loss {:<20} Weight {}".format(loss_name, loss_weight)) + if loss_name in self.criteria.keys() and \ + self.criteria[loss_name] is not None: + self.criteria[loss_name].to('cuda') + + if self.is_inference: + # The initialization steps below can be skipped during inference. + return + + # Initialize logging attributes. + self.current_iteration = 0 + self.current_epoch = 0 + self.start_iteration_time = None + self.start_epoch_time = None + self.elapsed_iteration_time = 0 + self.time_iteration = None + self.time_epoch = None + self.best_fid = None + if self.cfg.speed_benchmark: + self.accu_gen_forw_iter_time = 0 + self.accu_gen_loss_iter_time = 0 + self.accu_gen_back_iter_time = 0 + self.accu_gen_step_iter_time = 0 + self.accu_gen_avg_iter_time = 0 + self.accu_dis_forw_iter_time = 0 + self.accu_dis_loss_iter_time = 0 + self.accu_dis_back_iter_time = 0 + self.accu_dis_step_iter_time = 0 + + # Initialize tensorboard and hparams. + self._init_tensorboard() + self._init_hparams() + + # Initialize validation parameters. + self.val_sample_size = getattr(cfg.trainer, 'val_sample_size', 50000) + self.kid_num_subsets = getattr(cfg.trainer, 'kid_num_subsets', 10) + self.kid_subset_size = self.val_sample_size // self.kid_num_subsets + self.metrics_path = os.path.join(torch.hub.get_dir(), 'metrics') + self.best_metrics = {} + self.eval_networks = getattr(cfg.trainer, 'eval_network', ['clean_inception']) + if self.cfg.metrics_iter is None: + self.cfg.metrics_iter = self.cfg.snapshot_save_iter + if self.cfg.metrics_epoch is None: + self.cfg.metrics_epoch = self.cfg.snapshot_save_epoch + + # AWS credentials. + if hasattr(cfg, 'aws_credentials_file'): + with open(cfg.aws_credentials_file) as fin: + self.credentials = json.load(fin) + else: + self.credentials = None + + if 'TORCH_HOME' not in os.environ: + os.environ['TORCH_HOME'] = os.path.join( + os.environ['HOME'], ".cache") + + def _init_tensorboard(self): + r"""Initialize the tensorboard. Different algorithms might require + different performance metrics. Hence, custom tensorboard + initialization might be necessary. + """ + # Logging frequency: self.cfg.logging_iter + self.meters = {} + + # Logging frequency: self.cfg.snapshot_save_iter + self.metric_meters = {} + + # Logging frequency: self.cfg.image_display_iter + self.image_meter = Meter('images', reduce=False) + + def _init_hparams(self): + r"""Initialize a dictionary of hyperparameters that we want to monitor + in the HParams dashboard in tensorBoard. + """ + self.hparam_dict = {} + + def _write_tensorboard(self): + r"""Write values to tensorboard. By default, we will log the time used + per iteration, time used per epoch, generator learning rate, and + discriminator learning rate. We will log all the losses as well as + custom meters. + """ + # Logs that are shared by all models. + self._write_to_meters({'time/iteration': self.time_iteration, + 'time/epoch': self.time_epoch, + 'optim/gen_lr': self.sch_G.get_last_lr()[0], + 'optim/dis_lr': self.sch_D.get_last_lr()[0]}, + self.meters, + reduce=False) + # Logs for loss values. Different models have different losses. + self._write_loss_meters() + # Other custom logs. + self._write_custom_meters() + + def _write_loss_meters(self): + r"""Write all loss values to tensorboard.""" + for update, losses in self.losses.items(): + # update is 'gen_update' or 'dis_update'. + assert update == 'gen_update' or update == 'dis_update' + for loss_name, loss in losses.items(): + if loss is not None: + full_loss_name = update + '/' + loss_name + if full_loss_name not in self.meters.keys(): + # Create a new meter if it doesn't exist. + self.meters[full_loss_name] = Meter( + full_loss_name, reduce=True) + self.meters[full_loss_name].write(loss.item()) + + def _write_custom_meters(self): + r"""Dummy member function to be overloaded by the child class. + In the child class, you can write down whatever you want to track. + """ + pass + + @staticmethod + def _write_to_meters(data, meters, reduce=True): + r"""Write values to meters.""" + if reduce or is_master(): + for key, value in data.items(): + if key not in meters: + meters[key] = Meter(key, reduce=reduce) + meters[key].write(value) + + def _flush_meters(self, meters): + r"""Flush all meters using the current iteration.""" + for meter in meters.values(): + meter.flush(self.current_iteration) + + def _pre_save_checkpoint(self): + r"""Implement the things you want to do before saving a checkpoint. + For example, you can compute the K-mean features (pix2pixHD) before + saving the model weights to a checkpoint. + """ + pass + + def save_checkpoint(self, current_epoch, current_iteration): + r"""Save network weights, optimizer parameters, scheduler parameters + to a checkpoint. + """ + self._pre_save_checkpoint() + _save_checkpoint(self.cfg, + self.net_G, self.net_D, + self.opt_G, self.opt_D, + self.sch_G, self.sch_D, + current_epoch, current_iteration) + + def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True): + r"""Load network weights, optimizer parameters, scheduler parameters + from a checkpoint. + + Args: + cfg (obj): Global configuration. + checkpoint_path (str): Path to the checkpoint. + resume (bool or None): If not ``None``, will determine whether or + not to load optimizers in addition to network weights. + """ + if os.path.exists(checkpoint_path): + # If checkpoint_path exists, we will load its weights to + # initialize our network. + if resume is None: + resume = False + elif os.path.exists(os.path.join(cfg.logdir, 'latest_checkpoint.txt')): + # This is for resuming the training from the previously saved + # checkpoint. + fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt') + with open(fn, 'r') as f: + line = f.read().splitlines() + checkpoint_path = os.path.join(cfg.logdir, line[0].split(' ')[-1]) + if resume is None: + resume = True + else: + # checkpoint not found and not specified. We will train + # everything from scratch. + current_epoch = 0 + current_iteration = 0 + print('No checkpoint found.') + resume = False + return resume, current_epoch, current_iteration + # Load checkpoint + checkpoint = torch.load( + checkpoint_path, map_location=lambda storage, loc: storage) + current_epoch = 0 + current_iteration = 0 + if resume: + self.net_G.load_state_dict(checkpoint['net_G'], strict=self.cfg.trainer.strict_resume) + if not self.is_inference: + self.net_D.load_state_dict(checkpoint['net_D'], strict=self.cfg.trainer.strict_resume) + if 'opt_G' in checkpoint: + current_epoch = checkpoint['current_epoch'] + current_iteration = checkpoint['current_iteration'] + self.opt_G.load_state_dict(checkpoint['opt_G']) + self.opt_D.load_state_dict(checkpoint['opt_D']) + if load_sch: + self.sch_G.load_state_dict(checkpoint['sch_G']) + self.sch_D.load_state_dict(checkpoint['sch_D']) + else: + if self.cfg.gen_opt.lr_policy.iteration_mode: + self.sch_G.last_epoch = current_iteration + else: + self.sch_G.last_epoch = current_epoch + if self.cfg.dis_opt.lr_policy.iteration_mode: + self.sch_D.last_epoch = current_iteration + else: + self.sch_D.last_epoch = current_epoch + print('Load from: {}'.format(checkpoint_path)) + else: + print('Load network weights only.') + else: + try: + self.net_G.load_state_dict(checkpoint['net_G'], strict=self.cfg.trainer.strict_resume) + if 'net_D' in checkpoint: + self.net_D.load_state_dict(checkpoint['net_D'], strict=self.cfg.trainer.strict_resume) + except Exception: + if self.cfg.trainer.model_average_config.enabled: + net_G_module = self.net_G.module.module + else: + net_G_module = self.net_G.module + if hasattr(net_G_module, 'load_pretrained_network'): + net_G_module.load_pretrained_network(self.net_G, checkpoint['net_G']) + print('Load generator weights only.') + else: + raise ValueError('Checkpoint cannot be loaded.') + + print('Done with loading the checkpoint.') + return resume, current_epoch, current_iteration + + def start_of_epoch(self, current_epoch): + r"""Things to do before an epoch. + + Args: + current_epoch (int): Current number of epoch. + """ + self._start_of_epoch(current_epoch) + self.current_epoch = current_epoch + self.start_epoch_time = time.time() + + def start_of_iteration(self, data, current_iteration): + r"""Things to do before an iteration. + + Args: + data (dict): Data used for the current iteration. + current_iteration (int): Current number of iteration. + """ + data = self._start_of_iteration(data, current_iteration) + data = to_cuda(data) + if self.cfg.trainer.channels_last: + data = to_channels_last(data) + self.current_iteration = current_iteration + if not self.is_inference: + self.net_D.train() + self.net_G.train() + # torch.cuda.synchronize() + self.start_iteration_time = time.time() + return data + + def end_of_iteration(self, data, current_epoch, current_iteration): + r"""Things to do after an iteration. + + Args: + data (dict): Data used for the current iteration. + current_epoch (int): Current number of epoch. + current_iteration (int): Current number of iteration. + """ + self.current_iteration = current_iteration + self.current_epoch = current_epoch + # Update the learning rate policy for the generator if operating in the + # iteration mode. + if self.cfg.gen_opt.lr_policy.iteration_mode: + self.sch_G.step() + # Update the learning rate policy for the discriminator if operating in + # the iteration mode. + if self.cfg.dis_opt.lr_policy.iteration_mode: + self.sch_D.step() + + # Accumulate time + # torch.cuda.synchronize() + self.elapsed_iteration_time += time.time() - self.start_iteration_time + # Logging. + if current_iteration % self.cfg.logging_iter == 0: + ave_t = self.elapsed_iteration_time / self.cfg.logging_iter + self.time_iteration = ave_t + print('Iteration: {}, average iter time: ' + '{:6f}.'.format(current_iteration, ave_t)) + self.elapsed_iteration_time = 0 + + if self.cfg.speed_benchmark: + # Below code block only needed when analyzing computation + # bottleneck. + print('\tGenerator FWD time {:6f}'.format( + self.accu_gen_forw_iter_time / self.cfg.logging_iter)) + print('\tGenerator LOS time {:6f}'.format( + self.accu_gen_loss_iter_time / self.cfg.logging_iter)) + print('\tGenerator BCK time {:6f}'.format( + self.accu_gen_back_iter_time / self.cfg.logging_iter)) + print('\tGenerator STP time {:6f}'.format( + self.accu_gen_step_iter_time / self.cfg.logging_iter)) + print('\tGenerator AVG time {:6f}'.format( + self.accu_gen_avg_iter_time / self.cfg.logging_iter)) + + print('\tDiscriminator FWD time {:6f}'.format( + self.accu_dis_forw_iter_time / self.cfg.logging_iter)) + print('\tDiscriminator LOS time {:6f}'.format( + self.accu_dis_loss_iter_time / self.cfg.logging_iter)) + print('\tDiscriminator BCK time {:6f}'.format( + self.accu_dis_back_iter_time / self.cfg.logging_iter)) + print('\tDiscriminator STP time {:6f}'.format( + self.accu_dis_step_iter_time / self.cfg.logging_iter)) + + print('{:6f}'.format(ave_t)) + + self.accu_gen_forw_iter_time = 0 + self.accu_gen_loss_iter_time = 0 + self.accu_gen_back_iter_time = 0 + self.accu_gen_step_iter_time = 0 + self.accu_gen_avg_iter_time = 0 + self.accu_dis_forw_iter_time = 0 + self.accu_dis_loss_iter_time = 0 + self.accu_dis_back_iter_time = 0 + self.accu_dis_step_iter_time = 0 + + self._end_of_iteration(data, current_epoch, current_iteration) + + # Save everything to the checkpoint. + if current_iteration % self.cfg.snapshot_save_iter == 0: + if current_iteration >= self.cfg.snapshot_save_start_iter: + self.save_checkpoint(current_epoch, current_iteration) + + # Compute metrics. + if current_iteration % self.cfg.metrics_iter == 0: + self.save_image(self._get_save_path('images', 'jpg'), data) + self.write_metrics() + + # Compute image to be saved. + elif current_iteration % self.cfg.image_save_iter == 0: + self.save_image(self._get_save_path('images', 'jpg'), data) + elif current_iteration % self.cfg.image_display_iter == 0: + image_path = os.path.join(self.cfg.logdir, 'images', 'current.jpg') + self.save_image(image_path, data) + + # Logging. + self._write_tensorboard() + if current_iteration % self.cfg.logging_iter == 0: + # Write all logs to tensorboard. + self._flush_meters(self.meters) + + from torch.distributed import barrier + import torch.distributed as dist + if dist.is_initialized(): + barrier() + + def end_of_epoch(self, data, current_epoch, current_iteration): + r"""Things to do after an epoch. + + Args: + data (dict): Data used for the current iteration. + + current_epoch (int): Current number of epoch. + current_iteration (int): Current number of iteration. + """ + # Update the learning rate policy for the generator if operating in the + # epoch mode. + self.current_iteration = current_iteration + self.current_epoch = current_epoch + if not self.cfg.gen_opt.lr_policy.iteration_mode: + self.sch_G.step() + # Update the learning rate policy for the discriminator if operating + # in the epoch mode. + if not self.cfg.dis_opt.lr_policy.iteration_mode: + self.sch_D.step() + elapsed_epoch_time = time.time() - self.start_epoch_time + # Logging. + print('Epoch: {}, total time: {:6f}.'.format(current_epoch, + elapsed_epoch_time)) + self.time_epoch = elapsed_epoch_time + self._end_of_epoch(data, current_epoch, current_iteration) + + # Save everything to the checkpoint. + if current_iteration % self.cfg.snapshot_save_iter == 0: + if current_epoch >= self.cfg.snapshot_save_start_epoch: + self.save_checkpoint(current_epoch, current_iteration) + + # Compute metrics. + if current_iteration % self.cfg.metrics_iter == 0: + self.save_image(self._get_save_path('images', 'jpg'), data) + self.write_metrics() + + def pre_process(self, data): + r"""Custom data pre-processing function. Utilize this function if you + need to preprocess your data before sending it to the generator and + discriminator. + + Args: + data (dict): Data used for the current iteration. + """ + + def recalculate_batch_norm_statistics(self, data_loader, averaged=True): + r"""Update the statistics in the moving average model. + + Args: + data_loader (torch.utils.data.DataLoader): Data loader for + estimating the statistics. + averaged (Boolean): True/False, we recalculate batch norm statistics for EMA/regular + """ + if not self.cfg.trainer.model_average_config.enabled: + return + if averaged: + net_G = self.net_G.module.averaged_model + else: + net_G = self.net_G_module + model_average_iteration = \ + self.cfg.trainer.model_average_config.num_batch_norm_estimation_iterations + if model_average_iteration == 0: + return + with torch.no_grad(): + # Accumulate bn stats.. + net_G.train() + # Reset running stats. + net_G.apply(reset_batch_norm) + for cal_it, cal_data in enumerate(data_loader): + if cal_it >= model_average_iteration: + print('Done with {} iterations of updating batch norm ' + 'statistics'.format(model_average_iteration)) + break + cal_data = to_device(cal_data, 'cuda') + cal_data = self.pre_process(cal_data) + # Averaging over all batches + net_G.apply(calibrate_batch_norm_momentum) + net_G(cal_data) + + def save_image(self, path, data): + r"""Compute visualization images and save them to the disk. + + Args: + path (str): Location of the file. + data (dict): Data used for the current iteration. + """ + self.net_G.eval() + vis_images = self._get_visualizations(data) + if is_master() and vis_images is not None: + vis_images = torch.cat( + [img for img in vis_images if img is not None], dim=3).float() + vis_images = (vis_images + 1) / 2 + print('Save output images to {}'.format(path)) + vis_images.clamp_(0, 1) + os.makedirs(os.path.dirname(path), exist_ok=True) + image_grid = torchvision.utils.make_grid( + vis_images, nrow=1, padding=0, normalize=False) + if self.cfg.trainer.image_to_tensorboard: + self.image_meter.write_image(image_grid, self.current_iteration) + torchvision.utils.save_image(image_grid, path, nrow=1) + wandb.log({os.path.splitext(os.path.basename(path))[0]: [wandb.Image(path)]}) + + def write_metrics(self): + r"""Write metrics to the tensorboard.""" + cur_fid = self._compute_fid() + if cur_fid is not None: + if self.best_fid is not None: + self.best_fid = min(self.best_fid, cur_fid) + else: + self.best_fid = cur_fid + metric_dict = {'FID': cur_fid, 'best_FID': self.best_fid} + self._write_to_meters(metric_dict, self.metric_meters, reduce=False) + self._flush_meters(self.metric_meters) + + def _get_save_path(self, subdir, ext): + r"""Get the image save path. + + Args: + subdir (str): Sub-directory under the main directory for saving + the outputs. + ext (str): Filename extension for the image (e.g., jpg, png, ...). + Return: + (str): image filename to be used to save the visualization results. + """ + subdir_path = os.path.join(self.cfg.logdir, subdir) + if not os.path.exists(subdir_path): + os.makedirs(subdir_path, exist_ok=True) + return os.path.join( + subdir_path, 'epoch_{:05}_iteration_{:09}.{}'.format( + self.current_epoch, self.current_iteration, ext)) + + def _get_outputs(self, net_D_output, real=True): + r"""Return output values. Note that when the gan mode is relativistic. + It will do the difference before returning. + + Args: + net_D_output (dict): + real_outputs (tensor): Real output values. + fake_outputs (tensor): Fake output values. + real (bool): Return real or fake. + """ + + def _get_difference(a, b): + r"""Get difference between two lists of tensors or two tensors. + + Args: + a: list of tensors or tensor + b: list of tensors or tensor + """ + out = list() + for x, y in zip(a, b): + if isinstance(x, list): + res = _get_difference(x, y) + else: + res = x - y + out.append(res) + return out + + if real: + if self.cfg.trainer.gan_relativistic: + return _get_difference(net_D_output['real_outputs'], net_D_output['fake_outputs']) + else: + return net_D_output['real_outputs'] + else: + if self.cfg.trainer.gan_relativistic: + return _get_difference(net_D_output['fake_outputs'], net_D_output['real_outputs']) + else: + return net_D_output['fake_outputs'] + + def _start_of_epoch(self, current_epoch): + r"""Operations to do before starting an epoch. + + Args: + current_epoch (int): Current number of epoch. + """ + pass + + def _start_of_iteration(self, data, current_iteration): + r"""Operations to do before starting an iteration. + + Args: + data (dict): Data used for the current iteration. + current_iteration (int): Current epoch number. + Returns: + (dict): Data used for the current iteration. They might be + processed by the custom _start_of_iteration function. + """ + return data + + def _end_of_iteration(self, data, current_epoch, current_iteration): + r"""Operations to do after an iteration. + + Args: + data (dict): Data used for the current iteration. + current_epoch (int): Current number of epoch. + current_iteration (int): Current epoch number. + """ + pass + + def _end_of_epoch(self, data, current_epoch, current_iteration): + r"""Operations to do after an epoch. + + Args: + data (dict): Data used for the current iteration. + current_epoch (int): Current number of epoch. + current_iteration (int): Current epoch number. + """ + pass + + def _get_visualizations(self, data): + r"""Compute visualization outputs. + + Args: + data (dict): Data used for the current iteration. + """ + return None + + def _compute_fid(self): + r"""FID computation function to be overloaded.""" + return None + + def _init_loss(self, cfg): + r"""Every trainer should implement its own init loss function.""" + raise NotImplementedError + + def gen_update(self, data): + r"""Update the generator. + + Args: + data (dict): Data used for the current iteration. + """ + update_finished = False + while not update_finished: + # Set requires_grad flags. + requires_grad(self.net_G_module, True) + requires_grad(self.net_D, False) + + # Compute the loss. + self._time_before_forward() + with autocast(enabled=self.cfg.trainer.amp_config.enabled): + total_loss = self.gen_forward(data) + if total_loss is None: + return + + # Zero-grad and backpropagate the loss. + self.opt_G.zero_grad(set_to_none=True) + self._time_before_backward() + self.scaler_G.scale(total_loss).backward() + + # Optionally clip gradient norm. + if hasattr(self.cfg.gen_opt, 'clip_grad_norm'): + self.scaler_G.unscale_(self.opt_G) + total_norm = torch.nn.utils.clip_grad_norm_( + self.net_G_module.parameters(), + self.cfg.gen_opt.clip_grad_norm + ) + self.gen_grad_norm = total_norm + if torch.isfinite(total_norm) and \ + total_norm > self.cfg.gen_opt.clip_grad_norm: + # print(f"Gradient norm of the generator ({total_norm}) " + # f"too large.") + if getattr(self.cfg.gen_opt, 'skip_grad', False): + print(f"Skip gradient update.") + self.opt_G.zero_grad(set_to_none=True) + self.scaler_G.step(self.opt_G) + self.scaler_G.update() + break + # else: + # print(f"Clip gradient norm to " + # f"{self.cfg.gen_opt.clip_grad_norm}.") + + # Perform an optimizer step. + self._time_before_step() + self.scaler_G.step(self.opt_G) + self.scaler_G.update() + # Whether the step above was skipped. + if self.last_step_count_G == self.opt_G._step_count: + print("Generator overflowed! with step count [{}], total loss [{}]".format(self.opt_G._step_count, total_loss)) + if not torch.isfinite(total_loss): + print("Generator loss is not finite. Skip this iteration!") + update_finished = True + else: + self.last_step_count_G = self.opt_G._step_count + update_finished = True + + self._extra_gen_step(data) + + # Update model average. + self._time_before_model_avg() + if self.cfg.trainer.model_average_config.enabled: + self.net_G.module.update_average() + + self._detach_losses() + self._time_before_leave_gen() + + def gen_forward(self, data): + r"""Every trainer should implement its own generator forward.""" + raise NotImplementedError + + def _extra_gen_step(self, data): + pass + + def dis_update(self, data): + r"""Update the discriminator. + + Args: + data (dict): Data used for the current iteration. + """ + update_finished = False + while not update_finished: + # Set requires_grad flags. + requires_grad(self.net_G_module, False) + requires_grad(self.net_D, True) + + # Compute the loss. + self._time_before_forward() + with autocast(enabled=self.cfg.trainer.amp_config.enabled): + total_loss = self.dis_forward(data) + if total_loss is None: + return + + # Zero-grad and backpropagate the loss. + self.opt_D.zero_grad(set_to_none=True) + self._time_before_backward() + self.scaler_D.scale(total_loss).backward() + + # Optionally clip gradient norm. + if hasattr(self.cfg.dis_opt, 'clip_grad_norm'): + self.scaler_D.unscale_(self.opt_D) + total_norm = torch.nn.utils.clip_grad_norm_( + self.net_D.parameters(), self.cfg.dis_opt.clip_grad_norm + ) + self.dis_grad_norm = total_norm + if torch.isfinite(total_norm) and \ + total_norm > self.cfg.dis_opt.clip_grad_norm: + print(f"Gradient norm of the discriminator ({total_norm}) " + f"too large.") + if getattr(self.cfg.dis_opt, 'skip_grad', False): + print(f"Skip gradient update.") + self.opt_D.zero_grad(set_to_none=True) + self.scaler_D.step(self.opt_D) + self.scaler_D.update() + continue + else: + print(f"Clip gradient norm to " + f"{self.cfg.dis_opt.clip_grad_norm}.") + + # Perform an optimizer step. + self._time_before_step() + self.scaler_D.step(self.opt_D) + self.scaler_D.update() + # Whether the step above was skipped. + if self.last_step_count_D == self.opt_D._step_count: + print("Discriminator overflowed! with step count [{}], total loss [{}]".format(self.opt_D._step_count, total_loss)) + if not torch.isfinite(total_loss): + print("Discriminator loss is not finite. " + "Skip this iteration!") + update_finished = True + else: + self.last_step_count_D = self.opt_D._step_count + update_finished = True + + self._extra_dis_step(data) + + self._detach_losses() + self._time_before_leave_dis() + + def dis_forward(self, data): + r"""Every trainer should implement its own discriminator forward.""" + raise NotImplementedError + + def _extra_dis_step(self, data): + pass + + def test(self, data_loader, output_dir, inference_args): + r"""Compute results images for a batch of input data and save the + results in the specified folder. + + Args: + data_loader (torch.utils.data.DataLoader): PyTorch dataloader. + output_dir (str): Target location for saving the output image. + """ + if self.cfg.trainer.model_average_config.enabled: + net_G = self.net_G.module.averaged_model + else: + net_G = self.net_G.module + net_G.eval() + + print('# of samples %d' % len(data_loader)) + for it, data in enumerate(tqdm(data_loader)): + data = self.start_of_iteration(data, current_iteration=-1) + with torch.no_grad(): + output_images, file_names = \ + net_G.inference(data, **vars(inference_args)) + for output_image, file_name in zip(output_images, file_names): + fullname = os.path.join(output_dir, file_name + '.jpg') + output_image = tensor2pilimage(output_image.clamp_(-1, 1), + minus1to1_normalized=True) + save_pilimage_in_jpeg(fullname, output_image) + + def _get_total_loss(self, gen_forward): + r"""Return the total loss to be backpropagated. + Args: + gen_forward (bool): If ``True``, backpropagates the generator loss, + otherwise the discriminator loss. + """ + losses = self.gen_losses if gen_forward else self.dis_losses + total_loss = torch.tensor(0., device=torch.device('cuda')) + # Iterates over all possible losses. + for loss_name in self.weights: + # If it is for the current model (gen/dis). + if loss_name in losses: + # Multiply it with the corresponding weight + # and add it to the total loss. + total_loss += losses[loss_name] * self.weights[loss_name] + losses['total'] = total_loss # logging purpose + return total_loss + + def _detach_losses(self): + r"""Detach all logging variables to prevent potential memory leak.""" + for loss_name in self.gen_losses: + self.gen_losses[loss_name] = self.gen_losses[loss_name].detach() + for loss_name in self.dis_losses: + self.dis_losses[loss_name] = self.dis_losses[loss_name].detach() + + def _time_before_forward(self): + r""" + Record time before applying forward. + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + self.forw_time = time.time() + + def _time_before_loss(self): + r""" + Record time before computing loss. + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + self.loss_time = time.time() + + def _time_before_backward(self): + r""" + Record time before applying backward. + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + self.back_time = time.time() + + def _time_before_step(self): + r""" + Record time before updating the weights + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + self.step_time = time.time() + + def _time_before_model_avg(self): + r""" + Record time before applying model average. + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + self.avg_time = time.time() + + def _time_before_leave_gen(self): + r""" + Record forward, backward, loss, and model average time for the + generator update. + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + end_time = time.time() + self.accu_gen_forw_iter_time += self.loss_time - self.forw_time + self.accu_gen_loss_iter_time += self.back_time - self.loss_time + self.accu_gen_back_iter_time += self.step_time - self.back_time + self.accu_gen_step_iter_time += self.avg_time - self.step_time + self.accu_gen_avg_iter_time += end_time - self.avg_time + + def _time_before_leave_dis(self): + r""" + Record forward, backward, loss time for the discriminator update. + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + end_time = time.time() + self.accu_dis_forw_iter_time += self.loss_time - self.forw_time + self.accu_dis_loss_iter_time += self.back_time - self.loss_time + self.accu_dis_back_iter_time += self.step_time - self.back_time + self.accu_dis_step_iter_time += end_time - self.step_time + + +@master_only +def _save_checkpoint(cfg, + net_G, net_D, + opt_G, opt_D, + sch_G, sch_D, + current_epoch, current_iteration): + r"""Save network weights, optimizer parameters, scheduler parameters + in the checkpoint. + + Args: + cfg (obj): Global configuration. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + current_epoch (int): Current epoch. + current_iteration (int): Current iteration. + """ + latest_checkpoint_path = 'epoch_{:05}_iteration_{:09}_checkpoint.pt'.format( + current_epoch, current_iteration) + save_path = os.path.join(cfg.logdir, latest_checkpoint_path) + torch.save( + { + 'net_G': net_G.state_dict(), + 'net_D': net_D.state_dict(), + 'opt_G': opt_G.state_dict(), + 'opt_D': opt_D.state_dict(), + 'sch_G': sch_G.state_dict(), + 'sch_D': sch_D.state_dict(), + 'current_epoch': current_epoch, + 'current_iteration': current_iteration, + }, + save_path, + ) + fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt') + with open(fn, 'wt') as f: + f.write('latest_checkpoint: %s' % latest_checkpoint_path) + print('Save checkpoint to {}'.format(save_path)) + return save_path diff --git a/imaginaire/trainers/gancraft.py b/imaginaire/trainers/gancraft.py new file mode 100644 index 0000000000000000000000000000000000000000..7ebc4ca9017e03d18fa8d38faa2a7fd4dd26bd76 --- /dev/null +++ b/imaginaire/trainers/gancraft.py @@ -0,0 +1,325 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import collections +import os + +import torch +import torch.nn as nn + +from imaginaire.config import Config +from imaginaire.generators.spade import Generator as SPADEGenerator +from imaginaire.losses import (FeatureMatchingLoss, GaussianKLLoss, PerceptualLoss) +from imaginaire.model_utils.gancraft.loss import GANLoss +from imaginaire.trainers.base import BaseTrainer +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.io import get_checkpoint +from imaginaire.utils.misc import split_labels, to_device +from imaginaire.utils.trainer import ModelAverage, WrappedModel +from imaginaire.utils.visualization import tensor2label + + +class GauGANLoader(object): + r"""Manages the SPADE/GauGAN model used to generate pseudo-GTs for training GANcraft. + + Args: + gaugan_cfg (Config): SPADE configuration. + """ + + def __init__(self, gaugan_cfg): + print('[GauGANLoader] Loading GauGAN model.') + cfg = Config(gaugan_cfg.config) + # checkpoint = get_checkpoint(default_checkpoint_path, cfg.pretrained_weight) + checkpoint = cfg.pretrained_weight + ckpt = torch.load(checkpoint, map_location='cpu') + + net_G = WrappedModel(ModelAverage(SPADEGenerator(cfg.gen, cfg.data).to('cuda'))) + net_G.load_state_dict(ckpt['net_G']) + self.net_GG = net_G.module.averaged_model + self.net_GG.eval() + self.net_GG.half() + print('[GauGANLoader] GauGAN loading complete.') + + def eval(self, label, z=None, style_img=None): + r"""Produce output given segmentation and other conditioning inputs. + random style will be used if neither z nor style_img is provided. + + Args: + label (N x C x H x W tensor): One-hot segmentation mask of shape. + z: Style vector. + style_img: Style image. + """ + inputs = {'label': label[:, :-1].detach().half()} + random_style = True + + if z is not None: + random_style = False + inputs['z'] = z.detach().half() + elif style_img is not None: + random_style = False + inputs['images'] = style_img.detach().half() + + net_GG_output = self.net_GG(inputs, random_style=random_style) + + return net_GG_output['fake_images'] + + +class Trainer(BaseTrainer): + r"""Initialize GANcraft trainer. + + Args: + cfg (Config): Global configuration. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + """ + + def __init__(self, + cfg, + net_G, + net_D, + opt_G, + opt_D, + sch_G, + sch_D, + train_data_loader, + val_data_loader): + super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, + opt_D, sch_G, sch_D, + train_data_loader, val_data_loader) + + # Load the pseudo-GT network only if in training mode, else not needed. + if not self.is_inference: + self.gaugan_model = GauGANLoader(cfg.trainer.gaugan_loader) + + def _init_loss(self, cfg): + r"""Initialize loss terms. + + Args: + cfg (obj): Global configuration. + """ + if hasattr(cfg.trainer.loss_weight, 'gan'): + self.criteria['GAN'] = GANLoss() + self.weights['GAN'] = cfg.trainer.loss_weight.gan + if hasattr(cfg.trainer.loss_weight, 'pseudo_gan'): + self.criteria['PGAN'] = GANLoss() + self.weights['PGAN'] = cfg.trainer.loss_weight.pseudo_gan + if hasattr(cfg.trainer.loss_weight, 'l2'): + self.criteria['L2'] = nn.MSELoss() + self.weights['L2'] = cfg.trainer.loss_weight.l2 + if hasattr(cfg.trainer.loss_weight, 'l1'): + self.criteria['L1'] = nn.L1Loss() + self.weights['L1'] = cfg.trainer.loss_weight.l1 + if hasattr(cfg.trainer, 'perceptual_loss'): + self.criteria['Perceptual'] = \ + PerceptualLoss( + network=cfg.trainer.perceptual_loss.mode, + layers=cfg.trainer.perceptual_loss.layers, + weights=cfg.trainer.perceptual_loss.weights) + self.weights['Perceptual'] = cfg.trainer.loss_weight.perceptual + # Setup the feature matching loss. + if hasattr(cfg.trainer.loss_weight, 'feature_matching'): + self.criteria['FeatureMatching'] = FeatureMatchingLoss() + self.weights['FeatureMatching'] = \ + cfg.trainer.loss_weight.feature_matching + # Setup the Gaussian KL divergence loss. + if hasattr(cfg.trainer.loss_weight, 'kl'): + self.criteria['GaussianKL'] = GaussianKLLoss() + self.weights['GaussianKL'] = cfg.trainer.loss_weight.kl + + def _start_of_epoch(self, current_epoch): + torch.cuda.empty_cache() # Prevent the first iteration from running OOM. + + def _start_of_iteration(self, data, current_iteration): + r"""Model specific custom start of iteration process. We will do two + things. First, put all the data to GPU. Second, we will resize the + input so that it becomes multiple of the factor for bug-free + convolutional operations. This factor is given by the yaml file. + E.g., base = getattr(self.net_G, 'base', 32) + + Args: + data (dict): The current batch. + current_iteration (int): The iteration number of the current batch. + """ + data = to_device(data, 'cuda') + + # Sample camera poses and pseudo-GTs. + with torch.no_grad(): + samples = self.net_G.module.sample_camera(data, self.gaugan_model.eval) + + return {**data, **samples} + + def gen_forward(self, data): + r"""Compute the loss for SPADE generator. + + Args: + data (dict): Training data at the current iteration. + """ + net_G_output = self.net_G(data, random_style=False) + + self._time_before_loss() + + if 'GAN' in self.criteria or 'PGAN' in self.criteria: + incl_pseudo_real = False + if 'FeatureMatching' in self.criteria: + incl_pseudo_real = True + net_D_output = self.net_D(data, net_G_output, incl_real=False, incl_pseudo_real=incl_pseudo_real) + output_fake = net_D_output['fake_outputs'] # Choose from real_outputs and fake_outputs. + + gan_loss = self.criteria['GAN'](output_fake, True, dis_update=False) + if 'GAN' in self.criteria: + self.gen_losses['GAN'] = gan_loss + if 'PGAN' in self.criteria: + self.gen_losses['PGAN'] = gan_loss + + if 'FeatureMatching' in self.criteria: + self.gen_losses['FeatureMatching'] = self.criteria['FeatureMatching']( + net_D_output['fake_features'], net_D_output['pseudo_real_features']) + + if 'GaussianKL' in self.criteria: + self.gen_losses['GaussianKL'] = self.criteria['GaussianKL'](net_G_output['mu'], net_G_output['logvar']) + + # Perceptual loss is always between fake image and pseudo real image. + if 'Perceptual' in self.criteria: + self.gen_losses['Perceptual'] = self.criteria['Perceptual']( + net_G_output['fake_images'], data['pseudo_real_img']) + + # Reconstruction loss between fake and pseudo real. + if 'L2' in self.criteria: + self.gen_losses['L2'] = self.criteria['L2'](net_G_output['fake_images'], data['pseudo_real_img']) + if 'L1' in self.criteria: + self.gen_losses['L1'] = self.criteria['L1'](net_G_output['fake_images'], data['pseudo_real_img']) + + total_loss = 0 + for key in self.criteria: + total_loss = total_loss + self.gen_losses[key] * self.weights[key] + + self.gen_losses['total'] = total_loss + return total_loss + + def dis_forward(self, data): + r"""Compute the loss for GANcraft discriminator. + + Args: + data (dict): Training data at the current iteration. + """ + if 'GAN' not in self.criteria and 'PGAN' not in self.criteria: + return + + with torch.no_grad(): + net_G_output = self.net_G(data, random_style=False) + net_G_output['fake_images'] = net_G_output['fake_images'].detach() + + incl_real = False + incl_pseudo_real = False + if 'GAN' in self.criteria: + incl_real = True + if 'PGAN' in self.criteria: + incl_pseudo_real = True + net_D_output = self.net_D(data, net_G_output, incl_real=incl_real, incl_pseudo_real=incl_pseudo_real) + + self._time_before_loss() + total_loss = 0 + if 'GAN' in self.criteria: + output_fake = net_D_output['fake_outputs'] + output_real = net_D_output['real_outputs'] + + fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True) + true_loss = self.criteria['GAN'](output_real, True, dis_update=True) + self.dis_losses['GAN/fake'] = fake_loss + self.dis_losses['GAN/true'] = true_loss + self.dis_losses['GAN'] = fake_loss + true_loss + total_loss = total_loss + self.dis_losses['GAN'] * self.weights['GAN'] + if 'PGAN' in self.criteria: + output_fake = net_D_output['fake_outputs'] + output_pseudo_real = net_D_output['pseudo_real_outputs'] + + fake_loss = self.criteria['PGAN'](output_fake, False, dis_update=True) + true_loss = self.criteria['PGAN'](output_pseudo_real, True, dis_update=True) + self.dis_losses['PGAN/fake'] = fake_loss + self.dis_losses['PGAN/true'] = true_loss + self.dis_losses['PGAN'] = fake_loss + true_loss + total_loss = total_loss + self.dis_losses['PGAN'] * self.weights['PGAN'] + + self.dis_losses['total'] = total_loss + return total_loss + + def _get_visualizations(self, data): + r"""Compute visualization image. + + Args: + data (dict): The current batch. + """ + with torch.no_grad(): + label_lengths = self.train_data_loader.dataset.get_label_lengths() + labels = split_labels(data['label'], label_lengths) + + # Get visualization of the real image and segmentation mask. + segmap = tensor2label(labels['seg_maps'], label_lengths['seg_maps'], output_normalized_tensor=True) + segmap = torch.cat([x.unsqueeze(0) for x in segmap], 0) + + # Get output from GANcraft model + net_G_output_randstyle = self.net_G(data, random_style=True) + net_G_output = self.net_G(data, random_style=False) + + vis_images = [data['images'], segmap, net_G_output_randstyle['fake_images'], net_G_output['fake_images']] + + if 'fake_masks' in data: + # Get pseudo-GT. + labels = split_labels(data['fake_masks'], label_lengths) + segmap = tensor2label(labels['seg_maps'], label_lengths['seg_maps'], output_normalized_tensor=True) + segmap = torch.cat([x.unsqueeze(0) for x in segmap], 0) + vis_images.append(segmap) + + if 'pseudo_real_img' in data: + vis_images.append(data['pseudo_real_img']) + + if self.cfg.trainer.model_average_config.enabled: + net_G_model_average_output = self.net_G.module.averaged_model(data, random_style=True) + vis_images.append(net_G_model_average_output['fake_images']) + return vis_images + + def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True): + r"""Load network weights, optimizer parameters, scheduler parameters + from a checkpoint. + + Args: + cfg (obj): Global configuration. + checkpoint_path (str): Path to the checkpoint. + resume (bool or None): If not ``None``, will determine whether or + not to load optimizers in addition to network weights. + """ + ret = super().load_checkpoint(cfg, checkpoint_path, resume, load_sch) + + if getattr(cfg.trainer, 'reset_opt_g_on_resume', False): + self.opt_G.state = collections.defaultdict(dict) + print('[GANcraft::load_checkpoint] Resetting opt_G.state') + if getattr(cfg.trainer, 'reset_opt_d_on_resume', False): + self.opt_D.state = collections.defaultdict(dict) + print('[GANcraft::load_checkpoint] Resetting opt_D.state') + + return ret + + def test(self, data_loader, output_dir, inference_args): + r"""Compute results images for a batch of input data and save the + results in the specified folder. + + Args: + data_loader (torch.utils.data.DataLoader): PyTorch dataloader. + output_dir (str): Target location for saving the output image. + """ + if self.cfg.trainer.model_average_config.enabled: + net_G = self.net_G.module.averaged_model + else: + net_G = self.net_G.module + net_G.eval() + + torch.cuda.empty_cache() + with torch.no_grad(): + net_G.inference(output_dir, **vars(inference_args)) diff --git a/imaginaire/utils/__init__.py b/imaginaire/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780 --- /dev/null +++ b/imaginaire/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md diff --git a/imaginaire/utils/cudnn.py b/imaginaire/utils/cudnn.py new file mode 100644 index 0000000000000000000000000000000000000000..c7a5cc3b5607c56e997a6c38c184e4b3f4e302f8 --- /dev/null +++ b/imaginaire/utils/cudnn.py @@ -0,0 +1,22 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch.backends.cudnn as cudnn + +from imaginaire.utils.distributed import master_only_print as print + + +def init_cudnn(deterministic, benchmark): + r"""Initialize the cudnn module. The two things to consider is whether to + use cudnn benchmark and whether to use cudnn deterministic. If cudnn + benchmark is set, then the cudnn deterministic is automatically false. + + Args: + deterministic (bool): Whether to use cudnn deterministic. + benchmark (bool): Whether to use cudnn benchmark. + """ + cudnn.deterministic = deterministic + cudnn.benchmark = benchmark + print('cudnn benchmark: {}'.format(benchmark)) + print('cudnn deterministic: {}'.format(deterministic)) diff --git a/imaginaire/utils/data.py b/imaginaire/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..22268c955ff7df4e5933e8ce2fb0b38c9c0e2f4a --- /dev/null +++ b/imaginaire/utils/data.py @@ -0,0 +1,612 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# flake8: noqa: E712 +"""Utils for handling datasets.""" + +import time +import numpy as np +from PIL import Image + +# https://github.com/albumentations-team/albumentations#comments +import cv2 +# from imaginaire.utils.distributed import master_only_print as print +import albumentations as alb # noqa nopep8 + +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +IMG_EXTENSIONS = ('jpg', 'jpeg', 'png', 'ppm', 'bmp', + 'pgm', 'tif', 'tiff', 'webp', + 'JPG', 'JPEG', 'PNG', 'PPM', 'BMP', + 'PGM', 'TIF', 'TIFF', 'WEBP') +HDR_IMG_EXTENSIONS = ('hdr',) +VIDEO_EXTENSIONS = 'mp4' + + +class Augmentor(object): + r"""Handles data augmentation using albumentations library.""" + + def __init__(self, aug_list, individual_video_frame_aug_list, image_data_types, is_mask, + keypoint_data_types, interpolator): + r"""Initializes augmentation pipeline. + + Args: + aug_list (list): List of augmentation operations in sequence. + individual_video_frame_aug_list (list): List of augmentation operations in sequence that will be applied + to individual frames of videos independently. + image_data_types (list): List of keys in expected inputs. + is_mask (dict): Whether this data type is discrete masks? + keypoint_data_types (list): List of keys which are keypoints. + """ + + self.aug_list = aug_list + self.individual_video_frame_aug_list = individual_video_frame_aug_list + self.image_data_types = image_data_types + self.is_mask = is_mask + self.crop_h, self.crop_w = None, None + self.resize_h, self.resize_w = None, None + self.resize_smallest_side = None + self.max_time_step = 1 + self.keypoint_data_types = keypoint_data_types + self.interpolator = interpolator + + self.augment_ops = self._build_augmentation_ops() + self.individual_video_frame_augmentation_ops = self._build_individual_video_frame_augmentation_ops() + # Both crop and resize can't be none at the same time. + if self.crop_h is None and self.resize_smallest_side is None and \ + self.resize_h is None: + raise ValueError('resize_smallest_side, resize_h_w, ' + 'and crop_h_w cannot all be missing.') + # If resize_smallest_side is given, resize_h_w should not be give. + if self.resize_smallest_side is not None: + assert self.resize_h is None, \ + 'Cannot have both `resize_smallest_side` and `resize_h_w` set.' + if self.resize_smallest_side is None and self.resize_h is None: + self.resize_h, self.resize_w = self.crop_h, self.crop_w + + def _build_individual_video_frame_augmentation_ops(self): + r"""Builds sequence of augmentation ops that will be applied to each frame in the video independently. + Returns: + (list of alb.ops): List of augmentation ops. + """ + augs = [] + for key, value in self.individual_video_frame_aug_list.items(): + if key == 'random_scale_limit': + if type(value) == float: + scale_limit_lb = scale_limit_ub = value + p = 1 + else: + scale_limit_lb = value['scale_limit_lb'] + scale_limit_ub = value['scale_limit_ub'] + p = value['p'] + augs.append(alb.RandomScale(scale_limit=(-scale_limit_lb, scale_limit_ub), p=p)) + elif key == 'random_crop_h_w': + h, w = value.split(',') + h, w = int(h), int(w) + self.crop_h, self.crop_w = h, w + augs.append(alb.PadIfNeeded(min_height=h, min_width=w)) + augs.append(alb.RandomCrop(h, w, always_apply=True, p=1)) + return augs + + def _build_augmentation_ops(self): + r"""Builds sequence of augmentation ops. + Returns: + (list of alb.ops): List of augmentation ops. + """ + augs = [] + for key, value in self.aug_list.items(): + if key == 'resize_smallest_side': + if isinstance(value, int): + self.resize_smallest_side = value + else: + h, w = value.split(',') + h, w = int(h), int(w) + self.resize_smallest_side = (h, w) + elif key == 'resize_h_w': + h, w = value.split(',') + h, w = int(h), int(w) + self.resize_h, self.resize_w = h, w + elif key == 'random_resize_h_w_aspect': + aspect_start, aspect_end = value.find('('), value.find(')') + aspect = value[aspect_start+1:aspect_end] + aspect_min, aspect_max = aspect.split(',') + h, w = value[:aspect_start].split(',')[:2] + h, w = int(h), int(w) + aspect_min, aspect_max = float(aspect_min), float(aspect_max) + augs.append(alb.RandomResizedCrop( + h, w, scale=(1, 1), + ratio=(aspect_min, aspect_max), always_apply=True, p=1)) + self.resize_h, self.resize_w = h, w + elif key == 'rotate': + augs.append(alb.Rotate( + limit=value, always_apply=True, p=1)) + elif key == 'random_rotate_90': + augs.append(alb.RandomRotate90(always_apply=False, p=0.5)) + elif key == 'random_scale_limit': + augs.append(alb.RandomScale(scale_limit=(0, value), p=1)) + elif key == 'random_crop_h_w': + h, w = value.split(',') + h, w = int(h), int(w) + self.crop_h, self.crop_w = h, w + augs.append(alb.RandomCrop(h, w, always_apply=True, p=1)) + elif key == 'center_crop_h_w': + h, w = value.split(',') + h, w = int(h), int(w) + self.crop_h, self.crop_w = h, w + augs.append(alb.CenterCrop(h, w, always_apply=True, p=1)) + elif key == 'horizontal_flip': + # This is handled separately as we need to keep track if this + # was applied in order to correctly modify keypoint data. + if value: + augs.append(alb.HorizontalFlip(always_apply=False, p=0.5)) + # The options below including contrast, blur, motion_blur, compression, gamma + # were used during developing face-vid2vid. + elif key == 'contrast': + brightness_limit = value['brightness_limit'] + contrast_limit = value['contrast_limit'] + p = value['p'] + augs.append(alb.RandomBrightnessContrast( + brightness_limit=brightness_limit, contrast_limit=contrast_limit, p=p)) + elif key == 'blur': + blur_limit = value['blur_limit'] + p = value['p'] + augs.append(alb.Blur(blur_limit=blur_limit, p=p)) + elif key == 'motion_blur': + blur_limit = value['blur_limit'] + p = value['p'] + augs.append(alb.MotionBlur(blur_limit=blur_limit, p=p)) + elif key == 'compression': + quality_lower = value['quality_lower'] + p = value['p'] + augs.append(alb.ImageCompression(quality_lower=quality_lower, p=p)) + elif key == 'gamma': + gamma_limit_lb = value['gamma_limit_lb'] + gamma_limit_ub = value['gamma_limit_ub'] + p = value['p'] + augs.append(alb.RandomGamma(gamma_limit=(gamma_limit_lb, gamma_limit_ub), p=p)) + elif key == 'max_time_step': + self.max_time_step = value + assert self.max_time_step >= 1, \ + 'max_time_step has to be at least 1' + else: + raise ValueError('Unknown augmentation %s' % (key)) + return augs + + def _choose_image_key(self, inputs): + r"""Choose key to replace with 'image' for input to albumentations. + + Returns: + key (str): Chosen key to be replace with 'image' + """ + if 'image' in inputs: + return 'image' + for data_type in inputs: + if data_type in self.image_data_types: + return data_type + + def _choose_keypoint_key(self, inputs): + r"""Choose key to replace with 'keypoints' for input to albumentations. + Returns: + key (str): Chosen key to be replace with 'keypoints' + """ + if not self.keypoint_data_types: + return None + if 'keypoints' in inputs: + return 'keypoints' + for data_type in inputs: + if data_type in self.keypoint_data_types: + return data_type + + def _create_augmentation_targets(self, inputs): + r"""Create additional targets as required by the albumentation library. + + Args: + inputs (dict): Keys are from self.augmentable_data_types. Values can + be numpy.ndarray or list of numpy.ndarray + (image or list of images). + Returns: + (dict): + - targets (dict): Dict containing mapping of keys to image/mask types. + - new_inputs (dict): Dict containing mapping of keys to data. + """ + # Get additional target list. + targets, new_inputs = {}, {} + for data_type in inputs: + if data_type in self.keypoint_data_types: + # Keypoint-type. + target_type = 'keypoints' + elif data_type in self.image_data_types: + # Image-type. + # Find the target type (image/mask) based on interpolation + # method. + if self.is_mask[data_type]: + target_type = 'mask' + else: + target_type = 'image' + else: + raise ValueError( + 'Data type: %s is not image or keypoint' % (data_type)) + + current_data_type_inputs = inputs[data_type] + if not isinstance(current_data_type_inputs, list): + current_data_type_inputs = [current_data_type_inputs] + + # Create additional_targets and inputs when there are multiples. + for idx, new_input in enumerate(current_data_type_inputs): + key = data_type + if idx > 0: + key = '%s::%05d' % (key, idx) + targets[key] = target_type + new_inputs[key] = new_input + + return targets, new_inputs + + def _collate_augmented(self, augmented): + r"""Collate separated images back into sequence, grouped by keys. + + Args: + augmented (dict): Dict containing frames with keys of the form + 'key', 'key::00001', 'key::00002', ..., 'key::N'. + Returns: + (dict): + - outputs (dict): Dict with list of collated inputs, i.e. frames of + - same key are arranged in order ['key', 'key::00001', ..., 'key::N']. + """ + full_keys = sorted(augmented.keys()) + outputs = {} + for full_key in full_keys: + if '::' not in full_key: + # First occurrence of this key. + key = full_key + outputs[key] = [] + else: + key = full_key.split('::')[0] + outputs[key].append(augmented[full_key]) + return outputs + + def _get_resize_h_w(self, height, width): + r"""Get height and width to resize to, given smallest side. + + Args: + height (int): Input image height. + width (int): Input image width. + Returns: + (dict): + - height (int): Height to resize image to. + - width (int): Width to resize image to. + """ + if self.resize_smallest_side is None: + return self.resize_h, self.resize_w + + if isinstance(self.resize_smallest_side, int): + resize_smallest_height, resize_smallest_width = self.resize_smallest_side, self.resize_smallest_side + else: + resize_smallest_height, resize_smallest_width = self.resize_smallest_side + + if height * resize_smallest_width <= width * resize_smallest_height: + new_height = resize_smallest_height + new_width = int(np.round(new_height * width / float(height))) + else: + new_width = resize_smallest_width + new_height = int(np.round(new_width * height / float(width))) + return new_height, new_width + + def _perform_unpaired_augmentation(self, inputs, augment_ops): + r"""Perform different data augmentation on different image inputs. Note that this operation only works + + Args: + inputs (dict): Keys are from self.image_data_types. Values are list + of numpy.ndarray (list of images). + augment_ops (list): The augmentation operations. + Returns: + (dict): + - augmented (dict): Augmented inputs, with same keys as inputs. + - is_flipped (dict): Flag which tells if images have been LR flipped. + """ + # Process each data type separately as this is unpaired augmentation. + is_flipped = {} + for data_type in inputs: + assert data_type in self.image_data_types + augmented, flipped_flag = self._perform_paired_augmentation( + {data_type: inputs[data_type]}, augment_ops) + inputs[data_type] = augmented[data_type] + is_flipped[data_type] = flipped_flag + return inputs, is_flipped + + def _perform_paired_augmentation(self, inputs, augment_ops): + r"""Perform same data augmentation on all inputs. + + Args: + inputs (dict): Keys are from self.augmentable_data_types. Values are + list of numpy.ndarray (list of images). + augment_ops (list): The augmentation operations. + + Returns: + (dict): + - augmented (dict): Augmented inputs, with same keys as inputs. + - is_flipped (bool): Flag which tells if images have been LR flipped. + """ + # Different data types may have different sizes and we use the largest one as the original size. + # Convert PIL images to numpy array. + self.original_h, self.original_w = 0, 0 + for data_type in inputs: + if data_type in self.keypoint_data_types or \ + data_type not in self.image_data_types: + continue + for idx in range(len(inputs[data_type])): + value = inputs[data_type][idx] + # Get resize h, w. + w, h = get_image_size(value) + self.original_h, self.original_w = max(self.original_h, h), max(self.original_w, w) + # self.original_h, self.original_w = h, w + # self.resize_h, self.resize_w = self._get_resize_h_w(h, w) + # Convert to numpy array with 3 dims (H, W, C). + value = np.array(value) + if value.ndim == 2: + value = value[..., np.newaxis] + inputs[data_type][idx] = value + self.resize_h, self.resize_w = self._get_resize_h_w(self.original_h, self.original_w) + + # Add resize op to augmentation ops. + aug_ops_with_resize = [alb.Resize( + self.resize_h, self.resize_w, interpolation=getattr(cv2, self.interpolator), always_apply=1, p=1 + )] + augment_ops + + # Create targets. + targets, new_inputs = self._create_augmentation_targets(inputs) + extra_params = {} + + # Albumentation requires a key called 'image' and + # a key called 'keypoints', if any keypoints are being passed in. + # Arbitrarily choose one key of image type to be 'image'. + chosen_image_key = self._choose_image_key(inputs) + new_inputs['image'] = new_inputs.pop(chosen_image_key) + targets['image'] = targets.pop(chosen_image_key) + # Arbitrarily choose one key of keypoint type to be 'keypoints'. + chosen_keypoint_key = self._choose_keypoint_key(inputs) + if chosen_keypoint_key is not None: + new_inputs['keypoints'] = new_inputs.pop(chosen_keypoint_key) + targets['keypoints'] = targets.pop(chosen_keypoint_key) + extra_params['keypoint_params'] = alb.KeypointParams( + format='xy', remove_invisible=False) + + # Do augmentation. + augmented = alb.ReplayCompose( + aug_ops_with_resize, additional_targets=targets, + **extra_params)(**new_inputs) + augmentation_params = augmented.pop('replay') + + # Check if flipping has occurred. + is_flipped = False + for augmentation_param in augmentation_params['transforms']: + if 'HorizontalFlip' in augmentation_param['__class_fullname__']: + is_flipped = augmentation_param['applied'] + self.is_flipped = is_flipped + + # Replace the key 'image' with chosen_image_key, same for 'keypoints'. + augmented[chosen_image_key] = augmented.pop('image') + if chosen_keypoint_key is not None: + augmented[chosen_keypoint_key] = augmented.pop('keypoints') + + # Pack images back into a sequence. + augmented = self._collate_augmented(augmented) + + # Convert keypoint types to np.array from list. + for data_type in self.keypoint_data_types: + augmented[data_type] = np.array(augmented[data_type]) + + return augmented, is_flipped + + def perform_augmentation(self, inputs, paired, augment_ops): + r"""Entry point for augmentation. + + Args: + inputs (dict): Keys are from self.augmentable_data_types. Values are + list of numpy.ndarray (list of images). + paired (bool): Apply same augmentation to all input keys? + augment_ops (list): The augmentation operations. + """ + # Make sure that all inputs are of same size, else trouble will + # ensue. This is because different images might have different + # aspect ratios. + # Check within data type. + for data_type in inputs: + if data_type in self.keypoint_data_types or \ + data_type not in self.image_data_types: + continue + for idx in range(len(inputs[data_type])): + if idx == 0: + w, h = get_image_size(inputs[data_type][idx]) + else: + this_w, this_h = get_image_size(inputs[data_type][idx]) + # assert this_w == w and this_h == h + # assert this_w / (1.0 * this_h) == w / (1.0 * h) + # Check across data types. + if paired and self.resize_smallest_side is not None: + for idx, data_type in enumerate(inputs): + if data_type in self.keypoint_data_types or \ + data_type not in self.image_data_types: + continue + if paired: + return self._perform_paired_augmentation(inputs, augment_ops) + else: + return self._perform_unpaired_augmentation(inputs, augment_ops) + + +def load_from_lmdb(keys, lmdbs): + r"""Load keys from lmdb handles. + + Args: + keys (dict): This has data_type as key, and a list of paths into LMDB as + values. + lmdbs (dict): This has data_type as key, and LMDB handle as value. + Returns: + data (dict): This has data_type as key, and a list of decoded items from + LMDBs as value. + """ + data = {} + for data_type in keys: + if data_type not in data: + data[data_type] = [] + data_type_keys = keys[data_type] + if not isinstance(data_type_keys, list): + data_type_keys = [data_type_keys] + for key in data_type_keys: + data[data_type].append(lmdbs[data_type].getitem_by_path( + key.encode(), data_type)) + return data + + +def load_from_folder(keys, handles): + r"""Load keys from lmdb handles. + + Args: + keys (dict): This has data_type as key, and a list of paths as + values. + handles (dict): This has data_type as key, and Folder handle as value. + Returns: + data (dict): This has data_type as key, and a list of decoded items from + folders as value. + """ + data = {} + for data_type in keys: + if data_type not in data: + data[data_type] = [] + data_type_keys = keys[data_type] + if not isinstance(data_type_keys, list): + data_type_keys = [data_type_keys] + for key in data_type_keys: + data[data_type].append(handles[data_type].getitem_by_path( + key.encode(), data_type)) + return data + + +def load_from_object_store(keys, handles): + r"""Load keys from AWS S3 handles. + + Args: + keys (dict): This has data_type as key, and a list of paths as + values. + handles (dict): This has data_type as key, and Folder handle as value. + Returns: + data (dict): This has data_type as key, and a list of decoded items from + folders as value. + """ + data = {} + for data_type in keys: + if data_type not in data: + data[data_type] = [] + data_type_keys = keys[data_type] + if not isinstance(data_type_keys, list): + data_type_keys = [data_type_keys] + for key in data_type_keys: + while True: + try: + data[data_type].append(handles[data_type].getitem_by_path(key, data_type)) + except Exception as e: + print(e) + print(key, data_type) + print('Retrying in 30 seconds') + time.sleep(30) + continue + break + return data + + +def get_paired_input_image_channel_number(data_cfg): + r"""Get number of channels for the input image. + + Args: + data_cfg (obj): Data configuration structure. + Returns: + num_channels (int): Number of input image channels. + """ + num_channels = 0 + for ix, data_type in enumerate(data_cfg.input_types): + for k in data_type: + if k in data_cfg.input_image: + num_channels += data_type[k].num_channels + print('Concatenate %s for input.' % data_type) + print('\tNum. of channels in the input image: %d' % num_channels) + return num_channels + + +def get_paired_input_label_channel_number(data_cfg, video=False): + r"""Get number of channels for the input label map. + + Args: + data_cfg (obj): Data configuration structure. + video (bool): Whether we are dealing with video data. + Returns: + num_channels (int): Number of input label map channels. + """ + num_labels = 0 + if not hasattr(data_cfg, 'input_labels'): + return num_labels + for ix, data_type in enumerate(data_cfg.input_types): + for k in data_type: + if k in data_cfg.input_labels: + if hasattr(data_cfg, 'one_hot_num_classes') and k in data_cfg.one_hot_num_classes: + num_labels += data_cfg.one_hot_num_classes[k] + if getattr(data_cfg, 'use_dont_care', False): + num_labels += 1 + else: + num_labels += data_type[k].num_channels + print('Concatenate %s for input.' % data_type) + + if video: + num_time_steps = getattr(data_cfg.train, 'initial_sequence_length', + None) + num_labels *= num_time_steps + num_labels += get_paired_input_image_channel_number(data_cfg) * ( + num_time_steps - 1) + + print('\tNum. of channels in the input label: %d' % num_labels) + return num_labels + + +def get_class_number(data_cfg): + r"""Get number of classes for class-conditional GAN model + + Args: + data_cfg (obj): Data configuration structure. + + Returns: + (int): Number of classes. + """ + return data_cfg.num_classes + + +def get_crop_h_w(augmentation): + r"""Get height and width of crop. + + Args: + augmentation (dict): Dict of applied augmentations. + + Returns: + (dict): + - crop_h (int): Height of the image crop. + - crop_w (int): Width of the image crop. + """ + print(augmentation.__dict__.keys()) + for k in augmentation.__dict__.keys(): + if 'crop_h_w' in k: + filed = augmentation[k] + crop_h, crop_w = filed.split(',') + crop_h = int(crop_h) + crop_w = int(crop_w) + # assert crop_w == crop_h, 'This implementation only ' \ + # 'supports square-shaped images.' + print('\tCrop size: (%d, %d)' % (crop_h, crop_w)) + return crop_h, crop_w + raise AttributeError + + +def get_image_size(x): + try: + w, h = x.size + except Exception: + h, w, _ = x.shape + return w, h diff --git a/imaginaire/utils/dataset.py b/imaginaire/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..605bb2e20d2a22a70e38254bea46bd73177c8c5a --- /dev/null +++ b/imaginaire/utils/dataset.py @@ -0,0 +1,120 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import importlib + +import torch +import torch.distributed as dist + +from imaginaire.utils.distributed import master_only_print as print + + +def _get_train_and_val_dataset_objects(cfg): + r"""Return dataset objects for the training and validation sets. + + Args: + cfg (obj): Global configuration file. + + Returns: + (dict): + - train_dataset (obj): PyTorch training dataset object. + - val_dataset (obj): PyTorch validation dataset object. + """ + dataset_module = importlib.import_module(cfg.data.type) + train_dataset = dataset_module.Dataset(cfg, is_inference=False) + if hasattr(cfg.data.val, 'type'): + for key in ['type', 'input_types', 'input_image']: + setattr(cfg.data, key, getattr(cfg.data.val, key)) + dataset_module = importlib.import_module(cfg.data.type) + val_dataset = dataset_module.Dataset(cfg, is_inference=True) + print('Train dataset length:', len(train_dataset)) + print('Val dataset length:', len(val_dataset)) + return train_dataset, val_dataset + + +def _get_data_loader(cfg, dataset, batch_size, not_distributed=False, + shuffle=True, drop_last=True, seed=0): + r"""Return data loader . + + Args: + cfg (obj): Global configuration file. + dataset (obj): PyTorch dataset object. + batch_size (int): Batch size. + not_distributed (bool): Do not use distributed samplers. + + Return: + (obj): Data loader. + """ + not_distributed = not_distributed or not dist.is_initialized() + if not_distributed: + sampler = None + else: + sampler = torch.utils.data.distributed.DistributedSampler(dataset, seed=seed) + num_workers = getattr(cfg.data, 'num_workers', 8) + persistent_workers = getattr(cfg.data, 'persistent_workers', False) + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle and (sampler is None), + sampler=sampler, + pin_memory=True, + num_workers=num_workers, + drop_last=drop_last, + persistent_workers=persistent_workers if num_workers > 0 else False + ) + return data_loader + + +def get_train_and_val_dataloader(cfg, seed=0): + r"""Return dataset objects for the training and validation sets. + + Args: + cfg (obj): Global configuration file. + + Returns: + (dict): + - train_data_loader (obj): Train data loader. + - val_data_loader (obj): Val data loader. + """ + train_dataset, val_dataset = _get_train_and_val_dataset_objects(cfg) + train_data_loader = _get_data_loader(cfg, train_dataset, cfg.data.train.batch_size, drop_last=True, seed=seed) + not_distributed = getattr(cfg.data, 'val_data_loader_not_distributed', False) + not_distributed = 'video' in cfg.data.type or not_distributed + val_data_loader = _get_data_loader( + cfg, val_dataset, cfg.data.val.batch_size, not_distributed, + shuffle=False, drop_last=getattr(cfg.data.val, 'drop_last', False), seed=seed) + return train_data_loader, val_data_loader + + +def _get_test_dataset_object(cfg): + r"""Return dataset object for the test set + + Args: + cfg (obj): Global configuration file. + + Returns: + (obj): PyTorch dataset object. + """ + dataset_module = importlib.import_module(cfg.test_data.type) + test_dataset = dataset_module.Dataset(cfg, is_inference=True, is_test=True) + return test_dataset + + +def get_test_dataloader(cfg): + r"""Return dataset objects for testing + + Args: + cfg (obj): Global configuration file. + + Returns: + (obj): Val data loader. It may not contain the ground truth. + """ + test_dataset = _get_test_dataset_object(cfg) + not_distributed = getattr( + cfg.test_data, 'val_data_loader_not_distributed', False) + not_distributed = 'video' in cfg.test_data.type or not_distributed + test_data_loader = _get_data_loader( + cfg, test_dataset, cfg.test_data.test.batch_size, not_distributed, + shuffle=False) + return test_data_loader diff --git a/imaginaire/utils/diff_aug.py b/imaginaire/utils/diff_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..afd481cbbcfab8446b445001d5b0776c56b75b54 --- /dev/null +++ b/imaginaire/utils/diff_aug.py @@ -0,0 +1,142 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md + +# Differentiable Augmentation for Data-Efficient GAN Training +# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han +# https://arxiv.org/pdf/2006.10738 +# Modified from https://github.com/mit-han-lab/data-efficient-gans +import torch +import torch.nn.functional as F + + +def apply_diff_aug(data, keys, aug_policy, inplace=False, **kwargs): + r"""Applies differentiable augmentation. + Args: + data (dict): Input data. + keys (list of str): Keys to the data values that we want to apply + differentiable augmentation to. + aug_policy (str): Type of augmentation(s), ``'color'``, + ``'translation'``, or ``'cutout'`` separated by ``','``. + """ + if aug_policy == '': + return data + data_aug = data if inplace else {} + for key, value in data.items(): + if key in keys: + data_aug[key] = diff_aug(data[key], aug_policy, **kwargs) + else: + data_aug[key] = data[key] + return data_aug + + +def diff_aug(x, policy='', channels_first=True, **kwargs): + if policy: + if not channels_first: + x = x.permute(0, 3, 1, 2) + for p in policy.split(','): + for f in AUGMENT_FNS[p]: + x = f(x, **kwargs) + if not channels_first: + x = x.permute(0, 2, 3, 1) + x = x.contiguous() + return x + + +def rand_brightness(x, **kwargs): + x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, + device=x.device) - 0.5) + return x + + +def rand_saturation(x, **kwargs): + x_mean = x.mean(dim=1, keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, + device=x.device) * 2) + x_mean + return x + + +def rand_contrast(x, **kwargs): + x_mean = x.mean(dim=[1, 2, 3], keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, + device=x.device) + 0.5) + x_mean + return x + + +def rand_translation(x, ratio=0.125, **kwargs): + shift_x, shift_y = int(x.size(2) * ratio + 0.5), int( + x.size(3) * ratio + 0.5) + translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], + device=x.device) + translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], + device=x.device) + # noinspection PyTypeChecker + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(x.size(2), dtype=torch.long, device=x.device), + torch.arange(x.size(3), dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) + grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) + x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) + x = x_pad.permute(0, 2, 3, 1).contiguous()[ + grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous() + return x + + +def rand_cutout(x, ratio=0.5, **kwargs): + cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), + size=[x.size(0), 1, 1], device=x.device) + offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), + size=[x.size(0), 1, 1], device=x.device) + # noinspection PyTypeChecker + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(cutout_size[0], dtype=torch.long, device=x.device), + torch.arange(cutout_size[1], dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, + max=x.size(2) - 1) + grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, + max=x.size(3) - 1) + mask = torch.ones(x.size(0), x.size(2), x.size(3), + dtype=x.dtype, device=x.device) + mask[grid_batch, grid_x, grid_y] = 0 + x = x * mask.unsqueeze(1) + return x + + +def rand_translation_scale(x, trans_r=0.125, scale_r=0.125, + mode='bilinear', padding_mode='reflection', + **kwargs): + assert x.dim() == 4, "Input must be a 4D tensor." + batch_size = x.size(0) + + # Identity transformation. + theta = torch.eye(2, 3, device=x.device).unsqueeze(0).repeat( + batch_size, 1, 1) + + # Translation, uniformly sampled from (-trans_r, trans_r). + translate = \ + 2 * trans_r * torch.rand(batch_size, 2, device=x.device) - trans_r + theta[:, :, 2] += translate + + # Scaling, uniformly sampled from (1-scale_r, 1+scale_r). + scale = \ + 2 * scale_r * torch.rand(batch_size, 2, device=x.device) - scale_r + theta[:, :, :2] += torch.diag_embed(scale) + + grid = F.affine_grid(theta, x.size()) + x = F.grid_sample( + x.float(), grid.float(), mode=mode, padding_mode=padding_mode) + return x + + +AUGMENT_FNS = { + 'color': [rand_brightness, rand_saturation, rand_contrast], + 'translation': [rand_translation], + 'translation_scale': [rand_translation_scale], + 'cutout': [rand_cutout], +} diff --git a/imaginaire/utils/distributed.py b/imaginaire/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..d7ec9d1099684e58a80a61107fe828e292352002 --- /dev/null +++ b/imaginaire/utils/distributed.py @@ -0,0 +1,117 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import functools +import ctypes + +import torch +import torch.distributed as dist + + +def init_dist(local_rank, backend='nccl', **kwargs): + r"""Initialize distributed training""" + if dist.is_available(): + if dist.is_initialized(): + return torch.cuda.current_device() + torch.cuda.set_device(local_rank) + dist.init_process_group(backend=backend, init_method='env://', **kwargs) + + # Increase the L2 fetch granularity for faster speed. + _libcudart = ctypes.CDLL('libcudart.so') + # Set device limit on the current device + # cudaLimitMaxL2FetchGranularity = 0x05 + pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) + _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) + _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05)) + # assert pValue.contents.value == 128 + + +def get_rank(): + r"""Get rank of the thread.""" + rank = 0 + if dist.is_available(): + if dist.is_initialized(): + rank = dist.get_rank() + return rank + + +def get_world_size(): + r"""Get world size. How many GPUs are available in this job.""" + world_size = 1 + if dist.is_available(): + if dist.is_initialized(): + world_size = dist.get_world_size() + return world_size + + +def master_only(func): + r"""Apply this function only to the master GPU.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + r"""Simple function wrapper for the master function""" + if get_rank() == 0: + return func(*args, **kwargs) + else: + return None + return wrapper + + +def is_master(): + r"""check if current process is the master""" + return get_rank() == 0 + + +def is_local_master(): + return torch.cuda.current_device() == 0 + + +@master_only +def master_only_print(*args): + r"""master-only print""" + print(*args) + + +def dist_reduce_tensor(tensor, rank=0, reduce='mean'): + r""" Reduce to rank 0 """ + world_size = get_world_size() + if world_size < 2: + return tensor + with torch.no_grad(): + dist.reduce(tensor, dst=rank) + if get_rank() == rank: + if reduce == 'mean': + tensor /= world_size + elif reduce == 'sum': + pass + else: + raise NotImplementedError + return tensor + + +def dist_all_reduce_tensor(tensor, reduce='mean'): + r""" Reduce to all ranks """ + world_size = get_world_size() + if world_size < 2: + return tensor + with torch.no_grad(): + dist.all_reduce(tensor) + if reduce == 'mean': + tensor /= world_size + elif reduce == 'sum': + pass + else: + raise NotImplementedError + return tensor + + +def dist_all_gather_tensor(tensor): + r""" gather to all ranks """ + world_size = get_world_size() + if world_size < 2: + return [tensor] + tensor_list = [ + torch.ones_like(tensor) for _ in range(dist.get_world_size())] + with torch.no_grad(): + dist.all_gather(tensor_list, tensor) + return tensor_list diff --git a/imaginaire/utils/gpu_affinity.py b/imaginaire/utils/gpu_affinity.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4e9cb40a5a5f9185e903af55694b5952cfe0ff --- /dev/null +++ b/imaginaire/utils/gpu_affinity.py @@ -0,0 +1,61 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import math +import os +import pynvml + +pynvml.nvmlInit() + + +def systemGetDriverVersion(): + r"""Get Driver Version""" + return pynvml.nvmlSystemGetDriverVersion() + + +def deviceGetCount(): + r"""Get number of devices""" + return pynvml.nvmlDeviceGetCount() + + +class device(object): + r"""Device used for nvml.""" + _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) + + def __init__(self, device_idx): + super().__init__() + self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) + + def getName(self): + r"""Get obect name""" + return pynvml.nvmlDeviceGetName(self.handle) + + def getCpuAffinity(self): + r"""Get CPU affinity""" + affinity_string = '' + for j in pynvml.nvmlDeviceGetCpuAffinity( + self.handle, device._nvml_affinity_elements): + # assume nvml returns list of 64 bit ints + affinity_string = '{:064b}'.format(j) + affinity_string + affinity_list = [int(x) for x in affinity_string] + affinity_list.reverse() # so core 0 is in 0th element of list + + return [i for i, e in enumerate(affinity_list) if e != 0] + + +def set_affinity(gpu_id=None): + r"""Set GPU affinity + + Args: + gpu_id (int): Which gpu device. + """ + if gpu_id is None: + gpu_id = int(os.getenv('LOCAL_RANK', 0)) + + dev = device(gpu_id) + os.sched_setaffinity(0, dev.getCpuAffinity()) + + # list of ints + # representing the logical cores this process is now affinitied with + return os.sched_getaffinity(0) diff --git a/imaginaire/utils/init_weight.py b/imaginaire/utils/init_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..80d826c27d7fe1ab75bfe565b40531acd02abd2b --- /dev/null +++ b/imaginaire/utils/init_weight.py @@ -0,0 +1,84 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch.nn import init + + +def weights_init(init_type='normal', gain=0.02, bias=None): + r"""Initialize weights in the network. + + Args: + init_type (str): The name of the initialization scheme. + gain (float): The parameter that is required for the initialization + scheme. + bias (object): If not ``None``, specifies the initialization parameter + for bias. + + Returns: + (obj): init function to be applied. + """ + + def init_func(m): + r"""Init function + + Args: + m: module to be weight initialized. + """ + class_name = m.__class__.__name__ + if hasattr(m, 'weight') and ( + class_name.find('Conv') != -1 or + class_name.find('Linear') != -1 or + class_name.find('Embedding') != -1): + lr_mul = getattr(m, 'lr_mul', 1.) + gain_final = gain / lr_mul + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain_final) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain_final) + elif init_type == 'xavier_uniform': + init.xavier_uniform_(m.weight.data, gain=gain_final) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + with torch.no_grad(): + m.weight.data *= gain_final + elif init_type == 'kaiming_linear': + init.kaiming_normal_( + m.weight.data, a=0, mode='fan_in', nonlinearity='linear' + ) + with torch.no_grad(): + m.weight.data *= gain_final + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain_final) + elif init_type == 'none': + pass + # m.reset_parameters() + else: + raise NotImplementedError( + 'initialization method [%s] is ' + 'not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + if init_type == 'none': + pass + elif bias is not None: + bias_type = getattr(bias, 'type', 'normal') + if bias_type == 'normal': + bias_gain = getattr(bias, 'gain', 0.5) + init.normal_(m.bias.data, 0.0, bias_gain) + else: + raise NotImplementedError( + 'initialization method [%s] is ' + 'not implemented' % bias_type) + else: + init.constant_(m.bias.data, 0.0) + return init_func + + +def weights_rescale(): + def init_func(m): + if hasattr(m, 'init_gain'): + for name, p in m.named_parameters(): + if 'output_scale' not in name: + p.data.mul_(m.init_gain) + return init_func diff --git a/imaginaire/utils/io.py b/imaginaire/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..3744b1ffe8c5c832f273241701cb1c9b2b2f139e --- /dev/null +++ b/imaginaire/utils/io.py @@ -0,0 +1,138 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os + +import requests +import torch.distributed as dist +import torchvision.utils + +from imaginaire.utils.distributed import is_master + + +def save_pilimage_in_jpeg(fullname, output_img): + r"""Save PIL Image to JPEG. + + Args: + fullname (str): Full save path. + output_img (PIL Image): Image to be saved. + """ + dirname = os.path.dirname(fullname) + os.makedirs(dirname, exist_ok=True) + output_img.save(fullname, 'JPEG', quality=99) + + +def save_intermediate_training_results( + visualization_images, logdir, current_epoch, current_iteration): + r"""Save intermediate training results for debugging purpose. + + Args: + visualization_images (tensor): Image where pixel values are in [-1, 1]. + logdir (str): Where to save the image. + current_epoch (int): Current training epoch. + current_iteration (int): Current training iteration. + """ + visualization_images = (visualization_images + 1) / 2 + output_filename = os.path.join( + logdir, 'images', + 'epoch_{:05}iteration{:09}.jpg'.format( + current_epoch, current_iteration)) + print('Save output images to {}'.format(output_filename)) + os.makedirs(os.path.dirname(output_filename), exist_ok=True) + image_grid = torchvision.utils.make_grid( + visualization_images.data, nrow=1, padding=0, normalize=False) + torchvision.utils.save_image(image_grid, output_filename, nrow=1) + + +def download_file_from_google_drive(URL, destination): + r"""Download a file from google drive. + + Args: + URL: GDrive file ID. + destination: Path to save the file. + + Returns: + + """ + download_file(f"https://docs.google.com/uc?export=download&id={URL}", destination) + + +def download_file(URL, destination): + r"""Download a file from google drive or pbss by using the url. + + Args: + URL: GDrive URL or PBSS pre-signed URL for the checkpoint. + destination: Path to save the file. + + Returns: + + """ + session = requests.Session() + response = session.get(URL, stream=True) + token = get_confirm_token(response) + if token: + params = {'confirm': token} + response = session.get(URL, params=params, stream=True) + save_response_content(response, destination) + + +def get_confirm_token(response): + r"""Get confirm token + + Args: + response: Check if the file exists. + + Returns: + + """ + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination): + r"""Save response content + + Args: + response: + destination: Path to save the file. + + Returns: + + """ + chunk_size = 32768 + with open(destination, "wb") as f: + for chunk in response.iter_content(chunk_size): + if chunk: + f.write(chunk) + + +def get_checkpoint(checkpoint_path, url=''): + r"""Get the checkpoint path. If it does not exist yet, download it from + the url. + + Args: + checkpoint_path (str): Checkpoint path. + url (str): URL to download checkpoint. + Returns: + (str): Full checkpoint path. + """ + if 'TORCH_HOME' not in os.environ: + os.environ['TORCH_HOME'] = os.getcwd() + save_dir = os.path.join(os.environ['TORCH_HOME'], 'checkpoints') + os.makedirs(save_dir, exist_ok=True) + full_checkpoint_path = os.path.join(save_dir, checkpoint_path) + if not os.path.exists(full_checkpoint_path): + os.makedirs(os.path.dirname(full_checkpoint_path), exist_ok=True) + if is_master(): + print('Downloading {}'.format(url)) + if 'pbss.s8k.io' not in url: + url = f"https://docs.google.com/uc?export=download&id={url}" + # download_file(url, full_checkpoint_path) + import gdown + gdown.download(url, full_checkpoint_path, quiet=False) + if dist.is_available() and dist.is_initialized(): + dist.barrier() + return full_checkpoint_path diff --git a/imaginaire/utils/lmdb.py b/imaginaire/utils/lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..df40c146b73295598cde04fd94a6869c6a5e69d2 --- /dev/null +++ b/imaginaire/utils/lmdb.py @@ -0,0 +1,216 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import glob +import os + +import lmdb +from tqdm import tqdm + +from imaginaire.utils import path + + +def construct_file_path(root, data_type, sequence, filename, ext): + """Get file path for our dataset structure.""" + return '%s/%s/%s/%s.%s' % (root, data_type, sequence, filename, ext) + + +def check_and_add(filepath, key, filepaths, keys, remove_missing=False): + r"""Add filepath and key to list of filepaths and keys. + + Args: + filepath (str): Filepath to add. + key (str): LMDB key for this filepath. + filepaths (list): List of filepaths added so far. + keys (list): List of keys added so far. + remove_missing (bool): If ``True``, removes missing files, otherwise + raises an error. + Returns: + (int): Size of file at filepath. + """ + if not os.path.exists(filepath): + print(filepath + ' does not exist.') + if remove_missing: + return -1 + else: + raise FileNotFoundError(filepath + ' does not exist.') + filepaths.append(filepath) + keys.append(key) + return os.path.getsize(filepath) + + +def write_entry(txn, key, filepath): + r"""Dump binary contents of file associated with key to LMDB. + + Args: + txn: handle to LMDB. + key (str): LMDB key for this filepath. + filepath (str): Filepath to add. + """ + with open(filepath, 'rb') as f: + data = f.read() + txn.put(key.encode('ascii'), data) + + +def build_lmdb(filepaths, keys, output_filepath, map_size, large): + r"""Write out lmdb containing (key, contents of filepath) to file. + + Args: + filepaths (list): List of filepath strings. + keys (list): List of key strings associated with filepaths. + output_filepath (str): Location to write LMDB to. + map_size (int): Size of LMDB. + large (bool): Is the dataset large? + """ + if large: + db = lmdb.open(output_filepath, map_size=map_size, writemap=True) + else: + db = lmdb.open(output_filepath, map_size=map_size) + txn = db.begin(write=True) + print('Writing LMDB to:', output_filepath) + for filepath, key in tqdm(zip(filepaths, keys), total=len(keys)): + write_entry(txn, key, filepath) + txn.commit() + + +def get_all_filenames_from_list(list_name): + r"""Get all filenames from list. + + Args: + list_name (str): Path to filename list. + Returns: + all_filenames (dict): Folder name for key, and filename for values. + """ + with open(list_name, 'rt') as f: + lines = f.readlines() + lines = [line.strip() for line in lines] + all_filenames = dict() + for line in lines: + if '/' in line: + file_str = line.split('/')[0:-1] + folder_name = os.path.join(*file_str) + image_name = line.split('/')[-1].replace('.jpg', '') + else: + folder_name = '.' + image_name = line.replace('.jpg', '') + if folder_name in all_filenames: + all_filenames[folder_name].append(image_name) + else: + all_filenames[folder_name] = [image_name] + return all_filenames + + +def get_lmdb_data_types(cfg): + r"""Get the data types which should be put in LMDB. + + Args: + cfg: Configuration object. + """ + data_types, extensions = [], [] + for data_type in cfg.data.input_types: + name = list(data_type.keys()) + assert len(name) == 1 + name = name[0] + info = data_type[name] + + if 'computed_on_the_fly' not in info: + info['computed_on_the_fly'] = False + is_lmdb = not info['computed_on_the_fly'] + if not is_lmdb: + continue + + ext = info['ext'] + data_types.append(name) + extensions.append(ext) + + cfg.data.data_types = data_types + cfg.data.extensions = extensions + return cfg + + +def create_metadata(data_root=None, cfg=None, paired=None, input_list=''): + r"""Main function. + + Args: + data_root (str): Location of dataset root. + cfg (object): Loaded config object. + paired (bool): Paired or unpaired data. + input_list (str): Path to filename containing list of inputs. + Returns: + (tuple): + - all_filenames (dict): Key of data type, values with sequences. + - extensions (dict): Extension of each data type. + """ + cfg = get_lmdb_data_types(cfg) + + # Get list of all data_types in the dataset. + available_data_types = path.get_immediate_subdirectories(data_root) + print(available_data_types) + required_data_types = cfg.data.data_types + data_exts = cfg.data.extensions + + # Find filenames. + assert set(required_data_types).issubset(set(available_data_types)), \ + print(set(required_data_types) - set(available_data_types), 'missing') + + # Find extensions for each data type. + extensions = {} + for data_type, data_ext in zip(required_data_types, data_exts): + extensions[data_type] = data_ext + print('Data file extensions:', extensions) + + if paired: + if input_list != '': + all_filenames = get_all_filenames_from_list(input_list) + else: + # Get list of all sequences in the dataset. + if 'data_keypoint' in required_data_types: + search_dir = 'data_keypoint' + elif 'data_segmaps' in required_data_types: + search_dir = 'data_segmaps' + else: + search_dir = required_data_types[0] + print('Searching in dir: %s' % search_dir) + sequences = path.get_recursive_subdirectories( + os.path.join(data_root, search_dir), + extensions[search_dir]) + print('Found %d sequences' % (len(sequences))) + + # Get filenames in each sequence. + all_filenames = {} + for sequence in sequences: + folder = '%s/%s/%s/*.%s' % ( + data_root, search_dir, sequence, + extensions[search_dir]) + filenames = sorted(glob.glob(folder)) + filenames = [ + os.path.splitext(os.path.basename(filename))[0] for + filename in filenames] + all_filenames[sequence] = filenames + total_filenames = [len(filenames) + for _, filenames in all_filenames.items()] + print('Found %d files' % (sum(total_filenames))) + else: + # Get sequences in each data type. + all_filenames = {} + for data_type in required_data_types: + all_filenames[data_type] = {} + sequences = path.get_recursive_subdirectories( + os.path.join(data_root, data_type), extensions[data_type]) + + # Get filenames in each sequence. + total_filenames = 0 + for sequence in sequences: + folder = '%s/%s/%s/*.%s' % ( + data_root, data_type, sequence, extensions[data_type]) + filenames = sorted(glob.glob(folder)) + filenames = [ + os.path.splitext(os.path.basename(filename))[0] for + filename in filenames] + all_filenames[data_type][sequence] = filenames + total_filenames += len(filenames) + print('Data type: %s, Found %d sequences, Found %d files' % + (data_type, len(sequences), total_filenames)) + + return all_filenames, extensions diff --git a/imaginaire/utils/logging.py b/imaginaire/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..449e0b7b892d0e11baccc5c2c2333afec8501422 --- /dev/null +++ b/imaginaire/utils/logging.py @@ -0,0 +1,51 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import datetime +import os + +from imaginaire.utils.distributed import master_only +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.meters import set_summary_writer + + +def get_date_uid(): + """Generate a unique id based on date. + Returns: + str: Return uid string, e.g. '20171122171307111552'. + """ + return str(datetime.datetime.now().strftime("%Y_%m%d_%H%M_%S")) + + +def init_logging(config_path, logdir): + r"""Create log directory for storing checkpoints and output images. + + Args: + config_path (str): Path to the configuration file. + logdir (str): Log directory name + Returns: + str: Return log dir + """ + config_file = os.path.basename(config_path) + root_dir = 'logs' + date_uid = get_date_uid() + # example: logs/2019_0125_1047_58_spade_cocostuff + log_file = '_'.join([date_uid, os.path.splitext(config_file)[0]]) + if logdir is None: + logdir = os.path.join(root_dir, log_file) + return date_uid, logdir + + +@master_only +def make_logging_dir(logdir): + r"""Create the logging directory + + Args: + logdir (str): Log directory name + """ + print('Make folder {}'.format(logdir)) + os.makedirs(logdir, exist_ok=True) + tensorboard_dir = os.path.join(logdir, 'tensorboard') + os.makedirs(tensorboard_dir, exist_ok=True) + set_summary_writer(tensorboard_dir) diff --git a/imaginaire/utils/meters.py b/imaginaire/utils/meters.py new file mode 100644 index 0000000000000000000000000000000000000000..3befb7b1e5fc44c00d3fe29092e75777afa64caa --- /dev/null +++ b/imaginaire/utils/meters.py @@ -0,0 +1,149 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import math +from datetime import timedelta + +import torch +import wandb +from wandb import AlertLevel +from torch.utils.tensorboard import SummaryWriter + +from imaginaire.utils.distributed import master_only, dist_all_reduce_tensor, \ + is_master, get_rank + +from imaginaire.utils.distributed import master_only_print as print + +LOG_WRITER = None +LOG_DIR = None + + +@torch.no_grad() +def sn_reshape_weight_to_matrix(weight): + r"""Reshape weight to obtain the matrix form. + + Args: + weight (Parameters): pytorch layer parameter tensor. + """ + weight_mat = weight + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + +@torch.no_grad() +def get_weight_stats(mod): + r"""Get weight state + + Args: + mod: Pytorch module + """ + if mod.weight_orig.grad is not None: + grad_norm = mod.weight_orig.grad.data.norm().item() + else: + grad_norm = 0. + weight_norm = mod.weight_orig.data.norm().item() + weight_mat = sn_reshape_weight_to_matrix(mod.weight_orig) + sigma = torch.sum(mod.weight_u * torch.mv(weight_mat, mod.weight_v)) + return grad_norm, weight_norm, sigma + + +@master_only +def set_summary_writer(log_dir): + r"""Set summary writer + + Args: + log_dir (str): Log directory. + """ + global LOG_DIR, LOG_WRITER + LOG_DIR = log_dir + LOG_WRITER = SummaryWriter(log_dir=log_dir) + + +def write_summary(name, summary, step, hist=False): + """Utility function for write summary to log_writer. + """ + global LOG_WRITER + lw = LOG_WRITER + if lw is None: + raise Exception("Log writer not set.") + if hist: + lw.add_histogram(name, summary, step) + else: + lw.add_scalar(name, summary, step) + + +class Meter(object): + """Meter is to keep track of statistics along steps. + Meters write values for purpose like printing average values. + Meters can be flushed to log files (i.e. TensorBoard for now) + regularly. + + Args: + name (str): the name of meter + reduce (bool): If ``True``, perform a distributed reduce for the log + values across all GPUs. + """ + + def __init__(self, name, reduce=True): + self.name = name + self.reduce = reduce + self.values = [] + + def reset(self): + r"""Reset the meter values""" + if not self.reduce and get_rank() != 0: + return + self.values = [] + + def write(self, value): + r"""Record the value""" + if not self.reduce and get_rank() != 0: + return + if value is not None: + self.values.append(value) + + def flush(self, step): + r"""Write the value in the tensorboard. + + Args: + step (int): Epoch or iteration number. + """ + if not self.reduce and get_rank() != 0: + return + values = torch.tensor(self.values, device="cuda") + if self.reduce: + values = dist_all_reduce_tensor(values) + + if not all(math.isfinite(x) for x in values): + print("meter {} contained a nan or inf.".format(self.name)) + if is_master(): + wandb.alert( + title='NaN', + text=f'Meter {self.name} contained a nan or inf.', + level=AlertLevel.WARN, + wait_duration=timedelta(minutes=120) + ) + filtered_values = list(filter(lambda x: math.isfinite(x), self.values)) + if float(len(filtered_values)) != 0: + value = float(sum(filtered_values)) / float(len(filtered_values)) + if is_master(): + write_summary(self.name, value, step) + wandb.log({self.name: value}, step=step) + self.reset() + + @master_only + def write_image(self, img_grid, step): + r"""Write the value in the tensorboard. + + Args: + img_grid: + step (int): Epoch or iteration number. + """ + if not self.reduce and get_rank() != 0: + return + global LOG_WRITER + lw = LOG_WRITER + if lw is None: + raise Exception("Log writer not set.") + lw.add_image("Visualizations", img_grid, step) diff --git a/imaginaire/utils/misc.py b/imaginaire/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..11ae68652975a2c7e2396c4d7eda2fa2f61fe5a5 --- /dev/null +++ b/imaginaire/utils/misc.py @@ -0,0 +1,269 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +"""Miscellaneous utils.""" +import collections +from collections import OrderedDict + +import torch +import torch.nn.functional as F +string_classes = (str, bytes) + + +def split_labels(labels, label_lengths): + r"""Split concatenated labels into their parts. + + Args: + labels (torch.Tensor): Labels obtained through concatenation. + label_lengths (OrderedDict): Containing order of labels & their lengths. + + Returns: + + """ + assert isinstance(label_lengths, OrderedDict) + start = 0 + outputs = {} + for data_type, length in label_lengths.items(): + end = start + length + if labels.dim() == 5: + outputs[data_type] = labels[:, :, start:end] + elif labels.dim() == 4: + outputs[data_type] = labels[:, start:end] + elif labels.dim() == 3: + outputs[data_type] = labels[start:end] + start = end + return outputs + + +def requires_grad(model, require=True): + r""" Set a model to require gradient or not. + + Args: + model (nn.Module): Neural network model. + require (bool): Whether the network requires gradient or not. + + Returns: + + """ + for p in model.parameters(): + p.requires_grad = require + + +def to_device(data, device): + r"""Move all tensors inside data to device. + + Args: + data (dict, list, or tensor): Input data. + device (str): 'cpu' or 'cuda'. + """ + assert device in ['cpu', 'cuda'] + if isinstance(data, torch.Tensor): + data = data.to(torch.device(device)) + return data + elif isinstance(data, collections.abc.Mapping): + return {key: to_device(data[key], device) for key in data} + elif isinstance(data, collections.abc.Sequence) and \ + not isinstance(data, string_classes): + return [to_device(d, device) for d in data] + else: + return data + + +def to_cuda(data): + r"""Move all tensors inside data to gpu. + + Args: + data (dict, list, or tensor): Input data. + """ + return to_device(data, 'cuda') + + +def to_cpu(data): + r"""Move all tensors inside data to cpu. + + Args: + data (dict, list, or tensor): Input data. + """ + return to_device(data, 'cpu') + + +def to_half(data): + r"""Move all floats to half. + + Args: + data (dict, list or tensor): Input data. + """ + if isinstance(data, torch.Tensor) and torch.is_floating_point(data): + data = data.half() + return data + elif isinstance(data, collections.abc.Mapping): + return {key: to_half(data[key]) for key in data} + elif isinstance(data, collections.abc.Sequence) and \ + not isinstance(data, string_classes): + return [to_half(d) for d in data] + else: + return data + + +def to_float(data): + r"""Move all halfs to float. + + Args: + data (dict, list or tensor): Input data. + """ + if isinstance(data, torch.Tensor) and torch.is_floating_point(data): + data = data.float() + return data + elif isinstance(data, collections.abc.Mapping): + return {key: to_float(data[key]) for key in data} + elif isinstance(data, collections.abc.Sequence) and \ + not isinstance(data, string_classes): + return [to_float(d) for d in data] + else: + return data + + +def to_channels_last(data): + r"""Move all data to ``channels_last`` format. + + Args: + data (dict, list or tensor): Input data. + """ + if isinstance(data, torch.Tensor): + if data.dim() == 4: + data = data.to(memory_format=torch.channels_last) + return data + elif isinstance(data, collections.abc.Mapping): + return {key: to_channels_last(data[key]) for key in data} + elif isinstance(data, collections.abc.Sequence) and \ + not isinstance(data, string_classes): + return [to_channels_last(d) for d in data] + else: + return data + + +def slice_tensor(data, start, end): + r"""Slice all tensors from start to end. + Args: + data (dict, list or tensor): Input data. + """ + if isinstance(data, torch.Tensor): + data = data[start:end] + return data + elif isinstance(data, collections.abc.Mapping): + return {key: slice_tensor(data[key], start, end) for key in data} + elif isinstance(data, collections.abc.Sequence) and \ + not isinstance(data, string_classes): + return [slice_tensor(d, start, end) for d in data] + else: + return data + + +def get_and_setattr(cfg, name, default): + r"""Get attribute with default choice. If attribute does not exist, set it + using the default value. + + Args: + cfg (obj) : Config options. + name (str) : Attribute name. + default (obj) : Default attribute. + + Returns: + (obj) : Desired attribute. + """ + if not hasattr(cfg, name) or name not in cfg.__dict__: + setattr(cfg, name, default) + return getattr(cfg, name) + + +def get_nested_attr(cfg, attr_name, default): + r"""Iteratively try to get the attribute from cfg. If not found, return + default. + + Args: + cfg (obj): Config file. + attr_name (str): Attribute name (e.g. XXX.YYY.ZZZ). + default (obj): Default return value for the attribute. + + Returns: + (obj): Attribute value. + """ + names = attr_name.split('.') + atr = cfg + for name in names: + if not hasattr(atr, name): + return default + atr = getattr(atr, name) + return atr + + +def gradient_norm(model): + r"""Return the gradient norm of model. + + Args: + model (PyTorch module): Your network. + + """ + total_norm = 0 + for p in model.parameters(): + if p.grad is not None: + param_norm = p.grad.norm(2) + total_norm += param_norm.item() ** 2 + return total_norm ** (1. / 2) + + +def random_shift(x, offset=0.05, mode='bilinear', padding_mode='reflection'): + r"""Randomly shift the input tensor. + + Args: + x (4D tensor): The input batch of images. + offset (int): The maximum offset ratio that is between [0, 1]. + The maximum shift is offset * image_size for each direction. + mode (str): The resample mode for 'F.grid_sample'. + padding_mode (str): The padding mode for 'F.grid_sample'. + + Returns: + x (4D tensor) : The randomly shifted image. + """ + assert x.dim() == 4, "Input must be a 4D tensor." + batch_size = x.size(0) + theta = torch.eye(2, 3, device=x.device).unsqueeze(0).repeat( + batch_size, 1, 1) + theta[:, :, 2] = 2 * offset * torch.rand(batch_size, 2) - offset + grid = F.affine_grid(theta, x.size()) + x = F.grid_sample(x, grid, mode=mode, padding_mode=padding_mode) + return x + + +# def truncated_gaussian(threshold, size, seed=None, device=None): +# r"""Apply the truncated gaussian trick to trade diversity for quality +# +# Args: +# threshold (float): Truncation threshold. +# size (list of integer): Tensor size. +# seed (int): Random seed. +# device: +# """ +# state = None if seed is None else np.random.RandomState(seed) +# values = truncnorm.rvs(-threshold, threshold, +# size=size, random_state=state) +# return torch.tensor(values, device=device).float() + + +def apply_imagenet_normalization(input): + r"""Normalize using ImageNet mean and std. + + Args: + input (4D tensor NxCxHxW): The input images, assuming to be [-1, 1]. + + Returns: + Normalized inputs using the ImageNet normalization. + """ + # normalize the input back to [0, 1] + normalized_input = (input + 1) / 2 + # normalize the input using the ImageNet mean and std + mean = normalized_input.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + std = normalized_input.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + output = (normalized_input - mean) / std + return output diff --git a/imaginaire/utils/model_average.py b/imaginaire/utils/model_average.py new file mode 100644 index 0000000000000000000000000000000000000000..d5015bcc45b42aea06d8f396a726bc43174098d6 --- /dev/null +++ b/imaginaire/utils/model_average.py @@ -0,0 +1,215 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import copy + +import torch +from torch import nn +from imaginaire.layers.weight_norm import remove_weight_norms +from imaginaire.utils.misc import requires_grad + + +def reset_batch_norm(m): + r"""Reset batch norm statistics + + Args: + m: Pytorch module + """ + if hasattr(m, 'reset_running_stats'): + m.reset_running_stats() + + +def calibrate_batch_norm_momentum(m): + r"""Calibrate batch norm momentum + + Args: + m: Pytorch module + """ + if hasattr(m, 'reset_running_stats'): + # if m._get_name() == 'SyncBatchNorm': + if 'BatchNorm' in m._get_name(): + m.momentum = 1.0 / float(m.num_batches_tracked + 1) + + +class ModelAverage(nn.Module): + r"""In this model average implementation, the spectral layers are + absorbed in the model parameter by default. If such options are + turned on, be careful with how you do the training. Remember to + re-estimate the batch norm parameters before using the model. + + Args: + module (torch nn module): Torch network. + beta (float): Moving average weights. How much we weight the past. + start_iteration (int): From which iteration, we start the update. + remove_sn (bool): Whether we remove the spectral norm when we it. + """ + def __init__( + self, module, beta=0.9999, start_iteration=1000, + remove_wn_wrapper=True + ): + super(ModelAverage, self).__init__() + self.module = module + # A shallow copy creates a new object which stores the reference of + # the original elements. + # A deep copy creates a new object and recursively adds the copies of + # nested objects present in the original elements. + self.averaged_model = copy.deepcopy(self.module).to('cuda') + self.beta = beta + self.remove_wn_wrapper = remove_wn_wrapper + self.start_iteration = start_iteration + # This buffer is to track how many iterations has the model been + # trained for. We will ignore the first $(start_iterations) and start + # the averaging after. + self.register_buffer('num_updates_tracked', + torch.tensor(0, dtype=torch.long)) + self.num_updates_tracked = self.num_updates_tracked.to('cuda') + # if self.remove_sn: + # # If we want to remove the spectral norm, we first copy the + # # weights to the moving average model. + # self.copy_s2t() + # + # def fn_remove_sn(m): + # r"""Remove spectral norm.""" + # if hasattr(m, 'weight_orig'): + # remove_spectral_norm(m) + # + # self.averaged_model.apply(fn_remove_sn) + # self.dim = 0 + if self.remove_wn_wrapper: + self.copy_s2t() + + self.averaged_model.apply(remove_weight_norms) + self.dim = 0 + else: + self.averaged_model.eval() + + # Averaged model does not require grad. + requires_grad(self.averaged_model, False) + + def forward(self, *inputs, **kwargs): + r"""PyTorch module forward function overload.""" + return self.module(*inputs, **kwargs) + + @torch.no_grad() + def update_average(self): + r"""Update the moving average.""" + self.num_updates_tracked += 1 + if self.num_updates_tracked <= self.start_iteration: + beta = 0. + else: + beta = self.beta + source_dict = self.module.state_dict() + target_dict = self.averaged_model.state_dict() + for key in target_dict: + if 'num_batches_tracked' in key: + continue + if self.remove_wn_wrapper: + if key.endswith('weight'): + # This is a weight parameter. + if key + '_ori' in source_dict: + # This parameter has scaled lr. + source_param = \ + source_dict[key + '_ori'] * \ + source_dict[key + '_scale'] + elif key + '_orig' in source_dict: + # This parameter has spectral norm + # but not scaled lr. + source_param = source_dict[key + '_orig'] + elif key in source_dict: + # This parameter does not have + # weight normalization wrappers. + source_param = source_dict[key] + else: + raise ValueError( + f"{key} required in the averaged model but not " + f"found in the regular model." + ) + source_param = source_param.detach() + + if key + '_orig' in source_dict: + # This parameter has spectral norm. + source_param = self.sn_compute_weight( + source_param, + source_dict[key + '_u'], + source_dict[key + '_v'], + ) + elif key.endswith('bias') and key + '_ori' in source_dict: + # This is a bias parameter and has scaled lr. + source_param = source_dict[key + '_ori'] * \ + source_dict[key + '_scale'] + else: + # This is a normal parameter. + source_param = source_dict[key] + target_dict[key].data.mul_(beta).add_( + source_param.data, alpha=1 - beta + ) + else: + target_dict[key].data.mul_(beta).add_( + source_dict[key].data, alpha=1 - beta + ) + + @torch.no_grad() + def copy_t2s(self): + r"""Copy the original weights to the moving average weights.""" + target_dict = self.module.state_dict() + source_dict = self.averaged_model.state_dict() + beta = 0. + for key in source_dict: + target_dict[key].data.copy_( + target_dict[key].data * beta + + source_dict[key].data * (1 - beta)) + + @torch.no_grad() + def copy_s2t(self): + r""" Copy state_dictionary from source to target. + Here source is the regular module and the target is the moving + average module. Basically, we will copy weights in the regular module + to the moving average module. + """ + source_dict = self.module.state_dict() + target_dict = self.averaged_model.state_dict() + beta = 0. + for key in source_dict: + target_dict[key].data.copy_( + target_dict[key].data * beta + + source_dict[key].data * (1 - beta)) + + def __repr__(self): + r"""Returns a string that holds a printable representation of an + object""" + return self.module.__repr__() + + def sn_reshape_weight_to_matrix(self, weight): + r"""Reshape weight to obtain the matrix form. + + Args: + weight (Parameters): pytorch layer parameter tensor. + + Returns: + (Parameters): Reshaped weight matrix + """ + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute( + self.dim, + *[d for d in range(weight_mat.dim()) if d != self.dim]).contiguous() + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def sn_compute_weight(self, weight, u, v): + r"""Compute the spectral norm normalized matrix. + + Args: + weight (Parameters): pytorch layer parameter tensor. + u (tensor): left singular vectors. + v (tensor) right singular vectors + + Returns: + (Parameters): weight parameter object. + """ + weight_mat = self.sn_reshape_weight_to_matrix(weight) + sigma = torch.sum(u * torch.mv(weight_mat, v)) + weight = weight / sigma + return weight diff --git a/imaginaire/utils/path.py b/imaginaire/utils/path.py new file mode 100644 index 0000000000000000000000000000000000000000..e576fc91e66d7c1931b0fb3f349363b49f62c8d5 --- /dev/null +++ b/imaginaire/utils/path.py @@ -0,0 +1,36 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +"""Utils to deal with directories and paths.""" + +import glob +import os + + +def get_immediate_subdirectories(input_dir): + """List dirs immediately under input_dir. + + Args: + input_dir (str): Directory to list children of. + Returns: + (list): List of directory paths relative to input_dir. + """ + return sorted([name for name in os.listdir(input_dir) + if os.path.isdir(os.path.join(input_dir, name))]) + + +def get_recursive_subdirectories(input_dir, ext): + """List dirs recursively under input_dir. + + Args: + input_dir (str): Directory to list children of. + ext (str): Extension of files expected in this directory. + Returns: + (list): List of directory paths relative to input_dir. + """ + lines = glob.glob('%s/**/*.%s' % (input_dir, ext), recursive=True) + dirpaths = [os.path.dirname(item) for item in lines] + dirpaths = [os.path.relpath(item, input_dir) for item in dirpaths] + dirpaths = sorted(list(set(dirpaths))) + return dirpaths diff --git a/imaginaire/utils/trainer.py b/imaginaire/utils/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8b5b13cbbcd09031f03b55c0e66b295028f0c1 --- /dev/null +++ b/imaginaire/utils/trainer.py @@ -0,0 +1,348 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import importlib +import random +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.optim import SGD, Adam, RMSprop, lr_scheduler + +from imaginaire.optimizers import Fromage, Madam +from imaginaire.utils.distributed import get_rank, get_world_size +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.init_weight import weights_init, weights_rescale +from imaginaire.utils.model_average import ModelAverage + + +def set_random_seed(seed, by_rank=False): + r"""Set random seeds for everything. + + Args: + seed (int): Random seed. + by_rank (bool): + """ + if by_rank: + seed += get_rank() + print(f"Using random seed {seed}") + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_trainer(cfg, net_G, net_D=None, + opt_G=None, opt_D=None, + sch_G=None, sch_D=None, + train_data_loader=None, + val_data_loader=None): + """Return the trainer object. + + Args: + cfg (Config): Loaded config object. + net_G (obj): Generator network object. + net_D (obj): Discriminator network object. + opt_G (obj): Generator optimizer object. + opt_D (obj): Discriminator optimizer object. + sch_G (obj): Generator optimizer scheduler object. + sch_D (obj): Discriminator optimizer scheduler object. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + + Returns: + (obj): Trainer object. + """ + trainer_lib = importlib.import_module(cfg.trainer.type) + trainer = trainer_lib.Trainer(cfg, net_G, net_D, + opt_G, opt_D, + sch_G, sch_D, + train_data_loader, val_data_loader) + return trainer + + +def get_model_optimizer_and_scheduler(cfg, seed=0, generator_only=False): + r"""Return the networks, the optimizers, and the schedulers. We will + first set the random seed to a fixed value so that each GPU copy will be + initialized to have the same network weights. We will then use different + random seeds for different GPUs. After this we will wrap the generator + with a moving average model if applicable. It is followed by getting the + optimizers and data distributed data parallel wrapping. + + Args: + cfg (obj): Global configuration. + seed (int): Random seed. + + Returns: + (dict): + - net_G (obj): Generator network object. + - net_D (obj): Discriminator network object. + - opt_G (obj): Generator optimizer object. + - opt_D (obj): Discriminator optimizer object. + - sch_G (obj): Generator optimizer scheduler object. + - sch_D (obj): Discriminator optimizer scheduler object. + """ + # We first set the random seed to be the same so that we initialize each + # copy of the network in exactly the same way so that they have the same + # weights and other parameters. The true seed will be the seed. + set_random_seed(seed, by_rank=False) + + if generator_only: + # used for inference + lib_G = importlib.import_module(cfg.gen.type) + net_G = lib_G.Generator(cfg.gen, cfg.data) + net_G = net_G.to('cuda') + set_random_seed(seed, by_rank=True) + net_G = _wrap_model(cfg, net_G) + return net_G + else: + # Construct networks + lib_G = importlib.import_module(cfg.gen.type) + lib_D = importlib.import_module(cfg.dis.type) + net_G = lib_G.Generator(cfg.gen, cfg.data) + net_D = lib_D.Discriminator(cfg.dis, cfg.data) + print('Initialize net_G and net_D weights using ' + 'type: {} gain: {}'.format(cfg.trainer.init.type, + cfg.trainer.init.gain)) + init_bias = getattr(cfg.trainer.init, 'bias', None) + net_G.apply(weights_init( + cfg.trainer.init.type, cfg.trainer.init.gain, init_bias)) + net_D.apply(weights_init( + cfg.trainer.init.type, cfg.trainer.init.gain, init_bias)) + net_G.apply(weights_rescale()) + net_D.apply(weights_rescale()) + net_G = net_G.to('cuda') + net_D = net_D.to('cuda') + # Different GPU copies of the same model will receive noises + # initialized with different random seeds (if applicable) thanks to the + # set_random_seed command (GPU #K has random seed = args.seed + K). + set_random_seed(seed, by_rank=True) + print('net_G parameter count: {:,}'.format(_calculate_model_size(net_G))) + print('net_D parameter count: {:,}'.format(_calculate_model_size(net_D))) + + # Optimizer + opt_G = get_optimizer(cfg.gen_opt, net_G) + opt_D = get_optimizer(cfg.dis_opt, net_D) + + net_G, net_D, opt_G, opt_D = \ + wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D) + + # Scheduler + sch_G = get_scheduler(cfg.gen_opt, opt_G) + sch_D = get_scheduler(cfg.dis_opt, opt_D) + + return net_G, net_D, opt_G, opt_D, sch_G, sch_D + + +def wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D): + r"""Wrap the networks and the optimizers with AMP DDP and (optionally) + model average. + + Args: + cfg (obj): Global configuration. + net_G (obj): Generator network object. + net_D (obj): Discriminator network object. + opt_G (obj): Generator optimizer object. + opt_D (obj): Discriminator optimizer object. + + Returns: + (dict): + - net_G (obj): Generator network object. + - net_D (obj): Discriminator network object. + - opt_G (obj): Generator optimizer object. + - opt_D (obj): Discriminator optimizer object. + """ + # Apply model average wrapper. + if cfg.trainer.model_average_config.enabled: + if hasattr(cfg.trainer.model_average_config, 'g_smooth_img'): + # Specifies half-life of the running average of generator weights. + cfg.trainer.model_average_config.beta = \ + 0.5 ** (cfg.data.train.batch_size * + get_world_size() / cfg.trainer.model_average_config.g_smooth_img) + print(f"EMA Decay Factor: {cfg.trainer.model_average_config.beta}") + net_G = ModelAverage(net_G, cfg.trainer.model_average_config.beta, + cfg.trainer.model_average_config.start_iteration, + cfg.trainer.model_average_config.remove_sn) + if cfg.trainer.model_average_config.enabled: + net_G_module = net_G.module + else: + net_G_module = net_G + if hasattr(net_G_module, 'custom_init'): + net_G_module.custom_init() + + net_G = _wrap_model(cfg, net_G) + net_D = _wrap_model(cfg, net_D) + return net_G, net_D, opt_G, opt_D + + +def _calculate_model_size(model): + r"""Calculate number of parameters in a PyTorch network. + + Args: + model (obj): PyTorch network. + + Returns: + (int): Number of parameters. + """ + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +class WrappedModel(nn.Module): + r"""Dummy wrapping the module. + """ + + def __init__(self, module): + super(WrappedModel, self).__init__() + self.module = module + + def forward(self, *args, **kwargs): + r"""PyTorch module forward function overload.""" + return self.module(*args, **kwargs) + + +def _wrap_model(cfg, model): + r"""Wrap a model for distributed data parallel training. + + Args: + model (obj): PyTorch network model. + + Returns: + (obj): Wrapped PyTorch network model. + """ + if torch.distributed.is_available() and dist.is_initialized(): + # ddp = cfg.trainer.distributed_data_parallel + find_unused_parameters = cfg.trainer.distributed_data_parallel_params.find_unused_parameters + return torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[cfg.local_rank], + output_device=cfg.local_rank, + find_unused_parameters=find_unused_parameters, + broadcast_buffers=False + ) + # if ddp == 'pytorch': + # return torch.nn.parallel.DistributedDataParallel( + # model, + # device_ids=[cfg.local_rank], + # output_device=cfg.local_rank, + # find_unused_parameters=find_unused_parameters, + # broadcast_buffers=False) + # else: + # delay_allreduce = cfg.trainer.delay_allreduce + # return apex.parallel.DistributedDataParallel( + # model, delay_allreduce=delay_allreduce) + else: + return WrappedModel(model) + + +def get_scheduler(cfg_opt, opt): + """Return the scheduler object. + + Args: + cfg_opt (obj): Config for the specific optimization module (gen/dis). + opt (obj): PyTorch optimizer object. + + Returns: + (obj): Scheduler + """ + if cfg_opt.lr_policy.type == 'step': + scheduler = lr_scheduler.StepLR( + opt, + step_size=cfg_opt.lr_policy.step_size, + gamma=cfg_opt.lr_policy.gamma) + elif cfg_opt.lr_policy.type == 'constant': + scheduler = lr_scheduler.LambdaLR(opt, lambda x: 1) + elif cfg_opt.lr_policy.type == 'linear': + # Start linear decay from here. + decay_start = cfg_opt.lr_policy.decay_start + # End linear decay here. + # Continue to train using the lowest learning rate till the end. + decay_end = cfg_opt.lr_policy.decay_end + # Lowest learning rate multiplier. + decay_target = cfg_opt.lr_policy.decay_target + + def sch(x): + return min( + max(((x - decay_start) * decay_target + decay_end - x) / ( + decay_end - decay_start + ), decay_target), 1. + ) + scheduler = lr_scheduler.LambdaLR(opt, lambda x: sch(x)) + else: + return NotImplementedError('Learning rate policy {} not implemented.'. + format(cfg_opt.lr_policy.type)) + return scheduler + + +def get_optimizer(cfg_opt, net): + r"""Return the scheduler object. + + Args: + cfg_opt (obj): Config for the specific optimization module (gen/dis). + net (obj): PyTorch network object. + + Returns: + (obj): Pytorch optimizer + """ + if hasattr(net, 'get_param_groups'): + # Allow the network to use different hyper-parameters (e.g., learning + # rate) for different parameters. + params = net.get_param_groups(cfg_opt) + else: + params = net.parameters() + return get_optimizer_for_params(cfg_opt, params) + + +def get_optimizer_for_params(cfg_opt, params): + r"""Return the scheduler object. + + Args: + cfg_opt (obj): Config for the specific optimization module (gen/dis). + params (obj): Parameters to be trained by the parameters. + + Returns: + (obj): Optimizer + """ + # We will use fuse optimizers by default. + fused_opt = cfg_opt.fused_opt + try: + from apex.optimizers import FusedAdam + except: # noqa + fused_opt = False + + if cfg_opt.type == 'adam': + if fused_opt: + opt = FusedAdam(params, + lr=cfg_opt.lr, eps=cfg_opt.eps, + betas=(cfg_opt.adam_beta1, cfg_opt.adam_beta2)) + else: + opt = Adam(params, + lr=cfg_opt.lr, eps=cfg_opt.eps, + betas=(cfg_opt.adam_beta1, cfg_opt.adam_beta2)) + + elif cfg_opt.type == 'madam': + g_bound = getattr(cfg_opt, 'g_bound', None) + opt = Madam(params, lr=cfg_opt.lr, + scale=cfg_opt.scale, g_bound=g_bound) + elif cfg_opt.type == 'fromage': + opt = Fromage(params, lr=cfg_opt.lr) + elif cfg_opt.type == 'rmsprop': + opt = RMSprop(params, lr=cfg_opt.lr, + eps=cfg_opt.eps, weight_decay=cfg_opt.weight_decay) + elif cfg_opt.type == 'sgd': + if fused_opt: + from apex.optimizers import FusedSGD + opt = FusedSGD(params, + lr=cfg_opt.lr, + momentum=cfg_opt.momentum, + weight_decay=cfg_opt.weight_decay) + else: + opt = SGD(params, + lr=cfg_opt.lr, + momentum=cfg_opt.momentum, + weight_decay=cfg_opt.weight_decay) + else: + raise NotImplementedError( + 'Optimizer {} is not yet implemented.'.format(cfg_opt.type)) + return opt diff --git a/imaginaire/utils/visualization/__init__.py b/imaginaire/utils/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..27e3e7c0383a7e5f032593a50930e9d48bd0292b --- /dev/null +++ b/imaginaire/utils/visualization/__init__.py @@ -0,0 +1,9 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .common import tensor2im, tensor2flow, tensor2label, tensor2pilimage +from .common import save_tensor_image + +__all__ = ['tensor2im', 'tensor2flow', 'tensor2label', 'tensor2pilimage', + 'save_tensor_image'] diff --git a/imaginaire/utils/visualization/common.py b/imaginaire/utils/visualization/common.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b68c5b670c386fe9bef916db13c92682b81bd2 --- /dev/null +++ b/imaginaire/utils/visualization/common.py @@ -0,0 +1,314 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import cv2 +import numpy as np +import PIL +from PIL import Image +import torch +import torchvision +import os + + +def save_tensor_image( + filename, image, minus1to1_normalized=False): + r"""Convert a 3 dimensional torch tensor to a PIL image with the desired + width and height. + + Args: + filename (str): Image filename to be saved to. + image (3 x W1 x H1 tensor): Image tensor + minus1to1_normalized (bool): True if the tensor values are in [-1, + 1]. Otherwise, we assume the values are in [0, 1]. + + Returns: + (PIL image): The resulting PIL image. + """ + if len(image.size()) != 3: + raise ValueError('Image tensor dimension does not equal = 3.') + if image.size(0) != 3: + raise ValueError('Image has more than 3 channels.') + if minus1to1_normalized: + # Normalize back to [0, 1] + image = (image + 1) * 0.5 + dirname = os.path.dirname(filename) + os.makedirs(dirname, exist_ok=True) + image_grid = torchvision.utils.make_grid( + image, nrow=1, padding=0, normalize=False) + torchvision.utils.save_image(image_grid, filename, nrow=1) + return + + +def tensor2pilimage(image, width=None, height=None, minus1to1_normalized=False): + r"""Convert a 3 dimensional torch tensor to a PIL image with the desired + width and height. + + Args: + image (3 x W1 x H1 tensor): Image tensor + width (int): Desired width for the result PIL image. + height (int): Desired height for the result PIL image. + minus1to1_normalized (bool): True if the tensor values are in [-1, + 1]. Otherwise, we assume the values are in [0, 1]. + + Returns: + (PIL image): The resulting PIL image. + """ + if len(image.size()) != 3: + raise ValueError('Image tensor dimension does not equal = 3.') + if image.size(0) != 3: + raise ValueError('Image has more than 3 channels.') + if minus1to1_normalized: + # Normalize back to [0, 1] + image = (image + 1) * 0.5 + image = image.detach().cpu().squeeze().numpy() + image = np.transpose(image, (1, 2, 0)) * 255 + output_img = Image.fromarray(np.uint8(image)) + if width is not None and height is not None: + output_img = output_img.resize((width, height), Image.BICUBIC) + return output_img + + +def tensor2im(image_tensor, imtype=np.uint8, normalize=True, + three_channel_output=True): + r"""Convert tensor to image. + + Args: + image_tensor (torch.tensor or list of torch.tensor): If tensor then + (NxCxHxW) or (NxTxCxHxW) or (CxHxW). + imtype (np.dtype): Type of output image. + normalize (bool): Is the input image normalized or not? + three_channel_output (bool): Should single channel images be made 3 + channel in output? + + Returns: + (numpy.ndarray, list if case 1, 2 above). + """ + if image_tensor is None: + return None + if isinstance(image_tensor, list): + return [tensor2im(x, imtype, normalize) for x in image_tensor] + if image_tensor.dim() == 5 or image_tensor.dim() == 4: + return [tensor2im(image_tensor[idx], imtype, normalize) + for idx in range(image_tensor.size(0))] + + if image_tensor.dim() == 3: + image_numpy = image_tensor.cpu().float().numpy() + if normalize: + image_numpy = (np.transpose( + image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + else: + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 + image_numpy = np.clip(image_numpy, 0, 255) + if image_numpy.shape[2] == 1 and three_channel_output: + image_numpy = np.repeat(image_numpy, 3, axis=2) + elif image_numpy.shape[2] > 3: + image_numpy = image_numpy[:, :, :3] + return image_numpy.astype(imtype) + + +def tensor2label(segmap, n_label=None, imtype=np.uint8, + colorize=True, output_normalized_tensor=False): + r"""Convert segmentation mask tensor to color image. + Args: + segmap (tensor) of + If tensor then (NxCxHxW) or (NxTxCxHxW) or (CxHxW). + n_label (int): If None, then segmap.size(0). + imtype (np.dtype): Type of output image. + colorize (bool): Put colors in. + + Returns: + (numpy.ndarray or normalized torch image). + """ + if segmap is None: + return None + if isinstance(segmap, list): + return [tensor2label(x, n_label, + imtype, colorize, + output_normalized_tensor) for x in segmap] + if segmap.dim() == 5 or segmap.dim() == 4: + return [tensor2label(segmap[idx], n_label, + imtype, colorize, + output_normalized_tensor) + for idx in range(segmap.size(0))] + + segmap = segmap.float() + if not output_normalized_tensor: + segmap = segmap.cpu() + if n_label is None: + n_label = segmap.size(0) + if n_label > 1: + segmap = segmap.max(0, keepdim=True)[1] + + if output_normalized_tensor: + if n_label == 0: + segmap = Colorize(256)(segmap).to('cuda') + else: + segmap = Colorize(n_label)(segmap).to('cuda') + return 2 * (segmap.float() / 255) - 1 + else: + if colorize: + segmap = Colorize(n_label)(segmap) + segmap = np.transpose(segmap.numpy(), (1, 2, 0)) + else: + segmap = segmap.cpu().numpy() + return segmap.astype(imtype) + + +def tensor2flow(tensor, imtype=np.uint8): + r"""Convert flow tensor to color image. + + Args: + tensor (tensor) of + If tensor then (NxCxHxW) or (NxTxCxHxW) or (CxHxW). + imtype (np.dtype): Type of output image. + + Returns: + (numpy.ndarray or normalized torch image). + """ + if tensor is None: + return None + if isinstance(tensor, list): + tensor = [t for t in tensor if t is not None] + if not tensor: + return None + return [tensor2flow(t, imtype) for t in tensor] + if tensor.dim() == 5 or tensor.dim() == 4: + return [tensor2flow(tensor[b]) for b in range(tensor.size(0))] + + tensor = tensor.detach().cpu().float().numpy() + tensor = np.transpose(tensor, (1, 2, 0)) + + hsv = np.zeros((tensor.shape[0], tensor.shape[1], 3), dtype=imtype) + hsv[:, :, 0] = 255 + hsv[:, :, 1] = 255 + mag, ang = cv2.cartToPolar(tensor[..., 0], tensor[..., 1]) + hsv[..., 0] = ang * 180 / np.pi / 2 + hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) + rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + return rgb + + +def plot_keypoints(image, keypoints, normalize=True): + r"""Plot keypoints on image. + + Args: + image (PIL.Image, or numpy.ndarray, or torch.Tensor): Input image. + keypoints (np.ndarray or torch.Tensor, Nx2): Keypoint locations. + normalize (bool): Whether to normalize the image or not. + """ + if isinstance(image, PIL.Image.Image): + image = np.array(image) + if isinstance(image, torch.Tensor): + image = tensor2im(image, normalize=normalize) + if isinstance(image, np.ndarray): + assert image.ndim == 3 + assert image.shape[-1] == 1 or image.shape[-1] == 3 + if isinstance(keypoints, torch.Tensor): + keypoints = keypoints.cpu().numpy() + assert keypoints.ndim == 2 and keypoints.shape[1] == 2 + + cv2_image = np.ascontiguousarray(image[:, :, ::-1]) # RGB to BGR. + for idx in range(keypoints.shape[0]): + keypoint = np.round(keypoints[idx]).astype(np.int) + cv2_image = cv2.circle(cv2_image, tuple(keypoint), + 5, (0, 255, 0), -1) + image = np.ascontiguousarray(cv2_image[:, :, ::-1]) + return image + + +def labelcolormap(N): + r"""Create colors for segmentation label ids. + + Args: + N (int): Number of labels. + """ + if N == 35: # GTA/cityscape train + cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), + (111, 74, 0), (81, 0, 81), (128, 64, 128), + (244, 35, 232), (250, 170, 160), (230, 150, 140), + (70, 70, 70), (102, 102, 156), (190, 153, 153), + (180, 165, 180), (150, 100, 100), (150, 120, 90), + (153, 153, 153), (153, 153, 153), (250, 170, 30), + (220, 220, 0), (107, 142, 35), (152, 251, 152), + (70, 130, 180), (220, 20, 60), (255, 0, 0), + (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 0, 90), + (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), + (0, 0, 142)], + dtype=np.uint8) + elif N == 20: # GTA/cityscape eval + cmap = np.array([(128, 64, 128), (244, 35, 232), (70, 70, 70), + (102, 102, 156), (190, 153, 153), (153, 153, 153), + (250, 170, 30), (220, 220, 0), (107, 142, 35), + (152, 251, 152), (220, 20, 60), (255, 0, 0), + (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), + (0, 0, 230), (119, 11, 32), (70, 130, 180), (0, 0, 0)], + dtype=np.uint8) + else: + cmap = np.zeros([N, 3]).astype(np.uint8) + for i in range(N): + r, g, b = np.zeros(3) + for j in range(8): + r = r + (1 << (7 - j)) * ((i & (1 << (3 * j))) >> (3 * j)) + g = g + (1 << (7 - j)) * \ + ((i & (1 << (3 * j + 1))) >> (3 * j + 1)) + b = b + (1 << (7 - j)) * \ + ((i & (1 << (3 * j + 2))) >> (3 * j + 2)) + cmap[i, :] = np.array([r, g, b]) + return cmap + + +class Colorize(object): + """Class to colorize segmentation maps.""" + + def __init__(self, n=35): + self.cmap = labelcolormap(n) + self.cmap = torch.from_numpy(self.cmap[:n]) + + def __call__(self, seg_map): + r""" + + Args: + seg_map (tensor): Input Segmentation maps to be colorized. + """ + size = seg_map.size() + color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) + for label in range(0, len(self.cmap)): + mask = (label == seg_map[0]).cpu() + color_image[0][mask] = self.cmap[label][0] + color_image[1][mask] = self.cmap[label][1] + color_image[2][mask] = self.cmap[label][2] + return color_image + + +def plot_keypoints_on_black(resize_h, resize_w, crop_h, crop_w, is_flipped, + cfgdata, keypoints): + r"""Plot keypoints on black image. + + Args: + resize_h (int): Height to be resized to. + resize_w (int): Width to be resized to. + crop_h (int): Height of the cropping. + crop_w (int): Width of the cropping. + is_flipped (bool): If image is a flipped version. + cfgdata (obj): Data configuration object. + keypoints (np.ndarray): Keypoint locations. Shape of + (Nx2) or (TxNx2). + + Returns: + (list of np.ndarray): List of images (output_h, output_w, 3). + """ + if keypoints.ndim == 2 and keypoints.shape[1] == 2: + keypoints = keypoints[np.newaxis, ...] + + outputs = [] + for t_idx in range(keypoints.shape[0]): + cv2_image = np.zeros((crop_h, crop_w, 3)).astype(np.uint8) + for idx in range(keypoints[t_idx].shape[0]): + keypoint = np.round(keypoints[t_idx][idx]).astype(np.int) + cv2_image = cv2.circle(cv2_image, tuple(keypoint), + 5, (0, 255, 0), -1) + image = np.ascontiguousarray(cv2_image[:, :, ::-1]) # BGR to RGB. + outputs.append(image) + + return outputs diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..6aaafcc5fe531c20b024526c439a11351fb7dd7d --- /dev/null +++ b/inference.py @@ -0,0 +1,86 @@ +import argparse + +import os +import torch + +from imaginaire.config import Config +from imaginaire.utils.cudnn import init_cudnn +from imaginaire.utils.dataset import get_test_dataloader +from imaginaire.utils.distributed import init_dist +from imaginaire.utils.gpu_affinity import set_affinity +from imaginaire.utils.io import get_checkpoint as get_checkpoint +from imaginaire.utils.logging import init_logging +from imaginaire.utils.trainer import \ + (get_model_optimizer_and_scheduler, set_random_seed) +import imaginaire.config + + +def parse_args(): + parser = argparse.ArgumentParser(description='Training') + parser.add_argument('--config', required=True, + help='Path to the training config file.') + parser.add_argument('--checkpoint', default='', + help='Checkpoint path.') + parser.add_argument('--output_dir', required=True, + help='Location to save the image outputs') + parser.add_argument('--logdir', + help='Dir for saving logs and models.') + parser.add_argument('--seed', type=int, default=0, + help='Random seed.') + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + cfg = Config(args.config) + imaginaire.config.DEBUG = args.debug + + if not hasattr(cfg, 'inference_args'): + cfg.inference_args = None + + # Create log directory for storing training results. + cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir) + + # Initialize cudnn. + init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark) + + # Initialize data loaders and models. + net_G = get_model_optimizer_and_scheduler(cfg, seed=args.seed, generator_only=True) + + if args.checkpoint == '': + raise NotImplementedError("No checkpoint is provided for inference!") + + # Load checkpoint. + # trainer.load_checkpoint(cfg, args.checkpoint) + checkpoint = torch.load(args.checkpoint, map_location='cpu') + net_G.load_state_dict(checkpoint['net_G']) + + # Do inference. + net_G = net_G.module + net_G.eval() + for name, param in net_G.named_parameters(): + param.requires_grad = False + torch.cuda.empty_cache() + device = torch.device('cuda') + rng_cuda = torch.Generator(device=device) + rng_cuda = rng_cuda.manual_seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + world_dir = os.path.join(args.output_dir) + os.makedirs(world_dir, exist_ok=True) + print('[PCGGenerator] Generating BEV scene representation...') + os.system('python terrain_generator.py --size {} --seed {} --outdir {}'.format(net_G.voxel.sample_size, args.seed, world_dir)) + net_G.voxel.next_world(device, world_dir, checkpoint) + cam_mode = cfg.inference_args.camera_mode + current_outdir = os.path.join(world_dir, 'camera_{:02d}'.format(cam_mode)) + os.makedirs(current_outdir, exist_ok=True) + os.makedirs(current_outdir, exist_ok=True) + z = torch.empty(1, net_G.style_dims, dtype=torch.float32, device=device) + z.normal_(generator=rng_cuda) + net_G.inference_givenstyle(z, current_outdir, **vars(cfg.inference_args)) + +if __name__ == "__main__": + main() diff --git a/scripts/batch_terrain_gen.py b/scripts/batch_terrain_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..63a69a1635f01218f777a715f02b9f189b077435 --- /dev/null +++ b/scripts/batch_terrain_gen.py @@ -0,0 +1,48 @@ +import os +import time +import subprocess +from tqdm import tqdm +import argparse +from multiprocessing import Pool + +parser = argparse.ArgumentParser() +parser.add_argument('--size', type=int, default=1024) +parser.add_argument('--nbins', type=int, default=256) +parser.add_argument('--seed', type=int, default=762345) +parser.add_argument('--outdir', type=str, required=True) +parser.add_argument('--bs', type=int, default=1) +parser.add_argument("--num_workers", type=int, default=16) +parser.add_argument("--dry_run", action="store_true") +parser.add_argument("--parallel", action="store_true") +args = parser.parse_args() + + +def par_job(command): + if args.dry_run: + print(command) + else: + subprocess.call(command, shell=True) + + +if __name__ == "__main__": + t0 = time.time() + + cmd_list = [] + for idx in tqdm(range(args.bs)): + current_dir = os.path.join(args.outdir, 'world_{:04d}'.format(idx)) + cmd = 'python scripts/single_terrain_gen.py --size {} --seed {} --outdir {}'.format(args.size, args.seed + idx, current_dir) + if not args.parallel: + if args.dry_run: + print(cmd) + else: + subprocess.call(cmd, shell=True) + cmd_list.append(cmd) + + if args.parallel: + with Pool(processes=args.num_workers) as pool: + with tqdm(total=len(cmd_list)) as pbar: + for _ in tqdm(pool.imap_unordered(par_job, cmd_list)): + pbar.update() + t1 = time.time() + print("Finished in %.4f seconds" % (t1 - t0)) + os.system("stty sane") \ No newline at end of file diff --git a/scripts/build_lmdb.py b/scripts/build_lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..544afaa7b2b77f78f64c9e82ad60278437b00d21 --- /dev/null +++ b/scripts/build_lmdb.py @@ -0,0 +1,125 @@ +import copy +import shutil +import argparse +import json +import sys +import os +from tqdm import tqdm + +sys.path.append('.') +from imaginaire.utils.lmdb import create_metadata, \ + construct_file_path, check_and_add, build_lmdb # noqa: E402 +from imaginaire.config import Config # noqa: E402 + + +def parse_args(): + r"""Parse user input arguments""" + parser = argparse.ArgumentParser(description='Folder -> LMDB conversion') + parser.add_argument('--data_root', type=str, required=True, + help='Input data location.') + parser.add_argument('--config', type=str, required=True, + help='Config with label info.') + parser.add_argument('--output_root', type=str, required=True, + help='Output LMDB location') + parser.add_argument('--input_list', type=str, default='', + help='list of images that will be used.') + parser.add_argument('--metadata_factor', type=float, default=0.75, + help='Factor of filesize to allocate for metadata?') + parser.add_argument('--overwrite', default=False, action='store_true', + help='Overwrite output file if exists') + parser.add_argument('--paired', default=False, action='store_true', + help='Is the input data paired?') + parser.add_argument('--large', default=False, action='store_true', + help='Is the dataset large?') + parser.add_argument('--remove_missing', default=False, action='store_true', + help='Remove missing files from paired datasets?') + args = parser.parse_args() + return args + + +def main(): + r""" Build lmdb for training/testing. + Usage: + python scripts/build_lmdb.py \ + --config configs/data_image.yaml \ + --data_root /mnt/bigdata01/datasets/test_image \ + --output_root /mnt/bigdata01/datasets/test_image/lmdb_0/ \ + --overwrite + """ + args = parse_args() + cfg = Config(args.config) + + # Check if output file already exists. + if os.path.exists(args.output_root): + if args.overwrite: + print('Deleting existing output LMDB.') + shutil.rmtree(args.output_root) + else: + print('Output root LMDB already exists. Use --overwrite. ' + + 'Exiting...') + return + + all_filenames, extensions = \ + create_metadata(data_root=args.data_root, + cfg=cfg, + paired=args.paired, + input_list=args.input_list) + required_data_types = cfg.data.data_types + + # Build LMDB. + os.makedirs(args.output_root) + for data_type in required_data_types: + data_size = 0 + print('Data type:', data_type) + filepaths, keys = [], [] + print('>> Building file list.') + + # Get appropriate list of files. + if args.paired: + filenames = all_filenames + else: + filenames = all_filenames[data_type] + + for sequence in tqdm(filenames): + for filename in copy.deepcopy(filenames[sequence]): + filepath = construct_file_path( + args.data_root, data_type, sequence, filename, + extensions[data_type]) + key = '%s/%s' % (sequence, filename) + filesize = check_and_add(filepath, key, filepaths, keys, + remove_missing=args.remove_missing) + + # Remove file from list, if missing. + if filesize == -1 and args.paired and args.remove_missing: + print('Removing %s from list' % (filename)) + filenames[sequence].remove(filename) + data_size += filesize + + # Remove empty sequences. + if args.paired and args.remove_missing: + for sequence in copy.deepcopy(all_filenames): + if not all_filenames[sequence]: + all_filenames.pop(sequence) + + # Allocate size. + data_size = max(int((1 + args.metadata_factor) * data_size), 1e9) + print('Reserved size: %s, %dGB' % (data_type, data_size // 1e9)) + + # Write LMDB to file. + output_filepath = os.path.join(args.output_root, data_type) + build_lmdb(filepaths, keys, output_filepath, data_size, args.large) + + # Output list of all filenames. + if args.output_root: + with open(args.output_root + '/all_filenames.json', 'w') as fout: + json.dump(all_filenames, fout, indent=4) + + # Output metadata. + with open(args.output_root + '/metadata.json', 'w') as fout: + json.dump(extensions, fout, indent=4) + else: + return all_filenames, extensions + + +if __name__ == "__main__": + main() diff --git a/scripts/pcg_cache.py b/scripts/pcg_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..d50e76340298e33da7461163f77024e354ed866a --- /dev/null +++ b/scripts/pcg_cache.py @@ -0,0 +1,127 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import csv +import time +import random +import cv2 +import os +import torch +from tqdm import tqdm +from argparse import ArgumentParser + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--terrain', type=str, required=True, help='directory path to terrain dataset') + parser.add_argument('--outdir', type=str, required=True) + assert os.path.exists("./scenedreamer_released.pt") + pcg_asset = torch.load("./scenedreamer_released.pt", map_location='cpu') + args = parser.parse_args() + terrain_dir = args.terrain + outdir = args.outdir + sample_height = 256 + sample_size = 1024 + os.makedirs(outdir, exist_ok=True) + + trees_models = pcg_asset['assets'] + + # can be customized + biome_trees_dict = { + 'desert': [], + 'savanna': [5], + 'twoodland': [1, 7], + 'tundra': [], + 'seasonal forest': [1, 2], + 'rainforest': [1, 2, 3], + 'temp forest': [4], + 'temp rainforest': [0, 3], + 'boreal': [5,6,7], + 'water': [], + } + + biome2mclabels = torch.tensor([28, 9, 8, 1, 9, 8, 9, 8, 30, 26], dtype=torch.int32) + biome_names = list(biome_trees_dict.keys()) + chunk_grid_x, chunk_grid_y = torch.meshgrid(torch.arange(sample_size), torch.arange(sample_size)) + + terrain_list = os.listdir(terrain_dir) + for world in tqdm(terrain_list): + voxel_t = torch.zeros(sample_height, sample_size, sample_size).to(torch.int32) + current_dir = os.path.join(terrain_dir, world) + height_map = np.load(os.path.join(current_dir, 'biome_rivers_height.npy')) + height_map[height_map < 0] = 0 + height_map = ((height_map - height_map.min()) / (1 - height_map.min()) * (sample_height - 1)).astype(np.int16) + semantic_map = cv2.imread(os.path.join(current_dir, 'biome_rivers_labels.png'), 0) + tree_map = cv2.imread(os.path.join(current_dir, 'biome_trees_dist.png'), 0) + total_size = height_map.shape[0] + crop_pos_x, crop_pos_y = np.random.randint(0, total_size - sample_size, size=2) + org_height_map = height_map[crop_pos_x: crop_pos_x + sample_size, crop_pos_y: crop_pos_y + sample_size].astype(int) + chunk_height_map = torch.from_numpy(org_height_map)[None, ...] + chunk_semantic_map = semantic_map[crop_pos_x: crop_pos_x + sample_size, crop_pos_y: crop_pos_y + sample_size] + chunk_trees_map = tree_map[crop_pos_x: crop_pos_x + sample_size, crop_pos_y: crop_pos_y + sample_size] + org_semantic_map = torch.from_numpy(chunk_semantic_map.copy()) + org_semantic_map[chunk_trees_map != 255] = 10 + chunk_semantic_map = biome2mclabels[torch.from_numpy(chunk_semantic_map)[None, ...].long().contiguous()] + voxel_t = voxel_t.scatter_(0, chunk_height_map, chunk_semantic_map) + for preproc_step in range(8): + voxel_t = voxel_t.scatter(0, torch.clip(chunk_height_map + preproc_step + 1, 0, sample_height - 1), chunk_semantic_map) + + chunk_height_map = chunk_height_map + 8 + chunk_height_map = chunk_height_map[0] + boundary_detect = 50 + for biome_id in range(biome2mclabels.shape[0]): + tree_pos_mask = (chunk_trees_map == biome_id) + tree_pos_x = chunk_grid_x[tree_pos_mask] + tree_pos_y = chunk_grid_y[tree_pos_mask] + tree_pos_h = chunk_height_map[tree_pos_mask] + assert len(tree_pos_x) == len(tree_pos_y) + selected_trees = biome_trees_dict[biome_names[biome_id]] + if len(selected_trees) == 0: + continue + for idx in range(len(tree_pos_x)): + if tree_pos_x[idx] < boundary_detect or tree_pos_x[idx] > sample_size - boundary_detect or tree_pos_y[idx] < boundary_detect or tree_pos_y[idx] > sample_size - boundary_detect or tree_pos_h[idx] > sample_height - boundary_detect: + # FIXME: hack, to avoid out of index near the boundary + continue + tree_id = random.choice(selected_trees) + tmp = voxel_t[tree_pos_h[idx]: tree_pos_h[idx] + trees_models[tree_id].shape[0], tree_pos_x[idx]: tree_pos_x[idx] + trees_models[tree_id].shape[1], tree_pos_y[idx]: tree_pos_y[idx] + trees_models[tree_id].shape[2]] + tmp_mask = (tmp == 0) + try: + voxel_t[tree_pos_h[idx]: tree_pos_h[idx] + trees_models[tree_id].shape[0], tree_pos_x[idx]: tree_pos_x[idx] + trees_models[tree_id].shape[1], tree_pos_y[idx]: tree_pos_y[idx] + trees_models[tree_id].shape[2]][tmp_mask] = trees_models[tree_id][tmp_mask] + except: + print('height?', tree_pos_h[idx]) + print(tmp_mask.shape) + print(tmp.shape) + print(trees_models[tree_id].shape) + print(voxel_t.shape) + print(tree_id) + raise NotImplementedError + trans_mat = torch.eye(4) # Transform voxel to world + # Generate heightmap for camera trajectory generation + m, h = torch.max((torch.flip(voxel_t, [0]) != 0).int(), dim=0, keepdim=False) + heightmap = voxel_t.shape[0] - 1 - h + heightmap[m == 0] = 0 # Special case when the whole vertical column is empty + voxel_t = voxel_t.numpy() + voxel_value = voxel_t[voxel_t != 0] + voxel_x, voxel_y, voxel_z = np.where(voxel_t != 0) + current_height_map = (chunk_height_map / (sample_height - 1))[None, None, ...] + current_semantic_map = F.one_hot(org_semantic_map.to(torch.int64)).to(torch.float).permute(2, 0, 1)[None, ...] + semantic_map = torch.argmax(current_semantic_map, dim=1) + print('semantic map after one hot and argmax', torch.unique(semantic_map, return_counts=True)) + print(current_height_map.shape) + print(current_semantic_map.shape) + print(heightmap.shape) + print(voxel_t.shape) + print(voxel_value.shape) + print(voxel_x.shape) + print(voxel_y.shape) + print(voxel_z.shape) + print(voxel_z.dtype) + voxel_sparse = np.stack([voxel_x, voxel_y, voxel_z, voxel_value]) + print(voxel_sparse.shape) + current_outdir = os.path.join(outdir, world) + os.makedirs(current_outdir, exist_ok=True) + np.save(os.path.join(current_outdir, 'voxel_sparse.npy'), voxel_sparse.astype(np.int16)) + np.save(os.path.join(current_outdir, 'height_map.npy'), current_height_map.numpy()) + np.save(os.path.join(current_outdir, 'semantic_map.npy'), current_semantic_map.numpy()) + np.save(os.path.join(current_outdir, 'hmap_mc.npy'), heightmap.numpy()) diff --git a/scripts/single_terrain_gen.py b/scripts/single_terrain_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..28e175fa03699740b1bfa578b5e12ff387346c51 --- /dev/null +++ b/scripts/single_terrain_gen.py @@ -0,0 +1,468 @@ +import numpy as np +from matplotlib import pyplot as plt +from scipy.spatial import Voronoi +from skimage.draw import polygon +from PIL import Image +from noise import snoise3 +from skimage import exposure +from scipy.interpolate import interp1d +from scipy import ndimage +import cv2 +from scipy.ndimage import gaussian_filter +from scipy.ndimage import binary_dilation + +from argparse import ArgumentParser +import os + +def save_height_map(height_map, file_name): + #input height map should be float, raw output of noise map + normalized_height_map = (((height_map - height_map.min()) / (height_map.max() - height_map.min()))*255).astype(np.uint8) + cv2.imwrite(file_name, normalized_height_map) + np.save(file_name[:-4] + '.npy', height_map) + +def get_boundary(vor_map, kernel=1): + boundary_map = np.zeros_like(vor_map, dtype=bool) + n, m = vor_map.shape + + clip = lambda x: max(0, min(size-1, x)) + def check_for_mult(a): + b = a[0] + for i in range(len(a)-1): + if a[i] != b: return 1 + return 0 + + for i in range(n): + for j in range(m): + boundary_map[i, j] = check_for_mult(vor_map[ + clip(i-kernel):clip(i+kernel+1), + clip(j-kernel):clip(j+kernel+1), + ].flatten()) + + return boundary_map + +def gradient(im_smooth): + gradient_x = im_smooth.astype(float) + gradient_y = im_smooth.astype(float) + + kernel = np.arange(-1,2).astype(float) + kernel = - kernel / 2 + + gradient_x = ndimage.convolve(gradient_x, kernel[np.newaxis]) + gradient_y = ndimage.convolve(gradient_y, kernel[np.newaxis].T) + + return gradient_x, gradient_y + +def sobel(im_smooth): + gradient_x = im_smooth.astype(float) + gradient_y = im_smooth.astype(float) + + kernel = np.array([[-1,0,1],[-2,0,2],[-1,0,1]]) + + gradient_x = ndimage.convolve(gradient_x, kernel) + gradient_y = ndimage.convolve(gradient_y, kernel.T) + + return gradient_x, gradient_y + +def compute_normal_map(gradient_x, gradient_y, intensity=1): + width = gradient_x.shape[1] + height = gradient_x.shape[0] + max_x = np.max(gradient_x) + max_y = np.max(gradient_y) + + max_value = max_x + + if max_y > max_x: + max_value = max_y + + normal_map = np.zeros((height, width, 3), dtype=np.float32) + + intensity = 1 / intensity + + strength = max_value / (max_value * intensity) + + normal_map[..., 0] = gradient_x / max_value + normal_map[..., 1] = gradient_y / max_value + normal_map[..., 2] = 1 / strength + + norm = np.sqrt(np.power(normal_map[..., 0], 2) + np.power(normal_map[..., 1], 2) + np.power(normal_map[..., 2], 2)) + + normal_map[..., 0] /= norm + normal_map[..., 1] /= norm + normal_map[..., 2] /= norm + + normal_map *= 0.5 + normal_map += 0.5 + + return normal_map + + +def get_normal_map(im, intensity=1.0): + sobel_x, sobel_y = sobel(im) + normal_map = compute_normal_map(sobel_x, sobel_y, intensity) + return normal_map + +def get_normal_light(height_map_): + normal_map_ = get_normal_map(height_map_)[:,:,0:2].mean(axis=2) + normal_map_ = np.interp(normal_map_, (0, 1), (-1, 1)) + return normal_map_ + +def apply_height_map(im_map, smooth_map, height_map, land_mask): + normal_map = get_normal_light(height_map) + normal_map = normal_map*land_mask + smooth_map/2*(~land_mask) + + normal_map = np.interp(normal_map, (-1, 1), (-192, 192)) + + normal_map_color = np.repeat(normal_map[:, :, np.newaxis], 3, axis=-1) + normal_map_color = normal_map_color.astype(int) + + out_map = im_map + normal_map_color + return out_map, normal_map + + +def histeq(img, alpha=1): + img_cdf, bin_centers = exposure.cumulative_distribution(img) + img_eq = np.interp(img, bin_centers, img_cdf) + img_eq = np.interp(img_eq, (0, 1), (-1, 1)) + return alpha * img_eq + (1 - alpha) * img + +def voronoi(points, size): + # Add points at edges to eliminate infinite ridges + edge_points = size*np.array([[-1, -1], [-1, 2], [2, -1], [2, 2]]) + new_points = np.vstack([points, edge_points]) + + # Calculate Voronoi tessellation + vor = Voronoi(new_points) + + return vor + +def voronoi_map(vor, size): + # Calculate Voronoi map + vor_map = np.zeros((size, size), dtype=np.uint32) + + for i, region in enumerate(vor.regions): + # Skip empty regions and infinte ridge regions + if len(region) == 0 or -1 in region: continue + # Get polygon vertices + x, y = np.array([vor.vertices[i][::-1] for i in region]).T + # Get pixels inside polygon + rr, cc = polygon(x, y) + # Remove pixels out of image bounds + in_box = np.where((0 <= rr) & (rr < size) & (0 <= cc) & (cc < size)) + rr, cc = rr[in_box], cc[in_box] + # Paint image + vor_map[rr, cc] = i + + return vor_map + +# Lloyd's relaxation +def relax(points, size, k=10): + new_points = points.copy() + for _ in range(k): + vor = voronoi(new_points, size) + new_points = [] + for i, region in enumerate(vor.regions): + if len(region) == 0 or -1 in region: continue + poly = np.array([vor.vertices[i] for i in region]) + center = poly.mean(axis=0) + new_points.append(center) + new_points = np.array(new_points).clip(0, size) + return new_points + +def noise_map(size, res, seed, octaves=1, persistence=0.5, lacunarity=2.0): + scale = size/res + return np.array([[ + snoise3( + (x+0.1)/scale, + y/scale, + seed+map_seed, + octaves=octaves, + persistence=persistence, + lacunarity=lacunarity + ) + for x in range(size)] + for y in range(size) + ]) + +def average_cells(vor, data): + """Returns the average value of data inside every voronoi cell""" + size = vor.shape[0] + count = np.max(vor)+1 + + sum_ = np.zeros(count) + count = np.zeros(count) + + for i in range(size): + for j in range(size): + p = vor[i, j] + count[p] += 1 + sum_[p] += data[i, j] + + average = sum_/ (count + 1e-3) + average[count==0] = 0 + + return average + +def fill_cells(vor, data): + size = vor.shape[0] + image = np.zeros((size, size)) + + for i in range(size): + for j in range(size): + p = vor[i, j] + image[i, j] = data[p] + + return image + +def color_cells(vor, data, dtype=int): + size = vor.shape[0] + image = np.zeros((size, size, 3)) + + for i in range(size): + for j in range(size): + p = vor[i, j] + image[i, j] = data[p] + + return image.astype(dtype) + +def quantize(data, n): + bins = np.linspace(-1, 1, n+1) + return (np.digitize(data, bins) - 1).clip(0, n-1) + + +def bezier(x1, y1, x2, y2, a): + p1 = np.array([0, 0]) + p2 = np.array([x1, y1]) + p3 = np.array([x2, y2]) + p4 = np.array([1, a]) + + return lambda t: ((1-t)**3 * p1 + 3*(1-t)**2*t * p2 + 3*(1-t)*t**2 * p3 + t**3 * p4) + +def bezier_lut(x1, y1, x2, y2, a): + t = np.linspace(0, 1, 256) + f = bezier(x1, y1, x2, y2, a) + curve = np.array([f(t_) for t_ in t]) + + return interp1d(*curve.T) + +def filter_map(h_map, smooth_h_map, x1, y1, x2, y2, a, b): + f = bezier_lut(x1, y1, x2, y2, a) + output_map = b*h_map + (1-b)*smooth_h_map + output_map = f(output_map.clip(0, 1)) + return output_map + +def filter_inbox(pts): + inidx = np.all(pts < size, axis=1) + return pts[inidx] + +def generate_trees(n): + trees = np.random.randint(0, size-1, (n, 2)) + trees = relax(trees, size, k=10).astype(np.uint32) + trees = filter_inbox(trees) + return trees + +def place_trees(n, mask, a=0.5): + trees= generate_trees(n) + rr, cc = trees.T + + output_trees = np.zeros((size, size), dtype=bool) + output_trees[rr, cc] = True + output_trees = output_trees*(mask>a)*river_land_mask*(adjusted_height_river_map<0.5) + + output_trees = np.array(np.where(output_trees == 1))[::-1].T + return output_trees + +if __name__ == '__main__': + # for label conditioned discriminator in SceneDreamer + # ignore, sky, tree, dirt, flower, grass, gravel, water, rock, stone, sand, and snow + biome_names = [ + # sand and rock + "desert", + # grass gravel rock stone + "savanna", # mixed woodland and grassland + # trees flower + "tropical_woodland", # rainforest + # dirt grass gravel rock stone + "tundra", # no trees + # trees flower + "seasonal_forest", + # trees + "rainforest", + # trees + "temperate_forest", + # trees + "temperate_rainforest", + # snow rock tree + "boreal_forest" # taiga, snow forest + ] + biome_colors = [ + [255, 255, 178], + [184, 200, 98], + [188, 161, 53], + [190, 255, 242], + [106, 144, 38], + [33, 77, 41], + [86, 179, 106], + [34, 61, 53], + [35, 114, 94] + ] + + + parser = ArgumentParser() + parser.add_argument('--size', type=int, default=1024) + parser.add_argument('--nbins', type=int, default=256) + parser.add_argument('--seed', type=int, default=762345) + parser.add_argument('--outdir', type=str, required=True) + args = parser.parse_args() + size = args.size + n = args.nbins + map_seed = args.seed + outdir = args.outdir + os.makedirs(outdir, exist_ok=True) + np.random.seed(map_seed) + + # start generation + points = np.random.randint(0, size, (514, 2)) + points = relax(points, size, k=100) + vor = voronoi(points, size) + vor_map = voronoi_map(vor, size) + + boundary_displacement = 8 + boundary_noise = np.dstack([noise_map(size, 32, 200, octaves=8), noise_map(size, 32, 250, octaves=8)]) + boundary_noise = np.indices((size, size)).T + boundary_displacement*boundary_noise + boundary_noise = boundary_noise.clip(0, size-1).astype(np.uint32) + + blurred_vor_map = np.zeros_like(vor_map) + + for x in range(size): + for y in range(size): + j, i = boundary_noise[x, y] + blurred_vor_map[x, y] = vor_map[i, j] + + vor_map = blurred_vor_map + temperature_map = noise_map(size, 4, 10) + precipitation_map = noise_map(size, 4, 20) + + uniform_temperature_map = histeq(temperature_map, alpha=0.33) + uniform_precipitation_map = histeq(precipitation_map, alpha=0.33) + temperature_map = uniform_temperature_map + precipitation_map = uniform_precipitation_map + + temperature_cells = average_cells(vor_map, temperature_map) + precipitation_cells = average_cells(vor_map, precipitation_map) + + quantize_temperature_cells = quantize(temperature_cells, n) + quantize_precipitation_cells = quantize(precipitation_cells, n) + + quantize_temperature_map = fill_cells(vor_map, quantize_temperature_cells) + quantize_precipitation_map = fill_cells(vor_map, quantize_precipitation_cells) + + temperature_cells = quantize_temperature_cells + precipitation_cells = quantize_precipitation_cells + + temperature_map = quantize_temperature_map + precipitation_map = quantize_precipitation_map + + im = np.array(Image.open("./assets/biome_image.png"))[:, :, :3] + im = cv2.resize(im, (256, 256)) + biomes = np.zeros((256, 256)) + + for i, color in enumerate(biome_colors): + indices = np.where(np.all(im == color, axis=-1)) + biomes[indices] = i + biomes = np.flip(biomes, axis=0).T + + + n = len(temperature_cells) + biome_cells = np.zeros(n, dtype=np.uint32) + + for i in range(n): + temp, precip = temperature_cells[i], precipitation_cells[i] + biome_cells[i] = biomes[temp, precip] + + biome_map = fill_cells(vor_map, biome_cells).astype(np.uint32) + biome_color_map = color_cells(biome_map, biome_colors) + fig = plt.figure(figsize=(5, 5), dpi=150) + plt.imsave(os.path.join(outdir, 'biome_map.png'), biome_color_map.astype(np.uint8)) + + height_map = noise_map(size, 4, 0, octaves=6, persistence=0.5, lacunarity=2) + land_mask = height_map > 0 + smooth_height_map = noise_map(size, 4, 0, octaves=1, persistence=0.5, lacunarity=2) + + biome_height_maps = [ + # Desert + filter_map(height_map, smooth_height_map, 0.75, 0.2, 0.95, 0.2, 0.2, 0.5), + # Savanna + filter_map(height_map, smooth_height_map, 0.5, 0.1, 0.95, 0.1, 0.1, 0.2), + # Tropical Woodland + filter_map(height_map, smooth_height_map, 0.33, 0.33, 0.95, 0.1, 0.1, 0.75), + # Tundra + filter_map(height_map, smooth_height_map, 0.5, 1, 0.25, 1, 1, 1), + # Seasonal Forest + filter_map(height_map, smooth_height_map, 0.75, 0.5, 0.4, 0.4, 0.33, 0.2), + # Rainforest + filter_map(height_map, smooth_height_map, 0.5, 0.25, 0.66, 1, 1, 0.5), + # Temperate forest + filter_map(height_map, smooth_height_map, 0.75, 0.5, 0.4, 0.4, 0.33, 0.33), + # Temperate Rainforest + filter_map(height_map, smooth_height_map, 0.75, 0.5, 0.4, 0.4, 0.33, 0.33), + # Boreal + filter_map(height_map, smooth_height_map, 0.8, 0.1, 0.9, 0.05, 0.05, 0.1) + ] + + + biome_count = len(biome_names) + biome_masks = np.zeros((biome_count, size, size)) + + for i in range(biome_count): + biome_masks[i, biome_map==i] = 1 + biome_masks[i] = gaussian_filter(biome_masks[i], sigma=16) + + # Remove ocean from masks + blurred_land_mask = land_mask + blurred_land_mask = binary_dilation(land_mask, iterations=32).astype(np.float64) + blurred_land_mask = gaussian_filter(blurred_land_mask, sigma=16) + + # biome mask - [9, size, size] + biome_masks = biome_masks*blurred_land_mask + + for biome_id in range(biome_masks.shape[0]): + cv2.imwrite(os.path.join(outdir, 'biome_mask_{:02d}.png'.format(biome_id)), biome_masks[biome_id]*255) + + adjusted_height_map = height_map.copy() + + for i in range(len(biome_height_maps)): + adjusted_height_map = (1-biome_masks[i])*adjusted_height_map + biome_masks[i]*biome_height_maps[i] + + # add rivers + biome_bound = get_boundary(biome_map, kernel=5) + cell_bound = get_boundary(vor_map, kernel=2) + river_mask = noise_map(size, 4, 4353, octaves=6, persistence=0.5, lacunarity=2) > 0 + + new_biome_bound = biome_bound*(adjusted_height_map<0.5)*land_mask + new_cell_bound = cell_bound*(adjusted_height_map<0.05)*land_mask + + rivers = np.logical_or(new_biome_bound, new_cell_bound)*river_mask + loose_river_mask = binary_dilation(rivers, iterations=8) + rivers_height = gaussian_filter(rivers.astype(np.float64), sigma=2)*loose_river_mask + adjusted_height_river_map = adjusted_height_map*(1-rivers_height) - 0.05*rivers + + sea_color = np.array([12, 14, 255]) + river_land_mask = adjusted_height_river_map >= 0 + land_mask_color = np.repeat(river_land_mask[:, :, np.newaxis], 3, axis=-1) + rivers_biome_color_map = land_mask_color*biome_color_map + (1-land_mask_color)*sea_color + rivers_biome_map = river_land_mask * biome_map + (1 - river_land_mask) * biome_count # use biome count=9 as water indicator + cv2.imwrite(os.path.join(outdir, 'biome_rivers_labels.png'), rivers_biome_map.astype(np.uint8)) + plt.imsave(os.path.join(outdir, 'biome_rivers.png'), rivers_biome_color_map.astype(np.uint8)) + biome_rivers_normal, _ = apply_height_map(rivers_biome_color_map, adjusted_height_map, adjusted_height_map, land_mask) + plt.imsave(os.path.join(outdir, 'biome_rivers_normal.png'), np.clip(biome_rivers_normal, 0, 255).astype(np.uint8)) + save_height_map(adjusted_height_river_map, os.path.join(outdir, 'biome_rivers_height.png')) + save_height_map(adjusted_height_map, os.path.join(outdir, 'biome_height.png')) + + tree_densities = [4000, 1500, 8000, 1000, 10000, 25000, 10000, 20000, 5000] + trees = [np.array(place_trees(tree_densities[i]*4, biome_masks[i])) for i in range(len(biome_names))] + + canvas = np.ones((size, size)) * 255 + plt.figure(dpi=150, figsize=(5, 5)) + for k in range(len(biome_names)): + canvas[trees[k][:, 1], trees[k][:, 0]] = k + cv2.imwrite(os.path.join(outdir, 'biome_trees_dist.png'), canvas) \ No newline at end of file diff --git a/terrain_generator.py b/terrain_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..18fb40869999149ffc8d0785c2d9a80ae0e0e0ff --- /dev/null +++ b/terrain_generator.py @@ -0,0 +1,383 @@ +import numpy as np +from scipy.spatial import Voronoi +from skimage.draw import polygon +from PIL import Image +from noise import snoise3 +from skimage import exposure +from scipy.interpolate import interp1d +import cv2 +from scipy.ndimage import gaussian_filter +from scipy.ndimage import binary_dilation + +from argparse import ArgumentParser + +def save_height_map(height_map, file_name): + #input height map should be float, raw output of noise map + normalized_height_map = (((height_map - height_map.min()) / (height_map.max() - height_map.min()))*255).astype(np.uint8) + cv2.imwrite(file_name, normalized_height_map) + np.save(file_name[:-4] + '.npy', height_map) + +def get_boundary(vor_map, size, kernel=1): + boundary_map = np.zeros_like(vor_map, dtype=bool) + n, m = vor_map.shape + + clip = lambda x: max(0, min(size-1, x)) + def check_for_mult(a): + b = a[0] + for i in range(len(a)-1): + if a[i] != b: return 1 + return 0 + + for i in range(n): + for j in range(m): + boundary_map[i, j] = check_for_mult(vor_map[ + clip(i-kernel):clip(i+kernel+1), + clip(j-kernel):clip(j+kernel+1), + ].flatten()) + + return boundary_map + +def histeq(img, alpha=1): + img_cdf, bin_centers = exposure.cumulative_distribution(img) + img_eq = np.interp(img, bin_centers, img_cdf) + img_eq = np.interp(img_eq, (0, 1), (-1, 1)) + return alpha * img_eq + (1 - alpha) * img + +def voronoi(points, size): + # Add points at edges to eliminate infinite ridges + edge_points = size*np.array([[-1, -1], [-1, 2], [2, -1], [2, 2]]) + new_points = np.vstack([points, edge_points]) + + # Calculate Voronoi tessellation + vor = Voronoi(new_points) + + return vor + +def voronoi_map(vor, size): + # Calculate Voronoi map + vor_map = np.zeros((size, size), dtype=np.uint32) + + for i, region in enumerate(vor.regions): + # Skip empty regions and infinte ridge regions + if len(region) == 0 or -1 in region: continue + # Get polygon vertices + x, y = np.array([vor.vertices[i][::-1] for i in region]).T + # Get pixels inside polygon + rr, cc = polygon(x, y) + # Remove pixels out of image bounds + in_box = np.where((0 <= rr) & (rr < size) & (0 <= cc) & (cc < size)) + rr, cc = rr[in_box], cc[in_box] + # Paint image + vor_map[rr, cc] = i + + return vor_map + +# Lloyd's relaxation +def relax(points, size, k=10): + new_points = points.copy() + for _ in range(k): + vor = voronoi(new_points, size) + new_points = [] + for i, region in enumerate(vor.regions): + if len(region) == 0 or -1 in region: continue + poly = np.array([vor.vertices[i] for i in region]) + center = poly.mean(axis=0) + new_points.append(center) + new_points = np.array(new_points).clip(0, size) + return new_points + +def noise_map(size, res, seed, octaves=1, persistence=0.5, lacunarity=2.0): + scale = size/res + return np.array([[ + snoise3( + (x+0.1)/scale, + y/scale, + seed, + octaves=octaves, + persistence=persistence, + lacunarity=lacunarity + ) + for x in range(size)] + for y in range(size) + ]) + +def average_cells(vor, data): + """Returns the average value of data inside every voronoi cell""" + size = vor.shape[0] + count = np.max(vor)+1 + + sum_ = np.zeros(count) + count = np.zeros(count) + + for i in range(size): + for j in range(size): + p = vor[i, j] + count[p] += 1 + sum_[p] += data[i, j] + + average = sum_/ (count + 1e-3) + average[count==0] = 0 + + return average + +def fill_cells(vor, data): + size = vor.shape[0] + image = np.zeros((size, size)) + + for i in range(size): + for j in range(size): + p = vor[i, j] + image[i, j] = data[p] + + return image + +def color_cells(vor, data, dtype=int): + size = vor.shape[0] + image = np.zeros((size, size, 3)) + + for i in range(size): + for j in range(size): + p = vor[i, j] + image[i, j] = data[p] + + return image.astype(dtype) + +def quantize(data, n): + bins = np.linspace(-1, 1, n+1) + return (np.digitize(data, bins) - 1).clip(0, n-1) + + +def bezier(x1, y1, x2, y2, a): + p1 = np.array([0, 0]) + p2 = np.array([x1, y1]) + p3 = np.array([x2, y2]) + p4 = np.array([1, a]) + + return lambda t: ((1-t)**3 * p1 + 3*(1-t)**2*t * p2 + 3*(1-t)*t**2 * p3 + t**3 * p4) + +def bezier_lut(x1, y1, x2, y2, a): + t = np.linspace(0, 1, 256) + f = bezier(x1, y1, x2, y2, a) + curve = np.array([f(t_) for t_ in t]) + + return interp1d(*curve.T) + +def filter_map(h_map, smooth_h_map, x1, y1, x2, y2, a, b): + f = bezier_lut(x1, y1, x2, y2, a) + output_map = b*h_map + (1-b)*smooth_h_map + output_map = f(output_map.clip(0, 1)) + return output_map + +def filter_inbox(pts, size): + inidx = np.all(pts < size, axis=1) + return pts[inidx] + +def generate_trees(n, size): + trees = np.random.randint(0, size-1, (n, 2)) + trees = relax(trees, size, k=10).astype(np.uint32) + trees = filter_inbox(trees, size) + return trees + +def place_trees(river_land_mask, adjusted_height_river_map, n, mask, size, a=0.5): + trees= generate_trees(n, size) + rr, cc = trees.T + + output_trees = np.zeros((size, size), dtype=bool) + output_trees[rr, cc] = True + output_trees = output_trees*(mask>a)*river_land_mask*(adjusted_height_river_map<0.5) + + output_trees = np.array(np.where(output_trees == 1))[::-1].T + return output_trees + + +def PCGGen(map_size, nbins = 256, seed = 3407): + biome_names = [ + # sand and rock + "desert", + # grass gravel rock stone + "savanna", # mixed woodland and grassland + # trees flower + "tropical_woodland", # rainforest + # dirt grass gravel rock stone + "tundra", # no trees + # trees flower + "seasonal_forest", + # trees + "rainforest", + # trees + "temperate_forest", + # trees + "temperate_rainforest", + # snow rock tree + "boreal_forest" # taiga, snow forest + ] + biome_colors = [ + [255, 255, 178], + [184, 200, 98], + [188, 161, 53], + [190, 255, 242], + [106, 144, 38], + [33, 77, 41], + [86, 179, 106], + [34, 61, 53], + [35, 114, 94] + ] + + size = map_size + n = nbins + map_seed = seed + + # start generation + points = np.random.randint(0, size, (514, 2)) + points = relax(points, size, k=100) + vor = voronoi(points, size) + vor_map = voronoi_map(vor, size) + + boundary_displacement = 8 + boundary_noise = np.dstack([noise_map(size, 32, 200 + map_seed, octaves=8), noise_map(size, 32, 250 + map_seed, octaves=8)]) + boundary_noise = np.indices((size, size)).T + boundary_displacement*boundary_noise + boundary_noise = boundary_noise.clip(0, size-1).astype(np.uint32) + + blurred_vor_map = np.zeros_like(vor_map) + + for x in range(size): + for y in range(size): + j, i = boundary_noise[x, y] + blurred_vor_map[x, y] = vor_map[i, j] + + vor_map = blurred_vor_map + temperature_map = noise_map(size, 2, 10 + map_seed) + precipitation_map = noise_map(size, 2, 20 + map_seed) + + uniform_temperature_map = histeq(temperature_map, alpha=0.33) + uniform_precipitation_map = histeq(precipitation_map, alpha=0.33) + temperature_map = uniform_temperature_map + precipitation_map = uniform_precipitation_map + + temperature_cells = average_cells(vor_map, temperature_map) + precipitation_cells = average_cells(vor_map, precipitation_map) + + quantize_temperature_cells = quantize(temperature_cells, n) + quantize_precipitation_cells = quantize(precipitation_cells, n) + + quantize_temperature_map = fill_cells(vor_map, quantize_temperature_cells) + quantize_precipitation_map = fill_cells(vor_map, quantize_precipitation_cells) + + temperature_cells = quantize_temperature_cells + precipitation_cells = quantize_precipitation_cells + + temperature_map = quantize_temperature_map + precipitation_map = quantize_precipitation_map + + im = np.array(Image.open("./assets/biome_image.png"))[:, :, :3] + im = cv2.resize(im, (256, 256)) + biomes = np.zeros((256, 256)) + + for i, color in enumerate(biome_colors): + indices = np.where(np.all(im == color, axis=-1)) + biomes[indices] = i + biomes = np.flip(biomes, axis=0).T + + + n = len(temperature_cells) + biome_cells = np.zeros(n, dtype=np.uint32) + + for i in range(n): + temp, precip = temperature_cells[i], precipitation_cells[i] + biome_cells[i] = biomes[temp, precip] + + biome_map = fill_cells(vor_map, biome_cells).astype(np.uint32) + biome_color_map = color_cells(biome_map, biome_colors) + height_map = noise_map(size, 4, 0 + map_seed, octaves=6, persistence=0.5, lacunarity=2) + land_mask = height_map > 0 + smooth_height_map = noise_map(size, 4, 0 + map_seed, octaves=1, persistence=0.5, lacunarity=2) + + biome_height_maps = [ + # Desert + filter_map(height_map, smooth_height_map, 0.75, 0.2, 0.95, 0.2, 0.2, 0.5), + # Savanna + filter_map(height_map, smooth_height_map, 0.5, 0.1, 0.95, 0.1, 0.1, 0.2), + # Tropical Woodland + filter_map(height_map, smooth_height_map, 0.33, 0.33, 0.95, 0.1, 0.1, 0.75), + # Tundra + filter_map(height_map, smooth_height_map, 0.5, 1, 0.25, 1, 1, 1), + # Seasonal Forest + filter_map(height_map, smooth_height_map, 0.75, 0.5, 0.4, 0.4, 0.33, 0.2), + # Rainforest + filter_map(height_map, smooth_height_map, 0.5, 0.25, 0.66, 1, 1, 0.5), + # Temperate forest + filter_map(height_map, smooth_height_map, 0.75, 0.5, 0.4, 0.4, 0.33, 0.33), + # Temperate Rainforest + filter_map(height_map, smooth_height_map, 0.75, 0.5, 0.4, 0.4, 0.33, 0.33), + # Boreal + filter_map(height_map, smooth_height_map, 0.8, 0.1, 0.9, 0.05, 0.05, 0.1) + ] + + + biome_count = len(biome_names) + biome_masks = np.zeros((biome_count, size, size)) + + for i in range(biome_count): + biome_masks[i, biome_map==i] = 1 + biome_masks[i] = gaussian_filter(biome_masks[i], sigma=16) + + # Remove ocean from masks + blurred_land_mask = land_mask + blurred_land_mask = binary_dilation(land_mask, iterations=32).astype(np.float64) + blurred_land_mask = gaussian_filter(blurred_land_mask, sigma=16) + + # biome mask - [9, size, size] + biome_masks = biome_masks*blurred_land_mask + + adjusted_height_map = height_map.copy() + + for i in range(len(biome_height_maps)): + adjusted_height_map = (1-biome_masks[i])*adjusted_height_map + biome_masks[i]*biome_height_maps[i] + + # add rivers + biome_bound = get_boundary(biome_map, size, kernel=5) + cell_bound = get_boundary(vor_map, size, kernel=2) + river_mask = noise_map(size, 4, 4353 + map_seed, octaves=6, persistence=0.5, lacunarity=2) > 0 + + new_biome_bound = biome_bound*(adjusted_height_map<0.5)*land_mask + new_cell_bound = cell_bound*(adjusted_height_map<0.05)*land_mask + + rivers = np.logical_or(new_biome_bound, new_cell_bound)*river_mask + loose_river_mask = binary_dilation(rivers, iterations=8) + rivers_height = gaussian_filter(rivers.astype(np.float64), sigma=2)*loose_river_mask + adjusted_height_river_map = adjusted_height_map*(1-rivers_height) - 0.05*rivers + + sea_color = np.array([12, 14, 255]) + river_land_mask = adjusted_height_river_map >= 0 + land_mask_color = np.repeat(river_land_mask[:, :, np.newaxis], 3, axis=-1) + rivers_biome_color_map = land_mask_color*biome_color_map + (1-land_mask_color)*sea_color + rivers_biome_map = river_land_mask * biome_map + (1 - river_land_mask) * biome_count # use biome count=9 as water indicator + + semantic_map = rivers_biome_map + semantic_map_color = rivers_biome_color_map + height_map = adjusted_height_river_map + + tree_densities = [4000, 1500, 8000, 1000, 10000, 25000, 10000, 20000, 5000] + trees = [np.array(place_trees(river_land_mask, adjusted_height_river_map, tree_densities[i], biome_masks[i], size)) for i in range(len(biome_names))] + + canvas = np.ones((size, size)) * 255 + for k in range(len(biome_names)): + canvas[trees[k][:, 1], trees[k][:, 0]] = k + tree_map = canvas + + return height_map, semantic_map, tree_map, semantic_map_color + +if __name__ == '__main__': + import os + parser = ArgumentParser() + parser.add_argument('--size', type=int, required=True) + parser.add_argument('--nbins', type=int, default=256) + parser.add_argument('--seed', type=int, default=3407) + parser.add_argument('--outdir', type=str, required=True) + args = parser.parse_args() + outdir = args.outdir + heightmap, semanticmap, treemap, colormap = PCGGen(args.size, args.nbins, args.seed) + save_height_map(heightmap, os.path.join(outdir, 'heightmap.png')) + cv2.imwrite(os.path.join(outdir, 'semanticmap.png'), semanticmap.astype(np.uint8)) + cv2.imwrite(os.path.join(outdir, 'colormap.png'), colormap[..., [2, 1, 0]].astype(np.uint8)) + cv2.imwrite(os.path.join(outdir, 'treemap.png'), treemap) diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..c62171e9a2884167298c01159eb79f1cc9e9e7dd --- /dev/null +++ b/train.py @@ -0,0 +1,168 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import argparse +import os +import sys +import random + +import torch.autograd.profiler as profiler +import wandb + +import imaginaire.config +from imaginaire.config import Config +from imaginaire.utils.cudnn import init_cudnn +from imaginaire.utils.dataset import get_train_and_val_dataloader +from imaginaire.utils.distributed import init_dist, is_master, get_world_size +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.gpu_affinity import set_affinity +from imaginaire.utils.misc import slice_tensor +from imaginaire.utils.logging import init_logging, make_logging_dir +from imaginaire.utils.trainer import (get_model_optimizer_and_scheduler, + get_trainer, set_random_seed) + +sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.')) + + +def parse_args(): + parser = argparse.ArgumentParser(description='Training') + parser.add_argument('--config', + help='Path to the training config file.', required=True) + parser.add_argument('--logdir', help='Dir for saving logs and models.') + parser.add_argument('--checkpoint', default='', help='Checkpoint path.') + parser.add_argument('--seed', type=int, default=2, help='Random seed.') + parser.add_argument('--randomized_seed', action='store_true', help='Use a random seed between 0-10000.') + parser.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0)) + parser.add_argument('--single_gpu', action='store_true') + parser.add_argument('--debug', action='store_true') + parser.add_argument('--use_jit', action='store_true') + parser.add_argument('--profile', action='store_true') + parser.add_argument('--wandb', action='store_true') + parser.add_argument('--wandb_name', default='default', type=str) + parser.add_argument('--wandb_id', type=str) + parser.add_argument('--resume', type=int) + parser.add_argument('--num_workers', type=int) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + set_affinity(args.local_rank) + if args.randomized_seed: + args.seed = random.randint(0, 10000) + set_random_seed(args.seed, by_rank=True) + cfg = Config(args.config) + try: + from userlib.auto_resume import AutoResume + AutoResume.init() + except: # noqa + pass + + # If args.single_gpu is set to True, + # we will disable distributed data parallel + if not args.single_gpu: + cfg.local_rank = args.local_rank + init_dist(cfg.local_rank) + print(f"Training with {get_world_size()} GPUs.") + + # Global arguments. + imaginaire.config.DEBUG = args.debug + imaginaire.config.USE_JIT = args.use_jit + + # Override the number of data loading workers if necessary + if args.num_workers is not None: + cfg.data.num_workers = args.num_workers + + # Create log directory for storing training results. + cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir) + make_logging_dir(cfg.logdir) + + # Initialize cudnn. + init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark) + + # Initialize data loaders and models. + batch_size = cfg.data.train.batch_size + total_step = max(cfg.trainer.dis_step, cfg.trainer.gen_step) + cfg.data.train.batch_size *= total_step + train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg, args.seed) + net_G, net_D, opt_G, opt_D, sch_G, sch_D = \ + get_model_optimizer_and_scheduler(cfg, seed=args.seed) + trainer = get_trainer(cfg, net_G, net_D, + opt_G, opt_D, + sch_G, sch_D, + train_data_loader, val_data_loader) + resumed, current_epoch, current_iteration = trainer.load_checkpoint(cfg, args.checkpoint, args.resume) + + # Initialize Wandb. + if is_master(): + if args.wandb_id is not None: + wandb_id = args.wandb_id + else: + if resumed and os.path.exists(os.path.join(cfg.logdir, 'wandb_id.txt')): + with open(os.path.join(cfg.logdir, 'wandb_id.txt'), 'r+') as f: + wandb_id = f.read() + else: + wandb_id = wandb.util.generate_id() + with open(os.path.join(cfg.logdir, 'wandb_id.txt'), 'w+') as f: + f.write(wandb_id) + wandb_mode = "disabled" if (args.debug or not args.wandb) else "online" + wandb.init(id=wandb_id, + project=args.wandb_name, + config=cfg, + name=os.path.basename(cfg.logdir), + resume="allow", + settings=wandb.Settings(start_method="fork"), + mode=wandb_mode) + wandb.config.update({'dataset': cfg.data.name}) + wandb.watch(trainer.net_G_module) + wandb.watch(trainer.net_D.module) + + # Start training. + for epoch in range(current_epoch, cfg.max_epoch): + print('Epoch {} ...'.format(epoch)) + if not args.single_gpu: + train_data_loader.sampler.set_epoch(current_epoch) + trainer.start_of_epoch(current_epoch) + for it, data in enumerate(train_data_loader): + with profiler.profile(enabled=args.profile, + use_cuda=True, + profile_memory=True, + record_shapes=True) as prof: + data = trainer.start_of_iteration(data, current_iteration) + + for i in range(cfg.trainer.dis_step): + trainer.dis_update( + slice_tensor(data, i * batch_size, + (i + 1) * batch_size)) + for i in range(cfg.trainer.gen_step): + trainer.gen_update( + slice_tensor(data, i * batch_size, + (i + 1) * batch_size)) + + current_iteration += 1 + trainer.end_of_iteration(data, current_epoch, current_iteration) + if current_iteration >= cfg.max_iter: + print('Done with training!!!') + return + if args.profile: + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + prof.export_chrome_trace(os.path.join(cfg.logdir, "trace.json")) + try: + if AutoResume.termination_requested(): + trainer.save_checkpoint(current_epoch, current_iteration) + AutoResume.request_resume() + print("Training terminated. Returning") + return 0 + except: # noqa + pass + + current_epoch += 1 + trainer.end_of_epoch(data, current_epoch, current_iteration) + print('Done with training!!!') + return + + +if __name__ == "__main__": + main()