File size: 3,519 Bytes
96f495f
d266dc0
 
 
 
 
96f495f
1638eb1
d266dc0
 
 
 
 
 
 
96f495f
d266dc0
 
 
 
 
 
 
 
1638eb1
d266dc0
6f77195
d266dc0
 
 
 
 
 
 
 
 
 
96f495f
d266dc0
 
 
 
 
96f495f
d266dc0
 
 
6f77195
d266dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d74b25
d266dc0
 
 
 
 
 
 
 
 
 
 
20fc52e
d266dc0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
RealCanvas-MJ4K
A 16-GB-friendly Gradio Space that
1. streams the prompt dataset MohamedRashad/midjourney-detailed-prompts
2. generates realistic images using SDXL-Lightning
3. optionally displays random images from opendiffusionai/cc12m-4mp-realistic
"""

import gradio as gr
import torch, os, random, json, requests
from io import BytesIO
from PIL import Image
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler

# -------------------------------------------------
# 1. Load the prompt dataset (lazy streaming)
# -------------------------------------------------
print("πŸ” Streaming prompt dataset …")
ds_prompts = load_dataset(
    "MohamedRashad/midjourney-detailed-prompts",
    split="train",
    streaming=True
)
prompt_pool = list(ds_prompts.shuffle(seed=42).take(500_000))  # β‰ˆ 5 MB RAM

# -------------------------------------------------
# 2. Load SDXL-Lightning (fp16, 4-step, 4 GB VRAM)
# -------------------------------------------------
MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
print("πŸ€– Loading SDXL-Lightning …")
pipe = StableDiffusionXLPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
# lightning LoRA
lora_path = hf_hub_download(
    repo_id="ByteDance/SDXL-Lightning",
    filename="sdxl_lightning_4step_lora.safetensors"
)
pipe.load_lora_weights(lora_path)
pipe = pipe.to("cuda") if torch.cuda.is_available() else pipe.to("cpu")
pipe.enable_attention_slicing()

# -------------------------------------------------
# 3. Random CC12M-4MP image helper (optional demo)
# -------------------------------------------------
print("πŸ“Έ Streaming CC12M-4MP-realistic …")
ds_images = load_dataset(
    "opendiffusionai/cc12m-4mp-realistic",
    split="train",
    streaming=True
)
img_pool = list(ds_images.shuffle(seed=42).take(1_000))  # β‰ˆ 10 MB RAM

def random_cc12m_image():
    sample = random.choice(img_pool)
    return sample["image"].resize((512, 512))

# -------------------------------------------------
# 4. Gradio UI
# -------------------------------------------------
def generate(prompt: str, steps: int = 4, guidance: float = 0.0):
    if not prompt.strip():
        prompt = random.choice(prompt_pool)["prompt"]
    image = pipe(
        prompt,
        num_inference_steps=steps,
        guidance_scale=guidance
    ).images[0]
    return image.resize((768, 768))

with gr.Blocks(title="RealCanvas-MJ4K") as demo:
    gr.Markdown("# 🎨 RealCanvas-MJ4K  |  Midjourney-level realism under 16 GB")
    with gr.Row():
        prompt_in = gr.Textbox(
            label="Prompt (leave empty for random Midjourney-style prompt)",
            lines=2
        )
    with gr.Row():
        steps = gr.Slider(1, 8, value=4, step=1, label="Inference steps (SDXL-Lightning)")
        guidance = gr.Slider(0.0, 2.0, value=0.0, step=0.1, label="Guidance scale")
    btn = gr.Button("Generate", variant="primary")
    gallery = gr.Image(type="pil", label="Generated image")
    with gr.Accordion("πŸ“Έ Random CC12M-4MP sample", open=False):
        cc_btn = gr.Button("Show random CC12M-4MP image")
        cc_out = gr.Image(type="pil", label="Real photo from dataset")

    btn.click(generate, [prompt_in, steps, guidance], gallery)
    cc_btn.click(random_cc12m_image, outputs=cc_out)

demo.queue(max_size=8).launch()