gemini-prompt-expander / prompt_expander.py
sayakpaul's picture
sayakpaul HF Staff
Update prompt_expander.py
d84a752 verified
from typing import List
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
OutputParam,
)
import google.generativeai as genai
import os
SYSTEM_PROMPT = (
"You are an expert image generation assistant. "
"Take the user's short description and expand it into a vivid, detailed, and clear image generation prompt. "
"Ensure rich colors, depth, realistic lighting, and an imaginative composition. "
"Avoid vague terms — be specific about style, perspective, and mood. "
"Try to keep the output under 512 tokens. "
"Please don't return any prefix or suffix tokens, just the expanded user description."
)
class GeminiPromptExpander(ModularPipelineBlocks):
model_name = "flux"
def __init__(self, model_id="gemini-2.5-flash-lite", system_prompt=SYSTEM_PROMPT):
super().__init__()
api_key = os.getenv("GOOGLE_API_KEY")
if api_key is None:
raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.")
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model_name=model_id, system_instruction=system_prompt)
@property
def expected_components(self):
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"prompt",
type_hint=str,
required=True,
description="Prompt to use",
)
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt",
type_hint=str,
description="Expanded prompt by the LLM",
),
OutputParam(
"old_prompt",
type_hint=str,
description="Old prompt provided by the user",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
old_prompt = block_state.prompt
print(f"Actual prompt: {old_prompt}")
block_state.prompt = self.model.generate_content(old_prompt).text
block_state.old_prompt = old_prompt
print(f"{block_state.prompt=}")
self.set_block_state(state, block_state)
return components, state