File size: 5,103 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import torch
try:
from torch._subclasses.fake_tensor import unset_fake_temporarily
except ImportError:
# Unsupported torch version
pass
import deepspeed
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
def sync_all():
get_accelerator().synchronize()
dist.barrier()
def get_bw(comm_op, size, duration):
n = dist.get_world_size()
tput = 0
busbw = 0
if duration == 0:
raise ValueError("Error. Duration is 0.")
if comm_op == "all_to_all":
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "all_gather":
size *= n
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "all_reduce":
tput = (size * 2 / duration)
busbw = (size / duration) * (2 * (n - 1) / n)
elif comm_op == "pt2pt" or comm_op == "broadcast":
tput = (size / duration)
busbw = tput
else:
raise ValueError("wrong comm_op specified")
return tput, busbw
# Run all_gather and print metrics
def timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op):
sync_all()
# Warmups, establish connections, etc.
for i in range(warmup):
dist.all_gather_into_tensor(output, input, async_op=async_op)
sync_all()
# time the actual comm op trials times and average it
start_event.record()
for i in range(trials):
dist.all_gather_into_tensor(output, input, async_op=async_op)
end_event.record()
sync_all()
duration = start_event.elapsed_time(end_event) / 1000
# maintain and clean performance data
avg_duration = duration / trials
size = input.element_size() * input.nelement() * dist.get_world_size()
# tput, busbw = get_bw('all_gather', size, avg_duration)
avg_duration_ten = torch.tensor([avg_duration], device=device)
if dist.get_world_size() > 1:
dist.all_reduce(avg_duration_ten, dist.ReduceOp.AVG)
return size, avg_duration_ten.item()
def run_all_gather(device, dtype, maxsize, warmup=5, trials=10, async_op=False):
# Prepare benchmark header
global_rank = dist.get_rank()
world_size = dist.get_world_size()
start_event = get_accelerator().Event(enable_timing=True)
end_event = get_accelerator().Event(enable_timing=True)
# Create list of message sizes
M_LIST = []
for x in (2**p for p in range(1, maxsize)):
m = x // world_size
if m > 0:
M_LIST.append(m)
results = [(0, 0)]
sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(M, dtype=dtype, device=device)
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
get_accelerator().empty_cache()
output = torch.zeros(input.nelement() * world_size, dtype=dtype, device=device)
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
else:
raise e
sync_all()
results.append(timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op))
return results
profile_results = None
def create_predictor():
global profile_results
if profile_results is None:
with unset_fake_temporarily():
device = get_accelerator().current_device()
profile_results = run_all_gather(device, torch.bfloat16, 31)
if dist.get_rank() == 0:
for size, avg_duration in profile_results:
print(f"size: {size}, avg_duration: {avg_duration}")
# Extract size and avg_duration from results
sizes = [result[0] for result in profile_results]
durations = [result[1] for result in profile_results]
try:
from scipy.interpolate import interp1d
except ImportError:
raise RuntimeError("Please install scipy to use communication profiler in DeepCompile")
predictor = interp1d(sizes, durations, kind='linear', fill_value="extrapolate")
def f(size):
if size == 0:
return 0
return predictor(size)
# Create an interpolation function
return f
if __name__ == "__main__":
local_rank = int(os.environ['LOCAL_RANK'])
get_accelerator().set_device(local_rank)
print(f"local_rank={local_rank}")
deepspeed.init_distributed(dist_backend='nccl')
# Create predictor function
predictor = create_predictor()
# Predict time for a specific data size
example_size = 1e9
predicted_time = predictor(example_size)
print(f"Predicted time for size {example_size}: {predicted_time:.6f} seconds")
dist.destroy_process_group()
|