File size: 3,275 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
from dataclasses import dataclass
from deepspeed import comm as dist
from deepspeed.constants import CROSS_RANK, CROSS_SIZE, LOCAL_RANK
from .data_parallel_writer_factory import DataParallelWriterFactory

# TODO: parse socket number from env.
SOCKETS_PER_MACHINE = 2


@dataclass
class MPUInfo(object):
    pp_world_size: int
    pp_rank: int
    tp_world_size: int
    tp_rank: int
    dp_world_size: int
    dp_peer_ranks: list
    dp_rank: int


def _create_model_parallel_info(mpu):
    return MPUInfo(pp_world_size=mpu.get_pipeline_model_parallel_world_size(),
                   pp_rank=mpu.get_pipeline_model_parallel_rank(),
                   tp_world_size=mpu.get_tensor_model_parallel_world_size(),
                   tp_rank=mpu.get_tensor_model_parallel_rank(),
                   dp_world_size=mpu.get_data_parallel_world_size(),
                   dp_peer_ranks=mpu.get_data_parallel_group_ranks(),
                   dp_rank=mpu.get_data_parallel_rank())


@dataclass
class ExpertParallelInfo(object):
    ep_world_size: int
    ep_rank: int
    dp_world_size: int
    dp_peer_ranks: list
    dp_rank: int


def _create_expert_parallel_info(groups):
    group_name = groups._get_max_expert_size_name()
    return ExpertParallelInfo(ep_world_size=groups._get_expert_parallel_world_size(group_name),
                              ep_rank=groups._get_expert_parallel_rank(group_name),
                              dp_world_size=groups._get_expert_data_parallel_world_size(group_name),
                              dp_peer_ranks=groups._get_expert_data_parallel_group_ranks(group_name),
                              dp_rank=groups._get_expert_data_parallel_rank(group_name))


@dataclass
class UniversalParallelInfo(object):
    global_world_size: int
    global_rank: int
    local_rank: int
    mpu_info: MPUInfo
    ep_info: ExpertParallelInfo
    pure_dp: bool
    num_machines: int
    machine_rank: int
    num_sockets: int


def create_universal_parallel_info(groups, has_moe_layers):
    return UniversalParallelInfo(global_world_size=dist.get_world_size(),
                                 global_rank=dist.get_rank(),
                                 local_rank=int(os.environ[LOCAL_RANK]),
                                 mpu_info=None if groups.mpu is None else _create_model_parallel_info(groups.mpu),
                                 ep_info=_create_expert_parallel_info(groups) if has_moe_layers else None,
                                 pure_dp=groups.mpu is None
                                 or groups.mpu.get_data_parallel_world_size() == dist.get_world_size(),
                                 num_machines=int(os.environ[CROSS_SIZE]),
                                 machine_rank=int(os.environ[CROSS_RANK]),
                                 num_sockets=int(os.environ[CROSS_SIZE]) * SOCKETS_PER_MACHINE)


def create_data_parallel_writer_config(groups, parallel_unit, zero_stage, has_moe_layers):
    uni_parallel_info = create_universal_parallel_info(groups, has_moe_layers)
    writer_factory = DataParallelWriterFactory(uni_parallel_info, parallel_unit)
    return writer_factory.create_config(zero_stage, has_moe_layers)