File size: 14,079 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
from typing import Any, Callable, Dict, Generator, Optional, Set, Tuple, Type, cast

import torch.nn as nn


def default_auto_wrap_policy(
    module: nn.Module,
    recurse: bool,
    unwrapped_params: int,
    module_is_root: bool,
    # These are customizable for this default policy function.
    min_num_params: int = int(1e8),
    force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
    exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
    skip_params_check_for_root: bool = False,
) -> bool:
    """Default policy function for :func:`auto_wrap`.

       Return if a module should be wrapped during :func:`auto_wrap`.

       The first four parameters are used by :func:`auto_wrap`. If
       you write a custom version of this policy function, your version
       needs to at least accept the first four parameters and free
       to do whatever you want in the function.

    Args:
       module (nn.Module):
           The module to be considered in this decision.
       recurse (bool):
           Indicate if this is called to make a decision on whether we
           should recurse down a subgraph of the module structure.
           If False, it means this function is called to make a decision
           on whether we should wrap the said module.
       unwrapped_params (int):
           The number of parameters yet to be wrapped in this module.
       module_is_root (bool):
           Indicates if current module is the root.

       min_num_params (int):
           Customizable policy input. It controls the size threshold
           on how big should a module be to be considered wrapped.
       force_leaf_modules (Set[Type[nn.Module]]): set of module types to
           keep as leaves, i.e., their children will never be wrapped.
       exclude_wrap_modules (Set[Type[nn.Module]]):
           Customizable set of module types to be excluded in wrapping.
       skip_params_check_for_root (bool):
           If module_is_root is True, then this includes the root in
           wrapping regardless of their number of unwrapped params.
    """
    force_leaf_modules = (
        default_auto_wrap_policy.FORCE_LEAF_MODULES  # type: ignore
        if force_leaf_modules is None
        else force_leaf_modules
    )
    exclude_wrap_modules = (
        default_auto_wrap_policy.EXCLUDE_WRAP_MODULES  # type: ignore
        if exclude_wrap_modules is None
        else exclude_wrap_modules
    )

    is_large = unwrapped_params >= min_num_params
    if recurse:
        # We should recurse if the module is big enough but not in force_leaf_modules list.
        return is_large and not isinstance(module, tuple(force_leaf_modules))
    else:
        # If we are not recursing, determine if we should wrap.
        return ((module_is_root and skip_params_check_for_root) or is_large) and not isinstance(
            module, tuple(exclude_wrap_modules)
        )


# Set those defaults to the default_auto_wrap_policy function. Make them easy to be imported.
default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict}  # type: ignore
default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention}  # type: ignore


def config_auto_wrap_policy(
    module: nn.Module,
    recurse: bool,
    unwrapped_params: int,
    module_is_root: bool,
) -> bool:
    """Config based policy function for :func:`auto_wrap`.

       Return true for a module to be wrapped if it is already tagged with
       a ``wrapper_config`` attribute.

    Args:
       module (nn.Module):
           The module to be considered in this decision.
       recurse (bool):
           Indicate if this is called to make a decision on whether we
           should recurse down a subgraph of the module structure.
           If False, it means this function is called to make a decision
           on whether we should wrap the said module.
       unwrapped_params (int):
           The number of parameters yet to be wrapped in this module.
           Unused by this function.
       module_is_root (bool):
           Indicates if current module is the root.
           Unused by this function.
    """
    if recurse:
        # We should always recurse.
        return True
    else:
        # If we are not recursing, determine if we should wrap.
        return hasattr(module, "wrapper_config")


