File size: 11,471 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Literal, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS

if TYPE_CHECKING:
    from pytorch_lightning.tuner.lr_finder import _LRFinder


class Tuner:
    """Tuner class to tune your model."""

    def __init__(self, trainer: "pl.Trainer") -> None:
        self._trainer = trainer

    def scale_batch_size(
        self,
        model: "pl.LightningModule",
        train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional["pl.LightningDataModule"] = None,
        method: Literal["fit", "validate", "test", "predict"] = "fit",
        mode: str = "power",
        steps_per_trial: int = 3,
        init_val: int = 2,
        max_trials: int = 25,
        batch_arg_name: str = "batch_size",
    ) -> Optional[int]:
        """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
        error.

        Args:
            model: Model to tune.
            train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
                :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples.
                In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
            val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
            dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
                samples used for running tuner on validation/testing/prediction.
            datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
            method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
            mode: Search strategy to update the batch size:

                - ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error.
                - ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
                    do a binary search between the last successful batch size and the batch size that failed.

            steps_per_trial: number of steps to run with a given batch size.
                Ideally 1 should be enough to test if an OOM error occurs,
                however in practise a few are needed
            init_val: initial batch size to start the search with
            max_trials: max number of increases in batch size done before
               algorithm is terminated
            batch_arg_name: name of the attribute that stores the batch size.
                It is expected that the user has provided a model or datamodule that has a hyperparameter
                with that name. We will look for this attribute name in the following places

                - ``model``
                - ``model.hparams``
                - ``trainer.datamodule`` (the datamodule passed to the tune method)

        """
        _check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)
        _check_scale_batch_size_configuration(self._trainer)

        # local import to avoid circular import
        from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder

        batch_size_finder: Callback = BatchSizeFinder(
            mode=mode,
            steps_per_trial=steps_per_trial,
            init_val=init_val,
            max_trials=max_trials,
            batch_arg_name=batch_arg_name,
        )
        # do not continue with the loop in case Tuner is used
        batch_size_finder._early_exit = True
        self._trainer.callbacks = [batch_size_finder] + self._trainer.callbacks

        if method == "fit":
            self._trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)
        elif method == "validate":
            self._trainer.validate(model, dataloaders, datamodule=datamodule)
        elif method == "test":
            self._trainer.test(model, dataloaders, datamodule=datamodule)
        elif method == "predict":
            self._trainer.predict(model, dataloaders, datamodule=datamodule)

        self._trainer.callbacks = [cb for cb in self._trainer.callbacks if cb is not batch_size_finder]
        return batch_size_finder.optimal_batch_size

    def lr_find(
        self,
        model: "pl.LightningModule",
        train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional["pl.LightningDataModule"] = None,
        method: Literal["fit", "validate", "test", "predict"] = "fit",
        min_lr: float = 1e-8,
        max_lr: float = 1,
        num_training: int = 100,
        mode: str = "exponential",
        early_stop_threshold: Optional[float] = 4.0,
        update_attr: bool = True,
        attr_name: str = "",
    ) -> Optional["_LRFinder"]:
        """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in
        picking a good starting learning rate.

        Args:
            model: Model to tune.
            train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
                :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples.
                In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
            val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
            dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
                samples used for running tuner on validation/testing/prediction.
            datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
            method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
            min_lr: minimum learning rate to investigate
            max_lr: maximum learning rate to investigate
            num_training: number of learning rates to test
            mode: Search strategy to update learning rate after each batch:

                - ``'exponential'``: Increases the learning rate exponentially.
                - ``'linear'``: Increases the learning rate linearly.

            early_stop_threshold: Threshold for stopping the search. If the
                loss at any point is larger than early_stop_threshold*best_loss
                then the search is stopped. To disable, set to None.
            update_attr: Whether to update the learning rate attribute or not.
            attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
                automatically detected. Otherwise, set the name here.

        Raises:
            MisconfigurationException:
                If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden,
                or if you are using more than one optimizer.

        """
        if method != "fit":
            raise MisconfigurationException("method='fit' is the only valid configuration to run lr finder.")

        _check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)
        _check_lr_find_configuration(self._trainer)

        # local import to avoid circular import
        from pytorch_lightning.callbacks.lr_finder import LearningRateFinder

        lr_finder_callback: Callback = LearningRateFinder(
            min_lr=min_lr,
            max_lr=max_lr,
            num_training_steps=num_training,
            mode=mode,
            early_stop_threshold=early_stop_threshold,
            update_attr=update_attr,
            attr_name=attr_name,
        )

        lr_finder_callback._early_exit = True
        self._trainer.callbacks = [lr_finder_callback] + self._trainer.callbacks

        self._trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)

        self._trainer.callbacks = [cb for cb in self._trainer.callbacks if cb is not lr_finder_callback]

        return lr_finder_callback.optimal_lr


def _check_tuner_configuration(
    train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None,
    val_dataloaders: Optional[EVAL_DATALOADERS] = None,
    dataloaders: Optional[EVAL_DATALOADERS] = None,
    method: Literal["fit", "validate", "test", "predict"] = "fit",
) -> None:
    supported_methods = ("fit", "validate", "test", "predict")
    if method not in supported_methods:
        raise ValueError(f"method {method!r} is invalid. Should be one of {supported_methods}.")

    if method == "fit":
        if dataloaders is not None:
            raise MisconfigurationException(
                f"In tuner with method={method!r}, `dataloaders` argument should be None,"
                " please consider setting `train_dataloaders` and `val_dataloaders` instead."
            )
    else:
        if train_dataloaders is not None or val_dataloaders is not None:
            raise MisconfigurationException(
                f"In tuner with `method`={method!r}, `train_dataloaders` and `val_dataloaders`"
                " arguments should be None, please consider setting `dataloaders` instead."
            )


def _check_lr_find_configuration(trainer: "pl.Trainer") -> None:
    # local import to avoid circular import
    from pytorch_lightning.callbacks.lr_finder import LearningRateFinder

    configured_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, LearningRateFinder)]
    if configured_callbacks:
        raise ValueError(
            "Trainer is already configured with a `LearningRateFinder` callback."
            "Please remove it if you want to use the Tuner."
        )


def _check_scale_batch_size_configuration(trainer: "pl.Trainer") -> None:
    if trainer._accelerator_connector.is_distributed:
        raise ValueError("Tuning the batch size is currently not supported with distributed strategies.")

    # local import to avoid circular import
    from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder

    configured_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, BatchSizeFinder)]
    if configured_callbacks:
        raise ValueError(
            "Trainer is already configured with a `BatchSizeFinder` callback."
            "Please remove it if you want to use the Tuner."
        )