File size: 14,079 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
from typing import Any, Callable, Dict, Generator, Optional, Set, Tuple, Type, cast
import torch.nn as nn
def default_auto_wrap_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
module_is_root: bool,
# These are customizable for this default policy function.
min_num_params: int = int(1e8),
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
skip_params_check_for_root: bool = False,
) -> bool:
"""Default policy function for :func:`auto_wrap`.
Return if a module should be wrapped during :func:`auto_wrap`.
The first four parameters are used by :func:`auto_wrap`. If
you write a custom version of this policy function, your version
needs to at least accept the first four parameters and free
to do whatever you want in the function.
Args:
module (nn.Module):
The module to be considered in this decision.
recurse (bool):
Indicate if this is called to make a decision on whether we
should recurse down a subgraph of the module structure.
If False, it means this function is called to make a decision
on whether we should wrap the said module.
unwrapped_params (int):
The number of parameters yet to be wrapped in this module.
module_is_root (bool):
Indicates if current module is the root.
min_num_params (int):
Customizable policy input. It controls the size threshold
on how big should a module be to be considered wrapped.
force_leaf_modules (Set[Type[nn.Module]]): set of module types to
keep as leaves, i.e., their children will never be wrapped.
exclude_wrap_modules (Set[Type[nn.Module]]):
Customizable set of module types to be excluded in wrapping.
skip_params_check_for_root (bool):
If module_is_root is True, then this includes the root in
wrapping regardless of their number of unwrapped params.
"""
force_leaf_modules = (
default_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore
if force_leaf_modules is None
else force_leaf_modules
)
exclude_wrap_modules = (
default_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore
if exclude_wrap_modules is None
else exclude_wrap_modules
)
is_large = unwrapped_params >= min_num_params
if recurse:
# We should recurse if the module is big enough but not in force_leaf_modules list.
return is_large and not isinstance(module, tuple(force_leaf_modules))
else:
# If we are not recursing, determine if we should wrap.
return ((module_is_root and skip_params_check_for_root) or is_large) and not isinstance(
module, tuple(exclude_wrap_modules)
)
# Set those defaults to the default_auto_wrap_policy function. Make them easy to be imported.
default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore
default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore
def config_auto_wrap_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
module_is_root: bool,
) -> bool:
"""Config based policy function for :func:`auto_wrap`.
Return true for a module to be wrapped if it is already tagged with
a ``wrapper_config`` attribute.
Args:
module (nn.Module):
The module to be considered in this decision.
recurse (bool):
Indicate if this is called to make a decision on whether we
should recurse down a subgraph of the module structure.
If False, it means this function is called to make a decision
on whether we should wrap the said module.
unwrapped_params (int):
The number of parameters yet to be wrapped in this module.
Unused by this function.
module_is_root (bool):
Indicates if current module is the root.
Unused by this function.
"""
if recurse:
# We should always recurse.
return True
else:
# If we are not recursing, determine if we should wrap.
return hasattr(module, "wrapper_config")
@contextlib.contextmanager
def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]:
"""
Context manager to wrap modules using a wrapper.
Useful for when you'd like to apply the same parameters to all child modules
that you wrap. A particularly important use case is wrapping large layers so
that they get sharded (in-place) during initialization, to avoid running out of
system memory. Large layers can indicate that they should be sharded via
the ``wrap`` annotation and this context manager can provide the
exact configuration for these nested instances.
Usage::
with enable_wrap(**params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
self.l2 = auto_wrap(
TransformerBlock(),
# Wraps children modules based on a different min_num_params
auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1e7)
)
Args:
auto_wrap_policy (Callable, Optional):
Custom function to control how to do :func:`auto_wrap`. This is
useful to exclude unsupported modules or wrap based on sizes when
wrapping recursively. Note: modules annotated with :func:`wrap`
ignore this policy and will always be wrapped.
(default: :func:`default_auto_wrap_policy`)
**wrapper_kwargs:
Configuration settings that will be passed to all ``wrap``
instances inside the context
"""
with ConfigAutoWrap(auto_wrap_policy, **wrapper_kwargs):
yield
def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
"""
Annotate that a module should be wrapped. Annotated modules will only be
wrapped if inside of an :func:`enable_wrap` context manager. This allows
a module to be initialized both with and without a wrapper without code
change.
Both wrapper_cls and wrapper_config can be taken from 3 sources with
increasing priority:
1. ConfigAutoWrap's context
2. module.wrapper_config
3. wrap_overrides argument of this function
Usage::
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
Args:
module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
**wrap_overrides: configuration overrides that will take priority over
the values provided by the :func:`enable_wrap` context
"""
if ConfigAutoWrap.in_autowrap_context:
module_overrides = {}
if hasattr(module, "wrapper_config"):
module_overrides = module.wrapper_config
assert isinstance(module_overrides, dict)
wrap_overrides = {**ConfigAutoWrap.kwargs, **module_overrides, **wrap_overrides}
assert ConfigAutoWrap.wrapper_cls is not None
if ConfigAutoWrap.move_module_cuda_half:
module = module.cuda().half()
return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides)
return module
def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, **kwargs: Any) -> nn.Module:
"""
Annotate that a module should be wrapped with the *wrapper_cls* from the
:func:`enable_wrap` context (if the context exists) and recursively wrap
children modules that meet the criteria given by :func:`auto_wrap_policy`. This
is useful for wrapping large complex layers.
.. note:: auto_wrap can only be applied to a module once because it
assumes none of the sub-modules is already wrapped and uses that
assumption to compute the wrapped vs. unwrapped parameters.
To get around this limitation, users can pre-assign ``wrapper_config``
attributes to the sub-modules they want to wrap (in multiple passes)
and then uses the ``config_auto_wrap_policy``.
.. warning:: It is not recommended to use :func:`auto_wrap` with
:class:`FullyShardedDataParallel` on modules that have shared
parameters, as the parameter sharing may be broken (i.e. end up not
shared) if the shared parameters are not (auto-)wrapped under the same
FSDP wrapper instance.
Usage::
with enable_wrap(**params):
# Wraps children modules.
self.l1 = auto_wrap(TransformerBlock())
Args:
module (nn.Module):
module to wrap (if in :func:`enable_wrap` context)
auto_wrap_policy (Callable):
a function to determine should Module to be wrapped.
(default: wrap if > 100M parameters)
"""
if ConfigAutoWrap.in_autowrap_context:
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(
module, auto_wrap_policy=auto_wrap_policy, module_is_root=True, **kwargs
)
return wrapped_module
return module
class ConfigAutoWrap:
"""
Helper class to wrap modules based on default config args via a context manager.
See :func:`enable_wrap` for more information.
"""
in_autowrap_context: bool = False # Context flag
move_module_cuda_half: bool = False # A flag to control the wrap() function.
wrapper_cls: Optional[Callable] = None # The wrapper class
kwargs: Dict[str, Any] = {} # Wrapper's args
auto_wrap_policy: Optional[Callable] = None # Used only in auto_wrap
def __init__(self, auto_wrap_policy: Optional[Callable] = None, **kwargs: Dict[str, Any]):
self.auto_wrap_policy = auto_wrap_policy
self.kwargs = kwargs
@staticmethod
def enable_autowrap_context(auto_wrap_policy: Optional[Callable], kwargs: Any) -> None:
if ConfigAutoWrap.in_autowrap_context:
raise NotImplementedError(
"You are already within an autowrap context and we currently do not supported nested autowrap."
)
ConfigAutoWrap.in_autowrap_context = True
# Get and save the wrapper cls for the context.
if "move_module_cuda_half" in kwargs.keys():
ConfigAutoWrap.move_module_cuda_half = cast(bool, kwargs["move_module_cuda_half"])
del kwargs["move_module_cuda_half"]
assert "wrapper_cls" in kwargs.keys()
ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
del kwargs["wrapper_cls"]
# Save the rest.
ConfigAutoWrap.auto_wrap_policy = default_auto_wrap_policy if auto_wrap_policy is None else auto_wrap_policy
ConfigAutoWrap.kwargs = kwargs
@staticmethod
def disable_autowrap_context() -> None:
ConfigAutoWrap.in_autowrap_context = False
ConfigAutoWrap.move_module_cuda_half = False
ConfigAutoWrap.wrapper_cls = None
ConfigAutoWrap.kwargs = {}
ConfigAutoWrap.auto_wrap_policy = None
def __enter__(self) -> None:
self.enable_autowrap_context(self.auto_wrap_policy, self.kwargs)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.disable_autowrap_context()
@staticmethod
def recursive_wrap(
module: nn.Module, auto_wrap_policy: Optional[Callable], module_is_root: bool, **kwargs: Any
) -> Tuple[nn.Module, int]:
"""
Automatically wrap child modules of *module* that meet the given
criteria with :func:`auto_wrap`.
Args:
module (nn.Module):
module to recursively wrap
auto_wrap_policy (Callable, Optional):
optionally, override the :func:`auto_wrap_policy` from the context.
Returns:
(nn.Module, int):
Wrapped module and the number parameters wrapped recursively.
"""
if auto_wrap_policy is None:
auto_wrap_policy = ConfigAutoWrap.auto_wrap_policy
# Make sure no child is not already wrapped.
for _, child in module.named_modules():
assert not isinstance(child, cast(type, ConfigAutoWrap.wrapper_cls))
# We count all params, assuming none of them is already wrapped.
num_params = sum([p.numel() for p in module.parameters()])
assert auto_wrap_policy is not None
if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params, module_is_root=module_is_root):
total_wrapped_params = 0
# Iterate through the children, recursively wrap if necessary
for name, child in module.named_children():
wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap(
module=child, auto_wrap_policy=auto_wrap_policy, module_is_root=False, **kwargs
)
setattr(module, name, wrapped_child)
# Keep track of how many parameters have been wrapped
total_wrapped_params += num_wrapped_params
# decide if we need to wrap the current module,
# since the left over parameters exceed the number of params to wrap
remainder = num_params - total_wrapped_params
if auto_wrap_policy(
module=module, recurse=False, unwrapped_params=remainder, module_is_root=module_is_root
):
# Leaf node or final wrapping of the remainder both happen here.
return wrap(module, **kwargs), num_params
else:
return module, total_wrapped_params
return module, 0
|