markhristov commited on
Commit
210400d
·
1 Parent(s): 2d82e67

Convert input_ids to float32

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -27,7 +27,9 @@ scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_
27
  def text_enc(prompts, maxlen=None):
28
  if maxlen is None: maxlen = tokenizer.model_max_length
29
  inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
30
- return text_encoder(inp.input_ids.long())[0]
 
 
31
 
32
  def do_both(prompts):
33
  def mk_img(t):
 
27
  def text_enc(prompts, maxlen=None):
28
  if maxlen is None: maxlen = tokenizer.model_max_length
29
  inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
30
+ input_ids = inp.input_ids.long() # Convert input_ids to Long data type
31
+ input_ids = input_ids.float() # Convert input_ids to float32
32
+ return text_encoder(input_ids)[0]
33
 
34
  def do_both(prompts):
35
  def mk_img(t):