# Copyright Lightning AI. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Houses the methods used to set up the Trainer.""" from typing import Optional, Union import pytorch_lightning as pl from lightning_fabric.utilities.warnings import PossibleUserWarning from pytorch_lightning.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator from pytorch_lightning.loggers.logger import DummyLogger from pytorch_lightning.profilers import ( AdvancedProfiler, PassThroughProfiler, Profiler, PyTorchProfiler, SimpleProfiler, XLAProfiler, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _habana_available_and_importable from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn def _init_debugging_flags( trainer: "pl.Trainer", limit_train_batches: Optional[Union[int, float]], limit_val_batches: Optional[Union[int, float]], limit_test_batches: Optional[Union[int, float]], limit_predict_batches: Optional[Union[int, float]], fast_dev_run: Union[int, bool], overfit_batches: Union[int, float], val_check_interval: Optional[Union[int, float]], num_sanity_val_steps: int, ) -> None: # init debugging flags if isinstance(fast_dev_run, int) and (fast_dev_run < 0): raise MisconfigurationException( f"fast_dev_run={fast_dev_run!r} is not a valid configuration. It should be >= 0." ) trainer.fast_dev_run = fast_dev_run # set fast_dev_run=True when it is 1, used while logging if fast_dev_run == 1: trainer.fast_dev_run = True trainer.overfit_batches = _determine_batch_limits(overfit_batches, "overfit_batches") overfit_batches_enabled = overfit_batches > 0 if fast_dev_run: num_batches = int(fast_dev_run) if not overfit_batches_enabled: trainer.limit_train_batches = num_batches trainer.limit_val_batches = num_batches trainer.limit_test_batches = num_batches trainer.limit_predict_batches = num_batches trainer.fit_loop.epoch_loop.max_steps = num_batches trainer.num_sanity_val_steps = 0 trainer.fit_loop.max_epochs = 1 trainer.val_check_interval = 1.0 trainer.check_val_every_n_epoch = 1 trainer.loggers = [DummyLogger()] if trainer.loggers else [] rank_zero_info( f"Running in `fast_dev_run` mode: will run the requested loop using {num_batches} batch(es). " "Logging and checkpointing is suppressed." ) else: if not overfit_batches_enabled: trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches") trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches") trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches") trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") trainer.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval") if overfit_batches_enabled: trainer.limit_train_batches = overfit_batches trainer.limit_val_batches = overfit_batches def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]: if batches is None: # batches is optional to know if the user passed a value so that we can show the above info messages only to the # users that set a value explicitly return 1.0 # differentiating based on the type can be error-prone for users. show a message describing the chosen behaviour if isinstance(batches, int) and batches == 1: if name == "limit_train_batches": message = "1 batch per epoch will be used." elif name == "val_check_interval": message = "validation will run after every batch." else: message = "1 batch will be used." rank_zero_info(f"`Trainer({name}=1)` was configured so {message}") elif isinstance(batches, float) and batches == 1.0: if name == "limit_train_batches": message = "100% of the batches per epoch will be used." elif name == "val_check_interval": message = "validation will run at the end of the training epoch." else: message = "100% of the batches will be used." rank_zero_info(f"`Trainer({name}=1.0)` was configured so {message}.") if 0 <= batches <= 1: return batches if batches > 1 and batches % 1.0 == 0: return int(batches) raise MisconfigurationException( f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int." ) def _init_profiler(trainer: "pl.Trainer", profiler: Optional[Union[Profiler, str]]) -> None: if isinstance(profiler, str): PROFILERS = { "simple": SimpleProfiler, "advanced": AdvancedProfiler, "pytorch": PyTorchProfiler, "xla": XLAProfiler, } profiler = profiler.lower() if profiler not in PROFILERS: raise MisconfigurationException( "When passing string value for the `profiler` parameter of `Trainer`," f" it can only be one of {list(PROFILERS.keys())}" ) profiler_class = PROFILERS[profiler] profiler = profiler_class() trainer.profiler = profiler or PassThroughProfiler() def _log_device_info(trainer: "pl.Trainer") -> None: if CUDAAccelerator.is_available(): gpu_available = True gpu_type = " (cuda)" elif MPSAccelerator.is_available(): gpu_available = True gpu_type = " (mps)" else: gpu_available = False gpu_type = "" gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)) rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}") num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0 rank_zero_info(f"TPU available: {XLAAccelerator.is_available()}, using: {num_tpu_cores} TPU cores") if _habana_available_and_importable(): from lightning_habana import HPUAccelerator num_hpus = trainer.num_devices if isinstance(trainer.accelerator, HPUAccelerator) else 0 hpu_available = HPUAccelerator.is_available() else: num_hpus = 0 hpu_available = False rank_zero_info(f"HPU available: {hpu_available}, using: {num_hpus} HPUs") if ( CUDAAccelerator.is_available() and not isinstance(trainer.accelerator, CUDAAccelerator) or MPSAccelerator.is_available() and not isinstance(trainer.accelerator, MPSAccelerator) ): rank_zero_warn( "GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.", category=PossibleUserWarning, ) if XLAAccelerator.is_available() and not isinstance(trainer.accelerator, XLAAccelerator): rank_zero_warn("TPU available but not used. You can set it by doing `Trainer(accelerator='tpu')`.") if _habana_available_and_importable(): from lightning_habana import HPUAccelerator if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator): rank_zero_warn("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`.")