File size: 2,800 Bytes
9df2e22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import tempfile
from typing import Tuple
import numpy as np
import soundfile as sf
import torch
from pathlib import Path

from model import convNet
from preprocess import Audio, fft_and_melscale
from synthesize import create_tja, detect, synthesize


def trim_silence(data: np.ndarray, sr: int):
    start = 0
    end = len(data) - 1
    while start < len(data) and np.abs(data[start]) < 0.2:
        start += 1
    while end > 0 and np.abs(data[end]) < 0.1:
        end -= 1
    start = max(start - sr * 3, 0)
    end = min(end + sr * 3, len(data))
    print(
        f"Trimming {start/sr} seconds from the start and {end/sr} seconds from the end"
    )
    data = data[start:end]
    return data


class ODCNN:
    def __init__(self, don_model: str, ka_model: str, device: torch.device = "cpu"):
        donNet = convNet()
        donNet = donNet.to(device)
        donNet.load_state_dict(torch.load(don_model, map_location="cpu"))
        self.donNet = donNet

        kaNet = convNet()
        kaNet = kaNet.to(device)
        kaNet.load_state_dict(torch.load(ka_model, map_location="cpu"))
        self.kaNet = kaNet

        self.device = device

    def run(self, file: str, delta=0.05, trim=True) -> Tuple[str, str]:
        data, sr = sf.read(file, always_2d=True)
        song = Audio(data, sr)
        song.data = song.data.mean(axis=1)
        if trim:
            song.data = trim_silence(song.data, sr)

        song.feats = fft_and_melscale(
            song,
            nhop=512,
            nffts=[1024, 2048, 4096],
            mel_nband=80,
            mel_freqlo=27.5,
            mel_freqhi=16000.0,
        )

        don_inference = self.donNet.infer(song.feats, self.device, minibatch=4192)
        don_inference = np.reshape(don_inference, (-1))

        ka_inference = self.kaNet.infer(song.feats, self.device, minibatch=4192)
        ka_inference = np.reshape(ka_inference, (-1))

        easy_detection = detect(don_inference, ka_inference, delta=0.25)
        normal_detection = detect(don_inference, ka_inference, delta=0.2)
        hard_detection = detect(don_inference, ka_inference, delta=0.15)
        oni_detection = detect(don_inference, ka_inference, delta=0.075)
        ura_detection = detect(don_inference, ka_inference, delta=delta)

        synthesized_path = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
        synthesize(*hard_detection, song, synthesized_path)
        file = Path(file)
        tja = create_tja(
            song,
            timestamps=[
                easy_detection,
                normal_detection,
                hard_detection,
                oni_detection,
                ura_detection,
            ],
            title=file.stem,
            wave=file.name,
        )

        return synthesized_path, tja