File size: 2,489 Bytes
555f59f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbf66ad
 
555f59f
 
 
d84a752
 
555f59f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a0eac
 
 
 
555f59f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbf66ad
d84a752
bbf66ad
 
d84a752
555f59f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

from typing import List
from diffusers.modular_pipelines import (
    PipelineState,
    ModularPipelineBlocks,
    InputParam,
    OutputParam,
)
import google.generativeai as genai
import os

SYSTEM_PROMPT = (
    "You are an expert image generation assistant. "
    "Take the user's short description and expand it into a vivid, detailed, and clear image generation prompt. "
    "Ensure rich colors, depth, realistic lighting, and an imaginative composition. "
    "Avoid vague terms — be specific about style, perspective, and mood. "
    "Try to keep the output under 512 tokens. "
    "Please don't return any prefix or suffix tokens, just the expanded user description."
)

class GeminiPromptExpander(ModularPipelineBlocks):
    model_name = "flux"
    
    def __init__(self, model_id="gemini-2.5-flash-lite", system_prompt=SYSTEM_PROMPT):
        super().__init__()
        api_key = os.getenv("GOOGLE_API_KEY")
        if api_key is None:
            raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.")
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel(model_name=model_id, system_instruction=system_prompt)

    @property
    def expected_components(self):
        return []

    @property
    def inputs(self) -> List[InputParam]:
        return [
            InputParam(
                "prompt",
                type_hint=str,
                required=True,
                description="Prompt to use",
            )
        ]
    
    @property
    def intermediate_inputs(self) -> List[InputParam]:
        return []

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam(
                "prompt",
                type_hint=str,
                description="Expanded prompt by the LLM",
            ),
            OutputParam(
                "old_prompt",
                type_hint=str,
                description="Old prompt provided by the user",
            )
        ]


    def __call__(self, components, state: PipelineState) -> PipelineState:
        block_state = self.get_block_state(state)

        old_prompt = block_state.prompt
        print(f"Actual prompt: {old_prompt}")
        block_state.prompt = self.model.generate_content(old_prompt).text
        block_state.old_prompt = old_prompt
        print(f"{block_state.prompt=}")
        self.set_block_state(state, block_state)

        return components, state