# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team from packaging import version as pkg_version import torch def required_torch_version(min_version=None, max_version=None): assert min_version or max_version, "Must provide a min_version or max_version argument" torch_version = pkg_version.parse(torch.__version__) if min_version and pkg_version.parse(str(min_version)) > torch_version: return False if max_version and pkg_version.parse(str(max_version)) < torch_version: return False return True def register_grad_hook(param, hook): if required_torch_version(min_version=2.1): return param.register_post_accumulate_grad_hook(hook) else: param_tmp = param.expand_as(param) grad_acc = param_tmp.grad_fn.next_functions[0][0] return grad_acc.register_hook(hook)