#!/Users/james/src/MMaDA/venv/bin/python | |
from benchmarks.communication.run_all import main | |
from benchmarks.communication.constants import * | |
from benchmarks.communication.utils import * | |
import os | |
import sys | |
# Run the same file with deepspeed launcher. This is required since setuptools will auto-detect python files and insert a python shebang for both 'scripts' and 'entry_points', and this benchmarks require the DS launcher | |
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] | |
if not all(map(lambda v: v in os.environ, required_env)): | |
import subprocess | |
r = subprocess.check_output(["which", "ds_bench"]) | |
ds_bench_bin = r.decode('utf-8').strip() | |
safe_cmd = ["deepspeed", ds_bench_bin] + sys.argv[1:] | |
subprocess.run(safe_cmd) | |
else: | |
args = benchmark_parser().parse_args() | |
rank = args.local_rank | |
main(args, rank) | |