Spaces:
mashroo
/
Running on Zero

YoussefAnso commited on
Commit
1fa688f
·
1 Parent(s): 37c1f6f

Refactor app.py and inference.py to streamline image generation and mesh export processes. Removed unnecessary CPU transfers and temporary file handling, directly returning generated GLB paths. Updated mesh export in CRM model to support vertex colors and improved overall efficiency in texture mapping.

Browse files
Files changed (3) hide show
  1. app.py +4 -33
  2. inference.py +63 -69
  3. model/crm/model.py +3 -3
app.py CHANGED
@@ -6,7 +6,6 @@ from omegaconf import OmegaConf
6
  import torch
7
  from PIL import Image
8
  import PIL
9
- import base64
10
  from pipelines import TwoStagePipeline
11
  from huggingface_hub import hf_hub_download
12
  import os
@@ -94,9 +93,7 @@ def preprocess_image(image, background_choice, foreground_ratio, backgroud_color
94
  image = add_background(image, backgroud_color)
95
  return image.convert("RGB")
96
 
97
-
98
  @spaces.GPU
99
-
100
  def gen_image(input_image, seed, scale, step):
101
  global pipeline, model, args
102
  pipeline.set_seed(seed)
@@ -105,26 +102,12 @@ def gen_image(input_image, seed, scale, step):
105
  stage2_images = rt_dict["stage2_images"]
106
  np_imgs = np.concatenate(stage1_images, 1)
107
  np_xyzs = np.concatenate(stage2_images, 1)
108
-
109
  glb_path = generate3d(model, np_imgs, np_xyzs, args.device)
110
-
111
- # Create a temporary file with a proper name for the GLB data
112
- import tempfile
113
- import shutil
114
-
115
- # Create a temporary file with a proper extension
116
- temp_glb = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
117
- temp_glb.close()
118
-
119
- # Copy the generated GLB file to our temporary file
120
- shutil.copy2(glb_path, temp_glb.name)
121
-
122
- # Return images and the path to the temporary GLB file
123
- return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), temp_glb.name
124
 
125
 
126
  parser = argparse.ArgumentParser()
127
-
128
  parser.add_argument(
129
  "--stage1_config",
130
  type=str,
@@ -137,6 +120,7 @@ parser.add_argument(
137
  default="configs/stage2-v2-snr.yaml",
138
  help="config for stage2",
139
  )
 
140
  parser.add_argument("--device", type=str, default="cuda")
141
  args = parser.parse_args()
142
 
@@ -146,19 +130,6 @@ model = CRM(specs)
146
  model.load_state_dict(torch.load(crm_path, map_location="cpu"), strict=False)
147
  model = model.to(args.device)
148
 
149
- # After loading or instantiating the model, ensure everything is on CPU
150
- model = model.cpu()
151
- if hasattr(model, 'rgbMlp'):
152
- model.rgbMlp = model.rgbMlp.cpu()
153
- if hasattr(model, 'decoder'):
154
- model.decoder = model.decoder.cpu()
155
- if hasattr(model, 'unet2'):
156
- model.unet2 = model.unet2.cpu()
157
- if hasattr(model.unet2, 'unet'):
158
- model.unet2.unet = model.unet2.unet.cpu()
159
- if hasattr(model, 'lora'):
160
- model.lora = model.lora.cpu()
161
-
162
  stage1_config = OmegaConf.load(args.stage1_config).config
163
  stage2_config = OmegaConf.load(args.stage2_config).config
164
  stage2_sampler_config = stage2_config.sampler
@@ -262,4 +233,4 @@ with gr.Blocks() as demo:
262
  inputs=inputs,
263
  outputs=outputs,
264
  )
265
- demo.queue().launch()
 
6
  import torch
7
  from PIL import Image
8
  import PIL
 
9
  from pipelines import TwoStagePipeline
10
  from huggingface_hub import hf_hub_download
11
  import os
 
93
  image = add_background(image, backgroud_color)
94
  return image.convert("RGB")
95
 
 
96
  @spaces.GPU
 
97
  def gen_image(input_image, seed, scale, step):
98
  global pipeline, model, args
99
  pipeline.set_seed(seed)
 
102
  stage2_images = rt_dict["stage2_images"]
103
  np_imgs = np.concatenate(stage1_images, 1)
104
  np_xyzs = np.concatenate(stage2_images, 1)
