|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from diffusers import AutoencoderKL, UNet2DConditionModel |
|
from diffusers import LMSDiscreteScheduler |
|
import torch |
|
from tqdm.auto import tqdm |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") |
|
|
|
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema") |
|
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") |
|
|
|
beta_start, beta_end = 0.00085, 0.012 |
|
height = 512 |
|
width = 512 |
|
num_inference_steps = 70 |
|
guidance_scale = 7.5 |
|
batch_size = 1 |
|
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear", num_train_timesteps=1000) |
|
|
|
def text_enc(prompts, maxlen=None): |
|
if maxlen is None: |
|
maxlen = tokenizer.model_max_length |
|
inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt") |
|
input_ids = inp.input_ids |
|
input_ids = input_ids.to(torch.int) |
|
return text_encoder(input_ids)[0] |
|
|
|
def do_both(prompts): |
|
def mk_img(t): |
|
image = (t/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy() |
|
return Image.fromarray((image*255).round().astype("uint8")) |
|
|
|
def mk_samples(prompts, g=7.5, seed=100, steps=70): |
|
bs = len(prompts) |
|
text = text_enc(prompts) |
|
uncond = text_enc([""] * bs, text.shape[1]) |
|
emb = torch.cat([uncond, text]) |
|
if seed: |
|
torch.manual_seed(seed) |
|
|
|
latents = torch.randn((bs, unet.config.in_channels, height//8, width//8)) |
|
scheduler.set_timesteps(steps) |
|
latents = latents.float() * scheduler.init_noise_sigma |
|
|
|
for i,ts in enumerate(tqdm(scheduler.timesteps)): |
|
inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts) |
|
with torch.no_grad(): u,t = unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2) |
|
pred = u + g*(t-u) |
|
latents = scheduler.step(pred, ts, latents).prev_sample |
|
|
|
with torch.no_grad(): return vae.decode(1 / 0.18215 * latents).sample |
|
|
|
images = mk_samples([prompts]) |
|
for img in images: return(mk_img(img)) |
|
|
|
gr.Interface(do_both, gr.Text(), gr.Image(), title='Stable Diffusion model from scratch').launch(share=True, debug=True) |