File size: 1,464 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import numpy
import torch
from dataclasses import dataclass


@dataclass
class serialize_details:
    obj: object
    dtype: torch.dtype
    size: int
    nbytes: int


def tensor_to_bytes(tensor):
    return tensor.numpy().tobytes()


def bytes_to_tensor(buffer):
    return torch.from_numpy(numpy.array(numpy.frombuffer(buffer, dtype=numpy.uint8)))


def required_minimum_torch_version(major_version, minor_version):
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])

    if TORCH_MAJOR < major_version:
        return False

    return TORCH_MAJOR > major_version or TORCH_MINOR >= minor_version


# torch < 1.12
def _legacy_obj_serialization_details(storage_obj):
    nbytes = storage_obj.element_size() * storage_obj.size()
    return serialize_details(obj=storage_obj, dtype=storage_obj.dtype, size=nbytes, nbytes=nbytes)


# torch >= 1.12
def _new_obj_serialization_details(storage_obj):
    obj, dtype = storage_obj
    return serialize_details(obj=obj,
                             dtype=dtype,
                             size=obj.size() // torch._utils._element_size(dtype),
                             nbytes=obj.size())


def obj_serialization_details():
    if required_minimum_torch_version(1, 12):
        return _new_obj_serialization_details

    return _legacy_obj_serialization_details