[jit] move torchbind tests to separate file (#37473)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37473

Test Plan: Imported from OSS

Differential Revision: D21297541

Pulled By: suo

fbshipit-source-id: 65c48094b1f26fbbf251021957257ce04279922b
diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py
new file mode 100644
index 0000000..b505531
--- /dev/null
+++ b/test/jit/test_torchbind.py
@@ -0,0 +1,241 @@
+import io
+import os
+import sys
+import torch
+from typing import Optional
+
+# Make the helper files in test/ importable
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+from torch.testing._internal.jit_utils import JitTestCase
+from torch.testing._internal.common_utils import skipIfRocm
+from torch.testing import FileCheck
+
+if __name__ == "__main__":
+    raise RuntimeError(
+        "This test file is not meant to be run directly, use:\n\n"
+        "\tpython test/test_jit.py TESTNAME\n\n"
+        "instead."
+    )
+
+class TestTorchbind(JitTestCase):
+    @skipIfRocm
+    def test_torchbind(self):
+        def test_equality(f, cmp_key):
+            obj1 = f()
+            obj2 = torch.jit.script(f)()
+            return (cmp_key(obj1), cmp_key(obj2))
+
+        def f():
+            val = torch.classes._TorchScriptTesting._Foo(5, 3)
+            val.increment(1)
+            return val
+        test_equality(f, lambda x: x)
+
+        with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
+            val = torch.classes._TorchScriptTesting._Foo(5, 3)
+            val.increment('foo')
+
+        def f():
+            ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
+            return ss.pop()
+        test_equality(f, lambda x: x)
+
+        def f():
+            ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
+            ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
+            ss1.push(ss2.pop())
+            return ss1.pop() + ss2.pop()
+        test_equality(f, lambda x: x)
+
+    @skipIfRocm
+    def test_torchbind_take_as_arg(self):
+        global StackString  # see [local resolution in python]
+        StackString = torch.classes._TorchScriptTesting._StackString
+
+        def foo(stackstring):
+            # type: (StackString)
+            stackstring.push("lel")
+            return stackstring
+
+        script_input = torch.classes._TorchScriptTesting._StackString([])
+        scripted = torch.jit.script(foo)
+        script_output = scripted(script_input)
+        self.assertEqual(script_output.pop(), "lel")
+
+    @skipIfRocm
+    def test_torchbind_return_instance(self):
+        def foo():
+            ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
+            return ss
+
+        scripted = torch.jit.script(foo)
+        # Ensure we are creating the object and calling __init__
+        # rather than calling the __init__wrapper nonsense
+        fc = FileCheck().check('prim::CreateObject()')\
+                        .check('prim::CallMethod[name="__init__"]')
+        fc.run(str(scripted.graph))
+        out = scripted()
+        self.assertEqual(out.pop(), "mom")
+        self.assertEqual(out.pop(), "hi")
+
+    @skipIfRocm
+    def test_torchbind_return_instance_from_method(self):
+        def foo():
+            ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
+            clone = ss.clone()
+            ss.pop()
+            return ss, clone
+
+        scripted = torch.jit.script(foo)
+        out = scripted()
+        self.assertEqual(out[0].pop(), "hi")
+        self.assertEqual(out[1].pop(), "mom")
+        self.assertEqual(out[1].pop(), "hi")
+
+    @skipIfRocm
+    def test_torchbind_take_instance_as_method_arg(self):
+        def foo():
+            ss = torch.classes._TorchScriptTesting._StackString(["mom"])
+            ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
+            ss.merge(ss2)
+            return ss
+
+        scripted = torch.jit.script(foo)
+        out = scripted()
+        self.assertEqual(out.pop(), "hi")
+        self.assertEqual(out.pop(), "mom")
+
+    @skipIfRocm
+    def test_torchbind_return_tuple(self):
+        def f():
+            val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
+            return val.return_a_tuple()
+
+        scripted = torch.jit.script(f)
+        tup = scripted()
+        self.assertEqual(tup, (1337.0, 123))
+
+    @skipIfRocm
+    def test_torchbind_save_load(self):
+        def foo():
+            ss = torch.classes._TorchScriptTesting._StackString(["mom"])
+            ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
+            ss.merge(ss2)
+            return ss
+
+        scripted = torch.jit.script(foo)
+        self.getExportImportCopy(scripted)
+
+    @skipIfRocm
+    def test_torchbind_lambda_method(self):
+        def foo():
+            ss = torch.classes._TorchScriptTesting._StackString(["mom"])
+            return ss.top()
+
+        scripted = torch.jit.script(foo)
+        self.assertEqual(scripted(), "mom")
+
+    @skipIfRocm
+    def test_torchbind_class_attribute(self):
+        class FooBar1234(torch.nn.Module):
+            def __init__(self):
+                super(FooBar1234, self).__init__()
+                self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
+
+            def forward(self):
+                return self.f.top()
+
+        inst = FooBar1234()
+        scripted = torch.jit.script(inst)
+        eic = self.getExportImportCopy(scripted)
+        assert eic() == "deserialized"
+        for expected in ["deserialized", "was", "i"]:
+            assert eic.f.pop() == expected
+
+    @skipIfRocm
+    def test_torchbind_getstate(self):
+        class FooBar4321(torch.nn.Module):
+            def __init__(self):
+                super(FooBar4321, self).__init__()
+                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
+
+            def forward(self):
+                return self.f.top()
+
+        inst = FooBar4321()
+        scripted = torch.jit.script(inst)
+        eic = self.getExportImportCopy(scripted)
+        # NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
+        # return {1, 3, 3, 7}. I tried to make this actually depend on the
+        # values at instantiation in the test with some transformation, but
+        # because it seems we serialize/deserialize multiple times, that
+        # transformation isn't as you would it expect it to be.
+        assert eic() == 7
+        for expected in [7, 3, 3, 1]:
+            assert eic.f.pop() == expected
+
+    @skipIfRocm
+    def test_torchbind_tracing(self):
+        class TryTracing(torch.nn.Module):
+            def __init__(self):
+                super(TryTracing, self).__init__()
+                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
+
+            def forward(self):
+                return torch.ops._TorchScriptTesting.take_an_instance(self.f)
+
+        traced = torch.jit.trace(TryTracing(), ())
+        self.assertEqual(torch.zeros(4, 4), traced())
+
+    @skipIfRocm
+    def test_torchbind_tracing_nested(self):
+        class TryTracingNest(torch.nn.Module):
+            def __init__(self):
+                super(TryTracingNest, self).__init__()
+                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
+
+        class TryTracing123(torch.nn.Module):
+            def __init__(self):
+                super(TryTracing123, self).__init__()
+                self.nest = TryTracingNest()
+
+            def forward(self):
+                return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)
+
+        traced = torch.jit.trace(TryTracing123(), ())
+        self.assertEqual(torch.zeros(4, 4), traced())
+
+    @skipIfRocm
+    def test_torchbind_pickle_serialization(self):
+        nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
+        b = io.BytesIO()
+        torch.save(nt, b)
+        b.seek(0)
+        nt_loaded = torch.load(b)
+        for exp in [7, 3, 3, 1]:
+            self.assertEqual(nt_loaded.pop(), exp)
+
+    @skipIfRocm
+    def test_torchbind_instantiate_missing_class(self):
+        with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'):
+            torch.classes.foo.IDontExist(3, 4, 5)
+
+    @skipIfRocm
+    def test_torchbind_optional_explicit_attr(self):
+        class TorchBindOptionalExplicitAttr(torch.nn.Module):
+            foo : Optional[torch.classes._TorchScriptTesting._StackString]
+
+            def __init__(self):
+                super().__init__()
+                self.foo = torch.classes._TorchScriptTesting._StackString(["test"])
+
+            def forward(self) -> str:
+                foo_obj = self.foo
+                if foo_obj is not None:
+                    return foo_obj.pop()
+                else:
+                    return '<None>'
+
+        mod = TorchBindOptionalExplicitAttr()
+        scripted = torch.jit.script(mod)
diff --git a/test/test_jit.py b/test/test_jit.py
index ee734e9..c23dad1 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -23,6 +23,7 @@
 from jit.test_save_load import TestSaveLoad  # noqa: F401
 from jit.test_python_ir import TestPythonIr  # noqa: F401
 from jit.test_functional_blocks import TestFunctionalBlocks  # noqa: F401
