multimodalart HF Staff commited on
Commit
f782ff0
·
verified ·
1 Parent(s): 4670c79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -67
app.py CHANGED
@@ -33,13 +33,17 @@ def end_session(req: gr.Request):
33
 
34
  def preprocess_image(image: Image.Image) -> Image.Image:
35
  """
36
- Preprocess the input image.
 
 
 
 
37
 
38
  Args:
39
- image (Image.Image): The input image.
40
 
41
  Returns:
42
- Image.Image: The preprocessed image.
43
  """
44
  processed_image = pipeline.preprocess_image(image)
45
  return processed_image
@@ -47,13 +51,16 @@ def preprocess_image(image: Image.Image) -> Image.Image:
47
 
48
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
49
  """
50
- Preprocess a list of input images.
 
 
 
51
 
52
  Args:
53
- images (List[Tuple[Image.Image, str]]): The input images.
54
 
55
  Returns:
56
- List[Image.Image]: The preprocessed images.
57
  """
58
  images = [image[0] for image in images]
59
  processed_images = [pipeline.preprocess_image(image) for image in images]
@@ -102,13 +109,23 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
102
 
103
  def get_seed(randomize_seed: bool, seed: int) -> int:
104
  """
105
- Get the random seed.
 
 
 
 
 
 
 
 
 
 
106
  """
107
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
108
 
109
 
110
- @spaces.GPU
111
- def image_to_3d(
112
  image: Image.Image,
113
  multiimages: List[Tuple[Image.Image, str]],
114
  is_multiimage: bool,
@@ -118,10 +135,12 @@ def image_to_3d(
118
  slat_guidance_strength: float,
119
  slat_sampling_steps: int,
120
  multiimage_algo: Literal["multidiffusion", "stochastic"],
 
 
121
  req: gr.Request,
122
- ) -> Tuple[dict, str]:
123
  """
124
- Convert an image to a 3D model.
125
 
126
  Args:
127
  image (Image.Image): The input image.
@@ -133,12 +152,18 @@ def image_to_3d(
133
  slat_guidance_strength (float): The guidance strength for structured latent generation.
134
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
135
  multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
 
 
136
 
137
  Returns:
138
  dict: The information of the generated 3D model.
139
  str: The path to the video of the 3D model.
 
 
140
  """
141
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
 
142
  if not is_multiimage:
143
  outputs = pipeline.run(
144
  image,
@@ -170,53 +195,43 @@ def image_to_3d(
170
  },
171
  mode=multiimage_algo,
172
  )
 
 
173
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
174
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
175
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
176
  video_path = os.path.join(user_dir, 'sample.mp4')
177
  imageio.mimsave(video_path, video, fps=15)
178
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
179
- torch.cuda.empty_cache()
180
- return state, video_path
181
-
182
-
183
- @spaces.GPU(duration=90)
184
- def extract_glb(
185
- state: dict,
186
- mesh_simplify: float,
187
- texture_size: int,
188
- req: gr.Request,
189
- ) -> Tuple[str, str]:
190
- """
191
- Extract a GLB file from the 3D model.
192
-
193
- Args:
194
- state (dict): The state of the generated 3D model.
195
- mesh_simplify (float): The mesh simplification factor.
196
- texture_size (int): The texture resolution.
197
-
198
- Returns:
199
- str: The path to the extracted GLB file.
200
- """
201
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
202
- gs, mesh = unpack_state(state)
203
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
204
  glb_path = os.path.join(user_dir, 'sample.glb')
205
  glb.export(glb_path)
 
 
 
 
206
  torch.cuda.empty_cache()
207
- return glb_path, glb_path
208
 
209
 
210
  @spaces.GPU
211
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
212
  """
213
- Extract a Gaussian file from the 3D model.
 
 
 
 
214
 
215
  Args:
216
- state (dict): The state of the generated 3D model.
 
217
 
218
  Returns:
219
- str: The path to the extracted Gaussian file.
220
  """
221
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
222
  gs, _ = unpack_state(state)
@@ -242,7 +257,17 @@ def prepare_multi_example() -> List[Image.Image]:
242
 
243
  def split_image(image: Image.Image) -> List[Image.Image]:
244
  """
245
- Split an image into multiple views.
 
 
 
 
 
 
 
 
 
 
246
  """
247
  image = np.array(image)
248
  alpha = image[..., 3]
@@ -258,8 +283,9 @@ def split_image(image: Image.Image) -> List[Image.Image]:
258
  with gr.Blocks(delete_cache=(600, 600)) as demo:
259
  gr.Markdown("""
260
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
261
- * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
262
- * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
 
263
 
264
  ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
265
  """)
@@ -289,16 +315,13 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
289
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
290
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
291
  multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
292
-
293
- generate_btn = gr.Button("Generate")
294
 
295
  with gr.Accordion(label="GLB Extraction Settings", open=False):
296
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
297
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
298
-
299
- with gr.Row():
300
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
301
- extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
302
  gr.Markdown("""
303
  *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
304
  """)
@@ -366,26 +389,17 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
366
  inputs=[randomize_seed, seed],
367
  outputs=[seed],
368
  ).then(
369
- image_to_3d,
370
- inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
371
- outputs=[output_buf, video_output],
372
  ).then(
373
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
374
- outputs=[extract_glb_btn, extract_gs_btn],
375
  )
