File size: 4,589 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union

import torch
from torch.nn import Module, Parameter
from torch.optim import Optimizer
from torch.overrides import TorchFunctionMode
from typing_extensions import override

from lightning_fabric.utilities.rank_zero import rank_zero_warn
from lightning_fabric.utilities.types import _DEVICE


# From https://lernapparat.de/faster-model-init by Thomas Viehmann
class _EmptyInit(TorchFunctionMode):
    """Initialize `nn.Module` with empty tensors, i.e., uninitialized memory.

    Example::

        with _EmptyInit():
            model = BigModel()
        model.load_state_dict(torch.load("checkpoint.pt"))

    """

    def __init__(self, enabled: bool = True) -> None:
        super().__init__()
        self.enabled = enabled

    @override
    def __torch_function__(
        self,
        func: Callable,
        types: Sequence,
        args: Sequence[Any] = (),
        kwargs: Optional[dict] = None,
    ) -> Any:
        kwargs = kwargs or {}
        if not self.enabled:
            return func(*args, **kwargs)
        if getattr(func, "__module__", None) == "torch.nn.init":
            if "tensor" in kwargs:
                return kwargs["tensor"]
            return args[0]
        return func(*args, **kwargs)


def _materialize(module: Module, device: _DEVICE) -> None:
    """Materialize a module."""
    module.to_empty(device=device, recurse=False)
    if not hasattr(module, "reset_parameters"):
        raise TypeError(
            f"Materialization requires that the `{type(module).__name__}.reset_parameters` method is implemented."
            " This method is used to initialize any children parameters or buffers in this module."
        )
    if callable(module.reset_parameters):
        module.reset_parameters()


def _materialize_meta_tensors(module: Module, device: _DEVICE) -> None:
    """Materialize all tensors in a given module."""
    for module in module.modules():
        if _has_meta_device_parameters_or_buffers(module, recurse=False):
            _materialize(module, device)


def _materialize_distributed_module(module: Module, device: torch.device) -> None:
    # Reference: https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md#meta-device-initialization
    # TODO: Introduce `Fabric.materialize(module)` to give user control when materialization should happen
    # TODO: Make `torchmetrics.Metric` compatible with the `to_empty()` + `reset_parameters()` semantics
    if not _has_meta_device_parameters_or_buffers(module):
        return

    module.to_empty(device=device)  # has to be called on the root module

    uninitialized_modules = set()
    for submodule in module.modules():
        if all(False for _ in itertools.chain(submodule.parameters(recurse=False), submodule.buffers(recurse=False))):
            # module has no parameters or buffers
            continue
        if callable(reset_method := getattr(submodule, "reset_parameters", None)):
            reset_method()
        else:
            uninitialized_modules.add(type(submodule).__name__)

    if uninitialized_modules:
        rank_zero_warn(
            "Parameter initialization incomplete. The following modules have parameters or buffers with uninitialized"
            " memory because they don't define a `reset_parameters()` method for re-initialization:"
            f" {', '.join(uninitialized_modules)}"
        )


def _has_meta_device_parameters_or_buffers(obj: Union[Module, Optimizer], recurse: bool = True) -> bool:
    if isinstance(obj, Optimizer):
        return any(
            t.is_meta for param_group in obj.param_groups for t in param_group["params"] if isinstance(t, Parameter)
        )
    if isinstance(obj, Module):
        return any(t.is_meta for t in itertools.chain(obj.parameters(recurse=recurse), obj.buffers(recurse=recurse)))
    raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")