# Copyright 2025 LY Corporation import copy from abc import abstractmethod from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair __all__ = [ "AudioSpectrogramTransformer", "PositionalPatchEmbedding", "Aggregator", "HeadTokenAggregator", "HeadTokensAggregator", ] class BaseAudioSpectrogramTransformer(nn.Module): """Base class of audio spectrogram transformer.""" def __init__( self, embedding: "PositionalPatchEmbedding", backbone: nn.TransformerEncoder, ) -> None: super().__init__() self.embedding = embedding self.backbone = backbone def pad_by_length( self, input: torch.Tensor, length: Optional[torch.LongTensor] = None ) -> torch.Tensor: """Pad feature by length. Args: input (torch.Tensor): Spectrogram-like feature of shape (batch_size, n_bins, n_frames). length (torch.LongTensor, optional): Length of each sample in batch of shape (batch_size,). Returns: torch.Tensor: Padded feature of shape (batch_size, n_bins, n_frames). """ if length is None: output = input else: factory_kwargs = { "device": input.device, "dtype": torch.long, } max_length = input.size(-1) padding_mask = torch.arange( max_length, **factory_kwargs ) >= length.unsqueeze(dim=-1) output = input.masked_fill(padding_mask.unsqueeze(dim=-2), 0) return output def compute_patch_embedding(self, input: torch.Tensor) -> torch.Tensor: """Compute output shape from shape of spectrogram.""" output = self.embedding.compute_patch_embedding(input) return output def apply_positional_embedding( self, input: torch.Tensor, n_bins: int, n_frames: int, ) -> torch.Tensor: """Apply positional embedding. Args: input (torch.Tensor): Patches of shape (batch_size, embedding_dim, height, width). n_bins (int): Number of bins, not height. n_frames (int): Number of frames, not width. Returns: torch.Tensor: Resampled positional embedding of shape (embedding_dim, height', width'). """ positional_embedding = self.embedding.positional_embedding output = input + self.embedding.resample_positional_embedding( positional_embedding, n_bins, n_frames, ) return output def dropout_embedding(self, input: torch.Tensor) -> torch.Tensor: """Dropout embedding. Args: input (torch.Tensor): Sequence of shape (batch_size, height * width + num_head_tokens). Returns: torch.Tensor: Sequence of shape (batch_size, height * width + num_head_tokens). """ output = self.embedding.dropout(input) return output def compute_padding_mask( self, input: torch.Tensor, length: Optional[torch.LongTensor] = None, ) -> Optional[torch.BoolTensor]: """Compute padding mask. Args: input (torch.Tensor): Input feature of shape (batch_size, n_bins, max_frames). length (torch.LongTensor, optional): Length of input of shape (batch_size,). Returns: torch.BoolTensor: Padding mask of shape (batch_size, height * max_width + num_head_tokens). """ if length is None: padding_mask = None else: factory_kwargs = { "dtype": torch.long, "device": length.device, } _, n_bins, max_frames = input.size() width = [] for _length in length: n_frames = _length.item() _, _width = self.embedding.compute_output_shape(n_bins, n_frames) width.append(_width) width = torch.tensor(width, **factory_kwargs) max_height, max_width = self.embedding.compute_output_shape( n_bins, max_frames ) padding_mask = torch.arange(max_width, **factory_kwargs) >= width.unsqueeze( dim=-1 ) padding_mask = padding_mask.unsqueeze(dim=-2) padding_mask = padding_mask.repeat((1, max_height, 1)) padding_mask = self.patches_to_sequence(padding_mask) num_head_tokens = 0 if self.embedding.insert_cls_token: num_head_tokens += 1 if self.embedding.insert_dist_token: num_head_tokens += 1 padding_mask = F.pad(padding_mask, (num_head_tokens, 0), value=False) return padding_mask def patch_transformer_forward( self, input: torch.Tensor, padding_mask: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: """Transformer with patch inputs. Args: input (torch.Tensor): Patch feature of shape (batch_size, embedding_dim, height, width). padding_mask (torch.BoolTensor): Padding mask of shape (batch_size, height, width). Returns: torch.Tensor: Estimated patches of shape (batch_size, embedding_dim, height, width). """ _, _, height, width = input.size() x = self.patches_to_sequence(input) if padding_mask is not None: padding_mask = self.patches_to_sequence(padding_mask) x = self.transformer_forward(x, padding_mask=padding_mask) output = self.sequence_to_patches(x, height=height, width=width) return output def transformer_forward( self, input: torch.Tensor, padding_mask: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: """Run forward pass of backbone. Args: input (torch.Tensor): Sequence of shape (batch_size, length, embedding_dim). padding_mask (torch.BoolTensor, optional): Padding mask of shape (batch_size, length). Returns: torch.Tensor: Estimated sequence of shape (batch_size, length, embedding_dim). """ if padding_mask is None: kwargs = {} else: if isinstance(self.backbone, nn.TransformerEncoder): kwargs = { "src_key_padding_mask": padding_mask, } else: kwargs = { "padding_mask": padding_mask, } output = self.backbone(input, **kwargs) return output def spectrogram_to_patches(self, input: torch.Tensor) -> torch.Tensor: """Convert spectrogram to patches. Actual implementation depends on ``self.embedding.spectrogram_to_patches``. """ return self.embedding.spectrogram_to_patches(input) def patches_to_sequence( self, input: Union[torch.Tensor, torch.BoolTensor] ) -> torch.Tensor: r"""Convert 3D (batch_size, height, width) or 4D (batch_size, embedding_dim, height, width) tensor to shape (batch_size, length, \*) for input of Transformer. Args: input (torch.Tensor): Patches of shape (batch_size, height, width) or (batch_size, embedding_dim, height, width). Returns: torch.Tensor: Sequence of shape (batch_size, length) or (batch_size, length, embedding_dim). """ n_dims = input.dim() if n_dims == 3: batch_size, height, width = input.size() output = input.view(batch_size, height * width) elif n_dims == 4: batch_size, embedding_dim, height, width = input.size() x = input.view(batch_size, embedding_dim, height * width) output = x.permute(0, 2, 1).contiguous() else: raise ValueError("Only 3D and 4D tensors are supported.") return output def sequence_to_patches( self, input: Union[torch.Tensor, torch.BoolTensor], height: int, width: int ) -> torch.Tensor: r"""Convert (batch_size, max_length, \*) tensor to 3D (batch_size, height, width) or 4D (batch_size, embedding_dim, height, width) one. This method corresponds to inversion of ``patches_to_sequence``. """ n_dims = input.dim() if n_dims == 2: batch_size, _ = input.size() output = input.view(batch_size, height, width) elif n_dims == 3: batch_size, _, embedding_dim = input.size() x = input.view(batch_size, height, width, embedding_dim) output = x.permute(0, 3, 1, 2).contiguous() else: raise ValueError("Only 2D and 3D tensors are supported.") return output def split_sequence( self, sequence: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Split sequence to head tokens and content tokens. Args: sequence (torch.Tensor): Sequence containing head tokens, i.e. class and distillation tokens or corresponding mask. If the tokens are given, the shape should be (batch_size, length, embedding_dim). Otherwise (mask is given), the shape should be (batch_size, length). Returns: tuple: Tuple of tensors containing - torch.Tensor: Head tokens of shape (batch_size, num_head_tokens, embedding_dim) or (batch_size, num_head_tokens). - torch.Tensor: Sequence of shape (batch_size, length - num_head_tokens, embedding_dim) or (batch_size, length - num_head_tokens). .. note:: This method is applicable even when sequence does not contain head tokens. In that case, an empty sequnce is returened as the first item of returned tensors. """ n_dims = sequence.dim() if n_dims == 2: sequence = sequence.unsqueeze(dim=-1) head_tokens, sequence = self.embedding.split_sequence(sequence) if n_dims == 2: sequence = sequence.squeeze(dim=-1) return head_tokens, sequence def prepend_head_tokens(self, sequence: torch.Tensor) -> torch.Tensor: return self.embedding.prepend_head_tokens(sequence) def prepend_tokens( self, sequence: torch.Tensor, tokens: Optional[torch.Tensor] = None ) -> torch.Tensor: """Prepaned tokens to sequence. This method is inversion of ``split_sequence``. Args: sequence (torch.Tensor): Sequence of shape (batch_size, length, embedding_dim) or (batch_size, length). tokens (torch.Tensor, optional): Tokens of shape (batch_size, num_tokens, embedding_dim) or (batch_size, num_tokens). Returns: torch.Tensor: Concatenated sequence of shape (batch_size, length + num_tokens, embedding_dim) or (batch_size, length + num_tokens). """ if tokens is None: return sequence else: if sequence.dim() == 2: # assume (batch_size, length) and (batch_size, num_tokens) return torch.cat([tokens, sequence], dim=-1) else: return torch.cat([tokens, sequence], dim=-2) class AudioSpectrogramTransformer(BaseAudioSpectrogramTransformer): """Audio spectrogram transformer. Args: embedding (PositionalPatchEmbedding): Patch embedding followed by positional embedding. backbone (nn.TransformerEncoder): Transformer (encoder). """ def __init__( self, embedding: "PositionalPatchEmbedding", backbone: nn.TransformerEncoder, aggregator: Optional["Aggregator"] = None, ) -> None: super().__init__(embedding=embedding, backbone=backbone) self.aggregator = aggregator def forward( self, input: torch.Tensor, length: Optional[torch.LongTensor] = None, ) -> torch.Tensor: """Forward pass of AudioSpectrogramTransformer. Args: input (torch.Tensor): Spectrogram of shape (batch_size, n_bins, n_frames). length (torch.LongTensor, optional): Length of input of shape (batch_size,). Returns: torch.Tensor: Estimated patches. The shape is one of - (batch_size, height * width + num_head_tokens, embedding_dim). - (batch_size, height * width + num_head_tokens, out_channels). - (batch_size, embedding_dim). - (batch_size, out_channels). """ input = self.pad_by_length(input, length=length) x = self.embedding(input) padding_mask = self.compute_padding_mask(input, length=length) output = self.transformer_forward(x, padding_mask=padding_mask) if self.aggregator is not None: output = self.aggregator(output, padding_mask=padding_mask) return output class Aggregator(nn.Module): """Base class of module to aggregate features.""" @abstractmethod def forward( self, input: torch.Tensor, padding_mask: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: """Forward pass of Aggregator. Args: input (torch.Tensor): Sequence of shape (batch_size, length, embedding_dim). padding_mask (torch.BoolTensor, optional): Padding mask of shape (batch_size, length). Returns: torch.Tensor: Aggregated feature of shape (batch_size, embedding_dim). """ pass class HeadTokensAggregator(Aggregator): """Module of aggregation by extraction of head tokens. Args: insert_cls_token (bool): Given sequence is assumed to contain [CLS] token. insert_dist_token (bool): Given sequence is assumed to contain [DIST] token. """ def __init__( self, insert_cls_token: bool = True, insert_dist_token: bool = True, ) -> None: super().__init__() if not insert_cls_token and not insert_dist_token: raise ValueError( "At least one of insert_cls_token and insert_dist_token should be True." ) self.insert_cls_token = insert_cls_token self.insert_dist_token = insert_dist_token def forward( self, input: torch.Tensor, padding_mask: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: """Forward pass of HeadTokensAggregator. Args: input (torch.Tensor): Sequence of shape (batch_size, length, embedding_dim). padding_mask (torch.BoolTensor, optional): Padding mask of shape (batch_size, length). Returns: torch.Tensor: Aggregated feature of shape (batch_size, embedding_dim). .. note:: padding_mask is ignored. """ num_head_tokens = 0 if self.insert_cls_token: num_head_tokens += 1 if self.insert_dist_token: num_head_tokens += 1 head_tokens, _ = torch.split( input, [num_head_tokens, input.size(-2) - num_head_tokens], dim=-2 ) output = torch.mean(head_tokens, dim=-2) return output def extra_repr(self) -> str: s = [] if self.insert_cls_token: s.append("cls_token=True") if self.insert_dist_token: s.append("dist_token=True") s = ", ".join(s) return s class HeadTokenAggregator(Aggregator): """Module of aggregation by extraction of head token. Args: position (int): Position odf head token to be extracted. If [CLS] and [DIST] tokens are used by backbone, ``0`` typically corresponds to [CLS] token and ``1`` to [DIST] token. """ def __init__( self, position: int = 0, ) -> None: super().__init__() self.position = position def forward( self, input: torch.Tensor, padding_mask: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: """Forward pass of HeadTokensAggregator. Args: input (torch.Tensor): Sequence of shape (batch_size, length, embedding_dim). padding_mask (torch.BoolTensor, optional): Padding mask of shape (batch_size, length). Returns: torch.Tensor: Extracted feature of shape (batch_size, embedding_dim). .. note:: padding_mask is ignored. """ position = self.position _, head_token, _ = torch.split( input, [position, 1, input.size(-2) - position - 1], dim=-2 ) output = head_token.squeeze(dim=-2) return output def extra_repr(self) -> str: s = f"position={self.position}" return s class _PatchEmbedding(nn.Module): def __init__( self, embedding_dim: int, insert_cls_token: bool = False, insert_dist_token: bool = False, device: torch.device = None, dtype: torch.dtype = None, ) -> None: factory_kwargs = { "device": device, "dtype": dtype, } super().__init__() if insert_dist_token and not insert_cls_token: raise ValueError( "When insert_dist_token=True, insert_cls_token should be True." ) if insert_cls_token: cls_token = torch.empty( (embedding_dim,), **factory_kwargs, ) self.cls_token = nn.Parameter(cls_token) else: self.register_parameter("cls_token", None) if insert_dist_token: dist_token = torch.empty( (embedding_dim,), **factory_kwargs, ) self.dist_token = nn.Parameter(dist_token) else: self.register_parameter("dist_token", None) def reset_head_tokens(self) -> None: if self.cls_token is not None: nn.init.trunc_normal_(self.cls_token.data, std=0.02) if self.dist_token is not None: nn.init.trunc_normal_(self.dist_token.data, std=0.02) @property def insert_cls_token(self) -> bool: return self.cls_token is not None @property def insert_dist_token(self) -> bool: return self.dist_token is not None @abstractmethod def compute_patch_embedding(self, input: torch.Tensor) -> torch.Tensor: """Compute patch embeddings of input feature. Args: input (torch.Tensor): Spectrogram-like feature of shape (batch_size, n_bins, n_frames). Returns: torch.Tensor: Embedded features of shape (batch_size, embedding_dim, height, width). """ @abstractmethod def spectrogram_to_patches(self, input: torch.Tensor) -> torch.Tensor: """Convert spectrogram to patches.""" def patches_to_sequence( self, input: Union[torch.Tensor, torch.BoolTensor] ) -> torch.Tensor: """Convert 3D (batch_size, height, width) or 4D (batch_size, embedding_dim, height, width) tensor to shape (batch_size, length, *) for input of Transformer. Args: input (torch.Tensor): Patches of shape (batch_size, height, width) or (batch_size, embedding_dim, height, width). Returns: torch.Tensor: Sequence of shape (batch_size, length) or (batch_size, length, embedding_dim). """ n_dims = input.dim() if n_dims == 3: batch_size, height, width = input.size() output = input.view(batch_size, height * width) elif n_dims == 4: batch_size, embedding_dim, height, width = input.size() x = input.view(batch_size, embedding_dim, height * width) output = x.permute(0, 2, 1).contiguous() else: raise ValueError("Only 3D and 4D tensors are supported.") return output def prepend_head_tokens(self, sequence: torch.Tensor) -> torch.Tensor: """Prepend [CLS] and [DIST] tokens to sequence. Args: sequence (torch.Tensor): Sequence of shape (batch_size, height * width, embedding_dim). Returns: torch.Tensor: Sequence of shape (batch_size, height * width + num_head_tokens, embedding_dim), where `num_head_tokens` represents number of tokens for [CLS] and [DIST]. """ batch_size = sequence.size(0) if self.insert_dist_token: dist_token = self.dist_token.expand((batch_size, 1, -1)) sequence = torch.cat([dist_token, sequence], dim=-2) if self.insert_cls_token: cls_token = self.cls_token.expand((batch_size, 1, -1)) sequence = torch.cat([cls_token, sequence], dim=-2) return sequence def split_sequence( self, sequence: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Split sequence to head tokens and content tokens. Args: sequence (torch.Tensor): Sequence containing head tokens, i.e. class and distillation tokens. The shape is (batch_size, length, embedding_dim). Returns: tuple: Tuple of tensors containing - torch.Tensor: Head tokens of shape (batch_size, num_head_tokens, embedding_dim). - torch.Tensor: Sequence of shape (batch_size, length - num_head_tokens, embedding_dim). .. note:: This method is applicable even when sequence does not contain head tokens. In that case, an empty sequnce is returened as the first item of returned tensors. """ length = sequence.size(-2) num_head_tokens = 0 if self.cls_token is not None: num_head_tokens += 1 if self.dist_token is not None: num_head_tokens += 1 head_tokens, sequence = torch.split( sequence, [num_head_tokens, length - num_head_tokens], dim=-2 ) return head_tokens, sequence @abstractmethod def compute_output_shape(self, n_bins: int, n_frames: int) -> Tuple[int, int]: """Compute output shape from shape of spectrogram.""" class PositionalPatchEmbedding(_PatchEmbedding): """Patch embedding + trainable positional embedding. Args: embedding_dim (int): Embedding dimension. kernel_size (_size_2_t): Kernel size that corresponds to patch. stride (_size_2_t): Stride. insert_cls_token (bool): If ``True``, class token is inserted to beginning of sequence. insert_dist_token (bool): If ``True``, distillation token is inserd to beginning sequence. dropout (float): Dropout rate. n_bins (int): Number of input bins. n_frames (int): Number of input frames. .. note:: Unlike official implementation (of AST), trainable positional embedding for CLS (and DIST) token(s) are omitted in terms of redundancy. """ def __init__( self, embedding_dim: int, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, insert_cls_token: bool = False, insert_dist_token: bool = False, dropout: float = 0, n_bins: int = None, n_frames: int = None, device: torch.device = None, dtype: torch.dtype = None, ) -> None: factory_kwargs = { "device": device, "dtype": dtype, } super().__init__( embedding_dim, insert_cls_token=insert_cls_token, insert_dist_token=insert_dist_token, **factory_kwargs, ) if n_bins is None: raise ValueError("n_bins is required.") if n_frames is None: raise ValueError("n_frames is required.") kernel_size = _pair(kernel_size) if stride is None: stride = kernel_size stride = _pair(stride) self.embedding_dim = embedding_dim self.kernel_size = kernel_size self.stride = stride self.n_bins = n_bins self.n_frames = n_frames self.conv2d = nn.Conv2d( 1, embedding_dim, kernel_size=kernel_size, stride=stride, ) height, width = self.compute_output_shape(n_bins, n_frames) positional_embedding = torch.empty( (embedding_dim, height, width), **factory_kwargs ) self.positional_embedding = nn.Parameter(positional_embedding) self.dropout = nn.Dropout(dropout) self._reset_parameters() def _reset_parameters(self) -> None: self.reset_head_tokens() # based on official implementation nn.init.trunc_normal_(self.positional_embedding.data, std=0.02) def forward(self, input: torch.Tensor) -> torch.Tensor: """Forward pass of PositionalPatchEmbedding. Args: input (torch.Tensor): Spectrogram of shape (batch_size, n_bins, n_frames). Returns: torch.Tensor: (batch_size, height * width + num_head_tokens, embedding_dim), where `num_head_tokens` represents number of tokens for [CLS] and [DIST]. """ positional_embedding = self.positional_embedding _, n_bins, n_frames = input.size() x = self.compute_patch_embedding(input) x = x + self.resample_positional_embedding( positional_embedding, n_bins, n_frames, ) x = self.patches_to_sequence(x) x = self.prepend_head_tokens(x) output = self.dropout(x) return output def compute_patch_embedding(self, input: torch.Tensor) -> torch.Tensor: """Compute patch embeddings of input feature. Args: input (torch.Tensor): Spectrogram-like feature of shape (batch_size, n_bins, n_frames). Returns: torch.Tensor: Embedded features of shape (batch_size, embedding_dim, height, width). """ x = input.unsqueeze(dim=-3) output = self.conv2d(x) return output def spectrogram_to_patches(self, input: torch.Tensor) -> torch.Tensor: """Convert spectrogram to patches.""" conv2d = self.conv2d batch_size, n_bins, n_frames = input.size() x = input.view(batch_size, 1, n_bins, n_frames) x = F.unfold( x, kernel_size=conv2d.kernel_size, dilation=conv2d.dilation, padding=conv2d.padding, stride=conv2d.stride, ) height, width = self.compute_output_shape(n_bins, n_frames) output = x.view(batch_size, -1, height, width) return output def resample_positional_embedding( self, positional_embedding: Union[torch.Tensor], n_bins: int, n_frames: int, mode: str = "bilinear", ) -> torch.Tensor: """Resample positional embedding. Args: positional_embedding (torch.Tensor): Positional embedding of shape (embedding_dim, height, width). n_bins (int): Number of bins, not height. n_frames (int): Number of frames, not width. mode (str): Interpolation mode. Default: ``bilinear``. Returns: torch.Tensor: Resampled positional embedding of shape (embedding_dim, height', width'). """ _, height_org, width_org = positional_embedding.size() height, width = self.compute_output_shape(n_bins, n_frames) if width_org > width: start_idx = 0 _, positional_embedding, _ = torch.split( positional_embedding, [start_idx, width, width_org - width - start_idx], dim=-1, ) elif width > width_org: positional_embedding = positional_embedding.unsqueeze(dim=0) positional_embedding = F.interpolate( positional_embedding, size=(height_org, width), mode=mode ) positional_embedding = positional_embedding.squeeze(dim=0) if height_org > height: start_idx = height_org // 2 - height // 2 _, positional_embedding, _ = torch.split( positional_embedding, [start_idx, height, height_org - height - start_idx], dim=-1, ) elif height > height_org: positional_embedding = positional_embedding.unsqueeze(dim=0) positional_embedding = F.interpolate( positional_embedding, size=(height, width), mode=mode ) positional_embedding = positional_embedding.squeeze(dim=0) output = positional_embedding return output def compute_output_shape(self, n_bins: int, n_frames: int) -> Tuple[int, int]: Kh, Kw = self.conv2d.kernel_size Sh, Sw = self.conv2d.stride height = (n_bins - Kh) // Sh + 1 width = (n_frames - Kw) // Sw + 1 return height, width def align_patch_embedding( orig_patch_embedding: PositionalPatchEmbedding, stride: Optional[_size_2_t] = None, n_bins: Optional[int] = None, n_frames: Optional[int] = None, ) -> PositionalPatchEmbedding: pretrained_embedding_dim = orig_patch_embedding.embedding_dim pretrained_kernel_size = orig_patch_embedding.kernel_size pretrained_stride = orig_patch_embedding.stride pretrained_insert_cls_token = orig_patch_embedding.insert_cls_token pretrained_insert_dist_token = orig_patch_embedding.insert_dist_token pretrained_n_bins = orig_patch_embedding.n_bins pretrained_n_frames = orig_patch_embedding.n_frames pretrained_conv2d = orig_patch_embedding.conv2d pretrained_positional_embedding = orig_patch_embedding.positional_embedding pretrained_cls_token = orig_patch_embedding.cls_token pretrained_dist_token = orig_patch_embedding.dist_token if stride is None: stride = pretrained_stride if n_bins is None: n_bins = pretrained_n_bins if n_frames is None: n_frames = pretrained_n_frames new_patch_embedding = PositionalPatchEmbedding( pretrained_embedding_dim, kernel_size=pretrained_kernel_size, stride=stride, insert_cls_token=pretrained_insert_cls_token, insert_dist_token=pretrained_insert_dist_token, n_bins=n_bins, n_frames=n_frames, ) conv2d_state_dict = copy.deepcopy(pretrained_conv2d.state_dict()) new_patch_embedding.conv2d.load_state_dict(conv2d_state_dict) pretrained_positional_embedding = new_patch_embedding.resample_positional_embedding( pretrained_positional_embedding, n_bins, n_frames ) new_patch_embedding.positional_embedding.data.copy_(pretrained_positional_embedding) if pretrained_insert_cls_token: new_patch_embedding.cls_token.data.copy_(pretrained_cls_token) if pretrained_insert_dist_token: new_patch_embedding.dist_token.data.copy_(pretrained_dist_token) return new_patch_embedding