|
|
|
|
|
|
|
|
|
|
|
from .utils import check_tb_availability |
|
from .monitor import Monitor |
|
import os |
|
|
|
import deepspeed.comm as dist |
|
|
|
|
|
class TensorBoardMonitor(Monitor): |
|
|
|
def __init__(self, tensorboard_config): |
|
super().__init__(tensorboard_config) |
|
check_tb_availability() |
|
|
|
self.summary_writer = None |
|
self.enabled = tensorboard_config.enabled |
|
self.output_path = tensorboard_config.output_path |
|
self.job_name = tensorboard_config.job_name |
|
|
|
if self.enabled and dist.get_rank() == 0: |
|
self.get_summary_writer() |
|
|
|
def get_summary_writer(self, base=os.path.join(os.path.expanduser("~"), "tensorboard")): |
|
if self.enabled and dist.get_rank() == 0: |
|
from torch.utils.tensorboard import SummaryWriter |
|
if self.output_path is not None: |
|
log_dir = os.path.join(self.output_path, self.job_name) |
|
|
|
else: |
|
if "DLWS_JOB_ID" in os.environ: |
|
infra_job_id = os.environ["DLWS_JOB_ID"] |
|
elif "DLTS_JOB_ID" in os.environ: |
|
infra_job_id = os.environ["DLTS_JOB_ID"] |
|
else: |
|
infra_job_id = "unknown-job-id" |
|
|
|
summary_writer_dir_name = os.path.join(infra_job_id, "logs") |
|
log_dir = os.path.join(base, summary_writer_dir_name, self.output_path) |
|
os.makedirs(log_dir, exist_ok=True) |
|
self.summary_writer = SummaryWriter(log_dir=log_dir) |
|
return self.summary_writer |
|
|
|
def write_events(self, event_list, flush=True): |
|
if self.enabled and self.summary_writer is not None and dist.get_rank() == 0: |
|
for event in event_list: |
|
self.summary_writer.add_scalar(*event) |
|
if flush: |
|
self.summary_writer.flush() |
|
|
|
def flush(self): |
|
if self.enabled and self.summary_writer is not None and dist.get_rank() == 0: |
|
self.summary_writer.flush() |
|
|