xinjie.wang commited on
Commit
c6c24ac
Β·
1 Parent(s): 22e4e0c
app.py CHANGED
@@ -25,6 +25,7 @@ from common import (
25
  MAX_SEED,
26
  VERSION,
27
  active_btn_by_content,
 
28
  end_session,
29
  extract_3d_representations_v2,
30
  extract_urdf,
@@ -37,17 +38,33 @@ from common import (
37
  select_point,
38
  start_session,
39
  )
40
- from gradio.themes import Default
41
- from gradio.themes.utils.colors import slate
42
 
43
- with gr.Blocks(
44
- delete_cache=(43200, 43200), theme=Default(primary_hue=slate)
45
- ) as demo:
46
  gr.Markdown(
47
- f"""
48
- ## ***EmbodiedGen***: Image-to-3D Asset \n
49
- version: {VERSION}
50
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
 
53
  gr.HTML(image_css)
@@ -165,12 +182,14 @@ with gr.Blocks(
165
  )
166
 
167
  generate_btn = gr.Button(
168
- "Generate(~0.5 mins)", variant="primary", interactive=False
 
 
169
  )
170
  model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
171
  with gr.Row():
172
  extract_rep3d_btn = gr.Button(
173
- "Extract 3D Representation(~2 mins)",
174
  variant="primary",
175
  interactive=False,
176
  )
@@ -191,7 +210,7 @@ with gr.Blocks(
191
  )
192
  with gr.Row():
193
  extract_urdf_btn = gr.Button(
194
- "Extract URDF with physics(~1 mins)",
195
  variant="primary",
196
  interactive=False,
197
  )
@@ -214,7 +233,9 @@ with gr.Blocks(
214
  )
215
  with gr.Row():
216
  download_urdf = gr.DownloadButton(
217
- label="Download URDF", variant="primary", interactive=False
 
 
218
  )
219
 
220
  gr.Markdown(
@@ -477,4 +498,4 @@ with gr.Blocks(
477
 
478
 
479
  if __name__ == "__main__":
480
- demo.launch()
 
25
  MAX_SEED,
26
  VERSION,
27
  active_btn_by_content,
28
+ custom_theme,
29
  end_session,
30
  extract_3d_representations_v2,
31
  extract_urdf,
 
38
  select_point,
39
  start_session,
40
  )
 
 
41
 
42
+ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
 
 
43
  gr.Markdown(
 
 
 
44
  """
45
+ ## ***EmbodiedGen***: Image-to-3D Asset
46
+ **πŸ”– Version**: {VERSION}
47
+ <p style="display: flex; gap: 10px; flex-wrap: nowrap;">
48
+ <a href="https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html">
49
+ <img alt="🌐 Project Page" src="https://img.shields.io/badge/🌐-Project_Page-blue">
50
+ </a>
51
+ <a href="https://arxiv.org/abs/xxxx.xxxxx">
52
+ <img alt="πŸ“„ arXiv" src="https://img.shields.io/badge/πŸ“„-arXiv-b31b1b">
53
+ </a>
54
+ <a href="https://github.com/horizon-research/EmbodiedGen">
55
+ <img alt="πŸ’» GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
56
+ </a>
57
+ <a href="https://www.youtube.com/watch?v=SnHhzHeb_aI">
58
+ <img alt="πŸŽ₯ Video" src="https://img.shields.io/badge/πŸŽ₯-Video-red">
59
+ </a>
60
+ </p>
61
+
62
+ πŸ–ΌοΈ Generate physically plausible 3D asset from single input image.
63
+
64
+ """.format(
65
+ VERSION=VERSION
66
+ ),
67
+ elem_classes=["header"],
68
  )
69
 
70
  gr.HTML(image_css)
 
182
  )
183
 
184
  generate_btn = gr.Button(
185
+ "πŸš€ 1. Generate(~0.5 mins)",
186
+ variant="primary",
187
+ interactive=False,
188
  )
189
  model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
190
  with gr.Row():
191
  extract_rep3d_btn = gr.Button(
192
+ "πŸ” 2. Extract 3D Representation(~2 mins)",
193
  variant="primary",
194
  interactive=False,
195
  )
 
210
  )
211
  with gr.Row():
212
  extract_urdf_btn = gr.Button(
213
+ "🧩 3. Extract URDF with physics(~1 mins)",
214
  variant="primary",
215
  interactive=False,
216
  )
 
233
  )
234
  with gr.Row():
235
  download_urdf = gr.DownloadButton(
236
+ label="⬇️ 4. Download URDF",
237
+ variant="primary",
238
+ interactive=False,
239
  )
240
 
241
  gr.Markdown(
 
498
 
499
 
500
  if __name__ == "__main__":
501
+ demo.launch(server_name="10.34.8.82", server_port=8085)
common.py CHANGED
@@ -30,6 +30,8 @@ import torch
30
  import torch.nn.functional as F
31
  import trimesh
32
  from easydict import EasyDict as edict
 
 
33
  from PIL import Image
34
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
35
  from embodied_gen.data.differentiable_render import entrypoint as render_api
@@ -233,6 +235,14 @@ height: 100% !important;
233
  </style>
234
  """
235
 
 
 
 
 
 
 
 
 
236
 
237
  def start_session(req: gr.Request) -> None:
238
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
@@ -670,8 +680,8 @@ def text2image_fn(
670
  ip_adapt_scale: float = 0.3,
671
  image_wh: int | tuple[int, int] = [1024, 1024],
672
  rmbg_tag: str = "rembg",
673
- n_sample: int = 3,
674
  seed: int = None,
 
675
  req: gr.Request = None,
676
  ):
677
  if isinstance(image_wh, int):
 
30
  import torch.nn.functional as F
31
  import trimesh
32
  from easydict import EasyDict as edict
33
+ from gradio.themes import Soft
34
+ from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
35
  from PIL import Image
36
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
37
  from embodied_gen.data.differentiable_render import entrypoint as render_api
 
235
  </style>
236
  """
237
 
238
+ custom_theme = Soft(
239
+ primary_hue=stone,
240
+ secondary_hue=gray,
241
+ radius_size="md",
242
+ text_size="sm",
243
+ spacing_size="sm",
244
+ )
245
+
246
 
247
  def start_session(req: gr.Request) -> None:
248
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
680
  ip_adapt_scale: float = 0.3,
681
  image_wh: int | tuple[int, int] = [1024, 1024],
682
  rmbg_tag: str = "rembg",
 
683
  seed: int = None,
684
+ n_sample: int = 3,
685
  req: gr.Request = None,
686
  ):
687
  if isinstance(image_wh, int):
embodied_gen/models/text_model.py CHANGED
@@ -16,10 +16,10 @@
16
 
17
 
18
  import logging
 
19
 
20
- import torch
21
  import numpy as np
22
- import random
23
  from diffusers import (
24
  AutoencoderKL,
25
  EulerDiscreteScheduler,
@@ -143,9 +143,10 @@ def text2img_gen(
143
  seed: int = None,
144
  ) -> list[Image.Image]:
145
  prompt = "Single " + prompt + ", in the center of the image"
146
- prompt += ", high quality, high resolution, best quality, white background, 3D style," # noqa
147
  logger.info(f"Processing prompt: {prompt}")
148
 
 
149
  if seed is not None:
150
  generator = torch.Generator(pipeline.device).manual_seed(seed)
151
  torch.manual_seed(seed)
 
16
 
17
 
18
  import logging
19
+ import random
20
 
 
21
  import numpy as np
22
+ import torch
23
  from diffusers import (
24
  AutoencoderKL,
25
  EulerDiscreteScheduler,
 
143
  seed: int = None,
144
  ) -> list[Image.Image]:
145
  prompt = "Single " + prompt + ", in the center of the image"
146
+ prompt += ", high quality, high resolution, best quality, white background, 3D style" # noqa
147
  logger.info(f"Processing prompt: {prompt}")
148
 
149
+ generator = None
150
  if seed is not None:
151
  generator = torch.Generator(pipeline.device).manual_seed(seed)
152
  torch.manual_seed(seed)
embodied_gen/scripts/imageto3d.py CHANGED
@@ -70,7 +70,9 @@ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
70
  RBG_REMOVER = RembgRemover()
71
  RBG14_REMOVER = BMGG14Remover()
72
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
73
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
 
 
74
  PIPELINE.cuda()
75
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
76
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
 
70
  RBG_REMOVER = RembgRemover()
71
  RBG14_REMOVER = BMGG14Remover()
72
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
73
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
74
+ "microsoft/TRELLIS-image-large"
75
+ )
76
  PIPELINE.cuda()
77
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
78
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)