File size: 746 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.compiler import is_compiling
enable_nvtx = True
def instrument_w_nvtx(func):
"""Decorator that records an NVTX range for the duration of the function call.
Skips NVTX instrumentation when torch.compile is active to avoid graph breaks.
"""
def wrapped_fn(*args, **kwargs):
if enable_nvtx and not is_compiling():
get_accelerator().range_push(func.__qualname__)
ret_val = func(*args, **kwargs)
if enable_nvtx and not is_compiling():
get_accelerator().range_pop()
return ret_val
return wrapped_fn
|