# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import contextlib from typing import Any, Callable, Dict, Generator, Optional, Set, Tuple, Type, cast import torch.nn as nn def default_auto_wrap_policy( module: nn.Module, recurse: bool, unwrapped_params: int, module_is_root: bool, # These are customizable for this default policy function. min_num_params: int = int(1e8), force_leaf_modules: Optional[Set[Type[nn.Module]]] = None, exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None, skip_params_check_for_root: bool = False, ) -> bool: """Default policy function for :func:`auto_wrap`. Return if a module should be wrapped during :func:`auto_wrap`. The first four parameters are used by :func:`auto_wrap`. If you write a custom version of this policy function, your version needs to at least accept the first four parameters and free to do whatever you want in the function. Args: module (nn.Module): The module to be considered in this decision. recurse (bool): Indicate if this is called to make a decision on whether we should recurse down a subgraph of the module structure. If False, it means this function is called to make a decision on whether we should wrap the said module. unwrapped_params (int): The number of parameters yet to be wrapped in this module. module_is_root (bool): Indicates if current module is the root. min_num_params (int): Customizable policy input. It controls the size threshold on how big should a module be to be considered wrapped. force_leaf_modules (Set[Type[nn.Module]]): set of module types to keep as leaves, i.e., their children will never be wrapped. exclude_wrap_modules (Set[Type[nn.Module]]): Customizable set of module types to be excluded in wrapping. skip_params_check_for_root (bool): If module_is_root is True, then this includes the root in wrapping regardless of their number of unwrapped params. """ force_leaf_modules = ( default_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore if force_leaf_modules is None else force_leaf_modules ) exclude_wrap_modules = ( default_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore if exclude_wrap_modules is None else exclude_wrap_modules ) is_large = unwrapped_params >= min_num_params if recurse: # We should recurse if the module is big enough but not in force_leaf_modules list. return is_large and not isinstance(module, tuple(force_leaf_modules)) else: # If we are not recursing, determine if we should wrap. return ((module_is_root and skip_params_check_for_root) or is_large) and not isinstance( module, tuple(exclude_wrap_modules) ) # Set those defaults to the default_auto_wrap_policy function. Make them easy to be imported. default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore def config_auto_wrap_policy( module: nn.Module, recurse: bool, unwrapped_params: int, module_is_root: bool, ) -> bool: """Config based policy function for :func:`auto_wrap`. Return true for a module to be wrapped if it is already tagged with a ``wrapper_config`` attribute. Args: module (nn.Module): The module to be considered in this decision. recurse (bool): Indicate if this is called to make a decision on whether we should recurse down a subgraph of the module structure. If False, it means this function is called to make a decision on whether we should wrap the said module. unwrapped_params (int): The number of parameters yet to be wrapped in this module. Unused by this function. module_is_root (bool): Indicates if current module is the root. Unused by this function. """ if recurse: # We should always recurse. return True else: # If we are not recursing, determine if we should wrap. return hasattr(module, "wrapper_config") @contextlib.contextmanager def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]: """ Context manager to wrap modules using a wrapper. Useful for when you'd like to apply the same parameters to all child modules that you wrap. A particularly important use case is wrapping large layers so that they get sharded (in-place) during initialization, to avoid running out of system memory. Large layers can indicate that they should be sharded via the ``wrap`` annotation and this context manager can provide the exact configuration for these nested instances. Usage:: with enable_wrap(**params): # Wraps layer in FSDP by default if within context self.l1 = wrap(torch.nn.Linear(5, 5)) self.l2 = auto_wrap( TransformerBlock(), # Wraps children modules based on a different min_num_params auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1e7) ) Args: auto_wrap_policy (Callable, Optional): Custom function to control how to do :func:`auto_wrap`. This is useful to exclude unsupported modules or wrap based on sizes when wrapping recursively. Note: modules annotated with :func:`wrap` ignore this policy and will always be wrapped. (default: :func:`default_auto_wrap_policy`) **wrapper_kwargs: Configuration settings that will be passed to all ``wrap`` instances inside the context """ with ConfigAutoWrap(auto_wrap_policy, **wrapper_kwargs): yield def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: """ Annotate that a module should be wrapped. Annotated modules will only be wrapped if inside of an :func:`enable_wrap` context manager. This allows a module to be initialized both with and without a wrapper without code change. Both wrapper_cls and wrapper_config can be taken from 3 sources with increasing priority: 1. ConfigAutoWrap's context 2. module.wrapper_config 3. wrap_overrides argument of this function Usage:: with enable_wrap(wrapper_cls=FSDP, **fsdp_config): # Wraps layer in FSDP by default if within context self.l1 = wrap(torch.nn.Linear(5, 5)) Args: module (nn.Module): module to wrap (if in :func:`enable_wrap` context) **wrap_overrides: configuration overrides that will take priority over the values provided by the :func:`enable_wrap` context """ if ConfigAutoWrap.in_autowrap_context: module_overrides = {} if hasattr(module, "wrapper_config"): module_overrides = module.wrapper_config assert isinstance(module_overrides, dict) wrap_overrides = {**ConfigAutoWrap.kwargs, **module_overrides, **wrap_overrides} assert ConfigAutoWrap.wrapper_cls is not None if ConfigAutoWrap.move_module_cuda_half: module = module.cuda().half() return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides) return module def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, **kwargs: Any) -> nn.Module: """ Annotate that a module should be wrapped with the *wrapper_cls* from the :func:`enable_wrap` context (if the context exists) and recursively wrap children modules that meet the criteria given by :func:`auto_wrap_policy`. This is useful for wrapping large complex layers. .. note:: auto_wrap can only be applied to a module once because it assumes none of the sub-modules is already wrapped and uses that assumption to compute the wrapped vs. unwrapped parameters. To get around this limitation, users can pre-assign ``wrapper_config`` attributes to the sub-modules they want to wrap (in multiple passes) and then uses the ``config_auto_wrap_policy``. .. warning:: It is not recommended to use :func:`auto_wrap` with :class:`FullyShardedDataParallel` on modules that have shared parameters, as the parameter sharing may be broken (i.e. end up not shared) if the shared parameters are not (auto-)wrapped under the same FSDP wrapper instance. Usage:: with enable_wrap(**params): # Wraps children modules. self.l1 = auto_wrap(TransformerBlock()) Args: module (nn.Module): module to wrap (if in :func:`enable_wrap` context) auto_wrap_policy (Callable): a function to determine should Module to be wrapped. (default: wrap if > 100M parameters) """ if ConfigAutoWrap.in_autowrap_context: wrapped_module, remainder = ConfigAutoWrap.recursive_wrap( module, auto_wrap_policy=auto_wrap_policy, module_is_root=True, **kwargs ) return wrapped_module return module class ConfigAutoWrap: """ Helper class to wrap modules based on default config args via a context manager. See :func:`enable_wrap` for more information. """ in_autowrap_context: bool = False # Context flag move_module_cuda_half: bool = False # A flag to control the wrap() function. wrapper_cls: Optional[Callable] = None # The wrapper class kwargs: Dict[str, Any] = {} # Wrapper's args auto_wrap_policy: Optional[Callable] = None # Used only in auto_wrap def __init__(self, auto_wrap_policy: Optional[Callable] = None, **kwargs: Dict[str, Any]): self.auto_wrap_policy = auto_wrap_policy self.kwargs = kwargs @staticmethod def enable_autowrap_context(auto_wrap_policy: Optional[Callable], kwargs: Any) -> None: if ConfigAutoWrap.in_autowrap_context: raise NotImplementedError( "You are already within an autowrap context and we currently do not supported nested autowrap." ) ConfigAutoWrap.in_autowrap_context = True # Get and save the wrapper cls for the context. if "move_module_cuda_half" in kwargs.keys(): ConfigAutoWrap.move_module_cuda_half = cast(bool, kwargs["move_module_cuda_half"]) del kwargs["move_module_cuda_half"] assert "wrapper_cls" in kwargs.keys() ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) del kwargs["wrapper_cls"] # Save the rest. ConfigAutoWrap.auto_wrap_policy = default_auto_wrap_policy if auto_wrap_policy is None else auto_wrap_policy ConfigAutoWrap.kwargs = kwargs @staticmethod def disable_autowrap_context() -> None: ConfigAutoWrap.in_autowrap_context = False ConfigAutoWrap.move_module_cuda_half = False ConfigAutoWrap.wrapper_cls = None ConfigAutoWrap.kwargs = {} ConfigAutoWrap.auto_wrap_policy = None def __enter__(self) -> None: self.enable_autowrap_context(self.auto_wrap_policy, self.kwargs) def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.disable_autowrap_context() @staticmethod def recursive_wrap( module: nn.Module, auto_wrap_policy: Optional[Callable], module_is_root: bool, **kwargs: Any ) -> Tuple[nn.Module, int]: """ Automatically wrap child modules of *module* that meet the given criteria with :func:`auto_wrap`. Args: module (nn.Module): module to recursively wrap auto_wrap_policy (Callable, Optional): optionally, override the :func:`auto_wrap_policy` from the context. Returns: (nn.Module, int): Wrapped module and the number parameters wrapped recursively. """ if auto_wrap_policy is None: auto_wrap_policy = ConfigAutoWrap.auto_wrap_policy # Make sure no child is not already wrapped. for _, child in module.named_modules(): assert not isinstance(child, cast(type, ConfigAutoWrap.wrapper_cls)) # We count all params, assuming none of them is already wrapped. num_params = sum([p.numel() for p in module.parameters()]) assert auto_wrap_policy is not None if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params, module_is_root=module_is_root): total_wrapped_params = 0 # Iterate through the children, recursively wrap if necessary for name, child in module.named_children(): wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap( module=child, auto_wrap_policy=auto_wrap_policy, module_is_root=False, **kwargs ) setattr(module, name, wrapped_child) # Keep track of how many parameters have been wrapped total_wrapped_params += num_wrapped_params # decide if we need to wrap the current module, # since the left over parameters exceed the number of params to wrap remainder = num_params - total_wrapped_params if auto_wrap_policy( module=module, recurse=False, unwrapped_params=remainder, module_is_root=module_is_root ): # Leaf node or final wrapping of the remainder both happen here. return wrap(module, **kwargs), num_params else: return module, total_wrapped_params return module, 0