xinjie.wang commited on
Commit
6d38e38
·
1 Parent(s): 1043d26
app.py CHANGED
@@ -25,6 +25,7 @@ from common import (
25
  MAX_SEED,
26
  VERSION,
27
  active_btn_by_text_content,
 
28
  end_session,
29
  extract_3d_representations_v2,
30
  extract_urdf,
@@ -37,17 +38,33 @@ from common import (
37
  start_session,
38
  text2image_fn,
39
  )
40
- from gradio.themes import Default
41
- from gradio.themes.utils.colors import slate
42
 
43
- with gr.Blocks(
44
- delete_cache=(43200, 43200), theme=Default(primary_hue=slate)
45
- ) as demo:
46
  gr.Markdown(
47
- f"""
48
- ## ***EmbodiedGen***: Text-to-3D Asset \n
49
- version: {VERSION} \n
50
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
  gr.HTML(image_css)
53
  gr.HTML(lighting_css)
@@ -107,7 +124,7 @@ with gr.Blocks(
107
  )
108
 
109
  generate_img_btn = gr.Button(
110
- "Generate Images(~1min)",
111
  variant="primary",
112
  interactive=False,
113
  )
@@ -163,12 +180,14 @@ with gr.Blocks(
163
  )
164
 
165
  generate_btn = gr.Button(
166
- "Generate 3D(~0.5 mins)", variant="primary", interactive=False
 
 
167
  )
168
  model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
169
  with gr.Row():
170
  extract_rep3d_btn = gr.Button(
171
- "Extract 3D Representation(~1 mins)",
172
  variant="primary",
173
  interactive=False,
174
  )
@@ -189,13 +208,15 @@ with gr.Blocks(
189
  )
190
  with gr.Row():
191
  extract_urdf_btn = gr.Button(
192
- "Extract URDF with physics(~1 mins)",
193
  variant="primary",
194
  interactive=False,
195
  )
196
  with gr.Row():
197
  download_urdf = gr.DownloadButton(
198
- label="Download URDF", variant="primary", interactive=False
 
 
199
  )
200
 
201
  with gr.Column(scale=3):
@@ -286,12 +307,12 @@ with gr.Blocks(
286
  est_mu_text = gr.Textbox(
287
  label="Friction coefficient", interactive=False
288
  )
289
-
290
  prompt_examples = [
291
- "satin gold tea cup with saucer",
292
- "small brown leather bag",
293
  "Miniature cup with floral design",
294
- "带木质底座, 具有经纬线的地球仪",
295
  "橙色电动手钻, 有磨损细节",
296
  "手工制作的皮革笔记本",
297
  "写实风格机甲3D全身模型, 主体色调为深灰色和荧光黄",
 
25
  MAX_SEED,
26
  VERSION,
27
  active_btn_by_text_content,
28
+ custom_theme,
29
  end_session,
30
  extract_3d_representations_v2,
31
  extract_urdf,
 
38
  start_session,
39
  text2image_fn,
40
  )
 
 
41
 
42
+ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
 
 
43
  gr.Markdown(
44
+ """
45
+ ## ***EmbodiedGen***: Text-to-3D Asset
46
+ **🔖 Version**: {VERSION}
47
+ <p style="display: flex; gap: 10px; flex-wrap: nowrap;">
48
+ <a href="https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html">
49
+ <img alt="🌐 Project Page" src="https://img.shields.io/badge/🌐-Project_Page-blue">
50
+ </a>
51
+ <a href="https://arxiv.org/abs/xxxx.xxxxx">
52
+ <img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
53
+ </a>
54
+ <a href="https://github.com/horizon-research/EmbodiedGen">
55
+ <img alt="💻 GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
56
+ </a>
57
+ <a href="https://www.youtube.com/watch?v=SnHhzHeb_aI">
58
+ <img alt="🎥 Video" src="https://img.shields.io/badge/🎥-Video-red">
59
+ </a>
60
+ </p>
61
+
62
+ 📝 Create 3D assets from text descriptions for a wide range of geometry and styles.
63
+
64
+ """.format(
65
+ VERSION=VERSION
66
+ ),
67
+ elem_classes=["header"],
68
  )
69
  gr.HTML(image_css)
70
  gr.HTML(lighting_css)
 
124
  )
125
 
126
  generate_img_btn = gr.Button(
127
+ "🎨 1. Generate Images(~1min)",
128
  variant="primary",
129
  interactive=False,
130
  )
 
180
  )
181
 
182
  generate_btn = gr.Button(
183
+ "🚀 2. Generate 3D(~0.5 mins)",
184
+ variant="primary",
185
+ interactive=False,
186
  )
187
  model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
188
  with gr.Row():
189
  extract_rep3d_btn = gr.Button(
190
+ "🔍 3. Extract 3D Representation(~1 mins)",
191
  variant="primary",
192
  interactive=False,
193
  )
 
208
  )
209
  with gr.Row():
210
  extract_urdf_btn = gr.Button(
211
+ "🧩 4. Extract URDF with physics(~1 mins)",
212
  variant="primary",
213
  interactive=False,
214
  )
215
  with gr.Row():
216
  download_urdf = gr.DownloadButton(
217
+ label="⬇️ 5. Download URDF",
218
+ variant="primary",
219
+ interactive=False,
220
  )
221
 
222
  with gr.Column(scale=3):
 
307
  est_mu_text = gr.Textbox(
308
  label="Friction coefficient", interactive=False
309
  )
310
+
311
  prompt_examples = [
312
+ "satin gold tea cup with saucer",
313
+ "brown leather bag",
314
  "Miniature cup with floral design",
315
+ "带木质底座, 具有经纬线的地球仪",
316
  "橙色电动手钻, 有磨损细节",
317
  "手工制作的皮革笔记本",
318
  "写实风格机甲3D全身模型, 主体色调为深灰色和荧光黄",
common.py CHANGED
@@ -30,6 +30,8 @@ import torch
30
  import torch.nn.functional as F
31
  import trimesh
32
  from easydict import EasyDict as edict
 
 
33
  from PIL import Image
34
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
35
  from embodied_gen.data.differentiable_render import entrypoint as render_api
@@ -233,6 +235,14 @@ height: 100% !important;
233
  </style>
234
  """
235
 
 
 
 
 
 
 
 
 
236
 
237
  def start_session(req: gr.Request) -> None:
238
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
@@ -670,8 +680,8 @@ def text2image_fn(
670
  ip_adapt_scale: float = 0.3,
671
  image_wh: int | tuple[int, int] = [1024, 1024],
672
  rmbg_tag: str = "rembg",
673
- n_sample: int = 3,
674
  seed: int = None,
 
675
  req: gr.Request = None,
676
  ):
677
  if isinstance(image_wh, int):
 
30
  import torch.nn.functional as F
31
  import trimesh
32
  from easydict import EasyDict as edict
33
+ from gradio.themes import Soft
34
+ from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
35
  from PIL import Image
36
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
37
  from embodied_gen.data.differentiable_render import entrypoint as render_api
 
235
  </style>
236
  """
237
 
238
+ custom_theme = Soft(
239
+ primary_hue=stone,
240
+ secondary_hue=gray,
241
+ radius_size="md",
242
+ text_size="sm",
243
+ spacing_size="sm",
244
+ )
245
+
246
 
247
  def start_session(req: gr.Request) -> None:
248
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
680
  ip_adapt_scale: float = 0.3,
681
  image_wh: int | tuple[int, int] = [1024, 1024],
682
  rmbg_tag: str = "rembg",
 
683
  seed: int = None,
684
+ n_sample: int = 3,
685
  req: gr.Request = None,
686
  ):
687
  if isinstance(image_wh, int):
embodied_gen/models/text_model.py CHANGED
@@ -16,10 +16,10 @@
16
 
17
 
18
  import logging
 
19
 
20
- import torch
21
  import numpy as np
22
- import random
23
  from diffusers import (
24
  AutoencoderKL,
25
  EulerDiscreteScheduler,
@@ -143,9 +143,10 @@ def text2img_gen(
143
  seed: int = None,
144
  ) -> list[Image.Image]:
145
  prompt = "Single " + prompt + ", in the center of the image"
146
- prompt += ", high quality, high resolution, best quality, white background, 3D style," # noqa
147
  logger.info(f"Processing prompt: {prompt}")
148
 
 
149
  if seed is not None:
150
  generator = torch.Generator(pipeline.device).manual_seed(seed)
151
  torch.manual_seed(seed)
 
16
 
17
 
18
  import logging
19
+ import random
20
 
 
21
  import numpy as np
22
+ import torch
23
  from diffusers import (
24
  AutoencoderKL,
25
  EulerDiscreteScheduler,
 
143
  seed: int = None,
144
  ) -> list[Image.Image]:
145
  prompt = "Single " + prompt + ", in the center of the image"
146
+ prompt += ", high quality, high resolution, best quality, white background, 3D style" # noqa
147
  logger.info(f"Processing prompt: {prompt}")
148
 
149
+ generator = None
150
  if seed is not None:
151
  generator = torch.Generator(pipeline.device).manual_seed(seed)
152
  torch.manual_seed(seed)
embodied_gen/scripts/imageto3d.py CHANGED
@@ -70,7 +70,9 @@ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
70
  RBG_REMOVER = RembgRemover()
71
  RBG14_REMOVER = BMGG14Remover()
72
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
73
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
 
 
74
  PIPELINE.cuda()
75
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
76
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
 
70
  RBG_REMOVER = RembgRemover()
71
  RBG14_REMOVER = BMGG14Remover()
72
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
73
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
74
+ "microsoft/TRELLIS-image-large"
75
+ )
76
  PIPELINE.cuda()
77
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
78
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)