|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""TF-specific utils import.""" |
|
|
|
import os |
|
import warnings |
|
from functools import partial |
|
from math import ceil |
|
from uuid import uuid4 |
|
|
|
import numpy as np |
|
import pyarrow as pa |
|
from multiprocess import get_context |
|
|
|
|
|
try: |
|
from multiprocess.shared_memory import SharedMemory |
|
except ImportError: |
|
SharedMemory = None |
|
|
|
from .. import config |
|
|
|
|
|
def minimal_tf_collate_fn(features): |
|
if isinstance(features, dict): |
|
return features |
|
elif config.TF_AVAILABLE: |
|
import tensorflow as tf |
|
else: |
|
raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.") |
|
|
|
first = features[0] |
|
batch = {} |
|
for k, v in first.items(): |
|
if isinstance(v, np.ndarray): |
|
batch[k] = np.stack([f[k] for f in features]) |
|
elif isinstance(v, tf.Tensor): |
|
batch[k] = tf.stack([f[k] for f in features]) |
|
else: |
|
batch[k] = np.array([f[k] for f in features]) |
|
return batch |
|
|
|
|
|
def minimal_tf_collate_fn_with_renaming(features): |
|
batch = minimal_tf_collate_fn(features) |
|
if "label" in batch: |
|
batch["labels"] = batch["label"] |
|
del batch["label"] |
|
return batch |
|
|
|
|
|
def is_numeric_pa_type(pa_type): |
|
if pa.types.is_list(pa_type): |
|
return is_numeric_pa_type(pa_type.value_type) |
|
return pa.types.is_integer(pa_type) or pa.types.is_floating(pa_type) or pa.types.is_decimal(pa_type) |
|
|
|
|
|
def np_get_batch( |
|
indices, dataset, cols_to_retain, collate_fn, collate_fn_args, columns_to_np_types, return_dict=False |
|
): |
|
if not isinstance(indices, np.ndarray): |
|
indices = indices.numpy() |
|
|
|
is_batched = True |
|
|
|
if isinstance(indices, np.integer): |
|
batch = dataset[indices.item()] |
|
is_batched = False |
|
elif np.all(np.diff(indices) == 1): |
|
batch = dataset[indices[0] : indices[-1] + 1] |
|
elif isinstance(indices, np.ndarray): |
|
batch = dataset[indices] |
|
else: |
|
raise RuntimeError(f"Unexpected type for indices: {type(indices)}") |
|
|
|
if cols_to_retain is not None: |
|
batch = { |
|
key: value |
|
for key, value in batch.items() |
|
if key in cols_to_retain or key in ("label", "label_ids", "labels") |
|
} |
|
|
|
if is_batched: |
|
actual_size = len(list(batch.values())[0]) |
|
|
|
batch = [{key: value[i] for key, value in batch.items()} for i in range(actual_size)] |
|
batch = collate_fn(batch, **collate_fn_args) |
|
|
|
if return_dict: |
|
out_batch = {} |
|
for col, cast_dtype in columns_to_np_types.items(): |
|
|
|
array = np.array(batch[col]) |
|
array = array.astype(cast_dtype) |
|
out_batch[col] = array |
|
else: |
|
out_batch = [] |
|
for col, cast_dtype in columns_to_np_types.items(): |
|
|
|
array = np.array(batch[col]) |
|
array = array.astype(cast_dtype) |
|
out_batch.append(array) |
|
return out_batch |
|
|
|
|
|
def dataset_to_tf( |
|
dataset, |
|
cols_to_retain, |
|
collate_fn, |
|
collate_fn_args, |
|
columns_to_np_types, |
|
output_signature, |
|
shuffle, |
|
batch_size, |
|
drop_remainder, |
|
): |
|
"""Create a tf.data.Dataset from the underlying Dataset. This is a single-process method - the multiprocess |
|
equivalent is multiprocess_dataset_to_tf. |
|
|
|
Args: |
|
dataset (`Dataset`): Dataset to wrap with tf.data.Dataset. |
|
cols_to_retain (`List[str]`): Dataset column(s) to load in the |
|
tf.data.Dataset. It is acceptable to include column names that are created by the `collate_fn` and |
|
that do not exist in the original dataset. |
|
collate_fn(`Callable`): A function or callable object (such as a `DataCollator`) that will collate |
|
lists of samples into a batch. |
|
collate_fn_args (`Dict`): A `dict` of keyword arguments to be passed to the |
|
`collate_fn`. Can be empty. |
|
columns_to_np_types (`Dict[str, np.dtype]`): A `dict` mapping column names to numpy dtypes. |
|
output_signature (`Dict[str, tf.TensorSpec]`): A `dict` mapping column names to |
|
`tf.TensorSpec` objects. |
|
shuffle(`bool`): Shuffle the dataset order when loading. Recommended True for training, False for |
|
validation/evaluation. |
|
batch_size (`int`, default `None`): Size of batches to load from the dataset. Defaults to `None`, which implies that |
|
the dataset won't be batched, but the returned dataset can be batched later with `tf_dataset.batch(batch_size)`. |
|
drop_remainder(`bool`, default `None`): Drop the last incomplete batch when loading. If not provided, |
|
defaults to the same setting as shuffle. |
|
|
|
Returns: |
|
`tf.data.Dataset` |
|
""" |
|
if config.TF_AVAILABLE: |
|
import tensorflow as tf |
|
else: |
|
raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.") |
|
|
|
|
|
|
|
if hasattr(tf, "random_index_shuffle"): |
|
random_index_shuffle = tf.random_index_shuffle |
|
elif hasattr(tf.random.experimental, "index_shuffle"): |
|
random_index_shuffle = tf.random.experimental.index_shuffle |
|
else: |
|
if len(dataset) > 10_000_000: |
|
warnings.warn( |
|
"to_tf_dataset() can be memory-inefficient on versions of TensorFlow older than 2.9. " |
|
"If you are iterating over a dataset with a very large number of samples, consider " |
|
"upgrading to TF >= 2.9." |
|
) |
|
random_index_shuffle = None |
|
|
|
getter_fn = partial( |
|
np_get_batch, |
|
dataset=dataset, |
|
cols_to_retain=cols_to_retain, |
|
collate_fn=collate_fn, |
|
collate_fn_args=collate_fn_args, |
|
columns_to_np_types=columns_to_np_types, |
|
return_dict=False, |
|
) |
|
|
|
|
|
tout = [tf.dtypes.as_dtype(dtype) for dtype in columns_to_np_types.values()] |
|
|
|
@tf.function(input_signature=[tf.TensorSpec(None, tf.int64)]) |
|
def fetch_function(indices): |
|
output = tf.py_function( |
|
getter_fn, |
|
inp=[indices], |
|
Tout=tout, |
|
) |
|
return {key: output[i] for i, key in enumerate(columns_to_np_types.keys())} |
|
|
|
tf_dataset = tf.data.Dataset.range(len(dataset)) |
|
|
|
if shuffle and random_index_shuffle is not None: |
|
base_seed = tf.fill((3,), value=tf.cast(-1, dtype=tf.int64)) |
|
|
|
def scan_random_index(state, index): |
|
if tf.reduce_all(state == -1): |
|
|
|
|
|
state = tf.random.uniform(shape=(3,), maxval=2**62, dtype=tf.int64) |
|
shuffled_index = random_index_shuffle(index=index, seed=state, max_index=len(dataset) - 1) |
|
return state, shuffled_index |
|
|
|
tf_dataset = tf_dataset.scan(base_seed, scan_random_index) |
|
elif shuffle: |
|
tf_dataset = tf_dataset.shuffle(tf_dataset.cardinality()) |
|
|
|
if batch_size is not None: |
|
tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder) |
|
|
|
tf_dataset = tf_dataset.map(fetch_function) |
|
|
|
if batch_size is not None: |
|
|
|
def ensure_shapes(input_dict): |
|
return {key: tf.ensure_shape(val, output_signature[key].shape) for key, val in input_dict.items()} |
|
|
|
else: |
|
|
|
def ensure_shapes(input_dict): |
|
return {key: tf.ensure_shape(val, output_signature[key].shape[1:]) for key, val in input_dict.items()} |
|
|
|
return tf_dataset.map(ensure_shapes) |
|
|
|
|
|
class SharedMemoryContext: |
|
|
|
|
|
def __init__(self): |
|
self.created_shms = [] |
|
self.opened_shms = [] |
|
|
|
def get_shm(self, name, size, create): |
|
shm = SharedMemory(size=int(size), name=name, create=create) |
|
if create: |
|
|
|
self.created_shms.append(shm) |
|
else: |
|
|
|
self.opened_shms.append(shm) |
|
return shm |
|
|
|
def get_array(self, name, shape, dtype, create): |
|
shm = self.get_shm(name=name, size=np.prod(shape) * np.dtype(dtype).itemsize, create=create) |
|
return np.ndarray(shape, dtype=dtype, buffer=shm.buf) |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
for shm in self.created_shms: |
|
shm.close() |
|
shm.unlink() |
|
for shm in self.opened_shms: |
|
shm.close() |
|
|
|
|
|
class NumpyMultiprocessingGenerator: |
|
def __init__( |
|
self, |
|
dataset, |
|
cols_to_retain, |
|
collate_fn, |
|
collate_fn_args, |
|
columns_to_np_types, |
|
output_signature, |
|
shuffle, |
|
batch_size, |
|
drop_remainder, |
|
num_workers, |
|
): |
|
self.dataset = dataset |
|
self.cols_to_retain = cols_to_retain |
|
self.collate_fn = collate_fn |
|
self.collate_fn_args = collate_fn_args |
|
self.string_columns = [col for col, dtype in columns_to_np_types.items() if dtype is np.str_] |
|
|
|
self.columns_to_np_types = { |
|
col: dtype if col not in self.string_columns else np.dtype("U1") |
|
for col, dtype in columns_to_np_types.items() |
|
} |
|
self.output_signature = output_signature |
|
self.shuffle = shuffle |
|
self.batch_size = batch_size |
|
self.drop_remainder = drop_remainder |
|
self.num_workers = num_workers |
|
|
|
self.columns_to_ranks = { |
|
col: int(spec.shape.rank) if col not in self.string_columns else int(spec.shape.rank) + 1 |
|
for col, spec in output_signature.items() |
|
} |
|
|
|
def __iter__(self): |
|
|
|
num_workers = min(self.num_workers, int(ceil(len(self.dataset) / self.batch_size))) |
|
|
|
per_worker_batches, final_batch, final_batch_worker = self.distribute_batches( |
|
self.dataset, self.batch_size, self.drop_remainder, num_workers, self.shuffle |
|
) |
|
ctx = get_context("spawn") |
|
names = [] |
|
shape_arrays = [] |
|
workers = [] |
|
array_ready_events = [ctx.Event() for _ in range(num_workers)] |
|
array_loaded_events = [ctx.Event() for _ in range(num_workers)] |
|
|
|
base_args = { |
|
"dataset": self.dataset, |
|
"cols_to_retain": self.cols_to_retain, |
|
"collate_fn": self.collate_fn, |
|
"collate_fn_args": self.collate_fn_args, |
|
"columns_to_np_types": self.columns_to_np_types, |
|
"columns_to_ranks": self.columns_to_ranks, |
|
"string_columns": self.string_columns, |
|
} |
|
with SharedMemoryContext() as shm_ctx: |
|
for i in range(num_workers): |
|
worker_random_id = str(uuid4()) |
|
worker_name = f"dw_{i}_{worker_random_id}"[:10] |
|
names.append(worker_name) |
|
|
|
worker_shape_arrays = { |
|
col: shm_ctx.get_array(f"{worker_name}_{col}_shape", shape=(rank,), dtype=np.int64, create=True) |
|
for col, rank in self.columns_to_ranks.items() |
|
} |
|
shape_arrays.append(worker_shape_arrays) |
|
|
|
worker_indices = per_worker_batches[i] |
|
if i == final_batch_worker and final_batch is not None: |
|
final_batch_arg = final_batch |
|
else: |
|
final_batch_arg = None |
|
worker_kwargs = { |
|
"worker_name": worker_name, |
|
"indices": worker_indices, |
|
"extra_batch": final_batch_arg, |
|
"array_ready_event": array_ready_events[i], |
|
"array_loaded_event": array_loaded_events[i], |
|
**base_args, |
|
} |
|
worker = ctx.Process(target=self.worker_loop, kwargs=worker_kwargs, daemon=True) |
|
worker.start() |
|
workers.append(worker) |
|
|
|
end_signal_received = False |
|
while not end_signal_received: |
|
for i in range(num_workers): |
|
if not array_ready_events[i].wait(timeout=60): |
|
raise TimeoutError("Data loading worker timed out!") |
|
array_ready_events[i].clear() |
|
array_shapes = shape_arrays[i] |
|
if any(np.any(shape < 0) for shape in array_shapes.values()): |
|
|
|
|
|
end_signal_received = True |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with SharedMemoryContext() as batch_shm_ctx: |
|
|
|
arrays = { |
|
col: batch_shm_ctx.get_array( |
|
f"{names[i]}_{col}", |
|
shape=shape, |
|
dtype=self.columns_to_np_types[col], |
|
create=False, |
|
) |
|
for col, shape in array_shapes.items() |
|
} |
|
|
|
|
|
arrays = {col: np.copy(arr) for col, arr in arrays.items()} |
|
|
|
for string_col in self.string_columns: |
|
arrays[string_col] = ( |
|
arrays[string_col].view(f"U{arrays[string_col].shape[-1]}").squeeze(-1) |
|
) |
|
yield arrays |
|
array_loaded_events[i].set() |
|
|
|
|
|
for worker in workers: |
|
worker.join() |
|
|
|
def __call__(self): |
|
return self |
|
|
|
@staticmethod |
|
def worker_loop( |
|
dataset, |
|
cols_to_retain, |
|
collate_fn, |
|
collate_fn_args, |
|
columns_to_np_types, |
|
columns_to_ranks, |
|
string_columns, |
|
indices, |
|
extra_batch, |
|
worker_name, |
|
array_ready_event, |
|
array_loaded_event, |
|
): |
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
|
|
|
if config.TF_AVAILABLE: |
|
import tensorflow as tf |
|
else: |
|
raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.") |
|
|
|
tf.config.set_visible_devices([], "GPU") |
|
|
|
def send_batch_to_parent(indices): |
|
batch = np_get_batch( |
|
indices=indices, |
|
dataset=dataset, |
|
cols_to_retain=cols_to_retain, |
|
collate_fn=collate_fn, |
|
collate_fn_args=collate_fn_args, |
|
columns_to_np_types=columns_to_np_types, |
|
return_dict=True, |
|
) |
|
|
|
|
|
out_arrays = {} |
|
with SharedMemoryContext() as batch_shm_ctx: |
|
|
|
|
|
for col, cast_dtype in columns_to_np_types.items(): |
|
|
|
array = batch[col] |
|
if col in string_columns: |
|
|
|
|
|
array = array.view("U1").reshape(array.shape + (-1,)) |
|
shape_arrays[col][:] = array.shape |
|
out_arrays[col] = batch_shm_ctx.get_array( |
|
f"{worker_name}_{col}", shape=array.shape, dtype=cast_dtype, create=True |
|
) |
|
out_arrays[col][:] = array |
|
|
|
array_ready_event.set() |
|
array_loaded_event.wait() |
|
array_loaded_event.clear() |
|
|
|
with SharedMemoryContext() as shm_ctx: |
|
shape_arrays = { |
|
col: shm_ctx.get_array(f"{worker_name}_{col}_shape", shape=(rank,), dtype=np.int64, create=False) |
|
for col, rank in columns_to_ranks.items() |
|
} |
|
|
|
for batch in indices: |
|
send_batch_to_parent(batch) |
|
if extra_batch is not None: |
|
send_batch_to_parent(extra_batch) |
|
|
|
for col, array in shape_arrays.items(): |
|
array[:] = -1 |
|
array_ready_event.set() |
|
|
|
@staticmethod |
|
def distribute_batches(dataset, batch_size, drop_remainder, num_workers, shuffle): |
|
indices = np.arange(len(dataset)) |
|
if shuffle: |
|
np.random.shuffle(indices) |
|
num_samples = len(indices) |
|
|
|
|
|
incomplete_batch_cutoff = num_samples - (num_samples % batch_size) |
|
indices, last_incomplete_batch = np.split(indices, [incomplete_batch_cutoff]) |
|
if drop_remainder or len(last_incomplete_batch) == 0: |
|
last_incomplete_batch = None |
|
|
|
indices = indices.reshape(-1, batch_size) |
|
num_batches = len(indices) |
|
final_batches_cutoff = num_batches - (num_batches % num_workers) |
|
indices, final_batches = np.split(indices, [final_batches_cutoff]) |
|
indices = indices.reshape(-1, num_workers, batch_size) |
|
|
|
per_worker_indices = np.split(indices, indices.shape[1], axis=1) |
|
per_worker_indices = [np.squeeze(worker_indices, 1) for worker_indices in per_worker_indices] |
|
|
|
for i in range(len(final_batches)): |
|
|
|
per_worker_indices[i] = np.concatenate([per_worker_indices[i], final_batches[i].reshape(1, -1)], axis=0) |
|
|
|
if last_incomplete_batch is not None: |
|
incomplete_batch_worker_idx = len(final_batches) |
|
else: |
|
incomplete_batch_worker_idx = None |
|
return per_worker_indices, last_incomplete_batch, incomplete_batch_worker_idx |
|
|
|
|
|
def multiprocess_dataset_to_tf( |
|
dataset, |
|
cols_to_retain, |
|
collate_fn, |
|
collate_fn_args, |
|
columns_to_np_types, |
|
output_signature, |
|
shuffle, |
|
batch_size, |
|
drop_remainder, |
|
num_workers, |
|
): |
|
"""Create a tf.data.Dataset from the underlying Dataset. This is a multi-process method - the single-process |
|
equivalent is dataset_to_tf. |
|
|
|
Args: |
|
dataset (`Dataset`): Dataset to wrap with tf.data.Dataset. |
|
cols_to_retain (`List[str]`): Dataset column(s) to load in the |
|
tf.data.Dataset. It is acceptable to include column names that are created by the `collate_fn` and |
|
that do not exist in the original dataset. |
|
collate_fn(`Callable`): A function or callable object (such as a `DataCollator`) that will collate |
|
lists of samples into a batch. |
|
collate_fn_args (`Dict`): A `dict` of keyword arguments to be passed to the |
|
`collate_fn`. Can be empty. |
|
columns_to_np_types (`Dict[str, np.dtype]`): A `dict` mapping column names to numpy dtypes. |
|
output_signature (`Dict[str, tf.TensorSpec]`): A `dict` mapping column names to |
|
`tf.TensorSpec` objects. |
|
shuffle(`bool`): Shuffle the dataset order when loading. Recommended True for training, False for |
|
validation/evaluation. |
|
batch_size (`int`, default `None`): Size of batches to load from the dataset. Defaults to `None`, which implies that |
|
the dataset won't be batched, but the returned dataset can be batched later with `tf_dataset.batch(batch_size)`. |
|
drop_remainder(`bool`, default `None`): Drop the last incomplete batch when loading. If not provided, |
|
defaults to the same setting as shuffle. |
|
num_workers (`int`): Number of workers to use for loading the dataset. Should be >= 1. |
|
|
|
Returns: |
|
`tf.data.Dataset` |
|
""" |
|
if config.TF_AVAILABLE: |
|
import tensorflow as tf |
|
else: |
|
raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.") |
|
|
|
data_generator = NumpyMultiprocessingGenerator( |
|
dataset=dataset, |
|
cols_to_retain=cols_to_retain, |
|
collate_fn=collate_fn, |
|
collate_fn_args=collate_fn_args, |
|
columns_to_np_types=columns_to_np_types, |
|
output_signature=output_signature, |
|
shuffle=shuffle, |
|
batch_size=batch_size, |
|
drop_remainder=drop_remainder, |
|
num_workers=num_workers, |
|
) |
|
|
|
tf_dataset = tf.data.Dataset.from_generator(data_generator, output_signature=output_signature) |
|
if drop_remainder: |
|
dataset_length = int(len(dataset) // batch_size) |
|
else: |
|
dataset_length = int(ceil(len(dataset) / batch_size)) |
|
return tf_dataset.apply(tf.data.experimental.assert_cardinality(dataset_length)) |
|
|