PT2/TorchScript interoperability fix (#94678)
Allows torch.compile() to inline into ScriptFunction
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94678
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_interop.py b/test/dynamo/test_interop.py
new file mode 100644
index 0000000..1576706
--- /dev/null
+++ b/test/dynamo/test_interop.py
@@ -0,0 +1,38 @@
+# Owner(s): ["module: dynamo"]
+import torch
+
+import torch._dynamo.test_case
+import torch._dynamo.testing
+import torch.onnx.operators
+from torch._dynamo.testing import same
+
+
+def fn(a, b):
+ return a + b * 0.67
+
+
+class InteropTests(torch._dynamo.test_case.TestCase):
+ def _common(self, fn):
+ inputs = [torch.randn(10), torch.randn(10)]
+ ref = fn(*inputs)
+ opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
+ res = opt_fn(*inputs)
+ self.assertTrue(same(ref, res))
+
+ def test_fx_fn(self):
+ fx_fn = torch.fx.symbolic_trace(fn)
+ self._common(lambda a, b: fx_fn(a, b) + 1)
+
+ def test_script_fn(self):
+ script_fn = torch.jit.script(fn)
+ self._common(lambda a, b: script_fn(a, b) + 1)
+
+ def test_trace_fn(self):
+ trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)])
+ self._common(lambda a, b: trace_fn(a, b) + 1)
+
+
+if __name__ == "__main__":
+ from torch._dynamo.test_case import run_tests
+
+ run_tests()
diff --git a/test/jit/test_autodiff.py b/test/jit/test_autodiff.py
index 3173e81..a77569f 100644
--- a/test/jit/test_autodiff.py
+++ b/test/jit/test_autodiff.py
@@ -2,9 +2,12 @@
import torch
+from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.jit_utils import JitTestCase
from typing import List
+
+@skipIfTorchDynamo()
class TestAutodiffJit(JitTestCase):
def test_undefined_tensor_lists(self):
def fn(tensor_list: List[torch.Tensor], add_tensor):
diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py
index 81df055..5389751 100644
--- a/test/jit/test_profiler.py
+++ b/test/jit/test_profiler.py
@@ -4,6 +4,7 @@
import sys
import torch
+from torch.testing._internal.common_utils import skipIfTorchDynamo
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@@ -15,6 +16,7 @@
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
+@skipIfTorchDynamo()
class TestProfiler(JitTestCase):
def setUp(self):
self.prev_exec = torch._C._jit_set_profiling_executor(True)
diff --git a/test/test_jit.py b/test/test_jit.py
index 530b448..3394768 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -342,6 +342,8 @@
super().__init__()
self.bar = torch.jit.ScriptModule()
+
+@skipIfTorchDynamo()
class TestJit(JitTestCase):
@unittest.skip("Requires a lot of RAM")
def test_big(self):
@@ -2982,6 +2984,7 @@
self.assertRegex(graph.__repr__(), source_range_regex)
+@skipIfTorchDynamo()
class TestFrontend(JitTestCase):
def test_instancing_error(self):
@@ -3038,6 +3041,7 @@
res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])})
+@skipIfTorchDynamo()
class TestScript(JitTestCase):
# Tests that calling torch.jit.script repeated on function is allowed.
@@ -15989,10 +15993,12 @@
}
+@skipIfTorchDynamo()
class TestJitGeneratedModule(JitTestCase):
pass
+@skipIfTorchDynamo()
class TestJitGeneratedFunctional(JitTestCase):
pass
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index 711a44b..b00588e 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -80,6 +80,8 @@
finally:
torch._C._debug_set_fusion_group_inlining(old_inlining)
+
+@skipIfTorchDynamo()
class TestTEFuser(JitTestCase):
def setUp(self):
super().setUp()
@@ -2622,6 +2624,7 @@
# super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works.
# super(TestNNCOpInfo, self) fails because TestNNCOpInfo gets deleted from global scope.
# super(JitCommonTestCase, self).fn() would skip JitCommonTestCase.fn() implementation
+@skipIfTorchDynamo()
class TestNNCOpInfoParent(JitCommonTestCase):
pass
@@ -2739,6 +2742,7 @@
instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)
# Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent)
+@skipIfTorchDynamo()
class TestLoopnestRandomizationParent(JitTestCase):
pass
diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py
index e58b577..d60376f 100644
--- a/test/test_tensorexpr.py
+++ b/test/test_tensorexpr.py
@@ -7,7 +7,7 @@
import unittest
import itertools
-from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests
+from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests, skipIfTorchDynamo
from torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions
@@ -34,6 +34,7 @@
return results
+@skipIfTorchDynamo()
class TestTensorExprFuser(BaseTestClass):
def test_easy(self):
def easy(x, y):
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 67a0a53..51838eb 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -459,7 +459,7 @@
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
- elif istype(value, types.FunctionType):
+ elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
return UserFunctionVariable(
value,
source=self.source,
diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py
index d59767d..31d2e15 100644
--- a/torch/_dynamo/variables/functions.py
+++ b/torch/_dynamo/variables/functions.py
@@ -112,7 +112,7 @@
self.is_constant = False
assert isinstance(
- fn, types.FunctionType
+ fn, (types.FunctionType, torch.jit.ScriptFunction)
), f"expected FunctionType found {typestr(fn)} {fn}"
# unpack @torch._dynamo.optimize()(fn) wrapped function
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
diff --git a/torch/jit/_script.py b/torch/jit/_script.py
index cee7a24..fd0fa1f 100644
--- a/torch/jit/_script.py
+++ b/torch/jit/_script.py
@@ -1343,6 +1343,8 @@
)
# Forward docstrings
fn.__doc__ = obj.__doc__
+ # Allow torch.compile() to inline
+ fn._torchdynamo_inline = obj # type: ignore[attr-defined]
_set_jit_function_cache(obj, fn)
return fn
else:
diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py
index f0da4a1..4afe734 100644
--- a/torch/jit/_trace.py
+++ b/torch/jit/_trace.py
@@ -893,6 +893,8 @@
example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
)
+ # Allow torch.compile() to inline
+ traced._torchdynamo_inline = func # type: ignore[attr-defined]
return traced