hysts HF Staff commited on
Commit
4c58984
·
1 Parent(s): e669631
Files changed (2) hide show
  1. app.py +39 -83
  2. pyproject.toml +41 -40
app.py CHANGED
@@ -2,18 +2,16 @@ import os
2
  import shlex
3
  import shutil
4
  import subprocess
5
- from typing import *
6
 
7
  os.environ["SPCONV_ALGO"] = "native"
8
 
9
  if os.getenv("SPACE_ID"):
10
- subprocess.run(
11
- shlex.split(
12
- "pip install wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"
13
- ),
14
  check=True,
15
  )
16
- subprocess.run(
17
  shlex.split("pip install wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl"),
18
  check=True,
19
  )
@@ -46,8 +44,7 @@ def end_session(req: gr.Request):
46
 
47
 
48
  def preprocess_image(image: Image.Image) -> Image.Image:
49
- """
50
- Preprocess the input image.
51
 
52
  Args:
53
  image (Image.Image): The input image.
@@ -59,9 +56,8 @@ def preprocess_image(image: Image.Image) -> Image.Image:
59
  return processed_image
60
 
61
 
62
- def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
63
- """
64
- Preprocess a list of input images.
65
 
66
  Args:
67
  images (List[Tuple[Image.Image, str]]): The input images.
@@ -91,7 +87,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
91
  }
92
 
93
 
