|
import pickle |
|
from dataclasses import dataclass |
|
from io import BufferedIOBase |
|
from typing import Any |
|
|
|
import torch |
|
import torch._weights_only_unpickler as _weights_only_unpickler |
|
from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION |
|
|
|
|
|
__all__: list[str] = [] |
|
|
|
|
|
@dataclass |
|
class _Entry: |
|
key: str |
|
is_storage: bool |
|
length: int |
|
|
|
|
|
_weights_only_unpickler._add_safe_globals([_Entry]) |
|
|
|
|
|
class _PseudoZipFile: |
|
def __init__(self) -> None: |
|
self.records: dict[str, tuple[object, int]] = {} |
|
|
|
def write_record(self, key: str, data: object, length: int) -> None: |
|
self.records[key] = (data, length) |
|
|
|
def write_to(self, f: BufferedIOBase) -> None: |
|
entries = [] |
|
for key, (data, length) in self.records.items(): |
|
entries.append( |
|
_Entry( |
|
key=key, |
|
is_storage=isinstance(data, torch.UntypedStorage), |
|
length=length, |
|
) |
|
) |
|
|
|
pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL) |
|
|
|
for key, (data, length) in self.records.items(): |
|
if isinstance(data, bytes): |
|
f.write(data) |
|
elif isinstance(data, str): |
|
f.write(data.encode("utf-8")) |
|
elif isinstance(data, torch.UntypedStorage): |
|
data._write_file(f, False, False, 1) |
|
else: |
|
raise TypeError(f"unknown type: {type(data)}") |
|
|
|
def read_from(self, f: BufferedIOBase) -> None: |
|
entries = _weights_only_unpickler.load(f) |
|
|
|
for entry in entries: |
|
data = f.read(entry.length) |
|
if entry.is_storage: |
|
storage = torch.frombuffer( |
|
data, |
|
dtype=torch.uint8, |
|
).untyped_storage() |
|
|
|
self.records[entry.key] = ( |
|
storage, |
|
entry.length, |
|
) |
|
else: |
|
self.records[entry.key] = (data, entry.length) |
|
|
|
def has_record(self, key: str) -> bool: |
|
return key in self.records |
|
|
|
def get_record(self, key: str) -> object: |
|
return self.records[key][0] |
|
|
|
def get_storage_from_record( |
|
self, key: str, _length: int, _type: int |
|
) -> torch.Tensor: |
|
return torch.tensor(self.records[key][0], dtype=torch.uint8) |
|
|
|
def serialization_id(self) -> str: |
|
return "torchft" |
|
|
|
|
|
def _streaming_save( |
|
obj: object, |
|
f: BufferedIOBase, |
|
pickle_module: Any = pickle, |
|
pickle_protocol: int = DEFAULT_PROTOCOL, |
|
) -> None: |
|
""" |
|
Save the object to a file-like object in a streaming fashion compatible with |
|
network sockets. |
|
|
|
This behaves similarly to :func:`torch.save` with a few notable differences: |
|
|
|
* A non-seekable file like object can be used when loading. |
|
* No forwards/backwards compatiblity is provided for the serialization |
|
format. This is only intended to be used with a single version of PyTorch |
|
with transient storage (i.e. sockets or temp files). |
|
* mmap is not supported |
|
|
|
See :func:`torch.save` for more details on specific arguments. |
|
""" |
|
|
|
zip_file = _PseudoZipFile() |
|
_save( |
|
obj, |
|
zip_file=zip_file, |
|
pickle_module=pickle_module, |
|
pickle_protocol=pickle_protocol, |
|
_disable_byteorder_record=False, |
|
) |
|
zip_file.write_to(f) |
|
|
|
|
|
def _streaming_load( |
|
f: BufferedIOBase, |
|
map_location: MAP_LOCATION = None, |
|
pickle_module: Any = None, |
|
*, |
|
weights_only: bool = True, |
|
**pickle_load_args: Any, |
|
) -> object: |
|
""" |
|
Load the object from a file-like object in a streaming fashion compatible with |
|
network sockets. |
|
|
|
See :func:`_streaming_save` for more details about the streaming behavior. |
|
|
|
See :func:`torch.load` for more details on specific arguments. |
|
""" |
|
if weights_only: |
|
if pickle_module is not None: |
|
raise RuntimeError( |
|
"Can not safely load weights when explicit pickle_module is specified" |
|
) |
|
pickle_module = _weights_only_unpickler |
|
else: |
|
if pickle_module is None: |
|
pickle_module = pickle |
|
|
|
if "encoding" not in pickle_load_args.keys(): |
|
pickle_load_args["encoding"] = "utf-8" |
|
|
|
zip_file = _PseudoZipFile() |
|
zip_file.read_from(f) |
|
return _load( |
|
zip_file=zip_file, |
|
map_location=map_location, |
|
pickle_module=pickle_module, |
|
**pickle_load_args, |
|
) |
|
|