|
|
|
import multiprocessing |
|
import os |
|
import threading |
|
from multiprocessing import reduction |
|
from multiprocessing.util import register_after_fork |
|
from typing import Union |
|
|
|
import torch |
|
from torch._namedtensor_internals import check_serializing_named_tensor |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
import multiprocessing.resource_sharer |
|
except ImportError: |
|
pass |
|
|
|
|
|
class StorageWeakRef: |
|
r"""A weak reference to a Storage. |
|
|
|
The cdata member is a Python number containing the integer representation of |
|
the Storage pointer. |
|
""" |
|
|
|
__slots__ = ["cdata", "_free_weak_ref"] |
|
|
|
def __init__(self, storage): |
|
self.cdata = storage._weak_ref() |
|
|
|
|
|
self._free_weak_ref = torch.Storage._free_weak_ref |
|
|
|
@classmethod |
|
def from_weakref(cls, cdata): |
|
instance = cls.__new__(cls) |
|
instance.cdata = cdata |
|
instance._free_weak_ref = torch.Storage._free_weak_ref |
|
return instance |
|
|
|
def expired(self): |
|
return torch.Storage._expired(self.cdata) |
|
|
|
def __del__(self): |
|
self._free_weak_ref(self.cdata) |
|
|
|
def __hash__(self): |
|
return self.cdata |
|
|
|
def __eq__(self, other): |
|
if id(self) == id(other): |
|
return True |
|
return self.cdata == other.cdata |
|
|
|
|
|
class SharedCache(dict): |
|
"""Dictionary from multiprocessing handles to StorageWeakRef.""" |
|
|
|
def __init__(self) -> None: |
|
|
|
|
|
self.limit = 128 |
|
|
|
|
|
|
|
self._after_fork() |
|
register_after_fork(self, SharedCache._after_fork) |
|
|
|
def _after_fork(self): |
|
self.lock = threading.Lock() |
|
|
|
def get(self, key): |
|
with self.lock: |
|
return dict.get(self, key) |
|
|
|
def __setitem__(self, key, storage_ref): |
|
with self.lock: |
|
dict.__setitem__(self, key, storage_ref) |
|
if len(self) > self.limit: |
|
self.free_dead_references() |
|
|
|
def free_dead_references(self): |
|
live = 0 |
|
for key, storage_ref in list(self.items()): |
|
if storage_ref.expired(): |
|
del self[key] |
|
else: |
|
live += 1 |
|
self.limit = max(128, live * 2) |
|
|
|
|
|
|
|
shared_cache = SharedCache() |
|
|
|
|
|
def rebuild_event(device, handle): |
|
return torch.cuda.Event.from_ipc_handle(device, handle) |
|
|
|
|
|
def reduce_event(event): |
|
handle = event.ipc_handle() |
|
return (rebuild_event, (event.device, handle)) |
|
|
|
|
|
def rebuild_tensor(cls, storage, metadata): |
|
storage_offset, size, stride, requires_grad = metadata |
|
t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) |
|
if cls == torch.nn.parameter.Parameter: |
|
|
|
|
|
|
|
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) |
|
else: |
|
t.requires_grad = requires_grad |
|
return t |
|
|
|
|
|
def rebuild_meta_tensor( |
|
tensor_cls, |
|
tensor_size, |
|
tensor_stride, |
|
tensor_offset, |
|
dtype, |
|
storage_size_bytes, |
|
requires_grad, |
|
): |
|
untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta") |
|
|
|
typed_storage = torch.TypedStorage( |
|
wrap_storage=untyped_storage, dtype=dtype, _internal=True |
|
) |
|
|
|
t = torch._utils._rebuild_tensor( |
|
typed_storage, |
|
tensor_offset, |
|
tensor_size, |
|
tensor_stride, |
|
) |
|
|
|
if tensor_cls == torch.nn.parameter.Parameter: |
|
|
|
|
|
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) |
|
else: |
|
t.requires_grad = requires_grad |
|
|
|
return t |
|
|
|
|
|
def rebuild_cuda_tensor( |
|
tensor_cls, |
|
tensor_size, |
|
tensor_stride, |
|
tensor_offset, |
|
storage_cls, |
|
dtype, |
|
storage_device, |
|
storage_handle, |
|
storage_size_bytes, |
|
storage_offset_bytes, |
|
requires_grad, |
|
ref_counter_handle, |
|
ref_counter_offset, |
|
event_handle, |
|
event_sync_required, |
|
): |
|
|
|
if storage_handle is None or storage_size_bytes == 0: |
|
storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) |
|
else: |
|
storage = storage_from_cache( |
|
storage_cls, (storage_handle, storage_offset_bytes) |
|
) |
|
if storage is None: |
|
torch.cuda._lazy_init() |
|
storage = storage_cls._new_shared_cuda( |
|
storage_device, |
|
storage_handle, |
|
storage_size_bytes, |
|
storage_offset_bytes, |
|
ref_counter_handle, |
|
ref_counter_offset, |
|
event_handle, |
|
event_sync_required, |
|
) |
|
shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef( |
|
storage |
|
) |
|
else: |
|
|
|
storage_cls._release_ipc_counter( |
|
ref_counter_handle, ref_counter_offset, device=storage_device |
|
) |
|
|
|
_storage = ( |
|
storage |
|
if isinstance(storage, torch.UntypedStorage) |
|
else storage._untyped_storage |
|
) |
|
|
|
t = torch._utils._rebuild_tensor( |
|
torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), |
|
tensor_offset, |
|
tensor_size, |
|
tensor_stride, |
|
) |
|
|
|
if tensor_cls == torch.nn.parameter.Parameter: |
|
|
|
|
|
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) |
|
else: |
|
t.requires_grad = requires_grad |
|
|
|
return t |
|
|
|
|
|
def reduce_tensor(tensor): |
|
if tensor.requires_grad and not tensor.is_leaf: |
|
raise RuntimeError( |
|
"Cowardly refusing to serialize non-leaf tensor which requires_grad, " |
|
"since autograd does not support crossing process boundaries. " |
|
"If you just want to transfer the data, call detach() on the tensor " |
|
"before serializing (e.g., putting it on the queue)." |
|
) |
|
|
|
check_serializing_named_tensor(tensor) |
|
torch.utils.hooks.warn_if_has_hooks(tensor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch.nested._internal.nested_tensor import NestedTensor |
|
|
|
if tensor.is_nested and not isinstance(tensor, NestedTensor): |
|
return reduce_nested_tensor(tensor) |
|
|
|
if tensor.layout in { |
|
torch.sparse_coo, |
|
torch.sparse_csr, |
|
torch.sparse_bsr, |
|
torch.sparse_csc, |
|
torch.sparse_bsc, |
|
}: |
|
return reduce_sparse_tensor(tensor) |
|
|
|
storage = tensor._typed_storage() |
|
|
|
if storage._untyped_storage.device.type == "cuda": |
|
( |
|
device, |
|
handle, |
|
storage_size_bytes, |
|
storage_offset_bytes, |
|
ref_counter_handle, |
|
ref_counter_offset, |
|
event_handle, |
|
event_sync_required, |
|
) = storage._share_cuda_() |
|
tensor_offset = tensor.storage_offset() |
|
shared_cache[handle] = StorageWeakRef(storage) |
|
|
|
|
|
return ( |
|
rebuild_cuda_tensor, |
|
( |
|
type(tensor), |
|
tensor.size(), |
|
tensor.stride(), |
|
tensor_offset, |
|
type(storage), |
|
tensor.dtype, |
|
device, |
|
handle, |
|
storage_size_bytes, |
|
storage_offset_bytes, |
|
tensor.requires_grad, |
|
ref_counter_handle, |
|
ref_counter_offset, |
|
event_handle, |
|
event_sync_required, |
|
), |
|
) |
|
elif storage._untyped_storage.device.type == "meta": |
|
return ( |
|
rebuild_meta_tensor, |
|
( |
|
type(tensor), |
|
tensor.size(), |
|
tensor.stride(), |
|
tensor.storage_offset(), |
|
tensor.dtype, |
|
tensor.untyped_storage().size(), |
|
tensor.requires_grad, |
|
), |
|
) |
|
|
|
|
|
metadata = ( |
|
tensor.storage_offset(), |
|
tensor.size(), |
|
tensor.stride(), |
|
tensor.requires_grad, |
|
) |
|
return (rebuild_tensor, (type(tensor), storage, metadata)) |
|
|
|
|
|
def rebuild_nested_tensor( |
|
rebuild_buffer_func, |
|
rebuild_buffer_args, |
|
rebuild_sizes_func, |
|
rebuild_sizes_args, |
|
rebuild_strides_func, |
|
rebuild_strides_args, |
|
rebuild_offsets_func, |
|
rebuild_offsets_args, |
|
): |
|
buffer = rebuild_buffer_func(*rebuild_buffer_args) |
|
sizes = rebuild_sizes_func(*rebuild_sizes_args) |
|
strides = rebuild_strides_func(*rebuild_strides_args) |
|
offsets = rebuild_offsets_func(*rebuild_offsets_args) |
|
return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets) |
|
|
|
|
|
def reduce_nested_tensor(nt): |
|
rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values()) |
|
rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size()) |
|
rebuild_strides_func, rebuild_strides_args = reduce_tensor( |
|
nt._nested_tensor_strides() |
|
) |
|
rebuild_offsets_func, rebuild_offsets_args = reduce_tensor( |
|
nt._nested_tensor_storage_offsets() |
|
) |
|
|
|
return ( |
|
rebuild_nested_tensor, |
|
( |
|
rebuild_buffer_func, |
|
rebuild_buffer_args, |
|
rebuild_sizes_func, |
|
rebuild_sizes_args, |
|
rebuild_strides_func, |
|
rebuild_strides_args, |
|
rebuild_offsets_func, |
|
rebuild_offsets_args, |
|
), |
|
) |
|
|
|
|
|
def rebuild_sparse_coo_tensor( |
|
rebuild_indices_func, |
|
rebuild_indices_args, |
|
rebuild_values_func, |
|
rebuild_values_args, |
|
shape, |
|
is_coalesced, |
|
): |
|
indices = rebuild_indices_func(*rebuild_indices_args) |
|
values = rebuild_values_func(*rebuild_values_args) |
|
return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced) |
|
|
|
|
|
def rebuild_sparse_compressed_tensor( |
|
rebuild_compressed_indices_func, |
|
rebuild_compressed_indices_args, |
|
rebuild_plain_indices_func, |
|
rebuild_plain_indices_args, |
|
rebuild_values_func, |
|
rebuild_values_args, |
|
shape, |
|
layout, |
|
): |
|
compressed_indices = rebuild_compressed_indices_func( |
|
*rebuild_compressed_indices_args |
|
) |
|
plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args) |
|
values = rebuild_values_func(*rebuild_values_args) |
|
return torch.sparse_compressed_tensor( |
|
compressed_indices, plain_indices, values, shape, layout=layout |
|
) |
|
|
|
|
|
def reduce_sparse_tensor(sparse): |
|
if sparse.layout is torch.sparse_coo: |
|
rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices()) |
|
rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values()) |
|
return ( |
|
rebuild_sparse_coo_tensor, |
|
( |
|
rebuild_indices_func, |
|
rebuild_indices_args, |
|
rebuild_values_func, |
|
rebuild_values_args, |
|
sparse.shape, |
|
sparse.is_coalesced(), |
|
), |
|
) |
|
else: |
|
if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}: |
|
compressed_indices = sparse.crow_indices() |
|
plain_indices = sparse.col_indices() |
|
elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}: |
|
compressed_indices = sparse.ccol_indices() |
|
plain_indices = sparse.row_indices() |
|
else: |
|
raise NotImplementedError(sparse.layout) |
|
( |
|
rebuild_compressed_indices_func, |
|
rebuild_compressed_indices_args, |
|
) = reduce_tensor(compressed_indices) |
|
rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor( |
|
plain_indices |
|
) |
|
rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values()) |
|
return ( |
|
rebuild_sparse_compressed_tensor, |
|
( |
|
rebuild_compressed_indices_func, |
|
rebuild_compressed_indices_args, |
|
rebuild_plain_indices_func, |
|
rebuild_plain_indices_args, |
|
rebuild_values_func, |
|
rebuild_values_args, |
|
sparse.shape, |
|
sparse.layout, |
|
), |
|
) |
|
|
|
|
|
def fd_id(fd): |
|
|
|
|
|
|
|
stat = os.fstat(fd) |
|
return (stat.st_ino, stat.st_dev) |
|
|
|
|
|
def storage_from_cache(cls, key): |
|
storage_ref = shared_cache.get(key) |
|
if storage_ref is None: |
|
return None |
|
return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata) |
|
|
|
|
|
def rebuild_storage_fd(cls, df, size): |
|
fd = df.detach() |
|
try: |
|
storage = storage_from_cache(cls, fd_id(fd)) |
|
if storage is not None: |
|
return storage |
|
storage = cls._new_shared_fd_cpu(fd, size) |
|
shared_cache[fd_id(fd)] = StorageWeakRef(storage) |
|
return storage |
|
finally: |
|
os.close(fd) |
|
|
|
|
|
def rebuild_storage_filename(cls, manager, handle, size, dtype=None): |
|
storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache( |
|
cls, handle |
|
) |
|
if storage is not None: |
|
return storage._shared_decref() |
|
if dtype is None: |
|
storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size) |
|
else: |
|
byte_size = size * torch._utils._element_size(dtype) |
|
untyped_storage: torch.UntypedStorage = ( |
|
torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size) |
|
) |
|
storage = torch.TypedStorage( |
|
wrap_storage=untyped_storage, dtype=dtype, _internal=True |
|
) |
|
shared_cache[handle] = StorageWeakRef(storage) |
|
return storage._shared_decref() |
|
|
|
|
|
def rebuild_storage_empty(cls): |
|
return cls() |
|
|
|
|
|
def rebuild_typed_storage(storage, dtype): |
|
return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True) |
|
|
|
|
|
|
|
def reduce_typed_storage(storage): |
|
return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype)) |
|
|
|
|
|
def rebuild_typed_storage_child(storage, storage_type): |
|
return storage_type(wrap_storage=storage, _internal=True) |
|
|
|
|
|
|
|
def reduce_typed_storage_child(storage): |
|
return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage))) |
|
|
|
|
|
def reduce_storage(storage): |
|
from . import get_sharing_strategy |
|
|
|
if storage.is_cuda: |
|
raise RuntimeError( |
|
"Cannot pickle CUDA storage; try pickling a CUDA tensor instead" |
|
) |
|
elif storage.device.type == "meta": |
|
raise RuntimeError( |
|
"Cannot pickle meta storage; try pickling a meta tensor instead" |
|
) |
|
elif get_sharing_strategy() == "file_system": |
|
metadata = storage._share_filename_cpu_() |
|
cache_key = metadata[1] |
|
rebuild = rebuild_storage_filename |
|
if isinstance(storage, torch.TypedStorage): |
|
metadata += (storage.dtype,) |
|
storage._shared_incref() |
|
elif storage.size() == 0: |
|
|
|
|
|
return (rebuild_storage_empty, (type(storage),)) |
|
else: |
|
fd, size = storage._share_fd_cpu_() |
|
df = multiprocessing.reduction.DupFd(fd) |
|
cache_key = fd_id(fd) |
|
metadata = (df, size) |
|
rebuild = rebuild_storage_fd |
|
|
|
shared_cache[cache_key] = StorageWeakRef(storage) |
|
return (rebuild, (type(storage),) + metadata) |
|
|
|
|
|
def init_reductions(): |
|
reduction.register(torch.cuda.Event, reduce_event) |
|
|
|
for t in torch._storage_classes: |
|
if t.__name__ == "UntypedStorage": |
|
reduction.register(t, reduce_storage) |
|
else: |
|
reduction.register(t, reduce_typed_storage_child) |
|
|
|
reduction.register(torch.storage.TypedStorage, reduce_typed_storage) |
|
|
|
for t in torch._tensor_classes: |
|
reduction.register(t, reduce_tensor) |
|
|
|
|
|
reduction.register(torch.Tensor, reduce_tensor) |
|
|
|
from torch.nn.parameter import Parameter |
|
|
|
reduction.register(Parameter, reduce_tensor) |
|
|