PT2/TorchScript interoperability fix (#94678)

Allows torch.compile() to inline into ScriptFunction

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94678
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_interop.py b/test/dynamo/test_interop.py
new file mode 100644
index 0000000..1576706
--- /dev/null
+++ b/test/dynamo/test_interop.py
@@ -0,0 +1,38 @@
+# Owner(s): ["module: dynamo"]
+import torch
+
+import torch._dynamo.test_case
+import torch._dynamo.testing
+import torch.onnx.operators
+from torch._dynamo.testing import same
+
+
+def fn(a, b):
+    return a + b * 0.67
+
+
+class InteropTests(torch._dynamo.test_case.TestCase):
+    def _common(self, fn):
+        inputs = [torch.randn(10), torch.randn(10)]
+        ref = fn(*inputs)
+        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
+        res = opt_fn(*inputs)
+        self.assertTrue(same(ref, res))
+
+    def test_fx_fn(self):
+        fx_fn = torch.fx.symbolic_trace(fn)
+        self._common(lambda a, b: fx_fn(a, b) + 1)
+
+    def test_script_fn(self):
+        script_fn = torch.jit.script(fn)
+        self._common(lambda a, b: script_fn(a, b) + 1)
+
+    def test_trace_fn(self):
+        trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)])
+        self._common(lambda a, b: trace_fn(a, b) + 1)
+
+
+if __name__ == "__main__":
+    from torch._dynamo.test_case import run_tests
+
+    run_tests()
diff --git a/test/jit/test_autodiff.py b/test/jit/test_autodiff.py
index 3173e81..a77569f 100644
--- a/test/jit/test_autodiff.py
+++ b/test/jit/test_autodiff.py
@@ -2,9 +2,12 @@
 
 import torch
 
+from torch.testing._internal.common_utils import skipIfTorchDynamo
 from torch.testing._internal.jit_utils import JitTestCase
 from typing import List
 
+
+@skipIfTorchDynamo()
 class TestAutodiffJit(JitTestCase):
     def test_undefined_tensor_lists(self):
         def fn(tensor_list: List[torch.Tensor], add_tensor):
diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py
index 81df055..5389751 100644
--- a/test/jit/test_profiler.py
+++ b/test/jit/test_profiler.py
@@ -4,6 +4,7 @@
 import sys
 
 import torch
+from torch.testing._internal.common_utils import skipIfTorchDynamo
 
 # Make the helper files in test/ importable
 pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@@ -15,6 +16,7 @@
                        "\tpython test/test_jit.py TESTNAME\n\n"
                        "instead.")
 
+@skipIfTorchDynamo()
 class TestProfiler(JitTestCase):
     def setUp(self):
         self.prev_exec = torch._C._jit_set_profiling_executor(True)
diff --git a/test/test_jit.py b/test/test_jit.py
index 530b448..3394768 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -342,6 +342,8 @@
         super().__init__()
         self.bar = torch.jit.ScriptModule()
 
+
+@skipIfTorchDynamo()
 class TestJit(JitTestCase):
     @unittest.skip("Requires a lot of RAM")
     def test_big(self):
@@ -2982,6 +2984,7 @@
         self.assertRegex(graph.__repr__(), source_range_regex)
 
 
+@skipIfTorchDynamo()
 class TestFrontend(JitTestCase):
 
     def test_instancing_error(self):
@@ -3038,6 +3041,7 @@
             res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])})
 
 
+@skipIfTorchDynamo()
 class TestScript(JitTestCase):
 
     # Tests that calling torch.jit.script repeated on function is allowed.
@@ -15989,10 +15993,12 @@
 }
 
 
+@skipIfTorchDynamo()
 class TestJitGeneratedModule(JitTestCase):
     pass
 
 
+@skipIfTorchDynamo()
 class TestJitGeneratedFunctional(JitTestCase):
     pass
 
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index 711a44b..b00588e 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -80,6 +80,8 @@
     finally:
         torch._C._debug_set_fusion_group_inlining(old_inlining)
 
+
+@skipIfTorchDynamo()
 class TestTEFuser(JitTestCase):
     def setUp(self):
         super().setUp()
@@ -2622,6 +2624,7 @@
 # super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works.
 # super(TestNNCOpInfo, self) fails because TestNNCOpInfo gets deleted from global scope.
 # super(JitCommonTestCase, self).fn() would skip JitCommonTestCase.fn() implementation
+@skipIfTorchDynamo()
 class TestNNCOpInfoParent(JitCommonTestCase):
     pass
 
@@ -2739,6 +2742,7 @@
 instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)
 
 # Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent)
+@skipIfTorchDynamo()
 class TestLoopnestRandomizationParent(JitTestCase):
     pass
 
diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py
index e58b577..d60376f 100644
--- a/test/test_tensorexpr.py
+++ b/test/test_tensorexpr.py
@@ -7,7 +7,7 @@
 import unittest
 import itertools
 
-from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests
+from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests, skipIfTorchDynamo
 
 from torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions
 
@@ -34,6 +34,7 @@
     return results
 
 
+@skipIfTorchDynamo()
 class TestTensorExprFuser(BaseTestClass):
     def test_easy(self):
         def easy(x, y):
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 67a0a53..51838eb 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -459,7 +459,7 @@
                 source=self.source,
                 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
             )
-        elif istype(value, types.FunctionType):
+        elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
             return UserFunctionVariable(
                 value,
                 source=self.source,
diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py
index d59767d..31d2e15 100644
--- a/torch/_dynamo/variables/functions.py
+++ b/torch/_dynamo/variables/functions.py
@@ -112,7 +112,7 @@
             self.is_constant = False
 
         assert isinstance(
-            fn, types.FunctionType
+            fn, (types.FunctionType, torch.jit.ScriptFunction)
         ), f"expected FunctionType found {typestr(fn)} {fn}"
         # unpack @torch._dynamo.optimize()(fn) wrapped function
         fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
diff --git a/torch/jit/_script.py b/torch/jit/_script.py
index cee7a24..fd0fa1f 100644
--- a/torch/jit/_script.py
+++ b/torch/jit/_script.py
@@ -1343,6 +1343,8 @@
         )
         # Forward docstrings
         fn.__doc__ = obj.__doc__
+        # Allow torch.compile() to inline
+        fn._torchdynamo_inline = obj  # type: ignore[attr-defined]
         _set_jit_function_cache(obj, fn)
         return fn
     else:
diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py
index f0da4a1..4afe734 100644
--- a/torch/jit/_trace.py
+++ b/torch/jit/_trace.py
@@ -893,6 +893,8 @@
                 example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
             )
 
+    # Allow torch.compile() to inline
+    traced._torchdynamo_inline = func  # type: ignore[attr-defined]
     return traced