File size: 5,103 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
import torch

try:
    from torch._subclasses.fake_tensor import unset_fake_temporarily
except ImportError:
    # Unsupported torch version
    pass

import deepspeed
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator


def sync_all():
    get_accelerator().synchronize()
    dist.barrier()


def get_bw(comm_op, size, duration):
    n = dist.get_world_size()
    tput = 0
    busbw = 0

    if duration == 0:
        raise ValueError("Error. Duration is 0.")

    if comm_op == "all_to_all":
        tput = (size / duration)
        busbw = (size / duration) * ((n - 1) / n)
    elif comm_op == "all_gather":
        size *= n
        tput = (size / duration)
        busbw = (size / duration) * ((n - 1) / n)
    elif comm_op == "all_reduce":
        tput = (size * 2 / duration)
        busbw = (size / duration) * (2 * (n - 1) / n)
    elif comm_op == "pt2pt" or comm_op == "broadcast":
        tput = (size / duration)
        busbw = tput
    else:
        raise ValueError("wrong comm_op specified")

    return tput, busbw


# Run all_gather and print metrics
def timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op):
    sync_all()
    # Warmups, establish connections, etc.
    for i in range(warmup):
        dist.all_gather_into_tensor(output, input, async_op=async_op)
    sync_all()

    # time the actual comm op trials times and average it
    start_event.record()
    for i in range(trials):
        dist.all_gather_into_tensor(output, input, async_op=async_op)
    end_event.record()
    sync_all()
    duration = start_event.elapsed_time(end_event) / 1000

    # maintain and clean performance data
    avg_duration = duration / trials
    size = input.element_size() * input.nelement() * dist.get_world_size()
    # tput, busbw = get_bw('all_gather', size, avg_duration)

    avg_duration_ten = torch.tensor([avg_duration], device=device)
    if dist.get_world_size() > 1:
        dist.all_reduce(avg_duration_ten, dist.ReduceOp.AVG)

    return size, avg_duration_ten.item()


def run_all_gather(device, dtype, maxsize, warmup=5, trials=10, async_op=False):

    # Prepare benchmark header
    global_rank = dist.get_rank()
    world_size = dist.get_world_size()

    start_event = get_accelerator().Event(enable_timing=True)
    end_event = get_accelerator().Event(enable_timing=True)

    # Create list of message sizes
    M_LIST = []
    for x in (2**p for p in range(1, maxsize)):
        m = x // world_size
        if m > 0:
            M_LIST.append(m)

    results = [(0, 0)]
    sync_all()
    # loop over various tensor sizes
    for M in M_LIST:
        global_rank = dist.get_rank()
        try:
            mat = torch.ones(M, dtype=dtype, device=device)
            sync_all()
            input = ((mat.mul_(float(global_rank))).view(-1))
            # Delete original mat to avoid OOM
            del mat
            get_accelerator().empty_cache()
            output = torch.zeros(input.nelement() * world_size, dtype=dtype, device=device)
        except RuntimeError as e:
            if 'out of memory' in str(e):
                if dist.get_rank() == 0:
                    print('WARNING: Ran out of GPU memory. Exiting comm op.')
                sync_all()
                break
            else:
                raise e
        sync_all()
        results.append(timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op))

    return results


profile_results = None


def create_predictor():
    global profile_results
    if profile_results is None:
        with unset_fake_temporarily():
            device = get_accelerator().current_device()
            profile_results = run_all_gather(device, torch.bfloat16, 31)
        if dist.get_rank() == 0:
            for size, avg_duration in profile_results:
                print(f"size: {size}, avg_duration: {avg_duration}")

    # Extract size and avg_duration from results
    sizes = [result[0] for result in profile_results]
    durations = [result[1] for result in profile_results]

    try:
        from scipy.interpolate import interp1d
    except ImportError:
        raise RuntimeError("Please install scipy to use communication profiler in DeepCompile")

    predictor = interp1d(sizes, durations, kind='linear', fill_value="extrapolate")

    def f(size):
        if size == 0:
            return 0
        return predictor(size)

    # Create an interpolation function
    return f


if __name__ == "__main__":
    local_rank = int(os.environ['LOCAL_RANK'])
    get_accelerator().set_device(local_rank)
    print(f"local_rank={local_rank}")

    deepspeed.init_distributed(dist_backend='nccl')

    # Create predictor function
    predictor = create_predictor()

    # Predict time for a specific data size
    example_size = 1e9
    predicted_time = predictor(example_size)
    print(f"Predicted time for size {example_size}: {predicted_time:.6f} seconds")

    dist.destroy_process_group()