|
"""Integration with pytorch profiler.""" |
|
|
|
import os |
|
|
|
import wandb |
|
from wandb.errors import Error, UsageError |
|
from wandb.sdk.lib import telemetry |
|
|
|
PYTORCH_MODULE = "torch" |
|
PYTORCH_PROFILER_MODULE = "torch.profiler" |
|
|
|
|
|
def torch_trace_handler(): |
|
"""Create a trace handler for traces generated by the profiler. |
|
|
|
Provide as an argument to `torch.profiler.profile`: |
|
```python |
|
torch.profiler.profile(..., on_trace_ready=wandb.profiler.torch_trace_handler()) |
|
``` |
|
|
|
Calling this function ensures that profiler charts & tables can be viewed in |
|
your run dashboard on wandb.ai. |
|
|
|
Please note that `wandb.init()` must be called before this function is |
|
invoked, and the reinit setting must not be set to "create_new". The PyTorch |
|
(torch) version must also be at least 1.9, in order to ensure stability of |
|
their Profiler API. |
|
|
|
Args: |
|
None |
|
|
|
Returns: |
|
None |
|
|
|
Raises: |
|
UsageError if wandb.init() hasn't been called before profiling. |
|
Error if torch version is less than 1.9.0. |
|
|
|
Examples: |
|
```python |
|
run = wandb.init() |
|
run.config.id = "profile_code" |
|
|
|
with torch.profiler.profile( |
|
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), |
|
on_trace_ready=wandb.profiler.torch_trace_handler(), |
|
record_shapes=True, |
|
with_stack=True, |
|
) as prof: |
|
for i, batch in enumerate(dataloader): |
|
if step >= 5: |
|
break |
|
train(batch) |
|
prof.step() |
|
``` |
|
""" |
|
from packaging.version import parse |
|
|
|
torch = wandb.util.get_module(PYTORCH_MODULE, required=True) |
|
torch_profiler = wandb.util.get_module(PYTORCH_PROFILER_MODULE, required=True) |
|
|
|
if parse(torch.__version__) < parse("1.9.0"): |
|
raise Error( |
|
f"torch version must be at least 1.9 in order to use the PyTorch Profiler API.\ |
|
\nVersion of torch currently installed: {torch.__version__}" |
|
) |
|
|
|
try: |
|
logdir = os.path.join(wandb.run.dir, "pytorch_traces") |
|
os.mkdir(logdir) |
|
except AttributeError: |
|
raise UsageError( |
|
"Please call `wandb.init()` before `wandb.profiler.torch_trace_handler()`" |
|
) from None |
|
|
|
with telemetry.context() as tel: |
|
tel.feature.torch_profiler_trace = True |
|
|
|
return torch_profiler.tensorboard_trace_handler(logdir) |
|
|