File size: 888 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from packaging import version as pkg_version
import torch
def required_torch_version(min_version=None, max_version=None):
assert min_version or max_version, "Must provide a min_version or max_version argument"
torch_version = pkg_version.parse(torch.__version__)
if min_version and pkg_version.parse(str(min_version)) > torch_version:
return False
if max_version and pkg_version.parse(str(max_version)) < torch_version:
return False
return True
def register_grad_hook(param, hook):
if required_torch_version(min_version=2.1):
return param.register_post_accumulate_grad_hook(hook)
else:
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
return grad_acc.register_hook(hook)
|