from dataclasses import dataclass from pathlib import Path from typing import Optional, Tuple, Union, Literal import torch import torch.nn as nn from transformers.modeling_outputs import BaseModelOutput from transformers import HubertModel, AutoConfig, AutoModel @dataclass class CustomHubertConfig: """Configuration class for CustomHubert model.""" # e.g., "facebook/hubert-base-ls960" or "facebook/hubert-large-ll60k" checkpoint_name: str # Layer to extract features from (0-indexed, e.g., 9 for 10th layer) feature_layer: int = 11 # Target audio sample rate in Hz target_sample_rate: int = 16000 # Optional length multiple for audio trimming seq_len_multiple_of: Optional[int] = None @dataclass class HubertForBarkSemanticConfig: """Configuration for HuBERTForBarkSemantic.""" # # HuBERT model checkpoint for feature extractor layer checkpoint_name: Literal["facebook/hubert-base-ls960", "hubert-large-ls960-ft"] vocab_size: int # Layer to extract features from feature_layer: int = 11 # last three tokens for SOS, EOS and PAD tokens # maximum target sequence length max_target_length: int = 2000 num_decoder_layer: int = 12 sos_token_id: int = 10000 eos_token_id: int = 10001 class HubertFeatureExtractor(nn.Module): """ A custom HuBERT model that loads a pretrained model from transformers and extracts features from a specified layer. Processes raw audio waveforms and returns hidden states. Args: config (CustomHubertConfig): Configuration specifying checkpoint, layer, and audio settings. device (torch.device, optional): Device to run the model on (e.g., "cuda" or "cpu"). """ def __init__( self, config: CustomHubertConfig, load_pretrained_weights: bool, device: Optional[torch.device] = None, ): super().__init__() self.config = config self.target_sample_rate = config.target_sample_rate # Load pretrained HuBERT model from transformers self.hubert_config = AutoConfig.from_pretrained(config.checkpoint_name) if load_pretrained_weights: self.model = HubertModel.from_pretrained(config.checkpoint_name) else: # don't download the pretrained weights, init the model from the config self.model = AutoModel.from_config(self.hubert_config) # Validate feature_layer # e.g., 12 for BASE, 24 for LARGE num_layers = self.model.config.num_hidden_layers if not (0 <= config.feature_layer < num_layers): raise ValueError( f"feature_layer must be between 0 and {num_layers - 1}, got {config.feature_layer}" ) self.feature_layer = config.feature_layer # Move to device if specified if device is not None: self.to(device) @property def hidden_size(self) -> int: """Returns the hidden size of the HuBERT model (e.g., 768 for BASE, 1024 for LARGE).""" return self.model.config.hidden_size def forward( self, wav_input: torch.Tensor, ) -> torch.Tensor: """ Processes raw audio waveforms through HuBERT and extracts features from the specified layer. Input audio sample rate expected 16k Args: wav_input (torch.Tensor): Raw audio waveforms, shape [batch_size, audio_length]. return_shape (Tuple[int, int], optional): If provided, reshapes output to [batch_size, seq_length, hidden_size]. Returns: torch.Tensor: Features from the specified layer. Shape depends on return_shape: - If None: [batch_size * seq_length, hidden_size] (flattened). - If provided: [batch_size, seq_length, hidden_size]. """ # Forward pass through HuBERT # output_hidden_states=True returns all layer outputs outputs: BaseModelOutput = self.model( input_values=wav_input, output_hidden_states=True, return_dict=True ) # Extract features from the specified layer (0-indexed) # hidden_states is a tuple of [batch_size, seq_length, hidden_size] for each layer features = outputs.hidden_states[self.feature_layer] # e.g., [2, 500, 768] features = features.contiguous() return features class HuBERTForBarkSemantic(nn.Module): def __init__( self, config: HubertForBarkSemanticConfig, load_hubert_pretrained_weights: bool = True, device: Optional[torch.device] = None, ): super().__init__() self.config = config # HuBERT feature extractor hubert_config = CustomHubertConfig( checkpoint_name=config.checkpoint_name, feature_layer=config.feature_layer, ) self.hubert = HubertFeatureExtractor( config=hubert_config, load_pretrained_weights=load_hubert_pretrained_weights, device=device, ) # e.g., 768 for BASE input_size = self.hubert.model.config.hidden_size # Transformer Decoder self.decoder_embedding = nn.Embedding(config.vocab_size, input_size) self.pos_embedding = nn.Parameter( torch.zeros(1, config.max_target_length, input_size) ) self.decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer( d_model=input_size, nhead=8, dim_feedforward=2048, dropout=0.1, batch_first=True, ), num_layers=config.num_decoder_layer, # Adjust as needed ) self.fc = nn.Linear(input_size, config.vocab_size) if device is not None: self.to(device) def save_state_dict(self, save_path: str): torch.save(self.state_dict(), save_path) def forward(self, wav_input: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: """ Forward pass: Extracts HuBERT features and predicts semantic token probabilities. Args: wav_input: [batch_size, audio_length] (e.g., [2, 160000]) tgt: the target sequence Returns: [batch_size, seq_length, vocab_size + 1] (e.g., [2, 500, VOCAB_SIZE]) """ memory: torch.Tensor = self.hubert(wav_input) # [B, T, 768] B, T_tgt = tgt.shape tgt_emb = self.decoder_embedding(tgt) + self.pos_embedding[:, :T_tgt, :] tgt_mask = nn.Transformer.generate_square_subsequent_mask(T_tgt).to(tgt.device) output: torch.Tensor = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask) logits = self.fc(output) return logits @torch.no_grad def generate( self, wav_input: torch.Tensor, temperature: Optional[float] = 0.8, eos_p: Optional[float] = 0.5, max_length: int = 600, ) -> torch.Tensor: """ Inference: autoregressive generation. assuming wav_input audio is at 16000 sample rate""" self.eval() memory = self.hubert(wav_input) B = wav_input.shape[0] tgt = torch.full( size=(B, 1), fill_value=self.config.sos_token_id, device=wav_input.device ) for _ in range(max_length): tgt_emb = ( self.decoder_embedding(tgt) + self.pos_embedding[:, : tgt.shape[1], :] ) tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.shape[1]).to( tgt.device ) output = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask) # logits shape (B, T', vocab_size) logits: torch.Tensor = self.fc(output[:, -1, :]) if temperature is not None and temperature > 0: probs = torch.softmax(input=logits / temperature, dim=-1) next_token = torch.multinomial(input=probs, num_samples=1) else: probs = torch.softmax(input=logits, dim=-1) next_token = logits.argmax(dim=-1, keepdim=True) # stop if the EOS token probabilities are higher than the provided eos_p if eos_p is not None and eos_p > 0: if torch.all(probs[:, self.config.eos_token_id] > eos_p): break # early stopping if torch.all(next_token == self.config.eos_token_id): break tgt = torch.cat([tgt, next_token], dim=1) if (next_token == self.config.eos_token_id).all(): break # remove the [SOS] token from the generated semantic sequences return tgt[:, 1:]