|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Houses the methods used to set up the Trainer.""" |
|
|
|
from typing import Optional, Union |
|
|
|
import pytorch_lightning as pl |
|
from lightning_fabric.utilities.warnings import PossibleUserWarning |
|
from pytorch_lightning.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator |
|
from pytorch_lightning.loggers.logger import DummyLogger |
|
from pytorch_lightning.profilers import ( |
|
AdvancedProfiler, |
|
PassThroughProfiler, |
|
Profiler, |
|
PyTorchProfiler, |
|
SimpleProfiler, |
|
XLAProfiler, |
|
) |
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException |
|
from pytorch_lightning.utilities.imports import _habana_available_and_importable |
|
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn |
|
|
|
|
|
def _init_debugging_flags( |
|
trainer: "pl.Trainer", |
|
limit_train_batches: Optional[Union[int, float]], |
|
limit_val_batches: Optional[Union[int, float]], |
|
limit_test_batches: Optional[Union[int, float]], |
|
limit_predict_batches: Optional[Union[int, float]], |
|
fast_dev_run: Union[int, bool], |
|
overfit_batches: Union[int, float], |
|
val_check_interval: Optional[Union[int, float]], |
|
num_sanity_val_steps: int, |
|
) -> None: |
|
|
|
if isinstance(fast_dev_run, int) and (fast_dev_run < 0): |
|
raise MisconfigurationException( |
|
f"fast_dev_run={fast_dev_run!r} is not a valid configuration. It should be >= 0." |
|
) |
|
trainer.fast_dev_run = fast_dev_run |
|
|
|
|
|
if fast_dev_run == 1: |
|
trainer.fast_dev_run = True |
|
|
|
trainer.overfit_batches = _determine_batch_limits(overfit_batches, "overfit_batches") |
|
overfit_batches_enabled = overfit_batches > 0 |
|
|
|
if fast_dev_run: |
|
num_batches = int(fast_dev_run) |
|
if not overfit_batches_enabled: |
|
trainer.limit_train_batches = num_batches |
|
trainer.limit_val_batches = num_batches |
|
|
|
trainer.limit_test_batches = num_batches |
|
trainer.limit_predict_batches = num_batches |
|
trainer.fit_loop.epoch_loop.max_steps = num_batches |
|
trainer.num_sanity_val_steps = 0 |
|
trainer.fit_loop.max_epochs = 1 |
|
trainer.val_check_interval = 1.0 |
|
trainer.check_val_every_n_epoch = 1 |
|
trainer.loggers = [DummyLogger()] if trainer.loggers else [] |
|
rank_zero_info( |
|
f"Running in `fast_dev_run` mode: will run the requested loop using {num_batches} batch(es). " |
|
"Logging and checkpointing is suppressed." |
|
) |
|
else: |
|
if not overfit_batches_enabled: |
|
trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches") |
|
trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches") |
|
trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches") |
|
trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") |
|
trainer.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps |
|
trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval") |
|
|
|
if overfit_batches_enabled: |
|
trainer.limit_train_batches = overfit_batches |
|
trainer.limit_val_batches = overfit_batches |
|
|
|
|
|
def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]: |
|
if batches is None: |
|
|
|
|
|
return 1.0 |
|
|
|
|
|
if isinstance(batches, int) and batches == 1: |
|
if name == "limit_train_batches": |
|
message = "1 batch per epoch will be used." |
|
elif name == "val_check_interval": |
|
message = "validation will run after every batch." |
|
else: |
|
message = "1 batch will be used." |
|
rank_zero_info(f"`Trainer({name}=1)` was configured so {message}") |
|
elif isinstance(batches, float) and batches == 1.0: |
|
if name == "limit_train_batches": |
|
message = "100% of the batches per epoch will be used." |
|
elif name == "val_check_interval": |
|
message = "validation will run at the end of the training epoch." |
|
else: |
|
message = "100% of the batches will be used." |
|
rank_zero_info(f"`Trainer({name}=1.0)` was configured so {message}.") |
|
|
|
if 0 <= batches <= 1: |
|
return batches |
|
if batches > 1 and batches % 1.0 == 0: |
|
return int(batches) |
|
raise MisconfigurationException( |
|
f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int." |
|
) |
|
|
|
|
|
def _init_profiler(trainer: "pl.Trainer", profiler: Optional[Union[Profiler, str]]) -> None: |
|
if isinstance(profiler, str): |
|
PROFILERS = { |
|
"simple": SimpleProfiler, |
|
"advanced": AdvancedProfiler, |
|
"pytorch": PyTorchProfiler, |
|
"xla": XLAProfiler, |
|
} |
|
profiler = profiler.lower() |
|
if profiler not in PROFILERS: |
|
raise MisconfigurationException( |
|
"When passing string value for the `profiler` parameter of `Trainer`," |
|
f" it can only be one of {list(PROFILERS.keys())}" |
|
) |
|
profiler_class = PROFILERS[profiler] |
|
profiler = profiler_class() |
|
trainer.profiler = profiler or PassThroughProfiler() |
|
|
|
|
|
def _log_device_info(trainer: "pl.Trainer") -> None: |
|
if CUDAAccelerator.is_available(): |
|
gpu_available = True |
|
gpu_type = " (cuda)" |
|
elif MPSAccelerator.is_available(): |
|
gpu_available = True |
|
gpu_type = " (mps)" |
|
else: |
|
gpu_available = False |
|
gpu_type = "" |
|
|
|
gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)) |
|
rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}") |
|
|
|
num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0 |
|
rank_zero_info(f"TPU available: {XLAAccelerator.is_available()}, using: {num_tpu_cores} TPU cores") |
|
|
|
if _habana_available_and_importable(): |
|
from lightning_habana import HPUAccelerator |
|
|
|
num_hpus = trainer.num_devices if isinstance(trainer.accelerator, HPUAccelerator) else 0 |
|
hpu_available = HPUAccelerator.is_available() |
|
else: |
|
num_hpus = 0 |
|
hpu_available = False |
|
rank_zero_info(f"HPU available: {hpu_available}, using: {num_hpus} HPUs") |
|
|
|
if ( |
|
CUDAAccelerator.is_available() |
|
and not isinstance(trainer.accelerator, CUDAAccelerator) |
|
or MPSAccelerator.is_available() |
|
and not isinstance(trainer.accelerator, MPSAccelerator) |
|
): |
|
rank_zero_warn( |
|
"GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.", |
|
category=PossibleUserWarning, |
|
) |
|
|
|
if XLAAccelerator.is_available() and not isinstance(trainer.accelerator, XLAAccelerator): |
|
rank_zero_warn("TPU available but not used. You can set it by doing `Trainer(accelerator='tpu')`.") |
|
|
|
if _habana_available_and_importable(): |
|
from lightning_habana import HPUAccelerator |
|
|
|
if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator): |
|
rank_zero_warn("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`.") |
|
|