jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
"""Integration with pytorch profiler."""
import os
import wandb
from wandb.errors import Error, UsageError
from wandb.sdk.lib import telemetry
PYTORCH_MODULE = "torch"
PYTORCH_PROFILER_MODULE = "torch.profiler"
def torch_trace_handler():
"""Create a trace handler for traces generated by the profiler.
Provide as an argument to `torch.profiler.profile`:
```python
torch.profiler.profile(..., on_trace_ready=wandb.profiler.torch_trace_handler())
```
Calling this function ensures that profiler charts & tables can be viewed in
your run dashboard on wandb.ai.
Please note that `wandb.init()` must be called before this function is
invoked, and the reinit setting must not be set to "create_new". The PyTorch
(torch) version must also be at least 1.9, in order to ensure stability of
their Profiler API.
Args:
None
Returns:
None
Raises:
UsageError if wandb.init() hasn't been called before profiling.
Error if torch version is less than 1.9.0.
Examples:
```python
run = wandb.init()
run.config.id = "profile_code"
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
on_trace_ready=wandb.profiler.torch_trace_handler(),
record_shapes=True,
with_stack=True,
) as prof:
for i, batch in enumerate(dataloader):
if step >= 5:
break
train(batch)
prof.step()
```
"""
from packaging.version import parse
torch = wandb.util.get_module(PYTORCH_MODULE, required=True)
torch_profiler = wandb.util.get_module(PYTORCH_PROFILER_MODULE, required=True)
if parse(torch.__version__) < parse("1.9.0"):
raise Error(
f"torch version must be at least 1.9 in order to use the PyTorch Profiler API.\
\nVersion of torch currently installed: {torch.__version__}"
)
try:
logdir = os.path.join(wandb.run.dir, "pytorch_traces") # type: ignore
os.mkdir(logdir)
except AttributeError:
raise UsageError(
"Please call `wandb.init()` before `wandb.profiler.torch_trace_handler()`"
) from None
with telemetry.context() as tel:
tel.feature.torch_profiler_trace = True
return torch_profiler.tensorboard_trace_handler(logdir)