# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. # Copyright 2019 Kakao Brain # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Arbitrary dependency between two autograd lanes.""" from typing import List, Tuple import torch from torch import Tensor from .phony import get_phony __all__: List[str] = [] def fork(input: Tensor) -> Tuple[Tensor, Tensor]: """Branches out from an autograd lane of the given tensor.""" if torch.is_grad_enabled() and input.requires_grad: input, phony = Fork.apply(input) else: phony = get_phony(input.device, requires_grad=False) return input, phony class Fork(torch.autograd.Function): @staticmethod def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore phony = get_phony(input.device, requires_grad=False) return input.detach(), phony.detach() @staticmethod def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore return grad_input def join(input: Tensor, phony: Tensor) -> Tensor: """Merges two autograd lanes.""" if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): input = Join.apply(input, phony) return input class Join(torch.autograd.Function): @staticmethod def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore return input.detach() @staticmethod def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore return grad_input, None