Spaces:
Running
on
Zero
Running
on
Zero
# Project EmbodiedGen | |
# | |
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
# implied. See the License for the specific language governing | |
# permissions and limitations under the License. | |
import base64 | |
import logging | |
import os | |
from io import BytesIO | |
from typing import Optional | |
import yaml | |
from openai import AzureOpenAI, OpenAI # pip install openai | |
from PIL import Image | |
from tenacity import ( | |
retry, | |
stop_after_attempt, | |
stop_after_delay, | |
wait_random_exponential, | |
) | |
from embodied_gen.utils.process_media import combine_images_to_base64 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class GPTclient: | |
"""A client to interact with the GPT model via OpenAI or Azure API.""" | |
def __init__( | |
self, | |
endpoint: str, | |
api_key: str, | |
model_name: str = "yfb-gpt-4o", | |
api_version: str = None, | |
verbose: bool = False, | |
): | |
if api_version is not None: | |
self.client = AzureOpenAI( | |
azure_endpoint=endpoint, | |
api_key=api_key, | |
api_version=api_version, | |
) | |
else: | |
self.client = OpenAI( | |
base_url=endpoint, | |
api_key=api_key, | |
) | |
self.endpoint = endpoint | |
self.model_name = model_name | |
self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"} | |
self.verbose = verbose | |
logger.info(f"Using GPT model: {self.model_name}.") | |
def completion_with_backoff(self, **kwargs): | |
return self.client.chat.completions.create(**kwargs) | |
def query( | |
self, | |
text_prompt: str, | |
image_base64: Optional[list[str | Image.Image]] = None, | |
system_role: Optional[str] = None, | |
) -> Optional[str]: | |
"""Queries the GPT model with a text and optional image prompts. | |
Args: | |
text_prompt (str): The main text input that the model responds to. | |
image_base64 (Optional[List[str]]): A list of image base64 strings | |
or local image paths or PIL.Image to accompany the text prompt. | |
system_role (Optional[str]): Optional system-level instructions | |
that specify the behavior of the assistant. | |
Returns: | |
Optional[str]: The response content generated by the model based on | |
the prompt. Returns `None` if an error occurs. | |
""" | |
if system_role is None: | |
system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa | |
content_user = [ | |
{ | |
"type": "text", | |
"text": text_prompt, | |
}, | |
] | |
# Process images if provided | |
if image_base64 is not None: | |
image_base64 = ( | |
image_base64 | |
if isinstance(image_base64, list) | |
else [image_base64] | |
) | |
for img in image_base64: | |
if isinstance(img, Image.Image): | |
buffer = BytesIO() | |
img.save(buffer, format=img.format or "PNG") | |
buffer.seek(0) | |
image_binary = buffer.read() | |
img = base64.b64encode(image_binary).decode("utf-8") | |
elif ( | |
len(os.path.splitext(img)) > 1 | |
and os.path.splitext(img)[-1].lower() in self.image_formats | |
): | |
if not os.path.exists(img): | |
raise FileNotFoundError(f"Image file not found: {img}") | |
with open(img, "rb") as f: | |
img = base64.b64encode(f.read()).decode("utf-8") | |
content_user.append( | |
{ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/png;base64,{img}"}, | |
} | |
) | |
payload = { | |
"messages": [ | |
{"role": "system", "content": system_role}, | |
{"role": "user", "content": content_user}, | |
], | |
"temperature": 0.1, | |
"max_tokens": 500, | |
"top_p": 0.1, | |
"frequency_penalty": 0, | |
"presence_penalty": 0, | |
"stop": None, | |
} | |
payload.update({"model": self.model_name}) | |
response = None | |
try: | |
response = self.completion_with_backoff(**payload) | |
response = response.choices[0].message.content | |
except Exception as e: | |
logger.error(f"Error GPTclint {self.endpoint} API call: {e}") | |
response = None | |
if self.verbose: | |
logger.info(f"Prompt: {text_prompt}") | |
logger.info(f"Response: {response}") | |
return response | |
with open("embodied_gen/utils/gpt_config.yaml", "r") as f: | |
config = yaml.safe_load(f) | |
agent_type = config["agent_type"] | |
agent_config = config.get(agent_type, {}) | |
# Prefer environment variables, fallback to YAML config | |
endpoint = os.environ.get("ENDPOINT", agent_config.get("endpoint")) | |
api_key = os.environ.get("API_KEY", agent_config.get("api_key")) | |
api_version = os.environ.get("API_VERSION", agent_config.get("api_version")) | |
model_name = os.environ.get("MODEL_NAME", agent_config.get("model_name")) | |
GPT_CLIENT = GPTclient( | |
endpoint=endpoint, | |
api_key=api_key, | |
api_version=api_version, | |
model_name=model_name, | |
) | |
if __name__ == "__main__": | |
if "openrouter" in GPT_CLIENT.endpoint: | |
response = GPT_CLIENT.query( | |
text_prompt="What is the content in each image?", | |
image_base64=combine_images_to_base64( | |
[ | |
"apps/assets/example_image/sample_02.jpg", | |
"apps/assets/example_image/sample_03.jpg", | |
] | |
), # input raw image_path if only one image | |
) | |
print(response) | |
else: | |
response = GPT_CLIENT.query( | |
text_prompt="What is the content in the images?", | |
image_base64=[ | |
Image.open("apps/assets/example_image/sample_02.jpg"), | |
Image.open("apps/assets/example_image/sample_03.jpg"), | |
], | |
) | |
print(response) | |
# test2: text prompt | |
response = GPT_CLIENT.query( | |
text_prompt="What is the capital of China?" | |
) | |
print(response) | |