Spaces:
Runtime error
Runtime error
import os | |
import random | |
import gradio as gr | |
import torch | |
from diffusers import DiffusionPipeline | |
from transformers import CLIPTextModel, CLIPTokenizer | |
# Configuration - Using Flux Model | |
MODEL_ID = "CompVis/Flux-Pro" | |
MODEL_CACHE = "model_cache" | |
os.makedirs(MODEL_CACHE, exist_ok=True) | |
def get_pipeline(): | |
# Load Flux model components | |
text_encoder = CLIPTextModel.from_pretrained( | |
MODEL_ID, | |
subfolder="text_encoder", | |
cache_dir=MODEL_CACHE | |
) | |
tokenizer = CLIPTokenizer.from_pretrained( | |
MODEL_ID, | |
subfolder="tokenizer", | |
cache_dir=MODEL_CACHE | |
) | |
# Create pipeline | |
pipe = DiffusionPipeline.from_pretrained( | |
MODEL_ID, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
cache_dir=MODEL_CACHE, | |
torch_dtype=torch.float32, | |
safety_checker=None | |
) | |
# CPU optimizations | |
pipe = pipe.to("cpu") | |
pipe.enable_attention_slicing() | |
return pipe | |
# Load model | |
pipeline = get_pipeline() | |
def generate_image( | |
prompt: str, | |
negative_prompt: str = "", | |
width: int = 768, | |
height: int = 768, | |
seed: int = -1, | |
guidance_scale: float = 7.5, | |
num_inference_steps: int = 25 | |
): | |
if seed == -1: | |
seed = random.randint(0, 2147483647) | |
generator = torch.Generator(device="cpu").manual_seed(seed) | |
with torch.no_grad(): | |
image = pipeline( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator | |
).images[0] | |
return image, seed | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π FLUX-Pro Image Generator") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", lines=3) | |
negative_prompt = gr.Textbox(label="Negative Prompt", value="blurry, low quality") | |
generate_btn = gr.Button("Generate", variant="primary") | |
with gr.Accordion("Advanced", open=False): | |
width = gr.Slider(512, 1024, value=768, step=64, label="Width") | |
height = gr.Slider(512, 1024, value=768, step=64, label="Height") | |
guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance") | |
steps = gr.Slider(15, 50, value=25, step=1, label="Steps") | |
seed = gr.Number(label="Seed", value=-1) | |
with gr.Column(): | |
output_image = gr.Image(label="Result", type="pil") | |
used_seed = gr.Textbox(label="Used Seed") | |
generate_btn.click( | |
generate_image, | |
inputs=[prompt, negative_prompt, width, height, seed, guidance, steps], | |
outputs=[output_image, used_seed] | |
) | |
if __name__ == "__main__": | |
demo.launch() |