File size: 1,400 Bytes
d69a021 81ec701 d69a021 282e8e5 d69a021 28a20cc abe0476 d69a021 d9a39c6 d69a021 d9a39c6 d69a021 756d72b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from typing import Dict, List, Any
import base64
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras_cv.models.stable_diffusion.decoder import Decoder
class EndpointHandler():
def __init__(self, path=""):
img_height = 512
img_width = 512
img_height = round(img_height / 128) * 128
img_width = round(img_width / 128) * 128
self.decoder = Decoder(img_height, img_width)
decoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5",
file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962",
)
self.decoder.load_weights(decoder_weights_fpath)
def __call__(self, data: Dict[str, Any]) -> str:
# get inputs
latent = data.pop("inputs", data)
batch_size = data.pop("batch_size", 1)
latent = base64.b64decode(latent)
latent = np.frombuffer(latent, dtype="float32")
latent = np.reshape(latent, (batch_size, 64, 64, 4))
decoded = self.decoder.predict_on_batch(latent)
decoded = ((decoded + 1) / 2) * 255
images = np.clip(decoded, 0, 255).astype("uint8")
images_b64 = base64.b64encode(images.tobytes())
images_b64str = images_b64.decode()
return images_b64str
|