import os from dotenv import load_dotenv from langchain_core.prompts import PromptTemplate from openai import OpenAI from log_util import logger from time_it import time_it from util import load_prompt load_dotenv() IMAGE_GEN_API_BASE_URL = os.getenv('IMAGE_GEN_API_BASE_URL') IMAGE_GEN_API_KEY = os.getenv('IMAGE_GEN_API_KEY') IMAGE_GEN_MODEL = os.getenv('IMAGE_GEN_MODEL') IMAGE_GEN_MAX_PROMPT_LEN = int(os.getenv('IMAGE_GEN_MAX_PROMPT_LEN')) IMAGE_GEN_OPTIONS = { 'response_extension': 'png', 'width': 1024, 'height': 1024, 'num_inference_steps': int(os.getenv('NUM_INFERENCE_STEPS', '16')), 'negative_prompt': '', 'seed': -1 } @time_it def generate_image(prompt_file: str, input: dict) -> str: prompt = load_prompt(prompt_file) if len(prompt) > IMAGE_GEN_MAX_PROMPT_LEN: logger.info(f'Prompt length {len(prompt)} exceeds {IMAGE_GEN_MAX_PROMPT_LEN} characters, will be truncated.') prompt = prompt[:IMAGE_GEN_MAX_PROMPT_LEN] prompt_template = PromptTemplate.from_template(prompt) prompt = prompt_template.invoke(input).to_string() images_client = OpenAI(base_url=IMAGE_GEN_API_BASE_URL, api_key=IMAGE_GEN_API_KEY).images response = images_client.generate(model=IMAGE_GEN_MODEL, prompt=prompt, response_format='url', extra_body=IMAGE_GEN_OPTIONS) image_url = response.data[0].url logger.info(f'{image_url=}') return image_url