Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,712 Bytes
37a9836 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union, Literal
import torch
import torch.nn as nn
from transformers.modeling_outputs import BaseModelOutput
from transformers import HubertModel, AutoConfig, AutoModel
@dataclass
class CustomHubertConfig:
"""Configuration class for CustomHubert model."""
# e.g., "facebook/hubert-base-ls960" or "facebook/hubert-large-ll60k"
checkpoint_name: str
# Layer to extract features from (0-indexed, e.g., 9 for 10th layer)
feature_layer: int = 11
# Target audio sample rate in Hz
target_sample_rate: int = 16000
# Optional length multiple for audio trimming
seq_len_multiple_of: Optional[int] = None
@dataclass
class HubertForBarkSemanticConfig:
"""Configuration for HuBERTForBarkSemantic."""
# # HuBERT model checkpoint for feature extractor layer
checkpoint_name: Literal["facebook/hubert-base-ls960", "hubert-large-ls960-ft"]
vocab_size: int
# Layer to extract features from
feature_layer: int = 11
# last three tokens for SOS, EOS and PAD tokens
# maximum target sequence length
max_target_length: int = 2000
num_decoder_layer: int = 12
sos_token_id: int = 10000
eos_token_id: int = 10001
class HubertFeatureExtractor(nn.Module):
"""
A custom HuBERT model that loads a pretrained model from transformers and extracts
features from a specified layer. Processes raw audio waveforms and returns hidden states.
Args:
config (CustomHubertConfig): Configuration specifying checkpoint, layer, and audio settings.
device (torch.device, optional): Device to run the model on (e.g., "cuda" or "cpu").
"""
def __init__(
self,
config: CustomHubertConfig,
load_pretrained_weights: bool,
device: Optional[torch.device] = None,
):
super().__init__()
self.config = config
self.target_sample_rate = config.target_sample_rate
# Load pretrained HuBERT model from transformers
self.hubert_config = AutoConfig.from_pretrained(config.checkpoint_name)
if load_pretrained_weights:
self.model = HubertModel.from_pretrained(config.checkpoint_name)
else:
# don't download the pretrained weights, init the model from the config
self.model = AutoModel.from_config(self.hubert_config)
# Validate feature_layer
# e.g., 12 for BASE, 24 for LARGE
num_layers = self.model.config.num_hidden_layers
if not (0 <= config.feature_layer < num_layers):
raise ValueError(
f"feature_layer must be between 0 and {num_layers - 1}, got {config.feature_layer}"
)
self.feature_layer = config.feature_layer
# Move to device if specified
if device is not None:
self.to(device)
@property
def hidden_size(self) -> int:
"""Returns the hidden size of the HuBERT model (e.g., 768 for BASE, 1024 for LARGE)."""
return self.model.config.hidden_size
def forward(
self,
wav_input: torch.Tensor,
) -> torch.Tensor:
"""
Processes raw audio waveforms through HuBERT and extracts features from the specified layer.
Input audio sample rate expected 16k
Args:
wav_input (torch.Tensor): Raw audio waveforms, shape [batch_size, audio_length].
return_shape (Tuple[int, int], optional): If provided, reshapes output to [batch_size, seq_length, hidden_size].
Returns:
torch.Tensor: Features from the specified layer. Shape depends on return_shape:
- If None: [batch_size * seq_length, hidden_size] (flattened).
- If provided: [batch_size, seq_length, hidden_size].
"""
# Forward pass through HuBERT
# output_hidden_states=True returns all layer outputs
outputs: BaseModelOutput = self.model(
input_values=wav_input, output_hidden_states=True, return_dict=True
)
# Extract features from the specified layer (0-indexed)
# hidden_states is a tuple of [batch_size, seq_length, hidden_size] for each layer
features = outputs.hidden_states[self.feature_layer] # e.g., [2, 500, 768]
features = features.contiguous()
return features
class HuBERTForBarkSemantic(nn.Module):
def __init__(
self,
config: HubertForBarkSemanticConfig,
load_hubert_pretrained_weights: bool = True,
device: Optional[torch.device] = None,
):
super().__init__()
self.config = config
# HuBERT feature extractor
hubert_config = CustomHubertConfig(
checkpoint_name=config.checkpoint_name,
feature_layer=config.feature_layer,
)
self.hubert = HubertFeatureExtractor(
config=hubert_config,
load_pretrained_weights=load_hubert_pretrained_weights,
device=device,
)
# e.g., 768 for BASE
input_size = self.hubert.model.config.hidden_size
# Transformer Decoder
self.decoder_embedding = nn.Embedding(config.vocab_size, input_size)
self.pos_embedding = nn.Parameter(
torch.zeros(1, config.max_target_length, input_size)
)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=input_size,
nhead=8,
dim_feedforward=2048,
dropout=0.1,
batch_first=True,
),
num_layers=config.num_decoder_layer, # Adjust as needed
)
self.fc = nn.Linear(input_size, config.vocab_size)
if device is not None:
self.to(device)
def save_state_dict(self, save_path: str):
torch.save(self.state_dict(), save_path)
def forward(self, wav_input: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
"""
Forward pass: Extracts HuBERT features and predicts semantic token probabilities.
Args:
wav_input: [batch_size, audio_length] (e.g., [2, 160000])
tgt: the target sequence
Returns:
[batch_size, seq_length, vocab_size + 1] (e.g., [2, 500, VOCAB_SIZE])
"""
memory: torch.Tensor = self.hubert(wav_input) # [B, T, 768]
B, T_tgt = tgt.shape
tgt_emb = self.decoder_embedding(tgt) + self.pos_embedding[:, :T_tgt, :]
tgt_mask = nn.Transformer.generate_square_subsequent_mask(T_tgt).to(tgt.device)
output: torch.Tensor = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
logits = self.fc(output)
return logits
@torch.no_grad
def generate(
self,
wav_input: torch.Tensor,
temperature: Optional[float] = 0.8,
eos_p: Optional[float] = 0.5,
max_length: int = 600,
) -> torch.Tensor:
"""
Inference: autoregressive generation.
assuming wav_input audio is at 16000 sample rate"""
self.eval()
memory = self.hubert(wav_input)
B = wav_input.shape[0]
tgt = torch.full(
size=(B, 1), fill_value=self.config.sos_token_id, device=wav_input.device
)
for _ in range(max_length):
tgt_emb = (
self.decoder_embedding(tgt) + self.pos_embedding[:, : tgt.shape[1], :]
)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.shape[1]).to(
tgt.device
)
output = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
# logits shape (B, T', vocab_size)
logits: torch.Tensor = self.fc(output[:, -1, :])
if temperature is not None and temperature > 0:
probs = torch.softmax(input=logits / temperature, dim=-1)
next_token = torch.multinomial(input=probs, num_samples=1)
else:
probs = torch.softmax(input=logits, dim=-1)
next_token = logits.argmax(dim=-1, keepdim=True)
# stop if the EOS token probabilities are higher than the provided eos_p
if eos_p is not None and eos_p > 0:
if torch.all(probs[:, self.config.eos_token_id] > eos_p):
break
# early stopping
if torch.all(next_token == self.config.eos_token_id):
break
tgt = torch.cat([tgt, next_token], dim=1)
if (next_token == self.config.eos_token_id).all():
break
# remove the [SOS] token from the generated semantic sequences
return tgt[:, 1:]
|