105
+
106
  glb_path = generate3d(model, np_imgs, np_xyzs, args.device)
107
+ return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path#, obj_path
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  parser = argparse.ArgumentParser()
 
111
  parser.add_argument(
112
  "--stage1_config",
113
  type=str,
 
120
  default="configs/stage2-v2-snr.yaml",
121
  help="config for stage2",
122
  )
123
+
124
  parser.add_argument("--device", type=str, default="cuda")
125
  args = parser.parse_args()
126
 
 
130
  model.load_state_dict(torch.load(crm_path, map_location="cpu"), strict=False)
131
  model = model.to(args.device)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  stage1_config = OmegaConf.load(args.stage1_config).config
134
  stage2_config = OmegaConf.load(args.stage2_config).config
135
  stage2_sampler_config = stage2_config.sampler
 
233
  inputs=inputs,
234
  outputs=outputs,
235
  )
236
+ demo.queue().launch()
inference.py CHANGED
@@ -1,96 +1,90 @@
1
  import numpy as np
2
  import torch
3
  import time
4
- import tempfile
5
- import zipfile
6
  import nvdiffrast.torch as dr
7
- import xatlas
8
- import cv2
9
- import trimesh
10
-
11
  from util.utils import get_tri
 
12
  from mesh import Mesh
 
13
  from util.renderer import Renderer
14
- from kiui.mesh_utils import clean_mesh
15
-
16
 
17
  def generate3d(model, rgb, ccm, device):
 
18
  model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
19
- scale=model.input.scale, geo_type=model.geo_type)
 
 
 
 
 
20
 
21
- color_tri = torch.from_numpy(rgb) / 255
22
- xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
23
- color = color_tri.permute(2, 0, 1)
24
- xyz = xyz_tri.permute(2, 0, 1)
25
 
26
  def get_imgs(color):
27
- return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0)
 
 
 
 
 
 
 
28
 
29
- triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
30
  color = get_imgs(color)
31
  xyz = get_imgs(xyz)
32
 
33
- color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
34
- xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
35
- triplane = torch.cat([color, xyz], dim=1).to(device)
36
 
 
 
37
  model.eval()
38
- if model.denoising:
39
- tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
40
- noise_new = torch.randn_like(triplane) * 0.5 + 0.5
41
- triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
 
 
 
 
42
  with torch.no_grad():
43
- triplane_feature2 = model.unet2(triplane, tnew)
 
 
 
44
  else:
45
- with torch.no_grad():
46
- triplane_feature2 = model.unet2(triplane)
47
-
48
- data_config = {
49
- 'resolution': [1024, 1024],
50
- "triview_color": triplane_color.to(device),
51
- }
52
 
53
  with torch.no_grad():
 
 
 
 
 
54
  verts, faces = model.decode(data_config, triplane_feature2)
 
55
  data_config['verts'] = verts[0]
56
  data_config['faces'] = faces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- verts, faces = clean_mesh(
59
- data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
60
- data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
61
- repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
62
- )
63
- data_config['verts'] = torch.from_numpy(verts).contiguous()
64
- data_config['faces'] = torch.from_numpy(faces).contiguous()
65
-
66
- # CPU-only UV unwrapping with xatlas
67
- mesh_v = data_config['verts'].cpu().numpy()
68
- mesh_f = data_config['faces'].cpu().numpy()
69
- vmapping, ft, vt = xatlas.parametrize(mesh_v, mesh_f)
70
-
71
- # Use per-vertex colors if available, else fallback to white
72
- vertex_colors = np.ones((mesh_v.shape[0], 3), dtype=np.float32) # fallback: white
73
- # If you have per-vertex color, you can assign here, e.g.:
74
- # vertex_colors = ...
75
-
76
- # Bake vertex colors to texture in UV space
77
- tex_res = (1024, 1024)
78
- texture = np.zeros((tex_res[1], tex_res[0], 3), dtype=np.float32)
79
- vt_img = (vt * np.array(tex_res)).astype(np.int32)
80
- for face, uv_idx in zip(mesh_f, ft):
81
- pts = vt_img[uv_idx]
82
- color = vertex_colors[face].mean(axis=0)
83
- cv2.fillPoly(texture, [pts], color.tolist())
84
- texture = np.clip(texture, 0, 1)
85
-
86
- # Create Mesh and export .glb
87
- mesh = Mesh(
88
- v=torch.from_numpy(mesh_v).float(),
89
- f=torch.from_numpy(mesh_f).int(),
90
- vt=torch.from_numpy(vt).float(),
91
- ft=torch.from_numpy(ft).int(),
92
- albedo=torch.from_numpy(texture).float()
93
- )
94
- temp_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
95
- mesh.write(temp_path)
96
- return temp_path
 
