File size: 5,202 Bytes
01903c6 1e991a3 01903c6 856ab89 01903c6 856ab89 01903c6 856ab89 01903c6 856ab89 1e991a3 856ab89 01903c6 1e991a3 d70dd6e 856ab89 1e991a3 856ab89 01903c6 856ab89 01903c6 856ab89 01903c6 856ab89 01903c6 1bdf675 856ab89 01903c6 1bdf675 01903c6 856ab89 01903c6 fcc51de 856ab89 01903c6 856ab89 01903c6 856ab89 95554d6 856ab89 95554d6 856ab89 95554d6 856ab89 95554d6 856ab89 3b0ebf3 856ab89 3b0ebf3 856ab89 01903c6 856ab89 01903c6 856ab89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import random
import gradio as gr
import numpy as np
import torch
import spaces
from diffusers import FluxPipeline
from PIL import Image
from diffusers.utils import export_to_gif
from transformers import pipeline
# -------------------------
# Configuration constants
# -------------------------
FRAMES = 4 # number of stills laid out horizontally
DEFAULT_HEIGHT = 256 # per‑frame size (px)
DEFAULT_FPS = 8 # smoother playback than the original 4 fps
MAX_SEED = np.iinfo(np.int32).max
# -------------------------
# Model initialisation
# -------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = (
FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16, # slightly higher precision than bfloat16 for crisper output
)
.to(device)
)
# English is the primary UI language, but Korean prompts are still accepted & translated.
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
# -------------------------
# Helper functions
# -------------------------
def split_image(input_image: Image.Image, frame_size: int) -> list[Image.Image]:
"""Cut a wide strip into equal square frames."""
return [
input_image.crop((i * frame_size, 0, (i + 1) * frame_size, frame_size))
for i in range(FRAMES)
]
def translate_to_english(text: str) -> str:
"""Translate Korean text to English if necessary."""
return translator(text)[0]["translation_text"]
@spaces.GPU()
def predict(
prompt: str,
seed: int = 42,
randomize_seed: bool = False,
guidance_scale: float = 7.0,
num_inference_steps: int = 40,
height: int = DEFAULT_HEIGHT,
fps: int = DEFAULT_FPS,
progress: gr.Progress = gr.Progress(track_tqdm=True),
):
# 1) Language handling
if any("\u3131" <= ch <= "\u318E" or "\uAC00" <= ch <= "\uD7A3" for ch in prompt):
prompt = translate_to_english(prompt)
# 2) Prompt template
prompt_template = (
f"A side-by-side {FRAMES} frame image showing consecutive stills from a looped gif moving left to right. "
f"The gif is of {prompt}."
)
# 3) Seed control
if randomize_seed:
seed = random.randint(0, MAX_SEED)
width = FRAMES * height # maintain square frames
# 4) Generation
image = pipe(
prompt=prompt_template,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
generator=torch.Generator(device).manual_seed(seed),
height=height,
width=width,
).images[0]
# 5) Assemble gif
gif_path = export_to_gif(split_image(image, height), "flux.gif", fps=fps)
return gif_path, image, seed
# -------------------------
# Interface
# -------------------------
css = """
#col-container {max-width: 820px; margin: 0 auto;}
footer {visibility: hidden;}
"""
examples = [
"cat lazily swinging its paws in mid-air",
"panda shaking its hips",
"flower blooming in timelapse",
]
with gr.Blocks(theme="soft", css=css) as demo:
gr.Markdown("<h1 style='text-align:center'>FLUX GIF Generator</h1>")
with gr.Column(elem_id="col-container"):
# Prompt row
with gr.Row():
prompt = gr.Text(
label="", show_label=False, max_lines=1, placeholder="Enter your prompt here…"
)
submit = gr.Button("Generate", scale=0)
# Outputs
output_gif = gr.Image(label="", show_label=False)
output_stills = gr.Image(label="", show_label=False, elem_id="stills")
# Advanced controls
with gr.Accordion("Advanced settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale", minimum=1, maximum=15, step=0.1, value=7.0
)
num_inference_steps = gr.Slider(
label="Inference steps", minimum=10, maximum=60, step=1, value=40
)
with gr.Row():
height = gr.Slider(
label="Frame size (px)", minimum=256, maximum=512, step=64, value=DEFAULT_HEIGHT
)
fps = gr.Slider(
label="GIF FPS", minimum=4, maximum=20, step=1, value=DEFAULT_FPS
)
# Example prompts
gr.Examples(
examples=examples,
fn=predict,
inputs=[prompt],
outputs=[output_gif, output_stills, seed],
cache_examples="lazy",
)
# Event wiring
gr.on(
triggers=[submit.click, prompt.submit],
fn=predict,
inputs=[
prompt,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
height,
fps,
],
outputs=[output_gif, output_stills, seed],
)
demo.launch()
|