File size: 7,729 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import socket

from typing_extensions import override

from lightning_fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_fabric.utilities.cloud_io import get_filesystem

log = logging.getLogger(__name__)


class LSFEnvironment(ClusterEnvironment):
    """An environment for running on clusters managed by the LSF resource manager.

    It is expected that any execution using this ClusterEnvironment was executed
    using the Job Step Manager i.e. ``jsrun``.

    This plugin expects the following environment variables:

    ``LSB_JOBID``
      The LSF assigned job ID

    ``LSB_DJOB_RANKFILE``
      The OpenMPI compatible rank file for the LSF job

    ``JSM_NAMESPACE_LOCAL_RANK``
      The node local rank for the task. This environment variable is set by ``jsrun``

    ``JSM_NAMESPACE_SIZE``
      The world size for the task. This environment variable is set by ``jsrun``

    ``JSM_NAMESPACE_RANK``
      The global rank for the task. This environment variable is set by ``jsrun``

    """

    def __init__(self) -> None:
        super().__init__()
        self._main_address = self._get_main_address()
        self._main_port = self._get_main_port()
        self._node_rank = self._get_node_rank()
        self._set_init_progress_group_env_vars()

    def _set_init_progress_group_env_vars(self) -> None:
        # set environment variables needed for initializing torch distributed process group
        os.environ["MASTER_ADDR"] = str(self._main_address)
        log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
        os.environ["MASTER_PORT"] = str(self._main_port)
        log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

    @property
    @override
    def creates_processes_externally(self) -> bool:
        """LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them."""
        return True

    @property
    @override
    def main_address(self) -> str:
        """The main address is read from an OpenMPI host rank file in the environment variable
        ``LSB_DJOB_RANKFILE``."""
        return self._main_address

    @property
    @override
    def main_port(self) -> int:
        """The main port is calculated from the LSF job ID."""
        return self._main_port

    @staticmethod
    @override
    def detect() -> bool:
        """Returns ``True`` if the current process was launched using the ``jsrun`` command."""
        required_env_vars = {"LSB_JOBID", "LSB_DJOB_RANKFILE", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
        return required_env_vars.issubset(os.environ.keys())

    @override
    def world_size(self) -> int:
        """The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``."""
        world_size = os.environ.get("JSM_NAMESPACE_SIZE")
        if world_size is None:
            raise ValueError(
                "Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found."
                " Make sure you run your executable with `jsrun`."
            )
        return int(world_size)

    @override
    def set_world_size(self, size: int) -> None:
        log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

    @override
    def global_rank(self) -> int:
        """The world size is read from the environment variable ``JSM_NAMESPACE_RANK``."""
        global_rank = os.environ.get("JSM_NAMESPACE_RANK")
        if global_rank is None:
            raise ValueError(
                "Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found."
                " Make sure you run your executable with `jsrun`."
            )
        return int(global_rank)

    @override
    def set_global_rank(self, rank: int) -> None:
        log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

    @override
    def local_rank(self) -> int:
        """The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
        local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK")
        if local_rank is None:
            raise ValueError(
                "Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found."
                " Make sure you run your executable with `jsrun`."
            )
        return int(local_rank)

    @override
    def node_rank(self) -> int:
        """The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored in
        ``LSB_DJOB_RANKFILE``."""
        return self._node_rank

    def _get_node_rank(self) -> int:
        """A helper method for getting the node rank.

        The node rank is determined by the position of the current node in the list of hosts used in the job. This is
        calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list.

        """
        hosts = self._read_hosts()
        count: dict[str, int] = {}
        for host in hosts:
            if host not in count:
                count[host] = len(count)
        return count[socket.gethostname()]

    @staticmethod
    def _read_hosts() -> list[str]:
        """Read compute hosts that are a part of the compute job.

        LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes.
        Each job is assigned a launch node. This launch node will be the first node in the list contained in
        ``LSB_DJOB_RANKFILE``.

        """
        var = "LSB_DJOB_RANKFILE"
        rankfile = os.environ.get(var)
        if rankfile is None:
            raise ValueError("Did not find the environment variable `LSB_DJOB_RANKFILE`")
        if not rankfile:
            raise ValueError("The environment variable `LSB_DJOB_RANKFILE` is empty")

        fs = get_filesystem(rankfile)
        with fs.open(rankfile, "r") as f:
            ret = [line.strip() for line in f]
        # remove the launch node (i.e. the first node in LSB_DJOB_RANKFILE) from the list
        return ret[1:]

    def _get_main_address(self) -> str:
        """A helper for getting the main address.

        The main address is assigned to the first node in the list of nodes used for the job.

        """
        hosts = self._read_hosts()
        return hosts[0]

    @staticmethod
    def _get_main_port() -> int:
        """A helper function for accessing the main port.

        Uses the LSF job ID so all ranks can compute the main port.

        """
        # check for user-specified main port
        if "MASTER_PORT" in os.environ:
            log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}")
            return int(os.environ["MASTER_PORT"])
        if "LSB_JOBID" in os.environ:
            port = int(os.environ["LSB_JOBID"])
            # all ports should be in the 10k+ range
            port = port % 1000 + 10000
            log.debug(f"calculated LSF main port: {port}")
            return port
        raise ValueError("Could not find job id in environment variable LSB_JOBID")