File size: 2,391 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Integration with pytorch profiler."""

import os

import wandb
from wandb.errors import Error, UsageError
from wandb.sdk.lib import telemetry

PYTORCH_MODULE = "torch"
PYTORCH_PROFILER_MODULE = "torch.profiler"


def torch_trace_handler():
    """Create a trace handler for traces generated by the profiler.

     Provide as an argument to `torch.profiler.profile`:
     ```python
     torch.profiler.profile(..., on_trace_ready=wandb.profiler.torch_trace_handler())
     ```

    Calling this function ensures that profiler charts & tables can be viewed in
    your run dashboard on wandb.ai.

    Please note that `wandb.init()` must be called before this function is
    invoked, and the reinit setting must not be set to "create_new". The PyTorch
    (torch) version must also be at least 1.9, in order to ensure stability of
    their Profiler API.

    Args:
        None

    Returns:
        None

    Raises:
        UsageError if wandb.init() hasn't been called before profiling.
        Error if torch version is less than 1.9.0.

    Examples:
    ```python
    run = wandb.init()
    run.config.id = "profile_code"

    with torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
        on_trace_ready=wandb.profiler.torch_trace_handler(),
        record_shapes=True,
        with_stack=True,
    ) as prof:
        for i, batch in enumerate(dataloader):
            if step >= 5:
                break
            train(batch)
            prof.step()
    ```
    """
    from packaging.version import parse

    torch = wandb.util.get_module(PYTORCH_MODULE, required=True)
    torch_profiler = wandb.util.get_module(PYTORCH_PROFILER_MODULE, required=True)

    if parse(torch.__version__) < parse("1.9.0"):
        raise Error(
            f"torch version must be at least 1.9 in order to use the PyTorch Profiler API.\
            \nVersion of torch currently installed: {torch.__version__}"
        )

    try:
        logdir = os.path.join(wandb.run.dir, "pytorch_traces")  # type: ignore
        os.mkdir(logdir)
    except AttributeError:
        raise UsageError(
            "Please call `wandb.init()` before `wandb.profiler.torch_trace_handler()`"
        ) from None

    with telemetry.context() as tel:
        tel.feature.torch_profiler_trace = True

    return torch_profiler.tensorboard_trace_handler(logdir)