File size: 11,471 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Literal, Optional, Union
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
if TYPE_CHECKING:
from pytorch_lightning.tuner.lr_finder import _LRFinder
class Tuner:
"""Tuner class to tune your model."""
def __init__(self, trainer: "pl.Trainer") -> None:
self._trainer = trainer
def scale_batch_size(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional["pl.LightningDataModule"] = None,
method: Literal["fit", "validate", "test", "predict"] = "fit",
mode: str = "power",
steps_per_trial: int = 3,
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = "batch_size",
) -> Optional[int]:
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
error.
Args:
model: Model to tune.
train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples.
In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
samples used for running tuner on validation/testing/prediction.
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
mode: Search strategy to update the batch size:
- ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error.
- ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
do a binary search between the last successful batch size and the batch size that failed.
steps_per_trial: number of steps to run with a given batch size.
Ideally 1 should be enough to test if an OOM error occurs,
however in practise a few are needed
init_val: initial batch size to start the search with
max_trials: max number of increases in batch size done before
algorithm is terminated
batch_arg_name: name of the attribute that stores the batch size.
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places
- ``model``
- ``model.hparams``
- ``trainer.datamodule`` (the datamodule passed to the tune method)
"""
_check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)
_check_scale_batch_size_configuration(self._trainer)
# local import to avoid circular import
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
batch_size_finder: Callback = BatchSizeFinder(
mode=mode,
steps_per_trial=steps_per_trial,
init_val=init_val,
max_trials=max_trials,
batch_arg_name=batch_arg_name,
)
# do not continue with the loop in case Tuner is used
batch_size_finder._early_exit = True
self._trainer.callbacks = [batch_size_finder] + self._trainer.callbacks
if method == "fit":
self._trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)
elif method == "validate":
self._trainer.validate(model, dataloaders, datamodule=datamodule)
elif method == "test":
self._trainer.test(model, dataloaders, datamodule=datamodule)
elif method == "predict":
self._trainer.predict(model, dataloaders, datamodule=datamodule)
self._trainer.callbacks = [cb for cb in self._trainer.callbacks if cb is not batch_size_finder]
return batch_size_finder.optimal_batch_size
def lr_find(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional["pl.LightningDataModule"] = None,
method: Literal["fit", "validate", "test", "predict"] = "fit",
min_lr: float = 1e-8,
max_lr: float = 1,
num_training: int = 100,
mode: str = "exponential",
early_stop_threshold: Optional[float] = 4.0,
update_attr: bool = True,
attr_name: str = "",
) -> Optional["_LRFinder"]:
"""Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in
picking a good starting learning rate.
Args:
model: Model to tune.
train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples.
In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
samples used for running tuner on validation/testing/prediction.
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
min_lr: minimum learning rate to investigate
max_lr: maximum learning rate to investigate
num_training: number of learning rates to test
mode: Search strategy to update learning rate after each batch:
- ``'exponential'``: Increases the learning rate exponentially.
- ``'linear'``: Increases the learning rate linearly.
early_stop_threshold: Threshold for stopping the search. If the
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.
update_attr: Whether to update the learning rate attribute or not.
attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
automatically detected. Otherwise, set the name here.
Raises:
MisconfigurationException:
If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden,
or if you are using more than one optimizer.
"""
if method != "fit":
raise MisconfigurationException("method='fit' is the only valid configuration to run lr finder.")
_check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)
_check_lr_find_configuration(self._trainer)
# local import to avoid circular import
from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
lr_finder_callback: Callback = LearningRateFinder(
min_lr=min_lr,
max_lr=max_lr,
num_training_steps=num_training,
mode=mode,
early_stop_threshold=early_stop_threshold,
update_attr=update_attr,
attr_name=attr_name,
)
lr_finder_callback._early_exit = True
self._trainer.callbacks = [lr_finder_callback] + self._trainer.callbacks
self._trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)
self._trainer.callbacks = [cb for cb in self._trainer.callbacks if cb is not lr_finder_callback]
return lr_finder_callback.optimal_lr
def _check_tuner_configuration(
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
dataloaders: Optional[EVAL_DATALOADERS] = None,
method: Literal["fit", "validate", "test", "predict"] = "fit",
) -> None:
supported_methods = ("fit", "validate", "test", "predict")
if method not in supported_methods:
raise ValueError(f"method {method!r} is invalid. Should be one of {supported_methods}.")
if method == "fit":
if dataloaders is not None:
raise MisconfigurationException(
f"In tuner with method={method!r}, `dataloaders` argument should be None,"
" please consider setting `train_dataloaders` and `val_dataloaders` instead."
)
else:
if train_dataloaders is not None or val_dataloaders is not None:
raise MisconfigurationException(
f"In tuner with `method`={method!r}, `train_dataloaders` and `val_dataloaders`"
" arguments should be None, please consider setting `dataloaders` instead."
)
def _check_lr_find_configuration(trainer: "pl.Trainer") -> None:
# local import to avoid circular import
from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
configured_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, LearningRateFinder)]
if configured_callbacks:
raise ValueError(
"Trainer is already configured with a `LearningRateFinder` callback."
"Please remove it if you want to use the Tuner."
)
def _check_scale_batch_size_configuration(trainer: "pl.Trainer") -> None:
if trainer._accelerator_connector.is_distributed:
raise ValueError("Tuning the batch size is currently not supported with distributed strategies.")
# local import to avoid circular import
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
configured_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, BatchSizeFinder)]
if configured_callbacks:
raise ValueError(
"Trainer is already configured with a `BatchSizeFinder` callback."
"Please remove it if you want to use the Tuner."
)
|