File size: 17,568 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 |
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from collections.abc import Generator, Iterator, Mapping
from copy import deepcopy
from functools import partial, wraps
from types import MethodType
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
overload,
)
import torch
from lightning_utilities import is_overridden
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch import nn as nn
from torch._dynamo import OptimizedModule
from torch.nn.modules.module import _IncompatibleKeys
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from typing_extensions import override
from lightning_fabric.plugins import Precision
from lightning_fabric.strategies import Strategy
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.data import _set_sampler_epoch
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning_fabric.utilities.types import Optimizable
T_destination = TypeVar("T_destination", bound=dict[str, Any])
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")
_in_fabric_backward: bool = False
class _FabricOptimizer:
def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[list[Callable]] = None) -> None:
"""FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer
step calls to the strategy.
The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`.
Args:
optimizer: The optimizer to wrap
strategy: Reference to the strategy for handling the optimizer step
"""
self._optimizer = optimizer
self._strategy = strategy
self._callbacks = callbacks or []
# imitate the class of the wrapped object to make isinstance checks work
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
@property
def optimizer(self) -> Optimizer:
return self._optimizer
def state_dict(self) -> dict[str, Tensor]:
return self._strategy.get_optimizer_state(self.optimizer)
def load_state_dict(self, state_dict: dict[str, Tensor]) -> None:
self.optimizer.load_state_dict(state_dict)
def step(self, closure: Optional[Callable] = None) -> Any:
kwargs = {"closure": closure} if closure is not None else {}
if hasattr(self._strategy, "model") and isinstance(self._strategy.model, Optimizable):
# only DeepSpeed defines this
optimizer = self._strategy.model
else:
optimizer = self.optimizer
output = self._strategy.optimizer_step(
optimizer,
**kwargs,
)
for callback in self._callbacks:
hook = getattr(callback, "on_after_optimizer_step", None)
if callable(hook):
hook(strategy=self._strategy, optimizer=optimizer)
return output
def __getattr__(self, item: Any) -> Any:
return getattr(self._optimizer, item)
class _FabricModule(_DeviceDtypeModuleMixin):
def __init__(
self, forward_module: nn.Module, strategy: Strategy, original_module: Optional[nn.Module] = None
) -> None:
"""The FabricModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
automatically for the forward pass.
The underlying wrapped module can be accessed via the property :attr:`module`.
Args:
forward_module: The module to wrap the ``forward`` method on.
strategy: Reference to the strategy for handling precision etc.
original_module: The original, unmodified module as passed into the
:meth:`lightning_fabric.fabric.Fabric.setup` method. This is needed when attribute lookup
on this wrapper should pass through to the original module.
"""
super().__init__()
self._forward_module = forward_module
self._original_module = original_module or forward_module
self._strategy = strategy
self._forward_methods = set(_LIGHTNING_MODULE_STEP_METHODS)
self._fabric_module_initialized = True
@property
def module(self) -> nn.Module:
return self._original_module or self._forward_module
@override
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Casts all inputs to the right precision and handles autocast for operations in the module forward method."""
precision = self._strategy.precision
args, kwargs = precision.convert_input((args, kwargs))
with precision.forward_context():
output = self._forward_module(*args, **kwargs)
output = precision.convert_output(output)
apply_to_collection(output, dtype=Tensor, function=self._register_backward_hook)
return output
@overload
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ...
@overload
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> dict[str, Any]: ...
@override
def state_dict(
self, destination: Optional[T_destination] = None, prefix: str = "", keep_vars: bool = False
) -> Optional[dict[str, Any]]:
return self._original_module.state_dict(
destination=destination, # type: ignore[type-var]
prefix=prefix,
keep_vars=keep_vars,
)
@override
def load_state_dict( # type: ignore[override]
self, state_dict: Mapping[str, Any], strict: bool = True, **kwargs: Any
) -> _IncompatibleKeys:
return self._original_module.load_state_dict(state_dict=state_dict, strict=strict, **kwargs)
def mark_forward_method(self, method: Union[MethodType, str]) -> None:
"""Mark a method as a 'forward' method to prevent it bypassing the strategy wrapper (e.g., DDP)."""
if not isinstance(method, (MethodType, str)):
raise TypeError(f"Expected a method or a string, but got: {type(method).__name__}")
name = method if isinstance(method, str) else method.__name__
if name == "forward":
raise ValueError("You cannot mark the forward method itself as a forward method.")
if not isinstance(getattr(self._original_module, name, None), MethodType):
raise AttributeError(
f"You marked '{name}' as a forward method, but `{type(self._original_module).__name__}.{name}` does not"
f" exist or is not a method."
)
self._forward_methods.add(name)
def _redirection_through_forward(self, method_name: str) -> Callable:
assert method_name != "forward"
original_forward = self._original_module.forward
def wrapped_forward(*args: Any, **kwargs: Any) -> Any:
# Unpatch ourselves immediately before calling the method `method_name`
# because itself may want to call the real `forward`
self._original_module.forward = original_forward
# Call the actual method e.g. `.training_step(...)`
method = getattr(self._original_module, method_name)
return method(*args, **kwargs)
# We make the caller "unknowingly" send their arguments through the forward_module's `__call__`.
# We expect that the `forward_module` will eventually call `original_module.forward`, which we
# have patched to redirect back to `original_module.method_name()`.
def call_forward_module(*args: Any, **kwargs: Any) -> Any:
# Patch the original_module's forward, so we can redirect the arguments back to the real method
self._original_module.forward = wrapped_forward
return self.forward(*args, **kwargs)
return call_forward_module
def _wrap_method_with_module_call_tracker(self, method: Callable, name: str) -> Callable:
"""Tracks whether any submodule in ``self._original_module`` was called during the execution of ``method`` by
registering forward hooks on all submodules."""
module_called = False
def hook(*_: Any, **__: Any) -> None:
nonlocal module_called
module_called = True
@wraps(method)
def _wrapped_method(*args: Any, **kwargs: Any) -> Any:
handles = []
for module in self._original_module.modules():
handles.append(module.register_forward_hook(hook))
output = method(*args, **kwargs)
if module_called:
raise RuntimeError(
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
" model. To avoid issues with the currently selected strategy, explicitly mark it as a"
f" forward method with `fabric_model.mark_forward_method({name!r})` after `fabric.setup()`."
)
for handle in handles:
handle.remove()
return output
return _wrapped_method
def _register_backward_hook(self, tensor: Tensor) -> Tensor:
if not tensor.requires_grad:
return tensor
strategy_requires = is_overridden("backward", self._strategy, parent=Strategy)
precision_requires = any(
is_overridden(method, self._strategy.precision, parent=Precision)
for method in ("pre_backward", "backward", "post_backward")
)
hook = partial(_backward_hook, (strategy_requires or precision_requires))
tensor.register_hook(hook)
return tensor
@override
def __getattr__(self, item: Any) -> Any:
if (
item != "_forward_methods"
and item in self._forward_methods
and self._forward_module != self._original_module
):
# Special support for methods marked by `mark_forward_method` to prevent bypassing DDP's forward
return self._redirection_through_forward(item)
try:
# __getattr__ gets called as a last resort if the attribute does not exist
# call nn.Module's implementation first
return super().__getattr__(item)
except AttributeError:
# If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module
original_module = super().__getattr__("_original_module")
attr = getattr(original_module, item)
if inspect.ismethod(attr) and self._forward_module != self._original_module:
attr = self._wrap_method_with_module_call_tracker(attr, item)
return attr
@override
def __setattr__(self, name: str, value: Any) -> None:
if not getattr(self, "_fabric_module_initialized", False):
super().__setattr__(name, value)
return
# Get the _original_module attribute
original_module = self._original_module
original_has_attr = hasattr(original_module, name)
# Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules
# Can't use self.__getattr__ because it would pass through to the original module
fabric_has_attr = name in dir(self)
if not (original_has_attr or fabric_has_attr):
setattr(original_module, name, value)
return
# The original module can also inherit from _DeviceDtypeModuleMixin,
# in this case, both the Fabric module and original module have attributes like _dtype
# set attribute on both
if original_has_attr:
setattr(original_module, name, value)
if fabric_has_attr:
super().__setattr__(name, value)
class _FabricDataLoader:
def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None:
"""The FabricDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the
device automatically if the device is specified.
Args:
dataloader: The dataloader to wrap
device: The device to which the data should be moved. By default the device is `None` and no data
transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`).
"""
self.__dict__.update(dataloader.__dict__)
self._dataloader = dataloader
self._device = device
self._num_iter_calls = 0
@property
def device(self) -> Optional[torch.device]:
return self._device
def __len__(self) -> int:
return len(self._dataloader)
def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
# Without setting the epoch, the distributed sampler would return the same indices every time, even when
# shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler.
# In Fabric, we take care of this boilerplate code.
_set_sampler_epoch(self._dataloader, self._num_iter_calls)
self._num_iter_calls += 1
if self._device is None:
yield from iter(self._dataloader)
else:
for item in self._dataloader:
yield move_data_to_device(item, self._device)
def _unwrap_objects(collection: Any) -> Any:
def _unwrap(
obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader],
) -> Union[nn.Module, Optimizer, DataLoader]:
if isinstance(unwrapped := _unwrap_compiled(obj)[0], _FabricModule):
return _unwrap_compiled(unwrapped._forward_module)[0]
if isinstance(obj, _FabricOptimizer):
return obj.optimizer
if isinstance(obj, _FabricDataLoader):
return obj._dataloader
return obj
types = [_FabricModule, _FabricOptimizer, _FabricDataLoader]
types.append(OptimizedModule)
return apply_to_collection(collection, dtype=tuple(types), function=_unwrap)
def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> tuple[Union[Any, nn.Module], Optional[dict[str, Any]]]:
"""Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped.
Use this function before instance checks against e.g. :class:`_FabricModule`.
"""
if isinstance(obj, OptimizedModule):
if (compile_kwargs := getattr(obj, "_compile_kwargs", None)) is None:
raise RuntimeError(
"Failed to determine the arguments that were used to compile the module. Make sure to import"
" lightning before `torch.compile` is used."
)
return obj._orig_mod, compile_kwargs
return obj, None
def _to_compiled(module: nn.Module, compile_kwargs: dict[str, Any]) -> OptimizedModule:
return torch.compile(module, **compile_kwargs) # type: ignore[return-value]
def _backward_hook(requires_backward: bool, *_: Any) -> None:
if requires_backward and not _in_fabric_backward:
raise RuntimeError(
"The current strategy and precision selection requires you to call `fabric.backward(loss)`"
" instead of `loss.backward()`."
)
def is_wrapped(obj: object) -> bool:
"""Checks if an object was set up by Fabric.
A :class:`~torch.nn.Module` may be wrapped by a :class:`_FabricModule`, a :class:`~torch.optim.Optimizer`
may be wrapped by a :class:`_FabricOptimizer`, or a :class:`~torch.utils.data.DataLoader` may be wrapped by
:class:`_FabricDataLoader`.
Args:
obj: The object to test.
"""
obj, _ = _unwrap_compiled(obj)
return isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader))
def _capture_compile_kwargs(compile_fn: Callable) -> Callable:
"""Wraps the ``torch.compile`` function and captures the compile arguments.
We extract the compile arguments so that we can reapply ``torch.compile`` in ``Fabric.setup()`` with the
same arguments as the user passed to the original call. The arguments get stored in a dictionary
``_compile_kwargs`` on the returned compiled module.
"""
# Limitation: Currently, the global compile config does not get captured on a per-model basis.
# PyTorch will resolve this in the future: https://github.com/pytorch/pytorch/issues/116575
@wraps(compile_fn)
def _capture(*args: Any, **kwargs: Any) -> Any:
if not args or not isinstance(args[0], nn.Module):
# either torch.compile is being applied as a decorator or we're compiling something else
return compile_fn(*args, **kwargs)
model = args[0]
compiled_model = compile_fn(model, **kwargs)
compiled_model._compile_kwargs = deepcopy(kwargs)
return compiled_model
return _capture
torch.compile = _capture_compile_kwargs(torch.compile)
|