# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import pickle | |
from typing import Any | |
import torch | |
def pyobject_to_tensor(obj: Any, fixed_buffer_size: int = 0) -> torch.Tensor: | |
pickled = pickle.dumps(obj) | |
result: torch.Tensor = torch.ByteTensor(bytearray(pickled)) | |
if fixed_buffer_size: | |
delta = fixed_buffer_size - len(result) | |
if delta < 0: | |
raise ValueError( | |
f"message too big to send, increase `fixed_buffer_size`? - {len(result)} > {fixed_buffer_size}" | |
) | |
elif delta > 0: | |
result = torch.cat((result, torch.zeros(delta, dtype=torch.uint8))) | |
return result | |
def tensor_to_pyobject(tensor: torch.Tensor) -> Any: | |
nparray = tensor.cpu().numpy() | |
return pickle.loads(nparray.tobytes()) | |