File size: 5,592 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Literal, Optional, Union

from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

from lightning.fabric.utilities.types import _PARAMETERS, Optimizable

_PRECISION_INPUT_INT = Literal[64, 32, 16]
_PRECISION_INPUT_STR_ALIAS_CONVERSION = {"64": "64-true", "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"}
_PRECISION_INPUT_STR_ALIAS = Literal["64", "32", "16", "bf16"]
_PRECISION_INPUT_STR = Literal[
    "transformer-engine",
    "transformer-engine-float16",
    "16-true",
    "16-mixed",
    "bf16-true",
    "bf16-mixed",
    "32-true",
    "64-true",
]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS]


class Precision:
    """Base class for all plugins handling the precision-specific parts of the training.

    The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.

    """

    precision: _PRECISION_INPUT_STR = "32-true"

    def convert_module(self, module: Module) -> Module:
        """Convert the module parameters to the precision type this plugin handles.

        This is optional and depends on the precision limitations during optimization.

        """
        return module

    def tensor_init_context(self) -> AbstractContextManager:
        """Controls how tensors get created (device, dtype)."""
        return nullcontext()

    def module_init_context(self) -> AbstractContextManager:
        """Instantiate module parameters or tensors in the precision type this plugin handles.

        This is optional and depends on the precision limitations during optimization.

        """
        return nullcontext()

    def forward_context(self) -> AbstractContextManager:
        """A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
        return nullcontext()

    def convert_input(self, data: Any) -> Any:
        """Convert model inputs (forward) to the floating point precision type of this plugin.

        This is a no-op in the base precision plugin, since we assume the data already has the desired type (default is
        torch.float32).

        """
        return data

    def convert_output(self, data: Any) -> Any:
        """Convert outputs to the floating point precision type expected after model's forward.

        This is a no-op in the base precision plugin, since we assume the data already has the desired type (default is
        torch.float32).

        """
        return data

    def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> Any:
        """Runs before precision plugin executes backward.

        Args:
            tensor: The tensor that will be used for backpropagation
            module: The module that was involved in producing the tensor and whose parameters need the gradients

        """

    def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
        """Performs the actual backpropagation.

        Args:
            tensor: The tensor that will be used for backpropagation
            model: The module that was involved in producing the tensor and whose parameters need the gradients

        """
        tensor.backward(*args, **kwargs)

    def post_backward(self, tensor: Tensor, module: Optional[Module]) -> Any:
        """Runs after precision plugin executes backward.

        Args:
            tensor: The tensor that will be used for backpropagation
            module: The module that was involved in producing the tensor and whose parameters need the gradients

        """

    def optimizer_step(
        self,
        optimizer: Optimizable,
        **kwargs: Any,
    ) -> Any:
        """Hook to run the optimizer step."""
        return optimizer.step(**kwargs)

    def main_params(self, optimizer: Optimizer) -> _PARAMETERS:
        """The main params of the model.

        Returns the plain model params here. Maybe different in other precision plugins.

        """
        for group in optimizer.param_groups:
            yield from group["params"]

    def unscale_gradients(self, optimizer: Optimizer) -> None:
        return

    def state_dict(self) -> dict[str, Any]:
        """Called when saving a checkpoint, implement to generate precision plugin state_dict.

        Returns:
            A dictionary containing precision plugin state.

        """
        return {}

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        """Called when loading a checkpoint, implement to reload precision plugin state given precision plugin
        state_dict.

        Args:
            state_dict: the precision plugin state returned by ``state_dict``.

        """
        pass

    def teardown(self) -> None:
        """This method is called to teardown the training process.

        It is the right place to release memory and free other resources.

        """