jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import numpy
import torch
from dataclasses import dataclass
@dataclass
class serialize_details:
obj: object
dtype: torch.dtype
size: int
nbytes: int
def tensor_to_bytes(tensor):
return tensor.numpy().tobytes()
def bytes_to_tensor(buffer):
return torch.from_numpy(numpy.array(numpy.frombuffer(buffer, dtype=numpy.uint8)))
def required_minimum_torch_version(major_version, minor_version):
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR < major_version:
return False
return TORCH_MAJOR > major_version or TORCH_MINOR >= minor_version
# torch < 1.12
def _legacy_obj_serialization_details(storage_obj):
nbytes = storage_obj.element_size() * storage_obj.size()
return serialize_details(obj=storage_obj, dtype=storage_obj.dtype, size=nbytes, nbytes=nbytes)
# torch >= 1.12
def _new_obj_serialization_details(storage_obj):
obj, dtype = storage_obj
return serialize_details(obj=obj,
dtype=dtype,
size=obj.size() // torch._utils._element_size(dtype),
nbytes=obj.size())
def obj_serialization_details():
if required_minimum_torch_version(1, 12):
return _new_obj_serialization_details
return _legacy_obj_serialization_details