File size: 5,711 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
"""
Dynamo profiling implementation.
This module provides profiling functionality for Dynamo, including:
- ProfileMetrics: Class for collecting and aggregating performance metrics like
execution time, operator counts, and fusion statistics
- ProfileResult: Class for analyzing and reporting profiling results
- Utilities for tracking missed/uncaptured operations
- Functions for instrumenting FX graphs with profiling capabilities
The profiler helps measure and optimize the performance of Dynamo-compiled code
by tracking both captured and total operations, timing, and graph statistics.
"""
import dataclasses
import os
from typing import Any
from typing_extensions import Self
import torch
from .utils import print_once
@dataclasses.dataclass
class ProfileMetrics:
microseconds: float = 0.0
operators: int = 0
fusions: int = 0
graphs: int = 0
def __iadd__(self, other: Self) -> Self:
self.microseconds += other.microseconds
self.operators += other.operators
self.fusions += other.fusions
return self
def __add__(self, other: "ProfileMetrics") -> "ProfileMetrics":
assert isinstance(other, ProfileMetrics)
return ProfileMetrics(
self.microseconds + other.microseconds,
self.operators + other.operators,
self.fusions + other.fusions,
)
def __truediv__(self, other: Any) -> "ProfileMetrics":
if isinstance(other, int):
other = ProfileMetrics(other, other, other)
return ProfileMetrics(
self.microseconds / max(1, other.microseconds),
self.operators / max(1, other.operators),
self.fusions / max(1, other.fusions),
)
def __str__(self) -> str:
return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time"
def tocsv(self) -> list[float]:
return [self.operators, self.microseconds]
class ProfileResult:
def __init__(
self, captured: ProfileMetrics, total: ProfileMetrics, unique_graphs: int
) -> None:
self.captured: ProfileMetrics = captured or ProfileMetrics()
self.total: ProfileMetrics = total or ProfileMetrics()
self.unique_graphs: int = unique_graphs
def __iadd__(self, other: Self) -> Self:
self.captured += other.captured
self.total += other.total
self.unique_graphs += other.unique_graphs
return self
def percent(self) -> ProfileMetrics:
return self.captured / self.total
def __str__(self) -> str:
return (
f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls "
f"{self.captured.operators:4}/{self.total.operators:4} = "
+ str(self.percent())
)
def tocsv(self) -> list[Any]:
return [
self.unique_graphs,
self.captured.graphs,
self.captured.operators,
self.total.operators,
] + self.percent().tocsv()
def should_print_missing() -> bool:
return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1"
def print_missing(stack: list[str]) -> None:
if any("/torch/autograd/profiler.py" in x for x in stack):
return
stack = [
x for x in stack if ("<built-in" not in x and "site-packages/torch/" not in x)
]
print_once("MISSING", " >> ".join(stack[-3:]))
class Profiler:
unique_graphs: int = 0
def __init__(self) -> None:
self.prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU],
with_stack=should_print_missing(),
)
def results(self) -> ProfileResult:
captured_regions = 0
captured_ops = 0
captured_microseconds = 0
total_ops = 0
total_microseconds = 0
last_op_end_time = -1
captured_region_end_time = -1
events = sorted(self.prof.events(), key=lambda x: x.time_range.start)
for e in events:
if e.name == "TORCHDYNAMO":
captured_region_end_time = e.time_range.end
captured_regions += 1
# ignore `handle = torch.zeros(1)` in record_function.__init__()
total_ops -= 1
elif e.time_range.start >= last_op_end_time:
last_op_end_time = e.time_range.end
if e.time_range.end <= captured_region_end_time:
captured_ops += 1
captured_microseconds += e.time_range.elapsed_us()
elif should_print_missing():
print_missing(e.stack)
total_ops += 1
total_microseconds += e.time_range.elapsed_us()
else:
pass # ops recursively called from other ops (ignored)
unique_graphs = Profiler.unique_graphs
Profiler.unique_graphs = 0
# we counted one extra op that is part of the profiler setup code
total_ops -= 1
return ProfileResult(
captured=ProfileMetrics(
microseconds=captured_microseconds,
operators=captured_ops,
fusions=captured_ops - captured_regions,
graphs=captured_regions,
),
total=ProfileMetrics(
microseconds=total_microseconds,
operators=total_ops,
fusions=total_ops - 1,
),
unique_graphs=unique_graphs,
)
def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: list[Any]) -> Any:
def _wrapped(*args: Any) -> Any:
with torch.profiler.record_function("TORCHDYNAMO"):
return gm.forward(*args)
Profiler.unique_graphs += 1
return _wrapped
|