+from jit.test_torchbind import TestTorchbind  # noqa: F401
 
 # Torch
 from torch import Tensor
@@ -4579,114 +4580,6 @@
         self.assertEqual(7, w(3))
         self.assertFalse("training" in w.state_dict())
 
-    @skipIfRocm
-    def test_torchbind(self):
-        def test_equality(f, cmp_key):
-            obj1 = f()
-            obj2 = torch.jit.script(f)()
-            return (cmp_key(obj1), cmp_key(obj2))
-
-        def f():
-            val = torch.classes._TorchScriptTesting._Foo(5, 3)
-            val.increment(1)
-            return val
-        test_equality(f, lambda x: x)
-
-        with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
-            val = torch.classes._TorchScriptTesting._Foo(5, 3)
-            val.increment('foo')
-
-        def f():
-            ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
-            return ss.pop()
-        test_equality(f, lambda x: x)
-
-        def f():
-            ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
-            ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
-            ss1.push(ss2.pop())
-            return ss1.pop() + ss2.pop()
-        test_equality(f, lambda x: x)
-
-    @skipIfRocm
-    def test_torchbind_take_as_arg(self):
-        global StackString  # see [local resolution in python]
-        StackString = torch.classes._TorchScriptTesting._StackString
-
-        def foo(stackstring):
-            # type: (StackString)
-            stackstring.push("lel")
-            return stackstring
-
-        script_input = torch.classes._TorchScriptTesting._StackString([])
-        scripted = torch.jit.script(foo)
-        script_output = scripted(script_input)
-        self.assertEqual(script_output.pop(), "lel")
-
-    @skipIfRocm
-    def test_torchbind_return_instance(self):
-        def foo():
-            ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
-            return ss
-
-        scripted = torch.jit.script(foo)
-        # Ensure we are creating the object and calling __init__
-        # rather than calling the __init__wrapper nonsense
-        fc = FileCheck().check('prim::CreateObject()')\
-                        .check('prim::CallMethod[name="__init__"]')
-        fc.run(str(scripted.graph))
-        out = scripted()
-        self.assertEqual(out.pop(), "mom")
-        self.assertEqual(out.pop(), "hi")
-
-    @skipIfRocm
-    def test_torchbind_return_instance_from_method(self):
-        def foo():
-            ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
-            clone = ss.clone()
-            ss.pop()
-            return ss, clone
-
-        scripted = torch.jit.script(foo)
-        out = scripted()
-        self.assertEqual(out[0].pop(), "hi")
-        self.assertEqual(out[1].pop(), "mom")
-        self.assertEqual(out[1].pop(), "hi")
-
-    @skipIfRocm
-    def test_torchbind_take_instance_as_method_arg(self):
-        def foo():
-            ss = torch.classes._TorchScriptTesting._StackString(["mom"])
-            ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
-            ss.merge(ss2)
-            return ss
-
-        scripted = torch.jit.script(foo)
-        out = scripted()
-        self.assertEqual(out.pop(), "hi")
-        self.assertEqual(out.pop(), "mom")
-
-    @skipIfRocm
-    def test_torchbind_return_tuple(self):
-        def f():
-            val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
-            return val.return_a_tuple()
-
-        scripted = torch.jit.script(f)
-        tup = scripted()
-        self.assertEqual(tup, (1337.0, 123))
-
-    @skipIfRocm
-    def test_torchbind_save_load(self):
-        def foo():
-            ss = torch.classes._TorchScriptTesting._StackString(["mom"])
-            ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
-            ss.merge(ss2)
-            return ss
-
-        scripted = torch.jit.script(foo)
-        self.getExportImportCopy(scripted)
-
     def test_class_as_attribute(self):
         @torch.jit.script
         class Foo321(object):
