markhristov's picture
No Nvidia GPU...
ac3f08e
raw
history blame
2.94 kB
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers import LMSDiscreteScheduler
import torch
from tqdm.auto import tqdm
from PIL import Image
import gradio as gr
#from IPython.display import display
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to("cpu")
# Here we use a different VAE to the original release, which has been fine-tuned for more steps
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16).to("cpu")
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to("cpu")
beta_start,beta_end = 0.00085,0.012
height = 512
width = 512
num_inference_steps = 70
guidance_scale = 7.5
batch_size = 1
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear", num_train_timesteps=1000)
#prompt = ["a photograph of an astronaut riding a horse"]
def text_enc(prompts, maxlen=None):
if maxlen is None: maxlen = tokenizer.model_max_length
inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
return text_encoder(inp.input_ids.to("cpu"))[0].half()
def do_both(prompts):
def mk_img(t):
image = (t/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy()
return Image.fromarray((image*255).round().astype("uint8"))
def mk_samples(prompts, g=7.5, seed=100, steps=70):
bs = len(prompts)
text = text_enc(prompts)
uncond = text_enc([""] * bs, text.shape[1])
emb = torch.cat([uncond, text])
if seed: torch.manual_seed(seed)
latents = torch.randn((bs, unet.config.in_channels, height//8, width//8))
scheduler.set_timesteps(steps)
latents = latents.to("cpu").half() * scheduler.init_noise_sigma
for i,ts in enumerate(tqdm(scheduler.timesteps)):
inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)
with torch.no_grad(): u,t = unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)
pred = u + g*(t-u)
latents = scheduler.step(pred, ts, latents).prev_sample
with torch.no_grad(): return vae.decode(1 / 0.18215 * latents).sample
images = mk_samples([prompts])
for img in images: return(mk_img(img))
# do_both(prompt)
# images = mk_samples(prompt)
#iface = gr.Interface(fn=do_both, inputs=gr.inputs.Textbox(lines=2, label="Enter text prompt"), outputs=gr.outputs.Image(type="numpy", label="Generated Image")).launch()
gr.Interface(do_both, gr.Text(), gr.Image(), title = 'Stable Diffusion model from scratch').launch(share = True, debug = True)
# for img in images: display(mk_img(img))