Spaces:
Runtime error
Runtime error
| import argparse | |
| import logging | |
| import os | |
| import pathlib | |
| from functools import partial | |
| from typing import List, NoReturn | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.plugins import DDPPlugin | |
| from bytesep.callbacks import get_callbacks | |
| from bytesep.data.augmentors import Augmentor | |
| from bytesep.data.batch_data_preprocessors import ( | |
| get_batch_data_preprocessor_class, | |
| ) | |
| from bytesep.data.data_modules import DataModule, Dataset | |
| from bytesep.data.samplers import SegmentSampler | |
| from bytesep.losses import get_loss_function | |
| from bytesep.models.lightning_modules import ( | |
| LitSourceSeparation, | |
| get_model_class, | |
| ) | |
| from bytesep.optimizers.lr_schedulers import get_lr_lambda | |
| from bytesep.utils import ( | |
| create_logging, | |
| get_pitch_shift_factor, | |
| read_yaml, | |
| check_configs_gramma, | |
| ) | |
| def get_dirs( | |
| workspace: str, task_name: str, filename: str, config_yaml: str, gpus: int | |
| ) -> List[str]: | |
| r"""Get directories. | |
| Args: | |
| workspace: str | |
| task_name, str, e.g., 'musdb18' | |
| filenmae: str | |
| config_yaml: str | |
| gpus: int, e.g., 0 for cpu and 8 for training with 8 gpu cards | |
| Returns: | |
| checkpoints_dir: str | |
| logs_dir: str | |
| logger: pl.loggers.TensorBoardLogger | |
| statistics_path: str | |
| """ | |
| # save checkpoints dir | |
| checkpoints_dir = os.path.join( | |
| workspace, | |
| "checkpoints", | |
| task_name, | |
| filename, | |
| "config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus), | |
| ) | |
| os.makedirs(checkpoints_dir, exist_ok=True) | |
| # logs dir | |
| logs_dir = os.path.join( | |
| workspace, | |
| "logs", | |
| task_name, | |
| filename, | |
| "config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus), | |
| ) | |
| os.makedirs(logs_dir, exist_ok=True) | |
| # loggings | |
| create_logging(logs_dir, filemode='w') | |
| logging.info(args) | |
| # tensorboard logs dir | |
| tb_logs_dir = os.path.join(workspace, "tensorboard_logs") | |
| os.makedirs(tb_logs_dir, exist_ok=True) | |
| experiment_name = os.path.join(task_name, filename, pathlib.Path(config_yaml).stem) | |
| logger = pl.loggers.TensorBoardLogger(save_dir=tb_logs_dir, name=experiment_name) | |
| # statistics path | |
| statistics_path = os.path.join( | |
| workspace, | |
| "statistics", | |
| task_name, | |
| filename, | |
| "config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus), | |
| "statistics.pkl", | |
| ) | |
| os.makedirs(os.path.dirname(statistics_path), exist_ok=True) | |
| return checkpoints_dir, logs_dir, logger, statistics_path | |
| def _get_data_module( | |
| workspace: str, config_yaml: str, num_workers: int, distributed: bool | |
| ) -> DataModule: | |
| r"""Create data_module. Mini-batch data can be obtained by: | |
| code-block:: python | |
| data_module.setup() | |
| for batch_data_dict in data_module.train_dataloader(): | |
| print(batch_data_dict.keys()) | |
| break | |
| Args: | |
| workspace: str | |
| config_yaml: str | |
| num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores | |
| for preparing data in parallel | |
| distributed: bool | |
| Returns: | |
| data_module: DataModule | |
| """ | |
| configs = read_yaml(config_yaml) | |
| input_source_types = configs['train']['input_source_types'] | |
| indexes_path = os.path.join(workspace, configs['train']['indexes_dict']) | |
| sample_rate = configs['train']['sample_rate'] | |
| segment_seconds = configs['train']['segment_seconds'] | |
| mixaudio_dict = configs['train']['augmentations']['mixaudio'] | |
| augmentations = configs['train']['augmentations'] | |
| max_pitch_shift = max( | |
| [ | |
| augmentations['pitch_shift'][source_type] | |
| for source_type in input_source_types | |
| ] | |
| ) | |
| batch_size = configs['train']['batch_size'] | |
| steps_per_epoch = configs['train']['steps_per_epoch'] | |
| segment_samples = int(segment_seconds * sample_rate) | |
| ex_segment_samples = int(segment_samples * get_pitch_shift_factor(max_pitch_shift)) | |
| # sampler | |
| train_sampler = SegmentSampler( | |
| indexes_path=indexes_path, | |
| segment_samples=ex_segment_samples, | |
| mixaudio_dict=mixaudio_dict, | |
| batch_size=batch_size, | |
| steps_per_epoch=steps_per_epoch, | |
| ) | |
| # augmentor | |
| augmentor = Augmentor(augmentations=augmentations) | |
| # dataset | |
| train_dataset = Dataset(augmentor, segment_samples) | |
| # data module | |
| data_module = DataModule( | |
| train_sampler=train_sampler, | |
| train_dataset=train_dataset, | |
| num_workers=num_workers, | |
| distributed=distributed, | |
| ) | |
| return data_module | |
| def train(args) -> NoReturn: | |
| r"""Train & evaluate and save checkpoints. | |
| Args: | |
| workspace: str, directory of workspace | |
| gpus: int | |
| config_yaml: str, path of config file for training | |
| """ | |
| # arugments & parameters | |
| workspace = args.workspace | |
| gpus = args.gpus | |
| config_yaml = args.config_yaml | |
| filename = args.filename | |
| num_workers = 8 | |
| distributed = True if gpus > 1 else False | |
| evaluate_device = "cuda" if gpus > 0 else "cpu" | |
| # Read config file. | |
| configs = read_yaml(config_yaml) | |
| check_configs_gramma(configs) | |
| task_name = configs['task_name'] | |
| target_source_types = configs['train']['target_source_types'] | |
| target_sources_num = len(target_source_types) | |
| channels = configs['train']['channels'] | |
| batch_data_preprocessor_type = configs['train']['batch_data_preprocessor'] | |
| model_type = configs['train']['model_type'] | |
| loss_type = configs['train']['loss_type'] | |
| optimizer_type = configs['train']['optimizer_type'] | |
| learning_rate = float(configs['train']['learning_rate']) | |
| precision = configs['train']['precision'] | |
| early_stop_steps = configs['train']['early_stop_steps'] | |
| warm_up_steps = configs['train']['warm_up_steps'] | |
| reduce_lr_steps = configs['train']['reduce_lr_steps'] | |
| # paths | |
| checkpoints_dir, logs_dir, logger, statistics_path = get_dirs( | |
| workspace, task_name, filename, config_yaml, gpus | |
| ) | |
| # training data module | |
| data_module = _get_data_module( | |
| workspace=workspace, | |
| config_yaml=config_yaml, | |
| num_workers=num_workers, | |
| distributed=distributed, | |
| ) | |
| # batch data preprocessor | |
| BatchDataPreprocessor = get_batch_data_preprocessor_class( | |
| batch_data_preprocessor_type=batch_data_preprocessor_type | |
| ) | |
| batch_data_preprocessor = BatchDataPreprocessor( | |
| target_source_types=target_source_types | |
| ) | |
| # model | |
| Model = get_model_class(model_type=model_type) | |
| model = Model(input_channels=channels, target_sources_num=target_sources_num) | |
| # loss function | |
| loss_function = get_loss_function(loss_type=loss_type) | |
| # callbacks | |
| callbacks = get_callbacks( | |
| task_name=task_name, | |
| config_yaml=config_yaml, | |
| workspace=workspace, | |
| checkpoints_dir=checkpoints_dir, | |
| statistics_path=statistics_path, | |
| logger=logger, | |
| model=model, | |
| evaluate_device=evaluate_device, | |
| ) | |
| # callbacks = [] | |
| # learning rate reduce function | |
| lr_lambda = partial( | |
| get_lr_lambda, warm_up_steps=warm_up_steps, reduce_lr_steps=reduce_lr_steps | |
| ) | |
| # pytorch-lightning model | |
| pl_model = LitSourceSeparation( | |
| batch_data_preprocessor=batch_data_preprocessor, | |
| model=model, | |
| optimizer_type=optimizer_type, | |
| loss_function=loss_function, | |
| learning_rate=learning_rate, | |
| lr_lambda=lr_lambda, | |
| ) | |
| # trainer | |
| trainer = pl.Trainer( | |
| checkpoint_callback=False, | |
| gpus=gpus, | |
| callbacks=callbacks, | |
| max_steps=early_stop_steps, | |
| accelerator="ddp", | |
| sync_batchnorm=True, | |
| precision=precision, | |
| replace_sampler_ddp=False, | |
| plugins=[DDPPlugin(find_unused_parameters=True)], | |
| profiler='simple', | |
| ) | |
| # Fit, evaluate, and save checkpoints. | |
| trainer.fit(pl_model, data_module) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="") | |
| subparsers = parser.add_subparsers(dest="mode") | |
| parser_train = subparsers.add_parser("train") | |
| parser_train.add_argument( | |
| "--workspace", type=str, required=True, help="Directory of workspace." | |
| ) | |
| parser_train.add_argument("--gpus", type=int, required=True) | |
| parser_train.add_argument( | |
| "--config_yaml", | |
| type=str, | |
| required=True, | |
| help="Path of config file for training.", | |
| ) | |
| args = parser.parse_args() | |
| args.filename = pathlib.Path(__file__).stem | |
| if args.mode == "train": | |
| train(args) | |
| else: | |
| raise Exception("Error argument!") | |