|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
from threading import Event |
|
from typing import Dict, List, Optional, Union |
|
|
|
import torch |
|
|
|
from .async_schedule import AsyncEventLoop, ModuleWrapper |
|
from .messages import MakeTransport |
|
from .microbatch import Batch |
|
from .skip.layout import SkipLayout |
|
from .skip.tracker import SkipTrackerThroughPotals |
|
|
|
|
|
class AsyncPipeline: |
|
"""The async pipeline parallelism for Pipe.""" |
|
|
|
def __init__( |
|
self, |
|
partitions: List[ModuleWrapper], |
|
skip_layout: SkipLayout, |
|
checkpoint_stop: int, |
|
group: torch.distributed.ProcessGroup, |
|
*, |
|
worker_map: Optional[Dict[int, str]] = None, |
|
input_device: Union[None, int, str, torch.device] = None, |
|
final_stage: bool = False, |
|
) -> None: |
|
self.partitions = partitions |
|
self.skip_layout = skip_layout |
|
self.__checkpoint_stop = checkpoint_stop |
|
self.group = group |
|
self.training: bool |
|
self.transport = MakeTransport( |
|
use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ), |
|
worker_map=worker_map, |
|
input_device=input_device, |
|
) |
|
self.input_device = input_device |
|
self.final_stage = final_stage |
|
|
|
@property |
|
def checkpoint_stop(self) -> int: |
|
|
|
training = self.partitions[0].module.training |
|
if not training: |
|
return 0 |
|
return self.__checkpoint_stop |
|
|
|
def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None: |
|
|
|
"""Runs pipeline parallelism. |
|
|
|
It modifies the given batches in place. |
|
|
|
""" |
|
self.training = training |
|
|
|
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))] |
|
|
|
rank = self.group.rank() |
|
event_loop = AsyncEventLoop( |
|
self.partitions, |
|
self.group, |
|
self.transport, |
|
self.training, |
|
self.checkpoint_stop, |
|
) |
|
if rank == 0 and not self.final_stage: |
|
logging.debug(f"{torch.distributed.get_rank()}: entered event head") |
|
event_loop.event_loop_head(batches, skip_trackers, event) |
|
logging.debug(f"{torch.distributed.get_rank()}: exited event head") |
|
elif self.final_stage: |
|
logging.debug(f"{torch.distributed.get_rank()}: entered event tail") |
|
event_loop.event_loop_tail(batches, skip_trackers) |
|
logging.debug(f"{torch.distributed.get_rank()}: exited event tail") |
|
else: |
|
logging.debug(f"{torch.distributed.get_rank()}: entered event loop") |
|
event_loop.event_loop(len(batches), skip_trackers) |
|
logging.debug(f"{torch.distributed.get_rank()}: exited event loop") |
|
|
|
def back_helper(self, output: List[Batch]) -> None: |
|
pass |
|
|