Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from typing import List | |
| import torch | |
| from PIL import Image | |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
| from nested_attention_processor import AttnProcessor, NestedAttnProcessor | |
| from utils import get_generator | |
| from resampler import Resampler | |
| def add_special_token_to_tokenizer( | |
| pipe, | |
| placeholder_token, | |
| initializer_token | |
| ): | |
| num_added_tokens1 = pipe.tokenizer.add_tokens([placeholder_token]) | |
| num_added_tokens2 = pipe.tokenizer_2.add_tokens([placeholder_token]) | |
| if num_added_tokens1 != 1 or num_added_tokens2 != 1: | |
| raise ValueError("Failed to add placeholder token to tokenizer") | |
| token_ids1 = pipe.tokenizer.encode(initializer_token, add_special_tokens=False) | |
| token_ids2 = pipe.tokenizer_2.encode(initializer_token, add_special_tokens=False) | |
| if len(token_ids1) > 1 or len(token_ids2) > 1: | |
| raise ValueError("The initializer token must be a single token.") | |
| initializer_token_id1 = token_ids1[0] | |
| initializer_token_id2 = token_ids2[0] | |
| placeholder_token_ids1 = pipe.tokenizer.convert_tokens_to_ids([placeholder_token]) | |
| placeholder_token_ids2 = pipe.tokenizer_2.convert_tokens_to_ids([placeholder_token]) | |
| pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer)) | |
| pipe.text_encoder_2.resize_token_embeddings(len(pipe.tokenizer_2)) | |
| token_embeds1 = pipe.text_encoder.get_input_embeddings().weight.data | |
| token_embeds2 = pipe.text_encoder_2.get_input_embeddings().weight.data | |
| with torch.no_grad(): | |
| for token_id in placeholder_token_ids1: | |
| token_embeds1[token_id] = token_embeds1[initializer_token_id1].clone() | |
| for token_id in placeholder_token_ids2: | |
| token_embeds2[token_id] = token_embeds2[initializer_token_id2].clone() | |
| class NestedAdapterInference: | |
| def __init__( | |
| self, | |
| sd_pipe, | |
| image_encoder_path, | |
| adapter_ckpt, | |
| resampler_num_queries, | |
| vq_normalize_factor, | |
| device, | |
| ): | |
| self.device = device | |
| self.image_encoder_path = image_encoder_path | |
| self.adapter_ckpt = adapter_ckpt | |
| self.vq_normalize_factor = vq_normalize_factor | |
| self.pipe = sd_pipe.to(self.device) | |
| self.set_nested_adapter() | |
| # load image encoder | |
| self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| self.image_encoder_path | |
| ).to(self.device, dtype=torch.float16) | |
| self.clip_image_processor = CLIPImageProcessor() | |
| # spatial features model | |
| self.qformer = Resampler( | |
| dim=self.pipe.unet.config.cross_attention_dim, | |
| depth=4, | |
| dim_head=64, | |
| heads=12, | |
| num_queries=resampler_num_queries, | |
| embedding_dim=self.image_encoder.config.hidden_size, | |
| output_dim=self.pipe.unet.config.cross_attention_dim, | |
| ff_mult=4, | |
| ).to(self.device, dtype=torch.float16) | |
| if adapter_ckpt is not None: | |
| self.load_nested_adapter() | |
| def set_nested_adapter(self): | |
| unet = self.pipe.unet | |
| attn_procs = {} | |
| for name in unet.attn_processors.keys(): | |
| cross_attention_dim = ( | |
| None | |
| if name.endswith("attn1.processor") | |
| else unet.config.cross_attention_dim | |
| ) | |
| if name.startswith("mid_block"): | |
| hidden_size = unet.config.block_out_channels[-1] | |
| elif name.startswith("up_blocks"): | |
| block_id = int(name[len("up_blocks.")]) | |
| hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | |
| elif name.startswith("down_blocks"): | |
| block_id = int(name[len("down_blocks.")]) | |
| hidden_size = unet.config.block_out_channels[block_id] | |
| if cross_attention_dim is None: | |
| attn_procs[name] = AttnProcessor() | |
| else: | |
| attn_procs[name] = NestedAttnProcessor( | |
| hidden_size=hidden_size, | |
| cross_attention_dim=cross_attention_dim, | |
| normalize_factor=self.vq_normalize_factor, | |
| ).to(self.device, dtype=torch.float16) | |
| unet.set_attn_processor(attn_procs) | |
| def load_nested_adapter(self): | |
| state_dict = {"adapter_modules": {}, "qformer": {}} | |
| f = torch.load(self.adapter_ckpt, map_location="cpu") | |
| for key in f.keys(): | |
| if key.startswith("adapter_modules."): | |
| state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f[ | |
| key | |
| ] | |
| elif key.startswith("spatial_features_model."): | |
| state_dict["qformer"][key.replace("spatial_features_model.", "")] = f[ | |
| key | |
| ] | |
| self.qformer.load_state_dict(state_dict["qformer"]) | |
| adapter_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) | |
| adapter_layers.load_state_dict(state_dict["adapter_modules"]) | |
| def get_image_embeds(self, pil_image=None, clip_image_embeds=None): | |
| if isinstance(pil_image, Image.Image): | |
| pil_image = [pil_image] | |
| clip_image = self.clip_image_processor( | |
| images=pil_image, return_tensors="pt" | |
| ).pixel_values | |
| clip_image_embeds = self.image_encoder( | |
| clip_image.to(self.device, dtype=torch.float16) | |
| ) | |
| spatial_clip_image_embeds = clip_image_embeds.last_hidden_state | |
| spatial_clip_image_embeds = spatial_clip_image_embeds[:, 1:] # remove CLS token | |
| return spatial_clip_image_embeds | |
| def generate( | |
| self, | |
| pil_image=None, | |
| clip_image_embeds=None, | |
| prompt=None, | |
| placeholder_token_ids=None, | |
| negative_prompt=None, | |
| scale=1.0, | |
| num_samples=4, | |
| seed=None, | |
| guidance_scale=5.0, | |
| num_inference_steps=30, | |
| multiple_images=False, | |
| special_token_weight=1.0, | |
| **kwargs, | |
| ): | |
| if pil_image is not None: | |
| num_prompts = ( | |
| 1 | |
| if isinstance(pil_image, Image.Image) or multiple_images | |
| else len(pil_image) | |
| ) | |
| else: | |
| num_prompts = clip_image_embeds.size(0) | |
| if prompt is None: | |
| prompt = "best quality, high quality" | |
| if negative_prompt is None: | |
| negative_prompt = ( | |
| "monochrome, lowres, bad anatomy, worst quality, low quality" | |
| ) | |
| if not isinstance(prompt, List): | |
| prompt = [prompt] * num_prompts | |
| if not isinstance(negative_prompt, List): | |
| negative_prompt = [negative_prompt] * num_prompts | |
| text_input_ids = self.pipe.tokenizer( | |
| prompt, | |
| max_length=self.pipe.tokenizer.model_max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ).input_ids | |
| special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[ | |
| :, 1 | |
| ] | |
| spatial_clip_image_embeds = self.get_image_embeds( | |
| pil_image=pil_image, clip_image_embeds=clip_image_embeds | |
| ) # (bs, 256, 1280) | |
| with torch.no_grad(): | |
| ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds, | |
| ) = self.pipe.encode_prompt( | |
| prompt, | |
| num_images_per_prompt=num_samples, | |
| do_classifier_free_guidance=True, | |
| negative_prompt=negative_prompt, | |
| ) | |
| special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[ | |
| :, 1 | |
| ] | |
| with torch.no_grad(): | |
| qformer_tokens_out = self.qformer(spatial_clip_image_embeds) | |
| if multiple_images: | |
| b, num_tokens, d = qformer_tokens_out.shape | |
| qformer_tokens_out = qformer_tokens_out.reshape( | |
| 1, num_tokens * b, d | |
| ) | |
| bs_embed, num_tokens, _ = qformer_tokens_out.shape | |
| qformer_tokens_out = qformer_tokens_out.repeat(1, num_samples, 1, 1) | |
| qformer_tokens_out = qformer_tokens_out.view( | |
| bs_embed * num_samples, num_tokens, -1 | |
| ) | |
| qformer_tokens_out = qformer_tokens_out.repeat_interleave(2, dim=0) | |
| cross_attention_kwargs = { | |
| "qformer_tokens_out": qformer_tokens_out, | |
| "special_token_indices": special_token_indices, | |
| "special_token_weight": special_token_weight, | |
| "inference_mode": True, | |
| } | |
| generator = get_generator(seed, self.device) | |
| images = self.pipe( | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| **kwargs, | |
| ).images | |
| return images | |