File size: 9,758 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# Copyright 2023 MathInf GmbH
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this files from this repository except in compliance
# with the License reproduced below (also at
# http://www.apache.org/licenses/LICENSE-2.0).
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch._C import _TensorMeta
from torch.nn import Parameter
from typing_extensions import override

from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning_fabric.utilities.types import _PATH, _Stateful

_METADATA_FILENAME = "meta.pt"


if TYPE_CHECKING:
    from torch.storage import TypedStorage


# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann
class _NotYetLoadedTensor:
    def __init__(
        self,
        metatensor: Tensor,
        archiveinfo: "_LazyLoadingUnpickler",
        storageinfo: tuple,
        rebuild_args: tuple,
    ) -> None:
        self.metatensor = metatensor
        self.archiveinfo = archiveinfo
        self.storageinfo = storageinfo
        self.rebuild_args = rebuild_args

    @classmethod
    def rebuild_from_type_v2(
        cls,
        func: Callable,
        new_type: _TensorMeta,
        args: tuple,
        state: dict,
        *,
        archiveinfo: Optional["_LazyLoadingUnpickler"] = None,
    ) -> Any:
        ret = func(*args)
        if isinstance(ret, _NotYetLoadedTensor):
            old_lt = ret._load_tensor

            def _load_tensor() -> Any:
                t = old_lt()
                return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state)

            ret._load_tensor = _load_tensor  # type: ignore[method-assign]
            return ret
        return torch._tensor._rebuild_from_type_v2(func, new_type, args, state)

    @classmethod
    def rebuild_parameter(
        cls,
        data: Any,
        requires_grad: bool,
        backward_hooks: OrderedDict,
        *,
        archiveinfo: Optional["_LazyLoadingUnpickler"] = None,
    ) -> Union[Tensor, "_NotYetLoadedTensor"]:
        if isinstance(data, _NotYetLoadedTensor):
            old_lt = data._load_tensor

            def _load_tensor() -> Parameter:
                t = old_lt()
                return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks)

            data._load_tensor = _load_tensor  # type: ignore[method-assign]
            return data
        return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks)

    @classmethod
    def rebuild_tensor_v2(
        cls,
        storage: "TypedStorage",
        storage_offset: int,
        size: tuple,
        stride: tuple,
        requires_grad: bool,
        backward_hooks: OrderedDict,
        metadata: Optional[Any] = None,
        *,
        archiveinfo: "_LazyLoadingUnpickler",
    ) -> "_NotYetLoadedTensor":
        rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata)
        metatensor = torch._utils._rebuild_tensor_v2(
            storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata
        )
        storageinfo = storage.archiveinfo
        return _NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)

    def _load_tensor(self) -> Tensor:
        from torch.storage import TypedStorage, UntypedStorage

        _, _, fn, _, size = self.storageinfo
        dtype = self.metatensor.dtype

        storage = self.archiveinfo.file_reader.get_storage_from_record(
            f"data/{fn}", size * torch._utils._element_size(dtype), UntypedStorage
        )
        uts = storage._typed_storage()._untyped_storage

        with warnings.catch_warnings():
            # The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now
            warnings.simplefilter("ignore")
            storage = TypedStorage(wrap_storage=uts, dtype=dtype, _internal=True)
        return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)

    @classmethod
    def __torch_function__(
        cls,
        func: Callable,
        types: Sequence,
        args: Sequence[Any] = (),
        kwargs: Optional[dict] = None,
    ) -> Any:
        kwargs = kwargs or {}
        loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args]
        return func(*loaded_args, **kwargs)

    @property
    def device(self) -> torch.device:
        return torch.device(self.storageinfo[3])

    def __getattr__(self, name: str) -> Any:
        # These properties don't require materialization and can be accessed through the meta tensor directly
        if name in {
            "dtype",
            "grad",
            "grad_fn",
            "is_meta",
            "layout",
            "names",
            "ndim",
            "output_nr",
            "requires_grad",
            "retains_grad",
            "size",
            "shape",
            "volatile",
        }:
            return getattr(self.metatensor, name)

        # materializing these is needed for quantization (see lit-gpt)
        if name in {"contiguous", "cuda", "half", "data", "to"}:
            return getattr(self._load_tensor(), name)

        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({repr(self.metatensor)})"


# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann
class _LazyLoadingUnpickler(pickle.Unpickler):
    def __init__(self, file: IO, file_reader: torch.PyTorchFileReader) -> None:
        super().__init__(file)
        self.file_reader = file_reader

    @override
    def find_class(self, module: str, name: str) -> Any:
        if module == "torch._utils" and name == "_rebuild_tensor_v2":
            return partial(_NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self)
        if module == "torch._tensor" and name == "_rebuild_from_type_v2":
            return partial(_NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self)
        if module == "torch._utils" and name == "_rebuild_parameter":
            return partial(_NotYetLoadedTensor.rebuild_parameter, archiveinfo=self)
        return super().find_class(module, name)

    @override
    def persistent_load(self, pid: tuple) -> "TypedStorage":
        from torch.storage import TypedStorage

        _, cls, _, _, _ = pid
        with warnings.catch_warnings():
            # The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now
            warnings.simplefilter("ignore")
            storage = TypedStorage(dtype=cls().dtype, device="meta")
        storage.archiveinfo = pid
        return storage


def _lazy_load(filename: _PATH) -> Any:
    if not os.path.isfile(filename):
        raise FileNotFoundError(f"Path {str(filename)!r} does not exist or is not a file.")
    file_reader = torch.PyTorchFileReader(str(filename))
    with BytesIO(file_reader.get_record("data.pkl")) as pkl:
        mup = _LazyLoadingUnpickler(pkl, file_reader)
        return mup.load()


def _materialize_tensors(collection: Any) -> Any:
    def _load_tensor(t: _NotYetLoadedTensor) -> Tensor:
        return t._load_tensor()

    return apply_to_collection(collection, dtype=_NotYetLoadedTensor, function=_load_tensor)


def _move_state_into(
    source: dict[str, Any], destination: dict[str, Union[Any, _Stateful]], keys: Optional[set[str]] = None
) -> None:
    """Takes the state from the source destination and moves it into the destination dictionary.

    If an object in the destination follows the stateful protocol, it loads the source state via ``load_state_dict``.

    """
    keys = set(source) if keys is None else keys & set(source)
    for key in keys:
        state = source.pop(key)
        if key in destination and isinstance(destination[key], _Stateful):
            destination[key].load_state_dict(state)
        else:
            destination[key] = state


def _load_distributed_checkpoint(checkpoint_folder: Path) -> dict[str, Any]:
    """Loads a sharded checkpoint saved with the `torch.distributed.checkpoint` into a full state dict.

    The current implementation assumes that the entire checkpoint fits in CPU memory.

    """
    if not _TORCH_GREATER_EQUAL_2_3:
        raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.3.")

    from torch.distributed.checkpoint import FileSystemReader
    from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
    from torch.distributed.checkpoint.state_dict_loader import _load_state_dict

    checkpoint: dict[str, Any] = {}
    _load_state_dict(
        checkpoint,
        storage_reader=FileSystemReader(checkpoint_folder),
        planner=_EmptyStateDictLoadPlanner(),
        no_dist=True,
    )

    # This is the extra file saved by Fabric, with user data separate from weights and optimizer states
    extra_file = checkpoint_folder / _METADATA_FILENAME
    extra = torch.load(extra_file, map_location="cpu") if extra_file.is_file() else {}
    checkpoint.update(extra)

    return checkpoint