File size: 17,644 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
import atexit
import contextlib
import logging
import os
import signal
import time
from collections.abc import Iterable, Iterator, Sized
from contextlib import nullcontext
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union

import torch
import torch.nn.functional as F
from lightning_utilities.core.imports import package_available
from torch import Tensor
from torch.utils.data import Dataset, DistributedSampler, Sampler
from typing_extensions import Self, TypeGuard, override

from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.data import _num_cpus_available
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.rank_zero import rank_zero_info
from lightning.fabric.utilities.types import _PATH, ReduceOp

if torch.distributed.is_available():
    from torch.distributed import group
else:

    class group:  # type: ignore
        WORLD = None


if TYPE_CHECKING:
    from torch.distributed._tensor import DTensor

    from lightning.fabric.plugins import ClusterEnvironment
    from lightning.fabric.strategies import Strategy


log = logging.getLogger(__name__)


def is_shared_filesystem(strategy: "Strategy", path: Optional[_PATH] = None, timeout: int = 3) -> bool:
    """Checks whether the filesystem under the given path is shared across all processes.

    This function should only be used in a context where distributed is initialized.

    Args:
        strategy: The strategy being used, either from Fabric (``fabric.strategy``) or from Trainer
            (``trainer.strategy``).
        path: The path to check. Defaults to the current working directory. The user must have permissions to write
            to this path or the parent folder, and the filesystem must be writable.
        timeout: If any of the processes can't list the file created by rank 0 within this many seconds, the
            filesystem is determined to be not shared.

    """
    # Fast path: Any non-local filesystem is considered shared (e.g., S3)
    if path is not None and not _is_local_file_protocol(path):
        return True

    path = Path(Path.cwd() if path is None else path).resolve()

    # Fast path: Only distributed strategies can detect shared filesystems
    if not hasattr(strategy, "world_size") or strategy.world_size == 1:
        return True

    # Fast path: If the path is not the same on all ranks we know it's not a shared filesystem
    rank_zero_path = strategy.broadcast(path)
    if not strategy.reduce_boolean_decision(rank_zero_path == path, all=True):
        return False

    if not strategy.reduce_boolean_decision(path.exists(), all=True):
        raise FileNotFoundError(
            f"Unable to determine if the path belongs to a shared filesystem. The path does not exist: {path}"
        )

    path = path.parent if path.is_file() else path
    check_file = path / ".lightning_shared_fs_check"
    check_file.unlink(missing_ok=True)

    strategy.barrier()
    if strategy.is_global_zero:
        # Rank 0 creates the file
        check_file.touch()
        found = True
    else:
        # All other ranks will wait until they find the file or timeout
        start = time.perf_counter()
        found = False
        while not found and (time.perf_counter() - start) < timeout:
            found = check_file.exists()
    strategy.barrier()

    all_found = strategy.reduce_boolean_decision(found, all=True)

    with contextlib.suppress(OSError):  # handle race condition on deletion
        check_file.unlink()

    return all_found


def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> list[Tensor]:
    """Function to gather all tensors from several DDP processes onto a list that is broadcasted to all processes.

    Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
    tensors are padded, gathered and then trimmed to secure equal workload for all processes.

    Args:
        result: The value to sync
        group: The process group to gather results from. Defaults to all processes (world)

    Return:
        gathered_result: List with size equal to the process group where
            gathered_result[i] corresponds to result tensor from process i

    """
    if group is None:
        group = torch.distributed.group.WORLD

    # Convert tensors to contiguous format
    result = result.contiguous()

    world_size = torch.distributed.get_world_size(group)
    torch.distributed.barrier(group=group)

    # If the tensor is scalar, things are easy
    if result.ndim == 0:
        return _simple_gather_all_tensors(result, group, world_size)

    # 1. Gather sizes of all tensors
    local_size = torch.tensor(result.shape, device=result.device)
    local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
    torch.distributed.all_gather(local_sizes, local_size, group=group)
    max_size = torch.stack(local_sizes).max(dim=0).values
    all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)

    # 2. If shapes are all the same, then do a simple gather:
    if all_sizes_equal:
        return _simple_gather_all_tensors(result, group, world_size)

    # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
    pad_dims = []
    pad_by = (max_size - local_size).detach().cpu()
    for val in reversed(pad_by):
        pad_dims.append(0)
        pad_dims.append(val.item())
    result_padded = F.pad(result, pad_dims)
    gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
    torch.distributed.all_gather(gathered_result, result_padded, group)
    for idx, item_size in enumerate(local_sizes):
        slice_param = [slice(dim_size) for dim_size in item_size]
        gathered_result[idx] = gathered_result[idx][slice_param]
    return gathered_result


def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> list[Tensor]:
    gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
    torch.distributed.all_gather(gathered_result, result, group)
    return gathered_result


