File size: 1,841 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
class Base_IO_Buffer(object):
def __init__(self, pinned_tensor, dnvme_handle):
assert pinned_tensor.numel() % dnvme_handle.get_alignment() == 0
self._dnvme_handle = dnvme_handle
self._pinned_tensor = pinned_tensor
def fill(self, src_tensor, src_offset):
pass
def drain(self, num_bytes, fd, file_offset):
pass
def is_empty(self):
pass
def is_full(self):
pass
def get_buffer(self):
pass
def get_offset(self):
pass
def get_aligned_num_bytes(self):
pass
def get_unaligned_num_bytes(self):
pass
def reset(self):
pass
def complete_ongoing_drain(self):
pass
def _drain(self, num_bytes, fd, file_offset, blocking=False):
assert num_bytes <= self.get_offset()
assert num_bytes % self._dnvme_handle.get_alignment() == 0
buffer = self.get_buffer()
r = self._dnvme_handle.async_pwrite(torch.narrow(buffer, 0, 0, num_bytes), fd, file_offset)
assert 0 == r
if blocking:
assert 1 == self._dnvme_handle.wait()
@staticmethod
def fill_buffer(src_tensor, src_offset, buffer_tensor, buffer_offset):
src_bytes = src_tensor.numel() - src_offset
assert src_bytes > 0
dst_bytes = buffer_tensor.numel() - buffer_offset
copy_bytes = min(src_bytes, dst_bytes)
assert (buffer_offset + copy_bytes) <= buffer_tensor.numel()
if copy_bytes > 0:
src_slice = torch.narrow(src_tensor, 0, src_offset, copy_bytes)
dst_slice = torch.narrow(buffer_tensor, 0, buffer_offset, copy_bytes)
dst_slice.data.copy_(src_slice.data)
return copy_bytes
|