""" RealCanvas-MJ4K A 16-GB-friendly Gradio Space that 1. streams the prompt dataset MohamedRashad/midjourney-detailed-prompts 2. generates realistic images using SDXL-Lightning 3. optionally displays random images from opendiffusionai/cc12m-4mp-realistic """ import gradio as gr import torch, os, random, json, requests from io import BytesIO from PIL import Image from datasets import load_dataset from huggingface_hub import hf_hub_download from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler # ------------------------------------------------- # 1. Load the prompt dataset (lazy streaming) # ------------------------------------------------- print("🔍 Streaming prompt dataset …") ds_prompts = load_dataset( "MohamedRashad/midjourney-detailed-prompts", split="train", streaming=True ) prompt_pool = list(ds_prompts.shuffle(seed=42).take(500_000)) # ≈ 5 MB RAM # ------------------------------------------------- # 2. Load SDXL-Lightning (fp16, 4-step, 4 GB VRAM) # ------------------------------------------------- MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" print("🤖 Loading SDXL-Lightning …") pipe = StableDiffusionXLPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.float16, variant="fp16", use_safetensors=True ) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) # lightning LoRA lora_path = hf_hub_download( repo_id="ByteDance/SDXL-Lightning", filename="sdxl_lightning_4step_lora.safetensors" ) pipe.load_lora_weights(lora_path) pipe = pipe.to("cuda") if torch.cuda.is_available() else pipe.to("cpu") pipe.enable_attention_slicing() # ------------------------------------------------- # 3. Random CC12M-4MP image helper (optional demo) # ------------------------------------------------- print("📸 Streaming CC12M-4MP-realistic …") ds_images = load_dataset( "opendiffusionai/cc12m-4mp-realistic", split="train", streaming=True ) img_pool = list(ds_images.shuffle(seed=42).take(1_000)) # ≈ 10 MB RAM def random_cc12m_image(): sample = random.choice(img_pool) return sample["image"].resize((512, 512)) # ------------------------------------------------- # 4. Gradio UI # ------------------------------------------------- def generate(prompt: str, steps: int = 4, guidance: float = 0.0): if not prompt.strip(): prompt = random.choice(prompt_pool)["prompt"] image = pipe( prompt, num_inference_steps=steps, guidance_scale=guidance ).images[0] return image.resize((768, 768)) with gr.Blocks(title="RealCanvas-MJ4K") as demo: gr.Markdown("# 🎨 RealCanvas-MJ4K | Midjourney-level realism under 16 GB") with gr.Row(): prompt_in = gr.Textbox( label="Prompt (leave empty for random Midjourney-style prompt)", lines=2 ) with gr.Row(): steps = gr.Slider(1, 8, value=4, step=1, label="Inference steps (SDXL-Lightning)") guidance = gr.Slider(0.0, 2.0, value=0.0, step=0.1, label="Guidance scale") btn = gr.Button("Generate", variant="primary") gallery = gr.Image(type="pil", label="Generated image") with gr.Accordion("📸 Random CC12M-4MP sample", open=False): cc_btn = gr.Button("Show random CC12M-4MP image") cc_out = gr.Image(type="pil", label="Real photo from dataset") btn.click(generate, [prompt_in, steps, guidance], gallery) cc_btn.click(random_cc12m_image, outputs=cc_out) demo.queue(max_size=8).launch()