jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from collections.abc import Generator, Iterator, Mapping
from copy import deepcopy
from functools import partial, wraps
from types import MethodType
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
overload,
)
import torch
from lightning_utilities import is_overridden
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch import nn as nn
from torch._dynamo import OptimizedModule
from torch.nn.modules.module import _IncompatibleKeys
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from typing_extensions import override
from lightning_fabric.plugins import Precision
from lightning_fabric.strategies import Strategy
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.data import _set_sampler_epoch
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning_fabric.utilities.types import Optimizable
T_destination = TypeVar("T_destination", bound=dict[str, Any])
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")
_in_fabric_backward: bool = False
class _FabricOptimizer:
def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[list[Callable]] = None) -> None:
"""FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer
step calls to the strategy.
The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`.
Args:
optimizer: The optimizer to wrap
strategy: Reference to the strategy for handling the optimizer step
"""
self._optimizer = optimizer
self._strategy = strategy
self._callbacks = callbacks or []
# imitate the class of the wrapped object to make isinstance checks work
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
@property
def optimizer(self) -> Optimizer:
return self._optimizer
def state_dict(self) -> dict[str, Tensor]:
return self._strategy.get_optimizer_state(self.optimizer)
def load_state_dict(self, state_dict: dict[str, Tensor]) -> None:
self.optimizer.load_state_dict(state_dict)
def step(self, closure: Optional[Callable] = None) -> Any:
kwargs = {"closure": closure} if closure is not None else {}
if hasattr(self._strategy, "model") and isinstance(self._strategy.model, Optimizable):
# only DeepSpeed defines this
optimizer = self._strategy.model
else:
optimizer = self.optimizer
output = self._strategy.optimizer_step(
optimizer,
**kwargs,
)
for callback in self._callbacks:
hook = getattr(callback, "on_after_optimizer_step", None)
if callable(hook):
hook(strategy=self._strategy, optimizer=optimizer)
return output
def __getattr__(self, item: Any) -> Any:
return getattr(self._optimizer, item)
class _FabricModule(_DeviceDtypeModuleMixin):
def __init__(
self, forward_module: nn.Module, strategy: Strategy, original_module: Optional[nn.Module] = None
) -> None:
"""The FabricModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
automatically for the forward pass.
The underlying wrapped module can be accessed via the property :attr:`module`.
Args:
forward_module: The module to wrap the ``forward`` method on.
strategy: Reference to the strategy for handling precision etc.
original_module: The original, unmodified module as passed into the
:meth:`lightning_fabric.fabric.Fabric.setup` method. This is needed when attribute lookup
on this wrapper should pass through to the original module.
"""
super().__init__()
self._forward_module = forward_module
self._original_module = original_module or forward_module
self._strategy = strategy
self._forward_methods = set(_LIGHTNING_MODULE_STEP_METHODS)
self._fabric_module_initialized = True
@property
def module(self) -> nn.Module:
return self._original_module or self._forward_module
@override
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Casts all inputs to the right precision and handles autocast for operations in the module forward method."""
precision = self._strategy.precision
args, kwargs = precision.convert_input((args, kwargs))
with precision.forward_context():
output = self._forward_module(*args, **kwargs)
output = precision.convert_output(output)
apply_to_collection(output, dtype=Tensor, function=self._register_backward_hook)
return output
@overload
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ...
@overload
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> dict[str, Any]: ...
@override
def state_dict(
self, destination: Optional[T_destination] = None, prefix: str = "", keep_vars: bool = False
) -> Optional[dict[str, Any]]:
return self._original_module.state_dict(
destination=destination, # type: ignore[type-var]
prefix=prefix,
keep_vars=keep_vars,
)
@override
def load_state_dict( # type: ignore[override]
self, state_dict: Mapping[str, Any], strict: bool = True, **kwargs: Any
) -> _IncompatibleKeys:
return self._original_module.load_state_dict(state_dict=state_dict, strict=strict, **kwargs)
def mark_forward_method(self, method: Union[MethodType, str]) -> None:
"""Mark a method as a 'forward' method to prevent it bypassing the strategy wrapper (e.g., DDP)."""
if not isinstance(method, (MethodType, str)):
raise TypeError(f"Expected a method or a string, but got: {type(method).__name__}")
name = method if isinstance(method, str) else method.__name__
if name == "forward":
raise ValueError("You cannot mark the forward method itself as a forward method.")
if not isinstance(getattr(self._original_module, name, None), MethodType):
raise AttributeError(
f"You marked '{name}' as a forward method, but `{type(self._original_module).__name__}.{name}` does not"
f" exist or is not a method."
)
self._forward_methods.add(name)
def _redirection_through_forward(self, method_name: str) -> Callable:
assert method_name != "forward"
original_forward = self._original_module.forward
def wrapped_forward(*args: Any, **kwargs: Any) -> Any:
# Unpatch ourselves immediately before calling the method `method_name`
# because itself may want to call the real `forward`
self._original_module.forward = original_forward
# Call the actual method e.g. `.training_step(...)`
method = getattr(self._original_module, method_name)
return method(*args, **kwargs)
# We make the caller "unknowingly" send their arguments through the forward_module's `__call__`.
# We expect that the `forward_module` will eventually call `original_module.forward`, which we
# have patched to redirect back to `original_module.method_name()`.
def call_forward_module(*args: Any, **kwargs: Any) -> Any:
# Patch the original_module's forward, so we can redirect the arguments back to the real method
self._original_module.forward = wrapped_forward
return self.forward(*args, **kwargs)
return call_forward_module
def _wrap_method_with_module_call_tracker(self, method: Callable, name: str) -> Callable:
"""Tracks whether any submodule in ``self._original_module`` was called during the execution of ``method`` by
registering forward hooks on all submodules."""
module_called = False
def hook(*_: Any, **__: Any) -> None:
nonlocal module_called
module_called = True
@wraps(method)
def _wrapped_method(*args: Any, **kwargs: Any) -> Any:
handles = []
for module in self._original_module.modules():
handles.append(module.register_forward_hook(hook))
output = method(*args, **kwargs)
if module_called:
raise RuntimeError(
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
" model. To avoid issues with the currently selected strategy, explicitly mark it as a"
f" forward method with `fabric_model.mark_forward_method({name!r})` after `fabric.setup()`."
)
for handle in handles:
handle.remove()
return output
return _wrapped_method
def _register_backward_hook(self, tensor: Tensor) -> Tensor:
if not tensor.requires_grad:
return tensor
strategy_requires = is_overridden("backward", self._strategy, parent=Strategy)
precision_requires = any(
is_overridden(method, self._strategy.precision, parent=Precision)
for method in ("pre_backward", "backward", "post_backward")
)
hook = partial(_backward_hook, (strategy_requires or precision_requires))
tensor.register_hook(hook)
return tensor
@override
def __getattr__(self, item: Any) -> Any:
if (
item != "_forward_methods"
and item in self._forward_methods
and self._forward_module != self._original_module
):
# Special support for methods marked by `mark_forward_method` to prevent bypassing DDP's forward
return self._redirection_through_forward(item)
try:
# __getattr__ gets called as a last resort if the attribute does not exist
# call nn.Module's implementation first
return super().__getattr__(item)
except AttributeError:
# If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module
original_module = super().__getattr__("_original_module")
attr = getattr(original_module, item)
if inspect.ismethod(attr) and self._forward_module != self._original_module:
attr = self._wrap_method_with_module_call_tracker(attr, item)
return attr
@override
def __setattr__(self, name: str, value: Any) -> None:
if not getattr(self, "_fabric_module_initialized", False):
super().__setattr__(name, value)
return
# Get the _original_module attribute
original_module = self._original_module
original_has_attr = hasattr(original_module, name)
# Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules
# Can't use self.__getattr__ because it would pass through to the original module
fabric_has_attr = name in dir(self)
if not (original_has_attr or fabric_has_attr):
setattr(original_module, name, value)
return
# The original module can also inherit from _DeviceDtypeModuleMixin,
# in this case, both the Fabric module and original module have attributes like _dtype
# set attribute on both
if original_has_attr:
setattr(original_module, name, value)
if fabric_has_attr:
super().__setattr__(name, value)
class _FabricDataLoader:
def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None:
"""The FabricDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the
device automatically if the device is specified.
Args:
dataloader: The dataloader to wrap
device: The device to which the data should be moved. By default the device is `None` and no data
transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`).
"""
self.__dict__.update(dataloader.__dict__)
self._dataloader = dataloader
self._device = device
self._num_iter_calls = 0
@property
def device(self) -> Optional[torch.device]:
return self._device
def __len__(self) -> int:
return len(self._dataloader)
def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
# Without setting the epoch, the distributed sampler would return the same indices every time, even when
# shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler.
# In Fabric, we take care of this boilerplate code.
_set_sampler_epoch(self._dataloader, self._num_iter_calls)
self._num_iter_calls += 1
if self._device is None:
yield from iter(self._dataloader)
else:
for item in self._dataloader:
yield move_data_to_device(item, self._device)
def _unwrap_objects(collection: Any) -> Any:
def _unwrap(
obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader],
) -> Union[nn.Module, Optimizer, DataLoader]:
if isinstance(unwrapped := _unwrap_compiled(obj)[0], _FabricModule):
return _unwrap_compiled(unwrapped._forward_module)[0]
if isinstance(obj, _FabricOptimizer):
return obj.optimizer
if isinstance(obj, _FabricDataLoader):
return obj._dataloader
return obj
types = [_FabricModule, _FabricOptimizer, _FabricDataLoader]
types.append(OptimizedModule)
return apply_to_collection(collection, dtype=tuple(types), function=_unwrap)
def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> tuple[Union[Any, nn.Module], Optional[dict[str, Any]]]:
"""Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped.
Use this function before instance checks against e.g. :class:`_FabricModule`.
"""
if isinstance(obj, OptimizedModule):
if (compile_kwargs := getattr(obj, "_compile_kwargs", None)) is None:
raise RuntimeError(
"Failed to determine the arguments that were used to compile the module. Make sure to import"
" lightning before `torch.compile` is used."
)
return obj._orig_mod, compile_kwargs
return obj, None
def _to_compiled(module: nn.Module, compile_kwargs: dict[str, Any]) -> OptimizedModule:
return torch.compile(module, **compile_kwargs) # type: ignore[return-value]
def _backward_hook(requires_backward: bool, *_: Any) -> None:
if requires_backward and not _in_fabric_backward:
raise RuntimeError(
"The current strategy and precision selection requires you to call `fabric.backward(loss)`"
" instead of `loss.backward()`."
)
def is_wrapped(obj: object) -> bool:
"""Checks if an object was set up by Fabric.
A :class:`~torch.nn.Module` may be wrapped by a :class:`_FabricModule`, a :class:`~torch.optim.Optimizer`
may be wrapped by a :class:`_FabricOptimizer`, or a :class:`~torch.utils.data.DataLoader` may be wrapped by
:class:`_FabricDataLoader`.
Args:
obj: The object to test.
"""
obj, _ = _unwrap_compiled(obj)
return isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader))
def _capture_compile_kwargs(compile_fn: Callable) -> Callable:
"""Wraps the ``torch.compile`` function and captures the compile arguments.
We extract the compile arguments so that we can reapply ``torch.compile`` in ``Fabric.setup()`` with the
same arguments as the user passed to the original call. The arguments get stored in a dictionary
``_compile_kwargs`` on the returned compiled module.
"""
# Limitation: Currently, the global compile config does not get captured on a per-model basis.
# PyTorch will resolve this in the future: https://github.com/pytorch/pytorch/issues/116575
@wraps(compile_fn)
def _capture(*args: Any, **kwargs: Any) -> Any:
if not args or not isinstance(args[0], nn.Module):
# either torch.compile is being applied as a decorator or we're compiling something else
return compile_fn(*args, **kwargs)
model = args[0]
compiled_model = compile_fn(model, **kwargs)
compiled_model._compile_kwargs = deepcopy(kwargs)
return compiled_model
return _capture
torch.compile = _capture_compile_kwargs(torch.compile)