@contextlib.contextmanager
def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]:
    """
    Context manager to wrap modules using a wrapper.

    Useful for when you'd like to apply the same parameters to all child modules
    that you wrap. A particularly important use case is wrapping large layers so
    that they get sharded (in-place) during initialization, to avoid running out of
    system memory. Large layers can indicate that they should be sharded via
    the ``wrap`` annotation and this context manager can provide the
    exact configuration for these nested instances.

    Usage::

        with enable_wrap(**params):
            # Wraps layer in FSDP by default if within context
            self.l1 = wrap(torch.nn.Linear(5, 5))
            self.l2 = auto_wrap(
                TransformerBlock(),
                # Wraps children modules based on a different min_num_params
                auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1e7)
            )

    Args:
        auto_wrap_policy (Callable, Optional):
            Custom function to control how to do :func:`auto_wrap`. This is
            useful to exclude unsupported modules or wrap based on sizes when
            wrapping recursively. Note: modules annotated with :func:`wrap`
            ignore this policy and will always be wrapped.
            (default: :func:`default_auto_wrap_policy`)
        **wrapper_kwargs:
            Configuration settings that will be passed to all ``wrap``
            instances inside the context
    """
    with ConfigAutoWrap(auto_wrap_policy, **wrapper_kwargs):
        yield


def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
    """
    Annotate that a module should be wrapped. Annotated modules will only be
    wrapped if inside of an :func:`enable_wrap` context manager. This allows
    a module to be initialized both with and without a wrapper without code
    change.

    Both wrapper_cls and wrapper_config can be taken from 3 sources with
    increasing priority:

        1. ConfigAutoWrap's context
        2. module.wrapper_config
        3. wrap_overrides argument of this function

    Usage::

        with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
            # Wraps layer in FSDP by default if within context
            self.l1 = wrap(torch.nn.Linear(5, 5))

    Args:
        module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
        **wrap_overrides: configuration overrides that will take priority over
            the values provided by the :func:`enable_wrap` context
    """
    if ConfigAutoWrap.in_autowrap_context:
        module_overrides = {}
        if hasattr(module, "wrapper_config"):
            module_overrides = module.wrapper_config
            assert isinstance(module_overrides, dict)
        wrap_overrides = {**ConfigAutoWrap.kwargs, **module_overrides, **wrap_overrides}
        assert ConfigAutoWrap.wrapper_cls is not None
        if ConfigAutoWrap.move_module_cuda_half:
            module = module.cuda().half()
        return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides)
    return module


def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, **kwargs: Any) -> nn.Module:
    """
    Annotate that a module should be wrapped with the *wrapper_cls* from the
    :func:`enable_wrap` context (if the context exists) and recursively wrap
    children modules that meet the criteria given by :func:`auto_wrap_policy`. This
    is useful for wrapping large complex layers.

    .. note:: auto_wrap can only be applied to a module once because it
        assumes none of the sub-modules is already wrapped and uses that
        assumption to compute the wrapped vs. unwrapped parameters.
        To get around this limitation, users can pre-assign ``wrapper_config``
        attributes to the sub-modules they want to wrap (in multiple passes)
        and then uses the ``config_auto_wrap_policy``.

    .. warning:: It is not recommended to use :func:`auto_wrap` with
        :class:`FullyShardedDataParallel` on modules that have shared
        parameters, as the parameter sharing may be broken (i.e. end up not
        shared) if the shared parameters are not (auto-)wrapped under the same
        FSDP wrapper instance.

    Usage::

        with enable_wrap(**params):
            # Wraps children modules.
            self.l1 = auto_wrap(TransformerBlock())

    Args:
        module (nn.Module):
            module to wrap (if in :func:`enable_wrap` context)
        auto_wrap_policy (Callable):
            a function to determine should Module to be wrapped.
            (default: wrap if > 100M parameters)
    """
    if ConfigAutoWrap.in_autowrap_context:
        wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(
            module, auto_wrap_policy=auto_wrap_policy, module_is_root=True, **kwargs
        )
        return wrapped_module
    return module


