|
|
|
|
|
|
|
|
|
|
|
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent |
|
from typing import Any, Dict, Optional, Tuple |
|
from datetime import datetime |
|
from torch.distributed.elastic.utils.distributed import get_free_port |
|
from torch.distributed.elastic.metrics import put_metric |
|
from torch.distributed.elastic.agent.server.api import ( |
|
RunResult, |
|
WorkerGroup, |
|
WorkerSpec, |
|
WorkerState, |
|
) |
|
from torch.distributed import Store |
|
import time |
|
import os |
|
from torch.distributed.elastic.multiprocessing import start_processes |
|
from torch.distributed.elastic.utils import macros |
|
import shutil |
|
import copy |
|
from contextlib import closing |
|
import subprocess |
|
|
|
from torch.distributed.elastic.utils.logging import get_logger |
|
|
|
log = get_logger(__name__) |
|
|
|
|
|
class DSElasticAgent(LocalElasticAgent): |
|
|
|
def __init__( |
|
self, |
|
spec: WorkerSpec, |
|
env: Dict, |
|
start_method="spawn", |
|
exit_barrier_timeout: float = 300, |
|
log_dir: Optional[str] = None, |
|
): |
|
super().__init__(spec, start_method, exit_barrier_timeout, log_dir) |
|
self.ds_env = env |
|
|
|
@staticmethod |
|
def _set_master_addr_port(store: Store, |
|
master_addr: Optional[str], |
|
master_port: Optional[int], |
|
local_addr: Optional[str] = None): |
|
if master_port is None: |
|
sock = get_free_port() |
|
with closing(sock): |
|
master_port = sock.getsockname()[1] |
|
|
|
if master_addr is None: |
|
|
|
import shlex |
|
safe_cmd = shlex.split("hostname -I") |
|
result = subprocess.check_output(safe_cmd) |
|
master_addr = result.decode('utf-8').split()[0] |
|
|
|
store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8")) |
|
store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8")) |
|
|
|
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: |
|
spec = worker_group.spec |
|
store = worker_group.store |
|
assert store is not None |
|
master_addr, master_port = super()._get_master_addr_port(store) |
|
restart_count = spec.max_restarts - self._remaining_restarts |
|
|
|
use_agent_store = spec.rdzv_handler.get_backend() == "static" |
|
|
|
args: Dict[int, Tuple] = {} |
|
envs: Dict[int, Dict[str, str]] = {} |
|
for worker in worker_group.workers: |
|
local_rank = worker.local_rank |
|
|
|
worker_env_ds = copy.deepcopy(self.ds_env) |
|
worker_env_elastic = { |
|
"LOCAL_RANK": str(local_rank), |
|
"RANK": str(worker.global_rank), |
|
"GROUP_RANK": str(worker_group.group_rank), |
|
"ROLE_RANK": str(worker.role_rank), |
|
"ROLE_NAME": spec.role, |
|
"LOCAL_WORLD_SIZE": str(spec.local_world_size), |
|
"WORLD_SIZE": str(worker.world_size), |
|
"GROUP_WORLD_SIZE": str(worker_group.group_world_size), |
|
"ROLE_WORLD_SIZE": str(worker.role_world_size), |
|
"MASTER_ADDR": master_addr, |
|
"MASTER_PORT": str(master_port), |
|
"TORCHELASTIC_RESTART_COUNT": str(restart_count), |
|
"TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), |
|
"TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), |
|
"TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), |
|
"NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING", str(1)), |
|
} |
|
worker_env_ds.update(worker_env_elastic) |
|
if "OMP_NUM_THREADS" in os.environ: |
|
worker_env_ds["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] |
|
|
|
envs[local_rank] = worker_env_ds |
|
worker_args = list(spec.args) |
|
worker_args = macros.substitute(worker_args, str(local_rank)) |
|
args[local_rank] = tuple(worker_args) |
|
|
|
|
|
|
|
attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}") |
|
shutil.rmtree(attempt_log_dir, ignore_errors=True) |
|
os.makedirs(attempt_log_dir) |
|
|
|
assert spec.entrypoint is not None |
|
self._pcontext = start_processes( |
|
name=spec.role, |
|
entrypoint=spec.entrypoint, |
|
args=args, |
|
envs=envs, |
|
log_dir=attempt_log_dir, |
|
start_method=self._start_method, |
|
redirects=spec.redirects, |
|
tee=spec.tee, |
|
) |
|
|
|
return self._pcontext.pids() |
|
|
|
def _invoke_run(self, role: str = "default") -> RunResult: |
|
|
|
|
|
spec = self._worker_group.spec |
|
role = spec.role |
|
|
|
log.info(f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}") |
|
|
|
self._initialize_workers(self._worker_group) |
|
monitor_interval = spec.monitor_interval |
|
rdzv_handler = spec.rdzv_handler |
|
|
|
participants = rdzv_handler._state_holder.state.participants |
|
|
|
while True: |
|
assert self._worker_group.state != WorkerState.INIT |
|
time.sleep(monitor_interval) |
|
run_result = self._monitor_workers(self._worker_group) |
|
state = run_result.state |
|
self._worker_group.state = state |
|
|
|
expire_time = datetime.utcnow() - (rdzv_handler._settings.keep_alive_interval * |
|
rdzv_handler._settings.keep_alive_max_attempt) |
|
_dead_nodes = [ |
|
node for node, last_heartbeat in rdzv_handler._state_holder.state.last_heartbeats.items() |
|
if last_heartbeat < expire_time |
|
] |
|
|
|
put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) |
|
put_metric(f"workers.{role}.{state.name.lower()}", 1) |
|
|
|
if state == WorkerState.SUCCEEDED: |
|
log.info(f"[{role}] worker group successfully finished." |
|
f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.") |
|
self._exit_barrier() |
|
return run_result |
|
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED} or len(participants) > len( |
|
rdzv_handler._state_holder.state.participants): |
|
if self._remaining_restarts > 0: |
|
log.info(f"[{role}] Worker group {state.name}. " |
|
f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" |
|
f" will restart worker group") |
|
self._remaining_restarts -= 1 |
|
|
|
self._restart_workers(self._worker_group) |
|
participants = rdzv_handler._state_holder.state.participants |
|
|
|
else: |
|
self._stop_workers(self._worker_group) |
|
self._worker_group.state = WorkerState.FAILED |
|
self._exit_barrier() |
|
return run_result |
|
elif state == WorkerState.HEALTHY: |
|
|
|
num_nodes_waiting = rdzv_handler.num_nodes_waiting() |
|
group_rank = self._worker_group.group_rank |
|
if num_nodes_waiting > 0: |
|
log.info(f"[{role}] Detected {num_nodes_waiting} " |
|
f"new nodes from group_rank={group_rank}; " |
|
f"will restart worker group") |
|
self._restart_workers(self._worker_group) |
|
participants = rdzv_handler._state_holder.state.participants |
|
else: |
|
raise Exception(f"[{role}] Worker group in {state.name} state") |
|
|