File size: 3,248 Bytes
3251b02
 
 
8074036
5e76037
8074036
 
90f1e41
3251b02
8074036
90f1e41
 
 
3251b02
 
 
 
90f1e41
8074036
 
 
 
 
 
 
90f1e41
3067775
5e76037
033fe26
3cbcaed
8074036
 
3cbcaed
033fe26
8074036
90f1e41
 
 
 
 
 
 
 
 
2235066
8074036
2235066
3251b02
90f1e41
3251b02
 
8074036
2235066
8074036
3251b02
8074036
2a2202a
8074036
90f1e41
2235066
2a2202a
 
 
 
033fe26
8074036
 
5e76037
3251b02
8074036
3067775
90f1e41
3251b02
3067775
8074036
 
3067775
8074036
3067775
5e76037
 
 
 
2235066
8074036
3067775
 
033fe26
8074036
3251b02
033fe26
2235066
5e76037
2073989
3251b02
 
8074036
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import random
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
from huggingface_hub import hf_hub_download

# Configuration
BASE_MODEL = "stabilityai/stable-diffusion-2-1-base"
LORA_REPO = "Norod78/Flux-LoRA"
LORA_FILENAME = "flux_lora.safetensors"
MODEL_CACHE = "model_cache"
os.makedirs(MODEL_CACHE, exist_ok=True)

def get_pipeline():
    # Load safety checker + feature extractor
    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    )
    feature_extractor = CLIPFeatureExtractor.from_pretrained(
        "openai/clip-vit-base-patch32"
    )

    # Load base pipeline
    pipe = StableDiffusionPipeline.from_pretrained(
        BASE_MODEL,
        torch_dtype=torch.float32,
        cache_dir=MODEL_CACHE,
        safety_checker=safety_checker,
        feature_extractor=feature_extractor,
        use_safetensors=True
    )

    # Load and apply LoRA weights
    lora_path = hf_hub_download(
        repo_id=LORA_REPO,
        filename=LORA_FILENAME,
        cache_dir=MODEL_CACHE
    )
    pipe.load_lora_weights(lora_path)

    pipe.to("cpu")
    pipe.enable_attention_slicing()

    return pipe

# Load model once
pipeline = get_pipeline()

def generate_image(prompt, negative_prompt="", width=768, height=768, seed=-1, guidance_scale=7.5, num_inference_steps=25):
    if seed == -1:
        seed = random.randint(0, 2**31 - 1)
    generator = torch.Generator(device="cpu").manual_seed(seed)

    with torch.no_grad():
        output = pipeline(
            prompt=f"flux style, {prompt}",  # Stylized prompt
            negative_prompt=negative_prompt,
            width=width,
            height=height,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator
        )
        image = output.images[0]
    return image, seed

# Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸŒ€ Flux-LoRA Anime Image Generator (CPU only)")
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", lines=3)
            negative_prompt = gr.Textbox(label="Negative Prompt", value="blurry, lowres, bad anatomy")
            generate_btn = gr.Button("Generate", variant="primary")

            with gr.Accordion("Advanced", open=False):
                width = gr.Slider(512, 1024, value=768, step=64, label="Width")
                height = gr.Slider(512, 1024, value=768, step=64, label="Height")
                guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance")
                steps = gr.Slider(15, 50, value=25, step=1, label="Steps")
                seed = gr.Number(label="Seed", value=-1)

        with gr.Column():
            output_image = gr.Image(label="Result", type="pil")
            used_seed = gr.Textbox(label="Used Seed")

    generate_btn.click(
        generate_image,
        inputs=[prompt, negative_prompt, width, height, seed, guidance, steps],
        outputs=[output_image, used_seed]
    )

if __name__ == "__main__":
    demo.launch()