[Simple FSDP] Add unit test for torch.compile + reparameterization + SAC (#129641)

This can reproduce the error in https://github.com/pytorch/pytorch/issues/129684. Adding a unit test so that we hold the line for torch.compile + reparameterization + SAC to always be working, to pave the path for Tianyu's intern's project.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129641
Approved by: https://github.com/tianyu-l
diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py
index cc8e800..245f905 100644
--- a/test/dynamo/test_activation_checkpointing.py
+++ b/test/dynamo/test_activation_checkpointing.py
@@ -1,4 +1,5 @@
 # Owner(s): ["module: dynamo"]
+import copy
 import functools
 import math
 import unittest  # noqa: F811
@@ -10,6 +11,7 @@
 import torch._dynamo.test_case
 import torch._functorch.config
 import torch.distributed as dist
+import torch.nn as nn
 import torch.utils.checkpoint
 
 from functorch.compile import min_cut_rematerialization_partition
@@ -1000,6 +1002,98 @@
         ):
             self._validate(fn, backend, x, y)
 
+    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
+    def test_compile_selective_checkpoint_parametrization(self):
+        def sac_policy():
+            def _recomp_policy():
+                def _custom_policy(ctx, func, *args, **kwargs):
+                    to_recompute = func in {
+                        torch.ops.aten.mul.Tensor,
+                        torch.ops.aten.sigmoid.default,
+                    }
+                    return (
+                        CheckpointPolicy.MUST_RECOMPUTE
+                        if to_recompute
+                        else CheckpointPolicy.MUST_SAVE
+                    )
+
+                return _custom_policy
+
+            return create_selective_checkpoint_contexts(_recomp_policy())
+
+        class Parametrization(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def parametrization(self, x):
+                return torch.sigmoid(torch.mul(x, x))
+
+            def forward(self, x):
+                return checkpoint(
+                    self.parametrization, x, use_reentrant=False, context_fn=sac_policy
+                )
+
+        def apply_parametrization(model):
+            modules = list(model.modules())
+
+            for mod in modules:
+                params_dict = dict(mod.named_parameters(recurse=False))
+                for p_name, p in params_dict.items():
+                    mod.register_parameter(p_name, nn.Parameter(p))
+                    nn.utils.parametrize.register_parametrization(
+                        mod, p_name, Parametrization(), unsafe=True
+                    )
+
+            return model
+
+        class MLPModule(nn.Module):
+            def __init__(self):
+                super().__init__()
+                torch.manual_seed(5)
+                self.net1 = nn.Linear(16, 16, bias=False)
+
+            def forward(self, x):
+                return self.net1(x)
+
+            def reset_parameters(self):
+                self.net1.reset_parameters()
+
+        fw_compiler = functools.partial(
+            count_ops,
+            freqs=[1, 1],
+            ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
+        )
+        bw_compiler = functools.partial(
+            count_ops,
+            freqs=[
+                2,  # 1 from mul recompute, 1 from mul backward
+                1,
+            ],
+            ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
+        )
+
+        backend = aot_autograd(
+            fw_compiler=fw_compiler,
+            bw_compiler=bw_compiler,
+            partition_fn=min_cut_rematerialization_partition,
+        )
+
+        model = MLPModule()
+        model = apply_parametrization(model)
+        model_compiled = torch.compile(
+            copy.deepcopy(model), backend=backend, fullgraph=True
+        )
+        input = torch.randn(8, 16, requires_grad=True)
+        input_compiled = copy.deepcopy(input)
+
+        out = model(input)
+        out.sum().backward()
+        out_compiled = model_compiled(input_compiled)
+        out_compiled.sum().backward()
+
+        self.assertEqual(out, out_compiled)
+        self.assertEqual(input.grad, input_compiled.grad)
+
     @requires_cuda
     @skipIfRocm
     def test_autocast_flash_attention(self):