|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
from typing import Any, Callable, Dict, Generator, Optional, Set, Tuple, Type, cast |
|
|
|
import torch.nn as nn |
|
|
|
|
|
def default_auto_wrap_policy( |
|
module: nn.Module, |
|
recurse: bool, |
|
unwrapped_params: int, |
|
module_is_root: bool, |
|
|
|
min_num_params: int = int(1e8), |
|
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None, |
|
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None, |
|
skip_params_check_for_root: bool = False, |
|
) -> bool: |
|
"""Default policy function for :func:`auto_wrap`. |
|
|
|
Return if a module should be wrapped during :func:`auto_wrap`. |
|
|
|
The first four parameters are used by :func:`auto_wrap`. If |
|
you write a custom version of this policy function, your version |
|
needs to at least accept the first four parameters and free |
|
to do whatever you want in the function. |
|
|
|
Args: |
|
module (nn.Module): |
|
The module to be considered in this decision. |
|
recurse (bool): |
|
Indicate if this is called to make a decision on whether we |
|
should recurse down a subgraph of the module structure. |
|
If False, it means this function is called to make a decision |
|
on whether we should wrap the said module. |
|
unwrapped_params (int): |
|
The number of parameters yet to be wrapped in this module. |
|
module_is_root (bool): |
|
Indicates if current module is the root. |
|
|
|
min_num_params (int): |
|
Customizable policy input. It controls the size threshold |
|
on how big should a module be to be considered wrapped. |
|
force_leaf_modules (Set[Type[nn.Module]]): set of module types to |
|
keep as leaves, i.e., their children will never be wrapped. |
|
exclude_wrap_modules (Set[Type[nn.Module]]): |
|
Customizable set of module types to be excluded in wrapping. |
|
skip_params_check_for_root (bool): |
|
If module_is_root is True, then this includes the root in |
|
wrapping regardless of their number of unwrapped params. |
|
""" |
|
force_leaf_modules = ( |
|
default_auto_wrap_policy.FORCE_LEAF_MODULES |
|
if force_leaf_modules is None |
|
else force_leaf_modules |
|
) |
|
exclude_wrap_modules = ( |
|
default_auto_wrap_policy.EXCLUDE_WRAP_MODULES |
|
if exclude_wrap_modules is None |
|
else exclude_wrap_modules |
|
) |
|
|
|
is_large = unwrapped_params >= min_num_params |
|
if recurse: |
|
|
|
return is_large and not isinstance(module, tuple(force_leaf_modules)) |
|
else: |
|
|
|
return ((module_is_root and skip_params_check_for_root) or is_large) and not isinstance( |
|
module, tuple(exclude_wrap_modules) |
|
) |
|
|
|
|
|
|
|
default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} |
|
default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} |
|
|
|
|
|
def config_auto_wrap_policy( |
|
module: nn.Module, |
|
recurse: bool, |
|
unwrapped_params: int, |
|
module_is_root: bool, |
|
) -> bool: |
|
"""Config based policy function for :func:`auto_wrap`. |
|
|
|
Return true for a module to be wrapped if it is already tagged with |
|
a ``wrapper_config`` attribute. |
|
|
|
Args: |
|
module (nn.Module): |
|
The module to be considered in this decision. |
|
recurse (bool): |
|
Indicate if this is called to make a decision on whether we |
|
should recurse down a subgraph of the module structure. |
|
If False, it means this function is called to make a decision |
|
on whether we should wrap the said module. |
|
unwrapped_params (int): |
|
The number of parameters yet to be wrapped in this module. |
|
Unused by this function. |
|
module_is_root (bool): |
|
Indicates if current module is the root. |
|
Unused by this function. |
|
""" |
|
if recurse: |
|
|
|
return True |
|
else: |
|
|
|
return hasattr(module, "wrapper_config") |
|
|
|
|
|
@contextlib.contextmanager |
|
def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]: |
|
""" |
|
Context manager to wrap modules using a wrapper. |
|
|
|
Useful for when you'd like to apply the same parameters to all child modules |
|
that you wrap. A particularly important use case is wrapping large layers so |
|
that they get sharded (in-place) during initialization, to avoid running out of |
|
system memory. Large layers can indicate that they should be sharded via |
|
the ``wrap`` annotation and this context manager can provide the |
|
exact configuration for these nested instances. |
|
|
|
Usage:: |
|
|
|
with enable_wrap(**params): |
|
# Wraps layer in FSDP by default if within context |
|
self.l1 = wrap(torch.nn.Linear(5, 5)) |
|
self.l2 = auto_wrap( |
|
TransformerBlock(), |
|
# Wraps children modules based on a different min_num_params |
|
auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1e7) |
|
) |
|
|
|
Args: |
|
auto_wrap_policy (Callable, Optional): |
|
Custom function to control how to do :func:`auto_wrap`. This is |
|
useful to exclude unsupported modules or wrap based on sizes when |
|
wrapping recursively. Note: modules annotated with :func:`wrap` |
|
ignore this policy and will always be wrapped. |
|
(default: :func:`default_auto_wrap_policy`) |
|
**wrapper_kwargs: |
|
Configuration settings that will be passed to all ``wrap`` |
|
instances inside the context |
|
""" |
|
with ConfigAutoWrap(auto_wrap_policy, **wrapper_kwargs): |
|
yield |
|
|
|
|
|
def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: |
|
""" |
|
Annotate that a module should be wrapped. Annotated modules will only be |
|
wrapped if inside of an :func:`enable_wrap` context manager. This allows |
|
a module to be initialized both with and without a wrapper without code |
|
change. |
|
|
|
Both wrapper_cls and wrapper_config can be taken from 3 sources with |
|
increasing priority: |
|
|
|
1. ConfigAutoWrap's context |
|
2. module.wrapper_config |
|
3. wrap_overrides argument of this function |
|
|
|
Usage:: |
|
|
|
with enable_wrap(wrapper_cls=FSDP, **fsdp_config): |
|
# Wraps layer in FSDP by default if within context |
|
self.l1 = wrap(torch.nn.Linear(5, 5)) |
|
|
|
Args: |
|
module (nn.Module): module to wrap (if in :func:`enable_wrap` context) |
|
**wrap_overrides: configuration overrides that will take priority over |
|
the values provided by the :func:`enable_wrap` context |
|
""" |
|
if ConfigAutoWrap.in_autowrap_context: |
|
module_overrides = {} |
|
if hasattr(module, "wrapper_config"): |
|
module_overrides = module.wrapper_config |
|
assert isinstance(module_overrides, dict) |
|
wrap_overrides = {**ConfigAutoWrap.kwargs, **module_overrides, **wrap_overrides} |
|
assert ConfigAutoWrap.wrapper_cls is not None |
|
if ConfigAutoWrap.move_module_cuda_half: |
|
module = module.cuda().half() |
|
return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides) |
|
return module |
|
|
|
|
|
def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, **kwargs: Any) -> nn.Module: |
|
""" |
|
Annotate that a module should be wrapped with the *wrapper_cls* from the |
|
:func:`enable_wrap` context (if the context exists) and recursively wrap |
|
children modules that meet the criteria given by :func:`auto_wrap_policy`. This |
|
is useful for wrapping large complex layers. |
|
|
|
.. note:: auto_wrap can only be applied to a module once because it |
|
assumes none of the sub-modules is already wrapped and uses that |
|
assumption to compute the wrapped vs. unwrapped parameters. |
|
To get around this limitation, users can pre-assign ``wrapper_config`` |
|
attributes to the sub-modules they want to wrap (in multiple passes) |
|
and then uses the ``config_auto_wrap_policy``. |
|
|
|
.. warning:: It is not recommended to use :func:`auto_wrap` with |
|
:class:`FullyShardedDataParallel` on modules that have shared |
|
parameters, as the parameter sharing may be broken (i.e. end up not |
|
shared) if the shared parameters are not (auto-)wrapped under the same |
|
FSDP wrapper instance. |
|
|
|
Usage:: |
|
|
|
with enable_wrap(**params): |
|
# Wraps children modules. |
|
self.l1 = auto_wrap(TransformerBlock()) |
|
|
|
Args: |
|
module (nn.Module): |
|
module to wrap (if in :func:`enable_wrap` context) |
|
auto_wrap_policy (Callable): |
|
a function to determine should Module to be wrapped. |
|
(default: wrap if > 100M parameters) |
|
""" |
|
if ConfigAutoWrap.in_autowrap_context: |
|
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap( |
|
module, auto_wrap_policy=auto_wrap_policy, module_is_root=True, **kwargs |
|
) |
|
return wrapped_module |
|
return module |
|
|
|
|
|
class ConfigAutoWrap: |
|
""" |
|
Helper class to wrap modules based on default config args via a context manager. |
|
See :func:`enable_wrap` for more information. |
|
""" |
|
|
|
in_autowrap_context: bool = False |
|
move_module_cuda_half: bool = False |
|
wrapper_cls: Optional[Callable] = None |
|
kwargs: Dict[str, Any] = {} |
|
auto_wrap_policy: Optional[Callable] = None |
|
|
|
def __init__(self, auto_wrap_policy: Optional[Callable] = None, **kwargs: Dict[str, Any]): |
|
self.auto_wrap_policy = auto_wrap_policy |
|
self.kwargs = kwargs |
|
|
|
@staticmethod |
|
def enable_autowrap_context(auto_wrap_policy: Optional[Callable], kwargs: Any) -> None: |
|
if ConfigAutoWrap.in_autowrap_context: |
|
raise NotImplementedError( |
|
"You are already within an autowrap context and we currently do not supported nested autowrap." |
|
) |
|
ConfigAutoWrap.in_autowrap_context = True |
|
|
|
if "move_module_cuda_half" in kwargs.keys(): |
|
ConfigAutoWrap.move_module_cuda_half = cast(bool, kwargs["move_module_cuda_half"]) |
|
del kwargs["move_module_cuda_half"] |
|
assert "wrapper_cls" in kwargs.keys() |
|
ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) |
|
del kwargs["wrapper_cls"] |
|
|
|
ConfigAutoWrap.auto_wrap_policy = default_auto_wrap_policy if auto_wrap_policy is None else auto_wrap_policy |
|
ConfigAutoWrap.kwargs = kwargs |
|
|
|
@staticmethod |
|
def disable_autowrap_context() -> None: |
|
ConfigAutoWrap.in_autowrap_context = False |
|
ConfigAutoWrap.move_module_cuda_half = False |
|
ConfigAutoWrap.wrapper_cls = None |
|
ConfigAutoWrap.kwargs = {} |
|
ConfigAutoWrap.auto_wrap_policy = None |
|
|
|
def __enter__(self) -> None: |
|
self.enable_autowrap_context(self.auto_wrap_policy, self.kwargs) |
|
|
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: |
|
self.disable_autowrap_context() |
|
|
|
@staticmethod |
|
def recursive_wrap( |
|
module: nn.Module, auto_wrap_policy: Optional[Callable], module_is_root: bool, **kwargs: Any |
|
) -> Tuple[nn.Module, int]: |
|
""" |
|
Automatically wrap child modules of *module* that meet the given |
|
criteria with :func:`auto_wrap`. |
|
|
|
Args: |
|
module (nn.Module): |
|
module to recursively wrap |
|
auto_wrap_policy (Callable, Optional): |
|
optionally, override the :func:`auto_wrap_policy` from the context. |
|
|
|
Returns: |
|
(nn.Module, int): |
|
Wrapped module and the number parameters wrapped recursively. |
|
""" |
|
if auto_wrap_policy is None: |
|
auto_wrap_policy = ConfigAutoWrap.auto_wrap_policy |
|
|
|
|
|
for _, child in module.named_modules(): |
|
assert not isinstance(child, cast(type, ConfigAutoWrap.wrapper_cls)) |
|
|
|
|
|
num_params = sum([p.numel() for p in module.parameters()]) |
|
|
|
assert auto_wrap_policy is not None |
|
if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params, module_is_root=module_is_root): |
|
total_wrapped_params = 0 |
|
|
|
for name, child in module.named_children(): |
|
wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap( |
|
module=child, auto_wrap_policy=auto_wrap_policy, module_is_root=False, **kwargs |
|
) |
|
setattr(module, name, wrapped_child) |
|
|
|
total_wrapped_params += num_wrapped_params |
|
|
|
|
|
remainder = num_params - total_wrapped_params |
|
if auto_wrap_policy( |
|
module=module, recurse=False, unwrapped_params=remainder, module_is_root=module_is_root |
|
): |
|
|
|
return wrap(module, **kwargs), num_params |
|
else: |
|
return module, total_wrapped_params |
|
return module, 0 |
|
|