Spaces:
Runtime error
Runtime error
""" | |
RealCanvas-MJ4K | |
A 16-GB-friendly Gradio Space that | |
1. streams the prompt dataset MohamedRashad/midjourney-detailed-prompts | |
2. generates realistic images using SDXL-Lightning | |
3. optionally displays random images from opendiffusionai/cc12m-4mp-realistic | |
""" | |
import gradio as gr | |
import torch, os, random, json, requests | |
from io import BytesIO | |
from PIL import Image | |
from datasets import load_dataset | |
from huggingface_hub import hf_hub_download | |
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler | |
# ------------------------------------------------- | |
# 1. Load the prompt dataset (lazy streaming) | |
# ------------------------------------------------- | |
print("π Streaming prompt dataset β¦") | |
ds_prompts = load_dataset( | |
"MohamedRashad/midjourney-detailed-prompts", | |
split="train", | |
streaming=True | |
) | |
prompt_pool = list(ds_prompts.shuffle(seed=42).take(500_000)) # β 5 MB RAM | |
# ------------------------------------------------- | |
# 2. Load SDXL-Lightning (fp16, 4-step, 4 GB VRAM) | |
# ------------------------------------------------- | |
MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" | |
print("π€ Loading SDXL-Lightning β¦") | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True | |
) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
# lightning LoRA | |
lora_path = hf_hub_download( | |
repo_id="ByteDance/SDXL-Lightning", | |
filename="sdxl_lightning_4step_lora.safetensors" | |
) | |
pipe.load_lora_weights(lora_path) | |
pipe = pipe.to("cuda") if torch.cuda.is_available() else pipe.to("cpu") | |
pipe.enable_attention_slicing() | |
# ------------------------------------------------- | |
# 3. Random CC12M-4MP image helper (optional demo) | |
# ------------------------------------------------- | |
print("πΈ Streaming CC12M-4MP-realistic β¦") | |
ds_images = load_dataset( | |
"opendiffusionai/cc12m-4mp-realistic", | |
split="train", | |
streaming=True | |
) | |
img_pool = list(ds_images.shuffle(seed=42).take(1_000)) # β 10 MB RAM | |
def random_cc12m_image(): | |
sample = random.choice(img_pool) | |
return sample["image"].resize((512, 512)) | |
# ------------------------------------------------- | |
# 4. Gradio UI | |
# ------------------------------------------------- | |
def generate(prompt: str, steps: int = 4, guidance: float = 0.0): | |
if not prompt.strip(): | |
prompt = random.choice(prompt_pool)["prompt"] | |
image = pipe( | |
prompt, | |
num_inference_steps=steps, | |
guidance_scale=guidance | |
).images[0] | |
return image.resize((768, 768)) | |
with gr.Blocks(title="RealCanvas-MJ4K") as demo: | |
gr.Markdown("# π¨ RealCanvas-MJ4K | Midjourney-level realism under 16 GB") | |
with gr.Row(): | |
prompt_in = gr.Textbox( | |
label="Prompt (leave empty for random Midjourney-style prompt)", | |
lines=2 | |
) | |
with gr.Row(): | |
steps = gr.Slider(1, 8, value=4, step=1, label="Inference steps (SDXL-Lightning)") | |
guidance = gr.Slider(0.0, 2.0, value=0.0, step=0.1, label="Guidance scale") | |
btn = gr.Button("Generate", variant="primary") | |
gallery = gr.Image(type="pil", label="Generated image") | |
with gr.Accordion("πΈ Random CC12M-4MP sample", open=False): | |
cc_btn = gr.Button("Show random CC12M-4MP image") | |
cc_out = gr.Image(type="pil", label="Real photo from dataset") | |
btn.click(generate, [prompt_in, steps, guidance], gallery) | |
cc_btn.click(random_cc12m_image, outputs=cc_out) | |
demo.queue(max_size=8).launch() |