Spaces:
Build error
Build error
| import gradio as gr | |
| from dataclasses import dataclass | |
| import spaces | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import StableDiffusionXLPipeline, FluxPipeline | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| class GradioArgs: | |
| seed: list = None | |
| prompt: str = None | |
| mix_precision: str = "bf16" | |
| num_intervention_steps: int = 50 | |
| model: str = "sdxl" | |
| binary: bool = False | |
| masking: str = "binary" | |
| scope: str = "global" | |
| ratio: list = None | |
| width: int = None | |
| height: int = None | |
| epsilon: float = 0.0 | |
| lambda_threshold: float = 0.001 | |
| def __post_init__(self): | |
| if self.seed is None: | |
| self.seed = [44] | |
| def binary_mask_eval(args, model): | |
| model = model.lower() | |
| # load sdxl model | |
| if model == "sdxl": | |
| pruned_pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 | |
| ).to("cpu") | |
| pruned_pipe.unet = torch.load( | |
| hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"), | |
| map_location="cpu", | |
| ) | |
| elif model == "flux": | |
| pruned_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to( | |
| "cpu" | |
| ) | |
| pruned_pipe.transformer = torch.load( | |
| hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"), | |
| map_location="cpu", | |
| ) | |
| # reload the original model | |
| if model == "sdxl": | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 | |
| ).to("cpu") | |
| elif model == "flux": | |
| pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cpu") | |
| print("prune complete") | |
| return pipe, pruned_pipe | |
| def generate_images(prompt, seed, steps, pipe, pruned_pipe): | |
| pipe.to("cuda") | |
| pruned_pipe.to("cuda") | |
| # Run the model and return images directly | |
| g_cpu = torch.Generator("cuda").manual_seed(seed) | |
| original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] | |
| g_cpu = torch.Generator("cuda").manual_seed(seed) | |
| ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] | |
| return original_image, ecodiff_image | |
| def on_prune_click(prompt, seed, steps, model): | |
| args = GradioArgs(prompt=prompt, seed=[seed], num_intervention_steps=steps) | |
| pipe, pruned_pipe = binary_mask_eval(args, model) | |
| return pipe, pruned_pipe, [("Model Initialized", "green")] | |
| def on_generate_click(prompt, seed, steps, pipe, pruned_pipe): | |
| original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe) | |
| return original_image, ecodiff_image | |
| header = """ | |
| # 🌱 Text-to-Image Generation with EcoDiff Pruned Models | |
| """ | |
| header_2 = """ | |
| <div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
| <a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> | |
| <a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a> | |
| </div> | |
| """ | |
| header_3 = """ | |
| <div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
| For âš¡ <b>faster</b> âš¡ DEMO on one model only, please visit | |
| <a href="https://huggingface.co/spaces/zhangyang-0123/EcoDiff-SD-XL"><img alt="Static Badge" src="https://img.shields.io/badge/SDXL-fedcba.svg"></a> | |
| <a href="https://huggingface.co/spaces/zhangyang-0123/EcoDiff-FLUX-Schnell"><img alt="Static Badge" src="https://img.shields.io/badge/FLUX-fgdfba"></a> | |
| </div> | |
| """ | |
| def create_demo(): | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.Markdown(header) | |
| with gr.Row(): | |
| gr.HTML(header_2) | |
| with gr.Row(): | |
| gr.HTML(header_3) | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| **Note: Please first initialize the model before generating images. This may take a while to fully load.** | |
| """ | |
| ) | |
| with gr.Row(): | |
| model_choice = gr.Radio(choices=["SDXL", "FLUX"], value="SDXL", label="Model", scale=2) | |
| pruning_ratio = gr.Text("20% Pruning Ratio for SDXL, FLUX", label="Pruning Ratio", scale=2) | |
| status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1) | |
| prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1) | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| **Generate images with the original model and the pruned model. May take up to 1 minute due to dynamic allocation of GPU.** | |
| **Note: we prune on step-distilled FLUX, you should use step 5 (instead of 50) for FLUX generation.** | |
| """ | |
| ) | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value="A clock tower floating in a sea of clouds", | |
| scale=3, | |
| ) | |
| seed = gr.Number(label="Seed", value=44, precision=0, scale=1) | |
| steps = gr.Slider( | |
| label="Number of Steps", | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| scale=1, | |
| ) | |
| generate_btn = gr.Button("Generate Images") | |
| gr.Examples( | |
| examples=[ | |
| "A clock tower floating in a sea of clouds", | |
| "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
| "An astronaut riding a green horse", | |
| "A delicious ceviche cheesecake slice", | |
| "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages", | |
| ], | |
| inputs=[prompt], | |
| ) | |
| with gr.Row(): | |
| original_output = gr.Image(label="Original Output") | |
| ecodiff_output = gr.Image(label="EcoDiff Output") | |
| pipe_state = gr.State(None) | |
| pruned_pipe_state = gr.State(None) | |
| prompt.submit( | |
| fn=on_generate_click, | |
| inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state], | |
| outputs=[original_output, ecodiff_output], | |
| ) | |
| prune_btn.click( | |
| fn=on_prune_click, | |
| inputs=[prompt, seed, steps, model_choice], | |
| outputs=[pipe_state, pruned_pipe_state, status_label], | |
| ) | |
| generate_btn.click( | |
| fn=on_generate_click, | |
| inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state], | |
| outputs=[original_output, ecodiff_output], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch(share=True) | |