File size: 1,841 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch


class Base_IO_Buffer(object):

    def __init__(self, pinned_tensor, dnvme_handle):
        assert pinned_tensor.numel() % dnvme_handle.get_alignment() == 0
        self._dnvme_handle = dnvme_handle
        self._pinned_tensor = pinned_tensor

    def fill(self, src_tensor, src_offset):
        pass

    def drain(self, num_bytes, fd, file_offset):
        pass

    def is_empty(self):
        pass

    def is_full(self):
        pass

    def get_buffer(self):
        pass

    def get_offset(self):
        pass

    def get_aligned_num_bytes(self):
        pass

    def get_unaligned_num_bytes(self):
        pass

    def reset(self):
        pass

    def complete_ongoing_drain(self):
        pass

    def _drain(self, num_bytes, fd, file_offset, blocking=False):
        assert num_bytes <= self.get_offset()
        assert num_bytes % self._dnvme_handle.get_alignment() == 0
        buffer = self.get_buffer()
        r = self._dnvme_handle.async_pwrite(torch.narrow(buffer, 0, 0, num_bytes), fd, file_offset)
        assert 0 == r
        if blocking:
            assert 1 == self._dnvme_handle.wait()

    @staticmethod
    def fill_buffer(src_tensor, src_offset, buffer_tensor, buffer_offset):
        src_bytes = src_tensor.numel() - src_offset
        assert src_bytes > 0

        dst_bytes = buffer_tensor.numel() - buffer_offset
        copy_bytes = min(src_bytes, dst_bytes)
        assert (buffer_offset + copy_bytes) <= buffer_tensor.numel()

        if copy_bytes > 0:
            src_slice = torch.narrow(src_tensor, 0, src_offset, copy_bytes)
            dst_slice = torch.narrow(buffer_tensor, 0, buffer_offset, copy_bytes)
            dst_slice.data.copy_(src_slice.data)

        return copy_bytes