File size: 2,372 Bytes
ffd1a8e 53f8aa7 ffd1a8e 53f8aa7 ffd1a8e 53f8aa7 ffd1a8e 53f8aa7 ffd1a8e 53f8aa7 210400d ffd1a8e 53f8aa7 ffd1a8e 53f8aa7 ffd1a8e cbc8ee5 53f8aa7 ffd1a8e 53f8aa7 ffd1a8e 53f8aa7 ffd1a8e 53f8aa7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers import LMSDiscreteScheduler
import torch
from tqdm.auto import tqdm
from PIL import Image
import gradio as gr
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
beta_start, beta_end = 0.00085, 0.012
height = 512
width = 512
num_inference_steps = 70
guidance_scale = 7.5
batch_size = 1
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear", num_train_timesteps=1000)
def text_enc(prompts, maxlen=None):
if maxlen is None:
maxlen = tokenizer.model_max_length
inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
input_ids = inp.input_ids
input_ids = input_ids.to(torch.int)
return text_encoder(input_ids)[0]
def do_both(prompts):
def mk_img(t):
image = (t/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy()
return Image.fromarray((image*255).round().astype("uint8"))
def mk_samples(prompts, g=7.5, seed=100, steps=70):
bs = len(prompts)
text = text_enc(prompts)
uncond = text_enc([""] * bs, text.shape[1])
emb = torch.cat([uncond, text])
if seed:
torch.manual_seed(seed)
latents = torch.randn((bs, unet.config.in_channels, height//8, width//8))
scheduler.set_timesteps(steps)
latents = latents.float() * scheduler.init_noise_sigma
for i,ts in enumerate(tqdm(scheduler.timesteps)):
inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)
with torch.no_grad(): u,t = unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)
pred = u + g*(t-u)
latents = scheduler.step(pred, ts, latents).prev_sample
with torch.no_grad(): return vae.decode(1 / 0.18215 * latents).sample
images = mk_samples([prompts])
for img in images: return(mk_img(img))
gr.Interface(do_both, gr.Text(), gr.Image(), title='Stable Diffusion model from scratch').launch(share=True, debug=True) |