|
|
|
import io |
|
|
|
import torch |
|
from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer |
|
from torch.package._package_pickler import create_pickler |
|
from torch.package._package_unpickler import PackageUnpickler |
|
from torch.serialization import _maybe_decode_ascii |
|
|
|
|
|
def _save_storages(importer, obj): |
|
serialized_storages = [] |
|
serialized_dtypes = [] |
|
|
|
importer = importer if isinstance(importer, torch.package.PackageImporter) else None |
|
importers: Importer |
|
if importer is not None: |
|
importers = OrderedImporter(importer, sys_importer) |
|
else: |
|
importers = sys_importer |
|
|
|
def persistent_id(obj): |
|
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): |
|
if isinstance(obj, torch.storage.TypedStorage): |
|
|
|
|
|
dtype = obj.dtype |
|
else: |
|
dtype = torch.uint8 |
|
|
|
serialized_storages.append(obj) |
|
serialized_dtypes.append(dtype) |
|
return ("storage", len(serialized_storages) - 1) |
|
|
|
if hasattr(obj, "__reduce_deploy__"): |
|
if _serialized_reduces.get(id(obj)) is None: |
|
_serialized_reduces[id(obj)] = ( |
|
"reduce_deploy", |
|
id(obj), |
|
*obj.__reduce_deploy__(importers), |
|
) |
|
return _serialized_reduces[id(obj)] |
|
|
|
return None |
|
|
|
|
|
data_buf = io.BytesIO() |
|
pickler = create_pickler(data_buf, importers) |
|
pickler.persistent_id = persistent_id |
|
pickler.dump(obj) |
|
data_value = data_buf.getvalue() |
|
return ( |
|
data_value, |
|
serialized_storages, |
|
serialized_dtypes, |
|
importer.zip_reader if importer else None, |
|
) |
|
|
|
|
|
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes): |
|
def persistent_load(saved_id): |
|
assert isinstance(saved_id, tuple) |
|
typename = _maybe_decode_ascii(saved_id[0]) |
|
data = saved_id[1:] |
|
|
|
if typename == "storage": |
|
|
|
|
|
storage = serialized_storages[data[0]] |
|
dtype = serialized_dtypes[data[0]] |
|
return torch.storage.TypedStorage( |
|
wrap_storage=storage.untyped(), dtype=dtype |
|
) |
|
|
|
if typename == "reduce_deploy": |
|
reduce_id, func, args = data |
|
if reduce_id not in _loaded_reduces: |
|
_loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args) |
|
return _loaded_reduces[reduce_id] |
|
|
|
return None |
|
|
|
importer: Importer |
|
if zip_reader is not None: |
|
importer = OrderedImporter(_get_package(zip_reader), sys_importer) |
|
else: |
|
importer = sys_importer |
|
|
|
unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes)) |
|
unpickler.persistent_load = persistent_load |
|
result = _deploy_objects[id] = unpickler.load() |
|
return result |
|
|
|
|
|
def _get_package(zip_reader): |
|
if zip_reader not in _raw_packages: |
|
_raw_packages[zip_reader] = PackageImporter(zip_reader) |
|
return _raw_packages[zip_reader] |
|
|
|
|
|
_raw_packages: dict = {} |
|
_deploy_objects: dict = {} |
|
_serialized_reduces: dict = {} |
|
_loaded_reduces: dict = {} |
|
|