Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import os | |
import sys | |
import json | |
import time | |
import subprocess | |
from ..common import init_logger | |
from ..env_vars import trial_env_vars | |
from nni.utils import to_json | |
_sysdir = trial_env_vars.NNI_SYS_DIR | |
if not os.path.exists(os.path.join(_sysdir, '.nni')): | |
os.makedirs(os.path.join(_sysdir, '.nni')) | |
_metric_file = open(os.path.join(_sysdir, '.nni', 'metrics'), 'wb') | |
_outputdir = trial_env_vars.NNI_OUTPUT_DIR | |
if not os.path.exists(_outputdir): | |
os.makedirs(_outputdir) | |
_nni_platform = trial_env_vars.NNI_PLATFORM | |
if _nni_platform == 'local': | |
_log_file_path = os.path.join(_outputdir, 'trial.log') | |
init_logger(_log_file_path) | |
_multiphase = trial_env_vars.MULTI_PHASE | |
_param_index = 0 | |
def request_next_parameter(): | |
metric = to_json({ | |
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, | |
'type': 'REQUEST_PARAMETER', | |
'sequence': 0, | |
'parameter_index': _param_index | |
}) | |
send_metric(metric) | |
def get_next_parameter(): | |
global _param_index | |
params_file_name = '' | |
if _multiphase in ('true', 'True'): | |
params_file_name = ('parameter_{}.cfg'.format(_param_index), 'parameter.cfg')[_param_index == 0] | |
else: | |
if _param_index > 0: | |
return None | |
elif _param_index == 0: | |
params_file_name = 'parameter.cfg' | |
else: | |
raise AssertionError('_param_index value ({}) should >=0'.format(_param_index)) | |
params_filepath = os.path.join(_sysdir, params_file_name) | |
if not os.path.isfile(params_filepath): | |
request_next_parameter() | |
while not (os.path.isfile(params_filepath) and os.path.getsize(params_filepath) > 0): | |
time.sleep(3) | |
params_file = open(params_filepath, 'r') | |
params = json.load(params_file) | |
_param_index += 1 | |
return params | |
def send_metric(string): | |
if _nni_platform != 'local': | |
assert len(string) < 1000000, 'Metric too long' | |
print("NNISDK_MEb'%s'" % (string), flush=True) | |
else: | |
data = (string + '\n').encode('utf8') | |
assert len(data) < 1000000, 'Metric too long' | |
_metric_file.write(b'ME%06d%b' % (len(data), data)) | |
_metric_file.flush() | |
if sys.platform == "win32": | |
file = open(_metric_file.name) | |
file.close() | |
else: | |
subprocess.run(['touch', _metric_file.name], check=True) | |
def get_experiment_id(): | |
return trial_env_vars.NNI_EXP_ID | |
def get_trial_id(): | |
return trial_env_vars.NNI_TRIAL_JOB_ID | |
def get_sequence_id(): | |
return int(trial_env_vars.NNI_TRIAL_SEQ_ID) | |