File size: 10,194 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import AbstractContextManager, nullcontext
from datetime import timedelta
from typing import Any, Literal, Optional, Union
import torch
import torch.distributed
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
from typing_extensions import override
from lightning_fabric.accelerators.accelerator import Accelerator
from lightning_fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning_fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning_fabric.plugins.precision import Precision
from lightning_fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning_fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning_fabric.strategies.parallel import ParallelStrategy
from lightning_fabric.strategies.registry import _StrategyRegistry
from lightning_fabric.strategies.strategy import TBroadcast, _BackwardSyncControl
from lightning_fabric.utilities.distributed import (
ReduceOp,
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
_init_dist_connection,
_sync_ddp_if_available,
)
from lightning_fabric.utilities.distributed import group as _group
from lightning_fabric.utilities.rank_zero import rank_zero_only
_DDP_FORK_ALIASES = (
"ddp_fork",
"ddp_notebook",
)
class DDPStrategy(ParallelStrategy):
"""Strategy for multi-process single-device training on one or multiple nodes."""
def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[list[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
start_method: Literal["popen", "spawn", "fork", "forkserver"] = "popen",
**kwargs: Any,
) -> None:
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision=precision,
)
self._num_nodes = 1
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self._start_method = start_method
self._backward_sync_control = _DDPBackwardSyncControl()
self._ddp_kwargs = kwargs
@property
@override
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]
@property
def num_nodes(self) -> int:
return self._num_nodes
@num_nodes.setter
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
self._num_nodes = num_nodes
@property
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0
@property
@override
def distributed_sampler_kwargs(self) -> dict[str, Any]:
return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank}
@property
def process_group_backend(self) -> Optional[str]:
return self._process_group_backend
@override
def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
if self._start_method == "popen":
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
else:
self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method)
@override
def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()
@override
def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self._determine_ddp_device_ids()
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
@override
def module_to_device(self, module: Module) -> None:
module.to(self.root_device)
@override
def all_reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
"""Reduces a tensor from several distributed processes to one aggregated tensor.
Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
Can also be a string 'sum' to calculate the sum during reduction.
Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor
@override
def barrier(self, *args: Any, **kwargs: Any) -> None:
if not _distributed_is_initialized():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self._determine_ddp_device_ids())
else:
torch.distributed.barrier()
@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not _distributed_is_initialized():
return obj
obj = [obj]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]
@override
def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]:
if isinstance(module, DistributedDataParallel):
module = module.module
return super().get_module_state_dict(module)
@override
def load_module_state_dict(
self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True
) -> None:
if isinstance(module, DistributedDataParallel):
module = module.module
super().load_module_state_dict(module=module, state_dict=state_dict, strict=strict)
@classmethod
@override
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
entries = (
("ddp", "popen"),
("ddp_spawn", "spawn"),
("ddp_fork", "fork"),
("ddp_notebook", "fork"),
)
for name, start_method in entries:
strategy_registry.register(
name,
cls,
description=f"DDP strategy with `start_method={start_method!r}`",
start_method=start_method,
)
strategy_registry.register(
"ddp_find_unused_parameters_true",
cls,
description="Alias for `find_unused_parameters_true` and `start_method='popen'`",
find_unused_parameters=True,
start_method="popen",
)
def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
def _set_world_ranks(self) -> None:
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
def _determine_ddp_device_ids(self) -> Optional[list[int]]:
return None if self.root_device.type == "cpu" else [self.root_device.index]
class _DDPBackwardSyncControl(_BackwardSyncControl):
@override
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
"""Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel`
wrapper."""
if not enabled:
return nullcontext()
if not isinstance(module, DistributedDataParallel):
raise TypeError(
"Blocking backward sync is only possible if the module passed to"
f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `DistributedDataParallel`."
f" Got: {module.__class__.__name__}."
)
return module.no_sync()
|