diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..e2b8299740957bd22e34bf9c8e2fc06e446cc70a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +text/ja_userdic/userdict.csv filter=lfs diff=lfs merge=lfs -text +text/g2pw/polyphonic-fix.rep filter=lfs diff=lfs merge=lfs -text +text/g2pw/polyphonic.pickle filter=lfs diff=lfs merge=lfs -text +text/g2pw/polyphonic.rep filter=lfs diff=lfs merge=lfs -text +text/G2PWModel/char_bopomofo_dict.json filter=lfs diff=lfs merge=lfs -text +text/cmudict-fast.rep filter=lfs diff=lfs merge=lfs -text +text/cmudict.rep filter=lfs diff=lfs merge=lfs -text +text/engdict-hot.rep filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9952969eae88032c9fb9201bdee72269d07b7cc7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*/.DS_Store +.DS_Store +.idea/ \ No newline at end of file diff --git a/AR/__init__.py b/AR/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/AR/data/__init__.py b/AR/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/AR/data/bucket_sampler.py b/AR/data/bucket_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..d84573340732b916eaa6e9f8e88cb4166d6f1ca5 --- /dev/null +++ b/AR/data/bucket_sampler.py @@ -0,0 +1,149 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/bucket_sampler.py +# reference: https://github.com/lifeiteng/vall-e +import itertools +import math +import random +from random import shuffle +from typing import Iterator, Optional, TypeVar + +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, Sampler + +__all__ = [ + "DistributedBucketSampler", +] + +T_co = TypeVar("T_co", covariant=True) + + +class DistributedBucketSampler(Sampler[T_co]): + r""" + sort the dataset wrt. input length + divide samples into buckets + sort within buckets + divide buckets into batches + sort batches + """ + + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + batch_size: int = 32, + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1 + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() if torch.cuda.is_available() else 0 + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + if rank >= num_replicas or rank < 0: + raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1)) + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil( + len(self.dataset) / self.num_replicas, + ) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + self.batch_size = batch_size + self.id_with_length = self._get_sample_lengths() + self.id_buckets = self.make_buckets(bucket_width=2.0) + + def _get_sample_lengths(self): + id_with_lengths = [] + for i in range(len(self.dataset)): + id_with_lengths.append((i, self.dataset.get_sample_length(i))) + id_with_lengths.sort(key=lambda x: x[1]) + return id_with_lengths + + def make_buckets(self, bucket_width: float = 2.0): + buckets = [] + cur = [] + max_sec = bucket_width + for id, sec in self.id_with_length: + if sec < max_sec: + cur.append(id) + else: + buckets.append(cur) + cur = [id] + max_sec += bucket_width + if len(cur) > 0: + buckets.append(cur) + return buckets + + def __iter__(self) -> Iterator[T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + random.seed(self.epoch + self.seed) + shuffled_bucket = [] + for buc in self.id_buckets: + buc_copy = buc.copy() + shuffle(buc_copy) + shuffled_bucket.append(buc_copy) + grouped_batch_size = self.batch_size * self.num_replicas + shuffled_bucket = list(itertools.chain(*shuffled_bucket)) + n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size)) + batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)] + shuffle(batches) + indices = list(itertools.chain(*batches)) + else: + # type: ignore[arg-type] + indices = list(range(len(self.dataset))) + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r""" + Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/AR/data/data_module.py b/AR/data/data_module.py new file mode 100644 index 0000000000000000000000000000000000000000..f360503ba06301fc130ae4afbe27ecae4dae33ef --- /dev/null +++ b/AR/data/data_module.py @@ -0,0 +1,81 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py +# reference: https://github.com/lifeiteng/vall-e +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader + +from AR.data.bucket_sampler import DistributedBucketSampler +from AR.data.dataset import Text2SemanticDataset + + +class Text2SemanticDataModule(LightningDataModule): + def __init__( + self, + config, + train_semantic_path, + train_phoneme_path, + dev_semantic_path=None, + dev_phoneme_path=None, + ): + super().__init__() + self.config = config + self.train_semantic_path = train_semantic_path + self.train_phoneme_path = train_phoneme_path + self.dev_semantic_path = dev_semantic_path + self.dev_phoneme_path = dev_phoneme_path + self.num_workers = self.config["data"]["num_workers"] + + def prepare_data(self): + pass + + def setup(self, stage=None, output_logs=False): + self._train_dataset = Text2SemanticDataset( + phoneme_path=self.train_phoneme_path, + semantic_path=self.train_semantic_path, + max_sec=self.config["data"]["max_sec"], + pad_val=self.config["data"]["pad_val"], + ) + self._dev_dataset = self._train_dataset + # self._dev_dataset = Text2SemanticDataset( + # phoneme_path=self.dev_phoneme_path, + # semantic_path=self.dev_semantic_path, + # max_sample=self.config['data']['max_eval_sample'], + # max_sec=self.config['data']['max_sec'], + # pad_val=self.config['data']['pad_val']) + + def train_dataloader(self): + batch_size = ( + self.config["train"]["batch_size"] // 2 + if self.config["train"].get("if_dpo", False) is True + else self.config["train"]["batch_size"] + ) + batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存 + sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size) + return DataLoader( + self._train_dataset, + batch_size=batch_size, + sampler=sampler, + collate_fn=self._train_dataset.collate, + num_workers=self.num_workers, + persistent_workers=True, + prefetch_factor=16, + ) + + def val_dataloader(self): + return DataLoader( + self._dev_dataset, + batch_size=1, + shuffle=False, + collate_fn=self._train_dataset.collate, + num_workers=max(self.num_workers, 12), + persistent_workers=True, + prefetch_factor=16, + ) + + # 这个会使用到嘛? + def test_dataloader(self): + return DataLoader( + self._dev_dataset, + batch_size=1, + shuffle=False, + collate_fn=self._train_dataset.collate, + ) diff --git a/AR/data/dataset.py b/AR/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..402483d918bd05b8609fbcde4eb18f87b2242560 --- /dev/null +++ b/AR/data/dataset.py @@ -0,0 +1,320 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py +# reference: https://github.com/lifeiteng/vall-e + +# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert") +import os +import traceback +from typing import Dict, List + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import DataLoader, Dataset + +version = os.environ.get("version", None) + +from text import cleaned_text_to_sequence + +# from config import exp_dir + + +def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0): + seq = sequences[0] + ndim = seq.ndim + if axis < 0: + axis += ndim + dtype = seq.dtype + pad_value = dtype.type(pad_value) + seq_lengths = [seq.shape[axis] for seq in sequences] + max_length = np.max(seq_lengths) + + padded_sequences = [] + for seq, length in zip(sequences, seq_lengths): + padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1) + padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value) + padded_sequences.append(padded_seq) + batch = np.stack(padded_sequences) + return batch + + +class Text2SemanticDataset(Dataset): + """dataset class for text tokens to semantic model training.""" + + def __init__( + self, + phoneme_path: str, + semantic_path: str, + max_sample: int = None, + max_sec: int = 100, + pad_val: int = 1024, + # min value of phoneme/sec + min_ps_ratio: int = 3, + # max value of phoneme/sec + max_ps_ratio: int = 25, + ) -> None: + super().__init__() + + self.semantic_data = pd.read_csv( + semantic_path, + delimiter="\t", + encoding="utf-8", + ) + # get dict + self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path + self.path3 = "%s/3-bert" % ( + os.path.dirname( + phoneme_path, + ) + ) # "%s/3-bert"%exp_dir#bert_dir + self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path + assert os.path.exists(self.path2) + assert os.path.exists(self.path6) + self.phoneme_data = {} + with open(self.path2, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") + + for line in lines: + tmp = line.split("\t") + if len(tmp) != 4: + continue + self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]] + + # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item() + # pad for semantic tokens + self.PAD: int = pad_val + # self.hz = 25 + # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read() + # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz + # self.hz=int(data[:-2])# + self.hz = int(os.environ.get("hz", "25hz")[:-2]) + + # max seconds of semantic token + self.max_sec = max_sec + self.min_ps_ratio = min_ps_ratio + self.max_ps_ratio = max_ps_ratio + + if max_sample is not None: + self.semantic_data = self.semantic_data[:max_sample] + + # {idx: (semantic, phoneme)} + # semantic list, phoneme list + self.semantic_phoneme = [] + self.item_names = [] + + self.inited = False + + if not self.inited: + # 调用初始化函数 + self.init_batch() + self.inited = True + del self.semantic_data + del self.phoneme_data + # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large") + # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large") + + def init_batch(self): + semantic_data_len = len(self.semantic_data) + phoneme_data_len = len(self.phoneme_data.keys()) + print("semantic_data_len:", semantic_data_len) + print("phoneme_data_len:", phoneme_data_len) + print(self.semantic_data) + idx = 0 + num_not_in = 0 + num_deleted_bigger = 0 + num_deleted_ps = 0 + for i in range(semantic_data_len): + # 先依次遍历 + # get str + item_name = self.semantic_data.iloc[i, 0] + # print(self.phoneme_data) + try: + phoneme, word2ph, text = self.phoneme_data[item_name] + except Exception: + traceback.print_exc() + # print(f"{item_name} not in self.phoneme_data !") + num_not_in += 1 + continue + + semantic_str = self.semantic_data.iloc[i, 1] + # get token list + semantic_ids = [int(idx) for idx in semantic_str.split(" ")] + # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len + # 过滤掉太长的样本 + if ( + len(semantic_ids) > self.max_sec * self.hz + ): #########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k + num_deleted_bigger += 1 + continue + # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理#### + phoneme = phoneme.split(" ") + + try: + phoneme_ids = cleaned_text_to_sequence(phoneme, version) + except: + traceback.print_exc() + # print(f"{item_name} not in self.phoneme_data !") + num_not_in += 1 + continue + # if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行 + if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2:改为恒定限制为semantic/2.5就行 + num_deleted_ps += 1 + continue + # if len(semantic_ids) > 1000:###########3 + # num_deleted_bigger += 1 + # continue + + ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz) + + if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone + num_deleted_ps += 1 + # print(item_name) + continue + + self.semantic_phoneme.append((semantic_ids, phoneme_ids)) + idx += 1 + self.item_names.append(item_name) + + min_num = 100 # 20直接不补#30补了也不存ckpt + leng = len(self.semantic_phoneme) + if leng < min_num: + tmp1 = self.semantic_phoneme + tmp2 = self.item_names + self.semantic_phoneme = [] + self.item_names = [] + for _ in range(max(2, int(min_num / leng))): + self.semantic_phoneme += tmp1 + self.item_names += tmp2 + if num_not_in > 0: + print(f"there are {num_not_in} semantic datas not in phoneme datas") + if num_deleted_bigger > 0: + print( + f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds", + ) + if num_deleted_ps > 0: + # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值 + print( + f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}", + ) + """ + there are 31 semantic datas not in phoneme datas + deleted 34 audios who's duration are bigger than 54 seconds + deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3 + dataset.__len__(): 366463 + + """ + # 345410 for LibriTTS + print("dataset.__len__():", self.__len__()) + + def __get_item_names__(self) -> List[str]: + return self.item_names + + def __len__(self) -> int: + return len(self.semantic_phoneme) + + def __getitem__(self, idx: int) -> Dict: + semantic_ids, phoneme_ids = self.semantic_phoneme[idx] + item_name = self.item_names[idx] + phoneme_ids_len = len(phoneme_ids) + # semantic tokens target + semantic_ids_len = len(semantic_ids) + + flag = 0 + path_bert = "%s/%s.pt" % (self.path3, item_name) + if os.path.exists(path_bert) == True: + bert_feature = torch.load(path_bert, map_location="cpu") + else: + flag = 1 + if flag == 1: + # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32) + bert_feature = None + else: + assert bert_feature.shape[-1] == len(phoneme_ids) + return { + "idx": idx, + "phoneme_ids": phoneme_ids, + "phoneme_ids_len": phoneme_ids_len, + "semantic_ids": semantic_ids, + "semantic_ids_len": semantic_ids_len, + "bert_feature": bert_feature, + } + + def get_sample_length(self, idx: int): + semantic_ids = self.semantic_phoneme[idx][0] + sec = 1.0 * len(semantic_ids) / self.hz + return sec + + def collate(self, examples: List[Dict]) -> Dict: + sample_index: List[int] = [] + phoneme_ids: List[torch.Tensor] = [] + phoneme_ids_lens: List[int] = [] + semantic_ids: List[torch.Tensor] = [] + semantic_ids_lens: List[int] = [] + # return + + for item in examples: + sample_index.append(item["idx"]) + phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64)) + semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64)) + phoneme_ids_lens.append(item["phoneme_ids_len"]) + semantic_ids_lens.append(item["semantic_ids_len"]) + + # pad 0 + phoneme_ids = batch_sequences(phoneme_ids) + semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD) + + # # convert each batch to torch.tensor + phoneme_ids = torch.tensor(phoneme_ids) + semantic_ids = torch.tensor(semantic_ids) + phoneme_ids_lens = torch.tensor(phoneme_ids_lens) + semantic_ids_lens = torch.tensor(semantic_ids_lens) + bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens)) + bert_padded.zero_() + + for idx, item in enumerate(examples): + bert = item["bert_feature"] + if bert != None: + bert_padded[idx, :, : bert.shape[-1]] = bert + + return { + # List[int] + "ids": sample_index, + # torch.Tensor (B, max_phoneme_length) + "phoneme_ids": phoneme_ids, + # torch.Tensor (B) + "phoneme_ids_len": phoneme_ids_lens, + # torch.Tensor (B, max_semantic_ids_length) + "semantic_ids": semantic_ids, + # torch.Tensor (B) + "semantic_ids_len": semantic_ids_lens, + # torch.Tensor (B, 1024, max_phoneme_length) + "bert_feature": bert_padded, + } + + +if __name__ == "__main__": + root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/" + dataset = Text2SemanticDataset( + phoneme_path=root_dir + "phoneme_train.npy", + semantic_path=root_dir + "semantic_train.tsv", + ) + + batch_size = 12 + dataloader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=dataset.collate, + shuffle=False, + ) + for i, batch in enumerate(dataloader): + if i % 1000 == 0: + print(i) + # if i == 0: + # print('batch["ids"]:', batch["ids"]) + # print('batch["phoneme_ids"]:', batch["phoneme_ids"], + # batch["phoneme_ids"].shape) + # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"], + # batch["phoneme_ids_len"].shape) + # print('batch["semantic_ids"]:', batch["semantic_ids"], + # batch["semantic_ids"].shape) + # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"], + # batch["semantic_ids_len"].shape) diff --git a/AR/models/__init__.py b/AR/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/AR/models/t2s_lightning_module.py b/AR/models/t2s_lightning_module.py new file mode 100644 index 0000000000000000000000000000000000000000..fd357b9452b908c61b366957e2db609bdac0734a --- /dev/null +++ b/AR/models/t2s_lightning_module.py @@ -0,0 +1,146 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py +# reference: https://github.com/lifeiteng/vall-e +import os +import sys + +now_dir = os.getcwd() +sys.path.append(now_dir) +from typing import Dict + +import torch +from pytorch_lightning import LightningModule + +from AR.models.t2s_model import Text2SemanticDecoder +from AR.modules.lr_schedulers import WarmupCosineLRSchedule +from AR.modules.optim import ScaledAdam + + +class Text2SemanticLightningModule(LightningModule): + def __init__(self, config, output_dir, is_train=True): + super().__init__() + self.config = config + self.top_k = 3 + self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) + pretrained_s1 = config.get("pretrained_s1") + if pretrained_s1 and is_train: + # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) + print( + self.load_state_dict( + torch.load( + pretrained_s1, + map_location="cpu", + weights_only=False, + )["weight"], + ) + ) + if is_train: + self.automatic_optimization = False + self.save_hyperparameters() + self.eval_dir = output_dir / "eval" + self.eval_dir.mkdir(parents=True, exist_ok=True) + + def training_step(self, batch: Dict, batch_idx: int): + opt = self.optimizers() + scheduler = self.lr_schedulers() + forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old + loss, acc = forward( + batch["phoneme_ids"], + batch["phoneme_ids_len"], + batch["semantic_ids"], + batch["semantic_ids_len"], + batch["bert_feature"], + ) + self.manual_backward(loss) + if batch_idx > 0 and batch_idx % 4 == 0: + opt.step() + opt.zero_grad() + scheduler.step() + + self.log( + "total_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + self.log( + "lr", + scheduler.get_last_lr()[0], + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + self.log( + f"top_{self.top_k}_acc", + acc, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + def validation_step(self, batch: Dict, batch_idx: int): + return + + # # get loss + # loss, acc = self.model.forward( + # batch['phoneme_ids'], batch['phoneme_ids_len'], + # batch['semantic_ids'], batch['semantic_ids_len'], + # batch['bert_feature'] + # ) + # + # self.log( + # "val_total_loss", + # loss, + # on_step=True, + # on_epoch=True, + # prog_bar=True, + # sync_dist=True) + # self.log( + # f"val_top_{self.top_k}_acc", + # acc, + # on_step=True, + # on_epoch=True, + # prog_bar=True, + # sync_dist=True) + # + # # get infer output + # semantic_len = batch['semantic_ids'].size(1) + # prompt_len = min(int(semantic_len * 0.5), 150) + # prompt = batch['semantic_ids'][:, :prompt_len] + # pred_semantic = self.model.infer(batch['phoneme_ids'], + # batch['phoneme_ids_len'], prompt, + # batch['bert_feature'] + # ) + # save_name = f'semantic_toks_{batch_idx}.pt' + # save_path = os.path.join(self.eval_dir, save_name) + # torch.save(pred_semantic.detach().cpu(), save_path) + + def configure_optimizers(self): + model_parameters = self.model.parameters() + parameters_names = [] + parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()]) + lm_opt = ScaledAdam( + model_parameters, + lr=0.01, + betas=(0.9, 0.95), + clipping_scale=2.0, + parameters_names=parameters_names, + show_dominant_parameters=False, + clipping_update_period=1000, + ) + + return { + "optimizer": lm_opt, + "lr_scheduler": { + "scheduler": WarmupCosineLRSchedule( + lm_opt, + init_lr=self.config["optimizer"]["lr_init"], + peak_lr=self.config["optimizer"]["lr"], + end_lr=self.config["optimizer"]["lr_end"], + warmup_steps=self.config["optimizer"]["warmup_steps"], + total_steps=self.config["optimizer"]["decay_steps"], + ) + }, + } diff --git a/AR/models/t2s_lightning_module_onnx.py b/AR/models/t2s_lightning_module_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..b0ab59c4c80ed2cf0f7b5aa93e27cd60bf8279c7 --- /dev/null +++ b/AR/models/t2s_lightning_module_onnx.py @@ -0,0 +1,110 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py +# reference: https://github.com/lifeiteng/vall-e +import os +import sys + +now_dir = os.getcwd() +sys.path.append(now_dir) +from typing import Dict + +import torch +from pytorch_lightning import LightningModule + +from AR.models.t2s_model_onnx import Text2SemanticDecoder +from AR.modules.lr_schedulers import WarmupCosineLRSchedule +from AR.modules.optim import ScaledAdam + + +class Text2SemanticLightningModule(LightningModule): + def __init__(self, config, output_dir, is_train=True): + super().__init__() + self.config = config + self.top_k = 3 + self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) + pretrained_s1 = config.get("pretrained_s1") + if pretrained_s1 and is_train: + # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) + print( + self.load_state_dict( + torch.load( + pretrained_s1, + map_location="cpu", + )["weight"], + ), + ) + if is_train: + self.automatic_optimization = False + self.save_hyperparameters() + self.eval_dir = output_dir / "eval" + self.eval_dir.mkdir(parents=True, exist_ok=True) + + def training_step(self, batch: Dict, batch_idx: int): + opt = self.optimizers() + scheduler = self.lr_schedulers() + loss, acc = self.model.forward( + batch["phoneme_ids"], + batch["phoneme_ids_len"], + batch["semantic_ids"], + batch["semantic_ids_len"], + batch["bert_feature"], + ) + self.manual_backward(loss) + if batch_idx > 0 and batch_idx % 4 == 0: + opt.step() + opt.zero_grad() + scheduler.step() + + self.log( + "total_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + self.log( + "lr", + scheduler.get_last_lr()[0], + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + self.log( + f"top_{self.top_k}_acc", + acc, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + def validation_step(self, batch: Dict, batch_idx: int): + return + + def configure_optimizers(self): + model_parameters = self.model.parameters() + parameters_names = [] + parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()]) + lm_opt = ScaledAdam( + model_parameters, + lr=0.01, + betas=(0.9, 0.95), + clipping_scale=2.0, + parameters_names=parameters_names, + show_dominant_parameters=False, + clipping_update_period=1000, + ) + + return { + "optimizer": lm_opt, + "lr_scheduler": { + "scheduler": WarmupCosineLRSchedule( + lm_opt, + init_lr=self.config["optimizer"]["lr_init"], + peak_lr=self.config["optimizer"]["lr"], + end_lr=self.config["optimizer"]["lr_end"], + warmup_steps=self.config["optimizer"]["warmup_steps"], + total_steps=self.config["optimizer"]["decay_steps"], + ) + }, + } diff --git a/AR/models/t2s_model.py b/AR/models/t2s_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4725b7a3b2657cf1c83a035e7dd796ce51703745 --- /dev/null +++ b/AR/models/t2s_model.py @@ -0,0 +1,935 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py +# reference: https://github.com/lifeiteng/vall-e +import math +from typing import List, Optional + +import torch +from torch import nn +from torch.nn import functional as F +from torchmetrics.classification import MulticlassAccuracy +from tqdm import tqdm + +from AR.models.utils import ( + dpo_loss, + get_batch_logps, + make_pad_mask, + make_pad_mask_left, + make_reject_y, + sample, + topk_sampling, +) +from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding +from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer + +default_config = { + "embedding_dim": 512, + "hidden_dim": 512, + "num_head": 8, + "num_layers": 12, + "num_codebook": 8, + "p_dropout": 0.0, + "vocab_size": 1024 + 1, + "phoneme_vocab_size": 512, + "EOS": 1024, +} + + +# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定 +# Efficient implementation equivalent to the following: +def scaled_dot_product_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2) + if scale is None: + scale_factor = torch.tensor(1 / math.sqrt(query.size(-1))) + else: + scale_factor = scale + attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask, float("-inf")) + else: + attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_weight.masked_fill_(attn_mask, 0) + else: + attn_mask[attn_mask != float("-inf")] = 0 + attn_mask[attn_mask == float("-inf")] = 1 + attn_weight.masked_fill_(attn_mask, 0) + + return attn_weight @ value + + +@torch.jit.script +class T2SMLP: + def __init__(self, w1, b1, w2, b2): + self.w1 = w1 + self.b1 = b1 + self.w2 = w2 + self.b2 = b2 + + def forward(self, x): + x = F.relu(F.linear(x, self.w1, self.b1)) + x = F.linear(x, self.w2, self.b2) + return x + + +@torch.jit.script +class T2SBlock: + def __init__( + self, + num_heads, + hidden_dim: int, + mlp: T2SMLP, + qkv_w, + qkv_b, + out_w, + out_b, + norm_w1, + norm_b1, + norm_eps1, + norm_w2, + norm_b2, + norm_eps2, + ): + self.num_heads = num_heads + self.mlp = mlp + self.hidden_dim: int = hidden_dim + self.qkv_w = qkv_w + self.qkv_b = qkv_b + self.out_w = out_w + self.out_b = out_b + self.norm_w1 = norm_w1 + self.norm_b1 = norm_b1 + self.norm_eps1 = norm_eps1 + self.norm_w2 = norm_w2 + self.norm_b2 = norm_b2 + self.norm_eps2 = norm_eps2 + + self.false = torch.tensor(False, dtype=torch.bool) + + @torch.jit.ignore + def to_mask( + self, + x: torch.Tensor, + padding_mask: Optional[torch.Tensor], + ): + if padding_mask is None: + return x + + if padding_mask.dtype == torch.bool: + return x.masked_fill(padding_mask, 0) + else: + return x * padding_mask + + def process_prompt( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + torch_sdpa: bool = True, + ): + q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1) + + batch_size = q.shape[0] + q_len = q.shape[1] + kv_len = k.shape[1] + + q = self.to_mask(q, padding_mask) + k_cache = self.to_mask(k, padding_mask) + v_cache = self.to_mask(v, padding_mask) + + q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) + k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + + if torch_sdpa: + attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) + else: + attn = scaled_dot_product_attention(q, k, v, attn_mask) + + attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) + attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) + + x = x + attn + x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + return x, k_cache, v_cache + + def decode_next_token( + self, + x: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_mask: torch.Tensor = None, + torch_sdpa: bool = True, + ): + q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) + + k_cache = torch.cat([k_cache, k], dim=1) + v_cache = torch.cat([v_cache, v], dim=1) + + batch_size = q.shape[0] + q_len = q.shape[1] + kv_len = k_cache.shape[1] + + q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) + k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + + if torch_sdpa: + attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None) + else: + attn = scaled_dot_product_attention(q, k, v, attn_mask) + + attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) + attn = F.linear(attn, self.out_w, self.out_b) + + x = x + attn + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w1, + self.norm_b1, + self.norm_eps1, + ) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + return x, k_cache, v_cache + + +@torch.jit.script +class T2STransformer: + def __init__(self, num_blocks: int, blocks: List[T2SBlock]): + self.num_blocks: int = num_blocks + self.blocks = blocks + + def process_prompt( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + torch_sdpa: bool = True, + ): + k_cache: List[torch.Tensor] = [] + v_cache: List[torch.Tensor] = [] + for i in range(self.num_blocks): + x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa) + k_cache.append(k_cache_) + v_cache.append(v_cache_) + return x, k_cache, v_cache + + def decode_next_token( + self, + x: torch.Tensor, + k_cache: List[torch.Tensor], + v_cache: List[torch.Tensor], + attn_mask: torch.Tensor = None, + torch_sdpa: bool = True, + ): + for i in range(self.num_blocks): + x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token( + x, k_cache[i], v_cache[i], attn_mask, torch_sdpa + ) + return x, k_cache, v_cache + + +class Text2SemanticDecoder(nn.Module): + def __init__(self, config, norm_first=False, top_k=3): + super(Text2SemanticDecoder, self).__init__() + self.model_dim = config["model"]["hidden_dim"] + self.embedding_dim = config["model"]["embedding_dim"] + self.num_head = config["model"]["head"] + self.num_layers = config["model"]["n_layer"] + self.norm_first = norm_first + self.vocab_size = config["model"]["vocab_size"] + self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"] + self.p_dropout = config["model"]["dropout"] + self.EOS = config["model"]["EOS"] + self.norm_first = norm_first + assert self.EOS == self.vocab_size - 1 + # should be same as num of kmeans bin + # assert self.EOS == 1024 + self.bert_proj = nn.Linear(1024, self.embedding_dim) + self.ar_text_embedding = TokenEmbedding( + self.embedding_dim, + self.phoneme_vocab_size, + self.p_dropout, + ) + self.ar_text_position = SinePositionalEmbedding( + self.embedding_dim, + dropout=0.1, + scale=False, + alpha=True, + ) + self.ar_audio_embedding = TokenEmbedding( + self.embedding_dim, + self.vocab_size, + self.p_dropout, + ) + self.ar_audio_position = SinePositionalEmbedding( + self.embedding_dim, + dropout=0.1, + scale=False, + alpha=True, + ) + + self.h = TransformerEncoder( + TransformerEncoderLayer( + d_model=self.model_dim, + nhead=self.num_head, + dim_feedforward=self.model_dim * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + ), + num_layers=self.num_layers, + norm=LayerNorm(self.model_dim) if norm_first else None, + ) + + self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) + self.loss_fct = nn.CrossEntropyLoss(reduction="sum") + + self.ar_accuracy_metric = MulticlassAccuracy( + self.vocab_size, + top_k=top_k, + average="micro", + multidim_average="global", + ignore_index=self.EOS, + ) + + blocks = [] + + for i in range(self.num_layers): + layer = self.h.layers[i] + t2smlp = T2SMLP( + layer.linear1.weight, + layer.linear1.bias, + layer.linear2.weight, + layer.linear2.bias, + ) + + block = T2SBlock( + self.num_head, + self.model_dim, + t2smlp, + layer.self_attn.in_proj_weight, + layer.self_attn.in_proj_bias, + layer.self_attn.out_proj.weight, + layer.self_attn.out_proj.bias, + layer.norm1.weight, + layer.norm1.bias, + layer.norm1.eps, + layer.norm2.weight, + layer.norm2.bias, + layer.norm2.eps, + ) + + blocks.append(block) + + self.t2s_transformer = T2STransformer(self.num_layers, blocks) + + def make_input_data(self, x, x_lens, y, y_lens, bert_feature): + x = self.ar_text_embedding(x) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) + x = self.ar_text_position(x) + x_mask = make_pad_mask(x_lens) + + y_mask = make_pad_mask(y_lens) + y_mask_int = y_mask.type(torch.int64) + codes = y.type(torch.int64) * (1 - y_mask_int) + + # Training + # AR Decoder + y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS) + x_len = x_lens.max() + y_len = y_lens.max() + y_emb = self.ar_audio_embedding(y) + y_pos = self.ar_audio_position(y_emb) + + xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) + + ar_xy_padding_mask = xy_padding_mask + + x_attn_mask = F.pad( + torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), + (0, y_len), + value=True, + ) + # x_attn_mask[:, x_len]=False + y_attn_mask = F.pad( + torch.triu( + torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), + diagonal=1, + ), + (x_len, 0), + value=False, + ) + + xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) + bsz, src_len = x.shape[0], x_len + y_len + _xy_padding_mask = ( + ar_xy_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, self.num_head, -1, -1) + .reshape(bsz * self.num_head, 1, src_len) + ) + xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) + new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) + xy_attn_mask = new_attn_mask + # x 和完整的 y 一次性输入模型 + xy_pos = torch.concat([x, y_pos], dim=1) + + return xy_pos, xy_attn_mask, targets + + def forward(self, x, x_lens, y, y_lens, bert_feature): + """ + x: phoneme_ids + y: semantic_ids + """ + + reject_y, reject_y_lens = make_reject_y(y, y_lens) + + xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature) + + xy_dec, _ = self.h( + (xy_pos, None), + mask=xy_attn_mask, + ) + x_len = x_lens.max() + logits = self.ar_predict_layer(xy_dec[:, x_len:]) + + ###### DPO ############# + reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data( + x, x_lens, reject_y, reject_y_lens, bert_feature + ) + + reject_xy_dec, _ = self.h( + (reject_xy_pos, None), + mask=reject_xy_attn_mask, + ) + x_len = x_lens.max() + reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:]) + + # loss + # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum + + loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum") + acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item() + + A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets) + loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True) + + loss = loss_1 + loss_2 + + return loss, acc + + def forward_old(self, x, x_lens, y, y_lens, bert_feature): + """ + x: phoneme_ids + y: semantic_ids + """ + x = self.ar_text_embedding(x) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) + x = self.ar_text_position(x) + x_mask = make_pad_mask(x_lens) + + y_mask = make_pad_mask(y_lens) + y_mask_int = y_mask.type(torch.int64) + codes = y.type(torch.int64) * (1 - y_mask_int) + + # Training + # AR Decoder + y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS) + x_len = x_lens.max() + y_len = y_lens.max() + y_emb = self.ar_audio_embedding(y) + y_pos = self.ar_audio_position(y_emb) + + xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) + ar_xy_padding_mask = xy_padding_mask + + x_attn_mask = F.pad( + torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), + (0, y_len), + value=True, + ) + y_attn_mask = F.pad( + torch.triu( + torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), + diagonal=1, + ), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) + bsz, src_len = x.shape[0], x_len + y_len + _xy_padding_mask = ( + ar_xy_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, self.num_head, -1, -1) + .reshape(bsz * self.num_head, 1, src_len) + ) + xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) + new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) + xy_attn_mask = new_attn_mask + # x 和完整的 y 一次性输入模型 + xy_pos = torch.concat([x, y_pos], dim=1) + xy_dec, _ = self.h( + (xy_pos, None), + mask=xy_attn_mask, + ) + logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) + # loss + # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum + loss = F.cross_entropy(logits, targets, reduction="sum") + acc = self.ar_accuracy_metric(logits.detach(), targets).item() + return loss, acc + + # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么 + def infer( + self, + x, + x_lens, + prompts, + bert_feature, + top_k: int = -100, + early_stop_num: int = -1, + temperature: float = 1.0, + ): + x = self.ar_text_embedding(x) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) + x = self.ar_text_position(x) + + # AR Decoder + y = prompts + prefix_len = y.shape[1] + x_len = x.shape[1] + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + stop = False + for _ in tqdm(range(1500)): + y_emb = self.ar_audio_embedding(y) + y_pos = self.ar_audio_position(y_emb) + # x 和逐渐增长的 y 一起输入给模型 + xy_pos = torch.concat([x, y_pos], dim=1) + y_len = y.shape[1] + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), + value=True, + ) + y_attn_mask = F.pad( + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device) + + xy_dec, _ = self.h( + (xy_pos, None), + mask=xy_attn_mask, + ) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) + + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + print("use early stop num:", early_stop_num) + stop = True + + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) + stop = True + if stop: + if prompts.shape[1] == y.shape[1]: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) + print("bad zero prediction") + print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + break + # 本次生成的 semantic_ids 和之前的 y 构成新的 y + # print(samples.shape)#[1,1]#第一个1是bs + # import os + # os._exit(2333) + y = torch.concat([y, samples], dim=1) + return y + + def pad_y_eos(self, y, y_mask_int, eos_id): + targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1) + # 错位 + return targets[:, :-1], targets[:, 1:] + + def infer_panel_batch_infer( + self, + x: List[torch.LongTensor], #####全部文本token + x_lens: torch.LongTensor, + prompts: torch.LongTensor, ####参考音频token + bert_feature: List[torch.LongTensor], + top_k: int = -100, + top_p: int = 100, + early_stop_num: int = -1, + temperature: float = 1.0, + repetition_penalty: float = 1.35, + **kwargs, + ): + if prompts is None: + print("Warning: Prompt free is not supported batch_infer! switch to naive_infer") + return self.infer_panel_naive_batched( + x, + x_lens, + prompts, + bert_feature, + top_k=top_k, + top_p=top_p, + early_stop_num=early_stop_num, + temperature=temperature, + **kwargs, + ) + + max_len = kwargs.get("max_len", x_lens.max()) + x_list = [] + for x_item, bert_item in zip(x, bert_feature): + # max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) + x_item = self.ar_text_embedding(x_item.unsqueeze(0)) + x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0)) + x_item = self.ar_text_position(x_item).squeeze(0) + # x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0] early_stop_num) or idx == 1499: + print("use early stop num:", early_stop_num) + stop = True + for i, batch_index in enumerate(batch_idx_map): + batch_index = batch_idx_map[i] + idx_list[batch_index] = idx + y_list[batch_index] = y[i, :-1] + + if None not in idx_list: + stop = True + + if stop: + if y.shape[1] == 0: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) + print("bad zero prediction") + print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + break + + ####################### update next step ################################### + y_emb = self.ar_audio_embedding(y[:, -1:]) + xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[ + :, y_len + idx + ].to(dtype=y_emb.dtype, device=y_emb.device) + + if None in idx_list: + for i in range(x.shape[0]): + if idx_list[i] is None: + idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替 + + if ref_free: + return y_list, [0] * x.shape[0] + # print(idx_list) + return y_list, idx_list + + def infer_panel_naive_batched( + self, + x: List[torch.LongTensor], #####全部文本token + x_lens: torch.LongTensor, + prompts: torch.LongTensor, ####参考音频token + bert_feature: List[torch.LongTensor], + top_k: int = -100, + top_p: int = 100, + early_stop_num: int = -1, + temperature: float = 1.0, + repetition_penalty: float = 1.35, + **kwargs, + ): + y_list = [] + idx_list = [] + for i in range(len(x)): + y, idx = self.infer_panel_naive( + x[i].unsqueeze(0), + x_lens[i], + prompts[i].unsqueeze(0) if prompts is not None else None, + bert_feature[i].unsqueeze(0), + top_k, + top_p, + early_stop_num, + temperature, + repetition_penalty, + **kwargs, + ) + y_list.append(y[0]) + idx_list.append(idx) + + return y_list, idx_list + + def infer_panel_naive( + self, + x: torch.LongTensor, #####全部文本token + x_lens: torch.LongTensor, + prompts: torch.LongTensor, ####参考音频token + bert_feature: torch.LongTensor, + top_k: int = -100, + top_p: int = 100, + early_stop_num: int = -1, + temperature: float = 1.0, + repetition_penalty: float = 1.35, + **kwargs, + ): + x = self.ar_text_embedding(x) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) + x = self.ar_text_position(x) + + # AR Decoder + y = prompts + + x_len = x.shape[1] + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + stop = False + # print(1111111,self.num_layers) + + k_cache = None + v_cache = None + ################### first step ########################## + if y is not None: + y_emb = self.ar_audio_embedding(y) + y_len = y_emb.shape[1] + prefix_len = y.shape[1] + y_pos = self.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + ref_free = False + else: + y_emb = None + y_len = 0 + prefix_len = 0 + y_pos = None + xy_pos = x + y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device) + ref_free = True + + bsz = x.shape[0] + src_len = x_len + y_len + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) + y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = ( + torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + .unsqueeze(0) + .expand(bsz * self.num_head, -1, -1) + .view(bsz, self.num_head, src_len, src_len) + .to(device=x.device, dtype=torch.bool) + ) + + for idx in tqdm(range(1500)): + if xy_attn_mask is not None: + xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None) + else: + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) + + logits = self.ar_predict_layer(xy_dec[:, -1]) + + if idx == 0: + xy_attn_mask = None + if idx < 11: ###至少预测出10个token不然不给停止(0.4s) + logits = logits[:, :-1] + + samples = sample( + logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature + )[0] + + y = torch.concat([y, samples], dim=1) + + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + print("use early stop num:", early_stop_num) + stop = True + + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + stop = True + if stop: + if y.shape[1] == 0: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) + print("bad zero prediction") + print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + break + + ####################### update next step ################################### + y_emb = self.ar_audio_embedding(y[:, -1:]) + xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[ + :, y_len + idx + ].to(dtype=y_emb.dtype, device=y_emb.device) + + if ref_free: + return y[:, :-1], 0 + return y[:, :-1], idx + + def infer_panel( + self, + x: torch.LongTensor, #####全部文本token + x_lens: torch.LongTensor, + prompts: torch.LongTensor, ####参考音频token + bert_feature: torch.LongTensor, + top_k: int = -100, + top_p: int = 100, + early_stop_num: int = -1, + temperature: float = 1.0, + repetition_penalty: float = 1.35, + **kwargs, + ): + return self.infer_panel_naive( + x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs + ) diff --git a/AR/models/t2s_model_onnx.py b/AR/models/t2s_model_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7b50a3adc0217d8016bfc41507e71eb8dc00b8 --- /dev/null +++ b/AR/models/t2s_model_onnx.py @@ -0,0 +1,394 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py +# reference: https://github.com/lifeiteng/vall-e +import torch +from torch import nn +from torch.nn import functional as F +from torchmetrics.classification import MulticlassAccuracy + +from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding +from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer + +default_config = { + "embedding_dim": 512, + "hidden_dim": 512, + "num_head": 8, + "num_layers": 12, + "num_codebook": 8, + "p_dropout": 0.0, + "vocab_size": 1024 + 1, + "phoneme_vocab_size": 512, + "EOS": 1024, +} + +inf_tensor_value = torch.FloatTensor([-float("Inf")]).float() + + +def logits_to_probs( + logits, + previous_tokens=None, + temperature: float = 1.0, + top_k=None, + top_p=None, + repetition_penalty: float = 1.0, +): + previous_tokens = previous_tokens.squeeze() + if previous_tokens is not None and repetition_penalty != 1.0: + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=0, index=previous_tokens) + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits.scatter_(dim=0, index=previous_tokens, src=score) + + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum( + torch.nn.functional.softmax( + sorted_logits, + dim=-1, + ), + dim=-1, + ) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter( + dim=0, + index=sorted_indices, + src=sorted_indices_to_remove, + ) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) + + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, top_k) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, inf_tensor_value, logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.randn_like(probs_sort) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def sample( + logits, + previous_tokens, + **sampling_kwargs, +): + probs = logits_to_probs( + logits=logits, + previous_tokens=previous_tokens, + **sampling_kwargs, + ) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +class OnnxEncoder(nn.Module): + def __init__(self, ar_text_embedding, bert_proj, ar_text_position): + super().__init__() + self.ar_text_embedding = ar_text_embedding + self.bert_proj = bert_proj + self.ar_text_position = ar_text_position + + def forward(self, x, bert_feature): + x = self.ar_text_embedding(x) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) + return self.ar_text_position(x) + + +class T2SFirstStageDecoder(nn.Module): + def __init__( + self, + ar_audio_embedding, + ar_audio_position, + h, + ar_predict_layer, + loss_fct, + ar_accuracy_metric, + top_k, + early_stop_num, + num_layers, + ): + super().__init__() + self.ar_audio_embedding = ar_audio_embedding + self.ar_audio_position = ar_audio_position + self.h = h + self.ar_predict_layer = ar_predict_layer + self.loss_fct = loss_fct + self.ar_accuracy_metric = ar_accuracy_metric + self.top_k = top_k + self.early_stop_num = early_stop_num + self.num_layers = num_layers + + def forward(self, x, prompt): + y = prompt + x_example = x[:, :, 0] * 0.0 + # N, 1, 512 + cache = { + "all_stage": self.num_layers, + "k": None, + "v": None, + "y_emb": None, + "first_infer": 1, + "stage": 0, + } + + y_emb = self.ar_audio_embedding(y) + + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + + xy_pos = torch.concat([x, y_pos], dim=1) + + y_example = y_pos[:, :, 0] * 0.0 + x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool() + y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64) + y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum( + torch.ones_like( + y_example.transpose(0, 1), + dtype=torch.int64, + ), + dim=0, + ) + y_attn_mask = y_attn_mask > 0 + + x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool() + y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool() + x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1) + y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + cache["k"] = ( + torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512))) + .unsqueeze(1) + .repeat(self.num_layers, 1, 1, 1) + ) + cache["v"] = ( + torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512))) + .unsqueeze(1) + .repeat(self.num_layers, 1, 1, 1) + ) + + xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + + y = torch.concat([y, samples], dim=1) + + return y, cache["k"], cache["v"], cache["y_emb"], x_example + + +class T2SStageDecoder(nn.Module): + def __init__( + self, + ar_audio_embedding, + ar_audio_position, + h, + ar_predict_layer, + loss_fct, + ar_accuracy_metric, + top_k, + early_stop_num, + num_layers, + ): + super().__init__() + self.ar_audio_embedding = ar_audio_embedding + self.ar_audio_position = ar_audio_position + self.h = h + self.ar_predict_layer = ar_predict_layer + self.loss_fct = loss_fct + self.ar_accuracy_metric = ar_accuracy_metric + self.top_k = top_k + self.early_stop_num = early_stop_num + self.num_layers = num_layers + + def forward(self, y, k, v, y_emb, x_example): + cache = { + "all_stage": self.num_layers, + "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)), + "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)), + "y_emb": y_emb, + "first_infer": 0, + "stage": 0, + } + + y_emb = torch.cat( + [ + cache["y_emb"], + self.ar_audio_embedding(y[:, -1:]), + ], + 1, + ) + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + + xy_pos = y_pos[:, -1:] + + y_example = y_pos[:, :, 0] * 0.0 + + xy_attn_mask = torch.cat([x_example, y_example], dim=1) + xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool) + + xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + + y = torch.concat([y, samples], dim=1) + + return y, cache["k"], cache["v"], cache["y_emb"], logits, samples + + +class Text2SemanticDecoder(nn.Module): + def __init__(self, config, norm_first=False, top_k=3): + super(Text2SemanticDecoder, self).__init__() + self.model_dim = config["model"]["hidden_dim"] + self.embedding_dim = config["model"]["embedding_dim"] + self.num_head = config["model"]["head"] + self.num_layers = config["model"]["n_layer"] + self.norm_first = norm_first + self.vocab_size = config["model"]["vocab_size"] + self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"] + self.p_dropout = float(config["model"]["dropout"]) + self.EOS = config["model"]["EOS"] + self.norm_first = norm_first + assert self.EOS == self.vocab_size - 1 + self.bert_proj = nn.Linear(1024, self.embedding_dim) + self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout) + self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True) + self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout) + self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True) + self.h = TransformerEncoder( + TransformerEncoderLayer( + d_model=self.model_dim, + nhead=self.num_head, + dim_feedforward=self.model_dim * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + ), + num_layers=self.num_layers, + norm=LayerNorm(self.model_dim) if norm_first else None, + ) + self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) + self.loss_fct = nn.CrossEntropyLoss(reduction="sum") + self.ar_accuracy_metric = MulticlassAccuracy( + self.vocab_size, + top_k=top_k, + average="micro", + multidim_average="global", + ignore_index=self.EOS, + ) + self.top_k = torch.LongTensor([1]) + self.early_stop_num = torch.LongTensor([-1]) + + def init_onnx(self): + self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position) + self.first_stage_decoder = T2SFirstStageDecoder( + self.ar_audio_embedding, + self.ar_audio_position, + self.h, + self.ar_predict_layer, + self.loss_fct, + self.ar_accuracy_metric, + self.top_k, + self.early_stop_num, + self.num_layers, + ) + self.stage_decoder = T2SStageDecoder( + self.ar_audio_embedding, + self.ar_audio_position, + self.h, + self.ar_predict_layer, + self.loss_fct, + self.ar_accuracy_metric, + self.top_k, + self.early_stop_num, + self.num_layers, + ) + + def forward(self, x, prompts, bert_feature): + early_stop_num = self.early_stop_num + prefix_len = prompts.shape[1] + + x = self.onnx_encoder(x, bert_feature) + y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts) + + stop = False + for idx in range(1, 1500): + enco = self.stage_decoder(y, k, v, y_emb, stage, x_example) + y, k, v, y_emb, stage, logits, samples = enco + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + stop = True + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + stop = True + if stop: + break + y[0, -1] = 0 + return y, idx + + def infer(self, x, prompts, bert_feature): + top_k = self.top_k + early_stop_num = self.early_stop_num + + x = self.onnx_encoder(x, bert_feature) + + y = prompts + prefix_len = y.shape[1] + x_len = x.shape[1] + x_example = x[:, :, 0] * 0.0 + x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example) + x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool) + + stop = False + cache = { + "all_stage": self.num_layers, + "k": [None] * self.num_layers, + "v": [None] * self.num_layers, + "y_emb": None, + "first_infer": 1, + "stage": 0, + } + for idx in range(1500): + if cache["first_infer"] == 1: + y_emb = self.ar_audio_embedding(y) + else: + y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1) + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + if cache["first_infer"] == 1: + xy_pos = torch.concat([x, y_pos], dim=1) + else: + xy_pos = y_pos[:, -1:] + y_len = y_pos.shape[1] + if cache["first_infer"] == 1: + x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True) + y_attn_mask = F.pad( + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + else: + xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool) + xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + stop = True + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + stop = True + if stop: + if prompts.shape[1] == y.shape[1]: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) + break + y = torch.concat([y, samples], dim=1) + cache["first_infer"] = 0 + return y, idx diff --git a/AR/models/utils.py b/AR/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4f24d89ff7c02a2e35ce8c083f4b028dee85a0 --- /dev/null +++ b/AR/models/utils.py @@ -0,0 +1,282 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py +# reference: https://github.com/lifeiteng/vall-e +from typing import Tuple + +import torch +import torch.nn.functional as F + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """ + Args: + lengths: + A 1-D tensor containing sentence lengths. + max_len: + The length of masks. + Returns: + Return a 2-D bool tensor, where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + + #>>> lengths = torch.tensor([1, 3, 2, 5]) + #>>> make_pad_mask(lengths) + tensor([[False, True, True, True, True], + [False, False, False, True, True], + [False, False, True, True, True], + [False, False, False, False, False]]) + """ + assert lengths.ndim == 1, lengths.ndim + max_len = max(max_len, lengths.max()) + n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) + + return expaned_lengths >= lengths.unsqueeze(-1) + + +def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """ + Args: + lengths: + A 1-D tensor containing sentence lengths. + max_len: + The length of masks. + Returns: + Return a 2-D bool tensor, where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + + #>>> lengths = torch.tensor([1, 3, 2, 5]) + #>>> make_pad_mask(lengths) + tensor( + [ + [True, True, False], + [True, False, False], + [True, True, False], + ... + ] + ) + """ + assert lengths.ndim == 1, lengths.ndim + max_len = max(max_len, lengths.max()) + n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1) + expaned_lengths -= (max_len - lengths).unsqueeze(-1) + + return expaned_lengths < 0 + + +# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py +def top_k_top_p_filtering( + logits, + top_k=0, + top_p=1.0, + filter_value=-float("Inf"), + min_tokens_to_keep=1, +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits + + +def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): + # temperature: (`optional`) float + # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + # top_k: (`optional`) int + # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. + # top_p: (`optional`) float + # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. + + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + logits = logits / temperature + # Top-p/top-k filtering + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + # Sample + token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + return token + + +from typing import Optional + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs( + logits, + previous_tokens: Optional[torch.Tensor] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[int] = None, + repetition_penalty: float = 1.0, +): + # if previous_tokens is not None: + # previous_tokens = previous_tokens.squeeze() + # print(logits.shape,previous_tokens.shape) + # pdb.set_trace() + if previous_tokens is not None and repetition_penalty != 1.0: + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=1, index=previous_tokens) + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits.scatter_(dim=1, index=previous_tokens, src=score) + + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[:, 0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter( + dim=1, + index=sorted_indices, + src=sorted_indices_to_remove, + ) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) + + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v[:, -1].unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample( + logits, + previous_tokens: Optional[torch.Tensor] = None, + **sampling_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def dpo_loss( + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + beta: float, + reference_free: bool = False, +) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + if reference_free: + ref_logratios = 0 + + logits = pi_logratios - ref_logratios + + losses = -F.logsigmoid(beta * logits) + chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return losses.mean(), chosen_rewards, rejected_rewards + + +def get_batch_logps( + logits_target: torch.FloatTensor, + logits_reject: torch.FloatTensor, + labels_target: torch.LongTensor, + labels_reject: torch.LongTensor, + average_log_prob: bool = False, +) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + # dummy token; we'll ignore the losses on these tokens later + + per_token_logps_target = torch.gather( + logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2) + ).squeeze(2) + per_token_logps_reject = torch.gather( + logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2) + ).squeeze(2) + + return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1) + + +def make_reject_y(y_o, y_lens): + def repeat_P(y): + range_idx, _ = torch.randint(0, len(y), size=(2,)).sort() + pre = y[: range_idx[0]] + shf = y[range_idx[1] :] + range_text = y[range_idx[0] : range_idx[1]] + new_y = torch.cat([pre, range_text, range_text, shf]) + return new_y + + def lost_P(y): + range_idx, _ = torch.randint(0, len(y), size=(2,)).sort() + pre = y[: range_idx[0]] + shf = y[range_idx[1] :] + range_text = y[range_idx[0] : range_idx[1]] + new_y = torch.cat([pre, shf]) + return new_y + + bs = len(y_lens) + reject_y = [] + reject_y_lens = [] + for b in range(bs): + process_item_idx = torch.randint(0, 1, size=(1,))[0] + if process_item_idx == 0: + new_y = repeat_P(y_o[b]) + reject_y.append(new_y) + reject_y_lens.append(len(new_y)) + elif process_item_idx == 1: + new_y = lost_P(y_o[b]) + reject_y.append(new_y) + reject_y_lens.append(len(new_y)) + max_length = max(reject_y_lens) + for b in range(bs): + pad_length = max_length - reject_y_lens[b] + reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0) + + reject_y = torch.stack(reject_y, dim=0) + reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device) + + return reject_y, reject_y_lens diff --git a/AR/modules/__init__.py b/AR/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/AR/modules/activation.py b/AR/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..936f9c3fa5c18357eac4d57294167f23d7ce5700 --- /dev/null +++ b/AR/modules/activation.py @@ -0,0 +1,413 @@ +# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.nn import Linear, Module +from torch.nn import functional as F +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear +from torch.nn.parameter import Parameter + +from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched + +F.multi_head_attention_forward = multi_head_attention_forward_patched + + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces as described in the paper: + `Attention Is All You Need `_. + + Multi-Head Attention is defined as: + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + ``forward()`` will use a special optimized implementation if all of the following + conditions are met: + + - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This + restriction will be loosened in the future.) + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - dropout is 0 + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``batch_first`` is ``True`` and the input is batched + - ``kdim`` and ``vdim`` are equal to ``embed_dim`` + - at most one of ``key_padding_mask`` or ``attn_mask`` is passed + - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` + nor ``attn_mask`` is passed + + If the optimized implementation is in use, a + `NestedTensor `_ can be passed for + ``query``/``key``/``value`` to represent padding more efficiently than using a + padding mask. In this case, a `NestedTensor `_ + will be returned, and an additional speedup proportional to the fraction of the input + that is padding can be expected. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Examples:: + + >>> # xdoctest: +SKIP + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + """ + + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + linear1_cls=Linear, + linear2_cls=Linear, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + if add_bias_kv: + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + if linear1_cls == Linear: + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs), + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs), + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs), + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs), + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, + embed_dim, + bias=bias, + **factory_kwargs, + ) + + self._reset_parameters() + else: + if not self._qkv_same_embed_dim: + raise NotImplementedError + else: + self.in_proj_linear = linear1_cls( + embed_dim, + 3 * embed_dim, + bias=bias, + **factory_kwargs, + ) + self.in_proj_weight = self.in_proj_linear.weight + + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = self.in_proj_linear.bias + else: + self.register_parameter("in_proj_bias", None) + + self.out_proj = linear2_cls( + embed_dim, + embed_dim, + bias=bias, + **factory_kwargs, + ) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + self.add_zero_attn = add_zero_attn + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + cache=None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and byte masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ + is_batched = query.dim() == 3 + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != torch.bool and not torch.is_floating_point( + key_padding_mask, + ): + raise AssertionError("only bool and floating types of key_padding_mask are supported") + why_not_fast_path = "" + if not is_batched: + why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: + why_not_fast_path = ( + f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" + ) + elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype: + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = ( + f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" + ) + elif self.training: + why_not_fast_path = "training is enabled" + elif not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif self.bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif self.bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif self.dropout: + why_not_fast_path = f"dropout was {self.dropout}, required zero" + elif self.add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + elif attn_mask is not None: + why_not_fast_path = "attn_mask was not None" + elif query.is_nested and key_padding_mask is not None: + why_not_fast_path = "key_padding_mask is not supported with NestedTensor input" + elif self.num_heads % 2 == 1: + why_not_fast_path = "num_heads is odd" + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]): + why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" + elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]): + why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad" + if not why_not_fast_path: + return torch._native_multi_head_attention( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + key_padding_mask if key_padding_mask is not None else attn_mask, + need_weights, + average_attn_weights, + 1 if key_padding_mask is not None else 0 if attn_mask is not None else None, + ) + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + cache=cache, + ) + else: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + cache=cache, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights + else: + return attn_output, attn_output_weights diff --git a/AR/modules/activation_onnx.py b/AR/modules/activation_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..c14ce40c9648831a172861951b2dfdc6d5cab0c6 --- /dev/null +++ b/AR/modules/activation_onnx.py @@ -0,0 +1,188 @@ +# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.nn import Linear, Module +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear +from torch.nn.parameter import Parameter + +from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched + + +class MultiheadAttention(Module): + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + linear1_cls=Linear, + linear2_cls=Linear, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + if add_bias_kv: + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + if linear1_cls == Linear: + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty( + (embed_dim, embed_dim), + **factory_kwargs, + ) + ) + self.k_proj_weight = Parameter( + torch.empty( + (embed_dim, self.kdim), + **factory_kwargs, + ) + ) + self.v_proj_weight = Parameter( + torch.empty( + (embed_dim, self.vdim), + **factory_kwargs, + ) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = Parameter( + torch.empty( + (3 * embed_dim, embed_dim), + **factory_kwargs, + ) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = Parameter( + torch.empty(3 * embed_dim, **factory_kwargs), + ) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + + self._reset_parameters() + else: + if not self._qkv_same_embed_dim: + raise NotImplementedError + else: + self.in_proj_linear = linear1_cls( + embed_dim, + 3 * embed_dim, + bias=bias, + **factory_kwargs, + ) + self.in_proj_weight = self.in_proj_linear.weight + + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = self.in_proj_linear.bias + else: + self.register_parameter("in_proj_bias", None) + + self.out_proj = linear2_cls( + embed_dim, + embed_dim, + bias=bias, + **factory_kwargs, + ) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + self.add_zero_attn = add_zero_attn + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + cache=None, + ) -> Tuple[Tensor, Optional[Tensor]]: + any_nested = query.is_nested or key.is_nested or value.is_nested + query = key = value = query.transpose(1, 0) + attn_output = multi_head_attention_forward_patched( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + cache=cache, + ) + return attn_output.transpose(1, 0) diff --git a/AR/modules/embedding.py b/AR/modules/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..39da560386b954f9b1626cf4a47e3b4b4d4195d1 --- /dev/null +++ b/AR/modules/embedding.py @@ -0,0 +1,78 @@ +# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py +import math + +import torch +from torch import nn + + +class TokenEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + vocab_size: int, + dropout: float = 0.0, + ): + super().__init__() + + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + + self.dropout = torch.nn.Dropout(p=dropout) + self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim) + + @property + def weight(self) -> torch.Tensor: + return self.word_embeddings.weight + + def embedding(self, index: int) -> torch.Tensor: + return self.word_embeddings.weight[index : index + 1] + + def forward(self, x: torch.Tensor): + x = self.word_embeddings(x) + x = self.dropout(x) + return x + + +class SinePositionalEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + dropout: float = 0.0, + scale: bool = False, + alpha: bool = False, + ): + super().__init__() + self.embedding_dim = embedding_dim + self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 + self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) + self.dropout = torch.nn.Dropout(p=dropout) + + self.reverse = False + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, 4000)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.embedding_dim) + if self.reverse: + position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype).detach() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.extend_pe(x) + output = x.unsqueeze(-1) if x.ndim == 2 else x + output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(output) diff --git a/AR/modules/embedding_onnx.py b/AR/modules/embedding_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..c870013f0e4c1af463790ff3627fab3550fe4c93 --- /dev/null +++ b/AR/modules/embedding_onnx.py @@ -0,0 +1,63 @@ +# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py +import math + +import torch +from torch import nn + + +class TokenEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + vocab_size: int, + dropout: float = 0.0, + ): + super().__init__() + + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + + self.dropout = torch.nn.Dropout(p=dropout) + self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim) + + @property + def weight(self) -> torch.Tensor: + return self.word_embeddings.weight + + def embedding(self, index: int) -> torch.Tensor: + return self.word_embeddings.weight[index : index + 1] + + def forward(self, x: torch.Tensor): + x = self.word_embeddings(x) + x = self.dropout(x) + return x + + +class SinePositionalEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + dropout: float = 0.0, + scale: bool = False, + alpha: bool = False, + ): + super().__init__() + self.embedding_dim = embedding_dim + self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 + self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) + self.dropout = torch.nn.Dropout(p=dropout) + self.reverse = False + self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim)) + + def extend_pe(self, x): + position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1) + scpe = (position * self.div_term).unsqueeze(0) + pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0) + pe = pe.contiguous().view(1, -1, self.embedding_dim) + return pe + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pe = self.extend_pe(x) + output = x.unsqueeze(-1) if x.ndim == 2 else x + output = output * self.x_scale + self.alpha * pe + return self.dropout(output) diff --git a/AR/modules/lr_schedulers.py b/AR/modules/lr_schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..707a911f9b7fbcf74d009143aa0c3e21605c0704 --- /dev/null +++ b/AR/modules/lr_schedulers.py @@ -0,0 +1,85 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/modules/lr_schedulers.py +# reference: https://github.com/lifeiteng/vall-e +import math + +import torch +from matplotlib import pyplot as plt +from torch import nn +from torch.optim import Adam + + +class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler): + """ + Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers. + """ + + def __init__( + self, + optimizer, + init_lr, + peak_lr, + end_lr, + warmup_steps=10000, + total_steps=400000, + current_step=0, + ): + self.init_lr = init_lr + self.peak_lr = peak_lr + self.end_lr = end_lr + self.optimizer = optimizer + self._warmup_rate = (peak_lr - init_lr) / warmup_steps + self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps) + self._current_step = current_step + self.lr = init_lr + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self._last_lr = [self.lr] + + def set_lr(self, lr): + self._last_lr = [g["lr"] for g in self.optimizer.param_groups] + for g in self.optimizer.param_groups: + # g['lr'] = lr + g["lr"] = self.end_lr ###锁定用线性 + + def step(self): + if self._current_step < self.warmup_steps: + lr = self.init_lr + self._warmup_rate * self._current_step + + elif self._current_step > self.total_steps: + lr = self.end_lr + + else: + decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps) + if decay_ratio < 0.0 or decay_ratio > 1.0: + raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.") + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) + lr = self.end_lr + coeff * (self.peak_lr - self.end_lr) + + self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定! + self.set_lr(lr) + self.lr = lr + self._current_step += 1 + return self.lr + + +if __name__ == "__main__": + m = nn.Linear(10, 10) + opt = Adam(m.parameters(), lr=1e-4) + s = WarmupCosineLRSchedule( + opt, + 1e-6, + 2e-4, + 1e-6, + warmup_steps=2000, + total_steps=20000, + current_step=0, + ) + lrs = [] + for i in range(25000): + s.step() + lrs.append(s.lr) + print(s.lr) + + plt.plot(lrs) + plt.plot(range(0, 25000), lrs) + plt.show() diff --git a/AR/modules/optim.py b/AR/modules/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..fb878485430d47ada1b22210a5347bcd9a11c0eb --- /dev/null +++ b/AR/modules/optim.py @@ -0,0 +1,593 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import logging +from collections import defaultdict +from typing import List, Tuple + +import torch +from torch import Tensor +from torch.optim import Optimizer + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager + def batched_params(self, param_group, group_params_names): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state, p_names in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. + """ + batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + + assert len(param_group) == len(group_params_names) + for p, named_p in zip(param_group, group_params_names): + key = (str(p.dtype), *p.shape) + batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i]) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] + + stacked_params_dict = dict() + + # turn batches into a list, in deterministic order. + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] + + for batch, batch_names in zip(batches, batches_names): + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch]) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + tuples.append((p_stacked, state, batch_names)) + + yield tuples # <-- calling code will do the actual optimization here! + + for (stacked_params, _state, _names), batch in zip(tuples, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + +class ScaledAdam(BatchedOptimizer): + """ + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + """ + + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + parameters_names=None, + show_dominant_parameters=True, + ): + assert parameters_names is not None, ( + "Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter" + ) + defaults = dict( + lr=lr, + clipping_scale=clipping_scale, + betas=betas, + scalar_lr_scale=scalar_lr_scale, + eps=eps, + param_min_rms=param_min_rms, + param_max_rms=param_max_rms, + scalar_max=scalar_max, + size_update_period=size_update_period, + clipping_update_period=clipping_update_period, + ) + + super(ScaledAdam, self).__init__(params, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + self.show_dominant_parameters = show_dominant_parameters + + def __setstate__(self, state): + super(ScaledAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group, group_params_names in zip(self.param_groups, self.parameters_names): + with self.batched_params(group["params"], group_params_names) as batches: + # batches is list of pairs (stacked_param, state). stacked_param is like + # a regular parameter, and will have a .grad, but the 1st dim corresponds to + # a stacking dim, it is not a real dim. + + if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized + clipping_scale = 1 + else: + clipping_scale = self._get_clipping_scale(group, batches) + + for p, state, _ in batches: + # Perform optimization step. + # grad is not going to be None, we handled that when creating the batches. + grad = p.grad + if grad.is_sparse: + raise RuntimeError("ScaledAdam optimizer does not support sparse gradients") + # State initialization + if len(state) == 0: + self._init_state(group, p, state) + + self._step_one_batch(group, p, state, clipping_scale) + + return loss + + def _init_state(self, group: dict, p: Tensor, state: dict): + """ + Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p + is actually the batch dimension, corresponding to batched-together + parameters of a given shape. + + + Args: + group: Dict to look up configuration values. + p: The parameter that we are initializing the state for + state: Dict from string to whatever state we are initializing + """ + size_update_period = group["size_update_period"] + + state["step"] = 0 + + kwargs = {"device": p.device, "dtype": p.dtype} + + # 'delta' implements conventional momentum. There are + # several different kinds of update going on, so rather than + # compute "exp_avg" like in Adam, we store and decay a + # parameter-change "delta", which combines all forms of + # update. this is equivalent to how it's done in Adam, + # except for the first few steps. + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + batch_size = p.shape[0] + numel = p.numel() // batch_size + numel = p.numel() + + if numel > 1: + # "param_rms" just periodically records the scalar root-mean-square value of + # the parameter tensor. + # it has a shape like (batch_size, 1, 1, 1, 1) + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + state["param_rms"] = param_rms + + state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) + state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs) + + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float: + """ + Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients + by this amount before applying the rest of the update. + + Args: + group: the parameter group, an item in self.param_groups + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + assert len(tuples) >= 1 + clipping_scale = group["clipping_scale"] + (first_p, first_state, _) = tuples[0] + step = first_state["step"] + if clipping_scale is None or step == 0: + # no clipping. return early on step == 0 because the other + # parameters' state won't have been initialized yet. + return 1.0 + clipping_update_period = group["clipping_update_period"] + + tot_sumsq = torch.tensor(0.0, device=first_p.device) + for p, state, param_names in tuples: + grad = p.grad + if grad.is_sparse: + raise RuntimeError("ScaledAdam optimizer does not support sparse gradients") + if p.numel() == p.shape[0]: # a batch of scalars + tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + else: + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + + tot_norm = tot_sumsq.sqrt() + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device) + first_state["model_norms"][step % clipping_update_period] = tot_norm + + if step % clipping_update_period == 0: + # Print some stats. + # We don't reach here if step == 0 because we would have returned + # above. + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + quartiles = [] + for n in range(0, 5): + index = min( + clipping_update_period - 1, + (clipping_update_period // 4) * n, + ) + quartiles.append(sorted_norms[index].item()) + + median = quartiles[2] + threshold = clipping_scale * median + first_state["model_norm_threshold"] = threshold + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0 + ) + first_state["num_clipped"] = 0 + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) + + if step < clipping_update_period: + return 1.0 # We have not yet estimated a norm to clip to. + else: + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?" + ) + return 1.0 + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warning(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) + return ans + + def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor): + """ + Show information of parameter which dominating tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ + all_sumsq_orig = {} + for p, state, batch_param_names in tuples: + # p is a stacked batch parameters. + batch_grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_orig = batch_grad**2 + # Dummpy values used by following `zip` statement. + batch_rms_orig = torch.ones(p.shape[0]) + else: + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim))) + + for name, sumsq_orig, rms, grad in zip( + batch_param_names, + batch_sumsq_orig, + batch_rms_orig, + batch_grad, + ): + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) + + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + ) + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), + key=lambda item: item[1][0], + reverse=True, + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter Dominating tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq = {(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) + + def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float): + """ + Do the step for one parameter, which is actually going to be a batch of + `real` parameters, with dim 0 as the batch dim. + Args: + group: dict to look up configuration values + p: parameter to update (actually multiple parameters stacked together + as a batch) + state: state-dict for p, to look up the optimizer state + """ + lr = group["lr"] + size_update_period = group["size_update_period"] + beta1 = group["betas"][0] + + grad = p.grad + if clipping_scale != 1.0: + grad = grad * clipping_scale + step = state["step"] + delta = state["delta"] + + delta.mul_(beta1) + batch_size = p.shape[0] + numel = p.numel() // batch_size + if numel > 1: + # Update the size/scale of p, and set param_rms + scale_grads = state["scale_grads"] + scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True) + if step % size_update_period == size_update_period - 1: + param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) + param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()) + if step > 0: + # self._size_update() learns the overall scale on the + # parameter, by shrinking or expanding it. + self._size_update(group, scale_grads, p, state) + + if numel == 1: + # For parameters with 1 element we just use regular Adam. + # Updates delta. + self._step_scalar(group, p, state) + else: + self._step(group, p, state) + + state["step"] = step + 1 + + def _size_update( + self, + group: dict, + scale_grads: Tensor, + p: Tensor, + state: dict, + ) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p + """ + + param_rms = state["param_rms"] + beta1, beta2 = group["betas"] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + step = state["step"] + batch_size = p.shape[0] + + size_update_period = scale_grads.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2**size_update_period + + scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr**size_step + # we don't bother with bias_correction1; this will help prevent divergence + # at the start of training. + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + + is_too_small = param_rms < param_min_rms + is_too_large = param_rms > param_max_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + # when it gets too large, stop it from getting any larger. + scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) + delta = state["delta"] + # the factor of (1-beta1) relates to momentum. + delta.add_(p * scale_step, alpha=(1 - beta1)) + + def _step(self, group: dict, p: Tensor, state: dict): + """ + This function does the core update of self.step(), in the case where the members of + the batch have more than 1 element. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + This function modifies p. + """ + grad = p.grad + lr = group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + param_min_rms = group["param_min_rms"] + step = state["step"] + + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + + this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) + bias_correction2 = 1 - beta2 ** (this_step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + + denom = exp_avg_sq.sqrt() + denom += eps + grad = grad / denom + + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + + delta = state["delta"] + delta.add_(grad * alpha) + p.add_(delta) + + def _step_scalar(self, group: dict, p: Tensor, state: dict): + """ + A simplified form of the core update for scalar tensors, where we cannot get a good + estimate of the parameter rms. + """ + beta1, beta2 = group["betas"] + scalar_max = group["scalar_max"] + eps = group["eps"] + lr = group["lr"] * group["scalar_lr_scale"] + grad = p.grad + + exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. Don't bother with bias_correction1; + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + denom = (exp_avg_sq / bias_correction2).sqrt() + eps + + delta = state["delta"] + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + p.clamp_(min=-scalar_max, max=scalar_max) + p.add_(delta) diff --git a/AR/modules/patched_mha_with_cache.py b/AR/modules/patched_mha_with_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..5bffcea63c7081571defc308b601b748fe4eb797 --- /dev/null +++ b/AR/modules/patched_mha_with_cache.py @@ -0,0 +1,428 @@ +from torch.nn.functional import * +from torch.nn.functional import ( + _mha_shape_check, + _canonical_mask, + _none_or_dtype, + _in_projection_packed, +) +import torch +# Tensor = torch.Tensor +# from typing import Callable, List, Optional, Tuple, Union + + +def multi_head_attention_forward_patched( + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p: float, + out_proj_weight, + out_proj_bias, + training=True, + key_padding_mask=None, + need_weights=True, + attn_mask=None, + use_separate_proj_weight=False, + q_proj_weight=None, + k_proj_weight=None, + v_proj_weight=None, + static_k=None, + static_v=None, + average_attn_weights=True, + is_causal=False, + cache=None, +): + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + Default: `True` + Note: `needs_weight` defaults to `True`, but should be set to `False` + For best performance when attention weights are not nedeeded. + *Setting needs_weights to `True` + leads to a significant performance degradation.* + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + is_causal: If specified, applies a causal mask as attention mask, and ignores + attn_mask for computing scaled dot product attention. + Default: ``False``. + .. warning:: + is_causal is provides a hint that the attn_mask is the + causal mask.Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. + Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect + when ``need_weights=True.``. Default: True + + + Shape: + Inputs: + - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a FloatTensor is provided, it will be directly added to the value. + If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + + Outputs: + - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns + attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. + """ + tens_ops = ( + query, + key, + value, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + out_proj_weight, + out_proj_bias, + ) + if has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + is_causal=is_causal, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + average_attn_weights=average_attn_weights, + cache=cache, + ) + + is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) + + # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input + # is batched, run the computation and before returning squeeze the + # batch dimension so that the output doesn't carry this temporary batch dimension. + if not is_batched: + # unsqueeze if the input is unbatched + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + + key_padding_mask = _canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=_none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype, + ) + + if is_causal and attn_mask is None: + raise RuntimeError( + "Need attn_mask if specifying the is_causal hint. " + "You may use the Transformer module method " + "`generate_square_subsequent_mask` to create this mask." + ) + + if is_causal and key_padding_mask is None and not need_weights: + # when we have a kpm or need weights, we need attn_mask + # Otherwise, we use the is_causal hint go as is_causal + # indicator to SDPA. + attn_mask = None + else: + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + if key_padding_mask is not None: + # We have the attn_mask, and use that to merge kpm into it. + # Turn off use of is_causal hint, as the merged mask is no + # longer causal. + is_causal = False + + assert embed_dim == embed_dim_to_check, ( + f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + ) + if isinstance(embed_dim, torch.Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(num_heads, rounding_mode="trunc") + else: + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + if use_separate_proj_weight: + # allow MHA to have different embedding dimensions when separate projection weights are used + assert key.shape[:2] == value.shape[:2], ( + f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + ) + else: + assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" + + # + # compute in-projection + # + if not use_separate_proj_weight: + assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None" + q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) + else: + assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" + assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" + assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.chunk(3) + q, k, v = _in_projection( + query, + key, + value, + q_proj_weight, + k_proj_weight, + v_proj_weight, + b_q, + b_k, + b_v, + ) + if cache != None: + if cache["first_infer"] == 1: + cache["k"][cache["stage"]] = k + # print(0,cache["k"].shape) + cache["v"][cache["stage"]] = v + else: ###12个layer每个都要留自己的cache_kv + # print(1,cache["k"].shape) + cache["k"][cache["stage"]] = torch.cat( + [cache["k"][cache["stage"]], k], 0 + ) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了 + cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0) + # print(2, cache["k"].shape) + src_len = cache["k"][cache["stage"]].shape[0] + k = cache["k"][cache["stage"]] + v = cache["v"][cache["stage"]] + # if attn_mask is not None: + # attn_mask=attn_mask[-1:,] + # print(attn_mask.shape,attn_mask) + cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] + # print(2333,cache) + # prep attention mask + + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=q.dtype, + check_other=False, + ) + + if attn_mask is not None: + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + + # add bias along batch dimension (currently second) + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None + assert bias_v is None + + # + # reshape q, k, v for multihead attention and make em batch first + # + q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if static_k is None: + k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_k.size(0) == bsz * num_heads, ( + f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + ) + assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + k = static_k + if static_v is None: + v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_v.size(0) == bsz * num_heads, ( + f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + ) + assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + v = static_v + + # add zero attention along batch dimension (now first) + if add_zero_attn: + zero_attn_shape = (bsz * num_heads, 1, head_dim) + k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.size(1) + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + bsz, + src_len, + ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = ( + key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) + ) + if attn_mask is None: + attn_mask = key_padding_mask + else: + attn_mask = attn_mask + key_padding_mask + + # adjust dropout probability + if not training: + dropout_p = 0.0 + + # + # (deep breath) calculate attention and out projection + # + + if need_weights: + B, Nt, E = q.shape + q_scaled = q / math.sqrt(E) + + assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights" + + if attn_mask is not None: + attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1)) + else: + attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) + attn_output_weights = softmax(attn_output_weights, dim=-1) + if dropout_p > 0.0: + attn_output_weights = dropout(attn_output_weights, p=dropout_p) + + attn_output = torch.bmm(attn_output_weights, v) + + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + + # optionally average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.mean(dim=1) + + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) + return attn_output, attn_output_weights + else: + # attn_mask can be either (L,S) or (N*num_heads, L, S) + # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) + # in order to match the input for SDPA of (N, num_heads, L, S) + if attn_mask is not None: + if attn_mask.size(0) == 1 and attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) + + q = q.view(bsz, num_heads, tgt_len, head_dim) + k = k.view(bsz, num_heads, src_len, head_dim) + v = v.view(bsz, num_heads, src_len, head_dim) + + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): + attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) + + attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + return attn_output, None diff --git a/AR/modules/patched_mha_with_cache_onnx.py b/AR/modules/patched_mha_with_cache_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..8144c9c6b930dc825d56deb6b71229c037efb405 --- /dev/null +++ b/AR/modules/patched_mha_with_cache_onnx.py @@ -0,0 +1,85 @@ +from torch.nn.functional import * +from torch.nn.functional import ( + _canonical_mask, +) + + +def multi_head_attention_forward_patched( + query, + key, + value, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight, + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + cache=None, +) -> Tuple[Tensor, Optional[Tensor]]: + # set up shape vars + _, _, embed_dim = query.shape + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + head_dim = embed_dim // num_heads + + proj_qkv = linear(query, in_proj_weight, in_proj_bias) + proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2] + + if cache["first_infer"] == 1: + cache["k"][cache["stage"]] = k + cache["v"][cache["stage"]] = v + else: + cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0) + cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0) + k = cache["k"][cache["stage"]] + v = cache["v"][cache["stage"]] + cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] + + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=q.dtype, + check_other=False, + ) + attn_mask = attn_mask.unsqueeze(0) + + q = q.view(-1, num_heads, head_dim).transpose(0, 1) + k = k.view(-1, num_heads, head_dim).transpose(0, 1) + v = v.view(-1, num_heads, head_dim).transpose(0, 1) + + dropout_p = 0.0 + attn_mask = attn_mask.unsqueeze(0) + q = q.view(num_heads, -1, head_dim).unsqueeze(0) + k = k.view(num_heads, -1, head_dim).unsqueeze(0) + v = v.view(num_heads, -1, head_dim).unsqueeze(0) + attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) + attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(-1, 1, attn_output.size(1)) + + return attn_output diff --git a/AR/modules/scaling.py b/AR/modules/scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..aae1453316adc42b7ed17b7f0a6c776a78347e6a --- /dev/null +++ b/AR/modules/scaling.py @@ -0,0 +1,320 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +from typing import Optional +from typing import Tuple + +import torch +import torch.nn as nn +from torch import Tensor + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + if x.dtype == torch.float16: + x = x.to(torch.float32) + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = y * (1 - s) + s + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.043637 + ceil = 1.2 + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + +class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = x > 0 + if sign_factor is None: + ctx.save_for_backward(xgt0, scale_factor) + else: + ctx.save_for_backward(xgt0, scale_factor, sign_factor) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + if len(ctx.saved_tensors) == 3: + xgt0, scale_factor, sign_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + sign_factor = sign_factor.unsqueeze(-1) + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + else: + xgt0, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) + + +def _compute_scale_factor( + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) + + if min_abs == 0.0: + below_threshold = 0.0 + else: + # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if + # x_abs)_mean , min_abs. + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) + + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) + + return below_threshold - above_threshold + + +def _compute_sign_factor( + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) + if min_positive == 0.0: + factor1 = 0.0 + else: + # 0 if proportion_positive >= min_positive, else can be + # as large as max_factor. + factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor) + + if max_positive == 1.0: + factor2 = 0.0 + else: + # 0 if self.proportion_positive <= max_positive, else can be + # as large as -max_factor. + factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_( + min=0, max=max_factor + ) + sign_factor = factor1 - factor2 + # require min_positive != 0 or max_positive != 1: + assert not isinstance(sign_factor, float) + return sign_factor + + +class ActivationBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.02]. + sign_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_positive and max_positive + are violated. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + min_prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. Early in training we may use + higher probabilities than this; it will decay to this value. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, + ): + super(ActivationBalancer, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.max_factor = max_factor + self.min_abs = min_abs + self.max_abs = max_abs + self.min_prob = min_prob + self.sign_gain_factor = sign_gain_factor + self.scale_gain_factor = scale_gain_factor + + # count measures how many times the forward() function has been called. + # We occasionally sync this to a tensor called `count`, that exists to + # make sure it is synced to disk when we load and save the model. + self.cpu_count = 0 + self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): + return _no_op(x) + + count = self.cpu_count + self.cpu_count += 1 + + if random.random() < 0.01: + # Occasionally sync self.cpu_count with self.count. + # count affects the decay of 'prob'. don't do this on every iter, + # because syncing with the GPU is slow. + self.cpu_count = max(self.cpu_count, self.count.item()) + self.count.fill_(self.cpu_count) + + # the prob of doing some work exponentially decreases from 0.5 till it hits + # a floor at min_prob (==0.1, by default) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) + + if random.random() < prob: + sign_gain_factor = 0.5 + if self.min_positive != 0.0 or self.max_positive != 1.0: + sign_factor = _compute_sign_factor( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor, + ) + else: + sign_factor = None + + scale_factor = _compute_scale_factor( + x.detach(), + self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor, + ) + return ActivationBalancerFunction.apply( + x, + scale_factor, + sign_factor, + self.channel_dim, + ) + else: + return _no_op(x) + + +def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential: + """ + ActivationBalancer -> DoubleSwish + """ + balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob) + return nn.Sequential( + balancer, + DoubleSwish(), + ) diff --git a/AR/modules/transformer.py b/AR/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf21cdbbc21006af785534d0e528da703dd68d3 --- /dev/null +++ b/AR/modules/transformer.py @@ -0,0 +1,362 @@ +# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py +import copy +import numbers +from functools import partial +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +from AR.modules.activation import MultiheadAttention +from AR.modules.scaling import BalancedDoubleSwish +from torch import nn +from torch import Tensor +from torch.nn import functional as F + +_shape_t = Union[int, List[int], torch.Size] + + +class LayerNorm(nn.Module): + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + F.layer_norm( + input, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ), + embedding, + ) + + assert embedding is None + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + + def extra_repr(self) -> str: + return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__) + + +class IdentityNorm(nn.Module): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ) -> None: + super(IdentityNorm, self).__init__() + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + return input + + assert embedding is None + return input + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + + __constants__ = ["norm"] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + cache=None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + return_layer_states: return layers' state (optional). + + Shape: + see the docs in Transformer class. + """ + if return_layer_states: + layer_states = [] # layers' output + output = src + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + cache=cache, + ) + layer_states.append(output[0]) + + if self.norm is not None: + output = self.norm(output) + + return layer_states, output + + output = src + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + cache=cache, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerEncoderLayer, self).__init__() + # print(233333333333,d_model,nhead) + # import os + # os._exit(2333333) + self.self_attn = MultiheadAttention( + d_model, # 512 16 + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + + # Implementation of Feedforward model + self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs) + + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + activation = activation(d_model) + elif activation == BalancedDoubleSwish: + activation = BalancedDoubleSwish(d_model) + + # # We can't test self.activation in forward() in TorchScript, + # # so stash some information about it instead. + # if activation is F.relu or isinstance(activation, torch.nn.ReLU): + # self.activation_relu_or_gelu = 1 + # elif activation is F.gelu or isinstance(activation, torch.nn.GELU): + # self.activation_relu_or_gelu = 2 + # else: + # self.activation_relu_or_gelu = 0 + self.activation = activation + + norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + if layer_norm_cls == IdentityNorm: + norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + else: + norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + + if adaptive_layer_norm: + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + else: + self.norm1 = norm1 + self.norm2 = norm2 + + def __setstate__(self, state): + super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + cache=None, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + x, stage_embedding = src, None + is_src_tuple = False + if isinstance(src, tuple): + x, stage_embedding = src + is_src_tuple = True + + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask): + raise AssertionError("only bool and floating types of key_padding_mask are supported") + + if self.norm_first: + x = x + self._sa_block( + self.norm1(x, stage_embedding), + src_mask, + src_key_padding_mask, + cache=cache, + ) + x = x + self._ff_block(self.norm2(x, stage_embedding)) + else: + x = self.norm1( + x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache), + stage_embedding, + ) + x = self.norm2(x + self._ff_block(x), stage_embedding) + + if is_src_tuple: + return (x, stage_embedding) + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + cache=None, + ) -> Tensor: + # print(x.shape,attn_mask.shape,key_padding_mask) + # torch.Size([1, 188, 512]) torch.Size([188, 188]) None + # import os + # os._exit(23333) + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + cache=cache, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return (weight * self.norm(input) + bias, embedding) + + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) diff --git a/AR/modules/transformer_onnx.py b/AR/modules/transformer_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..fa1702548551ccd5166c50ca238a58b136144454 --- /dev/null +++ b/AR/modules/transformer_onnx.py @@ -0,0 +1,281 @@ +# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py +import copy +import numbers +from functools import partial +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +from AR.modules.activation_onnx import MultiheadAttention +from AR.modules.scaling import BalancedDoubleSwish +from torch import nn +from torch import Tensor +from torch.nn import functional as F + +_shape_t = Union[int, List[int], torch.Size] + + +class LayerNorm(nn.Module): + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + F.layer_norm( + input, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ), + embedding, + ) + + assert embedding is None + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + + def extra_repr(self) -> str: + return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__) + + +class IdentityNorm(nn.Module): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ) -> None: + super(IdentityNorm, self).__init__() + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + return input + + assert embedding is None + return input + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + + __constants__ = ["norm"] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + cache=None, + ) -> Tensor: + output = src + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + cache=cache, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, # 512 16 + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs) + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + if isinstance(activation, str): + activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + activation = activation(d_model) + elif activation == BalancedDoubleSwish: + activation = BalancedDoubleSwish(d_model) + self.activation = activation + + norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + if layer_norm_cls == IdentityNorm: + norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + else: + norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + + if adaptive_layer_norm: + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + else: + self.norm1 = norm1 + self.norm2 = norm2 + + def __setstate__(self, state): + super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + cache=None, + ) -> Tensor: + x = src + stage_embedding = None + x = self.norm1( + x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache), + stage_embedding, + ) + x = self.norm2(x + self._ff_block(x), stage_embedding) + + return x + + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + cache=None, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + cache=cache, + ) + return self.dropout1(x) + + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return (weight * self.norm(input) + bias, embedding) + + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) diff --git a/AR/text_processing/__init__.py b/AR/text_processing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/AR/text_processing/phonemizer.py b/AR/text_processing/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1003040e282c51e4e240a122bce4f3b87a09b38f --- /dev/null +++ b/AR/text_processing/phonemizer.py @@ -0,0 +1,72 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/phonemizer.py +# reference: https://github.com/lifeiteng/vall-e +import itertools +import re +from typing import Dict +from typing import List + +import regex +from gruut import sentences +from gruut.const import Sentence +from gruut.const import Word +from AR.text_processing.symbols import SYMBOL_TO_ID + + +class GruutPhonemizer: + def __init__(self, language: str): + self._phonemizer = sentences + self.lang = language + self.symbol_to_id = SYMBOL_TO_ID + self._special_cases_dict: Dict[str] = { + r"\.\.\.": "... ", + ";": "; ", + ":": ": ", + ",": ", ", + r"\.": ". ", + "!": "! ", + r"\?": "? ", + "—": "—", + "…": "… ", + "«": "«", + "»": "»", + } + self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])" + + def _normalize_punctuation(self, text: str) -> str: + text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text) + text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text) + text = regex.sub(r"\pZ+", r" ", text) + return text.strip() + + def _convert_punctuation(self, word: Word) -> str: + if not word.phonemes: + return "" + if word.phonemes[0] in ["‖", "|"]: + return word.text.strip() + + phonemes = "".join(word.phonemes) + # remove modifier characters ˈˌː with regex + phonemes = re.sub(r"[ˈˌː͡]", "", phonemes) + return phonemes.strip() + + def phonemize(self, text: str, espeak: bool = False) -> str: + text_to_phonemize: str = self._normalize_punctuation(text) + sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)] + words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)] + return " ".join(words) + + def transform(self, phonemes): + # convert phonemes to ids + # dictionary is in symbols.py + return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()] + + +if __name__ == "__main__": + phonemizer = GruutPhonemizer("en-us") + # text -> IPA + phonemes = phonemizer.phonemize("Hello, wor-ld ?") + print("phonemes:", phonemes) + print("len(phonemes):", len(phonemes)) + phoneme_ids = phonemizer.transform(phonemes) + print("phoneme_ids:", phoneme_ids) + print("len(phoneme_ids):", len(phoneme_ids)) diff --git a/AR/text_processing/symbols.py b/AR/text_processing/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ef57faf5b83cb2417b4f9244244dc9939153aa --- /dev/null +++ b/AR/text_processing/symbols.py @@ -0,0 +1,12 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/symbols.py +# reference: https://github.com/lifeiteng/vall-e +PAD = "_" +PUNCTUATION = ';:,.!?¡¿—…"«»“” ' +LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +IPA_LETTERS = ( + "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" +) +SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS) +SPACE_ID = SYMBOLS.index(" ") +SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)} +ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)} diff --git a/AR/utils/__init__.py b/AR/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a9cb4d58504c17a96634ddd98f22cadad9365de --- /dev/null +++ b/AR/utils/__init__.py @@ -0,0 +1,36 @@ +import re + + +def str2bool(str): + return True if str.lower() == "true" else False + + +def get_newest_ckpt(string_list): + # 定义一个正则表达式模式,用于匹配字符串中的数字 + pattern = r"epoch=(\d+)-step=(\d+)\.ckpt" + + # 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表 + extracted_info = [] + for string in string_list: + match = re.match(pattern, string) + if match: + epoch = int(match.group(1)) + step = int(match.group(2)) + extracted_info.append((epoch, step, string)) + # 按照 epoch 后面的数字和 step 后面的数字进行排序 + sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True) + # 获取最新的 ckpt 文件名 + newest_ckpt = sorted_info[0][2] + return newest_ckpt + + +# 文本存在且不为空时 return True +def check_txt_file(file_path): + try: + with open(file_path, "r") as file: + text = file.readline().strip() + assert text.strip() != "" + return text + except Exception: + return False + return False diff --git a/AR/utils/initialize.py b/AR/utils/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7c713823f57572ab8f7045ceba21e8e2619e4c --- /dev/null +++ b/AR/utils/initialize.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +"""Initialize modules for espnet2 neural networks.""" + +import torch +from typeguard import check_argument_types + + +def initialize(model: torch.nn.Module, init: str): + """Initialize weights of a neural network module. + + Parameters are initialized using the given method or distribution. + + Custom initialization routines can be implemented into submodules + as function `espnet_initialization_fn` within the custom module. + + Args: + model: Target. + init: Method of initialization. + """ + assert check_argument_types() + print("init with", init) + + # weight init + for p in model.parameters(): + if p.dim() > 1: + if init == "xavier_uniform": + torch.nn.init.xavier_uniform_(p.data) + elif init == "xavier_normal": + torch.nn.init.xavier_normal_(p.data) + elif init == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") + elif init == "kaiming_normal": + torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") + else: + raise ValueError("Unknown initialization: " + init) + # bias init + for name, p in model.named_parameters(): + if ".bias" in name and p.dim() == 1: + p.data.zero_() diff --git a/AR/utils/io.py b/AR/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..a6475cb6b114787acfde5d73e1552cf58e04997b --- /dev/null +++ b/AR/utils/io.py @@ -0,0 +1,30 @@ +import sys + +import torch +import yaml + + +def load_yaml_config(path): + with open(path) as f: + config = yaml.full_load(f) + return config + + +def save_config_to_yaml(config, path): + assert path.endswith(".yaml") + with open(path, "w") as f: + f.write(yaml.dump(config)) + f.close() + + +def write_args(args, path): + args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_")) + with open(path, "a") as args_file: + args_file.write("==> torch version: {}\n".format(torch.__version__)) + args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version())) + args_file.write("==> Cmd:\n") + args_file.write(str(sys.argv)) + args_file.write("\n==> args:\n") + for k, v in sorted(args_dict.items()): + args_file.write(" %s: %s\n" % (str(k), str(v))) + args_file.close() diff --git a/BigVGAN/LICENSE b/BigVGAN/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4c78361c86d4f685117d60d6623e2197fcfed706 --- /dev/null +++ b/BigVGAN/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 NVIDIA CORPORATION. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/BigVGAN/README.md b/BigVGAN/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2fa70ceea647053933b913b329041ee8c41526db --- /dev/null +++ b/BigVGAN/README.md @@ -0,0 +1,266 @@ +## BigVGAN: A Universal Neural Vocoder with Large-Scale Training + +#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon + +[[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN) + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bigvgan-a-universal-neural-vocoder-with-large/speech-synthesis-on-libritts)](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large) + +
+ +## News +- **Sep 2024 (v2.4):** + - We have updated the pretrained checkpoints trained for 5M steps. This is final release of the BigVGAN-v2 checkpoints. + +- **Jul 2024 (v2.3):** + - General refactor and code improvements for improved readability. + - Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark. + +- **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio. + +- **Jul 2024 (v2.1):** BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces. + +- **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights: + - Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU. + - Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546). + - Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments. + - We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio. + +## Installation + +The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment: + +```shell +conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia +conda activate bigvgan +``` + +Clone the repository and install dependencies: + +```shell +git clone https://github.com/NVIDIA/BigVGAN +cd BigVGAN +pip install -r requirements.txt +``` + +## Inference Quickstart using 🤗 Hugging Face Hub + +Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input. + +```python +device = 'cuda' + +import torch +import bigvgan +import librosa +from meldataset import get_mel_spectrogram + +# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference. +model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False) + +# remove weight norm in the model and set to eval mode +model.remove_weight_norm() +model = model.eval().to(device) + +# load wav file and compute mel spectrogram +wav_path = '/path/to/your/audio.wav' +wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1] +wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time] + +# compute mel spectrogram from the ground truth audio +mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame] + +# generate waveform from mel +with torch.inference_mode(): + wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1] +wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time] + +# you can convert the generated waveform to 16 bit linear PCM +wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype +``` + +## Local gradio demo + +You can run a local gradio demo using below command: + +```python +pip install -r demo/requirements.txt +python demo/app.py +``` + +## Training + +Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset: + +```shell +cd filelists/LibriTTS && \ +ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \ +ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \ +ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \ +ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \ +ln -s /path/to/your/LibriTTS/dev-other dev-other && \ +ln -s /path/to/your/LibriTTS/test-clean test-clean && \ +ln -s /path/to/your/LibriTTS/test-other test-other && \ +cd ../.. +``` + +Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input: + +```shell +python train.py \ +--config configs/bigvgan_v2_24khz_100band_256x.json \ +--input_wavs_dir filelists/LibriTTS \ +--input_training_file filelists/LibriTTS/train-full.txt \ +--input_validation_file filelists/LibriTTS/val-full.txt \ +--list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \ +--list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \ +--checkpoint_path exp/bigvgan_v2_24khz_100band_256x +``` + +## Synthesis + +Synthesize from BigVGAN model. Below is an example command for generating audio from the model. +It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`. + +```shell +python inference.py \ +--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \ +--input_wavs_dir /path/to/your/input_wav \ +--output_dir /path/to/your/output_wav +``` + +`inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`. +It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`. + +Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model. + +```shell +python inference_e2e.py \ +--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \ +--input_mels_dir /path/to/your/input_mel \ +--output_dir /path/to/your/output_wav +``` + +## Using Custom CUDA Kernel for Synthesis + +You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN: + +```python +generator = BigVGAN(h, use_cuda_kernel=True) +``` + +You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature. + +When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`. + +Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using. + +We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`: + +```python +python tests/test_cuda_vs_torch_model.py \ +--checkpoint_file /path/to/your/bigvgan_generator.pt +``` + +```shell +loading plain Pytorch BigVGAN +... +loading CUDA kernel BigVGAN with auto-build +Detected CUDA files, patching ldflags +Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja.. +Building extension module anti_alias_activation_cuda... +... +Loading extension module anti_alias_activation_cuda... +... +Loading '/path/to/your/bigvgan_generator.pt' +... +[Success] test CUDA fused vs. plain torch BigVGAN inference + > mean_difference=0.0007238413265440613 +... +``` + +If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version. + +## Pretrained Models + +We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a). +One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories. + +| Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned | +|:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:| +| [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 5M | No | +| [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 5M | No | +| [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 5M | No | +| [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 5M | No | +| [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 5M | No | +| [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No | +| [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No | +| [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No | +| [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No | + +The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset. +We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications. +Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality. + +You can fine-tune the models by: + +1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states) +2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py` + +## Training Details of BigVGAN-v2 + +Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs. + +Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs. + +When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`. + +## Evaluation Results of BigVGAN-v2 + +Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio. + +| Model | Dataset | Steps | PESQ(↑) | M-STFT(↓) | MCD(↓) | Periodicity(↓) | V/UV F1(↑) | +|:----------:|:-----------------------:|:-----:|:---------:|:----------:|:----------:|:--------------:|:----------:| +| BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 | +| BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 | +| BigVGAN-v2 | Large-scale Compilation | 3M | 4.359 | 0.7134 | 0.3060 | 0.0621 | 0.9777 | +| BigVGAN-v2 | Large-scale Compilation | 5M | **4.362** | **0.7026** | **0.2903** | **0.0593** | **0.9793** | + +## Speed Benchmark + +Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model. + +| GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) | +|:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:| +| NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 | +| | | True | 3916.5 | 163.2x | 1.3 | +| | 2048 | False | 1899.6 | 79.2x | 1.7 | +| | | True | 5330.1 | 222.1x | 1.7 | +| | 16384 | False | 1973.8 | 82.2x | 5.0 | +| | | True | 5761.7 | 240.1x | 4.4 | +| NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 | +| | | True | 1598.1 | 66.6x | 1.3 | +| | 2048 | False | 929.9 | 38.7x | 1.7 | +| | | True | 1971.3 | 82.1x | 1.6 | +| | 16384 | False | 943.4 | 39.3x | 5.0 | +| | | True | 2026.5 | 84.4x | 3.9 | +| NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 | +| | | True | 811.3 | 33.8x | 1.3 | +| | 2048 | False | 576.5 | 24.0x | 1.7 | +| | | True | 1023.0 | 42.6x | 1.5 | +| | 16384 | False | 589.4 | 24.6x | 5.0 | +| | | True | 1068.1 | 44.5x | 3.2 | + +## Acknowledgements + +We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference. + +## References + +- [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator) +- [Snake](https://github.com/EdwardDixon/snake) (for periodic activation) +- [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing) +- [Julius](https://github.com/adefossez/julius) (for low-pass filter) +- [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator) +- [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss) +- [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator) diff --git a/BigVGAN/activations.py b/BigVGAN/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..abe3ad9e25c6ab3d4545c6a8c60e1f85a5a8e98e --- /dev/null +++ b/BigVGAN/activations.py @@ -0,0 +1,122 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Snake(nn.Module): + """ + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + """ + super(Snake, self).__init__() + self.in_features = in_features + + # Initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # Log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # Linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # Initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # Log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # Linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x diff --git a/BigVGAN/alias_free_activation/cuda/__init__.py b/BigVGAN/alias_free_activation/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/BigVGAN/alias_free_activation/cuda/activation1d.py b/BigVGAN/alias_free_activation/cuda/activation1d.py new file mode 100644 index 0000000000000000000000000000000000000000..ea333cfa0d5f84de363b7b27739df3bbc457d763 --- /dev/null +++ b/BigVGAN/alias_free_activation/cuda/activation1d.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +from alias_free_activation.torch.resample import UpSample1d, DownSample1d + +# load fused CUDA kernel: this enables importing anti_alias_activation_cuda +from alias_free_activation.cuda import load + +anti_alias_activation_cuda = load.load() + + +class FusedAntiAliasActivation(torch.autograd.Function): + """ + Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs. + The hyperparameters are hard-coded in the kernel to maximize speed. + NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters. + """ + + @staticmethod + def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta): + activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta) + + return activation_results + + @staticmethod + def backward(ctx, output_grads): + raise NotImplementedError + return output_grads, None, None + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + fused: bool = True, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + self.fused = fused # Whether to use fused CUDA kernel or not + + def forward(self, x): + if not self.fused: + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + else: + if self.act.__class__.__name__ == "Snake": + beta = self.act.alpha.data # Snake uses same params for alpha and beta + else: + beta = self.act.beta.data # Snakebeta uses different params for alpha and beta + alpha = self.act.alpha.data + if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log + alpha = torch.log(alpha) + beta = torch.log(beta) + + x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta) + return x diff --git a/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp b/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c5651f77143bd678169eb11564a7cf7a7969a59e --- /dev/null +++ b/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp @@ -0,0 +1,23 @@ +/* coding=utf-8 + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #include + +extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)"); +} \ No newline at end of file diff --git a/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu b/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..8c442334869fe72d639ec203fa4fac07f96a0ee1 --- /dev/null +++ b/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu @@ -0,0 +1,246 @@ +/* coding=utf-8 + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "type_shim.h" +#include +#include +#include +#include +#include + +namespace +{ + // Hard-coded hyperparameters + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4; + constexpr int BUFFER_SIZE = 32; + constexpr int FILTER_SIZE = 12; + constexpr int HALF_FILTER_SIZE = 6; + constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl + constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl + constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl + + template + __global__ void anti_alias_activation_forward( + output_t *dst, + const input_t *src, + const input_t *up_ftr, + const input_t *down_ftr, + const input_t *alpha, + const input_t *beta, + int batch_size, + int channels, + int seq_len) + { + // Up and downsample filters + input_t up_filter[FILTER_SIZE]; + input_t down_filter[FILTER_SIZE]; + + // Load data from global memory including extra indices reserved for replication paddings + input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0}; + input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0}; + + // Output stores downsampled output before writing to dst + output_t output[BUFFER_SIZE]; + + // blockDim/threadIdx = (128, 1, 1) + // gridDim/blockIdx = (seq_blocks, channels, batches) + int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); + int local_offset = threadIdx.x * BUFFER_SIZE; + int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset; + + // intermediate have double the seq_len + int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2; + int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset; + + // Get values needed for replication padding before moving pointer + const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); + input_t seq_left_most_value = right_most_pntr[0]; + input_t seq_right_most_value = right_most_pntr[seq_len - 1]; + + // Move src and dst pointers + src += block_offset + local_offset; + dst += block_offset + local_offset; + + // Alpha and beta values for snake activatons. Applies exp by default + alpha = alpha + blockIdx.y; + input_t alpha_val = expf(alpha[0]); + beta = beta + blockIdx.y; + input_t beta_val = expf(beta[0]); + + #pragma unroll + for (int it = 0; it < FILTER_SIZE; it += 1) + { + up_filter[it] = up_ftr[it]; + down_filter[it] = down_ftr[it]; + } + + // Apply replication padding for upsampling, matching torch impl + #pragma unroll + for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1) + { + int element_index = seq_offset + it; // index for element + if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD)) + { + elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value; + } + if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD)) + { + elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value; + } + if ((element_index >= 0) && (element_index < seq_len)) + { + elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it]; + } + } + + // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later + #pragma unroll + for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1) + { + input_t acc = 0.0; + int element_index = intermediate_seq_offset + it; // index for intermediate + #pragma unroll + for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) + { + if ((element_index + f_idx) >= 0) + { + acc += up_filter[f_idx] * elements[it + f_idx]; + } + } + intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc; + } + + // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later + double no_div_by_zero = 0.000000001; + #pragma unroll + for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1) + { + intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val); + } + + // Apply replication padding before downsampling conv from intermediates + #pragma unroll + for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1) + { + intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT]; + } + #pragma unroll + for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1) + { + intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1]; + } + + // Apply downsample strided convolution (assuming stride=2) from intermediates + #pragma unroll + for (int it = 0; it < BUFFER_SIZE; it += 1) + { + input_t acc = 0.0; + #pragma unroll + for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) + { + // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation + acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT]; + } + output[it] = acc; + } + + // Write output to dst + #pragma unroll + for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG) + { + int element_index = seq_offset + it; + if (element_index < seq_len) + { + dst[it] = output[it]; + } + } + + } + + template + void dispatch_anti_alias_activation_forward( + output_t *dst, + const input_t *src, + const input_t *up_ftr, + const input_t *down_ftr, + const input_t *alpha, + const input_t *beta, + int batch_size, + int channels, + int seq_len) + { + if (seq_len == 0) + { + return; + } + else + { + // Use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + constexpr int seq_len_per_block = 4096; + int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block; + dim3 blocks(blocks_per_seq_len, channels, batch_size); + dim3 threads(threads_per_block, 1, 1); + + anti_alias_activation_forward + <<>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len); + } + } +} + +extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta) +{ + // Input is a 3d tensor with dimensions [batches, channels, seq_len] + const int batches = input.size(0); + const int channels = input.size(1); + const int seq_len = input.size(2); + + // Output + auto act_options = input.options().requires_grad(false); + + torch::Tensor anti_alias_activation_results = + torch::empty({batches, channels, seq_len}, act_options); + + void *input_ptr = static_cast(input.data_ptr()); + void *up_filter_ptr = static_cast(up_filter.data_ptr()); + void *down_filter_ptr = static_cast(down_filter.data_ptr()); + void *alpha_ptr = static_cast(alpha.data_ptr()); + void *beta_ptr = static_cast(beta.data_ptr()); + void *anti_alias_activation_results_ptr = static_cast(anti_alias_activation_results.data_ptr()); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch anti alias activation_forward", + dispatch_anti_alias_activation_forward( + reinterpret_cast(anti_alias_activation_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(up_filter_ptr), + reinterpret_cast(down_filter_ptr), + reinterpret_cast(alpha_ptr), + reinterpret_cast(beta_ptr), + batches, + channels, + seq_len);); + return anti_alias_activation_results; +} \ No newline at end of file diff --git a/BigVGAN/alias_free_activation/cuda/build/_ b/BigVGAN/alias_free_activation/cuda/build/_ new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/BigVGAN/alias_free_activation/cuda/build/_ @@ -0,0 +1 @@ + diff --git a/BigVGAN/alias_free_activation/cuda/compat.h b/BigVGAN/alias_free_activation/cuda/compat.h new file mode 100644 index 0000000000000000000000000000000000000000..25818b2edf4cb0dc9130e62c7c4de8d16a01baa5 --- /dev/null +++ b/BigVGAN/alias_free_activation/cuda/compat.h @@ -0,0 +1,29 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/BigVGAN/alias_free_activation/cuda/load.py b/BigVGAN/alias_free_activation/cuda/load.py new file mode 100644 index 0000000000000000000000000000000000000000..14fbf0548c84f6e698e18631b59473d7b4d7c736 --- /dev/null +++ b/BigVGAN/alias_free_activation/cuda/load.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import os +import pathlib +import subprocess + +from torch.utils import cpp_extension + +""" +Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels. +Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below +""" +os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + +def load(): + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") + + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / "build" + _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=[ + "-O3", + ], + extra_cuda_cflags=[ + "-O3", + "-gencode", + "arch=compute_70,code=sm_70", + "--use_fast_math", + ] + + extra_cuda_flags + + cc_flag, + verbose=True, + ) + + extra_cuda_flags = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ] + + sources = [ + srcpath / "anti_alias_activation.cpp", + srcpath / "anti_alias_activation_cuda.cu", + ] + anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags) + + return anti_alias_activation_cuda + + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") diff --git a/BigVGAN/alias_free_activation/cuda/type_shim.h b/BigVGAN/alias_free_activation/cuda/type_shim.h new file mode 100644 index 0000000000000000000000000000000000000000..5db7e8a397e982d4d30d16ab6060814b98b7ab83 --- /dev/null +++ b/BigVGAN/alias_free_activation/cuda/type_shim.h @@ -0,0 +1,92 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "compat.h" + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch (TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch (TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } diff --git a/BigVGAN/alias_free_activation/torch/__init__.py b/BigVGAN/alias_free_activation/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f756ed83f87f9839e457b240f60469bc187707d --- /dev/null +++ b/BigVGAN/alias_free_activation/torch/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .filter import * +from .resample import * +from .act import * diff --git a/BigVGAN/alias_free_activation/torch/act.py b/BigVGAN/alias_free_activation/torch/act.py new file mode 100644 index 0000000000000000000000000000000000000000..a6693aac602d7b331d6149522685dd512a26d277 --- /dev/null +++ b/BigVGAN/alias_free_activation/torch/act.py @@ -0,0 +1,30 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x diff --git a/BigVGAN/alias_free_activation/torch/filter.py b/BigVGAN/alias_free_activation/torch/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..dc905b204c91a5cea04cd4f8bbf60498fbc7b97f --- /dev/null +++ b/BigVGAN/alias_free_activation/torch/filter.py @@ -0,0 +1,99 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if "sinc" in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + """ + Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal. + """ + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + ): + """ + kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible. + """ + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + # Input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + return out diff --git a/BigVGAN/alias_free_activation/torch/resample.py b/BigVGAN/alias_free_activation/torch/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..e7928fadbe77d5ff04bdfefe70ab3ceb207c7580 --- /dev/null +++ b/BigVGAN/alias_free_activation/torch/resample.py @@ -0,0 +1,48 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + x = x[..., self.pad_left : -self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + xx = self.lowpass(x) + + return xx diff --git a/BigVGAN/bigvgan.py b/BigVGAN/bigvgan.py new file mode 100644 index 0000000000000000000000000000000000000000..febdf165c354b1fa2932f27e4ef8b7b6da10e2a6 --- /dev/null +++ b/BigVGAN/bigvgan.py @@ -0,0 +1,461 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import os +import json +from pathlib import Path +from typing import Optional, Union, Dict + +import torch +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +from . import activations +from .utils0 import init_weights, get_padding +from .alias_free_activation.torch.act import Activation1d as TorchActivation1d +from .env import AttrDict + +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download + + +def load_hparams_from_json(path) -> AttrDict: + with open(path) as f: + data = f.read() + return AttrDict(json.loads(data)) + + +class AMPBlock1(torch.nn.Module): + """ + AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer. + AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1 + + Args: + h (AttrDict): Hyperparameters. + channels (int): Number of convolution channels. + kernel_size (int): Size of the convolution kernel. Default is 3. + dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5). + activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None. + """ + + def __init__( + self, + h: AttrDict, + channels: int, + kernel_size: int = 3, + dilation: tuple = (1, 3, 5), + activation: str = None, + ): + super().__init__() + + self.h = h + + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=d, + padding=get_padding(kernel_size, d), + ) + ) + for d in dilation + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ) + for _ in range(len(dilation)) + ] + ) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers + + # Select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + from .alias_free_activation.cuda.activation1d import ( + Activation1d as CudaActivation1d, + ) + + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + + # Activation functions + if activation == "snake": + self.activations = nn.ModuleList( + [ + Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ] + ) + elif activation == "snakebeta": + self.activations = nn.ModuleList( + [ + Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class AMPBlock2(torch.nn.Module): + """ + AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer. + Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1 + + Args: + h (AttrDict): Hyperparameters. + channels (int): Number of convolution channels. + kernel_size (int): Size of the convolution kernel. Default is 3. + dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5). + activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None. + """ + + def __init__( + self, + h: AttrDict, + channels: int, + kernel_size: int = 3, + dilation: tuple = (1, 3, 5), + activation: str = None, + ): + super().__init__() + + self.h = h + + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=d, + padding=get_padding(kernel_size, d), + ) + ) + for d in dilation + ] + ) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) # Total number of conv layers + + # Select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + from .alias_free_activation.cuda.activation1d import ( + Activation1d as CudaActivation1d, + ) + + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + + # Activation functions + if activation == "snake": + self.activations = nn.ModuleList( + [ + Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ] + ) + elif activation == "snakebeta": + self.activations = nn.ModuleList( + [ + Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + for c, a in zip(self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class BigVGAN( + torch.nn.Module, + PyTorchModelHubMixin, + # library_name="bigvgan", + # repo_url="https://github.com/NVIDIA/BigVGAN", + # docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md", + # pipeline_tag="audio-to-audio", + # license="mit", + # tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"], +): + """ + BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks). + New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks. + + Args: + h (AttrDict): Hyperparameters. + use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels. + + Note: + - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported. + - Ensure that the activation function is correctly specified in the hyperparameters (h.activation). + """ + + def __init__(self, h: AttrDict, use_cuda_kernel: bool = False): + super().__init__() + self.h = h + self.h["use_cuda_kernel"] = use_cuda_kernel + + # Select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + from .alias_free_activation.cuda.activation1d import ( + Activation1d as CudaActivation1d, + ) + + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + + # Pre-conv + self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) + + # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + if h.resblock == "1": + resblock_class = AMPBlock1 + elif h.resblock == "2": + resblock_class = AMPBlock2 + else: + raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}") + + # Transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + nn.ModuleList( + [ + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ] + ) + ) + + # Residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation)) + + # Post-conv + activation_post = ( + activations.Snake(ch, alpha_logscale=h.snake_logscale) + if h.activation == "snake" + else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None) + ) + if activation_post is None: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + self.activation_post = Activation1d(activation=activation_post) + + # Whether to use bias for the final conv_post. Default to True for backward compatibility + self.use_bias_at_final = h.get("use_bias_at_final", True) + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)) + + # Weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + # Final tanh activation. Defaults to True for backward compatibility + self.use_tanh_at_final = h.get("use_tanh_at_final", True) + + def forward(self, x): + # Pre-conv + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + # Upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # Post-conv + x = self.activation_post(x) + x = self.conv_post(x) + # Final tanh activation + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1] + + return x + + def remove_weight_norm(self): + try: + # print("Removing weight norm...") + for l in self.ups: + for l_i in l: + remove_weight_norm(l_i) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + except ValueError: + print("[INFO] Model already removed weight norm. Skipping!") + pass + + # Additional methods for huggingface_hub support + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights and config.json from a Pytorch model to a local directory.""" + + model_path = save_directory / "bigvgan_generator.pt" + torch.save({"generator": self.state_dict()}, model_path) + + config_path = save_directory / "config.json" + with open(config_path, "w") as config_file: + json.dump(self.h, config_file, indent=4) + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: str, + cache_dir: str, + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", # Additional argument + strict: bool = False, # Additional argument + use_cuda_kernel: bool = False, + **model_kwargs, + ): + """Load Pytorch pretrained weights and return the loaded model.""" + + # Download and load hyperparameters (h) used by BigVGAN + if os.path.isdir(model_id): + # print("Loading config.json from local directory") + config_file = os.path.join(model_id, "config.json") + else: + config_file = hf_hub_download( + repo_id=model_id, + filename="config.json", + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + h = load_hparams_from_json(config_file) + + # instantiate BigVGAN using h + if use_cuda_kernel: + print( + "[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!" + ) + print( + "[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!" + ) + print( + "[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis" + ) + model = cls(h, use_cuda_kernel=use_cuda_kernel) + + # Download and load pretrained generator weight + if os.path.isdir(model_id): + # print("Loading weights from local directory") + model_file = os.path.join(model_id, "bigvgan_generator.pt") + else: + # print(f"Loading weights from {model_id}") + model_file = hf_hub_download( + repo_id=model_id, + filename="bigvgan_generator.pt", + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + + checkpoint_dict = torch.load(model_file, map_location=map_location) + + try: + model.load_state_dict(checkpoint_dict["generator"]) + except RuntimeError: + print( + "[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!" + ) + model.remove_weight_norm() + model.load_state_dict(checkpoint_dict["generator"]) + + return model diff --git a/BigVGAN/configs/bigvgan_22khz_80band.json b/BigVGAN/configs/bigvgan_22khz_80band.json new file mode 100644 index 0000000000000000000000000000000000000000..64bca7846edb4e86d7ee22d9ca7a1554cf7f1042 --- /dev/null +++ b/BigVGAN/configs/bigvgan_22khz_80band.json @@ -0,0 +1,45 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": true, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/BigVGAN/configs/bigvgan_24khz_100band.json b/BigVGAN/configs/bigvgan_24khz_100band.json new file mode 100644 index 0000000000000000000000000000000000000000..e7f7ff08f6697a4640d8e28c0b3fe7e62d0c3fc7 --- /dev/null +++ b/BigVGAN/configs/bigvgan_24khz_100band.json @@ -0,0 +1,45 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": true, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 100, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 24000, + + "fmin": 0, + "fmax": 12000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/BigVGAN/configs/bigvgan_base_22khz_80band.json b/BigVGAN/configs/bigvgan_base_22khz_80band.json new file mode 100644 index 0000000000000000000000000000000000000000..fd244848308917f4df7ce49bf6b76530fd04cbc2 --- /dev/null +++ b/BigVGAN/configs/bigvgan_base_22khz_80band.json @@ -0,0 +1,45 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [8,8,2,2], + "upsample_kernel_sizes": [16,16,4,4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": true, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/BigVGAN/configs/bigvgan_base_24khz_100band.json b/BigVGAN/configs/bigvgan_base_24khz_100band.json new file mode 100644 index 0000000000000000000000000000000000000000..0911508cac4a9346ada8c196bfcc228998da6f42 --- /dev/null +++ b/BigVGAN/configs/bigvgan_base_24khz_100band.json @@ -0,0 +1,45 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [8,8,2,2], + "upsample_kernel_sizes": [16,16,4,4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": true, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 100, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 24000, + + "fmin": 0, + "fmax": 12000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json b/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json new file mode 100644 index 0000000000000000000000000000000000000000..e96bd5fdd5b99767adba7f13bfcd1f777d5c599a --- /dev/null +++ b/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/BigVGAN/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json b/BigVGAN/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json new file mode 100644 index 0000000000000000000000000000000000000000..a3c9699fbe11948f4fd7e3434d2e623a00c802dd --- /dev/null +++ b/BigVGAN/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/BigVGAN/configs/bigvgan_v2_24khz_100band_256x.json b/BigVGAN/configs/bigvgan_v2_24khz_100band_256x.json new file mode 100644 index 0000000000000000000000000000000000000000..8057ee267c8ed80615362a41892b923a3ccd27e5 --- /dev/null +++ b/BigVGAN/configs/bigvgan_v2_24khz_100band_256x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 100, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 24000, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/BigVGAN/configs/bigvgan_v2_44khz_128band_256x.json b/BigVGAN/configs/bigvgan_v2_44khz_128band_256x.json new file mode 100644 index 0000000000000000000000000000000000000000..b6999d3028e5d741ec99b16b34f153e763d0cfec --- /dev/null +++ b/BigVGAN/configs/bigvgan_v2_44khz_128band_256x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 128, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 44100, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/BigVGAN/configs/bigvgan_v2_44khz_128band_512x.json b/BigVGAN/configs/bigvgan_v2_44khz_128band_512x.json new file mode 100644 index 0000000000000000000000000000000000000000..2d7176c910ae0969f208f6d28e3f14abca2dbc7f --- /dev/null +++ b/BigVGAN/configs/bigvgan_v2_44khz_128band_512x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [8,4,2,2,2,2], + "upsample_kernel_sizes": [16,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 128, + "num_freq": 2049, + "n_fft": 2048, + "hop_size": 512, + "win_size": 2048, + + "sampling_rate": 44100, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/BigVGAN/discriminators.py b/BigVGAN/discriminators.py new file mode 100644 index 0000000000000000000000000000000000000000..2d44c7983955a1be15a4520f6730de272f799128 --- /dev/null +++ b/BigVGAN/discriminators.py @@ -0,0 +1,625 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv2d +from torch.nn.utils import weight_norm, spectral_norm +from torchaudio.transforms import Spectrogram, Resample + +from env import AttrDict +from utils import get_padding +import typing +from typing import List, Tuple + + +class DiscriminatorP(torch.nn.Module): + def __init__( + self, + h: AttrDict, + period: List[int], + kernel_size: int = 5, + stride: int = 3, + use_spectral_norm: bool = False, + ): + super().__init__() + self.period = period + self.d_mult = h.discriminator_channel_mult + norm_f = weight_norm if not use_spectral_norm else spectral_norm + + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + int(32 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(32 * self.d_mult), + int(128 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(128 * self.d_mult), + int(512 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(512 * self.d_mult), + int(1024 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(1024 * self.d_mult), + int(1024 * self.d_mult), + (kernel_size, 1), + 1, + padding=(2, 0), + ) + ), + ] + ) + self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.1) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, h: AttrDict): + super().__init__() + self.mpd_reshapes = h.mpd_reshapes + print(f"mpd_reshapes: {self.mpd_reshapes}") + self.discriminators = nn.ModuleList( + [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__(self, cfg: AttrDict, resolution: List[List[int]]): + super().__init__() + + self.resolution = resolution + assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}" + self.lrelu_slope = 0.1 + + norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm + if hasattr(cfg, "mrd_use_spectral_norm"): + print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}") + norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm + self.d_mult = cfg.discriminator_channel_mult + if hasattr(cfg, "mrd_channel_mult"): + print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}") + self.d_mult = cfg.mrd_channel_mult + + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 9), + stride=(1, 2), + padding=(1, 4), + ) + ), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 9), + stride=(1, 2), + padding=(1, 4), + ) + ), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 9), + stride=(1, 2), + padding=(1, 4), + ) + ), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 3), + padding=(1, 1), + ) + ), + ] + ) + self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + fmap = [] + + x = self.spectrogram(x) + x = x.unsqueeze(1) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + def spectrogram(self, x: torch.Tensor) -> torch.Tensor: + n_fft, hop_length, win_length = self.resolution + x = F.pad( + x, + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + x = x.squeeze(1) + x = torch.stft( + x, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + center=False, + return_complex=True, + ) + x = torch.view_as_real(x) # [B, F, TT, 2] + mag = torch.norm(x, p=2, dim=-1) # [B, F, TT] + + return mag + + +class MultiResolutionDiscriminator(nn.Module): + def __init__(self, cfg, debug=False): + super().__init__() + self.resolutions = cfg.resolutions + assert len(self.resolutions) == 3, ( + f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}" + ) + self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions]) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec +# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license. +# LICENSE is in incl_licenses directory. +class DiscriminatorB(nn.Module): + def __init__( + self, + window_length: int, + channels: int = 32, + hop_factor: float = 0.25, + bands: Tuple[Tuple[float, float], ...] = ( + (0.0, 0.1), + (0.1, 0.25), + (0.25, 0.5), + (0.5, 0.75), + (0.75, 1.0), + ), + ): + super().__init__() + self.window_length = window_length + self.hop_factor = hop_factor + self.spec_fn = Spectrogram( + n_fft=window_length, + hop_length=int(window_length * hop_factor), + win_length=window_length, + power=None, + ) + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + convs = lambda: nn.ModuleList( + [ + weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + + self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) + + def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]: + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + x = self.spec_fn(x) + x = torch.view_as_real(x) + x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F] + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + x_bands = self.spectrogram(x.squeeze(1)) + fmap = [] + x = [] + + for band, stack in zip(x_bands, self.band_convs): + for i, layer in enumerate(stack): + band = layer(band) + band = torch.nn.functional.leaky_relu(band, 0.1) + if i > 0: + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return x, fmap + + +# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec +# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license. +# LICENSE is in incl_licenses directory. +class MultiBandDiscriminator(nn.Module): + def __init__( + self, + h, + ): + """ + Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec. + and the modified code adapted from https://github.com/gemelo-ai/vocos. + """ + super().__init__() + # fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h. + self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512]) + self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes]) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license. +# LICENSE is in incl_licenses directory. +class DiscriminatorCQT(nn.Module): + def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int): + super().__init__() + self.cfg = cfg + + self.filters = cfg["cqtd_filters"] + self.max_filters = cfg["cqtd_max_filters"] + self.filters_scale = cfg["cqtd_filters_scale"] + self.kernel_size = (3, 9) + self.dilations = cfg["cqtd_dilations"] + self.stride = (1, 2) + + self.in_channels = cfg["cqtd_in_channels"] + self.out_channels = cfg["cqtd_out_channels"] + self.fs = cfg["sampling_rate"] + self.hop_length = hop_length + self.n_octaves = n_octaves + self.bins_per_octave = bins_per_octave + + # Lazy-load + from nnAudio import features + + self.cqt_transform = features.cqt.CQT2010v2( + sr=self.fs * 2, + hop_length=self.hop_length, + n_bins=self.bins_per_octave * self.n_octaves, + bins_per_octave=self.bins_per_octave, + output_format="Complex", + pad_mode="constant", + ) + + self.conv_pres = nn.ModuleList() + for _ in range(self.n_octaves): + self.conv_pres.append( + nn.Conv2d( + self.in_channels * 2, + self.in_channels * 2, + kernel_size=self.kernel_size, + padding=self.get_2d_padding(self.kernel_size), + ) + ) + + self.convs = nn.ModuleList() + + self.convs.append( + nn.Conv2d( + self.in_channels * 2, + self.filters, + kernel_size=self.kernel_size, + padding=self.get_2d_padding(self.kernel_size), + ) + ) + + in_chs = min(self.filters_scale * self.filters, self.max_filters) + for i, dilation in enumerate(self.dilations): + out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters) + self.convs.append( + weight_norm( + nn.Conv2d( + in_chs, + out_chs, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=(dilation, 1), + padding=self.get_2d_padding(self.kernel_size, (dilation, 1)), + ) + ) + ) + in_chs = out_chs + out_chs = min( + (self.filters_scale ** (len(self.dilations) + 1)) * self.filters, + self.max_filters, + ) + self.convs.append( + weight_norm( + nn.Conv2d( + in_chs, + out_chs, + kernel_size=(self.kernel_size[0], self.kernel_size[0]), + padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])), + ) + ) + ) + + self.conv_post = weight_norm( + nn.Conv2d( + out_chs, + self.out_channels, + kernel_size=(self.kernel_size[0], self.kernel_size[0]), + padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])), + ) + ) + + self.activation = torch.nn.LeakyReLU(negative_slope=0.1) + self.resample = Resample(orig_freq=self.fs, new_freq=self.fs * 2) + + self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False) + if self.cqtd_normalize_volume: + print( + "[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!" + ) + + def get_2d_padding( + self, + kernel_size: typing.Tuple[int, int], + dilation: typing.Tuple[int, int] = (1, 1), + ): + return ( + ((kernel_size[0] - 1) * dilation[0]) // 2, + ((kernel_size[1] - 1) * dilation[1]) // 2, + ) + + def forward(self, x: torch.tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + fmap = [] + + if self.cqtd_normalize_volume: + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + + x = self.resample(x) + + z = self.cqt_transform(x) + + z_amplitude = z[:, :, :, 0].unsqueeze(1) + z_phase = z[:, :, :, 1].unsqueeze(1) + + z = torch.cat([z_amplitude, z_phase], dim=1) + z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W] + + latent_z = [] + for i in range(self.n_octaves): + latent_z.append( + self.conv_pres[i]( + z[ + :, + :, + :, + i * self.bins_per_octave : (i + 1) * self.bins_per_octave, + ] + ) + ) + latent_z = torch.cat(latent_z, dim=-1) + + for i, l in enumerate(self.convs): + latent_z = l(latent_z) + + latent_z = self.activation(latent_z) + fmap.append(latent_z) + + latent_z = self.conv_post(latent_z) + + return latent_z, fmap + + +class MultiScaleSubbandCQTDiscriminator(nn.Module): + def __init__(self, cfg: AttrDict): + super().__init__() + + self.cfg = cfg + # Using get with defaults + self.cfg["cqtd_filters"] = self.cfg.get("cqtd_filters", 32) + self.cfg["cqtd_max_filters"] = self.cfg.get("cqtd_max_filters", 1024) + self.cfg["cqtd_filters_scale"] = self.cfg.get("cqtd_filters_scale", 1) + self.cfg["cqtd_dilations"] = self.cfg.get("cqtd_dilations", [1, 2, 4]) + self.cfg["cqtd_in_channels"] = self.cfg.get("cqtd_in_channels", 1) + self.cfg["cqtd_out_channels"] = self.cfg.get("cqtd_out_channels", 1) + # Multi-scale params to loop over + self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256]) + self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9]) + self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48]) + + self.discriminators = nn.ModuleList( + [ + DiscriminatorCQT( + self.cfg, + hop_length=self.cfg["cqtd_hop_lengths"][i], + n_octaves=self.cfg["cqtd_n_octaves"][i], + bins_per_octave=self.cfg["cqtd_bins_per_octaves"][i], + ) + for i in range(len(self.cfg["cqtd_hop_lengths"])) + ] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for disc in self.discriminators: + y_d_r, fmap_r = disc(y) + y_d_g, fmap_g = disc(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class CombinedDiscriminator(nn.Module): + """ + Wrapper of chaining multiple discrimiantor architectures. + Example: combine mbd and cqtd as a single class + """ + + def __init__(self, list_discriminator: List[nn.Module]): + super().__init__() + self.discrimiantor = nn.ModuleList(list_discriminator) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for disc in self.discrimiantor: + y_d_r, y_d_g, fmap_r, fmap_g = disc(y, y_hat) + y_d_rs.extend(y_d_r) + fmap_rs.extend(fmap_r) + y_d_gs.extend(y_d_g) + fmap_gs.extend(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/BigVGAN/env.py b/BigVGAN/env.py new file mode 100644 index 0000000000000000000000000000000000000000..cf8ac6cea644c78d115dd3902b902993f366ee61 --- /dev/null +++ b/BigVGAN/env.py @@ -0,0 +1,18 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/BigVGAN/incl_licenses/LICENSE_1 b/BigVGAN/incl_licenses/LICENSE_1 new file mode 100644 index 0000000000000000000000000000000000000000..5afae394d6b37da0e12ba6b290d2512687f421ac --- /dev/null +++ b/BigVGAN/incl_licenses/LICENSE_1 @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/BigVGAN/incl_licenses/LICENSE_2 b/BigVGAN/incl_licenses/LICENSE_2 new file mode 100644 index 0000000000000000000000000000000000000000..322b758863c4219be68291ae3826218baa93cb4c --- /dev/null +++ b/BigVGAN/incl_licenses/LICENSE_2 @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Edward Dixon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/BigVGAN/incl_licenses/LICENSE_3 b/BigVGAN/incl_licenses/LICENSE_3 new file mode 100644 index 0000000000000000000000000000000000000000..56ee3c8c4cc2b4b32e0975d17258f9ba515fdbcc --- /dev/null +++ b/BigVGAN/incl_licenses/LICENSE_3 @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/BigVGAN/incl_licenses/LICENSE_4 b/BigVGAN/incl_licenses/LICENSE_4 new file mode 100644 index 0000000000000000000000000000000000000000..48fd1a1ba8d81a94b6c7d1c2ff1a1f307cc5371d --- /dev/null +++ b/BigVGAN/incl_licenses/LICENSE_4 @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2019, Seungwon Park 박승원 +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/BigVGAN/incl_licenses/LICENSE_5 b/BigVGAN/incl_licenses/LICENSE_5 new file mode 100644 index 0000000000000000000000000000000000000000..01ae5538e6b7c787bb4f5d6f2cd9903520d6e465 --- /dev/null +++ b/BigVGAN/incl_licenses/LICENSE_5 @@ -0,0 +1,16 @@ +Copyright 2020 Alexandre Défossez + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, +including without limitation the rights to use, copy, modify, merge, publish, distribute, +sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or +substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/BigVGAN/incl_licenses/LICENSE_6 b/BigVGAN/incl_licenses/LICENSE_6 new file mode 100644 index 0000000000000000000000000000000000000000..2569ec0b6c85f94f3cd071ba16e9028ccf156be2 --- /dev/null +++ b/BigVGAN/incl_licenses/LICENSE_6 @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023-present, Descript + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/BigVGAN/incl_licenses/LICENSE_7 b/BigVGAN/incl_licenses/LICENSE_7 new file mode 100644 index 0000000000000000000000000000000000000000..c37bdaf99c6921f5849425d546069e972f52d7fa --- /dev/null +++ b/BigVGAN/incl_licenses/LICENSE_7 @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Charactr Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/BigVGAN/incl_licenses/LICENSE_8 b/BigVGAN/incl_licenses/LICENSE_8 new file mode 100644 index 0000000000000000000000000000000000000000..ab3d7ffe795779f54e339078e4e752ad9019aae8 --- /dev/null +++ b/BigVGAN/incl_licenses/LICENSE_8 @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Amphion + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/BigVGAN/inference.py b/BigVGAN/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..5f892a3c807a7020eff7fea35179b0f6e5f991c9 --- /dev/null +++ b/BigVGAN/inference.py @@ -0,0 +1,85 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import argparse +import json +import torch +import librosa +from utils import load_checkpoint +from meldataset import get_mel_spectrogram +from scipy.io.wavfile import write +from env import AttrDict +from meldataset import MAX_WAV_VALUE +from bigvgan import BigVGAN as Generator + +h = None +device = None +torch.backends.cudnn.benchmark = False + + +def inference(a, h): + generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device) + + state_dict_g = load_checkpoint(a.checkpoint_file, device) + generator.load_state_dict(state_dict_g["generator"]) + + filelist = os.listdir(a.input_wavs_dir) + + os.makedirs(a.output_dir, exist_ok=True) + + generator.eval() + generator.remove_weight_norm() + with torch.no_grad(): + for i, filname in enumerate(filelist): + # Load the ground truth audio and resample if necessary + wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True) + wav = torch.FloatTensor(wav).to(device) + # Compute mel spectrogram from the ground truth audio + x = get_mel_spectrogram(wav.unsqueeze(0), generator.h) + + y_g_hat = generator(x) + + audio = y_g_hat.squeeze() + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype("int16") + + output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav") + write(output_file, h.sampling_rate, audio) + print(output_file) + + +def main(): + print("Initializing Inference Process..") + + parser = argparse.ArgumentParser() + parser.add_argument("--input_wavs_dir", default="test_files") + parser.add_argument("--output_dir", default="generated_files") + parser.add_argument("--checkpoint_file", required=True) + parser.add_argument("--use_cuda_kernel", action="store_true", default=False) + + a = parser.parse_args() + + config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json") + with open(config_file) as f: + data = f.read() + + global h + json_config = json.loads(data) + h = AttrDict(json_config) + + torch.manual_seed(h.seed) + global device + if torch.cuda.is_available(): + torch.cuda.manual_seed(h.seed) + device = torch.device("cuda") + else: + device = torch.device("cpu") + + inference(a, h) + + +if __name__ == "__main__": + main() diff --git a/BigVGAN/inference_e2e.py b/BigVGAN/inference_e2e.py new file mode 100644 index 0000000000000000000000000000000000000000..9c0df77435e91935beaca365dd5fd38d76098a4a --- /dev/null +++ b/BigVGAN/inference_e2e.py @@ -0,0 +1,100 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import glob +import os +import numpy as np +import argparse +import json +import torch +from scipy.io.wavfile import write +from env import AttrDict +from meldataset import MAX_WAV_VALUE +from bigvgan import BigVGAN as Generator + +h = None +device = None +torch.backends.cudnn.benchmark = False + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + "*") + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return "" + return sorted(cp_list)[-1] + + +def inference(a, h): + generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device) + + state_dict_g = load_checkpoint(a.checkpoint_file, device) + generator.load_state_dict(state_dict_g["generator"]) + + filelist = os.listdir(a.input_mels_dir) + + os.makedirs(a.output_dir, exist_ok=True) + + generator.eval() + generator.remove_weight_norm() + with torch.no_grad(): + for i, filname in enumerate(filelist): + # Load the mel spectrogram in .npy format + x = np.load(os.path.join(a.input_mels_dir, filname)) + x = torch.FloatTensor(x).to(device) + if len(x.shape) == 2: + x = x.unsqueeze(0) + + y_g_hat = generator(x) + + audio = y_g_hat.squeeze() + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype("int16") + + output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav") + write(output_file, h.sampling_rate, audio) + print(output_file) + + +def main(): + print("Initializing Inference Process..") + + parser = argparse.ArgumentParser() + parser.add_argument("--input_mels_dir", default="test_mel_files") + parser.add_argument("--output_dir", default="generated_files_from_mel") + parser.add_argument("--checkpoint_file", required=True) + parser.add_argument("--use_cuda_kernel", action="store_true", default=False) + + a = parser.parse_args() + + config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json") + with open(config_file) as f: + data = f.read() + + global h + json_config = json.loads(data) + h = AttrDict(json_config) + + torch.manual_seed(h.seed) + global device + if torch.cuda.is_available(): + torch.cuda.manual_seed(h.seed) + device = torch.device("cuda") + else: + device = torch.device("cpu") + + inference(a, h) + + +if __name__ == "__main__": + main() diff --git a/BigVGAN/loss.py b/BigVGAN/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c295a144ff7bcfc0d91d9d4676bedfa7015cdb79 --- /dev/null +++ b/BigVGAN/loss.py @@ -0,0 +1,238 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + + +import torch +import torch.nn as nn +from librosa.filters import mel as librosa_mel_fn +from scipy import signal + +import typing +from typing import List, Tuple +from collections import namedtuple +import math +import functools + + +# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license. +# LICENSE is in incl_licenses directory. +class MultiScaleMelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320], + window_lengths : List[int], optional + Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part) + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 1.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py + """ + + def __init__( + self, + sampling_rate: int, + n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320], + window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 0.0, + log_weight: float = 1.0, + pow: float = 1.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0], + mel_fmax: List[float] = [None, None, None, None, None, None, None], + window_type: str = "hann", + ): + super().__init__() + self.sampling_rate = sampling_rate + + STFTParams = namedtuple( + "STFTParams", + ["window_length", "hop_length", "window_type", "match_stride"], + ) + + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + @staticmethod + @functools.lru_cache(None) + def get_window( + window_type, + window_length, + ): + return signal.get_window(window_type, window_length) + + @staticmethod + @functools.lru_cache(None) + def get_mel_filters(sr, n_fft, n_mels, fmin, fmax): + return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + + def mel_spectrogram( + self, + wav, + n_mels, + fmin, + fmax, + window_length, + hop_length, + match_stride, + window_type, + ): + """ + Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from: + https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py + """ + B, C, T = wav.shape + + if match_stride: + assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4" + right_pad = math.ceil(T / hop_length) * hop_length - T + pad = (window_length - hop_length) // 2 + else: + right_pad = 0 + pad = 0 + + wav = torch.nn.functional.pad(wav, (pad, pad + right_pad), mode="reflect") + + window = self.get_window(window_type, window_length) + window = torch.from_numpy(window).to(wav.device).float() + + stft = torch.stft( + wav.reshape(-1, T), + n_fft=window_length, + hop_length=hop_length, + window=window, + return_complex=True, + center=True, + ) + _, nf, nt = stft.shape + stft = stft.reshape(B, C, nf, nt) + if match_stride: + """ + Drop first two and last two frames, which are added, because of padding. Now num_frames * hop_length = num_samples. + """ + stft = stft[..., 2:-2] + magnitude = torch.abs(stft) + + nf = magnitude.shape[2] + mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax) + mel_basis = torch.from_numpy(mel_basis).to(wav.device) + mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T + mel_spectrogram = mel_spectrogram.transpose(-1, 2) + + return mel_spectrogram + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : torch.Tensor + Estimate signal + y : torch.Tensor + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + + loss = 0.0 + for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params): + kwargs = { + "n_mels": n_mels, + "fmin": fmin, + "fmax": fmax, + "window_length": s.window_length, + "hop_length": s.hop_length, + "match_stride": s.match_stride, + "window_type": s.window_type, + } + + x_mels = self.mel_spectrogram(x, **kwargs) + y_mels = self.mel_spectrogram(y, **kwargs) + x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0)) + y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0)) + + loss += self.log_weight * self.loss_fn(x_logmels, y_logmels) + loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels) + + return loss + + +# Loss functions +def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor: + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 # This equates to lambda=2.0 for the feature matching loss + + +def discriminator_loss( + disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor] +) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss( + disc_outputs: List[torch.Tensor], +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/BigVGAN/meldataset.py b/BigVGAN/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dc12c9874cfb9958d6f4842cc067ffda66a390eb --- /dev/null +++ b/BigVGAN/meldataset.py @@ -0,0 +1,370 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import math +import os +import random +import torch +import torch.utils.data +import numpy as np +import librosa +from librosa.filters import mel as librosa_mel_fn +import pathlib +from tqdm import tqdm +from typing import List, Tuple, Optional +from .env import AttrDict + +MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases) + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + return dynamic_range_compression_torch(magnitudes) + + +def spectral_de_normalize_torch(magnitudes): + return dynamic_range_decompression_torch(magnitudes) + + +mel_basis_cache = {} +hann_window_cache = {} + + +def mel_spectrogram( + y: torch.Tensor, + n_fft: int, + num_mels: int, + sampling_rate: int, + hop_size: int, + win_size: int, + fmin: int, + fmax: int = None, + center: bool = False, +) -> torch.Tensor: + """ + Calculate the mel spectrogram of an input signal. + This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). + + Args: + y (torch.Tensor): Input signal. + n_fft (int): FFT size. + num_mels (int): Number of mel bins. + sampling_rate (int): Sampling rate of the input signal. + hop_size (int): Hop size for STFT. + win_size (int): Window size for STFT. + fmin (int): Minimum frequency for mel filterbank. + fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn + center (bool): Whether to pad the input to center the frames. Default is False. + + Returns: + torch.Tensor: Mel spectrogram. + """ + if torch.min(y) < -1.0: + print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") + if torch.max(y) > 1.0: + print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") + + device = y.device + key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" + + if key not in mel_basis_cache: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) + hann_window_cache[key] = torch.hann_window(win_size).to(device) + + mel_basis = mel_basis_cache[key] + hann_window = hann_window_cache[key] + + padding = (n_fft - hop_size) // 2 + y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = spectral_normalize_torch(mel_spec) + + return mel_spec + + +def get_mel_spectrogram(wav, h): + """ + Generate mel spectrogram from a waveform using given hyperparameters. + + Args: + wav (torch.Tensor): Input waveform. + h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax. + + Returns: + torch.Tensor: Mel spectrogram. + """ + return mel_spectrogram( + wav, + h.n_fft, + h.num_mels, + h.sampling_rate, + h.hop_size, + h.win_size, + h.fmin, + h.fmax, + ) + + +def get_dataset_filelist(a): + training_files = [] + validation_files = [] + list_unseen_validation_files = [] + + with open(a.input_training_file, "r", encoding="utf-8") as fi: + training_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + print(f"first training file: {training_files[0]}") + + with open(a.input_validation_file, "r", encoding="utf-8") as fi: + validation_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + print(f"first validation file: {validation_files[0]}") + + for i in range(len(a.list_input_unseen_validation_file)): + with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi: + unseen_validation_files = [ + os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav") + for x in fi.read().split("\n") + if len(x) > 0 + ] + print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}") + list_unseen_validation_files.append(unseen_validation_files) + + return training_files, validation_files, list_unseen_validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__( + self, + training_files: List[str], + hparams: AttrDict, + segment_size: int, + n_fft: int, + num_mels: int, + hop_size: int, + win_size: int, + sampling_rate: int, + fmin: int, + fmax: Optional[int], + split: bool = True, + shuffle: bool = True, + device: str = None, + fmax_loss: Optional[int] = None, + fine_tuning: bool = False, + base_mels_path: str = None, + is_seen: bool = True, + ): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.hparams = hparams + self.is_seen = is_seen + if self.is_seen: + self.name = pathlib.Path(self.audio_files[0]).parts[0] + else: + self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/") + + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + print("[INFO] checking dataset integrity...") + for i in tqdm(range(len(self.audio_files))): + assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found" + + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]: + try: + filename = self.audio_files[index] + + # Use librosa.load that ensures loading waveform into mono with [-1, 1] float values + # Audio is ndarray with shape [T_time]. Disable auto-resampling here to minimize overhead + # The on-the-fly resampling during training will be done only for the obtained random chunk + audio, source_sampling_rate = librosa.load(filename, sr=None, mono=True) + + # Main logic that uses pair for training BigVGAN + if not self.fine_tuning: + if self.split: # Training step + # Obtain randomized audio chunk + if source_sampling_rate != self.sampling_rate: + # Adjust segment size to crop if the source sr is different + target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate)) + else: + target_segment_size = self.segment_size + + # Compute upper bound index for the random chunk + random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size) + + # Crop or pad audio to obtain random chunk with target_segment_size + if audio.shape[0] >= target_segment_size: + audio_start = random.randint(0, random_chunk_upper_bound) + audio = audio[audio_start : audio_start + target_segment_size] + else: + audio = np.pad( + audio, + (0, target_segment_size - audio.shape[0]), + mode="constant", + ) + + # Resample audio chunk to self.sampling rate + if source_sampling_rate != self.sampling_rate: + audio = librosa.resample( + audio, + orig_sr=source_sampling_rate, + target_sr=self.sampling_rate, + ) + if audio.shape[0] > self.segment_size: + # trim last elements to match self.segment_size (e.g., 16385 for 44khz downsampled to 24khz -> 16384) + audio = audio[: self.segment_size] + + else: # Validation step + # Resample full audio clip to target sampling rate + if source_sampling_rate != self.sampling_rate: + audio = librosa.resample( + audio, + orig_sr=source_sampling_rate, + target_sr=self.sampling_rate, + ) + # Trim last elements to match audio length to self.hop_size * n for evaluation + if (audio.shape[0] % self.hop_size) != 0: + audio = audio[: -(audio.shape[0] % self.hop_size)] + + # BigVGAN is trained using volume-normalized waveform + audio = librosa.util.normalize(audio) * 0.95 + + # Cast ndarray to torch tensor + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) # [B(1), self.segment_size] + + # Compute mel spectrogram corresponding to audio + mel = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax, + center=False, + ) # [B(1), self.num_mels, self.segment_size // self.hop_size] + + # Fine-tuning logic that uses pre-computed mel. Example: Using TTS model-generated mel as input + else: + # For fine-tuning, assert that the waveform is in the defined sampling_rate + # Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly) + assert source_sampling_rate == self.sampling_rate, ( + f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}" + ) + + # Cast ndarray to torch tensor + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) # [B(1), T_time] + + # Load pre-computed mel from disk + mel = np.load( + os.path.join( + self.base_mels_path, + os.path.splitext(os.path.split(filename)[-1])[0] + ".npy", + ) + ) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) # ensure [B, C, T] + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start : mel_start + frames_per_seg] + audio = audio[ + :, + mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size, + ] + + # Pad pre-computed mel and audio to match length to ensuring fine-tuning without error. + # NOTE: this may introduce a single-frame misalignment of the + # To remove possible misalignment, it is recommended to prepare the pair where the audio length is the integer multiple of self.hop_size + mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + + # Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None) + mel_loss = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax_loss, + center=False, + ) # [B(1), self.num_mels, self.segment_size // self.hop_size] + + # Shape sanity checks + assert ( + audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size + ), ( + f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}" + ) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + # If it encounters error during loading the data, skip this sample and load random other sample to the batch + except Exception as e: + if self.fine_tuning: + raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly. + else: + print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}") + return self[random.randrange(len(self))] + + def __len__(self): + return len(self.audio_files) diff --git a/BigVGAN/nv-modelcard++/.gitkeep b/BigVGAN/nv-modelcard++/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/BigVGAN/nv-modelcard++/.gitkeep @@ -0,0 +1 @@ + diff --git a/BigVGAN/nv-modelcard++/bias.md b/BigVGAN/nv-modelcard++/bias.md new file mode 100644 index 0000000000000000000000000000000000000000..4b388c28d09b8ca3aab5096304c52e1a5dac0e16 --- /dev/null +++ b/BigVGAN/nv-modelcard++/bias.md @@ -0,0 +1,4 @@ +| Field | Response | +| :--------------------------------------------------------------------------------------------------------- | :--------------------------------------------------- | +| Participation considerations from adversely impacted groups protected classes in model design and testing: | None | +| Measures taken to mitigate against unwanted bias: | No measures taken to mitigate against unwanted bias. | diff --git a/BigVGAN/nv-modelcard++/explainability.md b/BigVGAN/nv-modelcard++/explainability.md new file mode 100644 index 0000000000000000000000000000000000000000..6f1a16676e438ba95f9d411a19e04a0f13409e54 --- /dev/null +++ b/BigVGAN/nv-modelcard++/explainability.md @@ -0,0 +1,13 @@ +| Field | Response | +| :---------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Intended Application & Domain: | Generating waveform from mel spectrogram. | +| Model Type: | Convolutional Neural Network (CNN) | +| Intended Users: | This model is intended for developers to synthesize and generate waveforms from the AI-generated mel spectrograms. | +| Output: | Audio Waveform | +| Describe how the model works: | Model generates audio waveform corresponding to the input mel spectrogram. | +| Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable | +| Technical Limitations: | This may not perform well on synthetically-generated mel spectrograms that deviate significantly from the profile of mel spectrograms on which this was trained. | +| Verified to have met prescribed NVIDIA quality standards: | Yes | +| Performance Metrics: | Perceptual Evaluation of Speech Quality (PESQ), Virtual Speech Quality Objective Listener (VISQOL), Multi-resolution STFT (MRSTFT), Mel cepstral distortion (MCD), Periodicity RMSE, Voice/Unvoiced F1 Score (V/UV F1) | +| Potential Known Risks: | This model may generate low-quality or distorted soundwaves. | +| Licensing: | https://github.com/NVIDIA/BigVGAN/blob/main/LICENSE | diff --git a/BigVGAN/nv-modelcard++/overview.md b/BigVGAN/nv-modelcard++/overview.md new file mode 100644 index 0000000000000000000000000000000000000000..a39cba0b49a4a32a37afa90f2baf4630dcd9cadc --- /dev/null +++ b/BigVGAN/nv-modelcard++/overview.md @@ -0,0 +1,126 @@ +# Model Overview + +## Description: + +BigVGAN is a generative AI model specialized in synthesizing audio waveforms using Mel spectrogram as inputs. + +
+ +BigVGAN is a fully convolutional architecture with several upsampling blocks using transposed convolution followed by multiple residual dilated convolution layers. + +BigVGAN consists of a novel module, called anti-aliased multi-periodicity composition (AMP), which is specifically designed for generating waveforms. AMP is specialized in synthesizing high-frequency and periodic soundwaves drawing inspiration from audio signal processing principles. + +It applies a periodic activation function, called Snake, which provides an inductive bias to the architecture in generating periodic soundwaves. It also applies anti-aliasing filters to reduce undesired artifacts in the generated waveforms.
+ +This model is ready for commercial use.
+ +## References(s): + +- [BigVGAN: A Universal Neural Vocoder with Large-Scale Training](https://arxiv.org/abs/2206.04658)
+- [Project Page](https://research.nvidia.com/labs/adlr/projects/bigvgan/)
+- [Audio Demo](https://bigvgan-demo.github.io/)
+ +## Model Architecture: + +**Architecture Type:** Convolution Neural Network (CNN)
+**Network Architecture:** You can see the details of this model on this link: https://github.com/NVIDIA/BigVGAN and the related paper can be found here: https://arxiv.org/abs/2206.04658
+**Model Version:** 2.0
+ +## Input: + +**Input Type:** Audio
+**Input Format:** Mel Spectrogram
+**Input Parameters:** None
+**Other Properties Related to Input:** The input mel spectrogram has shape `[batch, channels, frames]`, where `channels` refers to the number of mel bands defined by the model and `frames` refers to the temporal length. The model supports arbitrary long `frames` that fits into the GPU memory. + +## Output: + +**Input Type:** Audio
+**Output Format:** Audio Waveform
+**Output Parameters:** None
+**Other Properties Related to Output:** The output audio waveform has shape `[batch, 1, time]`, where `1` refers to the mono audio channels and `time` refers to the temporal length. `time` is defined as a fixed integer multiple of input `frames`, which is an upsampling ratio of the model (`time = upsampling ratio * frames`). The output audio waveform consitutes float values with a range of `[-1, 1]`. + +## Software Integration: + +**Runtime Engine(s):** PyTorch + +**Supported Hardware Microarchitecture Compatibility:** NVIDIA Ampere, NVIDIA Hopper, NVIDIA Lovelace, NVIDIA Turing, NVIDIA Volta
+ +## Preferred/Supported Operating System(s): + +Linux + +## Model Version(s): + +v2.0 + +## Training, Testing, and Evaluation Datasets: + +### Training Dataset: + +The dataset contains diverse audio types, including speech in multiple languages, environmental sounds, and instruments. + +**Links:** + +- [AAM: Artificial Audio Multitracks Dataset](https://zenodo.org/records/5794629) +- [AudioCaps](https://audiocaps.github.io/) +- [AudioSet](https://research.google.com/audioset/index.html) +- [common-accent](https://huggingface.co/datasets/DTU54DL/common-accent) +- [Crowd Sourced Emotional Multimodal Actors Dataset (CREMA-D)](https://ieeexplore.ieee.org/document/6849440) +- [DCASE2017 Challenge, Task 4: Large-scale weakly supervised sound event detection for smart cars](https://dcase.community/challenge2017/task-large-scale-sound-event-detection) +- [FSDnoisy18k](https://zenodo.org/records/2529934) +- [Free Universal Sound Separation Dataset](https://zenodo.org/records/3694384) +- [Greatest Hits dataset](https://andrewowens.com/vis/) +- [GTZAN](https://ieeexplore.ieee.org/document/1021072) +- [JL corpus](https://www.kaggle.com/datasets/tli725/jl-corpus) +- [Medley-solos-DB: a cross-collection dataset for musical instrument recognition](https://zenodo.org/records/3464194) +- [MUSAN: A Music, Speech, and Noise Corpus](https://www.openslr.org/17/) +- [MusicBench](https://huggingface.co/datasets/amaai-lab/MusicBench) +- [MusicCaps](https://www.kaggle.com/datasets/googleai/musiccaps) +- [MusicNet](https://www.kaggle.com/datasets/imsparsh/musicnet-dataset) +- [NSynth](https://magenta.tensorflow.org/datasets/nsynth) +- [OnAir-Music-Dataset](https://github.com/sevagh/OnAir-Music-Dataset) +- [Audio Piano Triads Dataset](https://zenodo.org/records/4740877) +- [Pitch Audio Dataset (Surge synthesizer)](https://zenodo.org/records/4677097) +- [SONYC Urban Sound Tagging (SONYC-UST): a multilabel dataset from an urban acoustic sensor network](https://zenodo.org/records/3966543) +- [VocalSound: A Dataset for Improving Human Vocal Sounds Recognition](https://arxiv.org/abs/2205.03433) +- [WavText5K](https://github.com/microsoft/WavText5K) +- [CSS10: A Collection of Single Speaker Speech Datasets for 10 Languages](https://github.com/Kyubyong/css10) +- [Hi-Fi Multi-Speaker English TTS Dataset (Hi-Fi TTS)](https://www.openslr.org/109/) +- [IIIT-H Indic Speech Databases](http://festvox.org/databases/iiit_voices/) +- [Libri-Light: A Benchmark for ASR with Limited or No Supervision](https://arxiv.org/abs/1912.07875) +- [LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech](https://www.openslr.org/60) +- [LibriTTS-R: A Restored Multi-Speaker Text-to-Speech Corpus](https://www.openslr.org/141/) +- [The SIWIS French Speech Synthesis Database](https://datashare.ed.ac.uk/handle/10283/2353) +- [Crowdsourced high-quality Colombian Spanish speech data set](https://openslr.org/72/) +- [TTS-Portuguese Corpus](https://github.com/Edresson/TTS-Portuguese-Corpus) +- [CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit](https://datashare.ed.ac.uk/handle/10283/3443) + +\*\* Data Collection Method by dataset
+ +- Human
+ +\*\* Labeling Method by dataset (for those with labels)
+ +- Hybrid: Automated, Human, Unknown
+ +### Evaluating Dataset: + +Properties: The audio generation quality of BigVGAN is evaluated using `dev` splits of the [LibriTTS dataset](https://www.openslr.org/60/) and [Hi-Fi TTS dataset](https://www.openslr.org/109/). The datasets include speech in English language with equal balance of genders. + +\*\* Data Collection Method by dataset
+ +- Human
+ +\*\* Labeling Method by dataset
+ +- Automated
+ +## Inference: + +**Engine:** PyTorch
+**Test Hardware:** NVIDIA A100 GPU
+ +## Ethical Considerations: + +NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the Model Card++ Explainability, Bias, Safety & Security, and Privacy Subcards. Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/). diff --git a/BigVGAN/nv-modelcard++/privacy.md b/BigVGAN/nv-modelcard++/privacy.md new file mode 100644 index 0000000000000000000000000000000000000000..73554a998384ca1b1050239ebd51bda46aec1878 --- /dev/null +++ b/BigVGAN/nv-modelcard++/privacy.md @@ -0,0 +1,14 @@ +| Field | Response | +| :------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------- | +| Generatable or reverse engineerable personal information? | None | +| Protected class data used to create this model? | None | +| Was consent obtained for any personal data used? | Not Applicable (No Personal Data) | +| How often is dataset reviewed? | Before Release | +| Is a mechanism in place to honor data subject right of access or deletion of personal data? | Not Applicable | +| If personal collected for the development of the model, was it collected directly by NVIDIA? | Not Applicable | +| If personal collected for the development of the model by NVIDIA, do you maintain or have access to disclosures made to data subjects? | Not Applicable | +| If personal collected for the development of this AI model, was it minimized to only what was required? | Not Applicable | +| Is data in dataset traceable? | Yes | +| Is there provenance for all datasets used in training? | Yes | +| Does data labeling (annotation, metadata) comply with privacy laws? | Yes | +| Is data compliant with data subject requests for data correction or removal, if such a request was made? | No, not possible with externally-sourced data. | diff --git a/BigVGAN/nv-modelcard++/safety.md b/BigVGAN/nv-modelcard++/safety.md new file mode 100644 index 0000000000000000000000000000000000000000..ed30370dfedbbb49748706034a7153d54f1a668f --- /dev/null +++ b/BigVGAN/nv-modelcard++/safety.md @@ -0,0 +1,6 @@ +| Field | Response | +| :---------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Model Application(s): | Synethic Audio Generation | +| Describe the life critical impact (if present). | Not Applicable | +| Use Case Restrictions: | None | +| Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to. | diff --git a/BigVGAN/requirements.txt b/BigVGAN/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6e61d3203966612e6ad193bbabdef10b1d3fed84 --- /dev/null +++ b/BigVGAN/requirements.txt @@ -0,0 +1,13 @@ +torch +numpy +librosa>=0.8.1 +scipy +tensorboard +soundfile +matplotlib +pesq +auraloss +tqdm +nnAudio +ninja +huggingface_hub>=0.23.4 \ No newline at end of file diff --git a/BigVGAN/tests/test_activation.py b/BigVGAN/tests/test_activation.py new file mode 100644 index 0000000000000000000000000000000000000000..4134883540e472afb9b79972dd5e1cd36bee0e04 --- /dev/null +++ b/BigVGAN/tests/test_activation.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import os +import sys + +# to import modules from parent_dir +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +import torch +from alias_free_activation.cuda import activation1d +from activations import Snake + + +def test_load_fused_kernels(): + try: + print("[Success] load_fused_kernels") + except ImportError as e: + print("[Fail] load_fused_kernels") + raise e + + +def test_anti_alias_activation(): + data = torch.rand((10, 10, 200), device="cuda") + + # Check activations.Snake cuda vs. torch + fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda() + fused_activation_output = fused_anti_alias_activation(data) + + torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda() + torch_activation_output = torch_anti_alias_activation(data) + + test_result = (fused_activation_output - torch_activation_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_anti_alias_activation" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}" + f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_anti_alias_activation" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, " + f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}" + ) + + +if __name__ == "__main__": + from alias_free_activation.cuda import load + + load.load() + test_load_fused_kernels() + test_anti_alias_activation() diff --git a/BigVGAN/tests/test_activation_snake_beta.py b/BigVGAN/tests/test_activation_snake_beta.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc46b98ff0e91ddbaa025aa7e86afa828bde71f --- /dev/null +++ b/BigVGAN/tests/test_activation_snake_beta.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import os +import sys + +# to import modules from parent_dir +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +import torch +from alias_free_activation.cuda import activation1d +from activations import SnakeBeta + + +def test_load_fused_kernels(): + try: + print("[Success] load_fused_kernels") + except ImportError as e: + print("[Fail] load_fused_kernels") + raise e + + +def test_anti_alias_activation(): + data = torch.rand((10, 10, 200), device="cuda") + + # Check activations, Snake CUDA vs. Torch + fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda() + fused_activation_output = fused_anti_alias_activation(data) + + torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda() + torch_activation_output = torch_anti_alias_activation(data) + + test_result = (fused_activation_output - torch_activation_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_anti_alias_activation" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}" + f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_anti_alias_activation" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, " + f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}" + ) + + +if __name__ == "__main__": + from alias_free_activation.cuda import load + + load.load() + test_load_fused_kernels() + test_anti_alias_activation() diff --git a/BigVGAN/tests/test_cuda_vs_torch_model.py b/BigVGAN/tests/test_cuda_vs_torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8ddb29e56c32eb20533a3022949f9487055380d9 --- /dev/null +++ b/BigVGAN/tests/test_cuda_vs_torch_model.py @@ -0,0 +1,215 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import os +import sys + +# to import modules from parent_dir +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +import torch +import json +from env import AttrDict +from bigvgan import BigVGAN +from time import time +from tqdm import tqdm +from meldataset import mel_spectrogram, MAX_WAV_VALUE +from scipy.io.wavfile import write +import numpy as np + +import argparse + +torch.backends.cudnn.benchmark = True + +# For easier debugging +torch.set_printoptions(linewidth=200, threshold=10_000) + + +def generate_soundwave(duration=5.0, sr=24000): + t = np.linspace(0, duration, int(sr * duration), False, dtype=np.float32) + + modulation = np.sin(2 * np.pi * t / duration) + + min_freq = 220 + max_freq = 1760 + frequencies = min_freq + (max_freq - min_freq) * (modulation + 1) / 2 + soundwave = np.sin(2 * np.pi * frequencies * t) + + soundwave = soundwave / np.max(np.abs(soundwave)) * 0.95 + + return soundwave, sr + + +def get_mel(x, h): + return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.") + parser.add_argument( + "--checkpoint_file", + type=str, + required=True, + help="Path to the checkpoint file. Assumes config.json exists in the directory.", + ) + + args = parser.parse_args() + + config_file = os.path.join(os.path.split(args.checkpoint_file)[0], "config.json") + with open(config_file) as f: + config = f.read() + json_config = json.loads(config) + h = AttrDict({**json_config}) + + print("loading plain Pytorch BigVGAN") + generator_original = BigVGAN(h).to("cuda") + print("loading CUDA kernel BigVGAN with auto-build") + generator_cuda_kernel = BigVGAN(h, use_cuda_kernel=True).to("cuda") + + state_dict_g = load_checkpoint(args.checkpoint_file, "cuda") + generator_original.load_state_dict(state_dict_g["generator"]) + generator_cuda_kernel.load_state_dict(state_dict_g["generator"]) + + generator_original.remove_weight_norm() + generator_original.eval() + generator_cuda_kernel.remove_weight_norm() + generator_cuda_kernel.eval() + + # define number of samples and length of mel frame to benchmark + num_sample = 10 + num_mel_frame = 16384 + + # CUDA kernel correctness check + diff = 0.0 + for i in tqdm(range(num_sample)): + # Random mel + data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") + + with torch.inference_mode(): + audio_original = generator_original(data) + + with torch.inference_mode(): + audio_cuda_kernel = generator_cuda_kernel(data) + + # Both outputs should be (almost) the same + test_result = (audio_original - audio_cuda_kernel).abs() + diff += test_result.mean(dim=-1).item() + + diff /= num_sample + if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality + print( + f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference" + f"\n > mean_difference={diff}" + f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}" + f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}" + ) + else: + print( + f"\n[Fail] test CUDA fused vs. plain torch BigVGAN inference" + f"\n > mean_difference={diff}" + f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, " + f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}" + ) + + del data, audio_original, audio_cuda_kernel + + # Variables for tracking total time and VRAM usage + toc_total_original = 0 + toc_total_cuda_kernel = 0 + vram_used_original_total = 0 + vram_used_cuda_kernel_total = 0 + audio_length_total = 0 + + # Measure Original inference in isolation + for i in tqdm(range(num_sample)): + torch.cuda.reset_peak_memory_stats(device="cuda") + data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") + torch.cuda.synchronize() + tic = time() + with torch.inference_mode(): + audio_original = generator_original(data) + torch.cuda.synchronize() + toc = time() - tic + toc_total_original += toc + + vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda") + + del data, audio_original + torch.cuda.empty_cache() + + # Measure CUDA kernel inference in isolation + for i in tqdm(range(num_sample)): + torch.cuda.reset_peak_memory_stats(device="cuda") + data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") + torch.cuda.synchronize() + tic = time() + with torch.inference_mode(): + audio_cuda_kernel = generator_cuda_kernel(data) + torch.cuda.synchronize() + toc = time() - tic + toc_total_cuda_kernel += toc + + audio_length_total += audio_cuda_kernel.shape[-1] + + vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda") + + del data, audio_cuda_kernel + torch.cuda.empty_cache() + + # Calculate metrics + audio_second = audio_length_total / h.sampling_rate + khz_original = audio_length_total / toc_total_original / 1000 + khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000 + vram_used_original_gb = vram_used_original_total / num_sample / (1024**3) + vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3) + + # Print results + print( + f"Original BigVGAN: took {toc_total_original:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_original:.1f}kHz, {audio_second / toc_total_original:.1f} faster than realtime, VRAM used {vram_used_original_gb:.1f} GB" + ) + print( + f"CUDA kernel BigVGAN: took {toc_total_cuda_kernel:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_cuda_kernel:.1f}kHz, {audio_second / toc_total_cuda_kernel:.1f} faster than realtime, VRAM used {vram_used_cuda_kernel_gb:.1f} GB" + ) + print(f"speedup of CUDA kernel: {khz_cuda_kernel / khz_original}") + print(f"VRAM saving of CUDA kernel: {vram_used_original_gb / vram_used_cuda_kernel_gb}") + + # Use artificial sine waves for inference test + audio_real, sr = generate_soundwave(duration=5.0, sr=h.sampling_rate) + audio_real = torch.tensor(audio_real).to("cuda") + # Compute mel spectrogram from the ground truth audio + x = get_mel(audio_real.unsqueeze(0), h) + + with torch.inference_mode(): + y_g_hat_original = generator_original(x) + y_g_hat_cuda_kernel = generator_cuda_kernel(x) + + audio_real = audio_real.squeeze() + audio_real = audio_real * MAX_WAV_VALUE + audio_real = audio_real.cpu().numpy().astype("int16") + + audio_original = y_g_hat_original.squeeze() + audio_original = audio_original * MAX_WAV_VALUE + audio_original = audio_original.cpu().numpy().astype("int16") + + audio_cuda_kernel = y_g_hat_cuda_kernel.squeeze() + audio_cuda_kernel = audio_cuda_kernel * MAX_WAV_VALUE + audio_cuda_kernel = audio_cuda_kernel.cpu().numpy().astype("int16") + + os.makedirs("tmp", exist_ok=True) + output_file_real = os.path.join("tmp", "audio_real.wav") + output_file_original = os.path.join("tmp", "audio_generated_original.wav") + output_file_cuda_kernel = os.path.join("tmp", "audio_generated_cuda_kernel.wav") + write(output_file_real, h.sampling_rate, audio_real) + write(output_file_original, h.sampling_rate, audio_original) + write(output_file_cuda_kernel, h.sampling_rate, audio_cuda_kernel) + print("Example generated audios of original vs. fused CUDA kernel written to tmp!") + print("Done") diff --git a/BigVGAN/train.py b/BigVGAN/train.py new file mode 100644 index 0000000000000000000000000000000000000000..39718cdb33d2e9a88ec9b98dd2032bdce83a4231 --- /dev/null +++ b/BigVGAN/train.py @@ -0,0 +1,716 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + + +import warnings + +warnings.simplefilter(action="ignore", category=FutureWarning) +import itertools +import os +import time +import argparse +import json +import torch +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DistributedSampler, DataLoader +import torch.multiprocessing as mp +from torch.distributed import init_process_group +from torch.nn.parallel import DistributedDataParallel +from env import AttrDict, build_env +from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist, MAX_WAV_VALUE + +from bigvgan import BigVGAN +from discriminators import ( + MultiPeriodDiscriminator, + MultiResolutionDiscriminator, + MultiBandDiscriminator, + MultiScaleSubbandCQTDiscriminator, +) +from loss import ( + feature_loss, + generator_loss, + discriminator_loss, + MultiScaleMelSpectrogramLoss, +) + +from utils import ( + plot_spectrogram, + plot_spectrogram_clipped, + scan_checkpoint, + load_checkpoint, + save_checkpoint, + save_audio, +) +import torchaudio as ta +from pesq import pesq +from tqdm import tqdm +import auraloss + +torch.backends.cudnn.benchmark = False + + +def train(rank, a, h): + if h.num_gpus > 1: + # initialize distributed + init_process_group( + backend=h.dist_config["dist_backend"], + init_method=h.dist_config["dist_url"], + world_size=h.dist_config["world_size"] * h.num_gpus, + rank=rank, + ) + + # Set seed and device + torch.cuda.manual_seed(h.seed) + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank:d}") + + # Define BigVGAN generator + generator = BigVGAN(h).to(device) + + # Define discriminators. MPD is used by default + mpd = MultiPeriodDiscriminator(h).to(device) + + # Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default + # New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator + if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD + print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator") + # Variable name is kept as "mrd" for backward compatibility & minimal code change + mrd = MultiBandDiscriminator(h).to(device) + elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD + print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator") + mrd = MultiScaleSubbandCQTDiscriminator(h).to(device) + else: # Fallback to original MRD in BigVGAN-v1 + mrd = MultiResolutionDiscriminator(h).to(device) + + # New in BigVGAN-v2: option to switch to multi-scale L1 mel loss + if h.get("use_multiscale_melloss", False): + print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss") + fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss( + sampling_rate=h.sampling_rate + ) # NOTE: accepts waveform as input + else: + fn_mel_loss_singlescale = F.l1_loss + + # Print the model & number of parameters, and create or scan the latest checkpoint from checkpoints directory + if rank == 0: + print(generator) + print(mpd) + print(mrd) + print(f"Generator params: {sum(p.numel() for p in generator.parameters())}") + print(f"Discriminator mpd params: {sum(p.numel() for p in mpd.parameters())}") + print(f"Discriminator mrd params: {sum(p.numel() for p in mrd.parameters())}") + os.makedirs(a.checkpoint_path, exist_ok=True) + print(f"Checkpoints directory: {a.checkpoint_path}") + + if os.path.isdir(a.checkpoint_path): + # New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training + cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt") + cp_do = scan_checkpoint( + a.checkpoint_path, + prefix="do_", + renamed_file="bigvgan_discriminator_optimizer.pt", + ) + + # Load the latest checkpoint if exists + steps = 0 + if cp_g is None or cp_do is None: + state_dict_do = None + last_epoch = -1 + else: + state_dict_g = load_checkpoint(cp_g, device) + state_dict_do = load_checkpoint(cp_do, device) + generator.load_state_dict(state_dict_g["generator"]) + mpd.load_state_dict(state_dict_do["mpd"]) + mrd.load_state_dict(state_dict_do["mrd"]) + steps = state_dict_do["steps"] + 1 + last_epoch = state_dict_do["epoch"] + + # Initialize DDP, optimizers, and schedulers + if h.num_gpus > 1: + generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) + mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) + mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device) + + optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) + optim_d = torch.optim.AdamW( + itertools.chain(mrd.parameters(), mpd.parameters()), + h.learning_rate, + betas=[h.adam_b1, h.adam_b2], + ) + + if state_dict_do is not None: + optim_g.load_state_dict(state_dict_do["optim_g"]) + optim_d.load_state_dict(state_dict_do["optim_d"]) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) + + # Define training and validation datasets + + """ + unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset + Example: trained on LibriTTS, validate on VCTK + """ + training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a) + + trainset = MelDataset( + training_filelist, + h, + h.segment_size, + h.n_fft, + h.num_mels, + h.hop_size, + h.win_size, + h.sampling_rate, + h.fmin, + h.fmax, + shuffle=False if h.num_gpus > 1 else True, + fmax_loss=h.fmax_for_loss, + device=device, + fine_tuning=a.fine_tuning, + base_mels_path=a.input_mels_dir, + is_seen=True, + ) + + train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None + + train_loader = DataLoader( + trainset, + num_workers=h.num_workers, + shuffle=False, + sampler=train_sampler, + batch_size=h.batch_size, + pin_memory=True, + drop_last=True, + ) + + if rank == 0: + validset = MelDataset( + validation_filelist, + h, + h.segment_size, + h.n_fft, + h.num_mels, + h.hop_size, + h.win_size, + h.sampling_rate, + h.fmin, + h.fmax, + False, + False, + fmax_loss=h.fmax_for_loss, + device=device, + fine_tuning=a.fine_tuning, + base_mels_path=a.input_mels_dir, + is_seen=True, + ) + validation_loader = DataLoader( + validset, + num_workers=1, + shuffle=False, + sampler=None, + batch_size=1, + pin_memory=True, + drop_last=True, + ) + + list_unseen_validset = [] + list_unseen_validation_loader = [] + for i in range(len(list_unseen_validation_filelist)): + unseen_validset = MelDataset( + list_unseen_validation_filelist[i], + h, + h.segment_size, + h.n_fft, + h.num_mels, + h.hop_size, + h.win_size, + h.sampling_rate, + h.fmin, + h.fmax, + False, + False, + fmax_loss=h.fmax_for_loss, + device=device, + fine_tuning=a.fine_tuning, + base_mels_path=a.input_mels_dir, + is_seen=False, + ) + unseen_validation_loader = DataLoader( + unseen_validset, + num_workers=1, + shuffle=False, + sampler=None, + batch_size=1, + pin_memory=True, + drop_last=True, + ) + list_unseen_validset.append(unseen_validset) + list_unseen_validation_loader.append(unseen_validation_loader) + + # Tensorboard logger + sw = SummaryWriter(os.path.join(a.checkpoint_path, "logs")) + if a.save_audio: # Also save audio to disk if --save_audio is set to True + os.makedirs(os.path.join(a.checkpoint_path, "samples"), exist_ok=True) + + """ + Validation loop, "mode" parameter is automatically defined as (seen or unseen)_(name of the dataset). + If the name of the dataset contains "nonspeech", it skips PESQ calculation to prevent errors + """ + + def validate(rank, a, h, loader, mode="seen"): + assert rank == 0, "validate should only run on rank=0" + generator.eval() + torch.cuda.empty_cache() + + val_err_tot = 0 + val_pesq_tot = 0 + val_mrstft_tot = 0 + + # Modules for evaluation metrics + pesq_resampler = ta.transforms.Resample(h.sampling_rate, 16000).cuda() + loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda") + + if a.save_audio: # Also save audio to disk if --save_audio is set to True + os.makedirs( + os.path.join(a.checkpoint_path, "samples", f"gt_{mode}"), + exist_ok=True, + ) + os.makedirs( + os.path.join(a.checkpoint_path, "samples", f"{mode}_{steps:08d}"), + exist_ok=True, + ) + + with torch.no_grad(): + print(f"step {steps} {mode} speaker validation...") + + # Loop over validation set and compute metrics + for j, batch in enumerate(tqdm(loader)): + x, y, _, y_mel = batch + y = y.to(device) + if hasattr(generator, "module"): + y_g_hat = generator.module(x.to(device)) + else: + y_g_hat = generator(x.to(device)) + y_mel = y_mel.to(device, non_blocking=True) + y_g_hat_mel = mel_spectrogram( + y_g_hat.squeeze(1), + h.n_fft, + h.num_mels, + h.sampling_rate, + h.hop_size, + h.win_size, + h.fmin, + h.fmax_for_loss, + ) + min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1)) + val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item() + + # PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out) + if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech" + # Resample to 16000 for pesq + y_16k = pesq_resampler(y) + y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1)) + y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() + y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() + val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb") + + # MRSTFT calculation + min_t = min(y.size(-1), y_g_hat.size(-1)) + val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item() + + # Log audio and figures to Tensorboard + if j % a.eval_subsample == 0: # Subsample every nth from validation set + if steps >= 0: + sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate) + if a.save_audio: # Also save audio to disk if --save_audio is set to True + save_audio( + y[0], + os.path.join( + a.checkpoint_path, + "samples", + f"gt_{mode}", + f"{j:04d}.wav", + ), + h.sampling_rate, + ) + sw.add_figure( + f"gt_{mode}/y_spec_{j}", + plot_spectrogram(x[0]), + steps, + ) + + sw.add_audio( + f"generated_{mode}/y_hat_{j}", + y_g_hat[0], + steps, + h.sampling_rate, + ) + if a.save_audio: # Also save audio to disk if --save_audio is set to True + save_audio( + y_g_hat[0, 0], + os.path.join( + a.checkpoint_path, + "samples", + f"{mode}_{steps:08d}", + f"{j:04d}.wav", + ), + h.sampling_rate, + ) + # Spectrogram of synthesized audio + y_hat_spec = mel_spectrogram( + y_g_hat.squeeze(1), + h.n_fft, + h.num_mels, + h.sampling_rate, + h.hop_size, + h.win_size, + h.fmin, + h.fmax, + ) + sw.add_figure( + f"generated_{mode}/y_hat_spec_{j}", + plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), + steps, + ) + + """ + Visualization of spectrogram difference between GT and synthesized audio, difference higher than 1 is clipped for better visualization. + """ + spec_delta = torch.clamp( + torch.abs(x[0] - y_hat_spec.squeeze(0).cpu()), + min=1e-6, + max=1.0, + ) + sw.add_figure( + f"delta_dclip1_{mode}/spec_{j}", + plot_spectrogram_clipped(spec_delta.numpy(), clip_max=1.0), + steps, + ) + + val_err = val_err_tot / (j + 1) + val_pesq = val_pesq_tot / (j + 1) + val_mrstft = val_mrstft_tot / (j + 1) + # Log evaluation metrics to Tensorboard + sw.add_scalar(f"validation_{mode}/mel_spec_error", val_err, steps) + sw.add_scalar(f"validation_{mode}/pesq", val_pesq, steps) + sw.add_scalar(f"validation_{mode}/mrstft", val_mrstft, steps) + + generator.train() + + # If the checkpoint is loaded, start with validation loop + if steps != 0 and rank == 0 and not a.debug: + if not a.skip_seen: + validate( + rank, + a, + h, + validation_loader, + mode=f"seen_{train_loader.dataset.name}", + ) + for i in range(len(list_unseen_validation_loader)): + validate( + rank, + a, + h, + list_unseen_validation_loader[i], + mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}", + ) + # Exit the script if --evaluate is set to True + if a.evaluate: + exit() + + # Main training loop + generator.train() + mpd.train() + mrd.train() + for epoch in range(max(0, last_epoch), a.training_epochs): + if rank == 0: + start = time.time() + print(f"Epoch: {epoch + 1}") + + if h.num_gpus > 1: + train_sampler.set_epoch(epoch) + + for i, batch in enumerate(train_loader): + if rank == 0: + start_b = time.time() + x, y, _, y_mel = batch + + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + y_mel = y_mel.to(device, non_blocking=True) + y = y.unsqueeze(1) + + y_g_hat = generator(x) + y_g_hat_mel = mel_spectrogram( + y_g_hat.squeeze(1), + h.n_fft, + h.num_mels, + h.sampling_rate, + h.hop_size, + h.win_size, + h.fmin, + h.fmax_for_loss, + ) + + optim_d.zero_grad() + + # MPD + y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) + loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) + + # MRD + y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach()) + loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) + + loss_disc_all = loss_disc_s + loss_disc_f + + # Set clip_grad_norm value + clip_grad_norm = h.get("clip_grad_norm", 1000.0) # Default to 1000 + + # Whether to freeze D for initial training steps + if steps >= a.freeze_step: + loss_disc_all.backward() + grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm) + grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm) + optim_d.step() + else: + print(f"[WARNING] skipping D training for the first {a.freeze_step} steps") + grad_norm_mpd = 0.0 + grad_norm_mrd = 0.0 + + # Generator + optim_g.zero_grad() + + # L1 Mel-Spectrogram Loss + lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set + if h.get("use_multiscale_melloss", False): # uses wav for loss + loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss + else: # Uses mel for loss + loss_mel = fn_mel_loss_singlescale(y_mel, y_g_hat_mel) * lambda_melloss + + # MPD loss + y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) + loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) + loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) + + # MRD loss + y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(y, y_g_hat) + loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) + loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) + + if steps >= a.freeze_step: + loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel + else: + print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps") + loss_gen_all = loss_mel + + loss_gen_all.backward() + grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm) + optim_g.step() + + if rank == 0: + # STDOUT logging + if steps % a.stdout_interval == 0: + mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout + print( + f"Steps: {steps:d}, " + f"Gen Loss Total: {loss_gen_all:4.3f}, " + f"Mel Error: {mel_error:4.3f}, " + f"s/b: {time.time() - start_b:4.3f} " + f"lr: {optim_g.param_groups[0]['lr']:4.7f} " + f"grad_norm_g: {grad_norm_g:4.3f}" + ) + + # Checkpointing + if steps % a.checkpoint_interval == 0 and steps != 0: + checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}" + save_checkpoint( + checkpoint_path, + {"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()}, + ) + checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}" + save_checkpoint( + checkpoint_path, + { + "mpd": (mpd.module if h.num_gpus > 1 else mpd).state_dict(), + "mrd": (mrd.module if h.num_gpus > 1 else mrd).state_dict(), + "optim_g": optim_g.state_dict(), + "optim_d": optim_d.state_dict(), + "steps": steps, + "epoch": epoch, + }, + ) + + # Tensorboard summary logging + if steps % a.summary_interval == 0: + mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard + sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps) + sw.add_scalar("training/mel_spec_error", mel_error, steps) + sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps) + sw.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps) + sw.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps) + sw.add_scalar("training/grad_norm_mpd", grad_norm_mpd, steps) + sw.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps) + sw.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps) + sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps) + sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps) + sw.add_scalar("training/grad_norm_g", grad_norm_g, steps) + sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps) + sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps) + sw.add_scalar("training/epoch", epoch + 1, steps) + + # Validation + if steps % a.validation_interval == 0: + # Plot training input x so far used + for i_x in range(x.shape[0]): + sw.add_figure( + f"training_input/x_{i_x}", + plot_spectrogram(x[i_x].cpu()), + steps, + ) + sw.add_audio( + f"training_input/y_{i_x}", + y[i_x][0], + steps, + h.sampling_rate, + ) + + # Seen and unseen speakers validation loops + if not a.debug and steps != 0: + validate( + rank, + a, + h, + validation_loader, + mode=f"seen_{train_loader.dataset.name}", + ) + for i in range(len(list_unseen_validation_loader)): + validate( + rank, + a, + h, + list_unseen_validation_loader[i], + mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}", + ) + steps += 1 + + # BigVGAN-v2 learning rate scheduler is changed from epoch-level to step-level + scheduler_g.step() + scheduler_d.step() + + if rank == 0: + print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n") + + +def main(): + print("Initializing Training Process..") + + parser = argparse.ArgumentParser() + + parser.add_argument("--group_name", default=None) + + parser.add_argument("--input_wavs_dir", default="LibriTTS") + parser.add_argument("--input_mels_dir", default="ft_dataset") + parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt") + parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt") + + parser.add_argument( + "--list_input_unseen_wavs_dir", + nargs="+", + default=["tests/LibriTTS", "tests/LibriTTS"], + ) + parser.add_argument( + "--list_input_unseen_validation_file", + nargs="+", + default=["tests/LibriTTS/dev-clean.txt", "tests/LibriTTS/dev-other.txt"], + ) + + parser.add_argument("--checkpoint_path", default="exp/bigvgan") + parser.add_argument("--config", default="") + + parser.add_argument("--training_epochs", default=100000, type=int) + parser.add_argument("--stdout_interval", default=5, type=int) + parser.add_argument("--checkpoint_interval", default=50000, type=int) + parser.add_argument("--summary_interval", default=100, type=int) + parser.add_argument("--validation_interval", default=50000, type=int) + + parser.add_argument( + "--freeze_step", + default=0, + type=int, + help="freeze D for the first specified steps. G only uses regression loss for these steps.", + ) + + parser.add_argument("--fine_tuning", default=False, type=bool) + + parser.add_argument( + "--debug", + default=False, + type=bool, + help="debug mode. skips validation loop throughout training", + ) + parser.add_argument( + "--evaluate", + default=False, + type=bool, + help="only run evaluation from checkpoint and exit", + ) + parser.add_argument( + "--eval_subsample", + default=5, + type=int, + help="subsampling during evaluation loop", + ) + parser.add_argument( + "--skip_seen", + default=False, + type=bool, + help="skip seen dataset. useful for test set inference", + ) + parser.add_argument( + "--save_audio", + default=False, + type=bool, + help="save audio of test set inference to disk", + ) + + a = parser.parse_args() + + with open(a.config) as f: + data = f.read() + + json_config = json.loads(data) + h = AttrDict(json_config) + + build_env(a.config, "config.json", a.checkpoint_path) + + torch.manual_seed(h.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(h.seed) + h.num_gpus = torch.cuda.device_count() + h.batch_size = int(h.batch_size / h.num_gpus) + print(f"Batch size per GPU: {h.batch_size}") + else: + pass + + if h.num_gpus > 1: + mp.spawn( + train, + nprocs=h.num_gpus, + args=( + a, + h, + ), + ) + else: + train(0, a, h) + + +if __name__ == "__main__": + main() diff --git a/BigVGAN/utils0.py b/BigVGAN/utils0.py new file mode 100644 index 0000000000000000000000000000000000000000..da98a24cf1447778305563f8e909f30b06e06b26 --- /dev/null +++ b/BigVGAN/utils0.py @@ -0,0 +1,99 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import glob +import os +import matplotlib +import torch +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") +import matplotlib.pylab as plt +from .meldataset import MAX_WAV_VALUE +from scipy.io.wavfile import write + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def plot_spectrogram_clipped(spectrogram, clip_max=2.0): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow( + spectrogram, + aspect="auto", + origin="lower", + interpolation="none", + vmin=1e-6, + vmax=clip_max, + ) + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print(f"Saving checkpoint to {filepath}") + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix, renamed_file=None): + # Fallback to original scanning logic first + pattern = os.path.join(cp_dir, prefix + "????????") + cp_list = glob.glob(pattern) + + if len(cp_list) > 0: + last_checkpoint_path = sorted(cp_list)[-1] + print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'") + return last_checkpoint_path + + # If no pattern-based checkpoints are found, check for renamed file + if renamed_file: + renamed_path = os.path.join(cp_dir, renamed_file) + if os.path.isfile(renamed_path): + print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'") + return renamed_path + + return None + + +def save_audio(audio, path, sr): + # wav: torch with 1d shape + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype("int16") + write(path, sr, audio) diff --git a/TTS_infer_pack/TTS.py b/TTS_infer_pack/TTS.py new file mode 100644 index 0000000000000000000000000000000000000000..795b55dde6e7bf792ee90778096feb6de6a439d0 --- /dev/null +++ b/TTS_infer_pack/TTS.py @@ -0,0 +1,1629 @@ +import gc +import math +import os +import random +import sys +import time +import traceback +from copy import deepcopy + +import torchaudio +from tqdm import tqdm + +now_dir = os.getcwd() +sys.path.append(now_dir) +import os +from typing import List, Tuple, Union + +import ffmpeg +import librosa +import numpy as np +import torch +import torch.nn.functional as F +import yaml +from AR.models.t2s_lightning_module import Text2SemanticLightningModule +from BigVGAN.bigvgan import BigVGAN +from feature_extractor.cnhubert import CNHubert +from module.mel_processing import mel_spectrogram_torch, spectrogram_torch +from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator +from peft import LoraConfig, get_peft_model +from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new +from transformers import AutoModelForMaskedLM, AutoTokenizer + +from tools.audio_sr import AP_BWE +from tools.i18n.i18n import I18nAuto, scan_language_list +from TTS_infer_pack.text_segmentation_method import splits +from TTS_infer_pack.TextPreprocessor import TextPreprocessor +from sv import SV + +resample_transform_dict = {} + + +def resample(audio_tensor, sr0, sr1, device): + global resample_transform_dict + key = "%s-%s-%s" % (sr0, sr1, str(device)) + if key not in resample_transform_dict: + resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) + return resample_transform_dict[key](audio_tensor) + + +language = os.environ.get("language", "Auto") +language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language +i18n = I18nAuto(language=language) + + +spec_min = -12 +spec_max = 2 + + +def norm_spec(x): + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 + + +def denorm_spec(x): + return (x + 1) / 2 * (spec_max - spec_min) + spec_min + + +mel_fn = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) + +mel_fn_v4 = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1280, + "win_size": 1280, + "hop_size": 320, + "num_mels": 100, + "sampling_rate": 32000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) + + +def speed_change(input_audio: np.ndarray, speed: float, sr: int): + # 将 NumPy 数组转换为原始 PCM 流 + raw_audio = input_audio.astype(np.int16).tobytes() + + # 设置 ffmpeg 输入流 + input_stream = ffmpeg.input("pipe:", format="s16le", acodec="pcm_s16le", ar=str(sr), ac=1) + + # 变速处理 + output_stream = input_stream.filter("atempo", speed) + + # 输出流到管道 + out, _ = output_stream.output("pipe:", format="s16le", acodec="pcm_s16le").run( + input=raw_audio, capture_stdout=True, capture_stderr=True + ) + + # 将管道输出解码为 NumPy 数组 + processed_audio = np.frombuffer(out, np.int16) + + return processed_audio + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +class NO_PROMPT_ERROR(Exception): + pass + + +# configs/tts_infer.yaml +""" +custom: + bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large + cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base + device: cpu + is_half: false + t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt + vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth + version: v2 +v1: + bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large + cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base + device: cpu + is_half: false + t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt + vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth + version: v1 +v2: + bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large + cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base + device: cpu + is_half: false + t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt + vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth + version: v2 +v3: + bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large + cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base + device: cpu + is_half: false + t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt + vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth + version: v3 +v4: + bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large + cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base + device: cpu + is_half: false + t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt + version: v4 + vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth +""" + + +def set_seed(seed: int): + seed = int(seed) + seed = seed if seed != -1 else random.randint(0, 2**32 - 1) + print(f"Set seed to {seed}") + os.environ["PYTHONHASHSEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + try: + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + # torch.backends.cudnn.enabled = True + # 开启后会影响精度 + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + except: + pass + return seed + + +class TTS_Config: + default_configs = { + "v1": { + "device": "cpu", + "is_half": False, + "version": "v1", + "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", + "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth", + "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + }, + "v2": { + "device": "cpu", + "is_half": False, + "version": "v2", + "t2s_weights_path": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", + "vits_weights_path": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", + "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + }, + "v3": { + "device": "cpu", + "is_half": False, + "version": "v3", + "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "vits_weights_path": "GPT_SoVITS/pretrained_models/s2Gv3.pth", + "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + }, + "v4": { + "device": "cpu", + "is_half": False, + "version": "v4", + "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "vits_weights_path": "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth", + "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + }, + "v2Pro": { + "device": "cpu", + "is_half": False, + "version": "v2Pro", + "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "vits_weights_path": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", + "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + }, + "v2ProPlus": { + "device": "cpu", + "is_half": False, + "version": "v2ProPlus", + "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "vits_weights_path": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth", + "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + }, + } + configs: dict = None + v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] + v2_languages: list = ["auto", "auto_yue", "en", "zh", "ja", "yue", "ko", "all_zh", "all_ja", "all_yue", "all_ko"] + languages: list = v2_languages + # "all_zh",#全部按中文识别 + # "en",#全部按英文识别#######不变 + # "all_ja",#全部按日文识别 + # "all_yue",#全部按中文识别 + # "all_ko",#全部按韩文识别 + # "zh",#按中英混合识别####不变 + # "ja",#按日英混合识别####不变 + # "yue",#按粤英混合识别####不变 + # "ko",#按韩英混合识别####不变 + # "auto",#多语种启动切分识别语种 + # "auto_yue",#多语种启动切分识别语种 + + def __init__(self, configs: Union[dict, str] = None): + # 设置默认配置文件路径 + configs_base_path: str = "GPT_SoVITS/configs/" + os.makedirs(configs_base_path, exist_ok=True) + self.configs_path: str = os.path.join(configs_base_path, "tts_infer.yaml") + + if configs in ["", None]: + if not os.path.exists(self.configs_path): + self.save_configs() + print(f"Create default config file at {self.configs_path}") + configs: dict = deepcopy(self.default_configs) + + if isinstance(configs, str): + self.configs_path = configs + configs: dict = self._load_configs(self.configs_path) + + assert isinstance(configs, dict) + version = configs.get("version", "v2").lower() + assert version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"] + self.default_configs[version] = configs.get(version, self.default_configs[version]) + self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version])) + + self.device = self.configs.get("device", torch.device("cpu")) + if "cuda" in str(self.device) and not torch.cuda.is_available(): + print("Warning: CUDA is not available, set device to CPU.") + self.device = torch.device("cpu") + + self.is_half = self.configs.get("is_half", False) + # if str(self.device) == "cpu" and self.is_half: + # print(f"Warning: Half precision is not supported on CPU, set is_half to False.") + # self.is_half = False + + self.version = version + self.t2s_weights_path = self.configs.get("t2s_weights_path", None) + self.vits_weights_path = self.configs.get("vits_weights_path", None) + self.bert_base_path = self.configs.get("bert_base_path", None) + self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) + self.languages = self.v1_languages if self.version == "v1" else self.v2_languages + + self.use_vocoder: bool = False + + if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): + self.t2s_weights_path = self.default_configs[version]["t2s_weights_path"] + print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}") + if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)): + self.vits_weights_path = self.default_configs[version]["vits_weights_path"] + print(f"fall back to default vits_weights_path: {self.vits_weights_path}") + if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)): + self.bert_base_path = self.default_configs[version]["bert_base_path"] + print(f"fall back to default bert_base_path: {self.bert_base_path}") + if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)): + self.cnhuhbert_base_path = self.default_configs[version]["cnhuhbert_base_path"] + print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}") + self.update_configs() + + self.max_sec = None + self.hz: int = 50 + self.semantic_frame_rate: str = "25hz" + self.segment_size: int = 20480 + self.filter_length: int = 2048 + self.sampling_rate: int = 32000 + self.hop_length: int = 640 + self.win_length: int = 2048 + self.n_speakers: int = 300 + + def _load_configs(self, configs_path: str) -> dict: + if os.path.exists(configs_path): + ... + else: + print(i18n("路径不存在,使用默认配置")) + self.save_configs(configs_path) + with open(configs_path, "r", encoding="utf-8") as f: + configs = yaml.load(f, Loader=yaml.FullLoader) + + return configs + + def save_configs(self, configs_path: str = None) -> None: + configs = deepcopy(self.default_configs) + if self.configs is not None: + configs["custom"] = self.update_configs() + + if configs_path is None: + configs_path = self.configs_path + with open(configs_path, "w") as f: + yaml.dump(configs, f) + + def update_configs(self): + self.config = { + "device": str(self.device), + "is_half": self.is_half, + "version": self.version, + "t2s_weights_path": self.t2s_weights_path, + "vits_weights_path": self.vits_weights_path, + "bert_base_path": self.bert_base_path, + "cnhuhbert_base_path": self.cnhuhbert_base_path, + } + return self.config + + def update_version(self, version: str) -> None: + self.version = version + self.languages = self.v1_languages if self.version == "v1" else self.v2_languages + + def __str__(self): + self.configs = self.update_configs() + string = "TTS Config".center(100, "-") + "\n" + for k, v in self.configs.items(): + string += f"{str(k).ljust(20)}: {str(v)}\n" + string += "-" * 100 + "\n" + return string + + def __repr__(self): + return self.__str__() + + def __hash__(self): + return hash(self.configs_path) + + def __eq__(self, other): + return isinstance(other, TTS_Config) and self.configs_path == other.configs_path + + +class TTS: + def __init__(self, configs: Union[dict, str, TTS_Config]): + if isinstance(configs, TTS_Config): + self.configs = configs + else: + self.configs: TTS_Config = TTS_Config(configs) + + self.t2s_model: Text2SemanticLightningModule = None + self.vits_model: Union[SynthesizerTrn, SynthesizerTrnV3] = None + self.bert_tokenizer: AutoTokenizer = None + self.bert_model: AutoModelForMaskedLM = None + self.cnhuhbert_model: CNHubert = None + self.vocoder = None + self.sr_model: AP_BWE = None + self.sv_model = None + self.sr_model_not_exist: bool = False + + self.vocoder_configs: dict = { + "sr": None, + "T_ref": None, + "T_chunk": None, + "upsample_rate": None, + "overlapped_len": None, + } + + self._init_models() + + self.text_preprocessor: TextPreprocessor = TextPreprocessor( + self.bert_model, self.bert_tokenizer, self.configs.device + ) + + self.prompt_cache: dict = { + "ref_audio_path": None, + "prompt_semantic": None, + "refer_spec": [], + "prompt_text": None, + "prompt_lang": None, + "phones": None, + "bert_features": None, + "norm_text": None, + "aux_ref_audio_paths": [], + } + + self.stop_flag: bool = False + self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32 + + def _init_models( + self, + ): + self.init_t2s_weights(self.configs.t2s_weights_path) + self.init_vits_weights(self.configs.vits_weights_path) + self.init_bert_weights(self.configs.bert_base_path) + self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path) + # self.enable_half_precision(self.configs.is_half) + + def init_cnhuhbert_weights(self, base_path: str): + print(f"Loading CNHuBERT weights from {base_path}") + self.cnhuhbert_model = CNHubert(base_path) + self.cnhuhbert_model = self.cnhuhbert_model.eval() + self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device) + if self.configs.is_half and str(self.configs.device) != "cpu": + self.cnhuhbert_model = self.cnhuhbert_model.half() + + def init_bert_weights(self, base_path: str): + print(f"Loading BERT weights from {base_path}") + self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path) + self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path) + self.bert_model = self.bert_model.eval() + self.bert_model = self.bert_model.to(self.configs.device) + if self.configs.is_half and str(self.configs.device) != "cpu": + self.bert_model = self.bert_model.half() + + def init_vits_weights(self, weights_path: str): + self.configs.vits_weights_path = weights_path + version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path) + if "Pro" in model_version: + self.init_sv_model() + path_sovits = self.configs.default_configs[model_version]["vits_weights_path"] + + if if_lora_v3 == True and os.path.exists(path_sovits) == False: + info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version) + raise FileExistsError(info) + + # dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False) + dict_s2 = load_sovits_new(weights_path) + hps = dict_s2["config"] + hps["model"]["semantic_frame_rate"] = "25hz" + if "enc_p.text_embedding.weight" not in dict_s2["weight"]: + hps["model"]["version"] = "v2" # v3model,v2sybomls + elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + hps["model"]["version"] = "v1" + else: + hps["model"]["version"] = "v2" + version = hps["model"]["version"] + v3v4set = {"v3", "v4"} + if model_version not in v3v4set: + if "Pro" not in model_version: + model_version = version + else: + hps["model"]["version"] = model_version + else: + hps["model"]["version"] = model_version + + self.configs.filter_length = hps["data"]["filter_length"] + self.configs.segment_size = hps["train"]["segment_size"] + self.configs.sampling_rate = hps["data"]["sampling_rate"] + self.configs.hop_length = hps["data"]["hop_length"] + self.configs.win_length = hps["data"]["win_length"] + self.configs.n_speakers = hps["data"]["n_speakers"] + self.configs.semantic_frame_rate = hps["model"]["semantic_frame_rate"] + kwargs = hps["model"] + # print(f"self.configs.sampling_rate:{self.configs.sampling_rate}") + + self.configs.update_version(model_version) + + # print(f"model_version:{model_version}") + # print(f'hps["model"]["version"]:{hps["model"]["version"]}') + if model_version not in v3v4set: + vits_model = SynthesizerTrn( + self.configs.filter_length // 2 + 1, + self.configs.segment_size // self.configs.hop_length, + n_speakers=self.configs.n_speakers, + **kwargs, + ) + self.configs.use_vocoder = False + else: + kwargs["version"] = model_version + vits_model = SynthesizerTrnV3( + self.configs.filter_length // 2 + 1, + self.configs.segment_size // self.configs.hop_length, + n_speakers=self.configs.n_speakers, + **kwargs, + ) + self.configs.use_vocoder = True + self.init_vocoder(model_version) + if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"): + del vits_model.enc_q + + self.is_v2pro = model_version in {"v2Pro", "v2ProPlus"} + + if if_lora_v3 == False: + print( + f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}" + ) + else: + print( + f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits)['weight'], strict=False)}" + ) + lora_rank = dict_s2["lora_rank"] + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + vits_model.cfm = get_peft_model(vits_model.cfm, lora_config) + print( + f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}" + ) + + vits_model.cfm = vits_model.cfm.merge_and_unload() + + vits_model = vits_model.to(self.configs.device) + vits_model = vits_model.eval() + + self.vits_model = vits_model + if self.configs.is_half and str(self.configs.device) != "cpu": + self.vits_model = self.vits_model.half() + + def init_t2s_weights(self, weights_path: str): + print(f"Loading Text2Semantic weights from {weights_path}") + self.configs.t2s_weights_path = weights_path + self.configs.save_configs() + self.configs.hz = 50 + dict_s1 = torch.load(weights_path, map_location=self.configs.device, weights_only=False) + config = dict_s1["config"] + self.configs.max_sec = config["data"]["max_sec"] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + t2s_model = t2s_model.to(self.configs.device) + t2s_model = t2s_model.eval() + self.t2s_model = t2s_model + if self.configs.is_half and str(self.configs.device) != "cpu": + self.t2s_model = self.t2s_model.half() + + def init_vocoder(self, version: str): + if version == "v3": + if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN": + return + if self.vocoder is not None: + self.vocoder.cpu() + del self.vocoder + self.empty_cache() + + self.vocoder = BigVGAN.from_pretrained( + "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), + use_cuda_kernel=False, + ) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + self.vocoder.remove_weight_norm() + + self.vocoder_configs["sr"] = 24000 + self.vocoder_configs["T_ref"] = 468 + self.vocoder_configs["T_chunk"] = 934 + self.vocoder_configs["upsample_rate"] = 256 + self.vocoder_configs["overlapped_len"] = 12 + + elif version == "v4": + if self.vocoder is not None and self.vocoder.__class__.__name__ == "Generator": + return + if self.vocoder is not None: + self.vocoder.cpu() + del self.vocoder + self.empty_cache() + + self.vocoder = Generator( + initial_channel=100, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[10, 6, 2, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[20, 12, 4, 4, 4], + gin_channels=0, + is_bias=True, + ) + self.vocoder.remove_weight_norm() + state_dict_g = torch.load( + "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), + map_location="cpu", + weights_only=False, + ) + print("loading vocoder", self.vocoder.load_state_dict(state_dict_g)) + + self.vocoder_configs["sr"] = 48000 + self.vocoder_configs["T_ref"] = 500 + self.vocoder_configs["T_chunk"] = 1000 + self.vocoder_configs["upsample_rate"] = 480 + self.vocoder_configs["overlapped_len"] = 12 + + self.vocoder = self.vocoder.eval() + if self.configs.is_half == True: + self.vocoder = self.vocoder.half().to(self.configs.device) + else: + self.vocoder = self.vocoder.to(self.configs.device) + + def init_sr_model(self): + if self.sr_model is not None: + return + try: + self.sr_model: AP_BWE = AP_BWE(self.configs.device, DictToAttrRecursive) + self.sr_model_not_exist = False + except FileNotFoundError: + print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好")) + self.sr_model_not_exist = True + + def init_sv_model(self): + if self.sv_model is not None: + return + self.sv_model = SV(self.configs.device, self.configs.is_half) + + def enable_half_precision(self, enable: bool = True, save: bool = True): + """ + To enable half precision for the TTS model. + Args: + enable: bool, whether to enable half precision. + + """ + if str(self.configs.device) == "cpu" and enable: + print("Half precision is not supported on CPU.") + return + + self.configs.is_half = enable + self.precision = torch.float16 if enable else torch.float32 + if save: + self.configs.save_configs() + if enable: + if self.t2s_model is not None: + self.t2s_model = self.t2s_model.half() + if self.vits_model is not None: + self.vits_model = self.vits_model.half() + if self.bert_model is not None: + self.bert_model = self.bert_model.half() + if self.cnhuhbert_model is not None: + self.cnhuhbert_model = self.cnhuhbert_model.half() + if self.vocoder is not None: + self.vocoder = self.vocoder.half() + else: + if self.t2s_model is not None: + self.t2s_model = self.t2s_model.float() + if self.vits_model is not None: + self.vits_model = self.vits_model.float() + if self.bert_model is not None: + self.bert_model = self.bert_model.float() + if self.cnhuhbert_model is not None: + self.cnhuhbert_model = self.cnhuhbert_model.float() + if self.vocoder is not None: + self.vocoder = self.vocoder.float() + + def set_device(self, device: torch.device, save: bool = True): + """ + To set the device for all models. + Args: + device: torch.device, the device to use for all models. + """ + self.configs.device = device + if save: + self.configs.save_configs() + if self.t2s_model is not None: + self.t2s_model = self.t2s_model.to(device) + if self.vits_model is not None: + self.vits_model = self.vits_model.to(device) + if self.bert_model is not None: + self.bert_model = self.bert_model.to(device) + if self.cnhuhbert_model is not None: + self.cnhuhbert_model = self.cnhuhbert_model.to(device) + if self.vocoder is not None: + self.vocoder = self.vocoder.to(device) + if self.sr_model is not None: + self.sr_model = self.sr_model.to(device) + + def set_ref_audio(self, ref_audio_path: str): + """ + To set the reference audio for the TTS model, + including the prompt_semantic and refer_spepc. + Args: + ref_audio_path: str, the path of the reference audio. + """ + self._set_prompt_semantic(ref_audio_path) + self._set_ref_spec(ref_audio_path) + self._set_ref_audio_path(ref_audio_path) + + def _set_ref_audio_path(self, ref_audio_path): + self.prompt_cache["ref_audio_path"] = ref_audio_path + + def _set_ref_spec(self, ref_audio_path): + spec_audio = self._get_ref_spec(ref_audio_path) + if self.prompt_cache["refer_spec"] in [[], None]: + self.prompt_cache["refer_spec"] = [spec_audio] + else: + self.prompt_cache["refer_spec"][0] = spec_audio + + def _get_ref_spec(self, ref_audio_path): + raw_audio, raw_sr = torchaudio.load(ref_audio_path) + raw_audio = raw_audio.to(self.configs.device).float() + self.prompt_cache["raw_audio"] = raw_audio + self.prompt_cache["raw_sr"] = raw_sr + + if raw_sr != self.configs.sampling_rate: + audio = raw_audio.to(self.configs.device) + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) + audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) + else: + audio = raw_audio.to(self.configs.device) + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) + + maxx = audio.abs().max() + if maxx > 1: + audio /= min(2, maxx) + spec = spectrogram_torch( + audio, + self.configs.filter_length, + self.configs.sampling_rate, + self.configs.hop_length, + self.configs.win_length, + center=False, + ) + if self.configs.is_half: + spec = spec.half() + if self.is_v2pro == True: + audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device) + if self.configs.is_half: + audio = audio.half() + else: + audio = None + return spec, audio + + def _set_prompt_semantic(self, ref_wav_path: str): + zero_wav = np.zeros( + int(self.configs.sampling_rate * 0.3), + dtype=np.float16 if self.configs.is_half else np.float32, + ) + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: + raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + wav16k = wav16k.to(self.configs.device) + zero_wav_torch = zero_wav_torch.to(self.configs.device) + if self.configs.is_half: + wav16k = wav16k.half() + zero_wav_torch = zero_wav_torch.half() + + wav16k = torch.cat([wav16k, zero_wav_torch]) + hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( + 1, 2 + ) # .float() + codes = self.vits_model.extract_latent(hubert_feature) + + prompt_semantic = codes[0, 0].to(self.configs.device) + self.prompt_cache["prompt_semantic"] = prompt_semantic + + def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None): + seq = sequences[0] + ndim = seq.dim() + if axis < 0: + axis += ndim + dtype: torch.dtype = seq.dtype + pad_value = torch.tensor(pad_value, dtype=dtype) + seq_lengths = [seq.shape[axis] for seq in sequences] + if max_length is None: + max_length = max(seq_lengths) + else: + max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length + + padded_sequences = [] + for seq, length in zip(sequences, seq_lengths): + padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1) + padded_seq = torch.nn.functional.pad(seq, padding, value=pad_value) + padded_sequences.append(padded_seq) + batch = torch.stack(padded_sequences) + return batch + + def to_batch( + self, + data: list, + prompt_data: dict = None, + batch_size: int = 5, + threshold: float = 0.75, + split_bucket: bool = True, + device: torch.device = torch.device("cpu"), + precision: torch.dtype = torch.float32, + ): + _data: list = [] + index_and_len_list = [] + for idx, item in enumerate(data): + norm_text_len = len(item["norm_text"]) + index_and_len_list.append([idx, norm_text_len]) + + batch_index_list = [] + if split_bucket: + index_and_len_list.sort(key=lambda x: x[1]) + index_and_len_list = np.array(index_and_len_list, dtype=np.int64) + + batch_index_list_len = 0 + pos = 0 + while pos < index_and_len_list.shape[0]: + # batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))]) + pos_end = min(pos + batch_size, index_and_len_list.shape[0]) + while pos < pos_end: + batch = index_and_len_list[pos:pos_end, 1].astype(np.float32) + score = batch[(pos_end - pos) // 2] / (batch.mean() + 1e-8) + if (score >= threshold) or (pos_end - pos == 1): + batch_index = index_and_len_list[pos:pos_end, 0].tolist() + batch_index_list_len += len(batch_index) + batch_index_list.append(batch_index) + pos = pos_end + break + pos_end = pos_end - 1 + + assert batch_index_list_len == len(data) + + else: + for i in range(len(data)): + if i % batch_size == 0: + batch_index_list.append([]) + batch_index_list[-1].append(i) + + for batch_idx, index_list in enumerate(batch_index_list): + item_list = [data[idx] for idx in index_list] + phones_list = [] + phones_len_list = [] + # bert_features_list = [] + all_phones_list = [] + all_phones_len_list = [] + all_bert_features_list = [] + norm_text_batch = [] + all_bert_max_len = 0 + all_phones_max_len = 0 + for item in item_list: + if prompt_data is not None: + all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1).to( + dtype=precision, device=device + ) + all_phones = torch.LongTensor(prompt_data["phones"] + item["phones"]).to(device) + phones = torch.LongTensor(item["phones"]).to(device) + # norm_text = prompt_data["norm_text"]+item["norm_text"] + else: + all_bert_features = item["bert_features"].to(dtype=precision, device=device) + phones = torch.LongTensor(item["phones"]).to(device) + all_phones = phones + # norm_text = item["norm_text"] + + all_bert_max_len = max(all_bert_max_len, all_bert_features.shape[-1]) + all_phones_max_len = max(all_phones_max_len, all_phones.shape[-1]) + + phones_list.append(phones) + phones_len_list.append(phones.shape[-1]) + all_phones_list.append(all_phones) + all_phones_len_list.append(all_phones.shape[-1]) + all_bert_features_list.append(all_bert_features) + norm_text_batch.append(item["norm_text"]) + + phones_batch = phones_list + all_phones_batch = all_phones_list + all_bert_features_batch = all_bert_features_list + + max_len = max(all_bert_max_len, all_phones_max_len) + # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) + #### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略) + # all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) + # all_bert_features_batch = all_bert_features_list + # all_bert_features_batch = torch.zeros((len(all_bert_features_list), 1024, max_len), dtype=precision, device=device) + # for idx, item in enumerate(all_bert_features_list): + # all_bert_features_batch[idx, :, : item.shape[-1]] = item + + # #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略) + # all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list] + # all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list] + # all_phones_batch = torch.stack(all_phones_list, dim=0) + + # all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list] + # all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list] + # all_bert_features_batch = torch.stack(all_bert_features_list, dim=0) + + batch = { + "phones": phones_batch, + "phones_len": torch.LongTensor(phones_len_list).to(device), + "all_phones": all_phones_batch, + "all_phones_len": torch.LongTensor(all_phones_len_list).to(device), + "all_bert_features": all_bert_features_batch, + "norm_text": norm_text_batch, + "max_len": max_len, + } + _data.append(batch) + + return _data, batch_index_list + + def recovery_order(self, data: list, batch_index_list: list) -> list: + """ + Recovery the order of the audio according to the batch_index_list. + + Args: + data (List[list(torch.Tensor)]): the out of order audio . + batch_index_list (List[list[int]]): the batch index list. + + Returns: + list (List[torch.Tensor]): the data in the original order. + """ + length = len(sum(batch_index_list, [])) + _data = [None] * length + for i, index_list in enumerate(batch_index_list): + for j, index in enumerate(index_list): + _data[index] = data[i][j] + return _data + + def stop( + self, + ): + """ + Stop the inference process. + """ + self.stop_flag = True + + @torch.no_grad() + def run(self, inputs: dict): + """ + Text to speech inference. + + Args: + inputs (dict): + { + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 5, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket: True, # bool. whether to split the batch into multiple buckets. + "return_fragment": False, # bool. step by step return the audio fragment. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35 # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + } + returns: + Tuple[int, np.ndarray]: sampling rate and audio data. + """ + ########## variables initialization ########### + self.stop_flag: bool = False + text: str = inputs.get("text", "") + text_lang: str = inputs.get("text_lang", "") + ref_audio_path: str = inputs.get("ref_audio_path", "") + aux_ref_audio_paths: list = inputs.get("aux_ref_audio_paths", []) + prompt_text: str = inputs.get("prompt_text", "") + prompt_lang: str = inputs.get("prompt_lang", "") + top_k: int = inputs.get("top_k", 5) + top_p: float = inputs.get("top_p", 1) + temperature: float = inputs.get("temperature", 1) + text_split_method: str = inputs.get("text_split_method", "cut0") + batch_size = inputs.get("batch_size", 1) + batch_threshold = inputs.get("batch_threshold", 0.75) + speed_factor = inputs.get("speed_factor", 1.0) + split_bucket = inputs.get("split_bucket", True) + return_fragment = inputs.get("return_fragment", False) + fragment_interval = inputs.get("fragment_interval", 0.3) + seed = inputs.get("seed", -1) + seed = -1 if seed in ["", None] else seed + actual_seed = set_seed(seed) + parallel_infer = inputs.get("parallel_infer", True) + repetition_penalty = inputs.get("repetition_penalty", 1.35) + sample_steps = inputs.get("sample_steps", 32) + super_sampling = inputs.get("super_sampling", False) + + if parallel_infer: + print(i18n("并行推理模式已开启")) + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer + else: + print(i18n("并行推理模式已关闭")) + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched + + if return_fragment: + print(i18n("分段返回模式已开启")) + if split_bucket: + split_bucket = False + print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) + + if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer): + print(i18n("分桶处理模式已开启")) + elif speed_factor != 1.0: + print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理")) + split_bucket = False + elif self.configs.use_vocoder and parallel_infer: + print(i18n("当开启并行推理模式时,SoVits V3/4模型不支持分桶处理,已自动关闭分桶处理")) + split_bucket = False + else: + print(i18n("分桶处理模式已关闭")) + + if fragment_interval < 0.01: + fragment_interval = 0.01 + print(i18n("分段间隔过小,已自动设置为0.01")) + + no_prompt_text = False + if prompt_text in [None, ""]: + no_prompt_text = True + + assert text_lang in self.configs.languages + if not no_prompt_text: + assert prompt_lang in self.configs.languages + + if no_prompt_text and self.configs.use_vocoder: + raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3") + + if ref_audio_path in [None, ""] and ( + (self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []]) + ): + raise ValueError( + "ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()" + ) + + ###### setting reference audio and prompt text preprocessing ######## + t0 = time.perf_counter() + if (ref_audio_path is not None) and ( + ref_audio_path != self.prompt_cache["ref_audio_path"] + or (self.is_v2pro and self.prompt_cache["refer_spec"][0][1] is None) + ): + if not os.path.exists(ref_audio_path): + raise ValueError(f"{ref_audio_path} not exists") + self.set_ref_audio(ref_audio_path) + + aux_ref_audio_paths = aux_ref_audio_paths if aux_ref_audio_paths is not None else [] + paths = set(aux_ref_audio_paths) & set(self.prompt_cache["aux_ref_audio_paths"]) + if not (len(list(paths)) == len(aux_ref_audio_paths) == len(self.prompt_cache["aux_ref_audio_paths"])): + self.prompt_cache["aux_ref_audio_paths"] = aux_ref_audio_paths + self.prompt_cache["refer_spec"] = [self.prompt_cache["refer_spec"][0]] + for path in aux_ref_audio_paths: + if path in [None, ""]: + continue + if not os.path.exists(path): + print(i18n("音频文件不存在,跳过:"), path) + continue + self.prompt_cache["refer_spec"].append(self._get_ref_spec(path)) + + if not no_prompt_text: + prompt_text = prompt_text.strip("\n") + if prompt_text[-1] not in splits: + prompt_text += "。" if prompt_lang != "en" else "." + print(i18n("实际输入的参考文本:"), prompt_text) + if self.prompt_cache["prompt_text"] != prompt_text: + phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text( + prompt_text, prompt_lang, self.configs.version + ) + self.prompt_cache["prompt_text"] = prompt_text + self.prompt_cache["prompt_lang"] = prompt_lang + self.prompt_cache["phones"] = phones + self.prompt_cache["bert_features"] = bert_features + self.prompt_cache["norm_text"] = norm_text + + ###### text preprocessing ######## + t1 = time.perf_counter() + data: list = None + if not return_fragment: + data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) + if len(data) == 0: + yield 16000, np.zeros(int(16000), dtype=np.int16) + return + + batch_index_list: list = None + data, batch_index_list = self.to_batch( + data, + prompt_data=self.prompt_cache if not no_prompt_text else None, + batch_size=batch_size, + threshold=batch_threshold, + split_bucket=split_bucket, + device=self.configs.device, + precision=self.precision, + ) + else: + print(f"############ {i18n('切分文本')} ############") + texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method) + data = [] + for i in range(len(texts)): + if i % batch_size == 0: + data.append([]) + data[-1].append(texts[i]) + + def make_batch(batch_texts): + batch_data = [] + print(f"############ {i18n('提取文本Bert特征')} ############") + for text in tqdm(batch_texts): + phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text( + text, text_lang, self.configs.version + ) + if phones is None: + continue + res = { + "phones": phones, + "bert_features": bert_features, + "norm_text": norm_text, + } + batch_data.append(res) + if len(batch_data) == 0: + return None + batch, _ = self.to_batch( + batch_data, + prompt_data=self.prompt_cache if not no_prompt_text else None, + batch_size=batch_size, + threshold=batch_threshold, + split_bucket=False, + device=self.configs.device, + precision=self.precision, + ) + return batch[0] + + t2 = time.perf_counter() + try: + print("############ 推理 ############") + ###### inference ###### + t_34 = 0.0 + t_45 = 0.0 + audio = [] + output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"] + for item in data: + t3 = time.perf_counter() + if return_fragment: + item = make_batch(item) + if item is None: + continue + + batch_phones: List[torch.LongTensor] = item["phones"] + # batch_phones:torch.LongTensor = item["phones"] + batch_phones_len: torch.LongTensor = item["phones_len"] + all_phoneme_ids: torch.LongTensor = item["all_phones"] + all_phoneme_lens: torch.LongTensor = item["all_phones_len"] + all_bert_features: torch.LongTensor = item["all_bert_features"] + norm_text: str = item["norm_text"] + max_len = item["max_len"] + + print(i18n("前端处理后的文本(每句):"), norm_text) + if no_prompt_text: + prompt = None + else: + prompt = ( + self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) + ) + + print(f"############ {i18n('预测语义Token')} ############") + pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_lens, + prompt, + all_bert_features, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + max_len=max_len, + repetition_penalty=repetition_penalty, + ) + t4 = time.perf_counter() + t_34 += t4 - t3 + + refer_audio_spec = [] + if self.is_v2pro: + sv_emb = [] + for spec, audio_tensor in self.prompt_cache["refer_spec"]: + spec = spec.to(dtype=self.precision, device=self.configs.device) + refer_audio_spec.append(spec) + if self.is_v2pro: + sv_emb.append(self.sv_model.compute_embedding3(audio_tensor)) + + batch_audio_fragment = [] + + # ## vits并行推理 method 1 + # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) + # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) + # max_len = 0 + # for i in range(0, len(batch_phones)): + # max_len = max(max_len, batch_phones[i].shape[-1]) + # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) + # batch_phones = batch_phones.to(self.configs.device) + # batch_audio_fragment = (self.vits_model.batched_decode( + # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec + # )) + print(f"############ {i18n('合成音频')} ############") + if not self.configs.use_vocoder: + if speed_factor == 1.0: + print(f"{i18n('并行合成中')}...") + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [ + pred_semantic_list[i].shape[0] * 2 * upsample_rate + for i in range(0, len(pred_semantic_list)) + ] + audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = ( + torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) + ) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + if self.is_v2pro != True: + _batch_audio_fragment = self.vits_model.decode( + all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor + ).detach()[0, 0, :] + else: + _batch_audio_fragment = self.vits_model.decode( + all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb + ).detach()[0, 0, :] + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment = [ + _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] + for i in range(1, len(audio_frag_end_idx)) + ] + else: + # ## vits串行推理 + for i, idx in enumerate(tqdm(idx_list)): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = ( + pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) + ) # .unsqueeze(0)#mq要多unsqueeze一次 + if self.is_v2pro != True: + audio_fragment = self.vits_model.decode( + _pred_semantic, phones, refer_audio_spec, speed=speed_factor + ).detach()[0, 0, :] + else: + audio_fragment = self.vits_model.decode( + _pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb + ).detach()[0, 0, :] + batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 + else: + if parallel_infer: + print(f"{i18n('并行合成中')}...") + audio_fragments = self.using_vocoder_synthesis_batched_infer( + idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps + ) + batch_audio_fragment.extend(audio_fragments) + else: + for i, idx in enumerate(tqdm(idx_list)): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = ( + pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) + ) # .unsqueeze(0)#mq要多unsqueeze一次 + audio_fragment = self.using_vocoder_synthesis( + _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + ) + batch_audio_fragment.append(audio_fragment) + + t5 = time.perf_counter() + t_45 += t5 - t4 + if return_fragment: + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) + yield self.audio_postprocess( + [batch_audio_fragment], + output_sr, + None, + speed_factor, + False, + fragment_interval, + super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, + ) + else: + audio.append(batch_audio_fragment) + + if self.stop_flag: + yield 16000, np.zeros(int(16000), dtype=np.int16) + return + + if not return_fragment: + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + if len(audio) == 0: + yield 16000, np.zeros(int(16000), dtype=np.int16) + return + yield self.audio_postprocess( + audio, + output_sr, + batch_index_list, + speed_factor, + split_bucket, + fragment_interval, + super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, + ) + + except Exception as e: + traceback.print_exc() + # 必须返回一个空音频, 否则会导致显存不释放。 + yield 16000, np.zeros(int(16000), dtype=np.int16) + # 重置模型, 否则会导致显存释放不完全。 + del self.t2s_model + del self.vits_model + self.t2s_model = None + self.vits_model = None + self.init_t2s_weights(self.configs.t2s_weights_path) + self.init_vits_weights(self.configs.vits_weights_path) + raise e + finally: + self.empty_cache() + + def empty_cache(self): + try: + gc.collect() # 触发gc的垃圾回收。避免内存一直增长。 + if "cuda" in str(self.configs.device): + torch.cuda.empty_cache() + elif str(self.configs.device) == "mps": + torch.mps.empty_cache() + except: + pass + + def audio_postprocess( + self, + audio: List[torch.Tensor], + sr: int, + batch_index_list: list = None, + speed_factor: float = 1.0, + split_bucket: bool = True, + fragment_interval: float = 0.3, + super_sampling: bool = False, + ) -> Tuple[int, np.ndarray]: + zero_wav = torch.zeros( + int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device + ) + + for i, batch in enumerate(audio): + for j, audio_fragment in enumerate(batch): + max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音 + if max_audio > 1: + audio_fragment /= max_audio + audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) + audio[i][j] = audio_fragment + + if split_bucket: + audio = self.recovery_order(audio, batch_index_list) + else: + # audio = [item for batch in audio for item in batch] + audio = sum(audio, []) + + audio = torch.cat(audio, dim=0) + + if super_sampling: + print(f"############ {i18n('音频超采样')} ############") + t1 = time.perf_counter() + self.init_sr_model() + if not self.sr_model_not_exist: + audio, sr = self.sr_model(audio.unsqueeze(0), sr) + max_audio = np.abs(audio).max() + if max_audio > 1: + audio /= max_audio + t2 = time.perf_counter() + print(f"超采样用时:{t2 - t1:.3f}s") + else: + audio = audio.cpu().numpy() + + audio = (audio * 32768).astype(np.int16) + + # try: + # if speed_factor != 1.0: + # audio = speed_change(audio, speed=speed_factor, sr=int(sr)) + # except Exception as e: + # print(f"Failed to change speed of audio: \n{e}") + + return sr, audio + + def using_vocoder_synthesis( + self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32 + ): + prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) + prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) + raw_entry = self.prompt_cache["refer_spec"][0] + if isinstance(raw_entry, tuple): + raw_entry = raw_entry[0] + refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device) + + fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + ref_audio: torch.Tensor = self.prompt_cache["raw_audio"] + ref_sr = self.prompt_cache["raw_sr"] + ref_audio = ref_audio.to(self.configs.device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + + # tgt_sr = self.vocoder_configs["sr"] + tgt_sr = 24000 if self.configs.version == "v3" else 32000 + if ref_sr != tgt_sr: + ref_audio = resample(ref_audio, ref_sr, tgt_sr, self.configs.device) + + mel2 = mel_fn(ref_audio) if self.configs.version == "v3" else mel_fn_v4(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + T_ref = self.vocoder_configs["T_ref"] + T_chunk = self.vocoder_configs["T_chunk"] + if T_min > T_ref: + mel2 = mel2[:, :, -T_ref:] + fea_ref = fea_ref[:, :, -T_ref:] + T_min = T_ref + chunk_len = T_chunk - T_min + + mel2 = mel2.to(self.precision) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + + cfm_resss = [] + idx = 0 + while 1: + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + + cfm_res = self.vits_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + + cfm_resss.append(cfm_res) + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + + with torch.inference_mode(): + wav_gen = self.vocoder(cfm_res) + audio = wav_gen[0][0] # .cpu().detach().numpy() + + return audio + + def using_vocoder_synthesis_batched_infer( + self, + idx_list: List[int], + semantic_tokens_list: List[torch.Tensor], + batch_phones: List[torch.Tensor], + speed: float = 1.0, + sample_steps: int = 32, + ) -> List[torch.Tensor]: + prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) + prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) + raw_entry = self.prompt_cache["refer_spec"][0] + if isinstance(raw_entry, tuple): + raw_entry = raw_entry[0] + refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device) + + fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + ref_audio: torch.Tensor = self.prompt_cache["raw_audio"] + ref_sr = self.prompt_cache["raw_sr"] + ref_audio = ref_audio.to(self.configs.device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + + # tgt_sr = self.vocoder_configs["sr"] + tgt_sr = 24000 if self.configs.version == "v3" else 32000 + if ref_sr != tgt_sr: + ref_audio = resample(ref_audio, ref_sr, tgt_sr, self.configs.device) + + mel2 = mel_fn(ref_audio) if self.configs.version == "v3" else mel_fn_v4(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + T_ref = self.vocoder_configs["T_ref"] + T_chunk = self.vocoder_configs["T_chunk"] + if T_min > T_ref: + mel2 = mel2[:, :, -T_ref:] + fea_ref = fea_ref[:, :, -T_ref:] + T_min = T_ref + chunk_len = T_chunk - T_min + + mel2 = mel2.to(self.precision) + + # #### batched inference + overlapped_len = self.vocoder_configs["overlapped_len"] + feat_chunks = [] + feat_lens = [] + feat_list = [] + + for i, idx in enumerate(idx_list): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + semantic_tokens = ( + semantic_tokens_list[i][-idx:].unsqueeze(0).unsqueeze(0) + ) # .unsqueeze(0)#mq要多unsqueeze一次 + feat, _ = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + feat_list.append(feat) + feat_lens.append(feat.shape[2]) + + feats = torch.cat(feat_list, 2) + feats_padded = F.pad(feats, (overlapped_len, 0), "constant", 0) + pos = 0 + padding_len = 0 + while True: + if pos == 0: + chunk = feats_padded[:, :, pos : pos + chunk_len] + else: + pos = pos - overlapped_len + chunk = feats_padded[:, :, pos : pos + chunk_len] + pos += chunk_len + if chunk.shape[-1] == 0: + break + + # padding for the last chunk + padding_len = chunk_len - chunk.shape[2] + if padding_len != 0: + chunk = F.pad(chunk, (0, padding_len), "constant", 0) + feat_chunks.append(chunk) + + feat_chunks = torch.cat(feat_chunks, 0) + bs = feat_chunks.shape[0] + fea_ref = fea_ref.repeat(bs, 1, 1) + fea = torch.cat([fea_ref, feat_chunks], 2).transpose(2, 1) + pred_spec = self.vits_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + pred_spec = pred_spec[:, :, -chunk_len:] + dd = pred_spec.shape[1] + pred_spec = pred_spec.permute(1, 0, 2).contiguous().view(dd, -1).unsqueeze(0) + # pred_spec = pred_spec[..., :-padding_len] + + pred_spec = denorm_spec(pred_spec) + + with torch.no_grad(): + wav_gen = self.vocoder(pred_spec) + audio = wav_gen[0][0] # .cpu().detach().numpy() + + audio_fragments = [] + upsample_rate = self.vocoder_configs["upsample_rate"] + pos = 0 + + while pos < audio.shape[-1]: + audio_fragment = audio[pos : pos + chunk_len * upsample_rate] + audio_fragments.append(audio_fragment) + pos += chunk_len * upsample_rate + + audio = self.sola_algorithm(audio_fragments, overlapped_len * upsample_rate) + audio = audio[overlapped_len * upsample_rate : -padding_len * upsample_rate] + + audio_fragments = [] + for feat_len in feat_lens: + audio_fragment = audio[: feat_len * upsample_rate] + audio_fragments.append(audio_fragment) + audio = audio[feat_len * upsample_rate :] + + return audio_fragments + + def sola_algorithm( + self, + audio_fragments: List[torch.Tensor], + overlap_len: int, + ): + for i in range(len(audio_fragments) - 1): + f1 = audio_fragments[i] + f2 = audio_fragments[i + 1] + w1 = f1[-overlap_len:] + w2 = f2[:overlap_len] + assert w1.shape == w2.shape + corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1), padding=w2.shape[-1] // 2).view(-1)[:-1] + idx = corr.argmax() + f1_ = f1[: -(overlap_len - idx)] + audio_fragments[i] = f1_ + + f2_ = f2[idx:] + window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype) + f2_[: (overlap_len - idx)] = ( + window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)] + + window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :] + ) + audio_fragments[i + 1] = f2_ + + return torch.cat(audio_fragments, 0) diff --git a/TTS_infer_pack/TextPreprocessor.py b/TTS_infer_pack/TextPreprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..9a478d431fabe9b95f665c66ffb762dd6522a7c0 --- /dev/null +++ b/TTS_infer_pack/TextPreprocessor.py @@ -0,0 +1,243 @@ +import os +import sys +import threading + +from tqdm import tqdm + +now_dir = os.getcwd() +sys.path.append(now_dir) + +import re +import torch +from text.LangSegmenter import LangSegmenter +from text import chinese +from typing import Dict, List, Tuple +from text.cleaner import clean_text +from text import cleaned_text_to_sequence +from transformers import AutoModelForMaskedLM, AutoTokenizer +from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method + +from tools.i18n.i18n import I18nAuto, scan_language_list + +language = os.environ.get("language", "Auto") +language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language +i18n = I18nAuto(language=language) +punctuation = set(["!", "?", "…", ",", ".", "-"]) + + +def get_first(text: str) -> str: + pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" + text = re.split(pattern, text)[0].strip() + return text + + +def merge_short_text_in_array(texts: str, threshold: int) -> list: + if (len(texts)) < 2: + return texts + result = [] + text = "" + for ele in texts: + text += ele + if len(text) >= threshold: + result.append(text) + text = "" + if len(text) > 0: + if len(result) == 0: + result.append(text) + else: + result[len(result) - 1] += text + return result + + +class TextPreprocessor: + def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device): + self.bert_model = bert_model + self.tokenizer = tokenizer + self.device = device + self.bert_lock = threading.RLock() + + def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: + print(f"############ {i18n('切分文本')} ############") + text = self.replace_consecutive_punctuation(text) + texts = self.pre_seg_text(text, lang, text_split_method) + result = [] + print(f"############ {i18n('提取文本Bert特征')} ############") + for text in tqdm(texts): + phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) + if phones is None or norm_text == "": + continue + res = { + "phones": phones, + "bert_features": bert_features, + "norm_text": norm_text, + } + result.append(res) + return result + + def pre_seg_text(self, text: str, lang: str, text_split_method: str): + text = text.strip("\n") + if len(text) == 0: + return [] + if text[0] not in splits and len(get_first(text)) < 4: + text = "。" + text if lang != "en" else "." + text + print(i18n("实际输入的目标文本:")) + print(text) + + seg_method = get_seg_method(text_split_method) + text = seg_method(text) + + while "\n\n" in text: + text = text.replace("\n\n", "\n") + + _texts = text.split("\n") + _texts = self.filter_text(_texts) + _texts = merge_short_text_in_array(_texts, 5) + texts = [] + + for text in _texts: + # 解决输入目标文本的空行导致报错的问题 + if len(text.strip()) == 0: + continue + if not re.sub("\W+", "", text): + # 检测一下,如果是纯符号,就跳过。 + continue + if text[-1] not in splits: + text += "。" if lang != "en" else "." + + # 解决句子过长导致Bert报错的问题 + if len(text) > 510: + texts.extend(split_big_text(text)) + else: + texts.append(text) + + print(i18n("实际输入的目标文本(切句后):")) + print(texts) + return texts + + def segment_and_extract_feature_for_text( + self, text: str, language: str, version: str = "v1" + ) -> Tuple[list, torch.Tensor, str]: + return self.get_phones_and_bert(text, language, version) + + def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False): + with self.bert_lock: + if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: + # language = language.replace("all_","") + formattext = text + while " " in formattext: + formattext = formattext.replace(" ", " ") + if language == "all_zh": + if re.search(r"[A-Za-z]", formattext): + formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return self.get_phones_and_bert(formattext, "zh", version) + else: + phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) + bert = self.get_bert_feature(norm_text, word2ph).to(self.device) + elif language == "all_yue" and re.search(r"[A-Za-z]", formattext): + formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return self.get_phones_and_bert(formattext, "yue", version) + else: + phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float32, + ).to(self.device) + elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: + textlist = [] + langlist = [] + if language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + if (tmp["lang"] == "en" and langlist[-1] == "en") or ( + tmp["lang"] != "en" and langlist[-1] != "en" + ): + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日韩文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + # print(textlist) + # print(langlist) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) + bert = self.get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) + + if not final and len(phones) < 6: + return self.get_phones_and_bert("." + text, language, version, final=True) + + return phones, bert, norm_text + + def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor: + with torch.no_grad(): + inputs = self.tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(self.device) + res = self.bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + assert len(word2ph) == len(text) + phone_level_feature = [] + for i in range(len(word2ph)): + repeat_feature = res[i].repeat(word2ph[i], 1) + phone_level_feature.append(repeat_feature) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + return phone_level_feature.T + + def clean_text_inf(self, text: str, language: str, version: str = "v2"): + language = language.replace("all_", "") + phones, word2ph, norm_text = clean_text(text, language, version) + phones = cleaned_text_to_sequence(phones, version) + return phones, word2ph, norm_text + + def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str): + language = language.replace("all_", "") + if language == "zh": + feature = self.get_bert_feature(norm_text, word2ph).to(self.device) + else: + feature = torch.zeros( + (1024, len(phones)), + dtype=torch.float32, + ).to(self.device) + + return feature + + def filter_text(self, texts): + _text = [] + if all(text in [None, " ", "\n", ""] for text in texts): + raise ValueError(i18n("请输入有效文本")) + for text in texts: + if text in [None, " ", ""]: + pass + else: + _text.append(text) + return _text + + def replace_consecutive_punctuation(self, text): + punctuations = "".join(re.escape(p) for p in punctuation) + pattern = f"([{punctuations}])([{punctuations}])+" + result = re.sub(pattern, r"\1", text) + return result diff --git a/TTS_infer_pack/__init__.py b/TTS_infer_pack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8579a63215aba6c2dc5674c4e8711256c7e904a7 --- /dev/null +++ b/TTS_infer_pack/__init__.py @@ -0,0 +1 @@ +from . import TTS, text_segmentation_method diff --git a/TTS_infer_pack/text_segmentation_method.py b/TTS_infer_pack/text_segmentation_method.py new file mode 100644 index 0000000000000000000000000000000000000000..fda70a49834ea43c2a3a55154705e111b24fa196 --- /dev/null +++ b/TTS_infer_pack/text_segmentation_method.py @@ -0,0 +1,189 @@ +import re +from typing import Callable + +punctuation = set(["!", "?", "…", ",", ".", "-", " "]) +METHODS = dict() + + +def get_method(name: str) -> Callable: + method = METHODS.get(name, None) + if method is None: + raise ValueError(f"Method {name} not found") + return method + + +def get_method_names() -> list: + return list(METHODS.keys()) + + +def register_method(name): + def decorator(func): + METHODS[name] = func + return func + + return decorator + + +splits = { + ",", + "。", + "?", + "!", + ",", + ".", + "?", + "!", + "~", + ":", + ":", + "—", + "…", +} + + +def split_big_text(text, max_len=510): + # 定义全角和半角标点符号 + punctuation = "".join(splits) + + # 切割文本 + segments = re.split("([" + punctuation + "])", text) + + # 初始化结果列表和当前片段 + result = [] + current_segment = "" + + for segment in segments: + # 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段 + if len(current_segment + segment) > max_len: + result.append(current_segment) + current_segment = segment + else: + current_segment += segment + + # 将最后一个片段加入结果列表 + if current_segment: + result.append(current_segment) + + return result + + +def split(todo_text): + todo_text = todo_text.replace("……", "。").replace("——", ",") + if todo_text[-1] not in splits: + todo_text += "。" + i_split_head = i_split_tail = 0 + len_text = len(todo_text) + todo_texts = [] + while 1: + if i_split_head >= len_text: + break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 + if todo_text[i_split_head] in splits: + i_split_head += 1 + todo_texts.append(todo_text[i_split_tail:i_split_head]) + i_split_tail = i_split_head + else: + i_split_head += 1 + return todo_texts + + +# 不切 +@register_method("cut0") +def cut0(inp): + if not set(inp).issubset(punctuation): + return inp + else: + return "/n" + + +# 凑四句一切 +@register_method("cut1") +def cut1(inp): + inp = inp.strip("\n") + inps = split(inp) + split_idx = list(range(0, len(inps), 4)) + split_idx[-1] = None + if len(split_idx) > 1: + opts = [] + for idx in range(len(split_idx) - 1): + opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]])) + else: + opts = [inp] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +# 凑50字一切 +@register_method("cut2") +def cut2(inp): + inp = inp.strip("\n") + inps = split(inp) + if len(inps) < 2: + return inp + opts = [] + summ = 0 + tmp_str = "" + for i in range(len(inps)): + summ += len(inps[i]) + tmp_str += inps[i] + if summ > 50: + summ = 0 + opts.append(tmp_str) + tmp_str = "" + if tmp_str != "": + opts.append(tmp_str) + # print(opts) + if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 + opts[-2] = opts[-2] + opts[-1] + opts = opts[:-1] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +# 按中文句号。切 +@register_method("cut3") +def cut3(inp): + inp = inp.strip("\n") + opts = ["%s" % item for item in inp.strip("。").split("。")] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +# 按英文句号.切 +@register_method("cut4") +def cut4(inp): + inp = inp.strip("\n") + opts = re.split(r"(? 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit(): + items.append(char) + else: + items.append(char) + mergeitems.append("".join(items)) + items = [] + else: + items.append(char) + + if items: + mergeitems.append("".join(items)) + + opt = [item for item in mergeitems if not set(item).issubset(punds)] + return "\n".join(opt) + + +if __name__ == "__main__": + method = get_method("cut5") + print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。")) diff --git a/configs/.gitignore b/configs/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2a6160516e607039d1737c43421a105e0860bd6c --- /dev/null +++ b/configs/.gitignore @@ -0,0 +1 @@ +*.yaml \ No newline at end of file diff --git a/configs/s2.json b/configs/s2.json new file mode 100644 index 0000000000000000000000000000000000000000..0bd672263e923ce21a0edb3ebdecbc8235c1291f --- /dev/null +++ b/configs/s2.json @@ -0,0 +1,91 @@ +{ + "train": { + "log_interval": 100, + "eval_interval": 500, + "seed": 1234, + "epochs": 100, + "learning_rate": 0.0001, + "betas": [ + 0.8, + 0.99 + ], + "eps": 1e-09, + "batch_size": 32, + "fp16_run": true, + "lr_decay": 0.999875, + "segment_size": 20480, + "init_lr_ratio": 1, + "warmup_epochs": 0, + "c_mel": 45, + "c_kl": 1.0, + "text_low_lr_rate": 0.4, + "grad_ckpt": false + }, + "data": { + "max_wav_value": 32768.0, + "sampling_rate": 32000, + "filter_length": 2048, + "hop_length": 640, + "win_length": 2048, + "n_mel_channels": 128, + "mel_fmin": 0.0, + "mel_fmax": null, + "add_blank": true, + "n_speakers": 300, + "cleaned_text": true + }, + "model": { + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 10, + 8, + 2, + 2, + 2 + ], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [ + 16, + 16, + 8, + 2, + 2 + ], + "n_layers_q": 3, + "use_spectral_norm": false, + "gin_channels": 512, + "semantic_frame_rate": "25hz", + "freeze_quantizer": true + }, + "s2_ckpt_dir": "logs/s2/big2k1", + "content_module": "cnhubert" +} \ No newline at end of file diff --git a/configs/s2v2Pro.json b/configs/s2v2Pro.json new file mode 100644 index 0000000000000000000000000000000000000000..4eaee8001f7736fee4b4a8280c823847de9736ed --- /dev/null +++ b/configs/s2v2Pro.json @@ -0,0 +1,91 @@ +{ + "train": { + "log_interval": 100, + "eval_interval": 500, + "seed": 1234, + "epochs": 100, + "learning_rate": 0.0001, + "betas": [ + 0.8, + 0.99 + ], + "eps": 1e-09, + "batch_size": 32, + "fp16_run": true, + "lr_decay": 0.999875, + "segment_size": 20480, + "init_lr_ratio": 1, + "warmup_epochs": 0, + "c_mel": 45, + "c_kl": 1.0, + "text_low_lr_rate": 0.4, + "grad_ckpt": false + }, + "data": { + "max_wav_value": 32768.0, + "sampling_rate": 32000, + "filter_length": 2048, + "hop_length": 640, + "win_length": 2048, + "n_mel_channels": 128, + "mel_fmin": 0.0, + "mel_fmax": null, + "add_blank": true, + "n_speakers": 300, + "cleaned_text": true + }, + "model": { + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": 0.0, + "resblock": "1", + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 10, + 8, + 2, + 2, + 2 + ], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [ + 16, + 16, + 8, + 2, + 2 + ], + "n_layers_q": 3, + "use_spectral_norm": false, + "gin_channels": 1024, + "semantic_frame_rate": "25hz", + "freeze_quantizer": true + }, + "s2_ckpt_dir": "logs/s2/big2k1", + "content_module": "cnhubert" +} \ No newline at end of file diff --git a/configs/s2v2ProPlus.json b/configs/s2v2ProPlus.json new file mode 100644 index 0000000000000000000000000000000000000000..37d8e16893729ff5dd816b1d86a1b38cd7d03e77 --- /dev/null +++ b/configs/s2v2ProPlus.json @@ -0,0 +1,91 @@ +{ + "train": { + "log_interval": 100, + "eval_interval": 500, + "seed": 1234, + "epochs": 100, + "learning_rate": 0.0001, + "betas": [ + 0.8, + 0.99 + ], + "eps": 1e-09, + "batch_size": 32, + "fp16_run": true, + "lr_decay": 0.999875, + "segment_size": 20480, + "init_lr_ratio": 1, + "warmup_epochs": 0, + "c_mel": 45, + "c_kl": 1.0, + "text_low_lr_rate": 0.4, + "grad_ckpt": false + }, + "data": { + "max_wav_value": 32768.0, + "sampling_rate": 32000, + "filter_length": 2048, + "hop_length": 640, + "win_length": 2048, + "n_mel_channels": 128, + "mel_fmin": 0.0, + "mel_fmax": null, + "add_blank": true, + "n_speakers": 300, + "cleaned_text": true + }, + "model": { + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": 0.0, + "resblock": "1", + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 10, + 8, + 2, + 2, + 2 + ], + "upsample_initial_channel": 768, + "upsample_kernel_sizes": [ + 20, + 16, + 8, + 2, + 2 + ], + "n_layers_q": 3, + "use_spectral_norm": false, + "gin_channels": 1024, + "semantic_frame_rate": "25hz", + "freeze_quantizer": true + }, + "s2_ckpt_dir": "logs/s2/big2k1", + "content_module": "cnhubert" +} \ No newline at end of file diff --git a/download.py b/download.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4ead63bfe3c15326212a6ebabe2dac166e0ff2 --- /dev/null +++ b/download.py @@ -0,0 +1,13 @@ +import os +import sys + +now_dir = os.getcwd() +sys.path.insert(0, now_dir) +from text.g2pw import G2PWPinyin + +g2pw = G2PWPinyin( + model_dir="GPT_SoVITS/text/G2PWModel", + model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + v_to_u=False, + neutral_tone_with_five=True, +) diff --git a/eres2net/ERes2Net.py b/eres2net/ERes2Net.py new file mode 100644 index 0000000000000000000000000000000000000000..1618c8139f5be521835d87af487ac8180120087b --- /dev/null +++ b/eres2net/ERes2Net.py @@ -0,0 +1,264 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +""" +Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker. +ERes2Net incorporates both local and global feature fusion techniques to improve the performance. +The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal. +The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal. +""" + +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +import pooling_layers as pooling_layers +from fusion import AFF + + +class ReLU(nn.Hardtanh): + def __init__(self, inplace=False): + super(ReLU, self).__init__(0, 20, inplace) + + def __repr__(self): + inplace_str = "inplace" if self.inplace else "" + return self.__class__.__name__ + " (" + inplace_str + ")" + + +class BasicBlockERes2Net(nn.Module): + expansion = 2 + + def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): + super(BasicBlockERes2Net, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + + convs = [] + bns = [] + for i in range(self.nums): + convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.relu = ReLU(inplace=True) + + self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes), + ) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class BasicBlockERes2Net_diff_AFF(nn.Module): + expansion = 2 + + def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): + super(BasicBlockERes2Net_diff_AFF, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + + convs = [] + fuse_models = [] + bns = [] + for i in range(self.nums): + convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False)) + bns.append(nn.BatchNorm2d(width)) + for j in range(self.nums - 1): + fuse_models.append(AFF(channels=width)) + + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.fuse_models = nn.ModuleList(fuse_models) + self.relu = ReLU(inplace=True) + + self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes), + ) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = self.fuse_models[i - 1](sp, spx[i]) + + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class ERes2Net(nn.Module): + def __init__( + self, + block=BasicBlockERes2Net, + block_fuse=BasicBlockERes2Net_diff_AFF, + num_blocks=[3, 4, 6, 3], + m_channels=32, + feat_dim=80, + embedding_size=192, + pooling_func="TSTP", + two_emb_layer=False, + ): + super(ERes2Net, self).__init__() + self.in_planes = m_channels + self.feat_dim = feat_dim + self.embedding_size = embedding_size + self.stats_dim = int(feat_dim / 8) * m_channels * 8 + self.two_emb_layer = two_emb_layer + + self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2) + + # Downsampling module for each layer + self.layer1_downsample = nn.Conv2d( + m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False + ) + self.layer2_downsample = nn.Conv2d( + m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False + ) + self.layer3_downsample = nn.Conv2d( + m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False + ) + + # Bottom-up fusion module + self.fuse_mode12 = AFF(channels=m_channels * 4) + self.fuse_mode123 = AFF(channels=m_channels * 8) + self.fuse_mode1234 = AFF(channels=m_channels * 16) + + self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2 + self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion) + self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) + self.seg_2 = nn.Linear(embedding_size, embedding_size) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out1_downsample = self.layer1_downsample(out1) + fuse_out12 = self.fuse_mode12(out2, out1_downsample) + out3 = self.layer3(out2) + fuse_out12_downsample = self.layer2_downsample(fuse_out12) + fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) + out4 = self.layer4(out3) + fuse_out123_downsample = self.layer3_downsample(fuse_out123) + fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample) + stats = self.pool(fuse_out1234) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_b + else: + return embed_a + + def forward3(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out1_downsample = self.layer1_downsample(out1) + fuse_out12 = self.fuse_mode12(out2, out1_downsample) + out3 = self.layer3(out2) + fuse_out12_downsample = self.layer2_downsample(fuse_out12) + fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) + out4 = self.layer4(out3) + fuse_out123_downsample = self.layer3_downsample(fuse_out123) + fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1) + return fuse_out1234 + + +if __name__ == "__main__": + x = torch.zeros(10, 300, 80) + model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func="TSTP") + model.eval() + out = model(x) + print(out.shape) # torch.Size([10, 192]) + + num_params = sum(param.numel() for param in model.parameters()) + print("{} M".format(num_params / 1e6)) # 6.61M diff --git a/eres2net/ERes2NetV2.py b/eres2net/ERes2NetV2.py new file mode 100644 index 0000000000000000000000000000000000000000..2e152a4193c648130911dd1b385e4de24489f907 --- /dev/null +++ b/eres2net/ERes2NetV2.py @@ -0,0 +1,272 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +""" +To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension +within each stage. However, this modification also increases the number of model parameters and computational complexity. +To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing +both the model parameters and its computational cost. +""" + +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +import pooling_layers as pooling_layers +from fusion import AFF + + +class ReLU(nn.Hardtanh): + def __init__(self, inplace=False): + super(ReLU, self).__init__(0, 20, inplace) + + def __repr__(self): + inplace_str = "inplace" if self.inplace else "" + return self.__class__.__name__ + " (" + inplace_str + ")" + + +class BasicBlockERes2NetV2(nn.Module): + def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2): + super(BasicBlockERes2NetV2, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + self.expansion = expansion + + convs = [] + bns = [] + for i in range(self.nums): + convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.relu = ReLU(inplace=True) + + self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes), + ) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class BasicBlockERes2NetV2AFF(nn.Module): + def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2): + super(BasicBlockERes2NetV2AFF, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + self.expansion = expansion + + convs = [] + fuse_models = [] + bns = [] + for i in range(self.nums): + convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False)) + bns.append(nn.BatchNorm2d(width)) + for j in range(self.nums - 1): + fuse_models.append(AFF(channels=width, r=4)) + + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.fuse_models = nn.ModuleList(fuse_models) + self.relu = ReLU(inplace=True) + + self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes), + ) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = self.fuse_models[i - 1](sp, spx[i]) + + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class ERes2NetV2(nn.Module): + def __init__( + self, + block=BasicBlockERes2NetV2, + block_fuse=BasicBlockERes2NetV2AFF, + num_blocks=[3, 4, 6, 3], + m_channels=64, + feat_dim=80, + embedding_size=192, + baseWidth=26, + scale=2, + expansion=2, + pooling_func="TSTP", + two_emb_layer=False, + ): + super(ERes2NetV2, self).__init__() + self.in_planes = m_channels + self.feat_dim = feat_dim + self.embedding_size = embedding_size + self.stats_dim = int(feat_dim / 8) * m_channels * 8 + self.two_emb_layer = two_emb_layer + self.baseWidth = baseWidth + self.scale = scale + self.expansion = expansion + + self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2) + + # Downsampling module + self.layer3_ds = nn.Conv2d( + m_channels * 4 * self.expansion, + m_channels * 8 * self.expansion, + kernel_size=3, + padding=1, + stride=2, + bias=False, + ) + + # Bottom-up fusion module + self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4) + + self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2 + self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * self.expansion) + self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats, embedding_size) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) + self.seg_2 = nn.Linear(embedding_size, embedding_size) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append( + block( + self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion + ) + ) + self.in_planes = planes * self.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + out4 = self.layer4(out3) + out3_ds = self.layer3_ds(out3) + fuse_out34 = self.fuse34(out4, out3_ds) + stats = self.pool(fuse_out34) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_b + else: + return embed_a + + def forward3(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + out4 = self.layer4(out3) + out3_ds = self.layer3_ds(out3) + fuse_out34 = self.fuse34(out4, out3_ds) + # print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72]) + return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1) + # stats = self.pool(fuse_out34) + # + # embed_a = self.seg_1(stats) + # if self.two_emb_layer: + # out = F.relu(embed_a) + # out = self.seg_bn_1(out) + # embed_b = self.seg_2(out) + # return embed_b + # else: + # return embed_a + + +if __name__ == "__main__": + x = torch.randn(1, 300, 80) + model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2) + model.eval() + y = model(x) + print(y.size()) + macs, num_params = profile(model, inputs=(x,)) + print("Params: {} M".format(num_params / 1e6)) # 17.86 M + print("MACs: {} G".format(macs / 1e9)) # 12.69 G diff --git a/eres2net/ERes2Net_huge.py b/eres2net/ERes2Net_huge.py new file mode 100644 index 0000000000000000000000000000000000000000..0f04236b0890b0ae862cf4ff5469d686c9905919 --- /dev/null +++ b/eres2net/ERes2Net_huge.py @@ -0,0 +1,289 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker. +ERes2Net incorporates both local and global feature fusion techniques to improve the performance. +The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal. +The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal. +ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better +recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance. +""" + +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +import pooling_layers as pooling_layers +from fusion import AFF + + +class ReLU(nn.Hardtanh): + def __init__(self, inplace=False): + super(ReLU, self).__init__(0, 20, inplace) + + def __repr__(self): + inplace_str = "inplace" if self.inplace else "" + return self.__class__.__name__ + " (" + inplace_str + ")" + + +class BasicBlockERes2Net(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): + super(BasicBlockERes2Net, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + + convs = [] + bns = [] + for i in range(self.nums): + convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.relu = ReLU(inplace=True) + + self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes), + ) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class BasicBlockERes2Net_diff_AFF(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): + super(BasicBlockERes2Net_diff_AFF, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + + convs = [] + fuse_models = [] + bns = [] + for i in range(self.nums): + convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False)) + bns.append(nn.BatchNorm2d(width)) + for j in range(self.nums - 1): + fuse_models.append(AFF(channels=width)) + + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.fuse_models = nn.ModuleList(fuse_models) + self.relu = ReLU(inplace=True) + + self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes), + ) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = self.fuse_models[i - 1](sp, spx[i]) + + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class ERes2Net(nn.Module): + def __init__( + self, + block=BasicBlockERes2Net, + block_fuse=BasicBlockERes2Net_diff_AFF, + num_blocks=[3, 4, 6, 3], + m_channels=64, + feat_dim=80, + embedding_size=192, + pooling_func="TSTP", + two_emb_layer=False, + ): + super(ERes2Net, self).__init__() + self.in_planes = m_channels + self.feat_dim = feat_dim + self.embedding_size = embedding_size + self.stats_dim = int(feat_dim / 8) * m_channels * 8 + self.two_emb_layer = two_emb_layer + + self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + + self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2) + + self.layer1_downsample = nn.Conv2d( + m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False + ) + self.layer2_downsample = nn.Conv2d( + m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False + ) + self.layer3_downsample = nn.Conv2d( + m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False + ) + + self.fuse_mode12 = AFF(channels=m_channels * 8) + self.fuse_mode123 = AFF(channels=m_channels * 16) + self.fuse_mode1234 = AFF(channels=m_channels * 32) + + self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2 + self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion) + self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) + self.seg_2 = nn.Linear(embedding_size, embedding_size) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out1_downsample = self.layer1_downsample(out1) + fuse_out12 = self.fuse_mode12(out2, out1_downsample) + out3 = self.layer3(out2) + fuse_out12_downsample = self.layer2_downsample(fuse_out12) + fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) + out4 = self.layer4(out3) + fuse_out123_downsample = self.layer3_downsample(fuse_out123) + fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample) + stats = self.pool(fuse_out1234) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_b + else: + return embed_a + + def forward2(self, x, if_mean): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out1_downsample = self.layer1_downsample(out1) + fuse_out12 = self.fuse_mode12(out2, out1_downsample) + out3 = self.layer3(out2) + fuse_out12_downsample = self.layer2_downsample(fuse_out12) + fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) + out4 = self.layer4(out3) + fuse_out123_downsample = self.layer3_downsample(fuse_out123) + fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2) # bs,20480,T + if if_mean == False: + mean = fuse_out1234[0].transpose(1, 0) # (T,20480),bs=T + else: + mean = fuse_out1234.mean(2) # bs,20480 + mean_std = torch.cat([mean, torch.zeros_like(mean)], 1) + return self.seg_1(mean_std) # (T,192) + + # stats = self.pool(fuse_out1234) + # if self.two_emb_layer: + # out = F.relu(embed_a) + # out = self.seg_bn_1(out) + # embed_b = self.seg_2(out) + # return embed_b + # else: + # return embed_a + + def forward3(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out1_downsample = self.layer1_downsample(out1) + fuse_out12 = self.fuse_mode12(out2, out1_downsample) + out3 = self.layer3(out2) + fuse_out12_downsample = self.layer2_downsample(fuse_out12) + fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) + out4 = self.layer4(out3) + fuse_out123_downsample = self.layer3_downsample(fuse_out123) + fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1) + return fuse_out1234 + # print(fuse_out1234.shape) + # print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape) + # pdb.set_trace() diff --git a/eres2net/fusion.py b/eres2net/fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..d156a55c8c7e5a730d40c2fe30879c5b1c7cba61 --- /dev/null +++ b/eres2net/fusion.py @@ -0,0 +1,27 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import torch +import torch.nn as nn + + +class AFF(nn.Module): + def __init__(self, channels=64, r=4): + super(AFF, self).__init__() + inter_channels = int(channels // r) + + self.local_att = nn.Sequential( + nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.SiLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + def forward(self, x, ds_y): + xa = torch.cat((x, ds_y), dim=1) + x_att = self.local_att(xa) + x_att = 1.0 + torch.tanh(x_att) + xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att) + + return xo diff --git a/eres2net/kaldi.py b/eres2net/kaldi.py new file mode 100644 index 0000000000000000000000000000000000000000..a80e5e6b727d40c81e8caf5121de9966ba8a750b --- /dev/null +++ b/eres2net/kaldi.py @@ -0,0 +1,844 @@ +import math +from typing import Tuple + +import torch +import torchaudio +from torch import Tensor + +__all__ = [ + "get_mel_banks", + "inverse_mel_scale", + "inverse_mel_scale_scalar", + "mel_scale", + "mel_scale_scalar", + "spectrogram", + "fbank", + "mfcc", + "vtln_warp_freq", + "vtln_warp_mel_freq", +] + +# numeric_limits::epsilon() 1.1920928955078125e-07 +EPSILON = torch.tensor(torch.finfo(torch.float).eps) +# 1 milliseconds = 0.001 seconds +MILLISECONDS_TO_SECONDS = 0.001 + +# window types +HAMMING = "hamming" +HANNING = "hanning" +POVEY = "povey" +RECTANGULAR = "rectangular" +BLACKMAN = "blackman" +WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN] + + +def _get_epsilon(device, dtype): + return EPSILON.to(device=device, dtype=dtype) + + +def _next_power_of_2(x: int) -> int: + r"""Returns the smallest power of 2 that is greater than x""" + return 1 if x == 0 else 2 ** (x - 1).bit_length() + + +def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor: + r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``) + representing how the window is shifted along the waveform. Each row is a frame. + + Args: + waveform (Tensor): Tensor of size ``num_samples`` + window_size (int): Frame length + window_shift (int): Frame shift + snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. + + Returns: + Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame + """ + assert waveform.dim() == 1 + num_samples = waveform.size(0) + strides = (window_shift * waveform.stride(0), waveform.stride(0)) + + if snip_edges: + if num_samples < window_size: + return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device) + else: + m = 1 + (num_samples - window_size) // window_shift + else: + reversed_waveform = torch.flip(waveform, [0]) + m = (num_samples + (window_shift // 2)) // window_shift + pad = window_size // 2 - window_shift // 2 + pad_right = reversed_waveform + if pad > 0: + # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect' + # but we want [2, 1, 0, 0, 1, 2] + pad_left = reversed_waveform[-pad:] + waveform = torch.cat((pad_left, waveform, pad_right), dim=0) + else: + # pad is negative so we want to trim the waveform at the front + waveform = torch.cat((waveform[-pad:], pad_right), dim=0) + + sizes = (m, window_size) + return waveform.as_strided(sizes, strides) + + +def _feature_window_function( + window_type: str, + window_size: int, + blackman_coeff: float, + device: torch.device, + dtype: int, +) -> Tensor: + r"""Returns a window function with the given type and size""" + if window_type == HANNING: + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype) + elif window_type == HAMMING: + return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype) + elif window_type == POVEY: + # like hanning but goes to zero at edges + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85) + elif window_type == RECTANGULAR: + return torch.ones(window_size, device=device, dtype=dtype) + elif window_type == BLACKMAN: + a = 2 * math.pi / (window_size - 1) + window_function = torch.arange(window_size, device=device, dtype=dtype) + # can't use torch.blackman_window as they use different coefficients + return ( + blackman_coeff + - 0.5 * torch.cos(a * window_function) + + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function) + ).to(device=device, dtype=dtype) + else: + raise Exception("Invalid window type " + window_type) + + +def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor: + r"""Returns the log energy of size (m) for a strided_input (m,*)""" + device, dtype = strided_input.device, strided_input.dtype + log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) + if energy_floor == 0.0: + return log_energy + return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype)) + + +def _get_waveform_and_window_properties( + waveform: Tensor, + channel: int, + sample_frequency: float, + frame_shift: float, + frame_length: float, + round_to_power_of_two: bool, + preemphasis_coefficient: float, +) -> Tuple[Tensor, int, int, int]: + r"""Gets the waveform and window properties""" + channel = max(channel, 0) + assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0)) + waveform = waveform[channel, :] # size (n) + window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS) + window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS) + padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size + + assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format( + window_size, len(waveform) + ) + assert 0 < window_shift, "`window_shift` must be greater than 0" + assert padded_window_size % 2 == 0, ( + "the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`" + ) + assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]" + assert sample_frequency > 0, "`sample_frequency` must be greater than zero" + return waveform, window_shift, window_size, padded_window_size + + +def _get_window( + waveform: Tensor, + padded_window_size: int, + window_size: int, + window_shift: int, + window_type: str, + blackman_coeff: float, + snip_edges: bool, + raw_energy: bool, + energy_floor: float, + dither: float, + remove_dc_offset: bool, + preemphasis_coefficient: float, +) -> Tuple[Tensor, Tensor]: + r"""Gets a window and its log energy + + Returns: + (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m) + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + # size (m, window_size) + strided_input = _get_strided(waveform, window_size, window_shift, snip_edges) + + if dither != 0.0: + rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype) + strided_input = strided_input + rand_gauss * dither + + if remove_dc_offset: + # Subtract each row/frame by its mean + row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1) + strided_input = strided_input - row_means + + if raw_energy: + # Compute the log energy of each row/frame before applying preemphasis and + # window function + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) + + if preemphasis_coefficient != 0.0: + # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j + offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze( + 0 + ) # size (m, window_size + 1) + strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1] + + # Apply window_function to each row/frame + window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze( + 0 + ) # size (1, window_size) + strided_input = strided_input * window_function # size (m, window_size) + + # Pad columns with zero until we reach size (m, padded_window_size) + if padded_window_size != window_size: + padding_right = padded_window_size - window_size + strided_input = torch.nn.functional.pad( + strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0 + ).squeeze(0) + + # Compute energy after window function (not the raw one) + if not raw_energy: + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) + + return strided_input, signal_log_energy + + +def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor: + # subtracts the column mean of the tensor size (m, n) if subtract_mean=True + # it returns size (m, n) + if subtract_mean: + col_means = torch.mean(tensor, dim=0).unsqueeze(0) + tensor = tensor - col_means + return tensor + + +def spectrogram( + waveform: Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + min_duration: float = 0.0, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + window_type: str = POVEY, +) -> Tensor: + r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's + compute-spectrogram-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A spectrogram identical to what Kaldi would output. The shape is + (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient + ) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return torch.empty(0) + + strided_input, signal_log_energy = _get_window( + waveform, + padded_window_size, + window_size, + window_shift, + window_type, + blackman_coeff, + snip_edges, + raw_energy, + energy_floor, + dither, + remove_dc_offset, + preemphasis_coefficient, + ) + + # size (m, padded_window_size // 2 + 1, 2) + fft = torch.fft.rfft(strided_input) + + # Convert the FFT into a power spectrum + power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1) + power_spectrum[:, 0] = signal_log_energy + + power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) + return power_spectrum + + +def inverse_mel_scale_scalar(mel_freq: float) -> float: + return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0) + + +def inverse_mel_scale(mel_freq: Tensor) -> Tensor: + return 700.0 * ((mel_freq / 1127.0).exp() - 1.0) + + +def mel_scale_scalar(freq: float) -> float: + return 1127.0 * math.log(1.0 + freq / 700.0) + + +def mel_scale(freq: Tensor) -> Tensor: + return 1127.0 * (1.0 + freq / 700.0).log() + + +def vtln_warp_freq( + vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq: float, + high_freq: float, + vtln_warp_factor: float, + freq: Tensor, +) -> Tensor: + r"""This computes a VTLN warping function that is not the same as HTK's one, + but has similar inputs (this function has the advantage of never producing + empty bins). + + This function computes a warp function F(freq), defined between low_freq + and high_freq inclusive, with the following properties: + F(low_freq) == low_freq + F(high_freq) == high_freq + The function is continuous and piecewise linear with two inflection + points. + The lower inflection point (measured in terms of the unwarped + frequency) is at frequency l, determined as described below. + The higher inflection point is at a frequency h, determined as + described below. + If l <= f <= h, then F(f) = f/vtln_warp_factor. + If the higher inflection point (measured in terms of the unwarped + frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. + Since (by the last point) F(h) == h/vtln_warp_factor, then + max(h, h/vtln_warp_factor) == vtln_high_cutoff, so + h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). + = vtln_high_cutoff * min(1, vtln_warp_factor). + If the lower inflection point (measured in terms of the unwarped + frequency) is at l, then min(l, F(l)) == vtln_low_cutoff + This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) + = vtln_low_cutoff * max(1, vtln_warp_factor) + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + freq (Tensor): given frequency in Hz + + Returns: + Tensor: Freq after vtln warp + """ + assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq" + assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]" + l = vtln_low_cutoff * max(1.0, vtln_warp_factor) + h = vtln_high_cutoff * min(1.0, vtln_warp_factor) + scale = 1.0 / vtln_warp_factor + Fl = scale * l # F(l) + Fh = scale * h # F(h) + assert l > low_freq and h < high_freq + # slope of left part of the 3-piece linear function + scale_left = (Fl - low_freq) / (l - low_freq) + # [slope of center part is just "scale"] + + # slope of right part of the 3-piece linear function + scale_right = (high_freq - Fh) / (high_freq - h) + + res = torch.empty_like(freq) + + outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq + before_l = torch.lt(freq, l) # freq < l + before_h = torch.lt(freq, h) # freq < h + after_h = torch.ge(freq, h) # freq >= h + + # order of operations matter here (since there is overlapping frequency regions) + res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq) + res[before_h] = scale * freq[before_h] + res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq) + res[outside_low_high_freq] = freq[outside_low_high_freq] + + return res + + +def vtln_warp_mel_freq( + vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq, + high_freq: float, + vtln_warp_factor: float, + mel_freq: Tensor, +) -> Tensor: + r""" + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + mel_freq (Tensor): Given frequency in Mel + + Returns: + Tensor: ``mel_freq`` after vtln warp + """ + return mel_scale( + vtln_warp_freq( + vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq) + ) + ) + + +def get_mel_banks( + num_bins: int, + window_length_padded: int, + sample_freq: float, + low_freq: float, + high_freq: float, + vtln_low: float, + vtln_high: float, + vtln_warp_factor: float, + device=None, + dtype=None, +) -> Tuple[Tensor, Tensor]: + """ + Returns: + (Tensor, Tensor): The tuple consists of ``bins`` (which is + melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is + center frequencies of bins of size (``num_bins``)). + """ + assert num_bins > 3, "Must have at least 3 mel bins" + assert window_length_padded % 2 == 0 + num_fft_bins = window_length_padded / 2 + nyquist = 0.5 * sample_freq + + if high_freq <= 0.0: + high_freq += nyquist + + assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), ( + "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist) + ) + + # fft-bin width [think of it as Nyquist-freq / half-window-length] + fft_bin_width = sample_freq / window_length_padded + mel_low_freq = mel_scale_scalar(low_freq) + mel_high_freq = mel_scale_scalar(high_freq) + + # divide by num_bins+1 in next line because of end-effects where the bins + # spread out to the sides. + mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + + if vtln_high < 0.0: + vtln_high += nyquist + + assert vtln_warp_factor == 1.0 or ( + (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high) + ), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format( + vtln_low, vtln_high, low_freq, high_freq + ) + + bin = torch.arange(num_bins).unsqueeze(1) + left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1) + center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1) + right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1) + + if vtln_warp_factor != 1.0: + left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel) + center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel) + right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel) + + # center_freqs = inverse_mel_scale(center_mel) # size (num_bins) + # size(1, num_fft_bins) + mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0) + + # size (num_bins, num_fft_bins) + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + + if vtln_warp_factor == 1.0: + # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values + bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) + else: + # warping can move the order of left_mel, center_mel, right_mel anywhere + bins = torch.zeros_like(up_slope) + up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel + down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel + bins[up_idx] = up_slope[up_idx] + bins[down_idx] = down_slope[down_idx] + + return bins.to(device=device, dtype=dtype) # , center_freqs + + +cache = {} + + +def fbank( + waveform: Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + use_log_fbank: bool = True, + use_power: bool = True, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY, +) -> Tensor: + r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's + compute-fbank-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features + (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``) + use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``) + where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient + ) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return torch.empty(0, device=device, dtype=dtype) + + # strided_input, size (m, padded_window_size) and signal_log_energy, size (m) + strided_input, signal_log_energy = _get_window( + waveform, + padded_window_size, + window_size, + window_shift, + window_type, + blackman_coeff, + snip_edges, + raw_energy, + energy_floor, + dither, + remove_dc_offset, + preemphasis_coefficient, + ) + + # size (m, padded_window_size // 2 + 1) + spectrum = torch.fft.rfft(strided_input).abs() + if use_power: + spectrum = spectrum.pow(2.0) + + # size (num_mel_bins, padded_window_size // 2) + # print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp) + + cache_key = "%s-%s-%s-%s-%s-%s-%s-%s-%s-%s" % ( + num_mel_bins, + padded_window_size, + sample_frequency, + low_freq, + high_freq, + vtln_low, + vtln_high, + vtln_warp, + device, + dtype, + ) + if cache_key not in cache: + mel_energies = get_mel_banks( + num_mel_bins, + padded_window_size, + sample_frequency, + low_freq, + high_freq, + vtln_low, + vtln_high, + vtln_warp, + device, + dtype, + ) + cache[cache_key] = mel_energies + else: + mel_energies = cache[cache_key] + + # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1) + mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0) + + # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) + mel_energies = torch.mm(spectrum, mel_energies.T) + if use_log_fbank: + # avoid log of zero (which should be prevented anyway by dithering) + mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() + + # if use_energy then add it as the last column for htk_compat == true else first column + if use_energy: + signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1) + # returns size (m, num_mel_bins + 1) + if htk_compat: + mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1) + else: + mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1) + + mel_energies = _subtract_column_mean(mel_energies, subtract_mean) + return mel_energies + + +def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor: + # returns a dct matrix of size (num_mel_bins, num_ceps) + # size (num_mel_bins, num_mel_bins) + dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho") + # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins) + # this would be the first column in the dct_matrix for torchaudio as it expects a + # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi + # expects a left multiply e.g. dct_matrix * vector). + dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins)) + dct_matrix = dct_matrix[:, :num_ceps] + return dct_matrix + + +def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor: + # returns size (num_ceps) + # Compute liftering coefficients (scaling on cepstral coeffs) + # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected. + i = torch.arange(num_ceps) + return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter) + + +def mfcc( + waveform: Tensor, + blackman_coeff: float = 0.42, + cepstral_lifter: float = 22.0, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + num_ceps: int = 13, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY, +) -> Tensor: + r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's + compute-mfcc-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible + features (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``"povey"``) + + Returns: + Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``) + where m is calculated in _get_strided + """ + assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins) + + device, dtype = waveform.device, waveform.dtype + + # The mel_energies should not be squared (use_power=True), not have mean subtracted + # (subtract_mean=False), and use log (use_log_fbank=True). + # size (m, num_mel_bins + use_energy) + feature = fbank( + waveform=waveform, + blackman_coeff=blackman_coeff, + channel=channel, + dither=dither, + energy_floor=energy_floor, + frame_length=frame_length, + frame_shift=frame_shift, + high_freq=high_freq, + htk_compat=htk_compat, + low_freq=low_freq, + min_duration=min_duration, + num_mel_bins=num_mel_bins, + preemphasis_coefficient=preemphasis_coefficient, + raw_energy=raw_energy, + remove_dc_offset=remove_dc_offset, + round_to_power_of_two=round_to_power_of_two, + sample_frequency=sample_frequency, + snip_edges=snip_edges, + subtract_mean=False, + use_energy=use_energy, + use_log_fbank=True, + use_power=True, + vtln_high=vtln_high, + vtln_low=vtln_low, + vtln_warp=vtln_warp, + window_type=window_type, + ) + + if use_energy: + # size (m) + signal_log_energy = feature[:, num_mel_bins if htk_compat else 0] + # offset is 0 if htk_compat==True else 1 + mel_offset = int(not htk_compat) + feature = feature[:, mel_offset : (num_mel_bins + mel_offset)] + + # size (num_mel_bins, num_ceps) + dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device) + + # size (m, num_ceps) + feature = feature.matmul(dct_matrix) + + if cepstral_lifter != 0.0: + # size (1, num_ceps) + lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0) + feature *= lifter_coeffs.to(device=device, dtype=dtype) + + # if use_energy then replace the last column for htk_compat == true else first column + if use_energy: + feature[:, 0] = signal_log_energy + + if htk_compat: + energy = feature[:, 0].unsqueeze(1) # size (m, 1) + feature = feature[:, 1:] # size (m, num_ceps - 1) + if not use_energy: + # scale on C0 (actually removing a scale we previously added that's + # part of one common definition of the cosine transform.) + energy *= math.sqrt(2) + + feature = torch.cat((feature, energy), dim=1) + + feature = _subtract_column_mean(feature, subtract_mean) + return feature diff --git a/eres2net/pooling_layers.py b/eres2net/pooling_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e0eab659befa56048da550218b2a557c61d8b3 --- /dev/null +++ b/eres2net/pooling_layers.py @@ -0,0 +1,101 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""This implementation is adapted from https://github.com/wenet-e2e/wespeaker.""" + +import torch +import torch.nn as nn + + +class TAP(nn.Module): + """ + Temporal average pooling, only first-order mean is considered + """ + + def __init__(self, **kwargs): + super(TAP, self).__init__() + + def forward(self, x): + pooling_mean = x.mean(dim=-1) + # To be compatable with 2D input + pooling_mean = pooling_mean.flatten(start_dim=1) + return pooling_mean + + +class TSDP(nn.Module): + """ + Temporal standard deviation pooling, only second-order std is considered + """ + + def __init__(self, **kwargs): + super(TSDP, self).__init__() + + def forward(self, x): + # The last dimension is the temporal axis + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8) + pooling_std = pooling_std.flatten(start_dim=1) + return pooling_std + + +class TSTP(nn.Module): + """ + Temporal statistics pooling, concatenate mean and std, which is used in + x-vector + Comment: simple concatenation can not make full use of both statistics + """ + + def __init__(self, **kwargs): + super(TSTP, self).__init__() + + def forward(self, x): + # The last dimension is the temporal axis + pooling_mean = x.mean(dim=-1) + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8) + pooling_mean = pooling_mean.flatten(start_dim=1) + pooling_std = pooling_std.flatten(start_dim=1) + + stats = torch.cat((pooling_mean, pooling_std), 1) + return stats + + +class ASTP(nn.Module): + """Attentive statistics pooling: Channel- and context-dependent + statistics pooling, first used in ECAPA_TDNN. + """ + + def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False): + super(ASTP, self).__init__() + self.global_context_att = global_context_att + + # Use Conv1d with stride == 1 rather than Linear, then we don't + # need to transpose inputs. + if global_context_att: + self.linear1 = nn.Conv1d(in_dim * 3, bottleneck_dim, kernel_size=1) # equals W and b in the paper + else: + self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper + self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper + + def forward(self, x): + """ + x: a 3-dimensional tensor in tdnn-based architecture (B,F,T) + or a 4-dimensional tensor in resnet architecture (B,C,F,T) + 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) + """ + if len(x.shape) == 4: + x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) + assert len(x.shape) == 3 + + if self.global_context_att: + context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) + context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) + x_in = torch.cat((x, context_mean, context_std), dim=1) + else: + x_in = x + + # DON'T use ReLU here! ReLU may be hard to converge. + alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in)) + alpha = torch.softmax(self.linear2(alpha), dim=2) + mean = torch.sum(alpha * x, dim=2) + var = torch.sum(alpha * (x**2), dim=2) - mean**2 + std = torch.sqrt(var.clamp(min=1e-10)) + return torch.cat([mean, std], dim=1) diff --git a/export_torch_script.py b/export_torch_script.py new file mode 100644 index 0000000000000000000000000000000000000000..a3a4827781cdee017d0960cc334d673df423aa25 --- /dev/null +++ b/export_torch_script.py @@ -0,0 +1,1097 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py +# reference: https://github.com/lifeiteng/vall-e +import argparse +from io import BytesIO +from typing import Optional +from my_utils import load_audio +import torch +import torchaudio + +from torch import IntTensor, LongTensor, Tensor, nn +from torch.nn import functional as F + +from transformers import AutoModelForMaskedLM, AutoTokenizer +from feature_extractor import cnhubert + +from AR.models.t2s_lightning_module import Text2SemanticLightningModule +from module.models_onnx import SynthesizerTrn + +from inference_webui import get_phones_and_bert + +from sv import SV +import kaldi as Kaldi + +import os +import soundfile + +default_config = { + "embedding_dim": 512, + "hidden_dim": 512, + "num_head": 8, + "num_layers": 12, + "num_codebook": 8, + "p_dropout": 0.0, + "vocab_size": 1024 + 1, + "phoneme_vocab_size": 512, + "EOS": 1024, +} + +sv_cn_model = None + + +def init_sv_cn(device, is_half): + global sv_cn_model + sv_cn_model = SV(device, is_half) + + +def load_sovits_new(sovits_path): + f = open(sovits_path, "rb") + meta = f.read(2) + if meta != b"PK": + data = b"PK" + f.read() + bio = BytesIO() + bio.write(data) + bio.seek(0) + return torch.load(bio, map_location="cpu", weights_only=False) + return torch.load(sovits_path, map_location="cpu", weights_only=False) + + +def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule: + config = dict_s1["config"] + config["model"]["dropout"] = float(config["model"]["dropout"]) + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + t2s_model = t2s_model.eval() + return t2s_model + + +@torch.jit.script +def logits_to_probs( + logits, + previous_tokens: Optional[torch.Tensor] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[int] = None, + repetition_penalty: float = 1.0, +): + # if previous_tokens is not None: + # previous_tokens = previous_tokens.squeeze() + # print(logits.shape,previous_tokens.shape) + # pdb.set_trace() + if previous_tokens is not None and repetition_penalty != 1.0: + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=1, index=previous_tokens) + score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) + logits.scatter_(dim=1, index=previous_tokens, src=score) + + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[:, 0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) + + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v[:, -1].unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +@torch.jit.script +def multinomial_sample_one_no_sync(probs_sort): + # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1.0) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +@torch.jit.script +def sample( + logits, + previous_tokens, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[int] = None, + repetition_penalty: float = 1.35, +): + probs = logits_to_probs( + logits=logits, + previous_tokens=previous_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +@torch.jit.script +def spectrogram_torch( + hann_window: Tensor, y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False +): + # hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype) + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +@torch.jit.script +class T2SMLP: + def __init__(self, w1, b1, w2, b2): + self.w1 = w1 + self.b1 = b1 + self.w2 = w2 + self.b2 = b2 + + def forward(self, x): + x = F.relu(F.linear(x, self.w1, self.b1)) + x = F.linear(x, self.w2, self.b2) + return x + + +@torch.jit.script +class T2SBlock: + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp: T2SMLP, + qkv_w, + qkv_b, + out_w, + out_b, + norm_w1, + norm_b1, + norm_eps1: float, + norm_w2, + norm_b2, + norm_eps2: float, + ): + self.num_heads = num_heads + self.mlp = mlp + self.hidden_dim: int = hidden_dim + self.qkv_w = qkv_w + self.qkv_b = qkv_b + self.out_w = out_w + self.out_b = out_b + self.norm_w1 = norm_w1 + self.norm_b1 = norm_b1 + self.norm_eps1 = norm_eps1 + self.norm_w2 = norm_w2 + self.norm_b2 = norm_b2 + self.norm_eps2 = norm_eps2 + + self.false = torch.tensor(False, dtype=torch.bool) + + @torch.jit.ignore + def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]): + if padding_mask is None: + return x + + if padding_mask.dtype == torch.bool: + return x.masked_fill(padding_mask, 0) + else: + return x * padding_mask + + def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1) + + batch_size = q.shape[0] + q_len = q.shape[1] + kv_len = k.shape[1] + + q = self.to_mask(q, padding_mask) + k_cache = self.to_mask(k, padding_mask) + v_cache = self.to_mask(v, padding_mask) + + q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) + k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + + attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) + + attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) + attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) + + if padding_mask is not None: + for i in range(batch_size): + # mask = padding_mask[i,:,0] + if self.false.device != padding_mask.device: + self.false = self.false.to(padding_mask.device) + idx = torch.where(padding_mask[i, :, 0] == self.false)[0] + x_item = x[i, idx, :].unsqueeze(0) + attn_item = attn[i, idx, :].unsqueeze(0) + x_item = x_item + attn_item + x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) + x_item = x_item + self.mlp.forward(x_item) + x_item = F.layer_norm( + x_item, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + x[i, idx, :] = x_item.squeeze(0) + x = self.to_mask(x, padding_mask) + else: + x = x + attn + x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + return x, k_cache, v_cache + + def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor): + q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) + + k_cache = torch.cat([k_cache, k], dim=1) + v_cache = torch.cat([v_cache, v], dim=1) + + batch_size = q.shape[0] + q_len = q.shape[1] + kv_len = k_cache.shape[1] + + q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) + k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + + attn = F.scaled_dot_product_attention(q, k, v) + + # attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) + # attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) + attn = F.linear(attn, self.out_w, self.out_b) + + x = x + attn + x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + return x, k_cache, v_cache + + +@torch.jit.script +class T2STransformer: + def __init__(self, num_blocks: int, blocks: list[T2SBlock]): + self.num_blocks: int = num_blocks + self.blocks = blocks + + def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + k_cache: list[torch.Tensor] = [] + v_cache: list[torch.Tensor] = [] + for i in range(self.num_blocks): + x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask) + k_cache.append(k_cache_) + v_cache.append(v_cache_) + return x, k_cache, v_cache + + def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]): + for i in range(self.num_blocks): + x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) + return x, k_cache, v_cache + + +class VitsModel(nn.Module): + def __init__(self, vits_path, version=None, is_half=True, device="cpu"): + super().__init__() + # dict_s2 = torch.load(vits_path,map_location="cpu") + dict_s2 = load_sovits_new(vits_path) + self.hps = dict_s2["config"] + + if version is None: + if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + self.hps["model"]["version"] = "v1" + else: + self.hps["model"]["version"] = "v2" + else: + if version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]: + self.hps["model"]["version"] = version + else: + raise ValueError(f"Unsupported version: {version}") + + self.hps = DictToAttrRecursive(self.hps) + self.hps.model.semantic_frame_rate = "25hz" + self.vq_model = SynthesizerTrn( + self.hps.data.filter_length // 2 + 1, + self.hps.train.segment_size // self.hps.data.hop_length, + n_speakers=self.hps.data.n_speakers, + **self.hps.model, + ) + self.vq_model.load_state_dict(dict_s2["weight"], strict=False) + self.vq_model.dec.remove_weight_norm() + if is_half: + self.vq_model = self.vq_model.half() + self.vq_model = self.vq_model.to(device) + self.vq_model.eval() + self.hann_window = torch.hann_window( + self.hps.data.win_length, device=device, dtype=torch.float16 if is_half else torch.float32 + ) + + def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None): + refer = spectrogram_torch( + self.hann_window, + ref_audio, + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False, + ) + return self.vq_model(pred_semantic, text_seq, refer, speed=speed, sv_emb=sv_emb)[0, 0] + + +class T2SModel(nn.Module): + def __init__(self, raw_t2s: Text2SemanticLightningModule): + super(T2SModel, self).__init__() + self.model_dim = raw_t2s.model.model_dim + self.embedding_dim = raw_t2s.model.embedding_dim + self.num_head = raw_t2s.model.num_head + self.num_layers = raw_t2s.model.num_layers + self.vocab_size = raw_t2s.model.vocab_size + self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size + # self.p_dropout = float(raw_t2s.model.p_dropout) + self.EOS: int = int(raw_t2s.model.EOS) + self.norm_first = raw_t2s.model.norm_first + assert self.EOS == self.vocab_size - 1 + self.hz = 50 + + self.bert_proj = raw_t2s.model.bert_proj + self.ar_text_embedding = raw_t2s.model.ar_text_embedding + self.ar_text_position = raw_t2s.model.ar_text_position + self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding + self.ar_audio_position = raw_t2s.model.ar_audio_position + + # self.t2s_transformer = T2STransformer(self.num_layers, blocks) + # self.t2s_transformer = raw_t2s.model.t2s_transformer + + blocks = [] + h = raw_t2s.model.h + + for i in range(self.num_layers): + layer = h.layers[i] + t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias) + + block = T2SBlock( + self.num_head, + self.model_dim, + t2smlp, + layer.self_attn.in_proj_weight, + layer.self_attn.in_proj_bias, + layer.self_attn.out_proj.weight, + layer.self_attn.out_proj.bias, + layer.norm1.weight, + layer.norm1.bias, + layer.norm1.eps, + layer.norm2.weight, + layer.norm2.bias, + layer.norm2.eps, + ) + + blocks.append(block) + + self.t2s_transformer = T2STransformer(self.num_layers, blocks) + + # self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) + self.ar_predict_layer = raw_t2s.model.ar_predict_layer + # self.loss_fct = nn.CrossEntropyLoss(reduction="sum") + self.max_sec = raw_t2s.config["data"]["max_sec"] + self.top_k = int(raw_t2s.config["inference"]["top_k"]) + self.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) + + def forward( + self, + prompts: LongTensor, + ref_seq: LongTensor, + text_seq: LongTensor, + ref_bert: torch.Tensor, + text_bert: torch.Tensor, + top_k: LongTensor, + ): + bert = torch.cat([ref_bert.T, text_bert.T], 1) + all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) + bert = bert.unsqueeze(0) + + x = self.ar_text_embedding(all_phoneme_ids) + x = x + self.bert_proj(bert.transpose(1, 2)) + x: torch.Tensor = self.ar_text_position(x) + + early_stop_num = self.early_stop_num + + # [1,N,512] [1,N] + # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + y = prompts + # x_example = x[:,:,0] * 0.0 + + x_len = x.shape[1] + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + + y_emb = self.ar_audio_embedding(y) + y_len = y_emb.shape[1] + prefix_len = y.shape[1] + y_pos = self.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + + bsz = x.shape[0] + src_len = x_len + y_len + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) + y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = ( + torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + .unsqueeze(0) + .expand(bsz * self.num_head, -1, -1) + .view(bsz, self.num_head, src_len, src_len) + .to(device=x.device, dtype=torch.bool) + ) + + idx = 0 + top_k = int(top_k) + + xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None) + + logits = self.ar_predict_layer(xy_dec[:, -1]) + logits = logits[:, :-1] + samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0] + y = torch.concat([y, samples], dim=1) + y_emb = self.ar_audio_embedding(y[:, -1:]) + xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[ + :, y_len + idx + ].to(dtype=y_emb.dtype, device=y_emb.device) + + stop = False + # for idx in range(1, 50): + for idx in range(1, 1500): + # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] + # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example) + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) + logits = self.ar_predict_layer(xy_dec[:, -1]) + + if idx < 11: ###至少预测出10个token不然不给停止(0.4s) + logits = logits[:, :-1] + + samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0] + + y = torch.concat([y, samples], dim=1) + + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + stop = True + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + stop = True + if stop: + if y.shape[1] == 0: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) + break + + y_emb = self.ar_audio_embedding(y[:, -1:]) + xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[ + :, y_len + idx + ].to(dtype=y_emb.dtype, device=y_emb.device) + + y[0, -1] = 0 + + return y[:, -idx:].unsqueeze(0) + + +bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large") +cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" +cnhubert.cnhubert_base_path = cnhubert_base_path + + +@torch.jit.script +def build_phone_level_feature(res: Tensor, word2ph: IntTensor): + phone_level_feature = [] + for i in range(word2ph.shape[0]): + repeat_feature = res[i].repeat(word2ph[i].item(), 1) + phone_level_feature.append(repeat_feature) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + # [sum(word2ph), 1024] + return phone_level_feature + + +class MyBertModel(torch.nn.Module): + def __init__(self, bert_model): + super(MyBertModel, self).__init__() + self.bert = bert_model + + def forward( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor + ): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) + # res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1] + res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1] + return build_phone_level_feature(res, word2ph) + + +class SSLModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.ssl = cnhubert.get_model().model + + def forward(self, ref_audio_16k) -> torch.Tensor: + ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2) + return ssl_content + + +class ExportSSLModel(torch.nn.Module): + def __init__(self, ssl: SSLModel): + super().__init__() + self.ssl = ssl + + def forward(self, ref_audio: torch.Tensor): + return self.ssl(ref_audio) + + @torch.jit.export + def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor: + audio = resamplex(ref_audio, src_sr, dst_sr).float() + return audio + + +def export_bert(output_path): + tokenizer = AutoTokenizer.from_pretrained(bert_path) + + text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么." + ref_bert_inputs = tokenizer(text, return_tensors="pt") + word2ph = [] + for c in text: + if c in [",", "。", ":", "?", ",", ".", "?"]: + word2ph.append(1) + else: + word2ph.append(2) + ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int() + + bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True) + my_bert_model = MyBertModel(bert_model) + + ref_bert_inputs = { + "input_ids": ref_bert_inputs["input_ids"], + "attention_mask": ref_bert_inputs["attention_mask"], + "token_type_ids": ref_bert_inputs["token_type_ids"], + "word2ph": ref_bert_inputs["word2ph"], + } + + torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1) + torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1) + torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1) + torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0) + + my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs) + output_path = os.path.join(output_path, "bert_model.pt") + my_bert_model.save(output_path) + print("#### exported bert ####") + + +def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"): + if not os.path.exists(output_path): + os.makedirs(output_path) + print(f"目录已创建: {output_path}") + else: + print(f"目录已存在: {output_path}") + + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ssl = SSLModel() + if export_bert_and_ssl: + s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio))) + ssl_path = os.path.join(output_path, "ssl_model.pt") + torch.jit.script(s).save(ssl_path) + print("#### exported ssl ####") + export_bert(output_path) + else: + s = ExportSSLModel(ssl) + + print(f"device: {device}") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2") + ref_seq = torch.LongTensor([ref_seq_id]).to(device) + ref_bert = ref_bert_T.T.to(ref_seq.device) + text_seq_id, text_bert_T, norm_text = get_phones_and_bert( + "这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", + "auto", + "v2", + ) + text_seq = torch.LongTensor([text_seq_id]).to(device) + text_bert = text_bert_T.T.to(text_seq.device) + + ssl_content = ssl(ref_audio).to(device) + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path, device=device, is_half=False) + vits.eval() + + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path, weights_only=False) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + t2s = torch.jit.script(t2s_m).to(device) + print("#### script t2s_m ####") + + print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) + gpt_sovits = GPT_SoVITS(t2s, vits).to(device) + gpt_sovits.eval() + + ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device) + + torch._dynamo.mark_dynamic(ssl_content, 2) + torch._dynamo.mark_dynamic(ref_audio_sr, 1) + torch._dynamo.mark_dynamic(ref_seq, 1) + torch._dynamo.mark_dynamic(text_seq, 1) + torch._dynamo.mark_dynamic(ref_bert, 0) + torch._dynamo.mark_dynamic(text_bert, 0) + + top_k = torch.LongTensor([5]).to(device) + + with torch.no_grad(): + gpt_sovits_export = torch.jit.trace( + gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k) + ) + + gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") + gpt_sovits_export.save(gpt_sovits_path) + print("#### exported gpt_sovits ####") + + +def export_prov2( + gpt_path, + vits_path, + version, + ref_audio_path, + ref_text, + output_path, + export_bert_and_ssl=False, + device="cpu", + is_half=True, +): + if sv_cn_model == None: + init_sv_cn(device, is_half) + + if not os.path.exists(output_path): + os.makedirs(output_path) + print(f"目录已创建: {output_path}") + else: + print(f"目录已存在: {output_path}") + + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ssl = SSLModel() + if export_bert_and_ssl: + s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio))) + ssl_path = os.path.join(output_path, "ssl_model.pt") + torch.jit.script(s).save(ssl_path) + print("#### exported ssl ####") + export_bert(output_path) + else: + s = ExportSSLModel(ssl) + + print(f"device: {device}") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2") + ref_seq = torch.LongTensor([ref_seq_id]).to(device) + ref_bert = ref_bert_T.T + if is_half: + ref_bert = ref_bert.half() + ref_bert = ref_bert.to(ref_seq.device) + + text_seq_id, text_bert_T, norm_text = get_phones_and_bert( + "这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", + "auto", + "v2", + ) + text_seq = torch.LongTensor([text_seq_id]).to(device) + text_bert = text_bert_T.T + if is_half: + text_bert = text_bert.half() + text_bert = text_bert.to(text_seq.device) + + ssl_content = ssl(ref_audio) + if is_half: + ssl_content = ssl_content.half() + ssl_content = ssl_content.to(device) + + sv_model = ExportERes2NetV2(sv_cn_model) + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path, version, is_half=is_half, device=device) + vits.eval() + + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path, weights_only=False) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + if is_half: + raw_t2s = raw_t2s.half() + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + t2s = torch.jit.script(t2s_m).to(device) + print("#### script t2s_m ####") + + print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) + gpt_sovits = GPT_SoVITS_V2Pro(t2s, vits, sv_model).to(device) + gpt_sovits.eval() + + ref_audio_sr = s.resample(ref_audio, 16000, 32000) + if is_half: + ref_audio_sr = ref_audio_sr.half() + ref_audio_sr = ref_audio_sr.to(device) + + torch._dynamo.mark_dynamic(ssl_content, 2) + torch._dynamo.mark_dynamic(ref_audio_sr, 1) + torch._dynamo.mark_dynamic(ref_seq, 1) + torch._dynamo.mark_dynamic(text_seq, 1) + torch._dynamo.mark_dynamic(ref_bert, 0) + torch._dynamo.mark_dynamic(text_bert, 0) + # torch._dynamo.mark_dynamic(sv_emb, 0) + + top_k = torch.LongTensor([5]).to(device) + # 先跑一遍 sv_model 让它加载 cache,详情见 L880 + gpt_sovits.sv_model(ref_audio_sr) + + with torch.no_grad(): + gpt_sovits_export = torch.jit.trace( + gpt_sovits, + example_inputs=( + ssl_content, + ref_audio_sr, + ref_seq, + text_seq, + ref_bert, + text_bert, + top_k, + ), + ) + + gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") + gpt_sovits_export.save(gpt_sovits_path) + print("#### exported gpt_sovits ####") + audio = gpt_sovits_export(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k) + print("start write wav") + soundfile.write("out.wav", audio.float().detach().cpu().numpy(), 32000) + + +@torch.jit.script +def parse_audio(ref_audio): + ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device) + ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device) + return ref_audio_16k, ref_audio_sr + + +@torch.jit.script +def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor: + return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float() + + +class GPT_SoVITS(nn.Module): + def __init__(self, t2s: T2SModel, vits: VitsModel): + super().__init__() + self.t2s = t2s + self.vits = vits + + def forward( + self, + ssl_content: torch.Tensor, + ref_audio_sr: torch.Tensor, + ref_seq: Tensor, + text_seq: Tensor, + ref_bert: Tensor, + text_bert: Tensor, + top_k: LongTensor, + speed=1.0, + ): + codes = self.vits.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompts = prompt_semantic.unsqueeze(0) + + pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed) + return audio + + +class ExportERes2NetV2(nn.Module): + def __init__(self, sv_cn_model: SV): + super(ExportERes2NetV2, self).__init__() + self.bn1 = sv_cn_model.embedding_model.bn1 + self.conv1 = sv_cn_model.embedding_model.conv1 + self.layer1 = sv_cn_model.embedding_model.layer1 + self.layer2 = sv_cn_model.embedding_model.layer2 + self.layer3 = sv_cn_model.embedding_model.layer3 + self.layer4 = sv_cn_model.embedding_model.layer4 + self.layer3_ds = sv_cn_model.embedding_model.layer3_ds + self.fuse34 = sv_cn_model.embedding_model.fuse34 + + # audio_16k.shape: [1,N] + def forward(self, audio_16k): + # 这个 fbank 函数有一个 cache, 不过不要紧,它跟 audio_16k 的长度无关 + # 只跟 device 和 dtype 有关 + x = Kaldi.fbank(audio_16k, num_mel_bins=80, sample_frequency=16000, dither=0) + x = torch.stack([x]) + + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + out4 = self.layer4(out3) + out3_ds = self.layer3_ds(out3) + fuse_out34 = self.fuse34(out4, out3_ds) + return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1) + + +class GPT_SoVITS_V2Pro(nn.Module): + def __init__(self, t2s: T2SModel, vits: VitsModel, sv_model: ExportERes2NetV2): + super().__init__() + self.t2s = t2s + self.vits = vits + self.sv_model = sv_model + + def forward( + self, + ssl_content: torch.Tensor, + ref_audio_sr: torch.Tensor, + ref_seq: Tensor, + text_seq: Tensor, + ref_bert: Tensor, + text_bert: Tensor, + top_k: LongTensor, + speed=1.0, + ): + codes = self.vits.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompts = prompt_semantic.unsqueeze(0) + + audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype) + sv_emb = self.sv_model(audio_16k) + + pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed, sv_emb) + return audio + + +def test(): + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") + parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file") + parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file") + parser.add_argument("--ref_text", required=True, help="Path to the reference text file") + parser.add_argument("--output_path", required=True, help="Path to the output directory") + + args = parser.parse_args() + gpt_path = args.gpt_model + vits_path = args.sovits_model + ref_audio_path = args.ref_audio + ref_text = args.ref_text + + tokenizer = AutoTokenizer.from_pretrained(bert_path) + # bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True) + # bert = MyBertModel(bert_model) + my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda") + + # dict_s1 = torch.load(gpt_path, map_location="cuda") + # raw_t2s = get_raw_t2s_model(dict_s1) + # t2s = T2SModel(raw_t2s) + # t2s.eval() + # t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda') + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + # vits = VitsModel(vits_path) + # vits.eval() + + # ssl = ExportSSLModel(SSLModel()).to('cuda') + # ssl.eval() + ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda") + + # gpt_sovits = GPT_SoVITS(t2s,vits) + gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2") + ref_seq = torch.LongTensor([ref_seq_id]) + ref_bert = ref_bert_T.T.to(ref_seq.device) + # text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2') + text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字." + + text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2") + + test_bert = tokenizer(text, return_tensors="pt") + word2ph = [] + for c in text: + if c in [",", "。", ":", "?", "?", ",", "."]: + word2ph.append(1) + else: + word2ph.append(2) + test_bert["word2ph"] = torch.Tensor(word2ph).int() + + test_bert = my_bert( + test_bert["input_ids"].to("cuda"), + test_bert["attention_mask"].to("cuda"), + test_bert["token_type_ids"].to("cuda"), + test_bert["word2ph"].to("cuda"), + ) + + text_seq = torch.LongTensor([text_seq_id]) + text_bert = text_bert_T.T.to(text_seq.device) + + print("text_bert:", text_bert.shape, text_bert) + print("test_bert:", test_bert.shape, test_bert) + print(torch.allclose(text_bert.to("cuda"), test_bert)) + + print("text_seq:", text_seq.shape) + print("text_bert:", text_bert.shape, text_bert.type()) + + # [1,N] + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda") + print("ref_audio:", ref_audio.shape) + + ref_audio_sr = ssl.resample(ref_audio, 16000, 32000) + print("start ssl") + ssl_content = ssl(ref_audio) + + print("start gpt_sovits:") + print("ssl_content:", ssl_content.shape) + print("ref_audio_sr:", ref_audio_sr.shape) + print("ref_seq:", ref_seq.shape) + ref_seq = ref_seq.to("cuda") + print("text_seq:", text_seq.shape) + text_seq = text_seq.to("cuda") + print("ref_bert:", ref_bert.shape) + ref_bert = ref_bert.to("cuda") + print("text_bert:", text_bert.shape) + text_bert = text_bert.to("cuda") + + top_k = torch.LongTensor([5]).to("cuda") + + with torch.no_grad(): + audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k) + print("start write wav") + soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000) + + +import text +import json + + +def export_symbel(version="v2"): + if version == "v1": + symbols = text._symbol_to_id_v1 + with open("onnx/symbols_v1.json", "w") as file: + json.dump(symbols, file, indent=4) + else: + symbols = text._symbol_to_id_v2 + with open("onnx/symbols_v2.json", "w") as file: + json.dump(symbols, file, indent=4) + + +def main(): + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") + parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file") + parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file") + parser.add_argument("--ref_text", required=True, help="Path to the reference text file") + parser.add_argument("--output_path", required=True, help="Path to the output directory") + parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model") + parser.add_argument("--device", help="Device to use") + parser.add_argument("--version", help="version of the model", default="v2") + parser.add_argument("--no-half", action="store_true", help="Do not use half precision for model weights") + + args = parser.parse_args() + if args.version in ["v2Pro", "v2ProPlus"]: + is_half = not args.no_half + print(f"Using half precision: {is_half}") + export_prov2( + gpt_path=args.gpt_model, + vits_path=args.sovits_model, + version=args.version, + ref_audio_path=args.ref_audio, + ref_text=args.ref_text, + output_path=args.output_path, + export_bert_and_ssl=args.export_common_model, + device=args.device, + is_half=is_half, + ) + else: + export( + gpt_path=args.gpt_model, + vits_path=args.sovits_model, + ref_audio_path=args.ref_audio, + ref_text=args.ref_text, + output_path=args.output_path, + device=args.device, + export_bert_and_ssl=args.export_common_model, + ) + + +if __name__ == "__main__": + with torch.no_grad(): + main() + # test() diff --git a/export_torch_script_v3v4.py b/export_torch_script_v3v4.py new file mode 100644 index 0000000000000000000000000000000000000000..b0e4dba507b429455aac903b368a618a22b705c2 --- /dev/null +++ b/export_torch_script_v3v4.py @@ -0,0 +1,1258 @@ +import os +from export_torch_script import ( + T2SModel, + get_raw_t2s_model, + resamplex, + spectrogram_torch, +) +from f5_tts.model.backbones.dit import DiT +from inference_webui import get_phones_and_bert +import librosa +from module import commons +from module.mel_processing import mel_spectrogram_torch +from module.models_onnx import CFM, Generator, SynthesizerTrnV3 +import numpy as np +import torch._dynamo.config +import torchaudio +import logging +import uvicorn +import torch +import soundfile +from librosa.filters import mel as librosa_mel_fn + + +from inference_webui import get_spepc, norm_spec, resample, ssl_model + +logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG) +logger = logging.getLogger("uvicorn") + +is_half = True +device = "cuda" if torch.cuda.is_available() else "cpu" +now_dir = os.getcwd() + + +class MelSpectrgram(torch.nn.Module): + def __init__( + self, + dtype, + device, + n_fft, + num_mels, + sampling_rate, + hop_size, + win_size, + fmin, + fmax, + center=False, + ): + super().__init__() + self.hann_window = torch.hann_window(win_size).to(device=device, dtype=dtype) + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device) + self.n_fft: int = n_fft + self.hop_size: int = hop_size + self.win_size: int = win_size + self.center: bool = center + + def forward(self, y): + y = torch.nn.functional.pad( + y.unsqueeze(1), + ( + int((self.n_fft - self.hop_size) / 2), + int((self.n_fft - self.hop_size) / 2), + ), + mode="reflect", + ) + y = y.squeeze(1) + spec = torch.stft( + y, + self.n_fft, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=self.center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9) + spec = torch.matmul(self.mel_basis, spec) + # spec = spectral_normalize_torch(spec) + spec = torch.log(torch.clamp(spec, min=1e-5)) + return spec + + +class ExportDitBlocks(torch.nn.Module): + def __init__(self, dit: DiT): + super().__init__() + self.transformer_blocks = dit.transformer_blocks + self.norm_out = dit.norm_out + self.proj_out = dit.proj_out + self.depth = dit.depth + + def forward(self, x, t, mask, rope): + for block in self.transformer_blocks: + x = block(x, t, mask=mask, rope=(rope, 1.0)) + x = self.norm_out(x, t) + output = self.proj_out(x) + return output + + +class ExportDitEmbed(torch.nn.Module): + def __init__(self, dit: DiT): + super().__init__() + self.time_embed = dit.time_embed + self.d_embed = dit.d_embed + self.text_embed = dit.text_embed + self.input_embed = dit.input_embed + self.rotary_embed = dit.rotary_embed + self.rotary_embed.inv_freq.to(device) + + def forward( + self, + x0: torch.Tensor, # nosied input audio # noqa: F722 + cond0: torch.Tensor, # masked cond audio # noqa: F722 + x_lens: torch.Tensor, + time: torch.Tensor, # time step # noqa: F821 F722 + dt_base_bootstrap: torch.Tensor, + text0: torch.Tensor, # noqa: F722#####condition feature + ): + x = x0.transpose(2, 1) + cond = cond0.transpose(2, 1) + text = text0.transpose(2, 1) + mask = commons.sequence_mask(x_lens, max_length=x.size(1)).to(x.device) + + t = self.time_embed(time) + self.d_embed(dt_base_bootstrap) + text_embed = self.text_embed(text, x.shape[1]) + rope_t = torch.arange(x.shape[1], device=device) + rope, _ = self.rotary_embed(rope_t) + x = self.input_embed(x, cond, text_embed) + return x, t, mask, rope + + +class ExportDiT(torch.nn.Module): + def __init__(self, dit: DiT): + super().__init__() + if dit != None: + self.embed = ExportDitEmbed(dit) + self.blocks = ExportDitBlocks(dit) + else: + self.embed = None + self.blocks = None + + def forward( # x, prompt_x, x_lens, t, style,cond + self, # d is channel,n is T + x0: torch.Tensor, # nosied input audio # noqa: F722 + cond0: torch.Tensor, # masked cond audio # noqa: F722 + x_lens: torch.Tensor, + time: torch.Tensor, # time step # noqa: F821 F722 + dt_base_bootstrap: torch.Tensor, + text0: torch.Tensor, # noqa: F722#####condition feature + ): + x, t, mask, rope = self.embed(x0, cond0, x_lens, time, dt_base_bootstrap, text0) + output = self.blocks(x, t, mask, rope) + return output + + +class ExportCFM(torch.nn.Module): + def __init__(self, cfm: CFM): + super().__init__() + self.cfm = cfm + + def forward( + self, + fea_ref: torch.Tensor, + fea_todo_chunk: torch.Tensor, + mel2: torch.Tensor, + sample_steps: torch.LongTensor, + ): + T_min = fea_ref.size(2) + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + cfm_res = self.cfm(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + return cfm_res, fea_ref, mel2 + + +mel_fn = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) +mel_fn_v4 = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1280, + "win_size": 1280, + "hop_size": 320, + "num_mels": 100, + "sampling_rate": 32000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) + +spec_min = -12 +spec_max = 2 + + +@torch.jit.script +def norm_spec(x): + spec_min = -12 + spec_max = 2 + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 + + +def denorm_spec(x): + spec_min = -12 + spec_max = 2 + return (x + 1) / 2 * (spec_max - spec_min) + spec_min + + +class ExportGPTSovitsHalf(torch.nn.Module): + def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3): + super().__init__() + self.hps = hps + self.t2s_m = t2s_m + self.vq_model = vq_model + self.mel2 = MelSpectrgram( + dtype=torch.float32, + device=device, + n_fft=1024, + num_mels=100, + sampling_rate=24000, + hop_size=256, + win_size=1024, + fmin=0, + fmax=None, + center=False, + ) + # self.dtype = dtype + self.filter_length: int = hps.data.filter_length + self.sampling_rate: int = hps.data.sampling_rate + self.hop_length: int = hps.data.hop_length + self.win_length: int = hps.data.win_length + self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32) + + def forward( + self, + ssl_content, + ref_audio_32k: torch.FloatTensor, + phoneme_ids0, + phoneme_ids1, + bert1, + bert2, + top_k, + ): + refer = spectrogram_torch( + self.hann_window, + ref_audio_32k, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ).to(ssl_content.dtype) + + codes = self.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0) + # print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k) + # print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + ge = self.vq_model.create_ge(refer) + # print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + prompt_ = prompt.unsqueeze(0) + fea_ref = self.vq_model(prompt_, phoneme_ids0, ge) + # print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + # print(prompt_.shape, phoneme_ids0.shape, ge.shape) + # print(fea_ref.shape) + + ref_24k = resamplex(ref_audio_32k, 32000, 24000) + mel2 = norm_spec(self.mel2(ref_24k)).to(ssl_content.dtype) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + if T_min > 468: + mel2 = mel2[:, :, -468:] + fea_ref = fea_ref[:, :, -468:] + T_min = 468 + + fea_todo = self.vq_model(pred_semantic, phoneme_ids1, ge) + # print('fea_todo',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + # print(pred_semantic.shape, phoneme_ids1.shape, ge.shape) + # print(fea_todo.shape) + + return fea_ref, fea_todo, mel2 + + +class ExportGPTSovitsV4Half(torch.nn.Module): + def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3): + super().__init__() + self.hps = hps + self.t2s_m = t2s_m + self.vq_model = vq_model + self.mel2 = MelSpectrgram( + dtype=torch.float32, + device=device, + n_fft=1280, + num_mels=100, + sampling_rate=32000, + hop_size=320, + win_size=1280, + fmin=0, + fmax=None, + center=False, + ) + # self.dtype = dtype + self.filter_length: int = hps.data.filter_length + self.sampling_rate: int = hps.data.sampling_rate + self.hop_length: int = hps.data.hop_length + self.win_length: int = hps.data.win_length + self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32) + + def forward( + self, + ssl_content, + ref_audio_32k: torch.FloatTensor, + phoneme_ids0, + phoneme_ids1, + bert1, + bert2, + top_k, + ): + refer = spectrogram_torch( + self.hann_window, + ref_audio_32k, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ).to(ssl_content.dtype) + + codes = self.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0) + # print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k) + # print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + ge = self.vq_model.create_ge(refer) + # print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + prompt_ = prompt.unsqueeze(0) + fea_ref = self.vq_model(prompt_, phoneme_ids0, ge) + # print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + # print(prompt_.shape, phoneme_ids0.shape, ge.shape) + # print(fea_ref.shape) + + ref_32k = ref_audio_32k + mel2 = norm_spec(self.mel2(ref_32k)).to(ssl_content.dtype) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + if T_min > 500: + mel2 = mel2[:, :, -500:] + fea_ref = fea_ref[:, :, -500:] + T_min = 500 + + fea_todo = self.vq_model(pred_semantic, phoneme_ids1, ge) + # print('fea_todo',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + # print(pred_semantic.shape, phoneme_ids1.shape, ge.shape) + # print(fea_todo.shape) + + return fea_ref, fea_todo, mel2 + + +class GPTSoVITSV3(torch.nn.Module): + def __init__(self, gpt_sovits_half, cfm, bigvgan): + super().__init__() + self.gpt_sovits_half = gpt_sovits_half + self.cfm = cfm + self.bigvgan = bigvgan + + def forward( + self, + ssl_content, + ref_audio_32k: torch.FloatTensor, + phoneme_ids0: torch.LongTensor, + phoneme_ids1: torch.LongTensor, + bert1, + bert2, + top_k: torch.LongTensor, + sample_steps: torch.LongTensor, + ): + # current_time = datetime.now() + # print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S")) + fea_ref, fea_todo, mel2 = self.gpt_sovits_half( + ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k + ) + chunk_len = 934 - fea_ref.shape[2] + wav_gen_list = [] + idx = 0 + fea_todo = fea_todo[:, :, :-5] + wav_gen_length = fea_todo.shape[2] * 256 + while 1: + # current_time = datetime.now() + # print("idx:",idx,current_time.strftime("%Y-%m-%d %H:%M:%S")) + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + + # 因为导出的模型在不同shape时会重新编译还是怎么的,会卡顿10s这样, + # 所以在这里补0让他shape维持不变 + # 但是这样会导致生成的音频长度不对,所以在最后截取一下。 + # 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256 + complete_len = chunk_len - fea_todo_chunk.shape[-1] + if complete_len != 0: + fea_todo_chunk = torch.cat( + [ + fea_todo_chunk, + torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype), + ], + 2, + ) + + cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps) + idx += chunk_len + + cfm_res = denorm_spec(cfm_res) + bigvgan_res = self.bigvgan(cfm_res) + wav_gen_list.append(bigvgan_res) + + wav_gen = torch.cat(wav_gen_list, 2) + return wav_gen[0][0][:wav_gen_length] + + +class GPTSoVITSV4(torch.nn.Module): + def __init__(self, gpt_sovits_half, cfm, hifigan): + super().__init__() + self.gpt_sovits_half = gpt_sovits_half + self.cfm = cfm + self.hifigan = hifigan + + def forward( + self, + ssl_content, + ref_audio_32k: torch.FloatTensor, + phoneme_ids0: torch.LongTensor, + phoneme_ids1: torch.LongTensor, + bert1, + bert2, + top_k: torch.LongTensor, + sample_steps: torch.LongTensor, + ): + # current_time = datetime.now() + # print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S")) + fea_ref, fea_todo, mel2 = self.gpt_sovits_half( + ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k + ) + chunk_len = 1000 - fea_ref.shape[2] + wav_gen_list = [] + idx = 0 + fea_todo = fea_todo[:, :, :-10] + wav_gen_length = fea_todo.shape[2] * 480 + while 1: + # current_time = datetime.now() + # print("idx:",idx,current_time.strftime("%Y-%m-%d %H:%M:%S")) + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + + # 因为导出的模型在不同shape时会重新编译还是怎么的,会卡顿10s这样, + # 所以在这里补0让他shape维持不变 + # 但是这样会导致生成的音频长度不对,所以在最后截取一下。 + # 经过 hifigan 之后音频长度就是 fea_todo.shape[2] * 480 + complete_len = chunk_len - fea_todo_chunk.shape[-1] + if complete_len != 0: + fea_todo_chunk = torch.cat( + [ + fea_todo_chunk, + torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype), + ], + 2, + ) + + cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps) + idx += chunk_len + + cfm_res = denorm_spec(cfm_res) + hifigan_res = self.hifigan(cfm_res) + wav_gen_list.append(hifigan_res) + + wav_gen = torch.cat(wav_gen_list, 2) + return wav_gen[0][0][:wav_gen_length] + + +def init_bigvgan(): + global bigvgan_model + from BigVGAN import bigvgan + + bigvgan_model = bigvgan.BigVGAN.from_pretrained( + "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), + use_cuda_kernel=False, + ) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + bigvgan_model.remove_weight_norm() + bigvgan_model = bigvgan_model.eval() + if is_half == True: + bigvgan_model = bigvgan_model.half().to(device) + else: + bigvgan_model = bigvgan_model.to(device) + + +def init_hifigan(): + global hifigan_model, bigvgan_model + hifigan_model = Generator( + initial_channel=100, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[10, 6, 2, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[20, 12, 4, 4, 4], + gin_channels=0, + is_bias=True, + ) + hifigan_model.eval() + hifigan_model.remove_weight_norm() + state_dict_g = torch.load( + "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu" + ) + print("loading vocoder", hifigan_model.load_state_dict(state_dict_g)) + if is_half == True: + hifigan_model = hifigan_model.half().to(device) + else: + hifigan_model = hifigan_model.to(device) + + +class Sovits: + def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps): + self.vq_model = vq_model + self.hps = hps + cfm.estimator = ExportDiT(cfm.estimator) + self.cfm = cfm + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new + +v3v4set = {"v3", "v4"} + + +def get_sovits_weights(sovits_path): + path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + is_exist_s2gv3 = os.path.exists(path_sovits_v3) + + version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) + if if_lora_v3 == True and is_exist_s2gv3 == False: + logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + + dict_s2 = load_sovits_new(sovits_path) + hps = dict_s2["config"] + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + if "enc_p.text_embedding.weight" not in dict_s2["weight"]: + hps.model.version = "v2" # v3model,v2sybomls + elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + hps.model.version = "v1" + else: + hps.model.version = "v2" + + if model_version in v3v4set: + hps.model.version = model_version + + logger.info(f"hps: {hps}") + + vq_model = SynthesizerTrnV3( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + # init_bigvgan() + model_version = hps.model.version + logger.info(f"模型版本: {model_version}") + + if is_half == True: + vq_model = vq_model.half().to(device) + else: + vq_model = vq_model.to(device) + vq_model.load_state_dict(dict_s2["weight"], strict=False) + vq_model.eval() + + cfm = vq_model.cfm + del vq_model.cfm + + sovits = Sovits(vq_model, cfm, hps) + return sovits + + +logger.info(f"torch version {torch.__version__}") +# ssl_model = cnhubert.get_model() +# if is_half: +# ssl_model = ssl_model.half().to(device) +# else: +# ssl_model = ssl_model.to(device) + + +def export_cfm( + e_cfm: ExportCFM, + mu: torch.Tensor, + x_lens: torch.LongTensor, + prompt: torch.Tensor, + n_timesteps: torch.IntTensor, + temperature=1.0, +): + cfm = e_cfm.cfm + + B, T = mu.size(0), mu.size(1) + x = torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature + print("x:", x.shape, x.dtype) + prompt_len = prompt.size(-1) + prompt_x = torch.zeros_like(x, dtype=mu.dtype) + prompt_x[..., :prompt_len] = prompt[..., :prompt_len] + x[..., :prompt_len] = 0.0 + mu = mu.transpose(2, 1) + + ntimestep = int(n_timesteps) + + t = torch.tensor(0.0, dtype=x.dtype, device=x.device) + d = torch.tensor(1.0 / ntimestep, dtype=x.dtype, device=x.device) + + t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t + d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d + + print( + "cfm input shapes:", + x.shape, + prompt_x.shape, + x_lens.shape, + t_tensor.shape, + d_tensor.shape, + mu.shape, + ) + + print("cfm input dtypes:", x.dtype, prompt_x.dtype, x_lens.dtype, t_tensor.dtype, d_tensor.dtype, mu.dtype) + + estimator: ExportDiT = torch.jit.trace( + cfm.estimator, + optimize=True, + example_inputs=(x, prompt_x, x_lens, t_tensor, d_tensor, mu), + ) + estimator.save("onnx/ad/estimator.pt") + # torch.onnx.export( + # cfm.estimator, + # (x, prompt_x, x_lens, t_tensor, d_tensor, mu), + # "onnx/ad/dit.onnx", + # input_names=["x", "prompt_x", "x_lens", "t", "d", "mu"], + # output_names=["output"], + # dynamic_axes={ + # "x": [2], + # "prompt_x": [2], + # "mu": [2], + # }, + # ) + print("save estimator ok") + cfm.estimator = estimator + export_cfm = torch.jit.script(e_cfm) + export_cfm.save("onnx/ad/cfm.pt") + # sovits.cfm = cfm + # cfm.save("onnx/ad/cfm.pt") + return export_cfm + + +def export_1(ref_wav_path, ref_wav_text, version="v3"): + if version == "v3": + sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") + init_bigvgan() + else: + sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth") + init_hifigan() + + dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt") + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + + if is_half: + raw_t2s = raw_t2s.half().to(device) + + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + script_t2s = torch.jit.script(t2s_m).to(device) + + hps = sovits.hps + # ref_wav_path = "onnx/ad/ref.wav" + speed = 1.0 + sample_steps = 8 + dtype = torch.float16 if is_half == True else torch.float32 + refer = get_spepc(hps, ref_wav_path).to(device).to(dtype) + zero_wav = np.zeros( + int(hps.data.sampling_rate * 0.3), + dtype=np.float16 if is_half == True else np.float32, + ) + + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + codes = sovits.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + + # phones1, bert1, norm_text1 = get_phones_and_bert( + # "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3" + # ) + phones1, bert1, norm_text1 = get_phones_and_bert(ref_wav_text, "auto", "v3") + phones2, bert2, norm_text2 = get_phones_and_bert( + "这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", + "auto", + "v3", + ) + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) + + # codes = sovits.vq_model.extract_latent(ssl_content) + # prompt_semantic = codes[0, 0] + # prompts = prompt_semantic.unsqueeze(0) + + top_k = torch.LongTensor([15]).to(device) + print("topk", top_k) + + bert1 = bert1.T.to(device) + bert2 = bert2.T.to(device) + print( + prompt.dtype, + phoneme_ids0.dtype, + phoneme_ids1.dtype, + bert1.dtype, + bert2.dtype, + top_k.dtype, + ) + print( + prompt.shape, + phoneme_ids0.shape, + phoneme_ids1.shape, + bert1.shape, + bert2.shape, + top_k.shape, + ) + pred_semantic = t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k) + + ge = sovits.vq_model.create_ge(refer) + prompt_ = prompt.unsqueeze(0) + + torch._dynamo.mark_dynamic(prompt_, 2) + torch._dynamo.mark_dynamic(phoneme_ids0, 1) + + fea_ref = sovits.vq_model(prompt_, phoneme_ids0, ge) + + inputs = { + "forward": (prompt_, phoneme_ids0, ge), + "extract_latent": ssl_content, + "create_ge": refer, + } + + trace_vq_model = torch.jit.trace_module(sovits.vq_model, inputs, optimize=True) + trace_vq_model.save("onnx/ad/vq_model.pt") + + print(fea_ref.shape, fea_ref.dtype, ge.shape) + print(prompt_.shape, phoneme_ids0.shape, ge.shape) + + # vq_model = torch.jit.trace( + # sovits.vq_model, + # optimize=True, + # # strict=False, + # example_inputs=(prompt_, phoneme_ids0, ge), + # ) + # vq_model = sovits.vq_model + vq_model = trace_vq_model + + if version == "v3": + gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model) + torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt") + else: + gpt_sovits_half = ExportGPTSovitsV4Half(sovits.hps, script_t2s, trace_vq_model) + torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v4_half.pt") + + ref_audio, sr = torchaudio.load(ref_wav_path) + ref_audio = ref_audio.to(device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + tgt_sr = 24000 if version == "v3" else 32000 + if sr != tgt_sr: + ref_audio = resample(ref_audio, sr, tgt_sr) + # mel2 = mel_fn(ref_audio) + mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + fea_ref = fea_ref[:, :, :T_min] + print("fea_ref:", fea_ref.shape, T_min) + Tref = 468 if version == "v3" else 500 + Tchunk = 934 if version == "v3" else 1000 + if T_min > Tref: + mel2 = mel2[:, :, -Tref:] + fea_ref = fea_ref[:, :, -Tref:] + T_min = Tref + chunk_len = Tchunk - T_min + mel2 = mel2.to(dtype) + + # fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge) + fea_todo = vq_model(pred_semantic, phoneme_ids1, ge) + + cfm_resss = [] + idx = 0 + sample_steps = torch.LongTensor([sample_steps]).to(device) + export_cfm_ = ExportCFM(sovits.cfm) + while 1: + print("idx:", idx) + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + + print( + "export_cfm:", + fea_ref.shape, + fea_todo_chunk.shape, + mel2.shape, + sample_steps.shape, + ) + if idx == 0: + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + export_cfm_ = export_cfm( + export_cfm_, + fea, + torch.LongTensor([fea.size(1)]).to(fea.device), + mel2, + sample_steps, + ) + # torch.onnx.export( + # export_cfm_, + # ( + # fea_ref, + # fea_todo_chunk, + # mel2, + # sample_steps, + # ), + # "onnx/ad/cfm.onnx", + # input_names=["fea_ref", "fea_todo_chunk", "mel2", "sample_steps"], + # output_names=["cfm_res", "fea_ref_", "mel2_"], + # dynamic_axes={ + # "fea_ref": [2], + # "fea_todo_chunk": [2], + # "mel2": [2], + # }, + # ) + + idx += chunk_len + + cfm_res, fea_ref, mel2 = export_cfm_(fea_ref, fea_todo_chunk, mel2, sample_steps) + cfm_resss.append(cfm_res) + continue + + cmf_res = torch.cat(cfm_resss, 2) + cmf_res = denorm_spec(cmf_res).to(device) + print("cmf_res:", cmf_res.shape, cmf_res.dtype) + with torch.inference_mode(): + cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype) + torch._dynamo.mark_dynamic(cmf_res_rand, 2) + if version == "v3": + bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,)) + bigvgan_model_.save("onnx/ad/bigvgan_model.pt") + wav_gen = bigvgan_model(cmf_res) + else: + hifigan_model_ = torch.jit.trace(hifigan_model, optimize=True, example_inputs=(cmf_res_rand,)) + hifigan_model_.save("onnx/ad/hifigan_model.pt") + wav_gen = hifigan_model(cmf_res) + + print("wav_gen:", wav_gen.shape, wav_gen.dtype) + audio = wav_gen[0][0].cpu().detach().numpy() + + sr = 24000 if version == "v3" else 48000 + soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr) + + +from datetime import datetime + + +def test_export( + todo_text, + gpt_sovits_v3_half, + cfm, + bigvgan, + output, +): + # hps = sovits.hps + ref_wav_path = "onnx/ad/ref.wav" + speed = 1.0 + sample_steps = 8 + + dtype = torch.float16 if is_half == True else torch.float32 + + zero_wav = np.zeros( + int(16000 * 0.3), + dtype=np.float16 if is_half == True else np.float32, + ) + + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + + ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000) + ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float() + + phones1, bert1, norm_text1 = get_phones_and_bert( + "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3" + ) + phones2, bert2, norm_text2 = get_phones_and_bert( + todo_text, + "zh", + "v3", + ) + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) + + bert1 = bert1.T.to(device) + bert2 = bert2.T.to(device) + top_k = torch.LongTensor([15]).to(device) + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + logger.info("start inference %s", current_time) + print( + ssl_content.shape, + ref_audio_32k.shape, + phoneme_ids0.shape, + phoneme_ids1.shape, + bert1.shape, + bert2.shape, + top_k.shape, + ) + fea_ref, fea_todo, mel2 = gpt_sovits_v3_half( + ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k + ) + chunk_len = 934 - fea_ref.shape[2] + print(fea_ref.shape, fea_todo.shape, mel2.shape) + + cfm_resss = [] + sample_steps = torch.LongTensor([sample_steps]) + idx = 0 + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + logger.info("start cfm %s", current_time) + wav_gen_length = fea_todo.shape[2] * 256 + + while 1: + current_time = datetime.now() + print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S")) + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + + complete_len = chunk_len - fea_todo_chunk.shape[-1] + if complete_len != 0: + fea_todo_chunk = torch.cat([fea_todo_chunk, torch.zeros(1, 512, complete_len).to(device).to(dtype)], 2) + + cfm_res, fea_ref, mel2 = cfm(fea_ref, fea_todo_chunk, mel2, sample_steps) + # if complete_len > 0 : + # cfm_res = cfm_res[:, :, :-complete_len] + # fea_ref = fea_ref[:, :, :-complete_len] + # mel2 = mel2[:, :, :-complete_len] + + idx += chunk_len + + current_time = datetime.now() + print("cfm end", current_time.strftime("%Y-%m-%d %H:%M:%S")) + cfm_res = denorm_spec(cfm_res).to(device) + bigvgan_res = bigvgan(cfm_res) + cfm_resss.append(bigvgan_res) + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + logger.info("start bigvgan %s", current_time) + wav_gen = torch.cat(cfm_resss, 2) + # cmf_res = denorm_spec(cmf_res) + # cmf_res = cmf_res.to(device) + # print("cmf_res:", cmf_res.shape) + + # cmf_res = torch.cat([cmf_res,torch.zeros([1,100,2000-cmf_res.size(2)],device=device,dtype=cmf_res.dtype)], 2) + + # wav_gen = bigvgan(cmf_res) + print("wav_gen:", wav_gen.shape, wav_gen.dtype) + wav_gen = wav_gen[:, :, :wav_gen_length] + + audio = wav_gen[0][0].cpu().detach().numpy() + logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + sr = 24000 + soundfile.write(output, (audio * 32768).astype(np.int16), sr) + + +def test_export( + todo_text, + gpt_sovits_v3v4, + output, + out_sr=24000, +): + # hps = sovits.hps + ref_wav_path = "onnx/ad/ref.wav" + speed = 1.0 + sample_steps = torch.LongTensor([16]) + + dtype = torch.float16 if is_half == True else torch.float32 + + zero_wav = np.zeros( + int(out_sr * 0.3), + dtype=np.float16 if is_half == True else np.float32, + ) + + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + print("ssl_content:", ssl_content.shape, ssl_content.dtype) + + ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000) + ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float() + + phones1, bert1, norm_text1 = get_phones_and_bert( + "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3" + ) + phones2, bert2, norm_text2 = get_phones_and_bert( + todo_text, + "zh", + "v3", + ) + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) + + bert1 = bert1.T.to(device) + bert2 = bert2.T.to(device) + top_k = torch.LongTensor([20]).to(device) + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + logger.info("start inference %s", current_time) + print( + ssl_content.shape, + ref_audio_32k.shape, + phoneme_ids0.shape, + phoneme_ids1.shape, + bert1.shape, + bert2.shape, + top_k.shape, + ) + wav_gen = gpt_sovits_v3v4(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps) + print("wav_gen:", wav_gen.shape, wav_gen.dtype) + + wav_gen = torch.cat([wav_gen, zero_wav_torch], 0) + + audio = wav_gen.cpu().detach().numpy() + logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + soundfile.write(output, (audio * 32768).astype(np.int16), out_sr) + + +import time + + +def export_2(version="v3"): + if version == "v3": + sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") + # init_bigvgan() + else: + sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth") + # init_hifigan() + + # cfm = ExportCFM(sovits.cfm) + # cfm.cfm.estimator = dit + sovits.cfm = None + + cfm = torch.jit.load("onnx/ad/cfm.pt", map_location=device) + # cfm = torch.jit.optimize_for_inference(cfm) + cfm = cfm.half().to(device) + + cfm.eval() + + logger.info("cfm ok") + + dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt") + # v2 的 gpt 也可以用 + # dict_s1 = torch.load("GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt") + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + if is_half: + raw_t2s = raw_t2s.half().to(device) + t2s_m = T2SModel(raw_t2s).half().to(device) + t2s_m.eval() + t2s_m = torch.jit.script(t2s_m).to(device) + t2s_m.eval() + # t2s_m.top_k = 15 + logger.info("t2s_m ok") + + vq_model: torch.jit.ScriptModule = torch.jit.load("onnx/ad/vq_model.pt", map_location=device) + # vq_model = torch.jit.optimize_for_inference(vq_model) + # vq_model = vq_model.half().to(device) + vq_model.eval() + # vq_model = sovits.vq_model + logger.info("vq_model ok") + + # gpt_sovits_v3_half = torch.jit.load("onnx/ad/gpt_sovits_v3_half.pt") + # gpt_sovits_v3_half = torch.jit.optimize_for_inference(gpt_sovits_v3_half) + # gpt_sovits_v3_half = gpt_sovits_v3_half.half() + # gpt_sovits_v3_half = gpt_sovits_v3_half.cuda() + # gpt_sovits_v3_half.eval() + if version == "v3": + gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model) + logger.info("gpt_sovits_v3_half ok") + # init_bigvgan() + # global bigvgan_model + bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt") + # bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model) + bigvgan_model = bigvgan_model.half() + bigvgan_model = bigvgan_model.cuda() + bigvgan_model.eval() + + logger.info("bigvgan ok") + gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model) + gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3) + gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt") + gpt_sovits_v3 = gpt_sovits_v3.half().to(device) + gpt_sovits_v3.eval() + print("save gpt_sovits_v3 ok") + else: + gpt_sovits_v4_half = ExportGPTSovitsV4Half(sovits.hps, t2s_m, vq_model) + logger.info("gpt_sovits_v4 ok") + + hifigan_model = torch.jit.load("onnx/ad/hifigan_model.pt") + hifigan_model = hifigan_model.half() + hifigan_model = hifigan_model.cuda() + hifigan_model.eval() + logger.info("hifigan ok") + gpt_sovits_v4 = GPTSoVITSV4(gpt_sovits_v4_half, cfm, hifigan_model) + gpt_sovits_v4 = torch.jit.script(gpt_sovits_v4) + gpt_sovits_v4.save("onnx/ad/gpt_sovits_v4.pt") + print("save gpt_sovits_v4 ok") + + gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4 + sr = 24000 if version == "v3" else 48000 + + time.sleep(5) + # print("thread:", torch.get_num_threads()) + # print("thread:", torch.get_num_interop_threads()) + # torch.set_num_interop_threads(1) + # torch.set_num_threads(1) + + test_export( + "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....", + gpt_sovits_v3v4, + "out.wav", + sr, + ) + + test_export( + "你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!", + gpt_sovits_v3v4, + "out2.wav", + sr, + ) + + # test_export( + # "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP. 哈哈哈...", + # gpt_sovits_v3_half, + # cfm, + # bigvgan_model, + # "out2.wav", + # ) + + +def test_export_gpt_sovits_v3(): + gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device) + # test_export1( + # "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....", + # gpt_sovits_v3, + # "out3.wav", + # ) + # test_export1( + # "你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!", + # gpt_sovits_v3, + # "out4.wav", + # ) + test_export( + "风萧萧兮易水寒,壮士一去兮不复还.", + gpt_sovits_v3, + "out5.wav", + ) + + +with torch.no_grad(): + # export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4") + export_2("v4") + # test_export_gpt_sovits_v3() diff --git a/f5_tts/model/__init__.py b/f5_tts/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..50cff994a777a9b7bba2a0f83fd03dd365fb49d2 --- /dev/null +++ b/f5_tts/model/__init__.py @@ -0,0 +1,13 @@ +# from f5_tts.model.cfm import CFM +# +# from f5_tts.model.backbones.unett import UNetT +from GPT_SoVITS.f5_tts.model.backbones.dit import DiT +# from f5_tts.model.backbones.dit import DiTNoCond +# from f5_tts.model.backbones.dit import DiTNoCondNoT +# from f5_tts.model.backbones.mmdit import MMDiT + +# from f5_tts.model.trainer import Trainer + + +# __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"] +# __all__ = ["CFM", "UNetT", "DiTNoCond","DiT", "MMDiT"] diff --git a/f5_tts/model/backbones/README.md b/f5_tts/model/backbones/README.md new file mode 100644 index 0000000000000000000000000000000000000000..155671e16fbf128a243ece9033cefd47b957af88 --- /dev/null +++ b/f5_tts/model/backbones/README.md @@ -0,0 +1,20 @@ +## Backbones quick introduction + + +### unett.py +- flat unet transformer +- structure same as in e2-tts & voicebox paper except using rotary pos emb +- update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat + +### dit.py +- adaln-zero dit +- embedded timestep as condition +- concatted noised_input + masked_cond + embedded_text, linear proj in +- possible abs pos emb & convnextv2 blocks for embedded text before concat +- possible long skip connection (first layer to last layer) + +### mmdit.py +- sd3 structure +- timestep as condition +- left stream: text embedded and applied a abs pos emb +- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett diff --git a/f5_tts/model/backbones/dit.py b/f5_tts/model/backbones/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa3b9ac0737da9fed529f3682c74ac4dcd20331 --- /dev/null +++ b/f5_tts/model/backbones/dit.py @@ -0,0 +1,194 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +from x_transformers.x_transformers import RotaryEmbedding + +from GPT_SoVITS.f5_tts.model.modules import ( + TimestepEmbedding, + ConvNeXtV2Block, + ConvPositionEmbedding, + DiTBlock, + AdaLayerNormZero_Final, + precompute_freqs_cis, + get_pos_embed_indices, +) + +from module.commons import sequence_mask + + +class TextEmbedding(nn.Module): + def __init__(self, text_dim, conv_layers=0, conv_mult=2): + super().__init__() + if conv_layers > 0: + self.extra_modeling = True + self.precompute_max_pos = 4096 # ~44s of 24khz audio + self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) + self.text_blocks = nn.Sequential( + *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] + ) + else: + self.extra_modeling = False + + def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 + batch, text_len = text.shape[0], text.shape[1] + + if drop_text: # cfg for text + text = torch.zeros_like(text) + + # possible extra modeling + if self.extra_modeling: + # sinus pos emb + batch_start = torch.zeros((batch,), dtype=torch.long) + pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) + text_pos_embed = self.freqs_cis[pos_idx] + + # print(23333333,text.shape,text_pos_embed.shape)#torch.Size([7, 465, 256]) torch.Size([7, 465, 256]) + + text = text + text_pos_embed + + # convnextv2 blocks + text = self.text_blocks(text) + + return text + + +# noised input audio and context mixing embedding + + +class InputEmbedding(nn.Module): + def __init__(self, mel_dim, text_dim, out_dim): + super().__init__() + self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) + self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) + + def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 + if drop_audio_cond: # cfg for cond audio + cond = torch.zeros_like(cond) + + x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) + x = self.conv_pos_embed(x) + x + return x + + +# Transformer backbone using DiT blocks + + +class DiT(nn.Module): + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + mel_dim=100, + text_dim=None, + conv_layers=0, + long_skip_connection=False, + ): + super().__init__() + + self.time_embed = TimestepEmbedding(dim) + self.d_embed = TimestepEmbedding(dim) + if text_dim is None: + text_dim = mel_dim + self.text_embed = TextEmbedding(text_dim, conv_layers=conv_layers) + self.input_embed = InputEmbedding(mel_dim, text_dim, dim) + + self.rotary_embed = RotaryEmbedding(dim_head) + + self.dim = dim + self.depth = depth + + self.transformer_blocks = nn.ModuleList( + [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)] + ) + self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None + + self.norm_out = AdaLayerNormZero_Final(dim) # final modulation + self.proj_out = nn.Linear(dim, mel_dim) + + def ckpt_wrapper(self, module): + # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py + def ckpt_forward(*inputs): + outputs = module(*inputs) + return outputs + + return ckpt_forward + + def forward( # x, prompt_x, x_lens, t, style,cond + self, # d is channel,n is T + x0: float["b n d"], # nosied input audio # noqa: F722 + cond0: float["b n d"], # masked cond audio # noqa: F722 + x_lens, + time: float["b"] | float[""], # time step # noqa: F821 F722 + dt_base_bootstrap, + text0, # : int["b nt"] # noqa: F722#####condition feature + use_grad_ckpt=False, # bool + ###no-use + drop_audio_cond=False, # cfg for cond audio + drop_text=False, # cfg for text + # mask: bool["b n"] | None = None, # noqa: F722 + infer=False, # bool + text_cache=None, # torch tensor as text_embed + dt_cache=None, # torch tensor as dt + ): + x = x0.transpose(2, 1) + cond = cond0.transpose(2, 1) + text = text0.transpose(2, 1) + mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device) + + batch, seq_len = x.shape[0], x.shape[1] + if time.ndim == 0: + time = time.repeat(batch) + + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(time) + if infer and dt_cache is not None: + dt = dt_cache + else: + dt = self.d_embed(dt_base_bootstrap) + t += dt + + if infer and text_cache is not None: + text_embed = text_cache + else: + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change + + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) + + rope = self.rotary_embed.forward_from_seq_len(seq_len) + + if self.long_skip_connection is not None: + residual = x + + for block in self.transformer_blocks: + if use_grad_ckpt: + x = checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False) + else: + x = block(x, t, mask=mask, rope=rope) + + if self.long_skip_connection is not None: + x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) + + x = self.norm_out(x, t) + output = self.proj_out(x) + + if infer: + return output, text_embed, dt + else: + return output diff --git a/f5_tts/model/backbones/mmdit.py b/f5_tts/model/backbones/mmdit.py new file mode 100644 index 0000000000000000000000000000000000000000..64c7ef18e1195631f3917af95ca7c8ac12462bf8 --- /dev/null +++ b/f5_tts/model/backbones/mmdit.py @@ -0,0 +1,146 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +import torch +from torch import nn + +from x_transformers.x_transformers import RotaryEmbedding + +from f5_tts.model.modules import ( + TimestepEmbedding, + ConvPositionEmbedding, + MMDiTBlock, + AdaLayerNormZero_Final, + precompute_freqs_cis, + get_pos_embed_indices, +) + + +# text embedding + + +class TextEmbedding(nn.Module): + def __init__(self, out_dim, text_num_embeds): + super().__init__() + self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token + + self.precompute_max_pos = 1024 + self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False) + + def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722 + text = text + 1 + if drop_text: + text = torch.zeros_like(text) + text = self.text_embed(text) + + # sinus pos emb + batch_start = torch.zeros((text.shape[0],), dtype=torch.long) + batch_text_len = text.shape[1] + pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos) + text_pos_embed = self.freqs_cis[pos_idx] + + text = text + text_pos_embed + + return text + + +# noised input & masked cond audio embedding + + +class AudioEmbedding(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.linear = nn.Linear(2 * in_dim, out_dim) + self.conv_pos_embed = ConvPositionEmbedding(out_dim) + + def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722 + if drop_audio_cond: + cond = torch.zeros_like(cond) + x = torch.cat((x, cond), dim=-1) + x = self.linear(x) + x = self.conv_pos_embed(x) + x + return x + + +# Transformer backbone using MM-DiT blocks + + +class MMDiT(nn.Module): + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + text_num_embeds=256, + mel_dim=100, + ): + super().__init__() + + self.time_embed = TimestepEmbedding(dim) + self.text_embed = TextEmbedding(dim, text_num_embeds) + self.audio_embed = AudioEmbedding(mel_dim, dim) + + self.rotary_embed = RotaryEmbedding(dim_head) + + self.dim = dim + self.depth = depth + + self.transformer_blocks = nn.ModuleList( + [ + MMDiTBlock( + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ff_mult=ff_mult, + context_pre_only=i == depth - 1, + ) + for i in range(depth) + ] + ) + self.norm_out = AdaLayerNormZero_Final(dim) # final modulation + self.proj_out = nn.Linear(dim, mel_dim) + + def forward( + self, + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 + drop_audio_cond, # cfg for cond audio + drop_text, # cfg for text + mask: bool["b n"] | None = None, # noqa: F722 + ): + batch = x.shape[0] + if time.ndim == 0: + time = time.repeat(batch) + + # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(time) + c = self.text_embed(text, drop_text=drop_text) + x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond) + + seq_len = x.shape[1] + text_len = text.shape[1] + rope_audio = self.rotary_embed.forward_from_seq_len(seq_len) + rope_text = self.rotary_embed.forward_from_seq_len(text_len) + + for block in self.transformer_blocks: + c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text) + + x = self.norm_out(x, t) + output = self.proj_out(x) + + return output diff --git a/f5_tts/model/backbones/unett.py b/f5_tts/model/backbones/unett.py new file mode 100644 index 0000000000000000000000000000000000000000..acf649a52448e87a34a2af4bc14051caaba74c86 --- /dev/null +++ b/f5_tts/model/backbones/unett.py @@ -0,0 +1,219 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations +from typing import Literal + +import torch +from torch import nn +import torch.nn.functional as F + +from x_transformers import RMSNorm +from x_transformers.x_transformers import RotaryEmbedding + +from f5_tts.model.modules import ( + TimestepEmbedding, + ConvNeXtV2Block, + ConvPositionEmbedding, + Attention, + AttnProcessor, + FeedForward, + precompute_freqs_cis, + get_pos_embed_indices, +) + + +# Text embedding + + +class TextEmbedding(nn.Module): + def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): + super().__init__() + self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + + if conv_layers > 0: + self.extra_modeling = True + self.precompute_max_pos = 4096 # ~44s of 24khz audio + self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) + self.text_blocks = nn.Sequential( + *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] + ) + else: + self.extra_modeling = False + + def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 + text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + batch, text_len = text.shape[0], text.shape[1] + text = F.pad(text, (0, seq_len - text_len), value=0) + + if drop_text: # cfg for text + text = torch.zeros_like(text) + + text = self.text_embed(text) # b n -> b n d + + # possible extra modeling + if self.extra_modeling: + # sinus pos emb + batch_start = torch.zeros((batch,), dtype=torch.long) + pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) + text_pos_embed = self.freqs_cis[pos_idx] + text = text + text_pos_embed + + # convnextv2 blocks + text = self.text_blocks(text) + + return text + + +# noised input audio and context mixing embedding + + +class InputEmbedding(nn.Module): + def __init__(self, mel_dim, text_dim, out_dim): + super().__init__() + self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) + self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) + + def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 + if drop_audio_cond: # cfg for cond audio + cond = torch.zeros_like(cond) + + x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) + x = self.conv_pos_embed(x) + x + return x + + +# Flat UNet Transformer backbone + + +class UNetT(nn.Module): + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + mel_dim=100, + text_num_embeds=256, + text_dim=None, + conv_layers=0, + skip_connect_type: Literal["add", "concat", "none"] = "concat", + ): + super().__init__() + assert depth % 2 == 0, "UNet-Transformer's depth should be even." + + self.time_embed = TimestepEmbedding(dim) + if text_dim is None: + text_dim = mel_dim + self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) + self.input_embed = InputEmbedding(mel_dim, text_dim, dim) + + self.rotary_embed = RotaryEmbedding(dim_head) + + # transformer layers & skip connections + + self.dim = dim + self.skip_connect_type = skip_connect_type + needs_skip_proj = skip_connect_type == "concat" + + self.depth = depth + self.layers = nn.ModuleList([]) + + for idx in range(depth): + is_later_half = idx >= (depth // 2) + + attn_norm = RMSNorm(dim) + attn = Attention( + processor=AttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) + + ff_norm = RMSNorm(dim) + ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + + skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None + + self.layers.append( + nn.ModuleList( + [ + skip_proj, + attn_norm, + attn, + ff_norm, + ff, + ] + ) + ) + + self.norm_out = RMSNorm(dim) + self.proj_out = nn.Linear(dim, mel_dim) + + def forward( + self, + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 + drop_audio_cond, # cfg for cond audio + drop_text, # cfg for text + mask: bool["b n"] | None = None, # noqa: F722 + ): + batch, seq_len = x.shape[0], x.shape[1] + if time.ndim == 0: + time = time.repeat(batch) + + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(time) + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) + + # postfix time t to input x, [b n d] -> [b n+1 d] + x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x + if mask is not None: + mask = F.pad(mask, (1, 0), value=1) + + rope = self.rotary_embed.forward_from_seq_len(seq_len + 1) + + # flat unet transformer + skip_connect_type = self.skip_connect_type + skips = [] + for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers): + layer = idx + 1 + + # skip connection logic + is_first_half = layer <= (self.depth // 2) + is_later_half = not is_first_half + + if is_first_half: + skips.append(x) + + if is_later_half: + skip = skips.pop() + if skip_connect_type == "concat": + x = torch.cat((x, skip), dim=-1) + x = maybe_skip_proj(x) + elif skip_connect_type == "add": + x = x + skip + + # attention and feedforward blocks + x = attn(attn_norm(x), rope=rope, mask=mask) + x + x = ff(ff_norm(x)) + x + + assert len(skips) == 0 + + x = self.norm_out(x)[:, 1:, :] # unpack t from x + + return self.proj_out(x) diff --git a/f5_tts/model/modules.py b/f5_tts/model/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9f030d96112b2d95a3a60241c603b16f8f2efde8 --- /dev/null +++ b/f5_tts/model/modules.py @@ -0,0 +1,666 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +import torchaudio +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from x_transformers.x_transformers import apply_rotary_pos_emb + + +# raw wav to mel spec + + +mel_basis_cache = {} +hann_window_cache = {} + + +def get_bigvgan_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, + fmin=0, + fmax=None, + center=False, +): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main + device = waveform.device + key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}" + + if key not in mel_basis_cache: + mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax) + mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()? + hann_window_cache[key] = torch.hann_window(win_length).to(device) + + mel_basis = mel_basis_cache[key] + hann_window = hann_window_cache[key] + + padding = (n_fft - hop_length) // 2 + waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) + + spec = torch.stft( + waveform, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5)) + + return mel_spec + + +def get_vocos_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, +): + mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=1, + center=True, + normalized=False, + norm=None, + ).to(waveform.device) + if len(waveform.shape) == 3: + waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' + + assert len(waveform.shape) == 2 + + mel = mel_stft(waveform) + mel = mel.clamp(min=1e-5).log() + return mel + + +class MelSpec(nn.Module): + def __init__( + self, + n_fft=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24_000, + mel_spec_type="vocos", + ): + super().__init__() + assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan") + + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.target_sample_rate = target_sample_rate + + if mel_spec_type == "vocos": + self.extractor = get_vocos_mel_spectrogram + elif mel_spec_type == "bigvgan": + self.extractor = get_bigvgan_mel_spectrogram + + self.register_buffer("dummy", torch.tensor(0), persistent=False) + + def forward(self, wav): + if self.dummy.device != wav.device: + self.to(wav.device) + + mel = self.extractor( + waveform=wav, + n_fft=self.n_fft, + n_mel_channels=self.n_mel_channels, + target_sample_rate=self.target_sample_rate, + hop_length=self.hop_length, + win_length=self.win_length, + ) + + return mel + + +# sinusoidal position embedding + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +# convolutional position embedding + + +class ConvPositionEmbedding(nn.Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + ) + + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 + if mask is not None: + mask = mask[..., None] + x = x.masked_fill(~mask, 0.0) + + x = x.permute(0, 2, 1) + x = self.conv1d(x) + out = x.permute(0, 2, 1) + + if mask is not None: + out = out.masked_fill(~mask, 0.0) + + return out + + +# rotary positional embedding related + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py + theta *= theta_rescale_factor ** (dim / (dim - 2)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cos = torch.cos(freqs) # real part + freqs_sin = torch.sin(freqs) # imaginary part + return torch.cat([freqs_cos, freqs_sin], dim=-1) + + +def get_pos_embed_indices(start, length, max_pos, scale=1.0): + # length = length if isinstance(length, int) else length.max() + scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar + pos = ( + start.unsqueeze(1) + + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long() + ) + # avoid extra long error. + pos = torch.where(pos < max_pos, pos, max_pos - 1) + return pos + + +# Global Response Normalization layer (Instance Normalization ?) + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=1, keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py +# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108 + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = x.transpose(1, 2) # b n d -> b d n + x = self.dwconv(x) + x = x.transpose(1, 2) # b d n -> b n d + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + return residual + x + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation + + +class AdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation + + +class AdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +# FeedForward + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + activation = nn.GELU(approximate=approximate) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.ff(x) + + +# Attention with possible joint part +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class Attention(nn.Module): + def __init__( + self, + processor: JointAttnProcessor | AttnProcessor, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + context_dim: Optional[int] = None, # if not None -> joint attention + context_pre_only=None, + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.processor = processor + + self.dim = dim + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + + self.context_dim = context_dim + self.context_pre_only = context_pre_only + + self.to_q = nn.Linear(dim, self.inner_dim) + self.to_k = nn.Linear(dim, self.inner_dim) + self.to_v = nn.Linear(dim, self.inner_dim) + + if self.context_dim is not None: + self.to_k_c = nn.Linear(context_dim, self.inner_dim) + self.to_v_c = nn.Linear(context_dim, self.inner_dim) + if self.context_pre_only is not None: + self.to_q_c = nn.Linear(context_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, dim)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_out_c = nn.Linear(self.inner_dim, dim) + + def forward( + self, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b n d"] = None, # context c # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.Tensor: + if c is not None: + return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope) + else: + return self.processor(self, x, mask=mask, rope=rope) + + +# Attention processor + + +# from torch.nn.attention import SDPBackend +# torch.backends.cuda.enable_flash_sdp(True) +class AttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + # print(3433333333,attn_mask.shape) + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + else: + attn_mask = None + # with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True): + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): + # print(torch.backends.cuda.flash_sdp_enabled()) + # print(torch.backends.cuda.mem_efficient_sdp_enabled()) + # print(torch.backends.cuda.math_sdp_enabled()) + x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +# Joint Attention processor for MM-DiT +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class JointAttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b nt d"] = None, # context c, here text # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.FloatTensor: + residual = x + + batch_size = c.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # `context` projections. + c_query = attn.to_q_c(c) + c_key = attn.to_k_c(c) + c_value = attn.to_v_c(c) + + # apply rope for context and noised input independently + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + if c_rope is not None: + freqs, xpos_scale = c_rope + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) + c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) + + # attention + query = torch.cat([query, c_query], dim=1) + key = torch.cat([key, c_key], dim=1) + value = torch.cat([value, c_value], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text) + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + else: + attn_mask = None + + x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # Split the attention outputs. + x, c = ( + x[:, : residual.shape[1]], + x[:, residual.shape[1] :], + ) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + if not attn.context_pre_only: + c = attn.to_out_c(c) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + # c = c.masked_fill(~mask, 0.) # no mask for c (text) + + return x, c + + +# DiT Block + + +class DiTBlock(nn.Module): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1): + super().__init__() + + self.attn_norm = AdaLayerNormZero(dim) + self.attn = Attention( + processor=AttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) + + self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + + def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + + # attention + attn_output = self.attn(x=norm, mask=mask, rope=rope) + + # process attention output for input x + x = x + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + x = x + gate_mlp.unsqueeze(1) * ff_output + + return x + + +# MMDiT Block https://arxiv.org/abs/2403.03206 + + +class MMDiTBlock(nn.Module): + r""" + modified from diffusers/src/diffusers/models/attention.py + + notes. + _c: context related. text, cond, etc. (left part in sd3 fig2.b) + _x: noised input related. (right part) + context_pre_only: last layer only do prenorm + modulation cuz no more ffn + """ + + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False): + super().__init__() + + self.context_pre_only = context_pre_only + + self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) + self.attn_norm_x = AdaLayerNormZero(dim) + self.attn = Attention( + processor=JointAttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + context_dim=dim, + context_pre_only=context_pre_only, + ) + + if not context_pre_only: + self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + else: + self.ff_norm_c = None + self.ff_c = None + self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + + def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding + # pre-norm & modulation for attention input + if self.context_pre_only: + norm_c = self.attn_norm_c(c, t) + else: + norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t) + norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t) + + # attention + x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope) + + # process attention output for context c + if self.context_pre_only: + c = None + else: # if not last layer + c = c + c_gate_msa.unsqueeze(1) * c_attn_output + + norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + c_ff_output = self.ff_c(norm_c) + c = c + c_gate_mlp.unsqueeze(1) * c_ff_output + + # process attention output for input x + x = x + x_gate_msa.unsqueeze(1) * x_attn_output + + norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None] + x_ff_output = self.ff_x(norm_x) + x = x + x_gate_mlp.unsqueeze(1) * x_ff_output + + return c, x + + +# time step conditioning embedding + + +class TimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + + def forward(self, timestep: float["b"]): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + time = self.time_mlp(time_hidden) # b d + return time diff --git a/feature_extractor/__init__.py b/feature_extractor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01ef5ddf13c867f215ddef009537bf713c10c717 --- /dev/null +++ b/feature_extractor/__init__.py @@ -0,0 +1,3 @@ +from . import cnhubert, whisper_enc + +content_module_map = {"cnhubert": cnhubert, "whisper": whisper_enc} diff --git a/feature_extractor/cnhubert.py b/feature_extractor/cnhubert.py new file mode 100644 index 0000000000000000000000000000000000000000..f22b8d09b7f9e8931011c9544c71b9668de3369d --- /dev/null +++ b/feature_extractor/cnhubert.py @@ -0,0 +1,106 @@ +import torch +import os +from transformers import logging as tf_logging + +tf_logging.set_verbosity_error() + +import logging + +logging.getLogger("numba").setLevel(logging.WARNING) + +from transformers import ( + Wav2Vec2FeatureExtractor, + HubertModel, +) + +import utils +import torch.nn as nn + +cnhubert_base_path = None + + +class CNHubert(nn.Module): + def __init__(self, base_path: str = None): + super().__init__() + if base_path is None: + base_path = cnhubert_base_path + if os.path.exists(base_path): + ... + else: + raise FileNotFoundError(base_path) + self.model = HubertModel.from_pretrained(base_path, local_files_only=True) + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True) + + def forward(self, x): + input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) + feats = self.model(input_values)["last_hidden_state"] + return feats + + +# class CNHubertLarge(nn.Module): +# def __init__(self): +# super().__init__() +# self.model = HubertModel.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large") +# self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large") +# def forward(self, x): +# input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) +# feats = self.model(input_values)["last_hidden_state"] +# return feats +# +# class CVec(nn.Module): +# def __init__(self): +# super().__init__() +# self.model = HubertModel.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base") +# self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base") +# def forward(self, x): +# input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) +# feats = self.model(input_values)["last_hidden_state"] +# return feats +# +# class cnw2v2base(nn.Module): +# def __init__(self): +# super().__init__() +# self.model = Wav2Vec2Model.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base") +# self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base") +# def forward(self, x): +# input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) +# feats = self.model(input_values)["last_hidden_state"] +# return feats + + +def get_model(): + model = CNHubert() + model.eval() + return model + + +# def get_large_model(): +# model = CNHubertLarge() +# model.eval() +# return model +# +# def get_model_cvec(): +# model = CVec() +# model.eval() +# return model +# +# def get_model_cnw2v2base(): +# model = cnw2v2base() +# model.eval() +# return model + + +def get_content(hmodel, wav_16k_tensor): + with torch.no_grad(): + feats = hmodel(wav_16k_tensor) + return feats.transpose(1, 2) + + +if __name__ == "__main__": + model = get_model() + src_path = "/Users/Shared/原音频2.wav" + wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000) + model = model + wav_16k_tensor = wav_16k_tensor + feats = get_content(model, wav_16k_tensor) + print(feats.shape) diff --git a/feature_extractor/whisper_enc.py b/feature_extractor/whisper_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..260539bc9fbd8a314cc4949d0f746c09e64f4073 --- /dev/null +++ b/feature_extractor/whisper_enc.py @@ -0,0 +1,23 @@ +import torch + + +def get_model(): + import whisper + + model = whisper.load_model("small", device="cpu") + + return model.encoder + + +def get_content(model=None, wav_16k_tensor=None): + from whisper import log_mel_spectrogram, pad_or_trim + + dev = next(model.parameters()).device + mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000] + # if torch.cuda.is_available(): + # mel = mel.to(torch.float16) + feature_len = mel.shape[-1] // 2 + assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频" + with torch.no_grad(): + feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2) + return feature diff --git a/inference_cli.py b/inference_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..571e5279060c26412d9f77b10cb81beaa8ee2b99 --- /dev/null +++ b/inference_cli.py @@ -0,0 +1,85 @@ +import argparse +import os +import soundfile as sf + +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav +import hashlib + +i18n = I18nAuto() +def get_tts_cache_key(model_name, text, prompt_audio_path): + """ + 生成 TTS 缓存 key: md5(模型+文本+md5(prompt音频内容)) + :param model_name: str + :param text: str + :param prompt_audio_path: str or None + :return: str (md5 hash) + """ + prompt_md5 = '' + if prompt_audio_path and os.path.exists(prompt_audio_path): + with open(prompt_audio_path, 'rb') as f: + prompt_content = f.read() + prompt_md5 = hashlib.md5(prompt_content).hexdigest() + key_str = f"{model_name}::{text}::{prompt_md5}" + return hashlib.md5(key_str.encode('utf-8')).hexdigest() + +def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, target_language, output_path): + + # Change model weights + change_gpt_weights(gpt_path=GPT_model_path) + change_sovits_weights(sovits_path=SoVITS_model_path) + + # Synthesize audio + synthesis_result = get_tts_wav( + ref_wav_path=ref_audio_path, + prompt_text=ref_text, + prompt_language=i18n(ref_language), + text=target_text, + text_language=i18n(target_language), + top_p=1, + temperature=1, + ) + + result_list = list(synthesis_result) + + if result_list: + last_sampling_rate, last_audio_data = result_list[-1] + output_wav_path = os.path.join(output_path, "output.wav") + sf.write(output_wav_path, last_audio_data, last_sampling_rate) + print(f"Audio saved to {output_wav_path}") + + +def main(): + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") + parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file") + parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file") + parser.add_argument("--ref_text", required=True, help="Path to the reference text file") + parser.add_argument( + "--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio" + ) + parser.add_argument("--target_text", required=True, help="Path to the target text file") + parser.add_argument( + "--target_language", + required=True, + choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], + help="Language of the target text", + ) + parser.add_argument("--output_path", required=True, help="Path to the output directory") + + args = parser.parse_args() + + synthesize( + args.gpt_model, + args.sovits_model, + args.ref_audio, + args.ref_text, + args.ref_language, + args.target_text, + args.target_language, + args.output_path, + ) + + +if __name__ == "__main__": + main() diff --git a/inference_gui.py b/inference_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..379f7fa8cdb32b4b56db8b242717c23bdb51eca0 --- /dev/null +++ b/inference_gui.py @@ -0,0 +1,316 @@ +import os +import sys +from PyQt5.QtCore import QEvent +from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushButton, QTextEdit +from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox +import soundfile as sf + +from tools.i18n.i18n import I18nAuto + +i18n = I18nAuto() + +from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav + + +class GPTSoVITSGUI(QMainWindow): + GPT_Path = gpt_path + SoVITS_Path = sovits_path + + def __init__(self): + super().__init__() + + self.setWindowTitle("GPT-SoVITS GUI") + self.setGeometry(800, 450, 950, 850) + + self.setStyleSheet(""" + QWidget { + background-color: #a3d3b1; + } + + QTabWidget::pane { + background-color: #a3d3b1; + } + + QTabWidget::tab-bar { + alignment: left; + } + + QTabBar::tab { + background: #8da4bf; + color: #ffffff; + padding: 8px; + } + + QTabBar::tab:selected { + background: #2a3f54; + } + + QLabel { + color: #000000; + } + + QPushButton { + background-color: #4CAF50; + color: white; + padding: 8px; + border: 1px solid #4CAF50; + border-radius: 4px; + } + + QPushButton:hover { + background-color: #45a049; + border: 1px solid #45a049; + box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1); + } + """) + + license_text = ( + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. " + "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE." + ) + license_label = QLabel(license_text) + license_label.setWordWrap(True) + + self.GPT_model_label = QLabel("选择GPT模型:") + self.GPT_model_input = QLineEdit() + self.GPT_model_input.setPlaceholderText("拖拽或选择文件") + self.GPT_model_input.setText(self.GPT_Path) + self.GPT_model_input.setReadOnly(True) + self.GPT_model_button = QPushButton("选择GPT模型文件") + self.GPT_model_button.clicked.connect(self.select_GPT_model) + + self.SoVITS_model_label = QLabel("选择SoVITS模型:") + self.SoVITS_model_input = QLineEdit() + self.SoVITS_model_input.setPlaceholderText("拖拽或选择文件") + self.SoVITS_model_input.setText(self.SoVITS_Path) + self.SoVITS_model_input.setReadOnly(True) + self.SoVITS_model_button = QPushButton("选择SoVITS模型文件") + self.SoVITS_model_button.clicked.connect(self.select_SoVITS_model) + + self.ref_audio_label = QLabel("上传参考音频:") + self.ref_audio_input = QLineEdit() + self.ref_audio_input.setPlaceholderText("拖拽或选择文件") + self.ref_audio_input.setReadOnly(True) + self.ref_audio_button = QPushButton("选择音频文件") + self.ref_audio_button.clicked.connect(self.select_ref_audio) + + self.ref_text_label = QLabel("参考音频文本:") + self.ref_text_input = QLineEdit() + self.ref_text_input.setPlaceholderText("直接输入文字或上传文本") + self.ref_text_button = QPushButton("上传文本") + self.ref_text_button.clicked.connect(self.upload_ref_text) + + self.ref_language_label = QLabel("参考音频语言:") + self.ref_language_combobox = QComboBox() + self.ref_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"]) + self.ref_language_combobox.setCurrentText("多语种混合") + + self.target_text_label = QLabel("合成目标文本:") + self.target_text_input = QLineEdit() + self.target_text_input.setPlaceholderText("直接输入文字或上传文本") + self.target_text_button = QPushButton("上传文本") + self.target_text_button.clicked.connect(self.upload_target_text) + + self.target_language_label = QLabel("合成音频语言:") + self.target_language_combobox = QComboBox() + self.target_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"]) + self.target_language_combobox.setCurrentText("多语种混合") + + self.output_label = QLabel("输出音频路径:") + self.output_input = QLineEdit() + self.output_input.setPlaceholderText("拖拽或选择文件") + self.output_input.setReadOnly(True) + self.output_button = QPushButton("选择文件夹") + self.output_button.clicked.connect(self.select_output_path) + + self.output_text = QTextEdit() + self.output_text.setReadOnly(True) + + self.add_drag_drop_events( + [ + self.GPT_model_input, + self.SoVITS_model_input, + self.ref_audio_input, + self.ref_text_input, + self.target_text_input, + self.output_input, + ] + ) + + self.synthesize_button = QPushButton("合成") + self.synthesize_button.clicked.connect(self.synthesize) + + self.clear_output_button = QPushButton("清空输出") + self.clear_output_button.clicked.connect(self.clear_output) + + self.status_bar = QStatusBar() + + main_layout = QVBoxLayout() + + input_layout = QGridLayout(self) + input_layout.setSpacing(10) + + input_layout.addWidget(license_label, 0, 0, 1, 3) + + input_layout.addWidget(self.GPT_model_label, 1, 0) + input_layout.addWidget(self.GPT_model_input, 2, 0, 1, 2) + input_layout.addWidget(self.GPT_model_button, 2, 2) + + input_layout.addWidget(self.SoVITS_model_label, 3, 0) + input_layout.addWidget(self.SoVITS_model_input, 4, 0, 1, 2) + input_layout.addWidget(self.SoVITS_model_button, 4, 2) + + input_layout.addWidget(self.ref_audio_label, 5, 0) + input_layout.addWidget(self.ref_audio_input, 6, 0, 1, 2) + input_layout.addWidget(self.ref_audio_button, 6, 2) + + input_layout.addWidget(self.ref_language_label, 7, 0) + input_layout.addWidget(self.ref_language_combobox, 8, 0, 1, 1) + input_layout.addWidget(self.ref_text_label, 9, 0) + input_layout.addWidget(self.ref_text_input, 10, 0, 1, 2) + input_layout.addWidget(self.ref_text_button, 10, 2) + + input_layout.addWidget(self.target_language_label, 11, 0) + input_layout.addWidget(self.target_language_combobox, 12, 0, 1, 1) + input_layout.addWidget(self.target_text_label, 13, 0) + input_layout.addWidget(self.target_text_input, 14, 0, 1, 2) + input_layout.addWidget(self.target_text_button, 14, 2) + + input_layout.addWidget(self.output_label, 15, 0) + input_layout.addWidget(self.output_input, 16, 0, 1, 2) + input_layout.addWidget(self.output_button, 16, 2) + + main_layout.addLayout(input_layout) + + output_layout = QVBoxLayout() + output_layout.addWidget(self.output_text) + main_layout.addLayout(output_layout) + + main_layout.addWidget(self.synthesize_button) + + main_layout.addWidget(self.clear_output_button) + + main_layout.addWidget(self.status_bar) + + self.central_widget = QWidget() + self.central_widget.setLayout(main_layout) + self.setCentralWidget(self.central_widget) + + def dragEnterEvent(self, event): + if event.mimeData().hasUrls(): + event.acceptProposedAction() + + def dropEvent(self, event): + if event.mimeData().hasUrls(): + file_paths = [url.toLocalFile() for url in event.mimeData().urls()] + if len(file_paths) == 1: + self.update_ref_audio(file_paths[0]) + else: + self.update_ref_audio(", ".join(file_paths)) + + def add_drag_drop_events(self, widgets): + for widget in widgets: + widget.setAcceptDrops(True) + widget.installEventFilter(self) + + def eventFilter(self, obj, event): + if event.type() in (QEvent.DragEnter, QEvent.Drop): + mime_data = event.mimeData() + if mime_data.hasUrls(): + event.acceptProposedAction() + + return super().eventFilter(obj, event) + + def select_GPT_model(self): + file_path, _ = QFileDialog.getOpenFileName(self, "选择GPT模型文件", "", "GPT Files (*.ckpt)") + if file_path: + self.GPT_model_input.setText(file_path) + + def select_SoVITS_model(self): + file_path, _ = QFileDialog.getOpenFileName(self, "选择SoVITS模型文件", "", "SoVITS Files (*.pth)") + if file_path: + self.SoVITS_model_input.setText(file_path) + + def select_ref_audio(self): + file_path, _ = QFileDialog.getOpenFileName(self, "选择参考音频文件", "", "Audio Files (*.wav *.mp3)") + if file_path: + self.update_ref_audio(file_path) + + def upload_ref_text(self): + file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)") + if file_path: + with open(file_path, "r", encoding="utf-8") as file: + content = file.read() + self.ref_text_input.setText(content) + + def upload_target_text(self): + file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)") + if file_path: + with open(file_path, "r", encoding="utf-8") as file: + content = file.read() + self.target_text_input.setText(content) + + def select_output_path(self): + options = QFileDialog.Options() + options |= QFileDialog.DontUseNativeDialog + options |= QFileDialog.ShowDirsOnly + + folder_dialog = QFileDialog() + folder_dialog.setOptions(options) + folder_dialog.setFileMode(QFileDialog.Directory) + + if folder_dialog.exec_(): + folder_path = folder_dialog.selectedFiles()[0] + self.output_input.setText(folder_path) + + def update_ref_audio(self, file_path): + self.ref_audio_input.setText(file_path) + + def clear_output(self): + self.output_text.clear() + + def synthesize(self): + GPT_model_path = self.GPT_model_input.text() + SoVITS_model_path = self.SoVITS_model_input.text() + ref_audio_path = self.ref_audio_input.text() + language_combobox = self.ref_language_combobox.currentText() + language_combobox = i18n(language_combobox) + ref_text = self.ref_text_input.text() + target_language_combobox = self.target_language_combobox.currentText() + target_language_combobox = i18n(target_language_combobox) + target_text = self.target_text_input.text() + output_path = self.output_input.text() + + if GPT_model_path != self.GPT_Path: + change_gpt_weights(gpt_path=GPT_model_path) + self.GPT_Path = GPT_model_path + if SoVITS_model_path != self.SoVITS_Path: + change_sovits_weights(sovits_path=SoVITS_model_path) + self.SoVITS_Path = SoVITS_model_path + + synthesis_result = get_tts_wav( + ref_wav_path=ref_audio_path, + prompt_text=ref_text, + prompt_language=language_combobox, + text=target_text, + text_language=target_language_combobox, + ) + + result_list = list(synthesis_result) + + if result_list: + last_sampling_rate, last_audio_data = result_list[-1] + output_wav_path = os.path.join(output_path, "output.wav") + sf.write(output_wav_path, last_audio_data, last_sampling_rate) + + result = "Audio saved to " + output_wav_path + + self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000) + self.output_text.append("处理结果:\n" + result) + + +if __name__ == "__main__": + app = QApplication(sys.argv) + mainWin = GPTSoVITSGUI() + mainWin.show() + sys.exit(app.exec_()) diff --git a/inference_webui.py b/inference_webui.py new file mode 100644 index 0000000000000000000000000000000000000000..46019ffe796195c4779eedae9582649a69d93ee6 --- /dev/null +++ b/inference_webui.py @@ -0,0 +1,1342 @@ +""" +按中英混合识别 +按日英混合识别 +多语种启动切分识别语种 +全部按中文识别 +全部按英文识别 +全部按日文识别 +""" + +import json +import logging +import os +import re +import sys +import traceback +import warnings + +import torch +import torchaudio +from text.LangSegmenter import LangSegmenter + +logging.getLogger("markdown_it").setLevel(logging.ERROR) +logging.getLogger("urllib3").setLevel(logging.ERROR) +logging.getLogger("httpcore").setLevel(logging.ERROR) +logging.getLogger("httpx").setLevel(logging.ERROR) +logging.getLogger("asyncio").setLevel(logging.ERROR) +logging.getLogger("charset_normalizer").setLevel(logging.ERROR) +logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) +logging.getLogger("multipart.multipart").setLevel(logging.ERROR) +warnings.simplefilter(action="ignore", category=FutureWarning) + +version = model_version = os.environ.get("version", "v2") + +from config import change_choices, get_weights_names, name2gpt_path, name2sovits_path + +SoVITS_names, GPT_names = get_weights_names() +from config import pretrained_sovits_name + +path_sovits_v3 = pretrained_sovits_name["v3"] +path_sovits_v4 = pretrained_sovits_name["v4"] +is_exist_s2gv3 = os.path.exists(path_sovits_v3) +is_exist_s2gv4 = os.path.exists(path_sovits_v4) + +if os.path.exists("./weight.json"): + pass +else: + with open("./weight.json", "w", encoding="utf-8") as file: + json.dump({"GPT": {}, "SoVITS": {}}, file) + +with open("./weight.json", "r", encoding="utf-8") as file: + weight_data = file.read() + weight_data = json.loads(weight_data) + gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, GPT_names[-1])) + sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, SoVITS_names[0])) + if isinstance(gpt_path, list): + gpt_path = gpt_path[0] + if isinstance(sovits_path, list): + sovits_path = sovits_path[0] + +# print(2333333) +# print(os.environ["gpt_path"]) +# print(gpt_path) +# print(GPT_names) +# print(weight_data) +# print(weight_data.get("GPT", {})) +# print(version)###GPT version里没有s2的v2pro +# print(weight_data.get("GPT", {}).get(version, GPT_names[-1])) +from huggingface_hub import snapshot_download +snapshot_download(repo_id="lj1995/GPT-SoVITS",local_dir="pretrained_models",repo_type="model") +cnhubert_base_path = os.environ.get("cnhubert_base_path", "pretrained_models/chinese-hubert-base") +bert_path = os.environ.get("bert_path", "pretrained_models/chinese-roberta-wwm-ext-large") +infer_ttswebui = os.environ.get("infer_ttswebui", 9872) +infer_ttswebui = int(infer_ttswebui) +is_share = os.environ.get("is_share", "False") +is_share = eval(is_share) +if "_CUDA_VISIBLE_DEVICES" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +# is_half=False +punctuation = set(["!", "?", "…", ",", ".", "-", " "]) +import gradio as gr +import librosa +import numpy as np +from feature_extractor import cnhubert +from transformers import AutoModelForMaskedLM, AutoTokenizer + +cnhubert.cnhubert_base_path = cnhubert_base_path + +import random + +from GPT_SoVITS.module.models import Generator, SynthesizerTrn, SynthesizerTrnV3 + + +def set_seed(seed): + if seed == -1: + seed = random.randint(0, 1000000) + seed = int(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +# set_seed(42) + +from time import time as ttime + +from AR.models.t2s_lightning_module import Text2SemanticLightningModule +from peft import LoraConfig, get_peft_model +from text import cleaned_text_to_sequence +from text.cleaner import clean_text + +from tools.assets import css, js, top_html +from tools.i18n.i18n import I18nAuto, scan_language_list + +language = os.environ.get("language", "Auto") +language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language +i18n = I18nAuto(language=language) + +# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。 + +if torch.cuda.is_available(): + device = "cuda" +else: + device = "cpu" + +dict_language_v1 = { + i18n("中文"): "all_zh", # 全部按中文识别 + i18n("英文"): "en", # 全部按英文识别#######不变 + i18n("日文"): "all_ja", # 全部按日文识别 + i18n("中英混合"): "zh", # 按中英混合识别####不变 + i18n("日英混合"): "ja", # 按日英混合识别####不变 + i18n("多语种混合"): "auto", # 多语种启动切分识别语种 +} +dict_language_v2 = { + i18n("中文"): "all_zh", # 全部按中文识别 + i18n("英文"): "en", # 全部按英文识别#######不变 + i18n("日文"): "all_ja", # 全部按日文识别 + i18n("粤语"): "all_yue", # 全部按中文识别 + i18n("韩文"): "all_ko", # 全部按韩文识别 + i18n("中英混合"): "zh", # 按中英混合识别####不变 + i18n("日英混合"): "ja", # 按日英混合识别####不变 + i18n("粤英混合"): "yue", # 按粤英混合识别####不变 + i18n("韩英混合"): "ko", # 按韩英混合识别####不变 + i18n("多语种混合"): "auto", # 多语种启动切分识别语种 + i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种 +} +dict_language = dict_language_v1 if version == "v1" else dict_language_v2 + +tokenizer = AutoTokenizer.from_pretrained(bert_path) +bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) +if is_half == True: + bert_model = bert_model.half().to(device) +else: + bert_model = bert_model.to(device) + + +def get_bert_feature(text, word2ph): + with torch.no_grad(): + inputs = tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(device) + res = bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + assert len(word2ph) == len(text) + phone_level_feature = [] + for i in range(len(word2ph)): + repeat_feature = res[i].repeat(word2ph[i], 1) + phone_level_feature.append(repeat_feature) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + return phone_level_feature.T + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +ssl_model = cnhubert.get_model() +if is_half == True: + ssl_model = ssl_model.half().to(device) +else: + ssl_model = ssl_model.to(device) + + +###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt +# symbol_version-model_version-if_lora_v3 +from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new + +v3v4set = {"v3", "v4"} + + +def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): + if "!" in sovits_path or "!" in sovits_path: + sovits_path = name2sovits_path[sovits_path] + global vq_model, hps, version, model_version, dict_language, if_lora_v3 + version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) + print(sovits_path, version, model_version, if_lora_v3) + is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4 + path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 + if if_lora_v3 == True and is_exist == False: + info = path_sovits + "SoVITS %s" % model_version + i18n("底模缺失,无法加载相应 LoRA 权重") + gr.Warning(info) + raise FileExistsError(info) + dict_language = dict_language_v1 if version == "v1" else dict_language_v2 + if prompt_language is not None and text_language is not None: + if prompt_language in list(dict_language.keys()): + prompt_text_update, prompt_language_update = ( + {"__type__": "update"}, + {"__type__": "update", "value": prompt_language}, + ) + else: + prompt_text_update = {"__type__": "update", "value": ""} + prompt_language_update = {"__type__": "update", "value": i18n("中文")} + if text_language in list(dict_language.keys()): + text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language} + else: + text_update = {"__type__": "update", "value": ""} + text_language_update = {"__type__": "update", "value": i18n("中文")} + if model_version in v3v4set: + visible_sample_steps = True + visible_inp_refs = False + else: + visible_sample_steps = False + visible_inp_refs = True + yield ( + {"__type__": "update", "choices": list(dict_language.keys())}, + {"__type__": "update", "choices": list(dict_language.keys())}, + prompt_text_update, + prompt_language_update, + text_update, + text_language_update, + { + "__type__": "update", + "visible": visible_sample_steps, + "value": 32 if model_version == "v3" else 8, + "choices": [4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32], + }, + {"__type__": "update", "visible": visible_inp_refs}, + {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False}, + {"__type__": "update", "visible": True if model_version == "v3" else False}, + {"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False}, + ) + + dict_s2 = load_sovits_new(sovits_path) + hps = dict_s2["config"] + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + if "enc_p.text_embedding.weight" not in dict_s2["weight"]: + hps.model.version = "v2" # v3model,v2sybomls + elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + hps.model.version = "v1" + else: + hps.model.version = "v2" + version = hps.model.version + # print("sovits版本:",hps.model.version) + if model_version not in v3v4set: + if "Pro" not in model_version: + model_version = version + else: + hps.model.version = model_version + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + else: + hps.model.version = model_version + vq_model = SynthesizerTrnV3( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + if "pretrained" not in sovits_path: + try: + del vq_model.enc_q + except: + pass + if is_half == True: + vq_model = vq_model.half().to(device) + else: + vq_model = vq_model.to(device) + vq_model.eval() + if if_lora_v3 == False: + print("loading sovits_%s" % model_version, vq_model.load_state_dict(dict_s2["weight"], strict=False)) + else: + path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 + print( + "loading sovits_%spretrained_G" % model_version, + vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False), + ) + lora_rank = dict_s2["lora_rank"] + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) + print("loading sovits_%s_lora%s" % (model_version, lora_rank)) + vq_model.load_state_dict(dict_s2["weight"], strict=False) + vq_model.cfm = vq_model.cfm.merge_and_unload() + # torch.save(vq_model.state_dict(),"merge_win.pth") + vq_model.eval() + + yield ( + {"__type__": "update", "choices": list(dict_language.keys())}, + {"__type__": "update", "choices": list(dict_language.keys())}, + prompt_text_update, + prompt_language_update, + text_update, + text_language_update, + { + "__type__": "update", + "visible": visible_sample_steps, + "value": 32 if model_version == "v3" else 8, + "choices": [4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32], + }, + {"__type__": "update", "visible": visible_inp_refs}, + {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False}, + {"__type__": "update", "visible": True if model_version == "v3" else False}, + {"__type__": "update", "value": i18n("合成语音"), "interactive": True}, + ) + with open("./weight.json") as f: + data = f.read() + data = json.loads(data) + data["SoVITS"][version] = sovits_path + with open("./weight.json", "w") as f: + f.write(json.dumps(data)) + + +try: + next(change_sovits_weights(sovits_path)) +except: + pass + + +def change_gpt_weights(gpt_path): + if "!" in gpt_path or "!" in gpt_path: + gpt_path = name2gpt_path[gpt_path] + global hz, max_sec, t2s_model, config + hz = 50 + dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False) + config = dict_s1["config"] + max_sec = config["data"]["max_sec"] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + if is_half == True: + t2s_model = t2s_model.half() + t2s_model = t2s_model.to(device) + t2s_model.eval() + # total = sum([param.nelement() for param in t2s_model.parameters()]) + # print("Number of parameter: %.2fM" % (total / 1e6)) + with open("./weight.json") as f: + data = f.read() + data = json.loads(data) + data["GPT"][version] = gpt_path + with open("./weight.json", "w") as f: + f.write(json.dumps(data)) + + +change_gpt_weights(gpt_path) +os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" +import torch + +now_dir = os.getcwd() + + +def clean_hifigan_model(): + global hifigan_model + if hifigan_model: + hifigan_model = hifigan_model.cpu() + hifigan_model = None + try: + torch.cuda.empty_cache() + except: + pass + + +def clean_bigvgan_model(): + global bigvgan_model + if bigvgan_model: + bigvgan_model = bigvgan_model.cpu() + bigvgan_model = None + try: + torch.cuda.empty_cache() + except: + pass + + +def clean_sv_cn_model(): + global sv_cn_model + if sv_cn_model: + sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu() + sv_cn_model = None + try: + torch.cuda.empty_cache() + except: + pass + + +def init_bigvgan(): + global bigvgan_model, hifigan_model, sv_cn_model + from BigVGAN import bigvgan + + bigvgan_model = bigvgan.BigVGAN.from_pretrained( + "%s/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), + use_cuda_kernel=False, + ) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + bigvgan_model.remove_weight_norm() + bigvgan_model = bigvgan_model.eval() + clean_hifigan_model() + clean_sv_cn_model() + if is_half == True: + bigvgan_model = bigvgan_model.half().to(device) + else: + bigvgan_model = bigvgan_model.to(device) + + +def init_hifigan(): + global hifigan_model, bigvgan_model, sv_cn_model + hifigan_model = Generator( + initial_channel=100, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[10, 6, 2, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[20, 12, 4, 4, 4], + gin_channels=0, + is_bias=True, + ) + hifigan_model.eval() + hifigan_model.remove_weight_norm() + state_dict_g = torch.load( + "%s/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), + map_location="cpu", + weights_only=False, + ) + print("loading vocoder", hifigan_model.load_state_dict(state_dict_g)) + clean_bigvgan_model() + clean_sv_cn_model() + if is_half == True: + hifigan_model = hifigan_model.half().to(device) + else: + hifigan_model = hifigan_model.to(device) + + +from sv import SV + + +def init_sv_cn(): + global hifigan_model, bigvgan_model, sv_cn_model + sv_cn_model = SV(device, is_half) + clean_bigvgan_model() + clean_hifigan_model() + + +bigvgan_model = hifigan_model = sv_cn_model = None +if model_version == "v3": + init_bigvgan() +if model_version == "v4": + init_hifigan() +if model_version in {"v2Pro", "v2ProPlus"}: + init_sv_cn() + +resample_transform_dict = {} + + +def resample(audio_tensor, sr0, sr1, device): + global resample_transform_dict + key = "%s-%s-%s" % (sr0, sr1, str(device)) + if key not in resample_transform_dict: + resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) + return resample_transform_dict[key](audio_tensor) + + +def get_spepc(hps, filename, dtype, device, is_v2pro=False): + # audio = load_audio(filename, int(hps.data.sampling_rate)) + + # audio, sampling_rate = librosa.load(filename, sr=int(hps.data.sampling_rate)) + # audio = torch.FloatTensor(audio) + + sr1 = int(hps.data.sampling_rate) + audio, sr0 = torchaudio.load(filename) + if sr0 != sr1: + audio = audio.to(device) + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) + audio = resample(audio, sr0, sr1, device) + else: + audio = audio.to(device) + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) + + maxx = audio.abs().max() + if maxx > 1: + audio /= min(2, maxx) + spec = spectrogram_torch( + audio, + hps.data.filter_length, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + center=False, + ) + spec = spec.to(dtype) + if is_v2pro == True: + audio = resample(audio, sr1, 16000, device).to(dtype) + return spec, audio + + +def clean_text_inf(text, language, version): + language = language.replace("all_", "") + phones, word2ph, norm_text = clean_text(text, language, version) + phones = cleaned_text_to_sequence(phones, version) + return phones, word2ph, norm_text + + +dtype = torch.float16 if is_half == True else torch.float32 + + +def get_bert_inf(phones, word2ph, norm_text, language): + language = language.replace("all_", "") + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + + return bert + + +splits = { + ",", + "。", + "?", + "!", + ",", + ".", + "?", + "!", + "~", + ":", + ":", + "—", + "…", +} + + +def get_first(text): + pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" + text = re.split(pattern, text)[0].strip() + return text + + +from text import chinese + + +def get_phones_and_bert(text, language, version, final=False): + if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: + formattext = text + while " " in formattext: + formattext = formattext.replace(" ", " ") + if language == "all_zh": + if re.search(r"[A-Za-z]", formattext): + formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return get_phones_and_bert(formattext, "zh", version) + else: + phones, word2ph, norm_text = clean_text_inf(formattext, language, version) + bert = get_bert_feature(norm_text, word2ph).to(device) + elif language == "all_yue" and re.search(r"[A-Za-z]", formattext): + formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return get_phones_and_bert(formattext, "yue", version) + else: + phones, word2ph, norm_text = clean_text_inf(formattext, language, version) + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: + textlist = [] + langlist = [] + if language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日韩文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + print(textlist) + print(langlist) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) + + if not final and len(phones) < 6: + return get_phones_and_bert("." + text, language, version, final=True) + + return phones, bert.to(dtype), norm_text + + +from module.mel_processing import mel_spectrogram_torch, spectrogram_torch + +spec_min = -12 +spec_max = 2 + + +def norm_spec(x): + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 + + +def denorm_spec(x): + return (x + 1) / 2 * (spec_max - spec_min) + spec_min + + +mel_fn = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) +mel_fn_v4 = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1280, + "win_size": 1280, + "hop_size": 320, + "num_mels": 100, + "sampling_rate": 32000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) + + +def merge_short_text_in_array(texts, threshold): + if (len(texts)) < 2: + return texts + result = [] + text = "" + for ele in texts: + text += ele + if len(text) >= threshold: + result.append(text) + text = "" + if len(text) > 0: + if len(result) == 0: + result.append(text) + else: + result[len(result) - 1] += text + return result + + +sr_model = None + + +def audio_sr(audio, sr): + global sr_model + if sr_model == None: + from tools.audio_sr import AP_BWE + + try: + sr_model = AP_BWE(device, DictToAttrRecursive) + except FileNotFoundError: + gr.Warning(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好")) + return audio.cpu().detach().numpy(), sr + return sr_model(audio, sr) + + +##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature +# cache_tokens={}#暂未实现清理机制 +cache = {} + + +def get_tts_wav( + ref_wav_path, + prompt_text, + prompt_language, + text, + text_language, + how_to_cut=i18n("不切"), + top_k=20, + top_p=0.6, + temperature=0.6, + ref_free=False, + speed=1, + if_freeze=False, + inp_refs=None, + sample_steps=8, + if_sr=False, + pause_second=0.3, +): + global cache + if ref_wav_path: + pass + else: + gr.Warning(i18n("请上传参考音频")) + if text: + pass + else: + gr.Warning(i18n("请填入推理文本")) + t = [] + if prompt_text is None or len(prompt_text) == 0: + ref_free = True + if model_version in v3v4set: + ref_free = False # s2v3暂不支持ref_free + else: + if_sr = False + if model_version not in {"v3", "v4", "v2Pro", "v2ProPlus"}: + clean_bigvgan_model() + clean_hifigan_model() + clean_sv_cn_model() + t0 = ttime() + prompt_language = dict_language[prompt_language] + text_language = dict_language[text_language] + + if not ref_free: + prompt_text = prompt_text.strip("\n") + if prompt_text[-1] not in splits: + prompt_text += "。" if prompt_language != "en" else "." + print(i18n("实际输入的参考文本:"), prompt_text) + text = text.strip("\n") + # if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text + + print(i18n("实际输入的目标文本:"), text) + zero_wav = np.zeros( + int(hps.data.sampling_rate * pause_second), + dtype=np.float16 if is_half == True else np.float32, + ) + zero_wav_torch = torch.from_numpy(zero_wav) + if is_half == True: + zero_wav_torch = zero_wav_torch.half().to(device) + else: + zero_wav_torch = zero_wav_torch.to(device) + if not ref_free: + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: + gr.Warning(i18n("参考音频在3~10秒范围外,请更换!")) + raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + wav16k = torch.from_numpy(wav16k) + if is_half == True: + wav16k = wav16k.half().to(device) + else: + wav16k = wav16k.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + + t1 = ttime() + t.append(t1 - t0) + + if how_to_cut == i18n("凑四句一切"): + text = cut1(text) + elif how_to_cut == i18n("凑50字一切"): + text = cut2(text) + elif how_to_cut == i18n("按中文句号。切"): + text = cut3(text) + elif how_to_cut == i18n("按英文句号.切"): + text = cut4(text) + elif how_to_cut == i18n("按标点符号切"): + text = cut5(text) + while "\n\n" in text: + text = text.replace("\n\n", "\n") + print(i18n("实际输入的目标文本(切句后):"), text) + texts = text.split("\n") + texts = process_text(texts) + texts = merge_short_text_in_array(texts, 5) + audio_opt = [] + ###s2v3暂不支持ref_free + if not ref_free: + phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) + + for i_text, text in enumerate(texts): + # 解决输入目标文本的空行导致报错的问题 + if len(text.strip()) == 0: + continue + if text[-1] not in splits: + text += "。" if text_language != "en" else "." + print(i18n("实际输入的目标文本(每句):"), text) + phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) + print(i18n("前端处理后的文本(每句):"), norm_text2) + if not ref_free: + bert = torch.cat([bert1, bert2], 1) + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) + else: + bert = bert2 + all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0) + + bert = bert.to(device).unsqueeze(0) + all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) + + t2 = ttime() + # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature) + # print(cache.keys(),if_freeze) + if i_text in cache and if_freeze == True: + pred_semantic = cache[i_text] + else: + with torch.no_grad(): + pred_semantic, idx = t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_len, + None if ref_free else prompt, + bert, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=hz * max_sec, + ) + pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) + cache[i_text] = pred_semantic + t3 = ttime() + is_v2pro = model_version in {"v2Pro", "v2ProPlus"} + # print(23333,is_v2pro,model_version) + ###v3不存在以下逻辑和inp_refs + if model_version not in v3v4set: + refers = [] + if is_v2pro: + sv_emb = [] + if sv_cn_model == None: + init_sv_cn() + if inp_refs: + for path in inp_refs: + try: #####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer + refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro) + refers.append(refer) + if is_v2pro: + sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor)) + except: + traceback.print_exc() + if len(refers) == 0: + refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro) + refers = [refers] + if is_v2pro: + sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)] + if is_v2pro: + audio = vq_model.decode( + pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb + )[0][0] + else: + audio = vq_model.decode( + pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed + )[0][0] + else: + refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device) + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) + fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) + ref_audio, sr = torchaudio.load(ref_wav_path) + ref_audio = ref_audio.to(device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + tgt_sr = 24000 if model_version == "v3" else 32000 + if sr != tgt_sr: + ref_audio = resample(ref_audio, sr, tgt_sr, device) + # print("ref_audio",ref_audio.abs().mean()) + mel2 = mel_fn(ref_audio) if model_version == "v3" else mel_fn_v4(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + Tref = 468 if model_version == "v3" else 500 + Tchunk = 934 if model_version == "v3" else 1000 + if T_min > Tref: + mel2 = mel2[:, :, -Tref:] + fea_ref = fea_ref[:, :, -Tref:] + T_min = Tref + chunk_len = Tchunk - T_min + mel2 = mel2.to(dtype) + fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) + cfm_resss = [] + idx = 0 + while 1: + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + cfm_res = vq_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + cfm_resss.append(cfm_res) + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + if model_version == "v3": + if bigvgan_model == None: + init_bigvgan() + else: # v4 + if hifigan_model == None: + init_hifigan() + vocoder_model = bigvgan_model if model_version == "v3" else hifigan_model + with torch.inference_mode(): + wav_gen = vocoder_model(cfm_res) + audio = wav_gen[0][0] # .cpu().detach().numpy() + max_audio = torch.abs(audio).max() # 简单防止16bit爆音 + if max_audio > 1: + audio = audio / max_audio + audio_opt.append(audio) + audio_opt.append(zero_wav_torch) # zero_wav + t4 = ttime() + t.extend([t2 - t1, t3 - t2, t4 - t3]) + t1 = ttime() + print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))) + audio_opt = torch.cat(audio_opt, 0) # np.concatenate + if model_version in {"v1", "v2", "v2Pro", "v2ProPlus"}: + opt_sr = 32000 + elif model_version == "v3": + opt_sr = 24000 + else: + opt_sr = 48000 # v4 + if if_sr == True and opt_sr == 24000: + print(i18n("音频超分中")) + audio_opt, opt_sr = audio_sr(audio_opt.unsqueeze(0), opt_sr) + max_audio = np.abs(audio_opt).max() + if max_audio > 1: + audio_opt /= max_audio + else: + audio_opt = audio_opt.cpu().detach().numpy() + yield opt_sr, (audio_opt * 32767).astype(np.int16) + + +def split(todo_text): + todo_text = todo_text.replace("……", "。").replace("——", ",") + if todo_text[-1] not in splits: + todo_text += "。" + i_split_head = i_split_tail = 0 + len_text = len(todo_text) + todo_texts = [] + while 1: + if i_split_head >= len_text: + break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 + if todo_text[i_split_head] in splits: + i_split_head += 1 + todo_texts.append(todo_text[i_split_tail:i_split_head]) + i_split_tail = i_split_head + else: + i_split_head += 1 + return todo_texts + + +def cut1(inp): + inp = inp.strip("\n") + inps = split(inp) + split_idx = list(range(0, len(inps), 4)) + split_idx[-1] = None + if len(split_idx) > 1: + opts = [] + for idx in range(len(split_idx) - 1): + opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]])) + else: + opts = [inp] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +def cut2(inp): + inp = inp.strip("\n") + inps = split(inp) + if len(inps) < 2: + return inp + opts = [] + summ = 0 + tmp_str = "" + for i in range(len(inps)): + summ += len(inps[i]) + tmp_str += inps[i] + if summ > 50: + summ = 0 + opts.append(tmp_str) + tmp_str = "" + if tmp_str != "": + opts.append(tmp_str) + # print(opts) + if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 + opts[-2] = opts[-2] + opts[-1] + opts = opts[:-1] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +def cut3(inp): + inp = inp.strip("\n") + opts = ["%s" % item for item in inp.strip("。").split("。")] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +def cut4(inp): + inp = inp.strip("\n") + opts = re.split(r"(? 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit(): + items.append(char) + else: + items.append(char) + mergeitems.append("".join(items)) + items = [] + else: + items.append(char) + + if items: + mergeitems.append("".join(items)) + + opt = [item for item in mergeitems if not set(item).issubset(punds)] + return "\n".join(opt) + + +def custom_sort_key(s): + # 使用正则表达式提取字符串中的数字部分和非数字部分 + parts = re.split("(\d+)", s) + # 将数字部分转换为整数,非数字部分保持不变 + parts = [int(part) if part.isdigit() else part for part in parts] + return parts + + +def process_text(texts): + _text = [] + if all(text in [None, " ", "\n", ""] for text in texts): + raise ValueError(i18n("请输入有效文本")) + for text in texts: + if text in [None, " ", ""]: + pass + else: + _text.append(text) + return _text + + +def html_center(text, label="p"): + return f"""
+ <{label} style="margin: 0; padding: 0;">{text} +
""" + + +def html_left(text, label="p"): + return f"""
+ <{label} style="margin: 0; padding: 0;">{text} +
""" + + +with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app: + gr.HTML( + top_html.format( + i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") + ), + elem_classes="markdown", + ) + with gr.Group(): + gr.Markdown(html_center(i18n("模型切换"), "h3")) + with gr.Row(): + GPT_dropdown = gr.Dropdown( + label=i18n("GPT模型列表"), + choices=sorted(GPT_names, key=custom_sort_key), + value=gpt_path, + interactive=True, + scale=14, + ) + SoVITS_dropdown = gr.Dropdown( + label=i18n("SoVITS模型列表"), + choices=sorted(SoVITS_names, key=custom_sort_key), + value=sovits_path, + interactive=True, + scale=14, + ) + refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary", scale=14) + refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown]) + gr.Markdown(html_center(i18n("*请上传并填写参考信息"), "h3")) + with gr.Row(): + inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath", scale=13) + with gr.Column(scale=13): + ref_text_free = gr.Checkbox( + label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。") + + i18n("v3暂不支持该模式,使用了会报错。"), + value=False, + interactive=True if model_version not in v3v4set else False, + show_label=True, + scale=1, + ) + gr.Markdown( + html_left( + i18n("使用无参考文本模式时建议使用微调的GPT") + + "
" + + i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。") + ) + ) + prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=5, max_lines=5, scale=1) + with gr.Column(scale=14): + prompt_language = gr.Dropdown( + label=i18n("参考音频的语种"), + choices=list(dict_language.keys()), + value=i18n("中文"), + ) + inp_refs = ( + gr.File( + label=i18n( + "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。" + ), + file_count="multiple", + ) + if model_version not in v3v4set + else gr.File( + label=i18n( + "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。" + ), + file_count="multiple", + visible=False, + ) + ) + sample_steps = ( + gr.Radio( + label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"), + value=32 if model_version == "v3" else 8, + choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32], + visible=True, + ) + if model_version in v3v4set + else gr.Radio( + label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"), + choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32], + visible=False, + value=32 if model_version == "v3" else 8, + ) + ) + if_sr_Checkbox = gr.Checkbox( + label=i18n("v3输出如果觉得闷可以试试开超分"), + value=False, + interactive=True, + show_label=True, + visible=False if model_version != "v3" else True, + ) + gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3")) + with gr.Row(): + with gr.Column(scale=13): + text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=26, max_lines=26) + with gr.Column(scale=7): + text_language = gr.Dropdown( + label=i18n("需要合成的语种") + i18n(".限制范围越小判别效果越好。"), + choices=list(dict_language.keys()), + value=i18n("中文"), + scale=1, + ) + how_to_cut = gr.Dropdown( + label=i18n("怎么切"), + choices=[ + i18n("不切"), + i18n("凑四句一切"), + i18n("凑50字一切"), + i18n("按中文句号。切"), + i18n("按英文句号.切"), + i18n("按标点符号切"), + ], + value=i18n("凑四句一切"), + interactive=True, + scale=1, + ) + gr.Markdown(value=html_center(i18n("语速调整,高为更快"))) + if_freeze = gr.Checkbox( + label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"), + value=False, + interactive=True, + show_label=True, + scale=1, + ) + with gr.Row(): + speed = gr.Slider( + minimum=0.6, maximum=1.65, step=0.05, label=i18n("语速"), value=1, interactive=True, scale=1 + ) + pause_second_slider = gr.Slider( + minimum=0.1, + maximum=0.5, + step=0.01, + label=i18n("句间停顿秒数"), + value=0.3, + interactive=True, + scale=1, + ) + gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认):"))) + top_k = gr.Slider( + minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True, scale=1 + ) + top_p = gr.Slider( + minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True, scale=1 + ) + temperature = gr.Slider( + minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True, scale=1 + ) + # with gr.Column(): + # gr.Markdown(value=i18n("手工调整音素。当音素框不为空时使用手工音素输入推理,无视目标文本框。")) + # phoneme=gr.Textbox(label=i18n("音素框"), value="") + # get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary") + with gr.Row(): + inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25) + output = gr.Audio(label=i18n("输出的语音"), scale=14) + + inference_button.click( + get_tts_wav, + [ + inp_ref, + prompt_text, + prompt_language, + text, + text_language, + how_to_cut, + top_k, + top_p, + temperature, + ref_text_free, + speed, + if_freeze, + inp_refs, + sample_steps, + if_sr_Checkbox, + pause_second_slider, + ], + [output], + ) + SoVITS_dropdown.change( + change_sovits_weights, + [SoVITS_dropdown, prompt_language, text_language], + [ + prompt_language, + text_language, + prompt_text, + prompt_language, + text, + text_language, + sample_steps, + inp_refs, + ref_text_free, + if_sr_Checkbox, + inference_button, + ], + ) + GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], []) + + # gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")) + # with gr.Row(): + # text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="") + # button1 = gr.Button(i18n("凑四句一切"), variant="primary") + # button2 = gr.Button(i18n("凑50字一切"), variant="primary") + # button3 = gr.Button(i18n("按中文句号。切"), variant="primary") + # button4 = gr.Button(i18n("按英文句号.切"), variant="primary") + # button5 = gr.Button(i18n("按标点符号切"), variant="primary") + # text_opt = gr.Textbox(label=i18n("切分后文本"), value="") + # button1.click(cut1, [text_inp], [text_opt]) + # button2.click(cut2, [text_inp], [text_opt]) + # button3.click(cut3, [text_inp], [text_opt]) + # button4.click(cut4, [text_inp], [text_opt]) + # button5.click(cut5, [text_inp], [text_opt]) + # gr.Markdown(html_center(i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))) + +if __name__ == "__main__": + app.queue().launch( # concurrency_count=511, max_size=1022 + server_name="0.0.0.0", + inbrowser=True, + share=is_share, + server_port=infer_ttswebui, + # quiet=True, + ) diff --git a/inference_webui_fast.py b/inference_webui_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..470b7bbd35e6feab47f8ad0a4367a08fef7173ef --- /dev/null +++ b/inference_webui_fast.py @@ -0,0 +1,509 @@ +""" +按中英混合识别 +按日英混合识别 +多语种启动切分识别语种 +全部按中文识别 +全部按英文识别 +全部按日文识别 +""" + +import json +import logging +import os +import random +import re +import sys + +import torch + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + +logging.getLogger("markdown_it").setLevel(logging.ERROR) +logging.getLogger("urllib3").setLevel(logging.ERROR) +logging.getLogger("httpcore").setLevel(logging.ERROR) +logging.getLogger("httpx").setLevel(logging.ERROR) +logging.getLogger("asyncio").setLevel(logging.ERROR) +logging.getLogger("charset_normalizer").setLevel(logging.ERROR) +logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) + + +infer_ttswebui = os.environ.get("infer_ttswebui", 9872) +infer_ttswebui = int(infer_ttswebui) +is_share = os.environ.get("is_share", "False") +is_share = eval(is_share) +if "_CUDA_VISIBLE_DEVICES" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] + +is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +gpt_path = os.environ.get("gpt_path", None) +sovits_path = os.environ.get("sovits_path", None) +cnhubert_base_path = os.environ.get("cnhubert_base_path", None) +bert_path = os.environ.get("bert_path", None) +version = model_version = os.environ.get("version", "v2") + +import gradio as gr +from TTS_infer_pack.text_segmentation_method import get_method +from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config + +from tools.assets import css, js, top_html +from tools.i18n.i18n import I18nAuto, scan_language_list + +language = os.environ.get("language", "Auto") +language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language +i18n = I18nAuto(language=language) + + +# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。 + +if torch.cuda.is_available(): + device = "cuda" +# elif torch.backends.mps.is_available(): +# device = "mps" +else: + device = "cpu" + +# is_half = False +# device = "cpu" + +dict_language_v1 = { + i18n("中文"): "all_zh", # 全部按中文识别 + i18n("英文"): "en", # 全部按英文识别#######不变 + i18n("日文"): "all_ja", # 全部按日文识别 + i18n("中英混合"): "zh", # 按中英混合识别####不变 + i18n("日英混合"): "ja", # 按日英混合识别####不变 + i18n("多语种混合"): "auto", # 多语种启动切分识别语种 +} +dict_language_v2 = { + i18n("中文"): "all_zh", # 全部按中文识别 + i18n("英文"): "en", # 全部按英文识别#######不变 + i18n("日文"): "all_ja", # 全部按日文识别 + i18n("粤语"): "all_yue", # 全部按中文识别 + i18n("韩文"): "all_ko", # 全部按韩文识别 + i18n("中英混合"): "zh", # 按中英混合识别####不变 + i18n("日英混合"): "ja", # 按日英混合识别####不变 + i18n("粤英混合"): "yue", # 按粤英混合识别####不变 + i18n("韩英混合"): "ko", # 按韩英混合识别####不变 + i18n("多语种混合"): "auto", # 多语种启动切分识别语种 + i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种 +} +dict_language = dict_language_v1 if version == "v1" else dict_language_v2 + +cut_method = { + i18n("不切"): "cut0", + i18n("凑四句一切"): "cut1", + i18n("凑50字一切"): "cut2", + i18n("按中文句号。切"): "cut3", + i18n("按英文句号.切"): "cut4", + i18n("按标点符号切"): "cut5", +} + +from config import change_choices, get_weights_names, name2gpt_path, name2sovits_path + +SoVITS_names, GPT_names = get_weights_names() +from config import pretrained_sovits_name + +path_sovits_v3 = pretrained_sovits_name["v3"] +path_sovits_v4 = pretrained_sovits_name["v4"] +is_exist_s2gv3 = os.path.exists(path_sovits_v3) +is_exist_s2gv4 = os.path.exists(path_sovits_v4) + +tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml") +tts_config.device = device +tts_config.is_half = is_half +tts_config.version = version +if gpt_path is not None: + if "!" in gpt_path or "!" in gpt_path: + gpt_path = name2gpt_path[gpt_path] + tts_config.t2s_weights_path = gpt_path +if sovits_path is not None: + if "!" in sovits_path or "!" in sovits_path: + sovits_path = name2sovits_path[sovits_path] + tts_config.vits_weights_path = sovits_path +if cnhubert_base_path is not None: + tts_config.cnhuhbert_base_path = cnhubert_base_path +if bert_path is not None: + tts_config.bert_base_path = bert_path + +print(tts_config) +tts_pipeline = TTS(tts_config) +gpt_path = tts_config.t2s_weights_path +sovits_path = tts_config.vits_weights_path +version = tts_config.version + + +def inference( + text, + text_lang, + ref_audio_path, + aux_ref_audio_paths, + prompt_text, + prompt_lang, + top_k, + top_p, + temperature, + text_split_method, + batch_size, + speed_factor, + ref_text_free, + split_bucket, + fragment_interval, + seed, + keep_random, + parallel_infer, + repetition_penalty, + sample_steps, + super_sampling, +): + seed = -1 if keep_random else seed + actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1) + inputs = { + "text": text, + "text_lang": dict_language[text_lang], + "ref_audio_path": ref_audio_path, + "aux_ref_audio_paths": [item.name for item in aux_ref_audio_paths] if aux_ref_audio_paths is not None else [], + "prompt_text": prompt_text if not ref_text_free else "", + "prompt_lang": dict_language[prompt_lang], + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "text_split_method": cut_method[text_split_method], + "batch_size": int(batch_size), + "speed_factor": float(speed_factor), + "split_bucket": split_bucket, + "return_fragment": False, + "fragment_interval": fragment_interval, + "seed": actual_seed, + "parallel_infer": parallel_infer, + "repetition_penalty": repetition_penalty, + "sample_steps": int(sample_steps), + "super_sampling": super_sampling, + } + try: + for item in tts_pipeline.run(inputs): + yield item, actual_seed + except NO_PROMPT_ERROR: + gr.Warning(i18n("V3不支持无参考文本模式,请填写参考文本!")) + + +def custom_sort_key(s): + # 使用正则表达式提取字符串中的数字部分和非数字部分 + parts = re.split("(\d+)", s) + # 将数字部分转换为整数,非数字部分保持不变 + parts = [int(part) if part.isdigit() else part for part in parts] + return parts + + +if os.path.exists("./weight.json"): + pass +else: + with open("./weight.json", "w", encoding="utf-8") as file: + json.dump({"GPT": {}, "SoVITS": {}}, file) + +with open("./weight.json", "r", encoding="utf-8") as file: + weight_data = file.read() + weight_data = json.loads(weight_data) + gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, GPT_names[-1])) + sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, SoVITS_names[0])) + if isinstance(gpt_path, list): + gpt_path = gpt_path[0] + if isinstance(sovits_path, list): + sovits_path = sovits_path[0] + +from process_ckpt import get_sovits_version_from_path_fast + +v3v4set = {"v3", "v4"} + + +def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): + if "!" in sovits_path or "!" in sovits_path: + sovits_path = name2sovits_path[sovits_path] + global version, model_version, dict_language, if_lora_v3 + version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) + # print(sovits_path,version, model_version, if_lora_v3) + is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4 + path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 + if if_lora_v3 == True and is_exist == False: + info = path_sovits + "SoVITS %s" % model_version + i18n("底模缺失,无法加载相应 LoRA 权重") + gr.Warning(info) + raise FileExistsError(info) + dict_language = dict_language_v1 if version == "v1" else dict_language_v2 + if prompt_language is not None and text_language is not None: + if prompt_language in list(dict_language.keys()): + prompt_text_update, prompt_language_update = ( + {"__type__": "update"}, + {"__type__": "update", "value": prompt_language}, + ) + else: + prompt_text_update = {"__type__": "update", "value": ""} + prompt_language_update = {"__type__": "update", "value": i18n("中文")} + if text_language in list(dict_language.keys()): + text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language} + else: + text_update = {"__type__": "update", "value": ""} + text_language_update = {"__type__": "update", "value": i18n("中文")} + if model_version in v3v4set: + visible_sample_steps = True + visible_inp_refs = False + else: + visible_sample_steps = False + visible_inp_refs = True + yield ( + {"__type__": "update", "choices": list(dict_language.keys())}, + {"__type__": "update", "choices": list(dict_language.keys())}, + prompt_text_update, + prompt_language_update, + text_update, + text_language_update, + {"__type__": "update", "interactive": visible_sample_steps, "value": 32}, + {"__type__": "update", "visible": visible_inp_refs}, + {"__type__": "update", "interactive": True if model_version not in v3v4set else False}, + {"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False}, + ) + + tts_pipeline.init_vits_weights(sovits_path) + yield ( + {"__type__": "update", "choices": list(dict_language.keys())}, + {"__type__": "update", "choices": list(dict_language.keys())}, + prompt_text_update, + prompt_language_update, + text_update, + text_language_update, + {"__type__": "update", "interactive": visible_sample_steps, "value": 32}, + {"__type__": "update", "visible": visible_inp_refs}, + {"__type__": "update", "interactive": True if model_version not in v3v4set else False}, + {"__type__": "update", "value": i18n("合成语音"), "interactive": True}, + ) + with open("./weight.json") as f: + data = f.read() + data = json.loads(data) + data["SoVITS"][version] = sovits_path + with open("./weight.json", "w") as f: + f.write(json.dumps(data)) + + +def change_gpt_weights(gpt_path): + if "!" in gpt_path or "!" in gpt_path: + gpt_path = name2gpt_path[gpt_path] + tts_pipeline.init_t2s_weights(gpt_path) + + +with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app: + gr.HTML( + top_html.format( + i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") + ), + elem_classes="markdown", + ) + + with gr.Column(): + # with gr.Group(): + gr.Markdown(value=i18n("模型切换")) + with gr.Row(): + GPT_dropdown = gr.Dropdown( + label=i18n("GPT模型列表"), + choices=sorted(GPT_names, key=custom_sort_key), + value=gpt_path, + interactive=True, + ) + SoVITS_dropdown = gr.Dropdown( + label=i18n("SoVITS模型列表"), + choices=sorted(SoVITS_names, key=custom_sort_key), + value=sovits_path, + interactive=True, + ) + refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary") + refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown]) + + with gr.Row(): + with gr.Column(): + gr.Markdown(value=i18n("*请上传并填写参考信息")) + with gr.Row(): + inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath") + inp_refs = gr.File( + label=i18n("辅参考音频(可选多个,或不选)"), + file_count="multiple", + visible=True if model_version != "v3" else False, + ) + prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2) + with gr.Row(): + prompt_language = gr.Dropdown( + label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文") + ) + with gr.Column(): + ref_text_free = gr.Checkbox( + label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), + value=False, + interactive=True if model_version != "v3" else False, + show_label=True, + ) + gr.Markdown( + i18n("使用无参考文本模式时建议使用微调的GPT") + + "
" + + i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。") + ) + + with gr.Column(): + gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式")) + text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=20, max_lines=20) + text_language = gr.Dropdown( + label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文") + ) + + with gr.Group(): + gr.Markdown(value=i18n("推理设置")) + with gr.Row(): + with gr.Column(): + with gr.Row(): + batch_size = gr.Slider( + minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True + ) + sample_steps = gr.Radio( + label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True + ) + with gr.Row(): + fragment_interval = gr.Slider( + minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True + ) + speed_factor = gr.Slider( + minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True + ) + with gr.Row(): + top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True) + top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True) + with gr.Row(): + temperature = gr.Slider( + minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True + ) + repetition_penalty = gr.Slider( + minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True + ) + + with gr.Column(): + with gr.Row(): + how_to_cut = gr.Dropdown( + label=i18n("怎么切"), + choices=[ + i18n("不切"), + i18n("凑四句一切"), + i18n("凑50字一切"), + i18n("按中文句号。切"), + i18n("按英文句号.切"), + i18n("按标点符号切"), + ], + value=i18n("凑四句一切"), + interactive=True, + scale=1, + ) + super_sampling = gr.Checkbox( + label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True + ) + + with gr.Row(): + parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True) + split_bucket = gr.Checkbox( + label=i18n("数据分桶(并行推理时会降低一点计算量)"), + value=True, + interactive=True, + show_label=True, + ) + + with gr.Row(): + seed = gr.Number(label=i18n("随机种子"), value=-1) + keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True) + + output = gr.Audio(label=i18n("输出的语音")) + with gr.Row(): + inference_button = gr.Button(i18n("合成语音"), variant="primary") + stop_infer = gr.Button(i18n("终止合成"), variant="primary") + + inference_button.click( + inference, + [ + text, + text_language, + inp_ref, + inp_refs, + prompt_text, + prompt_language, + top_k, + top_p, + temperature, + how_to_cut, + batch_size, + speed_factor, + ref_text_free, + split_bucket, + fragment_interval, + seed, + keep_random, + parallel_infer, + repetition_penalty, + sample_steps, + super_sampling, + ], + [output, seed], + ) + stop_infer.click(tts_pipeline.stop, [], []) + SoVITS_dropdown.change( + change_sovits_weights, + [SoVITS_dropdown, prompt_language, text_language], + [ + prompt_language, + text_language, + prompt_text, + prompt_language, + text, + text_language, + sample_steps, + inp_refs, + ref_text_free, + inference_button, + ], + ) # + GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], []) + + with gr.Group(): + gr.Markdown( + value=i18n( + "文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。" + ) + ) + with gr.Row(): + text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4) + with gr.Column(): + _how_to_cut = gr.Radio( + label=i18n("怎么切"), + choices=[ + i18n("不切"), + i18n("凑四句一切"), + i18n("凑50字一切"), + i18n("按中文句号。切"), + i18n("按英文句号.切"), + i18n("按标点符号切"), + ], + value=i18n("凑四句一切"), + interactive=True, + ) + cut_text = gr.Button(i18n("切分"), variant="primary") + + def to_cut(text_inp, how_to_cut): + if len(text_inp.strip()) == 0 or text_inp == []: + return "" + method = get_method(cut_method[how_to_cut]) + return method(text_inp) + + text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4) + cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt]) + gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。")) + +if __name__ == "__main__": + app.queue().launch( # concurrency_count=511, max_size=1022 + server_name="0.0.0.0", + inbrowser=True, + share=is_share, + server_port=infer_ttswebui, + # quiet=True, + ) diff --git a/module/__init__.py b/module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/module/attentions.py b/module/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..341de4ace129f713e5499cc8d7862ce3986d7175 --- /dev/null +++ b/module/attentions.py @@ -0,0 +1,659 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from module import commons +from module.modules import LayerNorm + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + isflow=False, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + if isflow: + cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1) + self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1) + self.cond_layer = weight_norm_modules(cond_layer, name="weight") + self.gin_channels = kwargs["gin_channels"] + + def forward(self, x, x_mask, g=None): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + if g is not None: + x = self.cond_pre(x) + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels])) + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert t_s == t_t, "Local attention is only available for self-attention." + block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + +import torch.nn as nn +from torch.nn.utils import remove_weight_norm, weight_norm + + +class Depthwise_Separable_Conv1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias=True, + padding_mode="zeros", # TODO: refine this type + device=None, + dtype=None, + ): + super().__init__() + self.depth_conv = nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + groups=in_channels, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + self.point_conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input): + return self.point_conv(self.depth_conv(input)) + + def weight_norm(self): + self.depth_conv = weight_norm(self.depth_conv, name="weight") + self.point_conv = weight_norm(self.point_conv, name="weight") + + def remove_weight_norm(self): + self.depth_conv = remove_weight_norm(self.depth_conv, name="weight") + self.point_conv = remove_weight_norm(self.point_conv, name="weight") + + +class Depthwise_Separable_TransposeConv1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + bias=True, + dilation=1, + padding_mode="zeros", # TODO: refine this type + device=None, + dtype=None, + ): + super().__init__() + self.depth_conv = nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + groups=in_channels, + stride=stride, + output_padding=output_padding, + padding=padding, + dilation=dilation, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + self.point_conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input): + return self.point_conv(self.depth_conv(input)) + + def weight_norm(self): + self.depth_conv = weight_norm(self.depth_conv, name="weight") + self.point_conv = weight_norm(self.point_conv, name="weight") + + def remove_weight_norm(self): + remove_weight_norm(self.depth_conv, name="weight") + remove_weight_norm(self.point_conv, name="weight") + + +def weight_norm_modules(module, name="weight", dim=0): + if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D): + module.weight_norm() + return module + else: + return weight_norm(module, name, dim) + + +def remove_weight_norm_modules(module, name="weight"): + if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D): + module.remove_weight_norm() + else: + remove_weight_norm(module, name) + + +class FFT(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers=1, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + isflow=False, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + if isflow: + cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1) + self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1) + self.cond_layer = weight_norm_modules(cond_layer, name="weight") + self.gin_channels = kwargs["gin_channels"] + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, g=None): + """ + x: decoder input + h: encoder output + """ + if g is not None: + g = self.cond_layer(g) + + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) + x = x * x_mask + for i in range(self.n_layers): + if g is not None: + x = self.cond_pre(x) + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels])) + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + x = x * x_mask + return x + + +class TransformerCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + n_layers, + n_heads, + p_dropout=0, + filter_channels=0, + mean_only=False, + wn_sharing_parameter=None, + gin_channels=0, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = ( + Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + isflow=True, + gin_channels=gin_channels, + ) + if wn_sharing_parameter is None + else wn_sharing_parameter + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x diff --git a/module/attentions_onnx.py b/module/attentions_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..9961f9870259d70a1e5cdb4b0bc1b1a9cb14f8c5 --- /dev/null +++ b/module/attentions_onnx.py @@ -0,0 +1,385 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from module import commons + +from typing import Optional + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + isflow=True, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + # if isflow: + # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1) + # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) + # self.cond_layer = weight_norm(cond_layer, name='weight') + # self.gin_channels = 256 + self.cond_layer_idx = self.n_layers + self.spk_emb_linear = nn.Linear(256, self.hidden_channels) + if "gin_channels" in kwargs: + self.gin_channels = kwargs["gin_channels"] + if self.gin_channels != 0: + self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) + # vits2 says 3rd block, so idx is 2 by default + self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 + logging.debug(self.gin_channels, self.cond_layer_idx) + assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers" + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + # def forward(self, x, x_mask, g=None): + # attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + # x = x * x_mask + # for i in range(self.n_layers): + # if i == self.cond_layer_idx and g is not None: + # g = self.spk_emb_linear(g.transpose(1, 2)) + # g = g.transpose(1, 2) + # x = x + g + # x = x * x_mask + # y = self.attn_layers[i](x, x, attn_mask) + # y = self.drop(y) + # x = self.norm_layers_1[i](x + y) + + # y = self.ffn_layers[i](x, x_mask) + # y = self.drop(y) + # x = self.norm_layers_2[i](x + y) + # x = x * x_mask + # return x + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zip( + self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2 + ): + y = attn_layers(x, x, attn_mask) + y = self.drop(y) + x = norm_layers_1(x + y) + + y = ffn_layers(x, x_mask) + y = self.drop(y) + x = norm_layers_2(x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask: Optional[torch.Tensor] = None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + # x, self.attn = self.attention(q, k, v, mask=attn_mask) + x, _ = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask: Optional[torch.Tensor] = None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, _ = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + + if self.window_size is not None: + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + + p_attn = F.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + + output = output.transpose(2, 3).contiguous().view(b, d, -1) + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_l = torch.zeros((1), dtype=torch.int64) + length - (self.window_size + 1) + pad_s = torch.zeros((1), dtype=torch.int64) + (self.window_size + 1) - length + pad_length = torch.max(pad_l, other=torch.zeros((1), dtype=torch.int64)) + slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype=torch.int64)) + + slice_end_position = slice_start_position + 2 * length - 1 + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation="", + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + # 从上下文看这里一定是 False + # if causal: + # self.padding = self._causal_padding + # else: + # self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def padding(self, x): + return self._same_padding(x) + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + +class MRTE(nn.Module): + def __init__( + self, + content_enc_channels=192, + hidden_size=512, + out_channels=192, + kernel_size=5, + n_heads=4, + ge_layer=2, + ): + super(MRTE, self).__init__() + self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads) + self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1) + self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1) + self.c_post = nn.Conv1d(hidden_size, out_channels, 1) + + def forward(self, ssl_enc, ssl_mask, text, text_mask, ge): + attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1) + + ssl_enc = self.c_pre(ssl_enc * ssl_mask) + text_enc = self.text_pre(text * text_mask) + x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge + x = self.c_post(x * ssl_mask) + return x diff --git a/module/commons.py b/module/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..20392f91fdd7632ebcbc508222500c37c1de5140 --- /dev/null +++ b/module/commons.py @@ -0,0 +1,185 @@ +import math +import torch +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +# def convert_pad_shape(pad_shape): +# l = pad_shape[::-1] +# pad_shape = [item for sublist in l for item in sublist] +# return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm + + +def squeeze(x, x_mask=None, n_sqz=2): + b, c, t = x.size() + + t = (t // n_sqz) * n_sqz + x = x[:, :, :t] + x_sqz = x.view(b, c, t // n_sqz, n_sqz) + x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz) + + if x_mask is not None: + x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz] + else: + x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype) + return x_sqz * x_mask, x_mask + + +def unsqueeze(x, x_mask=None, n_sqz=2): + b, c, t = x.size() + + x_unsqz = x.view(b, n_sqz, c // n_sqz, t) + x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz) + + if x_mask is not None: + x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz) + else: + x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype) + return x_unsqz * x_mask, x_mask diff --git a/module/core_vq.py b/module/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..b7dab3171814598660b49aff46119f77dd1ab67e --- /dev/null +++ b/module/core_vq.py @@ -0,0 +1,365 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import typing as tp + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F +from tqdm import tqdm + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + max_kmeans_samples = 500 + samples = samples[:max_kmeans_samples, :] + means = sample_vectors(samples, num_clusters) + + print("kmeans start ... ") + for _ in tqdm(range(num_iters)): + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + # broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + # broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True)) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)]) + + def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + out_quantized = [] + + n_q = n_q or len(self.layers) + + for i, layer in enumerate(self.layers[:n_q]): + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + if layers and i in layers: + out_quantized.append(quantized) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses, out_quantized + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + st = st or 0 + for layer in self.layers[st:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[st + i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/module/data_utils.py b/module/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..46eff5fb9427ced83f6d39f0e891c67867c64e97 --- /dev/null +++ b/module/data_utils.py @@ -0,0 +1,1071 @@ +import os +import random +import traceback +import torch +import torch.utils.data +from tqdm import tqdm + +from module.mel_processing import spectrogram_torch, spec_to_mel_torch +from text import cleaned_text_to_sequence +import torch.nn.functional as F +from tools.my_utils import load_audio + +version = os.environ.get("version", None) + + +# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79) +class TextAudioSpeakerLoader(torch.utils.data.Dataset): + """ + 1) loads audio, speaker_id, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, hparams, version=None, val=False): + exp_dir = hparams.exp_dir + self.path2 = "%s/2-name2text.txt" % exp_dir + self.path4 = "%s/4-cnhubert" % exp_dir + self.path5 = "%s/5-wav32k" % exp_dir + assert os.path.exists(self.path2) + assert os.path.exists(self.path4) + assert os.path.exists(self.path5) + self.is_v2Pro = version in {"v2Pro", "v2ProPlus"} + if self.is_v2Pro: + self.path7 = "%s/7-sv_cn" % exp_dir + assert os.path.exists(self.path7) + names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀 + names5 = set(os.listdir(self.path5)) + if self.is_v2Pro: + names6 = set([name[:-3] for name in list(os.listdir(self.path7))]) # 去除.pt后缀 + self.phoneme_data = {} + with open(self.path2, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") + + for line in lines: + tmp = line.split("\t") + if len(tmp) != 4: + continue + self.phoneme_data[tmp[0]] = [tmp[1]] + if self.is_v2Pro: + self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5 & names6) + else: + self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5) + tmp = self.audiopaths_sid_text + leng = len(tmp) + min_num = 100 + if leng < min_num: + self.audiopaths_sid_text = [] + for _ in range(max(2, int(min_num / leng))): + self.audiopaths_sid_text += tmp + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + self.val = val + + random.seed(1234) + random.shuffle(self.audiopaths_sid_text) + + print("phoneme_data_len:", len(self.phoneme_data.keys())) + print("wav_data_len:", len(self.audiopaths_sid_text)) + + audiopaths_sid_text_new = [] + lengths = [] + skipped_phone = 0 + skipped_dur = 0 + for audiopath in tqdm(self.audiopaths_sid_text): + try: + phoneme = self.phoneme_data[audiopath][0] + phoneme = phoneme.split(" ") + phoneme_ids = cleaned_text_to_sequence(phoneme, version) + except Exception: + print(f"{audiopath} not in self.phoneme_data !") + skipped_phone += 1 + continue + + size = os.path.getsize("%s/%s" % (self.path5, audiopath)) + duration = size / self.sampling_rate / 2 + + if duration == 0: + print(f"Zero duration for {audiopath}, skipping...") + skipped_dur += 1 + continue + + if 54 > duration > 0.6 or self.val: + audiopaths_sid_text_new.append([audiopath, phoneme_ids]) + lengths.append(size // (2 * self.hop_length)) + else: + skipped_dur += 1 + continue + + print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur) + print("total left: ", len(audiopaths_sid_text_new)) + assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo + self.audiopaths_sid_text = audiopaths_sid_text_new + self.lengths = lengths + + def get_audio_text_speaker_pair(self, audiopath_sid_text): + audiopath, phoneme_ids = audiopath_sid_text + text = torch.FloatTensor(phoneme_ids) + try: + spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath)) + with torch.no_grad(): + ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu") + if ssl.shape[-1] != spec.shape[-1]: + typee = ssl.dtype + ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) + ssl.requires_grad = False + if self.is_v2Pro: + sv_emb = torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu") + except: + traceback.print_exc() + spec = torch.zeros(1025, 100) + wav = torch.zeros(1, 100 * self.hop_length) + ssl = torch.zeros(1, 768, 100) + text = text[-1:] + if self.is_v2Pro: + sv_emb = torch.zeros(1, 20480) + print("load audio or ssl error!!!!!!", audiopath) + if self.is_v2Pro: + return (ssl, spec, wav, text, sv_emb) + else: + return (ssl, spec, wav, text) + + def get_audio(self, filename): + audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768 + audio = torch.FloatTensor(audio_array) # /32768 + audio_norm = audio + audio_norm = audio_norm.unsqueeze(0) + spec = spectrogram_torch( + audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False + ) + spec = torch.squeeze(spec, 0) + return spec, audio_norm + + def get_sid(self, sid): + sid = torch.LongTensor([int(sid)]) + return sid + + def __getitem__(self, index): + # with torch.no_grad(): + return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) + + def __len__(self): + return len(self.audiopaths_sid_text) + + def random_slice(self, ssl, wav, mel): + assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ("first", ssl.shape, wav.shape) + + len_mel = mel.shape[1] + if self.val: + reference_mel = mel[:, : len_mel // 3] + return reference_mel, ssl, wav, mel + dir = random.randint(0, 1) + sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2)) + + if dir == 0: + reference_mel = mel[:, :sep_point] + ssl = ssl[:, :, sep_point:] + wav2 = wav[:, sep_point * self.hop_length :] + mel = mel[:, sep_point:] + else: + reference_mel = mel[:, sep_point:] + ssl = ssl[:, :, :sep_point] + wav2 = wav[:, : sep_point * self.hop_length] + mel = mel[:, :sep_point] + + assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, ( + ssl.shape, + wav.shape, + wav2.shape, + mel.shape, + sep_point, + self.hop_length, + sep_point * self.hop_length, + dir, + ) + return reference_mel, ssl, wav2, mel + + +class TextAudioSpeakerCollate: + """Zero-pads model inputs and targets""" + + def __init__(self, return_ids=False, version=None): + self.return_ids = return_ids + self.is_v2Pro = version in {"v2Pro", "v2ProPlus"} + + def __call__(self, batch): + """Collate's training batch from normalized text, audio and speaker identities + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized, sid] + """ + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True) + + max_ssl_len = max([x[0].size(2) for x in batch]) + max_ssl_len = int(2 * ((max_ssl_len // 2) + 1)) + max_spec_len = max([x[1].size(1) for x in batch]) + max_spec_len = int(2 * ((max_spec_len // 2) + 1)) + max_wav_len = max([x[2].size(1) for x in batch]) + max_text_len = max([x[3].size(0) for x in batch]) + + ssl_lengths = torch.LongTensor(len(batch)) + spec_lengths = torch.LongTensor(len(batch)) + wav_lengths = torch.LongTensor(len(batch)) + text_lengths = torch.LongTensor(len(batch)) + + spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) + wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) + ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len) + text_padded = torch.LongTensor(len(batch), max_text_len) + + spec_padded.zero_() + wav_padded.zero_() + ssl_padded.zero_() + text_padded.zero_() + + if self.is_v2Pro: + sv_embs = torch.FloatTensor(len(batch), 20480) + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + + ssl = row[0] + ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :] + ssl_lengths[i] = ssl.size(2) + + spec = row[1] + spec_padded[i, :, : spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + wav = row[2] + wav_padded[i, :, : wav.size(1)] = wav + wav_lengths[i] = wav.size(1) + + text = row[3] + text_padded[i, : text.size(0)] = text + text_lengths[i] = text.size(0) + + if self.is_v2Pro: + sv_embs[i] = row[4] + if self.is_v2Pro: + return ( + ssl_padded, + ssl_lengths, + spec_padded, + spec_lengths, + wav_padded, + wav_lengths, + text_padded, + text_lengths, + sv_embs, + ) + else: + return ( + ssl_padded, + ssl_lengths, + spec_padded, + spec_lengths, + wav_padded, + wav_lengths, + text_padded, + text_lengths, + ) + + +class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset): + """ + 1) loads audio, speaker_id, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, hparams, val=False): + exp_dir = hparams.exp_dir + self.path2 = "%s/2-name2text.txt" % exp_dir + self.path4 = "%s/4-cnhubert" % exp_dir + self.path5 = "%s/5-wav32k" % exp_dir + assert os.path.exists(self.path2) + assert os.path.exists(self.path4) + assert os.path.exists(self.path5) + names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀 + names5 = set(os.listdir(self.path5)) + self.phoneme_data = {} + with open(self.path2, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") + + for line in lines: + tmp = line.split("\t") + if len(tmp) != 4: + continue + self.phoneme_data[tmp[0]] = [tmp[1]] + + self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5) + tmp = self.audiopaths_sid_text + leng = len(tmp) + min_num = 100 + if leng < min_num: + self.audiopaths_sid_text = [] + for _ in range(max(2, int(min_num / leng))): + self.audiopaths_sid_text += tmp + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + self.val = val + + random.seed(1234) + random.shuffle(self.audiopaths_sid_text) + + print("phoneme_data_len:", len(self.phoneme_data.keys())) + print("wav_data_len:", len(self.audiopaths_sid_text)) + + audiopaths_sid_text_new = [] + lengths = [] + skipped_phone = 0 + skipped_dur = 0 + for audiopath in tqdm(self.audiopaths_sid_text): + try: + phoneme = self.phoneme_data[audiopath][0] + phoneme = phoneme.split(" ") + phoneme_ids = cleaned_text_to_sequence(phoneme, version) + except Exception: + print(f"{audiopath} not in self.phoneme_data !") + skipped_phone += 1 + continue + + size = os.path.getsize("%s/%s" % (self.path5, audiopath)) + duration = size / self.sampling_rate / 2 + + if duration == 0: + print(f"Zero duration for {audiopath}, skipping...") + skipped_dur += 1 + continue + + if 54 > duration > 0.6 or self.val: + audiopaths_sid_text_new.append([audiopath, phoneme_ids]) + lengths.append(size // (2 * self.hop_length)) + else: + skipped_dur += 1 + continue + + print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur) + print("total left: ", len(audiopaths_sid_text_new)) + assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo + self.audiopaths_sid_text = audiopaths_sid_text_new + self.lengths = lengths + self.spec_min = -12 + self.spec_max = 2 + + self.filter_length_mel = self.win_length_mel = 1024 + self.hop_length_mel = 256 + self.n_mel_channels = 100 + self.sampling_rate_mel = 24000 + self.mel_fmin = 0 + self.mel_fmax = None + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def get_audio_text_speaker_pair(self, audiopath_sid_text): + audiopath, phoneme_ids = audiopath_sid_text + text = torch.FloatTensor(phoneme_ids) + try: + spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath)) + with torch.no_grad(): + ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu") + if ssl.shape[-1] != spec.shape[-1]: + typee = ssl.dtype + ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) + ssl.requires_grad = False + except: + traceback.print_exc() + mel = torch.zeros(100, 180) + # wav = torch.zeros(1, 96 * self.hop_length) + spec = torch.zeros(1025, 96) + ssl = torch.zeros(1, 768, 96) + text = text[-1:] + print("load audio or ssl error!!!!!!", audiopath) + return (ssl, spec, mel, text) + + def get_audio(self, filename): + audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768 + audio = torch.FloatTensor(audio_array) # /32768 + audio_norm = audio + audio_norm = audio_norm.unsqueeze(0) + audio_array24 = load_audio( + filename, 24000 + ) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速 + audio24 = torch.FloatTensor(audio_array24) # /32768 + audio_norm24 = audio24 + audio_norm24 = audio_norm24.unsqueeze(0) + + spec = spectrogram_torch( + audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False + ) + spec = torch.squeeze(spec, 0) + + spec1 = spectrogram_torch( + audio_norm24, + self.filter_length_mel, + self.sampling_rate_mel, + self.hop_length_mel, + self.win_length_mel, + center=False, + ) + mel = spec_to_mel_torch( + spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax + ) + mel = torch.squeeze(mel, 0) + mel = self.norm_spec(mel) + # print(1111111,spec.shape,mel.shape) + return spec, mel + + def get_sid(self, sid): + sid = torch.LongTensor([int(sid)]) + return sid + + def __getitem__(self, index): + # with torch.no_grad(): + return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) + + def __len__(self): + return len(self.audiopaths_sid_text) + + +class TextAudioSpeakerCollateV3: + """Zero-pads model inputs and targets""" + + def __init__(self, return_ids=False): + self.return_ids = return_ids + + def __call__(self, batch): + """Collate's training batch from normalized text, audio and speaker identities + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized, sid] + """ + # ssl, spec, wav,mel, text + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True) + # (ssl, spec,mel, text) + max_ssl_len = max([x[0].size(2) for x in batch]) + + max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1)) + max_ssl_len = int(2 * ((max_ssl_len // 2) + 1)) + + # max_ssl_len = int(8 * ((max_ssl_len // 8) + 1)) + # max_ssl_len1=max_ssl_len + + max_spec_len = max([x[1].size(1) for x in batch]) + max_spec_len = int(2 * ((max_spec_len // 2) + 1)) + # max_wav_len = max([x[2].size(1) for x in batch]) + + max_text_len = max([x[3].size(0) for x in batch]) + max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320 + + ssl_lengths = torch.LongTensor(len(batch)) + spec_lengths = torch.LongTensor(len(batch)) + text_lengths = torch.LongTensor(len(batch)) + # wav_lengths = torch.LongTensor(len(batch)) + mel_lengths = torch.LongTensor(len(batch)) + + spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) + mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_mel_len) + ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len) + text_padded = torch.LongTensor(len(batch), max_text_len) + # wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) + + spec_padded.zero_() + mel_padded.zero_() + ssl_padded.zero_() + text_padded.zero_() + # wav_padded.zero_() + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + # ssl, spec, wav,mel, text + ssl = row[0] + ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :] + ssl_lengths[i] = ssl.size(2) + + spec = row[1] + spec_padded[i, :, : spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + # wav = row[2] + # wav_padded[i, :, :wav.size(1)] = wav + # wav_lengths[i] = wav.size(1) + + mel = row[2] + mel_padded[i, :, : mel.size(1)] = mel + mel_lengths[i] = mel.size(1) + + text = row[3] + text_padded[i, : text.size(0)] = text + text_lengths[i] = text.size(0) + + # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths + return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths + + +class TextAudioSpeakerLoaderV4(torch.utils.data.Dataset): + """ + 1) loads audio, speaker_id, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, hparams, val=False): + exp_dir = hparams.exp_dir + self.path2 = "%s/2-name2text.txt" % exp_dir + self.path4 = "%s/4-cnhubert" % exp_dir + self.path5 = "%s/5-wav32k" % exp_dir + assert os.path.exists(self.path2) + assert os.path.exists(self.path4) + assert os.path.exists(self.path5) + names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀 + names5 = set(os.listdir(self.path5)) + self.phoneme_data = {} + with open(self.path2, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") + + for line in lines: + tmp = line.split("\t") + if len(tmp) != 4: + continue + self.phoneme_data[tmp[0]] = [tmp[1]] + + self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5) + tmp = self.audiopaths_sid_text + leng = len(tmp) + min_num = 100 + if leng < min_num: + self.audiopaths_sid_text = [] + for _ in range(max(2, int(min_num / leng))): + self.audiopaths_sid_text += tmp + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + self.val = val + + random.seed(1234) + random.shuffle(self.audiopaths_sid_text) + + print("phoneme_data_len:", len(self.phoneme_data.keys())) + print("wav_data_len:", len(self.audiopaths_sid_text)) + + audiopaths_sid_text_new = [] + lengths = [] + skipped_phone = 0 + skipped_dur = 0 + for audiopath in tqdm(self.audiopaths_sid_text): + try: + phoneme = self.phoneme_data[audiopath][0] + phoneme = phoneme.split(" ") + phoneme_ids = cleaned_text_to_sequence(phoneme, version) + except Exception: + print(f"{audiopath} not in self.phoneme_data !") + skipped_phone += 1 + continue + + size = os.path.getsize("%s/%s" % (self.path5, audiopath)) + duration = size / self.sampling_rate / 2 + + if duration == 0: + print(f"Zero duration for {audiopath}, skipping...") + skipped_dur += 1 + continue + + if 54 > duration > 0.6 or self.val: + audiopaths_sid_text_new.append([audiopath, phoneme_ids]) + lengths.append(size // (2 * self.hop_length)) + else: + skipped_dur += 1 + continue + + print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur) + print("total left: ", len(audiopaths_sid_text_new)) + assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo + self.audiopaths_sid_text = audiopaths_sid_text_new + self.lengths = lengths + self.spec_min = -12 + self.spec_max = 2 + + self.filter_length_mel = self.win_length_mel = 1280 + self.hop_length_mel = 320 + self.n_mel_channels = 100 + self.sampling_rate_mel = 32000 + self.mel_fmin = 0 + self.mel_fmax = None + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def get_audio_text_speaker_pair(self, audiopath_sid_text): + audiopath, phoneme_ids = audiopath_sid_text + text = torch.FloatTensor(phoneme_ids) + try: + spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath)) + with torch.no_grad(): + ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu") + if ssl.shape[-1] != spec.shape[-1]: + typee = ssl.dtype + ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) + ssl.requires_grad = False + except: + traceback.print_exc() + mel = torch.zeros(100, 192) + # wav = torch.zeros(1, 96 * self.hop_length) + spec = torch.zeros(1025, 96) + ssl = torch.zeros(1, 768, 96) + text = text[-1:] + print("load audio or ssl error!!!!!!", audiopath) + return (ssl, spec, mel, text) + + def get_audio(self, filename): + audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768 + audio = torch.FloatTensor(audio_array) # /32768 + audio_norm = audio + audio_norm = audio_norm.unsqueeze(0) + spec = spectrogram_torch( + audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False + ) + spec = torch.squeeze(spec, 0) + spec1 = spectrogram_torch(audio_norm, 1280, 32000, 320, 1280, center=False) + mel = spec_to_mel_torch(spec1, 1280, 100, 32000, 0, None) + mel = self.norm_spec(torch.squeeze(mel, 0)) + return spec, mel + + def get_sid(self, sid): + sid = torch.LongTensor([int(sid)]) + return sid + + def __getitem__(self, index): + # with torch.no_grad(): + return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) + + def __len__(self): + return len(self.audiopaths_sid_text) + + +class TextAudioSpeakerCollateV4: + """Zero-pads model inputs and targets""" + + def __init__(self, return_ids=False): + self.return_ids = return_ids + + def __call__(self, batch): + """Collate's training batch from normalized text, audio and speaker identities + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized, sid] + """ + # ssl, spec, wav,mel, text + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True) + # (ssl, spec,mel, text) + max_ssl_len = max([x[0].size(2) for x in batch]) + max_ssl_len = int(2 * ((max_ssl_len // 2) + 1)) + max_spec_len = max([x[1].size(1) for x in batch]) + max_spec_len = int(2 * ((max_spec_len // 2) + 1)) + # max_wav_len = max([x[2].size(1) for x in batch]) + max_text_len = max([x[3].size(0) for x in batch]) + + ssl_lengths = torch.LongTensor(len(batch)) + spec_lengths = torch.LongTensor(len(batch)) + text_lengths = torch.LongTensor(len(batch)) + # wav_lengths = torch.LongTensor(len(batch)) + mel_lengths = torch.LongTensor(len(batch)) + + spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) + mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_spec_len * 2) + ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len) + text_padded = torch.LongTensor(len(batch), max_text_len) + # wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) + + spec_padded.zero_() + mel_padded.zero_() + ssl_padded.zero_() + text_padded.zero_() + # wav_padded.zero_() + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + # ssl, spec, wav,mel, text + ssl = row[0] + ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :] + ssl_lengths[i] = ssl.size(2) + + spec = row[1] + spec_padded[i, :, : spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + # wav = row[2] + # wav_padded[i, :, :wav.size(1)] = wav + # wav_lengths[i] = wav.size(1) + + mel = row[2] + mel_padded[i, :, : mel.size(1)] = mel + mel_lengths[i] = mel.size(1) + + text = row[3] + text_padded[i, : text.size(0)] = text + text_lengths[i] = text.size(0) + + # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths + return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths + + +class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset): + """ + 1) loads audio, speaker_id, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, hparams, val=False): + exp_dir = hparams.exp_dir + self.path2 = "%s/2-name2text.txt" % exp_dir + self.path4 = "%s/4-cnhubert" % exp_dir + self.path5 = "%s/5-wav32k" % exp_dir + assert os.path.exists(self.path2) + assert os.path.exists(self.path4) + assert os.path.exists(self.path5) + names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀 + names5 = set(os.listdir(self.path5)) + self.phoneme_data = {} + with open(self.path2, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") + + for line in lines: + tmp = line.split("\t") + if len(tmp) != 4: + continue + self.phoneme_data[tmp[0]] = [tmp[1]] + + self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5) + tmp = self.audiopaths_sid_text + leng = len(tmp) + min_num = 100 + if leng < min_num: + self.audiopaths_sid_text = [] + for _ in range(max(2, int(min_num / leng))): + self.audiopaths_sid_text += tmp + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + self.val = val + + random.seed(1234) + random.shuffle(self.audiopaths_sid_text) + + print("phoneme_data_len:", len(self.phoneme_data.keys())) + print("wav_data_len:", len(self.audiopaths_sid_text)) + + audiopaths_sid_text_new = [] + lengths = [] + skipped_phone = 0 + skipped_dur = 0 + for audiopath in tqdm(self.audiopaths_sid_text): + try: + phoneme = self.phoneme_data[audiopath][0] + phoneme = phoneme.split(" ") + phoneme_ids = cleaned_text_to_sequence(phoneme, version) + except Exception: + print(f"{audiopath} not in self.phoneme_data !") + skipped_phone += 1 + continue + + size = os.path.getsize("%s/%s" % (self.path5, audiopath)) + duration = size / self.sampling_rate / 2 + + if duration == 0: + print(f"Zero duration for {audiopath}, skipping...") + skipped_dur += 1 + continue + + if 54 > duration > 0.6 or self.val: + audiopaths_sid_text_new.append([audiopath, phoneme_ids]) + lengths.append(size // (2 * self.hop_length)) + else: + skipped_dur += 1 + continue + + print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur) + print("total left: ", len(audiopaths_sid_text_new)) + assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo + self.audiopaths_sid_text = audiopaths_sid_text_new + self.lengths = lengths + self.spec_min = -12 + self.spec_max = 2 + + self.filter_length_mel = self.win_length_mel = 1024 + self.hop_length_mel = 256 + self.n_mel_channels = 100 + self.sampling_rate_mel = 24000 + self.mel_fmin = 0 + self.mel_fmax = None + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def get_audio_text_speaker_pair(self, audiopath_sid_text): + audiopath, phoneme_ids = audiopath_sid_text + text = torch.FloatTensor(phoneme_ids) + try: + spec, mel, wav = self.get_audio("%s/%s" % (self.path5, audiopath)) + with torch.no_grad(): + ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu") + if ssl.shape[-1] != spec.shape[-1]: + typee = ssl.dtype + ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) + ssl.requires_grad = False + except: + traceback.print_exc() + mel = torch.zeros(100, 180) + wav = torch.zeros(1, 96 * self.hop_length) + spec = torch.zeros(1025, 96) + ssl = torch.zeros(1, 768, 96) + text = text[-1:] + print("load audio or ssl error!!!!!!", audiopath) + return (ssl, spec, wav, mel, text) + + def get_audio(self, filename): + audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768 + audio = torch.FloatTensor(audio_array) # /32768 + audio_norm = audio + audio_norm = audio_norm.unsqueeze(0) + audio_array24 = load_audio( + filename, 24000 + ) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速 + audio24 = torch.FloatTensor(audio_array24) # /32768 + audio_norm24 = audio24 + audio_norm24 = audio_norm24.unsqueeze(0) + + spec = spectrogram_torch( + audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False + ) + spec = torch.squeeze(spec, 0) + + spec1 = spectrogram_torch( + audio_norm24, + self.filter_length_mel, + self.sampling_rate_mel, + self.hop_length_mel, + self.win_length_mel, + center=False, + ) + mel = spec_to_mel_torch( + spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax + ) + mel = torch.squeeze(mel, 0) + mel = self.norm_spec(mel) + # print(1111111,spec.shape,mel.shape) + return spec, mel, audio_norm + + def get_sid(self, sid): + sid = torch.LongTensor([int(sid)]) + return sid + + def __getitem__(self, index): + # with torch.no_grad(): + return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) + + def __len__(self): + return len(self.audiopaths_sid_text) + + +class TextAudioSpeakerCollateV3b: + """Zero-pads model inputs and targets""" + + def __init__(self, return_ids=False): + self.return_ids = return_ids + + def __call__(self, batch): + """Collate's training batch from normalized text, audio and speaker identities + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized, sid] + """ + # ssl, spec, wav,mel, text + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True) + # (ssl, spec,mel, text) + max_ssl_len = max([x[0].size(2) for x in batch]) + + max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1)) + max_ssl_len = int(2 * ((max_ssl_len // 2) + 1)) + + # max_ssl_len = int(8 * ((max_ssl_len // 8) + 1)) + # max_ssl_len1=max_ssl_len + + max_spec_len = max([x[1].size(1) for x in batch]) + max_spec_len = int(2 * ((max_spec_len // 2) + 1)) + max_wav_len = max([x[2].size(1) for x in batch]) + max_text_len = max([x[4].size(0) for x in batch]) + max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320 + + ssl_lengths = torch.LongTensor(len(batch)) + spec_lengths = torch.LongTensor(len(batch)) + text_lengths = torch.LongTensor(len(batch)) + wav_lengths = torch.LongTensor(len(batch)) + mel_lengths = torch.LongTensor(len(batch)) + + spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) + mel_padded = torch.FloatTensor(len(batch), batch[0][3].size(0), max_mel_len) + ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len) + text_padded = torch.LongTensor(len(batch), max_text_len) + wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) + + spec_padded.zero_() + mel_padded.zero_() + ssl_padded.zero_() + text_padded.zero_() + wav_padded.zero_() + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + # ssl, spec, wav,mel, text + ssl = row[0] + ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :] + ssl_lengths[i] = ssl.size(2) + + spec = row[1] + spec_padded[i, :, : spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + wav = row[2] + wav_padded[i, :, : wav.size(1)] = wav + wav_lengths[i] = wav.size(1) + + mel = row[3] + mel_padded[i, :, : mel.size(1)] = mel + mel_lengths[i] = mel.size(1) + + text = row[4] + text_padded[i, : text.size(0)] = text + text_lengths[i] = text.size(0) + + return ( + ssl_padded, + spec_padded, + mel_padded, + ssl_lengths, + spec_lengths, + text_padded, + text_lengths, + wav_padded, + wav_lengths, + mel_lengths, + ) + # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths + + +class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): + """ + Maintain similar input lengths in a batch. + Length groups are specified by boundaries. + Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. + + It removes samples which are not included in the boundaries. + Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. + """ + + def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): + super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + self.lengths = dataset.lengths + self.batch_size = batch_size + self.boundaries = boundaries + + self.buckets, self.num_samples_per_bucket = self._create_buckets() + self.total_size = sum(self.num_samples_per_bucket) + self.num_samples = self.total_size // self.num_replicas + + def _create_buckets(self): + buckets = [[] for _ in range(len(self.boundaries) - 1)] + for i in range(len(self.lengths)): + length = self.lengths[i] + idx_bucket = self._bisect(length) + if idx_bucket != -1: + buckets[idx_bucket].append(i) + + i = len(buckets) - 1 + while i >= 0: + if len(buckets[i]) == 0: + buckets.pop(i) + self.boundaries.pop(i + 1) + i -= 1 + + num_samples_per_bucket = [] + for i in range(len(buckets)): + len_bucket = len(buckets[i]) + total_batch_size = self.num_replicas * self.batch_size + rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size + num_samples_per_bucket.append(len_bucket + rem) + return buckets, num_samples_per_bucket + + def __iter__(self): + g = torch.Generator() + g.manual_seed(self.epoch) + + indices = [] + if self.shuffle: + for bucket in self.buckets: + indices.append(torch.randperm(len(bucket), generator=g).tolist()) + else: + for bucket in self.buckets: + indices.append(list(range(len(bucket)))) + + batches = [] + for i in range(len(self.buckets)): + bucket = self.buckets[i] + len_bucket = len(bucket) + ids_bucket = indices[i] + num_samples_bucket = self.num_samples_per_bucket[i] + + rem = num_samples_bucket - len_bucket + ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)] + + ids_bucket = ids_bucket[self.rank :: self.num_replicas] + + for j in range(len(ids_bucket) // self.batch_size): + batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]] + batches.append(batch) + + if self.shuffle: + batch_ids = torch.randperm(len(batches), generator=g).tolist() + batches = [batches[i] for i in batch_ids] + self.batches = batches + + assert len(self.batches) * self.batch_size == self.num_samples + return iter(self.batches) + + def _bisect(self, x, lo=0, hi=None): + if hi is None: + hi = len(self.boundaries) - 1 + + if hi > lo: + mid = (hi + lo) // 2 + if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: + return mid + elif x <= self.boundaries[mid]: + return self._bisect(x, lo, mid) + else: + return self._bisect(x, mid + 1, hi) + else: + return -1 + + def __len__(self): + return self.num_samples // self.batch_size diff --git a/module/losses.py b/module/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..2b642db156bcff5b0d20b5beecd85c511e981287 --- /dev/null +++ b/module/losses.py @@ -0,0 +1,70 @@ +import math + +import torch + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l + + +def mle_loss(z, m, logs, logdet, mask): + l = torch.sum(logs) + 0.5 * torch.sum( + torch.exp(-2 * logs) * ((z - m) ** 2) + ) # neg normal likelihood w/o the constant term + l = l - torch.sum(logdet) # log jacobian determinant + l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes + l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term + return l diff --git a/module/mel_processing.py b/module/mel_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..62c7b40e1348316f06c323ef03c8781f4cd2645d --- /dev/null +++ b/module/mel_processing.py @@ -0,0 +1,143 @@ +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if torch.min(y) < -1.2: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.2: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + # wnsize_dtype_device = str(win_size) + '_' + dtype_device + key = "%s-%s-%s-%s-%s" % (dtype_device, n_fft, sampling_rate, hop_size, win_size) + # if wnsize_dtype_device not in hann_window: + if key not in hann_window: + # hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + hann_window[key] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + # spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[key], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + # fmax_dtype_device = str(fmax) + '_' + dtype_device + key = "%s-%s-%s-%s-%s-%s" % (dtype_device, n_fft, num_mels, sampling_rate, fmin, fmax) + # if fmax_dtype_device not in mel_basis: + if key not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + # mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel_basis[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + # spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = torch.matmul(mel_basis[key], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.2: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.2: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + # fmax_dtype_device = str(fmax) + '_' + dtype_device + fmax_dtype_device = "%s-%s-%s-%s-%s-%s-%s-%s" % ( + dtype_device, + n_fft, + num_mels, + sampling_rate, + hop_size, + win_size, + fmin, + fmax, + ) + # wnsize_dtype_device = str(win_size) + '_' + dtype_device + wnsize_dtype_device = fmax_dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8) + + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/module/models.py b/module/models.py new file mode 100644 index 0000000000000000000000000000000000000000..1c8e662f4999a61720ca88eb30c42910224ff975 --- /dev/null +++ b/module/models.py @@ -0,0 +1,1433 @@ +import warnings + +warnings.filterwarnings("ignore") +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from module import commons +from module import modules +from module import attentions +from f5_tts.model import DiT +from torch.nn import Conv1d, ConvTranspose1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from module.commons import init_weights, get_padding +from module.mrte_model import MRTE +from module.quantize import ResidualVectorQuantizer + +# from text import symbols +from text import symbols as symbols_v1 +from text import symbols2 as symbols_v2 +from torch.cuda.amp import autocast +import contextlib +import random + + +class StochasticDurationPredictor(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + ): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) + logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + def __init__( + self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + latent_channels=192, + version="v2", + ): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.latent_channels = latent_channels + self.version = version + + self.ssl_proj = nn.Conv1d(768, hidden_channels, 1) + + self.encoder_ssl = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers // 2, + kernel_size, + p_dropout, + ) + + self.encoder_text = attentions.Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + + if self.version == "v1": + symbols = symbols_v1.symbols + else: + symbols = symbols_v2.symbols + self.text_embedding = nn.Embedding(len(symbols), hidden_channels) + + self.mrte = MRTE() + + self.encoder2 = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers // 2, + kernel_size, + p_dropout, + ) + + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None): + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) + + y = self.ssl_proj(y * y_mask) * y_mask + + y = self.encoder_ssl(y * y_mask, y_mask) + + text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype) + if test == 1: + text[:, :] = 0 + text = self.text_embedding(text).transpose(1, 2) + text = self.encoder_text(text * text_mask, text_mask) + y = self.mrte(y, y_mask, text, text_mask, ge) + y = self.encoder2(y * y_mask, y_mask) + if speed != 1: + y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear") + y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") + stats = self.proj(y) * y_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + return y, m, logs, y_mask + + def extract_latent(self, x): + x = self.ssl_proj(x) + quantized, codes, commit_loss, quantized_list = self.quantizer(x) + return codes.transpose(0, 1) + + def decode_latent(self, codes, y_mask, refer, refer_mask, ge): + quantized = self.quantizer.decode(codes) + + y = self.vq_proj(quantized) * y_mask + y = self.encoder_ssl(y * y_mask, y_mask) + + y = self.mrte(y, y_mask, refer, refer_mask, ge) + + y = self.encoder2(y * y_mask, y_mask) + + stats = self.proj(y) * y_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + return y, m, logs, y_mask, quantized + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + if g != None: + g = g.detach() + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Encoder(nn.Module): + def __init__( + self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0 + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + def forward(self, x, x_lengths, g=None): + if g != None: + g = g.detach() + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + return stats, x_mask + + +class WNEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.norm = modules.LayerNorm(out_channels) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + out = self.proj(x) * x_mask + out = self.norm(out) + return out + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + is_bias=False, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +v2pro_set = {"v2Pro", "v2ProPlus"} + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False, version=None): + super(MultiPeriodDiscriminator, self).__init__() + if version in v2pro_set: + periods = [2, 3, 5, 7, 11, 17, 23] + else: + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class ReferenceEncoder(nn.Module): + """ + inputs --- [N, Ty/r, n_mels*r] mels + outputs --- [N, ref_enc_gru_size] + """ + + def __init__(self, spec_channels, gin_channels=0): + super().__init__() + self.spec_channels = spec_channels + ref_enc_filters = [32, 32, 64, 64, 128, 128] + K = len(ref_enc_filters) + filters = [1] + ref_enc_filters + convs = [ + weight_norm( + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + ) + for i in range(K) + ] + self.convs = nn.ModuleList(convs) + # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) + + out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) + self.gru = nn.GRU( + input_size=ref_enc_filters[-1] * out_channels, + hidden_size=256 // 2, + batch_first=True, + ) + self.proj = nn.Linear(128, gin_channels) + + def forward(self, inputs): + N = inputs.size(0) + out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] + for conv in self.convs: + out = conv(out) + # out = wn(out) + out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] + + out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] + T = out.size(1) + N = out.size(0) + out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] + + self.gru.flatten_parameters() + memory, out = self.gru(out) # out --- [1, N, 128] + + return self.proj(out.squeeze(0)).unsqueeze(-1) + + def calculate_channels(self, L, kernel_size, stride, pad, n_convs): + for i in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class Quantizer_module(torch.nn.Module): + def __init__(self, n_e, e_dim): + super(Quantizer_module, self).__init__() + self.embedding = nn.Embedding(n_e, e_dim) + self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e) + + def forward(self, x): + d = ( + torch.sum(x**2, 1, keepdim=True) + + torch.sum(self.embedding.weight**2, 1) + - 2 * torch.matmul(x, self.embedding.weight.T) + ) + min_indicies = torch.argmin(d, 1) + z_q = self.embedding(min_indicies) + return z_q, min_indicies + + +class Quantizer(torch.nn.Module): + def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160): + super(Quantizer, self).__init__() + assert embed_dim % n_code_groups == 0 + self.quantizer_modules = nn.ModuleList( + [Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)] + ) + self.n_code_groups = n_code_groups + self.embed_dim = embed_dim + + def forward(self, xin): + # B, C, T + B, C, T = xin.shape + xin = xin.transpose(1, 2) + x = xin.reshape(-1, self.embed_dim) + x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1) + min_indicies = [] + z_q = [] + for _x, m in zip(x, self.quantizer_modules): + _z_q, _min_indicies = m(_x) + z_q.append(_z_q) + min_indicies.append(_min_indicies) # B * T, + z_q = torch.cat(z_q, -1).reshape(xin.shape) + loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2) + z_q = xin + (z_q - xin).detach() + z_q = z_q.transpose(1, 2) + codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups) + return z_q, loss, codes.transpose(1, 2) + + def embed(self, x): + # idx: N, 4, T + x = x.transpose(1, 2) + x = torch.split(x, 1, 2) + ret = [] + for q, embed in zip(x, self.quantizer_modules): + q = embed.embedding(q.squeeze(-1)) + ret.append(q) + ret = torch.cat(ret, -1) + return ret.transpose(1, 2) # N, C, T + + +class CodePredictor(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_q=8, + dims=1024, + ssl_dim=768, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1) + self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels) + + self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout) + + self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1) + self.n_q = n_q + self.dims = dims + + def forward(self, x, x_mask, refer, codes, infer=False): + x = x.detach() + x = self.vq_proj(x * x_mask) * x_mask + g = self.ref_enc(refer, x_mask) + x = x + g + x = self.encoder(x * x_mask, x_mask) + x = self.out_proj(x * x_mask) * x_mask + logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3) + target = codes[1:].transpose(0, 1) + if not infer: + logits = logits.reshape(-1, self.dims) + target = target.reshape(-1) + loss = torch.nn.functional.cross_entropy(logits, target) + return loss + else: + _, top10_preds = torch.topk(logits, 10, dim=-1) + correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1) + top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item() + + print("Top-10 Accuracy:", top3_acc, "%") + + pred_codes = torch.argmax(logits, dim=-1) + acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item() + print("Top-1 Accuracy:", acc, "%") + + return pred_codes.transpose(0, 1) + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + semantic_frame_rate=None, + freeze_quantizer=None, + version="v2", + **kwargs, + ): + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + self.version = version + + self.use_sdp = use_sdp + self.enc_p = TextEncoder( + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + version=version, + ) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + # self.version=os.environ.get("version","v1") + if self.version == "v1": + self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels) + else: + self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) + + ssl_dim = 768 + assert semantic_frame_rate in ["25hz", "50hz"] + self.semantic_frame_rate = semantic_frame_rate + if semantic_frame_rate == "25hz": + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) + else: + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) + + self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) + self.freeze_quantizer = freeze_quantizer + + self.is_v2pro = self.version in v2pro_set + if self.is_v2pro: + self.sv_emb = nn.Linear(20480, gin_channels) + self.ge_to512 = nn.Linear(gin_channels, 512) + self.prelu = nn.PReLU(num_parameters=gin_channels) + + def forward(self, ssl, y, y_lengths, text, text_lengths, sv_emb=None): + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) + if self.version == "v1": + ge = self.ref_enc(y * y_mask, y_mask) + else: + ge = self.ref_enc(y[:, :704] * y_mask, y_mask) + if self.is_v2pro: + sv_emb = self.sv_emb(sv_emb) # B*20480->B*512 + ge += sv_emb.unsqueeze(-1) + ge = self.prelu(ge) + ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) + with autocast(enabled=False): + maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() + with maybe_no_grad: + if self.freeze_quantizer: + self.ssl_proj.eval() + self.quantizer.eval() + ssl = self.ssl_proj(ssl) + quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0]) + + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") + + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) + z_p = self.flow(z, y_mask, g=ge) + + z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) + o = self.dec(z_slice, g=ge) + return ( + o, + commit_loss, + ids_slice, + y_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + quantized, + ) + + def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5): + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) + if self.version == "v1": + ge = self.ref_enc(y * y_mask, y_mask) + else: + ge = self.ref_enc(y[:, :704] * y_mask, y_mask) + + ssl = self.ssl_proj(ssl) + quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0]) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") + + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + + z = self.flow(z_p, y_mask, g=ge, reverse=True) + + o = self.dec((z * y_mask)[:, :, :], g=ge) + return o, y_mask, (z, z_p, m_p, logs_p) + + @torch.no_grad() + def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None): + def get_ge(refer, sv_emb): + ge = None + if refer is not None: + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) + if self.version == "v1": + ge = self.ref_enc(refer * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + sv_emb = self.sv_emb(sv_emb) # B*20480->B*512 + ge += sv_emb.unsqueeze(-1) + ge = self.prelu(ge) + return ge + + if type(refer) == list: + ges = [] + for idx, _refer in enumerate(refer): + ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None) + ges.append(ge) + ge = torch.stack(ges, 0).mean(0) + else: + ge = get_ge(refer, sv_emb) + + y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) + text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") + x, m_p, logs_p, y_mask = self.enc_p( + quantized, + y_lengths, + text, + text_lengths, + self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, + speed, + ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + + z = self.flow(z_p, y_mask, g=ge, reverse=True) + + o = self.dec((z * y_mask)[:, :, :], g=ge) + return o + + def extract_latent(self, x): + ssl = self.ssl_proj(x) + quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) + return codes.transpose(0, 1) + + +class CFM(torch.nn.Module): + def __init__(self, in_channels, dit): + super().__init__() + self.sigma_min = 1e-6 + + self.estimator = dit + + self.in_channels = in_channels + + self.criterion = torch.nn.MSELoss() + + self.use_conditioner_cache = True + + @torch.inference_mode() + def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0): + """Forward diffusion""" + B, T = mu.size(0), mu.size(1) + x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature + prompt_len = prompt.size(-1) + prompt_x = torch.zeros_like(x, dtype=mu.dtype) + prompt_x[..., :prompt_len] = prompt[..., :prompt_len] + x[..., :prompt_len] = 0 + mu = mu.transpose(2, 1) + t = 0 + d = 1 / n_timesteps + text_cache = None + text_cfg_cache = None + dt_cache = None + d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d + for j in range(n_timesteps): + t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t + # v_pred = model(x, t_tensor, d_tensor, **extra_args) + v_pred, text_emb, dt = self.estimator( + x, + prompt_x, + x_lens, + t_tensor, + d_tensor, + mu, + use_grad_ckpt=False, + drop_audio_cond=False, + drop_text=False, + infer=True, + text_cache=text_cache, + dt_cache=dt_cache, + ) + v_pred = v_pred.transpose(2, 1) + if self.use_conditioner_cache: + text_cache = text_emb + dt_cache = dt + if inference_cfg_rate > 1e-5: + neg, text_cfg_emb, _ = self.estimator( + x, + prompt_x, + x_lens, + t_tensor, + d_tensor, + mu, + use_grad_ckpt=False, + drop_audio_cond=True, + drop_text=True, + infer=True, + text_cache=text_cfg_cache, + dt_cache=dt_cache, + ) + neg = neg.transpose(2, 1) + if self.use_conditioner_cache: + text_cfg_cache = text_cfg_emb + v_pred = v_pred + (v_pred - neg) * inference_cfg_rate + x = x + d * v_pred + t = t + d + x[:, :, :prompt_len] = 0 + return x + + def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt): + b, _, t = x1.shape + t = torch.rand([b], device=mu.device, dtype=x1.dtype) + x0 = torch.randn_like(x1, device=mu.device) + vt = x1 - x0 + xt = x0 + t[:, None, None] * vt + dt = torch.zeros_like(t, device=mu.device) + prompt = torch.zeros_like(x1) + for i in range(b): + prompt[i, :, : prompt_lens[i]] = x1[i, :, : prompt_lens[i]] + xt[i, :, : prompt_lens[i]] = 0 + gailv = 0.3 # if ttime()>1736250488 else 0.1 + if random.random() < gailv: + base = torch.randint(2, 8, (t.shape[0],), device=mu.device) + d = 1 / torch.pow(2, base) + d_input = d.clone() + d_input[d_input < 1e-2] = 0 + # with torch.no_grad(): + v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu, use_grad_ckpt).transpose(2, 1).detach() + # v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach() + x_mid = xt + d[:, None, None] * v_pred_1 + # v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach() + v_pred_2 = self.estimator(x_mid, prompt, x_lens, t + d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach() + vt = (v_pred_1 + v_pred_2) / 2 + vt = vt.detach() + dt = 2 * d + + vt_pred = self.estimator(xt, prompt, x_lens, t, dt, mu, use_grad_ckpt).transpose(2, 1) + loss = 0 + for i in range(b): + loss += self.criterion(vt_pred[i, :, prompt_lens[i] : x_lens[i]], vt[i, :, prompt_lens[i] : x_lens[i]]) + loss /= b + + return loss + + +def set_no_grad(net_g): + for name, param in net_g.named_parameters(): + param.requires_grad = False + + +class SynthesizerTrnV3(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + semantic_frame_rate=None, + freeze_quantizer=None, + version="v3", + **kwargs, + ): + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + self.version = version + + self.model_dim = 512 + self.use_sdp = use_sdp + self.enc_p = TextEncoder( + inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + # self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback + self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback + # self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, + # upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) + # self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, + # gin_channels=gin_channels) + # self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + ssl_dim = 768 + assert semantic_frame_rate in ["25hz", "50hz"] + self.semantic_frame_rate = semantic_frame_rate + if semantic_frame_rate == "25hz": + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) + else: + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) + + self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) + self.freeze_quantizer = freeze_quantizer + inter_channels2 = 512 + self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU()) + self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels) + self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1) + self.cfm = CFM( + 100, + DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)), + ) # text_dim is condition feature dim + if self.freeze_quantizer == True: + set_no_grad(self.ssl_proj) + set_no_grad(self.quantizer) + set_no_grad(self.enc_p) + + def forward( + self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths, use_grad_ckpt + ): # ssl_lengths no need now + with autocast(enabled=False): + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) + ge = self.ref_enc(y[:, :704] * y_mask, y_mask) + maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() + with maybe_no_grad: + if self.freeze_quantizer: + self.ssl_proj.eval() # + self.quantizer.eval() + self.enc_p.eval() + ssl = self.ssl_proj(ssl) + quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0]) + quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + fea = self.bridge(x) + fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT + fea, y_mask_ = self.wns1( + fea, mel_lengths, ge + ) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate. + B = ssl.shape[0] + prompt_len_max = mel_lengths * 2 / 3 + prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long) + minn = min(mel.shape[-1], fea.shape[-1]) + mel = mel[:, :, :minn] + fea = fea[:, :, :minn] + cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt) + return cfm_loss + + @torch.no_grad() + def decode_encp(self, codes, text, refer, ge=None, speed=1): + # print(2333333,refer.shape) + # ge=None + if ge == None: + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device) + if speed == 1: + sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4)) + else: + sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4) / speed) + 1 + y_lengths1 = torch.LongTensor([sizee]).to(codes.device) + text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) + fea = self.bridge(x) + fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT + ####more wn paramter to learn mel + fea, y_mask_ = self.wns1(fea, y_lengths1, ge) + return fea, ge + + def extract_latent(self, x): + ssl = self.ssl_proj(x) + quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) + return codes.transpose(0, 1) + + +class SynthesizerTrnV3b(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + semantic_frame_rate=None, + freeze_quantizer=None, + **kwargs, + ): + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + + self.model_dim = 512 + self.use_sdp = use_sdp + self.enc_p = TextEncoder( + inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + # self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback + self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels + ) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + ssl_dim = 768 + assert semantic_frame_rate in ["25hz", "50hz"] + self.semantic_frame_rate = semantic_frame_rate + if semantic_frame_rate == "25hz": + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) + else: + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) + + self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) + self.freeze_quantizer = freeze_quantizer + + inter_channels2 = 512 + self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU()) + self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels) + self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1) + self.cfm = CFM( + 100, + DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)), + ) # text_dim is condition feature dim + + def forward(self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths): # ssl_lengths no need now + with autocast(enabled=False): + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) + ge = self.ref_enc(y[:, :704] * y_mask, y_mask) + # ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k + # ge=None + maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() + with maybe_no_grad: + if self.freeze_quantizer: + self.ssl_proj.eval() + self.quantizer.eval() + ssl = self.ssl_proj(ssl) + quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0]) + quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) + z_p = self.flow(z, y_mask, g=ge) + z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) + o = self.dec(z_slice, g=ge) + fea = self.bridge(x) + fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT + fea, y_mask_ = self.wns1(fea, mel_lengths, ge) + learned_mel = self.linear_mel(fea) + B = ssl.shape[0] + prompt_len_max = mel_lengths * 2 / 3 + prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long) # + minn = min(mel.shape[-1], fea.shape[-1]) + mel = mel[:, :, :minn] + fea = fea[:, :, :minn] + cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea) # fea==cond,y_lengths==target_mel_lengths#ge not need + return ( + commit_loss, + cfm_loss, + F.mse_loss(learned_mel, mel), + o, + ids_slice, + y_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + quantized, + ) + + @torch.no_grad() + def decode_encp(self, codes, text, refer, ge=None): + # print(2333333,refer.shape) + # ge=None + if ge == None: + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device) + y_lengths1 = torch.LongTensor([int(codes.size(2) * 2.5 * 1.5)]).to(codes.device) + text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + fea = self.bridge(x) + fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT + ####more wn paramter to learn mel + fea, y_mask_ = self.wns1(fea, y_lengths1, ge) + return fea, ge + + def extract_latent(self, x): + ssl = self.ssl_proj(x) + quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) + return codes.transpose(0, 1) diff --git a/module/models_onnx.py b/module/models_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..b62b8b711cdf8759a26e192876f982f324197eb1 --- /dev/null +++ b/module/models_onnx.py @@ -0,0 +1,1087 @@ +import math +from typing import Optional +import torch +from torch import nn +from torch.nn import functional as F + +from module import commons +from module import modules +from module import attentions_onnx as attentions + +from f5_tts.model import DiT + +from torch.nn import Conv1d, ConvTranspose1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from module.commons import init_weights, get_padding +from module.quantize import ResidualVectorQuantizer + +# from text import symbols +from text import symbols as symbols_v1 +from text import symbols2 as symbols_v2 + + +class StochasticDurationPredictor(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + ): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) + logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + def __init__( + self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + latent_channels=192, + version="v2", + ): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.latent_channels = latent_channels + self.version = version + + self.ssl_proj = nn.Conv1d(768, hidden_channels, 1) + + self.encoder_ssl = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers // 2, + kernel_size, + p_dropout, + ) + + self.encoder_text = attentions.Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + + if self.version == "v1": + symbols = symbols_v1.symbols + else: + symbols = symbols_v2.symbols + self.text_embedding = nn.Embedding(len(symbols), hidden_channels) + + self.mrte = attentions.MRTE() + + self.encoder2 = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers // 2, + kernel_size, + p_dropout, + ) + + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, y, text, ge, speed=1): + y_mask = torch.ones_like(y[:1, :1, :]) + + y = self.ssl_proj(y * y_mask) * y_mask + y = self.encoder_ssl(y * y_mask, y_mask) + + text_mask = torch.ones_like(text).to(y.dtype).unsqueeze(0) + + text = self.text_embedding(text).transpose(1, 2) + text = self.encoder_text(text * text_mask, text_mask) + y = self.mrte(y, y_mask, text, text_mask, ge) + + y = self.encoder2(y * y_mask, y_mask) + if speed != 1: + y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear") + y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") + + stats = self.proj(y) * y_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + return y, m, logs, y_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + if g != None: + g = g.detach() + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Encoder(nn.Module): + def __init__( + self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0 + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + def forward(self, x, x_lengths, g=None): + if g != None: + g = g.detach() + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + return stats, x_mask + + +class WNEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.norm = modules.LayerNorm(out_channels) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + out = self.proj(x) * x_mask + out = self.norm(out) + return out + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + is_bias=False, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g: Optional[torch.Tensor] = None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class ReferenceEncoder(nn.Module): + """ + inputs --- [N, Ty/r, n_mels*r] mels + outputs --- [N, ref_enc_gru_size] + """ + + def __init__(self, spec_channels, gin_channels=0): + super().__init__() + self.spec_channels = spec_channels + ref_enc_filters = [32, 32, 64, 64, 128, 128] + K = len(ref_enc_filters) + filters = [1] + ref_enc_filters + convs = [ + weight_norm( + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + ) + for i in range(K) + ] + self.convs = nn.ModuleList(convs) + # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) + + out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) + self.gru = nn.GRU( + input_size=ref_enc_filters[-1] * out_channels, + hidden_size=256 // 2, + batch_first=True, + ) + self.proj = nn.Linear(128, gin_channels) + + def forward(self, inputs): + N = inputs.size(0) + out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] + for conv in self.convs: + out = conv(out) + # out = wn(out) + out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] + + out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] + T = out.size(1) + N = out.size(0) + out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] + + self.gru.flatten_parameters() + memory, out = self.gru(out) # out --- [1, N, 128] + + return self.proj(out.squeeze(0)).unsqueeze(-1) + + def calculate_channels(self, L, kernel_size, stride, pad, n_convs): + for i in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class Quantizer_module(torch.nn.Module): + def __init__(self, n_e, e_dim): + super(Quantizer_module, self).__init__() + self.embedding = nn.Embedding(n_e, e_dim) + self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e) + + def forward(self, x): + d = ( + torch.sum(x**2, 1, keepdim=True) + + torch.sum(self.embedding.weight**2, 1) + - 2 * torch.matmul(x, self.embedding.weight.T) + ) + min_indicies = torch.argmin(d, 1) + z_q = self.embedding(min_indicies) + return z_q, min_indicies + + +class Quantizer(torch.nn.Module): + def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160): + super(Quantizer, self).__init__() + assert embed_dim % n_code_groups == 0 + self.quantizer_modules = nn.ModuleList( + [Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)] + ) + self.n_code_groups = n_code_groups + self.embed_dim = embed_dim + + def forward(self, xin): + # B, C, T + B, C, T = xin.shape + xin = xin.transpose(1, 2) + x = xin.reshape(-1, self.embed_dim) + x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1) + min_indicies = [] + z_q = [] + for _x, m in zip(x, self.quantizer_modules): + _z_q, _min_indicies = m(_x) + z_q.append(_z_q) + min_indicies.append(_min_indicies) # B * T, + z_q = torch.cat(z_q, -1).reshape(xin.shape) + loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2) + z_q = xin + (z_q - xin).detach() + z_q = z_q.transpose(1, 2) + codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups) + return z_q, loss, codes.transpose(1, 2) + + def embed(self, x): + # idx: N, 4, T + x = x.transpose(1, 2) + x = torch.split(x, 1, 2) + ret = [] + for q, embed in zip(x, self.quantizer_modules): + q = embed.embedding(q.squeeze(-1)) + ret.append(q) + ret = torch.cat(ret, -1) + return ret.transpose(1, 2) # N, C, T + + +class CodePredictor(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_q=8, + dims=1024, + ssl_dim=768, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1) + self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels) + + self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout) + + self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1) + self.n_q = n_q + self.dims = dims + + def forward(self, x, x_mask, refer, codes, infer=False): + x = x.detach() + x = self.vq_proj(x * x_mask) * x_mask + g = self.ref_enc(refer, x_mask) + x = x + g + x = self.encoder(x * x_mask, x_mask) + x = self.out_proj(x * x_mask) * x_mask + logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3) + target = codes[1:].transpose(0, 1) + if not infer: + logits = logits.reshape(-1, self.dims) + target = target.reshape(-1) + loss = torch.nn.functional.cross_entropy(logits, target) + return loss + else: + _, top10_preds = torch.topk(logits, 10, dim=-1) + correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1) + top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item() + + print("Top-10 Accuracy:", top3_acc, "%") + + pred_codes = torch.argmax(logits, dim=-1) + acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item() + print("Top-1 Accuracy:", acc, "%") + + return pred_codes.transpose(0, 1) + + +v2pro_set = {"v2Pro", "v2ProPlus"} + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + semantic_frame_rate=None, + freeze_quantizer=None, + version="v2", + **kwargs, + ): + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + self.version = version + + self.use_sdp = use_sdp + self.enc_p = TextEncoder( + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + version=version, + ) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + # self.enc_q = PosteriorEncoder( + # spec_channels, + # inter_channels, + # hidden_channels, + # 5, + # 1, + # 16, + # gin_channels=gin_channels, + # ) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + # self.version=os.environ.get("version","v1") + if self.version == "v1": + self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels) + else: + self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) + + ssl_dim = 768 + self.ssl_dim = ssl_dim + assert semantic_frame_rate in ["25hz", "50hz"] + self.semantic_frame_rate = semantic_frame_rate + if semantic_frame_rate == "25hz": + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) + else: + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) + + self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) + if freeze_quantizer: + self.ssl_proj.requires_grad_(False) + self.quantizer.requires_grad_(False) + # self.enc_p.text_embedding.requires_grad_(False) + # self.enc_p.encoder_text.requires_grad_(False) + # self.enc_p.mrte.requires_grad_(False) + self.is_v2pro = self.version in v2pro_set + if self.is_v2pro: + self.sv_emb = nn.Linear(20480, gin_channels) + self.ge_to512 = nn.Linear(gin_channels, 512) + self.prelu = nn.PReLU(num_parameters=gin_channels) + + def forward(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None): + refer_mask = torch.ones_like(refer[:1, :1, :]) + if self.version == "v1": + ge = self.ref_enc(refer * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + sv_emb = self.sv_emb(sv_emb) + ge += sv_emb.unsqueeze(-1) + ge = self.prelu(ge) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0) + quantized = dquantized.contiguous().view(1, self.ssl_dim, -1) + + if self.is_v2pro: + ge_ = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) + x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge_, speed) + else: + x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed) + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + + z = self.flow(z_p, y_mask, g=ge, reverse=True) + + o = self.dec((z * y_mask)[:, :, :], g=ge) + return o + + def extract_latent(self, x): + ssl = self.ssl_proj(x) + quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) + return codes.transpose(0, 1) + + +class CFM(torch.nn.Module): + def __init__(self, in_channels, dit): + super().__init__() + # self.sigma_min = 1e-6 + + self.estimator = dit + + self.in_channels = in_channels + + # self.criterion = torch.nn.MSELoss() + + def forward( + self, + mu: torch.Tensor, + x_lens: torch.LongTensor, + prompt: torch.Tensor, + n_timesteps: torch.LongTensor, + temperature: float = 1.0, + ): + """Forward diffusion""" + B, T = mu.size(0), mu.size(1) + x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype) + + ntimesteps = int(n_timesteps) + + prompt_len = prompt.size(-1) + prompt_x = torch.zeros_like(x, dtype=mu.dtype) + prompt_x[..., :prompt_len] = prompt[..., :prompt_len] + x[..., :prompt_len] = 0.0 + mu = mu.transpose(2, 1) + t = torch.tensor(0.0, dtype=x.dtype, device=x.device) + d = torch.tensor(1.0 / ntimesteps, dtype=x.dtype, device=x.device) + d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d + + for j in range(ntimesteps): + t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t + # d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d + # v_pred = model(x, t_tensor, d_tensor, **extra_args) + v_pred = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu).transpose(2, 1) + # if inference_cfg_rate>1e-5: + # neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1) + # v_pred=v_pred+(v_pred-neg)*inference_cfg_rate + x = x + d * v_pred + t = t + d + x[:, :, :prompt_len] = 0.0 + return x + + +def set_no_grad(net_g): + for name, param in net_g.named_parameters(): + param.requires_grad = False + + +@torch.jit.script_if_tracing +def compile_codes_length(codes): + y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device) + return y_lengths1 * 2.5 * 1.5 + + +@torch.jit.script_if_tracing +def compile_ref_length(refer): + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + return refer_lengths + + +class SynthesizerTrnV3(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + semantic_frame_rate=None, + freeze_quantizer=None, + version="v3", + **kwargs, + ): + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + self.version = version + + self.model_dim = 512 + self.use_sdp = use_sdp + self.enc_p = TextEncoder( + inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + # self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback + self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback + # self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, + # upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) + # self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, + # gin_channels=gin_channels) + # self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + ssl_dim = 768 + assert semantic_frame_rate in ["25hz", "50hz"] + self.semantic_frame_rate = semantic_frame_rate + if semantic_frame_rate == "25hz": + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) + else: + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) + + self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) + freeze_quantizer + inter_channels2 = 512 + self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU()) + self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels) + self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1) + self.cfm = CFM( + 100, + DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)), + ) # text_dim is condition feature dim + if freeze_quantizer == True: + set_no_grad(self.ssl_proj) + set_no_grad(self.quantizer) + set_no_grad(self.enc_p) + + def create_ge(self, refer): + refer_lengths = compile_ref_length(refer) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + return ge + + def forward(self, codes, text, ge, speed=1): + y_lengths1 = compile_codes_length(codes) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT + x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed) + fea = self.bridge(x) + fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT + ####more wn paramter to learn mel + fea, y_mask_ = self.wns1(fea, y_lengths1, ge) + return fea + + def extract_latent(self, x): + ssl = self.ssl_proj(x) + quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) + return codes.transpose(0, 1) diff --git a/module/modules.py b/module/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..6fa84a43ffd63e05692c55ba424521d23f357d25 --- /dev/null +++ b/module/modules.py @@ -0,0 +1,897 @@ +import math + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from torch.nn import Conv1d +from torch.nn.utils import weight_norm, remove_weight_norm + +from module import commons +from module.commons import init_weights, get_padding +from module.transforms import piecewise_rational_quadratic_transform +import torch.distributions as D + + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + p_dropout=0, + ): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=p_dropout, + gin_channels=gin_channels, + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ConvFlow(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + n_layers, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) + self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x + + +class LinearNorm(nn.Module): + def __init__( + self, + in_channels, + out_channels, + bias=True, + spectral_norm=False, + ): + super(LinearNorm, self).__init__() + self.fc = nn.Linear(in_channels, out_channels, bias) + + if spectral_norm: + self.fc = nn.utils.spectral_norm(self.fc) + + def forward(self, input): + out = self.fc(input) + return out + + +class Mish(nn.Module): + def __init__(self): + super(Mish, self).__init__() + + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class Conv1dGLU(nn.Module): + """ + Conv1d + GLU(Gated Linear Unit) with residual connection. + For GLU refer to https://arxiv.org/abs/1612.08083 paper. + """ + + def __init__(self, in_channels, out_channels, kernel_size, dropout): + super(Conv1dGLU, self).__init__() + self.out_channels = out_channels + self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + x = self.conv1(x) + x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1) + x = x1 * torch.sigmoid(x2) + x = residual + self.dropout(x) + return x + + +class ConvNorm(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=None, + dilation=1, + bias=True, + spectral_norm=False, + ): + super(ConvNorm, self).__init__() + + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + if spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + def forward(self, input): + out = self.conv(input) + return out + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention module""" + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False): + super().__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_head * d_k) + self.w_ks = nn.Linear(d_model, n_head * d_k) + self.w_vs = nn.Linear(d_model, n_head * d_v) + + self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout) + + self.fc = nn.Linear(n_head * d_v, d_model) + self.dropout = nn.Dropout(dropout) + + if spectral_norm: + self.w_qs = nn.utils.spectral_norm(self.w_qs) + self.w_ks = nn.utils.spectral_norm(self.w_ks) + self.w_vs = nn.utils.spectral_norm(self.w_vs) + self.fc = nn.utils.spectral_norm(self.fc) + + def forward(self, x, mask=None): + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_x, _ = x.size() + + residual = x + + q = self.w_qs(x).view(sz_b, len_x, n_head, d_k) + k = self.w_ks(x).view(sz_b, len_x, n_head, d_k) + v = self.w_vs(x).view(sz_b, len_x, n_head, d_v) + q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lq x dk + k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lk x dk + v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v) # (n*b) x lv x dv + + if mask is not None: + slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. + else: + slf_mask = None + output, attn = self.attention(q, k, v, mask=slf_mask) + + output = output.view(n_head, sz_b, len_x, d_v) + output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1) # b x lq x (n*dv) + + output = self.fc(output) + + output = self.dropout(output) + residual + return output, attn + + +class ScaledDotProductAttention(nn.Module): + """Scaled Dot-Product Attention""" + + def __init__(self, temperature, dropout): + super().__init__() + self.temperature = temperature + self.softmax = nn.Softmax(dim=2) + self.dropout = nn.Dropout(dropout) + + def forward(self, q, k, v, mask=None): + attn = torch.bmm(q, k.transpose(1, 2)) + attn = attn / self.temperature + + if mask is not None: + attn = attn.masked_fill(mask, -np.inf) + + attn = self.softmax(attn) + p_attn = self.dropout(attn) + + output = torch.bmm(p_attn, v) + return output, attn + + +class MelStyleEncoder(nn.Module): + """MelStyleEncoder""" + + def __init__( + self, + n_mel_channels=80, + style_hidden=128, + style_vector_dim=256, + style_kernel_size=5, + style_head=2, + dropout=0.1, + ): + super(MelStyleEncoder, self).__init__() + self.in_dim = n_mel_channels + self.hidden_dim = style_hidden + self.out_dim = style_vector_dim + self.kernel_size = style_kernel_size + self.n_head = style_head + self.dropout = dropout + + self.spectral = nn.Sequential( + LinearNorm(self.in_dim, self.hidden_dim), + Mish(), + nn.Dropout(self.dropout), + LinearNorm(self.hidden_dim, self.hidden_dim), + Mish(), + nn.Dropout(self.dropout), + ) + + self.temporal = nn.Sequential( + Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), + Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), + ) + + self.slf_attn = MultiHeadAttention( + self.n_head, + self.hidden_dim, + self.hidden_dim // self.n_head, + self.hidden_dim // self.n_head, + self.dropout, + ) + + self.fc = LinearNorm(self.hidden_dim, self.out_dim) + + def temporal_avg_pool(self, x, mask=None): + if mask is None: + out = torch.mean(x, dim=1) + else: + len_ = (~mask).sum(dim=1).unsqueeze(1) + x = x.masked_fill(mask.unsqueeze(-1), 0) + dtype = x.dtype + x = x.float() + x = torch.div(x, len_.unsqueeze(1)) + out = x.sum(dim=1).to(dtype) + return out + + def forward(self, x, mask=None): + x = x.transpose(1, 2) + if mask is not None: + mask = (mask.int() == 0).squeeze(1) + max_len = x.shape[1] + slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None + + # spectral + x = self.spectral(x) + # temporal + x = x.transpose(1, 2) + x = self.temporal(x) + x = x.transpose(1, 2) + # self-attention + if mask is not None: + x = x.masked_fill(mask.unsqueeze(-1), 0) + x, _ = self.slf_attn(x, mask=slf_attn_mask) + # fc + x = self.fc(x) + # temoral average pooling + w = self.temporal_avg_pool(x, mask=mask) + return w.unsqueeze(-1) + + +class MelStyleEncoderVAE(nn.Module): + def __init__(self, spec_channels, z_latent_dim, emb_dim): + super().__init__() + self.ref_encoder = MelStyleEncoder(spec_channels, style_vector_dim=emb_dim) + self.fc1 = nn.Linear(emb_dim, z_latent_dim) + self.fc2 = nn.Linear(emb_dim, z_latent_dim) + self.fc3 = nn.Linear(z_latent_dim, emb_dim) + self.z_latent_dim = z_latent_dim + + def reparameterize(self, mu, logvar): + if self.training: + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps.mul(std).add_(mu) + else: + return mu + + def forward(self, inputs, mask=None): + enc_out = self.ref_encoder(inputs.squeeze(-1), mask).squeeze(-1) + mu = self.fc1(enc_out) + logvar = self.fc2(enc_out) + posterior = D.Normal(mu, torch.exp(logvar)) + kl_divergence = D.kl_divergence(posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))) + loss_kl = kl_divergence.mean() + + z = posterior.rsample() + style_embed = self.fc3(z) + + return style_embed.unsqueeze(-1), loss_kl + + def infer(self, inputs=None, random_sample=False, manual_latent=None): + if manual_latent is None: + if random_sample: + dev = next(self.parameters()).device + posterior = D.Normal( + torch.zeros(1, self.z_latent_dim, device=dev), + torch.ones(1, self.z_latent_dim, device=dev), + ) + z = posterior.rsample() + else: + enc_out = self.ref_encoder(inputs.transpose(1, 2)) + mu = self.fc1(enc_out) + z = mu + else: + z = manual_latent + style_embed = self.fc3(z) + return style_embed.unsqueeze(-1), z + + +class ActNorm(nn.Module): + def __init__(self, channels, ddi=False, **kwargs): + super().__init__() + self.channels = channels + self.initialized = not ddi + + self.logs = nn.Parameter(torch.zeros(1, channels, 1)) + self.bias = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs): + if x_mask is None: + x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) + x_len = torch.sum(x_mask, [1, 2]) + if not self.initialized: + self.initialize(x, x_mask) + self.initialized = True + + if reverse: + z = (x - self.bias) * torch.exp(-self.logs) * x_mask + logdet = None + return z + else: + z = (self.bias + torch.exp(self.logs) * x) * x_mask + logdet = torch.sum(self.logs) * x_len # [b] + return z, logdet + + def store_inverse(self): + pass + + def set_ddi(self, ddi): + self.initialized = not ddi + + def initialize(self, x, x_mask): + with torch.no_grad(): + denom = torch.sum(x_mask, [0, 2]) + m = torch.sum(x * x_mask, [0, 2]) / denom + m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom + v = m_sq - (m**2) + logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) + + bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) + logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) + + self.bias.data.copy_(bias_init) + self.logs.data.copy_(logs_init) + + +class InvConvNear(nn.Module): + def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs): + super().__init__() + assert n_split % 2 == 0 + self.channels = channels + self.n_split = n_split + self.no_jacobian = no_jacobian + + w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0] + if torch.det(w_init) < 0: + w_init[:, 0] = -1 * w_init[:, 0] + self.weight = nn.Parameter(w_init) + + def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs): + b, c, t = x.size() + assert c % self.n_split == 0 + if x_mask is None: + x_mask = 1 + x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t + else: + x_len = torch.sum(x_mask, [1, 2]) + + x = x.view(b, 2, c // self.n_split, self.n_split // 2, t) + x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t) + + if reverse: + if hasattr(self, "weight_inv"): + weight = self.weight_inv + else: + weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) + logdet = None + else: + weight = self.weight + if self.no_jacobian: + logdet = 0 + else: + logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b] + + weight = weight.view(self.n_split, self.n_split, 1, 1) + z = F.conv2d(x, weight) + + z = z.view(b, 2, self.n_split // 2, c // self.n_split, t) + z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask + if reverse: + return z + else: + return z, logdet + + def store_inverse(self): + self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) diff --git a/module/mrte_model.py b/module/mrte_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e889b7e91dc26ba1a1cc7de57b0e60a102464d17 --- /dev/null +++ b/module/mrte_model.py @@ -0,0 +1,173 @@ +# This is Multi-reference timbre encoder + +import torch +from torch import nn +from torch.nn.utils import remove_weight_norm, weight_norm +from module.attentions import MultiHeadAttention + + +class MRTE(nn.Module): + def __init__( + self, + content_enc_channels=192, + hidden_size=512, + out_channels=192, + kernel_size=5, + n_heads=4, + ge_layer=2, + ): + super(MRTE, self).__init__() + self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads) + self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1) + self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1) + self.c_post = nn.Conv1d(hidden_size, out_channels, 1) + + def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None): + if ge == None: + ge = 0 + attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1) + + ssl_enc = self.c_pre(ssl_enc * ssl_mask) + text_enc = self.text_pre(text * text_mask) + if test != None: + if test == 0: + x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge + elif test == 1: + x = ssl_enc + ge + elif test == 2: + x = self.cross_attention(ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask) + ge + else: + raise ValueError("test should be 0,1,2") + else: + x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge + x = self.c_post(x * ssl_mask) + return x + + +class SpeakerEncoder(torch.nn.Module): + def __init__( + self, + mel_n_channels=80, + model_num_layers=2, + model_hidden_size=256, + model_embedding_size=256, + ): + super(SpeakerEncoder, self).__init__() + self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.linear = nn.Linear(model_hidden_size, model_embedding_size) + self.relu = nn.ReLU() + + def forward(self, mels): + self.lstm.flatten_parameters() + _, (hidden, _) = self.lstm(mels.transpose(-1, -2)) + embeds_raw = self.relu(self.linear(hidden[-1])) + return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + +class MELEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + def forward(self, x): + # print(x.shape,x_lengths.shape) + x = self.pre(x) + x = self.enc(x) + x = self.proj(x) + return x + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = weight_norm(in_layer) + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + + acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = x + res_acts + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output + + def remove_weight_norm(self): + for l in self.in_layers: + remove_weight_norm(l) + for l in self.res_skip_layers: + remove_weight_norm(l) + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input, n_channels): + n_channels_int = n_channels[0] + t_act = torch.tanh(input[:, :n_channels_int, :]) + s_act = torch.sigmoid(input[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +if __name__ == "__main__": + content_enc = torch.randn(3, 192, 100) + content_mask = torch.ones(3, 1, 100) + ref_mel = torch.randn(3, 128, 30) + ref_mask = torch.ones(3, 1, 30) + model = MRTE() + out = model(content_enc, content_mask, ref_mel, ref_mask) + print(out.shape) diff --git a/module/quantize.py b/module/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..0afed835ecd777879db50c8068e571abdd743ae8 --- /dev/null +++ b/module/quantize.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Residual vector quantizer implementation.""" + +from dataclasses import dataclass, field +import typing as tp + +import torch +from torch import nn + +from module.core_vq import ResidualVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.vq = ResidualVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + + def forward( + self, + x: torch.Tensor, + n_q: tp.Optional[int] = None, + layers: tp.Optional[list] = None, + ) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + n_q (int): Number of quantizer used to quantize. Default: All quantizers. + layers (list): Layer that need to return quantized. Defalt: None. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated numbert quantizers and layer quantized required to return. + """ + n_q = n_q if n_q else self.n_q + if layers and max(layers) >= n_q: + raise ValueError( + f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B." + ) + quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers) + return quantized, codes, torch.mean(commit_loss), quantized_list + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + Args: + x (torch.Tensor): Input tensor. + n_q (int): Number of quantizer used to quantize. Default: All quantizers. + st (int): Start to encode input from which layers. Default: 0. + """ + n_q = n_q if n_q else self.n_q + st = st or 0 + codes = self.vq.encode(x, n_q=n_q, st=st) + return codes + + def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor: + """Decode the given codes to the quantized representation. + Args: + codes (torch.Tensor): Input indices for each quantizer. + st (int): Start to decode input codes from which layers. Default: 0. + """ + quantized = self.vq.decode(codes, st=st) + return quantized diff --git a/module/transforms.py b/module/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..16b549851ad9bb7c84d2628d3dcac9430b409102 --- /dev/null +++ b/module/transforms.py @@ -0,0 +1,205 @@ +import torch +from torch.nn import functional as F + +import numpy as np + + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs, + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/onnx_export.py b/onnx_export.py new file mode 100644 index 0000000000000000000000000000000000000000..fd680135fb7d71afb4680b05a62b9874c39ad21c --- /dev/null +++ b/onnx_export.py @@ -0,0 +1,398 @@ +import torch +import torchaudio +from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule +from feature_extractor import cnhubert +from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2 +from torch import nn + +cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" +cnhubert.cnhubert_base_path = cnhubert_base_path +ssl_model = cnhubert.get_model() +import json +import os + +import soundfile +from text import cleaned_text_to_sequence + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +class T2SEncoder(nn.Module): + def __init__(self, t2s, vits): + super().__init__() + self.encoder = t2s.onnx_encoder + self.vits = vits + + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): + codes = self.vits.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1) + all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) + bert = bert.unsqueeze(0) + prompt = prompt_semantic.unsqueeze(0) + return self.encoder(all_phoneme_ids, bert), prompt + + +class T2SModel(nn.Module): + def __init__(self, t2s_path, vits_model): + super().__init__() + dict_s1 = torch.load(t2s_path, map_location="cpu") + self.config = dict_s1["config"] + self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False) + self.t2s_model.load_state_dict(dict_s1["weight"]) + self.t2s_model.eval() + self.vits_model = vits_model.vq_model + self.hz = 50 + self.max_sec = self.config["data"]["max_sec"] + self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]]) + self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) + self.t2s_model = self.t2s_model.model + self.t2s_model.init_onnx() + self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model) + self.first_stage_decoder = self.t2s_model.first_stage_decoder + self.stage_decoder = self.t2s_model.stage_decoder + # self.t2s_model = torch.jit.script(self.t2s_model) + + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): + early_stop_num = self.t2s_model.early_stop_num + + # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] + x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + + prefix_len = prompts.shape[1] + + # [1,N,512] [1,N] + y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + + stop = False + for idx in range(1, 1500): + # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] + enco = self.stage_decoder(y, k, v, y_emb, x_example) + y, k, v, y_emb, logits, samples = enco + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + stop = True + if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: + stop = True + if stop: + break + y[0, -1] = 0 + + return y[:, -idx:].unsqueeze(0) + + def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False): + # self.onnx_encoder = torch.jit.script(self.onnx_encoder) + if dynamo: + export_options = torch.onnx.ExportOptions(dynamic_shapes=True) + onnx_encoder_export_output = torch.onnx.dynamo_export( + self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options + ) + onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx") + return + + torch.onnx.export( + self.onnx_encoder, + (ref_seq, text_seq, ref_bert, text_bert, ssl_content), + f"onnx/{project_name}/{project_name}_t2s_encoder.onnx", + input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"], + output_names=["x", "prompts"], + dynamic_axes={ + "ref_seq": {1: "ref_length"}, + "text_seq": {1: "text_length"}, + "ref_bert": {0: "ref_length"}, + "text_bert": {0: "text_length"}, + "ssl_content": {2: "ssl_length"}, + }, + opset_version=16, + ) + x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + + torch.onnx.export( + self.first_stage_decoder, + (x, prompts), + f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx", + input_names=["x", "prompts"], + output_names=["y", "k", "v", "y_emb", "x_example"], + dynamic_axes={ + "x": {1: "x_length"}, + "prompts": {1: "prompts_length"}, + }, + verbose=False, + opset_version=16, + ) + y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + + torch.onnx.export( + self.stage_decoder, + (y, k, v, y_emb, x_example), + f"onnx/{project_name}/{project_name}_t2s_sdec.onnx", + input_names=["iy", "ik", "iv", "iy_emb", "ix_example"], + output_names=["y", "k", "v", "y_emb", "logits", "samples"], + dynamic_axes={ + "iy": {1: "iy_length"}, + "ik": {1: "ik_length"}, + "iv": {1: "iv_length"}, + "iy_emb": {1: "iy_emb_length"}, + "ix_example": {1: "ix_example_length"}, + }, + verbose=False, + opset_version=16, + ) + + +class VitsModel(nn.Module): + def __init__(self, vits_path): + super().__init__() + dict_s2 = torch.load(vits_path, map_location="cpu") + self.hps = dict_s2["config"] + if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + self.hps["model"]["version"] = "v1" + else: + self.hps["model"]["version"] = "v2" + + self.hps = DictToAttrRecursive(self.hps) + self.hps.model.semantic_frame_rate = "25hz" + self.vq_model = SynthesizerTrn( + self.hps.data.filter_length // 2 + 1, + self.hps.train.segment_size // self.hps.data.hop_length, + n_speakers=self.hps.data.n_speakers, + **self.hps.model, + ) + self.vq_model.eval() + self.vq_model.load_state_dict(dict_s2["weight"], strict=False) + + def forward(self, text_seq, pred_semantic, ref_audio): + refer = spectrogram_torch( + ref_audio, + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False, + ) + return self.vq_model(pred_semantic, text_seq, refer)[0, 0] + + +class GptSoVits(nn.Module): + def __init__(self, vits, t2s): + super().__init__() + self.vits = vits + self.t2s = t2s + + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False): + pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + audio = self.vits(text_seq, pred_semantic, ref_audio) + if debug: + import onnxruntime + + sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"]) + audio1 = sess.run( + None, + { + "text_seq": text_seq.detach().cpu().numpy(), + "pred_semantic": pred_semantic.detach().cpu().numpy(), + "ref_audio": ref_audio.detach().cpu().numpy(), + }, + ) + return audio, audio1 + return audio + + def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name): + self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name) + pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + torch.onnx.export( + self.vits, + (text_seq, pred_semantic, ref_audio), + f"onnx/{project_name}/{project_name}_vits.onnx", + input_names=["text_seq", "pred_semantic", "ref_audio"], + output_names=["audio"], + dynamic_axes={ + "text_seq": {1: "text_length"}, + "pred_semantic": {2: "pred_length"}, + "ref_audio": {1: "audio_length"}, + }, + opset_version=17, + verbose=False, + ) + + +class SSLModel(nn.Module): + def __init__(self): + super().__init__() + self.ssl = ssl_model + + def forward(self, ref_audio_16k): + return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) + + +def export(vits_path, gpt_path, project_name, vits_model="v2"): + vits = VitsModel(vits_path) + gpt = T2SModel(gpt_path, vits) + gpt_sovits = GptSoVits(vits, gpt) + ssl = SSLModel() + ref_seq = torch.LongTensor( + [ + cleaned_text_to_sequence( + [ + "n", + "i2", + "h", + "ao3", + ",", + "w", + "o3", + "sh", + "i4", + "b", + "ai2", + "y", + "e4", + ], + version=vits_model, + ) + ] + ) + text_seq = torch.LongTensor( + [ + cleaned_text_to_sequence( + [ + "w", + "o3", + "sh", + "i4", + "b", + "ai2", + "y", + "e4", + "w", + "o3", + "sh", + "i4", + "b", + "ai2", + "y", + "e4", + "w", + "o3", + "sh", + "i4", + "b", + "ai2", + "y", + "e4", + ], + version=vits_model, + ) + ] + ) + ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() + text_bert = torch.randn((text_seq.shape[1], 1024)).float() + ref_audio = torch.randn((1, 48000 * 5)).float() + # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float() + ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() + ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float() + + try: + os.mkdir(f"onnx/{project_name}") + except: + pass + + ssl_content = ssl(ref_audio_16k).float() + + # debug = False + debug = True + + # gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) + + if debug: + a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug) + soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate) + soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate) + else: + a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() + soundfile.write("out.wav", a, vits.hps.data.sampling_rate) + + if vits_model == "v1": + symbols = symbols_v1 + else: + symbols = symbols_v2 + + MoeVSConf = { + "Folder": f"{project_name}", + "Name": f"{project_name}", + "Type": "GPT-SoVits", + "Rate": vits.hps.data.sampling_rate, + "NumLayers": gpt.t2s_model.num_layers, + "EmbeddingDim": gpt.t2s_model.embedding_dim, + "Dict": "BasicDict", + "BertPath": "chinese-roberta-wwm-ext-large", + # "Symbol": symbols, + "AddBlank": False, + } + + MoeVSConfJson = json.dumps(MoeVSConf) + with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile: + json.dump(MoeVSConf, MoeVsConfFile, indent=4) + + +if __name__ == "__main__": + try: + os.mkdir("onnx") + except: + pass + + gpt_path = "GPT_weights/nahida-e25.ckpt" + vits_path = "SoVITS_weights/nahida_e30_s3930.pth" + exp_path = "nahida" + export(vits_path, gpt_path, exp_path) + + # soundfile.write("out.wav", a, vits.hps.data.sampling_rate) diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..a9f1eea092d5e971b5475b82ee835cec7f196bad --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +ffmpeg \ No newline at end of file diff --git a/prepare_datasets/1-get-text.py b/prepare_datasets/1-get-text.py new file mode 100644 index 0000000000000000000000000000000000000000..8d83e79ae6a8459339984dfe6c61d698f7e96746 --- /dev/null +++ b/prepare_datasets/1-get-text.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- + +import os + +inp_text = os.environ.get("inp_text") +inp_wav_dir = os.environ.get("inp_wav_dir") +exp_name = os.environ.get("exp_name") +i_part = os.environ.get("i_part") +all_parts = os.environ.get("all_parts") +if "_CUDA_VISIBLE_DEVICES" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +opt_dir = os.environ.get("opt_dir") +bert_pretrained_dir = os.environ.get("bert_pretrained_dir") +import torch + +is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +version = os.environ.get("version", None) +import traceback +import os.path +from text.cleaner import clean_text +from transformers import AutoModelForMaskedLM, AutoTokenizer +from tools.my_utils import clean_path + +# inp_text=sys.argv[1] +# inp_wav_dir=sys.argv[2] +# exp_name=sys.argv[3] +# i_part=sys.argv[4] +# all_parts=sys.argv[5] +# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]#i_gpu +# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name +# bert_pretrained_dir="/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large" + +from time import time as ttime +import shutil + + +def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path + dir = os.path.dirname(path) + name = os.path.basename(path) + # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part) + tmp_path = "%s%s.pth" % (ttime(), i_part) + torch.save(fea, tmp_path) + shutil.move(tmp_path, "%s/%s" % (dir, name)) + + +txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part) +if os.path.exists(txt_path) == False: + bert_dir = "%s/3-bert" % (opt_dir) + os.makedirs(opt_dir, exist_ok=True) + os.makedirs(bert_dir, exist_ok=True) + if torch.cuda.is_available(): + device = "cuda:0" + # elif torch.backends.mps.is_available(): + # device = "mps" + else: + device = "cpu" + if os.path.exists(bert_pretrained_dir): + ... + else: + raise FileNotFoundError(bert_pretrained_dir) + tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir) + bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir) + if is_half == True: + bert_model = bert_model.half().to(device) + else: + bert_model = bert_model.to(device) + + def get_bert_feature(text, word2ph): + with torch.no_grad(): + inputs = tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(device) + res = bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + + assert len(word2ph) == len(text) + phone_level_feature = [] + for i in range(len(word2ph)): + repeat_feature = res[i].repeat(word2ph[i], 1) + phone_level_feature.append(repeat_feature) + + phone_level_feature = torch.cat(phone_level_feature, dim=0) + + return phone_level_feature.T + + def process(data, res): + for name, text, lan in data: + try: + name = clean_path(name) + name = os.path.basename(name) + print(name) + phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("¥", ","), lan, version) + path_bert = "%s/%s.pt" % (bert_dir, name) + if os.path.exists(path_bert) == False and lan == "zh": + bert_feature = get_bert_feature(norm_text, word2ph) + assert bert_feature.shape[-1] == len(phones) + # torch.save(bert_feature, path_bert) + my_save(bert_feature, path_bert) + phones = " ".join(phones) + # res.append([name,phones]) + res.append([name, phones, word2ph, norm_text]) + except: + print(name, text, traceback.format_exc()) + + todo = [] + res = [] + with open(inp_text, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") + + language_v1_to_language_v2 = { + "ZH": "zh", + "zh": "zh", + "JP": "ja", + "jp": "ja", + "JA": "ja", + "ja": "ja", + "EN": "en", + "en": "en", + "En": "en", + "KO": "ko", + "Ko": "ko", + "ko": "ko", + "yue": "yue", + "YUE": "yue", + "Yue": "yue", + } + for line in lines[int(i_part) :: int(all_parts)]: + try: + wav_name, spk_name, language, text = line.split("|") + # todo.append([name,text,"zh"]) + if language in language_v1_to_language_v2.keys(): + todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)]) + else: + print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m") + except: + print(line, traceback.format_exc()) + + process(todo, res) + opt = [] + for name, phones, word2ph, norm_text in res: + opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text)) + with open(txt_path, "w", encoding="utf8") as f: + f.write("\n".join(opt) + "\n") diff --git a/prepare_datasets/2-get-hubert-wav32k.py b/prepare_datasets/2-get-hubert-wav32k.py new file mode 100644 index 0000000000000000000000000000000000000000..3a84c014aadf435dcce5c6339d060c6742f8633b --- /dev/null +++ b/prepare_datasets/2-get-hubert-wav32k.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- + +import sys +import os + +inp_text = os.environ.get("inp_text") +inp_wav_dir = os.environ.get("inp_wav_dir") +exp_name = os.environ.get("exp_name") +i_part = os.environ.get("i_part") +all_parts = os.environ.get("all_parts") +if "_CUDA_VISIBLE_DEVICES" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +from feature_extractor import cnhubert + +opt_dir = os.environ.get("opt_dir") +cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir") +import torch + +is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() + +import traceback +import numpy as np +from scipy.io import wavfile +import librosa + +now_dir = os.getcwd() +sys.path.append(now_dir) +from tools.my_utils import load_audio, clean_path + +# from config import cnhubert_base_path +# cnhubert.cnhubert_base_path=cnhubert_base_path +# inp_text=sys.argv[1] +# inp_wav_dir=sys.argv[2] +# exp_name=sys.argv[3] +# i_part=sys.argv[4] +# all_parts=sys.argv[5] +# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6] +# cnhubert.cnhubert_base_path=sys.argv[7] +# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name + +from time import time as ttime +import shutil + + +def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path + dir = os.path.dirname(path) + name = os.path.basename(path) + # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part) + tmp_path = "%s%s.pth" % (ttime(), i_part) + torch.save(fea, tmp_path) + shutil.move(tmp_path, "%s/%s" % (dir, name)) + + +hubert_dir = "%s/4-cnhubert" % (opt_dir) +wav32dir = "%s/5-wav32k" % (opt_dir) +os.makedirs(opt_dir, exist_ok=True) +os.makedirs(hubert_dir, exist_ok=True) +os.makedirs(wav32dir, exist_ok=True) + +maxx = 0.95 +alpha = 0.5 +if torch.cuda.is_available(): + device = "cuda:0" +# elif torch.backends.mps.is_available(): +# device = "mps" +else: + device = "cpu" +model = cnhubert.get_model() +# is_half=False +if is_half == True: + model = model.half().to(device) +else: + model = model.to(device) + +nan_fails = [] + + +def name2go(wav_name, wav_path): + hubert_path = "%s/%s.pt" % (hubert_dir, wav_name) + if os.path.exists(hubert_path): + return + tmp_audio = load_audio(wav_path, 32000) + tmp_max = np.abs(tmp_audio).max() + if tmp_max > 2.2: + print("%s-filtered,%s" % (wav_name, tmp_max)) + return + tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio + tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio + tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题 + tensor_wav16 = torch.from_numpy(tmp_audio) + if is_half == True: + tensor_wav16 = tensor_wav16.half().to(device) + else: + tensor_wav16 = tensor_wav16.to(device) + ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215]) + if np.isnan(ssl.detach().numpy()).sum() != 0: + nan_fails.append((wav_name, wav_path)) + print("nan filtered:%s" % wav_name) + return + wavfile.write( + "%s/%s" % (wav32dir, wav_name), + 32000, + tmp_audio32.astype("int16"), + ) + my_save(ssl, hubert_path) + + +with open(inp_text, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") + +for line in lines[int(i_part) :: int(all_parts)]: + try: + # wav_name,text=line.split("\t") + wav_name, spk_name, language, text = line.split("|") + wav_name = clean_path(wav_name) + if inp_wav_dir != "" and inp_wav_dir != None: + wav_name = os.path.basename(wav_name) + wav_path = "%s/%s" % (inp_wav_dir, wav_name) + + else: + wav_path = wav_name + wav_name = os.path.basename(wav_name) + name2go(wav_name, wav_path) + except: + print(line, traceback.format_exc()) + +if len(nan_fails) > 0 and is_half == True: + is_half = False + model = model.float() + for wav in nan_fails: + try: + name2go(wav[0], wav[1]) + except: + print(wav_name, traceback.format_exc()) diff --git a/prepare_datasets/2-get-sv.py b/prepare_datasets/2-get-sv.py new file mode 100644 index 0000000000000000000000000000000000000000..80b0ad69455024a7e6e318f1cac06218b4aa5f25 --- /dev/null +++ b/prepare_datasets/2-get-sv.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- + +import sys +import os + +inp_text = os.environ.get("inp_text") +inp_wav_dir = os.environ.get("inp_wav_dir") +exp_name = os.environ.get("exp_name") +i_part = os.environ.get("i_part") +all_parts = os.environ.get("all_parts") +if "_CUDA_VISIBLE_DEVICES" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] + +opt_dir = os.environ.get("opt_dir") +sv_path = os.environ.get("sv_path") +import torch + +is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() + +import traceback +import torchaudio + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append(f"{now_dir}/GPT_SoVITS/eres2net") +from tools.my_utils import clean_path +from time import time as ttime +import shutil +from ERes2NetV2 import ERes2NetV2 +import kaldi as Kaldi + + +def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path + dir = os.path.dirname(path) + name = os.path.basename(path) + # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part) + tmp_path = "%s%s.pth" % (ttime(), i_part) + torch.save(fea, tmp_path) + shutil.move(tmp_path, "%s/%s" % (dir, name)) + + +sv_cn_dir = "%s/7-sv_cn" % (opt_dir) +wav32dir = "%s/5-wav32k" % (opt_dir) +os.makedirs(opt_dir, exist_ok=True) +os.makedirs(sv_cn_dir, exist_ok=True) +os.makedirs(wav32dir, exist_ok=True) + +maxx = 0.95 +alpha = 0.5 +if torch.cuda.is_available(): + device = "cuda:0" +# elif torch.backends.mps.is_available(): +# device = "mps" +else: + device = "cpu" + + +class SV: + def __init__(self, device, is_half): + pretrained_state = torch.load(sv_path, map_location="cpu") + embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4) + embedding_model.load_state_dict(pretrained_state) + embedding_model.eval() + self.embedding_model = embedding_model + self.res = torchaudio.transforms.Resample(32000, 16000).to(device) + if is_half == False: + self.embedding_model = self.embedding_model.to(device) + else: + self.embedding_model = self.embedding_model.half().to(device) + self.is_half = is_half + + def compute_embedding3(self, wav): # (1,x)#-1~1 + with torch.no_grad(): + wav = self.res(wav) + if self.is_half == True: + wav = wav.half() + feat = torch.stack( + [Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav] + ) + sv_emb = self.embedding_model.forward3(feat) + return sv_emb + + +sv = SV(device, is_half) + + +def name2go(wav_name, wav_path): + sv_cn_path = "%s/%s.pt" % (sv_cn_dir, wav_name) + if os.path.exists(sv_cn_path): + return + wav_path = "%s/%s" % (wav32dir, wav_name) + wav32k, sr0 = torchaudio.load(wav_path) + assert sr0 == 32000 + wav32k = wav32k.to(device) + emb = sv.compute_embedding3(wav32k).cpu() # torch.Size([1, 20480]) + my_save(emb, sv_cn_path) + + +with open(inp_text, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") + +for line in lines[int(i_part) :: int(all_parts)]: + try: + wav_name, spk_name, language, text = line.split("|") + wav_name = clean_path(wav_name) + if inp_wav_dir != "" and inp_wav_dir != None: + wav_name = os.path.basename(wav_name) + wav_path = "%s/%s" % (inp_wav_dir, wav_name) + + else: + wav_path = wav_name + wav_name = os.path.basename(wav_name) + name2go(wav_name, wav_path) + except: + print(line, traceback.format_exc()) diff --git a/prepare_datasets/3-get-semantic.py b/prepare_datasets/3-get-semantic.py new file mode 100644 index 0000000000000000000000000000000000000000..ddb0607cc75fa03a8458f7ab0641e70dc79cf4f4 --- /dev/null +++ b/prepare_datasets/3-get-semantic.py @@ -0,0 +1,118 @@ +import os + +inp_text = os.environ.get("inp_text") +exp_name = os.environ.get("exp_name") +i_part = os.environ.get("i_part") +all_parts = os.environ.get("all_parts") +if "_CUDA_VISIBLE_DEVICES" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +opt_dir = os.environ.get("opt_dir") +pretrained_s2G = os.environ.get("pretrained_s2G") +s2config_path = os.environ.get("s2config_path") + +if os.path.exists(pretrained_s2G): + ... +else: + raise FileNotFoundError(pretrained_s2G) +# version=os.environ.get("version","v2") +size = os.path.getsize(pretrained_s2G) +if size < 82978 * 1024: + version = "v1" +elif size < 100 * 1024 * 1024: + version = "v2" +elif size < 103520 * 1024: + version = "v1" +elif size < 700 * 1024 * 1024: + version = "v2" +else: + version = "v3" +import torch + +is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +import traceback +import sys + +now_dir = os.getcwd() +sys.path.append(now_dir) +import logging +import utils + +if version != "v3": + from module.models import SynthesizerTrn +else: + from module.models import SynthesizerTrnV3 as SynthesizerTrn +from tools.my_utils import clean_path + +logging.getLogger("numba").setLevel(logging.WARNING) +# from config import pretrained_s2G + +# inp_text=sys.argv[1] +# exp_name=sys.argv[2] +# i_part=sys.argv[3] +# all_parts=sys.argv[4] +# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[5] +# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name + + +hubert_dir = "%s/4-cnhubert" % (opt_dir) +semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part) +if os.path.exists(semantic_path) == False: + os.makedirs(opt_dir, exist_ok=True) + + if torch.cuda.is_available(): + device = "cuda" + # elif torch.backends.mps.is_available(): + # device = "mps" + else: + device = "cpu" + hps = utils.get_hparams_from_file(s2config_path) + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + version=version, + **hps.model, + ) + if is_half == True: + vq_model = vq_model.half().to(device) + else: + vq_model = vq_model.to(device) + vq_model.eval() + # utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True) + # utils.load_checkpoint(pretrained_s2G, vq_model, None, True) + print( + vq_model.load_state_dict( + torch.load(pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False + ) + ) + + def name2go(wav_name, lines): + hubert_path = "%s/%s.pt" % (hubert_dir, wav_name) + if os.path.exists(hubert_path) == False: + return + ssl_content = torch.load(hubert_path, map_location="cpu") + if is_half == True: + ssl_content = ssl_content.half().to(device) + else: + ssl_content = ssl_content.to(device) + codes = vq_model.extract_latent(ssl_content) + semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()]) + lines.append("%s\t%s" % (wav_name, semantic)) + + with open(inp_text, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") + + lines1 = [] + for line in lines[int(i_part) :: int(all_parts)]: + # print(line) + try: + # wav_name,text=line.split("\t") + wav_name, spk_name, language, text = line.split("|") + wav_name = clean_path(wav_name) + wav_name = os.path.basename(wav_name) + # name2go(name,lines1) + name2go(wav_name, lines1) + except: + print(line, traceback.format_exc()) + with open(semantic_path, "w", encoding="utf8") as f: + f.write("\n".join(lines1)) diff --git a/pretrained_models/.gitignore b/pretrained_models/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c96a04f008ee21e260b28f7701595ed59e2839e3 --- /dev/null +++ b/pretrained_models/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/process_ckpt.py b/process_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..20db9b19bb05c1780a248297bd4ea64847002bf5 --- /dev/null +++ b/process_ckpt.py @@ -0,0 +1,138 @@ +import traceback +from collections import OrderedDict +from time import time as ttime +import shutil +import os +import torch +from tools.i18n.i18n import I18nAuto + +i18n = I18nAuto() + + +def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path + dir = os.path.dirname(path) + name = os.path.basename(path) + tmp_path = "%s.pth" % (ttime()) + torch.save(fea, tmp_path) + shutil.move(tmp_path, "%s/%s" % (dir, name)) + + +from io import BytesIO + +model_version2byte = { + "v3": b"03", + "v4": b"04", + "v2Pro": b"05", + "v2ProPlus": b"06", +} + + +def my_save2(fea, path, model_version): + bio = BytesIO() + torch.save(fea, bio) + bio.seek(0) + data = bio.getvalue() + byte = model_version2byte[model_version] + data = byte + data[2:] + with open(path, "wb") as f: + f.write(data) + + +def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None): + try: + opt = OrderedDict() + opt["weight"] = {} + for key in ckpt.keys(): + if "enc_q" in key: + continue + opt["weight"][key] = ckpt[key].half() + opt["config"] = hps + opt["info"] = "%sepoch_%siteration" % (epoch, steps) + if lora_rank: + opt["lora_rank"] = lora_rank + my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version) + elif model_version != None and "Pro" in model_version: + my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version) + else: + my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) + return "Success." + except: + return traceback.format_exc() + + +""" +00:v1 +01:v2 +02:v3 +03:v3lora +04:v4lora +05:v2Pro +06:v2ProPlus +""" +head2version = { + b"00": ["v1", "v1", False], + b"01": ["v2", "v2", False], + b"02": ["v2", "v3", False], + b"03": ["v2", "v3", True], + b"04": ["v2", "v4", True], + b"05": ["v2", "v2Pro", False], + b"06": ["v2", "v2ProPlus", False], +} +hash_pretrained_dict = { + "dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained + "43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained + "6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained + "4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained + "c7e9fce2223f3db685cdfa1e6368728a": ["v2", "v2Pro", False], # s2Gv2Pro.pth#sovits_v2Pro_pretrained + "66b313e39455b57ab1b0bc0b239c9d0a": ["v2", "v2ProPlus", False], # s2Gv2ProPlus.pth#sovits_v2ProPlus_pretrained +} +import hashlib + + +def get_hash_from_file(sovits_path): + with open(sovits_path, "rb") as f: + data = f.read(8192) + hash_md5 = hashlib.md5() + hash_md5.update(data) + return hash_md5.hexdigest() + + +def get_sovits_version_from_path_fast(sovits_path): + ###1-if it is pretrained sovits models, by hash + hash = get_hash_from_file(sovits_path) + if hash in hash_pretrained_dict: + return hash_pretrained_dict[hash] + ###2-new weights, by head + with open(sovits_path, "rb") as f: + version = f.read(2) + if version != b"PK": + return head2version[version] + ###3-old weights, by file size + if_lora_v3 = False + size = os.path.getsize(sovits_path) + """ + v1weights:about 82942KB + half thr:82978KB + v2weights:about 83014KB + v3weights:about 750MB + """ + if size < 82978 * 1024: + model_version = version = "v1" + elif size < 700 * 1024 * 1024: + model_version = version = "v2" + else: + version = "v2" + model_version = "v3" + return version, model_version, if_lora_v3 + + +def load_sovits_new(sovits_path): + f = open(sovits_path, "rb") + meta = f.read(2) + if meta != b"PK": + data = b"PK" + f.read() + bio = BytesIO() + bio.write(data) + bio.seek(0) + return torch.load(bio, map_location="cpu", weights_only=False) + return torch.load(sovits_path, map_location="cpu", weights_only=False) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..03e2ad19e112d2e09a47ecda505e5fdef8d4a618 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,34 @@ +numpy<2.0 +scipy>=1.11.3 +tensorboard==2.15.1 +librosa==0.9.2 +numba==0.56.4 +torchaudio +pytorch-lightning>=2.4 +gradio<5 +ffmpeg-python==0.2.0 +onnxruntime-gpu +tqdm==4.66.4 +cn2an==0.5.22 +pypinyin==0.50.0 +pyopenjtalk==0.4.1 +g2p_en==2.1.0 +sentencepiece==0.1.99 +transformers==4.35.0 +chardet==3.0.4 +PyYAML==6.0.1 +psutil==5.9.7 +jieba_fast==0.53 +jieba==0.42.1 +https://hf-mirror.com/lj1995/GPT-SoVITS-windows-package/resolve/main/langsegment-0.3.5-py3-none-any.whl?download=true +wordsegment==1.3.1 +rotary_embedding_torch==0.6.4 +spaces +pyjyutping==1.0.0 +g2pk2==0.0.3 +ko_pron==1.3 +opencc==1.1.0 +python_mecab_ko==1.3.7 +torch==2.4 +pydantic<=2.10.6 +torchmetrics<=1.5 \ No newline at end of file diff --git a/s1_train.py b/s1_train.py new file mode 100644 index 0000000000000000000000000000000000000000..1176f0bcef869d2f66574d47d9a55d0cc7e1ac0c --- /dev/null +++ b/s1_train.py @@ -0,0 +1,171 @@ +# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py +import os + +if "_CUDA_VISIBLE_DEVICES" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +import argparse +import logging +import platform +from pathlib import Path + +import torch +from AR.data.data_module import Text2SemanticDataModule +from AR.models.t2s_lightning_module import Text2SemanticLightningModule +from AR.utils.io import load_yaml_config +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger +from pytorch_lightning.strategies import DDPStrategy + +logging.getLogger("numba").setLevel(logging.WARNING) +logging.getLogger("matplotlib").setLevel(logging.WARNING) +torch.set_float32_matmul_precision("high") +from collections import OrderedDict + +from AR.utils import get_newest_ckpt +from process_ckpt import my_save + + +class my_model_ckpt(ModelCheckpoint): + def __init__( + self, + config, + if_save_latest, + if_save_every_weights, + half_weights_save_dir, + exp_name, + **kwargs, + ): + super().__init__(**kwargs) + self.if_save_latest = if_save_latest + self.if_save_every_weights = if_save_every_weights + self.half_weights_save_dir = half_weights_save_dir + self.exp_name = exp_name + self.config = config + + def on_train_epoch_end(self, trainer, pl_module): + # if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): + if self._should_save_on_train_epoch_end(trainer): + monitor_candidates = self._monitor_candidates(trainer) + if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: + if ( + self.if_save_latest == True + ): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt + to_clean = list(os.listdir(self.dirpath)) + self._save_topk_checkpoint(trainer, monitor_candidates) + if self.if_save_latest == True: + for name in to_clean: + try: + os.remove("%s/%s" % (self.dirpath, name)) + except: + pass + if self.if_save_every_weights == True: + to_save_od = OrderedDict() + to_save_od["weight"] = OrderedDict() + dictt = trainer.strategy._lightning_module.state_dict() + for key in dictt: + to_save_od["weight"][key] = dictt[key].half() + to_save_od["config"] = self.config + to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1) + # torch.save( + # print(os.environ) + if os.environ.get("LOCAL_RANK", "0") == "0": + my_save( + to_save_od, + "%s/%s-e%s.ckpt" + % ( + self.half_weights_save_dir, + self.exp_name, + trainer.current_epoch + 1, + ), + ) + self._save_last_checkpoint(trainer, monitor_candidates) + + +def main(args): + config = load_yaml_config(args.config_file) + + output_dir = Path(config["output_dir"]) + output_dir.mkdir(parents=True, exist_ok=True) + + ckpt_dir = output_dir / "ckpt" + ckpt_dir.mkdir(parents=True, exist_ok=True) + + seed_everything(config["train"]["seed"], workers=True) + ckpt_callback: ModelCheckpoint = my_model_ckpt( + config=config, + if_save_latest=config["train"]["if_save_latest"], + if_save_every_weights=config["train"]["if_save_every_weights"], + half_weights_save_dir=config["train"]["half_weights_save_dir"], + exp_name=config["train"]["exp_name"], + save_top_k=-1, + monitor="top_3_acc", + mode="max", + save_on_train_epoch_end=True, + every_n_epochs=config["train"]["save_every_n_epoch"], + dirpath=ckpt_dir, + ) + logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["USE_LIBUV"] = "0" + trainer: Trainer = Trainer( + max_epochs=config["train"]["epochs"], + accelerator="gpu" if torch.cuda.is_available() else "cpu", + # val_check_interval=9999999999999999999999,###不要验证 + # check_val_every_n_epoch=None, + limit_val_batches=0, + devices=-1 if torch.cuda.is_available() else 1, + benchmark=False, + fast_dev_run=False, + strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo") + if torch.cuda.is_available() + else "auto", + precision=config["train"]["precision"], + logger=logger, + num_sanity_val_steps=0, + callbacks=[ckpt_callback], + use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题! + ) + + model: Text2SemanticLightningModule = Text2SemanticLightningModule(config, output_dir) + + data_module: Text2SemanticDataModule = Text2SemanticDataModule( + config, + train_semantic_path=config["train_semantic_path"], + train_phoneme_path=config["train_phoneme_path"], + # dev_semantic_path=args.dev_semantic_path, + # dev_phoneme_path=args.dev_phoneme_path + ) + + try: + # 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序 + newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir)) + ckpt_path = ckpt_dir / newest_ckpt_name + except Exception: + ckpt_path = None + print("ckpt_path:", ckpt_path) + trainer.fit(model, data_module, ckpt_path=ckpt_path) + + +# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--config_file", + type=str, + default="configs/s1longer.yaml", + help="path of config file", + ) + # args for dataset + # parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv') + # parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt') + + # parser.add_argument('--dev_semantic_path', type=str, default='dump_mix/semantic_dev.tsv') + # parser.add_argument('--dev_phoneme_path', type=str, default='dump_mix/phoneme_dev.npy') + # parser.add_argument('--output_dir',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/logs_s1',help='directory to save the results') + # parser.add_argument('--output_dir',type=str,default='/liujing04/gpt_logs/s1/xuangou_ft',help='directory to save the results') + + args = parser.parse_args() + logging.info(str(args)) + main(args) diff --git a/s2_train.py b/s2_train.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9f64886ea54304c661ebaee3ad7c5193c9d8cf --- /dev/null +++ b/s2_train.py @@ -0,0 +1,684 @@ +import warnings + +warnings.filterwarnings("ignore") +import os + +import utils + +hps = utils.get_hparams(stage=2) +os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") +import logging + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.cuda.amp import GradScaler, autocast +from torch.nn import functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("h5py").setLevel(logging.INFO) +logging.getLogger("numba").setLevel(logging.INFO) +from random import randint + +from module import commons +from module.data_utils import ( + DistributedBucketSampler, + TextAudioSpeakerCollate, + TextAudioSpeakerLoader, +) +from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss +from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch +from module.models import ( + MultiPeriodDiscriminator, + SynthesizerTrn, +) +from process_ckpt import savee + +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = False +###反正A100fp32更快,那试试tf32吧 +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 +# from config import pretrained_s2G,pretrained_s2D +global_step = 0 + +device = "cpu" # cuda以外的设备,等mps优化后加入 + + +def main(): + if torch.cuda.is_available(): + n_gpus = torch.cuda.device_count() + else: + n_gpus = 1 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(randint(20000, 55555)) + + mp.spawn( + run, + nprocs=n_gpus, + args=( + n_gpus, + hps, + ), + ) + + +def run(rank, n_gpus, hps): + global global_step + if rank == 0: + logger = utils.get_logger(hps.data.exp_dir) + logger.info(hps) + # utils.check_git_hash(hps.s2_ckpt_dir) + writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) + + dist.init_process_group( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False", + world_size=n_gpus, + rank=rank, + ) + torch.manual_seed(hps.train.seed) + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + + train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version) + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [ + 32, + 300, + 400, + 500, + 600, + 700, + 800, + 900, + 1000, + 1100, + 1200, + 1300, + 1400, + 1500, + 1600, + 1700, + 1800, + 1900, + ], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + collate_fn = TextAudioSpeakerCollate(version=hps.model.version) + train_loader = DataLoader( + train_dataset, + num_workers=5, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + persistent_workers=True, + prefetch_factor=4, + ) + # if rank == 0: + # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True) + # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False, + # batch_size=1, pin_memory=True, + # drop_last=False, collate_fn=collate_fn) + + net_g = ( + SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).cuda(rank) + if torch.cuda.is_available() + else SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).to(device) + ) + + net_d = ( + MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank) + if torch.cuda.is_available() + else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device) + ) + for name, param in net_g.named_parameters(): + if not param.requires_grad: + print(name, "not requires_grad") + + te_p = list(map(id, net_g.enc_p.text_embedding.parameters())) + et_p = list(map(id, net_g.enc_p.encoder_text.parameters())) + mrte_p = list(map(id, net_g.enc_p.mrte.parameters())) + base_params = filter( + lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad, + net_g.parameters(), + ) + + # te_p=net_g.enc_p.text_embedding.parameters() + # et_p=net_g.enc_p.encoder_text.parameters() + # mrte_p=net_g.enc_p.mrte.parameters() + + optim_g = torch.optim.AdamW( + # filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致 + [ + {"params": base_params, "lr": hps.train.learning_rate}, + { + "params": net_g.enc_p.text_embedding.parameters(), + "lr": hps.train.learning_rate * hps.train.text_low_lr_rate, + }, + { + "params": net_g.enc_p.encoder_text.parameters(), + "lr": hps.train.learning_rate * hps.train.text_low_lr_rate, + }, + { + "params": net_g.enc_p.mrte.parameters(), + "lr": hps.train.learning_rate * hps.train.text_low_lr_rate, + }, + ], + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + optim_d = torch.optim.AdamW( + net_d.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + if torch.cuda.is_available(): + net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) + net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) + else: + net_g = net_g.to(device) + net_d = net_d.to(device) + + try: # 如果能加载自动resume + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"), + net_d, + optim_d, + ) # D多半加载没事 + if rank == 0: + logger.info("loaded D") + # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"), + net_g, + optim_g, + ) + epoch_str += 1 + global_step = (epoch_str - 1) * len(train_loader) + # epoch_str = 1 + # global_step = 0 + except: # 如果首次不能加载,加载pretrain + # traceback.print_exc() + epoch_str = 1 + global_step = 0 + if ( + hps.train.pretrained_s2G != "" + and hps.train.pretrained_s2G != None + and os.path.exists(hps.train.pretrained_s2G) + ): + if rank == 0: + logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) + print( + "loaded pretrained %s" % hps.train.pretrained_s2G, + net_g.module.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], + strict=False, + ) + if torch.cuda.is_available() + else net_g.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], + strict=False, + ), + ) ##测试不加载优化器 + if ( + hps.train.pretrained_s2D != "" + and hps.train.pretrained_s2D != None + and os.path.exists(hps.train.pretrained_s2D) + ): + if rank == 0: + logger.info("loaded pretrained %s" % hps.train.pretrained_s2D) + print( + "loaded pretrained %s" % hps.train.pretrained_s2D, + net_d.module.load_state_dict( + torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False + ) + if torch.cuda.is_available() + else net_d.load_state_dict( + torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], + ), + ) + + # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + optim_g, + gamma=hps.train.lr_decay, + last_epoch=-1, + ) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + optim_d, + gamma=hps.train.lr_decay, + last_epoch=-1, + ) + for _ in range(epoch_str): + scheduler_g.step() + scheduler_d.step() + + scaler = GradScaler(enabled=hps.train.fp16_run) + + print("start training from epoch %s" % epoch_str) + for epoch in range(epoch_str, hps.train.epochs + 1): + if rank == 0: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + # [train_loader, eval_loader], logger, [writer, writer_eval]) + [train_loader, None], + logger, + [writer, writer_eval], + ) + else: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + None, + None, + ) + scheduler_g.step() + scheduler_d.step() + print("training done") + + +def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): + net_g, net_d = nets + optim_g, optim_d = optims + # scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + net_d.train() + for batch_idx, data in enumerate(tqdm(train_loader)): + if hps.model.version in {"v2Pro", "v2ProPlus"}: + ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data + else: + ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data + if torch.cuda.is_available(): + spec, spec_lengths = ( + spec.cuda( + rank, + non_blocking=True, + ), + spec_lengths.cuda( + rank, + non_blocking=True, + ), + ) + y, y_lengths = ( + y.cuda( + rank, + non_blocking=True, + ), + y_lengths.cuda( + rank, + non_blocking=True, + ), + ) + ssl = ssl.cuda(rank, non_blocking=True) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) + text, text_lengths = ( + text.cuda( + rank, + non_blocking=True, + ), + text_lengths.cuda( + rank, + non_blocking=True, + ), + ) + if hps.model.version in {"v2Pro", "v2ProPlus"}: + sv_emb = sv_emb.cuda(rank, non_blocking=True) + else: + spec, spec_lengths = spec.to(device), spec_lengths.to(device) + y, y_lengths = y.to(device), y_lengths.to(device) + ssl = ssl.to(device) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) + text, text_lengths = text.to(device), text_lengths.to(device) + if hps.model.version in {"v2Pro", "v2ProPlus"}: + sv_emb = sv_emb.to(device) + with autocast(enabled=hps.train.fp16_run): + if hps.model.version in {"v2Pro", "v2ProPlus"}: + (y_hat, kl_ssl, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl) = net_g( + ssl, spec, spec_lengths, text, text_lengths, sv_emb + ) + else: + ( + y_hat, + kl_ssl, + ids_slice, + x_mask, + z_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + stats_ssl, + ) = net_g(ssl, spec, spec_lengths, text, text_lengths) + + mel = spec_to_mel_torch( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + + y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice + + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + with autocast(enabled=False): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( + y_d_hat_r, + y_d_hat_g, + ) + loss_disc_all = loss_disc + optim_d.zero_grad() + scaler.scale(loss_disc_all).backward() + scaler.unscale_(optim_d) + grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) + scaler.step(optim_d) + + with autocast(enabled=hps.train.fp16_run): + # Generator + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) + with autocast(enabled=False): + loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl + + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl + + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl] + logger.info( + "Train Epoch: {} [{:.0f}%]".format( + epoch, + 100.0 * batch_idx / len(train_loader), + ) + ) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = { + "loss/g/total": loss_gen_all, + "loss/d/total": loss_disc_all, + "learning_rate": lr, + "grad_norm_d": grad_norm_d, + "grad_norm_g": grad_norm_g, + } + scalar_dict.update( + { + "loss/g/fm": loss_fm, + "loss/g/mel": loss_mel, + "loss/g/kl_ssl": kl_ssl, + "loss/g/kl": loss_kl, + } + ) + + # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) + # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) + # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) + image_dict = None + try: ###Some people installed the wrong version of matplotlib. + image_dict = { + "slice/mel_org": utils.plot_spectrogram_to_numpy( + y_mel[0].data.cpu().numpy(), + ), + "slice/mel_gen": utils.plot_spectrogram_to_numpy( + y_hat_mel[0].data.cpu().numpy(), + ), + "all/mel": utils.plot_spectrogram_to_numpy( + mel[0].data.cpu().numpy(), + ), + "all/stats_ssl": utils.plot_spectrogram_to_numpy( + stats_ssl[0].data.cpu().numpy(), + ), + } + except: + pass + if image_dict: + utils.summarize( + writer=writer, + global_step=global_step, + images=image_dict, + scalars=scalar_dict, + ) + else: + utils.summarize( + writer=writer, + global_step=global_step, + scalars=scalar_dict, + ) + global_step += 1 + if epoch % hps.train.save_every_epoch == 0 and rank == 0: + if hps.train.if_save_latest == 0: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), + "G_{}.pth".format(global_step), + ), + ) + utils.save_checkpoint( + net_d, + optim_d, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), + "D_{}.pth".format(global_step), + ), + ) + else: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), + "G_{}.pth".format(233333333333), + ), + ) + utils.save_checkpoint( + net_d, + optim_d, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), + "D_{}.pth".format(233333333333), + ), + ) + if rank == 0 and hps.train.if_save_every_weights == True: + if hasattr(net_g, "module"): + ckpt = net_g.module.state_dict() + else: + ckpt = net_g.state_dict() + logger.info( + "saving ckpt %s_e%s:%s" + % ( + hps.name, + epoch, + savee( + ckpt, + hps.name + "_e%s_s%s" % (epoch, global_step), + epoch, + global_step, + hps, + model_version=None if hps.model.version not in {"v2Pro", "v2ProPlus"} else hps.model.version, + ), + ) + ) + + if rank == 0: + logger.info("====> Epoch: {}".format(epoch)) + + +def evaluate(hps, generator, eval_loader, writer_eval): + generator.eval() + image_dict = {} + audio_dict = {} + print("Evaluating ...") + with torch.no_grad(): + for batch_idx, ( + ssl, + ssl_lengths, + spec, + spec_lengths, + y, + y_lengths, + text, + text_lengths, + ) in enumerate(eval_loader): + print(111) + if torch.cuda.is_available(): + spec, spec_lengths = spec.cuda(), spec_lengths.cuda() + y, y_lengths = y.cuda(), y_lengths.cuda() + ssl = ssl.cuda() + text, text_lengths = text.cuda(), text_lengths.cuda() + else: + spec, spec_lengths = spec.to(device), spec_lengths.to(device) + y, y_lengths = y.to(device), y_lengths.to(device) + ssl = ssl.to(device) + text, text_lengths = text.to(device), text_lengths.to(device) + for test in [0, 1]: + y_hat, mask, *_ = ( + generator.module.infer( + ssl, + spec, + spec_lengths, + text, + text_lengths, + test=test, + ) + if torch.cuda.is_available() + else generator.infer( + ssl, + spec, + spec_lengths, + text, + text_lengths, + test=test, + ) + ) + y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length + + mel = spec_to_mel_torch( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1).float(), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + image_dict.update( + { + f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy( + y_hat_mel[0].cpu().numpy(), + ), + } + ) + audio_dict.update( + { + f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]], + }, + ) + image_dict.update( + { + f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()), + }, + ) + audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]}) + + # y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None) + # audio_dict.update({ + # f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :] + # }) + + utils.summarize( + writer=writer_eval, + global_step=global_step, + images=image_dict, + audios=audio_dict, + audio_sampling_rate=hps.data.sampling_rate, + ) + generator.train() + + +if __name__ == "__main__": + main() diff --git a/s2_train_v3.py b/s2_train_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..aa8dae7f2de991e36b31ae34e2bdf8d0357388ab --- /dev/null +++ b/s2_train_v3.py @@ -0,0 +1,467 @@ +import warnings + +warnings.filterwarnings("ignore") +import os + +import utils + +hps = utils.get_hparams(stage=2) +os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") +import logging + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("h5py").setLevel(logging.INFO) +logging.getLogger("numba").setLevel(logging.INFO) +from random import randint + +from module import commons +from module.data_utils import ( + DistributedBucketSampler, +) +from module.data_utils import ( + TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate, +) +from module.data_utils import ( + TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader, +) +from module.models import ( + SynthesizerTrnV3 as SynthesizerTrn, +) +from process_ckpt import savee + +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = False +###反正A100fp32更快,那试试tf32吧 +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 +# from config import pretrained_s2G,pretrained_s2D +global_step = 0 + +device = "cpu" # cuda以外的设备,等mps优化后加入 + + +def main(): + if torch.cuda.is_available(): + n_gpus = torch.cuda.device_count() + else: + n_gpus = 1 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(randint(20000, 55555)) + + mp.spawn( + run, + nprocs=n_gpus, + args=( + n_gpus, + hps, + ), + ) + + +def run(rank, n_gpus, hps): + global global_step + if rank == 0: + logger = utils.get_logger(hps.data.exp_dir) + logger.info(hps) + # utils.check_git_hash(hps.s2_ckpt_dir) + writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) + + dist.init_process_group( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False", + world_size=n_gpus, + rank=rank, + ) + torch.manual_seed(hps.train.seed) + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + + train_dataset = TextAudioSpeakerLoader(hps.data) ######## + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [ + 32, + 300, + 400, + 500, + 600, + 700, + 800, + 900, + 1000, + # 1100, + # 1200, + # 1300, + # 1400, + # 1500, + # 1600, + # 1700, + # 1800, + # 1900, + ], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + collate_fn = TextAudioSpeakerCollate() + train_loader = DataLoader( + train_dataset, + num_workers=6, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + persistent_workers=True, + prefetch_factor=4, + ) + # if rank == 0: + # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True) + # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False, + # batch_size=1, pin_memory=True, + # drop_last=False, collate_fn=collate_fn) + + net_g = ( + SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).cuda(rank) + if torch.cuda.is_available() + else SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).to(device) + ) + + # net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device) + # for name, param in net_g.named_parameters(): + # if not param.requires_grad: + # print(name, "not requires_grad") + + optim_g = torch.optim.AdamW( + filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致 + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + # optim_d = torch.optim.AdamW( + # net_d.parameters(), + # hps.train.learning_rate, + # betas=hps.train.betas, + # eps=hps.train.eps, + # ) + if torch.cuda.is_available(): + net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) + # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) + else: + net_g = net_g.to(device) + # net_d = net_d.to(device) + + try: # 如果能加载自动resume + # _, _, _, epoch_str = utils.load_checkpoint( + # utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"), + # net_d, + # optim_d, + # ) # D多半加载没事 + # if rank == 0: + # logger.info("loaded D") + # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"), + net_g, + optim_g, + ) + epoch_str += 1 + global_step = (epoch_str - 1) * len(train_loader) + # epoch_str = 1 + # global_step = 0 + except: # 如果首次不能加载,加载pretrain + # traceback.print_exc() + epoch_str = 1 + global_step = 0 + if ( + hps.train.pretrained_s2G != "" + and hps.train.pretrained_s2G != None + and os.path.exists(hps.train.pretrained_s2G) + ): + if rank == 0: + logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) + print( + "loaded pretrained %s" % hps.train.pretrained_s2G, + net_g.module.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], + strict=False, + ) + if torch.cuda.is_available() + else net_g.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], + strict=False, + ), + ) ##测试不加载优化器 + # if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D): + # if rank == 0: + # logger.info("loaded pretrained %s" % hps.train.pretrained_s2D) + # print( + # net_d.module.load_state_dict( + # torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"] + # ) if torch.cuda.is_available() else net_d.load_state_dict( + # torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"] + # ) + # ) + + # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1) + # scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + # optim_d, gamma=hps.train.lr_decay, last_epoch=-1 + # ) + for _ in range(epoch_str): + scheduler_g.step() + # scheduler_d.step() + + scaler = GradScaler(enabled=hps.train.fp16_run) + + net_d = optim_d = scheduler_d = None + print("start training from epoch %s" % epoch_str) + for epoch in range(epoch_str, hps.train.epochs + 1): + if rank == 0: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + # [train_loader, eval_loader], logger, [writer, writer_eval]) + [train_loader, None], + logger, + [writer, writer_eval], + ) + else: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + None, + None, + ) + scheduler_g.step() + # scheduler_d.step() + print("training done") + + +def train_and_evaluate( + rank, + epoch, + hps, + nets, + optims, + schedulers, + scaler, + loaders, + logger, + writers, +): + net_g, net_d = nets + optim_g, optim_d = optims + # scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + # net_d.train() + # for batch_idx, ( + # ssl, + # ssl_lengths, + # spec, + # spec_lengths, + # y, + # y_lengths, + # text, + # text_lengths, + # ) in enumerate(tqdm(train_loader)): + for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate( + tqdm(train_loader) + ): + if torch.cuda.is_available(): + spec, spec_lengths = ( + spec.cuda( + rank, + non_blocking=True, + ), + spec_lengths.cuda( + rank, + non_blocking=True, + ), + ) + mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True) + ssl = ssl.cuda(rank, non_blocking=True) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) + text, text_lengths = ( + text.cuda( + rank, + non_blocking=True, + ), + text_lengths.cuda( + rank, + non_blocking=True, + ), + ) + else: + spec, spec_lengths = spec.to(device), spec_lengths.to(device) + mel, mel_lengths = mel.to(device), mel_lengths.to(device) + ssl = ssl.to(device) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) + text, text_lengths = text.to(device), text_lengths.to(device) + + with autocast(enabled=hps.train.fp16_run): + cfm_loss = net_g( + ssl, + spec, + mel, + ssl_lengths, + spec_lengths, + text, + text_lengths, + mel_lengths, + use_grad_ckpt=hps.train.grad_ckpt, + ) + loss_gen_all = cfm_loss + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + # losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] + losses = [cfm_loss] + logger.info( + "Train Epoch: {} [{:.0f}%]".format( + epoch, + 100.0 * batch_idx / len(train_loader), + ) + ) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} + # image_dict = { + # "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), + # "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), + # "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), + # "all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()), + # } + utils.summarize( + writer=writer, + global_step=global_step, + # images=image_dict, + scalars=scalar_dict, + ) + + # if global_step % hps.train.eval_interval == 0: + # # evaluate(hps, net_g, eval_loader, writer_eval) + # utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "G_{}.pth".format(global_step)),scaler) + # # utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "D_{}.pth".format(global_step)),scaler) + # # keep_ckpts = getattr(hps.train, 'keep_ckpts', 3) + # # if keep_ckpts > 0: + # # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) + + global_step += 1 + if epoch % hps.train.save_every_epoch == 0 and rank == 0: + if hps.train.if_save_latest == 0: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), + "G_{}.pth".format(global_step), + ), + ) + # utils.save_checkpoint( + # net_d, + # optim_d, + # hps.train.learning_rate, + # epoch, + # os.path.join( + # "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step) + # ), + # ) + else: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), + "G_{}.pth".format(233333333333), + ), + ) + # utils.save_checkpoint( + # net_d, + # optim_d, + # hps.train.learning_rate, + # epoch, + # os.path.join( + # "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333) + # ), + # ) + if rank == 0 and hps.train.if_save_every_weights == True: + if hasattr(net_g, "module"): + ckpt = net_g.module.state_dict() + else: + ckpt = net_g.state_dict() + logger.info( + "saving ckpt %s_e%s:%s" + % ( + hps.name, + epoch, + savee( + ckpt, + hps.name + "_e%s_s%s" % (epoch, global_step), + epoch, + global_step, + hps, + ), + ) + ) + + if rank == 0: + logger.info("====> Epoch: {}".format(epoch)) + + +if __name__ == "__main__": + main() diff --git a/s2_train_v3_lora.py b/s2_train_v3_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..ba9e4ed437602cbee8c70c3a976471e73dff09e5 --- /dev/null +++ b/s2_train_v3_lora.py @@ -0,0 +1,379 @@ +import warnings + +warnings.filterwarnings("ignore") +import os + +import utils + +hps = utils.get_hparams(stage=2) +os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") +import logging + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("h5py").setLevel(logging.INFO) +logging.getLogger("numba").setLevel(logging.INFO) +from collections import OrderedDict as od +from random import randint + +from module import commons +from module.data_utils import ( + DistributedBucketSampler, + TextAudioSpeakerCollateV3, + TextAudioSpeakerLoaderV3, + TextAudioSpeakerCollateV4, + TextAudioSpeakerLoaderV4, +) +from module.models import ( + SynthesizerTrnV3 as SynthesizerTrn, +) +from peft import LoraConfig, get_peft_model +from process_ckpt import savee + +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = False +###反正A100fp32更快,那试试tf32吧 +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 +# from config import pretrained_s2G,pretrained_s2D +global_step = 0 + +device = "cpu" # cuda以外的设备,等mps优化后加入 + + +def main(): + if torch.cuda.is_available(): + n_gpus = torch.cuda.device_count() + else: + n_gpus = 1 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(randint(20000, 55555)) + + mp.spawn( + run, + nprocs=n_gpus, + args=( + n_gpus, + hps, + ), + ) + + +def run(rank, n_gpus, hps): + global global_step, no_grad_names, save_root, lora_rank + if rank == 0: + logger = utils.get_logger(hps.data.exp_dir) + logger.info(hps) + # utils.check_git_hash(hps.s2_ckpt_dir) + writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) + + dist.init_process_group( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False", + world_size=n_gpus, + rank=rank, + ) + torch.manual_seed(hps.train.seed) + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + + TextAudioSpeakerLoader = TextAudioSpeakerLoaderV3 if hps.model.version == "v3" else TextAudioSpeakerLoaderV4 + TextAudioSpeakerCollate = TextAudioSpeakerCollateV3 if hps.model.version == "v3" else TextAudioSpeakerCollateV4 + train_dataset = TextAudioSpeakerLoader(hps.data) ######## + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [ + 32, + 300, + 400, + 500, + 600, + 700, + 800, + 900, + 1000, + # 1100, + # 1200, + # 1300, + # 1400, + # 1500, + # 1600, + # 1700, + # 1800, + # 1900, + ], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + collate_fn = TextAudioSpeakerCollate() + train_loader = DataLoader( + train_dataset, + num_workers=6, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + persistent_workers=True, + prefetch_factor=4, + ) + save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank) + os.makedirs(save_root, exist_ok=True) + lora_rank = int(hps.train.lora_rank) + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + + def get_model(hps): + return SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + + def get_optim(net_g): + return torch.optim.AdamW( + filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致 + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + + def model2cuda(net_g, rank): + if torch.cuda.is_available(): + net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True) + else: + net_g = net_g.to(device) + return net_g + + try: # 如果能加载自动resume + net_g = get_model(hps) + net_g.cfm = get_peft_model(net_g.cfm, lora_config) + net_g = model2cuda(net_g, rank) + optim_g = get_optim(net_g) + # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path(save_root, "G_*.pth"), + net_g, + optim_g, + ) + epoch_str += 1 + global_step = (epoch_str - 1) * len(train_loader) + except: # 如果首次不能加载,加载pretrain + # traceback.print_exc() + epoch_str = 1 + global_step = 0 + net_g = get_model(hps) + if ( + hps.train.pretrained_s2G != "" + and hps.train.pretrained_s2G != None + and os.path.exists(hps.train.pretrained_s2G) + ): + if rank == 0: + logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) + print( + "loaded pretrained %s" % hps.train.pretrained_s2G, + net_g.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], + strict=False, + ), + ) + net_g.cfm = get_peft_model(net_g.cfm, lora_config) + net_g = model2cuda(net_g, rank) + optim_g = get_optim(net_g) + + no_grad_names = set() + for name, param in net_g.named_parameters(): + if not param.requires_grad: + no_grad_names.add(name.replace("module.", "")) + # print(name, "not requires_grad") + # print(no_grad_names) + # os._exit(233333) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1) + for _ in range(epoch_str): + scheduler_g.step() + + scaler = GradScaler(enabled=hps.train.fp16_run) + + net_d = optim_d = scheduler_d = None + print("start training from epoch %s" % epoch_str) + for epoch in range(epoch_str, hps.train.epochs + 1): + if rank == 0: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + # [train_loader, eval_loader], logger, [writer, writer_eval]) + [train_loader, None], + logger, + [writer, writer_eval], + ) + else: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + None, + None, + ) + scheduler_g.step() + print("training done") + + +def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): + net_g, net_d = nets + optim_g, optim_d = optims + # scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate( + tqdm(train_loader) + ): + if torch.cuda.is_available(): + spec, spec_lengths = ( + spec.cuda( + rank, + non_blocking=True, + ), + spec_lengths.cuda( + rank, + non_blocking=True, + ), + ) + mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True) + ssl = ssl.cuda(rank, non_blocking=True) + ssl.requires_grad = False + text, text_lengths = ( + text.cuda( + rank, + non_blocking=True, + ), + text_lengths.cuda( + rank, + non_blocking=True, + ), + ) + else: + spec, spec_lengths = spec.to(device), spec_lengths.to(device) + mel, mel_lengths = mel.to(device), mel_lengths.to(device) + ssl = ssl.to(device) + ssl.requires_grad = False + text, text_lengths = text.to(device), text_lengths.to(device) + + with autocast(enabled=hps.train.fp16_run): + cfm_loss = net_g( + ssl, + spec, + mel, + ssl_lengths, + spec_lengths, + text, + text_lengths, + mel_lengths, + use_grad_ckpt=hps.train.grad_ckpt, + ) + loss_gen_all = cfm_loss + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + losses = [cfm_loss] + logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader))) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} + utils.summarize( + writer=writer, + global_step=global_step, + scalars=scalar_dict, + ) + + global_step += 1 + if epoch % hps.train.save_every_epoch == 0 and rank == 0: + if hps.train.if_save_latest == 0: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join(save_root, "G_{}.pth".format(global_step)), + ) + else: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join(save_root, "G_{}.pth".format(233333333333)), + ) + if rank == 0 and hps.train.if_save_every_weights == True: + if hasattr(net_g, "module"): + ckpt = net_g.module.state_dict() + else: + ckpt = net_g.state_dict() + sim_ckpt = od() + for key in ckpt: + # if "cfm"not in key: + # print(key) + if key not in no_grad_names: + sim_ckpt[key] = ckpt[key].half().cpu() + logger.info( + "saving ckpt %s_e%s:%s" + % ( + hps.name, + epoch, + savee( + sim_ckpt, + hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank), + epoch, + global_step, + hps, + model_version=hps.model.version, + lora_rank=lora_rank, + ), + ) + ) + + if rank == 0: + logger.info("====> Epoch: {}".format(epoch)) + + +if __name__ == "__main__": + main() diff --git a/sv.py b/sv.py new file mode 100644 index 0000000000000000000000000000000000000000..22e703692b47e6affca73e03d3a631bcfe200706 --- /dev/null +++ b/sv.py @@ -0,0 +1,32 @@ +import sys +import os +import torch + +sys.path.append(f"{os.getcwd()}/GPT_SoVITS/eres2net") +sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt" +from ERes2NetV2 import ERes2NetV2 +import kaldi as Kaldi + + +class SV: + def __init__(self, device, is_half): + pretrained_state = torch.load(sv_path, map_location="cpu", weights_only=False) + embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4) + embedding_model.load_state_dict(pretrained_state) + embedding_model.eval() + self.embedding_model = embedding_model + if is_half == False: + self.embedding_model = self.embedding_model.to(device) + else: + self.embedding_model = self.embedding_model.half().to(device) + self.is_half = is_half + + def compute_embedding3(self, wav): + with torch.no_grad(): + if self.is_half == True: + wav = wav.half() + feat = torch.stack( + [Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav] + ) + sv_emb = self.embedding_model.forward3(feat) + return sv_emb diff --git a/text/.gitignore b/text/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..aea3ffddfb1bc3f6be2edba63f75aa593aea2317 --- /dev/null +++ b/text/.gitignore @@ -0,0 +1,3 @@ +G2PWModel +__pycache__ +*.zip \ No newline at end of file diff --git a/text/LangSegmenter/__init__.py b/text/LangSegmenter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a7649059c5c723fc85da3b618e00024895cb130 --- /dev/null +++ b/text/LangSegmenter/__init__.py @@ -0,0 +1 @@ +from .langsegmenter import LangSegmenter diff --git a/text/LangSegmenter/langsegmenter.py b/text/LangSegmenter/langsegmenter.py new file mode 100644 index 0000000000000000000000000000000000000000..0187ea697aab94551163bbd6f3ca77c3718da7f4 --- /dev/null +++ b/text/LangSegmenter/langsegmenter.py @@ -0,0 +1,180 @@ +import logging +import re + +# jieba静音 +import jieba + +jieba.setLogLevel(logging.CRITICAL) + +# 更改fast_langdetect大模型位置 +from pathlib import Path +import fast_langdetect + +fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector( + fast_langdetect.infer.LangDetectConfig( + cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect" + ) +) + + +from split_lang import LangSplitter + + +def full_en(text): + pattern = r"^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$" + return bool(re.match(pattern, text)) + + +def full_cjk(text): + # 来自wiki + cjk_ranges = [ + (0x4E00, 0x9FFF), # CJK Unified Ideographs + (0x3400, 0x4DB5), # CJK Extension A + (0x20000, 0x2A6DD), # CJK Extension B + (0x2A700, 0x2B73F), # CJK Extension C + (0x2B740, 0x2B81F), # CJK Extension D + (0x2B820, 0x2CEAF), # CJK Extension E + (0x2CEB0, 0x2EBEF), # CJK Extension F + (0x30000, 0x3134A), # CJK Extension G + (0x31350, 0x323AF), # CJK Extension H + (0x2EBF0, 0x2EE5D), # CJK Extension H + ] + + pattern = r"[0-9、-〜。!?.!?… /]+$" + + cjk_text = "" + for char in text: + code_point = ord(char) + in_cjk = any(start <= code_point <= end for start, end in cjk_ranges) + if in_cjk or re.match(pattern, char): + cjk_text += char + return cjk_text + + +def split_jako(tag_lang, item): + if tag_lang == "ja": + pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)" + else: + pattern = r"([\u1100-\u11FF\u3130-\u318F\uAC00-\uD7AF]+(?:[0-9、-〜。!?.!?… ]+[\u1100-\u11FF\u3130-\u318F\uAC00-\uD7AF]*)*)" + + lang_list: list[dict] = [] + tag = 0 + for match in re.finditer(pattern, item["text"]): + if match.start() > tag: + lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]}) + + tag = match.end() + lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]}) + + if tag < len(item["text"]): + lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]}) + + return lang_list + + +def merge_lang(lang_list, item): + if lang_list and item["lang"] == lang_list[-1]["lang"]: + lang_list[-1]["text"] += item["text"] + else: + lang_list.append(item) + return lang_list + + +class LangSegmenter: + # 默认过滤器, 基于gsv目前四种语言 + DEFAULT_LANG_MAP = { + "zh": "zh", + "yue": "zh", # 粤语 + "wuu": "zh", # 吴语 + "zh-cn": "zh", + "zh-tw": "x", # 繁体设置为x + "ko": "ko", + "ja": "ja", + "en": "en", + } + + def getTexts(text): + lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP) + substr = lang_splitter.split_by_lang(text=text) + + lang_list: list[dict] = [] + + for _, item in enumerate(substr): + dict_item = {"lang": item.lang, "text": item.text} + + # 处理短英文被识别为其他语言的问题 + if full_en(dict_item["text"]): + dict_item["lang"] = "en" + lang_list = merge_lang(lang_list, dict_item) + continue + + # 处理非日语夹日文的问题(不包含CJK) + ja_list: list[dict] = [] + if dict_item["lang"] != "ja": + ja_list = split_jako("ja", dict_item) + + if not ja_list: + ja_list.append(dict_item) + + # 处理非韩语夹韩语的问题(不包含CJK) + ko_list: list[dict] = [] + temp_list: list[dict] = [] + for _, ko_item in enumerate(ja_list): + if ko_item["lang"] != "ko": + ko_list = split_jako("ko", ko_item) + + if ko_list: + temp_list.extend(ko_list) + else: + temp_list.append(ko_item) + + # 未存在非日韩文夹日韩文 + if len(temp_list) == 1: + # 未知语言检查是否为CJK + if dict_item["lang"] == "x": + cjk_text = full_cjk(dict_item["text"]) + if cjk_text: + dict_item = {"lang": "zh", "text": cjk_text} + lang_list = merge_lang(lang_list, dict_item) + else: + lang_list = merge_lang(lang_list, dict_item) + continue + else: + lang_list = merge_lang(lang_list, dict_item) + continue + + # 存在非日韩文夹日韩文 + for _, temp_item in enumerate(temp_list): + # 未知语言检查是否为CJK + if temp_item["lang"] == "x": + cjk_text = full_cjk(dict_item["text"]) + if cjk_text: + dict_item = {"lang": "zh", "text": cjk_text} + lang_list = merge_lang(lang_list, dict_item) + else: + lang_list = merge_lang(lang_list, dict_item) + else: + lang_list = merge_lang(lang_list, temp_item) + + temp_list = lang_list + lang_list = [] + for _, temp_item in enumerate(temp_list): + if temp_item["lang"] == "x": + if lang_list: + temp_item["lang"] = lang_list[-1]["lang"] + elif len(temp_list) > 1: + temp_item["lang"] = temp_list[1]["lang"] + else: + temp_item["lang"] = "zh" + + lang_list = merge_lang(lang_list, temp_item) + + return lang_list + + +if __name__ == "__main__": + text = "MyGO?,你也喜欢まいご吗?" + print(LangSegmenter.getTexts(text)) + + text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。" + print(LangSegmenter.getTexts(text)) diff --git a/text/__init__.py b/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82df1fbbe2d3d05b16545a867d8dfd59c963527e --- /dev/null +++ b/text/__init__.py @@ -0,0 +1,28 @@ +import os +# if os.environ.get("version","v1")=="v1": +# from text.symbols import symbols +# else: +# from text.symbols2 import symbols + +from text import symbols as symbols_v1 +from text import symbols2 as symbols_v2 + +_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)} +_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)} + + +def cleaned_text_to_sequence(cleaned_text, version=None): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + """ + if version is None: + version = os.environ.get("version", "v2") + if version == "v1": + phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text] + else: + phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text] + + return phones diff --git a/text/cantonese.py b/text/cantonese.py new file mode 100644 index 0000000000000000000000000000000000000000..1f07c4144dbc410eaca04daf9e7d2ae374883b85 --- /dev/null +++ b/text/cantonese.py @@ -0,0 +1,222 @@ +# reference: https://huggingface.co/spaces/Naozumi0512/Bert-VITS2-Cantonese-Yue/blob/main/text/chinese.py + +import re +import cn2an +import ToJyutping + +from text.symbols import punctuation +from text.zh_normalization.text_normlization import TextNormalizer + +normalizer = lambda x: cn2an.transform(x, "an2cn") + +INITIALS = [ + "aa", + "aai", + "aak", + "aap", + "aat", + "aau", + "ai", + "au", + "ap", + "at", + "ak", + "a", + "p", + "b", + "e", + "ts", + "t", + "dz", + "d", + "kw", + "k", + "gw", + "g", + "f", + "h", + "l", + "m", + "ng", + "n", + "s", + "y", + "w", + "c", + "z", + "j", + "ong", + "on", + "ou", + "oi", + "ok", + "o", + "uk", + "ung", +] +INITIALS += ["sp", "spl", "spn", "sil"] + + +rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": "…", + "$": ".", + "“": "'", + "”": "'", + '"': "'", + "‘": "'", + "’": "'", + "(": "'", + ")": "'", + "(": "'", + ")": "'", + "《": "'", + "》": "'", + "【": "'", + "】": "'", + "[": "'", + "]": "'", + "—": "-", + "~": "-", + "~": "-", + "「": "'", + "」": "'", +} + + +def replace_punctuation(text): + # text = text.replace("嗯", "恩").replace("呣", "母") + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + + replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text) + + return replaced_text + + +def text_normalize(text): + tx = TextNormalizer() + sentences = tx.normalize(text) + dest_text = "" + for sentence in sentences: + dest_text += replace_punctuation(sentence) + return dest_text + + +punctuation_set = set(punctuation) + + +def jyuping_to_initials_finals_tones(jyuping_syllables): + initials_finals = [] + tones = [] + word2ph = [] + + for syllable in jyuping_syllables: + if syllable in punctuation: + initials_finals.append(syllable) + tones.append(0) + word2ph.append(1) # Add 1 for punctuation + elif syllable == "_": + initials_finals.append(syllable) + tones.append(0) + word2ph.append(1) # Add 1 for underscore + else: + try: + tone = int(syllable[-1]) + syllable_without_tone = syllable[:-1] + except ValueError: + tone = 0 + syllable_without_tone = syllable + + for initial in INITIALS: + if syllable_without_tone.startswith(initial): + if syllable_without_tone.startswith("nga"): + initials_finals.extend( + [ + syllable_without_tone[:2], + syllable_without_tone[2:] or syllable_without_tone[-1], + ] + ) + # tones.extend([tone, tone]) + tones.extend([-1, tone]) + word2ph.append(2) + else: + final = syllable_without_tone[len(initial) :] or initial[-1] + initials_finals.extend([initial, final]) + # tones.extend([tone, tone]) + tones.extend([-1, tone]) + word2ph.append(2) + break + assert len(initials_finals) == len(tones) + + ###魔改为辅音+带音调的元音 + phones = [] + for a, b in zip(initials_finals, tones): + if b not in [-1, 0]: ###防止粤语和普通话重合开头加Y,如果是标点,不加。 + todo = "%s%s" % (a, b) + else: + todo = a + if todo not in punctuation_set: + todo = "Y%s" % todo + phones.append(todo) + + # return initials_finals, tones, word2ph + return phones, word2ph + + +def get_jyutping(text): + jyutping_array = [] + punct_pattern = re.compile(r"^[{}]+$".format(re.escape("".join(punctuation)))) + + syllables = ToJyutping.get_jyutping_list(text) + + for word, syllable in syllables: + if punct_pattern.match(word): + puncts = re.split(r"([{}])".format(re.escape("".join(punctuation))), word) + for punct in puncts: + if len(punct) > 0: + jyutping_array.append(punct) + else: + # match multple jyutping eg: liu4 ge3, or single jyutping eg: liu4 + if not re.search(r"^([a-z]+[1-6]+[ ]?)+$", syllable): + raise ValueError(f"Failed to convert {word} to jyutping: {syllable}") + jyutping_array.append(syllable) + + return jyutping_array + + +def get_bert_feature(text, word2ph): + from text import chinese_bert + + return chinese_bert.get_bert_feature(text, word2ph) + + +def g2p(text): + # word2ph = [] + jyuping = get_jyutping(text) + # print(jyuping) + # phones, tones, word2ph = jyuping_to_initials_finals_tones(jyuping) + phones, word2ph = jyuping_to_initials_finals_tones(jyuping) + # phones = ["_"] + phones + ["_"] + # tones = [0] + tones + [0] + # word2ph = [1] + word2ph + [1] + return phones, word2ph + + +if __name__ == "__main__": + # text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏" + text = "佢個鋤頭太短啦。" + text = text_normalize(text) + # phones, tones, word2ph = g2p(text) + phones, word2ph = g2p(text) + # print(phones, tones, word2ph) + print(phones, word2ph) diff --git a/text/chinese.py b/text/chinese.py new file mode 100644 index 0000000000000000000000000000000000000000..ce44215f09a86c36b17abc1380fd836569571301 --- /dev/null +++ b/text/chinese.py @@ -0,0 +1,208 @@ +import os +import re + +import cn2an +from pypinyin import lazy_pinyin, Style + +from text.symbols import punctuation +from text.tone_sandhi import ToneSandhi +from text.zh_normalization.text_normlization import TextNormalizer + +normalizer = lambda x: cn2an.transform(x, "an2cn") + +current_file_path = os.path.dirname(__file__) +pinyin_to_symbol_map = { + line.split("\t")[0]: line.strip().split("\t")[1] + for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() +} + +import jieba_fast +import logging + +jieba_fast.setLogLevel(logging.CRITICAL) +import jieba_fast.posseg as psg + + +rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": "…", + "$": ".", + "/": ",", + "—": "-", + "~": "…", + "~": "…", +} + +tone_modifier = ToneSandhi() + + +def replace_punctuation(text): + text = text.replace("嗯", "恩").replace("呣", "母") + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + + replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text) + + return replaced_text + + +def replace_punctuation_with_en(text): + text = text.replace("嗯", "恩").replace("呣", "母") + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + + replaced_text = re.sub(r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text) + + return replaced_text + + +def replace_consecutive_punctuation(text): + punctuations = "".join(re.escape(p) for p in punctuation) + pattern = f"([{punctuations}])([{punctuations}])+" + result = re.sub(pattern, r"\1", text) + return result + + +def g2p(text): + pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) + sentences = [i for i in re.split(pattern, text) if i.strip() != ""] + phones, word2ph = _g2p(sentences) + return phones, word2ph + + +def _get_initials_finals(word): + initials = [] + finals = [] + orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS) + orig_finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) + for c, v in zip(orig_initials, orig_finals): + initials.append(c) + finals.append(v) + return initials, finals + + +def _g2p(segments): + phones_list = [] + word2ph = [] + for seg in segments: + pinyins = [] + # Replace all English words in the sentence + seg = re.sub("[a-zA-Z]+", "", seg) + seg_cut = psg.lcut(seg) + initials = [] + finals = [] + seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) + for word, pos in seg_cut: + if pos == "eng": + continue + sub_initials, sub_finals = _get_initials_finals(word) + sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) + initials.append(sub_initials) + finals.append(sub_finals) + + # assert len(sub_initials) == len(sub_finals) == len(word) + initials = sum(initials, []) + finals = sum(finals, []) + # + for c, v in zip(initials, finals): + raw_pinyin = c + v + # NOTE: post process for pypinyin outputs + # we discriminate i, ii and iii + if c == v: + assert c in punctuation + phone = [c] + word2ph.append(1) + else: + v_without_tone = v[:-1] + tone = v[-1] + + pinyin = c + v_without_tone + assert tone in "12345" + + if c: + # 多音节 + v_rep_map = { + "uei": "ui", + "iou": "iu", + "uen": "un", + } + if v_without_tone in v_rep_map.keys(): + pinyin = c + v_rep_map[v_without_tone] + else: + # 单音节 + pinyin_rep_map = { + "ing": "ying", + "i": "yi", + "in": "yin", + "u": "wu", + } + if pinyin in pinyin_rep_map.keys(): + pinyin = pinyin_rep_map[pinyin] + else: + single_rep_map = { + "v": "yu", + "e": "e", + "i": "y", + "u": "w", + } + if pinyin[0] in single_rep_map.keys(): + pinyin = single_rep_map[pinyin[0]] + pinyin[1:] + + assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) + new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") + new_v = new_v + tone + phone = [new_c, new_v] + word2ph.append(len(phone)) + + phones_list += phone + return phones_list, word2ph + + +def text_normalize(text): + # https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization + tx = TextNormalizer() + sentences = tx.normalize(text) + dest_text = "" + for sentence in sentences: + dest_text += replace_punctuation(sentence) + + # 避免重复标点引起的参考泄露 + dest_text = replace_consecutive_punctuation(dest_text) + return dest_text + + +# 不排除英文的文本格式化 +def mix_text_normalize(text): + # https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization + tx = TextNormalizer() + sentences = tx.normalize(text) + dest_text = "" + for sentence in sentences: + dest_text += replace_punctuation_with_en(sentence) + + # 避免重复标点引起的参考泄露 + dest_text = replace_consecutive_punctuation(dest_text) + return dest_text + + +if __name__ == "__main__": + text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" + text = "呣呣呣~就是…大人的鼹鼠党吧?" + text = "你好" + text = text_normalize(text) + print(g2p(text)) + + +# # 示例用法 +# text = "这是一个示例文本:,你好!这是一个测试..." +# print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试 diff --git a/text/chinese2.py b/text/chinese2.py new file mode 100644 index 0000000000000000000000000000000000000000..612aa3a5f8dd87960e9a3ff96c37dcddca4cdf73 --- /dev/null +++ b/text/chinese2.py @@ -0,0 +1,353 @@ +import os +import re + +import cn2an +from pypinyin import lazy_pinyin, Style +from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials + +from text.symbols import punctuation +from text.tone_sandhi import ToneSandhi +from text.zh_normalization.text_normlization import TextNormalizer + +normalizer = lambda x: cn2an.transform(x, "an2cn") + +current_file_path = os.path.dirname(__file__) +pinyin_to_symbol_map = { + line.split("\t")[0]: line.strip().split("\t")[1] + for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() +} + +import jieba_fast +import logging + +jieba_fast.setLogLevel(logging.CRITICAL) +import jieba_fast.posseg as psg + +# is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启 +# is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False +is_g2pw = True # True if is_g2pw_str.lower() == 'true' else False +if is_g2pw: + # print("当前使用g2pw进行拼音推理") + from text.g2pw import G2PWPinyin, correct_pronunciation + + parent_directory = os.path.dirname(current_file_path) + g2pw = G2PWPinyin( + model_dir="GPT_SoVITS/text/G2PWModel", + model_source=os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), + v_to_u=False, + neutral_tone_with_five=True, + ) + +rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": "…", + "$": ".", + "/": ",", + "—": "-", + "~": "…", + "~": "…", +} + +tone_modifier = ToneSandhi() + + +def replace_punctuation(text): + text = text.replace("嗯", "恩").replace("呣", "母") + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + + replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text) + + return replaced_text + + +def g2p(text): + pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) + sentences = [i for i in re.split(pattern, text) if i.strip() != ""] + phones, word2ph = _g2p(sentences) + return phones, word2ph + + +def _get_initials_finals(word): + initials = [] + finals = [] + + orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS) + orig_finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) + + for c, v in zip(orig_initials, orig_finals): + initials.append(c) + finals.append(v) + return initials, finals + + +must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"} +not_erhua = { + "虐儿", + "为儿", + "护儿", + "瞒儿", + "救儿", + "替儿", + "有儿", + "一儿", + "我儿", + "俺儿", + "妻儿", + "拐儿", + "聋儿", + "乞儿", + "患儿", + "幼儿", + "孤儿", + "婴儿", + "婴幼儿", + "连体儿", + "脑瘫儿", + "流浪儿", + "体弱儿", + "混血儿", + "蜜雪儿", + "舫儿", + "祖儿", + "美儿", + "应采儿", + "可儿", + "侄儿", + "孙儿", + "侄孙儿", + "女儿", + "男儿", + "红孩儿", + "花儿", + "虫儿", + "马儿", + "鸟儿", + "猪儿", + "猫儿", + "狗儿", + "少儿", +} + + +def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) -> list[list[str]]: + """ + Do erhub. + """ + # fix er1 + for i, phn in enumerate(finals): + if i == len(finals) - 1 and word[i] == "儿" and phn == "er1": + finals[i] = "er2" + + # 发音 + if word not in must_erhua and (word in not_erhua or pos in {"a", "j", "nr"}): + return initials, finals + + # "……" 等情况直接返回 + if len(finals) != len(word): + return initials, finals + + assert len(finals) == len(word) + + # 与前一个字发同音 + new_initials = [] + new_finals = [] + for i, phn in enumerate(finals): + if ( + i == len(finals) - 1 + and word[i] == "儿" + and phn in {"er2", "er5"} + and word[-2:] not in not_erhua + and new_finals + ): + phn = "er" + new_finals[-1][-1] + + new_initials.append(initials[i]) + new_finals.append(phn) + + return new_initials, new_finals + + +def _g2p(segments): + phones_list = [] + word2ph = [] + for seg in segments: + pinyins = [] + # Replace all English words in the sentence + seg = re.sub("[a-zA-Z]+", "", seg) + seg_cut = psg.lcut(seg) + seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) + initials = [] + finals = [] + + if not is_g2pw: + for word, pos in seg_cut: + if pos == "eng": + continue + sub_initials, sub_finals = _get_initials_finals(word) + sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) + # 儿化 + sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) + initials.append(sub_initials) + finals.append(sub_finals) + # assert len(sub_initials) == len(sub_finals) == len(word) + initials = sum(initials, []) + finals = sum(finals, []) + print("pypinyin结果", initials, finals) + else: + # g2pw采用整句推理 + pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3) + + pre_word_length = 0 + for word, pos in seg_cut: + sub_initials = [] + sub_finals = [] + now_word_length = pre_word_length + len(word) + + if pos == "eng": + pre_word_length = now_word_length + continue + + word_pinyins = pinyins[pre_word_length:now_word_length] + + # 多音字消歧 + word_pinyins = correct_pronunciation(word, word_pinyins) + + for pinyin in word_pinyins: + if pinyin[0].isalpha(): + sub_initials.append(to_initials(pinyin)) + sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True)) + else: + sub_initials.append(pinyin) + sub_finals.append(pinyin) + + pre_word_length = now_word_length + sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) + # 儿化 + sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) + initials.append(sub_initials) + finals.append(sub_finals) + + initials = sum(initials, []) + finals = sum(finals, []) + # print("g2pw结果",initials,finals) + + for c, v in zip(initials, finals): + raw_pinyin = c + v + # NOTE: post process for pypinyin outputs + # we discriminate i, ii and iii + if c == v: + assert c in punctuation + phone = [c] + word2ph.append(1) + else: + v_without_tone = v[:-1] + tone = v[-1] + + pinyin = c + v_without_tone + assert tone in "12345" + + if c: + # 多音节 + v_rep_map = { + "uei": "ui", + "iou": "iu", + "uen": "un", + } + if v_without_tone in v_rep_map.keys(): + pinyin = c + v_rep_map[v_without_tone] + else: + # 单音节 + pinyin_rep_map = { + "ing": "ying", + "i": "yi", + "in": "yin", + "u": "wu", + } + if pinyin in pinyin_rep_map.keys(): + pinyin = pinyin_rep_map[pinyin] + else: + single_rep_map = { + "v": "yu", + "e": "e", + "i": "y", + "u": "w", + } + if pinyin[0] in single_rep_map.keys(): + pinyin = single_rep_map[pinyin[0]] + pinyin[1:] + + assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) + new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") + new_v = new_v + tone + phone = [new_c, new_v] + word2ph.append(len(phone)) + + phones_list += phone + return phones_list, word2ph + + +def replace_punctuation_with_en(text): + text = text.replace("嗯", "恩").replace("呣", "母") + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + + replaced_text = re.sub(r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text) + + return replaced_text + + +def replace_consecutive_punctuation(text): + punctuations = "".join(re.escape(p) for p in punctuation) + pattern = f"([{punctuations}])([{punctuations}])+" + result = re.sub(pattern, r"\1", text) + return result + + +def text_normalize(text): + # https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization + tx = TextNormalizer() + sentences = tx.normalize(text) + dest_text = "" + for sentence in sentences: + dest_text += replace_punctuation(sentence) + + # 避免重复标点引起的参考泄露 + dest_text = replace_consecutive_punctuation(dest_text) + return dest_text + + +# 不排除英文的文本格式化 +def mix_text_normalize(text): + # https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization + tx = TextNormalizer() + sentences = tx.normalize(text) + dest_text = "" + for sentence in sentences: + dest_text += replace_punctuation_with_en(sentence) + + # 避免重复标点引起的参考泄露 + dest_text = replace_consecutive_punctuation(dest_text) + return dest_text + + +if __name__ == "__main__": + text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" + text = "呣呣呣~就是…大人的鼹鼠党吧?" + text = "你好" + text = text_normalize(text) + print(g2p(text)) + + +# # 示例用法 +# text = "这是一个示例文本:,你好!这是一个测试..." +# print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试 diff --git a/text/cleaner.py b/text/cleaner.py new file mode 100644 index 0000000000000000000000000000000000000000..7ba8f376bd8e17c64907d26740cae0bfd90230e6 --- /dev/null +++ b/text/cleaner.py @@ -0,0 +1,94 @@ +from text import cleaned_text_to_sequence +import os +# if os.environ.get("version","v1")=="v1": +# from text import chinese +# from text.symbols import symbols +# else: +# from text import chinese2 as chinese +# from text.symbols2 import symbols + +from text import symbols as symbols_v1 +from text import symbols2 as symbols_v2 + +special = [ + # ("%", "zh", "SP"), + ("¥", "zh", "SP2"), + ("^", "zh", "SP3"), + # ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧 +] + + +def clean_text(text, language, version=None): + if version is None: + version = os.environ.get("version", "v2") + if version == "v1": + symbols = symbols_v1.symbols + language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"} + else: + symbols = symbols_v2.symbols + language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"} + + if language not in language_module_map: + language = "en" + text = " " + for special_s, special_l, target_symbol in special: + if special_s in text and language == special_l: + return clean_special(text, language, special_s, target_symbol, version) + language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]]) + if hasattr(language_module, "text_normalize"): + norm_text = language_module.text_normalize(text) + else: + norm_text = text + if language == "zh" or language == "yue": ########## + phones, word2ph = language_module.g2p(norm_text) + assert len(phones) == sum(word2ph) + assert len(norm_text) == len(word2ph) + elif language == "en": + phones = language_module.g2p(norm_text) + if len(phones) < 4: + phones = [","] + phones + word2ph = None + else: + phones = language_module.g2p(norm_text) + word2ph = None + phones = ["UNK" if ph not in symbols else ph for ph in phones] + return phones, word2ph, norm_text + + +def clean_special(text, language, special_s, target_symbol, version=None): + if version is None: + version = os.environ.get("version", "v2") + if version == "v1": + symbols = symbols_v1.symbols + language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"} + else: + symbols = symbols_v2.symbols + language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"} + + """ + 特殊静音段sp符号处理 + """ + text = text.replace(special_s, ",") + language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]]) + norm_text = language_module.text_normalize(text) + phones = language_module.g2p(norm_text) + new_ph = [] + for ph in phones[0]: + assert ph in symbols + if ph == ",": + new_ph.append(target_symbol) + else: + new_ph.append(ph) + return new_ph, phones[1], norm_text + + +def text_to_sequence(text, language, version=None): + version = os.environ.get("version", version) + if version is None: + version = "v2" + phones = clean_text(text) + return cleaned_text_to_sequence(phones, version) + + +if __name__ == "__main__": + print(clean_text("你好%啊啊啊额、还是到付红四方。", "zh")) diff --git a/text/cmudict-fast.rep b/text/cmudict-fast.rep new file mode 100644 index 0000000000000000000000000000000000000000..b975207d56b9b9c9578d17a190533cc257a594cf --- /dev/null +++ b/text/cmudict-fast.rep @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53bfef0f27d7dd74d1ba74563d1e076d3e0672ce3596cb2d6c0d52ac9ad01f6d +size 3613898 diff --git a/text/cmudict.rep b/text/cmudict.rep new file mode 100644 index 0000000000000000000000000000000000000000..8a910783573b2f321bbfa6f93721fa5deec23b30 --- /dev/null +++ b/text/cmudict.rep @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e601d017d6e6f958443d41cd8922b4cd7598b3ba2056253a33f3e5a35f38494 +size 3731285 diff --git a/text/en_normalization/expend.py b/text/en_normalization/expend.py new file mode 100644 index 0000000000000000000000000000000000000000..bbd607cd261d0f5ed1ec7f8bec078321e90f4955 --- /dev/null +++ b/text/en_normalization/expend.py @@ -0,0 +1,283 @@ +# by https://github.com/Cosmo-klara + +from __future__ import print_function + +import re +import inflect +import unicodedata + +# 后缀计量单位替换表 +measurement_map = { + "m": ["meter", "meters"], + "km": ["kilometer", "kilometers"], + "km/h": ["kilometer per hour", "kilometers per hour"], + "ft": ["feet", "feet"], + "L": ["liter", "liters"], + "tbsp": ["tablespoon", "tablespoons"], + "tsp": ["teaspoon", "teaspoons"], + "h": ["hour", "hours"], + "min": ["minute", "minutes"], + "s": ["second", "seconds"], + "°C": ["degree celsius", "degrees celsius"], + "°F": ["degree fahrenheit", "degrees fahrenheit"], +} + + +# 识别 12,000 类型 +_inflect = inflect.engine() + +# 转化数字序数词 +_ordinal_number_re = re.compile(r"\b([0-9]+)\. ") + +# 我听说好像对于数字正则识别其实用 \d 会好一点 + +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") + +# 时间识别 +_time_re = re.compile(r"\b([01]?[0-9]|2[0-3]):([0-5][0-9])\b") + +# 后缀计量单位识别 +_measurement_re = re.compile(r"\b([0-9]+(\.[0-9]+)?(m|km|km/h|ft|L|tbsp|tsp|h|min|s|°C|°F))\b") + +# 前后 £ 识别 ( 写了识别两边某一边的,但是不知道为什么失败了┭┮﹏┭┮ ) +_pounds_re_start = re.compile(r"£([0-9\.\,]*[0-9]+)") +_pounds_re_end = re.compile(r"([0-9\.\,]*[0-9]+)£") + +# 前后 $ 识别 +_dollars_re_start = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_dollars_re_end = re.compile(r"([(0-9\.\,]*[0-9]+)\$") + +# 小数的识别 +_decimal_number_re = re.compile(r"([0-9]+\.\s*[0-9]+)") + +# 分数识别 (形式 "3/4" ) +_fraction_re = re.compile(r"([0-9]+/[0-9]+)") + +# 序数词识别 +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") + +# 数字处理 +_number_re = re.compile(r"[0-9]+") + + +def _convert_ordinal(m): + """ + 标准化序数词, 例如: 1. 2. 3. 4. 5. 6. + Examples: + input: "1. " + output: "1st" + 然后在后面的 _expand_ordinal, 将其转化为 first 这类的 + """ + ordinal = _inflect.ordinal(m.group(1)) + return ordinal + ", " + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_time(m): + """ + 将 24 小时制的时间转换为 12 小时制的时间表示方式。 + + Examples: + input: "13:00 / 4:00 / 13:30" + output: "one o'clock p.m. / four o'clock am. / one thirty p.m." + """ + hours, minutes = map(int, m.group(1, 2)) + period = "a.m." if hours < 12 else "p.m." + if hours > 12: + hours -= 12 + + hour_word = _inflect.number_to_words(hours) + minute_word = _inflect.number_to_words(minutes) if minutes != 0 else "" + + if minutes == 0: + return f"{hour_word} o'clock {period}" + else: + return f"{hour_word} {minute_word} {period}" + + +def _expand_measurement(m): + """ + 处理一些常见的测量单位后缀, 目前支持: m, km, km/h, ft, L, tbsp, tsp, h, min, s, °C, °F + 如果要拓展的话修改: _measurement_re 和 measurement_map + """ + sign = m.group(3) + ptr = 1 + # 想不到怎么方便的取数字,又懒得改正则,诶,1.2 反正也是复数读法,干脆直接去掉 "." + num = int(m.group(1).replace(sign, "").replace(".", "")) + decimal_part = m.group(2) + # 上面判断的漏洞,比如 0.1 的情况,在这里排除了 + if decimal_part == None and num == 1: + ptr = 0 + return m.group(1).replace(sign, " " + measurement_map[sign][ptr]) + + +def _expand_pounds(m): + """ + 没找到特别规范的说明,和美元的处理一样,其实可以把两个合并在一起 + """ + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " pounds" # Unexpected format + pounds = int(parts[0]) if parts[0] else 0 + pence = int(parts[1].ljust(2, "0")) if len(parts) > 1 and parts[1] else 0 + if pounds and pence: + pound_unit = "pound" if pounds == 1 else "pounds" + penny_unit = "penny" if pence == 1 else "pence" + return "%s %s and %s %s" % (pounds, pound_unit, pence, penny_unit) + elif pounds: + pound_unit = "pound" if pounds == 1 else "pounds" + return "%s %s" % (pounds, pound_unit) + elif pence: + penny_unit = "penny" if pence == 1 else "pence" + return "%s %s" % (pence, penny_unit) + else: + return "zero pounds" + + +def _expand_dollars(m): + """ + change: 美分是 100 的限值, 应该要做补零的吧 + Example: + input: "32.3$ / $6.24" + output: "thirty-two dollars and thirty cents" / "six dollars and twenty-four cents" + """ + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1].ljust(2, "0")) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s and %s %s" % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) + else: + return "zero dollars" + + +# 小数的处理 +def _expand_decimal_number(m): + """ + Example: + input: "13.234" + output: "thirteen point two three four" + """ + match = m.group(1) + parts = match.split(".") + words = [] + # 遍历字符串中的每个字符 + for char in parts[1]: + if char == ".": + words.append("point") + else: + words.append(char) + return parts[0] + " point " + " ".join(words) + + +# 分数的处理 +def _expend_fraction(m): + """ + 规则1: 分子使用基数词读法, 分母用序数词读法. + 规则2: 如果分子大于 1, 在读分母的时候使用序数词复数读法. + 规则3: 当分母为2的时候, 分母读做 half, 并且当分子大于 1 的时候, half 也要用复数读法, 读为 halves. + Examples: + + | Written | Said | + |:---:|:---:| + | 1/3 | one third | + | 3/4 | three fourths | + | 5/6 | five sixths | + | 1/2 | one half | + | 3/2 | three halves | + """ + match = m.group(0) + numerator, denominator = map(int, match.split("/")) + + numerator_part = _inflect.number_to_words(numerator) + if denominator == 2: + if numerator == 1: + denominator_part = "half" + else: + denominator_part = "halves" + elif denominator == 1: + return f"{numerator_part}" + else: + denominator_part = _inflect.ordinal(_inflect.number_to_words(denominator)) + if numerator > 1: + denominator_part += "s" + + return f"{numerator_part} {denominator_part}" + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +def normalize(text): + """ + !!! 所有的处理都需要正确的输入 !!! + 可以添加新的处理,只需要添加正则表达式和对应的处理函数即可 + """ + + text = re.sub(_ordinal_number_re, _convert_ordinal, text) + text = re.sub(r"(?= start_line: + line = line.strip() + word_split = line.split(" ") + word = word_split[0].lower() + + syllable_split = word_split[1].split(" - ") + g2p_dict[word] = [] + for syllable in syllable_split: + phone_split = syllable.split(" ") + g2p_dict[word].append(phone_split) + + line_index = line_index + 1 + line = f.readline() + + return g2p_dict + + +def read_dict_new(): + g2p_dict = {} + with open(CMU_DICT_PATH) as f: + line = f.readline() + line_index = 1 + while line: + if line_index >= 57: + line = line.strip() + word_split = line.split(" ") + word = word_split[0].lower() + g2p_dict[word] = [word_split[1].split(" ")] + + line_index = line_index + 1 + line = f.readline() + + with open(CMU_DICT_FAST_PATH) as f: + line = f.readline() + line_index = 1 + while line: + if line_index >= 0: + line = line.strip() + word_split = line.split(" ") + word = word_split[0].lower() + if word not in g2p_dict: + g2p_dict[word] = [word_split[1:]] + + line_index = line_index + 1 + line = f.readline() + + return g2p_dict + + +def hot_reload_hot(g2p_dict): + with open(CMU_DICT_HOT_PATH) as f: + line = f.readline() + line_index = 1 + while line: + if line_index >= 0: + line = line.strip() + word_split = line.split(" ") + word = word_split[0].lower() + # 自定义发音词直接覆盖字典 + g2p_dict[word] = [word_split[1:]] + + line_index = line_index + 1 + line = f.readline() + + return g2p_dict + + +def cache_dict(g2p_dict, file_path): + with open(file_path, "wb") as pickle_file: + pickle.dump(g2p_dict, pickle_file) + + +def get_dict(): + if os.path.exists(CACHE_PATH): + with open(CACHE_PATH, "rb") as pickle_file: + g2p_dict = pickle.load(pickle_file) + else: + g2p_dict = read_dict_new() + cache_dict(g2p_dict, CACHE_PATH) + + g2p_dict = hot_reload_hot(g2p_dict) + + return g2p_dict + + +def get_namedict(): + if os.path.exists(NAMECACHE_PATH): + with open(NAMECACHE_PATH, "rb") as pickle_file: + name_dict = pickle.load(pickle_file) + else: + name_dict = {} + + return name_dict + + +def text_normalize(text): + # todo: eng text normalize + + # 效果相同,和 chinese.py 保持一致 + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + text = pattern.sub(lambda x: rep_map[x.group()], text) + + text = unicode(text) + text = normalize(text) + + # 避免重复标点引起的参考泄露 + text = replace_consecutive_punctuation(text) + return text + + +class en_G2p(G2p): + def __init__(self): + super().__init__() + # 分词初始化 + wordsegment.load() + + # 扩展过时字典, 添加姓名字典 + self.cmu = get_dict() + self.namedict = get_namedict() + + # 剔除读音错误的几个缩写 + for word in ["AE", "AI", "AR", "IOS", "HUD", "OS"]: + del self.cmu[word.lower()] + + # 修正多音字 + self.homograph2features["read"] = (["R", "IY1", "D"], ["R", "EH1", "D"], "VBP") + self.homograph2features["complex"] = ( + ["K", "AH0", "M", "P", "L", "EH1", "K", "S"], + ["K", "AA1", "M", "P", "L", "EH0", "K", "S"], + "JJ", + ) + + def __call__(self, text): + # tokenization + words = word_tokenize(text) + tokens = pos_tag(words) # tuples of (word, tag) + + # steps + prons = [] + for o_word, pos in tokens: + # 还原 g2p_en 小写操作逻辑 + word = o_word.lower() + + if re.search("[a-z]", word) is None: + pron = [word] + # 先把单字母推出去 + elif len(word) == 1: + # 单读 A 发音修正, 这里需要原格式 o_word 判断大写 + if o_word == "A": + pron = ["EY1"] + else: + pron = self.cmu[word][0] + # g2p_en 原版多音字处理 + elif word in self.homograph2features: # Check homograph + pron1, pron2, pos1 = self.homograph2features[word] + if pos.startswith(pos1): + pron = pron1 + # pos1比pos长仅出现在read + elif len(pos) < len(pos1) and pos == pos1[: len(pos)]: + pron = pron1 + else: + pron = pron2 + else: + # 递归查找预测 + pron = self.qryword(o_word) + + prons.extend(pron) + prons.extend([" "]) + + return prons[:-1] + + def qryword(self, o_word): + word = o_word.lower() + + # 查字典, 单字母除外 + if len(word) > 1 and word in self.cmu: # lookup CMU dict + return self.cmu[word][0] + + # 单词仅首字母大写时查找姓名字典 + if o_word.istitle() and word in self.namedict: + return self.namedict[word][0] + + # oov 长度小于等于 3 直接读字母 + if len(word) <= 3: + phones = [] + for w in word: + # 单读 A 发音修正, 此处不存在大写的情况 + if w == "a": + phones.extend(["EY1"]) + elif not w.isalpha(): + phones.extend([w]) + else: + phones.extend(self.cmu[w][0]) + return phones + + # 尝试分离所有格 + if re.match(r"^([a-z]+)('s)$", word): + phones = self.qryword(word[:-2])[:] + # P T K F TH HH 无声辅音结尾 's 发 ['S'] + if phones[-1] in ["P", "T", "K", "F", "TH", "HH"]: + phones.extend(["S"]) + # S Z SH ZH CH JH 擦声结尾 's 发 ['IH1', 'Z'] 或 ['AH0', 'Z'] + elif phones[-1] in ["S", "Z", "SH", "ZH", "CH", "JH"]: + phones.extend(["AH0", "Z"]) + # B D G DH V M N NG L R W Y 有声辅音结尾 's 发 ['Z'] + # AH0 AH1 AH2 EY0 EY1 EY2 AE0 AE1 AE2 EH0 EH1 EH2 OW0 OW1 OW2 UH0 UH1 UH2 IY0 IY1 IY2 AA0 AA1 AA2 AO0 AO1 AO2 + # ER ER0 ER1 ER2 UW0 UW1 UW2 AY0 AY1 AY2 AW0 AW1 AW2 OY0 OY1 OY2 IH IH0 IH1 IH2 元音结尾 's 发 ['Z'] + else: + phones.extend(["Z"]) + return phones + + # 尝试进行分词,应对复合词 + comps = wordsegment.segment(word.lower()) + + # 无法分词的送回去预测 + if len(comps) == 1: + return self.predict(word) + + # 可以分词的递归处理 + return [phone for comp in comps for phone in self.qryword(comp)] + + +_g2p = en_G2p() + + +def g2p(text): + # g2p_en 整段推理,剔除不存在的arpa返回 + phone_list = _g2p(text) + phones = [ph if ph != "" else "UNK" for ph in phone_list if ph not in [" ", "", "UW", "", ""]] + + return replace_phs(phones) + + +if __name__ == "__main__": + print(g2p("hello")) + print(g2p(text_normalize("e.g. I used openai's AI tool to draw a picture."))) + print(g2p(text_normalize("In this; paper, we propose 1 DSPGAN, a GAN-based universal vocoder."))) diff --git a/text/g2pw/__init__.py b/text/g2pw/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5ab811d985aa194bfff021c85d25e8130c8e0eb9 --- /dev/null +++ b/text/g2pw/__init__.py @@ -0,0 +1 @@ +from text.g2pw.g2pw import * diff --git a/text/g2pw/dataset.py b/text/g2pw/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ff09cbc25161c57d671f271a97ee6a7dc57bf154 --- /dev/null +++ b/text/g2pw/dataset.py @@ -0,0 +1,160 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Credits + This code is modified from https://github.com/GitYCC/g2pW +""" + +from typing import Dict +from typing import List +from typing import Tuple + +import numpy as np + +from .utils import tokenize_and_map + +ANCHOR_CHAR = "▁" + + +def prepare_onnx_input( + tokenizer, + labels: List[str], + char2phonemes: Dict[str, List[int]], + chars: List[str], + texts: List[str], + query_ids: List[int], + use_mask: bool = False, + window_size: int = None, + max_len: int = 512, +) -> Dict[str, np.array]: + if window_size is not None: + truncated_texts, truncated_query_ids = _truncate_texts( + window_size=window_size, texts=texts, query_ids=query_ids + ) + input_ids = [] + token_type_ids = [] + attention_masks = [] + phoneme_masks = [] + char_ids = [] + position_ids = [] + + for idx in range(len(texts)): + text = (truncated_texts if window_size else texts)[idx].lower() + query_id = (truncated_query_ids if window_size else query_ids)[idx] + + try: + tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text) + except Exception: + print(f'warning: text "{text}" is invalid') + return {} + + text, query_id, tokens, text2token, token2text = _truncate( + max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text + ) + + processed_tokens = ["[CLS]"] + tokens + ["[SEP]"] + + input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) + token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) + attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) + + query_char = text[query_id] + phoneme_mask = ( + [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels) + ) + char_id = chars.index(query_char) + position_id = text2token[query_id] + 1 # [CLS] token locate at first place + + input_ids.append(input_id) + token_type_ids.append(token_type_id) + attention_masks.append(attention_mask) + phoneme_masks.append(phoneme_mask) + char_ids.append(char_id) + position_ids.append(position_id) + + outputs = { + "input_ids": np.array(input_ids).astype(np.int64), + "token_type_ids": np.array(token_type_ids).astype(np.int64), + "attention_masks": np.array(attention_masks).astype(np.int64), + "phoneme_masks": np.array(phoneme_masks).astype(np.float32), + "char_ids": np.array(char_ids).astype(np.int64), + "position_ids": np.array(position_ids).astype(np.int64), + } + return outputs + + +def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]: + truncated_texts = [] + truncated_query_ids = [] + for text, query_id in zip(texts, query_ids): + start = max(0, query_id - window_size // 2) + end = min(len(text), query_id + window_size // 2) + truncated_text = text[start:end] + truncated_texts.append(truncated_text) + + truncated_query_id = query_id - start + truncated_query_ids.append(truncated_query_id) + return truncated_texts, truncated_query_ids + + +def _truncate( + max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]] +): + truncate_len = max_len - 2 + if len(tokens) <= truncate_len: + return (text, query_id, tokens, text2token, token2text) + + token_position = text2token[query_id] + + token_start = token_position - truncate_len // 2 + token_end = token_start + truncate_len + font_exceed_dist = -token_start + back_exceed_dist = token_end - len(tokens) + if font_exceed_dist > 0: + token_start += font_exceed_dist + token_end += font_exceed_dist + elif back_exceed_dist > 0: + token_start -= back_exceed_dist + token_end -= back_exceed_dist + + start = token2text[token_start][0] + end = token2text[token_end - 1][1] + + return ( + text[start:end], + query_id - start, + tokens[token_start:token_end], + [i - token_start if i is not None else None for i in text2token[start:end]], + [(s - start, e - start) for s, e in token2text[token_start:token_end]], + ) + + +def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]: + labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars]))) + char2phonemes = {} + for char, phoneme in polyphonic_chars: + if char not in char2phonemes: + char2phonemes[char] = [] + char2phonemes[char].append(labels.index(phoneme)) + return labels, char2phonemes + + +def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]: + labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars]))) + char2phonemes = {} + for char, phoneme in polyphonic_chars: + if char not in char2phonemes: + char2phonemes[char] = [] + char2phonemes[char].append(labels.index(f"{char} {phoneme}")) + return labels, char2phonemes diff --git a/text/g2pw/g2pw.py b/text/g2pw/g2pw.py new file mode 100644 index 0000000000000000000000000000000000000000..08525e91f5409e9ed044908899146c9f9cb67520 --- /dev/null +++ b/text/g2pw/g2pw.py @@ -0,0 +1,159 @@ +# This code is modified from https://github.com/mozillazg/pypinyin-g2pW + +import pickle +import os + +from pypinyin.constants import RE_HANS +from pypinyin.core import Pinyin, Style +from pypinyin.seg.simpleseg import simple_seg +from pypinyin.converter import UltimateConverter +from pypinyin.contrib.tone_convert import to_tone +from .onnx_api import G2PWOnnxConverter + +current_file_path = os.path.dirname(__file__) +CACHE_PATH = os.path.join(current_file_path, "polyphonic.pickle") +PP_DICT_PATH = os.path.join(current_file_path, "polyphonic.rep") +PP_FIX_DICT_PATH = os.path.join(current_file_path, "polyphonic-fix.rep") + + +class G2PWPinyin(Pinyin): + def __init__( + self, + model_dir="G2PWModel/", + model_source=None, + enable_non_tradional_chinese=True, + v_to_u=False, + neutral_tone_with_five=False, + tone_sandhi=False, + **kwargs, + ): + self._g2pw = G2PWOnnxConverter( + model_dir=model_dir, + style="pinyin", + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + self._converter = Converter( + self._g2pw, + v_to_u=v_to_u, + neutral_tone_with_five=neutral_tone_with_five, + tone_sandhi=tone_sandhi, + ) + + def get_seg(self, **kwargs): + return simple_seg + + +class Converter(UltimateConverter): + def __init__(self, g2pw_instance, v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs): + super(Converter, self).__init__( + v_to_u=v_to_u, neutral_tone_with_five=neutral_tone_with_five, tone_sandhi=tone_sandhi, **kwargs + ) + + self._g2pw = g2pw_instance + + def convert(self, words, style, heteronym, errors, strict, **kwargs): + pys = [] + if RE_HANS.match(words): + pys = self._to_pinyin(words, style=style, heteronym=heteronym, errors=errors, strict=strict) + post_data = self.post_pinyin(words, heteronym, pys) + if post_data is not None: + pys = post_data + + pys = self.convert_styles(pys, words, style, heteronym, errors, strict) + + else: + py = self.handle_nopinyin(words, style=style, errors=errors, heteronym=heteronym, strict=strict) + if py: + pys.extend(py) + + return _remove_dup_and_empty(pys) + + def _to_pinyin(self, han, style, heteronym, errors, strict, **kwargs): + pinyins = [] + + g2pw_pinyin = self._g2pw(han) + + if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑 + return super(Converter, self).convert(han, Style.TONE, heteronym, errors, strict, **kwargs) + + for i, item in enumerate(g2pw_pinyin[0]): + if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑 + py = super(Converter, self).convert(han[i], Style.TONE, heteronym, errors, strict, **kwargs) + pinyins.extend(py) + else: + pinyins.append([to_tone(item)]) + + return pinyins + + +def _remove_dup_items(lst, remove_empty=False): + new_lst = [] + for item in lst: + if remove_empty and not item: + continue + if item not in new_lst: + new_lst.append(item) + return new_lst + + +def _remove_dup_and_empty(lst_list): + new_lst_list = [] + for lst in lst_list: + lst = _remove_dup_items(lst, remove_empty=True) + if lst: + new_lst_list.append(lst) + else: + new_lst_list.append([""]) + + return new_lst_list + + +def cache_dict(polyphonic_dict, file_path): + with open(file_path, "wb") as pickle_file: + pickle.dump(polyphonic_dict, pickle_file) + + +def get_dict(): + if os.path.exists(CACHE_PATH): + with open(CACHE_PATH, "rb") as pickle_file: + polyphonic_dict = pickle.load(pickle_file) + else: + polyphonic_dict = read_dict() + cache_dict(polyphonic_dict, CACHE_PATH) + + return polyphonic_dict + + +def read_dict(): + polyphonic_dict = {} + with open(PP_DICT_PATH, encoding="utf-8") as f: + line = f.readline() + while line: + key, value_str = line.split(":") + value = eval(value_str.strip()) + polyphonic_dict[key.strip()] = value + line = f.readline() + with open(PP_FIX_DICT_PATH, encoding="utf-8") as f: + line = f.readline() + while line: + key, value_str = line.split(":") + value = eval(value_str.strip()) + polyphonic_dict[key.strip()] = value + line = f.readline() + return polyphonic_dict + + +def correct_pronunciation(word, word_pinyins): + new_pinyins = pp_dict.get(word, "") + if new_pinyins == "": + for idx, w in enumerate(word): + w_pinyin = pp_dict.get(w, "") + if w_pinyin != "": + word_pinyins[idx] = w_pinyin[0] + return word_pinyins + else: + return new_pinyins + + +pp_dict = get_dict() diff --git a/text/g2pw/onnx_api.py b/text/g2pw/onnx_api.py new file mode 100644 index 0000000000000000000000000000000000000000..52eed4438fccb93f188fc61944071969a90dc40b --- /dev/null +++ b/text/g2pw/onnx_api.py @@ -0,0 +1,247 @@ +# This code is modified from https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/g2pw +# This code is modified from https://github.com/GitYCC/g2pW + +import json +import os +import warnings +import zipfile +from typing import Any, Dict, List, Tuple + +import numpy as np +import onnxruntime +import requests +import torch +from opencc import OpenCC +from pypinyin import Style, pinyin +from transformers.models.auto.tokenization_auto import AutoTokenizer + +from ..zh_normalization.char_convert import tranditional_to_simplified +from .dataset import get_char_phoneme_labels, get_phoneme_labels, prepare_onnx_input +from .utils import load_config + +onnxruntime.set_default_logger_severity(3) +try: + onnxruntime.preload_dlls() +except: + pass + # traceback.print_exc() +warnings.filterwarnings("ignore") + +model_version = "1.1" + + +def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[List[str], List[float]]: + all_preds = [] + all_confidences = [] + probs = session.run( + [], + { + "input_ids": onnx_input["input_ids"], + "token_type_ids": onnx_input["token_type_ids"], + "attention_mask": onnx_input["attention_masks"], + "phoneme_mask": onnx_input["phoneme_masks"], + "char_ids": onnx_input["char_ids"], + "position_ids": onnx_input["position_ids"], + }, + )[0] + + preds = np.argmax(probs, axis=1).tolist() + max_probs = [] + for index, arr in zip(preds, probs.tolist()): + max_probs.append(arr[index]) + all_preds += [labels[pred] for pred in preds] + all_confidences += max_probs + + return all_preds, all_confidences + + +def download_and_decompress(model_dir: str = "G2PWModel/"): + if not os.path.exists(model_dir): + parent_directory = os.path.dirname(model_dir) + zip_dir = os.path.join(parent_directory, "G2PWModel_1.1.zip") + extract_dir = os.path.join(parent_directory, "G2PWModel_1.1") + extract_dir_new = os.path.join(parent_directory, "G2PWModel") + print("Downloading g2pw model...") + modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip" + with requests.get(modelscope_url, stream=True) as r: + r.raise_for_status() + with open(zip_dir, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + print("Extracting g2pw model...") + with zipfile.ZipFile(zip_dir, "r") as zip_ref: + zip_ref.extractall(parent_directory) + + os.rename(extract_dir, extract_dir_new) + + return model_dir + + +class G2PWOnnxConverter: + def __init__( + self, + model_dir: str = "G2PWModel/", + style: str = "bopomofo", + model_source: str = None, + enable_non_tradional_chinese: bool = False, + ): + uncompress_path = download_and_decompress(model_dir) + + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL + sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0 + try: + self.session_g2pW = onnxruntime.InferenceSession( + os.path.join(uncompress_path, "g2pW.onnx"), + sess_options=sess_options, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + except: + self.session_g2pW = onnxruntime.InferenceSession( + os.path.join(uncompress_path, "g2pW.onnx"), + sess_options=sess_options, + providers=["CPUExecutionProvider"], + ) + self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True) + + self.model_source = model_source if model_source else self.config.model_source + self.enable_opencc = enable_non_tradional_chinese + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_source) + + polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt") + monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt") + self.polyphonic_chars = [ + line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n") + ] + self.non_polyphonic = { + "一", + "不", + "和", + "咋", + "嗲", + "剖", + "差", + "攢", + "倒", + "難", + "奔", + "勁", + "拗", + "肖", + "瘙", + "誒", + "泊", + "听", + "噢", + } + self.non_monophonic = {"似", "攢"} + self.monophonic_chars = [ + line.split("\t") for line in open(monophonic_chars_path, encoding="utf-8").read().strip().split("\n") + ] + self.labels, self.char2phonemes = ( + get_char_phoneme_labels(polyphonic_chars=self.polyphonic_chars) + if self.config.use_char_phoneme + else get_phoneme_labels(polyphonic_chars=self.polyphonic_chars) + ) + + self.chars = sorted(list(self.char2phonemes.keys())) + + self.polyphonic_chars_new = set(self.chars) + for char in self.non_polyphonic: + if char in self.polyphonic_chars_new: + self.polyphonic_chars_new.remove(char) + + self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars} + for char in self.non_monophonic: + if char in self.monophonic_chars_dict: + self.monophonic_chars_dict.pop(char) + + self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"] + + with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr: + self.bopomofo_convert_dict = json.load(fr) + self.style_convert_func = { + "bopomofo": lambda x: x, + "pinyin": self._convert_bopomofo_to_pinyin, + }[style] + + with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr: + self.char_bopomofo_dict = json.load(fr) + + if self.enable_opencc: + self.cc = OpenCC("s2tw") + + def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str: + tone = bopomofo[-1] + assert tone in "12345" + component = self.bopomofo_convert_dict.get(bopomofo[:-1]) + if component: + return component + tone + else: + print(f'Warning: "{bopomofo}" cannot convert to pinyin') + return None + + def __call__(self, sentences: List[str]) -> List[List[str]]: + if isinstance(sentences, str): + sentences = [sentences] + + if self.enable_opencc: + translated_sentences = [] + for sent in sentences: + translated_sent = self.cc.convert(sent) + assert len(translated_sent) == len(sent) + translated_sentences.append(translated_sent) + sentences = translated_sentences + + texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) + if len(texts) == 0: + # sentences no polyphonic words + return partial_results + + onnx_input = prepare_onnx_input( + tokenizer=self.tokenizer, + labels=self.labels, + char2phonemes=self.char2phonemes, + chars=self.chars, + texts=texts, + query_ids=query_ids, + use_mask=self.config.use_mask, + window_size=None, + ) + + preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels) + if self.config.use_char_phoneme: + preds = [pred.split(" ")[1] for pred in preds] + + results = partial_results + for sent_id, query_id, pred in zip(sent_ids, query_ids, preds): + results[sent_id][query_id] = self.style_convert_func(pred) + + return results + + def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]: + texts, query_ids, sent_ids, partial_results = [], [], [], [] + for sent_id, sent in enumerate(sentences): + # pypinyin works well for Simplified Chinese than Traditional Chinese + sent_s = tranditional_to_simplified(sent) + pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3) + partial_result = [None] * len(sent) + for i, char in enumerate(sent): + if char in self.polyphonic_chars_new: + texts.append(sent) + query_ids.append(i) + sent_ids.append(sent_id) + elif char in self.monophonic_chars_dict: + partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char]) + elif char in self.char_bopomofo_dict: + partial_result[i] = pypinyin_result[i][0] + # partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0]) + else: + partial_result[i] = pypinyin_result[i][0] + + partial_results.append(partial_result) + return texts, query_ids, sent_ids, partial_results diff --git a/text/g2pw/polyphonic-fix.rep b/text/g2pw/polyphonic-fix.rep new file mode 100644 index 0000000000000000000000000000000000000000..4566058647b2fd0bb82b840339251a927836a4fb --- /dev/null +++ b/text/g2pw/polyphonic-fix.rep @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6444b2ad4a1070dad9b16c7e47271910a69349ff079e8d8e236c8818209b65f4 +size 1660953 diff --git a/text/g2pw/polyphonic.pickle b/text/g2pw/polyphonic.pickle new file mode 100644 index 0000000000000000000000000000000000000000..0a8912d1b2cb3016eac46dc55ec6e10a7b442f7b --- /dev/null +++ b/text/g2pw/polyphonic.pickle @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f425246160a32c578557cd3151cd0bb97f5f44c3aaf65e718dd2c3213c04fb4b +size 1322387 diff --git a/text/g2pw/polyphonic.rep b/text/g2pw/polyphonic.rep new file mode 100644 index 0000000000000000000000000000000000000000..312871801de4ac858a58df7d02c4b4c7bd175cae --- /dev/null +++ b/text/g2pw/polyphonic.rep @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8f13e2004702cbbf01e2e3fe8436d9d620a0fe4d57176a844704619ffd5df5e +size 1391 diff --git a/text/g2pw/utils.py b/text/g2pw/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a86b2bc0421a0a4cf9fc57edb1fb3a4f3ae4a9e1 --- /dev/null +++ b/text/g2pw/utils.py @@ -0,0 +1,143 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Credits + This code is modified from https://github.com/GitYCC/g2pW +""" + +import os +import re + + +def wordize_and_map(text: str): + words = [] + index_map_from_text_to_word = [] + index_map_from_word_to_text = [] + while len(text) > 0: + match_space = re.match(r"^ +", text) + if match_space: + space_str = match_space.group(0) + index_map_from_text_to_word += [None] * len(space_str) + text = text[len(space_str) :] + continue + + match_en = re.match(r"^[a-zA-Z0-9]+", text) + if match_en: + en_word = match_en.group(0) + + word_start_pos = len(index_map_from_text_to_word) + word_end_pos = word_start_pos + len(en_word) + index_map_from_word_to_text.append((word_start_pos, word_end_pos)) + + index_map_from_text_to_word += [len(words)] * len(en_word) + + words.append(en_word) + text = text[len(en_word) :] + else: + word_start_pos = len(index_map_from_text_to_word) + word_end_pos = word_start_pos + 1 + index_map_from_word_to_text.append((word_start_pos, word_end_pos)) + + index_map_from_text_to_word += [len(words)] + + words.append(text[0]) + text = text[1:] + return words, index_map_from_text_to_word, index_map_from_word_to_text + + +def tokenize_and_map(tokenizer, text: str): + words, text2word, word2text = wordize_and_map(text=text) + + tokens = [] + index_map_from_token_to_text = [] + for word, (word_start, word_end) in zip(words, word2text): + word_tokens = tokenizer.tokenize(word) + + if len(word_tokens) == 0 or word_tokens == ["[UNK]"]: + index_map_from_token_to_text.append((word_start, word_end)) + tokens.append("[UNK]") + else: + current_word_start = word_start + for word_token in word_tokens: + word_token_len = len(re.sub(r"^##", "", word_token)) + index_map_from_token_to_text.append((current_word_start, current_word_start + word_token_len)) + current_word_start = current_word_start + word_token_len + tokens.append(word_token) + + index_map_from_text_to_token = text2word + for i, (token_start, token_end) in enumerate(index_map_from_token_to_text): + for token_pos in range(token_start, token_end): + index_map_from_text_to_token[token_pos] = i + + return tokens, index_map_from_text_to_token, index_map_from_token_to_text + + +def _load_config(config_path: os.PathLike): + import importlib.util + + spec = importlib.util.spec_from_file_location("__init__", config_path) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + return config + + +default_config_dict = { + "manual_seed": 1313, + "model_source": "bert-base-chinese", + "window_size": 32, + "num_workers": 2, + "use_mask": True, + "use_char_phoneme": False, + "use_conditional": True, + "param_conditional": { + "affect_location": "softmax", + "bias": True, + "char-linear": True, + "pos-linear": False, + "char+pos-second": True, + "char+pos-second_lowrank": False, + "lowrank_size": 0, + "char+pos-second_fm": False, + "fm_size": 0, + "fix_mode": None, + "count_json": "train.count.json", + }, + "lr": 5e-5, + "val_interval": 200, + "num_iter": 10000, + "use_focal": False, + "param_focal": {"alpha": 0.0, "gamma": 0.7}, + "use_pos": True, + "param_pos ": { + "weight": 0.1, + "pos_joint_training": True, + "train_pos_path": "train.pos", + "valid_pos_path": "dev.pos", + "test_pos_path": "test.pos", + }, +} + + +def load_config(config_path: os.PathLike, use_default: bool = False): + config = _load_config(config_path) + if use_default: + for attr, val in default_config_dict.items(): + if not hasattr(config, attr): + setattr(config, attr, val) + elif isinstance(val, dict): + d = getattr(config, attr) + for dict_k, dict_v in val.items(): + if dict_k not in d: + d[dict_k] = dict_v + return config diff --git a/text/ja_userdic/userdict.csv b/text/ja_userdic/userdict.csv new file mode 100644 index 0000000000000000000000000000000000000000..d9aa2378c16a27eaa512a2a7642d7241bd050e8b --- /dev/null +++ b/text/ja_userdic/userdict.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d857e443ee48d9641096816a98996669602895411e4330d7d91d1dbe1103389f +size 17180971 diff --git a/text/japanese.py b/text/japanese.py new file mode 100644 index 0000000000000000000000000000000000000000..a54d0cf0bafed09b428fa9dbf0e942ec2ec87fd5 --- /dev/null +++ b/text/japanese.py @@ -0,0 +1,276 @@ +# modified from https://github.com/CjangCjengh/vits/blob/main/text/japanese.py +import re +import os +import hashlib + +try: + import pyopenjtalk + + current_file_path = os.path.dirname(__file__) + + # 防止win下无法读取模型 + if os.name == "nt": + python_dir = os.getcwd() + OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8") + if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", OPEN_JTALK_DICT_DIR)): + if OPEN_JTALK_DICT_DIR[: len(python_dir)].upper() == python_dir.upper(): + OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR, python_dir)) + else: + import shutil + + if not os.path.exists("TEMP"): + os.mkdir("TEMP") + if not os.path.exists(os.path.join("TEMP", "ja")): + os.mkdir(os.path.join("TEMP", "ja")) + if os.path.exists(os.path.join("TEMP", "ja", "open_jtalk_dic")): + shutil.rmtree(os.path.join("TEMP", "ja", "open_jtalk_dic")) + shutil.copytree( + pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"), + os.path.join("TEMP", "ja", "open_jtalk_dic"), + ) + OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic") + pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8") + + if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", current_file_path)): + if current_file_path[: len(python_dir)].upper() == python_dir.upper(): + current_file_path = os.path.join(os.path.relpath(current_file_path, python_dir)) + else: + if not os.path.exists("TEMP"): + os.mkdir("TEMP") + if not os.path.exists(os.path.join("TEMP", "ja")): + os.mkdir(os.path.join("TEMP", "ja")) + if not os.path.exists(os.path.join("TEMP", "ja", "ja_userdic")): + os.mkdir(os.path.join("TEMP", "ja", "ja_userdic")) + shutil.copyfile( + os.path.join(current_file_path, "ja_userdic", "userdict.csv"), + os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"), + ) + current_file_path = os.path.join("TEMP", "ja") + + def get_hash(fp: str) -> str: + hash_md5 = hashlib.md5() + with open(fp, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + USERDIC_CSV_PATH = os.path.join(current_file_path, "ja_userdic", "userdict.csv") + USERDIC_BIN_PATH = os.path.join(current_file_path, "ja_userdic", "user.dict") + USERDIC_HASH_PATH = os.path.join(current_file_path, "ja_userdic", "userdict.md5") + # 如果没有用户词典,就生成一个;如果有,就检查md5,如果不一样,就重新生成 + if os.path.exists(USERDIC_CSV_PATH): + if ( + not os.path.exists(USERDIC_BIN_PATH) + or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r", encoding="utf-8").read() + ): + pyopenjtalk.mecab_dict_index(USERDIC_CSV_PATH, USERDIC_BIN_PATH) + with open(USERDIC_HASH_PATH, "w", encoding="utf-8") as f: + f.write(get_hash(USERDIC_CSV_PATH)) + + if os.path.exists(USERDIC_BIN_PATH): + pyopenjtalk.update_global_jtalk_with_user_dict(USERDIC_BIN_PATH) +except Exception: + # print(e) + import pyopenjtalk + + # failed to load user dictionary, ignore. + pass + + +from text.symbols import punctuation + +# Regular expression matching Japanese without punctuation marks: +_japanese_characters = re.compile( + r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" +) + +# Regular expression matching non-Japanese characters or punctuation marks: +_japanese_marks = re.compile( + r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" +) + +# List of (symbol, Japanese) pairs for marks: +_symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("%", "パーセント")]] + + +# List of (consonant, sokuon) pairs: +_real_sokuon = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + (r"Q([↑↓]*[kg])", r"k#\1"), + (r"Q([↑↓]*[tdjʧ])", r"t#\1"), + (r"Q([↑↓]*[sʃ])", r"s\1"), + (r"Q([↑↓]*[pb])", r"p#\1"), + ] +] + +# List of (consonant, hatsuon) pairs: +_real_hatsuon = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + (r"N([↑↓]*[pbm])", r"m\1"), + (r"N([↑↓]*[ʧʥj])", r"n^\1"), + (r"N([↑↓]*[tdn])", r"n\1"), + (r"N([↑↓]*[kg])", r"ŋ\1"), + ] +] + + +def post_replace_ph(ph): + rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": "…", + } + + if ph in rep_map.keys(): + ph = rep_map[ph] + return ph + + +def replace_consecutive_punctuation(text): + punctuations = "".join(re.escape(p) for p in punctuation) + pattern = f"([{punctuations}])([{punctuations}])+" + result = re.sub(pattern, r"\1", text) + return result + + +def symbols_to_japanese(text): + for regex, replacement in _symbols_to_japanese: + text = re.sub(regex, replacement, text) + return text + + +def preprocess_jap(text, with_prosody=False): + """Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html""" + text = symbols_to_japanese(text) + # English words to lower case, should have no influence on japanese words. + text = text.lower() + sentences = re.split(_japanese_marks, text) + marks = re.findall(_japanese_marks, text) + text = [] + for i, sentence in enumerate(sentences): + if re.match(_japanese_characters, sentence): + if with_prosody: + text += pyopenjtalk_g2p_prosody(sentence)[1:-1] + else: + p = pyopenjtalk.g2p(sentence) + text += p.split(" ") + + if i < len(marks): + if marks[i] == " ": # 防止意外的UNK + continue + text += [marks[i].replace(" ", "")] + return text + + +def text_normalize(text): + # todo: jap text normalize + + # 避免重复标点引起的参考泄露 + text = replace_consecutive_punctuation(text) + return text + + +# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py +def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True): + """Extract phoneme + prosoody symbol sequence from input full-context labels. + + The algorithm is based on `Prosodic features control by symbols as input of + sequence-to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks. + + Args: + text (str): Input text. + drop_unvoiced_vowels (bool): whether to drop unvoiced vowels. + + Returns: + List[str]: List of phoneme + prosody symbols. + + Examples: + >>> from espnet2.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody + >>> pyopenjtalk_g2p_prosody("こんにちは。") + ['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$'] + + .. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic + modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104 + + """ + labels = pyopenjtalk.make_label(pyopenjtalk.run_frontend(text)) + N = len(labels) + + phones = [] + for n in range(N): + lab_curr = labels[n] + + # current phoneme + p3 = re.search(r"\-(.*?)\+", lab_curr).group(1) + # deal unvoiced vowels as normal vowels + if drop_unvoiced_vowels and p3 in "AEIOU": + p3 = p3.lower() + + # deal with sil at the beginning and the end of text + if p3 == "sil": + assert n == 0 or n == N - 1 + if n == 0: + phones.append("^") + elif n == N - 1: + # check question form or not + e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr) + if e3 == 0: + phones.append("$") + elif e3 == 1: + phones.append("?") + continue + elif p3 == "pau": + phones.append("_") + continue + else: + phones.append(p3) + + # accent type and position info (forward or backward) + a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr) + a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr) + a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr) + + # number of mora in accent phrase + f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr) + + a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1]) + # accent phrase border + if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl": + phones.append("#") + # pitch falling + elif a1 == 0 and a2_next == a2 + 1 and a2 != f1: + phones.append("]") + # pitch rising + elif a2 == 1 and a2_next == 2: + phones.append("[") + + return phones + + +# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py +def _numeric_feature_by_regex(regex, s): + match = re.search(regex, s) + if match is None: + return -50 + return int(match.group(1)) + + +def g2p(norm_text, with_prosody=True): + phones = preprocess_jap(norm_text, with_prosody) + phones = [post_replace_ph(i) for i in phones] + # todo: implement tones and word2ph + return phones + + +if __name__ == "__main__": + phones = g2p("Hello.こんにちは!今日もNiCe天気ですね!tokyotowerに行きましょう!") + print(phones) diff --git a/text/korean.py b/text/korean.py new file mode 100644 index 0000000000000000000000000000000000000000..254b05cf3d9c8af364fa15a0555e4cd363e07724 --- /dev/null +++ b/text/korean.py @@ -0,0 +1,337 @@ +# reference: https://github.com/ORI-Muchim/MB-iSTFT-VITS-Korean/blob/main/text/korean.py + +import re +from jamo import h2j, j2hcj +import ko_pron +from g2pk2 import G2p + +import importlib +import os + +# 防止win下无法读取模型 +if os.name == "nt": + + class win_G2p(G2p): + def check_mecab(self): + super().check_mecab() + spam_spec = importlib.util.find_spec("eunjeon") + non_found = spam_spec is None + if non_found: + print("you have to install eunjeon. install it...") + else: + installpath = spam_spec.submodule_search_locations[0] + if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)): + import sys + from eunjeon import Mecab as _Mecab + + class Mecab(_Mecab): + def get_dicpath(installpath): + if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)): + import shutil + + python_dir = os.getcwd() + if installpath[: len(python_dir)].upper() == python_dir.upper(): + dicpath = os.path.join(os.path.relpath(installpath, python_dir), "data", "mecabrc") + else: + if not os.path.exists("TEMP"): + os.mkdir("TEMP") + if not os.path.exists(os.path.join("TEMP", "ko")): + os.mkdir(os.path.join("TEMP", "ko")) + if os.path.exists(os.path.join("TEMP", "ko", "ko_dict")): + shutil.rmtree(os.path.join("TEMP", "ko", "ko_dict")) + + shutil.copytree( + os.path.join(installpath, "data"), os.path.join("TEMP", "ko", "ko_dict") + ) + dicpath = os.path.join("TEMP", "ko", "ko_dict", "mecabrc") + else: + dicpath = os.path.abspath(os.path.join(installpath, "data/mecabrc")) + return dicpath + + def __init__(self, dicpath=get_dicpath(installpath)): + super().__init__(dicpath=dicpath) + + sys.modules["eunjeon"].Mecab = Mecab + + G2p = win_G2p + + +from text.symbols2 import symbols + +# This is a list of Korean classifiers preceded by pure Korean numerals. +_korean_classifiers = ( + "군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통" +) + +# List of (hangul, hangul divided) pairs: +_hangul_divided = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + # ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule + # ('ㄵ', 'ㄴㅈ'), + # ('ㄶ', 'ㄴㅎ'), + # ('ㄺ', 'ㄹㄱ'), + # ('ㄻ', 'ㄹㅁ'), + # ('ㄼ', 'ㄹㅂ'), + # ('ㄽ', 'ㄹㅅ'), + # ('ㄾ', 'ㄹㅌ'), + # ('ㄿ', 'ㄹㅍ'), + # ('ㅀ', 'ㄹㅎ'), + # ('ㅄ', 'ㅂㅅ'), + ("ㅘ", "ㅗㅏ"), + ("ㅙ", "ㅗㅐ"), + ("ㅚ", "ㅗㅣ"), + ("ㅝ", "ㅜㅓ"), + ("ㅞ", "ㅜㅔ"), + ("ㅟ", "ㅜㅣ"), + ("ㅢ", "ㅡㅣ"), + ("ㅑ", "ㅣㅏ"), + ("ㅒ", "ㅣㅐ"), + ("ㅕ", "ㅣㅓ"), + ("ㅖ", "ㅣㅔ"), + ("ㅛ", "ㅣㅗ"), + ("ㅠ", "ㅣㅜ"), + ] +] + +# List of (Latin alphabet, hangul) pairs: +_latin_to_hangul = [ + (re.compile("%s" % x[0], re.IGNORECASE), x[1]) + for x in [ + ("a", "에이"), + ("b", "비"), + ("c", "시"), + ("d", "디"), + ("e", "이"), + ("f", "에프"), + ("g", "지"), + ("h", "에이치"), + ("i", "아이"), + ("j", "제이"), + ("k", "케이"), + ("l", "엘"), + ("m", "엠"), + ("n", "엔"), + ("o", "오"), + ("p", "피"), + ("q", "큐"), + ("r", "아르"), + ("s", "에스"), + ("t", "티"), + ("u", "유"), + ("v", "브이"), + ("w", "더블유"), + ("x", "엑스"), + ("y", "와이"), + ("z", "제트"), + ] +] + +# List of (ipa, lazy ipa) pairs: +_ipa_to_lazy_ipa = [ + (re.compile("%s" % x[0], re.IGNORECASE), x[1]) + for x in [ + ("t͡ɕ", "ʧ"), + ("d͡ʑ", "ʥ"), + ("ɲ", "n^"), + ("ɕ", "ʃ"), + ("ʷ", "w"), + ("ɭ", "l`"), + ("ʎ", "ɾ"), + ("ɣ", "ŋ"), + ("ɰ", "ɯ"), + ("ʝ", "j"), + ("ʌ", "ə"), + ("ɡ", "g"), + ("\u031a", "#"), + ("\u0348", "="), + ("\u031e", ""), + ("\u0320", ""), + ("\u0339", ""), + ] +] + + +def fix_g2pk2_error(text): + new_text = "" + i = 0 + while i < len(text) - 4: + if (text[i : i + 3] == "ㅇㅡㄹ" or text[i : i + 3] == "ㄹㅡㄹ") and text[i + 3] == " " and text[i + 4] == "ㄹ": + new_text += text[i : i + 3] + " " + "ㄴ" + i += 5 + else: + new_text += text[i] + i += 1 + + new_text += text[i:] + return new_text + + +def latin_to_hangul(text): + for regex, replacement in _latin_to_hangul: + text = re.sub(regex, replacement, text) + return text + + +def divide_hangul(text): + text = j2hcj(h2j(text)) + for regex, replacement in _hangul_divided: + text = re.sub(regex, replacement, text) + return text + + +def hangul_number(num, sino=True): + """Reference https://github.com/Kyubyong/g2pK""" + num = re.sub(",", "", num) + + if num == "0": + return "영" + if not sino and num == "20": + return "스무" + + digits = "123456789" + names = "일이삼사오육칠팔구" + digit2name = {d: n for d, n in zip(digits, names)} + + modifiers = "한 두 세 네 다섯 여섯 일곱 여덟 아홉" + decimals = "열 스물 서른 마흔 쉰 예순 일흔 여든 아흔" + digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())} + digit2dec = {d: dec for d, dec in zip(digits, decimals.split())} + + spelledout = [] + for i, digit in enumerate(num): + i = len(num) - i - 1 + if sino: + if i == 0: + name = digit2name.get(digit, "") + elif i == 1: + name = digit2name.get(digit, "") + "십" + name = name.replace("일십", "십") + else: + if i == 0: + name = digit2mod.get(digit, "") + elif i == 1: + name = digit2dec.get(digit, "") + if digit == "0": + if i % 4 == 0: + last_three = spelledout[-min(3, len(spelledout)) :] + if "".join(last_three) == "": + spelledout.append("") + continue + else: + spelledout.append("") + continue + if i == 2: + name = digit2name.get(digit, "") + "백" + name = name.replace("일백", "백") + elif i == 3: + name = digit2name.get(digit, "") + "천" + name = name.replace("일천", "천") + elif i == 4: + name = digit2name.get(digit, "") + "만" + name = name.replace("일만", "만") + elif i == 5: + name = digit2name.get(digit, "") + "십" + name = name.replace("일십", "십") + elif i == 6: + name = digit2name.get(digit, "") + "백" + name = name.replace("일백", "백") + elif i == 7: + name = digit2name.get(digit, "") + "천" + name = name.replace("일천", "천") + elif i == 8: + name = digit2name.get(digit, "") + "억" + elif i == 9: + name = digit2name.get(digit, "") + "십" + elif i == 10: + name = digit2name.get(digit, "") + "백" + elif i == 11: + name = digit2name.get(digit, "") + "천" + elif i == 12: + name = digit2name.get(digit, "") + "조" + elif i == 13: + name = digit2name.get(digit, "") + "십" + elif i == 14: + name = digit2name.get(digit, "") + "백" + elif i == 15: + name = digit2name.get(digit, "") + "천" + spelledout.append(name) + return "".join(elem for elem in spelledout) + + +def number_to_hangul(text): + """Reference https://github.com/Kyubyong/g2pK""" + tokens = set(re.findall(r"(\d[\d,]*)([\uac00-\ud71f]+)", text)) + for token in tokens: + num, classifier = token + if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers: + spelledout = hangul_number(num, sino=False) + else: + spelledout = hangul_number(num, sino=True) + text = text.replace(f"{num}{classifier}", f"{spelledout}{classifier}") + # digit by digit for remaining digits + digits = "0123456789" + names = "영일이삼사오육칠팔구" + for d, n in zip(digits, names): + text = text.replace(d, n) + return text + + +def korean_to_lazy_ipa(text): + text = latin_to_hangul(text) + text = number_to_hangul(text) + text = re.sub("[\uac00-\ud7af]+", lambda x: ko_pron.romanise(x.group(0), "ipa").split("] ~ [")[0], text) + for regex, replacement in _ipa_to_lazy_ipa: + text = re.sub(regex, replacement, text) + return text + + +_g2p = G2p() + + +def korean_to_ipa(text): + text = latin_to_hangul(text) + text = number_to_hangul(text) + text = _g2p(text) + text = fix_g2pk2_error(text) + text = korean_to_lazy_ipa(text) + return text.replace("ʧ", "tʃ").replace("ʥ", "dʑ") + + +def post_replace_ph(ph): + rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": "…", + " ": "空", + } + if ph in rep_map.keys(): + ph = rep_map[ph] + if ph in symbols: + return ph + if ph not in symbols: + ph = "停" + return ph + + +def g2p(text): + text = latin_to_hangul(text) + text = _g2p(text) + text = divide_hangul(text) + text = fix_g2pk2_error(text) + text = re.sub(r"([\u3131-\u3163])$", r"\1.", text) + # text = "".join([post_replace_ph(i) for i in text]) + text = [post_replace_ph(i) for i in text] + return text + + +if __name__ == "__main__": + text = "안녕하세요" + print(g2p(text)) diff --git a/text/namedict_cache.pickle b/text/namedict_cache.pickle new file mode 100644 index 0000000000000000000000000000000000000000..9ad2ed9a6478077f8fe4469192597fa8f6442cf6 --- /dev/null +++ b/text/namedict_cache.pickle @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:559552094c4a6e995213e3fa586330e078ef8cb3a7a95a3109e945111cd2bfc1 +size 760663 diff --git a/text/opencpop-strict.txt b/text/opencpop-strict.txt new file mode 100644 index 0000000000000000000000000000000000000000..1bede89a429e34b1a58781ee570fc313b80c0aee --- /dev/null +++ b/text/opencpop-strict.txt @@ -0,0 +1,429 @@ +a AA a +ai AA ai +an AA an +ang AA ang +ao AA ao +ba b a +bai b ai +ban b an +bang b ang +bao b ao +bei b ei +ben b en +beng b eng +bi b i +bian b ian +biao b iao +bie b ie +bin b in +bing b ing +bo b o +bu b u +ca c a +cai c ai +can c an +cang c ang +cao c ao +ce c e +cei c ei +cen c en +ceng c eng +cha ch a +chai ch ai +chan ch an +chang ch ang +chao ch ao +che ch e +chen ch en +cheng ch eng +chi ch ir +chong ch ong +chou ch ou +chu ch u +chua ch ua +chuai ch uai +chuan ch uan +chuang ch uang +chui ch ui +chun ch un +chuo ch uo +ci c i0 +cong c ong +cou c ou +cu c u +cuan c uan +cui c ui +cun c un +cuo c uo +da d a +dai d ai +dan d an +dang d ang +dao d ao +de d e +dei d ei +den d en +deng d eng +di d i +dia d ia +dian d ian +diao d iao +die d ie +ding d ing +diu d iu +dong d ong +dou d ou +du d u +duan d uan +dui d ui +dun d un +duo d uo +e EE e +ei EE ei +en EE en +eng EE eng +er EE er +fa f a +fan f an +fang f ang +fei f ei +fen f en +feng f eng +fo f o +fou f ou +fu f u +ga g a +gai g ai +gan g an +gang g ang +gao g ao +ge g e +gei g ei +gen g en +geng g eng +gong g ong +gou g ou +gu g u +gua g ua +guai g uai +guan g uan +guang g uang +gui g ui +gun g un +guo g uo +ha h a +hai h ai +han h an +hang h ang +hao h ao +he h e +hei h ei +hen h en +heng h eng +hong h ong +hou h ou +hu h u +hua h ua +huai h uai +huan h uan +huang h uang +hui h ui +hun h un +huo h uo +ji j i +jia j ia +jian j ian +jiang j iang +jiao j iao +jie j ie +jin j in +jing j ing +jiong j iong +jiu j iu +ju j v +jv j v +juan j van +jvan j van +jue j ve +jve j ve +jun j vn +jvn j vn +ka k a +kai k ai +kan k an +kang k ang +kao k ao +ke k e +kei k ei +ken k en +keng k eng +kong k ong +kou k ou +ku k u +kua k ua +kuai k uai +kuan k uan +kuang k uang +kui k ui +kun k un +kuo k uo +la l a +lai l ai +lan l an +lang l ang +lao l ao +le l e +lei l ei +leng l eng +li l i +lia l ia +lian l ian +liang l iang +liao l iao +lie l ie +lin l in +ling l ing +liu l iu +lo l o +long l ong +lou l ou +lu l u +luan l uan +lun l un +luo l uo +lv l v +lve l ve +ma m a +mai m ai +man m an +mang m ang +mao m ao +me m e +mei m ei +men m en +meng m eng +mi m i +mian m ian +miao m iao +mie m ie +min m in +ming m ing +miu m iu +mo m o +mou m ou +mu m u +na n a +nai n ai +nan n an +nang n ang +nao n ao +ne n e +nei n ei +nen n en +neng n eng +ni n i +nian n ian +niang n iang +niao n iao +nie n ie +nin n in +ning n ing +niu n iu +nong n ong +nou n ou +nu n u +nuan n uan +nun n un +nuo n uo +nv n v +nve n ve +o OO o +ou OO ou +pa p a +pai p ai +pan p an +pang p ang +pao p ao +pei p ei +pen p en +peng p eng +pi p i +pian p ian +piao p iao +pie p ie +pin p in +ping p ing +po p o +pou p ou +pu p u +qi q i +qia q ia +qian q ian +qiang q iang +qiao q iao +qie q ie +qin q in +qing q ing +qiong q iong +qiu q iu +qu q v +qv q v +quan q van +qvan q van +que q ve +qve q ve +qun q vn +qvn q vn +ran r an +rang r ang +rao r ao +re r e +ren r en +reng r eng +ri r ir +rong r ong +rou r ou +ru r u +rua r ua +ruan r uan +rui r ui +run r un +ruo r uo +sa s a +sai s ai +san s an +sang s ang +sao s ao +se s e +sen s en +seng s eng +sha sh a +shai sh ai +shan sh an +shang sh ang +shao sh ao +she sh e +shei sh ei +shen sh en +sheng sh eng +shi sh ir +shou sh ou +shu sh u +shua sh ua +shuai sh uai +shuan sh uan +shuang sh uang +shui sh ui +shun sh un +shuo sh uo +si s i0 +song s ong +sou s ou +su s u +suan s uan +sui s ui +sun s un +suo s uo +ta t a +tai t ai +tan t an +tang t ang +tao t ao +te t e +tei t ei +teng t eng +ti t i +tian t ian +tiao t iao +tie t ie +ting t ing +tong t ong +tou t ou +tu t u +tuan t uan +tui t ui +tun t un +tuo t uo +wa w a +wai w ai +wan w an +wang w ang +wei w ei +wen w en +weng w eng +wo w o +wu w u +xi x i +xia x ia +xian x ian +xiang x iang +xiao x iao +xie x ie +xin x in +xing x ing +xiong x iong +xiu x iu +xu x v +xv x v +xuan x van +xvan x van +xue x ve +xve x ve +xun x vn +xvn x vn +ya y a +yan y En +yang y ang +yao y ao +ye y E +yi y i +yin y in +ying y ing +yo y o +yong y ong +you y ou +yu y v +yv y v +yuan y van +yvan y van +yue y ve +yve y ve +yun y vn +yvn y vn +za z a +zai z ai +zan z an +zang z ang +zao z ao +ze z e +zei z ei +zen z en +zeng z eng +zha zh a +zhai zh ai +zhan zh an +zhang zh ang +zhao zh ao +zhe zh e +zhei zh ei +zhen zh en +zheng zh eng +zhi zh ir +zhong zh ong +zhou zh ou +zhu zh u +zhua zh ua +zhuai zh uai +zhuan zh uan +zhuang zh uang +zhui zh ui +zhun zh un +zhuo zh uo +zi z i0 +zong z ong +zou z ou +zu z u +zuan z uan +zui z ui +zun z un +zuo z uo diff --git a/text/symbols.py b/text/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..b012882b053ebcb30b3aa54e9cc695cd569d774b --- /dev/null +++ b/text/symbols.py @@ -0,0 +1,399 @@ +# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿 +punctuation = ["!", "?", "…", ",", "."] # @是SP停顿 +punctuation.append("-") +pu_symbols = punctuation + ["SP", "SP2", "SP3", "UNK"] +# pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"] +pad = "_" + +c = [ + "AA", + "EE", + "OO", + "b", + "c", + "ch", + "d", + "f", + "g", + "h", + "j", + "k", + "l", + "m", + "n", + "p", + "q", + "r", + "s", + "sh", + "t", + "w", + "x", + "y", + "z", + "zh", +] +v = [ + "E1", + "En1", + "a1", + "ai1", + "an1", + "ang1", + "ao1", + "e1", + "ei1", + "en1", + "eng1", + "er1", + "i1", + "i01", + "ia1", + "ian1", + "iang1", + "iao1", + "ie1", + "in1", + "ing1", + "iong1", + "ir1", + "iu1", + "o1", + "ong1", + "ou1", + "u1", + "ua1", + "uai1", + "uan1", + "uang1", + "ui1", + "un1", + "uo1", + "v1", + "van1", + "ve1", + "vn1", + "E2", + "En2", + "a2", + "ai2", + "an2", + "ang2", + "ao2", + "e2", + "ei2", + "en2", + "eng2", + "er2", + "i2", + "i02", + "ia2", + "ian2", + "iang2", + "iao2", + "ie2", + "in2", + "ing2", + "iong2", + "ir2", + "iu2", + "o2", + "ong2", + "ou2", + "u2", + "ua2", + "uai2", + "uan2", + "uang2", + "ui2", + "un2", + "uo2", + "v2", + "van2", + "ve2", + "vn2", + "E3", + "En3", + "a3", + "ai3", + "an3", + "ang3", + "ao3", + "e3", + "ei3", + "en3", + "eng3", + "er3", + "i3", + "i03", + "ia3", + "ian3", + "iang3", + "iao3", + "ie3", + "in3", + "ing3", + "iong3", + "ir3", + "iu3", + "o3", + "ong3", + "ou3", + "u3", + "ua3", + "uai3", + "uan3", + "uang3", + "ui3", + "un3", + "uo3", + "v3", + "van3", + "ve3", + "vn3", + "E4", + "En4", + "a4", + "ai4", + "an4", + "ang4", + "ao4", + "e4", + "ei4", + "en4", + "eng4", + "er4", + "i4", + "i04", + "ia4", + "ian4", + "iang4", + "iao4", + "ie4", + "in4", + "ing4", + "iong4", + "ir4", + "iu4", + "o4", + "ong4", + "ou4", + "u4", + "ua4", + "uai4", + "uan4", + "uang4", + "ui4", + "un4", + "uo4", + "v4", + "van4", + "ve4", + "vn4", + "E5", + "En5", + "a5", + "ai5", + "an5", + "ang5", + "ao5", + "e5", + "ei5", + "en5", + "eng5", + "er5", + "i5", + "i05", + "ia5", + "ian5", + "iang5", + "iao5", + "ie5", + "in5", + "ing5", + "iong5", + "ir5", + "iu5", + "o5", + "ong5", + "ou5", + "u5", + "ua5", + "uai5", + "uan5", + "uang5", + "ui5", + "un5", + "uo5", + "v5", + "van5", + "ve5", + "vn5", +] + +v_without_tone = [ + "E", + "En", + "a", + "ai", + "an", + "ang", + "ao", + "e", + "ei", + "en", + "eng", + "er", + "i", + "i0", + "ia", + "ian", + "iang", + "iao", + "ie", + "in", + "ing", + "iong", + "ir", + "iu", + "o", + "ong", + "ou", + "u", + "ua", + "uai", + "uan", + "uang", + "ui", + "un", + "uo", + "v", + "van", + "ve", + "vn", +] + +# japanese +ja_symbols = [ + "I", + "N", + "U", + "a", + "b", + "by", + "ch", + "cl", + "d", + "dy", + "e", + "f", + "g", + "gy", + "h", + "hy", + "i", + "j", + "k", + "ky", + "m", + "my", + "n", + "ny", + "o", + "p", + "py", + "r", + "ry", + "s", + "sh", + "t", + "ts", + "u", + "v", + "w", + "y", + "z", + # "[", #上升调型 + # "]", #下降调型 + # "$", #结束符 + # "^", #开始符 +] + +arpa = { + "AH0", + "S", + "AH1", + "EY2", + "AE2", + "EH0", + "OW2", + "UH0", + "NG", + "B", + "G", + "AY0", + "M", + "AA0", + "F", + "AO0", + "ER2", + "UH1", + "IY1", + "AH2", + "DH", + "IY0", + "EY1", + "IH0", + "K", + "N", + "W", + "IY2", + "T", + "AA1", + "ER1", + "EH2", + "OY0", + "UH2", + "UW1", + "Z", + "AW2", + "AW1", + "V", + "UW2", + "AA2", + "ER", + "AW0", + "UW0", + "R", + "OW1", + "EH1", + "ZH", + "AE0", + "IH2", + "IH", + "Y", + "JH", + "P", + "AY1", + "EY0", + "OY2", + "TH", + "HH", + "D", + "ER0", + "CH", + "AO1", + "AE1", + "AO2", + "OY1", + "AY2", + "IH1", + "OW0", + "L", + "SH", +} + +symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa) +symbols = sorted(set(symbols)) +if __name__ == "__main__": + print(len(symbols)) diff --git a/text/symbols2.py b/text/symbols2.py new file mode 100644 index 0000000000000000000000000000000000000000..2f159d2bdc017fdf8ab5e001a3cb1ee605ed60d8 --- /dev/null +++ b/text/symbols2.py @@ -0,0 +1,797 @@ +# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿 +punctuation = ["!", "?", "…", ",", "."] # @是SP停顿 +punctuation.append("-") +pu_symbols = punctuation + ["SP", "SP2", "SP3", "UNK"] +# pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"] +pad = "_" + +c = [ + "AA", + "EE", + "OO", + "b", + "c", + "ch", + "d", + "f", + "g", + "h", + "j", + "k", + "l", + "m", + "n", + "p", + "q", + "r", + "s", + "sh", + "t", + "w", + "x", + "y", + "z", + "zh", +] +v = [ + "E1", + "En1", + "a1", + "ai1", + "an1", + "ang1", + "ao1", + "e1", + "ei1", + "en1", + "eng1", + "er1", + "i1", + "i01", + "ia1", + "ian1", + "iang1", + "iao1", + "ie1", + "in1", + "ing1", + "iong1", + "ir1", + "iu1", + "o1", + "ong1", + "ou1", + "u1", + "ua1", + "uai1", + "uan1", + "uang1", + "ui1", + "un1", + "uo1", + "v1", + "van1", + "ve1", + "vn1", + "E2", + "En2", + "a2", + "ai2", + "an2", + "ang2", + "ao2", + "e2", + "ei2", + "en2", + "eng2", + "er2", + "i2", + "i02", + "ia2", + "ian2", + "iang2", + "iao2", + "ie2", + "in2", + "ing2", + "iong2", + "ir2", + "iu2", + "o2", + "ong2", + "ou2", + "u2", + "ua2", + "uai2", + "uan2", + "uang2", + "ui2", + "un2", + "uo2", + "v2", + "van2", + "ve2", + "vn2", + "E3", + "En3", + "a3", + "ai3", + "an3", + "ang3", + "ao3", + "e3", + "ei3", + "en3", + "eng3", + "er3", + "i3", + "i03", + "ia3", + "ian3", + "iang3", + "iao3", + "ie3", + "in3", + "ing3", + "iong3", + "ir3", + "iu3", + "o3", + "ong3", + "ou3", + "u3", + "ua3", + "uai3", + "uan3", + "uang3", + "ui3", + "un3", + "uo3", + "v3", + "van3", + "ve3", + "vn3", + "E4", + "En4", + "a4", + "ai4", + "an4", + "ang4", + "ao4", + "e4", + "ei4", + "en4", + "eng4", + "er4", + "i4", + "i04", + "ia4", + "ian4", + "iang4", + "iao4", + "ie4", + "in4", + "ing4", + "iong4", + "ir4", + "iu4", + "o4", + "ong4", + "ou4", + "u4", + "ua4", + "uai4", + "uan4", + "uang4", + "ui4", + "un4", + "uo4", + "v4", + "van4", + "ve4", + "vn4", + "E5", + "En5", + "a5", + "ai5", + "an5", + "ang5", + "ao5", + "e5", + "ei5", + "en5", + "eng5", + "er5", + "i5", + "i05", + "ia5", + "ian5", + "iang5", + "iao5", + "ie5", + "in5", + "ing5", + "iong5", + "ir5", + "iu5", + "o5", + "ong5", + "ou5", + "u5", + "ua5", + "uai5", + "uan5", + "uang5", + "ui5", + "un5", + "uo5", + "v5", + "van5", + "ve5", + "vn5", +] + +v_without_tone = [ + "E", + "En", + "a", + "ai", + "an", + "ang", + "ao", + "e", + "ei", + "en", + "eng", + "er", + "i", + "i0", + "ia", + "ian", + "iang", + "iao", + "ie", + "in", + "ing", + "iong", + "ir", + "iu", + "o", + "ong", + "ou", + "u", + "ua", + "uai", + "uan", + "uang", + "ui", + "un", + "uo", + "v", + "van", + "ve", + "vn", +] + +# japanese +ja_symbols = [ + "I", + "N", + "U", + "a", + "b", + "by", + "ch", + "cl", + "d", + "dy", + "e", + "f", + "g", + "gy", + "h", + "hy", + "i", + "j", + "k", + "ky", + "m", + "my", + "n", + "ny", + "o", + "p", + "py", + "r", + "ry", + "s", + "sh", + "t", + "ts", + "u", + "v", + "w", + "y", + "z", + ###楼下2个留到后面加 + # "[", #上升调型 + # "]", #下降调型 + # "$", #结束符 + # "^", #开始符 +] + +arpa = { + "AH0", + "S", + "AH1", + "EY2", + "AE2", + "EH0", + "OW2", + "UH0", + "NG", + "B", + "G", + "AY0", + "M", + "AA0", + "F", + "AO0", + "ER2", + "UH1", + "IY1", + "AH2", + "DH", + "IY0", + "EY1", + "IH0", + "K", + "N", + "W", + "IY2", + "T", + "AA1", + "ER1", + "EH2", + "OY0", + "UH2", + "UW1", + "Z", + "AW2", + "AW1", + "V", + "UW2", + "AA2", + "ER", + "AW0", + "UW0", + "R", + "OW1", + "EH1", + "ZH", + "AE0", + "IH2", + "IH", + "Y", + "JH", + "P", + "AY1", + "EY0", + "OY2", + "TH", + "HH", + "D", + "ER0", + "CH", + "AO1", + "AE1", + "AO2", + "OY1", + "AY2", + "IH1", + "OW0", + "L", + "SH", +} + +ko_symbols = "ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停" +# ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ ' + +yue_symbols = { + "Yeot3", + "Yip1", + "Yyu3", + "Yeng4", + "Yut5", + "Yaan5", + "Ym5", + "Yaan6", + "Yang1", + "Yun4", + "Yon2", + "Yui5", + "Yun2", + "Yat3", + "Ye", + "Yeot1", + "Yoeng5", + "Yoek2", + "Yam2", + "Yeon6", + "Yu6", + "Yiu3", + "Yaang6", + "Yp5", + "Yai4", + "Yoek4", + "Yit6", + "Yam5", + "Yoeng6", + "Yg1", + "Yk3", + "Yoe4", + "Yam3", + "Yc", + "Yyu4", + "Yyut1", + "Yiu4", + "Ying3", + "Yip3", + "Yaap3", + "Yau3", + "Yan4", + "Yau1", + "Yap4", + "Yk6", + "Yok3", + "Yai1", + "Yeot6", + "Yan2", + "Yoek6", + "Yt1", + "Yoi1", + "Yit5", + "Yn4", + "Yaau3", + "Yau4", + "Yuk6", + "Ys", + "Yuk", + "Yin6", + "Yung6", + "Ya", + "You", + "Yaai5", + "Yau5", + "Yoi3", + "Yaak3", + "Yaat3", + "Ying2", + "Yok5", + "Yeng2", + "Yyut3", + "Yam1", + "Yip5", + "You1", + "Yam6", + "Yaa5", + "Yi6", + "Yek4", + "Yyu2", + "Yuk5", + "Yaam1", + "Yang2", + "Yai", + "Yiu6", + "Yin4", + "Yok4", + "Yot3", + "Yui2", + "Yeoi5", + "Yyun6", + "Yyu5", + "Yoi5", + "Yeot2", + "Yim4", + "Yeoi2", + "Yaan1", + "Yang6", + "Yong1", + "Yaang4", + "Yung5", + "Yeon1", + "Yin2", + "Ya3", + "Yaang3", + "Yg", + "Yk2", + "Yaau5", + "Yut1", + "Yt5", + "Yip4", + "Yung4", + "Yj", + "Yong3", + "Ya1", + "Yg6", + "Yaau6", + "Yit3", + "Yun3", + "Ying1", + "Yn2", + "Yg4", + "Yl", + "Yp3", + "Yn3", + "Yak1", + "Yang5", + "Yoe6", + "You2", + "Yap2", + "Yak2", + "Yt3", + "Yot5", + "Yim2", + "Yi1", + "Yn6", + "Yaat5", + "Yaam3", + "Yoek5", + "Ye3", + "Yeon4", + "Yaa2", + "Yu3", + "Yim6", + "Ym", + "Yoe3", + "Yaai2", + "Ym2", + "Ya6", + "Yeng6", + "Yik4", + "Yot4", + "Yaai4", + "Yyun3", + "Yu1", + "Yoeng1", + "Yaap2", + "Yuk3", + "Yoek3", + "Yeng5", + "Yeoi1", + "Yiu2", + "Yok1", + "Yo1", + "Yoek1", + "Yoeng2", + "Yeon5", + "Yiu1", + "Yoeng4", + "Yuk2", + "Yat4", + "Yg5", + "Yut4", + "Yan6", + "Yin3", + "Yaa6", + "Yap1", + "Yg2", + "Yoe5", + "Yt4", + "Ya5", + "Yo4", + "Yyu1", + "Yak3", + "Yeon2", + "Yong4", + "Ym1", + "Ye2", + "Yaang5", + "Yoi2", + "Yeng3", + "Yn", + "Yyut4", + "Yau", + "Yaak2", + "Yaan4", + "Yek2", + "Yin1", + "Yi5", + "Yoe2", + "Yei5", + "Yaat6", + "Yak5", + "Yp6", + "Yok6", + "Yei2", + "Yaap1", + "Yyut5", + "Yi4", + "Yim1", + "Yk5", + "Ye4", + "Yok2", + "Yaam6", + "Yat2", + "Yon6", + "Yei3", + "Yyu6", + "Yeot5", + "Yk4", + "Yai6", + "Yd", + "Yg3", + "Yei6", + "Yau2", + "Yok", + "Yau6", + "Yung3", + "Yim5", + "Yut6", + "Yit1", + "Yon3", + "Yat1", + "Yaam2", + "Yyut2", + "Yui6", + "Yt2", + "Yek6", + "Yt", + "Ye6", + "Yang3", + "Ying6", + "Yaau1", + "Yeon3", + "Yng", + "Yh", + "Yang4", + "Ying5", + "Yaap6", + "Yoeng3", + "Yyun4", + "You3", + "Yan5", + "Yat5", + "Yot1", + "Yun1", + "Yi3", + "Yaa1", + "Yaap4", + "You6", + "Yaang2", + "Yaap5", + "Yaa3", + "Yaak6", + "Yeng1", + "Yaak1", + "Yo5", + "Yoi4", + "Yam4", + "Yik1", + "Ye1", + "Yai5", + "Yung1", + "Yp2", + "Yui4", + "Yaak4", + "Yung2", + "Yak4", + "Yaat4", + "Yeoi4", + "Yut2", + "Yin5", + "Yaau4", + "Yap6", + "Yb", + "Yaam4", + "Yw", + "Yut3", + "Yong2", + "Yt6", + "Yaai6", + "Yap5", + "Yik5", + "Yun6", + "Yaam5", + "Yun5", + "Yik3", + "Ya2", + "Yyut6", + "Yon4", + "Yk1", + "Yit4", + "Yak6", + "Yaan2", + "Yuk1", + "Yai2", + "Yik2", + "Yaat2", + "Yo3", + "Ykw", + "Yn5", + "Yaa", + "Ye5", + "Yu4", + "Yei1", + "Yai3", + "Yyun5", + "Yip2", + "Yaau2", + "Yiu5", + "Ym4", + "Yeoi6", + "Yk", + "Ym6", + "Yoe1", + "Yeoi3", + "Yon", + "Yuk4", + "Yaai3", + "Yaa4", + "Yot6", + "Yaang1", + "Yei4", + "Yek1", + "Yo", + "Yp", + "Yo6", + "Yp4", + "Yan3", + "Yoi", + "Yap3", + "Yek3", + "Yim3", + "Yz", + "Yot2", + "Yoi6", + "Yit2", + "Yu5", + "Yaan3", + "Yan1", + "Yon5", + "Yp1", + "Yong5", + "Ygw", + "Yak", + "Yat6", + "Ying4", + "Yu2", + "Yf", + "Ya4", + "Yon1", + "You4", + "Yik6", + "Yui1", + "Yaat1", + "Yeot4", + "Yi2", + "Yaai1", + "Yek5", + "Ym3", + "Yong6", + "You5", + "Yyun1", + "Yn1", + "Yo2", + "Yip6", + "Yui3", + "Yaak5", + "Yyun2", +} + +# symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)+list(ko_symbols)#+list(yue_symbols)###直接这么加yue顺序乱了 +symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa) +symbols = sorted(set(symbols)) +# print(len(symbols)) +symbols += ["[", "]"] ##日文新增上升下降调型 +symbols += sorted(list(ko_symbols)) +symbols += sorted(list(yue_symbols)) ##新加的yue统一摆在后头#已查过开头加Y后没有重复,韩文显然不会重复 +# print(len(symbols)) +if __name__ == "__main__": + print(len(symbols)) +""" +粤语: + 732-353=379 +韩文+粤语: + 732-322=410 +""" diff --git a/text/tone_sandhi.py b/text/tone_sandhi.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed737811a54456adbd5178c6398deb3ff6a12ab --- /dev/null +++ b/text/tone_sandhi.py @@ -0,0 +1,774 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List +from typing import Tuple + +import jieba_fast as jieba +from pypinyin import lazy_pinyin +from pypinyin import Style + + +class ToneSandhi: + def __init__(self): + self.must_neural_tone_words = { + "麻烦", + "麻利", + "鸳鸯", + "高粱", + "骨头", + "骆驼", + "马虎", + "首饰", + "馒头", + "馄饨", + "风筝", + "难为", + "队伍", + "阔气", + "闺女", + "门道", + "锄头", + "铺盖", + "铃铛", + "铁匠", + "钥匙", + "里脊", + "里头", + "部分", + "那么", + "道士", + "造化", + "迷糊", + "连累", + "这么", + "这个", + "运气", + "过去", + "软和", + "转悠", + "踏实", + "跳蚤", + "跟头", + "趔趄", + "财主", + "豆腐", + "讲究", + "记性", + "记号", + "认识", + "规矩", + "见识", + "裁缝", + "补丁", + "衣裳", + "衣服", + "衙门", + "街坊", + "行李", + "行当", + "蛤蟆", + "蘑菇", + "薄荷", + "葫芦", + "葡萄", + "萝卜", + "荸荠", + "苗条", + "苗头", + "苍蝇", + "芝麻", + "舒服", + "舒坦", + "舌头", + "自在", + "膏药", + "脾气", + "脑袋", + "脊梁", + "能耐", + "胳膊", + "胭脂", + "胡萝", + "胡琴", + "胡同", + "聪明", + "耽误", + "耽搁", + "耷拉", + "耳朵", + "老爷", + "老实", + "老婆", + "老头", + "老太", + "翻腾", + "罗嗦", + "罐头", + "编辑", + "结实", + "红火", + "累赘", + "糨糊", + "糊涂", + "精神", + "粮食", + "簸箕", + "篱笆", + "算计", + "算盘", + "答应", + "笤帚", + "笑语", + "笑话", + "窟窿", + "窝囊", + "窗户", + "稳当", + "稀罕", + "称呼", + "秧歌", + "秀气", + "秀才", + "福气", + "祖宗", + "砚台", + "码头", + "石榴", + "石头", + "石匠", + "知识", + "眼睛", + "眯缝", + "眨巴", + "眉毛", + "相声", + "盘算", + "白净", + "痢疾", + "痛快", + "疟疾", + "疙瘩", + "疏忽", + "畜生", + "生意", + "甘蔗", + "琵琶", + "琢磨", + "琉璃", + "玻璃", + "玫瑰", + "玄乎", + "狐狸", + "状元", + "特务", + "牲口", + "牙碜", + "牌楼", + "爽快", + "爱人", + "热闹", + "烧饼", + "烟筒", + "烂糊", + "点心", + "炊帚", + "灯笼", + "火候", + "漂亮", + "滑溜", + "溜达", + "温和", + "清楚", + "消息", + "浪头", + "活泼", + "比方", + "正经", + "欺负", + "模糊", + "槟榔", + "棺材", + "棒槌", + "棉花", + "核桃", + "栅栏", + "柴火", + "架势", + "枕头", + "枇杷", + "机灵", + "本事", + "木头", + "木匠", + "朋友", + "月饼", + "月亮", + "暖和", + "明白", + "时候", + "新鲜", + "故事", + "收拾", + "收成", + "提防", + "挖苦", + "挑剔", + "指甲", + "指头", + "拾掇", + "拳头", + "拨弄", + "招牌", + "招呼", + "抬举", + "护士", + "折腾", + "扫帚", + "打量", + "打算", + "打点", + "打扮", + "打听", + "打发", + "扎实", + "扁担", + "戒指", + "懒得", + "意识", + "意思", + "情形", + "悟性", + "怪物", + "思量", + "怎么", + "念头", + "念叨", + "快活", + "忙活", + "志气", + "心思", + "得罪", + "张罗", + "弟兄", + "开通", + "应酬", + "庄稼", + "干事", + "帮手", + "帐篷", + "希罕", + "师父", + "师傅", + "巴结", + "巴掌", + "差事", + "工夫", + "岁数", + "屁股", + "尾巴", + "少爷", + "小气", + "小伙", + "将就", + "对头", + "对付", + "寡妇", + "家伙", + "客气", + "实在", + "官司", + "学问", + "学生", + "字号", + "嫁妆", + "媳妇", + "媒人", + "婆家", + "娘家", + "委屈", + "姑娘", + "姐夫", + "妯娌", + "妥当", + "妖精", + "奴才", + "女婿", + "头发", + "太阳", + "大爷", + "大方", + "大意", + "大夫", + "多少", + "多么", + "外甥", + "壮实", + "地道", + "地方", + "在乎", + "困难", + "嘴巴", + "嘱咐", + "嘟囔", + "嘀咕", + "喜欢", + "喇嘛", + "喇叭", + "商量", + "唾沫", + "哑巴", + "哈欠", + "哆嗦", + "咳嗽", + "和尚", + "告诉", + "告示", + "含糊", + "吓唬", + "后头", + "名字", + "名堂", + "合同", + "吆喝", + "叫唤", + "口袋", + "厚道", + "厉害", + "千斤", + "包袱", + "包涵", + "匀称", + "勤快", + "动静", + "动弹", + "功夫", + "力气", + "前头", + "刺猬", + "刺激", + "别扭", + "利落", + "利索", + "利害", + "分析", + "出息", + "凑合", + "凉快", + "冷战", + "冤枉", + "冒失", + "养活", + "关系", + "先生", + "兄弟", + "便宜", + "使唤", + "佩服", + "作坊", + "体面", + "位置", + "似的", + "伙计", + "休息", + "什么", + "人家", + "亲戚", + "亲家", + "交情", + "云彩", + "事情", + "买卖", + "主意", + "丫头", + "丧气", + "两口", + "东西", + "东家", + "世故", + "不由", + "不在", + "下水", + "下巴", + "上头", + "上司", + "丈夫", + "丈人", + "一辈", + "那个", + "菩萨", + "父亲", + "母亲", + "咕噜", + "邋遢", + "费用", + "冤家", + "甜头", + "介绍", + "荒唐", + "大人", + "泥鳅", + "幸福", + "熟悉", + "计划", + "扑腾", + "蜡烛", + "姥爷", + "照顾", + "喉咙", + "吉他", + "弄堂", + "蚂蚱", + "凤凰", + "拖沓", + "寒碜", + "糟蹋", + "倒腾", + "报复", + "逻辑", + "盘缠", + "喽啰", + "牢骚", + "咖喱", + "扫把", + "惦记", + } + self.must_not_neural_tone_words = { + "男子", + "女子", + "分子", + "原子", + "量子", + "莲子", + "石子", + "瓜子", + "电子", + "人人", + "虎虎", + "幺幺", + "干嘛", + "学子", + "哈哈", + "数数", + "袅袅", + "局地", + "以下", + "娃哈哈", + "花花草草", + "留得", + "耕地", + "想想", + "熙熙", + "攘攘", + "卵子", + "死死", + "冉冉", + "恳恳", + "佼佼", + "吵吵", + "打打", + "考考", + "整整", + "莘莘", + "落地", + "算子", + "家家户户", + "青青", + } + self.punc = ":,;。?!“”‘’':,;.?!" + + # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041 + # e.g. + # word: "家里" + # pos: "s" + # finals: ['ia1', 'i3'] + def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]: + # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺 + for j, item in enumerate(word): + if ( + j - 1 >= 0 + and item == word[j - 1] + and pos[0] in {"n", "v", "a"} + and word not in self.must_not_neural_tone_words + ): + finals[j] = finals[j][:-1] + "5" + ge_idx = word.find("个") + if len(word) >= 1 and word[-1] in "吧呢哈啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶": + finals[-1] = finals[-1][:-1] + "5" + elif len(word) >= 1 and word[-1] in "的地得": + finals[-1] = finals[-1][:-1] + "5" + # e.g. 走了, 看着, 去过 + elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}: + finals[-1] = finals[-1][:-1] + "5" + elif len(word) > 1 and word[-1] in "们子" and pos in {"r", "n"} and word not in self.must_not_neural_tone_words: + finals[-1] = finals[-1][:-1] + "5" + # e.g. 桌上, 地下, 家里 + elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}: + finals[-1] = finals[-1][:-1] + "5" + # e.g. 上来, 下去 + elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开": + finals[-1] = finals[-1][:-1] + "5" + # 个做量词 + elif ( + ge_idx >= 1 and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是") + ) or word == "个": + finals[ge_idx] = finals[ge_idx][:-1] + "5" + else: + if word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words: + finals[-1] = finals[-1][:-1] + "5" + + word_list = self._split_word(word) + finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]] + for i, word in enumerate(word_list): + # conventional neural in Chinese + if word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words: + finals_list[i][-1] = finals_list[i][-1][:-1] + "5" + finals = sum(finals_list, []) + return finals + + def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]: + # e.g. 看不懂 + if len(word) == 3 and word[1] == "不": + finals[1] = finals[1][:-1] + "5" + else: + for i, char in enumerate(word): + # "不" before tone4 should be bu2, e.g. 不怕 + if char == "不" and i + 1 < len(word) and finals[i + 1][-1] == "4": + finals[i] = finals[i][:-1] + "2" + return finals + + def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]: + # "一" in number sequences, e.g. 一零零, 二一零 + if word.find("一") != -1 and all([item.isnumeric() for item in word if item != "一"]): + return finals + # "一" between reduplication words shold be yi5, e.g. 看一看 + elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]: + finals[1] = finals[1][:-1] + "5" + # when "一" is ordinal word, it should be yi1 + elif word.startswith("第一"): + finals[1] = finals[1][:-1] + "1" + else: + for i, char in enumerate(word): + if char == "一" and i + 1 < len(word): + # "一" before tone4 should be yi2, e.g. 一段 + if finals[i + 1][-1] == "4": + finals[i] = finals[i][:-1] + "2" + # "一" before non-tone4 should be yi4, e.g. 一天 + else: + # "一" 后面如果是标点,还读一声 + if word[i + 1] not in self.punc: + finals[i] = finals[i][:-1] + "4" + return finals + + def _split_word(self, word: str) -> List[str]: + word_list = jieba.cut_for_search(word) + word_list = sorted(word_list, key=lambda i: len(i), reverse=False) + first_subword = word_list[0] + first_begin_idx = word.find(first_subword) + if first_begin_idx == 0: + second_subword = word[len(first_subword) :] + new_word_list = [first_subword, second_subword] + else: + second_subword = word[: -len(first_subword)] + new_word_list = [second_subword, first_subword] + return new_word_list + + def _three_sandhi(self, word: str, finals: List[str]) -> List[str]: + if len(word) == 2 and self._all_tone_three(finals): + finals[0] = finals[0][:-1] + "2" + elif len(word) == 3: + word_list = self._split_word(word) + if self._all_tone_three(finals): + # disyllabic + monosyllabic, e.g. 蒙古/包 + if len(word_list[0]) == 2: + finals[0] = finals[0][:-1] + "2" + finals[1] = finals[1][:-1] + "2" + # monosyllabic + disyllabic, e.g. 纸/老虎 + elif len(word_list[0]) == 1: + finals[1] = finals[1][:-1] + "2" + else: + finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]] + if len(finals_list) == 2: + for i, sub in enumerate(finals_list): + # e.g. 所有/人 + if self._all_tone_three(sub) and len(sub) == 2: + finals_list[i][0] = finals_list[i][0][:-1] + "2" + # e.g. 好/喜欢 + elif ( + i == 1 + and not self._all_tone_three(sub) + and finals_list[i][0][-1] == "3" + and finals_list[0][-1][-1] == "3" + ): + finals_list[0][-1] = finals_list[0][-1][:-1] + "2" + finals = sum(finals_list, []) + # split idiom into two words who's length is 2 + elif len(word) == 4: + finals_list = [finals[:2], finals[2:]] + finals = [] + for sub in finals_list: + if self._all_tone_three(sub): + sub[0] = sub[0][:-1] + "2" + finals += sub + + return finals + + def _all_tone_three(self, finals: List[str]) -> bool: + return all(x[-1] == "3" for x in finals) + + # merge "不" and the word behind it + # if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error + def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + new_seg = [] + last_word = "" + for word, pos in seg: + if last_word == "不": + word = last_word + word + if word != "不": + new_seg.append((word, pos)) + last_word = word[:] + if last_word == "不": + new_seg.append((last_word, "d")) + last_word = "" + return new_seg + + # function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听" + # function 2: merge single "一" and the word behind it + # if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error + # e.g. + # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')] + # output seg: [['听一听', 'v']] + def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + new_seg = [] + i = 0 + # function 1 + while i < len(seg): + word, pos = seg[i] + merged = False + if i - 1 >= 0 and word == "一" and i + 1 < len(seg): + last = new_seg[-1] if new_seg else seg[i - 1] + if last[0] == seg[i + 1][0] and last[1] == "v" and seg[i + 1][1] == "v": + combined = last[0] + "一" + seg[i + 1][0] + new_seg[-1] = [combined, last[1]] + i += 2 + merged = True + if not merged: + new_seg.append([word, pos]) + i += 1 + seg = new_seg + new_seg = [] + # function 2 + for word, pos in seg: + if new_seg and new_seg[-1][0] == "一": + new_seg[-1][0] = new_seg[-1][0] + word + else: + new_seg.append([word, pos]) + return new_seg + + # the first and the second words are all_tone_three + def _merge_continuous_three_tones(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + new_seg = [] + sub_finals_list = [ + lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg + ] + assert len(sub_finals_list) == len(seg) + merge_last = [False] * len(seg) + for i, (word, pos) in enumerate(seg): + if ( + i - 1 >= 0 + and self._all_tone_three(sub_finals_list[i - 1]) + and self._all_tone_three(sub_finals_list[i]) + and not merge_last[i - 1] + ): + # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi + if not self._is_reduplication(seg[i - 1][0]) and len(seg[i - 1][0]) + len(seg[i][0]) <= 3: + new_seg[-1][0] = new_seg[-1][0] + seg[i][0] + merge_last[i] = True + else: + new_seg.append([word, pos]) + else: + new_seg.append([word, pos]) + + return new_seg + + def _is_reduplication(self, word: str) -> bool: + return len(word) == 2 and word[0] == word[1] + + # the last char of first word and the first char of second word is tone_three + def _merge_continuous_three_tones_2(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + new_seg = [] + sub_finals_list = [ + lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg + ] + assert len(sub_finals_list) == len(seg) + merge_last = [False] * len(seg) + for i, (word, pos) in enumerate(seg): + if ( + i - 1 >= 0 + and sub_finals_list[i - 1][-1][-1] == "3" + and sub_finals_list[i][0][-1] == "3" + and not merge_last[i - 1] + ): + # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi + if not self._is_reduplication(seg[i - 1][0]) and len(seg[i - 1][0]) + len(seg[i][0]) <= 3: + new_seg[-1][0] = new_seg[-1][0] + seg[i][0] + merge_last[i] = True + else: + new_seg.append([word, pos]) + else: + new_seg.append([word, pos]) + return new_seg + + def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + new_seg = [] + for i, (word, pos) in enumerate(seg): + if i - 1 >= 0 and word == "儿" and seg[i - 1][0] != "#": + new_seg[-1][0] = new_seg[-1][0] + seg[i][0] + else: + new_seg.append([word, pos]) + return new_seg + + def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + new_seg = [] + for i, (word, pos) in enumerate(seg): + if new_seg and word == new_seg[-1][0]: + new_seg[-1][0] = new_seg[-1][0] + seg[i][0] + else: + new_seg.append([word, pos]) + return new_seg + + def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + seg = self._merge_bu(seg) + try: + seg = self._merge_yi(seg) + except: + print("_merge_yi failed") + seg = self._merge_reduplication(seg) + try: + seg = self._merge_continuous_three_tones(seg) + except: + print("_merge_continuous_three_tones failed") + try: + seg = self._merge_continuous_three_tones_2(seg) + except: + print("_merge_continuous_three_tones_2 failed") + + seg = self._merge_er(seg) + return seg + + def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]: + finals = self._bu_sandhi(word, finals) + finals = self._yi_sandhi(word, finals) + finals = self._neural_sandhi(word, pos, finals) + finals = self._three_sandhi(word, finals) + return finals diff --git a/text/zh_normalization/README.md b/text/zh_normalization/README.md new file mode 100644 index 0000000000000000000000000000000000000000..92eea9f54630dfd41dfc3ce53bc511cc7595062c --- /dev/null +++ b/text/zh_normalization/README.md @@ -0,0 +1,16 @@ +## Supported NSW (Non-Standard-Word) Normalization + +|NSW type|raw|normalized| +|:--|:-|:-| +|serial number|电影中梁朝伟扮演的陈永仁的编号27149|电影中梁朝伟扮演的陈永仁的编号二七一四九| +|cardinal|这块黄金重达324.75克
我们班的最高总分为583分|这块黄金重达三百二十四点七五克
我们班的最高总分为五百八十三分| +|numeric range |12\~23
-1.5\~2|十二到二十三
负一点五到二| +|date|她出生于86年8月18日,她弟弟出生于1995年3月1日|她出生于八六年八月十八日, 她弟弟出生于一九九五年三月一日| +|time|等会请在12:05请通知我|等会请在十二点零五分请通知我 +|temperature|今天的最低气温达到-10°C|今天的最低气温达到零下十度 +|fraction|现场有7/12的观众投出了赞成票|现场有十二分之七的观众投出了赞成票| +|percentage|明天有62%的概率降雨|明天有百分之六十二的概率降雨| +|money|随便来几个价格12块5,34.5元,20.1万|随便来几个价格十二块五,三十四点五元,二十点一万| +|telephone|这是固话0421-33441122
这是手机+86 18544139121|这是固话零四二一三三四四一一二二
这是手机八六一八五四四一三九一二一| +## References +[Pull requests #658 of DeepSpeech](https://github.com/PaddlePaddle/DeepSpeech/pull/658/files) diff --git a/text/zh_normalization/__init__.py b/text/zh_normalization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..46b367a68b074cf02da933f1e2433b86eeffe494 --- /dev/null +++ b/text/zh_normalization/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from text.zh_normalization.text_normlization import * diff --git a/text/zh_normalization/char_convert.py b/text/zh_normalization/char_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..5b57ed973cecdaf100aeababb2665215739bae2f --- /dev/null +++ b/text/zh_normalization/char_convert.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Traditional and simplified Chinese conversion, a simplified character may correspond to multiple traditional characters.""" + +simplified_charcters = "制咖片型超声盘鉴定仔点他命书歌粉巾字帐恤手指记忆棒形转弯沟光○〇㐄㐅㐆㐌㐖毒㐜㐡㐤㐰㐺㑇㑳㒳㒸㔾㗂㗎㝵㞎㞙㞞以㢲㢴㤅㥁㥯㨗㫺㬎㮎㮚㮸㲋㲱㲾㳮涧㵪㶸㷖㷭㹢㹴犬㺢狓㺵碗㽮㿝䍃䔢䖟䖸䗈䗥䗪䝓射䥯䦉䯝鲃鱼䲔䳗鹅䵹鼄䶑一对应映射丁不识下儿子做二休世丘之貉并中台原则串为甚谓干净了百事无成八变五十些人得道鸡升天代如并来去个国政策劲幽灵在欧洲游荡接样萝卜坑侧化传价元论醇共再准刀两断切分耕耘收获钱货物向看旧就绪险刻千金动劳永逸匙零夜半卡通回复返影踪反常态口咬气句话同吐快吹周味呼诺呜品红锅哄而散起唱和问三知生熟团漆黑火糟堆场空块面塌糊涂尘染壁厢夔已足多情露水大早到晚夫妻当关万莫开失古恨套所料既往孔见提师要家主审寸阴难买斗牛小撮部阵局展身层巴掌帆风顺席地带过年计于春头载四季期被蛇怕井绳度愿式份弹顷深前律径心意念差愁孤行俱全房厅交遮打技长把抓死拿眼泪鼻涕钥锁折段抿拍即合扫排掬挥拨拥上入击洞掷揽改故辙败文值名斑方面旁族日秋餐隔雅里终父旦时晌会霎间晃暴寒曝更月望垠际朝夕本正经利杯羹东西板枝独秀根筋杆进条龙服务概模次函数又性程总付步脚印趋登毛拔呵氧氮碳决雌雄波未平派谎言流清楚白准溜烟潭有获闻是处降琴鹤甲病发可拾沙目然了直以相眨穿睹瞥瞬矢的解石鸟神教秉虔诚秘种窝蜂穷窍笑置笔苟勾销抹杀煞等奖箍节吃箭仇双雕诗筹箩筐系列纸级士官统丝毫挂维网尽线微吭响股脑胎脉承腔臂力致效资源址器举功投般说讲规贸易叶障着慎满皆输号木电池衣倾钟高低视仁觉醒览遗角银币触溃九鼎蔽抄出驷马追重语破贫洗贯走路安蹴至几蹶振跃役胆汗较辈轮辞赞退六连遍递边针血锤音错门思闪真倒项栽雾类保护川先惊乍体哄鳞爪鸣滴泡邻域党专鼓作齐炒丑烯亥克内酯冬加奴卯肝炎基尺梁街裤镐客宠庭巳汝昌烷玲磊糖肇酉醛啷青县韪良香骨鲷丂七集河市弦喜嘴张舌堵区工业姊妹星架构巧彩扭歪拼凑余热曜武州爷浮屠美乡老阶树荤素碎落能魄鳃鳗珠丄丅丆万俟丈尚摸母娘量管群亚虎必我堂令申件装伏位博侠义界表女墟台戏臭皮匠胜诸葛亮赛顶倍催请运算包立叉戟离疫苗土史志演围揭瓦晒夷姑婆帝村宝烂尖杉碱屉桌山岔岛由纪峡坝库镇废从德后拗汤治旬食明昧曹朋友框栏极权幂曲归依猫民氟硼氯磷铁江侗自旅法司洋浦梅园温暖湾焦班幸用田略番叠皇炮捶硝苯酸腺苷棱草镜穗跳远索锦纲聚氰胺联店胚膲爱色堇紫罗兰芝茶饭菱云虫藏藩乱叛苏亲债凳学座恐恋柱测肌腹衩锥系貂企乌跪叩军车农题迭都甘油屯奏键短阿姨陪姐只顾茅庐槽驾魂鲜鹿页其菜单乘任供势午齿汉组织吊调泻唇坡城报坟外夸将尉建筑岸岗公床扬新剑升杭林栗校楼标款汽社浣海商馆剧院钢华港机械广媒环球融第医科证券综财乐育游涨犹岭疏瘾睑确兵领导缴肢膛船艾瑟尔苍蔡虞效衫覆访诉课谕议轨述野钩限敌鞋颌颔颚饶首龈站例修凡划垂届属崽颏厨拜挫摆放旋削棋榻槛礼沉注滑营狱画确仪聘花葬诏员跌辖周达酒锚闸陷陆雨雪飞威丌于丹久乏予理评产亢卑亦乎舞己悲矩圆词害志但住佞佳便俗信票案幅翁倦伦假偏倚斜亏鬼敲停备伤脾胃仅此像俭匮免宜穴焉戴兼容许冻伯仲负彼昼皂轩轾实刊划颠卫战哥比省非好黄饰别拘束掩奶睬选择摇扰烦苦枚写协厌及格受欢迎约只估侵犯割状告或缺抗拒挽撤救药喻磨灭端倪少逆逾越避靠适吉誉吝玉含延咎歹听啻渊善谋均匀堪忍够太惹妙妥妨孕症孝术室完纳推冠积宣疑辩栗碴称屈挠屑干涉衡待很忙恶忿怎么怠急耻恭息悦惑惜惟想愉愧怍慌愤启懂懈怀材才紧招认扣抵拉舍也罢插揣冒搭撞南墙扩核支攻敢雷攀敬里吗需景智暇曾罪遇朽枉止况竞争辱求愈渝溶济左右袒困补爽特寂寞示弱找谢畏强疾徐痛痒冤符眠睦瞅董何厚云措活疲羞者轻玻璃祥兆禁移稂莠稳佛换答简结果盟绝缕途给谈否羁翼耐肖胫毋宁兴舒若菲莱痕迹窠臼虚衰脸兔撒鹰棺范该详讳抬泰让须眉象众赀账费灰赖奇虑训辍辨菽麦辛近送透逞徒速续逮捕遂遑违逊斧钺艰醉锈随观弃显饱脂肪使丏丐帮丒且慢末丕替桃宗王尊凉爵各图屋脊粮署录坛吾禄职胄袭君厦丗北壑桐疹损逢陵鹬丙寅戌氨腈唑纶辰酮脱氢酶醚丞丢现掉纱帽弄扯炮碗丠両丣坐存激肩臻蒂莲悖序驱丨丩丫挺杈髻鬟细介俄伊犁京尼布订普渡央委监察检查剂圈设警队斯督剩震境航舶革防托播促质版蝾螈锋研艺历残消频谱精密制造陲邮候埔坚压坜凹汇执府究邦俘摄寮彬狼岳肺肿庸英讯诊埋粒胞括控码韩暑枪枢砥澳哇牟寿甸钻探篇签缀缝继耳肯照妇埃悬璧轴柜台辣搁浅邪跑纤阮阳私囊魔丮丰姿采丱烧丳丵丶丷丸参寨朗桂瑞砂衷霞貌凤仆舰因嫌宰峰干络牌持旨祭祷簿编罚宾办丼丿乀乂乃乄仰慕盛旷留考验阔乆乇么丑麽乊湖燃乑乒乓乕乖僻忤戾离谬迕乗危肥劫除隙浪婿乙炔肠酰吡咯盐乚乛乜嘢卿玄宫尾狐龟塔嶷兄弟泉章霄钉耙乞扎哀怜恕讨乢乣乤乥乧乨乩童乪乫乭乳晕汁液瑶浆牙癌突窦罩腐胶猪酪蛋糕菌瘤乴乵乶乷乸乹乺乼乾俸冰嘉哕嚎坤妈尸垒旱枯涸俐渴潮涩煸豆燥爹瘦瘪癣瞪袋脆姜贝隆馏乿亀亁叫咕攘扔搞男砸窜蓬麻亃亄亅却亇迟典今临繁累卵奉婚聪躬巨与迁添裂副宿岁怪恶尕仑愣杆硅硫钛铀锰芑杂异钠砷胂磺琥珀舱棍簧胡茬盗浩盆贩郎腿亍洪亐互欠助勉惠操斥诿系户译亓墓碑刑铃卅渠缤纷斗米旗宪钒灯徽瘟祖拳福谷丰脏腑绑肉腌苓蕴桥铺霸颜闹判喷冈底蛙陉矿亖亘亜罕们娜桑那努哈喀弗烈曼松森杜氏杯奥琛敦戊穆圣裔汇薛孙亟亡佚虏羊牢奋释卷卸契媾感额睫缠谊趾塞挤纽阻还配驰庄亨洛祚亪享津沪畿郊慈菴枇杷膏亭阁锃丽亳亶亹诛初责翻疯偶杰丛稠妖拖寰居吸授慧蜗吞壮魅狗矛盾益渣患忧稀描猿梦暂涯畜祸缘沸搜引擎臣横纭谁混援蒸兽狮税剖亻亼亽亡什献刹邡么仂仃仄仆富怨仈仉毕昔晨壳绍仍仏仒仕宦仗欺恃腰叹叹炬梓讫施仙后琼逝仚仝仞仟悔仡佬偿填泊拓扑簇羔购顿钦佩发棻阃驭养亿儆尤借帧赈凌叙帖李柔刚沃眦睚戒讹取飨读仨仫仮著泳卧躺韶夏裁仳仵唯贤凭钓诞仿似宋佛讽伀硕盼鹅伄儅伈伉俪柯始娃迈戈坦堡帕茨萨庙玛莉莎藤霍姆伋伍奢胥廷芳豪伎俩侍汛勒希羲雏伐憩整谟闲闲伕伙伴颐伜伝伢叔恒兹恩翰伱伲侣伶俜悧鼬伸懒缩喇叭伹伺伻伽倻辐伾似佃伫布乔妮墨佉卢佌贷劣廉昂档浓矮伞洼缓耗胸谷迷挡率龋宅沫舍疗佐贰佑占优据铧尝呢须鲁晓佗佘余坪寺瓜铳僧蒙芒陀龛哼呕坊奸孽弊揖祟茧缚誓贼佝偻瞀佟你夺赶佡佢佣佤佧贾佪佫佯佰佱洁绩酿肴佴卷佶佷佸佹佺佻佼佽佾具唤窘坏娱怒慨硬习惯聋膨胀蔓骇贵痹侀侁侂侃侄侅鸿燕侇侈糜靡侉侌妾侏儒仓鼠侐侑侔仑侘侚链侜偎傍钴循柳葫芦附価侮骂蔑侯岩截蚀局贴壶嬛宴捷携桶笺酌俣狭膝狄俅俉俊俏俎俑俓俔谚俚俛黎健呈固墒增守康箱湿祐镖镳杠盒靖膜龄俞豹猎噪孚封札筒托衍鸽剪撰稿炼厂禊练缮葺俯瞰撑冲效俳俴俵俶俷俺备俾伥倂倅储卒惶敷猝逃颉蓄崇隐倌倏忽刺蜡烛噍嚼坍扁抽毙葱楣灌灶粪背薮卖赔闭霉腾倓倔幸倘倜傥倝借箸挹浇阅倡狂倢倣値倥偬倨傲倩匡嗣冲柝珍倬倭寇猩倮倶倷倹勤赞偁偃充伪吏嗓寐惺扮拱芫茜藉虢钞偈伟晶偌宕距析滤殿疼瘫注颇偓偕鸭歇滞偝偟偢忘怡旺偨偩逼偫偭偯偰偱偲侦缉蹄偷减惰漏窥窃偸偺迹傀儡傅傈僳骂篱傎奎琳迪叟芭傒傔傕伧悉荒傜傞傢傣芽逼佣婢傮睨寄檄诵谣颂伛担辜弓惨蒿悼疤傺傻屄臆巢泄箧羡盖轧颓傿㑩僄僇佥僊働僎侨僔僖僚僝伪僣僤侥僦猴偾僩僬僭僮僯僰雇僵殖签静僾僿征陇儁侬儃儇侩朴薄儊儋儌儍傧儓俦侪拟尽儜儞儤儦儩汰哉寡渥裕酷儭儱罐儳儵儹傩俨儽兀臬臲鹫允勋勋宙宵帅憝彝谐嫂阋畅沛溢盈饥赫凶悍狠猛顽愚妣斩秦遣鞭耀敏荣槃泽爆碟磁秃缆辉霁卤朵娄孜烽酱勃汀箕裘钳耶蒙蕾彻兑软遭黜兎児韵媳爸兕觥兖兙兛兜售鍪肚兝兞兟兡兢兣樽殓涅睡禀籍赘泌啡肽奸幕涵涝熵疚眷稃衬讧赴焕椒歼植跏没试误猜栖窗肋袖颊兪卦撇胡岐廓轿疸枫茴珑厕秩募勺吨寓斤历亩迫筷厘最淫螺韬兮宽匪筛襄赢轭复兲诈刃堰戎痞蚁饷它冀铸冂冃円冇冉册嫁厉砺竭醮冏牧冑冓冔冕冖冗冘冞冢窄抑诬冥冫烘菇蛰冷凝坨橇淇淋炭饼砖碛窖醋雕雹霜冱冶炉艳嘲峻滩淡漠煖飕饮冼冽凃凄怆梗凅凇净凊凋敝蒙凔凛遵汞脢凞几凢処凰凯凵凶焰凸折刷纹预丧喽奔巡榜殡芙蓉租笼辑鞘萃凼锯镬刁蛮刂娩崩批拆摊掰蘖骤歧颗秒袂赃勿嘱忌磋琢肤刈羽刎讼戮舂桨艇刓刖霹雳刜创犊刡恙墅帜筵致劫劫刨昏默攸尿欲熏润薰圭删刮痧铲刱刲刳刴刵踏磅戳柏槐绣芹苋猬舟铭鹄鹜劫剁剃辫刭锉履铅克剌姻咽哨廊掠桅沿召瞻翅赵卜渺茫郭剒剔剕沥剚愎毅讷才剜剥啄采剞剟剡剣剤䌽剐肾驶黏剰袍剀紊铲剸剺剽剿劁劂札劈啪柴扳啦刘奭姥夼昫涓熙禅禹锡翔雁鹗刽刿弩柄蜻蛉劒劓劖劘劙澜篑赏矶釜晋甜薪逐劦熔纣虐赤囚劬劭労劵效劻劼劾峭艮勅勇励勍勐腊脖庞漫饲荡粥辄勖勗勘骄馁碌泮雇捐竹骑殊阱绩朴恳谨剿勧勩勯勰劢勋勷劝惩慰诫谏勹芡践阑匁庇拯粟扎袱裹饺匆遽匈匉匊匋匍匐茎匏匕妆痰脓蛹斋苑烤蹈塘羌熊阀螳螂疆碚竿纬荷茵邙魏匚匜匝匟扶稷匣匦拢匸匹耦匽匾匿卂叮疮禧轸堤棚迢钧炼卄卆遐卉瓷盲瓶当胱腱裸卋卌卍卐怯污贱鄙龌龊陋卓溪唐梯渔陈枣泥漳浔涧梨芬谯赡辕迦郑単驴弈洽鳌卛占筮卝卞卟吩啉屎翠厄卣卨卪卬卮榫袄玺绶钮蚤惧殆笃耸卲帘帙绕恤卼卽厂厎厓厔厖厗奚厘厍厜厝谅厕厤厥厪腻孢厮厰厳厣厹厺粕垢芜菁厼厾叁悟茸薯叄吵笄悌哺讥坫垄弧芯杠潜婴刍袁诘贪谍煽馈驳収岳缔灾贿骗叚叡吻拦蘑蜜诀燧玩砚筝椎蔺铜逗骊另觅叨唠谒杵姓喊嚷嚣咚咛塑寻恼憎擦只泣渗蝠叱吒咄咤喝籀黛舵舷叵叶铎懿昭穰苴辽叻叼吁堑嫖赌瞧爬众抒吅吆夥卺橡涤抱纵摩郡唁坠扇篮膀袜颈吋忾谘酬哭妓媛暗表缰迩妃羿絮蕃浑拐葵暮隅吔吖啶嗪戚吜啬噬咽吟哦咏吠吧唧嗒咐吪隽咀征燐苞茹钙哧吮吰吱嘎吲哚吴栋娇窟孟箫忠晗淞阖闾趼宇呐睛嘘拂捧疵熄竽笛糠吼吽呀吕韦蒙呃呆笨呇贡呉罄呋喃呎呏呔呠呡痴呣呤呦呧瑛眩扒晬淑姬瑜璇鹃呪呫哔嚅嗫呬呯呰呱呲咧噌钝呴呶呷呸呺呻哱咻啸噜吁坎坷逻呿咁咂咆哮咇咈咋蟹煦珅蔼咍咑咒诅咔哒嚓咾哝哩喱咗咠咡咢咣咥咦咨嗟询咩咪咫啮啮咭咮咱咲咳呛嗽咴啕咸咹咺呙喉咿婉恸悯赋矜绿茗蓝哂抢瞒哆嗦啰噻啾滨彗哋哌哎唷哟哏哐哞哢哤哪里哫啼喘哰哲萎蚌哳咩哽哿呗唅唆唈唉唎唏哗尧棣殇璜睿肃唔睇唕吣唞唣喳唪唬唰喏唲唳唵嘛唶唸唹唻唼唾唿啁啃鹦鹉啅埠栈榷祺铺鞅飙啊啍啎啐啓啕啖啗啜哑祈啢衔啤啥啫啱啲啵啺饥啽噶昆沁喁喂喆裙喈咙喋喌喎喑喒喓喔粗喙幛庆滋鹊喟喣喤喥喦喧骚喨喩梆吃葡萄喭驼挑吓碰枞瓣纯疱藻趟铬喵営喹喺喼喿嗀嗃嗄嗅嗈嗉嗊嗍嗐嗑嗔诟嗕嗖嗙嗛嗜痂癖嗝嗡嗤嗥嗨唢嗬嗯嗰嗲嗵叽嗷嗹嗾嗿嘀嘁嘂嘅惋嘈峪禾荫啀嘌嘏嘐嘒啯啧嘚唛嘞嘟囔嘣嘥嘦嘧嘬嘭这谑严敞馋松哓嘶嗥呒虾嘹嘻啴嘿噀噂噅噇噉噎噏噔噗噘噙噚咝噞噢噤蝉皿噩噫噭嗳噱哙噳嚏涌洒欲巫霏噷噼嚃嚄嚆抖哜尝嚔苏嚚嚜嚞嚟呖嚬嚭嚮嚯亸喾饬按竣苛嚵嘤啭冁呓膪谦囍囒囓囗囘萧酚飘溅谛囝溯眸纥銮鹘囟殉囡団囤囥囧囨囱囫囵囬囮囯囲図囶囷囸囹圄圉拟囻囿圀圂圃圊粹蠹赦圌垦圏滚鲱凿枘圕圛圜圞坯埂壤骸炕祠窑豚绅魠鲮鳖圧握圩圪垯圬圮圯炸岬幔毯祇窨菩溉圳圴圻圾坂坆沾坋坌舛壈昆垫墩椅坒坓坩埚坭坰坱坳坴坵坻坼杨挣涎帘垃垈垌垍垓垔垕垗垚垛垝垣垞垟垤垧垮垵垺垾垿埀畔埄埆埇埈埌殃隍埏埒埕埗埜垭埤埦埧埭埯埰埲埳埴埵埶绋埸培怖桩础辅埼埽堀诃侄庑堃堄摧磐贞韧砌堈堉垩堋堌堍堎垴堙堞堠礁堧堨舆堭堮蜓摘堲堳堽堿塁塄塈煤茔棵塍垲埘塓绸塕鸦沽虱塙冢塝缪塡坞埙塥塩塬塱场螨塼塽塾塿墀墁墈墉墐夯増毁墝墠墦渍钵墫墬堕墰墺墙橱壅壆壊壌壎壒榨蒜壔壕壖圹垆壜壝垅壡壬壭壱売壴壹壻壸寝壿夂夅夆変夊夌漱邑夓腕泄甥御骼夗夘夙衮瑙妊娠醣枭珊莺鹭戗幻魇夤蹀秘擂鸫姚宛闺屿庾挞拇賛蛤裨菠氅漓捞湄蚊霆鲨箐篆篷荆肆舅荔鲆巷惭骰辟邱镕镰阪漂烩鲵鲽鳄鸨胪鹏妒峨谭枰晏玑癸祝秤竺牡籁恢罡蝼蝎赐绒御梭夬夭砣榆怙枕夶夹馅奄崛葩谲奈贺祀赠奌奂奓奕䜣詝奘奜奠奡奣陶奨奁魁奫奬奰娲孩贬隶酥宄狡猾她姹嫣妁毡荼皋膻蝇嫔妄妍嫉媚娆妗趣妚妞妤碍妬娅妯娌妲妳妵妺姁姅姉姗姒姘姙姜姝姞姣姤姧姫姮娥姱姸姺姽婀娀诱慑胁娉婷娑娓娟娣娭娯娵娶娸娼婊婐婕婞婤婥溪孺婧婪婬婹婺婼婽媁媄媊媕媞媟媠媢媬媮妫媲媵媸媺媻媪眯媿嫄嫈袅嫏嫕妪嫘嫚嫜嫠嫡嫦嫩嫪毐嫫嫬嫰妩嫺娴嫽嫿妫嬃嬅嬉耍婵痴艳嬔嬖嬗嫱袅嫒嬢嬷嬦嬬嬭幼嬲嬴婶嬹嬾嬿孀娘孅娈孏曰癫屏孑孓雀孖斟篓谜摺孛矻鸠崮轲祜鸾孥邈毓棠膑孬孭孰孱孳孵泛罔衔孻孪宀宁冗拙株薇掣抚琪瓿榴谧弥宊濂祁瑕宍宏碁宓邸谳実潢町宥宧宨宬徵崎骏掖阙臊煮禽蚕宸豫寀寁寥寃檐庶寎暄碜寔寖寘寙寛寠苫寤肘洱滥蒗陕核寪弘绰螽宝擅疙瘩晷対檐専尃尅赎绌缭畴衅尌峙醌襟痲碧屁昊槌淘恵瀑牝畑莓缸羚觑蔻脏躁尔尓锐尗尙尜尟尢尥尨尪尬尭尰擒尲尶尴尸尹潽蠖蛾尻扣梢蚴鳍脬蹲屇屌蚵屐屃挪屖屘屙屛屝屡屣峦嶂岩舄屧屦屩屪屃屮戍驻钾崖嵛巅旮旯楂榄榉芋茱萸靛麓屴屹屺屼岀岊岌岍阜岑彭巩岒岝岢岚岣岧岨岫岱岵岷峁峇峋峒峓峞峠嵋峨峰峱岘峹峿崀崁崆祯崋崌崃岖昆崒崔嵬巍萤颢崚崞崟崠峥巆崤崦崧殂岽崱崳崴崶崿嵂嵇嵊泗嵌嵎嵒嵓岁嵙嵞嵡嵩嵫嵯嵴嵼嵾嵝崭崭晴嶋嶌嶒嶓嵚崂嶙嶝嶞峤嶡嶢峄嶨嶭嶮嶰嶲岙嵘巂巃巇巉岿巌巓巘巛滇芎巟巠弋回巣巤炊擘蜥蟒蛊觋巰蜀彦淖杏茂甫楞巻巽帼巿帛斐鲫蕊帑帔帗帚琉汶帟帡帣帨裙帯帰帷帹暆帏幄帮幋幌幏帻幙帮幞幠幡幢幦幨幩幪帱幭幯幰遥蹉跎馀庚鉴幵幷稚邃庀庁広庄庈庉笠庋跋庖牺庠庤庥鲸庬庱庳庴庵馨衢庹庿廃厩廆廋廌廎廏廐廑廒荫廖廛厮搏锣廞弛袤廥廧廨廪廱绵踵髓廸迫瓯邺廻廼廾廿躔弁皱弇弌弍弎弐弑吊诡憾荐弝弢弣弤弨弭弮弰弪霖繇焘斌旭溥骞弶弸弼弾彀彄别累纠强彔彖彘彟彟陌彤贻彧绘虹彪炳雕蔚鸥彰瘅彲彳彴仿彷徉徨彸彽踩敛旆徂徇徊渭畲铉裼従筌徘徙徜徕膳苏萌渐徬徭醺徯徳徴潘徻徼忀瘁胖燎怦悸颤扉犀澎湃砰恍惚绞隘忉惮挨饿忐忑忒忖応忝忞耿忡忪忭忮忱忸怩忻悠懑怏遏怔怗怚怛怞怼黍讶怫怭懦怱怲恍怵惕怸怹恁恂恇恉恌恏恒恓恔恘恚恛恝恞恟恠恣恧眄恪恫恬澹恰恿悀悁悃悄悆悊悐悒晦悚悛悜悝悤您悩悪悮悰悱凄恻德悴怅惘闷悻悾惄愫钟蒐惆惇惌惎惏惓惔惙惛耄惝疟浊恿惦德恽惴蠢惸拈愀愃愆愈愊愍愐愑愒愓愔愕恪氓蠢騃昵惬赧悫愬愮愯恺愼慁恿慅慆慇霭慉慊愠慝慥怄怂慬慱悭慴慵慷戚焚憀灼郁憃惫憋憍眺捏轼愦憔憖憙憧憬憨憪憭怃憯憷憸憹憺懃懅懆邀懊懋怿懔懐懞懠懤懥恹懫懮懰懱毖懵遁梁雍忏懽戁戄戆戉戋戕戛戝戛戠戡戢戣戤戥戦戬戭戯轰戱披菊牖戸戹戺戻卯戽锹扂楔扃扆扈扊杖牵绢铐镯赉扐搂搅烊盹瞌跟趸镲靶鼾払扗玫腮扛扞扠扡扢盔押扤扦扱罾揄绥鞍郤窾扻扼扽抃抆抈抉抌抏瞎抔缳缢擞抜拗択抨摔歉蹿牾抶抻搐泵菸拃拄拊髀抛拌脯拎拏拑擢秧沓曳挛迂拚拝拠拡拫拭拮踢拴拶拷攒拽掇芥橐簪摹疔挈瓢骥捺蹻挌挍挎挐拣挓挖掘浚挙揍聩挲挶挟挿捂捃捄捅捆捉捋胳膊揎捌捍捎躯蛛捗捘捙捜捥捩扪捭据捱捻捼捽掀掂抡臀膘掊掎掏掐笙掔掗掞棉芍掤搪阐掫掮掯揉掱掲掽掾揃揅揆搓揌诨揕揗揘揜揝揞揠揥揩揪揫橥遒麈揰揲揵揶揸背揺搆搉搊搋搌搎搔搕撼橹捣搘搠搡搢搣搤搥搦搧搨搬楦裢讪赸掏搰搲搳搴揾搷搽搾搿摀摁摂摃摎掴摒摓跤摙摛掼摞摠摦喉羯摭摮挚摰摲抠摴抟摷掺摽撂撃撅稻撊撋挦锏泼撕撙撚㧑挢撢掸撦撅撩撬撱朔揿蚍蜉挝捡擀掳闯擉缶觚擐擕擖擗擡擣擤澡腚擧擨擩擫擭摈拧撷擸撸擽擿攃摅撵攉攥攐攓撄搀撺每攩攫辔澄攮攰攲攴轶攷砭讦攽碘敁敃敇敉叙敎筏敔敕敖闰诲敜煌敧敪敳敹敺敻敿斁衽斄牒绉诌斉斎斓鹑谰驳鳢斒筲斛斝斞斠斡斢斨斫斮晾沂潟颖绛邵斲斸釳於琅斾斿旀旗旃旄涡旌旎旐旒旓旖旛旝旟旡旣浴旰獭魃旴时旻旼旽昀昃昄昇昉晰躲澈熹皎皓矾昑昕昜昝昞昡昤晖笋昦昨是昱昳昴昶昺昻晁蹇隧蔬髦晄晅晒晛晜晞晟晡晢晤晥曦晩萘莹顗晿暁暋暌暍暐暔暕煅旸暝暠暡曚暦暨暪朦胧昵暲殄冯暵暸暹暻暾曀晔昙曈曌曏曐暧曘曙曛叠昽曩骆曱甴肱曷牍禺锟曽沧耽朁朅朆杪栓夸竟粘绦朊膺朏朐朓朕朘朙瞄觐溘饔飧朠朢朣栅椆淀虱朩朮朰朱炆璋钰炽鹮朳槿朵朾朿杅杇杌陧欣钊湛漼楷瀍煜玟缨翱肇舜贽适逵杓杕杗杙荀蘅杝杞脩珓筊杰榔狍閦颦缅莞杲杳眇杴杶杸杻杼枋枌枒枓衾葄翘纾逋枙狸桠枟槁枲枳枴枵枷枸橼枹枻柁柂柃柅柈柊柎某柑橘柒柘柙柚柜柞栎柟柢柣柤柩柬柮柰柲橙柶柷柸柺査柿栃栄栒栔栘栝栟柏栩栫栭栱栲栳栴檀栵栻桀骜桁镁桄桉桋桎梏椹葚桓桔桕桜桟桫椤桭杯桯桲桴桷桹湘溟梃梊梍梐潼栀枧梜梠梡梣梧梩梱梲梳梴梵梹棁棃樱棐棑棕榈簑绷蓑枨棘棜棨棩棪棫棬棯棰棱棳棸棹椁棼碗椄苕椈椊椋椌椐椑椓椗検椤椪椰椳椴椵椷椸椽椿楀匾楅篪楋楍楎楗楘楙楛楝楟楠楢楥桢楩楪楫楬楮楯楰梅楸楹楻楽榀榃榊榎槺榕榖榘榛狉莽搒笞榠榡榤榥榦榧杩榭榰榱梿霰榼榾桤槊闩槎槑槔槖様槜槢槥椠槪槭椮槱槲槻槼槾樆樊樏樑樕樗樘樛樟樠樧樨権樲樴樵猢狲桦樻罍樾樿橁橄橆桡笥龠橕橚橛辆椭橤橧竖膈跨橾橿檩檃檇柽檍檎檑檖檗桧槚檠樯檨檫檬梼槟檴檵柠棹櫆櫌栉櫜椟櫡槠栌枥榇栊櫹棂茄櫽欀欂欃欐欑栾欙棂溴欨欬欱欵欶欷歔欸欹欻欼欿歁歃歆艎歈歊莳蝶歓歕歘歙歛歜欤歠蹦诠镶蹒跚升陟歩歮歯歰歳歴璞歺瞑歾殁夭殈殍殑殗殜殙殛殒殢殣殥殪殚僵殰殳荃殷殸殹蛟殻肴谤殴毈毉喂毎毑蕈毗毘毚茛邓毧毬毳毷毹毽毾毵牦氄氆靴氉氊氇氍氐聊氕氖気氘氙氚氛氜氝氡汹焊痉氤氲氥氦铝锌氪烃氩铵痤汪浒漉痘盂碾菖蒲蕹蛭螅氵冰氹氺氽烫氾氿渚汆汊汋汍汎汏汐汔汕褟汙汚汜蓠沼秽蔑汧汨汩汭汲汳汴堤汾沄沅沆瀣沇沈葆浸沦湎溺痼疴沌沍沏沐沔沕沘浜畹砾沚沢沬沭沮沰沱灢沴沷籽沺烹濡洄泂肛泅泆涌肓泐泑泒泓泔泖泙泚泜泝泠漩馍涛粼泞藓鳅泩泫泭泯铢泱泲洇洊泾琵琶荽蓟箔洌洎洏洑潄濯洙洚洟洢洣洧洨洩痢滔洫洮洳洴洵洸洹洺洼洿淌蜚浄浉浙赣渫浠浡浤浥淼瀚浬浭翩萍浯浰蜃淀苔蛞蝓蜇螵蛸煲鲤浃浼浽溦涂涊涐涑涒涔滂莅涘涙涪涫涬涮涴涶涷涿淄淅淆淊凄黯淓淙涟淜淝淟淠淢淤渌淦淩猥藿亵淬淮淯淰淳诣涞纺淸淹炖癯绮渇済渉渋渓渕涣渟渢滓渤澥渧渨渮渰渲渶渼湅湉湋湍湑湓湔黔湜湝浈湟湢湣湩湫湮麟湱湲湴涅満沩溍溎溏溛舐漭溠溤溧驯溮溱溲溳溵溷溻溼溽溾滁滃滉滊荥滏稽滕滘汇滝滫滮羼耷卤滹浐煎漈漊漎绎漕漖漘漙沤漜漪漾漥漦漯漰溆漶漷濞潀颍潎潏潕潗潚潝潞潠潦祉疡潲潵滗潸潺潾涠澁澂澃澉澌澍澐澒澔澙渑澣澦澧澨澫澬浍澰澴澶澼熏郁濆濇濈濉濊貊濔疣濜濠濩觞浚濮盥潍濲泺瀁滢渎渖瀌浏瀒瀔濒泸瀛潇潆瀡潴泷濑瀬弥潋瀳瀵瀹瀺瀼沣滠灉灋灒漓灖灏灞灠滦灥灨滟灪蜴灮烬獴灴灸灺炁炅鱿炗炘炙炤炫疽烙钎炯炰炱炲炴炷毁炻烀烋瘴鲳烓烔焙烜烝烳饪烺焃焄耆焌焐焓焗焜焞焠焢焮焯焱焼煁煃煆煇煊熠煍熬煐炜煕暖熏硷霾煚煝煟煠茕矸煨琐炀萁煳煺煻熀熅熇熉罴荧穹炝熘熛熜稔谙烁熤熨熯熰眶蚂颎熳熸熿燀烨燂燄盏燊燋燏燔隼燖焖燠燡灿燨燮燹燻燽燿爇爊爓爚爝爟爨蟾爯爰为爻丬爿牀牁牂牄牋窗牏牓窗釉牚腩蒡虻牠虽蛎牣牤牮牯牲牳牴牷牸牼绊牿靬犂犄犆犇犉犍犎犒荦犗犛犟犠犨犩犪犮犰狳犴犵犺狁甩狃狆狎狒獾狘狙黠狨狩狫狴狷狺狻豕狈蜘猁猇猈猊猋猓猖獗猗猘狰狞犸猞猟獕猭猱猲猳猷猸猹猺玃獀獃獉獍獏獐獒毙獙獚獜獝獞獠獢獣獧鼇蹊狯猃獬豸狝獯鬻獳犷猕猡玁菟玅玆玈珉糁禛郅玍玎玓瓅玔玕玖玗玘玞玠玡玢玤玥玦珏瑰玭玳瑁玶玷玹玼珂珇珈瑚珌馐馔珔珖珙珛珞珡珣珥珧珩珪佩珶珷珺珽琀琁陨玡琇琖琚琠琤琦琨琫琬琭琮琯琰琱琲琅琴珐珲瑀瑂瑄瑉玮瑑瑔瑗瑢瑭瑱瑲瑳瑽瑾瑿璀璨璁璅璆璈琏璊璐璘璚璝璟璠璡璥瑷璩璪璫璯璲玙璸璺璿瓀璎瓖瓘瓒瓛脐瓞瓠瓤瓧瓩瓮瓰瓱瓴瓸瓻瓼甀甁甃甄甇甋甍甎甏甑甒甓甔瓮甖甗饴蔗甙诧钜粱盎锈团甡褥産甪甬甭甮宁铠甹甽甾甿畀畁畇畈畊畋畎畓畚畛畟鄂畤畦畧荻畯畳畵畷畸畽畾疃叠疋疍疎箪疐疒疕疘疝疢疥疧疳疶疿痁痄痊痌痍痏痐痒痔痗瘢痚痠痡痣痦痩痭痯痱痳痵痻痿瘀痖瘃瘈瘉瘊瘌瘏瘐痪瘕瘖瘙瘚瘛疭瘜瘝瘗瘠瘥瘨瘭瘆瘯瘰疬瘳疠瘵瘸瘺瘘瘼癃痨痫癈癎癐癔癙癜癠疖症癞蟆癪瘿痈発踔绀蔫酵皙砬砒翎翳蔹钨镴皑鹎驹暨粤褶皀皁荚皃镈皈皌皋皒朱皕皖皘皜皝皞皤皦皨皪皫皭糙绽皴皲皻皽盅盋碗盍盚盝踞盦盩秋千盬盭眦睁瞤盯盱眙裰盵盻睐眂眅眈眊県眑眕眚眛眞眢眣眭眳眴眵眹瞓眽郛睃睅睆睊睍睎困睒睖睙睟睠睢睥睪睾睯睽睾眯瞈瞋瞍逛瞏瞕瞖眍䁖瞟瞠瞢瞫瞭瞳瞵瞷瞹瞽阇瞿眬矉矍铄矔矗矙瞩矞矟矠矣矧矬矫矰矱硪碇磙罅舫阡、矼矽礓砃砅砆砉砍砑砕砝砟砠砢砦砧砩砫砮砳艏砵砹砼硇硌硍硎硏硐硒硜硖砗磲茚钡硭硻硾碃碉碏碣碓碔碞碡碪碫碬砀碯碲砜碻礴磈磉磎硙磔磕磖磛磟磠磡磤磥蹭磪磬磴磵磹磻硗礀硚礅礌礐礚礜礞礤礧礮砻礲礵礽礿祂祄祅祆禳祊祍祏祓祔祕祗祘祛祧祫祲祻祼饵脔锢禂禇禋祦禔祎隋禖禘禚禜禝禠祃禢禤禥禨禫祢禴禸秆秈秊闱飒秋秏秕笈蘵赁秠秣秪秫秬秭秷秸稊稌稍稑稗稙稛稞稬秸稲稹稼颡稿穂穄穇穈穉穋稣贮穏穜穟秾穑穣穤穧穨穭穮穵穸窿阒窀窂窅窆窈窕窊窋窌窒窗窔窞窣窬黩蹙窑窳窴窵窭窸窗竁竃竈竑竜并竦竖篦篾笆鲛竾笉笊笎笏笐靥笓笤箓笪笫笭笮笰笱笲笳笵笸笻筀筅筇筈筎筑筘筠筤筥筦笕筒筭箸筰筱筳筴宴筸箂个箊箎箑箒箘箙箛箜篌箝箠箬镞箯箴箾篁筼筜篘篙篚篛篜篝篟篠篡篢篥篧篨篭篰篲筚篴篶篹篼箦簁簃簆簉簋簌簏簜簟簠簥簦簨簬簰簸簻籊藤籒籓籔签籚篯箨籣籥籧笾簖籫籯芾麴籵籸籹籼粁秕粋粑粔粝粛粞粢粧粨粲粳稗粻粽辟粿糅糆糈糌糍糒糔萼糗蛆蹋糢糨糬粽糯糱籴粜糸糺紃蹼鲣霉纡纨绔纫闽襻紑纰纮锭鸢鹞纴紞紟扎紩紬绂绁纻紽紾绐絁絃絅経絍绗絏缡褵絓絖絘絜绚絣螯絪絫聒絰絵绝絺絻絿綀绡綅绠绨绣綌綍綎捆綖綘継続缎绻綦綪线綮綯绾罟蝽綷縩绺绫緁绲緅緆缁绯緌緎総緑绱緖缃缄缂绵缗緤褓缌纂緪緰缑缈缏缇縁縃縄萦缙缒縏缣縕缞縚缜缟缛縠縡縢縦绦縯縰骋缧縳纤缦絷缥縻衙縿繄缫繈繊繋繐缯繖繘繙繠缋繣繨缰缲繸繻缱纁纆纇缬缵纩纑纕缵纙纚纛缾罃罆坛罋罂罎罏罖罘罛罝罠罣罥罦罨罫罭锾罳罶罹罻罽罿羂羃羇芈蕉51鸵羑羖羌羜羝羢羣羟羧羭羮羰羱羵羶羸藜鲐翀翃翅翊翌翏翕翛翟翡翣翥翦跹翪翫翚翮翯翱翽翾翿板饕鸹锨耋耇耎耏专耒耜耔耞耡耤耨耩耪耧耰鬓耵聍聃聆聎聝聡聦聱聴聂聼阈聿肄肏肐肕腋肙肜肟肧胛肫肬肭肰肴肵肸肼胊胍胏胑胔胗胙胝胠铨胤胦胩胬胭胯胰胲胴胹胻胼胾脇脘脝脞脡脣脤脥脧脰脲脳腆腊腌臜腍腒腓胨腜腠脶腥腧腬腯踝蹬镣腴腶蠕诽膂腽嗉膇膋膔腘膗膙膟黐膣膦膫膰膴膵膷脍臃臄臇臈臌臐臑臓膘臖臙臛臝臞臧蓐诩臽臾臿舀舁鳑鲏舋舎舔舗馆舝舠舡舢舨舭舲舳舴舸舺艁艄艅艉艋艑艕艖艗艘艚艜艟艣舣艨艩舻艬艭荏艴艳艸艹艻艿芃芄芊萰陂藭芏芔芘芚蕙芟芣芤茉芧芨芩芪芮芰鲢芴芷芸荛豢芼芿苄苒苘苙苜蓿苠苡苣荬苤苎苪镑苶苹苺苻苾茀茁范蠡萣茆茇茈茌茍茖茞茠茢茥茦菰茭茯茳藨茷藘茼荁荄荅荇荈菅蜢鸮荍荑荘豆荵荸荠莆莒莔莕莘莙莚莛莜莝莦莨菪莩莪莭莰莿菀菆菉菎菏菐菑菓菔芲菘菝菡菢菣菥蓂菧菫毂蓥菶菷菹醢菺菻菼菾萅萆苌萋萏萐萑萜萩萱萴莴扁萻葇葍葎葑荭葖葙葠葥苇葧葭药葳葴葶葸葹葽蒄蒎莼茏薹莅蒟蒻蒢蒦蒨蒭藁蒯蒱鉾蒴蒹蒺蒽荪蓁蓆蓇蓊蓌蓍蓏蓓蓖蓧蓪蓫荜跣藕苁蓰蓱莼蓷蓺蓼蔀蔂蔃蔆蔇蔉蔊蔋蔌蔎蔕蔘蔙蒌蔟锷蒋雯茑蔯蔳麻蔵蔸蔾荨蒇蕋蕍荞蕐蕑芸莸蕖蕗蕝蕞蕠蕡蒉蕣蕤蕨蕳蓣蕸蕺蕻薀薁薃薅薆荟薉芗薏薐蔷薖薘剃谔钗薜薠薢薤薧薨薫薬薳薶薷薸薽薾薿藄藇藋荩藐藙藚藟藦藳藴苈藷藾蘀蘁蕲苹蘗蘘蘝蘤蘧蘩蘸蘼虀虆虍蟠虒虓虖虡虣虥虩虬虰蛵蛇虷鳟虺虼蚆蚈蚋蚓蚔蚖蚘蚜蚡蚣蚧蚨蚩蚪蚯蚰蜒蚱蚳蚶蚹蚺蚻蚿蛀蛁蛄蛅蝮蛌蛍蛐蟮蛑蛓蛔蛘蛚蛜蛡蛣蜊蛩蛱蜕螫蜅蚬蜈蝣蜋蜍蜎蜑蠊蜛饯蜞蜣蜨蜩蜮蜱蜷蜺蜾蜿蝀蝃蝋蝌蝍蝎蝏蝗蝘蝙蝝鲼蝡蝤蝥猿蝰虻蝲蝴蝻螃蠏蛳螉螋螒螓螗螘螙螚蟥螟螣螥螬螭䗖螾螀蟀蟅蝈蟊蟋蟑蟓蟛蟜蟟蟢虮蟨蟪蟭蛲蟳蛏蟷蟺蟿蠁蠂蠃虿蠋蛴蠓蚝蠗蠙蠚蠛蠜蠧蟏蠩蜂蠮蠰蠲蠵蠸蠼蠽衁衄衄衇衈衉衋衎衒同衖胡衞裳钩衭衲衵衹衺衿袈裟袗袚袟袢袪袮袲袴袷袺袼褙袽裀裉袅裋夹裍裎裒裛裯裱裲裴裾褀褂褉褊裈褎褐褒褓褔褕袆褚褡褢褦褧褪褫袅褯褰褱裆褛褽褾襁褒襆裥襉襋襌襏襚襛襜裣襞襡襢褴襦襫襬襭襮襕襶襼襽襾覂覃覅霸覉覊覌覗觇覚覜觍觎覧覩觊觏覰観觌觔觕觖觜觽觝觡酲觩觫觭觱觳觯觷觼觾觿言赅讣訇訏訑訒诂讬訧訬訳訹证訾詀詅诋毁詈詊讵詑诒诐詗诎察詨诜詶詸詹詻诙诖誂誃诔锄诓誋诳诶悖誙诮诰誧説読誯谇訚谄谆諆諌诤诹诼諕谂谀諝谝諟喧谥諴諵谌谖誊謆謇歌謍謏謑谡谥謡謦謪谪讴謷謼谩哗譅譆譈譊讹譒撰谮鑫譞噪譩谵譬譱譲谴譸譹谫讅讆詟䜩雠讐谗谶讙谠讟谽豁豉豇岂豊豋豌豏豔豞豖豗豜豝豣豦豨豭豱豳豵豶豷豺豻貅貆狸猊貔貘䝙貜貤餍贳餸贶贲赂賏赊赇赒賝赓赕賨赍斗賮賵賸赚赙赜赟贉赆赑贕赝赬赭赱赳迄趁趂趄趐趑趒趔趡趦趫趮趯趱趴趵趷趹趺趿跁跂跅跆踬跄跐跕跖跗跙跛跦跧跩跫跬跮跱跲跴跺跼跽踅踆踈踉踊踒踖踘踜踟躇蹰踠踡踣踤踥踦踧跷踫踮逾踱踊踶踹踺踼踽躞蹁蹂躏蹎蹐蹓蹔跸蹚蹜蹝迹蹠蹡蹢跶蹧蹩蹪蹯鞠蹽躃躄躅踌跻躐踯跞躘躙躗躝躠蹑躜躧躩躭躰躬躶軃軆辊軏轫軘軜軝腭転軥軨軭軱轱辘軷轵轺軽軿輀輂辇辂辁輈挽輗辄辎辋輠輤輬輭輮辏輴輵輶輹輼辗辒轇轏轑轒辚轕轖轗轘轙轝轞轹轳罪辣辞辵辶辺込辿迅迋迍麿迓迣迤逦迥迨迮迸迺迻迿逄逅逌逍逑逓迳逖逡逭逯逴逶逹遄遅侦遘遛遝遢遨遫遯遰遴绕遹遻邂邅邉邋邎邕邗邘邛邠邢邧邨邯郸邰邲邳邴邶邷邽邾邿郃郄郇郈郔郕郗郙郚郜郝郞郏郠郢郪郫郯郰郲郳郴郷郹郾郿鄀鄄郓鄇鄈鄋鄍鄎鄏鄐鄑邹邬鄕郧鄗鄘鄚鄜鄞鄠鄢鄣鄤鄦鄩鄫鄬鄮鄯鄱郐鄷鄹邝鄻鄾鄿酃酅酆酇郦酊酋酎酏酐酣酔酕醄酖酗酞酡酢酤酩酴酹酺醁醅醆醊醍醐醑醓醖醝酝醡醤醨醪醭醯醰酦醲醴醵醸醹醼醽醾釂酾酽釆釈鲈镏阊钆钇钌钯钋鼢鼹钐钏釪釬釭釱钍釸钕钫鈃钭鈆鈇钚鈊鈌钤钣鈒鈤钬钪鈬铌铈钶铛钹铍钸钿鉄鉆铊铇鉌铋鉏铂钷铆钵鉥钲鉨钼钽鉱鉲鉶铰铒鉼铪銍銎铣銕镂铫铦铑铷銤铱铟銧铥铕铯銭銰焊銶锑锉汞鋂锒鋆鋈鋊铤鋍铗鋐鋑鋕鋘鋙锊锓锔锇铓鋭铖锆锂铽鋳鋹鋺鉴镚钎錀锞锖锫锩錍铔锕錔锱铮锛錞锬锜錤錩錬録铼錼锝钔锴鍉镀鍏鍐铡鍚锻锽锸锲锘鍫鍭鍱鍴锶鍹锗针锺锿镅鎉鎋鎌鎍鎏鎒鎓鎗镉鎚鎞镃鎤铩锼鎭鎯镒镍鎴镓鎸鎹镎镟鏊镆镠镝鏖铿锵鏚镗镘镛鏠鏦錾镤鏸镪鏻鏽鏾铙鐄鐇鐏铹镦镡鐗馗镫镢镨鐡锎镄鐩镌鐬鐱镭鐶鐻鐽镱鑀鑅镔鑐鑕鑚鑛鑢鑤镥鑪镧鑯鑱鑴鑵镊镢钃镻闫闬闶闳閒闵閗閟阂関合閤哄阆閲阉閺阎阏阍阌暗闉阕阗闑闒闿闘闚阚闟闠闤闼阞阢阤阨阬阯阹阼阽陁陑陔陛陜陡陥陬骘陴険陼陾阴隃隈隒隗隞隠隣隤隩隮隰颧隳隷隹雂雈雉雊雎雑雒雗雘雚雝雟雩雰雱驿霂霅霈霊沾霒霓霙霝霢霣霤霨霩霪霫霮靁叇叆靑靓靣腼靪靮靰靳靷靸靺靼靿鞀鞃鞄鞍鞗鞙鞚鞝鞞鞡鞣鞨鞫鞬鞮鞶鞹鞾鞑韅鞯驮韍韎韔韖韘韝韫韡韣韭韭韱韹韺頀刮頄顸顼頍颀颃颁頖頞頠頫頬颅頯頲颕頼悴顋顑颙颛颜顕顚顜颟顣颥颞飐飑台飓颸飏飖颽颾颿飀飂飚飌翻飡飣饲飥饨饫飮飧飶餀餂饸饹餇餈饽哺馂餖餗餚馄馃餟餠餤餧餩餪餫糊餮糇餲饧馎糕饩馈馊馌馒饇馑馓膳饎饐饘饟馕馘馥馝馡馣骝骡馵馹駃駄駅駆駉駋驽駓驵駗骀驸駜骂骈駪駬骃駴骎駹駽駾騂騄骓騆騉騋骒骐麟騑騒験騕骛騠騢騣騤騧骧騵驺骟騺蓦骖骠骢驆驈骅驌骁驎骣驒驔驖驙驦驩驫骺鲠骫骭肮骱骴骶骷髅骾髁髂髄髆膀髇髑髌髋髙髝髞髟髡髣髧髪髫髭髯髲髳髹髺髽髾鬁鬃鬅鬈鬋鬎鬏鬐鬑鬒鬖鬗鬘鬙鬠鬣斗鬫鬬阄鬯鬰鬲鬵鬷魆魈魊魋魍魉魑魖鳔魛魟魣魦魨魬鲂魵魸鮀鲅鮆鲧鲇鲍鲋鮓鲒鲕鮟鱇鮠鮦鮨鲔鲑鮶鮸鮿鲧鯄鯆鲩鯈鲻鯕鲭鲞鯙鯠鲲鯥鲰鲶鳀鯸鳊鲗䲠鹣鳇鰋鳄鳆鰕鰛鰜鲥鰤鳏鰦鳎鳐鳁鳓鰶鲦鲡鰼鰽鱀鱄鳙鱆鳕鱎鱐鳝鳝鳜鲟鲎鱠鳣鱨鲚鱮鱲鱵鱻鲅鳦凫鳯鳲鳷鳻鴂鴃鴄鸩鴈鴎鸰鴔鴗鸳鸯鸲鹆鸱鴠鴢鸪鴥鸸鹋鴳鸻鴷鴽鵀鵁鸺鹁鵖鵙鹈鹕鹅鵟鵩鹌鵫鵵鵷鵻鹍鶂鶊鶏鶒鹙鶗鶡鶤鶦鶬鶱鹟鶵鶸鶹鹡鶿鹚鷁鷃鷄鷇䴘䴘鷊鷏鹧鷕鹥鸷鷞鷟鸶鹪鹩鷩鷫鷭鹇鹇鸴鷾䴙鸂鸇䴙鸏鸑鸒鸓鸬鹳鸜鹂鹸咸鹾麀麂麃麄麇麋麌麐麑麒麚麛麝麤麸面麫麮麯麰麺麾黁黈黉黢黒黓黕黙黝黟黥黦黧黮黰黱黪黶黹黻黼黾鼋鼂鼃鼅鼈鼍鼏鼐鼒冬鼖鼙鼚鼛鼡鼩鼱鼪鼫鼯鼷鼽齁齆齇齈齉齌赍齑龀齕齗龅齚龇齞龃龉龆齢出齧齩齮齯齰齱齵齾厐龑龒龚龖龘龝龡龢龤" + +traditional_characters = "制咖片型超聲盤鑒定仔點他命書歌粉巾字帳恤手指記憶棒形轉彎溝光○〇㐄㐅㐆㐌㐖毒㐜㐡㐤㐰㐺㑇㑳㒳㒸㔾㗂㗎㝵㞎㞙㞞㠯㢲㢴㤅㥁㥯㨗㫺㬎㮎㮚㮸㲋㲱㲾㳮㵎㵪㶸㷖㷭㹢㹴犬㺢狓㺵㼝㽮㿝䍃䔢䖟䖸䗈䗥䗪䝓䠶䥯䦉䯝䰾魚䲔䳗䳘䵹鼄䶑一對應映射丁不識下兒子做二休世丘之貉並中台原則串為甚謂乾淨了百事無成八變五十些人得道雞升天代如併來去個國政策勁幽靈在歐洲遊蕩接樣蘿蔔坑側化傳價元論醇共再准刀兩斷切分耕耘收穫錢貨物向看舊就緒險刻千金動勞永逸匙零夜半卡通回復返影蹤反常態口咬氣句話同吐快吹周味呼諾嗚品紅鍋哄而散起唱和問三知生熟團漆黑火糟堆場空塊麵塌糊塗塵染壁廂夔已足多情露水大早到晚夫妻當關萬莫開失古恨套所料既往孔見提師要家主審寸陰難買鬥牛小撮部陣局展身層巴掌帆風順席地帶過年計於春頭載四季期被蛇怕井繩度願式份彈頃深前律徑心意念差愁孤行俱全房廳交遮打技長把抓死拿眼淚鼻涕鑰鎖折段抿拍即合掃排掬揮撥擁上入擊洞擲攬改故轍敗文值名斑方面旁族日秋餐隔雅里終父旦時晌會霎間晃暴寒曝更月望垠際朝夕本正經利杯羹東西板枝獨秀根筋桿進條龍服務概模次函數又性程總付步腳印趨登毛拔呵氧氮碳決雌雄波未平派謊言流清楚白準溜煙潭有獲聞是處降琴鶴甲病發可拾沙目然瞭直以相眨穿睹瞥瞬矢的解石鳥神教秉虔誠秘種窩蜂窮竅笑置筆苟勾銷抹殺煞等獎箍節吃箭仇雙鵰詩籌籮筐系列紙級士官統絲毫掛維網盡線微吭響股腦胎脈承腔臂力致效資源址器舉功投般說講規貿易葉障著慎滿皆輸號木電池衣傾鐘高低視仁覺醒覽遺角銀幣觸潰九鼎蔽抄出駟馬追重語破貧洗貫走路安蹴至幾蹶振躍役膽汗較輩輪辭贊退六連遍遞邊針血錘音錯門思閃真倒項栽霧類保護川先驚乍體鬨鱗爪鳴滴泡鄰域黨專鼓作齊炒丑烯亥克內酯冬加奴卯肝炎基尺梁街褲鎬客寵庭巳汝昌烷玲磊糖肇酉醛啷青縣韙良香骨鯛丂七集河市弦喜嘴張舌堵區工業姊妹星架構巧彩扭歪拼湊餘熱曜武州爺浮屠美鄉老階樹葷素碎落能魄鰓鰻珠丄丅丆万俟丈尚摸母娘量管群亞虎必我堂令申件裝伏位博俠義界表女墟臺戲臭皮匠勝諸葛亮賽頂倍催請運算包立叉戟離疫苗土史志演圍揭瓦曬夷姑婆帝村寶爛尖杉鹼屜桌山岔島由紀峽壩庫鎮廢從德後拗湯治旬食明昧曹朋友框欄極權冪曲歸依貓民氟硼氯磷鐵江侗自旅法司洋浦梅園溫暖灣焦班幸用田略番疊皇炮捶硝苯酸腺苷稜草鏡穗跳遠索錦綱聚氰胺聯店胚膲愛色堇紫羅蘭芝茶飯菱雲蟲藏藩亂叛蘇親債凳學座恐戀柱測肌腹衩錐係貂企烏跪叩軍車農題迭都甘油屯奏鍵短阿姨陪姐隻顧茅廬槽駕魂鮮鹿頁其菜單乘任供勢午齒漢組織吊調瀉唇坡城報墳外夸將尉建築岸崗公床揚新劍昇杭林栗校樓標款汽社浣海商館劇院鋼華港機械廣媒環球融第醫科證券綜財樂育游漲猶嶺疏癮瞼確兵領導繳肢膛船艾瑟爾蒼蔡虞傚衫覆訪訴課諭議軌述野鉤限敵鞋頜頷顎饒首齦站例修凡劃垂屆屬崽頦廚拜挫擺放旋削棋榻檻禮沉注滑營獄畫确儀聘花葬詔員跌轄週達酒錨閘陷陸雨雪飛威丌于丹久乏予理評產亢卑亦乎舞己悲矩圓詞害誌但住佞佳便俗信票案幅翁倦倫假偏倚斜虧鬼敲停備傷脾胃僅此像儉匱免宜穴焉戴兼容許凍伯仲負彼晝皂軒輊實刊划顛衛戰哥比省非好黃飾別拘束掩奶睬選擇搖擾煩苦枚寫協厭及格受歡迎約只估侵犯割狀告或缺抗拒挽撤救藥喻磨滅端倪少逆逾越避靠適吉譽吝玉含延咎歹聽啻淵善謀均勻堪忍夠太惹妙妥妨孕症孝術室完納推冠積宣疑辯慄碴稱屈撓屑干涉衡待很忙惡忿怎麼怠急恥恭息悅惑惜惟想愉愧怍慌憤啟懂懈懷材才緊招認扣抵拉捨也罷插揣冒搭撞南牆擴核支攻敢雷攀敬裡嗎需景智暇曾罪遇朽枉止況競爭辱求癒渝溶濟左右袒困補爽特寂寞示弱找謝畏強疾徐痛癢冤符眠睦瞅董何厚云措活疲羞者輕玻璃祥兆禁移稂莠穩佛換答簡結果盟絕縷途給談否羈翼耐肖脛毋寧興舒若菲萊痕跡窠臼虛衰臉兔撒鷹棺範該詳諱抬泰讓鬚眉象眾貲賬費灰賴奇慮訓輟辨菽麥辛近送透逞徒速續逮捕遂遑違遜斧鉞艱醉鏽隨觀棄顯飽脂肪使丏丐幫丒且慢末丕替桃宗王尊涼爵各圖屋脊糧署錄壇吾祿職胄襲君廈丗北壑桐疹損逢陵鷸丙寅戌氨腈唑綸辰酮脫氫酶醚丞丟現掉紗帽弄扯砲碗丠両丣坐存激肩臻蒂蓮悖序驅丨丩丫挺杈髻鬟細介俄伊犁京尼布訂普渡央委監察檢查劑圈設警隊斯督剩震境航舶革防托播促質版蠑螈鋒研藝歷殘消頻譜精密製造陲郵候埔堅壓壢凹匯執府究邦俘攝寮彬狼嶽肺腫庸英訊診埋粒胞括控碼韓暑槍樞砥澳哇牟壽甸鑽探篇簽綴縫繼耳肯照婦埃懸璧軸櫃檯辣擱淺邪跑纖阮陽私囊魔丮丰姿采丱燒丳丵丶丷丸參寨朗桂瑞砂衷霞貌鳳僕艦因嫌宰峰幹絡牌持旨祭禱簿編罰賓辦丼丿乀乂乃乄仰慕盛曠留考驗闊乆乇么醜麼乊湖燃乑乒乓乕乖僻忤戾离謬迕乗危肥劫除隙浪婿乙炔腸酰吡咯鹽乚乛乜嘢卿玄宮尾狐龜塔嶷兄弟泉章霄釘耙乞扎哀憐恕討乢乣乤乥乧乨乩童乪乫乭乳暈汁液瑤漿牙癌突竇罩腐膠豬酪蛋糕菌瘤乴乵乶乷乸乹乺乼乾俸冰嘉噦嚎坤媽屍壘旱枯涸俐渴潮澀煸豆燥爹瘦癟癬瞪袋脆薑貝隆餾乿亀亁叫咕攘扔搞男砸竄蓬麻亃亄亅卻亇遲典今臨繁累卵奉婚聰躬巨與遷添裂副宿歲怪噁尕崙愣杆硅硫鈦鈾錳芑雜異鈉砷胂磺琥珀艙棍簧胡茬盜浩盆販郎腿亍洪亐互欠助勉惠操斥諉繫戶譯亓墓碑刑鈴卅渠繽紛斗米旗憲釩燈徽瘟祖拳福穀豐臟腑綁肉醃苓蘊橋鋪霸顏鬧判噴岡底蛙陘礦亖亙亜罕們娜桑那努哈喀弗烈曼松森杜氏盃奧琛敦戊穆聖裔彙薛孫亟亡佚虜羊牢奮釋卷卸契媾感額睫纏誼趾塞擠紐阻還配馳莊亨洛祚亪享津滬畿郊慈菴枇杷膏亭閣鋥麗亳亶亹誅初責翻瘋偶傑叢稠妖拖寰居吸授慧蝸吞壯魅狗矛盾益渣患憂稀描猿夢暫涯畜禍緣沸搜引擎臣橫紜誰混援蒸獸獅稅剖亻亼亽亾什獻剎邡麽仂仃仄仆富怨仈仉畢昔晨殼紹仍仏仒仕宦仗欺恃腰嘆歎炬梓訖施仙后瓊逝仚仝仞仟悔仡佬償填泊拓撲簇羔購頓欽佩髮棻閫馭養億儆尤藉幀賑凌敘帖李柔剛沃眥睚戒訛取饗讀仨仫仮著泳臥躺韶夏裁仳仵唯賢憑釣誕仿似宋彿諷伀碩盼鵝伄儅伈伉儷柯始娃邁戈坦堡帕茨薩廟瑪莉莎藤霍姆伋伍奢胥廷芳豪伎倆侍汛勒希羲雛伐憩整謨閑閒伕伙伴頤伜伝伢叔恆茲恩翰伱伲侶伶俜悧鼬伸懶縮喇叭伹伺伻伽倻輻伾佀佃佇佈喬妮墨佉盧佌貸劣廉昂檔濃矮傘窪緩耗胸谷迷擋率齲宅沫舍療佐貳佑佔優據鏵嘗呢須魯曉佗佘余坪寺瓜銃僧蒙芒陀龕哼嘔坊姦孽弊揖祟繭縛誓賊佝僂瞀佟你奪趕佡佢佣佤佧賈佪佫佯佰佱潔績釀餚佴捲佶佷佸佹佺佻佼佽佾具喚窘壞娛怒慨硬習慣聾膨脹蔓駭貴痺侀侁侂侃侄侅鴻燕侇侈糜靡侉侌妾侏儒倉鼠侐侑侔侖侘侚鏈侜偎傍鈷循柳葫蘆附価侮罵蔑侯岩截蝕侷貼壺嬛宴捷攜桶箋酌俁狹膝狄俅俉俊俏俎俑俓俔諺俚俛黎健呈固墒增守康箱濕祐鏢鑣槓盒靖膜齡俞豹獵噪孚封札筒託衍鴿剪撰稿煉廠禊練繕葺俯瞰撐衝俲俳俴俵俶俷俺俻俾倀倂倅儲卒惶敷猝逃頡蓄崇隱倌倏忽刺蠟燭噍嚼坍扁抽斃蔥楣灌灶糞背藪賣賠閉霉騰倓倔倖倘倜儻倝借箸挹澆閱倡狂倢倣値倥傯倨傲倩匡嗣沖柝珍倬倭寇猩倮倶倷倹勤讚偁偃充偽吏嗓寐惺扮拱芫茜藉虢鈔偈偉晶偌宕距析濾殿疼癱註頗偓偕鴨歇滯偝偟偢忘怡旺偨偩偪偫偭偯偰偱偲偵緝蹄偷減惰漏窺竊偸偺迹傀儡傅傈僳傌籬傎奎琳迪叟芭傒傔傕傖悉荒傜傞傢傣芽逼傭婢傮睨寄檄誦謠頌傴擔辜弓慘蒿悼疤傺傻屄臆巢洩篋羨蓋軋頹傿儸僄僇僉僊働僎僑僔僖僚僝僞僣僤僥僦猴僨僩僬僭僮僯僰僱僵殖籤靜僾僿征隴儁儂儃儇儈朴薄儊儋儌儍儐儓儔儕儗儘儜儞儤儦儩汰哉寡渥裕酷儭儱罐儳儵儹儺儼儽兀臬臲鷲允勛勳宙宵帥憝彞諧嫂鬩暢沛溢盈飢赫兇悍狠猛頑愚妣斬秦遣鞭耀敏榮槃澤爆碟磁禿纜輝霽鹵朵婁孜烽醬勃汀箕裘鉗耶懞蕾徹兌軟遭黜兎児韻媳爸兕觥兗兙兛兜售鍪肚兝兞兟兡兢兣樽殮涅睡稟籍贅泌啡肽奸幕涵澇熵疚眷稃襯訌赴煥椒殲植跏沒試誤猜棲窗肋袖頰兪卦撇鬍岐廓轎疸楓茴瓏廁秩募勺噸寓斤曆畝迫筷釐最淫螺韜兮寬匪篩襄贏軛複兲詐刃堰戎痞蟻餉它冀鑄冂冃円冇冉冊嫁厲礪竭醮冏牧冑冓冔冕冖冗冘冞冢窄抑誣冥冫烘菇蟄冷凝坨橇淇淋炭餅磚磧窖醋雕雹霜冱冶爐艷嘲峻灘淡漠煖颼飲冼冽凃凄愴梗凅凇凈凊凋敝濛凔凜遵汞脢凞几凢処凰凱凵凶焰凸摺刷紋預喪嘍奔巡榜殯芙蓉租籠輯鞘萃凼鋸鑊刁蠻刂娩崩批拆攤掰櫱驟歧顆秒袂贓勿囑忌磋琢膚刈羽刎訟戮舂槳艇刓刖霹靂刜創犢刡恙墅幟筵緻刦刧刨昏默攸尿慾薰潤薰圭刪刮痧鏟刱刲刳刴刵踏磅戳柏槐繡芹莧蝟舟銘鵠鶩刼剁剃辮剄剉履鉛剋剌姻咽哨廊掠桅沿召瞻翅趙卜渺茫郭剒剔剕瀝剚愎毅訥纔剜剝啄採剞剟剡剣剤綵剮腎駛黏剰袍剴紊剷剸剺剽剿劁劂劄劈啪柴扳啦劉奭姥夼昫涓熙禪禹錫翔雁鶚劊劌弩柄蜻蛉劒劓劖劘劙瀾簣賞磯釜晉甜薪逐劦熔紂虐赤囚劬劭労劵効劻劼劾峭艮勅勇勵勍勐臘脖龐漫飼盪粥輒勖勗勘驕餒碌泮雇捐竹騎殊阱勣樸懇謹勦勧勩勯勰勱勲勷勸懲慰誡諫勹芡踐闌匁庇拯粟紮袱裹餃匆遽匈匉匊匋匍匐莖匏匕妝痰膿蛹齋苑烤蹈塘羌熊閥螳螂疆碚竿緯荷茵邙魏匚匜匝匟扶稷匣匭攏匸匹耦匽匾匿卂叮瘡禧軫堤棚迢鈞鍊卄卆遐卉瓷盲瓶噹胱腱裸卋卌卍卐怯污賤鄙齷齪陋卓溪唐梯漁陳棗泥漳潯澗梨芬譙贍轅迦鄭単驢弈洽鰲卛占筮卝卞卟吩啉屎翠厄卣卨卪卬卮榫襖璽綬鈕蚤懼殆篤聳卲帘帙繞卹卼卽厂厎厓厔厖厗奚厘厙厜厝諒厠厤厥厪膩孢厮厰厳厴厹厺粕垢蕪菁厼厾叁悟茸薯叄吵笄悌哺譏坫壟弧芯杠潛嬰芻袁詰貪諜煽饋駁収岳締災賄騙叚叡吻攔蘑蜜訣燧玩硯箏椎藺銅逗驪另覓叨嘮謁杵姓喊嚷囂咚嚀塑尋惱憎擦祇泣滲蝠叱吒咄咤喝籀黛舵舷叵叶鐸懿昭穰苴遼叻叼吁塹嫖賭瞧爬衆抒吅吆夥巹橡滌抱縱摩郡唁墜扇籃膀襪頸吋愾諮酬哭妓媛暗錶韁邇妃羿絮蕃渾拐葵暮隅吔吖啶嗪戚吜嗇噬嚥吟哦詠吠吧唧嗒咐吪雋咀徵燐苞茹鈣哧吮吰吱嘎吲哚吳棟嬌窟孟簫忠晗淞闔閭趼宇吶睛噓拂捧疵熄竽笛糠吼吽呀呂韋矇呃呆笨呇貢呉罄呋喃呎呏呔呠呡癡呣呤呦呧瑛眩扒晬淑姬瑜璇鵑呪呫嗶嚅囁呬呯呰呱呲咧噌鈍呴呶呷呸呺呻哱咻嘯嚕籲坎坷邏呿咁咂咆哮咇咈咋蟹煦珅藹咍咑咒詛咔噠嚓咾噥哩喱咗咠咡咢咣咥咦咨嗟詢咩咪咫嚙齧咭咮咱咲咳嗆嗽咴咷咸咹咺咼喉咿婉慟憫賦矜綠茗藍哂搶瞞哆嗦囉噻啾濱彗哋哌哎唷喲哏哐哞哢哤哪裏哫啼喘哰哲萎蚌哳哶哽哿唄唅唆唈唉唎唏嘩堯棣殤璜睿肅唔睇唕唚唞唣喳唪唬唰喏唲唳唵嘛唶唸唹唻唼唾唿啁啃鸚鵡啅埠棧榷祺舖鞅飆啊啍啎啐啓啕啖啗啜啞祈啢啣啤啥啫啱啲啵啺饑啽噶崑沁喁喂喆裙喈嚨喋喌喎喑喒喓喔粗喙幛慶滋鵲喟喣喤喥喦喧騷喨喩梆喫葡萄喭駝挑嚇碰樅瓣純皰藻趟鉻喵営喹喺喼喿嗀嗃嗄嗅嗈嗉嗊嗍嗐嗑嗔詬嗕嗖嗙嗛嗜痂癖嗝嗡嗤嗥嗨嗩嗬嗯嗰嗲嗵嘰嗷嗹嗾嗿嘀嘁嘂嘅惋嘈峪禾蔭嘊嘌嘏嘐嘒嘓嘖嘚嘜嘞嘟囔嘣嘥嘦嘧嘬嘭這謔嚴敞饞鬆嘵嘶嘷嘸蝦嘹嘻嘽嘿噀噂噅噇噉噎噏噔噗噘噙噚噝噞噢噤蟬皿噩噫噭噯噱噲噳嚏涌灑欲巫霏噷噼嚃嚄嚆抖嚌嚐嚔囌嚚嚜嚞嚟嚦嚬嚭嚮嚯嚲嚳飭按竣苛嚵嚶囀囅囈膪謙囍囒囓囗囘蕭酚飄濺諦囝溯眸紇鑾鶻囟殉囡団囤囥囧囨囪囫圇囬囮囯囲図囶囷囸囹圄圉擬囻囿圀圂圃圊粹蠹赦圌墾圏滾鯡鑿枘圕圛圜圞坯埂壤骸炕祠窯豚紳魠鯪鱉圧握圩圪垯圬圮圯炸岬幔毯祇窨菩溉圳圴圻圾坂坆沾坋坌舛壈昆墊墩椅坒坓坩堝坭坰坱坳坴坵坻坼楊掙涎簾垃垈垌垍垓垔垕垗垚垛垝垣垞垟垤垧垮垵垺垾垿埀畔埄埆埇埈埌殃隍埏埒埕埗埜埡埤埦埧埭埯埰埲埳埴埵埶紼埸培怖樁礎輔埼埽堀訶姪廡堃堄摧磐貞韌砌堈堉堊堋堌堍堎堖堙堞堠礁堧堨輿堭堮蜓摘堲堳堽堿塁塄塈煤塋棵塍塏塒塓綢塕鴉沽虱塙塚塝繆塡塢塤塥塩塬塱塲蟎塼塽塾塿墀墁墈墉墐夯増毀墝墠墦漬缽墫墬墮墰墺墻櫥壅壆壊壌壎壒榨蒜壔壕壖壙壚壜壝壠壡壬壭壱売壴壹壻壼寢壿夂夅夆変夊夌漱邑夓腕泄甥禦骼夗夘夙袞瑙妊娠醣梟珊鶯鷺戧幻魘夤蹀祕擂鶇姚宛閨嶼庾撻拇賛蛤裨菠氅漓撈湄蚊霆鯊箐篆篷荊肆舅荔鮃巷慚骰辟邱鎔鐮阪漂燴鯢鰈鱷鴇臚鵬妒峨譚枰晏璣癸祝秤竺牡籟恢罡螻蠍賜絨御梭夬夭砣榆怙枕夶夾餡奄崛葩譎奈賀祀贈奌奐奓奕訢詝奘奜奠奡奣陶奨奩魁奫奬奰媧孩貶隸酥宄狡猾她奼嫣妁氈荼皋膻蠅嬪妄妍嫉媚嬈妗趣妚妞妤礙妬婭妯娌妲妳妵妺姁姅姉姍姒姘姙姜姝姞姣姤姧姫姮娥姱姸姺姽婀娀誘懾脅娉婷娑娓娟娣娭娯娵娶娸娼婊婐婕婞婤婥谿孺婧婪婬婹婺婼婽媁媄媊媕媞媟媠媢媬媮媯媲媵媸媺媻媼眯媿嫄嫈嫋嫏嫕嫗嫘嫚嫜嫠嫡嫦嫩嫪毐嫫嫬嫰嫵嫺嫻嫽嫿嬀嬃嬅嬉耍嬋痴豔嬔嬖嬗嬙嬝嬡嬢嬤嬦嬬嬭幼嬲嬴嬸嬹嬾嬿孀孃孅孌孏曰癲屏孑孓雀孖斟簍謎摺孛矻鳩崮軻祜鸞孥邈毓棠臏孬孭孰孱孳孵泛罔銜孻孿宀宁宂拙株薇掣撫琪瓿榴謐彌宊濂祁瑕宍宏碁宓邸讞実潢町宥宧宨宬徵崎駿掖闕臊煮禽蠶宸豫寀寁寥寃簷庶寎暄磣寔寖寘寙寛寠苫寤肘洱濫蒗陝覈寪弘綽螽寳擅疙瘩晷対檐専尃尅贖絀繚疇釁尌峙醌襟痲碧屁昊槌淘恵瀑牝畑莓缸羚覷蔻髒躁尒尓銳尗尙尜尟尢尥尨尪尬尭尰擒尲尶尷尸尹潽蠖蛾尻釦梢蚴鰭脬蹲屇屌蚵屐屓挪屖屘屙屛屝屢屣巒嶂巖舄屧屨屩屪屭屮戍駐鉀崖嵛巔旮旯楂欖櫸芋茱萸靛麓屴屹屺屼岀岊岌岍阜岑彭鞏岒岝岢嵐岣岧岨岫岱岵岷峁峇峋峒峓峞峠嵋峩峯峱峴峹峿崀崁崆禎崋崌崍嶇崐崒崔嵬巍螢顥崚崞崟崠崢巆崤崦崧殂崬崱崳崴崶崿嵂嵇嵊泗嵌嵎嵒嵓嵗嵙嵞嵡嵩嵫嵯嵴嵼嵾嶁嶃嶄晴嶋嶌嶒嶓嶔嶗嶙嶝嶞嶠嶡嶢嶧嶨嶭嶮嶰嶲嶴嶸巂巃巇巉巋巌巓巘巛滇芎巟巠弋迴巣巤炊擘蜥蟒蠱覡巰蜀彥淖杏茂甫楞巻巽幗巿帛斐鯽蕊帑帔帗帚琉汶帟帡帣帨帬帯帰帷帹暆幃幄幇幋幌幏幘幙幚幞幠幡幢幦幨幩幪幬幭幯幰遙蹉跎餘庚鑑幵幷稚邃庀庁広庄庈庉笠庋跋庖犧庠庤庥鯨庬庱庳庴庵馨衢庹庿廃廄廆廋廌廎廏廐廑廒廕廖廛廝搏鑼廞弛袤廥廧廨廩廱綿踵髓廸廹甌鄴廻廼廾廿躔弁皺弇弌弍弎弐弒弔詭憾薦弝弢弣弤弨弭弮弰弳霖繇燾斌旭溥騫弶弸弼弾彀彄彆纍糾彊彔彖彘彟彠陌彤貽彧繪虹彪炳彫蔚鷗彰癉彲彳彴彷彷徉徨彸彽踩斂旆徂徇徊渭畬鉉裼従筌徘徙徜徠膳甦萌漸徬徭醺徯徳徴潘徻徼忀瘁胖燎怦悸顫扉犀澎湃砰恍惚絞隘忉憚挨餓忐忑忒忖応忝忞耿忡忪忭忮忱忸怩忻悠懣怏遏怔怗怚怛怞懟黍訝怫怭懦怱怲怳怵惕怸怹恁恂恇恉恌恏恒恓恔恘恚恛恝恞恟恠恣恧眄恪恫恬澹恰恿悀悁悃悄悆悊悐悒晦悚悛悜悝悤您悩悪悮悰悱悽惻悳悴悵惘悶悻悾惄愫鍾蒐惆惇惌惎惏惓惔惙惛耄惝瘧濁惥惦惪惲惴惷惸拈愀愃愆愈愊愍愐愑愒愓愔愕愙氓蠢騃昵愜赧愨愬愮愯愷愼慁慂慅慆慇靄慉慊慍慝慥慪慫慬慱慳慴慵慷慼焚憀灼鬱憃憊憋憍眺捏軾憒憔憖憙憧憬憨憪憭憮憯憷憸憹憺懃懅懆邀懊懋懌懍懐懞懠懤懥懨懫懮懰懱毖懵遁樑雍懺懽戁戄戇戉戔戕戛戝戞戠戡戢戣戤戥戦戩戭戯轟戱披菊牖戸戹戺戻戼戽鍬扂楔扃扆扈扊杖牽絹銬鐲賚扐摟攪烊盹瞌跟躉鑔靶鼾払扗玫腮扛扞扠扡扢盔押扤扦扱罾揄綏鞍郤窾扻扼扽抃抆抈抉抌抏瞎抔繯縊擻抜抝択抨摔歉躥牾抶抻搐泵菸拃拄拊髀拋拌脯拎拏拑擢秧沓曳攣迂拚拝拠拡拫拭拮踢拴拶拷攢拽掇芥橐簪摹疔挈瓢驥捺蹻挌挍挎挐揀挓挖掘浚挙揍聵挲挶挾挿捂捃捄捅捆捉捋胳膊揎捌捍捎軀蛛捗捘捙捜捥捩捫捭据捱捻捼捽掀掂掄臀膘掊掎掏掐笙掔掗掞棉芍掤搪闡掫掮掯揉掱掲掽掾揃揅揆搓揌諢揕揗揘揜揝揞揠揥揩揪揫櫫遒麈揰揲揵揶揸揹揺搆搉搊搋搌搎搔搕撼櫓搗搘搠搡搢搣搤搥搦搧搨搬楦褳訕赸搯搰搲搳搴搵搷搽搾搿摀摁摂摃摎摑摒摓跤摙摛摜摞摠摦睺羯摭摮摯摰摲摳摴摶摷摻摽撂撃撅稻撊撋撏鐧潑撕撙撚撝撟撢撣撦撧撩撬撱朔撳蚍蜉撾撿擀擄闖擉缶觚擐擕擖擗擡擣擤澡腚擧擨擩擫擭擯擰擷擸擼擽擿攃攄攆攉攥攐攓攖攙攛每攩攫轡澄攮攰攲攴軼攷砭訐攽碘敁敃敇敉敍敎筏敔敕敖閏誨敜煌敧敪敱敹敺敻敿斁衽斄牒縐謅斉斎斕鶉讕駮鱧斒筲斛斝斞斠斡斢斨斫斮晾沂潟穎絳邵斲斸釳於琅斾斿旀旂旃旄渦旌旎旐旒旓旖旛旝旟旡旣浴旰獺魃旴旹旻旼旽昀昃昄昇昉晰躲澈熹皎皓礬昑昕昜昝昞昡昤暉筍昦昨昰昱昳昴昶昺昻晁蹇隧蔬髦晄晅晒晛晜晞晟晡晢晤晥曦晩萘瑩顗晿暁暋暌暍暐暔暕煅暘暝暠暡曚暦暨暪朦朧暱暲殄馮暵暸暹暻暾曀曄曇曈曌曏曐曖曘曙曛曡曨曩駱曱甴肱曷牘禺錕曽滄耽朁朅朆杪栓誇竟粘絛朊膺朏朐朓朕朘朙瞄覲溘饔飧朠朢朣柵椆澱蝨朩朮朰朱炆璋鈺熾鹮朳槿朶朾朿杅杇杌隉欣釗湛漼楷瀍煜玟纓翱肈舜贄适逵杓杕杗杙荀蘅杝杞脩珓筊杰榔狍閦顰緬莞杲杳眇杴杶杸杻杼枋枌枒枓衾葄翹紓逋枙狸椏枟槁枲枳枴枵枷枸櫞枹枻柁柂柃柅柈柊柎某柑橘柒柘柙柚柜柞櫟柟柢柣柤柩柬柮柰柲橙柶柷柸柺査柿栃栄栒栔栘栝栟栢栩栫栭栱栲栳栴檀栵栻桀驁桁鎂桄桉桋桎梏椹葚桓桔桕桜桟桫欏桭桮桯桲桴桷桹湘溟梃梊梍梐潼梔梘梜梠梡梣梧梩梱梲梳梴梵梹棁棃櫻棐棑棕櫚簑繃蓑棖棘棜棨棩棪棫棬棯棰棱棳棸棹槨棼椀椄苕椈椊椋椌椐椑椓椗検椤椪椰椳椴椵椷椸椽椿楀楄楅篪楋楍楎楗楘楙楛楝楟楠楢楥楨楩楪楫楬楮楯楰楳楸楹楻楽榀榃榊榎槺榕榖榘榛狉莽榜笞榠榡榤榥榦榧榪榭榰榱槤霰榼榾榿槊閂槎槑槔槖様槜槢槥槧槪槭槮槱槲槻槼槾樆樊樏樑樕樗樘樛樟樠樧樨権樲樴樵猢猻樺樻罍樾樿橁橄橆橈笥龠橕橚橛輛橢橤橧豎膈跨橾橿檁檃檇檉檍檎檑檖檗檜檟檠檣檨檫檬檮檳檴檵檸櫂櫆櫌櫛櫜櫝櫡櫧櫨櫪櫬櫳櫹櫺茄櫽欀欂欃欐欑欒欙欞溴欨欬欱欵欶欷歔欸欹欻欼欿歁歃歆艎歈歊蒔蝶歓歕歘歙歛歜歟歠蹦詮鑲蹣跚陞陟歩歮歯歰歳歴璞歺瞑歾歿殀殈殍殑殗殜殙殛殞殢殣殥殪殫殭殰殳荃殷殸殹蛟殻殽謗毆毈毉餵毎毑蕈毗毘毚茛鄧毧毬毳毷毹毽毾毿氂氄氆靴氉氊氌氍氐聊氕氖気氘氙氚氛氜氝氡洶焊痙氤氳氥氦鋁鋅氪烴氬銨痤汪滸漉痘盂碾菖蒲蕹蛭螅氵氷氹氺氽燙氾氿渚汆汊汋汍汎汏汐汔汕褟汙汚汜蘺沼穢衊汧汨汩汭汲汳汴隄汾沄沅沆瀣沇沈葆浸淪湎溺痼痾沌沍沏沐沔沕沘浜畹礫沚沢沬沭沮沰沱灢沴沷籽沺烹濡洄泂肛泅泆湧肓泐泑泒泓泔泖泙泚泜泝泠漩饃濤粼濘蘚鰍泩泫泭泯銖泱泲洇洊涇琵琶荽薊箔洌洎洏洑潄濯洙洚洟洢洣洧洨洩痢滔洫洮洳洴洵洸洹洺洼洿淌蜚浄浉浙贛渫浠浡浤浥淼瀚浬浭翩萍浯浰蜃淀苔蛞蝓蜇螵蛸煲鯉浹浼浽溦涂涊涐涑涒涔滂涖涘涙涪涫涬涮涴涶涷涿淄淅淆淊淒黯淓淙漣淜淝淟淠淢淤淥淦淩猥藿褻淬淮淯淰淳詣淶紡淸淹燉癯綺渇済渉渋渓渕渙渟渢滓渤澥渧渨渮渰渲渶渼湅湉湋湍湑湓湔黔湜湝湞湟湢湣湩湫湮麟湱湲湴湼満溈溍溎溏溛舐漭溠溤溧馴溮溱溲溳溵溷溻溼溽溾滁滃滉滊滎滏稽滕滘滙滝滫滮羼耷滷滹滻煎漈漊漎繹漕漖漘漙漚漜漪漾漥漦漯漰漵漶漷濞潀潁潎潏潕潗潚潝潞潠潦祉瘍潲潵潷潸潺潾潿澁澂澃澉澌澍澐澒澔澙澠澣澦澧澨澫澬澮澰澴澶澼熏郁濆濇濈濉濊貊濔疣濜濠濩觴濬濮盥濰濲濼瀁瀅瀆瀋瀌瀏瀒瀔瀕瀘瀛瀟瀠瀡瀦瀧瀨瀬瀰瀲瀳瀵瀹瀺瀼灃灄灉灋灒灕灖灝灞灠灤灥灨灩灪蜴灮燼獴灴灸灺炁炅魷炗炘炙炤炫疽烙釺炯炰炱炲炴炷燬炻烀烋瘴鯧烓烔焙烜烝烳飪烺焃焄耆焌焐焓焗焜焞焠焢焮焯焱焼煁煃煆煇煊熠煍熬煐煒煕煗燻礆霾煚煝煟煠煢矸煨瑣煬萁煳煺煻熀熅熇熉羆熒穹熗熘熛熜稔諳爍熤熨熯熰眶螞熲熳熸熿燀燁燂燄盞燊燋燏燔隼燖燜燠燡燦燨燮燹燻燽燿爇爊爓爚爝爟爨蟾爯爰爲爻爿爿牀牁牂牄牋牎牏牓牕釉牚腩蒡虻牠雖蠣牣牤牮牯牲牳牴牷牸牼絆牿靬犂犄犆犇犉犍犎犒犖犗犛犟犠犨犩犪犮犰狳犴犵犺狁甩狃狆狎狒獾狘狙黠狨狩狫狴狷狺狻豕狽蜘猁猇猈猊猋猓猖獗猗猘猙獰獁猞猟獕猭猱猲猳猷猸猹猺玃獀獃獉獍獏獐獒獘獙獚獜獝獞獠獢獣獧鼇蹊獪獫獬豸獮獯鬻獳獷獼玀玁菟玅玆玈珉糝禛郅玍玎玓瓅玔玕玖玗玘玞玠玡玢玤玥玦玨瑰玭玳瑁玶玷玹玼珂珇珈瑚珌饈饌珔珖珙珛珞珡珣珥珧珩珪珮珶珷珺珽琀琁隕琊琇琖琚琠琤琦琨琫琬琭琮琯琰琱琲瑯琹琺琿瑀瑂瑄瑉瑋瑑瑔瑗瑢瑭瑱瑲瑳瑽瑾瑿璀璨璁璅璆璈璉璊璐璘璚璝璟璠璡璥璦璩璪璫璯璲璵璸璺璿瓀瓔瓖瓘瓚瓛臍瓞瓠瓤瓧瓩瓮瓰瓱瓴瓸瓻瓼甀甁甃甄甇甋甍甎甏甑甒甓甔甕甖甗飴蔗甙詫鉅粱盎銹糰甡褥産甪甬甭甮甯鎧甹甽甾甿畀畁畇畈畊畋畎畓畚畛畟鄂畤畦畧荻畯畳畵畷畸畽畾疃疉疋疍疎簞疐疒疕疘疝疢疥疧疳疶疿痁痄痊痌痍痏痐痒痔痗瘢痚痠痡痣痦痩痭痯痱痳痵痻痿瘀瘂瘃瘈瘉瘊瘌瘏瘐瘓瘕瘖瘙瘚瘛瘲瘜瘝瘞瘠瘥瘨瘭瘮瘯瘰癧瘳癘瘵瘸瘺瘻瘼癃癆癇癈癎癐癔癙癜癠癤癥癩蟆癪癭癰発踔紺蔫酵皙砬砒翎翳蘞鎢鑞皚鵯駒鱀粵褶皀皁莢皃鎛皈皌皐皒硃皕皖皘皜皝皞皤皦皨皪皫皭糙綻皴皸皻皽盅盋盌盍盚盝踞盦盩鞦韆盬盭眦睜瞤盯盱眙裰盵盻睞眂眅眈眊県眑眕眚眛眞眢眣眭眳眴眵眹瞓眽郛睃睅睆睊睍睎睏睒睖睙睟睠睢睥睪睪睯睽睾瞇瞈瞋瞍逛瞏瞕瞖瞘瞜瞟瞠瞢瞫瞭瞳瞵瞷瞹瞽闍瞿矓矉矍鑠矔矗矙矚矞矟矠矣矧矬矯矰矱硪碇磙罅舫阡、矼矽礓砃砅砆砉砍砑砕砝砟砠砢砦砧砩砫砮砳艏砵砹砼硇硌硍硎硏硐硒硜硤硨磲茚鋇硭硻硾碃碉碏碣碓碔碞碡碪碫碬碭碯碲碸碻礡磈磉磎磑磔磕磖磛磟磠磡磤磥蹭磪磬磴磵磹磻磽礀礄礅礌礐礚礜礞礤礧礮礱礲礵礽礿祂祄祅祆禳祊祍祏祓祔祕祗祘祛祧祫祲祻祼餌臠錮禂禇禋禑禔禕隋禖禘禚禜禝禠禡禢禤禥禨禫禰禴禸稈秈秊闈颯秌秏秕笈蘵賃秠秣秪秫秬秭秷秸稊稌稍稑稗稙稛稞稬稭稲稹稼顙稾穂穄穇穈穉穋穌貯穏穜穟穠穡穣穤穧穨穭穮穵穸窿闃窀窂窅窆窈窕窊窋窌窒窓窔窞窣窬黷蹙窰窳窴窵窶窸窻竁竃竈竑竜竝竦竪篦篾笆鮫竾笉笊笎笏笐靨笓笤籙笪笫笭笮笰笱笲笳笵笸笻筀筅筇筈筎筑筘筠筤筥筦筧筩筭筯筰筱筳筴讌筸箂箇箊箎箑箒箘箙箛箜篌箝箠箬鏃箯箴箾篁篔簹篘篙篚篛篜篝篟篠篡篢篥篧篨篭篰篲篳篴篶篹篼簀簁簃簆簉簋簌簏簜簟簠簥簦簨簬簰簸簻籊籐籒籓籔籖籚籛籜籣籥籧籩籪籫籯芾麴籵籸籹籼粁粃粋粑粔糲粛粞粢粧粨粲粳粺粻粽闢粿糅糆糈糌糍糒糔萼糗蛆蹋糢糨糬糭糯糱糴糶糸糺紃蹼鰹黴紆紈絝紉閩襻紑紕紘錠鳶鷂紝紞紟紥紩紬紱紲紵紽紾紿絁絃絅経絍絎絏縭褵絓絖絘絜絢絣螯絪絫聒絰絵絶絺絻絿綀綃綅綆綈綉綌綍綎綑綖綘継続緞綣綦綪綫綮綯綰罟蝽綷縩綹綾緁緄緅緆緇緋緌緎総緑緔緖緗緘緙緜緡緤緥緦纂緪緰緱緲緶緹縁縃縄縈縉縋縏縑縕縗縚縝縞縟縠縡縢縦縧縯縰騁縲縳縴縵縶縹縻衙縿繄繅繈繊繋繐繒繖繘繙繠繢繣繨繮繰繸繻繾纁纆纇纈纉纊纑纕纘纙纚纛缾罃罆罈罋罌罎罏罖罘罛罝罠罣罥罦罨罫罭鍰罳罶罹罻罽罿羂羃羇羋蕉51鴕羑羖羗羜羝羢羣羥羧羭羮羰羱羵羶羸藜鮐翀翃翄翊翌翏翕翛翟翡翣翥翦躚翪翫翬翮翯翺翽翾翿闆饕鴰鍁耋耇耎耏耑耒耜耔耞耡耤耨耩耪耬耰鬢耵聹聃聆聎聝聡聦聱聴聶聼閾聿肄肏肐肕腋肙肜肟肧胛肫肬肭肰肴肵肸肼胊胍胏胑胔胗胙胝胠銓胤胦胩胬胭胯胰胲胴胹胻胼胾脇脘脝脞脡脣脤脥脧脰脲脳腆腊腌臢腍腒腓腖腜腠腡腥腧腬腯踝蹬鐐腴腶蠕誹膂膃膆膇膋膔膕膗膙膟黐膣膦膫膰膴膵膷膾臃臄臇臈臌臐臑臓臕臖臙臛臝臞臧蓐詡臽臾臿舀舁鰟鮍舋舎舔舗舘舝舠舡舢舨舭舲舳舴舸舺艁艄艅艉艋艑艕艖艗艘艚艜艟艣艤艨艩艫艬艭荏艴艶艸艹艻艿芃芄芊萰陂藭芏芔芘芚蕙芟芣芤茉芧芨芩芪芮芰鰱芴芷芸蕘豢芼芿苄苒苘苙苜蓿苠苡苣蕒苤苧苪鎊苶苹苺苻苾茀茁范蠡萣茆茇茈茌茍茖茞茠茢茥茦菰茭茯茳藨茷藘茼荁荄荅荇荈菅蜢鴞荍荑荘荳荵荸薺莆莒莔莕莘莙莚莛莜莝莦莨菪莩莪莭莰莿菀菆菉菎菏菐菑菓菔菕菘菝菡菢菣菥蓂菧菫轂鎣菶菷菹醢菺菻菼菾萅萆萇萋萏萐萑萜萩萱萴萵萹萻葇葍葎葑葒葖葙葠葥葦葧葭葯葳葴葶葸葹葽蒄蒎蒓蘢薹蒞蒟蒻蒢蒦蒨蒭藁蒯蒱鉾蒴蒹蒺蒽蓀蓁蓆蓇蓊蓌蓍蓏蓓蓖蓧蓪蓫蓽跣藕蓯蓰蓱蓴蓷蓺蓼蔀蔂蔃蔆蔇蔉蔊蔋蔌蔎蔕蔘蔙蔞蔟鍔蔣雯蔦蔯蔳蔴蔵蔸蔾蕁蕆蕋蕍蕎蕐蕑蕓蕕蕖蕗蕝蕞蕠蕡蕢蕣蕤蕨蕳蕷蕸蕺蕻薀薁薃薅薆薈薉薌薏薐薔薖薘薙諤釵薜薠薢薤薧薨薫薬薳薶薷薸薽薾薿藄藇藋藎藐藙藚藟藦藳藴藶藷藾蘀蘁蘄蘋蘗蘘蘝蘤蘧蘩蘸蘼虀虆虍蟠虒虓虖虡虣虥虩虯虰蛵虵虷鱒虺虼蚆蚈蚋蚓蚔蚖蚘蚜蚡蚣蚧蚨蚩蚪蚯蚰蜒蚱蚳蚶蚹蚺蚻蚿蛀蛁蛄蛅蝮蛌蛍蛐蟮蛑蛓蛔蛘蛚蛜蛡蛣蜊蛩蛺蛻螫蜅蜆蜈蝣蜋蜍蜎蜑蠊蜛餞蜞蜣蜨蜩蜮蜱蜷蜺蜾蜿蝀蝃蝋蝌蝍蝎蝏蝗蝘蝙蝝鱝蝡蝤蝥蝯蝰蝱蝲蝴蝻螃蠏螄螉螋螒螓螗螘螙螚蟥螟螣螥螬螭螮螾螿蟀蟅蟈蟊蟋蟑蟓蟛蟜蟟蟢蟣蟨蟪蟭蟯蟳蟶蟷蟺蟿蠁蠂蠃蠆蠋蠐蠓蠔蠗蠙蠚蠛蠜蠧蠨蠩蠭蠮蠰蠲蠵蠸蠼蠽衁衂衄衇衈衉衋衎衒衕衖衚衞裳鈎衭衲衵衹衺衿袈裟袗袚袟袢袪袮袲袴袷袺袼褙袽裀裉裊裋裌裍裎裒裛裯裱裲裴裾褀褂褉褊褌褎褐褒褓褔褕褘褚褡褢褦褧褪褫褭褯褰褱襠褸褽褾襁襃襆襇襉襋襌襏襚襛襜襝襞襡襢襤襦襫襬襭襮襴襶襼襽襾覂覃覅覇覉覊覌覗覘覚覜覥覦覧覩覬覯覰観覿觔觕觖觜觽觝觡酲觩觫觭觱觳觶觷觼觾觿言賅訃訇訏訑訒詁託訧訬訳訹証訾詀詅詆譭詈詊詎詑詒詖詗詘詧詨詵詶詸詹詻詼詿誂誃誄鋤誆誋誑誒誖誙誚誥誧説読誯誶誾諂諄諆諌諍諏諑諕諗諛諝諞諟諠諡諴諵諶諼謄謆謇謌謍謏謑謖謚謡謦謪謫謳謷謼謾譁譅譆譈譊譌譒譔譖鑫譞譟譩譫譬譱譲譴譸譹譾讅讆讋讌讎讐讒讖讙讜讟谽豁豉豇豈豊豋豌豏豔豞豖豗豜豝豣豦豨豭豱豳豵豶豷豺豻貅貆貍貎貔貘貙貜貤饜貰餸貺賁賂賏賒賕賙賝賡賧賨賫鬭賮賵賸賺賻賾贇贉贐贔贕贗赬赭赱赳迄趁趂趄趐趑趒趔趡趦趫趮趯趲趴趵趷趹趺趿跁跂跅跆躓蹌跐跕跖跗跙跛跦跧跩跫跬跮跱跲跴跺跼跽踅踆踈踉踊踒踖踘踜踟躇躕踠踡踣踤踥踦踧蹺踫踮踰踱踴踶踹踺踼踽躞蹁蹂躪蹎蹐蹓蹔蹕蹚蹜蹝蹟蹠蹡蹢躂蹧蹩蹪蹯鞠蹽躃躄躅躊躋躐躑躒躘躙躛躝躠躡躦躧躩躭躰躳躶軃軆輥軏軔軘軜軝齶転軥軨軭軱軲轆軷軹軺軽軿輀輂輦輅輇輈輓輗輙輜輞輠輤輬輭輮輳輴輵輶輹輼輾轀轇轏轑轒轔轕轖轗轘轙轝轞轢轤辠辢辤辵辶辺込辿迅迋迍麿迓迣迤邐迥迨迮迸迺迻迿逄逅逌逍逑逓逕逖逡逭逯逴逶逹遄遅遉遘遛遝遢遨遫遯遰遴遶遹遻邂邅邉邋邎邕邗邘邛邠邢邧邨邯鄲邰邲邳邴邶邷邽邾邿郃郄郇郈郔郕郗郙郚郜郝郞郟郠郢郪郫郯郰郲郳郴郷郹郾郿鄀鄄鄆鄇鄈鄋鄍鄎鄏鄐鄑鄒鄔鄕鄖鄗鄘鄚鄜鄞鄠鄢鄣鄤鄦鄩鄫鄬鄮鄯鄱鄶鄷鄹鄺鄻鄾鄿酃酅酆酇酈酊酋酎酏酐酣酔酕醄酖酗酞酡酢酤酩酴酹酺醁醅醆醊醍醐醑醓醖醝醞醡醤醨醪醭醯醰醱醲醴醵醸醹醼醽醾釂釃釅釆釈鱸鎦閶釓釔釕鈀釙鼢鼴釤釧釪釬釭釱釷釸釹鈁鈃鈄鈆鈇鈈鈊鈌鈐鈑鈒鈤鈥鈧鈬鈮鈰鈳鐺鈸鈹鈽鈿鉄鉆鉈鉋鉌鉍鉏鉑鉕鉚鉢鉥鉦鉨鉬鉭鉱鉲鉶鉸鉺鉼鉿銍銎銑銕鏤銚銛銠銣銤銥銦銧銩銪銫銭銰銲銶銻銼銾鋂鋃鋆鋈鋊鋌鋍鋏鋐鋑鋕鋘鋙鋝鋟鋦鋨鋩鋭鋮鋯鋰鋱鋳鋹鋺鋻鏰鐱錀錁錆錇錈錍錏錒錔錙錚錛錞錟錡錤錩錬録錸錼鍀鍆鍇鍉鍍鍏鍐鍘鍚鍛鍠鍤鍥鍩鍫鍭鍱鍴鍶鍹鍺鍼鍾鎄鎇鎉鎋鎌鎍鎏鎒鎓鎗鎘鎚鎞鎡鎤鎩鎪鎭鎯鎰鎳鎴鎵鎸鎹鎿鏇鏊鏌鏐鏑鏖鏗鏘鏚鏜鏝鏞鏠鏦鏨鏷鏸鏹鏻鏽鏾鐃鐄鐇鐏鐒鐓鐔鐗馗鐙鐝鐠鐡鐦鐨鐩鐫鐬鐱鐳鐶鐻鐽鐿鑀鑅鑌鑐鑕鑚鑛鑢鑤鑥鑪鑭鑯鑱鑴鑵鑷钁钃镻閆閈閌閎閒閔閗閟閡関閤閤閧閬閲閹閺閻閼閽閿闇闉闋闐闑闒闓闘闚闞闟闠闤闥阞阢阤阨阬阯阹阼阽陁陑陔陛陜陡陥陬騭陴険陼陾隂隃隈隒隗隞隠隣隤隩隮隰顴隳隷隹雂雈雉雊雎雑雒雗雘雚雝雟雩雰雱驛霂霅霈霊霑霒霓霙霝霢霣霤霨霩霪霫霮靁靆靉靑靚靣靦靪靮靰靳靷靸靺靼靿鞀鞃鞄鞌鞗鞙鞚鞝鞞鞡鞣鞨鞫鞬鞮鞶鞹鞾韃韅韉馱韍韎韔韖韘韝韞韡韣韭韮韱韹韺頀颳頄頇頊頍頎頏頒頖頞頠頫頬顱頯頲頴頼顇顋顑顒顓顔顕顚顜顢顣顬顳颭颮颱颶颸颺颻颽颾颿飀飂飈飌飜飡飣飤飥飩飫飮飱飶餀餂餄餎餇餈餑餔餕餖餗餚餛餜餟餠餤餧餩餪餫餬餮餱餲餳餺餻餼餽餿饁饅饇饉饊饍饎饐饘饟饢馘馥馝馡馣騮騾馵馹駃駄駅駆駉駋駑駓駔駗駘駙駜駡駢駪駬駰駴駸駹駽駾騂騄騅騆騉騋騍騏驎騑騒験騕騖騠騢騣騤騧驤騵騶騸騺驀驂驃驄驆驈驊驌驍驎驏驒驔驖驙驦驩驫骺鯁骫骭骯骱骴骶骷髏骾髁髂髄髆髈髐髑髕髖髙髝髞髟髡髣髧髪髫髭髯髲髳髹髺髽髾鬁鬃鬅鬈鬋鬎鬏鬐鬑鬒鬖鬗鬘鬙鬠鬣鬪鬫鬬鬮鬯鬰鬲鬵鬷魆魈魊魋魍魎魑魖鰾魛魟魣魦魨魬魴魵魸鮀鮁鮆鮌鮎鮑鮒鮓鮚鮞鮟鱇鮠鮦鮨鮪鮭鮶鮸鮿鯀鯄鯆鯇鯈鯔鯕鯖鯗鯙鯠鯤鯥鯫鯰鯷鯸鯿鰂鰆鶼鰉鰋鰐鰒鰕鰛鰜鰣鰤鰥鰦鰨鰩鰮鰳鰶鰷鱺鰼鰽鱀鱄鱅鱆鱈鱎鱐鱓鱔鱖鱘鱟鱠鱣鱨鱭鱮鱲鱵鱻鲅鳦鳧鳯鳲鳷鳻鴂鴃鴄鴆鴈鴎鴒鴔鴗鴛鴦鴝鵒鴟鴠鴢鴣鴥鴯鶓鴳鴴鴷鴽鵀鵁鵂鵓鵖鵙鵜鶘鵞鵟鵩鵪鵫鵵鵷鵻鵾鶂鶊鶏鶒鶖鶗鶡鶤鶦鶬鶱鶲鶵鶸鶹鶺鶿鷀鷁鷃鷄鷇鷈鷉鷊鷏鷓鷕鷖鷙鷞鷟鷥鷦鷯鷩鷫鷭鷳鷴鷽鷾鷿鸂鸇鸊鸏鸑鸒鸓鸕鸛鸜鸝鹸鹹鹺麀麂麃麄麇麋麌麐麑麒麚麛麝麤麩麪麫麮麯麰麺麾黁黈黌黢黒黓黕黙黝黟黥黦黧黮黰黱黲黶黹黻黼黽黿鼂鼃鼅鼈鼉鼏鼐鼒鼕鼖鼙鼚鼛鼡鼩鼱鼪鼫鼯鼷鼽齁齆齇齈齉齌齎齏齔齕齗齙齚齜齞齟齬齠齢齣齧齩齮齯齰齱齵齾龎龑龒龔龖龘龝龡龢龤" + +assert len(simplified_charcters) == len(simplified_charcters) + +s2t_dict = {} +t2s_dict = {} +for i, item in enumerate(simplified_charcters): + s2t_dict[item] = traditional_characters[i] + t2s_dict[traditional_characters[i]] = item + + +def tranditional_to_simplified(text: str) -> str: + return "".join([t2s_dict[item] if item in t2s_dict else item for item in text]) + + +def simplified_to_traditional(text: str) -> str: + return "".join([s2t_dict[item] if item in s2t_dict else item for item in text]) + + +if __name__ == "__main__": + text = "一般是指存取一個應用程式啟動時始終顯示在網站或網頁瀏覽器中的一個或多個初始網頁等畫面存在的站點" + print(text) + text_simple = tranditional_to_simplified(text) + print(text_simple) + text_traditional = simplified_to_traditional(text_simple) + print(text_traditional) diff --git a/text/zh_normalization/chronology.py b/text/zh_normalization/chronology.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6f66c21de80aba717434849a36065f4b885a12 --- /dev/null +++ b/text/zh_normalization/chronology.py @@ -0,0 +1,139 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +from .num import DIGITS +from .num import num2str +from .num import verbalize_cardinal +from .num import verbalize_digit + + +def _time_num2str(num_string: str) -> str: + """A special case for verbalizing number in time.""" + result = num2str(num_string.lstrip("0")) + if num_string.startswith("0"): + result = DIGITS["0"] + result + return result + + +# 时刻表达式 +RE_TIME = re.compile( + r"([0-1]?[0-9]|2[0-3])" + r":([0-5][0-9])" + r"(:([0-5][0-9]))?" +) + +# 时间范围,如8:30-12:30 +RE_TIME_RANGE = re.compile( + r"([0-1]?[0-9]|2[0-3])" + r":([0-5][0-9])" + r"(:([0-5][0-9]))?" + r"(~|-)" + r"([0-1]?[0-9]|2[0-3])" + r":([0-5][0-9])" + r"(:([0-5][0-9]))?" +) + + +def replace_time(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + + is_range = len(match.groups()) > 5 + + hour = match.group(1) + minute = match.group(2) + second = match.group(4) + + if is_range: + hour_2 = match.group(6) + minute_2 = match.group(7) + second_2 = match.group(9) + + result = f"{num2str(hour)}点" + if minute.lstrip("0"): + if int(minute) == 30: + result += "半" + else: + result += f"{_time_num2str(minute)}分" + if second and second.lstrip("0"): + result += f"{_time_num2str(second)}秒" + + if is_range: + result += "至" + result += f"{num2str(hour_2)}点" + if minute_2.lstrip("0"): + if int(minute) == 30: + result += "半" + else: + result += f"{_time_num2str(minute_2)}分" + if second_2 and second_2.lstrip("0"): + result += f"{_time_num2str(second_2)}秒" + + return result + + +RE_DATE = re.compile( + r"(\d{4}|\d{2})年" + r"((0?[1-9]|1[0-2])月)?" + r"(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?" +) + + +def replace_date(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + year = match.group(1) + month = match.group(3) + day = match.group(5) + result = "" + if year: + result += f"{verbalize_digit(year)}年" + if month: + result += f"{verbalize_cardinal(month)}月" + if day: + result += f"{verbalize_cardinal(day)}{match.group(9)}" + return result + + +# 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期 +RE_DATE2 = re.compile(r"(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])") + + +def replace_date2(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + year = match.group(1) + month = match.group(3) + day = match.group(4) + result = "" + if year: + result += f"{verbalize_digit(year)}年" + if month: + result += f"{verbalize_cardinal(month)}月" + if day: + result += f"{verbalize_cardinal(day)}日" + return result diff --git a/text/zh_normalization/constants.py b/text/zh_normalization/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..4218a551dc0425e9d0726220e9859708fddedd89 --- /dev/null +++ b/text/zh_normalization/constants.py @@ -0,0 +1,62 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +import string + +from pypinyin.constants import SUPPORT_UCS4 + +# 全角半角转换 +# 英文字符全角 -> 半角映射表 (num: 52) +F2H_ASCII_LETTERS = {ord(char) + 65248: ord(char) for char in string.ascii_letters} + +# 英文字符半角 -> 全角映射表 +H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()} + +# 数字字符全角 -> 半角映射表 (num: 10) +F2H_DIGITS = {ord(char) + 65248: ord(char) for char in string.digits} +# 数字字符半角 -> 全角映射表 +H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()} + +# 标点符号全角 -> 半角映射表 (num: 32) +F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation} +# 标点符号半角 -> 全角映射表 +H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()} + +# 空格 (num: 1) +F2H_SPACE = {"\u3000": " "} +H2F_SPACE = {" ": "\u3000"} + +# 非"有拼音的汉字"的字符串,可用于NSW提取 +if SUPPORT_UCS4: + RE_NSW = re.compile( + r"(?:[^" + r"\u3007" # 〇 + r"\u3400-\u4dbf" # CJK扩展A:[3400-4DBF] + r"\u4e00-\u9fff" # CJK基本:[4E00-9FFF] + r"\uf900-\ufaff" # CJK兼容:[F900-FAFF] + r"\U00020000-\U0002A6DF" # CJK扩展B:[20000-2A6DF] + r"\U0002A703-\U0002B73F" # CJK扩展C:[2A700-2B73F] + r"\U0002B740-\U0002B81D" # CJK扩展D:[2B740-2B81D] + r"\U0002F80A-\U0002FA1F" # CJK兼容扩展:[2F800-2FA1F] + r"])+" + ) +else: + RE_NSW = re.compile( # pragma: no cover + r"(?:[^" + r"\u3007" # 〇 + r"\u3400-\u4dbf" # CJK扩展A:[3400-4DBF] + r"\u4e00-\u9fff" # CJK基本:[4E00-9FFF] + r"\uf900-\ufaff" # CJK兼容:[F900-FAFF] + r"])+" + ) diff --git a/text/zh_normalization/num.py b/text/zh_normalization/num.py new file mode 100644 index 0000000000000000000000000000000000000000..c3af4d6abbdf3f4f9512f1a5eb8eec77c93689e4 --- /dev/null +++ b/text/zh_normalization/num.py @@ -0,0 +1,317 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Rules to verbalize numbers into Chinese characters. +https://zh.wikipedia.org/wiki/中文数字#現代中文 +""" + +import re +from collections import OrderedDict +from typing import List + +DIGITS = {str(i): tran for i, tran in enumerate("零一二三四五六七八九")} +UNITS = OrderedDict( + { + 1: "十", + 2: "百", + 3: "千", + 4: "万", + 8: "亿", + } +) + +COM_QUANTIFIERS = "(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)" + +# 分数表达式 +RE_FRAC = re.compile(r"(-?)(\d+)/(\d+)") + + +def replace_frac(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + sign = match.group(1) + nominator = match.group(2) + denominator = match.group(3) + sign: str = "负" if sign else "" + nominator: str = num2str(nominator) + denominator: str = num2str(denominator) + result = f"{sign}{denominator}分之{nominator}" + return result + + +# 百分数表达式 +RE_PERCENTAGE = re.compile(r"(-?)(\d+(\.\d+)?)%") + + +def replace_percentage(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + sign = match.group(1) + percent = match.group(2) + sign: str = "负" if sign else "" + percent: str = num2str(percent) + result = f"{sign}百分之{percent}" + return result + + +# 整数表达式 +# 带负号的整数 -10 +RE_INTEGER = re.compile(r"(-)" r"(\d+)") + + +def replace_negative_num(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + sign = match.group(1) + number = match.group(2) + sign: str = "负" if sign else "" + number: str = num2str(number) + result = f"{sign}{number}" + return result + + +# 编号-无符号整形 +# 00078 +RE_DEFAULT_NUM = re.compile(r"\d{3}\d*") + + +def replace_default_num(match): + """ + Args: + match (re.Match) + Returns: + str + """ + number = match.group(0) + return verbalize_digit(number, alt_one=True) + + +# 加减乘除 +# RE_ASMD = re.compile( +# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))') +RE_ASMD = re.compile( + r"((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))" +) + +asmd_map = {"+": "加", "-": "减", "×": "乘", "÷": "除", "=": "等于"} + + +def replace_asmd(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + result = match.group(1) + asmd_map[match.group(8)] + match.group(9) + return result + + +# 次方专项 +RE_POWER = re.compile(r"[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+") + +power_map = { + "⁰": "0", + "¹": "1", + "²": "2", + "³": "3", + "⁴": "4", + "⁵": "5", + "⁶": "6", + "⁷": "7", + "⁸": "8", + "⁹": "9", + "ˣ": "x", + "ʸ": "y", + "ⁿ": "n", +} + + +def replace_power(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + power_num = "" + for m in match.group(0): + power_num += power_map[m] + result = "的" + power_num + "次方" + return result + + +# 数字表达式 +# 纯小数 +RE_DECIMAL_NUM = re.compile(r"(-?)((\d+)(\.\d+))" r"|(\.(\d+))") +# 正整数 + 量词 +RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS) +RE_NUMBER = re.compile(r"(-?)((\d+)(\.\d+)?)" r"|(\.(\d+))") + + +def replace_positive_quantifier(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + number = match.group(1) + match_2 = match.group(2) + if match_2 == "+": + match_2 = "多" + match_2: str = match_2 if match_2 else "" + quantifiers: str = match.group(3) + number: str = num2str(number) + number = "两" if number == "二" else number + result = f"{number}{match_2}{quantifiers}" + return result + + +def replace_number(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + sign = match.group(1) + number = match.group(2) + pure_decimal = match.group(5) + if pure_decimal: + result = num2str(pure_decimal) + else: + sign: str = "负" if sign else "" + number: str = num2str(number) + result = f"{sign}{number}" + return result + + +# 范围表达式 +# match.group(1) and match.group(8) are copy from RE_NUMBER + +RE_RANGE = re.compile( + r""" + (? str: + """ + Args: + match (re.Match) + Returns: + str + """ + first, second = match.group(1), match.group(6) + first = RE_NUMBER.sub(replace_number, first) + second = RE_NUMBER.sub(replace_number, second) + result = f"{first}到{second}" + return result + + +# ~至表达式 +RE_TO_RANGE = re.compile( + r"((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)" +) + + +def replace_to_range(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + result = match.group(0).replace("~", "至") + return result + + +def _get_value(value_string: str, use_zero: bool = True) -> List[str]: + stripped = value_string.lstrip("0") + if len(stripped) == 0: + return [] + elif len(stripped) == 1: + if use_zero and len(stripped) < len(value_string): + return [DIGITS["0"], DIGITS[stripped]] + else: + return [DIGITS[stripped]] + else: + largest_unit = next(power for power in reversed(UNITS.keys()) if power < len(stripped)) + first_part = value_string[:-largest_unit] + second_part = value_string[-largest_unit:] + return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part) + + +def verbalize_cardinal(value_string: str) -> str: + if not value_string: + return "" + + # 000 -> '零' , 0 -> '零' + value_string = value_string.lstrip("0") + if len(value_string) == 0: + return DIGITS["0"] + + result_symbols = _get_value(value_string) + # verbalized number starting with '一十*' is abbreviated as `十*` + if len(result_symbols) >= 2 and result_symbols[0] == DIGITS["1"] and result_symbols[1] == UNITS[1]: + result_symbols = result_symbols[1:] + return "".join(result_symbols) + + +def verbalize_digit(value_string: str, alt_one=False) -> str: + result_symbols = [DIGITS[digit] for digit in value_string] + result = "".join(result_symbols) + if alt_one: + result = result.replace("一", "幺") + return result + + +def num2str(value_string: str) -> str: + integer_decimal = value_string.split(".") + if len(integer_decimal) == 1: + integer = integer_decimal[0] + decimal = "" + elif len(integer_decimal) == 2: + integer, decimal = integer_decimal + else: + raise ValueError(f"The value string: '${value_string}' has more than one point in it.") + + result = verbalize_cardinal(integer) + + decimal = decimal.rstrip("0") + if decimal: + # '.22' is verbalized as '零点二二' + # '3.20' is verbalized as '三点二 + result = result if result else "零" + result += "点" + verbalize_digit(decimal) + return result diff --git a/text/zh_normalization/phonecode.py b/text/zh_normalization/phonecode.py new file mode 100644 index 0000000000000000000000000000000000000000..3560ac2ed265580e5f150da191c21acea7390087 --- /dev/null +++ b/text/zh_normalization/phonecode.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +from .num import verbalize_digit + +# 规范化固话/手机号码 +# 手机 +# http://www.jihaoba.com/news/show/13680 +# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 +# 联通:130、131、132、156、155、186、185、176 +# 电信:133、153、189、180、181、177 +RE_MOBILE_PHONE = re.compile(r"(? str: + if mobile: + sp_parts = phone_string.strip("+").split() + result = ",".join([verbalize_digit(part, alt_one=True) for part in sp_parts]) + return result + else: + sil_parts = phone_string.split("-") + result = ",".join([verbalize_digit(part, alt_one=True) for part in sil_parts]) + return result + + +def replace_phone(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + return phone2str(match.group(0), mobile=False) + + +def replace_mobile(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + return phone2str(match.group(0)) diff --git a/text/zh_normalization/quantifier.py b/text/zh_normalization/quantifier.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7f2aab07d8b16fd1d69a93c70100f969a7ae51 --- /dev/null +++ b/text/zh_normalization/quantifier.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +from .num import num2str + +# 温度表达式,温度会影响负号的读法 +# -3°C 零下三度 +RE_TEMPERATURE = re.compile(r"(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)") +measure_dict = { + "cm2": "平方厘米", + "cm²": "平方厘米", + "cm3": "立方厘米", + "cm³": "立方厘米", + "cm": "厘米", + "db": "分贝", + "ds": "毫秒", + "kg": "千克", + "km": "千米", + "m2": "平方米", + "m²": "平方米", + "m³": "立方米", + "m3": "立方米", + "ml": "毫升", + "m": "米", + "mm": "毫米", + "s": "秒", +} + + +def replace_temperature(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + sign = match.group(1) + temperature = match.group(2) + unit = match.group(3) + sign: str = "零下" if sign else "" + temperature: str = num2str(temperature) + unit: str = "摄氏度" if unit == "摄氏度" else "度" + result = f"{sign}{temperature}{unit}" + return result + + +def replace_measure(sentence) -> str: + for q_notation in measure_dict: + if q_notation in sentence: + sentence = sentence.replace(q_notation, measure_dict[q_notation]) + return sentence diff --git a/text/zh_normalization/text_normlization.py b/text/zh_normalization/text_normlization.py new file mode 100644 index 0000000000000000000000000000000000000000..099b01bd6b3ba9b7516e821a7d93239170ca3c68 --- /dev/null +++ b/text/zh_normalization/text_normlization.py @@ -0,0 +1,172 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from typing import List + +from .char_convert import tranditional_to_simplified +from .chronology import RE_DATE +from .chronology import RE_DATE2 +from .chronology import RE_TIME +from .chronology import RE_TIME_RANGE +from .chronology import replace_date +from .chronology import replace_date2 +from .chronology import replace_time +from .constants import F2H_ASCII_LETTERS +from .constants import F2H_DIGITS +from .constants import F2H_SPACE +from .num import RE_DECIMAL_NUM +from .num import RE_DEFAULT_NUM +from .num import RE_FRAC +from .num import RE_INTEGER +from .num import RE_NUMBER +from .num import RE_PERCENTAGE +from .num import RE_POSITIVE_QUANTIFIERS +from .num import RE_RANGE +from .num import RE_TO_RANGE +from .num import RE_ASMD +from .num import RE_POWER +from .num import replace_default_num +from .num import replace_frac +from .num import replace_negative_num +from .num import replace_number +from .num import replace_percentage +from .num import replace_positive_quantifier +from .num import replace_range +from .num import replace_to_range +from .num import replace_asmd +from .num import replace_power +from .phonecode import RE_MOBILE_PHONE +from .phonecode import RE_NATIONAL_UNIFORM_NUMBER +from .phonecode import RE_TELEPHONE +from .phonecode import replace_mobile +from .phonecode import replace_phone +from .quantifier import RE_TEMPERATURE +from .quantifier import replace_measure +from .quantifier import replace_temperature + + +class TextNormalizer: + def __init__(self): + self.SENTENCE_SPLITOR = re.compile(r"([:、,;。?!,;?!][”’]?)") + + def _split(self, text: str, lang="zh") -> List[str]: + """Split long text into sentences with sentence-splitting punctuations. + Args: + text (str): The input text. + Returns: + List[str]: Sentences. + """ + # Only for pure Chinese here + if lang == "zh": + text = text.replace(" ", "") + # 过滤掉特殊字符 + text = re.sub(r"[——《》【】<>{}()()#&@“”^_|\\]", "", text) + text = self.SENTENCE_SPLITOR.sub(r"\1\n", text) + text = text.strip() + sentences = [sentence.strip() for sentence in re.split(r"\n+", text)] + return sentences + + def _post_replace(self, sentence: str) -> str: + sentence = sentence.replace("/", "每") + # sentence = sentence.replace('~', '至') + # sentence = sentence.replace('~', '至') + sentence = sentence.replace("①", "一") + sentence = sentence.replace("②", "二") + sentence = sentence.replace("③", "三") + sentence = sentence.replace("④", "四") + sentence = sentence.replace("⑤", "五") + sentence = sentence.replace("⑥", "六") + sentence = sentence.replace("⑦", "七") + sentence = sentence.replace("⑧", "八") + sentence = sentence.replace("⑨", "九") + sentence = sentence.replace("⑩", "十") + sentence = sentence.replace("α", "阿尔法") + sentence = sentence.replace("β", "贝塔") + sentence = sentence.replace("γ", "伽玛").replace("Γ", "伽玛") + sentence = sentence.replace("δ", "德尔塔").replace("Δ", "德尔塔") + sentence = sentence.replace("ε", "艾普西龙") + sentence = sentence.replace("ζ", "捷塔") + sentence = sentence.replace("η", "依塔") + sentence = sentence.replace("θ", "西塔").replace("Θ", "西塔") + sentence = sentence.replace("ι", "艾欧塔") + sentence = sentence.replace("κ", "喀帕") + sentence = sentence.replace("λ", "拉姆达").replace("Λ", "拉姆达") + sentence = sentence.replace("μ", "缪") + sentence = sentence.replace("ν", "拗") + sentence = sentence.replace("ξ", "克西").replace("Ξ", "克西") + sentence = sentence.replace("ο", "欧米克伦") + sentence = sentence.replace("π", "派").replace("Π", "派") + sentence = sentence.replace("ρ", "肉") + sentence = sentence.replace("ς", "西格玛").replace("Σ", "西格玛").replace("σ", "西格玛") + sentence = sentence.replace("τ", "套") + sentence = sentence.replace("υ", "宇普西龙") + sentence = sentence.replace("φ", "服艾").replace("Φ", "服艾") + sentence = sentence.replace("χ", "器") + sentence = sentence.replace("ψ", "普赛").replace("Ψ", "普赛") + sentence = sentence.replace("ω", "欧米伽").replace("Ω", "欧米伽") + # 兜底数学运算,顺便兼容懒人用语 + sentence = sentence.replace("+", "加") + sentence = sentence.replace("-", "减") + sentence = sentence.replace("×", "乘") + sentence = sentence.replace("÷", "除") + sentence = sentence.replace("=", "等") + # re filter special characters, have one more character "-" than line 68 + sentence = re.sub(r"[-——《》【】<=>{}()()#&@“”^_|\\]", "", sentence) + return sentence + + def normalize_sentence(self, sentence: str) -> str: + # basic character conversions + sentence = tranditional_to_simplified(sentence) + sentence = sentence.translate(F2H_ASCII_LETTERS).translate(F2H_DIGITS).translate(F2H_SPACE) + + # number related NSW verbalization + sentence = RE_DATE.sub(replace_date, sentence) + sentence = RE_DATE2.sub(replace_date2, sentence) + + # range first + sentence = RE_TIME_RANGE.sub(replace_time, sentence) + sentence = RE_TIME.sub(replace_time, sentence) + + # 处理~波浪号作为至的替换 + sentence = RE_TO_RANGE.sub(replace_to_range, sentence) + sentence = RE_TEMPERATURE.sub(replace_temperature, sentence) + sentence = replace_measure(sentence) + + # 处理数学运算 + while RE_ASMD.search(sentence): + sentence = RE_ASMD.sub(replace_asmd, sentence) + sentence = RE_POWER.sub(replace_power, sentence) + + sentence = RE_FRAC.sub(replace_frac, sentence) + sentence = RE_PERCENTAGE.sub(replace_percentage, sentence) + sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence) + + sentence = RE_TELEPHONE.sub(replace_phone, sentence) + sentence = RE_NATIONAL_UNIFORM_NUMBER.sub(replace_phone, sentence) + + sentence = RE_RANGE.sub(replace_range, sentence) + + sentence = RE_INTEGER.sub(replace_negative_num, sentence) + sentence = RE_DECIMAL_NUM.sub(replace_number, sentence) + sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier, sentence) + sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence) + sentence = RE_NUMBER.sub(replace_number, sentence) + sentence = self._post_replace(sentence) + + return sentence + + def normalize(self, text: str) -> List[str]: + sentences = self._split(text) + sentences = [self.normalize_sentence(sent) for sent in sentences] + return sentences diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/asr/config.py b/tools/asr/config.py new file mode 100644 index 0000000000000000000000000000000000000000..4b0d37ae64a583cf80b5c244339c3f20b27d1acd --- /dev/null +++ b/tools/asr/config.py @@ -0,0 +1,33 @@ +import os + +def check_fw_local_models(): + ''' + 启动时检查本地是否有 Faster Whisper 模型. + ''' + model_size_list = [ + "tiny", "tiny.en", + "base", "base.en", + "small", "small.en", + "medium", "medium.en", + "large", "large-v1", + "large-v2", "large-v3"] + for i, size in enumerate(model_size_list): + if os.path.exists(f'tools/asr/models/faster-whisper-{size}'): + model_size_list[i] = size + '-local' + return model_size_list + +asr_dict = { + "达摩 ASR (中文)": { + 'lang': ['zh','yue'], + 'size': ['large'], + 'path': 'funasr_asr.py', + 'precision': ['float32'] + }, + "Faster Whisper (多语种)": { + 'lang': ['auto', 'zh', 'en', 'ja', 'ko', 'yue'], + 'size': check_fw_local_models(), + 'path': 'fasterwhisper_asr.py', + 'precision': ['float32', 'float16', 'int8'] + }, +} + diff --git a/tools/asr/fasterwhisper_asr.py b/tools/asr/fasterwhisper_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..da8eadfb10c3fe6e917c25c018703d460aae1564 --- /dev/null +++ b/tools/asr/fasterwhisper_asr.py @@ -0,0 +1,114 @@ +import argparse +import os +import traceback + +os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + +import torch +from faster_whisper import WhisperModel +from tqdm import tqdm + +from tools.asr.config import check_fw_local_models + +language_code_list = [ + "af", "am", "ar", "as", "az", + "ba", "be", "bg", "bn", "bo", + "br", "bs", "ca", "cs", "cy", + "da", "de", "el", "en", "es", + "et", "eu", "fa", "fi", "fo", + "fr", "gl", "gu", "ha", "haw", + "he", "hi", "hr", "ht", "hu", + "hy", "id", "is", "it", "ja", + "jw", "ka", "kk", "km", "kn", + "ko", "la", "lb", "ln", "lo", + "lt", "lv", "mg", "mi", "mk", + "ml", "mn", "mr", "ms", "mt", + "my", "ne", "nl", "nn", "no", + "oc", "pa", "pl", "ps", "pt", + "ro", "ru", "sa", "sd", "si", + "sk", "sl", "sn", "so", "sq", + "sr", "su", "sv", "sw", "ta", + "te", "tg", "th", "tk", "tl", + "tr", "tt", "uk", "ur", "uz", + "vi", "yi", "yo", "zh", "yue", + "auto"] + +def execute_asr(input_folder, output_folder, model_size, language, precision): + if '-local' in model_size: + model_size = model_size[:-6] + model_path = f'tools/asr/models/faster-whisper-{model_size}' + else: + model_path = model_size + if language == 'auto': + language = None #不设置语种由模型自动输出概率最高的语种 + print("loading faster whisper model:",model_size,model_path) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + try: + model = WhisperModel(model_path, device=device, compute_type=precision) + except: + return print(traceback.format_exc()) + + input_file_names = os.listdir(input_folder) + input_file_names.sort() + + output = [] + output_file_name = os.path.basename(input_folder) + + for file_name in tqdm(input_file_names): + try: + file_path = os.path.join(input_folder, file_name) + segments, info = model.transcribe( + audio = file_path, + beam_size = 5, + vad_filter = True, + vad_parameters = dict(min_silence_duration_ms=700), + language = language) + text = '' + + if info.language == "zh": + print("检测为中文文本, 转 FunASR 处理") + if("only_asr"not in globals()): + from tools.asr.funasr_asr import \ + only_asr # #如果用英文就不需要导入下载模型 + text = only_asr(file_path) + + if text == '': + for segment in segments: + text += segment.text + output.append(f"{file_path}|{output_file_name}|{info.language.upper()}|{text}") + except: + print(traceback.format_exc()) + + output_folder = output_folder or "output/asr_opt" + os.makedirs(output_folder, exist_ok=True) + output_file_path = os.path.abspath(f'{output_folder}/{output_file_name}.list') + + with open(output_file_path, "w", encoding="utf-8") as f: + f.write("\n".join(output)) + print(f"ASR 任务完成->标注文件路径: {output_file_path}\n") + return output_file_path + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input_folder", type=str, required=True, + help="Path to the folder containing WAV files.") + parser.add_argument("-o", "--output_folder", type=str, required=True, + help="Output folder to store transcriptions.") + parser.add_argument("-s", "--model_size", type=str, default='large-v3', + choices=check_fw_local_models(), + help="Model Size of Faster Whisper") + parser.add_argument("-l", "--language", type=str, default='ja', + choices=language_code_list, + help="Language of the audio files.") + parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32','int8'], + help="fp16, int8 or fp32") + + cmd = parser.parse_args() + output_file_path = execute_asr( + input_folder = cmd.input_folder, + output_folder = cmd.output_folder, + model_size = cmd.model_size, + language = cmd.language, + precision = cmd.precision, + ) diff --git a/tools/asr/funasr_asr.py b/tools/asr/funasr_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..11209ada3de924ca9afc2f84974b620ae0947adf --- /dev/null +++ b/tools/asr/funasr_asr.py @@ -0,0 +1,91 @@ +# -*- coding:utf-8 -*- + +import argparse +import os +import traceback +from tqdm import tqdm +# from funasr.utils import version_checker +# version_checker.check_for_update = lambda: None +from funasr import AutoModel + + +def only_asr(input_file): + try: + text = model.generate(input=input_file)[0]["text"] + except: + text = '' + print(traceback.format_exc()) + return text + +def execute_asr(input_folder, output_folder, model_size, language): + input_file_names = os.listdir(input_folder) + input_file_names.sort() + + output = [] + output_file_name = os.path.basename(input_folder) + + for file_name in tqdm(input_file_names): + try: + print(file_name) + file_path = os.path.join(input_folder, file_name) + text = model.generate(input=file_path)[0]["text"] + output.append(f"{file_path}|{output_file_name}|{language.upper()}|{text}") + except: + print(traceback.format_exc()) + + output_folder = output_folder or "output/asr_opt" + os.makedirs(output_folder, exist_ok=True) + output_file_path = os.path.abspath(f'{output_folder}/{output_file_name}.list') + + with open(output_file_path, "w", encoding="utf-8") as f: + f.write("\n".join(output)) + print(f"ASR 任务完成->标注文件路径: {output_file_path}\n") + return output_file_path + + +parser = argparse.ArgumentParser() +parser.add_argument("-i", "--input_folder", type=str, required=True, + help="Path to the folder containing WAV files.") +parser.add_argument("-o", "--output_folder", type=str, required=True, + help="Output folder to store transcriptions.") +parser.add_argument("-s", "--model_size", type=str, default='large', + help="Model Size of FunASR is Large") +parser.add_argument("-l", "--language", type=str, default='zh', choices=['zh','yue','auto'], + help="Language of the audio files.") +parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32'], + help="fp16 or fp32")#还没接入 + +cmd = parser.parse_args() + +path_vad = 'tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch' +path_punc = 'tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch' +path_vad = path_vad if os.path.exists(path_vad) else "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch" +path_punc = path_punc if os.path.exists(path_punc) else "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" +vad_model_revision=punc_model_revision="v2.0.4" + +if(cmd.language=="zh"): + path_asr = 'tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' + path_asr = path_asr if os.path.exists(path_asr) else "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" + model_revision="v2.0.4" +else: + path_asr = 'tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online' + path_asr = path_asr if os.path.exists(path_asr) else "iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online" + model_revision="master" + path_vad=path_punc=vad_model_revision=punc_model_revision=None###友情提示:粤语带VAD识别可能会有少量shape不对报错的,但是不带VAD可以.不带vad只能分阶段单独加标点。不过标点模型对粤语效果真的不行… + +model = AutoModel( + model=path_asr, + model_revision=model_revision, + vad_model=path_vad, + vad_model_revision=vad_model_revision, + punc_model=path_punc, + punc_model_revision=punc_model_revision, +) + +if __name__ == '__main__': + execute_asr( + input_folder = cmd.input_folder, + output_folder = cmd.output_folder, + model_size = cmd.model_size, + language = cmd.language, + ) diff --git a/tools/asr/models/.gitignore b/tools/asr/models/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c96a04f008ee21e260b28f7701595ed59e2839e3 --- /dev/null +++ b/tools/asr/models/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/tools/cmd-denoise.py b/tools/cmd-denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..1fdcab6dc1c8a3727d69faa96349b889b0d76d6d --- /dev/null +++ b/tools/cmd-denoise.py @@ -0,0 +1,33 @@ +import os,argparse +import traceback + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from tqdm import tqdm + +path_denoise = 'tools/denoise-model/speech_frcrn_ans_cirm_16k' +path_denoise = path_denoise if os.path.exists(path_denoise) else "damo/speech_frcrn_ans_cirm_16k" +ans = pipeline(Tasks.acoustic_noise_suppression,model=path_denoise) +def execute_denoise(input_folder,output_folder): + os.makedirs(output_folder,exist_ok=True) + # print(input_folder) + # print(list(os.listdir(input_folder).sort())) + for name in tqdm(os.listdir(input_folder)): + try: + ans("%s/%s"%(input_folder,name),output_path='%s/%s'%(output_folder,name)) + except: + traceback.print_exc() + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input_folder", type=str, required=True, + help="Path to the folder containing WAV files.") + parser.add_argument("-o", "--output_folder", type=str, required=True, + help="Output folder to store transcriptions.") + parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32'], + help="fp16 or fp32")#还没接入 + cmd = parser.parse_args() + execute_denoise( + input_folder = cmd.input_folder, + output_folder = cmd.output_folder, + ) \ No newline at end of file diff --git a/tools/denoise-model/.gitignore b/tools/denoise-model/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d6b7ef32c8478a48c3994dcadc86837f4371184d --- /dev/null +++ b/tools/denoise-model/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/tools/i18n/__pycache__/i18n.cpython-39.pyc b/tools/i18n/__pycache__/i18n.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1c99b5a151f8c15ad2338946baf760c3b77daa6 Binary files /dev/null and b/tools/i18n/__pycache__/i18n.cpython-39.pyc differ diff --git a/tools/i18n/i18n.py b/tools/i18n/i18n.py new file mode 100644 index 0000000000000000000000000000000000000000..e256941a64b67cc97fa9c1aa9317eeb6389305ba --- /dev/null +++ b/tools/i18n/i18n.py @@ -0,0 +1,36 @@ +import json +import locale +import os + +I18N_JSON_DIR : os.PathLike = os.path.join(os.path.dirname(os.path.relpath(__file__)), 'locale') + +def load_language_list(language): + with open(os.path.join(I18N_JSON_DIR, f"{language}.json"), "r", encoding="utf-8") as f: + language_list = json.load(f) + return language_list + +def scan_language_list(): + language_list = [] + for name in os.listdir(I18N_JSON_DIR): + if name.endswith(".json"):language_list.append(name.split('.')[0]) + return language_list + +class I18nAuto: + def __init__(self, language=None): + if language in ["Auto", None]: + language = locale.getdefaultlocale()[0] + # getlocale can't identify the system's language ((None, None)) + if not os.path.exists(os.path.join(I18N_JSON_DIR, f"{language}.json")): + language = "en_US" + self.language = language + self.language_map = load_language_list(language) + + def __call__(self, key): + return self.language_map.get(key, key) + + def __repr__(self): + return "Use Language: " + self.language + +if __name__ == "__main__": + i18n = I18nAuto(language='en_US') + print(i18n) \ No newline at end of file diff --git a/tools/i18n/locale/en_US.json b/tools/i18n/locale/en_US.json new file mode 100644 index 0000000000000000000000000000000000000000..267d715773e56e34e72b1277d2da799757315a9a --- /dev/null +++ b/tools/i18n/locale/en_US.json @@ -0,0 +1,180 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1)MDX-Net(onnx_dereverb): Best choice for dual-channel reverberation, cannot remove single-channel reverberation;", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho: Removes delay effects. Aggressive mode removes more thoroughly than Normal mode. DeReverb additionally removes reverberation, can remove mono reverberation, but does not clean heavily high-frequency plate reverberation.", + "*GPT模型列表": "*GPT models list", + "*SoVITS模型列表": "*SoVITS models list", + "*实验/模型名": "*Experiment/model name", + "*文本标注文件": "*Text labelling file", + "*训练集音频文件目录": "*Audio dataset folder", + "*请上传并填写参考信息": "*Please upload and fill reference information", + "*请填写需要合成的目标文本和语种模式": "*Please fill in the target text and language mode for synthesis", + ".list标注文件的路径": ".list annotation file path", + "0-前置数据集获取工具": "0-Fetch dataset", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0a-UVR5 webui (for vocal separation, deecho, dereverb and denoise)", + "0b-语音切分工具": "0b-Audio slicer", + "0bb-语音降噪工具": "0bb-Voice denoiser", + "0c-中文批量离线ASR工具": "0c-Chinese ASR tool", + "0d-语音文本校对标注工具": "0d-Speech to text proofreading tool", + "1-GPT-SoVITS-TTS": "1-GPT-SOVITS-TTS", + "1A-训练集格式化工具": "1A-Dataset formatting", + "1Aa-文本内容": "1Aa-Text", + "1Aabc-训练集格式化一键三连": "1Aabc-One-click formatting", + "1Ab-SSL自监督特征提取": "1Ab-SSL self-supervised feature extraction", + "1Ac-语义token提取": "1Ac-semantics token extraction", + "1B-微调训练": "1B-Fine-tuned training", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1Ba-SoVITS training. The model is located in SoVITS_weights.", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1Bb-GPT training. The model is located in GPT_weights.", + "1C-推理": "1C-inference", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1. The DeEcho-DeReverb model's processing time is nearly twice that of the other two DeEcho models.", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1. Preserve Vocals: Choose this option for audio without harmonies, as it better retains the main vocal compared to the HP5 model. This option includes two built-in models, HP2 and HP3. HP3 may slightly let through some accompaniment but retains the main vocal slightly better than HP2.", + "2-GPT-SoVITS-变声": "2-GPT-SoVITS-Voice Changer", + "2、MDX-Net-Dereverb模型挺慢的;": "2、MDX-Net-Dereverb Model is slow;", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2. Keep Only Main Vocal: Choose this option for audio with harmonies, as it may slightly reduce the main vocal. Includes one built-in HP5 model;", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3. Personal Recommendation for the cleanest configuration: First use MDX-Net followed by DeEcho-Aggressive", + "3、去混响、去延迟模型(by FoxJoy):": "3. Reverberation and delay removal model(by FoxJoy):", + "ASR 模型": "ASR model", + "ASR 模型尺寸": "ASR model size", + "ASR 语言设置": "ASR language", + "ASR进程输出信息": "ASR output log", + "GPT模型列表": "GPT weight list", + "GPT训练进程输出信息": "GPT training output log", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "GPT sampling parameters (not too low when there's no reference text. Use default if unsure):", + "GPU卡号,只能填1个整数": "GPU number, can only input ONE integer", + "GPU卡号以-分割,每个卡号一个进程": "GPU number is separated by -, each GPU will run one process ", + "SSL进程输出信息": "SSL output log", + "SoVITS模型列表": "SoVITS weight list", + "SoVITS训练进程输出信息": "SoVITS training output log", + "TTS推理WebUI进程输出信息": "TTS inference webui output log", + "TTS推理进程已关闭": "TTS inference process closed", + "TTS推理进程已开启": "TTS inference process is opened", + "UVR5已关闭": "UVR5 closed", + "UVR5已开启": "UVR5 opened ", + "UVR5进程输出信息": "UVR5 process output log", + "alpha_mix:混多少比例归一化后音频进来": "alpha_mix: proportion of normalized audio merged into dataset", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size: FO hop size, the smaller the value, the higher the accuracy)", + "max:归一化后最大值多少": "Loudness multiplier after normalized", + "max_sil_kept:切完后静音最多留多长": "Maximum length for silence to be kept", + "min_interval:最短切割间隔": "Minumum interval for audio cutting", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length: the minimum length of each segment. If the first segment is too short, it will be concatenated with the next segment until it exceeds this value", + "temperature": "temperature", + "threshold:音量小于这个值视作静音的备选切割点": "Noise gate threshold (loudness below this value will be treated as noise", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "One-click formatting output", + "不切": "No slice", + "中文": "Chinese", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "Chinese Tutorial:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "Chinese-English Mixed", + "也可批量输入音频文件, 二选一, 优先读文件夹": "Multiple audio files can also be imported. If a folder path exists, this input is ignored.", + "人声伴奏分离批量处理, 使用UVR5模型。": "Batch processing for vocal and instrumental separation, using the UVR5 model.", + "人声提取激进程度": "Vocal extraction aggressiveness", + "以下文件或文件夹不存在:": "No Such File or Folder:", + "以下模型不存在:": "No Such Model:", + "伴奏人声分离&去混响&去回声": "Vocals/Accompaniment Separation & Reverberation Removal", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "When using the no-reference text mode, it is recommended to use a fine-tuned GPT. If the reference audio is unclear and you don't know what to write, you can enable this feature, which will ignore the reference text you've entered.", + "保存频率save_every_epoch": "Save frequency (save_every_epoch):", + "凑50字一切": "Slice per 50 characters", + "凑四句一切": "Slice once every 4 sentences", + "切分后的子音频的输出根目录": "Audio slicer output folder", + "切割使用的进程数": "CPU threads used for audio slicing", + "刷新模型路径": "refreshing model paths", + "前端处理后的文本(每句):": "Processed text from the frontend (per sentence):", + "去混响/去延迟,附:": "Dereverberation/Delay Removal, including:", + "参考音频在3~10秒范围外,请更换!": "Reference audio is outside the 3-10 second range, please choose another one!", + "参考音频的文本": "Text for reference audio", + "参考音频的语种": "Language for reference audio", + "合成语音": "Start inference", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "An example of a valid folder path format: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例 (simply copy the address from the file manager's address bar).", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "Please fill in the segmented audio files' directory! The full path of the audio file = the directory concatenated with the filename corresponding to the waveform in the list file (not the full path). If left blank, the absolute full path in the .list file will be used.", + "多语种混合": "Multilingual Mixed", + "多语种混合(粤语)": "Multilingual Mixed(Yue)", + "实际输入的参考文本:": "Actual Input Reference Text:", + "实际输入的目标文本(切句后):": "Actual Input Target Text (after sentence segmentation):", + "实际输入的目标文本(每句):": "Actual Input Target Text (per sentence):", + "实际输入的目标文本:": "Actual Input Target Text:", + "导出文件格式": "Export file format", + "开启GPT训练": "Start GPT training", + "开启SSL提取": "Start SSL extracting", + "开启SoVITS训练": "Start SoVITS training", + "开启一键三连": "Start one-click formatting", + "开启文本获取": "Start speech-to-text", + "开启无参考文本模式。不填参考文本亦相当于开启。": "Enable no reference mode. If you don't fill 'Text for reference audio', no reference mode will be enabled.", + "开启离线批量ASR": "Start batch ASR", + "开启语义token提取": "Start semantics token extraction", + "开启语音切割": "Start audio slicer", + "开启语音降噪": "Start voice denoiser", + "怎么切": "How to slice the sentence", + "总训练轮数total_epoch": "Total training epochs (total_epoch):", + "总训练轮数total_epoch,不建议太高": "Total epochs, do not increase to a value that is too high", + "打标工具WebUI已关闭": "proofreading tool webui is closed", + "打标工具WebUI已开启": "proofreading tool webui is opened", + "打标工具进程输出信息": "Proofreading tool output log", + "指定输出主人声文件夹": "Specify the output folder for vocals:", + "指定输出非主人声文件夹": "Specify the output folder for accompaniment:", + "按中文句号。切": "Slice by Chinese punct", + "按标点符号切": "Slice by every punct", + "按英文句号.切": "Slice by English punct", + "数据类型精度": "Computing precision", + "文本模块学习率权重": "Text model learning rate weighting", + "文本进程输出信息": "Text processing output", + "施工中,请静候佳音": "In construction, please wait", + "日文": "Japanese", + "日英混合": "Japanese-English Mixed", + "是否仅保存最新的ckpt文件以节省硬盘空间": "Save only the latest '.ckpt' file to save disk space:", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "Save a small final model to the 'weights' folder at each save point:", + "是否开启TTS推理WebUI": "Open TTS inference WebUI", + "是否开启UVR5-WebUI": "Open UVR5-WebUI", + "是否开启dpo训练选项(实验性)": "Enable DPO training (experimental feature)", + "是否开启打标WebUI": "Open labelling WebUI", + "是否直接对上次合成结果调整语速和音色。防止随机性。": "Whether to directly adjust the speech rate and timebre of the last synthesis result to prevent randomness.", + "显卡信息": "GPU Information", + ".限制范围越小判别效果越好。": "The smaller the range, the better the performance.", + "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。": "Optional: Upload multiple reference audio files by dragging and dropping them (same gender recommended), and average their timebre. If this field is left blank, the timebre will be controlled by the single reference audio file on the left.", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "This software is open source under the MIT license. The author does not have any control over the software. Users who use the software and distribute the sounds exported by the software are solely responsible.
If you do not agree with this clause, you cannot use or reference any codes and files within the software package. See the root directory Agreement-LICENSE for details.", + "模型": "Model", + "模型分为三类:": "Models are categorized into three types:", + "模型切换": "Model switch", + "每张显卡的batch_size": "Batch size per GPU:", + "版本": "Version", + "粤英混合": "Yue-English Mixed", + "粤语": "Yue", + "终止ASR进程": "Stop ASR task", + "终止GPT训练": "Stop GPT training", + "终止SSL提取进程": "Stop SSL extraction", + "终止SoVITS训练": "Stop SoVITS training", + "终止一键三连": "Stop one-click formatting", + "终止文本获取进程": "Stop speech-to-text", + "终止语义token提取进程": "Stop semantics token extraction", + "终止语音切割": "Stop audio cutting", + "终止语音降噪进程": "Stop voice denoising", + "英文": "English", + "语义token提取进程输出信息": "Sematics token extraction output log", + "语速": "Speech rate", + "语速调整,高为更快": "Adjust speech rate, higher for faster", + "语音切割进程输出信息": "Audio slicer output log", + "语音降噪进程输出信息": "Voice Denoiser Process Output Information", + "请上传3~10秒内参考音频,超过会报错!": "Please upload a reference audio within the 3-10 second range; if it exceeds this duration, it will raise errors.", + "请上传参考音频": "Please Upload the Reference Audio", + "请填入推理文本": "Please Fill in the Terget Text", + "请输入有效文本": "Please enter valid text.", + "转换": "Convert", + "输入待处理音频文件夹路径": "Enter the path of the audio folder to be processed:", + "输入文件夹路径": "Input folder path", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "output folder (logs/{experiment name}) should have files and folders starts with 23456.", + "输出信息": "Output information", + "输出文件夹路径": "Output folder path", + "输出的语音": "Inference Result", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "Choose the models from SoVITS_weights and GPT_weights. The default one is a pretrain, so you can experience zero shot TTS.", + "降噪结果输出文件夹": "Denoised Results Output Folder", + "降噪音频文件输入文件夹": "Denoising Audio File Input Folder", + "需要合成的文本": "Inference text", + "需要合成的语种": "Inference text language", + "韩文": "Korean", + "韩英混合": "Korean-English Mixed", + "音频自动切分输入路径,可文件可文件夹": "Audio slicer input (file or folder)", + "预训练的GPT模型路径": "Pretrained GPT model path", + "预训练的SSL模型路径": "Pretrained SSL model path", + "预训练的SoVITS-D模型路径": "Pretrained SoVITS-D model path", + "预训练的SoVITS-G模型路径": "Pretrained SoVITS-G model path", + "预训练的中文BERT模型路径": " Pretrained BERT model path" +} diff --git a/tools/i18n/locale/es_ES.json b/tools/i18n/locale/es_ES.json new file mode 100644 index 0000000000000000000000000000000000000000..7d3e64c84f83a7f397ddd529a2abf00ef6debb95 --- /dev/null +++ b/tools/i18n/locale/es_ES.json @@ -0,0 +1,178 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1)MDX-Net (onnx_dereverb): reverberación estéreo, la mejor opción; no puede eliminar reverberación mono", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho: Eliminar el efecto de retardo. Aggressive elimina más que Normal, DeReverb elimina reverberación adicional, puede eliminar reverberación mono, pero no limpia bien la reverberación de placa de alta frecuencia", + "*GPT模型列表": "*Lista de modelos GPT", + "*SoVITS模型列表": "*Lista de modelos SoVITS", + "*实验/模型名": "*Nombre del experimento/modelo", + "*文本标注文件": "*Archivo de etiquetado de texto", + "*训练集音频文件目录": "*Directorio de archivos de audio de entrenamiento", + "*请上传并填写参考信息": "*Por favor, suba y complete la información de referencia", + "*请填写需要合成的目标文本和语种模式": "*Por favor, complete el texto objetivo a sintetizar y el modo de idioma", + ".list标注文件的路径": "Ruta del archivo de anotación .list", + "0-前置数据集获取工具": "0-Herramienta de obtención de conjunto de datos previo", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0a-Herramienta de separación de voz y acompañamiento UVR5 y eliminación de reverberación y retardo", + "0b-语音切分工具": "0b-Herramienta de división de voz", + "0bb-语音降噪工具": "0bb-Herramienta de reducción de ruido de voz", + "0c-中文批量离线ASR工具": "0c-Herramienta de ASR en lote fuera de línea en chino", + "0d-语音文本校对标注工具": "0d-Herramienta de corrección y etiquetado de texto de voz", + "1-GPT-SoVITS-TTS": "1-GPT-SoVITS-TTS", + "1A-训练集格式化工具": "1A-Herramienta de formateo del conjunto de datos de entrenamiento", + "1Aa-文本内容": "1Aa-Contenido del texto", + "1Aabc-训练集格式化一键三连": "1Aabc-Formateo del conjunto de datos de entrenamiento en un solo paso", + "1Ab-SSL自监督特征提取": "1Ab-Extracción de características auto-supervisada SSL", + "1Ac-语义token提取": "1Ac-Extracción de tokens semánticos", + "1B-微调训练": "1B-Entrenamiento de ajuste fino", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1Ba-Entrenamiento de SoVITS. Los archivos de modelo para compartir se encuentran en SoVITS_weights.", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1Bb-Entrenamiento de GPT. Los archivos de modelo para compartir se encuentran en GPT_weights.", + "1C-推理": "1C-Inferencia", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1. El modelo DeEcho-DeReverb tarda casi el doble que los otros dos modelos DeEcho", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1. Retener voz principal: seleccione este para audio sin coros, retiene mejor la voz principal que HP5. Incluye dos modelos, HP2 y HP3; HP3 puede filtrar ligeramente el acompañamiento pero retiene mejor la voz principal que HP2", + "2-GPT-SoVITS-变声": "2-GPT-SoVITS-Cambio de voz", + "2、MDX-Net-Dereverb模型挺慢的;": "2. El modelo MDX-Net-Dereverb es bastante lento", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2. Solo retener voz principal: seleccione este para audio con coros, puede debilitar la voz principal. Incluye un modelo HP5", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3. La configuración más limpia recomendada es primero MDX-Net, luego DeEcho-Aggressive", + "3、去混响、去延迟模型(by FoxJoy):": "3. Modelos de eliminación de reverberación y retardo (por FoxJoy)", + "ASR 模型": "Modelo ASR", + "ASR 模型尺寸": "Tamaño del modelo ASR", + "ASR 语言设置": "Configuración del idioma ASR", + "ASR进程输出信息": "Información de salida del proceso ASR", + "GPT模型列表": "Lista de modelos GPT", + "GPT训练进程输出信息": "Información de salida del proceso de entrenamiento de GPT", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "Parámetros de muestreo de GPT (no demasiado bajos cuando no hay texto de referencia. Use los valores por defecto si no está seguro):", + "GPU卡号,只能填1个整数": "Número de tarjeta GPU, solo se puede ingresar un número entero", + "GPU卡号以-分割,每个卡号一个进程": "Número de tarjeta GPU separado por '-', cada número de tarjeta es un proceso", + "SSL进程输出信息": "Información de salida del proceso SSL", + "SoVITS模型列表": "Lista de modelos SoVITS", + "SoVITS训练进程输出信息": "Información de salida del proceso de entrenamiento de SoVITS", + "TTS推理WebUI进程输出信息": "Información de salida del proceso de interfaz web de inferencia TTS", + "TTS推理进程已关闭": "Proceso de inferencia TTS cerrado", + "TTS推理进程已开启": "Proceso de inferencia TTS iniciado", + "UVR5已关闭": "UVR5 está deshabilitado", + "UVR5已开启": "UVR5 está habilitado", + "UVR5进程输出信息": "Información de salida del proceso UVR5", + "alpha_mix:混多少比例归一化后音频进来": "alpha_mix: proporción de mezcla de audio normalizado que entra", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size: cómo calcular la curva de volumen, cuanto más pequeño, mayor precisión pero mayor carga computacional (mayor precisión no significa mejor rendimiento)", + "max:归一化后最大值多少": "max: valor máximo después de la normalización", + "max_sil_kept:切完后静音最多留多长": "max_sil_kept: duración máxima del silencio después del corte", + "min_interval:最短切割间隔": "min_interval: intervalo mínimo de corte", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length: longitud mínima de cada segmento; si el primer segmento es demasiado corto, se une al siguiente hasta superar este valor", + "temperature": "temperatura", + "threshold:音量小于这个值视作静音的备选切割点": "umbral: puntos de corte alternativos considerados como silencio si el volumen es menor que este valor", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "Información de salida del proceso de triple acción", + "不切": "No cortar", + "中文": "Chino", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "Documentación del tutorial en chino: https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "Chino e inglés mezclados", + "也可批量输入音频文件, 二选一, 优先读文件夹": "También se pueden ingresar archivos de audio por lotes, seleccionar uno, prioridad para leer carpetas", + "人声伴奏分离批量处理, 使用UVR5模型。": "Procesamiento por lotes de separación de voz y acompañamiento utilizando el modelo UVR5", + "人声提取激进程度": "Nivel de agresividad en la extracción de voz", + "以下文件或文件夹不存在:": "No Existe Tal Archivo o Carpeta:", + "以下模型不存在:": "No Existe tal Modelo:", + "伴奏人声分离&去混响&去回声": "Separación de acompañamiento y voz principal y eliminación de reverberación y eco", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "Se recomienda usar un GPT ajustado en modo sin texto de referencia; habilítelo si no puede entender el audio de referencia (si no sabe qué escribir). Una vez habilitado, ignorará el texto de referencia ingresado.", + "保存频率save_every_epoch": "Frecuencia de guardado (cada epoch)", + "凑50字一切": "Todo para alcanzar las 50 palabras", + "凑四句一切": "Completa cuatro oraciones para rellenar todo", + "切分后的子音频的输出根目录": "Directorio raíz de salida de los sub-audios después de la división", + "切割使用的进程数": "Número de procesos utilizados para la división", + "刷新模型路径": "Actualizar la ruta del modelo", + "前端处理后的文本(每句):": "Texto después del procesamiento previo (por frase):", + "去混响/去延迟,附:": "Eliminación de reverberación/retardo, incluye:", + "参考音频在3~10秒范围外,请更换!": "El audio de referencia está fuera del rango de 3 a 10 segundos, ¡por favor cámbielo!", + "参考音频的文本": "Texto de referencia del audio", + "参考音频的语种": "Idioma del audio de referencia", + "合成语音": "Síntesis de voz", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "Ejemplo de formato de ruta de carpeta válida: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例 (simplemente copie desde la barra de direcciones del administrador de archivos).", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "Ingrese el directorio donde se encuentran los audios después de la división. La ruta completa de los archivos de audio leídos = este directorio + nombre de archivo correspondiente en el archivo .list (no la ruta completa). Si se deja en blanco, se utilizará la ruta completa del archivo .list.", + "多语种混合": "Mezcla de varios idiomas", + "多语种混合(粤语)": "Mezcla Multilingüe (Cantonés)", + "实际输入的参考文本:": "Texto de referencia realmente ingresado:", + "实际输入的目标文本(切句后):": "Texto objetivo realmente ingresado (después de dividir en frases):", + "实际输入的目标文本(每句):": "Texto objetivo realmente ingresado (por frase):", + "实际输入的目标文本:": "Texto objetivo realmente ingresado:", + "导出文件格式": "Formato de archivo de exportación", + "开启GPT训练": "Iniciar entrenamiento de GPT", + "开启SSL提取": "Habilitar la extracción SSL", + "开启SoVITS训练": "Iniciar entrenamiento de SoVITS", + "开启一键三连": "Habilitar un solo paso de formateo", + "开启文本获取": "Habilitar la obtención de texto", + "开启无参考文本模式。不填参考文本亦相当于开启。": "Habilitar el modo sin texto de referencia. No llenar el texto de referencia también lo habilita.", + "开启离线批量ASR": "Habilitar ASR en lote fuera de línea", + "开启语义token提取": "Habilitar la extracción de tokens semánticos", + "开启语音切割": "Habilitar la división de voz", + "开启语音降噪": "Habilitar la reducción de ruido de voz", + "怎么切": "Cómo cortar", + "总训练轮数total_epoch": "Número total de épocas de entrenamiento", + "总训练轮数total_epoch,不建议太高": "Número total de épocas de entrenamiento, no se recomienda demasiado alto", + "打标工具WebUI已关闭": "Interfaz web de la herramienta de etiquetado cerrada", + "打标工具WebUI已开启": "Interfaz web de la herramienta de etiquetado iniciada", + "打标工具进程输出信息": "Información de salida del proceso de la herramienta de etiquetado", + "指定输出主人声文件夹": "Especificar carpeta de salida de voz principal", + "指定输出非主人声文件夹": "Especificar carpeta de salida de no voz principal", + "按中文句号。切": "Cortar según puntos en chino", + "按标点符号切": "Cortar según los signos de puntuación", + "按英文句号.切": "Cortar por puntos en inglés.", + "数据类型精度": "precisión del tipo de datos", + "文本模块学习率权重": "Peso de la tasa de aprendizaje del módulo de texto", + "文本进程输出信息": "Información de salida del proceso de obtención de texto", + "施工中,请静候佳音": "En construcción, por favor espere pacientemente", + "日文": "Japonés", + "日英混合": "Mezcla de japonés e inglés", + "是否仅保存最新的ckpt文件以节省硬盘空间": "¿Guardar solo el último archivo ckpt para ahorrar espacio en disco?", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "¿Guardar el modelo final pequeño en la carpeta de pesos en cada punto de guardado?", + "是否开启TTS推理WebUI": "¿Habilitar la interfaz web de inferencia TTS?", + "是否开启UVR5-WebUI": "¿Habilitar UVR5-WebUI?", + "是否开启dpo训练选项(实验性)": "¿Habilitar la opción de entrenamiento dpo (experimental)?", + "是否开启打标WebUI": "¿Habilitar la interfaz web de etiquetado?", + "是否直接对上次合成结果调整语速。防止随机性。": "¿Si se ajusta directamente la velocidad de habla del último resultado de síntesis para evitar aleatoriedad?", + "显卡信息": "Información de la tarjeta gráfica", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "Este software es de código abierto bajo la licencia MIT. El autor no tiene control sobre el software. El usuario que lo utilice o distribuya, y el que genere sonidos a partir del software, asume toda la responsabilidad.
Si no acepta estos términos, no puede utilizar ni hacer referencia a ningún código o archivo dentro del paquete de software. Consulte el archivo LICENSE en el directorio raíz para obtener más detalles.", + "模型": "Modelo", + "模型分为三类:": "Los modelos se dividen en tres categorías:", + "模型切换": "Cambio de modelo", + "每张显卡的batch_size": "Tamaño de lote por tarjeta gráfica", + "版本": "Versión", + "粤英混合": "Mezcla Cantonés-Inglés", + "粤语": "Cantonés", + "终止ASR进程": "Terminar el proceso ASR", + "终止GPT训练": "Detener entrenamiento de GPT", + "终止SSL提取进程": "Terminar el proceso de extracción SSL", + "终止SoVITS训练": "Detener entrenamiento de SoVITS", + "终止一键三连": "Terminar el proceso de un solo paso de formateo", + "终止文本获取进程": "Terminar el proceso de obtención de texto", + "终止语义token提取进程": "Terminar el proceso de extracción de tokens semánticos", + "终止语音切割": "Terminar la división de voz", + "终止语音降噪进程": "Terminar el proceso de reducción de ruido de voz", + "英文": "Inglés", + "语义token提取进程输出信息": "Información de salida del proceso de extracción de tokens semánticos", + "语速": "Velocidad de habla", + "语速调整,高为更快": "Ajustar la velocidad de habla, más alta para más rápido", + "语音切割进程输出信息": "Información de salida del proceso de división de voz", + "语音降噪进程输出信息": "Información de salida del proceso de reducción de ruido de voz", + "请上传3~10秒内参考音频,超过会报错!": "Por favor, suba un audio de referencia de entre 3 y 10 segundos, ¡más de eso causará un error!", + "请上传参考音频": "Por Favor, Suba el Audio de Referencia", + "请填入推理文本": "Por Favor, Ingrese el Texto Objetivo", + "请输入有效文本": "Por favor, introduzca un texto válido", + "转换": "Convertir", + "输入待处理音频文件夹路径": "Ingrese la ruta de la carpeta de audio a procesar", + "输入文件夹路径": "Ingrese la ruta de la carpeta", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "Debe haber archivos y carpetas que comiencen con 23456 en el directorio logs/nombre del experimento", + "输出信息": "Información de salida", + "输出文件夹路径": "Ruta de la carpeta de salida", + "输出的语音": "Audio de salida", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "Seleccione el modelo almacenado en SoVITS_weights y GPT_weights después del entrenamiento. Uno de ellos es el modelo base, útil para experimentar con TTS de 5 segundos sin entrenamiento.", + "降噪结果输出文件夹": "Carpeta de salida de los resultados de reducción de ruido", + "降噪音频文件输入文件夹": "Carpeta de entrada de archivos de audio para reducción de ruido", + "需要合成的文本": "Texto a sintetizar", + "需要合成的语种": "Idioma para la síntesis", + "韩文": "Coreano", + "韩英混合": "Mezcla Coreano-Inglés", + "音频自动切分输入路径,可文件可文件夹": "Ruta de entrada para la división automática de audio, puede ser un archivo o una carpeta", + "预训练的GPT模型路径": "Ruta del modelo GPT preentrenado", + "预训练的SSL模型路径": "Ruta del modelo SSL preentrenado", + "预训练的SoVITS-D模型路径": "Ruta del modelo SoVITS-D preentrenado", + "预训练的SoVITS-G模型路径": "Ruta del modelo SoVITS-G preentrenado", + "预训练的中文BERT模型路径": "Ruta del modelo BERT en chino preentrenado" +} diff --git a/tools/i18n/locale/fr_FR.json b/tools/i18n/locale/fr_FR.json new file mode 100644 index 0000000000000000000000000000000000000000..8763a25c58b5d93195d8d321600afdf08cb6ac0c --- /dev/null +++ b/tools/i18n/locale/fr_FR.json @@ -0,0 +1,178 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1) MDX-Net (onnx_dereverb) : C'est le meilleur choix pour la réverbération à deux canaux, mais il ne peut pas éliminer la réverbération à un seul canal;", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho : Supprime les effets de délai. Aggressive est plus exhaustif que Normal dans la suppression, DeReverb élimine également la réverbération, peut supprimer la réverbération monocanal, mais n'élimine pas complètement la réverbération de plaque à haute fréquence.", + "*GPT模型列表": "*Liste des modèles GPT", + "*SoVITS模型列表": "*Liste des modèles SoVITS", + "*实验/模型名": "*Nom de l'expérience/modèle", + "*文本标注文件": "*Fichier d'annotation de texte", + "*训练集音频文件目录": "*Répertoire des fichiers audio d'entraînement", + "*请上传并填写参考信息": "*Veuillez télécharger et remplir les informations de référence", + "*请填写需要合成的目标文本和语种模式": "*Veuillez saisir le texte cible à synthétiser et le mode de langue.", + ".list标注文件的路径": "Chemin du fichier d'annotation .list", + "0-前置数据集获取工具": "0-Outil de récupération de jeu de données préalable", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0a-Outil de séparation de la voix humaine et de l'accompagnement UVR5 & suppression de la réverbération et du retard", + "0b-语音切分工具": "0b-Outil de découpage vocal", + "0bb-语音降噪工具": "0bb-Outil de réduction du bruit vocal", + "0c-中文批量离线ASR工具": "0c-Outil chinois de transcription automatique hors ligne en masse", + "0d-语音文本校对标注工具": "0d-Outil de correction et d'annotation de texte vocal", + "1-GPT-SoVITS-TTS": "1-GPT-SoVITS-TTS", + "1A-训练集格式化工具": "1A-Outil de formatage du jeu de données d'entraînement", + "1Aa-文本内容": "1Aa-Contenu du texte", + "1Aabc-训练集格式化一键三连": "1Aabc-Formatage en un clic du jeu de données d'entraînement", + "1Ab-SSL自监督特征提取": "1Ab-Extraction de caractéristiques auto-supervisée SSL", + "1Ac-语义token提取": "1Ac-Extraction de jetons sémantiques", + "1B-微调训练": "1B-Entraînement fin", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1Ba-Entraînement SoVITS. Les fichiers de modèle destinés au partage sont enregistrés sous SoVITS_weights.", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1Bb-Entraînement GPT. Les fichiers de modèle destinés au partage sont enregistrés sous GPT_weights.", + "1C-推理": "1C-Inférence", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1. Le temps de traitement du modèle DeEcho-DeReverb est presque le double de celui des deux autres modèles DeEcho;", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1. Préserver les voix : Choisissez cette option pour les audio sans harmonie, car elle conserve mieux la voix principale par rapport au modèle HP5. Deux modèles intégrés, HP2 et HP3, sont disponibles. HP3 peut légèrement laisser passer l'accompagnement mais conserve la voix principale un peu mieux que HP2;", + "2-GPT-SoVITS-变声": "2-GPT-SoVITS-Modification de la voix", + "2、MDX-Net-Dereverb模型挺慢的;": "2. Le modèle MDX-Net-Dereverb est assez lent;", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2. Conserver uniquement la voix principale : Choisissez cette option pour les audio avec harmonie, car elle peut affaiblir la voix principale. Un modèle HP5 intégré est disponible;", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3. La configuration la plus propre que je recommande est d'utiliser d'abord MDX-Net, puis DeEcho-Aggressive.", + "3、去混响、去延迟模型(by FoxJoy):": "3. Modèle de suppression de réverbération et de retard (par FoxJoy) :", + "ASR 模型": "Modèle ASR", + "ASR 模型尺寸": "Taille du modèle ASR", + "ASR 语言设置": "Paramètres de langue ASR", + "ASR进程输出信息": "Informations de processus ASR", + "GPT模型列表": "Liste des modèles GPT", + "GPT训练进程输出信息": "Informations de processus d'entraînement GPT", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "Paramètres d'échantillonnage de GPT (ne pas mettre trop bas lorsqu'il n'y a pas de texte de référence. Utilisez les valeurs par défaut si vous n'êtes pas sûr):", + "GPU卡号,只能填1个整数": "Numéro de carte GPU, ne peut contenir qu'un seul entier", + "GPU卡号以-分割,每个卡号一个进程": "Numéro de carte GPU séparé par des tirets, un processus par numéro de carte", + "SSL进程输出信息": "Informations de processus SSL", + "SoVITS模型列表": "Liste des modèles SoVITS", + "SoVITS训练进程输出信息": "Informations de processus d'entraînement SoVITS", + "TTS推理WebUI进程输出信息": "Informations de processus de l'interface Web d'inférence TTS", + "TTS推理进程已关闭": "Le processus d'inférence TTS est terminé", + "TTS推理进程已开启": "Le processus d'inférence TTS est en cours", + "UVR5已关闭": "UVR5 est désactivé", + "UVR5已开启": "UVR5 est activé", + "UVR5进程输出信息": "Informations de processus UVR5", + "alpha_mix:混多少比例归一化后音频进来": "alpha_mix: proportion d'audio normalisé mélangé", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size: comment calculer la courbe de volume, plus petit pour une précision plus élevée mais une charge de calcul plus élevée (ce n'est pas une meilleure précision)", + "max:归一化后最大值多少": "max: valeur maximale après normalisation", + "max_sil_kept:切完后静音最多留多长": "max_sil_kept: durée maximale de silence après la coupe", + "min_interval:最短切割间隔": "min_interval: intervalle de coupe minimum", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length:longueur minimale de chaque segment ; si le premier segment est trop court, il est concaténé avec les segments suivants jusqu'à ce que la longueur dépasse cette valeur", + "temperature": "température", + "threshold:音量小于这个值视作静音的备选切割点": "seuil: le volume inférieur à cette valeur est considéré comme un point de coupe silencieux alternatif", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "Informations de processus de l'un clic trois connexions", + "不切": "Pas de découpe", + "中文": "Chinois", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "Documentation du tutoriel en chinois:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "Mélange de chinois et d'anglais", + "也可批量输入音频文件, 二选一, 优先读文件夹": "Également possible d'entrer en lot des fichiers audio, au choix, privilégiez la lecture du dossier", + "人声伴奏分离批量处理, 使用UVR5模型。": "Traitement par lot de séparation voix-accompagnement en utilisant le modèle UVR5.", + "人声提取激进程度": "Degré d'extraction des voix", + "以下文件或文件夹不存在:": "Aucun fichier ou dossier de ce type:", + "以下模型不存在:": "Aucun Modèle de ce Type:", + "伴奏人声分离&去混响&去回声": "Séparation de la voix et de l'accompagnement, suppression de la réverbération et de l'écho", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "Il est recommandé d'utiliser GPT finement ajusté en mode sans texte de référence. Si vous ne comprenez pas ce que dit l'audio de référence (vous ne savez pas quoi écrire), vous pouvez l'activer ; une fois activé, ignorez le texte de référence saisi.", + "保存频率save_every_epoch": "Fréquence de sauvegarde (sauvegarder à chaque époque)", + "凑50字一切": "Assembler 50 mots tout", + "凑四句一切": "Composez quatre phrases pour tout remplir", + "切分后的子音频的输出根目录": "Répertoire racine de sortie des sous-audios après découpage", + "切割使用的进程数": "Nombre de processus utilisés pour le découpage", + "刷新模型路径": "Actualiser le chemin du modèle", + "前端处理后的文本(每句):": "Texte après traitement frontal (par phrase):", + "去混响/去延迟,附:": "Suppression de la réverbération / suppression du retard, ci-joint:", + "参考音频在3~10秒范围外,请更换!": "Veuillez remplacer l'audio de référence si sa durée est en dehors de la plage de 3 à 10 secondes!", + "参考音频的文本": "Texte de l'audio de référence", + "参考音频的语种": "Langue de l'audio de référence", + "合成语音": "Synthèse vocale", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "Exemple de format de chemin de dossier valide : E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例 (copiez-le depuis la barre d'adresse de l'explorateur de fichiers).", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "Veuillez indiquer le répertoire contenant les audio découpés ! Le chemin complet du fichier audio à lire = ce répertoire - nom du fichier correspondant à l'onde dans le fichier .list (pas le chemin complet). Si laissé vide, le chemin absolu dans le fichier .list sera utilisé.", + "多语种混合": "Mélange multilingue", + "多语种混合(粤语)": "Mélange Multilingue (Cantonais)", + "实际输入的参考文本:": "Texte de référence réellement saisi:", + "实际输入的目标文本(切句后):": "Texte cible réellement saisi (après découpage):", + "实际输入的目标文本(每句):": "Texte cible réellement saisi (par phrase):", + "实际输入的目标文本:": "Texte cible réellement saisi:", + "导出文件格式": "Format d'exportation du fichier", + "开启GPT训练": "Activer l'entraînement GPT", + "开启SSL提取": "Activer l'extraction SSL", + "开启SoVITS训练": "Activer l'entraînement SoVITS", + "开启一键三连": "Activer l'un clic trois connexions", + "开启文本获取": "Activer l'extraction de texte", + "开启无参考文本模式。不填参考文本亦相当于开启。": "Activer le mode sans texte de référence. Laisser le texte de référence vide équivaut également à activer le mode.", + "开启离线批量ASR": "Activer la transcription automatique hors ligne en masse", + "开启语义token提取": "Activer l'extraction de jetons sémantiques", + "开启语音切割": "Activer le découpage vocal", + "开启语音降噪": "Activer la réduction de bruit vocal", + "怎么切": "Comment découper", + "总训练轮数total_epoch": "Nombre total d'époques d'entraînement", + "总训练轮数total_epoch,不建议太高": "Nombre total d'époques d'entraînement, pas recommandé d'être trop élevé", + "打标工具WebUI已关闭": "L'interface Web de l'outil d'annotation est terminée", + "打标工具WebUI已开启": "L'interface Web de l'outil d'annotation est en cours", + "打标工具进程输出信息": "Informations de processus de l'outil d'annotation", + "指定输出主人声文件夹": "Spécifier le dossier de sortie pour la voix principale", + "指定输出非主人声文件夹": "Spécifier le dossier de sortie pour la non-voix principale", + "按中文句号。切": "Couper selon les points en chinois.", + "按标点符号切": "Couper selon les signes de ponctuation", + "按英文句号.切": "Découpez par des points en anglais", + "数据类型精度": "précision du type de données", + "文本模块学习率权重": "Poids du taux d'apprentissage du module de texte", + "文本进程输出信息": "Informations de processus de texte", + "施工中,请静候佳音": "En construction, veuillez attendre patiemment", + "日文": "Japonais", + "日英混合": "Mélange Japonais-Anglais", + "是否仅保存最新的ckpt文件以节省硬盘空间": "Sauvegarder uniquement le dernier fichier ckpt pour économiser de l'espace disque", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "Sauvegarder le petit modèle final dans le dossier weights à chaque point de sauvegarde", + "是否开启TTS推理WebUI": "Activer l'interface Web d'inférence TTS", + "是否开启UVR5-WebUI": "Activer UVR5-WebUI", + "是否开启dpo训练选项(实验性)": "Activer l'option d'entraînement DPO (expérimental)", + "是否开启打标WebUI": "Activer l'interface Web d'annotation", + "是否直接对上次合成结果调整语速。防止随机性。": "Est-ce qu'on ajuste directement la vitesse de parole du dernier résultat de synthèse pour éviter l'aléatoire ?", + "显卡信息": "Informations sur la carte graphique", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "Ce logiciel est open source sous la licence MIT. L'auteur n'a aucun contrôle sur le logiciel. Les utilisateurs et les diffuseurs du son exporté par le logiciel en assument l'entière responsabilité.
Si vous n'acceptez pas ces termes, vous ne pouvez ni utiliser ni citer aucun code ou fichier à l'intérieur du package. Voir LICENSE dans le répertoire racine pour plus de détails.", + "模型": "Modèle", + "模型分为三类:": "Les modèles sont classés en trois catégories:", + "模型切换": "Changement de modèle", + "每张显卡的batch_size": "Taille de lot par carte graphique", + "版本": "Version", + "粤英混合": "Mélange Cantonais-Anglais", + "粤语": "Cantonais", + "终止ASR进程": "Arrêter le processus ASR", + "终止GPT训练": "Arrêter l'entraînement GPT", + "终止SSL提取进程": "Arrêter le processus d'extraction SSL", + "终止SoVITS训练": "Arrêter l'entraînement SoVITS", + "终止一键三连": "Arrêter l'un clic trois connexions", + "终止文本获取进程": "Arrêter le processus d'extraction de texte", + "终止语义token提取进程": "Arrêter le processus d'extraction de jetons sémantiques", + "终止语音切割": "Arrêter le découpage vocal", + "终止语音降噪进程": "Arrêter le processus de réduction du bruit vocal", + "英文": "Anglais", + "语义token提取进程输出信息": "Informations de processus d'extraction de jetons sémantiques", + "语速": "Débit de parole", + "语速调整,高为更快": "Ajuster la vitesse de parole, plus élevée pour plus rapide", + "语音切割进程输出信息": "Informations de processus de découpage vocal", + "语音降噪进程输出信息": "Informations de sortie du processus de réduction du bruit vocal", + "请上传3~10秒内参考音频,超过会报错!": "Veuillez télécharger une référence audio de 3 à 10 secondes ; les fichiers plus longs généreront une erreur!", + "请上传参考音频": "Veuillez télécharger l'audio de référence", + "请填入推理文本": "Veuillez remplir le texte cible", + "请输入有效文本": "Veuillez entrer un texte valide", + "转换": "Conversion", + "输入待处理音频文件夹路径": "Entrez le chemin du dossier audio à traiter", + "输入文件夹路径": "Chemin du dossier à entrer", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "Les fichiers et dossiers commençant par 23456 devraient être présents dans le répertoire logs/nom de l'expérience", + "输出信息": "Sortie d'information", + "输出文件夹路径": "Chemin du dossier de sortie", + "输出的语音": "Audio de sortie", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "Choisissez le modèle entraîné stocké sous SoVITS_weights et GPT_weights. Par défaut, l'un d'eux est un modèle de base pour l'expérience de TTS Zero Shot de 5 secondes.", + "降噪结果输出文件夹": "Dossier de sortie des résultats de réduction du bruit", + "降噪音频文件输入文件夹": "Dossier d'entrée des fichiers audio de réduction du bruit", + "需要合成的文本": "Texte à synthétiser", + "需要合成的语种": "Langue de synthèse requise", + "韩文": "Coreano", + "韩英混合": "Mezcla Coreano-Inglés", + "音频自动切分输入路径,可文件可文件夹": "Chemin d'entrée automatique de découpage audio, peut être un fichier ou un dossier", + "预训练的GPT模型路径": "Chemin du modèle GPT pré-entraîné", + "预训练的SSL模型路径": "Chemin du modèle SSL pré-entraîné", + "预训练的SoVITS-D模型路径": "Chemin du modèle SoVITS-D pré-entraîné", + "预训练的SoVITS-G模型路径": "Chemin du modèle SoVITS-G pré-entraîné", + "预训练的中文BERT模型路径": "Chemin du modèle BERT chinois pré-entraîné" +} diff --git a/tools/i18n/locale/it_IT.json b/tools/i18n/locale/it_IT.json new file mode 100644 index 0000000000000000000000000000000000000000..5203e44d5d7d92f8d6990fb5460e08f5e23f1c16 --- /dev/null +++ b/tools/i18n/locale/it_IT.json @@ -0,0 +1,178 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1)MDX-Net (onnx_dereverb): È la scelta migliore per la riverberazione a due canali, ma non può rimuovere la riverberazione a canale singolo;", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho: Rimuove gli effetti di ritardo. Aggressive è più completo di Normal nella rimozione, DeReverb rimuove ulteriormente la riverberazione, può rimuovere la riverberazione a canale singolo, ma non rimuove completamente la riverberazione a piastra ad alta frequenza.", + "*GPT模型列表": "*Lista dei modelli GPT", + "*SoVITS模型列表": "*Lista dei modelli SoVITS", + "*实验/模型名": "*Nome dell'esperimento/modello", + "*文本标注文件": "*File di annotazione del testo", + "*训练集音频文件目录": "*Directory dei file audio del set di addestramento", + "*请上传并填写参考信息": "*Carica e compila le informazioni di riferimento", + "*请填写需要合成的目标文本和语种模式": "*Si prega di inserire il testo di destinazione da sintetizzare e la modalità lingua", + ".list标注文件的路径": "Percorso del file di annotazione .list", + "0-前置数据集获取工具": "0-Strumento di acquisizione del dataset preliminare", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0a-Strumento di separazione voce e accompagnamento UVR5 & Rimozione riverbero e ritardo", + "0b-语音切分工具": "0b-Strumento di segmentazione vocale", + "0bb-语音降噪工具": "0bb-Strumento di riduzione del rumore vocale", + "0c-中文批量离线ASR工具": "0c-Strumento di ASR offline batch in cinese", + "0d-语音文本校对标注工具": "0d-Strumento di correzione e annotazione testo vocale", + "1-GPT-SoVITS-TTS": "1-GPT-SoVITS-TTS", + "1A-训练集格式化工具": "1A-Strumento di formattazione del set di addestramento", + "1Aa-文本内容": "1Aa-Contenuto del testo", + "1Aabc-训练集格式化一键三连": "1Aabc-Strumento di formattazione del set di addestramento con tre passaggi", + "1Ab-SSL自监督特征提取": "1Ab-Estrazione di caratteristiche auto-supervisionata SSL", + "1Ac-语义token提取": "1Ac-Estrazione del token semantico", + "1B-微调训练": "1B-Allenamento di affinamento", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1Ba-Allenamento di SoVITS. I file del modello destinati alla condivisione sono salvati in SoVITS_weights.", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1Bb-Allenamento di GPT. I file del modello destinati alla condivisione sono salvati in GPT_weights.", + "1C-推理": "1C-Inferenza", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1. Il tempo di elaborazione del modello DeEcho-DeReverb è quasi il doppio di quello degli altri due modelli DeEcho;", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1. Conserva la voce principale: scegli questa opzione per audio senza armonie, poiché conserva meglio la voce principale rispetto al modello HP5. Include due modelli integrati, HP2 e HP3. HP3 potrebbe far passare leggermente l'accompagnamento ma conserva meglio la voce principale rispetto a HP2;", + "2-GPT-SoVITS-变声": "2-GPT-SoVITS-Voce modificata", + "2、MDX-Net-Dereverb模型挺慢的;": "2. Il modello MDX-Net-Dereverb è piuttosto lento;", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2. Solo conserva la voce principale: scegli questa opzione per audio con armonie, poiché potrebbe indebolire la voce principale. Include un modello HP5;", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3. La configurazione più pulita consigliata è MDX-Net seguito da DeEcho-Aggressive.", + "3、去混响、去延迟模型(by FoxJoy):": "3. Modello per rimuovere la riverberazione e il ritardo (by FoxJoy):", + "ASR 模型": "Modello ASR", + "ASR 模型尺寸": "Dimensioni del modello ASR", + "ASR 语言设置": "Impostazioni linguistiche ASR", + "ASR进程输出信息": "Informazioni sull'output del processo ASR", + "GPT模型列表": "Elenco dei modelli GPT", + "GPT训练进程输出信息": "Informazioni sull'output del processo di allenamento di GPT", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "Parametri di campionamento di GPT (non troppo bassi quando non c'è testo di riferimento. Utilizzare i valori predefiniti in caso di incertezza):", + "GPU卡号,只能填1个整数": "Numero della scheda grafica, può essere inserito solo un numero intero", + "GPU卡号以-分割,每个卡号一个进程": "Numero di GPU separati da '-'; ogni numero corrisponde a un processo", + "SSL进程输出信息": "Informazioni sull'output del processo SSL", + "SoVITS模型列表": "Elenco dei modelli SoVITS", + "SoVITS训练进程输出信息": "Informazioni sull'output del processo di allenamento di SoVITS", + "TTS推理WebUI进程输出信息": "Informazioni sull'output del processo dell'interfaccia utente Web per l'inferenza TTS", + "TTS推理进程已关闭": "Il processo di inferenza TTS è stato chiuso", + "TTS推理进程已开启": "Il processo di inferenza TTS è stato avviato", + "UVR5已关闭": "UVR5 è disattivato", + "UVR5已开启": "UVR5 è attivato", + "UVR5进程输出信息": "Informazioni sull'output del processo UVR5", + "alpha_mix:混多少比例归一化后音频进来": "alpha_mix: Quanta proporzione dell'audio normalizzato deve essere miscelata", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size: Come calcolare la curva del volume. Più piccolo è, maggiore è la precisione ma aumenta la complessità computazionale (non significa che una maggiore precisione dà risultati migliori)", + "max:归一化后最大值多少": "max: Massimo valore dopo la normalizzazione", + "max_sil_kept:切完后静音最多留多长": "max_sil_kept: Massima durata del silenzio dopo il taglio", + "min_interval:最短切割间隔": "min_interval: Intervallo minimo di taglio", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length: Lunghezza minima per segmento; se il primo segmento è troppo corto, sarà unito ai segmenti successivi fino a superare questo valore", + "temperature": "temperatura", + "threshold:音量小于这个值视作静音的备选切割点": "threshold: Punto di taglio alternativo considerato silenzioso se il volume è inferiore a questo valore", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "Informazioni sull'output del processo di 'One Click Three Connect'", + "不切": "Nessuna suddivisione", + "中文": "Cinese", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "Documentazione del tutorial in cinese:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "Cinese e inglese misti", + "也可批量输入音频文件, 二选一, 优先读文件夹": "È possibile anche inserire file audio in batch, una delle due opzioni, con priorità alla lettura della cartella", + "人声伴奏分离批量处理, 使用UVR5模型。": "Separazione voce-accompagnamento in batch, utilizza il modello UVR5.", + "人声提取激进程度": "Grado di aggressività dell'estrazione vocale", + "以下文件或文件夹不存在:": "Nessun file o cartella trovati:", + "以下模型不存在:": "Nessun Modello del Genere:", + "伴奏人声分离&去混响&去回声": "Separazione tra accompagnamento e voce & Rimozione dell'eco & Rimozione dell'eco", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "Si consiglia di utilizzare GPT fine-tuned quando si utilizza la modalità senza testo di riferimento. Se non si riesce a capire cosa dice l'audio di riferimento (e non si sa cosa scrivere), è possibile abilitare questa opzione, ignorando il testo di riferimento inserito.", + "保存频率save_every_epoch": "Frequenza di salvataggio ogni epoca", + "凑50字一切": "Riempire con 50 caratteri per tutto", + "凑四句一切": "Riempire con quattro frasi per tutto", + "切分后的子音频的输出根目录": "Directory radice di output per gli audio segmentati", + "切割使用的进程数": "Numero di processi utilizzati per il taglio", + "刷新模型路径": "Aggiorna il percorso del modello", + "前端处理后的文本(每句):": "Testo elaborato dal front-end (per frase):", + "去混响/去延迟,附:": "Rimozione della riverberazione/ritardo, allegato:", + "参考音频在3~10秒范围外,请更换!": "L'audio di riferimento è al di fuori dell'intervallo di 3-10 secondi. Si prega di cambiarlo!", + "参考音频的文本": "Testo dell'audio di riferimento", + "参考音频的语种": "Lingua dell'audio di riferimento", + "合成语音": "Sintesi vocale", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "Formato di percorso della cartella valido: E:\\codes\\py39\\vits_vc_gpu\\Esempio di test di BaiLuShuangHua (copiare direttamente dalla barra degli indirizzi del gestore file).", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "Inserisci la directory dell'audio segmentato! Il percorso completo del file audio letto = questa directory - unione del nome del file corrispondente alle forme d'onda nel file .list (non il percorso completo). Se lasciato vuoto, verrà utilizzato il percorso assoluto nel file .list.", + "多语种混合": "Mix multilingue", + "多语种混合(粤语)": "Misto Multilingue (Cantonese)", + "实际输入的参考文本:": "Testo di riferimento effettivamente inserito:", + "实际输入的目标文本(切句后):": "Testo di destinazione effettivamente inserito (dopo il taglio delle frasi):", + "实际输入的目标文本(每句):": "Testo di destinazione effettivamente inserito (per frase):", + "实际输入的目标文本:": "Testo di destinazione effettivamente inserito:", + "导出文件格式": "Formato di esportazione del file", + "开启GPT训练": "Attivare l'allenamento di GPT", + "开启SSL提取": "Attivare l'estrazione SSL", + "开启SoVITS训练": "Attivare l'allenamento di SoVITS", + "开启一键三连": "Attivare la formattazione con tre passaggi", + "开启文本获取": "Attivare l'estrazione del testo", + "开启无参考文本模式。不填参考文本亦相当于开启。": "Attivare la modalità senza testo di riferimento. Anche se non inserisci un testo di riferimento, la modalità verrà attivata.", + "开启离线批量ASR": "Attivare ASR offline batch", + "开启语义token提取": "Attivare l'estrazione del token semantico", + "开启语音切割": "Attivare la segmentazione vocale", + "开启语音降噪": "Attivare la riduzione del rumore vocale", + "怎么切": "Come tagliare", + "总训练轮数total_epoch": "Numero totale di epoche di addestramento", + "总训练轮数total_epoch,不建议太高": "Numero totale di epoche di addestramento, non raccomandato troppo alto", + "打标工具WebUI已关闭": "L'interfaccia utente Web dello strumento di annotazione è stata chiusa", + "打标工具WebUI已开启": "L'interfaccia utente Web dello strumento di annotazione è stata avviata", + "打标工具进程输出信息": "Informazioni sull'output del processo di annotazione", + "指定输出主人声文件夹": "Specifica la cartella di output per la voce principale", + "指定输出非主人声文件夹": "Specifica la cartella di output per la non voce principale", + "按中文句号。切": "Taglia secondo il punto cinese.", + "按标点符号切": "Taglia secondo i segni di punteggiatura", + "按英文句号.切": "Taglia secondo il punto inglese", + "数据类型精度": "precisione del tipo di dati", + "文本模块学习率权重": "Peso del tasso di apprendimento del modulo di testo", + "文本进程输出信息": "Informazioni sull'output del processo di estrazione del testo", + "施工中,请静候佳音": "In costruzione, attendi pazientemente le buone notizie", + "日文": "Giapponese", + "日英混合": "Mix giapponese e inglese", + "是否仅保存最新的ckpt文件以节省硬盘空间": "Salvare solo il file ckpt più recente per risparmiare spazio su disco", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "Salvare il modello finale più piccolo nella cartella weights ad ogni punto di salvataggio", + "是否开启TTS推理WebUI": "Attivare l'interfaccia utente Web per l'inferenza TTS", + "是否开启UVR5-WebUI": "Attivare UVR5-WebUI", + "是否开启dpo训练选项(实验性)": "Attivare l'opzione di addestramento DPO (sperimentale)", + "是否开启打标WebUI": "Attivare l'interfaccia utente Web di annotazione", + "是否直接对上次合成结果调整语速。防止随机性。": "Se regolare direttamente la velocità della voce dell'ultimo risultato di sintesi per evitare casualità.", + "显卡信息": "Informazioni sulla scheda grafica", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "Questo software è open source con licenza MIT. L'autore non ha alcun controllo sul software. L'utente che utilizza il software o diffonde i suoni derivati dal software ne è responsabile.
Se non accetti questi termini, non puoi utilizzare o citare alcun codice o file all'interno del pacchetto software. Vedi la cartella principaleLICENSE per i dettagli.", + "模型": "Modello", + "模型分为三类:": "I modelli sono divisi in tre categorie:", + "模型切换": "Cambio del modello", + "每张显卡的batch_size": "Batch size per ogni scheda grafica", + "版本": "Versione", + "粤英混合": "Misto Cantonese-Inglese", + "粤语": "Cantonese", + "终止ASR进程": "Terminare il processo ASR", + "终止GPT训练": "Terminare l'allenamento di GPT", + "终止SSL提取进程": "Terminare il processo di estrazione SSL", + "终止SoVITS训练": "Terminare l'allenamento di SoVITS", + "终止一键三连": "Terminare la formattazione con tre passaggi", + "终止文本获取进程": "Terminare il processo di estrazione del testo", + "终止语义token提取进程": "Terminare il processo di estrazione del token semantico", + "终止语音切割": "Terminare la segmentazione vocale", + "终止语音降噪进程": "Termina il processo di riduzione del rumore vocale", + "英文": "Inglese", + "语义token提取进程输出信息": "Informazioni sull'output del processo di estrazione del token semantico", + "语速": "Velocità della voce", + "语速调整,高为更快": "Regolare la velocità della voce, più alta per più veloce", + "语音切割进程输出信息": "Informazioni sull'output del processo di segmentazione vocale", + "语音降噪进程输出信息": "Informazioni sull'output del processo di riduzione del rumore vocale", + "请上传3~10秒内参考音频,超过会报错!": "Carica un audio di riferimento della durata compresa tra 3 e 10 secondi. Superiore a questo, verrà generato un errore!", + "请上传参考音频": "Si prega di caricare l'audio di riferimento", + "请填入推理文本": "Si prega di inserire il testo di destinazione", + "请输入有效文本": "Inserisci un testo valido", + "转换": "Converti", + "输入待处理音频文件夹路径": "Inserisci il percorso della cartella dei file audio da elaborare", + "输入文件夹路径": "Inserisci il percorso della cartella", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "Nella cartella logs/nome dell'esperimento dovrebbero esserci file e cartelle che iniziano con 23456", + "输出信息": "Informazioni di output", + "输出文件夹路径": "Percorso della cartella di output", + "输出的语音": "Audio di output", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "Scegli il modello salvato in SoVITS_weights e GPT_weights dopo l'addestramento. Uno di default è il modello di base, utilizzato per l'esperienza di Zero Shot TTS in 5 secondi.", + "降噪结果输出文件夹": "Cartella di output dei risultati di riduzione del rumore", + "降噪音频文件输入文件夹": "Cartella di input dei file audio per la riduzione del rumore", + "需要合成的文本": "Testo da sintetizzare", + "需要合成的语种": "Lingua da sintetizzare", + "韩文": "Coreano", + "韩英混合": "Misto Coreano-Inglese", + "音频自动切分输入路径,可文件可文件夹": "Percorso di input per la segmentazione automatica dell'audio, può essere un file o una cartella", + "预训练的GPT模型路径": "Percorso del modello preaddestrato GPT", + "预训练的SSL模型路径": "Percorso del modello SSL preaddestrato", + "预训练的SoVITS-D模型路径": "Percorso del modello preaddestrato SoVITS-D", + "预训练的SoVITS-G模型路径": "Percorso del modello preaddestrato SoVITS-G", + "预训练的中文BERT模型路径": "Percorso del modello BERT cinese preaddestrato" +} diff --git a/tools/i18n/locale/ja_JP.json b/tools/i18n/locale/ja_JP.json new file mode 100644 index 0000000000000000000000000000000000000000..719ae07b920b427725998d9ea7c6d64f75734ef3 --- /dev/null +++ b/tools/i18n/locale/ja_JP.json @@ -0,0 +1,178 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1)MDX-Net(onnx_dereverb):二重チャンネルのリバーブに最適な選択ですが、単一チャンネルのリバーブは除去できません;", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho:遅延効果を除去します。AggressiveはNormalよりも徹底的に除去し、DeReverbは追加でリバーブを除去し、モノラルリバーブを除去できますが、高周波数のプレートリバーブは完全には除去できません。", + "*GPT模型列表": "*GPTモデルリスト", + "*SoVITS模型列表": "*SoVITSモデルリスト", + "*实验/模型名": "*実験/モデル名", + "*文本标注文件": "*テキスト注釈ファイル", + "*训练集音频文件目录": "*トレーニングデータのオーディオファイルディレクトリ", + "*请上传并填写参考信息": "*参照情報をアップロードして記入してください", + "*请填写需要合成的目标文本和语种模式": "*合成対象テキストと言語モードを入力してください", + ".list标注文件的路径": ".listアノテーションファイルのパス", + "0-前置数据集获取工具": "0-データセット取得ツールの事前処理", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0a-UVR5ボーカルアカンパニメント分離&リバーブおよびディレイ除去ツール", + "0b-语音切分工具": "0b-音声分割ツール", + "0bb-语音降噪工具": "0bb-音声ノイズ除去ツール", + "0c-中文批量离线ASR工具": "0c-中国語バッチオフラインASRツール", + "0d-语音文本校对标注工具": "0d-音声テキストの校正アノテーションツール", + "1-GPT-SoVITS-TTS": "1-GPT-SoVITS-TTS", + "1A-训练集格式化工具": "1A-トレーニングデータのフォーマットツール", + "1Aa-文本内容": "1Aa-テキストの内容", + "1Aabc-训练集格式化一键三连": "1Aabc-トレーニングデータのフォーマットワンクリック三連", + "1Ab-SSL自监督特征提取": "1Ab-SSLセルフスーパーバイズ特徴抽出", + "1Ac-语义token提取": "1Ac-セマンティックトークン抽出", + "1B-微调训练": "1B-ファインチューニングトレーニング", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1Ba-SoVITSトレーニング。共有用のモデルファイルはSoVITS_weightsディレクトリに出力されます。", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1Bb-GPTトレーニング。共有用のモデルファイルはGPT_weightsディレクトリに出力されます。", + "1C-推理": "1C-推論", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1、DeEcho-DeReverbモデルの処理時間は、他の2つのDeEchoモデルのほぼ2倍です;", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1、主音を保持: ハーモニーなしの音声にはこのオプションを選択し、HP5よりも主音の保持が優れています。HP2とHP3の2つのモデルが内蔵されており、HP3はわずかに伴奏を漏らす可能性がありますが、HP2よりも主音の保持がわずかに良いです;", + "2-GPT-SoVITS-变声": "2-GPT-SoVITS-ボイスチェンジャー", + "2、MDX-Net-Dereverb模型挺慢的;": "2、MDX-Net-Dereverbモデルはかなり遅いです;", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2、主音のみを保持: ハーモニー付きの音声にはこのオプションを選択し、主音が弱くなる可能性があります。HP5モデルが1つ内蔵されています;", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3、最もクリーンな設定は、MDX-Netの後にDeEcho-Aggressiveを使用することをお勧めします。", + "3、去混响、去延迟模型(by FoxJoy):": "3、リバーブ除去と遅延除去モデル(by FoxJoy):", + "ASR 模型": "ASR モデル", + "ASR 模型尺寸": "ASRモデルサイズ", + "ASR 语言设置": "ASR 言語設定", + "ASR进程输出信息": "ASRプロセスの出力情報", + "GPT模型列表": "GPTモデルリスト", + "GPT训练进程输出信息": "GPTトレーニングプロセスの出力情報", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "GPT サンプリングパラメーター(参照テキストがない場合はあまり低くしないでください。わからない場合はデフォルトを使用してください):", + "GPU卡号,只能填1个整数": "GPU番号、1つの整数しか入力できません", + "GPU卡号以-分割,每个卡号一个进程": "GPUカード番号はハイフンで区切り、各カード番号ごとに1つのプロセスが実行されます", + "SSL进程输出信息": "SSLプロセスの出力情報", + "SoVITS模型列表": "SoVITSモデルリスト", + "SoVITS训练进程输出信息": "SoVITSトレーニングプロセスの出力情報", + "TTS推理WebUI进程输出信息": "TTS推論WebUIプロセスの出力情報", + "TTS推理进程已关闭": "TTS推論プロセスが終了しました", + "TTS推理进程已开启": "TTS推論プロセスが開始されました", + "UVR5已关闭": "UVR5がオフになっています", + "UVR5已开启": "UVR5がオンになっています", + "UVR5进程输出信息": "UVR5プロセスの出力情報", + "alpha_mix:混多少比例归一化后音频进来": "alpha_mix:正規化後のオーディオが入る割合", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size: 音量曲線の計算方法、小さいほど精度が高くなりますが、計算量が増加します(精度が高いほど必ずしも効果が良いわけではありません)", + "max:归一化后最大值多少": "max:正規化後の最大値", + "max_sil_kept:切完后静音最多留多长": "max_sil_kept:切り終えた後、最大でどれだけ静かにするか", + "min_interval:最短切割间隔": "min_interval:最短カット間隔", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length:各セグメントの最小長さ。最初のセグメントが短すぎる場合、連続して後続のセグメントに接続され、この値を超えるまで続きます。", + "temperature": "temperature", + "threshold:音量小于这个值视作静音的备选切割点": "閾値:この値未満の音量は静音と見なされ、代替のカットポイントとして扱われます", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "ワンクリック三連プロセスの出力情報", + "不切": "切らない", + "中文": "中国語", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "中国語チュートリアルドキュメント:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "中英混合", + "也可批量输入音频文件, 二选一, 优先读文件夹": "複数のオーディオファイルもインポートできます。フォルダパスが存在する場合、この入力は無視されます。", + "人声伴奏分离批量处理, 使用UVR5模型。": "人声と伴奏の分離をバッチ処理で行い、UVR5モデルを使用します。", + "人声提取激进程度": "人声抽出の積極性", + "以下文件或文件夹不存在:": "そのようなファイルやフォルダは存在しません:", + "以下模型不存在:": "モデルが存在しません:", + "伴奏人声分离&去混响&去回声": "ボーカル/伴奏の分離と残響の除去", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "参考テキストなしモードを使用する場合は、微調整されたGPTの使用をお勧めします。参考音声が聞き取れない場合(何を書けば良いかわからない場合)は、有効にすると、入力した参考テキストを無視します。", + "保存频率save_every_epoch": "保存頻度save_every_epoch", + "凑50字一切": "50文字ずつカット", + "凑四句一切": "4つの文で埋める", + "切分后的子音频的输出根目录": "分割後のサブオーディオの出力ルートディレクトリ", + "切割使用的进程数": "分割に使用されるプロセス数", + "刷新模型路径": "モデルのパスを更新", + "前端处理后的文本(每句):": "フロントエンド処理後のテキスト(文ごと):", + "去混响/去延迟,附:": "残響除去/遅延除去、附:", + "参考音频在3~10秒范围外,请更换!": "参照音声が3~10秒の範囲外です。別の音声に変更してください!", + "参考音频的文本": "参照オーディオのテキスト", + "参考音频的语种": "参照オーディオの言語", + "合成语音": "推論を開始", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "適切なフォルダパスの例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华テストサンプル(ファイルマネージャのアドレスバーからコピーしてください)。", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "切断後の音声ファイルが格納されているディレクトリを入力してください!読み取り対象の音声ファイルの完全パス = このディレクトリ - 結合 - listファイル内の波形に対応するファイル名(完全パスではありません)。空白の場合、.listファイル内の絶対完全パスを使用します。", + "多语种混合": "多言語混合", + "多语种混合(粤语)": "多言語混合(粤語)", + "实际输入的参考文本:": "実際に入力された参照テキスト:", + "实际输入的目标文本(切句后):": "実際に入力された目標テキスト(文分割後):", + "实际输入的目标文本(每句):": "実際に入力された目標テキスト(文ごと):", + "实际输入的目标文本:": "実際に入力された目標テキスト:", + "导出文件格式": "エクスポートファイル形式", + "开启GPT训练": "GPTトレーニングを開始", + "开启SSL提取": "SSL抽出を開始", + "开启SoVITS训练": "SoVITSトレーニングを開始", + "开启一键三连": "ワンクリック三連を開始", + "开启文本获取": "テキストの取得を開始", + "开启无参考文本模式。不填参考文本亦相当于开启。": "参照テキストなしモードを有効にします。参照テキストを入力しない場合も同様に有効になります。", + "开启离线批量ASR": "オフラインバッチASRを開始", + "开启语义token提取": "セマンティックトークン抽出を開始", + "开启语音切割": "音声の分割を開始", + "开启语音降噪": "音声ノイズ除去を有効にする", + "怎么切": "どうやって切るか", + "总训练轮数total_epoch": "総トレーニングエポック数total_epoch", + "总训练轮数total_epoch,不建议太高": "総トレーニングエポック数total_epoch、高すぎないようにお勧めします", + "打标工具WebUI已关闭": "校正ツールWebUIが終了しました", + "打标工具WebUI已开启": "校正ツールWebUIが開始されました", + "打标工具进程输出信息": "アノテーションツールプロセスの出力情報", + "指定输出主人声文件夹": "ボーカルの出力フォルダを指定:", + "指定输出非主人声文件夹": "伴奏の出力フォルダを指定:", + "按中文句号。切": "中国語の句点でカット", + "按标点符号切": "句読点で分割", + "按英文句号.切": "英文のピリオドで切ってください", + "数据类型精度": "データ型の精度", + "文本模块学习率权重": "テキストモジュールの学習率の重み", + "文本进程输出信息": "テキストプロセスの出力情報", + "施工中,请静候佳音": "施工中、お待ちください", + "日文": "日本語", + "日英混合": "日英混合", + "是否仅保存最新的ckpt文件以节省硬盘空间": "最新のckptファイルのみを保存してディスクスペースを節約するかどうか", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "各保存時間点で最終的な小さなモデルをweightsフォルダに保存するかどうか", + "是否开启TTS推理WebUI": "TTS推論WebUIを開く", + "是否开启UVR5-WebUI": "UVR5-WebUIをオンにしますか", + "是否开启dpo训练选项(实验性)": "DPOトレーニングオプションを有効にするかどうか(実験的)", + "是否开启打标WebUI": "WebUIを使用したアノテーションを開始しますか", + "是否直接对上次合成结果调整语速。防止随机性。": "直前の合成結果の話速を直接調整して、ランダム性を防ぐか。", + "显卡信息": "グラフィックカード情報", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "このソフトウェアはMITライセンスでオープンソース化されており、作者はソフトウェアに対して一切の制御権を持っていません。ソフトウェアを使用する者、ソフトウェアから導出される音声を広める者は、自己責任で行ってください。
この条件を認めない場合、ソフトウェアパッケージ内の任意のコードやファイルを使用または引用することはできません。詳細はルートディレクトリのLICENSEを参照してください。", + "模型": "モデル", + "模型分为三类:": "モデルは3種類に分かれています:", + "模型切换": "モデル切り替え", + "每张显卡的batch_size": "各グラフィックカードのバッチサイズ", + "版本": "バージョン", + "粤英混合": "粤英混合", + "粤语": "粤語", + "终止ASR进程": "ASRプロセスを停止", + "终止GPT训练": "GPTトレーニングを停止", + "终止SSL提取进程": "SSL抽出プロセスを停止", + "终止SoVITS训练": "SoVITSトレーニングを停止", + "终止一键三连": "ワンクリック三連を停止", + "终止文本获取进程": "テキスト取得プロセスを停止", + "终止语义token提取进程": "セマンティックトークン抽出プロセスを停止", + "终止语音切割": "音声の分割を停止", + "终止语音降噪进程": "音声ノイズ除去プロセスを終了する", + "英文": "英語", + "语义token提取进程输出信息": "セマンティックトークン抽出プロセスの出力情報", + "语速": "話速", + "语速调整,高为更快": "話速調整、高いほど速く", + "语音切割进程输出信息": "音声分割プロセスの出力情報", + "语音降噪进程输出信息": "音声ノイズ除去プロセスの出力情報", + "请上传3~10秒内参考音频,超过会报错!": "3~10秒以内の参照音声をアップロードしてください。それを超えるとエラーが発生します!", + "请上传参考音频": "リファレンスオーディオをアップロードしてください", + "请填入推理文本": "ターゲットテキストを入力してください", + "请输入有效文本": "有効なテキストを入力してください", + "转换": "変換", + "输入待处理音频文件夹路径": "処理するオーディオフォルダのパスを入力してください:", + "输入文件夹路径": "入力フォルダのパス", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "logs/実験名ディレクトリには23456で始まるファイルとフォルダが含まれている必要があります", + "输出信息": "出力情報", + "输出文件夹路径": "出力フォルダのパス", + "输出的语音": "推論結果", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "SoVITS_weightsおよびGPT_weightsに保存されたモデルを選択します。デフォルトのものはプレトレインであり、ゼロショットTTSを体験できます。", + "降噪结果输出文件夹": "ノイズ除去結果出力フォルダ", + "降噪音频文件输入文件夹": "ノイズ除去音声ファイル入力フォルダ", + "需要合成的文本": "推論テキスト", + "需要合成的语种": "推論テキストの言語", + "韩文": "韓国語", + "韩英混合": "韓英混合", + "音频自动切分输入路径,可文件可文件夹": "オーディオの自動分割入力パス、ファイルまたはフォルダを指定できます", + "预训练的GPT模型路径": "事前にトレーニングされたGPTモデルのパス", + "预训练的SSL模型路径": "事前にトレーニングされたSSLモデルのパス", + "预训练的SoVITS-D模型路径": "事前にトレーニングされたSoVITS-Dモデルのパス", + "预训练的SoVITS-G模型路径": "事前にトレーニングされたSoVITS-Gモデルのパス", + "预训练的中文BERT模型路径": "事前にトレーニングされた中文BERTモデルのパス" +} diff --git a/tools/i18n/locale/ko_KR.json b/tools/i18n/locale/ko_KR.json new file mode 100644 index 0000000000000000000000000000000000000000..ba3604692c5c3932e74676a96cd54766c64bdf81 --- /dev/null +++ b/tools/i18n/locale/ko_KR.json @@ -0,0 +1,178 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1)MDX-Net (onnx_dereverb): 듀얼 채널 리버브에는 가장 적합하지만, 싱글 채널 리버브는 제거할 수 없습니다", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho:지연 효과를 제거합니다. Aggressive는 Normal보다 더 철저하게 제거하며, DeReverb는 추가로 리버브를 제거하여 단일 채널 리버브를 제거할 수 있지만 고주파 리버브는 완전히 제거하지 못합니다.", + "*GPT模型列表": "*GPT 모델 목록", + "*SoVITS模型列表": "*SoVITS 모델 목록", + "*实验/模型名": "*실험/모델 이름", + "*文本标注文件": "*텍스트 주석 파일", + "*训练集音频文件目录": "*훈련 세트 오디오 파일 디렉터리", + "*请上传并填写参考信息": "*참고 정보를 업로드하고 입력하십시오", + "*请填写需要合成的目标文本和语种模式": "*합성할 목표 텍스트와 언어 모드를 입력하세요", + ".list标注文件的路径": ".list 주석 파일 경로", + "0-前置数据集获取工具": "0-전방 데이터 세트 수집 도구", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0a-UVR5 보컬 및 반주 분리 및 에코 및 지연 제거 도구", + "0b-语音切分工具": "0b-음성 분리 도구", + "0bb-语音降噪工具": "0bb-음성 노이즈 제거 도구", + "0c-中文批量离线ASR工具": "0c-중국어 대량 오프라인 ASR 도구", + "0d-语音文本校对标注工具": "0d-음성 텍스트 교정 주석 도구", + "1-GPT-SoVITS-TTS": "1-GPT-SoVITS-TTS", + "1A-训练集格式化工具": "1A-훈련 세트 형식 지정 도구", + "1Aa-文本内容": "1Aa-텍스트 내용", + "1Aabc-训练集格式化一键三连": "1Aabc-훈련 세트 형식 지정 일괄 처리", + "1Ab-SSL自监督特征提取": "1Ab-SSL 자기 지도 특징 추출", + "1Ac-语义token提取": "1Ac-의미 토큰 추출", + "1B-微调训练": "1B-미세 조정 훈련", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1Ba-SoVITS 훈련. 공유 용 모델 파일은 SoVITS_weights 하위에 출력됩니다.", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1Bb-GPT 훈련. 공유 용 모델 파일은 GPT_weights 하위에 출력됩니다.", + "1C-推理": "1C-추론", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1. DeEcho-DeReverb 모델의 처리 시간은 다른 두 DeEcho 모델의 거의 두 배입니다;", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1. 사람 목소리를 유지: 화음이 없는 오디오를 선택하면 HP5보다 사람 목소리를 더 잘 유지할 수 있습니다. 내장된 HP2와 HP3 모델이 있으며, HP3는 화음을 약간 놓칠 수 있지만 HP2보다 사람 목소리를 조금 더 잘 유지합니다;", + "2-GPT-SoVITS-变声": "2-GPT-SoVITS-음성 변환", + "2、MDX-Net-Dereverb模型挺慢的;": "2. MDX-Net-Dereverb 모델은 꽤 느립니다;", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2. 주 목소리만 유지: 화음이 있는 오디오에 이 모델을 선택하면 주 목소리가 약해질 수 있습니다. 내장된 HP5 모델이 있습니다;", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3. 개인적으로 가장 깨끗한 설정은 먼저 MDX-Net을 사용하고 그 다음에 DeEcho-Aggressive를 사용하는 것입니다;", + "3、去混响、去延迟模型(by FoxJoy):": "3. 잔향 제거 및 지연 제거 모델 (by FoxJoy):", + "ASR 模型": "ASR 모델", + "ASR 模型尺寸": "ASR 모델 크기", + "ASR 语言设置": "ASR 언어 설정", + "ASR进程输出信息": "ASR 프로세스 출력 정보", + "GPT模型列表": "GPT 모델 목록", + "GPT训练进程输出信息": "GPT 훈련 프로세스 출력 정보", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "GPT 샘플링 매개변수 (참조 텍스트가 없을 때 너무 낮게 설정하지 마십시오. 확실하지 않으면 기본값을 사용하십시오):", + "GPU卡号,只能填1个整数": "GPU 카드 번호, 1개의 정수만 입력 가능", + "GPU卡号以-分割,每个卡号一个进程": "GPU 카드 번호는 -로 구분되며 각 카드 번호에 하나의 프로세스가 있어야 함", + "SSL进程输出信息": "SSL 프로세스 출력 정보", + "SoVITS模型列表": "SoVITS 모델 목록", + "SoVITS训练进程输出信息": "SoVITS 훈련 프로세스 출력 정보", + "TTS推理WebUI进程输出信息": "TTS 추론 WebUI 프로세스 출력 정보", + "TTS推理进程已关闭": "TTS 추론 프로세스가 닫혔습니다", + "TTS推理进程已开启": "TTS 추론 프로세스가 열렸습니다", + "UVR5已关闭": "UVR5가 비활성화되었습니다", + "UVR5已开启": "UVR5가 활성화되었습니다", + "UVR5进程输出信息": "UVR5 프로세스 출력 정보", + "alpha_mix:混多少比例归一化后音频进来": "알파 믹스: 정규화된 오디오가 들어오는 비율", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop 크기: 볼륨 곡선을 계산하는 방법. 작을수록 정확도가 높아지지만 계산량이 높아집니다 (정확도가 높다고 효과가 좋아지지 않음)", + "max:归一化后最大值多少": "최대 값 (정규화 후)", + "max_sil_kept:切完后静音最多留多长": "최대 유지되는 정적 길이 (분리 후)", + "min_interval:最短切割间隔": "최소 분리 간격", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length:각 부분의 최소 길이, 첫 번째 부분이 너무 짧으면 다음 부분과 계속 연결하여 이 값을 초과할 때까지", + "temperature": "temperature", + "threshold:音量小于这个值视作静音的备选切割点": "임계 값: 이 값보다 작은 볼륨은 대체 분리 지점으로 간주됩니다.", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "일괄 처리 프로세스 출력 정보", + "不切": "자르지 않음", + "中文": "중국어", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "중국어 튜토리얼 문서:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "중영 혼합", + "也可批量输入音频文件, 二选一, 优先读文件夹": "오디오 파일을 일괄로 입력할 수도 있습니다. 둘 중 하나를 선택하고 폴더를 읽기를 우선합니다.", + "人声伴奏分离批量处理, 使用UVR5模型。": "보컬과 반주 분리 배치 처리, UVR5 모델 사용.", + "人声提取激进程度": "보컬 추출의 공격성", + "以下文件或文件夹不存在:": "해당 파일 또는 폴더가 존재하지 않습니다:", + "以下模型不存在:": "해당 모델이 존재하지 않습니다:", + "伴奏人声分离&去混响&去回声": "반주 및 보컬 분리 & 리버브 제거 & 에코 제거", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "참고 텍스트가 없을 때는 미세 조정된 GPT를 사용하는 것이 좋습니다. 참고 오디오에서 무엇을 말하는지 잘 들리지 않으면 이 모드를 켜서 입력한 참고 텍스트를 무시할 수 있습니다.", + "保存频率save_every_epoch": "저장 빈도 (각 라운드마다)", + "凑50字一切": "50자를 채우십시오", + "凑四句一切": "네 문장의 세트를 완성하세요.", + "切分后的子音频的输出根目录": "분리된 하위 오디오의 출력 기본 디렉터리", + "切割使用的进程数": "사용되는 프로세스 수로 자르기", + "刷新模型路径": "모델 경로 새로 고침", + "前端处理后的文本(每句):": "프론트엔드 처리 후 텍스트(문장별):", + "去混响/去延迟,附:": "리버브 제거/지연 제거, 부록:", + "参考音频在3~10秒范围外,请更换!": "참고 오디오가 3~10초 범위를 벗어났습니다. 다른 것으로 바꾸십시오!", + "参考音频的文本": "참고 오디오의 텍스트", + "参考音频的语种": "참고 오디오의 언어", + "合成语音": "합성 음성", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "적절한 폴더 경로 형식 예: E:\\codes\\py39\\vits_vc_gpu\\백로서리 테스트 샘플 (파일 관리자 주소 표시줄에서 복사하면 됩니다).", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "분리된 오디오가 위치한 디렉터리를 입력하세요! 읽어들인 오디오 파일의 전체 경로 = 이 디렉터리 - list 파일에서 파형에 해당하는 파일명(전체 경로가 아님). 비워 두면 .list 파일의 절대 전체 경로를 사용합니다.", + "多语种混合": "다국어 혼합", + "多语种混合(粤语)": "다국어 혼합(粤語)", + "实际输入的参考文本:": "실제 입력된 참고 텍스트:", + "实际输入的目标文本(切句后):": "실제 입력된 목표 텍스트(문장 분리 후):", + "实际输入的目标文本(每句):": "실제 입력된 목표 텍스트(문장별):", + "实际输入的目标文本:": "실제 입력된 목표 텍스트:", + "导出文件格式": "내보내기 파일 형식", + "开启GPT训练": "GPT 훈련 활성화", + "开启SSL提取": "SSL 추출 활성화", + "开启SoVITS训练": "SoVITS 훈련 활성화", + "开启一键三连": "일괄 처리 활성화", + "开启文本获取": "텍스트 추출 활성화", + "开启无参考文本模式。不填参考文本亦相当于开启。": "참고 텍스트 없이 모드를 활성화합니다. 참고 텍스트를 입력하지 않으면 자동으로 활성화됩니다.", + "开启离线批量ASR": "오프라인 대량 ASR 활성화", + "开启语义token提取": "의미 토큰 추출 활성화", + "开启语音切割": "음성 분리 활성화", + "开启语音降噪": "음성 노이즈 제거 활성화", + "怎么切": "자르기 옵션", + "总训练轮数total_epoch": "총 훈련 라운드 수 (total_epoch)", + "总训练轮数total_epoch,不建议太高": "총 훈련 라운드 수 (total_epoch), 너무 높지 않게 권장됨", + "打标工具WebUI已关闭": "주석 도구 WebUI가 닫혔습니다", + "打标工具WebUI已开启": "주석 도구 WebUI가 열렸습니다", + "打标工具进程输出信息": "주석 도구 프로세스 출력 정보", + "指定输出主人声文件夹": "지정된 주인 목소리 출력 폴더", + "指定输出非主人声文件夹": "지정된 비주인 목소리 출력 폴더", + "按中文句号。切": "중국어 문장으로 분리하십시오.", + "按标点符号切": "구두점을 기준으로 자르기", + "按英文句号.切": "영어 문장으로 분리하기", + "数据类型精度": "데이터 유형 정밀도", + "文本模块学习率权重": "텍스트 모듈 학습률 가중치", + "文本进程输出信息": "텍스트 프로세스 출력 정보", + "施工中,请静候佳音": "공사 중입니다. 기다려주십시오.", + "日文": "일본어", + "日英混合": "일본어와 영어 혼합", + "是否仅保存最新的ckpt文件以节省硬盘空间": "디스크 공간을 절약하기 위해 최신 ckpt 파일만 저장할지 여부", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "각 저장 시간에 최종 작은 모델을 weights 폴더에 저장할지 여부", + "是否开启TTS推理WebUI": "TTS 추론 WebUI 활성화 여부", + "是否开启UVR5-WebUI": "UVR5-WebUI를 여시겠습니까?", + "是否开启dpo训练选项(实验性)": "dpo 훈련 옵션(실험적) 활성화 여부", + "是否开启打标WebUI": "웹 기반 주석 활성화 여부", + "是否直接对上次合成结果调整语速。防止随机性。": "직전 합성 결과의 언어 속도를 직접 조정하여 무작위성을 방지할까요?", + "显卡信息": "그래픽 카드 정보", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "본 소프트웨어는 MIT 라이선스로 오픈 소스로 제공되며, 제작자는 소프트웨어에 대해 어떠한 제어력도 가지지 않습니다. 소프트웨어 사용자 및 소프트웨어에서 내보낸 소리를 전파하는 자는 전적으로 책임져야 합니다.
이 조항을 인정하지 않으면 소프트웨어의 코드 및 파일을 사용하거나 인용할 수 없습니다. 루트 디렉터리의 LICENSE를 참조하십시오.", + "模型": "모델", + "模型分为三类:": "모델은 3가지로 나뉩니다:", + "模型切换": "모델 전환", + "每张显卡的batch_size": "각 그래픽 카드의 배치 크기", + "版本": "버전", + "粤英混合": "粤영 혼합", + "粤语": "粤語", + "终止ASR进程": "ASR 프로세스 종료", + "终止GPT训练": "GPT 훈련 종료", + "终止SSL提取进程": "SSL 추출 프로세스 종료", + "终止SoVITS训练": "SoVITS 훈련 종료", + "终止一键三连": "일괄 처리 종료", + "终止文本获取进程": "텍스트 추출 프로세스 종료", + "终止语义token提取进程": "의미 토큰 추출 프로세스 종료", + "终止语音切割": "음성 분리 종료", + "终止语音降噪进程": "음성 노이즈 제거 프로세스 종료", + "英文": "영어", + "语义token提取进程输出信息": "의미 토큰 추출 프로세스 출력 정보", + "语速": "언어 속도", + "语速调整,高为更快": "언어 속도 조정, 높을수록 빠름", + "语音切割进程输出信息": "음성 분리 프로세스 출력 정보", + "语音降噪进程输出信息": "음성 노이즈 제거 프로세스 출력 정보", + "请上传3~10秒内参考音频,超过会报错!": "3~10초 이내의 참고 오디오를 업로드하십시오. 초과하면 오류가 발생합니다!", + "请上传参考音频": "참고 오디오를 업로드하세요", + "请填入推理文本": "목표 텍스트를 입력하세요", + "请输入有效文本": "유효한 텍스트를 입력하세요", + "转换": "변환", + "输入待处理音频文件夹路径": "처리 대기 중인 오디오 폴더 경로 입력", + "输入文件夹路径": "폴더 경로 입력", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "logs/실험 이름 디렉터리에는 23456으로 시작하는 파일과 폴더가 있어야 함", + "输出信息": "출력 정보", + "输出文件夹路径": "출력 폴더 경로", + "输出的语音": "출력 음성", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "SoVITS_weights 및 GPT_weights에 저장된 훈련 완료된 모델 중 선택. 기본적으로 하나는 기본 모델이며 5초 Zero Shot TTS를 체험할 수 있습니다.", + "降噪结果输出文件夹": "노이즈 제거 결과 출력 폴더", + "降噪音频文件输入文件夹": "노이즈 제거 오디오 파일 입력 폴더", + "需要合成的文本": "합성해야 할 텍스트", + "需要合成的语种": "합성해야 할 언어", + "韩文": "한국어", + "韩英混合": "한영 혼합", + "音频自动切分输入路径,可文件可文件夹": "오디오 자동 분리 입력 경로, 파일 또는 폴더 가능", + "预训练的GPT模型路径": "사전 훈련된 GPT 모델 경로", + "预训练的SSL模型路径": "사전 훈련된 SSL 모델 경로", + "预训练的SoVITS-D模型路径": "사전 훈련된 SoVITS-D 모델 경로", + "预训练的SoVITS-G模型路径": "사전 훈련된 SoVITS-G 모델 경로", + "预训练的中文BERT模型路径": "사전 훈련된 중국어 BERT 모델 경로" +} diff --git a/tools/i18n/locale/pt_BR.json b/tools/i18n/locale/pt_BR.json new file mode 100644 index 0000000000000000000000000000000000000000..5b2ac5fef715bd0ff71aff4b7cfd2f97bd5fee33 --- /dev/null +++ b/tools/i18n/locale/pt_BR.json @@ -0,0 +1,178 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1)MDX-Net (onnx_dereverb): É a melhor opção para reverberação de dois canais, mas não pode remover a reverberação de um único canal;", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho:Remove os efeitos de atraso. Aggressive é mais completo que Normal na remoção, DeReverb remove adicionalmente a reverberação, pode remover a reverberação de um canal único, mas não remove completamente a reverberação de placa de alta frequência.", + "*GPT模型列表": "*Lista de modelos GPT", + "*SoVITS模型列表": "*Lista de modelos Sovits", + "*实验/模型名": "*Nome do experimento/modelo", + "*文本标注文件": "*Arquivo de marcação de texto", + "*训练集音频文件目录": "*Diretório de arquivos de áudio do conjunto de treinamento", + "*请上传并填写参考信息": "Por favor, faça o upload e preencha as informações de referência", + "*请填写需要合成的目标文本和语种模式": "*Por favor, insira o texto alvo a ser sintetizado e o modo de idioma.", + ".list标注文件的路径": "Caminho do arquivo de anotação .list", + "0-前置数据集获取工具": "0- Ferramenta de aquisição de conjunto de dados pré-frontal", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0A-UVR5 separação de voz e acompanhamento instrumental & ferramenta para remover reverberação e atraso", + "0b-语音切分工具": "0b- Ferramenta de corte de voz", + "0bb-语音降噪工具": "0bb- Ferramenta de redução de ruído de voz", + "0c-中文批量离线ASR工具": "0c- Ferramenta chinesa de ASR offline em lote", + "0d-语音文本校对标注工具": "0d- Ferramenta de correção e marcação de texto de voz", + "1-GPT-SoVITS-TTS": "1-GPT-SOVITS-TTS", + "1A-训练集格式化工具": "1A-Ferramenta de formatação de conjunto de dados de treinamento", + "1Aa-文本内容": "1AA-Conteúdo do texto", + "1Aabc-训练集格式化一键三连": "1AABC-Formatação de conjunto de treinamento em um clique", + "1Ab-SSL自监督特征提取": "1AB-Extração de características auto-supervisionadas SSL", + "1Ac-语义token提取": "1AC-Extração de token semântico", + "1B-微调训练": "1B-Treinamento de ajuste fino", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1ba-Treinamento SoVITS. O arquivo de modelo para compartilhamento é gerado em SOVITS_WEIGHTS", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1BB-Treinamento GPT. O arquivo de modelo para compartilhamento é gerado em GPT_WEIGHTS", + "1C-推理": "1C-raciocínio", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1. O tempo de processamento do modelo DeEcho-DeReverb é quase o dobro dos outros dois modelos DeEcho;", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1. Manter a voz: selecione isso para áudio sem harmonia, que preserva melhor a voz principal do que o HP5. Inclui dois modelos, HP2 e HP3; o HP3 pode permitir um pequeno vazamento de acompanhamento, mas preserva a voz principal um pouco melhor do que o HP2;", + "2-GPT-SoVITS-变声": "2-gpt-sovits-mudança de voz", + "2、MDX-Net-Dereverb模型挺慢的;": "2. O modelo MDX-Net-Dereverb é bastante lento;", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2. Manter apenas a voz principal: selecione isso para áudio com harmonia, pode haver uma redução na voz principal. Inclui um modelo HP5;", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3. A configuração mais limpa recomendada é usar primeiro o MDX-Net e depois o DeEcho-Aggressive.", + "3、去混响、去延迟模型(by FoxJoy):": "3. Modelo de remoção de reverberação e atraso (por FoxJoy):", + "ASR 模型": "Modelo ASR", + "ASR 模型尺寸": "Tamanho do modelo ASR", + "ASR 语言设置": "Configurações de idioma do ASR", + "ASR进程输出信息": "Informações de saída do processo ASR", + "GPT模型列表": "Lista de modelos GPT", + "GPT训练进程输出信息": "Informações de saída do processo de treinamento GPT", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "Parâmetros de amostragem do GPT (não muito baixos quando não houver texto de referência. Use o padrão se não tiver certeza):", + "GPU卡号,只能填1个整数": "Número da placa de vídeo, só é possível preencher com um número inteiro", + "GPU卡号以-分割,每个卡号一个进程": "Número da placa de vídeo dividido por-, cada número de placa é um processo", + "SSL进程输出信息": "Informações de saída do processo SSL", + "SoVITS模型列表": "Lista de modelos SoVITS", + "SoVITS训练进程输出信息": "Informações de saída do processo de treinamento SoVITS", + "TTS推理WebUI进程输出信息": "Informações de saída do processo webui de raciocínio TTS", + "TTS推理进程已关闭": "O processo de inferência TTS foi desativado", + "TTS推理进程已开启": "O processo de inferência TTS foi iniciado", + "UVR5已关闭": "UVR5 está desativado", + "UVR5已开启": "UVR5 está ativado", + "UVR5进程输出信息": "Informações de saída do processo UVR5", + "alpha_mix:混多少比例归一化后音频进来": "alpha_mix: Em que proporção o áudio normalizado é misturado de volta", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "HOP_SIZE: Como calcular a curva de volume, quanto menor a precisão, maior a quantidade de cálculos (não significa que quanto maior a precisão, melhor o efeito)", + "max:归一化后最大值多少": "MAX: Qual é o valor máximo após a normalização?", + "max_sil_kept:切完后静音最多留多长": "max_sil_kept: Depois de cortar, por quanto tempo no máximo o silêncio é mantido", + "min_interval:最短切割间隔": "min_interval: O intervalo de corte mínimo", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length: Comprimento mínimo de cada segmento. Se o primeiro segmento for muito curto, ele será unido aos segmentos seguintes até exceder este valor", + "temperature": "temperatura", + "threshold:音量小于这个值视作静音的备选切割点": "Limiar: O volume menor que este valor é considerado como um ponto de corte mudo alternativo", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "Informações de saída do processo de um clique", + "不切": "Não dividir", + "中文": "Chinês", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "Documentação do tutorial em chinês:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "Mistura de Chinês e Inglês", + "也可批量输入音频文件, 二选一, 优先读文件夹": "Também é possível inserir arquivos de áudio em lote; escolha uma opção, preferencialmente leia a pasta.", + "人声伴奏分离批量处理, 使用UVR5模型。": "Processamento em lote de separação de voz e acompanhamento, usando o modelo UVR5.", + "人声提取激进程度": "Grau de agressividade da extração de voz", + "以下文件或文件夹不存在:": "Nenhum Arquivo ou Pasta Encontrado:", + "以下模型不存在:": "Nenhum Modelo Tal:", + "伴奏人声分离&去混响&去回声": "Separação de acompanhamento e voz & remoção de reverberação & remoção de eco", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "Ao usar o modo sem texto de referência, recomenda-se usar um GPT ajustado. Se não conseguir ouvir claramente o áudio de referência (não sabe o que escrever), você pode ativar o modo e ignorar o texto de referência fornecido.", + "保存频率save_every_epoch": "Frequência de salvamento save_every_epoch", + "凑50字一切": "Complete com 50 caracteres", + "凑四句一切": "Complete com quatro frases", + "切分后的子音频的输出根目录": "Diretório raiz de saída do sub-áudio após o corte", + "切割使用的进程数": "Número de processos para corte", + "刷新模型路径": "Atualizar caminho do modelo", + "前端处理后的文本(每句):": "Texto após processamento front-end (por frase):", + "去混响/去延迟,附:": "Remoção de reverberação/remoção de atraso, anexo:", + "参考音频在3~10秒范围外,请更换!": "O áudio de referência está fora do intervalo de 3 a 10 segundos. Por favor, substitua!", + "参考音频的文本": "Texto do áudio de referência", + "参考音频的语种": "Idioma do áudio de referência", + "合成语音": "Voz sintetizada", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "Exemplo de formato de caminho de pasta válido: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例 (copie do endereço da barra do gerenciador de arquivos).", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "Preencha o diretório onde os áudios cortados estão localizados! O caminho completo dos arquivos de áudio lidos = este diretório - concatenação com o nome do arquivo de forma correspondente no arquivo .list (não o caminho completo). Se deixar em branco, use o caminho absoluto no arquivo .list.", + "多语种混合": "Mistura de múltiplos idiomas", + "多语种混合(粤语)": "Mistura Multilíngue (Yue)", + "实际输入的参考文本:": "Texto de referência realmente inserido:", + "实际输入的目标文本(切句后):": "Texto alvo realmente inserido (após divisão de frases):", + "实际输入的目标文本(每句):": "Texto alvo realmente inserido (por frase):", + "实际输入的目标文本:": "Texto alvo realmente inserido:", + "导出文件格式": "Formato de arquivo de exportação", + "开启GPT训练": "Ativar treinamento GPT", + "开启SSL提取": "Ativar extração SSL", + "开启SoVITS训练": "Ativar treinamento SoVITS", + "开启一键三连": "Ativar um clique", + "开启文本获取": "Ativar obtenção de texto", + "开启无参考文本模式。不填参考文本亦相当于开启。": "Ativar o modo sem texto de referência. Não preencher o texto de referência também equivale a ativar.", + "开启离线批量ASR": "Ativar ASR offline em lote", + "开启语义token提取": "Ativar extração de token semântico", + "开启语音切割": "Ativar corte de voz", + "开启语音降噪": "Ativar redução de ruído de voz", + "怎么切": "Como cortar", + "总训练轮数total_epoch": "Total de epoch de treinamento", + "总训练轮数total_epoch,不建议太高": "Total de epoch de treinamento, não é recomendável um valor muito alto", + "打标工具WebUI已关闭": "A ferramenta de marcação WebUI foi desativado", + "打标工具WebUI已开启": "A ferramenta de marcação WebUI está ativada", + "打标工具进程输出信息": "Informações de saída do processo da ferramenta de marcação", + "指定输出主人声文件夹": "Especificar a pasta de saída da voz principal", + "指定输出非主人声文件夹": "Especificar a pasta de saída da voz secundária", + "按中文句号。切": "Dividir por ponto final chinês", + "按标点符号切": "Dividir por sinais de pontuação", + "按英文句号.切": "Dividir por ponto final em inglês", + "数据类型精度": "precisão do tipo de dado", + "文本模块学习率权重": "Weight da taxa de aprendizado do módulo de texto", + "文本进程输出信息": "Informações de saída do processo de texto", + "施工中,请静候佳音": "Em construção, por favor, aguarde por um bom som", + "日文": "Japonês", + "日英混合": "Mistura de Japonês e Inglês", + "是否仅保存最新的ckpt文件以节省硬盘空间": "Se deve salvar apenas o último arquivo CKPT para economizar espaço em disco", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "Se deve salvar o modelo pequeno final na pasta Weights em cada ponto de salvamento de tempo", + "是否开启TTS推理WebUI": "Se deseja ativar o webui de raciocínio TTS", + "是否开启UVR5-WebUI": "Se deseja ativar a UVR5-WebUI", + "是否开启dpo训练选项(实验性)": "Se deseja ativar a opção de treinamento DPO (experimental)", + "是否开启打标WebUI": "Se deseja abrir o webui de marcação", + "是否直接对上次合成结果调整语速。防止随机性。": "Se deve ajustar diretamente a velocidade da fala do último resultado de síntese para evitar aleatoriedade.", + "显卡信息": "Informações da placa de vídeo", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "Este software é de código aberto sob a licença MIT. O autor não tem controle sobre o software. Aqueles que usam o software e difundem os sons exportados pelo software são totalmente responsáveis.
Se você não concorda com esta cláusula, não pode usar ou citar nenhum código e arquivo dentro do pacote de software. Consulte o diretório raiz LICENSE para mais detalhes.

Traduzido por Rafael Godoy Ebert", + "模型": "Modelo", + "模型分为三类:": "Modelos dividem-se em três categorias:", + "模型切换": "Troca de modelo", + "每张显卡的batch_size": "Tamanho do lote de cada placa de vídeo", + "版本": "Versão", + "粤英混合": "Mistura Yue-Inglês", + "粤语": "Yue", + "终止ASR进程": "Encerrar processo ASR", + "终止GPT训练": "Encerrar treinamento GPT", + "终止SSL提取进程": "Encerrar processo de extração SSL", + "终止SoVITS训练": "Encerrar treinamento SoVITS", + "终止一键三连": "Encerrar um clique", + "终止文本获取进程": "Encerrar processo de obtenção de texto", + "终止语义token提取进程": "Encerrar processo de extração de token semântico", + "终止语音切割": "Encerrar corte de voz", + "终止语音降噪进程": "Encerrar o processo de redução de ruído de voz", + "英文": "Inglês", + "语义token提取进程输出信息": "Informações de saída do processo de extração de token semântico", + "语速": "Velocidade da fala", + "语速调整,高为更快": "Ajustar a velocidade da fala, mais alta para mais rápido", + "语音切割进程输出信息": "Informações de saída do processo de corte de voz", + "语音降噪进程输出信息": "Informações de saída do processo de redução de ruído de voz", + "请上传3~10秒内参考音频,超过会报错!": "Por favor, faça upload de um áudio de referência com duração entre 3 e 10 segundos. Áudios fora dessa faixa causarão erro!", + "请上传参考音频": "Por Favor, Carregue o Áudio de Referência", + "请填入推理文本": "Por Favor, Preencha o Texto de Inferência", + "请输入有效文本": "Por favor, insira um texto válido", + "转换": "Converter", + "输入待处理音频文件夹路径": "Caminho da pasta de arquivos de áudio a ser processados", + "输入文件夹路径": "Caminho da pasta de entrada", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "Logs de saída/deve haver arquivos e pastas começando com 23456 no diretório do nome do experimento", + "输出信息": "Informações de saída", + "输出文件夹路径": "Caminho da pasta de saída", + "输出的语音": "Áudio de saída", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "Selecione os modelos armazenados em Sovits_weights e GPT_WEIGHTS. O padrão é o modelo inferior, experiência para 5 segundos de Zero Shot TTS", + "降噪结果输出文件夹": "Pasta de saída dos resultados de redução de ruído", + "降噪音频文件输入文件夹": "Pasta de entrada dos arquivos de áudio para redução de ruído", + "需要合成的文本": "Texto a ser sintetizado", + "需要合成的语种": "Idioma a ser sintetizado", + "韩文": "Coreano", + "韩英混合": "Mistura Coreano-Inglês", + "音频自动切分输入路径,可文件可文件夹": "Caminho de entrada automático de corte de áudio, pode ser um arquivo ou uma pasta", + "预训练的GPT模型路径": "Caminho do modelo GPT pre-train", + "预训练的SSL模型路径": "Caminho do modelo SSL pre-train", + "预训练的SoVITS-D模型路径": "Caminho do modelo SoVITS-D pre-train", + "预训练的SoVITS-G模型路径": "Caminho do modelo SoVITS-G pre-train", + "预训练的中文BERT模型路径": "Caminho do modelo BERT chinês pre-train" +} diff --git a/tools/i18n/locale/ru_RU.json b/tools/i18n/locale/ru_RU.json new file mode 100644 index 0000000000000000000000000000000000000000..ea4460647b21f93ddeba7212158b42288bf50ed5 --- /dev/null +++ b/tools/i18n/locale/ru_RU.json @@ -0,0 +1,178 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1)MDX-Net(onnx_dereverb):Это лучший выбор для реверберации с двумя каналами, но он не может устранить реверберацию с одним каналом;", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho:Устраняет эффект задержки. Aggressive устраняет более тщательно, чем Normal, DeReverb дополнительно устраняет реверберацию, может устранить реверберацию с одного канала, но не полностью устраняет высокочастотную реверберацию.", + "*GPT模型列表": "*Список моделей GPT", + "*SoVITS模型列表": "*Список моделей SoVITS", + "*实验/模型名": "*Название эксперимента/модели", + "*文本标注文件": "*Файл текстовой аннотации", + "*训练集音频文件目录": "*Директория аудиофайлов обучающего набора", + "*请上传并填写参考信息": "*Пожалуйста, загрузите и заполните референтные данные", + "*请填写需要合成的目标文本和语种模式": "*Пожалуйста, введите целевой текст для синтеза и режим языка", + ".list标注文件的路径": "Путь к файлу аннотации .list", + "0-前置数据集获取工具": "0-Инструмент для получения предварительного набора данных", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0a-Инструмент для разделения вокала и аккомпанемента UVR5 & устранения реверберации и задержек", + "0b-语音切分工具": "0b-Инструмент для разделения речи", + "0bb-语音降噪工具": "0bb-Инструмент для подавления шумов в голосе", + "0c-中文批量离线ASR工具": "0c-Инструмент для пакетной офлайн ASR на китайском", + "0d-语音文本校对标注工具": "0d-Инструмент для коррекции и аннотации текста речи", + "1-GPT-SoVITS-TTS": "1-GPT-SoVITS-TTS", + "1A-训练集格式化工具": "1A-Инструмент для форматирования обучающего набора", + "1Aa-文本内容": "1Aa-Содержание текста", + "1Aabc-训练集格式化一键三连": "1Aabc-Форматирование обучающего набора одним нажатием", + "1Ab-SSL自监督特征提取": "1Ab-Самоконтролируемое извлечение признаков SSL", + "1Ac-语义token提取": "1Ac-Извлечение семантических токенов", + "1B-微调训练": "1B-Дообучение", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1Ba-Обучение SoVITS. Файлы моделей для распространения находятся в SoVITS_weights.", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1Bb-Обучение GPT. Файлы моделей для распространения находятся в GPT_weights.", + "1C-推理": "1C-Инференс", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1. Время обработки модели DeEcho-DeReverb почти вдвое больше, чем у двух других моделей DeEcho;", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1. Сохранение голоса: выберите этот для аудио без гармоний, сохранение голоса будет лучше, чем HP5. Встроенные модели HP2 и HP3, HP3 может немного пропускать сопровождение, но сохраняет голос немного лучше, чем HP2;", + "2-GPT-SoVITS-变声": "2-GPT-SoVITS-переозвучивание", + "2、MDX-Net-Dereverb模型挺慢的;": "2. Модель MDX-Net-Dereverb довольно медленная;", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2. Сохранение только основного голоса: выберите это для аудио с гармониями, может ослабить основной голос. Встроенная модель HP5;", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3. Лично рекомендованная самая чистая конфигурация — сначала MDX-Net, затем DeEcho-Aggressive.", + "3、去混响、去延迟模型(by FoxJoy):": "3. Модель удаления реверберации и задержек (от FoxJoy):", + "ASR 模型": "Модель ASR", + "ASR 模型尺寸": "Размер модели ASR", + "ASR 语言设置": "Настройки языка ASR", + "ASR进程输出信息": "Информация о процессе ASR", + "GPT模型列表": "Список моделей GPT", + "GPT训练进程输出信息": "Информация о процессе обучения GPT", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "Параметры выборки GPT (не устанавливайте слишком низкие значения, если нет ссылочного текста. Используйте значения по умолчанию, если не уверены):", + "GPU卡号,只能填1个整数": "Номер GPU, можно указать только одно целое число", + "GPU卡号以-分割,每个卡号一个进程": "Номера GPU разделяются дефисом, на каждый номер отдельный процесс", + "SSL进程输出信息": "Информация о процессе SSL", + "SoVITS模型列表": "Список моделей SoVITS", + "SoVITS训练进程输出信息": "Информация о процессе обучения SoVITS", + "TTS推理WebUI进程输出信息": "Информация о процессе TTS инференса WebUI", + "TTS推理进程已关闭": "Процесс TTS-инференции остановлен", + "TTS推理进程已开启": "Процесс TTS-инференции запущен", + "UVR5已关闭": "UVR5 выключен", + "UVR5已开启": "UVR5 включен", + "UVR5进程输出信息": "Вывод информации процесса UVR5", + "alpha_mix:混多少比例归一化后音频进来": "alpha_mix:Какая доля нормализованного аудио смешивается", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size:Как рассчитывается кривая громкости, чем меньше, тем выше точность и больше вычислительная нагрузка (большая точность не всегда означает лучший результат)", + "max:归一化后最大值多少": "max:Максимальное значение после нормализации", + "max_sil_kept:切完后静音最多留多长": "max_sil_kept:Максимальная длительность тишины после разреза", + "min_interval:最短切割间隔": "min_interval:Минимальный интервал разреза", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length:Минимальная длина каждого отрезка; если первый отрезок слишком короткий, он будет соединен с последующими до достижения этого значения", + "temperature": "temperature", + "threshold:音量小于这个值视作静音的备选切割点": "threshold:Значение громкости ниже этого считается тишиной для альтернативной точки разреза", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "Информация о процессе одного нажатия", + "不切": "Не разрезать", + "中文": "Китайский", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "Документация на китайском языке:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "Китайский и английский", + "也可批量输入音频文件, 二选一, 优先读文件夹": "Можно также импортировать несколько аудиофайлов. Если путь к папке существует, то этот ввод игнорируется.", + "人声伴奏分离批量处理, 使用UVR5模型。": "Обработка разделения вокала и аккомпанемента пакетно с использованием модели UVR5.", + "人声提取激进程度": "Степень агрессивности извлечения вокала", + "以下文件或文件夹不存在:": "Нет такого файла или папки:", + "以下模型不存在:": "Этот модель не существует", + "伴奏人声分离&去混响&去回声": "Разделение вокала/аккомпанемента и удаление эхо", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "При использовании режима без референсного текста рекомендуется использовать настроенную модель GPT. Если не удается разобрать, что говорит референсное аудио (не знаете, что писать), можете включить этот режим, и он проигнорирует введенный референсный текст.", + "保存频率save_every_epoch": "Частота сохранения save_every_epoch", + "凑50字一切": "Соберите все в 50 символов", + "凑四句一切": "Собрать четыре предложения и разрезать", + "切分后的子音频的输出根目录": "Корневой каталог вывода для подаудио после разделения", + "切割使用的进程数": "Количество процессов, используемых для разрезания", + "刷新模型路径": "Обновить путь к модели", + "前端处理后的文本(每句):": "Текст после предварительной обработки (каждое предложение):", + "去混响/去延迟,附:": "Удаление реверберации/удаление задержки, примечание:", + "参考音频在3~10秒范围外,请更换!": "Референтное аудио вне диапазона 3~10 секунд, пожалуйста, замените!", + "参考音频的文本": "Текст референтного аудио", + "参考音频的语种": "Язык референтного аудио", + "合成语音": "Синтезированный голос", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "Пример допустимого формата пути к папке: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例 (просто скопируйте из адресной строки файлового менеджера).", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "Заполните каталог, где находятся аудиофайлы после разрезания! Полный путь к читаемым аудиофайлам = каталог - файл .list, имя файла соответствует волне (не полный путь). Если оставить пустым, будет использоваться абсолютный путь из файла .list.", + "多语种混合": "Смешанные языки", + "多语种混合(粤语)": "Многоязычная смесь (кантонский)", + "实际输入的参考文本:": "Фактически введенный референсный текст:", + "实际输入的目标文本(切句后):": "Фактически введенный целевой текст (после разбиения на предложения):", + "实际输入的目标文本(每句):": "Фактически введенный целевой текст (каждое предложение):", + "实际输入的目标文本:": "Фактически введенный целевой текст:", + "导出文件格式": "Формат выходных файлов", + "开启GPT训练": "Включить обучение GPT", + "开启SSL提取": "Включить извлечение SSL", + "开启SoVITS训练": "Включить обучение SoVITS", + "开启一键三连": "Включить одно нажатие", + "开启文本获取": "Включить получение текста", + "开启无参考文本模式。不填参考文本亦相当于开启。": "Включить режим без референтного текста. Не заполняя референтный текст, вы также включаете этот режим.", + "开启离线批量ASR": "Включить пакетную офлайн ASR", + "开启语义token提取": "Включить извлечение семантических токенов", + "开启语音切割": "Включить разрезание речи", + "开启语音降噪": "Включить шумоподавление", + "怎么切": "Как разрезать", + "总训练轮数total_epoch": "Общее количество эпох обучения total_epoch", + "总训练轮数total_epoch,不建议太高": "Общее количество эпох обучения total_epoch, не рекомендуется слишком высокое", + "打标工具WebUI已关闭": "WebUI инструмента маркировки остановлен", + "打标工具WebUI已开启": "WebUI инструмента маркировки запущен", + "打标工具进程输出信息": "Информация о процессе аннотации", + "指定输出主人声文件夹": "Путь к папке для сохранения вокала:", + "指定输出非主人声文件夹": "Путь к папке для сохранения аккомпанемента:", + "按中文句号。切": "Разделение по китайским точкам.", + "按标点符号切": "Разрезать по пунктуационным знакам", + "按英文句号.切": "Разрезать по английской точке.", + "数据类型精度": "точность типа данных", + "文本模块学习率权重": "Веса скорости обучения текстового модуля", + "文本进程输出信息": "Информация о процессе обработки текста", + "施工中,请静候佳音": "В разработке, ожидайте хороших новостей", + "日文": "Японский", + "日英混合": "Японский и английский", + "是否仅保存最新的ckpt文件以节省硬盘空间": "Сохранять только последние файлы ckpt для экономии места на диске?", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "Сохранять финальную версию модели в папке weights на каждом этапе сохранения?", + "是否开启TTS推理WebUI": "Включить TTS инференс WebUI", + "是否开启UVR5-WebUI": "Включить UVR5-WebUI", + "是否开启dpo训练选项(实验性)": "Включить опцию тренировки dpo (экспериментально)", + "是否开启打标WebUI": "Включить интерфейс веб-аннотации", + "是否直接对上次合成结果调整语速。防止随机性。": "Следует ли непосредственно регулировать скорость речи последнего синтезированного результата, чтобы избежать случайности?", + "显卡信息": "Информация о видеокарте", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "Это программное обеспечение открыто по лицензии MIT, автор не имеет никакого контроля над программным обеспечением, пользователи программного обеспечения и те, кто распространяет звуки, экспортированные программным обеспечением, несут полную ответственность.
Если вы не согласны с этими условиями, вы не можете использовать или ссылаться на любой код и файлы в пакете программного обеспечения. Смотрите LICENSE в корневом каталоге.", + "模型": "Модели", + "模型分为三类:": "Модели делятся на три типа:", + "模型切换": "Переключение модели", + "每张显卡的batch_size": "Размер пакета для каждой видеокарты", + "版本": "Версия", + "粤英混合": "Кантоно-английская смесь", + "粤语": "Кантонийский", + "终止ASR进程": "Прекратить процесс ASR", + "终止GPT训练": "Прекратить обучение GPT", + "终止SSL提取进程": "Прекратить процесс извлечения SSL", + "终止SoVITS训练": "Прекратить обучение SoVITS", + "终止一键三连": "Прекратить одно нажатие", + "终止文本获取进程": "Прекратить процесс получения текста", + "终止语义token提取进程": "Прекратить процесс извлечения семантических токенов", + "终止语音切割": "Прекратить разрезание речи", + "终止语音降噪进程": "Прекратить процесс шумоподавления", + "英文": "Английский", + "语义token提取进程输出信息": "Информация о процессе извлечения семантических токенов", + "语速": "Скорость речи", + "语速调整,高为更快": "Регулировка скорости речи, чем выше, тем быстрее", + "语音切割进程输出信息": "Информация о процессе разрезания речи", + "语音降噪进程输出信息": "Информация о процессе шумоподавления", + "请上传3~10秒内参考音频,超过会报错!": "Пожалуйста, загрузите референтное аудио длительностью от 3 до 10 секунд, иначе будет ошибка!", + "请上传参考音频": "Пожалуйста, загрузите эталонное аудио", + "请填入推理文本": "Пожалуйста, введите целевой текст", + "请输入有效文本": "Введите действительный текст", + "转换": "Преобразовать", + "输入待处理音频文件夹路径": "Путь к папке с аудиофайлами для обработки:", + "输入文件夹路径": "Введите путь к папке", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "В директории logs/имя_эксперимента должны быть файлы и папки, начинающиеся с 23456", + "输出信息": "Статистика", + "输出文件夹路径": "Путь к папке для вывода", + "输出的语音": "Выводимый звук", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "Выберите модель, сохраненную в SoVITS_weights и GPT_weights после обучения. По умолчанию используется базовая модель для 5-секундного Zero Shot TTS.", + "降噪结果输出文件夹": "Папка для вывода результатов шумоподавления", + "降噪音频文件输入文件夹": "Папка для ввода аудиофайлов для шумоподавления", + "需要合成的文本": "Текст для синтеза", + "需要合成的语种": "Язык для синтеза", + "韩文": "Корейский", + "韩英混合": "Корейско-английская смесь", + "音频自动切分输入路径,可文件可文件夹": "Путь ввода для автоматического разделения аудио, может быть файлом или папкой", + "预训练的GPT模型路径": "Путь к предварительно обученной модели GPT", + "预训练的SSL模型路径": "Путь к предварительно обученной модели SSL", + "预训练的SoVITS-D模型路径": "Путь к предварительно обученной модели SoVITS-D", + "预训练的SoVITS-G模型路径": "Путь к предварительно обученной модели SoVITS-G", + "预训练的中文BERT模型路径": "Путь к предварительно обученной китайской модели BERT" +} diff --git a/tools/i18n/locale/tr_TR.json b/tools/i18n/locale/tr_TR.json new file mode 100644 index 0000000000000000000000000000000000000000..c9e0cc282fb5418af0a2dfd0fd2ac2ca5525fe63 --- /dev/null +++ b/tools/i18n/locale/tr_TR.json @@ -0,0 +1,178 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1)MDX-Net(onnx_dereverb):İki kanallı yankılar için en iyi seçimdir, ancak tek kanallı yankıları ortadan kaldıramaz;", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho:Gecikme etkilerini giderir. Aggressive, Normal'dan daha kapsamlı bir şekilde giderir, DeReverb ek olarak yankıyı giderir, tek kanallı yankıyı giderebilir, ancak yüksek frekanslı plaka yankısını tamamen gideremez.", + "*GPT模型列表": "*GPT model listesi", + "*SoVITS模型列表": "*SoVITS model listesi", + "*实验/模型名": "*Deney/model adı", + "*文本标注文件": "*Metin etiketleme dosyası", + "*训练集音频文件目录": "*Eğitim seti ses dosyası dizini", + "*请上传并填写参考信息": "*Lütfen referans bilgilerini yükleyin ve doldurun", + "*请填写需要合成的目标文本和语种模式": "*Lütfen sentezlenecek hedef metni ve dil modunu giriniz.", + ".list标注文件的路径": ".list etiketleme dosyasının yolu", + "0-前置数据集获取工具": "0-Ön veri seti alma aracı", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0a-UVR5 vokal eşlik ayırma & yankıyı giderme gecikme aracı", + "0b-语音切分工具": "0b-Ses bölme aracı", + "0bb-语音降噪工具": "0bb-Ses gürültü azaltma aracı", + "0c-中文批量离线ASR工具": "0c-Çince toplu offline ASR aracı", + "0d-语音文本校对标注工具": "0d-Ses ve metin düzeltme etiketleme aracı", + "1-GPT-SoVITS-TTS": "1-GPT-SoVITS-TTS", + "1A-训练集格式化工具": "1A-Eğitim seti formatlama aracı", + "1Aa-文本内容": "1Aa-Metin içeriği", + "1Aabc-训练集格式化一键三连": "1Aabc-Eğitim seti formatlama tek tuşla üçleme", + "1Ab-SSL自监督特征提取": "1Ab-SSL kendi kendine denetimli özellik çıkarma", + "1Ac-语义token提取": "1Ac-Anlamsal token çıkarma", + "1B-微调训练": "1B-Fine-tuning eğitimi", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1Ba-SoVITS eğitimi. Paylaşım için model dosyaları SoVITS_weights altında çıkarılır.", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1Bb-GPT eğitimi. Paylaşım için model dosyaları GPT_weights altında çıkarılır.", + "1C-推理": "1C-Çıkarım", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1. DeEcho-DeReverb modelinin işleme süresi, diğer iki DeEcho modelinin neredeyse iki katıdır;", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1. Ses koruma: Arka vokal içermeyen sesler için bu seçeneği kullanın, ana sesi HP5'ten daha iyi korur. HP2 ve HP3 adlı iki model içerir; HP3, arka vokali biraz kaçırabilir ancak ana sesi HP2'ye göre biraz daha iyi korur;", + "2-GPT-SoVITS-变声": "2-GPT-SoVITS-Ses Değiştirme", + "2、MDX-Net-Dereverb模型挺慢的;": "2. MDX-Net-Dereverb modeli oldukça yavaştır;", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2. Sadece ana sesi koruma: Arka vokalleri içeren sesler için bu seçeneği kullanın, ana sesi zayıflatabilir. İçinde HP5 modeli var;", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3. Kişisel olarak en temiz konfigürasyon MDX-Net'in ardından DeEcho-Aggressive'dir.", + "3、去混响、去延迟模型(by FoxJoy):": "3. Yankı ve gecikme giderme modeli (FoxJoy tarafından):", + "ASR 模型": "ASR modeli", + "ASR 模型尺寸": "ASR model boyutu", + "ASR 语言设置": "ASR dil ayarları", + "ASR进程输出信息": "ASR işlemi çıktı bilgisi", + "GPT模型列表": "GPT model listesi", + "GPT训练进程输出信息": "GPT eğitimi işlemi çıktı bilgisi", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "GPT örnekleme parametreleri (referans metin olmadığında çok düşük olmamalıdır. Emin değilseniz varsayılanı kullanın):", + "GPU卡号,只能填1个整数": "GPU kart numarası, sadece bir tamsayı girilebilir", + "GPU卡号以-分割,每个卡号一个进程": "GPU kart numaraları - ile ayrılır, her kart numarası için bir işlem", + "SSL进程输出信息": "SSL işlemi çıktı bilgisi", + "SoVITS模型列表": "SoVITS model listesi", + "SoVITS训练进程输出信息": "SoVITS eğitimi işlemi çıktı bilgisi", + "TTS推理WebUI进程输出信息": "TTS çıkarımı WebUI işlemi çıktı bilgisi", + "TTS推理进程已关闭": "TTS çıkarım işlemi kapatıldı", + "TTS推理进程已开启": "TTS çıkarım işlemi başlatıldı", + "UVR5已关闭": "UVR5 kapandı", + "UVR5已开启": "UVR5 açıldı", + "UVR5进程输出信息": "UVR5 işlem çıktı bilgisi", + "alpha_mix:混多少比例归一化后音频进来": "alpha_mix:Normalizasyondan sonraki sesin ne kadarlık bir oranı karıştırılsın", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size:Ses seviyesi eğrisi nasıl hesaplanır, ne kadar küçükse hassasiyet o kadar yüksek ve hesaplama yükü o kadar artar (hassasiyet arttıkça etki mutlaka daha iyi olmaz)", + "max:归一化后最大值多少": "max:Normalizasyondan sonra maksimum değer ne kadar", + "max_sil_kept:切完后静音最多留多长": "max_sil_kept:Kesimden sonra en fazla ne kadar sessizlik bırakılır", + "min_interval:最短切割间隔": "min_interval:Minimum kesim aralığı", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length: bölümün minimum uzunluğu, ilk bölüm çok kısa ise, bu değeri aşana kadar sonraki bölümlerle birleştirilir", + "temperature": "temperature", + "threshold:音量小于这个值视作静音的备选切割点": "threshold:Ses bu değerden düşükse sessiz olarak kabul edilen alternatif kesim noktası", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "Tek tuşla üçleme işlemi çıktı bilgisi", + "不切": "Kesme", + "中文": "Çince", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "Çince öğretici belge:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "Çince ve İngilizce karışık", + "也可批量输入音频文件, 二选一, 优先读文件夹": "Ses dosyaları ayrıca toplu olarak, iki seçimle, öncelikli okuma klasörüyle içe aktarılabilir", + "人声伴奏分离批量处理, 使用UVR5模型。": "Vokal ve akor ayırma toplu işleme, UVR5 modelini kullanarak.", + "人声提取激进程度": "Vokal çıkarma agresiflik derecesi", + "以下文件或文件夹不存在:": "Böyle Bir Dosya veya Klasör Yok:", + "以下模型不存在:": "Böyle bir model yok:", + "伴奏人声分离&去混响&去回声": "Vokal/Müzik Ayrıştırma ve Yankı Giderme", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "Referans metin modu olmadan kullanıldığında, referans sesi net duyulmadığında (ne yazılacağı bilinmiyorsa) açık bırakılması önerilir, bu durumda girilen referans metni göz ardı edilir.", + "保存频率save_every_epoch": "Kayıt sıklığı save_every_epoch", + "凑50字一切": "50 kelime birleştir ve kes", + "凑四句一切": "Dört cümleyi bir araya getirip kes", + "切分后的子音频的输出根目录": "Bölündükten sonra alt ses dosyalarının çıktı kök dizini", + "切割使用的进程数": "Kesim için kullanılan işlem sayısı", + "刷新模型路径": "Model yolu yenile", + "前端处理后的文本(每句):": "Ön işleme tabi tutulan metin (her cümle):", + "去混响/去延迟,附:": "Yankı giderme/Geçikme giderme, ek:", + "参考音频在3~10秒范围外,请更换!": "Referans ses dosyası 3~10 saniye aralığının dışında, lütfen değiştirin!", + "参考音频的文本": "Referans ses dosyasının metni", + "参考音频的语种": "Referans ses dosyasının dili", + "合成语音": "Ses sentezi", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "Geçerli klasör yolu formatı örneği: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例 (dosya yöneticisi adres çubuğundan kopyalayabilirsiniz).", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "Kesmeye uygun ses dosyalarının bulunduğu dizini doldurun! Okunan ses dosyasının tam yolu = bu dizin + list dosyasındaki dalga biçimiyle eşleşen dosya adı (tam yol değil). Boş bırakılırsa, .list dosyasındaki tam yol kullanılır.", + "多语种混合": "Çok dilli karışım", + "多语种混合(粤语)": "Çok dilli karışık (Yue)", + "实际输入的参考文本:": "Gerçekten girilen referans metin:", + "实际输入的目标文本(切句后):": "Gerçekten girilen hedef metin (cümleler kesildikten sonra):", + "实际输入的目标文本(每句):": "Gerçekten girilen hedef metin (her cümle):", + "实际输入的目标文本:": "Gerçekten girilen hedef metin:", + "导出文件格式": "Dışa aktarma dosya formatı", + "开启GPT训练": "GPT eğitimini başlat", + "开启SSL提取": "SSL çıkarmayı başlat", + "开启SoVITS训练": "SoVITS eğitimini başlat", + "开启一键三连": "Tek tuşla üçlemeyi başlat", + "开启文本获取": "Metin alma başlat", + "开启无参考文本模式。不填参考文本亦相当于开启。": "Referans metni olmayan mod açık. Referans metni doldurulmazsa bu mod otomatik olarak açılır.", + "开启离线批量ASR": "Offline toplu ASR başlat", + "开启语义token提取": "Anlamsal token çıkarmayı başlat", + "开启语音切割": "Ses kesimi başlat", + "开启语音降噪": "Ses gürültü azaltmayı başlat", + "怎么切": "Nasıl kesilir", + "总训练轮数total_epoch": "Toplam eğitim turu sayısı total_epoch", + "总训练轮数total_epoch,不建议太高": "Toplam eğitim turu sayısı total_epoch, çok yüksek önerilmez", + "打标工具WebUI已关闭": "Etiketleme aracı WebUI'si kapatıldı", + "打标工具WebUI已开启": "Etiketleme aracı WebUI'si açıldı", + "打标工具进程输出信息": "Etiketleme aracı işlemi çıktı bilgisi", + "指定输出主人声文件夹": "Vokal için çıkış klasörünü belirtin:", + "指定输出非主人声文件夹": "Müzik ve diğer sesler için çıkış klasörünü belirtin:", + "按中文句号。切": "Çince dönem işaretine göre kes", + "按标点符号切": "Noktalama işaretlerine göre kes", + "按英文句号.切": "İngilizce nokta işaretine göre kes", + "数据类型精度": "veri türü doğruluğu", + "文本模块学习率权重": "Metin modülü öğrenme oranı ağırlığı", + "文本进程输出信息": "Metin işlemi çıktı bilgisi", + "施工中,请静候佳音": "Yapım aşamasında, lütfen iyi haberler için bekleyin", + "日文": "Japonca", + "日英混合": "Japonca ve İngilizce karışımı", + "是否仅保存最新的ckpt文件以节省硬盘空间": "Sadece en yeni ckpt dosyasını kaydederek disk alanından tasarruf edilsin mi", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "Her kayıt zamanında son küçük modelin weights klasörüne kaydedilmesi gerekiyor mu", + "是否开启TTS推理WebUI": "TTS çıkarımı WebUI'si başlatılsın mı", + "是否开启UVR5-WebUI": "UVR5-WebUI açılsın mı", + "是否开启dpo训练选项(实验性)": "dpo eğitim seçeneği açılsın mı? (deneysel)", + "是否开启打标WebUI": "Etiketleme WebUI'si başlatılsın mı", + "是否直接对上次合成结果调整语速。防止随机性。": "Son sentez sonucunun konuşma hızını doğrudan ayarlamak, rastlantısallığı önlemek için mi?", + "显卡信息": "Ekran kartı bilgisi", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "Bu yazılım MIT lisansı ile açık kaynaktır, yazar yazılım üzerinde herhangi bir kontrol gücüne sahip değildir, yazılımı kullanıcılar ve yazılım tarafından üretilen sesleri yayınlayanlar tüm sorumluluğu üstlenir.
Eğer bu şartları kabul etmiyorsanız, yazılım paketindeki hiçbir kodu veya dosyayı kullanamaz veya atıfta bulunamazsınız. Ayrıntılar için ana dizindeki LICENSE'ı görün.", + "模型": "Model", + "模型分为三类:": "Modeller üç türdedir:", + "模型切换": "Model değiştirme", + "每张显卡的batch_size": "Her bir ekran kartı için batch_size", + "版本": "Versiyon", + "粤英混合": "Yue-İngilizce Karışık", + "粤语": "Yue", + "终止ASR进程": "ASR işlemini durdur", + "终止GPT训练": "GPT eğitimini durdur", + "终止SSL提取进程": "SSL çıkarma işlemini durdur", + "终止SoVITS训练": "SoVITS eğitimini durdur", + "终止一键三连": "Tek tuşla üçlemeyi durdur", + "终止文本获取进程": "Metin alma işlemini durdur", + "终止语义token提取进程": "Anlamsal token çıkarma işlemini durdur", + "终止语音切割": "Ses kesimini durdur", + "终止语音降噪进程": "Gürültü azaltma işlemini durdur", + "英文": "İngilizce", + "语义token提取进程输出信息": "Anlamsal token çıkarma işlemi çıktı bilgisi", + "语速": "Konuşma hızı", + "语速调整,高为更快": "Konuşma hızını ayarla, yüksek daha hızlı", + "语音切割进程输出信息": "Ses kesim işlemi çıktı bilgisi", + "语音降噪进程输出信息": "Gürültü azaltma işlemi çıktı bilgisi", + "请上传3~10秒内参考音频,超过会报错!": "Lütfen 3~10 saniye arasında bir referans ses dosyası yükleyin, aşım durumunda hata verilecektir!", + "请上传参考音频": "Lütfen Referans Sesi Yükleyin", + "请填入推理文本": "Lütfen Hedef Metni Girin", + "请输入有效文本": "Geçerli metin girin", + "转换": "Dönüştür", + "输入待处理音频文件夹路径": "İşlenecek ses klasörünün yolunu girin:", + "输入文件夹路径": "Dosya klasörü yolu girin", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "Çıktı logs/deney adı dizininde 23456 ile başlayan dosya ve klasörler olmalı", + "输出信息": "Çıkış bilgisi", + "输出文件夹路径": "Çıktı klasörü yolu", + "输出的语音": "Çıktı sesi", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "Eğitimi tamamlanmış ve SoVITS_weights ile GPT_weights altına kaydedilmiş modeli seçin. Varsayılan bir temel modeldir, 5 saniyelik Zero Shot TTS deneyimi için kullanılır.", + "降噪结果输出文件夹": "Gürültü azaltma sonuçları çıktı klasörü", + "降噪音频文件输入文件夹": "Gürültü azaltma ses dosyaları giriş klasörü", + "需要合成的文本": "Sentezlenmesi gereken metin", + "需要合成的语种": "Sentezlenmesi gereken dil", + "韩文": "Korece", + "韩英混合": "Korece-İngilizce Karışık", + "音频自动切分输入路径,可文件可文件夹": "Ses otomatik bölme giriş yolu, dosya veya klasör olabilir", + "预训练的GPT模型路径": "Ön eğitilmiş GPT model yolu", + "预训练的SSL模型路径": "Ön eğitilmiş SSL model yolu", + "预训练的SoVITS-D模型路径": "Ön eğitilmiş SoVITS-D model yolu", + "预训练的SoVITS-G模型路径": "Ön eğitilmiş SoVITS-G model yolu", + "预训练的中文BERT模型路径": "Ön eğitilmiş Çince BERT model yolu" +} diff --git a/tools/i18n/locale/zh_CN.json b/tools/i18n/locale/zh_CN.json new file mode 100644 index 0000000000000000000000000000000000000000..49eccf9c720dfac6d65ab01aa845570b33d624e5 --- /dev/null +++ b/tools/i18n/locale/zh_CN.json @@ -0,0 +1,178 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho:去除延迟效果。Aggressive 比 Normal 去除得更彻底,DeReverb 额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。", + "*GPT模型列表": "*GPT模型列表", + "*SoVITS模型列表": "*SoVITS模型列表", + "*实验/模型名": "*实验/模型名", + "*文本标注文件": "*文本标注文件", + "*训练集音频文件目录": "*训练集音频文件目录", + "*请上传并填写参考信息": "*请上传并填写参考信息", + "*请填写需要合成的目标文本和语种模式": "*请填写需要合成的目标文本和语种模式", + ".list标注文件的路径": ".list标注文件的路径", + "0-前置数据集获取工具": "0-前置数据集获取工具", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0a-UVR5人声伴奏分离&去混响去延迟工具", + "0b-语音切分工具": "0b-语音切分工具", + "0bb-语音降噪工具": "0bb-语音降噪工具", + "0c-中文批量离线ASR工具": "0c-中文批量离线ASR工具", + "0d-语音文本校对标注工具": "0d-语音文本校对标注工具", + "1-GPT-SoVITS-TTS": "1-GPT-SoVITS-TTS", + "1A-训练集格式化工具": "1A-训练集格式化工具", + "1Aa-文本内容": "1Aa-文本内容", + "1Aabc-训练集格式化一键三连": "1Aabc-训练集格式化一键三连", + "1Ab-SSL自监督特征提取": "1Ab-SSL自监督特征提取", + "1Ac-语义token提取": "1Ac-语义token提取", + "1B-微调训练": "1B-微调训练", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。", + "1C-推理": "1C-推理", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;", + "2-GPT-SoVITS-变声": "2-GPT-SoVITS-变声", + "2、MDX-Net-Dereverb模型挺慢的;": "2、MDX-Net-Dereverb模型挺慢的;", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。", + "3、去混响、去延迟模型(by FoxJoy):": "3、去混响、去延迟模型(by FoxJoy):", + "ASR 模型": "ASR 模型", + "ASR 模型尺寸": "ASR 模型尺寸", + "ASR 语言设置": "ASR 语言设置", + "ASR进程输出信息": "ASR进程输出信息", + "GPT模型列表": "GPT模型列表", + "GPT训练进程输出信息": "GPT训练进程输出信息", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "GPT采样参数(无参考文本时不要太低。不懂就用默认):", + "GPU卡号,只能填1个整数": "GPU卡号,只能填1个整数", + "GPU卡号以-分割,每个卡号一个进程": "GPU卡号以-分割,每个卡号一个进程", + "SSL进程输出信息": "SSL进程输出信息", + "SoVITS模型列表": "SoVITS模型列表", + "SoVITS训练进程输出信息": "SoVITS训练进程输出信息", + "TTS推理WebUI进程输出信息": "TTS推理WebUI进程输出信息", + "TTS推理进程已关闭": "TTS推理进程已关闭", + "TTS推理进程已开启": "TTS推理进程已开启", + "UVR5已关闭": "UVR5已关闭", + "UVR5已开启": "UVR5已开启", + "UVR5进程输出信息": "UVR5进程输出信息", + "alpha_mix:混多少比例归一化后音频进来": "alpha_mix:混多少比例归一化后音频进来", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)", + "max:归一化后最大值多少": "max:归一化后最大值多少", + "max_sil_kept:切完后静音最多留多长": "max_sil_kept:切完后静音最多留多长", + "min_interval:最短切割间隔": "min_interval:最短切割间隔", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值", + "temperature": "temperature", + "threshold:音量小于这个值视作静音的备选切割点": "threshold:音量小于这个值视作静音的备选切割点", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "一键三连进程输出信息", + "不切": "不切", + "中文": "中文", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "中英混合", + "也可批量输入音频文件, 二选一, 优先读文件夹": "也可批量输入音频文件, 二选一, 优先读文件夹", + "人声伴奏分离批量处理, 使用UVR5模型。": "人声伴奏分离批量处理, 使用UVR5模型。", + "人声提取激进程度": "人声提取激进程度", + "以下文件或文件夹不存在:": "以下文件或文件夹不存在:", + "以下模型不存在:": "以下模型不存在:", + "伴奏人声分离&去混响&去回声": "伴奏人声分离&去混响&去回声", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。", + "保存频率save_every_epoch": "保存频率save_every_epoch", + "凑50字一切": "凑50字一切", + "凑四句一切": "凑四句一切", + "切分后的子音频的输出根目录": "切分后的子音频的输出根目录", + "切割使用的进程数": "切割使用的进程数", + "刷新模型路径": "刷新模型路径", + "前端处理后的文本(每句):": "前端处理后的文本(每句):", + "去混响/去延迟,附:": "去混响/去延迟,附:", + "参考音频在3~10秒范围外,请更换!": "参考音频在3~10秒范围外,请更换!", + "参考音频的文本": "参考音频的文本", + "参考音频的语种": "参考音频的语种", + "合成语音": "合成语音", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。", + "多语种混合": "多语种混合", + "多语种混合(粤语)": "多语种混合(粤语)", + "实际输入的参考文本:": "实际输入的参考文本:", + "实际输入的目标文本(切句后):": "实际输入的目标文本(切句后):", + "实际输入的目标文本(每句):": "实际输入的目标文本(每句):", + "实际输入的目标文本:": "实际输入的目标文本:", + "导出文件格式": "导出文件格式", + "开启GPT训练": "开启GPT训练", + "开启SSL提取": "开启SSL提取", + "开启SoVITS训练": "开启SoVITS训练", + "开启一键三连": "开启一键三连", + "开启文本获取": "开启文本获取", + "开启无参考文本模式。不填参考文本亦相当于开启。": "开启无参考文本模式。不填参考文本亦相当于开启。", + "开启离线批量ASR": "开启离线批量ASR", + "开启语义token提取": "开启语义token提取", + "开启语音切割": "开启语音切割", + "开启语音降噪": "开启语音降噪", + "怎么切": "怎么切", + "总训练轮数total_epoch": "总训练轮数total_epoch", + "总训练轮数total_epoch,不建议太高": "总训练轮数total_epoch,不建议太高", + "打标工具WebUI已关闭": "打标工具WebUI已关闭", + "打标工具WebUI已开启": "打标工具WebUI已开启", + "打标工具进程输出信息": "打标工具进程输出信息", + "指定输出主人声文件夹": "指定输出主人声文件夹", + "指定输出非主人声文件夹": "指定输出非主人声文件夹", + "按中文句号。切": "按中文句号。切", + "按标点符号切": "按标点符号切", + "按英文句号.切": "按英文句号.切", + "数据类型精度": "数据类型精度", + "文本模块学习率权重": "文本模块学习率权重", + "文本进程输出信息": "文本进程输出信息", + "施工中,请静候佳音": "施工中,请静候佳音", + "日文": "日文", + "日英混合": "日英混合", + "是否仅保存最新的ckpt文件以节省硬盘空间": "是否仅保存最新的ckpt文件以节省硬盘空间", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "是否在每次保存时间点将最终小模型保存至weights文件夹", + "是否开启TTS推理WebUI": "是否开启TTS推理WebUI", + "是否开启UVR5-WebUI": "是否开启UVR5-WebUI", + "是否开启dpo训练选项(实验性)": "是否开启dpo训练选项(实验性)", + "是否开启打标WebUI": "是否开启打标WebUI", + "是否直接对上次合成结果调整语速。防止随机性。": "是否直接对上次合成结果调整语速。防止随机性。", + "显卡信息": "显卡信息", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.", + "模型": "模型", + "模型分为三类:": "模型分为三类:", + "模型切换": "模型切换", + "每张显卡的batch_size": "每张显卡的batch_size", + "版本": "版本", + "粤英混合": "粤英混合", + "粤语": "粤语", + "终止ASR进程": "终止ASR进程", + "终止GPT训练": "终止GPT训练", + "终止SSL提取进程": "终止SSL提取进程", + "终止SoVITS训练": "终止SoVITS训练", + "终止一键三连": "终止一键三连", + "终止文本获取进程": "终止文本获取进程", + "终止语义token提取进程": "终止语义token提取进程", + "终止语音切割": "终止语音切割", + "终止语音降噪进程": "终止语音降噪进程", + "英文": "英文", + "语义token提取进程输出信息": "语义token提取进程输出信息", + "语速": "语速", + "语速调整,高为更快": "语速调整,高为更快", + "语音切割进程输出信息": "语音切割进程输出信息", + "语音降噪进程输出信息": "语音降噪进程输出信息", + "请上传3~10秒内参考音频,超过会报错!": "请上传3~10秒内参考音频,超过会报错!", + "请上传参考音频": "请上传参考音频", + "请填入推理文本": "请填入推理文本", + "请输入有效文本": "请输入有效文本", + "转换": "转换", + "输入待处理音频文件夹路径": "输入待处理音频文件夹路径", + "输入文件夹路径": "输入文件夹路径", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "输出logs/实验名目录下应有23456开头的文件和文件夹", + "输出信息": "输出信息", + "输出文件夹路径": "输出文件夹路径", + "输出的语音": "输出的语音", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。", + "降噪结果输出文件夹": "降噪结果输出文件夹", + "降噪音频文件输入文件夹": "降噪音频文件输入文件夹", + "需要合成的文本": "需要合成的文本", + "需要合成的语种": "需要合成的语种", + "韩文": "韩文", + "韩英混合": "韩英混合", + "音频自动切分输入路径,可文件可文件夹": "音频自动切分输入路径,可文件可文件夹", + "预训练的GPT模型路径": "预训练的GPT模型路径", + "预训练的SSL模型路径": "预训练的SSL模型路径", + "预训练的SoVITS-D模型路径": "预训练的SoVITS-D模型路径", + "预训练的SoVITS-G模型路径": "预训练的SoVITS-G模型路径", + "预训练的中文BERT模型路径": "预训练的中文BERT模型路径" +} diff --git a/tools/i18n/locale/zh_HK.json b/tools/i18n/locale/zh_HK.json new file mode 100644 index 0000000000000000000000000000000000000000..c70f6eed163385b86c49c15290d9d6be3cbf3235 --- /dev/null +++ b/tools/i18n/locale/zh_HK.json @@ -0,0 +1,178 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1)MDX-Net(onnx_dereverb):對於雙通道混響是最佳選擇,但不能去除單通道混響;", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho: 去除延遲效果。Aggressive 比 Normal 去除得更徹底,DeReverb 額外去除混響,可去除單聲道混響,但對高頻重的板式混響去不乾淨。", + "*GPT模型列表": "*GPT模型列表", + "*SoVITS模型列表": "*SoVITS模型列表", + "*实验/模型名": "*實驗/模型名", + "*文本标注文件": "*文本標注文件", + "*训练集音频文件目录": "*訓練集音頻文件目錄", + "*请上传并填写参考信息": "*請上傳並填寫參考信息", + "*请填写需要合成的目标文本和语种模式": "請填寫需要合成的目標文本和語言模式", + ".list标注文件的路径": ".list標註文件的路徑", + "0-前置数据集获取工具": "0-前置數據集獲取工具", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0a-UVR5人聲伴奏分離&去混響去延遲工具", + "0b-语音切分工具": "0b-語音切分工具", + "0bb-语音降噪工具": "0bb-語音降噪工具", + "0c-中文批量离线ASR工具": "0c-中文批量離線ASR工具", + "0d-语音文本校对标注工具": "0d-語音文本校對標注工具", + "1-GPT-SoVITS-TTS": "1-GPT-SoVITS-TTS", + "1A-训练集格式化工具": "1A-訓練集格式化工具", + "1Aa-文本内容": "1Aa-文本內容", + "1Aabc-训练集格式化一键三连": "1Aabc-訓練集格式化一鍵三連", + "1Ab-SSL自监督特征提取": "1Ab-SSL自監督特徵提取", + "1Ac-语义token提取": "1Ac-語義token提取", + "1B-微调训练": "1B-微調訓練", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1Ba-SoVITS訓練。用於分享的模型文件輸出在SoVITS_weights下。", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1Bb-GPT訓練。用於分享的模型文件輸出在GPT_weights下。", + "1C-推理": "1C-推理", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1、DeEcho-DeReverb 模型的處理時間是另外兩個 DeEcho 模型的接近兩倍;", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1、保留人聲:不帶和聲的音頻選這個,對主人聲保留比HP5更好。內置HP2和HP3兩個模型,HP3可能輕微漏伴奏但對主人聲保留比HP2稍微好一點點;", + "2-GPT-SoVITS-变声": "2-GPT-SoVITS-變聲", + "2、MDX-Net-Dereverb模型挺慢的;": "2、MDX-Net-Dereverb 模型的處理時間挺慢的;", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2、僅保留主人聲:帶和聲的音頻選這個,對主人聲可能有削弱。內置HP5一個模型;", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3、個人推薦的最乾淨的配置是先 MDX-Net 再 DeEcho-Aggressive。", + "3、去混响、去延迟模型(by FoxJoy):": "3、去混響、去延遲模型(by FoxJoy):", + "ASR 模型": "ASR 模型", + "ASR 模型尺寸": "ASR 模型尺寸", + "ASR 语言设置": "ASR 語言設置", + "ASR进程输出信息": "ASR進程輸出信息", + "GPT模型列表": "GPT模型列表", + "GPT训练进程输出信息": "GPT訓練進程輸出信息", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "GPT 采样参数(无参考文本时不要太低。不懂就用默认):", + "GPU卡号,只能填1个整数": "GPU卡號,只能填1個整數", + "GPU卡号以-分割,每个卡号一个进程": "GPU卡號以-分割,每個卡號一個進程", + "SSL进程输出信息": "SSL進程輸出信息", + "SoVITS模型列表": "SoVITS模型列表", + "SoVITS训练进程输出信息": "SoVITS訓練進程輸出信息", + "TTS推理WebUI进程输出信息": "TTS推理WebUI進程輸出信息", + "TTS推理进程已关闭": "TTS推理進程已關閉", + "TTS推理进程已开启": "TTS推理進程已開啟", + "UVR5已关闭": "UVR5已關閉", + "UVR5已开启": "UVR5已開啟", + "UVR5进程输出信息": "UVR5進程輸出信息", + "alpha_mix:混多少比例归一化后音频进来": "alpha_mix:混多少比例歸一化後音頻進來", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size:怎麼算音量曲線,越小精度越大計算量越高(不是精度越大效果越好)", + "max:归一化后最大值多少": "max:歸一化後最大值多少", + "max_sil_kept:切完后静音最多留多长": "max_sil_kept:切完後靜音最多留多長", + "min_interval:最短切割间隔": "min_interval:最短切割間隔", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length:每段最小多長,如果第一段太短一直和後面段連起來直到超過這個值", + "temperature": "temperature", + "threshold:音量小于这个值视作静音的备选切割点": "threshold:音量小於這個值視作靜音的備選切割點", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "一鍵三連進程輸出信息", + "不切": "不切", + "中文": "中文", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "中文教程文檔:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "中英混合", + "也可批量输入音频文件, 二选一, 优先读文件夹": "也可批量输入音频文件, 二选一, 优先读文件夹", + "人声伴奏分离批量处理, 使用UVR5模型。": "人聲伴奏分離批量處理, 使用UVR5模型。", + "人声提取激进程度": "人聲提取激進程度", + "以下文件或文件夹不存在:": "沒有這樣的檔案或文件夾:", + "以下模型不存在:": "以下模型不存在:", + "伴奏人声分离&去混响&去回声": "伴奏人聲分離&去混響&去回聲", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "使用無參考文本模式時建議使用微調的GPT,聽不清參考音頻說的是啥(不知道寫啥)可以開啟,開啟後無視填寫的參考文本。", + "保存频率save_every_epoch": "保存頻率save_every_epoch", + "凑50字一切": "湊50字一切", + "凑四句一切": "湊四句一切", + "切分后的子音频的输出根目录": "切分後的子音頻的輸出根目錄", + "切割使用的进程数": "切割使用的進程數", + "刷新模型路径": "刷新模型路徑", + "前端处理后的文本(每句):": "前端處理後的文本(每句):", + "去混响/去延迟,附:": "去混響/去延遲,附", + "参考音频在3~10秒范围外,请更换!": "參考音頻在3~10秒範圍外,請更換!", + "参考音频的文本": "參考音頻的文本", + "参考音频的语种": "參考音頻的語種", + "合成语音": "合成語音", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "合格的文件夾路徑格式舉例: E:\\codes\\py39\\vits_vc_gpu\\白鷺霜華測試樣例(去文件管理器地址欄拷就行了)。", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "填切割後音頻所在目錄!讀取的音頻文件完整路徑=該目錄-拼接-list文件裡波形對應的文件名(不是全路徑)。如果留空則使用.list文件裡的絕對全路徑。", + "多语种混合": "多語種混合", + "多语种混合(粤语)": "多語種混合 (粵語)", + "实际输入的参考文本:": "實際輸入的參考文本:", + "实际输入的目标文本(切句后):": "實際輸入的目標文本(切句後):", + "实际输入的目标文本(每句):": "實際輸入的目標文本(每句):", + "实际输入的目标文本:": "實際輸入的目標文本:", + "导出文件格式": "導出檔格式", + "开启GPT训练": "開啟GPT訓練", + "开启SSL提取": "開啟SSL提取", + "开启SoVITS训练": "開啟SoVITS訓練", + "开启一键三连": "開啟一鍵三連", + "开启文本获取": "開啟文本獲取", + "开启无参考文本模式。不填参考文本亦相当于开启。": "開啟無參考文本模式。不填參考文本亦相當於開啟。", + "开启离线批量ASR": "開啟離線批量ASR", + "开启语义token提取": "開啟語義token提取", + "开启语音切割": "開啟語音切割", + "开启语音降噪": "開啟語音降噪", + "怎么切": "怎麼切", + "总训练轮数total_epoch": "總訓練輪數total_epoch", + "总训练轮数total_epoch,不建议太高": "總訓練輪數total_epoch,不建議太高", + "打标工具WebUI已关闭": "打標工具WebUI已關閉", + "打标工具WebUI已开启": "打標工具WebUI已開啟", + "打标工具进程输出信息": "打標工具進程輸出信息", + "指定输出主人声文件夹": "指定输出主人声文件夹", + "指定输出非主人声文件夹": "指定输出非主人声文件夹", + "按中文句号。切": "按中文句號。切", + "按标点符号切": "按標點符號切", + "按英文句号.切": "按英文句號.切", + "数据类型精度": "數據類型精度", + "文本模块学习率权重": "文本模塊學習率權重", + "文本进程输出信息": "文本進程輸出信息", + "施工中,请静候佳音": "施工中,請靜候佳音", + "日文": "日文", + "日英混合": "日英混合", + "是否仅保存最新的ckpt文件以节省硬盘空间": "是否僅保存最新的ckpt文件以節省硬碟空間", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "是否在每次保存時間點將最終小模型保存至weights文件夾", + "是否开启TTS推理WebUI": "是否開啟TTS推理WebUI", + "是否开启UVR5-WebUI": "是否開啟UVR5-WebUI", + "是否开启dpo训练选项(实验性)": "是否開啟dpo訓練選項(實驗性)", + "是否开启打标WebUI": "是否開啟打標WebUI", + "是否直接对上次合成结果调整语速。防止随机性。": "是否直接對上次合成結果調整語速。防止隨機性。", + "显卡信息": "顯卡信息", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "本軟件以MIT協議開源, 作者不對軟件具備任何控制力, 使用軟件者、傳播軟件導出的聲音者自負全責.
如不認可該條款, 則不能使用或引用軟件包內任何代碼和文件. 詳見根目錄LICENSE.", + "模型": "模型", + "模型分为三类:": "模型分為三類:", + "模型切换": "模型切換", + "每张显卡的batch_size": "每張顯卡的batch_size", + "版本": "版本", + "粤英混合": "粵英混合", + "粤语": "粵語", + "终止ASR进程": "終止ASR進程", + "终止GPT训练": "終止GPT訓練", + "终止SSL提取进程": "終止SSL提取進程", + "终止SoVITS训练": "終止SoVITS訓練", + "终止一键三连": "終止一鍵三連", + "终止文本获取进程": "終止文本獲取進程", + "终止语义token提取进程": "終止語義token提取進程", + "终止语音切割": "終止語音切割", + "终止语音降噪进程": "終止語音降噪進程", + "英文": "英文", + "语义token提取进程输出信息": "語義token提取進程輸出信息", + "语速": "語速", + "语速调整,高为更快": "調整語速,高為更快", + "语音切割进程输出信息": "語音切割進程輸出信息", + "语音降噪进程输出信息": "語音降噪進程輸出信息", + "请上传3~10秒内参考音频,超过会报错!": "請上傳3~10秒內參考音頻,超過會報錯!", + "请上传参考音频": "請上傳參考音頻", + "请填入推理文本": "請填入推理文本", + "请输入有效文本": "請輸入有效文本", + "转换": "轉換", + "输入待处理音频文件夹路径": "輸入待處理音頻資料夾路徑", + "输入文件夹路径": "輸入文件夾路徑", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "輸出logs/實驗名目錄下應有23456開頭的文件和文件夾", + "输出信息": "輸出訊息", + "输出文件夹路径": "輸出文件夾路徑", + "输出的语音": "輸出的語音", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "選擇訓練完存放在SoVITS_weights和GPT_weights下的模型。默認的一個是底模,體驗5秒Zero Shot TTS用。", + "降噪结果输出文件夹": "降噪結果輸出文件夾", + "降噪音频文件输入文件夹": "降噪音頻文件輸入文件夾", + "需要合成的文本": "需要合成的文本", + "需要合成的语种": "需要合成的語種", + "韩文": "韓文", + "韩英混合": "韓英混合", + "音频自动切分输入路径,可文件可文件夹": "音頻自動切分輸入路徑,可文件可文件夾", + "预训练的GPT模型路径": "預訓練的GPT模型路徑", + "预训练的SSL模型路径": "預訓練的SSL模型路徑", + "预训练的SoVITS-D模型路径": "預訓練的SoVITS-D模型路徑", + "预训练的SoVITS-G模型路径": "預訓練的SoVITS-G模型路徑", + "预训练的中文BERT模型路径": "預訓練的中文BERT模型路徑" +} diff --git a/tools/i18n/locale/zh_SG.json b/tools/i18n/locale/zh_SG.json new file mode 100644 index 0000000000000000000000000000000000000000..3ed0801887ece8de2d102198b84d4f6ca066a050 --- /dev/null +++ b/tools/i18n/locale/zh_SG.json @@ -0,0 +1,178 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1)MDX-Net(onnx_dereverb):對於雙通道混響是最好的選擇,不能去除單通道混響;", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho: Aggressive 比 Normal 去除得更徹底,DeReverb 額外去除混響,可去除單聲道混響,但是對高頻重的板式混響去不乾淨。", + "*GPT模型列表": "*GPT模型列表", + "*SoVITS模型列表": "*SoVITS模型列表", + "*实验/模型名": "*實驗/模型名", + "*文本标注文件": "*文本標註文件", + "*训练集音频文件目录": "*訓練集音頻文件目錄", + "*请上传并填写参考信息": "*請上傳並填寫參考信息", + "*请填写需要合成的目标文本和语种模式": "請填寫需要合成的目標文本和語言模式", + ".list标注文件的路径": ".list標註文件的路徑", + "0-前置数据集获取工具": "0-前置數據集獲取工具", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0a-UVR5人聲伴奏分離&去混響去延遲工具", + "0b-语音切分工具": "0b-語音切分工具", + "0bb-语音降噪工具": "0bb-語音降噪工具", + "0c-中文批量离线ASR工具": "0c-中文批量離線ASR工具", + "0d-语音文本校对标注工具": "0d-語音文本校對標註工具", + "1-GPT-SoVITS-TTS": "1-GPT-SoVITS-TTS", + "1A-训练集格式化工具": "1A-訓練集格式化工具", + "1Aa-文本内容": "1Aa-文本內容", + "1Aabc-训练集格式化一键三连": "1Aabc-訓練集格式化一鍵三連", + "1Ab-SSL自监督特征提取": "1Ab-SSL自監督特徵提取", + "1Ac-语义token提取": "1Ac-語義token提取", + "1B-微调训练": "1B-微調訓練", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1Ba-SoVITS訓練。用於分享的模型文件輸出在SoVITS_weights下。", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1Bb-GPT訓練。用於分享的模型文件輸出在GPT_weights下。", + "1C-推理": "1C-推理", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1、DeEcho-DeReverb 模型的耗時是另外兩個 DeEcho 模型的接近兩倍;", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1、保留人聲:不帶和聲的音頻選這個,對主人聲保留比HP5更好。內置HP2和HP3兩個模型,HP3可能輕微漏伴奏但對主人聲保留比HP2稍微好一丁點;", + "2-GPT-SoVITS-变声": "2-GPT-SoVITS-變聲", + "2、MDX-Net-Dereverb模型挺慢的;": "2、MDX-Net-Dereverb 模型的處理時間挺慢的;", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2、僅保留主人聲:帶和聲的音頻選這個,對主人聲可能有削弱。內置HP5一個模型;", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3、個人推薦的最乾淨的配置是先 MDX-Net 再 DeEcho-Aggressive。", + "3、去混响、去延迟模型(by FoxJoy):": "3、去混響、去延遲模型(by FoxJoy):", + "ASR 模型": "ASR 模型", + "ASR 模型尺寸": "ASR 模型尺寸", + "ASR 语言设置": "ASR 語言設定", + "ASR进程输出信息": "ASR進程輸出資訊", + "GPT模型列表": "GPT模型列表", + "GPT训练进程输出信息": "GPT訓練進程輸出資訊", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "GPT 采样参数(无参考文本时不要太低。不懂就用默认):", + "GPU卡号,只能填1个整数": "GPU卡號,只能填1個整數", + "GPU卡号以-分割,每个卡号一个进程": "GPU卡號以-分割,每個卡號一個進程", + "SSL进程输出信息": "SSL進程輸出資訊", + "SoVITS模型列表": "SoVITS模型列表", + "SoVITS训练进程输出信息": "SoVITS訓練進程輸出資訊", + "TTS推理WebUI进程输出信息": "TTS推理WebUI進程輸出資訊", + "TTS推理进程已关闭": "TTS推理進程已關閉", + "TTS推理进程已开启": "TTS推理進程已開啟", + "UVR5已关闭": "UVR5已關閉", + "UVR5已开启": "UVR5已開啟", + "UVR5进程输出信息": "UVR5進程輸出資訊", + "alpha_mix:混多少比例归一化后音频进来": "alpha_mix:混多少比例歸一化後音頻進來", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size:怎麼算音量曲線,越小精度越大計算量越高(不是精度越大效果越好)", + "max:归一化后最大值多少": "max:歸一化後最大值多少", + "max_sil_kept:切完后静音最多留多长": "max_sil_kept:切完後靜音最多留多長", + "min_interval:最短切割间隔": "min_interval:最短切割間隔", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length:每段最小多長,如果第一段太短一直和後面段連起來直到超過這個值", + "temperature": "temperature", + "threshold:音量小于这个值视作静音的备选切割点": "threshold:音量小於這個值視作靜音的備選切割點", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "一鍵三連進程輸出資訊", + "不切": "不切", + "中文": "中文", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "中文教程文檔:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "中英混合", + "也可批量输入音频文件, 二选一, 优先读文件夹": "也可批量输入音频文件, 二选一, 优先读文件夹", + "人声伴奏分离批量处理, 使用UVR5模型。": "人聲伴奏分離批量處理, 使用UVR5模型。", + "人声提取激进程度": "人聲提取激進程度", + "以下文件或文件夹不存在:": "沒有這樣的檔案或文件夾:", + "以下模型不存在:": "以下模型不存在", + "伴奏人声分离&去混响&去回声": "伴奏人聲分離&去混響&去回聲", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "使用無參考文本模式時建議使用微調的GPT,聽不清參考音頻說的啥(不曉得寫啥)可以開,開啟後無視填寫的參考文本。", + "保存频率save_every_epoch": "保存頻率save_every_epoch", + "凑50字一切": "湊50字一切", + "凑四句一切": "湊四句一切", + "切分后的子音频的输出根目录": "切分後的子音頻的輸出根目錄", + "切割使用的进程数": "切割使用的進程數", + "刷新模型路径": "刷新模型路徑", + "前端处理后的文本(每句):": "前端處理後的文本(每句):", + "去混响/去延迟,附:": "去混響/去延遲,附:", + "参考音频在3~10秒范围外,请更换!": "參考音頻在3~10秒範圍外,請更換!", + "参考音频的文本": "參考音頻的文本", + "参考音频的语种": "參考音頻的語種", + "合成语音": "合成語音", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "合格的資料夾路徑格式舉例: E:\\codes\\py39\\vits_vc_gpu\\白鷺霜華測試範例(去文件管理器地址欄拷就行了)。", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "填切割後音頻所在目錄!讀取的音頻檔案完整路徑=該目錄-拼接-list檔案裡波形對應的檔案名(不是全路徑)。如果留空則使用.list檔案裡的絕對全路徑。", + "多语种混合": "多語種混合", + "多语种混合(粤语)": "多語種混合 (粵語)", + "实际输入的参考文本:": "實際輸入的參考文本:", + "实际输入的目标文本(切句后):": "實際輸入的目標文本(切句後):", + "实际输入的目标文本(每句):": "實際輸入的目標文本(每句):", + "实际输入的目标文本:": "實際輸入的目標文本:", + "导出文件格式": "導出檔格式", + "开启GPT训练": "開啟GPT訓練", + "开启SSL提取": "開啟SSL提取", + "开启SoVITS训练": "開啟SoVITS訓練", + "开启一键三连": "開啟一鍵三連", + "开启文本获取": "開啟文本獲取", + "开启无参考文本模式。不填参考文本亦相当于开启。": "開啟無參考文本模式。不填參考文本亦相當於開啟。", + "开启离线批量ASR": "開啟離線批量ASR", + "开启语义token提取": "開啟語義token提取", + "开启语音切割": "開啟語音切割", + "开启语音降噪": "開啟語音降噪", + "怎么切": "怎麼切", + "总训练轮数total_epoch": "總訓練輪數total_epoch", + "总训练轮数total_epoch,不建议太高": "總訓練輪數total_epoch,不建議太高", + "打标工具WebUI已关闭": "打標工具WebUI已關閉", + "打标工具WebUI已开启": "打標工具WebUI已開啟", + "打标工具进程输出信息": "打標工具進程輸出資訊", + "指定输出主人声文件夹": "指定输出主人声文件夹", + "指定输出非主人声文件夹": "指定输出非主人声文件夹", + "按中文句号。切": "按中文句號。切", + "按标点符号切": "按標點符號切", + "按英文句号.切": "按英文句號.切", + "数据类型精度": "數據類型精度", + "文本模块学习率权重": "文本模塊學習率權重", + "文本进程输出信息": "文本進程輸出資訊", + "施工中,请静候佳音": "施工中,請靜候佳音", + "日文": "日文", + "日英混合": "日英混合", + "是否仅保存最新的ckpt文件以节省硬盘空间": "是否僅保存最新的ckpt文件以節省硬盤空間", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "是否在每次保存時間點將最終小模型保存至weights文件夾", + "是否开启TTS推理WebUI": "是否開啟TTS推理WebUI", + "是否开启UVR5-WebUI": "是否開啟UVR5-WebUI", + "是否开启dpo训练选项(实验性)": "是否開啟dpo訓練選項(實驗性)", + "是否开启打标WebUI": "是否開啟打標WebUI", + "是否直接对上次合成结果调整语速。防止随机性。": "是否直接對上次合成結果調整語速。防止隨機性。", + "显卡信息": "顯卡資訊", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "本軟體以MIT協議開源,作者不對軟體具備任何控制力,使用軟體者、傳播軟體導出的聲音者自負全責。
如不認可該條款,則不能使用或引用軟體包內任何代碼和文件。詳見根目錄LICENSE。", + "模型": "模型", + "模型分为三类:": "模型分為三類:", + "模型切换": "模型切換", + "每张显卡的batch_size": "每張顯卡的batch_size", + "版本": "版本", + "粤英混合": "粵英混合", + "粤语": "粵語", + "终止ASR进程": "終止ASR進程", + "终止GPT训练": "終止GPT訓練", + "终止SSL提取进程": "終止SSL提取進程", + "终止SoVITS训练": "終止SoVITS訓練", + "终止一键三连": "終止一鍵三連", + "终止文本获取进程": "終止文本獲取進程", + "终止语义token提取进程": "終止語義token提取進程", + "终止语音切割": "終止語音切割", + "终止语音降噪进程": "終止語音降噪進程", + "英文": "英文", + "语义token提取进程输出信息": "語義token提取進程輸出資訊", + "语速": "語速", + "语速调整,高为更快": "調整語速,高為更快", + "语音切割进程输出信息": "語音切割進程輸出資訊", + "语音降噪进程输出信息": "語音降噪進程輸出資訊", + "请上传3~10秒内参考音频,超过会报错!": "請上傳3~10秒內參考音頻,超過會報錯!", + "请上传参考音频": "請上傳參考音頻", + "请填入推理文本": "請填入推理文本", + "请输入有效文本": "請輸入有效文本", + "转换": "轉換", + "输入待处理音频文件夹路径": "輸入待處理音頻資料夾路徑", + "输入文件夹路径": "輸入文件夾路徑", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "輸出logs/實驗名目錄下應有23456開頭的文件和文件夾", + "输出信息": "輸出訊息", + "输出文件夹路径": "輸出文件夾路徑", + "输出的语音": "輸出的語音", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "選擇訓練完存放在SoVITS_weights和GPT_weights下的模型。默認的一個是底模,體驗5秒Zero Shot TTS用。", + "降噪结果输出文件夹": "降噪結果輸出文件夾", + "降噪音频文件输入文件夹": "降噪音頻文件輸入文件夾", + "需要合成的文本": "需要合成的文本", + "需要合成的语种": "需要合成的語種", + "韩文": "韓文", + "韩英混合": "韓英混合", + "音频自动切分输入路径,可文件可文件夹": "音頻自動切分輸入路徑,可文件可文件夾", + "预训练的GPT模型路径": "預訓練的GPT模型路徑", + "预训练的SSL模型路径": "預訓練的SSL模型路徑", + "预训练的SoVITS-D模型路径": "預訓練的SoVITS-D模型路徑", + "预训练的SoVITS-G模型路径": "預訓練的SoVITS-G模型路徑", + "预训练的中文BERT模型路径": "預訓練的中文BERT模型路徑" +} diff --git a/tools/i18n/locale/zh_TW.json b/tools/i18n/locale/zh_TW.json new file mode 100644 index 0000000000000000000000000000000000000000..9e63896e02349241fe079ef927dab9faa4c92bf7 --- /dev/null +++ b/tools/i18n/locale/zh_TW.json @@ -0,0 +1,178 @@ +{ + "(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;": "(1)MDX-Net(onnx_dereverb):對於雙通道混響是最好的選擇,不能去除單通道混響;", + "(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。": "(234)DeEcho:去除延遲效果。Aggressive 比 Normal 去除得更徹底,DeReverb 額外去除混響,可去除單聲道混響,但是對高頻重的板式混響去不乾淨。", + "*GPT模型列表": "*GPT模型列表", + "*SoVITS模型列表": "*SoVITS模型列表", + "*实验/模型名": "*實驗/模型名", + "*文本标注文件": "*文本標注文件", + "*训练集音频文件目录": "*訓練集音頻文件目錄", + "*请上传并填写参考信息": "*請上傳並填寫參考資訊", + "*请填写需要合成的目标文本和语种模式": "請填寫需要合成的目標文本和語言模式", + ".list标注文件的路径": ".list標注文件的路徑", + "0-前置数据集获取工具": "0-前置數據集獲取工具", + "0a-UVR5人声伴奏分离&去混响去延迟工具": "0a-UVR5人聲伴奏分離&去混響去延遲工具", + "0b-语音切分工具": "0b-語音切分工具", + "0bb-语音降噪工具": "0bb-語音降噪工具", + "0c-中文批量离线ASR工具": "0c-中文批量離線ASR工具", + "0d-语音文本校对标注工具": "0d-語音文本校對標注工具", + "1-GPT-SoVITS-TTS": "1-GPT-SoVITS-TTS", + "1A-训练集格式化工具": "1A-訓練集格式化工具", + "1Aa-文本内容": "1Aa-文本內容", + "1Aabc-训练集格式化一键三连": "1Aabc-訓練集格式化一鍵三連", + "1Ab-SSL自监督特征提取": "1Ab-SSL自監督特徵提取", + "1Ac-语义token提取": "1Ac-語義token提取", + "1B-微调训练": "1B-微調訓練", + "1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。": "1Ba-SoVITS訓練。用於分享的模型文件輸出在SoVITS_weights下。", + "1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。": "1Bb-GPT訓練。用於分享的模型文件輸出在GPT_weights下。", + "1C-推理": "1C-推理", + "1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;": "1、DeEcho-DeReverb 模型的耗時是另外兩個 DeEcho 模型的接近兩倍;", + "1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;": "1、保留人聲:不帶和聲的音頻選這個,對主人聲保留比HP5更好。內置HP2和HP3兩個模型,HP3可能輕微漏伴奏但對主人聲保留比HP2稍微好一丁點;", + "2-GPT-SoVITS-变声": "2-GPT-SoVITS-變聲", + "2、MDX-Net-Dereverb模型挺慢的;": "2、MDX-Net-Dereverb模型挺慢的;", + "2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;": "2、僅保留主人聲:帶和聲的音頻選這個,對主人聲可能有削弱。內置HP5一個模型;", + "3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。": "3、個人推薦的最乾淨的配置是先 MDX-Net 再 DeEcho-Aggressive。", + "3、去混响、去延迟模型(by FoxJoy):": "3、去混響、去延遲模型(by FoxJoy):", + "ASR 模型": "ASR 模型", + "ASR 模型尺寸": "ASR 模型尺寸", + "ASR 语言设置": "ASR 語言設置", + "ASR进程输出信息": "ASR進程輸出資訊", + "GPT模型列表": "GPT模型列表", + "GPT训练进程输出信息": "GPT訓練進程輸出資訊", + "GPT采样参数(无参考文本时不要太低。不懂就用默认):": "GPT 采样参数(无参考文本时不要太低。不懂就用默认):", + "GPU卡号,只能填1个整数": "GPU卡號,只能填1個整數", + "GPU卡号以-分割,每个卡号一个进程": "GPU卡號以-分割,每個卡號一個進程", + "SSL进程输出信息": "SSL進程輸出資訊", + "SoVITS模型列表": "SoVITS模型列表", + "SoVITS训练进程输出信息": "SoVITS訓練進程輸出資訊", + "TTS推理WebUI进程输出信息": "TTS推理WebUI進程輸出資訊", + "TTS推理进程已关闭": "TTS推理進程已關閉", + "TTS推理进程已开启": "TTS推理進程已開啟", + "UVR5已关闭": "UVR5已關閉", + "UVR5已开启": "UVR5已開啟", + "UVR5进程输出信息": "UVR5進程輸出資訊", + "alpha_mix:混多少比例归一化后音频进来": "alpha_mix:混多少比例歸一化後音頻進來", + "hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size:怎麼算音量曲線,越小精度越大計算量越高(不是精度越大效果越好)", + "max:归一化后最大值多少": "max:歸一化後最大值多少", + "max_sil_kept:切完后静音最多留多长": "max_sil_kept:切完後靜音最多留多長", + "min_interval:最短切割间隔": "min_interval:最短切割間隔", + "min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length:每段最小多長,如果第一段太短一直和後面段連起來直到超過這個值", + "temperature": "temperature", + "threshold:音量小于这个值视作静音的备选切割点": "threshold:音量小於這個值視作靜音的備選切割點", + "top_k": "top_k", + "top_p": "top_p", + "一键三连进程输出信息": "一鍵三連進程輸出資訊", + "不切": "不切", + "中文": "中文", + "中文教程文档:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e": "中文教程文檔:https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e", + "中英混合": "中英混合", + "也可批量输入音频文件, 二选一, 优先读文件夹": "也可批量输入音频文件, 二选一, 优先读文件夹", + "人声伴奏分离批量处理, 使用UVR5模型。": "人聲伴奏分離批量處理, 使用UVR5模型。", + "人声提取激进程度": "人聲提取激進程度", + "以下文件或文件夹不存在:": "沒有這樣的檔案或文件夾:", + "以下模型不存在:": "#以下模型不存在", + "伴奏人声分离&去混响&去回声": "伴奏人聲分離&去混響&去回聲", + "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。
开启后无视填写的参考文本。": "使用無參考文本模式時建議使用微調的GPT,聽不清參考音頻說的啥(不曉得寫啥)可以開,開啟後無視填寫的參考文本。", + "保存频率save_every_epoch": "保存頻率save_every_epoch", + "凑50字一切": "湊50字一切", + "凑四句一切": "湊四句一切", + "切分后的子音频的输出根目录": "切分後的子音頻的輸出根目錄", + "切割使用的进程数": "切割使用的進程數", + "刷新模型路径": "刷新模型路徑", + "前端处理后的文本(每句):": "前端處理後的文本(每句):", + "去混响/去延迟,附:": "去混響/去延遲,附:", + "参考音频在3~10秒范围外,请更换!": "參考音頻在3~10秒範圍外,請更換!", + "参考音频的文本": "參考音頻的文本", + "参考音频的语种": "參考音頻的語種", + "合成语音": "合成語音", + "合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。": "合格的資料夾路徑格式舉例: E:\\codes\\py39\\vits_vc_gpu\\白鷺霜華測試範例(去文件管理器地址欄拷就行了)。", + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。": "填切割後音頻所在目錄!讀取的音頻檔案完整路徑=該目錄-拼接-list檔案裡波形對應的檔案名(不是全路徑)。如果留空則使用.list檔案裡的絕對全路徑。", + "多语种混合": "多語種混合", + "多语种混合(粤语)": "多語種混合 (粵語)", + "实际输入的参考文本:": "實際輸入的參考文本:", + "实际输入的目标文本(切句后):": "實際輸入的目標文本(切句後):", + "实际输入的目标文本(每句):": "實際輸入的目標文本(每句):", + "实际输入的目标文本:": "實際輸入的目標文本:", + "导出文件格式": "導出檔格式", + "开启GPT训练": "開啟GPT訓練", + "开启SSL提取": "開啟SSL提取", + "开启SoVITS训练": "開啟SoVITS訓練", + "开启一键三连": "開啟一鍵三連", + "开启文本获取": "開啟文本獲取", + "开启无参考文本模式。不填参考文本亦相当于开启。": "開啟無參考文本模式。不填參考文本亦相當於開啟。", + "开启离线批量ASR": "開啟離線批量ASR", + "开启语义token提取": "開啟語義token提取", + "开启语音切割": "開啟語音切割", + "开启语音降噪": "開啟語音降噪", + "怎么切": "怎麼切", + "总训练轮数total_epoch": "總訓練輪數total_epoch", + "总训练轮数total_epoch,不建议太高": "總訓練輪數total_epoch,不建議太高", + "打标工具WebUI已关闭": "打標工具WebUI已關閉", + "打标工具WebUI已开启": "打標工具WebUI已開啟", + "打标工具进程输出信息": "打標工具進程輸出資訊", + "指定输出主人声文件夹": "指定输出主人声文件夹", + "指定输出非主人声文件夹": "指定输出非主人声文件夹", + "按中文句号。切": "按中文句號。切", + "按标点符号切": "按標點符號切", + "按英文句号.切": "按英文句號.切", + "数据类型精度": "數據類型精度", + "文本模块学习率权重": "文本模塊學習率權重", + "文本进程输出信息": "文本進程輸出資訊", + "施工中,请静候佳音": "施工中,請靜候佳音", + "日文": "日文", + "日英混合": "日英混合", + "是否仅保存最新的ckpt文件以节省硬盘空间": "是否僅保存最新的ckpt文件以節省硬盤空間", + "是否在每次保存时间点将最终小模型保存至weights文件夹": "是否在每次保存時間點將最終小模型保存至weights文件夾", + "是否开启TTS推理WebUI": "是否開啟TTS推理WebUI", + "是否开启UVR5-WebUI": "是否開啟UVR5-WebUI", + "是否开启dpo训练选项(实验性)": "是否開啟dpo訓練選項(實驗性)", + "是否开启打标WebUI": "是否開啟打標WebUI", + "是否直接对上次合成结果调整语速。防止随机性。": "是否直接對上次合成結果調整語速。防止隨機性。", + "显卡信息": "顯卡資訊", + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "本軟體以MIT協議開源,作者不對軟體具備任何控制力,使用軟體者、傳播軟體導出的聲音者自負全責。
如不認可該條款,則不能使用或引用軟體包內任何代碼和文件。詳見根目錄LICENSE。", + "模型": "模型", + "模型分为三类:": "模型分為三類:", + "模型切换": "模型切換", + "每张显卡的batch_size": "每張顯卡的batch_size", + "版本": "版本", + "粤英混合": "粵英混合", + "粤语": "粵語", + "终止ASR进程": "終止ASR進程", + "终止GPT训练": "終止GPT訓練", + "终止SSL提取进程": "終止SSL提取進程", + "终止SoVITS训练": "終止SoVITS訓練", + "终止一键三连": "終止一鍵三連", + "终止文本获取进程": "終止文本獲取進程", + "终止语义token提取进程": "終止語義token提取進程", + "终止语音切割": "終止語音切割", + "终止语音降噪进程": "終止語音降噪進程", + "英文": "英文", + "语义token提取进程输出信息": "語義token提取進程輸出資訊", + "语速": "語速", + "语速调整,高为更快": "調整語速,高為更快", + "语音切割进程输出信息": "語音切割進程輸出資訊", + "语音降噪进程输出信息": "語音降噪進程輸出資訊", + "请上传3~10秒内参考音频,超过会报错!": "請上傳3~10秒內參考音頻,超過會報錯!", + "请上传参考音频": "請上傳參考音頻", + "请填入推理文本": "請填入推理文本", + "请输入有效文本": "請輸入有效文本", + "转换": "轉換", + "输入待处理音频文件夹路径": "輸入待處理音頻資料夾路徑", + "输入文件夹路径": "輸入文件夾路徑", + "输出logs/实验名目录下应有23456开头的文件和文件夹": "輸出logs/實驗名目錄下應有23456開頭的文件和文件夾", + "输出信息": "輸出訊息", + "输出文件夹路径": "輸出文件夾路徑", + "输出的语音": "輸出的語音", + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。": "選擇訓練完存放在SoVITS_weights和GPT_weights下的模型。默認的一個是底模,體驗5秒Zero Shot TTS用。", + "降噪结果输出文件夹": "降噪結果輸出文件夾", + "降噪音频文件输入文件夹": "降噪音頻文件輸入文件夾", + "需要合成的文本": "需要合成的文本", + "需要合成的语种": "需要合成的語種", + "韩文": "韓文", + "韩英混合": "韓英混合", + "音频自动切分输入路径,可文件可文件夹": "音頻自動切分輸入路徑,可文件可文件夾", + "预训练的GPT模型路径": "預訓練的GPT模型路徑", + "预训练的SSL模型路径": "預訓練的SSL模型路徑", + "预训练的SoVITS-D模型路径": "預訓練的SoVITS-D模型路徑", + "预训练的SoVITS-G模型路径": "預訓練的SoVITS-G模型路徑", + "预训练的中文BERT模型路径": "預訓練的中文BERT模型路徑" +} diff --git a/tools/i18n/scan_i18n.py b/tools/i18n/scan_i18n.py new file mode 100644 index 0000000000000000000000000000000000000000..98bea6a8e957ab632e2499ee3ffaf9f00783ca88 --- /dev/null +++ b/tools/i18n/scan_i18n.py @@ -0,0 +1,119 @@ +import ast +import glob +import json +import os +from collections import OrderedDict + +I18N_JSON_DIR : os.PathLike = os.path.join(os.path.dirname(os.path.relpath(__file__)), 'locale') +DEFAULT_LANGUAGE: str = "zh_CN" # 默认语言 +TITLE_LEN : int = 60 # 标题显示长度 +KEY_LEN : int = 30 # 键名显示长度 +SHOW_KEYS : bool = False # 是否显示键信息 + +def extract_i18n_strings(node): + i18n_strings = [] + + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "i18n" + ): + for arg in node.args: + if isinstance(arg, ast.Str): + i18n_strings.append(arg.s) + + for child_node in ast.iter_child_nodes(node): + i18n_strings.extend(extract_i18n_strings(child_node)) + + return i18n_strings + +def scan_i18n_strings(): + """ + scan the directory for all .py files (recursively) + for each file, parse the code into an AST + for each AST, extract the i18n strings + """ + strings = [] + print(" Scanning Files and Extracting i18n Strings ".center(TITLE_LEN, "=")) + for filename in glob.iglob("**/*.py", recursive=True): + with open(filename, "r", encoding="utf-8") as f: + code = f.read() + if "I18nAuto" in code: + tree = ast.parse(code) + i18n_strings = extract_i18n_strings(tree) + print(f"{filename.ljust(30)}: {len(i18n_strings)}") + strings.extend(i18n_strings) + + code_keys = set(strings) + print(f"{'Total Unique'.ljust(30)}: {len(code_keys)}") + return code_keys + +def update_i18n_json(json_file, standard_keys): + print(f" Process {json_file} ".center(TITLE_LEN, "=")) + # 读取 JSON 文件 + with open(json_file, "r", encoding="utf-8") as f: + json_data = json.load(f, object_pairs_hook=OrderedDict) + # 打印处理前的 JSON 条目数 + len_before = len(json_data) + print(f"{'Total Keys'.ljust(KEY_LEN)}: {len_before}") + # 识别缺失的键并补全 + miss_keys = set(standard_keys) - set(json_data.keys()) + if len(miss_keys) > 0: + print(f"{'Missing Keys (+)'.ljust(KEY_LEN)}: {len(miss_keys)}") + for key in miss_keys: + if DEFAULT_LANGUAGE in json_file: + # 默认语言的键值相同. + json_data[key] = key + else: + # 其他语言的值设置为 #! + 键名以标注未被翻译. + json_data[key] = "#!" + key + if SHOW_KEYS: + print(f"{'Added Missing Key'.ljust(KEY_LEN)}: {key}") + # 识别多余的键并删除 + diff_keys = set(json_data.keys()) - set(standard_keys) + if len(diff_keys) > 0: + print(f"{'Unused Keys (-)'.ljust(KEY_LEN)}: {len(diff_keys)}") + for key in diff_keys: + del json_data[key] + if SHOW_KEYS: + print(f"{'Removed Unused Key'.ljust(KEY_LEN)}: {key}") + # 按键顺序排序 + json_data = OrderedDict( + sorted(json_data.items(), + key=lambda x: list(standard_keys).index(x[0]))) + # 打印处理后的 JSON 条目数 + if len(miss_keys) != 0 or len(diff_keys) != 0: + print(f"{'Total Keys (After)'.ljust(KEY_LEN)}: {len(json_data)}") + # 识别有待翻译的键 + num_miss_translation = 0 + duplicate_items = {} + for key, value in json_data.items(): + if value.startswith("#!"): + num_miss_translation += 1 + if SHOW_KEYS: + print(f"{'Missing Translation'.ljust(KEY_LEN)}: {key}") + if value in duplicate_items: + duplicate_items[value].append(key) + else: + duplicate_items[value] = [key] + # 打印是否有重复的值 + for value, keys in duplicate_items.items(): + if len(keys) > 1: + print("\n".join([f"\033[31m{'[Failed] Duplicate Value'.ljust(KEY_LEN)}: {key} -> {value}\033[0m" for key in keys])) + + if num_miss_translation > 0: + print(f"\033[31m{'[Failed] Missing Translation'.ljust(KEY_LEN)}: {num_miss_translation}\033[0m") + else: + print(f"\033[32m[Passed] All Keys Translated\033[0m") + # 将处理后的结果写入 JSON 文件 + with open(json_file, "w", encoding="utf-8") as f: + json.dump(json_data, f, ensure_ascii=False, indent=4, sort_keys=True) + f.write("\n") + print(f" Updated {json_file} ".center(TITLE_LEN, "=") + '\n') + +if __name__ == "__main__": + code_keys = scan_i18n_strings() + for json_file in os.listdir(I18N_JSON_DIR): + if json_file.endswith(r".json"): + json_file = os.path.join(I18N_JSON_DIR, json_file) + update_i18n_json(json_file, code_keys) \ No newline at end of file diff --git a/tools/my_utils.py b/tools/my_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..53544f8d3c84e2022e0c7154f992bbb881dfc2f7 --- /dev/null +++ b/tools/my_utils.py @@ -0,0 +1,32 @@ +import platform,os,traceback +import ffmpeg +import numpy as np + + +def load_audio(file, sr): + try: + # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 + # This launches a subprocess to decode audio while down-mixing and resampling as necessary. + # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. + file = clean_path(file) # 防止小白拷路径头尾带了空格和"和回车 + if os.path.exists(file) == False: + raise RuntimeError( + "You input a wrong audio path that does not exists, please fix it!" + ) + out, _ = ( + ffmpeg.input(file, threads=0) + .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) + .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) + ) + except Exception as e: + traceback.print_exc() + raise RuntimeError(f"Failed to load audio: {e}") + + return np.frombuffer(out, np.float32).flatten() + + +def clean_path(path_str:str): + if path_str.endswith(('\\','/')): + return clean_path(path_str[0:-1]) + path_str = path_str.replace('/', os.sep).replace('\\', os.sep) + return path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a") diff --git a/tools/slice_audio.py b/tools/slice_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..8a06292d993825ca49d57f1274865c029c0b2bb4 --- /dev/null +++ b/tools/slice_audio.py @@ -0,0 +1,48 @@ +import os,sys,numpy as np +import traceback +from scipy.io import wavfile +# parent_directory = os.path.dirname(os.path.abspath(__file__)) +# sys.path.append(parent_directory) +from tools.my_utils import load_audio +from slicer2 import Slicer + +def slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_max,alpha,i_part,all_part): + os.makedirs(opt_root,exist_ok=True) + if os.path.isfile(inp): + input=[inp] + elif os.path.isdir(inp): + input=[os.path.join(inp, name) for name in sorted(list(os.listdir(inp)))] + else: + return "输入路径存在但既不是文件也不是文件夹" + slicer = Slicer( + sr=32000, # 长音频采样率 + threshold= int(threshold), # 音量小于这个值视作静音的备选切割点 + min_length= int(min_length), # 每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值 + min_interval= int(min_interval), # 最短切割间隔 + hop_size= int(hop_size), # 怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好) + max_sil_kept= int(max_sil_kept), # 切完后静音最多留多长 + ) + _max=float(_max) + alpha=float(alpha) + for inp_path in input[int(i_part)::int(all_part)]: + # print(inp_path) + try: + name = os.path.basename(inp_path) + audio = load_audio(inp_path, 32000) + # print(audio.shape) + for chunk, start, end in slicer.slice(audio): # start和end是帧数 + tmp_max = np.abs(chunk).max() + if(tmp_max>1):chunk/=tmp_max + chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk + wavfile.write( + "%s/%s_%010d_%010d.wav" % (opt_root, name, start, end), + 32000, + # chunk.astype(np.float32), + (chunk * 32767).astype(np.int16), + ) + except: + print(inp_path,"->fail->",traceback.format_exc()) + return "执行完毕,请检查输出文件" + +print(slice(*sys.argv[1:])) + diff --git a/tools/slicer2.py b/tools/slicer2.py new file mode 100644 index 0000000000000000000000000000000000000000..ba6794b6335fc50a494ba1b1cfb375536ab7a1aa --- /dev/null +++ b/tools/slicer2.py @@ -0,0 +1,261 @@ +import numpy as np + + +# This function is obtained from librosa. +def get_rms( + y, + frame_length=2048, + hop_length=512, + pad_mode="constant", +): + padding = (int(frame_length // 2), int(frame_length // 2)) + y = np.pad(y, padding, mode=pad_mode) + + axis = -1 + # put our new within-frame axis at the end for now + out_strides = y.strides + tuple([y.strides[axis]]) + # Reduce the shape on the framing axis + x_shape_trimmed = list(y.shape) + x_shape_trimmed[axis] -= frame_length - 1 + out_shape = tuple(x_shape_trimmed) + tuple([frame_length]) + xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides) + if axis < 0: + target_axis = axis - 1 + else: + target_axis = axis + 1 + xw = np.moveaxis(xw, -1, target_axis) + # Downsample along the target axis + slices = [slice(None)] * xw.ndim + slices[axis] = slice(0, None, hop_length) + x = xw[tuple(slices)] + + # Calculate power + power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True) + + return np.sqrt(power) + + +class Slicer: + def __init__( + self, + sr: int, + threshold: float = -40.0, + min_length: int = 5000, + min_interval: int = 300, + hop_size: int = 20, + max_sil_kept: int = 5000, + ): + if not min_length >= min_interval >= hop_size: + raise ValueError( + "The following condition must be satisfied: min_length >= min_interval >= hop_size" + ) + if not max_sil_kept >= hop_size: + raise ValueError( + "The following condition must be satisfied: max_sil_kept >= hop_size" + ) + min_interval = sr * min_interval / 1000 + self.threshold = 10 ** (threshold / 20.0) + self.hop_size = round(sr * hop_size / 1000) + self.win_size = min(round(min_interval), 4 * self.hop_size) + self.min_length = round(sr * min_length / 1000 / self.hop_size) + self.min_interval = round(min_interval / self.hop_size) + self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) + + def _apply_slice(self, waveform, begin, end): + if len(waveform.shape) > 1: + return waveform[ + :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size) + ] + else: + return waveform[ + begin * self.hop_size : min(waveform.shape[0], end * self.hop_size) + ] + + # @timeit + def slice(self, waveform): + if len(waveform.shape) > 1: + samples = waveform.mean(axis=0) + else: + samples = waveform + if samples.shape[0] <= self.min_length: + return [waveform] + rms_list = get_rms( + y=samples, frame_length=self.win_size, hop_length=self.hop_size + ).squeeze(0) + sil_tags = [] + silence_start = None + clip_start = 0 + for i, rms in enumerate(rms_list): + # Keep looping while frame is silent. + if rms < self.threshold: + # Record start of silent frames. + if silence_start is None: + silence_start = i + continue + # Keep looping while frame is not silent and silence start has not been recorded. + if silence_start is None: + continue + # Clear recorded silence start if interval is not enough or clip is too short + is_leading_silence = silence_start == 0 and i > self.max_sil_kept + need_slice_middle = ( + i - silence_start >= self.min_interval + and i - clip_start >= self.min_length + ) + if not is_leading_silence and not need_slice_middle: + silence_start = None + continue + # Need slicing. Record the range of silent frames to be removed. + if i - silence_start <= self.max_sil_kept: + pos = rms_list[silence_start : i + 1].argmin() + silence_start + if silence_start == 0: + sil_tags.append((0, pos)) + else: + sil_tags.append((pos, pos)) + clip_start = pos + elif i - silence_start <= self.max_sil_kept * 2: + pos = rms_list[ + i - self.max_sil_kept : silence_start + self.max_sil_kept + 1 + ].argmin() + pos += i - self.max_sil_kept + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + clip_start = pos_r + else: + sil_tags.append((min(pos_l, pos), max(pos_r, pos))) + clip_start = max(pos_r, pos) + else: + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + else: + sil_tags.append((pos_l, pos_r)) + clip_start = pos_r + silence_start = None + # Deal with trailing silence. + total_frames = rms_list.shape[0] + if ( + silence_start is not None + and total_frames - silence_start >= self.min_interval + ): + silence_end = min(total_frames, silence_start + self.max_sil_kept) + pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start + sil_tags.append((pos, total_frames + 1)) + # Apply and return slices. + ####音频+起始时间+终止时间 + if len(sil_tags) == 0: + return [[waveform,0,int(total_frames*self.hop_size)]] + else: + chunks = [] + if sil_tags[0][0] > 0: + chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]),0,int(sil_tags[0][0]*self.hop_size)]) + for i in range(len(sil_tags) - 1): + chunks.append( + [self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),int(sil_tags[i][1]*self.hop_size),int(sil_tags[i + 1][0]*self.hop_size)] + ) + if sil_tags[-1][1] < total_frames: + chunks.append( + [self._apply_slice(waveform, sil_tags[-1][1], total_frames),int(sil_tags[-1][1]*self.hop_size),int(total_frames*self.hop_size)] + ) + return chunks + + +def main(): + import os.path + from argparse import ArgumentParser + + import librosa + import soundfile + + parser = ArgumentParser() + parser.add_argument("audio", type=str, help="The audio to be sliced") + parser.add_argument( + "--out", type=str, help="Output directory of the sliced audio clips" + ) + parser.add_argument( + "--db_thresh", + type=float, + required=False, + default=-40, + help="The dB threshold for silence detection", + ) + parser.add_argument( + "--min_length", + type=int, + required=False, + default=5000, + help="The minimum milliseconds required for each sliced audio clip", + ) + parser.add_argument( + "--min_interval", + type=int, + required=False, + default=300, + help="The minimum milliseconds for a silence part to be sliced", + ) + parser.add_argument( + "--hop_size", + type=int, + required=False, + default=10, + help="Frame length in milliseconds", + ) + parser.add_argument( + "--max_sil_kept", + type=int, + required=False, + default=500, + help="The maximum silence length kept around the sliced clip, presented in milliseconds", + ) + args = parser.parse_args() + out = args.out + if out is None: + out = os.path.dirname(os.path.abspath(args.audio)) + audio, sr = librosa.load(args.audio, sr=None, mono=False) + slicer = Slicer( + sr=sr, + threshold=args.db_thresh, + min_length=args.min_length, + min_interval=args.min_interval, + hop_size=args.hop_size, + max_sil_kept=args.max_sil_kept, + ) + chunks = slicer.slice(audio) + if not os.path.exists(out): + os.makedirs(out) + for i, chunk in enumerate(chunks): + if len(chunk.shape) > 1: + chunk = chunk.T + soundfile.write( + os.path.join( + out, + f"%s_%d.wav" + % (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i), + ), + chunk, + sr, + ) + + +if __name__ == "__main__": + main() diff --git a/tools/subfix_webui.py b/tools/subfix_webui.py new file mode 100644 index 0000000000000000000000000000000000000000..d6624d03601bbfd6b1c4b2c3627b777c6e59cf27 --- /dev/null +++ b/tools/subfix_webui.py @@ -0,0 +1,498 @@ +import argparse,os +import copy +import json +import os +import uuid + +import librosa +import gradio as gr +import numpy as np +import soundfile + +g_json_key_text = "" +g_json_key_path = "" +g_load_file = "" +g_load_format = "" + +g_max_json_index = 0 +g_index = 0 +g_batch = 10 +g_text_list = [] +g_audio_list = [] +g_checkbox_list = [] +g_data_json = [] + + +def reload_data(index, batch): + global g_index + g_index = index + global g_batch + g_batch = batch + datas = g_data_json[index:index+batch] + output = [] + for d in datas: + output.append( + { + g_json_key_text: d[g_json_key_text], + g_json_key_path: d[g_json_key_path] + } + ) + return output + + +def b_change_index(index, batch): + global g_index, g_batch + g_index, g_batch = index, batch + datas = reload_data(index, batch) + output = [] + for i , _ in enumerate(datas): + output.append( + # gr.Textbox( + # label=f"Text {i+index}", + # value=_[g_json_key_text]#text + # ) + { + "__type__":"update", + "label":f"Text {i+index}", + "value":_[g_json_key_text] + } + ) + for _ in range(g_batch - len(datas)): + output.append( + # gr.Textbox( + # label=f"Text", + # value="" + # ) + { + "__type__": "update", + "label": f"Text", + "value": "" + } + ) + for _ in datas: + output.append(_[g_json_key_path]) + for _ in range(g_batch - len(datas)): + output.append(None) + for _ in range(g_batch): + output.append(False) + return output + + +def b_next_index(index, batch): + b_save_file() + if (index + batch) <= g_max_json_index: + return index + batch , *b_change_index(index + batch, batch) + else: + return index, *b_change_index(index, batch) + + +def b_previous_index(index, batch): + b_save_file() + if (index - batch) >= 0: + return index - batch , *b_change_index(index - batch, batch) + else: + return 0, *b_change_index(0, batch) + + +def b_submit_change(*text_list): + global g_data_json + change = False + for i, new_text in enumerate(text_list): + if g_index + i <= g_max_json_index: + new_text = new_text.strip()+' ' + if (g_data_json[g_index + i][g_json_key_text] != new_text): + g_data_json[g_index + i][g_json_key_text] = new_text + change = True + if change: + b_save_file() + return g_index, *b_change_index(g_index, g_batch) + + +def b_delete_audio(*checkbox_list): + global g_data_json, g_index, g_max_json_index + b_save_file() + change = False + for i, checkbox in reversed(list(enumerate(checkbox_list))): + if g_index + i < len(g_data_json): + if (checkbox == True): + g_data_json.pop(g_index + i) + change = True + + g_max_json_index = len(g_data_json)-1 + if g_index > g_max_json_index: + g_index = g_max_json_index + g_index = g_index if g_index >= 0 else 0 + if change: + b_save_file() + # return gr.Slider(value=g_index, maximum=(g_max_json_index if g_max_json_index>=0 else 0)), *b_change_index(g_index, g_batch) + return {"value":g_index,"__type__":"update","maximum":(g_max_json_index if g_max_json_index>=0 else 0)},*b_change_index(g_index, g_batch) + + +def b_invert_selection(*checkbox_list): + new_list = [not item if item is True else True for item in checkbox_list] + return new_list + + +def get_next_path(filename): + base_dir = os.path.dirname(filename) + base_name = os.path.splitext(os.path.basename(filename))[0] + for i in range(100): + new_path = os.path.join(base_dir, f"{base_name}_{str(i).zfill(2)}.wav") + if not os.path.exists(new_path) : + return new_path + return os.path.join(base_dir, f'{str(uuid.uuid4())}.wav') + + +def b_audio_split(audio_breakpoint, *checkbox_list): + global g_data_json , g_max_json_index + checked_index = [] + for i, checkbox in enumerate(checkbox_list): + if (checkbox == True and g_index+i < len(g_data_json)): + checked_index.append(g_index + i) + if len(checked_index) == 1 : + index = checked_index[0] + audio_json = copy.deepcopy(g_data_json[index]) + path = audio_json[g_json_key_path] + data, sample_rate = librosa.load(path, sr=None, mono=True) + audio_maxframe = len(data) + break_frame = int(audio_breakpoint * sample_rate) + + if (break_frame >= 1 and break_frame < audio_maxframe): + audio_first = data[0:break_frame] + audio_second = data[break_frame:] + nextpath = get_next_path(path) + soundfile.write(nextpath, audio_second, sample_rate) + soundfile.write(path, audio_first, sample_rate) + g_data_json.insert(index + 1, audio_json) + g_data_json[index + 1][g_json_key_path] = nextpath + b_save_file() + + g_max_json_index = len(g_data_json) - 1 + # return gr.Slider(value=g_index, maximum=g_max_json_index), *b_change_index(g_index, g_batch) + return {"value":g_index,"maximum":g_max_json_index,"__type__":"update"}, *b_change_index(g_index, g_batch) + +def b_merge_audio(interval_r, *checkbox_list): + global g_data_json , g_max_json_index + b_save_file() + checked_index = [] + audios_path = [] + audios_text = [] + for i, checkbox in enumerate(checkbox_list): + if (checkbox == True and g_index+i < len(g_data_json)): + checked_index.append(g_index + i) + + if (len(checked_index)>1): + for i in checked_index: + audios_path.append(g_data_json[i][g_json_key_path]) + audios_text.append(g_data_json[i][g_json_key_text]) + for i in reversed(checked_index[1:]): + g_data_json.pop(i) + + base_index = checked_index[0] + base_path = audios_path[0] + g_data_json[base_index][g_json_key_text] = "".join(audios_text) + + audio_list = [] + l_sample_rate = None + for i, path in enumerate(audios_path): + data, sample_rate = librosa.load(path, sr=l_sample_rate, mono=True) + l_sample_rate = sample_rate + if (i > 0): + silence = np.zeros(int(l_sample_rate * interval_r)) + audio_list.append(silence) + + audio_list.append(data) + + audio_concat = np.concatenate(audio_list) + + soundfile.write(base_path, audio_concat, l_sample_rate) + + b_save_file() + + g_max_json_index = len(g_data_json) - 1 + + # return gr.Slider(value=g_index, maximum=g_max_json_index), *b_change_index(g_index, g_batch) + return {"value":g_index,"maximum":g_max_json_index,"__type__":"update"}, *b_change_index(g_index, g_batch) + + +def b_save_json(): + with open(g_load_file,'w', encoding="utf-8") as file: + for data in g_data_json: + file.write(f'{json.dumps(data, ensure_ascii = False)}\n') + + +def b_save_list(): + with open(g_load_file,'w', encoding="utf-8") as file: + for data in g_data_json: + wav_path = data["wav_path"] + speaker_name = data["speaker_name"] + language = data["language"] + text = data["text"] + file.write(f"{wav_path}|{speaker_name}|{language}|{text}".strip()+'\n') + + +def b_load_json(): + global g_data_json, g_max_json_index + with open(g_load_file, 'r', encoding="utf-8") as file: + g_data_json = file.readlines() + g_data_json = [json.loads(line) for line in g_data_json] + g_max_json_index = len(g_data_json) - 1 + + +def b_load_list(): + global g_data_json, g_max_json_index + with open(g_load_file, 'r', encoding="utf-8") as source: + data_list = source.readlines() + for _ in data_list: + data = _.split('|') + if (len(data) == 4): + wav_path, speaker_name, language, text = data + g_data_json.append( + { + 'wav_path':wav_path, + 'speaker_name':speaker_name, + 'language':language, + 'text':text.strip() + } + ) + else: + print("error line:", data) + g_max_json_index = len(g_data_json) - 1 + + +def b_save_file(): + if g_load_format == "json": + b_save_json() + elif g_load_format == "list": + b_save_list() + + +def b_load_file(): + if g_load_format == "json": + b_load_json() + elif g_load_format == "list": + b_load_list() + + +def set_global(load_json, load_list, json_key_text, json_key_path, batch): + global g_json_key_text, g_json_key_path, g_load_file, g_load_format, g_batch + + g_batch = int(batch) + + if (load_json != "None"): + g_load_format = "json" + g_load_file = load_json + elif (load_list != "None"): + g_load_format = "list" + g_load_file = load_list + else: + g_load_format = "list" + g_load_file = "demo.list" + + g_json_key_text = json_key_text + g_json_key_path = json_key_path + + b_load_file() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Process some integers.') + parser.add_argument('--load_json', default="None", help='source file, like demo.json') + parser.add_argument('--is_share', default="False", help='whether webui is_share=True') + parser.add_argument('--load_list', default="None", help='source file, like demo.list') + parser.add_argument('--webui_port_subfix', default=9871, help='source file, like demo.list') + parser.add_argument('--json_key_text', default="text", help='the text key name in json, Default: text') + parser.add_argument('--json_key_path', default="wav_path", help='the path key name in json, Default: wav_path') + parser.add_argument('--g_batch', default=10, help='max number g_batch wav to display, Default: 10') + + args = parser.parse_args() + + set_global(args.load_json, args.load_list, args.json_key_text, args.json_key_path, args.g_batch) + + with gr.Blocks() as demo: + + with gr.Row(): + btn_change_index = gr.Button("Change Index") + btn_submit_change = gr.Button("Submit Text") + btn_merge_audio = gr.Button("Merge Audio") + btn_delete_audio = gr.Button("Delete Audio") + btn_previous_index = gr.Button("Previous Index") + btn_next_index = gr.Button("Next Index") + + with gr.Row(): + index_slider = gr.Slider( + minimum=0, maximum=g_max_json_index, value=g_index, step=1, label="Index", scale=3 + ) + splitpoint_slider = gr.Slider( + minimum=0, maximum=120.0, value=0, step=0.1, label="Audio Split Point(s)", scale=3 + ) + btn_audio_split = gr.Button("Split Audio", scale=1) + btn_save_json = gr.Button("Save File", visible=True, scale=1) + btn_invert_selection = gr.Button("Invert Selection", scale=1) + + with gr.Row(): + with gr.Column(): + for _ in range(0,g_batch): + with gr.Row(): + text = gr.Textbox( + label = "Text", + visible = True, + scale=5 + ) + audio_output = gr.Audio( + label="Output Audio", + visible = True, + scale=5 + ) + audio_check = gr.Checkbox( + label="Yes", + show_label = True, + info = "Choose Audio", + scale=1 + ) + g_text_list.append(text) + g_audio_list.append(audio_output) + g_checkbox_list.append(audio_check) + + + + with gr.Row(): + batchsize_slider = gr.Slider( + minimum=1, maximum=g_batch, value=g_batch, step=1, label="Batch Size", scale=3, interactive=False + ) + interval_slider = gr.Slider( + minimum=0, maximum=2, value=0, step=0.01, label="Interval", scale=3 + ) + btn_theme_dark = gr.Button("Light Theme", link="?__theme=light", scale=1) + btn_theme_light = gr.Button("Dark Theme", link="?__theme=dark", scale=1) + + btn_change_index.click( + b_change_index, + inputs=[ + index_slider, + batchsize_slider, + ], + outputs=[ + *g_text_list, + *g_audio_list, + *g_checkbox_list + ], + ) + + + btn_submit_change.click( + b_submit_change, + inputs=[ + *g_text_list, + ], + outputs=[ + index_slider, + *g_text_list, + *g_audio_list, + *g_checkbox_list + ], + ) + + btn_previous_index.click( + b_previous_index, + inputs=[ + index_slider, + batchsize_slider, + ], + outputs=[ + index_slider, + *g_text_list, + *g_audio_list, + *g_checkbox_list + ], + ) + + btn_next_index.click( + b_next_index, + inputs=[ + index_slider, + batchsize_slider, + ], + outputs=[ + index_slider, + *g_text_list, + *g_audio_list, + *g_checkbox_list + ], + ) + + btn_delete_audio.click( + b_delete_audio, + inputs=[ + *g_checkbox_list + ], + outputs=[ + index_slider, + *g_text_list, + *g_audio_list, + *g_checkbox_list + ] + ) + + btn_merge_audio.click( + b_merge_audio, + inputs=[ + interval_slider, + *g_checkbox_list + ], + outputs=[ + index_slider, + *g_text_list, + *g_audio_list, + *g_checkbox_list + ] + ) + + btn_audio_split.click( + b_audio_split, + inputs=[ + splitpoint_slider, + *g_checkbox_list + ], + outputs=[ + index_slider, + *g_text_list, + *g_audio_list, + *g_checkbox_list + ] + ) + + btn_invert_selection.click( + b_invert_selection, + inputs=[ + *g_checkbox_list + ], + outputs=[ + *g_checkbox_list + ] + ) + + btn_save_json.click( + b_save_file + ) + + demo.load( + b_change_index, + inputs=[ + index_slider, + batchsize_slider, + ], + outputs=[ + *g_text_list, + *g_audio_list, + *g_checkbox_list + ], + ) + + demo.launch( + server_name="0.0.0.0", + inbrowser=True, + quiet=True, + share=eval(args.is_share), + server_port=int(args.webui_port_subfix) + ) \ No newline at end of file diff --git a/tools/uvr5/bs_roformer/__init__.py b/tools/uvr5/bs_roformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/uvr5/bs_roformer/attend.py b/tools/uvr5/bs_roformer/attend.py new file mode 100644 index 0000000000000000000000000000000000000000..34476c181629652e10ca866679abbbe4868927e6 --- /dev/null +++ b/tools/uvr5/bs_roformer/attend.py @@ -0,0 +1,120 @@ +from functools import wraps +from packaging import version +from collections import namedtuple + +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, reduce + +# constants + +FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) + +# helpers + +def exists(val): + return val is not None + +def default(v, d): + return v if exists(v) else d + +def once(fn): + called = False + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + return inner + +print_once = once(print) + +# main class + +class Attend(nn.Module): + def __init__( + self, + dropout = 0., + flash = False, + scale = None + ): + super().__init__() + self.scale = scale + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.flash = flash + assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' + + # determine efficient attention configs for cuda and cpu + + self.cpu_config = FlashAttentionConfig(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once('A100 GPU detected, using flash attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(True, False, False) + else: + print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(False, True, True) + + def flash_attn(self, q, k, v): + _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device + + if exists(self.scale): + default_scale = q.shape[-1] ** -0.5 + q = q * (self.scale / default_scale) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, + dropout_p = self.dropout if self.training else 0. + ) + + return out + + def forward(self, q, k, v): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + q_len, k_len, device = q.shape[-2], k.shape[-2], q.device + + scale = default(self.scale, q.shape[-1] ** -0.5) + + if self.flash: + return self.flash_attn(q, k, v) + + # similarity + + sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, b h j d -> b h i d", attn, v) + + return out diff --git a/tools/uvr5/bs_roformer/bs_roformer.py b/tools/uvr5/bs_roformer/bs_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..88af3caa06369f2b815fd6cea532f8ba6e974aa2 --- /dev/null +++ b/tools/uvr5/bs_roformer/bs_roformer.py @@ -0,0 +1,583 @@ +from functools import partial + +import torch +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from bs_roformer.attend import Attend + +from typing import Tuple, Optional, List, Callable +# from beartype.typing import Tuple, Optional, List, Callable +# from beartype import beartype + +from rotary_embedding_torch import RotaryEmbedding + +from einops import rearrange, pack, unpack +from einops.layers.torch import Rearrange + +# helper functions + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# norm + +def l2norm(t): + return F.normalize(t, dim = -1, p = 2) + + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +# attention + +class FeedForward(Module): + def __init__( + self, + dim, + mult=4, + dropout=0. + ): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0., + rotary_embed=None, + flash=True + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + self.attend = Attend(flash=flash, dropout=dropout) + + self.norm = RMSNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=False), + nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) + + if exists(self.rotary_embed): + q = self.rotary_embed.rotate_queries_or_keys(q) + k = self.rotary_embed.rotate_queries_or_keys(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + # @beartype + def __init__( + self, + *, + dim, + dim_head=32, + heads=8, + scale=8, + flash=False, + dropout=0. + ): + super().__init__() + dim_inner = dim_head * heads + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads) + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.attend = Attend( + scale=scale, + dropout=dropout, + flash=flash + ) + + self.to_out = nn.Sequential( + Rearrange('b h d n -> b n (h d)'), + nn.Linear(dim_inner, dim, bias=False) + ) + + def forward( + self, + x + ): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + if linear_attn: + attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn) + else: + attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, + rotary_embed=rotary_embed, flash=flash_attn) + + self.layers.append(ModuleList([ + attn, + FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) + ])) + + self.norm = RMSNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +# bandsplit module + +class BandSplit(Module): + # @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...] + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential( + RMSNorm(dim_in), + nn.Linear(dim_in, dim) + ) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP( + dim_in, + dim_out, + dim_hidden=None, + depth=1, + activation=nn.Tanh +): + dim_hidden = default(dim_hidden, dim_in) + + net = [] + dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + # @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...], + depth, + mlp_expansion_factor=4 + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + net = [] + + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), + nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# main class + +DEFAULT_FREQS_PER_BANDS = ( + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 12, 12, 12, 12, 12, 12, 12, 12, + 24, 24, 24, 24, 24, 24, 24, 24, + 48, 48, 48, 48, 48, 48, 48, 48, + 128, 129, +) + + +class BSRoformer(Module): + + # @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, + # in the paper, they divide into ~60 bands, test with 1 for starters + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + flash_attn=True, + dim_freqs_in=1025, + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=2, + multi_stft_resolution_loss_weight=1., + multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + + self.layers = ModuleList([]) + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn, + norm_output=False + ) + + time_rotary_embed = RotaryEmbedding(dim=dim_head) + freq_rotary_embed = RotaryEmbedding(dim=dim_head) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs)) + tran_modules.append( + Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs) + ) + tran_modules.append( + Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.final_norm = RMSNorm(dim) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized + ) + + self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) + + freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1] + + assert len(freqs_per_bands) > 1 + assert sum( + freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}' + + freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands) + + self.band_split = BandSplit( + dim=dim, + dim_inputs=freqs_per_bands_with_complex + ) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, + normalized=multi_stft_normalized + ) + + def forward( + self, + raw_audio, + target=None, + return_loss_breakdown=False + ): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, 'b t -> b 1 t') + + channels = raw_audio.shape[1] + assert (not self.stereo and channels == 1) or ( + self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') + + stft_window = self.stft_window_fn(device=device) + + stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') + stft_repr = rearrange(stft_repr, + 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + + x = rearrange(stft_repr, 'b f t c -> b t (f c)') + # print("460:", x.dtype)#fp32 + x = self.band_split(x) + + # axial / hierarchical attention + + # print("487:",x.dtype)#fp16 + for transformer_block in self.layers: + + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = transformer_block + + x, ft_ps = pack([x], 'b * d') + # print("494:", x.dtype)#fp16 + x = linear_transformer(x) + # print("496:", x.dtype)#fp16 + x, = unpack(x, ft_ps, 'b * d') + else: + time_transformer, freq_transformer = transformer_block + + # print("501:", x.dtype)#fp16 + x = rearrange(x, 'b t f d -> b f t d') + x, ps = pack([x], '* t d') + + x = time_transformer(x) + # print("505:", x.dtype)#fp16 + x, = unpack(x, ps, '* t d') + x = rearrange(x, 'b f t d -> b t f d') + x, ps = pack([x], '* f d') + + x = freq_transformer(x) + + x, = unpack(x, ps, '* f d') + + # print("515:", x.dtype)######fp16 + x = self.final_norm(x) + + num_stems = len(self.mask_estimators) + # print("519:", x.dtype)#fp32 + mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + mask = torch.view_as_complex(mask) + + stft_repr = stft_repr * mask + + # istft + + stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) + + recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False) + + recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') + + # if a target is passed in, calculate loss for learning + + if not exists(target): + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, '... t -> ... 1 t') + + target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0. + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) + target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) + + weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) \ No newline at end of file diff --git a/tools/uvr5/bsroformer.py b/tools/uvr5/bsroformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d162032686d8b9478010bd4e0bd154f56200069e --- /dev/null +++ b/tools/uvr5/bsroformer.py @@ -0,0 +1,216 @@ +# This code is modified from https://github.com/ZFTurbo/ +import pdb + +import librosa +from tqdm import tqdm +import os +import torch +import numpy as np +import soundfile as sf +import torch.nn as nn + +import warnings +warnings.filterwarnings("ignore") +from bs_roformer.bs_roformer import BSRoformer + +class BsRoformer_Loader: + def get_model_from_config(self): + config = { + "attn_dropout": 0.1, + "depth": 12, + "dim": 512, + "dim_freqs_in": 1025, + "dim_head": 64, + "ff_dropout": 0.1, + "flash_attn": True, + "freq_transformer_depth": 1, + "freqs_per_bands":(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129), + "heads": 8, + "linear_transformer_depth": 0, + "mask_estimator_depth": 2, + "multi_stft_hop_size": 147, + "multi_stft_normalized": False, + "multi_stft_resolution_loss_weight": 1.0, + "multi_stft_resolutions_window_sizes":(4096, 2048, 1024, 512, 256), + "num_stems": 1, + "stereo": True, + "stft_hop_length": 441, + "stft_n_fft": 2048, + "stft_normalized": False, + "stft_win_length": 2048, + "time_transformer_depth": 1, + + } + + + model = BSRoformer( + **dict(config) + ) + + return model + + + def demix_track(self, model, mix, device): + C = 352800 + # num_overlap + N = 1 + fade_size = C // 10 + step = int(C // N) + border = C - step + batch_size = 4 + + length_init = mix.shape[-1] + + progress_bar = tqdm(total=length_init // step + 1) + progress_bar.set_description("Processing") + + # Do pad from the beginning and end to account floating window results better + if length_init > 2 * border and (border > 0): + mix = nn.functional.pad(mix, (border, border), mode='reflect') + + # Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment + window_size = C + fadein = torch.linspace(0, 1, fade_size) + fadeout = torch.linspace(1, 0, fade_size) + window_start = torch.ones(window_size) + window_middle = torch.ones(window_size) + window_finish = torch.ones(window_size) + window_start[-fade_size:] *= fadeout # First audio chunk, no fadein + window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout + window_middle[-fade_size:] *= fadeout + window_middle[:fade_size] *= fadein + + with torch.amp.autocast('cuda'): + with torch.inference_mode(): + req_shape = (1, ) + tuple(mix.shape) + + result = torch.zeros(req_shape, dtype=torch.float32) + counter = torch.zeros(req_shape, dtype=torch.float32) + i = 0 + batch_data = [] + batch_locations = [] + while i < mix.shape[1]: + part = mix[:, i:i + C].to(device) + length = part.shape[-1] + if length < C: + if length > C // 2 + 1: + part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') + else: + part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) + if(self.is_half==True): + part=part.half() + batch_data.append(part) + batch_locations.append((i, length)) + i += step + progress_bar.update(1) + + if len(batch_data) >= batch_size or (i >= mix.shape[1]): + arr = torch.stack(batch_data, dim=0) + # print(23333333,arr.dtype) + x = model(arr) + + window = window_middle + if i - step == 0: # First audio chunk, no fadein + window = window_start + elif i >= mix.shape[1]: # Last audio chunk, no fadeout + window = window_finish + + for j in range(len(batch_locations)): + start, l = batch_locations[j] + result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l] + counter[..., start:start+l] += window[..., :l] + + batch_data = [] + batch_locations = [] + + estimated_sources = result / counter + estimated_sources = estimated_sources.cpu().numpy() + np.nan_to_num(estimated_sources, copy=False, nan=0.0) + + if length_init > 2 * border and (border > 0): + # Remove pad + estimated_sources = estimated_sources[..., border:-border] + + progress_bar.close() + + return {k: v for k, v in zip(['vocals', 'other'], estimated_sources)} + + + def run_folder(self,input, vocal_root, others_root, format): + # start_time = time.time() + self.model.eval() + path = input + + if not os.path.isdir(vocal_root): + os.mkdir(vocal_root) + + if not os.path.isdir(others_root): + os.mkdir(others_root) + + try: + mix, sr = librosa.load(path, sr=44100, mono=False) + except Exception as e: + print('Can read track: {}'.format(path)) + print('Error message: {}'.format(str(e))) + return + + # Convert mono to stereo if needed + if len(mix.shape) == 1: + mix = np.stack([mix, mix], axis=0) + + mix_orig = mix.copy() + + mixture = torch.tensor(mix, dtype=torch.float32) + res = self.demix_track(self.model, mixture, self.device) + + estimates = res['vocals'].T + + if format in ["wav", "flac"]: + sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr) + sf.write("{}/{}_{}.{}".format(others_root, os.path.basename(path)[:-4], 'instrumental', format), mix_orig.T - estimates, sr) + else: + path_vocal = "%s/%s_vocals.wav" % (vocal_root, os.path.basename(path)[:-4]) + path_other = "%s/%s_instrumental.wav" % (others_root, os.path.basename(path)[:-4]) + sf.write(path_vocal, estimates, sr) + sf.write(path_other, mix_orig.T - estimates, sr) + opt_path_vocal = path_vocal[:-4] + ".%s" % format + opt_path_other = path_other[:-4] + ".%s" % format + if os.path.exists(path_vocal): + os.system( + "ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal) + ) + if os.path.exists(opt_path_vocal): + try: + os.remove(path_vocal) + except: + pass + if os.path.exists(path_other): + os.system( + "ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other) + ) + if os.path.exists(opt_path_other): + try: + os.remove(path_other) + except: + pass + + # print("Elapsed time: {:.2f} sec".format(time.time() - start_time)) + + + def __init__(self, model_path, device,is_half): + self.device = device + self.extract_instrumental=True + + model = self.get_model_from_config() + state_dict = torch.load(model_path,map_location="cpu") + model.load_state_dict(state_dict) + self.is_half=is_half + if(is_half==False): + self.model = model.to(device) + else: + self.model = model.half().to(device) + + + def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False): + self.run_folder(input, vocal_root, others_root, format) + diff --git a/tools/uvr5/lib/lib_v5/dataset.py b/tools/uvr5/lib/lib_v5/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd01a174978d97180a897e40cb59ecadec1d12e --- /dev/null +++ b/tools/uvr5/lib/lib_v5/dataset.py @@ -0,0 +1,183 @@ +import os +import random + +import numpy as np +import torch +import torch.utils.data +from tqdm import tqdm + +from . import spec_utils + + +class VocalRemoverValidationSet(torch.utils.data.Dataset): + def __init__(self, patch_list): + self.patch_list = patch_list + + def __len__(self): + return len(self.patch_list) + + def __getitem__(self, idx): + path = self.patch_list[idx] + data = np.load(path) + + X, y = data["X"], data["y"] + + X_mag = np.abs(X) + y_mag = np.abs(y) + + return X_mag, y_mag + + +def make_pair(mix_dir, inst_dir): + input_exts = [".wav", ".m4a", ".mp3", ".mp4", ".flac"] + + X_list = sorted( + [ + os.path.join(mix_dir, fname) + for fname in os.listdir(mix_dir) + if os.path.splitext(fname)[1] in input_exts + ] + ) + y_list = sorted( + [ + os.path.join(inst_dir, fname) + for fname in os.listdir(inst_dir) + if os.path.splitext(fname)[1] in input_exts + ] + ) + + filelist = list(zip(X_list, y_list)) + + return filelist + + +def train_val_split(dataset_dir, split_mode, val_rate, val_filelist): + if split_mode == "random": + filelist = make_pair( + os.path.join(dataset_dir, "mixtures"), + os.path.join(dataset_dir, "instruments"), + ) + + random.shuffle(filelist) + + if len(val_filelist) == 0: + val_size = int(len(filelist) * val_rate) + train_filelist = filelist[:-val_size] + val_filelist = filelist[-val_size:] + else: + train_filelist = [ + pair for pair in filelist if list(pair) not in val_filelist + ] + elif split_mode == "subdirs": + if len(val_filelist) != 0: + raise ValueError( + "The `val_filelist` option is not available in `subdirs` mode" + ) + + train_filelist = make_pair( + os.path.join(dataset_dir, "training/mixtures"), + os.path.join(dataset_dir, "training/instruments"), + ) + + val_filelist = make_pair( + os.path.join(dataset_dir, "validation/mixtures"), + os.path.join(dataset_dir, "validation/instruments"), + ) + + return train_filelist, val_filelist + + +def augment(X, y, reduction_rate, reduction_mask, mixup_rate, mixup_alpha): + perm = np.random.permutation(len(X)) + for i, idx in enumerate(tqdm(perm)): + if np.random.uniform() < reduction_rate: + y[idx] = spec_utils.reduce_vocal_aggressively( + X[idx], y[idx], reduction_mask + ) + + if np.random.uniform() < 0.5: + # swap channel + X[idx] = X[idx, ::-1] + y[idx] = y[idx, ::-1] + if np.random.uniform() < 0.02: + # mono + X[idx] = X[idx].mean(axis=0, keepdims=True) + y[idx] = y[idx].mean(axis=0, keepdims=True) + if np.random.uniform() < 0.02: + # inst + X[idx] = y[idx] + + if np.random.uniform() < mixup_rate and i < len(perm) - 1: + lam = np.random.beta(mixup_alpha, mixup_alpha) + X[idx] = lam * X[idx] + (1 - lam) * X[perm[i + 1]] + y[idx] = lam * y[idx] + (1 - lam) * y[perm[i + 1]] + + return X, y + + +def make_padding(width, cropsize, offset): + left = offset + roi_size = cropsize - left * 2 + if roi_size == 0: + roi_size = cropsize + right = roi_size - (width % roi_size) + left + + return left, right, roi_size + + +def make_training_set(filelist, cropsize, patches, sr, hop_length, n_fft, offset): + len_dataset = patches * len(filelist) + + X_dataset = np.zeros((len_dataset, 2, n_fft // 2 + 1, cropsize), dtype=np.complex64) + y_dataset = np.zeros((len_dataset, 2, n_fft // 2 + 1, cropsize), dtype=np.complex64) + + for i, (X_path, y_path) in enumerate(tqdm(filelist)): + X, y = spec_utils.cache_or_load(X_path, y_path, sr, hop_length, n_fft) + coef = np.max([np.abs(X).max(), np.abs(y).max()]) + X, y = X / coef, y / coef + + l, r, roi_size = make_padding(X.shape[2], cropsize, offset) + X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode="constant") + y_pad = np.pad(y, ((0, 0), (0, 0), (l, r)), mode="constant") + + starts = np.random.randint(0, X_pad.shape[2] - cropsize, patches) + ends = starts + cropsize + for j in range(patches): + idx = i * patches + j + X_dataset[idx] = X_pad[:, :, starts[j] : ends[j]] + y_dataset[idx] = y_pad[:, :, starts[j] : ends[j]] + + return X_dataset, y_dataset + + +def make_validation_set(filelist, cropsize, sr, hop_length, n_fft, offset): + patch_list = [] + patch_dir = "cs{}_sr{}_hl{}_nf{}_of{}".format( + cropsize, sr, hop_length, n_fft, offset + ) + os.makedirs(patch_dir, exist_ok=True) + + for i, (X_path, y_path) in enumerate(tqdm(filelist)): + basename = os.path.splitext(os.path.basename(X_path))[0] + + X, y = spec_utils.cache_or_load(X_path, y_path, sr, hop_length, n_fft) + coef = np.max([np.abs(X).max(), np.abs(y).max()]) + X, y = X / coef, y / coef + + l, r, roi_size = make_padding(X.shape[2], cropsize, offset) + X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode="constant") + y_pad = np.pad(y, ((0, 0), (0, 0), (l, r)), mode="constant") + + len_dataset = int(np.ceil(X.shape[2] / roi_size)) + for j in range(len_dataset): + outpath = os.path.join(patch_dir, "{}_p{}.npz".format(basename, j)) + start = j * roi_size + if not os.path.exists(outpath): + np.savez( + outpath, + X=X_pad[:, :, start : start + cropsize], + y=y_pad[:, :, start : start + cropsize], + ) + patch_list.append(outpath) + + return VocalRemoverValidationSet(patch_list) diff --git a/tools/uvr5/lib/lib_v5/layers.py b/tools/uvr5/lib/lib_v5/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc1b5cb85a3327f60cbb9f5deffbeeaaac516ad --- /dev/null +++ b/tools/uvr5/lib/lib_v5/layers.py @@ -0,0 +1,118 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from . import spec_utils + + +class Conv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class SeperableConv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(SeperableConv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nin, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + groups=nin, + bias=False, + ), + nn.Conv2d(nin, nout, kernel_size=1, bias=False), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class Encoder(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ) + + def __call__(self, x): + skip = self.conv1(x) + h = self.conv2(skip) + + return h, skip + + +class Decoder(nn.Module): + def __init__( + self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False + ): + super(Decoder, self).__init__() + self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + if skip is not None: + skip = spec_utils.crop_center(skip, x) + x = torch.cat([x, skip], dim=1) + h = self.conv(x) + + if self.dropout is not None: + h = self.dropout(h) + + return h + + +class ASPPModule(nn.Module): + def __init__(self, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU): + super(ASPPModule, self).__init__() + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), + ) + self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) + self.conv3 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[0], dilations[0], activ=activ + ) + self.conv4 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[1], dilations[1], activ=activ + ) + self.conv5 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.bottleneck = nn.Sequential( + Conv2DBNActiv(nin * 5, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1) + ) + + def forward(self, x): + _, _, h, w = x.size() + feat1 = F.interpolate( + self.conv1(x), size=(h, w), mode="bilinear", align_corners=True + ) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) + bottle = self.bottleneck(out) + return bottle diff --git a/tools/uvr5/lib/lib_v5/layers_123812KB.py b/tools/uvr5/lib/lib_v5/layers_123812KB.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc1b5cb85a3327f60cbb9f5deffbeeaaac516ad --- /dev/null +++ b/tools/uvr5/lib/lib_v5/layers_123812KB.py @@ -0,0 +1,118 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from . import spec_utils + + +class Conv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class SeperableConv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(SeperableConv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nin, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + groups=nin, + bias=False, + ), + nn.Conv2d(nin, nout, kernel_size=1, bias=False), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class Encoder(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ) + + def __call__(self, x): + skip = self.conv1(x) + h = self.conv2(skip) + + return h, skip + + +class Decoder(nn.Module): + def __init__( + self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False + ): + super(Decoder, self).__init__() + self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + if skip is not None: + skip = spec_utils.crop_center(skip, x) + x = torch.cat([x, skip], dim=1) + h = self.conv(x) + + if self.dropout is not None: + h = self.dropout(h) + + return h + + +class ASPPModule(nn.Module): + def __init__(self, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU): + super(ASPPModule, self).__init__() + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), + ) + self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) + self.conv3 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[0], dilations[0], activ=activ + ) + self.conv4 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[1], dilations[1], activ=activ + ) + self.conv5 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.bottleneck = nn.Sequential( + Conv2DBNActiv(nin * 5, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1) + ) + + def forward(self, x): + _, _, h, w = x.size() + feat1 = F.interpolate( + self.conv1(x), size=(h, w), mode="bilinear", align_corners=True + ) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) + bottle = self.bottleneck(out) + return bottle diff --git a/tools/uvr5/lib/lib_v5/layers_123821KB.py b/tools/uvr5/lib/lib_v5/layers_123821KB.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc1b5cb85a3327f60cbb9f5deffbeeaaac516ad --- /dev/null +++ b/tools/uvr5/lib/lib_v5/layers_123821KB.py @@ -0,0 +1,118 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from . import spec_utils + + +class Conv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class SeperableConv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(SeperableConv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nin, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + groups=nin, + bias=False, + ), + nn.Conv2d(nin, nout, kernel_size=1, bias=False), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class Encoder(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ) + + def __call__(self, x): + skip = self.conv1(x) + h = self.conv2(skip) + + return h, skip + + +class Decoder(nn.Module): + def __init__( + self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False + ): + super(Decoder, self).__init__() + self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + if skip is not None: + skip = spec_utils.crop_center(skip, x) + x = torch.cat([x, skip], dim=1) + h = self.conv(x) + + if self.dropout is not None: + h = self.dropout(h) + + return h + + +class ASPPModule(nn.Module): + def __init__(self, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU): + super(ASPPModule, self).__init__() + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), + ) + self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) + self.conv3 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[0], dilations[0], activ=activ + ) + self.conv4 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[1], dilations[1], activ=activ + ) + self.conv5 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.bottleneck = nn.Sequential( + Conv2DBNActiv(nin * 5, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1) + ) + + def forward(self, x): + _, _, h, w = x.size() + feat1 = F.interpolate( + self.conv1(x), size=(h, w), mode="bilinear", align_corners=True + ) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) + bottle = self.bottleneck(out) + return bottle diff --git a/tools/uvr5/lib/lib_v5/layers_33966KB.py b/tools/uvr5/lib/lib_v5/layers_33966KB.py new file mode 100644 index 0000000000000000000000000000000000000000..9b127bc6427f5c60c8cf85603a3d8a093c3501c4 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/layers_33966KB.py @@ -0,0 +1,126 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from . import spec_utils + + +class Conv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class SeperableConv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(SeperableConv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nin, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + groups=nin, + bias=False, + ), + nn.Conv2d(nin, nout, kernel_size=1, bias=False), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class Encoder(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ) + + def __call__(self, x): + skip = self.conv1(x) + h = self.conv2(skip) + + return h, skip + + +class Decoder(nn.Module): + def __init__( + self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False + ): + super(Decoder, self).__init__() + self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + if skip is not None: + skip = spec_utils.crop_center(skip, x) + x = torch.cat([x, skip], dim=1) + h = self.conv(x) + + if self.dropout is not None: + h = self.dropout(h) + + return h + + +class ASPPModule(nn.Module): + def __init__(self, nin, nout, dilations=(4, 8, 16, 32, 64), activ=nn.ReLU): + super(ASPPModule, self).__init__() + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), + ) + self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) + self.conv3 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[0], dilations[0], activ=activ + ) + self.conv4 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[1], dilations[1], activ=activ + ) + self.conv5 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.conv6 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.conv7 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.bottleneck = nn.Sequential( + Conv2DBNActiv(nin * 7, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1) + ) + + def forward(self, x): + _, _, h, w = x.size() + feat1 = F.interpolate( + self.conv1(x), size=(h, w), mode="bilinear", align_corners=True + ) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + feat6 = self.conv6(x) + feat7 = self.conv7(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1) + bottle = self.bottleneck(out) + return bottle diff --git a/tools/uvr5/lib/lib_v5/layers_537227KB.py b/tools/uvr5/lib/lib_v5/layers_537227KB.py new file mode 100644 index 0000000000000000000000000000000000000000..9b127bc6427f5c60c8cf85603a3d8a093c3501c4 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/layers_537227KB.py @@ -0,0 +1,126 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from . import spec_utils + + +class Conv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class SeperableConv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(SeperableConv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nin, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + groups=nin, + bias=False, + ), + nn.Conv2d(nin, nout, kernel_size=1, bias=False), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class Encoder(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ) + + def __call__(self, x): + skip = self.conv1(x) + h = self.conv2(skip) + + return h, skip + + +class Decoder(nn.Module): + def __init__( + self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False + ): + super(Decoder, self).__init__() + self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + if skip is not None: + skip = spec_utils.crop_center(skip, x) + x = torch.cat([x, skip], dim=1) + h = self.conv(x) + + if self.dropout is not None: + h = self.dropout(h) + + return h + + +class ASPPModule(nn.Module): + def __init__(self, nin, nout, dilations=(4, 8, 16, 32, 64), activ=nn.ReLU): + super(ASPPModule, self).__init__() + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), + ) + self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) + self.conv3 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[0], dilations[0], activ=activ + ) + self.conv4 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[1], dilations[1], activ=activ + ) + self.conv5 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.conv6 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.conv7 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.bottleneck = nn.Sequential( + Conv2DBNActiv(nin * 7, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1) + ) + + def forward(self, x): + _, _, h, w = x.size() + feat1 = F.interpolate( + self.conv1(x), size=(h, w), mode="bilinear", align_corners=True + ) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + feat6 = self.conv6(x) + feat7 = self.conv7(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1) + bottle = self.bottleneck(out) + return bottle diff --git a/tools/uvr5/lib/lib_v5/layers_537238KB.py b/tools/uvr5/lib/lib_v5/layers_537238KB.py new file mode 100644 index 0000000000000000000000000000000000000000..9b127bc6427f5c60c8cf85603a3d8a093c3501c4 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/layers_537238KB.py @@ -0,0 +1,126 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from . import spec_utils + + +class Conv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class SeperableConv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(SeperableConv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nin, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + groups=nin, + bias=False, + ), + nn.Conv2d(nin, nout, kernel_size=1, bias=False), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class Encoder(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ) + + def __call__(self, x): + skip = self.conv1(x) + h = self.conv2(skip) + + return h, skip + + +class Decoder(nn.Module): + def __init__( + self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False + ): + super(Decoder, self).__init__() + self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + if skip is not None: + skip = spec_utils.crop_center(skip, x) + x = torch.cat([x, skip], dim=1) + h = self.conv(x) + + if self.dropout is not None: + h = self.dropout(h) + + return h + + +class ASPPModule(nn.Module): + def __init__(self, nin, nout, dilations=(4, 8, 16, 32, 64), activ=nn.ReLU): + super(ASPPModule, self).__init__() + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), + ) + self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) + self.conv3 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[0], dilations[0], activ=activ + ) + self.conv4 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[1], dilations[1], activ=activ + ) + self.conv5 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.conv6 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.conv7 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.bottleneck = nn.Sequential( + Conv2DBNActiv(nin * 7, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1) + ) + + def forward(self, x): + _, _, h, w = x.size() + feat1 = F.interpolate( + self.conv1(x), size=(h, w), mode="bilinear", align_corners=True + ) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + feat6 = self.conv6(x) + feat7 = self.conv7(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1) + bottle = self.bottleneck(out) + return bottle diff --git a/tools/uvr5/lib/lib_v5/layers_new.py b/tools/uvr5/lib/lib_v5/layers_new.py new file mode 100644 index 0000000000000000000000000000000000000000..44153b6a23399c6938affc61c71919eaa172bcee --- /dev/null +++ b/tools/uvr5/lib/lib_v5/layers_new.py @@ -0,0 +1,125 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from . import spec_utils + + +class Conv2DBNActiv(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, x): + return self.conv(x) + + +class Encoder(nn.Module): + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ) + self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) + + def __call__(self, x): + h = self.conv1(x) + h = self.conv2(h) + + return h + + +class Decoder(nn.Module): + def __init__( + self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False + ): + super(Decoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + + if skip is not None: + skip = spec_utils.crop_center(skip, x) + x = torch.cat([x, skip], dim=1) + + h = self.conv1(x) + # h = self.conv2(h) + + if self.dropout is not None: + h = self.dropout(h) + + return h + + +class ASPPModule(nn.Module): + def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False): + super(ASPPModule, self).__init__() + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ), + ) + self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) + self.conv3 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[0], dilations[0], activ=activ + ) + self.conv4 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[1], dilations[1], activ=activ + ) + self.conv5 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def forward(self, x): + _, _, h, w = x.size() + feat1 = F.interpolate( + self.conv1(x), size=(h, w), mode="bilinear", align_corners=True + ) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) + out = self.bottleneck(out) + + if self.dropout is not None: + out = self.dropout(out) + + return out + + +class LSTMModule(nn.Module): + def __init__(self, nin_conv, nin_lstm, nout_lstm): + super(LSTMModule, self).__init__() + self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0) + self.lstm = nn.LSTM( + input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True + ) + self.dense = nn.Sequential( + nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU() + ) + + def forward(self, x): + N, _, nbins, nframes = x.size() + h = self.conv(x)[:, 0] # N, nbins, nframes + h = h.permute(2, 0, 1) # nframes, N, nbins + h, _ = self.lstm(h) + h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins + h = h.reshape(nframes, N, 1, nbins) + h = h.permute(1, 2, 3, 0) + + return h diff --git a/tools/uvr5/lib/lib_v5/model_param_init.py b/tools/uvr5/lib/lib_v5/model_param_init.py new file mode 100644 index 0000000000000000000000000000000000000000..b995c0bfb1194746187692e2ab1c2a6dbaaaec6c --- /dev/null +++ b/tools/uvr5/lib/lib_v5/model_param_init.py @@ -0,0 +1,69 @@ +import json +import os +import pathlib + +default_param = {} +default_param["bins"] = 768 +default_param["unstable_bins"] = 9 # training only +default_param["reduction_bins"] = 762 # training only +default_param["sr"] = 44100 +default_param["pre_filter_start"] = 757 +default_param["pre_filter_stop"] = 768 +default_param["band"] = {} + + +default_param["band"][1] = { + "sr": 11025, + "hl": 128, + "n_fft": 960, + "crop_start": 0, + "crop_stop": 245, + "lpf_start": 61, # inference only + "res_type": "polyphase", +} + +default_param["band"][2] = { + "sr": 44100, + "hl": 512, + "n_fft": 1536, + "crop_start": 24, + "crop_stop": 547, + "hpf_start": 81, # inference only + "res_type": "sinc_best", +} + + +def int_keys(d): + r = {} + for k, v in d: + if k.isdigit(): + k = int(k) + r[k] = v + return r + + +class ModelParameters(object): + def __init__(self, config_path=""): + if ".pth" == pathlib.Path(config_path).suffix: + import zipfile + + with zipfile.ZipFile(config_path, "r") as zip: + self.param = json.loads( + zip.read("param.json"), object_pairs_hook=int_keys + ) + elif ".json" == pathlib.Path(config_path).suffix: + with open(config_path, "r") as f: + self.param = json.loads(f.read(), object_pairs_hook=int_keys) + else: + self.param = default_param + + for k in [ + "mid_side", + "mid_side_b", + "mid_side_b2", + "stereo_w", + "stereo_n", + "reverse", + ]: + if not k in self.param: + self.param[k] = False diff --git a/tools/uvr5/lib/lib_v5/modelparams/1band_sr16000_hl512.json b/tools/uvr5/lib/lib_v5/modelparams/1band_sr16000_hl512.json new file mode 100644 index 0000000000000000000000000000000000000000..72cb4499867ad2827185e85687f06fb73d33eced --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/1band_sr16000_hl512.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 16000, + "hl": 512, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 16000, + "pre_filter_start": 1023, + "pre_filter_stop": 1024 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/1band_sr32000_hl512.json b/tools/uvr5/lib/lib_v5/modelparams/1band_sr32000_hl512.json new file mode 100644 index 0000000000000000000000000000000000000000..3c00ecf0a105e55a6a86a3c32db301a2635b5b41 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/1band_sr32000_hl512.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 32000, + "hl": 512, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "kaiser_fast" + } + }, + "sr": 32000, + "pre_filter_start": 1000, + "pre_filter_stop": 1021 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/1band_sr33075_hl384.json b/tools/uvr5/lib/lib_v5/modelparams/1band_sr33075_hl384.json new file mode 100644 index 0000000000000000000000000000000000000000..55666ac9a8d0547751fb4b4d3bffb1ee2c956913 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/1band_sr33075_hl384.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 33075, + "hl": 384, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 33075, + "pre_filter_start": 1000, + "pre_filter_stop": 1021 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl1024.json b/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl1024.json new file mode 100644 index 0000000000000000000000000000000000000000..665abe20eb3cc39fe0f8493dad8f25f6ef634a14 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl1024.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 1024, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 1023, + "pre_filter_stop": 1024 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl256.json b/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl256.json new file mode 100644 index 0000000000000000000000000000000000000000..0e8b16f89b0231d06eabe8d2f7c2670c7caa2272 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl256.json @@ -0,0 +1,19 @@ +{ + "bins": 256, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 256, + "n_fft": 512, + "crop_start": 0, + "crop_stop": 256, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 256, + "pre_filter_stop": 256 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl512.json b/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl512.json new file mode 100644 index 0000000000000000000000000000000000000000..3b38fcaf60ba204e03a47f5bd3f5bcfe75e1983a --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl512.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 512, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 1023, + "pre_filter_stop": 1024 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl512_cut.json b/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl512_cut.json new file mode 100644 index 0000000000000000000000000000000000000000..630df3524e340f43a1ddb7b33ff02cc91fc1cb47 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl512_cut.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 512, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 700, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 1023, + "pre_filter_stop": 700 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/2band_32000.json b/tools/uvr5/lib/lib_v5/modelparams/2band_32000.json new file mode 100644 index 0000000000000000000000000000000000000000..ab9cf1150a818eb6252105408311be0a40d423b3 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/2band_32000.json @@ -0,0 +1,30 @@ +{ + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 705, + "band": { + "1": { + "sr": 6000, + "hl": 66, + "n_fft": 512, + "crop_start": 0, + "crop_stop": 240, + "lpf_start": 60, + "lpf_stop": 118, + "res_type": "sinc_fastest" + }, + "2": { + "sr": 32000, + "hl": 352, + "n_fft": 1024, + "crop_start": 22, + "crop_stop": 505, + "hpf_start": 44, + "hpf_stop": 23, + "res_type": "sinc_medium" + } + }, + "sr": 32000, + "pre_filter_start": 710, + "pre_filter_stop": 731 +} diff --git a/tools/uvr5/lib/lib_v5/modelparams/2band_44100_lofi.json b/tools/uvr5/lib/lib_v5/modelparams/2band_44100_lofi.json new file mode 100644 index 0000000000000000000000000000000000000000..7faa216d7b49aeece24123dbdd868847a1dbc03c --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/2band_44100_lofi.json @@ -0,0 +1,30 @@ +{ + "bins": 512, + "unstable_bins": 7, + "reduction_bins": 510, + "band": { + "1": { + "sr": 11025, + "hl": 160, + "n_fft": 768, + "crop_start": 0, + "crop_stop": 192, + "lpf_start": 41, + "lpf_stop": 139, + "res_type": "sinc_fastest" + }, + "2": { + "sr": 44100, + "hl": 640, + "n_fft": 1024, + "crop_start": 10, + "crop_stop": 320, + "hpf_start": 47, + "hpf_stop": 15, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 510, + "pre_filter_stop": 512 +} diff --git a/tools/uvr5/lib/lib_v5/modelparams/2band_48000.json b/tools/uvr5/lib/lib_v5/modelparams/2band_48000.json new file mode 100644 index 0000000000000000000000000000000000000000..7e78175052b09cb1a32345e54006475992712f9a --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/2band_48000.json @@ -0,0 +1,30 @@ +{ + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 705, + "band": { + "1": { + "sr": 6000, + "hl": 66, + "n_fft": 512, + "crop_start": 0, + "crop_stop": 240, + "lpf_start": 60, + "lpf_stop": 240, + "res_type": "sinc_fastest" + }, + "2": { + "sr": 48000, + "hl": 528, + "n_fft": 1536, + "crop_start": 22, + "crop_stop": 505, + "hpf_start": 82, + "hpf_stop": 22, + "res_type": "sinc_medium" + } + }, + "sr": 48000, + "pre_filter_start": 710, + "pre_filter_stop": 731 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/3band_44100.json b/tools/uvr5/lib/lib_v5/modelparams/3band_44100.json new file mode 100644 index 0000000000000000000000000000000000000000..d881d767ff83fbac0e18dfe2587ef16925b29b3c --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/3band_44100.json @@ -0,0 +1,42 @@ +{ + "bins": 768, + "unstable_bins": 5, + "reduction_bins": 733, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 768, + "crop_start": 0, + "crop_stop": 278, + "lpf_start": 28, + "lpf_stop": 140, + "res_type": "polyphase" + }, + "2": { + "sr": 22050, + "hl": 256, + "n_fft": 768, + "crop_start": 14, + "crop_stop": 322, + "hpf_start": 70, + "hpf_stop": 14, + "lpf_start": 283, + "lpf_stop": 314, + "res_type": "polyphase" + }, + "3": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 131, + "crop_stop": 313, + "hpf_start": 154, + "hpf_stop": 141, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 757, + "pre_filter_stop": 768 +} diff --git a/tools/uvr5/lib/lib_v5/modelparams/3band_44100_mid.json b/tools/uvr5/lib/lib_v5/modelparams/3band_44100_mid.json new file mode 100644 index 0000000000000000000000000000000000000000..77ec198573b19f36519a028a509767d30764c0e2 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/3band_44100_mid.json @@ -0,0 +1,43 @@ +{ + "mid_side": true, + "bins": 768, + "unstable_bins": 5, + "reduction_bins": 733, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 768, + "crop_start": 0, + "crop_stop": 278, + "lpf_start": 28, + "lpf_stop": 140, + "res_type": "polyphase" + }, + "2": { + "sr": 22050, + "hl": 256, + "n_fft": 768, + "crop_start": 14, + "crop_stop": 322, + "hpf_start": 70, + "hpf_stop": 14, + "lpf_start": 283, + "lpf_stop": 314, + "res_type": "polyphase" + }, + "3": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 131, + "crop_stop": 313, + "hpf_start": 154, + "hpf_stop": 141, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 757, + "pre_filter_stop": 768 +} diff --git a/tools/uvr5/lib/lib_v5/modelparams/3band_44100_msb2.json b/tools/uvr5/lib/lib_v5/modelparams/3band_44100_msb2.json new file mode 100644 index 0000000000000000000000000000000000000000..85ee8a7d44541c9176e85ea3dce8728d34990938 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/3band_44100_msb2.json @@ -0,0 +1,43 @@ +{ + "mid_side_b2": true, + "bins": 640, + "unstable_bins": 7, + "reduction_bins": 565, + "band": { + "1": { + "sr": 11025, + "hl": 108, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 187, + "lpf_start": 92, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "2": { + "sr": 22050, + "hl": 216, + "n_fft": 768, + "crop_start": 0, + "crop_stop": 212, + "hpf_start": 68, + "hpf_stop": 34, + "lpf_start": 174, + "lpf_stop": 209, + "res_type": "polyphase" + }, + "3": { + "sr": 44100, + "hl": 432, + "n_fft": 640, + "crop_start": 66, + "crop_stop": 307, + "hpf_start": 86, + "hpf_stop": 72, + "res_type": "kaiser_fast" + } + }, + "sr": 44100, + "pre_filter_start": 639, + "pre_filter_stop": 640 +} diff --git a/tools/uvr5/lib/lib_v5/modelparams/4band_44100.json b/tools/uvr5/lib/lib_v5/modelparams/4band_44100.json new file mode 100644 index 0000000000000000000000000000000000000000..df123754204372aa50d464fbe9102a401f48cc73 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/4band_44100.json @@ -0,0 +1,54 @@ +{ + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} diff --git a/tools/uvr5/lib/lib_v5/modelparams/4band_44100_mid.json b/tools/uvr5/lib/lib_v5/modelparams/4band_44100_mid.json new file mode 100644 index 0000000000000000000000000000000000000000..e91b699eb63d3382c3b9e9edf46d40ed91d6122b --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/4band_44100_mid.json @@ -0,0 +1,55 @@ +{ + "bins": 768, + "unstable_bins": 7, + "mid_side": true, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} diff --git a/tools/uvr5/lib/lib_v5/modelparams/4band_44100_msb.json b/tools/uvr5/lib/lib_v5/modelparams/4band_44100_msb.json new file mode 100644 index 0000000000000000000000000000000000000000..f852f280ec9d98fc1b65cec688290eaafec61b84 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/4band_44100_msb.json @@ -0,0 +1,55 @@ +{ + "mid_side_b": true, + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/4band_44100_msb2.json b/tools/uvr5/lib/lib_v5/modelparams/4band_44100_msb2.json new file mode 100644 index 0000000000000000000000000000000000000000..f852f280ec9d98fc1b65cec688290eaafec61b84 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/4band_44100_msb2.json @@ -0,0 +1,55 @@ +{ + "mid_side_b": true, + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/4band_44100_reverse.json b/tools/uvr5/lib/lib_v5/modelparams/4band_44100_reverse.json new file mode 100644 index 0000000000000000000000000000000000000000..7a07d5541bd83dc1caa20b531c3b43a2ffccac88 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/4band_44100_reverse.json @@ -0,0 +1,55 @@ +{ + "reverse": true, + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/4band_44100_sw.json b/tools/uvr5/lib/lib_v5/modelparams/4band_44100_sw.json new file mode 100644 index 0000000000000000000000000000000000000000..ba0cf342106de793e6ec3e876854c7fd451fbf76 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/4band_44100_sw.json @@ -0,0 +1,55 @@ +{ + "stereo_w": true, + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/4band_v2.json b/tools/uvr5/lib/lib_v5/modelparams/4band_v2.json new file mode 100644 index 0000000000000000000000000000000000000000..33281a0cf9916fc33558ddfda7a0287a2547faf4 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/4band_v2.json @@ -0,0 +1,54 @@ +{ + "bins": 672, + "unstable_bins": 8, + "reduction_bins": 637, + "band": { + "1": { + "sr": 7350, + "hl": 80, + "n_fft": 640, + "crop_start": 0, + "crop_stop": 85, + "lpf_start": 25, + "lpf_stop": 53, + "res_type": "polyphase" + }, + "2": { + "sr": 7350, + "hl": 80, + "n_fft": 320, + "crop_start": 4, + "crop_stop": 87, + "hpf_start": 25, + "hpf_stop": 12, + "lpf_start": 31, + "lpf_stop": 62, + "res_type": "polyphase" + }, + "3": { + "sr": 14700, + "hl": 160, + "n_fft": 512, + "crop_start": 17, + "crop_stop": 216, + "hpf_start": 48, + "hpf_stop": 24, + "lpf_start": 139, + "lpf_stop": 210, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 480, + "n_fft": 960, + "crop_start": 78, + "crop_stop": 383, + "hpf_start": 130, + "hpf_stop": 86, + "res_type": "kaiser_fast" + } + }, + "sr": 44100, + "pre_filter_start": 668, + "pre_filter_stop": 672 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/4band_v2_sn.json b/tools/uvr5/lib/lib_v5/modelparams/4band_v2_sn.json new file mode 100644 index 0000000000000000000000000000000000000000..2e5c770fe188779bf6b0873190b7a324d6a867b2 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/4band_v2_sn.json @@ -0,0 +1,55 @@ +{ + "bins": 672, + "unstable_bins": 8, + "reduction_bins": 637, + "band": { + "1": { + "sr": 7350, + "hl": 80, + "n_fft": 640, + "crop_start": 0, + "crop_stop": 85, + "lpf_start": 25, + "lpf_stop": 53, + "res_type": "polyphase" + }, + "2": { + "sr": 7350, + "hl": 80, + "n_fft": 320, + "crop_start": 4, + "crop_stop": 87, + "hpf_start": 25, + "hpf_stop": 12, + "lpf_start": 31, + "lpf_stop": 62, + "res_type": "polyphase" + }, + "3": { + "sr": 14700, + "hl": 160, + "n_fft": 512, + "crop_start": 17, + "crop_stop": 216, + "hpf_start": 48, + "hpf_stop": 24, + "lpf_start": 139, + "lpf_stop": 210, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 480, + "n_fft": 960, + "crop_start": 78, + "crop_stop": 383, + "hpf_start": 130, + "hpf_stop": 86, + "convert_channels": "stereo_n", + "res_type": "kaiser_fast" + } + }, + "sr": 44100, + "pre_filter_start": 668, + "pre_filter_stop": 672 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/4band_v3.json b/tools/uvr5/lib/lib_v5/modelparams/4band_v3.json new file mode 100644 index 0000000000000000000000000000000000000000..2a73bc97ac545145a75bdca7addc5d59f5b8574b --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/4band_v3.json @@ -0,0 +1,54 @@ +{ + "bins": 672, + "unstable_bins": 8, + "reduction_bins": 530, + "band": { + "1": { + "sr": 7350, + "hl": 80, + "n_fft": 640, + "crop_start": 0, + "crop_stop": 85, + "lpf_start": 25, + "lpf_stop": 53, + "res_type": "polyphase" + }, + "2": { + "sr": 7350, + "hl": 80, + "n_fft": 320, + "crop_start": 4, + "crop_stop": 87, + "hpf_start": 25, + "hpf_stop": 12, + "lpf_start": 31, + "lpf_stop": 62, + "res_type": "polyphase" + }, + "3": { + "sr": 14700, + "hl": 160, + "n_fft": 512, + "crop_start": 17, + "crop_stop": 216, + "hpf_start": 48, + "hpf_stop": 24, + "lpf_start": 139, + "lpf_stop": 210, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 480, + "n_fft": 960, + "crop_start": 78, + "crop_stop": 383, + "hpf_start": 130, + "hpf_stop": 86, + "res_type": "kaiser_fast" + } + }, + "sr": 44100, + "pre_filter_start": 668, + "pre_filter_stop": 672 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/modelparams/ensemble.json b/tools/uvr5/lib/lib_v5/modelparams/ensemble.json new file mode 100644 index 0000000000000000000000000000000000000000..ee69beb46fc82f34619c5e48761e329fcabbbd00 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/modelparams/ensemble.json @@ -0,0 +1,43 @@ +{ + "mid_side_b2": true, + "bins": 1280, + "unstable_bins": 7, + "reduction_bins": 565, + "band": { + "1": { + "sr": 11025, + "hl": 108, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 374, + "lpf_start": 92, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "2": { + "sr": 22050, + "hl": 216, + "n_fft": 1536, + "crop_start": 0, + "crop_stop": 424, + "hpf_start": 68, + "hpf_stop": 34, + "lpf_start": 348, + "lpf_stop": 418, + "res_type": "polyphase" + }, + "3": { + "sr": 44100, + "hl": 432, + "n_fft": 1280, + "crop_start": 132, + "crop_stop": 614, + "hpf_start": 172, + "hpf_stop": 144, + "res_type": "polyphase" + } + }, + "sr": 44100, + "pre_filter_start": 1280, + "pre_filter_stop": 1280 +} \ No newline at end of file diff --git a/tools/uvr5/lib/lib_v5/nets.py b/tools/uvr5/lib/lib_v5/nets.py new file mode 100644 index 0000000000000000000000000000000000000000..5da3948c2f2e9edcc3cdac49bdf9f738e403de40 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/nets.py @@ -0,0 +1,123 @@ +import layers +import torch +import torch.nn.functional as F +from torch import nn + +from . import spec_utils + + +class BaseASPPNet(nn.Module): + def __init__(self, nin, ch, dilations=(4, 8, 16)): + super(BaseASPPNet, self).__init__() + self.enc1 = layers.Encoder(nin, ch, 3, 2, 1) + self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1) + self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1) + self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1) + + self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations) + + self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1) + self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1) + self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1) + self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1) + + def __call__(self, x): + h, e1 = self.enc1(x) + h, e2 = self.enc2(h) + h, e3 = self.enc3(h) + h, e4 = self.enc4(h) + + h = self.aspp(h) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = self.dec1(h, e1) + + return h + + +class CascadedASPPNet(nn.Module): + def __init__(self, n_fft): + super(CascadedASPPNet, self).__init__() + self.stg1_low_band_net = BaseASPPNet(2, 16) + self.stg1_high_band_net = BaseASPPNet(2, 16) + + self.stg2_bridge = layers.Conv2DBNActiv(18, 8, 1, 1, 0) + self.stg2_full_band_net = BaseASPPNet(8, 16) + + self.stg3_bridge = layers.Conv2DBNActiv(34, 16, 1, 1, 0) + self.stg3_full_band_net = BaseASPPNet(16, 32) + + self.out = nn.Conv2d(32, 2, 1, bias=False) + self.aux1_out = nn.Conv2d(16, 2, 1, bias=False) + self.aux2_out = nn.Conv2d(16, 2, 1, bias=False) + + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + + self.offset = 128 + + def forward(self, x, aggressiveness=None): + mix = x.detach() + x = x.clone() + + x = x[:, :, : self.max_bin] + + bandw = x.size()[2] // 2 + aux1 = torch.cat( + [ + self.stg1_low_band_net(x[:, :, :bandw]), + self.stg1_high_band_net(x[:, :, bandw:]), + ], + dim=2, + ) + + h = torch.cat([x, aux1], dim=1) + aux2 = self.stg2_full_band_net(self.stg2_bridge(h)) + + h = torch.cat([x, aux1, aux2], dim=1) + h = self.stg3_full_band_net(self.stg3_bridge(h)) + + mask = torch.sigmoid(self.out(h)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode="replicate", + ) + + if self.training: + aux1 = torch.sigmoid(self.aux1_out(aux1)) + aux1 = F.pad( + input=aux1, + pad=(0, 0, 0, self.output_bin - aux1.size()[2]), + mode="replicate", + ) + aux2 = torch.sigmoid(self.aux2_out(aux2)) + aux2 = F.pad( + input=aux2, + pad=(0, 0, 0, self.output_bin - aux2.size()[2]), + mode="replicate", + ) + return mask * mix, aux1 * mix, aux2 * mix + else: + if aggressiveness: + mask[:, :, : aggressiveness["split_bin"]] = torch.pow( + mask[:, :, : aggressiveness["split_bin"]], + 1 + aggressiveness["value"] / 3, + ) + mask[:, :, aggressiveness["split_bin"] :] = torch.pow( + mask[:, :, aggressiveness["split_bin"] :], + 1 + aggressiveness["value"], + ) + + return mask * mix + + def predict(self, x_mag, aggressiveness=None): + h = self.forward(x_mag, aggressiveness) + + if self.offset > 0: + h = h[:, :, :, self.offset : -self.offset] + assert h.size()[3] > 0 + + return h diff --git a/tools/uvr5/lib/lib_v5/nets_123812KB.py b/tools/uvr5/lib/lib_v5/nets_123812KB.py new file mode 100644 index 0000000000000000000000000000000000000000..167d4cb2198863cf43e93440f7e63c5342fc7605 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/nets_123812KB.py @@ -0,0 +1,122 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from . import layers_123821KB as layers + + +class BaseASPPNet(nn.Module): + def __init__(self, nin, ch, dilations=(4, 8, 16)): + super(BaseASPPNet, self).__init__() + self.enc1 = layers.Encoder(nin, ch, 3, 2, 1) + self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1) + self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1) + self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1) + + self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations) + + self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1) + self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1) + self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1) + self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1) + + def __call__(self, x): + h, e1 = self.enc1(x) + h, e2 = self.enc2(h) + h, e3 = self.enc3(h) + h, e4 = self.enc4(h) + + h = self.aspp(h) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = self.dec1(h, e1) + + return h + + +class CascadedASPPNet(nn.Module): + def __init__(self, n_fft): + super(CascadedASPPNet, self).__init__() + self.stg1_low_band_net = BaseASPPNet(2, 32) + self.stg1_high_band_net = BaseASPPNet(2, 32) + + self.stg2_bridge = layers.Conv2DBNActiv(34, 16, 1, 1, 0) + self.stg2_full_band_net = BaseASPPNet(16, 32) + + self.stg3_bridge = layers.Conv2DBNActiv(66, 32, 1, 1, 0) + self.stg3_full_band_net = BaseASPPNet(32, 64) + + self.out = nn.Conv2d(64, 2, 1, bias=False) + self.aux1_out = nn.Conv2d(32, 2, 1, bias=False) + self.aux2_out = nn.Conv2d(32, 2, 1, bias=False) + + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + + self.offset = 128 + + def forward(self, x, aggressiveness=None): + mix = x.detach() + x = x.clone() + + x = x[:, :, : self.max_bin] + + bandw = x.size()[2] // 2 + aux1 = torch.cat( + [ + self.stg1_low_band_net(x[:, :, :bandw]), + self.stg1_high_band_net(x[:, :, bandw:]), + ], + dim=2, + ) + + h = torch.cat([x, aux1], dim=1) + aux2 = self.stg2_full_band_net(self.stg2_bridge(h)) + + h = torch.cat([x, aux1, aux2], dim=1) + h = self.stg3_full_band_net(self.stg3_bridge(h)) + + mask = torch.sigmoid(self.out(h)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode="replicate", + ) + + if self.training: + aux1 = torch.sigmoid(self.aux1_out(aux1)) + aux1 = F.pad( + input=aux1, + pad=(0, 0, 0, self.output_bin - aux1.size()[2]), + mode="replicate", + ) + aux2 = torch.sigmoid(self.aux2_out(aux2)) + aux2 = F.pad( + input=aux2, + pad=(0, 0, 0, self.output_bin - aux2.size()[2]), + mode="replicate", + ) + return mask * mix, aux1 * mix, aux2 * mix + else: + if aggressiveness: + mask[:, :, : aggressiveness["split_bin"]] = torch.pow( + mask[:, :, : aggressiveness["split_bin"]], + 1 + aggressiveness["value"] / 3, + ) + mask[:, :, aggressiveness["split_bin"] :] = torch.pow( + mask[:, :, aggressiveness["split_bin"] :], + 1 + aggressiveness["value"], + ) + + return mask * mix + + def predict(self, x_mag, aggressiveness=None): + h = self.forward(x_mag, aggressiveness) + + if self.offset > 0: + h = h[:, :, :, self.offset : -self.offset] + assert h.size()[3] > 0 + + return h diff --git a/tools/uvr5/lib/lib_v5/nets_123821KB.py b/tools/uvr5/lib/lib_v5/nets_123821KB.py new file mode 100644 index 0000000000000000000000000000000000000000..167d4cb2198863cf43e93440f7e63c5342fc7605 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/nets_123821KB.py @@ -0,0 +1,122 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from . import layers_123821KB as layers + + +class BaseASPPNet(nn.Module): + def __init__(self, nin, ch, dilations=(4, 8, 16)): + super(BaseASPPNet, self).__init__() + self.enc1 = layers.Encoder(nin, ch, 3, 2, 1) + self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1) + self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1) + self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1) + + self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations) + + self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1) + self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1) + self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1) + self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1) + + def __call__(self, x): + h, e1 = self.enc1(x) + h, e2 = self.enc2(h) + h, e3 = self.enc3(h) + h, e4 = self.enc4(h) + + h = self.aspp(h) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = self.dec1(h, e1) + + return h + + +class CascadedASPPNet(nn.Module): + def __init__(self, n_fft): + super(CascadedASPPNet, self).__init__() + self.stg1_low_band_net = BaseASPPNet(2, 32) + self.stg1_high_band_net = BaseASPPNet(2, 32) + + self.stg2_bridge = layers.Conv2DBNActiv(34, 16, 1, 1, 0) + self.stg2_full_band_net = BaseASPPNet(16, 32) + + self.stg3_bridge = layers.Conv2DBNActiv(66, 32, 1, 1, 0) + self.stg3_full_band_net = BaseASPPNet(32, 64) + + self.out = nn.Conv2d(64, 2, 1, bias=False) + self.aux1_out = nn.Conv2d(32, 2, 1, bias=False) + self.aux2_out = nn.Conv2d(32, 2, 1, bias=False) + + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + + self.offset = 128 + + def forward(self, x, aggressiveness=None): + mix = x.detach() + x = x.clone() + + x = x[:, :, : self.max_bin] + + bandw = x.size()[2] // 2 + aux1 = torch.cat( + [ + self.stg1_low_band_net(x[:, :, :bandw]), + self.stg1_high_band_net(x[:, :, bandw:]), + ], + dim=2, + ) + + h = torch.cat([x, aux1], dim=1) + aux2 = self.stg2_full_band_net(self.stg2_bridge(h)) + + h = torch.cat([x, aux1, aux2], dim=1) + h = self.stg3_full_band_net(self.stg3_bridge(h)) + + mask = torch.sigmoid(self.out(h)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode="replicate", + ) + + if self.training: + aux1 = torch.sigmoid(self.aux1_out(aux1)) + aux1 = F.pad( + input=aux1, + pad=(0, 0, 0, self.output_bin - aux1.size()[2]), + mode="replicate", + ) + aux2 = torch.sigmoid(self.aux2_out(aux2)) + aux2 = F.pad( + input=aux2, + pad=(0, 0, 0, self.output_bin - aux2.size()[2]), + mode="replicate", + ) + return mask * mix, aux1 * mix, aux2 * mix + else: + if aggressiveness: + mask[:, :, : aggressiveness["split_bin"]] = torch.pow( + mask[:, :, : aggressiveness["split_bin"]], + 1 + aggressiveness["value"] / 3, + ) + mask[:, :, aggressiveness["split_bin"] :] = torch.pow( + mask[:, :, aggressiveness["split_bin"] :], + 1 + aggressiveness["value"], + ) + + return mask * mix + + def predict(self, x_mag, aggressiveness=None): + h = self.forward(x_mag, aggressiveness) + + if self.offset > 0: + h = h[:, :, :, self.offset : -self.offset] + assert h.size()[3] > 0 + + return h diff --git a/tools/uvr5/lib/lib_v5/nets_33966KB.py b/tools/uvr5/lib/lib_v5/nets_33966KB.py new file mode 100644 index 0000000000000000000000000000000000000000..73a5b836177b706c306e27875f8391c1aed4b948 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/nets_33966KB.py @@ -0,0 +1,122 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from . import layers_33966KB as layers + + +class BaseASPPNet(nn.Module): + def __init__(self, nin, ch, dilations=(4, 8, 16, 32)): + super(BaseASPPNet, self).__init__() + self.enc1 = layers.Encoder(nin, ch, 3, 2, 1) + self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1) + self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1) + self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1) + + self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations) + + self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1) + self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1) + self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1) + self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1) + + def __call__(self, x): + h, e1 = self.enc1(x) + h, e2 = self.enc2(h) + h, e3 = self.enc3(h) + h, e4 = self.enc4(h) + + h = self.aspp(h) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = self.dec1(h, e1) + + return h + + +class CascadedASPPNet(nn.Module): + def __init__(self, n_fft): + super(CascadedASPPNet, self).__init__() + self.stg1_low_band_net = BaseASPPNet(2, 16) + self.stg1_high_band_net = BaseASPPNet(2, 16) + + self.stg2_bridge = layers.Conv2DBNActiv(18, 8, 1, 1, 0) + self.stg2_full_band_net = BaseASPPNet(8, 16) + + self.stg3_bridge = layers.Conv2DBNActiv(34, 16, 1, 1, 0) + self.stg3_full_band_net = BaseASPPNet(16, 32) + + self.out = nn.Conv2d(32, 2, 1, bias=False) + self.aux1_out = nn.Conv2d(16, 2, 1, bias=False) + self.aux2_out = nn.Conv2d(16, 2, 1, bias=False) + + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + + self.offset = 128 + + def forward(self, x, aggressiveness=None): + mix = x.detach() + x = x.clone() + + x = x[:, :, : self.max_bin] + + bandw = x.size()[2] // 2 + aux1 = torch.cat( + [ + self.stg1_low_band_net(x[:, :, :bandw]), + self.stg1_high_band_net(x[:, :, bandw:]), + ], + dim=2, + ) + + h = torch.cat([x, aux1], dim=1) + aux2 = self.stg2_full_band_net(self.stg2_bridge(h)) + + h = torch.cat([x, aux1, aux2], dim=1) + h = self.stg3_full_band_net(self.stg3_bridge(h)) + + mask = torch.sigmoid(self.out(h)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode="replicate", + ) + + if self.training: + aux1 = torch.sigmoid(self.aux1_out(aux1)) + aux1 = F.pad( + input=aux1, + pad=(0, 0, 0, self.output_bin - aux1.size()[2]), + mode="replicate", + ) + aux2 = torch.sigmoid(self.aux2_out(aux2)) + aux2 = F.pad( + input=aux2, + pad=(0, 0, 0, self.output_bin - aux2.size()[2]), + mode="replicate", + ) + return mask * mix, aux1 * mix, aux2 * mix + else: + if aggressiveness: + mask[:, :, : aggressiveness["split_bin"]] = torch.pow( + mask[:, :, : aggressiveness["split_bin"]], + 1 + aggressiveness["value"] / 3, + ) + mask[:, :, aggressiveness["split_bin"] :] = torch.pow( + mask[:, :, aggressiveness["split_bin"] :], + 1 + aggressiveness["value"], + ) + + return mask * mix + + def predict(self, x_mag, aggressiveness=None): + h = self.forward(x_mag, aggressiveness) + + if self.offset > 0: + h = h[:, :, :, self.offset : -self.offset] + assert h.size()[3] > 0 + + return h diff --git a/tools/uvr5/lib/lib_v5/nets_537227KB.py b/tools/uvr5/lib/lib_v5/nets_537227KB.py new file mode 100644 index 0000000000000000000000000000000000000000..823b44fb64898e8dcbb12180ba45d1718f9b03f7 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/nets_537227KB.py @@ -0,0 +1,123 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from . import layers_537238KB as layers + + +class BaseASPPNet(nn.Module): + def __init__(self, nin, ch, dilations=(4, 8, 16)): + super(BaseASPPNet, self).__init__() + self.enc1 = layers.Encoder(nin, ch, 3, 2, 1) + self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1) + self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1) + self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1) + + self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations) + + self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1) + self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1) + self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1) + self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1) + + def __call__(self, x): + h, e1 = self.enc1(x) + h, e2 = self.enc2(h) + h, e3 = self.enc3(h) + h, e4 = self.enc4(h) + + h = self.aspp(h) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = self.dec1(h, e1) + + return h + + +class CascadedASPPNet(nn.Module): + def __init__(self, n_fft): + super(CascadedASPPNet, self).__init__() + self.stg1_low_band_net = BaseASPPNet(2, 64) + self.stg1_high_band_net = BaseASPPNet(2, 64) + + self.stg2_bridge = layers.Conv2DBNActiv(66, 32, 1, 1, 0) + self.stg2_full_band_net = BaseASPPNet(32, 64) + + self.stg3_bridge = layers.Conv2DBNActiv(130, 64, 1, 1, 0) + self.stg3_full_band_net = BaseASPPNet(64, 128) + + self.out = nn.Conv2d(128, 2, 1, bias=False) + self.aux1_out = nn.Conv2d(64, 2, 1, bias=False) + self.aux2_out = nn.Conv2d(64, 2, 1, bias=False) + + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + + self.offset = 128 + + def forward(self, x, aggressiveness=None): + mix = x.detach() + x = x.clone() + + x = x[:, :, : self.max_bin] + + bandw = x.size()[2] // 2 + aux1 = torch.cat( + [ + self.stg1_low_band_net(x[:, :, :bandw]), + self.stg1_high_band_net(x[:, :, bandw:]), + ], + dim=2, + ) + + h = torch.cat([x, aux1], dim=1) + aux2 = self.stg2_full_band_net(self.stg2_bridge(h)) + + h = torch.cat([x, aux1, aux2], dim=1) + h = self.stg3_full_band_net(self.stg3_bridge(h)) + + mask = torch.sigmoid(self.out(h)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode="replicate", + ) + + if self.training: + aux1 = torch.sigmoid(self.aux1_out(aux1)) + aux1 = F.pad( + input=aux1, + pad=(0, 0, 0, self.output_bin - aux1.size()[2]), + mode="replicate", + ) + aux2 = torch.sigmoid(self.aux2_out(aux2)) + aux2 = F.pad( + input=aux2, + pad=(0, 0, 0, self.output_bin - aux2.size()[2]), + mode="replicate", + ) + return mask * mix, aux1 * mix, aux2 * mix + else: + if aggressiveness: + mask[:, :, : aggressiveness["split_bin"]] = torch.pow( + mask[:, :, : aggressiveness["split_bin"]], + 1 + aggressiveness["value"] / 3, + ) + mask[:, :, aggressiveness["split_bin"] :] = torch.pow( + mask[:, :, aggressiveness["split_bin"] :], + 1 + aggressiveness["value"], + ) + + return mask * mix + + def predict(self, x_mag, aggressiveness=None): + h = self.forward(x_mag, aggressiveness) + + if self.offset > 0: + h = h[:, :, :, self.offset : -self.offset] + assert h.size()[3] > 0 + + return h diff --git a/tools/uvr5/lib/lib_v5/nets_537238KB.py b/tools/uvr5/lib/lib_v5/nets_537238KB.py new file mode 100644 index 0000000000000000000000000000000000000000..823b44fb64898e8dcbb12180ba45d1718f9b03f7 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/nets_537238KB.py @@ -0,0 +1,123 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from . import layers_537238KB as layers + + +class BaseASPPNet(nn.Module): + def __init__(self, nin, ch, dilations=(4, 8, 16)): + super(BaseASPPNet, self).__init__() + self.enc1 = layers.Encoder(nin, ch, 3, 2, 1) + self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1) + self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1) + self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1) + + self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations) + + self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1) + self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1) + self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1) + self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1) + + def __call__(self, x): + h, e1 = self.enc1(x) + h, e2 = self.enc2(h) + h, e3 = self.enc3(h) + h, e4 = self.enc4(h) + + h = self.aspp(h) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = self.dec1(h, e1) + + return h + + +class CascadedASPPNet(nn.Module): + def __init__(self, n_fft): + super(CascadedASPPNet, self).__init__() + self.stg1_low_band_net = BaseASPPNet(2, 64) + self.stg1_high_band_net = BaseASPPNet(2, 64) + + self.stg2_bridge = layers.Conv2DBNActiv(66, 32, 1, 1, 0) + self.stg2_full_band_net = BaseASPPNet(32, 64) + + self.stg3_bridge = layers.Conv2DBNActiv(130, 64, 1, 1, 0) + self.stg3_full_band_net = BaseASPPNet(64, 128) + + self.out = nn.Conv2d(128, 2, 1, bias=False) + self.aux1_out = nn.Conv2d(64, 2, 1, bias=False) + self.aux2_out = nn.Conv2d(64, 2, 1, bias=False) + + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + + self.offset = 128 + + def forward(self, x, aggressiveness=None): + mix = x.detach() + x = x.clone() + + x = x[:, :, : self.max_bin] + + bandw = x.size()[2] // 2 + aux1 = torch.cat( + [ + self.stg1_low_band_net(x[:, :, :bandw]), + self.stg1_high_band_net(x[:, :, bandw:]), + ], + dim=2, + ) + + h = torch.cat([x, aux1], dim=1) + aux2 = self.stg2_full_band_net(self.stg2_bridge(h)) + + h = torch.cat([x, aux1, aux2], dim=1) + h = self.stg3_full_band_net(self.stg3_bridge(h)) + + mask = torch.sigmoid(self.out(h)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode="replicate", + ) + + if self.training: + aux1 = torch.sigmoid(self.aux1_out(aux1)) + aux1 = F.pad( + input=aux1, + pad=(0, 0, 0, self.output_bin - aux1.size()[2]), + mode="replicate", + ) + aux2 = torch.sigmoid(self.aux2_out(aux2)) + aux2 = F.pad( + input=aux2, + pad=(0, 0, 0, self.output_bin - aux2.size()[2]), + mode="replicate", + ) + return mask * mix, aux1 * mix, aux2 * mix + else: + if aggressiveness: + mask[:, :, : aggressiveness["split_bin"]] = torch.pow( + mask[:, :, : aggressiveness["split_bin"]], + 1 + aggressiveness["value"] / 3, + ) + mask[:, :, aggressiveness["split_bin"] :] = torch.pow( + mask[:, :, aggressiveness["split_bin"] :], + 1 + aggressiveness["value"], + ) + + return mask * mix + + def predict(self, x_mag, aggressiveness=None): + h = self.forward(x_mag, aggressiveness) + + if self.offset > 0: + h = h[:, :, :, self.offset : -self.offset] + assert h.size()[3] > 0 + + return h diff --git a/tools/uvr5/lib/lib_v5/nets_61968KB.py b/tools/uvr5/lib/lib_v5/nets_61968KB.py new file mode 100644 index 0000000000000000000000000000000000000000..167d4cb2198863cf43e93440f7e63c5342fc7605 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/nets_61968KB.py @@ -0,0 +1,122 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from . import layers_123821KB as layers + + +class BaseASPPNet(nn.Module): + def __init__(self, nin, ch, dilations=(4, 8, 16)): + super(BaseASPPNet, self).__init__() + self.enc1 = layers.Encoder(nin, ch, 3, 2, 1) + self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1) + self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1) + self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1) + + self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations) + + self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1) + self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1) + self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1) + self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1) + + def __call__(self, x): + h, e1 = self.enc1(x) + h, e2 = self.enc2(h) + h, e3 = self.enc3(h) + h, e4 = self.enc4(h) + + h = self.aspp(h) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = self.dec1(h, e1) + + return h + + +class CascadedASPPNet(nn.Module): + def __init__(self, n_fft): + super(CascadedASPPNet, self).__init__() + self.stg1_low_band_net = BaseASPPNet(2, 32) + self.stg1_high_band_net = BaseASPPNet(2, 32) + + self.stg2_bridge = layers.Conv2DBNActiv(34, 16, 1, 1, 0) + self.stg2_full_band_net = BaseASPPNet(16, 32) + + self.stg3_bridge = layers.Conv2DBNActiv(66, 32, 1, 1, 0) + self.stg3_full_band_net = BaseASPPNet(32, 64) + + self.out = nn.Conv2d(64, 2, 1, bias=False) + self.aux1_out = nn.Conv2d(32, 2, 1, bias=False) + self.aux2_out = nn.Conv2d(32, 2, 1, bias=False) + + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + + self.offset = 128 + + def forward(self, x, aggressiveness=None): + mix = x.detach() + x = x.clone() + + x = x[:, :, : self.max_bin] + + bandw = x.size()[2] // 2 + aux1 = torch.cat( + [ + self.stg1_low_band_net(x[:, :, :bandw]), + self.stg1_high_band_net(x[:, :, bandw:]), + ], + dim=2, + ) + + h = torch.cat([x, aux1], dim=1) + aux2 = self.stg2_full_band_net(self.stg2_bridge(h)) + + h = torch.cat([x, aux1, aux2], dim=1) + h = self.stg3_full_band_net(self.stg3_bridge(h)) + + mask = torch.sigmoid(self.out(h)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode="replicate", + ) + + if self.training: + aux1 = torch.sigmoid(self.aux1_out(aux1)) + aux1 = F.pad( + input=aux1, + pad=(0, 0, 0, self.output_bin - aux1.size()[2]), + mode="replicate", + ) + aux2 = torch.sigmoid(self.aux2_out(aux2)) + aux2 = F.pad( + input=aux2, + pad=(0, 0, 0, self.output_bin - aux2.size()[2]), + mode="replicate", + ) + return mask * mix, aux1 * mix, aux2 * mix + else: + if aggressiveness: + mask[:, :, : aggressiveness["split_bin"]] = torch.pow( + mask[:, :, : aggressiveness["split_bin"]], + 1 + aggressiveness["value"] / 3, + ) + mask[:, :, aggressiveness["split_bin"] :] = torch.pow( + mask[:, :, aggressiveness["split_bin"] :], + 1 + aggressiveness["value"], + ) + + return mask * mix + + def predict(self, x_mag, aggressiveness=None): + h = self.forward(x_mag, aggressiveness) + + if self.offset > 0: + h = h[:, :, :, self.offset : -self.offset] + assert h.size()[3] > 0 + + return h diff --git a/tools/uvr5/lib/lib_v5/nets_new.py b/tools/uvr5/lib/lib_v5/nets_new.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0f4fa96d921e979fe31bd4151701b7783fbcea --- /dev/null +++ b/tools/uvr5/lib/lib_v5/nets_new.py @@ -0,0 +1,133 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from . import layers_new + + +class BaseNet(nn.Module): + def __init__( + self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6)) + ): + super(BaseNet, self).__init__() + self.enc1 = layers_new.Conv2DBNActiv(nin, nout, 3, 1, 1) + self.enc2 = layers_new.Encoder(nout, nout * 2, 3, 2, 1) + self.enc3 = layers_new.Encoder(nout * 2, nout * 4, 3, 2, 1) + self.enc4 = layers_new.Encoder(nout * 4, nout * 6, 3, 2, 1) + self.enc5 = layers_new.Encoder(nout * 6, nout * 8, 3, 2, 1) + + self.aspp = layers_new.ASPPModule(nout * 8, nout * 8, dilations, dropout=True) + + self.dec4 = layers_new.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1) + self.dec3 = layers_new.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1) + self.dec2 = layers_new.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1) + self.lstm_dec2 = layers_new.LSTMModule(nout * 2, nin_lstm, nout_lstm) + self.dec1 = layers_new.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1) + + def __call__(self, x): + e1 = self.enc1(x) + e2 = self.enc2(e1) + e3 = self.enc3(e2) + e4 = self.enc4(e3) + e5 = self.enc5(e4) + + h = self.aspp(e5) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = torch.cat([h, self.lstm_dec2(h)], dim=1) + h = self.dec1(h, e1) + + return h + + +class CascadedNet(nn.Module): + def __init__(self, n_fft, nout=32, nout_lstm=128): + super(CascadedNet, self).__init__() + + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + self.nin_lstm = self.max_bin // 2 + self.offset = 64 + + self.stg1_low_band_net = nn.Sequential( + BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm), + layers_new.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0), + ) + + self.stg1_high_band_net = BaseNet( + 2, nout // 4, self.nin_lstm // 2, nout_lstm // 2 + ) + + self.stg2_low_band_net = nn.Sequential( + BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm), + layers_new.Conv2DBNActiv(nout, nout // 2, 1, 1, 0), + ) + self.stg2_high_band_net = BaseNet( + nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2 + ) + + self.stg3_full_band_net = BaseNet( + 3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm + ) + + self.out = nn.Conv2d(nout, 2, 1, bias=False) + self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False) + + def forward(self, x): + x = x[:, :, : self.max_bin] + + bandw = x.size()[2] // 2 + l1_in = x[:, :, :bandw] + h1_in = x[:, :, bandw:] + l1 = self.stg1_low_band_net(l1_in) + h1 = self.stg1_high_band_net(h1_in) + aux1 = torch.cat([l1, h1], dim=2) + + l2_in = torch.cat([l1_in, l1], dim=1) + h2_in = torch.cat([h1_in, h1], dim=1) + l2 = self.stg2_low_band_net(l2_in) + h2 = self.stg2_high_band_net(h2_in) + aux2 = torch.cat([l2, h2], dim=2) + + f3_in = torch.cat([x, aux1, aux2], dim=1) + f3 = self.stg3_full_band_net(f3_in) + + mask = torch.sigmoid(self.out(f3)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode="replicate", + ) + + if self.training: + aux = torch.cat([aux1, aux2], dim=1) + aux = torch.sigmoid(self.aux_out(aux)) + aux = F.pad( + input=aux, + pad=(0, 0, 0, self.output_bin - aux.size()[2]), + mode="replicate", + ) + return mask, aux + else: + return mask + + def predict_mask(self, x): + mask = self.forward(x) + + if self.offset > 0: + mask = mask[:, :, :, self.offset : -self.offset] + assert mask.size()[3] > 0 + + return mask + + def predict(self, x, aggressiveness=None): + mask = self.forward(x) + pred_mag = x * mask + + if self.offset > 0: + pred_mag = pred_mag[:, :, :, self.offset : -self.offset] + assert pred_mag.size()[3] > 0 + + return pred_mag diff --git a/tools/uvr5/lib/lib_v5/spec_utils.py b/tools/uvr5/lib/lib_v5/spec_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..da072e4b2dd59b5382d3ebde818df286f9153f38 --- /dev/null +++ b/tools/uvr5/lib/lib_v5/spec_utils.py @@ -0,0 +1,676 @@ +import hashlib +import json +import math +import os + +import librosa +import numpy as np +import soundfile as sf +from tqdm import tqdm + + +def crop_center(h1, h2): + h1_shape = h1.size() + h2_shape = h2.size() + + if h1_shape[3] == h2_shape[3]: + return h1 + elif h1_shape[3] < h2_shape[3]: + raise ValueError("h1_shape[3] must be greater than h2_shape[3]") + + # s_freq = (h2_shape[2] - h1_shape[2]) // 2 + # e_freq = s_freq + h1_shape[2] + s_time = (h1_shape[3] - h2_shape[3]) // 2 + e_time = s_time + h2_shape[3] + h1 = h1[:, :, :, s_time:e_time] + + return h1 + + +def wave_to_spectrogram( + wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False +): + if reverse: + wave_left = np.flip(np.asfortranarray(wave[0])) + wave_right = np.flip(np.asfortranarray(wave[1])) + elif mid_side: + wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1])) + elif mid_side_b2: + wave_left = np.asfortranarray(np.add(wave[1], wave[0] * 0.5)) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * 0.5)) + else: + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + + spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length) + spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length) + + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + + +def wave_to_spectrogram_mt( + wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False +): + import threading + + if reverse: + wave_left = np.flip(np.asfortranarray(wave[0])) + wave_right = np.flip(np.asfortranarray(wave[1])) + elif mid_side: + wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1])) + elif mid_side_b2: + wave_left = np.asfortranarray(np.add(wave[1], wave[0] * 0.5)) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * 0.5)) + else: + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + + def run_thread(**kwargs): + global spec_left + spec_left = librosa.stft(**kwargs) + + thread = threading.Thread( + target=run_thread, + kwargs={"y": wave_left, "n_fft": n_fft, "hop_length": hop_length}, + ) + thread.start() + spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length) + thread.join() + + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + + +def combine_spectrograms(specs, mp): + l = min([specs[i].shape[2] for i in specs]) + spec_c = np.zeros(shape=(2, mp.param["bins"] + 1, l), dtype=np.complex64) + offset = 0 + bands_n = len(mp.param["band"]) + + for d in range(1, bands_n + 1): + h = mp.param["band"][d]["crop_stop"] - mp.param["band"][d]["crop_start"] + spec_c[:, offset : offset + h, :l] = specs[d][ + :, mp.param["band"][d]["crop_start"] : mp.param["band"][d]["crop_stop"], :l + ] + offset += h + + if offset > mp.param["bins"]: + raise ValueError("Too much bins") + + # lowpass fiter + if ( + mp.param["pre_filter_start"] > 0 + ): # and mp.param['band'][bands_n]['res_type'] in ['scipy', 'polyphase']: + if bands_n == 1: + spec_c = fft_lp_filter( + spec_c, mp.param["pre_filter_start"], mp.param["pre_filter_stop"] + ) + else: + gp = 1 + for b in range( + mp.param["pre_filter_start"] + 1, mp.param["pre_filter_stop"] + ): + g = math.pow( + 10, -(b - mp.param["pre_filter_start"]) * (3.5 - gp) / 20.0 + ) + gp = g + spec_c[:, b, :] *= g + + return np.asfortranarray(spec_c) + + +def spectrogram_to_image(spec, mode="magnitude"): + if mode == "magnitude": + if np.iscomplexobj(spec): + y = np.abs(spec) + else: + y = spec + y = np.log10(y**2 + 1e-8) + elif mode == "phase": + if np.iscomplexobj(spec): + y = np.angle(spec) + else: + y = spec + + y -= y.min() + y *= 255 / y.max() + img = np.uint8(y) + + if y.ndim == 3: + img = img.transpose(1, 2, 0) + img = np.concatenate([np.max(img, axis=2, keepdims=True), img], axis=2) + + return img + + +def reduce_vocal_aggressively(X, y, softmask): + v = X - y + y_mag_tmp = np.abs(y) + v_mag_tmp = np.abs(v) + + v_mask = v_mag_tmp > y_mag_tmp + y_mag = np.clip(y_mag_tmp - v_mag_tmp * v_mask * softmask, 0, np.inf) + + return y_mag * np.exp(1.0j * np.angle(y)) + + +def mask_silence(mag, ref, thres=0.2, min_range=64, fade_size=32): + if min_range < fade_size * 2: + raise ValueError("min_range must be >= fade_area * 2") + + mag = mag.copy() + + idx = np.where(ref.mean(axis=(0, 1)) < thres)[0] + starts = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0]) + ends = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1]) + uninformative = np.where(ends - starts > min_range)[0] + if len(uninformative) > 0: + starts = starts[uninformative] + ends = ends[uninformative] + old_e = None + for s, e in zip(starts, ends): + if old_e is not None and s - old_e < fade_size: + s = old_e - fade_size * 2 + + if s != 0: + weight = np.linspace(0, 1, fade_size) + mag[:, :, s : s + fade_size] += weight * ref[:, :, s : s + fade_size] + else: + s -= fade_size + + if e != mag.shape[2]: + weight = np.linspace(1, 0, fade_size) + mag[:, :, e - fade_size : e] += weight * ref[:, :, e - fade_size : e] + else: + e += fade_size + + mag[:, :, s + fade_size : e - fade_size] += ref[ + :, :, s + fade_size : e - fade_size + ] + old_e = e + + return mag + + +def align_wave_head_and_tail(a, b): + l = min([a[0].size, b[0].size]) + + return a[:l, :l], b[:l, :l] + + +def cache_or_load(mix_path, inst_path, mp): + mix_basename = os.path.splitext(os.path.basename(mix_path))[0] + inst_basename = os.path.splitext(os.path.basename(inst_path))[0] + + cache_dir = "mph{}".format( + hashlib.sha1(json.dumps(mp.param, sort_keys=True).encode("utf-8")).hexdigest() + ) + mix_cache_dir = os.path.join("cache", cache_dir) + inst_cache_dir = os.path.join("cache", cache_dir) + + os.makedirs(mix_cache_dir, exist_ok=True) + os.makedirs(inst_cache_dir, exist_ok=True) + + mix_cache_path = os.path.join(mix_cache_dir, mix_basename + ".npy") + inst_cache_path = os.path.join(inst_cache_dir, inst_basename + ".npy") + + if os.path.exists(mix_cache_path) and os.path.exists(inst_cache_path): + X_spec_m = np.load(mix_cache_path) + y_spec_m = np.load(inst_cache_path) + else: + X_wave, y_wave, X_spec_s, y_spec_s = {}, {}, {}, {} + + for d in range(len(mp.param["band"]), 0, -1): + bp = mp.param["band"][d] + + if d == len(mp.param["band"]): # high-end band + X_wave[d], _ = librosa.load( + mix_path, + sr = bp["sr"], + mono = False, + dtype = np.float32, + res_type = bp["res_type"] + ) + y_wave[d], _ = librosa.load( + inst_path, + sr = bp["sr"], + mono = False, + dtype = np.float32, + res_type = bp["res_type"], + ) + else: # lower bands + X_wave[d] = librosa.resample( + X_wave[d + 1], + orig_sr = mp.param["band"][d + 1]["sr"], + target_sr = bp["sr"], + res_type = bp["res_type"], + ) + y_wave[d] = librosa.resample( + y_wave[d + 1], + orig_sr = mp.param["band"][d + 1]["sr"], + target_sr = bp["sr"], + res_type = bp["res_type"], + ) + + X_wave[d], y_wave[d] = align_wave_head_and_tail(X_wave[d], y_wave[d]) + + X_spec_s[d] = wave_to_spectrogram( + X_wave[d], + bp["hl"], + bp["n_fft"], + mp.param["mid_side"], + mp.param["mid_side_b2"], + mp.param["reverse"], + ) + y_spec_s[d] = wave_to_spectrogram( + y_wave[d], + bp["hl"], + bp["n_fft"], + mp.param["mid_side"], + mp.param["mid_side_b2"], + mp.param["reverse"], + ) + + del X_wave, y_wave + + X_spec_m = combine_spectrograms(X_spec_s, mp) + y_spec_m = combine_spectrograms(y_spec_s, mp) + + if X_spec_m.shape != y_spec_m.shape: + raise ValueError("The combined spectrograms are different: " + mix_path) + + _, ext = os.path.splitext(mix_path) + + np.save(mix_cache_path, X_spec_m) + np.save(inst_cache_path, y_spec_m) + + return X_spec_m, y_spec_m + + +def spectrogram_to_wave(spec, hop_length, mid_side, mid_side_b2, reverse): + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + + wave_left = librosa.istft(spec_left, hop_length=hop_length) + wave_right = librosa.istft(spec_right, hop_length=hop_length) + + if reverse: + return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)]) + elif mid_side: + return np.asfortranarray( + [np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)] + ) + elif mid_side_b2: + return np.asfortranarray( + [ + np.add(wave_right / 1.25, 0.4 * wave_left), + np.subtract(wave_left / 1.25, 0.4 * wave_right), + ] + ) + else: + return np.asfortranarray([wave_left, wave_right]) + + +def spectrogram_to_wave_mt(spec, hop_length, mid_side, reverse, mid_side_b2): + import threading + + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + + def run_thread(**kwargs): + global wave_left + wave_left = librosa.istft(**kwargs) + + thread = threading.Thread( + target=run_thread, kwargs={"stft_matrix": spec_left, "hop_length": hop_length} + ) + thread.start() + wave_right = librosa.istft(spec_right, hop_length=hop_length) + thread.join() + + if reverse: + return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)]) + elif mid_side: + return np.asfortranarray( + [np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)] + ) + elif mid_side_b2: + return np.asfortranarray( + [ + np.add(wave_right / 1.25, 0.4 * wave_left), + np.subtract(wave_left / 1.25, 0.4 * wave_right), + ] + ) + else: + return np.asfortranarray([wave_left, wave_right]) + + +def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None): + wave_band = {} + bands_n = len(mp.param["band"]) + offset = 0 + + for d in range(1, bands_n + 1): + bp = mp.param["band"][d] + spec_s = np.ndarray( + shape=(2, bp["n_fft"] // 2 + 1, spec_m.shape[2]), dtype=complex + ) + h = bp["crop_stop"] - bp["crop_start"] + spec_s[:, bp["crop_start"] : bp["crop_stop"], :] = spec_m[ + :, offset : offset + h, : + ] + + offset += h + if d == bands_n: # higher + if extra_bins_h: # if --high_end_process bypass + max_bin = bp["n_fft"] // 2 + spec_s[:, max_bin - extra_bins_h : max_bin, :] = extra_bins[ + :, :extra_bins_h, : + ] + if bp["hpf_start"] > 0: + spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1) + if bands_n == 1: + wave = spectrogram_to_wave( + spec_s, + bp["hl"], + mp.param["mid_side"], + mp.param["mid_side_b2"], + mp.param["reverse"], + ) + else: + wave = np.add( + wave, + spectrogram_to_wave( + spec_s, + bp["hl"], + mp.param["mid_side"], + mp.param["mid_side_b2"], + mp.param["reverse"], + ), + ) + else: + sr = mp.param["band"][d + 1]["sr"] + if d == 1: # lower + spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"]) + wave = librosa.resample( + spectrogram_to_wave( + spec_s, + bp["hl"], + mp.param["mid_side"], + mp.param["mid_side_b2"], + mp.param["reverse"], + ), + orig_sr = bp["sr"], + target_sr = sr, + res_type = "sinc_fastest", + ) + else: # mid + spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1) + spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"]) + wave2 = np.add( + wave, + spectrogram_to_wave( + spec_s, + bp["hl"], + mp.param["mid_side"], + mp.param["mid_side_b2"], + mp.param["reverse"], + ), + ) + # wave = librosa.core.resample(wave2, orig_sr=bp['sr'], target_sr=sr, res_type="sinc_fastest") + wave = librosa.core.resample(wave2, orig_sr=bp["sr"], target_sr=sr, res_type="scipy") + + return wave.T + + +def fft_lp_filter(spec, bin_start, bin_stop): + g = 1.0 + for b in range(bin_start, bin_stop): + g -= 1 / (bin_stop - bin_start) + spec[:, b, :] = g * spec[:, b, :] + + spec[:, bin_stop:, :] *= 0 + + return spec + + +def fft_hp_filter(spec, bin_start, bin_stop): + g = 1.0 + for b in range(bin_start, bin_stop, -1): + g -= 1 / (bin_start - bin_stop) + spec[:, b, :] = g * spec[:, b, :] + + spec[:, 0 : bin_stop + 1, :] *= 0 + + return spec + + +def mirroring(a, spec_m, input_high_end, mp): + if "mirroring" == a: + mirror = np.flip( + np.abs( + spec_m[ + :, + mp.param["pre_filter_start"] + - 10 + - input_high_end.shape[1] : mp.param["pre_filter_start"] + - 10, + :, + ] + ), + 1, + ) + mirror = mirror * np.exp(1.0j * np.angle(input_high_end)) + + return np.where( + np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror + ) + + if "mirroring2" == a: + mirror = np.flip( + np.abs( + spec_m[ + :, + mp.param["pre_filter_start"] + - 10 + - input_high_end.shape[1] : mp.param["pre_filter_start"] + - 10, + :, + ] + ), + 1, + ) + mi = np.multiply(mirror, input_high_end * 1.7) + + return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi) + + +def ensembling(a, specs): + for i in range(1, len(specs)): + if i == 1: + spec = specs[0] + + ln = min([spec.shape[2], specs[i].shape[2]]) + spec = spec[:, :, :ln] + specs[i] = specs[i][:, :, :ln] + + if "min_mag" == a: + spec = np.where(np.abs(specs[i]) <= np.abs(spec), specs[i], spec) + if "max_mag" == a: + spec = np.where(np.abs(specs[i]) >= np.abs(spec), specs[i], spec) + + return spec + + +def stft(wave, nfft, hl): + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl) + spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl) + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + + +def istft(spec, hl): + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + + wave_left = librosa.istft(spec_left, hop_length=hl) + wave_right = librosa.istft(spec_right, hop_length=hl) + wave = np.asfortranarray([wave_left, wave_right]) + + +if __name__ == "__main__": + import argparse + import sys + import time + + import cv2 + from model_param_init import ModelParameters + + p = argparse.ArgumentParser() + p.add_argument( + "--algorithm", + "-a", + type=str, + choices=["invert", "invert_p", "min_mag", "max_mag", "deep", "align"], + default="min_mag", + ) + p.add_argument( + "--model_params", + "-m", + type=str, + default=os.path.join("modelparams", "1band_sr44100_hl512.json"), + ) + p.add_argument("--output_name", "-o", type=str, default="output") + p.add_argument("--vocals_only", "-v", action="store_true") + p.add_argument("input", nargs="+") + args = p.parse_args() + + start_time = time.time() + + if args.algorithm.startswith("invert") and len(args.input) != 2: + raise ValueError("There should be two input files.") + + if not args.algorithm.startswith("invert") and len(args.input) < 2: + raise ValueError("There must be at least two input files.") + + wave, specs = {}, {} + mp = ModelParameters(args.model_params) + + for i in range(len(args.input)): + spec = {} + + for d in range(len(mp.param["band"]), 0, -1): + bp = mp.param["band"][d] + + if d == len(mp.param["band"]): # high-end band + wave[d], _ = librosa.load( + args.input[i], + sr = bp["sr"], + mono = False, + dtype = np.float32, + res_type = bp["res_type"], + ) + + if len(wave[d].shape) == 1: # mono to stereo + wave[d] = np.array([wave[d], wave[d]]) + else: # lower bands + wave[d] = librosa.resample( + wave[d + 1], + orig_sr = mp.param["band"][d + 1]["sr"], + target_sr = bp["sr"], + res_type = bp["res_type"], + ) + + spec[d] = wave_to_spectrogram( + wave[d], + bp["hl"], + bp["n_fft"], + mp.param["mid_side"], + mp.param["mid_side_b2"], + mp.param["reverse"], + ) + + specs[i] = combine_spectrograms(spec, mp) + + del wave + + if args.algorithm == "deep": + d_spec = np.where(np.abs(specs[0]) <= np.abs(spec[1]), specs[0], spec[1]) + v_spec = d_spec - specs[1] + sf.write( + os.path.join("{}.wav".format(args.output_name)), + cmb_spectrogram_to_wave(v_spec, mp), + mp.param["sr"], + ) + + if args.algorithm.startswith("invert"): + ln = min([specs[0].shape[2], specs[1].shape[2]]) + specs[0] = specs[0][:, :, :ln] + specs[1] = specs[1][:, :, :ln] + + if "invert_p" == args.algorithm: + X_mag = np.abs(specs[0]) + y_mag = np.abs(specs[1]) + max_mag = np.where(X_mag >= y_mag, X_mag, y_mag) + v_spec = specs[1] - max_mag * np.exp(1.0j * np.angle(specs[0])) + else: + specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2) + v_spec = specs[0] - specs[1] + + if not args.vocals_only: + X_mag = np.abs(specs[0]) + y_mag = np.abs(specs[1]) + v_mag = np.abs(v_spec) + + X_image = spectrogram_to_image(X_mag) + y_image = spectrogram_to_image(y_mag) + v_image = spectrogram_to_image(v_mag) + + cv2.imwrite("{}_X.png".format(args.output_name), X_image) + cv2.imwrite("{}_y.png".format(args.output_name), y_image) + cv2.imwrite("{}_v.png".format(args.output_name), v_image) + + sf.write( + "{}_X.wav".format(args.output_name), + cmb_spectrogram_to_wave(specs[0], mp), + mp.param["sr"], + ) + sf.write( + "{}_y.wav".format(args.output_name), + cmb_spectrogram_to_wave(specs[1], mp), + mp.param["sr"], + ) + + sf.write( + "{}_v.wav".format(args.output_name), + cmb_spectrogram_to_wave(v_spec, mp), + mp.param["sr"], + ) + else: + if not args.algorithm == "deep": + sf.write( + os.path.join("ensembled", "{}.wav".format(args.output_name)), + cmb_spectrogram_to_wave(ensembling(args.algorithm, specs), mp), + mp.param["sr"], + ) + + if args.algorithm == "align": + trackalignment = [ + { + "file1": '"{}"'.format(args.input[0]), + "file2": '"{}"'.format(args.input[1]), + } + ] + + for i, e in tqdm(enumerate(trackalignment), desc="Performing Alignment..."): + os.system(f"python lib/align_tracks.py {e['file1']} {e['file2']}") + + # print('Total time: {0:.{1}f}s'.format(time.time() - start_time, 1)) diff --git a/tools/uvr5/lib/name_params.json b/tools/uvr5/lib/name_params.json new file mode 100644 index 0000000000000000000000000000000000000000..4e5ee7bec45de4740f8402c42537c9a98681c95e --- /dev/null +++ b/tools/uvr5/lib/name_params.json @@ -0,0 +1,263 @@ +{ + "equivalent" : [ + { + "model_hash_name" : [ + { + "hash_name": "47939caf0cfe52a0e81442b85b971dfd", + "model_params": "lib/lib_v5/modelparams/4band_44100.json", + "param_name": "4band_44100" + }, + { + "hash_name": "4e4ecb9764c50a8c414fee6e10395bbe", + "model_params": "lib/lib_v5/modelparams/4band_v2.json", + "param_name": "4band_v2" + }, + { + "hash_name": "ca106edd563e034bde0bdec4bb7a4b36", + "model_params": "lib/lib_v5/modelparams/4band_v2.json", + "param_name": "4band_v2" + }, + { + "hash_name": "e60a1e84803ce4efc0a6551206cc4b71", + "model_params": "lib/lib_v5/modelparams/4band_44100.json", + "param_name": "4band_44100" + }, + { + "hash_name": "a82f14e75892e55e994376edbf0c8435", + "model_params": "lib/lib_v5/modelparams/4band_44100.json", + "param_name": "4band_44100" + }, + { + "hash_name": "6dd9eaa6f0420af9f1d403aaafa4cc06", + "model_params": "lib/lib_v5/modelparams/4band_v2_sn.json", + "param_name": "4band_v2_sn" + }, + { + "hash_name": "08611fb99bd59eaa79ad27c58d137727", + "model_params": "lib/lib_v5/modelparams/4band_v2_sn.json", + "param_name": "4band_v2_sn" + }, + { + "hash_name": "5c7bbca45a187e81abbbd351606164e5", + "model_params": "lib/lib_v5/modelparams/3band_44100_msb2.json", + "param_name": "3band_44100_msb2" + }, + { + "hash_name": "d6b2cb685a058a091e5e7098192d3233", + "model_params": "lib/lib_v5/modelparams/3band_44100_msb2.json", + "param_name": "3band_44100_msb2" + }, + { + "hash_name": "c1b9f38170a7c90e96f027992eb7c62b", + "model_params": "lib/lib_v5/modelparams/4band_44100.json", + "param_name": "4band_44100" + }, + { + "hash_name": "c3448ec923fa0edf3d03a19e633faa53", + "model_params": "lib/lib_v5/modelparams/4band_44100.json", + "param_name": "4band_44100" + }, + { + "hash_name": "68aa2c8093d0080704b200d140f59e54", + "model_params": "lib/lib_v5/modelparams/3band_44100.json", + "param_name": "3band_44100" + }, + { + "hash_name": "fdc83be5b798e4bd29fe00fe6600e147", + "model_params": "lib/lib_v5/modelparams/3band_44100_mid.json", + "param_name": "3band_44100_mid.json" + }, + { + "hash_name": "2ce34bc92fd57f55db16b7a4def3d745", + "model_params": "lib/lib_v5/modelparams/3band_44100_mid.json", + "param_name": "3band_44100_mid.json" + }, + { + "hash_name": "52fdca89576f06cf4340b74a4730ee5f", + "model_params": "lib/lib_v5/modelparams/4band_44100.json", + "param_name": "4band_44100.json" + }, + { + "hash_name": "41191165b05d38fc77f072fa9e8e8a30", + "model_params": "lib/lib_v5/modelparams/4band_44100.json", + "param_name": "4band_44100.json" + }, + { + "hash_name": "89e83b511ad474592689e562d5b1f80e", + "model_params": "lib/lib_v5/modelparams/2band_32000.json", + "param_name": "2band_32000.json" + }, + { + "hash_name": "0b954da81d453b716b114d6d7c95177f", + "model_params": "lib/lib_v5/modelparams/2band_32000.json", + "param_name": "2band_32000.json" + } + + ], + "v4 Models": [ + { + "hash_name": "6a00461c51c2920fd68937d4609ed6c8", + "model_params": "lib/lib_v5/modelparams/1band_sr16000_hl512.json", + "param_name": "1band_sr16000_hl512" + }, + { + "hash_name": "0ab504864d20f1bd378fe9c81ef37140", + "model_params": "lib/lib_v5/modelparams/1band_sr32000_hl512.json", + "param_name": "1band_sr32000_hl512" + }, + { + "hash_name": "7dd21065bf91c10f7fccb57d7d83b07f", + "model_params": "lib/lib_v5/modelparams/1band_sr32000_hl512.json", + "param_name": "1band_sr32000_hl512" + }, + { + "hash_name": "80ab74d65e515caa3622728d2de07d23", + "model_params": "lib/lib_v5/modelparams/1band_sr32000_hl512.json", + "param_name": "1band_sr32000_hl512" + }, + { + "hash_name": "edc115e7fc523245062200c00caa847f", + "model_params": "lib/lib_v5/modelparams/1band_sr33075_hl384.json", + "param_name": "1band_sr33075_hl384" + }, + { + "hash_name": "28063e9f6ab5b341c5f6d3c67f2045b7", + "model_params": "lib/lib_v5/modelparams/1band_sr33075_hl384.json", + "param_name": "1band_sr33075_hl384" + }, + { + "hash_name": "b58090534c52cbc3e9b5104bad666ef2", + "model_params": "lib/lib_v5/modelparams/1band_sr44100_hl512.json", + "param_name": "1band_sr44100_hl512" + }, + { + "hash_name": "0cdab9947f1b0928705f518f3c78ea8f", + "model_params": "lib/lib_v5/modelparams/1band_sr44100_hl512.json", + "param_name": "1band_sr44100_hl512" + }, + { + "hash_name": "ae702fed0238afb5346db8356fe25f13", + "model_params": "lib/lib_v5/modelparams/1band_sr44100_hl1024.json", + "param_name": "1band_sr44100_hl1024" + } + ] + } + ], + "User Models" : [ + { + "1 Band": [ + { + "hash_name": "1band_sr16000_hl512", + "model_params": "lib/lib_v5/modelparams/1band_sr16000_hl512.json", + "param_name": "1band_sr16000_hl512" + }, + { + "hash_name": "1band_sr32000_hl512", + "model_params": "lib/lib_v5/modelparams/1band_sr32000_hl512.json", + "param_name": "1band_sr16000_hl512" + }, + { + "hash_name": "1band_sr33075_hl384", + "model_params": "lib/lib_v5/modelparams/1band_sr33075_hl384.json", + "param_name": "1band_sr33075_hl384" + }, + { + "hash_name": "1band_sr44100_hl256", + "model_params": "lib/lib_v5/modelparams/1band_sr44100_hl256.json", + "param_name": "1band_sr44100_hl256" + }, + { + "hash_name": "1band_sr44100_hl512", + "model_params": "lib/lib_v5/modelparams/1band_sr44100_hl512.json", + "param_name": "1band_sr44100_hl512" + }, + { + "hash_name": "1band_sr44100_hl1024", + "model_params": "lib/lib_v5/modelparams/1band_sr44100_hl1024.json", + "param_name": "1band_sr44100_hl1024" + } + ], + "2 Band": [ + { + "hash_name": "2band_44100_lofi", + "model_params": "lib/lib_v5/modelparams/2band_44100_lofi.json", + "param_name": "2band_44100_lofi" + }, + { + "hash_name": "2band_32000", + "model_params": "lib/lib_v5/modelparams/2band_32000.json", + "param_name": "2band_32000" + }, + { + "hash_name": "2band_48000", + "model_params": "lib/lib_v5/modelparams/2band_48000.json", + "param_name": "2band_48000" + } + ], + "3 Band": [ + { + "hash_name": "3band_44100", + "model_params": "lib/lib_v5/modelparams/3band_44100.json", + "param_name": "3band_44100" + }, + { + "hash_name": "3band_44100_mid", + "model_params": "lib/lib_v5/modelparams/3band_44100_mid.json", + "param_name": "3band_44100_mid" + }, + { + "hash_name": "3band_44100_msb2", + "model_params": "lib/lib_v5/modelparams/3band_44100_msb2.json", + "param_name": "3band_44100_msb2" + } + ], + "4 Band": [ + { + "hash_name": "4band_44100", + "model_params": "lib/lib_v5/modelparams/4band_44100.json", + "param_name": "4band_44100" + }, + { + "hash_name": "4band_44100_mid", + "model_params": "lib/lib_v5/modelparams/4band_44100_mid.json", + "param_name": "4band_44100_mid" + }, + { + "hash_name": "4band_44100_msb", + "model_params": "lib/lib_v5/modelparams/4band_44100_msb.json", + "param_name": "4band_44100_msb" + }, + { + "hash_name": "4band_44100_msb2", + "model_params": "lib/lib_v5/modelparams/4band_44100_msb2.json", + "param_name": "4band_44100_msb2" + }, + { + "hash_name": "4band_44100_reverse", + "model_params": "lib/lib_v5/modelparams/4band_44100_reverse.json", + "param_name": "4band_44100_reverse" + }, + { + "hash_name": "4band_44100_sw", + "model_params": "lib/lib_v5/modelparams/4band_44100_sw.json", + "param_name": "4band_44100_sw" + }, + { + "hash_name": "4band_v2", + "model_params": "lib/lib_v5/modelparams/4band_v2.json", + "param_name": "4band_v2" + }, + { + "hash_name": "4band_v2_sn", + "model_params": "lib/lib_v5/modelparams/4band_v2_sn.json", + "param_name": "4band_v2_sn" + }, + { + "hash_name": "tmodelparam", + "model_params": "lib/lib_v5/modelparams/tmodelparam.json", + "param_name": "User Model Param Set" + } + ] + } + ] +} \ No newline at end of file diff --git a/tools/uvr5/lib/utils.py b/tools/uvr5/lib/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8cd22fad3d26d89a3c9c09e9c569eae73d7275 --- /dev/null +++ b/tools/uvr5/lib/utils.py @@ -0,0 +1,121 @@ +import json + +import numpy as np +import torch +from tqdm import tqdm + + +def load_data(file_name: str = "./lib/name_params.json") -> dict: + with open(file_name, "r") as f: + data = json.load(f) + + return data + + +def make_padding(width, cropsize, offset): + left = offset + roi_size = cropsize - left * 2 + if roi_size == 0: + roi_size = cropsize + right = roi_size - (width % roi_size) + left + + return left, right, roi_size + + +def inference(X_spec, device, model, aggressiveness, data): + """ + data : dic configs + """ + + def _execute( + X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half=True + ): + model.eval() + with torch.no_grad(): + preds = [] + + iterations = [n_window] + + total_iterations = sum(iterations) + for i in tqdm(range(n_window)): + start = i * roi_size + X_mag_window = X_mag_pad[ + None, :, :, start : start + data["window_size"] + ] + X_mag_window = torch.from_numpy(X_mag_window) + if is_half: + X_mag_window = X_mag_window.half() + X_mag_window = X_mag_window.to(device) + + pred = model.predict(X_mag_window, aggressiveness) + + pred = pred.detach().cpu().numpy() + preds.append(pred[0]) + + pred = np.concatenate(preds, axis=2) + return pred + + def preprocess(X_spec): + X_mag = np.abs(X_spec) + X_phase = np.angle(X_spec) + + return X_mag, X_phase + + X_mag, X_phase = preprocess(X_spec) + + coef = X_mag.max() + X_mag_pre = X_mag / coef + + n_frame = X_mag_pre.shape[2] + pad_l, pad_r, roi_size = make_padding(n_frame, data["window_size"], model.offset) + n_window = int(np.ceil(n_frame / roi_size)) + + X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant") + + if list(model.state_dict().values())[0].dtype == torch.float16: + is_half = True + else: + is_half = False + pred = _execute( + X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half + ) + pred = pred[:, :, :n_frame] + + if data["tta"]: + pad_l += roi_size // 2 + pad_r += roi_size // 2 + n_window += 1 + + X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant") + + pred_tta = _execute( + X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half + ) + pred_tta = pred_tta[:, :, roi_size // 2 :] + pred_tta = pred_tta[:, :, :n_frame] + + return (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.0j * X_phase) + else: + return pred * coef, X_mag, np.exp(1.0j * X_phase) + + +def _get_name_params(model_path, model_hash): + data = load_data() + flag = False + ModelName = model_path + for type in list(data): + for model in list(data[type][0]): + for i in range(len(data[type][0][model])): + if str(data[type][0][model][i]["hash_name"]) == model_hash: + flag = True + elif str(data[type][0][model][i]["hash_name"]) in ModelName: + flag = True + + if flag: + model_params_auto = data[type][0][model][i]["model_params"] + param_name_auto = data[type][0][model][i]["param_name"] + if type == "equivalent": + return param_name_auto, model_params_auto + else: + flag = False + return param_name_auto, model_params_auto diff --git a/tools/uvr5/mdxnet.py b/tools/uvr5/mdxnet.py new file mode 100644 index 0000000000000000000000000000000000000000..372db25b2e169e1821696608676838b3d3207e2e --- /dev/null +++ b/tools/uvr5/mdxnet.py @@ -0,0 +1,256 @@ +import os +import logging + +logger = logging.getLogger(__name__) + +import librosa +import numpy as np +import soundfile as sf +import torch +from tqdm import tqdm + +cpu = torch.device("cpu") + + +class ConvTDFNetTrim: + def __init__( + self, device, model_name, target_name, L, dim_f, dim_t, n_fft, hop=1024 + ): + super(ConvTDFNetTrim, self).__init__() + + self.dim_f = dim_f + self.dim_t = 2**dim_t + self.n_fft = n_fft + self.hop = hop + self.n_bins = self.n_fft // 2 + 1 + self.chunk_size = hop * (self.dim_t - 1) + self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to( + device + ) + self.target_name = target_name + self.blender = "blender" in model_name + + self.dim_c = 4 + out_c = self.dim_c * 4 if target_name == "*" else self.dim_c + self.freq_pad = torch.zeros( + [1, out_c, self.n_bins - self.dim_f, self.dim_t] + ).to(device) + + self.n = L // 2 + + def stft(self, x): + x = x.reshape([-1, self.chunk_size]) + x = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop, + window=self.window, + center=True, + return_complex=True, + ) + x = torch.view_as_real(x) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape( + [-1, self.dim_c, self.n_bins, self.dim_t] + ) + return x[:, :, : self.dim_f] + + def istft(self, x, freq_pad=None): + freq_pad = ( + self.freq_pad.repeat([x.shape[0], 1, 1, 1]) + if freq_pad is None + else freq_pad + ) + x = torch.cat([x, freq_pad], -2) + c = 4 * 2 if self.target_name == "*" else 2 + x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape( + [-1, 2, self.n_bins, self.dim_t] + ) + x = x.permute([0, 2, 3, 1]) + x = x.contiguous() + x = torch.view_as_complex(x) + x = torch.istft( + x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True + ) + return x.reshape([-1, c, self.chunk_size]) + + +def get_models(device, dim_f, dim_t, n_fft): + return ConvTDFNetTrim( + device=device, + model_name="Conv-TDF", + target_name="vocals", + L=11, + dim_f=dim_f, + dim_t=dim_t, + n_fft=n_fft, + ) + + +class Predictor: + def __init__(self, args): + import onnxruntime as ort + + logger.info(ort.get_available_providers()) + self.args = args + self.model_ = get_models( + device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft + ) + self.model = ort.InferenceSession( + os.path.join(args.onnx, self.model_.target_name + ".onnx"), + providers=[ + "CUDAExecutionProvider", + "DmlExecutionProvider", + "CPUExecutionProvider", + ], + ) + logger.info("ONNX load done") + + def demix(self, mix): + samples = mix.shape[-1] + margin = self.args.margin + chunk_size = self.args.chunks * 44100 + assert not margin == 0, "margin cannot be zero!" + if margin > chunk_size: + margin = chunk_size + + segmented_mix = {} + + if self.args.chunks == 0 or samples < chunk_size: + chunk_size = samples + + counter = -1 + for skip in range(0, samples, chunk_size): + counter += 1 + + s_margin = 0 if counter == 0 else margin + end = min(skip + chunk_size + margin, samples) + + start = skip - s_margin + + segmented_mix[skip] = mix[:, start:end].copy() + if end == samples: + break + + sources = self.demix_base(segmented_mix, margin_size=margin) + """ + mix:(2,big_sample) + segmented_mix:offset->(2,small_sample) + sources:(1,2,big_sample) + """ + return sources + + def demix_base(self, mixes, margin_size): + chunked_sources = [] + progress_bar = tqdm(total=len(mixes)) + progress_bar.set_description("Processing") + for mix in mixes: + cmix = mixes[mix] + sources = [] + n_sample = cmix.shape[1] + model = self.model_ + trim = model.n_fft // 2 + gen_size = model.chunk_size - 2 * trim + pad = gen_size - n_sample % gen_size + mix_p = np.concatenate( + (np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1 + ) + mix_waves = [] + i = 0 + while i < n_sample + pad: + waves = np.array(mix_p[:, i : i + model.chunk_size]) + mix_waves.append(waves) + i += gen_size + mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(cpu) + with torch.no_grad(): + _ort = self.model + spek = model.stft(mix_waves) + if self.args.denoise: + spec_pred = ( + -_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5 + + _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5 + ) + tar_waves = model.istft(torch.tensor(spec_pred)) + else: + tar_waves = model.istft( + torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0]) + ) + tar_signal = ( + tar_waves[:, :, trim:-trim] + .transpose(0, 1) + .reshape(2, -1) + .numpy()[:, :-pad] + ) + + start = 0 if mix == 0 else margin_size + end = None if mix == list(mixes.keys())[::-1][0] else -margin_size + if margin_size == 0: + end = None + sources.append(tar_signal[:, start:end]) + + progress_bar.update(1) + + chunked_sources.append(sources) + _sources = np.concatenate(chunked_sources, axis=-1) + # del self.model + progress_bar.close() + return _sources + + def prediction(self, m, vocal_root, others_root, format): + os.makedirs(vocal_root, exist_ok=True) + os.makedirs(others_root, exist_ok=True) + basename = os.path.basename(m) + mix, rate = librosa.load(m, mono=False, sr=44100) + if mix.ndim == 1: + mix = np.asfortranarray([mix, mix]) + mix = mix.T + sources = self.demix(mix.T) + opt = sources[0].T + if format in ["wav", "flac"]: + sf.write( + "%s/%s_main_vocal.%s" % (vocal_root, basename, format), mix - opt, rate + ) + sf.write("%s/%s_others.%s" % (others_root, basename, format), opt, rate) + else: + path_vocal = "%s/%s_main_vocal.wav" % (vocal_root, basename) + path_other = "%s/%s_others.wav" % (others_root, basename) + sf.write(path_vocal, mix - opt, rate) + sf.write(path_other, opt, rate) + opt_path_vocal = path_vocal[:-4] + ".%s" % format + opt_path_other = path_other[:-4] + ".%s" % format + if os.path.exists(path_vocal): + os.system( + "ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal) + ) + if os.path.exists(opt_path_vocal): + try: + os.remove(path_vocal) + except: + pass + if os.path.exists(path_other): + os.system( + "ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other) + ) + if os.path.exists(opt_path_other): + try: + os.remove(path_other) + except: + pass + + +class MDXNetDereverb: + def __init__(self, chunks): + self.onnx = "%s/uvr5_weights/onnx_dereverb_By_FoxJoy"%os.path.dirname(os.path.abspath(__file__)) + self.shifts = 10 # 'Predict with randomised equivariant stabilisation' + self.mixing = "min_mag" # ['default','min_mag','max_mag'] + self.chunks = chunks + self.margin = 44100 + self.dim_t = 9 + self.dim_f = 3072 + self.n_fft = 6144 + self.denoise = True + self.pred = Predictor(self) + self.device = cpu + + def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False): + self.pred.prediction(input, vocal_root, others_root, format) diff --git a/tools/uvr5/uvr5_weights/.gitignore b/tools/uvr5/uvr5_weights/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d6b7ef32c8478a48c3994dcadc86837f4371184d --- /dev/null +++ b/tools/uvr5/uvr5_weights/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/tools/uvr5/vr.py b/tools/uvr5/vr.py new file mode 100644 index 0000000000000000000000000000000000000000..640392a4723ac94d076aa4765546bcac240d2717 --- /dev/null +++ b/tools/uvr5/vr.py @@ -0,0 +1,370 @@ +import os,sys +parent_directory = os.path.dirname(os.path.abspath(__file__)) +import logging,pdb +logger = logging.getLogger(__name__) + +import librosa +import numpy as np +import soundfile as sf +import torch +from lib.lib_v5 import nets_61968KB as Nets +from lib.lib_v5 import spec_utils +from lib.lib_v5.model_param_init import ModelParameters +from lib.lib_v5.nets_new import CascadedNet +from lib.utils import inference + + +class AudioPre: + def __init__(self, agg, model_path, device, is_half, tta=False): + self.model_path = model_path + self.device = device + self.data = { + # Processing Options + "postprocess": False, + "tta": tta, + # Constants + "window_size": 512, + "agg": agg, + "high_end_process": "mirroring", + } + mp = ModelParameters("%s/lib/lib_v5/modelparams/4band_v2.json"%parent_directory) + model = Nets.CascadedASPPNet(mp.param["bins"] * 2) + cpk = torch.load(model_path, map_location="cpu") + model.load_state_dict(cpk) + model.eval() + if is_half: + model = model.half().to(device) + else: + model = model.to(device) + + self.mp = mp + self.model = model + + def _path_audio_( + self, music_file, ins_root=None, vocal_root=None, format="flac", is_hp3=False + ): + if ins_root is None and vocal_root is None: + return "No save root." + name = os.path.basename(music_file) + if ins_root is not None: + os.makedirs(ins_root, exist_ok=True) + if vocal_root is not None: + os.makedirs(vocal_root, exist_ok=True) + X_wave, y_wave, X_spec_s, y_spec_s = {}, {}, {}, {} + bands_n = len(self.mp.param["band"]) + # print(bands_n) + for d in range(bands_n, 0, -1): + bp = self.mp.param["band"][d] + if d == bands_n: # high-end band + ( + X_wave[d], + _, + ) = librosa.core.load( # 理论上librosa读取可能对某些音频有bug,应该上ffmpeg读取,但是太麻烦了弃坑 + music_file, + sr = bp["sr"], + mono = False, + dtype = np.float32, + res_type = bp["res_type"], + ) + if X_wave[d].ndim == 1: + X_wave[d] = np.asfortranarray([X_wave[d], X_wave[d]]) + else: # lower bands + X_wave[d] = librosa.core.resample( + X_wave[d + 1], + orig_sr = self.mp.param["band"][d + 1]["sr"], + target_sr = bp["sr"], + res_type = bp["res_type"], + ) + # Stft of wave source + X_spec_s[d] = spec_utils.wave_to_spectrogram_mt( + X_wave[d], + bp["hl"], + bp["n_fft"], + self.mp.param["mid_side"], + self.mp.param["mid_side_b2"], + self.mp.param["reverse"], + ) + # pdb.set_trace() + if d == bands_n and self.data["high_end_process"] != "none": + input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + ( + self.mp.param["pre_filter_stop"] - self.mp.param["pre_filter_start"] + ) + input_high_end = X_spec_s[d][ + :, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, : + ] + + X_spec_m = spec_utils.combine_spectrograms(X_spec_s, self.mp) + aggresive_set = float(self.data["agg"] / 100) + aggressiveness = { + "value": aggresive_set, + "split_bin": self.mp.param["band"][1]["crop_stop"], + } + with torch.no_grad(): + pred, X_mag, X_phase = inference( + X_spec_m, self.device, self.model, aggressiveness, self.data + ) + # Postprocess + if self.data["postprocess"]: + pred_inv = np.clip(X_mag - pred, 0, np.inf) + pred = spec_utils.mask_silence(pred, pred_inv) + y_spec_m = pred * X_phase + v_spec_m = X_spec_m - y_spec_m + + if is_hp3 == True: + ins_root,vocal_root = vocal_root,ins_root + + if ins_root is not None: + if self.data["high_end_process"].startswith("mirroring"): + input_high_end_ = spec_utils.mirroring( + self.data["high_end_process"], y_spec_m, input_high_end, self.mp + ) + wav_instrument = spec_utils.cmb_spectrogram_to_wave( + y_spec_m, self.mp, input_high_end_h, input_high_end_ + ) + else: + wav_instrument = spec_utils.cmb_spectrogram_to_wave(y_spec_m, self.mp) + logger.info("%s instruments done" % name) + if is_hp3 == True: + head = "vocal_" + else: + head = "instrument_" + if format in ["wav", "flac"]: + sf.write( + os.path.join( + ins_root, + head + "{}_{}.{}".format(name, self.data["agg"], format), + ), + (np.array(wav_instrument) * 32768).astype("int16"), + self.mp.param["sr"], + ) # + else: + path = os.path.join( + ins_root, head + "{}_{}.wav".format(name, self.data["agg"]) + ) + sf.write( + path, + (np.array(wav_instrument) * 32768).astype("int16"), + self.mp.param["sr"], + ) + if os.path.exists(path): + opt_format_path = path[:-4] + ".%s" % format + os.system("ffmpeg -i %s -vn %s -q:a 2 -y" % (path, opt_format_path)) + if os.path.exists(opt_format_path): + try: + os.remove(path) + except: + pass + if vocal_root is not None: + if is_hp3 == True: + head = "instrument_" + else: + head = "vocal_" + if self.data["high_end_process"].startswith("mirroring"): + input_high_end_ = spec_utils.mirroring( + self.data["high_end_process"], v_spec_m, input_high_end, self.mp + ) + wav_vocals = spec_utils.cmb_spectrogram_to_wave( + v_spec_m, self.mp, input_high_end_h, input_high_end_ + ) + else: + wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp) + logger.info("%s vocals done" % name) + if format in ["wav", "flac"]: + sf.write( + os.path.join( + vocal_root, + head + "{}_{}.{}".format(name, self.data["agg"], format), + ), + (np.array(wav_vocals) * 32768).astype("int16"), + self.mp.param["sr"], + ) + else: + path = os.path.join( + vocal_root, head + "{}_{}.wav".format(name, self.data["agg"]) + ) + sf.write( + path, + (np.array(wav_vocals) * 32768).astype("int16"), + self.mp.param["sr"], + ) + if os.path.exists(path): + opt_format_path = path[:-4] + ".%s" % format + os.system("ffmpeg -i %s -vn %s -q:a 2 -y" % (path, opt_format_path)) + if os.path.exists(opt_format_path): + try: + os.remove(path) + except: + pass + + +class AudioPreDeEcho: + def __init__(self, agg, model_path, device, is_half, tta=False): + self.model_path = model_path + self.device = device + self.data = { + # Processing Options + "postprocess": False, + "tta": tta, + # Constants + "window_size": 512, + "agg": agg, + "high_end_process": "mirroring", + } + mp = ModelParameters("%s/lib/lib_v5/modelparams/4band_v3.json"%parent_directory) + nout = 64 if "DeReverb" in model_path else 48 + model = CascadedNet(mp.param["bins"] * 2, nout) + cpk = torch.load(model_path, map_location="cpu") + model.load_state_dict(cpk) + model.eval() + if is_half: + model = model.half().to(device) + else: + model = model.to(device) + + self.mp = mp + self.model = model + + def _path_audio_( + self, music_file, vocal_root=None, ins_root=None, format="flac", is_hp3=False + ): # 3个VR模型vocal和ins是反的 + if ins_root is None and vocal_root is None: + return "No save root." + name = os.path.basename(music_file) + if ins_root is not None: + os.makedirs(ins_root, exist_ok=True) + if vocal_root is not None: + os.makedirs(vocal_root, exist_ok=True) + X_wave, y_wave, X_spec_s, y_spec_s = {}, {}, {}, {} + bands_n = len(self.mp.param["band"]) + # print(bands_n) + for d in range(bands_n, 0, -1): + bp = self.mp.param["band"][d] + if d == bands_n: # high-end band + ( + X_wave[d], + _, + ) = librosa.core.load( # 理论上librosa读取可能对某些音频有bug,应该上ffmpeg读取,但是太麻烦了弃坑 + music_file, + sr = bp["sr"], + mono = False, + dtype = np.float32, + res_type = bp["res_type"], + ) + if X_wave[d].ndim == 1: + X_wave[d] = np.asfortranarray([X_wave[d], X_wave[d]]) + else: # lower bands + X_wave[d] = librosa.core.resample( + X_wave[d + 1], + orig_sr = self.mp.param["band"][d + 1]["sr"], + target_sr = bp["sr"], + res_type = bp["res_type"], + ) + # Stft of wave source + X_spec_s[d] = spec_utils.wave_to_spectrogram_mt( + X_wave[d], + bp["hl"], + bp["n_fft"], + self.mp.param["mid_side"], + self.mp.param["mid_side_b2"], + self.mp.param["reverse"], + ) + # pdb.set_trace() + if d == bands_n and self.data["high_end_process"] != "none": + input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + ( + self.mp.param["pre_filter_stop"] - self.mp.param["pre_filter_start"] + ) + input_high_end = X_spec_s[d][ + :, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, : + ] + + X_spec_m = spec_utils.combine_spectrograms(X_spec_s, self.mp) + aggresive_set = float(self.data["agg"] / 100) + aggressiveness = { + "value": aggresive_set, + "split_bin": self.mp.param["band"][1]["crop_stop"], + } + with torch.no_grad(): + pred, X_mag, X_phase = inference( + X_spec_m, self.device, self.model, aggressiveness, self.data + ) + # Postprocess + if self.data["postprocess"]: + pred_inv = np.clip(X_mag - pred, 0, np.inf) + pred = spec_utils.mask_silence(pred, pred_inv) + y_spec_m = pred * X_phase + v_spec_m = X_spec_m - y_spec_m + + if ins_root is not None: + if self.data["high_end_process"].startswith("mirroring"): + input_high_end_ = spec_utils.mirroring( + self.data["high_end_process"], y_spec_m, input_high_end, self.mp + ) + wav_instrument = spec_utils.cmb_spectrogram_to_wave( + y_spec_m, self.mp, input_high_end_h, input_high_end_ + ) + else: + wav_instrument = spec_utils.cmb_spectrogram_to_wave(y_spec_m, self.mp) + logger.info("%s instruments done" % name) + if format in ["wav", "flac"]: + sf.write( + os.path.join( + ins_root, + "vocal_{}_{}.{}".format(name, self.data["agg"], format), + ), + (np.array(wav_instrument) * 32768).astype("int16"), + self.mp.param["sr"], + ) # + else: + path = os.path.join( + ins_root, "vocal_{}_{}.wav".format(name, self.data["agg"]) + ) + sf.write( + path, + (np.array(wav_instrument) * 32768).astype("int16"), + self.mp.param["sr"], + ) + if os.path.exists(path): + opt_format_path = path[:-4] + ".%s" % format + os.system("ffmpeg -i %s -vn %s -q:a 2 -y" % (path, opt_format_path)) + if os.path.exists(opt_format_path): + try: + os.remove(path) + except: + pass + if vocal_root is not None: + if self.data["high_end_process"].startswith("mirroring"): + input_high_end_ = spec_utils.mirroring( + self.data["high_end_process"], v_spec_m, input_high_end, self.mp + ) + wav_vocals = spec_utils.cmb_spectrogram_to_wave( + v_spec_m, self.mp, input_high_end_h, input_high_end_ + ) + else: + wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp) + logger.info("%s vocals done" % name) + if format in ["wav", "flac"]: + sf.write( + os.path.join( + vocal_root, + "instrument_{}_{}.{}".format(name, self.data["agg"], format), + ), + (np.array(wav_vocals) * 32768).astype("int16"), + self.mp.param["sr"], + ) + else: + path = os.path.join( + vocal_root, "instrument_{}_{}.wav".format(name, self.data["agg"]) + ) + sf.write( + path, + (np.array(wav_vocals) * 32768).astype("int16"), + self.mp.param["sr"], + ) + if os.path.exists(path): + opt_format_path = path[:-4] + ".%s" % format + os.system("ffmpeg -i %s -vn %s -q:a 2 -y" % (path, opt_format_path)) + if os.path.exists(opt_format_path): + try: + os.remove(path) + except: + pass diff --git a/tools/uvr5/webui.py b/tools/uvr5/webui.py new file mode 100644 index 0000000000000000000000000000000000000000..60dfdaa7979d3472454611b63619cc9096c9e630 --- /dev/null +++ b/tools/uvr5/webui.py @@ -0,0 +1,190 @@ +import os +import traceback,gradio as gr +import logging +from tools.i18n.i18n import I18nAuto +from tools.my_utils import clean_path +i18n = I18nAuto() + +logger = logging.getLogger(__name__) +import librosa,ffmpeg +import soundfile as sf +import torch +import sys +from mdxnet import MDXNetDereverb +from vr import AudioPre, AudioPreDeEcho +from bsroformer import BsRoformer_Loader + +weight_uvr5_root = "tools/uvr5/uvr5_weights" +uvr5_names = [] +for name in os.listdir(weight_uvr5_root): + if name.endswith(".pth") or name.endswith(".ckpt") or "onnx" in name: + uvr5_names.append(name.replace(".pth", "").replace(".ckpt", "")) + +device=sys.argv[1] +is_half=eval(sys.argv[2]) +webui_port_uvr5=int(sys.argv[3]) +is_share=eval(sys.argv[4]) + +def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format0): + infos = [] + try: + inp_root = clean_path(inp_root) + save_root_vocal = clean_path(save_root_vocal) + save_root_ins = clean_path(save_root_ins) + is_hp3 = "HP3" in model_name + if model_name == "onnx_dereverb_By_FoxJoy": + pre_fun = MDXNetDereverb(15) + elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower(): + func = BsRoformer_Loader + pre_fun = func( + model_path = os.path.join(weight_uvr5_root, model_name + ".ckpt"), + device = device, + is_half=is_half + ) + else: + func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho + pre_fun = func( + agg=int(agg), + model_path=os.path.join(weight_uvr5_root, model_name + ".pth"), + device=device, + is_half=is_half, + ) + if inp_root != "": + paths = [os.path.join(inp_root, name) for name in os.listdir(inp_root)] + else: + paths = [path.name for path in paths] + for path in paths: + inp_path = os.path.join(inp_root, path) + if(os.path.isfile(inp_path)==False):continue + need_reformat = 1 + done = 0 + try: + info = ffmpeg.probe(inp_path, cmd="ffprobe") + if ( + info["streams"][0]["channels"] == 2 + and info["streams"][0]["sample_rate"] == "44100" + ): + need_reformat = 0 + pre_fun._path_audio_( + inp_path, save_root_ins, save_root_vocal, format0,is_hp3 + ) + done = 1 + except: + need_reformat = 1 + traceback.print_exc() + if need_reformat == 1: + tmp_path = "%s/%s.reformatted.wav" % ( + os.path.join(os.environ["TEMP"]), + os.path.basename(inp_path), + ) + os.system( + f'ffmpeg -i "{inp_path}" -vn -acodec pcm_s16le -ac 2 -ar 44100 "{tmp_path}" -y' + ) + inp_path = tmp_path + try: + if done == 0: + pre_fun._path_audio_( + inp_path, save_root_ins, save_root_vocal, format0,is_hp3 + ) + infos.append("%s->Success" % (os.path.basename(inp_path))) + yield "\n".join(infos) + except: + infos.append( + "%s->%s" % (os.path.basename(inp_path), traceback.format_exc()) + ) + yield "\n".join(infos) + except: + infos.append(traceback.format_exc()) + yield "\n".join(infos) + finally: + try: + if model_name == "onnx_dereverb_By_FoxJoy": + del pre_fun.pred.model + del pre_fun.pred.model_ + else: + del pre_fun.model + del pre_fun + except: + traceback.print_exc() + print("clean_empty_cache") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + yield "\n".join(infos) + +with gr.Blocks(title="UVR5 WebUI") as app: + gr.Markdown( + value= + i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") + ) + with gr.Tabs(): + with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")): + with gr.Group(): + gr.Markdown( + value=i18n("人声伴奏分离批量处理, 使用UVR5模型。") + "
" + \ + i18n("合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。")+ "
" + \ + i18n("模型分为三类:") + "
" + \ + i18n("1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;") + "
" + \ + i18n("2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;") + "
" + \ + i18n("3、去混响、去延迟模型(by FoxJoy):") + "
  " + \ + i18n("(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;") + "
 " + \ + i18n("(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。") + "
" + \ + i18n("去混响/去延迟,附:") + "
" + \ + i18n("1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;") + "
" + \ + i18n("2、MDX-Net-Dereverb模型挺慢的;") + "
" + \ + i18n("3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。") + ) + with gr.Row(): + with gr.Column(): + dir_wav_input = gr.Textbox( + label=i18n("输入待处理音频文件夹路径"), + placeholder="C:\\Users\\Desktop\\todo-songs", + ) + wav_inputs = gr.File( + file_count="multiple", label=i18n("也可批量输入音频文件, 二选一, 优先读文件夹") + ) + with gr.Column(): + model_choose = gr.Dropdown(label=i18n("模型"), choices=uvr5_names) + agg = gr.Slider( + minimum=0, + maximum=20, + step=1, + label=i18n("人声提取激进程度"), + value=10, + interactive=True, + visible=False, # 先不开放调整 + ) + opt_vocal_root = gr.Textbox( + label=i18n("指定输出主人声文件夹"), value="output/uvr5_opt" + ) + opt_ins_root = gr.Textbox( + label=i18n("指定输出非主人声文件夹"), value="output/uvr5_opt" + ) + format0 = gr.Radio( + label=i18n("导出文件格式"), + choices=["wav", "flac", "mp3", "m4a"], + value="flac", + interactive=True, + ) + but2 = gr.Button(i18n("转换"), variant="primary") + vc_output4 = gr.Textbox(label=i18n("输出信息")) + but2.click( + uvr, + [ + model_choose, + dir_wav_input, + opt_vocal_root, + wav_inputs, + opt_ins_root, + agg, + format0, + ], + [vc_output4], + api_name="uvr_convert", + ) +app.queue(concurrency_count=511, max_size=1022).launch( + server_name="0.0.0.0", + inbrowser=True, + share=is_share, + server_port=webui_port_uvr5, + quiet=True, +) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..08e1838420c399ec76266abedf75f9fdf076050b --- /dev/null +++ b/utils.py @@ -0,0 +1,361 @@ +import argparse +import glob +import json +import logging +import os +import subprocess +import sys +import traceback + +import librosa +import numpy as np +import torch + +logging.getLogger("numba").setLevel(logging.ERROR) +logging.getLogger("matplotlib").setLevel(logging.ERROR) +logging.getLogger("httpx").setLevel(logging.ERROR) +MATPLOTLIB_FLAG = False + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging + + +def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + iteration = checkpoint_dict["iteration"] + learning_rate = checkpoint_dict["learning_rate"] + if optimizer is not None and not skip_optimizer and checkpoint_dict["optimizer"] is not None: + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + saved_state_dict = checkpoint_dict["model"] + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): + try: + # assert "quantizer" not in k + # print("load", k) + new_state_dict[k] = saved_state_dict[k] + assert saved_state_dict[k].shape == v.shape, ( + saved_state_dict[k].shape, + v.shape, + ) + except: + traceback.print_exc() + print("error, %s is not in the checkpoint" % k) # shape不对也会,比如text_embedding当cleaner修改时 + new_state_dict[k] = v + if hasattr(model, "module"): + model.module.load_state_dict(new_state_dict) + else: + model.load_state_dict(new_state_dict) + print("load ") + logger.info( + "Loaded checkpoint '{}' (iteration {})".format( + checkpoint_path, + iteration, + ) + ) + return model, optimizer, learning_rate, iteration + + +import shutil +from time import time as ttime + + +def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path + dir = os.path.dirname(path) + name = os.path.basename(path) + tmp_path = "%s.pth" % (ttime()) + torch.save(fea, tmp_path) + shutil.move(tmp_path, "%s/%s" % (dir, name)) + + +def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): + logger.info("Saving model and optimizer state at iteration {} to {}".format(iteration, checkpoint_path)) + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + # torch.save( + my_save( + { + "model": state_dict, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) + + +def summarize( + writer, + global_step, + scalars={}, + histograms={}, + images={}, + audios={}, + audio_sampling_rate=22050, +): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats="HWC") + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) + + +def latest_checkpoint_path(dir_path, regex="G_*.pth"): + f_list = glob.glob(os.path.join(dir_path, regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + x = f_list[-1] + print(x) + return x + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow( + alignment.transpose(), + aspect="auto", + origin="lower", + interpolation="none", + ) + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + if info is not None: + xlabel += "\n\n" + info + plt.xlabel(xlabel) + plt.ylabel("Encoder timestep") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def load_wav_to_torch(full_path): + data, sampling_rate = librosa.load(full_path, sr=None) + return torch.FloatTensor(data), sampling_rate + + +def load_filepaths_and_text(filename, split="|"): + with open(filename, encoding="utf-8") as f: + filepaths_and_text = [line.strip().split(split) for line in f] + return filepaths_and_text + + +def get_hparams(init=True, stage=1): + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--config", + type=str, + default="./configs/s2.json", + help="JSON file for configuration", + ) + parser.add_argument("-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir") + parser.add_argument( + "-rs", + "--resume_step", + type=int, + required=False, + default=None, + help="resume step", + ) + # parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory') + # parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights') + # parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights') + + args = parser.parse_args() + + config_path = args.config + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.pretrain = args.pretrain + hparams.resume_step = args.resume_step + # hparams.data.exp_dir = args.exp_dir + if stage == 1: + model_dir = hparams.s1_ckpt_dir + else: + model_dir = hparams.s2_ckpt_dir + config_save_path = os.path.join(model_dir, "config.json") + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + with open(config_save_path, "w") as f: + f.write(data) + return hparams + + +def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True): + """Freeing up space by deleting saved ckpts + + Arguments: + path_to_models -- Path to the model directory + n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth + sort_by_time -- True -> chronologically delete ckpts + False -> lexicographically delete ckpts + """ + import re + + ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))] + name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1)) + time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)) + sort_key = time_key if sort_by_time else name_key + x_sorted = lambda _x: sorted( + [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], + key=sort_key, + ) + to_del = [ + os.path.join(path_to_models, fn) for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep]) + ] + del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}") + del_routine = lambda x: [os.remove(x), del_info(x)] + rs = [del_routine(fn) for fn in to_del] + + +def get_hparams_from_dir(model_dir): + config_save_path = os.path.join(model_dir, "config.json") + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_file(config_path): + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + return hparams + + +def check_git_hash(model_dir): + source_dir = os.path.dirname(os.path.realpath(__file__)) + if not os.path.exists(os.path.join(source_dir, ".git")): + logger.warning( + "{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir, + ) + ) + return + + cur_hash = subprocess.getoutput("git rev-parse HEAD") + + path = os.path.join(model_dir, "githash") + if os.path.exists(path): + saved_hash = open(path).read() + if saved_hash != cur_hash: + logger.warning( + "git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], + cur_hash[:8], + ) + ) + else: + open(path, "w").write(cur_hash) + + +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.INFO) + + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.INFO) + h.setFormatter(formatter) + logger.addHandler(h) + return logger + + +class HParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + +if __name__ == "__main__": + print( + load_wav_to_torch( + "/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac", + ) + )