File size: 16,188 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 |
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from typing import Any, Optional
import torch
from torch.fx.node import map_aggregate
from torch.utils._pytree import tree_flatten, tree_unflatten
__all__ = [
"TensorChunkSpec",
"split_args_kwargs_into_chunks",
"merge_chunks",
]
logger = logging.getLogger(__name__)
"""
_debug_mask_minibatches specifies to send masked versions of the mini-batch
through instead of micro-batch slices--this can be used for more stable
numerical testing (see [A Note About Correctness Testing])
"""
_debug_mask_minibatches = False
class _CustomReducer:
"""
Custom reducer class that can be used to specify a custom operation that
reduces losses of multiple microbatches into one value.
Example:
>>> # xdoctest: +SKIP
>>> sum_reducer = _CustomReducer(
>>> torch.tensor(0.0),
>>> lambda a, b: a + b
>>> )
"""
def __init__(self, init_value, reduce_fn):
self.init_value = init_value
self.reduce_fn = reduce_fn
class _LossReducer(_CustomReducer):
pass
sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b)
# Default chunking dimension is 0. This is used for the case where the user did
# not specify a chunking dimension.
DEFAULT_CHUNK_DIM = 0
class TensorChunkSpec:
"""
Class used to specify chunking of inputs
"""
def __init__(self, split_dim):
self.split_dim = split_dim
split_dim: int
def __repr__(self):
return (
f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})"
)
def __str__(self):
return f"TensorChunkSpec({self.split_dim})"
@staticmethod
def from_tuple(
chunk_dims: tuple[int, ...],
):
"""
A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
dimensions (int's).
Example:
>>> # xdoctest: +SKIP
>>> # There are three positional arguments to the model, and
>>> # we are chunking them along dimension 0, 0 and 1, respectively
>>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
"""
args_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
)
return args_chunk_spec
@staticmethod
def from_dict(
chunk_dims: dict[str, int],
):
"""
A helper for creating a dictionary of `TensorChunkSpec` from a
dictionary of chunk dimensions (int's).
Example:
>>> # xdoctest: +SKIP
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
>>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
"""
kwargs_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
)
return kwargs_chunk_spec
# Class used to specify replication of inputs
class _Replicate:
pass
def _shard_dict_of_args(
args_dict,
args_chunk_spec,
num_chunks,
):
"""
Given a dictionary of args, and a dictionary of chunking specs, shard the
args according to the chunking specs.
Args:
args_dict: Dictionary of args
args_chunk_spec: Dictionary of chunking specs
num_chunks: Number of chunks to shard the args into
Returns:
args_split: List of sharded args
"""
# Stage 1+2: flatten and shard/replicate
# args_sharded_replicated : [num args, num flat values, num chunks]
args_sharded_replicated = {}
arg_specs = []
real_num_chunks = num_chunks
first_tensor = True
assert len(args_dict) == len(args_chunk_spec), (
f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
)
for arg_key, arg in args_dict.items():
flat, spec = tree_flatten(arg)
arg_specs.append(spec)
chunk_spec = args_chunk_spec[arg_key]
assert chunk_spec is not None # Should have been set by caller
chunk_spec_flat, _ = tree_flatten(chunk_spec)
if len(flat) != len(chunk_spec_flat):
raise ValueError(
f"Argument value {arg} did not have the same number of "
f"values as as chunk spec {chunk_spec}"
)
sharded_arg_flat = []
for v, chunk_v in zip(flat, chunk_spec_flat):
if chunk_v is _Replicate or not isinstance(v, torch.Tensor):
sharded_arg_flat.append([v] * real_num_chunks)
elif isinstance(chunk_v, TensorChunkSpec):
# TODO: check type of v. If it's a tensor, use chunk (or debug mask).
# If it's a collection type, split it as you would expect. Otherwise,
# Throw an error
assert isinstance(v, torch.Tensor), f"{v} is not a tensor"
v_split_dim_size = v.size(chunk_v.split_dim)
if v_split_dim_size < real_num_chunks:
if first_tensor:
# We can only adjust number of chunks when we hit this
# issue at the first tensor encountered
logger.warning(
f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004
f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}."
)
real_num_chunks = v_split_dim_size
else:
raise RuntimeError(
f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, "
f"smaller than the number of chunks {num_chunks}. "
"PiPPy cannot reduce the number of chunks because "
"other arguments have bigger chunk-dimension sizes. "
"Please adjust your num_chunks setting."
)
chunk_tensors = torch.tensor_split(
v, real_num_chunks, chunk_v.split_dim
)
if _debug_mask_minibatches:
expanded_chunks = []
split_dim_idx = 0
for chunk_tensor in chunk_tensors:
new_val = torch.zeros_like(v)
upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim)
slice_indices = [slice(None, None, None)] * new_val.ndim
slice_indices[chunk_v.split_dim] = slice(
split_dim_idx, upper_idx
)
new_val[slice_indices] = chunk_tensor
expanded_chunks.append(new_val)
split_dim_idx += chunk_tensor.size(chunk_v.split_dim)
sharded_arg_flat.append(expanded_chunks)
else:
sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type]
first_tensor = False
else:
raise TypeError(f"Unrecognized chunk spec: {chunk_v}")
args_sharded_replicated[arg_key] = sharded_arg_flat
# chunks_flat : [num chunks, num args, num flat values]
chunks_flat = []
for chunk_idx in range(real_num_chunks):
chunk_args = {}
for key, arg in args_sharded_replicated.items():
arg_single_chunk = [v_flat[chunk_idx] for v_flat in arg]
chunk_args[key] = arg_single_chunk
chunks_flat.append(chunk_args)
# args_split : [num chunks, num args]
args_split = []
for chunk in chunks_flat:
per_chunk_args = {}
assert len(arg_specs) == len(chunk)
for (key, arg), arg_spec in zip(chunk.items(), arg_specs):
per_chunk_args[key] = tree_unflatten(arg, arg_spec)
args_split.append(per_chunk_args)
return args_split
def split_args_kwargs_into_chunks(
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]],
chunks: int,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
) -> tuple[list[tuple], list[dict]]:
"""
Given a sequence of args and kwargs, split them into a number of chunks
according to their respective chunking specs.
Args:
args: Tuple of args
kwargs: Dict of kwargs
chunks: Number of chunks to split the args and kwargs into
args_chunk_spec: chunking specs for args, in same shape as args
kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs
Returns:
args_split: List of sharded args
kwargs_split: List of sharded kwargs
"""
# Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
# the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
# and `kwargs_chunk_spec` specifications. The steps are as follows:
#
# 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.
# To use a running example: suppose our inputs look like
#
# args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)
# (kwargs not shown but it's a similar process)
#
# Then for this step we would end up with
#
# args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)
#
# 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2
#
# args = ([[A, A], [B, B], [C_1, C_2]], [D, D])
#
# 3. Rotate the nesting order such that chunks are the outer dimension
#
# args_chunks = [
# ([A, B, C_1], D),
# ([A, B, C_2], D),
# ]
#
# 4. Unflatten each chunk according to the spec
#
# args_chunks = [
# ([A, [B, C_1]], D),
# ([A, [B, C_2]], D),
# ]
# TODO: _debug_mask_minibatches
# Handle the case where kwargs is None
if kwargs is None:
kwargs = {}
# If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend
# their format and use default chunking along dim 0
if args_chunk_spec is None:
args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args)
if kwargs_chunk_spec is None:
kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM))
args_split_dict = _shard_dict_of_args(
dict(enumerate(args)),
dict(enumerate(args_chunk_spec)),
chunks,
)
real_num_chunks = len(args_split_dict)
kwargs_split = _shard_dict_of_args(
kwargs,
kwargs_chunk_spec,
real_num_chunks,
)
if len(kwargs_split) < real_num_chunks:
# In case kwargs are sharded into less chunks
# e.g. when `args` has no tensor, just values
real_num_chunks = len(kwargs_split)
# Re-shard args
args_split_dict = _shard_dict_of_args(
dict(enumerate(args)),
dict(enumerate(args_chunk_spec)),
real_num_chunks,
)
if len(args_split_dict) != len(kwargs_split):
raise RuntimeError(
"args and kwargs are split into different number of chunks: "
f"{len(args_split_dict)}, {len(kwargs_split)}"
)
args_split = [
tuple(chunk_args[i] for i in range(len(chunk_args)))
for chunk_args in args_split_dict
]
return args_split, kwargs_split
def merge_chunks(
chunks: list[Any],
chunk_spec,
):
"""
Given a list of chunks, merge them into a single value according to
the chunk spec.
Args:
chunks: list of chunks
chunk_spec: Chunking spec for the chunks
Returns:
value: Merged value
"""
# This is essentially the inverse of `split_args_kwargs_into_chunks`, so the
# steps are similar to the steps in that function but in reverse. Given the
# input values:
#
# chunks = [
# ([A, [B, C_1]], D),
# ([A, [B, C_2]], D),
# ]
# args_spec = ([None, [None, TensorChunkSpec]], None)
#
# 1. Flatten the chunks according to the chunk_spec
#
# chunks_flat = [
# ([A, B, C_1], D),
# ([A, B, C_2], D),
# ]
#
# 2. Rotate the nesting order such that chunks are the inner dimension
#
# value_inner = ([A, B, [C_1, C_2]], D)
#
# 3. Concatenate sharded arguments
#
# value_combined = ([A, B, C], D)
#
# 4. Unflatten the combined args given the spec
#
# value = ([A, [B, C]], D)
# Preliminary: flatten the chunk spec
if chunk_spec is not None:
spec_flattened, flatten_spec = tree_flatten(chunk_spec)
else:
# If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields
# We obtain the output structure by flattening chunk 0 and generate the chunk_spec
chunk0_flat, flatten_spec = tree_flatten(chunks[0])
spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat)
# Stage 1: flatten chunks
# chunks_flattened : [num chunks, num args]
chunks_flattened = []
for chunk in chunks:
chunk_flattened, _ = tree_flatten(chunk)
if len(chunk_flattened) != len(spec_flattened):
raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")
chunks_flattened.append(chunk_flattened)
# Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and
# concatenate sharded operands
# args_flattened : [num args]
args_flattened = []
for arg_idx, arg in enumerate(spec_flattened):
if isinstance(arg, TensorChunkSpec):
partial_values = [
chunks_flattened[chunk_idx][arg_idx]
for chunk_idx in range(len(chunks_flattened))
]
if _debug_mask_minibatches:
# Infer size of individual chunks by running `tensor_split` again
overall_shape = partial_values[0].shape
for val in partial_values[1:]:
assert val.shape == overall_shape
meta_chunks = torch.tensor_split(
torch.empty(*overall_shape, device="meta"),
sections=len(partial_values),
dim=arg.split_dim,
)
values_to_cat = []
chunk_start_idx = 0
assert len(partial_values) == len(meta_chunks)
for partial_value, meta_chunk in zip(partial_values, meta_chunks):
chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
slice_indices = [slice(None, None, None)] * partial_value.ndim
slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx)
sliced = partial_value[slice_indices]
values_to_cat.append(sliced)
chunk_start_idx = chunk_end_idx
else:
values_to_cat = partial_values
args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim))
elif isinstance(arg, _CustomReducer):
reduced_val = arg.init_value
for chunk_idx in range(len(chunks_flattened)):
reduced_val = arg.reduce_fn(
reduced_val, chunks_flattened[chunk_idx][arg_idx]
)
args_flattened.append(reduced_val)
else:
value = chunks_flattened[0][arg_idx]
for chunk_idx in range(1, len(chunks_flattened)):
assert chunks_flattened[chunk_idx][arg_idx] == value
args_flattened.append(value)
# Stage 4: Unflatten combined args
return tree_unflatten(args_flattened, flatten_spec)
|