blob: 7be8ddefafe9ea3e164f50f499a863cbd22669dd [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import pytest
import torch
import torch.cuda
from torch import nn
from torch.distributed.pipeline.sync.checkpoint import (
checkpoint,
Checkpointing,
is_checkpointing,
is_recomputing,
)
from torch.distributed.pipeline.sync.dependency import fork, join
from torch.distributed.pipeline.sync.microbatch import Batch
from torch.testing._internal.common_utils import run_tests
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
# Copied from https://github.com/pytorch/pytorch/pull/18568.
timeline = []
class Log(torch.autograd.Function):
@staticmethod
def forward(ctx, name, x):
ctx.name = name
timeline.append(f"{name}:forward")
return x.detach()
@staticmethod
def backward(ctx, grad_output):
name = ctx.name
timeline.append(f"{name}:backward")
return None, grad_output
a = torch.rand(1, device=device, requires_grad=True)
b = torch.rand(1, device=device, requires_grad=True)
# Increase the next function sequence number.
_ = a + 1 + 2 + 3 + 4 + 5
a = checkpoint(partial(Log.apply, "a"), a)
a, phony = fork(a)
b = join(b, phony)
b = checkpoint(partial(Log.apply, "b"), b)
c = torch.cat((a, b))
out = c.sum()
# +--> {a} --Checkpoint(Log)--> {a}
# {out} --Sum--> {c} --Cat ^-----------------------------+
# +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
out.backward()
assert timeline == [
"a:forward",
"b:forward",
"b:forward",
"b:backward",
"a:forward",
"a:backward",
]
# |----------------------| |-----------------------| |-----------------------|
# forward pass Checkpoint(Log[b]) Checkpoint(Log[a])
def test_not_requires_grad():
x = Batch(torch.rand(1, requires_grad=False))
assert not x[0].requires_grad
def f(x):
return x * 2
chk = Checkpointing(f, x)
x = chk.checkpoint()
assert x[0].requires_grad
chk.recompute(x)
assert x[0].requires_grad
x.tensor.backward()
def test_not_requires_grad_with_parameter():
x = torch.rand(1, requires_grad=False)
a = torch.rand(1, requires_grad=True)
def f(x):
return x * a
y = checkpoint(f, x)
y.backward()
assert a.grad is not None
@pytest.mark.parametrize("device", devices)
def test_random_in_checkpoint(device):
dropout = nn.Dropout(p=0.5)
torch.manual_seed(0)
x = torch.randn(3, 3, device=device, requires_grad=True)
y = dropout(x)
y.norm().backward()
torch.manual_seed(0)
chk_x = torch.randn(3, 3, device=device, requires_grad=True)
chk_y = checkpoint(dropout, chk_x)
chk_y.norm().backward()
assert torch.allclose(x.grad, chk_x.grad)
def test_detect_checkpointing_recomputing():
logs = []
class Detect(nn.Module):
def forward(self, input):
logs.append((is_checkpointing(), is_recomputing()))
return input
model = Detect()
input = torch.rand(1, requires_grad=True)
output = checkpoint(model, input)
output.backward()
assert logs == [(True, False), (False, True)]
def test_detect_checkpointing_recomputing_without_checkpoint():
logs = []
class Detect(nn.Module):
def forward(self, input):
logs.append((is_checkpointing(), is_recomputing()))
return input
model = Detect()
input = torch.rand(1, requires_grad=True)
output = model(input)
output.backward()
assert logs == [(False, False)]
def test_non_grad_output():
class ForkNonGrad(nn.Module):
def forward(self, input):
return (input * 2, torch.rand(1))
model = ForkNonGrad()
input = torch.rand(1, requires_grad=True)
output = checkpoint(model, input)
output[0].backward()
if __name__ == "__main__":
run_tests()