Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import os | |
import argparse | |
import logging | |
import json | |
import base64 | |
from .runtime.common import enable_multi_thread, enable_multi_phase | |
from .runtime.msg_dispatcher import MsgDispatcher | |
from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance | |
logger = logging.getLogger('nni.main') | |
logger.debug('START') | |
if os.environ.get('COVERAGE_PROCESS_START'): | |
import coverage | |
coverage.process_startup() | |
def main(): | |
parser = argparse.ArgumentParser(description='Dispatcher command line parser') | |
parser.add_argument('--exp_params', type=str, required=True) | |
args, _ = parser.parse_known_args() | |
exp_params_decode = base64.b64decode(args.exp_params).decode('utf-8') | |
logger.debug('decoded exp_params: [%s]', exp_params_decode) | |
exp_params = json.loads(exp_params_decode) | |
logger.debug('exp_params json obj: [%s]', json.dumps(exp_params, indent=4)) | |
if exp_params.get('multiThread'): | |
enable_multi_thread() | |
if exp_params.get('multiPhase'): | |
enable_multi_phase() | |
if exp_params.get('advisor') is not None: | |
# advisor is enabled and starts to run | |
_run_advisor(exp_params) | |
else: | |
# tuner (and assessor) is enabled and starts to run | |
assert exp_params.get('tuner') is not None | |
tuner = _create_tuner(exp_params) | |
if exp_params.get('assessor') is not None: | |
assessor = _create_assessor(exp_params) | |
else: | |
assessor = None | |
dispatcher = MsgDispatcher(tuner, assessor) | |
try: | |
dispatcher.run() | |
tuner._on_exit() | |
if assessor is not None: | |
assessor._on_exit() | |
except Exception as exception: | |
logger.exception(exception) | |
tuner._on_error() | |
if assessor is not None: | |
assessor._on_error() | |
raise | |
def _run_advisor(exp_params): | |
if exp_params.get('advisor').get('builtinAdvisorName'): | |
dispatcher = create_builtin_class_instance( | |
exp_params.get('advisor').get('builtinAdvisorName'), | |
exp_params.get('advisor').get('classArgs'), | |
'advisors') | |
else: | |
dispatcher = create_customized_class_instance(exp_params.get('advisor')) | |
if dispatcher is None: | |
raise AssertionError('Failed to create Advisor instance') | |
try: | |
dispatcher.run() | |
except Exception as exception: | |
logger.exception(exception) | |
raise | |
def _create_tuner(exp_params): | |
if exp_params.get('tuner').get('builtinTunerName'): | |
tuner = create_builtin_class_instance( | |
exp_params.get('tuner').get('builtinTunerName'), | |
exp_params.get('tuner').get('classArgs'), | |
'tuners') | |
else: | |
tuner = create_customized_class_instance(exp_params.get('tuner')) | |
if tuner is None: | |
raise AssertionError('Failed to create Tuner instance') | |
return tuner | |
def _create_assessor(exp_params): | |
if exp_params.get('assessor').get('builtinAssessorName'): | |
assessor = create_builtin_class_instance( | |
exp_params.get('assessor').get('builtinAssessorName'), | |
exp_params.get('assessor').get('classArgs'), | |
'assessors') | |
else: | |
assessor = create_customized_class_instance(exp_params.get('assessor')) | |
if assessor is None: | |
raise AssertionError('Failed to create Assessor instance') | |
return assessor | |
if __name__ == '__main__': | |
try: | |
main() | |
except Exception as exception: | |
logger.exception(exception) | |
raise | |