xinjie.wang commited on
Commit
f219113
·
1 Parent(s): dd1f1fd
Files changed (2) hide show
  1. app.py +170 -488
  2. embodied_gen/utils/gpt_clients.py +1 -0
app.py CHANGED
@@ -1,501 +1,183 @@
1
- # Project EmbodiedGen
2
- #
3
- # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
- # implied. See the License for the specific language governing
15
- # permissions and limitations under the License.
16
-
17
-
18
- import os
19
-
20
- os.environ["GRADIO_APP"] = "imageto3d"
21
- from glob import glob
22
-
23
  import gradio as gr
24
- from common import (
25
- MAX_SEED,
26
- VERSION,
27
- active_btn_by_content,
28
- custom_theme,
29
- end_session,
30
- extract_3d_representations_v2,
31
- extract_urdf,
32
- get_seed,
33
- image_css,
34
- image_to_3d,
35
- lighting_css,
36
- preprocess_image_fn,
37
- preprocess_sam_image_fn,
38
- select_point,
39
- start_session,
40
  )
41
 
42
- with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
43
- gr.Markdown(
44
- """
45
- ## ***EmbodiedGen***: Image-to-3D Asset
46
- **🔖 Version**: {VERSION}
47
- <p style="display: flex; gap: 10px; flex-wrap: nowrap;">
48
- <a href="https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html">
49
- <img alt="🌐 Project Page" src="https://img.shields.io/badge/🌐-Project_Page-blue">
50
- </a>
51
- <a href="https://arxiv.org/abs/xxxx.xxxxx">
52
- <img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
53
- </a>
54
- <a href="https://github.com/HorizonRobotics/EmbodiedGen">
55
- <img alt="💻 GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
56
- </a>
57
- <a href="https://www.youtube.com/watch?v=SnHhzHeb_aI">
58
- <img alt="🎥 Video" src="https://img.shields.io/badge/🎥-Video-red">
59
- </a>
60
- </p>
61
-
62
- 🖼️ Generate physically plausible 3D asset from single input image.
63
-
64
- """.format(
65
- VERSION=VERSION
66
- ),
67
- elem_classes=["header"],
68
- )
69
-
70
- gr.HTML(image_css)
71
- # gr.HTML(lighting_css)
72
- with gr.Row():
73
- with gr.Column(scale=2):
74
- with gr.Tabs() as input_tabs:
75
- with gr.Tab(
76
- label="Image(auto seg)", id=0
77
- ) as single_image_input_tab:
78
- raw_image_cache = gr.Image(
79
- format="png",
80
- image_mode="RGB",
81
- type="pil",
82
- visible=False,
83
- )
84
- image_prompt = gr.Image(
85
- label="Input Image",
86
- format="png",
87
- image_mode="RGBA",
88
- type="pil",
89
- height=400,
90
- elem_classes=["image_fit"],
91
- )
92
- gr.Markdown(
93
- """
94
- If you are not satisfied with the auto segmentation
95
- result, please switch to the `Image(SAM seg)` tab."""
96
- )
97
- with gr.Tab(
98
- label="Image(SAM seg)", id=1
99
- ) as samimage_input_tab:
100
- with gr.Row():
101
- with gr.Column(scale=1):
102
- image_prompt_sam = gr.Image(
103
- label="Input Image",
104
- type="numpy",
105
- height=400,
106
- elem_classes=["image_fit"],
107
- )
108
- image_seg_sam = gr.Image(
109
- label="SAM Seg Image",
110
- image_mode="RGBA",
111
- type="pil",
112
- height=400,
113
- visible=False,
114
- )
115
- with gr.Column(scale=1):
116
- image_mask_sam = gr.AnnotatedImage(
117
- elem_classes=["image_fit"]
118
- )
119
-
120
- fg_bg_radio = gr.Radio(
121
- ["foreground_point", "background_point"],
122
- label="Select foreground(green) or background(red) points, by default foreground", # noqa
123
- value="foreground_point",
124
- )
125
- gr.Markdown(
126
- """ Click the `Input Image` to select SAM points,
127
- after get the satisified segmentation, click `Generate`
128
- button to generate the 3D asset. \n
129
- Note: If the segmented foreground is too small relative
130
- to the entire image area, the generation will fail.
131
- """
132
- )
133
-
134
- with gr.Accordion(label="Generation Settings", open=False):
135
- with gr.Row():
136
- seed = gr.Slider(
137
- 0, MAX_SEED, label="Seed", value=0, step=1
138
- )
139
- texture_size = gr.Slider(
140
- 1024,
141
- 4096,
142
- label="UV texture size",
143
- value=2048,
144
- step=256,
145
- )
146
- rmbg_tag = gr.Radio(
147
- choices=["rembg", "rmbg14"],
148
- value="rembg",
149
- label="Background Removal Model",
150
- )
151
- with gr.Row():
152
- randomize_seed = gr.Checkbox(
153
- label="Randomize Seed", value=False
154
- )
155
- project_delight = gr.Checkbox(
156
- label="Backproject delighting",
157
- value=False,
158
- )
159
- gr.Markdown("Geo Structure Generation")
160
- with gr.Row():
161
- ss_guidance_strength = gr.Slider(
162
- 0.0,
163
- 10.0,
164
- label="Guidance Strength",
165
- value=7.5,
166
- step=0.1,
167
- )
168
- ss_sampling_steps = gr.Slider(
169
- 1, 50, label="Sampling Steps", value=12, step=1
170
- )
171
- gr.Markdown("Visual Appearance Generation")
172
- with gr.Row():
173
- slat_guidance_strength = gr.Slider(
174
- 0.0,
175
- 10.0,
176
- label="Guidance Strength",
177
- value=3.0,
178
- step=0.1,
179
- )
180
- slat_sampling_steps = gr.Slider(
181
- 1, 50, label="Sampling Steps", value=12, step=1
182
- )
183
-
184
- generate_btn = gr.Button(
185
- "🚀 1. Generate(~0.5 mins)",
186
- variant="primary",
187
- interactive=False,
188
  )
