Spaces:
Runtime error
Runtime error
| from typing import Dict, List, NoReturn, Optional | |
| import h5py | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from pytorch_lightning.core.datamodule import LightningDataModule | |
| from bytesep.data.samplers import DistributedSamplerWrapper | |
| from bytesep.utils import int16_to_float32 | |
| class DataModule(LightningDataModule): | |
| def __init__( | |
| self, | |
| train_sampler: object, | |
| train_dataset: object, | |
| num_workers: int, | |
| distributed: bool, | |
| ): | |
| r"""Data module. | |
| Args: | |
| train_sampler: Sampler object | |
| train_dataset: Dataset object | |
| num_workers: int | |
| distributed: bool | |
| """ | |
| super().__init__() | |
| self._train_sampler = train_sampler | |
| self.train_dataset = train_dataset | |
| self.num_workers = num_workers | |
| self.distributed = distributed | |
| def setup(self, stage: Optional[str] = None) -> NoReturn: | |
| r"""called on every device.""" | |
| # SegmentSampler is used for selecting segments for training. | |
| # On multiple devices, each SegmentSampler samples a part of mini-batch | |
| # data. | |
| if self.distributed: | |
| self.train_sampler = DistributedSamplerWrapper(self._train_sampler) | |
| else: | |
| self.train_sampler = self._train_sampler | |
| def train_dataloader(self) -> torch.utils.data.DataLoader: | |
| r"""Get train loader.""" | |
| train_loader = torch.utils.data.DataLoader( | |
| dataset=self.train_dataset, | |
| batch_sampler=self.train_sampler, | |
| collate_fn=collate_fn, | |
| num_workers=self.num_workers, | |
| pin_memory=True, | |
| ) | |
| return train_loader | |
| class Dataset: | |
| def __init__(self, augmentor: object, segment_samples: int): | |
| r"""Used for getting data according to a meta. | |
| Args: | |
| augmentor: Augmentor class | |
| segment_samples: int | |
| """ | |
| self.augmentor = augmentor | |
| self.segment_samples = segment_samples | |
| def __getitem__(self, meta: Dict) -> Dict: | |
| r"""Return data according to a meta. E.g., an input meta looks like: { | |
| 'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]], | |
| 'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}. | |
| } | |
| Then, vocals segments of song_A and song_B will be mixed (mix-audio augmentation). | |
| Accompaniment segments of song_C and song_B will be mixed (mix-audio augmentation). | |
| Finally, mixture is created by summing vocals and accompaniment. | |
| Args: | |
| meta: dict, e.g., { | |
| 'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]], | |
| 'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]} | |
| } | |
| Returns: | |
| data_dict: dict, e.g., { | |
| 'vocals': (channels, segments_num), | |
| 'accompaniment': (channels, segments_num), | |
| 'mixture': (channels, segments_num), | |
| } | |
| """ | |
| source_types = meta.keys() | |
| data_dict = {} | |
| for source_type in source_types: | |
| # E.g., ['vocals', 'bass', ...] | |
| waveforms = [] # Audio segments to be mix-audio augmented. | |
| for m in meta[source_type]: | |
| # E.g., { | |
| # 'hdf5_path': '.../song_A.h5', | |
| # 'key_in_hdf5': 'vocals', | |
| # 'begin_sample': '13406400', | |
| # 'end_sample': 13538700, | |
| # } | |
| hdf5_path = m['hdf5_path'] | |
| key_in_hdf5 = m['key_in_hdf5'] | |
| bgn_sample = m['begin_sample'] | |
| end_sample = m['end_sample'] | |
| with h5py.File(hdf5_path, 'r') as hf: | |
| if source_type == 'audioset': | |
| index_in_hdf5 = m['index_in_hdf5'] | |
| waveform = int16_to_float32( | |
| hf['waveform'][index_in_hdf5][bgn_sample:end_sample] | |
| ) | |
| waveform = waveform[None, :] | |
| else: | |
| waveform = int16_to_float32( | |
| hf[key_in_hdf5][:, bgn_sample:end_sample] | |
| ) | |
| if self.augmentor: | |
| waveform = self.augmentor(waveform, source_type) | |
| waveform = librosa.util.fix_length( | |
| waveform, size=self.segment_samples, axis=1 | |
| ) | |
| # (channels_num, segments_num) | |
| waveforms.append(waveform) | |
| # E.g., waveforms: [(channels_num, audio_samples), (channels_num, audio_samples)] | |
| # mix-audio augmentation | |
| data_dict[source_type] = np.sum(waveforms, axis=0) | |
| # data_dict[source_type]: (channels_num, audio_samples) | |
| # data_dict looks like: { | |
| # 'voclas': (channels_num, audio_samples), | |
| # 'accompaniment': (channels_num, audio_samples) | |
| # } | |
| # Mix segments from different sources. | |
| mixture = np.sum( | |
| [data_dict[source_type] for source_type in source_types], axis=0 | |
| ) | |
| data_dict['mixture'] = mixture | |
| # shape: (channels_num, audio_samples) | |
| return data_dict | |
| def collate_fn(list_data_dict: List[Dict]) -> Dict: | |
| r"""Collate mini-batch data to inputs and targets for training. | |
| Args: | |
| list_data_dict: e.g., [ | |
| {'vocals': (channels_num, segment_samples), | |
| 'accompaniment': (channels_num, segment_samples), | |
| 'mixture': (channels_num, segment_samples) | |
| }, | |
| {'vocals': (channels_num, segment_samples), | |
| 'accompaniment': (channels_num, segment_samples), | |
| 'mixture': (channels_num, segment_samples) | |
| }, | |
| ...] | |
| Returns: | |
| data_dict: e.g. { | |
| 'vocals': (batch_size, channels_num, segment_samples), | |
| 'accompaniment': (batch_size, channels_num, segment_samples), | |
| 'mixture': (batch_size, channels_num, segment_samples) | |
| } | |
| """ | |
| data_dict = {} | |
| for key in list_data_dict[0].keys(): | |
| data_dict[key] = torch.Tensor( | |
| np.array([data_dict[key] for data_dict in list_data_dict]) | |
| ) | |
| return data_dict | |