|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
|
|
try: |
|
from torch._subclasses.fake_tensor import unset_fake_temporarily |
|
except ImportError: |
|
|
|
pass |
|
|
|
import deepspeed |
|
import deepspeed.comm as dist |
|
from deepspeed.accelerator import get_accelerator |
|
|
|
|
|
def sync_all(): |
|
get_accelerator().synchronize() |
|
dist.barrier() |
|
|
|
|
|
def get_bw(comm_op, size, duration): |
|
n = dist.get_world_size() |
|
tput = 0 |
|
busbw = 0 |
|
|
|
if duration == 0: |
|
raise ValueError("Error. Duration is 0.") |
|
|
|
if comm_op == "all_to_all": |
|
tput = (size / duration) |
|
busbw = (size / duration) * ((n - 1) / n) |
|
elif comm_op == "all_gather": |
|
size *= n |
|
tput = (size / duration) |
|
busbw = (size / duration) * ((n - 1) / n) |
|
elif comm_op == "all_reduce": |
|
tput = (size * 2 / duration) |
|
busbw = (size / duration) * (2 * (n - 1) / n) |
|
elif comm_op == "pt2pt" or comm_op == "broadcast": |
|
tput = (size / duration) |
|
busbw = tput |
|
else: |
|
raise ValueError("wrong comm_op specified") |
|
|
|
return tput, busbw |
|
|
|
|
|
|
|
def timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op): |
|
sync_all() |
|
|
|
for i in range(warmup): |
|
dist.all_gather_into_tensor(output, input, async_op=async_op) |
|
sync_all() |
|
|
|
|
|
start_event.record() |
|
for i in range(trials): |
|
dist.all_gather_into_tensor(output, input, async_op=async_op) |
|
end_event.record() |
|
sync_all() |
|
duration = start_event.elapsed_time(end_event) / 1000 |
|
|
|
|
|
avg_duration = duration / trials |
|
size = input.element_size() * input.nelement() * dist.get_world_size() |
|
|
|
|
|
avg_duration_ten = torch.tensor([avg_duration], device=device) |
|
if dist.get_world_size() > 1: |
|
dist.all_reduce(avg_duration_ten, dist.ReduceOp.AVG) |
|
|
|
return size, avg_duration_ten.item() |
|
|
|
|
|
def run_all_gather(device, dtype, maxsize, warmup=5, trials=10, async_op=False): |
|
|
|
|
|
global_rank = dist.get_rank() |
|
world_size = dist.get_world_size() |
|
|
|
start_event = get_accelerator().Event(enable_timing=True) |
|
end_event = get_accelerator().Event(enable_timing=True) |
|
|
|
|
|
M_LIST = [] |
|
for x in (2**p for p in range(1, maxsize)): |
|
m = x // world_size |
|
if m > 0: |
|
M_LIST.append(m) |
|
|
|
results = [(0, 0)] |
|
sync_all() |
|
|
|
for M in M_LIST: |
|
global_rank = dist.get_rank() |
|
try: |
|
mat = torch.ones(M, dtype=dtype, device=device) |
|
sync_all() |
|
input = ((mat.mul_(float(global_rank))).view(-1)) |
|
|
|
del mat |
|
get_accelerator().empty_cache() |
|
output = torch.zeros(input.nelement() * world_size, dtype=dtype, device=device) |
|
except RuntimeError as e: |
|
if 'out of memory' in str(e): |
|
if dist.get_rank() == 0: |
|
print('WARNING: Ran out of GPU memory. Exiting comm op.') |
|
sync_all() |
|
break |
|
else: |
|
raise e |
|
sync_all() |
|
results.append(timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op)) |
|
|
|
return results |
|
|
|
|
|
profile_results = None |
|
|
|
|
|
def create_predictor(): |
|
global profile_results |
|
if profile_results is None: |
|
with unset_fake_temporarily(): |
|
device = get_accelerator().current_device() |
|
profile_results = run_all_gather(device, torch.bfloat16, 31) |
|
if dist.get_rank() == 0: |
|
for size, avg_duration in profile_results: |
|
print(f"size: {size}, avg_duration: {avg_duration}") |
|
|
|
|
|
sizes = [result[0] for result in profile_results] |
|
durations = [result[1] for result in profile_results] |
|
|
|
try: |
|
from scipy.interpolate import interp1d |
|
except ImportError: |
|
raise RuntimeError("Please install scipy to use communication profiler in DeepCompile") |
|
|
|
predictor = interp1d(sizes, durations, kind='linear', fill_value="extrapolate") |
|
|
|
def f(size): |
|
if size == 0: |
|
return 0 |
|
return predictor(size) |
|
|
|
|
|
return f |
|
|
|
|
|
if __name__ == "__main__": |
|
local_rank = int(os.environ['LOCAL_RANK']) |
|
get_accelerator().set_device(local_rank) |
|
print(f"local_rank={local_rank}") |
|
|
|
deepspeed.init_distributed(dist_backend='nccl') |
|
|
|
|
|
predictor = create_predictor() |
|
|
|
|
|
example_size = 1e9 |
|
predicted_time = predictor(example_size) |
|
print(f"Predicted time for size {example_size}: {predicted_time:.6f} seconds") |
|
|
|
dist.destroy_process_group() |
|
|