# Copyright (c) Microsoft Corporation. | |
# SPDX-License-Identifier: Apache-2.0 | |
# DeepSpeed Team | |
from deepspeed.accelerator import get_accelerator | |
from deepspeed.runtime.compiler import is_compiling | |
enable_nvtx = True | |
def instrument_w_nvtx(func): | |
"""Decorator that records an NVTX range for the duration of the function call. | |
Skips NVTX instrumentation when torch.compile is active to avoid graph breaks. | |
""" | |
def wrapped_fn(*args, **kwargs): | |
if enable_nvtx and not is_compiling(): | |
get_accelerator().range_push(func.__qualname__) | |
ret_val = func(*args, **kwargs) | |
if enable_nvtx and not is_compiling(): | |
get_accelerator().range_pop() | |
return ret_val | |
return wrapped_fn | |