xinjie.wang commited on
Commit
3075458
·
1 Parent(s): 9e9a83d
common.py CHANGED
@@ -503,7 +503,17 @@ def extract_3d_representations_v2(
503
  device="cpu",
504
  )
505
  color_path = os.path.join(user_dir, "color.png")
506
- render_gs_api(aligned_gs_path, color_path)
 
 
 
 
 
 
 
 
 
 
507
 
508
  mesh = trimesh.Trimesh(
509
  vertices=mesh_model.vertices.cpu().numpy(),
@@ -518,12 +528,14 @@ def extract_3d_representations_v2(
518
  mesh = backproject_api(
519
  delight_model=DELIGHT,
520
  imagesr_model=IMAGESR_MODEL,
521
- color_path=color_path,
522
  mesh_path=mesh_obj_path,
523
  output_path=mesh_obj_path,
524
  skip_fix_mesh=False,
525
  delight=enable_delight,
526
  texture_wh=[texture_size, texture_size],
 
 
527
  )
528
 
529
  mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
 
503
  device="cpu",
504
  )
505
  color_path = os.path.join(user_dir, "color.png")
506
+ render_gs_api(
507
+ input_gs=aligned_gs_path,
508
+ output_path=color_path,
509
+ elevation=[20, -10],
510
+ )
511
+ color_path2 = os.path.join(user_dir, "color2.png")
512
+ render_gs_api(
513
+ input_gs=aligned_gs_path,
514
+ output_path=color_path2,
515
+ elevation=[60, -50],
516
+ )
517
 
518
  mesh = trimesh.Trimesh(
519
  vertices=mesh_model.vertices.cpu().numpy(),
 
528
  mesh = backproject_api(
529
  delight_model=DELIGHT,
530
  imagesr_model=IMAGESR_MODEL,
531
+ color_path=[color_path, color_path2],
532
  mesh_path=mesh_obj_path,
533
  output_path=mesh_obj_path,
534
  skip_fix_mesh=False,
535
  delight=enable_delight,
536
  texture_wh=[texture_size, texture_size],
537
+ elevation=[20, -10, 60, -50],
538
+ num_images=12,
539
  )
540
 
541
  mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
embodied_gen/data/backproject_v2.py CHANGED
@@ -33,6 +33,7 @@ from embodied_gen.data.mesh_operator import MeshFixer
33
  from embodied_gen.data.utils import (
34
  CameraSetting,
35
  DiffrastRender,
 
36
  get_images_from_grid,
37
  init_kal_camera,
38
  normalize_vertices_array,
@@ -41,6 +42,7 @@ from embodied_gen.data.utils import (
41
  )
42
  from embodied_gen.models.delight_model import DelightingModel
43
  from embodied_gen.models.sr_model import ImageRealESRGAN
 
44
 
45
  logging.basicConfig(
46
  format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
@@ -541,8 +543,9 @@ def parse_args():
541
  parser = argparse.ArgumentParser(description="Backproject texture")
542
  parser.add_argument(
543
  "--color_path",
 
544
  type=str,
545
- help="Multiview color image in 6x512x512 file path",
546
  )
547
  parser.add_argument(
548
  "--mesh_path",
@@ -559,7 +562,7 @@ def parse_args():
559
  )
560
  parser.add_argument(
561
  "--elevation",
562
- nargs=2,
563
  type=float,
564
  default=[20.0, -10.0],
565
  help="Elevation angles for the camera (default: [20.0, -10.0])",
@@ -647,19 +650,23 @@ def entrypoint(
647
  fov=math.radians(args.fov),
648
  device=args.device,
649
  )
650
- view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
651
 
652
- color_grid = Image.open(args.color_path)
 
 
 
 
 
653
  if args.delight:
654
- if delight_model is None:
655
- delight_model = DelightingModel()
656
- save_dir = os.path.dirname(args.output_path)
657
- os.makedirs(save_dir, exist_ok=True)
658
  color_grid = delight_model(color_grid)
659
  if not args.no_save_delight_img:
660
- color_grid.save(f"{save_dir}/color_grid_delight.png")
 
 
661
 
662
  multiviews = get_images_from_grid(color_grid, img_size=512)
 
 
663
 
664
  # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
665
  if imagesr_model is None:
@@ -688,7 +695,7 @@ def entrypoint(
688
  texture_backer = TextureBacker(
689
  camera_params=camera_params,
690
  view_weights=view_weights,
691
- render_wh=camera_params.resolution_hw,
692
  texture_wh=args.texture_wh,
693
  smooth_texture=not args.no_smooth_texture,
694
  )
 
33
  from embodied_gen.data.utils import (
34
  CameraSetting,
35
  DiffrastRender,
36
+ as_list,
37
  get_images_from_grid,
38
  init_kal_camera,
39
  normalize_vertices_array,
 
42
  )
43
  from embodied_gen.models.delight_model import DelightingModel
44
  from embodied_gen.models.sr_model import ImageRealESRGAN
45
+ from embodied_gen.utils.process_media import vcat_pil_images
46
 
47
  logging.basicConfig(
48
  format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
 
543
  parser = argparse.ArgumentParser(description="Backproject texture")
544
  parser.add_argument(
545
  "--color_path",
546
+ nargs="+",
547
  type=str,
548
+ help="Multiview color image in 6x512x512 file paths",
549
  )
550
  parser.add_argument(
551
  "--mesh_path",
 
562
  )
563
  parser.add_argument(
564
  "--elevation",
565
+ nargs="+",
566
  type=float,
567
  default=[20.0, -10.0],
568
  help="Elevation angles for the camera (default: [20.0, -10.0])",
 
650
  fov=math.radians(args.fov),
651
  device=args.device,
652
  )
 
653
 
654
+ args.color_path = as_list(args.color_path)
655
+ if args.delight and delight_model is None:
656
+ delight_model = DelightingModel()
657
+
658
+ color_grid = [Image.open(color_path) for color_path in args.color_path]
659
+ color_grid = vcat_pil_images(color_grid, image_mode="RGBA")
660
  if args.delight:
 
 
 
 
661
  color_grid = delight_model(color_grid)
662
  if not args.no_save_delight_img:
663
+ save_dir = os.path.dirname(args.output_path)
664
+ os.makedirs(save_dir, exist_ok=True)
665
+ color_grid.save(f"{save_dir}/color_delight.png")
666
 
667
  multiviews = get_images_from_grid(color_grid, img_size=512)
668
+ view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
669
+ view_weights += [0.01] * (len(multiviews) - len(view_weights))
670
 
671
  # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
672
  if imagesr_model is None:
 
695
  texture_backer = TextureBacker(
696
  camera_params=camera_params,
697
  view_weights=view_weights,
698
+ render_wh=args.resolution_hw,
699
  texture_wh=args.texture_wh,
700
  smooth_texture=not args.no_smooth_texture,
701
  )
embodied_gen/data/differentiable_render.py CHANGED
@@ -503,7 +503,7 @@ def parse_args():
503
  help="Whether to generate global normal .mp4 rendering file.",
504
  )
505
  parser.add_argument(
506
- "--prompts",
507
  type=str,
508
  nargs="+",
509
  default=None,
@@ -579,7 +579,7 @@ def entrypoint(**kwargs) -> None:
579
  mesh_path=args.mesh_path,
580
  output_root=args.output_root,
581
  uuid=args.uuid,
582
- prompts=args.prompts,
583
  )
584
 
585
  return
 
503
  help="Whether to generate global normal .mp4 rendering file.",
504
  )
505
  parser.add_argument(
506
+ "--video_prompts",
507
  type=str,
508
  nargs="+",
509
  default=None,
 
579
  mesh_path=args.mesh_path,
580
  output_root=args.output_root,
581
  uuid=args.uuid,
582
+ prompts=args.video_prompts,
583
  )
584
 
585
  return
embodied_gen/data/utils.py CHANGED
@@ -28,7 +28,7 @@ import numpy as np
28
  import nvdiffrast.torch as dr
29
  import torch
30
  import torch.nn.functional as F
31
- from PIL import Image
32
 
33
  try:
34
  from kolors.models.modeling_chatglm import ChatGLMModel
@@ -698,6 +698,8 @@ def as_list(obj):
698
  return obj
699
  elif isinstance(obj, set):
700
  return list(obj)
 
 
701
  else:
702
  return [obj]
703
 
@@ -742,6 +744,8 @@ def _compute_az_el_by_camera_params(
742
  ):
743
  num_view = camera_params.num_images // len(camera_params.elevation)
744
  view_interval = 2 * np.pi / num_view / 2
 
 
745
  azimuths = []
746
  elevations = []
747
  for idx, el in enumerate(camera_params.elevation):
@@ -758,8 +762,13 @@ def _compute_az_el_by_camera_params(
758
  return azimuths, elevations
759
 
760
 
761
- def init_kal_camera(camera_params: CameraSetting) -> Camera:
762
- azimuths, elevations = _compute_az_el_by_camera_params(camera_params)
 
 
 
 
 
763
  cam_pts = _compute_cam_pts_by_az_el(
764
  azimuths, elevations, camera_params.distance
765
  )
@@ -856,13 +865,38 @@ def get_images_from_grid(
856
  image = Image.open(image)
857
 
858
  view_images = np.array(image)
859
- view_images = np.concatenate(
860
- [view_images[:img_size, ...], view_images[img_size:, ...]], axis=1
861
- )
862
- images = np.split(view_images, view_images.shape[1] // img_size, axis=1)
863
- images = [Image.fromarray(img) for img in images]
 
 
 
 
 
 
 
 
 
 
864
 
865
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
866
 
867
 
868
  def post_process_texture(texture: np.ndarray, iter: int = 1) -> np.ndarray:
@@ -872,7 +906,14 @@ def post_process_texture(texture: np.ndarray, iter: int = 1) -> np.ndarray:
872
  texture, d=5, sigmaColor=20, sigmaSpace=20
873
  )
874
 
875
- return texture
 
 
 
 
 
 
 
876
 
877
 
878
  def quat_mult(q1, q2):
 
28
  import nvdiffrast.torch as dr
29
  import torch
30
  import torch.nn.functional as F
31
+ from PIL import Image, ImageEnhance
32
 
33
  try:
34
  from kolors.models.modeling_chatglm import ChatGLMModel
 
698
  return obj
699
  elif isinstance(obj, set):
700
  return list(obj)
701
+ elif obj is None:
702
+ return obj
703
  else:
704
  return [obj]
705
 
 
744
  ):
745
  num_view = camera_params.num_images // len(camera_params.elevation)
746
  view_interval = 2 * np.pi / num_view / 2
747
+ if num_view == 1:
748
+ view_interval = np.pi / 2
749
  azimuths = []
750
  elevations = []
751
  for idx, el in enumerate(camera_params.elevation):
 
762
  return azimuths, elevations
763
 
764
 
765
+ def init_kal_camera(
766
+ camera_params: CameraSetting,
767
+ flip_az: bool = False,
768
+ ) -> Camera:
769
+ azimuths, elevations = _compute_az_el_by_camera_params(
770
+ camera_params, flip_az
771
+ )
772
  cam_pts = _compute_cam_pts_by_az_el(
773
  azimuths, elevations, camera_params.distance
774
  )
 
865
  image = Image.open(image)
866
 
867
  view_images = np.array(image)
868
+ height, width, _ = view_images.shape
869
+ rows = height // img_size
870
+ cols = width // img_size
871
+ blocks = []
872
+ for i in range(rows):
873
+ for j in range(cols):
874
+ block = view_images[
875
+ i * img_size : (i + 1) * img_size,
876
+ j * img_size : (j + 1) * img_size,
877
+ :,
878
+ ]
879
+ blocks.append(Image.fromarray(block))
880
+
881
+ return blocks
882
+
883
 
884
+ def enhance_image(
885
+ image: Image.Image,
886
+ contrast_factor: float = 1.3,
887
+ color_factor: float = 1.2,
888
+ brightness_factor: float = 0.95,
889
+ ) -> Image.Image:
890
+ enhancer_contrast = ImageEnhance.Contrast(image)
891
+ img_contrasted = enhancer_contrast.enhance(contrast_factor)
892
+
893
+ enhancer_color = ImageEnhance.Color(img_contrasted)
894
+ img_colored = enhancer_color.enhance(color_factor)
895
+
896
+ enhancer_brightness = ImageEnhance.Brightness(img_colored)
897
+ enhanced_image = enhancer_brightness.enhance(brightness_factor)
898
+
899
+ return enhanced_image
900
 
901
 
902
  def post_process_texture(texture: np.ndarray, iter: int = 1) -> np.ndarray:
 
906
  texture, d=5, sigmaColor=20, sigmaSpace=20
907
  )
908
 
909
+ texture = enhance_image(
910
+ image=Image.fromarray(texture),
911
+ contrast_factor=1.3,
912
+ color_factor=1.2,
913
+ brightness_factor=0.95,
914
+ )
915
+
916
+ return np.array(texture)
917
 
918
 
919
  def quat_mult(q1, q2):
embodied_gen/models/delight_model.py CHANGED
@@ -29,6 +29,7 @@ from diffusers import (
29
  from huggingface_hub import snapshot_download
30
  from PIL import Image
31
  from embodied_gen.models.segment_model import RembgRemover
 
32
 
33
  __all__ = [
34
  "DelightingModel",
@@ -84,6 +85,7 @@ class DelightingModel(object):
84
 
85
  def _lazy_init_pipeline(self):
86
  if self.pipeline is None:
 
87
  pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
88
  self.model_path,
89
  torch_dtype=torch.float16,
 
29
  from huggingface_hub import snapshot_download
30
  from PIL import Image
31
  from embodied_gen.models.segment_model import RembgRemover
32
+ from embodied_gen.utils.log import logger
33
 
34
  __all__ = [
35
  "DelightingModel",
 
85
 
86
  def _lazy_init_pipeline(self):
87
  if self.pipeline is None:
88
+ logger.info("Loading Delighting Model...")
89
  pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
90
  self.model_path,
91
  torch_dtype=torch.float16,
embodied_gen/models/texture_model.py CHANGED
@@ -29,6 +29,7 @@ from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
29
  )
30
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
31
  from embodied_gen.models.text_model import download_kolors_weights
 
32
 
33
  __all__ = [
34
  "build_texture_gen_pipe",
@@ -42,7 +43,7 @@ def build_texture_gen_pipe(
42
  device: str = "cuda",
43
  ) -> DiffusionPipeline:
44
  download_kolors_weights(f"{base_ckpt_dir}/Kolors")
45
-
46
  tokenizer = ChatGLMTokenizer.from_pretrained(
47
  f"{base_ckpt_dir}/Kolors/text_encoder"
48
  )
 
29
  )
30
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
31
  from embodied_gen.models.text_model import download_kolors_weights
32
+ from embodied_gen.utils.log import logger
33
 
34
  __all__ = [
35
  "build_texture_gen_pipe",
 
43
  device: str = "cuda",
44
  ) -> DiffusionPipeline:
45
  download_kolors_weights(f"{base_ckpt_dir}/Kolors")
46
+ logger.info(f"Load Kolors weights...")
47
  tokenizer = ChatGLMTokenizer.from_pretrained(
48
  f"{base_ckpt_dir}/Kolors/text_encoder"
49
  )
embodied_gen/scripts/gen_texture.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from dataclasses import dataclass
4
+
5
+ import tyro
6
+ from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
7
+ from embodied_gen.data.differentiable_render import entrypoint as drender_api
8
+ from embodied_gen.data.utils import as_list
9
+ from embodied_gen.models.delight_model import DelightingModel
10
+ from embodied_gen.models.sr_model import ImageRealESRGAN
11
+ from embodied_gen.scripts.render_mv import (
12
+ build_texture_gen_pipe,
13
+ )
14
+ from embodied_gen.scripts.render_mv import infer_pipe as render_mv_api
15
+ from embodied_gen.utils.log import logger
16
+
17
+
18
+ @dataclass
19
+ class TextureGenConfig:
20
+ mesh_path: str | list[str]
21
+ prompt: str | list[str]
22
+ output_root: str
23
+ controlnet_cond_scale: float = 0.7
24
+ guidance_scale: float = 9
25
+ strength: float = 0.9
26
+ num_inference_steps: int = 40
27
+ delight: bool = True
28
+ seed: int = 0
29
+ base_ckpt_dir: str = "./weights"
30
+ texture_size: int = 2048
31
+ ip_adapt_scale: float = 0.0
32
+ ip_img_path: str | list[str] | None = None
33
+
34
+
35
+ def entrypoint() -> None:
36
+ cfg = tyro.cli(TextureGenConfig)
37
+ cfg.mesh_path = as_list(cfg.mesh_path)
38
+ cfg.prompt = as_list(cfg.prompt)
39
+ cfg.ip_img_path = as_list(cfg.ip_img_path)
40
+ assert len(cfg.mesh_path) == len(cfg.prompt)
41
+
42
+ # Pre-load models.
43
+ if cfg.ip_adapt_scale > 0:
44
+ PIPELINE = build_texture_gen_pipe(
45
+ base_ckpt_dir="./weights",
46
+ ip_adapt_scale=cfg.ip_adapt_scale,
47
+ device="cuda",
48
+ )
49
+ else:
50
+ PIPELINE = build_texture_gen_pipe(
51
+ base_ckpt_dir="./weights",
52
+ ip_adapt_scale=0,
53
+ device="cuda",
54
+ )
55
+ DELIGHT = None
56
+ if cfg.delight:
57
+ DELIGHT = DelightingModel()
58
+ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
59
+
60
+ for idx in range(len(cfg.mesh_path)):
61
+ mesh_path = cfg.mesh_path[idx]
62
+ prompt = cfg.prompt[idx]
63
+ uuid = os.path.splitext(os.path.basename(mesh_path))[0]
64
+ output_root = os.path.join(cfg.output_root, uuid)
65
+ drender_api(
66
+ mesh_path=mesh_path,
67
+ output_root=f"{output_root}/condition",
68
+ uuid=uuid,
69
+ )
70
+ render_mv_api(
71
+ index_file=f"{output_root}/condition/index.json",
72
+ controlnet_cond_scale=cfg.controlnet_cond_scale,
73
+ guidance_scale=cfg.guidance_scale,
74
+ strength=cfg.strength,
75
+ num_inference_steps=cfg.num_inference_steps,
76
+ ip_adapt_scale=cfg.ip_adapt_scale,
77
+ ip_img_path=(
78
+ None if cfg.ip_img_path is None else cfg.ip_img_path[idx]
79
+ ),
80
+ prompt=prompt,
81
+ save_dir=f"{output_root}/multi_view",
82
+ sub_idxs=[[0, 1, 2], [3, 4, 5]],
83
+ pipeline=PIPELINE,
84
+ seed=cfg.seed,
85
+ )
86
+ textured_mesh = backproject_api(
87
+ delight_model=DELIGHT,
88
+ imagesr_model=IMAGESR_MODEL,
89
+ mesh_path=mesh_path,
90
+ color_path=f"{output_root}/multi_view/color_sample0.png",
91
+ output_path=f"{output_root}/texture_mesh/{uuid}.obj",
92
+ save_glb_path=f"{output_root}/texture_mesh/{uuid}.glb",
93
+ skip_fix_mesh=True,
94
+ delight=cfg.delight,
95
+ no_save_delight_img=True,
96
+ texture_wh=[cfg.texture_size, cfg.texture_size],
97
+ )
98
+ drender_api(
99
+ mesh_path=f"{output_root}/texture_mesh/{uuid}.obj",
100
+ output_root=f"{output_root}/texture_mesh",
101
+ uuid=uuid,
102
+ num_images=90,
103
+ elevation=[20],
104
+ with_mtl=True,
105
+ gen_color_mp4=True,
106
+ pbr_light_factor=1.2,
107
+ )
108
+
109
+ # Re-organize folders
110
+ shutil.rmtree(f"{output_root}/condition")
111
+ shutil.copy(
112
+ f"{output_root}/texture_mesh/{uuid}/color.mp4",
113
+ f"{output_root}/color.mp4",
114
+ )
115
+ shutil.rmtree(f"{output_root}/texture_mesh/{uuid}")
116
+
117
+ logger.info(
118
+ f"Successfully generate textured mesh in {output_root}/texture_mesh"
119
+ )
120
+
121
+
122
+ if __name__ == "__main__":
123
+ entrypoint()
embodied_gen/scripts/imageto3d.py CHANGED
@@ -108,6 +108,9 @@ def parse_args():
108
  default=2,
109
  )
110
  parser.add_argument("--disable_decompose_convex", action="store_true")
 
 
 
111
  args, unknown = parser.parse_known_args()
112
 
113
  return args
@@ -209,7 +212,17 @@ def entrypoint(**kwargs):
209
  device="cpu",
210
  )
211
  color_path = os.path.join(output_root, "color.png")
212
- render_gs_api(aligned_gs_path, color_path)
 
 
 
 
 
 
 
 
 
 
213
 
214
  geo_flag, geo_result = GEO_CHECKER(
215
  [color_path], text=asset_node
@@ -241,12 +254,14 @@ def entrypoint(**kwargs):
241
  mesh = backproject_api(
242
  delight_model=DELIGHT,
243
  imagesr_model=IMAGESR_MODEL,
244
- color_path=color_path,
245
  mesh_path=mesh_obj_path,
246
  output_path=mesh_obj_path,
247
  skip_fix_mesh=False,
248
  delight=True,
249
- texture_wh=[2048, 2048],
 
 
250
  )
251
 
252
  mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
 
108
  default=2,
109
  )
110
  parser.add_argument("--disable_decompose_convex", action="store_true")
111
+ parser.add_argument(
112
+ "--texture_wh", type=int, nargs=2, default=[2048, 2048]
113
+ )
114
  args, unknown = parser.parse_known_args()
115
 
116
  return args
 
212
  device="cpu",
213
  )
214
  color_path = os.path.join(output_root, "color.png")
215
+ render_gs_api(
216
+ input_gs=aligned_gs_path,
217
+ output_path=color_path,
218
+ elevation=[20, -10],
219
+ )
220
+ color_path2 = os.path.join(output_root, "color2.png")
221
+ render_gs_api(
222
+ input_gs=aligned_gs_path,
223
+ output_path=color_path2,
224
+ elevation=[60, -50],
225
+ )
226
 
227
  geo_flag, geo_result = GEO_CHECKER(
228
  [color_path], text=asset_node
 
254
  mesh = backproject_api(
255
  delight_model=DELIGHT,
256
  imagesr_model=IMAGESR_MODEL,
257
+ color_path=[color_path, color_path2],
258
  mesh_path=mesh_obj_path,
259
  output_path=mesh_obj_path,
260
  skip_fix_mesh=False,
261
  delight=True,
262
+ texture_wh=args.texture_wh,
263
+ elevation=[20, -10, 60, -50],
264
+ num_images=12,
265
  )
266
 
267
  mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
embodied_gen/scripts/render_gs.py CHANGED
@@ -18,12 +18,11 @@
18
  import argparse
19
  import logging
20
  import math
21
- import os
22
 
23
  import cv2
24
- import numpy as np
25
  import spaces
26
  import torch
 
27
  from tqdm import tqdm
28
  from embodied_gen.data.utils import (
29
  CameraSetting,
@@ -31,6 +30,7 @@ from embodied_gen.data.utils import (
31
  normalize_vertices_array,
32
  )
33
  from embodied_gen.models.gs_model import GaussianOperator
 
34
 
35
  logging.basicConfig(
36
  format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
@@ -113,12 +113,11 @@ def load_gs_model(
113
 
114
 
115
  @spaces.GPU
116
- def entrypoint(input_gs: str = None, output_path: str = None) -> None:
117
  args = parse_args()
118
- if isinstance(input_gs, str):
119
- args.input_gs = input_gs
120
- if isinstance(output_path, str):
121
- args.output_path = output_path
122
 
123
  # Setup camera parameters
124
  camera_params = CameraSetting(
@@ -129,7 +128,7 @@ def entrypoint(input_gs: str = None, output_path: str = None) -> None:
129
  fov=math.radians(args.fov),
130
  device=args.device,
131
  )
132
- camera = init_kal_camera(camera_params)
133
  matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
134
  matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
135
  w2cs = matrix_mv.to(camera_params.device)
@@ -153,21 +152,11 @@ def entrypoint(input_gs: str = None, output_path: str = None) -> None:
153
  (args.image_size, args.image_size),
154
  interpolation=cv2.INTER_AREA,
155
  )
156
- images.append(color)
157
-
158
- # Cat color images into grid image and save.
159
- select_idxs = [[0, 2, 1], [5, 4, 3]] # fix order for 6 views
160
- grid_image = []
161
- for row_idxs in select_idxs:
162
- row_image = []
163
- for row_idx in row_idxs:
164
- row_image.append(images[row_idx])
165
- row_image = np.concatenate(row_image, axis=1)
166
- grid_image.append(row_image)
167
-
168
- grid_image = np.concatenate(grid_image, axis=0)
169
- os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
170
- cv2.imwrite(args.output_path, grid_image)
171
  logger.info(f"Saved grid image to {args.output_path}")
172
 
173
 
 
18
  import argparse
19
  import logging
20
  import math
 
21
 
22
  import cv2
 
23
  import spaces
24
  import torch
25
+ from PIL import Image
26
  from tqdm import tqdm
27
  from embodied_gen.data.utils import (
28
  CameraSetting,
 
30
  normalize_vertices_array,
31
  )
32
  from embodied_gen.models.gs_model import GaussianOperator
33
+ from embodied_gen.utils.process_media import combine_images_to_grid
34
 
35
  logging.basicConfig(
36
  format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
 
113
 
114
 
115
  @spaces.GPU
116
+ def entrypoint(**kwargs) -> None:
117
  args = parse_args()
118
+ for k, v in kwargs.items():
119
+ if hasattr(args, k) and v is not None:
120
+ setattr(args, k, v)
 
121
 
122
  # Setup camera parameters
123
  camera_params = CameraSetting(
 
128
  fov=math.radians(args.fov),
129
  device=args.device,
130
  )
131
+ camera = init_kal_camera(camera_params, flip_az=True)
132
  matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
133
  matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
134
  w2cs = matrix_mv.to(camera_params.device)
 
152
  (args.image_size, args.image_size),
153
  interpolation=cv2.INTER_AREA,
154
  )
155
+ color = cv2.cvtColor(color, cv2.COLOR_BGRA2RGBA)
156
+ images.append(Image.fromarray(color))
157
+
158
+ combine_images_to_grid(images, image_mode="RGBA")[0].save(args.output_path)
159
+
 
 
 
 
 
 
 
 
 
 
160
  logger.info(f"Saved grid image to {args.output_path}")
161
 
162
 
embodied_gen/scripts/texture_gen.sh CHANGED
@@ -28,6 +28,7 @@ if [[ -z "$mesh_path" || -z "$prompt" || -z "$output_root" ]]; then
28
  exit 1
29
  fi
30
 
 
31
  uuid=$(basename "$output_root")
32
  # Step 1: drender-cli for condition rendering
33
  drender-cli --mesh_path ${mesh_path} \
 
28
  exit 1
29
  fi
30
 
31
+ echo "Will be deprecated, recommended to use 'texture-cli' instead."
32
  uuid=$(basename "$output_root")
33
  # Step 1: drender-cli for condition rendering
34
  drender-cli --mesh_path ${mesh_path} \
embodied_gen/utils/process_media.py CHANGED
@@ -49,6 +49,7 @@ __all__ = [
49
  "is_image_file",
50
  "parse_text_prompts",
51
  "check_object_edge_truncated",
 
52
  ]
53
 
54
 
@@ -166,6 +167,7 @@ def combine_images_to_grid(
166
  images: list[str | Image.Image],
167
  cat_row_col: tuple[int, int] = None,
168
  target_wh: tuple[int, int] = (512, 512),
 
169
  ) -> list[Image.Image]:
170
  n_images = len(images)
171
  if n_images == 1:
@@ -178,13 +180,13 @@ def combine_images_to_grid(
178
  n_row, n_col = cat_row_col
179
 
180
  images = [
181
- Image.open(p).convert("RGB") if isinstance(p, str) else p
182
  for p in images
183
  ]
184
  images = [img.resize(target_wh) for img in images]
185
 
186
  grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
187
- grid = Image.new("RGB", (grid_w, grid_h), (0, 0, 0))
188
 
189
  for idx, img in enumerate(images):
190
  row, col = divmod(idx, n_col)
@@ -435,6 +437,21 @@ def check_object_edge_truncated(
435
  return not (top or bottom or left or right)
436
 
437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  if __name__ == "__main__":
439
  image_paths = [
440
  "outputs/layouts_sim/task_0000/images/pen.png",
 
49
  "is_image_file",
50
  "parse_text_prompts",
51
  "check_object_edge_truncated",
52
+ "vcat_pil_images",
53
  ]
54
 
55
 
 
167
  images: list[str | Image.Image],
168
  cat_row_col: tuple[int, int] = None,
169
  target_wh: tuple[int, int] = (512, 512),
170
+ image_mode: str = "RGB",
171
  ) -> list[Image.Image]:
172
  n_images = len(images)
173
  if n_images == 1:
 
180
  n_row, n_col = cat_row_col
181
 
182
  images = [
183
+ Image.open(p).convert(image_mode) if isinstance(p, str) else p
184
  for p in images
185
  ]
186
  images = [img.resize(target_wh) for img in images]
187
 
188
  grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
189
+ grid = Image.new(image_mode, (grid_w, grid_h), (0, 0, 0))
190
 
191
  for idx, img in enumerate(images):
192
  row, col = divmod(idx, n_col)
 
437
  return not (top or bottom or left or right)
438
 
439
 
440
+ def vcat_pil_images(
441
+ images: list[Image.Image], image_mode: str = "RGB"
442
+ ) -> Image.Image:
443
+ widths, heights = zip(*(img.size for img in images))
444
+ total_height = sum(heights)
445
+ max_width = max(widths)
446
+ new_image = Image.new(image_mode, (max_width, total_height))
447
+ y_offset = 0
448
+ for image in images:
449
+ new_image.paste(image, (0, y_offset))
450
+ y_offset += image.size[1]
451
+
452
+ return new_image
453
+
454
+
455
  if __name__ == "__main__":
456
  image_paths = [
457
  "outputs/layouts_sim/task_0000/images/pen.png",
embodied_gen/validators/urdf_convertor.py CHANGED
@@ -266,7 +266,7 @@ class URDFGenerator(object):
266
  if self.decompose_convex:
267
  try:
268
  d_params = dict(
269
- threshold=0.05, max_convex_hull=64, verbose=False
270
  )
271
  filename = f"{os.path.splitext(obj_name)[0]}_collision.ply"
272
  output_path = os.path.join(mesh_folder, filename)
 
266
  if self.decompose_convex:
267
  try:
268
  d_params = dict(
269
+ threshold=0.05, max_convex_hull=100, verbose=False
270
  )
271
  filename = f"{os.path.splitext(obj_name)[0]}_collision.ply"
272
  output_path = os.path.join(mesh_folder, filename)