|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from types import MethodType |
|
from typing import Any, Optional, Union |
|
|
|
from .state import PartialState |
|
from .utils import ( |
|
calculate_maximum_sizes, |
|
convert_bytes, |
|
copy_tensor_to_devices, |
|
ignorant_find_batch_size, |
|
infer_auto_device_map, |
|
is_pippy_available, |
|
pad_input_tensors, |
|
send_to_device, |
|
) |
|
|
|
|
|
def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None): |
|
""" |
|
Calculates the device map for `model` with an offset for PiPPy |
|
""" |
|
if num_processes == 1: |
|
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False) |
|
if max_memory is None: |
|
model_size, shared = calculate_maximum_sizes(model) |
|
|
|
|
|
memory = (model_size + shared[0]) / num_processes |
|
memory = convert_bytes(memory) |
|
value, ending = memory.split(" ") |
|
|
|
|
|
memory = math.ceil(float(value)) * 1.1 |
|
memory = f"{memory} {ending}" |
|
max_memory = {i: memory for i in range(num_processes)} |
|
device_map = infer_auto_device_map( |
|
model, |
|
max_memory=max_memory, |
|
no_split_module_classes=no_split_module_classes, |
|
clean_result=False, |
|
) |
|
return device_map |
|
|
|
|
|
def find_pippy_batch_size(args, kwargs): |
|
found_batch_size = None |
|
if args is not None: |
|
for arg in args: |
|
found_batch_size = ignorant_find_batch_size(arg) |
|
if found_batch_size is not None: |
|
break |
|
if kwargs is not None and found_batch_size is None: |
|
for kwarg in kwargs.values(): |
|
found_batch_size = ignorant_find_batch_size(kwarg) |
|
if found_batch_size is not None: |
|
break |
|
return found_batch_size |
|
|
|
|
|
def build_pipeline(model, split_points, args, kwargs, num_chunks): |
|
""" |
|
Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing |
|
in needed `args` and `kwargs` as the model needs on the CPU. |
|
|
|
Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use |
|
`AcceleratorState.num_processes` |
|
""" |
|
|
|
from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline |
|
|
|
|
|
state = PartialState() |
|
split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points} |
|
pipe = pipeline( |
|
model, |
|
mb_args=args, |
|
mb_kwargs=kwargs, |
|
split_spec=split_spec, |
|
) |
|
stage = pipe.build_stage(state.local_process_index, device=state.device) |
|
schedule = ScheduleGPipe(stage, num_chunks) |
|
|
|
return schedule |
|
|
|
|
|
def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs): |
|
state = PartialState() |
|
output = None |
|
|
|
if state.num_processes == 1: |
|
output = forward(*args, **kwargs) |
|
elif state.is_local_main_process: |
|
found_batch_size = find_pippy_batch_size(args, kwargs) |
|
if found_batch_size is None: |
|
raise ValueError("Could not find batch size from args or kwargs") |
|
else: |
|
if found_batch_size != num_chunks: |
|
args = pad_input_tensors(args, found_batch_size, num_chunks) |
|
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks) |
|
forward(*args, **kwargs) |
|
elif state.is_last_process: |
|
output = forward() |
|
else: |
|
forward() |
|
if gather_output: |
|
|
|
output = copy_tensor_to_devices(output) |
|
return output |
|
|
|
|
|
def prepare_pippy( |
|
model, |
|
split_points: Optional[Union[str, list[str]]] = "auto", |
|
no_split_module_classes: Optional[list[str]] = None, |
|
example_args: Optional[tuple[Any]] = (), |
|
example_kwargs: Optional[dict[str, Any]] = None, |
|
num_chunks: Optional[int] = None, |
|
gather_output: Optional[bool] = False, |
|
): |
|
""" |
|
Wraps `model` for pipeline parallel inference. |
|
|
|
Args: |
|
model (`torch.nn.Module`): |
|
A model we want to split for pipeline-parallel inference |
|
split_points (`str` or `List[str]`, defaults to 'auto'): |
|
How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced |
|
split given any model. Should be a list of layer names in the model to split by otherwise. |
|
no_split_module_classes (`List[str]`): |
|
A list of class names for layers we don't want to be split. |
|
example_args (tuple of model inputs): |
|
The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use |
|
this method if possible. |
|
example_kwargs (dict of model inputs) |
|
The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a |
|
*highly* limiting structure that requires the same keys be present at *all* inference calls. Not |
|
recommended unless the prior condition is true for all cases. |
|
num_chunks (`int`, defaults to the number of available GPUs): |
|
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but |
|
this can be tuned and played with. In general one should have num_chunks >= num_gpus. |
|
gather_output (`bool`, defaults to `False`): |
|
If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs. |
|
""" |
|
if not is_pippy_available(): |
|
raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.") |
|
state = PartialState() |
|
example_args = send_to_device(example_args, "cpu") |
|
example_kwargs = send_to_device(example_kwargs, "cpu") |
|
if num_chunks is None: |
|
num_chunks = state.num_processes |
|
if split_points == "auto": |
|
device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes) |
|
split_points = [] |
|
for i in range(1, num_chunks): |
|
split_points.append(next(k for k, v in device_map.items() if v == i)) |
|
model.hf_split_points = split_points |
|
stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks) |
|
model._original_forward = model.forward |
|
model._original_call = model.__call__ |
|
model.pippy_stage = stage |
|
model.hf_split_points = split_points |
|
|
|
def forward(*args, **kwargs): |
|
return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs) |
|
|
|
|
|
|
|
model_forward = MethodType(forward, model) |
|
forward.__wrapped__ = model_forward |
|
model.forward = forward |
|
return model |
|
|