File size: 2,489 Bytes
555f59f bbf66ad 555f59f d84a752 555f59f 52a0eac 555f59f bbf66ad d84a752 bbf66ad d84a752 555f59f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from typing import List
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
OutputParam,
)
import google.generativeai as genai
import os
SYSTEM_PROMPT = (
"You are an expert image generation assistant. "
"Take the user's short description and expand it into a vivid, detailed, and clear image generation prompt. "
"Ensure rich colors, depth, realistic lighting, and an imaginative composition. "
"Avoid vague terms — be specific about style, perspective, and mood. "
"Try to keep the output under 512 tokens. "
"Please don't return any prefix or suffix tokens, just the expanded user description."
)
class GeminiPromptExpander(ModularPipelineBlocks):
model_name = "flux"
def __init__(self, model_id="gemini-2.5-flash-lite", system_prompt=SYSTEM_PROMPT):
super().__init__()
api_key = os.getenv("GOOGLE_API_KEY")
if api_key is None:
raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.")
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model_name=model_id, system_instruction=system_prompt)
@property
def expected_components(self):
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"prompt",
type_hint=str,
required=True,
description="Prompt to use",
)
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt",
type_hint=str,
description="Expanded prompt by the LLM",
),
OutputParam(
"old_prompt",
type_hint=str,
description="Old prompt provided by the user",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
old_prompt = block_state.prompt
print(f"Actual prompt: {old_prompt}")
block_state.prompt = self.model.generate_content(old_prompt).text
block_state.old_prompt = old_prompt
print(f"{block_state.prompt=}")
self.set_block_state(state, block_state)
return components, state
|