class ConfigAutoWrap:
    """
    Helper class to wrap modules based on default config args via a context manager.
    See :func:`enable_wrap` for more information.
    """

    in_autowrap_context: bool = False  # Context flag
    move_module_cuda_half: bool = False  # A flag to control the wrap() function.
    wrapper_cls: Optional[Callable] = None  # The wrapper class
    kwargs: Dict[str, Any] = {}  # Wrapper's args
    auto_wrap_policy: Optional[Callable] = None  # Used only in auto_wrap

    def __init__(self, auto_wrap_policy: Optional[Callable] = None, **kwargs: Dict[str, Any]):
        self.auto_wrap_policy = auto_wrap_policy
        self.kwargs = kwargs

    @staticmethod
    def enable_autowrap_context(auto_wrap_policy: Optional[Callable], kwargs: Any) -> None:
        if ConfigAutoWrap.in_autowrap_context:
            raise NotImplementedError(
                "You are already within an autowrap context and we currently do not supported nested autowrap."
            )
        ConfigAutoWrap.in_autowrap_context = True
        # Get and save the wrapper cls for the context.
        if "move_module_cuda_half" in kwargs.keys():
            ConfigAutoWrap.move_module_cuda_half = cast(bool, kwargs["move_module_cuda_half"])
            del kwargs["move_module_cuda_half"]
        assert "wrapper_cls" in kwargs.keys()
        ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
        del kwargs["wrapper_cls"]
        # Save the rest.
        ConfigAutoWrap.auto_wrap_policy = default_auto_wrap_policy if auto_wrap_policy is None else auto_wrap_policy
        ConfigAutoWrap.kwargs = kwargs

    @staticmethod
    def disable_autowrap_context() -> None:
        ConfigAutoWrap.in_autowrap_context = False
        ConfigAutoWrap.move_module_cuda_half = False
        ConfigAutoWrap.wrapper_cls = None
        ConfigAutoWrap.kwargs = {}
        ConfigAutoWrap.auto_wrap_policy = None

    def __enter__(self) -> None:
        self.enable_autowrap_context(self.auto_wrap_policy, self.kwargs)

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
        self.disable_autowrap_context()

    @staticmethod
    def recursive_wrap(
        module: nn.Module, auto_wrap_policy: Optional[Callable], module_is_root: bool, **kwargs: Any
    ) -> Tuple[nn.Module, int]:
        """
        Automatically wrap child modules of *module* that meet the given
        criteria with :func:`auto_wrap`.

        Args:
            module (nn.Module):
                module to recursively wrap
            auto_wrap_policy (Callable, Optional):
                optionally, override the :func:`auto_wrap_policy` from the context.

        Returns:
            (nn.Module, int):
                Wrapped module and the number parameters wrapped recursively.
        """
        if auto_wrap_policy is None:
            auto_wrap_policy = ConfigAutoWrap.auto_wrap_policy

        # Make sure no child is not already wrapped.
        for _, child in module.named_modules():
            assert not isinstance(child, cast(type, ConfigAutoWrap.wrapper_cls))

        # We count all params, assuming none of them is already wrapped.
        num_params = sum([p.numel() for p in module.parameters()])

        assert auto_wrap_policy is not None
        if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params, module_is_root=module_is_root):
            total_wrapped_params = 0
            # Iterate through the children, recursively wrap if necessary
            for name, child in module.named_children():
                wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap(
                    module=child, auto_wrap_policy=auto_wrap_policy, module_is_root=False, **kwargs
                )
                setattr(module, name, wrapped_child)
                # Keep track of how many parameters have been wrapped
                total_wrapped_params += num_wrapped_params
            # decide if we need to wrap the current module,
            # since the left over parameters exceed the number of params to wrap
            remainder = num_params - total_wrapped_params
            if auto_wrap_policy(
                module=module, recurse=False, unwrapped_params=remainder, module_is_root=module_is_root
            ):
                # Leaf node or final wrapping of the remainder both happen here.
                return wrap(module, **kwargs), num_params
            else:
                return module, total_wrapped_params
        return module, 0