Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
from langchain_core.prompts import PromptTemplate | |
from openai import OpenAI | |
from log_util import logger | |
from time_it import time_it | |
from util import load_prompt | |
load_dotenv() | |
IMAGE_GEN_API_BASE_URL = os.getenv('IMAGE_GEN_API_BASE_URL') | |
IMAGE_GEN_API_KEY = os.getenv('IMAGE_GEN_API_KEY') | |
IMAGE_GEN_MODEL = os.getenv('IMAGE_GEN_MODEL') | |
IMAGE_GEN_MAX_PROMPT_LEN = int(os.getenv('IMAGE_GEN_MAX_PROMPT_LEN')) | |
IMAGE_GEN_OPTIONS = { | |
'response_extension': 'png', | |
'width': 1024, | |
'height': 1024, | |
'num_inference_steps': int(os.getenv('NUM_INFERENCE_STEPS', '16')), | |
'negative_prompt': '', | |
'seed': -1 | |
} | |
def generate_image(prompt_file: str, input: dict) -> str: | |
prompt = load_prompt(prompt_file) | |
if len(prompt) > IMAGE_GEN_MAX_PROMPT_LEN: | |
logger.info(f'Prompt length {len(prompt)} exceeds {IMAGE_GEN_MAX_PROMPT_LEN} characters, will be truncated.') | |
prompt = prompt[:IMAGE_GEN_MAX_PROMPT_LEN] | |
prompt_template = PromptTemplate.from_template(prompt) | |
prompt = prompt_template.invoke(input).to_string() | |
images_client = OpenAI(base_url=IMAGE_GEN_API_BASE_URL, api_key=IMAGE_GEN_API_KEY).images | |
response = images_client.generate(model=IMAGE_GEN_MODEL, prompt=prompt, response_format='url', extra_body=IMAGE_GEN_OPTIONS) | |
image_url = response.data[0].url | |
logger.info(f'{image_url=}') | |
return image_url | |