|
"""Handle Manager.""" |
|
|
|
import json |
|
import logging |
|
import math |
|
import numbers |
|
import time |
|
from collections import defaultdict |
|
from queue import Queue |
|
from threading import Event |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
Callable, |
|
Dict, |
|
Iterable, |
|
List, |
|
Optional, |
|
Sequence, |
|
Tuple, |
|
cast, |
|
) |
|
|
|
from wandb.errors.links import url_registry |
|
from wandb.proto.wandb_internal_pb2 import ( |
|
HistoryRecord, |
|
InternalMessages, |
|
MetricRecord, |
|
Record, |
|
Result, |
|
RunRecord, |
|
SampledHistoryItem, |
|
SummaryItem, |
|
SummaryRecord, |
|
SummaryRecordRequest, |
|
SystemMetricSample, |
|
SystemMetricsBuffer, |
|
) |
|
|
|
from ..interface.interface_queue import InterfaceQueue |
|
from ..lib import handler_util, proto_util |
|
from ..wandb_metadata import Metadata |
|
from . import context, sample, tb_watcher |
|
from .settings_static import SettingsStatic |
|
from .system.system_monitor import SystemMonitor |
|
|
|
if TYPE_CHECKING: |
|
from wandb.proto.wandb_internal_pb2 import MetricSummary |
|
|
|
|
|
SummaryDict = Dict[str, Any] |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REPLACE_SUMMARY_ART_PATH_WITH_LATEST = False |
|
|
|
|
|
def _dict_nested_set(target: Dict[str, Any], key_list: Sequence[str], v: Any) -> None: |
|
|
|
|
|
for k in key_list[:-1]: |
|
target.setdefault(k, {}) |
|
new_target = target.get(k) |
|
if TYPE_CHECKING: |
|
new_target = cast(Dict[str, Any], new_target) |
|
target = new_target |
|
|
|
target[key_list[-1]] = v |
|
|
|
|
|
class HandleManager: |
|
_consolidated_summary: SummaryDict |
|
_sampled_history: Dict[str, sample.UniformSampleAccumulator] |
|
_partial_history: Dict[str, Any] |
|
_run_proto: Optional[RunRecord] |
|
_settings: SettingsStatic |
|
_record_q: "Queue[Record]" |
|
_result_q: "Queue[Result]" |
|
_stopped: Event |
|
_writer_q: "Queue[Record]" |
|
_interface: InterfaceQueue |
|
_system_monitor: Optional[SystemMonitor] |
|
_tb_watcher: Optional[tb_watcher.TBWatcher] |
|
_metric_defines: Dict[str, MetricRecord] |
|
_metric_globs: Dict[str, MetricRecord] |
|
_metric_track: Dict[Tuple[str, ...], float] |
|
_metric_copy: Dict[Tuple[str, ...], Any] |
|
_track_time: Optional[float] |
|
_accumulate_time: float |
|
_run_start_time: Optional[float] |
|
_context_keeper: context.ContextKeeper |
|
|
|
def __init__( |
|
self, |
|
settings: SettingsStatic, |
|
record_q: "Queue[Record]", |
|
result_q: "Queue[Result]", |
|
stopped: Event, |
|
writer_q: "Queue[Record]", |
|
interface: InterfaceQueue, |
|
context_keeper: context.ContextKeeper, |
|
) -> None: |
|
self._settings = settings |
|
self._record_q = record_q |
|
self._result_q = result_q |
|
self._stopped = stopped |
|
self._writer_q = writer_q |
|
self._interface = interface |
|
self._context_keeper = context_keeper |
|
|
|
self._tb_watcher = None |
|
self._system_monitor = None |
|
self._metadata: Optional[Metadata] = None |
|
self._step = 0 |
|
|
|
self._track_time = None |
|
self._accumulate_time = 0 |
|
self._run_start_time = None |
|
|
|
|
|
self._consolidated_summary = dict() |
|
self._sampled_history = defaultdict(sample.UniformSampleAccumulator) |
|
self._run_proto = None |
|
self._partial_history = dict() |
|
self._metric_defines = defaultdict(MetricRecord) |
|
self._metric_globs = defaultdict(MetricRecord) |
|
self._metric_track = dict() |
|
self._metric_copy = dict() |
|
self._internal_messages = InternalMessages() |
|
|
|
self._dropped_history = False |
|
|
|
def __len__(self) -> int: |
|
return self._record_q.qsize() |
|
|
|
def handle(self, record: Record) -> None: |
|
self._context_keeper.add_from_record(record) |
|
record_type = record.WhichOneof("record_type") |
|
assert record_type |
|
handler_str = "handle_" + record_type |
|
handler: Callable[[Record], None] = getattr(self, handler_str, None) |
|
assert handler, f"unknown handle: {handler_str}" |
|
handler(record) |
|
|
|
def handle_request(self, record: Record) -> None: |
|
request_type = record.request.WhichOneof("request_type") |
|
assert request_type |
|
handler_str = "handle_request_" + request_type |
|
handler: Callable[[Record], None] = getattr(self, handler_str, None) |
|
if request_type != "network_status": |
|
logger.debug(f"handle_request: {request_type}") |
|
assert handler, f"unknown handle: {handler_str}" |
|
handler(record) |
|
|
|
def _dispatch_record(self, record: Record, always_send: bool = False) -> None: |
|
if always_send: |
|
record.control.always_send = True |
|
self._writer_q.put(record) |
|
|
|
def _respond_result(self, result: Result) -> None: |
|
context_id = context.context_id_from_result(result) |
|
self._context_keeper.release(context_id) |
|
self._result_q.put(result) |
|
|
|
def debounce(self) -> None: |
|
pass |
|
|
|
def handle_request_cancel(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_request_metadata(self, record: Record) -> None: |
|
logger.warning("Metadata updates are ignored when using the legacy service.") |
|
|
|
def handle_request_defer(self, record: Record) -> None: |
|
defer = record.request.defer |
|
state = defer.state |
|
|
|
logger.info(f"handle defer: {state}") |
|
|
|
if state == defer.FLUSH_STATS: |
|
|
|
|
|
if self._system_monitor is not None: |
|
self._system_monitor.finish() |
|
elif state == defer.FLUSH_TB: |
|
if self._tb_watcher: |
|
|
|
self._tb_watcher.finish() |
|
self._tb_watcher = None |
|
elif state == defer.FLUSH_PARTIAL_HISTORY: |
|
self._flush_partial_history() |
|
elif state == defer.FLUSH_SUM: |
|
self._save_summary(self._consolidated_summary, flush=True) |
|
|
|
|
|
self._dispatch_record(record, always_send=True) |
|
|
|
def handle_request_python_packages(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_run(self, record: Record) -> None: |
|
if self._settings._offline: |
|
self._run_proto = record.run |
|
result = proto_util._result_from_record(record) |
|
result.run_result.run.CopyFrom(record.run) |
|
self._respond_result(result) |
|
self._dispatch_record(record) |
|
|
|
def handle_stats(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_config(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_output(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_output_raw(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_files(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_request_link_artifact(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_use_artifact(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_artifact(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_alert(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def _save_summary(self, summary_dict: SummaryDict, flush: bool = False) -> None: |
|
summary = SummaryRecord() |
|
for k, v in summary_dict.items(): |
|
update = summary.update.add() |
|
update.key = k |
|
update.value_json = json.dumps(v) |
|
if flush: |
|
record = Record(summary=summary) |
|
self._dispatch_record(record) |
|
elif not self._settings._offline: |
|
|
|
summary_record = SummaryRecordRequest(summary=summary) |
|
request_record = self._interface._make_request( |
|
summary_record=summary_record |
|
) |
|
self._dispatch_record(request_record) |
|
|
|
def _save_history( |
|
self, |
|
history: HistoryRecord, |
|
) -> None: |
|
for item in history.item: |
|
|
|
k = item.key |
|
v = json.loads(item.value_json) |
|
if isinstance(v, numbers.Real): |
|
self._sampled_history[k].add(v) |
|
|
|
def _update_summary_metrics( |
|
self, |
|
s: "MetricSummary", |
|
kl: List[str], |
|
v: "numbers.Real", |
|
float_v: float, |
|
goal_max: Optional[bool], |
|
) -> bool: |
|
updated = False |
|
best_key: Optional[Tuple[str, ...]] = None |
|
if s.none: |
|
return False |
|
if s.copy: |
|
|
|
if len(kl) > 1: |
|
_dict_nested_set(self._consolidated_summary, kl, v) |
|
return True |
|
if s.last: |
|
last_key = tuple(kl + ["last"]) |
|
old_last = self._metric_track.get(last_key) |
|
if old_last is None or float_v != old_last: |
|
self._metric_track[last_key] = float_v |
|
_dict_nested_set(self._consolidated_summary, last_key, v) |
|
updated = True |
|
if s.best: |
|
best_key = tuple(kl + ["best"]) |
|
if s.max or best_key and goal_max: |
|
max_key = tuple(kl + ["max"]) |
|
old_max = self._metric_track.get(max_key) |
|
if old_max is None or float_v > old_max: |
|
self._metric_track[max_key] = float_v |
|
if s.max: |
|
_dict_nested_set(self._consolidated_summary, max_key, v) |
|
updated = True |
|
if best_key: |
|
_dict_nested_set(self._consolidated_summary, best_key, v) |
|
updated = True |
|
|
|
if s.min or best_key and not goal_max: |
|
min_key = tuple(kl + ["min"]) |
|
old_min = self._metric_track.get(min_key) |
|
if old_min is None or float_v < old_min: |
|
self._metric_track[min_key] = float_v |
|
if s.min: |
|
_dict_nested_set(self._consolidated_summary, min_key, v) |
|
updated = True |
|
if best_key: |
|
_dict_nested_set(self._consolidated_summary, best_key, v) |
|
updated = True |
|
if s.mean: |
|
tot_key = tuple(kl + ["tot"]) |
|
num_key = tuple(kl + ["num"]) |
|
avg_key = tuple(kl + ["mean"]) |
|
tot = self._metric_track.get(tot_key, 0.0) |
|
num = self._metric_track.get(num_key, 0) |
|
tot += float_v |
|
num += 1 |
|
self._metric_track[tot_key] = tot |
|
self._metric_track[num_key] = num |
|
_dict_nested_set(self._consolidated_summary, avg_key, tot / num) |
|
updated = True |
|
return updated |
|
|
|
def _update_summary_leaf( |
|
self, |
|
kl: List[str], |
|
v: Any, |
|
d: Optional[MetricRecord] = None, |
|
) -> bool: |
|
has_summary = d and d.HasField("summary") |
|
if len(kl) == 1: |
|
copy_key = tuple(kl) |
|
old_copy = self._metric_copy.get(copy_key) |
|
if old_copy is None or v != old_copy: |
|
self._metric_copy[copy_key] = v |
|
|
|
if not has_summary or (d and d.summary.copy): |
|
self._consolidated_summary[kl[0]] = v |
|
return True |
|
if not d: |
|
return False |
|
if not has_summary: |
|
return False |
|
if not isinstance(v, numbers.Real): |
|
return False |
|
if math.isnan(v): |
|
return False |
|
float_v = float(v) |
|
goal_max = None |
|
if d.goal: |
|
goal_max = d.goal == d.GOAL_MAXIMIZE |
|
if self._update_summary_metrics( |
|
d.summary, kl=kl, v=v, float_v=float_v, goal_max=goal_max |
|
): |
|
return True |
|
return False |
|
|
|
def _update_summary_list( |
|
self, |
|
kl: List[str], |
|
v: Any, |
|
d: Optional[MetricRecord] = None, |
|
) -> bool: |
|
metric_key = ".".join([k.replace(".", "\\.") for k in kl]) |
|
d = self._metric_defines.get(metric_key, d) |
|
|
|
if isinstance(v, dict) and not handler_util.metric_is_wandb_dict(v): |
|
updated = False |
|
for nk, nv in v.items(): |
|
if self._update_summary_list(kl=kl[:] + [nk], v=nv, d=d): |
|
updated = True |
|
return updated |
|
|
|
elif ( |
|
REPLACE_SUMMARY_ART_PATH_WITH_LATEST |
|
and isinstance(v, dict) |
|
and handler_util.metric_is_wandb_dict(v) |
|
): |
|
if "_latest_artifact_path" in v and "artifact_path" in v: |
|
|
|
v["artifact_path"] = v["_latest_artifact_path"] |
|
updated = self._update_summary_leaf(kl=kl, v=v, d=d) |
|
return updated |
|
|
|
def _update_summary_media_objects(self, v: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
for nk, nv in v.items(): |
|
if REPLACE_SUMMARY_ART_PATH_WITH_LATEST and ( |
|
isinstance(nv, dict) |
|
and handler_util.metric_is_wandb_dict(nv) |
|
and "_latest_artifact_path" in nv |
|
and "artifact_path" in nv |
|
): |
|
|
|
nv["artifact_path"] = nv["_latest_artifact_path"] |
|
v[nk] = nv |
|
return v |
|
|
|
def _update_summary(self, history_dict: Dict[str, Any]) -> List[str]: |
|
|
|
if not self._metric_defines: |
|
history_dict = self._update_summary_media_objects(history_dict) |
|
self._consolidated_summary.update(history_dict) |
|
return list(history_dict.keys()) |
|
updated_keys = [] |
|
for k, v in history_dict.items(): |
|
if self._update_summary_list(kl=[k], v=v): |
|
updated_keys.append(k) |
|
return updated_keys |
|
|
|
def _history_assign_step( |
|
self, |
|
history: HistoryRecord, |
|
history_dict: Dict[str, Any], |
|
) -> None: |
|
has_step = history.HasField("step") |
|
item = history.item.add() |
|
item.key = "_step" |
|
if has_step: |
|
step = history.step.num |
|
history_dict["_step"] = step |
|
item.value_json = json.dumps(step) |
|
self._step = step + 1 |
|
else: |
|
history_dict["_step"] = self._step |
|
item.value_json = json.dumps(self._step) |
|
self._step += 1 |
|
|
|
def _history_define_metric(self, hkey: str) -> Optional[MetricRecord]: |
|
"""Check for hkey match in glob metrics and return the defined metric.""" |
|
|
|
if hkey.startswith("_"): |
|
return None |
|
for k, mglob in self._metric_globs.items(): |
|
if k.endswith("*"): |
|
if hkey.startswith(k[:-1]): |
|
m = MetricRecord() |
|
m.CopyFrom(mglob) |
|
m.ClearField("glob_name") |
|
m.options.defined = False |
|
m.name = hkey |
|
return m |
|
return None |
|
|
|
def _history_update_leaf( |
|
self, |
|
kl: List[str], |
|
v: Any, |
|
history_dict: Dict[str, Any], |
|
update_history: Dict[str, Any], |
|
) -> None: |
|
hkey = ".".join([k.replace(".", "\\.") for k in kl]) |
|
m = self._metric_defines.get(hkey) |
|
if not m: |
|
m = self._history_define_metric(hkey) |
|
if not m: |
|
return |
|
mr = Record() |
|
mr.metric.CopyFrom(m) |
|
mr.control.local = True |
|
self._handle_defined_metric(mr) |
|
|
|
if m.options.step_sync and m.step_metric: |
|
if m.step_metric not in history_dict: |
|
copy_key = tuple([m.step_metric]) |
|
step = self._metric_copy.get(copy_key) |
|
if step is not None: |
|
update_history[m.step_metric] = step |
|
|
|
def _history_update_list( |
|
self, |
|
kl: List[str], |
|
v: Any, |
|
history_dict: Dict[str, Any], |
|
update_history: Dict[str, Any], |
|
) -> None: |
|
if isinstance(v, dict): |
|
for nk, nv in v.items(): |
|
self._history_update_list( |
|
kl=kl[:] + [nk], |
|
v=nv, |
|
history_dict=history_dict, |
|
update_history=update_history, |
|
) |
|
return |
|
self._history_update_leaf( |
|
kl=kl, v=v, history_dict=history_dict, update_history=update_history |
|
) |
|
|
|
def _history_update( |
|
self, |
|
history: HistoryRecord, |
|
history_dict: Dict[str, Any], |
|
) -> None: |
|
|
|
if history_dict.get("_step") is None: |
|
self._history_assign_step(history, history_dict) |
|
|
|
update_history: Dict[str, Any] = {} |
|
|
|
if self._metric_defines or self._metric_globs: |
|
for hkey, hval in history_dict.items(): |
|
self._history_update_list([hkey], hval, history_dict, update_history) |
|
|
|
if update_history: |
|
history_dict.update(update_history) |
|
for k, v in update_history.items(): |
|
item = history.item.add() |
|
item.key = k |
|
item.value_json = json.dumps(v) |
|
|
|
def handle_history(self, record: Record) -> None: |
|
history_dict = proto_util.dict_from_proto_list(record.history.item) |
|
|
|
|
|
if history_dict is not None: |
|
if "_runtime" not in history_dict: |
|
self._history_assign_runtime(record.history, history_dict) |
|
|
|
self._history_update(record.history, history_dict) |
|
self._dispatch_record(record) |
|
self._save_history(record.history) |
|
|
|
updated_keys = self._update_summary(history_dict) |
|
if updated_keys: |
|
updated_items = {k: self._consolidated_summary[k] for k in updated_keys} |
|
self._save_summary(updated_items) |
|
|
|
def _flush_partial_history( |
|
self, |
|
step: Optional[int] = None, |
|
) -> None: |
|
if not self._partial_history: |
|
return |
|
|
|
history = HistoryRecord() |
|
for k, v in self._partial_history.items(): |
|
item = history.item.add() |
|
item.key = k |
|
item.value_json = json.dumps(v) |
|
if step is not None: |
|
history.step.num = step |
|
self.handle_history(Record(history=history)) |
|
self._partial_history = {} |
|
|
|
def handle_request_sender_mark_report(self, record: Record) -> None: |
|
self._dispatch_record(record, always_send=True) |
|
|
|
def handle_request_status_report(self, record: Record) -> None: |
|
self._dispatch_record(record, always_send=True) |
|
|
|
def handle_request_partial_history(self, record: Record) -> None: |
|
partial_history = record.request.partial_history |
|
|
|
flush = None |
|
if partial_history.HasField("action"): |
|
flush = partial_history.action.flush |
|
|
|
step = None |
|
if partial_history.HasField("step"): |
|
step = partial_history.step.num |
|
|
|
history_dict = proto_util.dict_from_proto_list(partial_history.item) |
|
if step is not None: |
|
if step < self._step: |
|
if not self._dropped_history: |
|
message = ( |
|
"Step only supports monotonically increasing values, use define_metric to set a custom x " |
|
f"axis. For details see: {url_registry.url('define-metric')}" |
|
) |
|
self._internal_messages.warning.append(message) |
|
self._dropped_history = True |
|
message = ( |
|
f"(User provided step: {step} is less than current step: {self._step}. " |
|
f"Dropping entry: {history_dict})." |
|
) |
|
self._internal_messages.warning.append(message) |
|
return |
|
elif step > self._step: |
|
self._flush_partial_history() |
|
self._step = step |
|
elif flush is None: |
|
flush = True |
|
|
|
self._partial_history.update(history_dict) |
|
|
|
if flush: |
|
self._flush_partial_history(self._step) |
|
|
|
def handle_summary(self, record: Record) -> None: |
|
summary = record.summary |
|
for item in summary.update: |
|
if len(item.nested_key) > 0: |
|
|
|
assert item.key == "" |
|
key = tuple(item.nested_key) |
|
else: |
|
|
|
|
|
key = (item.key,) |
|
|
|
target = self._consolidated_summary |
|
|
|
|
|
for prop in key[:-1]: |
|
target = target[prop] |
|
|
|
|
|
target[key[-1]] = json.loads(item.value_json) |
|
|
|
for item in summary.remove: |
|
if len(item.nested_key) > 0: |
|
|
|
assert item.key == "" |
|
key = tuple(item.nested_key) |
|
else: |
|
|
|
|
|
key = (item.key,) |
|
|
|
target = self._consolidated_summary |
|
|
|
|
|
for prop in key[:-1]: |
|
target = target[prop] |
|
|
|
|
|
del target[key[-1]] |
|
|
|
self._save_summary(self._consolidated_summary) |
|
|
|
def handle_exit(self, record: Record) -> None: |
|
if self._track_time is not None: |
|
self._accumulate_time += time.time() - self._track_time |
|
record.exit.runtime = int(self._accumulate_time) |
|
self._dispatch_record(record, always_send=True) |
|
|
|
def handle_final(self, record: Record) -> None: |
|
self._dispatch_record(record, always_send=True) |
|
|
|
def handle_preempting(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_header(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_footer(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_request_attach(self, record: Record) -> None: |
|
result = proto_util._result_from_record(record) |
|
attach_id = record.request.attach.attach_id |
|
assert attach_id |
|
assert self._run_proto |
|
result.response.attach_response.run.CopyFrom(self._run_proto) |
|
self._respond_result(result) |
|
|
|
def handle_request_log_artifact(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_telemetry(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_request_run_start(self, record: Record) -> None: |
|
run_start = record.request.run_start |
|
assert run_start |
|
assert run_start.run |
|
|
|
self._run_proto = run_start.run |
|
|
|
self._run_start_time = run_start.run.start_time.ToMicroseconds() / 1e6 |
|
|
|
self._track_time = time.time() |
|
if run_start.run.resumed and run_start.run.runtime: |
|
self._accumulate_time = run_start.run.runtime |
|
else: |
|
self._accumulate_time = 0 |
|
|
|
|
|
self._system_monitor = SystemMonitor( |
|
self._settings, |
|
self._interface, |
|
) |
|
if not ( |
|
self._settings.x_disable_stats or self._settings.x_disable_machine_info |
|
): |
|
self._system_monitor.start() |
|
if ( |
|
not (self._settings.x_disable_meta or self._settings.x_disable_machine_info) |
|
and not run_start.run.resumed |
|
): |
|
try: |
|
self._metadata = Metadata(**self._system_monitor.probe(publish=True)) |
|
except Exception: |
|
logger.exception("Error probing system metadata.") |
|
|
|
self._tb_watcher = tb_watcher.TBWatcher( |
|
self._settings, interface=self._interface, run_proto=run_start.run |
|
) |
|
|
|
if run_start.run.resumed or run_start.run.forked: |
|
self._step = run_start.run.starting_step |
|
result = proto_util._result_from_record(record) |
|
self._respond_result(result) |
|
|
|
def handle_request_resume(self, record: Record) -> None: |
|
if self._system_monitor is not None: |
|
logger.info("starting system metrics thread") |
|
self._system_monitor.start() |
|
|
|
if self._track_time is not None: |
|
self._accumulate_time += time.time() - self._track_time |
|
self._track_time = time.time() |
|
|
|
def handle_request_pause(self, record: Record) -> None: |
|
if self._system_monitor is not None: |
|
logger.info("stopping system metrics thread") |
|
self._system_monitor.finish() |
|
if self._track_time is not None: |
|
self._accumulate_time += time.time() - self._track_time |
|
self._track_time = None |
|
|
|
def handle_request_poll_exit(self, record: Record) -> None: |
|
self._dispatch_record(record, always_send=True) |
|
|
|
def handle_request_stop_status(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_request_network_status(self, record: Record) -> None: |
|
self._dispatch_record(record) |
|
|
|
def handle_request_internal_messages(self, record: Record) -> None: |
|
result = proto_util._result_from_record(record) |
|
result.response.internal_messages_response.messages.CopyFrom( |
|
self._internal_messages |
|
) |
|
self._internal_messages.Clear() |
|
self._respond_result(result) |
|
|
|
def handle_request_status(self, record: Record) -> None: |
|
result = proto_util._result_from_record(record) |
|
self._respond_result(result) |
|
|
|
def handle_request_get_summary(self, record: Record) -> None: |
|
result = proto_util._result_from_record(record) |
|
for key, value in self._consolidated_summary.items(): |
|
item = SummaryItem() |
|
item.key = key |
|
item.value_json = json.dumps(value) |
|
result.response.get_summary_response.item.append(item) |
|
self._respond_result(result) |
|
|
|
def handle_request_get_system_metrics(self, record: Record) -> None: |
|
result = proto_util._result_from_record(record) |
|
if self._system_monitor is None: |
|
return |
|
|
|
buffer = self._system_monitor.buffer |
|
for key, samples in buffer.items(): |
|
buff = [] |
|
for s in samples: |
|
sms = SystemMetricSample() |
|
sms.timestamp.FromMicroseconds(int(s[0] * 1e6)) |
|
sms.value = s[1] |
|
buff.append(sms) |
|
|
|
result.response.get_system_metrics_response.system_metrics[key].CopyFrom( |
|
SystemMetricsBuffer(record=buff) |
|
) |
|
|
|
self._respond_result(result) |
|
|
|
def handle_request_get_system_metadata(self, record: Record) -> None: |
|
result = proto_util._result_from_record(record) |
|
if self._system_monitor is None or self._metadata is None: |
|
return |
|
|
|
result.response.get_system_metadata_response.metadata.CopyFrom( |
|
self._metadata.to_proto() |
|
) |
|
self._respond_result(result) |
|
|
|
def handle_tbrecord(self, record: Record) -> None: |
|
logger.info("handling tbrecord: %s", record) |
|
if self._tb_watcher: |
|
tbrecord = record.tbrecord |
|
self._tb_watcher.add(tbrecord.log_dir, tbrecord.save, tbrecord.root_dir) |
|
self._dispatch_record(record) |
|
|
|
def _handle_defined_metric(self, record: Record) -> None: |
|
metric = record.metric |
|
if metric._control.overwrite: |
|
self._metric_defines[metric.name].CopyFrom(metric) |
|
else: |
|
self._metric_defines[metric.name].MergeFrom(metric) |
|
|
|
|
|
|
|
metric = self._metric_defines[metric.name] |
|
if metric.step_metric and metric.step_metric not in self._metric_defines: |
|
m = MetricRecord(name=metric.step_metric) |
|
self._metric_defines[metric.step_metric] = m |
|
mr = Record() |
|
mr.metric.CopyFrom(m) |
|
mr.control.local = True |
|
self._dispatch_record(mr) |
|
|
|
self._dispatch_record(record) |
|
|
|
def _handle_glob_metric(self, record: Record) -> None: |
|
metric = record.metric |
|
if metric._control.overwrite: |
|
self._metric_globs[metric.glob_name].CopyFrom(metric) |
|
else: |
|
self._metric_globs[metric.glob_name].MergeFrom(metric) |
|
self._dispatch_record(record) |
|
|
|
def handle_metric(self, record: Record) -> None: |
|
"""Handle MetricRecord. |
|
|
|
Walkthrough of the life of a MetricRecord: |
|
|
|
Metric defined: |
|
- run.define_metric() parses arguments create wandb_metric.Metric |
|
- build MetricRecord publish to interface |
|
- handler (this function) keeps list of metrics published: |
|
- self._metric_defines: Fully defined metrics |
|
- self._metric_globs: metrics that have a wildcard |
|
- dispatch writer and sender thread |
|
- writer: records are saved to persistent store |
|
- sender: fully defined metrics get mapped into metadata for UI |
|
|
|
History logged: |
|
- handle_history |
|
- check if metric matches _metric_defines |
|
- if not, check if metric matches _metric_globs |
|
- if _metric globs match, generate defined metric and call _handle_metric |
|
|
|
Args: |
|
record (Record): Metric record to process |
|
""" |
|
if record.metric.name: |
|
self._handle_defined_metric(record) |
|
elif record.metric.glob_name: |
|
self._handle_glob_metric(record) |
|
|
|
def handle_request_sampled_history(self, record: Record) -> None: |
|
result = proto_util._result_from_record(record) |
|
for key, sampled in self._sampled_history.items(): |
|
item = SampledHistoryItem() |
|
item.key = key |
|
values: Iterable[Any] = sampled.get() |
|
if all(isinstance(i, numbers.Integral) for i in values): |
|
try: |
|
item.values_int.extend(values) |
|
except ValueError: |
|
|
|
pass |
|
elif all(isinstance(i, numbers.Real) for i in values): |
|
item.values_float.extend(values) |
|
result.response.sampled_history_response.item.append(item) |
|
self._respond_result(result) |
|
|
|
def handle_request_keepalive(self, record: Record) -> None: |
|
"""Handle a keepalive request. |
|
|
|
Keepalive is a noop, we just want to verify transport is alive. |
|
""" |
|
|
|
def handle_request_run_status(self, record: Record) -> None: |
|
self._dispatch_record(record, always_send=True) |
|
|
|
def handle_request_shutdown(self, record: Record) -> None: |
|
|
|
result = proto_util._result_from_record(record) |
|
self._respond_result(result) |
|
self._stopped.set() |
|
|
|
def handle_request_operations(self, record: Record) -> None: |
|
"""No-op. Not implemented for the legacy-service.""" |
|
self._respond_result(proto_util._result_from_record(record)) |
|
|
|
def finish(self) -> None: |
|
logger.info("shutting down handler") |
|
if self._system_monitor is not None: |
|
self._system_monitor.finish() |
|
if self._tb_watcher: |
|
self._tb_watcher.finish() |
|
|
|
|
|
def __next__(self) -> Record: |
|
return self._record_q.get(block=True) |
|
|
|
next = __next__ |
|
|
|
def _history_assign_runtime( |
|
self, |
|
history: HistoryRecord, |
|
history_dict: Dict[str, Any], |
|
) -> None: |
|
|
|
if "_timestamp" not in history_dict: |
|
return |
|
|
|
|
|
if self._run_start_time is None: |
|
self._run_start_time = history_dict["_timestamp"] |
|
history_dict["_runtime"] = history_dict["_timestamp"] - self._run_start_time |
|
item = history.item.add() |
|
item.key = "_runtime" |
|
item.value_json = json.dumps(history_dict[item.key]) |
|
|