CodeNyx / app.py
AryanRathod3097's picture
Update app.py
d266dc0 verified
"""
RealCanvas-MJ4K
A 16-GB-friendly Gradio Space that
1. streams the prompt dataset MohamedRashad/midjourney-detailed-prompts
2. generates realistic images using SDXL-Lightning
3. optionally displays random images from opendiffusionai/cc12m-4mp-realistic
"""
import gradio as gr
import torch, os, random, json, requests
from io import BytesIO
from PIL import Image
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
# -------------------------------------------------
# 1. Load the prompt dataset (lazy streaming)
# -------------------------------------------------
print("πŸ” Streaming prompt dataset …")
ds_prompts = load_dataset(
"MohamedRashad/midjourney-detailed-prompts",
split="train",
streaming=True
)
prompt_pool = list(ds_prompts.shuffle(seed=42).take(500_000)) # β‰ˆ 5 MB RAM
# -------------------------------------------------
# 2. Load SDXL-Lightning (fp16, 4-step, 4 GB VRAM)
# -------------------------------------------------
MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
print("πŸ€– Loading SDXL-Lightning …")
pipe = StableDiffusionXLPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
# lightning LoRA
lora_path = hf_hub_download(
repo_id="ByteDance/SDXL-Lightning",
filename="sdxl_lightning_4step_lora.safetensors"
)
pipe.load_lora_weights(lora_path)
pipe = pipe.to("cuda") if torch.cuda.is_available() else pipe.to("cpu")
pipe.enable_attention_slicing()
# -------------------------------------------------
# 3. Random CC12M-4MP image helper (optional demo)
# -------------------------------------------------
print("πŸ“Έ Streaming CC12M-4MP-realistic …")
ds_images = load_dataset(
"opendiffusionai/cc12m-4mp-realistic",
split="train",
streaming=True
)
img_pool = list(ds_images.shuffle(seed=42).take(1_000)) # β‰ˆ 10 MB RAM
def random_cc12m_image():
sample = random.choice(img_pool)
return sample["image"].resize((512, 512))
# -------------------------------------------------
# 4. Gradio UI
# -------------------------------------------------
def generate(prompt: str, steps: int = 4, guidance: float = 0.0):
if not prompt.strip():
prompt = random.choice(prompt_pool)["prompt"]
image = pipe(
prompt,
num_inference_steps=steps,
guidance_scale=guidance
).images[0]
return image.resize((768, 768))
with gr.Blocks(title="RealCanvas-MJ4K") as demo:
gr.Markdown("# 🎨 RealCanvas-MJ4K | Midjourney-level realism under 16 GB")
with gr.Row():
prompt_in = gr.Textbox(
label="Prompt (leave empty for random Midjourney-style prompt)",
lines=2
)
with gr.Row():
steps = gr.Slider(1, 8, value=4, step=1, label="Inference steps (SDXL-Lightning)")
guidance = gr.Slider(0.0, 2.0, value=0.0, step=0.1, label="Guidance scale")
btn = gr.Button("Generate", variant="primary")
gallery = gr.Image(type="pil", label="Generated image")
with gr.Accordion("πŸ“Έ Random CC12M-4MP sample", open=False):
cc_btn = gr.Button("Show random CC12M-4MP image")
cc_out = gr.Image(type="pil", label="Real photo from dataset")
btn.click(generate, [prompt_in, steps, guidance], gallery)
cc_btn.click(random_cc12m_image, outputs=cc_out)
demo.queue(max_size=8).launch()