File size: 2,518 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Context Keeper."""

import logging
import threading
from typing import Dict, Optional

from wandb.proto.wandb_internal_pb2 import Record, Result

logger = logging.getLogger(__name__)


class Context:
    _cancel_event: threading.Event
    # TODO(debug_context) add debug setting to enable this
    # _debug_record: Optional[Record]

    def __init__(self) -> None:
        self._cancel_event = threading.Event()
        # TODO(debug_context) see above
        # self._debug_record = None

    def cancel(self) -> None:
        self._cancel_event.set()

    @property
    def cancel_event(self) -> threading.Event:
        return self._cancel_event


def context_id_from_record(record: Record) -> str:
    context_id = record.control.mailbox_slot
    return context_id


def context_id_from_result(result: Result) -> str:
    context_id = result.control.mailbox_slot
    return context_id


class ContextKeeper:
    _active_items: Dict[str, Context]

    def __init__(self) -> None:
        self._active_items = {}

    def add_from_record(self, record: Record) -> Optional[Context]:
        context_id = context_id_from_record(record)
        if not context_id:
            return None
        context_obj = self.add(context_id)

        # TODO(debug_context) see above
        # context_obj._debug_record = record

        return context_obj

    def add(self, context_id: str) -> Context:
        assert context_id
        context_obj = Context()
        self._active_items[context_id] = context_obj
        return context_obj

    def get(self, context_id: str) -> Optional[Context]:
        item = self._active_items.get(context_id)
        return item

    def release(self, context_id: str) -> None:
        if not context_id:
            return
        _ = self._active_items.pop(context_id, None)

    def cancel(self, context_id: str) -> bool:
        item = self.get(context_id)
        if item:
            item.cancel()
            return True
        return False

    # TODO(debug_context) see above
    # def _debug_print_orphans(self, print_to_stdout: bool) -> None:
    #     for context_id, context in self._active_items.items():
    #         record = context._debug_record
    #         record_type = record.WhichOneof("record_type") if record else "unknown"
    #         message = (
    #             f"Context: {context_id} {context.cancel_event.is_set()} {record_type}"
    #         )
    #         logger.warning(message)
    #         if print_to_stdout:
    #             print(message)