chansung commited on
Commit
299f5c7
·
1 Parent(s): d0f58da

update custom handler

Browse files
__pycache__/handler.cpython-38.pyc CHANGED
Binary files a/__pycache__/handler.cpython-38.pyc and b/__pycache__/handler.cpython-38.pyc differ
 
handler.py CHANGED
@@ -1,27 +1,41 @@
1
  from typing import Dict, List, Any
2
  import base64
3
 
4
- import logging
5
-
6
  import tensorflow as tf
7
  from tensorflow import keras
8
- from keras_cv.models.generative.stable_diffusion.text_encoder import TextEncoder
9
- from keras_cv.models.generative.stable_diffusion.clip_tokenizer import SimpleTokenizer
10
- from keras_cv.models.generative.stable_diffusion.constants import _UNCONDITIONAL_TOKENS
 
11
 
12
  class EndpointHandler():
13
- def __init__(self, path=""):
14
  self.MAX_PROMPT_LENGTH = 77
15
 
16
  self.tokenizer = SimpleTokenizer()
17
- self.text_encoder = TextEncoder(self.MAX_PROMPT_LENGTH)
18
- text_encoder_weights_fpath = keras.utils.get_file(
19
- origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
20
- file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
21
- )
22
- self.text_encoder.load_weights(text_encoder_weights_fpath)
23
  self.pos_ids = tf.convert_to_tensor([list(range(self.MAX_PROMPT_LENGTH))], dtype=tf.int32)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def _get_unconditional_context(self):
26
  unconditional_tokens = tf.convert_to_tensor(
27
  [_UNCONDITIONAL_TOKENS], dtype=tf.int32
 
1
  from typing import Dict, List, Any
2
  import base64
3
 
 
 
4
  import tensorflow as tf
5
  from tensorflow import keras
6
+ from keras_cv.models.stable_diffusion.text_encoder import TextEncoder
7
+ from keras_cv.models.stable_diffusion.text_encoder import TextEncoderV2
8
+ from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
9
+ from keras_cv.models.stable_diffusion.constants import _UNCONDITIONAL_TOKENS
10
 
11
  class EndpointHandler():
12
+ def __init__(self, path="", version="2"):
13
  self.MAX_PROMPT_LENGTH = 77
14
 
15
  self.tokenizer = SimpleTokenizer()
16
+ self.text_encoder = self._instantiate_text_encoder(version)
 
 
 
 
 
17
  self.pos_ids = tf.convert_to_tensor([list(range(self.MAX_PROMPT_LENGTH))], dtype=tf.int32)
18
 
19
+ def _instantiate_text_encoder(self, version: str):
20
+ if version == "1.4":
21
+ text_encoder_weights_fpath = keras.utils.get_file(
22
+ origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
23
+ file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
24
+ )
25
+ text_encoder = TextEncoder(self.MAX_PROMPT_LENGTH)
26
+ text_encoder.load_weights(text_encoder_weights_fpath)
27
+ return text_encoder
28
+ elif version == "2":
29
+ text_encoder_weights_fpath = keras.utils.get_file(
30
+ origin="https://huggingface.co/ianstenbit/keras-sd2.1/resolve/main/text_encoder_v2_1.h5",
31
+ file_hash="985002e68704e1c5c3549de332218e99c5b9b745db7171d5f31fcd9a6089f25b",
32
+ )
33
+ text_encoder = TextEncoderV2(self.MAX_PROMPT_LENGTH)
34
+ text_encoder.load_weights(text_encoder_weights_fpath)
35
+ return text_encoder
36
+ else:
37
+ return f"v{version} is not supported"
38
+
39
  def _get_unconditional_context(self):
40
  unconditional_tokens = tf.convert_to_tensor(
41
  [_UNCONDITIONAL_TOKENS], dtype=tf.int32
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- keras-cv
2
- tensorflow
3
  tensorflow_datasets
 
1
+ keras-cv==0.4
2
+ tensorflow==2.11
3
  tensorflow_datasets