File size: 4,193 Bytes
d72a5f9 d63ad23 d72a5f9 9a3d596 d72a5f9 9a3d596 d72a5f9 9a3d596 9d0b3b4 9a3d596 9d0b3b4 d72a5f9 9a3d596 d72a5f9 9a3d596 d72a5f9 9a3d596 d72a5f9 9a3d596 d72a5f9 9a3d596 db2fd1d 9a3d596 d72a5f9 9a3d596 d72a5f9 9a3d596 d72a5f9 9a3d596 d65fe1c d72a5f9 6b79750 d72a5f9 9a3d596 d72a5f9 9a3d596 d72a5f9 5175c4e dc9d3c9 5175c4e d72a5f9 eb6226d 411e8a2 9a3d596 411e8a2 9a3d596 e5c94c9 9a3d596 e5c94c9 9a3d596 a34ef67 9a3d596 b913122 3fc7370 9a3d596 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import numpy as np
import torch
import time
import nvdiffrast.torch as dr
from util.utils import get_tri
import tempfile
from mesh import Mesh
import zipfile
from util.renderer import Renderer
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
scale=model.input.scale, geo_type = model.geo_type)
color_tri = torch.from_numpy(rgb)/255
xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
color = color_tri.permute(2,0,1)
xyz = xyz_tri.permute(2,0,1)
def get_imgs(color):
# color : [C, H, W*6]
color_list = []
color_list.append(color[:,:,256*5:256*(1+5)])
for i in range(0,5):
color_list.append(color[:,:,256*i:256*(1+i)])
return torch.stack(color_list, dim=0)# [6, C, H, W]
triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)
triplane = torch.cat([color,xyz],dim=1).to(device)
# 3D visualize
model.eval()
if model.denoising == True:
tnew = 20
tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) *0.5+0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
start_time = time.time()
with torch.no_grad():
triplane_feature2 = model.unet2(triplane,tnew)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"unet takes {elapsed_time}s")
else:
triplane_feature2 = model.unet2(triplane)
with torch.no_grad():
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
from kiui.mesh_utils import clean_mesh
verts, faces = clean_mesh(
data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
repair=False, remesh=True, remesh_size=0.01, remesh_iters=10
)
data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()
start_time = time.time()
with torch.no_grad():
mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)
# glctx = dr.RasterizeGLContext()#dr.RasterizeCudaContext()
# mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
# model.export_mesh_wt_uv(glctx, data_config, mesh_path_obj, "", device, res=(1024,1024), tri_fea_2=triplane_feature2)
# mesh = Mesh.load(mesh_path_obj+".obj", bound=0.9, front_dir="+z")
# mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
# mesh.write(mesh_path_glb+".glb")
# # mesh_obj2 = trimesh.load(mesh_path_glb+".glb", file_type='glb')
# # mesh_path_obj2 = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
# # mesh_obj2.export(mesh_path_obj2+".obj")
# with zipfile.ZipFile(mesh_path_obj+'.zip', 'w') as myzip:
# myzip.write(mesh_path_obj+'.obj', mesh_path_obj.split("/")[-1]+'.obj')
# myzip.write(mesh_path_obj+'.png', mesh_path_obj.split("/")[-1]+'.png')
# myzip.write(mesh_path_obj+'.mtl', mesh_path_obj.split("/")[-1]+'.mtl')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"uv takes {elapsed_time}s")
# print("Raw color min/max/mean:", colors.min(), colors.max(), colors.mean())
# print("Scaled color min/max/mean:", colors.min(), colors.max(), colors.mean())
return mesh_path_glb+".obj" |