189
- model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
190
- with gr.Row():
191
- extract_rep3d_btn = gr.Button(
192
- "🔍 2. Extract 3D Representation(~2 mins)",
193
- variant="primary",
194
- interactive=False,
195
- )
196
- with gr.Accordion(
197
- label="Enter Asset Attributes(optional)", open=False
198
- ):
199
- asset_cat_text = gr.Textbox(
200
- label="Enter Asset Category (e.g., chair)"
201
- )
202
- height_range_text = gr.Textbox(
203
- label="Enter **Height Range** in meter (e.g., 0.5-0.6)"
204
- )
205
- mass_range_text = gr.Textbox(
206
- label="Enter **Mass Range** in kg (e.g., 1.1-1.2)"
207
- )
208
- asset_version_text = gr.Textbox(
209
- label=f"Enter version (e.g., {VERSION})"
210
- )
211
- with gr.Row():
212
- extract_urdf_btn = gr.Button(
213
- "🧩 3. Extract URDF with physics(~1 mins)",
214
- variant="primary",
215
- interactive=False,
216
- )
217
- with gr.Row():
218
- gr.Markdown(
219
- "#### Estimated Asset 3D Attributes(No input required)"
220
- )
221
- with gr.Row():
222
- est_type_text = gr.Textbox(
223
- label="Asset category", interactive=False
224
- )
225
- est_height_text = gr.Textbox(
226
- label="Real height(.m)", interactive=False
227
- )
228
- est_mass_text = gr.Textbox(
229
- label="Mass(.kg)", interactive=False
230
- )
231
- est_mu_text = gr.Textbox(
232
- label="Friction coefficient", interactive=False
233
- )
234
- with gr.Row():
235
- download_urdf = gr.DownloadButton(
236
- label="⬇️ 4. Download URDF",
237
- variant="primary",
238
- interactive=False,
239
- )
240
-
241
- gr.Markdown(
242
- """ NOTE: If `Asset Attributes` are provided, the provided
243
- properties will be used; otherwise, the GPT-preset properties
244
- will be applied. \n
245
- The `Download URDF` file is restored to the real scale and
246
- has quality inspection, open with an editor to view details.
247
- """
248
  )
249
 
250
- with gr.Row() as single_image_example:
251
- examples = gr.Examples(
252
- label="Image Gallery",
253
- examples=[
254
- [image_path]
255
- for image_path in sorted(
256
- glob("assets/example_image/*")
257
- )
258
- ],
259
- inputs=[image_prompt, rmbg_tag],
260
- fn=preprocess_image_fn,
261
- outputs=[image_prompt, raw_image_cache],
262
- run_on_click=True,
263
- examples_per_page=10,
264
- )
265
 
