""" Helpful functions to process audio """ import numpy as np import soundfile as sf from typing_extensions import Annotated, Literal, Optional import torchaudio import torch AudioChannel = Literal[1, 2] def read_audio_file( path: str, target_sample_rate: int = 16000, channels: int = 1, normalize: bool = True, max_duration: Optional[float] = None, ) -> np.ndarray: """Read and resample audio file If target_sample_rate is different than the audio's sample rate, this function will resample it If GPU is available, the resampling will be on GPU. Args: path: Path to the audio file (supports WAV, FLAC, OGG) target_sample_rate: Target sample rate (default: 24000) channels: Number of output channels (1 for mono, 2 for stereo) normalize: Whether to normalize audio to [-1, 1] max_duration: Maximum duration in seconds (truncates longer files) device: Device to process on ("cuda" or "cpu", defaults to cuda if available) Returns: np.ndarray: Processed audio samples as a numpy array Raises: RuntimeError: If the file cannot be read or processing fails """ try: # Load audio file with torchaudio waveform, original_sample_rate = torchaudio.load(path) # [channels, samples] # Truncate to max_duration before resampling if max_duration is not None: max_samples = int(max_duration * original_sample_rate) if waveform.size(1) > max_samples: waveform = waveform[:, :max_samples] # Downmix to desired channels if waveform.size(0) > channels: if channels == 1: waveform = waveform.mean(dim=0, keepdim=True) # Mono: average channels elif channels == 2: waveform = waveform[:2, :] # Stereo: take first 2 channels # Resample if needed if original_sample_rate != target_sample_rate: device = "cuda" if torch.cuda.is_available() else "cpu" waveform = waveform.to(device) resampler = torchaudio.transforms.Resample( orig_freq=original_sample_rate, new_freq=target_sample_rate, resampling_method="sinc_interp_kaiser", # Fast and high-quality ).to(device) waveform = resampler(waveform) # Normalize to [-1, 1] if requested if normalize: max_val = waveform.abs().max() if max_val > 0: waveform = waveform / max_val # Move back to CPU and convert to numpy data = waveform.cpu().numpy() # Ensure correct shape (remove extra dim if mono) if channels == 1 and data.shape[0] == 1: data = data[0, :] return data except Exception as e: raise RuntimeError(f"Failed to read audio file {path}: {str(e)}") def save_audio_file( audio_array: np.ndarray, sample_rate: int, file_path: str, format="WAV" ): """ Save an audio array to a file. Parameters: - audio_array: numpy array or list containing the audio samples - sample_rate: int, the sample rate of the audio (e.g., 44100 Hz) - file_path: str, path where the file will be saved (e.g., 'output.wav') - format: str, audio file format (e.g., 'WAV', 'FLAC', 'OGG'), default is 'WAV' """ try: if not file_path.endswith(".wav"): file_path += ".wav" sf.write(file_path, audio_array, sample_rate, format=format) except Exception as e: print(f"Error saving audio file at {file_path}: {e}")