jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright 2023 MathInf GmbH
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this files from this repository except in compliance
# with the License reproduced below (also at
# http://www.apache.org/licenses/LICENSE-2.0).
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch._C import _TensorMeta
from torch.nn import Parameter
from typing_extensions import override
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning_fabric.utilities.types import _PATH, _Stateful
_METADATA_FILENAME = "meta.pt"
if TYPE_CHECKING:
from torch.storage import TypedStorage
# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann
class _NotYetLoadedTensor:
def __init__(
self,
metatensor: Tensor,
archiveinfo: "_LazyLoadingUnpickler",
storageinfo: tuple,
rebuild_args: tuple,
) -> None:
self.metatensor = metatensor
self.archiveinfo = archiveinfo
self.storageinfo = storageinfo
self.rebuild_args = rebuild_args
@classmethod
def rebuild_from_type_v2(
cls,
func: Callable,
new_type: _TensorMeta,
args: tuple,
state: dict,
*,
archiveinfo: Optional["_LazyLoadingUnpickler"] = None,
) -> Any:
ret = func(*args)
if isinstance(ret, _NotYetLoadedTensor):
old_lt = ret._load_tensor
def _load_tensor() -> Any:
t = old_lt()
return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state)
ret._load_tensor = _load_tensor # type: ignore[method-assign]
return ret
return torch._tensor._rebuild_from_type_v2(func, new_type, args, state)
@classmethod
def rebuild_parameter(
cls,
data: Any,
requires_grad: bool,
backward_hooks: OrderedDict,
*,
archiveinfo: Optional["_LazyLoadingUnpickler"] = None,
) -> Union[Tensor, "_NotYetLoadedTensor"]:
if isinstance(data, _NotYetLoadedTensor):
old_lt = data._load_tensor
def _load_tensor() -> Parameter:
t = old_lt()
return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks)
data._load_tensor = _load_tensor # type: ignore[method-assign]
return data
return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks)
@classmethod
def rebuild_tensor_v2(
cls,
storage: "TypedStorage",
storage_offset: int,
size: tuple,
stride: tuple,
requires_grad: bool,
backward_hooks: OrderedDict,
metadata: Optional[Any] = None,
*,
archiveinfo: "_LazyLoadingUnpickler",
) -> "_NotYetLoadedTensor":
rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata)
metatensor = torch._utils._rebuild_tensor_v2(
storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata
)
storageinfo = storage.archiveinfo
return _NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)
def _load_tensor(self) -> Tensor:
from torch.storage import TypedStorage, UntypedStorage
_, _, fn, _, size = self.storageinfo
dtype = self.metatensor.dtype
storage = self.archiveinfo.file_reader.get_storage_from_record(
f"data/{fn}", size * torch._utils._element_size(dtype), UntypedStorage
)
uts = storage._typed_storage()._untyped_storage
with warnings.catch_warnings():
# The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now
warnings.simplefilter("ignore")
storage = TypedStorage(wrap_storage=uts, dtype=dtype, _internal=True)
return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
@classmethod
def __torch_function__(
cls,
func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Optional[dict] = None,
) -> Any:
kwargs = kwargs or {}
loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args]
return func(*loaded_args, **kwargs)
@property
def device(self) -> torch.device:
return torch.device(self.storageinfo[3])
def __getattr__(self, name: str) -> Any:
# These properties don't require materialization and can be accessed through the meta tensor directly
if name in {
"dtype",
"grad",
"grad_fn",
"is_meta",
"layout",
"names",
"ndim",
"output_nr",
"requires_grad",
"retains_grad",
"size",
"shape",
"volatile",
}:
return getattr(self.metatensor, name)
# materializing these is needed for quantization (see lit-gpt)
if name in {"contiguous", "cuda", "half", "data", "to"}:
return getattr(self._load_tensor(), name)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def __repr__(self) -> str:
return f"{self.__class__.__name__}({repr(self.metatensor)})"
# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann
class _LazyLoadingUnpickler(pickle.Unpickler):
def __init__(self, file: IO, file_reader: torch.PyTorchFileReader) -> None:
super().__init__(file)
self.file_reader = file_reader
@override
def find_class(self, module: str, name: str) -> Any:
if module == "torch._utils" and name == "_rebuild_tensor_v2":
return partial(_NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self)
if module == "torch._tensor" and name == "_rebuild_from_type_v2":
return partial(_NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self)
if module == "torch._utils" and name == "_rebuild_parameter":
return partial(_NotYetLoadedTensor.rebuild_parameter, archiveinfo=self)
return super().find_class(module, name)
@override
def persistent_load(self, pid: tuple) -> "TypedStorage":
from torch.storage import TypedStorage
_, cls, _, _, _ = pid
with warnings.catch_warnings():
# The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now
warnings.simplefilter("ignore")
storage = TypedStorage(dtype=cls().dtype, device="meta")
storage.archiveinfo = pid
return storage
def _lazy_load(filename: _PATH) -> Any:
if not os.path.isfile(filename):
raise FileNotFoundError(f"Path {str(filename)!r} does not exist or is not a file.")
file_reader = torch.PyTorchFileReader(str(filename))
with BytesIO(file_reader.get_record("data.pkl")) as pkl:
mup = _LazyLoadingUnpickler(pkl, file_reader)
return mup.load()
def _materialize_tensors(collection: Any) -> Any:
def _load_tensor(t: _NotYetLoadedTensor) -> Tensor:
return t._load_tensor()
return apply_to_collection(collection, dtype=_NotYetLoadedTensor, function=_load_tensor)
def _move_state_into(
source: dict[str, Any], destination: dict[str, Union[Any, _Stateful]], keys: Optional[set[str]] = None
) -> None:
"""Takes the state from the source destination and moves it into the destination dictionary.
If an object in the destination follows the stateful protocol, it loads the source state via ``load_state_dict``.
"""
keys = set(source) if keys is None else keys & set(source)
for key in keys:
state = source.pop(key)
if key in destination and isinstance(destination[key], _Stateful):
destination[key].load_state_dict(state)
else:
destination[key] = state
def _load_distributed_checkpoint(checkpoint_folder: Path) -> dict[str, Any]:
"""Loads a sharded checkpoint saved with the `torch.distributed.checkpoint` into a full state dict.
The current implementation assumes that the entire checkpoint fits in CPU memory.
"""
if not _TORCH_GREATER_EQUAL_2_3:
raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.3.")
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
checkpoint: dict[str, Any] = {}
_load_state_dict(
checkpoint,
storage_reader=FileSystemReader(checkpoint_folder),
planner=_EmptyStateDictLoadPlanner(),
no_dist=True,
)
# This is the extra file saved by Fabric, with user data separate from weights and optimizer states
extra_file = checkpoint_folder / _METADATA_FILENAME
extra = torch.load(extra_file, map_location="cpu") if extra_file.is_file() else {}
checkpoint.update(extra)
return checkpoint