|
|
|
|
|
|
|
|
|
|
|
import numpy |
|
import torch |
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
class serialize_details: |
|
obj: object |
|
dtype: torch.dtype |
|
size: int |
|
nbytes: int |
|
|
|
|
|
def tensor_to_bytes(tensor): |
|
return tensor.numpy().tobytes() |
|
|
|
|
|
def bytes_to_tensor(buffer): |
|
return torch.from_numpy(numpy.array(numpy.frombuffer(buffer, dtype=numpy.uint8))) |
|
|
|
|
|
def required_minimum_torch_version(major_version, minor_version): |
|
TORCH_MAJOR = int(torch.__version__.split('.')[0]) |
|
TORCH_MINOR = int(torch.__version__.split('.')[1]) |
|
|
|
if TORCH_MAJOR < major_version: |
|
return False |
|
|
|
return TORCH_MAJOR > major_version or TORCH_MINOR >= minor_version |
|
|
|
|
|
|
|
def _legacy_obj_serialization_details(storage_obj): |
|
nbytes = storage_obj.element_size() * storage_obj.size() |
|
return serialize_details(obj=storage_obj, dtype=storage_obj.dtype, size=nbytes, nbytes=nbytes) |
|
|
|
|
|
|
|
def _new_obj_serialization_details(storage_obj): |
|
obj, dtype = storage_obj |
|
return serialize_details(obj=obj, |
|
dtype=dtype, |
|
size=obj.size() // torch._utils._element_size(dtype), |
|
nbytes=obj.size()) |
|
|
|
|
|
def obj_serialization_details(): |
|
if required_minimum_torch_version(1, 12): |
|
return _new_obj_serialization_details |
|
|
|
return _legacy_obj_serialization_details |
|
|