kemuriririn commited on
Commit
60ea83f
·
1 Parent(s): ad2f564
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. .gitignore +3 -0
  3. AR/__init__.py +0 -0
  4. AR/data/__init__.py +0 -0
  5. AR/data/bucket_sampler.py +149 -0
  6. AR/data/data_module.py +81 -0
  7. AR/data/dataset.py +320 -0
  8. AR/models/__init__.py +0 -0
  9. AR/models/t2s_lightning_module.py +146 -0
  10. AR/models/t2s_lightning_module_onnx.py +110 -0
  11. AR/models/t2s_model.py +935 -0
  12. AR/models/t2s_model_onnx.py +394 -0
  13. AR/models/utils.py +282 -0
  14. AR/modules/__init__.py +0 -0
  15. AR/modules/activation.py +413 -0
  16. AR/modules/activation_onnx.py +188 -0
  17. AR/modules/embedding.py +78 -0
  18. AR/modules/embedding_onnx.py +63 -0
  19. AR/modules/lr_schedulers.py +85 -0
  20. AR/modules/optim.py +593 -0
  21. AR/modules/patched_mha_with_cache.py +428 -0
  22. AR/modules/patched_mha_with_cache_onnx.py +85 -0
  23. AR/modules/scaling.py +320 -0
  24. AR/modules/transformer.py +362 -0
  25. AR/modules/transformer_onnx.py +281 -0
  26. AR/text_processing/__init__.py +0 -0
  27. AR/text_processing/phonemizer.py +72 -0
  28. AR/text_processing/symbols.py +12 -0
  29. AR/utils/__init__.py +36 -0
  30. AR/utils/initialize.py +39 -0
  31. AR/utils/io.py +30 -0
  32. BigVGAN/LICENSE +21 -0
  33. BigVGAN/README.md +266 -0
  34. BigVGAN/activations.py +122 -0
  35. BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  36. BigVGAN/alias_free_activation/cuda/activation1d.py +69 -0
  37. BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  38. BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  39. BigVGAN/alias_free_activation/cuda/build/_ +1 -0
  40. BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  41. BigVGAN/alias_free_activation/cuda/load.py +82 -0
  42. BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  43. BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  44. BigVGAN/alias_free_activation/torch/act.py +30 -0
  45. BigVGAN/alias_free_activation/torch/filter.py +99 -0
  46. BigVGAN/alias_free_activation/torch/resample.py +48 -0
  47. BigVGAN/bigvgan.py +461 -0
  48. BigVGAN/configs/bigvgan_22khz_80band.json +45 -0
  49. BigVGAN/configs/bigvgan_24khz_100band.json +45 -0
  50. BigVGAN/configs/bigvgan_base_22khz_80band.json +45 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ text/ja_userdic/userdict.csv filter=lfs diff=lfs merge=lfs -text
