File size: 712 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.utils.torch import required_torch_version
try:
from torch.compiler import is_compiling as torch_is_compiling
except ImportError:
try:
from torch._dynamo.external_utils import is_compiling as torch_is_compiling
except ImportError:
# Torch does not have compiler support
torch_is_compiling = lambda: False
def is_compile_supported():
return required_torch_version(min_version=2.1)
def disable(func):
if is_compile_supported():
return torch.compiler.disable(func)
return func
def is_compiling():
return torch_is_compiling()
|