chansung commited on
Commit
dc39a0d
·
1 Parent(s): 2846968

add custom handler

Browse files
__pycache__/handler.cpython-38.pyc ADDED
Binary file (2.31 kB). View file
 
handler.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+
3
+ import tensorflow as tf
4
+ from keras_cv.models.generative.stable_diffusion.text_encoder import TextEncoder
5
+ from keras_cv.models.generative.stable_diffusion.clip_tokenizer import SimpleTokenizer
6
+ from keras_cv.models.generative.stable_diffusion.constants import _UNCONDITIONAL_TOKENS
7
+
8
+ class EndpointHandler():
9
+ def __init__(self, path=""):
10
+ self.MAX_PROMPT_LENGTH = 77
11
+
12
+ self.tokenizer = SimpleTokenizer()
13
+ self.text_encoder = TextEncoder(self.MAX_PROMPT_LENGTH)
14
+ self.pos_ids = tf.convert_to_tensor([list(range(self.MAX_PROMPT_LENGTH))], dtype=tf.int32)
15
+
16
+ def _get_unconditional_context(self):
17
+ unconditional_tokens = tf.convert_to_tensor(
18
+ [_UNCONDITIONAL_TOKENS], dtype=tf.int32
19
+ )
20
+ unconditional_context = self.text_encoder.predict_on_batch(
21
+ [unconditional_tokens, self.pos_ids]
22
+ )
23
+
24
+ return unconditional_context
25
+
26
+ def encode_text(self, prompt):
27
+ # Tokenize prompt (i.e. starting context)
28
+ inputs = self.tokenizer.encode(prompt)
29
+ if len(inputs) > self.MAX_PROMPT_LENGTH:
30
+ raise ValueError(
31
+ f"Prompt is too long (should be <= {self.MAX_PROMPT_LENGTH} tokens)"
32
+ )
33
+ phrase = inputs + [49407] * (self.MAX_PROMPT_LENGTH - len(inputs))
34
+ phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)
35
+
36
+ context = self.text_encoder.predict_on_batch([phrase, self.pos_ids])
37
+
38
+ return context
39
+
40
+ def get_contexts(self, encoded_text, batch_size):
41
+ encoded_text = tf.squeeze(encoded_text)
42
+ if encoded_text.shape.rank == 2:
43
+ encoded_text = tf.repeat(
44
+ tf.expand_dims(encoded_text, axis=0), batch_size, axis=0
45
+ )
46
+
47
+ context = encoded_text
48
+
49
+ unconditional_context = tf.repeat(
50
+ _get_unconditional_context(), batch_size, axis=0
51
+ )
52
+
53
+ return context, unconditional_context
54
+
55
+ def __call__(self, data: Dict[str, Any]) -> str:
56
+ # get inputs
57
+ prompt = data.pop("inputs", data)
58
+ batch_size = data.pop("batch_size", 1)
59
+
60
+ encoded_text = encode_text(prompt)
61
+ context, unconditional_context = get_contexts(encoded_text, batch_size)
62
+
63
+ return context, unconditional_context
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ keras-cv
2
+ tensorflow
3
+ tensorflow_datasets