jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
from typing import Any, Callable, Dict, Generator, Optional, Set, Tuple, Type, cast
import torch.nn as nn
def default_auto_wrap_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
module_is_root: bool,
# These are customizable for this default policy function.
min_num_params: int = int(1e8),
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
skip_params_check_for_root: bool = False,
) -> bool:
"""Default policy function for :func:`auto_wrap`.
Return if a module should be wrapped during :func:`auto_wrap`.
The first four parameters are used by :func:`auto_wrap`. If
you write a custom version of this policy function, your version
needs to at least accept the first four parameters and free
to do whatever you want in the function.
Args:
module (nn.Module):
The module to be considered in this decision.
recurse (bool):
Indicate if this is called to make a decision on whether we
should recurse down a subgraph of the module structure.
If False, it means this function is called to make a decision
on whether we should wrap the said module.
unwrapped_params (int):
The number of parameters yet to be wrapped in this module.
module_is_root (bool):
Indicates if current module is the root.
min_num_params (int):
Customizable policy input. It controls the size threshold
on how big should a module be to be considered wrapped.
force_leaf_modules (Set[Type[nn.Module]]): set of module types to
keep as leaves, i.e., their children will never be wrapped.
exclude_wrap_modules (Set[Type[nn.Module]]):
Customizable set of module types to be excluded in wrapping.
skip_params_check_for_root (bool):
If module_is_root is True, then this includes the root in
wrapping regardless of their number of unwrapped params.
"""
force_leaf_modules = (
default_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore
if force_leaf_modules is None
else force_leaf_modules
)
exclude_wrap_modules = (
default_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore
if exclude_wrap_modules is None
else exclude_wrap_modules
)
is_large = unwrapped_params >= min_num_params
if recurse:
# We should recurse if the module is big enough but not in force_leaf_modules list.
return is_large and not isinstance(module, tuple(force_leaf_modules))
else:
# If we are not recursing, determine if we should wrap.
return ((module_is_root and skip_params_check_for_root) or is_large) and not isinstance(
module, tuple(exclude_wrap_modules)
)
# Set those defaults to the default_auto_wrap_policy function. Make them easy to be imported.
default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore
default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore
def config_auto_wrap_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
module_is_root: bool,
) -> bool:
"""Config based policy function for :func:`auto_wrap`.
Return true for a module to be wrapped if it is already tagged with
a ``wrapper_config`` attribute.
Args:
module (nn.Module):
The module to be considered in this decision.
recurse (bool):
Indicate if this is called to make a decision on whether we
should recurse down a subgraph of the module structure.
If False, it means this function is called to make a decision
on whether we should wrap the said module.
unwrapped_params (int):
The number of parameters yet to be wrapped in this module.
Unused by this function.
module_is_root (bool):
Indicates if current module is the root.
Unused by this function.
"""
if recurse:
# We should always recurse.
return True
else:
# If we are not recursing, determine if we should wrap.
return hasattr(module, "wrapper_config")
@contextlib.contextmanager
def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]:
"""
Context manager to wrap modules using a wrapper.
Useful for when you'd like to apply the same parameters to all child modules
that you wrap. A particularly important use case is wrapping large layers so
that they get sharded (in-place) during initialization, to avoid running out of
system memory. Large layers can indicate that they should be sharded via
the ``wrap`` annotation and this context manager can provide the
exact configuration for these nested instances.
Usage::
with enable_wrap(**params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
self.l2 = auto_wrap(
TransformerBlock(),
# Wraps children modules based on a different min_num_params
auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1e7)
)
Args:
auto_wrap_policy (Callable, Optional):
Custom function to control how to do :func:`auto_wrap`. This is
useful to exclude unsupported modules or wrap based on sizes when
wrapping recursively. Note: modules annotated with :func:`wrap`
ignore this policy and will always be wrapped.
(default: :func:`default_auto_wrap_policy`)
**wrapper_kwargs:
Configuration settings that will be passed to all ``wrap``
instances inside the context
"""
with ConfigAutoWrap(auto_wrap_policy, **wrapper_kwargs):
yield
def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
"""
Annotate that a module should be wrapped. Annotated modules will only be
wrapped if inside of an :func:`enable_wrap` context manager. This allows
a module to be initialized both with and without a wrapper without code
change.
Both wrapper_cls and wrapper_config can be taken from 3 sources with
increasing priority:
1. ConfigAutoWrap's context
2. module.wrapper_config
3. wrap_overrides argument of this function
Usage::
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
Args:
module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
**wrap_overrides: configuration overrides that will take priority over
the values provided by the :func:`enable_wrap` context
"""
if ConfigAutoWrap.in_autowrap_context:
module_overrides = {}
if hasattr(module, "wrapper_config"):
module_overrides = module.wrapper_config
assert isinstance(module_overrides, dict)
wrap_overrides = {**ConfigAutoWrap.kwargs, **module_overrides, **wrap_overrides}
assert ConfigAutoWrap.wrapper_cls is not None
if ConfigAutoWrap.move_module_cuda_half:
module = module.cuda().half()
return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides)
return module
def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, **kwargs: Any) -> nn.Module:
"""
Annotate that a module should be wrapped with the *wrapper_cls* from the
:func:`enable_wrap` context (if the context exists) and recursively wrap
children modules that meet the criteria given by :func:`auto_wrap_policy`. This
is useful for wrapping large complex layers.
.. note:: auto_wrap can only be applied to a module once because it
assumes none of the sub-modules is already wrapped and uses that
assumption to compute the wrapped vs. unwrapped parameters.
To get around this limitation, users can pre-assign ``wrapper_config``
attributes to the sub-modules they want to wrap (in multiple passes)
and then uses the ``config_auto_wrap_policy``.
.. warning:: It is not recommended to use :func:`auto_wrap` with
:class:`FullyShardedDataParallel` on modules that have shared
parameters, as the parameter sharing may be broken (i.e. end up not
shared) if the shared parameters are not (auto-)wrapped under the same
FSDP wrapper instance.
Usage::
with enable_wrap(**params):
# Wraps children modules.
self.l1 = auto_wrap(TransformerBlock())
Args:
module (nn.Module):
module to wrap (if in :func:`enable_wrap` context)
auto_wrap_policy (Callable):
a function to determine should Module to be wrapped.
(default: wrap if > 100M parameters)
"""
if ConfigAutoWrap.in_autowrap_context:
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(
module, auto_wrap_policy=auto_wrap_policy, module_is_root=True, **kwargs
)
return wrapped_module
return module
class ConfigAutoWrap:
"""
Helper class to wrap modules based on default config args via a context manager.
See :func:`enable_wrap` for more information.
"""
in_autowrap_context: bool = False # Context flag
move_module_cuda_half: bool = False # A flag to control the wrap() function.
wrapper_cls: Optional[Callable] = None # The wrapper class
kwargs: Dict[str, Any] = {} # Wrapper's args
auto_wrap_policy: Optional[Callable] = None # Used only in auto_wrap
def __init__(self, auto_wrap_policy: Optional[Callable] = None, **kwargs: Dict[str, Any]):
self.auto_wrap_policy = auto_wrap_policy
self.kwargs = kwargs
@staticmethod
def enable_autowrap_context(auto_wrap_policy: Optional[Callable], kwargs: Any) -> None:
if ConfigAutoWrap.in_autowrap_context:
raise NotImplementedError(
"You are already within an autowrap context and we currently do not supported nested autowrap."
)
ConfigAutoWrap.in_autowrap_context = True
# Get and save the wrapper cls for the context.
if "move_module_cuda_half" in kwargs.keys():
ConfigAutoWrap.move_module_cuda_half = cast(bool, kwargs["move_module_cuda_half"])
del kwargs["move_module_cuda_half"]
assert "wrapper_cls" in kwargs.keys()
ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
del kwargs["wrapper_cls"]
# Save the rest.
ConfigAutoWrap.auto_wrap_policy = default_auto_wrap_policy if auto_wrap_policy is None else auto_wrap_policy
ConfigAutoWrap.kwargs = kwargs
@staticmethod
def disable_autowrap_context() -> None:
ConfigAutoWrap.in_autowrap_context = False
ConfigAutoWrap.move_module_cuda_half = False
ConfigAutoWrap.wrapper_cls = None
ConfigAutoWrap.kwargs = {}
ConfigAutoWrap.auto_wrap_policy = None
def __enter__(self) -> None:
self.enable_autowrap_context(self.auto_wrap_policy, self.kwargs)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.disable_autowrap_context()
@staticmethod
def recursive_wrap(
module: nn.Module, auto_wrap_policy: Optional[Callable], module_is_root: bool, **kwargs: Any
) -> Tuple[nn.Module, int]:
"""
Automatically wrap child modules of *module* that meet the given
criteria with :func:`auto_wrap`.
Args:
module (nn.Module):
module to recursively wrap
auto_wrap_policy (Callable, Optional):
optionally, override the :func:`auto_wrap_policy` from the context.
Returns:
(nn.Module, int):
Wrapped module and the number parameters wrapped recursively.
"""
if auto_wrap_policy is None:
auto_wrap_policy = ConfigAutoWrap.auto_wrap_policy
# Make sure no child is not already wrapped.
for _, child in module.named_modules():
assert not isinstance(child, cast(type, ConfigAutoWrap.wrapper_cls))
# We count all params, assuming none of them is already wrapped.
num_params = sum([p.numel() for p in module.parameters()])
assert auto_wrap_policy is not None
if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params, module_is_root=module_is_root):
total_wrapped_params = 0
# Iterate through the children, recursively wrap if necessary
for name, child in module.named_children():
wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap(
module=child, auto_wrap_policy=auto_wrap_policy, module_is_root=False, **kwargs
)
setattr(module, name, wrapped_child)
# Keep track of how many parameters have been wrapped
total_wrapped_params += num_wrapped_params
# decide if we need to wrap the current module,
# since the left over parameters exceed the number of params to wrap
remainder = num_params - total_wrapped_params
if auto_wrap_policy(
module=module, recurse=False, unwrapped_params=remainder, module_is_root=module_is_root
):
# Leaf node or final wrapping of the remainder both happen here.
return wrap(module, **kwargs), num_params
else:
return module, total_wrapped_params
return module, 0