|
|
|
import torch |
|
|
|
import copy |
|
from diffusers import DPMSolverMultistepScheduler |
|
import os |
|
from collections import OrderedDict |
|
import logging |
|
from safetensors.torch import load_file |
|
from diffusers import ( |
|
AutoencoderDC, |
|
FlowMatchEulerDiscreteScheduler, |
|
SanaTransformer2DModel |
|
) |
|
import torch.nn as nn |
|
from .pipeline_sana import SanaPipeline |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ToClipMLP(nn.Module): |
|
def __init__(self, input_dim, output_dim): |
|
super().__init__() |
|
|
|
self.fc1 = nn.Linear(input_dim, 2048) |
|
self.layer_norm1 = nn.LayerNorm(2048) |
|
self.relu = nn.ReLU() |
|
self.fc2 = nn.Linear(2048, output_dim) |
|
self.layer_norm2 = nn.LayerNorm(output_dim) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.fc1(hidden_states) |
|
hidden_states = self.layer_norm1(hidden_states) |
|
hidden_states = self.relu(hidden_states) |
|
hidden_states = self.fc2(hidden_states) |
|
hidden_states = self.layer_norm2(hidden_states) |
|
return hidden_states |
|
|
|
class ToClipMLP(nn.Module): |
|
def __init__(self, input_dim, output_dim): |
|
super().__init__() |
|
|
|
self.fc1 = nn.Linear(input_dim, 2048) |
|
self.layer_norm1 = nn.LayerNorm(2048) |
|
self.relu = nn.ReLU() |
|
self.fc2 = nn.Linear(2048, output_dim) |
|
self.layer_norm2 = nn.LayerNorm(output_dim) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.fc1(hidden_states) |
|
hidden_states = self.layer_norm1(hidden_states) |
|
hidden_states = self.relu(hidden_states) |
|
hidden_states = self.fc2(hidden_states) |
|
hidden_states = self.layer_norm2(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class SanaModel_withMLP(nn.Module): |
|
def __init__(self, sana, vision_dim=1152): |
|
super().__init__() |
|
self.sana = sana |
|
self.dtype = torch.bfloat16 |
|
self.mlp = ToClipMLP(vision_dim, 2304) |
|
|
|
self.config = self.sana.config |
|
|
|
def forward(self, hidden_states, |
|
timestep, |
|
encoder_hidden_states, |
|
return_dict, |
|
encoder_attention_mask=None, |
|
**kargs): |
|
|
|
encoder_hidden_states = self.mlp(encoder_hidden_states) |
|
hidden_states = self.sana( |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
timestep=timestep, |
|
return_dict=False, |
|
**kargs |
|
) |
|
return hidden_states |
|
|
|
def enable_gradient_checkpointing(self): |
|
self.sana.enable_gradient_checkpointing() |
|
|
|
def inference_load_denoising_pretrained_weights( |
|
net, |
|
weights_path, |
|
names=None, |
|
prefix_to_remove=None, |
|
): |
|
|
|
state_dict = load_file(weights_path) |
|
net.load_state_dict(state_dict, strict=False) |
|
return |
|
|
|
|
|
def load_denoising_pretrained_weights( |
|
net, |
|
weights_path, |
|
names=None, |
|
prefix_to_remove=None, |
|
): |
|
state_dict = torch.load(weights_path, map_location="cpu") |
|
if "model" in state_dict: |
|
state_dict = state_dict["model"] |
|
elif "net" in state_dict: |
|
state_dict = state_dict["net"] |
|
|
|
|
|
|
|
|
|
|
|
if names is not None: |
|
selected_state_dict = OrderedDict() |
|
for ori_name in names: |
|
name = ori_name[len(prefix_to_remove):] if prefix_to_remove is not None else ori_name |
|
selected_state_dict[name] = state_dict[ori_name] |
|
|
|
state_dict = selected_state_dict |
|
|
|
net.load_state_dict(state_dict, strict=True) |
|
return |
|
|
|
|
|
class SANALoss(torch.nn.Module): |
|
def __init__( |
|
self, |
|
model_path, scheduler_path, vision_dim=3584, diffusion_type='flow_matching', convert_vpred_to_xpred=True, |
|
checkpoint_path=None, |
|
|
|
|
|
mlp_state_dict=None, |
|
trainable_params='all', device='cpu', guidance_scale=3.5, revision=None, variant=None, repa_loss=False, mid_layer_idx=10, mid_loss_weight=1.0 |
|
): |
|
super(SANALoss, self).__init__() |
|
self.torch_type = torch.bfloat16 |
|
self.base_model_path = model_path |
|
self.use_mid_loss = repa_loss |
|
self.mid_loss_weight = mid_loss_weight |
|
self.mid_layer_idx = mid_layer_idx |
|
|
|
|
|
self.scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler") |
|
|
|
|
|
self.device = torch.device(torch.cuda.current_device()) |
|
self.scheduler_path = scheduler_path |
|
self.vae = AutoencoderDC.from_pretrained( |
|
model_path, |
|
subfolder="vae", |
|
revision=revision, |
|
variant=variant, |
|
) |
|
|
|
|
|
self.vae.requires_grad_(False) |
|
|
|
self.train_model = SanaTransformer2DModel.from_pretrained( |
|
model_path, subfolder="transformer", revision=revision, variant=variant |
|
) |
|
|
|
if checkpoint_path is not None: |
|
assert os.path.exists(checkpoint_path) |
|
load_denoising_pretrained_weights(self.train_model, checkpoint_path) |
|
|
|
|
|
|
|
self.train_model = SanaModel_withMLP(self.train_model, vision_dim=vision_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
assert mlp_state_dict is not None |
|
self.train_model.mlp.load_state_dict(mlp_state_dict, strict=True) |
|
|
|
|
|
hidden_dim = 2240 |
|
self.mid_layer_mlp = None |
|
if self.use_mid_loss: |
|
self.mid_layer_mlp = torch.nn.Sequential( |
|
torch.nn.Linear(hidden_dim, hidden_dim * 2), |
|
torch.nn.GELU(), |
|
torch.nn.Linear(hidden_dim * 2, 32), |
|
torch.nn.LayerNorm(32) |
|
) |
|
|
|
|
|
for m in self.mid_layer_mlp.modules(): |
|
if isinstance(m, torch.nn.Linear): |
|
|
|
torch.nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') |
|
if m.bias is not None: |
|
|
|
torch.nn.init.zeros_(m.bias) |
|
|
|
self.train_model.enable_gradient_checkpointing() |
|
|
|
self.set_trainable_params(trainable_params) |
|
|
|
|
|
num_parameters_trainable = 0 |
|
num_parameters = 0 |
|
name_parameters_trainable = [] |
|
for n, p in self.train_model.named_parameters(): |
|
num_parameters += p.data.nelement() |
|
if not p.requires_grad: |
|
continue |
|
name_parameters_trainable.append(n) |
|
num_parameters_trainable += p.data.nelement() |
|
|
|
self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( |
|
self.scheduler_path, subfolder="scheduler" |
|
) |
|
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Preparation done. Starting training diffusion ...") |
|
|
|
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): |
|
|
|
sigmas = self.noise_scheduler_copy.sigmas |
|
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device=timesteps.device) |
|
timesteps = timesteps |
|
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] |
|
|
|
sigma = sigmas[step_indices].flatten() |
|
while len(sigma.shape) < n_dim: |
|
sigma = sigma.unsqueeze(-1) |
|
return sigma |
|
|
|
def compute_text_embeddings(self, prompt, text_encoders, tokenizers): |
|
with torch.no_grad(): |
|
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( |
|
[text_encoders], [tokenizers], prompt, 77 |
|
) |
|
|
|
pooled_prompt_embeds = pooled_prompt_embeds.to(local_rank) |
|
|
|
return prompt_embeds, pooled_prompt_embeds, text_ids |
|
|
|
def set_trainable_params(self, trainable_params): |
|
|
|
self.vae.requires_grad_(False) |
|
|
|
if trainable_params == 'all': |
|
self.train_model.requires_grad_(True) |
|
else: |
|
self.train_model.requires_grad_(False) |
|
for name, module in self.train_model.named_modules(): |
|
for trainable_param in trainable_params: |
|
if trainable_param in name: |
|
for params in module.parameters(): |
|
params.requires_grad = True |
|
|
|
num_parameters_trainable = 0 |
|
num_parameters = 0 |
|
name_parameters_trainable = [] |
|
for n, p in self.train_model.named_parameters(): |
|
num_parameters += p.data.nelement() |
|
if not p.requires_grad: |
|
continue |
|
name_parameters_trainable.append(n) |
|
num_parameters_trainable += p.data.nelement() |
|
|
|
def sample(self, encoder_hidden_states, steps=20, cfg=7.0, seed=42, height=512, width=512): |
|
|
|
self.pipelines = SanaPipeline(vae=self.vae, |
|
transformer=self.train_model, |
|
text_encoder=None, |
|
tokenizer=None, |
|
scheduler=self.noise_scheduler, |
|
).to(self.device) |
|
|
|
prompt_attention_mask = torch.ones(encoder_hidden_states.shape[:2]).to(self.device) |
|
negative_attention_mask = torch.ones(encoder_hidden_states.shape[:2]).to(self.device) |
|
|
|
image = self.pipelines( |
|
prompt_embeds=encoder_hidden_states, |
|
prompt_attention_mask=prompt_attention_mask, |
|
negative_prompt_embeds=encoder_hidden_states*0, |
|
negative_prompt_attention_mask=negative_attention_mask, |
|
guidance_scale=cfg, |
|
generator=torch.manual_seed(seed), |
|
num_inference_steps=steps, |
|
device=self.device, |
|
height=height, |
|
width=width, |
|
max_sequence_length=300, |
|
).images[0] |
|
|
|
return image |
|
|