|
|
|
import os |
|
from collections.abc import Sequence |
|
from dataclasses import dataclass, field |
|
from enum import Enum |
|
from typing import Any, Optional, Union |
|
|
|
import torch |
|
from torch.distributed.checkpoint.stateful import StatefulT |
|
|
|
|
|
__all__ = [ |
|
"ChunkStorageMetadata", |
|
"TensorStorageMetadata", |
|
"BytesStorageMetadata", |
|
"Metadata", |
|
"MetadataIndex", |
|
"TensorProperties", |
|
"StorageMeta", |
|
] |
|
|
|
|
|
@dataclass |
|
class ChunkStorageMetadata: |
|
""" |
|
Each chunk is expected to have the same properties of the TensorStorageMetadata |
|
that includes it. |
|
""" |
|
|
|
offsets: torch.Size |
|
sizes: torch.Size |
|
|
|
|
|
class _MEM_FORMAT_ENCODING(Enum): |
|
"""Describe the memory format of a tensor.""" |
|
|
|
TORCH_CONTIGUOUS_FORMAT = 0 |
|
TORCH_CHANNELS_LAST = 1 |
|
TORCH_PRESERVE_FORMAT = 2 |
|
|
|
|
|
@dataclass |
|
class TensorProperties: |
|
"""Properties used to create :class:`Tensor`""" |
|
|
|
|
|
dtype: torch.dtype = field(default_factory=torch.get_default_dtype) |
|
|
|
layout: torch.layout = field(default=torch.strided) |
|
|
|
requires_grad: bool = False |
|
|
|
memory_format: torch.memory_format = field(default=torch.contiguous_format) |
|
|
|
pin_memory: bool = False |
|
|
|
def __getstate__(self): |
|
|
|
memory_format = self.memory_format |
|
if memory_format == torch.contiguous_format: |
|
mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT |
|
elif memory_format == torch.channels_last: |
|
mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST |
|
elif memory_format == torch.preserve_format: |
|
mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT |
|
else: |
|
raise RuntimeError(f"Invalid torch.memory_format: {memory_format}") |
|
|
|
return ( |
|
self.dtype, |
|
self.layout, |
|
self.requires_grad, |
|
mem_format_encoding, |
|
self.pin_memory, |
|
) |
|
|
|
def __setstate__( |
|
self, |
|
state, |
|
): |
|
( |
|
self.dtype, |
|
self.layout, |
|
self.requires_grad, |
|
mem_format_encoding, |
|
self.pin_memory, |
|
) = state |
|
|
|
if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: |
|
memory_format = torch.contiguous_format |
|
elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: |
|
memory_format = torch.channels_last |
|
elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: |
|
memory_format = torch.preserve_format |
|
else: |
|
raise RuntimeError( |
|
f"Invalid torch.memory_format encoding: {mem_format_encoding}" |
|
) |
|
|
|
self.memory_format = memory_format |
|
|
|
@staticmethod |
|
def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": |
|
return TensorProperties( |
|
dtype=tensor.dtype, |
|
layout=tensor.layout, |
|
requires_grad=tensor.requires_grad, |
|
memory_format=torch.contiguous_format, |
|
pin_memory=tensor.is_pinned(), |
|
) |
|
|
|
|
|
@dataclass |
|
class TensorStorageMetadata: |
|
properties: TensorProperties |
|
size: torch.Size |
|
chunks: list[ChunkStorageMetadata] |
|
|
|
|
|
@dataclass |
|
class BytesStorageMetadata: |
|
pass |
|
|
|
|
|
STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata] |
|
STATE_DICT_TYPE = dict[str, Union[StatefulT, Any]] |
|
|
|
|
|
@dataclass |
|
class StorageMeta: |
|
checkpoint_id: Union[str, os.PathLike, None] = None |
|
save_id: Optional[str] = None |
|
load_id: Optional[str] = None |
|
modules: list[str] = field(default_factory=list) |
|
|
|
|
|
@dataclass |
|
class Metadata: |
|
"""This class represents the metadata of the checkpoint.""" |
|
|
|
|
|
state_dict_metadata: dict[str, STORAGE_TYPES] |
|
|
|
|
|
|
|
|
|
planner_data: Any = None |
|
storage_data: Any = None |
|
storage_meta: Optional[StorageMeta] = None |
|
|
|
|
|
@dataclass(frozen=True) |
|
class MetadataIndex: |
|
"""This class represents a lookup key for items in a state dict or Metadata.""" |
|
|
|
fqn: str |
|
"""Fully Qualified Name of the object""" |
|
|
|
offset: Optional[torch.Size] = None |
|
"""If the object is a tensor, offset into the tensor we're looking for""" |
|
|
|
index: Optional[int] = field(hash=False, compare=False, default=None) |
|
""" |
|
Index hint when searching for tensor chunk to speedup lookups (optional) |
|
|
|
A common representation of a sharded tensor is as a list of chunks so to |
|
find the index in such a list you need to linear search it. |
|
|
|
When constructing an instance of MetadataIndex that points to that list, |
|
one can provide the index as a hint and it will be probed first before |
|
the linear search and thus making it significantly faster. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
fqn: str, |
|
offset: Optional[Sequence[int]] = None, |
|
index: Optional[int] = None, |
|
): |
|
|
|
object.__setattr__(self, "fqn", fqn) |
|
object.__setattr__(self, "index", index) |
|
if offset is not None: |
|
object.__setattr__(self, "offset", torch.Size(offset)) |
|
|