Spaces:
Running
Running
File size: 2,800 Bytes
9df2e22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import tempfile
from typing import Tuple
import numpy as np
import soundfile as sf
import torch
from pathlib import Path
from model import convNet
from preprocess import Audio, fft_and_melscale
from synthesize import create_tja, detect, synthesize
def trim_silence(data: np.ndarray, sr: int):
start = 0
end = len(data) - 1
while start < len(data) and np.abs(data[start]) < 0.2:
start += 1
while end > 0 and np.abs(data[end]) < 0.1:
end -= 1
start = max(start - sr * 3, 0)
end = min(end + sr * 3, len(data))
print(
f"Trimming {start/sr} seconds from the start and {end/sr} seconds from the end"
)
data = data[start:end]
return data
class ODCNN:
def __init__(self, don_model: str, ka_model: str, device: torch.device = "cpu"):
donNet = convNet()
donNet = donNet.to(device)
donNet.load_state_dict(torch.load(don_model, map_location="cpu"))
self.donNet = donNet
kaNet = convNet()
kaNet = kaNet.to(device)
kaNet.load_state_dict(torch.load(ka_model, map_location="cpu"))
self.kaNet = kaNet
self.device = device
def run(self, file: str, delta=0.05, trim=True) -> Tuple[str, str]:
data, sr = sf.read(file, always_2d=True)
song = Audio(data, sr)
song.data = song.data.mean(axis=1)
if trim:
song.data = trim_silence(song.data, sr)
song.feats = fft_and_melscale(
song,
nhop=512,
nffts=[1024, 2048, 4096],
mel_nband=80,
mel_freqlo=27.5,
mel_freqhi=16000.0,
)
don_inference = self.donNet.infer(song.feats, self.device, minibatch=4192)
don_inference = np.reshape(don_inference, (-1))
ka_inference = self.kaNet.infer(song.feats, self.device, minibatch=4192)
ka_inference = np.reshape(ka_inference, (-1))
easy_detection = detect(don_inference, ka_inference, delta=0.25)
normal_detection = detect(don_inference, ka_inference, delta=0.2)
hard_detection = detect(don_inference, ka_inference, delta=0.15)
oni_detection = detect(don_inference, ka_inference, delta=0.075)
ura_detection = detect(don_inference, ka_inference, delta=delta)
synthesized_path = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
synthesize(*hard_detection, song, synthesized_path)
file = Path(file)
tja = create_tja(
song,
timestamps=[
easy_detection,
normal_detection,
hard_detection,
oni_detection,
ura_detection,
],
title=file.stem,
wave=file.name,
)
return synthesized_path, tja
|