376
 
377
  video_output.clear(
378
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
379
- outputs=[extract_glb_btn, extract_gs_btn],
380
- )
381
-
382
- extract_glb_btn.click(
383
- extract_glb,
384
- inputs=[output_buf, mesh_simplify, texture_size],
385
- outputs=[model_output, download_glb],
386
- ).then(
387
- lambda: gr.Button(interactive=True),
388
- outputs=[download_glb],
389
  )
390
 
391
  extract_gs_btn.click(
@@ -398,8 +412,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
398
  )
399
 
400
  model_output.clear(
401
- lambda: gr.Button(interactive=False),
402
- outputs=[download_glb],
403
  )
404
 
405
 
@@ -411,4 +425,4 @@ if __name__ == "__main__":
411
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
412
  except:
413
  pass
414
- demo.launch()
 
33
 
34
  def preprocess_image(image: Image.Image) -> Image.Image:
35
  """
36
+ Preprocess the input image for 3D generation.
37
+
38
+ This function is called when a user uploads an image or selects an example.
39
+ It applies background removal and other preprocessing steps necessary for
40
+ optimal 3D model generation.
41
 
42
  Args:
43
+ image (Image.Image): The input image from the user
44
 
45
  Returns:
46
+ Image.Image: The preprocessed image ready for 3D generation
47
  """
48
  processed_image = pipeline.preprocess_image(image)
49
  return processed_image
 
51
 
52
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
53
  """
54
+ Preprocess a list of input images for multi-image 3D generation.
55
+
56
+ This function is called when users upload multiple images in the gallery.
57
+ It processes each image to prepare them for the multi-image 3D generation pipeline.
58
 
59
  Args:
60
+ images (List[Tuple[Image.Image, str]]): The input images from the gallery
61
 
62
  Returns:
63
+ List[Image.Image]: The preprocessed images ready for 3D generation
64
  """
65
  images = [image[0] for image in images]
66
  processed_images = [pipeline.preprocess_image(image) for image in images]
 
109
 
110
  def get_seed(randomize_seed: bool, seed: int) -> int:
111
  """
112
+ Get the random seed for generation.
113
+
114
+ This function is called by the generate button to determine whether to use
115
+ a random seed or the user-specified seed value.
116
+
117
+ Args:
118
+ randomize_seed (bool): Whether to generate a random seed
119
+ seed (int): The user-specified seed value
120
+
121
+ Returns:
122
+ int: The seed to use for generation
123
  """
124
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
125
 
126
 
127
+ @spaces.GPU(duration=120)
128
+ def generate_and_extract_glb(
129
  image: Image.Image,
130
  multiimages: List[Tuple[Image.Image, str]],
131
  is_multiimage: bool,
 
135
  slat_guidance_strength: float,
136
  slat_sampling_steps: int,
137
  multiimage_algo: Literal["multidiffusion", "stochastic"],
138
+ mesh_simplify: float,
139
+ texture_size: int,
140
  req: gr.Request,
141
+ ) -> Tuple[dict, str, str, str]:
142
  """
143
+ Convert an image to a 3D model and extract GLB file.
144
 
145
  Args:
146
  image (Image.Image): The input image.
 
152
  slat_guidance_strength (float): The guidance strength for structured latent generation.
153
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
154
  multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
155
+ mesh_simplify (float): The mesh simplification factor.
156
+ texture_size (int): The texture resolution.
157
 
158
  Returns:
159
  dict: The information of the generated 3D model.
160
  str: The path to the video of the 3D model.
161
+ str: The path to the extracted GLB file.
162
+ str: The path to the extracted GLB file (for download).
163
  """
