Revert "[inductor] Move things into torch/testing/_internal/inductor_utils.py (#113275)"

This reverts commit c967dc526a40f4b15003f9c99383acabe66367a6.

Reverted https://github.com/pytorch/pytorch/pull/113275 on behalf of https://github.com/PaliC due to the diff this is stacked on top of appears to be causing inductor failures internally ([comment](https://github.com/pytorch/pytorch/pull/113275#issuecomment-1805131017))
diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index 9a43aa4..0851845 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -25,13 +25,8 @@
     TEST_WITH_ROCM,
     TestCase,
 )
-from torch.testing._internal.inductor_utils import (
-    copy_tests,
-    HAS_CUDA,
-    requires_cuda,
-    requires_multigpu,
-    TestFailure,
-)
+
+from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
 from torch.utils import _pytree as pytree
 
 if HAS_CUDA:
@@ -50,6 +45,16 @@
         sys.exit(0)
     raise unittest.SkipTest("requires sympy/functorch/filelock")
 
+try:
+    try:
+        from .test_torchinductor import copy_tests, requires_multigpu, TestFailure
+    except ImportError:
+        from test_torchinductor import copy_tests, requires_multigpu, TestFailure
+except (unittest.SkipTest, ImportError) as e:
+    if __name__ == "__main__":
+        sys.exit(0)
+    raise
+
 
 class AOTInductorModelRunner:
     @classmethod
diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py
index 1348c16..237be76 100644
--- a/test/inductor/test_benchmark_fusion.py
+++ b/test/inductor/test_benchmark_fusion.py
@@ -32,11 +32,7 @@
         sys.exit(0)
     raise unittest.SkipTest("requires sympy/functorch/filelock")
 
-from torch.testing._internal.inductor_utils import (
-    check_model,
-    check_model_cuda,
-    copy_tests,
-)
+from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests
 
 
 class TestCase(TorchTestCase):
diff --git a/test/inductor/test_binary_folding.py b/test/inductor/test_binary_folding.py
index 68d5ef0..136f07a 100644
--- a/test/inductor/test_binary_folding.py
+++ b/test/inductor/test_binary_folding.py
@@ -2,22 +2,20 @@
 import functools
 import importlib
 import itertools
+import os
 import sys
 import unittest
 
 import torch
 from torch import nn
-from torch._dynamo.testing import load_test_module
 from torch._inductor import config as inductor_config
 from torch.testing._internal.common_cuda import TEST_CUDNN
+
+# Make the helper files in test/ importable
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+
 from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, TEST_WITH_ASAN
