Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1 +1,151 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import html
|
4 |
+
import glob
|
5 |
+
import uuid
|
6 |
+
import hashlib
|
7 |
+
import requests
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
os.system("git clone https://github.com/FrozenBurning/SceneDreamer.git")
|
11 |
+
os.system("cp -r SceneDreamer/* ./")
|
12 |
+
os.system("bash install.sh")
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import importlib
|
19 |
+
import argparse
|
20 |
+
from imaginaire.config import Config
|
21 |
+
from imaginaire.utils.cudnn import init_cudnn
|
22 |
+
import gradio as gr
|
23 |
+
from PIL import Image
|
24 |
+
|
25 |
+
|
26 |
+
class WrappedModel(nn.Module):
|
27 |
+
r"""Dummy wrapping the module.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, module):
|
31 |
+
super(WrappedModel, self).__init__()
|
32 |
+
self.module = module
|
33 |
+
|
34 |
+
def forward(self, *args, **kwargs):
|
35 |
+
r"""PyTorch module forward function overload."""
|
36 |
+
return self.module(*args, **kwargs)
|
37 |
+
|
38 |
+
def parse_args():
|
39 |
+
parser = argparse.ArgumentParser(description='Training')
|
40 |
+
parser.add_argument('--config', type=str, default='./configs/scenedreamer_inference.yaml', help='Path to the training config file.')
|
41 |
+
parser.add_argument('--checkpoint', default='./scenedreamer_released.pt',
|
42 |
+
help='Checkpoint path.')
|
43 |
+
parser.add_argument('--output_dir', type=str, default='./test/',
|
44 |
+
help='Location to save the image outputs')
|
45 |
+
parser.add_argument('--seed', type=int, default=8888,
|
46 |
+
help='Random seed.')
|
47 |
+
args = parser.parse_args()
|
48 |
+
return args
|
49 |
+
|
50 |
+
|
51 |
+
args = parse_args()
|
52 |
+
cfg = Config(args.config)
|
53 |
+
|
54 |
+
# Initialize cudnn.
|
55 |
+
init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)
|
56 |
+
|
57 |
+
# Initialize data loaders and models.
|
58 |
+
|
59 |
+
lib_G = importlib.import_module(cfg.gen.type)
|
60 |
+
net_G = lib_G.Generator(cfg.gen, cfg.data)
|
61 |
+
net_G = net_G.to('cuda')
|
62 |
+
net_G = WrappedModel(net_G)
|
63 |
+
|
64 |
+
if args.checkpoint == '':
|
65 |
+
raise NotImplementedError("No checkpoint is provided for inference!")
|
66 |
+
|
67 |
+
# Load checkpoint.
|
68 |
+
# trainer.load_checkpoint(cfg, args.checkpoint)
|
69 |
+
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
70 |
+
net_G.load_state_dict(checkpoint['net_G'])
|
71 |
+
|
72 |
+
# Do inference.
|
73 |
+
net_G = net_G.module
|
74 |
+
net_G.eval()
|
75 |
+
for name, param in net_G.named_parameters():
|
76 |
+
param.requires_grad = False
|
77 |
+
torch.cuda.empty_cache()
|
78 |
+
world_dir = os.path.join(args.output_dir)
|
79 |
+
os.makedirs(world_dir, exist_ok=True)
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
def get_bev(seed):
|
84 |
+
print('[PCGGenerator] Generating BEV scene representation...')
|
85 |
+
os.system('python terrain_generator.py --size {} --seed {} --outdir {}'.format(net_G.voxel.sample_size, seed, world_dir))
|
86 |
+
heightmap_path = os.path.join(world_dir, 'heightmap.png')
|
87 |
+
semantic_path = os.path.join(world_dir, 'colormap.png')
|
88 |
+
heightmap = Image.open(heightmap_path)
|
89 |
+
semantic = Image.open(semantic_path)
|
90 |
+
return semantic, heightmap
|
91 |
+
|
92 |
+
def get_video(seed, num_frames, reso_h, reso_w):
|
93 |
+
device = torch.device('cuda')
|
94 |
+
rng_cuda = torch.Generator(device=device)
|
95 |
+
rng_cuda = rng_cuda.manual_seed(seed)
|
96 |
+
torch.manual_seed(seed)
|
97 |
+
torch.cuda.manual_seed(seed)
|
98 |
+
net_G.voxel.next_world(device, world_dir, checkpoint)
|
99 |
+
cam_mode = cfg.inference_args.camera_mode
|
100 |
+
cfg.inference_args.cam_maxstep = num_frames
|
101 |
+
cfg.inference_args.resolution_hw = [reso_h, reso_w]
|
102 |
+
current_outdir = os.path.join(world_dir, 'camera_{:02d}'.format(cam_mode))
|
103 |
+
os.makedirs(current_outdir, exist_ok=True)
|
104 |
+
z = torch.empty(1, net_G.style_dims, dtype=torch.float32, device=device)
|
105 |
+
z.normal_(generator=rng_cuda)
|
106 |
+
net_G.inference_givenstyle(z, current_outdir, **vars(cfg.inference_args))
|
107 |
+
return os.path.join(current_outdir, 'rgb_render.mp4')
|
108 |
+
|
109 |
+
markdown=f'''
|
110 |
+
# SceneDreamer: Unbounded 3D Scene Generation from 2D Image Collections
|
111 |
+
|
112 |
+
Authored by Zhaoxi Chen, Guangcong Wang, Ziwei Liu
|
113 |
+
### Useful links:
|
114 |
+
- [Official Github Repo](https://github.com/FrozenBurning/SceneDreamer)
|
115 |
+
- [Project Page](https://scene-dreamer.github.io/)
|
116 |
+
- [arXiv Link](https://arxiv.org/abs/2302.01330)
|
117 |
+
Licensed under the S-Lab License.
|
118 |
+
We offer a sampled scene whose BEVs are shown on the right. You can also use the button "Generate BEV" to randomly sample a new 3D world represented by a height map and a semantic map. But it requires a long time.
|
119 |
+
|
120 |
+
To render video, push the button "Render" to generate a camera trajectory flying through the world. You can specify rendering options as shown below!
|
121 |
+
'''
|
122 |
+
|
123 |
+
with gr.Blocks() as demo:
|
124 |
+
with gr.Row():
|
125 |
+
with gr.Column():
|
126 |
+
gr.Markdown(markdown)
|
127 |
+
with gr.Column():
|
128 |
+
with gr.Row():
|
129 |
+
with gr.Column():
|
130 |
+
semantic = gr.Image(value='./test/colormap.png',type="pil", height=512, width=512)
|
131 |
+
with gr.Column():
|
132 |
+
height = gr.Image(value='./test/heightmap.png', type="pil", height=512, width=512)
|
133 |
+
with gr.Row():
|
134 |
+
# with gr.Column():
|
135 |
+
# image = gr.Image(type='pil', shape(540, 960))
|
136 |
+
with gr.Column():
|
137 |
+
video = gr.Video()
|
138 |
+
with gr.Row():
|
139 |
+
num_frames = gr.Slider(minimum=10, maximum=200, value=20, step=1, label='Number of rendered frames')
|
140 |
+
user_seed = gr.Slider(minimum=0, maximum=999999, value=8888, step=1, label='Random seed')
|
141 |
+
resolution_h = gr.Slider(minimum=256, maximum=2160, value=270, step=1, label='Height of rendered image')
|
142 |
+
resolution_w = gr.Slider(minimum=256, maximum=3840, value=480, step=1, label='Width of rendered image')
|
143 |
+
|
144 |
+
with gr.Row():
|
145 |
+
btn = gr.Button(value="Generate BEV")
|
146 |
+
btn_2=gr.Button(value="Render")
|
147 |
+
|
148 |
+
btn.click(get_bev,[user_seed],[semantic, height])
|
149 |
+
btn_2.click(get_video,[user_seed, num_frames, resolution_h, resolution_w], [video])
|
150 |
+
|
151 |
+
demo.launch(debug=True)
|