File size: 5,661 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import numpy as np
import torch
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import PackbitsBuilder
class CompressedBackend(object):
def __init__(self, mpu=None):
if mpu is None:
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
else:
self.mpu = mpu
self.world_group = self.mpu.get_data_parallel_group()
self.size = dist.get_world_size(group=self.world_group)
self.rank = dist.get_rank(group=self.world_group)
self.packer = PackbitsBuilder().load()
def my_igather(self, rank, size, group, sendbuf, recvbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(dist.irecv(recvbuf[idx], src=idx, group=group))
else:
recvbuf[rank] = sendbuf
else:
req.append(dist.isend(sendbuf, group=group, dst=root))
return req
def my_gather(self, rank, size, group, sendbuf, recvbuf, root):
if rank == root:
for idx in range(size):
if idx != rank:
dist.recv(recvbuf[idx], src=idx, group=group)
else:
recvbuf[rank] = sendbuf
else:
dist.send(sendbuf, group=group, dst=root)
def pack(self, buffer, size):
# pack float tensor into uint8 tensor
packed = self.packer.packbits(buffer.float(), buffer.numel(), self.rank)
return packed.reshape(size, -1)
def unpack(self, buffer, size, dtype):
# unpack uint8 to float tensor
unpacked = self.packer.unpackbits(buffer, buffer.numel(), self.rank)
return unpacked.reshape(size, -1).to(dtype)
def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank):
original_shape = buffer_m.size()
if len(original_shape) > 1:
buffer_m = torch.flatten(buffer_m)
# align size of original_buffer and error
original_size = buffer_m.numel()
worker_error_size = worker_error.numel()
if original_size != worker_error_size:
empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])
buffer_m.add_(worker_error)
worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
sign_list_packed_tmp = self.pack(buffer_m, self.size).type(torch.int8)
recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])],
dtype=sign_list_packed_tmp[0].dtype,
device=sign_list_packed_tmp.device)
sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)]
recvbuf_scale = [
torch.zeros(1, dtype=worker_scale.dtype, device=get_accelerator().current_device_name())
for _ in range(self.size)
]
# communication phase 1
# all to all for sign
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group)
# all gather for scale
dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)
flattened_recvbuf_sign = recvbuf_sign.type(torch.uint8).flatten()
compensated_server_m = self.unpack(flattened_recvbuf_sign, self.size, torch.float32) \
.mul_(torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0)
compensated_server_m.add_(server_error)
server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
server_error.set_(compensated_server_m -
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
server_sign_packed = self.pack(compensated_server_m, 1).type(torch.int8)
# recvbuf_sign_server
recvbuf_sign_server_tmp = torch.zeros([self.size, len(server_sign_packed[0])],
dtype=recvbuf_sign.dtype,
device=server_sign_packed.device)
recvbuf_sign_server = [recvbuf_sign_server_tmp[idx] for idx in range(self.size)]
# recvbuf_scale_server
recvbuf_scale_server_tmp = torch.zeros([self.size, 1],
dtype=worker_scale.dtype,
device=server_sign_packed.device)
recvbuf_scale_server = [recvbuf_scale_server_tmp[idx] for idx in range(self.size)]
# communication Phase 2
dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group)
dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group)
recvbuf_sign_server = torch.stack(recvbuf_sign_server)
flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten()
buffer_m.data.copy_(
self.unpack(flattened_recvbuf_sign_server, self.size,
torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data)
if original_size != worker_error_size:
buffer_m = buffer_m[0:original_size]
if len(original_shape) > 1:
buffer_m = buffer_m.reshape(original_shape)
return buffer_m
|