Spaces:
Running
Running
| import os | |
| import torch | |
| import torch.nn as nn | |
| from utils.utils_bnorm import merge_bn, tidy_sequential | |
| from torch.nn.parallel import DataParallel, DistributedDataParallel | |
| class ModelBase(): | |
| def __init__(self, opt): | |
| self.opt = opt # opt | |
| self.save_dir = opt['path']['models'] # save models | |
| self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') | |
| self.is_train = opt['is_train'] # training or not | |
| self.schedulers = [] # schedulers | |
| """ | |
| # ---------------------------------------- | |
| # Preparation before training with data | |
| # Save model during training | |
| # ---------------------------------------- | |
| """ | |
| def init_train(self): | |
| pass | |
| def load(self): | |
| pass | |
| def save(self, label): | |
| pass | |
| def define_loss(self): | |
| pass | |
| def define_optimizer(self): | |
| pass | |
| def define_scheduler(self): | |
| pass | |
| """ | |
| # ---------------------------------------- | |
| # Optimization during training with data | |
| # Testing/evaluation | |
| # ---------------------------------------- | |
| """ | |
| def feed_data(self, data): | |
| pass | |
| def optimize_parameters(self): | |
| pass | |
| def current_visuals(self): | |
| pass | |
| def current_losses(self): | |
| pass | |
| def update_learning_rate(self, n): | |
| for scheduler in self.schedulers: | |
| scheduler.step(n) | |
| def current_learning_rate(self): | |
| return self.schedulers[0].get_lr()[0] | |
| def requires_grad(self, model, flag=True): | |
| for p in model.parameters(): | |
| p.requires_grad = flag | |
| """ | |
| # ---------------------------------------- | |
| # Information of net | |
| # ---------------------------------------- | |
| """ | |
| def print_network(self): | |
| pass | |
| def info_network(self): | |
| pass | |
| def print_params(self): | |
| pass | |
| def info_params(self): | |
| pass | |
| def get_bare_model(self, network): | |
| """Get bare model, especially under wrapping with | |
| DistributedDataParallel or DataParallel. | |
| """ | |
| if isinstance(network, (DataParallel, DistributedDataParallel)): | |
| network = network.module | |
| return network | |
| def model_to_device(self, network): | |
| """Model to device. It also warps models with DistributedDataParallel | |
| or DataParallel. | |
| Args: | |
| network (nn.Module) | |
| """ | |
| network = network.to(self.device) | |
| if self.opt['dist']: | |
| find_unused_parameters = self.opt.get('find_unused_parameters', True) | |
| use_static_graph = self.opt.get('use_static_graph', False) | |
| network = DistributedDataParallel(network, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) | |
| if use_static_graph: | |
| print('Using static graph. Make sure that "unused parameters" will not change during training loop.') | |
| network._set_static_graph() | |
| else: | |
| network = DataParallel(network) | |
| return network | |
| # ---------------------------------------- | |
| # network name and number of parameters | |
| # ---------------------------------------- | |
| def describe_network(self, network): | |
| network = self.get_bare_model(network) | |
| msg = '\n' | |
| msg += 'Networks name: {}'.format(network.__class__.__name__) + '\n' | |
| msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), network.parameters()))) + '\n' | |
| msg += 'Net structure:\n{}'.format(str(network)) + '\n' | |
| return msg | |
| # ---------------------------------------- | |
| # parameters description | |
| # ---------------------------------------- | |
| def describe_params(self, network): | |
| network = self.get_bare_model(network) | |
| msg = '\n' | |
| msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n' | |
| for name, param in network.state_dict().items(): | |
| if not 'num_batches_tracked' in name: | |
| v = param.data.clone().float() | |
| msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n' | |
| return msg | |
| """ | |
| # ---------------------------------------- | |
| # Save prameters | |
| # Load prameters | |
| # ---------------------------------------- | |
| """ | |
| # ---------------------------------------- | |
| # save the state_dict of the network | |
| # ---------------------------------------- | |
| def save_network(self, save_dir, network, network_label, iter_label): | |
| save_filename = '{}_{}.pth'.format(iter_label, network_label) | |
| save_path = os.path.join(save_dir, save_filename) | |
| network = self.get_bare_model(network) | |
| state_dict = network.state_dict() | |
| for key, param in state_dict.items(): | |
| state_dict[key] = param.cpu() | |
| torch.save(state_dict, save_path) | |
| # ---------------------------------------- | |
| # load the state_dict of the network | |
| # ---------------------------------------- | |
| def load_network(self, load_path, network, strict=True, param_key='params'): | |
| network = self.get_bare_model(network) | |
| if strict: | |
| state_dict = torch.load(load_path) | |
| if param_key in state_dict.keys(): | |
| state_dict = state_dict[param_key] | |
| network.load_state_dict(state_dict, strict=strict) | |
| else: | |
| state_dict_old = torch.load(load_path) | |
| if param_key in state_dict_old.keys(): | |
| state_dict_old = state_dict_old[param_key] | |
| state_dict = network.state_dict() | |
| for ((key_old, param_old),(key, param)) in zip(state_dict_old.items(), state_dict.items()): | |
| state_dict[key] = param_old | |
| network.load_state_dict(state_dict, strict=True) | |
| del state_dict_old, state_dict | |
| # ---------------------------------------- | |
| # save the state_dict of the optimizer | |
| # ---------------------------------------- | |
| def save_optimizer(self, save_dir, optimizer, optimizer_label, iter_label): | |
| save_filename = '{}_{}.pth'.format(iter_label, optimizer_label) | |
| save_path = os.path.join(save_dir, save_filename) | |
| torch.save(optimizer.state_dict(), save_path) | |
| # ---------------------------------------- | |
| # load the state_dict of the optimizer | |
| # ---------------------------------------- | |
| def load_optimizer(self, load_path, optimizer): | |
| optimizer.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device()))) | |
| def update_E(self, decay=0.999): | |
| netG = self.get_bare_model(self.netG) | |
| netG_params = dict(netG.named_parameters()) | |
| netE_params = dict(self.netE.named_parameters()) | |
| for k in netG_params.keys(): | |
| netE_params[k].data.mul_(decay).add_(netG_params[k].data, alpha=1-decay) | |
| """ | |
| # ---------------------------------------- | |
| # Merge Batch Normalization for training | |
| # Merge Batch Normalization for testing | |
| # ---------------------------------------- | |
| """ | |
| # ---------------------------------------- | |
| # merge bn during training | |
| # ---------------------------------------- | |
| def merge_bnorm_train(self): | |
| merge_bn(self.netG) | |
| tidy_sequential(self.netG) | |
| self.define_optimizer() | |
| self.define_scheduler() | |
| # ---------------------------------------- | |
| # merge bn before testing | |
| # ---------------------------------------- | |
| def merge_bnorm_test(self): | |
| merge_bn(self.netG) | |
| tidy_sequential(self.netG) | |