File size: 9,922 Bytes
b72e09b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import cv2
import os


class PCGCache(nn.Module):
    r"""PCG Datasets"""
    def __init__(self, pcg_dataset_path):
        super(PCGCache, self).__init__()
        '''
        height_map: [size, size] array, in [-1, 1] range where < 0 indicates water
        semantic_map: [size, size] array, in {0, 1, ..., 9} range, where 9 indicates water
        '''
        self.sample_size = 1024
        self.sample_height = 256
        pcg_world_list = sorted(os.listdir(pcg_dataset_path))
        self.pcg_world_path = []
        for p in pcg_world_list:
            self.pcg_world_path.append(os.path.join(pcg_dataset_path, p))
        self.n = len(self.pcg_world_path)

    def sample_world(self, device):
        idx = random.randint(0, self.n - 1)
        world_path = self.pcg_world_path[idx]
        voxel_sparse = np.load(os.path.join(world_path, 'voxel_sparse.npy'))
        current_height_map = np.load(os.path.join(world_path, 'height_map.npy'))
        current_semantic_map = np.load(os.path.join(world_path, 'semantic_map.npy'))
        heightmap = np.load(os.path.join(world_path, 'hmap_mc.npy'))
        voxel_sparse = torch.from_numpy(voxel_sparse).to(device)
        voxel_1 = voxel_sparse[0, :].to(torch.int64)
        voxel_2 = voxel_sparse[1, :].to(torch.int64)
        voxel_3 = voxel_sparse[2, :].to(torch.int64)
        self.voxel_t = torch.zeros(self.sample_height, self.sample_size, self.sample_size, device=device, dtype=torch.int32)
        self.voxel_t[voxel_1, voxel_2, voxel_3] = voxel_sparse[3, :].to(torch.int32)
        self.current_height_map = torch.from_numpy(current_height_map).to(device)
        self.current_semantic_map = torch.from_numpy(current_semantic_map).to(device)
        self.heightmap = torch.from_numpy(heightmap)
        self.trans_mat = torch.eye(4)
        gnd_level = heightmap.min()
        sky_level = heightmap.max() + 1
        self.voxel_t = self.voxel_t[gnd_level:sky_level, :, :]
        self.trans_mat[0, 3] += gnd_level

    def world2local(self, v, is_vec=False):
        mat_world2local = torch.inverse(self.trans_mat)
        return trans_vec_homo(mat_world2local, v, is_vec)

    def _truncate_voxel(self):
        gnd_level = self.heightmap.min()
        sky_level = self.heightmap.max() + 1
        self.voxel_t = self.voxel_t[gnd_level:sky_level, :, :]
        self.trans_mat[0, 3] += gnd_level
        print('[GANcraft-utils] Voxel truncated. Gnd: {}; Sky: {}.'.format(gnd_level.item(), sky_level.item()))

    def is_sea(self, loc):
        r"""loc: [2]: x, z."""
        x = int(loc[1])
        z = int(loc[2])
        if x < 0 or x > self.heightmap.size(0) or z < 0 or z > self.heightmap.size(1):
            print('[McVoxel] is_sea(): Index out of bound.')
            return True
        y = self.heightmap[x, z] - self.trans_mat[0, 3]
        y = int(y)
        if self.voxel_t[y, x, z] == 26:
            print('[McVoxel] is_sea(): Get a sea.')
            print(self.voxel_t[y, x, z], self.voxel_t[y+1, x, z])
            return True
        else:
            return False


