import collections.abc import math import sys from itertools import repeat import matplotlib.pyplot as plt import numpy as np import timm import torch from torch import nn from torchvision.models.vision_transformer import Encoder from typing import Tuple from functools import partial from collections.abc import Iterable # import directly from collections for Python < 3.3 def plot_fbank(fbank, title=None, save_path=None, **kwargs): fig, axs = plt.subplots(min(4, fbank.shape[0]), 1, sharex=True, sharey=True) if not isinstance(axs, Iterable): axs = np.array([axs]) vmin, vmax = kwargs.get("vmin", None), kwargs.get("vmax", None) # max 4 channels... for channel in range(0, min(4, fbank.shape[0])): axs[channel].set_title(f"Filter bank channel {channel}, {title}") im = axs[channel].imshow(fbank[channel].T, aspect="auto", vmin=vmin, vmax=vmax) axs[channel].set_ylabel("mel") axs[channel].set_xlabel("time") plt.gca().invert_yaxis() plt.tight_layout() fig.colorbar(im, ax=axs.ravel().tolist()) plt.show() if save_path: fig.savefig(save_path) plt.close() return fig # From PyTorch Internals to create the tuples of the given iterable. def _ntuple(n): def parse(x): # if x is already an instance of iterable object, create a tuple out of it if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) # Otherwise repeat the x, n times, and create a tuple. return tuple(repeat(x, n)) return parse class PatchEmbed(nn.Module): """Image to Patch Embedding""" def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = _ntuple(2)(img_size) patch_size = _ntuple(2)(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d( in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, ) # We need to override these. def forward(self, x): x = self.proj(x).flatten(2).transpose(1, 2) return x def get_sinusoid_encoding(n_position, d_hid): """Sinusoid position encoding table""" def get_position_angle_vec(position): return [ position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid) ] sinusoid_table = np.array( [get_position_angle_vec(pos_i) for pos_i in range(n_position)] ) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.FloatTensor(sinusoid_table).unsqueeze(0) def create_pretrained_model(model_size, encoder_num_layers = 12, encoder_num_heads = 12, encoder_hidden_dim = 768, encoder_mlp_dim= 3072, encoder_dropout = 0.0, encoder_attention_dropout = 0.0, encoder_norm_layer_eps = 1e-6): if model_size == "tiny": v = timm.create_model("deit_tiny_distilled_patch16_224", pretrained=False) hidden_dim = 182 elif model_size == "small": v = timm.create_model("deit_small_distilled_patch16_224", pretrained=False) hidden_dim = 384 elif model_size == "base": v = Encoder( seq_length = 0, #Only used for pos_embeddings and we set them later! num_layers = encoder_num_layers, num_heads = encoder_num_heads, hidden_dim = encoder_hidden_dim, mlp_dim= encoder_mlp_dim, dropout = encoder_dropout, attention_dropout = encoder_attention_dropout, norm_layer = partial(nn.LayerNorm, eps=encoder_norm_layer_eps)) hidden_dim = encoder_hidden_dim elif model_size == "base_nokd": v = timm.create_model("deit_base_patch16_384", pretrained=False) hidden_dim = 768 else: print("Wrong model size!") sys.exit(0) return v, hidden_dim def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values left = norm_cdf((a - mean) / std) up = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * left - 1, 2 * up - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): # type: (Tensor, float, float, float, float) -> Tensor r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are applied while sampling the normal with mean/std applied, therefore a, b args should be adjusted to match the range of mean, std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ with torch.no_grad(): return _trunc_normal_(tensor, mean, std, a, b) def expand_index_like(index: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: """Expands the index along the last dimension of the input tokens. Args: index: Index tensor with shape (batch_size, idx_length) where each entry is an index in [0, sequence_length). tokens: Tokens tensor with shape (batch_size, sequence_length, dim). Returns: Index tensor with shape (batch_size, idx_length, dim) where the original indices are repeated dim times along the last dimension. """ dim = tokens.shape[-1] index = index.unsqueeze(-1).expand(-1, -1, dim) return index def set_at_index( tokens: torch.Tensor, index: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: """Copies all values into the input tensor at the given indices. Args: tokens: Tokens tensor with shape (batch_size, sequence_length, dim). index: Index tensor with shape (batch_size, index_length). value: Value tensor with shape (batch_size, index_length, dim). Returns: Tokens tensor with shape (batch_size, sequence_length, dim) containing the new values. """ index = expand_index_like(index, tokens) return torch.scatter(tokens, 1, index, value) def repeat_token(token: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: """Repeats a token size times. Args: token: Token tensor with shape (1, 1, dim). size: (batch_size, sequence_length) tuple. Returns: Tensor with shape (batch_size, sequence_length, dim) containing copies of the input token. """ batch_size, sequence_length = size return token.repeat(batch_size, sequence_length, 1)