|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
import copy |
|
from abc import ABC, abstractmethod |
|
from collections.abc import Generator, Iterable, Sequence |
|
from typing import Any, Callable, cast, Optional, Union |
|
|
|
import torch.nn as nn |
|
|
|
|
|
__all__ = [ |
|
"always_wrap_policy", |
|
"lambda_auto_wrap_policy", |
|
"transformer_auto_wrap_policy", |
|
"size_based_auto_wrap_policy", |
|
"enable_wrap", |
|
"wrap", |
|
"CustomPolicy", |
|
"ModuleWrapPolicy", |
|
] |
|
|
|
|
|
|
|
|
|
|
|
def _post_order_apply( |
|
root_module: nn.Module, |
|
fn: Callable[[nn.Module], Optional[nn.Module]], |
|
): |
|
""" |
|
This applies ``fn`` to every module in the module tree of ``root_module`` |
|
following a post-order traversal. If ``fn`` returns an :class:`nn.Module`, |
|
then this replaces the original module with the newly returned one in the |
|
tree. Otherwise, ``fn`` should return ``None``, in which case the module is |
|
not changed. |
|
""" |
|
|
|
visited_modules: set[nn.Module] = {root_module} |
|
|
|
def _post_order_apply_inner( |
|
module: nn.Module, |
|
module_name: str, |
|
parent_module: Optional[nn.Module], |
|
): |
|
for child_module_name, child_module in module.named_children(): |
|
if child_module not in visited_modules: |
|
visited_modules.add(child_module) |
|
_post_order_apply_inner(child_module, child_module_name, module) |
|
optional_module = fn(module) |
|
if optional_module is not None: |
|
assert isinstance(parent_module, nn.Module), ( |
|
"Non-root modules should have their parent module set but got " |
|
f"{parent_module} for {module}" |
|
) |
|
assert module_name, ( |
|
"Non-root modules should have their module name set but got " |
|
f"an empty module name for {module}" |
|
) |
|
assert isinstance(optional_module, nn.Module), ( |
|
f"fn should return None or an nn.Module but got {optional_module}" |
|
) |
|
setattr(parent_module, module_name, optional_module) |
|
|
|
_post_order_apply_inner(root_module, "", None) |
|
|
|
|
|
def _construct_wrap_fn( |
|
root_module: nn.Module, |
|
target_module_to_kwargs: dict[nn.Module, dict[str, Any]], |
|
fsdp_fn: Callable, |
|
) -> Callable[[nn.Module], Optional[nn.Module]]: |
|
""" |
|
This constructs the "wrap" function to pass to :func:`_post_order_apply` |
|
based on ``target_module_to_kwargs``, which should be constructed from the |
|
wrapping policy. |
|
""" |
|
|
|
def fn(module: nn.Module) -> Optional[nn.Module]: |
|
|
|
|
|
if module in target_module_to_kwargs and module is not root_module: |
|
kwargs = target_module_to_kwargs[module] |
|
return fsdp_fn(module, **kwargs) |
|
return None |
|
|
|
return fn |
|
|
|
|
|
def _run_mixed_precision_override_policy( |
|
root_module: nn.Module, |
|
module_classes: Iterable[type[nn.Module]], |
|
ignored_modules: set[nn.Module], |
|
root_kwargs: dict[str, Any], |
|
target_module_to_kwargs: dict[nn.Module, dict[str, Any]], |
|
): |
|
module_classes_tuple = tuple(set(module_classes)) |
|
for module in root_module.modules(): |
|
if module in ignored_modules: |
|
continue |
|
elif isinstance(module, module_classes_tuple): |
|
|
|
if module not in target_module_to_kwargs: |
|
|
|
target_module_to_kwargs[module] = root_kwargs |
|
target_module_to_kwargs[module]["mixed_precision"] = None |
|
return target_module_to_kwargs |
|
|
|
|
|
def always_wrap_policy(*args, **kwargs) -> bool: |
|
""" |
|
A simple recursive wrap policy that always returns ``True``. This means |
|
that every submodule is wrapped by the wrapper class in |
|
:func:`_recursive_wrap`. |
|
""" |
|
return True |
|
|
|
|
|
class _Policy(ABC): |
|
""" |
|
This defines an abstract base class that represents a policy for applying |
|
a module-level API. |
|
""" |
|
|
|
@abstractmethod |
|
def _run_policy( |
|
self, |
|
root_module: nn.Module, |
|
ignored_modules: set[nn.Module], |
|
root_kwargs: dict[str, Any], |
|
) -> dict[nn.Module, dict[str, Any]]: |
|
""" |
|
This should return a dict ``target_module_to_kwargs`` that maps from |
|
each target module to wrap to its kwargs. |
|
""" |
|
... |
|
|
|
|
|
def _module_wrap_policy( |
|
module: nn.Module, |
|
recurse: bool, |
|
nonwrapped_numel: int, |
|
module_classes: set[type[nn.Module]], |
|
) -> bool: |
|
""" |
|
This auto wrap policy wraps every module that is an instance of any type in |
|
``module_classes`` as its own FSDP instance. The root module given by |
|
``module`` is always wrapped as an FSDP instance regardless. Since the |
|
wrapping proceeds bottom up, each FSDP instance manages the parameters in |
|
its subtree excluding any already managed by a child FSDP instance. |
|
|
|
Args: |
|
module (nn.Module): Current module being considered. |
|
recurse (bool): If ``False``, then this function must decide whether |
|
``module`` should be wrapped as an FSDP instance or not. If |
|
``True``, then the function is still recursing down the module |
|
tree as a part of the DFS. |
|
nonwrapped_numel (int): Parameter numel not yet wrapped. |
|
module_classes (Set[Type[nn.Module]]): Set of module classes that are |
|
wrapped as FSDP instances. |
|
|
|
Returns: |
|
``True`` if ``recurse=True``, and whether ``module`` should be wrapped |
|
if ``recurse=False``. |
|
""" |
|
if recurse: |
|
return True |
|
return isinstance(module, tuple(module_classes)) |
|
|
|
|
|
class ModuleWrapPolicy(_Policy): |
|
""" |
|
This policy applies to every module of the specified module classes, |
|
passing in the kwargs given to the root. |
|
""" |
|
|
|
def __init__(self, module_classes: Iterable[type[nn.Module]]): |
|
module_classes_set = set(module_classes) |
|
self._module_classes = module_classes_set |
|
self._module_classes_str = str(module_classes_set) |
|
|
|
def _run_policy( |
|
self, |
|
root_module: nn.Module, |
|
ignored_modules: set[nn.Module], |
|
root_kwargs: dict[str, Any], |
|
) -> dict[nn.Module, dict[str, Any]]: |
|
module_classes = tuple(self._module_classes) |
|
target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {} |
|
for module in root_module.modules(): |
|
if module in ignored_modules: |
|
continue |
|
elif isinstance(module, module_classes): |
|
|
|
target_module_to_kwargs[module] = copy.copy(root_kwargs) |
|
return target_module_to_kwargs |
|
|
|
def __call__(self, module, recurse, *args, **kwargs): |
|
|
|
return _module_wrap_policy( |
|
module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes |
|
) |
|
|
|
def __repr__(self) -> str: |
|
return super().__repr__() + f"({self._module_classes_str})" |
|
|
|
|
|
class CustomPolicy(_Policy): |
|
""" |
|
This policy takes in a lambda function that maps a given ``nn.Module`` to |
|
either ``False``, ``True``, or a kwarg dictionary. |
|
- If the function returns ``False`` or an empty dictionary, then the module |
|
does not have the API applied. |
|
- If the function returns ``True``, then the module has the API applied |
|
with the root's kwargs. |
|
- If the function returns a non-empty dictionary, then the module has the |
|
API applied, and the dictionary overrides the root's kwargs. |
|
|
|
Example:: |
|
|
|
>>> # xdoctest: +SKIP("undefined variables") |
|
>>> model = init_transformer_model(...) |
|
>>> def lambda_fn(module: nn.Module): |
|
>>> if module is model.lm_head: |
|
>>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP} |
|
>>> elif isinstance(module, TransformerBlock): |
|
>>> return True |
|
>>> return False |
|
>>> policy = CustomPolicy(lambda_fn) |
|
>>> fsdp_model = FSDP(model, auto_wrap_policy=policy) |
|
""" |
|
|
|
def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, dict[str, Any]]]): |
|
self._lambda_fn = lambda_fn |
|
|
|
def _run_policy( |
|
self, |
|
root_module: nn.Module, |
|
ignored_modules: set[nn.Module], |
|
root_kwargs: dict[str, Any], |
|
) -> dict[nn.Module, dict[str, Any]]: |
|
target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {} |
|
for module in root_module.modules(): |
|
if module in ignored_modules: |
|
continue |
|
res = self._lambda_fn(module) |
|
if not isinstance(res, (dict, bool)): |
|
raise ValueError( |
|
"The lambda_fn passed to CustomPolicy should return " |
|
f"False/True or a kwarg dict, but it returned {res}" |
|
) |
|
if not res: |
|
continue |
|
kwargs = copy.copy(root_kwargs) |
|
if isinstance(res, dict): |
|
|
|
|
|
kwargs.update(res) |
|
target_module_to_kwargs[module] = kwargs |
|
return target_module_to_kwargs |
|
|
|
|
|
def lambda_auto_wrap_policy( |
|
module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable |
|
) -> bool: |
|
""" |
|
A convenient auto wrap policy to wrap submodules based on an arbitrary user |
|
function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as |
|
a `wrapper_cls` unit. |
|
|
|
Return if a module should be wrapped during auto wrapping. |
|
|
|
The first three parameters are required by :func:`_recursive_wrap`. |
|
|
|
Args: |
|
module (nn.Module): Current module being considered. |
|
recurse (bool): If ``False``, then this function must decide whether |
|
``module`` should be wrapped as an FSDP instance or not. If |
|
``True``, then the function is still recursing down the module |
|
tree as a part of the DFS. |
|
nonwrapped_numel (int): Parameter numel not yet wrapped. |
|
|
|
lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then |
|
this module will be wrapped. |
|
""" |
|
if recurse: |
|
return True |
|
return lambda_fn(module) |
|
|
|
|
|
def transformer_auto_wrap_policy( |
|
module: nn.Module, |
|
recurse: bool, |
|
nonwrapped_numel: int, |
|
transformer_layer_cls: set[type[nn.Module]], |
|
) -> bool: |
|
""" |
|
See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the |
|
same as ``module_classes``. Note that shared parameters must be wrapped in |
|
the same FSDP instance, so this auto wrap policy can help wrap shared |
|
embeddings into the same FSDP instance for transformer models. |
|
""" |
|
return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls) |
|
|
|
|
|
def _wrap_module_cls_individually( |
|
module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs |
|
): |
|
if recurse: |
|
|
|
return True |
|
else: |
|
|
|
|
|
return isinstance(module, tuple(module_classes)) |
|
|
|
|
|
def _or_policy( |
|
module: nn.Module, |
|
recurse: bool, |
|
nonwrapped_numel: int, |
|
policies, |
|
) -> bool: |
|
""" |
|
A policy that wraps ``module`` if any policy in the passed in iterable of |
|
``policies`` returns ``True``. |
|
""" |
|
return any( |
|
policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel) |
|
for policy in policies |
|
) |
|
|
|
|
|
def size_based_auto_wrap_policy( |
|
module: nn.Module, |
|
recurse: bool, |
|
nonwrapped_numel: int, |
|
|
|
min_num_params: int = int(1e8), |
|
force_leaf_modules: Optional[set[type[nn.Module]]] = None, |
|
exclude_wrap_modules: Optional[set[type[nn.Module]]] = None, |
|
) -> bool: |
|
""" |
|
A size-based auto wrap policy. |
|
|
|
Args: |
|
module (nn.Module): Current module being considered. |
|
recurse (bool): If ``False``, then this function must decide whether |
|
``module`` should be wrapped as an FSDP instance or not. If |
|
``True``, then the function is still recursing down the module |
|
tree as a part of the DFS. |
|
nonwrapped_numel (int): Parameter numel not yet wrapped. |
|
|
|
min_num_params (int): Customizable policy input that controls the size |
|
threshold over which a module is ready to be wrapped. This is in |
|
units of numel. |
|
force_leaf_modules (Optional[set[type[nn.Module]]]): Set of module types to keep |
|
as leaves, i.e. their children will never be wrapped. |
|
exclude_wrap_modules (Optional[set[type[nn.Module]]]): Set of module types to be |
|
excluded in wrapping. |
|
|
|
Returns: |
|
Whether ``module`` should be wrapped. |
|
""" |
|
force_leaf_modules = ( |
|
size_based_auto_wrap_policy.FORCE_LEAF_MODULES |
|
if force_leaf_modules is None |
|
else force_leaf_modules |
|
) |
|
exclude_wrap_modules = ( |
|
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES |
|
if exclude_wrap_modules is None |
|
else exclude_wrap_modules |
|
) |
|
|
|
|
|
|
|
min_nonwrapped_numel = min_num_params |
|
is_large = nonwrapped_numel >= min_nonwrapped_numel |
|
if recurse: |
|
|
|
return is_large and not isinstance(module, tuple(force_leaf_modules)) |
|
else: |
|
|
|
return is_large and not isinstance(module, tuple(exclude_wrap_modules)) |
|
|
|
|
|
|
|
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} |
|
size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} |
|
|
|
|
|
@contextlib.contextmanager |
|
def enable_wrap( |
|
*, wrapper_cls: Any, **wrapper_kwargs: Any |
|
) -> Generator[None, None, None]: |
|
""" |
|
Context manager to wrap modules using a wrapper. |
|
|
|
Useful for when you'd like to apply the same configuration arguments to all |
|
child modules that you wrap. A particularly important use case is wrapping |
|
large layers so that they get sharded (in-place) during initialization, to |
|
avoid running out of system memory. Large layers can indicate that they |
|
should be sharded via the ``wrap`` annotation and this context manager can |
|
provide the exact configuration for these nested instances. |
|
|
|
Usage:: |
|
|
|
with enable_wrap(wrapper_cls, **params): |
|
# Wraps layer in FSDP by default if within context |
|
self.l1 = wrap(torch.nn.Linear(5, 5)) |
|
|
|
Args: |
|
wrapper_cls: |
|
Class that `wrap` annotation will `wrap` modules with, such as |
|
`FullyShardedDataParallel`. |
|
**wrapper_kwargs: |
|
Configuration settings that will be passed to all ``wrap`` |
|
instances inside the context |
|
""" |
|
kwargs = { |
|
"wrapper_cls": wrapper_cls, |
|
**wrapper_kwargs, |
|
} |
|
with _ConfigAutoWrap(**kwargs): |
|
yield |
|
|
|
|
|
def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: |
|
""" |
|
Annotate that a module should be wrapped. Annotated modules will only be |
|
wrapped if inside of an :func:`enable_wrap` context manager. This allows |
|
a module to be initialized both with and without a wrapper without code |
|
change. |
|
|
|
The class that this function wraps the passed in ``nn.Module`` with is the |
|
passed in ``wrapper_cls`` argument into ``enable_wrap``. Both |
|
``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct |
|
the ``wrapper_cls`` instance. In the case of duplicate kwargs in |
|
``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be |
|
respected. |
|
|
|
Usage:: |
|
|
|
with enable_wrap(wrapper_cls=FSDP, **fsdp_config): |
|
# Wraps layer in FSDP by default if within context |
|
self.l1 = wrap(torch.nn.Linear(5, 5)) |
|
|
|
Args: |
|
module (nn.Module): module to wrap (if in :func:`enable_wrap` context) |
|
**wrap_overrides: configuration overrides that will take priority over |
|
the values provided by the :func:`enable_wrap` context |
|
""" |
|
if _ConfigAutoWrap.in_autowrap_context: |
|
assert _ConfigAutoWrap.wrapper_cls is not None |
|
|
|
wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides} |
|
return _wrap( |
|
module, |
|
_ConfigAutoWrap.wrapper_cls, |
|
**wrap_overrides, |
|
) |
|
return module |
|
|
|
|
|
def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: |
|
assert wrapper_cls is not None |
|
if hasattr(module, "_wrap_overrides"): |
|
|
|
|
|
|
|
|
|
overrides = {**kwargs, **module._wrap_overrides} |
|
return wrapper_cls(module, **overrides) |
|
|
|
return wrapper_cls(module, **kwargs) |
|
|
|
|
|
def _recursive_wrap( |
|
module: nn.Module, |
|
auto_wrap_policy: Callable, |
|
wrapper_cls: Callable, |
|
ignored_modules: set[nn.Module], |
|
ignored_params: set[nn.Parameter], |
|
only_wrap_children: bool = False, |
|
**kwargs: Any, |
|
) -> tuple[nn.Module, int]: |
|
""" |
|
Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns |
|
``True`` with ``wrapper_cls``. |
|
|
|
Args: |
|
module (nn.Module): Module to recursively wrap. |
|
auto_wrap_policy (Callable): A callable representing a policy that |
|
determines which modules to recursively wrap with ``wrapper_cls``. |
|
ignored_modules (set[torch.nn.Module]): Modules to ignore when |
|
wrapping. |
|
ignored_params (set[torch.nn.Parameter]): Parameters to ignore when |
|
wrapping; these should be the parameters contained in the modules |
|
in ``ignored_modules``. |
|
Returns: |
|
(nn.Module, int): |
|
``module`` after wrapping and the numel recursively wrapped. |
|
""" |
|
assert auto_wrap_policy is not None, "Must specify auto_wrap_policy." |
|
assert wrapper_cls is not None, "Must specify wrapper_cls" |
|
|
|
for _, child in module.named_modules(): |
|
if child in ignored_modules: |
|
continue |
|
try: |
|
assert not isinstance(child, cast(type, wrapper_cls)) |
|
except TypeError: |
|
|
|
pass |
|
|
|
|
|
nonwrapped_numel = sum( |
|
p.numel() for p in module.parameters() if p not in ignored_params |
|
) |
|
|
|
assert auto_wrap_policy is not None |
|
if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): |
|
total_wrapped_numel = 0 |
|
|
|
for name, child in module.named_children(): |
|
if child in ignored_modules: |
|
continue |
|
wrapped_child, num_wrapped_params = _recursive_wrap( |
|
module=child, |
|
auto_wrap_policy=auto_wrap_policy, |
|
wrapper_cls=wrapper_cls, |
|
ignored_modules=ignored_modules, |
|
ignored_params=ignored_params, |
|
**kwargs, |
|
) |
|
setattr(module, name, wrapped_child) |
|
|
|
total_wrapped_numel += num_wrapped_params |
|
|
|
|
|
remainder = nonwrapped_numel - total_wrapped_numel |
|
if not only_wrap_children and auto_wrap_policy( |
|
module=module, recurse=False, nonwrapped_numel=remainder |
|
): |
|
|
|
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel |
|
else: |
|
return module, total_wrapped_numel |
|
return module, 0 |
|
|
|
|
|
class _ConfigAutoWrap: |
|
""" |
|
Helper class to wrap modules based on default config args via a context manager. |
|
See :func:`enable_wrap` for more information. |
|
""" |
|
|
|
in_autowrap_context: bool = False |
|
wrapper_cls: Optional[Callable] = None |
|
kwargs: dict[str, Any] = {} |
|
|
|
def __init__(self, **kwargs: dict[str, Any]): |
|
self.kwargs = kwargs |
|
|
|
@staticmethod |
|
def enable_autowrap_context(kwargs: Any) -> None: |
|
if _ConfigAutoWrap.in_autowrap_context: |
|
raise NotImplementedError( |
|
"You are already within an autowrap context and we currently do not supported nested autowrap." |
|
) |
|
_ConfigAutoWrap.in_autowrap_context = True |
|
|
|
assert "wrapper_cls" in kwargs.keys(), ( |
|
"Expected to pass in wrapper_cls arg into _ConfigAutoWrap." |
|
) |
|
_ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) |
|
del kwargs["wrapper_cls"] |
|
|
|
_ConfigAutoWrap.kwargs = kwargs |
|
|
|
@staticmethod |
|
def disable_autowrap_context() -> None: |
|
_ConfigAutoWrap.in_autowrap_context = False |
|
_ConfigAutoWrap.wrapper_cls = None |
|
_ConfigAutoWrap.kwargs = {} |
|
|
|
def __enter__(self) -> None: |
|
self.enable_autowrap_context(self.kwargs) |
|
|
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: |
|
self.disable_autowrap_context() |
|
|