Spaces:
mashroo
/
Running on Zero

File size: 4,193 Bytes
d72a5f9
d63ad23
d72a5f9
9a3d596
d72a5f9
 
9a3d596
 
d72a5f9
9a3d596
9d0b3b4
9a3d596
 
 
 
 
 
 
9d0b3b4
 
d72a5f9
9a3d596
d72a5f9
9a3d596
 
 
 
 
 
d72a5f9
 
 
 
9a3d596
 
d72a5f9
9a3d596
 
d72a5f9
9a3d596
db2fd1d
9a3d596
d72a5f9
9a3d596
 
 
 
d72a5f9
9a3d596
 
 
 
d72a5f9
 
9a3d596
d65fe1c
d72a5f9
 
6b79750
d72a5f9
 
 
 
9a3d596
d72a5f9
 
9a3d596
d72a5f9
 
5175c4e
 
 
dc9d3c9
5175c4e
 
 
d72a5f9
eb6226d
 
 
 
411e8a2
9a3d596
 
 
411e8a2
9a3d596
 
 
e5c94c9
9a3d596
 
 
e5c94c9
9a3d596
 
 
 
a34ef67
9a3d596
 
 
b913122
3fc7370
 
9a3d596
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import torch
import time
import nvdiffrast.torch as dr
from util.utils import get_tri
import tempfile
from mesh import Mesh
import zipfile
from util.renderer import Renderer
def generate3d(model, rgb, ccm, device):

    model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
                                scale=model.input.scale, geo_type = model.geo_type)

    color_tri = torch.from_numpy(rgb)/255
    xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
    color = color_tri.permute(2,0,1)
    xyz = xyz_tri.permute(2,0,1)


    def get_imgs(color):
        # color : [C, H, W*6]
        color_list = []
        color_list.append(color[:,:,256*5:256*(1+5)])
        for i in range(0,5):
            color_list.append(color[:,:,256*i:256*(1+i)])
        return torch.stack(color_list, dim=0)# [6, C, H, W]
    
    triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]

    color = get_imgs(color)
    xyz = get_imgs(xyz)

    color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
    xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)

    triplane = torch.cat([color,xyz],dim=1).to(device)
    # 3D visualize
    model.eval()
    

    if model.denoising == True:
        tnew = 20
        tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
        noise_new = torch.randn_like(triplane) *0.5+0.5
        triplane = model.scheduler.add_noise(triplane, noise_new, tnew)    
        start_time = time.time()
        with torch.no_grad():
            triplane_feature2 = model.unet2(triplane,tnew)
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"unet takes {elapsed_time}s")
    else:
        triplane_feature2 = model.unet2(triplane)
        

    with torch.no_grad():
        data_config = {
            'resolution': [1024, 1024],
            "triview_color": triplane_color.to(device),
        }

        verts, faces = model.decode(data_config, triplane_feature2)

        data_config['verts'] = verts[0]
        data_config['faces'] = faces
        

    from kiui.mesh_utils import clean_mesh
    verts, faces = clean_mesh(
        data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
        data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
        repair=False, remesh=True, remesh_size=0.01, remesh_iters=10
    )
    data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
    data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()

    start_time = time.time()
    with torch.no_grad():
        mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
        model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)

        # glctx = dr.RasterizeGLContext()#dr.RasterizeCudaContext()
        # mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
        # model.export_mesh_wt_uv(glctx, data_config, mesh_path_obj, "", device, res=(1024,1024), tri_fea_2=triplane_feature2)

        # mesh = Mesh.load(mesh_path_obj+".obj", bound=0.9, front_dir="+z")
        # mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
        # mesh.write(mesh_path_glb+".glb")

        # # mesh_obj2 = trimesh.load(mesh_path_glb+".glb", file_type='glb')
        # # mesh_path_obj2 = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
        # # mesh_obj2.export(mesh_path_obj2+".obj")

        # with zipfile.ZipFile(mesh_path_obj+'.zip', 'w') as myzip:
        #     myzip.write(mesh_path_obj+'.obj', mesh_path_obj.split("/")[-1]+'.obj')
        #     myzip.write(mesh_path_obj+'.png', mesh_path_obj.split("/")[-1]+'.png')
        #     myzip.write(mesh_path_obj+'.mtl', mesh_path_obj.split("/")[-1]+'.mtl')

    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"uv takes {elapsed_time}s")

    # print("Raw color min/max/mean:", colors.min(), colors.max(), colors.mean())
    # print("Scaled color min/max/mean:", colors.min(), colors.max(), colors.mean())
    return mesh_path_glb+".obj"