File size: 3,327 Bytes
3251b02
 
 
 
033fe26
 
3251b02
033fe26
 
 
3251b02
 
 
 
033fe26
 
2a2202a
033fe26
3251b02
033fe26
 
 
 
 
 
 
2235066
 
033fe26
 
 
2235066
 
 
033fe26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2235066
 
 
 
 
 
3251b02
2a2202a
3251b02
 
 
 
2235066
 
 
 
033fe26
 
3251b02
2235066
3251b02
 
 
2a2202a
 
2235066
 
2a2202a
 
 
 
033fe26
2a2202a
3251b02
 
 
033fe26
 
3251b02
 
033fe26
 
 
 
3251b02
033fe26
 
 
 
 
2235066
3251b02
033fe26
 
 
 
3251b02
033fe26
2235066
3251b02
 
 
 
033fe26
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import random
import gradio as gr
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer

# Configuration - Using Flux Model
MODEL_ID = "CompVis/Flux-Pro"
LORA_ID = "flux/lora-weights"
MODEL_CACHE = "model_cache"
os.makedirs(MODEL_CACHE, exist_ok=True)

def get_pipeline():
    # Load Flux components
    unet = UNet2DConditionModel.from_pretrained(
        MODEL_ID,
        subfolder="unet",
        cache_dir=MODEL_CACHE,
        torch_dtype=torch.float32
    )
    
    text_encoder = CLIPTextModel.from_pretrained(
        MODEL_ID,
        subfolder="text_encoder",
        cache_dir=MODEL_CACHE
    )
    
    tokenizer = CLIPTokenizer.from_pretrained(
        MODEL_ID,
        subfolder="tokenizer",
        cache_dir=MODEL_CACHE
    )
    
    # Create pipeline
    pipe = DiffusionPipeline.from_pretrained(
        MODEL_ID,
        unet=unet,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        cache_dir=MODEL_CACHE,
        torch_dtype=torch.float32,
        safety_checker=None
    )
    
    # Load LoRA weights
    lora_path = hf_hub_download(
        LORA_ID,
        "flux_lora.safetensors",
        cache_dir=MODEL_CACHE
    )
    pipe.unet.load_attn_procs(lora_path)
    
    # CPU optimizations
    pipe = pipe.to("cpu")
    pipe.enable_attention_slicing()
    
    return pipe

# Load model
pipeline = get_pipeline()

def generate_image(
    prompt: str,
    negative_prompt: str = "",
    width: int = 768,
    height: int = 768,
    seed: int = -1,
    guidance_scale: float = 7.5,
    num_inference_steps: int = 25
):
    if seed == -1:
        seed = random.randint(0, 2147483647)
    generator = torch.Generator(device="cpu").manual_seed(seed)
    
    with torch.no_grad():
        image = pipeline(
            prompt=prompt,
            negative_prompt=negative_prompt,
            width=width,
            height=height,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator
        ).images[0]
    
    return image, seed

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸŒ€ FLUX-Pro Image Generator")
    
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", lines=3)
            negative_prompt = gr.Textbox(label="Negative Prompt", value="blurry, low quality")
            generate_btn = gr.Button("Generate", variant="primary")
            
            with gr.Accordion("Advanced", open=False):
                width = gr.Slider(512, 1024, value=768, step=64, label="Width")
                height = gr.Slider(512, 1024, value=768, step=64, label="Height")
                guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance")
                steps = gr.Slider(15, 50, value=25, step=1, label="Steps")
                seed = gr.Number(label="Seed", value=-1)
        
        with gr.Column():
            output_image = gr.Image(label="Result", type="pil")
            used_seed = gr.Textbox(label="Used Seed")

    generate_btn.click(
        generate_image,
        inputs=[prompt, negative_prompt, width, height, seed, guidance, steps],
        outputs=[output_image, used_seed]
    )

if __name__ == "__main__":
    demo.launch()