Spaces:
Runtime error
Runtime error
| from typing import Dict, List | |
| import torch | |
| class BasicBatchDataPreprocessor: | |
| def __init__(self, target_source_types: List[str]): | |
| r"""Batch data preprocessor. Used for preparing mixtures and targets for | |
| training. If there are multiple target source types, the waveforms of | |
| those sources will be stacked along the channel dimension. | |
| Args: | |
| target_source_types: List[str], e.g., ['vocals', 'bass', ...] | |
| """ | |
| self.target_source_types = target_source_types | |
| def __call__(self, batch_data_dict: Dict) -> List[Dict]: | |
| r"""Format waveforms and targets for training. | |
| Args: | |
| batch_data_dict: dict, e.g., { | |
| 'mixture': (batch_size, channels_num, segment_samples), | |
| 'vocals': (batch_size, channels_num, segment_samples), | |
| 'bass': (batch_size, channels_num, segment_samples), | |
| ..., | |
| } | |
| Returns: | |
| input_dict: dict, e.g., { | |
| 'waveform': (batch_size, channels_num, segment_samples), | |
| } | |
| output_dict: dict, e.g., { | |
| 'target': (batch_size, target_sources_num * channels_num, segment_samples) | |
| } | |
| """ | |
| mixtures = batch_data_dict['mixture'] | |
| # mixtures: (batch_size, channels_num, segment_samples) | |
| # Concatenate waveforms of multiple targets along the channel axis. | |
| targets = torch.cat( | |
| [batch_data_dict[source_type] for source_type in self.target_source_types], | |
| dim=1, | |
| ) | |
| # targets: (batch_size, target_sources_num * channels_num, segment_samples) | |
| input_dict = {'waveform': mixtures} | |
| target_dict = {'waveform': targets} | |
| return input_dict, target_dict | |
| class ConditionalSisoBatchDataPreprocessor: | |
| def __init__(self, target_source_types: List[str]): | |
| r"""Conditional single input single output (SISO) batch data | |
| preprocessor. Select one target source from several target sources as | |
| training target and prepare the corresponding conditional vector. | |
| Args: | |
| target_source_types: List[str], e.g., ['vocals', 'bass', ...] | |
| """ | |
| self.target_source_types = target_source_types | |
| def __call__(self, batch_data_dict: Dict) -> List[Dict]: | |
| r"""Format waveforms and targets for training. | |
| Args: | |
| batch_data_dict: dict, e.g., { | |
| 'mixture': (batch_size, channels_num, segment_samples), | |
| 'vocals': (batch_size, channels_num, segment_samples), | |
| 'bass': (batch_size, channels_num, segment_samples), | |
| ..., | |
| } | |
| Returns: | |
| input_dict: dict, e.g., { | |
| 'waveform': (batch_size, channels_num, segment_samples), | |
| 'condition': (batch_size, target_sources_num), | |
| } | |
| output_dict: dict, e.g., { | |
| 'target': (batch_size, channels_num, segment_samples) | |
| } | |
| """ | |
| batch_size = len(batch_data_dict['mixture']) | |
| target_sources_num = len(self.target_source_types) | |
| assert ( | |
| batch_size % target_sources_num == 0 | |
| ), "Batch size should be \ | |
| evenly divided by target sources number." | |
| mixtures = batch_data_dict['mixture'] | |
| # mixtures: (batch_size, channels_num, segment_samples) | |
| conditions = torch.zeros(batch_size, target_sources_num).to(mixtures.device) | |
| # conditions: (batch_size, target_sources_num) | |
| targets = [] | |
| for n in range(batch_size): | |
| k = n % target_sources_num # source class index | |
| source_type = self.target_source_types[k] | |
| targets.append(batch_data_dict[source_type][n]) | |
| conditions[n, k] = 1 | |
| # conditions will looks like: | |
| # [[1, 0, 0, 0], | |
| # [0, 1, 0, 0], | |
| # [0, 0, 1, 0], | |
| # [0, 0, 0, 1], | |
| # [1, 0, 0, 0], | |
| # [0, 1, 0, 0], | |
| # ..., | |
| # ] | |
| targets = torch.stack(targets, dim=0) | |
| # targets: (batch_size, channels_num, segment_samples) | |
| input_dict = { | |
| 'waveform': mixtures, | |
| 'condition': conditions, | |
| } | |
| target_dict = {'waveform': targets} | |
| return input_dict, target_dict | |
| def get_batch_data_preprocessor_class(batch_data_preprocessor_type: str) -> object: | |
| r"""Get batch data preprocessor class.""" | |
| if batch_data_preprocessor_type == 'BasicBatchDataPreprocessor': | |
| return BasicBatchDataPreprocessor | |
| elif batch_data_preprocessor_type == 'ConditionalSisoBatchDataPreprocessor': | |
| return ConditionalSisoBatchDataPreprocessor | |
| else: | |
| raise NotImplementedError | |