|
|
|
|
|
import logging |
|
from typing import Any, Optional |
|
|
|
import torch |
|
from torch.fx.node import map_aggregate |
|
from torch.utils._pytree import tree_flatten, tree_unflatten |
|
|
|
|
|
__all__ = [ |
|
"TensorChunkSpec", |
|
"split_args_kwargs_into_chunks", |
|
"merge_chunks", |
|
] |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
""" |
|
_debug_mask_minibatches specifies to send masked versions of the mini-batch |
|
through instead of micro-batch slices--this can be used for more stable |
|
numerical testing (see [A Note About Correctness Testing]) |
|
""" |
|
_debug_mask_minibatches = False |
|
|
|
|
|
class _CustomReducer: |
|
""" |
|
Custom reducer class that can be used to specify a custom operation that |
|
reduces losses of multiple microbatches into one value. |
|
|
|
Example: |
|
>>> # xdoctest: +SKIP |
|
>>> sum_reducer = _CustomReducer( |
|
>>> torch.tensor(0.0), |
|
>>> lambda a, b: a + b |
|
>>> ) |
|
""" |
|
|
|
def __init__(self, init_value, reduce_fn): |
|
self.init_value = init_value |
|
self.reduce_fn = reduce_fn |
|
|
|
|
|
class _LossReducer(_CustomReducer): |
|
pass |
|
|
|
|
|
sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b) |
|
|
|
|
|
|
|
DEFAULT_CHUNK_DIM = 0 |
|
|
|
|
|
class TensorChunkSpec: |
|
""" |
|
Class used to specify chunking of inputs |
|
""" |
|
|
|
def __init__(self, split_dim): |
|
self.split_dim = split_dim |
|
|
|
split_dim: int |
|
|
|
def __repr__(self): |
|
return ( |
|
f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})" |
|
) |
|
|
|
def __str__(self): |
|
return f"TensorChunkSpec({self.split_dim})" |
|
|
|
@staticmethod |
|
def from_tuple( |
|
chunk_dims: tuple[int, ...], |
|
): |
|
""" |
|
A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk |
|
dimensions (int's). |
|
Example: |
|
>>> # xdoctest: +SKIP |
|
>>> # There are three positional arguments to the model, and |
|
>>> # we are chunking them along dimension 0, 0 and 1, respectively |
|
>>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1)) |
|
""" |
|
args_chunk_spec = map_aggregate( |
|
chunk_dims, |
|
lambda dim: TensorChunkSpec(dim), |
|
) |
|
return args_chunk_spec |
|
|
|
@staticmethod |
|
def from_dict( |
|
chunk_dims: dict[str, int], |
|
): |
|
""" |
|
A helper for creating a dictionary of `TensorChunkSpec` from a |
|
dictionary of chunk dimensions (int's). |
|
Example: |
|
>>> # xdoctest: +SKIP |
|
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument |
|
>>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1}) |
|
""" |
|
kwargs_chunk_spec = map_aggregate( |
|
chunk_dims, |
|
lambda dim: TensorChunkSpec(dim), |
|
) |
|
return kwargs_chunk_spec |
|
|
|
|
|
|
|
class _Replicate: |
|
pass |
|
|
|
|
|
def _shard_dict_of_args( |
|
args_dict, |
|
args_chunk_spec, |
|
num_chunks, |
|
): |
|
""" |
|
Given a dictionary of args, and a dictionary of chunking specs, shard the |
|
args according to the chunking specs. |
|
|
|
Args: |
|
args_dict: Dictionary of args |
|
args_chunk_spec: Dictionary of chunking specs |
|
num_chunks: Number of chunks to shard the args into |
|
|
|
Returns: |
|
args_split: List of sharded args |
|
""" |
|
|
|
|
|
|
|
args_sharded_replicated = {} |
|
arg_specs = [] |
|
|
|
real_num_chunks = num_chunks |
|
first_tensor = True |
|
|
|
assert len(args_dict) == len(args_chunk_spec), ( |
|
f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" |
|
) |
|
|
|
for arg_key, arg in args_dict.items(): |
|
flat, spec = tree_flatten(arg) |
|
arg_specs.append(spec) |
|
|
|
chunk_spec = args_chunk_spec[arg_key] |
|
assert chunk_spec is not None |
|
chunk_spec_flat, _ = tree_flatten(chunk_spec) |
|
if len(flat) != len(chunk_spec_flat): |
|
raise ValueError( |
|
f"Argument value {arg} did not have the same number of " |
|
f"values as as chunk spec {chunk_spec}" |
|
) |
|
|
|
sharded_arg_flat = [] |
|
|
|
for v, chunk_v in zip(flat, chunk_spec_flat): |
|
if chunk_v is _Replicate or not isinstance(v, torch.Tensor): |
|
sharded_arg_flat.append([v] * real_num_chunks) |
|
elif isinstance(chunk_v, TensorChunkSpec): |
|
|
|
|
|
|
|
assert isinstance(v, torch.Tensor), f"{v} is not a tensor" |
|
|
|
v_split_dim_size = v.size(chunk_v.split_dim) |
|
if v_split_dim_size < real_num_chunks: |
|
if first_tensor: |
|
|
|
|
|
logger.warning( |
|
f"Tensor size on chunking dimension is {v_split_dim_size}, " |
|
f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}." |
|
) |
|
real_num_chunks = v_split_dim_size |
|
else: |
|
raise RuntimeError( |
|
f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, " |
|
f"smaller than the number of chunks {num_chunks}. " |
|
"PiPPy cannot reduce the number of chunks because " |
|
"other arguments have bigger chunk-dimension sizes. " |
|
"Please adjust your num_chunks setting." |
|
) |
|
|
|
chunk_tensors = torch.tensor_split( |
|
v, real_num_chunks, chunk_v.split_dim |
|
) |
|
|
|
if _debug_mask_minibatches: |
|
expanded_chunks = [] |
|
|
|
split_dim_idx = 0 |
|
for chunk_tensor in chunk_tensors: |
|
new_val = torch.zeros_like(v) |
|
upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim) |
|
|
|
slice_indices = [slice(None, None, None)] * new_val.ndim |
|
slice_indices[chunk_v.split_dim] = slice( |
|
split_dim_idx, upper_idx |
|
) |
|
new_val[slice_indices] = chunk_tensor |
|
|
|
expanded_chunks.append(new_val) |
|
|
|
split_dim_idx += chunk_tensor.size(chunk_v.split_dim) |
|
|
|
sharded_arg_flat.append(expanded_chunks) |
|
else: |
|
sharded_arg_flat.append(chunk_tensors) |
|
|
|
first_tensor = False |
|
else: |
|
raise TypeError(f"Unrecognized chunk spec: {chunk_v}") |
|
|
|
args_sharded_replicated[arg_key] = sharded_arg_flat |
|
|
|
|
|
chunks_flat = [] |
|
for chunk_idx in range(real_num_chunks): |
|
chunk_args = {} |
|
for key, arg in args_sharded_replicated.items(): |
|
arg_single_chunk = [v_flat[chunk_idx] for v_flat in arg] |
|
chunk_args[key] = arg_single_chunk |
|
chunks_flat.append(chunk_args) |
|
|
|
|
|
args_split = [] |
|
|
|
for chunk in chunks_flat: |
|
per_chunk_args = {} |
|
assert len(arg_specs) == len(chunk) |
|
for (key, arg), arg_spec in zip(chunk.items(), arg_specs): |
|
per_chunk_args[key] = tree_unflatten(arg, arg_spec) |
|
args_split.append(per_chunk_args) |
|
|
|
return args_split |
|
|
|
|
|
def split_args_kwargs_into_chunks( |
|
args: tuple[Any, ...], |
|
kwargs: Optional[dict[str, Any]], |
|
chunks: int, |
|
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, |
|
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, |
|
) -> tuple[list[tuple], list[dict]]: |
|
""" |
|
Given a sequence of args and kwargs, split them into a number of chunks |
|
according to their respective chunking specs. |
|
|
|
Args: |
|
args: Tuple of args |
|
kwargs: Dict of kwargs |
|
chunks: Number of chunks to split the args and kwargs into |
|
args_chunk_spec: chunking specs for args, in same shape as args |
|
kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs |
|
|
|
Returns: |
|
args_split: List of sharded args |
|
kwargs_split: List of sharded kwargs |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if kwargs is None: |
|
kwargs = {} |
|
|
|
|
|
|
|
if args_chunk_spec is None: |
|
args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args) |
|
|
|
if kwargs_chunk_spec is None: |
|
kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM)) |
|
|
|
args_split_dict = _shard_dict_of_args( |
|
dict(enumerate(args)), |
|
dict(enumerate(args_chunk_spec)), |
|
chunks, |
|
) |
|
real_num_chunks = len(args_split_dict) |
|
|
|
kwargs_split = _shard_dict_of_args( |
|
kwargs, |
|
kwargs_chunk_spec, |
|
real_num_chunks, |
|
) |
|
|
|
if len(kwargs_split) < real_num_chunks: |
|
|
|
|
|
real_num_chunks = len(kwargs_split) |
|
|
|
args_split_dict = _shard_dict_of_args( |
|
dict(enumerate(args)), |
|
dict(enumerate(args_chunk_spec)), |
|
real_num_chunks, |
|
) |
|
|
|
if len(args_split_dict) != len(kwargs_split): |
|
raise RuntimeError( |
|
"args and kwargs are split into different number of chunks: " |
|
f"{len(args_split_dict)}, {len(kwargs_split)}" |
|
) |
|
|
|
args_split = [ |
|
tuple(chunk_args[i] for i in range(len(chunk_args))) |
|
for chunk_args in args_split_dict |
|
] |
|
|
|
return args_split, kwargs_split |
|
|
|
|
|
def merge_chunks( |
|
chunks: list[Any], |
|
chunk_spec, |
|
): |
|
""" |
|
Given a list of chunks, merge them into a single value according to |
|
the chunk spec. |
|
|
|
Args: |
|
chunks: list of chunks |
|
chunk_spec: Chunking spec for the chunks |
|
|
|
Returns: |
|
value: Merged value |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if chunk_spec is not None: |
|
spec_flattened, flatten_spec = tree_flatten(chunk_spec) |
|
else: |
|
|
|
|
|
chunk0_flat, flatten_spec = tree_flatten(chunks[0]) |
|
spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat) |
|
|
|
|
|
|
|
chunks_flattened = [] |
|
|
|
for chunk in chunks: |
|
chunk_flattened, _ = tree_flatten(chunk) |
|
if len(chunk_flattened) != len(spec_flattened): |
|
raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}") |
|
|
|
chunks_flattened.append(chunk_flattened) |
|
|
|
|
|
|
|
|
|
args_flattened = [] |
|
for arg_idx, arg in enumerate(spec_flattened): |
|
if isinstance(arg, TensorChunkSpec): |
|
partial_values = [ |
|
chunks_flattened[chunk_idx][arg_idx] |
|
for chunk_idx in range(len(chunks_flattened)) |
|
] |
|
|
|
if _debug_mask_minibatches: |
|
|
|
overall_shape = partial_values[0].shape |
|
for val in partial_values[1:]: |
|
assert val.shape == overall_shape |
|
meta_chunks = torch.tensor_split( |
|
torch.empty(*overall_shape, device="meta"), |
|
sections=len(partial_values), |
|
dim=arg.split_dim, |
|
) |
|
|
|
values_to_cat = [] |
|
chunk_start_idx = 0 |
|
assert len(partial_values) == len(meta_chunks) |
|
for partial_value, meta_chunk in zip(partial_values, meta_chunks): |
|
chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim) |
|
|
|
slice_indices = [slice(None, None, None)] * partial_value.ndim |
|
slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx) |
|
sliced = partial_value[slice_indices] |
|
values_to_cat.append(sliced) |
|
|
|
chunk_start_idx = chunk_end_idx |
|
|
|
else: |
|
values_to_cat = partial_values |
|
|
|
args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim)) |
|
elif isinstance(arg, _CustomReducer): |
|
reduced_val = arg.init_value |
|
|
|
for chunk_idx in range(len(chunks_flattened)): |
|
reduced_val = arg.reduce_fn( |
|
reduced_val, chunks_flattened[chunk_idx][arg_idx] |
|
) |
|
|
|
args_flattened.append(reduced_val) |
|
else: |
|
value = chunks_flattened[0][arg_idx] |
|
for chunk_idx in range(1, len(chunks_flattened)): |
|
assert chunks_flattened[chunk_idx][arg_idx] == value |
|
args_flattened.append(value) |
|
|
|
|
|
return tree_unflatten(args_flattened, flatten_spec) |
|
|