blob: e966d6541bf59a02519a8e3939b114e6fe0aeb54 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import weakref
import pytest
import torch
from torch.distributed.pipeline.sync.dependency import Fork, fork, Join, join
from torch.testing._internal.common_utils import run_tests
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_fork_join():
logs = []
class Log(torch.autograd.Function):
@staticmethod
def forward(ctx, number, tensor):
ctx.number = number
return tensor.detach()
@staticmethod
def backward(ctx, grad):
logs.append(ctx.number)
return None, grad
a = torch.rand(1, device="cpu", requires_grad=True)
b = torch.rand(1, device="cuda", requires_grad=True)
a = Log.apply(1, a)
a, phony = fork(a)
b = join(a, phony)
b = Log.apply(2, b)
b = b.to("cpu")
(a + b).backward()
assert logs == [2, 1]
def test_fork_join_enable_grad():
x = torch.rand(1, requires_grad=True)
with torch.enable_grad():
x2, p = fork(x)
assert p.requires_grad
assert x2 is not x
x = x2
assert x.requires_grad
assert p.requires_grad
assert x.grad_fn.__class__ is Fork._backward_cls
assert p.grad_fn.__class__ is Fork._backward_cls
with torch.enable_grad():
x2 = join(x, p)
assert x2 is not x
x = x2
assert x.requires_grad
assert x.grad_fn.__class__ is Join._backward_cls
def test_fork_join_no_grad(monkeypatch):
def do_not_apply(*args):
raise AssertionError("Function.apply called")
monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply)
x = torch.rand(1, requires_grad=True)
with torch.no_grad():
x2, p = fork(x)
assert not p.requires_grad
assert x2 is x
x = x2
with torch.no_grad():
x2 = join(x, p)
assert x2 is x
x = x2
def test_fork_leak():
leak = None
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad):
nonlocal leak
leak = weakref.ref(ctx)
return grad
x = torch.rand(1, requires_grad=True)
x = F.apply(x)
x, phony = fork(x)
x = join(x, phony)
x.backward()
del x, phony
assert leak() is None
def test_join_when_fork_not_requires_grad():
x = torch.rand(2, 1)
a, b = x.chunk(2)
assert not a.requires_grad
a, p = fork(a)
assert not a.requires_grad
assert not p.requires_grad
assert not b.requires_grad
b = join(b, p)
assert not b.requires_grad
def test_join_when_fork_requires_grad():
x = torch.rand(2, 1)
a, b = x.chunk(2)
a.requires_grad_()
assert a.requires_grad
a, p = fork(a)
assert a.requires_grad
assert p.requires_grad
assert not b.requires_grad
b = join(b, p)
assert b.requires_grad
if __name__ == "__main__":
run_tests()