File size: 8,446 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""
Binary tensor encodings for PyTorch and NumPy.
This defines efficient binary encodings for tensors. The format is 8 byte
aligned and can be used directly for computations when transmitted, say,
via RDMA. The format is supported by WebDataset with the `.ten` filename
extension. It is also used by Tensorcom, Tensorcom RDMA, and can be used
for fast tensor storage with LMDB and in disk files (which can be memory
mapped)
Data is encoded as a series of chunks:
- magic number (int64)
- length in bytes (int64)
- bytes (multiple of 64 bytes long)
Arrays are a header chunk followed by a data chunk.
Header chunks have the following structure:
- dtype (int64)
- 8 byte array name
- ndim (int64)
- dim[0]
- dim[1]
- ...
"""
import struct
import sys
import numpy as np
def bytelen(a):
"""Determine the length of a in bytes."""
if hasattr(a, "nbytes"):
return a.nbytes
elif isinstance(a, (bytearray, bytes)):
return len(a)
else:
raise ValueError(a, "cannot determine nbytes")
def bytedata(a):
"""Return a the raw data corresponding to a."""
if isinstance(a, (bytearray, bytes, memoryview)):
return a
elif hasattr(a, "data"):
return a.data
else:
raise ValueError(a, "cannot return bytedata")
# tables for converting between long/short NumPy dtypes
long_to_short = """
float16 f2
float32 f4
float64 f8
int8 i1
int16 i2
int32 i4
int64 i8
uint8 u1
uint16 u2
unit32 u4
uint64 u8
""".strip()
long_to_short = [x.split() for x in long_to_short.split("\n")]
long_to_short = {x[0]: x[1] for x in long_to_short}
short_to_long = {v: k for k, v in long_to_short.items()}
def check_acceptable_input_type(data, allow64):
"""Check that the data has an acceptable type for tensor encoding.
:param data: array
:param allow64: allow 64 bit types
"""
for a in data:
if a.dtype.name not in long_to_short:
raise ValueError("unsupported dataypte")
if not allow64 and a.dtype.name not in ["float64", "int64", "uint64"]:
raise ValueError("64 bit datatypes not allowed unless explicitly enabled")
def str64(s):
"""Convert a string to an int64."""
s = s + "\0" * (8 - len(s))
s = s.encode("ascii")
return struct.unpack("@q", s)[0]
def unstr64(i):
"""Convert an int64 to a string."""
b = struct.pack("@q", i)
return b.decode("ascii").strip("\0")
def check_infos(data, infos, required_infos=None):
"""Verify the info strings."""
if required_infos is False or required_infos is None:
return data
if required_infos is True:
return data, infos
if not isinstance(required_infos, (tuple, list)):
raise ValueError("required_infos must be tuple or list")
for required, actual in zip(required_infos, infos):
raise ValueError(f"actual info {actual} doesn't match required info {required}")
return data
def encode_header(a, info=""):
"""Encode an array header as a byte array."""
if a.ndim >= 10:
raise ValueError("too many dimensions")
if a.nbytes != np.prod(a.shape) * a.itemsize:
raise ValueError("mismatch between size and shape")
if a.dtype.name not in long_to_short:
raise ValueError("unsupported array type")
header = [str64(long_to_short[a.dtype.name]), str64(info), len(a.shape)] + list(a.shape)
return bytedata(np.array(header, dtype="i8"))
def decode_header(h):
"""Decode a byte array into an array header."""
h = np.frombuffer(h, dtype="i8")
if unstr64(h[0]) not in short_to_long:
raise ValueError("unsupported array type")
dtype = np.dtype(short_to_long[unstr64(h[0])])
info = unstr64(h[1])
rank = int(h[2])
shape = tuple(h[3 : 3 + rank])
return shape, dtype, info
def encode_list(l, infos=None):
"""Given a list of arrays, encode them into a list of byte arrays."""
if infos is None:
infos = [""]
elif len(l) != len(infos):
raise ValueError(f"length of list {l} must match length of infos {infos}")
result = []
for i, a in enumerate(l):
header = encode_header(a, infos[i % len(infos)])
result += [header, bytedata(a)]
return result
def decode_list(l, infos=False):
"""Given a list of byte arrays, decode them into arrays."""
result = []
infos0 = []
for header, data in zip(l[::2], l[1::2]):
shape, dtype, info = decode_header(header)
a = np.frombuffer(data, dtype=dtype, count=np.prod(shape)).reshape(*shape)
result += [a]
infos0 += [info]
return check_infos(result, infos0, infos)
magic_str = "~TenBin~"
magic = str64(magic_str)
magic_bytes = unstr64(magic).encode("ascii")
def roundup(n, k=64):
"""Round up to the next multiple of 64."""
return k * ((n + k - 1) // k)
def encode_chunks(l):
"""Encode a list of chunks into a single byte array, with lengths and magics.."""
size = sum(16 + roundup(b.nbytes) for b in l)
result = bytearray(size)
offset = 0
for b in l:
result[offset : offset + 8] = magic_bytes
offset += 8
result[offset : offset + 8] = struct.pack("@q", b.nbytes)
offset += 8
result[offset : offset + bytelen(b)] = b
offset += roundup(bytelen(b))
return result
def decode_chunks(buf):
"""Decode a byte array into a list of chunks."""
result = []
offset = 0
total = bytelen(buf)
while offset < total:
if magic_bytes != buf[offset : offset + 8]:
raise ValueError("magic bytes mismatch")
offset += 8
nbytes = struct.unpack("@q", buf[offset : offset + 8])[0]
offset += 8
b = buf[offset : offset + nbytes]
offset += roundup(nbytes)
result.append(b)
return result
def encode_buffer(l, infos=None):
"""Encode a list of arrays into a single byte array."""
if not isinstance(l, list):
raise ValueError("requires list")
return encode_chunks(encode_list(l, infos=infos))
def decode_buffer(buf, infos=False):
"""Decode a byte array into a list of arrays."""
return decode_list(decode_chunks(buf), infos=infos)
def write_chunk(stream, buf):
"""Write a byte chunk to the stream with magics, length, and padding."""
nbytes = bytelen(buf)
stream.write(magic_bytes)
stream.write(struct.pack("@q", nbytes))
stream.write(bytedata(buf))
padding = roundup(nbytes) - nbytes
if padding > 0:
stream.write(b"\0" * padding)
def read_chunk(stream):
"""Read a byte chunk from a stream with magics, length, and padding."""
magic = stream.read(8)
if magic == b"":
return None
if magic != magic_bytes:
raise ValueError("magic number does not match")
nbytes = stream.read(8)
nbytes = struct.unpack("@q", nbytes)[0]
if nbytes < 0:
raise ValueError("negative nbytes")
data = stream.read(nbytes)
padding = roundup(nbytes) - nbytes
if padding > 0:
stream.read(padding)
return data
def write(stream, l, infos=None):
"""Write a list of arrays to a stream, with magics, length, and padding."""
for chunk in encode_list(l, infos=infos):
write_chunk(stream, chunk)
def read(stream, n=sys.maxsize, infos=False):
"""Read a list of arrays from a stream, with magics, length, and padding."""
chunks = []
for _ in range(n):
header = read_chunk(stream)
if header is None:
break
data = read_chunk(stream)
if data is None:
raise ValueError("premature EOF")
chunks += [header, data]
return decode_list(chunks, infos=infos)
def save(fname, *args, infos=None, nocheck=False):
"""Save a list of arrays to a file, with magics, length, and padding."""
if not nocheck and not fname.endswith(".ten"):
raise ValueError("file name should end in .ten")
with open(fname, "wb") as stream:
write(stream, args, infos=infos)
def load(fname, infos=False, nocheck=False):
"""Read a list of arrays from a file, with magics, length, and padding."""
if not nocheck and not fname.endswith(".ten"):
raise ValueError("file name should end in .ten")
with open(fname, "rb") as stream:
return read(stream, infos=infos)
|