# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities that can be used for calling functions on a particular rank.""" import logging import os from typing import Optional import lightning_utilities.core.rank_zero as rank_zero_module # note: we want to keep these indirections so the `rank_zero_only.rank` is set on import from lightning_utilities.core.rank_zero import ( # noqa: F401 WarningCache, rank_prefixed_message, rank_zero_debug, rank_zero_deprecation, rank_zero_info, rank_zero_warn, ) rank_zero_module.log = logging.getLogger(__name__) def _get_rank() -> Optional[int]: # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, # therefore LOCAL_RANK needs to be checked first rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") for key in rank_keys: rank = os.environ.get(key) if rank is not None: return int(rank) # None to differentiate whether an environment variable was set at all return None rank_zero_only = rank_zero_module.rank_zero_only # add the attribute to the function but don't overwrite in case Trainer has already set it rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank() or 0) class LightningDeprecationWarning(DeprecationWarning): """Deprecation warnings raised by Lightning.""" rank_zero_module.rank_zero_deprecation_category = LightningDeprecationWarning