blob: fbffc90f19c51f0370454fc9551c09309bf4a22f [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
import torch
import torch.nn as nn
from torch.distributed._composable import checkpoint
import unittest
from collections import deque
from copy import deepcopy
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(100, 100)
self.seq = nn.Sequential(
nn.ReLU(),
nn.Linear(100, 100),
nn.ReLU(),
)
def forward(self, x):
return self.seq(self.l1(x))
class TestCheckpoint(TestCase):
def _get_graph_size(self, out: torch.Tensor) -> int:
q = deque([out.grad_fn])
num_functions = 0
while len(q):
fn = q.pop()
num_functions += 1
for next_fn, _ in fn.next_functions:
if next_fn:
q.append(next_fn)
return num_functions
def _test_tensor_only(self, net: nn.Module, x: torch.Tensor) -> None:
x1 = x.clone()
x2 = x.clone()
x1.requires_grad = True
x2.requires_grad = True
net1 = net
net2 = deepcopy(net)
# no checkpoint
loss1 = net1(x1).sum()
graph_size1 = self._get_graph_size(loss1)
loss1.backward()
# with checkpoint
checkpoint(net2.seq)
loss2 = net2(x2).sum()
graph_size2 = self._get_graph_size(loss2)
loss2.backward()
self.assertTrue(graph_size2 < graph_size1)
for p1, p2 in zip(net1.parameters(), net2.parameters()):
self.assertEqual(p1.grad, p2.grad)
def test_tensor_only_cpu(self):
x = torch.randn(20, 100)
net = ToyModel()
self._test_tensor_only(net, x)
@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_tensor_only_gpu(self):
x = torch.randn(20, 100, device="cuda:0")
net = ToyModel().to("cuda:0")
self._test_tensor_only(net, x)
if __name__ == "__main__":
run_tests()