File size: 7,591 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import copy
import dataclasses
import logging
import os
from enum import Enum
from typing import Optional, Union

from torch._inductor.remote_cache import JsonDataTy, RemoteCacheJsonSerde
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.utils._appending_byte_serializer import (
    AppendingByteSerializer,
    BytesReader,
    BytesWriter,
)
from torch.utils._ordered_set import OrderedSet


log = logging.getLogger(__name__)


class CacheArtifactType(Enum):
    """
    Type of cache
    """

    INDUCTOR = 0
    AUTOTUNE = 1
    AOT_AUTOGRAD = 2
    PGO = 3


@dataclasses.dataclass(frozen=True)
class CacheArtifact:
    """
    Data for each cache artifact that will be serialized and deserialized
    """

    type: CacheArtifactType
    key: str
    content: bytes = dataclasses.field(repr=False)  # Do not display potential binary

    @staticmethod
    def serialize(writer: BytesWriter, cls: "CacheArtifact") -> None:
        writer.write_uint64(cls.type.value)
        writer.write_str(cls.key)
        writer.write_bytes(cls.content)

    @staticmethod
    def deserialize(reader: BytesReader) -> "CacheArtifact":
        type = reader.read_uint64()
        key = reader.read_str()
        content = reader.read_bytes()
        return CacheArtifact(CacheArtifactType(type), key, content)


@dataclasses.dataclass
class CacheInfo:
    """
    Return value of serialization and deserialization for the purpose of
    instrumentation
    """

    inductor_artifacts: list[str] = dataclasses.field(default_factory=list)
    autotune_artifacts: list[str] = dataclasses.field(default_factory=list)
    aot_autograd_artifacts: list[str] = dataclasses.field(default_factory=list)
    pgo_artifacts: list[str] = dataclasses.field(default_factory=list)

    def add(self, artifact: CacheArtifact) -> None:
        if artifact.type == CacheArtifactType.INDUCTOR:
            self.inductor_artifacts.append(artifact.key)
        elif artifact.type == CacheArtifactType.AUTOTUNE:
            self.autotune_artifacts.append(artifact.key)
        elif artifact.type == CacheArtifactType.AOT_AUTOGRAD:
            self.aot_autograd_artifacts.append(artifact.key)
        elif artifact.type == CacheArtifactType.PGO:
            self.pgo_artifacts.append(artifact.key)
        else:
            log.warning(f"Unsupported artifact type {artifact.type}")  # noqa: G004

    def clear(self) -> None:
        self.inductor_artifacts.clear()
        self.autotune_artifacts.clear()
        self.aot_autograd_artifacts.clear()
        self.pgo_artifacts.clear()


class CacheArtifactManager:
    """
    Lightweight manager class for collecting and processing cache artifacts for
    hot loading

    Intended Lifecycle:
    - Execute code via torch.compile, this will call
        CacheArtifactManager.record_artifact on each cache artifact
    - Call CacheArtifactManager.serialize to convert all the cache artifacts
        to portable format
    - Call CacheArtifactManager.deserialize to hot load the cache artifacts on
        a potentially different process

    NOTE: There's no FB/FC guarentees, results of cache artifacts will not be
          used unless code version matches.
    """

    # Protected by the compile_lock
    _new_cache_artifacts: list[CacheArtifact] = []
    # Keep a seperate seen artifacts list to make avoid unnecessary duplicates
    # This list will not be cleared between serialize() calls
    _seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
    # When serialize() is called, artifacts are transferred from _cache_artifacts to
    # internal data structure of the _serializer
    # This allows us to only pay the cost of serialization if serialize() is called
    _serializer: AppendingByteSerializer[CacheArtifact] = AppendingByteSerializer(
        serialize_fn=CacheArtifact.serialize
    )
    _cache_info: CacheInfo = CacheInfo()

    @classmethod
    def clear(cls) -> None:
        cls._new_cache_artifacts.clear()
        cls._seen_artifacts.clear()
        cls._serializer.clear()
        cls._cache_info.clear()

    @classmethod
    def record_artifact(
        cls,
        artifact_type: CacheArtifactType,
        key: str,
        content: Union[bytes, JsonDataTy],
    ) -> None:
        """
        Called from each caching operation to record the artifact in this
        "mega" list
        """
        if artifact_type == CacheArtifactType.AUTOTUNE:
            assert not isinstance(content, bytes)
            serde = RemoteCacheJsonSerde()
            content = serde.encode(content)
        assert isinstance(content, bytes)
        artifact = CacheArtifact(artifact_type, key, content)
        if artifact in cls._seen_artifacts:
            return
        log.debug("Recording %s", str(artifact))
        cls._new_cache_artifacts.append(artifact)
        cls._seen_artifacts.add(artifact)

    @classmethod
    def need_serialize(cls) -> bool:
        """
        Have we seen new artifacts since last serialize call?
        """
        return len(cls._new_cache_artifacts) != 0

    @classmethod
    def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
        """
        Converts the "mega" list into portable format
        """
        for artifact in cls._new_cache_artifacts:
            log.debug("saving: %s", artifact)
            cls._cache_info.add(artifact)
        try:
            # We deep copy cls._cache_info since later compilations
            # can keep adding to cache_info
            info = copy.deepcopy(cls._cache_info)
            cls._serializer.extend(cls._new_cache_artifacts)
            artifact_bytes = cls._serializer.to_bytes()
            cls._new_cache_artifacts.clear()
            return artifact_bytes, info
        except Exception:
            log.warning("Failed to pickle cache artifacts", exc_info=True)
        return None

    @staticmethod
    def deserialize(serialized_artifacts: bytes) -> Optional[CacheInfo]:
        """
        Converts the portable format back into various filesystem caches
        """
        try:
            artifacts = AppendingByteSerializer.to_list(
                serialized_artifacts, deserialize_fn=CacheArtifact.deserialize
            )
        except Exception:
            log.warning("Failed to un-pickle cache artifacts", exc_info=True)
            return None

        from torch._dynamo.pgo import write_local_impl
        from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
        from torch._inductor.codecache import FxGraphCache
        from torch._inductor.runtime.autotune_cache import _LocalAutotuneCacheBackend

        autotune_cache = _LocalAutotuneCacheBackend()

        info = CacheInfo()
        for artifact in artifacts:
            log.debug("writing: %s", artifact)
            info.add(artifact)

            if artifact.type == CacheArtifactType.INDUCTOR:
                FxGraphCache._write_to_local_cache(artifact.key, artifact.content)
            elif artifact.type == CacheArtifactType.AUTOTUNE:
                key = os.path.join(cache_dir(), artifact.key)
                autotune_cache._put(key, artifact.content)
            elif artifact.type == CacheArtifactType.AOT_AUTOGRAD:
                AOTAutogradCache._write_to_local_cache(artifact.key, artifact.content)
            elif artifact.type == CacheArtifactType.PGO:
                meta = write_local_impl(artifact.key, artifact.content)
                assert meta is not None
            else:
                log.warning(f"Unsupported artifact type {artifact.type}")  # noqa: G004
        return info