266
- with gr.Row(visible=False) as single_sam_image_example:
267
- examples = gr.Examples(
268
- label="Image Gallery",
269
- examples=[
270
- [image_path]
271
- for image_path in sorted(
272
- glob("assets/example_image/*")
273
- )
274
- ],
275
- inputs=[image_prompt_sam],
276
- fn=preprocess_sam_image_fn,
277
- outputs=[image_prompt_sam, raw_image_cache],
278
- run_on_click=True,
279
- examples_per_page=10,
280
- )
281
- with gr.Column(scale=1):
282
- video_output = gr.Video(
283
- label="Generated 3D Asset",
284
- autoplay=True,
285
- loop=True,
286
- height=300,
287
- )
288
- model_output_gs = gr.Model3D(
289
- label="Gaussian Representation", height=300, interactive=False
290
- )
291
- aligned_gs = gr.Textbox(visible=False)
292
- gr.Markdown(
293
- """ The rendering of `Gaussian Representation` takes additional 10s. """ # noqa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  )
295
- with gr.Row():
296
- model_output_mesh = gr.Model3D(
297
- label="Mesh Representation",
298
- height=300,
299
- interactive=False,
300
- clear_color=[0.8, 0.8, 0.8, 1],
301
- elem_id="lighter_mesh",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  )
303
 
304
- is_samimage = gr.State(False)
305
- output_buf = gr.State()
306
- selected_points = gr.State(value=[])
307
-
308
- demo.load(start_session)
309
- demo.unload(end_session)
310
-
311
- single_image_input_tab.select(
312
- lambda: tuple(
313
- [False, gr.Row.update(visible=True), gr.Row.update(visible=False)]
314
- ),
315
- outputs=[is_samimage, single_image_example, single_sam_image_example],
316
- )
317
- samimage_input_tab.select(
318
- lambda: tuple(
319
- [True, gr.Row.update(visible=True), gr.Row.update(visible=False)]
320
- ),
321
- outputs=[is_samimage, single_sam_image_example, single_image_example],
322
- )
323
-
324
- image_prompt.upload(
325
- preprocess_image_fn,
326
- inputs=[image_prompt, rmbg_tag],
327
- outputs=[image_prompt, raw_image_cache],
328
- )
329
- image_prompt.change(
330
- lambda: tuple(
331
- [
332
- gr.Button(interactive=False),
333
- gr.Button(interactive=False),
334
- gr.Button(interactive=False),
335
- None,
336
- "",
337
- None,
338
- None,
339
- "",
340
- "",
341
- "",
342
- "",
343
- "",
344
- "",
345
- "",
346
- "",
347
- ]
348
- ),
349
- outputs=[
350
- extract_rep3d_btn,
351
- extract_urdf_btn,
352
- download_urdf,
353
- model_output_gs,
354
- aligned_gs,
355
- model_output_mesh,
356
- video_output,
357
- asset_cat_text,
358
- height_range_text,
359
- mass_range_text,
360
- asset_version_text,
361
- est_type_text,
362
- est_height_text,
363
- est_mass_text,
364
- est_mu_text,
365
- ],
366
- )
367
- image_prompt.change(
368
- active_btn_by_content,
369
- inputs=image_prompt,
370
- outputs=generate_btn,
371
- )
372
-
373
- image_prompt_sam.upload(
374
- preprocess_sam_image_fn,
375
- inputs=[image_prompt_sam],
376
- outputs=[image_prompt_sam, raw_image_cache],
377
- )
378
- image_prompt_sam.change(
379
- lambda: tuple(
380
- [
381
- gr.Button(interactive=False),
382
- gr.Button(interactive=False),
383
- gr.Button(interactive=False),
384
- None,
385
- None,
386
- None,
387
- "",
388
- "",
389
- "",
390
- "",
391
- "",
392
- "",
393
- "",
394
- "",
395
- None,
396
- [],
397
- ]
398
- ),
399
- outputs=[
400
- extract_rep3d_btn,
401
- extract_urdf_btn,
402
- download_urdf,
403
- model_output_gs,
404
- model_output_mesh,
405
- video_output,
406
- asset_cat_text,
407
- height_range_text,
408
- mass_range_text,
409
- asset_version_text,
410
- est_type_text,
411
- est_height_text,
412
- est_mass_text,
413
- est_mu_text,
414
- image_mask_sam,
415
- selected_points,
416
- ],
417
- )
418
-
419
- image_prompt_sam.select(
420
- select_point,
421
- [
422
- image_prompt_sam,
423
- selected_points,
424
- fg_bg_radio,
425
- ],
426
- [image_mask_sam, image_seg_sam],
427
- )
428
- image_seg_sam.change(
429
- active_btn_by_content,
430
- inputs=image_seg_sam,
431
- outputs=generate_btn,
432
- )
433
-
434
- generate_btn.click(
435
- get_seed,
436
- inputs=[randomize_seed, seed],
437
- outputs=[seed],
438
- ).success(
439
- image_to_3d,
440
- inputs=[
441
- image_prompt,
442
- seed,
443
- ss_guidance_strength,
444
- ss_sampling_steps,
445
- slat_guidance_strength,
446
- slat_sampling_steps,
447
- raw_image_cache,
448
- image_seg_sam,
449
- is_samimage,
450
- ],
451
- outputs=[output_buf, video_output],
452
- ).success(
453
- lambda: gr.Button(interactive=True),
454
- outputs=[extract_rep3d_btn],
455
- )
456
-
457
- extract_rep3d_btn.click(
458
- extract_3d_representations_v2,
459
- inputs=[
460
- output_buf,
461
- project_delight,
462
- texture_size,
463
- ],
464
- outputs=[
465
- model_output_mesh,
466
- model_output_gs,
467
- model_output_obj,
468
- aligned_gs,
469
- ],
470
- ).success(
471
- lambda: gr.Button(interactive=True),
472
- outputs=[extract_urdf_btn],
473
- )
474
-
475
- extract_urdf_btn.click(
476
- extract_urdf,
477
- inputs=[
478
- aligned_gs,
479
- model_output_obj,
480
- asset_cat_text,
481
- height_range_text,
482
- mass_range_text,
483
- asset_version_text,
484
- ],
485
- outputs=[
486
- download_urdf,
487
- est_type_text,
488
- est_height_text,
489
- est_mass_text,
490
- est_mu_text,
491
- ],
492
- queue=True,
493
- show_progress="full",
494
- ).success(
495
- lambda: gr.Button(interactive=True),
496
- outputs=[download_urdf],
497
- )
498
 
499
 
500
  if __name__ == "__main__":
501
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import yaml
4
+ import base64
5
+ import logging
6
+ import os
7
+ from io import BytesIO
8
+ from typing import Optional
9
+
10
+ import yaml
11
+ from openai import AzureOpenAI, OpenAI # pip install openai
12
+ from PIL import Image
13
+ from tenacity import (
14
+ retry,
15
+ stop_after_attempt,
16
+ stop_after_delay,
17
+ wait_random_exponential,
18
  )
19
 
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class GPTclient:
25
+ """A client to interact with the GPT model via OpenAI or Azure API."""
26
+
27
+ def __init__(
28
+ self,
29
+ endpoint: str,
30
+ api_key: str,
31
+ model_name: str = "yfb-gpt-4o",
32
+ api_version: str = None,
33
+ verbose: bool = False,
34
+ ):
35
+ if api_version is not None:
36
+ self.client = AzureOpenAI(
37
+ azure_endpoint=endpoint,
38
+ api_key=api_key,
39
+ api_version=api_version,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
41
+ else:
42
+ self.client = OpenAI(
43
+ base_url=endpoint,
44
+ api_key=api_key,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  )
46
 
47
+ self.endpoint = endpoint
48
+ self.model_name = model_name
49
+ self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
50
+ self.verbose = verbose
51
+ logger.info(f"Using GPT model: {self.model_name}.")
 
 
 
 
 
 
 
 
 
 
52
 
53
+ @retry(
54
+ wait=wait_random_exponential(min=1, max=20),
55
+ stop=(stop_after_attempt(10) | stop_after_delay(30)),
56
+ )
57
+ def completion_with_backoff(self, **kwargs):
58
+ return self.client.chat.completions.create(**kwargs)
59
+
60
+ def query(
61
+ self,
62
+ text_prompt: str,
63
+ image_base64: Optional[list[str | Image.Image]] = None,
64
+ system_role: Optional[str] = None,
65
+ ) -> Optional[str]:
66
+ """Queries the GPT model with a text and optional image prompts.
67
+
68
+ Args:
69
+ text_prompt (str): The main text input that the model responds to.
70
+ image_base64 (Optional[List[str]]): A list of image base64 strings
71
+ or local image paths or PIL.Image to accompany the text prompt.
72
+ system_role (Optional[str]): Optional system-level instructions
73
+ that specify the behavior of the assistant.
74
+
75
+ Returns:
76
+ Optional[str]: The response content generated by the model based on
77
+ the prompt. Returns `None` if an error occurs.
78
+ """
79
+ if system_role is None:
80
+ system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
81
+
82
+ content_user = [
83
+ {
84
+ "type": "text",
85
+ "text": text_prompt,
86
+ },
87
+ ]
88
+
89
+ # Process images if provided
90
+ if image_base64 is not None:
91
+ image_base64 = (
92
+ image_base64
93
+ if isinstance(image_base64, list)
94
+ else [image_base64]
95
  )
96
+ for img in image_base64:
97
+ if isinstance(img, Image.Image):
98
+ buffer = BytesIO()
99
+ img.save(buffer, format=img.format or "PNG")
100
+ buffer.seek(0)
101
+ image_binary = buffer.read()
102
+ img = base64.b64encode(image_binary).decode("utf-8")
103
+ elif (
104
+ len(os.path.splitext(img)) > 1
105
+ and os.path.splitext(img)[-1].lower() in self.image_formats
106
+ ):
107
+ if not os.path.exists(img):
108
+ raise FileNotFoundError(f"Image file not found: {img}")
109
+ with open(img, "rb") as f:
110
+ img = base64.b64encode(f.read()).decode("utf-8")
111
+
112
+ content_user.append(
113
+ {
114
+ "type": "image_url",
115
+ "image_url": {"url": f"data:image/png;base64,{img}"},
116
+ }
117
  )
118
 
119
+ payload = {
120
+ "messages": [
121
+ {"role": "system", "content": system_role},
122
+ {"role": "user", "content": content_user},
123
+ ],
124
+ "temperature": 0.1,
125
+ "max_tokens": 500,
126
+ "top_p": 0.1,
127
+ "frequency_penalty": 0,
128
+ "presence_penalty": 0,
129
+ "stop": None,
130
+ }
131
+ payload.update({"model": self.model_name})
132
+
133
+ response = None
134
+ try:
135
+ response = self.completion_with_backoff(**payload)
136
+ response = response.choices[0].message.content
137
+ except Exception as e:
138
+ logger.error(f"Error GPTclint {self.endpoint} API call: {e}")
139
+ response = None
140
+
141
+ if self.verbose:
142
+ logger.info(f"Prompt: {text_prompt}")
143
+ logger.info(f"Response: {response}")
144
+
145
+ return response
146
+
147
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT
148
+
149
+ print(GPT_CLIENT.api_version, GPT_CLIENT.model_name, GPT_CLIENT.endpoint)
150
+
151
+ def debug_gptclient(text_prompt, images, system_role):
152
+ try:
153
+ # Handle image input (Gradio passes images as PIL.Image or file paths)
154
+ image_base64 = images if images else None
155
+ response = GPT_CLIENT.query(
156
+ text_prompt=text_prompt,
157
+ image_base64=image_base64,
158
+ system_role=system_role
159
+ )
160
+ return response if response else "No response received or an error occurred."
161
+ except Exception as e:
162
+ return f"Error: {str(e)}"
163
+
164
+ # Create Gradio interface
165
+ iface = gr.Interface(
166
+ fn=debug_gptclient,
167
+ inputs=[
168
+ gr.Textbox(label="Text Prompt", placeholder="Enter your text prompt here"),
169
+ gr.File(label="Images (Optional)", type="filepath", file_count="multiple"),
170
+ gr.Textbox(
171
+ label="System Role (Optional)",
172
+ placeholder="Enter system role or leave empty for default",
173
+ value="You are a highly knowledgeable assistant specializing in physics, engineering, and object properties."
174
+ )
175
+ ],
176
+ outputs=gr.Textbox(label="Response"),
177
+ title="GPTclient Debug Interface",
178
+ description="A simple interface to debug GPTclient inputs and outputs."
179
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
  if __name__ == "__main__":
183
+ iface.launch()
embodied_gen/utils/gpt_clients.py CHANGED
@@ -61,6 +61,7 @@ class GPTclient:
61
 
62
  self.endpoint = endpoint
63
  self.model_name = model_name
 
64
  self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
65
  self.verbose = verbose
66
  logger.info(f"Using GPT model: {self.model_name}.")
 
61
 
62
  self.endpoint = endpoint
63
  self.model_name = model_name
64
+ self.api_version = api_version
65
  self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
66
  self.verbose = verbose
67
  logger.info(f"Using GPT model: {self.model_name}.")