File size: 11,049 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
import torch
import time
from dataclasses import dataclass

from .constants import *
from .base_file_writer import BaseFileWriter
from .single_io_buffer import Single_IO_Buffer
from .double_io_buffer import Double_IO_Buffer
from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.accelerator import get_accelerator

from .utils import (tensor_to_bytes, bytes_to_tensor, obj_serialization_details)

FASTIO_STAT_KEYS = [
    AIO_WRITE_SEC_KEY,
    AIO_WRITE_BYTES_KEY,
    AIO_SPEED_KEY,
    SLOW_WRITE_BYTES_KEY,
    SLOW_WRITE_SEC_KEY,
    AIO_FILL_BUFFER_COUNT_KEY,
    AIO_FILL_BUFFER_SEC_KEY,
    AIO_FILL_BUFFER_SPEED_KEY,
    SAVE_STORAGE_KEY,
    SAVE_STORAGE_BYTES_KEY,
]


@dataclass
class FastFileWriterConfig:
    dnvme_handle: object
    pinned_tensor: torch.Tensor
    double_buffer: bool = True
    num_parallel_writers: int = 1
    writer_rank: int = 0
    global_rank: int = 0


class FastFileWriter(BaseFileWriter):

    def __init__(self, file_path, config):
        super(FastFileWriter, self).__init__(file_path)
        self._aio_fd = os.open(self._file_path, flags=os.O_DIRECT | os.O_CREAT | os.O_WRONLY)
        self._dnvme_handle = config.dnvme_handle
        self._file_offset = 0
        io_buffer_type = Double_IO_Buffer if config.double_buffer else Single_IO_Buffer
        self._io_buffer = io_buffer_type(config.pinned_tensor, self._dnvme_handle)
        self._cast_to_byte_tensor = UtilsBuilder().load().cast_to_byte_tensor
        self._get_serialization_details = obj_serialization_details()
        self._num_parallel_writers = config.num_parallel_writers
        self._writer_rank = config.writer_rank
        self._global_rank = config.global_rank

        for k in FASTIO_STAT_KEYS:
            self._stats[k] = 0

    def write(self, buffer):
        assert self._file_offset % self._dnvme_handle.get_alignment() == 0
        buffer_num_bytes = len(buffer)
        num_written_bytes = self._write_from_tensor(bytes_to_tensor(buffer))
        assert buffer_num_bytes == num_written_bytes
        return buffer_num_bytes

    def split_index_list(self, storage_obj_list, num_splits):
        assert num_splits > 0
        split_list = [-1] * num_splits
        # t[0] is data, t[1] is data_type
        tensor_bytes_list = [len(t[0]) for t in storage_obj_list]
        print(tensor_bytes_list)
        total_bytes = sum(tensor_bytes_list)
        bytes_per_group = total_bytes / num_splits
        split_counter = 0
        tmp_size = 0
        for i in range(len(tensor_bytes_list)):
            tmp_size += tensor_bytes_list[i]
            if tmp_size > bytes_per_group:
                split_list[split_counter] = i
                tmp_size = 0
                split_counter += 1
        if split_list[num_splits - 1] == -1:
            split_list[num_splits - 1] = len(tensor_bytes_list)
        return split_list

    def save_torch_storage_object_list(self, storage_obj_list, save_size):
        assert self._file_offset % self._dnvme_handle.get_alignment() == 0
        num_bytes_written = self._save_storage_list(storage_obj_list, save_size)
        return num_bytes_written

    def close(self):
        self._fini()
        self._incr_stats(CLOSE_COUNT_KEY)

    def fileno(self):
        self._incr_stats(FILENO_COUNT_KEY)
        return INVALID_FD  # self._aio_fd

    def flush(self):
        self._incr_stats(FLUSH_COUNT_KEY)

    def __del__(self):
        self._fini()
        assert self._aio_fd == INVALID_FD
        assert self._io_buffer.get_offset() == 0, \
            f'__del__ assert: pinned_offset {self._io_buffer.get_offset()} != 0'
        assert self._file_offset == self._stats[WRITE_BYTES_KEY], \
            f'__del__ assert: file_offset != write_bytes - {self._file_offset} != {self._stats[WRITE_BYTES_KEY]}'

    def _fini(self):
        if not self._io_buffer_is_empty():
            self._force_drain()
        self._io_buffer.reset()
        self._aio_fd = INVALID_FD

    def _fill_io_buffer(self, src_tensor, src_offset):
        st = time.time()
        copy_bytes = self._io_buffer.fill(src_tensor, src_offset)
        self._incr_stats(AIO_FILL_BUFFER_SEC_KEY, time.time() - st)
        self._incr_stats(AIO_FILL_BUFFER_COUNT_KEY)
        return copy_bytes

    def _drain_io_buffer(self, num_bytes):
        st = time.time()
        self._io_buffer.drain(num_bytes, self._aio_fd, self._file_offset)
        self._incr_stats(AIO_WRITE_SEC_KEY, time.time() - st)
        self._incr_stats(AIO_WRITE_BYTES_KEY, num_bytes)
        self._file_offset += num_bytes

    def _io_buffer_is_full(self):
        return self._io_buffer.is_full()

    def _io_buffer_is_empty(self):
        return self._io_buffer.is_empty()

    def _force_drain(self):
        st = time.time()
        aligned_num_bytes = self._io_buffer.get_aligned_num_bytes()
        # Important to retrieve unaligned drain bytes and tensor before doing aligned drain because of the side effects.
        # TODO: Need to eliminate this dependency
        unaligned_num_bytes = self._io_buffer.get_unaligned_num_bytes()
        unaligned_tensor = torch.narrow(self._io_buffer.get_buffer(), 0, aligned_num_bytes, unaligned_num_bytes)

        if aligned_num_bytes > 0:
            self._drain_io_buffer(aligned_num_bytes)

        self._io_buffer.complete_ongoing_drain()
        self._incr_stats(AIO_WRITE_SEC_KEY, time.time() - st)

        if unaligned_num_bytes > 0:
            self._unaligned_drain(unaligned_tensor)
        self._incr_stats(WRITE_SEC_KEY, time.time() - st)

    def _unaligned_drain(self, unaligned_tensor):
        os.close(self._aio_fd)
        st = time.time()
        fp = open(self._file_path, 'ab')
        fp.write(tensor_to_bytes(unaligned_tensor.cpu()))
        fp.close()
        self._file_offset += unaligned_tensor.numel()
        self._incr_stats(SLOW_WRITE_SEC_KEY, time.time() - st)
        self._incr_stats(SLOW_WRITE_BYTES_KEY, unaligned_tensor.numel())
        self._aio_fd = os.open(self._file_path, flags=os.O_DIRECT | os.O_WRONLY | os.O_APPEND)

    def _dump_state(self):
        if self._stats[AIO_WRITE_SEC_KEY] > 0:
            self._stats[AIO_SPEED_KEY] = (self._stats[AIO_WRITE_BYTES_KEY] / self._stats[AIO_WRITE_SEC_KEY] /
                                          (1024**3))
        if self._stats[AIO_FILL_BUFFER_SEC_KEY] > 0:
            self._stats[AIO_FILL_BUFFER_SPEED_KEY] = (self._stats[AIO_WRITE_BYTES_KEY] /
                                                      self._stats[AIO_FILL_BUFFER_SEC_KEY] / (1024**3))
        super()._dump_state()

    def _update_write_stats(self, num_bytes, secs_latency):
        self._incr_stats(WRITE_COUNT_KEY)
        self._incr_stats(WRITE_BYTES_KEY, num_bytes)
        self._incr_stats(WRITE_SEC_KEY, secs_latency)

    def _write_from_tensor(self, buffer_tensor):
        st = time.time()
        buffer_offset = 0
        while (buffer_offset < buffer_tensor.numel()):
            num_copied_bytes = self._fill_io_buffer(buffer_tensor, buffer_offset)
            if self._io_buffer_is_full():
                self._drain_io_buffer(self._io_buffer.get_offset())
            buffer_offset += num_copied_bytes

        self._update_write_stats(buffer_offset, time.time() - st)

        return buffer_offset

    def _save_storage_list(self, obj_list, save_size):
        byte_tensor_list, byte_tensor_nbytes = self._convert_to_byte_tensors(obj_list, save_size)
        if self._num_parallel_writers > 1:
            my_byte_tensor_list = self._partition_byte_tensors(byte_tensor_list, byte_tensor_nbytes,
                                                               self._num_parallel_writers, self._writer_rank)
        else:
            my_byte_tensor_list = byte_tensor_list

        num_object_bytes_written = 0
        for byte_tensor in my_byte_tensor_list:
            num_object_bytes_written += self._write_from_tensor(byte_tensor)

        self._incr_stats(SAVE_STORAGE_KEY, len(obj_list))
        self._incr_stats(SAVE_STORAGE_BYTES_KEY, num_object_bytes_written)
        return num_object_bytes_written

    # Convert list of storage objects into list of byte tensors of object and size bytes
    def _convert_to_byte_tensors(self, obj_list, save_size):
        tensor_list = []
        num_bytes = 0
        for storage_obj in obj_list:
            details = self._get_serialization_details(storage_obj)
            if save_size:
                tensor_list.append(
                    torch.tensor(
                        details.size,
                        dtype=torch.int64,
                    ).to(get_accelerator().device_name()))
            tensor_list.append(torch.empty(0, dtype=details.dtype, device=details.obj.device).set_(details.obj))
            num_bytes += details.nbytes
        if save_size:
            num_bytes += STORAGE_OBJ_SIZE * len(obj_list)

        return self._cast_to_byte_tensor(tensor_list), num_bytes

    def _partition_byte_tensors(self, byte_tensor_list, byte_tensor_nbytes, num_ranks, my_rank):
        assert my_rank >= 0, f'Invalid for rank number to be negative: {my_rank}'
        assert num_ranks > my_rank, f'Number of ranks {num_ranks} must be greater than rank {my_rank}'

        partition_size = int(byte_tensor_nbytes // num_ranks)
        num_remainder_bytes = byte_tensor_nbytes % num_ranks
        if num_remainder_bytes == 0:
            partition_start = partition_size * my_rank
        else:
            # Spread extra bytes evenly among early ranks
            if num_remainder_bytes > my_rank:
                partition_size += 1
                partition_start = partition_size * my_rank
            else:
                # Account for allocation of extra bytes to earlier ranks
                partition_start = (partition_size * my_rank) + num_remainder_bytes

        partition_end = min(partition_start + partition_size, byte_tensor_nbytes)
        partition_tensor_list = []
        current_offset = 0
        for byte_tensor in byte_tensor_list:
            byte_tensor_end = current_offset + byte_tensor.numel()
            if current_offset < partition_end and byte_tensor_end > partition_start:
                fragment_start = max(current_offset, partition_start)
                fragment_end = min(byte_tensor_end, partition_end)
                assert fragment_start < fragment_end, \
                    f'fragment start {fragment_start} should be < fragment_end {fragment_end}'

                fragment_numel = fragment_end - fragment_start
                partition_tensor_list.append(byte_tensor.narrow(0, fragment_start - current_offset, fragment_numel))

            current_offset += byte_tensor.numel()

        actual_partition_nbytes = sum([t.numel() for t in partition_tensor_list])
        assert actual_partition_nbytes == partition_size, \
        f'Incorrect partition bytes for rank {my_rank}, expected = {partition_size} actual = {actual_partition_nbytes}'

        return partition_tensor_list