markhristov's picture
hf changes
53f8aa7
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers import LMSDiscreteScheduler
import torch
from tqdm.auto import tqdm
from PIL import Image
import gradio as gr
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
beta_start, beta_end = 0.00085, 0.012
height = 512
width = 512
num_inference_steps = 70
guidance_scale = 7.5
batch_size = 1
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear", num_train_timesteps=1000)
def text_enc(prompts, maxlen=None):
if maxlen is None:
maxlen = tokenizer.model_max_length
inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
input_ids = inp.input_ids
input_ids = input_ids.to(torch.int)
return text_encoder(input_ids)[0]
def do_both(prompts):
def mk_img(t):
image = (t/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy()
return Image.fromarray((image*255).round().astype("uint8"))
def mk_samples(prompts, g=7.5, seed=100, steps=70):
bs = len(prompts)
text = text_enc(prompts)
uncond = text_enc([""] * bs, text.shape[1])
emb = torch.cat([uncond, text])
if seed:
torch.manual_seed(seed)
latents = torch.randn((bs, unet.config.in_channels, height//8, width//8))
scheduler.set_timesteps(steps)
latents = latents.float() * scheduler.init_noise_sigma
for i,ts in enumerate(tqdm(scheduler.timesteps)):
inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)
with torch.no_grad(): u,t = unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)
pred = u + g*(t-u)
latents = scheduler.step(pred, ts, latents).prev_sample
with torch.no_grad(): return vae.decode(1 / 0.18215 * latents).sample
images = mk_samples([prompts])
for img in images: return(mk_img(img))
gr.Interface(do_both, gr.Text(), gr.Image(), title='Stable Diffusion model from scratch').launch(share=True, debug=True)