File size: 3,532 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from abc import ABC, abstractmethod
from typing import Any, Optional

from torch import Tensor
from typing_extensions import Self

from lightning.fabric.utilities.types import CollectibleGroup


class Collective(ABC):
    """Interface for collective operations.

    Supports communications between multiple processes and multiple nodes. A collective owns a group.

    .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature which is still in development.

    """

    def __init__(self) -> None:
        self._group: Optional[CollectibleGroup] = None

    @property
    @abstractmethod
    def rank(self) -> int:
        """Rank."""

    @property
    @abstractmethod
    def world_size(self) -> int:
        """World size."""

    @property
    def group(self) -> CollectibleGroup:
        if self._group is None:
            raise RuntimeError(
                f"`{type(self).__name__}` does not own a group. HINT: try `collective.create_group().group`"
            )
        return self._group

    @abstractmethod
    def broadcast(self, tensor: Tensor, src: int) -> Tensor: ...

    @abstractmethod
    def all_reduce(self, tensor: Tensor, op: str) -> Tensor: ...

    @abstractmethod
    def reduce(self, tensor: Tensor, dst: int, op: str) -> Tensor: ...

    @abstractmethod
    def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]: ...

    @abstractmethod
    def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]: ...

    @abstractmethod
    def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor: ...

    @abstractmethod
    def reduce_scatter(self, output: Tensor, input_list: list[Tensor], op: str) -> Tensor: ...

    @abstractmethod
    def all_to_all(self, output_tensor_list: list[Tensor], input_tensor_list: list[Tensor]) -> list[Tensor]: ...

    @abstractmethod
    def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None: ...

    @abstractmethod
    def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor: ...

    @abstractmethod
    def barrier(self, device_ids: Optional[list[int]] = None) -> None: ...

    @classmethod
    @abstractmethod
    def is_available(cls) -> bool: ...

    @classmethod
    @abstractmethod
    def is_initialized(cls) -> bool: ...

    @classmethod
    @abstractmethod
    def init_group(cls, **kwargs: Any) -> None: ...

    @classmethod
    @abstractmethod
    def new_group(cls, **kwargs: Any) -> CollectibleGroup: ...

    @classmethod
    @abstractmethod
    def destroy_group(cls, group: CollectibleGroup) -> None: ...

    @classmethod
    @abstractmethod
    def _convert_to_native_op(cls, op: str) -> Any: ...

    def setup(self, **kwargs: Any) -> Self:
        if not self.is_initialized():
            self.init_group(**kwargs)
        return self

    def create_group(self, **kwargs: Any) -> Self:
        """Create a group.

        This assumes that :meth:`~lightning.fabric.plugins.collectives.Collective.init_group` has been
        called already by the user.

        """
        if self._group is not None:
            raise RuntimeError(f"`{type(self).__name__}` already owns a group.")
        self._group = self.new_group(**kwargs)
        return self

    def teardown(self) -> Self:
        if self._group is None:
            raise RuntimeError(f"`{type(self).__name__}` does not own a group to destroy.")
        self.destroy_group(self._group)
        self._group = None
        return self