File size: 3,613 Bytes
37a9836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Helpful functions to process audio
"""

import numpy as np
import soundfile as sf

from typing_extensions import Annotated, Literal, Optional
import torchaudio
import torch

AudioChannel = Literal[1, 2]


def read_audio_file(
    path: str,
    target_sample_rate: int = 16000,
    channels: int = 1,
    normalize: bool = True,
    max_duration: Optional[float] = None,
) -> np.ndarray:
    """Read and resample audio file
    If target_sample_rate is different than the audio's sample rate, this function will resample it
    If GPU is available, the resampling will be on GPU.

    Args:
        path: Path to the audio file (supports WAV, FLAC, OGG)
        target_sample_rate: Target sample rate (default: 24000)
        channels: Number of output channels (1 for mono, 2 for stereo)
        normalize: Whether to normalize audio to [-1, 1]
        max_duration: Maximum duration in seconds (truncates longer files)
        device: Device to process on ("cuda" or "cpu", defaults to cuda if available)

    Returns:
        np.ndarray: Processed audio samples as a numpy array

    Raises:
        RuntimeError: If the file cannot be read or processing fails
    """
    try:
        # Load audio file with torchaudio
        waveform, original_sample_rate = torchaudio.load(path)  # [channels, samples]

        # Truncate to max_duration before resampling
        if max_duration is not None:
            max_samples = int(max_duration * original_sample_rate)
            if waveform.size(1) > max_samples:
                waveform = waveform[:, :max_samples]

        # Downmix to desired channels
        if waveform.size(0) > channels:
            if channels == 1:
                waveform = waveform.mean(dim=0, keepdim=True)  # Mono: average channels
            elif channels == 2:
                waveform = waveform[:2, :]  # Stereo: take first 2 channels

        # Resample if needed
        if original_sample_rate != target_sample_rate:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            waveform = waveform.to(device)
            resampler = torchaudio.transforms.Resample(
                orig_freq=original_sample_rate,
                new_freq=target_sample_rate,
                resampling_method="sinc_interp_kaiser",  # Fast and high-quality
            ).to(device)
            waveform = resampler(waveform)

        # Normalize to [-1, 1] if requested
        if normalize:
            max_val = waveform.abs().max()
            if max_val > 0:
                waveform = waveform / max_val

        # Move back to CPU and convert to numpy
        data = waveform.cpu().numpy()

        # Ensure correct shape (remove extra dim if mono)
        if channels == 1 and data.shape[0] == 1:
            data = data[0, :]

        return data

    except Exception as e:
        raise RuntimeError(f"Failed to read audio file {path}: {str(e)}")


def save_audio_file(
    audio_array: np.ndarray, sample_rate: int, file_path: str, format="WAV"
):
    """
    Save an audio array to a file.

    Parameters:
    - audio_array: numpy array or list containing the audio samples
    - sample_rate: int, the sample rate of the audio (e.g., 44100 Hz)
    - file_path: str, path where the file will be saved (e.g., 'output.wav')
    - format: str, audio file format (e.g., 'WAV', 'FLAC', 'OGG'), default is 'WAV'
    """
    try:
        if not file_path.endswith(".wav"):
            file_path += ".wav"
        sf.write(file_path, audio_array, sample_rate, format=format)
    except Exception as e:
        print(f"Error saving audio file at {file_path}: {e}")