{ "device": "$torch.device('cuda:' + os.environ['LOCAL_RANK'])", "use_tensorboard": "$dist.get_rank() == 0", "controlnet": { "_target_": "torch.nn.parallel.DistributedDataParallel", "module": "$@controlnet_def.to(@device)", "find_unused_parameters": true, "device_ids": [ "@device" ] }, "load_controlnet": "$@controlnet.module.load_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)", "train#sampler": { "_target_": "DistributedSampler", "dataset": "@train#dataset", "even_divisible": true, "shuffle": true }, "train#dataloader#sampler": "@train#sampler", "train#dataloader#shuffle": false, "train#trainer#train_handlers": "$@train#handlers[: -1 if dist.get_rank() > 0 else None]", "initialize": [ "$import torch.distributed as dist", "$dist.is_initialized() or dist.init_process_group(backend='nccl')", "$torch.cuda.set_device(@device)", "$monai.utils.set_determinism(seed=123)" ], "run": [ "$@train#trainer.run()" ], "finalize": [ "$dist.is_initialized() and dist.destroy_process_group()" ] }