File size: 1,400 Bytes
d69a021
 
 
 
81ec701
d69a021
282e8e5
d69a021
 
28a20cc
abe0476
 
 
 
 
 
 
 
 
 
 
d69a021
 
 
 
d9a39c6
 
d69a021
 
d9a39c6
d69a021
 
 
 
 
 
 
 
756d72b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from typing import Dict, List, Any

import base64
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras_cv.models.stable_diffusion.decoder import Decoder

class EndpointHandler():
    def __init__(self, path=""):        
        img_height = 512
        img_width = 512
        img_height = round(img_height / 128) * 128
        img_width = round(img_width / 128) * 128
        
        self.decoder = Decoder(img_height, img_width)
        decoder_weights_fpath = keras.utils.get_file(
            origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5",
            file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962",
        )
        self.decoder.load_weights(decoder_weights_fpath)

    def __call__(self, data: Dict[str, Any]) -> str:
        # get inputs 
        latent = data.pop("inputs", data)
        batch_size = data.pop("batch_size", 1)

        latent = base64.b64decode(latent)
        latent = np.frombuffer(latent, dtype="float32")
        latent = np.reshape(latent, (batch_size, 64, 64, 4))

        decoded = self.decoder.predict_on_batch(latent)
        decoded = ((decoded + 1) / 2) * 255
        images = np.clip(decoded, 0, 255).astype("uint8")

        images_b64 = base64.b64encode(images.tobytes())
        images_b64str = images_b64.decode()

        return images_b64str