# Copyright (c) Microsoft Corporation. | |
# SPDX-License-Identifier: Apache-2.0 | |
# DeepSpeed Team | |
""" | |
Functionality of swapping tensors to/from (NVMe) storage devices. | |
""" | |
import subprocess | |
import shlex | |
class Job(object): | |
def __init__(self, cmd_line, output_file=None, work_dir=None): | |
self.cmd_line = cmd_line | |
self.output_file = output_file | |
self.work_dir = work_dir | |
self.output_fd = None | |
def cmd(self): | |
return self.cmd_line | |
def get_stdout(self): | |
return self.output_fd | |
def get_stderr(self): | |
return self.output_fd | |
def get_cwd(self): | |
return self.work_dir | |
def open_output_file(self): | |
if self.output_file is not None: | |
self.output_fd = open(self.output_file, 'w') | |
def close_output_file(self): | |
if self.output_fd is not None: | |
self.output_fd.close() | |
self.output_fd = None | |
def run_job(job, verbose=False): | |
args = shlex.split(' '.join(job.cmd())) | |
if verbose: | |
print(f'args = {args}') | |
job.open_output_file() | |
proc = subprocess.run(args=args, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd()) | |
job.close_output_file() | |
assert proc.returncode == 0, \ | |
f"This command failed: {job.cmd()}" | |