|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
import time |
|
from dataclasses import dataclass |
|
|
|
from .constants import * |
|
from .base_file_writer import BaseFileWriter |
|
from .single_io_buffer import Single_IO_Buffer |
|
from .double_io_buffer import Double_IO_Buffer |
|
from deepspeed.ops.op_builder import UtilsBuilder |
|
from deepspeed.accelerator import get_accelerator |
|
|
|
from .utils import (tensor_to_bytes, bytes_to_tensor, obj_serialization_details) |
|
|
|
FASTIO_STAT_KEYS = [ |
|
AIO_WRITE_SEC_KEY, |
|
AIO_WRITE_BYTES_KEY, |
|
AIO_SPEED_KEY, |
|
SLOW_WRITE_BYTES_KEY, |
|
SLOW_WRITE_SEC_KEY, |
|
AIO_FILL_BUFFER_COUNT_KEY, |
|
AIO_FILL_BUFFER_SEC_KEY, |
|
AIO_FILL_BUFFER_SPEED_KEY, |
|
SAVE_STORAGE_KEY, |
|
SAVE_STORAGE_BYTES_KEY, |
|
] |
|
|
|
|
|
@dataclass |
|
class FastFileWriterConfig: |
|
dnvme_handle: object |
|
pinned_tensor: torch.Tensor |
|
double_buffer: bool = True |
|
num_parallel_writers: int = 1 |
|
writer_rank: int = 0 |
|
global_rank: int = 0 |
|
|
|
|
|
class FastFileWriter(BaseFileWriter): |
|
|
|
def __init__(self, file_path, config): |
|
super(FastFileWriter, self).__init__(file_path) |
|
self._aio_fd = os.open(self._file_path, flags=os.O_DIRECT | os.O_CREAT | os.O_WRONLY) |
|
self._dnvme_handle = config.dnvme_handle |
|
self._file_offset = 0 |
|
io_buffer_type = Double_IO_Buffer if config.double_buffer else Single_IO_Buffer |
|
self._io_buffer = io_buffer_type(config.pinned_tensor, self._dnvme_handle) |
|
self._cast_to_byte_tensor = UtilsBuilder().load().cast_to_byte_tensor |
|
self._get_serialization_details = obj_serialization_details() |
|
self._num_parallel_writers = config.num_parallel_writers |
|
self._writer_rank = config.writer_rank |
|
self._global_rank = config.global_rank |
|
|
|
for k in FASTIO_STAT_KEYS: |
|
self._stats[k] = 0 |
|
|
|
def write(self, buffer): |
|
assert self._file_offset % self._dnvme_handle.get_alignment() == 0 |
|
buffer_num_bytes = len(buffer) |
|
num_written_bytes = self._write_from_tensor(bytes_to_tensor(buffer)) |
|
assert buffer_num_bytes == num_written_bytes |
|
return buffer_num_bytes |
|
|
|
def split_index_list(self, storage_obj_list, num_splits): |
|
assert num_splits > 0 |
|
split_list = [-1] * num_splits |
|
|
|
tensor_bytes_list = [len(t[0]) for t in storage_obj_list] |
|
print(tensor_bytes_list) |
|
total_bytes = sum(tensor_bytes_list) |
|
bytes_per_group = total_bytes / num_splits |
|
split_counter = 0 |
|
tmp_size = 0 |
|
for i in range(len(tensor_bytes_list)): |
|
tmp_size += tensor_bytes_list[i] |
|
if tmp_size > bytes_per_group: |
|
split_list[split_counter] = i |
|
tmp_size = 0 |
|
split_counter += 1 |
|
if split_list[num_splits - 1] == -1: |
|
split_list[num_splits - 1] = len(tensor_bytes_list) |
|
return split_list |
|
|
|
def save_torch_storage_object_list(self, storage_obj_list, save_size): |
|
assert self._file_offset % self._dnvme_handle.get_alignment() == 0 |
|
num_bytes_written = self._save_storage_list(storage_obj_list, save_size) |
|
return num_bytes_written |
|
|
|
def close(self): |
|
self._fini() |
|
self._incr_stats(CLOSE_COUNT_KEY) |
|
|
|
def fileno(self): |
|
self._incr_stats(FILENO_COUNT_KEY) |
|
return INVALID_FD |
|
|
|
def flush(self): |
|
self._incr_stats(FLUSH_COUNT_KEY) |
|
|
|
def __del__(self): |
|
self._fini() |
|
assert self._aio_fd == INVALID_FD |
|
assert self._io_buffer.get_offset() == 0, \ |
|
f'__del__ assert: pinned_offset {self._io_buffer.get_offset()} != 0' |
|
assert self._file_offset == self._stats[WRITE_BYTES_KEY], \ |
|
f'__del__ assert: file_offset != write_bytes - {self._file_offset} != {self._stats[WRITE_BYTES_KEY]}' |
|
|
|
def _fini(self): |
|
if not self._io_buffer_is_empty(): |
|
self._force_drain() |
|
self._io_buffer.reset() |
|
self._aio_fd = INVALID_FD |
|
|
|
def _fill_io_buffer(self, src_tensor, src_offset): |
|
st = time.time() |
|
copy_bytes = self._io_buffer.fill(src_tensor, src_offset) |
|
self._incr_stats(AIO_FILL_BUFFER_SEC_KEY, time.time() - st) |
|
self._incr_stats(AIO_FILL_BUFFER_COUNT_KEY) |
|
return copy_bytes |
|
|
|
def _drain_io_buffer(self, num_bytes): |
|
st = time.time() |
|
self._io_buffer.drain(num_bytes, self._aio_fd, self._file_offset) |
|
self._incr_stats(AIO_WRITE_SEC_KEY, time.time() - st) |
|
self._incr_stats(AIO_WRITE_BYTES_KEY, num_bytes) |
|
self._file_offset += num_bytes |
|
|
|
def _io_buffer_is_full(self): |
|
return self._io_buffer.is_full() |
|
|
|
def _io_buffer_is_empty(self): |
|
return self._io_buffer.is_empty() |
|
|
|
def _force_drain(self): |
|
st = time.time() |
|
aligned_num_bytes = self._io_buffer.get_aligned_num_bytes() |
|
|
|
|
|
unaligned_num_bytes = self._io_buffer.get_unaligned_num_bytes() |
|
unaligned_tensor = torch.narrow(self._io_buffer.get_buffer(), 0, aligned_num_bytes, unaligned_num_bytes) |
|
|
|
if aligned_num_bytes > 0: |
|
self._drain_io_buffer(aligned_num_bytes) |
|
|
|
self._io_buffer.complete_ongoing_drain() |
|
self._incr_stats(AIO_WRITE_SEC_KEY, time.time() - st) |
|
|
|
if unaligned_num_bytes > 0: |
|
self._unaligned_drain(unaligned_tensor) |
|
self._incr_stats(WRITE_SEC_KEY, time.time() - st) |
|
|
|
def _unaligned_drain(self, unaligned_tensor): |
|
os.close(self._aio_fd) |
|
st = time.time() |
|
fp = open(self._file_path, 'ab') |
|
fp.write(tensor_to_bytes(unaligned_tensor.cpu())) |
|
fp.close() |
|
self._file_offset += unaligned_tensor.numel() |
|
self._incr_stats(SLOW_WRITE_SEC_KEY, time.time() - st) |
|
self._incr_stats(SLOW_WRITE_BYTES_KEY, unaligned_tensor.numel()) |
|
self._aio_fd = os.open(self._file_path, flags=os.O_DIRECT | os.O_WRONLY | os.O_APPEND) |
|
|
|
def _dump_state(self): |
|
if self._stats[AIO_WRITE_SEC_KEY] > 0: |
|
self._stats[AIO_SPEED_KEY] = (self._stats[AIO_WRITE_BYTES_KEY] / self._stats[AIO_WRITE_SEC_KEY] / |
|
(1024**3)) |
|
if self._stats[AIO_FILL_BUFFER_SEC_KEY] > 0: |
|
self._stats[AIO_FILL_BUFFER_SPEED_KEY] = (self._stats[AIO_WRITE_BYTES_KEY] / |
|
self._stats[AIO_FILL_BUFFER_SEC_KEY] / (1024**3)) |
|
super()._dump_state() |
|
|
|
def _update_write_stats(self, num_bytes, secs_latency): |
|
self._incr_stats(WRITE_COUNT_KEY) |
|
self._incr_stats(WRITE_BYTES_KEY, num_bytes) |
|
self._incr_stats(WRITE_SEC_KEY, secs_latency) |
|
|
|
def _write_from_tensor(self, buffer_tensor): |
|
st = time.time() |
|
buffer_offset = 0 |
|
while (buffer_offset < buffer_tensor.numel()): |
|
num_copied_bytes = self._fill_io_buffer(buffer_tensor, buffer_offset) |
|
if self._io_buffer_is_full(): |
|
self._drain_io_buffer(self._io_buffer.get_offset()) |
|
buffer_offset += num_copied_bytes |
|
|
|
self._update_write_stats(buffer_offset, time.time() - st) |
|
|
|
return buffer_offset |
|
|
|
def _save_storage_list(self, obj_list, save_size): |
|
byte_tensor_list, byte_tensor_nbytes = self._convert_to_byte_tensors(obj_list, save_size) |
|
if self._num_parallel_writers > 1: |
|
my_byte_tensor_list = self._partition_byte_tensors(byte_tensor_list, byte_tensor_nbytes, |
|
self._num_parallel_writers, self._writer_rank) |
|
else: |
|
my_byte_tensor_list = byte_tensor_list |
|
|
|
num_object_bytes_written = 0 |
|
for byte_tensor in my_byte_tensor_list: |
|
num_object_bytes_written += self._write_from_tensor(byte_tensor) |
|
|
|
self._incr_stats(SAVE_STORAGE_KEY, len(obj_list)) |
|
self._incr_stats(SAVE_STORAGE_BYTES_KEY, num_object_bytes_written) |
|
return num_object_bytes_written |
|
|
|
|
|
def _convert_to_byte_tensors(self, obj_list, save_size): |
|
tensor_list = [] |
|
num_bytes = 0 |
|
for storage_obj in obj_list: |
|
details = self._get_serialization_details(storage_obj) |
|
if save_size: |
|
tensor_list.append( |
|
torch.tensor( |
|
details.size, |
|
dtype=torch.int64, |
|
).to(get_accelerator().device_name())) |
|
tensor_list.append(torch.empty(0, dtype=details.dtype, device=details.obj.device).set_(details.obj)) |
|
num_bytes += details.nbytes |
|
if save_size: |
|
num_bytes += STORAGE_OBJ_SIZE * len(obj_list) |
|
|
|
return self._cast_to_byte_tensor(tensor_list), num_bytes |
|
|
|
def _partition_byte_tensors(self, byte_tensor_list, byte_tensor_nbytes, num_ranks, my_rank): |
|
assert my_rank >= 0, f'Invalid for rank number to be negative: {my_rank}' |
|
assert num_ranks > my_rank, f'Number of ranks {num_ranks} must be greater than rank {my_rank}' |
|
|
|
partition_size = int(byte_tensor_nbytes // num_ranks) |
|
num_remainder_bytes = byte_tensor_nbytes % num_ranks |
|
if num_remainder_bytes == 0: |
|
partition_start = partition_size * my_rank |
|
else: |
|
|
|
if num_remainder_bytes > my_rank: |
|
partition_size += 1 |
|
partition_start = partition_size * my_rank |
|
else: |
|
|
|
partition_start = (partition_size * my_rank) + num_remainder_bytes |
|
|
|
partition_end = min(partition_start + partition_size, byte_tensor_nbytes) |
|
partition_tensor_list = [] |
|
current_offset = 0 |
|
for byte_tensor in byte_tensor_list: |
|
byte_tensor_end = current_offset + byte_tensor.numel() |
|
if current_offset < partition_end and byte_tensor_end > partition_start: |
|
fragment_start = max(current_offset, partition_start) |
|
fragment_end = min(byte_tensor_end, partition_end) |
|
assert fragment_start < fragment_end, \ |
|
f'fragment start {fragment_start} should be < fragment_end {fragment_end}' |
|
|
|
fragment_numel = fragment_end - fragment_start |
|
partition_tensor_list.append(byte_tensor.narrow(0, fragment_start - current_offset, fragment_numel)) |
|
|
|
current_offset += byte_tensor.numel() |
|
|
|
actual_partition_nbytes = sum([t.numel() for t in partition_tensor_list]) |
|
assert actual_partition_nbytes == partition_size, \ |
|
f'Incorrect partition bytes for rank {my_rank}, expected = {partition_size} actual = {actual_partition_nbytes}' |
|
|
|
return partition_tensor_list |
|
|