Spaces:
Runtime error
Runtime error
File size: 4,166 Bytes
0c8d55e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import torch
from typing import List
def _encode_prompt_with_t5(
text_encoder,
tokenizer,
max_sequence_length,
prompt=None,
num_images_per_prompt=1,
device=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def _encode_prompt_with_clip(
text_encoder,
tokenizer,
prompt: str,
device=None,
text_input_ids=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError(
"text_input_ids must be provided when the tokenizer is not specified"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, pooled_prompt_embeds
def encode_prompt(
text_encoders,
tokenizers,
prompt: str|List,
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
text_input_ids_list=None,
only_positive_t5=False,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
clip_tokenizers = tokenizers[:2]
clip_text_encoders = text_encoders[:2]
clip_prompt_embeds_list = []
clip_pooled_prompt_embeds_list = []
for i, (tokenizer, text_encoder) in enumerate(
zip(clip_tokenizers, clip_text_encoders)
):
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt if not only_positive_t5 else [""] * len(prompt),
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
t5_prompt_embed = _encode_prompt_with_t5(
text_encoders[-1],
tokenizers[-1],
max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[-1].device,
)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds,
(0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]),
)
t5_prompt_embed = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
return clip_prompt_embeds, t5_prompt_embed, pooled_prompt_embeds
|