jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
raw
history blame contribute delete
691 Bytes
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import cupy
from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack
class CupyBackend(object):
def __init__(self):
pass
def torch2cupy(self, tensor):
return cupy.from_dlpack(to_dlpack(tensor))
def cupy2torch(self, cupy_tensor):
return from_dlpack(cupy_tensor)
def compress_by_chunk(self, cupy_bool_tensor, num_chunks):
packed_sign = cupy.packbits(cupy_bool_tensor)
sign_list_packed = cupy.split(packed_sign, num_chunks)
cupy.cuda.get_current_stream().synchronize()
return sign_list_packed