@@ -4706,124 +4599,6 @@
         x = torch.rand(3, 4)
         self.assertEqual(scripted(x), eic(x))
 
-    @skipIfRocm
-    def test_torchbind_lambda_method(self):
-        def foo():
-            ss = torch.classes._TorchScriptTesting._StackString(["mom"])
-            return ss.top()
-
-        scripted = torch.jit.script(foo)
-        self.assertEqual(scripted(), "mom")
-
-    @skipIfRocm
-    def test_torchbind_class_attribute(self):
-        class FooBar1234(torch.nn.Module):
-            def __init__(self):
-                super(FooBar1234, self).__init__()
-                self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
-
-            def forward(self):
-                return self.f.top()
-
-        inst = FooBar1234()
-        scripted = torch.jit.script(inst)
-        eic = self.getExportImportCopy(scripted)
-        assert eic() == "deserialized"
-        for expected in ["deserialized", "was", "i"]:
-            assert eic.f.pop() == expected
-
-    @skipIfRocm
-    def test_torchbind_getstate(self):
-        class FooBar4321(torch.nn.Module):
-            def __init__(self):
-                super(FooBar4321, self).__init__()
-                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
-
-            def forward(self):
-                return self.f.top()
-
-        inst = FooBar4321()
-        scripted = torch.jit.script(inst)
-        eic = self.getExportImportCopy(scripted)
-        # NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
-        # return {1, 3, 3, 7}. I tried to make this actually depend on the
-        # values at instantiation in the test with some transformation, but
-        # because it seems we serialize/deserialize multiple times, that
-        # transformation isn't as you would it expect it to be.
-        assert eic() == 7
-        for expected in [7, 3, 3, 1]:
-            assert eic.f.pop() == expected
-
-    @skipIfRocm
-    def test_torchbind_tracing(self):
-        class TryTracing(torch.nn.Module):
-            def __init__(self):
-                super(TryTracing, self).__init__()
-                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
-
-            def forward(self):
-                return torch.ops._TorchScriptTesting.take_an_instance(self.f)
-
-        traced = torch.jit.trace(TryTracing(), ())
-        self.assertEqual(torch.zeros(4, 4), traced())
-
-    @skipIfRocm
-    def test_torchbind_tracing_nested(self):
-        class TryTracingNest(torch.nn.Module):
-            def __init__(self):
-                super(TryTracingNest, self).__init__()
-                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
-
-        class TryTracing123(torch.nn.Module):
-            def __init__(self):
-                super(TryTracing123, self).__init__()
-                self.nest = TryTracingNest()
-
-            def forward(self):
-                return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)
-
-        traced = torch.jit.trace(TryTracing123(), ())
-        self.assertEqual(torch.zeros(4, 4), traced())
-
-    @skipIfRocm
-    def test_torchbind_pickle_serialization(self):
-        nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
-        b = io.BytesIO()
-        torch.save(nt, b)
-        b.seek(0)
-        nt_loaded = torch.load(b)
-        for exp in [7, 3, 3, 1]:
-            self.assertEqual(nt_loaded.pop(), exp)
-
-    @skipIfRocm
-    def test_torchbind_instantiate_missing_class(self):
-        with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'):
-            torch.classes.foo.IDontExist(3, 4, 5)
-
-    @skipIfRocm
-    def test_torchbind_optional_explicit_attr(self):
-        class TorchBindOptionalExplicitAttr(torch.nn.Module):
-            foo : Optional[torch.classes._TorchScriptTesting._StackString]
-
-            def __init__(self):
-                super().__init__()
-                self.foo = torch.classes._TorchScriptTesting._StackString(["test"])
-
-            def forward(self) -> str:
-                foo_obj = self.foo
-                if foo_obj is not None:
-                    return foo_obj.pop()
-                else:
-                    return '<None>'
-
-        mod = TorchBindOptionalExplicitAttr()
-        scripted = torch.jit.script(mod)
-
-    @skipIfRocm
-    def test_torchbind_str(self):
-        foo = torch.classes._TorchScriptTesting._StackString(["foo", "bar", "baz"])
-        self.assertEqual(str(foo), "[foo, bar, baz]")
-
     def test_module_str(self):
         class Foo(torch.nn.Module):
             def forward(self, x):
@@ -4832,12 +4607,6 @@
         f = torch.jit.script(Foo())
         self.assertEqual('ScriptObject', str(f._c))
 
-    @skipIfRocm
-    def test_torchbind_magic_unimplemented(self):
-        foo = torch.classes._TorchScriptTesting._StackString(["foo", "bar", "baz"])
-        with self.assertRaises(NotImplementedError):
-            foo[3]
-
     def _test_lower_graph_impl(self, model, data):
         model.qconfig = torch.quantization.default_qconfig
         model = torch.quantization.prepare(model)