class PCGVoxelGenerator(nn.Module):
    def __init__(self, sample_size = 2048):
        super(PCGVoxelGenerator, self).__init__()
        self.sample_height = 256
        self.sample_size = sample_size
        self.voxel_t = None

    def next_world(self, device, world_dir, pcg_asset):
        # Generate BEV representation
        print('[PCGGenerator] Loading BEV scene representation...')
        heightmap_path = os.path.join(world_dir, 'heightmap.npy')
        semanticmap_path = os.path.join(world_dir, 'semanticmap.png')
        treemap_path = os.path.join(world_dir, 'treemap.png')
        height_map = np.load(heightmap_path)
        semantic_map = cv2.imread(semanticmap_path, 0)
        tree_map = cv2.imread(treemap_path, 0)

        print('[PCGGenerator] Creating scene windows...')
        height_map[height_map < 0] = 0
        height_map = ((height_map - height_map.min()) / (1 - height_map.min()) * (self.sample_height - 1)).astype(np.int16)

        self.total_size = height_map.shape


        org_semantic_map = torch.from_numpy(semantic_map.copy())
        org_semantic_map[tree_map != 255] = 10
        chunk_trees_map = tree_map

        biome_trees_dict = {
            'desert': [],
            'savanna': [5],
            'twoodland': [1, 7],
            'tundra': [],
            'seasonal forest': [1, 2],
            'rainforest': [1, 2, 3],
            'temp forest': [4],
            'temp rainforest': [0, 3],
            'boreal': [5,6,7],
            'water': [],
        }
        biome2mclabels = torch.tensor([28, 9, 8, 1, 9, 8, 9, 8, 30, 26], dtype=torch.int32)
        biome_names = list(biome_trees_dict.keys())
        chunk_grid_x, chunk_grid_y = torch.meshgrid(torch.arange(self.total_size[0]), torch.arange(self.total_size[1]))
        world_voxel_t = torch.zeros(self.sample_height, self.total_size[0], self.total_size[1]).to(torch.int32)

        chunk_height_map = torch.from_numpy(height_map.astype(int))[None, ...]
        chunk_semantic_map = torch.from_numpy(semantic_map)
        chunk_semantic_map = biome2mclabels[chunk_semantic_map[None, ...].long().contiguous()]
        world_voxel_t = world_voxel_t.scatter_(0, chunk_height_map, chunk_semantic_map)
        pad_num = 16
        for preproc_step in range(pad_num):
            world_voxel_t = world_voxel_t.scatter(0, torch.clip(chunk_height_map + preproc_step + 1, 0, self.sample_height - 1), chunk_semantic_map)
        chunk_height_map = chunk_height_map + pad_num
        chunk_height_map = chunk_height_map[0]
        boundary_detect = 50

        trees_models = pcg_asset['assets']

        for biome_id in range(biome2mclabels.shape[0]):
            tree_pos_mask = (chunk_trees_map == biome_id)
            tree_pos_x = chunk_grid_x[tree_pos_mask]
            tree_pos_y = chunk_grid_y[tree_pos_mask]
            tree_pos_h = chunk_height_map[tree_pos_mask]
            assert len(tree_pos_x) == len(tree_pos_y)
            selected_trees = biome_trees_dict[biome_names[biome_id]]
            if len(selected_trees) == 0:
                continue
            for idx in range(len(tree_pos_x)):
                if tree_pos_x[idx] < boundary_detect or tree_pos_x[idx] > self.total_size[0] - boundary_detect or tree_pos_y[idx] < boundary_detect or tree_pos_y[idx] > self.total_size[1] - boundary_detect or tree_pos_h[idx] > self.sample_height - boundary_detect:
                    # hack, to avoid out of index near the boundary
                    continue
                tree_id = random.choice(selected_trees)
                tmp = world_voxel_t[tree_pos_h[idx]: tree_pos_h[idx] + trees_models[tree_id].shape[0], tree_pos_x[idx]: tree_pos_x[idx] + trees_models[tree_id].shape[1], tree_pos_y[idx]: tree_pos_y[idx] + trees_models[tree_id].shape[2]]
                tmp_mask = (tmp == 0)
                try:
                    world_voxel_t[tree_pos_h[idx]: tree_pos_h[idx] + trees_models[tree_id].shape[0], tree_pos_x[idx]: tree_pos_x[idx] + trees_models[tree_id].shape[1], tree_pos_y[idx]: tree_pos_y[idx] + trees_models[tree_id].shape[2]][tmp_mask] = trees_models[tree_id][tmp_mask]
                except:
                    print('height?', tree_pos_h[idx])
                    print(tmp_mask.shape)
                    print(tmp.shape)
                    print(trees_models[tree_id].shape)
                    print(world_voxel_t.shape)
                    print(tree_id)
                    raise NotImplementedError
        self.trans_mat = torch.eye(4)  # Transform voxel to world
        # Generate heightmap for camera trajectory generation
        m, h = torch.max((torch.flip(world_voxel_t, [0]) != 0).int(), dim=0, keepdim=False)
        heightmap = world_voxel_t.shape[0] - 1 - h
        heightmap[m == 0] = 0  # Special case when the whole vertical column is empty
        gnd_level = heightmap.min()
        sky_level = heightmap.max() + 1
        current_height_map = (chunk_height_map / (self.sample_height - 1))[None, None, ...]
        current_semantic_map = F.one_hot(org_semantic_map.to(torch.int64)).to(torch.float).permute(2, 0, 1)[None, ...]

        self.current_height_map = current_height_map.to(device)
        self.current_semantic_map = current_semantic_map.to(device)
        self.heightmap = heightmap
        self.voxel_t = world_voxel_t[gnd_level:sky_level, :, :].to(device)
        self.trans_mat[0, 3] += gnd_level

    def world2local(self, v, is_vec=False):
        mat_world2local = torch.inverse(self.trans_mat)
        return trans_vec_homo(mat_world2local, v, is_vec)

    def is_sea(self, loc):
        r"""loc: [2]: x, z."""
        x = int(loc[1])
        z = int(loc[2])
        if x < 0 or x > self.heightmap.size(0) or z < 0 or z > self.heightmap.size(1):
            print('[McVoxel] is_sea(): Index out of bound.')
            return True
        y = self.heightmap[x, z] - self.trans_mat[0, 3]
        y = int(y)
        if self.voxel_t[y, x, z] == 26:
            print('[McVoxel] is_sea(): Get a sea.')
            print(self.voxel_t[y, x, z], self.voxel_t[y+1, x, z])
            return True
        else:
            return False
        
def trans_vec_homo(m, v, is_vec=False):
    r"""3-dimensional Homogeneous matrix and regular vector multiplication
    Convert v to homogeneous vector, perform M-V multiplication, and convert back
    Note that this function does not support autograd.

    Args:
        m (4 x 4 tensor): a homogeneous matrix
        v (3 tensor): a 3-d vector
        vec (bool): if true, v is direction. Otherwise v is point
    """
    if is_vec:
        v = torch.tensor([v[0], v[1], v[2], 0], dtype=v.dtype)
    else:
        v = torch.tensor([v[0], v[1], v[2], 1], dtype=v.dtype)
    v = torch.mv(m, v)
    if not is_vec:
        v = v / v[3]
    v = v[:3]
    return v