#!/usr/bin/env python3 # mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import sys import uuid from dataclasses import dataclass, field from typing import Any, Callable, Optional, Union import torch.distributed.elastic.rendezvous.registry as rdzv_registry from torch.distributed.elastic import events, metrics from torch.distributed.elastic.agent.server.api import WorkerSpec from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent from torch.distributed.elastic.multiprocessing import ( DefaultLogsSpecs, LogsSpecs, SignalException, ) from torch.distributed.elastic.multiprocessing.errors import ChildFailedError from torch.distributed.elastic.rendezvous import RendezvousParameters from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint from torch.distributed.elastic.utils.logging import get_logger __all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] logger = get_logger(__name__) @dataclass class LaunchConfig: """ Creates a rendezvous config. Args: min_nodes: Minimum amount of nodes that the user function will be launched on. Elastic agent ensures that the user function start only when the min_nodes amount enters the rendezvous. max_nodes: Maximum amount of nodes that the user function will be launched on. nproc_per_node: On each node the elastic agent will launch this amount of workers that will execute user defined function. rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd). rdzv_endpoint: The endpoint of the rdzv sync. storage. rdzv_configs: Key, value pair that specifies rendezvous specific configuration. rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going to be removed in future versions, see the note below. The default timeout is 900 seconds. run_id: The unique run id of the job (if not passed a unique one will be deduced from run environment - flow workflow id in flow - or auto generated). role: User defined role of the worker (defaults to "trainer"). max_restarts: The maximum amount of restarts that elastic agent will conduct on workers before failure. monitor_interval: The interval in seconds that is used by the elastic_agent as a period of monitoring workers. start_method: The method is used by the elastic agent to start the workers (spawn, fork, forkserver). metrics_cfg: configuration to initialize metrics. local_addr: address of the local node if any. If not set, a lookup on the local machine's FQDN will be performed. local_ranks_filter: ranks for which to show logs in console. If not set, show from all. .. note:: `rdzv_timeout` is a legacy argument that will be removed in future. Set the timeout via `rdzv_configs['timeout']` """ min_nodes: int max_nodes: int nproc_per_node: int logs_specs: Optional[LogsSpecs] = None run_id: str = "" role: str = "default_role" rdzv_endpoint: str = "" rdzv_backend: str = "etcd" rdzv_configs: dict[str, Any] = field(default_factory=dict) rdzv_timeout: int = -1 max_restarts: int = 3 monitor_interval: float = 0.1 start_method: str = "spawn" log_line_prefix_template: Optional[str] = None metrics_cfg: dict[str, str] = field(default_factory=dict) local_addr: Optional[str] = None def __post_init__(self): default_timeout = 900 if self.rdzv_timeout != -1: self.rdzv_configs["timeout"] = self.rdzv_timeout elif "timeout" not in self.rdzv_configs: self.rdzv_configs["timeout"] = default_timeout # Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage if self.logs_specs is None: self.logs_specs = DefaultLogsSpecs() class elastic_launch: """ Launches an torchelastic agent on the container that invoked the entrypoint. 1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/ ``entrypoint`` can be a function or a command. 2. The return value is a map of each worker's output mapped by their respective global rank. Usage :: def worker_fn(foo): # ... def main(): # entrypoint is a function. outputs = elastic_launch(LaunchConfig, worker_fn)(foo) # return rank 0's output return outputs[0] # entrypoint is a command and ``script.py`` is the python module. outputs = elastic_launch(LaunchConfig, "script.py")(args) outputs = elastic_launch(LaunchConfig, "python")("script.py") """ def __init__( self, config: LaunchConfig, entrypoint: Union[Callable, str, None], ): self._config = config self._entrypoint = entrypoint def __call__(self, *args): return launch_agent(self._config, self._entrypoint, list(args)) def _get_entrypoint_name( entrypoint: Union[Callable, str, None], args: list[Any] ) -> str: """Retrieve entrypoint name with the rule: 1. If entrypoint is a function, use ``entrypoint.__qualname__``. 2. If entrypoint is a string, check its value: 2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args`` which does not start with hifen letter (for example, "-u" will be skipped). 2.2 otherwise, use ``entrypoint`` value. 3. Otherwise, return empty string. """ if isinstance(entrypoint, Callable): # type: ignore[arg-type] return entrypoint.__name__ # type: ignore[union-attr] elif isinstance(entrypoint, str): if entrypoint == sys.executable: return next((arg for arg in args if arg[0] != "-"), "") else: return entrypoint else: return "" def _get_addr_and_port( rdzv_parameters: RendezvousParameters, ) -> tuple[Optional[str], Optional[int]]: if rdzv_parameters.backend != "static": return (None, None) endpoint = rdzv_parameters.endpoint endpoint = endpoint.strip() if not endpoint: raise ValueError( "Endpoint is missing in endpoint. Try to add --master-addr and --master-port" ) master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1) if master_port == -1: raise ValueError( f"port is missing in endpoint: {endpoint}. Try to specify --master-port" ) return (master_addr, master_port) def launch_agent( config: LaunchConfig, entrypoint: Union[Callable, str, None], args: list[Any], ) -> dict[int, Any]: if not config.run_id: run_id = str(uuid.uuid4().int) logger.warning("config has no run_id, generated a random run_id: %s", run_id) config.run_id = run_id entrypoint_name = _get_entrypoint_name(entrypoint, args) logger.info( "Starting elastic_operator with launch configs:\n" " entrypoint : %(entrypoint)s\n" " min_nodes : %(min_nodes)s\n" " max_nodes : %(max_nodes)s\n" " nproc_per_node : %(nproc_per_node)s\n" " run_id : %(run_id)s\n" " rdzv_backend : %(rdzv_backend)s\n" " rdzv_endpoint : %(rdzv_endpoint)s\n" " rdzv_configs : %(rdzv_configs)s\n" " max_restarts : %(max_restarts)s\n" " monitor_interval : %(monitor_interval)s\n" " log_dir : %(log_dir)s\n" " metrics_cfg : %(metrics_cfg)s\n", { "entrypoint": entrypoint_name, "min_nodes": config.min_nodes, "max_nodes": config.max_nodes, "nproc_per_node": config.nproc_per_node, "run_id": config.run_id, "rdzv_backend": config.rdzv_backend, "rdzv_endpoint": config.rdzv_endpoint, "rdzv_configs": config.rdzv_configs, "max_restarts": config.max_restarts, "monitor_interval": config.monitor_interval, "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] "metrics_cfg": config.metrics_cfg, }, ) rdzv_parameters = RendezvousParameters( backend=config.rdzv_backend, endpoint=config.rdzv_endpoint, run_id=config.run_id, min_nodes=config.min_nodes, max_nodes=config.max_nodes, local_addr=config.local_addr, **config.rdzv_configs, ) master_addr, master_port = _get_addr_and_port(rdzv_parameters) spec = WorkerSpec( role=config.role, local_world_size=config.nproc_per_node, entrypoint=entrypoint, args=tuple(args), rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), max_restarts=config.max_restarts, monitor_interval=config.monitor_interval, master_addr=master_addr, master_port=master_port, local_addr=config.local_addr, ) agent = LocalElasticAgent( spec=spec, logs_specs=config.logs_specs, # type: ignore[arg-type] start_method=config.start_method, log_line_prefix_template=config.log_line_prefix_template, ) shutdown_rdzv = True try: metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) result = agent.run() # records that agent.run() has succeeded NOT that workers have succeeded events.record(agent.get_event_succeeded()) if result.is_failed(): # ChildFailedError is treated specially by @record # if the error files for the failed children exist # @record will copy the first error (root cause) # to the error file of the launcher process. raise ChildFailedError( name=entrypoint_name, failures=result.failures, ) return result.return_values except ChildFailedError: raise except SignalException: # when the agent dies with a signal do NOT shutdown the rdzv_handler # since this closes the rendezvous on this rdzv_id permanently and # prevents any additional scaling events shutdown_rdzv = False events.record(agent.get_event_failed()) raise except Exception: events.record(agent.get_event_failed()) raise finally: if shutdown_rdzv: spec.rdzv_handler.shutdown()