Commit
·
1fa688f
1
Parent(s):
37c1f6f
Refactor app.py and inference.py to streamline image generation and mesh export processes. Removed unnecessary CPU transfers and temporary file handling, directly returning generated GLB paths. Updated mesh export in CRM model to support vertex colors and improved overall efficiency in texture mapping.
Browse files- app.py +4 -33
- inference.py +63 -69
- model/crm/model.py +3 -3
app.py
CHANGED
@@ -6,7 +6,6 @@ from omegaconf import OmegaConf
|
|
6 |
import torch
|
7 |
from PIL import Image
|
8 |
import PIL
|
9 |
-
import base64
|
10 |
from pipelines import TwoStagePipeline
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
import os
|
@@ -94,9 +93,7 @@ def preprocess_image(image, background_choice, foreground_ratio, backgroud_color
|
|
94 |
image = add_background(image, backgroud_color)
|
95 |
return image.convert("RGB")
|
96 |
|
97 |
-
|
98 |
@spaces.GPU
|
99 |
-
|
100 |
def gen_image(input_image, seed, scale, step):
|
101 |
global pipeline, model, args
|
102 |
pipeline.set_seed(seed)
|
@@ -105,26 +102,12 @@ def gen_image(input_image, seed, scale, step):
|
|
105 |
stage2_images = rt_dict["stage2_images"]
|
106 |
np_imgs = np.concatenate(stage1_images, 1)
|
107 |
np_xyzs = np.concatenate(stage2_images, 1)
|
108 |
-
|
109 |
glb_path = generate3d(model, np_imgs, np_xyzs, args.device)
|
110 |
-
|
111 |
-
# Create a temporary file with a proper name for the GLB data
|
112 |
-
import tempfile
|
113 |
-
import shutil
|
114 |
-
|
115 |
-
# Create a temporary file with a proper extension
|
116 |
-
temp_glb = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
|
117 |
-
temp_glb.close()
|
118 |
-
|
119 |
-
# Copy the generated GLB file to our temporary file
|
120 |
-
shutil.copy2(glb_path, temp_glb.name)
|
121 |
-
|
122 |
-
# Return images and the path to the temporary GLB file
|
123 |
-
return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), temp_glb.name
|
124 |
|
125 |
|
126 |
parser = argparse.ArgumentParser()
|
127 |
-
|
128 |
parser.add_argument(
|
129 |
"--stage1_config",
|
130 |
type=str,
|
@@ -137,6 +120,7 @@ parser.add_argument(
|
|
137 |
default="configs/stage2-v2-snr.yaml",
|
138 |
help="config for stage2",
|
139 |
)
|
|
|
140 |
parser.add_argument("--device", type=str, default="cuda")
|
141 |
args = parser.parse_args()
|
142 |
|
@@ -146,19 +130,6 @@ model = CRM(specs)
|
|
146 |
model.load_state_dict(torch.load(crm_path, map_location="cpu"), strict=False)
|
147 |
model = model.to(args.device)
|
148 |
|
149 |
-
# After loading or instantiating the model, ensure everything is on CPU
|
150 |
-
model = model.cpu()
|
151 |
-
if hasattr(model, 'rgbMlp'):
|
152 |
-
model.rgbMlp = model.rgbMlp.cpu()
|
153 |
-
if hasattr(model, 'decoder'):
|
154 |
-
model.decoder = model.decoder.cpu()
|
155 |
-
if hasattr(model, 'unet2'):
|
156 |
-
model.unet2 = model.unet2.cpu()
|
157 |
-
if hasattr(model.unet2, 'unet'):
|
158 |
-
model.unet2.unet = model.unet2.unet.cpu()
|
159 |
-
if hasattr(model, 'lora'):
|
160 |
-
model.lora = model.lora.cpu()
|
161 |
-
|
162 |
stage1_config = OmegaConf.load(args.stage1_config).config
|
163 |
stage2_config = OmegaConf.load(args.stage2_config).config
|
164 |
stage2_sampler_config = stage2_config.sampler
|
@@ -262,4 +233,4 @@ with gr.Blocks() as demo:
|
|
262 |
inputs=inputs,
|
263 |
outputs=outputs,
|
264 |
)
|
265 |
-
demo.queue().launch()
|
|
|
6 |
import torch
|
7 |
from PIL import Image
|
8 |
import PIL
|
|
|
9 |
from pipelines import TwoStagePipeline
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
import os
|
|
|
93 |
image = add_background(image, backgroud_color)
|
94 |
return image.convert("RGB")
|
95 |
|
|
|
96 |
@spaces.GPU
|
|
|
97 |
def gen_image(input_image, seed, scale, step):
|
98 |
global pipeline, model, args
|
99 |
pipeline.set_seed(seed)
|
|
|
102 |
stage2_images = rt_dict["stage2_images"]
|
103 |
np_imgs = np.concatenate(stage1_images, 1)
|
104 |
np_xyzs = np.concatenate(stage2_images, 1)
|
105 |
+
|
106 |
glb_path = generate3d(model, np_imgs, np_xyzs, args.device)
|
107 |
+
return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path#, obj_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
|
110 |
parser = argparse.ArgumentParser()
|
|
|
111 |
parser.add_argument(
|
112 |
"--stage1_config",
|
113 |
type=str,
|
|
|
120 |
default="configs/stage2-v2-snr.yaml",
|
121 |
help="config for stage2",
|
122 |
)
|
123 |
+
|
124 |
parser.add_argument("--device", type=str, default="cuda")
|
125 |
args = parser.parse_args()
|
126 |
|
|
|
130 |
model.load_state_dict(torch.load(crm_path, map_location="cpu"), strict=False)
|
131 |
model = model.to(args.device)
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
stage1_config = OmegaConf.load(args.stage1_config).config
|
134 |
stage2_config = OmegaConf.load(args.stage2_config).config
|
135 |
stage2_sampler_config = stage2_config.sampler
|
|
|
233 |
inputs=inputs,
|
234 |
outputs=outputs,
|
235 |
)
|
236 |
+
demo.queue().launch()
|
inference.py
CHANGED
@@ -1,96 +1,90 @@
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
import time
|
4 |
-
import tempfile
|
5 |
-
import zipfile
|
6 |
import nvdiffrast.torch as dr
|
7 |
-
import xatlas
|
8 |
-
import cv2
|
9 |
-
import trimesh
|
10 |
-
|
11 |
from util.utils import get_tri
|
|
|
12 |
from mesh import Mesh
|
|
|
13 |
from util.renderer import Renderer
|
14 |
-
|
15 |
-
|
16 |
|
17 |
def generate3d(model, rgb, ccm, device):
|
|
|
18 |
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
color_tri = torch.from_numpy(rgb) / 255
|
22 |
-
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
|
23 |
-
color = color_tri.permute(2, 0, 1)
|
24 |
-
xyz = xyz_tri.permute(2, 0, 1)
|
25 |
|
26 |
def get_imgs(color):
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
|
30 |
color = get_imgs(color)
|
31 |
xyz = get_imgs(xyz)
|
32 |
|
33 |
-
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
|
34 |
-
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
|
35 |
-
triplane = torch.cat([color, xyz], dim=1).to(device)
|
36 |
|
|
|
|
|
37 |
model.eval()
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
with torch.no_grad():
|
43 |
-
triplane_feature2 = model.unet2(triplane,
|
|
|
|
|
|
|
44 |
else:
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
data_config = {
|
49 |
-
'resolution': [1024, 1024],
|
50 |
-
"triview_color": triplane_color.to(device),
|
51 |
-
}
|
52 |
|
53 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
54 |
verts, faces = model.decode(data_config, triplane_feature2)
|
|
|
55 |
data_config['verts'] = verts[0]
|
56 |
data_config['faces'] = faces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
)
|
63 |
-
|
64 |
-
data_config['faces'] = torch.from_numpy(faces).contiguous()
|
65 |
-
|
66 |
-
# CPU-only UV unwrapping with xatlas
|
67 |
-
mesh_v = data_config['verts'].cpu().numpy()
|
68 |
-
mesh_f = data_config['faces'].cpu().numpy()
|
69 |
-
vmapping, ft, vt = xatlas.parametrize(mesh_v, mesh_f)
|
70 |
-
|
71 |
-
# Use per-vertex colors if available, else fallback to white
|
72 |
-
vertex_colors = np.ones((mesh_v.shape[0], 3), dtype=np.float32) # fallback: white
|
73 |
-
# If you have per-vertex color, you can assign here, e.g.:
|
74 |
-
# vertex_colors = ...
|
75 |
-
|
76 |
-
# Bake vertex colors to texture in UV space
|
77 |
-
tex_res = (1024, 1024)
|
78 |
-
texture = np.zeros((tex_res[1], tex_res[0], 3), dtype=np.float32)
|
79 |
-
vt_img = (vt * np.array(tex_res)).astype(np.int32)
|
80 |
-
for face, uv_idx in zip(mesh_f, ft):
|
81 |
-
pts = vt_img[uv_idx]
|
82 |
-
color = vertex_colors[face].mean(axis=0)
|
83 |
-
cv2.fillPoly(texture, [pts], color.tolist())
|
84 |
-
texture = np.clip(texture, 0, 1)
|
85 |
-
|
86 |
-
# Create Mesh and export .glb
|
87 |
-
mesh = Mesh(
|
88 |
-
v=torch.from_numpy(mesh_v).float(),
|
89 |
-
f=torch.from_numpy(mesh_f).int(),
|
90 |
-
vt=torch.from_numpy(vt).float(),
|
91 |
-
ft=torch.from_numpy(ft).int(),
|
92 |
-
albedo=torch.from_numpy(texture).float()
|
93 |
-
)
|
94 |
-
temp_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
|
95 |
-
mesh.write(temp_path)
|
96 |
-
return temp_path
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
import time
|
|
|
|
|
4 |
import nvdiffrast.torch as dr
|
|
|
|
|
|
|
|
|
5 |
from util.utils import get_tri
|
6 |
+
import tempfile
|
7 |
from mesh import Mesh
|
8 |
+
import zipfile
|
9 |
from util.renderer import Renderer
|
10 |
+
import trimesh
|
|
|
11 |
|
12 |
def generate3d(model, rgb, ccm, device):
|
13 |
+
|
14 |
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
|
15 |
+
scale=model.input.scale, geo_type = model.geo_type)
|
16 |
+
|
17 |
+
color_tri = torch.from_numpy(rgb)/255
|
18 |
+
xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
|
19 |
+
color = color_tri.permute(2,0,1)
|
20 |
+
xyz = xyz_tri.permute(2,0,1)
|
21 |
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def get_imgs(color):
|
24 |
+
# color : [C, H, W*6]
|
25 |
+
color_list = []
|
26 |
+
color_list.append(color[:,:,256*5:256*(1+5)])
|
27 |
+
for i in range(0,5):
|
28 |
+
color_list.append(color[:,:,256*i:256*(1+i)])
|
29 |
+
return torch.stack(color_list, dim=0)# [6, C, H, W]
|
30 |
+
|
31 |
+
triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]
|
32 |
|
|
|
33 |
color = get_imgs(color)
|
34 |
xyz = get_imgs(xyz)
|
35 |
|
36 |
+
color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
|
37 |
+
xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)
|
|
|
38 |
|
39 |
+
triplane = torch.cat([color,xyz],dim=1).to(device)
|
40 |
+
# 3D visualize
|
41 |
model.eval()
|
42 |
+
|
43 |
+
|
44 |
+
if model.denoising == True:
|
45 |
+
tnew = 20
|
46 |
+
tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
|
47 |
+
noise_new = torch.randn_like(triplane) *0.5+0.5
|
48 |
+
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
|
49 |
+
start_time = time.time()
|
50 |
with torch.no_grad():
|
51 |
+
triplane_feature2 = model.unet2(triplane,tnew)
|
52 |
+
end_time = time.time()
|
53 |
+
elapsed_time = end_time - start_time
|
54 |
+
print(f"unet takes {elapsed_time}s")
|
55 |
else:
|
56 |
+
triplane_feature2 = model.unet2(triplane)
|
57 |
+
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
with torch.no_grad():
|
60 |
+
data_config = {
|
61 |
+
'resolution': [1024, 1024],
|
62 |
+
"triview_color": triplane_color.to(device),
|
63 |
+
}
|
64 |
+
|
65 |
verts, faces = model.decode(data_config, triplane_feature2)
|
66 |
+
|
67 |
data_config['verts'] = verts[0]
|
68 |
data_config['faces'] = faces
|
69 |
+
|
70 |
+
|
71 |
+
from kiui.mesh_utils import clean_mesh
|
72 |
+
verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=True, remesh_size=0.005, remesh_iters=1)
|
73 |
+
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
|
74 |
+
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
|
75 |
+
|
76 |
+
start_time = time.time()
|
77 |
+
with torch.no_grad():
|
78 |
+
mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
|
79 |
+
model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)
|
80 |
+
|
81 |
+
end_time = time.time()
|
82 |
+
elapsed_time = end_time - start_time
|
83 |
+
print(f"uv takes {elapsed_time}s")
|
84 |
|
85 |
+
# Convert .obj (with vertex colors) to .glb
|
86 |
+
obj_path = mesh_path_glb + ".obj"
|
87 |
+
glb_path = mesh_path_glb + ".glb"
|
88 |
+
mesh = trimesh.load(obj_path, process=False)
|
89 |
+
mesh.export(glb_path)
|
90 |
+
return glb_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/crm/model.py
CHANGED
@@ -107,8 +107,8 @@ class CRM(nn.Module):
|
|
107 |
|
108 |
# export the final mesh
|
109 |
with torch.no_grad():
|
110 |
-
mesh = trimesh.Trimesh(
|
111 |
-
mesh.export(f
|
112 |
|
113 |
def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None):
|
114 |
|
@@ -214,4 +214,4 @@ class CRM(nn.Module):
|
|
214 |
img = img.clip(0, 255).astype(np.uint8)
|
215 |
|
216 |
cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]])
|
217 |
-
# cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]])
|
|
|
107 |
|
108 |
# export the final mesh
|
109 |
with torch.no_grad():
|
110 |
+
mesh = trimesh.Trimesh(verts, faces, vertex_colors=colors, process=False) # important, process=True leads to seg fault...
|
111 |
+
mesh.export(f'{out_dir}.obj')
|
112 |
|
113 |
def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None):
|
114 |
|
|
|
214 |
img = img.clip(0, 255).astype(np.uint8)
|
215 |
|
216 |
cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]])
|
217 |
+
# cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]])
|