# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import logging import os from threading import Event from typing import Dict, List, Optional, Union import torch from .async_schedule import AsyncEventLoop, ModuleWrapper from .messages import MakeTransport from .microbatch import Batch from .skip.layout import SkipLayout from .skip.tracker import SkipTrackerThroughPotals class AsyncPipeline: """The async pipeline parallelism for Pipe.""" def __init__( self, partitions: List[ModuleWrapper], skip_layout: SkipLayout, checkpoint_stop: int, group: torch.distributed.ProcessGroup, *, worker_map: Optional[Dict[int, str]] = None, input_device: Union[None, int, str, torch.device] = None, final_stage: bool = False, ) -> None: self.partitions = partitions self.skip_layout = skip_layout self.__checkpoint_stop = checkpoint_stop self.group = group self.training: bool self.transport = MakeTransport( use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ), worker_map=worker_map, input_device=input_device, ) self.input_device = input_device self.final_stage = final_stage @property def checkpoint_stop(self) -> int: # Disable checkpointing if in eval mode. training = self.partitions[0].module.training if not training: return 0 return self.__checkpoint_stop def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None: """Runs pipeline parallelism. It modifies the given batches in place. """ self.training = training skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))] rank = self.group.rank() event_loop = AsyncEventLoop( self.partitions, self.group, self.transport, self.training, self.checkpoint_stop, ) if rank == 0 and not self.final_stage: logging.debug(f"{torch.distributed.get_rank()}: entered event head") event_loop.event_loop_head(batches, skip_trackers, event) logging.debug(f"{torch.distributed.get_rank()}: exited event head") elif self.final_stage: logging.debug(f"{torch.distributed.get_rank()}: entered event tail") event_loop.event_loop_tail(batches, skip_trackers) logging.debug(f"{torch.distributed.get_rank()}: exited event tail") else: logging.debug(f"{torch.distributed.get_rank()}: entered event loop") event_loop.event_loop(len(batches), skip_trackers) logging.debug(f"{torch.distributed.get_rank()}: exited event loop") def back_helper(self, output: List[Batch]) -> None: pass