|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import pickle |
|
import warnings |
|
from collections import OrderedDict |
|
from collections.abc import Sequence |
|
from functools import partial |
|
from io import BytesIO |
|
from pathlib import Path |
|
from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union |
|
|
|
import torch |
|
from lightning_utilities.core.apply_func import apply_to_collection |
|
from torch import Tensor |
|
from torch._C import _TensorMeta |
|
from torch.nn import Parameter |
|
from typing_extensions import override |
|
|
|
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 |
|
from lightning_fabric.utilities.types import _PATH, _Stateful |
|
|
|
_METADATA_FILENAME = "meta.pt" |
|
|
|
|
|
if TYPE_CHECKING: |
|
from torch.storage import TypedStorage |
|
|
|
|
|
|
|
class _NotYetLoadedTensor: |
|
def __init__( |
|
self, |
|
metatensor: Tensor, |
|
archiveinfo: "_LazyLoadingUnpickler", |
|
storageinfo: tuple, |
|
rebuild_args: tuple, |
|
) -> None: |
|
self.metatensor = metatensor |
|
self.archiveinfo = archiveinfo |
|
self.storageinfo = storageinfo |
|
self.rebuild_args = rebuild_args |
|
|
|
@classmethod |
|
def rebuild_from_type_v2( |
|
cls, |
|
func: Callable, |
|
new_type: _TensorMeta, |
|
args: tuple, |
|
state: dict, |
|
*, |
|
archiveinfo: Optional["_LazyLoadingUnpickler"] = None, |
|
) -> Any: |
|
ret = func(*args) |
|
if isinstance(ret, _NotYetLoadedTensor): |
|
old_lt = ret._load_tensor |
|
|
|
def _load_tensor() -> Any: |
|
t = old_lt() |
|
return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) |
|
|
|
ret._load_tensor = _load_tensor |
|
return ret |
|
return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) |
|
|
|
@classmethod |
|
def rebuild_parameter( |
|
cls, |
|
data: Any, |
|
requires_grad: bool, |
|
backward_hooks: OrderedDict, |
|
*, |
|
archiveinfo: Optional["_LazyLoadingUnpickler"] = None, |
|
) -> Union[Tensor, "_NotYetLoadedTensor"]: |
|
if isinstance(data, _NotYetLoadedTensor): |
|
old_lt = data._load_tensor |
|
|
|
def _load_tensor() -> Parameter: |
|
t = old_lt() |
|
return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) |
|
|
|
data._load_tensor = _load_tensor |
|
return data |
|
return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) |
|
|
|
@classmethod |
|
def rebuild_tensor_v2( |
|
cls, |
|
storage: "TypedStorage", |
|
storage_offset: int, |
|
size: tuple, |
|
stride: tuple, |
|
requires_grad: bool, |
|
backward_hooks: OrderedDict, |
|
metadata: Optional[Any] = None, |
|
*, |
|
archiveinfo: "_LazyLoadingUnpickler", |
|
) -> "_NotYetLoadedTensor": |
|
rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) |
|
metatensor = torch._utils._rebuild_tensor_v2( |
|
storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata |
|
) |
|
storageinfo = storage.archiveinfo |
|
return _NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) |
|
|
|
def _load_tensor(self) -> Tensor: |
|
from torch.storage import TypedStorage, UntypedStorage |
|
|
|
_, _, fn, _, size = self.storageinfo |
|
dtype = self.metatensor.dtype |
|
|
|
storage = self.archiveinfo.file_reader.get_storage_from_record( |
|
f"data/{fn}", size * torch._utils._element_size(dtype), UntypedStorage |
|
) |
|
uts = storage._typed_storage()._untyped_storage |
|
|
|
with warnings.catch_warnings(): |
|
|
|
warnings.simplefilter("ignore") |
|
storage = TypedStorage(wrap_storage=uts, dtype=dtype, _internal=True) |
|
return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) |
|
|
|
@classmethod |
|
def __torch_function__( |
|
cls, |
|
func: Callable, |
|
types: Sequence, |
|
args: Sequence[Any] = (), |
|
kwargs: Optional[dict] = None, |
|
) -> Any: |
|
kwargs = kwargs or {} |
|
loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args] |
|
return func(*loaded_args, **kwargs) |
|
|
|
@property |
|
def device(self) -> torch.device: |
|
return torch.device(self.storageinfo[3]) |
|
|
|
def __getattr__(self, name: str) -> Any: |
|
|
|
if name in { |
|
"dtype", |
|
"grad", |
|
"grad_fn", |
|
"is_meta", |
|
"layout", |
|
"names", |
|
"ndim", |
|
"output_nr", |
|
"requires_grad", |
|
"retains_grad", |
|
"size", |
|
"shape", |
|
"volatile", |
|
}: |
|
return getattr(self.metatensor, name) |
|
|
|
|
|
if name in {"contiguous", "cuda", "half", "data", "to"}: |
|
return getattr(self._load_tensor(), name) |
|
|
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}({repr(self.metatensor)})" |
|
|
|
|
|
|
|
class _LazyLoadingUnpickler(pickle.Unpickler): |
|
def __init__(self, file: IO, file_reader: torch.PyTorchFileReader) -> None: |
|
super().__init__(file) |
|
self.file_reader = file_reader |
|
|
|
@override |
|
def find_class(self, module: str, name: str) -> Any: |
|
if module == "torch._utils" and name == "_rebuild_tensor_v2": |
|
return partial(_NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) |
|
if module == "torch._tensor" and name == "_rebuild_from_type_v2": |
|
return partial(_NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) |
|
if module == "torch._utils" and name == "_rebuild_parameter": |
|
return partial(_NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) |
|
return super().find_class(module, name) |
|
|
|
@override |
|
def persistent_load(self, pid: tuple) -> "TypedStorage": |
|
from torch.storage import TypedStorage |
|
|
|
_, cls, _, _, _ = pid |
|
with warnings.catch_warnings(): |
|
|
|
warnings.simplefilter("ignore") |
|
storage = TypedStorage(dtype=cls().dtype, device="meta") |
|
storage.archiveinfo = pid |
|
return storage |
|
|
|
|
|
def _lazy_load(filename: _PATH) -> Any: |
|
if not os.path.isfile(filename): |
|
raise FileNotFoundError(f"Path {str(filename)!r} does not exist or is not a file.") |
|
file_reader = torch.PyTorchFileReader(str(filename)) |
|
with BytesIO(file_reader.get_record("data.pkl")) as pkl: |
|
mup = _LazyLoadingUnpickler(pkl, file_reader) |
|
return mup.load() |
|
|
|
|
|
def _materialize_tensors(collection: Any) -> Any: |
|
def _load_tensor(t: _NotYetLoadedTensor) -> Tensor: |
|
return t._load_tensor() |
|
|
|
return apply_to_collection(collection, dtype=_NotYetLoadedTensor, function=_load_tensor) |
|
|
|
|
|
def _move_state_into( |
|
source: dict[str, Any], destination: dict[str, Union[Any, _Stateful]], keys: Optional[set[str]] = None |
|
) -> None: |
|
"""Takes the state from the source destination and moves it into the destination dictionary. |
|
|
|
If an object in the destination follows the stateful protocol, it loads the source state via ``load_state_dict``. |
|
|
|
""" |
|
keys = set(source) if keys is None else keys & set(source) |
|
for key in keys: |
|
state = source.pop(key) |
|
if key in destination and isinstance(destination[key], _Stateful): |
|
destination[key].load_state_dict(state) |
|
else: |
|
destination[key] = state |
|
|
|
|
|
def _load_distributed_checkpoint(checkpoint_folder: Path) -> dict[str, Any]: |
|
"""Loads a sharded checkpoint saved with the `torch.distributed.checkpoint` into a full state dict. |
|
|
|
The current implementation assumes that the entire checkpoint fits in CPU memory. |
|
|
|
""" |
|
if not _TORCH_GREATER_EQUAL_2_3: |
|
raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.3.") |
|
|
|
from torch.distributed.checkpoint import FileSystemReader |
|
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner |
|
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict |
|
|
|
checkpoint: dict[str, Any] = {} |
|
_load_state_dict( |
|
checkpoint, |
|
storage_reader=FileSystemReader(checkpoint_folder), |
|
planner=_EmptyStateDictLoadPlanner(), |
|
no_dist=True, |
|
) |
|
|
|
|
|
extra_file = checkpoint_folder / _METADATA_FILENAME |
|
extra = torch.load(extra_file, map_location="cpu") if extra_file.is_file() else {} |
|
checkpoint.update(extra) |
|
|
|
return checkpoint |
|
|