Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
import typing as tp | |
import torchaudio | |
import einops | |
from abc import ABC, abstractmethod | |
class AbstractVAE(ABC, nn.Module): | |
def frame_rate(self) -> float: | |
... | |
def orig_sample_rate(self) -> int: | |
... | |
def channel_dim(self) -> int: | |
... | |
def split_bands(self) -> int: | |
... | |
def input_channel(self) -> int: | |
... | |
def encode(self, wav) -> torch.Tensor: | |
... | |
def decode(self, latents) -> torch.Tensor: | |
... | |
from .autoencoders import create_autoencoder_from_config, AudioAutoencoder | |
class StableVAE(AbstractVAE): | |
def __init__(self, vae_ckpt, vae_cfg, sr=48000) -> None: | |
super().__init__() | |
import json | |
with open(vae_cfg) as f: | |
config = json.load(f) | |
self.vae: AudioAutoencoder = create_autoencoder_from_config(config) | |
self.vae.load_state_dict(torch.load(vae_ckpt)['state_dict']) | |
self.sample_rate = sr | |
self.rsp48k = torchaudio.transforms.Resample(sr, self.orig_sample_rate) if sr != self.orig_sample_rate else nn.Identity() | |
def encode(self, wav: torch.Tensor, sample=True) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | |
wav = self.rsp48k(wav) | |
if wav.shape[-1] < 2048: | |
return torch.zeros((wav.shape[0], self.channel_dim, 0), device=wav.device, dtype=wav.dtype) | |
if wav.ndim == 2: | |
wav = wav.unsqueeze(1) | |
if wav.shape[1] == 1: | |
wav = wav.repeat(1, self.vae.in_channels, 1) | |
latent = self.vae.encode_audio(wav) # B, 64, T | |
return latent | |
def decode(self, latents: torch.Tensor, **kwargs): | |
# B, 64, T | |
with torch.no_grad(): | |
audio_recon = self.vae.decode_audio(latents, **kwargs) | |
return audio_recon | |
def frame_rate(self) -> float: | |
return float(self.vae.sample_rate) / self.vae.downsampling_ratio | |
def orig_sample_rate(self) -> int: | |
return self.vae.sample_rate | |
def channel_dim(self) -> int: | |
return self.vae.latent_dim | |
def split_bands(self) -> int: | |
return 1 | |
def input_channel(self) -> int: | |
return self.vae.in_channels |