|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Provides phony for arbitrary dependency in a autograd graph.""" |
|
from typing import Dict, List, Tuple |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
from .stream import default_stream, use_stream |
|
|
|
__all__: List[str] = [] |
|
|
|
|
|
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} |
|
|
|
|
|
def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: |
|
"""Gets a phony. Phony is tensor without space. It is useful to make |
|
arbitrary dependency in a autograd graph because it doesn't require any |
|
gradient accumulation. |
|
|
|
.. note:: |
|
|
|
Phonies for each device are cached. If an autograd function gets a phony |
|
internally, the phony must be detached to be returned. Otherwise, the |
|
autograd engine will mutate the cached phony in-place:: |
|
|
|
class Phonify(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, input): |
|
phony = get_phony(input.device, requires_grad=False) |
|
return phony.detach() # detach() is necessary. |
|
|
|
""" |
|
key = (device, requires_grad) |
|
|
|
try: |
|
phony = _phonies[key] |
|
except KeyError: |
|
with use_stream(default_stream(device)): |
|
|
|
|
|
phony = torch.empty(1, device=device, requires_grad=requires_grad) |
|
|
|
_phonies[key] = phony |
|
|
|
return phony |
|
|