File size: 5,661 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import numpy as np
import torch
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import PackbitsBuilder


class CompressedBackend(object):

    def __init__(self, mpu=None):
        if mpu is None:
            self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
        else:
            self.mpu = mpu
            self.world_group = self.mpu.get_data_parallel_group()
        self.size = dist.get_world_size(group=self.world_group)
        self.rank = dist.get_rank(group=self.world_group)
        self.packer = PackbitsBuilder().load()

    def my_igather(self, rank, size, group, sendbuf, recvbuf, root):
        req = []
        if rank == root:
            for idx in range(size):
                if idx != rank:
                    req.append(dist.irecv(recvbuf[idx], src=idx, group=group))
                else:
                    recvbuf[rank] = sendbuf
        else:
            req.append(dist.isend(sendbuf, group=group, dst=root))
        return req

    def my_gather(self, rank, size, group, sendbuf, recvbuf, root):
        if rank == root:
            for idx in range(size):
                if idx != rank:
                    dist.recv(recvbuf[idx], src=idx, group=group)
                else:
                    recvbuf[rank] = sendbuf
        else:
            dist.send(sendbuf, group=group, dst=root)

    def pack(self, buffer, size):
        # pack float tensor into uint8 tensor
        packed = self.packer.packbits(buffer.float(), buffer.numel(), self.rank)
        return packed.reshape(size, -1)

    def unpack(self, buffer, size, dtype):
        # unpack uint8 to float tensor
        unpacked = self.packer.unpackbits(buffer, buffer.numel(), self.rank)
        return unpacked.reshape(size, -1).to(dtype)

    def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank):
        original_shape = buffer_m.size()
        if len(original_shape) > 1:
            buffer_m = torch.flatten(buffer_m)

        # align size of original_buffer and error
        original_size = buffer_m.numel()
        worker_error_size = worker_error.numel()
        if original_size != worker_error_size:
            empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device)
            buffer_m = torch.cat([buffer_m, empty_tensor])

        buffer_m.add_(worker_error)
        worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))

        worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

        sign_list_packed_tmp = self.pack(buffer_m, self.size).type(torch.int8)

        recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])],
                                   dtype=sign_list_packed_tmp[0].dtype,
                                   device=sign_list_packed_tmp.device)

        sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)]

        recvbuf_scale = [
            torch.zeros(1, dtype=worker_scale.dtype, device=get_accelerator().current_device_name())
            for _ in range(self.size)
        ]

        # communication phase 1
        # all to all for sign
        dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group)
        # all gather for scale
        dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)

        flattened_recvbuf_sign = recvbuf_sign.type(torch.uint8).flatten()
        compensated_server_m = self.unpack(flattened_recvbuf_sign, self.size, torch.float32) \
            .mul_(torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0)

        compensated_server_m.add_(server_error)

        server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())

        server_error.set_(compensated_server_m -
                          server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

        server_sign_packed = self.pack(compensated_server_m, 1).type(torch.int8)

        # recvbuf_sign_server
        recvbuf_sign_server_tmp = torch.zeros([self.size, len(server_sign_packed[0])],
                                              dtype=recvbuf_sign.dtype,
                                              device=server_sign_packed.device)

        recvbuf_sign_server = [recvbuf_sign_server_tmp[idx] for idx in range(self.size)]

        # recvbuf_scale_server
        recvbuf_scale_server_tmp = torch.zeros([self.size, 1],
                                               dtype=worker_scale.dtype,
                                               device=server_sign_packed.device)

        recvbuf_scale_server = [recvbuf_scale_server_tmp[idx] for idx in range(self.size)]

        # communication Phase 2
        dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group)
        dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group)

        recvbuf_sign_server = torch.stack(recvbuf_sign_server)

        flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten()

        buffer_m.data.copy_(
            self.unpack(flattened_recvbuf_sign_server, self.size,
                        torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data)

        if original_size != worker_error_size:
            buffer_m = buffer_m[0:original_size]
        if len(original_shape) > 1:
            buffer_m = buffer_m.reshape(original_shape)

        return buffer_m