94
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
95
  gs = Gaussian(
96
  aabb=state["gaussian"]["aabb"],
97
  sh_degree=state["gaussian"]["sh_degree"],
@@ -115,16 +111,14 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
115
 
116
 
117
  def get_seed(randomize_seed: bool, seed: int) -> int:
118
- """
119
- Get the random seed.
120
- """
121
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
122
 
123
 
124
  @spaces.GPU
125
  def image_to_3d(
126
  image: Image.Image,
127
- multiimages: List[Tuple[Image.Image, str]],
128
  is_multiimage: bool,
129
  seed: int,
130
  ss_guidance_strength: float,
@@ -133,9 +127,8 @@ def image_to_3d(
133
  slat_sampling_steps: int,
134
  multiimage_algo: Literal["multidiffusion", "stochastic"],
135
  req: gr.Request,
136
- ) -> Tuple[dict, str]:
137
- """
138
- Convert an image to a 3D model.
139
 
140
  Args:
141
  image (Image.Image): The input image.
@@ -186,9 +179,7 @@ def image_to_3d(
186
  )
187
  video = render_utils.render_video(outputs["gaussian"][0], num_frames=120)["color"]
188
  video_geo = render_utils.render_video(outputs["mesh"][0], num_frames=120)["normal"]
189
- video = [
190
- np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))
191
- ]
192
  video_path = os.path.join(user_dir, "sample.mp4")
193
  imageio.mimsave(video_path, video, fps=15)
194
  state = pack_state(outputs["gaussian"][0], outputs["mesh"][0])
@@ -202,9 +193,8 @@ def extract_glb(
202
  mesh_simplify: float,
203
  texture_size: int,
204
  req: gr.Request,
205
- ) -> Tuple[str, str]:
206
- """
207
- Extract a GLB file from the 3D model.
208
 
209
  Args:
210
  state (dict): The state of the generated 3D model.
@@ -216,9 +206,7 @@ def extract_glb(
216
  """
217
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
218
  gs, mesh = unpack_state(state)
219
- glb = postprocessing_utils.to_glb(
220
- gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False
221
- )
222
  glb_path = os.path.join(user_dir, "sample.glb")
223
  glb.export(glb_path)
224
  torch.cuda.empty_cache()
@@ -226,9 +214,8 @@ def extract_glb(
226
 
227
 
228
  @spaces.GPU
229
- def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
230
- """
231
- Extract a Gaussian file from the 3D model.
232
 
233
  Args:
234
  state (dict): The state of the generated 3D model.
@@ -244,10 +231,8 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
244
  return gaussian_path, gaussian_path
245
 
246
 
247
- def prepare_multi_example() -> List[Image.Image]:
248
- multi_case = list(
249
- set([i.split("_")[0] for i in os.listdir("assets/example_multi_image")])
250
- )
251
  images = []
252
  for case in multi_case:
253
  _images = []
@@ -260,17 +245,15 @@ def prepare_multi_example() -> List[Image.Image]:
260
  return images
261
 
262
 
263
- def split_image(image: Image.Image) -> List[Image.Image]:
264
- """
265
- Split an image into multiple views.
266
- """
267
  image = np.array(image)
268
  alpha = image[..., 3]
269
  alpha = np.any(alpha > 0, axis=0)
270
  start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
271
  end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
272
  images = []
273
- for s, e in zip(start_pos, end_pos):
274
  images.append(Image.fromarray(image[:, s : e + 1]))
275
  return [preprocess_image(image) for image in images]
276
 
@@ -280,7 +263,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
280
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
281
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
282
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
283
-
284
  ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
285
  """)
286
 
@@ -304,8 +287,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
304
  columns=3,
305
  )
306
  gr.Markdown("""
307
- Input different views of the object in separate images.
308
-
309
  *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
310
  """)
311
 
@@ -314,20 +297,12 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
314
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
315
  gr.Markdown("Stage 1: Sparse Structure Generation")
316
  with gr.Row():
317
- ss_guidance_strength = gr.Slider(
318
- 0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1
319
- )
320
- ss_sampling_steps = gr.Slider(
321
- 1, 50, label="Sampling Steps", value=12, step=1
322
- )
323
  gr.Markdown("Stage 2: Structured Latent Generation")
324
  with gr.Row():
325
- slat_guidance_strength = gr.Slider(
326
- 0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1
327
- )
328
- slat_sampling_steps = gr.Slider(
329
- 1, 50, label="Sampling Steps", value=12, step=1
330
- )
331
  multiimage_algo = gr.Radio(
332
  ["stochastic", "multidiffusion"],
333
  label="Multi-image Algorithm",
@@ -337,12 +312,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
337
  generate_btn = gr.Button("Generate")
338
 
339
  with gr.Accordion(label="GLB Extraction Settings", open=False):
340
- mesh_simplify = gr.Slider(
341
- 0.9, 0.98, label="Simplify", value=0.95, step=0.01
342
- )
343
- texture_size = gr.Slider(
344
- 512, 2048, label="Texture Size", value=1024, step=512
345
- )
346
 
347
  with gr.Row():
348
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
@@ -352,18 +323,12 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
352
  """)
353
 
354
  with gr.Column():
355
- video_output = gr.Video(
356
- label="Generated 3D Asset", autoplay=True, loop=True, height=300
357
- )
358
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
359
 
360
  with gr.Row():
361
- download_glb = gr.DownloadButton(
362
- label="Download GLB", interactive=False
363
- )
364
- download_gs = gr.DownloadButton(
365
- label="Download Gaussian", interactive=False
366
- )
367
 
368
  is_multiimage = gr.State(False)
369
  output_buf = gr.State()
@@ -371,10 +336,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
371
  # Example images at the bottom of the page
372
  with gr.Row() as single_image_example:
373
  examples = gr.Examples(
374
- examples=[
375
- f"assets/example_image/{image}"
376
- for image in os.listdir("assets/example_image")
377
- ],
378
  inputs=[image_prompt],
379
  fn=preprocess_image,
380
  outputs=[image_prompt],
@@ -396,15 +358,11 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
396
  demo.unload(end_session)
397
 
398
  single_image_input_tab.select(
399
- lambda: tuple(
400
- [False, gr.Row.update(visible=True), gr.Row.update(visible=False)]
401
- ),
402
  outputs=[is_multiimage, single_image_example, multiimage_example],
403
  )
404
  multiimage_input_tab.select(
405
- lambda: tuple(
406
- [True, gr.Row.update(visible=False), gr.Row.update(visible=True)]
407
- ),
408
  outputs=[is_multiimage, single_image_example, multiimage_example],
409
  )
410
 
@@ -476,9 +434,7 @@ if __name__ == "__main__":
476
  pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
477
  pipeline.cuda()
478
  try:
479
- pipeline.preprocess_image(
480
- Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
481
- ) # Preload rembg
482
  except:
483
  pass
484
  demo.launch(mcp_server=True)
 
2
  import shlex
3
  import shutil
4
  import subprocess
5
+ from typing import Literal
6
 
7
  os.environ["SPCONV_ALGO"] = "native"
8
 
9
  if os.getenv("SPACE_ID"):
10
+ subprocess.run( # noqa: S603
11
+ shlex.split("pip install wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"),
 
 
12
  check=True,
13
  )
14
+ subprocess.run( # noqa: S603
15
  shlex.split("pip install wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl"),
16
  check=True,
17
  )
 
44
 
45
 
46
  def preprocess_image(image: Image.Image) -> Image.Image:
47
+ """Preprocess the input image.
 
48
 
49
  Args:
50
  image (Image.Image): The input image.
 
56
  return processed_image
57
 
58
 
59
+ def preprocess_images(images: list[tuple[Image.Image, str]]) -> list[Image.Image]:
60
+ """Preprocess a list of input images.
 
61
 
62
  Args:
63
  images (List[Tuple[Image.Image, str]]): The input images.
 
87
  }
88
 
89
 
90
+ def unpack_state(state: dict) -> tuple[Gaussian, edict, str]:
91
  gs = Gaussian(
92
  aabb=state["gaussian"]["aabb"],
93
  sh_degree=state["gaussian"]["sh_degree"],
 
111
 
112
 
113
  def get_seed(randomize_seed: bool, seed: int) -> int:
114
+ """Get the random seed."""
 
 
115
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
116
 
117
 
118
  @spaces.GPU
119
  def image_to_3d(
120
  image: Image.Image,
121
+ multiimages: list[tuple[Image.Image, str]],
122
  is_multiimage: bool,
123
  seed: int,
124
  ss_guidance_strength: float,
 
127
  slat_sampling_steps: int,
128
  multiimage_algo: Literal["multidiffusion", "stochastic"],
129
  req: gr.Request,
130
+ ) -> tuple[dict, str]:
131
+ """Convert an image to a 3D model.
 
132
 
133
  Args:
134
  image (Image.Image): The input image.
 
179
  )
180
  video = render_utils.render_video(outputs["gaussian"][0], num_frames=120)["color"]
181
  video_geo = render_utils.render_video(outputs["mesh"][0], num_frames=120)["normal"]
182
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
 
183
  video_path = os.path.join(user_dir, "sample.mp4")
184
  imageio.mimsave(video_path, video, fps=15)
185
  state = pack_state(outputs["gaussian"][0], outputs["mesh"][0])
 
193
  mesh_simplify: float,
194
  texture_size: int,
195
  req: gr.Request,
196
+ ) -> tuple[str, str]:
197
+ """Extract a GLB file from the 3D model.
 
198
 
199
  Args:
200
  state (dict): The state of the generated 3D model.
 
206
  """
207
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
208
  gs, mesh = unpack_state(state)
209
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
 
 
210
  glb_path = os.path.join(user_dir, "sample.glb")
211
  glb.export(glb_path)
212
  torch.cuda.empty_cache()
 
214
 
215
 
216
  @spaces.GPU
217
+ def extract_gaussian(state: dict, req: gr.Request) -> tuple[str, str]:
218
+ """Extract a Gaussian file from the 3D model.
 
219
 
220
  Args:
221
  state (dict): The state of the generated 3D model.
 
231
  return gaussian_path, gaussian_path
232
 
233
 
234
+ def prepare_multi_example() -> list[Image.Image]:
235
+ multi_case = list(set([i.split("_")[0] for i in os.listdir("assets/example_multi_image")]))
 
 
236
  images = []
237
  for case in multi_case:
238
  _images = []
 
245
  return images
246
 
247
 
248
+ def split_image(image: Image.Image) -> list[Image.Image]:
249
+ """Split an image into multiple views."""
 
 
250
  image = np.array(image)
251
  alpha = image[..., 3]
252
  alpha = np.any(alpha > 0, axis=0)
253
  start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
254
  end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
255
  images = []
256
+ for s, e in zip(start_pos, end_pos, strict=False):
257
  images.append(Image.fromarray(image[:, s : e + 1]))
258
  return [preprocess_image(image) for image in images]
259
 
 
263
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
264
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
265
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
266
+
267
  ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
268
  """)
269
 
 
287
  columns=3,
288
  )
289
  gr.Markdown("""
290
+ Input different views of the object in separate images.
291
+
292
  *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
293
  """)
294
 
 
297
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
298
  gr.Markdown("Stage 1: Sparse Structure Generation")
299
  with gr.Row():
300
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
301
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
 
 
 
302
  gr.Markdown("Stage 2: Structured Latent Generation")
303
  with gr.Row():
304
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
305
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
 
 
 
306
  multiimage_algo = gr.Radio(
307
  ["stochastic", "multidiffusion"],
308
  label="Multi-image Algorithm",
 
312
  generate_btn = gr.Button("Generate")
313
 
314
  with gr.Accordion(label="GLB Extraction Settings", open=False):
315
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
316
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
 
 
 
 
317
 
318
  with gr.Row():
319
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
 
323
  """)
324
 
325
  with gr.Column():
326
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
 
327
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
328
 
329
  with gr.Row():
330
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
331
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
 
 
 
 
332
 
333
  is_multiimage = gr.State(False)
334
  output_buf = gr.State()
 
336
  # Example images at the bottom of the page
337
  with gr.Row() as single_image_example:
338
  examples = gr.Examples(
339
+ examples=[f"assets/example_image/{image}" for image in os.listdir("assets/example_image")],
 
 
 
340
  inputs=[image_prompt],
341
  fn=preprocess_image,
342
  outputs=[image_prompt],
 
358
  demo.unload(end_session)
359
 
360
  single_image_input_tab.select(
361
+ lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
 
 
362
  outputs=[is_multiimage, single_image_example, multiimage_example],
363
  )
364
  multiimage_input_tab.select(
365
+ lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
 
 
366
  outputs=[is_multiimage, single_image_example, multiimage_example],
367
  )
368
 
 
434
  pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
435
  pipeline.cuda()
436
  try:
437
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
 
 
438
  except:
439
  pass
440
  demo.launch(mcp_server=True)
pyproject.toml CHANGED
@@ -42,43 +42,44 @@ dev = [
42
  "setuptools>=80.8.0",
43
  ]
44
 
45
- #[tool.ruff]
46
- #line-length = 119
47
- #
48
- #[tool.ruff.lint]
49
- #select = ["ALL"]
50
- #ignore = [
51
- # "COM812", # missing-trailing-comma
52
- # "D203", # one-blank-line-before-class
53
- # "D213", # multi-line-summary-second-line
54
- # "E501", # line-too-long
55
- # "SIM117", # multiple-with-statements
56
- # #
57
- # "D100", # undocumented-public-module
58
- # "D101", # undocumented-public-class
59
- # "D102", # undocumented-public-method
60
- # "D103", # undocumented-public-function
61
- # "D104", # undocumented-public-package
62
- # "D105", # undocumented-magic-method
63
- # "D107", # undocumented-public-init
64
- # "EM101", # raw-string-in-exception
65
- # "FBT001", # boolean-type-hint-positional-argument
66
- # "FBT002", # boolean-default-value-positional-argument
67
- # "PD901", # pandas-df-variable-name
68
- # "PGH003", # blanket-type-ignore
69
- # "PLR0913", # too-many-arguments
70
- # "PLR0915", # too-many-statements
71
- # "TRY003", # raise-vanilla-args
72
- #]
73
- #unfixable = [
74
- # "F401", # unused-import
75
- #]
76
- #
77
- #[tool.ruff.lint.pydocstyle]
78
- #convention = "google"
79
- #
80
- #[tool.ruff.lint.per-file-ignores]
81
- #"*.ipynb" = ["T201", "T203"]
82
- #
83
- #[tool.ruff.format]
84
- #docstring-code-format = true
 
 
42
  "setuptools>=80.8.0",
43
  ]
44
 
45
+ [tool.ruff]
46
+ line-length = 119
47
+ exclude = ["trellis", "extensions"]
48
+
49
+ [tool.ruff.lint]
50
+ select = ["ALL"]
51
+ ignore = [
52
+ "COM812", # missing-trailing-comma
53
+ "D203", # one-blank-line-before-class
54
+ "D213", # multi-line-summary-second-line
55
+ "E501", # line-too-long
56
+ "SIM117", # multiple-with-statements
57
+ #
58
+ "D100", # undocumented-public-module
59
+ "D101", # undocumented-public-class
60
+ "D102", # undocumented-public-method
61
+ "D103", # undocumented-public-function
62
+ "D104", # undocumented-public-package
63
+ "D105", # undocumented-magic-method
64
+ "D107", # undocumented-public-init
65
+ "EM101", # raw-string-in-exception
66
+ "FBT001", # boolean-type-hint-positional-argument
67
+ "FBT002", # boolean-default-value-positional-argument
68
+ "PD901", # pandas-df-variable-name
69
+ "PGH003", # blanket-type-ignore
70
+ "PLR0913", # too-many-arguments
71
+ "PLR0915", # too-many-statements
72
+ "TRY003", # raise-vanilla-args
73
+ ]
74
+ unfixable = [
75
+ "F401", # unused-import
76
+ ]
77
+
78
+ [tool.ruff.lint.pydocstyle]
79
+ convention = "google"
80
+
81
+ [tool.ruff.lint.per-file-ignores]
82
+ "*.ipynb" = ["T201", "T203"]
83
+
84
+ [tool.ruff.format]
85
+ docstring-code-format = true