File size: 3,166 Bytes
d63ad23 d454202 d63ad23 1fa688f d63ad23 1fa688f 5777f44 1fa688f f4e8cf6 dfab55e 1fa688f a14c9ce 1fa688f dfab55e cb29219 a14c9ce 1fa688f f4e8cf6 1fa688f f4e8cf6 1fa688f f4e8cf6 1fa688f a14c9ce 1fa688f a14c9ce 1fa688f a14c9ce 76eeb7d 1fa688f f4e8cf6 1fa688f a14c9ce 1fa688f cb29219 1fa688f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import numpy as np
import torch
import time
import nvdiffrast.torch as dr
from util.utils import get_tri
import tempfile
from mesh import Mesh
import zipfile
from util.renderer import Renderer
import trimesh
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
scale=model.input.scale, geo_type = model.geo_type)
color_tri = torch.from_numpy(rgb)/255
xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
color = color_tri.permute(2,0,1)
xyz = xyz_tri.permute(2,0,1)
def get_imgs(color):
# color : [C, H, W*6]
color_list = []
color_list.append(color[:,:,256*5:256*(1+5)])
for i in range(0,5):
color_list.append(color[:,:,256*i:256*(1+i)])
return torch.stack(color_list, dim=0)# [6, C, H, W]
triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)
triplane = torch.cat([color,xyz],dim=1).to(device)
# 3D visualize
model.eval()
if model.denoising == True:
tnew = 20
tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) *0.5+0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
start_time = time.time()
with torch.no_grad():
triplane_feature2 = model.unet2(triplane,tnew)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"unet takes {elapsed_time}s")
else:
triplane_feature2 = model.unet2(triplane)
with torch.no_grad():
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
from kiui.mesh_utils import clean_mesh
verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=True, remesh_size=0.005, remesh_iters=1)
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
start_time = time.time()
with torch.no_grad():
mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"uv takes {elapsed_time}s")
# Convert .obj (with vertex colors) to .glb
obj_path = mesh_path_glb + ".obj"
glb_path = mesh_path_glb + ".glb"
mesh = trimesh.load(obj_path, process=False)
mesh.export(glb_path)
return glb_path |