Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch import nn | |
| from einops import rearrange | |
| import numpy as np | |
| from typing import List | |
| from models.id_embedding.helpers import get_rep_pos, shift_tensor_dim0 | |
| from models.id_embedding.meta_net import StyleVectorizer | |
| from models.celeb_embeddings import _get_celeb_embeddings_basis | |
| from functools import partial | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import torch.nn.init as init | |
| DEFAULT_PLACEHOLDER_TOKEN = ["*"] | |
| PROGRESSIVE_SCALE = 2000 | |
| def get_clip_token_for_string(tokenizer, string): | |
| batch_encoding = tokenizer(string, return_length=True, padding=True, truncation=True, return_overflowing_tokens=False, return_tensors="pt") | |
| tokens = batch_encoding["input_ids"] | |
| return tokens | |
| def get_embedding_for_clip_token(embedder, token): | |
| return embedder(token.unsqueeze(0)) | |
| class EmbeddingManagerId_adain(nn.Module): | |
| def __init__( | |
| self, | |
| tokenizer, | |
| text_encoder, | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), | |
| experiment_name = "normal_GAN", | |
| num_embeds_per_token: int = 2, | |
| loss_type: str = None, | |
| mlp_depth: int = 2, | |
| token_dim: int = 1024, | |
| input_dim: int = 1024, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.device = device | |
| self.num_es = num_embeds_per_token | |
| self.get_token_for_string = partial(get_clip_token_for_string, tokenizer) | |
| self.get_embedding_for_tkn = partial(get_embedding_for_clip_token, text_encoder.text_model.embeddings) | |
| self.token_dim = token_dim | |
| ''' 1. Placeholder mapping dicts ''' | |
| self.placeholder_token = self.get_token_for_string("*")[0][1] | |
| if experiment_name == "normal_GAN": | |
| self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names.txt") | |
| elif experiment_name == "man_GAN": | |
| self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names_man.txt") | |
| elif experiment_name == "woman_GAN": | |
| self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names_woman.txt") | |
| else: | |
| print("Hello, please notice this ^_^") | |
| assert 0 | |
| print("now experiment_name:", experiment_name) | |
| self.celeb_embeddings_mean = self.celeb_embeddings_mean.to(device) | |
| self.celeb_embeddings_std = self.celeb_embeddings_std.to(device) | |
| self.name_projection_layer = StyleVectorizer(input_dim, self.token_dim * self.num_es, depth=mlp_depth, lr_mul=0.1) | |
| self.embedding_discriminator = Embedding_discriminator(self.token_dim * self.num_es, dropout_rate = 0.2) | |
| self.adain_mode = 0 | |
| def forward( | |
| self, | |
| tokenized_text, | |
| embedded_text, | |
| name_batch, | |
| random_embeddings = None, | |
| timesteps = None, | |
| ): | |
| if tokenized_text is not None: | |
| batch_size, n, device = *tokenized_text.shape, tokenized_text.device | |
| other_return_dict = {} | |
| if random_embeddings is not None: | |
| mlp_output_embedding = self.name_projection_layer(random_embeddings) | |
| total_embedding = mlp_output_embedding.view(mlp_output_embedding.shape[0], 2, 1024) | |
| if self.adain_mode == 0: | |
| adained_total_embedding = total_embedding * self.celeb_embeddings_std + self.celeb_embeddings_mean | |
| else: | |
| adained_total_embedding = total_embedding | |
| other_return_dict["total_embedding"] = total_embedding | |
| other_return_dict["adained_total_embedding"] = adained_total_embedding | |
| if name_batch is not None: | |
| if isinstance(name_batch, list): | |
| name_tokens = self.get_token_for_string(name_batch)[:, 1:3] | |
| name_embeddings = self.get_embedding_for_tkn(name_tokens.to(random_embeddings.device))[0] | |
| other_return_dict["name_embeddings"] = name_embeddings | |
| else: | |
| assert 0 | |
| if tokenized_text is not None: | |
| placeholder_pos = get_rep_pos(tokenized_text, | |
| [self.placeholder_token]) | |
| placeholder_pos = np.array(placeholder_pos) | |
| if len(placeholder_pos) != 0: | |
| batch_size = adained_total_embedding.shape[0] | |
| end_index = min(batch_size, placeholder_pos.shape[0]) | |
| embedded_text[placeholder_pos[:, 0], placeholder_pos[:, 1]] = adained_total_embedding[:end_index,0,:] | |
| embedded_text[placeholder_pos[:, 0], placeholder_pos[:, 1] + 1] = adained_total_embedding[:end_index,1,:] | |
| return embedded_text, other_return_dict | |
| def load(self, ckpt_path): | |
| ckpt = torch.load(ckpt_path, map_location='cuda') | |
| if ckpt.get("name_projection_layer") is not None: | |
| self.name_projection_layer = ckpt.get("name_projection_layer").float() | |
| print('[Embedding Manager] weights loaded.') | |
| def save(self, ckpt_path): | |
| save_dict = {} | |
| save_dict["name_projection_layer"] = self.name_projection_layer | |
| torch.save(save_dict, ckpt_path) | |
| def trainable_projection_parameters(self): | |
| trainable_list = [] | |
| trainable_list.extend(list(self.name_projection_layer.parameters())) | |
| return trainable_list | |
| class Embedding_discriminator(nn.Module): | |
| def __init__(self, input_size, dropout_rate): | |
| super(Embedding_discriminator, self).__init__() | |
| self.input_size = input_size | |
| self.fc1 = nn.Linear(input_size, 512) | |
| self.fc2 = nn.Linear(512, 256) | |
| self.fc3 = nn.Linear(256, 1) | |
| self.LayerNorm1 = nn.LayerNorm(512) | |
| self.LayerNorm2 = nn.LayerNorm(256) | |
| self.leaky_relu = nn.LeakyReLU(0.2) | |
| self.dropout_rate = dropout_rate | |
| if self.dropout_rate > 0: | |
| self.dropout1 = nn.Dropout(dropout_rate) | |
| self.dropout2 = nn.Dropout(dropout_rate) | |
| def forward(self, input): | |
| x = input.view(-1, self.input_size) | |
| if self.dropout_rate > 0: | |
| x = self.leaky_relu(self.dropout1(self.fc1(x))) | |
| else: | |
| x = self.leaky_relu(self.fc1(x)) | |
| if self.dropout_rate > 0: | |
| x = self.leaky_relu(self.dropout2(self.fc2(x))) | |
| else: | |
| x = self.leaky_relu(self.fc2(x)) | |
| x = self.fc3(x) | |
| return x | |
| def save(self, ckpt_path): | |
| save_dict = {} | |
| save_dict["fc1"] = self.fc1 | |
| save_dict["fc2"] = self.fc2 | |
| save_dict["fc3"] = self.fc3 | |
| save_dict["LayerNorm1"] = self.LayerNorm1 | |
| save_dict["LayerNorm2"] = self.LayerNorm2 | |
| save_dict["leaky_relu"] = self.leaky_relu | |
| save_dict["dropout1"] = self.dropout1 | |
| save_dict["dropout2"] = self.dropout2 | |
| torch.save(save_dict, ckpt_path) | |
| def load(self, ckpt_path): | |
| ckpt = torch.load(ckpt_path, map_location='cuda') | |
| if ckpt.get("first_name_proj_layer") is not None: | |
| self.fc1 = ckpt.get("fc1").float() | |
| self.fc2 = ckpt.get("fc2").float() | |
| self.fc3 = ckpt.get("fc3").float() | |
| self.LayerNorm1 = ckpt.get("LayerNorm1").float() | |
| self.LayerNorm2 = ckpt.get("LayerNorm2").float() | |
| self.leaky_relu = ckpt.get("leaky_relu").float() | |
| self.dropout1 = ckpt.get("dropout1").float() | |
| self.dropout2 = ckpt.get("dropout2").float() | |
| print('[Embedding D] weights loaded.') | |