Spaces:
mashroo
/
Runtime error

File size: 3,166 Bytes
d63ad23
 
 
d454202
d63ad23
1fa688f
d63ad23
1fa688f
5777f44
1fa688f
f4e8cf6
dfab55e
1fa688f
a14c9ce
1fa688f
 
 
 
 
 
dfab55e
cb29219
a14c9ce
1fa688f
 
 
 
 
 
 
 
f4e8cf6
 
 
 
1fa688f
 
f4e8cf6
1fa688f
 
f4e8cf6
1fa688f
 
 
 
 
 
 
 
a14c9ce
1fa688f
 
 
 
a14c9ce
1fa688f
 
a14c9ce
76eeb7d
1fa688f
 
 
 
 
f4e8cf6
1fa688f
a14c9ce
 
1fa688f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb29219
1fa688f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import torch
import time
import nvdiffrast.torch as dr
from util.utils import get_tri
import tempfile
from mesh import Mesh
import zipfile
from util.renderer import Renderer
import trimesh

def generate3d(model, rgb, ccm, device):

    model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
                                scale=model.input.scale, geo_type = model.geo_type)

    color_tri = torch.from_numpy(rgb)/255
    xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
    color = color_tri.permute(2,0,1)
    xyz = xyz_tri.permute(2,0,1)


    def get_imgs(color):
        # color : [C, H, W*6]
        color_list = []
        color_list.append(color[:,:,256*5:256*(1+5)])
        for i in range(0,5):
            color_list.append(color[:,:,256*i:256*(1+i)])
        return torch.stack(color_list, dim=0)# [6, C, H, W]
    
    triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]

    color = get_imgs(color)
    xyz = get_imgs(xyz)

    color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
    xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)

    triplane = torch.cat([color,xyz],dim=1).to(device)
    # 3D visualize
    model.eval()
    

    if model.denoising == True:
        tnew = 20
        tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
        noise_new = torch.randn_like(triplane) *0.5+0.5
        triplane = model.scheduler.add_noise(triplane, noise_new, tnew)    
        start_time = time.time()
        with torch.no_grad():
            triplane_feature2 = model.unet2(triplane,tnew)
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"unet takes {elapsed_time}s")
    else:
        triplane_feature2 = model.unet2(triplane)
        

    with torch.no_grad():
        data_config = {
            'resolution': [1024, 1024],
            "triview_color": triplane_color.to(device),
        }

        verts, faces = model.decode(data_config, triplane_feature2)

        data_config['verts'] = verts[0]
        data_config['faces'] = faces
        

    from kiui.mesh_utils import clean_mesh
    verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=True, remesh_size=0.005, remesh_iters=1)
    data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
    data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()

    start_time = time.time()
    with torch.no_grad():
        mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
        model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)

    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"uv takes {elapsed_time}s")

    # Convert .obj (with vertex colors) to .glb
    obj_path = mesh_path_glb + ".obj"
    glb_path = mesh_path_glb + ".glb"
    mesh = trimesh.load(obj_path, process=False)
    mesh.export(glb_path)
    return glb_path