|
|
|
from typing import List |
|
from diffusers.modular_pipelines import ( |
|
PipelineState, |
|
ModularPipelineBlocks, |
|
InputParam, |
|
OutputParam, |
|
) |
|
import google.generativeai as genai |
|
import os |
|
|
|
SYSTEM_PROMPT = ( |
|
"You are an expert image generation assistant. " |
|
"Take the user's short description and expand it into a vivid, detailed, and clear image generation prompt. " |
|
"Ensure rich colors, depth, realistic lighting, and an imaginative composition. " |
|
"Avoid vague terms — be specific about style, perspective, and mood. " |
|
"Try to keep the output under 512 tokens. " |
|
"Please don't return any prefix or suffix tokens, just the expanded user description." |
|
) |
|
|
|
class GeminiPromptExpander(ModularPipelineBlocks): |
|
model_name = "flux" |
|
|
|
def __init__(self, model_id="gemini-2.5-flash-lite", system_prompt=SYSTEM_PROMPT): |
|
super().__init__() |
|
api_key = os.getenv("GOOGLE_API_KEY") |
|
if api_key is None: |
|
raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.") |
|
genai.configure(api_key=api_key) |
|
self.model = genai.GenerativeModel(model_name=model_id, system_instruction=system_prompt) |
|
|
|
@property |
|
def expected_components(self): |
|
return [] |
|
|
|
@property |
|
def inputs(self) -> List[InputParam]: |
|
return [ |
|
InputParam( |
|
"prompt", |
|
type_hint=str, |
|
required=True, |
|
description="Prompt to use", |
|
) |
|
] |
|
|
|
@property |
|
def intermediate_inputs(self) -> List[InputParam]: |
|
return [] |
|
|
|
@property |
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
return [ |
|
OutputParam( |
|
"prompt", |
|
type_hint=str, |
|
description="Expanded prompt by the LLM", |
|
), |
|
OutputParam( |
|
"old_prompt", |
|
type_hint=str, |
|
description="Old prompt provided by the user", |
|
) |
|
] |
|
|
|
|
|
def __call__(self, components, state: PipelineState) -> PipelineState: |
|
block_state = self.get_block_state(state) |
|
|
|
old_prompt = block_state.prompt |
|
print(f"Actual prompt: {old_prompt}") |
|
block_state.prompt = self.model.generate_content(old_prompt).text |
|
block_state.old_prompt = old_prompt |
|
print(f"{block_state.prompt=}") |
|
self.set_block_state(state, block_state) |
|
|
|
return components, state |
|
|