File size: 3,919 Bytes
dc39a0d
4c8e231
e0c0641
dc39a0d
 
e06cb5f
299f5c7
 
 
 
dc39a0d
 
299f5c7
dc39a0d
 
299f5c7
4c8e231
 
 
 
dc39a0d
 
299f5c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc39a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0c0641
dc39a0d
 
 
 
d063207
dc39a0d
 
 
 
e0c0641
 
dc39a0d
e0c0641
 
 
 
 
6f23e96
a57f986
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Dict, List, Any
import sys
import base64

import tensorflow as tf
from tensorflow import keras
from keras_cv.models.stable_diffusion.text_encoder import TextEncoder
from keras_cv.models.stable_diffusion.text_encoder import TextEncoderV2
from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.stable_diffusion.constants import _UNCONDITIONAL_TOKENS

class EndpointHandler():
    def __init__(self, path="", version="2"):
        self.MAX_PROMPT_LENGTH = 77

        self.text_encoder = self._instantiate_text_encoder(version)
        if isinstance(self.text_encoder, str):
          sys.exit(self.text_encoder)

        self.tokenizer = SimpleTokenizer()
        self.pos_ids = tf.convert_to_tensor([list(range(self.MAX_PROMPT_LENGTH))], dtype=tf.int32)    

    def _instantiate_text_encoder(self, version: str):
        if version == "1.4":
            text_encoder_weights_fpath = keras.utils.get_file(
                origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
                file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
            )
            text_encoder = TextEncoder(self.MAX_PROMPT_LENGTH)
            text_encoder.load_weights(text_encoder_weights_fpath)
            return text_encoder
        elif version == "2":
            text_encoder_weights_fpath = keras.utils.get_file(
                origin="https://huggingface.co/ianstenbit/keras-sd2.1/resolve/main/text_encoder_v2_1.h5",
                file_hash="985002e68704e1c5c3549de332218e99c5b9b745db7171d5f31fcd9a6089f25b",
            )
            text_encoder = TextEncoderV2(self.MAX_PROMPT_LENGTH)
            text_encoder.load_weights(text_encoder_weights_fpath)
            return text_encoder
        else:
            return f"v{version} is not supported"

    def _get_unconditional_context(self):
        unconditional_tokens = tf.convert_to_tensor(
            [_UNCONDITIONAL_TOKENS], dtype=tf.int32
        )
        unconditional_context = self.text_encoder.predict_on_batch(
            [unconditional_tokens, self.pos_ids]
        )

        return unconditional_context

    def encode_text(self, prompt):
      # Tokenize prompt (i.e. starting context)
      inputs = self.tokenizer.encode(prompt)
      if len(inputs) > self.MAX_PROMPT_LENGTH:
          raise ValueError(
              f"Prompt is too long (should be <= {self.MAX_PROMPT_LENGTH} tokens)"
          )
      phrase = inputs + [49407] * (self.MAX_PROMPT_LENGTH - len(inputs))
      phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)

      context = self.text_encoder.predict_on_batch([phrase, self.pos_ids])

      return context  

    def get_contexts(self, encoded_text, batch_size):
        encoded_text = tf.squeeze(encoded_text)
        if encoded_text.shape.rank == 2:
            encoded_text = tf.repeat(
                tf.expand_dims(encoded_text, axis=0), batch_size, axis=0
            )

        context = encoded_text

        unconditional_context = tf.repeat(
            self._get_unconditional_context(), batch_size, axis=0
        )  

        return context, unconditional_context

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        # get inputs 
        prompt = data.pop("inputs", data)
        batch_size = data.pop("batch_size", 1)

        encoded_text = self.encode_text(prompt)
        context, unconditional_context = self.get_contexts(encoded_text, batch_size)

        context_b64 = base64.b64encode(context.numpy().tobytes())
        context_b64str = context_b64.decode()

        unconditional_context_b64 = base64.b64encode(unconditional_context.numpy().tobytes())
        unconditional_context_b64str = unconditional_context_b64.decode()        
        
        return {"context_b64str": context_b64str, "unconditional_context_b64str": unconditional_context_b64str}