|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Extends `dill` to support pickling more types and produce more consistent dumps.""" |
|
|
|
import os |
|
import sys |
|
from io import BytesIO |
|
from types import CodeType, FunctionType |
|
|
|
import dill |
|
from packaging import version |
|
|
|
from .. import config |
|
|
|
|
|
class Pickler(dill.Pickler): |
|
dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy()) |
|
_legacy_no_dict_keys_sorting = False |
|
|
|
def save(self, obj, save_persistent_id=True): |
|
obj_type = type(obj) |
|
if obj_type not in self.dispatch: |
|
if "regex" in sys.modules: |
|
import regex |
|
|
|
if obj_type is regex.Pattern: |
|
pklregister(obj_type)(_save_regexPattern) |
|
if "spacy" in sys.modules: |
|
import spacy |
|
|
|
if issubclass(obj_type, spacy.Language): |
|
pklregister(obj_type)(_save_spacyLanguage) |
|
if "tiktoken" in sys.modules: |
|
import tiktoken |
|
|
|
if obj_type is tiktoken.Encoding: |
|
pklregister(obj_type)(_save_tiktokenEncoding) |
|
if "torch" in sys.modules: |
|
import torch |
|
|
|
if issubclass(obj_type, torch.Tensor): |
|
pklregister(obj_type)(_save_torchTensor) |
|
|
|
if obj_type is torch.Generator: |
|
pklregister(obj_type)(_save_torchGenerator) |
|
|
|
|
|
if issubclass(obj_type, torch.nn.Module): |
|
obj = getattr(obj, "_orig_mod", obj) |
|
if "transformers" in sys.modules: |
|
import transformers |
|
|
|
if issubclass(obj_type, transformers.PreTrainedTokenizerBase): |
|
pklregister(obj_type)(_save_transformersPreTrainedTokenizerBase) |
|
|
|
|
|
if obj_type is FunctionType: |
|
obj = getattr(obj, "_torchdynamo_orig_callable", obj) |
|
dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id) |
|
|
|
def _batch_setitems(self, items): |
|
if self._legacy_no_dict_keys_sorting: |
|
return super()._batch_setitems(items) |
|
|
|
try: |
|
|
|
items = sorted(items) |
|
except Exception: |
|
from datasets.fingerprint import Hasher |
|
|
|
items = sorted(items, key=lambda x: Hasher.hash(x[0])) |
|
dill.Pickler._batch_setitems(self, items) |
|
|
|
def memoize(self, obj): |
|
|
|
if type(obj) is not str: |
|
dill.Pickler.memoize(self, obj) |
|
|
|
|
|
def pklregister(t): |
|
"""Register a custom reducer for the type.""" |
|
|
|
def proxy(func): |
|
Pickler.dispatch[t] = func |
|
return func |
|
|
|
return proxy |
|
|
|
|
|
def dump(obj, file): |
|
"""Pickle an object to a file.""" |
|
Pickler(file, recurse=True).dump(obj) |
|
|
|
|
|
def dumps(obj): |
|
"""Pickle an object to a string.""" |
|
file = BytesIO() |
|
dump(obj, file) |
|
return file.getvalue() |
|
|
|
|
|
if config.DILL_VERSION < version.parse("0.3.6"): |
|
|
|
def log(pickler, msg): |
|
dill._dill.log.info(msg) |
|
|
|
elif config.DILL_VERSION.release[:3] in [ |
|
version.parse("0.3.6").release, |
|
version.parse("0.3.7").release, |
|
version.parse("0.3.8").release, |
|
]: |
|
|
|
def log(pickler, msg): |
|
dill._dill.logger.trace(pickler, msg) |
|
|
|
|
|
@pklregister(set) |
|
def _save_set(pickler, obj): |
|
log(pickler, f"Se: {obj}") |
|
try: |
|
|
|
args = (sorted(obj),) |
|
except Exception: |
|
from datasets.fingerprint import Hasher |
|
|
|
args = (sorted(obj, key=Hasher.hash),) |
|
|
|
pickler.save_reduce(set, args, obj=obj) |
|
log(pickler, "# Se") |
|
|
|
|
|
def _save_regexPattern(pickler, obj): |
|
import regex |
|
|
|
log(pickler, f"Re: {obj}") |
|
args = (obj.pattern, obj.flags) |
|
pickler.save_reduce(regex.compile, args, obj=obj) |
|
log(pickler, "# Re") |
|
|
|
|
|
def _save_tiktokenEncoding(pickler, obj): |
|
import tiktoken |
|
|
|
log(pickler, f"Enc: {obj}") |
|
args = (obj.name, obj._pat_str, obj._mergeable_ranks, obj._special_tokens) |
|
pickler.save_reduce(tiktoken.Encoding, args, obj=obj) |
|
log(pickler, "# Enc") |
|
|
|
|
|
def _save_torchTensor(pickler, obj): |
|
import torch |
|
|
|
|
|
def create_torchTensor(np_array, dtype=None): |
|
tensor = torch.from_numpy(np_array) |
|
if dtype: |
|
tensor = tensor.type(dtype) |
|
return tensor |
|
|
|
log(pickler, f"To: {obj}") |
|
if obj.dtype == torch.bfloat16: |
|
args = (obj.detach().to(torch.float).cpu().numpy(), torch.bfloat16) |
|
else: |
|
args = (obj.detach().cpu().numpy(),) |
|
pickler.save_reduce(create_torchTensor, args, obj=obj) |
|
log(pickler, "# To") |
|
|
|
|
|
def _save_torchGenerator(pickler, obj): |
|
import torch |
|
|
|
def create_torchGenerator(state): |
|
generator = torch.Generator() |
|
generator.set_state(state) |
|
return generator |
|
|
|
log(pickler, f"Ge: {obj}") |
|
args = (obj.get_state(),) |
|
pickler.save_reduce(create_torchGenerator, args, obj=obj) |
|
log(pickler, "# Ge") |
|
|
|
|
|
def _save_spacyLanguage(pickler, obj): |
|
import spacy |
|
|
|
def create_spacyLanguage(config, bytes): |
|
lang_cls = spacy.util.get_lang_class(config["nlp"]["lang"]) |
|
lang_inst = lang_cls.from_config(config) |
|
return lang_inst.from_bytes(bytes) |
|
|
|
log(pickler, f"Sp: {obj}") |
|
args = (obj.config, obj.to_bytes()) |
|
pickler.save_reduce(create_spacyLanguage, args, obj=obj) |
|
log(pickler, "# Sp") |
|
|
|
|
|
def _save_transformersPreTrainedTokenizerBase(pickler, obj): |
|
log(pickler, f"Tok: {obj}") |
|
|
|
state = obj.__dict__ |
|
if "cache" in state and isinstance(state["cache"], dict): |
|
state["cache"] = {} |
|
pickler.save_reduce(type(obj), (), state=state, obj=obj) |
|
log(pickler, "# Tok") |
|
|
|
|
|
if config.DILL_VERSION < version.parse("0.3.6"): |
|
|
|
@pklregister(CodeType) |
|
def _save_code(pickler, obj): |
|
""" |
|
From dill._dill.save_code |
|
This is a modified version that removes the origin (filename + line no.) |
|
of functions created in notebooks or shells for example. |
|
""" |
|
dill._dill.log.info(f"Co: {obj}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
co_filename = ( |
|
"" |
|
if obj.co_filename.startswith("<") |
|
or ( |
|
len(obj.co_filename.split(os.path.sep)) > 1 |
|
and obj.co_filename.split(os.path.sep)[-2].startswith("ipykernel_") |
|
) |
|
or obj.co_name == "<lambda>" |
|
else os.path.basename(obj.co_filename) |
|
) |
|
co_firstlineno = 1 |
|
|
|
if dill._dill.PY3: |
|
if hasattr(obj, "co_posonlyargcount"): |
|
args = ( |
|
obj.co_argcount, |
|
obj.co_posonlyargcount, |
|
obj.co_kwonlyargcount, |
|
obj.co_nlocals, |
|
obj.co_stacksize, |
|
obj.co_flags, |
|
obj.co_code, |
|
obj.co_consts, |
|
obj.co_names, |
|
obj.co_varnames, |
|
co_filename, |
|
obj.co_name, |
|
co_firstlineno, |
|
obj.co_lnotab, |
|
obj.co_freevars, |
|
obj.co_cellvars, |
|
) |
|
else: |
|
args = ( |
|
obj.co_argcount, |
|
obj.co_kwonlyargcount, |
|
obj.co_nlocals, |
|
obj.co_stacksize, |
|
obj.co_flags, |
|
obj.co_code, |
|
obj.co_consts, |
|
obj.co_names, |
|
obj.co_varnames, |
|
co_filename, |
|
obj.co_name, |
|
co_firstlineno, |
|
obj.co_lnotab, |
|
obj.co_freevars, |
|
obj.co_cellvars, |
|
) |
|
else: |
|
args = ( |
|
obj.co_argcount, |
|
obj.co_nlocals, |
|
obj.co_stacksize, |
|
obj.co_flags, |
|
obj.co_code, |
|
obj.co_consts, |
|
obj.co_names, |
|
obj.co_varnames, |
|
co_filename, |
|
obj.co_name, |
|
co_firstlineno, |
|
obj.co_lnotab, |
|
obj.co_freevars, |
|
obj.co_cellvars, |
|
) |
|
pickler.save_reduce(CodeType, args, obj=obj) |
|
dill._dill.log.info("# Co") |
|
return |
|
|
|
elif config.DILL_VERSION.release[:3] in [ |
|
version.parse("0.3.6").release, |
|
version.parse("0.3.7").release, |
|
version.parse("0.3.8").release, |
|
]: |
|
|
|
@pklregister(CodeType) |
|
def save_code(pickler, obj): |
|
dill._dill.logger.trace(pickler, "Co: %s", obj) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
co_filename = ( |
|
"" |
|
if obj.co_filename.startswith("<") |
|
or ( |
|
len(obj.co_filename.split(os.path.sep)) > 1 |
|
and obj.co_filename.split(os.path.sep)[-2].startswith("ipykernel_") |
|
) |
|
or obj.co_name == "<lambda>" |
|
else os.path.basename(obj.co_filename) |
|
) |
|
co_firstlineno = 1 |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(obj, "co_endlinetable"): |
|
args = ( |
|
obj.co_lnotab, |
|
obj.co_argcount, |
|
obj.co_posonlyargcount, |
|
obj.co_kwonlyargcount, |
|
obj.co_nlocals, |
|
obj.co_stacksize, |
|
obj.co_flags, |
|
obj.co_code, |
|
obj.co_consts, |
|
obj.co_names, |
|
obj.co_varnames, |
|
co_filename, |
|
obj.co_name, |
|
obj.co_qualname, |
|
co_firstlineno, |
|
obj.co_linetable, |
|
obj.co_endlinetable, |
|
obj.co_columntable, |
|
obj.co_exceptiontable, |
|
obj.co_freevars, |
|
obj.co_cellvars, |
|
) |
|
elif hasattr(obj, "co_exceptiontable"): |
|
args = ( |
|
obj.co_lnotab, |
|
obj.co_argcount, |
|
obj.co_posonlyargcount, |
|
obj.co_kwonlyargcount, |
|
obj.co_nlocals, |
|
obj.co_stacksize, |
|
obj.co_flags, |
|
obj.co_code, |
|
obj.co_consts, |
|
obj.co_names, |
|
obj.co_varnames, |
|
co_filename, |
|
obj.co_name, |
|
obj.co_qualname, |
|
co_firstlineno, |
|
obj.co_linetable, |
|
obj.co_exceptiontable, |
|
obj.co_freevars, |
|
obj.co_cellvars, |
|
) |
|
elif hasattr(obj, "co_linetable"): |
|
args = ( |
|
obj.co_lnotab, |
|
obj.co_argcount, |
|
obj.co_posonlyargcount, |
|
obj.co_kwonlyargcount, |
|
obj.co_nlocals, |
|
obj.co_stacksize, |
|
obj.co_flags, |
|
obj.co_code, |
|
obj.co_consts, |
|
obj.co_names, |
|
obj.co_varnames, |
|
co_filename, |
|
obj.co_name, |
|
co_firstlineno, |
|
obj.co_linetable, |
|
obj.co_freevars, |
|
obj.co_cellvars, |
|
) |
|
elif hasattr(obj, "co_posonlyargcount"): |
|
args = ( |
|
obj.co_argcount, |
|
obj.co_posonlyargcount, |
|
obj.co_kwonlyargcount, |
|
obj.co_nlocals, |
|
obj.co_stacksize, |
|
obj.co_flags, |
|
obj.co_code, |
|
obj.co_consts, |
|
obj.co_names, |
|
obj.co_varnames, |
|
co_filename, |
|
obj.co_name, |
|
co_firstlineno, |
|
obj.co_lnotab, |
|
obj.co_freevars, |
|
obj.co_cellvars, |
|
) |
|
else: |
|
args = ( |
|
obj.co_argcount, |
|
obj.co_kwonlyargcount, |
|
obj.co_nlocals, |
|
obj.co_stacksize, |
|
obj.co_flags, |
|
obj.co_code, |
|
obj.co_consts, |
|
obj.co_names, |
|
obj.co_varnames, |
|
co_filename, |
|
obj.co_name, |
|
co_firstlineno, |
|
obj.co_lnotab, |
|
obj.co_freevars, |
|
obj.co_cellvars, |
|
) |
|
|
|
pickler.save_reduce(dill._dill._create_code, args, obj=obj) |
|
dill._dill.logger.trace(pickler, "# Co") |
|
return |
|
|