1
  import numpy as np
2
  import torch
3
  import time
 
 
4
  import nvdiffrast.torch as dr
 
 
 
 
5
  from util.utils import get_tri
6
+ import tempfile
7
  from mesh import Mesh
8
+ import zipfile
9
  from util.renderer import Renderer
10
+ import trimesh
 
11
 
12
  def generate3d(model, rgb, ccm, device):
13
+
14
  model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
15
+ scale=model.input.scale, geo_type = model.geo_type)
16
+
17
+ color_tri = torch.from_numpy(rgb)/255
18
+ xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
19
+ color = color_tri.permute(2,0,1)
20
+ xyz = xyz_tri.permute(2,0,1)
21
 
 
 
 
 
22
 
23
  def get_imgs(color):
24
+ # color : [C, H, W*6]
25
+ color_list = []
26
+ color_list.append(color[:,:,256*5:256*(1+5)])
27
+ for i in range(0,5):
28
+ color_list.append(color[:,:,256*i:256*(1+i)])
29
+ return torch.stack(color_list, dim=0)# [6, C, H, W]
30
+
31
+ triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]
32
 
 
33
  color = get_imgs(color)
34
  xyz = get_imgs(xyz)
35
 
36
+ color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
37
+ xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)
 
38
 
39
+ triplane = torch.cat([color,xyz],dim=1).to(device)
40
+ # 3D visualize
41
  model.eval()
42
+
43
+
44
+ if model.denoising == True:
45
+ tnew = 20
46
+ tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
47
+ noise_new = torch.randn_like(triplane) *0.5+0.5
48
+ triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
49
+ start_time = time.time()
50
  with torch.no_grad():
51
+ triplane_feature2 = model.unet2(triplane,tnew)
52
+ end_time = time.time()
53
+ elapsed_time = end_time - start_time
54
+ print(f"unet takes {elapsed_time}s")
55
  else:
56
+ triplane_feature2 = model.unet2(triplane)
57
+
 
 
 
 
 
58
 
59
  with torch.no_grad():
60
+ data_config = {
61
+ 'resolution': [1024, 1024],
62
+ "triview_color": triplane_color.to(device),
63
+ }
64
+
65
  verts, faces = model.decode(data_config, triplane_feature2)
66
+
67
  data_config['verts'] = verts[0]
68
  data_config['faces'] = faces
69
+
70
+
71
+ from kiui.mesh_utils import clean_mesh
72
+ verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=True, remesh_size=0.005, remesh_iters=1)
73
+ data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
74
+ data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
75
+
76
+ start_time = time.time()
77
+ with torch.no_grad():
78
+ mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
79
+ model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)
80
+
81
+ end_time = time.time()
82
+ elapsed_time = end_time - start_time
83
+ print(f"uv takes {elapsed_time}s")
84
 
85
+ # Convert .obj (with vertex colors) to .glb
86
+ obj_path = mesh_path_glb + ".obj"
87
+ glb_path = mesh_path_glb + ".glb"
88
+ mesh = trimesh.load(obj_path, process=False)
89
+ mesh.export(glb_path)
90
+ return glb_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/crm/model.py CHANGED
@@ -107,8 +107,8 @@ class CRM(nn.Module):
107
 
108
  # export the final mesh
109
  with torch.no_grad():
110
- mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False) # important, process=True leads to seg fault...
111
- mesh.export(f"{out_dir}.glb", file_type="glb")
112
 
113
  def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None):
114
 
@@ -214,4 +214,4 @@ class CRM(nn.Module):
214
  img = img.clip(0, 255).astype(np.uint8)
215
 
216
  cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]])
217
- # cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]])
 
107
 
108
  # export the final mesh
109
  with torch.no_grad():
110
+ mesh = trimesh.Trimesh(verts, faces, vertex_colors=colors, process=False) # important, process=True leads to seg fault...
111
+ mesh.export(f'{out_dir}.obj')
112
 
113
  def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None):
114
 
 
214
  img = img.clip(0, 255).astype(np.uint8)
215
 
216
  cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]])
217
+ # cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]])