[Simple FSDP] Add unit test for torch.compile + reparameterization + SAC (#129641)
This can reproduce the error in https://github.com/pytorch/pytorch/issues/129684. Adding a unit test so that we hold the line for torch.compile + reparameterization + SAC to always be working, to pave the path for Tianyu's intern's project.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129641
Approved by: https://github.com/tianyu-l
diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py
index cc8e800..245f905 100644
--- a/test/dynamo/test_activation_checkpointing.py
+++ b/test/dynamo/test_activation_checkpointing.py
@@ -1,4 +1,5 @@
# Owner(s): ["module: dynamo"]
+import copy
import functools
import math
import unittest # noqa: F811
@@ -10,6 +11,7 @@
import torch._dynamo.test_case
import torch._functorch.config
import torch.distributed as dist
+import torch.nn as nn
import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition
@@ -1000,6 +1002,98 @@
):
self._validate(fn, backend, x, y)
+ @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
+ def test_compile_selective_checkpoint_parametrization(self):
+ def sac_policy():
+ def _recomp_policy():
+ def _custom_policy(ctx, func, *args, **kwargs):
+ to_recompute = func in {
+ torch.ops.aten.mul.Tensor,
+ torch.ops.aten.sigmoid.default,
+ }
+ return (
+ CheckpointPolicy.MUST_RECOMPUTE
+ if to_recompute
+ else CheckpointPolicy.MUST_SAVE
+ )
+
+ return _custom_policy
+
+ return create_selective_checkpoint_contexts(_recomp_policy())
+
+ class Parametrization(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def parametrization(self, x):
+ return torch.sigmoid(torch.mul(x, x))
+
+ def forward(self, x):
+ return checkpoint(
+ self.parametrization, x, use_reentrant=False, context_fn=sac_policy
+ )
+
+ def apply_parametrization(model):
+ modules = list(model.modules())
+
+ for mod in modules:
+ params_dict = dict(mod.named_parameters(recurse=False))
+ for p_name, p in params_dict.items():
+ mod.register_parameter(p_name, nn.Parameter(p))
+ nn.utils.parametrize.register_parametrization(
+ mod, p_name, Parametrization(), unsafe=True
+ )
+
+ return model
+
+ class MLPModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ torch.manual_seed(5)
+ self.net1 = nn.Linear(16, 16, bias=False)
+
+ def forward(self, x):
+ return self.net1(x)
+
+ def reset_parameters(self):
+ self.net1.reset_parameters()
+
+ fw_compiler = functools.partial(
+ count_ops,
+ freqs=[1, 1],
+ ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
+ )
+ bw_compiler = functools.partial(
+ count_ops,
+ freqs=[
+ 2, # 1 from mul recompute, 1 from mul backward
+ 1,
+ ],
+ ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
+ )
+
+ backend = aot_autograd(
+ fw_compiler=fw_compiler,
+ bw_compiler=bw_compiler,
+ partition_fn=min_cut_rematerialization_partition,
+ )
+
+ model = MLPModule()
+ model = apply_parametrization(model)
+ model_compiled = torch.compile(
+ copy.deepcopy(model), backend=backend, fullgraph=True
+ )
+ input = torch.randn(8, 16, requires_grad=True)
+ input_compiled = copy.deepcopy(input)
+
+ out = model(input)
+ out.sum().backward()
+ out_compiled = model_compiled(input_compiled)
+ out_compiled.sum().backward()
+
+ self.assertEqual(out, out_compiled)
+ self.assertEqual(input.grad, input_compiled.grad)
+
@requires_cuda
@skipIfRocm
def test_autocast_flash_attention(self):