File size: 5,403 Bytes
631fce7 99a3462 631fce7 99a3462 631fce7 6a121f5 99a3462 631fce7 99a3462 631fce7 9cfbc46 dfec646 631fce7 9cfbc46 631fce7 99a3462 631fce7 9cfbc46 631fce7 99a3462 631fce7 dfec646 631fce7 2781352 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from typing import Dict, List, Any
import sys
import base64
import math
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras_cv.models.stable_diffusion.constants import _ALPHAS_CUMPROD
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModelV2
class EndpointHandler():
def __init__(self, path="", version="2"):
self.seed = None
img_height = 512
img_width = 512
self.img_height = round(img_height / 128) * 128
self.img_width = round(img_width / 128) * 128
self.MAX_PROMPT_LENGTH = 77
self.version = version
self.diffusion_model = self._instantiate_diffusion_model(version)
if isinstance(self.diffusion_model, str):
sys.exit(self.diffusion_model)
def _instantiate_diffusion_model(self, version: str):
if version == "1.4":
diffusion_model_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5",
file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe",
)
diffusion_model = DiffusionModel(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
diffusion_model.load_weights(diffusion_model_weights_fpath)
return diffusion_model
elif version == "2":
diffusion_model_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/ianstenbit/keras-sd2.1/resolve/main/diffusion_model_v2_1.h5",
file_hash="c31730e91111f98fe0e2dbde4475d381b5287ebb9672b1821796146a25c5132d",
)
diffusion_model = DiffusionModelV2(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
diffusion_model.load_weights(diffusion_model_weights_fpath)
return diffusion_model
else:
return f"v{version} is not supported"
def _get_initial_diffusion_noise(self, batch_size, seed):
if seed is not None:
return tf.random.stateless_normal(
(batch_size, self.img_height // 8, self.img_width // 8, 4),
seed=[seed, seed],
)
else:
return tf.random.normal(
(batch_size, self.img_height // 8, self.img_width // 8, 4)
)
def _get_initial_alphas(self, timesteps):
alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
alphas_prev = [1.0] + alphas[:-1]
return alphas, alphas_prev
def _get_timestep_embedding(self, timestep, batch_size, dim=320, max_period=10000):
half = dim // 2
freqs = tf.math.exp(
-math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
)
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
embedding = tf.reshape(embedding, [1, -1])
return tf.repeat(embedding, batch_size, axis=0)
def __call__(self, data: Dict[str, Any]) -> str:
# get inputs
contexts = data.pop("inputs", data)
batch_size = data.pop("batch_size", 1)
context = base64.b64decode(contexts[0])
context = np.frombuffer(context, dtype="float32")
if self.version == "1.4":
context = np.reshape(context, (batch_size, 77, 768))
else:
context = np.reshape(context, (batch_size, 77, 1024))
unconditional_context = base64.b64decode(contexts[1])
unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
if self.version == "1.4":
unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 768))
else:
unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 1024))
num_steps = data.pop("num_steps", 25)
unconditional_guidance_scale = data.pop("unconditional_guidance_scale", 7.5)
latent = self._get_initial_diffusion_noise(batch_size, self.seed)
# Iterative reverse diffusion stage
timesteps = tf.range(1, 1000, 1000 // num_steps)
alphas, alphas_prev = self._get_initial_alphas(timesteps)
progbar = keras.utils.Progbar(len(timesteps))
iteration = 0
for index, timestep in list(enumerate(timesteps))[::-1]:
latent_prev = latent # Set aside the previous latent vector
t_emb = self._get_timestep_embedding(timestep, batch_size)
unconditional_latent = self.diffusion_model.predict_on_batch(
[latent, t_emb, unconditional_context]
)
latent = self.diffusion_model.predict_on_batch([latent, t_emb, context])
latent = unconditional_latent + unconditional_guidance_scale * (
latent - unconditional_latent
)
a_t, a_prev = alphas[index], alphas_prev[index]
pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(a_t)
latent = latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0
iteration += 1
progbar.update(iteration)
latent_b64 = base64.b64encode(latent.numpy().tobytes())
latent_b64str = latent_b64.decode()
return latent_b64str
|