Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
from azureml.core.run import Run # pylint: disable=import-error | |
from .base_channel import BaseChannel | |
from .log_utils import LogType, nni_log | |
class AMLChannel(BaseChannel): | |
def __init__(self, args): | |
self.args = args | |
self.run = Run.get_context() | |
super(AMLChannel, self).__init__(args) | |
self.current_message_index = -1 | |
def _inner_open(self): | |
pass | |
def _inner_close(self): | |
pass | |
def _inner_send(self, message): | |
try: | |
self.run.log('trial_runner', message.decode('utf8')) | |
except Exception as exception: | |
nni_log(LogType.Error, 'meet unhandled exception when send message: %s' % exception) | |
def _inner_receive(self): | |
messages = [] | |
message_dict = self.run.get_metrics() | |
if 'nni_manager' not in message_dict: | |
return [] | |
message_list = message_dict['nni_manager'] | |
if not message_list: | |
return messages | |
if type(message_list) is list: | |
if self.current_message_index < len(message_list) - 1: | |
messages = message_list[self.current_message_index + 1 : len(message_list)] | |
self.current_message_index = len(message_list) - 1 | |
elif self.current_message_index == -1: | |
messages = [message_list] | |
self.current_message_index += 1 | |
newMessage = [] | |
for message in messages: | |
# receive message is string, to get consistent result, encode it here. | |
newMessage.append(message.encode('utf8')) | |
return newMessage | |