Spaces:
Sleeping
Sleeping
File size: 2,513 Bytes
a0e2cb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
from torch import nn
import typing as tp
import torchaudio
import einops
from abc import ABC, abstractmethod
class AbstractVAE(ABC, nn.Module):
@property
@abstractmethod
def frame_rate(self) -> float:
...
@property
@abstractmethod
def orig_sample_rate(self) -> int:
...
@property
@abstractmethod
def channel_dim(self) -> int:
...
@property
@abstractmethod
def split_bands(self) -> int:
...
@property
@abstractmethod
def input_channel(self) -> int:
...
def encode(self, wav) -> torch.Tensor:
...
def decode(self, latents) -> torch.Tensor:
...
from .autoencoders import create_autoencoder_from_config, AudioAutoencoder
class StableVAE(AbstractVAE):
def __init__(self, vae_ckpt, vae_cfg, sr=48000) -> None:
super().__init__()
import json
with open(vae_cfg) as f:
config = json.load(f)
self.vae: AudioAutoencoder = create_autoencoder_from_config(config)
self.vae.load_state_dict(torch.load(vae_ckpt)['state_dict'])
self.sample_rate = sr
self.rsp48k = torchaudio.transforms.Resample(sr, self.orig_sample_rate) if sr != self.orig_sample_rate else nn.Identity()
@torch.no_grad()
def encode(self, wav: torch.Tensor, sample=True) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
wav = self.rsp48k(wav)
if wav.shape[-1] < 2048:
return torch.zeros((wav.shape[0], self.channel_dim, 0), device=wav.device, dtype=wav.dtype)
if wav.ndim == 2:
wav = wav.unsqueeze(1)
if wav.shape[1] == 1:
wav = wav.repeat(1, self.vae.in_channels, 1)
latent = self.vae.encode_audio(wav) # B, 64, T
return latent
def decode(self, latents: torch.Tensor, **kwargs):
# B, 64, T
with torch.no_grad():
audio_recon = self.vae.decode_audio(latents, **kwargs)
return audio_recon
@property
def frame_rate(self) -> float:
return float(self.vae.sample_rate) / self.vae.downsampling_ratio
@property
def orig_sample_rate(self) -> int:
return self.vae.sample_rate
@property
def channel_dim(self) -> int:
return self.vae.latent_dim
@property
def split_bands(self) -> int:
return 1
@property
def input_channel(self) -> int:
return self.vae.in_channels |