# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from collections import Counter from collections.abc import Iterable from typing import Any, Optional, Union, cast import torch from typing_extensions import get_args from lightning_fabric.accelerators import ACCELERATOR_REGISTRY from lightning_fabric.accelerators.accelerator import Accelerator from lightning_fabric.accelerators.cuda import CUDAAccelerator from lightning_fabric.accelerators.mps import MPSAccelerator from lightning_fabric.accelerators.xla import XLAAccelerator from lightning_fabric.plugins import ( BitsandbytesPrecision, CheckpointIO, DeepSpeedPrecision, HalfPrecision, MixedPrecision, Precision, TransformerEnginePrecision, XLAPrecision, ) from lightning_fabric.plugins.environments import ( ClusterEnvironment, LightningEnvironment, LSFEnvironment, MPIEnvironment, SLURMEnvironment, TorchElasticEnvironment, ) from lightning_fabric.plugins.precision.double import DoublePrecision from lightning_fabric.plugins.precision.fsdp import FSDPPrecision from lightning_fabric.plugins.precision.precision import ( _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS, _PRECISION_INPUT_STR_ALIAS_CONVERSION, ) from lightning_fabric.strategies import ( STRATEGY_REGISTRY, DeepSpeedStrategy, ParallelStrategy, SingleDeviceStrategy, SingleDeviceXLAStrategy, Strategy, XLAFSDPStrategy, XLAStrategy, ) from lightning_fabric.strategies.ddp import _DDP_FORK_ALIASES from lightning_fabric.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy from lightning_fabric.strategies.model_parallel import ModelParallelStrategy from lightning_fabric.utilities import rank_zero_info, rank_zero_warn from lightning_fabric.utilities.device_parser import _determine_root_gpu_device from lightning_fabric.utilities.imports import _IS_INTERACTIVE _PLUGIN_INPUT = Union[Precision, ClusterEnvironment, CheckpointIO] class _Connector: """The Connector parses several Fabric arguments and instantiates the Strategy including its owned components. A. accelerator flag could be: 1. accelerator class 2. accelerator str 3. accelerator auto B. strategy flag could be: 1. strategy class 2. strategy str registered with STRATEGY_REGISTRY 3. strategy str in _strategy_type enum which listed in each strategy as backend (registered these too, and _strategy_type could be deprecated) C. plugins flag could be: 1. precision class (should be removed, and precision flag should allow user pass classes) 2. checkpoint_io class 3. cluster_environment class priorities which to take when: A. Class > str B. Strategy > Accelerator/precision/plugins """ def __init__( self, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", devices: Union[list[int], str, int] = "auto", num_nodes: int = 1, precision: Optional[_PRECISION_INPUT] = None, plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]] = None, ) -> None: # These arguments can be set through environment variables set by the CLI accelerator = self._argument_from_env("accelerator", accelerator, default="auto") strategy = self._argument_from_env("strategy", strategy, default="auto") devices = self._argument_from_env("devices", devices, default="auto") num_nodes = int(self._argument_from_env("num_nodes", num_nodes, default=1)) precision = self._argument_from_env("precision", precision, default=None) # 1. Parsing flags # Get registered strategies, built-in accelerators and precision plugins self._registered_strategies = STRATEGY_REGISTRY.available_strategies() self._registered_accelerators = ACCELERATOR_REGISTRY.available_accelerators() # Raise an exception if there are conflicts between flags # Set each valid flag to `self._x_flag` after validation # For devices: Assign gpus, etc. to the accelerator flag and devices flag self._strategy_flag: Union[Strategy, str] = "auto" self._accelerator_flag: Union[Accelerator, str] = "auto" self._precision_input: _PRECISION_INPUT_STR = "32-true" self._precision_instance: Optional[Precision] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None self._parallel_devices: list[Union[int, torch.device, str]] = [] self.checkpoint_io: Optional[CheckpointIO] = None self._check_config_and_set_final_flags( strategy=strategy, accelerator=accelerator, precision=precision, plugins=plugins, ) self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes) # 2. Instantiate Accelerator # handle `auto`, `None` and `gpu` if self._accelerator_flag == "auto": self._accelerator_flag = self._choose_auto_accelerator() elif self._accelerator_flag == "gpu": self._accelerator_flag = self._choose_gpu_accelerator_backend() self._set_parallel_devices_and_init_accelerator() # 3. Instantiate ClusterEnvironment self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment() # 4. Instantiate Strategy - Part 1 if self._strategy_flag == "auto": self._strategy_flag = self._choose_strategy() # In specific cases, ignore user selection and fall back to a different strategy self._check_strategy_and_fallback() self._init_strategy() # 5. Instantiate Precision Plugin self.precision = self._check_and_init_precision() # 6. Instantiate Strategy - Part 2 self._lazy_init_strategy() def _check_config_and_set_final_flags( self, strategy: Union[str, Strategy], accelerator: Union[str, Accelerator], precision: Optional[_PRECISION_INPUT], plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]], ) -> None: """This method checks: 1. strategy: whether the strategy name is valid, and sets the internal flags if it is. 2. accelerator: if the value of the accelerator argument is a type of accelerator (instance or string), set self._accelerator_flag accordingly. 3. precision: The final value of the precision flag may be determined either by the precision argument or by a plugin instance. 4. plugins: The list of plugins may contain a Precision plugin, CheckpointIO, ClusterEnvironment and others. Additionally, other flags such as `precision` can populate the list with the corresponding plugin instances. """ if plugins is not None: plugins = [plugins] if not isinstance(plugins, Iterable) else plugins if isinstance(strategy, str): strategy = strategy.lower() self._strategy_flag = strategy if strategy != "auto" and strategy not in self._registered_strategies and not isinstance(strategy, Strategy): raise ValueError( f"You selected an invalid strategy name: `strategy={strategy!r}`." " It must be either a string or an instance of `lightning_fabric.strategies.Strategy`." " Example choices: auto, ddp, ddp_spawn, deepspeed, dp, ..." " Find a complete list of options in our documentation at https://lightning.ai" ) if ( accelerator not in self._registered_accelerators and accelerator not in ("auto", "gpu") and not isinstance(accelerator, Accelerator) ): raise ValueError( f"You selected an invalid accelerator name: `accelerator={accelerator!r}`." f" Available names are: auto, {', '.join(self._registered_accelerators)}." ) # MPS accelerator is incompatible with DDP family of strategies. It supports single-device operation only. is_ddp_str = isinstance(strategy, str) and "ddp" in strategy is_dp_str = isinstance(strategy, str) and "dp" in strategy is_deepspeed_str = isinstance(strategy, str) and "deepspeed" in strategy is_parallel_strategy = isinstance(strategy, ParallelStrategy) or is_ddp_str or is_dp_str or is_deepspeed_str is_mps_accelerator = MPSAccelerator.is_available() and ( accelerator in ("mps", "auto", "gpu", None) or isinstance(accelerator, MPSAccelerator) ) if is_mps_accelerator and is_parallel_strategy: raise ValueError( f"You set `strategy={strategy}` but strategies from the DDP family are not supported on the" f" MPS accelerator. Either explicitly set `accelerator='cpu'` or change the strategy." ) self._accelerator_flag = accelerator precision_input = _convert_precision_to_unified_args(precision) if plugins: plugins_flags_types: dict[str, int] = Counter() for plugin in plugins: if isinstance(plugin, Precision): self._precision_instance = plugin plugins_flags_types[Precision.__name__] += 1 elif isinstance(plugin, CheckpointIO): self.checkpoint_io = plugin plugins_flags_types[CheckpointIO.__name__] += 1 elif isinstance(plugin, ClusterEnvironment): self._cluster_environment_flag = plugin plugins_flags_types[ClusterEnvironment.__name__] += 1 else: raise TypeError( f"Found invalid type for plugin {plugin}. Expected one of: Precision, " "CheckpointIO, ClusterEnvironment." ) duplicated_plugin_key = [k for k, v in plugins_flags_types.items() if v > 1] if duplicated_plugin_key: raise ValueError( f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`." " Expected one value for each type at most." ) if plugins_flags_types.get(Precision.__name__) and precision_input is not None: raise ValueError( f"Received both `precision={precision_input}` and `plugins={self._precision_instance}`. Choose one." ) self._precision_input = "32-true" if precision_input is None else precision_input # handle the case when the user passes in a strategy instance which has an accelerator, precision, # checkpoint io or cluster env set up # TODO: improve the error messages below if isinstance(self._strategy_flag, Strategy): if self._strategy_flag._accelerator: if self._accelerator_flag != "auto": raise ValueError("accelerator set through both strategy class and accelerator flag, choose one") self._accelerator_flag = self._strategy_flag._accelerator if self._strategy_flag._precision: # [RFC] handle precision plugin set up conflict? if self._precision_instance: raise ValueError("precision set through both strategy class and plugins, choose one") self._precision_instance = self._strategy_flag._precision if self._strategy_flag._checkpoint_io: if self.checkpoint_io: raise ValueError("checkpoint_io set through both strategy class and plugins, choose one") self.checkpoint_io = self._strategy_flag._checkpoint_io if getattr(self._strategy_flag, "cluster_environment", None): if self._cluster_environment_flag: raise ValueError("cluster_environment set through both strategy class and plugins, choose one") self._cluster_environment_flag = getattr(self._strategy_flag, "cluster_environment") if hasattr(self._strategy_flag, "parallel_devices") and self._strategy_flag.parallel_devices: if self._strategy_flag.parallel_devices[0].type == "cpu": if self._accelerator_flag and self._accelerator_flag not in ("auto", "cpu"): raise ValueError( f"CPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cpu" if self._strategy_flag.parallel_devices[0].type == "cuda": if self._accelerator_flag and self._accelerator_flag not in ("auto", "cuda", "gpu"): raise ValueError( f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cuda" self._parallel_devices = self._strategy_flag.parallel_devices def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str, int], num_nodes: int) -> None: if not isinstance(num_nodes, int) or num_nodes < 1: raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.") self._num_nodes_flag = num_nodes self._devices_flag = devices if self._devices_flag in ([], 0, "0"): accelerator_name = ( self._accelerator_flag.__class__.__qualname__ if isinstance(self._accelerator_flag, Accelerator) else self._accelerator_flag ) raise ValueError( f"`Fabric(devices={self._devices_flag!r})` value is not a valid input" f" using {accelerator_name} accelerator." ) @staticmethod def _choose_auto_accelerator() -> str: """Choose the accelerator type (str) based on availability when ``accelerator='auto'``.""" if XLAAccelerator.is_available(): return "tpu" if MPSAccelerator.is_available(): return "mps" if CUDAAccelerator.is_available(): return "cuda" return "cpu" @staticmethod def _choose_gpu_accelerator_backend() -> str: if MPSAccelerator.is_available(): return "mps" if CUDAAccelerator.is_available(): return "cuda" raise RuntimeError("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: if isinstance(self._accelerator_flag, Accelerator): self.accelerator: Accelerator = self._accelerator_flag else: assert self._accelerator_flag is not None self.accelerator = ACCELERATOR_REGISTRY.get(self._accelerator_flag) accelerator_cls = self.accelerator.__class__ if not accelerator_cls.is_available(): available_accelerator = [ acc_str for acc_str in self._registered_accelerators if ACCELERATOR_REGISTRY[acc_str]["accelerator"].is_available() ] raise RuntimeError( f"`{accelerator_cls.__qualname__}` can not run on your system" " since the accelerator is not available. The following accelerator(s)" " is available and can be passed into `accelerator` argument of" f" `Fabric`: {available_accelerator}." ) self._set_devices_flag_if_auto_passed() self._devices_flag = accelerator_cls.parse_devices(self._devices_flag) if not self._parallel_devices: self._parallel_devices = accelerator_cls.get_parallel_devices(self._devices_flag) def _set_devices_flag_if_auto_passed(self) -> None: if self._devices_flag != "auto": return if ( _IS_INTERACTIVE and isinstance(self.accelerator, CUDAAccelerator) and self.accelerator.auto_device_count() > 1 ): self._devices_flag = 1 rank_zero_info( f"Fabric will use only 1 of {self.accelerator.auto_device_count()} GPUs because it is running inside" " an interactive / notebook environment. You may try to set `Fabric(devices=" f"{self.accelerator.auto_device_count()})` but please note that multi-GPU inside interactive /" " notebook environments is considered experimental and unstable. Your mileage may vary." ) else: self._devices_flag = self.accelerator.auto_device_count() def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: if isinstance(self._cluster_environment_flag, ClusterEnvironment): return self._cluster_environment_flag for env_type in ( # TorchElastic has the highest priority since it can also be used inside SLURM TorchElasticEnvironment, SLURMEnvironment, LSFEnvironment, MPIEnvironment, ): if env_type.detect(): return env_type() return LightningEnvironment() def _choose_strategy(self) -> Union[Strategy, str]: if self._accelerator_flag == "tpu" or isinstance(self._accelerator_flag, XLAAccelerator): if self._parallel_devices and len(self._parallel_devices) > 1: return "xla" # TODO: lazy initialized device, then here could be self._strategy_flag = "single_xla" return SingleDeviceXLAStrategy(device=self._parallel_devices[0]) if self._num_nodes_flag > 1: return "ddp" if len(self._parallel_devices) <= 1: if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") ): device = _determine_root_gpu_device(self._parallel_devices) else: device = "cpu" # TODO: lazy initialized device, then here could be self._strategy_flag = "single_device" return SingleDeviceStrategy(device=device) # type: ignore if len(self._parallel_devices) > 1 and _IS_INTERACTIVE: return "ddp_fork" return "ddp" def _check_strategy_and_fallback(self) -> None: """Checks edge cases when the strategy selection was a string input, and we need to fall back to a different choice depending on other parameters or the environment.""" # current fallback and check logic only apply to user pass in str config and object config # TODO this logic should apply to both str and object config strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag # Change fsdp to xla_fsdp if using TPU if strategy_flag == "fsdp" and self._accelerator_flag == "tpu": strategy_flag = "xla_fsdp" if strategy_flag == "dp" and self._accelerator_flag == "cpu": rank_zero_warn(f"{strategy_flag!r} is not supported on CPUs, hence setting `strategy='ddp'`.") strategy_flag = "ddp" if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods(): raise ValueError( f"You selected `Fabric(strategy='{strategy_flag}')` but process forking is not supported on this" f" platform. We recommend `Fabric(strategy='ddp_spawn')` instead." ) if ( strategy_flag in _FSDP_ALIASES or type(self._strategy_flag) is FSDPStrategy ) and self._accelerator_flag not in ("cuda", "gpu"): raise ValueError( "You selected the FSDP strategy but FSDP is only available on GPU. Set `Fabric(accelerator='gpu', ...)`" " to continue or select a different strategy." ) if strategy_flag: self._strategy_flag = strategy_flag def _init_strategy(self) -> None: """Instantiate the Strategy given depending on the setting of ``_strategy_flag``.""" # The validation of `_strategy_flag` already happened earlier on in the connector assert isinstance(self._strategy_flag, (str, Strategy)) if isinstance(self._strategy_flag, str): self.strategy = STRATEGY_REGISTRY.get(self._strategy_flag) else: self.strategy = self._strategy_flag def _check_and_init_precision(self) -> Precision: if isinstance(self._precision_instance, Precision): if isinstance(self._precision_instance, BitsandbytesPrecision) and not isinstance( self.accelerator, CUDAAccelerator ): raise RuntimeError("Bitsandbytes is only supported on CUDA GPUs.") return self._precision_instance if isinstance(self.strategy, (SingleDeviceXLAStrategy, XLAStrategy, XLAFSDPStrategy)): return XLAPrecision(self._precision_input) # type: ignore if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_input) # type: ignore if isinstance(self.strategy, FSDPStrategy): return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type] mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true") if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported: raise ValueError( f"The `ModelParallelStrategy` does not support `Fabric(..., precision={self._precision_input!r})`." f" Choose a different precision among: {', '.join(mp_precision_supported)}." ) if self._precision_input in ("16-true", "bf16-true"): return HalfPrecision(self._precision_input) # type: ignore if self._precision_input == "32-true": return Precision() if self._precision_input == "64-true": return DoublePrecision() if self._precision_input == "transformer-engine": return TransformerEnginePrecision(weights_dtype=torch.bfloat16) if self._precision_input == "transformer-engine-float16": return TransformerEnginePrecision(weights_dtype=torch.float16) if self._precision_input == "16-mixed" and self._accelerator_flag == "cpu": rank_zero_warn( "You passed `Fabric(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on " "CPU. Using `precision='bf16-mixed'` instead." ) self._precision_input = "bf16-mixed" if self._precision_input in ("16-mixed", "bf16-mixed"): rank_zero_info( "Using 16-bit Automatic Mixed Precision (AMP)" if self._precision_input == "16-mixed" else "Using bfloat16 Automatic Mixed Precision (AMP)" ) device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda" return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] raise RuntimeError("No precision set") def _lazy_init_strategy(self) -> None: """Lazily set missing attributes on the previously instantiated strategy.""" self.strategy.accelerator = self.accelerator if self.precision: self.strategy.precision = self.precision if self.checkpoint_io: self.strategy.checkpoint_io = self.checkpoint_io if hasattr(self.strategy, "cluster_environment"): if self.strategy.cluster_environment is None: self.strategy.cluster_environment = self.cluster_environment self.cluster_environment = self.strategy.cluster_environment if hasattr(self.strategy, "parallel_devices"): if self.strategy.parallel_devices: self._parallel_devices = self.strategy.parallel_devices else: self.strategy.parallel_devices = self._parallel_devices if hasattr(self.strategy, "num_nodes"): self.strategy._num_nodes = self._num_nodes_flag if hasattr(self.strategy, "_set_world_ranks"): self.strategy._set_world_ranks() self.strategy._configure_launcher() if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible: raise RuntimeError( f"`Fabric(strategy={self._strategy_flag!r})` is not compatible with an interactive" " environment. Run your code as a script, or choose one of the compatible strategies:" f" `Fabric(strategy='dp'|'ddp_notebook')`." " In case you are spawning processes yourself, make sure to include the Fabric" " creation inside the worker function." ) # TODO: should be moved to _check_strategy_and_fallback(). # Current test check precision first, so keep this check here to meet error order if isinstance(self.accelerator, XLAAccelerator) and not isinstance( self.strategy, (SingleDeviceXLAStrategy, XLAStrategy, XLAFSDPStrategy) ): raise ValueError( "The `XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`, `XLAStrategy`, or" f" `XLAFSDPStrategy`. Found {self.strategy.__class__.__name__}." ) @staticmethod def _argument_from_env(name: str, current: Any, default: Any) -> Any: env_value: Optional[str] = os.environ.get("LT_" + name.upper()) if env_value is None: return current if env_value is not None and env_value != str(current) and str(current) != str(default) and _is_using_cli(): raise ValueError( f"Your code has `Fabric({name}={current!r}, ...)` but it conflicts with the value " f"`--{name}={env_value}` set through the CLI. " " Remove it either from the CLI or from the Lightning Fabric object." ) return env_value def _convert_precision_to_unified_args(precision: Optional[_PRECISION_INPUT]) -> Optional[_PRECISION_INPUT_STR]: if precision is None: return None supported_precision = ( get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + get_args(_PRECISION_INPUT_STR_ALIAS) ) if precision not in supported_precision: raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}") precision = str(precision) # convert int flags to str here to enable the legacy-conversion below if precision in get_args(_PRECISION_INPUT_STR_ALIAS): if str(precision)[:2] not in ("32", "64"): rank_zero_warn( f"`precision={precision}` is supported for historical reasons but its usage is discouraged. " f"Please set your precision to {_PRECISION_INPUT_STR_ALIAS_CONVERSION[precision]} instead!" ) precision = _PRECISION_INPUT_STR_ALIAS_CONVERSION[precision] return cast(_PRECISION_INPUT_STR, precision) def _is_using_cli() -> bool: return bool(int(os.environ.get("LT_CLI_USED", "0")))