blob: 7645ac997e646a8747753709de0f5a8be8b772f8 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import functools
import math
import sys
import unittest # noqa: F811
from importlib import import_module
import torch
import torch._dynamo.config
import torch._dynamo.test_case
import torch._functorch.config
import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import CompileCounterWithBackend
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
def checkpoint_wrapper(fn):
def inner(*args):
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
return inner
def count_ops(
gm, args, freq=None, freq_ge=None, op=None, freqs=None, freqs_ge=None, ops=None
):
def match_rng_op(node, op):
if isinstance(node.target, torch._ops.HigherOrderOperator):
if node.name == "run_and_save_rng_state":
return node.args[0] == op
elif node.name == "run_with_rng_state":
return node.args[1] == op
return False
# assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops)
if op is not None:
ops = [op]
if freq is not None:
freqs = [freq]
if freq_ge is not None:
freqs_ge = [freq_ge]
if freqs:
for op, freq in zip(ops, freqs):
actual_count = 0
for node in gm.graph.nodes:
if match_rng_op(node, op) or node.target == op:
actual_count += 1
assert (
actual_count == freq
), f"In graph {gm}, expected {op} to have occurred {freq} times in the graph, but got {actual_count}."
else:
assert freqs_ge is not None
for op, freq_ge in zip(ops, freqs_ge):
actual_count = 0
for node in gm.graph.nodes:
if match_rng_op(node, op) or node.target == op:
actual_count += 1
assert (
actual_count >= freq_ge
), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
return gm
class _InvalidContext:
def __init__(self):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def _invalid_context_gen():
return _InvalidContext(), _InvalidContext()
def find_first_node(gm, func):
for node in gm.graph.nodes:
if node.target is func:
return node
return None
def op_count(gm):
result = 0
for node in gm.graph.nodes:
if "call" in node.op:
result += 1
return result
def _get_custom_policy(no_recompute_list=None):
def _custom_policy(mode, func, *args, **kwargs):
return func in no_recompute_list
return _custom_policy
class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True):
cloned_args = []
for arg in args:
cloned_args.append(arg.clone().detach().requires_grad_(arg.requires_grad))
torch.manual_seed(0)
expected = fn(*args)
expected.sum().backward()
torch.manual_seed(0)
result = torch.compile(fn, fullgraph=fullgraph, backend=backend)(*cloned_args)
result.sum().backward()
if not skip_check:
self.assertEqual(
result,
expected,
msg="Output mismatch between torch.compile and eager versions",
)
for arg, cloned_arg in zip(args, cloned_args):
self.assertEqual(
arg.grad,
cloned_arg.grad,
msg="Gradient mismatch between torch.compile and eager versions",
)
def _compare_orig_and_checkpointed_fns(
self, orig_fn, checkpointed_fn, *args, fullgraph=True
):
# The original version and the checkpointed version of the same function
# should produce the same outputs and the same gradients under torch.compile.
# Run original version
cloned_args_orig_fn = []
for arg in args:
cloned_args_orig_fn.append(
arg.clone().detach().requires_grad_(arg.requires_grad)
)
torch.manual_seed(0)
compiled_orig_fn = torch.compile(
orig_fn, fullgraph=fullgraph, backend="inductor"
)
result_orig_fn = compiled_orig_fn(*cloned_args_orig_fn)
result_orig_fn.sum().backward()
# Run checkpointed version
cloned_args_checkpointed_fn = []
for arg in args:
cloned_args_checkpointed_fn.append(
arg.clone().detach().requires_grad_(arg.requires_grad)
)
torch.manual_seed(0)
compiled_checkpointed_fn = torch.compile(
checkpointed_fn, fullgraph=fullgraph, backend="inductor"
)
result_checkpointed_fn = compiled_checkpointed_fn(*cloned_args_checkpointed_fn)
result_checkpointed_fn.sum().backward()
# Check that outputs and gradients are equal
self.assertEqual(
result_orig_fn,
result_checkpointed_fn,
msg="Output mismatch between the original version and the checkpointed version of the same function",
)
for cloned_arg_orig_fn, cloned_arg_checkpointed_fn in zip(
cloned_args_orig_fn, cloned_args_checkpointed_fn
):
self.assertEqual(
cloned_arg_orig_fn.grad,
cloned_arg_checkpointed_fn.grad,
msg="Gradient mismatch between the original version and the checkpointed version of the same function",
)
@requires_cuda
def test_tags_function(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True
)
x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda
def test_tags_function_via_global_checkpoint(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
# This goes through VariableBuilder
return checkpoint(gn, torch.sin(x), y, use_reentrant=True)
x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda
def test_tags_function_with_kwargs(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
)
x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda
def test_tags_multiple_checkpoints(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
x = torch.sin(x)
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(z)
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
return z
x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=6, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda
def test_tags_module(self):
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return torch.sigmoid(self.linear(x))
mod = MockModule().cuda()
def fn(x):
return torch.utils.checkpoint.checkpoint(
mod, torch.sin(x), use_reentrant=True
)
x = torch.randn(10, 10, device="cuda", requires_grad=True)
fw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x)
@requires_cuda
def test_tags_decomps(self):
# Ensures that tags are passed on through decompositions as well
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return torch.nn.functional.gelu(self.linear(x))
mod = MockModule().cuda()
def fn(x):
return torch.utils.checkpoint.checkpoint(
mod, torch.sin(x), use_reentrant=True
)
x = torch.randn(10, 10, device="cuda", requires_grad=True)
fw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.erf.default
)
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.erf.default
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
)
self._validate(fn, backend, x)
@requires_cuda
@torch._inductor.config.patch(fallback_random=True)
def test_tags_recomputed_rand(self):
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
return z
x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
# fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
# bw_compiler = functools.partial(
# count_ops, freq=6, op=torch.ops.aten.mm.default
# ) # mm recomputed in the bwd
# backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = "inductor"
self._validate(fn, backend, x, y)
@requires_cuda
@torch._inductor.config.patch(fallback_random=True)
def test_tags_rand(self):
def gn(x, y):
x = torch.mm(x, y)
x = torch.mm(x, y)
return x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
# x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
return x
x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
# fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
# bw_compiler = functools.partial(
# count_ops, freq=6, op=torch.ops.aten.mm.default
# ) # mm recomputed in the bwd
# backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
# backend = "aot_eager"
backend = "inductor"
self._validate(fn, backend, x, y)
@requires_cuda
@torch._inductor.config.patch(fallback_random=True)
def test_tags_dropout(self):
# Figure out a way to test the number of inductor_random calls
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.dropout = torch.nn.Dropout(0.2)
def forward(self, x):
return self.dropout(self.linear(x))
mod = MockModule().cuda()
def fn(x):
return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)
x = torch.randn(10, 10, device="cuda", requires_grad=True)
backend = "inductor"
# rand decomps do not have have numerical results as eager
self._validate(fn, backend, x, skip_check=True)
@requires_cuda
def test_fallback(self):
def gn(x, y):
torch._dynamo.graph_break()
a = torch.sigmoid(torch.matmul(x, y))
torch._dynamo.graph_break()
return torch.cos(a)
def fn(x, y):
return torch.cos(checkpoint(gn, torch.sin(x), y, use_reentrant=False))
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
args = (x, y)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
expected = fn(*args)
result = torch.compile(fn, backend=cnt)(*args)
self.assertEqual(result, expected)
# One graph for torch.sin on the input, and other for torch.cos.
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
self.assertEqual(len(cnt.graphs), 2)
@requires_cuda
def test_kwargs(self):
def gn(x, y, z=None):
a = torch.matmul(x, y)
if z is not None:
return torch.matmul(a, z)
return a
def fn(x, y, z):
return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
z = torch.randn(4, 4, requires_grad=True)
args = (x, y, z)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
expected = fn(*args)
result = torch.compile(fn, backend=cnt)(*args)
self.assertEqual(result, expected)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(len(cnt.graphs), 1)
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
# one for checkpoint, and 3 for x, y, z
self.assertEqual(len(wrap_node.args), 4)
body_function = getattr(cnt.graphs[0], wrap_node.args[0].name)
self.assertEqual(op_count(body_function), 2)
@requires_cuda
def test_symints_location(self):
def gn(x, y):
return torch.matmul(x, torch.nn.functional.dropout(y, 0.5))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
opt_fn = torch.compile(fn, backend=cnt)
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
args = (x, y)
expected = fn(*args)
result = opt_fn(*args)
x = torch.randn(5, 5, requires_grad=True)
y = torch.randn(5, 5, requires_grad=True)
args = (x, y)
expected = fn(*args)
result = opt_fn(*args)
self.assertEqual(result.shape, expected.shape)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(len(cnt.graphs), 2)
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
self.assertEqual(len(wrap_node.args), 3)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@torch._dynamo.config.patch(
"_experimental_support_context_fn_in_torch_utils_checkpoint", True
)
def test_compile_selective_checkpoint_gemm_only(self):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device="cuda")
y = torch.randn(4, 4, requires_grad=True, device="cuda")
fw_compiler = functools.partial(
count_ops,
freq=2,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# We would've expected 6 here
# (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
# if we didn't enable selective checkpointing.
freq=4,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@torch._dynamo.config.patch(
"_experimental_support_context_fn_in_torch_utils_checkpoint", True
)
def test_compile_selective_checkpoint_tensor_subclass(self):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
rand_tensor = torch.randn(4, 4, requires_grad=True, device="cuda")
# tensor subclasses as inputs
x = TwoTensor(rand_tensor, rand_tensor.clone())
y = TwoTensor(rand_tensor.clone(), rand_tensor.clone())
fw_compiler = functools.partial(
count_ops,
freq=4,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# We would've expected 12 here
# (4 matmul recompute and 4 mm ops per fwd matmul, so 4 + 2 * 4 = 12)
# if we didn't enable selective checkpointing.
freq=8,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@torch._dynamo.config.patch(
"_experimental_support_context_fn_in_torch_utils_checkpoint", True
)
def test_compile_selective_checkpoint_custom_rule(self):
def _get_custom_policy(meta):
no_recompute_list = [
torch.ops.aten.mm.default,
]
def _custom_policy(mode, func, *args, **kwargs):
mm_count_key = f"{mode}_mm_count"
if mm_count_key not in meta:
meta[mm_count_key] = 0
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except second mm
# (i.e. we will hint the partitioner to recompute second mm in backward pass)
return func in no_recompute_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] == 2
)
return _custom_policy
def selective_checkpointing_context_fn():
meta = {}
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta))
def gn(x, y):
return torch.sigmoid(
torch.sigmoid(torch.matmul(torch.matmul(x, y) * y, y) * y)
)
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device="cuda")
y = torch.randn(4, 4, requires_grad=True, device="cuda")
fw_compiler = functools.partial(
count_ops,
freq=2,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# Q: How do we come to this number 4?
# A: We have 2 matmuls in the forward pass, each matmul contributes 2 `mm` ops in the backward pass,
# so we have at least 4 `mm` ops in backward pass. It's "at least" because whether second matmul in
# the forward pass is recomputed in the backward pass is up to the partitioner to decide.
freq_ge=4,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@torch._dynamo.config.patch(
"_experimental_support_context_fn_in_torch_utils_checkpoint", True
)
def test_compile_selective_checkpoint_partial_ctx_fn(self):
def selective_checkpointing_context_fn(no_recompute_list):
return _pt2_selective_checkpoint_context_fn_gen(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=functools.partial(
selective_checkpointing_context_fn, [torch.ops.aten.mm.default]
),
)
x = torch.randn(4, 4, requires_grad=True, device="cuda")
y = torch.randn(4, 4, requires_grad=True, device="cuda")
fw_compiler = functools.partial(
count_ops,
freq=2,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# We would've expected 6 here
# (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
# if we didn't enable selective checkpointing.
freq=4,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@torch._dynamo.config.patch(
"_experimental_support_context_fn_in_torch_utils_checkpoint", True
)
def test_compile_selective_checkpoint_outplace_op(self):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
torch.ops.aten.sigmoid.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
_get_custom_policy(no_recompute_list=no_recompute_list),
)
def gn(x, y):
return torch.sigmoid(torch.selu(torch.matmul(torch.matmul(x, y), y))).relu()
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device="cuda")
y = torch.randn(4, 4, requires_grad=True, device="cuda")
fw_compiler = functools.partial(
count_ops,
freqs=[2, 1],
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
)
bw_compiler = functools.partial(
count_ops,
freqs=[4, 0],
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@unittest.skip(
"In-place op support in selective checkpointing + torch.compile "
"requires TorchDispatchMode + torch.compile work to complete"
)
@torch._dynamo.config.patch(
"_experimental_support_context_fn_in_torch_utils_checkpoint", True
)
def test_compile_selective_checkpoint_inplace_op(self):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
torch.ops.aten.sigmoid.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(
torch.selu_(torch.matmul(torch.matmul(x, y), y))
).relu_()
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device="cuda")
y = torch.randn(4, 4, requires_grad=True, device="cuda")
fw_compiler = functools.partial(
count_ops,
freqs=[2, 1],
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
)
bw_compiler = functools.partial(
count_ops,
freqs=[4, 0],
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@torch._dynamo.config.patch(
"_experimental_support_context_fn_in_torch_utils_checkpoint", True
)
def test_compile_selective_checkpoint_random_op(self):
for preserve_rng_state in [True, False]:
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.sigmoid.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x):
return torch.sigmoid(torch.dropout(torch.sigmoid(x), p=0.5, train=True))
def fn(x):
return torch.utils.checkpoint.checkpoint(
gn,
x,
use_reentrant=False,
# Regardless of whether `preserve_rng_state` is True or False,
# we will always preserve RNG state when using `torch.compile`.
preserve_rng_state=preserve_rng_state,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device="cuda")
fw_compiler = functools.partial(
count_ops,
freqs=[2, 1],
ops=[
torch.ops.aten.sigmoid.default,
torch.ops.aten.native_dropout.default,
],
)
bw_compiler = functools.partial(
count_ops,
# NOTE: This unit test expects `dropout` to be recomputed (notice the count for `native_dropout` is 1).
freqs=[0, 1],
ops=[
torch.ops.aten.sigmoid.default,
torch.ops.aten.native_dropout.default,
],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
# because eager version doesn't preserve RNG state while torch.compile still does.
# Hence when `preserve_rng_state` is False, we skip the output and gradient comparison
# between torch.compile and eager.
self._validate(fn, backend, x, skip_check=not preserve_rng_state)
self._compare_orig_and_checkpointed_fns(gn, fn, x)
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@torch._dynamo.config.patch(
"_experimental_support_context_fn_in_torch_utils_checkpoint", True
)
def test_compile_selective_checkpoint_invalid_context(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=_invalid_context_gen,
)
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
fw_compiler = functools.partial(
count_ops,
freq=1,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
freq_ge=2,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
with self.assertRaisesRegex(
Exception, "must generate a tuple of two `TorchDispatchMode`s"
):
self._validate(fn, backend, x, y)
@requires_cuda
@skipIfRocm
def test_autocast_flash_attention(self):
def fn(primals_1, primals_2, primals_3):
return torch.ops.aten._scaled_dot_product_efficient_attention.default(
primals_1, primals_2, primals_3, None, True, scale=0.17677669529663687
)[0]
def gn(*args):
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
with torch.cuda.amp.autocast():
x = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
y = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
z = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
args = (x, y, z)
torch.manual_seed(0)
ref = gn(*args)
opt_gn = torch.compile(gn)
torch.manual_seed(0)
res = opt_gn(*args)
self.assertEqual(ref, res)
@requires_cuda
def test_error_msg(self):
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = torch.sin(x)
torch._dynamo.graph_break()
x = torch.cos(x)
return x
mod = MockModule().cuda()
def fn(x):
return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)
x = torch.randn(4, 4).cuda()
opt_fn = torch.compile(fn, fullgraph=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "skip function graph_break in file"
):
opt_fn(x)
@requires_cuda
def test_list_inputs(self):
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, ys):
a = torch.sin(x)
b = torch.cos(ys[0])
c = torch.cos(ys[1])
return (x, [b, c])
mod = MockModule().cuda()
def fn(x, ys):
return torch.utils.checkpoint.checkpoint(mod, x, ys, use_reentrant=True)
x = torch.randn(4, 4).cuda()
y = torch.randn(4, 4).cuda()
z = torch.randn(4, 4).cuda()
ref = fn(x, [y, z])
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x, [y, z])
self.assertEqual(ref, res)
@requires_cuda
def test_pattern_matcher(self):
# Check that the sdpa op is recomputed in the backward graph
# tests percolate_tags
@checkpoint_wrapper
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return (
torch.matmul(query, key.transpose(-2, -1))
.mul(1.0 / math.sqrt(key.shape[-1]))
.softmax(dim=-1)
.matmul(value)
)
def fn(query, key, value):
# Checks that sin is not recomputed in the backward graph
return dot_prod_attention(query.sin(), key, value)
tensor_shape = (4, 2, 16, 32)
dtype = torch.float16
args1 = [
torch.randn(tensor_shape, device="cuda", dtype=dtype, requires_grad=True),
torch.randn(tensor_shape, device="cuda", dtype=dtype, requires_grad=True),
torch.randn(tensor_shape, device="cuda", dtype=dtype, requires_grad=True),
]
# Save the AOT graphs
aot_graphs = []
from torch._inductor import compile_fx
def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
aot_graphs.append(graph)
return compile_fx.compile_fx_inner(graph, example_inputs, *args, **kwargs)
backend = functools.partial(
compile_fx.compile_fx, inner_compile=debug_compile_fx_inner
)
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
opt_fn(*args1).sum().backward()
fwd_graph = aot_graphs[0]
self.assertTrue(
count_ops(
fwd_graph,
[],
freq=1,
op=torch.ops.aten._scaled_dot_product_flash_attention.default,
)
)
bwd_graph = aot_graphs[1]
# Check that sin is not recomputed in the backward graph - checks percolate tags
self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default))
# Check that the sdpa op is recomputed in the backward graph
self.assertTrue(
count_ops(
bwd_graph,
[],
freq=1,
op=torch.ops.aten._scaled_dot_product_flash_attention.default,
)
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()