gabrielsemiceki9 commited on
Commit
5c73806
·
verified ·
1 Parent(s): b0b8615

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -1
app.py CHANGED
@@ -1 +1,151 @@
1
- SceneDreamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import html
4
+ import glob
5
+ import uuid
6
+ import hashlib
7
+ import requests
8
+ from tqdm import tqdm
9
+
10
+ os.system("git clone https://github.com/FrozenBurning/SceneDreamer.git")
11
+ os.system("cp -r SceneDreamer/* ./")
12
+ os.system("bash install.sh")
13
+
14
+
15
+ import os
16
+ import torch
17
+ import torch.nn as nn
18
+ import importlib
19
+ import argparse
20
+ from imaginaire.config import Config
21
+ from imaginaire.utils.cudnn import init_cudnn
22
+ import gradio as gr
23
+ from PIL import Image
24
+
25
+
26
+ class WrappedModel(nn.Module):
27
+ r"""Dummy wrapping the module.
28
+ """
29
+
30
+ def __init__(self, module):
31
+ super(WrappedModel, self).__init__()
32
+ self.module = module
33
+
34
+ def forward(self, *args, **kwargs):
35
+ r"""PyTorch module forward function overload."""
36
+ return self.module(*args, **kwargs)
37
+
38
+ def parse_args():
39
+ parser = argparse.ArgumentParser(description='Training')
40
+ parser.add_argument('--config', type=str, default='./configs/scenedreamer_inference.yaml', help='Path to the training config file.')
41
+ parser.add_argument('--checkpoint', default='./scenedreamer_released.pt',
42
+ help='Checkpoint path.')
43
+ parser.add_argument('--output_dir', type=str, default='./test/',
44
+ help='Location to save the image outputs')
45
+ parser.add_argument('--seed', type=int, default=8888,
46
+ help='Random seed.')
47
+ args = parser.parse_args()
48
+ return args
49
+
50
+
51
+ args = parse_args()
52
+ cfg = Config(args.config)
53
+
54
+ # Initialize cudnn.
55
+ init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)
56
+
57
+ # Initialize data loaders and models.
58
+
59
+ lib_G = importlib.import_module(cfg.gen.type)
60
+ net_G = lib_G.Generator(cfg.gen, cfg.data)
61
+ net_G = net_G.to('cuda')
62
+ net_G = WrappedModel(net_G)
63
+
64
+ if args.checkpoint == '':
65
+ raise NotImplementedError("No checkpoint is provided for inference!")
66
+
67
+ # Load checkpoint.
68
+ # trainer.load_checkpoint(cfg, args.checkpoint)
69
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
70
+ net_G.load_state_dict(checkpoint['net_G'])
71
+
72
+ # Do inference.
73
+ net_G = net_G.module
74
+ net_G.eval()
75
+ for name, param in net_G.named_parameters():
76
+ param.requires_grad = False
77
+ torch.cuda.empty_cache()
78
+ world_dir = os.path.join(args.output_dir)
79
+ os.makedirs(world_dir, exist_ok=True)
80
+
81
+
82
+
83
+ def get_bev(seed):
84
+ print('[PCGGenerator] Generating BEV scene representation...')
85
+ os.system('python terrain_generator.py --size {} --seed {} --outdir {}'.format(net_G.voxel.sample_size, seed, world_dir))
86
+ heightmap_path = os.path.join(world_dir, 'heightmap.png')
87
+ semantic_path = os.path.join(world_dir, 'colormap.png')
88
+ heightmap = Image.open(heightmap_path)
89
+ semantic = Image.open(semantic_path)
90
+ return semantic, heightmap
91
+
92
+ def get_video(seed, num_frames, reso_h, reso_w):
93
+ device = torch.device('cuda')
94
+ rng_cuda = torch.Generator(device=device)
95
+ rng_cuda = rng_cuda.manual_seed(seed)
96
+ torch.manual_seed(seed)
97
+ torch.cuda.manual_seed(seed)
98
+ net_G.voxel.next_world(device, world_dir, checkpoint)
99
+ cam_mode = cfg.inference_args.camera_mode
100
+ cfg.inference_args.cam_maxstep = num_frames
101
+ cfg.inference_args.resolution_hw = [reso_h, reso_w]
102
+ current_outdir = os.path.join(world_dir, 'camera_{:02d}'.format(cam_mode))
103
+ os.makedirs(current_outdir, exist_ok=True)
104
+ z = torch.empty(1, net_G.style_dims, dtype=torch.float32, device=device)
105
+ z.normal_(generator=rng_cuda)
106
+ net_G.inference_givenstyle(z, current_outdir, **vars(cfg.inference_args))
107
+ return os.path.join(current_outdir, 'rgb_render.mp4')
108
+
109
+ markdown=f'''
110
+ # SceneDreamer: Unbounded 3D Scene Generation from 2D Image Collections
111
+
112
+ Authored by Zhaoxi Chen, Guangcong Wang, Ziwei Liu
113
+ ### Useful links:
114
+ - [Official Github Repo](https://github.com/FrozenBurning/SceneDreamer)
115
+ - [Project Page](https://scene-dreamer.github.io/)
116
+ - [arXiv Link](https://arxiv.org/abs/2302.01330)
117
+ Licensed under the S-Lab License.
118
+ We offer a sampled scene whose BEVs are shown on the right. You can also use the button "Generate BEV" to randomly sample a new 3D world represented by a height map and a semantic map. But it requires a long time.
119
+
120
+ To render video, push the button "Render" to generate a camera trajectory flying through the world. You can specify rendering options as shown below!
121
+ '''
122
+
123
+ with gr.Blocks() as demo:
124
+ with gr.Row():
125
+ with gr.Column():
126
+ gr.Markdown(markdown)
127
+ with gr.Column():
128
+ with gr.Row():
129
+ with gr.Column():
130
+ semantic = gr.Image(value='./test/colormap.png',type="pil", height=512, width=512)
131
+ with gr.Column():
132
+ height = gr.Image(value='./test/heightmap.png', type="pil", height=512, width=512)
133
+ with gr.Row():
134
+ # with gr.Column():
135
+ # image = gr.Image(type='pil', shape(540, 960))
136
+ with gr.Column():
137
+ video = gr.Video()
138
+ with gr.Row():
139
+ num_frames = gr.Slider(minimum=10, maximum=200, value=20, step=1, label='Number of rendered frames')
140
+ user_seed = gr.Slider(minimum=0, maximum=999999, value=8888, step=1, label='Random seed')
141
+ resolution_h = gr.Slider(minimum=256, maximum=2160, value=270, step=1, label='Height of rendered image')
142
+ resolution_w = gr.Slider(minimum=256, maximum=3840, value=480, step=1, label='Width of rendered image')
143
+
144
+ with gr.Row():
145
+ btn = gr.Button(value="Generate BEV")
146
+ btn_2=gr.Button(value="Render")
147
+
148
+ btn.click(get_bev,[user_seed],[semantic, height])
149
+ btn_2.click(get_video,[user_seed, num_frames, resolution_h, resolution_w], [video])
150
+
151
+ demo.launch(debug=True)