|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import itertools |
|
from collections.abc import Sequence |
|
from typing import Any, Callable, Optional, Union |
|
|
|
import torch |
|
from torch.nn import Module, Parameter |
|
from torch.optim import Optimizer |
|
from torch.overrides import TorchFunctionMode |
|
from typing_extensions import override |
|
|
|
from lightning.fabric.utilities.rank_zero import rank_zero_warn |
|
from lightning.fabric.utilities.types import _DEVICE |
|
|
|
|
|
|
|
class _EmptyInit(TorchFunctionMode): |
|
"""Initialize `nn.Module` with empty tensors, i.e., uninitialized memory. |
|
|
|
Example:: |
|
|
|
with _EmptyInit(): |
|
model = BigModel() |
|
model.load_state_dict(torch.load("checkpoint.pt")) |
|
|
|
""" |
|
|
|
def __init__(self, enabled: bool = True) -> None: |
|
super().__init__() |
|
self.enabled = enabled |
|
|
|
@override |
|
def __torch_function__( |
|
self, |
|
func: Callable, |
|
types: Sequence, |
|
args: Sequence[Any] = (), |
|
kwargs: Optional[dict] = None, |
|
) -> Any: |
|
kwargs = kwargs or {} |
|
if not self.enabled: |
|
return func(*args, **kwargs) |
|
if getattr(func, "__module__", None) == "torch.nn.init": |
|
if "tensor" in kwargs: |
|
return kwargs["tensor"] |
|
return args[0] |
|
return func(*args, **kwargs) |
|
|
|
|
|
def _materialize(module: Module, device: _DEVICE) -> None: |
|
"""Materialize a module.""" |
|
module.to_empty(device=device, recurse=False) |
|
if not hasattr(module, "reset_parameters"): |
|
raise TypeError( |
|
f"Materialization requires that the `{type(module).__name__}.reset_parameters` method is implemented." |
|
" This method is used to initialize any children parameters or buffers in this module." |
|
) |
|
if callable(module.reset_parameters): |
|
module.reset_parameters() |
|
|
|
|
|
def _materialize_meta_tensors(module: Module, device: _DEVICE) -> None: |
|
"""Materialize all tensors in a given module.""" |
|
for module in module.modules(): |
|
if _has_meta_device_parameters_or_buffers(module, recurse=False): |
|
_materialize(module, device) |
|
|
|
|
|
def _materialize_distributed_module(module: Module, device: torch.device) -> None: |
|
|
|
|
|
|
|
if not _has_meta_device_parameters_or_buffers(module): |
|
return |
|
|
|
module.to_empty(device=device) |
|
|
|
uninitialized_modules = set() |
|
for submodule in module.modules(): |
|
if all(False for _ in itertools.chain(submodule.parameters(recurse=False), submodule.buffers(recurse=False))): |
|
|
|
continue |
|
if callable(reset_method := getattr(submodule, "reset_parameters", None)): |
|
reset_method() |
|
else: |
|
uninitialized_modules.add(type(submodule).__name__) |
|
|
|
if uninitialized_modules: |
|
rank_zero_warn( |
|
"Parameter initialization incomplete. The following modules have parameters or buffers with uninitialized" |
|
" memory because they don't define a `reset_parameters()` method for re-initialization:" |
|
f" {', '.join(uninitialized_modules)}" |
|
) |
|
|
|
|
|
def _has_meta_device_parameters_or_buffers(obj: Union[Module, Optimizer], recurse: bool = True) -> bool: |
|
if isinstance(obj, Optimizer): |
|
return any( |
|
t.is_meta for param_group in obj.param_groups for t in param_group["params"] if isinstance(t, Parameter) |
|
) |
|
if isinstance(obj, Module): |
|
return any(t.is_meta for t in itertools.chain(obj.parameters(recurse=recurse), obj.buffers(recurse=recurse))) |
|
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}") |
|
|