""" ein notation: b - batch n - sequence nt - text sequence nw - raw wave length d - dimension """ from __future__ import annotations from random import random from typing import Callable import numpy as np import torch import torch.nn.functional as F from torch import nn from ctcmodel import ConformerCTC from discriminator_conformer import ConformerDiscirminator from ecapa_tdnn import ECAPA_TDNN from f5_tts.model import DiT from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, list_str_to_tensor, mask_from_frac_lengths) class NoOpContext: def __enter__(self): pass def __exit__(self, *args): pass def predict_flow( transformer, # flow model x, # noisy input cond, # mask (prompt mask + length mask) text, # text input time, # time step second_time=None, cfg_strength=1.0, ): pred = transformer( x=x, cond=cond, text=text, time=time, second_time=second_time, drop_audio_cond=False, drop_text=False, ) if cfg_strength < 1e-5: return pred null_pred = transformer( x=x, cond=cond, text=text, time=time, second_time=second_time, drop_audio_cond=True, drop_text=True, ) return pred + (pred - null_pred) * cfg_strength def _kl_dist_func(x, y): log_probs = F.log_softmax(x, dim=2) target_probs = F.log_softmax(y, dim=2) return torch.nn.functional.kl_div( log_probs, target_probs, reduction="batchmean", log_target=True ) class Guidance(nn.Module): def __init__( self, real_unet: DiT, # teacher flow model fake_unet: DiT, # student flow model use_fp16: bool = True, real_guidance_scale: float = 0.0, fake_guidance_scale: float = 0.0, gen_cls_loss: bool = False, sv_path_en: str = "", sv_path_zh: str = "", ctc_path: str = "", sway_coeff: float = 0.0, scale: float = 1.0, ): super().__init__() self.vocab_size = real_unet.vocab_size if ctc_path != "": model = ConformerCTC( vocab_size=real_unet.vocab_size, mel_dim=real_unet.mel_dim, num_heads=8, d_hid=512, nlayers=6, ) self.ctc_model = model.eval() self.ctc_model.requires_grad_(False) self.ctc_model.load_state_dict( torch.load(ctc_path, weights_only=True, map_location="cpu")[ "model_state_dict" ] ) if sv_path_en != "": model = ECAPA_TDNN() self.sv_model_en = model.eval() self.sv_model_en.requires_grad_(False) self.sv_model_en.load_state_dict( torch.load(sv_path, weights_only=True, map_location="cpu")[ "model_state_dict" ] ) if sv_path_zh != "": model = ECAPA_TDNN() self.sv_model_zh = model.eval() self.sv_model_zh.requires_grad_(False) self.sv_model_zh.load_state_dict( torch.load(sv_path_zh, weights_only=True, map_location="cpu")[ "model_state_dict" ] ) self.scale = scale self.real_unet = real_unet self.real_unet.requires_grad_(False) # no update on the teacher model self.fake_unet = fake_unet self.fake_unet.requires_grad_(True) # update the student model self.real_guidance_scale = real_guidance_scale self.fake_guidance_scale = fake_guidance_scale assert self.fake_guidance_scale == 0, "no guidance for fake" self.use_fp16 = use_fp16 self.gen_cls_loss = gen_cls_loss self.sway_coeff = sway_coeff if self.gen_cls_loss: self.cls_pred_branch = ConformerDiscirminator( input_dim=(self.fake_unet.depth + 1) * self.fake_unet.dim + 3 * 512, # 3 is the number of layers from the CTC model num_layers=3, channels=self.fake_unet.dim // 2, ) self.cls_pred_branch.requires_grad_(True) self.network_context_manager = ( torch.autocast(device_type="cuda", dtype=torch.float16) if self.use_fp16 else NoOpContext() ) from torch.utils.data import DataLoader, Dataset, SequentialSampler from f5_tts.model.dataset import (DynamicBatchSampler, collate_fn, load_dataset) from f5_tts.model.utils import get_tokenizer bsz = 16 tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" if tokenizer == "custom": tokenizer_path = tokenizer_path else: tokenizer_path = dataset_name vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) self.vocab_char_map = vocab_char_map def compute_distribution_matching_loss( self, inp: float["b n d"] | float["b nw"], # mel or raw wave, ground truth latent text: int["b nt"] | list[str], # text input *, second_time: torch.Tensor | None = None, # second time step for flow prediction rand_span_mask: ( bool["b n d"] | bool["b nw"] | None ) = None, # combined mask (prompt mask + padding mask) ): """ Compute DMD loss (L_DMD) between the student distribution and teacher distribution. Following the DMDSpeech logic: - Sample time t - Construct noisy input phi = (1 - t)*x0 + t*x1, where x0 is noise and x1 is inp - Predict flows with teacher (f_phi) and student (G_theta) - Compute gradient that aligns student distribution with teacher distribution The code is adapted from F5-TTS but conceptualized per DMD: L_DMD encourages p_theta to match p_data via the difference between teacher and student predictions. """ original_inp = inp with torch.no_grad(): batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device # mel is x1 x1 = inp # x0 is gaussian noise x0 = torch.randn_like(x1) # time step time = torch.rand((batch,), dtype=dtype, device=device) # get flow t = time.unsqueeze(-1).unsqueeze(-1) # t = t + self.sway_coeff * (torch.cos(torch.pi / 2 * t) - 1 + t) sigma_t, alpha_t = (1 - t), t phi = (1 - t) * x0 + t * x1 # noisy x flow = x1 - x0 # flow target # only predict what is within the random mask span for infilling cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) # run at full precision as autocast and no_grad doesn't work well together with self.network_context_manager: pred_fake = predict_flow( self.fake_unet, phi, cond, # mask (prompt mask + length mask) text, # text input time, # time step second_time=second_time, cfg_strength=self.fake_guidance_scale, ) # pred = (x1 - x0), thus phi + (1-t) * pred = (1 - t) * x0 + t * x1 + (1 - t) * (x1 - x0) = (1 - t) * x1 + t * x1 = x1 pred_fake_image = phi + (1 - t) * pred_fake pred_fake_image[~rand_span_mask] = inp[~rand_span_mask] with self.network_context_manager: pred_real = predict_flow( self.real_unet, phi, cond, text, time, cfg_strength=self.real_guidance_scale, ) pred_real_image = phi + (1 - t) * pred_real pred_real_image[~rand_span_mask] = inp[~rand_span_mask] p_real = inp - pred_real_image p_fake = inp - pred_fake_image grad = (p_real - p_fake) / torch.abs(p_real).mean(dim=[1, 2], keepdim=True) grad = torch.nan_to_num(grad) # grad = grad / sigma_t # pred_fake - pred_real # grad = grad * (1 + sigma_t / alpha_t) # grad = grad / (1 + sigma_t / alpha_t) # noise # grad = grad / sigma_t # score difference # grad = grad * alpha_t # grad = grad * (sigma_t ** 2 / alpha_t) # grad = grad * (alpha_t + sigma_t ** 2 / alpha_t) # The DMD loss: MSE to move student distribution closer to teacher distribution # Only optimize over the masked region loss = ( 0.5 * F.mse_loss( original_inp.float(), (original_inp - grad).detach().float(), reduction="none", ) * rand_span_mask.unsqueeze(-1) ) loss = loss.sum() / (rand_span_mask.sum() * grad.size(-1)) loss_dict = {"loss_dm": loss} dm_log_dict = { "dmtrain_time": time.detach().float(), "dmtrain_noisy_inp": phi.detach().float(), "dmtrain_pred_real_image": pred_real_image.detach().float(), "dmtrain_pred_fake_image": pred_fake_image.detach().float(), "dmtrain_grad": grad.detach().float(), "dmtrain_gradient_norm": torch.norm(grad).item(), } return loss_dict, dm_log_dict def compute_ctc_sv_loss( self, real_inp: torch.Tensor, # real data latent fake_inp: torch.Tensor, # student-generated data latent text: torch.Tensor, text_lens: torch.Tensor, rand_span_mask: torch.Tensor, second_time: torch.Tensor | None = None, ): """ Compute CTC + SV loss for direct metric optimization, as described in DMDSpeech. - CTC loss reduces WER - SV loss improves speaker similarity Both CTC and SV models operate on latents. """ # compute CTC loss out, layer, ctc_loss = self.ctc_model( fake_inp * self.scale, text, text_lens ) # lengths from rand_span_mask or known with torch.no_grad(): real_out, real_layers, ctc_loss_test = self.ctc_model( real_inp * self.scale, text, text_lens ) real_logits = real_out.log_softmax(dim=2) # emb_real = self.sv_model(real_inp * self.scale) # snippet from prompt region fake_logits = out.log_softmax(dim=2) kl_loss = F.kl_div(fake_logits, real_logits, reduction="mean", log_target=True) # For SV: # Extract speaker embeddings from real (prompt) and fake: # emb_fake = self.sv_model(fake_inp * self.scale) # sv_loss = 1 - F.cosine_similarity(emb_real, emb_fake, dim=-1).mean() input_lengths = rand_span_mask.sum(axis=-1).cpu().numpy() prompt_lengths = real_inp.size(1) - rand_span_mask.sum(axis=-1).cpu().numpy() chunks_real = [] chunks_fake = [] mel_len = min([int(input_lengths.min().item() - 1), 300]) for bib in range(len(input_lengths)): prompt_length = int(prompt_lengths[bib].item()) mel_length = int(input_lengths[bib].item()) mask = rand_span_mask[bib] mask = torch.where(mask)[0] prompt_start = mask[0].cpu().numpy() prompt_end = mask[-1].cpu().numpy() if prompt_end - mel_len <= prompt_start: random_start = np.random.randint(0, mel_length - mel_len) else: random_start = np.random.randint(prompt_start, prompt_end - mel_len) chunks_fake.append(fake_inp[bib, random_start : random_start + mel_len, :]) chunks_real.append(real_inp[bib, :mel_len, :]) chunks_real = torch.stack(chunks_real, dim=0) chunks_fake = torch.stack(chunks_fake, dim=0) with torch.no_grad(): emb_real_en = self.sv_model_en(chunks_real * self.scale) emb_fake_en = self.sv_model_en(chunks_fake * self.scale) sv_loss_en = 1 - F.cosine_similarity(emb_real_en, emb_fake_en, dim=-1).mean() with torch.no_grad(): emb_real_zh = self.sv_model_zh(chunks_real * self.scale) emb_fake_zh = self.sv_model_zh(chunks_fake * self.scale) sv_loss_zh = 1 - F.cosine_similarity(emb_real_zh, emb_fake_zh, dim=-1).mean() sv_loss = (sv_loss_en + sv_loss_zh) / 2 return ( {"loss_ctc": ctc_loss, "loss_kl": kl_loss, "loss_sim": sv_loss}, layer, real_layers, ) def compute_loss_fake( self, inp: torch.Tensor, # student generator output text: torch.Tensor | list[str], rand_span_mask: torch.Tensor, second_time: torch.Tensor | None = None, ): """ Compute flow loss for the fake flow model, which is trained to estimate the flow (score) of the student distribution. This is the same as L_diff in the paper. """ # Similar to distribution matching, but only train fake to predict flow directly batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device if isinstance(text, list): if exists(self.vocab_char_map): text = list_str_to_idx(text, self.vocab_char_map).to(device) else: text = list_str_to_tensor(text).to(device) assert text.shape[0] == batch # Sample a time time = torch.rand((batch,), dtype=dtype, device=device) x1 = inp x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) phi = (1 - t) * x0 + t * x1 flow = x1 - x0 cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) with self.network_context_manager: pred = self.fake_unet( x=phi, cond=cond, text=text, time=time, second_time=second_time, drop_audio_cond=False, drop_text=False, # make sure the cfg=1 ) # Compute MSE between predicted flow and actual flow, masked by rand_span_mask loss = F.mse_loss(pred, flow, reduction="none") loss = loss[rand_span_mask].mean() loss_dict = {"loss_fake_mean": loss} log_dict = { "faketrain_noisy_inp": phi.detach().float(), "faketrain_x1": x1.detach().float(), "faketrain_pred_flow": pred.detach().float(), } return loss_dict, log_dict def compute_cls_logits( self, inp: torch.Tensor, # student generator output layer: torch.Tensor, text: torch.Tensor, rand_span_mask: torch.Tensor, second_time: torch.Tensor | None = None, guidance: bool = False, ): """ Compute adversarial loss logits for the generator. This is used to compute L_adv in the paper. """ context_no_grad = torch.no_grad if guidance else NoOpContext with context_no_grad(): # If we are not doing generator classification loss, return zeros if not self.gen_cls_loss: return torch.zeros_like(inp[..., 0]) # shape (b, n) # For classification, we need some representation: # We'll mimic the logic from compute_loss_fake batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device if isinstance(text, list): if exists(self.vocab_char_map): text = list_str_to_idx(text, self.vocab_char_map).to(device) else: text = list_str_to_tensor(text).to(device) assert text.shape[0] == batch # Sample a time time = torch.rand((batch,), dtype=dtype, device=device) x1 = inp x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) phi = (1 - t) * x0 + t * x1 cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) with self.network_context_manager: layers = self.fake_unet( x=phi, cond=cond, text=text, time=time, second_time=second_time, drop_audio_cond=False, drop_text=False, # make sure the cfg=1 classify_mode=True, ) # layers = torch.stack(layers, dim=0) if guidance: layers = [layer.detach() for layer in layers] layer = layer[-3:] # only use the last 3 layers layer = [l.transpose(-1, -2) for l in layer] # layer = [F.interpolate(l, mode='nearest', scale_factor=4).transpose(-1, -2) for l in layer] if layer[0].size(1) < layers[0].size(1): layer = [F.pad(l, (0, 0, 0, layers[0].size(1) - l.size(1))) for l in layer] layers = layer + layers # logits: (b, 1) logits = self.cls_pred_branch(layers) return logits, layers def compute_generator_cls_loss( self, inp: torch.Tensor, # student generator output layer: torch.Tensor, real_layers: torch.Tensor, text: torch.Tensor, rand_span_mask: torch.Tensor, second_time: torch.Tensor | None = None, mse_loss: bool = False, mse_inp: torch.Tensor | None = None, ): """ Compute the adversarial loss for the generator. """ # Compute classification loss for generator: if not self.gen_cls_loss: return {"gen_cls_loss": 0} logits, fake_layers = self.compute_cls_logits( inp, layer, text, rand_span_mask, second_time, guidance=False ) loss = ((1 - logits) ** 2).mean() return {"gen_cls_loss": loss, "loss_mse": 0} def compute_guidance_cls_loss( self, fake_inp: torch.Tensor, text: torch.Tensor, rand_span_mask: torch.Tensor, real_data: dict, second_time: torch.Tensor | None = None, ): """ This function computes the adversarial loss for the discirminator. The discriminator is trained to classify the generator output as real or fake. """ with torch.no_grad(): # get layers from CTC model _, layer = self.ctc_model(fake_inp * self.scale) logits_fake, _ = self.compute_cls_logits( fake_inp.detach(), layer, text, rand_span_mask, second_time, guidance=True ) loss_fake = (logits_fake**2).mean() real_inp = real_data["inp"] with torch.no_grad(): # get layers from CTC model _, layer = self.ctc_model(real_inp * self.scale) logits_real, _ = self.compute_cls_logits( real_inp.detach(), layer, text, rand_span_mask, second_time, guidance=True ) loss_real = ((1 - logits_real) ** 2).mean() classification_loss = loss_real + loss_fake loss_dict = {"guidance_cls_loss": classification_loss} log_dict = { "pred_realism_on_real": loss_real.detach().item(), "pred_realism_on_fake": loss_fake.detach().item(), } return loss_dict, log_dict def generator_forward( self, inp: torch.Tensor, text: torch.Tensor, text_lens: torch.Tensor, text_normalized: torch.Tensor, text_normalized_lens: torch.Tensor, rand_span_mask: torch.Tensor, real_data: ( dict | None ) = None, # ground truth data (primarily prompt) to compute SV loss second_time: torch.Tensor | None = None, mse_loss: bool = False, ): """ Forward pass for the generator. This function computes the loss for the generator, which includes: - Distribution matching loss (L_DMD) - Adversarial generator loss (L_adv(G; D)) - CTC/SV loss (L_ctc + L_sv) """ # 1. Compute DM loss dm_loss_dict, dm_log_dict = self.compute_distribution_matching_loss( inp, text, rand_span_mask=rand_span_mask, second_time=second_time ) ctc_sv_loss_dict = {} cls_loss_dict = {} # 2. Compute optional CTC/SV loss if real_data provided if real_data is not None: real_inp = real_data["inp"] ctc_sv_loss_dict, layer, real_layers = self.compute_ctc_sv_loss( real_inp, inp, text_normalized, text_normalized_lens, rand_span_mask, second_time=second_time, ) # 3. Compute optional classification loss if self.gen_cls_loss: cls_loss_dict = self.compute_generator_cls_loss( inp, layer, real_layers, text, rand_span_mask=rand_span_mask, second_time=second_time, mse_inp=real_data["inp"] if real_data is not None else None, mse_loss=mse_loss, ) loss_dict = {**dm_loss_dict, **cls_loss_dict, **ctc_sv_loss_dict} log_dict = {**dm_log_dict} return loss_dict, log_dict def guidance_forward( self, fake_inp: torch.Tensor, text: torch.Tensor, text_lens: torch.Tensor, rand_span_mask: torch.Tensor, real_data: dict | None = None, second_time: torch.Tensor | None = None, ): """ Forward pass for the guidnce module (discriminator + fake flow function). This function computes the loss for the guidance module, which includes: - Flow matching loss (L_diff) - Adversarial discrminator loss (L_adv(D; G)) """ # Compute fake loss (like epsilon prediction loss in Guidance) fake_loss_dict, fake_log_dict = self.compute_loss_fake( fake_inp, text, rand_span_mask=rand_span_mask, second_time=second_time ) # If gen_cls_loss, compute guidance cls loss cls_loss_dict = {} cls_log_dict = {} if self.gen_cls_loss and real_data is not None: cls_loss_dict, cls_log_dict = self.compute_guidance_cls_loss( fake_inp, text, rand_span_mask, real_data, second_time=second_time ) loss_dict = {**fake_loss_dict, **cls_loss_dict} log_dict = {**fake_log_dict, **cls_log_dict} return loss_dict, log_dict def forward( self, generator_turn=False, guidance_turn=False, generator_data_dict=None, guidance_data_dict=None, ): if generator_turn: loss_dict, log_dict = self.generator_forward( inp=generator_data_dict["inp"], text=generator_data_dict["text"], text_lens=generator_data_dict["text_lens"], text_normalized=generator_data_dict["text_normalized"], text_normalized_lens=generator_data_dict["text_normalized_lens"], rand_span_mask=generator_data_dict["rand_span_mask"], real_data=generator_data_dict.get("real_data", None), second_time=generator_data_dict.get("second_time", None), mse_loss=generator_data_dict.get("mse_loss", False), ) elif guidance_turn: loss_dict, log_dict = self.guidance_forward( fake_inp=guidance_data_dict["inp"], text=guidance_data_dict["text"], text_lens=guidance_data_dict["text_lens"], rand_span_mask=guidance_data_dict["rand_span_mask"], real_data=guidance_data_dict.get("real_data", None), second_time=guidance_data_dict.get("second_time", None), ) else: raise NotImplementedError( "Must specify either generator_turn or guidance_turn" ) return loss_dict, log_dict if __name__ == "__main__": from f5_tts.model.utils import get_tokenizer bsz = 16 tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" if tokenizer == "custom": tokenizer_path = tokenizer_path else: tokenizer_path = dataset_name vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) real_unet = DiT( dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, text_num_embeds=vocab_size, mel_dim=100, ) fake_unet = DiT( dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, text_num_embeds=vocab_size, mel_dim=100, ) guidance = Guidance( real_unet, fake_unet, real_guidance_scale=1.0, fake_guidance_scale=0.0, use_fp16=True, gen_cls_loss=True, ).cuda() text = ["hello world"] * bsz lens = torch.randint(1, 1000, (bsz,)).cuda() inp = torch.randn(bsz, lens.max(), 80).cuda() batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device # handle text as string if isinstance(text, list): if exists(vocab_char_map): text = list_str_to_idx(text, vocab_char_map).to(device) else: text = list_str_to_tensor(text).to(device) assert text.shape[0] == batch # lens and mask if not exists(lens): lens = torch.full((batch,), seq_len, device=device) mask = lens_to_mask( lens, length=seq_len ) # useless here, as collate_fn will pad to max length in batch frac_lengths_mask = (0.7, 1.0) # get a random span to mask out for training conditionally frac_lengths = ( torch.zeros((batch,), device=device).float().uniform_(*frac_lengths_mask) ) rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) if exists(mask): rand_span_mask &= mask # Construct data dicts for generator and guidance phases # For flow, `real_data` can just be the ground truth if available; here we simulate it real_data_dict = { "inp": torch.zeros_like(inp), # simulating real data } generator_data_dict = { "inp": inp, "text": text, "rand_span_mask": rand_span_mask, "real_data": real_data_dict, } guidance_data_dict = { "inp": inp, "text": text, "rand_span_mask": rand_span_mask, "real_data": real_data_dict, } # Generator forward pass loss_dict, log_dict = guidance( generator_turn=True, generator_data_dict=generator_data_dict ) print("Generator turn losses:", loss_dict) # Guidance forward pass loss_dict, log_dict = guidance( guidance_turn=True, guidance_data_dict=guidance_data_dict ) print("Guidance turn losses:", loss_dict)