# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team from deepspeed.accelerator import get_accelerator from deepspeed.runtime.compiler import is_compiling enable_nvtx = True def instrument_w_nvtx(func): """Decorator that records an NVTX range for the duration of the function call. Skips NVTX instrumentation when torch.compile is active to avoid graph breaks. """ def wrapped_fn(*args, **kwargs): if enable_nvtx and not is_compiling(): get_accelerator().range_push(func.__qualname__) ret_val = func(*args, **kwargs) if enable_nvtx and not is_compiling(): get_accelerator().range_pop() return ret_val return wrapped_fn