def _sync_ddp_if_available(
    result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> Tensor:
    """Function to reduce a tensor across worker processes during distributed training.

    Args:
        result: The value to sync and reduce (typically tensor or number)
        group: The process group to gather results from. Defaults to all processes (world)
        reduce_op: The reduction operation. Defaults to sum.
            Can also be a string of 'avg', 'mean' to calculate the mean during reduction.

    Return:
        reduced value

    """
    if _distributed_is_initialized():
        return _sync_ddp(result, group=group, reduce_op=reduce_op)
    return result


def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor:
    """Reduces a tensor across several distributed processes.

    This operation is performed in-place, meaning the result will be placed back into the input tensor on all processes.

    Args:
        result: The value to sync and reduce (typically tensor or number)
        group: The process group to gather results from. Defaults to all processes (world)
        reduce_op: The reduction operation. Defaults to sum.
            Can also be a string of 'avg', 'mean' to calculate the mean during reduction.

    Return:
        The reduced value.

    """
    divide_by_world_size = False
    group = torch.distributed.group.WORLD if group is None else group

    op: Optional[ReduceOp]
    if isinstance(reduce_op, str):
        reduce_op = "avg" if reduce_op == "mean" else reduce_op
        if reduce_op.lower() == "avg" and torch.distributed.get_backend(group) == "gloo":
            # The GLOO backend does not support the `ReduceOp.AVG` operation
            op = ReduceOp.SUM  # type: ignore[assignment]
            divide_by_world_size = True
        else:
            op = getattr(ReduceOp, reduce_op.upper())
    else:
        op = reduce_op

    # HPU doesn't support Long types, forcefully set it to float
    # TODO: move this to the `lightning_habana` package
    if (
        package_available("habana_frameworks")
        and os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
        and result.type()
        in (
            "torch.LongTensor",
            "torch.hpu.LongTensor",
        )
    ):
        rank_zero_info("Long tensor unsupported on HPU, casting to float")
        result = result.float()

    # Sync all processes before reduction
    torch.distributed.barrier(group=group)
    torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
    world_size = torch.distributed.get_world_size(group)

    if not divide_by_world_size:
        return result
    # `torch.distributed.all_reduce` is in-place, so we should do the division in-place to leave the modified tensors
    # with the expected value
    if not torch.is_floating_point(result):
        return result.copy_(result / world_size)
    return result.div_(world_size)


def _all_gather_ddp_if_available(
    tensor: Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False
) -> Tensor:
    """Function to gather a tensor from several distributed processes.

    Args:
        tensor: Tensor of shape (batch, ...)
        group: The process group to gather results from. Defaults to all processes (world)
        sync_grads: Flag that allows users to synchronize gradients for all_gather op

    Return:
        A tensor of shape (world_size, batch, ...)

    """
    if not _distributed_is_initialized():
        return tensor

    from torch.distributed.nn.functional import all_gather

    tensor = tensor.contiguous()  # https://github.com/pytorch/pytorch/issues/73515
    with nullcontext() if sync_grads else torch.no_grad():
        gathered_tensors = all_gather(tensor, group)
    return torch.stack(gathered_tensors)


def _init_dist_connection(
    cluster_environment: "ClusterEnvironment",
    torch_distributed_backend: str,
    global_rank: Optional[int] = None,
    world_size: Optional[int] = None,
    **kwargs: Any,
) -> None:
    """Utility function to initialize distributed connection by setting env variables and initializing the distributed
    process group.

    Args:
        cluster_environment: ``ClusterEnvironment`` instance
        torch_distributed_backend: Backend to use (includes `nccl` and `gloo`)
        global_rank: Rank of the current process
        world_size: Number of processes in the group
        kwargs: Kwargs for ``init_process_group``

    Raises:
        RuntimeError:
            If ``torch.distributed`` is not available

    """
    if not torch.distributed.is_available():
        raise RuntimeError("torch.distributed is not available. Cannot initialize distributed process group")
    if torch.distributed.is_initialized():
        log.debug("torch.distributed is already initialized. Exiting early")
        return
    global_rank = global_rank if global_rank is not None else cluster_environment.global_rank()
    world_size = world_size if world_size is not None else cluster_environment.world_size()
    os.environ["MASTER_ADDR"] = cluster_environment.main_address
    os.environ["MASTER_PORT"] = str(cluster_environment.main_port)
    log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
    torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)

    if torch_distributed_backend == "nccl":
        # PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit
        atexit.register(_destroy_dist_connection)

    # On rank=0 let everyone know training is starting
    rank_zero_info(
        f"{'-' * 100}\n"
        f"distributed_backend={torch_distributed_backend}\n"
        f"All distributed processes registered. Starting with {world_size} processes\n"
        f"{'-' * 100}\n"
    )


