xinjie.wang commited on
Commit
d31a703
·
1 Parent(s): f219113
Files changed (2) hide show
  1. app.py +2 -141
  2. embodied_gen/validators/urdf_convertor.py +3 -0
app.py CHANGED
@@ -1,149 +1,10 @@
1
- import gradio as gr
2
- import os
3
- import yaml
4
- import base64
5
- import logging
6
- import os
7
- from io import BytesIO
8
- from typing import Optional
9
-
10
- import yaml
11
- from openai import AzureOpenAI, OpenAI # pip install openai
12
- from PIL import Image
13
- from tenacity import (
14
- retry,
15
- stop_after_attempt,
16
- stop_after_delay,
17
- wait_random_exponential,
18
- )
19
 
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
23
 
24
- class GPTclient:
25
- """A client to interact with the GPT model via OpenAI or Azure API."""
26
-
27
- def __init__(
28
- self,
29
- endpoint: str,
30
- api_key: str,
31
- model_name: str = "yfb-gpt-4o",
32
- api_version: str = None,
33
- verbose: bool = False,
34
- ):
35
- if api_version is not None:
36
- self.client = AzureOpenAI(
37
- azure_endpoint=endpoint,
38
- api_key=api_key,
39
- api_version=api_version,
40
- )
41
- else:
42
- self.client = OpenAI(
43
- base_url=endpoint,
44
- api_key=api_key,
45
- )
46
-
47
- self.endpoint = endpoint
48
- self.model_name = model_name
49
- self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
50
- self.verbose = verbose
51
- logger.info(f"Using GPT model: {self.model_name}.")
52
-
53
- @retry(
54
- wait=wait_random_exponential(min=1, max=20),
55
- stop=(stop_after_attempt(10) | stop_after_delay(30)),
56
- )
57
- def completion_with_backoff(self, **kwargs):
58
- return self.client.chat.completions.create(**kwargs)
59
-
60
- def query(
61
- self,
62
- text_prompt: str,
63
- image_base64: Optional[list[str | Image.Image]] = None,
64
- system_role: Optional[str] = None,
65
- ) -> Optional[str]:
66
- """Queries the GPT model with a text and optional image prompts.
67
-
68
- Args:
69
- text_prompt (str): The main text input that the model responds to.
70
- image_base64 (Optional[List[str]]): A list of image base64 strings
71
- or local image paths or PIL.Image to accompany the text prompt.
72
- system_role (Optional[str]): Optional system-level instructions
73
- that specify the behavior of the assistant.
74
-
75
- Returns:
76
- Optional[str]: The response content generated by the model based on
77
- the prompt. Returns `None` if an error occurs.
78
- """
79
- if system_role is None:
80
- system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
81
-
82
- content_user = [
83
- {
84
- "type": "text",
85
- "text": text_prompt,
86
- },
87
- ]
88
-
89
- # Process images if provided
90
- if image_base64 is not None:
91
- image_base64 = (
92
- image_base64
93
- if isinstance(image_base64, list)
94
- else [image_base64]
95
- )
96
- for img in image_base64:
97
- if isinstance(img, Image.Image):
98
- buffer = BytesIO()
99
- img.save(buffer, format=img.format or "PNG")
100
- buffer.seek(0)
101
- image_binary = buffer.read()
102
- img = base64.b64encode(image_binary).decode("utf-8")
103
- elif (
104
- len(os.path.splitext(img)) > 1
105
- and os.path.splitext(img)[-1].lower() in self.image_formats
106
- ):
107
- if not os.path.exists(img):
108
- raise FileNotFoundError(f"Image file not found: {img}")
109
- with open(img, "rb") as f:
110
- img = base64.b64encode(f.read()).decode("utf-8")
111
-
112
- content_user.append(
113
- {
114
- "type": "image_url",
115
- "image_url": {"url": f"data:image/png;base64,{img}"},
116
- }
117
- )
118
-
119
- payload = {
120
- "messages": [
121
- {"role": "system", "content": system_role},
122
- {"role": "user", "content": content_user},
123
- ],
124
- "temperature": 0.1,
125
- "max_tokens": 500,
126
- "top_p": 0.1,
127
- "frequency_penalty": 0,
128
- "presence_penalty": 0,
129
- "stop": None,
130
- }
131
- payload.update({"model": self.model_name})
132
-
133
- response = None
134
- try:
135
- response = self.completion_with_backoff(**payload)
136
- response = response.choices[0].message.content
137
- except Exception as e:
138
- logger.error(f"Error GPTclint {self.endpoint} API call: {e}")
139
- response = None
140
-
141
- if self.verbose:
142
- logger.info(f"Prompt: {text_prompt}")
143
- logger.info(f"Response: {response}")
144
-
145
- return response
146
-
147
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
148
 
149
  print(GPT_CLIENT.api_version, GPT_CLIENT.model_name, GPT_CLIENT.endpoint)
 
1
+ import gradio as gr
2
+ import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  logging.basicConfig(level=logging.INFO)
5
  logger = logging.getLogger(__name__)
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
9
 
10
  print(GPT_CLIENT.api_version, GPT_CLIENT.model_name, GPT_CLIENT.endpoint)
embodied_gen/validators/urdf_convertor.py CHANGED
@@ -366,6 +366,9 @@ class URDFGenerator(object):
366
  image_path = combine_images_to_base64(image_path)
367
 
368
  response = self.gpt_client.query(text_prompt, image_path)
 
 
 
369
  if response is None:
370
  asset_attrs = {
371
  "category": category.lower(),
 
366
  image_path = combine_images_to_base64(image_path)
367
 
368
  response = self.gpt_client.query(text_prompt, image_path)
369
+ print("text_prompt: ", text_prompt)
370
+ print("image_path: ", image_path)
371
+ print("response: ", response)
372
  if response is None:
373
  asset_attrs = {
374
  "category": category.lower(),