blob: d4c53e167b65046b319d87a91dc5836a4b55079a [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import unittest
from collections import deque, OrderedDict
from contextlib import ContextDecorator
from copy import deepcopy
from typing import Tuple
import torch
import torch.nn as nn
from torch.distributed._composable import checkpoint
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, TestCase
class MemoryDelta(ContextDecorator):
def __init__(self, device: torch.device):
self.device: torch.device = device
self.active_memory_enter: int = 0
self.active_memory_exit: int = 0
def __enter__(self):
self.active_memory_enter = (
torch.cuda.memory_stats()["active_bytes.all.current"]
if self.device.type == "cuda"
else 0
)
return self
def __exit__(self, *exc):
self.active_memory_exit = (
torch.cuda.memory_stats()["active_bytes.all.current"]
if self.device.type == "cuda"
else 0
)
def delta(self) -> int:
return self.active_memory_exit - self.active_memory_enter
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(100, 100)
self.seq = nn.Sequential(
nn.ReLU(),
nn.Linear(100, 100),
nn.ReLU(),
)
def forward(self, x):
return self.seq(self.l1(x))
class RandomModel(nn.Module):
def __init__(self):
super().__init__()
self.p = nn.Parameter(torch.randn(100, 100))
def forward(self, x):
y = torch.matmul(self.p, torch.randn(100, 100, device=self.p.device))
return torch.matmul(x, y)
class MultiOutputModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
self.w2 = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = x @ self.w1
z = nn.functional.relu(z)
z = z @ self.w2
return z.sin(), z.cos()
class MultiInputModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.w = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
x, y = xs
z = x + y
z = z @ self.w
return nn.functional.relu(z)
class TestCheckpoint(TestCase):
def _get_graph_size(self, out: torch.Tensor) -> int:
q = deque([out.grad_fn])
num_functions = 0
while len(q):
fn = q.pop()
num_functions += 1
for next_fn, _ in fn.next_functions:
if next_fn:
q.append(next_fn)
return num_functions
def _test_tensor_only(
self,
net: nn.Module,
x: torch.Tensor,
) -> None:
x1 = x.clone()
x2 = x.clone()
x1.requires_grad = True
x2.requires_grad = True
net1 = net
net2 = deepcopy(net)
# no checkpoint
with MemoryDelta(x.device) as mem1:
loss1 = net1(x1).sum()
graph_size1 = self._get_graph_size(loss1)
loss1.backward()
# with checkpoint
checkpoint(net2.seq)
with MemoryDelta(x.device) as mem2:
loss2 = net2(x2).sum()
loss2.backward()
if x.is_cuda:
self.assertTrue(mem2.delta() < mem1.delta())
for p1, p2 in zip(net1.parameters(), net2.parameters()):
self.assertEqual(p1.grad, p2.grad)
def test_tensor_only_cpu(self):
x = torch.randn(20, 100)
net = ToyModel()
self._test_tensor_only(net, x)
@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_tensor_only_gpu(self):
x = torch.randn(20, 100, device="cuda:0")
net = ToyModel().to("cuda:0")
self._test_tensor_only(net, x)
def test_random_cpu(self):
x1 = torch.randn(20, 100, requires_grad=True)
x2 = x1.clone()
net1 = RandomModel()
net2 = deepcopy(net1)
cpu_rng_state = torch.get_rng_state()
net1(x1).sum().backward()
torch.set_rng_state(cpu_rng_state)
checkpoint(net2)(x2).sum().backward()
for p1, p2 in zip(net1.parameters(), net2.parameters()):
self.assertEqual(p1.grad, p2.grad)
def test_multi_args(self):
"""
Tests checkpoint for modules with multiple output args and hence
multiple backward function input args.
"""
device = torch.device("cpu")
net1 = nn.Sequential(
MultiOutputModel(device),
MultiInputModel(device),
MultiOutputModel(device),
MultiInputModel(device),
)
net2 = deepcopy(net1)
checkpoint(net2[0])
checkpoint(net2[2])
x1 = torch.randn(20, 100, requires_grad=True)
x2 = x1.clone()
net1(x1).sum().backward()
net2(x2).sum().backward()
for p1, p2 in zip(net1.parameters(), net2.parameters()):
self.assertEqual(p1.grad, p2.grad)
def test_clears_state_on_error_in_forward(self):
class MyModel(torch.nn.Module):
def __init__(self, raise_in_recomp):
super().__init__()
self.fwd_count = 0
self.raise_in_recomp = raise_in_recomp
self.a = torch.nn.Linear(2, 2)
def forward(self, x):
if self.raise_in_recomp and self.fwd_count == 1:
raise RuntimeError("foo")
else:
if not self.raise_in_recomp:
# raise in the first forward
raise RuntimeError("foo")
self.fwd_count += 1
return self.a(x)
m = MyModel(raise_in_recomp=True)
m_seq = torch.nn.Sequential(OrderedDict({"m": m}))
checkpoint(m_seq.m)
inp = torch.randn(1, 2)
out = m_seq(inp).sum()
# Should raise in forward recomputation
with self.assertRaisesRegex(RuntimeError, "foo"):
out.backward()
# Check that _ac_generator is cleared out
self.assertEqual(None, checkpoint.state(m)._ac_generator)
m = MyModel(raise_in_recomp=False)
checkpoint(m)
inp = torch.randn(1, 2)
# Should raise in first forward
with self.assertRaises(RuntimeError):
m(inp)
self.assertEqual(None, checkpoint.state(m)._ac_generator)
if __name__ == "__main__":
run_tests()