File size: 736 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import torch.distributed as dist

from .checkpoint import checkpoint_wrapper
from .data_parallel import FullyShardedDataParallel

if dist.is_available():
    # Prevent import failure if dist is not available. #1057
    from .data_parallel import ShardedDataParallel
    from .moe import MOELayer, Top2Gate
    from .pipe import Pipe, PipeRPCWrapper

from .misc import FlattenParamsWrapper
from .wrap import auto_wrap, config_auto_wrap_policy, default_auto_wrap_policy, enable_wrap, wrap

__all__: List[str] = []