sleeper371's picture
add code
37a9836
"""
Helpful functions to process audio
"""
import numpy as np
import soundfile as sf
from typing_extensions import Annotated, Literal, Optional
import torchaudio
import torch
AudioChannel = Literal[1, 2]
def read_audio_file(
path: str,
target_sample_rate: int = 16000,
channels: int = 1,
normalize: bool = True,
max_duration: Optional[float] = None,
) -> np.ndarray:
"""Read and resample audio file
If target_sample_rate is different than the audio's sample rate, this function will resample it
If GPU is available, the resampling will be on GPU.
Args:
path: Path to the audio file (supports WAV, FLAC, OGG)
target_sample_rate: Target sample rate (default: 24000)
channels: Number of output channels (1 for mono, 2 for stereo)
normalize: Whether to normalize audio to [-1, 1]
max_duration: Maximum duration in seconds (truncates longer files)
device: Device to process on ("cuda" or "cpu", defaults to cuda if available)
Returns:
np.ndarray: Processed audio samples as a numpy array
Raises:
RuntimeError: If the file cannot be read or processing fails
"""
try:
# Load audio file with torchaudio
waveform, original_sample_rate = torchaudio.load(path) # [channels, samples]
# Truncate to max_duration before resampling
if max_duration is not None:
max_samples = int(max_duration * original_sample_rate)
if waveform.size(1) > max_samples:
waveform = waveform[:, :max_samples]
# Downmix to desired channels
if waveform.size(0) > channels:
if channels == 1:
waveform = waveform.mean(dim=0, keepdim=True) # Mono: average channels
elif channels == 2:
waveform = waveform[:2, :] # Stereo: take first 2 channels
# Resample if needed
if original_sample_rate != target_sample_rate:
device = "cuda" if torch.cuda.is_available() else "cpu"
waveform = waveform.to(device)
resampler = torchaudio.transforms.Resample(
orig_freq=original_sample_rate,
new_freq=target_sample_rate,
resampling_method="sinc_interp_kaiser", # Fast and high-quality
).to(device)
waveform = resampler(waveform)
# Normalize to [-1, 1] if requested
if normalize:
max_val = waveform.abs().max()
if max_val > 0:
waveform = waveform / max_val
# Move back to CPU and convert to numpy
data = waveform.cpu().numpy()
# Ensure correct shape (remove extra dim if mono)
if channels == 1 and data.shape[0] == 1:
data = data[0, :]
return data
except Exception as e:
raise RuntimeError(f"Failed to read audio file {path}: {str(e)}")
def save_audio_file(
audio_array: np.ndarray, sample_rate: int, file_path: str, format="WAV"
):
"""
Save an audio array to a file.
Parameters:
- audio_array: numpy array or list containing the audio samples
- sample_rate: int, the sample rate of the audio (e.g., 44100 Hz)
- file_path: str, path where the file will be saved (e.g., 'output.wav')
- format: str, audio file format (e.g., 'WAV', 'FLAC', 'OGG'), default is 'WAV'
"""
try:
if not file_path.endswith(".wav"):
file_path += ".wav"
sf.write(file_path, audio_array, sample_rate, format=format)
except Exception as e:
print(f"Error saving audio file at {file_path}: {e}")