-from torch.testing._internal.inductor_utils import (
-    check_model,
-    check_model_cuda,
-    copy_tests,
-    HAS_CPU,
-    HAS_CUDA,
-)
 
 if IS_WINDOWS and IS_CI:
     sys.stderr.write(
@@ -27,11 +25,13 @@
         sys.exit(0)
     raise unittest.SkipTest("requires sympy/functorch/filelock")
 
-TestCase = load_test_module(__file__, "inductor.test_inductor_freezing").TestCase
+from inductor.test_inductor_freezing import TestCase
+from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests
 
 importlib.import_module("functorch")
 importlib.import_module("filelock")
 
+from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
 
 aten = torch.ops.aten
 
diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py
index 30a5586..2f0e455 100644
--- a/test/inductor/test_compiled_optimizers.py
+++ b/test/inductor/test_compiled_optimizers.py
@@ -1,4 +1,7 @@
 # Owner(s): ["module: inductor"]
+
+import sys
+import unittest
 import weakref
 
 from copy import deepcopy
@@ -12,16 +15,21 @@
 
 from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase
 
-from torch.testing._internal.inductor_utils import (
-    check_model,
-    check_model_cuda,
-    HAS_CPU,
-    HAS_CUDA,
-    requires_cuda,
-)
+from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
 
 aten = torch.ops.aten
 
+try:
+    try:
+        from .test_torchinductor import check_model, check_model_cuda, requires_cuda
+    except ImportError:
+        from test_torchinductor import check_model, check_model_cuda, requires_cuda
+except (unittest.SkipTest, ImportError) as e:
+    sys.stderr.write(f"{type(e)}: {e}\n")
+    if __name__ == "__main__":
+        sys.exit(0)
+    raise
+
 
 def compile_opt(opt_compiled, closure=None):
     # run the patcher so that step has the expected structure
diff --git a/test/inductor/test_cpp_wrapper.py b/test/inductor/test_cpp_wrapper.py
index 78dbdca..d1d9623 100644
--- a/test/inductor/test_cpp_wrapper.py
+++ b/test/inductor/test_cpp_wrapper.py
@@ -1,8 +1,9 @@
 # Owner(s): ["module: inductor"]
+import sys
+import unittest
 from typing import NamedTuple
 
 import torch
-from torch._dynamo.testing import load_test_module
 from torch._inductor import config
 from torch.testing._internal.common_utils import (
     IS_MACOS,
@@ -11,24 +12,32 @@
     TEST_WITH_ROCM,
     TestCase as TorchTestCase,
 )
-from torch.testing._internal.inductor_utils import (
-    HAS_CPU,
-    HAS_CUDA,
-    run_and_get_cpp_code,
-    TestFailure,
-)
+from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
 
-test_cpu_repro = load_test_module(__file__, "inductor.test_cpu_repro")
-test_foreach = load_test_module(__file__, "inductor.test_foreach")
-test_mkldnn_pattern_matcher = load_test_module(
-    __file__, "inductor.test_mkldnn_pattern_matcher"
-)
-test_pattern_matcher = load_test_module(__file__, "inductor.test_pattern_matcher")
-test_select_algorithm = load_test_module(__file__, "inductor.test_select_algorithm")
-test_torchinductor = load_test_module(__file__, "inductor.test_torchinductor")
-test_torchinductor_dynamic_shapes = load_test_module(
-    __file__, "inductor.test_torchinductor_dynamic_shapes"
-)
+
+try:
+    try:
+        from . import (
+            test_cpu_repro,
+            test_foreach,
+            test_mkldnn_pattern_matcher,
+            test_pattern_matcher,
+            test_select_algorithm,
+            test_torchinductor,
+            test_torchinductor_dynamic_shapes,
+        )
+    except ImportError:
+        import test_cpu_repro
+        import test_foreach
+        import test_mkldnn_pattern_matcher
+        import test_pattern_matcher
+        import test_select_algorithm
+        import test_torchinductor
+        import test_torchinductor_dynamic_shapes
+except unittest.SkipTest:
+    if __name__ == "__main__":
+        sys.exit(0)
+    raise
 
 
 RUN_CPU = HAS_CPU and not torch.backends.mps.is_available() and not IS_MACOS
@@ -61,21 +70,25 @@
 
 test_failures_cpp_wrapper = {
     # conv2d will fallback for dynamic shapes; the fallback path is not yet supported
-    "test_conv2d_unary_cpu_dynamic_shapes": TestFailure(("cpp_wrapper",), is_skip=True),
-    "test_conv2d_binary_inplace_fusion_failed_cpu_dynamic_shapes": TestFailure(
+    "test_conv2d_unary_cpu_dynamic_shapes": test_torchinductor.TestFailure(
         ("cpp_wrapper",), is_skip=True
     ),
-    "test_conv2d_binary_inplace_fusion_pass_cpu_dynamic_shapes": TestFailure(
+    "test_conv2d_binary_inplace_fusion_failed_cpu_dynamic_shapes": test_torchinductor.TestFailure(
+        ("cpp_wrapper",), is_skip=True
+    ),
+    "test_conv2d_binary_inplace_fusion_pass_cpu_dynamic_shapes": test_torchinductor.TestFailure(
         ("cpp_wrapper",), is_skip=True
     ),
     # aten._native_multi_head_attention.default is not yet supported for dynamic shapes
-    "test_multihead_attention_cpu_dynamic_shapes": TestFailure(
+    "test_multihead_attention_cpu_dynamic_shapes": test_torchinductor.TestFailure(
         ("cpp_wrapper",), is_skip=True
     ),
 }
 
 test_failures_cuda_wrapper = {
-    "test_mm_plus_mm2_dynamic_shapes": TestFailure(("cuda_wrapper",), is_skip=True),
+    "test_mm_plus_mm2_dynamic_shapes": test_torchinductor.TestFailure(
+        ("cuda_wrapper",), is_skip=True
+    ),
 }
 
 
@@ -91,7 +104,9 @@
         tests.setUpClass()
         tests.setUp()
         try:
-            _, code = run_and_get_cpp_code(func, *func_inputs if func_inputs else [])
+            _, code = test_torchinductor.run_and_get_cpp_code(
+                func, *func_inputs if func_inputs else []
+            )
             self.assertEqual("CppWrapperCodeCache" in code, True)
         finally:
             tests.tearDown()
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index 899a259..583a708 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -35,20 +35,28 @@
 from torch.fx.experimental.proxy_tensor import make_fx
 from torch.nn import functional as F
 from torch.testing._internal.common_utils import IS_MACOS, slowTest
-from torch.testing._internal.inductor_utils import (
-    check_model,
-    run_and_get_cpp_code,
-    TestCase,
-    vec_dtypes,
-)
 from torch.utils._python_dispatch import TorchDispatchMode
 
+try:
+    try:
+        from . import test_torchinductor
+    except ImportError:
+        import test_torchinductor
+except unittest.SkipTest:
+    if __name__ == "__main__":
+        sys.exit(0)
+    raise
 
+
+vec_dtypes = test_torchinductor.vec_dtypes
 _lowp_fp_dtypes = (
     torch.bfloat16,
     torch.float16,
 )
+run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
+TestCase = test_torchinductor.TestCase
 aten = torch.ops.aten
+check_model = test_torchinductor.check_model
 
 
 class LstmModule(torch.nn.Module):
diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py
index 5dac564..c38ffa5 100644
--- a/test/inductor/test_cuda_repro.py
+++ b/test/inductor/test_cuda_repro.py
@@ -22,16 +22,27 @@
     skipIfRocm,
     TEST_WITH_ASAN,
 )
-from torch.testing._internal.inductor_utils import check_model_cuda, TestCase, ToTuple
 
 try:
-    import triton
-    from triton import language as tl
-except ImportError:
+    try:
+        import triton
+        from triton import language as tl
+    except ImportError:
+        raise unittest.SkipTest("requires triton")  # noqa: TRY200
+
+    try:
+        from . import test_torchinductor
+    except ImportError:
+        import test_torchinductor
+except unittest.SkipTest:
     if __name__ == "__main__":
         sys.exit(0)
     raise
 
+
+TestCase = test_torchinductor.TestCase
+ToTuple = test_torchinductor.ToTuple
+check_model_cuda = test_torchinductor.check_model_cuda
 aten = torch.ops.aten
 
 
diff --git a/test/inductor/test_debug_trace.py b/test/inductor/test_debug_trace.py
index a7892fa..50ea00b 100644
--- a/test/inductor/test_debug_trace.py
+++ b/test/inductor/test_debug_trace.py
@@ -1,16 +1,33 @@
 # Owner(s): ["module: inductor"]
 import logging
+import os
 import pathlib
 import re
 import shutil
+import sys
+import unittest
 
 import torch
 from torch._inductor import config, test_operators
-from torch.testing._internal.inductor_utils import filesize, TestCase
+
+try:
+    try:
+        from . import test_torchinductor
+    except ImportError:
+        import test_torchinductor
+except unittest.SkipTest:
+    if __name__ == "__main__":
+        sys.exit(0)
+    raise
+
+
+def filesize(filename: pathlib.Path):
+    assert filename.exists(), f"{filename} is missing"
+    return os.stat(filename).st_size
 
 
 @config.patch("trace.enabled", True)
-class TestDebugTrace(TestCase):
+class TestDebugTrace(test_torchinductor.TestCase):
     def test_debug_trace(self):
         @torch.compile
         def fn(a, b):
diff --git a/test/inductor/test_efficient_conv_bn_eval.py b/test/inductor/test_efficient_conv_bn_eval.py
index 9c2529f..fdb879f 100644
--- a/test/inductor/test_efficient_conv_bn_eval.py
+++ b/test/inductor/test_efficient_conv_bn_eval.py
@@ -1,19 +1,25 @@
 # Owner(s): ["module: inductor"]
 import copy
+import importlib
 import itertools
+import os
 import sys
 import unittest
 
 import torch
 from torch import nn
 
+# Make the helper files in test/ importable
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+
 from torch._dynamo.test_case import TestCase
 from torch._dynamo.utils import counters
 from torch._inductor import config as inductor_config
 
 from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, TEST_WITH_ASAN
 
-from torch.testing._internal.inductor_utils import copy_tests, HAS_CPU, HAS_CUDA
+from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
 
 if IS_WINDOWS and IS_CI:
     sys.stderr.write(
@@ -23,6 +29,11 @@
         sys.exit(0)
     raise unittest.SkipTest("requires sympy/functorch/filelock")
 
+importlib.import_module("functorch")
+importlib.import_module("filelock")
+
+from inductor.test_torchinductor import copy_tests
+
 
 class ConvOp(nn.Module):
     expected_optimization_count = 1
diff --git a/test/inductor/test_extension_backend.py b/test/inductor/test_extension_backend.py
index 549dd42..54f0209 100644
--- a/test/inductor/test_extension_backend.py
+++ b/test/inductor/test_extension_backend.py
@@ -27,7 +27,20 @@
     register_backend_for_device,
 )
 from torch.testing._internal.common_utils import IS_FBCODE, IS_MACOS
-from torch.testing._internal.inductor_utils import run_and_get_cpp_code, TestCase
+
+try:
+    try:
+        from . import test_torchinductor
+    except ImportError:
+        import test_torchinductor
+except unittest.SkipTest:
+    if __name__ == "__main__":
+        sys.exit(0)
+    raise
+
+
+run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
+TestCase = test_torchinductor.TestCase
 
 
 def remove_build_path():
diff --git a/test/inductor/test_foreach.py b/test/inductor/test_foreach.py
index 8998dcb..3e141e8 100644
--- a/test/inductor/test_foreach.py
+++ b/test/inductor/test_foreach.py
@@ -1,5 +1,6 @@
 # Owner(s): ["module: inductor"]
 
+import sys
 import unittest
 
 import torch
@@ -14,16 +15,21 @@
     TestCase,
 )
 
-from torch.testing._internal.inductor_utils import (
-    check_model,
-    check_model_cuda,
-    HAS_CPU,
-    HAS_CUDA,
-    requires_cuda,
-)
+from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
 
 aten = torch.ops.aten
 
+try:
+    try:
+        from .test_torchinductor import check_model, check_model_cuda, requires_cuda
+    except ImportError:
+        from test_torchinductor import check_model, check_model_cuda, requires_cuda
+except (unittest.SkipTest, ImportError) as e:
+    sys.stderr.write(f"{type(e)}: {e}\n")
+    if __name__ == "__main__":
+        sys.exit(0)
+    raise
+
 
 bin_ops_under_test = [
     torch._foreach_add,
diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py
index 7117d1f..2637eaf 100644
--- a/test/inductor/test_group_batch_fusion.py
+++ b/test/inductor/test_group_batch_fusion.py
@@ -1,12 +1,13 @@
 # Owner(s): ["module: inductor"]
 
+import functools
 import unittest
 
 import torch
 import torch._inductor
 from torch._dynamo.test_case import run_tests, TestCase
 from torch._dynamo.utils import counters
-from torch.testing._internal.inductor_utils import requires_cuda
+from torch.testing._internal.inductor_utils import HAS_CUDA
 
 try:
     # importing this will register fbgemm lowerings for inductor
@@ -17,6 +18,8 @@
     has_fbgemm = False
     pass
 
+requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
+
 
 class MyModule(torch.nn.Module):
     def __init__(self, z: int, has_bias: bool, device="cuda") -> None:
diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py
index c7b6043..958b9a0 100644
--- a/test/inductor/test_inductor_freezing.py
+++ b/test/inductor/test_inductor_freezing.py
@@ -1,7 +1,9 @@
 # Owner(s): ["module: inductor"]
 import contextlib
 import functools
+import importlib
 import itertools
+import os
 import sys
 import unittest
 import weakref
@@ -14,6 +16,10 @@
 from torch.testing import FileCheck
 from torch.testing._internal.common_cuda import SM80OrLater
 
+# Make the helper files in test/ importable
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+
 from torch.testing._internal.common_utils import (
     IS_CI,
     IS_WINDOWS,
@@ -21,13 +27,6 @@
     TEST_WITH_ASAN,
     TestCase as TorchTestCase,
 )
-from torch.testing._internal.inductor_utils import (
-    check_model,
-    check_model_cuda,
-    copy_tests,
-    HAS_CPU,
-    HAS_CUDA,
-)
 
 if IS_WINDOWS and IS_CI:
     sys.stderr.write(
@@ -37,6 +36,12 @@
         sys.exit(0)
     raise unittest.SkipTest("requires sympy/functorch/filelock")
 
+from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests
+
+importlib.import_module("functorch")
+importlib.import_module("filelock")
+
+from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
 
 HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2
 aten = torch.ops.aten
diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py
index 7d8d7e6..2fdae87 100644
--- a/test/inductor/test_memory_planning.py
+++ b/test/inductor/test_memory_planning.py
@@ -1,19 +1,8 @@
 # Owner(s): ["module: inductor"]
 
 import sys
-import unittest
-from typing import List
-
-import torch
-from torch._C import FileCheck
-from torch._dynamo.test_case import run_tests, TestCase
-from torch._dynamo.testing import load_test_module
-from torch._dynamo.utils import same
-from torch._inductor import config
 
 from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, skipIfRocm
-from torch.testing._internal.inductor_utils import run_and_get_cpp_code
-from torch.utils._triton import has_triton
 
 if IS_WINDOWS and IS_CI:
     sys.stderr.write(
@@ -23,6 +12,17 @@
         sys.exit(0)
     raise unittest.SkipTest("requires sympy/functorch/filelock")
 
+import unittest
+from typing import List
+
+import torch
+from test_torchinductor import run_and_get_cpp_code
+from torch._C import FileCheck
+from torch._dynamo.test_case import run_tests, TestCase
+from torch._dynamo.utils import same
+from torch._inductor import config
+from torch.utils._triton import has_triton
+
 
 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
 @config.patch(memory_planning=True)
@@ -80,9 +80,8 @@
 
     @skipIfRocm(msg="test_aot_inductor doesn't work on ROCm")
     def test_abi_compatible(self):
-        AOTInductorModelRunner = load_test_module(
-            __file__, "inductor.test_aot_inductor"
-        ).AOTInductorModelRunner
+        from test_aot_inductor import AOTInductorModelRunner
+
         f, args = self._generate(device="cuda")
         constraints: List[torch.export.Constraint] = [
             torch._export.dynamic_dim(args[0], 0) >= 1,
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 9190019..1111888 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -14,9 +14,11 @@
 import subprocess
 import sys
 import threading
+import time
 import typing
 import unittest
 import weakref
+from typing import Tuple
 from unittest.mock import patch
 
 import numpy as np
@@ -55,20 +57,25 @@
     DeterministicGuard,
     IS_CI,
     IS_FBCODE,
+    IS_MACOS,
     IS_WINDOWS,
     IS_X86,
     skipIfRocm,
     TEST_WITH_ASAN,
+    TestCase as TorchTestCase,
 )
 from torch.utils import _pytree as pytree
 from torch.utils._python_dispatch import TorchDispatchMode
+from torch.utils._pytree import tree_flatten, tree_unflatten
 from torch.utils.weak import WeakTensorKeyDictionary
 
 if IS_WINDOWS and IS_CI:
     sys.stderr.write(
         "Windows CI does not have necessary dependencies for test_torchinductor yet\n"
     )
-    sys.exit(0)
+    if __name__ == "__main__":
+        sys.exit(0)
+    raise unittest.SkipTest("requires sympy/functorch/filelock")
 
 importlib.import_module("functorch")
 importlib.import_module("filelock")
@@ -79,24 +86,19 @@
 from torch._inductor.utils import has_torchvision_roi_align
 
 from torch.testing._internal.common_utils import slowTest
-from torch.testing._internal.inductor_utils import (
-    check_model,
-    check_model_cuda,
-    copy_tests,
-    HAS_AVX2,
-    HAS_CPU,
-    HAS_CUDA,
-    requires_cuda,
-    requires_multigpu,
-    run_and_get_cpp_code,
-    skip_if_x86_mac,
-    TestCase,
-    ToTuple,
-)
+from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
 
-
+HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2
+HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines
 aten = torch.ops.aten
-
+requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
+requires_multigpu = functools.partial(
+    unittest.skipIf, not HAS_MULTIGPU, "requires multiple cuda devices"
+)
+skip_if_x86_mac = functools.partial(
+    unittest.skipIf, IS_MACOS and IS_X86, "Does not work on x86 Mac"
+)
+vec_dtypes = [torch.float, torch.bfloat16, torch.float16]
 
 libfoo = None
 
@@ -110,6 +112,48 @@
     return run_and_get_code(run_with_backward)
 
 
+class TestCase(TorchTestCase):
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        cls._stack = contextlib.ExitStack()
+        cls._stack.enter_context(
+            config.patch(
+                {
+                    "debug": True,
+                    "debug_index_asserts": True,
+                    "cpp.min_chunk_size": 1,
+                    "triton.autotune_pointwise": False,  # too slow
+                    "implicit_fallbacks": False,
+                    "generate_intermediate_hooks": True,
+                }
+            )
+        )
+
+    @classmethod
+    def tearDownClass(cls):
+        cls._stack.close()
+        super().tearDownClass()
+
+    def setUp(self):
+        torch._dynamo.reset()
+        torch._inductor.metrics.reset()
+        super().setUp()
+        self._start = time.perf_counter()
+
+    def tearDown(self):
+        super().tearDown()
+        torch._dynamo.reset()
+        if os.environ.get("ERROR_ON_SLOW") == "1":
+            elapsed = time.perf_counter() - self._start
+            assert elapsed < 120
+
+
+class ToTuple(torch.nn.Module):
+    def forward(self, x):
+        return (x,)
+
+
 @dataclasses.dataclass
 class InputGen:
     n: int
@@ -142,6 +186,319 @@
         return torch.arange(self.n, device=self.device, dtype=torch.int32)
 
 
+def compute_grads(args, kwrags, results, grads):
+    def gather_leaf_tensors(args, kwargs):
+        args = pytree.arg_tree_leaves(*args, **kwargs)
+        leaf_tensors = [
+            arg for arg in args if isinstance(arg, torch.Tensor) and arg.requires_grad
+        ]
+        return leaf_tensors
+
+    flat_results = pytree.tree_leaves(results)
+    flat_diff_results = [r for r in flat_results if r.requires_grad]
+    assert len(flat_diff_results) > 0
+
+    leaf_tensors = gather_leaf_tensors(args, kwrags)
+    assert len(leaf_tensors) > 0
+    return torch.autograd.grad(
+        flat_diff_results,
+        leaf_tensors,
+        grads,
+        allow_unused=True,
+        retain_graph=True,
+    )
+
+
+def clone_preserve_strides(x, device=None):
+    if not isinstance(x, torch.Tensor):
+        return x
+    buffer = torch.as_strided(
+        x, (x.untyped_storage().size() // x.element_size(),), (1,), 0
+    )
+    if not device:
+        buffer = buffer.clone()
+    else:
+        buffer = buffer.to(device, copy=True)
+    out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
+    return out
+
+
+def run_and_get_cpp_code(fn, *args, **kwargs):
+    # We use the patch context manager instead of using it as a decorator.
+    # In this way, we can ensure that the attribute is patched and unpatched correctly
+    # even if this run_and_get_cpp_code function is called multiple times.
+    with patch.object(config, "debug", True):
+        torch._dynamo.reset()
+        import io
+        import logging
+
+        log_capture_string = io.StringIO()
+        ch = logging.StreamHandler(log_capture_string)
+        from torch._inductor.graph import output_code_log
+
+        output_code_log.addHandler(ch)
+        prev_level = output_code_log.level
+        output_code_log.setLevel(logging.DEBUG)
+        result = fn(*args, **kwargs)
+        s = log_capture_string.getvalue()
+        output_code_log.setLevel(prev_level)
+        output_code_log.removeHandler(ch)
+    return result, s
+
+
+def check_model(
+    self: TestCase,
+    model,
+    example_inputs,
+    kwargs=None,
+    *,
+    atol=None,
+    rtol=None,
+    check_lowp=True,
+    exact_dtype=True,
+    nopython=True,
+    copy_to_cuda=True,
+    reference_in_float=True,
+    assert_equal=True,
+    check_gradient=False,
+    check_has_compiled=True,
+    output_process_fn_grad=lambda x: x,
+):
+    kwargs = kwargs or {}
+    torch._dynamo.reset()
+
+    ref_inputs = [clone_preserve_strides(x) for x in example_inputs]
+    ref_kwargs = kwargs
+    has_lowp_args = False
+    original_lowp_dtype = torch.half
+
+    if reference_in_float:
+        # check_lowp is ignored here, it's kept just to be able to call `common` with extra arg
+        def upcast_fn(x):
+            nonlocal has_lowp_args
+            if isinstance(x, torch.Tensor) and (
+                x.dtype == torch.float16 or x.dtype == torch.bfloat16
+            ):
+                has_lowp_args = True
+                return x.float()
+            else:
+                return x
+
+        def get_original_lowp_dtype(example_inputs):
+            dtypes = [x.dtype for x in example_inputs if isinstance(x, torch.Tensor)]
+            dtype_set = set(dtypes)
+            return dtype_set.pop() if len(dtype_set) == 1 else torch.half
+
+        ref_inputs = list(map(upcast_fn, example_inputs))
+        ref_kwargs = {k: upcast_fn(v) for k, v in kwargs.items()}
+        if has_lowp_args:
+            original_lowp_dtype = get_original_lowp_dtype(example_inputs)
+            if hasattr(model, "to"):
+                model = model.to(torch.float)
+
+    torch.manual_seed(0)
+
+    correct = model(*ref_inputs, **ref_kwargs)
+    # downcast the model back if needed
+    if reference_in_float and has_lowp_args:
+        if hasattr(model, "to"):
+            model = model.to(original_lowp_dtype)
+
+    torch._inductor.metrics.reset()
+
+    called = False
+
+    def compile_fx_wrapper(model_, example_inputs_):
+        nonlocal called
+        called = True
+        return compile_fx(model_, example_inputs_)
+
+    def run(*ex, **kwargs):
+        return model(*ex, **kwargs)
+
+    run = torch._dynamo.optimize(compile_fx_wrapper, nopython=nopython)(run)
+
+    torch.manual_seed(0)
+    actual = run(*example_inputs, **kwargs)
+    # if not called:
+    #     exp = torch._dynamo.explain(run)(*example_inputs)
+    #     print("Explain:", exp[0])
+    #     for graph in exp[2]:
+    #         print("Graph", graph)
+    if check_has_compiled:
+        assert called, "Ran graph without calling compile_fx"
+    assert type(actual) == type(correct)
+
+    correct_flat, correct_spec = tree_flatten(correct)
+    actual_flat = pytree.tree_leaves(actual)
+
+    def reference_to_expect(actual_flat, correct_flat):
+        return tuple(
+            y.to(x.dtype)
+            if isinstance(y, torch.Tensor) and y.dtype.is_floating_point
+            else y
+            for x, y in zip(actual_flat, correct_flat)
+        )
+
+    if reference_in_float:
+        correct_flat = reference_to_expect(actual_flat, correct_flat)
+        correct = tree_unflatten(correct_flat, correct_spec)
+
+    if assert_equal:
+        self.assertEqual(
+            actual,
+            correct,
+            atol=atol,
+            rtol=rtol,
+            equal_nan=True,
+            exact_dtype=exact_dtype,
+        )
+        # In case of input mutations, check that inputs are the same
+        self.assertEqual(
+            ref_inputs,
+            example_inputs,
+            atol=atol,
+            rtol=rtol,
+            equal_nan=True,
+            # our testing sometimes uses higher precision inputs for the reference
+            exact_dtype=False,
+        )
+    else:
+        for correct_val, actual_val in zip(correct_flat, actual_flat):
+            if isinstance(correct_val, torch.Tensor):
+                assert correct_val.device == actual_val.device
+                assert correct_val.size() == actual_val.size()
+                assert correct_val.stride() == actual_val.stride()
+                assert correct_val.layout == actual_val.layout
+                if exact_dtype:
+                    assert correct_val.dtype == actual_val.dtype
+
+    if check_gradient:
+        actual = output_process_fn_grad(actual)
+        correct = output_process_fn_grad(correct)
+        actual_flat = pytree.tree_leaves(actual)
+        correct_flat = pytree.tree_leaves(correct)
+
+        # generate random unit norm gradients
+        grads = [
+            torch.rand(r.shape, device=r.device, dtype=r.dtype)
+            for r in correct_flat
+            if r.requires_grad
+        ]
+        for g in grads:
+            g /= g.norm()
+
+        correct_grad = compute_grads(ref_inputs, ref_kwargs, correct, grads)
+        all_none_grads = all(x is None for x in correct_grad)
+        if all_none_grads:
+            # See Note [Detaching inputs that never need gradients]
+            # There are a handful of ops that can return None gradients, into of zero gradients.
+            # If all inputs to an AOTAutograd graph are supposed to get None gradients,
+            # AOTAutograd will end up forcing all of the outputs of the forward to not require grad.
+            # There's no easy fix to this (see the note above), although one option is to
+            # force any derivative formulas in core to return tensors of zeros instead of None.
+            flat_results = pytree.tree_leaves(actual)
+            results_that_require_grad = [
+                x
+                for x in flat_results
+                if isinstance(x, torch.Tensor) and x.requires_grad
+            ]
+            self.assertEqual(len(results_that_require_grad), 0)
+        else:
+            actual_grad = compute_grads(example_inputs, kwargs, actual, grads)
+
+            if reference_in_float:
+                expect_grad = reference_to_expect(actual_grad, correct_grad)
+            else:
+                expect_grad = correct_grad
+
+            self.assertEqual(
+                actual_grad,
+                expect_grad,
+                atol=atol,
+                rtol=rtol,
+                equal_nan=True,
+                exact_dtype=exact_dtype,
+            )
+
+    torch._dynamo.reset()
+
+
+@torch._inductor.config.patch("triton.cudagraphs", False)
+def check_model_cuda(
+    self: TestCase,
+    model,
+    example_inputs,
+    kwargs=None,
+    *,
+    atol=None,
+    rtol=None,
+    check_lowp=True,
+    exact_dtype=True,
+    nopython=True,
+    copy_to_cuda=True,
+    reference_in_float=True,
+    assert_equal=True,
+    check_gradient=False,
+    check_has_compiled=True,
+    output_process_fn_grad=lambda x: x,
+):
+    kwargs = kwargs or {}
+    if hasattr(model, "to"):
+        model = model.to("cuda")
+
+    if copy_to_cuda:
+        example_inputs = tuple(
+            clone_preserve_strides(x, device="cuda") for x in example_inputs
+        )
+
+    check_model(
+        self,
+        model,
+        example_inputs,
+        kwargs,
+        atol=atol,
+        rtol=rtol,
+        exact_dtype=exact_dtype,
+        nopython=nopython,
+        reference_in_float=reference_in_float,
+        assert_equal=assert_equal,
+        check_gradient=check_gradient,
+        check_has_compiled=check_has_compiled,
+        output_process_fn_grad=output_process_fn_grad,
+    )
+
+    if check_lowp:
+
+        def downcast_fn(x):
+            if not isinstance(x, torch.Tensor) or not x.dtype == torch.float:
+                return x
+            return torch.empty_strided(
+                x.size(), x.stride(), device="cuda", dtype=torch.half
+            ).copy_(x)
+
+        example_inputs = list(map(downcast_fn, example_inputs))
+        if hasattr(model, "to"):
+            model = model.to(torch.half)
+        if rtol is not None:
+            rtol = max(2e-3, rtol)
+        check_model(
+            self,
+            model,
+            example_inputs,
+            kwargs,
+            atol=atol,
+            rtol=rtol,
+            exact_dtype=exact_dtype,
+            nopython=nopython,
+            reference_in_float=reference_in_float,
+            assert_equal=assert_equal,
+            check_gradient=check_gradient,
+            check_has_compiled=check_has_compiled,
+            output_process_fn_grad=output_process_fn_grad,
+        )
+
+
 def _run_and_assert_no_indirect_indexing(test_case, func, *args, **kwargs):
     result, source_codes = run_and_get_code(func, *args, **kwargs)
 
@@ -7314,6 +7671,46 @@
         self.common(fn, (x,))
 
 
+@dataclasses.dataclass
+class TestFailure:
+    suffixes: Tuple[str]
+    is_skip: bool = False
+    __test__: bool = False
+
+
+def copy_tests(
+    my_cls, other_cls, suffix, test_failures=None, xfail_prop=None
+):  # noqa: B902
+    for name, value in my_cls.__dict__.items():
+        if name.startswith("test_"):
+            # You cannot copy functions in Python, so we use closures here to
+            # create objects with different ids. Otherwise, unittest.skip
+            # would modify all methods sharing the same object id. Also, by
+            # using a default argument, we create a copy instead of a
+            # reference. Otherwise, we would lose access to the value.
+
+            @functools.wraps(value)
+            def new_test(self, value=value):
+                return value(self)
+
+            # Copy __dict__ which may contain test metadata
+            new_test.__dict__ = copy.deepcopy(value.__dict__)
+
+            if xfail_prop is not None and hasattr(value, xfail_prop):
+                new_test = unittest.expectedFailure(new_test)
+
+            tf = test_failures and test_failures.get(name)
+            if tf is not None and suffix in tf.suffixes:
+                skip_func = (
+                    unittest.skip("Skipped!")
+                    if tf.is_skip
+                    else unittest.expectedFailure
+                )
+                new_test = skip_func(new_test)
+
+            setattr(other_cls, f"{name}_{suffix}", new_test)
+
+
 if HAS_CPU and not torch.backends.mps.is_available():
 
     class SweepInputsCpuTest(SweepInputs2, TestCase):
diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
index b49b5e5..1813588 100644
--- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
+++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
@@ -1,33 +1,23 @@
 # Owner(s): ["module: inductor"]
+import importlib
+import os
 import sys
 import unittest
 
 import torch
-from torch._dynamo.testing import load_test_module
 from torch._inductor.compile_fx import compile_fx
-from torch._inductor.utils import run_and_get_triton_code
 from torch.testing._internal.common_utils import (
     IS_CI,
     IS_WINDOWS,
     TEST_WITH_ASAN,
     TestCase,
 )
-
 from torch.testing._internal.inductor_utils import (
     _check_has_dynamic_shape,
-    copy_tests,
     HAS_CPU,
     HAS_CUDA,
-    make_dynamic_cls,
-    run_and_get_cpp_code,
-    TestFailure,
 )
 
-CommonTemplate = load_test_module(
-    __file__, "inductor.test_torchinductor"
-).CommonTemplate
-
-
 if IS_WINDOWS and IS_CI:
     sys.stderr.write(
         "Windows CI does not have necessary dependencies for test_torchinductor_codegen_dynamic_shapes yet\n"
@@ -36,6 +26,20 @@
         sys.exit(0)
     raise unittest.SkipTest("requires sympy/functorch/filelock")
 
+importlib.import_module("filelock")
+
+# Make the helper files in test/ importable
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+from inductor.test_torchinductor import (
+    CommonTemplate,
+    copy_tests,
+    run_and_get_cpp_code,
+    run_and_get_triton_code,
+    TestFailure,
+)
+from inductor.test_torchinductor_dynamic_shapes import make_dynamic_cls
+
 
 # Checks for patterns in generated C++/Triton code to see if it's dynamic
 def check_codegen(
diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py
index 73f4586..9cc9a18 100644
--- a/test/inductor/test_torchinductor_dynamic_shapes.py
+++ b/test/inductor/test_torchinductor_dynamic_shapes.py
@@ -1,6 +1,8 @@
 # Owner(s): ["module: inductor"]
 import contextlib
+import importlib
 import math
+import os
 import sys
 import unittest
 from functools import partial
@@ -8,7 +10,7 @@
 import torch
 import torch._custom_ops as custom_ops
 import torch.library
-from torch._dynamo.testing import load_test_module
+from torch._dynamo.testing import make_test_cls_with_patches
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests,
     onlyCPU,
@@ -21,20 +23,7 @@
     TEST_WITH_ROCM,
     TestCase,
 )
-from torch.testing._internal.inductor_utils import (
-    check_model,
-    check_model_cuda,
-    copy_tests,
-    HAS_CPU,
-    HAS_CUDA,
-    make_dynamic_cls,
-    TestFailure,
-)
-
-
-CommonTemplate = load_test_module(
-    __file__, "inductor.test_torchinductor"
-).CommonTemplate
+from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
 
 if IS_WINDOWS and IS_CI:
     sys.stderr.write(
@@ -44,6 +33,18 @@
         sys.exit(0)
     raise unittest.SkipTest("requires sympy/functorch/filelock")
 
+# Make the helper files in test/ importable
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+from inductor.test_torchinductor import (
+    check_model,
+    check_model_cuda,
+    CommonTemplate,
+    copy_tests,
+    TestFailure,
+)
+
+importlib.import_module("filelock")
 
 # xfail by default, set is_skip=True to skip
 test_failures = {
@@ -65,6 +66,16 @@
     )
 
 
+def make_dynamic_cls(cls, xfail_prop="_expected_failure_dynamic"):
+    return make_test_cls_with_patches(
+        cls,
+        "DynamicShapes",
+        "_dynamic_shapes",
+        (torch._dynamo.config, "assume_static_by_default", False),
+        xfail_prop=xfail_prop,
+    )
+
+
 DynamicShapesCommonTemplate = make_dynamic_cls(CommonTemplate)
 
 
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index b7a1ea6..6839100 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -3,6 +3,7 @@
 import contextlib
 import functools
 import os
+import sys
 import unittest
 from collections import defaultdict
 from enum import Enum
@@ -41,15 +42,20 @@
     TEST_WITH_ROCM,
     TestCase,
 )
-from torch.testing._internal.inductor_utils import (
-    check_model,
-    check_model_cuda,
-    HAS_CPU,
-    HAS_CUDA,
-)
+from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
 from torch.utils._python_dispatch import TorchDispatchMode
 from torch.utils._pytree import tree_map
 
+try:
+    try:
+        from .test_torchinductor import check_model, check_model_cuda
+    except ImportError:
+        from test_torchinductor import check_model, check_model_cuda
+except (unittest.SkipTest, ImportError) as e:
+    sys.stderr.write(f"{type(e)}: {e}\n")
+    if __name__ == "__main__":
+        sys.exit(0)
+    raise
 
 bf16 = torch.bfloat16  # not tested
 f64 = torch.float64
diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py
index 791c302..89b2a4d 100644
--- a/torch/testing/_internal/inductor_utils.py
+++ b/torch/testing/_internal/inductor_utils.py
@@ -1,33 +1,16 @@
-import contextlib
-import os
-import pathlib
-import time
 from subprocess import CalledProcessError
 
-from torch.testing._internal.common_utils import (
-    TestCase as TorchTestCase,
-)
 from torch._inductor.codecache import CppCodeCache
 from torch.utils._triton import has_triton
 from torch.testing._internal.common_utils import (
     LazyVal,
     IS_FBCODE,
-    IS_MACOS,
-    IS_X86,
 )
 from torch._dynamo.backends.registry import register_backend
 from torch._inductor.compile_fx import compile_fx, count_bytes_inner
 from torch.testing._internal.common_utils import TestCase
 import torch
 import re
-import functools
-import unittest
-import dataclasses
-import copy
-from torch.utils import _pytree as pytree
-from torch.utils._pytree import tree_flatten, tree_unflatten
-from typing import Tuple
-from torch._dynamo.testing import make_test_cls_with_patches
 
 def test_cpu():
     try:
@@ -66,416 +49,3 @@
         has_dynamic, msg=f"Failed to find dynamic for loop variable\n{code}"
     )
     self.assertTrue(for_loop_found, f"Failed to find for loop\n{code}")
-
-HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2
-HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines
-requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
-requires_multigpu = functools.partial(
-    unittest.skipIf, not HAS_MULTIGPU, "requires multiple cuda devices"
-)
-skip_if_x86_mac = functools.partial(
-    unittest.skipIf, IS_MACOS and IS_X86, "Does not work on x86 Mac"
-)
-vec_dtypes = [torch.float, torch.bfloat16, torch.float16]
-
-@dataclasses.dataclass
-class TestFailure:
-    suffixes: Tuple[str]
-    is_skip: bool = False
-    __test__: bool = False
-
-def copy_tests(
-        my_cls, other_cls, suffix, test_failures=None, xfail_prop=None
-):  # noqa: B902
-    for name, value in my_cls.__dict__.items():
-        if name.startswith("test_"):
-            # You cannot copy functions in Python, so we use closures here to
-            # create objects with different ids. Otherwise, unittest.skip
-            # would modify all methods sharing the same object id. Also, by
-            # using a default argument, we create a copy instead of a
-            # reference. Otherwise, we would lose access to the value.
-
-            @functools.wraps(value)
-            def new_test(self, value=value):
-                return value(self)
-
-            # Copy __dict__ which may contain test metadata
-            new_test.__dict__ = copy.deepcopy(value.__dict__)
-
-            if xfail_prop is not None and hasattr(value, xfail_prop):
-                new_test = unittest.expectedFailure(new_test)
-
-            tf = test_failures and test_failures.get(name)
-            if tf is not None and suffix in tf.suffixes:
-                skip_func = (
-                    unittest.skip("Skipped!")
-                    if tf.is_skip
-                    else unittest.expectedFailure
-                )
-                new_test = skip_func(new_test)
-
-            setattr(other_cls, f"{name}_{suffix}", new_test)
-
-
-def clone_preserve_strides(x, device=None):
-    if not isinstance(x, torch.Tensor):
-        return x
-    buffer = torch.as_strided(
-        x, (x.untyped_storage().size() // x.element_size(),), (1,), 0
-    )
-    if not device:
-        buffer = buffer.clone()
-    else:
-        buffer = buffer.to(device, copy=True)
-    out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
-    return out
-
-
-
-def compute_grads(args, kwrags, results, grads):
-    def gather_leaf_tensors(args, kwargs):
-        args = pytree.arg_tree_leaves(*args, **kwargs)
-        leaf_tensors = [
-            arg for arg in args if isinstance(arg, torch.Tensor) and arg.requires_grad
-        ]
-        return leaf_tensors
-
-    flat_results = pytree.tree_leaves(results)
-    flat_diff_results = [r for r in flat_results if r.requires_grad]
-    assert len(flat_diff_results) > 0
-
-    leaf_tensors = gather_leaf_tensors(args, kwrags)
-    assert len(leaf_tensors) > 0
-    return torch.autograd.grad(
-        flat_diff_results,
-        leaf_tensors,
-        grads,
-        allow_unused=True,
-        retain_graph=True,
-    )
-
-
-def check_model(
-        self: TestCase,
-        model,
-        example_inputs,
-        kwargs=None,
-        *,
-        atol=None,
-        rtol=None,
-        check_lowp=True,
-        exact_dtype=True,
-        nopython=True,
-        copy_to_cuda=True,
-        reference_in_float=True,
-        assert_equal=True,
-        check_gradient=False,
-        check_has_compiled=True,
-        output_process_fn_grad=lambda x: x,
-):
-    kwargs = kwargs or {}
-    torch._dynamo.reset()
-
-    ref_inputs = [clone_preserve_strides(x) for x in example_inputs]
-    ref_kwargs = kwargs
-    has_lowp_args = False
-    original_lowp_dtype = torch.half
-
-    if reference_in_float:
-        # check_lowp is ignored here, it's kept just to be able to call `common` with extra arg
-        def upcast_fn(x):
-            nonlocal has_lowp_args
-            if isinstance(x, torch.Tensor) and (
-                    x.dtype == torch.float16 or x.dtype == torch.bfloat16
-            ):
-                has_lowp_args = True
-                return x.float()
-            else:
-                return x
-
-        def get_original_lowp_dtype(example_inputs):
-            dtypes = [x.dtype for x in example_inputs if isinstance(x, torch.Tensor)]
-            dtype_set = set(dtypes)
-            return dtype_set.pop() if len(dtype_set) == 1 else torch.half
-
-        ref_inputs = list(map(upcast_fn, example_inputs))
-        ref_kwargs = {k: upcast_fn(v) for k, v in kwargs.items()}
-        if has_lowp_args:
-            original_lowp_dtype = get_original_lowp_dtype(example_inputs)
-            if hasattr(model, "to"):
-                model = model.to(torch.float)
-
-    torch.manual_seed(0)
-
-    correct = model(*ref_inputs, **ref_kwargs)
-    # downcast the model back if needed
-    if reference_in_float and has_lowp_args:
-        if hasattr(model, "to"):
-            model = model.to(original_lowp_dtype)
-
-    torch._inductor.metrics.reset()
-
-    called = False
-
-    def compile_fx_wrapper(model_, example_inputs_):
-        nonlocal called
-        called = True
-        return compile_fx(model_, example_inputs_)
-
-    def run(*ex, **kwargs):
-        return model(*ex, **kwargs)
-
-    run = torch._dynamo.optimize(compile_fx_wrapper, nopython=nopython)(run)
-
-    torch.manual_seed(0)
-    actual = run(*example_inputs, **kwargs)
-    # if not called:
-    #     exp = torch._dynamo.explain(run)(*example_inputs)
-    #     print("Explain:", exp[0])
-    #     for graph in exp[2]:
-    #         print("Graph", graph)
-    if check_has_compiled:
-        assert called, "Ran graph without calling compile_fx"
-    assert type(actual) == type(correct)
-
-    correct_flat, correct_spec = tree_flatten(correct)
-    actual_flat = pytree.tree_leaves(actual)
-
-    def reference_to_expect(actual_flat, correct_flat):
-        return tuple(
-            y.to(x.dtype)
-            if isinstance(y, torch.Tensor) and y.dtype.is_floating_point
-            else y
-            for x, y in zip(actual_flat, correct_flat)
-        )
-
-    if reference_in_float:
-        correct_flat = reference_to_expect(actual_flat, correct_flat)
-        correct = tree_unflatten(correct_flat, correct_spec)
-
-    if assert_equal:
-        self.assertEqual(
-            actual,
-            correct,
-            atol=atol,
-            rtol=rtol,
-            equal_nan=True,
-            exact_dtype=exact_dtype,
-        )
-        # In case of input mutations, check that inputs are the same
-        self.assertEqual(
-            ref_inputs,
-            example_inputs,
-            atol=atol,
-            rtol=rtol,
-            equal_nan=True,
-            # our testing sometimes uses higher precision inputs for the reference
-            exact_dtype=False,
-        )
-    else:
-        for correct_val, actual_val in zip(correct_flat, actual_flat):
-            if isinstance(correct_val, torch.Tensor):
-                assert correct_val.device == actual_val.device
-                assert correct_val.size() == actual_val.size()
-                assert correct_val.stride() == actual_val.stride()
-                assert correct_val.layout == actual_val.layout
-                if exact_dtype:
-                    assert correct_val.dtype == actual_val.dtype
-
-    if check_gradient:
-        actual = output_process_fn_grad(actual)
-        correct = output_process_fn_grad(correct)
-        actual_flat = pytree.tree_leaves(actual)
-        correct_flat = pytree.tree_leaves(correct)
-
-        # generate random unit norm gradients
-        grads = [
-            torch.rand(r.shape, device=r.device, dtype=r.dtype)
-            for r in correct_flat
-            if r.requires_grad
-        ]
-        for g in grads:
-            g /= g.norm()
-
-        correct_grad = compute_grads(ref_inputs, ref_kwargs, correct, grads)
-        all_none_grads = all(x is None for x in correct_grad)
-        if all_none_grads:
-            # See Note [Detaching inputs that never need gradients]
-            # There are a handful of ops that can return None gradients, into of zero gradients.
-            # If all inputs to an AOTAutograd graph are supposed to get None gradients,
-            # AOTAutograd will end up forcing all of the outputs of the forward to not require grad.
-            # There's no easy fix to this (see the note above), although one option is to
-            # force any derivative formulas in core to return tensors of zeros instead of None.
-            flat_results = pytree.tree_leaves(actual)
-            results_that_require_grad = [
-                x
-                for x in flat_results
-                if isinstance(x, torch.Tensor) and x.requires_grad
-            ]
-            self.assertEqual(len(results_that_require_grad), 0)
-        else:
-            actual_grad = compute_grads(example_inputs, kwargs, actual, grads)
-
-            if reference_in_float:
-                expect_grad = reference_to_expect(actual_grad, correct_grad)
-            else:
-                expect_grad = correct_grad
-
-            self.assertEqual(
-                actual_grad,
-                expect_grad,
-                atol=atol,
-                rtol=rtol,
-                equal_nan=True,
-                exact_dtype=exact_dtype,
-            )
-
-    torch._dynamo.reset()
-
-
-@torch._inductor.config.patch("triton.cudagraphs", False)
-def check_model_cuda(
-        self: TestCase,
-        model,
-        example_inputs,
-        kwargs=None,
-        *,
-        atol=None,
-        rtol=None,
-        check_lowp=True,
-        exact_dtype=True,
-        nopython=True,
-        copy_to_cuda=True,
-        reference_in_float=True,
-        assert_equal=True,
-        check_gradient=False,
-        check_has_compiled=True,
-        output_process_fn_grad=lambda x: x,
-):
-    kwargs = kwargs or {}
-    if hasattr(model, "to"):
-        model = model.to("cuda")
-
-    if copy_to_cuda:
-        example_inputs = tuple(
-            clone_preserve_strides(x, device="cuda") for x in example_inputs
-        )
-
-    check_model(
-        self,
-        model,
-        example_inputs,
-        kwargs,
-        atol=atol,
-        rtol=rtol,
-        exact_dtype=exact_dtype,
-        nopython=nopython,
-        reference_in_float=reference_in_float,
-        assert_equal=assert_equal,
-        check_gradient=check_gradient,
-        check_has_compiled=check_has_compiled,
-        output_process_fn_grad=output_process_fn_grad,
-    )
-
-    if check_lowp:
-
-        def downcast_fn(x):
-            if not isinstance(x, torch.Tensor) or not x.dtype == torch.float:
-                return x
-            return torch.empty_strided(
-                x.size(), x.stride(), device="cuda", dtype=torch.half
-            ).copy_(x)
-
-        example_inputs = list(map(downcast_fn, example_inputs))
-        if hasattr(model, "to"):
-            model = model.to(torch.half)
-        if rtol is not None:
-            rtol = max(2e-3, rtol)
-        check_model(
-            self,
-            model,
-            example_inputs,
-            kwargs,
-            atol=atol,
-            rtol=rtol,
-            exact_dtype=exact_dtype,
-            nopython=nopython,
-            reference_in_float=reference_in_float,
-            assert_equal=assert_equal,
-            check_gradient=check_gradient,
-            check_has_compiled=check_has_compiled,
-            output_process_fn_grad=output_process_fn_grad,
-        )
-def run_and_get_cpp_code(fn, *args, **kwargs):
-    # We use the patch context manager instead of using it as a decorator.
-    # In this way, we can ensure that the attribute is patched and unpatched correctly
-    # even if this run_and_get_cpp_code function is called multiple times.
-    with torch._inductor.config.patch(debug=True):
-        torch._dynamo.reset()
-        import io
-        import logging
-
-        log_capture_string = io.StringIO()
-        ch = logging.StreamHandler(log_capture_string)
-        from torch._inductor.graph import output_code_log
-
-        output_code_log.addHandler(ch)
-        prev_level = output_code_log.level
-        output_code_log.setLevel(logging.DEBUG)
-        result = fn(*args, **kwargs)
-        s = log_capture_string.getvalue()
-        output_code_log.setLevel(prev_level)
-        output_code_log.removeHandler(ch)
-    return result, s
-
-class TestCase(TorchTestCase):
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-        cls._stack = contextlib.ExitStack()
-        cls._stack.enter_context(
-            torch._inductor.config.patch(
-                {
-                    "debug": True,
-                    "debug_index_asserts": True,
-                    "cpp.min_chunk_size": 1,
-                    "triton.autotune_pointwise": False,  # too slow
-                    "implicit_fallbacks": False,
-                    "generate_intermediate_hooks": True,
-                }
-            )
-        )
-
-    @classmethod
-    def tearDownClass(cls):
-        cls._stack.close()
-        super().tearDownClass()
-
-    def setUp(self):
-        torch._dynamo.reset()
-        torch._inductor.metrics.reset()
-        super().setUp()
-        self._start = time.perf_counter()
-
-    def tearDown(self):
-        super().tearDown()
-        torch._dynamo.reset()
-        if os.environ.get("ERROR_ON_SLOW") == "1":
-            elapsed = time.perf_counter() - self._start
-            assert elapsed < 120
-class ToTuple(torch.nn.Module):
-    def forward(self, x):
-        return (x,)
-
-
-def make_dynamic_cls(cls, xfail_prop="_expected_failure_dynamic"):
-    return make_test_cls_with_patches(
-        cls,
-        "DynamicShapes",
-        "_dynamic_shapes",
-        (torch._dynamo.config, "assume_static_by_default", False),
-        xfail_prop=xfail_prop,
-    )
-def filesize(filename: pathlib.Path):
-    assert filename.exists(), f"{filename} is missing"
-    return os.stat(filename).st_size