|
|
|
|
|
|
|
|
|
""" |
|
DeepSpeed runner is the main front-end to launching multi-worker |
|
training jobs with DeepSpeed. By default this uses pdsh to parallel |
|
ssh into multiple worker nodes and launch all the necessary processes |
|
per rank for training. |
|
""" |
|
|
|
import os |
|
import re |
|
import sys |
|
import json |
|
import base64 |
|
import argparse |
|
import subprocess |
|
import collections |
|
from copy import deepcopy |
|
import signal |
|
import time |
|
from typing import Tuple, List, Dict |
|
from collections import defaultdict |
|
import shlex |
|
|
|
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner, IMPIRunner |
|
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER, IMPI_LAUNCHER |
|
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT |
|
from ..nebula.constants import NEBULA_EXPORT_ENVS |
|
from ..utils import logger |
|
|
|
from ..autotuning import Autotuner |
|
from deepspeed.accelerator import get_accelerator |
|
|
|
DLTS_HOSTFILE = "/job/hostfile" |
|
EXPORT_ENVS = ['MLFLOW', 'PYTHON', 'MV2', 'UCX'] |
|
EXPORT_ENVS += NEBULA_EXPORT_ENVS |
|
DEEPSPEED_ENVIRONMENT_NAME = os.getenv("DS_ENV_FILE", ".deepspeed_env") |
|
DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.'] |
|
PDSH_MAX_FAN_OUT = 1024 |
|
|
|
|
|
|
|
EXCLUDE_ENVS = {'AISC_JOB_NAME': ['NCCL_IB_HCA', 'UCX_NET_DEVICES']} |
|
|
|
|
|
def parse_args(args=None): |
|
parser = argparse.ArgumentParser(description="DeepSpeed runner to help launch distributed " |
|
"multi-node/multi-gpu training jobs.", |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
|
parser.add_argument("-H", |
|
"--hostfile", |
|
type=str, |
|
default=DLTS_HOSTFILE, |
|
help="Hostfile path (in MPI style) that defines the " |
|
"resource pool available to the job (e.g., " |
|
"worker-0 slots=4)") |
|
|
|
parser.add_argument("-i", |
|
"--include", |
|
type=str, |
|
default="", |
|
help='''Specify hardware resources to use during execution. |
|
String format is |
|
NODE_SPEC[@NODE_SPEC ...], |
|
where |
|
NODE_SPEC=NAME[:SLOT[,SLOT ...]]. |
|
If :SLOT is omitted, include all slots on that host. |
|
Example: -i "worker-0@worker-1:0,2" will use all slots |
|
on worker-0 and slots [0, 2] on worker-1. |
|
''') |
|
|
|
parser.add_argument("-e", |
|
"--exclude", |
|
type=str, |
|
default="", |
|
help='''Specify hardware resources to NOT use during execution. |
|
Mutually exclusive with --include. Resource formatting |
|
is the same as --include. |
|
Example: -e "worker-1:0" will use all available |
|
resources except slot 0 on worker-1. |
|
''') |
|
|
|
parser.add_argument("--num_nodes", |
|
type=int, |
|
default=-1, |
|
help="Total number of worker nodes to run on, this will use " |
|
"the top N hosts from the given hostfile.") |
|
|
|
parser.add_argument("--min_elastic_nodes", |
|
type=int, |
|
default=-1, |
|
help="Minimum number of nodes to run elastic training on. " |
|
"Default is 1 when elastic training is enabled") |
|
|
|
parser.add_argument("--max_elastic_nodes", |
|
type=int, |
|
default=-1, |
|
help="Maximum number of nodes to run elastic training on. " |
|
"Default is num_nodes when elastic training is enabled") |
|
|
|
parser.add_argument("--num_gpus", |
|
"--num_accelerators", |
|
type=int, |
|
default=-1, |
|
help="Max number of GPUs to use on each node, will use " |
|
"[0:N) GPU ids on each node.") |
|
|
|
parser.add_argument("--master_port", |
|
default=TORCH_DISTRIBUTED_DEFAULT_PORT, |
|
type=int, |
|
help="(optional) Port used by PyTorch distributed for " |
|
"communication during training.") |
|
|
|
parser.add_argument("--master_addr", |
|
default="", |
|
type=str, |
|
help="(optional) IP address of node 0, will be " |
|
"inferred via 'hostname -I' if not specified.") |
|
|
|
parser.add_argument("--node_rank", |
|
default=-1, |
|
type=int, |
|
help="ID of each node in the range [0:N). " |
|
"Only required when --no_ssh is set.") |
|
|
|
parser.add_argument("--launcher", |
|
default=PDSH_LAUNCHER, |
|
type=str, |
|
help="(optional) choose launcher backend for multi-node " |
|
"training. Options currently include PDSH, OpenMPI, MVAPICH, SLURM, MPICH, IMPI.") |
|
|
|
parser.add_argument("--launcher_args", |
|
default="", |
|
type=str, |
|
help="(optional) pass launcher specific arguments as a " |
|
"single quoted argument.") |
|
|
|
parser.add_argument("--module", |
|
action="store_true", |
|
help="Change each process to interpret the launch " |
|
"script as a Python module, executing with the same " |
|
"behavior as 'python -m'.") |
|
|
|
parser.add_argument("--no_python", |
|
action="store_true", |
|
help="Skip prepending the training script with " |
|
"'python' - just execute it directly.") |
|
|
|
parser.add_argument("--no_local_rank", |
|
action="store_true", |
|
help="Do not pass local_rank as an argument when calling " |
|
"the user's training script.") |
|
|
|
parser.add_argument("--no_ssh", |
|
action="store_true", |
|
help="Launch training independently on each node without ssh setup.") |
|
|
|
parser.add_argument("--no_ssh_check", |
|
action="store_true", |
|
help="Do not perform ssh check in multi-node launcher model") |
|
|
|
parser.add_argument("--force_multi", |
|
action="store_true", |
|
help="Force multi-node launcher mode, helps in cases where user " |
|
"wants to launch on single remote node.") |
|
|
|
parser.add_argument("--save_pid", |
|
action="store_true", |
|
help="Save file containing launcher process id (pid) at /tmp/<main-pid>.ds, " |
|
"where <main-pid> is the pid of the first process that invoked `deepspeed`. " |
|
"Useful when launching deepspeed processes programmatically.") |
|
|
|
parser.add_argument("--enable_each_rank_log", |
|
default="None", |
|
type=str, |
|
help="redirect the stdout and stderr from each rank into different log files") |
|
|
|
parser.add_argument("--autotuning", |
|
default="", |
|
choices=["tune", "run"], |
|
type=str, |
|
help="Run DeepSpeed autotuner to discover optimal configuration parameters " |
|
"before running job.") |
|
|
|
parser.add_argument("--elastic_training", |
|
action="store_true", |
|
help="Enable elastic training support in DeepSpeed.") |
|
|
|
parser.add_argument("user_script", type=str, help="User script to launch, followed by any required " |
|
"arguments.") |
|
|
|
parser.add_argument('user_args', nargs=argparse.REMAINDER) |
|
|
|
parser.add_argument("--bind_cores_to_rank", |
|
action="store_true", |
|
help="Bind each rank to different cores of the host") |
|
|
|
parser.add_argument("--bind_core_list", |
|
type=str, |
|
default=None, |
|
help="List of cores to bind to with comma separated list of " |
|
"numbers and range. i.e. 1,3-5,7 => [1,3,4,5,7]. When not " |
|
"specified, all cores on system would be used rank binding") |
|
|
|
parser.add_argument("--ssh_port", type=int, default=None, help="SSH port to use for remote connections") |
|
|
|
parser.add_argument("--venv_script", |
|
type=str, |
|
default=None, |
|
help="Python virtual environment activation script for job.") |
|
|
|
return parser.parse_args(args=args) |
|
|
|
|
|
def fetch_hostfile(hostfile_path): |
|
if not os.path.isfile(hostfile_path): |
|
logger.warning("Unable to find hostfile, will proceed with training " |
|
"with local resources only.") |
|
return None |
|
|
|
|
|
with open(hostfile_path, 'r') as fd: |
|
hostfile_text = fd.readlines() |
|
|
|
return _parse_hostfile(hostfile_text) |
|
|
|
|
|
def _parse_hostfile(hostfile_lines): |
|
|
|
|
|
|
|
pattern = r'^(\S+)\s+slots=(\d+)' |
|
|
|
resource_pool = collections.OrderedDict() |
|
|
|
for line in hostfile_lines: |
|
line = line.strip() |
|
match = re.search(pattern, line) |
|
if line.startswith("#") or line == "": |
|
|
|
continue |
|
elif match: |
|
host = match.group(1) |
|
num_slots = int(match.group(2)) |
|
if host in resource_pool: |
|
logger.error(f"Bad hostfile text: {hostfile_lines}") |
|
raise ValueError(f"Hostfile contains multiple entries for {host}, unable to proceed with launching") |
|
resource_pool[host] = num_slots |
|
else: |
|
logger.error(f"Bad hostfile text: {hostfile_lines}") |
|
raise ValueError(f"Hostfile contains a bad entry: {line}, unable to proceed with launching") |
|
|
|
if len(resource_pool) == 0: |
|
logger.error(f"Bad hostfile text: {hostfile_lines}") |
|
raise ValueError("Hostfile is empty or not formatted correctly, unable to proceed with launching.") |
|
|
|
return resource_pool |
|
|
|
|
|
def _stable_remove_duplicates(data): |
|
|
|
|
|
new_list = [] |
|
for x in data: |
|
if x not in new_list: |
|
new_list.append(x) |
|
return new_list |
|
|
|
|
|
def parse_node_config(node_config: str) -> Tuple[str, List[int]]: |
|
SLOT_LIST_START = ':' |
|
SLOT_SEP = ',' |
|
|
|
if SLOT_LIST_START not in node_config: |
|
return node_config, [] |
|
|
|
hostname, slots = node_config.split(SLOT_LIST_START) |
|
slots = [int(x) for x in slots.split(SLOT_SEP)] |
|
|
|
return hostname, slots |
|
|
|
|
|
def parse_node_config_list(node_config_list: List[str]) -> Dict[str, List[int]]: |
|
NODE_SEP = '@' |
|
|
|
node_configs = defaultdict(list) |
|
|
|
for node_config in node_config_list.split(NODE_SEP): |
|
hostname, slots = parse_node_config(node_config) |
|
node_configs[hostname] += slots |
|
|
|
return {k: sorted(list(set(v))) for k, v in node_configs.items()} |
|
|
|
|
|
def parse_resource_filter(host_info, include_str="", exclude_str=""): |
|
'''Parse an inclusion or exclusion string and filter a hostfile dictionary. |
|
|
|
String format is NODE_SPEC[@NODE_SPEC ...], where |
|
NODE_SPEC = NAME[:SLOT[,SLOT ...]]. |
|
If :SLOT is omitted, include/exclude all slots on that host. |
|
|
|
Examples: |
|
include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and |
|
slots [0, 2] on worker-1. |
|
exclude_str="worker-1:0" will use all available resources except |
|
slot 0 on worker-1. |
|
''' |
|
|
|
|
|
if (include_str != "") and (exclude_str != ""): |
|
raise ValueError('include_str and exclude_str are mutually exclusive.') |
|
|
|
|
|
if (include_str == "") and (exclude_str == ""): |
|
return host_info |
|
|
|
|
|
filtered_hosts = dict() |
|
if include_str: |
|
parse_str = include_str |
|
if exclude_str != "": |
|
filtered_hosts = deepcopy(host_info) |
|
parse_str = exclude_str |
|
|
|
|
|
for hostname, slots in parse_node_config_list(parse_str).items(): |
|
|
|
if len(slots) > 0: |
|
|
|
if hostname not in host_info: |
|
raise ValueError(f"Hostname '{hostname}' not found in hostfile") |
|
for slot in slots: |
|
if slot not in host_info[hostname]: |
|
raise ValueError(f"No slot '{slot}' specified on host '{hostname}'") |
|
|
|
|
|
if include_str: |
|
filtered_hosts[hostname] = slots |
|
elif exclude_str: |
|
for slot in slots: |
|
logger.info(f'removing {slot} from {hostname}') |
|
filtered_hosts[hostname].remove(slot) |
|
|
|
|
|
else: |
|
|
|
if hostname not in host_info: |
|
raise ValueError(f"Hostname '{hostname}' not found in hostfile") |
|
|
|
if include_str: |
|
filtered_hosts[hostname] = host_info[hostname] |
|
elif exclude_str: |
|
filtered_hosts[hostname] = [] |
|
|
|
|
|
del_keys = [] |
|
for hostname in filtered_hosts: |
|
|
|
filtered_hosts[hostname] = _stable_remove_duplicates(filtered_hosts[hostname]) |
|
|
|
if len(filtered_hosts[hostname]) == 0: |
|
del_keys.append(hostname) |
|
for name in del_keys: |
|
del filtered_hosts[name] |
|
|
|
|
|
|
|
ordered_hosts = collections.OrderedDict() |
|
for host in host_info: |
|
if host in filtered_hosts: |
|
ordered_hosts[host] = filtered_hosts[host] |
|
|
|
return ordered_hosts |
|
|
|
|
|
def parse_inclusion_exclusion(resource_pool, inclusion, exclusion): |
|
active_resources = collections.OrderedDict() |
|
node_configs = parse_node_config_list(inclusion) |
|
|
|
for hostname, slots in resource_pool.items(): |
|
active_resources[hostname] = node_configs[hostname] if hostname in node_configs else list(range(slots)) |
|
|
|
return parse_resource_filter(active_resources, include_str=inclusion, exclude_str=exclusion) |
|
|
|
|
|
def encode_world_info(world_info): |
|
world_info_json = json.dumps(world_info).encode('utf-8') |
|
world_info_base64 = base64.urlsafe_b64encode(world_info_json).decode('utf-8') |
|
return world_info_base64 |
|
|
|
|
|
def run_autotuning(args, active_resources): |
|
tuner = Autotuner(args, active_resources) |
|
logger.info("[Start] Running autotuning") |
|
|
|
tuner.tune() |
|
tuner.print_tuning_results() |
|
|
|
logger.info("[End] Running autotuning") |
|
tuner.write_optimal_config() |
|
|
|
if args.autotuning == "run": |
|
tuner.run_after_tuning() |
|
|
|
|
|
def parse_num_nodes(str_num_nodes: str, elastic_training: bool): |
|
node_list = str_num_nodes.split(":") |
|
|
|
if len(node_list) == 1: |
|
min_nodes, max_nodes = int(node_list[0]), -1 |
|
elif len(node_list) == 2 and elastic_training: |
|
min_nodes, max_nodes = int(node_list[0]), int(node_list[1]) |
|
elif len(node_list) == 2 and not elastic_training: |
|
raise RuntimeError("MIN:MAX format is only supported in elastic training") |
|
else: |
|
raise RuntimeError("num_nodes {} is not in MIN:MAX format".format(str_num_nodes)) |
|
|
|
return min_nodes, max_nodes |
|
|
|
|
|
def main(args=None): |
|
args = parse_args(args) |
|
|
|
if args.elastic_training: |
|
assert args.master_addr != "", "Master Addr is required when elastic training is enabled" |
|
|
|
resource_pool = fetch_hostfile(args.hostfile) |
|
|
|
|
|
visible_devices_env = get_accelerator().visible_devices_envs()[0] |
|
visible_devices = os.environ.get(visible_devices_env, "") |
|
if not resource_pool and len(visible_devices): |
|
detected_str = f"Detected VISIBLE_DEVICES={visible_devices}" |
|
if len(args.include) or len(args.exclude) or args.num_nodes > 1 or args.num_gpus > 0: |
|
print( |
|
f"{detected_str} but ignoring it because one or several of --include/--exclude/--num_gpus/--num_nodes cl args were used. If you want to use CUDA_VISIBLE_DEVICES don't pass any of these arguments to deepspeed." |
|
) |
|
else: |
|
args.include = f"localhost:{visible_devices}" |
|
print(f"{detected_str}: setting --include={args.include}") |
|
del os.environ[visible_devices_env] |
|
|
|
if args.num_nodes >= 0 or args.num_gpus >= 0: |
|
if args.include != "" or args.exclude != "": |
|
raise ValueError("Cannot specify num_nodes/gpus with include/exclude") |
|
|
|
multi_node_exec = True |
|
if not resource_pool: |
|
resource_pool = {} |
|
device_count = get_accelerator().device_count() |
|
if device_count == 0: |
|
raise RuntimeError("Unable to proceed, no GPU resources available") |
|
resource_pool['localhost'] = device_count |
|
args.master_addr = "127.0.0.1" |
|
multi_node_exec = False |
|
|
|
if not multi_node_exec and args.num_nodes > 1: |
|
raise ValueError("Num nodes is >1 but no extra nodes available via hostfile") |
|
|
|
active_resources = parse_inclusion_exclusion(resource_pool, args.include, args.exclude) |
|
env = os.environ.copy() |
|
|
|
|
|
if multi_node_exec and not args.no_ssh_check and not args.no_ssh: |
|
first_host = list(active_resources.keys())[0] |
|
try: |
|
ssh_check_cmd = ("ssh -o PasswordAuthentication=no " + |
|
(f"-p {args.ssh_port} " if args.ssh_port is not None else "") + f"{first_host} hostname") |
|
safe_ssh_cmd = shlex.split(ssh_check_cmd) |
|
subprocess.check_call(safe_ssh_cmd, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) |
|
except subprocess.CalledProcessError: |
|
raise RuntimeError( |
|
f"Using hostfile at {args.hostfile} but host={first_host} was not reachable via ssh. If you are running with a single node please remove {args.hostfile} or setup passwordless ssh." |
|
) |
|
|
|
if not args.master_addr: |
|
assert multi_node_exec |
|
first_host = list(active_resources.keys())[0] |
|
ssh_check_cmd = "ssh " |
|
if args.ssh_port is not None: |
|
ssh_check_cmd += f" -p {args.ssh_port}" |
|
ssh_check_cmd += f" {first_host} hostname -I" |
|
hostname_cmd = shlex.split(ssh_check_cmd) |
|
try: |
|
result = subprocess.check_output(hostname_cmd) |
|
except subprocess.CalledProcessError as err: |
|
logger.error( |
|
"Unable to detect suitable master address via 'hostname -I', please manually specify one via --master_addr" |
|
) |
|
raise err |
|
args.master_addr = result.decode('utf-8').split()[0] |
|
if not args.master_addr: |
|
raise RuntimeError( |
|
"Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr" |
|
) |
|
logger.info(f"Using IP address of {args.master_addr} for node {first_host}") |
|
|
|
if args.autotuning != "": |
|
run_autotuning(args, active_resources) |
|
return |
|
|
|
if args.num_nodes > 0: |
|
updated_active_resources = collections.OrderedDict() |
|
for count, hostname in enumerate(active_resources.keys()): |
|
if args.num_nodes == count: |
|
break |
|
updated_active_resources[hostname] = active_resources[hostname] |
|
active_resources = updated_active_resources |
|
|
|
if args.num_gpus > 0: |
|
updated_active_resources = collections.OrderedDict() |
|
for hostname in active_resources.keys(): |
|
updated_active_resources[hostname] = list(range(args.num_gpus)) |
|
active_resources = updated_active_resources |
|
|
|
if args.elastic_training: |
|
assert not args.no_local_rank, "--no_local_rank argument is not supported in Elastic training" |
|
|
|
if args.no_ssh: |
|
assert (0 <= args.node_rank < |
|
len(active_resources)), "Launching training without ssh, but --node_rank is not set correctly." |
|
|
|
|
|
world_info_base64 = encode_world_info(active_resources) |
|
|
|
multi_node_exec = (args.force_multi or len(active_resources) > 1) and not args.no_ssh |
|
|
|
if not multi_node_exec: |
|
deepspeed_launch = [ |
|
sys.executable, "-u", "-m", "deepspeed.launcher.launch", f"--world_info={world_info_base64}", |
|
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}" |
|
] |
|
if args.no_ssh: |
|
deepspeed_launch.append(f"--node_rank={args.node_rank}") |
|
if args.no_python: |
|
deepspeed_launch.append("--no_python") |
|
if args.module: |
|
deepspeed_launch.append("--module") |
|
if args.no_local_rank: |
|
deepspeed_launch.append("--no_local_rank") |
|
if args.save_pid: |
|
deepspeed_launch += ["--save_pid", f"{os.getpid()}"] |
|
if args.enable_each_rank_log: |
|
deepspeed_launch.append(f"--enable_each_rank_log={args.enable_each_rank_log}") |
|
if args.elastic_training: |
|
deepspeed_launch.append("--enable_elastic_training") |
|
deepspeed_launch.append(f"--max_elastic_nodes={args.max_elastic_nodes}") |
|
deepspeed_launch.append(f"--min_elastic_nodes={args.min_elastic_nodes}") |
|
if args.bind_cores_to_rank: |
|
deepspeed_launch.append("--bind_cores_to_rank") |
|
if args.bind_core_list is not None: |
|
deepspeed_launch.append(f"--bind_core_list={args.bind_core_list}") |
|
cmd = deepspeed_launch + [args.user_script] + args.user_args |
|
else: |
|
args.launcher = args.launcher.lower() |
|
if args.launcher == PDSH_LAUNCHER: |
|
runner = PDSHRunner(args, world_info_base64) |
|
elif args.launcher == OPENMPI_LAUNCHER: |
|
runner = OpenMPIRunner(args, world_info_base64, resource_pool) |
|
elif args.launcher == MPICH_LAUNCHER: |
|
runner = MPICHRunner(args, world_info_base64, resource_pool) |
|
elif args.launcher == IMPI_LAUNCHER: |
|
runner = IMPIRunner(args, world_info_base64, resource_pool) |
|
elif args.launcher == MVAPICH_LAUNCHER: |
|
runner = MVAPICHRunner(args, world_info_base64, resource_pool) |
|
elif args.launcher == SLURM_LAUNCHER: |
|
runner = SlurmRunner(args, world_info_base64, resource_pool) |
|
else: |
|
raise NotImplementedError(f"Unknown launcher {args.launcher}") |
|
|
|
if not runner.backend_exists(): |
|
raise RuntimeError(f"launcher '{args.launcher}' not installed.") |
|
|
|
curr_path = os.path.abspath('.') |
|
if 'PYTHONPATH' in env: |
|
env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH'] |
|
else: |
|
env['PYTHONPATH'] = curr_path |
|
|
|
excluded_vars = [] |
|
for exclude_key, var_list in EXCLUDE_ENVS.items(): |
|
if exclude_key in env.keys(): |
|
|
|
excluded_vars += var_list |
|
|
|
|
|
exports = EXPORT_ENVS + get_accelerator().export_envs() |
|
for var in env.keys(): |
|
if any([var.startswith(name) for name in exports]): |
|
if not any([var == name for name in excluded_vars]): |
|
runner.add_export(var, env[var]) |
|
|
|
for environ_path in DEEPSPEED_ENVIRONMENT_PATHS: |
|
environ_file = os.path.join(environ_path, DEEPSPEED_ENVIRONMENT_NAME) |
|
if os.path.isfile(environ_file): |
|
logger.info(f"deepspeed_env file = {environ_file}") |
|
with open(environ_file, 'r') as fd: |
|
for var in fd.readlines(): |
|
key, val = var.split('=', maxsplit=1) |
|
runner.add_export(key, val) |
|
|
|
if args.launcher == PDSH_LAUNCHER: |
|
cmd, kill_cmd, env = runner.get_cmd(env, active_resources) |
|
else: |
|
cmd = runner.get_cmd(env, active_resources) |
|
|
|
logger.info(f"cmd = {' '.join(cmd)}") |
|
result = subprocess.Popen(cmd, env=env) |
|
|
|
def sigkill_handler(signum, frame): |
|
result.send_signal(signal.SIGINT) |
|
time.sleep(0.1) |
|
result.send_signal(signal.SIGTERM) |
|
result_kill = subprocess.Popen(kill_cmd, env=env) |
|
result_kill.wait() |
|
time.sleep(1) |
|
sys.exit(1) |
|
|
|
if args.launcher == PDSH_LAUNCHER and multi_node_exec: |
|
signal.signal(signal.SIGINT, sigkill_handler) |
|
signal.signal(signal.SIGTERM, sigkill_handler) |
|
|
|
result.wait() |
|
|
|
|
|
|
|
|
|
if result.returncode > 0: |
|
sys.exit(result.returncode) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|