|
|
|
from typing import Any, Union |
|
|
|
import torch |
|
from torch.utils._contextlib import ( |
|
_DecoratorContextManager, |
|
_NoParamDecoratorContextManager, |
|
F, |
|
) |
|
|
|
|
|
__all__ = [ |
|
"no_grad", |
|
"enable_grad", |
|
"set_grad_enabled", |
|
"inference_mode", |
|
"set_multithreading_enabled", |
|
] |
|
|
|
|
|
class no_grad(_NoParamDecoratorContextManager): |
|
r"""Context-manager that disables gradient calculation. |
|
|
|
Disabling gradient calculation is useful for inference, when you are sure |
|
that you will not call :meth:`Tensor.backward()`. It will reduce memory |
|
consumption for computations that would otherwise have `requires_grad=True`. |
|
|
|
In this mode, the result of every computation will have |
|
`requires_grad=False`, even when the inputs have `requires_grad=True`. |
|
There is an exception! All factory functions, or functions that create |
|
a new Tensor and take a requires_grad kwarg, will NOT be affected by |
|
this mode. |
|
|
|
This context manager is thread local; it will not affect computation |
|
in other threads. |
|
|
|
Also functions as a decorator. |
|
|
|
.. note:: |
|
No-grad is one of several mechanisms that can enable or |
|
disable gradients locally see :ref:`locally-disable-grad-doc` for |
|
more information on how they compare. |
|
|
|
.. note:: |
|
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. |
|
If you want to disable forward AD for a computation, you can unpack |
|
your dual tensors. |
|
|
|
Example:: |
|
>>> # xdoctest: +SKIP |
|
>>> x = torch.tensor([1.], requires_grad=True) |
|
>>> with torch.no_grad(): |
|
... y = x * 2 |
|
>>> y.requires_grad |
|
False |
|
>>> @torch.no_grad() |
|
... def doubler(x): |
|
... return x * 2 |
|
>>> z = doubler(x) |
|
>>> z.requires_grad |
|
False |
|
>>> @torch.no_grad() |
|
... def tripler(x): |
|
... return x * 3 |
|
>>> z = tripler(x) |
|
>>> z.requires_grad |
|
False |
|
>>> # factory function exception |
|
>>> with torch.no_grad(): |
|
... a = torch.nn.Parameter(torch.rand(10)) |
|
>>> a.requires_grad |
|
True |
|
""" |
|
|
|
def __init__(self) -> None: |
|
if not torch._jit_internal.is_scripting(): |
|
super().__init__() |
|
self.prev = False |
|
|
|
def __enter__(self) -> None: |
|
self.prev = torch.is_grad_enabled() |
|
torch.set_grad_enabled(False) |
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
|
torch.set_grad_enabled(self.prev) |
|
|
|
|
|
class enable_grad(_NoParamDecoratorContextManager): |
|
r"""Context-manager that enables gradient calculation. |
|
|
|
Enables gradient calculation, if it has been disabled via :class:`~no_grad` |
|
or :class:`~set_grad_enabled`. |
|
|
|
This context manager is thread local; it will not affect computation |
|
in other threads. |
|
|
|
Also functions as a decorator. |
|
|
|
.. note:: |
|
enable_grad is one of several mechanisms that can enable or |
|
disable gradients locally see :ref:`locally-disable-grad-doc` for |
|
more information on how they compare. |
|
|
|
.. note:: |
|
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. |
|
|
|
Example:: |
|
>>> # xdoctest: +SKIP |
|
>>> x = torch.tensor([1.], requires_grad=True) |
|
>>> with torch.no_grad(): |
|
... with torch.enable_grad(): |
|
... y = x * 2 |
|
>>> y.requires_grad |
|
True |
|
>>> y.backward() |
|
>>> x.grad |
|
tensor([2.]) |
|
>>> @torch.enable_grad() |
|
... def doubler(x): |
|
... return x * 2 |
|
>>> with torch.no_grad(): |
|
... z = doubler(x) |
|
>>> z.requires_grad |
|
True |
|
>>> @torch.enable_grad() |
|
... def tripler(x): |
|
... return x * 3 |
|
>>> with torch.no_grad(): |
|
... z = tripler(x) |
|
>>> z.requires_grad |
|
True |
|
|
|
""" |
|
|
|
def __enter__(self) -> None: |
|
self.prev = torch.is_grad_enabled() |
|
torch._C._set_grad_enabled(True) |
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
|
torch._C._set_grad_enabled(self.prev) |
|
|
|
|
|
class set_grad_enabled(_DecoratorContextManager): |
|
r"""Context-manager that sets gradient calculation on or off. |
|
|
|
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`. |
|
It can be used as a context-manager or as a function. |
|
|
|
This context manager is thread local; it will not affect computation |
|
in other threads. |
|
|
|
Args: |
|
mode (bool): Flag whether to enable grad (``True``), or disable |
|
(``False``). This can be used to conditionally enable |
|
gradients. |
|
|
|
.. note:: |
|
set_grad_enabled is one of several mechanisms that can enable or |
|
disable gradients locally see :ref:`locally-disable-grad-doc` for |
|
more information on how they compare. |
|
|
|
.. note:: |
|
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. |
|
|
|
Example:: |
|
>>> # xdoctest: +SKIP |
|
>>> x = torch.tensor([1.], requires_grad=True) |
|
>>> is_train = False |
|
>>> with torch.set_grad_enabled(is_train): |
|
... y = x * 2 |
|
>>> y.requires_grad |
|
False |
|
>>> _ = torch.set_grad_enabled(True) |
|
>>> y = x * 2 |
|
>>> y.requires_grad |
|
True |
|
>>> _ = torch.set_grad_enabled(False) |
|
>>> y = x * 2 |
|
>>> y.requires_grad |
|
False |
|
|
|
""" |
|
|
|
def __init__(self, mode: bool) -> None: |
|
self.prev = torch.is_grad_enabled() |
|
self.mode = mode |
|
torch._C._set_grad_enabled(mode) |
|
|
|
def __call__(self, orig_func: F) -> F: |
|
torch._C._set_grad_enabled(self.prev) |
|
return super().__call__(orig_func) |
|
|
|
def __enter__(self) -> None: |
|
torch._C._set_grad_enabled(self.mode) |
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
|
torch._C._set_grad_enabled(self.prev) |
|
|
|
def clone(self) -> "set_grad_enabled": |
|
r""" |
|
Create a copy of this class |
|
""" |
|
return self.__class__(self.mode) |
|
|
|
|
|
class inference_mode(_DecoratorContextManager): |
|
r"""Context-manager that enables or disables inference mode. |
|
|
|
InferenceMode is a context manager analogous to :class:`~no_grad` |
|
to be used when you are certain your operations will have no interactions |
|
with autograd (e.g., model training). Code run under this mode gets better |
|
performance by disabling view tracking and version counter bumps. Note that |
|
unlike some other mechanisms that locally enable or disable grad, |
|
entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`. |
|
|
|
This context manager is thread local; it will not affect computation |
|
in other threads. |
|
|
|
Also functions as a decorator. |
|
|
|
.. note:: |
|
Inference mode is one of several mechanisms that can enable or |
|
disable gradients locally see :ref:`locally-disable-grad-doc` for |
|
more information on how they compare. |
|
|
|
Args: |
|
mode (bool or function): Either a boolean flag whether to enable or |
|
disable inference mode or a Python function to decorate with |
|
inference mode enabled |
|
|
|
Example:: |
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) |
|
>>> import torch |
|
>>> x = torch.ones(1, 2, 3, requires_grad=True) |
|
>>> with torch.inference_mode(): |
|
... y = x * x |
|
>>> y.requires_grad |
|
False |
|
>>> # xdoctest: +SKIP("want string isnt quite right") |
|
>>> y._version |
|
Traceback (most recent call last): |
|
File "<stdin>", line 1, in <module> |
|
RuntimeError: Inference tensors do not track version counter. |
|
>>> @torch.inference_mode() |
|
... def func(x): |
|
... return x * x |
|
>>> out = func(x) |
|
>>> out.requires_grad |
|
False |
|
>>> @torch.inference_mode() |
|
... def doubler(x): |
|
... return x * 2 |
|
>>> out = doubler(x) |
|
>>> out.requires_grad |
|
False |
|
|
|
""" |
|
|
|
def __init__(self, mode: bool = True) -> None: |
|
if not torch._jit_internal.is_scripting(): |
|
super().__init__() |
|
self.mode = mode |
|
|
|
def __new__(cls, mode=True): |
|
if isinstance(mode, bool): |
|
return super().__new__(cls) |
|
return cls()(mode) |
|
|
|
def __enter__(self) -> None: |
|
self._inference_mode_context = torch._C._InferenceMode(self.mode) |
|
self._inference_mode_context.__enter__() |
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
|
self._inference_mode_context.__exit__(exc_type, exc_value, traceback) |
|
|
|
def clone(self) -> "inference_mode": |
|
r""" |
|
Create a copy of this class |
|
""" |
|
return self.__class__(self.mode) |
|
|
|
|
|
def _enter_inference_mode(mode): |
|
mode_context = torch._C._InferenceMode(mode) |
|
mode_context.__enter__() |
|
return mode_context |
|
|
|
|
|
def _exit_inference_mode(mode): |
|
mode.__exit__(None, None, None) |
|
|
|
|
|
class set_multithreading_enabled(_DecoratorContextManager): |
|
r"""Context-manager that sets multithreaded backwards on or off. |
|
|
|
``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`. |
|
It can be used as a context-manager or as a function. |
|
|
|
This context manager is thread local; it will not affect computation |
|
in other threads. |
|
|
|
Args: |
|
mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable |
|
(``False``). |
|
|
|
.. note:: |
|
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. |
|
|
|
""" |
|
|
|
def __init__(self, mode: bool) -> None: |
|
self.prev = torch._C._is_multithreading_enabled() |
|
torch._C._set_multithreading_enabled(mode) |
|
self.mode = mode |
|
|
|
def __enter__(self) -> None: |
|
pass |
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
|
torch._C._set_multithreading_enabled(self.prev) |
|
|
|
def clone(self) -> "set_multithreading_enabled": |
|
r""" |
|
Create a copy of this class |
|
""" |
|
return self.__class__(self.mode) |
|
|
|
|
|
class _force_original_view_tracking(_DecoratorContextManager): |
|
r"""Context-manager that sets whether or not to always enable view-replay in autograd. |
|
|
|
``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`. |
|
It can be used as a context-manager or as a function. |
|
|
|
This context manager is thread local; it will not affect computation |
|
in other threads. |
|
|
|
When a tensor view is mutated, the autograd engine needs to decide whether or not |
|
to regenerate the "updated view" by either replaying the chain of views from the updated base, |
|
or with a single call to as_strided. |
|
|
|
If set_view_replay_enabled is set to True, then autograd will always use view replay. |
|
Otherwise, it will fall back to its existing logic. |
|
|
|
Args: |
|
mode (bool): Flag whether to enable view-replay (``True``), or disable |
|
(``False``). |
|
|
|
""" |
|
|
|
def __init__(self, mode: bool) -> None: |
|
self.prev = torch._C._is_view_replay_enabled() |
|
torch._C._set_view_replay_enabled(mode) |
|
self.mode = mode |
|
|
|
def __enter__(self) -> None: |
|
pass |
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
|
torch._C._set_view_replay_enabled(self.prev) |
|
|
|
def clone(self): |
|
return self.__class__(self.mode) |
|
|
|
|
|
class _unsafe_preserve_version_counter(_DecoratorContextManager): |
|
r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING. |
|
|
|
This context manager can lead to arbitrary silent-correctness issues in any other part of your code |
|
(even the ones not touched directly by the context manager)! |
|
|
|
Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute. |
|
This is generally important for correctness, as for example, mutating a tensor that autograd has saved |
|
for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect |
|
and error out in this situation. |
|
|
|
However, there are rare instances where it might be useful to hide mutations from autograd. For example: |
|
if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate |
|
the tensor right before it is needed by autograd. |
|
|
|
Args: |
|
tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of. |
|
|
|
.. note:: |
|
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. |
|
|
|
""" |
|
|
|
def __init__(self, tensors: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None: |
|
self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors |
|
assert isinstance(self.tensors, tuple) |
|
self.prev_versions = tuple(t._version for t in self.tensors) |
|
|
|
def __enter__(self) -> None: |
|
pass |
|
|
|
def __exit__(self, *args) -> None: |
|
torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions) |
|
|