File size: 3,919 Bytes
dc39a0d 4c8e231 e0c0641 dc39a0d e06cb5f 299f5c7 dc39a0d 299f5c7 dc39a0d 299f5c7 4c8e231 dc39a0d 299f5c7 dc39a0d e0c0641 dc39a0d d063207 dc39a0d e0c0641 dc39a0d e0c0641 6f23e96 a57f986 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from typing import Dict, List, Any
import sys
import base64
import tensorflow as tf
from tensorflow import keras
from keras_cv.models.stable_diffusion.text_encoder import TextEncoder
from keras_cv.models.stable_diffusion.text_encoder import TextEncoderV2
from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.stable_diffusion.constants import _UNCONDITIONAL_TOKENS
class EndpointHandler():
def __init__(self, path="", version="2"):
self.MAX_PROMPT_LENGTH = 77
self.text_encoder = self._instantiate_text_encoder(version)
if isinstance(self.text_encoder, str):
sys.exit(self.text_encoder)
self.tokenizer = SimpleTokenizer()
self.pos_ids = tf.convert_to_tensor([list(range(self.MAX_PROMPT_LENGTH))], dtype=tf.int32)
def _instantiate_text_encoder(self, version: str):
if version == "1.4":
text_encoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
)
text_encoder = TextEncoder(self.MAX_PROMPT_LENGTH)
text_encoder.load_weights(text_encoder_weights_fpath)
return text_encoder
elif version == "2":
text_encoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/ianstenbit/keras-sd2.1/resolve/main/text_encoder_v2_1.h5",
file_hash="985002e68704e1c5c3549de332218e99c5b9b745db7171d5f31fcd9a6089f25b",
)
text_encoder = TextEncoderV2(self.MAX_PROMPT_LENGTH)
text_encoder.load_weights(text_encoder_weights_fpath)
return text_encoder
else:
return f"v{version} is not supported"
def _get_unconditional_context(self):
unconditional_tokens = tf.convert_to_tensor(
[_UNCONDITIONAL_TOKENS], dtype=tf.int32
)
unconditional_context = self.text_encoder.predict_on_batch(
[unconditional_tokens, self.pos_ids]
)
return unconditional_context
def encode_text(self, prompt):
# Tokenize prompt (i.e. starting context)
inputs = self.tokenizer.encode(prompt)
if len(inputs) > self.MAX_PROMPT_LENGTH:
raise ValueError(
f"Prompt is too long (should be <= {self.MAX_PROMPT_LENGTH} tokens)"
)
phrase = inputs + [49407] * (self.MAX_PROMPT_LENGTH - len(inputs))
phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)
context = self.text_encoder.predict_on_batch([phrase, self.pos_ids])
return context
def get_contexts(self, encoded_text, batch_size):
encoded_text = tf.squeeze(encoded_text)
if encoded_text.shape.rank == 2:
encoded_text = tf.repeat(
tf.expand_dims(encoded_text, axis=0), batch_size, axis=0
)
context = encoded_text
unconditional_context = tf.repeat(
self._get_unconditional_context(), batch_size, axis=0
)
return context, unconditional_context
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# get inputs
prompt = data.pop("inputs", data)
batch_size = data.pop("batch_size", 1)
encoded_text = self.encode_text(prompt)
context, unconditional_context = self.get_contexts(encoded_text, batch_size)
context_b64 = base64.b64encode(context.numpy().tobytes())
context_b64str = context_b64.decode()
unconditional_context_b64 = base64.b64encode(unconditional_context.numpy().tobytes())
unconditional_context_b64str = unconditional_context_b64.decode()
return {"context_b64str": context_b64str, "unconditional_context_b64str": unconditional_context_b64str}
|