blob: 98a5d6e6b38dd1caa1598452009a3553908a2df7 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import functools
import unittest
from importlib import import_module
import torch
import torch._dynamo.test_case
import torch._functorch.config
import torch.utils.checkpoint
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import CompileCounterWithBackend
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.utils.checkpoint import checkpoint
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
def count_ops(gm, args, freq, op):
assert [node.target for node in gm.graph.nodes].count(op) == freq
return gm
def find_first_node(gm, func):
for node in gm.graph.nodes:
if node.target is func:
return node
return None
def op_count(gm):
result = 0
for node in gm.graph.nodes:
if "call" in node.op:
result += 1
return result
class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True):
cloned_args = []
for arg in args:
cloned_args.append(arg.clone().detach().requires_grad_(arg.requires_grad))
torch.manual_seed(0)
expected = fn(*args)
expected.sum().backward()
torch.manual_seed(0)
result = torch.compile(fn, fullgraph=fullgraph, backend=backend)(*cloned_args)
result.sum().backward()
if not skip_check:
self.assertEqual(result, expected)
for arg, cloned_arg in zip(args, cloned_args):
self.assertEqual(arg.grad, cloned_arg.grad)
@requires_cuda()
def test_tags_function(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y)
x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda()
def test_tags_function_via_global_checkpoint(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
# This goes through VariableBuilder
return checkpoint(gn, torch.sin(x), y)
x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda()
def test_tags_function_with_kwargs(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
)
x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda()
def test_tags_multiple_checkpoints(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
x = torch.sin(x)
z = torch.utils.checkpoint.checkpoint(gn, x, y)
x = torch.sin(z)
z = torch.utils.checkpoint.checkpoint(gn, x, y)
return z
x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=6, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda()
def test_tags_module(self):
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return torch.sigmoid(self.linear(x))
mod = MockModule().cuda()
def fn(x):
return torch.utils.checkpoint.checkpoint(mod, torch.sin(x))
x = torch.randn(10, 10, device="cuda", requires_grad=True)
fw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x)
@requires_cuda()
def test_tags_decomps(self):
# Ensures that tags are passed on through decompositions as well
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return torch.nn.functional.gelu(self.linear(x))
mod = MockModule().cuda()
def fn(x):
return torch.utils.checkpoint.checkpoint(mod, torch.sin(x))
x = torch.randn(10, 10, device="cuda", requires_grad=True)
fw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.erf.default
)
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.erf.default
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
)
self._validate(fn, backend, x)
@requires_cuda()
@torch._inductor.config.patch(fallback_random=True)
def test_tags_recomputed_rand(self):
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y)
x = torch.sin(x)
z = torch.utils.checkpoint.checkpoint(gn, x, y)
return z
x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
# fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
# bw_compiler = functools.partial(
# count_ops, freq=6, op=torch.ops.aten.mm.default
# ) # mm recomputed in the bwd
# backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = "inductor"
self._validate(fn, backend, x, y)
@requires_cuda()
@torch._inductor.config.patch(fallback_random=True)
def test_tags_rand(self):
def gn(x, y):
x = torch.mm(x, y)
x = torch.mm(x, y)
return x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y)
x = torch.sin(x)
# x = torch.utils.checkpoint.checkpoint(gn, x, y)
return x
x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
# fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
# bw_compiler = functools.partial(
# count_ops, freq=6, op=torch.ops.aten.mm.default
# ) # mm recomputed in the bwd
# backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
# backend = "aot_eager"
backend = "inductor"
self._validate(fn, backend, x, y)
@requires_cuda()
@torch._inductor.config.patch(fallback_random=True)
def test_tags_dropout(self):
# Figure out a way to test the number of inductor_random calls
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.dropout = torch.nn.Dropout(0.2)
def forward(self, x):
return self.dropout(self.linear(x))
mod = MockModule().cuda()
def fn(x):
return torch.utils.checkpoint.checkpoint(mod, x)
x = torch.randn(10, 10, device="cuda", requires_grad=True)
backend = "inductor"
# rand decomps do not have have numerical results as eager
self._validate(fn, backend, x, skip_check=True)
@requires_cuda()
def test_fallback(self):
def gn(x, y):
torch._dynamo.graph_break()
a = torch.sigmoid(torch.matmul(x, y))
torch._dynamo.graph_break()
return torch.cos(a)
def fn(x, y):
return torch.cos(checkpoint(gn, torch.sin(x), y, use_reentrant=False))
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
args = (x, y)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
expected = fn(*args)
result = torch.compile(fn, backend=cnt)(*args)
self.assertEqual(result, expected)
# One graph for torch.sin on the input, and other for torch.cos.
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
self.assertEqual(len(cnt.graphs), 2)
@requires_cuda()
def test_kwargs(self):
def gn(x, y, z=None):
a = torch.matmul(x, y)
if z is not None:
return torch.matmul(a, z)
return a
def fn(x, y, z):
return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
z = torch.randn(4, 4, requires_grad=True)
args = (x, y, z)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
expected = fn(*args)
result = torch.compile(fn, backend=cnt)(*args)
self.assertEqual(result, expected)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(len(cnt.graphs), 1)
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
# one for checkpoint, and 3 for x, y, z
self.assertEqual(len(wrap_node.args), 4)
body_function = getattr(cnt.graphs[0], wrap_node.args[0].name)
self.assertEqual(op_count(body_function), 2)
@requires_cuda()
def test_symints_location(self):
def gn(x, y):
return torch.matmul(x, torch.nn.functional.dropout(y, 0.5))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(gn, x, y)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
opt_fn = torch.compile(fn, backend=cnt)
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
args = (x, y)
expected = fn(*args)
result = opt_fn(*args)
x = torch.randn(5, 5, requires_grad=True)
y = torch.randn(5, 5, requires_grad=True)
args = (x, y)
expected = fn(*args)
result = opt_fn(*args)
self.assertEqual(result.shape, expected.shape)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(len(cnt.graphs), 2)
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
self.assertEqual(len(wrap_node.args), 3)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()