Commit
·
d493b2e
1
Parent(s):
d72a5f9
Enhance Mesh class documentation by adding missing line breaks in docstrings for improved readability. Update device handling in FlexiCubes and FlexiCubesGeometry classes to default to 'cuda', ensuring consistent device usage across the application. Refactor ImageDreamDiffusion class to assert mode validity and streamline camera matrix pre-computation.
Browse files- libs/sample.py +17 -12
- mesh.py +23 -1
- util/flexicubes.py +4 -4
- util/flexicubes_geometry.py +2 -3
libs/sample.py
CHANGED
@@ -11,32 +11,36 @@ class ImageDreamDiffusion:
|
|
11 |
def __init__(
|
12 |
self,
|
13 |
model,
|
14 |
-
device
|
15 |
-
dtype
|
16 |
-
mode
|
17 |
-
num_frames
|
18 |
-
camera_views
|
19 |
-
ref_position
|
20 |
random_background=False,
|
21 |
offset_noise=False,
|
22 |
resize_rate=1,
|
23 |
image_size=256,
|
24 |
seed=1234,
|
25 |
) -> None:
|
26 |
-
|
27 |
-
|
28 |
self.seed = seed
|
29 |
batch_size = max(4, num_frames)
|
|
|
30 |
neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear."
|
31 |
uc = model.get_learned_conditioning([neg_texts]).to(device)
|
32 |
sampler = DDIMSampler(model)
|
|
|
|
|
33 |
camera = [get_camera_for_index(i).squeeze() for i in camera_views]
|
34 |
-
camera[ref_position] = torch.zeros_like(camera[ref_position])
|
35 |
camera = torch.stack(camera)
|
36 |
camera = camera.repeat(batch_size // num_frames, 1).to(device)
|
|
|
37 |
self.image_transform = T.Compose(
|
38 |
[
|
39 |
-
T.Resize((
|
40 |
T.ToTensor(),
|
41 |
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
42 |
]
|
@@ -47,7 +51,8 @@ class ImageDreamDiffusion:
|
|
47 |
self.random_background = random_background
|
48 |
self.resize_rate = resize_rate
|
49 |
self.num_frames = num_frames
|
50 |
-
self.size =
|
|
|
51 |
self.batch_size = batch_size
|
52 |
self.model = model
|
53 |
self.sampler = sampler
|
@@ -372,4 +377,4 @@ class ImageDreamDiffusionStage2:
|
|
372 |
)
|
373 |
images.append(img)
|
374 |
set_seed() # unset random and numpy seed
|
375 |
-
return images
|
|
|
11 |
def __init__(
|
12 |
self,
|
13 |
model,
|
14 |
+
device,
|
15 |
+
dtype,
|
16 |
+
mode,
|
17 |
+
num_frames,
|
18 |
+
camera_views,
|
19 |
+
ref_position,
|
20 |
random_background=False,
|
21 |
offset_noise=False,
|
22 |
resize_rate=1,
|
23 |
image_size=256,
|
24 |
seed=1234,
|
25 |
) -> None:
|
26 |
+
assert mode in ["pixel", "local"]
|
27 |
+
size = image_size
|
28 |
self.seed = seed
|
29 |
batch_size = max(4, num_frames)
|
30 |
+
|
31 |
neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear."
|
32 |
uc = model.get_learned_conditioning([neg_texts]).to(device)
|
33 |
sampler = DDIMSampler(model)
|
34 |
+
|
35 |
+
# pre-compute camera matrices
|
36 |
camera = [get_camera_for_index(i).squeeze() for i in camera_views]
|
37 |
+
camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero
|
38 |
camera = torch.stack(camera)
|
39 |
camera = camera.repeat(batch_size // num_frames, 1).to(device)
|
40 |
+
|
41 |
self.image_transform = T.Compose(
|
42 |
[
|
43 |
+
T.Resize((size, size)),
|
44 |
T.ToTensor(),
|
45 |
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
46 |
]
|
|
|
51 |
self.random_background = random_background
|
52 |
self.resize_rate = resize_rate
|
53 |
self.num_frames = num_frames
|
54 |
+
self.size = size
|
55 |
+
self.device = device
|
56 |
self.batch_size = batch_size
|
57 |
self.model = model
|
58 |
self.sampler = sampler
|
|
|
377 |
)
|
378 |
images.append(img)
|
379 |
set_seed() # unset random and numpy seed
|
380 |
+
return images
|
mesh.py
CHANGED
@@ -10,6 +10,7 @@ from kiui.typing import *
|
|
10 |
class Mesh:
|
11 |
"""
|
12 |
A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
|
|
|
13 |
Note:
|
14 |
This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
|
15 |
"""
|
@@ -27,6 +28,7 @@ class Mesh:
|
|
27 |
device: Optional[torch.device] = None,
|
28 |
):
|
29 |
"""Init a mesh directly using all attributes.
|
|
|
30 |
Args:
|
31 |
v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
|
32 |
f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
|
@@ -60,6 +62,7 @@ class Mesh:
|
|
60 |
@classmethod
|
61 |
def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
|
62 |
"""load mesh from path.
|
|
|
63 |
Args:
|
64 |
path (str): path to mesh file, supports ply, obj, glb.
|
65 |
clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
|
@@ -73,6 +76,7 @@ class Mesh:
|
|
73 |
Note:
|
74 |
a ``device`` keyword argument can be provided to specify the torch device.
|
75 |
If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
|
|
|
76 |
Returns:
|
77 |
Mesh: the loaded Mesh object.
|
78 |
"""
|
@@ -136,6 +140,7 @@ class Mesh:
|
|
136 |
@classmethod
|
137 |
def load_obj(cls, path, albedo_path=None, device=None):
|
138 |
"""load an ``obj`` mesh.
|
|
|
139 |
Args:
|
140 |
path (str): path to mesh.
|
141 |
albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
|
@@ -144,6 +149,7 @@ class Mesh:
|
|
144 |
Note:
|
145 |
We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
|
146 |
The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
|
|
|
147 |
Returns:
|
148 |
Mesh: the loaded Mesh object.
|
149 |
"""
|
@@ -307,13 +313,17 @@ class Mesh:
|
|
307 |
@classmethod
|
308 |
def load_trimesh(cls, path, device=None):
|
309 |
"""load a mesh using ``trimesh.load()``.
|
|
|
310 |
Can load various formats like ``glb`` and serves as a fallback.
|
|
|
311 |
Note:
|
312 |
We will try to merge all meshes if the glb contains more than one,
|
313 |
but **this may cause the texture to lose**, since we only support one texture image!
|
|
|
314 |
Args:
|
315 |
path (str): path to the mesh file.
|
316 |
device (torch.device, optional): torch device. Defaults to None.
|
|
|
317 |
Returns:
|
318 |
Mesh: the loaded Mesh object.
|
319 |
"""
|
@@ -413,8 +423,10 @@ class Mesh:
|
|
413 |
# sample surface (using trimesh)
|
414 |
def sample_surface(self, count: int):
|
415 |
"""sample points on the surface of the mesh.
|
|
|
416 |
Args:
|
417 |
count (int): number of points to sample.
|
|
|
418 |
Returns:
|
419 |
torch.Tensor: the sampled points, float [count, 3].
|
420 |
"""
|
@@ -426,6 +438,7 @@ class Mesh:
|
|
426 |
# aabb
|
427 |
def aabb(self):
|
428 |
"""get the axis-aligned bounding box of the mesh.
|
|
|
429 |
Returns:
|
430 |
Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
|
431 |
"""
|
@@ -435,6 +448,7 @@ class Mesh:
|
|
435 |
@torch.no_grad()
|
436 |
def auto_size(self, bound=0.9):
|
437 |
"""auto resize the mesh.
|
|
|
438 |
Args:
|
439 |
bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
|
440 |
"""
|
@@ -470,6 +484,7 @@ class Mesh:
|
|
470 |
|
471 |
def auto_uv(self, cache_path=None, vmap=True):
|
472 |
"""auto calculate the uv coordinates.
|
|
|
473 |
Args:
|
474 |
cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
|
475 |
vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf).
|
@@ -508,6 +523,7 @@ class Mesh:
|
|
508 |
|
509 |
def align_v_to_vt(self, vmapping=None):
|
510 |
""" remap v/f and vn/fn to vt/ft.
|
|
|
511 |
Args:
|
512 |
vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
|
513 |
"""
|
@@ -526,8 +542,10 @@ class Mesh:
|
|
526 |
|
527 |
def to(self, device):
|
528 |
"""move all tensor attributes to device.
|
|
|
529 |
Args:
|
530 |
device (torch.device): target device.
|
|
|
531 |
Returns:
|
532 |
Mesh: self.
|
533 |
"""
|
@@ -540,6 +558,7 @@ class Mesh:
|
|
540 |
|
541 |
def write(self, path):
|
542 |
"""write the mesh to a path.
|
|
|
543 |
Args:
|
544 |
path (str): path to write, supports ply, obj and glb.
|
545 |
"""
|
@@ -554,6 +573,7 @@ class Mesh:
|
|
554 |
|
555 |
def write_ply(self, path):
|
556 |
"""write the mesh in ply format. Only for geometry!
|
|
|
557 |
Args:
|
558 |
path (str): path to write.
|
559 |
"""
|
@@ -571,6 +591,7 @@ class Mesh:
|
|
571 |
def write_glb(self, path):
|
572 |
"""write the mesh in glb/gltf format.
|
573 |
This will create a scene with a single mesh.
|
|
|
574 |
Args:
|
575 |
path (str): path to write.
|
576 |
"""
|
@@ -757,6 +778,7 @@ class Mesh:
|
|
757 |
|
758 |
def write_obj(self, path):
|
759 |
"""write the mesh in obj format. Will also write the texture and mtl files.
|
|
|
760 |
Args:
|
761 |
path (str): path to write.
|
762 |
"""
|
@@ -819,4 +841,4 @@ class Mesh:
|
|
819 |
metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
|
820 |
metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
|
821 |
cv2.imwrite(metallic_path, metallicRoughness[..., 2])
|
822 |
-
cv2.imwrite(roughness_path, metallicRoughness[..., 1])
|
|
|
10 |
class Mesh:
|
11 |
"""
|
12 |
A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
|
13 |
+
|
14 |
Note:
|
15 |
This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
|
16 |
"""
|
|
|
28 |
device: Optional[torch.device] = None,
|
29 |
):
|
30 |
"""Init a mesh directly using all attributes.
|
31 |
+
|
32 |
Args:
|
33 |
v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
|
34 |
f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
|
|
|
62 |
@classmethod
|
63 |
def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
|
64 |
"""load mesh from path.
|
65 |
+
|
66 |
Args:
|
67 |
path (str): path to mesh file, supports ply, obj, glb.
|
68 |
clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
|
|
|
76 |
Note:
|
77 |
a ``device`` keyword argument can be provided to specify the torch device.
|
78 |
If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
|
79 |
+
|
80 |
Returns:
|
81 |
Mesh: the loaded Mesh object.
|
82 |
"""
|
|
|
140 |
@classmethod
|
141 |
def load_obj(cls, path, albedo_path=None, device=None):
|
142 |
"""load an ``obj`` mesh.
|
143 |
+
|
144 |
Args:
|
145 |
path (str): path to mesh.
|
146 |
albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
|
|
|
149 |
Note:
|
150 |
We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
|
151 |
The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
|
152 |
+
|
153 |
Returns:
|
154 |
Mesh: the loaded Mesh object.
|
155 |
"""
|
|
|
313 |
@classmethod
|
314 |
def load_trimesh(cls, path, device=None):
|
315 |
"""load a mesh using ``trimesh.load()``.
|
316 |
+
|
317 |
Can load various formats like ``glb`` and serves as a fallback.
|
318 |
+
|
319 |
Note:
|
320 |
We will try to merge all meshes if the glb contains more than one,
|
321 |
but **this may cause the texture to lose**, since we only support one texture image!
|
322 |
+
|
323 |
Args:
|
324 |
path (str): path to the mesh file.
|
325 |
device (torch.device, optional): torch device. Defaults to None.
|
326 |
+
|
327 |
Returns:
|
328 |
Mesh: the loaded Mesh object.
|
329 |
"""
|
|
|
423 |
# sample surface (using trimesh)
|
424 |
def sample_surface(self, count: int):
|
425 |
"""sample points on the surface of the mesh.
|
426 |
+
|
427 |
Args:
|
428 |
count (int): number of points to sample.
|
429 |
+
|
430 |
Returns:
|
431 |
torch.Tensor: the sampled points, float [count, 3].
|
432 |
"""
|
|
|
438 |
# aabb
|
439 |
def aabb(self):
|
440 |
"""get the axis-aligned bounding box of the mesh.
|
441 |
+
|
442 |
Returns:
|
443 |
Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
|
444 |
"""
|
|
|
448 |
@torch.no_grad()
|
449 |
def auto_size(self, bound=0.9):
|
450 |
"""auto resize the mesh.
|
451 |
+
|
452 |
Args:
|
453 |
bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
|
454 |
"""
|
|
|
484 |
|
485 |
def auto_uv(self, cache_path=None, vmap=True):
|
486 |
"""auto calculate the uv coordinates.
|
487 |
+
|
488 |
Args:
|
489 |
cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
|
490 |
vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf).
|
|
|
523 |
|
524 |
def align_v_to_vt(self, vmapping=None):
|
525 |
""" remap v/f and vn/fn to vt/ft.
|
526 |
+
|
527 |
Args:
|
528 |
vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
|
529 |
"""
|
|
|
542 |
|
543 |
def to(self, device):
|
544 |
"""move all tensor attributes to device.
|
545 |
+
|
546 |
Args:
|
547 |
device (torch.device): target device.
|
548 |
+
|
549 |
Returns:
|
550 |
Mesh: self.
|
551 |
"""
|
|
|
558 |
|
559 |
def write(self, path):
|
560 |
"""write the mesh to a path.
|
561 |
+
|
562 |
Args:
|
563 |
path (str): path to write, supports ply, obj and glb.
|
564 |
"""
|
|
|
573 |
|
574 |
def write_ply(self, path):
|
575 |
"""write the mesh in ply format. Only for geometry!
|
576 |
+
|
577 |
Args:
|
578 |
path (str): path to write.
|
579 |
"""
|
|
|
591 |
def write_glb(self, path):
|
592 |
"""write the mesh in glb/gltf format.
|
593 |
This will create a scene with a single mesh.
|
594 |
+
|
595 |
Args:
|
596 |
path (str): path to write.
|
597 |
"""
|
|
|
778 |
|
779 |
def write_obj(self, path):
|
780 |
"""write the mesh in obj format. Will also write the texture and mtl files.
|
781 |
+
|
782 |
Args:
|
783 |
path (str): path to write.
|
784 |
"""
|
|
|
841 |
metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
|
842 |
metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
|
843 |
cv2.imwrite(metallic_path, metallicRoughness[..., 2])
|
844 |
+
cv2.imwrite(roughness_path, metallicRoughness[..., 1])
|
util/flexicubes.py
CHANGED
@@ -25,7 +25,7 @@ class FlexiCubes:
|
|
25 |
PyTorch tensors on the specified device.
|
26 |
|
27 |
Attributes:
|
28 |
-
device (str): Specifies the computational device (default is "
|
29 |
dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
|
30 |
associated with each dual vertex in 256 Marching Cubes (MC) configurations.
|
31 |
num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
|
@@ -64,8 +64,8 @@ class FlexiCubes:
|
|
64 |
The scale of weights in FlexiCubes. Should be between 0 and 1.
|
65 |
"""
|
66 |
|
67 |
-
def __init__(self, device=
|
68 |
-
|
69 |
self.device = device
|
70 |
self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
|
71 |
self.num_vd_table = torch.tensor(num_vd_table,
|
@@ -576,4 +576,4 @@ class FlexiCubes:
|
|
576 |
|
577 |
tets = torch.cat([tets_surface, tets_inside])
|
578 |
vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
|
579 |
-
return vertices, tets
|
|
|
25 |
PyTorch tensors on the specified device.
|
26 |
|
27 |
Attributes:
|
28 |
+
device (str): Specifies the computational device (default is "cuda").
|
29 |
dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
|
30 |
associated with each dual vertex in 256 Marching Cubes (MC) configurations.
|
31 |
num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
|
|
|
64 |
The scale of weights in FlexiCubes. Should be between 0 and 1.
|
65 |
"""
|
66 |
|
67 |
+
def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
|
68 |
+
|
69 |
self.device = device
|
70 |
self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
|
71 |
self.num_vd_table = torch.tensor(num_vd_table,
|
|
|
576 |
|
577 |
tets = torch.cat([tets_surface, tets_inside])
|
578 |
vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
|
579 |
+
return vertices, tets
|
util/flexicubes_geometry.py
CHANGED
@@ -31,9 +31,8 @@ def get_center_boundary_index(grid_res, device):
|
|
31 |
###############################################################################
|
32 |
class FlexiCubesGeometry(object):
|
33 |
def __init__(
|
34 |
-
self, grid_res=64, scale=2.0, device=
|
35 |
render_type='neural_render', args=None):
|
36 |
-
device = torch.device("cuda")
|
37 |
super(FlexiCubesGeometry, self).__init__()
|
38 |
self.grid_res = grid_res
|
39 |
self.device = device
|
@@ -114,4 +113,4 @@ class FlexiCubesGeometry(object):
|
|
114 |
value = [v[k] for v in all_render_output]
|
115 |
return_value[k] = value
|
116 |
# We can do concatenation outside of the render
|
117 |
-
return return_value
|
|
|
31 |
###############################################################################
|
32 |
class FlexiCubesGeometry(object):
|
33 |
def __init__(
|
34 |
+
self, grid_res=64, scale=2.0, device='cuda', renderer=None,
|
35 |
render_type='neural_render', args=None):
|
|
|
36 |
super(FlexiCubesGeometry, self).__init__()
|
37 |
self.grid_res = grid_res
|
38 |
self.device = device
|
|
|
113 |
value = [v[k] for v in all_render_output]
|
114 |
return_value[k] = value
|
115 |
# We can do concatenation outside of the render
|
116 |
+
return return_value
|