[Reland][functorch] test for compiling functorch transforms (#100718)
Original PR over at #100151. Was reverted due to internal test failures.
I have fixed the internal build system.
Differential Revision: [D45608453](https://our.internmc.facebook.com/intern/diff/D45608453)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100718
Approved by: https://github.com/kshitij12345, https://github.com/atalman
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index ce6ff33..be88114 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -9,7 +9,7 @@
import copy
from torch.testing._internal.common_utils import (
TestCase, run_tests, parametrize, subtest, instantiate_parametrized_tests,
- IS_FBCODE, freeze_rng_state, skipIfTorchDynamo,
+ IS_FBCODE, freeze_rng_state, skipIfTorchDynamo, IS_WINDOWS
)
import torch
import torch.nn as nn
@@ -20,10 +20,12 @@
import unittest
import warnings
import math
+from functools import wraps
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCPU, dtypes, onlyCUDA
from torch.testing._internal.common_dtype import get_all_fp_dtypes
-from torch.testing._internal.common_cuda import with_tf32_off
+from torch.testing._internal.common_cuda import with_tf32_off, SM70OrLater, TEST_CUDA
from torch.testing import make_tensor
+from torch._dynamo import allow_in_graph
from torch._subclasses.fake_tensor import FakeTensorMode
from functools import partial
from functorch.experimental import replace_all_batch_norm_modules_
@@ -44,6 +46,7 @@
from torch._functorch.utils import enable_single_level_autograd_function
import torch.autograd.forward_ad as fwAD
from torch.func import functional_call, stack_module_state, linearize
+from common_utils import expectedFailureIf
# NB: numpy is a testing dependency!
import numpy as np
@@ -4712,6 +4715,48 @@
params = ({'weight': torch.zeros(1, 1)}, {'bias': torch.ones(1)})
functional_call(mod, params, x)
+def traceable(f):
+ f = allow_in_graph(f)
+
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ return f(*args, **kwargs)
+
+ return wrapper
+
+
+class TestCompileTransforms(TestCase):
+ # torch.compile is not supported on Windows
+ # Triton only supports GPU with SM70 or later.
+ @expectedFailureIf(IS_WINDOWS or (TEST_CUDA and not SM70OrLater))
+ def test_compile_vmap_hessian(self, device):
+ # The model and inputs are a smaller version
+ # of code at benchmark repo:
+ # https://github.com/pytorch/benchmark/blob/main/userbenchmark/functorch/vmap_hessian_fc.py
+ D = 2
+ B = 4
+
+ x = torch.randn(B, D, device=device)
+
+ model = nn.Sequential(nn.Linear(D, D), nn.ReLU()).to(device)
+
+ params_and_buffers = (dict(model.named_parameters()), dict(model.named_buffers()))
+
+ def predict(params_and_buffers, x):
+ out = torch.func.functional_call(model, params_and_buffers, x)
+ return out, out
+
+ fn = vmap(
+ jacfwd(jacrev(predict, argnums=1, has_aux=True), argnums=1, has_aux=True),
+ in_dims=(None, 0),
+ )
+
+ expected = fn(params_and_buffers, x)
+
+ opt_fn = torch.compile(traceable(fn))
+ actual = opt_fn(params_and_buffers, x)
+ self.assertEqual(actual, expected)
+
only_for = ("cpu", "cuda")
instantiate_device_type_tests(
@@ -4787,6 +4832,11 @@
instantiate_parametrized_tests(
TestMakeFunctional,
)
+instantiate_device_type_tests(
+ TestCompileTransforms,
+ globals(),
+ only_for=only_for,
+)
if __name__ == '__main__':
run_tests()