164
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
165
+
166
+ # Generate 3D model
167
  if not is_multiimage:
168
  outputs = pipeline.run(
169
  image,
 
195
  },
196
  mode=multiimage_algo,
197
  )
198
+
199
+ # Render video
200
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
201
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
202
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
203
  video_path = os.path.join(user_dir, 'sample.mp4')
204
  imageio.mimsave(video_path, video, fps=15)
205
+
206
+ # Extract GLB
207
+ gs = outputs['gaussian'][0]
208
+ mesh = outputs['mesh'][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
210
  glb_path = os.path.join(user_dir, 'sample.glb')
211
  glb.export(glb_path)
212
+
213
+ # Pack state for optional Gaussian extraction
214
+ state = pack_state(gs, mesh)
215
+
216
  torch.cuda.empty_cache()
217
+ return state, video_path, glb_path, glb_path
218
 
219
 
220
  @spaces.GPU
221
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
222
  """
223
+ Extract a Gaussian splatting file from the generated 3D model.
224
+
225
+ This function is called when the user clicks "Extract Gaussian" button.
226
+ It converts the 3D model state into a .ply file format containing
227
+ Gaussian splatting data for advanced 3D applications.
228
 
229
  Args:
230
+ state (dict): The state of the generated 3D model containing Gaussian data
231
+ req (gr.Request): Gradio request object for session management
232
 
233
  Returns:
234
+ Tuple[str, str]: Paths to the extracted Gaussian file (for display and download)
235
  """
236
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
237
  gs, _ = unpack_state(state)
 
257
 
258
  def split_image(image: Image.Image) -> List[Image.Image]:
259
  """
260
+ Split a multi-view image into separate view images.
261
+
262
+ This function is called when users select multi-image examples that contain
263
+ multiple views in a single concatenated image. It automatically splits them
264
+ based on alpha channel boundaries and preprocesses each view.
265
+
266
+ Args:
267
+ image (Image.Image): A concatenated image containing multiple views
268
+
269
+ Returns:
270
+ List[Image.Image]: List of individual preprocessed view images
271
  """
272
  image = np.array(image)
273
  alpha = image[..., 3]
 
283
  with gr.Blocks(delete_cache=(600, 600)) as demo:
284
  gr.Markdown("""
285
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
286
+ * Upload an image and click "Generate & Extract GLB" to create a 3D asset and automatically extract the GLB file.
287
+ * If you want the Gaussian file as well, click "Extract Gaussian" after generation.
288
+ * If the image has alpha channel, it will be used as the mask. Otherwise, we use `rembg` to remove the background.
289
 
290
  ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
291
  """)
 
315
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
316
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
317
  multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
 
 
318
 
319
  with gr.Accordion(label="GLB Extraction Settings", open=False):
320
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
321
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
322
+
323
+ generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
324
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
 
325
  gr.Markdown("""
326
  *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
327
  """)
 
389
  inputs=[randomize_seed, seed],
390
  outputs=[seed],
391
  ).then(
392
+ generate_and_extract_glb,
393
+ inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size],
394
+ outputs=[output_buf, video_output, model_output, download_glb],
395
  ).then(
396
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
397
+ outputs=[extract_gs_btn, download_glb],
398
  )
399
 
400
  video_output.clear(
401
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False)]),
402
+ outputs=[extract_gs_btn, download_glb, download_gs],
 
 
 
 
 
 
 
 
 
403
  )
404
 
405
  extract_gs_btn.click(
 
412
  )
413
 
414
  model_output.clear(
415
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
416
+ outputs=[download_glb, download_gs],
417
  )
418
 
419
 
 
425
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
426
  except:
427
  pass
428
+ demo.launch(mcp_server=True)