import torch from typing import List def _encode_prompt_with_t5( text_encoder, tokenizer, max_sequence_length, prompt=None, num_images_per_prompt=1, device=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device))[0] dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _encode_prompt_with_clip( text_encoder, tokenizer, prompt: str, device=None, text_input_ids=None, num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError( "text_input_ids must be provided when the tokenizer is not specified" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds, pooled_prompt_embeds def encode_prompt( text_encoders, tokenizers, prompt: str|List, max_sequence_length, device=None, num_images_per_prompt: int = 1, text_input_ids_list=None, only_positive_t5=False, ): prompt = [prompt] if isinstance(prompt, str) else prompt clip_tokenizers = tokenizers[:2] clip_text_encoders = text_encoders[:2] clip_prompt_embeds_list = [] clip_pooled_prompt_embeds_list = [] for i, (tokenizer, text_encoder) in enumerate( zip(clip_tokenizers, clip_text_encoders) ): prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoder, tokenizer=tokenizer, prompt=prompt if not only_positive_t5 else [""] * len(prompt), device=device if device is not None else text_encoder.device, num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids_list[i] if text_input_ids_list else None, ) clip_prompt_embeds_list.append(prompt_embeds) clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) t5_prompt_embed = _encode_prompt_with_t5( text_encoders[-1], tokenizers[-1], max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[-1].device, ) clip_prompt_embeds = torch.nn.functional.pad( clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), ) t5_prompt_embed = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) return clip_prompt_embeds, t5_prompt_embed, pooled_prompt_embeds