| # Copyright 2019 Kakao Brain |
| # |
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| # |
| # This source code is licensed under the BSD license found in the |
| # LICENSE file in the root directory of this source tree. |
| """Arbitrary dependency between two autograd lanes.""" |
| from typing import List, Tuple |
| |
| import torch |
| from torch import Tensor |
| |
| from .phony import get_phony |
| |
| __all__: List[str] = ["fork", "Fork", "join", "Join"] |
| |
| |
| def fork(input: Tensor) -> Tuple[Tensor, Tensor]: |
| """Branches out from an autograd lane of the given tensor.""" |
| if torch.is_grad_enabled() and input.requires_grad: |
| input, phony = Fork.apply(input) |
| else: |
| phony = get_phony(input.device, requires_grad=False) |
| |
| return input, phony |
| |
| |
| class Fork(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore[override] |
| phony = get_phony(input.device, requires_grad=False) |
| return input.detach(), phony.detach() |
| |
| @staticmethod |
| def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore[override] |
| return grad_input |
| |
| |
| def join(input: Tensor, phony: Tensor) -> Tensor: |
| """Merges two autograd lanes.""" |
| if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): |
| input = Join.apply(input, phony) |
| |
| return input |
| |
| |
| class Join(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore[override] |
| return input.detach() |
| |
| @staticmethod |
| def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore[override] |
| return grad_input, None |