37
+ text/g2pw/polyphonic-fix.rep filter=lfs diff=lfs merge=lfs -text
38
+ text/g2pw/polyphonic.pickle filter=lfs diff=lfs merge=lfs -text
39
+ text/g2pw/polyphonic.rep filter=lfs diff=lfs merge=lfs -text
40
+ text/G2PWModel/char_bopomofo_dict.json filter=lfs diff=lfs merge=lfs -text
41
+ text/cmudict-fast.rep filter=lfs diff=lfs merge=lfs -text
42
+ text/cmudict.rep filter=lfs diff=lfs merge=lfs -text
43
+ text/engdict-hot.rep filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ */.DS_Store
2
+ .DS_Store
3
+ .idea/
AR/__init__.py ADDED
File without changes
AR/data/__init__.py ADDED
File without changes
AR/data/bucket_sampler.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/bucket_sampler.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import itertools
4
+ import math
5
+ import random
6
+ from random import shuffle
7
+ from typing import Iterator, Optional, TypeVar
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+ from torch.utils.data import Dataset, Sampler
12
+
13
+ __all__ = [
14
+ "DistributedBucketSampler",
15
+ ]
16
+
17
+ T_co = TypeVar("T_co", covariant=True)
18
+
19
+
20
+ class DistributedBucketSampler(Sampler[T_co]):
21
+ r"""
22
+ sort the dataset wrt. input length
23
+ divide samples into buckets
24
+ sort within buckets
25
+ divide buckets into batches
26
+ sort batches
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ dataset: Dataset,
32
+ num_replicas: Optional[int] = None,
33
+ rank: Optional[int] = None,
34
+ shuffle: bool = True,
35
+ seed: int = 0,
36
+ drop_last: bool = False,
37
+ batch_size: int = 32,
38
+ ) -> None:
39
+ if num_replicas is None:
40
+ if not dist.is_available():
41
+ raise RuntimeError("Requires distributed package to be available")
42
+ num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
43
+ if rank is None:
44
+ if not dist.is_available():
45
+ raise RuntimeError("Requires distributed package to be available")
46
+ rank = dist.get_rank() if torch.cuda.is_available() else 0
47
+ if torch.cuda.is_available():
48
+ torch.cuda.set_device(rank)
49
+ if rank >= num_replicas or rank < 0:
50
+ raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
51
+ self.dataset = dataset
52
+ self.num_replicas = num_replicas
53
+ self.rank = rank
54
+ self.epoch = 0
55
+ self.drop_last = drop_last
56
+ # If the dataset length is evenly divisible by # of replicas, then there
57
+ # is no need to drop any data, since the dataset will be split equally.
58
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
59
+ # Split to nearest available length that is evenly divisible.
60
+ # This is to ensure each rank receives the same amount of data when
61
+ # using this Sampler.
62
+ self.num_samples = math.ceil(
63
+ (len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
64
+ )
65
+ else:
66
+ self.num_samples = math.ceil(
67
+ len(self.dataset) / self.num_replicas,
68
+ ) # type: ignore[arg-type]
69
+ self.total_size = self.num_samples * self.num_replicas
70
+ self.shuffle = shuffle
71
+ self.seed = seed
72
+ self.batch_size = batch_size
73
+ self.id_with_length = self._get_sample_lengths()
74
+ self.id_buckets = self.make_buckets(bucket_width=2.0)
75
+
76
+ def _get_sample_lengths(self):
77
+ id_with_lengths = []
78
+ for i in range(len(self.dataset)):
79
+ id_with_lengths.append((i, self.dataset.get_sample_length(i)))
80
+ id_with_lengths.sort(key=lambda x: x[1])
81
+ return id_with_lengths
82
+
83
+ def make_buckets(self, bucket_width: float = 2.0):
84
+ buckets = []
85
+ cur = []
86
+ max_sec = bucket_width
87
+ for id, sec in self.id_with_length:
88
+ if sec < max_sec:
89
+ cur.append(id)
90
+ else:
91
+ buckets.append(cur)
92
+ cur = [id]
93
+ max_sec += bucket_width
94
+ if len(cur) > 0:
95
+ buckets.append(cur)
96
+ return buckets
97
+
98
+ def __iter__(self) -> Iterator[T_co]:
99
+ if self.shuffle:
100
+ # deterministically shuffle based on epoch and seed
101
+ g = torch.Generator()
102
+ g.manual_seed(self.seed + self.epoch)
103
+ random.seed(self.epoch + self.seed)
104
+ shuffled_bucket = []
105
+ for buc in self.id_buckets:
106
+ buc_copy = buc.copy()
107
+ shuffle(buc_copy)
108
+ shuffled_bucket.append(buc_copy)
109
+ grouped_batch_size = self.batch_size * self.num_replicas
110
+ shuffled_bucket = list(itertools.chain(*shuffled_bucket))
111
+ n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
112
+ batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
113
+ shuffle(batches)
114
+ indices = list(itertools.chain(*batches))
115
+ else:
116
+ # type: ignore[arg-type]
117
+ indices = list(range(len(self.dataset)))
118
+
119
+ if not self.drop_last:
120
+ # add extra samples to make it evenly divisible
121
+ padding_size = self.total_size - len(indices)
122
+ if padding_size <= len(indices):
123
+ indices += indices[:padding_size]
124
+ else:
125
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
126
+ else:
127
+ # remove tail of data to make it evenly divisible.
128
+ indices = indices[: self.total_size]
129
+ assert len(indices) == self.total_size
130
+
131
+ # subsample
132
+ indices = indices[self.rank : self.total_size : self.num_replicas]
133
+ assert len(indices) == self.num_samples
134
+
135
+ return iter(indices)
136
+
137
+ def __len__(self) -> int:
138
+ return self.num_samples
139
+
140
+ def set_epoch(self, epoch: int) -> None:
141
+ r"""
142
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
143
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
144
+ sampler will yield the same ordering.
145
+
146
+ Args:
147
+ epoch (int): Epoch number.
148
+ """
149
+ self.epoch = epoch
AR/data/data_module.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ from pytorch_lightning import LightningDataModule
4
+ from torch.utils.data import DataLoader
5
+
6
+ from AR.data.bucket_sampler import DistributedBucketSampler
7
+ from AR.data.dataset import Text2SemanticDataset
8
+
9
+
10
+ class Text2SemanticDataModule(LightningDataModule):
11
+ def __init__(
12
+ self,
13
+ config,
14
+ train_semantic_path,
15
+ train_phoneme_path,
16
+ dev_semantic_path=None,
17
+ dev_phoneme_path=None,
18
+ ):
19
+ super().__init__()
20
+ self.config = config
21
+ self.train_semantic_path = train_semantic_path
22
+ self.train_phoneme_path = train_phoneme_path
23
+ self.dev_semantic_path = dev_semantic_path
24
+ self.dev_phoneme_path = dev_phoneme_path
25
+ self.num_workers = self.config["data"]["num_workers"]
26
+
27
+ def prepare_data(self):
28
+ pass
29
+
30
+ def setup(self, stage=None, output_logs=False):
31
+ self._train_dataset = Text2SemanticDataset(
32
+ phoneme_path=self.train_phoneme_path,
33
+ semantic_path=self.train_semantic_path,
34
+ max_sec=self.config["data"]["max_sec"],
35
+ pad_val=self.config["data"]["pad_val"],
36
+ )
37
+ self._dev_dataset = self._train_dataset
38
+ # self._dev_dataset = Text2SemanticDataset(
39
+ # phoneme_path=self.dev_phoneme_path,
40
+ # semantic_path=self.dev_semantic_path,
41
+ # max_sample=self.config['data']['max_eval_sample'],
42
+ # max_sec=self.config['data']['max_sec'],
43
+ # pad_val=self.config['data']['pad_val'])
44
+
45
+ def train_dataloader(self):
46
+ batch_size = (
47
+ self.config["train"]["batch_size"] // 2
48
+ if self.config["train"].get("if_dpo", False) is True
49
+ else self.config["train"]["batch_size"]
50
+ )
51
+ batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
52
+ sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
53
+ return DataLoader(
54
+ self._train_dataset,
55
+ batch_size=batch_size,
56
+ sampler=sampler,
57
+ collate_fn=self._train_dataset.collate,
58
+ num_workers=self.num_workers,
59
+ persistent_workers=True,
60
+ prefetch_factor=16,
61
+ )
62
+
63
+ def val_dataloader(self):
64
+ return DataLoader(
65
+ self._dev_dataset,
66
+ batch_size=1,
67
+ shuffle=False,
68
+ collate_fn=self._train_dataset.collate,
69
+ num_workers=max(self.num_workers, 12),
70
+ persistent_workers=True,
71
+ prefetch_factor=16,
72
+ )
73
+
74
+ # 这个会使用到嘛?
75
+ def test_dataloader(self):
76
+ return DataLoader(
77
+ self._dev_dataset,
78
+ batch_size=1,
79
+ shuffle=False,
80
+ collate_fn=self._train_dataset.collate,
81
+ )
AR/data/dataset.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+
4
+ # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
5
+ import os
6
+ import traceback
7
+ from typing import Dict, List
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ from torch.utils.data import DataLoader, Dataset
13
+
14
+ version = os.environ.get("version", None)
15
+
16
+ from text import cleaned_text_to_sequence
17
+
18
+ # from config import exp_dir
19
+
20
+
21
+ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
22
+ seq = sequences[0]
23
+ ndim = seq.ndim
24
+ if axis < 0:
25
+ axis += ndim
26
+ dtype = seq.dtype
27
+ pad_value = dtype.type(pad_value)
28
+ seq_lengths = [seq.shape[axis] for seq in sequences]
29
+ max_length = np.max(seq_lengths)
30
+
31
+ padded_sequences = []
32
+ for seq, length in zip(sequences, seq_lengths):
33
+ padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
34
+ padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
35
+ padded_sequences.append(padded_seq)
36
+ batch = np.stack(padded_sequences)
37
+ return batch
38
+
39
+
40
+ class Text2SemanticDataset(Dataset):
41
+ """dataset class for text tokens to semantic model training."""
42
+
43
+ def __init__(
44
+ self,
45
+ phoneme_path: str,
46
+ semantic_path: str,
47
+ max_sample: int = None,
48
+ max_sec: int = 100,
49
+ pad_val: int = 1024,
50
+ # min value of phoneme/sec
51
+ min_ps_ratio: int = 3,
52
+ # max value of phoneme/sec
53
+ max_ps_ratio: int = 25,
54
+ ) -> None:
55
+ super().__init__()
56
+
57
+ self.semantic_data = pd.read_csv(
58
+ semantic_path,
59
+ delimiter="\t",
60
+ encoding="utf-8",
61
+ )
62
+ # get dict
63
+ self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
64
+ self.path3 = "%s/3-bert" % (
65
+ os.path.dirname(
66
+ phoneme_path,
67
+ )
68
+ ) # "%s/3-bert"%exp_dir#bert_dir
69
+ self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
70
+ assert os.path.exists(self.path2)
71
+ assert os.path.exists(self.path6)
72
+ self.phoneme_data = {}
73
+ with open(self.path2, "r", encoding="utf8") as f:
74
+ lines = f.read().strip("\n").split("\n")
75
+
76
+ for line in lines:
77
+ tmp = line.split("\t")
78
+ if len(tmp) != 4:
79
+ continue
80
+ self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
81
+
82
+ # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
83
+ # pad for semantic tokens
84
+ self.PAD: int = pad_val
85
+ # self.hz = 25
86
+ # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
87
+ # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
88
+ # self.hz=int(data[:-2])#
89
+ self.hz = int(os.environ.get("hz", "25hz")[:-2])
90
+
91
+ # max seconds of semantic token
92
+ self.max_sec = max_sec
93
+ self.min_ps_ratio = min_ps_ratio
94
+ self.max_ps_ratio = max_ps_ratio
95
+
96
+ if max_sample is not None:
97
+ self.semantic_data = self.semantic_data[:max_sample]
98
+
99
+ # {idx: (semantic, phoneme)}
100
+ # semantic list, phoneme list
101
+ self.semantic_phoneme = []
102
+ self.item_names = []
103
+
104
+ self.inited = False
105
+
106
+ if not self.inited:
107
+ # 调用初始化函数
108
+ self.init_batch()
109
+ self.inited = True
110
+ del self.semantic_data
111
+ del self.phoneme_data
112
+ # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
113
+ # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
114
+
115
+ def init_batch(self):
116
+ semantic_data_len = len(self.semantic_data)
117
+ phoneme_data_len = len(self.phoneme_data.keys())
118
+ print("semantic_data_len:", semantic_data_len)
119
+ print("phoneme_data_len:", phoneme_data_len)
120
+ print(self.semantic_data)
121
+ idx = 0
122
+ num_not_in = 0
123
+ num_deleted_bigger = 0
124
+ num_deleted_ps = 0
125
+ for i in range(semantic_data_len):
126
+ # 先依次遍历
127
+ # get str
128
+ item_name = self.semantic_data.iloc[i, 0]
129
+ # print(self.phoneme_data)
130
+ try:
131
+ phoneme, word2ph, text = self.phoneme_data[item_name]
132
+ except Exception:
133
+ traceback.print_exc()
134
+ # print(f"{item_name} not in self.phoneme_data !")
135
+ num_not_in += 1
136
+ continue
137
+
138
+ semantic_str = self.semantic_data.iloc[i, 1]
139
+ # get token list
140
+ semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
141
+ # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
142
+ # 过滤掉太长的样��
143
+ if (
144
+ len(semantic_ids) > self.max_sec * self.hz
145
+ ): #########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k
146
+ num_deleted_bigger += 1
147
+ continue
148
+ # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
149
+ phoneme = phoneme.split(" ")
150
+
151
+ try:
152
+ phoneme_ids = cleaned_text_to_sequence(phoneme, version)
153
+ except:
154
+ traceback.print_exc()
155
+ # print(f"{item_name} not in self.phoneme_data !")
156
+ num_not_in += 1
157
+ continue
158
+ # if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
159
+ if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2:改为恒定限制为semantic/2.5就行
160
+ num_deleted_ps += 1
161
+ continue
162
+ # if len(semantic_ids) > 1000:###########3
163
+ # num_deleted_bigger += 1
164
+ # continue
165
+
166
+ ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
167
+
168
+ if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
169
+ num_deleted_ps += 1
170
+ # print(item_name)
171
+ continue
172
+
173
+ self.semantic_phoneme.append((semantic_ids, phoneme_ids))
174
+ idx += 1
175
+ self.item_names.append(item_name)
176
+
177
+ min_num = 100 # 20直接不补#30补了也不存ckpt
178
+ leng = len(self.semantic_phoneme)
179
+ if leng < min_num:
180
+ tmp1 = self.semantic_phoneme
181
+ tmp2 = self.item_names
182
+ self.semantic_phoneme = []
183
+ self.item_names = []
184
+ for _ in range(max(2, int(min_num / leng))):
185
+ self.semantic_phoneme += tmp1
186
+ self.item_names += tmp2
187
+ if num_not_in > 0:
188
+ print(f"there are {num_not_in} semantic datas not in phoneme datas")
189
+ if num_deleted_bigger > 0:
190
+ print(
191
+ f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
192
+ )
193
+ if num_deleted_ps > 0:
194
+ # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
195
+ print(
196
+ f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
197
+ )
198
+ """
199
+ there are 31 semantic datas not in phoneme datas
200
+ deleted 34 audios who's duration are bigger than 54 seconds
201
+ deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
202
+ dataset.__len__(): 366463
203
+
204
+ """
205
+ # 345410 for LibriTTS
206
+ print("dataset.__len__():", self.__len__())
207
+
208
+ def __get_item_names__(self) -> List[str]:
209
+ return self.item_names
210
+
211
+ def __len__(self) -> int:
212
+ return len(self.semantic_phoneme)
213
+
214
+ def __getitem__(self, idx: int) -> Dict:
215
+ semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
216
+ item_name = self.item_names[idx]
217
+ phoneme_ids_len = len(phoneme_ids)
218
+ # semantic tokens target
219
+ semantic_ids_len = len(semantic_ids)
220
+
221
+ flag = 0
222
+ path_bert = "%s/%s.pt" % (self.path3, item_name)
223
+ if os.path.exists(path_bert) == True:
224
+ bert_feature = torch.load(path_bert, map_location="cpu")
225
+ else:
226
+ flag = 1
227
+ if flag == 1:
228
+ # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
229
+ bert_feature = None
230
+ else:
231
+ assert bert_feature.shape[-1] == len(phoneme_ids)
232
+ return {
233
+ "idx": idx,
234
+ "phoneme_ids": phoneme_ids,
235
+ "phoneme_ids_len": phoneme_ids_len,
236
+ "semantic_ids": semantic_ids,
237
+ "semantic_ids_len": semantic_ids_len,
238
+ "bert_feature": bert_feature,
239
+ }
240
+
241
+ def get_sample_length(self, idx: int):
242
+ semantic_ids = self.semantic_phoneme[idx][0]
243
+ sec = 1.0 * len(semantic_ids) / self.hz
244
+ return sec
245
+
246
+ def collate(self, examples: List[Dict]) -> Dict:
247
+ sample_index: List[int] = []
248
+ phoneme_ids: List[torch.Tensor] = []
249
+ phoneme_ids_lens: List[int] = []
250
+ semantic_ids: List[torch.Tensor] = []
251
+ semantic_ids_lens: List[int] = []
252
+ # return
253
+
254
+ for item in examples:
255
+ sample_index.append(item["idx"])
256
+ phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
257
+ semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
258
+ phoneme_ids_lens.append(item["phoneme_ids_len"])
259
+ semantic_ids_lens.append(item["semantic_ids_len"])
260
+
261
+ # pad 0
262
+ phoneme_ids = batch_sequences(phoneme_ids)
263
+ semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
264
+
265
+ # # convert each batch to torch.tensor
266
+ phoneme_ids = torch.tensor(phoneme_ids)
267
+ semantic_ids = torch.tensor(semantic_ids)
268
+ phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
269
+ semantic_ids_lens = torch.tensor(semantic_ids_lens)
270
+ bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
271
+ bert_padded.zero_()
272
+
273
+ for idx, item in enumerate(examples):
274
+ bert = item["bert_feature"]
275
+ if bert != None:
276
+ bert_padded[idx, :, : bert.shape[-1]] = bert
277
+
278
+ return {
279
+ # List[int]
280
+ "ids": sample_index,
281
+ # torch.Tensor (B, max_phoneme_length)
282
+ "phoneme_ids": phoneme_ids,
283
+ # torch.Tensor (B)
284
+ "phoneme_ids_len": phoneme_ids_lens,
285
+ # torch.Tensor (B, max_semantic_ids_length)
286
+ "semantic_ids": semantic_ids,
287
+ # torch.Tensor (B)
288
+ "semantic_ids_len": semantic_ids_lens,
289
+ # torch.Tensor (B, 1024, max_phoneme_length)
290
+ "bert_feature": bert_padded,
291
+ }
292
+
293
+
294
+ if __name__ == "__main__":
295
+ root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
296
+ dataset = Text2SemanticDataset(
297
+ phoneme_path=root_dir + "phoneme_train.npy",
298
+ semantic_path=root_dir + "semantic_train.tsv",
299
+ )
300
+
301
+ batch_size = 12
302
+ dataloader = DataLoader(
303
+ dataset,
304
+ batch_size=batch_size,
305
+ collate_fn=dataset.collate,
306
+ shuffle=False,
307
+ )
308
+ for i, batch in enumerate(dataloader):
309
+ if i % 1000 == 0:
310
+ print(i)
311
+ # if i == 0:
312
+ # print('batch["ids"]:', batch["ids"])
313
+ # print('batch["phoneme_ids"]:', batch["phoneme_ids"],
314
+ # batch["phoneme_ids"].shape)
315
+ # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
316
+ # batch["phoneme_ids_len"].shape)
317
+ # print('batch["semantic_ids"]:', batch["semantic_ids"],
318
+ # batch["semantic_ids"].shape)
319
+ # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
320
+ # batch["semantic_ids_len"].shape)
AR/models/__init__.py ADDED
File without changes
AR/models/t2s_lightning_module.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import os
4
+ import sys
5
+
6
+ now_dir = os.getcwd()
7
+ sys.path.append(now_dir)
8
+ from typing import Dict
9
+
10
+ import torch
11
+ from pytorch_lightning import LightningModule
12
+
13
+ from AR.models.t2s_model import Text2SemanticDecoder
14
+ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
15
+ from AR.modules.optim import ScaledAdam
16
+
17
+
18
+ class Text2SemanticLightningModule(LightningModule):
19
+ def __init__(self, config, output_dir, is_train=True):
20
+ super().__init__()
21
+ self.config = config
22
+ self.top_k = 3
23
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
24
+ pretrained_s1 = config.get("pretrained_s1")
25
+ if pretrained_s1 and is_train:
26
+ # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
27
+ print(
28
+ self.load_state_dict(
29
+ torch.load(
30
+ pretrained_s1,
31
+ map_location="cpu",
32
+ weights_only=False,
33
+ )["weight"],
34
+ )
35
+ )
36
+ if is_train:
37
+ self.automatic_optimization = False
38
+ self.save_hyperparameters()
39
+ self.eval_dir = output_dir / "eval"
40
+ self.eval_dir.mkdir(parents=True, exist_ok=True)
41
+
42
+ def training_step(self, batch: Dict, batch_idx: int):
43
+ opt = self.optimizers()
44
+ scheduler = self.lr_schedulers()
45
+ forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
46
+ loss, acc = forward(
47
+ batch["phoneme_ids"],
48
+ batch["phoneme_ids_len"],
49
+ batch["semantic_ids"],
50
+ batch["semantic_ids_len"],
51
+ batch["bert_feature"],
52
+ )
53
+ self.manual_backward(loss)
54
+ if batch_idx > 0 and batch_idx % 4 == 0:
55
+ opt.step()
56
+ opt.zero_grad()
57
+ scheduler.step()
58
+
59
+ self.log(
60
+ "total_loss",
61
+ loss,
62
+ on_step=True,
63
+ on_epoch=True,
64
+ prog_bar=True,
65
+ sync_dist=True,
66
+ )
67
+ self.log(
68
+ "lr",
69
+ scheduler.get_last_lr()[0],
70
+ on_epoch=True,
71
+ prog_bar=True,
72
+ sync_dist=True,
73
+ )
74
+ self.log(
75
+ f"top_{self.top_k}_acc",
76
+ acc,
77
+ on_step=True,
78
+ on_epoch=True,
79
+ prog_bar=True,
80
+ sync_dist=True,
81
+ )
82
+
83
+ def validation_step(self, batch: Dict, batch_idx: int):
84
+ return
85
+
86
+ # # get loss
87
+ # loss, acc = self.model.forward(
88
+ # batch['phoneme_ids'], batch['phoneme_ids_len'],
89
+ # batch['semantic_ids'], batch['semantic_ids_len'],
90
+ # batch['bert_feature']
91
+ # )
92
+ #
93
+ # self.log(
94
+ # "val_total_loss",
95
+ # loss,
96
+ # on_step=True,
97
+ # on_epoch=True,
98
+ # prog_bar=True,
99
+ # sync_dist=True)
100
+ # self.log(
101
+ # f"val_top_{self.top_k}_acc",
102
+ # acc,
103
+ # on_step=True,
104
+ # on_epoch=True,
105
+ # prog_bar=True,
106
+ # sync_dist=True)
107
+ #
108
+ # # get infer output
109
+ # semantic_len = batch['semantic_ids'].size(1)
110
+ # prompt_len = min(int(semantic_len * 0.5), 150)
111
+ # prompt = batch['semantic_ids'][:, :prompt_len]
112
+ # pred_semantic = self.model.infer(batch['phoneme_ids'],
113
+ # batch['phoneme_ids_len'], prompt,
114
+ # batch['bert_feature']
115
+ # )
116
+ # save_name = f'semantic_toks_{batch_idx}.pt'
117
+ # save_path = os.path.join(self.eval_dir, save_name)
118
+ # torch.save(pred_semantic.detach().cpu(), save_path)
119
+
120
+ def configure_optimizers(self):
121
+ model_parameters = self.model.parameters()
122
+ parameters_names = []
123
+ parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
124
+ lm_opt = ScaledAdam(
125
+ model_parameters,
126
+ lr=0.01,
127
+ betas=(0.9, 0.95),
128
+ clipping_scale=2.0,
129
+ parameters_names=parameters_names,
130
+ show_dominant_parameters=False,
131
+ clipping_update_period=1000,
132
+ )
133
+
134
+ return {
135
+ "optimizer": lm_opt,
136
+ "lr_scheduler": {
137
+ "scheduler": WarmupCosineLRSchedule(
138
+ lm_opt,
139
+ init_lr=self.config["optimizer"]["lr_init"],
140
+ peak_lr=self.config["optimizer"]["lr"],
141
+ end_lr=self.config["optimizer"]["lr_end"],
142
+ warmup_steps=self.config["optimizer"]["warmup_steps"],
143
+ total_steps=self.config["optimizer"]["decay_steps"],
144
+ )
145
+ },
146
+ }
AR/models/t2s_lightning_module_onnx.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import os
4
+ import sys
5
+
6
+ now_dir = os.getcwd()
7
+ sys.path.append(now_dir)
8
+ from typing import Dict
9
+
10
+ import torch
11
+ from pytorch_lightning import LightningModule
12
+
13
+ from AR.models.t2s_model_onnx import Text2SemanticDecoder
14
+ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
15
+ from AR.modules.optim import ScaledAdam
16
+
17
+
18
+ class Text2SemanticLightningModule(LightningModule):
19
+ def __init__(self, config, output_dir, is_train=True):
20
+ super().__init__()
21
+ self.config = config
22
+ self.top_k = 3
23
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
24
+ pretrained_s1 = config.get("pretrained_s1")
25
+ if pretrained_s1 and is_train:
26
+ # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
27
+ print(
28
+ self.load_state_dict(
29
+ torch.load(
30
+ pretrained_s1,
31
+ map_location="cpu",
32
+ )["weight"],
33
+ ),
34
+ )
35
+ if is_train:
36
+ self.automatic_optimization = False
37
+ self.save_hyperparameters()
38
+ self.eval_dir = output_dir / "eval"
39
+ self.eval_dir.mkdir(parents=True, exist_ok=True)
40
+
41
+ def training_step(self, batch: Dict, batch_idx: int):
42
+ opt = self.optimizers()
43
+ scheduler = self.lr_schedulers()
44
+ loss, acc = self.model.forward(
45
+ batch["phoneme_ids"],
46
+ batch["phoneme_ids_len"],
47
+ batch["semantic_ids"],
48
+ batch["semantic_ids_len"],
49
+ batch["bert_feature"],
50
+ )
51
+ self.manual_backward(loss)
52
+ if batch_idx > 0 and batch_idx % 4 == 0:
53
+ opt.step()
54
+ opt.zero_grad()
55
+ scheduler.step()
56
+
57
+ self.log(
58
+ "total_loss",
59
+ loss,
60
+ on_step=True,
61
+ on_epoch=True,
62
+ prog_bar=True,
63
+ sync_dist=True,
64
+ )
65
+ self.log(
66
+ "lr",
67
+ scheduler.get_last_lr()[0],
68
+ on_epoch=True,
69
+ prog_bar=True,
70
+ sync_dist=True,
71
+ )
72
+ self.log(
73
+ f"top_{self.top_k}_acc",
74
+ acc,
75
+ on_step=True,
76
+ on_epoch=True,
77
+ prog_bar=True,
78
+ sync_dist=True,
79
+ )
80
+
81
+ def validation_step(self, batch: Dict, batch_idx: int):
82
+ return
83
+
84
+ def configure_optimizers(self):
85
+ model_parameters = self.model.parameters()
86
+ parameters_names = []
87
+ parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
88
+ lm_opt = ScaledAdam(
89
+ model_parameters,
90
+ lr=0.01,
91
+ betas=(0.9, 0.95),
92
+ clipping_scale=2.0,
93
+ parameters_names=parameters_names,
94
+ show_dominant_parameters=False,
95
+ clipping_update_period=1000,
96
+ )
97
+
98
+ return {
99
+ "optimizer": lm_opt,
100
+ "lr_scheduler": {
101
+ "scheduler": WarmupCosineLRSchedule(
102
+ lm_opt,
103
+ init_lr=self.config["optimizer"]["lr_init"],
104
+ peak_lr=self.config["optimizer"]["lr"],
105
+ end_lr=self.config["optimizer"]["lr_end"],
106
+ warmup_steps=self.config["optimizer"]["warmup_steps"],
107
+ total_steps=self.config["optimizer"]["decay_steps"],
108
+ )
109
+ },
110
+ }
AR/models/t2s_model.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import math
4
+ from typing import List, Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torchmetrics.classification import MulticlassAccuracy
10
+ from tqdm import tqdm
11
+
12
+ from AR.models.utils import (
13
+ dpo_loss,
14
+ get_batch_logps,
15
+ make_pad_mask,
16
+ make_pad_mask_left,
17
+ make_reject_y,
18
+ sample,
19
+ topk_sampling,
20
+ )
21
+ from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
22
+ from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
23
+
24
+ default_config = {
25
+ "embedding_dim": 512,
26
+ "hidden_dim": 512,
27
+ "num_head": 8,
28
+ "num_layers": 12,
29
+ "num_codebook": 8,
30
+ "p_dropout": 0.0,
31
+ "vocab_size": 1024 + 1,
32
+ "phoneme_vocab_size": 512,
33
+ "EOS": 1024,
34
+ }
35
+
36
+
37
+ # @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
38
+ # Efficient implementation equivalent to the following:
39
+ def scaled_dot_product_attention(
40
+ query: torch.Tensor,
41
+ key: torch.Tensor,
42
+ value: torch.Tensor,
43
+ attn_mask: Optional[torch.Tensor] = None,
44
+ scale: Optional[torch.Tensor] = None,
45
+ ) -> torch.Tensor:
46
+ B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
47
+ if scale is None:
48
+ scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
49
+ else:
50
+ scale_factor = scale
51
+ attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
52
+
53
+ if attn_mask is not None:
54
+ if attn_mask.dtype == torch.bool:
55
+ attn_bias.masked_fill_(attn_mask, float("-inf"))
56
+ else:
57
+ attn_bias += attn_mask
58
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
59
+ attn_weight += attn_bias
60
+ attn_weight = torch.softmax(attn_weight, dim=-1)
61
+
62
+ if attn_mask is not None:
63
+ if attn_mask.dtype == torch.bool:
64
+ attn_weight.masked_fill_(attn_mask, 0)
65
+ else:
66
+ attn_mask[attn_mask != float("-inf")] = 0
67
+ attn_mask[attn_mask == float("-inf")] = 1
68
+ attn_weight.masked_fill_(attn_mask, 0)
69
+
70
+ return attn_weight @ value
71
+
72
+
73
+ @torch.jit.script
74
+ class T2SMLP:
75
+ def __init__(self, w1, b1, w2, b2):
76
+ self.w1 = w1
77
+ self.b1 = b1
78
+ self.w2 = w2
79
+ self.b2 = b2
80
+
81
+ def forward(self, x):
82
+ x = F.relu(F.linear(x, self.w1, self.b1))
83
+ x = F.linear(x, self.w2, self.b2)
84
+ return x
85
+
86
+
87
+ @torch.jit.script
88
+ class T2SBlock:
89
+ def __init__(
90
+ self,
91
+ num_heads,
92
+ hidden_dim: int,
93
+ mlp: T2SMLP,
94
+ qkv_w,
95
+ qkv_b,
96
+ out_w,
97
+ out_b,
98
+ norm_w1,
99
+ norm_b1,
100
+ norm_eps1,
101
+ norm_w2,
102
+ norm_b2,
103
+ norm_eps2,
104
+ ):
105
+ self.num_heads = num_heads
106
+ self.mlp = mlp
107
+ self.hidden_dim: int = hidden_dim
108
+ self.qkv_w = qkv_w
109
+ self.qkv_b = qkv_b
110
+ self.out_w = out_w
111
+ self.out_b = out_b
112
+ self.norm_w1 = norm_w1
113
+ self.norm_b1 = norm_b1
114
+ self.norm_eps1 = norm_eps1
115
+ self.norm_w2 = norm_w2
116
+ self.norm_b2 = norm_b2
117
+ self.norm_eps2 = norm_eps2
118
+
119
+ self.false = torch.tensor(False, dtype=torch.bool)
120
+
121
+ @torch.jit.ignore
122
+ def to_mask(
123
+ self,
124
+ x: torch.Tensor,
125
+ padding_mask: Optional[torch.Tensor],
126
+ ):
127
+ if padding_mask is None:
128
+ return x
129
+
130
+ if padding_mask.dtype == torch.bool:
131
+ return x.masked_fill(padding_mask, 0)
132
+ else:
133
+ return x * padding_mask
134
+
135
+ def process_prompt(
136
+ self,
137
+ x: torch.Tensor,
138
+ attn_mask: torch.Tensor,
139
+ padding_mask: Optional[torch.Tensor] = None,
140
+ torch_sdpa: bool = True,
141
+ ):
142
+ q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
143
+
144
+ batch_size = q.shape[0]
145
+ q_len = q.shape[1]
146
+ kv_len = k.shape[1]
147
+
148
+ q = self.to_mask(q, padding_mask)
149
+ k_cache = self.to_mask(k, padding_mask)
150
+ v_cache = self.to_mask(v, padding_mask)
151
+
152
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
153
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
154
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
155
+
156
+ if torch_sdpa:
157
+ attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
158
+ else:
159
+ attn = scaled_dot_product_attention(q, k, v, attn_mask)
160
+
161
+ attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
162
+ attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
163
+
164
+ x = x + attn
165
+ x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
166
+ x = x + self.mlp.forward(x)
167
+ x = F.layer_norm(
168
+ x,
169
+ [self.hidden_dim],
170
+ self.norm_w2,
171
+ self.norm_b2,
172
+ self.norm_eps2,
173
+ )
174
+ return x, k_cache, v_cache
175
+
176
+ def decode_next_token(
177
+ self,
178
+ x: torch.Tensor,
179
+ k_cache: torch.Tensor,
180
+ v_cache: torch.Tensor,
181
+ attn_mask: torch.Tensor = None,
182
+ torch_sdpa: bool = True,
183
+ ):
184
+ q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
185
+
186
+ k_cache = torch.cat([k_cache, k], dim=1)
187
+ v_cache = torch.cat([v_cache, v], dim=1)
188
+
189
+ batch_size = q.shape[0]
190
+ q_len = q.shape[1]
191
+ kv_len = k_cache.shape[1]
192
+
193
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
194
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
195
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
196
+
197
+ if torch_sdpa:
198
+ attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
199
+ else:
200
+ attn = scaled_dot_product_attention(q, k, v, attn_mask)
201
+
202
+ attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
203
+ attn = F.linear(attn, self.out_w, self.out_b)
204
+
205
+ x = x + attn
206
+ x = F.layer_norm(
207
+ x,
208
+ [self.hidden_dim],
209
+ self.norm_w1,
210
+ self.norm_b1,
211
+ self.norm_eps1,
212
+ )
213
+ x = x + self.mlp.forward(x)
214
+ x = F.layer_norm(
215
+ x,
216
+ [self.hidden_dim],
217
+ self.norm_w2,
218
+ self.norm_b2,
219
+ self.norm_eps2,
220
+ )
221
+ return x, k_cache, v_cache
222
+
223
+
224
+ @torch.jit.script
225
+ class T2STransformer:
226
+ def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
227
+ self.num_blocks: int = num_blocks
228
+ self.blocks = blocks
229
+
230
+ def process_prompt(
231
+ self,
232
+ x: torch.Tensor,
233
+ attn_mask: torch.Tensor,
234
+ padding_mask: Optional[torch.Tensor] = None,
235
+ torch_sdpa: bool = True,
236
+ ):
237
+ k_cache: List[torch.Tensor] = []
238
+ v_cache: List[torch.Tensor] = []
239
+ for i in range(self.num_blocks):
240
+ x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
241
+ k_cache.append(k_cache_)
242
+ v_cache.append(v_cache_)
243
+ return x, k_cache, v_cache
244
+
245
+ def decode_next_token(
246
+ self,
247
+ x: torch.Tensor,
248
+ k_cache: List[torch.Tensor],
249
+ v_cache: List[torch.Tensor],
250
+ attn_mask: torch.Tensor = None,
251
+ torch_sdpa: bool = True,
252
+ ):
253
+ for i in range(self.num_blocks):
254
+ x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
255
+ x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
256
+ )
257
+ return x, k_cache, v_cache
258
+
259
+
260
+ class Text2SemanticDecoder(nn.Module):
261
+ def __init__(self, config, norm_first=False, top_k=3):
262
+ super(Text2SemanticDecoder, self).__init__()
263
+ self.model_dim = config["model"]["hidden_dim"]
264
+ self.embedding_dim = config["model"]["embedding_dim"]
265
+ self.num_head = config["model"]["head"]
266
+ self.num_layers = config["model"]["n_layer"]
267
+ self.norm_first = norm_first
268
+ self.vocab_size = config["model"]["vocab_size"]
269
+ self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
270
+ self.p_dropout = config["model"]["dropout"]
271
+ self.EOS = config["model"]["EOS"]
272
+ self.norm_first = norm_first
273
+ assert self.EOS == self.vocab_size - 1
274
+ # should be same as num of kmeans bin
275
+ # assert self.EOS == 1024
276
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
277
+ self.ar_text_embedding = TokenEmbedding(
278
+ self.embedding_dim,
279
+ self.phoneme_vocab_size,
280
+ self.p_dropout,
281
+ )
282
+ self.ar_text_position = SinePositionalEmbedding(
283
+ self.embedding_dim,
284
+ dropout=0.1,
285
+ scale=False,
286
+ alpha=True,
287
+ )
288
+ self.ar_audio_embedding = TokenEmbedding(
289
+ self.embedding_dim,
290
+ self.vocab_size,
291
+ self.p_dropout,
292
+ )
293
+ self.ar_audio_position = SinePositionalEmbedding(
294
+ self.embedding_dim,
295
+ dropout=0.1,
296
+ scale=False,
297
+ alpha=True,
298
+ )
299
+
300
+ self.h = TransformerEncoder(
301
+ TransformerEncoderLayer(
302
+ d_model=self.model_dim,
303
+ nhead=self.num_head,
304
+ dim_feedforward=self.model_dim * 4,
305
+ dropout=0.1,
306
+ batch_first=True,
307
+ norm_first=norm_first,
308
+ ),
309
+ num_layers=self.num_layers,
310
+ norm=LayerNorm(self.model_dim) if norm_first else None,
311
+ )
312
+
313
+ self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
314
+ self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
315
+
316
+ self.ar_accuracy_metric = MulticlassAccuracy(
317
+ self.vocab_size,
318
+ top_k=top_k,
319
+ average="micro",
320
+ multidim_average="global",
321
+ ignore_index=self.EOS,
322
+ )
323
+
324
+ blocks = []
325
+
326
+ for i in range(self.num_layers):
327
+ layer = self.h.layers[i]
328
+ t2smlp = T2SMLP(
329
+ layer.linear1.weight,
330
+ layer.linear1.bias,
331
+ layer.linear2.weight,
332
+ layer.linear2.bias,
333
+ )
334
+
335
+ block = T2SBlock(
336
+ self.num_head,
337
+ self.model_dim,
338
+ t2smlp,
339
+ layer.self_attn.in_proj_weight,
340
+ layer.self_attn.in_proj_bias,
341
+ layer.self_attn.out_proj.weight,
342
+ layer.self_attn.out_proj.bias,
343
+ layer.norm1.weight,
344
+ layer.norm1.bias,
345
+ layer.norm1.eps,
346
+ layer.norm2.weight,
347
+ layer.norm2.bias,
348
+ layer.norm2.eps,
349
+ )
350
+
351
+ blocks.append(block)
352
+
353
+ self.t2s_transformer = T2STransformer(self.num_layers, blocks)
354
+
355
+ def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
356
+ x = self.ar_text_embedding(x)
357
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
358
+ x = self.ar_text_position(x)
359
+ x_mask = make_pad_mask(x_lens)
360
+
361
+ y_mask = make_pad_mask(y_lens)
362
+ y_mask_int = y_mask.type(torch.int64)
363
+ codes = y.type(torch.int64) * (1 - y_mask_int)
364
+
365
+ # Training
366
+ # AR Decoder
367
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
368
+ x_len = x_lens.max()
369
+ y_len = y_lens.max()
370
+ y_emb = self.ar_audio_embedding(y)
371
+ y_pos = self.ar_audio_position(y_emb)
372
+
373
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
374
+
375
+ ar_xy_padding_mask = xy_padding_mask
376
+
377
+ x_attn_mask = F.pad(
378
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
379
+ (0, y_len),
380
+ value=True,
381
+ )
382
+ # x_attn_mask[:, x_len]=False
383
+ y_attn_mask = F.pad(
384
+ torch.triu(
385
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
386
+ diagonal=1,
387
+ ),
388
+ (x_len, 0),
389
+ value=False,
390
+ )
391
+
392
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
393
+ bsz, src_len = x.shape[0], x_len + y_len
394
+ _xy_padding_mask = (
395
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
396
+ .expand(-1, self.num_head, -1, -1)
397
+ .reshape(bsz * self.num_head, 1, src_len)
398
+ )
399
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
400
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
401
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
402
+ xy_attn_mask = new_attn_mask
403
+ # x 和完整的 y 一次性输入模型
404
+ xy_pos = torch.concat([x, y_pos], dim=1)
405
+
406
+ return xy_pos, xy_attn_mask, targets
407
+
408
+ def forward(self, x, x_lens, y, y_lens, bert_feature):
409
+ """
410
+ x: phoneme_ids
411
+ y: semantic_ids
412
+ """
413
+
414
+ reject_y, reject_y_lens = make_reject_y(y, y_lens)
415
+
416
+ xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
417
+
418
+ xy_dec, _ = self.h(
419
+ (xy_pos, None),
420
+ mask=xy_attn_mask,
421
+ )
422
+ x_len = x_lens.max()
423
+ logits = self.ar_predict_layer(xy_dec[:, x_len:])
424
+
425
+ ###### DPO #############
426
+ reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
427
+ x, x_lens, reject_y, reject_y_lens, bert_feature
428
+ )
429
+
430
+ reject_xy_dec, _ = self.h(
431
+ (reject_xy_pos, None),
432
+ mask=reject_xy_attn_mask,
433
+ )
434
+ x_len = x_lens.max()
435
+ reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
436
+
437
+ # loss
438
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
439
+
440
+ loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
441
+ acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
442
+
443
+ A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
444
+ loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
445
+
446
+ loss = loss_1 + loss_2
447
+
448
+ return loss, acc
449
+
450
+ def forward_old(self, x, x_lens, y, y_lens, bert_feature):
451
+ """
452
+ x: phoneme_ids
453
+ y: semantic_ids
454
+ """
455
+ x = self.ar_text_embedding(x)
456
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
457
+ x = self.ar_text_position(x)
458
+ x_mask = make_pad_mask(x_lens)
459
+
460
+ y_mask = make_pad_mask(y_lens)
461
+ y_mask_int = y_mask.type(torch.int64)
462
+ codes = y.type(torch.int64) * (1 - y_mask_int)
463
+
464
+ # Training
465
+ # AR Decoder
466
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
467
+ x_len = x_lens.max()
468
+ y_len = y_lens.max()
469
+ y_emb = self.ar_audio_embedding(y)
470
+ y_pos = self.ar_audio_position(y_emb)
471
+
472
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
473
+ ar_xy_padding_mask = xy_padding_mask
474
+
475
+ x_attn_mask = F.pad(
476
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
477
+ (0, y_len),
478
+ value=True,
479
+ )
480
+ y_attn_mask = F.pad(
481
+ torch.triu(
482
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
483
+ diagonal=1,
484
+ ),
485
+ (x_len, 0),
486
+ value=False,
487
+ )
488
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
489
+ bsz, src_len = x.shape[0], x_len + y_len
490
+ _xy_padding_mask = (
491
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
492
+ .expand(-1, self.num_head, -1, -1)
493
+ .reshape(bsz * self.num_head, 1, src_len)
494
+ )
495
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
496
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
497
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
498
+ xy_attn_mask = new_attn_mask
499
+ # x 和完整的 y 一次性输入模型
500
+ xy_pos = torch.concat([x, y_pos], dim=1)
501
+ xy_dec, _ = self.h(
502
+ (xy_pos, None),
503
+ mask=xy_attn_mask,
504
+ )
505
+ logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
506
+ # loss
507
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
508
+ loss = F.cross_entropy(logits, targets, reduction="sum")
509
+ acc = self.ar_accuracy_metric(logits.detach(), targets).item()
510
+ return loss, acc
511
+
512
+ # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
513
+ def infer(
514
+ self,
515
+ x,
516
+ x_lens,
517
+ prompts,
518
+ bert_feature,
519
+ top_k: int = -100,
520
+ early_stop_num: int = -1,
521
+ temperature: float = 1.0,
522
+ ):
523
+ x = self.ar_text_embedding(x)
524
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
525
+ x = self.ar_text_position(x)
526
+
527
+ # AR Decoder
528
+ y = prompts
529
+ prefix_len = y.shape[1]
530
+ x_len = x.shape[1]
531
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
532
+ stop = False
533
+ for _ in tqdm(range(1500)):
534
+ y_emb = self.ar_audio_embedding(y)
535
+ y_pos = self.ar_audio_position(y_emb)
536
+ # x 和逐渐增长的 y 一起输入给模型
537
+ xy_pos = torch.concat([x, y_pos], dim=1)
538
+ y_len = y.shape[1]
539
+ x_attn_mask_pad = F.pad(
540
+ x_attn_mask,
541
+ (0, y_len),
542
+ value=True,
543
+ )
544
+ y_attn_mask = F.pad(
545
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
546
+ (x_len, 0),
547
+ value=False,
548
+ )
549
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
550
+
551
+ xy_dec, _ = self.h(
552
+ (xy_pos, None),
553
+ mask=xy_attn_mask,
554
+ )
555
+ logits = self.ar_predict_layer(xy_dec[:, -1])
556
+ samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
557
+
558
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
559
+ print("use early stop num:", early_stop_num)
560
+ stop = True
561
+
562
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
563
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
564
+ stop = True
565
+ if stop:
566
+ if prompts.shape[1] == y.shape[1]:
567
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
568
+ print("bad zero prediction")
569
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
570
+ break
571
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
572
+ # print(samples.shape)#[1,1]#第一个1是bs
573
+ # import os
574
+ # os._exit(2333)
575
+ y = torch.concat([y, samples], dim=1)
576
+ return y
577
+
578
+ def pad_y_eos(self, y, y_mask_int, eos_id):
579
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
580
+ # 错位
581
+ return targets[:, :-1], targets[:, 1:]
582
+
583
+ def infer_panel_batch_infer(
584
+ self,
585
+ x: List[torch.LongTensor], #####全部文本token
586
+ x_lens: torch.LongTensor,
587
+ prompts: torch.LongTensor, ####参考音频token
588
+ bert_feature: List[torch.LongTensor],
589
+ top_k: int = -100,
590
+ top_p: int = 100,
591
+ early_stop_num: int = -1,
592
+ temperature: float = 1.0,
593
+ repetition_penalty: float = 1.35,
594
+ **kwargs,
595
+ ):
596
+ if prompts is None:
597
+ print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
598
+ return self.infer_panel_naive_batched(
599
+ x,
600
+ x_lens,
601
+ prompts,
602
+ bert_feature,
603
+ top_k=top_k,
604
+ top_p=top_p,
605
+ early_stop_num=early_stop_num,
606
+ temperature=temperature,
607
+ **kwargs,
608
+ )
609
+
610
+ max_len = kwargs.get("max_len", x_lens.max())
611
+ x_list = []
612
+ for x_item, bert_item in zip(x, bert_feature):
613
+ # max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
614
+ x_item = self.ar_text_embedding(x_item.unsqueeze(0))
615
+ x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
616
+ x_item = self.ar_text_position(x_item).squeeze(0)
617
+ # x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
618
+ x_item = (
619
+ F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
620
+ ) ### padding left
621
+ x_list.append(x_item)
622
+ x: torch.Tensor = torch.stack(x_list, dim=0)
623
+
624
+ # AR Decoder
625
+ y = prompts
626
+
627
+ x_len = x.shape[1]
628
+ stop = False
629
+
630
+ k_cache = None
631
+ v_cache = None
632
+ ################### first step ##########################
633
+ assert y is not None, "Error: Prompt free is not supported batch_infer!"
634
+ ref_free = False
635
+
636
+ y_emb = self.ar_audio_embedding(y)
637
+ y_len = y_emb.shape[1]
638
+ prefix_len = y.shape[1]
639
+ y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
640
+ y_pos = self.ar_audio_position(y_emb)
641
+ xy_pos = torch.concat([x, y_pos], dim=1)
642
+
643
+ ##### create mask #####
644
+ bsz = x.shape[0]
645
+ src_len = x_len + y_len
646
+ y_paddind_mask = make_pad_mask_left(y_lens, y_len)
647
+ x_paddind_mask = make_pad_mask_left(x_lens, max_len)
648
+
649
+ # (bsz, x_len + y_len)
650
+ padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
651
+
652
+ x_mask = F.pad(
653
+ torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
654
+ (0, y_len),
655
+ value=True,
656
+ )
657
+
658
+ y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
659
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
660
+ (x_len, 0),
661
+ value=False,
662
+ )
663
+
664
+ causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
665
+ # padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
666
+ ### 上面是错误的,会导致padding的token被"看见"
667
+
668
+ # 正确的padding_mask应该是:
669
+ # | pad_len | x_len | y_len |
670
+ # [[PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
671
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
672
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
673
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
674
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
675
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
676
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
677
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
678
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
679
+
680
+ padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
681
+
682
+ attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
683
+ attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
684
+
685
+ # 正确的attn_mask应该是这样的:
686
+ # | pad_len | x_len | y_len |
687
+ # [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
688
+ # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
689
+ # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
690
+ # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
691
+ # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
692
+ # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
693
+ # [PAD, PAD, PAD, 1, 2, 3, 4, EOS, EOS],
694
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
695
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
696
+
697
+ ###### decode #####
698
+ y_list = [None] * y.shape[0]
699
+ batch_idx_map = list(range(y.shape[0]))
700
+ idx_list = [None] * y.shape[0]
701
+ for idx in tqdm(range(1500)):
702
+ if idx == 0:
703
+ xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
704
+ else:
705
+ xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
706
+ logits = self.ar_predict_layer(xy_dec[:, -1])
707
+
708
+ if idx == 0:
709
+ attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
710
+ logits = logits[:, :-1]
711
+ else:
712
+ attn_mask = F.pad(attn_mask, (0, 1), value=False)
713
+
714
+ samples = sample(
715
+ logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
716
+ )[0]
717
+
718
+ y = torch.concat([y, samples], dim=1)
719
+
720
+ ####### 移除batch中已经生成完毕的序列,进一步优化计算量
721
+ tokens = torch.argmax(logits, dim=-1)
722
+ reserved_idx_of_batch_for_y = None
723
+ if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS,则停止
724
+ l1 = samples[:, 0] == self.EOS
725
+ l2 = tokens == self.EOS
726
+ l = l1.logical_or(l2)
727
+ removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
728
+ reserved_idx_of_batch_for_y = torch.where(l == False)[0]
729
+ # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
730
+ for i in removed_idx_of_batch_for_y:
731
+ batch_index = batch_idx_map[i]
732
+ idx_list[batch_index] = idx
733
+ y_list[batch_index] = y[i, :-1]
734
+
735
+ batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
736
+
737
+ # 只保留batch中未生成完毕的序列
738
+ if reserved_idx_of_batch_for_y is not None:
739
+ # index = torch.LongTensor(batch_idx_map).to(y.device)
740
+ y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
741
+ attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
742
+ if k_cache is not None:
743
+ for i in range(len(k_cache)):
744
+ k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
745
+ v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
746
+
747
+ if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
748
+ print("use early stop num:", early_stop_num)
749
+ stop = True
750
+ for i, batch_index in enumerate(batch_idx_map):
751
+ batch_index = batch_idx_map[i]
752
+ idx_list[batch_index] = idx
753
+ y_list[batch_index] = y[i, :-1]
754
+
755
+ if None not in idx_list:
756
+ stop = True
757
+
758
+ if stop:
759
+ if y.shape[1] == 0:
760
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
761
+ print("bad zero prediction")
762
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
763
+ break
764
+
765
+ ####################### update next step ###################################
766
+ y_emb = self.ar_audio_embedding(y[:, -1:])
767
+ xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
768
+ :, y_len + idx
769
+ ].to(dtype=y_emb.dtype, device=y_emb.device)
770
+
771
+ if None in idx_list:
772
+ for i in range(x.shape[0]):
773
+ if idx_list[i] is None:
774
+ idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
775
+
776
+ if ref_free:
777
+ return y_list, [0] * x.shape[0]
778
+ # print(idx_list)
779
+ return y_list, idx_list
780
+
781
+ def infer_panel_naive_batched(
782
+ self,
783
+ x: List[torch.LongTensor], #####全部文本token
784
+ x_lens: torch.LongTensor,
785
+ prompts: torch.LongTensor, ####参考音频token
786
+ bert_feature: List[torch.LongTensor],
787
+ top_k: int = -100,
788
+ top_p: int = 100,
789
+ early_stop_num: int = -1,
790
+ temperature: float = 1.0,
791
+ repetition_penalty: float = 1.35,
792
+ **kwargs,
793
+ ):
794
+ y_list = []
795
+ idx_list = []
796
+ for i in range(len(x)):
797
+ y, idx = self.infer_panel_naive(
798
+ x[i].unsqueeze(0),
799
+ x_lens[i],
800
+ prompts[i].unsqueeze(0) if prompts is not None else None,
801
+ bert_feature[i].unsqueeze(0),
802
+ top_k,
803
+ top_p,
804
+ early_stop_num,
805
+ temperature,
806
+ repetition_penalty,
807
+ **kwargs,
808
+ )
809
+ y_list.append(y[0])
810
+ idx_list.append(idx)
811
+
812
+ return y_list, idx_list
813
+
814
+ def infer_panel_naive(
815
+ self,
816
+ x: torch.LongTensor, #####全部文本token
817
+ x_lens: torch.LongTensor,
818
+ prompts: torch.LongTensor, ####参考音频token
819
+ bert_feature: torch.LongTensor,
820
+ top_k: int = -100,
821
+ top_p: int = 100,
822
+ early_stop_num: int = -1,
823
+ temperature: float = 1.0,
824
+ repetition_penalty: float = 1.35,
825
+ **kwargs,
826
+ ):
827
+ x = self.ar_text_embedding(x)
828
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
829
+ x = self.ar_text_position(x)
830
+
831
+ # AR Decoder
832
+ y = prompts
833
+
834
+ x_len = x.shape[1]
835
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
836
+ stop = False
837
+ # print(1111111,self.num_layers)
838
+
839
+ k_cache = None
840
+ v_cache = None
841
+ ################### first step ##########################
842
+ if y is not None:
843
+ y_emb = self.ar_audio_embedding(y)
844
+ y_len = y_emb.shape[1]
845
+ prefix_len = y.shape[1]
846
+ y_pos = self.ar_audio_position(y_emb)
847
+ xy_pos = torch.concat([x, y_pos], dim=1)
848
+ ref_free = False
849
+ else:
850
+ y_emb = None
851
+ y_len = 0
852
+ prefix_len = 0
853
+ y_pos = None
854
+ xy_pos = x
855
+ y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
856
+ ref_free = True
857
+
858
+ bsz = x.shape[0]
859
+ src_len = x_len + y_len
860
+ x_attn_mask_pad = F.pad(
861
+ x_attn_mask,
862
+ (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
863
+ value=True,
864
+ )
865
+ y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
866
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
867
+ (x_len, 0),
868
+ value=False,
869
+ )
870
+ xy_attn_mask = (
871
+ torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
872
+ .unsqueeze(0)
873
+ .expand(bsz * self.num_head, -1, -1)
874
+ .view(bsz, self.num_head, src_len, src_len)
875
+ .to(device=x.device, dtype=torch.bool)
876
+ )
877
+
878
+ for idx in tqdm(range(1500)):
879
+ if xy_attn_mask is not None:
880
+ xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
881
+ else:
882
+ xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
883
+
884
+ logits = self.ar_predict_layer(xy_dec[:, -1])
885
+
886
+ if idx == 0:
887
+ xy_attn_mask = None
888
+ if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
889
+ logits = logits[:, :-1]
890
+
891
+ samples = sample(
892
+ logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
893
+ )[0]
894
+
895
+ y = torch.concat([y, samples], dim=1)
896
+
897
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
898
+ print("use early stop num:", early_stop_num)
899
+ stop = True
900
+
901
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
902
+ stop = True
903
+ if stop:
904
+ if y.shape[1] == 0:
905
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
906
+ print("bad zero prediction")
907
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
908
+ break
909
+
910
+ ####################### update next step ###################################
911
+ y_emb = self.ar_audio_embedding(y[:, -1:])
912
+ xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
913
+ :, y_len + idx
914
+ ].to(dtype=y_emb.dtype, device=y_emb.device)
915
+
916
+ if ref_free:
917
+ return y[:, :-1], 0
918
+ return y[:, :-1], idx
919
+
920
+ def infer_panel(
921
+ self,
922
+ x: torch.LongTensor, #####全部文本token
923
+ x_lens: torch.LongTensor,
924
+ prompts: torch.LongTensor, ####参考音频token
925
+ bert_feature: torch.LongTensor,
926
+ top_k: int = -100,
927
+ top_p: int = 100,
928
+ early_stop_num: int = -1,
929
+ temperature: float = 1.0,
930
+ repetition_penalty: float = 1.35,
931
+ **kwargs,
932
+ ):
933
+ return self.infer_panel_naive(
934
+ x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
935
+ )
AR/models/t2s_model_onnx.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torchmetrics.classification import MulticlassAccuracy
7
+
8
+ from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
9
+ from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
10
+
11
+ default_config = {
12
+ "embedding_dim": 512,
13
+ "hidden_dim": 512,
14
+ "num_head": 8,
15
+ "num_layers": 12,
16
+ "num_codebook": 8,
17
+ "p_dropout": 0.0,
18
+ "vocab_size": 1024 + 1,
19
+ "phoneme_vocab_size": 512,
20
+ "EOS": 1024,
21
+ }
22
+
23
+ inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
24
+
25
+
26
+ def logits_to_probs(
27
+ logits,
28
+ previous_tokens=None,
29
+ temperature: float = 1.0,
30
+ top_k=None,
31
+ top_p=None,
32
+ repetition_penalty: float = 1.0,
33
+ ):
34
+ previous_tokens = previous_tokens.squeeze()
35
+ if previous_tokens is not None and repetition_penalty != 1.0:
36
+ previous_tokens = previous_tokens.long()
37
+ score = torch.gather(logits, dim=0, index=previous_tokens)
38
+ score = torch.where(
39
+ score < 0,
40
+ score * repetition_penalty,
41
+ score / repetition_penalty,
42
+ )
43
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
44
+
45
+ if top_p is not None and top_p < 1.0:
46
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
47
+ cum_probs = torch.cumsum(
48
+ torch.nn.functional.softmax(
49
+ sorted_logits,
50
+ dim=-1,
51
+ ),
52
+ dim=-1,
53
+ )
54
+ sorted_indices_to_remove = cum_probs > top_p
55
+ sorted_indices_to_remove[0] = False # keep at least one option
56
+ indices_to_remove = sorted_indices_to_remove.scatter(
57
+ dim=0,
58
+ index=sorted_indices,
59
+ src=sorted_indices_to_remove,
60
+ )
61
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
62
+
63
+ logits = logits / max(temperature, 1e-5)
64
+
65
+ if top_k is not None:
66
+ v, _ = torch.topk(logits, top_k)
67
+ pivot = v.select(-1, -1).unsqueeze(-1)
68
+ logits = torch.where(logits < pivot, inf_tensor_value, logits)
69
+
70
+ probs = torch.nn.functional.softmax(logits, dim=-1)
71
+ return probs
72
+
73
+
74
+ def multinomial_sample_one_no_sync(
75
+ probs_sort,
76
+ ): # Does multinomial sampling without a cuda synchronization
77
+ q = torch.randn_like(probs_sort)
78
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
79
+
80
+
81
+ def sample(
82
+ logits,
83
+ previous_tokens,
84
+ **sampling_kwargs,
85
+ ):
86
+ probs = logits_to_probs(
87
+ logits=logits,
88
+ previous_tokens=previous_tokens,
89
+ **sampling_kwargs,
90
+ )
91
+ idx_next = multinomial_sample_one_no_sync(probs)
92
+ return idx_next, probs
93
+
94
+
95
+ class OnnxEncoder(nn.Module):
96
+ def __init__(self, ar_text_embedding, bert_proj, ar_text_position):
97
+ super().__init__()
98
+ self.ar_text_embedding = ar_text_embedding
99
+ self.bert_proj = bert_proj
100
+ self.ar_text_position = ar_text_position
101
+
102
+ def forward(self, x, bert_feature):
103
+ x = self.ar_text_embedding(x)
104
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
105
+ return self.ar_text_position(x)
106
+
107
+
108
+ class T2SFirstStageDecoder(nn.Module):
109
+ def __init__(
110
+ self,
111
+ ar_audio_embedding,
112
+ ar_audio_position,
113
+ h,
114
+ ar_predict_layer,
115
+ loss_fct,
116
+ ar_accuracy_metric,
117
+ top_k,
118
+ early_stop_num,
119
+ num_layers,
120
+ ):
121
+ super().__init__()
122
+ self.ar_audio_embedding = ar_audio_embedding
123
+ self.ar_audio_position = ar_audio_position
124
+ self.h = h
125
+ self.ar_predict_layer = ar_predict_layer
126
+ self.loss_fct = loss_fct
127
+ self.ar_accuracy_metric = ar_accuracy_metric
128
+ self.top_k = top_k
129
+ self.early_stop_num = early_stop_num
130
+ self.num_layers = num_layers
131
+
132
+ def forward(self, x, prompt):
133
+ y = prompt
134
+ x_example = x[:, :, 0] * 0.0
135
+ # N, 1, 512
136
+ cache = {
137
+ "all_stage": self.num_layers,
138
+ "k": None,
139
+ "v": None,
140
+ "y_emb": None,
141
+ "first_infer": 1,
142
+ "stage": 0,
143
+ }
144
+
145
+ y_emb = self.ar_audio_embedding(y)
146
+
147
+ cache["y_emb"] = y_emb
148
+ y_pos = self.ar_audio_position(y_emb)
149
+
150
+ xy_pos = torch.concat([x, y_pos], dim=1)
151
+
152
+ y_example = y_pos[:, :, 0] * 0.0
153
+ x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
154
+ y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
155
+ y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
156
+ torch.ones_like(
157
+ y_example.transpose(0, 1),
158
+ dtype=torch.int64,
159
+ ),
160
+ dim=0,
161
+ )
162
+ y_attn_mask = y_attn_mask > 0
163
+
164
+ x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
165
+ y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
166
+ x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
167
+ y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
168
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
169
+ cache["k"] = (
170
+ torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
171
+ .unsqueeze(1)
172
+ .repeat(self.num_layers, 1, 1, 1)
173
+ )
174
+ cache["v"] = (
175
+ torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
176
+ .unsqueeze(1)
177
+ .repeat(self.num_layers, 1, 1, 1)
178
+ )
179
+
180
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
181
+ logits = self.ar_predict_layer(xy_dec[:, -1])
182
+ samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
183
+
184
+ y = torch.concat([y, samples], dim=1)
185
+
186
+ return y, cache["k"], cache["v"], cache["y_emb"], x_example
187
+
188
+
189
+ class T2SStageDecoder(nn.Module):
190
+ def __init__(
191
+ self,
192
+ ar_audio_embedding,
193
+ ar_audio_position,
194
+ h,
195
+ ar_predict_layer,
196
+ loss_fct,
197
+ ar_accuracy_metric,
198
+ top_k,
199
+ early_stop_num,
200
+ num_layers,
201
+ ):
202
+ super().__init__()
203
+ self.ar_audio_embedding = ar_audio_embedding
204
+ self.ar_audio_position = ar_audio_position
205
+ self.h = h
206
+ self.ar_predict_layer = ar_predict_layer
207
+ self.loss_fct = loss_fct
208
+ self.ar_accuracy_metric = ar_accuracy_metric
209
+ self.top_k = top_k
210
+ self.early_stop_num = early_stop_num
211
+ self.num_layers = num_layers
212
+
213
+ def forward(self, y, k, v, y_emb, x_example):
214
+ cache = {
215
+ "all_stage": self.num_layers,
216
+ "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
217
+ "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
218
+ "y_emb": y_emb,
219
+ "first_infer": 0,
220
+ "stage": 0,
221
+ }
222
+
223
+ y_emb = torch.cat(
224
+ [
225
+ cache["y_emb"],
226
+ self.ar_audio_embedding(y[:, -1:]),
227
+ ],
228
+ 1,
229
+ )
230
+ cache["y_emb"] = y_emb
231
+ y_pos = self.ar_audio_position(y_emb)
232
+
233
+ xy_pos = y_pos[:, -1:]
234
+
235
+ y_example = y_pos[:, :, 0] * 0.0
236
+
237
+ xy_attn_mask = torch.cat([x_example, y_example], dim=1)
238
+ xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
239
+
240
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
241
+ logits = self.ar_predict_layer(xy_dec[:, -1])
242
+ samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
243
+
244
+ y = torch.concat([y, samples], dim=1)
245
+
246
+ return y, cache["k"], cache["v"], cache["y_emb"], logits, samples
247
+
248
+
249
+ class Text2SemanticDecoder(nn.Module):
250
+ def __init__(self, config, norm_first=False, top_k=3):
251
+ super(Text2SemanticDecoder, self).__init__()
252
+ self.model_dim = config["model"]["hidden_dim"]
253
+ self.embedding_dim = config["model"]["embedding_dim"]
254
+ self.num_head = config["model"]["head"]
255
+ self.num_layers = config["model"]["n_layer"]
256
+ self.norm_first = norm_first
257
+ self.vocab_size = config["model"]["vocab_size"]
258
+ self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
259
+ self.p_dropout = float(config["model"]["dropout"])
260
+ self.EOS = config["model"]["EOS"]
261
+ self.norm_first = norm_first
262
+ assert self.EOS == self.vocab_size - 1
263
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
264
+ self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
265
+ self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
266
+ self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
267
+ self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
268
+ self.h = TransformerEncoder(
269
+ TransformerEncoderLayer(
270
+ d_model=self.model_dim,
271
+ nhead=self.num_head,
272
+ dim_feedforward=self.model_dim * 4,
273
+ dropout=0.1,
274
+ batch_first=True,
275
+ norm_first=norm_first,
276
+ ),
277
+ num_layers=self.num_layers,
278
+ norm=LayerNorm(self.model_dim) if norm_first else None,
279
+ )
280
+ self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
281
+ self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
282
+ self.ar_accuracy_metric = MulticlassAccuracy(
283
+ self.vocab_size,
284
+ top_k=top_k,
285
+ average="micro",
286
+ multidim_average="global",
287
+ ignore_index=self.EOS,
288
+ )
289
+ self.top_k = torch.LongTensor([1])
290
+ self.early_stop_num = torch.LongTensor([-1])
291
+
292
+ def init_onnx(self):
293
+ self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
294
+ self.first_stage_decoder = T2SFirstStageDecoder(
295
+ self.ar_audio_embedding,
296
+ self.ar_audio_position,
297
+ self.h,
298
+ self.ar_predict_layer,
299
+ self.loss_fct,
300
+ self.ar_accuracy_metric,
301
+ self.top_k,
302
+ self.early_stop_num,
303
+ self.num_layers,
304
+ )
305
+ self.stage_decoder = T2SStageDecoder(
306
+ self.ar_audio_embedding,
307
+ self.ar_audio_position,
308
+ self.h,
309
+ self.ar_predict_layer,
310
+ self.loss_fct,
311
+ self.ar_accuracy_metric,
312
+ self.top_k,
313
+ self.early_stop_num,
314
+ self.num_layers,
315
+ )
316
+
317
+ def forward(self, x, prompts, bert_feature):
318
+ early_stop_num = self.early_stop_num
319
+ prefix_len = prompts.shape[1]
320
+
321
+ x = self.onnx_encoder(x, bert_feature)
322
+ y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts)
323
+
324
+ stop = False
325
+ for idx in range(1, 1500):
326
+ enco = self.stage_decoder(y, k, v, y_emb, stage, x_example)
327
+ y, k, v, y_emb, stage, logits, samples = enco
328
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
329
+ stop = True
330
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
331
+ stop = True
332
+ if stop:
333
+ break
334
+ y[0, -1] = 0
335
+ return y, idx
336
+
337
+ def infer(self, x, prompts, bert_feature):
338
+ top_k = self.top_k
339
+ early_stop_num = self.early_stop_num
340
+
341
+ x = self.onnx_encoder(x, bert_feature)
342
+
343
+ y = prompts
344
+ prefix_len = y.shape[1]
345
+ x_len = x.shape[1]
346
+ x_example = x[:, :, 0] * 0.0
347
+ x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
348
+ x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
349
+
350
+ stop = False
351
+ cache = {
352
+ "all_stage": self.num_layers,
353
+ "k": [None] * self.num_layers,
354
+ "v": [None] * self.num_layers,
355
+ "y_emb": None,
356
+ "first_infer": 1,
357
+ "stage": 0,
358
+ }
359
+ for idx in range(1500):
360
+ if cache["first_infer"] == 1:
361
+ y_emb = self.ar_audio_embedding(y)
362
+ else:
363
+ y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
364
+ cache["y_emb"] = y_emb
365
+ y_pos = self.ar_audio_position(y_emb)
366
+ if cache["first_infer"] == 1:
367
+ xy_pos = torch.concat([x, y_pos], dim=1)
368
+ else:
369
+ xy_pos = y_pos[:, -1:]
370
+ y_len = y_pos.shape[1]
371
+ if cache["first_infer"] == 1:
372
+ x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
373
+ y_attn_mask = F.pad(
374
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
375
+ (x_len, 0),
376
+ value=False,
377
+ )
378
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
379
+ else:
380
+ xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
381
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
382
+ logits = self.ar_predict_layer(xy_dec[:, -1])
383
+ samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
384
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
385
+ stop = True
386
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
387
+ stop = True
388
+ if stop:
389
+ if prompts.shape[1] == y.shape[1]:
390
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
391
+ break
392
+ y = torch.concat([y, samples], dim=1)
393
+ cache["first_infer"] = 0
394
+ return y, idx
AR/models/utils.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def sequence_mask(length, max_length=None):
10
+ if max_length is None:
11
+ max_length = length.max()
12
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
13
+ return x.unsqueeze(0) < length.unsqueeze(1)
14
+
15
+
16
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
17
+ """
18
+ Args:
19
+ lengths:
20
+ A 1-D tensor containing sentence lengths.
21
+ max_len:
22
+ The length of masks.
23
+ Returns:
24
+ Return a 2-D bool tensor, where masked positions
25
+ are filled with `True` and non-masked positions are
26
+ filled with `False`.
27
+
28
+ #>>> lengths = torch.tensor([1, 3, 2, 5])
29
+ #>>> make_pad_mask(lengths)
30
+ tensor([[False, True, True, True, True],
31
+ [False, False, False, True, True],
32
+ [False, False, True, True, True],
33
+ [False, False, False, False, False]])
34
+ """
35
+ assert lengths.ndim == 1, lengths.ndim
36
+ max_len = max(max_len, lengths.max())
37
+ n = lengths.size(0)
38
+ seq_range = torch.arange(0, max_len, device=lengths.device)
39
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
40
+
41
+ return expaned_lengths >= lengths.unsqueeze(-1)
42
+
43
+
44
+ def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
45
+ """
46
+ Args:
47
+ lengths:
48
+ A 1-D tensor containing sentence lengths.
49
+ max_len:
50
+ The length of masks.
51
+ Returns:
52
+ Return a 2-D bool tensor, where masked positions
53
+ are filled with `True` and non-masked positions are
54
+ filled with `False`.
55
+
56
+ #>>> lengths = torch.tensor([1, 3, 2, 5])
57
+ #>>> make_pad_mask(lengths)
58
+ tensor(
59
+ [
60
+ [True, True, False],
61
+ [True, False, False],
62
+ [True, True, False],
63
+ ...
64
+ ]
65
+ )
66
+ """
67
+ assert lengths.ndim == 1, lengths.ndim
68
+ max_len = max(max_len, lengths.max())
69
+ n = lengths.size(0)
70
+ seq_range = torch.arange(0, max_len, device=lengths.device)
71
+ expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
72
+ expaned_lengths -= (max_len - lengths).unsqueeze(-1)
73
+
74
+ return expaned_lengths < 0
75
+
76
+
77
+ # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
78
+ def top_k_top_p_filtering(
79
+ logits,
80
+ top_k=0,
81
+ top_p=1.0,
82
+ filter_value=-float("Inf"),
83
+ min_tokens_to_keep=1,
84
+ ):
85
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
86
+ Args:
87
+ logits: logits distribution shape (batch size, vocabulary size)
88
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
89
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
90
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
91
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
92
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
93
+ """
94
+ if top_k > 0:
95
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
96
+ # Remove all tokens with a probability less than the last token of the top-k
97
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
98
+ logits[indices_to_remove] = filter_value
99
+
100
+ if top_p < 1.0:
101
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
102
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
103
+
104
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
105
+ sorted_indices_to_remove = cumulative_probs > top_p
106
+ if min_tokens_to_keep > 1:
107
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
108
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
109
+ # Shift the indices to the right to keep also the first token above the threshold
110
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
111
+ sorted_indices_to_remove[..., 0] = 0
112
+
113
+ # scatter sorted tensors to original indexing
114
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
115
+ logits[indices_to_remove] = filter_value
116
+ return logits
117
+
118
+
119
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
120
+ # temperature: (`optional`) float
121
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
122
+ # top_k: (`optional`) int
123
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
124
+ # top_p: (`optional`) float
125
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
126
+
127
+ # Temperature (higher temperature => more likely to sample low probability tokens)
128
+ if temperature != 1.0:
129
+ logits = logits / temperature
130
+ # Top-p/top-k filtering
131
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
132
+ # Sample
133
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
134
+ return token
135
+
136
+
137
+ from typing import Optional
138
+
139
+
140
+ def multinomial_sample_one_no_sync(
141
+ probs_sort,
142
+ ): # Does multinomial sampling without a cuda synchronization
143
+ q = torch.empty_like(probs_sort).exponential_(1)
144
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
145
+
146
+
147
+ def logits_to_probs(
148
+ logits,
149
+ previous_tokens: Optional[torch.Tensor] = None,
150
+ temperature: float = 1.0,
151
+ top_k: Optional[int] = None,
152
+ top_p: Optional[int] = None,
153
+ repetition_penalty: float = 1.0,
154
+ ):
155
+ # if previous_tokens is not None:
156
+ # previous_tokens = previous_tokens.squeeze()
157
+ # print(logits.shape,previous_tokens.shape)
158
+ # pdb.set_trace()
159
+ if previous_tokens is not None and repetition_penalty != 1.0:
160
+ previous_tokens = previous_tokens.long()
161
+ score = torch.gather(logits, dim=1, index=previous_tokens)
162
+ score = torch.where(
163
+ score < 0,
164
+ score * repetition_penalty,
165
+ score / repetition_penalty,
166
+ )
167
+ logits.scatter_(dim=1, index=previous_tokens, src=score)
168
+
169
+ if top_p is not None and top_p < 1.0:
170
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
171
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
172
+ sorted_indices_to_remove = cum_probs > top_p
173
+ sorted_indices_to_remove[:, 0] = False # keep at least one option
174
+ indices_to_remove = sorted_indices_to_remove.scatter(
175
+ dim=1,
176
+ index=sorted_indices,
177
+ src=sorted_indices_to_remove,
178
+ )
179
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
180
+
181
+ logits = logits / max(temperature, 1e-5)
182
+
183
+ if top_k is not None:
184
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
185
+ pivot = v[:, -1].unsqueeze(-1)
186
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
187
+
188
+ probs = torch.nn.functional.softmax(logits, dim=-1)
189
+ return probs
190
+
191
+
192
+ def sample(
193
+ logits,
194
+ previous_tokens: Optional[torch.Tensor] = None,
195
+ **sampling_kwargs,
196
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
197
+ probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
198
+ idx_next = multinomial_sample_one_no_sync(probs)
199
+ return idx_next, probs
200
+
201
+
202
+ def dpo_loss(
203
+ policy_chosen_logps: torch.FloatTensor,
204
+ policy_rejected_logps: torch.FloatTensor,
205
+ reference_chosen_logps: torch.FloatTensor,
206
+ reference_rejected_logps: torch.FloatTensor,
207
+ beta: float,
208
+ reference_free: bool = False,
209
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
210
+ pi_logratios = policy_chosen_logps - policy_rejected_logps
211
+ ref_logratios = reference_chosen_logps - reference_rejected_logps
212
+
213
+ if reference_free:
214
+ ref_logratios = 0
215
+
216
+ logits = pi_logratios - ref_logratios
217
+
218
+ losses = -F.logsigmoid(beta * logits)
219
+ chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
220
+ rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()
221
+
222
+ return losses.mean(), chosen_rewards, rejected_rewards
223
+
224
+
225
+ def get_batch_logps(
226
+ logits_target: torch.FloatTensor,
227
+ logits_reject: torch.FloatTensor,
228
+ labels_target: torch.LongTensor,
229
+ labels_reject: torch.LongTensor,
230
+ average_log_prob: bool = False,
231
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
232
+ # dummy token; we'll ignore the losses on these tokens later
233
+
234
+ per_token_logps_target = torch.gather(
235
+ logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
236
+ ).squeeze(2)
237
+ per_token_logps_reject = torch.gather(
238
+ logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
239
+ ).squeeze(2)
240
+
241
+ return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
242
+
243
+
244
+ def make_reject_y(y_o, y_lens):
245
+ def repeat_P(y):
246
+ range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
247
+ pre = y[: range_idx[0]]
248
+ shf = y[range_idx[1] :]
249
+ range_text = y[range_idx[0] : range_idx[1]]
250
+ new_y = torch.cat([pre, range_text, range_text, shf])
251
+ return new_y
252
+
253
+ def lost_P(y):
254
+ range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
255
+ pre = y[: range_idx[0]]
256
+ shf = y[range_idx[1] :]
257
+ range_text = y[range_idx[0] : range_idx[1]]
258
+ new_y = torch.cat([pre, shf])
259
+ return new_y
260
+
261
+ bs = len(y_lens)
262
+ reject_y = []
263
+ reject_y_lens = []
264
+ for b in range(bs):
265
+ process_item_idx = torch.randint(0, 1, size=(1,))[0]
266
+ if process_item_idx == 0:
267
+ new_y = repeat_P(y_o[b])
268
+ reject_y.append(new_y)
269
+ reject_y_lens.append(len(new_y))
270
+ elif process_item_idx == 1:
271
+ new_y = lost_P(y_o[b])
272
+ reject_y.append(new_y)
273
+ reject_y_lens.append(len(new_y))
274
+ max_length = max(reject_y_lens)
275
+ for b in range(bs):
276
+ pad_length = max_length - reject_y_lens[b]
277
+ reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
278
+
279
+ reject_y = torch.stack(reject_y, dim=0)
280
+ reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
281
+
282
+ return reject_y, reject_y_lens
AR/modules/__init__.py ADDED
File without changes
AR/modules/activation.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear, Module
7
+ from torch.nn import functional as F
8
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
10
+ from torch.nn.parameter import Parameter
11
+
12
+ from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
13
+
14
+ F.multi_head_attention_forward = multi_head_attention_forward_patched
15
+
16
+
17
+ class MultiheadAttention(Module):
18
+ r"""Allows the model to jointly attend to information
19
+ from different representation subspaces as described in the paper:
20
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
21
+
22
+ Multi-Head Attention is defined as:
23
+
24
+ .. math::
25
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
26
+
27
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
28
+
29
+ ``forward()`` will use a special optimized implementation if all of the following
30
+ conditions are met:
31
+
32
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
33
+ restriction will be loosened in the future.)
34
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
35
+ - training is disabled (using ``.eval()``)
36
+ - dropout is 0
37
+ - ``add_bias_kv`` is ``False``
38
+ - ``add_zero_attn`` is ``False``
39
+ - ``batch_first`` is ``True`` and the input is batched
40
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
41
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
42
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
43
+ nor ``attn_mask`` is passed
44
+
45
+ If the optimized implementation is in use, a
46
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
47
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
48
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
49
+ will be returned, and an additional speedup proportional to the fraction of the input
50
+ that is padding can be expected.
51
+
52
+ Args:
53
+ embed_dim: Total dimension of the model.
54
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
55
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
56
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
57
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
58
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
59
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
60
+ Default: ``False``.
61
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
62
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
63
+ batch_first: If ``True``, then the input and output tensors are provided
64
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
65
+
66
+ Examples::
67
+
68
+ >>> # xdoctest: +SKIP
69
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
70
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
71
+
72
+ """
73
+
74
+ __constants__ = ["batch_first"]
75
+ bias_k: Optional[torch.Tensor]
76
+ bias_v: Optional[torch.Tensor]
77
+
78
+ def __init__(
79
+ self,
80
+ embed_dim,
81
+ num_heads,
82
+ dropout=0.0,
83
+ bias=True,
84
+ add_bias_kv=False,
85
+ add_zero_attn=False,
86
+ kdim=None,
87
+ vdim=None,
88
+ batch_first=False,
89
+ linear1_cls=Linear,
90
+ linear2_cls=Linear,
91
+ device=None,
92
+ dtype=None,
93
+ ) -> None:
94
+ factory_kwargs = {"device": device, "dtype": dtype}
95
+ super(MultiheadAttention, self).__init__()
96
+ self.embed_dim = embed_dim
97
+ self.kdim = kdim if kdim is not None else embed_dim
98
+ self.vdim = vdim if vdim is not None else embed_dim
99
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
100
+
101
+ self.num_heads = num_heads
102
+ self.dropout = dropout
103
+ self.batch_first = batch_first
104
+ self.head_dim = embed_dim // num_heads
105
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
106
+
107
+ if add_bias_kv:
108
+ self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
109
+ self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
110
+ else:
111
+ self.bias_k = self.bias_v = None
112
+
113
+ if linear1_cls == Linear:
114
+ if not self._qkv_same_embed_dim:
115
+ self.q_proj_weight = Parameter(
116
+ torch.empty((embed_dim, embed_dim), **factory_kwargs),
117
+ )
118
+ self.k_proj_weight = Parameter(
119
+ torch.empty((embed_dim, self.kdim), **factory_kwargs),
120
+ )
121
+ self.v_proj_weight = Parameter(
122
+ torch.empty((embed_dim, self.vdim), **factory_kwargs),
123
+ )
124
+ self.register_parameter("in_proj_weight", None)
125
+ else:
126
+ self.in_proj_weight = Parameter(
127
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
128
+ )
129
+ self.register_parameter("q_proj_weight", None)
130
+ self.register_parameter("k_proj_weight", None)
131
+ self.register_parameter("v_proj_weight", None)
132
+
133
+ if bias:
134
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
135
+ else:
136
+ self.register_parameter("in_proj_bias", None)
137
+ self.out_proj = NonDynamicallyQuantizableLinear(
138
+ embed_dim,
139
+ embed_dim,
140
+ bias=bias,
141
+ **factory_kwargs,
142
+ )
143
+
144
+ self._reset_parameters()
145
+ else:
146
+ if not self._qkv_same_embed_dim:
147
+ raise NotImplementedError
148
+ else:
149
+ self.in_proj_linear = linear1_cls(
150
+ embed_dim,
151
+ 3 * embed_dim,
152
+ bias=bias,
153
+ **factory_kwargs,
154
+ )
155
+ self.in_proj_weight = self.in_proj_linear.weight
156
+
157
+ self.register_parameter("q_proj_weight", None)
158
+ self.register_parameter("k_proj_weight", None)
159
+ self.register_parameter("v_proj_weight", None)
160
+
161
+ if bias:
162
+ self.in_proj_bias = self.in_proj_linear.bias
163
+ else:
164
+ self.register_parameter("in_proj_bias", None)
165
+
166
+ self.out_proj = linear2_cls(
167
+ embed_dim,
168
+ embed_dim,
169
+ bias=bias,
170
+ **factory_kwargs,
171
+ )
172
+
173
+ if self.bias_k is not None:
174
+ xavier_normal_(self.bias_k)
175
+ if self.bias_v is not None:
176
+ xavier_normal_(self.bias_v)
177
+
178
+ self.add_zero_attn = add_zero_attn
179
+
180
+ def _reset_parameters(self):
181
+ if self._qkv_same_embed_dim:
182
+ xavier_uniform_(self.in_proj_weight)
183
+ else:
184
+ xavier_uniform_(self.q_proj_weight)
185
+ xavier_uniform_(self.k_proj_weight)
186
+ xavier_uniform_(self.v_proj_weight)
187
+
188
+ if self.in_proj_bias is not None:
189
+ constant_(self.in_proj_bias, 0.0)
190
+ constant_(self.out_proj.bias, 0.0)
191
+
192
+ if self.bias_k is not None:
193
+ xavier_normal_(self.bias_k)
194
+ if self.bias_v is not None:
195
+ xavier_normal_(self.bias_v)
196
+
197
+ def __setstate__(self, state):
198
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
199
+ if "_qkv_same_embed_dim" not in state:
200
+ state["_qkv_same_embed_dim"] = True
201
+
202
+ super(MultiheadAttention, self).__setstate__(state)
203
+
204
+ def forward(
205
+ self,
206
+ query: Tensor,
207
+ key: Tensor,
208
+ value: Tensor,
209
+ key_padding_mask: Optional[Tensor] = None,
210
+ need_weights: bool = True,
211
+ attn_mask: Optional[Tensor] = None,
212
+ average_attn_weights: bool = True,
213
+ cache=None,
214
+ ) -> Tuple[Tensor, Optional[Tensor]]:
215
+ r"""
216
+ Args:
217
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
218
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
219
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
220
+ Queries are compared against key-value pairs to produce the output.
221
+ See "Attention Is All You Need" for more details.
222
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
223
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
224
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
225
+ See "Attention Is All You Need" for more details.
226
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
227
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
228
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
229
+ See "Attention Is All You Need" for more details.
230
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
231
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
232
+ Binary and byte masks are supported.
233
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
234
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
235
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
236
+ Default: ``True``.
237
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
238
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
239
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
240
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
241
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
242
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
243
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
244
+ the attention weight.
245
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
246
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
247
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
248
+
249
+ Outputs:
250
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
251
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
252
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
253
+ embedding dimension ``embed_dim``.
254
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
255
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
256
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
257
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
258
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
259
+
260
+ .. note::
261
+ `batch_first` argument is ignored for unbatched inputs.
262
+ """
263
+ is_batched = query.dim() == 3
264
+ if key_padding_mask is not None:
265
+ _kpm_dtype = key_padding_mask.dtype
266
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
267
+ key_padding_mask,
268
+ ):
269
+ raise AssertionError("only bool and floating types of key_padding_mask are supported")
270
+ why_not_fast_path = ""
271
+ if not is_batched:
272
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
273
+ elif query is not key or key is not value:
274
+ # When lifting this restriction, don't forget to either
275
+ # enforce that the dtypes all match or test cases where
276
+ # they don't!
277
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
278
+ elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
279
+ why_not_fast_path = (
280
+ f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
281
+ )
282
+ elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
283
+ # this case will fail anyway, but at least they'll get a useful error message.
284
+ why_not_fast_path = (
285
+ f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
286
+ )
287
+ elif self.training:
288
+ why_not_fast_path = "training is enabled"
289
+ elif not self.batch_first:
290
+ why_not_fast_path = "batch_first was not True"
291
+ elif self.bias_k is not None:
292
+ why_not_fast_path = "self.bias_k was not None"
293
+ elif self.bias_v is not None:
294
+ why_not_fast_path = "self.bias_v was not None"
295
+ elif self.dropout:
296
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
297
+ elif self.add_zero_attn:
298
+ why_not_fast_path = "add_zero_attn was enabled"
299
+ elif not self._qkv_same_embed_dim:
300
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
301
+ elif attn_mask is not None:
302
+ why_not_fast_path = "attn_mask was not None"
303
+ elif query.is_nested and key_padding_mask is not None:
304
+ why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
305
+ elif self.num_heads % 2 == 1:
306
+ why_not_fast_path = "num_heads is odd"
307
+ elif torch.is_autocast_enabled():
308
+ why_not_fast_path = "autocast is enabled"
309
+
310
+ if not why_not_fast_path:
311
+ tensor_args = (
312
+ query,
313
+ key,
314
+ value,
315
+ self.in_proj_weight,
316
+ self.in_proj_bias,
317
+ self.out_proj.weight,
318
+ self.out_proj.bias,
319
+ )
320
+ # We have to use list comprehensions below because TorchScript does not support
321
+ # generator expressions.
322
+ if torch.overrides.has_torch_function(tensor_args):
323
+ why_not_fast_path = "some Tensor argument has_torch_function"
324
+ elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
325
+ why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
326
+ elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
327
+ why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
328
+ if not why_not_fast_path:
329
+ return torch._native_multi_head_attention(
330
+ query,
331
+ key,
332
+ value,
333
+ self.embed_dim,
334
+ self.num_heads,
335
+ self.in_proj_weight,
336
+ self.in_proj_bias,
337
+ self.out_proj.weight,
338
+ self.out_proj.bias,
339
+ key_padding_mask if key_padding_mask is not None else attn_mask,
340
+ need_weights,
341
+ average_attn_weights,
342
+ 1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
343
+ )
344
+
345
+ any_nested = query.is_nested or key.is_nested or value.is_nested
346
+ assert not any_nested, (
347
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
348
+ + f"The fast path was not hit because {why_not_fast_path}"
349
+ )
350
+
351
+ if self.batch_first and is_batched:
352
+ # make sure that the transpose op does not affect the "is" property
353
+ if key is value:
354
+ if query is key:
355
+ query = key = value = query.transpose(1, 0)
356
+ else:
357
+ query, key = [x.transpose(1, 0) for x in (query, key)]
358
+ value = key
359
+ else:
360
+ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
361
+
362
+ if not self._qkv_same_embed_dim:
363
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
364
+ query,
365
+ key,
366
+ value,
367
+ self.embed_dim,
368
+ self.num_heads,
369
+ self.in_proj_weight,
370
+ self.in_proj_bias,
371
+ self.bias_k,
372
+ self.bias_v,
373
+ self.add_zero_attn,
374
+ self.dropout,
375
+ self.out_proj.weight,
376
+ self.out_proj.bias,
377
+ training=self.training,
378
+ key_padding_mask=key_padding_mask,
379
+ need_weights=need_weights,
380
+ attn_mask=attn_mask,
381
+ use_separate_proj_weight=True,
382
+ q_proj_weight=self.q_proj_weight,
383
+ k_proj_weight=self.k_proj_weight,
384
+ v_proj_weight=self.v_proj_weight,
385
+ average_attn_weights=average_attn_weights,
386
+ cache=cache,
387
+ )
388
+ else:
389
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
390
+ query,
391
+ key,
392
+ value,
393
+ self.embed_dim,
394
+ self.num_heads,
395
+ self.in_proj_weight,
396
+ self.in_proj_bias,
397
+ self.bias_k,
398
+ self.bias_v,
399
+ self.add_zero_attn,
400
+ self.dropout,
401
+ self.out_proj.weight,
402
+ self.out_proj.bias,
403
+ training=self.training,
404
+ key_padding_mask=key_padding_mask,
405
+ need_weights=need_weights,
406
+ attn_mask=attn_mask,
407
+ average_attn_weights=average_attn_weights,
408
+ cache=cache,
409
+ )
410
+ if self.batch_first and is_batched:
411
+ return attn_output.transpose(1, 0), attn_output_weights
412
+ else:
413
+ return attn_output, attn_output_weights
AR/modules/activation_onnx.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear, Module
7
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
8
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
9
+ from torch.nn.parameter import Parameter
10
+
11
+ from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
12
+
13
+
14
+ class MultiheadAttention(Module):
15
+ __constants__ = ["batch_first"]
16
+ bias_k: Optional[torch.Tensor]
17
+ bias_v: Optional[torch.Tensor]
18
+
19
+ def __init__(
20
+ self,
21
+ embed_dim,
22
+ num_heads,
23
+ dropout=0.0,
24
+ bias=True,
25
+ add_bias_kv=False,
26
+ add_zero_attn=False,
27
+ kdim=None,
28
+ vdim=None,
29
+ batch_first=False,
30
+ linear1_cls=Linear,
31
+ linear2_cls=Linear,
32
+ device=None,
33
+ dtype=None,
34
+ ) -> None:
35
+ factory_kwargs = {"device": device, "dtype": dtype}
36
+ super(MultiheadAttention, self).__init__()
37
+ self.embed_dim = embed_dim
38
+ self.kdim = kdim if kdim is not None else embed_dim
39
+ self.vdim = vdim if vdim is not None else embed_dim
40
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
41
+
42
+ self.num_heads = num_heads
43
+ self.dropout = dropout
44
+ self.batch_first = batch_first
45
+ self.head_dim = embed_dim // num_heads
46
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
47
+
48
+ if add_bias_kv:
49
+ self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
50
+ self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
51
+ else:
52
+ self.bias_k = self.bias_v = None
53
+
54
+ if linear1_cls == Linear:
55
+ if not self._qkv_same_embed_dim:
56
+ self.q_proj_weight = Parameter(
57
+ torch.empty(
58
+ (embed_dim, embed_dim),
59
+ **factory_kwargs,
60
+ )
61
+ )
62
+ self.k_proj_weight = Parameter(
63
+ torch.empty(
64
+ (embed_dim, self.kdim),
65
+ **factory_kwargs,
66
+ )
67
+ )
68
+ self.v_proj_weight = Parameter(
69
+ torch.empty(
70
+ (embed_dim, self.vdim),
71
+ **factory_kwargs,
72
+ )
73
+ )
74
+ self.register_parameter("in_proj_weight", None)
75
+ else:
76
+ self.in_proj_weight = Parameter(
77
+ torch.empty(
78
+ (3 * embed_dim, embed_dim),
79
+ **factory_kwargs,
80
+ )
81
+ )
82
+ self.register_parameter("q_proj_weight", None)
83
+ self.register_parameter("k_proj_weight", None)
84
+ self.register_parameter("v_proj_weight", None)
85
+
86
+ if bias:
87
+ self.in_proj_bias = Parameter(
88
+ torch.empty(3 * embed_dim, **factory_kwargs),
89
+ )
90
+ else:
91
+ self.register_parameter("in_proj_bias", None)
92
+ self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
93
+
94
+ self._reset_parameters()
95
+ else:
96
+ if not self._qkv_same_embed_dim:
97
+ raise NotImplementedError
98
+ else:
99
+ self.in_proj_linear = linear1_cls(
100
+ embed_dim,
101
+ 3 * embed_dim,
102
+ bias=bias,
103
+ **factory_kwargs,
104
+ )
105
+ self.in_proj_weight = self.in_proj_linear.weight
106
+
107
+ self.register_parameter("q_proj_weight", None)
108
+ self.register_parameter("k_proj_weight", None)
109
+ self.register_parameter("v_proj_weight", None)
110
+
111
+ if bias:
112
+ self.in_proj_bias = self.in_proj_linear.bias
113
+ else:
114
+ self.register_parameter("in_proj_bias", None)
115
+
116
+ self.out_proj = linear2_cls(
117
+ embed_dim,
118
+ embed_dim,
119
+ bias=bias,
120
+ **factory_kwargs,
121
+ )
122
+
123
+ if self.bias_k is not None:
124
+ xavier_normal_(self.bias_k)
125
+ if self.bias_v is not None:
126
+ xavier_normal_(self.bias_v)
127
+
128
+ self.add_zero_attn = add_zero_attn
129
+
130
+ def _reset_parameters(self):
131
+ if self._qkv_same_embed_dim:
132
+ xavier_uniform_(self.in_proj_weight)
133
+ else:
134
+ xavier_uniform_(self.q_proj_weight)
135
+ xavier_uniform_(self.k_proj_weight)
136
+ xavier_uniform_(self.v_proj_weight)
137
+
138
+ if self.in_proj_bias is not None:
139
+ constant_(self.in_proj_bias, 0.0)
140
+ constant_(self.out_proj.bias, 0.0)
141
+
142
+ if self.bias_k is not None:
143
+ xavier_normal_(self.bias_k)
144
+ if self.bias_v is not None:
145
+ xavier_normal_(self.bias_v)
146
+
147
+ def __setstate__(self, state):
148
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
149
+ if "_qkv_same_embed_dim" not in state:
150
+ state["_qkv_same_embed_dim"] = True
151
+
152
+ super(MultiheadAttention, self).__setstate__(state)
153
+
154
+ def forward(
155
+ self,
156
+ query: Tensor,
157
+ key: Tensor,
158
+ value: Tensor,
159
+ key_padding_mask: Optional[Tensor] = None,
160
+ need_weights: bool = True,
161
+ attn_mask: Optional[Tensor] = None,
162
+ average_attn_weights: bool = True,
163
+ cache=None,
164
+ ) -> Tuple[Tensor, Optional[Tensor]]:
165
+ any_nested = query.is_nested or key.is_nested or value.is_nested
166
+ query = key = value = query.transpose(1, 0)
167
+ attn_output = multi_head_attention_forward_patched(
168
+ query,
169
+ key,
170
+ value,
171
+ self.embed_dim,
172
+ self.num_heads,
173
+ self.in_proj_weight,
174
+ self.in_proj_bias,
175
+ self.bias_k,
176
+ self.bias_v,
177
+ self.add_zero_attn,
178
+ self.dropout,
179
+ self.out_proj.weight,
180
+ self.out_proj.bias,
181
+ training=self.training,
182
+ key_padding_mask=key_padding_mask,
183
+ need_weights=need_weights,
184
+ attn_mask=attn_mask,
185
+ average_attn_weights=average_attn_weights,
186
+ cache=cache,
187
+ )
188
+ return attn_output.transpose(1, 0)
AR/modules/embedding.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class TokenEmbedding(nn.Module):
9
+ def __init__(
10
+ self,
11
+ embedding_dim: int,
12
+ vocab_size: int,
13
+ dropout: float = 0.0,
14
+ ):
15
+ super().__init__()
16
+
17
+ self.vocab_size = vocab_size
18
+ self.embedding_dim = embedding_dim
19
+
20
+ self.dropout = torch.nn.Dropout(p=dropout)
21
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
22
+
23
+ @property
24
+ def weight(self) -> torch.Tensor:
25
+ return self.word_embeddings.weight
26
+
27
+ def embedding(self, index: int) -> torch.Tensor:
28
+ return self.word_embeddings.weight[index : index + 1]
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = self.word_embeddings(x)
32
+ x = self.dropout(x)
33
+ return x
34
+
35
+
36
+ class SinePositionalEmbedding(nn.Module):
37
+ def __init__(
38
+ self,
39
+ embedding_dim: int,
40
+ dropout: float = 0.0,
41
+ scale: bool = False,
42
+ alpha: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.embedding_dim = embedding_dim
46
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
47
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
48
+ self.dropout = torch.nn.Dropout(p=dropout)
49
+
50
+ self.reverse = False
51
+ self.pe = None
52
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
53
+
54
+ def extend_pe(self, x):
55
+ """Reset the positional encodings."""
56
+ if self.pe is not None:
57
+ if self.pe.size(1) >= x.size(1):
58
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
59
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
60
+ return
61
+ pe = torch.zeros(x.size(1), self.embedding_dim)
62
+ if self.reverse:
63
+ position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
64
+ else:
65
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
66
+ div_term = torch.exp(
67
+ torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
68
+ )
69
+ pe[:, 0::2] = torch.sin(position * div_term)
70
+ pe[:, 1::2] = torch.cos(position * div_term)
71
+ pe = pe.unsqueeze(0)
72
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ self.extend_pe(x)
76
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
77
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
78
+ return self.dropout(output)
AR/modules/embedding_onnx.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class TokenEmbedding(nn.Module):
9
+ def __init__(
10
+ self,
11
+ embedding_dim: int,
12
+ vocab_size: int,
13
+ dropout: float = 0.0,
14
+ ):
15
+ super().__init__()
16
+
17
+ self.vocab_size = vocab_size
18
+ self.embedding_dim = embedding_dim
19
+
20
+ self.dropout = torch.nn.Dropout(p=dropout)
21
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
22
+
23
+ @property
24
+ def weight(self) -> torch.Tensor:
25
+ return self.word_embeddings.weight
26
+
27
+ def embedding(self, index: int) -> torch.Tensor:
28
+ return self.word_embeddings.weight[index : index + 1]
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = self.word_embeddings(x)
32
+ x = self.dropout(x)
33
+ return x
34
+
35
+
36
+ class SinePositionalEmbedding(nn.Module):
37
+ def __init__(
38
+ self,
39
+ embedding_dim: int,
40
+ dropout: float = 0.0,
41
+ scale: bool = False,
42
+ alpha: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.embedding_dim = embedding_dim
46
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
47
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
48
+ self.dropout = torch.nn.Dropout(p=dropout)
49
+ self.reverse = False
50
+ self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
51
+
52
+ def extend_pe(self, x):
53
+ position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1)
54
+ scpe = (position * self.div_term).unsqueeze(0)
55
+ pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
56
+ pe = pe.contiguous().view(1, -1, self.embedding_dim)
57
+ return pe
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ pe = self.extend_pe(x)
61
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
62
+ output = output * self.x_scale + self.alpha * pe
63
+ return self.dropout(output)
AR/modules/lr_schedulers.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/modules/lr_schedulers.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import math
4
+
5
+ import torch
6
+ from matplotlib import pyplot as plt
7
+ from torch import nn
8
+ from torch.optim import Adam
9
+
10
+
11
+ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
12
+ """
13
+ Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ optimizer,
19
+ init_lr,
20
+ peak_lr,
21
+ end_lr,
22
+ warmup_steps=10000,
23
+ total_steps=400000,
24
+ current_step=0,
25
+ ):
26
+ self.init_lr = init_lr
27
+ self.peak_lr = peak_lr
28
+ self.end_lr = end_lr
29
+ self.optimizer = optimizer
30
+ self._warmup_rate = (peak_lr - init_lr) / warmup_steps
31
+ self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps)
32
+ self._current_step = current_step
33
+ self.lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.total_steps = total_steps
36
+ self._last_lr = [self.lr]
37
+
38
+ def set_lr(self, lr):
39
+ self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
40
+ for g in self.optimizer.param_groups:
41
+ # g['lr'] = lr
42
+ g["lr"] = self.end_lr ###锁定用线性
43
+
44
+ def step(self):
45
+ if self._current_step < self.warmup_steps:
46
+ lr = self.init_lr + self._warmup_rate * self._current_step
47
+
48
+ elif self._current_step > self.total_steps:
49
+ lr = self.end_lr
50
+
51
+ else:
52
+ decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
53
+ if decay_ratio < 0.0 or decay_ratio > 1.0:
54
+ raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
55
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
56
+ lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
57
+
58
+ self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
59
+ self.set_lr(lr)
60
+ self.lr = lr
61
+ self._current_step += 1
62
+ return self.lr
63
+
64
+
65
+ if __name__ == "__main__":
66
+ m = nn.Linear(10, 10)
67
+ opt = Adam(m.parameters(), lr=1e-4)
68
+ s = WarmupCosineLRSchedule(
69
+ opt,
70
+ 1e-6,
71
+ 2e-4,
72
+ 1e-6,
73
+ warmup_steps=2000,
74
+ total_steps=20000,
75
+ current_step=0,
76
+ )
77
+ lrs = []
78
+ for i in range(25000):
79
+ s.step()
80
+ lrs.append(s.lr)
81
+ print(s.lr)
82
+
83
+ plt.plot(lrs)
84
+ plt.plot(range(0, 25000), lrs)
85
+ plt.show()
AR/modules/optim.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import contextlib
17
+ import logging
18
+ from collections import defaultdict
19
+ from typing import List, Tuple
20
+
21
+ import torch
22
+ from torch import Tensor
23
+ from torch.optim import Optimizer
24
+
25
+
26
+ class BatchedOptimizer(Optimizer):
27
+ """
28
+ This class adds to class Optimizer the capability to optimize parameters in batches:
29
+ it will stack the parameters and their grads for you so the optimizer can work
30
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
31
+ as it reduces the number of kernels launched in the optimizer.
32
+
33
+ Args:
34
+ params:
35
+ """
36
+
37
+ def __init__(self, params, defaults):
38
+ super(BatchedOptimizer, self).__init__(params, defaults)
39
+
40
+ @contextlib.contextmanager
41
+ def batched_params(self, param_group, group_params_names):
42
+ """
43
+ This function returns (technically, yields) a list of
44
+ of tuples (p, state), where
45
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
46
+ that share the same shape, and its gradient is also stacked;
47
+ `state` is the state corresponding to this batch of parameters
48
+ (it will be physically located in the "state" for one of the real
49
+ parameters, the last one that has any particular shape and dtype).
50
+
51
+ This function is decorated as a context manager so that it can
52
+ write parameters back to their "real" locations.
53
+
54
+ The idea is, instead of doing:
55
+ <code>
56
+ for p in group["params"]:
57
+ state = self.state[p]
58
+ ...
59
+ </code>
60
+ you can do:
61
+ <code>
62
+ with self.batched_params(group["params"]) as batches:
63
+ for p, state, p_names in batches:
64
+ ...
65
+ </code>
66
+
67
+ Args:
68
+ group: a parameter group, which is a list of parameters; should be
69
+ one of self.param_groups.
70
+ group_params_names: name for each parameter in group,
71
+ which is List[str].
72
+ """
73
+ batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
74
+ batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
75
+
76
+ assert len(param_group) == len(group_params_names)
77
+ for p, named_p in zip(param_group, group_params_names):
78
+ key = (str(p.dtype), *p.shape)
79
+ batches[key].append(p)
80
+ batches_names[key].append(named_p)
81
+
82
+ batches_names_keys = list(batches_names.keys())
83
+ sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
84
+ batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
85
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
86
+
87
+ stacked_params_dict = dict()
88
+
89
+ # turn batches into a list, in deterministic order.
90
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
91
+ # one for each batch in `batches`.
92
+ tuples = []
93
+
94
+ for batch, batch_names in zip(batches, batches_names):
95
+ p = batch[0]
96
+ # we arbitrarily store the state in the
97
+ # state corresponding to the 1st parameter in the
98
+ # group. class Optimizer will take care of saving/loading state.
99
+ state = self.state[p]
100
+ p_stacked = torch.stack(batch)
101
+ grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
102
+ p_stacked.grad = grad
103
+ stacked_params_dict[key] = p_stacked
104
+ tuples.append((p_stacked, state, batch_names))
105
+
106
+ yield tuples # <-- calling code will do the actual optimization here!
107
+
108
+ for (stacked_params, _state, _names), batch in zip(tuples, batches):
109
+ for i, p in enumerate(batch): # batch is list of Parameter
110
+ p.copy_(stacked_params[i])
111
+
112
+
113
+ class ScaledAdam(BatchedOptimizer):
114
+ """
115
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
116
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
117
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
118
+ param = underlying_param * log_scale.exp())
119
+
120
+
121
+ Args:
122
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
123
+ lr: The learning rate. We will typically use a learning rate schedule that starts
124
+ at 0.03 and decreases over time, i.e. much higher than other common
125
+ optimizers.
126
+ clipping_scale: (e.g. 2.0)
127
+ A scale for gradient-clipping: if specified, the normalized gradients
128
+ over the whole model will be clipped to have 2-norm equal to
129
+ `clipping_scale` times the median 2-norm over the most recent period
130
+ of `clipping_update_period` minibatches. By "normalized gradients",
131
+ we mean after multiplying by the rms parameter value for this tensor
132
+ [for non-scalars]; this is appropriate because our update is scaled
133
+ by this quantity.
134
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
135
+ Must satisfy 0 < beta <= beta2 < 1.
136
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
137
+ scale of each parameter tensor and scalar parameters of the mode..
138
+ If each parameter were decomposed
139
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
140
+ would be a the scaling factor on the learning rate of p_scale.
141
+ eps: A general-purpose epsilon to prevent division by zero
142
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
143
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
144
+ parameter tensor to be >= this value)
145
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
146
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
147
+ parameter tensor to be <= this value)
148
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
149
+ model has any parameters with numel() == 1).
150
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
151
+ of the parameter tensor. This is provided to save a little time
152
+ in the update.
153
+ clipping_update_period: if clipping_scale is specified, this is the period
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ params,
159
+ lr=3e-02,
160
+ clipping_scale=None,
161
+ betas=(0.9, 0.98),
162
+ scalar_lr_scale=0.1,
163
+ eps=1.0e-08,
164
+ param_min_rms=1.0e-05,
165
+ param_max_rms=3.0,
166
+ scalar_max=10.0,
167
+ size_update_period=4,
168
+ clipping_update_period=100,
169
+ parameters_names=None,
170
+ show_dominant_parameters=True,
171
+ ):
172
+ assert parameters_names is not None, (
173
+ "Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
174
+ )
175
+ defaults = dict(
176
+ lr=lr,
177
+ clipping_scale=clipping_scale,
178
+ betas=betas,
179
+ scalar_lr_scale=scalar_lr_scale,
180
+ eps=eps,
181
+ param_min_rms=param_min_rms,
182
+ param_max_rms=param_max_rms,
183
+ scalar_max=scalar_max,
184
+ size_update_period=size_update_period,
185
+ clipping_update_period=clipping_update_period,
186
+ )
187
+
188
+ super(ScaledAdam, self).__init__(params, defaults)
189
+ assert len(self.param_groups) == len(parameters_names)
190
+ self.parameters_names = parameters_names
191
+ self.show_dominant_parameters = show_dominant_parameters
192
+
193
+ def __setstate__(self, state):
194
+ super(ScaledAdam, self).__setstate__(state)
195
+
196
+ @torch.no_grad()
197
+ def step(self, closure=None):
198
+ """Performs a single optimization step.
199
+
200
+ Arguments:
201
+ closure (callable, optional): A closure that reevaluates the model
202
+ and returns the loss.
203
+ """
204
+ loss = None
205
+ if closure is not None:
206
+ with torch.enable_grad():
207
+ loss = closure()
208
+
209
+ batch = True
210
+
211
+ for group, group_params_names in zip(self.param_groups, self.parameters_names):
212
+ with self.batched_params(group["params"], group_params_names) as batches:
213
+ # batches is list of pairs (stacked_param, state). stacked_param is like
214
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
215
+ # a stacking dim, it is not a real dim.
216
+
217
+ if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
218
+ clipping_scale = 1
219
+ else:
220
+ clipping_scale = self._get_clipping_scale(group, batches)
221
+
222
+ for p, state, _ in batches:
223
+ # Perform optimization step.
224
+ # grad is not going to be None, we handled that when creating the batches.
225
+ grad = p.grad
226
+ if grad.is_sparse:
227
+ raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
228
+ # State initialization
229
+ if len(state) == 0:
230
+ self._init_state(group, p, state)
231
+
232
+ self._step_one_batch(group, p, state, clipping_scale)
233
+
234
+ return loss
235
+
236
+ def _init_state(self, group: dict, p: Tensor, state: dict):
237
+ """
238
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
239
+ is actually the batch dimension, corresponding to batched-together
240
+ parameters of a given shape.
241
+
242
+
243
+ Args:
244
+ group: Dict to look up configuration values.
245
+ p: The parameter that we are initializing the state for
246
+ state: Dict from string to whatever state we are initializing
247
+ """
248
+ size_update_period = group["size_update_period"]
249
+
250
+ state["step"] = 0
251
+
252
+ kwargs = {"device": p.device, "dtype": p.dtype}
253
+
254
+ # 'delta' implements conventional momentum. There are
255
+ # several different kinds of update going on, so rather than
256
+ # compute "exp_avg" like in Adam, we store and decay a
257
+ # parameter-change "delta", which combines all forms of
258
+ # update. this is equivalent to how it's done in Adam,
259
+ # except for the first few steps.
260
+ state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
261
+
262
+ batch_size = p.shape[0]
263
+ numel = p.numel() // batch_size
264
+ numel = p.numel()
265
+
266
+ if numel > 1:
267
+ # "param_rms" just periodically records the scalar root-mean-square value of
268
+ # the parameter tensor.
269
+ # it has a shape like (batch_size, 1, 1, 1, 1)
270
+ param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
271
+ state["param_rms"] = param_rms
272
+
273
+ state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
274
+ state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
275
+
276
+ # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
277
+ state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
278
+
279
+ def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
280
+ """
281
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
282
+ by this amount before applying the rest of the update.
283
+
284
+ Args:
285
+ group: the parameter group, an item in self.param_groups
286
+ tuples: a list of tuples of (param, state, param_names)
287
+ where param is a batched set of parameters,
288
+ with a .grad (1st dim is batch dim)
289
+ and state is the state-dict where optimization parameters are kept.
290
+ param_names is a List[str] while each str is name for a parameter
291
+ in batched set of parameters "param".
292
+ """
293
+ assert len(tuples) >= 1
294
+ clipping_scale = group["clipping_scale"]
295
+ (first_p, first_state, _) = tuples[0]
296
+ step = first_state["step"]
297
+ if clipping_scale is None or step == 0:
298
+ # no clipping. return early on step == 0 because the other
299
+ # parameters' state won't have been initialized yet.
300
+ return 1.0
301
+ clipping_update_period = group["clipping_update_period"]
302
+
303
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
304
+ for p, state, param_names in tuples:
305
+ grad = p.grad
306
+ if grad.is_sparse:
307
+ raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
308
+ if p.numel() == p.shape[0]: # a batch of scalars
309
+ tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
310
+ else:
311
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
312
+
313
+ tot_norm = tot_sumsq.sqrt()
314
+ if "model_norms" not in first_state:
315
+ first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
316
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
317
+
318
+ if step % clipping_update_period == 0:
319
+ # Print some stats.
320
+ # We don't reach here if step == 0 because we would have returned
321
+ # above.
322
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
323
+ quartiles = []
324
+ for n in range(0, 5):
325
+ index = min(
326
+ clipping_update_period - 1,
327
+ (clipping_update_period // 4) * n,
328
+ )
329
+ quartiles.append(sorted_norms[index].item())
330
+
331
+ median = quartiles[2]
332
+ threshold = clipping_scale * median
333
+ first_state["model_norm_threshold"] = threshold
334
+ percent_clipped = (
335
+ first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
336
+ )
337
+ first_state["num_clipped"] = 0
338
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
339
+ logging.info(
340
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
341
+ )
342
+
343
+ if step < clipping_update_period:
344
+ return 1.0 # We have not yet estimated a norm to clip to.
345
+ else:
346
+ try:
347
+ model_norm_threshold = first_state["model_norm_threshold"]
348
+ except KeyError:
349
+ logging.info(
350
+ "Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
351
+ )
352
+ return 1.0
353
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
354
+ if ans < 1.0:
355
+ first_state["num_clipped"] += 1
356
+ if ans < 0.1:
357
+ logging.warning(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
358
+ if self.show_dominant_parameters:
359
+ assert p.shape[0] == len(param_names)
360
+ self._show_gradient_dominating_parameter(tuples, tot_sumsq)
361
+ return ans
362
+
363
+ def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
364
+ """
365
+ Show information of parameter which dominating tot_sumsq.
366
+
367
+ Args:
368
+ tuples: a list of tuples of (param, state, param_names)
369
+ where param is a batched set of parameters,
370
+ with a .grad (1st dim is batch dim)
371
+ and state is the state-dict where optimization parameters are kept.
372
+ param_names is a List[str] while each str is name for a parameter
373
+ in batched set of parameters "param".
374
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
375
+ from tuples, we still pass it to save some time.
376
+ """
377
+ all_sumsq_orig = {}
378
+ for p, state, batch_param_names in tuples:
379
+ # p is a stacked batch parameters.
380
+ batch_grad = p.grad
381
+ if p.numel() == p.shape[0]: # a batch of scalars
382
+ batch_sumsq_orig = batch_grad**2
383
+ # Dummpy values used by following `zip` statement.
384
+ batch_rms_orig = torch.ones(p.shape[0])
385
+ else:
386
+ batch_rms_orig = state["param_rms"]
387
+ batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
388
+
389
+ for name, sumsq_orig, rms, grad in zip(
390
+ batch_param_names,
391
+ batch_sumsq_orig,
392
+ batch_rms_orig,
393
+ batch_grad,
394
+ ):
395
+ proportion_orig = sumsq_orig / tot_sumsq
396
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
397
+
398
+ assert torch.isclose(
399
+ sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
400
+ torch.tensor(1.0),
401
+ )
402
+ sorted_by_proportion = {
403
+ k: v
404
+ for k, v in sorted(
405
+ all_sumsq_orig.items(),
406
+ key=lambda item: item[1][0],
407
+ reverse=True,
408
+ )
409
+ }
410
+ dominant_param_name = next(iter(sorted_by_proportion))
411
+ (
412
+ dominant_proportion,
413
+ dominant_sumsq,
414
+ dominant_rms,
415
+ dominant_grad,
416
+ ) = sorted_by_proportion[dominant_param_name]
417
+ logging.info(
418
+ f"Parameter Dominating tot_sumsq {dominant_param_name}"
419
+ f" with proportion {dominant_proportion:.2f},"
420
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
421
+ f"={dominant_sumsq:.3e},"
422
+ f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
423
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
424
+ )
425
+
426
+ def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
427
+ """
428
+ Do the step for one parameter, which is actually going to be a batch of
429
+ `real` parameters, with dim 0 as the batch dim.
430
+ Args:
431
+ group: dict to look up configuration values
432
+ p: parameter to update (actually multiple parameters stacked together
433
+ as a batch)
434
+ state: state-dict for p, to look up the optimizer state
435
+ """
436
+ lr = group["lr"]
437
+ size_update_period = group["size_update_period"]
438
+ beta1 = group["betas"][0]
439
+
440
+ grad = p.grad
441
+ if clipping_scale != 1.0:
442
+ grad = grad * clipping_scale
443
+ step = state["step"]
444
+ delta = state["delta"]
445
+
446
+ delta.mul_(beta1)
447
+ batch_size = p.shape[0]
448
+ numel = p.numel() // batch_size
449
+ if numel > 1:
450
+ # Update the size/scale of p, and set param_rms
451
+ scale_grads = state["scale_grads"]
452
+ scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
453
+ if step % size_update_period == size_update_period - 1:
454
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
455
+ param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
456
+ if step > 0:
457
+ # self._size_update() learns the overall scale on the
458
+ # parameter, by shrinking or expanding it.
459
+ self._size_update(group, scale_grads, p, state)
460
+
461
+ if numel == 1:
462
+ # For parameters with 1 element we just use regular Adam.
463
+ # Updates delta.
464
+ self._step_scalar(group, p, state)
465
+ else:
466
+ self._step(group, p, state)
467
+
468
+ state["step"] = step + 1
469
+
470
+ def _size_update(
471
+ self,
472
+ group: dict,
473
+ scale_grads: Tensor,
474
+ p: Tensor,
475
+ state: dict,
476
+ ) -> None:
477
+ """
478
+ Called only where p.numel() > 1, this updates the scale of the parameter.
479
+ If we imagine: p = underlying_param * scale.exp(), and we are doing
480
+ gradient descent on underlying param and on scale, this function does the update
481
+ on `scale`.
482
+
483
+ Args:
484
+ group: dict to look up configuration values
485
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
486
+ grads w.r.t. the scales.
487
+ p: The parameter to update
488
+ state: The state-dict of p
489
+ """
490
+
491
+ param_rms = state["param_rms"]
492
+ beta1, beta2 = group["betas"]
493
+ size_lr = group["lr"] * group["scalar_lr_scale"]
494
+ param_min_rms = group["param_min_rms"]
495
+ param_max_rms = group["param_max_rms"]
496
+ eps = group["eps"]
497
+ step = state["step"]
498
+ batch_size = p.shape[0]
499
+
500
+ size_update_period = scale_grads.shape[0]
501
+ # correct beta2 for the size update period: we will have
502
+ # faster decay at this level.
503
+ beta2_corr = beta2**size_update_period
504
+
505
+ scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
506
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
507
+ (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
508
+ alpha=1 - beta2_corr,
509
+ ) # shape is (batch_size, 1, 1, ...)
510
+
511
+ # The 1st time we reach here is when size_step == 1.
512
+ size_step = (step + 1) // size_update_period
513
+ bias_correction2 = 1 - beta2_corr**size_step
514
+ # we don't bother with bias_correction1; this will help prevent divergence
515
+ # at the start of training.
516
+
517
+ denom = scale_exp_avg_sq.sqrt() + eps
518
+
519
+ scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
520
+
521
+ is_too_small = param_rms < param_min_rms
522
+ is_too_large = param_rms > param_max_rms
523
+
524
+ # when the param gets too small, just don't shrink it any further.
525
+ scale_step.masked_fill_(is_too_small, 0.0)
526
+ # when it gets too large, stop it from getting any larger.
527
+ scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
528
+ delta = state["delta"]
529
+ # the factor of (1-beta1) relates to momentum.
530
+ delta.add_(p * scale_step, alpha=(1 - beta1))
531
+
532
+ def _step(self, group: dict, p: Tensor, state: dict):
533
+ """
534
+ This function does the core update of self.step(), in the case where the members of
535
+ the batch have more than 1 element.
536
+
537
+ Args:
538
+ group: A dict which will be used to look up configuration values
539
+ p: The parameter to be updated
540
+ grad: The grad of p
541
+ state: The state-dict corresponding to parameter p
542
+
543
+ This function modifies p.
544
+ """
545
+ grad = p.grad
546
+ lr = group["lr"]
547
+ beta1, beta2 = group["betas"]
548
+ eps = group["eps"]
549
+ param_min_rms = group["param_min_rms"]
550
+ step = state["step"]
551
+
552
+ exp_avg_sq = state["exp_avg_sq"]
553
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
554
+
555
+ this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
556
+ bias_correction2 = 1 - beta2 ** (this_step + 1)
557
+ if bias_correction2 < 0.99:
558
+ # note: not in-place.
559
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
560
+
561
+ denom = exp_avg_sq.sqrt()
562
+ denom += eps
563
+ grad = grad / denom
564
+
565
+ alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
566
+
567
+ delta = state["delta"]
568
+ delta.add_(grad * alpha)
569
+ p.add_(delta)
570
+
571
+ def _step_scalar(self, group: dict, p: Tensor, state: dict):
572
+ """
573
+ A simplified form of the core update for scalar tensors, where we cannot get a good
574
+ estimate of the parameter rms.
575
+ """
576
+ beta1, beta2 = group["betas"]
577
+ scalar_max = group["scalar_max"]
578
+ eps = group["eps"]
579
+ lr = group["lr"] * group["scalar_lr_scale"]
580
+ grad = p.grad
581
+
582
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
583
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
584
+
585
+ # bias_correction2 is like in Adam. Don't bother with bias_correction1;
586
+ # slower update at the start will help stability anyway.
587
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
588
+ denom = (exp_avg_sq / bias_correction2).sqrt() + eps
589
+
590
+ delta = state["delta"]
591
+ delta.add_(grad / denom, alpha=-lr * (1 - beta1))
592
+ p.clamp_(min=-scalar_max, max=scalar_max)
593
+ p.add_(delta)
AR/modules/patched_mha_with_cache.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.functional import *
2
+ from torch.nn.functional import (
3
+ _mha_shape_check,
4
+ _canonical_mask,
5
+ _none_or_dtype,
6
+ _in_projection_packed,
7
+ )
8
+ import torch
9
+ # Tensor = torch.Tensor
10
+ # from typing import Callable, List, Optional, Tuple, Union
11
+
12
+
13
+ def multi_head_attention_forward_patched(
14
+ query,
15
+ key,
16
+ value,
17
+ embed_dim_to_check,
18
+ num_heads,
19
+ in_proj_weight,
20
+ in_proj_bias,
21
+ bias_k,
22
+ bias_v,
23
+ add_zero_attn,
24
+ dropout_p: float,
25
+ out_proj_weight,
26
+ out_proj_bias,
27
+ training=True,
28
+ key_padding_mask=None,
29
+ need_weights=True,
30
+ attn_mask=None,
31
+ use_separate_proj_weight=False,
32
+ q_proj_weight=None,
33
+ k_proj_weight=None,
34
+ v_proj_weight=None,
35
+ static_k=None,
36
+ static_v=None,
37
+ average_attn_weights=True,
38
+ is_causal=False,
39
+ cache=None,
40
+ ):
41
+ r"""
42
+ Args:
43
+ query, key, value: map a query and a set of key-value pairs to an output.
44
+ See "Attention Is All You Need" for more details.
45
+ embed_dim_to_check: total dimension of the model.
46
+ num_heads: parallel attention heads.
47
+ in_proj_weight, in_proj_bias: input projection weight and bias.
48
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
49
+ add_zero_attn: add a new batch of zeros to the key and
50
+ value sequences at dim=1.
51
+ dropout_p: probability of an element to be zeroed.
52
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
53
+ training: apply dropout if is ``True``.
54
+ key_padding_mask: if provided, specified padding elements in the key will
55
+ be ignored by the attention. This is an binary mask. When the value is True,
56
+ the corresponding value on the attention layer will be filled with -inf.
57
+ need_weights: output attn_output_weights.
58
+ Default: `True`
59
+ Note: `needs_weight` defaults to `True`, but should be set to `False`
60
+ For best performance when attention weights are not nedeeded.
61
+ *Setting needs_weights to `True`
62
+ leads to a significant performance degradation.*
63
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
64
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
65
+ is_causal: If specified, applies a causal mask as attention mask, and ignores
66
+ attn_mask for computing scaled dot product attention.
67
+ Default: ``False``.
68
+ .. warning::
69
+ is_causal is provides a hint that the attn_mask is the
70
+ causal mask.Providing incorrect hints can result in
71
+ incorrect execution, including forward and backward
72
+ compatibility.
73
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
74
+ and value in different forms. If false, in_proj_weight will be used, which is
75
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
76
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
77
+ static_k, static_v: static key and value used for attention operators.
78
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
79
+ Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
80
+ when ``need_weights=True.``. Default: True
81
+
82
+
83
+ Shape:
84
+ Inputs:
85
+ - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
86
+ the embedding dimension.
87
+ - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
88
+ the embedding dimension.
89
+ - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
90
+ the embedding dimension.
91
+ - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
92
+ If a FloatTensor is provided, it will be directly added to the value.
93
+ If a BoolTensor is provided, the positions with the
94
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
95
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
96
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
97
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
98
+ positions. If a BoolTensor is provided, positions with ``True``
99
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
100
+ is provided, it will be added to the attention weight.
101
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
102
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
103
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
104
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
105
+
106
+ Outputs:
107
+ - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
108
+ E is the embedding dimension.
109
+ - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
110
+ attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
111
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
112
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
113
+ head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
114
+ """
115
+ tens_ops = (
116
+ query,
117
+ key,
118
+ value,
119
+ in_proj_weight,
120
+ in_proj_bias,
121
+ bias_k,
122
+ bias_v,
123
+ out_proj_weight,
124
+ out_proj_bias,
125
+ )
126
+ if has_torch_function(tens_ops):
127
+ return handle_torch_function(
128
+ multi_head_attention_forward,
129
+ tens_ops,
130
+ query,
131
+ key,
132
+ value,
133
+ embed_dim_to_check,
134
+ num_heads,
135
+ in_proj_weight,
136
+ in_proj_bias,
137
+ bias_k,
138
+ bias_v,
139
+ add_zero_attn,
140
+ dropout_p,
141
+ out_proj_weight,
142
+ out_proj_bias,
143
+ training=training,
144
+ key_padding_mask=key_padding_mask,
145
+ need_weights=need_weights,
146
+ attn_mask=attn_mask,
147
+ is_causal=is_causal,
148
+ use_separate_proj_weight=use_separate_proj_weight,
149
+ q_proj_weight=q_proj_weight,
150
+ k_proj_weight=k_proj_weight,
151
+ v_proj_weight=v_proj_weight,
152
+ static_k=static_k,
153
+ static_v=static_v,
154
+ average_attn_weights=average_attn_weights,
155
+ cache=cache,
156
+ )
157
+
158
+ is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
159
+
160
+ # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
161
+ # is batched, run the computation and before returning squeeze the
162
+ # batch dimension so that the output doesn't carry this temporary batch dimension.
163
+ if not is_batched:
164
+ # unsqueeze if the input is unbatched
165
+ query = query.unsqueeze(1)
166
+ key = key.unsqueeze(1)
167
+ value = value.unsqueeze(1)
168
+ if key_padding_mask is not None:
169
+ key_padding_mask = key_padding_mask.unsqueeze(0)
170
+
171
+ # set up shape vars
172
+ tgt_len, bsz, embed_dim = query.shape
173
+ src_len, _, _ = key.shape
174
+
175
+ key_padding_mask = _canonical_mask(
176
+ mask=key_padding_mask,
177
+ mask_name="key_padding_mask",
178
+ other_type=_none_or_dtype(attn_mask),
179
+ other_name="attn_mask",
180
+ target_type=query.dtype,
181
+ )
182
+
183
+ if is_causal and attn_mask is None:
184
+ raise RuntimeError(
185
+ "Need attn_mask if specifying the is_causal hint. "
186
+ "You may use the Transformer module method "
187
+ "`generate_square_subsequent_mask` to create this mask."
188
+ )
189
+
190
+ if is_causal and key_padding_mask is None and not need_weights:
191
+ # when we have a kpm or need weights, we need attn_mask
192
+ # Otherwise, we use the is_causal hint go as is_causal
193
+ # indicator to SDPA.
194
+ attn_mask = None
195
+ else:
196
+ attn_mask = _canonical_mask(
197
+ mask=attn_mask,
198
+ mask_name="attn_mask",
199
+ other_type=None,
200
+ other_name="",
201
+ target_type=query.dtype,
202
+ check_other=False,
203
+ )
204
+
205
+ if key_padding_mask is not None:
206
+ # We have the attn_mask, and use that to merge kpm into it.
207
+ # Turn off use of is_causal hint, as the merged mask is no
208
+ # longer causal.
209
+ is_causal = False
210
+
211
+ assert embed_dim == embed_dim_to_check, (
212
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
213
+ )
214
+ if isinstance(embed_dim, torch.Tensor):
215
+ # embed_dim can be a tensor when JIT tracing
216
+ head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
217
+ else:
218
+ head_dim = embed_dim // num_heads
219
+ assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
220
+ if use_separate_proj_weight:
221
+ # allow MHA to have different embedding dimensions when separate projection weights are used
222
+ assert key.shape[:2] == value.shape[:2], (
223
+ f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
224
+ )
225
+ else:
226
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
227
+
228
+ #
229
+ # compute in-projection
230
+ #
231
+ if not use_separate_proj_weight:
232
+ assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
233
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
234
+ else:
235
+ assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
236
+ assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
237
+ assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
238
+ if in_proj_bias is None:
239
+ b_q = b_k = b_v = None
240
+ else:
241
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
242
+ q, k, v = _in_projection(
243
+ query,
244
+ key,
245
+ value,
246
+ q_proj_weight,
247
+ k_proj_weight,
248
+ v_proj_weight,
249
+ b_q,
250
+ b_k,
251
+ b_v,
252
+ )
253
+ if cache != None:
254
+ if cache["first_infer"] == 1:
255
+ cache["k"][cache["stage"]] = k
256
+ # print(0,cache["k"].shape)
257
+ cache["v"][cache["stage"]] = v
258
+ else: ###12个layer每个都要留自己的cache_kv
259
+ # print(1,cache["k"].shape)
260
+ cache["k"][cache["stage"]] = torch.cat(
261
+ [cache["k"][cache["stage"]], k], 0
262
+ ) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
263
+ cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
264
+ # print(2, cache["k"].shape)
265
+ src_len = cache["k"][cache["stage"]].shape[0]
266
+ k = cache["k"][cache["stage"]]
267
+ v = cache["v"][cache["stage"]]
268
+ # if attn_mask is not None:
269
+ # attn_mask=attn_mask[-1:,]
270
+ # print(attn_mask.shape,attn_mask)
271
+ cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
272
+ # print(2333,cache)
273
+ # prep attention mask
274
+
275
+ attn_mask = _canonical_mask(
276
+ mask=attn_mask,
277
+ mask_name="attn_mask",
278
+ other_type=None,
279
+ other_name="",
280
+ target_type=q.dtype,
281
+ check_other=False,
282
+ )
283
+
284
+ if attn_mask is not None:
285
+ # ensure attn_mask's dim is 3
286
+ if attn_mask.dim() == 2:
287
+ correct_2d_size = (tgt_len, src_len)
288
+ if attn_mask.shape != correct_2d_size:
289
+ raise RuntimeError(
290
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
291
+ )
292
+ attn_mask = attn_mask.unsqueeze(0)
293
+ elif attn_mask.dim() == 3:
294
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
295
+ if attn_mask.shape != correct_3d_size:
296
+ raise RuntimeError(
297
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
298
+ )
299
+ else:
300
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
301
+
302
+ # add bias along batch dimension (currently second)
303
+ if bias_k is not None and bias_v is not None:
304
+ assert static_k is None, "bias cannot be added to static key."
305
+ assert static_v is None, "bias cannot be added to static value."
306
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
307
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
308
+ if attn_mask is not None:
309
+ attn_mask = pad(attn_mask, (0, 1))
310
+ if key_padding_mask is not None:
311
+ key_padding_mask = pad(key_padding_mask, (0, 1))
312
+ else:
313
+ assert bias_k is None
314
+ assert bias_v is None
315
+
316
+ #
317
+ # reshape q, k, v for multihead attention and make em batch first
318
+ #
319
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
320
+ if static_k is None:
321
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
322
+ else:
323
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
324
+ assert static_k.size(0) == bsz * num_heads, (
325
+ f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
326
+ )
327
+ assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
328
+ k = static_k
329
+ if static_v is None:
330
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
331
+ else:
332
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
333
+ assert static_v.size(0) == bsz * num_heads, (
334
+ f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
335
+ )
336
+ assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
337
+ v = static_v
338
+
339
+ # add zero attention along batch dimension (now first)
340
+ if add_zero_attn:
341
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
342
+ k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
343
+ v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
344
+ if attn_mask is not None:
345
+ attn_mask = pad(attn_mask, (0, 1))
346
+ if key_padding_mask is not None:
347
+ key_padding_mask = pad(key_padding_mask, (0, 1))
348
+
349
+ # update source sequence length after adjustments
350
+ src_len = k.size(1)
351
+
352
+ # merge key padding and attention masks
353
+ if key_padding_mask is not None:
354
+ assert key_padding_mask.shape == (
355
+ bsz,
356
+ src_len,
357
+ ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
358
+ key_padding_mask = (
359
+ key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
360
+ )
361
+ if attn_mask is None:
362
+ attn_mask = key_padding_mask
363
+ else:
364
+ attn_mask = attn_mask + key_padding_mask
365
+
366
+ # adjust dropout probability
367
+ if not training:
368
+ dropout_p = 0.0
369
+
370
+ #
371
+ # (deep breath) calculate attention and out projection
372
+ #
373
+
374
+ if need_weights:
375
+ B, Nt, E = q.shape
376
+ q_scaled = q / math.sqrt(E)
377
+
378
+ assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
379
+
380
+ if attn_mask is not None:
381
+ attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
382
+ else:
383
+ attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
384
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
385
+ if dropout_p > 0.0:
386
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p)
387
+
388
+ attn_output = torch.bmm(attn_output_weights, v)
389
+
390
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
391
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
392
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
393
+
394
+ # optionally average attention weights over heads
395
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
396
+ if average_attn_weights:
397
+ attn_output_weights = attn_output_weights.mean(dim=1)
398
+
399
+ if not is_batched:
400
+ # squeeze the output if input was unbatched
401
+ attn_output = attn_output.squeeze(1)
402
+ attn_output_weights = attn_output_weights.squeeze(0)
403
+ return attn_output, attn_output_weights
404
+ else:
405
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
406
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
407
+ # in order to match the input for SDPA of (N, num_heads, L, S)
408
+ if attn_mask is not None:
409
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
410
+ attn_mask = attn_mask.unsqueeze(0)
411
+ else:
412
+ attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
413
+
414
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
415
+ k = k.view(bsz, num_heads, src_len, head_dim)
416
+ v = v.view(bsz, num_heads, src_len, head_dim)
417
+
418
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
419
+ attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
420
+
421
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
422
+
423
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
424
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
425
+ if not is_batched:
426
+ # squeeze the output if input was unbatched
427
+ attn_output = attn_output.squeeze(1)
428
+ return attn_output, None
AR/modules/patched_mha_with_cache_onnx.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.functional import *
2
+ from torch.nn.functional import (
3
+ _canonical_mask,
4
+ )
5
+
6
+
7
+ def multi_head_attention_forward_patched(
8
+ query,
9
+ key,
10
+ value,
11
+ embed_dim_to_check: int,
12
+ num_heads: int,
13
+ in_proj_weight,
14
+ in_proj_bias: Optional[Tensor],
15
+ bias_k: Optional[Tensor],
16
+ bias_v: Optional[Tensor],
17
+ add_zero_attn: bool,
18
+ dropout_p: float,
19
+ out_proj_weight: Tensor,
20
+ out_proj_bias: Optional[Tensor],
21
+ training: bool = True,
22
+ key_padding_mask: Optional[Tensor] = None,
23
+ need_weights: bool = True,
24
+ attn_mask: Optional[Tensor] = None,
25
+ use_separate_proj_weight: bool = False,
26
+ q_proj_weight: Optional[Tensor] = None,
27
+ k_proj_weight: Optional[Tensor] = None,
28
+ v_proj_weight: Optional[Tensor] = None,
29
+ static_k: Optional[Tensor] = None,
30
+ static_v: Optional[Tensor] = None,
31
+ average_attn_weights: bool = True,
32
+ is_causal: bool = False,
33
+ cache=None,
34
+ ) -> Tuple[Tensor, Optional[Tensor]]:
35
+ # set up shape vars
36
+ _, _, embed_dim = query.shape
37
+ attn_mask = _canonical_mask(
38
+ mask=attn_mask,
39
+ mask_name="attn_mask",
40
+ other_type=None,
41
+ other_name="",
42
+ target_type=query.dtype,
43
+ check_other=False,
44
+ )
45
+ head_dim = embed_dim // num_heads
46
+
47
+ proj_qkv = linear(query, in_proj_weight, in_proj_bias)
48
+ proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
49
+ q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
50
+
51
+ if cache["first_infer"] == 1:
52
+ cache["k"][cache["stage"]] = k
53
+ cache["v"][cache["stage"]] = v
54
+ else:
55
+ cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
56
+ cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
57
+ k = cache["k"][cache["stage"]]
58
+ v = cache["v"][cache["stage"]]
59
+ cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
60
+
61
+ attn_mask = _canonical_mask(
62
+ mask=attn_mask,
63
+ mask_name="attn_mask",
64
+ other_type=None,
65
+ other_name="",
66
+ target_type=q.dtype,
67
+ check_other=False,
68
+ )
69
+ attn_mask = attn_mask.unsqueeze(0)
70
+
71
+ q = q.view(-1, num_heads, head_dim).transpose(0, 1)
72
+ k = k.view(-1, num_heads, head_dim).transpose(0, 1)
73
+ v = v.view(-1, num_heads, head_dim).transpose(0, 1)
74
+
75
+ dropout_p = 0.0
76
+ attn_mask = attn_mask.unsqueeze(0)
77
+ q = q.view(num_heads, -1, head_dim).unsqueeze(0)
78
+ k = k.view(num_heads, -1, head_dim).unsqueeze(0)
79
+ v = v.view(num_heads, -1, head_dim).unsqueeze(0)
80
+ attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
81
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
82
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
83
+ attn_output = attn_output.view(-1, 1, attn_output.size(1))
84
+
85
+ return attn_output
AR/modules/scaling.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import random
17
+ from typing import Optional
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch import Tensor
23
+
24
+
25
+ class DoubleSwishFunction(torch.autograd.Function):
26
+ """
27
+ double_swish(x) = x * torch.sigmoid(x-1)
28
+ This is a definition, originally motivated by its close numerical
29
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
30
+
31
+ Memory-efficient derivative computation:
32
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
33
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
34
+ Now, s'(x) = s(x) * (1-s(x)).
35
+ double_swish'(x) = x * s'(x) + s(x).
36
+ = x * s(x) * (1-s(x)) + s(x).
37
+ = double_swish(x) * (1-s(x)) + s(x)
38
+ ... so we just need to remember s(x) but not x itself.
39
+ """
40
+
41
+ @staticmethod
42
+ def forward(ctx, x: Tensor) -> Tensor:
43
+ requires_grad = x.requires_grad
44
+ x_dtype = x.dtype
45
+ if x.dtype == torch.float16:
46
+ x = x.to(torch.float32)
47
+
48
+ s = torch.sigmoid(x - 1.0)
49
+ y = x * s
50
+
51
+ if requires_grad:
52
+ deriv = y * (1 - s) + s
53
+ # notes on derivative of x * sigmoid(x - 1):
54
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
55
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
56
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
57
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
58
+ # floors), should be expectation-preserving.
59
+ floor = -0.043637
60
+ ceil = 1.2
61
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
62
+ if __name__ == "__main__":
63
+ # for self-testing only.
64
+ assert d_scaled.min() >= 0.0
65
+ assert d_scaled.max() < 256.0
66
+ d_int = d_scaled.to(torch.uint8)
67
+ ctx.save_for_backward(d_int)
68
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
69
+ y = y.to(torch.float16)
70
+ return y
71
+
72
+ @staticmethod
73
+ def backward(ctx, y_grad: Tensor) -> Tensor:
74
+ (d,) = ctx.saved_tensors
75
+ # the same constants as used in forward pass.
76
+ floor = -0.043637
77
+ ceil = 1.2
78
+ d = d * ((ceil - floor) / 255.0) + floor
79
+ return y_grad * d
80
+
81
+
82
+ class DoubleSwish(torch.nn.Module):
83
+ def forward(self, x: Tensor) -> Tensor:
84
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
85
+ that we approximate closely with x * sigmoid(x-1).
86
+ """
87
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
88
+ return x * torch.sigmoid(x - 1.0)
89
+ return DoubleSwishFunction.apply(x)
90
+
91
+
92
+ class ActivationBalancerFunction(torch.autograd.Function):
93
+ @staticmethod
94
+ def forward(
95
+ ctx,
96
+ x: Tensor,
97
+ scale_factor: Tensor,
98
+ sign_factor: Optional[Tensor],
99
+ channel_dim: int,
100
+ ) -> Tensor:
101
+ if channel_dim < 0:
102
+ channel_dim += x.ndim
103
+ ctx.channel_dim = channel_dim
104
+ xgt0 = x > 0
105
+ if sign_factor is None:
106
+ ctx.save_for_backward(xgt0, scale_factor)
107
+ else:
108
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
109
+ return x
110
+
111
+ @staticmethod
112
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
113
+ if len(ctx.saved_tensors) == 3:
114
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
115
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
116
+ scale_factor = scale_factor.unsqueeze(-1)
117
+ sign_factor = sign_factor.unsqueeze(-1)
118
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
119
+ else:
120
+ xgt0, scale_factor = ctx.saved_tensors
121
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
122
+ scale_factor = scale_factor.unsqueeze(-1)
123
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
124
+ neg_delta_grad = x_grad.abs() * factor
125
+ return (
126
+ x_grad - neg_delta_grad,
127
+ None,
128
+ None,
129
+ None,
130
+ )
131
+
132
+
133
+ def _compute_scale_factor(
134
+ x: Tensor,
135
+ channel_dim: int,
136
+ min_abs: float,
137
+ max_abs: float,
138
+ gain_factor: float,
139
+ max_factor: float,
140
+ ) -> Tensor:
141
+ if channel_dim < 0:
142
+ channel_dim += x.ndim
143
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
144
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
145
+
146
+ if min_abs == 0.0:
147
+ below_threshold = 0.0
148
+ else:
149
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
150
+ # x_abs)_mean , min_abs.
151
+ below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
152
+
153
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
154
+
155
+ return below_threshold - above_threshold
156
+
157
+
158
+ def _compute_sign_factor(
159
+ x: Tensor,
160
+ channel_dim: int,
161
+ min_positive: float,
162
+ max_positive: float,
163
+ gain_factor: float,
164
+ max_factor: float,
165
+ ) -> Tensor:
166
+ if channel_dim < 0:
167
+ channel_dim += x.ndim
168
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
169
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
170
+ if min_positive == 0.0:
171
+ factor1 = 0.0
172
+ else:
173
+ # 0 if proportion_positive >= min_positive, else can be
174
+ # as large as max_factor.
175
+ factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
176
+
177
+ if max_positive == 1.0:
178
+ factor2 = 0.0
179
+ else:
180
+ # 0 if self.proportion_positive <= max_positive, else can be
181
+ # as large as -max_factor.
182
+ factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
183
+ min=0, max=max_factor
184
+ )
185
+ sign_factor = factor1 - factor2
186
+ # require min_positive != 0 or max_positive != 1:
187
+ assert not isinstance(sign_factor, float)
188
+ return sign_factor
189
+
190
+
191
+ class ActivationBalancer(torch.nn.Module):
192
+ """
193
+ Modifies the backpropped derivatives of a function to try to encourage, for
194
+ each channel, that it is positive at least a proportion `threshold` of the
195
+ time. It does this by multiplying negative derivative values by up to
196
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
197
+ interpolated from 1 at the threshold to those extremal values when none
198
+ of the inputs are positive.
199
+
200
+ Args:
201
+ num_channels: the number of channels
202
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
203
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
204
+ min_positive: the minimum, per channel, of the proportion of the time
205
+ that (x > 0), below which we start to modify the derivatives.
206
+ max_positive: the maximum, per channel, of the proportion of the time
207
+ that (x > 0), above which we start to modify the derivatives.
208
+ max_factor: the maximum factor by which we modify the derivatives for
209
+ either the sign constraint or the magnitude constraint;
210
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
211
+ values in the range [0.98..1.02].
212
+ sign_gain_factor: determines the 'gain' with which we increase the
213
+ change in gradient once the constraints on min_positive and max_positive
214
+ are violated.
215
+ scale_gain_factor: determines the 'gain' with which we increase the
216
+ change in gradient once the constraints on min_abs and max_abs
217
+ are violated.
218
+ min_abs: the minimum average-absolute-value difference from the mean
219
+ value per channel, which we allow, before we start to modify
220
+ the derivatives to prevent this.
221
+ max_abs: the maximum average-absolute-value difference from the mean
222
+ value per channel, which we allow, before we start to modify
223
+ the derivatives to prevent this.
224
+ min_prob: determines the minimum probability with which we modify the
225
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
226
+ on each forward(). This is done randomly to prevent all layers
227
+ from doing it at the same time. Early in training we may use
228
+ higher probabilities than this; it will decay to this value.
229
+ """
230
+
231
+ def __init__(
232
+ self,
233
+ num_channels: int,
234
+ channel_dim: int,
235
+ min_positive: float = 0.05,
236
+ max_positive: float = 0.95,
237
+ max_factor: float = 0.04,
238
+ sign_gain_factor: float = 0.01,
239
+ scale_gain_factor: float = 0.02,
240
+ min_abs: float = 0.2,
241
+ max_abs: float = 100.0,
242
+ min_prob: float = 0.1,
243
+ ):
244
+ super(ActivationBalancer, self).__init__()
245
+ self.num_channels = num_channels
246
+ self.channel_dim = channel_dim
247
+ self.min_positive = min_positive
248
+ self.max_positive = max_positive
249
+ self.max_factor = max_factor
250
+ self.min_abs = min_abs
251
+ self.max_abs = max_abs
252
+ self.min_prob = min_prob
253
+ self.sign_gain_factor = sign_gain_factor
254
+ self.scale_gain_factor = scale_gain_factor
255
+
256
+ # count measures how many times the forward() function has been called.
257
+ # We occasionally sync this to a tensor called `count`, that exists to
258
+ # make sure it is synced to disk when we load and save the model.
259
+ self.cpu_count = 0
260
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
261
+
262
+ def forward(self, x: Tensor) -> Tensor:
263
+ if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
264
+ return _no_op(x)
265
+
266
+ count = self.cpu_count
267
+ self.cpu_count += 1
268
+
269
+ if random.random() < 0.01:
270
+ # Occasionally sync self.cpu_count with self.count.
271
+ # count affects the decay of 'prob'. don't do this on every iter,
272
+ # because syncing with the GPU is slow.
273
+ self.cpu_count = max(self.cpu_count, self.count.item())
274
+ self.count.fill_(self.cpu_count)
275
+
276
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
277
+ # a floor at min_prob (==0.1, by default)
278
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
279
+
280
+ if random.random() < prob:
281
+ sign_gain_factor = 0.5
282
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
283
+ sign_factor = _compute_sign_factor(
284
+ x,
285
+ self.channel_dim,
286
+ self.min_positive,
287
+ self.max_positive,
288
+ gain_factor=self.sign_gain_factor / prob,
289
+ max_factor=self.max_factor,
290
+ )
291
+ else:
292
+ sign_factor = None
293
+
294
+ scale_factor = _compute_scale_factor(
295
+ x.detach(),
296
+ self.channel_dim,
297
+ min_abs=self.min_abs,
298
+ max_abs=self.max_abs,
299
+ gain_factor=self.scale_gain_factor / prob,
300
+ max_factor=self.max_factor,
301
+ )
302
+ return ActivationBalancerFunction.apply(
303
+ x,
304
+ scale_factor,
305
+ sign_factor,
306
+ self.channel_dim,
307
+ )
308
+ else:
309
+ return _no_op(x)
310
+
311
+
312
+ def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
313
+ """
314
+ ActivationBalancer -> DoubleSwish
315
+ """
316
+ balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
317
+ return nn.Sequential(
318
+ balancer,
319
+ DoubleSwish(),
320
+ )
AR/modules/transformer.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
2
+ import copy
3
+ import numbers
4
+ from functools import partial
5
+ from typing import Any
6
+ from typing import Callable
7
+ from typing import List
8
+ from typing import Optional
9
+ from typing import Tuple
10
+ from typing import Union
11
+
12
+ import torch
13
+ from AR.modules.activation import MultiheadAttention
14
+ from AR.modules.scaling import BalancedDoubleSwish
15
+ from torch import nn
16
+ from torch import Tensor
17
+ from torch.nn import functional as F
18
+
19
+ _shape_t = Union[int, List[int], torch.Size]
20
+
21
+
22
+ class LayerNorm(nn.Module):
23
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
24
+ normalized_shape: Tuple[int, ...]
25
+ eps: float
26
+ elementwise_affine: bool
27
+
28
+ def __init__(
29
+ self,
30
+ normalized_shape: _shape_t,
31
+ eps: float = 1e-5,
32
+ elementwise_affine: bool = True,
33
+ device=None,
34
+ dtype=None,
35
+ ) -> None:
36
+ factory_kwargs = {"device": device, "dtype": dtype}
37
+ super(LayerNorm, self).__init__()
38
+ if isinstance(normalized_shape, numbers.Integral):
39
+ # mypy error: incompatible types in assignment
40
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
41
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
42
+ self.eps = eps
43
+ self.elementwise_affine = elementwise_affine
44
+ if self.elementwise_affine:
45
+ self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
46
+ self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
47
+ else:
48
+ self.register_parameter("weight", None)
49
+ self.register_parameter("bias", None)
50
+
51
+ self.reset_parameters()
52
+
53
+ def reset_parameters(self) -> None:
54
+ if self.elementwise_affine:
55
+ nn.init.ones_(self.weight)
56
+ nn.init.zeros_(self.bias)
57
+
58
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
59
+ if isinstance(input, tuple):
60
+ input, embedding = input
61
+ return (
62
+ F.layer_norm(
63
+ input,
64
+ self.normalized_shape,
65
+ self.weight,
66
+ self.bias,
67
+ self.eps,
68
+ ),
69
+ embedding,
70
+ )
71
+
72
+ assert embedding is None
73
+ return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
74
+
75
+ def extra_repr(self) -> str:
76
+ return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
77
+
78
+
79
+ class IdentityNorm(nn.Module):
80
+ def __init__(
81
+ self,
82
+ d_model: int,
83
+ eps: float = 1e-5,
84
+ device=None,
85
+ dtype=None,
86
+ ) -> None:
87
+ super(IdentityNorm, self).__init__()
88
+
89
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
90
+ if isinstance(input, tuple):
91
+ return input
92
+
93
+ assert embedding is None
94
+ return input
95
+
96
+
97
+ class TransformerEncoder(nn.Module):
98
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
99
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
100
+
101
+ Args:
102
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
103
+ num_layers: the number of sub-encoder-layers in the encoder (required).
104
+ norm: the layer normalization component (optional).
105
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
106
+ (and convert back on output). This will improve the overall performance of
107
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
108
+
109
+ Examples::
110
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
111
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
112
+ >>> src = torch.rand(10, 32, 512)
113
+ >>> out = transformer_encoder(src)
114
+ """
115
+
116
+ __constants__ = ["norm"]
117
+
118
+ def __init__(self, encoder_layer, num_layers, norm=None):
119
+ super(TransformerEncoder, self).__init__()
120
+ self.layers = _get_clones(encoder_layer, num_layers)
121
+ self.num_layers = num_layers
122
+ self.norm = norm
123
+
124
+ def forward(
125
+ self,
126
+ src: Tensor,
127
+ mask: Optional[Tensor] = None,
128
+ src_key_padding_mask: Optional[Tensor] = None,
129
+ return_layer_states: bool = False,
130
+ cache=None,
131
+ ) -> Tensor:
132
+ r"""Pass the input through the encoder layers in turn.
133
+
134
+ Args:
135
+ src: the sequence to the encoder (required).
136
+ mask: the mask for the src sequence (optional).
137
+ src_key_padding_mask: the mask for the src keys per batch (optional).
138
+ return_layer_states: return layers' state (optional).
139
+
140
+ Shape:
141
+ see the docs in Transformer class.
142
+ """
143
+ if return_layer_states:
144
+ layer_states = [] # layers' output
145
+ output = src
146
+ for mod in self.layers:
147
+ output = mod(
148
+ output,
149
+ src_mask=mask,
150
+ src_key_padding_mask=src_key_padding_mask,
151
+ cache=cache,
152
+ )
153
+ layer_states.append(output[0])
154
+
155
+ if self.norm is not None:
156
+ output = self.norm(output)
157
+
158
+ return layer_states, output
159
+
160
+ output = src
161
+ for mod in self.layers:
162
+ output = mod(
163
+ output,
164
+ src_mask=mask,
165
+ src_key_padding_mask=src_key_padding_mask,
166
+ cache=cache,
167
+ )
168
+
169
+ if self.norm is not None:
170
+ output = self.norm(output)
171
+
172
+ return output
173
+
174
+
175
+ class TransformerEncoderLayer(nn.Module):
176
+ __constants__ = ["batch_first", "norm_first"]
177
+
178
+ def __init__(
179
+ self,
180
+ d_model: int,
181
+ nhead: int,
182
+ dim_feedforward: int = 2048,
183
+ dropout: float = 0.1,
184
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
185
+ batch_first: bool = False,
186
+ norm_first: bool = False,
187
+ device=None,
188
+ dtype=None,
189
+ linear1_self_attention_cls: nn.Module = nn.Linear,
190
+ linear2_self_attention_cls: nn.Module = nn.Linear,
191
+ linear1_feedforward_cls: nn.Module = nn.Linear,
192
+ linear2_feedforward_cls: nn.Module = nn.Linear,
193
+ layer_norm_cls: nn.Module = LayerNorm,
194
+ layer_norm_eps: float = 1e-5,
195
+ adaptive_layer_norm=False,
196
+ ) -> None:
197
+ factory_kwargs = {"device": device, "dtype": dtype}
198
+ super(TransformerEncoderLayer, self).__init__()
199
+ # print(233333333333,d_model,nhead)
200
+ # import os
201
+ # os._exit(2333333)
202
+ self.self_attn = MultiheadAttention(
203
+ d_model, # 512 16
204
+ nhead,
205
+ dropout=dropout,
206
+ batch_first=batch_first,
207
+ linear1_cls=linear1_self_attention_cls,
208
+ linear2_cls=linear2_self_attention_cls,
209
+ **factory_kwargs,
210
+ )
211
+
212
+ # Implementation of Feedforward model
213
+ self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
214
+ self.dropout = nn.Dropout(dropout)
215
+ self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
216
+
217
+ self.norm_first = norm_first
218
+ self.dropout1 = nn.Dropout(dropout)
219
+ self.dropout2 = nn.Dropout(dropout)
220
+
221
+ # Legacy string support for activation function.
222
+ if isinstance(activation, str):
223
+ activation = _get_activation_fn(activation)
224
+ elif isinstance(activation, partial):
225
+ activation = activation(d_model)
226
+ elif activation == BalancedDoubleSwish:
227
+ activation = BalancedDoubleSwish(d_model)
228
+
229
+ # # We can't test self.activation in forward() in TorchScript,
230
+ # # so stash some information about it instead.
231
+ # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
232
+ # self.activation_relu_or_gelu = 1
233
+ # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
234
+ # self.activation_relu_or_gelu = 2
235
+ # else:
236
+ # self.activation_relu_or_gelu = 0
237
+ self.activation = activation
238
+
239
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
240
+ if layer_norm_cls == IdentityNorm:
241
+ norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
242
+ else:
243
+ norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
244
+
245
+ if adaptive_layer_norm:
246
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
247
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
248
+ else:
249
+ self.norm1 = norm1
250
+ self.norm2 = norm2
251
+
252
+ def __setstate__(self, state):
253
+ super(TransformerEncoderLayer, self).__setstate__(state)
254
+ if not hasattr(self, "activation"):
255
+ self.activation = F.relu
256
+
257
+ def forward(
258
+ self,
259
+ src: Tensor,
260
+ src_mask: Optional[Tensor] = None,
261
+ src_key_padding_mask: Optional[Tensor] = None,
262
+ cache=None,
263
+ ) -> Tensor:
264
+ r"""Pass the input through the encoder layer.
265
+
266
+ Args:
267
+ src: the sequence to the encoder layer (required).
268
+ src_mask: the mask for the src sequence (optional).
269
+ src_key_padding_mask: the mask for the src keys per batch (optional).
270
+
271
+ Shape:
272
+ see the docs in Transformer class.
273
+ """
274
+ x, stage_embedding = src, None
275
+ is_src_tuple = False
276
+ if isinstance(src, tuple):
277
+ x, stage_embedding = src
278
+ is_src_tuple = True
279
+
280
+ if src_key_padding_mask is not None:
281
+ _skpm_dtype = src_key_padding_mask.dtype
282
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
283
+ raise AssertionError("only bool and floating types of key_padding_mask are supported")
284
+
285
+ if self.norm_first:
286
+ x = x + self._sa_block(
287
+ self.norm1(x, stage_embedding),
288
+ src_mask,
289
+ src_key_padding_mask,
290
+ cache=cache,
291
+ )
292
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
293
+ else:
294
+ x = self.norm1(
295
+ x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
296
+ stage_embedding,
297
+ )
298
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
299
+
300
+ if is_src_tuple:
301
+ return (x, stage_embedding)
302
+ return x
303
+
304
+ # self-attention block
305
+ def _sa_block(
306
+ self,
307
+ x: Tensor,
308
+ attn_mask: Optional[Tensor],
309
+ key_padding_mask: Optional[Tensor],
310
+ cache=None,
311
+ ) -> Tensor:
312
+ # print(x.shape,attn_mask.shape,key_padding_mask)
313
+ # torch.Size([1, 188, 512]) torch.Size([188, 188]) None
314
+ # import os
315
+ # os._exit(23333)
316
+ x = self.self_attn(
317
+ x,
318
+ x,
319
+ x,
320
+ attn_mask=attn_mask,
321
+ key_padding_mask=key_padding_mask,
322
+ need_weights=False,
323
+ cache=cache,
324
+ )[0]
325
+ return self.dropout1(x)
326
+
327
+ # feed forward block
328
+ def _ff_block(self, x: Tensor) -> Tensor:
329
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
330
+ return self.dropout2(x)
331
+
332
+
333
+ class AdaptiveLayerNorm(nn.Module):
334
+ r"""Adaptive Layer Normalization"""
335
+
336
+ def __init__(self, d_model, norm) -> None:
337
+ super(AdaptiveLayerNorm, self).__init__()
338
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
339
+ self.norm = norm
340
+ self.d_model = d_model
341
+ self.eps = self.norm.eps
342
+
343
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
344
+ if isinstance(input, tuple):
345
+ input, embedding = input
346
+ weight, bias = torch.split(
347
+ self.project_layer(embedding),
348
+ split_size_or_sections=self.d_model,
349
+ dim=-1,
350
+ )
351
+ return (weight * self.norm(input) + bias, embedding)
352
+
353
+ weight, bias = torch.split(
354
+ self.project_layer(embedding),
355
+ split_size_or_sections=self.d_model,
356
+ dim=-1,
357
+ )
358
+ return weight * self.norm(input) + bias
359
+
360
+
361
+ def _get_clones(module, N):
362
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
AR/modules/transformer_onnx.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
2
+ import copy
3
+ import numbers
4
+ from functools import partial
5
+ from typing import Any
6
+ from typing import Callable
7
+ from typing import List
8
+ from typing import Optional
9
+ from typing import Tuple
10
+ from typing import Union
11
+
12
+ import torch
13
+ from AR.modules.activation_onnx import MultiheadAttention
14
+ from AR.modules.scaling import BalancedDoubleSwish
15
+ from torch import nn
16
+ from torch import Tensor
17
+ from torch.nn import functional as F
18
+
19
+ _shape_t = Union[int, List[int], torch.Size]
20
+
21
+
22
+ class LayerNorm(nn.Module):
23
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
24
+ normalized_shape: Tuple[int, ...]
25
+ eps: float
26
+ elementwise_affine: bool
27
+
28
+ def __init__(
29
+ self,
30
+ normalized_shape: _shape_t,
31
+ eps: float = 1e-5,
32
+ elementwise_affine: bool = True,
33
+ device=None,
34
+ dtype=None,
35
+ ) -> None:
36
+ factory_kwargs = {"device": device, "dtype": dtype}
37
+ super(LayerNorm, self).__init__()
38
+ if isinstance(normalized_shape, numbers.Integral):
39
+ # mypy error: incompatible types in assignment
40
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
41
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
42
+ self.eps = eps
43
+ self.elementwise_affine = elementwise_affine
44
+ if self.elementwise_affine:
45
+ self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
46
+ self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
47
+ else:
48
+ self.register_parameter("weight", None)
49
+ self.register_parameter("bias", None)
50
+
51
+ self.reset_parameters()
52
+
53
+ def reset_parameters(self) -> None:
54
+ if self.elementwise_affine:
55
+ nn.init.ones_(self.weight)
56
+ nn.init.zeros_(self.bias)
57
+
58
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
59
+ if isinstance(input, tuple):
60
+ input, embedding = input
61
+ return (
62
+ F.layer_norm(
63
+ input,
64
+ self.normalized_shape,
65
+ self.weight,
66
+ self.bias,
67
+ self.eps,
68
+ ),
69
+ embedding,
70
+ )
71
+
72
+ assert embedding is None
73
+ return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
74
+
75
+ def extra_repr(self) -> str:
76
+ return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
77
+
78
+
79
+ class IdentityNorm(nn.Module):
80
+ def __init__(
81
+ self,
82
+ d_model: int,
83
+ eps: float = 1e-5,
84
+ device=None,
85
+ dtype=None,
86
+ ) -> None:
87
+ super(IdentityNorm, self).__init__()
88
+
89
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
90
+ if isinstance(input, tuple):
91
+ return input
92
+
93
+ assert embedding is None
94
+ return input
95
+
96
+
97
+ class TransformerEncoder(nn.Module):
98
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
99
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
100
+
101
+ Args:
102
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
103
+ num_layers: the number of sub-encoder-layers in the encoder (required).
104
+ norm: the layer normalization component (optional).
105
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
106
+ (and convert back on output). This will improve the overall performance of
107
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
108
+
109
+ Examples::
110
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
111
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
112
+ >>> src = torch.rand(10, 32, 512)
113
+ >>> out = transformer_encoder(src)
114
+ """
115
+
116
+ __constants__ = ["norm"]
117
+
118
+ def __init__(self, encoder_layer, num_layers, norm=None):
119
+ super(TransformerEncoder, self).__init__()
120
+ self.layers = _get_clones(encoder_layer, num_layers)
121
+ self.num_layers = num_layers
122
+ self.norm = norm
123
+
124
+ def forward(
125
+ self,
126
+ src: Tensor,
127
+ mask: Optional[Tensor] = None,
128
+ src_key_padding_mask: Optional[Tensor] = None,
129
+ return_layer_states: bool = False,
130
+ cache=None,
131
+ ) -> Tensor:
132
+ output = src
133
+ for mod in self.layers:
134
+ output = mod(
135
+ output,
136
+ src_mask=mask,
137
+ src_key_padding_mask=src_key_padding_mask,
138
+ cache=cache,
139
+ )
140
+
141
+ if self.norm is not None:
142
+ output = self.norm(output)
143
+
144
+ return output
145
+
146
+
147
+ class TransformerEncoderLayer(nn.Module):
148
+ __constants__ = ["batch_first", "norm_first"]
149
+
150
+ def __init__(
151
+ self,
152
+ d_model: int,
153
+ nhead: int,
154
+ dim_feedforward: int = 2048,
155
+ dropout: float = 0.1,
156
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
157
+ batch_first: bool = False,
158
+ norm_first: bool = False,
159
+ device=None,
160
+ dtype=None,
161
+ linear1_self_attention_cls: nn.Module = nn.Linear,
162
+ linear2_self_attention_cls: nn.Module = nn.Linear,
163
+ linear1_feedforward_cls: nn.Module = nn.Linear,
164
+ linear2_feedforward_cls: nn.Module = nn.Linear,
165
+ layer_norm_cls: nn.Module = LayerNorm,
166
+ layer_norm_eps: float = 1e-5,
167
+ adaptive_layer_norm=False,
168
+ ) -> None:
169
+ factory_kwargs = {"device": device, "dtype": dtype}
170
+ super(TransformerEncoderLayer, self).__init__()
171
+ self.self_attn = MultiheadAttention(
172
+ d_model, # 512 16
173
+ nhead,
174
+ dropout=dropout,
175
+ batch_first=batch_first,
176
+ linear1_cls=linear1_self_attention_cls,
177
+ linear2_cls=linear2_self_attention_cls,
178
+ **factory_kwargs,
179
+ )
180
+ self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
181
+ self.dropout = nn.Dropout(dropout)
182
+ self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
183
+ self.norm_first = norm_first
184
+ self.dropout1 = nn.Dropout(dropout)
185
+ self.dropout2 = nn.Dropout(dropout)
186
+ if isinstance(activation, str):
187
+ activation = _get_activation_fn(activation)
188
+ elif isinstance(activation, partial):
189
+ activation = activation(d_model)
190
+ elif activation == BalancedDoubleSwish:
191
+ activation = BalancedDoubleSwish(d_model)
192
+ self.activation = activation
193
+
194
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
195
+ if layer_norm_cls == IdentityNorm:
196
+ norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
197
+ else:
198
+ norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
199
+
200
+ if adaptive_layer_norm:
201
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
202
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
203
+ else:
204
+ self.norm1 = norm1
205
+ self.norm2 = norm2
206
+
207
+ def __setstate__(self, state):
208
+ super(TransformerEncoderLayer, self).__setstate__(state)
209
+ if not hasattr(self, "activation"):
210
+ self.activation = F.relu
211
+
212
+ def forward(
213
+ self,
214
+ src: Tensor,
215
+ src_mask: Optional[Tensor] = None,
216
+ src_key_padding_mask: Optional[Tensor] = None,
217
+ cache=None,
218
+ ) -> Tensor:
219
+ x = src
220
+ stage_embedding = None
221
+ x = self.norm1(
222
+ x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
223
+ stage_embedding,
224
+ )
225
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
226
+
227
+ return x
228
+
229
+ def _sa_block(
230
+ self,
231
+ x: Tensor,
232
+ attn_mask: Optional[Tensor],
233
+ key_padding_mask: Optional[Tensor],
234
+ cache=None,
235
+ ) -> Tensor:
236
+ x = self.self_attn(
237
+ x,
238
+ x,
239
+ x,
240
+ attn_mask=attn_mask,
241
+ key_padding_mask=key_padding_mask,
242
+ need_weights=False,
243
+ cache=cache,
244
+ )
245
+ return self.dropout1(x)
246
+
247
+ def _ff_block(self, x: Tensor) -> Tensor:
248
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
249
+ return self.dropout2(x)
250
+
251
+
252
+ class AdaptiveLayerNorm(nn.Module):
253
+ r"""Adaptive Layer Normalization"""
254
+
255
+ def __init__(self, d_model, norm) -> None:
256
+ super(AdaptiveLayerNorm, self).__init__()
257
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
258
+ self.norm = norm
259
+ self.d_model = d_model
260
+ self.eps = self.norm.eps
261
+
262
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
263
+ if isinstance(input, tuple):
264
+ input, embedding = input
265
+ weight, bias = torch.split(
266
+ self.project_layer(embedding),
267
+ split_size_or_sections=self.d_model,
268
+ dim=-1,
269
+ )
270
+ return (weight * self.norm(input) + bias, embedding)
271
+
272
+ weight, bias = torch.split(
273
+ self.project_layer(embedding),
274
+ split_size_or_sections=self.d_model,
275
+ dim=-1,
276
+ )
277
+ return weight * self.norm(input) + bias
278
+
279
+
280
+ def _get_clones(module, N):
281
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
AR/text_processing/__init__.py ADDED
File without changes
AR/text_processing/phonemizer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/phonemizer.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import itertools
4
+ import re
5
+ from typing import Dict
6
+ from typing import List
7
+
8
+ import regex
9
+ from gruut import sentences
10
+ from gruut.const import Sentence
11
+ from gruut.const import Word
12
+ from AR.text_processing.symbols import SYMBOL_TO_ID
13
+
14
+
15
+ class GruutPhonemizer:
16
+ def __init__(self, language: str):
17
+ self._phonemizer = sentences
18
+ self.lang = language
19
+ self.symbol_to_id = SYMBOL_TO_ID
20
+ self._special_cases_dict: Dict[str] = {
21
+ r"\.\.\.": "... ",
22
+ ";": "; ",
23
+ ":": ": ",
24
+ ",": ", ",
25
+ r"\.": ". ",
26
+ "!": "! ",
27
+ r"\?": "? ",
28
+ "—": "—",
29
+ "…": "… ",
30
+ "«": "«",
31
+ "»": "»",
32
+ }
33
+ self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
34
+
35
+ def _normalize_punctuation(self, text: str) -> str:
36
+ text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
37
+ text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
38
+ text = regex.sub(r"\pZ+", r" ", text)
39
+ return text.strip()
40
+
41
+ def _convert_punctuation(self, word: Word) -> str:
42
+ if not word.phonemes:
43
+ return ""
44
+ if word.phonemes[0] in ["‖", "|"]:
45
+ return word.text.strip()
46
+
47
+ phonemes = "".join(word.phonemes)
48
+ # remove modifier characters ˈˌː with regex
49
+ phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
50
+ return phonemes.strip()
51
+
52
+ def phonemize(self, text: str, espeak: bool = False) -> str:
53
+ text_to_phonemize: str = self._normalize_punctuation(text)
54
+ sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
55
+ words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
56
+ return " ".join(words)
57
+
58
+ def transform(self, phonemes):
59
+ # convert phonemes to ids
60
+ # dictionary is in symbols.py
61
+ return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
62
+
63
+
64
+ if __name__ == "__main__":
65
+ phonemizer = GruutPhonemizer("en-us")
66
+ # text -> IPA
67
+ phonemes = phonemizer.phonemize("Hello, wor-ld ?")
68
+ print("phonemes:", phonemes)
69
+ print("len(phonemes):", len(phonemes))
70
+ phoneme_ids = phonemizer.transform(phonemes)
71
+ print("phoneme_ids:", phoneme_ids)
72
+ print("len(phoneme_ids):", len(phoneme_ids))
AR/text_processing/symbols.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/symbols.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ PAD = "_"
4
+ PUNCTUATION = ';:,.!?¡¿—…"«»“” '
5
+ LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
6
+ IPA_LETTERS = (
7
+ "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
8
+ )
9
+ SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
10
+ SPACE_ID = SYMBOLS.index(" ")
11
+ SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
12
+ ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)}
AR/utils/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def str2bool(str):
5
+ return True if str.lower() == "true" else False
6
+
7
+
8
+ def get_newest_ckpt(string_list):
9
+ # 定义一个正则表达式模式,用于匹配字符串中的数字
10
+ pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
11
+
12
+ # 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
13
+ extracted_info = []
14
+ for string in string_list:
15
+ match = re.match(pattern, string)
16
+ if match:
17
+ epoch = int(match.group(1))
18
+ step = int(match.group(2))
19
+ extracted_info.append((epoch, step, string))
20
+ # 按照 epoch 后面的数字和 step 后面的数字进行排序
21
+ sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
22
+ # 获取最新的 ckpt 文件名
23
+ newest_ckpt = sorted_info[0][2]
24
+ return newest_ckpt
25
+
26
+
27
+ # 文本存在且不为空时 return True
28
+ def check_txt_file(file_path):
29
+ try:
30
+ with open(file_path, "r") as file:
31
+ text = file.readline().strip()
32
+ assert text.strip() != ""
33
+ return text
34
+ except Exception:
35
+ return False
36
+ return False
AR/utils/initialize.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Initialize modules for espnet2 neural networks."""
3
+
4
+ import torch
5
+ from typeguard import check_argument_types
6
+
7
+
8
+ def initialize(model: torch.nn.Module, init: str):
9
+ """Initialize weights of a neural network module.
10
+
11
+ Parameters are initialized using the given method or distribution.
12
+
13
+ Custom initialization routines can be implemented into submodules
14
+ as function `espnet_initialization_fn` within the custom module.
15
+
16
+ Args:
17
+ model: Target.
18
+ init: Method of initialization.
19
+ """
20
+ assert check_argument_types()
21
+ print("init with", init)
22
+
23
+ # weight init
24
+ for p in model.parameters():
25
+ if p.dim() > 1:
26
+ if init == "xavier_uniform":
27
+ torch.nn.init.xavier_uniform_(p.data)
28
+ elif init == "xavier_normal":
29
+ torch.nn.init.xavier_normal_(p.data)
30
+ elif init == "kaiming_uniform":
31
+ torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
32
+ elif init == "kaiming_normal":
33
+ torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
34
+ else:
35
+ raise ValueError("Unknown initialization: " + init)
36
+ # bias init
37
+ for name, p in model.named_parameters():
38
+ if ".bias" in name and p.dim() == 1:
39
+ p.data.zero_()
AR/utils/io.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ import yaml
5
+
6
+
7
+ def load_yaml_config(path):
8
+ with open(path) as f:
9
+ config = yaml.full_load(f)
10
+ return config
11
+
12
+
13
+ def save_config_to_yaml(config, path):
14
+ assert path.endswith(".yaml")
15
+ with open(path, "w") as f:
16
+ f.write(yaml.dump(config))
17
+ f.close()
18
+
19
+
20
+ def write_args(args, path):
21
+ args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
22
+ with open(path, "a") as args_file:
23
+ args_file.write("==> torch version: {}\n".format(torch.__version__))
24
+ args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
25
+ args_file.write("==> Cmd:\n")
26
+ args_file.write(str(sys.argv))
27
+ args_file.write("\n==> args:\n")
28
+ for k, v in sorted(args_dict.items()):
29
+ args_file.write(" %s: %s\n" % (str(k), str(v)))
30
+ args_file.close()
BigVGAN/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
BigVGAN/README.md ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
2
+
3
+ #### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
4
+
5
+ [[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN)
6
+
7
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bigvgan-a-universal-neural-vocoder-with-large/speech-synthesis-on-libritts)](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large)
8
+
9
+ <center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
10
+
11
+ ## News
12
+ - **Sep 2024 (v2.4):**
13
+ - We have updated the pretrained checkpoints trained for 5M steps. This is final release of the BigVGAN-v2 checkpoints.
14
+
15
+ - **Jul 2024 (v2.3):**
16
+ - General refactor and code improvements for improved readability.
17
+ - Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark.
18
+
19
+ - **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio.
20
+
21
+ - **Jul 2024 (v2.1):** BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces.
22
+
23
+ - **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
24
+ - Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
25
+ - Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
26
+ - Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
27
+ - We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
28
+
29
+ ## Installation
30
+
31
+ The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
32
+
33
+ ```shell
34
+ conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
35
+ conda activate bigvgan
36
+ ```
37
+
38
+ Clone the repository and install dependencies:
39
+
40
+ ```shell
41
+ git clone https://github.com/NVIDIA/BigVGAN
42
+ cd BigVGAN
43
+ pip install -r requirements.txt
44
+ ```
45
+
46
+ ## Inference Quickstart using 🤗 Hugging Face Hub
47
+
48
+ Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input.
49
+
50
+ ```python
51
+ device = 'cuda'
52
+
53
+ import torch
54
+ import bigvgan
55
+ import librosa
56
+ from meldataset import get_mel_spectrogram
57
+
58
+ # instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
59
+ model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
60
+
61
+ # remove weight norm in the model and set to eval mode
62
+ model.remove_weight_norm()
63
+ model = model.eval().to(device)
64
+
65
+ # load wav file and compute mel spectrogram
66
+ wav_path = '/path/to/your/audio.wav'
67
+ wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
68
+ wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
69
+
70
+ # compute mel spectrogram from the ground truth audio
71
+ mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
72
+
73
+ # generate waveform from mel
74
+ with torch.inference_mode():
75
+ wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
76
+ wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]
77
+
78
+ # you can convert the generated waveform to 16 bit linear PCM
79
+ wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
80
+ ```
81
+
82
+ ## Local gradio demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
83
+
84
+ You can run a local gradio demo using below command:
85
+
86
+ ```python
87
+ pip install -r demo/requirements.txt
88
+ python demo/app.py
89
+ ```
90
+
91
+ ## Training
92
+
93
+ Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
94
+
95
+ ```shell
96
+ cd filelists/LibriTTS && \
97
+ ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
98
+ ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
99
+ ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
100
+ ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
101
+ ln -s /path/to/your/LibriTTS/dev-other dev-other && \
102
+ ln -s /path/to/your/LibriTTS/test-clean test-clean && \
103
+ ln -s /path/to/your/LibriTTS/test-other test-other && \
104
+ cd ../..
105
+ ```
106
+
107
+ Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
108
+
109
+ ```shell
110
+ python train.py \
111
+ --config configs/bigvgan_v2_24khz_100band_256x.json \
112
+ --input_wavs_dir filelists/LibriTTS \
113
+ --input_training_file filelists/LibriTTS/train-full.txt \
114
+ --input_validation_file filelists/LibriTTS/val-full.txt \
115
+ --list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \
116
+ --list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \
117
+ --checkpoint_path exp/bigvgan_v2_24khz_100band_256x
118
+ ```
119
+
120
+ ## Synthesis
121
+
122
+ Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
123
+ It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
124
+
125
+ ```shell
126
+ python inference.py \
127
+ --checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
128
+ --input_wavs_dir /path/to/your/input_wav \
129
+ --output_dir /path/to/your/output_wav
130
+ ```
131
+
132
+ `inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`.
133
+ It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
134
+
135
+ Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
136
+
137
+ ```shell
138
+ python inference_e2e.py \
139
+ --checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
140
+ --input_mels_dir /path/to/your/input_mel \
141
+ --output_dir /path/to/your/output_wav
142
+ ```
143
+
144
+ ## Using Custom CUDA Kernel for Synthesis
145
+
146
+ You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
147
+
148
+ ```python
149
+ generator = BigVGAN(h, use_cuda_kernel=True)
150
+ ```
151
+
152
+ You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
153
+
154
+ When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
155
+
156
+ Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
157
+
158
+ We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
159
+
160
+ ```python
161
+ python tests/test_cuda_vs_torch_model.py \
162
+ --checkpoint_file /path/to/your/bigvgan_generator.pt
163
+ ```
164
+
165
+ ```shell
166
+ loading plain Pytorch BigVGAN
167
+ ...
168
+ loading CUDA kernel BigVGAN with auto-build
169
+ Detected CUDA files, patching ldflags
170
+ Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja..
171
+ Building extension module anti_alias_activation_cuda...
172
+ ...
173
+ Loading extension module anti_alias_activation_cuda...
174
+ ...
175
+ Loading '/path/to/your/bigvgan_generator.pt'
176
+ ...
177
+ [Success] test CUDA fused vs. plain torch BigVGAN inference
178
+ > mean_difference=0.0007238413265440613
179
+ ...
180
+ ```
181
+
182
+ If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
183
+
184
+ ## Pretrained Models
185
+
186
+ We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a).
187
+ One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories.
188
+
189
+ | Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned |
190
+ |:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:|
191
+ | [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 5M | No |
192
+ | [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 5M | No |
193
+ | [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 5M | No |
194
+ | [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 5M | No |
195
+ | [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 5M | No |
196
+ | [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No |
197
+ | [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No |
198
+ | [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No |
199
+ | [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No |
200
+
201
+ The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
202
+ We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
203
+ Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality.
204
+
205
+ You can fine-tune the models by:
206
+
207
+ 1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states)
208
+ 2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py`
209
+
210
+ ## Training Details of BigVGAN-v2
211
+
212
+ Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
213
+
214
+ Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
215
+
216
+ When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
217
+
218
+ ## Evaluation Results of BigVGAN-v2
219
+
220
+ Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
221
+
222
+ | Model | Dataset | Steps | PESQ(↑) | M-STFT(↓) | MCD(↓) | Periodicity(↓) | V/UV F1(↑) |
223
+ |:----------:|:-----------------------:|:-----:|:---------:|:----------:|:----------:|:--------------:|:----------:|
224
+ | BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 |
225
+ | BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 |
226
+ | BigVGAN-v2 | Large-scale Compilation | 3M | 4.359 | 0.7134 | 0.3060 | 0.0621 | 0.9777 |
227
+ | BigVGAN-v2 | Large-scale Compilation | 5M | **4.362** | **0.7026** | **0.2903** | **0.0593** | **0.9793** |
228
+
229
+ ## Speed Benchmark
230
+
231
+ Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model.
232
+
233
+ | GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) |
234
+ |:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:|
235
+ | NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 |
236
+ | | | True | 3916.5 | 163.2x | 1.3 |
237
+ | | 2048 | False | 1899.6 | 79.2x | 1.7 |
238
+ | | | True | 5330.1 | 222.1x | 1.7 |
239
+ | | 16384 | False | 1973.8 | 82.2x | 5.0 |
240
+ | | | True | 5761.7 | 240.1x | 4.4 |
241
+ | NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 |
242
+ | | | True | 1598.1 | 66.6x | 1.3 |
243
+ | | 2048 | False | 929.9 | 38.7x | 1.7 |
244
+ | | | True | 1971.3 | 82.1x | 1.6 |
245
+ | | 16384 | False | 943.4 | 39.3x | 5.0 |
246
+ | | | True | 2026.5 | 84.4x | 3.9 |
247
+ | NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 |
248
+ | | | True | 811.3 | 33.8x | 1.3 |
249
+ | | 2048 | False | 576.5 | 24.0x | 1.7 |
250
+ | | | True | 1023.0 | 42.6x | 1.5 |
251
+ | | 16384 | False | 589.4 | 24.6x | 5.0 |
252
+ | | | True | 1068.1 | 44.5x | 3.2 |
253
+
254
+ ## Acknowledgements
255
+
256
+ We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
257
+
258
+ ## References
259
+
260
+ - [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
261
+ - [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
262
+ - [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
263
+ - [Julius](https://github.com/adefossez/julius) (for low-pass filter)
264
+ - [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
265
+ - [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
266
+ - [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
BigVGAN/activations.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ """
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ """
25
+
26
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
27
+ """
28
+ Initialization.
29
+ INPUT:
30
+ - in_features: shape of the input
31
+ - alpha: trainable parameter
32
+ alpha is initialized to 1 by default, higher values = higher-frequency.
33
+ alpha will be trained along with the rest of your model.
34
+ """
35
+ super(Snake, self).__init__()
36
+ self.in_features = in_features
37
+
38
+ # Initialize alpha
39
+ self.alpha_logscale = alpha_logscale
40
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
41
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
42
+ else: # Linear scale alphas initialized to ones
43
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
44
+
45
+ self.alpha.requires_grad = alpha_trainable
46
+
47
+ self.no_div_by_zero = 0.000000001
48
+
49
+ def forward(self, x):
50
+ """
51
+ Forward pass of the function.
52
+ Applies the function to the input elementwise.
53
+ Snake ∶= x + 1/a * sin^2 (xa)
54
+ """
55
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
56
+ if self.alpha_logscale:
57
+ alpha = torch.exp(alpha)
58
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
59
+
60
+ return x
61
+
62
+
63
+ class SnakeBeta(nn.Module):
64
+ """
65
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
66
+ Shape:
67
+ - Input: (B, C, T)
68
+ - Output: (B, C, T), same shape as the input
69
+ Parameters:
70
+ - alpha - trainable parameter that controls frequency
71
+ - beta - trainable parameter that controls magnitude
72
+ References:
73
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
74
+ https://arxiv.org/abs/2006.08195
75
+ Examples:
76
+ >>> a1 = snakebeta(256)
77
+ >>> x = torch.randn(256)
78
+ >>> x = a1(x)
79
+ """
80
+
81
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
82
+ """
83
+ Initialization.
84
+ INPUT:
85
+ - in_features: shape of the input
86
+ - alpha - trainable parameter that controls frequency
87
+ - beta - trainable parameter that controls magnitude
88
+ alpha is initialized to 1 by default, higher values = higher-frequency.
89
+ beta is initialized to 1 by default, higher values = higher-magnitude.
90
+ alpha will be trained along with the rest of your model.
91
+ """
92
+ super(SnakeBeta, self).__init__()
93
+ self.in_features = in_features
94
+
95
+ # Initialize alpha
96
+ self.alpha_logscale = alpha_logscale
97
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
98
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
99
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
100
+ else: # Linear scale alphas initialized to ones
101
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
102
+ self.beta = Parameter(torch.ones(in_features) * alpha)
103
+
104
+ self.alpha.requires_grad = alpha_trainable
105
+ self.beta.requires_grad = alpha_trainable
106
+
107
+ self.no_div_by_zero = 0.000000001
108
+
109
+ def forward(self, x):
110
+ """
111
+ Forward pass of the function.
112
+ Applies the function to the input elementwise.
113
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
114
+ """
115
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
116
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
117
+ if self.alpha_logscale:
118
+ alpha = torch.exp(alpha)
119
+ beta = torch.exp(beta)
120
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
121
+
122
+ return x
BigVGAN/alias_free_activation/cuda/__init__.py ADDED
File without changes
BigVGAN/alias_free_activation/cuda/activation1d.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from alias_free_activation.torch.resample import UpSample1d, DownSample1d
7
+
8
+ # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
9
+ from alias_free_activation.cuda import load
10
+
11
+ anti_alias_activation_cuda = load.load()
12
+
13
+
14
+ class FusedAntiAliasActivation(torch.autograd.Function):
15
+ """
16
+ Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
17
+ The hyperparameters are hard-coded in the kernel to maximize speed.
18
+ NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
19
+ """
20
+
21
+ @staticmethod
22
+ def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
23
+ activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta)
24
+
25
+ return activation_results
26
+
27
+ @staticmethod
28
+ def backward(ctx, output_grads):
29
+ raise NotImplementedError
30
+ return output_grads, None, None
31
+
32
+
33
+ class Activation1d(nn.Module):
34
+ def __init__(
35
+ self,
36
+ activation,
37
+ up_ratio: int = 2,
38
+ down_ratio: int = 2,
39
+ up_kernel_size: int = 12,
40
+ down_kernel_size: int = 12,
41
+ fused: bool = True,
42
+ ):
43
+ super().__init__()
44
+ self.up_ratio = up_ratio
45
+ self.down_ratio = down_ratio
46
+ self.act = activation
47
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
48
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
49
+
50
+ self.fused = fused # Whether to use fused CUDA kernel or not
51
+
52
+ def forward(self, x):
53
+ if not self.fused:
54
+ x = self.upsample(x)
55
+ x = self.act(x)
56
+ x = self.downsample(x)
57
+ return x
58
+ else:
59
+ if self.act.__class__.__name__ == "Snake":
60
+ beta = self.act.alpha.data # Snake uses same params for alpha and beta
61
+ else:
62
+ beta = self.act.beta.data # Snakebeta uses different params for alpha and beta
63
+ alpha = self.act.alpha.data
64
+ if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log
65
+ alpha = torch.log(alpha)
66
+ beta = torch.log(beta)
67
+
68
+ x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)
69
+ return x
BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <torch/extension.h>
18
+
19
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
23
+ }
BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include <cuda.h>
19
+ #include <cuda_runtime.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_profiler_api.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+ #include <torch/extension.h>
24
+ #include "type_shim.h"
25
+ #include <assert.h>
26
+ #include <cfloat>
27
+ #include <limits>
28
+ #include <stdint.h>
29
+ #include <c10/macros/Macros.h>
30
+
31
+ namespace
32
+ {
33
+ // Hard-coded hyperparameters
34
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
35
+ constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
36
+ constexpr int BUFFER_SIZE = 32;
37
+ constexpr int FILTER_SIZE = 12;
38
+ constexpr int HALF_FILTER_SIZE = 6;
39
+ constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
40
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
41
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
42
+
43
+ template <typename input_t, typename output_t, typename acc_t>
44
+ __global__ void anti_alias_activation_forward(
45
+ output_t *dst,
46
+ const input_t *src,
47
+ const input_t *up_ftr,
48
+ const input_t *down_ftr,
49
+ const input_t *alpha,
50
+ const input_t *beta,
51
+ int batch_size,
52
+ int channels,
53
+ int seq_len)
54
+ {
55
+ // Up and downsample filters
56
+ input_t up_filter[FILTER_SIZE];
57
+ input_t down_filter[FILTER_SIZE];
58
+
59
+ // Load data from global memory including extra indices reserved for replication paddings
60
+ input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
61
+ input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
62
+
63
+ // Output stores downsampled output before writing to dst
64
+ output_t output[BUFFER_SIZE];
65
+
66
+ // blockDim/threadIdx = (128, 1, 1)
67
+ // gridDim/blockIdx = (seq_blocks, channels, batches)
68
+ int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
69
+ int local_offset = threadIdx.x * BUFFER_SIZE;
70
+ int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
71
+
72
+ // intermediate have double the seq_len
73
+ int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
74
+ int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
75
+
76
+ // Get values needed for replication padding before moving pointer
77
+ const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
78
+ input_t seq_left_most_value = right_most_pntr[0];
79
+ input_t seq_right_most_value = right_most_pntr[seq_len - 1];
80
+
81
+ // Move src and dst pointers
82
+ src += block_offset + local_offset;
83
+ dst += block_offset + local_offset;
84
+
85
+ // Alpha and beta values for snake activatons. Applies exp by default
86
+ alpha = alpha + blockIdx.y;
87
+ input_t alpha_val = expf(alpha[0]);
88
+ beta = beta + blockIdx.y;
89
+ input_t beta_val = expf(beta[0]);
90
+
91
+ #pragma unroll
92
+ for (int it = 0; it < FILTER_SIZE; it += 1)
93
+ {
94
+ up_filter[it] = up_ftr[it];
95
+ down_filter[it] = down_ftr[it];
96
+ }
97
+
98
+ // Apply replication padding for upsampling, matching torch impl
99
+ #pragma unroll
100
+ for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
101
+ {
102
+ int element_index = seq_offset + it; // index for element
103
+ if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
104
+ {
105
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
106
+ }
107
+ if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
108
+ {
109
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
110
+ }
111
+ if ((element_index >= 0) && (element_index < seq_len))
112
+ {
113
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
114
+ }
115
+ }
116
+
117
+ // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
118
+ #pragma unroll
119
+ for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
120
+ {
121
+ input_t acc = 0.0;
122
+ int element_index = intermediate_seq_offset + it; // index for intermediate
123
+ #pragma unroll
124
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
125
+ {
126
+ if ((element_index + f_idx) >= 0)
127
+ {
128
+ acc += up_filter[f_idx] * elements[it + f_idx];
129
+ }
130
+ }
131
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
132
+ }
133
+
134
+ // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
135
+ double no_div_by_zero = 0.000000001;
136
+ #pragma unroll
137
+ for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
138
+ {
139
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
140
+ }
141
+
142
+ // Apply replication padding before downsampling conv from intermediates
143
+ #pragma unroll
144
+ for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
145
+ {
146
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
147
+ }
148
+ #pragma unroll
149
+ for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
150
+ {
151
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
152
+ }
153
+
154
+ // Apply downsample strided convolution (assuming stride=2) from intermediates
155
+ #pragma unroll
156
+ for (int it = 0; it < BUFFER_SIZE; it += 1)
157
+ {
158
+ input_t acc = 0.0;
159
+ #pragma unroll
160
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
161
+ {
162
+ // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
163
+ acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
164
+ }
165
+ output[it] = acc;
166
+ }
167
+
168
+ // Write output to dst
169
+ #pragma unroll
170
+ for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
171
+ {
172
+ int element_index = seq_offset + it;
173
+ if (element_index < seq_len)
174
+ {
175
+ dst[it] = output[it];
176
+ }
177
+ }
178
+
179
+ }
180
+
181
+ template <typename input_t, typename output_t, typename acc_t>
182
+ void dispatch_anti_alias_activation_forward(
183
+ output_t *dst,
184
+ const input_t *src,
185
+ const input_t *up_ftr,
186
+ const input_t *down_ftr,
187
+ const input_t *alpha,
188
+ const input_t *beta,
189
+ int batch_size,
190
+ int channels,
191
+ int seq_len)
192
+ {
193
+ if (seq_len == 0)
194
+ {
195
+ return;
196
+ }
197
+ else
198
+ {
199
+ // Use 128 threads per block to maximimize gpu utilization
200
+ constexpr int threads_per_block = 128;
201
+ constexpr int seq_len_per_block = 4096;
202
+ int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
203
+ dim3 blocks(blocks_per_seq_len, channels, batch_size);
204
+ dim3 threads(threads_per_block, 1, 1);
205
+
206
+ anti_alias_activation_forward<input_t, output_t, acc_t>
207
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
208
+ }
209
+ }
210
+ }
211
+
212
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
213
+ {
214
+ // Input is a 3d tensor with dimensions [batches, channels, seq_len]
215
+ const int batches = input.size(0);
216
+ const int channels = input.size(1);
217
+ const int seq_len = input.size(2);
218
+
219
+ // Output
220
+ auto act_options = input.options().requires_grad(false);
221
+
222
+ torch::Tensor anti_alias_activation_results =
223
+ torch::empty({batches, channels, seq_len}, act_options);
224
+
225
+ void *input_ptr = static_cast<void *>(input.data_ptr());
226
+ void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
227
+ void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
228
+ void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
229
+ void *beta_ptr = static_cast<void *>(beta.data_ptr());
230
+ void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
231
+
232
+ DISPATCH_FLOAT_HALF_AND_BFLOAT(
233
+ input.scalar_type(),
234
+ "dispatch anti alias activation_forward",
235
+ dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
236
+ reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
237
+ reinterpret_cast<const scalar_t *>(input_ptr),
238
+ reinterpret_cast<const scalar_t *>(up_filter_ptr),
239
+ reinterpret_cast<const scalar_t *>(down_filter_ptr),
240
+ reinterpret_cast<const scalar_t *>(alpha_ptr),
241
+ reinterpret_cast<const scalar_t *>(beta_ptr),
242
+ batches,
243
+ channels,
244
+ seq_len););
245
+ return anti_alias_activation_results;
246
+ }
BigVGAN/alias_free_activation/cuda/build/_ ADDED
@@ -0,0 +1 @@
 
 
1
+
BigVGAN/alias_free_activation/cuda/compat.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+ #ifndef TORCH_CHECK
22
+ #define TORCH_CHECK AT_CHECK
23
+ #endif
24
+
25
+ #ifdef VERSION_GE_1_3
26
+ #define DATA_PTR data_ptr
27
+ #else
28
+ #define DATA_PTR data
29
+ #endif
BigVGAN/alias_free_activation/cuda/load.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import pathlib
6
+ import subprocess
7
+
8
+ from torch.utils import cpp_extension
9
+
10
+ """
11
+ Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
12
+ Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
13
+ """
14
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
15
+
16
+
17
+ def load():
18
+ # Check if cuda 11 is installed for compute capability 8.0
19
+ cc_flag = []
20
+ _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
21
+ if int(bare_metal_major) >= 11:
22
+ cc_flag.append("-gencode")
23
+ cc_flag.append("arch=compute_80,code=sm_80")
24
+
25
+ # Build path
26
+ srcpath = pathlib.Path(__file__).parent.absolute()
27
+ buildpath = srcpath / "build"
28
+ _create_build_dir(buildpath)
29
+
30
+ # Helper function to build the kernels.
31
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
32
+ return cpp_extension.load(
33
+ name=name,
34
+ sources=sources,
35
+ build_directory=buildpath,
36
+ extra_cflags=[
37
+ "-O3",
38
+ ],
39
+ extra_cuda_cflags=[
40
+ "-O3",
41
+ "-gencode",
42
+ "arch=compute_70,code=sm_70",
43
+ "--use_fast_math",
44
+ ]
45
+ + extra_cuda_flags
46
+ + cc_flag,
47
+ verbose=True,
48
+ )
49
+
50
+ extra_cuda_flags = [
51
+ "-U__CUDA_NO_HALF_OPERATORS__",
52
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
53
+ "--expt-relaxed-constexpr",
54
+ "--expt-extended-lambda",
55
+ ]
56
+
57
+ sources = [
58
+ srcpath / "anti_alias_activation.cpp",
59
+ srcpath / "anti_alias_activation_cuda.cu",
60
+ ]
61
+ anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags)
62
+
63
+ return anti_alias_activation_cuda
64
+
65
+
66
+ def _get_cuda_bare_metal_version(cuda_dir):
67
+ raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
68
+ output = raw_output.split()
69
+ release_idx = output.index("release") + 1
70
+ release = output[release_idx].split(".")
71
+ bare_metal_major = release[0]
72
+ bare_metal_minor = release[1][0]
73
+
74
+ return raw_output, bare_metal_major, bare_metal_minor
75
+
76
+
77
+ def _create_build_dir(buildpath):
78
+ try:
79
+ os.mkdir(buildpath)
80
+ except OSError:
81
+ if not os.path.isdir(buildpath):
82
+ print(f"Creation of the build directory {buildpath} failed")
BigVGAN/alias_free_activation/cuda/type_shim.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include "compat.h"
19
+
20
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
21
+ switch (TYPE) \
22
+ { \
23
+ case at::ScalarType::Float: \
24
+ { \
25
+ using scalar_t = float; \
26
+ __VA_ARGS__; \
27
+ break; \
28
+ } \
29
+ case at::ScalarType::Half: \
30
+ { \
31
+ using scalar_t = at::Half; \
32
+ __VA_ARGS__; \
33
+ break; \
34
+ } \
35
+ case at::ScalarType::BFloat16: \
36
+ { \
37
+ using scalar_t = at::BFloat16; \
38
+ __VA_ARGS__; \
39
+ break; \
40
+ } \
41
+ default: \
42
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
43
+ }
44
+
45
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
46
+ switch (TYPEIN) \
47
+ { \
48
+ case at::ScalarType::Float: \
49
+ { \
50
+ using scalar_t_in = float; \
51
+ switch (TYPEOUT) \
52
+ { \
53
+ case at::ScalarType::Float: \
54
+ { \
55
+ using scalar_t_out = float; \
56
+ __VA_ARGS__; \
57
+ break; \
58
+ } \
59
+ case at::ScalarType::Half: \
60
+ { \
61
+ using scalar_t_out = at::Half; \
62
+ __VA_ARGS__; \
63
+ break; \
64
+ } \
65
+ case at::ScalarType::BFloat16: \
66
+ { \
67
+ using scalar_t_out = at::BFloat16; \
68
+ __VA_ARGS__; \
69
+ break; \
70
+ } \
71
+ default: \
72
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
73
+ } \
74
+ break; \
75
+ } \
76
+ case at::ScalarType::Half: \
77
+ { \
78
+ using scalar_t_in = at::Half; \
79
+ using scalar_t_out = at::Half; \
80
+ __VA_ARGS__; \
81
+ break; \
82
+ } \
83
+ case at::ScalarType::BFloat16: \
84
+ { \
85
+ using scalar_t_in = at::BFloat16; \
86
+ using scalar_t_out = at::BFloat16; \
87
+ __VA_ARGS__; \
88
+ break; \
89
+ } \
90
+ default: \
91
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
92
+ }
BigVGAN/alias_free_activation/torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
BigVGAN/alias_free_activation/torch/act.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(
10
+ self,
11
+ activation,
12
+ up_ratio: int = 2,
13
+ down_ratio: int = 2,
14
+ up_kernel_size: int = 12,
15
+ down_kernel_size: int = 12,
16
+ ):
17
+ super().__init__()
18
+ self.up_ratio = up_ratio
19
+ self.down_ratio = down_ratio
20
+ self.act = activation
21
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
22
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
23
+
24
+ # x: [B,C,T]
25
+ def forward(self, x):
26
+ x = self.upsample(x)
27
+ x = self.act(x)
28
+ x = self.downsample(x)
29
+
30
+ return x
BigVGAN/alias_free_activation/torch/filter.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if "sinc" in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(
21
+ x == 0,
22
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
23
+ torch.sin(math.pi * x) / math.pi / x,
24
+ )
25
+
26
+
27
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
+ # https://adefossez.github.io/julius/julius/lowpass.html
29
+ # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
31
+ even = kernel_size % 2 == 0
32
+ half_size = kernel_size // 2
33
+
34
+ # For kaiser window
35
+ delta_f = 4 * half_width
36
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
37
+ if A > 50.0:
38
+ beta = 0.1102 * (A - 8.7)
39
+ elif A >= 21.0:
40
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
41
+ else:
42
+ beta = 0.0
43
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
44
+
45
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
46
+ if even:
47
+ time = torch.arange(-half_size, half_size) + 0.5
48
+ else:
49
+ time = torch.arange(kernel_size) - half_size
50
+ if cutoff == 0:
51
+ filter_ = torch.zeros_like(time)
52
+ else:
53
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
54
+ """
55
+ Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
56
+ """
57
+ filter_ /= filter_.sum()
58
+ filter = filter_.view(1, 1, kernel_size)
59
+
60
+ return filter
61
+
62
+
63
+ class LowPassFilter1d(nn.Module):
64
+ def __init__(
65
+ self,
66
+ cutoff=0.5,
67
+ half_width=0.6,
68
+ stride: int = 1,
69
+ padding: bool = True,
70
+ padding_mode: str = "replicate",
71
+ kernel_size: int = 12,
72
+ ):
73
+ """
74
+ kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
75
+ """
76
+ super().__init__()
77
+ if cutoff < -0.0:
78
+ raise ValueError("Minimum cutoff must be larger than zero.")
79
+ if cutoff > 0.5:
80
+ raise ValueError("A cutoff above 0.5 does not make sense.")
81
+ self.kernel_size = kernel_size
82
+ self.even = kernel_size % 2 == 0
83
+ self.pad_left = kernel_size // 2 - int(self.even)
84
+ self.pad_right = kernel_size // 2
85
+ self.stride = stride
86
+ self.padding = padding
87
+ self.padding_mode = padding_mode
88
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
89
+ self.register_buffer("filter", filter)
90
+
91
+ # Input [B, C, T]
92
+ def forward(self, x):
93
+ _, C, _ = x.shape
94
+
95
+ if self.padding:
96
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
97
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
98
+
99
+ return out
BigVGAN/alias_free_activation/torch/resample.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ self.stride = ratio
16
+ self.pad = self.kernel_size // ratio - 1
17
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
20
+ self.register_buffer("filter", filter)
21
+
22
+ # x: [B, C, T]
23
+ def forward(self, x):
24
+ _, C, _ = x.shape
25
+
26
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
27
+ x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
28
+ x = x[..., self.pad_left : -self.pad_right]
29
+
30
+ return x
31
+
32
+
33
+ class DownSample1d(nn.Module):
34
+ def __init__(self, ratio=2, kernel_size=None):
35
+ super().__init__()
36
+ self.ratio = ratio
37
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
38
+ self.lowpass = LowPassFilter1d(
39
+ cutoff=0.5 / ratio,
40
+ half_width=0.6 / ratio,
41
+ stride=ratio,
42
+ kernel_size=self.kernel_size,
43
+ )
44
+
45
+ def forward(self, x):
46
+ xx = self.lowpass(x)
47
+
48
+ return xx
BigVGAN/bigvgan.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Optional, Union, Dict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import Conv1d, ConvTranspose1d
15
+ from torch.nn.utils import weight_norm, remove_weight_norm
16
+
17
+ from . import activations
18
+ from .utils0 import init_weights, get_padding
19
+ from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
20
+ from .env import AttrDict
21
+
22
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
23
+
24
+
25
+ def load_hparams_from_json(path) -> AttrDict:
26
+ with open(path) as f:
27
+ data = f.read()
28
+ return AttrDict(json.loads(data))
29
+
30
+
31
+ class AMPBlock1(torch.nn.Module):
32
+ """
33
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
34
+ AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
35
+
36
+ Args:
37
+ h (AttrDict): Hyperparameters.
38
+ channels (int): Number of convolution channels.
39
+ kernel_size (int): Size of the convolution kernel. Default is 3.
40
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
41
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ h: AttrDict,
47
+ channels: int,
48
+ kernel_size: int = 3,
49
+ dilation: tuple = (1, 3, 5),
50
+ activation: str = None,
51
+ ):
52
+ super().__init__()
53
+
54
+ self.h = h
55
+
56
+ self.convs1 = nn.ModuleList(
57
+ [
58
+ weight_norm(
59
+ Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ stride=1,
64
+ dilation=d,
65
+ padding=get_padding(kernel_size, d),
66
+ )
67
+ )
68
+ for d in dilation
69
+ ]
70
+ )
71
+ self.convs1.apply(init_weights)
72
+
73
+ self.convs2 = nn.ModuleList(
74
+ [
75
+ weight_norm(
76
+ Conv1d(
77
+ channels,
78
+ channels,
79
+ kernel_size,
80
+ stride=1,
81
+ dilation=1,
82
+ padding=get_padding(kernel_size, 1),
83
+ )
84
+ )
85
+ for _ in range(len(dilation))
86
+ ]
87
+ )
88
+ self.convs2.apply(init_weights)
89
+
90
+ self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
91
+
92
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
93
+ if self.h.get("use_cuda_kernel", False):
94
+ from .alias_free_activation.cuda.activation1d import (
95
+ Activation1d as CudaActivation1d,
96
+ )
97
+
98
+ Activation1d = CudaActivation1d
99
+ else:
100
+ Activation1d = TorchActivation1d
101
+
102
+ # Activation functions
103
+ if activation == "snake":
104
+ self.activations = nn.ModuleList(
105
+ [
106
+ Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
107
+ for _ in range(self.num_layers)
108
+ ]
109
+ )
110
+ elif activation == "snakebeta":
111
+ self.activations = nn.ModuleList(
112
+ [
113
+ Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
114
+ for _ in range(self.num_layers)
115
+ ]
116
+ )
117
+ else:
118
+ raise NotImplementedError(
119
+ "activation incorrectly specified. check the config file and look for 'activation'."
120
+ )
121
+
122
+ def forward(self, x):
123
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
124
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
125
+ xt = a1(x)
126
+ xt = c1(xt)
127
+ xt = a2(xt)
128
+ xt = c2(xt)
129
+ x = xt + x
130
+
131
+ return x
132
+
133
+ def remove_weight_norm(self):
134
+ for l in self.convs1:
135
+ remove_weight_norm(l)
136
+ for l in self.convs2:
137
+ remove_weight_norm(l)
138
+
139
+
140
+ class AMPBlock2(torch.nn.Module):
141
+ """
142
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
143
+ Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
144
+
145
+ Args:
146
+ h (AttrDict): Hyperparameters.
147
+ channels (int): Number of convolution channels.
148
+ kernel_size (int): Size of the convolution kernel. Default is 3.
149
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
150
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ h: AttrDict,
156
+ channels: int,
157
+ kernel_size: int = 3,
158
+ dilation: tuple = (1, 3, 5),
159
+ activation: str = None,
160
+ ):
161
+ super().__init__()
162
+
163
+ self.h = h
164
+
165
+ self.convs = nn.ModuleList(
166
+ [
167
+ weight_norm(
168
+ Conv1d(
169
+ channels,
170
+ channels,
171
+ kernel_size,
172
+ stride=1,
173
+ dilation=d,
174
+ padding=get_padding(kernel_size, d),
175
+ )
176
+ )
177
+ for d in dilation
178
+ ]
179
+ )
180
+ self.convs.apply(init_weights)
181
+
182
+ self.num_layers = len(self.convs) # Total number of conv layers
183
+
184
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
185
+ if self.h.get("use_cuda_kernel", False):
186
+ from .alias_free_activation.cuda.activation1d import (
187
+ Activation1d as CudaActivation1d,
188
+ )
189
+
190
+ Activation1d = CudaActivation1d
191
+ else:
192
+ Activation1d = TorchActivation1d
193
+
194
+ # Activation functions
195
+ if activation == "snake":
196
+ self.activations = nn.ModuleList(
197
+ [
198
+ Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
199
+ for _ in range(self.num_layers)
200
+ ]
201
+ )
202
+ elif activation == "snakebeta":
203
+ self.activations = nn.ModuleList(
204
+ [
205
+ Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
206
+ for _ in range(self.num_layers)
207
+ ]
208
+ )
209
+ else:
210
+ raise NotImplementedError(
211
+ "activation incorrectly specified. check the config file and look for 'activation'."
212
+ )
213
+
214
+ def forward(self, x):
215
+ for c, a in zip(self.convs, self.activations):
216
+ xt = a(x)
217
+ xt = c(xt)
218
+ x = xt + x
219
+ return x
220
+
221
+ def remove_weight_norm(self):
222
+ for l in self.convs:
223
+ remove_weight_norm(l)
224
+
225
+
226
+ class BigVGAN(
227
+ torch.nn.Module,
228
+ PyTorchModelHubMixin,
229
+ # library_name="bigvgan",
230
+ # repo_url="https://github.com/NVIDIA/BigVGAN",
231
+ # docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
232
+ # pipeline_tag="audio-to-audio",
233
+ # license="mit",
234
+ # tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
235
+ ):
236
+ """
237
+ BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
238
+ New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
239
+
240
+ Args:
241
+ h (AttrDict): Hyperparameters.
242
+ use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
243
+
244
+ Note:
245
+ - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
246
+ - Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
247
+ """
248
+
249
+ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
250
+ super().__init__()
251
+ self.h = h
252
+ self.h["use_cuda_kernel"] = use_cuda_kernel
253
+
254
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
255
+ if self.h.get("use_cuda_kernel", False):
256
+ from .alias_free_activation.cuda.activation1d import (
257
+ Activation1d as CudaActivation1d,
258
+ )
259
+
260
+ Activation1d = CudaActivation1d
261
+ else:
262
+ Activation1d = TorchActivation1d
263
+
264
+ self.num_kernels = len(h.resblock_kernel_sizes)
265
+ self.num_upsamples = len(h.upsample_rates)
266
+
267
+ # Pre-conv
268
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
269
+
270
+ # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
271
+ if h.resblock == "1":
272
+ resblock_class = AMPBlock1
273
+ elif h.resblock == "2":
274
+ resblock_class = AMPBlock2
275
+ else:
276
+ raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
277
+
278
+ # Transposed conv-based upsamplers. does not apply anti-aliasing
279
+ self.ups = nn.ModuleList()
280
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
281
+ self.ups.append(
282
+ nn.ModuleList(
283
+ [
284
+ weight_norm(
285
+ ConvTranspose1d(
286
+ h.upsample_initial_channel // (2**i),
287
+ h.upsample_initial_channel // (2 ** (i + 1)),
288
+ k,
289
+ u,
290
+ padding=(k - u) // 2,
291
+ )
292
+ )
293
+ ]
294
+ )
295
+ )
296
+
297
+ # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
298
+ self.resblocks = nn.ModuleList()
299
+ for i in range(len(self.ups)):
300
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
301
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
302
+ self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
303
+
304
+ # Post-conv
305
+ activation_post = (
306
+ activations.Snake(ch, alpha_logscale=h.snake_logscale)
307
+ if h.activation == "snake"
308
+ else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
309
+ )
310
+ if activation_post is None:
311
+ raise NotImplementedError(
312
+ "activation incorrectly specified. check the config file and look for 'activation'."
313
+ )
314
+
315
+ self.activation_post = Activation1d(activation=activation_post)
316
+
317
+ # Whether to use bias for the final conv_post. Default to True for backward compatibility
318
+ self.use_bias_at_final = h.get("use_bias_at_final", True)
319
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
320
+
321
+ # Weight initialization
322
+ for i in range(len(self.ups)):
323
+ self.ups[i].apply(init_weights)
324
+ self.conv_post.apply(init_weights)
325
+
326
+ # Final tanh activation. Defaults to True for backward compatibility
327
+ self.use_tanh_at_final = h.get("use_tanh_at_final", True)
328
+
329
+ def forward(self, x):
330
+ # Pre-conv
331
+ x = self.conv_pre(x)
332
+
333
+ for i in range(self.num_upsamples):
334
+ # Upsampling
335
+ for i_up in range(len(self.ups[i])):
336
+ x = self.ups[i][i_up](x)
337
+ # AMP blocks
338
+ xs = None
339
+ for j in range(self.num_kernels):
340
+ if xs is None:
341
+ xs = self.resblocks[i * self.num_kernels + j](x)
342
+ else:
343
+ xs += self.resblocks[i * self.num_kernels + j](x)
344
+ x = xs / self.num_kernels
345
+
346
+ # Post-conv
347
+ x = self.activation_post(x)
348
+ x = self.conv_post(x)
349
+ # Final tanh activation
350
+ if self.use_tanh_at_final:
351
+ x = torch.tanh(x)
352
+ else:
353
+ x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
354
+
355
+ return x
356
+
357
+ def remove_weight_norm(self):
358
+ try:
359
+ # print("Removing weight norm...")
360
+ for l in self.ups:
361
+ for l_i in l:
362
+ remove_weight_norm(l_i)
363
+ for l in self.resblocks:
364
+ l.remove_weight_norm()
365
+ remove_weight_norm(self.conv_pre)
366
+ remove_weight_norm(self.conv_post)
367
+ except ValueError:
368
+ print("[INFO] Model already removed weight norm. Skipping!")
369
+ pass
370
+
371
+ # Additional methods for huggingface_hub support
372
+ def _save_pretrained(self, save_directory: Path) -> None:
373
+ """Save weights and config.json from a Pytorch model to a local directory."""
374
+
375
+ model_path = save_directory / "bigvgan_generator.pt"
376
+ torch.save({"generator": self.state_dict()}, model_path)
377
+
378
+ config_path = save_directory / "config.json"
379
+ with open(config_path, "w") as config_file:
380
+ json.dump(self.h, config_file, indent=4)
381
+
382
+ @classmethod
383
+ def _from_pretrained(
384
+ cls,
385
+ *,
386
+ model_id: str,
387
+ revision: str,
388
+ cache_dir: str,
389
+ force_download: bool,
390
+ proxies: Optional[Dict],
391
+ resume_download: bool,
392
+ local_files_only: bool,
393
+ token: Union[str, bool, None],
394
+ map_location: str = "cpu", # Additional argument
395
+ strict: bool = False, # Additional argument
396
+ use_cuda_kernel: bool = False,
397
+ **model_kwargs,
398
+ ):
399
+ """Load Pytorch pretrained weights and return the loaded model."""
400
+
401
+ # Download and load hyperparameters (h) used by BigVGAN
402
+ if os.path.isdir(model_id):
403
+ # print("Loading config.json from local directory")
404
+ config_file = os.path.join(model_id, "config.json")
405
+ else:
406
+ config_file = hf_hub_download(
407
+ repo_id=model_id,
408
+ filename="config.json",
409
+ revision=revision,
410
+ cache_dir=cache_dir,
411
+ force_download=force_download,
412
+ proxies=proxies,
413
+ resume_download=resume_download,
414
+ token=token,
415
+ local_files_only=local_files_only,
416
+ )
417
+ h = load_hparams_from_json(config_file)
418
+
419
+ # instantiate BigVGAN using h
420
+ if use_cuda_kernel:
421
+ print(
422
+ "[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
423
+ )
424
+ print(
425
+ "[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
426
+ )
427
+ print(
428
+ "[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
429
+ )
430
+ model = cls(h, use_cuda_kernel=use_cuda_kernel)
431
+
432
+ # Download and load pretrained generator weight
433
+ if os.path.isdir(model_id):
434
+ # print("Loading weights from local directory")
435
+ model_file = os.path.join(model_id, "bigvgan_generator.pt")
436
+ else:
437
+ # print(f"Loading weights from {model_id}")
438
+ model_file = hf_hub_download(
439
+ repo_id=model_id,
440
+ filename="bigvgan_generator.pt",
441
+ revision=revision,
442
+ cache_dir=cache_dir,
443
+ force_download=force_download,
444
+ proxies=proxies,
445
+ resume_download=resume_download,
446
+ token=token,
447
+ local_files_only=local_files_only,
448
+ )
449
+
450
+ checkpoint_dict = torch.load(model_file, map_location=map_location)
451
+
452
+ try:
453
+ model.load_state_dict(checkpoint_dict["generator"])
454
+ except RuntimeError:
455
+ print(
456
+ "[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
457
+ )
458
+ model.remove_weight_norm()
459
+ model.load_state_dict(checkpoint_dict["generator"])
460
+
461
+ return model
BigVGAN/configs/bigvgan_22khz_80band.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "activation": "snakebeta",
18
+ "snake_logscale": true,
19
+
20
+ "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
21
+ "mpd_reshapes": [2, 3, 5, 7, 11],
22
+ "use_spectral_norm": false,
23
+ "discriminator_channel_mult": 1,
24
+
25
+ "segment_size": 8192,
26
+ "num_mels": 80,
27
+ "num_freq": 1025,
28
+ "n_fft": 1024,
29
+ "hop_size": 256,
30
+ "win_size": 1024,
31
+
32
+ "sampling_rate": 22050,
33
+
34
+ "fmin": 0,
35
+ "fmax": 8000,
36
+ "fmax_for_loss": null,
37
+
38
+ "num_workers": 4,
39
+
40
+ "dist_config": {
41
+ "dist_backend": "nccl",
42
+ "dist_url": "tcp://localhost:54321",
43
+ "world_size": 1
44
+ }
45
+ }
BigVGAN/configs/bigvgan_24khz_100band.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "activation": "snakebeta",
18
+ "snake_logscale": true,
19
+
20
+ "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
21
+ "mpd_reshapes": [2, 3, 5, 7, 11],
22
+ "use_spectral_norm": false,
23
+ "discriminator_channel_mult": 1,
24
+
25
+ "segment_size": 8192,
26
+ "num_mels": 100,
27
+ "num_freq": 1025,
28
+ "n_fft": 1024,
29
+ "hop_size": 256,
30
+ "win_size": 1024,
31
+
32
+ "sampling_rate": 24000,
33
+
34
+ "fmin": 0,
35
+ "fmax": 12000,
36
+ "fmax_for_loss": null,
37
+
38
+ "num_workers": 4,
39
+
40
+ "dist_config": {
41
+ "dist_backend": "nccl",
42
+ "dist_url": "tcp://localhost:54321",
43
+ "world_size": 1
44
+ }
45
+ }
BigVGAN/configs/bigvgan_base_22khz_80band.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [8,8,2,2],
12
+ "upsample_kernel_sizes": [16,16,4,4],
13
+ "upsample_initial_channel": 512,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "activation": "snakebeta",
18
+ "snake_logscale": true,
19
+
20
+ "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
21
+ "mpd_reshapes": [2, 3, 5, 7, 11],
22
+ "use_spectral_norm": false,
23
+ "discriminator_channel_mult": 1,
24
+
25
+ "segment_size": 8192,
26
+ "num_mels": 80,
27
+ "num_freq": 1025,
28
+ "n_fft": 1024,
29
+ "hop_size": 256,
30
+ "win_size": 1024,
31
+
32
+ "sampling_rate": 22050,
33
+
34
+ "fmin": 0,
35
+ "fmax": 8000,
36
+ "fmax_for_loss": null,
37
+
38
+ "num_workers": 4,
39
+
40
+ "dist_config": {
41
+ "dist_backend": "nccl",
42
+ "dist_url": "tcp://localhost:54321",
43
+ "world_size": 1
44
+ }
45
+ }