def _destroy_dist_connection() -> None:
    # Don't allow Ctrl+C to interrupt this handler
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    if _distributed_is_initialized():
        torch.distributed.destroy_process_group()
    signal.signal(signal.SIGINT, signal.SIG_DFL)


def _get_default_process_group_backend_for_device(device: torch.device) -> str:
    return "nccl" if device.type == "cuda" else "gloo"


class _DatasetSamplerWrapper(Dataset):
    """Dataset to create indexes from `Sampler` or `Iterable`"""

    def __init__(self, sampler: Union[Sampler, Iterable]) -> None:
        if not isinstance(sampler, Sized):
            raise TypeError(
                "You seem to have configured a sampler in your DataLoader which"
                " does not provide `__len__` method. The sampler was about to be"
                " replaced by `DistributedSamplerWrapper` since `use_distributed_sampler`"
                " is True and you are using distributed training. Either provide `__len__`"
                " method in your sampler, remove it from DataLoader or set `use_distributed_sampler=False`"
                " if you want to handle distributed sampling yourself."
            )
        if len(sampler) == float("inf"):
            raise TypeError(
                "You seem to have configured a sampler in your DataLoader which"
                " does not provide finite `__len__` method. The sampler was about to be"
                " replaced by `DistributedSamplerWrapper` since `use_distributed_sampler`"
                " is True and you are using distributed training. Either provide `__len__`"
                " method in your sampler which returns a finite number, remove it from DataLoader"
                " or set `use_distributed_sampler=False` if you want to handle distributed sampling yourself."
            )
        self._sampler = sampler
        # defer materializing an iterator until it is necessary
        self._sampler_list: Optional[list[Any]] = None

    @override
    def __getitem__(self, index: int) -> Any:
        if self._sampler_list is None:
            self._sampler_list = list(self._sampler)
        return self._sampler_list[index]

    def __len__(self) -> int:
        return len(self._sampler)

    def reset(self) -> None:
        """Reset the sampler list in order to get new sampling."""
        self._sampler_list = list(self._sampler)


class DistributedSamplerWrapper(DistributedSampler):
    """Wrapper over ``Sampler`` for distributed training.

    Allows you to use any sampler in distributed mode. It will be automatically used by Lightning in distributed mode if
    sampler replacement is enabled.

    Note:
        The purpose of this wrapper is to take care of sharding the sampler indices. It is up to the underlying
        sampler to handle randomness and shuffling. The ``shuffle`` and ``seed`` arguments on this wrapper won't
        have any effect.

    """

    def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None:
        super().__init__(_DatasetSamplerWrapper(sampler), *args, **kwargs)

    @override
    def __iter__(self) -> Iterator:
        self.dataset.reset()
        return (self.dataset[index] for index in super().__iter__())


def _suggested_max_num_threads(num_processes: int = 1) -> int:
    if num_processes < 1:
        raise ValueError(f"`num_processes` should be >= 1, got {num_processes}.")
    return max(1, _num_cpus_available() // num_processes)


def _set_num_threads_if_needed(num_processes: int = 1) -> None:
    if "OMP_NUM_THREADS" not in os.environ:
        num_threads = _suggested_max_num_threads(num_processes)
        torch.set_num_threads(num_threads)
        os.environ["OMP_NUM_THREADS"] = str(num_threads)


def _distributed_is_initialized() -> bool:
    # `is_initialized` is only defined conditionally
    # https://github.com/pytorch/pytorch/blob/v2.1.0/torch/distributed/__init__.py#L25
    # this might happen to MacOS builds from source (default) or any build from source that sets `USE_DISTRIBUTED=0`
    return torch.distributed.is_available() and torch.distributed.is_initialized()


class _InfiniteBarrier:
    """A barrier with an infinite timeout.

    Creates a new process group with the GLOO backend with a very high timeout that makes the barrier effectively wait
    forever. This is useful in cases where you want to execute a long-running operation on a subset of ranks that should
    not be subject to the regular collective timeout.

    """

    def __init__(self) -> None:
        self.group = None
        self.barrier = lambda: None

    def __call__(self) -> None:
        self.barrier()

    def __enter__(self) -> Self:
        if _distributed_is_initialized():
            # Create a barrier with an 'infinite' timeout (only reliably possible over the GLOO backend)
            self.group = torch.distributed.new_group(backend="gloo", timeout=timedelta(days=10000))
            self.barrier = self.group.monitored_barrier
        return self

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        self.barrier()
        if self.group is not None:
            torch.distributed.destroy_process_group(self.group)


def _is_dtensor(tensor: Tensor) -> TypeGuard["DTensor"]:
    if _TORCH_GREATER_EQUAL_2_4:
        from torch.distributed._tensor import DTensor

        return isinstance(tensor, DTensor)
    return False