[Reland][functorch] test for compiling functorch transforms (#100718)

Original PR over at #100151. Was reverted due to internal test failures.
I have fixed the internal build system.

Differential Revision: [D45608453](https://our.internmc.facebook.com/intern/diff/D45608453)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100718
Approved by: https://github.com/kshitij12345, https://github.com/atalman
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index ce6ff33..be88114 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -9,7 +9,7 @@
 import copy
 from torch.testing._internal.common_utils import (
     TestCase, run_tests, parametrize, subtest, instantiate_parametrized_tests,
-    IS_FBCODE, freeze_rng_state, skipIfTorchDynamo,
+    IS_FBCODE, freeze_rng_state, skipIfTorchDynamo, IS_WINDOWS
 )
 import torch
 import torch.nn as nn
@@ -20,10 +20,12 @@
 import unittest
 import warnings
 import math
+from functools import wraps
 from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCPU, dtypes, onlyCUDA
 from torch.testing._internal.common_dtype import get_all_fp_dtypes
-from torch.testing._internal.common_cuda import with_tf32_off
+from torch.testing._internal.common_cuda import with_tf32_off, SM70OrLater, TEST_CUDA
 from torch.testing import make_tensor
+from torch._dynamo import allow_in_graph
 from torch._subclasses.fake_tensor import FakeTensorMode
 from functools import partial
 from functorch.experimental import replace_all_batch_norm_modules_
@@ -44,6 +46,7 @@
 from torch._functorch.utils import enable_single_level_autograd_function
 import torch.autograd.forward_ad as fwAD
 from torch.func import functional_call, stack_module_state, linearize
+from common_utils import expectedFailureIf
 
 # NB: numpy is a testing dependency!
 import numpy as np
@@ -4712,6 +4715,48 @@
         params = ({'weight': torch.zeros(1, 1)}, {'bias': torch.ones(1)})
         functional_call(mod, params, x)
 
+def traceable(f):
+    f = allow_in_graph(f)
+
+    @wraps(f)
+    def wrapper(*args, **kwargs):
+        return f(*args, **kwargs)
+
+    return wrapper
+
+
+class TestCompileTransforms(TestCase):
+    # torch.compile is not supported on Windows
+    # Triton only supports GPU with SM70 or later.
+    @expectedFailureIf(IS_WINDOWS or (TEST_CUDA and not SM70OrLater))
+    def test_compile_vmap_hessian(self, device):
+        # The model and inputs are a smaller version
+        # of code at benchmark repo:
+        # https://github.com/pytorch/benchmark/blob/main/userbenchmark/functorch/vmap_hessian_fc.py
+        D = 2
+        B = 4
+
+        x = torch.randn(B, D, device=device)
+
+        model = nn.Sequential(nn.Linear(D, D), nn.ReLU()).to(device)
+
+        params_and_buffers = (dict(model.named_parameters()), dict(model.named_buffers()))
+
+        def predict(params_and_buffers, x):
+            out = torch.func.functional_call(model, params_and_buffers, x)
+            return out, out
+
+        fn = vmap(
+            jacfwd(jacrev(predict, argnums=1, has_aux=True), argnums=1, has_aux=True),
+            in_dims=(None, 0),
+        )
+
+        expected = fn(params_and_buffers, x)
+
+        opt_fn = torch.compile(traceable(fn))
+        actual = opt_fn(params_and_buffers, x)
+        self.assertEqual(actual, expected)
+
 
 only_for = ("cpu", "cuda")
 instantiate_device_type_tests(
@@ -4787,6 +4832,11 @@
 instantiate_parametrized_tests(
     TestMakeFunctional,
 )
+instantiate_device_type_tests(
+    TestCompileTransforms,
+    globals(),
+    only_for=only_for,
+)
 
 if __name__ == '__main__':
     run_tests()